Commit 93d1c06d by Wei Chen Committed by Jared Roesch

[Relay][VM]Compiling pattern matching (#3470)

* [Relay][VM]Compiling pattern matching

* Fix lint

* Remove debug code

* Move TreeNode definition

* merge ifi and selecti, todo: remove them

* fix lint

* remove ifi and selecti

* rename GetTagi to GetTag

* fix dltype

* fix more dltype

* Generalize If and select, and rename to Ifi and Selecti

* Fix lint

* Rename Ifi to If

* Change register default to match value

* Remove bad specialization for Move

* Stop use Select

* Remove Select

* TreeNode refactor

* Change entry_func name

* Remove Cmp due to rebase issue
parent be776dc7
...@@ -61,9 +61,11 @@ enum class Opcode { ...@@ -61,9 +61,11 @@ enum class Opcode {
AllocClosure = 8U, AllocClosure = 8U,
GetField = 9U, GetField = 9U,
If = 10U, If = 10U,
Select = 11U, LoadConst = 11U,
LoadConst = 12U, Goto = 12U,
Goto = 13U GetTag = 13U,
LoadConsti = 14U,
Fatal = 15U,
}; };
/*! \brief A single virtual machine instruction. /*! \brief A single virtual machine instruction.
...@@ -123,22 +125,16 @@ struct Instruction { ...@@ -123,22 +125,16 @@ struct Instruction {
/*! \brief The arguments to pass to the packed function. */ /*! \brief The arguments to pass to the packed function. */
RegName* packed_args; RegName* packed_args;
}; };
struct /* Select Operands */ {
/*! \brief The condition of select. */
RegName select_cond;
/*! \brief The true branch. */
RegName select_op1;
/*! \brief The false branch. */
RegName select_op2;
};
struct /* If Operands */ { struct /* If Operands */ {
/*! \brief The register containing the condition value. */ /*! \brief The register containing the test value. */
RegName if_cond; RegName test;
/*! \brief The register containing the target value. */
RegName target;
/*! \brief The program counter offset for the true branch. */ /*! \brief The program counter offset for the true branch. */
Index true_offset; Index true_offset;
/*! \brief The program counter offset for the false branch. */ /*! \brief The program counter offset for the false branch. */
Index false_offset; Index false_offset;
}; } if_op;
struct /* Invoke Operands */ { struct /* Invoke Operands */ {
/*! \brief The function to call. */ /*! \brief The function to call. */
Index func_index; Index func_index;
...@@ -151,6 +147,10 @@ struct Instruction { ...@@ -151,6 +147,10 @@ struct Instruction {
/* \brief The index into the constant pool. */ /* \brief The index into the constant pool. */
Index const_index; Index const_index;
}; };
struct /* LoadConsti Operands */ {
/* \brief The index into the constant pool. */
size_t val;
} load_consti;
struct /* Jump Operands */ { struct /* Jump Operands */ {
/*! \brief The jump offset. */ /*! \brief The jump offset. */
Index pc_offset; Index pc_offset;
...@@ -161,6 +161,10 @@ struct Instruction { ...@@ -161,6 +161,10 @@ struct Instruction {
/*! \brief The field to read out. */ /*! \brief The field to read out. */
Index field_index; Index field_index;
}; };
struct /* GetTag Operands */ {
/*! \brief The register to project from. */
RegName object;
} get_tag;
struct /* AllocDatatype Operands */ { struct /* AllocDatatype Operands */ {
/*! \brief The datatype's constructor tag. */ /*! \brief The datatype's constructor tag. */
Index constructor_tag; Index constructor_tag;
...@@ -179,19 +183,15 @@ struct Instruction { ...@@ -179,19 +183,15 @@ struct Instruction {
}; };
}; };
/*! \brief Construct a select instruction.
* \param cond The condition register.
* \param op1 The true register.
* \param op2 The false register.
* \param dst The destination register.
* \return The select instruction.
*/
static Instruction Select(RegName cond, RegName op1, RegName op2, RegName dst);
/*! \brief Construct a return instruction. /*! \brief Construct a return instruction.
* \param return_reg The register containing the return value. * \param return_reg The register containing the return value.
* \return The return instruction. * \return The return instruction.
* */ * */
static Instruction Ret(RegName return_reg); static Instruction Ret(RegName return_reg);
/*! \brief Construct a fatal instruction.
* \return The fatal instruction.
* */
static Instruction Fatal();
/*! \brief Construct a invoke packed instruction. /*! \brief Construct a invoke packed instruction.
* \param packed_index The index of the packed function. * \param packed_index The index of the packed function.
* \param arity The arity of the function. * \param arity The arity of the function.
...@@ -240,13 +240,20 @@ struct Instruction { ...@@ -240,13 +240,20 @@ struct Instruction {
* \return The get field instruction. * \return The get field instruction.
*/ */
static Instruction GetField(RegName object_reg, Index field_index, RegName dst); static Instruction GetField(RegName object_reg, Index field_index, RegName dst);
/*! \brief Construct a get_tag instruction.
* \param object_reg The register containing the object to project from.
* \param dst The destination register.
* \return The get_tag instruction.
*/
static Instruction GetTag(RegName object_reg, RegName dst);
/*! \brief Construct an if instruction. /*! \brief Construct an if instruction.
* \param cond_reg The register containing the condition. * \param test The register containing the test value.
* \param target The register containing the target value.
* \param true_branch The offset to the true branch. * \param true_branch The offset to the true branch.
* \param false_branch The offset to the false branch. * \param false_branch The offset to the false branch.
* \return The if instruction. * \return The if instruction.
*/ */
static Instruction If(RegName cond_reg, Index true_branch, Index false_branch); static Instruction If(RegName test, RegName target, Index true_branch, Index false_branch);
/*! \brief Construct a goto instruction. /*! \brief Construct a goto instruction.
* \param pc_offset The offset from the current pc. * \param pc_offset The offset from the current pc.
* \return The goto instruction. * \return The goto instruction.
...@@ -272,6 +279,12 @@ struct Instruction { ...@@ -272,6 +279,12 @@ struct Instruction {
* \return The load constant instruction. * \return The load constant instruction.
*/ */
static Instruction LoadConst(Index const_index, RegName dst); static Instruction LoadConst(Index const_index, RegName dst);
/*! \brief Construct a load_constanti instruction.
* \param val The interger constant value.
* \param dst The destination register.
* \return The load_constanti instruction.
*/
static Instruction LoadConsti(size_t val, RegName dst);
/*! \brief Construct a move instruction. /*! \brief Construct a move instruction.
* \param src The source register. * \param src The source register.
* \param dst The destination register. * \param dst The destination register.
...@@ -398,6 +411,12 @@ struct VirtualMachine { ...@@ -398,6 +411,12 @@ struct VirtualMachine {
*/ */
inline Object ReadRegister(RegName reg) const; inline Object ReadRegister(RegName reg) const;
/*! \brief Read a VM register and cast it to int32_t
* \param reg The register to read from.
* \return The read scalar.
*/
int32_t LoadScalarInt(RegName reg) const;
/*! \brief Invoke a VM function. /*! \brief Invoke a VM function.
* \param func The function. * \param func The function.
* \param args The arguments to the function. * \param args The arguments to the function.
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
* Copyright (c) 2018 by Contributors. * Copyright (c) 2018 by Contributors.
* *
* \file tvm/relay/pass/pass_util.h * \file tvm/relay/pass/pass_util.h
* \brief Utilities for writing * \brief Utilities for writing passes
*/ */
#ifndef TVM_RELAY_PASS_PASS_UTIL_H_ #ifndef TVM_RELAY_PASS_PASS_UTIL_H_
#define TVM_RELAY_PASS_PASS_UTIL_H_ #define TVM_RELAY_PASS_PASS_UTIL_H_
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/attrs/transform.h> #include <tvm/relay/attrs/transform.h>
#include <memory>
#include <unordered_map> #include <unordered_map>
namespace tvm { namespace tvm {
...@@ -108,6 +109,63 @@ inline bool IsAtomic(const Expr& e) { ...@@ -108,6 +109,63 @@ inline bool IsAtomic(const Expr& e) {
return e.as<VarNode>() || e.as<OpNode>() || e.as<ConstructorNode>() || e.as<GlobalVarNode>(); return e.as<VarNode>() || e.as<OpNode>() || e.as<ConstructorNode>() || e.as<GlobalVarNode>();
} }
template<typename ConditionNodePtr>
struct TreeNode {
typedef std::shared_ptr<TreeNode<ConditionNodePtr>> pointer;
virtual ~TreeNode() {}
};
template<typename ConditionNodePtr>
struct TreeLeafNode : TreeNode<ConditionNodePtr> {
using TreeNodePtr = typename TreeNode<ConditionNodePtr>::pointer;
Expr body;
explicit TreeLeafNode(Expr body): body(body) {}
static TreeNodePtr Make(Expr body) {
return std::make_shared<TreeLeafNode>(body);
}
~TreeLeafNode() {}
};
template<typename ConditionNodePtr>
struct TreeLeafFatalNode : TreeNode<ConditionNodePtr> {
using TreeNodePtr = typename TreeNode<ConditionNodePtr>::pointer;
TreeLeafFatalNode() = default;
static TreeNodePtr Make() {
return std::make_shared<TreeLeafFatalNode>();
}
~TreeLeafFatalNode() {}
};
template<typename ConditionNodePtr>
struct TreeBranchNode : TreeNode<ConditionNodePtr> {
using TreeNodePtr = typename TreeNode<ConditionNodePtr>::pointer;
ConditionNodePtr cond;
TreeNodePtr then_branch;
TreeNodePtr else_branch;
TreeBranchNode(ConditionNodePtr cond,
TreeNodePtr then_branch,
TreeNodePtr else_branch)
: cond(cond), then_branch(then_branch), else_branch(else_branch) {}
static TreeNodePtr Make(ConditionNodePtr cond,
TreeNodePtr then_branch,
TreeNodePtr else_branch) {
return std::make_shared<TreeBranchNode>(cond, then_branch, else_branch);
}
~TreeBranchNode() {}
};
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_PASS_UTIL_H_ #endif // TVM_RELAY_PASS_PASS_UTIL_H_
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import os import os
from nose.tools import nottest from nose.tools import nottest, raises
import tvm import tvm
import numpy as np import numpy as np
...@@ -39,6 +39,15 @@ def veval(f, *args, ctx=tvm.cpu()): ...@@ -39,6 +39,15 @@ def veval(f, *args, ctx=tvm.cpu()):
else: else:
return ex.evaluate()(*args) return ex.evaluate()(*args)
def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.interpreter.TensorValue):
return [o.data.asnumpy().tolist()]
if isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
result = []
for f in o.fields:
result.extend(vmobj_to_list(f))
return result
def test_split(): def test_split():
x = relay.var('x', shape=(12,)) x = relay.var('x', shape=(12,))
y = relay.split(x, 3, axis=0).astuple() y = relay.split(x, 3, axis=0).astuple()
...@@ -186,15 +195,6 @@ def test_tuple_second(): ...@@ -186,15 +195,6 @@ def test_tuple_second():
tvm.testing.assert_allclose(result.asnumpy(), j_data) tvm.testing.assert_allclose(result.asnumpy(), j_data)
def test_list_constructor(): def test_list_constructor():
def to_list(o):
if isinstance(o, tvm.relay.backend.interpreter.TensorValue):
return [o.data.asnumpy().tolist()]
if isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
result = []
for f in o.fields:
result.extend(to_list(f))
return result
mod = relay.Module() mod = relay.Module()
p = Prelude(mod) p = Prelude(mod)
...@@ -202,11 +202,6 @@ def test_list_constructor(): ...@@ -202,11 +202,6 @@ def test_list_constructor():
cons = p.cons cons = p.cons
l = p.l l = p.l
# remove all functions to not have pattern match to pass vm compilation
# TODO(wweic): remove the hack and implement pattern match
for v, _ in mod.functions.items():
mod[v] = relay.const(0)
one2 = cons(relay.const(1), nil()) one2 = cons(relay.const(1), nil())
one3 = cons(relay.const(2), one2) one3 = cons(relay.const(2), one2)
one4 = cons(relay.const(3), one3) one4 = cons(relay.const(3), one3)
...@@ -215,7 +210,7 @@ def test_list_constructor(): ...@@ -215,7 +210,7 @@ def test_list_constructor():
mod["main"] = f mod["main"] = f
result = veval(mod)() result = veval(mod)()
obj = to_list(result) obj = vmobj_to_list(result)
tvm.testing.assert_allclose(obj, np.array([3,2,1])) tvm.testing.assert_allclose(obj, np.array([3,2,1]))
def test_let_tensor(): def test_let_tensor():
...@@ -256,13 +251,6 @@ def test_compose(): ...@@ -256,13 +251,6 @@ def test_compose():
compose = p.compose compose = p.compose
# remove all functions to not have pattern match to pass vm compilation
# TODO(wweic): remove the hack and implement pattern match
for v, _ in mod.functions.items():
if v.name_hint == 'compose':
continue
mod[v] = relay.const(0)
# add_one = fun x -> x + 1 # add_one = fun x -> x + 1
sb = relay.ScopeBuilder() sb = relay.ScopeBuilder()
x = relay.var('x', 'float32') x = relay.var('x', 'float32')
...@@ -291,6 +279,215 @@ def test_compose(): ...@@ -291,6 +279,215 @@ def test_compose():
tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0) tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0)
def test_list_hd():
mod = relay.Module()
p = Prelude(mod)
nil = p.nil
cons = p.cons
l = p.l
hd = p.hd
one2 = cons(relay.const(1), nil())
one3 = cons(relay.const(2), one2)
one4 = cons(relay.const(3), one3)
three = hd(one4)
f = relay.Function([], three)
mod["main"] = f
result = veval(mod)()
tvm.testing.assert_allclose(result.asnumpy(), 3)
@raises(Exception)
def test_list_tl_empty_list():
mod = relay.Module()
p = Prelude(mod)
nil = p.nil
l = p.l
tl = p.tl
f = relay.Function([], tl(nil()))
mod["main"] = f
result = veval(mod)()
print(result)
def test_list_tl():
mod = relay.Module()
p = Prelude(mod)
nil = p.nil
cons = p.cons
l = p.l
tl = p.tl
one2 = cons(relay.const(1), nil())
one3 = cons(relay.const(2), one2)
one4 = cons(relay.const(3), one3)
f = relay.Function([], tl(one4))
mod["main"] = f
result = veval(mod)()
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([2,1]))
def test_list_nth():
expected = list(range(10))
for i in range(len(expected)):
mod = relay.Module()
p = Prelude(mod)
nil = p.nil
cons = p.cons
nth = p.nth
l = nil()
for i in reversed(expected):
l = cons(relay.const(i), l)
f = relay.Function([], nth(l, relay.const(i)))
mod["main"] = f
result = veval(mod)()
tvm.testing.assert_allclose(result.asnumpy(), expected[i])
def test_list_update():
expected = list(range(10))
mod = relay.Module()
p = Prelude(mod)
nil = p.nil
cons = p.cons
update = p.update
l = nil()
# create zero initialized list
for i in range(len(expected)):
l = cons(relay.const(0), l)
# set value
for i, v in enumerate(expected):
l = update(l, relay.const(i), relay.const(v))
f = relay.Function([], l)
mod["main"] = f
result = veval(mod)()
tvm.testing.assert_allclose(vmobj_to_list(result), np.array(expected))
def test_list_length():
expected = list(range(10))
mod = relay.Module()
p = Prelude(mod)
nil = p.nil
cons = p.cons
length = p.length
l = nil()
# create zero initialized list
for i in range(len(expected)):
l = cons(relay.const(0), l)
l = length(l)
f = relay.Function([], l)
mod["main"] = f
result = veval(mod)()
tvm.testing.assert_allclose(result.asnumpy(), 10)
def test_list_map():
mod = relay.Module()
p = Prelude(mod)
x = relay.var('x', 'int32')
add_one_func = relay.Function([x], relay.const(1) + x)
nil = p.nil
cons = p.cons
map = p.map
l = cons(relay.const(2), cons(relay.const(1), nil()))
f = relay.Function([], map(add_one_func, l))
mod["main"] = f
result = veval(mod)()
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 2]))
def test_list_foldl():
mod = relay.Module()
p = Prelude(mod)
nil = p.nil
cons = p.cons
foldl = p.foldl
x = relay.var("x")
y = relay.var("y")
rev_dup_func = relay.Function([y, x], cons(x, cons(x, y)))
l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil())))
f = relay.Function([], foldl(rev_dup_func, nil(), l))
mod["main"] = f
result = veval(mod)()
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 3, 2, 2, 1, 1]))
def test_list_foldr():
mod = relay.Module()
p = Prelude(mod)
nil = p.nil
cons = p.cons
foldr = p.foldr
x = relay.var("x")
y = relay.var("y")
identity_func = relay.Function([x, y], cons(x, y))
l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil())))
f = relay.Function([], foldr(identity_func, nil(), l))
mod["main"] = f
result = veval(mod)()
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([1, 2, 3]))
def test_list_sum():
mod = relay.Module()
p = Prelude(mod)
nil = p.nil
cons = p.cons
sum = p.sum
l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil())))
f = relay.Function([], sum(l))
mod["main"] = f
result = veval(mod)()
tvm.testing.assert_allclose(result.asnumpy(), 6)
def test_list_filter():
mod = relay.Module()
p = Prelude(mod)
nil = p.nil
cons = p.cons
filter = p.filter
x = relay.var("x", 'int32')
greater_than_one = relay.Function([x], x > relay.const(1))
l = cons(relay.const(1),
cons(relay.const(3),
cons(relay.const(1),
cons(relay.const(5),
cons(relay.const(1), nil())))))
f = relay.Function([], filter(greater_than_one, l))
mod["main"] = f
result = veval(mod)()
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 5]))
def test_closure(): def test_closure():
x = relay.var('x', shape=()) x = relay.var('x', shape=())
y = relay.var('y', shape=()) y = relay.var('y', shape=())
...@@ -315,6 +512,15 @@ if __name__ == "__main__": ...@@ -315,6 +512,15 @@ if __name__ == "__main__":
test_let_tensor() test_let_tensor()
test_split() test_split()
test_split_no_fuse() test_split_no_fuse()
# TODO(@jroesch): restore when match is supported test_list_constructor()
# test_list_constructor() test_list_tl_empty_list()
test_list_tl()
test_list_nth()
test_list_update()
test_list_length()
test_list_map()
test_list_foldl()
test_list_foldr()
test_list_sum()
test_list_filter()
test_closure() test_closure()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment