Commit f2406eae by Wei Chen Committed by Jared Roesch

Create closure object for GlobalVar (#3411)

parent e9634ead
......@@ -482,8 +482,8 @@ class Prelude:
with open(prelude_file) as prelude:
prelude = fromtext(prelude.read())
self.mod.update(prelude)
self.id = self.mod["id"]
self.compose = self.mod["compose"]
self.id = self.mod.get_global_var("id")
self.compose = self.mod.get_global_var("compose")
def __init__(self, mod):
......
......@@ -231,8 +231,12 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
}
void VisitExpr_(const GlobalVarNode* gvar) {
// TODO(wweic): Support Load GlobalVar into a register
LOG(FATAL) << "Loading GlobalVar into register is not yet supported";
auto var = GetRef<GlobalVar>(gvar);
auto func = this->context->module->Lookup(var);
auto it = this->context->global_map.find(var);
CHECK(it != this->context->global_map.end());
// Allocate closure with zero free vars
Emit(Instruction::AllocClosure(it->second, 0, {}, NewRegister()));
}
void VisitExpr_(const IfNode* if_node) {
......
......@@ -250,6 +250,47 @@ def test_let_scalar():
result = veval(f, x_data)
tvm.testing.assert_allclose(result.asnumpy(), x_data + 42.0)
def test_compose():
mod = relay.Module()
p = Prelude(mod)
compose = p.compose
# remove all functions to not have pattern match to pass vm compilation
# TODO(wweic): remove the hack and implement pattern match
for v, _ in mod.functions.items():
if v.name_hint == 'compose':
continue
mod[v] = relay.const(0)
# add_one = fun x -> x + 1
sb = relay.ScopeBuilder()
x = relay.var('x', 'float32')
x1 = sb.let('x1', x)
xplusone = x1 + relay.const(1.0, 'float32')
sb.ret(xplusone)
body = sb.get()
add_one = relay.GlobalVar("add_one")
add_one_func = relay.Function([x], body)
# add_two = compose(add_one, add_one)
sb = relay.ScopeBuilder()
y = relay.var('y', 'float32')
add_two_func = sb.let('add_two', compose(add_one_func, add_one_func))
add_two_res = add_two_func(y)
sb.ret(add_two_res)
add_two_body = sb.get()
mod[add_one] = add_one_func
f = relay.Function([y], add_two_body)
mod[mod.entry_func] = f
x_data = np.array(np.random.rand()).astype('float32')
result = veval(mod)(x_data)
tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0)
def test_closure():
x = relay.var('x', shape=())
y = relay.var('y', shape=())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment