Skip to content
Snippets Groups Projects
Commit 61de73b4 authored by tqchen's avatar tqchen
Browse files

Finalize tensor and operation

parent 605813e4
No related branches found
No related tags found
No related merge requests found
......@@ -13,6 +13,7 @@
#include <unordered_map>
#include <vector>
#include "./expr.h"
#include "./schedule.h"
namespace tvm {
namespace ir {
......@@ -50,6 +51,14 @@ Stmt Inline(FunctionRef f,
Expr body,
Stmt stmt);
/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \return the result Stmt
*/
Stmt ScheduelOps(Schedule s);
} // namespace ir
} // namespace tvm
......
......@@ -9,43 +9,15 @@
#include <string>
#include "./expr.h"
#include "./domain.h"
#include "./tensor.h"
namespace tvm {
// internal node container for Operation
class OperationNode;
/*! \brief Split over input domain */
class Operation : public NodeRef {
public:
/*! \brief default constructor */
Operation() {}
explicit Operation(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const OperationNode* operator->() const;
};
/*!
* \brief base class of operation node.
*/
class OperationNode : public Node {
public:
/*! \brief The domain of iteration of this op. */
Domain domain;
/*! \brief optional name of the operation */
std::string name;
};
/*!
* \brief A Compute op that compute a tensor on certain domain.
*/
class ComputeOpNode : public OperationNode {
public:
/*! \brief iter-Var over the dimensions */
Array<Var> dim_var;
/*! \brief the compute expression */
Expr body;
/*! \brief constructor */
......@@ -54,6 +26,12 @@ class ComputeOpNode : public OperationNode {
const char* type_key() const final {
return "ComputeOp";
}
size_t num_outputs() const final {
return 1;
}
std::string output_name(size_t i) const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("domain", &domain);
v->Visit("name", &name);
......@@ -66,9 +44,43 @@ class ComputeOpNode : public OperationNode {
Expr body);
};
// Implementations of inline functions
inline const OperationNode* Operation::operator->() const {
return static_cast<const OperationNode*>(node_.get());
/*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>;
/*!
* \brief Construct a new tensor by computing over shape,
* using the computation rule: result_tensor[axis] = fcompute(axis)
* \param shape Shape of the tensor.
* \param fcompute The compute function to create the tensor.
* \param name The optional name of the tensor.
*/
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
// same as compute, specialized for different fcompute function
inline Tensor Compute(Array<Expr> shape,
std::function<Expr(Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); };
return Compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
std::function<Expr(Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); };
return Compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
return Compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
return Compute(shape, fc, name);
}
} // namespace tvm
......
......@@ -14,12 +14,14 @@
#include "./base.h"
#include "./expr.h"
#include "./operation.h"
#include "./domain.h"
namespace tvm {
// Internal node container of Tensor
class TensorNode;
// internal node container for Operation
class OperationNode;
using Halide::IR::FunctionRef;
......@@ -68,57 +70,24 @@ class Tensor : public FunctionRef {
friend std::ostream& operator<<(std::ostream &os, const Tensor& t);
};
/*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>;
// converters from other functions into fcompute
inline FCompute GetFCompute(std::function<Expr(Var x)> f) {
return [f] (const Array<Var>& i) { return f(i[0]); };
}
inline FCompute GetFCompute(std::function<Expr(Var, Var)> f) {
return [f] (const Array<Var>& i) { return f(i[0], i[1]); };
}
inline FCompute GetFCompute(std::function<Expr(Var, Var, Var)> f) {
return [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
}
inline FCompute GetFCompute(std::function<Expr(Var, Var, Var, Var)> f) {
return [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
}
/*!
* \brief Construct a new tensor by computing over shape,
* using the computation rule: result_tensor[axis] = fcompute(axis)
* \param shape Shape of the tensor.
* \param fcompute The compute function to create the tensor.
* \param name The optional name of the tensor.
*/
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
// same as compute, specialized for different fcompute function
inline Tensor Compute(Array<Expr> shape,
std::function<Expr(Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); };
return Compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
std::function<Expr(Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); };
return Compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
return Compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
return Compute(shape, fc, name);
}
/*! \brief Operation that produces tensors */
class Operation : public NodeRef {
public:
/*! \brief default constructor */
Operation() {}
explicit Operation(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const OperationNode* operator->() const;
/*!
* \brief get the i-th output of the operation.
* \param i the output index.
* \return The i-th output.
*/
Tensor output(size_t i) const;
};
/*! \brief Node to represent a tensor */
class TensorNode : public FunctionBaseNode {
......@@ -158,7 +127,31 @@ class TensorNode : public FunctionBaseNode {
int value_index);
};
// implementations
/*!
* \brief base class of operation node.
*/
class OperationNode : public Node {
public:
/*! \brief The domain of iteration of this op. */
Domain domain;
/*! \brief iter-Var over the dimensions */
Array<Var> dim_var;
/*! \brief optional name of the operation */
std::string name;
/*! \return number of outputs of this op */
virtual size_t num_outputs() const = 0;
/*! \return name of i-th output */
virtual std::string output_name(size_t i) const = 0;
/*! \return type of i-th output */
virtual Type output_dtype(size_t i) const = 0;
/*! \return shape of i-th output */
virtual Array<Expr> output_shape(size_t i) const = 0;
};
// Implementations of inline functions
inline const OperationNode* Operation::operator->() const {
return static_cast<const OperationNode*>(node_.get());
}
inline const TensorNode* Tensor::operator->() const {
return static_cast<const TensorNode*>(node_.get());
......
......@@ -10,12 +10,9 @@
namespace tvm {
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
auto node = std::make_shared<TensorNode>();
auto op_node = std::make_shared<ComputeOpNode>();
node->name = name;
node->shape = shape;
// compute dimension.
size_t ndim = node->shape.size();
size_t ndim = shape.size();
std::vector<Var> dim_index;
for (size_t i = 0; i < ndim; ++i) {
std::ostringstream os;
......@@ -32,10 +29,8 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
op_node->domain = Domain(dom);
op_node->body = fcompute(op_node->dim_var);
op_node->name = name;
node->dtype = op_node->body.type();
node->op = Operation(op_node);
node->value_index = 0;
return Tensor(node);
return Operation(op_node).output(0);
}
Operation ComputeOpNode::make(Domain domain,
......@@ -50,6 +45,37 @@ Operation ComputeOpNode::make(Domain domain,
return Operation(n);
}
Tensor Operation::output(size_t i) const {
auto node = std::make_shared<TensorNode>();
node->op = *this;
node->value_index = 0;
node->name = (*this)->output_name(i);
node->dtype = (*this)->output_dtype(i);
node->shape = (*this)->output_shape(i);
return Tensor(node);
}
std::string ComputeOpNode::output_name(size_t i) const {
CHECK_EQ(i, 0);
return name;
}
Type ComputeOpNode::output_dtype(size_t i) const {
CHECK_EQ(i, 0);
return body.type();
}
Array<Expr> ComputeOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0);
std::vector<Expr> shape;
for (size_t i = 0; i < domain.size(); ++i) {
shape.push_back(domain[i]->extent);
}
return Array<Expr>(shape);
}
TVM_REGISTER_NODE_TYPE(ComputeOpNode);
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file schedule_ops.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
namespace tvm {
namespace ir {
namespace {
// inject the operator's realization on the stmt.
class InjectRealize : public IRMutator {
public:
explicit InjectRealize(std::vector<Tensor> tensors)
: tensors_(tensors) {}
std::vector<Tensor> tensors_;
};
} // namespace
} // namespace ir
} // namespace tvm
......@@ -8,7 +8,7 @@ def test_tensor():
B = tvm.placeholder((n, l), name='B')
T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
print(T.source)
print(T.op.body)
assert(tuple(T.shape) == (m, n, l))
assert(A.source is None)
......@@ -21,7 +21,7 @@ def test_tensor_reduce():
T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
rd = tvm.RDomain(tvm.Range(A.shape[1]))
C = tvm.compute((m, n), lambda i, j: tvm.sum(T(i, j, rd.index[0]), rdom=rd))
print(C.source)
print(C.op.body)
if __name__ == "__main__":
test_tensor()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment