Skip to content
Snippets Groups Projects
Commit c5395a1f authored by tqchen's avatar tqchen
Browse files

Expose testcase as bound inference to python, now push toward the goal!

parent 54c18a6c
No related branches found
No related tags found
No related merge requests found
......@@ -239,7 +239,8 @@ def _init_function_module(root_namespace):
module_internal = sys.modules["%s._function_internal" % root_namespace]
namespace_match = {
"_make_" : sys.modules["%s.make" % root_namespace],
"_pass_" : sys.modules["%s.ir_pass" % root_namespace]
"_pass_" : sys.modules["%s.ir_pass" % root_namespace],
"_schedule_" : sys.modules["%s.schedule" % root_namespace]
}
for name in op_names:
......
......@@ -7,10 +7,10 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include "./c_api_registry.h"
#include "../schedule/bound.h"
namespace tvm {
namespace ir {
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
......@@ -21,6 +21,12 @@ using RetValue = APIVariantValue;
*ret = PassName(args.at(0)); \
}) \
#define REGISTER_PASS2(PassName) \
TVM_REGISTER_API(_pass_## PassName) \
.set_body([](const ArgStack& args, RetValue *ret) { \
*ret = PassName(args.at(0), args.at(1)); \
}) \
#define REGISTER_PASS4(PassName) \
TVM_REGISTER_API(_pass_## PassName) \
.set_body([](const ArgStack& args, RetValue *ret) { \
......
/*!
* Copyright (c) 2016 by Contributors
* Implementation of API functions related to schedule pass.
* \file c_api_lang.cc
*/
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/schedule.h>
#include "../schedule/bound.h"
#include "./c_api_registry.h"
namespace tvm {
namespace schedule {
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
#define REGISTER_SCHEDULE_PASS1(PassName) \
TVM_REGISTER_API(_schedule_## PassName) \
.set_body([](const ArgStack& args, RetValue *ret) { \
*ret = PassName(args.at(0)); \
}) \
REGISTER_SCHEDULE_PASS1(InferBound);
} // namespace schedule
} // namespace tvm
......@@ -143,7 +143,7 @@ void InferBound(const Schedule& sch,
}
std::unordered_map<IterVar, Range> InferBound(Schedule sch) {
Map<IterVar, Range> InferBound(Schedule sch) {
return {};
}
......
......@@ -19,7 +19,7 @@ namespace schedule {
* \param sch The root schedule to infer all the bounds.
* \return the result bound of the iteration Variable
*/
std::unordered_map<IterVar, Range> InferBound(Schedule sch);
Map<IterVar, Range> InferBound(Schedule sch);
} // namespace schedule
} // namespace tvm
......
import tvm
def test_bound_inference():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j])
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3)
sA1 = tvm.Schedule(A1.op)
sA2 = tvm.Schedule(A2.op)
xo, xi = sA1.split(A1.op.dim_var[0], factor=8)
sA2.compute_at(sA1, xi)
bounds = tvm.schedule.InferBound(sA1)
assert isinstance(bounds, tvm.collections.Map)
print(bounds)
if __name__ == "__main__":
test_bound_inference()
......@@ -48,4 +48,3 @@ if __name__ == "__main__":
test_schedule_create()
test_reorder()
test_tile()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment