Commit ead3ac6c by Jared Roesch Committed by Tianqi Chen

Rename relay::Environment to relay::Module (#2054)

parent 420ec786
...@@ -165,7 +165,7 @@ class RelayNode : public Node { ...@@ -165,7 +165,7 @@ class RelayNode : public Node {
TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node); TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node);
}; };
struct Environment; struct Module;
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#define TVM_RELAY_BUILD_MODULE_H_ #define TVM_RELAY_BUILD_MODULE_H_
#include <tvm/lowered_func.h> #include <tvm/lowered_func.h>
#include <tvm/relay/environment.h> #include <tvm/relay/module.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <string> #include <string>
...@@ -61,13 +61,13 @@ RELAY_DEFINE_NODE_REF(LoweredOp, LoweredOpNode, NodeRef); ...@@ -61,13 +61,13 @@ RELAY_DEFINE_NODE_REF(LoweredOp, LoweredOpNode, NodeRef);
* \note This will do a reachability analysis and lower all definitions * \note This will do a reachability analysis and lower all definitions
* reachable from the provided expression. * reachable from the provided expression.
* *
* \param env The environment. * \param mod The module.
* \param expr The expression with operations to be lowered. * \param expr The expression with operations to be lowered.
* \param target The target to lower the functions to. * \param target The target to lower the functions to.
* *
* \return The set of lowered operations. * \return The set of lowered operations.
*/ */
Array<LoweredOp> LowerOps(const Environment& env, const Expr& expr, Array<LoweredOp> LowerOps(const Module& mod, const Expr& expr,
const std::string& target = "llvm"); const std::string& target = "llvm");
} // namespace relay } // namespace relay
......
...@@ -160,7 +160,7 @@ class VarNode : public ExprNode { ...@@ -160,7 +160,7 @@ class VarNode : public ExprNode {
RELAY_DEFINE_NODE_REF(Var, VarNode, Expr); RELAY_DEFINE_NODE_REF(Var, VarNode, Expr);
/*! /*!
* \brief Global variable that leaves in the top-level environment. * \brief Global variable that leaves in the top-level module.
* This is used to enable recursive calls between function. * This is used to enable recursive calls between function.
* *
* \note A GlobalVar may only point to functions. * \note A GlobalVar may only point to functions.
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* \brief An interpreter for Relay. * \brief An interpreter for Relay.
* *
* This file implements a simple reference interpreter for Relay programs. * This file implements a simple reference interpreter for Relay programs.
* Given a Relay environment, and a Relay expression it produces a value. * Given a Relay module, and a Relay expression it produces a value.
* *
* The interpreter's values are a naive representation of the values that * The interpreter's values are a naive representation of the values that
* can be produced by a Relay program and are exposed via tvm::Node's * can be produced by a Relay program and are exposed via tvm::Node's
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#ifndef TVM_RELAY_INTERPRETER_H_ #ifndef TVM_RELAY_INTERPRETER_H_
#define TVM_RELAY_INTERPRETER_H_ #define TVM_RELAY_INTERPRETER_H_
#include <tvm/relay/environment.h> #include <tvm/relay/module.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
namespace tvm { namespace tvm {
...@@ -39,7 +39,7 @@ class Value; ...@@ -39,7 +39,7 @@ class Value;
* Our intent is that this will never be the most efficient implementation of * Our intent is that this will never be the most efficient implementation of
* Relay's semantics, but a readable and clear one. * Relay's semantics, but a readable and clear one.
*/ */
Value Evaluate(Environment env, Expr e); Value Evaluate(Module mod, Expr e);
/*! \brief The base container type of Relay values. */ /*! \brief The base container type of Relay values. */
class ValueNode : public RelayNode { class ValueNode : public RelayNode {
......
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* \file tvm/relay/environment.h * \file tvm/relay/module.h
* \brief The global environment: contains information needed to * \brief The global environment: contains information needed to
* compile & optimize Relay programs. * compile & optimize Relay programs.
*/ */
#ifndef TVM_RELAY_ENVIRONMENT_H_ #ifndef TVM_RELAY_MODULE_H_
#define TVM_RELAY_ENVIRONMENT_H_ #define TVM_RELAY_MODULE_H_
#include <tvm/relay/error.h> #include <tvm/relay/error.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
struct Environment; struct Module;
/*! \brief The global environment of Relay programs. /*! \brief The global environment of Relay programs.
* *
...@@ -28,29 +28,29 @@ struct Environment; ...@@ -28,29 +28,29 @@ struct Environment;
* options. * options.
* *
* Many operations require access to the global * Many operations require access to the global
* Environment. We pass the Environment by value * Module. We pass the Module by value
* in a functional style as an explicit argument, * in a functional style as an explicit argument,
* but we mutate the Environment while optimizing * but we mutate the Module while optimizing
* Relay programs. * Relay programs.
* *
* The functional style allows users to construct custom * The functional style allows users to construct custom
* environments easily, for example each thread can store * environments easily, for example each thread can store
* an Environment while auto-tuning. * an Module while auto-tuning.
* */ * */
class EnvironmentNode : public RelayNode { class ModuleNode : public RelayNode {
public: public:
/*! \brief A map from ids to all global functions. */ /*! \brief A map from ids to all global functions. */
tvm::Map<GlobalVar, Function> functions; tvm::Map<GlobalVar, Function> functions;
EnvironmentNode() {} ModuleNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("functions", &functions); v->Visit("functions", &functions);
v->Visit("global_var_map_", &global_var_map_); v->Visit("global_var_map_", &global_var_map_);
} }
TVM_DLL static Environment make(tvm::Map<GlobalVar, Function> global_funcs); TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs);
/*! /*!
* \brief Add a function to the global environment. * \brief Add a function to the global environment.
...@@ -100,10 +100,10 @@ class EnvironmentNode : public RelayNode { ...@@ -100,10 +100,10 @@ class EnvironmentNode : public RelayNode {
* functions in another environment. * functions in another environment.
* \param other The other environment. * \param other The other environment.
*/ */
void Update(const Environment& other); void Update(const Module& other);
static constexpr const char* _type_key = "relay.Environment"; static constexpr const char* _type_key = "relay.Module";
TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node); TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node);
private: private:
/*! \brief A map from string names to global variables that /*! \brief A map from string names to global variables that
...@@ -112,18 +112,18 @@ class EnvironmentNode : public RelayNode { ...@@ -112,18 +112,18 @@ class EnvironmentNode : public RelayNode {
tvm::Map<std::string, GlobalVar> global_var_map_; tvm::Map<std::string, GlobalVar> global_var_map_;
}; };
struct Environment : public NodeRef { struct Module : public NodeRef {
Environment() {} Module() {}
explicit Environment(NodePtr<tvm::Node> p) : NodeRef(p) {} explicit Module(NodePtr<tvm::Node> p) : NodeRef(p) {}
inline EnvironmentNode* operator->() const { inline ModuleNode* operator->() const {
return static_cast<EnvironmentNode*>(node_.get()); return static_cast<ModuleNode*>(node_.get());
} }
using ContainerType = EnvironmentNode; using ContainerType = ModuleNode;
}; };
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_ENVIRONMENT_H_ #endif // TVM_RELAY_MODULE_H_
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#ifndef TVM_RELAY_PASS_H_ #ifndef TVM_RELAY_PASS_H_
#define TVM_RELAY_PASS_H_ #define TVM_RELAY_PASS_H_
#include <tvm/relay/environment.h> #include <tvm/relay/module.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <string> #include <string>
...@@ -21,23 +21,23 @@ namespace relay { ...@@ -21,23 +21,23 @@ namespace relay {
* populated with the result type. * populated with the result type.
* *
* \param expr The expression to type check. * \param expr The expression to type check.
* \param env The environment used for referencing global functions, can be * \param mod The module used for referencing global functions, can be
* None. * None.
* *
* \return A type checked expression with its checked_type field populated. * \return A type checked expression with its checked_type field populated.
*/ */
Expr InferType(const Expr& expr, const Environment& env); Expr InferType(const Expr& expr, const Module& mod);
/*! /*!
* \brief Infer the type of a function as if it is mapped to var in the env. * \brief Infer the type of a function as if it is mapped to var in the mod.
* *
* \param f the function. * \param f the function.
* \param env The environment used for referencing global functions. * \param mod The module used for referencing global functions.
* \param var The global variable corresponding to the function. * \param var The global variable corresponding to the function.
* *
* \return A type checked Function with its checked_type field populated. * \return A type checked Function with its checked_type field populated.
* \note this function mutates env and is not thread-safe. * \note this function mutates mod and is not thread-safe.
*/ */
Function InferType(const Function& f, const Environment& env, Function InferType(const Function& f, const Module& mod,
const GlobalVar& var); const GlobalVar& var);
/*! /*!
...@@ -52,11 +52,11 @@ Function InferType(const Function& f, const Environment& env, ...@@ -52,11 +52,11 @@ Function InferType(const Function& f, const Environment& env,
* a data type such as `int`, `float`, `uint`. * a data type such as `int`, `float`, `uint`.
* *
* \param t The type to check. * \param t The type to check.
* \param env The global environment. * \param mod The global module.
* *
* \return true if the rules are satisified otherwise false * \return true if the rules are satisified otherwise false
*/ */
bool KindCheck(const Type& t, const Environment& env); bool KindCheck(const Type& t, const Module& mod);
/*! \brief Compare two expressions for structural equivalence. /*! \brief Compare two expressions for structural equivalence.
* *
......
...@@ -349,14 +349,14 @@ class TypeRelation; ...@@ -349,14 +349,14 @@ class TypeRelation;
/*! /*!
* \brief TypeRelation container. * \brief TypeRelation container.
* \note This node is not directly serializable. * \note This node is not directly serializable.
* The type function need to be lookedup in the environment. * The type function need to be lookedup in the module.
*/ */
class TypeRelationNode : public TypeConstraintNode { class TypeRelationNode : public TypeConstraintNode {
public: public:
/*! /*!
* \brief The function on input and output variables which * \brief The function on input and output variables which
* this is not directly serializable, * this is not directly serializable,
* need to be looked-up in the environment. * need to be looked-up in the module.
*/ */
TypeRelationFn func; TypeRelationFn func;
/*! \brief The type arguments to the type function. */ /*! \brief The type arguments to the type function. */
......
...@@ -5,7 +5,7 @@ from ..api import register_func ...@@ -5,7 +5,7 @@ from ..api import register_func
from . import base from . import base
from . import ty from . import ty
from . import expr from . import expr
from . import env from . import module
from . import ir_pass from . import ir_pass
from .build_module import build from .build_module import build
from .interpreter import create_executor from .interpreter import create_executor
...@@ -26,7 +26,7 @@ from .scope_builder import ScopeBuilder ...@@ -26,7 +26,7 @@ from .scope_builder import ScopeBuilder
Span = base.Span Span = base.Span
# Env # Env
Environment = env.Environment Module = module.Module
# Type # Type
Type = ty.Type Type = ty.Type
......
from .env import Environment from .env import Module
from . import ir from . import ir
def check_expr(env: Environment, expr: ir.Expr) -> ir.Type: ... def check_expr(env: Module, expr: ir.Expr) -> ir.Type: ...
def generalize(env: Environment, expr: ir.Expr) -> ir.Expr: ... def generalize(env: Module, expr: ir.Expr) -> ir.Expr: ...
def _get_checked_type(expr: ir.Expr) -> ir.Type: ... def _get_checked_type(expr: ir.Expr) -> ir.Type: ...
def well_formed(expr: ir.Expr) -> bool: ... def well_formed(expr: ir.Expr) -> bool: ...
def dead_code_elimination(expr: ir.Expr) -> ir.Expr: ... def dead_code_elimination(expr: ir.Expr) -> ir.Expr: ...
\ No newline at end of file
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface to the Environment exposed from C++.""" """The interface to the Module exposed from C++."""
from tvm._ffi.function import _init_api from tvm._ffi.function import _init_api
_init_api("relay._env", __name__) _init_api("relay._module", __name__)
...@@ -2,4 +2,4 @@ from typing import Union, Tuple, Dict, List ...@@ -2,4 +2,4 @@ from typing import Union, Tuple, Dict, List
from relay.ir import GlobalId, OperatorId, Item, NodeBase, Span, FileId from relay.ir import GlobalId, OperatorId, Item, NodeBase, Span, FileId
from relay.ir import ShapeExtension, Operator, Defn from relay.ir import ShapeExtension, Operator, Defn
class Environment(NodeBase): ... class Module(NodeBase): ...
\ No newline at end of file
...@@ -5,9 +5,9 @@ from a Relay expression. ...@@ -5,9 +5,9 @@ from a Relay expression.
from ..build_module import build as tvm_build_module from ..build_module import build as tvm_build_module
from . graph_runtime_codegen import GraphRuntimeCodegen from . graph_runtime_codegen import GraphRuntimeCodegen
from . import ir_pass from . import ir_pass
from .env import Environment from .module import Module
def build(func, params=None, target=None, env=None): def build(func, params=None, target=None, mod=None):
""" """
Compile a single function to the components needed by the Compile a single function to the components needed by the
TVM RTS. TVM RTS.
...@@ -29,15 +29,15 @@ def build(func, params=None, target=None, env=None): ...@@ -29,15 +29,15 @@ def build(func, params=None, target=None, env=None):
if target is None: if target is None:
target = 'llvm' target = 'llvm'
if env is None: if mod is None:
env = Environment({}) mod = Module({})
comp = GraphRuntimeCodegen(env) comp = GraphRuntimeCodegen(mod)
# NB(@jroesch) This creates lowered functions, and generates names for them # NB(@jroesch) This creates lowered functions, and generates names for them
# #
# We need these names to emit the correct graph as these are names of the # We need these names to emit the correct graph as these are names of the
# functions contained in the module. # functions contained in the module.
lowered_ops = ir_pass.lower_ops(env, func) lowered_ops = ir_pass.lower_ops(mod, func)
mod = tvm_build_module([lf.lowered_func for lf in lowered_ops], target) mod = tvm_build_module([lf.lowered_func for lf in lowered_ops], target)
# Therefore the call to compile must come after. # Therefore the call to compile must come after.
......
...@@ -172,7 +172,7 @@ class GlobalVar(Expr): ...@@ -172,7 +172,7 @@ class GlobalVar(Expr):
"""A global variable in Tvm.Relay. """A global variable in Tvm.Relay.
GlobalVar is used to refer to the global functions GlobalVar is used to refer to the global functions
stored in the environment. stored in the module.
Parameters Parameters
---------- ----------
......
...@@ -8,7 +8,7 @@ from . import build_module ...@@ -8,7 +8,7 @@ from . import build_module
from . import _make from . import _make
from . import _interpreter from . import _interpreter
from . import ir_pass from . import ir_pass
from .env import Environment from .module import Module
from .expr import Call, Constant, GlobalVar, Function, const from .expr import Call, Constant, GlobalVar, Function, const
from .scope_builder import ScopeBuilder from .scope_builder import ScopeBuilder
from .._ffi.base import integer_types from .._ffi.base import integer_types
...@@ -90,24 +90,24 @@ def _arg_to_ast(arg): ...@@ -90,24 +90,24 @@ def _arg_to_ast(arg):
class Executor(object): class Executor(object):
"""An abstract interface for executing Relay programs.""" """An abstract interface for executing Relay programs."""
def __init__(self, env=None): def __init__(self, mod=None):
""" """
Parameters Parameters
---------- ----------
env: relay.Environment mod: relay.Module
The environment. The module.
""" """
if env is None: if mod is None:
self.env = Environment({}) self.mod = Module({})
else: else:
self.env = env self.mod = mod
def optimize(self, expr): def optimize(self, expr):
# TODO: We need to move this optimization code into the optimizer/pass manager # TODO: We need to move this optimization code into the optimizer/pass manager
ck_expr = ir_pass.infer_type(expr, env=self.env) ck_expr = ir_pass.infer_type(expr, mod=self.mod)
fused_expr = ir_pass.fuse_ops(self.env, ck_expr) fused_expr = ir_pass.fuse_ops(self.mod, ck_expr)
ck_fused = ir_pass.infer_type(fused_expr, env=self.env) ck_fused = ir_pass.infer_type(fused_expr, mod=self.mod)
return ck_fused return ck_fused
def _make_executor(self, _): def _make_executor(self, _):
...@@ -153,8 +153,8 @@ class Interpreter(Executor): ...@@ -153,8 +153,8 @@ class Interpreter(Executor):
""" """
A wrapper around the Relay interpreter, implements the excecutor interface. A wrapper around the Relay interpreter, implements the excecutor interface.
""" """
def __init__(self, env=None): def __init__(self, mod=None):
Executor.__init__(self, env) Executor.__init__(self, mod)
def _make_executor(self, expr): def _make_executor(self, expr):
def _interp_wrapper(*args): def _interp_wrapper(*args):
...@@ -163,28 +163,28 @@ class Interpreter(Executor): ...@@ -163,28 +163,28 @@ class Interpreter(Executor):
relay_args.append(_arg_to_ast(arg)) relay_args.append(_arg_to_ast(arg))
if isinstance(expr, GlobalVar): if isinstance(expr, GlobalVar):
func = self.env[expr] func = self.mod[expr]
func = self.optimize(func) func = self.optimize(func)
self.env._add(expr, func, True) self.mod._add(expr, func, True)
opt_expr = Call(expr, relay_args) opt_expr = Call(expr, relay_args)
return _interpreter.evaluate(self.env, opt_expr) return _interpreter.evaluate(self.mod, opt_expr)
else: else:
call = Call(expr, relay_args) call = Call(expr, relay_args)
opt_expr = self.optimize(call) opt_expr = self.optimize(call)
return _interpreter.evaluate(self.env, opt_expr) return _interpreter.evaluate(self.mod, opt_expr)
return _interp_wrapper return _interp_wrapper
class GraphRuntime(Executor): class GraphRuntime(Executor):
"""A wrapper around the TVM graph runtime, implements the Executor interface.""" """A wrapper around the TVM graph runtime, implements the Executor interface."""
def __init__(self, env=None): def __init__(self, mod=None):
Executor.__init__(self, env) Executor.__init__(self, mod)
def _make_executor(self, expr): def _make_executor(self, expr):
def _graph_wrapper(*args): def _graph_wrapper(*args):
func = self.optimize(expr) func = self.optimize(expr)
graph_json, mod, params = build_module.build(func, env=self.env) graph_json, mod, params = build_module.build(func, mod=self.mod)
assert params is None assert params is None
gmodule = tvm_runtime.create(graph_json, mod, cpu(0)) gmodule = tvm_runtime.create(graph_json, mod, cpu(0))
# Create map of inputs. # Create map of inputs.
...@@ -199,10 +199,10 @@ class GraphRuntime(Executor): ...@@ -199,10 +199,10 @@ class GraphRuntime(Executor):
return _graph_wrapper return _graph_wrapper
def create_executor(mode='debug', env=None): def create_executor(mode='debug', mod=None):
if mode == 'debug': if mode == 'debug':
return Interpreter(env) return Interpreter(mod)
elif mode == 'graph': elif mode == 'graph':
return GraphRuntime(env) return GraphRuntime(mod)
else: else:
raise Exception("unknown mode {0}".format(mode)) raise Exception("unknown mode {0}".format(mode))
...@@ -11,16 +11,16 @@ from .expr import Expr ...@@ -11,16 +11,16 @@ from .expr import Expr
from .ty import Type from .ty import Type
def infer_type(expr, env=None): def infer_type(expr, mod=None):
"""Infer the type of expr under the context of env. """Infer the type of expr under the context of mod.
Parameters Parameters
---------- ----------
expr: tvm.relay.Expr expr: tvm.relay.Expr
The input expression. The input expression.
env: Optional[tvm.relay.Environment] mod: Optional[tvm.relay.Module]
The global environment. The global module.
Returns Returns
...@@ -28,7 +28,7 @@ def infer_type(expr, env=None): ...@@ -28,7 +28,7 @@ def infer_type(expr, env=None):
checked_expr : tvm.relay.Expr checked_expr : tvm.relay.Expr
The checked expression. The checked expression.
""" """
return _ir_pass.infer_type(expr, env) return _ir_pass.infer_type(expr, mod)
def backward_fold_scale_axis(expr): def backward_fold_scale_axis(expr):
...@@ -93,7 +93,7 @@ def well_formed(expr): ...@@ -93,7 +93,7 @@ def well_formed(expr):
return _ir_pass.well_formed(expr) return _ir_pass.well_formed(expr)
def check_kind(t, env=None): def check_kind(t, mod=None):
"""Check that the type is well kinded. """Check that the type is well kinded.
For example, this mean type cannot has tensor of tensor, or is a tuple type of 2 shapes. For example, this mean type cannot has tensor of tensor, or is a tuple type of 2 shapes.
...@@ -102,8 +102,8 @@ def check_kind(t, env=None): ...@@ -102,8 +102,8 @@ def check_kind(t, env=None):
t: tvm.relay.Type t: tvm.relay.Type
The type to check The type to check
env: tvm.relay.Environment, optional mod: tvm.relay.Module, optional
The global environment The global module
Returns Returns
------- -------
...@@ -117,8 +117,8 @@ def check_kind(t, env=None): ...@@ -117,8 +117,8 @@ def check_kind(t, env=None):
assert not check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Shape)])) assert not check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Shape)]))
assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)])) assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)]))
""" """
if env is not None: if mod is not None:
return _ir_pass.check_kind(t, env) return _ir_pass.check_kind(t, mod)
else: else:
return _ir_pass.check_kind(t) return _ir_pass.check_kind(t)
...@@ -256,8 +256,8 @@ def structural_hash(value): ...@@ -256,8 +256,8 @@ def structural_hash(value):
"relay.Expr or relay.Type").format(type(value)) "relay.Expr or relay.Type").format(type(value))
raise TypeError(msg) raise TypeError(msg)
def fuse_ops(expr, env): def fuse_ops(expr, mod):
return _ir_pass.FuseOps(env, expr) return _ir_pass.FuseOps(mod, expr)
def lower_ops(env, expr, target='llvm'): def lower_ops(mod, expr, target='llvm'):
return _ir_pass.LowerOps(env, expr, target) return _ir_pass.LowerOps(mod, expr, target)
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
"""A global environment storing everything needed to interpret or compile a Relay program.""" """A global module storing everything needed to interpret or compile a Relay program."""
from .base import register_relay_node, RelayNode from .base import register_relay_node, RelayNode
from .._ffi import base as _base from .._ffi import base as _base
from . import _make from . import _make
from . import _env from . import _module
from . import expr as _expr from . import expr as _expr
@register_relay_node @register_relay_node
class Environment(RelayNode): class Module(RelayNode):
"""The global Relay environment containing collection of functions. """The global Relay module containing collection of functions.
Each global function is identified by an unique tvm.relay.GlobalVar. Each global function is identified by an unique tvm.relay.GlobalVar.
tvm.relay.GlobalVar and Environment is necessary in order to enable tvm.relay.GlobalVar and Module is necessary in order to enable
recursions in function to avoid cyclic reference in the function.x recursions in function to avoid cyclic reference in the function.x
Parameters Parameters
...@@ -32,10 +32,10 @@ class Environment(RelayNode): ...@@ -32,10 +32,10 @@ class Environment(RelayNode):
raise TypeError("Expect functions to be Dict[GlobalVar, Function]") raise TypeError("Expect functions to be Dict[GlobalVar, Function]")
mapped_funcs[k] = v mapped_funcs[k] = v
functions = mapped_funcs functions = mapped_funcs
self.__init_handle_by_constructor__(_make.Environment, functions) self.__init_handle_by_constructor__(_make.Module, functions)
def __setitem__(self, var, func): def __setitem__(self, var, func):
"""Add a function to the environment. """Add a function to the module.
Parameters Parameters
--------- ---------
...@@ -50,7 +50,7 @@ class Environment(RelayNode): ...@@ -50,7 +50,7 @@ class Environment(RelayNode):
def _add(self, var, func, update=False): def _add(self, var, func, update=False):
if isinstance(var, _base.string_types): if isinstance(var, _base.string_types):
var = _expr.GlobalVar(var) var = _expr.GlobalVar(var)
return _env.Environment_Add(self, var, func, update) return _module.Module_Add(self, var, func, update)
def __getitem__(self, var): def __getitem__(self, var):
"""Lookup a global function by name or by variable. """Lookup a global function by name or by variable.
...@@ -66,21 +66,21 @@ class Environment(RelayNode): ...@@ -66,21 +66,21 @@ class Environment(RelayNode):
The function referenced by :code:`var`. The function referenced by :code:`var`.
""" """
if isinstance(var, _base.string_types): if isinstance(var, _base.string_types):
return _env.Environment_Lookup_str(self, var) return _module.Module_Lookup_str(self, var)
else: else:
return _env.Environment_Lookup(self, var) return _module.Module_Lookup(self, var)
def update(self, other): def update(self, other):
"""Insert functions in another Environment to current one. """Insert functions in another Module to current one.
Parameters Parameters
---------- ----------
other: Environment other: Module
The environment to merge into the current Environment. The module to merge into the current Module.
""" """
if isinstance(other, dict): if isinstance(other, dict):
other = Environment(other) other = Module(other)
return _env.Environment_Update(self, other) return _module.Module_Update(self, other)
def get_global_var(self, name): def get_global_var(self, name):
"""Get a global variable in the function by name. """Get a global variable in the function by name.
...@@ -99,4 +99,4 @@ class Environment(RelayNode): ...@@ -99,4 +99,4 @@ class Environment(RelayNode):
------ ------
tvm.TVMError if we cannot find corresponding global var. tvm.TVMError if we cannot find corresponding global var.
""" """
return _env.Environment_GetGlobalVar(self, name) return _module.Module_GetGlobalVar(self, name)
...@@ -183,7 +183,7 @@ struct ExprEqual { ...@@ -183,7 +183,7 @@ struct ExprEqual {
}; };
struct Interpreter : ExprFunctor<Value(const Expr& n)> { struct Interpreter : ExprFunctor<Value(const Expr& n)> {
Environment env; Module mod;
Stack stack; Stack stack;
using JitKey = Function; using JitKey = Function;
...@@ -197,8 +197,8 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> { ...@@ -197,8 +197,8 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
return f(); return f();
} }
Interpreter(Environment env) : env(env), operator_map_() {} Interpreter(Module mod) : mod(mod), operator_map_() {}
Interpreter(Environment env, OpMap operator_map) : env(env), operator_map_(operator_map) {} Interpreter(Module mod, OpMap operator_map) : mod(mod), operator_map_(operator_map) {}
void extend(const Var& id, Value v) { void extend(const Var& id, Value v) {
this->stack.current_frame().locals.Set(id, v); this->stack.current_frame().locals.Set(id, v);
...@@ -223,7 +223,7 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> { ...@@ -223,7 +223,7 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
} }
Value VisitExpr_(const GlobalVarNode* op) override { Value VisitExpr_(const GlobalVarNode* op) override {
return Eval(this->env->Lookup(GetRef<GlobalVar>(op))); return Eval(this->mod->Lookup(GetRef<GlobalVar>(op)));
} }
Value VisitExpr_(const OpNode* id) override { Value VisitExpr_(const OpNode* id) override {
...@@ -251,14 +251,14 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> { ...@@ -251,14 +251,14 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
Value VisitExpr_(const FunctionNode* func_node) override { Value VisitExpr_(const FunctionNode* func_node) override {
auto func = GetRef<Function>(func_node); auto func = GetRef<Function>(func_node);
tvm::Map<Var, Value> captured_env; tvm::Map<Var, Value> captured_mod;
Array<Var> free_vars = FreeVars(func); Array<Var> free_vars = FreeVars(func);
for (const auto& var : free_vars) { for (const auto& var : free_vars) {
captured_env.Set(var, Eval(var)); captured_mod.Set(var, Eval(var));
} }
return ClosureNode::make(captured_env, func); return ClosureNode::make(captured_mod, func);
} }
inline Value InvokeCompiledOp(PackedFunc func, const Array<Value>& args, inline Value InvokeCompiledOp(PackedFunc func, const Array<Value>& args,
...@@ -315,7 +315,7 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> { ...@@ -315,7 +315,7 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
locals.Set(func->params[i], args[i]); locals.Set(func->params[i], args[i]);
} }
// Add the var to value mappings from the Closure's environment. // Add the var to value mappings from the Closure's modironment.
for (auto it = closure->env.begin(); it != closure->env.end(); ++it) { for (auto it = closure->env.begin(); it != closure->env.end(); ++it) {
CHECK_EQ(locals.count((*it).first), 0); CHECK_EQ(locals.count((*it).first), 0);
locals.Set((*it).first, (*it).second); locals.Set((*it).first, (*it).second);
...@@ -384,9 +384,9 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> { ...@@ -384,9 +384,9 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
} }
}; };
Interpreter::OpMap CompileOperators(const Environment& env, const Expr& e) { Interpreter::OpMap CompileOperators(const Module& mod, const Expr& e) {
Interpreter::OpMap op_map; Interpreter::OpMap op_map;
auto lowered_ops = LowerOps(env, e); auto lowered_ops = LowerOps(mod, e);
RELAY_LOG(INFO) << "LoweredFuncs: " << lowered_ops << std::endl; RELAY_LOG(INFO) << "LoweredFuncs: " << lowered_ops << std::endl;
if (lowered_ops.size()) { if (lowered_ops.size()) {
const PackedFunc* fbuild_ptr = Registry::Get("relay.op.compiler._build"); const PackedFunc* fbuild_ptr = Registry::Get("relay.op.compiler._build");
...@@ -399,7 +399,7 @@ Interpreter::OpMap CompileOperators(const Environment& env, const Expr& e) { ...@@ -399,7 +399,7 @@ Interpreter::OpMap CompileOperators(const Environment& env, const Expr& e) {
lowered_funcs.push_back(lop->lowered_func); lowered_funcs.push_back(lop->lowered_func);
} }
Module module = fbuild(lowered_funcs); runtime::Module module = fbuild(lowered_funcs);
// Loop over the lowered operations to map them into the operator map. // Loop over the lowered operations to map them into the operator map.
for (auto lop : lowered_ops) { for (auto lop : lowered_ops) {
...@@ -415,17 +415,17 @@ Interpreter::OpMap CompileOperators(const Environment& env, const Expr& e) { ...@@ -415,17 +415,17 @@ Interpreter::OpMap CompileOperators(const Environment& env, const Expr& e) {
return op_map; return op_map;
} }
Value Evaluate(Environment env, Expr e) { Value Evaluate(Module mod, Expr e) {
auto op_map = CompileOperators(env, e); auto op_map = CompileOperators(mod, e);
Interpreter interp(env, op_map); Interpreter interp(mod, op_map);
return interp.Eval(e); return interp.Eval(e);
} }
TVM_REGISTER_API("relay._interpreter.evaluate") TVM_REGISTER_API("relay._interpreter.evaluate")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
Environment env = args[0]; Module mod = args[0];
Expr expr = args[1]; Expr expr = args[1];
*ret = Evaluate(env, expr); *ret = Evaluate(mod, expr);
}); });
} // namespace relay } // namespace relay
......
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* \file environment.cc * \file module.cc
* \brief The global environment in Relay. * \brief The global module in Relay.
*/ */
#include <tvm/relay/environment.h> #include <tvm/relay/module.h>
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include <sstream> #include <sstream>
...@@ -13,8 +13,8 @@ namespace relay { ...@@ -13,8 +13,8 @@ namespace relay {
using tvm::IRPrinter; using tvm::IRPrinter;
using namespace runtime; using namespace runtime;
Environment EnvironmentNode::make(tvm::Map<GlobalVar, Function> global_funcs) { Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
auto n = make_node<EnvironmentNode>(); auto n = make_node<ModuleNode>();
n->functions = std::move(global_funcs); n->functions = std::move(global_funcs);
for (const auto& kv : n->functions) { for (const auto& kv : n->functions) {
...@@ -23,22 +23,22 @@ Environment EnvironmentNode::make(tvm::Map<GlobalVar, Function> global_funcs) { ...@@ -23,22 +23,22 @@ Environment EnvironmentNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
<< "Duplicate global function name " << kv.first->name_hint; << "Duplicate global function name " << kv.first->name_hint;
n->global_var_map_.Set(kv.first->name_hint, kv.first); n->global_var_map_.Set(kv.first->name_hint, kv.first);
} }
return Environment(n); return Module(n);
} }
GlobalVar EnvironmentNode::GetGlobalVar(const std::string& name) { GlobalVar ModuleNode::GetGlobalVar(const std::string& name) {
auto it = global_var_map_.find(name); auto it = global_var_map_.find(name);
CHECK(it != global_var_map_.end()) CHECK(it != global_var_map_.end())
<< "Cannot find global var " << name << " in the Environment"; << "Cannot find global var " << name << " in the Module";
return (*it).second; return (*it).second;
} }
void EnvironmentNode::Add(const GlobalVar& var, void ModuleNode::Add(const GlobalVar& var,
const Function& func, const Function& func,
bool update) { bool update) {
// Type check the item before we add it to the environment. // Type check the item before we add it to the modironment.
auto env = GetRef<Environment>(this); auto mod = GetRef<Module>(this);
Function checked_func = InferType(func, env, var); Function checked_func = InferType(func, mod, var);
auto type = checked_func->checked_type(); auto type = checked_func->checked_type();
CHECK(type.as<IncompleteTypeNode>() == nullptr); CHECK(type.as<IncompleteTypeNode>() == nullptr);
if (functions.find(var) != functions.end()) { if (functions.find(var) != functions.end()) {
...@@ -46,7 +46,7 @@ void EnvironmentNode::Add(const GlobalVar& var, ...@@ -46,7 +46,7 @@ void EnvironmentNode::Add(const GlobalVar& var,
<< "Already have definition for " << var->name_hint; << "Already have definition for " << var->name_hint;
auto old_type = functions[var].as<FunctionNode>()->checked_type(); auto old_type = functions[var].as<FunctionNode>()->checked_type();
CHECK(AlphaEqual(type, old_type)) CHECK(AlphaEqual(type, old_type))
<< "Environment#update changes type, not possible in this mode."; << "Module#update changes type, not possible in this mode.";
} }
this->functions.Set(var, checked_func); this->functions.Set(var, checked_func);
...@@ -62,79 +62,79 @@ void EnvironmentNode::Add(const GlobalVar& var, ...@@ -62,79 +62,79 @@ void EnvironmentNode::Add(const GlobalVar& var,
global_var_map_.Set(var->name_hint, var); global_var_map_.Set(var->name_hint, var);
} }
void EnvironmentNode::Update(const GlobalVar& var, const Function& func) { void ModuleNode::Update(const GlobalVar& var, const Function& func) {
this->Add(var, func, true); this->Add(var, func, true);
} }
void EnvironmentNode::Remove(const GlobalVar& var) { void ModuleNode::Remove(const GlobalVar& var) {
auto functions_node = this->functions.CopyOnWrite(); auto functions_node = this->functions.CopyOnWrite();
functions_node->data.erase(var.node_); functions_node->data.erase(var.node_);
auto gvar_node = global_var_map_.CopyOnWrite(); auto gvar_node = global_var_map_.CopyOnWrite();
gvar_node->data.erase(var->name_hint); gvar_node->data.erase(var->name_hint);
} }
Function EnvironmentNode::Lookup(const GlobalVar& var) { Function ModuleNode::Lookup(const GlobalVar& var) {
auto it = functions.find(var); auto it = functions.find(var);
CHECK(it != functions.end()) CHECK(it != functions.end())
<< "There is no definition of " << var->name_hint; << "There is no definition of " << var->name_hint;
return (*it).second; return (*it).second;
} }
Function EnvironmentNode::Lookup(const std::string& name) { Function ModuleNode::Lookup(const std::string& name) {
GlobalVar id = this->GetGlobalVar(name); GlobalVar id = this->GetGlobalVar(name);
return this->Lookup(id); return this->Lookup(id);
} }
void EnvironmentNode::Update(const Environment& env) { void ModuleNode::Update(const Module& mod) {
for (auto pair : env->functions) { for (auto pair : mod->functions) {
this->Update(pair.first, pair.second); this->Update(pair.first, pair.second);
} }
} }
TVM_REGISTER_NODE_TYPE(EnvironmentNode); TVM_REGISTER_NODE_TYPE(ModuleNode);
TVM_REGISTER_API("relay._make.Environment") TVM_REGISTER_API("relay._make.Module")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = EnvironmentNode::make(args[0]); *ret = ModuleNode::make(args[0]);
}); });
TVM_REGISTER_API("relay._env.Environment_Add") TVM_REGISTER_API("relay._module.Module_Add")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0]; Module mod = args[0];
env->Add(args[1], args[2], args[3]); mod->Add(args[1], args[2], args[3]);
}); });
TVM_REGISTER_API("relay._env.Environment_GetGlobalVar") TVM_REGISTER_API("relay._module.Module_GetGlobalVar")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0]; Module mod = args[0];
*ret = env->GetGlobalVar(args[1]); *ret = mod->GetGlobalVar(args[1]);
}); });
TVM_REGISTER_API("relay._env.Environment_Lookup") TVM_REGISTER_API("relay._module.Module_Lookup")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0]; Module mod = args[0];
GlobalVar var = args[1]; GlobalVar var = args[1];
*ret = env->Lookup(var); *ret = mod->Lookup(var);
}); });
TVM_REGISTER_API("relay._env.Environment_Lookup_str") TVM_REGISTER_API("relay._module.Module_Lookup_str")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0]; Module mod = args[0];
std::string var_name = args[1]; std::string var_name = args[1];
auto var = env->GetGlobalVar(var_name); auto var = mod->GetGlobalVar(var_name);
*ret = env->Lookup(var); *ret = mod->Lookup(var);
}); });
TVM_REGISTER_API("relay._env.Environment_Update") TVM_REGISTER_API("relay._module.Module_Update")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0]; Module mod = args[0];
env->Update(args[1]); mod->Update(args[1]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<EnvironmentNode>( .set_dispatch<ModuleNode>(
[](const EnvironmentNode *node, tvm::IRPrinter *p) { [](const ModuleNode *node, tvm::IRPrinter *p) {
p->stream << "EnvironmentNode( " << node->functions << ")"; p->stream << "ModuleNode( " << node->functions << ")";
}); });
} // namespace relay } // namespace relay
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
* \file text_printer.cc * \file text_printer.cc
* \brief Text printer to print relay in text form. * \brief Text printer to print relay in text form.
*/ */
#include <tvm/relay/environment.h> #include <tvm/relay/module.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <sstream> #include <sstream>
#include "type_functor.h" #include "type_functor.h"
...@@ -133,8 +133,8 @@ class TextPrinter : ...@@ -133,8 +133,8 @@ class TextPrinter :
std::string Print(const NodeRef& node) { std::string Print(const NodeRef& node) {
if (node.as<FunctionNode>()) { if (node.as<FunctionNode>()) {
this->PrintFunc(Downcast<Function>(node)); this->PrintFunc(Downcast<Function>(node));
} else if (node.as<EnvironmentNode>()) { } else if (node.as<ModuleNode>()) {
this->PrintEnv(Downcast<Environment>(node)); this->PrintEnv(Downcast<Module>(node));
} else if (node.as_derived<TypeNode>()) { } else if (node.as_derived<TypeNode>()) {
this->PrintType(Downcast<Type>(node), stream_); this->PrintType(Downcast<Type>(node), stream_);
} else if (node.as_derived<ExprNode>()) { } else if (node.as_derived<ExprNode>()) {
...@@ -158,9 +158,9 @@ class TextPrinter : ...@@ -158,9 +158,9 @@ class TextPrinter :
stream_ << "\n"; stream_ << "\n";
} }
void PrintEnv(const Environment& env) { void PrintEnv(const Module& mod) {
int counter = 0; int counter = 0;
for (const auto& kv : env->functions) { for (const auto& kv : mod->functions) {
std::ostringstream os; std::ostringstream os;
if (counter++ != 0) { if (counter++ != 0) {
stream_ << "\n"; stream_ << "\n";
......
...@@ -20,12 +20,12 @@ namespace relay { ...@@ -20,12 +20,12 @@ namespace relay {
using namespace runtime; using namespace runtime;
struct AbstractFusableOps : ExprMutator { struct AbstractFusableOps : ExprMutator {
Environment env; Module mod;
Array<GlobalVar> fusable_funcs; Array<GlobalVar> fusable_funcs;
int counter = 0; int counter = 0;
size_t expr_hash; size_t expr_hash;
AbstractFusableOps(Environment env, size_t expr_hash) : env(env), expr_hash(expr_hash) {} AbstractFusableOps(Module mod, size_t expr_hash) : mod(mod), expr_hash(expr_hash) {}
Expr VisitExpr_(const CallNode* call) { Expr VisitExpr_(const CallNode* call) {
if (auto op_node = call->op.as<OpNode>()) { if (auto op_node = call->op.as<OpNode>()) {
...@@ -55,7 +55,7 @@ struct AbstractFusableOps : ExprMutator { ...@@ -55,7 +55,7 @@ struct AbstractFusableOps : ExprMutator {
func_name += "_"; func_name += "_";
func_name += std::to_string(expr_hash); func_name += std::to_string(expr_hash);
auto gv = GlobalVarNode::make(func_name); auto gv = GlobalVarNode::make(func_name);
env->Add(gv, func); mod->Add(gv, func);
fusable_funcs.push_back(gv); fusable_funcs.push_back(gv);
return CallNode::make(gv, args, Attrs()); return CallNode::make(gv, args, Attrs());
} else { } else {
...@@ -64,12 +64,12 @@ struct AbstractFusableOps : ExprMutator { ...@@ -64,12 +64,12 @@ struct AbstractFusableOps : ExprMutator {
} }
}; };
Expr FuseOps(const Environment& env, const Expr& e) { Expr FuseOps(const Module& mod, const Expr& e) {
// First we convert all chains of fusable ops into // First we convert all chains of fusable ops into
// abstracted functions which we mark as primtive // abstracted functions which we mark as primtive
// then we convert these primtive functions into // then we convert these primtive functions into
// new operators. // new operators.
auto abstract = AbstractFusableOps(env, StructuralHash()(e)); auto abstract = AbstractFusableOps(mod, StructuralHash()(e));
auto abstracted_e = abstract.VisitExpr(e); auto abstracted_e = abstract.VisitExpr(e);
RELAY_LOG(INFO) << "FuseOps: before=" << e RELAY_LOG(INFO) << "FuseOps: before=" << e
<< "Fuse: after=" << abstracted_e; << "Fuse: after=" << abstracted_e;
......
...@@ -99,7 +99,7 @@ struct KindChecker : TypeVisitor { ...@@ -99,7 +99,7 @@ struct KindChecker : TypeVisitor {
} }
}; };
bool KindCheck(const Type& t, const Environment& env) { bool KindCheck(const Type& t, const Module& mod) {
KindChecker kc; KindChecker kc;
return kc.Check(t); return kc.Check(t);
} }
...@@ -107,7 +107,7 @@ bool KindCheck(const Type& t, const Environment& env) { ...@@ -107,7 +107,7 @@ bool KindCheck(const Type& t, const Environment& env) {
TVM_REGISTER_API("relay._ir_pass.check_kind") TVM_REGISTER_API("relay._ir_pass.check_kind")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 1) { if (args.size() == 1) {
*ret = KindCheck(args[0], EnvironmentNode::make({})); *ret = KindCheck(args[0], ModuleNode::make({}));
} else { } else {
*ret = KindCheck(args[0], args[1]); *ret = KindCheck(args[0], args[1]);
} }
......
...@@ -28,12 +28,12 @@ LoweredOp LoweredOpNode::make(Function func, LoweredFunc lowered_func) { ...@@ -28,12 +28,12 @@ LoweredOp LoweredOpNode::make(Function func, LoweredFunc lowered_func) {
} }
struct AbstractLocalFunctions : ExprMutator { struct AbstractLocalFunctions : ExprMutator {
Environment env; Module mod;
size_t expr_hash; size_t expr_hash;
int counter = 0; int counter = 0;
std::unordered_set<GlobalVar, NodeHash, NodeEqual> visited_funcs; std::unordered_set<GlobalVar, NodeHash, NodeEqual> visited_funcs;
explicit AbstractLocalFunctions(Environment env) explicit AbstractLocalFunctions(Module mod)
: env(env), expr_hash(0), counter(0), visited_funcs() {} : mod(mod), expr_hash(0), counter(0), visited_funcs() {}
Expr Abstract(const Expr& e) { Expr Abstract(const Expr& e) {
expr_hash = StructuralHash()(e); expr_hash = StructuralHash()(e);
...@@ -44,7 +44,7 @@ struct AbstractLocalFunctions : ExprMutator { ...@@ -44,7 +44,7 @@ struct AbstractLocalFunctions : ExprMutator {
auto gvar = GetRef<GlobalVar>(gvar_node); auto gvar = GetRef<GlobalVar>(gvar_node);
auto it = visited_funcs.find(gvar); auto it = visited_funcs.find(gvar);
if (it == visited_funcs.end()) { if (it == visited_funcs.end()) {
auto func = env->Lookup(gvar); auto func = mod->Lookup(gvar);
visited_funcs.insert(gvar); visited_funcs.insert(gvar);
auto new_func = FunctionNode::make( auto new_func = FunctionNode::make(
func->params, func->params,
...@@ -52,7 +52,7 @@ struct AbstractLocalFunctions : ExprMutator { ...@@ -52,7 +52,7 @@ struct AbstractLocalFunctions : ExprMutator {
func->ret_type, func->ret_type,
func->type_params, func->type_params,
func->attrs); func->attrs);
env->Update(gvar, new_func); mod->Update(gvar, new_func);
} }
return gvar; return gvar;
} }
...@@ -70,7 +70,7 @@ struct AbstractLocalFunctions : ExprMutator { ...@@ -70,7 +70,7 @@ struct AbstractLocalFunctions : ExprMutator {
abs_func += std::to_string(expr_hash); abs_func += std::to_string(expr_hash);
auto gv = GlobalVarNode::make(abs_func); auto gv = GlobalVarNode::make(abs_func);
auto lifted_func = FunctionNode::make(params, func, Type(), {}, {}); auto lifted_func = FunctionNode::make(params, func, Type(), {}, {});
env->Add(gv, lifted_func); mod->Add(gv, lifted_func);
Array<Expr> args; Array<Expr> args;
for (auto free_var : free_vars) { for (auto free_var : free_vars) {
args.push_back(free_var); args.push_back(free_var);
...@@ -80,8 +80,8 @@ struct AbstractLocalFunctions : ExprMutator { ...@@ -80,8 +80,8 @@ struct AbstractLocalFunctions : ExprMutator {
}; };
struct LiveFunctions : ExprVisitor { struct LiveFunctions : ExprVisitor {
Environment env; Module mod;
explicit LiveFunctions(Environment env) : env(env), global_funcs() {} explicit LiveFunctions(Module mod) : mod(mod), global_funcs() {}
std::unordered_set<GlobalVar, NodeHash, NodeEqual> visited_funcs; std::unordered_set<GlobalVar, NodeHash, NodeEqual> visited_funcs;
std::unordered_set<GlobalVar, NodeHash, NodeEqual> global_funcs; std::unordered_set<GlobalVar, NodeHash, NodeEqual> global_funcs;
...@@ -100,7 +100,7 @@ struct LiveFunctions : ExprVisitor { ...@@ -100,7 +100,7 @@ struct LiveFunctions : ExprVisitor {
GlobalVar var = GetRef<GlobalVar>(var_node); GlobalVar var = GetRef<GlobalVar>(var_node);
auto it = visited_funcs.find(var); auto it = visited_funcs.find(var);
if (it == visited_funcs.end()) { if (it == visited_funcs.end()) {
auto func = env->Lookup(var); auto func = mod->Lookup(var);
visited_funcs.insert(var); visited_funcs.insert(var);
// The last pass has trasnformed functions of the form: // The last pass has trasnformed functions of the form:
// //
...@@ -134,7 +134,7 @@ struct LiveFunctions : ExprVisitor { ...@@ -134,7 +134,7 @@ struct LiveFunctions : ExprVisitor {
RELAY_LOG(INFO) << "LiveOps: CallNode=" << GetRef<Call>(call); RELAY_LOG(INFO) << "LiveOps: CallNode=" << GetRef<Call>(call);
if (auto gv_node = call->op.as<GlobalVarNode>()) { if (auto gv_node = call->op.as<GlobalVarNode>()) {
GlobalVar gvar = GetRef<GlobalVar>(gv_node); GlobalVar gvar = GetRef<GlobalVar>(gv_node);
Function func = env->Lookup(gvar); Function func = mod->Lookup(gvar);
auto attr = FunctionGetAttr(func, "Primitive"); auto attr = FunctionGetAttr(func, "Primitive");
...@@ -159,15 +159,15 @@ using FCompute = TypedPackedFunc<Array<Tensor>( ...@@ -159,15 +159,15 @@ using FCompute = TypedPackedFunc<Array<Tensor>(
using FSchedule = TypedPackedFunc<Schedule(const Array<Tensor>&, std::string)>; using FSchedule = TypedPackedFunc<Schedule(const Array<Tensor>&, std::string)>;
/*! \brief Return the set of operators in their TVM format. */ /*! \brief Return the set of operators in their TVM format. */
Array<LoweredOp> LowerOps(const Environment& env, const Expr& e, Array<LoweredOp> LowerOps(const Module& mod, const Expr& e,
const std::string& target) { const std::string& target) {
RELAY_LOG(INFO) << "LowerOps: e=" << e; RELAY_LOG(INFO) << "LowerOps: e=" << e;
auto flower_ptr = Registry::Get("relay.op.compiler._lower"); auto flower_ptr = Registry::Get("relay.op.compiler._lower");
CHECK(flower_ptr); CHECK(flower_ptr);
PackedFunc flower = *flower_ptr; PackedFunc flower = *flower_ptr;
auto abstracted_e = AbstractLocalFunctions(env).Abstract(e); auto abstracted_e = AbstractLocalFunctions(mod).Abstract(e);
auto live_funcs = LiveFunctions(env); auto live_funcs = LiveFunctions(mod);
live_funcs.VisitExpr(abstracted_e); live_funcs.VisitExpr(abstracted_e);
auto schedule_reg = Op::GetAttr<FSchedule>("FTVMSchedule"); auto schedule_reg = Op::GetAttr<FSchedule>("FTVMSchedule");
...@@ -176,7 +176,7 @@ Array<LoweredOp> LowerOps(const Environment& env, const Expr& e, ...@@ -176,7 +176,7 @@ Array<LoweredOp> LowerOps(const Environment& env, const Expr& e,
Array<LoweredOp> lowered_funcs; Array<LoweredOp> lowered_funcs;
for (auto func_name : live_funcs.global_funcs) { for (auto func_name : live_funcs.global_funcs) {
auto func = env->Lookup(func_name); auto func = mod->Lookup(func_name);
auto call = Downcast<Call>(func->body); auto call = Downcast<Call>(func->body);
auto op_node = call->op.as<OpNode>(); auto op_node = call->op.as<OpNode>();
CHECK(op_node) << "violated invariant that primtiive calls contain a single op call"; CHECK(op_node) << "violated invariant that primtiive calls contain a single op call";
...@@ -205,7 +205,7 @@ Array<LoweredOp> LowerOps(const Environment& env, const Expr& e, ...@@ -205,7 +205,7 @@ Array<LoweredOp> LowerOps(const Environment& env, const Expr& e,
LoweredFunc lf = LoweredFunc lf =
flower(op->name + std::to_string(hash), schedule, inputs, outputs); flower(op->name + std::to_string(hash), schedule, inputs, outputs);
func = FunctionSetAttr(func, "LoweredFunc", lf); func = FunctionSetAttr(func, "LoweredFunc", lf);
env->Add(func_name, func, true); mod->Add(func_name, func, true);
lowered_funcs.push_back(LoweredOpNode::make(func, lf)); lowered_funcs.push_back(LoweredOpNode::make(func, lf));
} }
......
...@@ -104,8 +104,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -104,8 +104,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// constructors // constructors
TypeInferencer() { TypeInferencer() {
} }
explicit TypeInferencer(Environment env) explicit TypeInferencer(Module mod)
: env_(env) { : mod_(mod) {
} }
// inference the type of expr. // inference the type of expr.
...@@ -115,7 +115,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -115,7 +115,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// type resolver that maps back to type // type resolver that maps back to type
class Resolver; class Resolver;
// internal environment // internal environment
Environment env_; Module mod_;
// map from expression to checked type // map from expression to checked type
// type inferencer will populate it up // type inferencer will populate it up
std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual> type_map_; std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual> type_map_;
...@@ -164,9 +164,9 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -164,9 +164,9 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
Type VisitExpr_(const GlobalVarNode* op) final { Type VisitExpr_(const GlobalVarNode* op) final {
GlobalVar var = GetRef<GlobalVar>(op); GlobalVar var = GetRef<GlobalVar>(op);
CHECK(env_.defined()) CHECK(mod_.defined())
<< "Cannot do type inference without a global variable"; << "Cannot do type inference without a global variable";
Expr e = env_->Lookup(var); Expr e = mod_->Lookup(var);
return e->checked_type(); return e->checked_type();
} }
...@@ -511,20 +511,20 @@ Expr TypeInferencer::Infer(Expr expr) { ...@@ -511,20 +511,20 @@ Expr TypeInferencer::Infer(Expr expr) {
} }
Expr InferType(const Expr& expr, const Environment& env) { Expr InferType(const Expr& expr, const Module& mod) {
auto e = TypeInferencer(env).Infer(expr); auto e = TypeInferencer(mod).Infer(expr);
CHECK(WellFormed(e)); CHECK(WellFormed(e));
return e; return e;
} }
Function InferType(const Function& func, Function InferType(const Function& func,
const Environment& env, const Module& mod,
const GlobalVar& var) { const GlobalVar& var) {
Function func_copy = Function(make_node<FunctionNode>(*func.operator->())); Function func_copy = Function(make_node<FunctionNode>(*func.operator->()));
func_copy->checked_type_ = func_copy->func_type_annotation(); func_copy->checked_type_ = func_copy->func_type_annotation();
env->functions.Set(var, func_copy); mod->functions.Set(var, func_copy);
Expr func_ret = TypeInferencer(env).Infer(func_copy); Expr func_ret = TypeInferencer(mod).Infer(func_copy);
auto map_node = env->functions.CopyOnWrite(); auto map_node = mod->functions.CopyOnWrite();
map_node->data.erase(var.node_); map_node->data.erase(var.node_);
CHECK(WellFormed(func_ret)); CHECK(WellFormed(func_ret));
return Downcast<Function>(func_ret); return Downcast<Function>(func_ret);
......
...@@ -11,7 +11,7 @@ TEST(Relay, SelfReference) { ...@@ -11,7 +11,7 @@ TEST(Relay, SelfReference) {
auto x = relay::VarNode::make("x", type_a); auto x = relay::VarNode::make("x", type_a);
auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, type_b, Array<relay::TypeVar>{}); auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, type_b, Array<relay::TypeVar>{});
auto fx = relay::CallNode::make(f, Array<relay::Expr>{ x }); auto fx = relay::CallNode::make(f, Array<relay::Expr>{ x });
auto type_fx = relay::InferType(fx, relay::EnvironmentNode::make(Map<relay::GlobalVar, relay::Function>{})); auto type_fx = relay::InferType(fx, relay::ModuleNode::make(Map<relay::GlobalVar, relay::Function>{}));
CHECK_EQ(type_fx->checked_type(), type_a); CHECK_EQ(type_fx->checked_type(), type_a);
} }
......
...@@ -6,10 +6,10 @@ from tvm.relay.ir_pass import infer_type ...@@ -6,10 +6,10 @@ from tvm.relay.ir_pass import infer_type
from tvm.relay.interpreter import Interpreter from tvm.relay.interpreter import Interpreter
from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.op import add from tvm.relay.op import add
from tvm.relay.env import Environment from tvm.relay.module import Module
# @tq, @jr should we put this in testing ns? # @tq, @jr should we put this in testing ns?
def check_rts(expr, args, expected_result, env=None): def check_rts(expr, args, expected_result, mod=None):
""" """
Check that evaluating `expr` applied to the arguments produces Check that evaluating `expr` applied to the arguments produces
`result` on both the evaluator and TVM runtime. `result` on both the evaluator and TVM runtime.
...@@ -25,8 +25,8 @@ def check_rts(expr, args, expected_result, env=None): ...@@ -25,8 +25,8 @@ def check_rts(expr, args, expected_result, env=None):
expected_result: expected_result:
The expected result of running the expression. The expected result of running the expression.
""" """
intrp = create_executor('graph', env=env) intrp = create_executor('graph', mod=mod)
graph = create_executor('graph', env=env) graph = create_executor('graph', mod=mod)
eval_result = intrp.evaluate(expr)(*args) eval_result = intrp.evaluate(expr)(*args)
rts_result = graph.evaluate(expr)(*args) rts_result = graph.evaluate(expr)(*args)
np.testing.assert_allclose(eval_result.asnumpy(), rts_result.asnumpy()) np.testing.assert_allclose(eval_result.asnumpy(), rts_result.asnumpy())
......
...@@ -7,8 +7,8 @@ from tvm.relay.scope_builder import ScopeBuilder ...@@ -7,8 +7,8 @@ from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay import testing, create_executor from tvm.relay import testing, create_executor
def check_eval(expr, args, expected_result, env=None, rtol=1e-07): def check_eval(expr, args, expected_result, mod=None, rtol=1e-07):
intrp = create_executor(env=env) intrp = create_executor(mod=mod)
result = intrp.evaluate(expr)(*args) result = intrp.evaluate(expr)(*args)
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol) np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
...@@ -87,7 +87,7 @@ def test_subtract(): ...@@ -87,7 +87,7 @@ def test_subtract():
check_eval(func, [i_data], 0) check_eval(func, [i_data], 0)
def test_simple_loop(): def test_simple_loop():
env = relay.env.Environment({}) mod = relay.module.Module({})
sum_up = relay.GlobalVar('sum_up') sum_up = relay.GlobalVar('sum_up')
i = relay.var('i', shape=[], dtype='int32') i = relay.var('i', shape=[], dtype='int32')
sb = ScopeBuilder() sb = ScopeBuilder()
...@@ -98,12 +98,12 @@ def test_simple_loop(): ...@@ -98,12 +98,12 @@ def test_simple_loop():
rec_call = relay.Call(sum_up, [one_less]) rec_call = relay.Call(sum_up, [one_less])
sb.ret(op.add(rec_call, i)) sb.ret(op.add(rec_call, i))
func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32')) func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32'))
env[sum_up] = func mod[sum_up] = func
i_data = np.array(10, dtype='int32') i_data = np.array(10, dtype='int32')
check_eval(sum_up, [i_data], sum(range(1, 11)), env=env) check_eval(sum_up, [i_data], sum(range(1, 11)), mod=mod)
def test_loop(): def test_loop():
env = relay.env.Environment({}) mod = relay.module.Module({})
sum_up = relay.GlobalVar('sum_up') sum_up = relay.GlobalVar('sum_up')
i = relay.var('i', shape=[], dtype='int32') i = relay.var('i', shape=[], dtype='int32')
accum = relay.var('accum', shape=[], dtype='int32') accum = relay.var('accum', shape=[], dtype='int32')
...@@ -115,10 +115,10 @@ def test_loop(): ...@@ -115,10 +115,10 @@ def test_loop():
new_accum = op.add(accum, i) new_accum = op.add(accum, i)
sb.ret(relay.Call(sum_up, [one_less, new_accum])) sb.ret(relay.Call(sum_up, [one_less, new_accum]))
func = relay.Function([i, accum], sb.get()) func = relay.Function([i, accum], sb.get())
env[sum_up] = func mod[sum_up] = func
i_data = np.array(10, dtype='int32') i_data = np.array(10, dtype='int32')
accum_data = np.array(0, dtype='int32') accum_data = np.array(0, dtype='int32')
check_eval(sum_up, [i_data, accum_data], sum(range(1, 11)), env=env) check_eval(sum_up, [i_data, accum_data], sum(range(1, 11)), mod=mod)
def test_mlp(): def test_mlp():
pass pass
......
...@@ -28,7 +28,7 @@ def test_env(): ...@@ -28,7 +28,7 @@ def test_env():
z = relay.add(x, y) z = relay.add(x, y)
z = relay.add(z, z) z = relay.add(z, z)
f = relay.Function([x, y], z) f = relay.Function([x, y], z)
env = relay.Environment() env = relay.Module()
env["myf"] = f env["myf"] = f
text = env.astext() text = env.astext()
assert "def @myf" in text assert "def @myf" in text
......
...@@ -9,8 +9,8 @@ from tvm.relay import op ...@@ -9,8 +9,8 @@ from tvm.relay import op
from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.scope_builder import ScopeBuilder
def assert_has_type(expr, typ, env=relay.env.Environment({})): def assert_has_type(expr, typ, mod=relay.module.Module({})):
checked_expr = infer_type(expr, env) checked_expr = infer_type(expr, mod)
checked_type = checked_expr.checked_type checked_type = checked_expr.checked_type
if checked_type != typ: if checked_type != typ:
raise RuntimeError("Type mismatch %s vs %s" % ( raise RuntimeError("Type mismatch %s vs %s" % (
...@@ -105,10 +105,10 @@ def test_recursion(): ...@@ -105,10 +105,10 @@ def test_recursion():
sb.ret(data) sb.ret(data)
with sb.else_scope(): with sb.else_scope():
sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data))) sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data)))
env = relay.Environment() mod = relay.Module()
env[f] = relay.Function([n, data], sb.get()) mod[f] = relay.Function([n, data], sb.get())
assert "%3 = @f(%1, %2)" in env.astext() assert "%3 = @f(%1, %2)" in mod.astext()
assert env[f].checked_type == relay.FuncType([ti32, tf32], tf32) assert mod[f].checked_type == relay.FuncType([ti32, tf32], tf32)
# This currently fails and should pass under the type system. # This currently fails and should pass under the type system.
# #
...@@ -179,12 +179,12 @@ def test_self_reference(): ...@@ -179,12 +179,12 @@ def test_self_reference():
def test_global_var_cow_issue(): def test_global_var_cow_issue():
env = relay.env.Environment({}) mod = relay.Module({})
gv = relay.GlobalVar("foo") gv = relay.GlobalVar("foo")
x = relay.var('x', shape=[]) x = relay.var('x', shape=[])
func = relay.Function([x], relay.Call(gv, [x]), func = relay.Function([x], relay.Call(gv, [x]),
relay.TensorType([], 'float32')) relay.TensorType([], 'float32'))
env[gv] = func mod[gv] = func
def test_equal(): def test_equal():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment