Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
ead3ac6c
Commit
ead3ac6c
authored
Nov 02, 2018
by
Jared Roesch
Committed by
Tianqi Chen
Nov 02, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Rename relay::Environment to relay::Module (#2054)
parent
420ec786
Show whitespace changes
Inline
Side-by-side
Showing
28 changed files
with
223 additions
and
224 deletions
+223
-224
include/tvm/relay/base.h
+1
-1
include/tvm/relay/build_module.h
+3
-3
include/tvm/relay/expr.h
+1
-1
include/tvm/relay/interpreter.h
+3
-3
include/tvm/relay/module.h
+20
-20
include/tvm/relay/pass.h
+9
-9
include/tvm/relay/type.h
+2
-2
python/tvm/relay/__init__.py
+2
-2
python/tvm/relay/_ir_pass.pyi
+3
-3
python/tvm/relay/_module.py
+2
-2
python/tvm/relay/_module.pyi
+1
-2
python/tvm/relay/build_module.py
+6
-6
python/tvm/relay/expr.py
+1
-1
python/tvm/relay/interpreter.py
+22
-22
python/tvm/relay/ir_pass.py
+14
-14
python/tvm/relay/module.py
+16
-16
src/relay/interpreter.cc
+16
-16
src/relay/ir/module.cc
+41
-41
src/relay/ir/text_printer.cc
+5
-5
src/relay/pass/fuse_ops.cc
+5
-5
src/relay/pass/kind_check.cc
+2
-2
src/relay/pass/lower_ops.cc
+15
-15
src/relay/pass/type_infer.cc
+11
-11
tests/cpp/relay_pass_type_infer_test.cc
+1
-1
tests/python/relay/test_graph_runtime.py
+4
-4
tests/python/relay/test_interpreter.py
+8
-8
tests/python/relay/test_ir_text_printer.py
+1
-1
tests/python/relay/test_type_infer.py
+8
-8
No files found.
include/tvm/relay/base.h
View file @
ead3ac6c
...
@@ -165,7 +165,7 @@ class RelayNode : public Node {
...
@@ -165,7 +165,7 @@ class RelayNode : public Node {
TVM_DECLARE_BASE_NODE_INFO
(
RelayNode
,
Node
);
TVM_DECLARE_BASE_NODE_INFO
(
RelayNode
,
Node
);
};
};
struct
Environment
;
struct
Module
;
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
...
...
include/tvm/relay/build_module.h
View file @
ead3ac6c
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
#define TVM_RELAY_BUILD_MODULE_H_
#define TVM_RELAY_BUILD_MODULE_H_
#include <tvm/lowered_func.h>
#include <tvm/lowered_func.h>
#include <tvm/relay/
environment
.h>
#include <tvm/relay/
module
.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr.h>
#include <string>
#include <string>
...
@@ -61,13 +61,13 @@ RELAY_DEFINE_NODE_REF(LoweredOp, LoweredOpNode, NodeRef);
...
@@ -61,13 +61,13 @@ RELAY_DEFINE_NODE_REF(LoweredOp, LoweredOpNode, NodeRef);
* \note This will do a reachability analysis and lower all definitions
* \note This will do a reachability analysis and lower all definitions
* reachable from the provided expression.
* reachable from the provided expression.
*
*
* \param
env The environment
.
* \param
mod The module
.
* \param expr The expression with operations to be lowered.
* \param expr The expression with operations to be lowered.
* \param target The target to lower the functions to.
* \param target The target to lower the functions to.
*
*
* \return The set of lowered operations.
* \return The set of lowered operations.
*/
*/
Array
<
LoweredOp
>
LowerOps
(
const
Environment
&
env
,
const
Expr
&
expr
,
Array
<
LoweredOp
>
LowerOps
(
const
Module
&
mod
,
const
Expr
&
expr
,
const
std
::
string
&
target
=
"llvm"
);
const
std
::
string
&
target
=
"llvm"
);
}
// namespace relay
}
// namespace relay
...
...
include/tvm/relay/expr.h
View file @
ead3ac6c
...
@@ -160,7 +160,7 @@ class VarNode : public ExprNode {
...
@@ -160,7 +160,7 @@ class VarNode : public ExprNode {
RELAY_DEFINE_NODE_REF
(
Var
,
VarNode
,
Expr
);
RELAY_DEFINE_NODE_REF
(
Var
,
VarNode
,
Expr
);
/*!
/*!
* \brief Global variable that leaves in the top-level
environment
.
* \brief Global variable that leaves in the top-level
module
.
* This is used to enable recursive calls between function.
* This is used to enable recursive calls between function.
*
*
* \note A GlobalVar may only point to functions.
* \note A GlobalVar may only point to functions.
...
...
include/tvm/relay/interpreter.h
View file @
ead3ac6c
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
* \brief An interpreter for Relay.
* \brief An interpreter for Relay.
*
*
* This file implements a simple reference interpreter for Relay programs.
* This file implements a simple reference interpreter for Relay programs.
* Given a Relay
environment
, and a Relay expression it produces a value.
* Given a Relay
module
, and a Relay expression it produces a value.
*
*
* The interpreter's values are a naive representation of the values that
* The interpreter's values are a naive representation of the values that
* can be produced by a Relay program and are exposed via tvm::Node's
* can be produced by a Relay program and are exposed via tvm::Node's
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
#ifndef TVM_RELAY_INTERPRETER_H_
#ifndef TVM_RELAY_INTERPRETER_H_
#define TVM_RELAY_INTERPRETER_H_
#define TVM_RELAY_INTERPRETER_H_
#include <tvm/relay/
environment
.h>
#include <tvm/relay/
module
.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr.h>
namespace
tvm
{
namespace
tvm
{
...
@@ -39,7 +39,7 @@ class Value;
...
@@ -39,7 +39,7 @@ class Value;
* Our intent is that this will never be the most efficient implementation of
* Our intent is that this will never be the most efficient implementation of
* Relay's semantics, but a readable and clear one.
* Relay's semantics, but a readable and clear one.
*/
*/
Value
Evaluate
(
Environment
env
,
Expr
e
);
Value
Evaluate
(
Module
mod
,
Expr
e
);
/*! \brief The base container type of Relay values. */
/*! \brief The base container type of Relay values. */
class
ValueNode
:
public
RelayNode
{
class
ValueNode
:
public
RelayNode
{
...
...
include/tvm/relay/
environment
.h
→
include/tvm/relay/
module
.h
View file @
ead3ac6c
/*!
/*!
* Copyright (c) 2018 by Contributors
* Copyright (c) 2018 by Contributors
* \file tvm/relay/
environment
.h
* \file tvm/relay/
module
.h
* \brief The global environment: contains information needed to
* \brief The global environment: contains information needed to
* compile & optimize Relay programs.
* compile & optimize Relay programs.
*/
*/
#ifndef TVM_RELAY_
ENVIRONMENT
_H_
#ifndef TVM_RELAY_
MODULE
_H_
#define TVM_RELAY_
ENVIRONMENT
_H_
#define TVM_RELAY_
MODULE
_H_
#include <tvm/relay/error.h>
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr.h>
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
namespace
tvm
{
namespace
tvm
{
namespace
relay
{
namespace
relay
{
struct
Environment
;
struct
Module
;
/*! \brief The global environment of Relay programs.
/*! \brief The global environment of Relay programs.
*
*
...
@@ -28,29 +28,29 @@ struct Environment;
...
@@ -28,29 +28,29 @@ struct Environment;
* options.
* options.
*
*
* Many operations require access to the global
* Many operations require access to the global
*
Environment. We pass the Environment
by value
*
Module. We pass the Module
by value
* in a functional style as an explicit argument,
* in a functional style as an explicit argument,
* but we mutate the
Environment
while optimizing
* but we mutate the
Module
while optimizing
* Relay programs.
* Relay programs.
*
*
* The functional style allows users to construct custom
* The functional style allows users to construct custom
* environments easily, for example each thread can store
* environments easily, for example each thread can store
* an
Environment
while auto-tuning.
* an
Module
while auto-tuning.
* */
* */
class
Environment
Node
:
public
RelayNode
{
class
Module
Node
:
public
RelayNode
{
public
:
public
:
/*! \brief A map from ids to all global functions. */
/*! \brief A map from ids to all global functions. */
tvm
::
Map
<
GlobalVar
,
Function
>
functions
;
tvm
::
Map
<
GlobalVar
,
Function
>
functions
;
Environment
Node
()
{}
Module
Node
()
{}
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"functions"
,
&
functions
);
v
->
Visit
(
"functions"
,
&
functions
);
v
->
Visit
(
"global_var_map_"
,
&
global_var_map_
);
v
->
Visit
(
"global_var_map_"
,
&
global_var_map_
);
}
}
TVM_DLL
static
Environment
make
(
tvm
::
Map
<
GlobalVar
,
Function
>
global_funcs
);
TVM_DLL
static
Module
make
(
tvm
::
Map
<
GlobalVar
,
Function
>
global_funcs
);
/*!
/*!
* \brief Add a function to the global environment.
* \brief Add a function to the global environment.
...
@@ -100,10 +100,10 @@ class EnvironmentNode : public RelayNode {
...
@@ -100,10 +100,10 @@ class EnvironmentNode : public RelayNode {
* functions in another environment.
* functions in another environment.
* \param other The other environment.
* \param other The other environment.
*/
*/
void
Update
(
const
Environment
&
other
);
void
Update
(
const
Module
&
other
);
static
constexpr
const
char
*
_type_key
=
"relay.
Environment
"
;
static
constexpr
const
char
*
_type_key
=
"relay.
Module
"
;
TVM_DECLARE_NODE_TYPE_INFO
(
Environment
Node
,
Node
);
TVM_DECLARE_NODE_TYPE_INFO
(
Module
Node
,
Node
);
private
:
private
:
/*! \brief A map from string names to global variables that
/*! \brief A map from string names to global variables that
...
@@ -112,18 +112,18 @@ class EnvironmentNode : public RelayNode {
...
@@ -112,18 +112,18 @@ class EnvironmentNode : public RelayNode {
tvm
::
Map
<
std
::
string
,
GlobalVar
>
global_var_map_
;
tvm
::
Map
<
std
::
string
,
GlobalVar
>
global_var_map_
;
};
};
struct
Environment
:
public
NodeRef
{
struct
Module
:
public
NodeRef
{
Environment
()
{}
Module
()
{}
explicit
Environment
(
NodePtr
<
tvm
::
Node
>
p
)
:
NodeRef
(
p
)
{}
explicit
Module
(
NodePtr
<
tvm
::
Node
>
p
)
:
NodeRef
(
p
)
{}
inline
Environment
Node
*
operator
->
()
const
{
inline
Module
Node
*
operator
->
()
const
{
return
static_cast
<
Environment
Node
*>
(
node_
.
get
());
return
static_cast
<
Module
Node
*>
(
node_
.
get
());
}
}
using
ContainerType
=
Environment
Node
;
using
ContainerType
=
Module
Node
;
};
};
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
#endif // TVM_RELAY_
ENVIRONMENT
_H_
#endif // TVM_RELAY_
MODULE
_H_
include/tvm/relay/pass.h
View file @
ead3ac6c
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#ifndef TVM_RELAY_PASS_H_
#ifndef TVM_RELAY_PASS_H_
#define TVM_RELAY_PASS_H_
#define TVM_RELAY_PASS_H_
#include <tvm/relay/
environment
.h>
#include <tvm/relay/
module
.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr.h>
#include <string>
#include <string>
...
@@ -21,23 +21,23 @@ namespace relay {
...
@@ -21,23 +21,23 @@ namespace relay {
* populated with the result type.
* populated with the result type.
*
*
* \param expr The expression to type check.
* \param expr The expression to type check.
* \param
env The environment
used for referencing global functions, can be
* \param
mod The module
used for referencing global functions, can be
* None.
* None.
*
*
* \return A type checked expression with its checked_type field populated.
* \return A type checked expression with its checked_type field populated.
*/
*/
Expr
InferType
(
const
Expr
&
expr
,
const
Environment
&
env
);
Expr
InferType
(
const
Expr
&
expr
,
const
Module
&
mod
);
/*!
/*!
* \brief Infer the type of a function as if it is mapped to var in the
env
.
* \brief Infer the type of a function as if it is mapped to var in the
mod
.
*
*
* \param f the function.
* \param f the function.
* \param
env The environment
used for referencing global functions.
* \param
mod The module
used for referencing global functions.
* \param var The global variable corresponding to the function.
* \param var The global variable corresponding to the function.
*
*
* \return A type checked Function with its checked_type field populated.
* \return A type checked Function with its checked_type field populated.
* \note this function mutates
env
and is not thread-safe.
* \note this function mutates
mod
and is not thread-safe.
*/
*/
Function
InferType
(
const
Function
&
f
,
const
Environment
&
env
,
Function
InferType
(
const
Function
&
f
,
const
Module
&
mod
,
const
GlobalVar
&
var
);
const
GlobalVar
&
var
);
/*!
/*!
...
@@ -52,11 +52,11 @@ Function InferType(const Function& f, const Environment& env,
...
@@ -52,11 +52,11 @@ Function InferType(const Function& f, const Environment& env,
* a data type such as `int`, `float`, `uint`.
* a data type such as `int`, `float`, `uint`.
*
*
* \param t The type to check.
* \param t The type to check.
* \param
env The global environment
.
* \param
mod The global module
.
*
*
* \return true if the rules are satisified otherwise false
* \return true if the rules are satisified otherwise false
*/
*/
bool
KindCheck
(
const
Type
&
t
,
const
Environment
&
env
);
bool
KindCheck
(
const
Type
&
t
,
const
Module
&
mod
);
/*! \brief Compare two expressions for structural equivalence.
/*! \brief Compare two expressions for structural equivalence.
*
*
...
...
include/tvm/relay/type.h
View file @
ead3ac6c
...
@@ -349,14 +349,14 @@ class TypeRelation;
...
@@ -349,14 +349,14 @@ class TypeRelation;
/*!
/*!
* \brief TypeRelation container.
* \brief TypeRelation container.
* \note This node is not directly serializable.
* \note This node is not directly serializable.
* The type function need to be lookedup in the
environment
.
* The type function need to be lookedup in the
module
.
*/
*/
class
TypeRelationNode
:
public
TypeConstraintNode
{
class
TypeRelationNode
:
public
TypeConstraintNode
{
public
:
public
:
/*!
/*!
* \brief The function on input and output variables which
* \brief The function on input and output variables which
* this is not directly serializable,
* this is not directly serializable,
* need to be looked-up in the
environment
.
* need to be looked-up in the
module
.
*/
*/
TypeRelationFn
func
;
TypeRelationFn
func
;
/*! \brief The type arguments to the type function. */
/*! \brief The type arguments to the type function. */
...
...
python/tvm/relay/__init__.py
View file @
ead3ac6c
...
@@ -5,7 +5,7 @@ from ..api import register_func
...
@@ -5,7 +5,7 @@ from ..api import register_func
from
.
import
base
from
.
import
base
from
.
import
ty
from
.
import
ty
from
.
import
expr
from
.
import
expr
from
.
import
env
from
.
import
module
from
.
import
ir_pass
from
.
import
ir_pass
from
.build_module
import
build
from
.build_module
import
build
from
.interpreter
import
create_executor
from
.interpreter
import
create_executor
...
@@ -26,7 +26,7 @@ from .scope_builder import ScopeBuilder
...
@@ -26,7 +26,7 @@ from .scope_builder import ScopeBuilder
Span
=
base
.
Span
Span
=
base
.
Span
# Env
# Env
Environment
=
env
.
Environment
Module
=
module
.
Module
# Type
# Type
Type
=
ty
.
Type
Type
=
ty
.
Type
...
...
python/tvm/relay/_ir_pass.pyi
View file @
ead3ac6c
from .env import
Environment
from .env import
Module
from . import ir
from . import ir
def check_expr(env:
Environment
, expr: ir.Expr) -> ir.Type: ...
def check_expr(env:
Module
, expr: ir.Expr) -> ir.Type: ...
def generalize(env:
Environment
, expr: ir.Expr) -> ir.Expr: ...
def generalize(env:
Module
, expr: ir.Expr) -> ir.Expr: ...
def _get_checked_type(expr: ir.Expr) -> ir.Type: ...
def _get_checked_type(expr: ir.Expr) -> ir.Type: ...
def well_formed(expr: ir.Expr) -> bool: ...
def well_formed(expr: ir.Expr) -> bool: ...
def dead_code_elimination(expr: ir.Expr) -> ir.Expr: ...
def dead_code_elimination(expr: ir.Expr) -> ir.Expr: ...
python/tvm/relay/_
env
.py
→
python/tvm/relay/_
module
.py
View file @
ead3ac6c
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface to the
Environment
exposed from C++."""
"""The interface to the
Module
exposed from C++."""
from
tvm._ffi.function
import
_init_api
from
tvm._ffi.function
import
_init_api
_init_api
(
"relay._
env
"
,
__name__
)
_init_api
(
"relay._
module
"
,
__name__
)
python/tvm/relay/_
env
.pyi
→
python/tvm/relay/_
module
.pyi
View file @
ead3ac6c
...
@@ -2,4 +2,4 @@ from typing import Union, Tuple, Dict, List
...
@@ -2,4 +2,4 @@ from typing import Union, Tuple, Dict, List
from relay.ir import GlobalId, OperatorId, Item, NodeBase, Span, FileId
from relay.ir import GlobalId, OperatorId, Item, NodeBase, Span, FileId
from relay.ir import ShapeExtension, Operator, Defn
from relay.ir import ShapeExtension, Operator, Defn
class Environment(NodeBase): ...
class Module(NodeBase): ...
\ No newline at end of file
python/tvm/relay/build_module.py
View file @
ead3ac6c
...
@@ -5,9 +5,9 @@ from a Relay expression.
...
@@ -5,9 +5,9 @@ from a Relay expression.
from
..build_module
import
build
as
tvm_build_module
from
..build_module
import
build
as
tvm_build_module
from
.
graph_runtime_codegen
import
GraphRuntimeCodegen
from
.
graph_runtime_codegen
import
GraphRuntimeCodegen
from
.
import
ir_pass
from
.
import
ir_pass
from
.
env
import
Environment
from
.
module
import
Module
def
build
(
func
,
params
=
None
,
target
=
None
,
env
=
None
):
def
build
(
func
,
params
=
None
,
target
=
None
,
mod
=
None
):
"""
"""
Compile a single function to the components needed by the
Compile a single function to the components needed by the
TVM RTS.
TVM RTS.
...
@@ -29,15 +29,15 @@ def build(func, params=None, target=None, env=None):
...
@@ -29,15 +29,15 @@ def build(func, params=None, target=None, env=None):
if
target
is
None
:
if
target
is
None
:
target
=
'llvm'
target
=
'llvm'
if
env
is
None
:
if
mod
is
None
:
env
=
Environment
({})
mod
=
Module
({})
comp
=
GraphRuntimeCodegen
(
env
)
comp
=
GraphRuntimeCodegen
(
mod
)
# NB(@jroesch) This creates lowered functions, and generates names for them
# NB(@jroesch) This creates lowered functions, and generates names for them
#
#
# We need these names to emit the correct graph as these are names of the
# We need these names to emit the correct graph as these are names of the
# functions contained in the module.
# functions contained in the module.
lowered_ops
=
ir_pass
.
lower_ops
(
env
,
func
)
lowered_ops
=
ir_pass
.
lower_ops
(
mod
,
func
)
mod
=
tvm_build_module
([
lf
.
lowered_func
for
lf
in
lowered_ops
],
target
)
mod
=
tvm_build_module
([
lf
.
lowered_func
for
lf
in
lowered_ops
],
target
)
# Therefore the call to compile must come after.
# Therefore the call to compile must come after.
...
...
python/tvm/relay/expr.py
View file @
ead3ac6c
...
@@ -172,7 +172,7 @@ class GlobalVar(Expr):
...
@@ -172,7 +172,7 @@ class GlobalVar(Expr):
"""A global variable in Tvm.Relay.
"""A global variable in Tvm.Relay.
GlobalVar is used to refer to the global functions
GlobalVar is used to refer to the global functions
stored in the
environment
.
stored in the
module
.
Parameters
Parameters
----------
----------
...
...
python/tvm/relay/interpreter.py
View file @
ead3ac6c
...
@@ -8,7 +8,7 @@ from . import build_module
...
@@ -8,7 +8,7 @@ from . import build_module
from
.
import
_make
from
.
import
_make
from
.
import
_interpreter
from
.
import
_interpreter
from
.
import
ir_pass
from
.
import
ir_pass
from
.
env
import
Environment
from
.
module
import
Module
from
.expr
import
Call
,
Constant
,
GlobalVar
,
Function
,
const
from
.expr
import
Call
,
Constant
,
GlobalVar
,
Function
,
const
from
.scope_builder
import
ScopeBuilder
from
.scope_builder
import
ScopeBuilder
from
.._ffi.base
import
integer_types
from
.._ffi.base
import
integer_types
...
@@ -90,24 +90,24 @@ def _arg_to_ast(arg):
...
@@ -90,24 +90,24 @@ def _arg_to_ast(arg):
class
Executor
(
object
):
class
Executor
(
object
):
"""An abstract interface for executing Relay programs."""
"""An abstract interface for executing Relay programs."""
def
__init__
(
self
,
env
=
None
):
def
__init__
(
self
,
mod
=
None
):
"""
"""
Parameters
Parameters
----------
----------
env: relay.Environment
mod: relay.Module
The
environment
.
The
module
.
"""
"""
if
env
is
None
:
if
mod
is
None
:
self
.
env
=
Environment
({})
self
.
mod
=
Module
({})
else
:
else
:
self
.
env
=
env
self
.
mod
=
mod
def
optimize
(
self
,
expr
):
def
optimize
(
self
,
expr
):
# TODO: We need to move this optimization code into the optimizer/pass manager
# TODO: We need to move this optimization code into the optimizer/pass manager
ck_expr
=
ir_pass
.
infer_type
(
expr
,
env
=
self
.
env
)
ck_expr
=
ir_pass
.
infer_type
(
expr
,
mod
=
self
.
mod
)
fused_expr
=
ir_pass
.
fuse_ops
(
self
.
env
,
ck_expr
)
fused_expr
=
ir_pass
.
fuse_ops
(
self
.
mod
,
ck_expr
)
ck_fused
=
ir_pass
.
infer_type
(
fused_expr
,
env
=
self
.
env
)
ck_fused
=
ir_pass
.
infer_type
(
fused_expr
,
mod
=
self
.
mod
)
return
ck_fused
return
ck_fused
def
_make_executor
(
self
,
_
):
def
_make_executor
(
self
,
_
):
...
@@ -153,8 +153,8 @@ class Interpreter(Executor):
...
@@ -153,8 +153,8 @@ class Interpreter(Executor):
"""
"""
A wrapper around the Relay interpreter, implements the excecutor interface.
A wrapper around the Relay interpreter, implements the excecutor interface.
"""
"""
def
__init__
(
self
,
env
=
None
):
def
__init__
(
self
,
mod
=
None
):
Executor
.
__init__
(
self
,
env
)
Executor
.
__init__
(
self
,
mod
)
def
_make_executor
(
self
,
expr
):
def
_make_executor
(
self
,
expr
):
def
_interp_wrapper
(
*
args
):
def
_interp_wrapper
(
*
args
):
...
@@ -163,28 +163,28 @@ class Interpreter(Executor):
...
@@ -163,28 +163,28 @@ class Interpreter(Executor):
relay_args
.
append
(
_arg_to_ast
(
arg
))
relay_args
.
append
(
_arg_to_ast
(
arg
))
if
isinstance
(
expr
,
GlobalVar
):
if
isinstance
(
expr
,
GlobalVar
):
func
=
self
.
env
[
expr
]
func
=
self
.
mod
[
expr
]
func
=
self
.
optimize
(
func
)
func
=
self
.
optimize
(
func
)
self
.
env
.
_add
(
expr
,
func
,
True
)
self
.
mod
.
_add
(
expr
,
func
,
True
)
opt_expr
=
Call
(
expr
,
relay_args
)
opt_expr
=
Call
(
expr
,
relay_args
)
return
_interpreter
.
evaluate
(
self
.
env
,
opt_expr
)
return
_interpreter
.
evaluate
(
self
.
mod
,
opt_expr
)
else
:
else
:
call
=
Call
(
expr
,
relay_args
)
call
=
Call
(
expr
,
relay_args
)
opt_expr
=
self
.
optimize
(
call
)
opt_expr
=
self
.
optimize
(
call
)
return
_interpreter
.
evaluate
(
self
.
env
,
opt_expr
)
return
_interpreter
.
evaluate
(
self
.
mod
,
opt_expr
)
return
_interp_wrapper
return
_interp_wrapper
class
GraphRuntime
(
Executor
):
class
GraphRuntime
(
Executor
):
"""A wrapper around the TVM graph runtime, implements the Executor interface."""
"""A wrapper around the TVM graph runtime, implements the Executor interface."""
def
__init__
(
self
,
env
=
None
):
def
__init__
(
self
,
mod
=
None
):
Executor
.
__init__
(
self
,
env
)
Executor
.
__init__
(
self
,
mod
)
def
_make_executor
(
self
,
expr
):
def
_make_executor
(
self
,
expr
):
def
_graph_wrapper
(
*
args
):
def
_graph_wrapper
(
*
args
):
func
=
self
.
optimize
(
expr
)
func
=
self
.
optimize
(
expr
)
graph_json
,
mod
,
params
=
build_module
.
build
(
func
,
env
=
self
.
env
)
graph_json
,
mod
,
params
=
build_module
.
build
(
func
,
mod
=
self
.
mod
)
assert
params
is
None
assert
params
is
None
gmodule
=
tvm_runtime
.
create
(
graph_json
,
mod
,
cpu
(
0
))
gmodule
=
tvm_runtime
.
create
(
graph_json
,
mod
,
cpu
(
0
))
# Create map of inputs.
# Create map of inputs.
...
@@ -199,10 +199,10 @@ class GraphRuntime(Executor):
...
@@ -199,10 +199,10 @@ class GraphRuntime(Executor):
return
_graph_wrapper
return
_graph_wrapper
def
create_executor
(
mode
=
'debug'
,
env
=
None
):
def
create_executor
(
mode
=
'debug'
,
mod
=
None
):
if
mode
==
'debug'
:
if
mode
==
'debug'
:
return
Interpreter
(
env
)
return
Interpreter
(
mod
)
elif
mode
==
'graph'
:
elif
mode
==
'graph'
:
return
GraphRuntime
(
env
)
return
GraphRuntime
(
mod
)
else
:
else
:
raise
Exception
(
"unknown mode {0}"
.
format
(
mode
))
raise
Exception
(
"unknown mode {0}"
.
format
(
mode
))
python/tvm/relay/ir_pass.py
View file @
ead3ac6c
...
@@ -11,16 +11,16 @@ from .expr import Expr
...
@@ -11,16 +11,16 @@ from .expr import Expr
from
.ty
import
Type
from
.ty
import
Type
def
infer_type
(
expr
,
env
=
None
):
def
infer_type
(
expr
,
mod
=
None
):
"""Infer the type of expr under the context of
env
.
"""Infer the type of expr under the context of
mod
.
Parameters
Parameters
----------
----------
expr: tvm.relay.Expr
expr: tvm.relay.Expr
The input expression.
The input expression.
env: Optional[tvm.relay.Environment
]
mod: Optional[tvm.relay.Module
]
The global
environment
.
The global
module
.
Returns
Returns
...
@@ -28,7 +28,7 @@ def infer_type(expr, env=None):
...
@@ -28,7 +28,7 @@ def infer_type(expr, env=None):
checked_expr : tvm.relay.Expr
checked_expr : tvm.relay.Expr
The checked expression.
The checked expression.
"""
"""
return
_ir_pass
.
infer_type
(
expr
,
env
)
return
_ir_pass
.
infer_type
(
expr
,
mod
)
def
backward_fold_scale_axis
(
expr
):
def
backward_fold_scale_axis
(
expr
):
...
@@ -93,7 +93,7 @@ def well_formed(expr):
...
@@ -93,7 +93,7 @@ def well_formed(expr):
return
_ir_pass
.
well_formed
(
expr
)
return
_ir_pass
.
well_formed
(
expr
)
def
check_kind
(
t
,
env
=
None
):
def
check_kind
(
t
,
mod
=
None
):
"""Check that the type is well kinded.
"""Check that the type is well kinded.
For example, this mean type cannot has tensor of tensor, or is a tuple type of 2 shapes.
For example, this mean type cannot has tensor of tensor, or is a tuple type of 2 shapes.
...
@@ -102,8 +102,8 @@ def check_kind(t, env=None):
...
@@ -102,8 +102,8 @@ def check_kind(t, env=None):
t: tvm.relay.Type
t: tvm.relay.Type
The type to check
The type to check
env: tvm.relay.Environment
, optional
mod: tvm.relay.Module
, optional
The global
environment
The global
module
Returns
Returns
-------
-------
...
@@ -117,8 +117,8 @@ def check_kind(t, env=None):
...
@@ -117,8 +117,8 @@ def check_kind(t, env=None):
assert not check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Shape)]))
assert not check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Shape)]))
assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)]))
assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)]))
"""
"""
if
env
is
not
None
:
if
mod
is
not
None
:
return
_ir_pass
.
check_kind
(
t
,
env
)
return
_ir_pass
.
check_kind
(
t
,
mod
)
else
:
else
:
return
_ir_pass
.
check_kind
(
t
)
return
_ir_pass
.
check_kind
(
t
)
...
@@ -256,8 +256,8 @@ def structural_hash(value):
...
@@ -256,8 +256,8 @@ def structural_hash(value):
"relay.Expr or relay.Type"
)
.
format
(
type
(
value
))
"relay.Expr or relay.Type"
)
.
format
(
type
(
value
))
raise
TypeError
(
msg
)
raise
TypeError
(
msg
)
def
fuse_ops
(
expr
,
env
):
def
fuse_ops
(
expr
,
mod
):
return
_ir_pass
.
FuseOps
(
env
,
expr
)
return
_ir_pass
.
FuseOps
(
mod
,
expr
)
def
lower_ops
(
env
,
expr
,
target
=
'llvm'
):
def
lower_ops
(
mod
,
expr
,
target
=
'llvm'
):
return
_ir_pass
.
LowerOps
(
env
,
expr
,
target
)
return
_ir_pass
.
LowerOps
(
mod
,
expr
,
target
)
python/tvm/relay/
env
.py
→
python/tvm/relay/
module
.py
View file @
ead3ac6c
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
"""A global
environment
storing everything needed to interpret or compile a Relay program."""
"""A global
module
storing everything needed to interpret or compile a Relay program."""
from
.base
import
register_relay_node
,
RelayNode
from
.base
import
register_relay_node
,
RelayNode
from
.._ffi
import
base
as
_base
from
.._ffi
import
base
as
_base
from
.
import
_make
from
.
import
_make
from
.
import
_
env
from
.
import
_
module
from
.
import
expr
as
_expr
from
.
import
expr
as
_expr
@register_relay_node
@register_relay_node
class
Environment
(
RelayNode
):
class
Module
(
RelayNode
):
"""The global Relay
environment
containing collection of functions.
"""The global Relay
module
containing collection of functions.
Each global function is identified by an unique tvm.relay.GlobalVar.
Each global function is identified by an unique tvm.relay.GlobalVar.
tvm.relay.GlobalVar and
Environment
is necessary in order to enable
tvm.relay.GlobalVar and
Module
is necessary in order to enable
recursions in function to avoid cyclic reference in the function.x
recursions in function to avoid cyclic reference in the function.x
Parameters
Parameters
...
@@ -32,10 +32,10 @@ class Environment(RelayNode):
...
@@ -32,10 +32,10 @@ class Environment(RelayNode):
raise
TypeError
(
"Expect functions to be Dict[GlobalVar, Function]"
)
raise
TypeError
(
"Expect functions to be Dict[GlobalVar, Function]"
)
mapped_funcs
[
k
]
=
v
mapped_funcs
[
k
]
=
v
functions
=
mapped_funcs
functions
=
mapped_funcs
self
.
__init_handle_by_constructor__
(
_make
.
Environment
,
functions
)
self
.
__init_handle_by_constructor__
(
_make
.
Module
,
functions
)
def
__setitem__
(
self
,
var
,
func
):
def
__setitem__
(
self
,
var
,
func
):
"""Add a function to the
environment
.
"""Add a function to the
module
.
Parameters
Parameters
---------
---------
...
@@ -50,7 +50,7 @@ class Environment(RelayNode):
...
@@ -50,7 +50,7 @@ class Environment(RelayNode):
def
_add
(
self
,
var
,
func
,
update
=
False
):
def
_add
(
self
,
var
,
func
,
update
=
False
):
if
isinstance
(
var
,
_base
.
string_types
):
if
isinstance
(
var
,
_base
.
string_types
):
var
=
_expr
.
GlobalVar
(
var
)
var
=
_expr
.
GlobalVar
(
var
)
return
_
env
.
Environment
_Add
(
self
,
var
,
func
,
update
)
return
_
module
.
Module
_Add
(
self
,
var
,
func
,
update
)
def
__getitem__
(
self
,
var
):
def
__getitem__
(
self
,
var
):
"""Lookup a global function by name or by variable.
"""Lookup a global function by name or by variable.
...
@@ -66,21 +66,21 @@ class Environment(RelayNode):
...
@@ -66,21 +66,21 @@ class Environment(RelayNode):
The function referenced by :code:`var`.
The function referenced by :code:`var`.
"""
"""
if
isinstance
(
var
,
_base
.
string_types
):
if
isinstance
(
var
,
_base
.
string_types
):
return
_
env
.
Environment
_Lookup_str
(
self
,
var
)
return
_
module
.
Module
_Lookup_str
(
self
,
var
)
else
:
else
:
return
_
env
.
Environment
_Lookup
(
self
,
var
)
return
_
module
.
Module
_Lookup
(
self
,
var
)
def
update
(
self
,
other
):
def
update
(
self
,
other
):
"""Insert functions in another
Environment
to current one.
"""Insert functions in another
Module
to current one.
Parameters
Parameters
----------
----------
other:
Environment
other:
Module
The
environment to merge into the current Environment
.
The
module to merge into the current Module
.
"""
"""
if
isinstance
(
other
,
dict
):
if
isinstance
(
other
,
dict
):
other
=
Environment
(
other
)
other
=
Module
(
other
)
return
_
env
.
Environment
_Update
(
self
,
other
)
return
_
module
.
Module
_Update
(
self
,
other
)
def
get_global_var
(
self
,
name
):
def
get_global_var
(
self
,
name
):
"""Get a global variable in the function by name.
"""Get a global variable in the function by name.
...
@@ -99,4 +99,4 @@ class Environment(RelayNode):
...
@@ -99,4 +99,4 @@ class Environment(RelayNode):
------
------
tvm.TVMError if we cannot find corresponding global var.
tvm.TVMError if we cannot find corresponding global var.
"""
"""
return
_
env
.
Environment
_GetGlobalVar
(
self
,
name
)
return
_
module
.
Module
_GetGlobalVar
(
self
,
name
)
src/relay/interpreter.cc
View file @
ead3ac6c
...
@@ -183,7 +183,7 @@ struct ExprEqual {
...
@@ -183,7 +183,7 @@ struct ExprEqual {
};
};
struct
Interpreter
:
ExprFunctor
<
Value
(
const
Expr
&
n
)
>
{
struct
Interpreter
:
ExprFunctor
<
Value
(
const
Expr
&
n
)
>
{
Environment
env
;
Module
mod
;
Stack
stack
;
Stack
stack
;
using
JitKey
=
Function
;
using
JitKey
=
Function
;
...
@@ -197,8 +197,8 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
...
@@ -197,8 +197,8 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
return
f
();
return
f
();
}
}
Interpreter
(
Environment
env
)
:
env
(
env
),
operator_map_
()
{}
Interpreter
(
Module
mod
)
:
mod
(
mod
),
operator_map_
()
{}
Interpreter
(
Environment
env
,
OpMap
operator_map
)
:
env
(
env
),
operator_map_
(
operator_map
)
{}
Interpreter
(
Module
mod
,
OpMap
operator_map
)
:
mod
(
mod
),
operator_map_
(
operator_map
)
{}
void
extend
(
const
Var
&
id
,
Value
v
)
{
void
extend
(
const
Var
&
id
,
Value
v
)
{
this
->
stack
.
current_frame
().
locals
.
Set
(
id
,
v
);
this
->
stack
.
current_frame
().
locals
.
Set
(
id
,
v
);
...
@@ -223,7 +223,7 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
...
@@ -223,7 +223,7 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
}
}
Value
VisitExpr_
(
const
GlobalVarNode
*
op
)
override
{
Value
VisitExpr_
(
const
GlobalVarNode
*
op
)
override
{
return
Eval
(
this
->
env
->
Lookup
(
GetRef
<
GlobalVar
>
(
op
)));
return
Eval
(
this
->
mod
->
Lookup
(
GetRef
<
GlobalVar
>
(
op
)));
}
}
Value
VisitExpr_
(
const
OpNode
*
id
)
override
{
Value
VisitExpr_
(
const
OpNode
*
id
)
override
{
...
@@ -251,14 +251,14 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
...
@@ -251,14 +251,14 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
Value
VisitExpr_
(
const
FunctionNode
*
func_node
)
override
{
Value
VisitExpr_
(
const
FunctionNode
*
func_node
)
override
{
auto
func
=
GetRef
<
Function
>
(
func_node
);
auto
func
=
GetRef
<
Function
>
(
func_node
);
tvm
::
Map
<
Var
,
Value
>
captured_
env
;
tvm
::
Map
<
Var
,
Value
>
captured_
mod
;
Array
<
Var
>
free_vars
=
FreeVars
(
func
);
Array
<
Var
>
free_vars
=
FreeVars
(
func
);
for
(
const
auto
&
var
:
free_vars
)
{
for
(
const
auto
&
var
:
free_vars
)
{
captured_
env
.
Set
(
var
,
Eval
(
var
));
captured_
mod
.
Set
(
var
,
Eval
(
var
));
}
}
return
ClosureNode
::
make
(
captured_
env
,
func
);
return
ClosureNode
::
make
(
captured_
mod
,
func
);
}
}
inline
Value
InvokeCompiledOp
(
PackedFunc
func
,
const
Array
<
Value
>&
args
,
inline
Value
InvokeCompiledOp
(
PackedFunc
func
,
const
Array
<
Value
>&
args
,
...
@@ -315,7 +315,7 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
...
@@ -315,7 +315,7 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
locals
.
Set
(
func
->
params
[
i
],
args
[
i
]);
locals
.
Set
(
func
->
params
[
i
],
args
[
i
]);
}
}
// Add the var to value mappings from the Closure's
env
ironment.
// Add the var to value mappings from the Closure's
mod
ironment.
for
(
auto
it
=
closure
->
env
.
begin
();
it
!=
closure
->
env
.
end
();
++
it
)
{
for
(
auto
it
=
closure
->
env
.
begin
();
it
!=
closure
->
env
.
end
();
++
it
)
{
CHECK_EQ
(
locals
.
count
((
*
it
).
first
),
0
);
CHECK_EQ
(
locals
.
count
((
*
it
).
first
),
0
);
locals
.
Set
((
*
it
).
first
,
(
*
it
).
second
);
locals
.
Set
((
*
it
).
first
,
(
*
it
).
second
);
...
@@ -384,9 +384,9 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
...
@@ -384,9 +384,9 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
}
}
};
};
Interpreter
::
OpMap
CompileOperators
(
const
Environment
&
env
,
const
Expr
&
e
)
{
Interpreter
::
OpMap
CompileOperators
(
const
Module
&
mod
,
const
Expr
&
e
)
{
Interpreter
::
OpMap
op_map
;
Interpreter
::
OpMap
op_map
;
auto
lowered_ops
=
LowerOps
(
env
,
e
);
auto
lowered_ops
=
LowerOps
(
mod
,
e
);
RELAY_LOG
(
INFO
)
<<
"LoweredFuncs: "
<<
lowered_ops
<<
std
::
endl
;
RELAY_LOG
(
INFO
)
<<
"LoweredFuncs: "
<<
lowered_ops
<<
std
::
endl
;
if
(
lowered_ops
.
size
())
{
if
(
lowered_ops
.
size
())
{
const
PackedFunc
*
fbuild_ptr
=
Registry
::
Get
(
"relay.op.compiler._build"
);
const
PackedFunc
*
fbuild_ptr
=
Registry
::
Get
(
"relay.op.compiler._build"
);
...
@@ -399,7 +399,7 @@ Interpreter::OpMap CompileOperators(const Environment& env, const Expr& e) {
...
@@ -399,7 +399,7 @@ Interpreter::OpMap CompileOperators(const Environment& env, const Expr& e) {
lowered_funcs
.
push_back
(
lop
->
lowered_func
);
lowered_funcs
.
push_back
(
lop
->
lowered_func
);
}
}
Module
module
=
fbuild
(
lowered_funcs
);
runtime
::
Module
module
=
fbuild
(
lowered_funcs
);
// Loop over the lowered operations to map them into the operator map.
// Loop over the lowered operations to map them into the operator map.
for
(
auto
lop
:
lowered_ops
)
{
for
(
auto
lop
:
lowered_ops
)
{
...
@@ -415,17 +415,17 @@ Interpreter::OpMap CompileOperators(const Environment& env, const Expr& e) {
...
@@ -415,17 +415,17 @@ Interpreter::OpMap CompileOperators(const Environment& env, const Expr& e) {
return
op_map
;
return
op_map
;
}
}
Value
Evaluate
(
Environment
env
,
Expr
e
)
{
Value
Evaluate
(
Module
mod
,
Expr
e
)
{
auto
op_map
=
CompileOperators
(
env
,
e
);
auto
op_map
=
CompileOperators
(
mod
,
e
);
Interpreter
interp
(
env
,
op_map
);
Interpreter
interp
(
mod
,
op_map
);
return
interp
.
Eval
(
e
);
return
interp
.
Eval
(
e
);
}
}
TVM_REGISTER_API
(
"relay._interpreter.evaluate"
)
TVM_REGISTER_API
(
"relay._interpreter.evaluate"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Environment
env
=
args
[
0
];
Module
mod
=
args
[
0
];
Expr
expr
=
args
[
1
];
Expr
expr
=
args
[
1
];
*
ret
=
Evaluate
(
env
,
expr
);
*
ret
=
Evaluate
(
mod
,
expr
);
});
});
}
// namespace relay
}
// namespace relay
...
...
src/relay/ir/
environment
.cc
→
src/relay/ir/
module
.cc
View file @
ead3ac6c
/*!
/*!
* Copyright (c) 2018 by Contributors
* Copyright (c) 2018 by Contributors
* \file
environment
.cc
* \file
module
.cc
* \brief The global
environment
in Relay.
* \brief The global
module
in Relay.
*/
*/
#include <tvm/relay/
environment
.h>
#include <tvm/relay/
module
.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/pass.h>
#include <sstream>
#include <sstream>
...
@@ -13,8 +13,8 @@ namespace relay {
...
@@ -13,8 +13,8 @@ namespace relay {
using
tvm
::
IRPrinter
;
using
tvm
::
IRPrinter
;
using
namespace
runtime
;
using
namespace
runtime
;
Environment
Environment
Node
::
make
(
tvm
::
Map
<
GlobalVar
,
Function
>
global_funcs
)
{
Module
Module
Node
::
make
(
tvm
::
Map
<
GlobalVar
,
Function
>
global_funcs
)
{
auto
n
=
make_node
<
Environment
Node
>
();
auto
n
=
make_node
<
Module
Node
>
();
n
->
functions
=
std
::
move
(
global_funcs
);
n
->
functions
=
std
::
move
(
global_funcs
);
for
(
const
auto
&
kv
:
n
->
functions
)
{
for
(
const
auto
&
kv
:
n
->
functions
)
{
...
@@ -23,22 +23,22 @@ Environment EnvironmentNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
...
@@ -23,22 +23,22 @@ Environment EnvironmentNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
<<
"Duplicate global function name "
<<
kv
.
first
->
name_hint
;
<<
"Duplicate global function name "
<<
kv
.
first
->
name_hint
;
n
->
global_var_map_
.
Set
(
kv
.
first
->
name_hint
,
kv
.
first
);
n
->
global_var_map_
.
Set
(
kv
.
first
->
name_hint
,
kv
.
first
);
}
}
return
Environment
(
n
);
return
Module
(
n
);
}
}
GlobalVar
Environment
Node
::
GetGlobalVar
(
const
std
::
string
&
name
)
{
GlobalVar
Module
Node
::
GetGlobalVar
(
const
std
::
string
&
name
)
{
auto
it
=
global_var_map_
.
find
(
name
);
auto
it
=
global_var_map_
.
find
(
name
);
CHECK
(
it
!=
global_var_map_
.
end
())
CHECK
(
it
!=
global_var_map_
.
end
())
<<
"Cannot find global var "
<<
name
<<
" in the
Environment
"
;
<<
"Cannot find global var "
<<
name
<<
" in the
Module
"
;
return
(
*
it
).
second
;
return
(
*
it
).
second
;
}
}
void
Environment
Node
::
Add
(
const
GlobalVar
&
var
,
void
Module
Node
::
Add
(
const
GlobalVar
&
var
,
const
Function
&
func
,
const
Function
&
func
,
bool
update
)
{
bool
update
)
{
// Type check the item before we add it to the
env
ironment.
// Type check the item before we add it to the
mod
ironment.
auto
env
=
GetRef
<
Environment
>
(
this
);
auto
mod
=
GetRef
<
Module
>
(
this
);
Function
checked_func
=
InferType
(
func
,
env
,
var
);
Function
checked_func
=
InferType
(
func
,
mod
,
var
);
auto
type
=
checked_func
->
checked_type
();
auto
type
=
checked_func
->
checked_type
();
CHECK
(
type
.
as
<
IncompleteTypeNode
>
()
==
nullptr
);
CHECK
(
type
.
as
<
IncompleteTypeNode
>
()
==
nullptr
);
if
(
functions
.
find
(
var
)
!=
functions
.
end
())
{
if
(
functions
.
find
(
var
)
!=
functions
.
end
())
{
...
@@ -46,7 +46,7 @@ void EnvironmentNode::Add(const GlobalVar& var,
...
@@ -46,7 +46,7 @@ void EnvironmentNode::Add(const GlobalVar& var,
<<
"Already have definition for "
<<
var
->
name_hint
;
<<
"Already have definition for "
<<
var
->
name_hint
;
auto
old_type
=
functions
[
var
].
as
<
FunctionNode
>
()
->
checked_type
();
auto
old_type
=
functions
[
var
].
as
<
FunctionNode
>
()
->
checked_type
();
CHECK
(
AlphaEqual
(
type
,
old_type
))
CHECK
(
AlphaEqual
(
type
,
old_type
))
<<
"
Environment
#update changes type, not possible in this mode."
;
<<
"
Module
#update changes type, not possible in this mode."
;
}
}
this
->
functions
.
Set
(
var
,
checked_func
);
this
->
functions
.
Set
(
var
,
checked_func
);
...
@@ -62,79 +62,79 @@ void EnvironmentNode::Add(const GlobalVar& var,
...
@@ -62,79 +62,79 @@ void EnvironmentNode::Add(const GlobalVar& var,
global_var_map_
.
Set
(
var
->
name_hint
,
var
);
global_var_map_
.
Set
(
var
->
name_hint
,
var
);
}
}
void
Environment
Node
::
Update
(
const
GlobalVar
&
var
,
const
Function
&
func
)
{
void
Module
Node
::
Update
(
const
GlobalVar
&
var
,
const
Function
&
func
)
{
this
->
Add
(
var
,
func
,
true
);
this
->
Add
(
var
,
func
,
true
);
}
}
void
Environment
Node
::
Remove
(
const
GlobalVar
&
var
)
{
void
Module
Node
::
Remove
(
const
GlobalVar
&
var
)
{
auto
functions_node
=
this
->
functions
.
CopyOnWrite
();
auto
functions_node
=
this
->
functions
.
CopyOnWrite
();
functions_node
->
data
.
erase
(
var
.
node_
);
functions_node
->
data
.
erase
(
var
.
node_
);
auto
gvar_node
=
global_var_map_
.
CopyOnWrite
();
auto
gvar_node
=
global_var_map_
.
CopyOnWrite
();
gvar_node
->
data
.
erase
(
var
->
name_hint
);
gvar_node
->
data
.
erase
(
var
->
name_hint
);
}
}
Function
Environment
Node
::
Lookup
(
const
GlobalVar
&
var
)
{
Function
Module
Node
::
Lookup
(
const
GlobalVar
&
var
)
{
auto
it
=
functions
.
find
(
var
);
auto
it
=
functions
.
find
(
var
);
CHECK
(
it
!=
functions
.
end
())
CHECK
(
it
!=
functions
.
end
())
<<
"There is no definition of "
<<
var
->
name_hint
;
<<
"There is no definition of "
<<
var
->
name_hint
;
return
(
*
it
).
second
;
return
(
*
it
).
second
;
}
}
Function
Environment
Node
::
Lookup
(
const
std
::
string
&
name
)
{
Function
Module
Node
::
Lookup
(
const
std
::
string
&
name
)
{
GlobalVar
id
=
this
->
GetGlobalVar
(
name
);
GlobalVar
id
=
this
->
GetGlobalVar
(
name
);
return
this
->
Lookup
(
id
);
return
this
->
Lookup
(
id
);
}
}
void
EnvironmentNode
::
Update
(
const
Environment
&
env
)
{
void
ModuleNode
::
Update
(
const
Module
&
mod
)
{
for
(
auto
pair
:
env
->
functions
)
{
for
(
auto
pair
:
mod
->
functions
)
{
this
->
Update
(
pair
.
first
,
pair
.
second
);
this
->
Update
(
pair
.
first
,
pair
.
second
);
}
}
}
}
TVM_REGISTER_NODE_TYPE
(
Environment
Node
);
TVM_REGISTER_NODE_TYPE
(
Module
Node
);
TVM_REGISTER_API
(
"relay._make.
Environment
"
)
TVM_REGISTER_API
(
"relay._make.
Module
"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
Environment
Node
::
make
(
args
[
0
]);
*
ret
=
Module
Node
::
make
(
args
[
0
]);
});
});
TVM_REGISTER_API
(
"relay._
env.Environment
_Add"
)
TVM_REGISTER_API
(
"relay._
module.Module
_Add"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Environment
env
=
args
[
0
];
Module
mod
=
args
[
0
];
env
->
Add
(
args
[
1
],
args
[
2
],
args
[
3
]);
mod
->
Add
(
args
[
1
],
args
[
2
],
args
[
3
]);
});
});
TVM_REGISTER_API
(
"relay._
env.Environment
_GetGlobalVar"
)
TVM_REGISTER_API
(
"relay._
module.Module
_GetGlobalVar"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Environment
env
=
args
[
0
];
Module
mod
=
args
[
0
];
*
ret
=
env
->
GetGlobalVar
(
args
[
1
]);
*
ret
=
mod
->
GetGlobalVar
(
args
[
1
]);
});
});
TVM_REGISTER_API
(
"relay._
env.Environment
_Lookup"
)
TVM_REGISTER_API
(
"relay._
module.Module
_Lookup"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Environment
env
=
args
[
0
];
Module
mod
=
args
[
0
];
GlobalVar
var
=
args
[
1
];
GlobalVar
var
=
args
[
1
];
*
ret
=
env
->
Lookup
(
var
);
*
ret
=
mod
->
Lookup
(
var
);
});
});
TVM_REGISTER_API
(
"relay._
env.Environment
_Lookup_str"
)
TVM_REGISTER_API
(
"relay._
module.Module
_Lookup_str"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Environment
env
=
args
[
0
];
Module
mod
=
args
[
0
];
std
::
string
var_name
=
args
[
1
];
std
::
string
var_name
=
args
[
1
];
auto
var
=
env
->
GetGlobalVar
(
var_name
);
auto
var
=
mod
->
GetGlobalVar
(
var_name
);
*
ret
=
env
->
Lookup
(
var
);
*
ret
=
mod
->
Lookup
(
var
);
});
});
TVM_REGISTER_API
(
"relay._
env.Environment
_Update"
)
TVM_REGISTER_API
(
"relay._
module.Module
_Update"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Environment
env
=
args
[
0
];
Module
mod
=
args
[
0
];
env
->
Update
(
args
[
1
]);
mod
->
Update
(
args
[
1
]);
});
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
Environment
Node
>
(
.
set_dispatch
<
Module
Node
>
(
[](
const
Environment
Node
*
node
,
tvm
::
IRPrinter
*
p
)
{
[](
const
Module
Node
*
node
,
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"
Environment
Node( "
<<
node
->
functions
<<
")"
;
p
->
stream
<<
"
Module
Node( "
<<
node
->
functions
<<
")"
;
});
});
}
// namespace relay
}
// namespace relay
...
...
src/relay/ir/text_printer.cc
View file @
ead3ac6c
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
* \file text_printer.cc
* \file text_printer.cc
* \brief Text printer to print relay in text form.
* \brief Text printer to print relay in text form.
*/
*/
#include <tvm/relay/
environment
.h>
#include <tvm/relay/
module
.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/expr_functor.h>
#include <sstream>
#include <sstream>
#include "type_functor.h"
#include "type_functor.h"
...
@@ -133,8 +133,8 @@ class TextPrinter :
...
@@ -133,8 +133,8 @@ class TextPrinter :
std
::
string
Print
(
const
NodeRef
&
node
)
{
std
::
string
Print
(
const
NodeRef
&
node
)
{
if
(
node
.
as
<
FunctionNode
>
())
{
if
(
node
.
as
<
FunctionNode
>
())
{
this
->
PrintFunc
(
Downcast
<
Function
>
(
node
));
this
->
PrintFunc
(
Downcast
<
Function
>
(
node
));
}
else
if
(
node
.
as
<
Environment
Node
>
())
{
}
else
if
(
node
.
as
<
Module
Node
>
())
{
this
->
PrintEnv
(
Downcast
<
Environment
>
(
node
));
this
->
PrintEnv
(
Downcast
<
Module
>
(
node
));
}
else
if
(
node
.
as_derived
<
TypeNode
>
())
{
}
else
if
(
node
.
as_derived
<
TypeNode
>
())
{
this
->
PrintType
(
Downcast
<
Type
>
(
node
),
stream_
);
this
->
PrintType
(
Downcast
<
Type
>
(
node
),
stream_
);
}
else
if
(
node
.
as_derived
<
ExprNode
>
())
{
}
else
if
(
node
.
as_derived
<
ExprNode
>
())
{
...
@@ -158,9 +158,9 @@ class TextPrinter :
...
@@ -158,9 +158,9 @@ class TextPrinter :
stream_
<<
"
\n
"
;
stream_
<<
"
\n
"
;
}
}
void
PrintEnv
(
const
Environment
&
env
)
{
void
PrintEnv
(
const
Module
&
mod
)
{
int
counter
=
0
;
int
counter
=
0
;
for
(
const
auto
&
kv
:
env
->
functions
)
{
for
(
const
auto
&
kv
:
mod
->
functions
)
{
std
::
ostringstream
os
;
std
::
ostringstream
os
;
if
(
counter
++
!=
0
)
{
if
(
counter
++
!=
0
)
{
stream_
<<
"
\n
"
;
stream_
<<
"
\n
"
;
...
...
src/relay/pass/fuse_ops.cc
View file @
ead3ac6c
...
@@ -20,12 +20,12 @@ namespace relay {
...
@@ -20,12 +20,12 @@ namespace relay {
using
namespace
runtime
;
using
namespace
runtime
;
struct
AbstractFusableOps
:
ExprMutator
{
struct
AbstractFusableOps
:
ExprMutator
{
Environment
env
;
Module
mod
;
Array
<
GlobalVar
>
fusable_funcs
;
Array
<
GlobalVar
>
fusable_funcs
;
int
counter
=
0
;
int
counter
=
0
;
size_t
expr_hash
;
size_t
expr_hash
;
AbstractFusableOps
(
Environment
env
,
size_t
expr_hash
)
:
env
(
env
),
expr_hash
(
expr_hash
)
{}
AbstractFusableOps
(
Module
mod
,
size_t
expr_hash
)
:
mod
(
mod
),
expr_hash
(
expr_hash
)
{}
Expr
VisitExpr_
(
const
CallNode
*
call
)
{
Expr
VisitExpr_
(
const
CallNode
*
call
)
{
if
(
auto
op_node
=
call
->
op
.
as
<
OpNode
>
())
{
if
(
auto
op_node
=
call
->
op
.
as
<
OpNode
>
())
{
...
@@ -55,7 +55,7 @@ struct AbstractFusableOps : ExprMutator {
...
@@ -55,7 +55,7 @@ struct AbstractFusableOps : ExprMutator {
func_name
+=
"_"
;
func_name
+=
"_"
;
func_name
+=
std
::
to_string
(
expr_hash
);
func_name
+=
std
::
to_string
(
expr_hash
);
auto
gv
=
GlobalVarNode
::
make
(
func_name
);
auto
gv
=
GlobalVarNode
::
make
(
func_name
);
env
->
Add
(
gv
,
func
);
mod
->
Add
(
gv
,
func
);
fusable_funcs
.
push_back
(
gv
);
fusable_funcs
.
push_back
(
gv
);
return
CallNode
::
make
(
gv
,
args
,
Attrs
());
return
CallNode
::
make
(
gv
,
args
,
Attrs
());
}
else
{
}
else
{
...
@@ -64,12 +64,12 @@ struct AbstractFusableOps : ExprMutator {
...
@@ -64,12 +64,12 @@ struct AbstractFusableOps : ExprMutator {
}
}
};
};
Expr
FuseOps
(
const
Environment
&
env
,
const
Expr
&
e
)
{
Expr
FuseOps
(
const
Module
&
mod
,
const
Expr
&
e
)
{
// First we convert all chains of fusable ops into
// First we convert all chains of fusable ops into
// abstracted functions which we mark as primtive
// abstracted functions which we mark as primtive
// then we convert these primtive functions into
// then we convert these primtive functions into
// new operators.
// new operators.
auto
abstract
=
AbstractFusableOps
(
env
,
StructuralHash
()(
e
));
auto
abstract
=
AbstractFusableOps
(
mod
,
StructuralHash
()(
e
));
auto
abstracted_e
=
abstract
.
VisitExpr
(
e
);
auto
abstracted_e
=
abstract
.
VisitExpr
(
e
);
RELAY_LOG
(
INFO
)
<<
"FuseOps: before="
<<
e
RELAY_LOG
(
INFO
)
<<
"FuseOps: before="
<<
e
<<
"Fuse: after="
<<
abstracted_e
;
<<
"Fuse: after="
<<
abstracted_e
;
...
...
src/relay/pass/kind_check.cc
View file @
ead3ac6c
...
@@ -99,7 +99,7 @@ struct KindChecker : TypeVisitor {
...
@@ -99,7 +99,7 @@ struct KindChecker : TypeVisitor {
}
}
};
};
bool
KindCheck
(
const
Type
&
t
,
const
Environment
&
env
)
{
bool
KindCheck
(
const
Type
&
t
,
const
Module
&
mod
)
{
KindChecker
kc
;
KindChecker
kc
;
return
kc
.
Check
(
t
);
return
kc
.
Check
(
t
);
}
}
...
@@ -107,7 +107,7 @@ bool KindCheck(const Type& t, const Environment& env) {
...
@@ -107,7 +107,7 @@ bool KindCheck(const Type& t, const Environment& env) {
TVM_REGISTER_API
(
"relay._ir_pass.check_kind"
)
TVM_REGISTER_API
(
"relay._ir_pass.check_kind"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
if
(
args
.
size
()
==
1
)
{
if
(
args
.
size
()
==
1
)
{
*
ret
=
KindCheck
(
args
[
0
],
Environment
Node
::
make
({}));
*
ret
=
KindCheck
(
args
[
0
],
Module
Node
::
make
({}));
}
else
{
}
else
{
*
ret
=
KindCheck
(
args
[
0
],
args
[
1
]);
*
ret
=
KindCheck
(
args
[
0
],
args
[
1
]);
}
}
...
...
src/relay/pass/lower_ops.cc
View file @
ead3ac6c
...
@@ -28,12 +28,12 @@ LoweredOp LoweredOpNode::make(Function func, LoweredFunc lowered_func) {
...
@@ -28,12 +28,12 @@ LoweredOp LoweredOpNode::make(Function func, LoweredFunc lowered_func) {
}
}
struct
AbstractLocalFunctions
:
ExprMutator
{
struct
AbstractLocalFunctions
:
ExprMutator
{
Environment
env
;
Module
mod
;
size_t
expr_hash
;
size_t
expr_hash
;
int
counter
=
0
;
int
counter
=
0
;
std
::
unordered_set
<
GlobalVar
,
NodeHash
,
NodeEqual
>
visited_funcs
;
std
::
unordered_set
<
GlobalVar
,
NodeHash
,
NodeEqual
>
visited_funcs
;
explicit
AbstractLocalFunctions
(
Environment
env
)
explicit
AbstractLocalFunctions
(
Module
mod
)
:
env
(
env
),
expr_hash
(
0
),
counter
(
0
),
visited_funcs
()
{}
:
mod
(
mod
),
expr_hash
(
0
),
counter
(
0
),
visited_funcs
()
{}
Expr
Abstract
(
const
Expr
&
e
)
{
Expr
Abstract
(
const
Expr
&
e
)
{
expr_hash
=
StructuralHash
()(
e
);
expr_hash
=
StructuralHash
()(
e
);
...
@@ -44,7 +44,7 @@ struct AbstractLocalFunctions : ExprMutator {
...
@@ -44,7 +44,7 @@ struct AbstractLocalFunctions : ExprMutator {
auto
gvar
=
GetRef
<
GlobalVar
>
(
gvar_node
);
auto
gvar
=
GetRef
<
GlobalVar
>
(
gvar_node
);
auto
it
=
visited_funcs
.
find
(
gvar
);
auto
it
=
visited_funcs
.
find
(
gvar
);
if
(
it
==
visited_funcs
.
end
())
{
if
(
it
==
visited_funcs
.
end
())
{
auto
func
=
env
->
Lookup
(
gvar
);
auto
func
=
mod
->
Lookup
(
gvar
);
visited_funcs
.
insert
(
gvar
);
visited_funcs
.
insert
(
gvar
);
auto
new_func
=
FunctionNode
::
make
(
auto
new_func
=
FunctionNode
::
make
(
func
->
params
,
func
->
params
,
...
@@ -52,7 +52,7 @@ struct AbstractLocalFunctions : ExprMutator {
...
@@ -52,7 +52,7 @@ struct AbstractLocalFunctions : ExprMutator {
func
->
ret_type
,
func
->
ret_type
,
func
->
type_params
,
func
->
type_params
,
func
->
attrs
);
func
->
attrs
);
env
->
Update
(
gvar
,
new_func
);
mod
->
Update
(
gvar
,
new_func
);
}
}
return
gvar
;
return
gvar
;
}
}
...
@@ -70,7 +70,7 @@ struct AbstractLocalFunctions : ExprMutator {
...
@@ -70,7 +70,7 @@ struct AbstractLocalFunctions : ExprMutator {
abs_func
+=
std
::
to_string
(
expr_hash
);
abs_func
+=
std
::
to_string
(
expr_hash
);
auto
gv
=
GlobalVarNode
::
make
(
abs_func
);
auto
gv
=
GlobalVarNode
::
make
(
abs_func
);
auto
lifted_func
=
FunctionNode
::
make
(
params
,
func
,
Type
(),
{},
{});
auto
lifted_func
=
FunctionNode
::
make
(
params
,
func
,
Type
(),
{},
{});
env
->
Add
(
gv
,
lifted_func
);
mod
->
Add
(
gv
,
lifted_func
);
Array
<
Expr
>
args
;
Array
<
Expr
>
args
;
for
(
auto
free_var
:
free_vars
)
{
for
(
auto
free_var
:
free_vars
)
{
args
.
push_back
(
free_var
);
args
.
push_back
(
free_var
);
...
@@ -80,8 +80,8 @@ struct AbstractLocalFunctions : ExprMutator {
...
@@ -80,8 +80,8 @@ struct AbstractLocalFunctions : ExprMutator {
};
};
struct
LiveFunctions
:
ExprVisitor
{
struct
LiveFunctions
:
ExprVisitor
{
Environment
env
;
Module
mod
;
explicit
LiveFunctions
(
Environment
env
)
:
env
(
env
),
global_funcs
()
{}
explicit
LiveFunctions
(
Module
mod
)
:
mod
(
mod
),
global_funcs
()
{}
std
::
unordered_set
<
GlobalVar
,
NodeHash
,
NodeEqual
>
visited_funcs
;
std
::
unordered_set
<
GlobalVar
,
NodeHash
,
NodeEqual
>
visited_funcs
;
std
::
unordered_set
<
GlobalVar
,
NodeHash
,
NodeEqual
>
global_funcs
;
std
::
unordered_set
<
GlobalVar
,
NodeHash
,
NodeEqual
>
global_funcs
;
...
@@ -100,7 +100,7 @@ struct LiveFunctions : ExprVisitor {
...
@@ -100,7 +100,7 @@ struct LiveFunctions : ExprVisitor {
GlobalVar
var
=
GetRef
<
GlobalVar
>
(
var_node
);
GlobalVar
var
=
GetRef
<
GlobalVar
>
(
var_node
);
auto
it
=
visited_funcs
.
find
(
var
);
auto
it
=
visited_funcs
.
find
(
var
);
if
(
it
==
visited_funcs
.
end
())
{
if
(
it
==
visited_funcs
.
end
())
{
auto
func
=
env
->
Lookup
(
var
);
auto
func
=
mod
->
Lookup
(
var
);
visited_funcs
.
insert
(
var
);
visited_funcs
.
insert
(
var
);
// The last pass has trasnformed functions of the form:
// The last pass has trasnformed functions of the form:
//
//
...
@@ -134,7 +134,7 @@ struct LiveFunctions : ExprVisitor {
...
@@ -134,7 +134,7 @@ struct LiveFunctions : ExprVisitor {
RELAY_LOG
(
INFO
)
<<
"LiveOps: CallNode="
<<
GetRef
<
Call
>
(
call
);
RELAY_LOG
(
INFO
)
<<
"LiveOps: CallNode="
<<
GetRef
<
Call
>
(
call
);
if
(
auto
gv_node
=
call
->
op
.
as
<
GlobalVarNode
>
())
{
if
(
auto
gv_node
=
call
->
op
.
as
<
GlobalVarNode
>
())
{
GlobalVar
gvar
=
GetRef
<
GlobalVar
>
(
gv_node
);
GlobalVar
gvar
=
GetRef
<
GlobalVar
>
(
gv_node
);
Function
func
=
env
->
Lookup
(
gvar
);
Function
func
=
mod
->
Lookup
(
gvar
);
auto
attr
=
FunctionGetAttr
(
func
,
"Primitive"
);
auto
attr
=
FunctionGetAttr
(
func
,
"Primitive"
);
...
@@ -159,15 +159,15 @@ using FCompute = TypedPackedFunc<Array<Tensor>(
...
@@ -159,15 +159,15 @@ using FCompute = TypedPackedFunc<Array<Tensor>(
using
FSchedule
=
TypedPackedFunc
<
Schedule
(
const
Array
<
Tensor
>&
,
std
::
string
)
>
;
using
FSchedule
=
TypedPackedFunc
<
Schedule
(
const
Array
<
Tensor
>&
,
std
::
string
)
>
;
/*! \brief Return the set of operators in their TVM format. */
/*! \brief Return the set of operators in their TVM format. */
Array
<
LoweredOp
>
LowerOps
(
const
Environment
&
env
,
const
Expr
&
e
,
Array
<
LoweredOp
>
LowerOps
(
const
Module
&
mod
,
const
Expr
&
e
,
const
std
::
string
&
target
)
{
const
std
::
string
&
target
)
{
RELAY_LOG
(
INFO
)
<<
"LowerOps: e="
<<
e
;
RELAY_LOG
(
INFO
)
<<
"LowerOps: e="
<<
e
;
auto
flower_ptr
=
Registry
::
Get
(
"relay.op.compiler._lower"
);
auto
flower_ptr
=
Registry
::
Get
(
"relay.op.compiler._lower"
);
CHECK
(
flower_ptr
);
CHECK
(
flower_ptr
);
PackedFunc
flower
=
*
flower_ptr
;
PackedFunc
flower
=
*
flower_ptr
;
auto
abstracted_e
=
AbstractLocalFunctions
(
env
).
Abstract
(
e
);
auto
abstracted_e
=
AbstractLocalFunctions
(
mod
).
Abstract
(
e
);
auto
live_funcs
=
LiveFunctions
(
env
);
auto
live_funcs
=
LiveFunctions
(
mod
);
live_funcs
.
VisitExpr
(
abstracted_e
);
live_funcs
.
VisitExpr
(
abstracted_e
);
auto
schedule_reg
=
Op
::
GetAttr
<
FSchedule
>
(
"FTVMSchedule"
);
auto
schedule_reg
=
Op
::
GetAttr
<
FSchedule
>
(
"FTVMSchedule"
);
...
@@ -176,7 +176,7 @@ Array<LoweredOp> LowerOps(const Environment& env, const Expr& e,
...
@@ -176,7 +176,7 @@ Array<LoweredOp> LowerOps(const Environment& env, const Expr& e,
Array
<
LoweredOp
>
lowered_funcs
;
Array
<
LoweredOp
>
lowered_funcs
;
for
(
auto
func_name
:
live_funcs
.
global_funcs
)
{
for
(
auto
func_name
:
live_funcs
.
global_funcs
)
{
auto
func
=
env
->
Lookup
(
func_name
);
auto
func
=
mod
->
Lookup
(
func_name
);
auto
call
=
Downcast
<
Call
>
(
func
->
body
);
auto
call
=
Downcast
<
Call
>
(
func
->
body
);
auto
op_node
=
call
->
op
.
as
<
OpNode
>
();
auto
op_node
=
call
->
op
.
as
<
OpNode
>
();
CHECK
(
op_node
)
<<
"violated invariant that primtiive calls contain a single op call"
;
CHECK
(
op_node
)
<<
"violated invariant that primtiive calls contain a single op call"
;
...
@@ -205,7 +205,7 @@ Array<LoweredOp> LowerOps(const Environment& env, const Expr& e,
...
@@ -205,7 +205,7 @@ Array<LoweredOp> LowerOps(const Environment& env, const Expr& e,
LoweredFunc
lf
=
LoweredFunc
lf
=
flower
(
op
->
name
+
std
::
to_string
(
hash
),
schedule
,
inputs
,
outputs
);
flower
(
op
->
name
+
std
::
to_string
(
hash
),
schedule
,
inputs
,
outputs
);
func
=
FunctionSetAttr
(
func
,
"LoweredFunc"
,
lf
);
func
=
FunctionSetAttr
(
func
,
"LoweredFunc"
,
lf
);
env
->
Add
(
func_name
,
func
,
true
);
mod
->
Add
(
func_name
,
func
,
true
);
lowered_funcs
.
push_back
(
LoweredOpNode
::
make
(
func
,
lf
));
lowered_funcs
.
push_back
(
LoweredOpNode
::
make
(
func
,
lf
));
}
}
...
...
src/relay/pass/type_infer.cc
View file @
ead3ac6c
...
@@ -104,8 +104,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
...
@@ -104,8 +104,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// constructors
// constructors
TypeInferencer
()
{
TypeInferencer
()
{
}
}
explicit
TypeInferencer
(
Environment
env
)
explicit
TypeInferencer
(
Module
mod
)
:
env_
(
env
)
{
:
mod_
(
mod
)
{
}
}
// inference the type of expr.
// inference the type of expr.
...
@@ -115,7 +115,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
...
@@ -115,7 +115,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// type resolver that maps back to type
// type resolver that maps back to type
class
Resolver
;
class
Resolver
;
// internal environment
// internal environment
Environment
env
_
;
Module
mod
_
;
// map from expression to checked type
// map from expression to checked type
// type inferencer will populate it up
// type inferencer will populate it up
std
::
unordered_map
<
Expr
,
ResolvedTypeInfo
,
NodeHash
,
NodeEqual
>
type_map_
;
std
::
unordered_map
<
Expr
,
ResolvedTypeInfo
,
NodeHash
,
NodeEqual
>
type_map_
;
...
@@ -164,9 +164,9 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
...
@@ -164,9 +164,9 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
Type
VisitExpr_
(
const
GlobalVarNode
*
op
)
final
{
Type
VisitExpr_
(
const
GlobalVarNode
*
op
)
final
{
GlobalVar
var
=
GetRef
<
GlobalVar
>
(
op
);
GlobalVar
var
=
GetRef
<
GlobalVar
>
(
op
);
CHECK
(
env
_
.
defined
())
CHECK
(
mod
_
.
defined
())
<<
"Cannot do type inference without a global variable"
;
<<
"Cannot do type inference without a global variable"
;
Expr
e
=
env
_
->
Lookup
(
var
);
Expr
e
=
mod
_
->
Lookup
(
var
);
return
e
->
checked_type
();
return
e
->
checked_type
();
}
}
...
@@ -511,20 +511,20 @@ Expr TypeInferencer::Infer(Expr expr) {
...
@@ -511,20 +511,20 @@ Expr TypeInferencer::Infer(Expr expr) {
}
}
Expr
InferType
(
const
Expr
&
expr
,
const
Environment
&
env
)
{
Expr
InferType
(
const
Expr
&
expr
,
const
Module
&
mod
)
{
auto
e
=
TypeInferencer
(
env
).
Infer
(
expr
);
auto
e
=
TypeInferencer
(
mod
).
Infer
(
expr
);
CHECK
(
WellFormed
(
e
));
CHECK
(
WellFormed
(
e
));
return
e
;
return
e
;
}
}
Function
InferType
(
const
Function
&
func
,
Function
InferType
(
const
Function
&
func
,
const
Environment
&
env
,
const
Module
&
mod
,
const
GlobalVar
&
var
)
{
const
GlobalVar
&
var
)
{
Function
func_copy
=
Function
(
make_node
<
FunctionNode
>
(
*
func
.
operator
->
()));
Function
func_copy
=
Function
(
make_node
<
FunctionNode
>
(
*
func
.
operator
->
()));
func_copy
->
checked_type_
=
func_copy
->
func_type_annotation
();
func_copy
->
checked_type_
=
func_copy
->
func_type_annotation
();
env
->
functions
.
Set
(
var
,
func_copy
);
mod
->
functions
.
Set
(
var
,
func_copy
);
Expr
func_ret
=
TypeInferencer
(
env
).
Infer
(
func_copy
);
Expr
func_ret
=
TypeInferencer
(
mod
).
Infer
(
func_copy
);
auto
map_node
=
env
->
functions
.
CopyOnWrite
();
auto
map_node
=
mod
->
functions
.
CopyOnWrite
();
map_node
->
data
.
erase
(
var
.
node_
);
map_node
->
data
.
erase
(
var
.
node_
);
CHECK
(
WellFormed
(
func_ret
));
CHECK
(
WellFormed
(
func_ret
));
return
Downcast
<
Function
>
(
func_ret
);
return
Downcast
<
Function
>
(
func_ret
);
...
...
tests/cpp/relay_pass_type_infer_test.cc
View file @
ead3ac6c
...
@@ -11,7 +11,7 @@ TEST(Relay, SelfReference) {
...
@@ -11,7 +11,7 @@ TEST(Relay, SelfReference) {
auto
x
=
relay
::
VarNode
::
make
(
"x"
,
type_a
);
auto
x
=
relay
::
VarNode
::
make
(
"x"
,
type_a
);
auto
f
=
relay
::
FunctionNode
::
make
(
tvm
::
Array
<
relay
::
Var
>
{
x
},
x
,
type_b
,
Array
<
relay
::
TypeVar
>
{});
auto
f
=
relay
::
FunctionNode
::
make
(
tvm
::
Array
<
relay
::
Var
>
{
x
},
x
,
type_b
,
Array
<
relay
::
TypeVar
>
{});
auto
fx
=
relay
::
CallNode
::
make
(
f
,
Array
<
relay
::
Expr
>
{
x
});
auto
fx
=
relay
::
CallNode
::
make
(
f
,
Array
<
relay
::
Expr
>
{
x
});
auto
type_fx
=
relay
::
InferType
(
fx
,
relay
::
Environment
Node
::
make
(
Map
<
relay
::
GlobalVar
,
relay
::
Function
>
{}));
auto
type_fx
=
relay
::
InferType
(
fx
,
relay
::
Module
Node
::
make
(
Map
<
relay
::
GlobalVar
,
relay
::
Function
>
{}));
CHECK_EQ
(
type_fx
->
checked_type
(),
type_a
);
CHECK_EQ
(
type_fx
->
checked_type
(),
type_a
);
}
}
...
...
tests/python/relay/test_graph_runtime.py
View file @
ead3ac6c
...
@@ -6,10 +6,10 @@ from tvm.relay.ir_pass import infer_type
...
@@ -6,10 +6,10 @@ from tvm.relay.ir_pass import infer_type
from
tvm.relay.interpreter
import
Interpreter
from
tvm.relay.interpreter
import
Interpreter
from
tvm.relay.scope_builder
import
ScopeBuilder
from
tvm.relay.scope_builder
import
ScopeBuilder
from
tvm.relay.op
import
add
from
tvm.relay.op
import
add
from
tvm.relay.
env
import
Environment
from
tvm.relay.
module
import
Module
# @tq, @jr should we put this in testing ns?
# @tq, @jr should we put this in testing ns?
def
check_rts
(
expr
,
args
,
expected_result
,
env
=
None
):
def
check_rts
(
expr
,
args
,
expected_result
,
mod
=
None
):
"""
"""
Check that evaluating `expr` applied to the arguments produces
Check that evaluating `expr` applied to the arguments produces
`result` on both the evaluator and TVM runtime.
`result` on both the evaluator and TVM runtime.
...
@@ -25,8 +25,8 @@ def check_rts(expr, args, expected_result, env=None):
...
@@ -25,8 +25,8 @@ def check_rts(expr, args, expected_result, env=None):
expected_result:
expected_result:
The expected result of running the expression.
The expected result of running the expression.
"""
"""
intrp
=
create_executor
(
'graph'
,
env
=
env
)
intrp
=
create_executor
(
'graph'
,
mod
=
mod
)
graph
=
create_executor
(
'graph'
,
env
=
env
)
graph
=
create_executor
(
'graph'
,
mod
=
mod
)
eval_result
=
intrp
.
evaluate
(
expr
)(
*
args
)
eval_result
=
intrp
.
evaluate
(
expr
)(
*
args
)
rts_result
=
graph
.
evaluate
(
expr
)(
*
args
)
rts_result
=
graph
.
evaluate
(
expr
)(
*
args
)
np
.
testing
.
assert_allclose
(
eval_result
.
asnumpy
(),
rts_result
.
asnumpy
())
np
.
testing
.
assert_allclose
(
eval_result
.
asnumpy
(),
rts_result
.
asnumpy
())
...
...
tests/python/relay/test_interpreter.py
View file @
ead3ac6c
...
@@ -7,8 +7,8 @@ from tvm.relay.scope_builder import ScopeBuilder
...
@@ -7,8 +7,8 @@ from tvm.relay.scope_builder import ScopeBuilder
from
tvm.relay
import
testing
,
create_executor
from
tvm.relay
import
testing
,
create_executor
def
check_eval
(
expr
,
args
,
expected_result
,
env
=
None
,
rtol
=
1e-07
):
def
check_eval
(
expr
,
args
,
expected_result
,
mod
=
None
,
rtol
=
1e-07
):
intrp
=
create_executor
(
env
=
env
)
intrp
=
create_executor
(
mod
=
mod
)
result
=
intrp
.
evaluate
(
expr
)(
*
args
)
result
=
intrp
.
evaluate
(
expr
)(
*
args
)
np
.
testing
.
assert_allclose
(
result
.
asnumpy
(),
expected_result
,
rtol
=
rtol
)
np
.
testing
.
assert_allclose
(
result
.
asnumpy
(),
expected_result
,
rtol
=
rtol
)
...
@@ -87,7 +87,7 @@ def test_subtract():
...
@@ -87,7 +87,7 @@ def test_subtract():
check_eval
(
func
,
[
i_data
],
0
)
check_eval
(
func
,
[
i_data
],
0
)
def
test_simple_loop
():
def
test_simple_loop
():
env
=
relay
.
env
.
Environment
({})
mod
=
relay
.
module
.
Module
({})
sum_up
=
relay
.
GlobalVar
(
'sum_up'
)
sum_up
=
relay
.
GlobalVar
(
'sum_up'
)
i
=
relay
.
var
(
'i'
,
shape
=
[],
dtype
=
'int32'
)
i
=
relay
.
var
(
'i'
,
shape
=
[],
dtype
=
'int32'
)
sb
=
ScopeBuilder
()
sb
=
ScopeBuilder
()
...
@@ -98,12 +98,12 @@ def test_simple_loop():
...
@@ -98,12 +98,12 @@ def test_simple_loop():
rec_call
=
relay
.
Call
(
sum_up
,
[
one_less
])
rec_call
=
relay
.
Call
(
sum_up
,
[
one_less
])
sb
.
ret
(
op
.
add
(
rec_call
,
i
))
sb
.
ret
(
op
.
add
(
rec_call
,
i
))
func
=
relay
.
Function
([
i
],
sb
.
get
(),
ret_type
=
relay
.
TensorType
([],
'int32'
))
func
=
relay
.
Function
([
i
],
sb
.
get
(),
ret_type
=
relay
.
TensorType
([],
'int32'
))
env
[
sum_up
]
=
func
mod
[
sum_up
]
=
func
i_data
=
np
.
array
(
10
,
dtype
=
'int32'
)
i_data
=
np
.
array
(
10
,
dtype
=
'int32'
)
check_eval
(
sum_up
,
[
i_data
],
sum
(
range
(
1
,
11
)),
env
=
env
)
check_eval
(
sum_up
,
[
i_data
],
sum
(
range
(
1
,
11
)),
mod
=
mod
)
def
test_loop
():
def
test_loop
():
env
=
relay
.
env
.
Environment
({})
mod
=
relay
.
module
.
Module
({})
sum_up
=
relay
.
GlobalVar
(
'sum_up'
)
sum_up
=
relay
.
GlobalVar
(
'sum_up'
)
i
=
relay
.
var
(
'i'
,
shape
=
[],
dtype
=
'int32'
)
i
=
relay
.
var
(
'i'
,
shape
=
[],
dtype
=
'int32'
)
accum
=
relay
.
var
(
'accum'
,
shape
=
[],
dtype
=
'int32'
)
accum
=
relay
.
var
(
'accum'
,
shape
=
[],
dtype
=
'int32'
)
...
@@ -115,10 +115,10 @@ def test_loop():
...
@@ -115,10 +115,10 @@ def test_loop():
new_accum
=
op
.
add
(
accum
,
i
)
new_accum
=
op
.
add
(
accum
,
i
)
sb
.
ret
(
relay
.
Call
(
sum_up
,
[
one_less
,
new_accum
]))
sb
.
ret
(
relay
.
Call
(
sum_up
,
[
one_less
,
new_accum
]))
func
=
relay
.
Function
([
i
,
accum
],
sb
.
get
())
func
=
relay
.
Function
([
i
,
accum
],
sb
.
get
())
env
[
sum_up
]
=
func
mod
[
sum_up
]
=
func
i_data
=
np
.
array
(
10
,
dtype
=
'int32'
)
i_data
=
np
.
array
(
10
,
dtype
=
'int32'
)
accum_data
=
np
.
array
(
0
,
dtype
=
'int32'
)
accum_data
=
np
.
array
(
0
,
dtype
=
'int32'
)
check_eval
(
sum_up
,
[
i_data
,
accum_data
],
sum
(
range
(
1
,
11
)),
env
=
env
)
check_eval
(
sum_up
,
[
i_data
,
accum_data
],
sum
(
range
(
1
,
11
)),
mod
=
mod
)
def
test_mlp
():
def
test_mlp
():
pass
pass
...
...
tests/python/relay/test_ir_text_printer.py
View file @
ead3ac6c
...
@@ -28,7 +28,7 @@ def test_env():
...
@@ -28,7 +28,7 @@ def test_env():
z
=
relay
.
add
(
x
,
y
)
z
=
relay
.
add
(
x
,
y
)
z
=
relay
.
add
(
z
,
z
)
z
=
relay
.
add
(
z
,
z
)
f
=
relay
.
Function
([
x
,
y
],
z
)
f
=
relay
.
Function
([
x
,
y
],
z
)
env
=
relay
.
Environment
()
env
=
relay
.
Module
()
env
[
"myf"
]
=
f
env
[
"myf"
]
=
f
text
=
env
.
astext
()
text
=
env
.
astext
()
assert
"def @myf"
in
text
assert
"def @myf"
in
text
...
...
tests/python/relay/test_type_infer.py
View file @
ead3ac6c
...
@@ -9,8 +9,8 @@ from tvm.relay import op
...
@@ -9,8 +9,8 @@ from tvm.relay import op
from
tvm.relay.scope_builder
import
ScopeBuilder
from
tvm.relay.scope_builder
import
ScopeBuilder
def
assert_has_type
(
expr
,
typ
,
env
=
relay
.
env
.
Environment
({})):
def
assert_has_type
(
expr
,
typ
,
mod
=
relay
.
module
.
Module
({})):
checked_expr
=
infer_type
(
expr
,
env
)
checked_expr
=
infer_type
(
expr
,
mod
)
checked_type
=
checked_expr
.
checked_type
checked_type
=
checked_expr
.
checked_type
if
checked_type
!=
typ
:
if
checked_type
!=
typ
:
raise
RuntimeError
(
"Type mismatch
%
s vs
%
s"
%
(
raise
RuntimeError
(
"Type mismatch
%
s vs
%
s"
%
(
...
@@ -105,10 +105,10 @@ def test_recursion():
...
@@ -105,10 +105,10 @@ def test_recursion():
sb
.
ret
(
data
)
sb
.
ret
(
data
)
with
sb
.
else_scope
():
with
sb
.
else_scope
():
sb
.
ret
(
f
(
relay
.
subtract
(
n
,
relay
.
const
(
1
,
ti32
)),
relay
.
log
(
data
)))
sb
.
ret
(
f
(
relay
.
subtract
(
n
,
relay
.
const
(
1
,
ti32
)),
relay
.
log
(
data
)))
env
=
relay
.
Environment
()
mod
=
relay
.
Module
()
env
[
f
]
=
relay
.
Function
([
n
,
data
],
sb
.
get
())
mod
[
f
]
=
relay
.
Function
([
n
,
data
],
sb
.
get
())
assert
"
%3
= @f(
%1
,
%2
)"
in
env
.
astext
()
assert
"
%3
= @f(
%1
,
%2
)"
in
mod
.
astext
()
assert
env
[
f
]
.
checked_type
==
relay
.
FuncType
([
ti32
,
tf32
],
tf32
)
assert
mod
[
f
]
.
checked_type
==
relay
.
FuncType
([
ti32
,
tf32
],
tf32
)
# This currently fails and should pass under the type system.
# This currently fails and should pass under the type system.
#
#
...
@@ -179,12 +179,12 @@ def test_self_reference():
...
@@ -179,12 +179,12 @@ def test_self_reference():
def
test_global_var_cow_issue
():
def
test_global_var_cow_issue
():
env
=
relay
.
env
.
Environment
({})
mod
=
relay
.
Module
({})
gv
=
relay
.
GlobalVar
(
"foo"
)
gv
=
relay
.
GlobalVar
(
"foo"
)
x
=
relay
.
var
(
'x'
,
shape
=
[])
x
=
relay
.
var
(
'x'
,
shape
=
[])
func
=
relay
.
Function
([
x
],
relay
.
Call
(
gv
,
[
x
]),
func
=
relay
.
Function
([
x
],
relay
.
Call
(
gv
,
[
x
]),
relay
.
TensorType
([],
'float32'
))
relay
.
TensorType
([],
'float32'
))
env
[
gv
]
=
func
mod
[
gv
]
=
func
def
test_equal
():
def
test_equal
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment