Commit 93d1c06d by Wei Chen Committed by Jared Roesch

[Relay][VM]Compiling pattern matching (#3470)

* [Relay][VM]Compiling pattern matching

* Fix lint

* Remove debug code

* Move TreeNode definition

* merge ifi and selecti, todo: remove them

* fix lint

* remove ifi and selecti

* rename GetTagi to GetTag

* fix dltype

* fix more dltype

* Generalize If and select, and rename to Ifi and Selecti

* Fix lint

* Rename Ifi to If

* Change register default to match value

* Remove bad specialization for Move

* Stop use Select

* Remove Select

* TreeNode refactor

* Change entry_func name

* Remove Cmp due to rebase issue
parent be776dc7
...@@ -61,9 +61,11 @@ enum class Opcode { ...@@ -61,9 +61,11 @@ enum class Opcode {
AllocClosure = 8U, AllocClosure = 8U,
GetField = 9U, GetField = 9U,
If = 10U, If = 10U,
Select = 11U, LoadConst = 11U,
LoadConst = 12U, Goto = 12U,
Goto = 13U GetTag = 13U,
LoadConsti = 14U,
Fatal = 15U,
}; };
/*! \brief A single virtual machine instruction. /*! \brief A single virtual machine instruction.
...@@ -123,22 +125,16 @@ struct Instruction { ...@@ -123,22 +125,16 @@ struct Instruction {
/*! \brief The arguments to pass to the packed function. */ /*! \brief The arguments to pass to the packed function. */
RegName* packed_args; RegName* packed_args;
}; };
struct /* Select Operands */ {
/*! \brief The condition of select. */
RegName select_cond;
/*! \brief The true branch. */
RegName select_op1;
/*! \brief The false branch. */
RegName select_op2;
};
struct /* If Operands */ { struct /* If Operands */ {
/*! \brief The register containing the condition value. */ /*! \brief The register containing the test value. */
RegName if_cond; RegName test;
/*! \brief The register containing the target value. */
RegName target;
/*! \brief The program counter offset for the true branch. */ /*! \brief The program counter offset for the true branch. */
Index true_offset; Index true_offset;
/*! \brief The program counter offset for the false branch. */ /*! \brief The program counter offset for the false branch. */
Index false_offset; Index false_offset;
}; } if_op;
struct /* Invoke Operands */ { struct /* Invoke Operands */ {
/*! \brief The function to call. */ /*! \brief The function to call. */
Index func_index; Index func_index;
...@@ -151,6 +147,10 @@ struct Instruction { ...@@ -151,6 +147,10 @@ struct Instruction {
/* \brief The index into the constant pool. */ /* \brief The index into the constant pool. */
Index const_index; Index const_index;
}; };
struct /* LoadConsti Operands */ {
/* \brief The index into the constant pool. */
size_t val;
} load_consti;
struct /* Jump Operands */ { struct /* Jump Operands */ {
/*! \brief The jump offset. */ /*! \brief The jump offset. */
Index pc_offset; Index pc_offset;
...@@ -161,6 +161,10 @@ struct Instruction { ...@@ -161,6 +161,10 @@ struct Instruction {
/*! \brief The field to read out. */ /*! \brief The field to read out. */
Index field_index; Index field_index;
}; };
struct /* GetTag Operands */ {
/*! \brief The register to project from. */
RegName object;
} get_tag;
struct /* AllocDatatype Operands */ { struct /* AllocDatatype Operands */ {
/*! \brief The datatype's constructor tag. */ /*! \brief The datatype's constructor tag. */
Index constructor_tag; Index constructor_tag;
...@@ -179,19 +183,15 @@ struct Instruction { ...@@ -179,19 +183,15 @@ struct Instruction {
}; };
}; };
/*! \brief Construct a select instruction.
* \param cond The condition register.
* \param op1 The true register.
* \param op2 The false register.
* \param dst The destination register.
* \return The select instruction.
*/
static Instruction Select(RegName cond, RegName op1, RegName op2, RegName dst);
/*! \brief Construct a return instruction. /*! \brief Construct a return instruction.
* \param return_reg The register containing the return value. * \param return_reg The register containing the return value.
* \return The return instruction. * \return The return instruction.
* */ * */
static Instruction Ret(RegName return_reg); static Instruction Ret(RegName return_reg);
/*! \brief Construct a fatal instruction.
* \return The fatal instruction.
* */
static Instruction Fatal();
/*! \brief Construct a invoke packed instruction. /*! \brief Construct a invoke packed instruction.
* \param packed_index The index of the packed function. * \param packed_index The index of the packed function.
* \param arity The arity of the function. * \param arity The arity of the function.
...@@ -240,13 +240,20 @@ struct Instruction { ...@@ -240,13 +240,20 @@ struct Instruction {
* \return The get field instruction. * \return The get field instruction.
*/ */
static Instruction GetField(RegName object_reg, Index field_index, RegName dst); static Instruction GetField(RegName object_reg, Index field_index, RegName dst);
/*! \brief Construct a get_tag instruction.
* \param object_reg The register containing the object to project from.
* \param dst The destination register.
* \return The get_tag instruction.
*/
static Instruction GetTag(RegName object_reg, RegName dst);
/*! \brief Construct an if instruction. /*! \brief Construct an if instruction.
* \param cond_reg The register containing the condition. * \param test The register containing the test value.
* \param target The register containing the target value.
* \param true_branch The offset to the true branch. * \param true_branch The offset to the true branch.
* \param false_branch The offset to the false branch. * \param false_branch The offset to the false branch.
* \return The if instruction. * \return The if instruction.
*/ */
static Instruction If(RegName cond_reg, Index true_branch, Index false_branch); static Instruction If(RegName test, RegName target, Index true_branch, Index false_branch);
/*! \brief Construct a goto instruction. /*! \brief Construct a goto instruction.
* \param pc_offset The offset from the current pc. * \param pc_offset The offset from the current pc.
* \return The goto instruction. * \return The goto instruction.
...@@ -272,6 +279,12 @@ struct Instruction { ...@@ -272,6 +279,12 @@ struct Instruction {
* \return The load constant instruction. * \return The load constant instruction.
*/ */
static Instruction LoadConst(Index const_index, RegName dst); static Instruction LoadConst(Index const_index, RegName dst);
/*! \brief Construct a load_constanti instruction.
* \param val The interger constant value.
* \param dst The destination register.
* \return The load_constanti instruction.
*/
static Instruction LoadConsti(size_t val, RegName dst);
/*! \brief Construct a move instruction. /*! \brief Construct a move instruction.
* \param src The source register. * \param src The source register.
* \param dst The destination register. * \param dst The destination register.
...@@ -398,6 +411,12 @@ struct VirtualMachine { ...@@ -398,6 +411,12 @@ struct VirtualMachine {
*/ */
inline Object ReadRegister(RegName reg) const; inline Object ReadRegister(RegName reg) const;
/*! \brief Read a VM register and cast it to int32_t
* \param reg The register to read from.
* \return The read scalar.
*/
int32_t LoadScalarInt(RegName reg) const;
/*! \brief Invoke a VM function. /*! \brief Invoke a VM function.
* \param func The function. * \param func The function.
* \param args The arguments to the function. * \param args The arguments to the function.
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include <vector> #include <vector>
#include "../../../runtime/vm/naive_allocator.h" #include "../../../runtime/vm/naive_allocator.h"
#include "../../backend/compile_engine.h" #include "../../backend/compile_engine.h"
#include "../../pass/pass_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -122,6 +123,49 @@ std::tuple<ConstMap, ConstTensorShapeMap> LayoutConstantPool(const Module& modul ...@@ -122,6 +123,49 @@ std::tuple<ConstMap, ConstTensorShapeMap> LayoutConstantPool(const Module& modul
void InstructionPrint(std::ostream& os, const Instruction& instr); void InstructionPrint(std::ostream& os, const Instruction& instr);
// Represent a runtime object that's going to be matched by pattern match expressions
struct MatchValue {
virtual ~MatchValue() {}
};
using MatchValuePtr = std::shared_ptr<MatchValue>;
// A runtime object that resides in a register
struct RegisterValue : MatchValue {
// The register num
RegName rergister_num;
explicit RegisterValue(RegName reg) : rergister_num(reg) {}
~RegisterValue() {}
};
// The value is a field of another runtime object
struct AccessField : MatchValue {
MatchValuePtr parent;
// Field index
size_t index;
// Runtime register num after compiling the access field path
RegName reg{-1};
AccessField(MatchValuePtr parent, size_t index)
: parent(parent), index(index) {}
~AccessField() {}
};
struct VMCompiler;
/*!
* \brief Compile a pattern match expression
* It first converts the pattern match expression into a desicision tree, the condition
* could be object comparison or variable binding. If any of the condition fails in a clause,
* the decision tree switches to check the conditions of next clause and so on. If no clause
* matches the value, a fatal node is inserted.
*
* After the decision tree is built, we convert it into bytecodes using If/Goto.
*/
void CompileMatch(Match match, VMCompiler* compiler);
struct VMCompiler : ExprFunctor<void(const Expr& expr)> { struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
/*! \brief Store the expression a variable points to. */ /*! \brief Store the expression a variable points to. */
std::unordered_map<Var, Expr, NodeHash, NodeEqual> expr_map; std::unordered_map<Var, Expr, NodeHash, NodeEqual> expr_map;
...@@ -159,8 +203,9 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -159,8 +203,9 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
case Opcode::AllocTensor: case Opcode::AllocTensor:
case Opcode::AllocTensorReg: case Opcode::AllocTensorReg:
case Opcode::GetField: case Opcode::GetField:
case Opcode::GetTag:
case Opcode::LoadConst: case Opcode::LoadConst:
case Opcode::Select: case Opcode::LoadConsti:
case Opcode::Invoke: case Opcode::Invoke:
case Opcode::AllocClosure: case Opcode::AllocClosure:
case Opcode::Move: case Opcode::Move:
...@@ -173,6 +218,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -173,6 +218,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
case Opcode::If: case Opcode::If:
case Opcode::Ret: case Opcode::Ret:
case Opcode::Goto: case Opcode::Goto:
case Opcode::Fatal:
break; break;
} }
instructions.push_back(instr); instructions.push_back(instr);
...@@ -211,8 +257,9 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -211,8 +257,9 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
void VisitExpr_(const MatchNode* match_node) { void VisitExpr_(const MatchNode* match_node) {
auto match = GetRef<Match>(match_node); auto match = GetRef<Match>(match_node);
LOG(FATAL) << "translation of match nodes to the VM is"
<< "currently unsupported"; this->VisitExpr(match->data);
CompileMatch(match, this);
} }
void VisitExpr_(const LetNode* let_node) { void VisitExpr_(const LetNode* let_node) {
...@@ -242,15 +289,15 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -242,15 +289,15 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
void VisitExpr_(const IfNode* if_node) { void VisitExpr_(const IfNode* if_node) {
this->VisitExpr(if_node->cond); this->VisitExpr(if_node->cond);
size_t cond_register = last_register; size_t test_register = last_register;
this->Emit(Instruction::LoadConsti(1, NewRegister()));
auto after_cond = this->instructions.size(); auto after_cond = this->instructions.size();
auto target_register = this->last_register;
this->Emit(Instruction::If(cond_register, 0, 0)); this->Emit(Instruction::If(test_register, target_register, 0, 0));
this->VisitExpr(if_node->true_branch); this->VisitExpr(if_node->true_branch);
size_t true_register = last_register; size_t true_register = last_register;
Emit(Instruction::Goto(0)); Emit(Instruction::Goto(0));
// Finally store how many instructions there are in the // Finally store how many instructions there are in the
...@@ -261,6 +308,8 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -261,6 +308,8 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
size_t false_register = last_register; size_t false_register = last_register;
// In else-branch, override the then-branch register
Emit(Instruction::Move(false_register, true_register));
// Compute the total number of instructions // Compute the total number of instructions
// after generating false. // after generating false.
auto after_false = this->instructions.size(); auto after_false = this->instructions.size();
...@@ -273,13 +322,13 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -273,13 +322,13 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
// we patch up the if instruction, and goto. // we patch up the if instruction, and goto.
auto true_offset = 1; auto true_offset = 1;
auto false_offset = after_true - after_cond; auto false_offset = after_true - after_cond;
this->instructions[after_cond].true_offset = true_offset; this->instructions[after_cond].if_op.true_offset = true_offset;
this->instructions[after_cond].false_offset = false_offset; this->instructions[after_cond].if_op.false_offset = false_offset;
// Patch the Goto. // Patch the Goto.
this->instructions[after_true - 1].pc_offset = (after_false - after_true) + 1; this->instructions[after_true - 1].pc_offset = (after_false - after_true) + 1;
Emit(Instruction::Select(cond_register, true_register, false_register, NewRegister())); this->last_register = true_register;
} }
Instruction AllocTensorFromType(const TensorTypeNode* ttype) { Instruction AllocTensorFromType(const TensorTypeNode* ttype) {
...@@ -464,6 +513,160 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -464,6 +513,160 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
} }
}; };
/*!
* \brief Compile a match value
* Generate byte code that compute the value specificed in val
*
* \return The register number assigned for the final value
*/
RegName CompileMatchValue(MatchValuePtr val, VMCompiler* compiler) {
if (std::dynamic_pointer_cast<RegisterValue>(val)) {
auto r = std::dynamic_pointer_cast<RegisterValue>(val);
return r->rergister_num;
} else {
auto path = std::dynamic_pointer_cast<AccessField>(val);
auto p = CompileMatchValue(path->parent, compiler);
compiler->Emit(Instruction::GetField(p, path->index, compiler->NewRegister()));
path->reg = compiler->last_register;
return path->reg;
}
}
/*!
* \brief Condition in a decision tree
*/
struct ConditionNode {
virtual ~ConditionNode() {}
};
using ConditionNodePtr = std::shared_ptr<ConditionNode>;
/*!
* \brief A var binding condition
*/
struct VarBinding : ConditionNode {
Var var;
MatchValuePtr val;
VarBinding(Var var, MatchValuePtr val)
: var(var), val(val) {}
~VarBinding() {}
};
/*!
* \brief Compare the tag of the object
*/
struct TagCompare : ConditionNode {
/*! \brief The object to be examined */
MatchValuePtr obj;
/*! \brief The expected tag */
int target_tag;
TagCompare(MatchValuePtr obj, size_t target)
: obj(obj), target_tag(target) {
}
~TagCompare() {}
};
using TreeNodePtr = typename relay::TreeNode<ConditionNodePtr>::pointer;
using TreeLeafNode = relay::TreeLeafNode<ConditionNodePtr>;
using TreeLeafFatalNode = relay::TreeLeafFatalNode<ConditionNodePtr>;
using TreeBranchNode = relay::TreeBranchNode<ConditionNodePtr>;
void CompileTreeNode(TreeNodePtr tree, VMCompiler* compiler) {
if (std::dynamic_pointer_cast<TreeLeafNode>(tree)) {
auto node = std::dynamic_pointer_cast<TreeLeafNode>(tree);
compiler->VisitExpr(node->body);
} else if (std::dynamic_pointer_cast<TreeLeafFatalNode>(tree)) {
compiler->Emit(Instruction::Fatal());
} else if (std::dynamic_pointer_cast<TreeBranchNode>(tree)) {
auto node = std::dynamic_pointer_cast<TreeBranchNode>(tree);
if (std::dynamic_pointer_cast<TagCompare>(node->cond)) {
// For Tag compariton, generate branches
auto cond = std::dynamic_pointer_cast<TagCompare>(node->cond);
auto r = CompileMatchValue(cond->obj, compiler);
compiler->Emit(Instruction::GetTag(r, compiler->NewRegister()));
auto operand1 = compiler->last_register;
compiler->Emit(Instruction::LoadConsti(cond->target_tag, compiler->NewRegister()));
auto operand2 = compiler->last_register;
compiler->Emit(Instruction::If(operand1, operand2, 1, 0));
auto cond_offset = compiler->instructions.size() - 1;
CompileTreeNode(node->then_branch, compiler);
auto if_reg = compiler->last_register;
compiler->Emit(Instruction::Goto(1));
auto goto_offset = compiler->instructions.size() - 1;
CompileTreeNode(node->else_branch, compiler);
auto else_reg = compiler->last_register;
compiler->Emit(Instruction::Move(else_reg, if_reg));
compiler->last_register = if_reg;
auto else_offset = compiler->instructions.size() - 1;
// Fixing offsets
compiler->instructions[cond_offset].if_op.false_offset = goto_offset - cond_offset + 1;
compiler->instructions[goto_offset].pc_offset = else_offset - goto_offset + 1;
} else {
// For other non-branch conditions, move to then_branch directly
auto cond = std::dynamic_pointer_cast<VarBinding>(node->cond);
compiler->var_register_map[cond->var] = CompileMatchValue(cond->val, compiler);
CompileTreeNode(node->then_branch, compiler);
}
}
}
TreeNodePtr BuildDecisionTreeFromPattern(MatchValuePtr data,
Pattern pattern,
TreeNodePtr then_branch,
TreeNodePtr else_branch) {
if (pattern.as<PatternWildcardNode>()) {
// We ignore wildcard binding since it's not producing new vars
return then_branch;
} else if (pattern.as<PatternVarNode>()) {
auto pat = pattern.as<PatternVarNode>();
auto pattern = GetRef<PatternVar>(pat);
auto cond = std::make_shared<VarBinding>(pattern->var, data);
return TreeBranchNode::Make(cond, then_branch, else_branch);
} else {
auto pat = pattern.as<PatternConstructorNode>();
auto pattern = GetRef<PatternConstructor>(pat);
auto tag = pattern->constructor->tag;
size_t field_index = 0;
for (auto& p : pattern->patterns) {
auto d = std::make_shared<AccessField>(data, field_index);
then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
field_index++;
}
auto cond = std::make_shared<TagCompare>(data, tag);
return TreeBranchNode::Make(cond, then_branch, else_branch);
}
}
TreeNodePtr BuildDecisionTreeFromClause(MatchValuePtr data,
Clause clause,
TreeNodePtr else_branch) {
return BuildDecisionTreeFromPattern(data, clause->lhs,
TreeLeafNode::Make(clause->rhs), else_branch);
}
TreeNodePtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array<Clause> clauses) {
// When nothing matches, the VM throws fatal error
TreeNodePtr else_branch = TreeLeafFatalNode::Make();
// Start from the last clause
for (auto it = clauses.rbegin(); it != clauses.rend(); ++it) {
else_branch = BuildDecisionTreeFromClause(data, *it, else_branch);
}
return else_branch;
}
void CompileMatch(Match match, VMCompiler* compiler) {
auto data = std::make_shared<RegisterValue>(compiler->last_register);
auto decision_tree = BuildDecisionTreeFromClauses(data, match->clauses);
CompileTreeNode(decision_tree, compiler);
}
void PopulatePackedFuncMap(const std::vector<LoweredFunc>& lowered_funcs, void PopulatePackedFuncMap(const std::vector<LoweredFunc>& lowered_funcs,
std::vector<PackedFunc>* packed_funcs) { std::vector<PackedFunc>* packed_funcs) {
runtime::Module mod; runtime::Module mod;
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
* Copyright (c) 2018 by Contributors. * Copyright (c) 2018 by Contributors.
* *
* \file tvm/relay/pass/pass_util.h * \file tvm/relay/pass/pass_util.h
* \brief Utilities for writing * \brief Utilities for writing passes
*/ */
#ifndef TVM_RELAY_PASS_PASS_UTIL_H_ #ifndef TVM_RELAY_PASS_PASS_UTIL_H_
#define TVM_RELAY_PASS_PASS_UTIL_H_ #define TVM_RELAY_PASS_PASS_UTIL_H_
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/attrs/transform.h> #include <tvm/relay/attrs/transform.h>
#include <memory>
#include <unordered_map> #include <unordered_map>
namespace tvm { namespace tvm {
...@@ -108,6 +109,63 @@ inline bool IsAtomic(const Expr& e) { ...@@ -108,6 +109,63 @@ inline bool IsAtomic(const Expr& e) {
return e.as<VarNode>() || e.as<OpNode>() || e.as<ConstructorNode>() || e.as<GlobalVarNode>(); return e.as<VarNode>() || e.as<OpNode>() || e.as<ConstructorNode>() || e.as<GlobalVarNode>();
} }
template<typename ConditionNodePtr>
struct TreeNode {
typedef std::shared_ptr<TreeNode<ConditionNodePtr>> pointer;
virtual ~TreeNode() {}
};
template<typename ConditionNodePtr>
struct TreeLeafNode : TreeNode<ConditionNodePtr> {
using TreeNodePtr = typename TreeNode<ConditionNodePtr>::pointer;
Expr body;
explicit TreeLeafNode(Expr body): body(body) {}
static TreeNodePtr Make(Expr body) {
return std::make_shared<TreeLeafNode>(body);
}
~TreeLeafNode() {}
};
template<typename ConditionNodePtr>
struct TreeLeafFatalNode : TreeNode<ConditionNodePtr> {
using TreeNodePtr = typename TreeNode<ConditionNodePtr>::pointer;
TreeLeafFatalNode() = default;
static TreeNodePtr Make() {
return std::make_shared<TreeLeafFatalNode>();
}
~TreeLeafFatalNode() {}
};
template<typename ConditionNodePtr>
struct TreeBranchNode : TreeNode<ConditionNodePtr> {
using TreeNodePtr = typename TreeNode<ConditionNodePtr>::pointer;
ConditionNodePtr cond;
TreeNodePtr then_branch;
TreeNodePtr else_branch;
TreeBranchNode(ConditionNodePtr cond,
TreeNodePtr then_branch,
TreeNodePtr else_branch)
: cond(cond), then_branch(then_branch), else_branch(else_branch) {}
static TreeNodePtr Make(ConditionNodePtr cond,
TreeNodePtr then_branch,
TreeNodePtr else_branch) {
return std::make_shared<TreeBranchNode>(cond, then_branch, else_branch);
}
~TreeBranchNode() {}
};
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_PASS_UTIL_H_ #endif // TVM_RELAY_PASS_PASS_UTIL_H_
...@@ -58,10 +58,7 @@ Instruction::Instruction(const Instruction& instr) { ...@@ -58,10 +58,7 @@ Instruction::Instruction(const Instruction& instr) {
case Opcode::Move: case Opcode::Move:
this->from = instr.from; this->from = instr.from;
return; return;
case Opcode::Select: case Opcode::Fatal:
this->select_cond = instr.select_cond;
this->select_op1 = instr.select_op1;
this->select_op2 = instr.select_op2;
return; return;
case Opcode::Ret: case Opcode::Ret:
this->result = instr.result; this->result = instr.result;
...@@ -103,17 +100,21 @@ Instruction::Instruction(const Instruction& instr) { ...@@ -103,17 +100,21 @@ Instruction::Instruction(const Instruction& instr) {
this->invoke_args_registers = Duplicate<RegName>(instr.invoke_args_registers, instr.num_args); this->invoke_args_registers = Duplicate<RegName>(instr.invoke_args_registers, instr.num_args);
return; return;
case Opcode::If: case Opcode::If:
this->if_cond = instr.if_cond; this->if_op = instr.if_op;
this->true_offset = instr.true_offset;
this->false_offset = instr.false_offset;
return; return;
case Opcode::LoadConst: case Opcode::LoadConst:
this->const_index = instr.const_index; this->const_index = instr.const_index;
return; return;
case Opcode::LoadConsti:
this->load_consti = instr.load_consti;
return;
case Opcode::GetField: case Opcode::GetField:
this->object = instr.object; this->object = instr.object;
this->field_index = instr.field_index; this->field_index = instr.field_index;
return; return;
case Opcode::GetTag:
this->get_tag = instr.get_tag;
return;
case Opcode::Goto: case Opcode::Goto:
this->pc_offset = instr.pc_offset; this->pc_offset = instr.pc_offset;
return; return;
...@@ -139,10 +140,10 @@ Instruction& Instruction::operator=(const Instruction& instr) { ...@@ -139,10 +140,10 @@ Instruction& Instruction::operator=(const Instruction& instr) {
case Opcode::Move: case Opcode::Move:
this->from = instr.from; this->from = instr.from;
return *this; return *this;
case Opcode::Select: case Opcode::Fatal:
this->select_cond = instr.select_cond; return *this;
this->select_op1 = instr.select_op1; case Opcode::LoadConsti:
this->select_op2 = instr.select_op2; this->load_consti = instr.load_consti;
return *this; return *this;
case Opcode::Ret: case Opcode::Ret:
this->result = instr.result; this->result = instr.result;
...@@ -189,9 +190,7 @@ Instruction& Instruction::operator=(const Instruction& instr) { ...@@ -189,9 +190,7 @@ Instruction& Instruction::operator=(const Instruction& instr) {
this->invoke_args_registers = Duplicate<RegName>(instr.invoke_args_registers, instr.num_args); this->invoke_args_registers = Duplicate<RegName>(instr.invoke_args_registers, instr.num_args);
return *this; return *this;
case Opcode::If: case Opcode::If:
this->if_cond = instr.if_cond; this->if_op = instr.if_op;
this->true_offset = instr.true_offset;
this->false_offset = instr.false_offset;
return *this; return *this;
case Opcode::LoadConst: case Opcode::LoadConst:
this->const_index = instr.const_index; this->const_index = instr.const_index;
...@@ -200,6 +199,9 @@ Instruction& Instruction::operator=(const Instruction& instr) { ...@@ -200,6 +199,9 @@ Instruction& Instruction::operator=(const Instruction& instr) {
this->object = instr.object; this->object = instr.object;
this->field_index = instr.field_index; this->field_index = instr.field_index;
return *this; return *this;
case Opcode::GetTag:
this->get_tag = instr.get_tag;
return *this;
case Opcode::Goto: case Opcode::Goto:
this->pc_offset = instr.pc_offset; this->pc_offset = instr.pc_offset;
return *this; return *this;
...@@ -213,13 +215,15 @@ Instruction& Instruction::operator=(const Instruction& instr) { ...@@ -213,13 +215,15 @@ Instruction& Instruction::operator=(const Instruction& instr) {
Instruction::~Instruction() { Instruction::~Instruction() {
switch (this->op) { switch (this->op) {
case Opcode::Move: case Opcode::Move:
case Opcode::Select:
case Opcode::Ret: case Opcode::Ret:
case Opcode::AllocTensorReg: case Opcode::AllocTensorReg:
case Opcode::If: case Opcode::If:
case Opcode::LoadConst: case Opcode::LoadConst:
case Opcode::GetField: case Opcode::GetField:
case Opcode::GetTag:
case Opcode::Goto: case Opcode::Goto:
case Opcode::LoadConsti:
case Opcode::Fatal:
return; return;
case Opcode::AllocTensor: case Opcode::AllocTensor:
delete this->alloc_tensor.shape; delete this->alloc_tensor.shape;
...@@ -252,6 +256,12 @@ Instruction Instruction::Ret(RegName result) { ...@@ -252,6 +256,12 @@ Instruction Instruction::Ret(RegName result) {
return instr; return instr;
} }
Instruction Instruction::Fatal() {
Instruction instr;
instr.op = Opcode::Fatal;
return instr;
}
Instruction Instruction::InvokePacked(Index packed_index, Index arity, Index output_size, Instruction Instruction::InvokePacked(Index packed_index, Index arity, Index output_size,
const std::vector<RegName>& args) { const std::vector<RegName>& args) {
Instruction instr; Instruction instr;
...@@ -325,22 +335,21 @@ Instruction Instruction::GetField(RegName object, Index field_index, RegName dst ...@@ -325,22 +335,21 @@ Instruction Instruction::GetField(RegName object, Index field_index, RegName dst
return instr; return instr;
} }
Instruction Instruction::If(RegName cond, Index true_branch, Index false_branch) { Instruction Instruction::GetTag(RegName object, RegName dst) {
Instruction instr; Instruction instr;
instr.op = Opcode::If; instr.op = Opcode::GetTag;
instr.if_cond = cond; instr.dst = dst;
instr.true_offset = true_branch; instr.get_tag.object = object;
instr.false_offset = false_branch;
return instr; return instr;
} }
Instruction Instruction::Select(RegName cond, RegName op1, RegName op2, RegName dst) { Instruction Instruction::If(RegName test, RegName target, Index true_branch, Index false_branch) {
Instruction instr; Instruction instr;
instr.op = Opcode::Select; instr.op = Opcode::If;
instr.dst = dst; instr.if_op.test = test;
instr.select_cond = cond; instr.if_op.target = target;
instr.select_op1 = op1; instr.if_op.true_offset = true_branch;
instr.select_op2 = op2; instr.if_op.false_offset = false_branch;
return instr; return instr;
} }
...@@ -387,6 +396,14 @@ Instruction Instruction::LoadConst(Index const_index, RegName dst) { ...@@ -387,6 +396,14 @@ Instruction Instruction::LoadConst(Index const_index, RegName dst) {
return instr; return instr;
} }
Instruction Instruction::LoadConsti(size_t val, RegName dst) {
Instruction instr;
instr.op = Opcode::LoadConsti;
instr.dst = dst;
instr.load_consti.val = val;
return instr;
}
Instruction Instruction::Move(RegName src, RegName dst) { Instruction Instruction::Move(RegName src, RegName dst) {
Instruction instr; Instruction instr;
instr.op = Opcode::Move; instr.op = Opcode::Move;
...@@ -437,6 +454,10 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { ...@@ -437,6 +454,10 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
os << "ret $" << instr.result; os << "ret $" << instr.result;
break; break;
} }
case Opcode::Fatal: {
os << "fatal";
break;
}
case Opcode::InvokePacked: { case Opcode::InvokePacked: {
os << "invoke_packed PackedFunc[" << instr.packed_index << "](in: $" os << "invoke_packed PackedFunc[" << instr.packed_index << "](in: $"
<< StrJoin<RegName>(instr.packed_args, 0, instr.arity - instr.output_size, ",$") << StrJoin<RegName>(instr.packed_args, 0, instr.arity - instr.output_size, ",$")
...@@ -471,8 +492,8 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { ...@@ -471,8 +492,8 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
break; break;
} }
case Opcode::If: { case Opcode::If: {
os << "if " << "$" << instr.if_cond << " " << instr.true_offset << " " os << "if " << "$" << instr.if_op.test << " " << instr.if_op.target << " "
<< instr.false_offset; << instr.if_op.true_offset << " " << instr.if_op.false_offset;
break; break;
} }
case Opcode::Invoke: { case Opcode::Invoke: {
...@@ -491,18 +512,21 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { ...@@ -491,18 +512,21 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
os << "load_const $" << instr.dst << " Const[" << instr.const_index << "]"; os << "load_const $" << instr.dst << " Const[" << instr.const_index << "]";
break; break;
} }
case Opcode::LoadConsti: {
os << "load_consti $" << instr.dst << " Const[" << instr.load_consti.val << "]";
break;
}
case Opcode::GetField: { case Opcode::GetField: {
os << "get_field $" << instr.dst << " $" << instr.object << "[" os << "get_field $" << instr.dst << " $" << instr.object << "["
<< instr.field_index << "]"; << instr.field_index << "]";
break; break;
} }
case Opcode::Goto: { case Opcode::GetTag: {
os << "goto " << instr.pc_offset; os << "get_tag $" << instr.dst << " $" << instr.get_tag.object;
break; break;
} }
case Opcode::Select: { case Opcode::Goto: {
os << "select $" << instr.dst << " $" << instr.select_cond << " $" os << "goto " << instr.pc_offset;
<< instr.select_op1 << " $" << instr.select_op2;
break; break;
} }
default: default:
...@@ -617,6 +641,21 @@ inline Object VirtualMachine::ReadRegister(Index r) const { ...@@ -617,6 +641,21 @@ inline Object VirtualMachine::ReadRegister(Index r) const {
return frames.back().register_file[r]; return frames.back().register_file[r];
} }
inline int32_t VirtualMachine::LoadScalarInt(Index r) const {
int32_t result;
const auto& obj = ReadRegister(r);
NDArray array = ToNDArray(obj).CopyTo({kDLCPU, 0});
if (array->dtype.bits <= 8) {
result = reinterpret_cast<int8_t*>(array->data)[0];
} else if (array->dtype.bits <= 16) {
result = reinterpret_cast<int16_t*>(array->data)[0];
} else {
result = reinterpret_cast<int32_t*>(array->data)[0];
}
return result;
}
void VirtualMachine::Run() { void VirtualMachine::Run() {
CHECK(this->code); CHECK(this->code);
this->pc = 0; this->pc = 0;
...@@ -632,20 +671,26 @@ void VirtualMachine::Run() { ...@@ -632,20 +671,26 @@ void VirtualMachine::Run() {
switch (instr.op) { switch (instr.op) {
case Opcode::Move: { case Opcode::Move: {
Object from_obj; Object from_obj;
if (instr.from == 0) { from_obj = ReadRegister(instr.from);
from_obj = return_register;
} else {
from_obj = ReadRegister(instr.from);
}
WriteRegister(instr.dst, from_obj); WriteRegister(instr.dst, from_obj);
pc++; pc++;
goto main_loop; goto main_loop;
} }
case Opcode::Fatal: {
throw std::runtime_error("VM encountered fatal error");
}
case Opcode::LoadConst: { case Opcode::LoadConst: {
WriteRegister(instr.dst, this->constants[instr.const_index]); WriteRegister(instr.dst, this->constants[instr.const_index]);
pc++; pc++;
goto main_loop; goto main_loop;
} }
case Opcode::LoadConsti: {
auto tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, {kDLCPU, 0});
reinterpret_cast<int32_t*>(tensor->data)[0] = instr.load_consti.val;
WriteRegister(instr.dst, Object::Tensor(tensor));
pc++;
goto main_loop;
}
case Opcode::Invoke: { case Opcode::Invoke: {
std::vector<Object> args; std::vector<Object> args;
for (Index i = 0; i < instr.num_args; ++i) { for (Index i = 0; i < instr.num_args; ++i) {
...@@ -695,25 +740,34 @@ void VirtualMachine::Run() { ...@@ -695,25 +740,34 @@ void VirtualMachine::Run() {
pc++; pc++;
goto main_loop; goto main_loop;
} }
case Opcode::GetTag: {
auto object = ReadRegister(instr.get_tag.object);
CHECK(object->tag == ObjectTag::kDatatype)
<< "Object is not data type object, register "
<< instr.get_tag.object << ", Object tag "
<< static_cast<int>(object->tag);
const auto& data = object.AsDatatype();
auto tag = data->tag;
auto tag_tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, {kDLCPU, 0});
reinterpret_cast<int32_t*>(tag_tensor->data)[0] = tag;
WriteRegister(instr.dst, Object::Tensor(tag_tensor));
pc++;
goto main_loop;
}
case Opcode::Goto: { case Opcode::Goto: {
pc += instr.pc_offset; pc += instr.pc_offset;
goto main_loop; goto main_loop;
} }
case Opcode::If: { case Opcode::If: {
// How do we do this efficiently? int32_t test_val = LoadScalarInt(instr.if_op.test);
DLContext cpu_ctx; int32_t target_val = LoadScalarInt(instr.if_op.target);
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
const auto& cond = ReadRegister(instr.if_cond); if (test_val == target_val) {
NDArray cpu_array = ToNDArray(cond).CopyTo(cpu_ctx); CHECK_NE(instr.if_op.true_offset, 0);
// CHECK_EQ(cpu_array->dtype, Bool()); pc += instr.if_op.true_offset;
bool branch = reinterpret_cast<uint8_t*>(cpu_array->data)[0];
if (branch) {
pc += instr.true_offset;
} else { } else {
pc += instr.false_offset; CHECK_NE(instr.if_op.false_offset, 0);
pc += instr.if_op.false_offset;
} }
goto main_loop; goto main_loop;
...@@ -768,26 +822,6 @@ void VirtualMachine::Run() { ...@@ -768,26 +822,6 @@ void VirtualMachine::Run() {
pc++; pc++;
goto main_loop; goto main_loop;
} }
case Opcode::Select: {
DLContext cpu_ctx;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
auto cond = ReadRegister(instr.select_cond);
NDArray cpu_array = ToNDArray(cond).CopyTo(cpu_ctx);
// CHECK_EQ(TVMType2Type(cpu_array->dtype), Bool());
bool branch = reinterpret_cast<uint8_t*>(cpu_array->data)[0];
if (branch) {
auto op1 = ReadRegister(instr.select_op1);
WriteRegister(instr.dst, op1);
} else {
auto op2 = ReadRegister(instr.select_op2);
WriteRegister(instr.dst, op2);
}
pc++;
goto main_loop;
}
case Opcode::Ret: { case Opcode::Ret: {
// If we have hit the point from which we started // If we have hit the point from which we started
// running, we should return to the caller breaking // running, we should return to the caller breaking
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import os import os
from nose.tools import nottest from nose.tools import nottest, raises
import tvm import tvm
import numpy as np import numpy as np
...@@ -39,6 +39,15 @@ def veval(f, *args, ctx=tvm.cpu()): ...@@ -39,6 +39,15 @@ def veval(f, *args, ctx=tvm.cpu()):
else: else:
return ex.evaluate()(*args) return ex.evaluate()(*args)
def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.interpreter.TensorValue):
return [o.data.asnumpy().tolist()]
if isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
result = []
for f in o.fields:
result.extend(vmobj_to_list(f))
return result
def test_split(): def test_split():
x = relay.var('x', shape=(12,)) x = relay.var('x', shape=(12,))
y = relay.split(x, 3, axis=0).astuple() y = relay.split(x, 3, axis=0).astuple()
...@@ -186,15 +195,6 @@ def test_tuple_second(): ...@@ -186,15 +195,6 @@ def test_tuple_second():
tvm.testing.assert_allclose(result.asnumpy(), j_data) tvm.testing.assert_allclose(result.asnumpy(), j_data)
def test_list_constructor(): def test_list_constructor():
def to_list(o):
if isinstance(o, tvm.relay.backend.interpreter.TensorValue):
return [o.data.asnumpy().tolist()]
if isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
result = []
for f in o.fields:
result.extend(to_list(f))
return result
mod = relay.Module() mod = relay.Module()
p = Prelude(mod) p = Prelude(mod)
...@@ -202,11 +202,6 @@ def test_list_constructor(): ...@@ -202,11 +202,6 @@ def test_list_constructor():
cons = p.cons cons = p.cons
l = p.l l = p.l
# remove all functions to not have pattern match to pass vm compilation
# TODO(wweic): remove the hack and implement pattern match
for v, _ in mod.functions.items():
mod[v] = relay.const(0)
one2 = cons(relay.const(1), nil()) one2 = cons(relay.const(1), nil())
one3 = cons(relay.const(2), one2) one3 = cons(relay.const(2), one2)
one4 = cons(relay.const(3), one3) one4 = cons(relay.const(3), one3)
...@@ -215,7 +210,7 @@ def test_list_constructor(): ...@@ -215,7 +210,7 @@ def test_list_constructor():
mod["main"] = f mod["main"] = f
result = veval(mod)() result = veval(mod)()
obj = to_list(result) obj = vmobj_to_list(result)
tvm.testing.assert_allclose(obj, np.array([3,2,1])) tvm.testing.assert_allclose(obj, np.array([3,2,1]))
def test_let_tensor(): def test_let_tensor():
...@@ -256,13 +251,6 @@ def test_compose(): ...@@ -256,13 +251,6 @@ def test_compose():
compose = p.compose compose = p.compose
# remove all functions to not have pattern match to pass vm compilation
# TODO(wweic): remove the hack and implement pattern match
for v, _ in mod.functions.items():
if v.name_hint == 'compose':
continue
mod[v] = relay.const(0)
# add_one = fun x -> x + 1 # add_one = fun x -> x + 1
sb = relay.ScopeBuilder() sb = relay.ScopeBuilder()
x = relay.var('x', 'float32') x = relay.var('x', 'float32')
...@@ -291,6 +279,215 @@ def test_compose(): ...@@ -291,6 +279,215 @@ def test_compose():
tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0) tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0)
def test_list_hd():
mod = relay.Module()
p = Prelude(mod)
nil = p.nil
cons = p.cons
l = p.l
hd = p.hd
one2 = cons(relay.const(1), nil())
one3 = cons(relay.const(2), one2)
one4 = cons(relay.const(3), one3)
three = hd(one4)
f = relay.Function([], three)
mod["main"] = f
result = veval(mod)()
tvm.testing.assert_allclose(result.asnumpy(), 3)
@raises(Exception)
def test_list_tl_empty_list():
mod = relay.Module()
p = Prelude(mod)
nil = p.nil
l = p.l
tl = p.tl
f = relay.Function([], tl(nil()))
mod["main"] = f
result = veval(mod)()
print(result)
def test_list_tl():
mod = relay.Module()
p = Prelude(mod)
nil = p.nil
cons = p.cons
l = p.l
tl = p.tl
one2 = cons(relay.const(1), nil())
one3 = cons(relay.const(2), one2)
one4 = cons(relay.const(3), one3)
f = relay.Function([], tl(one4))
mod["main"] = f
result = veval(mod)()
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([2,1]))
def test_list_nth():
expected = list(range(10))
for i in range(len(expected)):
mod = relay.Module()
p = Prelude(mod)
nil = p.nil
cons = p.cons
nth = p.nth
l = nil()
for i in reversed(expected):
l = cons(relay.const(i), l)
f = relay.Function([], nth(l, relay.const(i)))
mod["main"] = f
result = veval(mod)()
tvm.testing.assert_allclose(result.asnumpy(), expected[i])
def test_list_update():
expected = list(range(10))
mod = relay.Module()
p = Prelude(mod)
nil = p.nil
cons = p.cons
update = p.update
l = nil()
# create zero initialized list
for i in range(len(expected)):
l = cons(relay.const(0), l)
# set value
for i, v in enumerate(expected):
l = update(l, relay.const(i), relay.const(v))
f = relay.Function([], l)
mod["main"] = f
result = veval(mod)()
tvm.testing.assert_allclose(vmobj_to_list(result), np.array(expected))
def test_list_length():
expected = list(range(10))
mod = relay.Module()
p = Prelude(mod)
nil = p.nil
cons = p.cons
length = p.length
l = nil()
# create zero initialized list
for i in range(len(expected)):
l = cons(relay.const(0), l)
l = length(l)
f = relay.Function([], l)
mod["main"] = f
result = veval(mod)()
tvm.testing.assert_allclose(result.asnumpy(), 10)
def test_list_map():
mod = relay.Module()
p = Prelude(mod)
x = relay.var('x', 'int32')
add_one_func = relay.Function([x], relay.const(1) + x)
nil = p.nil
cons = p.cons
map = p.map
l = cons(relay.const(2), cons(relay.const(1), nil()))
f = relay.Function([], map(add_one_func, l))
mod["main"] = f
result = veval(mod)()
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 2]))
def test_list_foldl():
mod = relay.Module()
p = Prelude(mod)
nil = p.nil
cons = p.cons
foldl = p.foldl
x = relay.var("x")
y = relay.var("y")
rev_dup_func = relay.Function([y, x], cons(x, cons(x, y)))
l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil())))
f = relay.Function([], foldl(rev_dup_func, nil(), l))
mod["main"] = f
result = veval(mod)()
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 3, 2, 2, 1, 1]))
def test_list_foldr():
mod = relay.Module()
p = Prelude(mod)
nil = p.nil
cons = p.cons
foldr = p.foldr
x = relay.var("x")
y = relay.var("y")
identity_func = relay.Function([x, y], cons(x, y))
l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil())))
f = relay.Function([], foldr(identity_func, nil(), l))
mod["main"] = f
result = veval(mod)()
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([1, 2, 3]))
def test_list_sum():
mod = relay.Module()
p = Prelude(mod)
nil = p.nil
cons = p.cons
sum = p.sum
l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil())))
f = relay.Function([], sum(l))
mod["main"] = f
result = veval(mod)()
tvm.testing.assert_allclose(result.asnumpy(), 6)
def test_list_filter():
mod = relay.Module()
p = Prelude(mod)
nil = p.nil
cons = p.cons
filter = p.filter
x = relay.var("x", 'int32')
greater_than_one = relay.Function([x], x > relay.const(1))
l = cons(relay.const(1),
cons(relay.const(3),
cons(relay.const(1),
cons(relay.const(5),
cons(relay.const(1), nil())))))
f = relay.Function([], filter(greater_than_one, l))
mod["main"] = f
result = veval(mod)()
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 5]))
def test_closure(): def test_closure():
x = relay.var('x', shape=()) x = relay.var('x', shape=())
y = relay.var('y', shape=()) y = relay.var('y', shape=())
...@@ -315,6 +512,15 @@ if __name__ == "__main__": ...@@ -315,6 +512,15 @@ if __name__ == "__main__":
test_let_tensor() test_let_tensor()
test_split() test_split()
test_split_no_fuse() test_split_no_fuse()
# TODO(@jroesch): restore when match is supported test_list_constructor()
# test_list_constructor() test_list_tl_empty_list()
test_list_tl()
test_list_nth()
test_list_update()
test_list_length()
test_list_map()
test_list_foldl()
test_list_foldr()
test_list_sum()
test_list_filter()
test_closure() test_closure()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment