Commit 6616355d by 雾雨魔理沙 Committed by Tianqi Chen

[Relay] GetItem (#1861)

parent 4e309e67
......@@ -360,8 +360,6 @@ class IfNode : public ExprNode {
/*! \brief The expression evaluated when condition is false */
Expr false_branch;
IfNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("cond", &cond);
v->Visit("true_branch", &true_branch);
......@@ -378,6 +376,28 @@ class IfNode : public ExprNode {
RELAY_DEFINE_NODE_REF(If, IfNode, Expr);
/*! \brief Get a field out of a tuple. */
class TupleGetItem;
class TupleGetItemNode : public ExprNode {
public:
/*! \brief The tuple */
Expr tuple;
/*! \brief which value to get */
int index;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("tuple", &tuple);
v->Visit("index", &index);
}
TVM_DLL static TupleGetItem make(Expr tuple, int index);
static constexpr const char * _type_key = "relay.GetItem";
TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);
/*! \brief Print a debug representation of the expression to the stream.
* \param env The environment.
* \param e The expression
......
......@@ -89,6 +89,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const OpNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key());
}
......@@ -108,6 +109,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH(LetNode);
RELAY_EXPR_FUNCTOR_DISPATCH(IfNode);
RELAY_EXPR_FUNCTOR_DISPATCH(OpNode);
RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode);
return vtable;
}
};
......@@ -131,6 +133,7 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
void VisitExpr_(const LetNode* op) override;
void VisitExpr_(const IfNode* op) override;
void VisitExpr_(const OpNode* op) override;
void VisitExpr_(const TupleGetItemNode* op) override;
virtual void VisitType(const Type& t);
};
......@@ -153,6 +156,7 @@ class ExprMutator
Expr VisitExpr_(const CallNode* call_node) override;
Expr VisitExpr_(const LetNode* op) override;
Expr VisitExpr_(const IfNode* op) override;
Expr VisitExpr_(const TupleGetItemNode* op) override;
/*! \brief Used to visit the types inside of expressions.
*
* Can be overloaded to transform the types in arbitrary
......
......@@ -39,3 +39,4 @@ Function = expr.Function
Call = expr.Call
Let = expr.Let
If = expr.If
TupleGetItem = expr.TupleGetItem
......@@ -125,4 +125,12 @@ class If(Expr):
self.__init_handle_by_constructor__(
_make.If, cond, true_value, false_value)
@register_relay_node
class TupleGetItem(Expr):
"""An expression that get field from tuple in Relay, see tvm/relay/expr.h for more details."""
def __init__(self, tuple_, index):
self.__init_handle_by_constructor__(
_make.TupleGetItem, tuple_, index)
debug_print = _expr._debug_print
......@@ -223,7 +223,6 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
}
Doc VisitExpr_(const CallNode* c) final {
auto args = DocifyExprArray(c->args);
return Docify(c->op) + Seq("<", DocifyExprArray(c->args), ">");
}
......@@ -244,6 +243,10 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
return DocOfStr(o->name);
}
Doc VisitExpr_(const TupleGetItemNode* g) final {
return Docify(g->tuple) + DocOfStr(std::string(".") + std::to_string(g->index));
}
public:
ExprDocifier(const Environment& env) : env(env), td(env) { }
......@@ -291,7 +294,6 @@ std::string PrintType(const Environment& env, const Type& t) {
TVM_REGISTER_API("relay._expr._debug_print")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[1];
std::cout << x << std::endl;
if (x.as<TypeNode>()) {
*ret = PrintType(args[0], Downcast<Type>(x));
} else {
......
......@@ -193,5 +193,21 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< ", " << node->false_branch << ")";
});
TupleGetItem TupleGetItemNode::make(Expr tuple, int index) {
NodePtr<TupleGetItemNode> n = make_node<TupleGetItemNode>();
n->tuple = std::move(tuple);
n->index = index;
return TupleGetItem(n);
}
TVM_REGISTER_API("relay._make.TupleGetItem").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TupleGetItemNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleGetItemNode>([](const TupleGetItemNode* node, tvm::IRPrinter* p) {
p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")";
});
} // namespace relay
} // namespace tvm
......@@ -150,10 +150,17 @@ Expr ExprMutator::VisitExpr_(const IfNode* op) {
}
}
Type ExprMutator::VisitType(const Type& t) {
return t;
Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
auto t = this->Mutate(g->tuple);
if (g->tuple == t) {
return GetRef<Expr>(g);
} else {
return TupleGetItemNode::make(t, g->index);
}
}
Type ExprMutator::VisitType(const Type& t) { return t; }
void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) {
}
......@@ -206,6 +213,10 @@ void ExprVisitor::VisitExpr_(const IfNode* op) {
void ExprVisitor::VisitExpr_(const OpNode* op) { return; }
void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) {
this->VisitExpr(op->tuple);
}
void ExprVisitor::VisitType(const Type& t) { return; }
} // namespace relay
......
......@@ -335,6 +335,15 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
equal = false;
}
}
void VisitExpr_(const TupleGetItemNode* op, const Expr& e2) final {
if (const TupleGetItemNode* proj = e2.as<TupleGetItemNode>()) {
this->VisitExpr(op->tuple, proj->tuple);
equal = equal && (op->index == proj->index);
} else {
equal = false;
}
}
};
bool AlphaEqual(const Expr& e1, const Expr& e2) {
......
......@@ -8,7 +8,6 @@
#include <tvm/node/ir_functor.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/error.h>
#include <string>
namespace tvm {
......
......@@ -119,6 +119,20 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
return TupleTypeNode::make(fields);
}
Type VisitExpr_(const TupleGetItemNode* op) final {
// TODO(M.K.)
// handle case where field type is not known
Type tuple_type = GetType(op->tuple);
auto tuple_ty_node = tuple_type.as<TupleTypeNode>();
if (!tuple_ty_node) {
LOG(FATAL) << "only expressions with tuple types is accepted" << GetRef<TupleGetItem>(op);
}
if (static_cast<int>(tuple_ty_node->fields.size()) <= op->index) {
LOG(FATAL) << "tuple not big enough" << GetRef<TupleGetItem>(op);
}
return tuple_ty_node->fields[op->index];
}
Type VisitExpr_(const OpNode* op) final {
return op->op_type;
}
......@@ -293,6 +307,10 @@ class TypeInferencer::Resolver : public ExprMutator {
return AttachCheckedType(op);
}
Expr VisitExpr_(const TupleGetItemNode* op) final {
return AttachCheckedType(op);
}
Expr VisitExpr_(const ParamNode* op) final {
return ExprMutator::VisitExpr_(op);
}
......
......@@ -77,7 +77,7 @@ def test_call():
def test_let():
lv = relay.Var('x')
ty = relay.ty.TensorType((10, 20), "float32")
ty = relay.ty.TensorType((10, 20), 'float32')
arr = tvm.nd.array(10)
value = relay.Constant(arr)
let = relay.Let(lv, value, lv, ty)
......@@ -90,3 +90,8 @@ def test_if():
right = relay.Var('right')
ife = relay.If(cond, left, right)
show(ife)
def test_tuple_get_item():
t = relay.Var('t')
g = relay.TupleGetItem(t, 0)
show(g)
......@@ -175,6 +175,13 @@ def test_if():
str(ife)
def test_tuple_get_item():
tup = relay.Var("tuple")
get = relay.TupleGetItem(tup, 1)
assert get.tuple == tup
assert get.index == 1
str(get)
if __name__ == "__main__":
test_bad_constructor()
test_span()
......@@ -192,3 +199,4 @@ if __name__ == "__main__":
test_call()
test_let()
test_if()
test_tuple_get_item()
......@@ -3,7 +3,7 @@ from tvm import relay
from tvm.relay.ir_pass import well_formed
def test_well_formed():
x = relay.Var("x")
x = relay.Var('x')
assert well_formed(x)
v = relay.Constant(tvm.nd.array(10))
ty = None
......@@ -16,3 +16,19 @@ def test_well_formed():
# but we want all binder to be distinct from each other.
assert not well_formed(relay.Let(relay.Var("y"), f,
relay.Let(relay.Var("z"), f, v, ty), ty))
def test_tuple():
x = relay.Var('x')
assert well_formed(x)
v = relay.Constant(tvm.nd.array(10))
ty = None
let = relay.Let(x, v, x, ty)
assert well_formed(let)
assert well_formed(relay.Tuple([v, v]))
assert not well_formed(relay.Tuple([let, let]))
def test_tuple_get_item():
t = relay.Var('t')
assert well_formed(relay.TupleGetItem(t, 2))
......@@ -167,11 +167,19 @@ def test_type_relation_alpha_equal():
assert bigger != diff_num_inputs
def test_tuple_get_item_alpha_equal():
x = relay.Var('x')
y = relay.Var('y')
assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(y, 1))
assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 2))
assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1))
if __name__ == "__main__":
test_tensor_type_alpha_equal()
test_incomplete_type_alpha_equal()
test_constant_alpha_equal()
test_type_param_alpha_equal()
test_func_type_alpha_equal()
test_tuple_type_alpha_equal()
test_type_relation_alpha_equal()
test_tuple_get_item_alpha_equal()
......@@ -4,6 +4,7 @@ from tvm.relay.ir_pass import dead_code_elimination, alpha_equal
from tvm.relay.ir_builder import convert, IRBuilder
from tvm.relay.op import log, add, equal, subtract
class env:
def __init__(self):
self.a = relay.Var("a")
......@@ -22,20 +23,25 @@ class env:
self.two = convert(2.0)
self.three = convert(3.0)
e = env()
def test_let():
orig = relay.Let(e.x, e.y, e.z, e.tt)
assert alpha_equal(dead_code_elimination(orig), e.z)
def test_used_let():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c, e.tt), e.tt)
assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.d, e.c, e.tt))
def test_chain_unused_let():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e, e.tt), e.tt)
assert alpha_equal(dead_code_elimination(orig), e.e)
# make sure we dont infinite loop
def test_recursion():
"""
......@@ -60,14 +66,23 @@ def test_recursion():
assert alpha_equal(dead_code_elimination(orig), orig)
assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three, e.float32)), e.three)
def test_op_let():
assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three, e.float32), e.two)), add(e.three, e.two))
def test_if():
orig = relay.If(convert(True), e.a, e.b)
assert alpha_equal(dead_code_elimination(orig), e.a)
def test_tuple_get_item():
t = relay.Var('t')
g = relay.TupleGetItem(t, 0)
assert alpha_equal(dead_code_elimination(g), g)
assert alpha_equal(dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t, e.float32), 0)), g)
if __name__ == "__main__":
test_let()
test_used_let()
......@@ -75,3 +90,4 @@ if __name__ == "__main__":
test_recursion()
test_op_let()
test_if()
test_tuple_get_item()
......@@ -15,6 +15,17 @@ def test_free_vars():
f = relay.Function([relay.Param(x, ty)], ty, x)
assert len(free_vars(f)) == 0
def test_tuple():
t = relay.Var('t')
fv = free_vars(relay.Tuple([t, t]))
assert len(fv) == 1
assert fv[0] == t
fv = free_vars(relay.TupleGetItem(t, 123))
assert len(fv) == 1
assert fv[0] == t
def test_free_type_vars():
tp = relay.TypeParam("")
ty = relay.TupleType([tp, relay.TensorType([], "int32")])
......
......@@ -9,6 +9,7 @@ from tvm.relay.ir_builder import scalar_type, convert, tensor_type
from tvm.relay.env import Environment
from tvm.relay.op import log, add, equal, subtract, concatenate
from tvm.relay.expr import Function
from tvm import relay
def assert_has_type(expr, typ, env=Environment({})):
checked_expr = infer_type(env, expr)
......@@ -110,6 +111,16 @@ def test_concat():
fn_ty = func_type([tensor_type(3, 2), tensor_type(2, 2)], tensor_type(5, 2))
assert_decl_has_type(ib.env, try_concat2, fn_ty)
def test_tuple():
ib = IRBuilder()
dup = ib.global_var('dup')
x = ib.param('x')
with ib.decl(dup, x):
ib.ret(relay.Tuple([x, x]))
# todo: why is this not generalized?
fn_ty = func_type([tensor_type()], relay.TupleType([tensor_type(), tensor_type()]))
assert_decl_has_type(ib.env, dup, fn_ty)
if __name__ == "__main__":
test_dual_op()
test_recursion()
......@@ -117,3 +128,4 @@ if __name__ == "__main__":
test_decl()
test_recursion()
test_concat()
test_tuple()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment