Skip to content
Snippets Groups Projects
Commit 56e10eb0 authored by tqchen's avatar tqchen
Browse files

Tensor API

parent 5f829774
No related branches found
No related tags found
No related merge requests found
......@@ -23,10 +23,10 @@ class ArrayNode : public Node {
return "ArrayNode";
}
void VisitAttrs(AttrVisitor* visitor) override {
LOG(FATAL) << "need to specially handle list";
LOG(FATAL) << "need to specially handle list attrs";
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
LOG(FATAL) << "need to specially handle list";
// Do nothing, specially handled
}
};
......
......@@ -141,7 +141,7 @@ struct BinaryOpNode : public ExprNode {
}
};
/*! \brief Binary mapping operator */
/*! \brief Reduction operator operator */
struct ReduceNode : public ExprNode {
public:
/*! \brief The operator */
......@@ -178,6 +178,43 @@ struct ReduceNode : public ExprNode {
}
};
/*! \brief Tensor read operator */
struct TensorReadNode : public ExprNode {
public:
/*! \brief The tensor to be read from */
Tensor tensor;
/*! \brief The indices of read */
Array<Expr> indices;
/*! \brief constructor, do not use constructor */
TensorReadNode() {
node_type_ = kTensorReadNode;
}
TensorReadNode(Tensor && tensor, Array<Expr> && indices)
: tensor(std::move(tensor)), indices(std::move(indices)) {
node_type_ = kReduceNode;
dtype_ = tensor.dtype();
}
~TensorReadNode() {
this->Destroy();
}
const char* type_key() const override {
return "TensorReadNode";
}
void Verify() const override {
CHECK_EQ(dtype_, tensor.dtype());
for (size_t i = 0; i < indices.size(); ++i) {
CHECK_EQ(indices[i].dtype(), kInt32);
}
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("dtype", &dtype_);
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("tensor", &tensor);
fvisit("indices", &indices);
}
};
} // namespace tvm
#endif // TVM_EXPR_NODE_H_
......@@ -7,6 +7,7 @@
#define TVM_TENSOR_H_
#include <string>
#include <type_traits>
#include "./expr.h"
#include "./array.h"
......@@ -19,15 +20,14 @@ class TensorNode : public Node {
std::string name;
/*! \brief data type in the content of the tensor */
DataType dtype;
/*! \brief The index on each dimension */
/*! \brief The index representing each dimension, used by source expression. */
Array<Var> dim_index;
/*! \brief The shape of the tensor */
Array<Expr> shape;
/*! \brief source expression */
Expr source;
/*! \brief constructor */
TensorNode() {
}
TensorNode() {}
const char* type_key() const override {
return "TensorNode";
}
......@@ -42,20 +42,104 @@ class TensorNode : public Node {
}
};
/*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>;
// converters from other functions into fcompute
inline FCompute GetFCompute(std::function<Expr (Var x)> f) {
return [f](const Array<Var>& i) { return f(i[0]); };
}
inline FCompute GetFCompute(std::function<Expr (Var, Var)> f) {
return [f](const Array<Var>& i) { return f(i[0], i[1]); };
}
inline FCompute GetFCompute(std::function<Expr (Var, Var, Var)> f) {
return [f](const Array<Var>& i) { return f(i[0], i[1], i[2]); };
}
inline FCompute GetFCompute(std::function<Expr (Var, Var, Var, Var)> f) {
return [f](const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
}
/*!
* \brief Tensor structure representing a possible input,
* or intermediate computation result.
*/
class Tensor : public NodeRef {
public:
explicit Tensor(Array<Expr> shape);
inline size_t ndim() const;
/*! \brief default constructor, used internally */
Tensor() {}
/*!
* \brief constructor of input tensor
* \param shape Shape of the tensor.
* \param name optional name of the Tensor.
* \param dtype The data type of the input tensor.
*/
explicit Tensor(Array<Expr> shape,
std::string name = "tensor",
DataType dtype = kFloat32);
/*!
* \brief constructor of intermediate result.
* \param shape Shape of the tensor.
* \param fcompute The compute function to create the tensor.
* \param name The optional name of the tensor.
*/
Tensor(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
// same constructor, specialized for different fcompute function
Tensor(Array<Expr> shape, std::function<Expr(Var)> f, std::string name = "tensor")
:Tensor(shape, GetFCompute(f), name) {}
Tensor(Array<Expr> shape, std::function<Expr(Var, Var)> f, std::string name = "tensor")
:Tensor(shape, GetFCompute(f), name) {}
Tensor(Array<Expr> shape, std::function<Expr(Var, Var, Var)> f, std::string name = "tensor")
:Tensor(shape, GetFCompute(f), name) {}
Tensor(Array<Expr> shape, std::function<Expr(Var, Var, Var, Var)> f, std::string name = "tensor")
:Tensor(shape, GetFCompute(f), name) {}
/*! \return The dimension of the tensor */
inline size_t ndim() const {
return static_cast<const TensorNode*>(node_.get())->shape.size();
}
/*! \return The name of the tensor */
inline const std::string& name() const {
return static_cast<const TensorNode*>(node_.get())->name;
}
/*! \return The data type tensor */
inline DataType dtype() const {
return static_cast<const TensorNode*>(node_.get())->dtype;
}
/*! \return The source expression of intermediate tensor */
inline const Expr& source() const {
return static_cast<const TensorNode*>(node_.get())->source;
}
/*! \return The internal dimension index used by source expression */
inline const Array<Var>& dim_index() const {
return static_cast<const TensorNode*>(node_.get())->dim_index;
}
/*! \return The shape of the tensor */
inline const Array<Expr>& shape() const {
return static_cast<const TensorNode*>(node_.get())->shape;
}
/*!
* \brief Take elements from the tensor
* \param args The indices
* \return the result expression representing tensor read.
*/
template<typename... Args>
inline Expr operator()(Args&& ...args) const {
Array<Expr> indices{std::forward<Args>(args)...};
CHECK_EQ(ndim(), indices.size())
<< "Tensor dimension mismatch in read";
return Expr{};
return operator()(indices);
}
/*!
* \brief Take elements from the tensor
* \param indices the indices.
* \return the result expression representing tensor read.
*/
Expr operator()(Array<Expr> indices) const;
// printt function
friend std::ostream& operator<<(std::ostream &os, const Tensor& t) { // NOLINT(*)
os << "Tensor(shape=" << t.shape()
<< ", source=" << t.source()
<< ", name=" << t.name() << ')';
return os;
}
};
} // namespace tvm
#endif // TVM_TENSOR_H_
......@@ -22,7 +22,9 @@ Expr Range::extent() const {
RDomain::RDomain(Domain domain) {
std::vector<Var> index;
for (size_t i = 0; i < domain.size(); ++i) {
index.push_back(Var("reduction_index"));
std::ostringstream os;
os << "reduction_index" << i;
index.push_back(Var(os.str()));
}
Array<Var> idx(index);
node_ = std::make_shared<RDomainNode>(
......
......@@ -55,6 +55,11 @@ void Expr::Print(std::ostream& os) const {
os << ", " << n->rdom << ')';
return;
}
case kTensorReadNode: {
const auto* n = Get<TensorReadNode>();
os << n->tensor.name() << n->indices;
return;
}
default: {
LOG(FATAL) << "not able to handle type " << typeid(node_.get()).name();
}
......
......@@ -43,5 +43,6 @@ TVM_REGISTER_NODE_TYPE(FloatNode);
TVM_REGISTER_NODE_TYPE(UnaryOpNode);
TVM_REGISTER_NODE_TYPE(BinaryOpNode);
TVM_REGISTER_NODE_TYPE(ReduceNode);
TVM_REGISTER_NODE_TYPE(TensorReadNode);
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file tensor.cc
*/
#include <tvm/tensor.h>
#include <tvm/expr_node.h>
#include <memory>
namespace tvm {
Tensor::Tensor(Array<Expr> shape, std::string name, DataType dtype) {
auto node = std::make_shared<TensorNode>();
node->name = std::move(name);
node->dtype = dtype;
node->shape = std::move(shape);
node_ = std::move(node);
}
Tensor::Tensor(Array<Expr> shape, FCompute fcompute, std::string name) {
auto node = std::make_shared<TensorNode>();
node->name = std::move(name);
node->shape = std::move(shape);
size_t ndim = node->shape.size();
std::vector<Var> dim_index;
for (size_t i = 0; i < ndim; ++i) {
std::ostringstream os;
os << "dim_index" << i;
dim_index.push_back(Var(os.str()));
}
node->dim_index = Array<Var>(dim_index);
node->source = fcompute(node->dim_index);
node->dtype = node->source.dtype();
node_ = std::move(node);
}
Expr Tensor::operator()(Array<Expr> indices) const {
CHECK_EQ(ndim(), indices.size())
<< "Tensor dimension mismatch in read"
<< "ndim = " << ndim() << ", indices.size=" << indices.size();
auto node = std::make_shared<TensorReadNode>();
node->tensor = *this;
node->indices = std::move(indices);
return Expr(std::move(node));
}
TVM_REGISTER_NODE_TYPE(TensorNode);
} // namespace tvm
......@@ -5,9 +5,14 @@
TEST(Tensor, Basic) {
using namespace tvm;
Var m, n, k;
Tensor A({m, k});
Tensor B({n, k});
Var m("m"), n("n"), l("l");
Tensor A({m, l}, "A");
Tensor B({n, l}, "B");
RDomain rd({{0, l}});
auto C = Tensor({m, n}, [&](Var i, Var j) {
return sum(A(i, rd.i0()) * B(j, rd.i0()), rd);
}, "C");
}
int main(int argc, char ** argv) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment