Commit ead3ac6c by Jared Roesch Committed by Tianqi Chen

Rename relay::Environment to relay::Module (#2054)

parent 420ec786
......@@ -165,7 +165,7 @@ class RelayNode : public Node {
TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node);
};
struct Environment;
struct Module;
} // namespace relay
} // namespace tvm
......
......@@ -8,7 +8,7 @@
#define TVM_RELAY_BUILD_MODULE_H_
#include <tvm/lowered_func.h>
#include <tvm/relay/environment.h>
#include <tvm/relay/module.h>
#include <tvm/relay/expr.h>
#include <string>
......@@ -61,13 +61,13 @@ RELAY_DEFINE_NODE_REF(LoweredOp, LoweredOpNode, NodeRef);
* \note This will do a reachability analysis and lower all definitions
* reachable from the provided expression.
*
* \param env The environment.
* \param mod The module.
* \param expr The expression with operations to be lowered.
* \param target The target to lower the functions to.
*
* \return The set of lowered operations.
*/
Array<LoweredOp> LowerOps(const Environment& env, const Expr& expr,
Array<LoweredOp> LowerOps(const Module& mod, const Expr& expr,
const std::string& target = "llvm");
} // namespace relay
......
......@@ -160,7 +160,7 @@ class VarNode : public ExprNode {
RELAY_DEFINE_NODE_REF(Var, VarNode, Expr);
/*!
* \brief Global variable that leaves in the top-level environment.
* \brief Global variable that leaves in the top-level module.
* This is used to enable recursive calls between function.
*
* \note A GlobalVar may only point to functions.
......
......@@ -4,7 +4,7 @@
* \brief An interpreter for Relay.
*
* This file implements a simple reference interpreter for Relay programs.
* Given a Relay environment, and a Relay expression it produces a value.
* Given a Relay module, and a Relay expression it produces a value.
*
* The interpreter's values are a naive representation of the values that
* can be produced by a Relay program and are exposed via tvm::Node's
......@@ -16,7 +16,7 @@
#ifndef TVM_RELAY_INTERPRETER_H_
#define TVM_RELAY_INTERPRETER_H_
#include <tvm/relay/environment.h>
#include <tvm/relay/module.h>
#include <tvm/relay/expr.h>
namespace tvm {
......@@ -39,7 +39,7 @@ class Value;
* Our intent is that this will never be the most efficient implementation of
* Relay's semantics, but a readable and clear one.
*/
Value Evaluate(Environment env, Expr e);
Value Evaluate(Module mod, Expr e);
/*! \brief The base container type of Relay values. */
class ValueNode : public RelayNode {
......
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/environment.h
* \file tvm/relay/module.h
* \brief The global environment: contains information needed to
* compile & optimize Relay programs.
*/
#ifndef TVM_RELAY_ENVIRONMENT_H_
#define TVM_RELAY_ENVIRONMENT_H_
#ifndef TVM_RELAY_MODULE_H_
#define TVM_RELAY_MODULE_H_
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
......@@ -17,7 +17,7 @@
namespace tvm {
namespace relay {
struct Environment;
struct Module;
/*! \brief The global environment of Relay programs.
*
......@@ -28,29 +28,29 @@ struct Environment;
* options.
*
* Many operations require access to the global
* Environment. We pass the Environment by value
* Module. We pass the Module by value
* in a functional style as an explicit argument,
* but we mutate the Environment while optimizing
* but we mutate the Module while optimizing
* Relay programs.
*
* The functional style allows users to construct custom
* environments easily, for example each thread can store
* an Environment while auto-tuning.
* an Module while auto-tuning.
* */
class EnvironmentNode : public RelayNode {
class ModuleNode : public RelayNode {
public:
/*! \brief A map from ids to all global functions. */
tvm::Map<GlobalVar, Function> functions;
EnvironmentNode() {}
ModuleNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("functions", &functions);
v->Visit("global_var_map_", &global_var_map_);
}
TVM_DLL static Environment make(tvm::Map<GlobalVar, Function> global_funcs);
TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs);
/*!
* \brief Add a function to the global environment.
......@@ -100,10 +100,10 @@ class EnvironmentNode : public RelayNode {
* functions in another environment.
* \param other The other environment.
*/
void Update(const Environment& other);
void Update(const Module& other);
static constexpr const char* _type_key = "relay.Environment";
TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node);
static constexpr const char* _type_key = "relay.Module";
TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node);
private:
/*! \brief A map from string names to global variables that
......@@ -112,18 +112,18 @@ class EnvironmentNode : public RelayNode {
tvm::Map<std::string, GlobalVar> global_var_map_;
};
struct Environment : public NodeRef {
Environment() {}
explicit Environment(NodePtr<tvm::Node> p) : NodeRef(p) {}
struct Module : public NodeRef {
Module() {}
explicit Module(NodePtr<tvm::Node> p) : NodeRef(p) {}
inline EnvironmentNode* operator->() const {
return static_cast<EnvironmentNode*>(node_.get());
inline ModuleNode* operator->() const {
return static_cast<ModuleNode*>(node_.get());
}
using ContainerType = EnvironmentNode;
using ContainerType = ModuleNode;
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ENVIRONMENT_H_
#endif // TVM_RELAY_MODULE_H_
......@@ -6,7 +6,7 @@
#ifndef TVM_RELAY_PASS_H_
#define TVM_RELAY_PASS_H_
#include <tvm/relay/environment.h>
#include <tvm/relay/module.h>
#include <tvm/relay/expr.h>
#include <string>
......@@ -21,23 +21,23 @@ namespace relay {
* populated with the result type.
*
* \param expr The expression to type check.
* \param env The environment used for referencing global functions, can be
* \param mod The module used for referencing global functions, can be
* None.
*
* \return A type checked expression with its checked_type field populated.
*/
Expr InferType(const Expr& expr, const Environment& env);
Expr InferType(const Expr& expr, const Module& mod);
/*!
* \brief Infer the type of a function as if it is mapped to var in the env.
* \brief Infer the type of a function as if it is mapped to var in the mod.
*
* \param f the function.
* \param env The environment used for referencing global functions.
* \param mod The module used for referencing global functions.
* \param var The global variable corresponding to the function.
*
* \return A type checked Function with its checked_type field populated.
* \note this function mutates env and is not thread-safe.
* \note this function mutates mod and is not thread-safe.
*/
Function InferType(const Function& f, const Environment& env,
Function InferType(const Function& f, const Module& mod,
const GlobalVar& var);
/*!
......@@ -52,11 +52,11 @@ Function InferType(const Function& f, const Environment& env,
* a data type such as `int`, `float`, `uint`.
*
* \param t The type to check.
* \param env The global environment.
* \param mod The global module.
*
* \return true if the rules are satisified otherwise false
*/
bool KindCheck(const Type& t, const Environment& env);
bool KindCheck(const Type& t, const Module& mod);
/*! \brief Compare two expressions for structural equivalence.
*
......
......@@ -349,14 +349,14 @@ class TypeRelation;
/*!
* \brief TypeRelation container.
* \note This node is not directly serializable.
* The type function need to be lookedup in the environment.
* The type function need to be lookedup in the module.
*/
class TypeRelationNode : public TypeConstraintNode {
public:
/*!
* \brief The function on input and output variables which
* this is not directly serializable,
* need to be looked-up in the environment.
* need to be looked-up in the module.
*/
TypeRelationFn func;
/*! \brief The type arguments to the type function. */
......
......@@ -5,7 +5,7 @@ from ..api import register_func
from . import base
from . import ty
from . import expr
from . import env
from . import module
from . import ir_pass
from .build_module import build
from .interpreter import create_executor
......@@ -26,7 +26,7 @@ from .scope_builder import ScopeBuilder
Span = base.Span
# Env
Environment = env.Environment
Module = module.Module
# Type
Type = ty.Type
......
from .env import Environment
from .env import Module
from . import ir
def check_expr(env: Environment, expr: ir.Expr) -> ir.Type: ...
def generalize(env: Environment, expr: ir.Expr) -> ir.Expr: ...
def check_expr(env: Module, expr: ir.Expr) -> ir.Type: ...
def generalize(env: Module, expr: ir.Expr) -> ir.Expr: ...
def _get_checked_type(expr: ir.Expr) -> ir.Type: ...
def well_formed(expr: ir.Expr) -> bool: ...
def dead_code_elimination(expr: ir.Expr) -> ir.Expr: ...
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface to the Environment exposed from C++."""
"""The interface to the Module exposed from C++."""
from tvm._ffi.function import _init_api
_init_api("relay._env", __name__)
_init_api("relay._module", __name__)
......@@ -2,4 +2,4 @@ from typing import Union, Tuple, Dict, List
from relay.ir import GlobalId, OperatorId, Item, NodeBase, Span, FileId
from relay.ir import ShapeExtension, Operator, Defn
class Environment(NodeBase): ...
\ No newline at end of file
class Module(NodeBase): ...
......@@ -5,9 +5,9 @@ from a Relay expression.
from ..build_module import build as tvm_build_module
from . graph_runtime_codegen import GraphRuntimeCodegen
from . import ir_pass
from .env import Environment
from .module import Module
def build(func, params=None, target=None, env=None):
def build(func, params=None, target=None, mod=None):
"""
Compile a single function to the components needed by the
TVM RTS.
......@@ -29,15 +29,15 @@ def build(func, params=None, target=None, env=None):
if target is None:
target = 'llvm'
if env is None:
env = Environment({})
if mod is None:
mod = Module({})
comp = GraphRuntimeCodegen(env)
comp = GraphRuntimeCodegen(mod)
# NB(@jroesch) This creates lowered functions, and generates names for them
#
# We need these names to emit the correct graph as these are names of the
# functions contained in the module.
lowered_ops = ir_pass.lower_ops(env, func)
lowered_ops = ir_pass.lower_ops(mod, func)
mod = tvm_build_module([lf.lowered_func for lf in lowered_ops], target)
# Therefore the call to compile must come after.
......
......@@ -172,7 +172,7 @@ class GlobalVar(Expr):
"""A global variable in Tvm.Relay.
GlobalVar is used to refer to the global functions
stored in the environment.
stored in the module.
Parameters
----------
......
......@@ -8,7 +8,7 @@ from . import build_module
from . import _make
from . import _interpreter
from . import ir_pass
from .env import Environment
from .module import Module
from .expr import Call, Constant, GlobalVar, Function, const
from .scope_builder import ScopeBuilder
from .._ffi.base import integer_types
......@@ -90,24 +90,24 @@ def _arg_to_ast(arg):
class Executor(object):
"""An abstract interface for executing Relay programs."""
def __init__(self, env=None):
def __init__(self, mod=None):
"""
Parameters
----------
env: relay.Environment
The environment.
mod: relay.Module
The module.
"""
if env is None:
self.env = Environment({})
if mod is None:
self.mod = Module({})
else:
self.env = env
self.mod = mod
def optimize(self, expr):
# TODO: We need to move this optimization code into the optimizer/pass manager
ck_expr = ir_pass.infer_type(expr, env=self.env)
fused_expr = ir_pass.fuse_ops(self.env, ck_expr)
ck_fused = ir_pass.infer_type(fused_expr, env=self.env)
ck_expr = ir_pass.infer_type(expr, mod=self.mod)
fused_expr = ir_pass.fuse_ops(self.mod, ck_expr)
ck_fused = ir_pass.infer_type(fused_expr, mod=self.mod)
return ck_fused
def _make_executor(self, _):
......@@ -153,8 +153,8 @@ class Interpreter(Executor):
"""
A wrapper around the Relay interpreter, implements the excecutor interface.
"""
def __init__(self, env=None):
Executor.__init__(self, env)
def __init__(self, mod=None):
Executor.__init__(self, mod)
def _make_executor(self, expr):
def _interp_wrapper(*args):
......@@ -163,28 +163,28 @@ class Interpreter(Executor):
relay_args.append(_arg_to_ast(arg))
if isinstance(expr, GlobalVar):
func = self.env[expr]
func = self.mod[expr]
func = self.optimize(func)
self.env._add(expr, func, True)
self.mod._add(expr, func, True)
opt_expr = Call(expr, relay_args)
return _interpreter.evaluate(self.env, opt_expr)
return _interpreter.evaluate(self.mod, opt_expr)
else:
call = Call(expr, relay_args)
opt_expr = self.optimize(call)
return _interpreter.evaluate(self.env, opt_expr)
return _interpreter.evaluate(self.mod, opt_expr)
return _interp_wrapper
class GraphRuntime(Executor):
"""A wrapper around the TVM graph runtime, implements the Executor interface."""
def __init__(self, env=None):
Executor.__init__(self, env)
def __init__(self, mod=None):
Executor.__init__(self, mod)
def _make_executor(self, expr):
def _graph_wrapper(*args):
func = self.optimize(expr)
graph_json, mod, params = build_module.build(func, env=self.env)
graph_json, mod, params = build_module.build(func, mod=self.mod)
assert params is None
gmodule = tvm_runtime.create(graph_json, mod, cpu(0))
# Create map of inputs.
......@@ -199,10 +199,10 @@ class GraphRuntime(Executor):
return _graph_wrapper
def create_executor(mode='debug', env=None):
def create_executor(mode='debug', mod=None):
if mode == 'debug':
return Interpreter(env)
return Interpreter(mod)
elif mode == 'graph':
return GraphRuntime(env)
return GraphRuntime(mod)
else:
raise Exception("unknown mode {0}".format(mode))
......@@ -11,16 +11,16 @@ from .expr import Expr
from .ty import Type
def infer_type(expr, env=None):
"""Infer the type of expr under the context of env.
def infer_type(expr, mod=None):
"""Infer the type of expr under the context of mod.
Parameters
----------
expr: tvm.relay.Expr
The input expression.
env: Optional[tvm.relay.Environment]
The global environment.
mod: Optional[tvm.relay.Module]
The global module.
Returns
......@@ -28,7 +28,7 @@ def infer_type(expr, env=None):
checked_expr : tvm.relay.Expr
The checked expression.
"""
return _ir_pass.infer_type(expr, env)
return _ir_pass.infer_type(expr, mod)
def backward_fold_scale_axis(expr):
......@@ -93,7 +93,7 @@ def well_formed(expr):
return _ir_pass.well_formed(expr)
def check_kind(t, env=None):
def check_kind(t, mod=None):
"""Check that the type is well kinded.
For example, this mean type cannot has tensor of tensor, or is a tuple type of 2 shapes.
......@@ -102,8 +102,8 @@ def check_kind(t, env=None):
t: tvm.relay.Type
The type to check
env: tvm.relay.Environment, optional
The global environment
mod: tvm.relay.Module, optional
The global module
Returns
-------
......@@ -117,8 +117,8 @@ def check_kind(t, env=None):
assert not check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Shape)]))
assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)]))
"""
if env is not None:
return _ir_pass.check_kind(t, env)
if mod is not None:
return _ir_pass.check_kind(t, mod)
else:
return _ir_pass.check_kind(t)
......@@ -256,8 +256,8 @@ def structural_hash(value):
"relay.Expr or relay.Type").format(type(value))
raise TypeError(msg)
def fuse_ops(expr, env):
return _ir_pass.FuseOps(env, expr)
def fuse_ops(expr, mod):
return _ir_pass.FuseOps(mod, expr)
def lower_ops(env, expr, target='llvm'):
return _ir_pass.LowerOps(env, expr, target)
def lower_ops(mod, expr, target='llvm'):
return _ir_pass.LowerOps(mod, expr, target)
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
"""A global environment storing everything needed to interpret or compile a Relay program."""
"""A global module storing everything needed to interpret or compile a Relay program."""
from .base import register_relay_node, RelayNode
from .._ffi import base as _base
from . import _make
from . import _env
from . import _module
from . import expr as _expr
@register_relay_node
class Environment(RelayNode):
"""The global Relay environment containing collection of functions.
class Module(RelayNode):
"""The global Relay module containing collection of functions.
Each global function is identified by an unique tvm.relay.GlobalVar.
tvm.relay.GlobalVar and Environment is necessary in order to enable
tvm.relay.GlobalVar and Module is necessary in order to enable
recursions in function to avoid cyclic reference in the function.x
Parameters
......@@ -32,10 +32,10 @@ class Environment(RelayNode):
raise TypeError("Expect functions to be Dict[GlobalVar, Function]")
mapped_funcs[k] = v
functions = mapped_funcs
self.__init_handle_by_constructor__(_make.Environment, functions)
self.__init_handle_by_constructor__(_make.Module, functions)
def __setitem__(self, var, func):
"""Add a function to the environment.
"""Add a function to the module.
Parameters
---------
......@@ -50,7 +50,7 @@ class Environment(RelayNode):
def _add(self, var, func, update=False):
if isinstance(var, _base.string_types):
var = _expr.GlobalVar(var)
return _env.Environment_Add(self, var, func, update)
return _module.Module_Add(self, var, func, update)
def __getitem__(self, var):
"""Lookup a global function by name or by variable.
......@@ -66,21 +66,21 @@ class Environment(RelayNode):
The function referenced by :code:`var`.
"""
if isinstance(var, _base.string_types):
return _env.Environment_Lookup_str(self, var)
return _module.Module_Lookup_str(self, var)
else:
return _env.Environment_Lookup(self, var)
return _module.Module_Lookup(self, var)
def update(self, other):
"""Insert functions in another Environment to current one.
"""Insert functions in another Module to current one.
Parameters
----------
other: Environment
The environment to merge into the current Environment.
other: Module
The module to merge into the current Module.
"""
if isinstance(other, dict):
other = Environment(other)
return _env.Environment_Update(self, other)
other = Module(other)
return _module.Module_Update(self, other)
def get_global_var(self, name):
"""Get a global variable in the function by name.
......@@ -99,4 +99,4 @@ class Environment(RelayNode):
------
tvm.TVMError if we cannot find corresponding global var.
"""
return _env.Environment_GetGlobalVar(self, name)
return _module.Module_GetGlobalVar(self, name)
......@@ -183,7 +183,7 @@ struct ExprEqual {
};
struct Interpreter : ExprFunctor<Value(const Expr& n)> {
Environment env;
Module mod;
Stack stack;
using JitKey = Function;
......@@ -197,8 +197,8 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
return f();
}
Interpreter(Environment env) : env(env), operator_map_() {}
Interpreter(Environment env, OpMap operator_map) : env(env), operator_map_(operator_map) {}
Interpreter(Module mod) : mod(mod), operator_map_() {}
Interpreter(Module mod, OpMap operator_map) : mod(mod), operator_map_(operator_map) {}
void extend(const Var& id, Value v) {
this->stack.current_frame().locals.Set(id, v);
......@@ -223,7 +223,7 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
}
Value VisitExpr_(const GlobalVarNode* op) override {
return Eval(this->env->Lookup(GetRef<GlobalVar>(op)));
return Eval(this->mod->Lookup(GetRef<GlobalVar>(op)));
}
Value VisitExpr_(const OpNode* id) override {
......@@ -251,14 +251,14 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
Value VisitExpr_(const FunctionNode* func_node) override {
auto func = GetRef<Function>(func_node);
tvm::Map<Var, Value> captured_env;
tvm::Map<Var, Value> captured_mod;
Array<Var> free_vars = FreeVars(func);
for (const auto& var : free_vars) {
captured_env.Set(var, Eval(var));
captured_mod.Set(var, Eval(var));
}
return ClosureNode::make(captured_env, func);
return ClosureNode::make(captured_mod, func);
}
inline Value InvokeCompiledOp(PackedFunc func, const Array<Value>& args,
......@@ -315,7 +315,7 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
locals.Set(func->params[i], args[i]);
}
// Add the var to value mappings from the Closure's environment.
// Add the var to value mappings from the Closure's modironment.
for (auto it = closure->env.begin(); it != closure->env.end(); ++it) {
CHECK_EQ(locals.count((*it).first), 0);
locals.Set((*it).first, (*it).second);
......@@ -384,9 +384,9 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
}
};
Interpreter::OpMap CompileOperators(const Environment& env, const Expr& e) {
Interpreter::OpMap CompileOperators(const Module& mod, const Expr& e) {
Interpreter::OpMap op_map;
auto lowered_ops = LowerOps(env, e);
auto lowered_ops = LowerOps(mod, e);
RELAY_LOG(INFO) << "LoweredFuncs: " << lowered_ops << std::endl;
if (lowered_ops.size()) {
const PackedFunc* fbuild_ptr = Registry::Get("relay.op.compiler._build");
......@@ -399,7 +399,7 @@ Interpreter::OpMap CompileOperators(const Environment& env, const Expr& e) {
lowered_funcs.push_back(lop->lowered_func);
}
Module module = fbuild(lowered_funcs);
runtime::Module module = fbuild(lowered_funcs);
// Loop over the lowered operations to map them into the operator map.
for (auto lop : lowered_ops) {
......@@ -415,17 +415,17 @@ Interpreter::OpMap CompileOperators(const Environment& env, const Expr& e) {
return op_map;
}
Value Evaluate(Environment env, Expr e) {
auto op_map = CompileOperators(env, e);
Interpreter interp(env, op_map);
Value Evaluate(Module mod, Expr e) {
auto op_map = CompileOperators(mod, e);
Interpreter interp(mod, op_map);
return interp.Eval(e);
}
TVM_REGISTER_API("relay._interpreter.evaluate")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Environment env = args[0];
Module mod = args[0];
Expr expr = args[1];
*ret = Evaluate(env, expr);
*ret = Evaluate(mod, expr);
});
} // namespace relay
......
/*!
* Copyright (c) 2018 by Contributors
* \file environment.cc
* \brief The global environment in Relay.
* \file module.cc
* \brief The global module in Relay.
*/
#include <tvm/relay/environment.h>
#include <tvm/relay/module.h>
#include <tvm/relay/pass.h>
#include <sstream>
......@@ -13,8 +13,8 @@ namespace relay {
using tvm::IRPrinter;
using namespace runtime;
Environment EnvironmentNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
auto n = make_node<EnvironmentNode>();
Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
auto n = make_node<ModuleNode>();
n->functions = std::move(global_funcs);
for (const auto& kv : n->functions) {
......@@ -23,22 +23,22 @@ Environment EnvironmentNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
<< "Duplicate global function name " << kv.first->name_hint;
n->global_var_map_.Set(kv.first->name_hint, kv.first);
}
return Environment(n);
return Module(n);
}
GlobalVar EnvironmentNode::GetGlobalVar(const std::string& name) {
GlobalVar ModuleNode::GetGlobalVar(const std::string& name) {
auto it = global_var_map_.find(name);
CHECK(it != global_var_map_.end())
<< "Cannot find global var " << name << " in the Environment";
<< "Cannot find global var " << name << " in the Module";
return (*it).second;
}
void EnvironmentNode::Add(const GlobalVar& var,
void ModuleNode::Add(const GlobalVar& var,
const Function& func,
bool update) {
// Type check the item before we add it to the environment.
auto env = GetRef<Environment>(this);
Function checked_func = InferType(func, env, var);
// Type check the item before we add it to the modironment.
auto mod = GetRef<Module>(this);
Function checked_func = InferType(func, mod, var);
auto type = checked_func->checked_type();
CHECK(type.as<IncompleteTypeNode>() == nullptr);
if (functions.find(var) != functions.end()) {
......@@ -46,7 +46,7 @@ void EnvironmentNode::Add(const GlobalVar& var,
<< "Already have definition for " << var->name_hint;
auto old_type = functions[var].as<FunctionNode>()->checked_type();
CHECK(AlphaEqual(type, old_type))
<< "Environment#update changes type, not possible in this mode.";
<< "Module#update changes type, not possible in this mode.";
}
this->functions.Set(var, checked_func);
......@@ -62,79 +62,79 @@ void EnvironmentNode::Add(const GlobalVar& var,
global_var_map_.Set(var->name_hint, var);
}
void EnvironmentNode::Update(const GlobalVar& var, const Function& func) {
void ModuleNode::Update(const GlobalVar& var, const Function& func) {
this->Add(var, func, true);
}
void EnvironmentNode::Remove(const GlobalVar& var) {
void ModuleNode::Remove(const GlobalVar& var) {
auto functions_node = this->functions.CopyOnWrite();
functions_node->data.erase(var.node_);
auto gvar_node = global_var_map_.CopyOnWrite();
gvar_node->data.erase(var->name_hint);
}
Function EnvironmentNode::Lookup(const GlobalVar& var) {
Function ModuleNode::Lookup(const GlobalVar& var) {
auto it = functions.find(var);
CHECK(it != functions.end())
<< "There is no definition of " << var->name_hint;
return (*it).second;
}
Function EnvironmentNode::Lookup(const std::string& name) {
Function ModuleNode::Lookup(const std::string& name) {
GlobalVar id = this->GetGlobalVar(name);
return this->Lookup(id);
}
void EnvironmentNode::Update(const Environment& env) {
for (auto pair : env->functions) {
void ModuleNode::Update(const Module& mod) {
for (auto pair : mod->functions) {
this->Update(pair.first, pair.second);
}
}
TVM_REGISTER_NODE_TYPE(EnvironmentNode);
TVM_REGISTER_NODE_TYPE(ModuleNode);
TVM_REGISTER_API("relay._make.Environment")
TVM_REGISTER_API("relay._make.Module")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = EnvironmentNode::make(args[0]);
*ret = ModuleNode::make(args[0]);
});
TVM_REGISTER_API("relay._env.Environment_Add")
TVM_REGISTER_API("relay._module.Module_Add")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0];
env->Add(args[1], args[2], args[3]);
Module mod = args[0];
mod->Add(args[1], args[2], args[3]);
});
TVM_REGISTER_API("relay._env.Environment_GetGlobalVar")
TVM_REGISTER_API("relay._module.Module_GetGlobalVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0];
*ret = env->GetGlobalVar(args[1]);
Module mod = args[0];
*ret = mod->GetGlobalVar(args[1]);
});
TVM_REGISTER_API("relay._env.Environment_Lookup")
TVM_REGISTER_API("relay._module.Module_Lookup")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0];
Module mod = args[0];
GlobalVar var = args[1];
*ret = env->Lookup(var);
*ret = mod->Lookup(var);
});
TVM_REGISTER_API("relay._env.Environment_Lookup_str")
TVM_REGISTER_API("relay._module.Module_Lookup_str")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0];
Module mod = args[0];
std::string var_name = args[1];
auto var = env->GetGlobalVar(var_name);
*ret = env->Lookup(var);
auto var = mod->GetGlobalVar(var_name);
*ret = mod->Lookup(var);
});
TVM_REGISTER_API("relay._env.Environment_Update")
TVM_REGISTER_API("relay._module.Module_Update")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0];
env->Update(args[1]);
Module mod = args[0];
mod->Update(args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<EnvironmentNode>(
[](const EnvironmentNode *node, tvm::IRPrinter *p) {
p->stream << "EnvironmentNode( " << node->functions << ")";
.set_dispatch<ModuleNode>(
[](const ModuleNode *node, tvm::IRPrinter *p) {
p->stream << "ModuleNode( " << node->functions << ")";
});
} // namespace relay
......
......@@ -3,7 +3,7 @@
* \file text_printer.cc
* \brief Text printer to print relay in text form.
*/
#include <tvm/relay/environment.h>
#include <tvm/relay/module.h>
#include <tvm/relay/expr_functor.h>
#include <sstream>
#include "type_functor.h"
......@@ -133,8 +133,8 @@ class TextPrinter :
std::string Print(const NodeRef& node) {
if (node.as<FunctionNode>()) {
this->PrintFunc(Downcast<Function>(node));
} else if (node.as<EnvironmentNode>()) {
this->PrintEnv(Downcast<Environment>(node));
} else if (node.as<ModuleNode>()) {
this->PrintEnv(Downcast<Module>(node));
} else if (node.as_derived<TypeNode>()) {
this->PrintType(Downcast<Type>(node), stream_);
} else if (node.as_derived<ExprNode>()) {
......@@ -158,9 +158,9 @@ class TextPrinter :
stream_ << "\n";
}
void PrintEnv(const Environment& env) {
void PrintEnv(const Module& mod) {
int counter = 0;
for (const auto& kv : env->functions) {
for (const auto& kv : mod->functions) {
std::ostringstream os;
if (counter++ != 0) {
stream_ << "\n";
......
......@@ -20,12 +20,12 @@ namespace relay {
using namespace runtime;
struct AbstractFusableOps : ExprMutator {
Environment env;
Module mod;
Array<GlobalVar> fusable_funcs;
int counter = 0;
size_t expr_hash;
AbstractFusableOps(Environment env, size_t expr_hash) : env(env), expr_hash(expr_hash) {}
AbstractFusableOps(Module mod, size_t expr_hash) : mod(mod), expr_hash(expr_hash) {}
Expr VisitExpr_(const CallNode* call) {
if (auto op_node = call->op.as<OpNode>()) {
......@@ -55,7 +55,7 @@ struct AbstractFusableOps : ExprMutator {
func_name += "_";
func_name += std::to_string(expr_hash);
auto gv = GlobalVarNode::make(func_name);
env->Add(gv, func);
mod->Add(gv, func);
fusable_funcs.push_back(gv);
return CallNode::make(gv, args, Attrs());
} else {
......@@ -64,12 +64,12 @@ struct AbstractFusableOps : ExprMutator {
}
};
Expr FuseOps(const Environment& env, const Expr& e) {
Expr FuseOps(const Module& mod, const Expr& e) {
// First we convert all chains of fusable ops into
// abstracted functions which we mark as primtive
// then we convert these primtive functions into
// new operators.
auto abstract = AbstractFusableOps(env, StructuralHash()(e));
auto abstract = AbstractFusableOps(mod, StructuralHash()(e));
auto abstracted_e = abstract.VisitExpr(e);
RELAY_LOG(INFO) << "FuseOps: before=" << e
<< "Fuse: after=" << abstracted_e;
......
......@@ -99,7 +99,7 @@ struct KindChecker : TypeVisitor {
}
};
bool KindCheck(const Type& t, const Environment& env) {
bool KindCheck(const Type& t, const Module& mod) {
KindChecker kc;
return kc.Check(t);
}
......@@ -107,7 +107,7 @@ bool KindCheck(const Type& t, const Environment& env) {
TVM_REGISTER_API("relay._ir_pass.check_kind")
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 1) {
*ret = KindCheck(args[0], EnvironmentNode::make({}));
*ret = KindCheck(args[0], ModuleNode::make({}));
} else {
*ret = KindCheck(args[0], args[1]);
}
......
......@@ -28,12 +28,12 @@ LoweredOp LoweredOpNode::make(Function func, LoweredFunc lowered_func) {
}
struct AbstractLocalFunctions : ExprMutator {
Environment env;
Module mod;
size_t expr_hash;
int counter = 0;
std::unordered_set<GlobalVar, NodeHash, NodeEqual> visited_funcs;
explicit AbstractLocalFunctions(Environment env)
: env(env), expr_hash(0), counter(0), visited_funcs() {}
explicit AbstractLocalFunctions(Module mod)
: mod(mod), expr_hash(0), counter(0), visited_funcs() {}
Expr Abstract(const Expr& e) {
expr_hash = StructuralHash()(e);
......@@ -44,7 +44,7 @@ struct AbstractLocalFunctions : ExprMutator {
auto gvar = GetRef<GlobalVar>(gvar_node);
auto it = visited_funcs.find(gvar);
if (it == visited_funcs.end()) {
auto func = env->Lookup(gvar);
auto func = mod->Lookup(gvar);
visited_funcs.insert(gvar);
auto new_func = FunctionNode::make(
func->params,
......@@ -52,7 +52,7 @@ struct AbstractLocalFunctions : ExprMutator {
func->ret_type,
func->type_params,
func->attrs);
env->Update(gvar, new_func);
mod->Update(gvar, new_func);
}
return gvar;
}
......@@ -70,7 +70,7 @@ struct AbstractLocalFunctions : ExprMutator {
abs_func += std::to_string(expr_hash);
auto gv = GlobalVarNode::make(abs_func);
auto lifted_func = FunctionNode::make(params, func, Type(), {}, {});
env->Add(gv, lifted_func);
mod->Add(gv, lifted_func);
Array<Expr> args;
for (auto free_var : free_vars) {
args.push_back(free_var);
......@@ -80,8 +80,8 @@ struct AbstractLocalFunctions : ExprMutator {
};
struct LiveFunctions : ExprVisitor {
Environment env;
explicit LiveFunctions(Environment env) : env(env), global_funcs() {}
Module mod;
explicit LiveFunctions(Module mod) : mod(mod), global_funcs() {}
std::unordered_set<GlobalVar, NodeHash, NodeEqual> visited_funcs;
std::unordered_set<GlobalVar, NodeHash, NodeEqual> global_funcs;
......@@ -100,7 +100,7 @@ struct LiveFunctions : ExprVisitor {
GlobalVar var = GetRef<GlobalVar>(var_node);
auto it = visited_funcs.find(var);
if (it == visited_funcs.end()) {
auto func = env->Lookup(var);
auto func = mod->Lookup(var);
visited_funcs.insert(var);
// The last pass has trasnformed functions of the form:
//
......@@ -134,7 +134,7 @@ struct LiveFunctions : ExprVisitor {
RELAY_LOG(INFO) << "LiveOps: CallNode=" << GetRef<Call>(call);
if (auto gv_node = call->op.as<GlobalVarNode>()) {
GlobalVar gvar = GetRef<GlobalVar>(gv_node);
Function func = env->Lookup(gvar);
Function func = mod->Lookup(gvar);
auto attr = FunctionGetAttr(func, "Primitive");
......@@ -159,15 +159,15 @@ using FCompute = TypedPackedFunc<Array<Tensor>(
using FSchedule = TypedPackedFunc<Schedule(const Array<Tensor>&, std::string)>;
/*! \brief Return the set of operators in their TVM format. */
Array<LoweredOp> LowerOps(const Environment& env, const Expr& e,
Array<LoweredOp> LowerOps(const Module& mod, const Expr& e,
const std::string& target) {
RELAY_LOG(INFO) << "LowerOps: e=" << e;
auto flower_ptr = Registry::Get("relay.op.compiler._lower");
CHECK(flower_ptr);
PackedFunc flower = *flower_ptr;
auto abstracted_e = AbstractLocalFunctions(env).Abstract(e);
auto live_funcs = LiveFunctions(env);
auto abstracted_e = AbstractLocalFunctions(mod).Abstract(e);
auto live_funcs = LiveFunctions(mod);
live_funcs.VisitExpr(abstracted_e);
auto schedule_reg = Op::GetAttr<FSchedule>("FTVMSchedule");
......@@ -176,7 +176,7 @@ Array<LoweredOp> LowerOps(const Environment& env, const Expr& e,
Array<LoweredOp> lowered_funcs;
for (auto func_name : live_funcs.global_funcs) {
auto func = env->Lookup(func_name);
auto func = mod->Lookup(func_name);
auto call = Downcast<Call>(func->body);
auto op_node = call->op.as<OpNode>();
CHECK(op_node) << "violated invariant that primtiive calls contain a single op call";
......@@ -205,7 +205,7 @@ Array<LoweredOp> LowerOps(const Environment& env, const Expr& e,
LoweredFunc lf =
flower(op->name + std::to_string(hash), schedule, inputs, outputs);
func = FunctionSetAttr(func, "LoweredFunc", lf);
env->Add(func_name, func, true);
mod->Add(func_name, func, true);
lowered_funcs.push_back(LoweredOpNode::make(func, lf));
}
......
......@@ -104,8 +104,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// constructors
TypeInferencer() {
}
explicit TypeInferencer(Environment env)
: env_(env) {
explicit TypeInferencer(Module mod)
: mod_(mod) {
}
// inference the type of expr.
......@@ -115,7 +115,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// type resolver that maps back to type
class Resolver;
// internal environment
Environment env_;
Module mod_;
// map from expression to checked type
// type inferencer will populate it up
std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual> type_map_;
......@@ -164,9 +164,9 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
Type VisitExpr_(const GlobalVarNode* op) final {
GlobalVar var = GetRef<GlobalVar>(op);
CHECK(env_.defined())
CHECK(mod_.defined())
<< "Cannot do type inference without a global variable";
Expr e = env_->Lookup(var);
Expr e = mod_->Lookup(var);
return e->checked_type();
}
......@@ -511,20 +511,20 @@ Expr TypeInferencer::Infer(Expr expr) {
}
Expr InferType(const Expr& expr, const Environment& env) {
auto e = TypeInferencer(env).Infer(expr);
Expr InferType(const Expr& expr, const Module& mod) {
auto e = TypeInferencer(mod).Infer(expr);
CHECK(WellFormed(e));
return e;
}
Function InferType(const Function& func,
const Environment& env,
const Module& mod,
const GlobalVar& var) {
Function func_copy = Function(make_node<FunctionNode>(*func.operator->()));
func_copy->checked_type_ = func_copy->func_type_annotation();
env->functions.Set(var, func_copy);
Expr func_ret = TypeInferencer(env).Infer(func_copy);
auto map_node = env->functions.CopyOnWrite();
mod->functions.Set(var, func_copy);
Expr func_ret = TypeInferencer(mod).Infer(func_copy);
auto map_node = mod->functions.CopyOnWrite();
map_node->data.erase(var.node_);
CHECK(WellFormed(func_ret));
return Downcast<Function>(func_ret);
......
......@@ -11,7 +11,7 @@ TEST(Relay, SelfReference) {
auto x = relay::VarNode::make("x", type_a);
auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, type_b, Array<relay::TypeVar>{});
auto fx = relay::CallNode::make(f, Array<relay::Expr>{ x });
auto type_fx = relay::InferType(fx, relay::EnvironmentNode::make(Map<relay::GlobalVar, relay::Function>{}));
auto type_fx = relay::InferType(fx, relay::ModuleNode::make(Map<relay::GlobalVar, relay::Function>{}));
CHECK_EQ(type_fx->checked_type(), type_a);
}
......
......@@ -6,10 +6,10 @@ from tvm.relay.ir_pass import infer_type
from tvm.relay.interpreter import Interpreter
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.op import add
from tvm.relay.env import Environment
from tvm.relay.module import Module
# @tq, @jr should we put this in testing ns?
def check_rts(expr, args, expected_result, env=None):
def check_rts(expr, args, expected_result, mod=None):
"""
Check that evaluating `expr` applied to the arguments produces
`result` on both the evaluator and TVM runtime.
......@@ -25,8 +25,8 @@ def check_rts(expr, args, expected_result, env=None):
expected_result:
The expected result of running the expression.
"""
intrp = create_executor('graph', env=env)
graph = create_executor('graph', env=env)
intrp = create_executor('graph', mod=mod)
graph = create_executor('graph', mod=mod)
eval_result = intrp.evaluate(expr)(*args)
rts_result = graph.evaluate(expr)(*args)
np.testing.assert_allclose(eval_result.asnumpy(), rts_result.asnumpy())
......
......@@ -7,8 +7,8 @@ from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay import testing, create_executor
def check_eval(expr, args, expected_result, env=None, rtol=1e-07):
intrp = create_executor(env=env)
def check_eval(expr, args, expected_result, mod=None, rtol=1e-07):
intrp = create_executor(mod=mod)
result = intrp.evaluate(expr)(*args)
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
......@@ -87,7 +87,7 @@ def test_subtract():
check_eval(func, [i_data], 0)
def test_simple_loop():
env = relay.env.Environment({})
mod = relay.module.Module({})
sum_up = relay.GlobalVar('sum_up')
i = relay.var('i', shape=[], dtype='int32')
sb = ScopeBuilder()
......@@ -98,12 +98,12 @@ def test_simple_loop():
rec_call = relay.Call(sum_up, [one_less])
sb.ret(op.add(rec_call, i))
func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32'))
env[sum_up] = func
mod[sum_up] = func
i_data = np.array(10, dtype='int32')
check_eval(sum_up, [i_data], sum(range(1, 11)), env=env)
check_eval(sum_up, [i_data], sum(range(1, 11)), mod=mod)
def test_loop():
env = relay.env.Environment({})
mod = relay.module.Module({})
sum_up = relay.GlobalVar('sum_up')
i = relay.var('i', shape=[], dtype='int32')
accum = relay.var('accum', shape=[], dtype='int32')
......@@ -115,10 +115,10 @@ def test_loop():
new_accum = op.add(accum, i)
sb.ret(relay.Call(sum_up, [one_less, new_accum]))
func = relay.Function([i, accum], sb.get())
env[sum_up] = func
mod[sum_up] = func
i_data = np.array(10, dtype='int32')
accum_data = np.array(0, dtype='int32')
check_eval(sum_up, [i_data, accum_data], sum(range(1, 11)), env=env)
check_eval(sum_up, [i_data, accum_data], sum(range(1, 11)), mod=mod)
def test_mlp():
pass
......
......@@ -28,7 +28,7 @@ def test_env():
z = relay.add(x, y)
z = relay.add(z, z)
f = relay.Function([x, y], z)
env = relay.Environment()
env = relay.Module()
env["myf"] = f
text = env.astext()
assert "def @myf" in text
......
......@@ -9,8 +9,8 @@ from tvm.relay import op
from tvm.relay.scope_builder import ScopeBuilder
def assert_has_type(expr, typ, env=relay.env.Environment({})):
checked_expr = infer_type(expr, env)
def assert_has_type(expr, typ, mod=relay.module.Module({})):
checked_expr = infer_type(expr, mod)
checked_type = checked_expr.checked_type
if checked_type != typ:
raise RuntimeError("Type mismatch %s vs %s" % (
......@@ -105,10 +105,10 @@ def test_recursion():
sb.ret(data)
with sb.else_scope():
sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data)))
env = relay.Environment()
env[f] = relay.Function([n, data], sb.get())
assert "%3 = @f(%1, %2)" in env.astext()
assert env[f].checked_type == relay.FuncType([ti32, tf32], tf32)
mod = relay.Module()
mod[f] = relay.Function([n, data], sb.get())
assert "%3 = @f(%1, %2)" in mod.astext()
assert mod[f].checked_type == relay.FuncType([ti32, tf32], tf32)
# This currently fails and should pass under the type system.
#
......@@ -179,12 +179,12 @@ def test_self_reference():
def test_global_var_cow_issue():
env = relay.env.Environment({})
mod = relay.Module({})
gv = relay.GlobalVar("foo")
x = relay.var('x', shape=[])
func = relay.Function([x], relay.Call(gv, [x]),
relay.TensorType([], 'float32'))
env[gv] = func
mod[gv] = func
def test_equal():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment