Unverified Commit 6a4d71ff by Jared Roesch Committed by GitHub

[Relay][Runtime] Add VM compiler. (#3139)

* Implement the VM compiler

* Fix issues

* Fix ASF headers

* Fix test issue

* Apply typo fixes.

* Update src/relay/backend/vm/compiler.cc

Co-Authored-By: 雾雨魔理沙 <lolisa@marisa.moe>

* Refactor compiler

* Fix

* Fix

* Fix in benchmark

* Fix

* Address comments
parent 5f5bf797
......@@ -65,6 +65,7 @@
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/type.h>
#include <tvm/relay/adt.h>
#include <tvm/runtime/vm.h>
#include <string>
#include <vector>
......@@ -593,6 +594,18 @@ TVM_DLL Expr ToGraphNormalForm(const Expr& e);
* As a side effect, code size will explode.
*/
Expr PartialEval(const Expr& e);
namespace vm {
/*! \brief Compile a module, and construct the virtual machine.
*
* \param mod The module to compile.
* \return The constructed virtual machine.
*/
runtime::vm::VirtualMachine CompileModule(const Module& mod);
} // namespace vm
} // namespace relay
} // namespace tvm
......
......@@ -265,7 +265,7 @@ struct Instruction {
Instruction();
Instruction(const Instruction& instr);
Instruction& operator=(const Instruction& instr) = delete;
Instruction& operator=(const Instruction& instr);
~Instruction();
friend std::ostream& operator<<(std::ostream& os, const Instruction&);
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file tvm/relay/backend/vm/inline_primitives.cc
* \brief Ensure that primitives only appear in the call position.
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/logging.h>
#include <tvm/relay/pass.h>
#include <tvm/runtime/vm.h>
#include <iostream>
#include <vector>
using namespace tvm::runtime;
namespace tvm {
namespace relay {
namespace vm {
struct PrimitiveInliner : ExprMutator {
Module module_;
std::unordered_map<Var, Expr, NodeHash, NodeEqual> var_map;
explicit PrimitiveInliner(const Module& module) : module_(module) {}
Expr VisitExpr_(const LetNode* let_node) {
var_map.insert({let_node->var, VisitExpr(let_node->value)});
return ExprMutator::VisitExpr_(let_node);
}
Expr VisitExpr_(const CallNode* call) {
Expr op = call->op;
// For now just collapse the chain of variables to see if
// they point to a primitive function.
const VarNode* var_node;
// Collapse a chain of let bindings
//
// let x = fn (..) { .. };
// let y = x
// let w = y
// in w(...)
while ((var_node = op.as<VarNode>())) {
auto var = GetRef<Var>(var_node);
DLOG(INFO) << "Var: " << var << std::endl;
auto it = var_map.find(GetRef<Var>(var_node));
if (it != var_map.end()) {
op = it->second;
} else {
return ExprMutator::VisitExpr_(call);
}
}
if (auto func = op.as<FunctionNode>()) {
if (func->IsPrimitive()) {
return CallNode::make(GetRef<Function>(func), call->args, call->attrs, call->type_args);
}
}
if (auto global = op.as<GlobalVarNode>()) {
return CallNode::make(GetRef<GlobalVar>(global), call->args, call->attrs, call->type_args);
}
return ExprMutator::VisitExpr_(call);
}
Expr VisitExpr_(const FunctionNode* func) {
if (func->IsPrimitive()) {
return GetRef<Function>(func);
} else {
return ExprMutator::VisitExpr_(func);
}
}
Function Inline(const Function& func) {
DLOG(INFO) << "Before inlining primitives: " << std::endl
<< "func= " << AsText(func, false) << std::endl;
auto inlined = FunctionNode::make(func->params, VisitExpr(func->body), func->ret_type,
func->type_params, func->attrs);
inlined = Downcast<Function>(DeadCodeElimination(inlined));
DLOG(INFO) << "After inlining primitives" << std::endl
<< "after_func= " << AsText(inlined, false) << std::endl;
return inlined;
}
};
// TODO(@jroesch): write verifier
/* This pass will eliminate primitives which have been lifted by the ANF
* transform inlining them directly into call sites.
*
* This makes VM related code generation easier as the call target is always
* a primitive function.
*
* let prim = fn(...) { ... };
* prim(...)
*
* will become:
*
* (fn(...) { ... })(...)
*/
Module InlinePrimitives(const Module& module) {
PrimitiveInliner inliner(module);
tvm::Map<GlobalVar, Function> updates;
// There is an ordering bug here.
for (auto pair : module->functions) {
auto global = pair.first;
auto func = pair.second;
updates.Set(global, inliner.Inline(func));
}
for (auto pair : updates) {
module->Add(pair.first, pair.second, true);
}
return module;
}
} // namespace vm
} // namespace relay
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file tvm/relay/backend/vm/lambda_lift.cc
* \brief Lift all local functions into global functions.
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/logging.h>
#include <tvm/relay/pass.h>
#include <tvm/runtime/vm.h>
#include <iostream>
#include <vector>
using namespace tvm::runtime;
namespace tvm {
namespace relay {
namespace vm {
static const char* kIsClosure = "IsClosure";
inline std::string GenerateName(const Function& func) {
size_t hash = StructuralHash()(func);
return std::string("lifted_name") + std::to_string(hash);
}
bool IsClosure(const Function& func) {
NodeRef res = FunctionGetAttr(func, kIsClosure);
const ir::IntImm* pval = res.as<ir::IntImm>();
return pval && pval->value != 0;
}
Function MarkClosure(const Function& func) {
return FunctionSetAttr(func, kIsClosure, tvm::Integer(1));
}
struct LambdaLifter : ExprMutator {
Module module_;
std::vector<std::pair<GlobalVar, Function>> lifted_;
explicit LambdaLifter(const Module& module) : module_(module) {}
Expr VisitExpr_(const FunctionNode* func_node) final {
auto func = GetRef<Function>(func_node);
// We should not transform primitive functions.
if (func->IsPrimitive()) {
return std::move(func);
}
auto free_vars = FreeVars(func);
auto free_type_vars = FreeTypeVars(func, module_);
auto body = Downcast<Function>(ExprMutator::VisitExpr_(func_node));
// When performing this optimization there are two
// cases.
//
// The first case in which we have no free variables
// we can just lift the function into the global
// environment without needing to allocate a closure.
//
//
// The second case requires that we generate a special
// function with makes a distinction between allocating
// a closure, and then the code for the closure.
//
// We represent a closure allocation by lifting the
// closure to a global function which takes its
// captured arguments and then directly returns
// the function representing the closure's code.
//
// When we generate code later on a call to the "outer"
// function marked as a closure is used to emit allocation
// code for the closure's environment.
//
// The "inner" function is should be used to generate the
// code for the closure.
Function lifted_func;
if (free_vars.size() == 0) {
lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, free_type_vars);
} else {
lifted_func =
FunctionNode::make(free_vars, body, func->func_type_annotation(), free_type_vars);
lifted_func = MarkClosure(lifted_func);
}
CHECK(lifted_func.defined());
auto name = GenerateName(lifted_func);
auto global = this->module_->GetGlobalVar(name);
lifted_.push_back({global, lifted_func});
if (free_vars.size() == 0) {
return std::move(global);
} else {
// If we need to allocate a closure
// we pass the variables in its environment
// here.
Array<Expr> fvs;
for (auto fv : free_vars) {
fvs.push_back(fv);
}
return CallNode::make(global, fvs);
}
}
Function Lift(const Function& func) {
DLOG(INFO) << "Lifting: " << AsText(func, false) << std::endl;
return FunctionNode::make(func->params, VisitExpr(func->body), func->ret_type,
func->type_params, func->attrs);
}
};
/* The goal of this pass is to lift out any nested functions into top-level
* functions.
*
* We will lift the functions out into globals which take the set of the free vars
* and then return a function whcih has b
*/
Module LambdaLift(const Module& module) {
LambdaLifter lifter(module);
tvm::Map<GlobalVar, Function> updates;
// There is an ordering bug here.
for (auto pair : module->functions) {
auto global = pair.first;
auto func = pair.second;
updates.Set(global, lifter.Lift(func));
}
for (auto i = lifter.lifted_.begin(); i != lifter.lifted_.end(); i++) {
module->Add(i->first, i->second);
}
for (auto pair : updates) {
module->Add(pair.first, pair.second, true);
}
return module;
}
} // namespace vm
} // namespace relay
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/vm.cc
* \brief The Relay virtual machine.
*/
#include <tvm/relay/interpreter.h>
#include <tvm/logging.h>
#include <tvm/relay/module.h>
#include <tvm/runtime/vm.h>
#include <tvm/relay/pass.h>
namespace tvm {
namespace relay {
namespace vm {
using tvm::runtime::Object;
using tvm::runtime::ObjectTag;
using tvm::runtime::vm::VirtualMachine;
VirtualMachine FromModule(const Module& module, const std::vector<TVMContext>& ctxs) {
auto vm = CompileModule(module);
vm.Init(ctxs);
return vm;
}
Object EvaluateModule(const Module& module, const std::vector<TVMContext> ctxs,
const std::vector<Object>& vm_args) {
VirtualMachine vm = FromModule(module, ctxs);
// TODO(zhiics): This measurement is for temporary usage. Remove it later. We
// need to introduce a better profiling method.
#if ENABLE_PROFILING
DLOG(INFO) << "Entry function is " << module->entry_func << std::endl;
auto start = std::chrono::high_resolution_clock::now();
#endif // ENABLE_PROFILING
Object res = vm.Invoke(module->entry_func->name_hint, vm_args);
#if ENABLE_PROFILING
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
LOG(INFO) << "Inference time: " << duration << "ms\n";
#endif // ENABLE_PROFILING
return res;
}
Value VMToValue(const relay::Module& module, const relay::Type& type, Object obj) {
CHECK(module.defined() && type.defined());
switch (obj->tag) {
case ObjectTag::kTensor: {
CHECK(type.as<TensorTypeNode>()) << "VM internal error: return value must be a tensor";
return TensorValueNode::make(ToNDArray(obj));
}
case ObjectTag::kDatatype: {
// const auto* tuple_type
// const auto& data_type = obj.AsDatatype();
// tvm::Array<Value> fields;
// for (size_t i = 0; i < data_type->fields.size(); ++i) {
// fields.push_back(VMToValue(tag_index_map, data_type->fields[i]));
// }
// return ConstructorValueNode::make(tag_index_map.at(data_type->tag), fields);
LOG(FATAL) << "fix me";
}
default:
LOG(FATAL) << "unsupported return value of type: " << obj->tag;
return Value();
}
}
TVM_REGISTER_API("relay._vm._Tensor").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Object::Tensor(args[0]);
});
TVM_REGISTER_API("relay._vm._Tuple").set_body([](TVMArgs args, TVMRetValue* ret) {
std::vector<Object> fields;
for (auto i = 0; i < args.size(); i++) {
fields.push_back(args[i]);
}
*ret = Object::Tuple(fields);
});
template <typename T>
std::string ToString(const T& t) {
std::stringstream s;
s << t;
return s.str();
}
TVM_REGISTER_API("relay._vm._ObjectTag").set_body([](TVMArgs args, TVMRetValue* ret) {
Object obj = args[0];
*ret = ToString(obj->tag);
});
TVM_REGISTER_API("relay._vm._Datatype")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int itag = args[0];
size_t tag = static_cast<size_t>(itag);
std::vector<Object> fields;
for (int i = 1; i < args.size(); i++) {
fields.push_back(args[i]);
}
*ret = Object::Datatype(tag, fields);
});
TVM_REGISTER_API("relay._vm._evaluate_vm").set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef to_compile = args[0];
TVMContext ctx;
int dev_type = args[1];
ctx.device_type = static_cast<DLDeviceType>(dev_type);
ctx.device_id = args[2];
Module module;
if (to_compile.as<FunctionNode>()) {
Function to_compile = args[0];
module = ModuleNode::FromExpr(to_compile);
} else if (to_compile.as<ModuleNode>()) {
module = args[0];
} else {
LOG(FATAL) << "expected function or module";
}
auto return_type = module->Lookup(module->entry_func)->ret_type;
std::vector<Object> vm_args;
for (auto i = 3; i < args.size(); i++) {
Object obj = args[i];
vm_args.push_back(obj);
}
auto result = EvaluateModule(module, {ctx}, vm_args);
DLOG(INFO) << "Evaluate VM returning: result=" << result->tag;
*ret = VMToValue(module, return_type, result);
});
} // namespace vm
} // namespace relay
} // namespace tvm
......@@ -154,6 +154,9 @@ Array<Tensor> ReduceCompute(const Attrs& attrs,
F f) {
const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr);
if (inputs[0]->shape.size() == 0) {
return { topi::identity(inputs[0]) };
}
auto axes = param->axis;
if (param->exclude) {
axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis);
......@@ -251,7 +254,6 @@ bool ReduceRel(const Array<Type>& types,
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
CHECK(static_cast<int>(data->shape.size()) != 0);
std::vector<IndexExpr>&& in_shape = AsVector(data->shape);
const ReduceAttrs* param = attrs.as<ReduceAttrs>();
......
......@@ -124,7 +124,8 @@ class CalcDep : private ExprVisitor {
friend CalcDep;
bool HasLet(const Var& v) {
return (use_map_[v] > 1 || (use_map_[v] != 0 && letrec_set_.count(v) != 0));
// TODO(@jroesch): MK fix me
return (use_map_[v] > 0 || (use_map_[v] != 0 && letrec_set_.count(v) != 0));
}
Expr VisitExpr_(const VarNode* op) final {
......
......@@ -118,6 +118,86 @@ Instruction::Instruction(const Instruction& instr) {
}
}
template<typename T>
static inline void FreeIf(T* t) {
if (t != nullptr) {
delete t;
}
}
Instruction& Instruction::operator=(const Instruction& instr) {
this->op = instr.op;
this->dst = instr.dst;
switch (instr.op) {
case Opcode::Move:
this->from = instr.from;
return *this;
case Opcode::Select:
this->select_cond = instr.select_cond;
this->select_op1 = instr.select_op1;
this->select_op2 = instr.select_op2;
return *this;
case Opcode::Ret:
this->result = instr.result;
return *this;
case Opcode::AllocTensor:
this->shape_register = instr.shape_register;
this->dtype = instr.dtype;
return *this;
case Opcode::AllocDatatype:
this->constructor_tag = instr.constructor_tag;
this->num_fields = instr.num_fields;
FreeIf(this->datatype_fields);
this->datatype_fields = Duplicate<RegName>(instr.datatype_fields, instr.num_fields);
return *this;
case Opcode::AllocClosure:
this->clo_index = instr.clo_index;
this->num_freevar = instr.num_freevar;
FreeIf(this->free_vars);
this->free_vars = Duplicate<RegName>(instr.free_vars, instr.num_freevar);
return *this;
case Opcode::InvokePacked:
this->packed_index = instr.packed_index;
this->arity = instr.arity;
this->output_size = instr.output_size;
FreeIf(this->packed_args);
this->packed_args = Duplicate<RegName>(instr.packed_args, instr.arity);
return *this;
case Opcode::InvokeClosure:
this->closure = instr.closure;
this->closure_args_num = instr.closure_args_num;
FreeIf(this->closure_args);
this->closure_args = Duplicate<RegName>(instr.closure_args, instr.closure_args_num);
return *this;
case Opcode::Invoke:
this->func_index = instr.func_index;
this->num_args = instr.num_args;
FreeIf(this->invoke_args_registers);
this->invoke_args_registers = Duplicate<RegName>(instr.invoke_args_registers, instr.num_args);
return *this;
case Opcode::If:
this->if_cond = instr.if_cond;
this->true_offset = instr.true_offset;
this->false_offset = instr.false_offset;
return *this;
case Opcode::LoadConst:
this->const_index = instr.const_index;
return *this;
case Opcode::GetField:
this->object = instr.object;
this->field_index = instr.field_index;
return *this;
case Opcode::Goto:
this->pc_offset = instr.pc_offset;
return *this;
default:
std::ostringstream out;
out << "Invalid instruction " << static_cast<int>(instr.op);
throw std::runtime_error(out.str());
}
}
Instruction::~Instruction() {
switch (this->op) {
case Opcode::Move:
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Benchmarking Relay VM using models from MXNet."""
import numpy as np
import tvm
from tvm.contrib import graph_runtime
from tvm import relay
from tvm.relay import testing
def benchmark_execution(net,
params,
measure=False,
data_shape=(1, 3, 224, 224),
out_shape=(1, 1000),
dtype='float32'):
def get_tvm_output(net, data, params, target, ctx, dtype='float32'):
with relay.build_config(opt_level=1):
graph, lib, params = relay.build(net, target, params=params)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input("data", data)
m.set_input(**params)
m.run()
out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
if measure:
print("Evaluate graph runtime inference time cost...")
ftimer = m.module.time_evaluator("run", ctx, number=1, repeat=20)
# Measure in millisecond.
prof_res = np.array(ftimer().results) * 1000
print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
(np.mean(prof_res), np.std(prof_res)))
return out.asnumpy()
def get_tvm_vm_output(net, data, params, target, ctx, dtype='float32'):
ex = relay.create_executor('vm', mod=relay.Module(), ctx=ctx)
result = ex.evaluate(net)(data, **params)
return result.asnumpy().astype(dtype)
# random input
data = np.random.uniform(size=data_shape).astype(dtype)
target = "llvm"
ctx = tvm.cpu(0)
tvm_out = get_tvm_output(net, tvm.nd.array(data.astype(dtype)), params,
target, ctx, dtype)
vm_out = get_tvm_vm_output(net, tvm.nd.array(data.astype(dtype)), params,
target, ctx, dtype)
tvm.testing.assert_allclose(vm_out, tvm_out, rtol=1e-5, atol=1e-5)
def test_mlp():
image_shape = (1, 28, 28)
net, params = testing.mlp.get_workload(1)
benchmark_execution(net, params, data_shape=image_shape, out_shape=(1, 10))
def test_vgg():
for n in [11, 16]:
net, params = testing.vgg.get_workload(1, num_layers=n)
benchmark_execution(net, params)
def test_resnet():
for n in [18, 50]:
net, params = testing.resnet.get_workload(batch_size=1, num_layers=n)
benchmark_execution(net, params, True)
def test_squeezenet():
for version in ['1.0', '1.1']:
net, params = testing.squeezenet.get_workload(version=version)
benchmark_execution(net, params)
def test_inception_v3():
image_shape = (3, 299, 299)
net, params = testing.inception_v3.get_workload(image_shape=image_shape)
benchmark_execution(net, params, data_shape=image_shape)
def test_dqn():
image_shape = (4, 84, 84)
net, params = testing.dqn.get_workload(
batch_size=1, image_shape=image_shape)
benchmark_execution(net, params, data_shape=image_shape, out_shape=(1, 18))
def test_dcgan():
image_shape = (1, 100)
net, params = testing.dcgan.get_workload(batch_size=1)
benchmark_execution(net, params, data_shape=image_shape)
def test_mobilenet():
net, params = testing.mobilenet.get_workload(batch_size=1)
benchmark_execution(net, params)
def test_densenet():
net, params = testing.densenet.get_workload(batch_size=1)
benchmark_execution(net, params)
if __name__ == '__main__':
test_resnet()
test_vgg()
test_squeezenet()
test_mobilenet()
test_densenet()
# The following networks fail
# test_inception_v3()
# test_mlp()
# test_dqn()
# test_dcgan()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
from nose.tools import nottest
import tvm
import numpy as np
from tvm import relay
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.prelude import Prelude
def veval(f, *args, ctx=tvm.cpu()):
if isinstance(f, relay.Expr):
ex = relay.create_executor('vm', mod=relay.Module(), ctx=ctx)
if len(args) == 0:
return ex.evaluate(f)
else:
return ex.evaluate(f)(*args)
else:
assert isinstance(f, relay.Module), "expected expression or module"
mod = f
ex = relay.create_executor('vm', mod=mod, ctx=ctx)
if len(args) == 0:
return ex.evaluate(mod[mod.entry_func])
else:
return ex.evaluate(mod[mod.entry_func])(*args)
def test_split():
x = relay.var('x', shape=(12,))
y = relay.split(x, 3, axis=0).astuple()
z = relay.concatenate([relay.TupleGetItem(y, 0)], axis=0)
f = relay.Function([x], z)
x_data = np.random.rand(12,).astype('float32')
res = veval(f, x_data)
tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0])
def test_id():
x = relay.var('x', shape=(10, 10))
f = relay.Function([x], x)
x_data = np.random.rand(10, 10).astype('float64')
res = veval(f, x_data)
tvm.testing.assert_allclose(res.asnumpy(), x_data)
def test_op():
x = relay.var('x', shape=(10, 10))
f = relay.Function([x], x + x)
x_data = np.random.rand(10, 10).astype('float32')
res = veval(f, x_data)
tvm.testing.assert_allclose(res.asnumpy(), x_data + x_data)
def any(x):
x = relay.op.nn.batch_flatten(x)
return relay.op.min(x, axis=[0, 1])
def test_cond():
x = relay.var('x', shape=(10, 10))
y = relay.var('x', shape=(10, 10))
# f = relay.Function([x, y], relay.op.equal(x, y))
f = relay.Function([x, y], any(relay.op.equal(x, y)))
x_data = np.random.rand(10, 10).astype('float32')
y_data = np.random.rand(10, 10).astype('float32')
# same
res = veval(f, x_data, x_data)
np.testing.assert_allclose(res.asnumpy(), True)
# diff
res = veval(f, x_data, y_data)
tvm.testing.assert_allclose(res.asnumpy(), False)
def test_simple_if():
x = relay.var('x', shape=(10, 10))
y = relay.var('y', shape=(10, 10))
f = relay.Function([x, y],
relay.If(any(relay.op.equal(x, y)), x, y))
x_data = np.random.rand(10, 10).astype('float32')
y_data = np.random.rand(10, 10).astype('float32')
# same
res = veval(f, x_data, x_data)
tvm.testing.assert_allclose(res.asnumpy(), x_data)
# diff
res = veval(f, x_data, y_data)
tvm.testing.assert_allclose(res.asnumpy(), y_data)
def test_simple_call():
mod = relay.module.Module({})
sum_up = relay.GlobalVar('sum_up')
i = relay.var('i', shape=[], dtype='int32')
sb = ScopeBuilder()
sb.ret(i)
func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32'))
mod[sum_up] = func
i_data = np.array(0, dtype='int32')
iarg = relay.var('i', shape=[], dtype='int32')
mod[mod.entry_func] = relay.Function([iarg], sum_up(iarg))
result = veval(mod, i_data)
tvm.testing.assert_allclose(result.asnumpy(), i_data)
def test_count_loop():
mod = relay.module.Module({})
sum_up = relay.GlobalVar('sum_up')
i = relay.var('i', shape=[], dtype='int32')
sb = ScopeBuilder()
with sb.if_scope(relay.equal(i, relay.const(0, dtype='int32'))):
sb.ret(i)
with sb.else_scope():
one_less = relay.subtract(i, relay.const(1, dtype='int32'))
rec_call = relay.Call(sum_up, [one_less])
sb.ret(relay.add(rec_call, i))
func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32'))
mod[sum_up] = func
i_data = np.array(0, dtype='int32')
iarg = relay.var('i', shape=[], dtype='int32')
mod[mod.entry_func] = relay.Function([iarg], sum_up(iarg))
result = veval(mod, i_data)
tvm.testing.assert_allclose(result.asnumpy(), i_data)
def test_sum_loop():
mod = relay.module.Module({})
sum_up = relay.GlobalVar('sum_up')
i = relay.var('i', shape=[], dtype='int32')
accum = relay.var('accum', shape=[], dtype='int32')
sb = ScopeBuilder()
with sb.if_scope(relay.equal(i, relay.const(0, 'int32'))):
sb.ret(accum)
with sb.else_scope():
one_less = relay.subtract(i, relay.const(1, 'int32'))
new_accum = relay.add(accum, i)
sb.ret(relay.Call(sum_up, [one_less, new_accum]))
func = relay.Function([i, accum], sb.get())
mod[sum_up] = func
loop_bound = 0
i_data = np.array(loop_bound, dtype='int32')
accum_data = np.array(0, dtype='int32')
iarg = relay.var('i', shape=[], dtype='int32')
aarg = relay.var('accum', shape=[], dtype='int32')
mod[mod.entry_func] = relay.Function([iarg, aarg], sum_up(iarg, aarg))
result = veval(mod, i_data, accum_data)
tvm.testing.assert_allclose(result.asnumpy(), sum(range(1, loop_bound + 1)))
def test_tuple_fst():
ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))])
tup = relay.var('tup', type_annotation=ttype)
f = relay.Function([tup], relay.TupleGetItem(tup, 0))
i_data = np.random.rand(41).astype('float32')
j_data = np.random.rand(10).astype('float32')
result = veval(f, (i_data, j_data))
tvm.testing.assert_allclose(result.asnumpy(), i_data)
def test_tuple_second():
ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))])
tup = relay.var('tup', type_annotation=ttype)
f = relay.Function([tup], relay.TupleGetItem(tup, 1))
i_data = np.random.rand(41).astype('float32')
j_data = np.random.rand(10).astype('float32')
result = veval(f, (i_data, j_data))
tvm.testing.assert_allclose(result.asnumpy(), j_data)
@nottest
def test_list_constructor():
# TODO(wweic): implement pattern match to support this test
def to_list(o):
if isinstance(o, tvm.relay.backend.interpreter.TensorValue):
return [o.data.asnumpy().tolist()]
if isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
result = []
for f in o.fields:
result.extend(to_list(f))
return result
mod = relay.Module()
p = Prelude(mod)
nil = p.nil
cons = p.cons
l = p.l
one2 = cons(relay.const(1), nil())
one3 = cons(relay.const(2), one2)
one4 = cons(relay.const(3), one3)
f = relay.Function([], one4)
mod[mod.entry_func] = f
result = veval(mod)()
obj = to_list(result)
import pdb; pdb.set_trace()
tvm.testing.assert_allclose(obj, np.array([3,2,1]))
def test_let_tensor():
sb = relay.ScopeBuilder()
shape = (1,)
x = relay.var('x', shape=shape, dtype='float32')
x1 = relay.var('x1', shape=shape, dtype='float32')
x1 = sb.let(x1, x)
xplusone = x1 + relay.const(42.0, 'float32')
sb.ret(xplusone)
body = sb.get()
f = relay.Function([x], body)
x_data = np.random.rand(*shape).astype('float32')
result = veval(f, x_data)
tvm.testing.assert_allclose(result.asnumpy(), x_data + 42.0)
def test_let_scalar():
sb = relay.ScopeBuilder()
x = relay.var('x', 'float32')
x1 = sb.let('x1', x)
xplusone = x1 + relay.const(42.0, 'float32')
sb.ret(xplusone)
body = sb.get()
f = relay.Function([x], body)
x_data = np.array(np.random.rand()).astype('float32')
result = veval(f, x_data)
tvm.testing.assert_allclose(result.asnumpy(), x_data + 42.0)
def test_closure():
x = relay.var('x', shape=())
y = relay.var('y', shape=())
f = relay.Function([x], x + y)
ff = relay.Function([y], f)
clo = ff(relay.const(1.0))
main = clo(relay.const(2.0))
res = veval(main)
tvm.testing.assert_allclose(res.asnumpy(), 3.0)
if __name__ == "__main__":
test_id()
test_op()
test_cond()
test_simple_if()
test_simple_call()
test_count_loop()
test_sum_loop()
test_tuple_fst()
test_tuple_second()
test_let_scalar()
test_let_tensor()
# TODO(@jroesch): restore when match is supported
# test_list_constructor()
test_closure()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment