diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index dc094e00e05b38eea6e313128fab4614614f8638..b10e9f2e2ea3faee6875cdc184c02e4103281335 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -82,8 +82,6 @@ class ScheduleGetter : } } readable_name_stream_ << "fused"; - // enter the target context - TargetContext target_ctx(target_); cache_node->outputs = this->VisitExpr(prim_func->body); cache_node->func_name = readable_name_stream_.str(); CachedFunc cfunc(cache_node); @@ -284,6 +282,9 @@ class CompileEngineImpl : public CompileEngineNode { value->use_count = 0; cache_[key] = value; } + // Enforce use the target. + TargetContext target_ctx(key->target); + CHECK(!value->cached_func.defined()); auto spair = CreateSchedule(key->source_func, key->target); auto cache_node = make_node<CachedFuncNode>( diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index a5d514b76556e73188d184abc5dc8c162a9923dd..6237bcdce7a83956088949af1d42508b14ba7616 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -107,6 +107,10 @@ Expr FoldConstant(const Expr& expr) { ctx.device_type = kDLCPU; ctx.device_id = 0; Target target = Target::create("llvm"); + // use a fresh build context + // in case we are already in a build context. + BuildConfigContext fresh_build_ctx(build_config()); + return ConstantFolder(CreateInterpreter( Module(nullptr), ctx, target)).Mutate(expr); } diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py index 4d9e397be975a7ebfcee3b97f68bc3e42fa4d84a..250cfc70cc28374330cf582262842d073117c7c3 100644 --- a/tests/python/relay/test_pass_fold_constant.py +++ b/tests/python/relay/test_pass_fold_constant.py @@ -1,4 +1,5 @@ import numpy as np +import tvm from tvm import relay @@ -19,7 +20,13 @@ def test_fold_const(): y = relay.add(x, relay.const(c_folded)) z = relay.add(y, relay.const(c_data)) return relay.Function([x], z) - zz = relay.ir_pass.fold_constant(before()) + + def fail(x): + raise RuntimeError() + # the fold constant should work on any context. + with tvm.build_config(add_lower_pass=[(0, fail)]): + with tvm.target.create("cuda"): + zz = relay.ir_pass.fold_constant(before()) zexpected = expected() assert relay.ir_pass.alpha_equal(zz, zexpected)