From 56e10eb09f0d048b5eec4b13dccff6dfeb4c703d Mon Sep 17 00:00:00 2001
From: tqchen <tianqi.tchen@gmail.com>
Date: Wed, 19 Oct 2016 14:27:07 -0700
Subject: [PATCH] Tensor API

---
 include/tvm/array.h      |   4 +-
 include/tvm/expr_node.h  |  39 ++++++++++++++-
 include/tvm/tensor.h     | 104 +++++++++++++++++++++++++++++++++++----
 src/expr/domain.cc       |   4 +-
 src/expr/expr.cc         |   5 ++
 src/expr/expr_node.cc    |   1 +
 src/expr/tensor.cc       |  48 ++++++++++++++++++
 tests/cpp/tensor_test.cc |  11 +++--
 8 files changed, 199 insertions(+), 17 deletions(-)
 create mode 100644 src/expr/tensor.cc

diff --git a/include/tvm/array.h b/include/tvm/array.h
index 4484d3b89..db5e0d7af 100644
--- a/include/tvm/array.h
+++ b/include/tvm/array.h
@@ -23,10 +23,10 @@ class ArrayNode : public Node {
     return "ArrayNode";
   }
   void VisitAttrs(AttrVisitor* visitor) override {
-    LOG(FATAL) << "need to specially handle list";
+    LOG(FATAL) << "need to specially handle list attrs";
   }
   void VisitNodeRefFields(FNodeRefVisit fvisit) override {
-    LOG(FATAL) << "need to specially handle list";
+    // Do nothing, specially handled
   }
 };
 
diff --git a/include/tvm/expr_node.h b/include/tvm/expr_node.h
index 371d00908..d5dec0bb8 100644
--- a/include/tvm/expr_node.h
+++ b/include/tvm/expr_node.h
@@ -141,7 +141,7 @@ struct BinaryOpNode : public ExprNode {
   }
 };
 
-/*! \brief Binary mapping operator */
+/*! \brief Reduction operator operator */
 struct ReduceNode : public ExprNode {
  public:
   /*! \brief The operator */
@@ -178,6 +178,43 @@ struct ReduceNode : public ExprNode {
   }
 };
 
+/*! \brief Tensor read operator */
+struct TensorReadNode : public ExprNode {
+ public:
+  /*! \brief The tensor to be read from */
+  Tensor tensor;
+  /*! \brief The indices of read */
+  Array<Expr> indices;
+  /*! \brief constructor, do not use constructor */
+  TensorReadNode() {
+    node_type_ = kTensorReadNode;
+  }
+  TensorReadNode(Tensor && tensor, Array<Expr> && indices)
+      : tensor(std::move(tensor)), indices(std::move(indices)) {
+    node_type_ = kReduceNode;
+    dtype_ = tensor.dtype();
+  }
+  ~TensorReadNode() {
+    this->Destroy();
+  }
+  const char* type_key() const override {
+    return "TensorReadNode";
+  }
+  void Verify() const override {
+    CHECK_EQ(dtype_, tensor.dtype());
+    for (size_t i = 0; i < indices.size(); ++i) {
+      CHECK_EQ(indices[i].dtype(), kInt32);
+    }
+  }
+  void VisitAttrs(AttrVisitor* visitor) override {
+    visitor->Visit("dtype", &dtype_);
+  }
+  void VisitNodeRefFields(FNodeRefVisit fvisit) override {
+    fvisit("tensor", &tensor);
+    fvisit("indices", &indices);
+  }
+};
+
 }  // namespace tvm
 
 #endif  // TVM_EXPR_NODE_H_
diff --git a/include/tvm/tensor.h b/include/tvm/tensor.h
index 861202306..bc06ed3ee 100644
--- a/include/tvm/tensor.h
+++ b/include/tvm/tensor.h
@@ -7,6 +7,7 @@
 #define TVM_TENSOR_H_
 
 #include <string>
+#include <type_traits>
 #include "./expr.h"
 #include "./array.h"
 
@@ -19,15 +20,14 @@ class TensorNode : public Node {
   std::string name;
   /*! \brief data type in the content of the tensor */
   DataType dtype;
-  /*! \brief The index on each dimension */
+  /*! \brief The index representing each dimension, used by source expression. */
   Array<Var> dim_index;
   /*! \brief The shape of the tensor */
   Array<Expr> shape;
   /*! \brief source expression */
   Expr source;
   /*! \brief constructor */
-  TensorNode() {
-  }
+  TensorNode() {}
   const char* type_key() const override {
     return "TensorNode";
   }
@@ -42,20 +42,104 @@ class TensorNode : public Node {
   }
 };
 
