Unverified Commit b2521604 by Tianqi Chen Committed by GitHub

[RELAY][PASS] Bind, FoldConstant (#2100)

parent 1b863732
......@@ -182,6 +182,17 @@ class ExprMutator
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_;
};
/*
* \brief Bind function parameters or free variables.
*
* Parameter binding can only happen if expr is a Function.
* binds cannot change internal arguments of internal functions.
*
* \param expr The function to be binded.
* \param binds The map of arguments to
*/
Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_FUNCTOR_H_
......@@ -39,6 +39,16 @@ enum OpPatternKind {
using TOpPattern = int;
/*!
* \brief Whether operator is stateful or contain internal state.
*
* All the primitive ops we registered so far are pure.
* This attribute is left for potential future compatible reasons.
* We can always work around the stateful ops by adding an additional
* handle argument and return it.
*/
using TOpIsStateful = bool;
/*!
* \brief Computation description interface.
*
* \note This function have a special convention
......
......@@ -143,6 +143,22 @@ tvm::Array<TypeVar> FreeTypeVars(const Expr& expr);
*/
Expr DeadCodeElimination(const Expr& e);
/*!
* \brief Fold constant expressions.
* \param expr the expression to be optimized.
* \return The optimized expression.
*/
Expr FoldConstant(const Expr& expr);
/*!
* \brief Fuse operations into expr into seperate functions.
* \param expr The expression.
* \param fuse_opt_level Optimization level.
* \return The optimized expression.
*/
Expr FuseOps(const Expr& expr, int fuse_opt_level);
/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
/*! \brief Hash a Relay type.
......
......@@ -54,7 +54,7 @@ TupleGetItem = expr.TupleGetItem
# helper functions
var = expr.var
const = expr.const
bind = expr.bind
# pylint: disable=unused-argument
@register_func("relay.debug")
......
......@@ -102,6 +102,7 @@ class GraphRuntimeCodegen(ExprFunctor):
self.target = target
self.nodes = []
self.var_map = {}
self.params = {}
self.compile_engine = compile_engine.get()
self.lowered_funcs = set()
self._name_map = {}
......@@ -162,8 +163,12 @@ class GraphRuntimeCodegen(ExprFunctor):
assert isinstance(vtuple, tuple)
return vtuple[op.index]
def visit_constant(self, _):
raise RuntimeError("constant not supported")
def visit_constant(self, op):
index = len(self.params)
name = "p%d" % index
self.params[name] = op.data
node = InputNode(name, {})
return self.add_node(node, op.checked_type)
def visit_function(self, _):
raise RuntimeError("function not supported")
......@@ -312,6 +317,9 @@ class GraphRuntimeCodegen(ExprFunctor):
lowered_funcs : List[tvm.LoweredFunc]
The lowered functions.
params : Dict[str, tvm.nd.NDArray]
Additional constant parameters.
"""
# First we convert all the parameters into input nodes.
for param in func.params:
......@@ -324,7 +332,7 @@ class GraphRuntimeCodegen(ExprFunctor):
self.heads = self.visit(func.body)
graph_json = self._get_json()
lowered_funcs = list(self.lowered_funcs)
return graph_json, lowered_funcs
return graph_json, lowered_funcs, self.params
def _get_unique_name(self, name):
if name not in self._name_map:
......
......@@ -6,6 +6,7 @@ from ..build_module import build as _tvm_build_module
from .. import nd as _nd, target as _target, autotvm
from ..contrib import graph_runtime as _graph_rt
from . import ir_pass
from . import expr
from .backend import interpreter as _interpreter
from .backend import graph_runtime_codegen as _graph_gen
......@@ -13,6 +14,7 @@ from .backend import graph_runtime_codegen as _graph_gen
OPT_PASS_LEVEL = {
"SimplifyInference": 0,
"OpFusion": 1,
"FoldConstant": 2,
"FoldScaleAxis": 3,
}
......@@ -95,7 +97,27 @@ def build_config(**kwargs):
return BuildConfig(**kwargs)
def optimize(func):
def _bind_params_by_name(func, params):
"""Bind parameters of function by its name."""
name_dict = {}
for arg in func.params:
name = arg.name_hint
if name in name_dict:
name_dict[name] = None
else:
name_dict[name] = arg
bind_dict = {}
for k, v in params.items():
if k not in name_dict:
continue
arg = name_dict[k]
if arg is None:
raise ValueError("Multiple args in the function have name %s" % k)
bind_dict[arg] = expr.const(v)
return expr.bind(func, bind_dict)
def optimize(func, params=None):
"""Perform target invariant optimizations.
Parameters
......@@ -103,6 +125,10 @@ def optimize(func):
func : tvm.relay.Function
The input to optimization.
params : Optional[Dict[str, tvm.nd.NDArray]]
Input parameters to the graph that do not change
during inference time. used for constant folding.
Returns
-------
opt_func : tvm.relay.Function
......@@ -110,7 +136,11 @@ def optimize(func):
"""
cfg = BuildConfig.current
if cfg.pass_enabled("FoldScaleAxis"):
# bind expressions
if params:
func = _bind_params_by_name(func, params)
if cfg.pass_enabled("SimplifyInference"):
func = ir_pass.infer_type(func)
func = ir_pass.simplify_inference(func)
......@@ -119,6 +149,10 @@ def optimize(func):
func = ir_pass.backward_fold_scale_axis(func)
func = ir_pass.infer_type(func)
func = ir_pass.forward_fold_scale_axis(func)
if cfg.pass_enabled("FoldConstant"):
func = ir_pass.fold_constant(func)
return func
......@@ -147,8 +181,7 @@ def build(func,
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for pre-compute
folding optimization.
during inference time. Used for constant folding.
Returns
-------
......@@ -176,14 +209,14 @@ def build(func,
cfg = BuildConfig.current
with tophub_context:
func = optimize(func)
func = optimize(func, params)
# Fuse ops before running code gen
func = ir_pass.infer_type(func)
func = ir_pass.fuse_ops(func, cfg.opt_level)
# Graph code generation
func = ir_pass.infer_type(func)
graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target)
graph_json, lowered_funcs = graph_gen.codegen(func)
graph_json, lowered_funcs, params = graph_gen.codegen(func)
mod = _tvm_build_module(lowered_funcs, target=target, target_host=target_host)
return graph_json, mod, params
......@@ -210,21 +243,22 @@ class GraphExecutor(_interpreter.Executor):
self.target = target
def _make_executor(self, func):
def _graph_wrapper(*args):
graph_json, mod, params = build(func, target=self.target)
assert params is None
gmodule = _graph_rt.create(graph_json, mod, self.ctx)
if params:
gmodule.set_input(*params)
def _graph_wrapper(*args):
# Create map of inputs.
for i, arg in enumerate(args):
gmodule.set_input(i, arg)
# Run the module, and fetch the output.
gmodule.run()
return gmodule.get_output(0)
# make a copy so multiple invocation won't hurt perf.
return gmodule.get_output(0).copyto(_nd.cpu(0))
return _graph_wrapper
def create_executor(kind="debug",
mod=None,
ctx=None,
......
......@@ -6,6 +6,7 @@ from numbers import Number as _Number
import numpy as _np
from .base import RelayNode, register_relay_node
from . import _make
from . import _expr
from . import ty as _ty
from .._ffi import base as _base
from .. import nd as _nd
......@@ -577,3 +578,24 @@ def const(value, dtype=None):
if not isinstance(value, _nd.NDArray):
raise ValueError("value has to be scalar or NDArray")
return Constant(value)
def bind(expr, binds):
"""Bind an free variables in expr or function arguments.
We can bind parameters expr if it is a function.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
binds : Union[Map[tvm.relay.Var, tvm.relay.Expr], Map[str, tvm.relay.Expr]]
The specific bindings.
Returns
-------
result : tvm.relay.Expr
The expression or function after binding.
"""
return _expr.Bind(expr, binds)
......@@ -259,6 +259,22 @@ def structural_hash(value):
raise TypeError(msg)
def fold_constant(expr):
"""Fold the constant expression in expr.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
transformed_expr : tvm.relay.Expr
The transformed expression.
"""
return _ir_pass.FoldConstant(expr)
def fuse_ops(expr, opt_level=1):
"""Fuse operators in expr together.
......
......@@ -6,8 +6,8 @@
* ExprMutator uses memoization and self return in order to amortize
* the cost of using functional updates.
*/
#include <tvm/relay/expr_functor.h>
#include "type_functor.h"
namespace tvm {
namespace relay {
......@@ -228,5 +228,74 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) {
void ExprVisitor::VisitType(const Type& t) { return; }
// Implement bind.
class ExprBinder : public ExprMutator {
public:
explicit ExprBinder(const tvm::Map<Var, Expr>& args_map)
: args_map_(args_map) {
}
Expr VisitExpr_(const LetNode* op) final {
CHECK(!args_map_.count(op->var))
<< "Cannot bind an internel variable in let";
return ExprMutator::VisitExpr_(op);
}
Expr VisitExpr_(const FunctionNode* op) final {
for (Var param : op->params) {
CHECK(!args_map_.count(param))
<< "Cannnot bind an internal function parameter";
}
return ExprMutator::VisitExpr_(op);
}
Expr VisitExpr_(const VarNode* op) final {
auto id = GetRef<Var>(op);
auto it = args_map_.find(id);
if (it != args_map_.end()) {
return (*it).second;
} else {
return id;
}
}
private:
const tvm::Map<Var, Expr>& args_map_;
};
Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
if (const FunctionNode* func = expr.as<FunctionNode>()) {
Expr new_body = ExprBinder(args_map).Mutate(func->body);
Array<Var> new_params;
for (Var param : func->params) {
if (!args_map.count(param)) {
new_params.push_back(param);
}
}
if (new_body.same_as(func->body) &&
new_params.size() == func->params.size()) {
return expr;
}
return FunctionNode::make(new_params,
new_body,
func->ret_type,
func->type_params,
func->attrs);
} else {
return ExprBinder(args_map).Mutate(expr);
}
}
TVM_REGISTER_API("relay._expr.Bind")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef input = args[0];
if (input->derived_from<ExprNode>()) {
*ret = Bind(Downcast<Expr>(input), args[1]);
} else {
CHECK(input->derived_from<TypeNode>());
*ret = Bind(Downcast<Type>(input), args[1]);
}
});
} // namespace relay
} // namespace tvm
......@@ -11,8 +11,6 @@
#include <memory>
#include <mutex>
#include "./../pass/type_subst.h"
namespace dmlc {
// enable registry
DMLC_REGISTRY_ENABLE(::tvm::relay::OpRegistry);
......
/*!
* Copyright (c) 2018 by Contributors
* \file type_functor.cc
* \brief Implementations of type functors.
*/
#include "type_functor.h"
namespace tvm {
namespace relay {
void TypeVisitor::VisitType_(const TypeVarNode* op) {
}
void TypeVisitor::VisitType_(const TensorTypeNode* op) {
}
void TypeVisitor::VisitType_(const IncompleteTypeNode* op) {
}
void TypeVisitor::VisitType_(const FuncTypeNode* op) {
for (auto type_param : op->type_params) {
this->VisitType(type_param);
}
for (auto type_cs : op->type_constraints) {
this->VisitType(type_cs);
}
for (auto arg_type : op->arg_types) {
this->VisitType(arg_type);
}
this->VisitType(op->ret_type);
}
void TypeVisitor::VisitType_(const TupleTypeNode* op) {
for (const Type& t : op->fields) {
this->VisitType(t);
}
}
void TypeVisitor::VisitType_(const TypeRelationNode* op) {
for (const Type& t : op->args) {
this->VisitType(t);
}
}
// Type Mutator.
Array<Type> TypeMutator::MutateArray(Array<Type> arr) {
// The array will do copy on write
// If no changes are made, the original array will be returned.
for (size_t i = 0; i < arr.size(); ++i) {
Type ty = arr[i];
Type new_ty = VisitType(ty);
if (!ty.same_as(new_ty)) {
arr.Set(i, new_ty);
}
}
return arr;
}
Type TypeMutator::VisitType_(const TypeVarNode* op) {
return GetRef<TypeVar>(op);
}
Type TypeMutator::VisitType_(const TensorTypeNode* op) {
// TODO(tvm-team) recursively visit to replace Var
return GetRef<Type>(op);
}
Type TypeMutator::VisitType_(const IncompleteTypeNode* op) {
return GetRef<Type>(op);
}
Type TypeMutator::VisitType_(const FuncTypeNode* op) {
bool changed = false;
Array<TypeVar> type_params;
for (auto type_param : op->type_params) {
auto new_type_param = VisitType(type_param);
changed = changed || !new_type_param.same_as(type_param);
if (const TypeVarNode* tin = new_type_param.as<TypeVarNode>()) {
type_params.push_back(GetRef<TypeVar>(tin));
} else {
LOG(FATAL) << new_type_param << std::endl;
}
}
Array<TypeConstraint> type_constraints;
for (auto type_cs : op->type_constraints) {
auto new_type_cs = VisitType(type_cs);
changed = changed || !new_type_cs.same_as(type_cs);
if (const TypeConstraintNode* tin =
new_type_cs.as_derived<TypeConstraintNode>()) {
type_constraints.push_back(GetRef<TypeConstraint>(tin));
} else {
LOG(FATAL) << new_type_cs << std::endl;
}
}
Array<Type> new_args = MutateArray(op->arg_types);
changed = changed || new_args.same_as(op->arg_types);
Type new_ret_type = VisitType(op->ret_type);
changed = changed || new_ret_type.same_as(op->ret_type);
if (!changed) return GetRef<Type>(op);
return FuncTypeNode::make(new_args,
new_ret_type,
type_params,
type_constraints);
}
Type TypeMutator::VisitType_(const TupleTypeNode* op) {
Array<Type> new_fields = MutateArray(op->fields);
if (new_fields.same_as(op->fields)) {
return GetRef<Type>(op);
} else {
return TupleTypeNode::make(new_fields);
}
}
Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) {
Array<Type> new_args = MutateArray(type_rel->args);
if (new_args.same_as(type_rel->args)) {
return GetRef<Type>(type_rel);
} else {
return TypeRelationNode::make(type_rel->func,
new_args,
type_rel->num_inputs,
type_rel->attrs);
}
}
// Implements bind.
class TypeBinder : public TypeMutator {
public:
explicit TypeBinder(const tvm::Map<TypeVar, Type>& args_map)
: args_map_(args_map) {}
Type VisitType_(const TypeVarNode* op) override {
auto id = GetRef<TypeVar>(op);
auto it = args_map_.find(id);
if (it != args_map_.end()) {
return (*it).second;
} else {
return id;
}
}
private:
const tvm::Map<TypeVar, Type>& args_map_;
};
Type Bind(const Type& type, const tvm::Map<TypeVar, Type>& args_map) {
return TypeBinder(args_map).VisitType(type);
}
} // namespace relay
} // namespace tvm
......@@ -91,113 +91,39 @@ class TypeFunctor<R(const Type& n, Args...)> {
};
/*!
* \brief A type visitor for vistiors which make use of internal
* mutable state.
*
* We recursively visit each type contained inside the visitor.
* \brief A type visitor that recursively visit types.
*/
class TypeVisitor :
public ::tvm::relay::TypeFunctor<void(const Type& n)> {
class TypeVisitor : public TypeFunctor<void(const Type& n)> {
public:
void VisitType_(const TypeVarNode* op) override {}
void VisitType_(const FuncTypeNode* op) override {
for (auto type_param : op->type_params) {
this->VisitType(type_param);
}
for (auto type_cs : op->type_constraints) {
this->VisitType(type_cs);
}
for (auto arg_type : op->arg_types) {
this->VisitType(arg_type);
}
this->VisitType(op->ret_type);
}
void VisitType_(const TensorTypeNode* op) override {}
void VisitType_(const TupleTypeNode* op) override {
for (const Type& t : op->fields) {
this->VisitType(t);
}
}
void VisitType_(const TypeRelationNode* op) override {
for (const Type& t : op->args) {
this->VisitType(t);
}
}
void VisitType_(const IncompleteTypeNode* op) override {}
void VisitType_(const TypeVarNode* op) override;
void VisitType_(const IncompleteTypeNode* op) override;
void VisitType_(const TensorTypeNode* op) override;
void VisitType_(const FuncTypeNode* op) override;
void VisitType_(const TupleTypeNode* op) override;
void VisitType_(const TypeRelationNode* op) override;
};
// A functional visitor for rebuilding an AST in place.
struct TypeMutator : TypeFunctor<Type(const Type& n)> {
Type VisitType_(const TensorTypeNode* op) override {
// TODO(@jroesch): maybe we should recursively visit
return TensorTypeNode::make(op->shape, op->dtype);
}
Type VisitType_(const TypeVarNode* op) override {
return GetRef<TypeVar>(op);
}
Type VisitType_(const FuncTypeNode* op) override {
Array<TypeVar> type_params;
for (auto type_param : op->type_params) {
auto new_type_param = VisitType(type_param);
if (const TypeVarNode* tin = new_type_param.as<TypeVarNode>()) {
type_params.push_back(GetRef<TypeVar>(tin));
} else {
CHECK(false) << new_type_param << std::endl;
}
}
Array<TypeConstraint> type_constraints;
for (auto type_cs : op->type_constraints) {
auto new_type_cs = VisitType(type_cs);
if (const TypeConstraintNode* tin =
new_type_cs.as_derived<TypeConstraintNode>()) {
type_constraints.push_back(GetRef<TypeConstraint>(tin));
} else {
CHECK(false) << new_type_cs << std::endl;
}
}
std::vector<Type> args;
for (auto arg_type : op->arg_types) {
args.push_back(VisitType(arg_type));
}
return FuncTypeNode::make(tvm::Array<Type>(args), VisitType(op->ret_type),
type_params, type_constraints);
}
// Mutator that transform a type to another one.
class TypeMutator : public TypeFunctor<Type(const Type& n)> {
public:
Type VisitType_(const TypeVarNode* op) override;
Type VisitType_(const TensorTypeNode* op) override;
Type VisitType_(const IncompleteTypeNode* op) override;
Type VisitType_(const FuncTypeNode* op) override;
Type VisitType_(const TupleTypeNode* op) override;
Type VisitType_(const TypeRelationNode* type_rel) override;
Type VisitType_(const TupleTypeNode* op) override {
std::vector<Type> new_fields;
for (const Type& t : op->fields) {
new_fields.push_back(this->VisitType(t));
}
return TupleTypeNode::make(new_fields);
}
private:
Array<Type> MutateArray(Array<Type> arr);
};
Type VisitType_(const TypeRelationNode* type_rel) override {
std::vector<Type> new_args;
for (const Type& t : type_rel->args) {
new_args.push_back(this->VisitType(t));
}
return TypeRelationNode::make(type_rel->func,
new_args,
type_rel->num_inputs,
type_rel->attrs);
}
/*!
* \brief Bind free type variables in the type.
* \param type The type to be updated.
* \param args_map The binding map.
*/
Type Bind(const Type& type, const Map<TypeVar, Type>& args_map);
Type VisitType_(const IncompleteTypeNode* op) override {
return GetRef<Type>(op);
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_IR_TYPE_FUNCTOR_H_
......@@ -71,7 +71,8 @@ std::vector<T> AsVector(const Array<T> &array) {
.add_argument("lhs", "Tensor", "The left hand side tensor.") \
.add_argument("rhs", "Tensor", "The right hand side tensor.") \
.add_type_rel("Broadcast", BroadcastRel) \
.set_attr<TOpPattern>("TOpPattern", kBroadcast)
.set_attr<TOpPattern>("TOpPattern", kBroadcast) \
.set_attr<TOpIsStateful>("TOpIsStateful", false)
} // namespace relay
} // namespace tvm
......
/*!
* Copyright (c) 2018 by Contributors
* \file constant_folding.cc
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/interpreter.h>
namespace tvm {
namespace relay {
using FInterpreter = runtime::TypedPackedFunc<Value(Expr)>;
// TODO(tvm-team) consider combine dead-code with constant folder.
// or make a more powerful partial evaluator.
class ConstantFolder : public ExprMutator {
public:
explicit ConstantFolder(FInterpreter executor)
: executor_(executor) {
}
Expr VisitExpr_(const LetNode* op) final {
Expr value = this->Mutate(op->value);
if (value.as<ConstantNode>()) {
memo_[op->var] = value;
return this->Mutate(op->body);
} else {
Var var = Downcast<Var>(this->Mutate(op->var));
Expr body = this->Mutate(op->body);
if (var.same_as(op->var) &&
value.same_as(op->value) &&
body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
return LetNode::make(var, value, body);
}
}
}
Expr VisitExpr_(const CallNode* call) final {
static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
Expr res = ExprMutator::VisitExpr_(call);
call = res.as<CallNode>();
// We don't constant fold function with zero arguments.
// This is a heuristic that is useful.
// For example it is harmful to fold ones(shape=(4, 5)).
if (call->args.size() == 0) return res;
const OpNode* op = call->op.as<OpNode>();
if (op == nullptr) return res;
// skip stateful ops.
if (op_stateful.get(GetRef<Op>(op), false)) return res;
bool all_const_args = true;
for (Expr arg : call->args) {
if (arg.as<ConstantNode>() == nullptr) {
all_const_args = false;
}
}
if (all_const_args) {
return ConstEvaluate(res);
} else {
return res;
}
}
Expr VisitExpr_(const TupleGetItemNode* op) final {
Expr res = ExprMutator::VisitExpr_(op);
op = res.as<TupleGetItemNode>();
if (const auto* tuple = op->tuple.as<TupleNode>()) {
return tuple->fields[op->index];
} else {
return res;
}
}
private:
// Internal interepreter.
FInterpreter executor_;
// Convert value to expression.
Expr ValueToExpr(Value value) {
if (const auto* val = value.as<TensorValueNode>()) {
return ConstantNode::make(val->data);
} else if (const auto* val = value.as<TupleValueNode>()) {
Array<Expr> fields;
for (Value field : val->fields) {
fields.push_back(ValueToExpr(field));
}
return TupleNode::make(fields);
} else {
LOG(FATAL) << "Cannot handle " << value->type_key();
return Expr();
}
}
// Constant evaluate a expression.
Expr ConstEvaluate(Expr expr) {
expr = InferType(expr, Module(nullptr));
expr = FuseOps(expr, 0);
expr = InferType(expr, Module(nullptr));
return ValueToExpr(executor_(expr));
}
};
Expr FoldConstant(const Expr& expr) {
DLContext ctx;
ctx.device_type = kDLCPU;
ctx.device_id = 0;
Target target = Target::create("llvm");
return ConstantFolder(CreateInterpreter(
Module(nullptr), ctx, target)).Mutate(expr);
}
TVM_REGISTER_API("relay._ir_pass.FoldConstant")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FoldConstant(args[0]);
});
} // namespace relay
} // namespace tvm
......@@ -22,6 +22,23 @@ namespace relay {
std::unordered_map<const Node*, size_t>
GetExprRefCount(const Expr& body);
/*!
* \brief Substitute var with subst.
* \param type The type to be substituted.
* \param tvar The type variable to be substituted.
* \param subst The target of substitution.
* \return The substituted result.
*/
Type TypeSubst(const Type& type, const TypeVar& tvar, const Type& subst);
/*!
* \brief Substitute type vars in type.
* \param type The type to be substituted.
* \param subst_map The map of substitution.
* \return The substituted result.
*/
Type TypeSubst(const Type& type, const tvm::Map<TypeVar, Type>& subst_map);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_PASS_UTIL_H_
......@@ -24,7 +24,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pass.h>
#include "type_solver.h"
#include "type_subst.h"
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
......@@ -278,7 +278,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
Type inst_ty = FuncTypeNode::make(fn_ty->arg_types,
ret_type, {},
fn_ty->type_constraints);
inst_ty = TypeSubst(inst_ty, subst_map);
inst_ty = Bind(inst_ty, subst_map);
return Downcast<FuncType>(inst_ty);
}
......
/*!
* Copyright (c) 2018 by Contributors
* \file type_subst.cc
* \brief Function for substituting a concrete type in place of a type ID
*/
#include "./type_subst.h"
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
struct TypeSubstV : TypeMutator {
tvm::Map<TypeVar, Type> subst_map;
explicit TypeSubstV(tvm::Map<TypeVar, Type> subst_map)
: subst_map(subst_map) {}
Type VisitType_(const TypeVarNode* op) override {
auto id = GetRef<TypeVar>(op);
if (subst_map.find(id) != subst_map.end()) {
return this->subst_map[id];
} else {
return id;
}
}
};
Type TypeSubst(const Type& type, const TypeVar& target, const Type& subst) {
TypeSubstV ty_sub({ {target, subst} });
return ty_sub.VisitType(type);
}
Type TypeSubst(const Type& type, tvm::Map<TypeVar, Type> subst_map) {
TypeSubstV ty_sub(subst_map);
return ty_sub.VisitType(type);
}
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file src/tvm/relay/pass/type_subst.h
* \brief Utility functions for substituting types.
*/
#ifndef TVM_RELAY_PASS_TYPE_SUBST_H_
#define TVM_RELAY_PASS_TYPE_SUBST_H_
#include <tvm/relay/expr.h>
namespace tvm {
namespace relay {
Type TypeSubst(const Type& type, const TypeVar& target, const Type& subst);
Type TypeSubst(const Type& type, tvm::Map<TypeVar, Type> subst_map);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_TYPE_SUBST_H_
......@@ -13,7 +13,6 @@ namespace tvm {
namespace relay {
// FreeTypeVar
class FreeTypeVarTVisitor : public TypeVisitor {
public:
FreeTypeVarTVisitor(
......
import numpy as np
import tvm
from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.ir_pass import infer_type
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.op import add
......@@ -27,7 +29,7 @@ def check_rts(expr, args, expected_result, mod=None):
graph = relay.create_executor('graph', mod=mod)
eval_result = intrp.evaluate(expr)(*args)
rts_result = graph.evaluate(expr)(*args)
np.testing.assert_allclose(eval_result.asnumpy(), rts_result.asnumpy())
tvm.testing.assert_allclose(eval_result.asnumpy(), rts_result.asnumpy())
def test_add_op_scalar():
"""
......@@ -71,7 +73,26 @@ def test_add_op_broadcast():
y_data = np.random.rand(1, 5).astype('float32')
check_rts(func, [x_data, y_data], x_data + y_data)
def test_with_params():
x = relay.var('x', shape=(10, 5))
y = relay.var('y', shape=(1, 5))
func = relay.Function([x, y], add(x, y))
x_data = np.random.rand(10, 5).astype('float32')
y_data = np.random.rand(1, 5).astype('float32')
params = {"y": y_data}
graph, lib, params = relay.build(func, "llvm", params=params)
mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
mod.set_input(**params)
mod.set_input(x=x_data)
mod.run()
res = mod.get_output(0).asnumpy()
ref_res = y_data + x_data
tvm.testing.assert_allclose(res, ref_res)
if __name__ == "__main__":
test_with_params()
test_add_op_scalar()
test_add_op_tensor()
test_add_op_broadcast()
""" test bind function."""
import tvm
from tvm import relay
def test_bind_params():
x = relay.var("x")
y = relay.var("y")
z = relay.add(x, y)
f = relay.Function([x, y], z)
fbinded = relay.bind(f, {x : relay.const(1, "float32")})
fexpected =relay.Function(
[y],
relay.add(relay.const(1, "float32"), y))
assert relay.ir_pass.alpha_equal(fbinded, fexpected)
zbinded = relay.bind(z, {y: x})
zexpected = relay.add(x, x)
assert relay.ir_pass.alpha_equal(zbinded, zexpected)
if __name__ == "__main__":
test_bind_params()
import numpy as np
from tvm import relay
def test_fold_const():
c_data = np.array([1, 2, 3]).astype("float32")
def before():
c = relay.const(c_data)
x = relay.var("x")
y = relay.add(c, c)
y = relay.multiply(y, relay.const(2, "float32"))
y = relay.add(x, y)
z = relay.add(y, c)
return relay.Function([x], z)
def expected():
x = relay.var("x")
c_folded = (c_data + c_data) * 2
y = relay.add(x, relay.const(c_folded))
z = relay.add(y, relay.const(c_data))
return relay.Function([x], z)
zz = relay.ir_pass.fold_constant(before())
zexpected = expected()
assert relay.ir_pass.alpha_equal(zz, zexpected)
def test_fold_let():
c_data = np.array(1).astype("float32")
def before():
sb = relay.ScopeBuilder()
x = relay.var("x")
t1 = sb.let("t1", relay.const(c_data))
t2 = sb.let("t2", relay.add(t1, t1))
t3 = sb.let("t3", relay.add(t2, x))
sb.ret(t3)
return relay.Function([x], sb.get())
def expected():
sb = relay.ScopeBuilder()
x = relay.var("x")
c_folded = (c_data + c_data)
t3 = sb.let("t3", relay.add(relay.const(c_folded), x))
sb.ret(t3)
return relay.Function([x], sb.get())
zz = relay.ir_pass.fold_constant(before())
zexpected = expected()
assert relay.ir_pass.graph_equal(zz, zexpected)
def test_fold_tuple():
c_data = np.array(1).astype("float32")
def before():
c = relay.const(c_data)
x = relay.var("x")
y = relay.Tuple([x, c])
z = relay.add(y[1], c)
z = relay.add(z, y[0])
return relay.Function([x], z)
def expected():
c = relay.const(c_data + c_data)
x = relay.var("x")
z = relay.add(c, x)
return relay.Function([x], z)
zz = relay.ir_pass.fold_constant(before())
zexpected = expected()
assert relay.ir_pass.graph_equal(zz, zexpected)
if __name__ == "__main__":
test_fold_const()
test_fold_let()
test_fold_tuple()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment