Skip to content
Snippets Groups Projects
Commit 0068781d authored by tqchen's avatar tqchen
Browse files

Check in Tensor API on python

parent bcea8f6f
No related branches found
No related tags found
No related merge requests found
Subproject commit f72e313118a61b0cc49987b9eebfc77300d2de0d
Subproject commit bd94f8c8e41b46ae7ca69a3405aac7463a4e23d5
/*!
* Copyright (c) 2016 by Contributors
* \file domain.h
* \brief Defines the domain in AST
*/
#ifndef TVM_DOMAIN_H_
#define TVM_DOMAIN_H_
#include <ir/Range.h>
#include <memory>
#include "./base.h"
#include "./expr.h"
namespace tvm {
/*! \brief container class of reduction domain */
class RDomainNode;
/*!
* \brief same as Halide::IR::Range
* except it provide an constructor with (begin, end)
*
* \note Traditional Halide's Range have a constructor with
* (begin, extent), which does not match the convention in e.g. python.
* We decided to correct it by removing the constructor in HalideIR,
* and add it back in TVM's range.
*/
class Range : public Halide::IR::Range {
public:
/*! \brief constructor */
Range() {}
explicit Range(std::shared_ptr<Node> n) : Halide::IR::Range(n) {}
/*!
* \brief constructor by begin and end
* \param begin The begin of the range.
* \param end The end of the range.
*/
Range(Expr begin, Expr end);
};
/*! \brief Domain is a multi-dimensional range */
using Domain = Array<Range>;
/*! \brief reduction domain */
class RDomain : public NodeRef {
public:
/*! \brief constructor*/
RDomain() {}
/*!
* constructor by domain
* \param domain The domain of reduction.
*/
explicit RDomain(Domain domain);
/*!
* \brief constructor by list of ranges
* \param domain The reduction domain
*/
explicit RDomain(std::initializer_list<Range> domain)
: RDomain(Domain(domain)) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const RDomainNode* operator->() const;
/*! \return The dimension of the RDomain */
inline size_t ndim() const;
/*!
* \param i the index.
* \return i-th index variable in the RDomain
*/
inline Var index(size_t i) const;
/*! \return the 0-th index of the domain */
inline Var i0() const {
return index(0);
}
};
/*! \brief use RDom as alias of RDomain */
using RDom = RDomain;
/*! \brief reduction domain node */
class RDomainNode : public Node {
public:
/*! \brief internal index */
Array<Var> index;
/*! \brief The inernal domain */
Domain domain;
/*! \brief constructor */
RDomainNode() {}
RDomainNode(Array<Var> && index, Domain && domain)
: index(std::move(index)), domain(std::move(domain)) {
}
const char* type_key() const override {
return "RDomain";
}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("index", &index);
v->Visit("domain", &domain);
}
};
inline const RDomainNode* RDomain::operator->() const {
return static_cast<const RDomainNode*>(node_.get());
}
inline size_t RDomain::ndim() const {
return (*this)->index.size();
}
inline Var RDomain::index(size_t i) const {
return (*this)->index[i];
}
// overload print function
inline std::ostream& operator<<(std::ostream &os, const RDomain& r){ // NOLINT(*)
os << "rdomain(" << r->domain << ")";
return os;
}
} // namespace tvm
#endif // TVM_DOMAIN_H_
......@@ -27,6 +27,7 @@ using Halide::abs;
using Halide::select;
using Halide::Expr;
using Halide::Internal::Stmt;
using Var = Halide::VarExpr;
} // namespace tvm
......
......@@ -6,11 +6,12 @@
#ifndef TVM_TENSOR_H_
#define TVM_TENSOR_H_
#include <tvm/array.h>
#include <ir/FunctionBase.h>
#include <string>
#include <vector>
#include <type_traits>
#include <tvm/array.h>
#include <ir/FunctionBase.h>
#include "./base.h"
#include "./expr.h"
......@@ -46,6 +47,7 @@ class Tensor : public FunctionRef {
public:
/*! \brief default constructor, used internally */
Tensor() {}
explicit Tensor(std::shared_ptr<Node> n) : FunctionRef(n) {}
/*!
* \brief constructor of input tensor
* \param shape Shape of the tensor.
......@@ -101,14 +103,14 @@ class Tensor : public FunctionRef {
/*! \brief Node to represent a tensor */
class TensorNode : public Node {
public:
/*! \brief The shape of the tensor */
Array<Expr> shape;
/*! \brief optional name of the tensor */
std::string name;
/*! \brief data type in the content of the tensor */
Type dtype;
/*! \brief The index representing each dimension, used by source expression. */
Array<Var> dim_var;
/*! \brief The shape of the tensor */
Array<Expr> shape;
/*! \brief source expression */
Expr source;
/*! \brief constructor */
......@@ -117,13 +119,17 @@ class TensorNode : public Node {
return "Tensor";
}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("shape", &shape);
v->Visit("name", &name);
v->Visit("dtype", &dtype);
v->Visit("dim_var", &dim_var);
v->Visit("shape", &shape);
v->Visit("source", &source);
}
static Tensor make(Array<Expr> shape,
std::string name,
Type dtype,
Array<Var> dim_var,
Expr source);
};
// implementations
......
......@@ -7,3 +7,4 @@ from . import expr
from . import stmt
from . import make
from . import collections
from . import tensor
......@@ -107,13 +107,13 @@ def convert(value):
"""Convert a value to expression."""
if isinstance(value, Number):
return const(value)
elif isinstance(value, list):
elif isinstance(value, (list, tuple)):
value = [convert(x) for x in value]
return _function_internal._Array(*value)
else:
if not isinstance(value, NodeBase):
raise ValueError("don't know how to handle type %s" % type(value))
return value
def _push_arg(arg):
a = ArgVariant()
......@@ -172,7 +172,7 @@ def _make_function(handle, name):
"""TVM function"""
cargs = []
for x in args:
if isinstance(x, list):
if isinstance(x, (list, tuple)):
cargs.append(convert(x))
else:
cargs.append(x)
......
"""Collection structure in the high level DSL."""
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import _function_internal
......
......@@ -54,6 +54,7 @@ class LogicalExpr(Expr):
@register_node("Variable")
class Var(Expr):
pass
@register_node
......@@ -162,6 +163,12 @@ class Broadcast(Expr):
@register_node
class Call(Expr):
Extern = 0
ExternCPlusPlus = 1
PureExtern = 2
Halide = 3
Intrinsic = 4
PureIntrinsic = 5
pass
@register_node
......
......@@ -35,33 +35,45 @@ def convert(value):
"""Convert a value to expression."""
if isinstance(value, _Number):
return const(value)
elif isinstance(value, list):
elif isinstance(value, (list, tuple)):
value = [convert(x) for x in value]
return _function_internal._Array(*value)
else:
return value
def Range(begin, **kwargs):
"""Create a TVM Range object.
User can either call:
Range(10) to get a range in [0, 10)
or
Range(begin=1, extent=10), to get a range in [0, 11)
def Tensor(shape, fcompute=None, dtype=None, name="TensorObj"):
"""Construct a tensor object in dataflow.
Parameters
----------
begin : Expr
The beginning of the expression.
shape: Tuple of Expr
The shape of the tensor
fcompute: lambda function of *indices-> value
Specifies the input source expression
extent : optional, Expr
The extent(i.e. the length) of the range.
dtype: str, optional
The data type of the tensor, must specify when fcompute is not specified.
name: str, optional
The name hint of the tensor
Returns
-------
tensor: tensor.Tensor
The created tensor
"""
if "extent" in kwargs:
return _function_internal._Range(begin, kwargs["extent"])
ndim = len(shape)
dim_var = [Var("dim_var%d" % i) for i in range(ndim)]
if fcompute:
source = fcompute(*dim_var)
return _function_internal._Tensor(
shape, name, source.dtype, dim_var, source)
else:
return _function_internal._Range(0, begin);
dtype = float32 if dtype is None else dtype
return _function_internal._Tensor(
shape, name, dtype, None, None)
_init_function_module("tvm")
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import _function_internal
from . import make as _make
from . import expr as _expr
@register_node
class Tensor(NodeBase):
"""Tensor object, to construct, see function.Tensor"""
def __call__(self, *indices):
ndim = self.ndim
if len(indices) != ndim:
raise ValueError("Need to provide %d index in tensor slice" % ndim)
return _make.Call(self.dtype, self.name, indices, _expr.Call.Halide, self, 0)
@property
def ndim(self):
return len(self.shape)
......@@ -4,6 +4,8 @@
* \file c_api_impl.cc
*/
#include <tvm/expr.h>
#include <tvm/domain.h>
#include <tvm/tensor.h>
#include <ir/IROperator.h>
#include "./c_api_registry.h"
......@@ -13,30 +15,22 @@ DMLC_REGISTRY_ENABLE(::tvm::APIFunctionReg);
namespace tvm {
using namespace Halide::Internal;
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
TVM_REGISTER_API(_const)
.set_body([](const ArgStack& args, RetValue *ret) {
if (args.at(0).type_id == kLong) {
*ret = make_const(args.at(1), args.at(0).operator int64_t());
} else if (args.at(0).type_id == kDouble) {
*ret = make_const(args.at(1), args.at(0).operator double());
} else {
LOG(FATAL) << "only accept int or float";
}
})
.add_argument("src", "Number", "source number")
.add_argument("dtype", "str", "data type");
TVM_REGISTER_API(format_str)
.set_body([](const ArgStack& args, RetValue *ret) {
using Halide::Internal::BaseExprNode;
using Halide::Internal::BaseStmtNode;
CHECK(args.at(0).type_id == kNodeHandle);
std::ostringstream os;
auto& sptr = args.at(0).sptr;
if (dynamic_cast<const BaseExprNode*>(sptr.get())) {
if (sptr->is_type<TensorNode>()) {
os << args.at(0).operator Tensor();
} else if (sptr->is_type<RDomainNode>()) {
os << args.at(0).operator RDomain();
} else if (dynamic_cast<const BaseExprNode*>(sptr.get())) {
os << args.at(0).operator Expr();
} else if (dynamic_cast<const BaseStmtNode*>(sptr.get())) {
os << args.at(0).operator Stmt();
......@@ -47,46 +41,11 @@ TVM_REGISTER_API(format_str)
})
.add_argument("expr", "Node", "expression to be printed");
TVM_REGISTER_API(_Array)
.set_body([](const ArgStack& args, RetValue *ret) {
std::vector<std::shared_ptr<Node> > data;
for (size_t i = 0; i < args.size(); ++i) {
CHECK(args.at(i).type_id == kNodeHandle);
data.push_back(args.at(i).sptr);
}
auto node = std::make_shared<ArrayNode>();
node->data = std::move(data);
ret->type_id = kNodeHandle;
ret->sptr = node;
});
TVM_REGISTER_API(_ArrayGetItem)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
int64_t i = args.at(1);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>());
auto* n = static_cast<const ArrayNode*>(sptr.get());
CHECK_LT(static_cast<size_t>(i), n->data.size())
<< "out of bound of array";
ret->sptr = n->data[i];
ret->type_id = kNodeHandle;
});
TVM_REGISTER_API(_ArraySize)
TVM_REGISTER_API(_raw_ptr)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>());
*ret = static_cast<int64_t>(
static_cast<const ArrayNode*>(sptr.get())->data.size());
});
TVM_REGISTER_API(_Range)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Range(args.at(0), args.at(1));
*ret = reinterpret_cast<int64_t>(args.at(0).sptr.get());
})
.add_argument("min", "Expr", "beginning of the range.")
.add_argument("extent", "Expr", "extent of the range");
.add_argument("src", "NodeBase", "the node base");
} // namespace tvm
......@@ -29,6 +29,16 @@ TVM_REGISTER_API(_make_For)
args.at(5));
});
TVM_REGISTER_API(_make_Call)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Call::make(args.at(0),
args.at(1),
args.at(2),
static_cast<Call::CallType>(args.at(3).operator int()),
args.at(4));
});
TVM_REGISTER_API(_make_Allocate)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Allocate::make(args.at(0),
......@@ -91,7 +101,6 @@ REGISTER_MAKE3(Select);
REGISTER_MAKE3(Ramp);
REGISTER_MAKE2(Broadcast);
REGISTER_MAKE3(Let);
// TODO(tqchen) Call;
REGISTER_MAKE3(LetStmt);
REGISTER_MAKE2(AssertStmt);
REGISTER_MAKE3(ProducerConsumer);
......
/*!
* Copyright (c) 2016 by Contributors
* Implementation of API functions related to Higher DSL build.
* \file c_api_lang.cc
*/
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/domain.h>
#include <ir/IROperator.h>
#include "./c_api_registry.h"
namespace tvm {
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
TVM_REGISTER_API(_const)
.set_body([](const ArgStack& args, RetValue *ret) {
using Halide::Internal::make_const;
if (args.at(0).type_id == kLong) {
*ret = make_const(args.at(1), args.at(0).operator int64_t());
} else if (args.at(0).type_id == kDouble) {
*ret = make_const(args.at(1), args.at(0).operator double());
} else {
LOG(FATAL) << "only accept int or float";
}
})
.add_argument("src", "Number", "source number")
.add_argument("dtype", "str", "data type");
TVM_REGISTER_API(_Array)
.set_body([](const ArgStack& args, RetValue *ret) {
std::vector<std::shared_ptr<Node> > data;
for (size_t i = 0; i < args.size(); ++i) {
CHECK(args.at(i).type_id == kNodeHandle)
<< "need content of array to be NodeBase";
data.push_back(args.at(i).sptr);
}
auto node = std::make_shared<ArrayNode>();
node->data = std::move(data);
ret->type_id = kNodeHandle;
ret->sptr = node;
});
TVM_REGISTER_API(_ArrayGetItem)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
int64_t i = args.at(1);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>());
auto* n = static_cast<const ArrayNode*>(sptr.get());
CHECK_LT(static_cast<size_t>(i), n->data.size())
<< "out of bound of array";
ret->sptr = n->data[i];
ret->type_id = kNodeHandle;
});
TVM_REGISTER_API(_ArraySize)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>());
*ret = static_cast<int64_t>(
static_cast<const ArrayNode*>(sptr.get())->data.size());
});
TVM_REGISTER_API(Range)
.set_body([](const ArgStack& args, RetValue *ret) {
if (args.size() == 1) {
*ret = Range(0, args.at(0));
} else {
*ret = Range(args.at(0), args.at(1));
}
})
.add_argument("begin", "Expr", "beginning of the range.")
.add_argument("end", "Expr", "extent of the range");
TVM_REGISTER_API(_Tensor)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = TensorNode::make(args.at(0),
args.at(1),
args.at(2),
args.at(3),
args.at(4));
});
TVM_REGISTER_API(_RDomain)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = RDomain(args.at(0).operator Domain());
});
} // namespace tvm
......@@ -80,8 +80,12 @@ struct APIVariantValue {
return *this;
}
inline APIVariantValue& operator=(const NodeRef& ref) {
type_id = kNodeHandle;
this->sptr = ref.node_;
if (ref.node_.get() == nullptr) {
type_id = kNull;
} else {
type_id = kNodeHandle;
this->sptr = ref.node_;
}
return *this;
}
inline APIVariantValue& operator=(const Type& value) {
......
/*!
* Copyright (c) 2016 by Contributors
* \file domain.cc
*/
#include <tvm/base.h>
#include <tvm/domain.h>
namespace tvm {
Range::Range(Expr begin, Expr end)
: Range(std::make_shared<Halide::IR::RangeNode>(begin, end - begin)) {
// TODO(tqchen) add simplify to end - begin
}
RDomain::RDomain(Domain domain) {
std::vector<Var> index;
for (size_t i = 0; i < domain.size(); ++i) {
std::ostringstream os;
os << "reduction_index" << i;
index.push_back(Var(os.str()));
}
Array<Var> idx(index);
node_ = std::make_shared<RDomainNode>(
std::move(idx), std::move(domain));
}
TVM_REGISTER_NODE_TYPE(RDomainNode);
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file expr_node.cc
* \file ir_node.cc
*/
#include <tvm/base.h>
#include <tvm/expr.h>
......
......@@ -42,6 +42,20 @@ Expr Tensor::operator()(Array<Expr> indices) const {
(*this)->dtype, (*this)->name, indices, Call::Halide, *this);
}
Tensor TensorNode::make(Array<Expr> shape,
std::string name,
Type dtype,
Array<Var> dim_var,
Expr source) {
auto n = std::make_shared<TensorNode>();
n->shape = shape;
n->name = name;
n->dtype = dtype;
n->dim_var = dim_var;
n->source = source;
return Tensor(n);
}
TVM_REGISTER_NODE_TYPE(TensorNode);
} // namespace tvm
import tvm
def test_tensor():
m = tvm.Var('m')
n = tvm.Var('n')
l = tvm.Var('l')
A = tvm.Tensor((m, l), name='A')
B = tvm.Tensor((n, l), name='B')
T = tvm.Tensor((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
print(tvm.format_str(T.source))
assert(tuple(T.shape) == (m, n, l))
assert(A.source is None)
if __name__ == "__main__":
test_tensor()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment