diff --git a/src/arithmetic/compute_expr.h b/src/arithmetic/compute_expr.h index 471bda5634b4501985cf094a8bdafbe4c5e88e13..d5e8d8f6d8310a488291e2f35661cf32ac2662dc 100644 --- a/src/arithmetic/compute_expr.h +++ b/src/arithmetic/compute_expr.h @@ -80,7 +80,7 @@ inline bool GetConstInt(Expr e, int* out) { } \ uint64_t ua = 0, ub = 0; \ if (GetConst(a, &ua) && GetConst(b, &ub)) { \ - return ir::UIntImm::make(a.type(), ua + ub); \ + return ir::UIntImm::make(a.type(), ua OP ub); \ } \ template<> diff --git a/src/arithmetic/modular.cc b/src/arithmetic/modular.cc index c487701064f9682f7f4e3a1f50d4e44ae5645b4e..19383f7051a0f5d96855989eae142b5641b480b7 100644 --- a/src/arithmetic/modular.cc +++ b/src/arithmetic/modular.cc @@ -113,7 +113,7 @@ class ModularEvaluator private: const std::unordered_map< const Variable*, ModularEntry>& mod_map_; - + friend struct ModularEntry; // simplify the base by putting it in range. static int BaseSimplify(int base, int coeff) { if (coeff == 0) return base; @@ -136,6 +136,15 @@ class ModularEvaluator } }; +ModularEntry ModularEntry::Add(const ModularEntry& a, + const ModularEntry& b) { + ModularEntry ret; + ret.coeff = ModularEvaluator::ZeroAwareGCD(a.coeff, b.coeff); + ret.base = ModularEvaluator::BaseSimplify(a.base + b.base, ret.coeff); + return ret; +} + + ModularEntry EvalModular( const Expr& e, const std::unordered_map<const Variable*, ModularEntry>& mod_map) { diff --git a/src/arithmetic/modular.h b/src/arithmetic/modular.h index bb51901a65f36b98b64c717b08bbc5768983d50d..a152f63732b89278e0481dd42efd0af8993308c2 100644 --- a/src/arithmetic/modular.h +++ b/src/arithmetic/modular.h @@ -37,6 +37,14 @@ struct ModularEntry { e.base = 0; e.coeff = 1; return e; } + /*! + * \brief Add two modular entries together to get a new modular entry. + * \param a The left operand. + * \param b The right operand. + * \return The combined modular entry. + */ + static ModularEntry Add(const ModularEntry& a, + const ModularEntry& b); }; /*! diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index d7c0a35d3b79d4b59879129b10a7be3f8c6279eb..f5da1b0c2d64296e6ca22de9e0bfd884663d491a 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -102,8 +102,14 @@ void CodeGenLLVM::InitGlobalContext() { gv_mod_ctx_->setInitializer(llvm::Constant::getNullValue(t_void_p_)); } -void CodeGenLLVM::AddFunction(const LoweredFunc& f) { +void CodeGenLLVM::InitFuncState() { var_map_.clear(); + align_map_.clear(); + alloc_storage_scope_.clear(); +} + +void CodeGenLLVM::AddFunction(const LoweredFunc& f) { + this->InitFuncState(); CHECK(!module_->getFunction(f->name)) << "Function " << f->name << "already exists in module"; std::vector<llvm::Type*> arg_type; @@ -163,6 +169,7 @@ class FPassManager : public llvm::legacy::FunctionPassManager { llvm::legacy::FunctionPassManager::add(p); } }; + class MPassManager : public llvm::legacy::PassManager { public: // override add to allow messaging @@ -245,25 +252,26 @@ void CodeGenLLVM::AddAliasInfo( int base = 0, width = 0; // create meta-data for alias analysis // Use a group of binary tree ranges. - const Ramp* ramp = index.as<Ramp>(); - if (ramp) { - int base, stride; - if (arith::GetConstInt(ramp->base, &base) && - arith::GetConstInt(ramp->stride, &stride)) { - int xwith = ramp->lanes * stride; - width = 1; - while (width < xwith) { - width *= 2; - } - while (base % width) { - base -= base % width; - width *= 2; + if (index.defined()) { + const Ramp* ramp = index.as<Ramp>(); + if (ramp) { + int base, stride; + if (arith::GetConstInt(ramp->base, &base) && + arith::GetConstInt(ramp->stride, &stride)) { + int xwith = ramp->lanes * stride; + width = 1; + while (width < xwith) { + width *= 2; + } + while (base % width) { + base -= base % width; + width *= 2; + } } + } else { + if (arith::GetConstInt(index, &base)) width = 1; } - } else { - if (arith::GetConstInt(index, &base)) width = 1; } - llvm::MDNode* meta = md_tbaa_root_; std::ostringstream buffer_addr; buffer_addr << buffer; @@ -283,12 +291,12 @@ void CodeGenLLVM::AddAliasInfo( } llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) { - llvm::Constant* init = llvm::UndefValue::get( + llvm::Constant* undef = llvm::UndefValue::get( llvm::VectorType::get(value->getType(), lanes)); llvm::Constant* zero = ConstInt32(0); - value = builder_->CreateInsertElement(init, value, zero); + value = builder_->CreateInsertElement(undef, value, zero); llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero); - return builder_->CreateShuffleVector(value, init, mask); + return builder_->CreateShuffleVector(value, undef, mask); } llvm::Value* CodeGenLLVM::CreateBufferPtr( @@ -684,6 +692,38 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) { return nullptr; } +int CodeGenLLVM::NativeVectorBits(const std::string& storage_scope) const { + // By default, we ask the buffer to be aligned to 64 bytes + return 64 * 8; +} + +void CodeGenLLVM::GetAlignment( + Type t, const Variable* buf_var, const Expr& index, + int* p_alignment, int* p_native_bits) { + int& alignment = *p_alignment; + int& native_bits = *p_native_bits; + // The storage scope. + std::string scope; + auto it = alloc_storage_scope_.find(buf_var); + if (it != alloc_storage_scope_.end()) { + scope = it->second; + } + arith::ModularEntry m = EvalModular(index, align_map_); + native_bits = NativeVectorBits(scope); + alignment = t.element_of().bits(); + // find alignment + while ((m.coeff & 1) == 0 && + (m.base & 1) == 0 && + alignment < native_bits) { + m.coeff /= 2; + m.base /= 2; + alignment *= 2; + } + CHECK_EQ(alignment % 8, 0) + << "Load from memory that does not align to 8 bits"; + alignment /= 8; +} + // visitor overrides llvm::Value* CodeGenLLVM::VisitExpr_(const Variable* op) { return GetVarValue(op); @@ -849,7 +889,9 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Select* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const Let* op) { llvm::Value* v = MakeValue(op->value); CHECK(!var_map_.count(op->var.get())); + CHECK(!align_map_.count(op->var.get())); var_map_[op->var.get()] = v; + align_map_[op->var.get()] = arith::EvalModular(op->value, align_map_); return MakeValue(op->body); } @@ -872,25 +914,254 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Ramp* op) { return value; } +void CodeGenLLVM::Scalarize( + const Expr& e, + std::function<void(int i, llvm::Value* v)> f) { + const Ramp* ramp = e.as<Ramp>(); + Type t = e.type(); + if (ramp) { + for (int i = 0; i < t.lanes(); ++i) { + Expr offset = arith::ComputeExpr<Add>( + ramp->base, + arith::ComputeExpr<Mul>(ramp->stride, i)); + f(i, MakeValue(offset)); + } + } else { + llvm::Value* index = MakeValue(e); + for (int i = 0; i < t.lanes(); ++i) { + f(i, builder_->CreateExtractElement(index, ConstInt32(i))); + } + } +} + +llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) { + int lanes = static_cast<int>(vec->getType()->getVectorNumElements()); + std::vector<llvm::Constant*> indices; + for (int i = lanes; i != 0; --i) { + indices.push_back(ConstInt32(i - 1)); + } + llvm::Constant* undef = llvm::UndefValue::get(vec->getType()); + return builder_->CreateShuffleVector( + vec, undef, llvm::ConstantVector::get(indices)); +} + +llvm::Value* CodeGenLLVM::CreateVecSlice( + llvm::Value* vec, int begin, int lanes) { + int total_lanes = static_cast<int>(vec->getType()->getVectorNumElements()); + CHECK_LE(begin + lanes, total_lanes); + if (lanes == total_lanes && begin == 0) return vec; + std::vector<llvm::Constant*> indices; + for (int i = 0; i < lanes; ++i) { + indices.push_back(ConstInt32(begin + i)); + } + llvm::Constant* undef = llvm::UndefValue::get(vec->getType()); + return builder_->CreateShuffleVector( + vec, undef, llvm::ConstantVector::get(indices)); +} + +llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) { + int lanes = static_cast<int>(vec->getType()->getVectorNumElements()); + if (target_lanes == lanes) return vec; + CHECK_GT(target_lanes, lanes); + int pad_lanes = target_lanes - lanes; + llvm::Constant* undef = llvm::UndefValue::get( + llvm::VectorType::get(vec->getType()->getVectorElementType(), pad_lanes)); + std::vector<llvm::Constant*> indices; + for (int i = 0; i < target_lanes; ++i) { + indices.push_back(ConstInt32(i)); + } + return builder_->CreateShuffleVector( + vec, undef, llvm::ConstantVector::get(indices)); +} + +llvm::Value* CodeGenLLVM::CreateVecConcat( + std::vector<llvm::Value*> vec) { + CHECK_NE(vec.size(), 0U); + int target_lanes = 0; + for (llvm::Value* v : vec) { + target_lanes += static_cast<int>(v->getType()->getVectorNumElements()); + } + // tree shape merging + while (vec.size() != 1) { + std::vector<llvm::Value*> merged; + for (size_t i = 0; i < vec.size() - 1; i += 2) { + llvm::Value* v1 = vec[i]; + llvm::Value* v2 = vec[i + 1]; + int w1 = static_cast<int>(v1->getType()->getVectorNumElements()); + int w2 = static_cast<int>(v2->getType()->getVectorNumElements()); + int w = std::max(w1, w2); + v1 = CreateVecPad(v1, w); + v2 = CreateVecPad(v2, w); + std::vector<llvm::Constant*> indices; + for (int i = 0; i < w * 2; ++i) { + indices.push_back(ConstInt32(i)); + } + merged.push_back( + builder_->CreateShuffleVector( + v1, v2, llvm::ConstantVector::get(indices))); + } + if (vec.size() % 2 == 1) { + merged.push_back(vec.back()); + } + vec = merged; + } + return CreateVecSlice(vec[0], 0, target_lanes); +} + llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) { Type t = op->type; - CHECK(!t.is_vector()); - + const Ramp* ramp = op->index.as<Ramp>(); + llvm::Value* buf = GetVarValue(op->buffer_var.get()); if (t.is_scalar()) { llvm::LoadInst* inst = builder_->CreateAlignedLoad( - CreateBufferPtr( - t, - GetVarValue(op->buffer_var.get()), - MakeValue(op->index)), + CreateBufferPtr(t, buf, MakeValue(op->index)), data_layout_->getTypeAllocSize(LLVMType(t))); AddAliasInfo(inst, op->buffer_var.get(), op->index); return inst; + } else if (ramp && is_one(ramp->stride)) { + int alignment, native_bits; + GetAlignment(t, op->buffer_var.get(), ramp->base, + &alignment, &native_bits); + int total_lanes = t.lanes(); + int step = native_bits / t.bits(); + + std::vector<llvm::Value*> loads; + for (int offset = 0; offset < total_lanes; offset += step) { + int lanes = std::min(step, total_lanes - offset); + Expr base = arith::ComputeExpr<Add>( + ramp->base, make_const(ramp->base.type(), offset)); + llvm::Value* ptr = CreateBufferPtr(t.element_of(), buf, MakeValue(base)); + llvm::Type* vtype = llvm::VectorType::get( + LLVMType(t.element_of()), lanes)->getPointerTo(); + llvm::LoadInst* inst = builder_->CreateAlignedLoad( + builder_->CreatePointerCast(ptr, vtype), alignment); + AddAliasInfo(inst, op->buffer_var.get(), + Ramp::make(base, make_const(base.type(), 1), lanes)); + loads.push_back(inst); + } + return CreateVecConcat(loads); + } else if (ramp && is_const(ramp->stride, 2)) { + int alignment, native_bits; + GetAlignment(t, op->buffer_var.get(), ramp->base, + &alignment, &native_bits); + arith::ModularEntry e = arith::EvalModular(ramp->base, align_map_); + Type bt = ramp->base.type(); + int first_shift, next_shift; + // If it is even base, and native alignments is bigger than twice + // of the type, to ensure safe loading. + if (e.coeff % 2 == 0 && + e.base % 2 == 0 && + native_bits >= t.bits() * 2) { + first_shift = 0; + next_shift = 0; + } else if (e.coeff % 2 == 0 && e.base % 2 == 1) { + // odd base, shift both to left. + first_shift = -1; + next_shift = -1; + } else { + // save option, right part, safe option. + first_shift = 0; + next_shift = -1; + } + llvm::Value* first = MakeValue(Load::make( + t, op->buffer_var, + Ramp::make(arith::ComputeExpr<Add>( + ramp->base, make_const(bt, first_shift)), + make_const(bt, 1), ramp->lanes))); + llvm::Value* next = MakeValue(Load::make( + t, op->buffer_var, + Ramp::make(arith::ComputeExpr<Add>( + ramp->base, make_const(bt, ramp->lanes + next_shift)), + make_const(bt, 1), ramp->lanes))); + // shuffle + std::vector<llvm::Constant*> indices; + int target_index = 0; + for (int i = 0; i < ramp->lanes; ++i) { + int idx = first_shift + i; + if (idx == target_index) { + indices.push_back(ConstInt32(i)); + target_index += 2; + } + } + for (int i = 0; i < ramp->lanes; ++i) { + int idx = ramp->lanes + next_shift + i; + if (idx == target_index) { + indices.push_back(ConstInt32(i + ramp->lanes)); + target_index += 2; + } + } + CHECK_EQ(indices.size(), static_cast<size_t>(ramp->lanes)); + return builder_->CreateShuffleVector( + first, next, llvm::ConstantVector::get(indices)); + } else if (ramp && is_const(ramp->stride, -1)) { + int lanes = ramp->type.lanes(); + Expr neg_ramp = Ramp::make( + arith::ComputeExpr<Sub>( + ramp->base, + make_const(ramp->base.type(), lanes - 1)), + make_const(ramp->base.type(), 1), + lanes); + // load value then flip + llvm::Value* v = MakeValue(Load::make(t, op->buffer_var, neg_ramp)); + return CreateVecFlip(v); + } else { + llvm::Value* ret = llvm::UndefValue::get(LLVMType(t)); + Scalarize(op->index, [&](int i, llvm::Value* offset) { + llvm::Value* ptr = CreateBufferPtr(t.element_of(), buf, offset); + llvm::LoadInst* inst = builder_->CreateAlignedLoad( + ptr, data_layout_->getTypeAllocSize(LLVMType(t))); + AddAliasInfo(inst, op->buffer_var.get(), Expr()); + ret = builder_->CreateInsertElement(ret, inst, ConstInt32(i)); + }); + return ret; + } +} + +// stmts +void CodeGenLLVM::VisitStmt_(const Store* op) { + llvm::Value* value = MakeValue(op->value); + Type t = op->value.type(); + const Ramp* ramp = op->index.as<Ramp>(); + llvm::Value* buf = GetVarValue(op->buffer_var.get()); + + if (t.is_scalar()) { + llvm::StoreInst* inst = builder_->CreateAlignedStore( + value, + CreateBufferPtr(t, buf, MakeValue(op->index)), + data_layout_->getTypeAllocSize(value->getType())); + AddAliasInfo(inst, op->buffer_var.get(), op->index); + } else if (ramp && is_one(ramp->stride)) { + int alignment, native_bits; + GetAlignment(t, op->buffer_var.get(), ramp->base, + &alignment, &native_bits); + int total_lanes = t.lanes(); + int step = native_bits / t.bits(); + // vector store. + for (int offset = 0; offset < total_lanes; offset += step) { + int lanes = std::min(step, total_lanes - offset); + Expr base = arith::ComputeExpr<Add>( + ramp->base, make_const(ramp->base.type(), offset)); + llvm::Value* ptr = CreateBufferPtr(t.element_of(), buf, MakeValue(base)); + llvm::Type* vtype = llvm::VectorType::get( + LLVMType(t.element_of()), lanes)->getPointerTo(); + llvm::StoreInst* inst = builder_->CreateAlignedStore( + CreateVecSlice(value, offset, lanes), + builder_->CreatePointerCast(ptr, vtype), alignment); + AddAliasInfo(inst, op->buffer_var.get(), + Ramp::make(base, make_const(base.type(), 1), lanes)); + } } else { - LOG(FATAL) << "not yet supported"; - return nullptr; + Scalarize(op->index, [&](int i, llvm::Value* offset) { + llvm::Value* ptr = CreateBufferPtr(t.element_of(), buf, offset); + llvm::StoreInst* inst = builder_->CreateAlignedStore( + builder_->CreateExtractElement(value, ConstInt32(i)), + ptr, data_layout_->getTypeAllocSize(LLVMType(t))); + AddAliasInfo(inst, op->buffer_var.get(), Expr()); + }); } } + llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) { if (op->is_intrinsic(intrinsic::tvm_call_packed)) { return CreateCallPacked(op); @@ -904,24 +1175,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) { } } -// stmts -void CodeGenLLVM::VisitStmt_(const Store* op) { - llvm::Value* value = MakeValue(op->value); - Type t = op->value.type(); - CHECK(!t.is_vector()); - if (t.is_scalar()) { - llvm::StoreInst* inst = builder_->CreateAlignedStore( - value, - CreateBufferPtr( - t, - GetVarValue(op->buffer_var.get()), - MakeValue(op->index)), - data_layout_->getTypeAllocSize(value->getType())); - AddAliasInfo(inst, op->buffer_var.get(), op->index); - } else { - LOG(FATAL) << "not yet supported"; - } -} void CodeGenLLVM::VisitStmt_(const For* op) { CHECK(is_zero(op->min)); @@ -986,6 +1239,11 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) { } void CodeGenLLVM::VisitStmt_(const AttrStmt* op) { + if (op->type_key == ir::attr::storage_scope) { + const Variable* v = op->node.as<Variable>(); + CHECK(v); + alloc_storage_scope_[v] = op->value.as<StringImm>()->value; + } this->VisitStmt(op->body); } @@ -1014,7 +1272,9 @@ void CodeGenLLVM::VisitStmt_(const AssertStmt* op) { void CodeGenLLVM::VisitStmt_(const LetStmt* op) { llvm::Value* v = MakeValue(op->value); CHECK(!var_map_.count(op->var.get())); + CHECK(!align_map_.count(op->var.get())); var_map_[op->var.get()] = v; + align_map_[op->var.get()] = arith::EvalModular(op->value, align_map_); this->VisitStmt(op->body); } void CodeGenLLVM::VisitStmt_(const Block* op) { diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h index aed75a866d169b583263debab4d1161ca323222f..63009fd1ec1553de23284dce06ccc8d41be6253d 100644 --- a/src/codegen/llvm/codegen_llvm.h +++ b/src/codegen/llvm/codegen_llvm.h @@ -14,6 +14,7 @@ #include <vector> #include <string> #include "./llvm_common.h" +#include "../../arithmetic/modular.h" namespace tvm { namespace codegen { @@ -109,18 +110,29 @@ class CodeGenLLVM : virtual llvm::Value* CreateCallExtern(const Call* op); // create call into tvm packed function. virtual llvm::Value* CreateCallPacked(const Call* op); - + // Scalarize e by iterating elements of e. + // f is a callback that takes index and v. + virtual void Scalarize(const Expr& e, + std::function<void(int i, llvm::Value* v)> f); protected: /*! * \param t The original type. * \return LLVM type of t */ llvm::Type* LLVMType(const Type& t) const; + // initialize the function state. + void InitFuncState(); + // Get alignment given index. + void GetAlignment( + Type t, const Variable* buf_var, const Expr& index, + int* p_alignment, int* p_native_bits); // do a scalarize call with f llvm::Value* CreateScalarizedCall( const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args); // apply optimization on the module. virtual void Optimize(); + // Get the maximim storage align bits of buffer pointer given storage scope. + virtual int NativeVectorBits(const std::string& storage_scope) const; // The IRBuilder. using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>; // The current function @@ -162,6 +174,8 @@ class CodeGenLLVM : llvm::Function* f_tvm_parallel_for_{nullptr}; // The acting body llvm::BasicBlock* block_{nullptr}; + /*! \brief the storage scope of allocation */ + std::unordered_map<const Variable*, std::string> alloc_storage_scope_; private: // comparison op @@ -178,6 +192,11 @@ class CodeGenLLVM : llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index); llvm::Value* CreateCast(Type from, Type to, llvm::Value* value); llvm::Value* GetPackedFuncHandle(const std::string& str); + // Vector concatenation. + llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent); + llvm::Value* CreateVecFlip(llvm::Value* vec); + llvm::Value* CreateVecConcat(std::vector<llvm::Value*> vecs); + llvm::Value* CreateVecPad(llvm::Value* vec, int target_lanes); // Create parallel for. void CreateParallelFor(const For* op); // Create serial for @@ -197,6 +216,8 @@ class CodeGenLLVM : std::unordered_map<const Variable*, llvm::Value*> var_map_; // global strings std::unordered_map<std::string, llvm::Constant*> str_map_; + // The alignment information + std::unordered_map<const Variable*, arith::ModularEntry> align_map_; // The local module_context llvm::GlobalVariable* gv_mod_ctx_{nullptr}; // global to packed function handle diff --git a/src/pass/vectorize_loop.cc b/src/pass/vectorize_loop.cc index 18f57217d1ad38f13466e8e6f6f247aec925d5d0..7cc029cb20a69ad8c502d39ee0f3156616b52811 100644 --- a/src/pass/vectorize_loop.cc +++ b/src/pass/vectorize_loop.cc @@ -355,7 +355,9 @@ class Vectorizer : public IRMutator { const Ramp* a_ramp = a.as<Ramp>(); if (a.type().lanes() == 1 && b_ramp) { return Ramp::make( - arith::ComputeExpr<T>(a, b_ramp->base), b_ramp->stride, b_ramp->lanes); + arith::ComputeExpr<T>(a, b_ramp->base), + arith::ComputeExpr<T>(make_zero(b_ramp->stride.type()), b_ramp->stride), + b_ramp->lanes); } if (b.type().lanes() == 1 && a_ramp) { return Ramp::make( diff --git a/tests/python/unittest/test_codegen_llvm.py b/tests/python/unittest/test_codegen_llvm.py index bf2a028260554fbd9fad121e86fbdda44f67f6d7..e120bf8ec15f0c858fa03b29e40f53537e733103 100644 --- a/tests/python/unittest/test_codegen_llvm.py +++ b/tests/python/unittest/test_codegen_llvm.py @@ -2,13 +2,15 @@ import tvm import numpy as np def test_llvm_add_pipeline(): - n = tvm.Var('n') + nn = 1024 + n = tvm.convert(nn) A = tvm.placeholder((n,), name='A') B = tvm.placeholder((n,), name='B') C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') s = tvm.Schedule(C.op) - s[C].parallel(C.op.axis[0]) - + xo, xi = s[C].split(C.op.axis[0], factor=4) + s[C].parallel(xo) + s[C].vectorize(xi) def check_llvm(): if not tvm.codegen.enabled("llvm"): return @@ -16,16 +18,71 @@ def test_llvm_add_pipeline(): f = tvm.build(s, [A, B, C], "llvm") ctx = tvm.cpu(0) # launch the kernel. - n = 1027 * 1024 + n = nn a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx) c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) - for i in range(1000): - f(a, b, c) + f(a, b, c) np.testing.assert_allclose( c.asnumpy(), a.asnumpy() + b.asnumpy()) check_llvm() +def test_llvm_flip_pipeline(): + def check_llvm(nn, base): + if not tvm.codegen.enabled("llvm"): + return + n = tvm.convert(nn) + A = tvm.placeholder((n + base), name='A') + C = tvm.compute((n,), lambda i: A(nn + base- i - 1), name='C') + s = tvm.Schedule(C.op) + xo, xi = s[C].split(C.op.axis[0], factor=4) + s[C].parallel(xo) + s[C].vectorize(xi) + # build and invoke the kernel. + f = tvm.build(s, [A, C], "llvm") + ctx = tvm.cpu(0) + # launch the kernel. + n = nn + a = tvm.nd.array(np.random.uniform(size=(n + base)).astype(A.dtype), ctx) + c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) + f(a, c) + np.testing.assert_allclose( + c.asnumpy(), a.asnumpy()[::-1][:n]) + check_llvm(4, 0) + check_llvm(128, 8) + check_llvm(3, 0) + check_llvm(128, 1) + + +def test_llvm_madd_pipeline(): + def check_llvm(nn, base, stride): + if not tvm.codegen.enabled("llvm"): + return + n = tvm.convert(nn) + A = tvm.placeholder((n + base, stride), name='A') + C = tvm.compute((n, stride), lambda i, j: A(base + i, j) + 1, name='C') + s = tvm.Schedule(C.op) + xo, xi = s[C].split(C.op.axis[0], factor=4) + s[C].parallel(xo) + s[C].vectorize(xi) + # build and invoke the kernel. + f = tvm.build(s, [A, C], "llvm") + ctx = tvm.cpu(0) + # launch the kernel. + n = nn + a = tvm.nd.array(np.random.uniform(size=(n + base, stride)).astype(A.dtype), ctx) + c = tvm.nd.array(np.zeros((n, stride), dtype=C.dtype), ctx) + f(a, c) + np.testing.assert_allclose( + c.asnumpy(), a.asnumpy()[base:] + 1) + check_llvm(64, 0, 2) + check_llvm(4, 0, 1) + check_llvm(4, 0, 3) + + + if __name__ == "__main__": test_llvm_add_pipeline() + test_llvm_flip_pipeline() + test_llvm_madd_pipeline()