Unverified Commit 6a4d71ff by Jared Roesch Committed by GitHub

[Relay][Runtime] Add VM compiler. (#3139)

* Implement the VM compiler

* Fix issues

* Fix ASF headers

* Fix test issue

* Apply typo fixes.

* Update src/relay/backend/vm/compiler.cc

Co-Authored-By: 雾雨魔理沙 <lolisa@marisa.moe>

* Refactor compiler

* Fix

* Fix

* Fix in benchmark

* Fix

* Address comments
parent 5f5bf797
......@@ -65,6 +65,7 @@
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/type.h>
#include <tvm/relay/adt.h>
#include <tvm/runtime/vm.h>
#include <string>
#include <vector>
......@@ -593,6 +594,18 @@ TVM_DLL Expr ToGraphNormalForm(const Expr& e);
* As a side effect, code size will explode.
*/
Expr PartialEval(const Expr& e);
namespace vm {
/*! \brief Compile a module, and construct the virtual machine.
*
* \param mod The module to compile.
* \return The constructed virtual machine.
*/
runtime::vm::VirtualMachine CompileModule(const Module& mod);
} // namespace vm
} // namespace relay
} // namespace tvm
......
......@@ -265,7 +265,7 @@ struct Instruction {
Instruction();
Instruction(const Instruction& instr);
Instruction& operator=(const Instruction& instr) = delete;
Instruction& operator=(const Instruction& instr);
~Instruction();
friend std::ostream& operator<<(std::ostream& os, const Instruction&);
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/compiler.cc
* \brief A compiler from relay::Module to the VM byte code.
*/
#include <tvm/relay/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h>
#include <tvm/logging.h>
#include <tvm/relay/pass.h>
#include <tvm/runtime/vm.h>
#include <iostream>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "../../../runtime/vm/naive_allocator.h"
#include "../../backend/compile_engine.h"
namespace tvm {
namespace relay {
namespace vm {
using namespace tvm::runtime;
using namespace tvm::runtime::vm;
// (@jroesch): VM passes, eventually declare as passes.
bool IsClosure(const Function& func);
Module LambdaLift(const Module& module);
Module InlinePrimitives(const Module& module);
template <typename T, typename U>
using NodeMap = std::unordered_map<T, U, NodeHash, NodeEqual>;
using TagMap = NodeMap<tvm::relay::Constructor, Index>;
using TagNameMap = std::unordered_map<size_t, tvm::relay::Constructor>;
using GlobalMap = NodeMap<GlobalVar, Index>;
using ConstMap = NodeMap<Constant, Index>;
using ConstTensorShapeMap = NodeMap<TensorType, std::pair<Index, NDArray>>;
struct VMCompilerContext {
// The module context for the compilation
Module module;
// Error reporter
ErrorReporter err_reporter;
// Map from a unique integer to ADT constructor tag
TagNameMap tag_index_map;
// Map from ADT constructor tag to a unique integer
TagMap tag_map;
// Map from global var to a unique integer
GlobalMap global_map;
// Map from Const object to its index in const pool
ConstMap const_map;
// Map from Const tensor shape to its index in const pool
ConstTensorShapeMap const_tensor_shape_map;
// List of lowered functions
std::vector<LoweredFunc> lowered_funcs;
};
// Compute the constant pool, i.e a mapping from Constant node to constant index.
struct ConstantPool : ExprVisitor {
std::set<GlobalVar> visited;
Module module;
ConstMap const_map;
ConstTensorShapeMap const_tensor_shape_map;
size_t index;
explicit ConstantPool(const Module& mod) : module(mod), const_map(), index(0) {}
void VisitExpr_(const GlobalVarNode* var_node) {
auto gvar = GetRef<GlobalVar>(var_node);
if (visited.find(gvar) == visited.end()) {
visited.insert(gvar);
this->VisitExpr(this->module->Lookup(gvar));
}
}
void AddConstantTensorShape(TensorType expr, NDArray value) {
auto it = this->const_tensor_shape_map.find(expr);
if (it == this->const_tensor_shape_map.end()) {
this->const_tensor_shape_map.insert({expr, std::make_pair(index++, value)});
}
}
void VisitExpr_(const ConstantNode* const_node) {
auto konst = GetRef<Constant>(const_node);
auto it = this->const_map.find(konst);
if (it == this->const_map.end()) {
this->const_map.insert({konst, index++});
}
}
NDArray GetTensorConstant(const TensorTypeNode* ttype) {
std::vector<int64_t> shapes;
for (auto sh : ttype->shape) {
shapes.push_back(Downcast<tvm::Integer>(sh)->value);
}
int64_t s = shapes.size();
DLContext cpu_ctx;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
auto shape_tensor = NDArray::Empty({s}, Type2TVMType(Int(64)), cpu_ctx);
int64_t* dims = static_cast<int64_t*>(shape_tensor->data);
for (size_t i = 0; i < shapes.size(); ++i) {
dims[i] = shapes[i];
}
return shape_tensor;
}
void VisitExpr_(const CallNode* call_node) {
for (auto arg : call_node->args) {
this->VisitExpr(arg);
}
Expr op = call_node->op;
auto func_node = op.as<FunctionNode>();
if (func_node) {
auto ret_type = call_node->checked_type();
if (const TensorTypeNode* ttype = ret_type.as<TensorTypeNode>()) {
auto shape = GetTensorConstant(ttype);
auto tensor_type = GetRef<TensorType>(ttype);
AddConstantTensorShape(tensor_type, shape);
} else if (const TupleTypeNode* ttype = ret_type.as<TupleTypeNode>()) {
for (size_t i = 0; i < ttype->fields.size(); ++i) {
auto f = ttype->fields[i];
auto f_type = f.as<TensorTypeNode>();
auto shape = GetTensorConstant(f_type);
auto tensor_type = GetRef<TensorType>(f_type);
AddConstantTensorShape(tensor_type, shape);
}
}
}
}
};
std::tuple<ConstMap, ConstTensorShapeMap> LayoutConstantPool(const Module& module) {
auto cp = ConstantPool(module);
for (auto& func : module->functions) {
cp.VisitExpr(func.first);
}
return std::make_tuple(cp.const_map, cp.const_tensor_shape_map);
}
void InstructionPrint(std::ostream& os, const Instruction& instr);
struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
/*! \brief Store the expression a variable points to. */
std::unordered_map<Var, Expr, NodeHash, NodeEqual> expr_map;
std::vector<Instruction> instructions;
// var -> register num
std::unordered_map<Var, RegName, NodeHash, NodeEqual> var_register_map;
size_t last_register;
// Total number of virtual registers allocated
size_t registers_num;
CompileEngine engine;
/*! \brief The functions that have been lowered. */
std::unordered_map<LoweredFunc, size_t, NodeHash, NodeEqual> seen_funcs;
/*! \brief Global shared meta data */
VMCompilerContext* context;
VMCompiler(VMCompilerContext* context)
: instructions(),
var_register_map(),
last_register(0),
registers_num(0),
engine(CompileEngine::Global()),
context(context)
{}
size_t NewRegister() { return registers_num++; }
inline void Emit(const Instruction& instr) {
DLOG(INFO) << "VMCompiler::Emit: instr=" << instr;
CHECK((int)instr.op < 100) << "Invalid opcode " << (int)instr.op;
switch (instr.op) {
case Opcode::AllocDatatype:
case Opcode::AllocTensor:
case Opcode::GetField:
case Opcode::LoadConst:
case Opcode::Select:
case Opcode::Invoke:
case Opcode::AllocClosure:
case Opcode::Move:
case Opcode::InvokeClosure:
last_register = instr.dst;
break;
case Opcode::InvokePacked:
last_register = instr.packed_args[instr.arity - 1];
break;
case Opcode::If:
case Opcode::Ret:
case Opcode::Goto:
break;
}
instructions.push_back(instr);
}
void VisitExpr_(const ConstantNode* const_node) {
auto rconst = GetRef<Constant>(const_node);
auto it = this->context->const_map.find(rconst);
CHECK(it != this->context->const_map.end());
Emit(Instruction::LoadConst(it->second, NewRegister()));
}
void VisitExpr_(const VarNode* var_node) {
auto var = GetRef<Var>(var_node);
auto reg_it = this->var_register_map.find(var);
CHECK(reg_it != this->var_register_map.end());
last_register = reg_it->second;
}
void VisitExpr_(const TupleNode* tuple_node) {
auto tuple = GetRef<Tuple>(tuple_node);
std::vector<Index> fields_registers;
for (auto& field : tuple->fields) {
this->VisitExpr(field);
fields_registers.push_back(last_register);
}
// TODO(@jroesch): use correct tag
Emit(Instruction::AllocDatatype(
0,
tuple->fields.size(),
fields_registers,
NewRegister()));
}
void VisitExpr_(const MatchNode* match_node) {
auto match = GetRef<Match>(match_node);
LOG(FATAL) << "translation of match nodes to the VM is"
<< "currently unsupported" << std::endl;
}
void VisitExpr_(const LetNode* let_node) {
DLOG(INFO) << let_node->value << std::endl;
this->VisitExpr(let_node->value);
DLOG(INFO) << this->last_register << std::endl;
var_register_map.insert({let_node->var, this->last_register});
this->VisitExpr(let_node->body);
}
void VisitExpr_(const TupleGetItemNode* get_node) {
auto get = GetRef<TupleGetItem>(get_node);
this->VisitExpr(get->tuple);
auto tuple_register = last_register;
Emit(Instruction::GetField(tuple_register, get->index, NewRegister()));
}
void VisitExpr_(const GlobalVarNode* gvar) {
LOG(FATAL) << "Global variables should only appear in the call position";
}
void VisitExpr_(const IfNode* if_node) {
this->VisitExpr(if_node->cond);
size_t cond_register = last_register;
auto after_cond = this->instructions.size();
this->Emit(Instruction::If(cond_register, 0, 0));
this->VisitExpr(if_node->true_branch);
size_t true_register = last_register;
Emit(Instruction::Goto(0));
// Finally store how many instructions there are in the
// true branch.
auto after_true = this->instructions.size();
this->VisitExpr(if_node->false_branch);
size_t false_register = last_register;
// Compute the total number of instructions
// after generating false.
auto after_false = this->instructions.size();
// Now we will compute the jump targets in order
// to properly patch the instruction with the
// the requiste targets.
// After we emit the true body, and false body,
// we patch up the if instruction, and goto.
auto true_offset = 1;
auto false_offset = after_true - after_cond;
this->instructions[after_cond].true_offset = true_offset;
this->instructions[after_cond].false_offset = false_offset;
// Patch the Goto.
this->instructions[after_true - 1].pc_offset = (after_false - after_true) + 1;
Emit(Instruction::Select(cond_register, true_register, false_register, NewRegister()));
}
Instruction AllocTensorFromType(const TensorTypeNode* ttype) {
DataType dtype = ttype->dtype;
TVMType dltype = Type2TVMType(dtype);
auto tensor_type = GetRef<TensorType>(ttype);
auto it = this->context->const_tensor_shape_map.find(tensor_type);
if (it == this->context->const_tensor_shape_map.end()) {
DLOG(INFO) << "Can not find constant shape for " << tensor_type;
} else {
Emit(Instruction::LoadConst(it->second.first, NewRegister()));
}
return Instruction::AllocTensor(last_register, dltype, NewRegister());
}
void EmitInvokePrimitive(const Function& func, std::vector<Index> args_registers,
const Type& ret_type) {
std::vector<Instruction> allocs;
size_t return_num = 0;
if (const TensorTypeNode* ttype = ret_type.as<TensorTypeNode>()) {
// Allocate space for the return tensor.
auto alloc = AllocTensorFromType(ttype);
allocs.push_back(alloc);
return_num = 1;
} else if (const TupleTypeNode* ttype = ret_type.as<TupleTypeNode>()) {
std::vector<Index> fields_registers;
for (size_t i = 0; i < ttype->fields.size(); ++i) {
auto f = ttype->fields[i];
auto f_type = f.as<TensorTypeNode>();
allocs.push_back(AllocTensorFromType(f_type));
fields_registers.push_back(allocs.back().dst);
}
return_num = ttype->fields.size();
} else {
LOG(FATAL) << "Unsupported return value type";
}
for (auto& alloc : allocs) {
Emit(alloc);
args_registers.push_back(alloc.dst);
}
// Next generate the invoke instruction.
CHECK(func->IsPrimitive());
auto target = Target::create("llvm");
auto key = CCacheKeyNode::make(func, target);
auto cfunc = engine->Lower(key);
// TODO(jroesch): support lowered funcs for multiple targets
CHECK_EQ(cfunc->funcs.size(), 1);
auto op_index = -1;
if (seen_funcs.find(cfunc->funcs[0]) == seen_funcs.end()) {
op_index = this->context->lowered_funcs.size();
this->context->lowered_funcs.push_back(cfunc->funcs[0]);
seen_funcs[cfunc->funcs[0]] = op_index;
} else {
op_index = seen_funcs[cfunc->funcs[0]];
}
// If Tensor, 1
// If Tuple, size of tuple
size_t arity = func->params.size() + return_num;
Emit(Instruction::InvokePacked(op_index, arity, return_num, args_registers));
if (return_num > 1) {
// return value is a tuple, we need to create a tuple
std::vector<Index> fields_registers;
for (size_t i = func->params.size(); i < arity; ++i) {
fields_registers.push_back(args_registers[i]);
}
Emit(Instruction::AllocDatatype(0, return_num, fields_registers, NewRegister()));
}
}
void VisitExpr_(const CallNode* call_node) {
std::vector<Index> args_registers;
for (auto arg : call_node->args) {
CHECK(arg.as<VarNode>()) << "found: " << AsText(arg, false) << std::endl << arg;
this->VisitExpr(arg);
args_registers.push_back(last_register);
}
Expr op = call_node->op;
if (auto func_node = op.as<FunctionNode>()) {
CHECK(func_node->IsPrimitive());
EmitInvokePrimitive(GetRef<Function>(func_node), args_registers, call_node->checked_type());
} else if (auto global_node = op.as<GlobalVarNode>()) {
auto global = GetRef<GlobalVar>(global_node);
auto it = this->context->global_map.find(global);
CHECK(it != this->context->global_map.end());
DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint
<< " with func_index=" << it->second;
auto func = this->context->module->Lookup(global);
if (IsClosure(func)) {
auto arity = func->params.size();
std::vector<Index> free_var_registers;
for (size_t i = 0; i < arity; ++i) {
free_var_registers.push_back(var_register_map.at(func->params[i]));
}
Emit(Instruction::AllocClosure(it->second, arity, free_var_registers, NewRegister()));
} else {
Emit(Instruction::Invoke(it->second, args_registers, NewRegister()));
}
} else if (auto constructor_node = op.as<ConstructorNode>()) {
auto constructor = GetRef<Constructor>(constructor_node);
auto tag = GetConstructorTag(constructor);
Emit(Instruction::AllocDatatype(tag, call_node->args.size(), args_registers, NewRegister()));
} else if (auto var_node = op.as<VarNode>()) {
VisitExpr(GetRef<Var>(var_node));
Emit(Instruction::InvokeClosure(last_register, args_registers, NewRegister()));
} else {
LOG(FATAL) << "unsupported case in vm compiler: " << op;
}
}
size_t GetConstructorTag(tvm::relay::Constructor constructor) {
auto it = this->context->tag_map.find(constructor);
if (it != this->context->tag_map.end()) {
return it->second;
} else {
auto tag = this->context->tag_map.size();
this->context->tag_map[constructor] = tag;
this->context->tag_index_map[tag] = constructor;
return tag;
}
}
void VisitExpr_(const FunctionNode* func_node) {
if (!func_node->IsPrimitive()) {
LOG(FATAL) << "local functions should have been removed by lambda lifting:" << std::endl
<< "Program: " << AsText(GetRef<Function>(func_node), false) << std::endl
<< "AST: " << GetRef<Function>(func_node);
}
}
void CompileClosure(const Function& func) {
// We first layout the function arguments.
auto inner_func = Downcast<Function>(func->body);
size_t i = 0;
for (auto param : inner_func->params) {
auto arg_register = NewRegister();
CHECK_EQ(i, arg_register);
var_register_map.insert({param, arg_register});
i++;
}
// We then assign register num to the free variables
for (auto param : func->params) {
auto arg_register = NewRegister();
CHECK_EQ(i, arg_register);
var_register_map.insert({param, arg_register});
i++;
}
// We will now process the body like normal.
this->VisitExpr(inner_func->body);
}
void Compile(const Function& func) {
// We need to generate code specially for lifted closures.
if (IsClosure(func)) {
CompileClosure(func);
return;
}
for (size_t i = 0; i < func->params.size(); ++i) {
auto arg_register = NewRegister();
CHECK_EQ(arg_register, i);
var_register_map.insert({func->params[i], arg_register});
}
this->VisitExpr(func->body);
}
};
void PopulatePackedFuncMap(const std::vector<LoweredFunc>& lowered_funcs,
std::vector<PackedFunc>* packed_funcs) {
runtime::Module mod;
if (lowered_funcs.size() > 0) {
// TODO(@jroesch): we need to read target from build config
Target target = Target::create("llvm");
if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
mod = (*f)(tvm::Array<LoweredFunc>(lowered_funcs.begin(), lowered_funcs.end()), target);
} else {
LOG(FATAL) << "relay.backend.build is not registered";
}
CHECK(mod.operator->());
for (auto lfunc : lowered_funcs) {
packed_funcs->push_back(mod.GetFunction(lfunc->name));
}
}
}
VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const Function& func) {
DLOG(INFO) << "CompileFunc: " << std::endl << AsText(func, false) << std::endl;
size_t params = func->params.size();
VMCompiler compiler(context);
compiler.Compile(func);
// return the last evaluated expression
compiler.instructions.push_back(Instruction::Ret(compiler.last_register));
// Would like to refactor this so we only check if closure once.
if (IsClosure(func)) {
auto inner_params = Downcast<Function>(func->body)->params.size();
return VMFunction(var->name_hint, params + inner_params, compiler.instructions,
compiler.registers_num);
} else {
return VMFunction(var->name_hint, params, compiler.instructions, compiler.registers_num);
}
}
Module OptimizeModule(const Module& mod) {
ToANormalForm(mod->entry_func, mod);
InlinePrimitives(mod);
LambdaLift(mod);
return InlinePrimitives(mod);
}
void PopulateGlobalMap(GlobalMap* global_map, const Module& mod) {
// First we populate global map.
size_t global_index = 0;
for (auto named_func : mod->functions) {
auto gvar = named_func.first;
global_map->insert({gvar, global_index++});
}
}
VirtualMachine CompileModule(const Module& mod_ref) {
Module mod = mod_ref;
// Run some optimizations first, this code should
// be moved to pass manager.
mod = OptimizeModule(mod);
VirtualMachine vm;
VMCompilerContext context;
context.module = mod;
// Populate the global map.
//
// This maps global variables to a global index
// in the VMFunction table.
PopulateGlobalMap(&context.global_map, mod);
// Next we populate constant map.
auto constant_analysis_result = LayoutConstantPool(mod);
context.const_map = std::get<0>(constant_analysis_result);
context.const_tensor_shape_map = std::get<1>(constant_analysis_result);
// Next we get ready by allocating space for
// the global state.
vm.functions.resize(mod->functions.size());
vm.constants.resize(context.const_map.size() + context.const_tensor_shape_map.size());
for (auto pair : context.const_map) {
vm.constants[pair.second] = Object::Tensor(pair.first->data);
}
for (auto pair : context.const_tensor_shape_map) {
vm.constants[pair.second.first] = Object::Tensor(pair.second.second);
}
for (auto named_func : mod->functions) {
auto gvar = named_func.first;
auto func = named_func.second;
auto vm_func = CompileFunc(&context, gvar, func);
size_t func_index = context.global_map.at(gvar);
CHECK(func_index < vm.functions.size());
vm.functions[func_index] = vm_func;
}
#ifdef USE_RELAY_DEBUG
for (auto vm_func : vm.functions) {
std::cout << "Function: " << vm_func.name << std::endl
<< vm_func << "-------------" << std::endl;
}
#endif // USE_RELAY_DEBUG
PopulatePackedFuncMap(context.lowered_funcs, &vm.packed_funcs);
for (auto gv : context.global_map) {
vm.global_map_.insert({gv.first->name_hint, gv.second});
}
return vm;
}
} // namespace vm
} // namespace relay
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file tvm/relay/backend/vm/inline_primitives.cc
* \brief Ensure that primitives only appear in the call position.
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/logging.h>
#include <tvm/relay/pass.h>
#include <tvm/runtime/vm.h>
#include <iostream>
#include <vector>
using namespace tvm::runtime;
namespace tvm {
namespace relay {
namespace vm {
struct PrimitiveInliner : ExprMutator {
Module module_;
std::unordered_map<Var, Expr, NodeHash, NodeEqual> var_map;
explicit PrimitiveInliner(const Module& module) : module_(module) {}
Expr VisitExpr_(const LetNode* let_node) {
var_map.insert({let_node->var, VisitExpr(let_node->value)});
return ExprMutator::VisitExpr_(let_node);
}
Expr VisitExpr_(const CallNode* call) {
Expr op = call->op;
// For now just collapse the chain of variables to see if
// they point to a primitive function.
const VarNode* var_node;
// Collapse a chain of let bindings
//
// let x = fn (..) { .. };
// let y = x
// let w = y
// in w(...)
while ((var_node = op.as<VarNode>())) {
auto var = GetRef<Var>(var_node);
DLOG(INFO) << "Var: " << var << std::endl;
auto it = var_map.find(GetRef<Var>(var_node));
if (it != var_map.end()) {
op = it->second;
} else {
return ExprMutator::VisitExpr_(call);
}
}
if (auto func = op.as<FunctionNode>()) {
if (func->IsPrimitive()) {
return CallNode::make(GetRef<Function>(func), call->args, call->attrs, call->type_args);
}
}
if (auto global = op.as<GlobalVarNode>()) {
return CallNode::make(GetRef<GlobalVar>(global), call->args, call->attrs, call->type_args);
}
return ExprMutator::VisitExpr_(call);
}
Expr VisitExpr_(const FunctionNode* func) {
if (func->IsPrimitive()) {
return GetRef<Function>(func);
} else {
return ExprMutator::VisitExpr_(func);
}
}
Function Inline(const Function& func) {
DLOG(INFO) << "Before inlining primitives: " << std::endl
<< "func= " << AsText(func, false) << std::endl;
auto inlined = FunctionNode::make(func->params, VisitExpr(func->body), func->ret_type,
func->type_params, func->attrs);
inlined = Downcast<Function>(DeadCodeElimination(inlined));
DLOG(INFO) << "After inlining primitives" << std::endl
<< "after_func= " << AsText(inlined, false) << std::endl;
return inlined;
}
};
// TODO(@jroesch): write verifier
/* This pass will eliminate primitives which have been lifted by the ANF
* transform inlining them directly into call sites.
*
* This makes VM related code generation easier as the call target is always
* a primitive function.
*
* let prim = fn(...) { ... };
* prim(...)
*
* will become:
*
* (fn(...) { ... })(...)
*/
Module InlinePrimitives(const Module& module) {
PrimitiveInliner inliner(module);
tvm::Map<GlobalVar, Function> updates;
// There is an ordering bug here.
for (auto pair : module->functions) {
auto global = pair.first;
auto func = pair.second;
updates.Set(global, inliner.Inline(func));
}
for (auto pair : updates) {
module->Add(pair.first, pair.second, true);
}
return module;
}
} // namespace vm
} // namespace relay
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file tvm/relay/backend/vm/lambda_lift.cc
* \brief Lift all local functions into global functions.
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/logging.h>
#include <tvm/relay/pass.h>
#include <tvm/runtime/vm.h>
#include <iostream>
#include <vector>
using namespace tvm::runtime;
namespace tvm {
namespace relay {
namespace vm {
static const char* kIsClosure = "IsClosure";
inline std::string GenerateName(const Function& func) {
size_t hash = StructuralHash()(func);
return std::string("lifted_name") + std::to_string(hash);
}
bool IsClosure(const Function& func) {
NodeRef res = FunctionGetAttr(func, kIsClosure);
const ir::IntImm* pval = res.as<ir::IntImm>();
return pval && pval->value != 0;
}
Function MarkClosure(const Function& func) {
return FunctionSetAttr(func, kIsClosure, tvm::Integer(1));
}
struct LambdaLifter : ExprMutator {
Module module_;
std::vector<std::pair<GlobalVar, Function>> lifted_;
explicit LambdaLifter(const Module& module) : module_(module) {}
Expr VisitExpr_(const FunctionNode* func_node) final {
auto func = GetRef<Function>(func_node);
// We should not transform primitive functions.
if (func->IsPrimitive()) {
return std::move(func);
}
auto free_vars = FreeVars(func);
auto free_type_vars = FreeTypeVars(func, module_);
auto body = Downcast<Function>(ExprMutator::VisitExpr_(func_node));
// When performing this optimization there are two
// cases.
//
// The first case in which we have no free variables
// we can just lift the function into the global
// environment without needing to allocate a closure.
//
//
// The second case requires that we generate a special
// function with makes a distinction between allocating
// a closure, and then the code for the closure.
//
// We represent a closure allocation by lifting the
// closure to a global function which takes its
// captured arguments and then directly returns
// the function representing the closure's code.
//
// When we generate code later on a call to the "outer"
// function marked as a closure is used to emit allocation
// code for the closure's environment.
//
// The "inner" function is should be used to generate the
// code for the closure.
Function lifted_func;
if (free_vars.size() == 0) {
lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, free_type_vars);
} else {
lifted_func =
FunctionNode::make(free_vars, body, func->func_type_annotation(), free_type_vars);
lifted_func = MarkClosure(lifted_func);
}
CHECK(lifted_func.defined());
auto name = GenerateName(lifted_func);
auto global = this->module_->GetGlobalVar(name);
lifted_.push_back({global, lifted_func});
if (free_vars.size() == 0) {
return std::move(global);
} else {
// If we need to allocate a closure
// we pass the variables in its environment
// here.
Array<Expr> fvs;
for (auto fv : free_vars) {
fvs.push_back(fv);
}
return CallNode::make(global, fvs);
}
}
Function Lift(const Function& func) {
DLOG(INFO) << "Lifting: " << AsText(func, false) << std::endl;
return FunctionNode::make(func->params, VisitExpr(func->body), func->ret_type,
func->type_params, func->attrs);
}
};
/* The goal of this pass is to lift out any nested functions into top-level
* functions.
*
* We will lift the functions out into globals which take the set of the free vars
* and then return a function whcih has b
*/
Module LambdaLift(const Module& module) {
LambdaLifter lifter(module);
tvm::Map<GlobalVar, Function> updates;
// There is an ordering bug here.
for (auto pair : module->functions) {
auto global = pair.first;
auto func = pair.second;
updates.Set(global, lifter.Lift(func));
}
for (auto i = lifter.lifted_.begin(); i != lifter.lifted_.end(); i++) {
module->Add(i->first, i->second);
}
for (auto pair : updates) {
module->Add(pair.first, pair.second, true);
}
return module;
}
} // namespace vm
} // namespace relay
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/vm.cc
* \brief The Relay virtual machine.
*/
#include <tvm/relay/interpreter.h>
#include <tvm/logging.h>
#include <tvm/relay/module.h>
#include <tvm/runtime/vm.h>
#include <tvm/relay/pass.h>
namespace tvm {
namespace relay {
namespace vm {
using tvm::runtime::Object;
using tvm::runtime::ObjectTag;
using tvm::runtime::vm::VirtualMachine;
VirtualMachine FromModule(const Module& module, const std::vector<TVMContext>& ctxs) {
auto vm = CompileModule(module);
vm.Init(ctxs);
return vm;
}
Object EvaluateModule(const Module& module, const std::vector<TVMContext> ctxs,
const std::vector<Object>& vm_args) {
VirtualMachine vm = FromModule(module, ctxs);
// TODO(zhiics): This measurement is for temporary usage. Remove it later. We
// need to introduce a better profiling method.
#if ENABLE_PROFILING
DLOG(INFO) << "Entry function is " << module->entry_func << std::endl;
auto start = std::chrono::high_resolution_clock::now();
#endif // ENABLE_PROFILING
Object res = vm.Invoke(module->entry_func->name_hint, vm_args);
#if ENABLE_PROFILING
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
LOG(INFO) << "Inference time: " << duration << "ms\n";
#endif // ENABLE_PROFILING
return res;
}
Value VMToValue(const relay::Module& module, const relay::Type& type, Object obj) {
CHECK(module.defined() && type.defined());
switch (obj->tag) {
case ObjectTag::kTensor: {
CHECK(type.as<TensorTypeNode>()) << "VM internal error: return value must be a tensor";
return TensorValueNode::make(ToNDArray(obj));
}
case ObjectTag::kDatatype: {
// const auto* tuple_type
// const auto& data_type = obj.AsDatatype();
// tvm::Array<Value> fields;
// for (size_t i = 0; i < data_type->fields.size(); ++i) {
// fields.push_back(VMToValue(tag_index_map, data_type->fields[i]));
// }
// return ConstructorValueNode::make(tag_index_map.at(data_type->tag), fields);
LOG(FATAL) << "fix me";
}
default:
LOG(FATAL) << "unsupported return value of type: " << obj->tag;
return Value();
}
}
TVM_REGISTER_API("relay._vm._Tensor").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Object::Tensor(args[0]);
});
TVM_REGISTER_API("relay._vm._Tuple").set_body([](TVMArgs args, TVMRetValue* ret) {
std::vector<Object> fields;
for (auto i = 0; i < args.size(); i++) {
fields.push_back(args[i]);
}
*ret = Object::Tuple(fields);
});
template <typename T>
std::string ToString(const T& t) {
std::stringstream s;
s << t;
return s.str();
}
TVM_REGISTER_API("relay._vm._ObjectTag").set_body([](TVMArgs args, TVMRetValue* ret) {
Object obj = args[0];
*ret = ToString(obj->tag);
});
TVM_REGISTER_API("relay._vm._Datatype")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int itag = args[0];
size_t tag = static_cast<size_t>(itag);
std::vector<Object> fields;
for (int i = 1; i < args.size(); i++) {
fields.push_back(args[i]);
}
*ret = Object::Datatype(tag, fields);
});
TVM_REGISTER_API("relay._vm._evaluate_vm").set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef to_compile = args[0];
TVMContext ctx;
int dev_type = args[1];
ctx.device_type = static_cast<DLDeviceType>(dev_type);
ctx.device_id = args[2];
Module module;
if (to_compile.as<FunctionNode>()) {
Function to_compile = args[0];
module = ModuleNode::FromExpr(to_compile);
} else if (to_compile.as<ModuleNode>()) {
module = args[0];
} else {
LOG(FATAL) << "expected function or module";
}
auto return_type = module->Lookup(module->entry_func)->ret_type;
std::vector<Object> vm_args;
for (auto i = 3; i < args.size(); i++) {
Object obj = args[i];
vm_args.push_back(obj);
}
auto result = EvaluateModule(module, {ctx}, vm_args);
DLOG(INFO) << "Evaluate VM returning: result=" << result->tag;
*ret = VMToValue(module, return_type, result);
});
} // namespace vm
} // namespace relay
} // namespace tvm
......@@ -154,6 +154,9 @@ Array<Tensor> ReduceCompute(const Attrs& attrs,
F f) {
const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr);
if (inputs[0]->shape.size() == 0) {
return { topi::identity(inputs[0]) };
}
auto axes = param->axis;
if (param->exclude) {
axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis);
......@@ -251,7 +254,6 @@ bool ReduceRel(const Array<Type>& types,
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
CHECK(static_cast<int>(data->shape.size()) != 0);
std::vector<IndexExpr>&& in_shape = AsVector(data->shape);
const ReduceAttrs* param = attrs.as<ReduceAttrs>();
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -124,7 +124,8 @@ class CalcDep : private ExprVisitor {
friend CalcDep;
bool HasLet(const Var& v) {
return (use_map_[v] > 1 || (use_map_[v] != 0 && letrec_set_.count(v) != 0));
// TODO(@jroesch): MK fix me
return (use_map_[v] > 0 || (use_map_[v] != 0 && letrec_set_.count(v) != 0));
}
Expr VisitExpr_(const VarNode* op) final {
......
......@@ -118,6 +118,86 @@ Instruction::Instruction(const Instruction& instr) {
}
}
template<typename T>
static inline void FreeIf(T* t) {
if (t != nullptr) {
delete t;
}
}
Instruction& Instruction::operator=(const Instruction& instr) {
this->op = instr.op;
this->dst = instr.dst;
switch (instr.op) {
case Opcode::Move:
this->from = instr.from;
return *this;
case Opcode::Select:
this->select_cond = instr.select_cond;
this->select_op1 = instr.select_op1;
this->select_op2 = instr.select_op2;
return *this;
case Opcode::Ret:
this->result = instr.result;
return *this;
case Opcode::AllocTensor:
this->shape_register = instr.shape_register;
this->dtype = instr.dtype;
return *this;
case Opcode::AllocDatatype:
this->constructor_tag = instr.constructor_tag;
this->num_fields = instr.num_fields;
FreeIf(this->datatype_fields);
this->datatype_fields = Duplicate<RegName>(instr.datatype_fields, instr.num_fields);
return *this;
case Opcode::AllocClosure:
this->clo_index = instr.clo_index;
this->num_freevar = instr.num_freevar;
FreeIf(this->free_vars);
this->free_vars = Duplicate<RegName>(instr.free_vars, instr.num_freevar);
return *this;
case Opcode::InvokePacked:
this->packed_index = instr.packed_index;
this->arity = instr.arity;
this->output_size = instr.output_size;
FreeIf(this->packed_args);
this->packed_args = Duplicate<RegName>(instr.packed_args, instr.arity);
return *this;
case Opcode::InvokeClosure:
this->closure = instr.closure;
this->closure_args_num = instr.closure_args_num;
FreeIf(this->closure_args);
this->closure_args = Duplicate<RegName>(instr.closure_args, instr.closure_args_num);
return *this;
case Opcode::Invoke:
this->func_index = instr.func_index;
this->num_args = instr.num_args;
FreeIf(this->invoke_args_registers);
this->invoke_args_registers = Duplicate<RegName>(instr.invoke_args_registers, instr.num_args);
return *this;
case Opcode::If:
this->if_cond = instr.if_cond;
this->true_offset = instr.true_offset;
this->false_offset = instr.false_offset;
return *this;
case Opcode::LoadConst:
this->const_index = instr.const_index;
return *this;
case Opcode::GetField:
this->object = instr.object;
this->field_index = instr.field_index;
return *this;
case Opcode::Goto:
this->pc_offset = instr.pc_offset;
return *this;
default:
std::ostringstream out;
out << "Invalid instruction " << static_cast<int>(instr.op);
throw std::runtime_error(out.str());
}
}
Instruction::~Instruction() {
switch (this->op) {
case Opcode::Move:
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Benchmarking Relay VM using models from MXNet."""
import numpy as np
import tvm
from tvm.contrib import graph_runtime
from tvm import relay
from tvm.relay import testing
def benchmark_execution(net,
params,
measure=False,
data_shape=(1, 3, 224, 224),
out_shape=(1, 1000),
dtype='float32'):
def get_tvm_output(net, data, params, target, ctx, dtype='float32'):
with relay.build_config(opt_level=1):
graph, lib, params = relay.build(net, target, params=params)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input("data", data)
m.set_input(**params)
m.run()
out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
if measure:
print("Evaluate graph runtime inference time cost...")
ftimer = m.module.time_evaluator("run", ctx, number=1, repeat=20)
# Measure in millisecond.
prof_res = np.array(ftimer().results) * 1000
print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
(np.mean(prof_res), np.std(prof_res)))
return out.asnumpy()
def get_tvm_vm_output(net, data, params, target, ctx, dtype='float32'):
ex = relay.create_executor('vm', mod=relay.Module(), ctx=ctx)
result = ex.evaluate(net)(data, **params)
return result.asnumpy().astype(dtype)
# random input
data = np.random.uniform(size=data_shape).astype(dtype)
target = "llvm"
ctx = tvm.cpu(0)
tvm_out = get_tvm_output(net, tvm.nd.array(data.astype(dtype)), params,
target, ctx, dtype)
vm_out = get_tvm_vm_output(net, tvm.nd.array(data.astype(dtype)), params,
target, ctx, dtype)
tvm.testing.assert_allclose(vm_out, tvm_out, rtol=1e-5, atol=1e-5)
def test_mlp():
image_shape = (1, 28, 28)
net, params = testing.mlp.get_workload(1)
benchmark_execution(net, params, data_shape=image_shape, out_shape=(1, 10))
def test_vgg():
for n in [11, 16]:
net, params = testing.vgg.get_workload(1, num_layers=n)
benchmark_execution(net, params)
def test_resnet():
for n in [18, 50]:
net, params = testing.resnet.get_workload(batch_size=1, num_layers=n)
benchmark_execution(net, params, True)
def test_squeezenet():
for version in ['1.0', '1.1']:
net, params = testing.squeezenet.get_workload(version=version)
benchmark_execution(net, params)
def test_inception_v3():
image_shape = (3, 299, 299)
net, params = testing.inception_v3.get_workload(image_shape=image_shape)
benchmark_execution(net, params, data_shape=image_shape)
def test_dqn():
image_shape = (4, 84, 84)
net, params = testing.dqn.get_workload(
batch_size=1, image_shape=image_shape)
benchmark_execution(net, params, data_shape=image_shape, out_shape=(1, 18))
def test_dcgan():
image_shape = (1, 100)
net, params = testing.dcgan.get_workload(batch_size=1)
benchmark_execution(net, params, data_shape=image_shape)
def test_mobilenet():
net, params = testing.mobilenet.get_workload(batch_size=1)
benchmark_execution(net, params)
def test_densenet():
net, params = testing.densenet.get_workload(batch_size=1)
benchmark_execution(net, params)
if __name__ == '__main__':
test_resnet()
test_vgg()
test_squeezenet()
test_mobilenet()
test_densenet()
# The following networks fail
# test_inception_v3()
# test_mlp()
# test_dqn()
# test_dcgan()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
from nose.tools import nottest
import tvm
import numpy as np
from tvm import relay
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.prelude import Prelude
def veval(f, *args, ctx=tvm.cpu()):
if isinstance(f, relay.Expr):
ex = relay.create_executor('vm', mod=relay.Module(), ctx=ctx)
if len(args) == 0:
return ex.evaluate(f)
else:
return ex.evaluate(f)(*args)
else:
assert isinstance(f, relay.Module), "expected expression or module"
mod = f
ex = relay.create_executor('vm', mod=mod, ctx=ctx)
if len(args) == 0:
return ex.evaluate(mod[mod.entry_func])
else:
return ex.evaluate(mod[mod.entry_func])(*args)
def test_split():
x = relay.var('x', shape=(12,))
y = relay.split(x, 3, axis=0).astuple()
z = relay.concatenate([relay.TupleGetItem(y, 0)], axis=0)
f = relay.Function([x], z)
x_data = np.random.rand(12,).astype('float32')
res = veval(f, x_data)
tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0])
def test_id():
x = relay.var('x', shape=(10, 10))
f = relay.Function([x], x)
x_data = np.random.rand(10, 10).astype('float64')
res = veval(f, x_data)
tvm.testing.assert_allclose(res.asnumpy(), x_data)
def test_op():
x = relay.var('x', shape=(10, 10))
f = relay.Function([x], x + x)
x_data = np.random.rand(10, 10).astype('float32')
res = veval(f, x_data)
tvm.testing.assert_allclose(res.asnumpy(), x_data + x_data)
def any(x):
x = relay.op.nn.batch_flatten(x)
return relay.op.min(x, axis=[0, 1])
def test_cond():
x = relay.var('x', shape=(10, 10))
y = relay.var('x', shape=(10, 10))
# f = relay.Function([x, y], relay.op.equal(x, y))
f = relay.Function([x, y], any(relay.op.equal(x, y)))
x_data = np.random.rand(10, 10).astype('float32')
y_data = np.random.rand(10, 10).astype('float32')
# same
res = veval(f, x_data, x_data)
np.testing.assert_allclose(res.asnumpy(), True)
# diff
res = veval(f, x_data, y_data)
tvm.testing.assert_allclose(res.asnumpy(), False)
def test_simple_if():
x = relay.var('x', shape=(10, 10))
y = relay.var('y', shape=(10, 10))
f = relay.Function([x, y],
relay.If(any(relay.op.equal(x, y)), x, y))
x_data = np.random.rand(10, 10).astype('float32')
y_data = np.random.rand(10, 10).astype('float32')
# same
res = veval(f, x_data, x_data)
tvm.testing.assert_allclose(res.asnumpy(), x_data)
# diff
res = veval(f, x_data, y_data)
tvm.testing.assert_allclose(res.asnumpy(), y_data)
def test_simple_call():
mod = relay.module.Module({})
sum_up = relay.GlobalVar('sum_up')
i = relay.var('i', shape=[], dtype='int32')
sb = ScopeBuilder()
sb.ret(i)
func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32'))
mod[sum_up] = func
i_data = np.array(0, dtype='int32')
iarg = relay.var('i', shape=[], dtype='int32')
mod[mod.entry_func] = relay.Function([iarg], sum_up(iarg))
result = veval(mod, i_data)
tvm.testing.assert_allclose(result.asnumpy(), i_data)
def test_count_loop():
mod = relay.module.Module({})
sum_up = relay.GlobalVar('sum_up')
i = relay.var('i', shape=[], dtype='int32')
sb = ScopeBuilder()
with sb.if_scope(relay.equal(i, relay.const(0, dtype='int32'))):
sb.ret(i)
with sb.else_scope():
one_less = relay.subtract(i, relay.const(1, dtype='int32'))
rec_call = relay.Call(sum_up, [one_less])
sb.ret(relay.add(rec_call, i))
func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32'))
mod[sum_up] = func
i_data = np.array(0, dtype='int32')
iarg = relay.var('i', shape=[], dtype='int32')
mod[mod.entry_func] = relay.Function([iarg], sum_up(iarg))
result = veval(mod, i_data)
tvm.testing.assert_allclose(result.asnumpy(), i_data)
def test_sum_loop():
mod = relay.module.Module({})
sum_up = relay.GlobalVar('sum_up')
i = relay.var('i', shape=[], dtype='int32')
accum = relay.var('accum', shape=[], dtype='int32')
sb = ScopeBuilder()
with sb.if_scope(relay.equal(i, relay.const(0, 'int32'))):
sb.ret(accum)
with sb.else_scope():
one_less = relay.subtract(i, relay.const(1, 'int32'))
new_accum = relay.add(accum, i)
sb.ret(relay.Call(sum_up, [one_less, new_accum]))
func = relay.Function([i, accum], sb.get())
mod[sum_up] = func
loop_bound = 0
i_data = np.array(loop_bound, dtype='int32')
accum_data = np.array(0, dtype='int32')
iarg = relay.var('i', shape=[], dtype='int32')
aarg = relay.var('accum', shape=[], dtype='int32')
mod[mod.entry_func] = relay.Function([iarg, aarg], sum_up(iarg, aarg))
result = veval(mod, i_data, accum_data)
tvm.testing.assert_allclose(result.asnumpy(), sum(range(1, loop_bound + 1)))
def test_tuple_fst():
ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))])
tup = relay.var('tup', type_annotation=ttype)
f = relay.Function([tup], relay.TupleGetItem(tup, 0))
i_data = np.random.rand(41).astype('float32')
j_data = np.random.rand(10).astype('float32')
result = veval(f, (i_data, j_data))
tvm.testing.assert_allclose(result.asnumpy(), i_data)
def test_tuple_second():
ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))])
tup = relay.var('tup', type_annotation=ttype)
f = relay.Function([tup], relay.TupleGetItem(tup, 1))
i_data = np.random.rand(41).astype('float32')
j_data = np.random.rand(10).astype('float32')
result = veval(f, (i_data, j_data))
tvm.testing.assert_allclose(result.asnumpy(), j_data)
@nottest
def test_list_constructor():
# TODO(wweic): implement pattern match to support this test
def to_list(o):
if isinstance(o, tvm.relay.backend.interpreter.TensorValue):
return [o.data.asnumpy().tolist()]
if isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
result = []
for f in o.fields:
result.extend(to_list(f))
return result
mod = relay.Module()
p = Prelude(mod)
nil = p.nil
cons = p.cons
l = p.l
one2 = cons(relay.const(1), nil())
one3 = cons(relay.const(2), one2)
one4 = cons(relay.const(3), one3)
f = relay.Function([], one4)
mod[mod.entry_func] = f
result = veval(mod)()
obj = to_list(result)
import pdb; pdb.set_trace()
tvm.testing.assert_allclose(obj, np.array([3,2,1]))
def test_let_tensor():
sb = relay.ScopeBuilder()
shape = (1,)
x = relay.var('x', shape=shape, dtype='float32')
x1 = relay.var('x1', shape=shape, dtype='float32')
x1 = sb.let(x1, x)
xplusone = x1 + relay.const(42.0, 'float32')
sb.ret(xplusone)
body = sb.get()
f = relay.Function([x], body)
x_data = np.random.rand(*shape).astype('float32')
result = veval(f, x_data)
tvm.testing.assert_allclose(result.asnumpy(), x_data + 42.0)
def test_let_scalar():
sb = relay.ScopeBuilder()
x = relay.var('x', 'float32')
x1 = sb.let('x1', x)
xplusone = x1 + relay.const(42.0, 'float32')
sb.ret(xplusone)
body = sb.get()
f = relay.Function([x], body)
x_data = np.array(np.random.rand()).astype('float32')
result = veval(f, x_data)
tvm.testing.assert_allclose(result.asnumpy(), x_data + 42.0)
def test_closure():
x = relay.var('x', shape=())
y = relay.var('y', shape=())
f = relay.Function([x], x + y)
ff = relay.Function([y], f)
clo = ff(relay.const(1.0))
main = clo(relay.const(2.0))
res = veval(main)
tvm.testing.assert_allclose(res.asnumpy(), 3.0)
if __name__ == "__main__":
test_id()
test_op()
test_cond()
test_simple_if()
test_simple_call()
test_count_loop()
test_sum_loop()
test_tuple_fst()
test_tuple_second()
test_let_scalar()
test_let_tensor()
# TODO(@jroesch): restore when match is supported
# test_list_constructor()
test_closure()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment