From 56e10eb09f0d048b5eec4b13dccff6dfeb4c703d Mon Sep 17 00:00:00 2001 From: tqchen <tianqi.tchen@gmail.com> Date: Wed, 19 Oct 2016 14:27:07 -0700 Subject: [PATCH] Tensor API --- include/tvm/array.h | 4 +- include/tvm/expr_node.h | 39 ++++++++++++++- include/tvm/tensor.h | 104 +++++++++++++++++++++++++++++++++++---- src/expr/domain.cc | 4 +- src/expr/expr.cc | 5 ++ src/expr/expr_node.cc | 1 + src/expr/tensor.cc | 48 ++++++++++++++++++ tests/cpp/tensor_test.cc | 11 +++-- 8 files changed, 199 insertions(+), 17 deletions(-) create mode 100644 src/expr/tensor.cc diff --git a/include/tvm/array.h b/include/tvm/array.h index 4484d3b89..db5e0d7af 100644 --- a/include/tvm/array.h +++ b/include/tvm/array.h @@ -23,10 +23,10 @@ class ArrayNode : public Node { return "ArrayNode"; } void VisitAttrs(AttrVisitor* visitor) override { - LOG(FATAL) << "need to specially handle list"; + LOG(FATAL) << "need to specially handle list attrs"; } void VisitNodeRefFields(FNodeRefVisit fvisit) override { - LOG(FATAL) << "need to specially handle list"; + // Do nothing, specially handled } }; diff --git a/include/tvm/expr_node.h b/include/tvm/expr_node.h index 371d00908..d5dec0bb8 100644 --- a/include/tvm/expr_node.h +++ b/include/tvm/expr_node.h @@ -141,7 +141,7 @@ struct BinaryOpNode : public ExprNode { } }; -/*! \brief Binary mapping operator */ +/*! \brief Reduction operator operator */ struct ReduceNode : public ExprNode { public: /*! \brief The operator */ @@ -178,6 +178,43 @@ struct ReduceNode : public ExprNode { } }; +/*! \brief Tensor read operator */ +struct TensorReadNode : public ExprNode { + public: + /*! \brief The tensor to be read from */ + Tensor tensor; + /*! \brief The indices of read */ + Array<Expr> indices; + /*! \brief constructor, do not use constructor */ + TensorReadNode() { + node_type_ = kTensorReadNode; + } + TensorReadNode(Tensor && tensor, Array<Expr> && indices) + : tensor(std::move(tensor)), indices(std::move(indices)) { + node_type_ = kReduceNode; + dtype_ = tensor.dtype(); + } + ~TensorReadNode() { + this->Destroy(); + } + const char* type_key() const override { + return "TensorReadNode"; + } + void Verify() const override { + CHECK_EQ(dtype_, tensor.dtype()); + for (size_t i = 0; i < indices.size(); ++i) { + CHECK_EQ(indices[i].dtype(), kInt32); + } + } + void VisitAttrs(AttrVisitor* visitor) override { + visitor->Visit("dtype", &dtype_); + } + void VisitNodeRefFields(FNodeRefVisit fvisit) override { + fvisit("tensor", &tensor); + fvisit("indices", &indices); + } +}; + } // namespace tvm #endif // TVM_EXPR_NODE_H_ diff --git a/include/tvm/tensor.h b/include/tvm/tensor.h index 861202306..bc06ed3ee 100644 --- a/include/tvm/tensor.h +++ b/include/tvm/tensor.h @@ -7,6 +7,7 @@ #define TVM_TENSOR_H_ #include <string> +#include <type_traits> #include "./expr.h" #include "./array.h" @@ -19,15 +20,14 @@ class TensorNode : public Node { std::string name; /*! \brief data type in the content of the tensor */ DataType dtype; - /*! \brief The index on each dimension */ + /*! \brief The index representing each dimension, used by source expression. */ Array<Var> dim_index; /*! \brief The shape of the tensor */ Array<Expr> shape; /*! \brief source expression */ Expr source; /*! \brief constructor */ - TensorNode() { - } + TensorNode() {} const char* type_key() const override { return "TensorNode"; } @@ -42,20 +42,104 @@ class TensorNode : public Node { } }; +/*! \brief The compute function to specify the input source of a Tensor */ +using FCompute = std::function<Expr (const Array<Var>& i)>; + +// converters from other functions into fcompute +inline FCompute GetFCompute(std::function<Expr (Var x)> f) { + return [f](const Array<Var>& i) { return f(i[0]); }; +} +inline FCompute GetFCompute(std::function<Expr (Var, Var)> f) { + return [f](const Array<Var>& i) { return f(i[0], i[1]); }; +} +inline FCompute GetFCompute(std::function<Expr (Var, Var, Var)> f) { + return [f](const Array<Var>& i) { return f(i[0], i[1], i[2]); }; +} +inline FCompute GetFCompute(std::function<Expr (Var, Var, Var, Var)> f) { + return [f](const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); }; +} + +/*! + * \brief Tensor structure representing a possible input, + * or intermediate computation result. + */ class Tensor : public NodeRef { public: - explicit Tensor(Array<Expr> shape); - inline size_t ndim() const; - + /*! \brief default constructor, used internally */ + Tensor() {} + /*! + * \brief constructor of input tensor + * \param shape Shape of the tensor. + * \param name optional name of the Tensor. + * \param dtype The data type of the input tensor. + */ + explicit Tensor(Array<Expr> shape, + std::string name = "tensor", + DataType dtype = kFloat32); + /*! + * \brief constructor of intermediate result. + * \param shape Shape of the tensor. + * \param fcompute The compute function to create the tensor. + * \param name The optional name of the tensor. + */ + Tensor(Array<Expr> shape, FCompute fcompute, std::string name = "tensor"); + // same constructor, specialized for different fcompute function + Tensor(Array<Expr> shape, std::function<Expr(Var)> f, std::string name = "tensor") + :Tensor(shape, GetFCompute(f), name) {} + Tensor(Array<Expr> shape, std::function<Expr(Var, Var)> f, std::string name = "tensor") + :Tensor(shape, GetFCompute(f), name) {} + Tensor(Array<Expr> shape, std::function<Expr(Var, Var, Var)> f, std::string name = "tensor") + :Tensor(shape, GetFCompute(f), name) {} + Tensor(Array<Expr> shape, std::function<Expr(Var, Var, Var, Var)> f, std::string name = "tensor") + :Tensor(shape, GetFCompute(f), name) {} + /*! \return The dimension of the tensor */ + inline size_t ndim() const { + return static_cast<const TensorNode*>(node_.get())->shape.size(); + } + /*! \return The name of the tensor */ + inline const std::string& name() const { + return static_cast<const TensorNode*>(node_.get())->name; + } + /*! \return The data type tensor */ + inline DataType dtype() const { + return static_cast<const TensorNode*>(node_.get())->dtype; + } + /*! \return The source expression of intermediate tensor */ + inline const Expr& source() const { + return static_cast<const TensorNode*>(node_.get())->source; + } + /*! \return The internal dimension index used by source expression */ + inline const Array<Var>& dim_index() const { + return static_cast<const TensorNode*>(node_.get())->dim_index; + } + /*! \return The shape of the tensor */ + inline const Array<Expr>& shape() const { + return static_cast<const TensorNode*>(node_.get())->shape; + } + /*! + * \brief Take elements from the tensor + * \param args The indices + * \return the result expression representing tensor read. + */ template<typename... Args> inline Expr operator()(Args&& ...args) const { Array<Expr> indices{std::forward<Args>(args)...}; - CHECK_EQ(ndim(), indices.size()) - << "Tensor dimension mismatch in read"; - return Expr{}; + return operator()(indices); + } + /*! + * \brief Take elements from the tensor + * \param indices the indices. + * \return the result expression representing tensor read. + */ + Expr operator()(Array<Expr> indices) const; + // printt function + friend std::ostream& operator<<(std::ostream &os, const Tensor& t) { // NOLINT(*) + os << "Tensor(shape=" << t.shape() + << ", source=" << t.source() + << ", name=" << t.name() << ')'; + return os; } }; - } // namespace tvm #endif // TVM_TENSOR_H_ diff --git a/src/expr/domain.cc b/src/expr/domain.cc index 1fd51bd0c..e8c56e414 100644 --- a/src/expr/domain.cc +++ b/src/expr/domain.cc @@ -22,7 +22,9 @@ Expr Range::extent() const { RDomain::RDomain(Domain domain) { std::vector<Var> index; for (size_t i = 0; i < domain.size(); ++i) { - index.push_back(Var("reduction_index")); + std::ostringstream os; + os << "reduction_index" << i; + index.push_back(Var(os.str())); } Array<Var> idx(index); node_ = std::make_shared<RDomainNode>( diff --git a/src/expr/expr.cc b/src/expr/expr.cc index d0479d3f6..1121c8191 100644 --- a/src/expr/expr.cc +++ b/src/expr/expr.cc @@ -55,6 +55,11 @@ void Expr::Print(std::ostream& os) const { os << ", " << n->rdom << ')'; return; } + case kTensorReadNode: { + const auto* n = Get<TensorReadNode>(); + os << n->tensor.name() << n->indices; + return; + } default: { LOG(FATAL) << "not able to handle type " << typeid(node_.get()).name(); } diff --git a/src/expr/expr_node.cc b/src/expr/expr_node.cc index cc05472c9..33a53bd04 100644 --- a/src/expr/expr_node.cc +++ b/src/expr/expr_node.cc @@ -43,5 +43,6 @@ TVM_REGISTER_NODE_TYPE(FloatNode); TVM_REGISTER_NODE_TYPE(UnaryOpNode); TVM_REGISTER_NODE_TYPE(BinaryOpNode); TVM_REGISTER_NODE_TYPE(ReduceNode); +TVM_REGISTER_NODE_TYPE(TensorReadNode); } // namespace tvm diff --git a/src/expr/tensor.cc b/src/expr/tensor.cc new file mode 100644 index 000000000..3067a7425 --- /dev/null +++ b/src/expr/tensor.cc @@ -0,0 +1,48 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file tensor.cc + */ +#include <tvm/tensor.h> +#include <tvm/expr_node.h> +#include <memory> + +namespace tvm { + +Tensor::Tensor(Array<Expr> shape, std::string name, DataType dtype) { + auto node = std::make_shared<TensorNode>(); + node->name = std::move(name); + node->dtype = dtype; + node->shape = std::move(shape); + node_ = std::move(node); +} + +Tensor::Tensor(Array<Expr> shape, FCompute fcompute, std::string name) { + auto node = std::make_shared<TensorNode>(); + node->name = std::move(name); + node->shape = std::move(shape); + size_t ndim = node->shape.size(); + std::vector<Var> dim_index; + for (size_t i = 0; i < ndim; ++i) { + std::ostringstream os; + os << "dim_index" << i; + dim_index.push_back(Var(os.str())); + } + node->dim_index = Array<Var>(dim_index); + node->source = fcompute(node->dim_index); + node->dtype = node->source.dtype(); + node_ = std::move(node); +} + +Expr Tensor::operator()(Array<Expr> indices) const { + CHECK_EQ(ndim(), indices.size()) + << "Tensor dimension mismatch in read" + << "ndim = " << ndim() << ", indices.size=" << indices.size(); + auto node = std::make_shared<TensorReadNode>(); + node->tensor = *this; + node->indices = std::move(indices); + return Expr(std::move(node)); +} + +TVM_REGISTER_NODE_TYPE(TensorNode); + +} // namespace tvm diff --git a/tests/cpp/tensor_test.cc b/tests/cpp/tensor_test.cc index 5281bccdd..814bcd5aa 100644 --- a/tests/cpp/tensor_test.cc +++ b/tests/cpp/tensor_test.cc @@ -5,9 +5,14 @@ TEST(Tensor, Basic) { using namespace tvm; - Var m, n, k; - Tensor A({m, k}); - Tensor B({n, k}); + Var m("m"), n("n"), l("l"); + Tensor A({m, l}, "A"); + Tensor B({n, l}, "B"); + RDomain rd({{0, l}}); + + auto C = Tensor({m, n}, [&](Var i, Var j) { + return sum(A(i, rd.i0()) * B(j, rd.i0()), rd); + }, "C"); } int main(int argc, char ** argv) { -- GitLab