From 54593ca1cc90e98abb09e64d73f1694f8082c757 Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Fri, 31 Mar 2017 21:47:54 -0700 Subject: [PATCH] [LANG/GPU] Cross Thread Reduction (#79) * [LANG/GPU] Cross Thread Reduction. * Fix doxygen error * Upgrade verilog testcase to new one --- include/tvm/expr.h | 2 +- include/tvm/ir.h | 30 ++ include/tvm/ir_pass.h | 7 + include/tvm/schedule.h | 8 + python/tvm/build.py | 3 +- python/tvm/schedule.py | 18 ++ src/api/api_lang.cc | 7 + src/api/api_pass.cc | 1 + src/codegen/codegen_c.cc | 82 ++++-- src/codegen/codegen_c.h | 15 +- src/codegen/codegen_opencl.cc | 10 +- src/codegen/codegen_opencl.h | 5 +- src/codegen/codegen_source_base.cc | 1 + src/lang/ir.cc | 26 ++ src/lang/tensor.cc | 1 - src/op/compute_op.cc | 127 ++++++-- src/op/op_util.cc | 24 +- src/op/op_util.h | 10 +- src/op/scan_op.cc | 3 +- src/pass/ir_util.h | 15 + src/pass/lower_thread_allreduce.cc | 275 ++++++++++++++++++ src/runtime/thread_storage_scope.h | 14 +- src/schedule/message_passing.cc | 12 +- src/schedule/schedule_dataflow_rewrite.cc | 12 +- src/schedule/schedule_lang.cc | 19 ++ tests/python/integration/test_reduce.py | 52 +++- .../{ => unittest}/test_buffer_doublebuff.py | 12 +- .../{ => unittest}/test_buffer_doublebuff.v | 0 .../{ => unittest}/test_buffer_fifo.py | 4 +- .../verilog/{ => unittest}/test_buffer_fifo.v | 0 .../{ => unittest}/test_buffer_linebuff.py | 10 +- .../{ => unittest}/test_buffer_linebuff.v | 0 32 files changed, 692 insertions(+), 113 deletions(-) create mode 100644 src/pass/lower_thread_allreduce.cc rename tests/verilog/{ => unittest}/test_buffer_doublebuff.py (89%) rename tests/verilog/{ => unittest}/test_buffer_doublebuff.v (100%) rename tests/verilog/{ => unittest}/test_buffer_fifo.py (94%) rename tests/verilog/{ => unittest}/test_buffer_fifo.v (100%) rename tests/verilog/{ => unittest}/test_buffer_linebuff.py (92%) rename tests/verilog/{ => unittest}/test_buffer_linebuff.v (100%) diff --git a/include/tvm/expr.h b/include/tvm/expr.h index 91efe1727..7162c92d4 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -299,7 +299,7 @@ inline const char* IterVarType2String(IterVarType t) { switch (t) { case kDataPar: return "DataPar"; case kThreadIndex: return "ThreadIndex"; - case kCommReduce: return "CommRedude"; + case kCommReduce: return "CommReduce"; case kOrdered: return "Ordered"; case kOpaque: return "Opaque"; case kUnrolled: return "Unrolled"; diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 5fdc6fa21..a70de3586 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -42,6 +42,21 @@ struct Reduce : public ExprNode<Reduce> { static Expr make(std::string op, Expr src, Array<IterVar> rdom, Expr condition = const_true()); + /*! + * \brief Get initial value for reduction. + * \param op The operator + * \param type The data type. + * \return The initial value that can be assigned to reduction. + */ + static Expr InitValue(const std::string& op, Type type); + /*! + * \brief Combine two values with given reduction. + * \param op The operator + * \param a The left operand. + * \param b The left operand. + * \return The combined reduction result. + */ + static Expr Combine(const std::string& op, Expr a, Expr b); void VisitAttrs(AttrVisitor* v) final { v->Visit("dtype", &type); @@ -86,6 +101,10 @@ constexpr const char* thread_extent = "thread_extent"; * \brief Mark launching of a virtual thread. */ constexpr const char* virtual_thread = "virtual_thread"; +/*! + * \brief Mark the scope as volatile access for certain handle. + */ +constexpr const char* volatile_scope = "volatile_scope"; /*! * \brief Mark storage scope of buffers */ @@ -164,6 +183,17 @@ constexpr const char* tvm_call_packed = "tvm_call_packed"; * } */ constexpr const char* tvm_storage_sync = "tvm_storage_sync"; +/*! + * \brief See pesudo code + * + * Expr tvm_thread_allreduce(std::string op, Expr value, Expr cond, + * Var thread_idx1, thread_idx2...) { + * // constraint by the other thread_idx remain the same. + * return reduce(op, value, cond, + * over [thread_idx1, thread_idx2] passed by any caller) + * } + */ +constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce"; /*! \brief The field id of each field in array */ enum TVMArrayFieldKind { diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 082400b61..9f57724f7 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -234,6 +234,13 @@ Array<LoweredFunc> SplitHostDevice(LoweredFunc func); */ LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope); +/*! + * \brief Lower cross thread alleduce in the stmt. + * \param f The device function to be lowered. + * \param warp_size the size of warp where no sync is needed. + * \return Transformed function. + */ +LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size); } // namespace ir } // namespace tvm diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index 93b93a62c..5f7c0e0eb 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -73,6 +73,14 @@ class Stage : public NodeRef { * \return reference to self. */ Stage& compute_root(); // NOLINT(*) + /*! + * \brief Rebase the parent iter var as rebased variable. + * + * \param parent The parent iteration domain. + * \param rebased The variable to be used in rebase. + * \return reference to self. + */ + Stage& rebase(IterVar parent, IterVar rebased); /*! * \brief Split the parent by factor, generate * \param parent The parent iteration domain. diff --git a/python/tvm/build.py b/python/tvm/build.py index ec5a0dba1..588ead632 100644 --- a/python/tvm/build.py +++ b/python/tvm/build.py @@ -71,7 +71,6 @@ def lower(sch, return fapi - def build(sch, args=None, target="llvm", @@ -128,6 +127,8 @@ def build(sch, fsplits = [x for x in fsplits] for i in range(1, len(fsplits)): fsplits[i] = ir_pass.StorageSync(fsplits[i], "shared") + warp_size = 32 if target == "cuda" else 1 + fsplits[i] = ir_pass.LowerThreadAllreduce(fsplits[i], warp_size) if len(fsplits) > 1: mhost = codegen.build(fsplits[0], target_host) diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index dcddafa4c..7d3562b3b 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -112,6 +112,24 @@ class Schedule(NodeBase): @register_node class Stage(NodeBase): """A Stage represents schedule for one operation.""" + def rebase(self, parent, rebased): + """Rebase parent by an existing thread axis. + + Parameters + ---------- + parent : IterVar + The parent iter var. + + rebased : IterVar + The rebased iter var. + Returns + ------- + rebased : IterVar + The rebased itervar. + """ + _api_internal._StageRebase(self, parent, rebased) + return rebased + def split(self, parent, factor=None, outer=None): """Split the stage either by factor providing outer scope, or both diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 933adc872..cb7437333 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -219,6 +219,13 @@ TVM_REGISTER_API(_StageSetScope) .set_scope(args[1]); }); +TVM_REGISTER_API(_StageRebase) +.set_body([](TVMArgs args, TVMRetValue* ret) { + IterVar outer, inner; + args[0].operator Stage() + .rebase(args[1], args[2]); + }); + TVM_REGISTER_API(_StageSplitByFactor) .set_body([](TVMArgs args, TVMRetValue* ret) { IterVar outer, inner; diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 22624f608..96e94f2b8 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -73,6 +73,7 @@ REGISTER_PASS1(InjectVirtualThread); REGISTER_PASS1(LoopPartition); REGISTER_PASS1(RemoveNoOp); REGISTER_PASS2(SplitPipeline); +REGISTER_PASS2(LowerThreadAllreduce); REGISTER_PASS1(NarrowChannelAccess); } // namespace ir } // namespace tvm diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index b288ab82e..95b7901f4 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -88,14 +88,26 @@ void CodeGenC::PrintSSAAssign( } // Print a reference expression to a buffer. -void CodeGenC::PrintBufferRef( +std::string CodeGenC::GetBufferRef( const Variable* buffer, - Type t, Expr index, - std::ostream& os) { // NOLINT(*) + Type t, Expr index) { + std::ostringstream os; std::string vid = GetVarID(buffer); + std::string scope; + if (alloc_storage_scope_.count(buffer)) { + scope = alloc_storage_scope_.at(buffer); + } + bool is_vol = volatile_buf_.count(buffer); if (t.lanes() == 1) { - if (!HandleTypeMatch(buffer, t)) { + if (!HandleTypeMatch(buffer, t) || is_vol) { os << "(("; + if (is_vol) { + os << "volatile "; + } + if (scope.length() != 0) { + PrintStorageScope(scope, os); + } + os << ' '; PrintType(t, os); os << "*)" << vid << ')'; } else { @@ -107,17 +119,24 @@ void CodeGenC::PrintBufferRef( } else { // Buffer declared as vector type. // optimize for case where it is in register, - if (HandleTypeMatch(buffer, t)) { + if (HandleTypeMatch(buffer, t) && !is_vol) { // optimize for constant access int offset; if (arith::GetConstInt(index, &offset)) { CHECK_EQ(offset % t.lanes(), 0) << "Find unaligned vector load to a vector type"; os << vid << '[' << (offset / t.lanes()) << ']'; - return; + return os.str(); } } os << "(("; + if (is_vol) { + os << "volatile "; + } + if (scope.length() != 0) { + PrintStorageScope(scope, os); + } + os << ' '; PrintType(t, os); os << "*)("; if (!HandleTypeMatch(buffer, t.element_of())) { @@ -129,6 +148,7 @@ void CodeGenC::PrintBufferRef( PrintExpr(index, os); os << "))[0]"; } + return os.str(); } @@ -162,18 +182,17 @@ void CodeGenC::PrintVecElemStore(const std::string& vec, << " = " << value << ";\n"; } -void CodeGenC::PrintVecLoad(const Variable* buffer, - Type t, Expr base, - std::ostream& os) { - PrintBufferRef(buffer, t, base, os); +std::string CodeGenC::GetVecLoad(const Variable* buffer, + Type t, Expr base) { + return GetBufferRef(buffer, t, base); } void CodeGenC::PrintVecStore(const Variable* buffer, Type t, Expr base, const std::string& value) { + std::string ref = GetBufferRef(buffer, t, base); this->PrintIndent(); - PrintBufferRef(buffer, t, base, stream); - stream << " = " << value << ";\n"; + stream << ref << " = " << value << ";\n"; } void CodeGenC::PrintThreadIndexExpr( @@ -483,24 +502,21 @@ inline bool TryGetRamp1Base(Expr index, int lanes, Expr *base) { void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*) int lanes = op->type.lanes(); - std::string svalue = GetUniqueName("_"); // delcare type. - this->PrintIndent(); - this->PrintType(op->type, stream); - stream << ' ' << svalue; if (op->type.lanes() == 1) { - stream << " = "; - this->PrintBufferRef(op->buffer_var.get(), op->type, op->index, stream); - stream << ";\n"; + std::string ref = GetBufferRef(op->buffer_var.get(), op->type, op->index); + os << ref; } else { Expr base; if (TryGetRamp1Base(op->index, op->type.lanes(), &base)) { - stream << " = "; - this->PrintVecLoad(op->buffer_var.get(), op->type, base, stream); - stream << ";\n"; + std::string ref = GetVecLoad(op->buffer_var.get(), op->type, base); + os << ref; } else { - // Load elements seperately - stream << ";\n"; + // load seperately. + std::string svalue = GetUniqueName("_"); + this->PrintIndent(); + this->PrintType(op->type, stream); + stream << ' ' << svalue << ";\n"; std::string sindex = SSAGetID(PrintExpr(op->index), op->index.type()); std::string vid = GetVarID(op->buffer_var.get()); Type elem_type = op->type.element_of(); @@ -518,18 +534,18 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*) value_temp << ']'; PrintVecElemStore(svalue, op->type, i, value_temp.str()); } + os << svalue; } } - os << svalue; } void CodeGenC::VisitStmt_(const Store* op) { Type t = op->value.type(); if (t.lanes() == 1) { std::string value = this->PrintExpr(op->value); + std::string ref = this->GetBufferRef(op->buffer_var.get(), t, op->index); this->PrintIndent(); - this->PrintBufferRef(op->buffer_var.get(), t, op->index, stream); - stream << " = " << value << ";\n"; + stream << ref << " = " << value << ";\n"; } else { Expr base; if (TryGetRamp1Base(op->index, t.lanes(), &base)) { @@ -577,7 +593,13 @@ void CodeGenC::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(* } void CodeGenC::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(*) - LOG(FATAL) << "Select: not supported "; + os << "("; + PrintExpr(op->condition, os); + os << " ? "; + PrintExpr(op->true_value, os); + os << " : "; + PrintExpr(op->false_value, os); + os << ")"; } void CodeGenC::VisitStmt_(const LetStmt* op) { @@ -649,6 +671,10 @@ void CodeGenC::VisitStmt_(const AttrStmt* op) { const Variable* v = op->node.as<Variable>(); CHECK(v); alloc_storage_scope_[v] = op->value.as<StringImm>()->value; + } else if (op->type_key == ir::attr::volatile_scope) { + const Variable* v = op->node.as<Variable>(); + CHECK(v); + volatile_buf_.insert(v); } this->PrintStmt(op->body); } diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index bd7ef9ba3..e682e089b 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -13,6 +13,7 @@ #include <string> #include <vector> #include <unordered_map> +#include <unordered_set> #include "./codegen_source_base.h" namespace tvm { @@ -132,9 +133,8 @@ class CodeGenC : const std::string&op, Type op_type, Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*) // print vector load - virtual void PrintVecLoad(const Variable* buffer, - Type t, Expr base, - std::ostream& os); // NOLINT(*) + virtual std::string GetVecLoad(const Variable* buffer, + Type t, Expr base); // print vector store virtual void PrintVecStore(const Variable* buffer, Type t, Expr base, @@ -149,9 +149,8 @@ class CodeGenC : protected: // print reference to a buffer as type t in index. - void PrintBufferRef(const Variable* buffer, - Type t, Expr index, - std::ostream& os); // NOLINT(*) + std::string GetBufferRef(const Variable* buffer, + Type t, Expr index); /*! * \brief If buffer is allocated as type t. * \param buf_var The buffer variable. @@ -172,9 +171,11 @@ class CodeGenC : private: /*! \brief whether to print in SSA form */ - bool print_ssa_form_{true}; + bool print_ssa_form_{false}; /*! \brief the data type of allocated buffers */ std::unordered_map<const Variable*, Type> handle_data_type_; + /*! \brief set of volatile buf access */ + std::unordered_set<const Variable*> volatile_buf_; }; } // namespace codegen diff --git a/src/codegen/codegen_opencl.cc b/src/codegen/codegen_opencl.cc index 715f72068..5b3c93c7a 100644 --- a/src/codegen/codegen_opencl.cc +++ b/src/codegen/codegen_opencl.cc @@ -95,12 +95,13 @@ void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, Type t, os << GetVarID(buffer) << " + "; PrintExpr(base, os); } -void CodeGenOpenCL::PrintVecLoad(const Variable* buffer, - Type t, Expr base, - std::ostream& os) { +std::string CodeGenOpenCL::GetVecLoad(const Variable* buffer, + Type t, Expr base) { + std::ostringstream os; os << "vload" << t.lanes() << "(0, "; PrintVecAddr(buffer, t, base, os); os << ")"; + return os.str(); } void CodeGenOpenCL::PrintVecStore(const Variable* buffer, @@ -121,7 +122,8 @@ void CodeGenOpenCL::PrintStorageSync(const std::string& sync) { } } -void CodeGenOpenCL::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) +void CodeGenOpenCL::PrintStorageScope( + const std::string& scope, std::ostream& os) { // NOLINT(*) if (scope == "global") { os << "__global"; } else if (scope == "shared") { diff --git a/src/codegen/codegen_opencl.h b/src/codegen/codegen_opencl.h index 55168fdfe..fdd8d5615 100644 --- a/src/codegen/codegen_opencl.h +++ b/src/codegen/codegen_opencl.h @@ -24,9 +24,8 @@ class CodeGenOpenCL : public CodeGenC { void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageSync(const std::string& scope) final; // NOLINT(*) void PrintType(Type t, std::ostream& os) const final; // NOLINT(*) - void PrintVecLoad(const Variable* buffer, - Type t, Expr base, - std::ostream& os) final; // NOLINT(*) + std::string GetVecLoad(const Variable* buffer, + Type t, Expr base) final; void PrintVecStore(const Variable* buffer, Type t, Expr base, const std::string& value) final; // NOLINT(*) diff --git a/src/codegen/codegen_source_base.cc b/src/codegen/codegen_source_base.cc index cf3a6ec5a..2066e90bb 100644 --- a/src/codegen/codegen_source_base.cc +++ b/src/codegen/codegen_source_base.cc @@ -35,6 +35,7 @@ std::string CodeGenSourceBase::GetUniqueName(std::string prefix) { } std::string CodeGenSourceBase::SSAGetID(std::string src, Type t) { + LOG(INFO) << "ssa get id"; if (name_alloc_map_.count(src)) return src; auto it = ssa_assign_map_.find(src); if (it != ssa_assign_map_.end()) { diff --git a/src/lang/ir.cc b/src/lang/ir.cc index 55a4d7a0d..f7aa94b09 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -61,6 +61,32 @@ Expr Reduce::make(std::string op, Expr source, return Expr(n); } +Expr Reduce::InitValue(const std::string& op, Type type) { + if (op == "Add") { + return make_zero(type); + } else if (op == "Max") { + return type.min(); + } else if (op == "Min") { + return type.max(); + } else { + LOG(FATAL) << "Unsupported reduction " << op; + return Expr(); + } +} + +Expr Reduce::Combine(const std::string& op, Expr a, Expr b) { + if (op == "Add") { + return Add::make(a, b); + } else if (op == "Max") { + return Max::make(a, b); + } else if (op == "Min") { + return Min::make(a, b); + } else { + LOG(FATAL) << "Unsupported reduction " << op; + return Expr(); + } +} + TVM_REGISTER_NODE_TYPE(Reduce); TVM_REGISTER_NODE_TYPE(AttrStmt); diff --git a/src/lang/tensor.cc b/src/lang/tensor.cc index 962960f56..3d894a04a 100644 --- a/src/lang/tensor.cc +++ b/src/lang/tensor.cc @@ -20,7 +20,6 @@ Expr Tensor::operator()(Array<Expr> indices) const { return n; } - Tensor TensorNode::make(Array<Expr> shape, Type dtype, Operation op, diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index e2467bc32..185714971 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -174,19 +174,8 @@ void MakeReduction(const ComputeOpNode* op, } const Reduce* reduce = op->body.as<Reduce>(); CHECK(reduce); - Expr init_value, update_value; - if (reduce->op == "Add") { - init_value = make_zero(reduce->type); - update_value = Add::make(t(args), reduce->source); - } else if (reduce->op == "Max") { - init_value = reduce->type.min(); - update_value = Max::make(t(args), reduce->source); - } else if (reduce->op == "Min") { - init_value = reduce->type.max(); - update_value = Min::make(t(args), reduce->source); - } else { - LOG(FATAL) << "Unsupported reduction " << reduce->op; - } + Expr init_value = Reduce::InitValue(reduce->op, reduce->type); + Expr update_value = Reduce::Combine(reduce->op, t(args), reduce->source); *init = Provide::make(t->op, t->value_index, init_value, args); *provide = Provide::make(t->op, t->value_index, update_value, args); if (!is_one(reduce->condition)) { @@ -194,15 +183,6 @@ void MakeReduction(const ComputeOpNode* op, } } -Stmt MakeProvide(const ComputeOpNode* op, - const Tensor& t) { - Array<Expr> args; - for (IterVar iv : op->axis) { - args.push_back(iv->var); - } - return Provide::make(t->op, t->value_index, op->body, args); -} - Stmt Substitute(Stmt s, const std::unordered_map<IterVar, Expr>& value_map) { Map<Var, Expr> temp; @@ -212,11 +192,107 @@ Stmt Substitute(Stmt s, return ir::Substitute(s, temp); } +// Cross Thread reduction marker. +bool IsCrossThreadReduction(const ComputeOpNode* self, + const Stage& stage) { + std::unordered_set<IterVar> rebase_thread; + for (IterVarRelation rel : stage->relations) { + if (const RebaseNode* s = rel.as<RebaseNode>()) { + if (s->parent->iter_type == kCommReduce && + s->rebased->iter_type == kThreadIndex) { + rebase_thread.insert(s->rebased); + } + } + } + if (rebase_thread.size() == 0) return false; + // Verify correctness of leaf nest. + bool reduce_start = false; + for (IterVar iv : stage->leaf_iter_vars) { + if (iv->iter_type == kCommReduce) { + LOG(FATAL) << "Cannot mix cross thread reduce with normal reduce"; + } else if (rebase_thread.count(iv)) { + reduce_start = true; + } else { + CHECK(!reduce_start) + << "Cross thread reduce cannot swap with normal data axis"; + } + } + return true; +} + +Stmt MakeCrossThreadReduction( + const ComputeOpNode* self, + const Stage& stage, + const std::unordered_map<IterVar, Range>& dom_map) { + Array<Expr> args; + for (IterVar iv : self->axis) { + args.push_back(iv->var); + } + const Reduce* reduce = self->body.as<Reduce>(); + CHECK(reduce); + std::unordered_map<IterVar, Expr> value_map; + auto nest = op::MakeLoopNest( + stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map); + auto conds = op::MakeBoundCheck( + stage, dom_map, false, + std::unordered_set<IterVar>(), value_map); + Expr cond = reduce->condition; + for (Expr v : conds) { + cond = cond && v; + } + Var res_handle("reduce_temp", Handle()); + Array<Expr> freduce_args; + freduce_args.push_back(StringImm::make(reduce->op)); + freduce_args.push_back(reduce->source); + freduce_args.push_back(cond); + + std::vector<Expr> thread_head_check; + for (IterVarRelation rel : stage->relations) { + if (const RebaseNode* s = rel.as<RebaseNode>()) { + if (s->parent->iter_type == kCommReduce && + s->rebased->iter_type == kThreadIndex) { + freduce_args.push_back(s->rebased->var); + thread_head_check.push_back(s->rebased->var == 0); + } + } + } + Stmt reduce_body = Store::make( + res_handle, Call::make( + reduce->type, + ir::intrinsic::tvm_thread_allreduce, + freduce_args, Call::Intrinsic), + 0); + Stmt assign_body = Provide::make( + stage->op, 0, Load::make(reduce->type, res_handle, 0), args); + assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body); + assign_body = MergeNest(op::MakeIfNest(conds), assign_body); + Stmt body = Allocate::make( + res_handle, reduce->type, {1}, const_true(), + Block::make(reduce_body, assign_body)); + body = AttrStmt::make( + res_handle, attr::storage_scope, StringImm::make("local"), body); + body = Substitute(body, value_map); + return MergeNest(nest, body); +} + +Stmt MakeProvide(const ComputeOpNode* op, + const Tensor& t) { + Array<Expr> args; + for (IterVar iv : op->axis) { + args.push_back(iv->var); + } + return Provide::make(t->op, t->value_index, op->body, args); +} + Stmt ComputeOpNode::BuildProvide( const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map) const { CHECK_EQ(stage->op.operator->(), this); + if (IsCrossThreadReduction(this, stage)) { + // specially handle cross thread reduction. + return MakeCrossThreadReduction(this, stage, dom_map); + } Stmt init, provide; if (this->reduce_axis.size() == 0) { provide = MakeProvide(this, stage->op.output(0)); @@ -227,9 +303,9 @@ Stmt ComputeOpNode::BuildProvide( std::unordered_map<IterVar, Expr> value_map; auto nest = op::MakeLoopNest( stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map); - nest.push_back(op::MakeBoundCheck( + nest.push_back(op::MakeIfNest(op::MakeBoundCheck( stage, dom_map, false, - std::unordered_set<IterVar>(), value_map)); + std::unordered_set<IterVar>(), value_map))); provide = Substitute(provide, value_map); if (init.defined()) { @@ -266,7 +342,8 @@ Stmt ComputeOpNode::BuildProvide( stage, dom_map, begin_loop, true, skip_iter, &init_value_map); init_nest.push_back( - op::MakeBoundCheck(stage, dom_map, true, skip_iter, init_value_map)); + op::MakeIfNest( + op::MakeBoundCheck(stage, dom_map, true, skip_iter, init_value_map))); init = Substitute(init, init_value_map); init = MergeNest(init_nest, init); // common nest diff --git a/src/op/op_util.cc b/src/op/op_util.cc index 487be17cc..640652008 100644 --- a/src/op/op_util.cc +++ b/src/op/op_util.cc @@ -160,37 +160,45 @@ void PassUpBoundCheck(const Stage& s, } } -std::vector<Stmt> MakeBoundCheck( +std::vector<Expr> MakeBoundCheck( const Stage& stage, const Map<IterVar, Range>& dom_map, bool skip_ivar_domain, const std::unordered_set<IterVar>& skip_iter, const std::unordered_map<IterVar, Expr>& value_map) { - Stmt no_op = Evaluate::make(0); std::unordered_map<IterVar, bool> bound_state; for (IterVar iv : stage->leaf_iter_vars) { bound_state[iv] = false; } PassUpBoundCheck(stage, dom_map, &bound_state); - // insert conditions - std::vector<Stmt> nest; + std::vector<Expr> preds; for (IterVar iv : stage->op->root_iter_vars()) { if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue; Range dom = dom_map.at(iv); if (bound_state.at(iv)) { - Expr condition = ComputeExpr<Sub>(value_map.at(iv), dom->min) < dom->extent; - nest.emplace_back(IfThenElse::make(condition, no_op)); + preds.emplace_back( + ComputeExpr<Sub>(value_map.at(iv), dom->min) < dom->extent); } CHECK(iv->dom.defined()); if (!skip_ivar_domain && !iv->dom.same_as(dom)) { - Expr condition = ComputeExpr<Sub>(value_map.at(iv), iv->dom->min) < iv->dom->extent; - nest.emplace_back(IfThenElse::make(condition, no_op)); + preds.emplace_back( + ComputeExpr<Sub>(value_map.at(iv), iv->dom->min) < iv->dom->extent); } } + return preds; +} + +std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) { + Stmt no_op = Evaluate::make(0); + std::vector<Stmt> nest; + for (const Expr& cond : predicates) { + nest.emplace_back(IfThenElse::make(cond, no_op)); + } return nest; } + // replacer to replace tensors class TensorReplacer : public ir::IRMutator { public: diff --git a/src/op/op_util.h b/src/op/op_util.h index ca37c0d5f..914815f9a 100644 --- a/src/op/op_util.h +++ b/src/op/op_util.h @@ -43,13 +43,21 @@ MakeLoopNest(const Stage& stage, * \param skip_ivar_domain Whether we can skip check for IterVar's original domain. * \param skip_iter Whether skip certain iteration. * \param value_map The result value of each IterVar. + * \return List of predicates that we need to check. */ -std::vector<Stmt> +std::vector<Expr> MakeBoundCheck(const Stage& stage, const Map<IterVar, Range>& dom_map, bool skip_ivar_domain, const std::unordered_set<IterVar>& skip_iter, const std::unordered_map<IterVar, Expr>& value_map); +/*! + * \brief Create a nest of if checking the predicates. + * + * \param predicates The predicates to be checked. + * \return List of If nest that checks the predicates. + */ +std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates); /*! * \brief Replace the tensor reference in stmt by the replace map. diff --git a/src/op/scan_op.cc b/src/op/scan_op.cc index f34aee579..45031082d 100644 --- a/src/op/scan_op.cc +++ b/src/op/scan_op.cc @@ -269,7 +269,8 @@ Stmt ScanOpNode::BuildProvide( stage, dom_map, 0, false, empty, &vmap); nest[begin_scan].push_back(init); nest.push_back( - op::MakeBoundCheck(stage, dom_map, false, empty, vmap)); + op::MakeIfNest( + op::MakeBoundCheck(stage, dom_map, false, empty, vmap))); return MergeNest(nest, provide); } } // namespace tvm diff --git a/src/pass/ir_util.h b/src/pass/ir_util.h index 2fbff8099..47a57a9ed 100644 --- a/src/pass/ir_util.h +++ b/src/pass/ir_util.h @@ -70,6 +70,21 @@ inline Stmt MergeNest(std::vector<std::vector<Stmt> > nest, Stmt body) { return body; } + +/*! + * \brief combine sequence of operations. + * \param seq The sequence. + * \return The combined Stmt + */ +inline Stmt MergeSeq(const std::vector<Stmt>& seq) { + if (seq.size() == 0) return Evaluate::make(0); + Stmt body = seq[0]; + for (size_t i = 1; i < seq.size(); ++i) { + body = Block::make(body, seq[i]); + } + return body; +} + } // namespace ir } // namespace tvm #endif // TVM_PASS_IR_UTIL_H_ diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc new file mode 100644 index 000000000..1b70c52ab --- /dev/null +++ b/src/pass/lower_thread_allreduce.cc @@ -0,0 +1,275 @@ +/*! + * Copyright (c) 2017 by Contributors + * Lower allreduce to device implementable ir. + * \file lower_thread_allreduce.cc + */ +#include <tvm/ir.h> +#include <tvm/ir_mutator.h> +#include <tvm/ir_pass.h> +#include <unordered_set> +#include "./ir_util.h" +#include "../arithmetic/compute_expr.h" +#include "../runtime/thread_storage_scope.h" + +namespace tvm { +namespace ir { + +class ThreadAllreduceBuilder : public IRMutator { + public: + explicit ThreadAllreduceBuilder(int warp_size) + : warp_size_(warp_size) {} + + Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { + if (op->type_key == attr::thread_extent) { + thread_extents_.push_back(op); + Stmt ret = IRMutator::Mutate_(op, s); + thread_extents_.pop_back(); + return ret; + } else if (op->type_key == attr::storage_scope) { + Stmt ret = IRMutator::Mutate_(op, s); + op = ret.as<AttrStmt>(); + const Variable* v = op->node.as<Variable>(); + if (alloc_remap_.count(v)) { + return op->body; + } else { + return ret; + } + } else { + return IRMutator::Mutate_(op, s); + } + } + Stmt Mutate_(const Store* op, const Stmt& s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as<Store>(); + const Call* call = op->value.as<Call>(); + if (call && call->is_intrinsic(intrinsic::tvm_thread_allreduce)) { + return MakeAllreduce(op, call); + } else { + return stmt; + } + } + Stmt Mutate_(const Allocate* op, const Stmt& s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as<Allocate>(); + auto it = alloc_remap_.find(op->buffer_var.get()); + if (it != alloc_remap_.end()) { + const Allocate* repl = it->second.as<Allocate>(); + // use volatile access to shared buffer. + stmt = AttrStmt::make( + repl->buffer_var, attr::volatile_scope, 1, op->body); + stmt = Allocate::make( + repl->buffer_var, repl->type, + repl->extents, repl->condition, stmt); + stmt = AttrStmt::make( + repl->buffer_var, attr::storage_scope, + StringImm::make("shared"), stmt); + return stmt; + } else { + return stmt; + } + } + Expr Mutate_(const Load* op, const Expr& e) final { + auto it = load_remap_.find(op->buffer_var.get()); + if (it != load_remap_.end()) { + CHECK(is_zero(op->index)); + return it->second; + } else { + return IRMutator::Mutate_(op, e); + } + } + + private: + // Thread entry + struct ThreadEntry { + runtime::ThreadScope scope; + IterVar iv; + int extent; + // comparator + bool operator<(const ThreadEntry& other) const { + return scope.dim_index < other.scope.dim_index; + } + }; + // make allreduce. + Stmt MakeAllreduce(const Store* op, const Call* call) { + const std::string& op_code = call->args[0].as<StringImm>()->value; + Expr value = call->args[1]; + Expr cond = call->args[2]; + if (!is_one(cond)) { + value = Select::make( + cond, value, Reduce::InitValue(op_code, value.type())); + } + + std::unordered_set<const Variable*> reduce_index_; + for (size_t i = 3; i < call->args.size(); ++i) { + const Variable* v = call->args[i].as<Variable>(); + CHECK(v); + reduce_index_.insert(v); + } + size_t nmatch = 0; + std::vector<ThreadEntry> vred, vpar; + for (const AttrStmt* attr : thread_extents_) { + ThreadEntry e; + IterVar iv(attr->node.node_); + e.scope = runtime::ThreadScope::make(iv->thread_tag); + e.iv = iv; + CHECK(arith::GetConstInt(attr->value, &(e.extent))) + << "Need constant extent for thread group"; + CHECK_LE(e.scope.rank, 1); + CHECK_GE(e.scope.dim_index, 0) + << "vthread do not work with cross thread reduction"; + if (e.scope.rank == 1) { + if (reduce_index_.count(iv->var.get())) { + vred.push_back(e); + ++nmatch; + } else { + vpar.push_back(e); + } + } + } + CHECK_EQ(nmatch, reduce_index_.size()) + << "Not all reduce index are presented in the context"; + std::sort(vred.begin(), vred.end()); + std::sort(vpar.begin(), vpar.end()); + // the size of each index. + int reduce_extent, group_extent; + int threadx_extent = 1; + Expr reduce_index = FlattenThread(vred, &reduce_extent); + Expr group_index = FlattenThread(vpar, &group_extent); + if (reduce_extent == 1) { + // special case, no reduction is needed. + return Store::make(op->buffer_var, value, 0); + } + // Whether the threadIdx.x is involved in reduction. + if (vred[0].scope.dim_index == 0) { + threadx_extent = vred[0].extent; + } + Var shared_buf("red_buf", Handle()); + std::vector<Stmt> seq; + seq.emplace_back(Store::make( + shared_buf, value, + BufIndex(reduce_index, group_index, reduce_extent))); + seq.emplace_back(SyncThread()); + seq.emplace_back(MakeBufAllreduce( + op_code, value.type(), shared_buf, + reduce_index, group_index, reduce_extent, threadx_extent)); + CHECK(!load_remap_.count(op->buffer_var.get())); + load_remap_[op->buffer_var.get()] = + Load::make( + value.type(), shared_buf, + BufIndex(make_zero(reduce_index.type()), group_index, reduce_extent)); + alloc_remap_[op->buffer_var.get()] = + Allocate::make(shared_buf, value.type(), + {Expr(group_extent), Expr(reduce_extent)}, + const_true(), Evaluate::make(0)); + return MergeSeq(seq); + } + // make allreduce. + Stmt MakeBufAllreduce(const std::string& op, + Type type, + Var shared_buf, + Expr reduce_index, + Expr group_index, + int reduce_extent, + int threadx_extent) { + // Get next power of two + int reduce_align = 1; + while (reduce_extent > reduce_align) { + reduce_align = reduce_align << 1; + } + CHECK_GT(reduce_align, 1); + std::vector<Stmt> seq; + + Expr buf_index = BufIndex(reduce_index, group_index, reduce_extent); + // make reduction + auto freduce = [&](int offset) { + Expr b = Load::make( + type, shared_buf, + BufIndex(reduce_index + offset, group_index, reduce_extent)); + Expr a = Load::make(type, shared_buf, buf_index); + return Store::make(shared_buf, Reduce::Combine(op, a, b), buf_index); + }; + // Step one, check for + if (reduce_align > reduce_extent) { + // reduction with the boundary condition + reduce_align = reduce_align >> 1; + Expr cond = reduce_index < (reduce_extent - reduce_align); + seq.emplace_back(IfThenElse::make(cond, freduce(reduce_align))); + seq.emplace_back(SyncThread()); + } + CHECK(threadx_extent >= 1 && warp_size_ >= 1); + // normal synchronization + while (reduce_align > threadx_extent || + reduce_align > warp_size_) { + reduce_align = reduce_align >> 1; + Expr cond = reduce_index < reduce_align; + seq.emplace_back(IfThenElse::make(cond, freduce(reduce_align))); + seq.emplace_back(SyncThread()); + } + // in warp synchronization. + std::vector<Stmt> in_warp_seq; + Expr in_warp_cond = reduce_index < (reduce_align >> 1); + while (reduce_align > 1) { + reduce_align = reduce_align >> 1; + in_warp_seq.emplace_back(freduce(reduce_align)); + } + if (in_warp_seq.size() != 0) { + Stmt warp_body = MergeSeq(in_warp_seq); + seq.emplace_back(IfThenElse::make(in_warp_cond, warp_body)); + } + return MergeSeq(seq); + } + // Flatten the thread index. + // Also return a warp number, + Expr FlattenThread(const std::vector<ThreadEntry>& tvec, + int* out_total_extent) { + int& total_extent = *out_total_extent; + total_extent = 1; + if (tvec.size() == 0) { + return make_zero(Int(32)); + } + + Expr ret; + for (const ThreadEntry& e : tvec) { + if (ret.defined()) { + ret = ret + e.iv->var * total_extent; + } else { + CHECK_EQ(total_extent, 1); + ret = e.iv->var; + } + total_extent *= e.extent; + } + return ret; + } + // sync thread op. + static Stmt SyncThread() { + return Evaluate::make( + Call::make(Int(32), intrinsic::tvm_storage_sync, + {StringImm::make("shared")}, + Call::Intrinsic)); + } + // The local buffer index. + static Expr BufIndex(Expr reduce_index, Expr group_index, int reduce_extent) { + if (!is_zero(group_index)) { + return ir::Simplify(group_index * reduce_extent + reduce_index); + } else { + return reduce_index; + } + } + // The warp size of the device. + int warp_size_{1}; + // surrounding scope of thread extent. + std::vector<const AttrStmt*> thread_extents_; + // The load remap + std::unordered_map<const Variable *, Expr> load_remap_; + // Allocate remap + std::unordered_map<const Variable *, Stmt> alloc_remap_; +}; + +LoweredFunc +LowerThreadAllreduce(LoweredFunc f, int warp_size) { + auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); + n->body = ThreadAllreduceBuilder(warp_size).Mutate(n->body); + return LoweredFunc(n); +} +} // namespace ir +} // namespace tvm diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index da623567b..b523b9318 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -112,18 +112,10 @@ class ThreadAxisConfig { arg_index_map_.push_back(ts.rank * 3 + ts.dim_index); filled[ts.rank * 3 + ts.dim_index] = true; } - work_dim_ = 3; + work_dim_ = 1; for (int i = 0; i < 3; ++i) { - if (!filled[i]) { - for (int j = i; j < 3; ++j) { - CHECK(!filled[j] && !filled[j + 3]) - << "Invalid thread group configuration"; - } - work_dim_ = i; - break; - } else { - CHECK(filled[i]) - << "Must have both threadIdx and blockIdx"; + if (filled[i] || filled[i + 3]) { + work_dim_ = i + 1; } } } diff --git a/src/schedule/message_passing.cc b/src/schedule/message_passing.cc index 68d28df2c..51d1aa229 100644 --- a/src/schedule/message_passing.cc +++ b/src/schedule/message_passing.cc @@ -75,8 +75,16 @@ void PassDownDomain(const Stage& stage, CHECK(allow_missing); continue; } - state[r->rebased] = Range::make_with_min_extent( - 0, state.at(r->parent)->extent); + Range res = Range::make_with_min_extent( + 0, state.at(r->parent)->extent); + if (r->rebased->dom.defined()) { + Range rebase_rng = r->rebased->dom; + bool match = is_zero(rebase_rng->min); + if (!prove_equal(rebase_rng->extent, res->extent)) match = false; + CHECK(match) << r->rebased + << " does not match parent scope's range"; + } + state[r->rebased] = res; } else { LOG(FATAL) << "unknown relation type"; } diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index b577f0a43..9545a35cd 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -305,8 +305,10 @@ Tensor Schedule::rfactor(const Tensor& tensor, } } // predicate generation, copy not touched axis. + const Reduce* reduce = compute_op->body.as<Reduce>(); + CHECK(reduce) << "Can only rfactor non-inline reductions"; + Expr predicate = reduce->condition; std::unordered_map<const Variable*, Expr> vsub; - Expr predicate; for (IterVar iv : compute_op->reduce_axis) { if (!touch_map.count(iv)) { n->reduce_axis.push_back(iv); @@ -316,10 +318,10 @@ Tensor Schedule::rfactor(const Tensor& tensor, vsub[iv->var.get()] = index; if (!index.same_as(iv->var)) { Expr cond = (index < dom_map.at(iv)->extent); - if (predicate.defined()) { - predicate = predicate && cond; - } else { + if (is_one(predicate)) { predicate = cond; + } else { + predicate = predicate && cond; } } } @@ -333,8 +335,6 @@ Tensor Schedule::rfactor(const Tensor& tensor, n->reduce_axis.push_back(IterVar(ncpy)); } } - const Reduce* reduce = compute_op->body.as<Reduce>(); - CHECK(reduce) << "Can only rfactor non-inline reductions"; n->body = Reduce::make(reduce->op, VarReplacer(vsub).Mutate(reduce->source), n->reduce_axis, diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc index 318e9b057..723588e20 100644 --- a/src/schedule/schedule_lang.cc +++ b/src/schedule/schedule_lang.cc @@ -136,6 +136,25 @@ Stage& Stage::compute_root() { // NOLINT(*) return *this; } +Stage& Stage::rebase(IterVar parent, IterVar rebased) { // NOLINT(*) + CHECK(parent->iter_type == kDataPar || + parent->iter_type == kCommReduce) + << "Cannot rebase " << IterVarType2String(parent->iter_type); + CHECK(rebased->iter_type == kThreadIndex) + << "Cannot rebase by " << IterVarType2String(rebased->iter_type) + << ", only thread axis is allowed so far"; + ArrayNode* all_vars = (*this)->all_iter_vars.CopyOnWrite(); + ArrayNode* leaf_vars = (*this)->leaf_iter_vars.CopyOnWrite(); + size_t pos = FindLeafVar(all_vars, leaf_vars, parent); + (*this)->relations.push_back(RebaseNode::make(parent, rebased)); + // add vars to all vars + all_vars->data.push_back(rebased.node_); + // replace the position. + leaf_vars->data.erase(leaf_vars->data.begin() + pos); + leaf_vars->data.insert(leaf_vars->data.begin() + pos, rebased.node_); + return *this; +} + Stage& Stage::split( IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor) { // NOLINT(*) CheckSplit(operator->(), parent, IterVar()); diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py index 726cd3f11..fbb3c9f10 100644 --- a/tests/python/integration/test_reduce.py +++ b/tests/python/integration/test_reduce.py @@ -51,7 +51,7 @@ def test_rfactor(): n = tvm.convert(1027) A = tvm.placeholder((n,), name='A') k = tvm.reduce_axis((0, n)) - B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k, where=(i>1)), name='B') + B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k), name='B') kf = tvm.reduce_axis((0, 4)) # schedule s = tvm.Schedule(B.op) @@ -78,6 +78,56 @@ def test_rfactor(): check_target() + +def test_rfactor_threads(): + nn = 1027 + mm = 10 + n = tvm.convert(nn) + m = tvm.convert(mm) + A = tvm.placeholder((m, n), name='A') + k = tvm.reduce_axis((0, n)) + nthread = 16 + B = tvm.compute((m,), lambda i: tvm.sum(A[i, k], axis=k, where=(i>1)), name='B') + tx = tvm.thread_axis((0, nthread), "threadIdx.x") + ty = tvm.thread_axis((0, nthread), "threadIdx.y") + bx = tvm.thread_axis(None, "blockIdx.x") + # schedule + s = tvm.Schedule(B.op) + ko, kf = s[B].split(k, factor=nthread) + BF = s.rfactor(B, kf) + xo, xi = s[B].split(s[B].op.axis[0], factor=nthread, outer=bx) + s[B].rebase(xi, ty) + s[B].rebase(s[B].op.reduce_axis[0], tx) + s[BF].compute_at(s[B], tx) + + # one line to build the function. + def check_target(device, host="stackvm"): + if not tvm.codegen.enabled(device): + return + ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) + fapi = tvm.lower(s, args=[A, B]) + fapi2 = tvm.ir_pass.LowerThreadAllreduce(fapi, 32) + fsum = tvm.build(fapi, + target=device, + name="mysum") + print(fsum.imported_modules[0].get_source()) + # launch the kernel. + n = nn + m = mm + a = tvm.nd.array(np.random.uniform(size=(m, n)).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(m, dtype=B.dtype), ctx) + fsum(a, b) + res = np.sum(a.asnumpy(), axis=1) + res[:2] = 0 + np.testing.assert_allclose( + b.asnumpy(), res, rtol=1e-4) + + if tvm.module.enabled("opencl"): + tvm.module.init_opencl() + check_target("cuda") + check_target("opencl") + if __name__ == "__main__": + test_rfactor_threads() test_rfactor() test_sum() diff --git a/tests/verilog/test_buffer_doublebuff.py b/tests/verilog/unittest/test_buffer_doublebuff.py similarity index 89% rename from tests/verilog/test_buffer_doublebuff.py rename to tests/verilog/unittest/test_buffer_doublebuff.py index e0439d9c9..7d8cb1d98 100644 --- a/tests/verilog/test_buffer_doublebuff.py +++ b/tests/verilog/unittest/test_buffer_doublebuff.py @@ -35,11 +35,11 @@ def test_buffer_doublebuff(): write_data.put_int(0) # De-assert reset - sess.yield_until_posedge() + sess.yield_until_next_cycle() rst.put_int(0) # Leave the following signals set to true - sess.yield_until_posedge() + sess.yield_until_next_cycle() write_valid.put_int(1) # Main simulation loop @@ -50,15 +50,15 @@ def test_buffer_doublebuff(): if (write_idx < len(test_data)): write_advance.put_int(0) if (write_ready.get_int()): - write_data.put_int(test_data[write_idx]) - write_addr.put_int(write_idx%window_width) + write_data.put_int(int(test_data[write_idx])) + write_addr.put_int(write_idx % window_width) if (write_idx%window_width==window_width-1): write_advance.put_int(1) write_idx += 1 else: write_advance.put_int(0) write_valid.put_int(0) - + # correctness checks if (read_data_valid.get_int()): assert(read_data.get_int()==test_data[read_idx]) @@ -66,7 +66,7 @@ def test_buffer_doublebuff(): read_idx += 1 # step - sess.yield_until_posedge() + sess.yield_until_next_cycle() if __name__ == "__main__": diff --git a/tests/verilog/test_buffer_doublebuff.v b/tests/verilog/unittest/test_buffer_doublebuff.v similarity index 100% rename from tests/verilog/test_buffer_doublebuff.v rename to tests/verilog/unittest/test_buffer_doublebuff.v diff --git a/tests/verilog/test_buffer_fifo.py b/tests/verilog/unittest/test_buffer_fifo.py similarity index 94% rename from tests/verilog/test_buffer_fifo.py rename to tests/verilog/unittest/test_buffer_fifo.py index 3255ceafb..f95fe9796 100644 --- a/tests/verilog/test_buffer_fifo.py +++ b/tests/verilog/unittest/test_buffer_fifo.py @@ -27,7 +27,7 @@ def test_buffer_fifo(): write_data.put_int(0) # De-assert reset - sess.yield_until_posedge() + sess.yield_until_next_cycle() rst.put_int(0) # Main simulation loop @@ -46,7 +46,7 @@ def test_buffer_fifo(): assert(read_data.get_int()==test_data[read_idx]) read_idx += 1 # step - sess.yield_until_posedge() + sess.yield_until_next_cycle() if __name__ == "__main__": diff --git a/tests/verilog/test_buffer_fifo.v b/tests/verilog/unittest/test_buffer_fifo.v similarity index 100% rename from tests/verilog/test_buffer_fifo.v rename to tests/verilog/unittest/test_buffer_fifo.v diff --git a/tests/verilog/test_buffer_linebuff.py b/tests/verilog/unittest/test_buffer_linebuff.py similarity index 92% rename from tests/verilog/test_buffer_linebuff.py rename to tests/verilog/unittest/test_buffer_linebuff.py index da01f3fc0..b4d2b34c1 100644 --- a/tests/verilog/test_buffer_linebuff.py +++ b/tests/verilog/unittest/test_buffer_linebuff.py @@ -33,11 +33,11 @@ def test_buffer_linebuff(): write_data.put_int(0) # De-assert reset - sess.yield_until_posedge() + sess.yield_until_next_cycle() rst.put_int(0) # Leave the following signals set to true - sess.yield_until_posedge() + sess.yield_until_next_cycle() write_advance.put_int(1) write_valid.put_int(1) @@ -48,12 +48,12 @@ def test_buffer_linebuff(): # write logic if (write_idx < len(test_data)): if (write_ready.get_int()): - write_data.put_int(test_data[write_idx]) + write_data.put_int(int(test_data[write_idx])) write_idx += 1 else: write_advance.put_int(0) write_valid.put_int(0) - + # correctness checks if (read_data_valid.get_int()): # Derive convolution window indices @@ -67,7 +67,7 @@ def test_buffer_linebuff(): read_idx += 1 # step - sess.yield_until_posedge() + sess.yield_until_next_cycle() if __name__ == "__main__": diff --git a/tests/verilog/test_buffer_linebuff.v b/tests/verilog/unittest/test_buffer_linebuff.v similarity index 100% rename from tests/verilog/test_buffer_linebuff.v rename to tests/verilog/unittest/test_buffer_linebuff.v -- GitLab