+/*! \brief The compute function to specify the input source of a Tensor */
+using FCompute = std::function<Expr (const Array<Var>& i)>;
+
+// converters from other functions into fcompute
+inline FCompute GetFCompute(std::function<Expr (Var x)> f) {
+  return [f](const Array<Var>& i) { return f(i[0]); };
+}
+inline FCompute GetFCompute(std::function<Expr (Var, Var)> f) {
+  return [f](const Array<Var>& i) { return f(i[0], i[1]); };
+}
+inline FCompute GetFCompute(std::function<Expr (Var, Var, Var)> f) {
+  return [f](const Array<Var>& i) { return f(i[0], i[1], i[2]); };
+}
+inline FCompute GetFCompute(std::function<Expr (Var, Var, Var, Var)> f) {
+  return [f](const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
+}
+
+/*!
+ * \brief Tensor structure representing a possible input,
+ *  or intermediate computation result.
+ */
 class Tensor : public NodeRef {
  public:
-  explicit Tensor(Array<Expr> shape);
-  inline size_t ndim() const;
-
+  /*! \brief default constructor, used internally */
+  Tensor() {}
+  /*!
+   * \brief constructor of input tensor
+   * \param shape Shape of the tensor.
+   * \param name optional name of the Tensor.
+   * \param dtype The data type of the input tensor.
+   */
+  explicit Tensor(Array<Expr> shape,
+                  std::string name = "tensor",
+                  DataType dtype = kFloat32);
+  /*!
+   * \brief constructor of intermediate result.
+   * \param shape Shape of the tensor.
+   * \param fcompute The compute function to create the tensor.
+   * \param name The optional name of the tensor.
+   */
+  Tensor(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
+  // same constructor, specialized for different fcompute function
+  Tensor(Array<Expr> shape, std::function<Expr(Var)> f, std::string name = "tensor")
+      :Tensor(shape, GetFCompute(f), name) {}
+  Tensor(Array<Expr> shape, std::function<Expr(Var, Var)> f, std::string name = "tensor")
+      :Tensor(shape, GetFCompute(f), name) {}
+  Tensor(Array<Expr> shape, std::function<Expr(Var, Var, Var)> f, std::string name = "tensor")
+      :Tensor(shape, GetFCompute(f), name) {}
+  Tensor(Array<Expr> shape, std::function<Expr(Var, Var, Var, Var)> f, std::string name = "tensor")
+        :Tensor(shape, GetFCompute(f), name) {}
+  /*! \return The dimension of the tensor */
+  inline size_t ndim() const {
+    return static_cast<const TensorNode*>(node_.get())->shape.size();
+  }
+  /*! \return The name of the tensor */
+  inline const std::string& name() const {
+    return static_cast<const TensorNode*>(node_.get())->name;
+  }
+  /*! \return The data type tensor */
+  inline DataType dtype() const {
+    return static_cast<const TensorNode*>(node_.get())->dtype;
+  }
+  /*! \return The source expression of intermediate tensor */
+  inline const Expr& source() const {
+    return static_cast<const TensorNode*>(node_.get())->source;
+  }
+  /*! \return The internal dimension index used by source expression */
+  inline const Array<Var>& dim_index() const {
+    return static_cast<const TensorNode*>(node_.get())->dim_index;
+  }
+  /*! \return The shape of the tensor */
+  inline const Array<Expr>& shape() const {
+    return static_cast<const TensorNode*>(node_.get())->shape;
+  }
+  /*!
+   * \brief Take elements from the tensor
+   * \param args The indices
+   * \return the result expression representing tensor read.
+   */
   template<typename... Args>
   inline Expr operator()(Args&& ...args) const {
     Array<Expr> indices{std::forward<Args>(args)...};
-    CHECK_EQ(ndim(), indices.size())
-        << "Tensor dimension mismatch in read";
-    return Expr{};
+    return operator()(indices);
+  }
+  /*!
+   * \brief Take elements from the tensor
+   * \param indices the indices.
+   * \return the result expression representing tensor read.
+   */
+  Expr operator()(Array<Expr> indices) const;
+  // printt function
+  friend std::ostream& operator<<(std::ostream &os, const Tensor& t) {  // NOLINT(*)
+    os << "Tensor(shape=" << t.shape()
+       << ", source=" << t.source()
+       << ", name=" << t.name() << ')';
+    return os;
   }
 };
 
-
 }  // namespace tvm
 #endif  // TVM_TENSOR_H_
diff --git a/src/expr/domain.cc b/src/expr/domain.cc
index 1fd51bd0c..e8c56e414 100644
--- a/src/expr/domain.cc
+++ b/src/expr/domain.cc
@@ -22,7 +22,9 @@ Expr Range::extent() const {
 RDomain::RDomain(Domain domain) {
   std::vector<Var> index;
   for (size_t i = 0; i < domain.size(); ++i) {
-    index.push_back(Var("reduction_index"));
+    std::ostringstream os;
+    os << "reduction_index" << i;
+    index.push_back(Var(os.str()));
   }
   Array<Var> idx(index);
   node_ = std::make_shared<RDomainNode>(
diff --git a/src/expr/expr.cc b/src/expr/expr.cc
index d0479d3f6..1121c8191 100644
--- a/src/expr/expr.cc
+++ b/src/expr/expr.cc
@@ -55,6 +55,11 @@ void Expr::Print(std::ostream& os) const {
       os << ", " << n->rdom << ')';
       return;
     }
+    case kTensorReadNode: {
+      const auto* n = Get<TensorReadNode>();
+      os << n->tensor.name() << n->indices;
+      return;
+    }
     default: {
       LOG(FATAL) << "not able to handle type " << typeid(node_.get()).name();
     }
diff --git a/src/expr/expr_node.cc b/src/expr/expr_node.cc
index cc05472c9..33a53bd04 100644
--- a/src/expr/expr_node.cc
+++ b/src/expr/expr_node.cc
@@ -43,5 +43,6 @@ TVM_REGISTER_NODE_TYPE(FloatNode);
 TVM_REGISTER_NODE_TYPE(UnaryOpNode);
 TVM_REGISTER_NODE_TYPE(BinaryOpNode);
 TVM_REGISTER_NODE_TYPE(ReduceNode);
+TVM_REGISTER_NODE_TYPE(TensorReadNode);
 
 }  // namespace tvm
diff --git a/src/expr/tensor.cc b/src/expr/tensor.cc
new file mode 100644
index 000000000..3067a7425
--- /dev/null
+++ b/src/expr/tensor.cc
@@ -0,0 +1,48 @@
+/*!
+ *  Copyright (c) 2016 by Contributors
+ * \file tensor.cc
+ */
+#include <tvm/tensor.h>
+#include <tvm/expr_node.h>
+#include <memory>
+
+namespace tvm {
+
+Tensor::Tensor(Array<Expr> shape, std::string name, DataType dtype) {
+  auto node = std::make_shared<TensorNode>();
+  node->name = std::move(name);
+  node->dtype = dtype;
+  node->shape = std::move(shape);
+  node_ = std::move(node);
+}
+
+Tensor::Tensor(Array<Expr> shape, FCompute fcompute, std::string name) {
+  auto node = std::make_shared<TensorNode>();
+  node->name = std::move(name);
+  node->shape = std::move(shape);
+  size_t ndim = node->shape.size();
+  std::vector<Var> dim_index;
+  for (size_t i = 0; i < ndim; ++i) {
+    std::ostringstream os;
+    os << "dim_index" << i;
+    dim_index.push_back(Var(os.str()));
+  }
+  node->dim_index = Array<Var>(dim_index);
+  node->source = fcompute(node->dim_index);
+  node->dtype = node->source.dtype();
+  node_ = std::move(node);
+}
+
+Expr Tensor::operator()(Array<Expr> indices) const {
+  CHECK_EQ(ndim(), indices.size())
+      << "Tensor dimension mismatch in read"
+      << "ndim = " << ndim() << ", indices.size=" << indices.size();
+  auto node = std::make_shared<TensorReadNode>();
+  node->tensor = *this;
+  node->indices = std::move(indices);
+  return Expr(std::move(node));
+}
+
+TVM_REGISTER_NODE_TYPE(TensorNode);
+
+}  // namespace tvm
diff --git a/tests/cpp/tensor_test.cc b/tests/cpp/tensor_test.cc
index 5281bccdd..814bcd5aa 100644
--- a/tests/cpp/tensor_test.cc
+++ b/tests/cpp/tensor_test.cc
@@ -5,9 +5,14 @@
 
 TEST(Tensor, Basic) {
   using namespace tvm;
-  Var m, n, k;
-  Tensor A({m, k});
-  Tensor B({n, k});
+  Var m("m"), n("n"), l("l");
+  Tensor A({m, l}, "A");
+  Tensor B({n, l}, "B");
+  RDomain rd({{0, l}});
+
+  auto C = Tensor({m, n}, [&](Var i, Var j) {
+      return sum(A(i, rd.i0()) * B(j, rd.i0()), rd);
+    }, "C");
 }
 
 int main(int argc, char ** argv) {
-- 
GitLab