Unverified Commit b2521604 by Tianqi Chen Committed by GitHub

[RELAY][PASS] Bind, FoldConstant (#2100)

parent 1b863732
...@@ -182,6 +182,17 @@ class ExprMutator ...@@ -182,6 +182,17 @@ class ExprMutator
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_; std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_;
}; };
/*
* \brief Bind function parameters or free variables.
*
* Parameter binding can only happen if expr is a Function.
* binds cannot change internal arguments of internal functions.
*
* \param expr The function to be binded.
* \param binds The map of arguments to
*/
Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_EXPR_FUNCTOR_H_ #endif // TVM_RELAY_EXPR_FUNCTOR_H_
...@@ -39,6 +39,16 @@ enum OpPatternKind { ...@@ -39,6 +39,16 @@ enum OpPatternKind {
using TOpPattern = int; using TOpPattern = int;
/*! /*!
* \brief Whether operator is stateful or contain internal state.
*
* All the primitive ops we registered so far are pure.
* This attribute is left for potential future compatible reasons.
* We can always work around the stateful ops by adding an additional
* handle argument and return it.
*/
using TOpIsStateful = bool;
/*!
* \brief Computation description interface. * \brief Computation description interface.
* *
* \note This function have a special convention * \note This function have a special convention
......
...@@ -143,6 +143,22 @@ tvm::Array<TypeVar> FreeTypeVars(const Expr& expr); ...@@ -143,6 +143,22 @@ tvm::Array<TypeVar> FreeTypeVars(const Expr& expr);
*/ */
Expr DeadCodeElimination(const Expr& e); Expr DeadCodeElimination(const Expr& e);
/*!
* \brief Fold constant expressions.
* \param expr the expression to be optimized.
* \return The optimized expression.
*/
Expr FoldConstant(const Expr& expr);
/*!
* \brief Fuse operations into expr into seperate functions.
* \param expr The expression.
* \param fuse_opt_level Optimization level.
* \return The optimized expression.
*/
Expr FuseOps(const Expr& expr, int fuse_opt_level);
/*! \brief A hashing structure in the style of std::hash. */ /*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash { struct StructuralHash {
/*! \brief Hash a Relay type. /*! \brief Hash a Relay type.
......
...@@ -54,7 +54,7 @@ TupleGetItem = expr.TupleGetItem ...@@ -54,7 +54,7 @@ TupleGetItem = expr.TupleGetItem
# helper functions # helper functions
var = expr.var var = expr.var
const = expr.const const = expr.const
bind = expr.bind
# pylint: disable=unused-argument # pylint: disable=unused-argument
@register_func("relay.debug") @register_func("relay.debug")
......
...@@ -102,6 +102,7 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -102,6 +102,7 @@ class GraphRuntimeCodegen(ExprFunctor):
self.target = target self.target = target
self.nodes = [] self.nodes = []
self.var_map = {} self.var_map = {}
self.params = {}
self.compile_engine = compile_engine.get() self.compile_engine = compile_engine.get()
self.lowered_funcs = set() self.lowered_funcs = set()
self._name_map = {} self._name_map = {}
...@@ -162,8 +163,12 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -162,8 +163,12 @@ class GraphRuntimeCodegen(ExprFunctor):
assert isinstance(vtuple, tuple) assert isinstance(vtuple, tuple)
return vtuple[op.index] return vtuple[op.index]
def visit_constant(self, _): def visit_constant(self, op):
raise RuntimeError("constant not supported") index = len(self.params)
name = "p%d" % index
self.params[name] = op.data
node = InputNode(name, {})
return self.add_node(node, op.checked_type)
def visit_function(self, _): def visit_function(self, _):
raise RuntimeError("function not supported") raise RuntimeError("function not supported")
...@@ -312,6 +317,9 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -312,6 +317,9 @@ class GraphRuntimeCodegen(ExprFunctor):
lowered_funcs : List[tvm.LoweredFunc] lowered_funcs : List[tvm.LoweredFunc]
The lowered functions. The lowered functions.
params : Dict[str, tvm.nd.NDArray]
Additional constant parameters.
""" """
# First we convert all the parameters into input nodes. # First we convert all the parameters into input nodes.
for param in func.params: for param in func.params:
...@@ -324,7 +332,7 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -324,7 +332,7 @@ class GraphRuntimeCodegen(ExprFunctor):
self.heads = self.visit(func.body) self.heads = self.visit(func.body)
graph_json = self._get_json() graph_json = self._get_json()
lowered_funcs = list(self.lowered_funcs) lowered_funcs = list(self.lowered_funcs)
return graph_json, lowered_funcs return graph_json, lowered_funcs, self.params
def _get_unique_name(self, name): def _get_unique_name(self, name):
if name not in self._name_map: if name not in self._name_map:
......
...@@ -6,6 +6,7 @@ from ..build_module import build as _tvm_build_module ...@@ -6,6 +6,7 @@ from ..build_module import build as _tvm_build_module
from .. import nd as _nd, target as _target, autotvm from .. import nd as _nd, target as _target, autotvm
from ..contrib import graph_runtime as _graph_rt from ..contrib import graph_runtime as _graph_rt
from . import ir_pass from . import ir_pass
from . import expr
from .backend import interpreter as _interpreter from .backend import interpreter as _interpreter
from .backend import graph_runtime_codegen as _graph_gen from .backend import graph_runtime_codegen as _graph_gen
...@@ -13,6 +14,7 @@ from .backend import graph_runtime_codegen as _graph_gen ...@@ -13,6 +14,7 @@ from .backend import graph_runtime_codegen as _graph_gen
OPT_PASS_LEVEL = { OPT_PASS_LEVEL = {
"SimplifyInference": 0, "SimplifyInference": 0,
"OpFusion": 1, "OpFusion": 1,
"FoldConstant": 2,
"FoldScaleAxis": 3, "FoldScaleAxis": 3,
} }
...@@ -95,7 +97,27 @@ def build_config(**kwargs): ...@@ -95,7 +97,27 @@ def build_config(**kwargs):
return BuildConfig(**kwargs) return BuildConfig(**kwargs)
def optimize(func): def _bind_params_by_name(func, params):
"""Bind parameters of function by its name."""
name_dict = {}
for arg in func.params:
name = arg.name_hint
if name in name_dict:
name_dict[name] = None
else:
name_dict[name] = arg
bind_dict = {}
for k, v in params.items():
if k not in name_dict:
continue
arg = name_dict[k]
if arg is None:
raise ValueError("Multiple args in the function have name %s" % k)
bind_dict[arg] = expr.const(v)
return expr.bind(func, bind_dict)
def optimize(func, params=None):
"""Perform target invariant optimizations. """Perform target invariant optimizations.
Parameters Parameters
...@@ -103,6 +125,10 @@ def optimize(func): ...@@ -103,6 +125,10 @@ def optimize(func):
func : tvm.relay.Function func : tvm.relay.Function
The input to optimization. The input to optimization.
params : Optional[Dict[str, tvm.nd.NDArray]]
Input parameters to the graph that do not change
during inference time. used for constant folding.
Returns Returns
------- -------
opt_func : tvm.relay.Function opt_func : tvm.relay.Function
...@@ -110,7 +136,11 @@ def optimize(func): ...@@ -110,7 +136,11 @@ def optimize(func):
""" """
cfg = BuildConfig.current cfg = BuildConfig.current
if cfg.pass_enabled("FoldScaleAxis"): # bind expressions
if params:
func = _bind_params_by_name(func, params)
if cfg.pass_enabled("SimplifyInference"):
func = ir_pass.infer_type(func) func = ir_pass.infer_type(func)
func = ir_pass.simplify_inference(func) func = ir_pass.simplify_inference(func)
...@@ -119,6 +149,10 @@ def optimize(func): ...@@ -119,6 +149,10 @@ def optimize(func):
func = ir_pass.backward_fold_scale_axis(func) func = ir_pass.backward_fold_scale_axis(func)
func = ir_pass.infer_type(func) func = ir_pass.infer_type(func)
func = ir_pass.forward_fold_scale_axis(func) func = ir_pass.forward_fold_scale_axis(func)
if cfg.pass_enabled("FoldConstant"):
func = ir_pass.fold_constant(func)
return func return func
...@@ -147,8 +181,7 @@ def build(func, ...@@ -147,8 +181,7 @@ def build(func,
params : dict of str to NDArray params : dict of str to NDArray
Input parameters to the graph that do not change Input parameters to the graph that do not change
during inference time. Used for pre-compute during inference time. Used for constant folding.
folding optimization.
Returns Returns
------- -------
...@@ -176,14 +209,14 @@ def build(func, ...@@ -176,14 +209,14 @@ def build(func,
cfg = BuildConfig.current cfg = BuildConfig.current
with tophub_context: with tophub_context:
func = optimize(func) func = optimize(func, params)
# Fuse ops before running code gen # Fuse ops before running code gen
func = ir_pass.infer_type(func) func = ir_pass.infer_type(func)
func = ir_pass.fuse_ops(func, cfg.opt_level) func = ir_pass.fuse_ops(func, cfg.opt_level)
# Graph code generation # Graph code generation
func = ir_pass.infer_type(func) func = ir_pass.infer_type(func)
graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target) graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target)
graph_json, lowered_funcs = graph_gen.codegen(func) graph_json, lowered_funcs, params = graph_gen.codegen(func)
mod = _tvm_build_module(lowered_funcs, target=target, target_host=target_host) mod = _tvm_build_module(lowered_funcs, target=target, target_host=target_host)
return graph_json, mod, params return graph_json, mod, params
...@@ -210,21 +243,22 @@ class GraphExecutor(_interpreter.Executor): ...@@ -210,21 +243,22 @@ class GraphExecutor(_interpreter.Executor):
self.target = target self.target = target
def _make_executor(self, func): def _make_executor(self, func):
def _graph_wrapper(*args):
graph_json, mod, params = build(func, target=self.target) graph_json, mod, params = build(func, target=self.target)
assert params is None
gmodule = _graph_rt.create(graph_json, mod, self.ctx) gmodule = _graph_rt.create(graph_json, mod, self.ctx)
if params:
gmodule.set_input(*params)
def _graph_wrapper(*args):
# Create map of inputs. # Create map of inputs.
for i, arg in enumerate(args): for i, arg in enumerate(args):
gmodule.set_input(i, arg) gmodule.set_input(i, arg)
# Run the module, and fetch the output. # Run the module, and fetch the output.
gmodule.run() gmodule.run()
return gmodule.get_output(0) # make a copy so multiple invocation won't hurt perf.
return gmodule.get_output(0).copyto(_nd.cpu(0))
return _graph_wrapper return _graph_wrapper
def create_executor(kind="debug", def create_executor(kind="debug",
mod=None, mod=None,
ctx=None, ctx=None,
......
...@@ -6,6 +6,7 @@ from numbers import Number as _Number ...@@ -6,6 +6,7 @@ from numbers import Number as _Number
import numpy as _np import numpy as _np
from .base import RelayNode, register_relay_node from .base import RelayNode, register_relay_node
from . import _make from . import _make
from . import _expr
from . import ty as _ty from . import ty as _ty
from .._ffi import base as _base from .._ffi import base as _base
from .. import nd as _nd from .. import nd as _nd
...@@ -577,3 +578,24 @@ def const(value, dtype=None): ...@@ -577,3 +578,24 @@ def const(value, dtype=None):
if not isinstance(value, _nd.NDArray): if not isinstance(value, _nd.NDArray):
raise ValueError("value has to be scalar or NDArray") raise ValueError("value has to be scalar or NDArray")
return Constant(value) return Constant(value)
def bind(expr, binds):
"""Bind an free variables in expr or function arguments.
We can bind parameters expr if it is a function.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
binds : Union[Map[tvm.relay.Var, tvm.relay.Expr], Map[str, tvm.relay.Expr]]
The specific bindings.
Returns
-------
result : tvm.relay.Expr
The expression or function after binding.
"""
return _expr.Bind(expr, binds)
...@@ -259,6 +259,22 @@ def structural_hash(value): ...@@ -259,6 +259,22 @@ def structural_hash(value):
raise TypeError(msg) raise TypeError(msg)
def fold_constant(expr):
"""Fold the constant expression in expr.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
transformed_expr : tvm.relay.Expr
The transformed expression.
"""
return _ir_pass.FoldConstant(expr)
def fuse_ops(expr, opt_level=1): def fuse_ops(expr, opt_level=1):
"""Fuse operators in expr together. """Fuse operators in expr together.
......
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
* ExprMutator uses memoization and self return in order to amortize * ExprMutator uses memoization and self return in order to amortize
* the cost of using functional updates. * the cost of using functional updates.
*/ */
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include "type_functor.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -228,5 +228,74 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { ...@@ -228,5 +228,74 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) {
void ExprVisitor::VisitType(const Type& t) { return; } void ExprVisitor::VisitType(const Type& t) { return; }
// Implement bind.
class ExprBinder : public ExprMutator {
public:
explicit ExprBinder(const tvm::Map<Var, Expr>& args_map)
: args_map_(args_map) {
}
Expr VisitExpr_(const LetNode* op) final {
CHECK(!args_map_.count(op->var))
<< "Cannot bind an internel variable in let";
return ExprMutator::VisitExpr_(op);
}
Expr VisitExpr_(const FunctionNode* op) final {
for (Var param : op->params) {
CHECK(!args_map_.count(param))
<< "Cannnot bind an internal function parameter";
}
return ExprMutator::VisitExpr_(op);
}
Expr VisitExpr_(const VarNode* op) final {
auto id = GetRef<Var>(op);
auto it = args_map_.find(id);
if (it != args_map_.end()) {
return (*it).second;
} else {
return id;
}
}
private:
const tvm::Map<Var, Expr>& args_map_;
};
Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
if (const FunctionNode* func = expr.as<FunctionNode>()) {
Expr new_body = ExprBinder(args_map).Mutate(func->body);
Array<Var> new_params;
for (Var param : func->params) {
if (!args_map.count(param)) {
new_params.push_back(param);
}
}
if (new_body.same_as(func->body) &&
new_params.size() == func->params.size()) {
return expr;
}
return FunctionNode::make(new_params,
new_body,
func->ret_type,
func->type_params,
func->attrs);
} else {
return ExprBinder(args_map).Mutate(expr);
}
}
TVM_REGISTER_API("relay._expr.Bind")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef input = args[0];
if (input->derived_from<ExprNode>()) {
*ret = Bind(Downcast<Expr>(input), args[1]);
} else {
CHECK(input->derived_from<TypeNode>());
*ret = Bind(Downcast<Type>(input), args[1]);
}
});
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -11,8 +11,6 @@ ...@@ -11,8 +11,6 @@
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include "./../pass/type_subst.h"
namespace dmlc { namespace dmlc {
// enable registry // enable registry
DMLC_REGISTRY_ENABLE(::tvm::relay::OpRegistry); DMLC_REGISTRY_ENABLE(::tvm::relay::OpRegistry);
......
/*!
* Copyright (c) 2018 by Contributors
* \file type_functor.cc
* \brief Implementations of type functors.
*/
#include "type_functor.h"
namespace tvm {
namespace relay {
void TypeVisitor::VisitType_(const TypeVarNode* op) {
}
void TypeVisitor::VisitType_(const TensorTypeNode* op) {
}
void TypeVisitor::VisitType_(const IncompleteTypeNode* op) {
}
void TypeVisitor::VisitType_(const FuncTypeNode* op) {
for (auto type_param : op->type_params) {
this->VisitType(type_param);
}
for (auto type_cs : op->type_constraints) {
this->VisitType(type_cs);
}
for (auto arg_type : op->arg_types) {
this->VisitType(arg_type);
}
this->VisitType(op->ret_type);
}
void TypeVisitor::VisitType_(const TupleTypeNode* op) {
for (const Type& t : op->fields) {
this->VisitType(t);
}
}
void TypeVisitor::VisitType_(const TypeRelationNode* op) {
for (const Type& t : op->args) {
this->VisitType(t);
}
}
// Type Mutator.
Array<Type> TypeMutator::MutateArray(Array<Type> arr) {
// The array will do copy on write
// If no changes are made, the original array will be returned.
for (size_t i = 0; i < arr.size(); ++i) {
Type ty = arr[i];
Type new_ty = VisitType(ty);
if (!ty.same_as(new_ty)) {
arr.Set(i, new_ty);
}
}
return arr;
}
Type TypeMutator::VisitType_(const TypeVarNode* op) {
return GetRef<TypeVar>(op);
}
Type TypeMutator::VisitType_(const TensorTypeNode* op) {
// TODO(tvm-team) recursively visit to replace Var
return GetRef<Type>(op);
}
Type TypeMutator::VisitType_(const IncompleteTypeNode* op) {
return GetRef<Type>(op);
}
Type TypeMutator::VisitType_(const FuncTypeNode* op) {
bool changed = false;
Array<TypeVar> type_params;
for (auto type_param : op->type_params) {
auto new_type_param = VisitType(type_param);
changed = changed || !new_type_param.same_as(type_param);
if (const TypeVarNode* tin = new_type_param.as<TypeVarNode>()) {
type_params.push_back(GetRef<TypeVar>(tin));
} else {
LOG(FATAL) << new_type_param << std::endl;
}
}
Array<TypeConstraint> type_constraints;
for (auto type_cs : op->type_constraints) {
auto new_type_cs = VisitType(type_cs);
changed = changed || !new_type_cs.same_as(type_cs);
if (const TypeConstraintNode* tin =
new_type_cs.as_derived<TypeConstraintNode>()) {
type_constraints.push_back(GetRef<TypeConstraint>(tin));
} else {
LOG(FATAL) << new_type_cs << std::endl;
}
}
Array<Type> new_args = MutateArray(op->arg_types);
changed = changed || new_args.same_as(op->arg_types);
Type new_ret_type = VisitType(op->ret_type);
changed = changed || new_ret_type.same_as(op->ret_type);
if (!changed) return GetRef<Type>(op);
return FuncTypeNode::make(new_args,
new_ret_type,
type_params,
type_constraints);
}
Type TypeMutator::VisitType_(const TupleTypeNode* op) {
Array<Type> new_fields = MutateArray(op->fields);
if (new_fields.same_as(op->fields)) {
return GetRef<Type>(op);
} else {
return TupleTypeNode::make(new_fields);
}
}
Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) {
Array<Type> new_args = MutateArray(type_rel->args);
if (new_args.same_as(type_rel->args)) {
return GetRef<Type>(type_rel);
} else {
return TypeRelationNode::make(type_rel->func,
new_args,
type_rel->num_inputs,
type_rel->attrs);
}
}
// Implements bind.
class TypeBinder : public TypeMutator {
public:
explicit TypeBinder(const tvm::Map<TypeVar, Type>& args_map)
: args_map_(args_map) {}
Type VisitType_(const TypeVarNode* op) override {
auto id = GetRef<TypeVar>(op);
auto it = args_map_.find(id);
if (it != args_map_.end()) {
return (*it).second;
} else {
return id;
}
}
private:
const tvm::Map<TypeVar, Type>& args_map_;
};
Type Bind(const Type& type, const tvm::Map<TypeVar, Type>& args_map) {
return TypeBinder(args_map).VisitType(type);
}
} // namespace relay
} // namespace tvm
...@@ -91,113 +91,39 @@ class TypeFunctor<R(const Type& n, Args...)> { ...@@ -91,113 +91,39 @@ class TypeFunctor<R(const Type& n, Args...)> {
}; };
/*! /*!
* \brief A type visitor for vistiors which make use of internal * \brief A type visitor that recursively visit types.
* mutable state.
*
* We recursively visit each type contained inside the visitor.
*/ */
class TypeVisitor : class TypeVisitor : public TypeFunctor<void(const Type& n)> {
public ::tvm::relay::TypeFunctor<void(const Type& n)> {
public: public:
void VisitType_(const TypeVarNode* op) override {} void VisitType_(const TypeVarNode* op) override;
void VisitType_(const IncompleteTypeNode* op) override;
void VisitType_(const FuncTypeNode* op) override { void VisitType_(const TensorTypeNode* op) override;
for (auto type_param : op->type_params) { void VisitType_(const FuncTypeNode* op) override;
this->VisitType(type_param); void VisitType_(const TupleTypeNode* op) override;
} void VisitType_(const TypeRelationNode* op) override;
for (auto type_cs : op->type_constraints) {
this->VisitType(type_cs);
}
for (auto arg_type : op->arg_types) {
this->VisitType(arg_type);
}
this->VisitType(op->ret_type);
}
void VisitType_(const TensorTypeNode* op) override {}
void VisitType_(const TupleTypeNode* op) override {
for (const Type& t : op->fields) {
this->VisitType(t);
}
}
void VisitType_(const TypeRelationNode* op) override {
for (const Type& t : op->args) {
this->VisitType(t);
}
}
void VisitType_(const IncompleteTypeNode* op) override {}
}; };
// A functional visitor for rebuilding an AST in place. // Mutator that transform a type to another one.
struct TypeMutator : TypeFunctor<Type(const Type& n)> { class TypeMutator : public TypeFunctor<Type(const Type& n)> {
Type VisitType_(const TensorTypeNode* op) override { public:
// TODO(@jroesch): maybe we should recursively visit Type VisitType_(const TypeVarNode* op) override;
return TensorTypeNode::make(op->shape, op->dtype); Type VisitType_(const TensorTypeNode* op) override;
} Type VisitType_(const IncompleteTypeNode* op) override;
Type VisitType_(const FuncTypeNode* op) override;
Type VisitType_(const TypeVarNode* op) override { Type VisitType_(const TupleTypeNode* op) override;
return GetRef<TypeVar>(op); Type VisitType_(const TypeRelationNode* type_rel) override;
}
Type VisitType_(const FuncTypeNode* op) override {
Array<TypeVar> type_params;
for (auto type_param : op->type_params) {
auto new_type_param = VisitType(type_param);
if (const TypeVarNode* tin = new_type_param.as<TypeVarNode>()) {
type_params.push_back(GetRef<TypeVar>(tin));
} else {
CHECK(false) << new_type_param << std::endl;
}
}
Array<TypeConstraint> type_constraints;
for (auto type_cs : op->type_constraints) {
auto new_type_cs = VisitType(type_cs);
if (const TypeConstraintNode* tin =
new_type_cs.as_derived<TypeConstraintNode>()) {
type_constraints.push_back(GetRef<TypeConstraint>(tin));
} else {
CHECK(false) << new_type_cs << std::endl;
}
}
std::vector<Type> args;
for (auto arg_type : op->arg_types) {
args.push_back(VisitType(arg_type));
}
return FuncTypeNode::make(tvm::Array<Type>(args), VisitType(op->ret_type),
type_params, type_constraints);
}
Type VisitType_(const TupleTypeNode* op) override { private:
std::vector<Type> new_fields; Array<Type> MutateArray(Array<Type> arr);
for (const Type& t : op->fields) { };
new_fields.push_back(this->VisitType(t));
}
return TupleTypeNode::make(new_fields);
}
Type VisitType_(const TypeRelationNode* type_rel) override { /*!
std::vector<Type> new_args; * \brief Bind free type variables in the type.
for (const Type& t : type_rel->args) { * \param type The type to be updated.
new_args.push_back(this->VisitType(t)); * \param args_map The binding map.
} */
return TypeRelationNode::make(type_rel->func, Type Bind(const Type& type, const Map<TypeVar, Type>& args_map);
new_args,
type_rel->num_inputs,
type_rel->attrs);
}
Type VisitType_(const IncompleteTypeNode* op) override {
return GetRef<Type>(op);
}
};
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_IR_TYPE_FUNCTOR_H_ #endif // TVM_RELAY_IR_TYPE_FUNCTOR_H_
...@@ -71,7 +71,8 @@ std::vector<T> AsVector(const Array<T> &array) { ...@@ -71,7 +71,8 @@ std::vector<T> AsVector(const Array<T> &array) {
.add_argument("lhs", "Tensor", "The left hand side tensor.") \ .add_argument("lhs", "Tensor", "The left hand side tensor.") \
.add_argument("rhs", "Tensor", "The right hand side tensor.") \ .add_argument("rhs", "Tensor", "The right hand side tensor.") \
.add_type_rel("Broadcast", BroadcastRel) \ .add_type_rel("Broadcast", BroadcastRel) \
.set_attr<TOpPattern>("TOpPattern", kBroadcast) .set_attr<TOpPattern>("TOpPattern", kBroadcast) \
.set_attr<TOpIsStateful>("TOpIsStateful", false)
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
/*!
* Copyright (c) 2018 by Contributors
* \file constant_folding.cc
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/interpreter.h>
namespace tvm {
namespace relay {
using FInterpreter = runtime::TypedPackedFunc<Value(Expr)>;
// TODO(tvm-team) consider combine dead-code with constant folder.
// or make a more powerful partial evaluator.
class ConstantFolder : public ExprMutator {
public:
explicit ConstantFolder(FInterpreter executor)
: executor_(executor) {
}
Expr VisitExpr_(const LetNode* op) final {
Expr value = this->Mutate(op->value);
if (value.as<ConstantNode>()) {
memo_[op->var] = value;
return this->Mutate(op->body);
} else {
Var var = Downcast<Var>(this->Mutate(op->var));
Expr body = this->Mutate(op->body);
if (var.same_as(op->var) &&
value.same_as(op->value) &&
body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
return LetNode::make(var, value, body);
}
}
}
Expr VisitExpr_(const CallNode* call) final {
static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
Expr res = ExprMutator::VisitExpr_(call);
call = res.as<CallNode>();
// We don't constant fold function with zero arguments.
// This is a heuristic that is useful.
// For example it is harmful to fold ones(shape=(4, 5)).
if (call->args.size() == 0) return res;
const OpNode* op = call->op.as<OpNode>();
if (op == nullptr) return res;
// skip stateful ops.
if (op_stateful.get(GetRef<Op>(op), false)) return res;
bool all_const_args = true;
for (Expr arg : call->args) {
if (arg.as<ConstantNode>() == nullptr) {
all_const_args = false;
}
}
if (all_const_args) {
return ConstEvaluate(res);
} else {
return res;
}
}
Expr VisitExpr_(const TupleGetItemNode* op) final {
Expr res = ExprMutator::VisitExpr_(op);
op = res.as<TupleGetItemNode>();
if (const auto* tuple = op->tuple.as<TupleNode>()) {
return tuple->fields[op->index];
} else {
return res;
}
}
private:
// Internal interepreter.
FInterpreter executor_;
// Convert value to expression.
Expr ValueToExpr(Value value) {
if (const auto* val = value.as<TensorValueNode>()) {
return ConstantNode::make(val->data);
} else if (const auto* val = value.as<TupleValueNode>()) {
Array<Expr> fields;
for (Value field : val->fields) {
fields.push_back(ValueToExpr(field));
}
return TupleNode::make(fields);
} else {
LOG(FATAL) << "Cannot handle " << value->type_key();
return Expr();
}
}
// Constant evaluate a expression.
Expr ConstEvaluate(Expr expr) {
expr = InferType(expr, Module(nullptr));
expr = FuseOps(expr, 0);
expr = InferType(expr, Module(nullptr));
return ValueToExpr(executor_(expr));
}
};
Expr FoldConstant(const Expr& expr) {
DLContext ctx;
ctx.device_type = kDLCPU;
ctx.device_id = 0;
Target target = Target::create("llvm");
return ConstantFolder(CreateInterpreter(
Module(nullptr), ctx, target)).Mutate(expr);
}
TVM_REGISTER_API("relay._ir_pass.FoldConstant")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FoldConstant(args[0]);
});
} // namespace relay
} // namespace tvm
...@@ -22,6 +22,23 @@ namespace relay { ...@@ -22,6 +22,23 @@ namespace relay {
std::unordered_map<const Node*, size_t> std::unordered_map<const Node*, size_t>
GetExprRefCount(const Expr& body); GetExprRefCount(const Expr& body);
/*!
* \brief Substitute var with subst.
* \param type The type to be substituted.
* \param tvar The type variable to be substituted.
* \param subst The target of substitution.
* \return The substituted result.
*/
Type TypeSubst(const Type& type, const TypeVar& tvar, const Type& subst);
/*!
* \brief Substitute type vars in type.
* \param type The type to be substituted.
* \param subst_map The map of substitution.
* \return The substituted result.
*/
Type TypeSubst(const Type& type, const tvm::Map<TypeVar, Type>& subst_map);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_PASS_UTIL_H_ #endif // TVM_RELAY_PASS_PASS_UTIL_H_
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include "type_solver.h" #include "type_solver.h"
#include "type_subst.h" #include "../ir/type_functor.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -278,7 +278,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -278,7 +278,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, Type inst_ty = FuncTypeNode::make(fn_ty->arg_types,
ret_type, {}, ret_type, {},
fn_ty->type_constraints); fn_ty->type_constraints);
inst_ty = TypeSubst(inst_ty, subst_map); inst_ty = Bind(inst_ty, subst_map);
return Downcast<FuncType>(inst_ty); return Downcast<FuncType>(inst_ty);
} }
......
/*!
* Copyright (c) 2018 by Contributors
* \file type_subst.cc
* \brief Function for substituting a concrete type in place of a type ID
*/
#include "./type_subst.h"
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
struct TypeSubstV : TypeMutator {
tvm::Map<TypeVar, Type> subst_map;
explicit TypeSubstV(tvm::Map<TypeVar, Type> subst_map)
: subst_map(subst_map) {}
Type VisitType_(const TypeVarNode* op) override {
auto id = GetRef<TypeVar>(op);
if (subst_map.find(id) != subst_map.end()) {
return this->subst_map[id];
} else {
return id;
}
}
};
Type TypeSubst(const Type& type, const TypeVar& target, const Type& subst) {
TypeSubstV ty_sub({ {target, subst} });
return ty_sub.VisitType(type);
}
Type TypeSubst(const Type& type, tvm::Map<TypeVar, Type> subst_map) {
TypeSubstV ty_sub(subst_map);
return ty_sub.VisitType(type);
}
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file src/tvm/relay/pass/type_subst.h
* \brief Utility functions for substituting types.
*/
#ifndef TVM_RELAY_PASS_TYPE_SUBST_H_
#define TVM_RELAY_PASS_TYPE_SUBST_H_
#include <tvm/relay/expr.h>
namespace tvm {
namespace relay {
Type TypeSubst(const Type& type, const TypeVar& target, const Type& subst);
Type TypeSubst(const Type& type, tvm::Map<TypeVar, Type> subst_map);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_TYPE_SUBST_H_
...@@ -13,7 +13,6 @@ namespace tvm { ...@@ -13,7 +13,6 @@ namespace tvm {
namespace relay { namespace relay {
// FreeTypeVar // FreeTypeVar
class FreeTypeVarTVisitor : public TypeVisitor { class FreeTypeVarTVisitor : public TypeVisitor {
public: public:
FreeTypeVarTVisitor( FreeTypeVarTVisitor(
......
import numpy as np import numpy as np
import tvm
from tvm import relay from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.ir_pass import infer_type from tvm.relay.ir_pass import infer_type
from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.op import add from tvm.relay.op import add
...@@ -27,7 +29,7 @@ def check_rts(expr, args, expected_result, mod=None): ...@@ -27,7 +29,7 @@ def check_rts(expr, args, expected_result, mod=None):
graph = relay.create_executor('graph', mod=mod) graph = relay.create_executor('graph', mod=mod)
eval_result = intrp.evaluate(expr)(*args) eval_result = intrp.evaluate(expr)(*args)
rts_result = graph.evaluate(expr)(*args) rts_result = graph.evaluate(expr)(*args)
np.testing.assert_allclose(eval_result.asnumpy(), rts_result.asnumpy()) tvm.testing.assert_allclose(eval_result.asnumpy(), rts_result.asnumpy())
def test_add_op_scalar(): def test_add_op_scalar():
""" """
...@@ -71,7 +73,26 @@ def test_add_op_broadcast(): ...@@ -71,7 +73,26 @@ def test_add_op_broadcast():
y_data = np.random.rand(1, 5).astype('float32') y_data = np.random.rand(1, 5).astype('float32')
check_rts(func, [x_data, y_data], x_data + y_data) check_rts(func, [x_data, y_data], x_data + y_data)
def test_with_params():
x = relay.var('x', shape=(10, 5))
y = relay.var('y', shape=(1, 5))
func = relay.Function([x, y], add(x, y))
x_data = np.random.rand(10, 5).astype('float32')
y_data = np.random.rand(1, 5).astype('float32')
params = {"y": y_data}
graph, lib, params = relay.build(func, "llvm", params=params)
mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
mod.set_input(**params)
mod.set_input(x=x_data)
mod.run()
res = mod.get_output(0).asnumpy()
ref_res = y_data + x_data
tvm.testing.assert_allclose(res, ref_res)
if __name__ == "__main__": if __name__ == "__main__":
test_with_params()
test_add_op_scalar() test_add_op_scalar()
test_add_op_tensor() test_add_op_tensor()
test_add_op_broadcast() test_add_op_broadcast()
""" test bind function."""
import tvm
from tvm import relay
def test_bind_params():
x = relay.var("x")
y = relay.var("y")
z = relay.add(x, y)
f = relay.Function([x, y], z)
fbinded = relay.bind(f, {x : relay.const(1, "float32")})
fexpected =relay.Function(
[y],
relay.add(relay.const(1, "float32"), y))
assert relay.ir_pass.alpha_equal(fbinded, fexpected)
zbinded = relay.bind(z, {y: x})
zexpected = relay.add(x, x)
assert relay.ir_pass.alpha_equal(zbinded, zexpected)
if __name__ == "__main__":
test_bind_params()
import numpy as np
from tvm import relay
def test_fold_const():
c_data = np.array([1, 2, 3]).astype("float32")
def before():
c = relay.const(c_data)
x = relay.var("x")
y = relay.add(c, c)
y = relay.multiply(y, relay.const(2, "float32"))
y = relay.add(x, y)
z = relay.add(y, c)
return relay.Function([x], z)
def expected():
x = relay.var("x")
c_folded = (c_data + c_data) * 2
y = relay.add(x, relay.const(c_folded))
z = relay.add(y, relay.const(c_data))
return relay.Function([x], z)
zz = relay.ir_pass.fold_constant(before())
zexpected = expected()
assert relay.ir_pass.alpha_equal(zz, zexpected)
def test_fold_let():
c_data = np.array(1).astype("float32")
def before():
sb = relay.ScopeBuilder()
x = relay.var("x")
t1 = sb.let("t1", relay.const(c_data))
t2 = sb.let("t2", relay.add(t1, t1))
t3 = sb.let("t3", relay.add(t2, x))
sb.ret(t3)
return relay.Function([x], sb.get())
def expected():
sb = relay.ScopeBuilder()
x = relay.var("x")
c_folded = (c_data + c_data)
t3 = sb.let("t3", relay.add(relay.const(c_folded), x))
sb.ret(t3)
return relay.Function([x], sb.get())
zz = relay.ir_pass.fold_constant(before())
zexpected = expected()
assert relay.ir_pass.graph_equal(zz, zexpected)
def test_fold_tuple():
c_data = np.array(1).astype("float32")
def before():
c = relay.const(c_data)
x = relay.var("x")
y = relay.Tuple([x, c])
z = relay.add(y[1], c)
z = relay.add(z, y[0])
return relay.Function([x], z)
def expected():
c = relay.const(c_data + c_data)
x = relay.var("x")
z = relay.add(c, x)
return relay.Function([x], z)
zz = relay.ir_pass.fold_constant(before())
zexpected = expected()
assert relay.ir_pass.graph_equal(zz, zexpected)
if __name__ == "__main__":
test_fold_const()
test_fold_let()
test_fold_tuple()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment