Skip to content
Snippets Groups Projects
Commit 13383928 authored by Haichen Shen's avatar Haichen Shen
Browse files

add var binding for expr

parent 816419be
No related branches found
No related tags found
No related merge requests found
......@@ -6,6 +6,8 @@
#ifndef TVM_EXPR_UTIL_H_
#define TVM_EXPR_UTIL_H_
#include <vector>
#include "./expr.h"
#include "./expr_node.h"
......@@ -18,6 +20,14 @@ namespace tvm {
*/
Expr Simplify(Expr src);
/*!
* \brief replace the variables in expression src by specification from dict
* \param src The source expression
* \param dict The specification for variable replacement
* \return the new expression with variable replaced
*/
Expr Bind(Expr src, std::unordered_map<Expr, Expr> dict);
/*!
* \brief visit the exression node in expr tree in post DFS order.
* \param expr The expression tree
......@@ -55,6 +65,47 @@ inline void Visit(const Expr& expr, FVisit fvisit) {
fvisit(expr);
}
/*!
* \brief transform the exression node in expr tree in post DFS order.
* \param expr The expression tree
* \param fvisit The visit function.
* \return the new expression after transformation
*/
template<typename FVisit>
inline Expr Transform(const Expr& expr, FVisit fvisit) {
// TODO(tqchen) change to stack based impl.
std::vector<Expr> children;
switch (expr.node_type()) {
case kBinaryOpNode: {
const auto* n = expr.Get<BinaryOpNode>();
Expr e = Transform(n->lhs, fvisit);
children.push_back(e);
children.push_back(Transform(n->rhs, fvisit));
break;
}
case kUnaryOpNode: {
const auto* n = expr.Get<UnaryOpNode>();
children.push_back(Transform(n->src, fvisit));
break;
}
case kReduceNode: {
const auto* n = expr.Get<ReduceNode>();
children.push_back(Transform(n->src, fvisit));
break;
}
case kTensorReadNode: {
const auto* n = expr.Get<TensorReadNode>();
for (size_t i = 0; i < n->indices.size(); ++i) {
children.push_back(Transform(n->indices[i], fvisit));
}
break;
}
default: break;
}
Expr ret = fvisit(expr, children);
return ret;
}
} // namespace tvm
#endif // TVM_EXPR_UTIL_H_
......@@ -146,6 +146,68 @@ Expr Simplify(Expr src) {
return cexpr.AsExpr();
}
Expr ExprWithNewChildren(Expr src, std::vector<Expr> children) {
if (children.size()) {
switch (src.node_type()) {
case kBinaryOpNode: {
const auto* n = src.Get<BinaryOpNode>();
if (n->lhs == children[0] && n->rhs == children[0])
return src;
return (*n->op)(children[0], children[1]);
}
case kUnaryOpNode: {
const auto* n = src.Get<UnaryOpNode>();
if (n->src == children[0])
return src;
return (*n->op)(children[0]);
}
case kReduceNode: {
const auto* n = src.Get<ReduceNode>();
if (n->src == children[0])
return src;
return (n->op)->Reduce(children[0], n->rdom);
}
case kTensorReadNode: {
const auto* n = src.Get<TensorReadNode>();
bool same = true;
for (size_t i = 0; i < n->indices.size(); ++i) {
if (n->indices[i] != children[i]) {
same = false;
break;
}
}
if (same)
return src;
Array<Expr> indices(children);
return n->tensor(indices);
}
default: {
return src;
}
}
}
return src;
}
Expr Bind(Expr src, std::unordered_map<Expr, Expr> dict) {
auto replace = [&](Expr e, std::vector<Expr> children) {
switch (e.node_type()) {
case kVarNode: {
auto it = dict.find(e);
if (it != dict.end()) {
return it->second;
}
return e;
}
default: {
return ExprWithNewChildren(e, children);
}
}
};
return Transform(src, replace);
}
void Expr::Print(std::ostream& os) const {
if (is_null()) {
os << "null"; return;
......
......@@ -13,6 +13,12 @@ DMLC_REGISTRY_ENABLE(::tvm::UnaryOpReg);
namespace tvm {
Expr UnaryOp::operator()(Expr src) const {
auto nptr = std::make_shared<UnaryOpNode>(this, std::move(src));
nptr->Verify();
return Expr(std::move(nptr));
}
Expr BinaryOp::operator()(Expr lhs, Expr rhs) const {
auto nptr = std::make_shared<BinaryOpNode>(
this, std::move(lhs), std::move(rhs));
......
......@@ -30,6 +30,25 @@ TEST(Expr, Simplify) {
CHECK(os.str() == "((x * 100) + 1000)");
}
TEST(Expr, Bind) {
using namespace tvm;
Var x("x"), y("y"), z("z");
Var i("i"), j("j");
Tensor A({y, z}, "A");
Expr e1 = x * 5;
std::unordered_map<Expr, Expr> dict = {{x, y * 10 + z}};
std::ostringstream os1, os2;
os1 << Bind(e1, dict);
CHECK(os1.str() == "(((y * 10) + z) * 5)");
Expr e2 = A(i, j);
dict.clear();
dict[i] = 64 * x;
dict[j] = z + 16 * y;
os2 << Bind(e2, dict);
CHECK(os2.str() == "A[(64 * x), (z + (16 * y))]");
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment