Unverified Commit 415f7c4b by Tianqi Chen Committed by GitHub

[RELAY][PASS] Make FoldConst context and target invariant (#2114)

parent a3cfa5ff
...@@ -82,8 +82,6 @@ class ScheduleGetter : ...@@ -82,8 +82,6 @@ class ScheduleGetter :
} }
} }
readable_name_stream_ << "fused"; readable_name_stream_ << "fused";
// enter the target context
TargetContext target_ctx(target_);
cache_node->outputs = this->VisitExpr(prim_func->body); cache_node->outputs = this->VisitExpr(prim_func->body);
cache_node->func_name = readable_name_stream_.str(); cache_node->func_name = readable_name_stream_.str();
CachedFunc cfunc(cache_node); CachedFunc cfunc(cache_node);
...@@ -284,6 +282,9 @@ class CompileEngineImpl : public CompileEngineNode { ...@@ -284,6 +282,9 @@ class CompileEngineImpl : public CompileEngineNode {
value->use_count = 0; value->use_count = 0;
cache_[key] = value; cache_[key] = value;
} }
// Enforce use the target.
TargetContext target_ctx(key->target);
CHECK(!value->cached_func.defined()); CHECK(!value->cached_func.defined());
auto spair = CreateSchedule(key->source_func, key->target); auto spair = CreateSchedule(key->source_func, key->target);
auto cache_node = make_node<CachedFuncNode>( auto cache_node = make_node<CachedFuncNode>(
......
...@@ -107,6 +107,10 @@ Expr FoldConstant(const Expr& expr) { ...@@ -107,6 +107,10 @@ Expr FoldConstant(const Expr& expr) {
ctx.device_type = kDLCPU; ctx.device_type = kDLCPU;
ctx.device_id = 0; ctx.device_id = 0;
Target target = Target::create("llvm"); Target target = Target::create("llvm");
// use a fresh build context
// in case we are already in a build context.
BuildConfigContext fresh_build_ctx(build_config());
return ConstantFolder(CreateInterpreter( return ConstantFolder(CreateInterpreter(
Module(nullptr), ctx, target)).Mutate(expr); Module(nullptr), ctx, target)).Mutate(expr);
} }
......
import numpy as np import numpy as np
import tvm
from tvm import relay from tvm import relay
...@@ -19,7 +20,13 @@ def test_fold_const(): ...@@ -19,7 +20,13 @@ def test_fold_const():
y = relay.add(x, relay.const(c_folded)) y = relay.add(x, relay.const(c_folded))
z = relay.add(y, relay.const(c_data)) z = relay.add(y, relay.const(c_data))
return relay.Function([x], z) return relay.Function([x], z)
zz = relay.ir_pass.fold_constant(before())
def fail(x):
raise RuntimeError()
# the fold constant should work on any context.
with tvm.build_config(add_lower_pass=[(0, fail)]):
with tvm.target.create("cuda"):
zz = relay.ir_pass.fold_constant(before())
zexpected = expected() zexpected = expected()
assert relay.ir_pass.alpha_equal(zz, zexpected) assert relay.ir_pass.alpha_equal(zz, zexpected)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment