Unverified Commit 415f7c4b by Tianqi Chen Committed by GitHub

[RELAY][PASS] Make FoldConst context and target invariant (#2114)

parent a3cfa5ff
......@@ -82,8 +82,6 @@ class ScheduleGetter :
}
}
readable_name_stream_ << "fused";
// enter the target context
TargetContext target_ctx(target_);
cache_node->outputs = this->VisitExpr(prim_func->body);
cache_node->func_name = readable_name_stream_.str();
CachedFunc cfunc(cache_node);
......@@ -284,6 +282,9 @@ class CompileEngineImpl : public CompileEngineNode {
value->use_count = 0;
cache_[key] = value;
}
// Enforce use the target.
TargetContext target_ctx(key->target);
CHECK(!value->cached_func.defined());
auto spair = CreateSchedule(key->source_func, key->target);
auto cache_node = make_node<CachedFuncNode>(
......
......@@ -107,6 +107,10 @@ Expr FoldConstant(const Expr& expr) {
ctx.device_type = kDLCPU;
ctx.device_id = 0;
Target target = Target::create("llvm");
// use a fresh build context
// in case we are already in a build context.
BuildConfigContext fresh_build_ctx(build_config());
return ConstantFolder(CreateInterpreter(
Module(nullptr), ctx, target)).Mutate(expr);
}
......
import numpy as np
import tvm
from tvm import relay
......@@ -19,7 +20,13 @@ def test_fold_const():
y = relay.add(x, relay.const(c_folded))
z = relay.add(y, relay.const(c_data))
return relay.Function([x], z)
zz = relay.ir_pass.fold_constant(before())
def fail(x):
raise RuntimeError()
# the fold constant should work on any context.
with tvm.build_config(add_lower_pass=[(0, fail)]):
with tvm.target.create("cuda"):
zz = relay.ir_pass.fold_constant(before())
zexpected = expected()
assert relay.ir_pass.alpha_equal(zz, zexpected)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment