Skip to content
Snippets Groups Projects
Commit 7591714a authored by tqchen's avatar tqchen
Browse files

checkin initial of itervar

parent 70d93028
No related branches found
No related tags found
No related merge requests found
......@@ -15,6 +15,7 @@ namespace tvm {
/*! \brief container class of reduction domain */
class RDomainNode;
class IterDomainNode;
/*!
* \brief same as Halide::IR::Range
......@@ -40,6 +41,9 @@ class Range : public Halide::IR::Range {
static Range make_with_min_extent(Expr min, Expr extent);
};
/*! \brief Domain is a multi-dimensional range */
using Domain = Array<Range>;
......@@ -83,6 +87,20 @@ class RDomain : public NodeRef {
/*! \brief use RDom as alias of RDomain */
using RDom = RDomain;
/*!
* \brief An iteration variable representing an iteration
* over a one dimensional domain.
*/
class IterVarNode : public Node {
/*! \brief The */
Var var;
/*! \brief the domain of iteration */
Range dom;
/*! \brief additional tag on the iteration variable */
std::string tag;
};
/*! \brief reduction domain node */
class RDomainNode : public Node {
public:
......
......@@ -20,6 +20,7 @@ using Halide::Internal::ExprNode;
using Halide::Internal::StmtNode;
using Halide::Internal::IRNodeType;
using Halide::Internal::ForType;
using Halide::DeviceAPI;
/*! \brief Reduction operator operator */
struct Reduce : public ExprNode<Reduce> {
......
......@@ -38,8 +38,36 @@ Stmt MakeLoop(std::vector<Stmt>&& nest, Stmt body) {
return body;
}
void MakeLoop(const DimSplitNode* op,
const Split& s,
Scope<AttrKey, Expr>* pscope,
std::vector<Stmt>* nest) {
auto& scope = *pscope;
Expr out_min = scope[{op->var, "min"}];
Expr out_ext = scope[{op->var, "extent"}];
Expr stride = op->factor;
Var offset(s->var->name_hint + ".offset", Int(32));
// for loop with stride
// TODO(tqchen) split the loop to deal with tails
nest->emplace_back(
For::make(
offset, out_min, out_ext,
ForType::Parallel, DeviceAPI::None, Stmt()));
Expr in_min = offset + out_min;
Expr in_ext = min(stride, out_ext - offset);
// declare min and extent of the corresponding variable
nest->emplace_back(AttrStmt::make(op->var, "min", in_min, Stmt()));
nest->emplace_back(AttrStmt::make(op->var, "extent", in_ext, Stmt()));
// declare this is the loop
nest->emplace_back(AttrStmt::make(s, "split", 0, Stmt()));
// setup the scope.
pscope->Push({op->var, "min"}, in_min);
pscope->Push({op->var, "extent"}, in_ext);
}
Stmt MakePipeline(const Schedule& sch, Stmt body) {
return body;
}
......@@ -50,10 +78,17 @@ class InjectRealize : public IRMutator {
: sch_(sch) {}
Stmt Mutate(Stmt stmt) final {
stmt = IRMutator::Mutate(stmt);
const AttrStmt* op = stmt.as<AttrStmt>();
if (op != nullptr) {
attr_scope_.Push({op->node, op->type_key}, op->value);
stmt = IRMutator::Mutate(stmt);
attr_scope_.Pop({op->node, op->type_key});
} else {
stmt = IRMutator::Mutate(stmt);
}
if (op != nullptr &&
op->type_key == "Split" &&
op->type_key == "split" &&
op->node == sch_->attach_parent) {
return AttrStmt::make(
op->node, op->type_key, op->value,
......@@ -66,6 +101,7 @@ class InjectRealize : public IRMutator {
private:
// the operations to be carried
Schedule sch_;
Scope<AttrKey, Expr> attr_scope_;
};
} // namespace
......
......@@ -8,7 +8,7 @@ def test_schedule_create():
B = tvm.placeholder((n, l), name='B')
T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
sch = tvm.Schedule(T, scope="shared")
sch = tvm.Schedule(T.op, scope="shared")
tk1 = tvm.Split(T.op.dim_var[0], 10)
assert isinstance(sch, tvm.schedule.Schedule)
assert isinstance(tk1, tvm.schedule.DimSplit)
......
......@@ -21,6 +21,7 @@ def test_tensor_reduce():
T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
rd = tvm.RDomain(tvm.Range(A.shape[1]))
C = tvm.compute((m, n), lambda i, j: tvm.sum(T(i, j, rd.index[0]), rdom=rd))
print(C.op.body)
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment