Commit 395804e5 by Jared Roesch Committed by Tianqi Chen

Small refactors and bug fixes. (#2281)

parent 5cb729ec
...@@ -248,6 +248,13 @@ class FunctionNode : public ExprNode { ...@@ -248,6 +248,13 @@ class FunctionNode : public ExprNode {
*/ */
TVM_DLL FuncType func_type_annotation() const; TVM_DLL FuncType func_type_annotation() const;
/*!
* \brief Check whether the function is a primitive function.
*
* \return Whether the function is primitive or not.
*/
bool IsPrimitive() const;
TVM_DLL static Function make(tvm::Array<Var> params, TVM_DLL static Function make(tvm::Array<Var> params,
Expr body, Expr body,
Type ret_type, Type ret_type,
......
...@@ -5,6 +5,7 @@ from ..api import register_func ...@@ -5,6 +5,7 @@ from ..api import register_func
from . import base from . import base
from . import ty from . import ty
from . import expr from . import expr
from . import expr_functor
from . import module from . import module
from . import ir_pass from . import ir_pass
from .build_module import build, build_config, create_executor from .build_module import build, build_config, create_executor
...@@ -53,6 +54,10 @@ Let = expr.Let ...@@ -53,6 +54,10 @@ Let = expr.Let
If = expr.If If = expr.If
TupleGetItem = expr.TupleGetItem TupleGetItem = expr.TupleGetItem
# ExprFunctor
ExprFunctor = expr_functor.ExprFunctor
ExprMutator = expr_functor.ExprMutator
# helper functions # helper functions
var = expr.var var = expr.var
const = expr.const const = expr.const
......
...@@ -24,7 +24,8 @@ import attr ...@@ -24,7 +24,8 @@ import attr
from . import _backend from . import _backend
from . import compile_engine from . import compile_engine
from ..op import Op from ..op import Op
from ..expr import Function, GlobalVar, ExprFunctor from ..expr import Function, GlobalVar
from ..expr_functor import ExprFunctor
from ..ty import TupleType, TensorType from ..ty import TupleType, TensorType
...@@ -251,6 +252,9 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -251,6 +252,9 @@ class GraphRuntimeCodegen(ExprFunctor):
op_name, inputs, {}) op_name, inputs, {})
return self.add_node(op_node, call) return self.add_node(op_node, call)
def visit_op(self, _):
raise Exception("can not compile op in non-eta expanded form")
def _get_json(self): def _get_json(self):
""" """
Convert the sequence of nodes stored by the compiler into the Convert the sequence of nodes stored by the compiler into the
......
...@@ -222,12 +222,13 @@ class Function(Expr): ...@@ -222,12 +222,13 @@ class Function(Expr):
params, params,
body, body,
ret_type=None, ret_type=None,
type_params=None): type_params=None,
attrs=None):
if type_params is None: if type_params is None:
type_params = convert([]) type_params = convert([])
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.Function, params, body, ret_type, type_params) _make.Function, params, body, ret_type, type_params, attrs)
def __call__(self, *args): def __call__(self, *args):
"""Invoke the gobal function. """Invoke the gobal function.
...@@ -343,131 +344,6 @@ class TempExpr(Expr): ...@@ -343,131 +344,6 @@ class TempExpr(Expr):
return _expr.TempExprRealize(self) return _expr.TempExprRealize(self)
class ExprFunctor(object):
"""
An abstract visitor defined over Expr.
Defines the default dispatch over expressions, and
implements memoization.
"""
def __init__(self):
self.memo_map = {}
# pylint: disable=no-else-return
def visit(self, expr):
"""Apply the visitor to an expression."""
found = self.memo_map.get(expr)
if found:
return found
if isinstance(expr, Function):
res = self.visit_function(expr)
elif isinstance(expr, Call):
res = self.visit_call(expr)
elif isinstance(expr, Let):
res = self.visit_let(expr)
elif isinstance(expr, Var):
res = self.visit_var(expr)
elif isinstance(expr, GlobalVar):
res = self.visit_global_var(expr)
elif isinstance(expr, If):
res = self.visit_if(expr)
elif isinstance(expr, Tuple):
res = self.visit_tuple(expr)
elif isinstance(expr, TupleGetItem):
res = self.visit_tuple_getitem(expr)
elif isinstance(expr, Constant):
res = self.visit_constant(expr)
else:
raise Exception("warning unhandled case: {0}".format(type(expr)))
self.memo_map[expr] = res
return res
def visit_function(self, _):
raise NotImplementedError()
def visit_let(self, _):
raise NotImplementedError()
def visit_call(self, _):
raise NotImplementedError()
def visit_var(self, _):
raise NotImplementedError()
def visit_type(self, typ):
return typ
def visit_if(self, _):
raise NotImplementedError()
def visit_tuple(self, _):
raise NotImplementedError()
def visit_tuple_getitem(self, _):
raise NotImplementedError()
def visit_constant(self, _):
raise NotImplementedError()
def visit_global_var(self, _):
raise NotImplementedError()
class ExprMutator(ExprFunctor):
"""
A functional visitor over Expr.
The default behavior recursively traverses the AST
and reconstructs the AST.
"""
def visit_function(self, fn):
new_body = self.visit(fn.body)
return Function(
list(fn.params),
fn.ret_type, new_body,
fn.type_params)
def visit_let(self, let):
new_var = self.visit(let.var)
new_val = self.visit(let.value)
new_body = self.visit(let.body)
return Let(new_var, new_val, new_body)
def visit_call(self, call):
new_fn = self.visit(call.op)
new_args = [self.visit(arg) for arg in call.args]
return Call(new_fn, new_args, call.attrs)
def visit_var(self, rvar):
return rvar
def visit_global_id(self, global_var):
return global_var
def visit_if(self, ite):
return If(
self.visit(ite.guard),
self.visit(ite.true_b),
self.visit(ite.false_b))
def visit_tuple(self, tup):
return Tuple([self.visit(field) for field in tup.fields])
def visit_tuple_getitem(self, op):
tuple_value = self.visit(op.tuple_value)
if not tuple_value.same_as(op.tuple_value):
return TupleGetItem(tuple_value, op.index)
return op
def visit_global_var(self, gvar):
return gvar
def visit_constant(self, rconst):
return rconst
class TupleWrapper(object): class TupleWrapper(object):
"""TupleWrapper. """TupleWrapper.
......
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The expression functor of Relay."""
from .expr import Function, Call, Let, Var, GlobalVar, If, Tuple, TupleGetItem, Constant
from .op import Op
class ExprFunctor:
"""
An abstract visitor defined over Expr.
Defines the default dispatch over expressions, and
implements memoization.
"""
def __init__(self):
self.memo_map = {}
# pylint: disable=no-else-return
def visit(self, expr):
"""Apply the visitor to an expression."""
found = self.memo_map.get(expr)
if found:
return found
if isinstance(expr, Function):
res = self.visit_function(expr)
elif isinstance(expr, Call):
res = self.visit_call(expr)
elif isinstance(expr, Let):
res = self.visit_let(expr)
elif isinstance(expr, Var):
res = self.visit_var(expr)
elif isinstance(expr, GlobalVar):
res = self.visit_global_var(expr)
elif isinstance(expr, If):
res = self.visit_if(expr)
elif isinstance(expr, Tuple):
res = self.visit_tuple(expr)
elif isinstance(expr, TupleGetItem):
res = self.visit_tuple_getitem(expr)
elif isinstance(expr, Constant):
res = self.visit_constant(expr)
elif isinstance(expr, Op):
res = self.visit_op(expr)
else:
raise Exception("warning unhandled case: {0}".format(type(expr)))
self.memo_map[expr] = res
return res
def visit_function(self, _):
raise NotImplementedError()
def visit_let(self, _):
raise NotImplementedError()
def visit_call(self, _):
raise NotImplementedError()
def visit_var(self, _):
raise NotImplementedError()
def visit_type(self, typ):
return typ
def visit_if(self, _):
raise NotImplementedError()
def visit_tuple(self, _):
raise NotImplementedError()
def visit_tuple_getitem(self, _):
raise NotImplementedError()
def visit_global_var(self, _):
raise NotImplementedError()
def visit_op(self, _):
raise NotImplementedError()
def visit_constant(self, _):
raise NotImplementedError()
class ExprMutator(ExprFunctor):
"""
A functional visitor over Expr.
The default behavior recursively traverses the AST
and reconstructs the AST.
"""
def visit_function(self, fn):
new_body = self.visit(fn.body)
return Function(
list(fn.params),
new_body,
fn.ret_type,
fn.type_params,
fn.attrs)
def visit_let(self, let):
new_var = self.visit(let.var)
new_val = self.visit(let.value)
new_body = self.visit(let.body)
return Let(new_var, new_val, new_body)
def visit_call(self, call):
new_fn = self.visit(call.op)
new_args = [self.visit(arg) for arg in call.args]
return Call(new_fn, new_args, call.attrs)
def visit_var(self, rvar):
return rvar
def visit_global_id(self, global_var):
return global_var
def visit_if(self, ite):
return If(
self.visit(ite.guard),
self.visit(ite.true_b),
self.visit(ite.false_b))
def visit_tuple(self, tup):
return Tuple([self.visit(field) for field in tup.fields])
def visit_tuple_getitem(self, op):
tuple_value = self.visit(op.tuple_value)
if not tuple_value.same_as(op.tuple_value):
return TupleGetItem(tuple_value, op.index)
return op
def visit_global_var(self, gvar):
return gvar
def visit_op(self, op):
return op
def visit_constant(self, const):
return const
def visit_constructor(self, con):
return con
def visit_match(self, m):
return Match(self.visit(m.data), [Clause(c.lhs, self.visit(c.rhs)) for c in m.pattern])
def visit_ref_new(self, r):
return RefNew(self.visit(r.value))
def visit_ref_write(self, r):
return RefWrite(self.visit(r.ref), self.visit(r.value))
def visit_ref_read(self, r):
return RefRead(self.visit(r.ref))
...@@ -157,14 +157,14 @@ class ScheduleGetter : ...@@ -157,14 +157,14 @@ class ScheduleGetter :
int op_pattern = fpattern[op]; int op_pattern = fpattern[op];
if (op_pattern >= kCommReduce) { if (op_pattern >= kCommReduce) {
CHECK(!master_op_.defined() || master_op_patetrn_ < kCommReduce) CHECK(!master_op_.defined() || master_op_pattern_ < kCommReduce)
<< "Two complicated op in a primitive function " << "Two complicated op in a primitive function "
<< " master=" << master_op_ << " current=" << op; << " master=" << master_op_ << " current=" << op;
} }
if (op_pattern >= master_op_patetrn_) { if (op_pattern >= master_op_pattern_) {
master_op_ = op; master_op_ = op;
master_attrs_ = call_node->attrs; master_attrs_ = call_node->attrs;
master_op_patetrn_ = op_pattern; master_op_pattern_ = op_pattern;
} }
if (outputs.size() != 1) { if (outputs.size() != 1) {
const auto* tuple_type = const auto* tuple_type =
...@@ -213,7 +213,7 @@ class ScheduleGetter : ...@@ -213,7 +213,7 @@ class ScheduleGetter :
tvm::Target target_; tvm::Target target_;
Op master_op_; Op master_op_;
Attrs master_attrs_; Attrs master_attrs_;
int master_op_patetrn_{0}; int master_op_pattern_{0};
std::ostringstream readable_name_stream_; std::ostringstream readable_name_stream_;
std::unordered_map<Expr, Array<Tensor>, NodeHash, NodeEqual> memo_; std::unordered_map<Expr, Array<Tensor>, NodeHash, NodeEqual> memo_;
}; };
......
...@@ -292,17 +292,10 @@ class Interpreter : ...@@ -292,17 +292,10 @@ class Interpreter :
} }
} }
// Check if function is a primitive function.
bool IsPrimitive(const Function& func) const {
NodeRef res = FunctionGetAttr(func, "Primitive");
const ir::IntImm* pval = res.as<ir::IntImm>();
return pval && pval->value != 0;
}
// Invoke the closure // Invoke the closure
Value Invoke(const Closure& closure, const tvm::Array<Value>& args) { Value Invoke(const Closure& closure, const tvm::Array<Value>& args) {
// Get a reference to the function inside the closure. // Get a reference to the function inside the closure.
if (IsPrimitive(closure->func)) { if (closure->func->IsPrimitive()) {
return InvokePrimitiveOp(closure->func, args); return InvokePrimitiveOp(closure->func, args);
} }
auto func = closure->func; auto func = closure->func;
......
...@@ -135,6 +135,12 @@ FuncType FunctionNode::func_type_annotation() const { ...@@ -135,6 +135,12 @@ FuncType FunctionNode::func_type_annotation() const {
return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {}); return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {});
} }
bool FunctionNode::IsPrimitive() const {
NodeRef res = FunctionGetAttr(GetRef<Function>(this), "Primitive");
const ir::IntImm* pval = res.as<ir::IntImm>();
return pval && pval->value != 0;
}
NodeRef FunctionGetAttr(const Function& func, const std::string& key) { NodeRef FunctionGetAttr(const Function& func, const std::string& key) {
if (!func->attrs.defined()) { return NodeRef(); } if (!func->attrs.defined()) { return NodeRef(); }
...@@ -172,7 +178,7 @@ TVM_REGISTER_NODE_TYPE(FunctionNode); ...@@ -172,7 +178,7 @@ TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_API("relay._make.Function") TVM_REGISTER_API("relay._make.Function")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = FunctionNode::make(args[0], args[1], args[2], args[3]); *ret = FunctionNode::make(args[0], args[1], args[2], args[3], args[4]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
......
...@@ -699,9 +699,7 @@ class FuseMutator : private ExprMutator { ...@@ -699,9 +699,7 @@ class FuseMutator : private ExprMutator {
std::unordered_map<GraphPartitioner::Group*, GroupInfo> ginfo_; std::unordered_map<GraphPartitioner::Group*, GroupInfo> ginfo_;
// Skip primitive function. // Skip primitive function.
Expr VisitExpr_(const FunctionNode* fn_node) { Expr VisitExpr_(const FunctionNode* fn_node) {
NodeRef res = FunctionGetAttr(GetRef<Function>(fn_node), "Primitive"); if (fn_node->IsPrimitive()) {
const ir::IntImm* pval = res.as<ir::IntImm>();
if (pval && pval->value != 0) {
return GetRef<Expr>(fn_node); return GetRef<Expr>(fn_node);
} else { } else {
return ExprMutator::VisitExpr_(fn_node); return ExprMutator::VisitExpr_(fn_node);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment