Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
ead3ac6c
Commit
ead3ac6c
authored
Nov 02, 2018
by
Jared Roesch
Committed by
Tianqi Chen
Nov 02, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Rename relay::Environment to relay::Module (#2054)
parent
420ec786
Hide whitespace changes
Inline
Side-by-side
Showing
28 changed files
with
224 additions
and
226 deletions
+224
-226
include/tvm/relay/base.h
+1
-1
include/tvm/relay/build_module.h
+3
-3
include/tvm/relay/expr.h
+1
-1
include/tvm/relay/interpreter.h
+3
-3
include/tvm/relay/module.h
+20
-20
include/tvm/relay/pass.h
+9
-9
include/tvm/relay/type.h
+2
-2
python/tvm/relay/__init__.py
+2
-2
python/tvm/relay/_ir_pass.pyi
+4
-5
python/tvm/relay/_module.py
+2
-2
python/tvm/relay/_module.pyi
+1
-2
python/tvm/relay/build_module.py
+6
-6
python/tvm/relay/expr.py
+1
-1
python/tvm/relay/interpreter.py
+22
-22
python/tvm/relay/ir_pass.py
+14
-14
python/tvm/relay/module.py
+16
-16
src/relay/interpreter.cc
+16
-16
src/relay/ir/module.cc
+41
-41
src/relay/ir/text_printer.cc
+5
-5
src/relay/pass/fuse_ops.cc
+5
-5
src/relay/pass/kind_check.cc
+2
-2
src/relay/pass/lower_ops.cc
+15
-15
src/relay/pass/type_infer.cc
+11
-11
tests/cpp/relay_pass_type_infer_test.cc
+1
-1
tests/python/relay/test_graph_runtime.py
+4
-4
tests/python/relay/test_interpreter.py
+8
-8
tests/python/relay/test_ir_text_printer.py
+1
-1
tests/python/relay/test_type_infer.py
+8
-8
No files found.
include/tvm/relay/base.h
View file @
ead3ac6c
...
...
@@ -165,7 +165,7 @@ class RelayNode : public Node {
TVM_DECLARE_BASE_NODE_INFO
(
RelayNode
,
Node
);
};
struct
Environment
;
struct
Module
;
}
// namespace relay
}
// namespace tvm
...
...
include/tvm/relay/build_module.h
View file @
ead3ac6c
...
...
@@ -8,7 +8,7 @@
#define TVM_RELAY_BUILD_MODULE_H_
#include <tvm/lowered_func.h>
#include <tvm/relay/
environment
.h>
#include <tvm/relay/
module
.h>
#include <tvm/relay/expr.h>
#include <string>
...
...
@@ -61,13 +61,13 @@ RELAY_DEFINE_NODE_REF(LoweredOp, LoweredOpNode, NodeRef);
* \note This will do a reachability analysis and lower all definitions
* reachable from the provided expression.
*
* \param
env The environment
.
* \param
mod The module
.
* \param expr The expression with operations to be lowered.
* \param target The target to lower the functions to.
*
* \return The set of lowered operations.
*/
Array
<
LoweredOp
>
LowerOps
(
const
Environment
&
env
,
const
Expr
&
expr
,
Array
<
LoweredOp
>
LowerOps
(
const
Module
&
mod
,
const
Expr
&
expr
,
const
std
::
string
&
target
=
"llvm"
);
}
// namespace relay
...
...
include/tvm/relay/expr.h
View file @
ead3ac6c
...
...
@@ -160,7 +160,7 @@ class VarNode : public ExprNode {
RELAY_DEFINE_NODE_REF
(
Var
,
VarNode
,
Expr
);
/*!
* \brief Global variable that leaves in the top-level
environment
.
* \brief Global variable that leaves in the top-level
module
.
* This is used to enable recursive calls between function.
*
* \note A GlobalVar may only point to functions.
...
...
include/tvm/relay/interpreter.h
View file @
ead3ac6c
...
...
@@ -4,7 +4,7 @@
* \brief An interpreter for Relay.
*
* This file implements a simple reference interpreter for Relay programs.
* Given a Relay
environment
, and a Relay expression it produces a value.
* Given a Relay
module
, and a Relay expression it produces a value.
*
* The interpreter's values are a naive representation of the values that
* can be produced by a Relay program and are exposed via tvm::Node's
...
...
@@ -16,7 +16,7 @@
#ifndef TVM_RELAY_INTERPRETER_H_
#define TVM_RELAY_INTERPRETER_H_
#include <tvm/relay/
environment
.h>
#include <tvm/relay/
module
.h>
#include <tvm/relay/expr.h>
namespace
tvm
{
...
...
@@ -39,7 +39,7 @@ class Value;
* Our intent is that this will never be the most efficient implementation of
* Relay's semantics, but a readable and clear one.
*/
Value
Evaluate
(
Environment
env
,
Expr
e
);
Value
Evaluate
(
Module
mod
,
Expr
e
);
/*! \brief The base container type of Relay values. */
class
ValueNode
:
public
RelayNode
{
...
...
include/tvm/relay/
environment
.h
→
include/tvm/relay/
module
.h
View file @
ead3ac6c
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/
environment
.h
* \file tvm/relay/
module
.h
* \brief The global environment: contains information needed to
* compile & optimize Relay programs.
*/
#ifndef TVM_RELAY_
ENVIRONMENT
_H_
#define TVM_RELAY_
ENVIRONMENT
_H_
#ifndef TVM_RELAY_
MODULE
_H_
#define TVM_RELAY_
MODULE
_H_
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
...
...
@@ -17,7 +17,7 @@
namespace
tvm
{
namespace
relay
{
struct
Environment
;
struct
Module
;
/*! \brief The global environment of Relay programs.
*
...
...
@@ -28,29 +28,29 @@ struct Environment;
* options.
*
* Many operations require access to the global
*
Environment. We pass the Environment
by value
*
Module. We pass the Module
by value
* in a functional style as an explicit argument,
* but we mutate the
Environment
while optimizing
* but we mutate the
Module
while optimizing
* Relay programs.
*
* The functional style allows users to construct custom
* environments easily, for example each thread can store
* an
Environment
while auto-tuning.
* an
Module
while auto-tuning.
* */
class
Environment
Node
:
public
RelayNode
{
class
Module
Node
:
public
RelayNode
{
public
:
/*! \brief A map from ids to all global functions. */
tvm
::
Map
<
GlobalVar
,
Function
>
functions
;
Environment
Node
()
{}
Module
Node
()
{}
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"functions"
,
&
functions
);
v
->
Visit
(
"global_var_map_"
,
&
global_var_map_
);
}
TVM_DLL
static
Environment
make
(
tvm
::
Map
<
GlobalVar
,
Function
>
global_funcs
);
TVM_DLL
static
Module
make
(
tvm
::
Map
<
GlobalVar
,
Function
>
global_funcs
);
/*!
* \brief Add a function to the global environment.
...
...
@@ -100,10 +100,10 @@ class EnvironmentNode : public RelayNode {
* functions in another environment.
* \param other The other environment.
*/
void
Update
(
const
Environment
&
other
);
void
Update
(
const
Module
&
other
);
static
constexpr
const
char
*
_type_key
=
"relay.
Environment
"
;
TVM_DECLARE_NODE_TYPE_INFO
(
Environment
Node
,
Node
);
static
constexpr
const
char
*
_type_key
=
"relay.
Module
"
;
TVM_DECLARE_NODE_TYPE_INFO
(
Module
Node
,
Node
);
private
:
/*! \brief A map from string names to global variables that
...
...
@@ -112,18 +112,18 @@ class EnvironmentNode : public RelayNode {
tvm
::
Map
<
std
::
string
,
GlobalVar
>
global_var_map_
;
};
struct
Environment
:
public
NodeRef
{
Environment
()
{}
explicit
Environment
(
NodePtr
<
tvm
::
Node
>
p
)
:
NodeRef
(
p
)
{}
struct
Module
:
public
NodeRef
{
Module
()
{}
explicit
Module
(
NodePtr
<
tvm
::
Node
>
p
)
:
NodeRef
(
p
)
{}
inline
Environment
Node
*
operator
->
()
const
{
return
static_cast
<
Environment
Node
*>
(
node_
.
get
());
inline
Module
Node
*
operator
->
()
const
{
return
static_cast
<
Module
Node
*>
(
node_
.
get
());
}
using
ContainerType
=
Environment
Node
;
using
ContainerType
=
Module
Node
;
};
}
// namespace relay
}
// namespace tvm
#endif // TVM_RELAY_
ENVIRONMENT
_H_
#endif // TVM_RELAY_
MODULE
_H_
include/tvm/relay/pass.h
View file @
ead3ac6c
...
...
@@ -6,7 +6,7 @@
#ifndef TVM_RELAY_PASS_H_
#define TVM_RELAY_PASS_H_
#include <tvm/relay/
environment
.h>
#include <tvm/relay/
module
.h>
#include <tvm/relay/expr.h>
#include <string>
...
...
@@ -21,23 +21,23 @@ namespace relay {
* populated with the result type.
*
* \param expr The expression to type check.
* \param
env The environment
used for referencing global functions, can be
* \param
mod The module
used for referencing global functions, can be
* None.
*
* \return A type checked expression with its checked_type field populated.
*/
Expr
InferType
(
const
Expr
&
expr
,
const
Environment
&
env
);
Expr
InferType
(
const
Expr
&
expr
,
const
Module
&
mod
);
/*!
* \brief Infer the type of a function as if it is mapped to var in the
env
.
* \brief Infer the type of a function as if it is mapped to var in the
mod
.
*
* \param f the function.
* \param
env The environment
used for referencing global functions.
* \param
mod The module
used for referencing global functions.
* \param var The global variable corresponding to the function.
*
* \return A type checked Function with its checked_type field populated.
* \note this function mutates
env
and is not thread-safe.
* \note this function mutates
mod
and is not thread-safe.
*/
Function
InferType
(
const
Function
&
f
,
const
Environment
&
env
,
Function
InferType
(
const
Function
&
f
,
const
Module
&
mod
,
const
GlobalVar
&
var
);
/*!
...
...
@@ -52,11 +52,11 @@ Function InferType(const Function& f, const Environment& env,
* a data type such as `int`, `float`, `uint`.
*
* \param t The type to check.
* \param
env The global environment
.
* \param
mod The global module
.
*
* \return true if the rules are satisified otherwise false
*/
bool
KindCheck
(
const
Type
&
t
,
const
Environment
&
env
);
bool
KindCheck
(
const
Type
&
t
,
const
Module
&
mod
);
/*! \brief Compare two expressions for structural equivalence.
*
...
...
include/tvm/relay/type.h
View file @
ead3ac6c
...
...
@@ -349,14 +349,14 @@ class TypeRelation;
/*!
* \brief TypeRelation container.
* \note This node is not directly serializable.
* The type function need to be lookedup in the
environment
.
* The type function need to be lookedup in the
module
.
*/
class
TypeRelationNode
:
public
TypeConstraintNode
{
public
:
/*!
* \brief The function on input and output variables which
* this is not directly serializable,
* need to be looked-up in the
environment
.
* need to be looked-up in the
module
.
*/
TypeRelationFn
func
;
/*! \brief The type arguments to the type function. */
...
...
python/tvm/relay/__init__.py
View file @
ead3ac6c
...
...
@@ -5,7 +5,7 @@ from ..api import register_func
from
.
import
base
from
.
import
ty
from
.
import
expr
from
.
import
env
from
.
import
module
from
.
import
ir_pass
from
.build_module
import
build
from
.interpreter
import
create_executor
...
...
@@ -26,7 +26,7 @@ from .scope_builder import ScopeBuilder
Span
=
base
.
Span
# Env
Environment
=
env
.
Environment
Module
=
module
.
Module
# Type
Type
=
ty
.
Type
...
...
python/tvm/relay/_ir_pass.pyi
View file @
ead3ac6c
from .env import
Environment
from .env import
Module
from . import ir
def check_expr(env:
Environment
, expr: ir.Expr) -> ir.Type: ...
def generalize(env:
Environment
, expr: ir.Expr) -> ir.Expr: ...
def check_expr(env:
Module
, expr: ir.Expr) -> ir.Type: ...
def generalize(env:
Module
, expr: ir.Expr) -> ir.Expr: ...
def _get_checked_type(expr: ir.Expr) -> ir.Type: ...
def well_formed(expr: ir.Expr) -> bool: ...
def dead_code_elimination(expr: ir.Expr) -> ir.Expr: ...
\ No newline at end of file
def dead_code_elimination(expr: ir.Expr) -> ir.Expr: ...
python/tvm/relay/_
env
.py
→
python/tvm/relay/_
module
.py
View file @
ead3ac6c
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface to the
Environment
exposed from C++."""
"""The interface to the
Module
exposed from C++."""
from
tvm._ffi.function
import
_init_api
_init_api
(
"relay._
env
"
,
__name__
)
_init_api
(
"relay._
module
"
,
__name__
)
python/tvm/relay/_
env
.pyi
→
python/tvm/relay/_
module
.pyi
View file @
ead3ac6c
...
...
@@ -2,4 +2,4 @@ from typing import Union, Tuple, Dict, List
from relay.ir import GlobalId, OperatorId, Item, NodeBase, Span, FileId
from relay.ir import ShapeExtension, Operator, Defn
class Environment(NodeBase): ...
\ No newline at end of file
class Module(NodeBase): ...
python/tvm/relay/build_module.py
View file @
ead3ac6c
...
...
@@ -5,9 +5,9 @@ from a Relay expression.
from
..build_module
import
build
as
tvm_build_module
from
.
graph_runtime_codegen
import
GraphRuntimeCodegen
from
.
import
ir_pass
from
.
env
import
Environment
from
.
module
import
Module
def
build
(
func
,
params
=
None
,
target
=
None
,
env
=
None
):
def
build
(
func
,
params
=
None
,
target
=
None
,
mod
=
None
):
"""
Compile a single function to the components needed by the
TVM RTS.
...
...
@@ -29,15 +29,15 @@ def build(func, params=None, target=None, env=None):
if
target
is
None
:
target
=
'llvm'
if
env
is
None
:
env
=
Environment
({})
if
mod
is
None
:
mod
=
Module
({})
comp
=
GraphRuntimeCodegen
(
env
)
comp
=
GraphRuntimeCodegen
(
mod
)
# NB(@jroesch) This creates lowered functions, and generates names for them
#
# We need these names to emit the correct graph as these are names of the
# functions contained in the module.
lowered_ops
=
ir_pass
.
lower_ops
(
env
,
func
)
lowered_ops
=
ir_pass
.
lower_ops
(
mod
,
func
)
mod
=
tvm_build_module
([
lf
.
lowered_func
for
lf
in
lowered_ops
],
target
)
# Therefore the call to compile must come after.
...
...
python/tvm/relay/expr.py
View file @
ead3ac6c
...
...
@@ -172,7 +172,7 @@ class GlobalVar(Expr):
"""A global variable in Tvm.Relay.
GlobalVar is used to refer to the global functions
stored in the
environment
.
stored in the
module
.
Parameters
----------
...
...
python/tvm/relay/interpreter.py
View file @
ead3ac6c
...
...
@@ -8,7 +8,7 @@ from . import build_module
from
.
import
_make
from
.
import
_interpreter
from
.
import
ir_pass
from
.
env
import
Environment
from
.
module
import
Module
from
.expr
import
Call
,
Constant
,
GlobalVar
,
Function
,
const
from
.scope_builder
import
ScopeBuilder
from
.._ffi.base
import
integer_types
...
...
@@ -90,24 +90,24 @@ def _arg_to_ast(arg):
class
Executor
(
object
):
"""An abstract interface for executing Relay programs."""
def
__init__
(
self
,
env
=
None
):
def
__init__
(
self
,
mod
=
None
):
"""
Parameters
----------
env: relay.Environment
The
environment
.
mod: relay.Module
The
module
.
"""
if
env
is
None
:
self
.
env
=
Environment
({})
if
mod
is
None
:
self
.
mod
=
Module
({})
else
:
self
.
env
=
env
self
.
mod
=
mod
def
optimize
(
self
,
expr
):
# TODO: We need to move this optimization code into the optimizer/pass manager
ck_expr
=
ir_pass
.
infer_type
(
expr
,
env
=
self
.
env
)
fused_expr
=
ir_pass
.
fuse_ops
(
self
.
env
,
ck_expr
)
ck_fused
=
ir_pass
.
infer_type
(
fused_expr
,
env
=
self
.
env
)
ck_expr
=
ir_pass
.
infer_type
(
expr
,
mod
=
self
.
mod
)
fused_expr
=
ir_pass
.
fuse_ops
(
self
.
mod
,
ck_expr
)
ck_fused
=
ir_pass
.
infer_type
(
fused_expr
,
mod
=
self
.
mod
)
return
ck_fused
def
_make_executor
(
self
,
_
):
...
...
@@ -153,8 +153,8 @@ class Interpreter(Executor):
"""
A wrapper around the Relay interpreter, implements the excecutor interface.
"""
def
__init__
(
self
,
env
=
None
):
Executor
.
__init__
(
self
,
env
)
def
__init__
(
self
,
mod
=
None
):
Executor
.
__init__
(
self
,
mod
)
def
_make_executor
(
self
,
expr
):
def
_interp_wrapper
(
*
args
):
...
...
@@ -163,28 +163,28 @@ class Interpreter(Executor):
relay_args
.
append
(
_arg_to_ast
(
arg
))
if
isinstance
(
expr
,
GlobalVar
):
func
=
self
.
env
[
expr
]
func
=
self
.
mod
[
expr
]
func
=
self
.
optimize
(
func
)
self
.
env
.
_add
(
expr
,
func
,
True
)
self
.
mod
.
_add
(
expr
,
func
,
True
)
opt_expr
=
Call
(
expr
,
relay_args
)
return
_interpreter
.
evaluate
(
self
.
env
,
opt_expr
)
return
_interpreter
.
evaluate
(
self
.
mod
,
opt_expr
)
else
:
call
=
Call
(
expr
,
relay_args
)
opt_expr
=
self
.
optimize
(
call
)
return
_interpreter
.
evaluate
(
self
.
env
,
opt_expr
)
return
_interpreter
.
evaluate
(
self
.
mod
,
opt_expr
)
return
_interp_wrapper
class
GraphRuntime
(
Executor
):
"""A wrapper around the TVM graph runtime, implements the Executor interface."""
def
__init__
(
self
,
env
=
None
):
Executor
.
__init__
(
self
,
env
)
def
__init__
(
self
,
mod
=
None
):
Executor
.
__init__
(
self
,
mod
)
def
_make_executor
(
self
,
expr
):
def
_graph_wrapper
(
*
args
):
func
=
self
.
optimize
(
expr
)
graph_json
,
mod
,
params
=
build_module
.
build
(
func
,
env
=
self
.
env
)
graph_json
,
mod
,
params
=
build_module
.
build
(
func
,
mod
=
self
.
mod
)
assert
params
is
None
gmodule
=
tvm_runtime
.
create
(
graph_json
,
mod
,
cpu
(
0
))
# Create map of inputs.
...
...
@@ -199,10 +199,10 @@ class GraphRuntime(Executor):
return
_graph_wrapper
def
create_executor
(
mode
=
'debug'
,
env
=
None
):
def
create_executor
(
mode
=
'debug'
,
mod
=
None
):
if
mode
==
'debug'
:
return
Interpreter
(
env
)
return
Interpreter
(
mod
)
elif
mode
==
'graph'
:
return
GraphRuntime
(
env
)
return
GraphRuntime
(
mod
)
else
:
raise
Exception
(
"unknown mode {0}"
.
format
(
mode
))
python/tvm/relay/ir_pass.py
View file @
ead3ac6c
...
...
@@ -11,16 +11,16 @@ from .expr import Expr
from
.ty
import
Type
def
infer_type
(
expr
,
env
=
None
):
"""Infer the type of expr under the context of
env
.
def
infer_type
(
expr
,
mod
=
None
):
"""Infer the type of expr under the context of
mod
.
Parameters
----------
expr: tvm.relay.Expr
The input expression.
env: Optional[tvm.relay.Environment
]
The global
environment
.
mod: Optional[tvm.relay.Module
]
The global
module
.
Returns
...
...
@@ -28,7 +28,7 @@ def infer_type(expr, env=None):
checked_expr : tvm.relay.Expr
The checked expression.
"""
return
_ir_pass
.
infer_type
(
expr
,
env
)
return
_ir_pass
.
infer_type
(
expr
,
mod
)
def
backward_fold_scale_axis
(
expr
):
...
...
@@ -93,7 +93,7 @@ def well_formed(expr):
return
_ir_pass
.
well_formed
(
expr
)
def
check_kind
(
t
,
env
=
None
):
def
check_kind
(
t
,
mod
=
None
):
"""Check that the type is well kinded.
For example, this mean type cannot has tensor of tensor, or is a tuple type of 2 shapes.
...
...
@@ -102,8 +102,8 @@ def check_kind(t, env=None):
t: tvm.relay.Type
The type to check
env: tvm.relay.Environment
, optional
The global
environment
mod: tvm.relay.Module
, optional
The global
module
Returns
-------
...
...
@@ -117,8 +117,8 @@ def check_kind(t, env=None):
assert not check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Shape)]))
assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)]))
"""
if
env
is
not
None
:
return
_ir_pass
.
check_kind
(
t
,
env
)
if
mod
is
not
None
:
return
_ir_pass
.
check_kind
(
t
,
mod
)
else
:
return
_ir_pass
.
check_kind
(
t
)
...
...
@@ -256,8 +256,8 @@ def structural_hash(value):
"relay.Expr or relay.Type"
)
.
format
(
type
(
value
))
raise
TypeError
(
msg
)
def
fuse_ops
(
expr
,
env
):
return
_ir_pass
.
FuseOps
(
env
,
expr
)
def
fuse_ops
(
expr
,
mod
):
return
_ir_pass
.
FuseOps
(
mod
,
expr
)
def
lower_ops
(
env
,
expr
,
target
=
'llvm'
):
return
_ir_pass
.
LowerOps
(
env
,
expr
,
target
)
def
lower_ops
(
mod
,
expr
,
target
=
'llvm'
):
return
_ir_pass
.
LowerOps
(
mod
,
expr
,
target
)
python/tvm/relay/
env
.py
→
python/tvm/relay/
module
.py
View file @
ead3ac6c
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
"""A global
environment
storing everything needed to interpret or compile a Relay program."""
"""A global
module
storing everything needed to interpret or compile a Relay program."""
from
.base
import
register_relay_node
,
RelayNode
from
.._ffi
import
base
as
_base
from
.
import
_make
from
.
import
_
env
from
.
import
_
module
from
.
import
expr
as
_expr
@register_relay_node
class
Environment
(
RelayNode
):
"""The global Relay
environment
containing collection of functions.
class
Module
(
RelayNode
):
"""The global Relay
module
containing collection of functions.
Each global function is identified by an unique tvm.relay.GlobalVar.
tvm.relay.GlobalVar and
Environment
is necessary in order to enable
tvm.relay.GlobalVar and
Module
is necessary in order to enable
recursions in function to avoid cyclic reference in the function.x
Parameters
...
...
@@ -32,10 +32,10 @@ class Environment(RelayNode):
raise
TypeError
(
"Expect functions to be Dict[GlobalVar, Function]"
)
mapped_funcs
[
k
]
=
v
functions
=
mapped_funcs
self
.
__init_handle_by_constructor__
(
_make
.
Environment
,
functions
)
self
.
__init_handle_by_constructor__
(
_make
.
Module
,
functions
)
def
__setitem__
(
self
,
var
,
func
):
"""Add a function to the
environment
.
"""Add a function to the
module
.
Parameters
---------
...
...
@@ -50,7 +50,7 @@ class Environment(RelayNode):
def
_add
(
self
,
var
,
func
,
update
=
False
):
if
isinstance
(
var
,
_base
.
string_types
):
var
=
_expr
.
GlobalVar
(
var
)
return
_
env
.
Environment
_Add
(
self
,
var
,
func
,
update
)
return
_
module
.
Module
_Add
(
self
,
var
,
func
,
update
)
def
__getitem__
(
self
,
var
):
"""Lookup a global function by name or by variable.
...
...
@@ -66,21 +66,21 @@ class Environment(RelayNode):
The function referenced by :code:`var`.
"""
if
isinstance
(
var
,
_base
.
string_types
):
return
_
env
.
Environment
_Lookup_str
(
self
,
var
)
return
_
module
.
Module
_Lookup_str
(
self
,
var
)
else
:
return
_
env
.
Environment
_Lookup
(
self
,
var
)
return
_
module
.
Module
_Lookup
(
self
,
var
)
def
update
(
self
,
other
):
"""Insert functions in another
Environment
to current one.
"""Insert functions in another
Module
to current one.
Parameters
----------
other:
Environment
The
environment to merge into the current Environment
.
other:
Module
The
module to merge into the current Module
.
"""
if
isinstance
(
other
,
dict
):
other
=
Environment
(
other
)
return
_
env
.
Environment
_Update
(
self
,
other
)
other
=
Module
(
other
)
return
_
module
.
Module
_Update
(
self
,
other
)
def
get_global_var
(
self
,
name
):
"""Get a global variable in the function by name.
...
...
@@ -99,4 +99,4 @@ class Environment(RelayNode):
------
tvm.TVMError if we cannot find corresponding global var.
"""
return
_
env
.
Environment
_GetGlobalVar
(
self
,
name
)
return
_
module
.
Module
_GetGlobalVar
(
self
,
name
)
src/relay/interpreter.cc
View file @
ead3ac6c
...
...
@@ -183,7 +183,7 @@ struct ExprEqual {
};
struct
Interpreter
:
ExprFunctor
<
Value
(
const
Expr
&
n
)
>
{
Environment
env
;
Module
mod
;
Stack
stack
;
using
JitKey
=
Function
;
...
...
@@ -197,8 +197,8 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
return
f
();
}
Interpreter
(
Environment
env
)
:
env
(
env
),
operator_map_
()
{}
Interpreter
(
Environment
env
,
OpMap
operator_map
)
:
env
(
env
),
operator_map_
(
operator_map
)
{}
Interpreter
(
Module
mod
)
:
mod
(
mod
),
operator_map_
()
{}
Interpreter
(
Module
mod
,
OpMap
operator_map
)
:
mod
(
mod
),
operator_map_
(
operator_map
)
{}
void
extend
(
const
Var
&
id
,
Value
v
)
{
this
->
stack
.
current_frame
().
locals
.
Set
(
id
,
v
);
...
...
@@ -223,7 +223,7 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
}
Value
VisitExpr_
(
const
GlobalVarNode
*
op
)
override
{
return
Eval
(
this
->
env
->
Lookup
(
GetRef
<
GlobalVar
>
(
op
)));
return
Eval
(
this
->
mod
->
Lookup
(
GetRef
<
GlobalVar
>
(
op
)));
}
Value
VisitExpr_
(
const
OpNode
*
id
)
override
{
...
...
@@ -251,14 +251,14 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
Value
VisitExpr_
(
const
FunctionNode
*
func_node
)
override
{
auto
func
=
GetRef
<
Function
>
(
func_node
);
tvm
::
Map
<
Var
,
Value
>
captured_
env
;
tvm
::
Map
<
Var
,
Value
>
captured_
mod
;
Array
<
Var
>
free_vars
=
FreeVars
(
func
);
for
(
const
auto
&
var
:
free_vars
)
{
captured_
env
.
Set
(
var
,
Eval
(
var
));
captured_
mod
.
Set
(
var
,
Eval
(
var
));
}
return
ClosureNode
::
make
(
captured_
env
,
func
);
return
ClosureNode
::
make
(
captured_
mod
,
func
);
}
inline
Value
InvokeCompiledOp
(
PackedFunc
func
,
const
Array
<
Value
>&
args
,
...
...
@@ -315,7 +315,7 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
locals
.
Set
(
func
->
params
[
i
],
args
[
i
]);
}
// Add the var to value mappings from the Closure's
env
ironment.
// Add the var to value mappings from the Closure's
mod
ironment.
for
(
auto
it
=
closure
->
env
.
begin
();
it
!=
closure
->
env
.
end
();
++
it
)
{
CHECK_EQ
(
locals
.
count
((
*
it
).
first
),
0
);
locals
.
Set
((
*
it
).
first
,
(
*
it
).
second
);
...
...
@@ -384,9 +384,9 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
}
};
Interpreter
::
OpMap
CompileOperators
(
const
Environment
&
env
,
const
Expr
&
e
)
{
Interpreter
::
OpMap
CompileOperators
(
const
Module
&
mod
,
const
Expr
&
e
)
{
Interpreter
::
OpMap
op_map
;
auto
lowered_ops
=
LowerOps
(
env
,
e
);
auto
lowered_ops
=
LowerOps
(
mod
,
e
);
RELAY_LOG
(
INFO
)
<<
"LoweredFuncs: "
<<
lowered_ops
<<
std
::
endl
;
if
(
lowered_ops
.
size
())
{
const
PackedFunc
*
fbuild_ptr
=
Registry
::
Get
(
"relay.op.compiler._build"
);
...
...
@@ -399,7 +399,7 @@ Interpreter::OpMap CompileOperators(const Environment& env, const Expr& e) {
lowered_funcs
.
push_back
(
lop
->
lowered_func
);
}
Module
module
=
fbuild
(
lowered_funcs
);
runtime
::
Module
module
=
fbuild
(
lowered_funcs
);
// Loop over the lowered operations to map them into the operator map.
for
(
auto
lop
:
lowered_ops
)
{
...
...
@@ -415,17 +415,17 @@ Interpreter::OpMap CompileOperators(const Environment& env, const Expr& e) {
return
op_map
;
}
Value
Evaluate
(
Environment
env
,
Expr
e
)
{
auto
op_map
=
CompileOperators
(
env
,
e
);
Interpreter
interp
(
env
,
op_map
);
Value
Evaluate
(
Module
mod
,
Expr
e
)
{
auto
op_map
=
CompileOperators
(
mod
,
e
);
Interpreter
interp
(
mod
,
op_map
);
return
interp
.
Eval
(
e
);
}
TVM_REGISTER_API
(
"relay._interpreter.evaluate"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Environment
env
=
args
[
0
];
Module
mod
=
args
[
0
];
Expr
expr
=
args
[
1
];
*
ret
=
Evaluate
(
env
,
expr
);
*
ret
=
Evaluate
(
mod
,
expr
);
});
}
// namespace relay
...
...
src/relay/ir/
environment
.cc
→
src/relay/ir/
module
.cc
View file @
ead3ac6c
/*!
* Copyright (c) 2018 by Contributors
* \file
environment
.cc
* \brief The global
environment
in Relay.
* \file
module
.cc
* \brief The global
module
in Relay.
*/
#include <tvm/relay/
environment
.h>
#include <tvm/relay/
module
.h>
#include <tvm/relay/pass.h>
#include <sstream>
...
...
@@ -13,8 +13,8 @@ namespace relay {
using
tvm
::
IRPrinter
;
using
namespace
runtime
;
Environment
Environment
Node
::
make
(
tvm
::
Map
<
GlobalVar
,
Function
>
global_funcs
)
{
auto
n
=
make_node
<
Environment
Node
>
();
Module
Module
Node
::
make
(
tvm
::
Map
<
GlobalVar
,
Function
>
global_funcs
)
{
auto
n
=
make_node
<
Module
Node
>
();
n
->
functions
=
std
::
move
(
global_funcs
);
for
(
const
auto
&
kv
:
n
->
functions
)
{
...
...
@@ -23,22 +23,22 @@ Environment EnvironmentNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
<<
"Duplicate global function name "
<<
kv
.
first
->
name_hint
;
n
->
global_var_map_
.
Set
(
kv
.
first
->
name_hint
,
kv
.
first
);
}
return
Environment
(
n
);
return
Module
(
n
);
}
GlobalVar
Environment
Node
::
GetGlobalVar
(
const
std
::
string
&
name
)
{
GlobalVar
Module
Node
::
GetGlobalVar
(
const
std
::
string
&
name
)
{
auto
it
=
global_var_map_
.
find
(
name
);
CHECK
(
it
!=
global_var_map_
.
end
())
<<
"Cannot find global var "
<<
name
<<
" in the
Environment
"
;
<<
"Cannot find global var "
<<
name
<<
" in the
Module
"
;
return
(
*
it
).
second
;
}
void
Environment
Node
::
Add
(
const
GlobalVar
&
var
,
void
Module
Node
::
Add
(
const
GlobalVar
&
var
,
const
Function
&
func
,
bool
update
)
{
// Type check the item before we add it to the
env
ironment.
auto
env
=
GetRef
<
Environment
>
(
this
);
Function
checked_func
=
InferType
(
func
,
env
,
var
);
// Type check the item before we add it to the
mod
ironment.
auto
mod
=
GetRef
<
Module
>
(
this
);
Function
checked_func
=
InferType
(
func
,
mod
,
var
);
auto
type
=
checked_func
->
checked_type
();
CHECK
(
type
.
as
<
IncompleteTypeNode
>
()
==
nullptr
);
if
(
functions
.
find
(
var
)
!=
functions
.
end
())
{
...
...
@@ -46,7 +46,7 @@ void EnvironmentNode::Add(const GlobalVar& var,
<<
"Already have definition for "
<<
var
->
name_hint
;
auto
old_type
=
functions
[
var
].
as
<
FunctionNode
>
()
->
checked_type
();
CHECK
(
AlphaEqual
(
type
,
old_type
))
<<
"
Environment
#update changes type, not possible in this mode."
;
<<
"
Module
#update changes type, not possible in this mode."
;
}
this
->
functions
.
Set
(
var
,
checked_func
);
...
...
@@ -62,79 +62,79 @@ void EnvironmentNode::Add(const GlobalVar& var,
global_var_map_
.
Set
(
var
->
name_hint
,
var
);
}
void
Environment
Node
::
Update
(
const
GlobalVar
&
var
,
const
Function
&
func
)
{
void
Module
Node
::
Update
(
const
GlobalVar
&
var
,
const
Function
&
func
)
{
this
->
Add
(
var
,
func
,
true
);
}
void
Environment
Node
::
Remove
(
const
GlobalVar
&
var
)
{
void
Module
Node
::
Remove
(
const
GlobalVar
&
var
)
{
auto
functions_node
=
this
->
functions
.
CopyOnWrite
();
functions_node
->
data
.
erase
(
var
.
node_
);
auto
gvar_node
=
global_var_map_
.
CopyOnWrite
();
gvar_node
->
data
.
erase
(
var
->
name_hint
);
}
Function
Environment
Node
::
Lookup
(
const
GlobalVar
&
var
)
{
Function
Module
Node
::
Lookup
(
const
GlobalVar
&
var
)
{
auto
it
=
functions
.
find
(
var
);
CHECK
(
it
!=
functions
.
end
())
<<
"There is no definition of "
<<
var
->
name_hint
;
return
(
*
it
).
second
;
}
Function
Environment
Node
::
Lookup
(
const
std
::
string
&
name
)
{
Function
Module
Node
::
Lookup
(
const
std
::
string
&
name
)
{
GlobalVar
id
=
this
->
GetGlobalVar
(
name
);
return
this
->
Lookup
(
id
);
}
void
EnvironmentNode
::
Update
(
const
Environment
&
env
)
{
for
(
auto
pair
:
env
->
functions
)
{
void
ModuleNode
::
Update
(
const
Module
&
mod
)
{
for
(
auto
pair
:
mod
->
functions
)
{
this
->
Update
(
pair
.
first
,
pair
.
second
);
}
}
TVM_REGISTER_NODE_TYPE
(
Environment
Node
);
TVM_REGISTER_NODE_TYPE
(
Module
Node
);
TVM_REGISTER_API
(
"relay._make.
Environment
"
)
TVM_REGISTER_API
(
"relay._make.
Module
"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
Environment
Node
::
make
(
args
[
0
]);
*
ret
=
Module
Node
::
make
(
args
[
0
]);
});
TVM_REGISTER_API
(
"relay._
env.Environment
_Add"
)
TVM_REGISTER_API
(
"relay._
module.Module
_Add"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Environment
env
=
args
[
0
];
env
->
Add
(
args
[
1
],
args
[
2
],
args
[
3
]);
Module
mod
=
args
[
0
];
mod
->
Add
(
args
[
1
],
args
[
2
],
args
[
3
]);
});
TVM_REGISTER_API
(
"relay._
env.Environment
_GetGlobalVar"
)
TVM_REGISTER_API
(
"relay._
module.Module
_GetGlobalVar"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Environment
env
=
args
[
0
];
*
ret
=
env
->
GetGlobalVar
(
args
[
1
]);
Module
mod
=
args
[
0
];
*
ret
=
mod
->
GetGlobalVar
(
args
[
1
]);
});
TVM_REGISTER_API
(
"relay._
env.Environment
_Lookup"
)
TVM_REGISTER_API
(
"relay._
module.Module
_Lookup"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Environment
env
=
args
[
0
];
Module
mod
=
args
[
0
];
GlobalVar
var
=
args
[
1
];
*
ret
=
env
->
Lookup
(
var
);
*
ret
=
mod
->
Lookup
(
var
);
});
TVM_REGISTER_API
(
"relay._
env.Environment
_Lookup_str"
)
TVM_REGISTER_API
(
"relay._
module.Module
_Lookup_str"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Environment
env
=
args
[
0
];
Module
mod
=
args
[
0
];
std
::
string
var_name
=
args
[
1
];
auto
var
=
env
->
GetGlobalVar
(
var_name
);
*
ret
=
env
->
Lookup
(
var
);
auto
var
=
mod
->
GetGlobalVar
(
var_name
);
*
ret
=
mod
->
Lookup
(
var
);
});
TVM_REGISTER_API
(
"relay._
env.Environment
_Update"
)
TVM_REGISTER_API
(
"relay._
module.Module
_Update"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Environment
env
=
args
[
0
];
env
->
Update
(
args
[
1
]);
Module
mod
=
args
[
0
];
mod
->
Update
(
args
[
1
]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
Environment
Node
>
(
[](
const
Environment
Node
*
node
,
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"
Environment
Node( "
<<
node
->
functions
<<
")"
;
.
set_dispatch
<
Module
Node
>
(
[](
const
Module
Node
*
node
,
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"
Module
Node( "
<<
node
->
functions
<<
")"
;
});
}
// namespace relay
...
...
src/relay/ir/text_printer.cc
View file @
ead3ac6c
...
...
@@ -3,7 +3,7 @@
* \file text_printer.cc
* \brief Text printer to print relay in text form.
*/
#include <tvm/relay/
environment
.h>
#include <tvm/relay/
module
.h>
#include <tvm/relay/expr_functor.h>
#include <sstream>
#include "type_functor.h"
...
...
@@ -133,8 +133,8 @@ class TextPrinter :
std
::
string
Print
(
const
NodeRef
&
node
)
{
if
(
node
.
as
<
FunctionNode
>
())
{
this
->
PrintFunc
(
Downcast
<
Function
>
(
node
));
}
else
if
(
node
.
as
<
Environment
Node
>
())
{
this
->
PrintEnv
(
Downcast
<
Environment
>
(
node
));
}
else
if
(
node
.
as
<
Module
Node
>
())
{
this
->
PrintEnv
(
Downcast
<
Module
>
(
node
));
}
else
if
(
node
.
as_derived
<
TypeNode
>
())
{
this
->
PrintType
(
Downcast
<
Type
>
(
node
),
stream_
);
}
else
if
(
node
.
as_derived
<
ExprNode
>
())
{
...
...
@@ -158,9 +158,9 @@ class TextPrinter :
stream_
<<
"
\n
"
;
}
void
PrintEnv
(
const
Environment
&
env
)
{
void
PrintEnv
(
const
Module
&
mod
)
{
int
counter
=
0
;
for
(
const
auto
&
kv
:
env
->
functions
)
{
for
(
const
auto
&
kv
:
mod
->
functions
)
{
std
::
ostringstream
os
;
if
(
counter
++
!=
0
)
{
stream_
<<
"
\n
"
;
...
...
src/relay/pass/fuse_ops.cc
View file @
ead3ac6c
...
...
@@ -20,12 +20,12 @@ namespace relay {
using
namespace
runtime
;
struct
AbstractFusableOps
:
ExprMutator
{
Environment
env
;
Module
mod
;
Array
<
GlobalVar
>
fusable_funcs
;
int
counter
=
0
;
size_t
expr_hash
;
AbstractFusableOps
(
Environment
env
,
size_t
expr_hash
)
:
env
(
env
),
expr_hash
(
expr_hash
)
{}
AbstractFusableOps
(
Module
mod
,
size_t
expr_hash
)
:
mod
(
mod
),
expr_hash
(
expr_hash
)
{}
Expr
VisitExpr_
(
const
CallNode
*
call
)
{
if
(
auto
op_node
=
call
->
op
.
as
<
OpNode
>
())
{
...
...
@@ -55,7 +55,7 @@ struct AbstractFusableOps : ExprMutator {
func_name
+=
"_"
;
func_name
+=
std
::
to_string
(
expr_hash
);
auto
gv
=
GlobalVarNode
::
make
(
func_name
);
env
->
Add
(
gv
,
func
);
mod
->
Add
(
gv
,
func
);
fusable_funcs
.
push_back
(
gv
);
return
CallNode
::
make
(
gv
,
args
,
Attrs
());
}
else
{
...
...
@@ -64,12 +64,12 @@ struct AbstractFusableOps : ExprMutator {
}
};
Expr
FuseOps
(
const
Environment
&
env
,
const
Expr
&
e
)
{
Expr
FuseOps
(
const
Module
&
mod
,
const
Expr
&
e
)
{
// First we convert all chains of fusable ops into
// abstracted functions which we mark as primtive
// then we convert these primtive functions into
// new operators.
auto
abstract
=
AbstractFusableOps
(
env
,
StructuralHash
()(
e
));
auto
abstract
=
AbstractFusableOps
(
mod
,
StructuralHash
()(
e
));
auto
abstracted_e
=
abstract
.
VisitExpr
(
e
);
RELAY_LOG
(
INFO
)
<<
"FuseOps: before="
<<
e
<<
"Fuse: after="
<<
abstracted_e
;
...
...
src/relay/pass/kind_check.cc
View file @
ead3ac6c
...
...
@@ -99,7 +99,7 @@ struct KindChecker : TypeVisitor {
}
};
bool
KindCheck
(
const
Type
&
t
,
const
Environment
&
env
)
{
bool
KindCheck
(
const
Type
&
t
,
const
Module
&
mod
)
{
KindChecker
kc
;
return
kc
.
Check
(
t
);
}
...
...
@@ -107,7 +107,7 @@ bool KindCheck(const Type& t, const Environment& env) {
TVM_REGISTER_API
(
"relay._ir_pass.check_kind"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
if
(
args
.
size
()
==
1
)
{
*
ret
=
KindCheck
(
args
[
0
],
Environment
Node
::
make
({}));
*
ret
=
KindCheck
(
args
[
0
],
Module
Node
::
make
({}));
}
else
{
*
ret
=
KindCheck
(
args
[
0
],
args
[
1
]);
}
...
...
src/relay/pass/lower_ops.cc
View file @
ead3ac6c
...
...
@@ -28,12 +28,12 @@ LoweredOp LoweredOpNode::make(Function func, LoweredFunc lowered_func) {
}
struct
AbstractLocalFunctions
:
ExprMutator
{
Environment
env
;
Module
mod
;
size_t
expr_hash
;
int
counter
=
0
;
std
::
unordered_set
<
GlobalVar
,
NodeHash
,
NodeEqual
>
visited_funcs
;
explicit
AbstractLocalFunctions
(
Environment
env
)
:
env
(
env
),
expr_hash
(
0
),
counter
(
0
),
visited_funcs
()
{}
explicit
AbstractLocalFunctions
(
Module
mod
)
:
mod
(
mod
),
expr_hash
(
0
),
counter
(
0
),
visited_funcs
()
{}
Expr
Abstract
(
const
Expr
&
e
)
{
expr_hash
=
StructuralHash
()(
e
);
...
...
@@ -44,7 +44,7 @@ struct AbstractLocalFunctions : ExprMutator {
auto
gvar
=
GetRef
<
GlobalVar
>
(
gvar_node
);
auto
it
=
visited_funcs
.
find
(
gvar
);
if
(
it
==
visited_funcs
.
end
())
{
auto
func
=
env
->
Lookup
(
gvar
);
auto
func
=
mod
->
Lookup
(
gvar
);
visited_funcs
.
insert
(
gvar
);
auto
new_func
=
FunctionNode
::
make
(
func
->
params
,
...
...
@@ -52,7 +52,7 @@ struct AbstractLocalFunctions : ExprMutator {
func
->
ret_type
,
func
->
type_params
,
func
->
attrs
);
env
->
Update
(
gvar
,
new_func
);
mod
->
Update
(
gvar
,
new_func
);
}
return
gvar
;
}
...
...
@@ -70,7 +70,7 @@ struct AbstractLocalFunctions : ExprMutator {
abs_func
+=
std
::
to_string
(
expr_hash
);
auto
gv
=
GlobalVarNode
::
make
(
abs_func
);
auto
lifted_func
=
FunctionNode
::
make
(
params
,
func
,
Type
(),
{},
{});
env
->
Add
(
gv
,
lifted_func
);
mod
->
Add
(
gv
,
lifted_func
);
Array
<
Expr
>
args
;
for
(
auto
free_var
:
free_vars
)
{
args
.
push_back
(
free_var
);
...
...
@@ -80,8 +80,8 @@ struct AbstractLocalFunctions : ExprMutator {
};
struct
LiveFunctions
:
ExprVisitor
{
Environment
env
;
explicit
LiveFunctions
(
Environment
env
)
:
env
(
env
),
global_funcs
()
{}
Module
mod
;
explicit
LiveFunctions
(
Module
mod
)
:
mod
(
mod
),
global_funcs
()
{}
std
::
unordered_set
<
GlobalVar
,
NodeHash
,
NodeEqual
>
visited_funcs
;
std
::
unordered_set
<
GlobalVar
,
NodeHash
,
NodeEqual
>
global_funcs
;
...
...
@@ -100,7 +100,7 @@ struct LiveFunctions : ExprVisitor {
GlobalVar
var
=
GetRef
<
GlobalVar
>
(
var_node
);
auto
it
=
visited_funcs
.
find
(
var
);
if
(
it
==
visited_funcs
.
end
())
{
auto
func
=
env
->
Lookup
(
var
);
auto
func
=
mod
->
Lookup
(
var
);
visited_funcs
.
insert
(
var
);
// The last pass has trasnformed functions of the form:
//
...
...
@@ -134,7 +134,7 @@ struct LiveFunctions : ExprVisitor {
RELAY_LOG
(
INFO
)
<<
"LiveOps: CallNode="
<<
GetRef
<
Call
>
(
call
);
if
(
auto
gv_node
=
call
->
op
.
as
<
GlobalVarNode
>
())
{
GlobalVar
gvar
=
GetRef
<
GlobalVar
>
(
gv_node
);
Function
func
=
env
->
Lookup
(
gvar
);
Function
func
=
mod
->
Lookup
(
gvar
);
auto
attr
=
FunctionGetAttr
(
func
,
"Primitive"
);
...
...
@@ -159,15 +159,15 @@ using FCompute = TypedPackedFunc<Array<Tensor>(
using
FSchedule
=
TypedPackedFunc
<
Schedule
(
const
Array
<
Tensor
>&
,
std
::
string
)
>
;
/*! \brief Return the set of operators in their TVM format. */
Array
<
LoweredOp
>
LowerOps
(
const
Environment
&
env
,
const
Expr
&
e
,
Array
<
LoweredOp
>
LowerOps
(
const
Module
&
mod
,
const
Expr
&
e
,
const
std
::
string
&
target
)
{
RELAY_LOG
(
INFO
)
<<
"LowerOps: e="
<<
e
;
auto
flower_ptr
=
Registry
::
Get
(
"relay.op.compiler._lower"
);
CHECK
(
flower_ptr
);
PackedFunc
flower
=
*
flower_ptr
;
auto
abstracted_e
=
AbstractLocalFunctions
(
env
).
Abstract
(
e
);
auto
live_funcs
=
LiveFunctions
(
env
);
auto
abstracted_e
=
AbstractLocalFunctions
(
mod
).
Abstract
(
e
);
auto
live_funcs
=
LiveFunctions
(
mod
);
live_funcs
.
VisitExpr
(
abstracted_e
);
auto
schedule_reg
=
Op
::
GetAttr
<
FSchedule
>
(
"FTVMSchedule"
);
...
...
@@ -176,7 +176,7 @@ Array<LoweredOp> LowerOps(const Environment& env, const Expr& e,
Array
<
LoweredOp
>
lowered_funcs
;
for
(
auto
func_name
:
live_funcs
.
global_funcs
)
{
auto
func
=
env
->
Lookup
(
func_name
);
auto
func
=
mod
->
Lookup
(
func_name
);
auto
call
=
Downcast
<
Call
>
(
func
->
body
);
auto
op_node
=
call
->
op
.
as
<
OpNode
>
();
CHECK
(
op_node
)
<<
"violated invariant that primtiive calls contain a single op call"
;
...
...
@@ -205,7 +205,7 @@ Array<LoweredOp> LowerOps(const Environment& env, const Expr& e,
LoweredFunc
lf
=
flower
(
op
->
name
+
std
::
to_string
(
hash
),
schedule
,
inputs
,
outputs
);
func
=
FunctionSetAttr
(
func
,
"LoweredFunc"
,
lf
);
env
->
Add
(
func_name
,
func
,
true
);
mod
->
Add
(
func_name
,
func
,
true
);
lowered_funcs
.
push_back
(
LoweredOpNode
::
make
(
func
,
lf
));
}
...
...
src/relay/pass/type_infer.cc
View file @
ead3ac6c
...
...
@@ -104,8 +104,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// constructors
TypeInferencer
()
{
}
explicit
TypeInferencer
(
Environment
env
)
:
env_
(
env
)
{
explicit
TypeInferencer
(
Module
mod
)
:
mod_
(
mod
)
{
}
// inference the type of expr.
...
...
@@ -115,7 +115,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// type resolver that maps back to type
class
Resolver
;
// internal environment
Environment
env
_
;
Module
mod
_
;
// map from expression to checked type
// type inferencer will populate it up
std
::
unordered_map
<
Expr
,
ResolvedTypeInfo
,
NodeHash
,
NodeEqual
>
type_map_
;
...
...
@@ -164,9 +164,9 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
Type
VisitExpr_
(
const
GlobalVarNode
*
op
)
final
{
GlobalVar
var
=
GetRef
<
GlobalVar
>
(
op
);
CHECK
(
env
_
.
defined
())
CHECK
(
mod
_
.
defined
())
<<
"Cannot do type inference without a global variable"
;
Expr
e
=
env
_
->
Lookup
(
var
);
Expr
e
=
mod
_
->
Lookup
(
var
);
return
e
->
checked_type
();
}
...
...
@@ -511,20 +511,20 @@ Expr TypeInferencer::Infer(Expr expr) {
}
Expr
InferType
(
const
Expr
&
expr
,
const
Environment
&
env
)
{
auto
e
=
TypeInferencer
(
env
).
Infer
(
expr
);
Expr
InferType
(
const
Expr
&
expr
,
const
Module
&
mod
)
{
auto
e
=
TypeInferencer
(
mod
).
Infer
(
expr
);
CHECK
(
WellFormed
(
e
));
return
e
;
}
Function
InferType
(
const
Function
&
func
,
const
Environment
&
env
,
const
Module
&
mod
,
const
GlobalVar
&
var
)
{
Function
func_copy
=
Function
(
make_node
<
FunctionNode
>
(
*
func
.
operator
->
()));
func_copy
->
checked_type_
=
func_copy
->
func_type_annotation
();
env
->
functions
.
Set
(
var
,
func_copy
);
Expr
func_ret
=
TypeInferencer
(
env
).
Infer
(
func_copy
);
auto
map_node
=
env
->
functions
.
CopyOnWrite
();
mod
->
functions
.
Set
(
var
,
func_copy
);
Expr
func_ret
=
TypeInferencer
(
mod
).
Infer
(
func_copy
);
auto
map_node
=
mod
->
functions
.
CopyOnWrite
();
map_node
->
data
.
erase
(
var
.
node_
);
CHECK
(
WellFormed
(
func_ret
));
return
Downcast
<
Function
>
(
func_ret
);
...
...
tests/cpp/relay_pass_type_infer_test.cc
View file @
ead3ac6c
...
...
@@ -11,7 +11,7 @@ TEST(Relay, SelfReference) {
auto
x
=
relay
::
VarNode
::
make
(
"x"
,
type_a
);
auto
f
=
relay
::
FunctionNode
::
make
(
tvm
::
Array
<
relay
::
Var
>
{
x
},
x
,
type_b
,
Array
<
relay
::
TypeVar
>
{});
auto
fx
=
relay
::
CallNode
::
make
(
f
,
Array
<
relay
::
Expr
>
{
x
});
auto
type_fx
=
relay
::
InferType
(
fx
,
relay
::
Environment
Node
::
make
(
Map
<
relay
::
GlobalVar
,
relay
::
Function
>
{}));
auto
type_fx
=
relay
::
InferType
(
fx
,
relay
::
Module
Node
::
make
(
Map
<
relay
::
GlobalVar
,
relay
::
Function
>
{}));
CHECK_EQ
(
type_fx
->
checked_type
(),
type_a
);
}
...
...
tests/python/relay/test_graph_runtime.py
View file @
ead3ac6c
...
...
@@ -6,10 +6,10 @@ from tvm.relay.ir_pass import infer_type
from
tvm.relay.interpreter
import
Interpreter
from
tvm.relay.scope_builder
import
ScopeBuilder
from
tvm.relay.op
import
add
from
tvm.relay.
env
import
Environment
from
tvm.relay.
module
import
Module
# @tq, @jr should we put this in testing ns?
def
check_rts
(
expr
,
args
,
expected_result
,
env
=
None
):
def
check_rts
(
expr
,
args
,
expected_result
,
mod
=
None
):
"""
Check that evaluating `expr` applied to the arguments produces
`result` on both the evaluator and TVM runtime.
...
...
@@ -25,8 +25,8 @@ def check_rts(expr, args, expected_result, env=None):
expected_result:
The expected result of running the expression.
"""
intrp
=
create_executor
(
'graph'
,
env
=
env
)
graph
=
create_executor
(
'graph'
,
env
=
env
)
intrp
=
create_executor
(
'graph'
,
mod
=
mod
)
graph
=
create_executor
(
'graph'
,
mod
=
mod
)
eval_result
=
intrp
.
evaluate
(
expr
)(
*
args
)
rts_result
=
graph
.
evaluate
(
expr
)(
*
args
)
np
.
testing
.
assert_allclose
(
eval_result
.
asnumpy
(),
rts_result
.
asnumpy
())
...
...
tests/python/relay/test_interpreter.py
View file @
ead3ac6c
...
...
@@ -7,8 +7,8 @@ from tvm.relay.scope_builder import ScopeBuilder
from
tvm.relay
import
testing
,
create_executor
def
check_eval
(
expr
,
args
,
expected_result
,
env
=
None
,
rtol
=
1e-07
):
intrp
=
create_executor
(
env
=
env
)
def
check_eval
(
expr
,
args
,
expected_result
,
mod
=
None
,
rtol
=
1e-07
):
intrp
=
create_executor
(
mod
=
mod
)
result
=
intrp
.
evaluate
(
expr
)(
*
args
)
np
.
testing
.
assert_allclose
(
result
.
asnumpy
(),
expected_result
,
rtol
=
rtol
)
...
...
@@ -87,7 +87,7 @@ def test_subtract():
check_eval
(
func
,
[
i_data
],
0
)
def
test_simple_loop
():
env
=
relay
.
env
.
Environment
({})
mod
=
relay
.
module
.
Module
({})
sum_up
=
relay
.
GlobalVar
(
'sum_up'
)
i
=
relay
.
var
(
'i'
,
shape
=
[],
dtype
=
'int32'
)
sb
=
ScopeBuilder
()
...
...
@@ -98,12 +98,12 @@ def test_simple_loop():
rec_call
=
relay
.
Call
(
sum_up
,
[
one_less
])
sb
.
ret
(
op
.
add
(
rec_call
,
i
))
func
=
relay
.
Function
([
i
],
sb
.
get
(),
ret_type
=
relay
.
TensorType
([],
'int32'
))
env
[
sum_up
]
=
func
mod
[
sum_up
]
=
func
i_data
=
np
.
array
(
10
,
dtype
=
'int32'
)
check_eval
(
sum_up
,
[
i_data
],
sum
(
range
(
1
,
11
)),
env
=
env
)
check_eval
(
sum_up
,
[
i_data
],
sum
(
range
(
1
,
11
)),
mod
=
mod
)
def
test_loop
():
env
=
relay
.
env
.
Environment
({})
mod
=
relay
.
module
.
Module
({})
sum_up
=
relay
.
GlobalVar
(
'sum_up'
)
i
=
relay
.
var
(
'i'
,
shape
=
[],
dtype
=
'int32'
)
accum
=
relay
.
var
(
'accum'
,
shape
=
[],
dtype
=
'int32'
)
...
...
@@ -115,10 +115,10 @@ def test_loop():
new_accum
=
op
.
add
(
accum
,
i
)
sb
.
ret
(
relay
.
Call
(
sum_up
,
[
one_less
,
new_accum
]))
func
=
relay
.
Function
([
i
,
accum
],
sb
.
get
())
env
[
sum_up
]
=
func
mod
[
sum_up
]
=
func
i_data
=
np
.
array
(
10
,
dtype
=
'int32'
)
accum_data
=
np
.
array
(
0
,
dtype
=
'int32'
)
check_eval
(
sum_up
,
[
i_data
,
accum_data
],
sum
(
range
(
1
,
11
)),
env
=
env
)
check_eval
(
sum_up
,
[
i_data
,
accum_data
],
sum
(
range
(
1
,
11
)),
mod
=
mod
)
def
test_mlp
():
pass
...
...
tests/python/relay/test_ir_text_printer.py
View file @
ead3ac6c
...
...
@@ -28,7 +28,7 @@ def test_env():
z
=
relay
.
add
(
x
,
y
)
z
=
relay
.
add
(
z
,
z
)
f
=
relay
.
Function
([
x
,
y
],
z
)
env
=
relay
.
Environment
()
env
=
relay
.
Module
()
env
[
"myf"
]
=
f
text
=
env
.
astext
()
assert
"def @myf"
in
text
...
...
tests/python/relay/test_type_infer.py
View file @
ead3ac6c
...
...
@@ -9,8 +9,8 @@ from tvm.relay import op
from
tvm.relay.scope_builder
import
ScopeBuilder
def
assert_has_type
(
expr
,
typ
,
env
=
relay
.
env
.
Environment
({})):
checked_expr
=
infer_type
(
expr
,
env
)
def
assert_has_type
(
expr
,
typ
,
mod
=
relay
.
module
.
Module
({})):
checked_expr
=
infer_type
(
expr
,
mod
)
checked_type
=
checked_expr
.
checked_type
if
checked_type
!=
typ
:
raise
RuntimeError
(
"Type mismatch
%
s vs
%
s"
%
(
...
...
@@ -105,10 +105,10 @@ def test_recursion():
sb
.
ret
(
data
)
with
sb
.
else_scope
():
sb
.
ret
(
f
(
relay
.
subtract
(
n
,
relay
.
const
(
1
,
ti32
)),
relay
.
log
(
data
)))
env
=
relay
.
Environment
()
env
[
f
]
=
relay
.
Function
([
n
,
data
],
sb
.
get
())
assert
"
%3
= @f(
%1
,
%2
)"
in
env
.
astext
()
assert
env
[
f
]
.
checked_type
==
relay
.
FuncType
([
ti32
,
tf32
],
tf32
)
mod
=
relay
.
Module
()
mod
[
f
]
=
relay
.
Function
([
n
,
data
],
sb
.
get
())
assert
"
%3
= @f(
%1
,
%2
)"
in
mod
.
astext
()
assert
mod
[
f
]
.
checked_type
==
relay
.
FuncType
([
ti32
,
tf32
],
tf32
)
# This currently fails and should pass under the type system.
#
...
...
@@ -179,12 +179,12 @@ def test_self_reference():
def
test_global_var_cow_issue
():
env
=
relay
.
env
.
Environment
({})
mod
=
relay
.
Module
({})
gv
=
relay
.
GlobalVar
(
"foo"
)
x
=
relay
.
var
(
'x'
,
shape
=
[])
func
=
relay
.
Function
([
x
],
relay
.
Call
(
gv
,
[
x
]),
relay
.
TensorType
([],
'float32'
))
env
[
gv
]
=
func
mod
[
gv
]
=
func
def
test_equal
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment