Skip to content
Snippets Groups Projects
Commit 01a7ce0c authored by Tianqi Chen's avatar Tianqi Chen Committed by GitHub
Browse files

[RUNTIME] Add Function, Unify TVMTypeCode and TVMArgTypeID (#24)

parent 4f1473f3
No related branches found
No related tags found
No related merge requests found
......@@ -16,9 +16,9 @@ using RetValue = APIVariantValue;
TVM_REGISTER_API(_const)
.set_body([](const ArgStack& args, RetValue *ret) {
if (args.at(0).type_id == kLong) {
if (args.at(0).type_code == kInt) {
*ret = make_const(args.at(1), args.at(0).operator int64_t());
} else if (args.at(0).type_id == kDouble) {
} else if (args.at(0).type_code == kFloat) {
*ret = make_const(args.at(1), args.at(0).operator double());
} else {
LOG(FATAL) << "only accept int or float";
......@@ -31,19 +31,19 @@ TVM_REGISTER_API(_Array)
.set_body([](const ArgStack& args, RetValue *ret) {
std::vector<std::shared_ptr<Node> > data;
for (size_t i = 0; i < args.size(); ++i) {
CHECK(args.at(i).type_id == kNodeHandle)
CHECK(args.at(i).type_code == kNodeHandle)
<< "need content of array to be NodeBase";
data.push_back(args.at(i).sptr);
}
auto node = std::make_shared<ArrayNode>();
node->data = std::move(data);
ret->type_id = kNodeHandle;
ret->type_code = kNodeHandle;
ret->sptr = node;
});
TVM_REGISTER_API(_ArrayGetItem)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
CHECK(args.at(0).type_code == kNodeHandle);
int64_t i = args.at(1);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>());
......@@ -51,12 +51,12 @@ TVM_REGISTER_API(_ArrayGetItem)
CHECK_LT(static_cast<size_t>(i), n->data.size())
<< "out of bound of array";
ret->sptr = n->data[i];
ret->type_id = kNodeHandle;
ret->type_code = kNodeHandle;
});
TVM_REGISTER_API(_ArraySize)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
CHECK(args.at(0).type_code == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>());
*ret = static_cast<int64_t>(
......@@ -68,21 +68,21 @@ TVM_REGISTER_API(_Map)
CHECK_EQ(args.size() % 2, 0U);
MapNode::ContainerType data;
for (size_t i = 0; i < args.size(); i += 2) {
CHECK(args.at(i).type_id == kNodeHandle)
CHECK(args.at(i).type_code == kNodeHandle)
<< "need content of array to be NodeBase";
CHECK(args.at(i + 1).type_id == kNodeHandle)
CHECK(args.at(i + 1).type_code == kNodeHandle)
<< "need content of array to be NodeBase";
data.emplace(std::make_pair(args.at(i).sptr, args.at(i + 1).sptr));
}
auto node = std::make_shared<MapNode>();
node->data = std::move(data);
ret->type_id = kNodeHandle;
ret->type_code = kNodeHandle;
ret->sptr = node;
});
TVM_REGISTER_API(_MapSize)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
CHECK(args.at(0).type_code == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get());
......@@ -91,8 +91,8 @@ TVM_REGISTER_API(_MapSize)
TVM_REGISTER_API(_MapGetItem)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
CHECK(args.at(1).type_id == kNodeHandle);
CHECK(args.at(0).type_code == kNodeHandle);
CHECK(args.at(1).type_code == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get());
......@@ -100,13 +100,13 @@ TVM_REGISTER_API(_MapGetItem)
CHECK(it != n->data.end())
<< "cannot find the corresponding key in the Map";
ret->sptr = (*it).second;
ret->type_id = kNodeHandle;
ret->type_code = kNodeHandle;
});
TVM_REGISTER_API(_MapCount)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
CHECK(args.at(1).type_id == kNodeHandle);
CHECK(args.at(0).type_code == kNodeHandle);
CHECK(args.at(1).type_code == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get());
......@@ -115,7 +115,7 @@ TVM_REGISTER_API(_MapCount)
TVM_REGISTER_API(_MapItems)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
CHECK(args.at(0).type_code == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get());
......@@ -125,7 +125,7 @@ TVM_REGISTER_API(_MapItems)
rkvs->data.push_back(kv.second);
}
ret->sptr = rkvs;
ret->type_id = kNodeHandle;
ret->type_code = kNodeHandle;
});
TVM_REGISTER_API(Range)
......
......@@ -9,25 +9,25 @@
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/c_api.h>
#include <tvm/runtime/runtime.h>
#include <memory>
#include <limits>
#include <string>
#include <vector>
#include "../base/common.h"
using ArgVariant = TVMArg;
using ArgVariantID = TVMArgTypeID;
namespace tvm {
inline const char* TypeId2Str(ArgVariantID type_id) {
switch (type_id) {
case kNull: return "Null";
case kLong: return "Long";
case kDouble: return "Double";
case kStr: return "Str";
inline const char* TVMTypeCode2Str(int type_code) {
switch (type_code) {
case kInt: return "int";
case kFloat: return "float";
case kStr: return "str";
case kHandle: return "Handle";
case kNull: return "NULL";
case kNodeHandle: return "NodeHandle";
default: LOG(FATAL) << "unknown type_id=" << type_id; return "";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
}
......@@ -96,72 +96,83 @@ inline std::string NodeTypeName() {
class APIVariantValue {
public:
/*! \brief the type id */
ArgVariantID type_id{kNull};
int type_code{kNull};
/*! \brief shared pointer container */
std::shared_ptr<Node> sptr;
/*! \brief string container */
std::string str;
/*! \brief the variant holder */
ArgVariant v_union;
TVMValue v_union;
/*! \brief std::function */
runtime::PackedFunc::FType func;
// constructor
APIVariantValue() {}
APIVariantValue() {
}
// clear value
inline void Clear() {
}
// assign op
inline APIVariantValue& operator=(double value) {
type_id = kDouble;
v_union.v_double = value;
type_code = kFloat;
v_union.v_float64 = value;
return *this;
}
inline APIVariantValue& operator=(std::nullptr_t value) {
type_id = kNull;
type_code = kHandle;
v_union.v_handle = value;
return *this;
}
inline APIVariantValue& operator=(int64_t value) {
type_id = kLong;
v_union.v_long = value;
type_code = kInt;
v_union.v_int64 = value;
return *this;
}
inline APIVariantValue& operator=(bool value) {
type_id = kLong;
v_union.v_long = value;
type_code = kInt;
v_union.v_int64 = value;
return *this;
}
inline APIVariantValue& operator=(std::string value) {
type_id = kStr;
type_code = kStr;
str = std::move(value);
v_union.v_str = str.c_str();
return *this;
}
inline APIVariantValue& operator=(const NodeRef& ref) {
if (ref.node_.get() == nullptr) {
type_id = kNull;
type_code = kNull;
} else {
type_id = kNodeHandle;
type_code = kNodeHandle;
this->sptr = ref.node_;
}
return *this;
}
inline APIVariantValue& operator=(const runtime::PackedFunc& f) {
type_code = kFuncHandle;
this->func = f.body();
return *this;
}
inline APIVariantValue& operator=(const Type& value) {
return operator=(Type2String(value));
}
template<typename T,
typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type>
inline operator T() const {
if (type_id == kNull) return T();
CHECK_EQ(type_id, kNodeHandle);
if (type_code == kNull) return T();
CHECK_EQ(type_code, kNodeHandle);
CHECK(NodeTypeChecker<T>::Check(sptr.get()))
<< "Did not get expected type " << NodeTypeName<T>();
return T(sptr);
}
inline operator Expr() const {
if (type_id == kNull) return Expr();
if (type_id == kLong) return Expr(operator int());
if (type_id == kDouble) {
if (type_code == kNull) {
return Expr();
}
if (type_code == kInt) return Expr(operator int());
if (type_code == kFloat) {
return Expr(static_cast<float>(operator double()));
}
CHECK_EQ(type_id, kNodeHandle);
CHECK_EQ(type_code, kNodeHandle);
if (sptr->is_type<IterVarNode>()) {
return IterVar(sptr)->var;
} else {
......@@ -171,52 +182,58 @@ class APIVariantValue {
}
}
inline operator double() const {
CHECK_EQ(type_id, kDouble);
return v_union.v_double;
CHECK_EQ(type_code, kFloat);
return v_union.v_float64;
}
inline operator int64_t() const {
CHECK_EQ(type_id, kLong);
return v_union.v_long;
CHECK_EQ(type_code, kInt);
return v_union.v_int64;
}
inline operator uint64_t() const {
CHECK_EQ(type_id, kLong);
return v_union.v_long;
CHECK_EQ(type_code, kInt);
return v_union.v_int64;
}
inline operator int() const {
CHECK_EQ(type_id, kLong);
CHECK_LE(v_union.v_long,
CHECK_EQ(type_code, kInt);
CHECK_LE(v_union.v_int64,
std::numeric_limits<int>::max());
return v_union.v_long;
return v_union.v_int64;
}
inline operator bool() const {
CHECK_EQ(type_id, kLong)
<< "expect boolean(int) but get " << TypeId2Str(type_id);
return v_union.v_long != 0;
CHECK_EQ(type_code, kInt)
<< "expect boolean(int) but get "
<< TVMTypeCode2Str(type_code);
return v_union.v_int64 != 0;
}
inline operator std::string() const {
CHECK_EQ(type_id, kStr)
<< "expect Str but get " << TypeId2Str(type_id);
CHECK_EQ(type_code, kStr)
<< "expect Str but get "
<< TVMTypeCode2Str(type_code);
return str;
}
inline operator Type() const {
return String2Type(operator std::string());
}
inline operator runtime::PackedFunc() const {
CHECK_EQ(type_code, kFuncHandle);
return runtime::PackedFunc(func);
}
};
// common defintiion of API function.
using APIFunction = std::function<
using APIFunc = std::function<
void(const std::vector<APIVariantValue> &args, APIVariantValue* ret)>;
/*!
* \brief Registry entry for DataIterator factory functions.
*/
struct APIFunctionReg
: public dmlc::FunctionRegEntryBase<APIFunctionReg,
APIFunction> {
struct APIFuncReg
: public dmlc::FunctionRegEntryBase<APIFuncReg,
APIFunc> {
};
#define TVM_REGISTER_API(TypeName) \
DMLC_REGISTRY_REGISTER(::tvm::APIFunctionReg, APIFunctionReg, TypeName) \
#define TVM_REGISTER_API(TypeName) \
DMLC_REGISTRY_REGISTER(::tvm::APIFuncReg, APIFuncReg, TypeName) \
} // namespace tvm
......
......@@ -7,7 +7,6 @@
#define TVM_CODEGEN_CODEGEN_C_H_
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/module.h>
#include <string>
#include <unordered_map>
......
......@@ -3,7 +3,8 @@
* \file c_runtime_api.cc
* \brief Device specific implementations
*/
#include <tvm/c_runtime_api.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/runtime.h>
#include <algorithm>
#include "./runtime_base.h"
#include "./device_api.h"
......@@ -34,7 +35,7 @@ inline void TVMArrayFree_(TVMArray* arr) {
delete arr;
}
inline void VerifyType(TVMDataType dtype) {
inline void VerifyType(TVMType dtype) {
CHECK_GE(dtype.lanes, 1U);
if (dtype.type_code == kFloat) {
CHECK_EQ(dtype.bits % 32U, 0U);
......@@ -98,7 +99,7 @@ int TVMContextEnabled(TVMContext ctx,
int TVMArrayAlloc(const tvm_index_t* shape,
tvm_index_t ndim,
TVMDataType dtype,
TVMType dtype,
TVMContext ctx,
TVMArrayHandle* out) {
TVMArray* arr = nullptr;
......@@ -166,3 +167,19 @@ int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) {
});
API_END();
}
int TVMFuncFree(TVMFunctionHandle func) {
API_BEGIN();
delete static_cast<PackedFunc::FType*>(func);
API_END();
}
int TVMFuncCall(TVMFunctionHandle func,
TVMValue* args,
int* arg_type_codes,
int num_args) {
API_BEGIN();
(*static_cast<const PackedFunc::FType*>(func))(
args, arg_type_codes, num_args);
API_END();
}
......@@ -7,7 +7,7 @@
#define TVM_RUNTIME_DEVICE_API_H_
#include <tvm/base.h>
#include <tvm/c_runtime_api.h>
#include <tvm/runtime/c_runtime_api.h>
namespace tvm {
namespace runtime {
......
......@@ -6,7 +6,7 @@
#ifndef TVM_RUNTIME_RUNTIME_BASE_H_
#define TVM_RUNTIME_RUNTIME_BASE_H_
#include <tvm/c_runtime_api.h>
#include <tvm/runtime/c_runtime_api.h>
#include <stdexcept>
/*! \brief macro to guard beginning and end section of all functions */
......
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/runtime/runtime.h>
TEST(PackedFunc, Basic) {
using namespace tvm::runtime;
int x = 0;
void* handle = &x;
TVMArray a;
PackedFunc([&](const TVMValue* args, const int* type_codes, int num_args) {
CHECK(num_args == 3);
CHECK(args[0].v_float64 == 1.0);
CHECK(type_codes[0] == kFloat);
CHECK(args[1].v_handle == &a);
CHECK(type_codes[1] == kHandle);
CHECK(args[2].v_handle == &x);
CHECK(type_codes[2] == kHandle);
})(1.0, &a, handle);
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
......@@ -15,11 +15,11 @@ def mock_test_add():
thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x")
_, x = s[C].split(C.op.axis[0], factor=num_thread, outer=grid_x)
_, x = s[C].split(x, outer=thread_x)
# compile to IR
bounds = tvm.schedule.InferBound(s)
stmt = tvm.ir_pass.ScheduleOps(s, bounds)
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
Bb = tvm.Buffer(B.shape, B.dtype, name='B')
Cb = tvm.Buffer(C.shape, C.dtype, name='C')
......
import tvm
import numpy as np
def test_function():
ctx = tvm.cpu(0)
x = np.random.randint(0, 10, size=(3, 4))
x = np.array(x)
y = tvm.nd.array(x, ctx=ctx)
f = tvm.codegen.DummyHelloFunction()
f(y, 10)
if __name__ == "__main__":
test_function()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment