Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
8876eac8
Unverified
Commit
8876eac8
authored
Oct 19, 2018
by
Tianqi Chen
Committed by
GitHub
Oct 19, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY] IR builder stablize refactor, clean pass (#1934)
parent
4300bbc2
Hide whitespace changes
Inline
Side-by-side
Showing
52 changed files
with
1212 additions
and
1836 deletions
+1212
-1836
include/tvm/relay/attrs/nn.h
+1
-1
include/tvm/relay/environment.h
+18
-10
include/tvm/relay/expr.h
+5
-4
include/tvm/relay/op.h
+3
-3
include/tvm/relay/pass.h
+18
-9
include/tvm/relay/type.h
+16
-16
python/tvm/relay/__init__.py
+4
-2
python/tvm/relay/env.py
+60
-46
python/tvm/relay/expr.py
+54
-9
python/tvm/relay/ir_builder.py
+0
-387
python/tvm/relay/ir_pass.py
+38
-31
python/tvm/relay/scope_builder.py
+185
-0
python/tvm/relay/ty.py
+42
-20
src/relay/ir/environment.cc
+43
-59
src/relay/ir/expr.cc
+1
-1
src/relay/ir/expr_functor.cc
+2
-2
src/relay/ir/text_printer.cc
+10
-2
src/relay/ir/type.cc
+11
-11
src/relay/op/image/resize.cc
+1
-0
src/relay/op/nn/nn.cc
+5
-0
src/relay/op/nn/pad.cc
+1
-0
src/relay/op/nn/pooling.cc
+5
-0
src/relay/op/nn/upsampling.cc
+1
-0
src/relay/op/tensor/reduce.cc
+2
-2
src/relay/op/tensor/transform.cc
+13
-0
src/relay/op/type_relations.cc
+43
-134
src/relay/op/type_relations.h
+0
-27
src/relay/op/vision/multibox_op.cc
+1
-0
src/relay/pass/alpha_eq.cc
+5
-5
src/relay/pass/dead_code.cc
+3
-1
src/relay/pass/kind_check.cc
+2
-2
src/relay/pass/let_list.h
+1
-1
src/relay/pass/type_functor.h
+2
-2
src/relay/pass/type_infer.cc
+58
-22
src/relay/pass/type_subst.cc
+6
-6
src/relay/pass/type_subst.h
+2
-2
src/relay/pass/type_visitor.h
+6
-6
src/relay/pass/util.cc
+12
-12
tests/python/relay/test_ir_builder.py
+0
-19
tests/python/relay/test_ir_nodes.py
+4
-4
tests/python/relay/test_ir_text_printer.py
+21
-11
tests/python/relay/test_op_level1.py
+98
-241
tests/python/relay/test_op_level2.py
+123
-192
tests/python/relay/test_op_level3.py
+78
-161
tests/python/relay/test_op_level4.py
+51
-172
tests/python/relay/test_op_level5.py
+19
-35
tests/python/relay/test_pass_alpha_equal.py
+51
-52
tests/python/relay/test_pass_check_kind.py
+22
-22
tests/python/relay/test_pass_dead_code_elimination.py
+13
-9
tests/python/relay/test_pass_free_vars.py
+1
-1
tests/python/relay/test_type_infer.py
+51
-80
tests/python/relay/test_type_solver.py
+0
-2
No files found.
include/tvm/relay/attrs/nn.h
View file @
8876eac8
...
@@ -254,7 +254,7 @@ struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
...
@@ -254,7 +254,7 @@ struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
struct
LeakyReluAttrs
:
public
tvm
::
AttrsNode
<
LeakyReluAttrs
>
{
struct
LeakyReluAttrs
:
public
tvm
::
AttrsNode
<
LeakyReluAttrs
>
{
double
alpha
;
double
alpha
;
TVM_DECLARE_ATTRS
(
Dense
Attrs
,
"relay.attrs.LeakyReluAttrs"
)
{
TVM_DECLARE_ATTRS
(
LeakyRelu
Attrs
,
"relay.attrs.LeakyReluAttrs"
)
{
TVM_ATTR_FIELD
(
alpha
).
set_lower_bound
(
0
.
0
).
set_default
(
0
.
25
)
TVM_ATTR_FIELD
(
alpha
).
set_lower_bound
(
0
.
0
).
set_default
(
0
.
25
)
.
describe
(
"Slope coefficient for the negative half axis."
);
.
describe
(
"Slope coefficient for the negative half axis."
);
}
}
...
...
include/tvm/relay/environment.h
View file @
8876eac8
...
@@ -47,12 +47,13 @@ class EnvironmentNode : public RelayNode {
...
@@ -47,12 +47,13 @@ class EnvironmentNode : public RelayNode {
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"functions"
,
&
functions
);
v
->
Visit
(
"functions"
,
&
functions
);
v
->
Visit
(
"global_
map_"
,
&
global
_map_
);
v
->
Visit
(
"global_
var_map_"
,
&
global_var
_map_
);
}
}
TVM_DLL
static
Environment
make
(
tvm
::
Map
<
GlobalVar
,
Function
>
global_funcs
);
TVM_DLL
static
Environment
make
(
tvm
::
Map
<
GlobalVar
,
Function
>
global_funcs
);
/*! \brief Add a function to the global environment.
/*!
* \brief Add a function to the global environment.
* \param var The name of the global function.
* \param var The name of the global function.
* \param func The function.
* \param func The function.
* \param update Controls whether you can replace a definition in the
* \param update Controls whether you can replace a definition in the
...
@@ -60,39 +61,46 @@ class EnvironmentNode : public RelayNode {
...
@@ -60,39 +61,46 @@ class EnvironmentNode : public RelayNode {
*/
*/
void
Add
(
const
GlobalVar
&
var
,
const
Function
&
func
,
bool
update
=
false
);
void
Add
(
const
GlobalVar
&
var
,
const
Function
&
func
,
bool
update
=
false
);
/*! \brief Update a function in the global environment.
/*!
* \brief Update a function in the global environment.
* \param var The name of the global function to update.
* \param var The name of the global function to update.
* \param func The new function.
* \param func The new function.
*/
*/
void
Update
(
const
GlobalVar
&
var
,
const
Function
&
func
);
void
Update
(
const
GlobalVar
&
var
,
const
Function
&
func
);
/*! \brief Remove a function from the global environment.
/*!
* \brief Remove a function from the global environment.
* \param var The name of the global function to update.
* \param var The name of the global function to update.
*/
*/
void
Remove
(
const
GlobalVar
&
var
);
void
Remove
(
const
GlobalVar
&
var
);
/*! \brief Lookup a global function by its variable.
/*!
* \brief Lookup a global function by its variable.
* \param str The unique string specifying the global variable.
* \param str The unique string specifying the global variable.
* \returns The global variable.
* \returns The global variable.
*/
*/
GlobalVar
GetGlobalVar
(
const
std
::
string
&
str
);
GlobalVar
GetGlobalVar
(
const
std
::
string
&
str
);
/*! \brief Lookup a global function by its variable.
/*!
* \brief Lookup a global function by its variable.
* \param var The global var to lookup.
* \param var The global var to lookup.
* \returns The function named by the variable argument.
* \returns The function named by the variable argument.
*/
*/
Function
Lookup
(
const
GlobalVar
&
var
);
Function
Lookup
(
const
GlobalVar
&
var
);
/*! \brief Lookup a global function by its string name
/*!
* \brief Lookup a global function by its string name
* \param name The name of the function.
* \param name The name of the function.
* \returns The function named by the argument.
* \returns The function named by the argument.
*/
*/
Function
Lookup
(
const
std
::
string
&
name
);
Function
Lookup
(
const
std
::
string
&
name
);
/*! \brief Combine with another Environment.
/*!
* \brief Update the functions inside this environment by
* functions in another environment.
* \param other The other environment.
* \param other The other environment.
*/
*/
void
Merg
e
(
const
Environment
&
other
);
void
Updat
e
(
const
Environment
&
other
);
static
constexpr
const
char
*
_type_key
=
"relay.Environment"
;
static
constexpr
const
char
*
_type_key
=
"relay.Environment"
;
TVM_DECLARE_NODE_TYPE_INFO
(
EnvironmentNode
,
Node
);
TVM_DECLARE_NODE_TYPE_INFO
(
EnvironmentNode
,
Node
);
...
@@ -101,7 +109,7 @@ class EnvironmentNode : public RelayNode {
...
@@ -101,7 +109,7 @@ class EnvironmentNode : public RelayNode {
/*! \brief A map from string names to global variables that
/*! \brief A map from string names to global variables that
* ensures global uniqueness.
* ensures global uniqueness.
*/
*/
tvm
::
Map
<
std
::
string
,
GlobalVar
>
global_map_
;
tvm
::
Map
<
std
::
string
,
GlobalVar
>
global_
var_
map_
;
};
};
struct
Environment
:
public
NodeRef
{
struct
Environment
:
public
NodeRef
{
...
...
include/tvm/relay/expr.h
View file @
8876eac8
...
@@ -197,7 +197,7 @@ class FunctionNode : public ExprNode {
...
@@ -197,7 +197,7 @@ class FunctionNode : public ExprNode {
*
*
* \note This can be usually empty for non-polymorphic functions.
* \note This can be usually empty for non-polymorphic functions.
*/
*/
tvm
::
Array
<
Type
Param
>
type_params
;
tvm
::
Array
<
Type
Var
>
type_params
;
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"params"
,
&
params
);
v
->
Visit
(
"params"
,
&
params
);
...
@@ -219,7 +219,7 @@ class FunctionNode : public ExprNode {
...
@@ -219,7 +219,7 @@ class FunctionNode : public ExprNode {
TVM_DLL
static
Function
make
(
tvm
::
Array
<
Var
>
params
,
TVM_DLL
static
Function
make
(
tvm
::
Array
<
Var
>
params
,
Expr
body
,
Expr
body
,
Type
ret_type
,
Type
ret_type
,
tvm
::
Array
<
Type
Param
>
ty_params
);
tvm
::
Array
<
Type
Var
>
ty_params
);
static
constexpr
const
char
*
_type_key
=
"relay.Function"
;
static
constexpr
const
char
*
_type_key
=
"relay.Function"
;
TVM_DECLARE_NODE_TYPE_INFO
(
FunctionNode
,
ExprNode
);
TVM_DECLARE_NODE_TYPE_INFO
(
FunctionNode
,
ExprNode
);
...
@@ -375,13 +375,14 @@ class TupleGetItemNode : public ExprNode {
...
@@ -375,13 +375,14 @@ class TupleGetItemNode : public ExprNode {
int
index
;
int
index
;
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"tuple"
,
&
tuple
);
v
->
Visit
(
"tuple
_value
"
,
&
tuple
);
v
->
Visit
(
"index"
,
&
index
);
v
->
Visit
(
"index"
,
&
index
);
v
->
Visit
(
"_checked_type_"
,
&
checked_type_
);
}
}
TVM_DLL
static
TupleGetItem
make
(
Expr
tuple
,
int
index
);
TVM_DLL
static
TupleGetItem
make
(
Expr
tuple
,
int
index
);
static
constexpr
const
char
*
_type_key
=
"relay.GetItem"
;
static
constexpr
const
char
*
_type_key
=
"relay.
Tuple
GetItem"
;
TVM_DECLARE_NODE_TYPE_INFO
(
TupleGetItemNode
,
ExprNode
);
TVM_DECLARE_NODE_TYPE_INFO
(
TupleGetItemNode
,
ExprNode
);
};
};
...
...
include/tvm/relay/op.h
View file @
8876eac8
...
@@ -371,14 +371,14 @@ inline OpRegistry& OpRegistry::add_type_rel(
...
@@ -371,14 +371,14 @@ inline OpRegistry& OpRegistry::add_type_rel(
env_type_rel_func
=
env_func
;
env_type_rel_func
=
env_func
;
}
}
Array
<
Type
Param
>
type_params
;
Array
<
Type
Var
>
type_params
;
Array
<
Type
>
arg_types
;
Array
<
Type
>
arg_types
;
// Add inputs.
// Add inputs.
std
::
string
input_name_prefix
=
"in"
;
std
::
string
input_name_prefix
=
"in"
;
for
(
int
i
=
0
;
i
<
get
()
->
num_inputs
;
i
++
)
{
for
(
int
i
=
0
;
i
<
get
()
->
num_inputs
;
i
++
)
{
auto
name
=
input_name_prefix
+
std
::
to_string
(
i
);
auto
name
=
input_name_prefix
+
std
::
to_string
(
i
);
auto
param
=
Type
ParamNode
::
make
(
name
,
TypeParam
Node
::
Kind
::
kType
);
auto
param
=
Type
VarNode
::
make
(
name
,
TypeVar
Node
::
Kind
::
kType
);
type_params
.
push_back
(
param
);
type_params
.
push_back
(
param
);
arg_types
.
push_back
(
param
);
arg_types
.
push_back
(
param
);
}
}
...
@@ -386,7 +386,7 @@ inline OpRegistry& OpRegistry::add_type_rel(
...
@@ -386,7 +386,7 @@ inline OpRegistry& OpRegistry::add_type_rel(
Array
<
Type
>
ty_call_args
=
arg_types
;
Array
<
Type
>
ty_call_args
=
arg_types
;
// Add output type.
// Add output type.
auto
out_param
=
Type
ParamNode
::
make
(
"out"
,
TypeParam
Node
::
Kind
::
kType
);
auto
out_param
=
Type
VarNode
::
make
(
"out"
,
TypeVar
Node
::
Kind
::
kType
);
type_params
.
push_back
(
out_param
);
type_params
.
push_back
(
out_param
);
// this will trigger copy on write.
// this will trigger copy on write.
ty_call_args
.
push_back
(
out_param
);
ty_call_args
.
push_back
(
out_param
);
...
...
include/tvm/relay/pass.h
View file @
8876eac8
...
@@ -12,21 +12,30 @@
...
@@ -12,21 +12,30 @@
namespace
tvm
{
namespace
tvm
{
namespace
relay
{
namespace
relay
{
/*! \brief Infer the type of an expression with the provided environment.
/*!
* \brief Infer the type of an expression.
*
*
* The result of type checking is a new expression with unambigous
* The result of type checking is a new expression with unambigous
* type information filled in, as well as it's checked type field
* type information filled in, as well as it's checked type field
* populated with the result type.
* populated with the result type.
*
*
* \param env The environment used for global settings and referencing
* \param expr The expression to type check.
* global functions.
* \param env The environment used for referencing global functions, can be None.
*
* \param e The expression to type check.
*
*
* \return A type checked expression with its checked_type field populated.
* \return A type checked expression with its checked_type field populated.
*/
*/
Expr
InferType
(
const
Environment
&
env
,
const
Expr
&
e
);
Expr
InferType
(
const
Expr
&
expr
,
const
Environment
&
env
);
Expr
InferType
(
const
Environment
&
env
,
const
GlobalVar
&
var
,
const
Function
&
f
);
/*!
* \brief Infer the type of a function as if it is mapped to var in the env.
*
* \param f the function.
* \param env The environment used for referencing global functions.
* \param var The global variable corresponding to the function.
*
* \return A type checked Function with its checked_type field populated.
* \note this function mutates env and is not thread-safe.
*/
Function
InferType
(
const
Function
&
f
,
const
Environment
&
env
,
const
GlobalVar
&
var
);
/*!
/*!
* \brief Check that types are well kinded by applying "kinding rules".
* \brief Check that types are well kinded by applying "kinding rules".
...
@@ -111,7 +120,7 @@ tvm::Array<Var> FreeVariables(const Expr& e);
...
@@ -111,7 +120,7 @@ tvm::Array<Var> FreeVariables(const Expr& e);
*
*
* \return the set of free type variables.
* \return the set of free type variables.
*/
*/
tvm
::
Array
<
Type
Param
>
FreeTypeVariables
(
const
Expr
&
e
);
tvm
::
Array
<
Type
Var
>
FreeTypeVariables
(
const
Expr
&
e
);
/*! \brief Get free type parameters from type t.
/*! \brief Get free type parameters from type t.
*
*
...
@@ -121,7 +130,7 @@ tvm::Array<TypeParam> FreeTypeVariables(const Expr& e);
...
@@ -121,7 +130,7 @@ tvm::Array<TypeParam> FreeTypeVariables(const Expr& e);
*
*
* \return the set of free type variables.
* \return the set of free type variables.
*/
*/
tvm
::
Array
<
Type
Param
>
FreeTypeVariables
(
const
Type
&
t
);
tvm
::
Array
<
Type
Var
>
FreeTypeVariables
(
const
Type
&
t
);
/*! \brief Remove expressions which does not effect the program result.
/*! \brief Remove expressions which does not effect the program result.
*
*
...
...
include/tvm/relay/type.h
View file @
8876eac8
...
@@ -98,7 +98,7 @@ RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type);
...
@@ -98,7 +98,7 @@ RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type);
* This can be viewed as template parameter in c++ template function.
* This can be viewed as template parameter in c++ template function.
*
*
* For example, in the following pesudo code,
* For example, in the following pesudo code,
* the Type
Param of f is TypeParam
(kind=kShapeVar, var=n).
* the Type
Var of f is TypeVar
(kind=kShapeVar, var=n).
* This function can take in a Tensor with shape=(3, 3) and
* This function can take in a Tensor with shape=(3, 3) and
* returns a Tensor with shape=(9,)
* returns a Tensor with shape=(9,)
*
*
...
@@ -108,13 +108,13 @@ RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type);
...
@@ -108,13 +108,13 @@ RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type);
* f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)]
* f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)]
*
*
* \endcode
* \endcode
* \sa Type
ParamNode The actual container class of TypeParam
* \sa Type
VarNode The actual container class of TypeVar
*/
*/
class
Type
Param
;
class
Type
Var
;
/*! \brief Type
Param
container node */
/*! \brief Type
Var
container node */
class
Type
Param
Node
:
public
TypeNode
{
class
Type
Var
Node
:
public
TypeNode
{
public
:
public
:
/*! \brief possible kinds of Type
Param
*/
/*! \brief possible kinds of Type
Var
*/
enum
Kind
:
int
{
enum
Kind
:
int
{
/*! \brief template variable in shape expression */
/*! \brief template variable in shape expression */
kType
=
0
,
kType
=
0
,
...
@@ -136,13 +136,13 @@ class TypeParamNode : public TypeNode {
...
@@ -136,13 +136,13 @@ class TypeParamNode : public TypeNode {
v
->
Visit
(
"span"
,
&
span
);
v
->
Visit
(
"span"
,
&
span
);
}
}
TVM_DLL
static
Type
Param
make
(
std
::
string
name
,
Kind
kind
);
TVM_DLL
static
Type
Var
make
(
std
::
string
name
,
Kind
kind
);
static
constexpr
const
char
*
_type_key
=
"relay.Type
Param
"
;
static
constexpr
const
char
*
_type_key
=
"relay.Type
Var
"
;
TVM_DECLARE_NODE_TYPE_INFO
(
Type
Param
Node
,
TypeNode
);
TVM_DECLARE_NODE_TYPE_INFO
(
Type
Var
Node
,
TypeNode
);
};
};
RELAY_DEFINE_NODE_REF
(
Type
Param
,
TypeParam
Node
,
Type
);
RELAY_DEFINE_NODE_REF
(
Type
Var
,
TypeVar
Node
,
Type
);
/*!
/*!
* \brief IncompleteType.
* \brief IncompleteType.
...
@@ -150,20 +150,20 @@ RELAY_DEFINE_NODE_REF(TypeParam, TypeParamNode, Type);
...
@@ -150,20 +150,20 @@ RELAY_DEFINE_NODE_REF(TypeParam, TypeParamNode, Type);
*
*
* If we view the type relations as "computational graph of types",
* If we view the type relations as "computational graph of types",
* then IncompleteType represents intermediate values of the graph,
* then IncompleteType represents intermediate values of the graph,
* Type
Param
represents the input to the graph.
* Type
Var
represents the input to the graph.
*/
*/
class
IncompleteType
;
class
IncompleteType
;
/*! \brief IncompleteType container node */
/*! \brief IncompleteType container node */
class
IncompleteTypeNode
:
public
TypeNode
{
class
IncompleteTypeNode
:
public
TypeNode
{
public
:
public
:
Type
Param
Node
::
Kind
kind
;
Type
Var
Node
::
Kind
kind
;
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"kind"
,
&
kind
);
v
->
Visit
(
"kind"
,
&
kind
);
}
}
TVM_DLL
static
IncompleteType
make
(
Type
Param
Node
::
Kind
kind
);
TVM_DLL
static
IncompleteType
make
(
Type
Var
Node
::
Kind
kind
);
static
constexpr
const
char
*
_type_key
=
"relay.IncompleteType"
;
static
constexpr
const
char
*
_type_key
=
"relay.IncompleteType"
;
TVM_DECLARE_NODE_TYPE_INFO
(
IncompleteTypeNode
,
TypeNode
);
TVM_DECLARE_NODE_TYPE_INFO
(
IncompleteTypeNode
,
TypeNode
);
...
@@ -192,7 +192,7 @@ class FuncType;
...
@@ -192,7 +192,7 @@ class FuncType;
* Relay support polymorphic function type.
* Relay support polymorphic function type.
* This can be roughly viewed as template function in C++.
* This can be roughly viewed as template function in C++.
*
*
* \sa Type
Param
, TypeConstraint
* \sa Type
Var
, TypeConstraint
*/
*/
class
FuncTypeNode
:
public
TypeNode
{
class
FuncTypeNode
:
public
TypeNode
{
public
:
public
:
...
@@ -203,7 +203,7 @@ class FuncTypeNode : public TypeNode {
...
@@ -203,7 +203,7 @@ class FuncTypeNode : public TypeNode {
// The following fields are used in polymorphic(template) functions
// The following fields are used in polymorphic(template) functions
// For normal functions, the following two fields will be empty.
// For normal functions, the following two fields will be empty.
/*! \brief The type parameters of the function */
/*! \brief The type parameters of the function */
tvm
::
Array
<
Type
Param
>
type_params
;
tvm
::
Array
<
Type
Var
>
type_params
;
/*!
/*!
* \brief potential constraint the type need to obey
* \brief potential constraint the type need to obey
* \note this field is reserved for futher purposes.
* \note this field is reserved for futher purposes.
...
@@ -220,7 +220,7 @@ class FuncTypeNode : public TypeNode {
...
@@ -220,7 +220,7 @@ class FuncTypeNode : public TypeNode {
TVM_DLL
static
FuncType
make
(
tvm
::
Array
<
Type
>
arg_types
,
TVM_DLL
static
FuncType
make
(
tvm
::
Array
<
Type
>
arg_types
,
Type
ret_type
,
Type
ret_type
,
tvm
::
Array
<
Type
Param
>
type_params
,
tvm
::
Array
<
Type
Var
>
type_params
,
tvm
::
Array
<
TypeConstraint
>
type_constraints
);
tvm
::
Array
<
TypeConstraint
>
type_constraints
);
static
constexpr
const
char
*
_type_key
=
"relay.FuncType"
;
static
constexpr
const
char
*
_type_key
=
"relay.FuncType"
;
...
...
python/tvm/relay/__init__.py
View file @
8876eac8
...
@@ -5,7 +5,6 @@ from . import ty
...
@@ -5,7 +5,6 @@ from . import ty
from
.
import
expr
from
.
import
expr
from
.
import
env
from
.
import
env
from
.
import
ir_pass
from
.
import
ir_pass
from
.
import
ir_builder
# Root operators
# Root operators
from
.op
import
Op
from
.op
import
Op
...
@@ -16,6 +15,8 @@ from . import nn
...
@@ -16,6 +15,8 @@ from . import nn
from
.
import
vision
from
.
import
vision
from
.
import
image
from
.
import
image
from
.scope_builder
import
ScopeBuilder
# Span
# Span
Span
=
base
.
Span
Span
=
base
.
Span
...
@@ -27,11 +28,12 @@ Type = ty.Type
...
@@ -27,11 +28,12 @@ Type = ty.Type
TupleType
=
ty
.
TupleType
TupleType
=
ty
.
TupleType
TensorType
=
ty
.
TensorType
TensorType
=
ty
.
TensorType
Kind
=
ty
.
Kind
Kind
=
ty
.
Kind
Type
Param
=
ty
.
TypeParam
Type
Var
=
ty
.
TypeVar
TypeConstraint
=
ty
.
TypeConstraint
TypeConstraint
=
ty
.
TypeConstraint
FuncType
=
ty
.
FuncType
FuncType
=
ty
.
FuncType
TypeRelation
=
ty
.
TypeRelation
TypeRelation
=
ty
.
TypeRelation
IncompleteType
=
ty
.
IncompleteType
IncompleteType
=
ty
.
IncompleteType
scalar_type
=
ty
.
scalar_type
# Expr
# Expr
Constant
=
expr
.
Constant
Constant
=
expr
.
Constant
...
...
python/tvm/relay/env.py
View file @
8876eac8
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
"""A global environment storing everything needed to interpret or compile a Relay program."""
"""A global environment storing everything needed to interpret or compile a Relay program."""
from
.base
import
register_relay_node
,
RelayNode
from
.base
import
register_relay_node
,
RelayNode
from
.._ffi
import
base
as
_base
from
.
import
_make
from
.
import
_make
from
.
import
_env
from
.
import
_env
from
.
import
expr
as
_expr
@register_relay_node
@register_relay_node
class
Environment
(
RelayNode
):
class
Environment
(
RelayNode
):
"""The global Relay environment containing functions,
"""The global Relay environment containing collection of functions.
options and more.
"""
def
__init__
(
self
,
funcs
=
None
):
"""Construct an environment.
Parameters
------
funcs : optional, dict
Map of global var to Function
Returns
Each global function is identified by an unique tvm.relay.GlobalVar.
------
tvm.relay.GlobalVar and Environment is necessary in order to enable
env: A new environment containing :py:class:`~relay.env.Environment`.
recursions in function to avoid cyclic reference in the function.x
"""
funcs
=
funcs
if
funcs
else
{}
self
.
__init_handle_by_constructor__
(
_make
.
Environment
,
funcs
)
def
add
(
self
,
var
,
func
):
Parameters
----------
functions : dict, optional.
Map of global var to Function
"""
def
__init__
(
self
,
functions
=
None
):
if
functions
is
None
:
functions
=
{}
elif
isinstance
(
functions
,
dict
):
mapped_funcs
=
{}
for
k
,
v
in
functions
.
items
():
if
isinstance
(
k
,
_base
.
string_types
):
k
=
_expr
.
GlobalVar
(
k
)
if
not
isinstance
(
k
,
_expr
.
GlobalVar
):
raise
TypeError
(
"Expect functions to be Dict[GlobalVar, Function]"
)
mapped_funcs
[
k
]
=
v
functions
=
mapped_funcs
self
.
__init_handle_by_constructor__
(
_make
.
Environment
,
functions
)
def
__setitem__
(
self
,
var
,
func
):
"""Add a function to the environment.
"""Add a function to the environment.
Parameters
Parameters
...
@@ -36,50 +45,55 @@ class Environment(RelayNode):
...
@@ -36,50 +45,55 @@ class Environment(RelayNode):
func: Function
func: Function
The function.
The function.
"""
"""
if
isinstance
(
var
,
str
):
if
isinstance
(
var
,
_base
.
string_types
):
var
=
_env
.
Environment_GetGlobalVar
(
self
,
var
)
var
=
_expr
.
GlobalVar
(
var
)
_env
.
Environment_Add
(
self
,
var
,
func
)
_env
.
Environment_Add
(
self
,
var
,
func
)
def
merge
(
self
,
othe
r
):
def
__getitem__
(
self
,
va
r
):
"""
Merge two environments
.
"""
Lookup a global function by name or by variable
.
Parameters
Parameters
----------
----------
other: Environment
var: str or GlobalVar
The environment to merge into the current Environment.
The name or global variable.
Returns
-------
func: Function
The function referenced by :code:`var`.
"""
"""
return
_env
.
Environment_Merge
(
self
,
other
)
if
isinstance
(
var
,
_base
.
string_types
):
return
_env
.
Environment_Lookup_str
(
self
,
var
)
else
:
return
_env
.
Environment_Lookup
(
self
,
var
)
def
global_var
(
self
,
name
):
def
update
(
self
,
other
):
"""
Get a global variable by nam
e.
"""
Insert functions in another Environment to current on
e.
Parameters
Parameters
----------
----------
name: str
other: Environment
The name of the global variable.
The environment to merge into the current Environment.
Returns
-------
global_var: GlobalVar
The global variable mapped to :code:`name`.
"""
"""
return
_env
.
Environment_GetGlobalVar
(
self
,
name
)
if
isinstance
(
other
,
dict
):
other
=
Environment
(
other
)
return
_env
.
Environment_Update
(
self
,
other
)
def
__getitem__
(
self
,
var
):
def
get_global_var
(
self
,
name
):
"""
Lookup a global function by name or by variabl
e.
"""
Get a global variable in the function by nam
e.
Parameters
Parameters
----------
----------
var: str or GlobalVa
r
name: st
r
The name o
r
global variable.
The name o
f the
global variable.
Returns
Returns
-------
-------
func: Function
global_var: GlobalVar
The function referenced by :code:`var`.
The global variable mapped to :code:`name`.
Raises
------
tvm.TVMError if we cannot find corresponding global var.
"""
"""
if
isinstance
(
var
,
str
):
return
_env
.
Environment_GetGlobalVar
(
self
,
name
)
return
_env
.
Environment_Lookup_str
(
self
,
var
)
else
:
return
_env
.
Environment_Lookup
(
self
,
var
)
python/tvm/relay/expr.py
View file @
8876eac8
...
@@ -28,9 +28,6 @@ class Expr(RelayNode):
...
@@ -28,9 +28,6 @@ class Expr(RelayNode):
" the checked_type for this node"
)
" the checked_type for this node"
)
return
ret
return
ret
def
__call__
(
self
,
*
args
):
return
Call
(
self
,
args
,
None
,
None
)
@register_relay_node
@register_relay_node
class
Constant
(
Expr
):
class
Constant
(
Expr
):
...
@@ -57,6 +54,14 @@ class Tuple(Expr):
...
@@ -57,6 +54,14 @@ class Tuple(Expr):
def
__init__
(
self
,
fields
):
def
__init__
(
self
,
fields
):
self
.
__init_handle_by_constructor__
(
_make
.
Tuple
,
fields
)
self
.
__init_handle_by_constructor__
(
_make
.
Tuple
,
fields
)
def
__getitem__
(
self
,
index
):
if
index
>=
len
(
self
):
raise
IndexError
(
"Tuple index out of range"
)
return
self
.
fields
[
index
]
def
__len__
(
self
):
return
len
(
self
.
fields
)
@register_relay_node
@register_relay_node
class
Var
(
Expr
):
class
Var
(
Expr
):
...
@@ -95,6 +100,16 @@ class GlobalVar(Expr):
...
@@ -95,6 +100,16 @@ class GlobalVar(Expr):
def
__init__
(
self
,
name_hint
):
def
__init__
(
self
,
name_hint
):
self
.
__init_handle_by_constructor__
(
_make
.
GlobalVar
,
name_hint
)
self
.
__init_handle_by_constructor__
(
_make
.
GlobalVar
,
name_hint
)
def
__call__
(
self
,
*
args
):
"""Invoke the gobal function.
Parameters
----------
args: List[relay.Expr]
Arguments.
"""
return
Call
(
self
,
args
,
None
,
None
)
@register_relay_node
@register_relay_node
class
Function
(
Expr
):
class
Function
(
Expr
):
...
@@ -126,6 +141,16 @@ class Function(Expr):
...
@@ -126,6 +141,16 @@ class Function(Expr):
self
.
__init_handle_by_constructor__
(
self
.
__init_handle_by_constructor__
(
_make
.
Function
,
params
,
body
,
ret_type
,
type_params
)
_make
.
Function
,
params
,
body
,
ret_type
,
type_params
)
def
__call__
(
self
,
*
args
):
"""Invoke the gobal function.
Parameters
----------
args: List[relay.Expr]
Arguments.
"""
return
Call
(
self
,
args
,
None
,
None
)
@register_relay_node
@register_relay_node
class
Call
(
Expr
):
class
Call
(
Expr
):
...
@@ -238,11 +263,17 @@ class TupleWrapper(_node.NodeGeneric):
...
@@ -238,11 +263,17 @@ class TupleWrapper(_node.NodeGeneric):
return
self
.
tuple_value
return
self
.
tuple_value
def
__getitem__
(
self
,
key
):
def
__getitem__
(
self
,
index
):
return
self
.
tuple_value
.
fields
[
key
]
if
index
>=
len
(
self
):
raise
IndexError
(
"Tuple index out of range"
)
return
TupleGetItem
(
self
.
tuple_value
,
index
)
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
tuple_value
.
fields
)
return
self
.
size
def
__repr__
(
self
):
return
(
"TupleWrapper("
+
self
.
tuple_value
.
__repr__
()
+
", "
+
self
.
size
+
")"
)
def
var
(
name_hint
,
def
var
(
name_hint
,
...
@@ -304,13 +335,27 @@ def const(value, dtype=None):
...
@@ -304,13 +335,27 @@ def const(value, dtype=None):
dtype: str, optional
dtype: str, optional
The data type of the value.
The data type of the value.
Note
----
When dtype is None, we use the following rule:
- int maps to "int32"
- float maps to "float32"
- bool maps to "bool"
- other using the same default rule as numpy.
"""
"""
if
isinstance
(
value
,
_base
.
numeric_types
):
if
isinstance
(
value
,
(
_base
.
numeric_types
,
(
bool
,
list
))):
value
=
_np
.
array
(
value
,
dtype
=
dtype
)
elif
isinstance
(
value
,
(
bool
,
list
)):
value
=
_np
.
array
(
value
,
dtype
=
dtype
)
value
=
_np
.
array
(
value
,
dtype
=
dtype
)
# convert default to int32 and float32
if
dtype
is
None
:
if
value
.
dtype
==
"float64"
:
value
=
value
.
astype
(
"float32"
)
elif
value
.
dtype
==
"int64"
:
value
=
value
.
astype
(
"int32"
)
if
isinstance
(
value
,
(
_np
.
ndarray
,
_np
.
generic
)):
if
isinstance
(
value
,
(
_np
.
ndarray
,
_np
.
generic
)):
value
=
_nd
.
array
(
value
)
value
=
_nd
.
array
(
value
)
if
not
isinstance
(
value
,
_nd
.
NDArray
):
if
not
isinstance
(
value
,
_nd
.
NDArray
):
raise
ValueError
(
"value has to be scalar or NDArray"
)
raise
ValueError
(
"value has to be scalar or NDArray"
)
return
Constant
(
value
)
return
Constant
(
value
)
python/tvm/relay/ir_builder.py
deleted
100644 → 0
View file @
4300bbc2
# pylint: disable=no-else-return
"""IR builder for the Relay IR.
Enables users to construct Relay programs with a Python API.
"""
from
collections
import
OrderedDict
import
numpy
as
np
import
tvm
from
.ty
import
Type
,
FuncType
,
TensorType
from
.expr
import
Expr
,
Constant
,
Let
,
Var
,
Function
,
If
from
.env
import
Environment
def
_convert_to_value
(
arg
,
ctxt
=
tvm
.
cpu
(
0
)):
# type: (Any, tvm.Context) -> tvm.nd.NDArray
"""Convert Python values into the appropriate types
for the Relay evaluator.
"""
if
isinstance
(
arg
,
bool
):
# bool is subclass of int
return
tvm
.
nd
.
array
(
np
.
array
(
arg
,
dtype
=
'uint8'
),
ctxt
)
elif
isinstance
(
arg
,
int
):
return
tvm
.
nd
.
array
(
np
.
array
(
arg
,
dtype
=
'int32'
),
ctxt
)
elif
isinstance
(
arg
,
float
):
return
tvm
.
nd
.
array
(
arg
,
ctxt
)
elif
isinstance
(
arg
,
np
.
ndarray
):
return
tvm
.
nd
.
array
(
arg
,
ctxt
)
elif
isinstance
(
arg
,
tvm
.
ndarray
.
NDArray
):
return
arg
else
:
# raise Exception(f"can't convert {type(arg)} to a Relay AST")
raise
Exception
(
"unsupported argument type {0}"
.
format
(
type
(
arg
)))
def
_convert_type
(
rtype
):
if
isinstance
(
rtype
,
str
):
return
scalar_type
(
rtype
)
elif
isinstance
(
rtype
,
Type
):
return
rtype
else
:
raise
Exception
(
"unsupported conversion to Relay type {0}"
.
format
(
type
(
rtype
)))
def
convert
(
arg
):
# type: (Any) -> Expr
"""Convert some Python objects into a Relay AST fragment.
Parameters
----------
arg: Any
The Python object
Returns
-------
expr: relay.Expr
The converted expression.
"""
if
isinstance
(
arg
,
Expr
):
return
arg
elif
isinstance
(
arg
,
tuple
):
return
relay
.
Tuple
([
convert
(
el
)
for
el
in
arg
])
elif
isinstance
(
arg
,
PartialFunc
):
return
arg
.
to_func
()
elif
isinstance
(
arg
,
tvm
.
_ffi
.
node
.
NodeGeneric
):
return
arg
.
asnode
()
else
:
value
=
_convert_to_value
(
arg
)
return
Constant
(
value
)
class
WithScope
(
object
):
"""A wrapper for builder methods which introduce scoping."""
def
__init__
(
self
,
enter_value
,
exit_cb
):
self
.
_enter_value
=
enter_value
self
.
_exit_cb
=
exit_cb
def
__enter__
(
self
):
return
self
.
_enter_value
def
__exit__
(
self
,
ptype
,
value
,
trace
):
if
value
:
raise
value
else
:
self
.
_exit_cb
()
class
PartialFunc
(
object
):
"""A wrapper around functions while they are being built.
Used by the builder as a user is building up a function,
allows Function nodes which contain partially initialized
state.
"""
def
__init__
(
self
,
params
,
ret_type
,
body
,
type_params
):
self
.
params
=
params
self
.
ret_type
=
ret_type
self
.
body
=
body
self
.
type_params
=
type_params
def
param_ids
(
self
):
return
[
p
for
p
in
self
.
params
]
def
to_func
(
self
):
"""Converts a PartialFunc into a :py:class:`~relay.Function`."""
return
Function
(
self
.
params
,
self
.
body
,
self
.
ret_type
,
self
.
type_params
)
#pylint: disable=invalid-name
def
_mk_let
(
bindings
,
ret_value
):
let_expr
=
ret_value
for
var
,
value
in
reversed
(
list
(
bindings
.
items
())):
let_expr
=
Let
(
var
,
value
,
let_expr
)
return
let_expr
class
IRBuilder
(
object
):
"""The IRBuilder class.
Enables users to build up a Relay environment and program.
Examples
--------
Program:
fn (x : Tensor[f32, (10, 10)]) {
let t1 = log(x);
let t2 = add(t1, x);
return t1;
}
..code-block: python
b = IRBuilder()
with b.function(('x', tensor_type(10, 10))) as func:
x, = func.param_ids()
t1 = b.let('t1', log(x))
t2 = b.let('t2', add(t1, x))
b.ret(t2)
"""
def
__init__
(
self
):
self
.
bindings
=
[
OrderedDict
({})]
self
.
scopes
=
[
OrderedDict
({})]
self
.
params
=
[]
self
.
ret_values
=
[
None
]
self
.
env
=
Environment
({})
def
enter_scope
(
self
,
params
=
None
):
if
not
params
:
params
=
[]
self
.
bindings
.
append
(
OrderedDict
({}))
self
.
scopes
.
append
(
OrderedDict
({}))
self
.
params
.
append
(
params
)
self
.
ret_values
.
append
(
None
)
def
exit_scope
(
self
):
bindings
=
self
.
bindings
.
pop
()
scopes
=
self
.
scopes
.
pop
()
params
=
self
.
params
.
pop
()
ret_value
=
self
.
ret_values
.
pop
()
return
bindings
,
scopes
,
params
,
ret_value
#pylint: disable=invalid-name
def
bind
(
self
,
name
,
value
,
ty
):
lv
=
Var
(
name
,
ty
)
self
.
scopes
[
-
1
][
name
]
=
lv
self
.
bindings
[
-
1
][
lv
]
=
value
return
lv
def
let
(
self
,
name
,
value
,
value_type
=
None
):
if
not
isinstance
(
value
,
Expr
):
value
=
convert
(
value
)
return
self
.
bind
(
name
,
value
,
value_type
)
def
_convert_params
(
self
,
raw_params
):
relay_params
=
[]
for
raw_param
in
raw_params
:
if
isinstance
(
raw_param
,
Var
):
param
=
raw_param
elif
isinstance
(
raw_param
,
tuple
):
var
,
ty
=
raw_param
ty
=
_convert_type
(
ty
)
param
=
Var
(
var
,
ty
)
elif
isinstance
(
raw_param
,
str
):
param
=
Var
(
raw_param
,
None
)
else
:
raise
Exception
(
"unknown parameter type"
)
self
.
scopes
[
-
1
][
param
.
name_hint
]
=
param
relay_params
.
append
(
param
)
return
relay_params
def
function
(
self
,
*
params
):
"""Construct a Relay function."""
relay_params
=
self
.
_convert_params
(
params
)
self
.
enter_scope
()
pfunc
=
PartialFunc
(
relay_params
,
None
,
None
,
[])
def
_on_exit
():
bindings
,
_
,
_
,
ret_value
=
self
.
exit_scope
()
body
=
_mk_let
(
bindings
,
ret_value
)
pfunc
.
body
=
body
return
WithScope
(
pfunc
,
_on_exit
)
def
ret
(
self
,
x
):
"""Set `x` to be the return value of the current function."""
if
not
self
.
ret_values
[
-
1
]:
self
.
ret_values
[
-
1
]
=
convert
(
x
)
else
:
raise
Exception
(
"return value already set, a function can only have one return value"
)
def
if_scope
(
self
,
cond
):
"""Construct the if branch an if expression with scoping."""
self
.
enter_scope
()
def
_on_exit
():
bindings
,
_
,
_
,
ret_value
=
self
.
exit_scope
()
assert
self
.
ret_values
[
-
1
]
is
None
true_branch
=
_mk_let
(
bindings
,
ret_value
)
self
.
ret_values
[
-
1
]
=
If
(
cond
,
true_branch
,
None
)
return
WithScope
(
10
,
_on_exit
)
def
else_scope
(
self
):
"""Construct the else branch of an if expression with scoping."""
self
.
enter_scope
()
def
_on_exit
():
bindings
,
_
,
_
,
ret_value
=
self
.
exit_scope
()
partial_if
=
self
.
ret_values
[
-
1
]
assert
isinstance
(
partial_if
,
If
)
and
partial_if
.
false_branch
is
None
false_branch
=
_mk_let
(
bindings
,
ret_value
)
self
.
ret_values
[
-
1
]
=
If
(
partial_if
.
cond
,
partial_if
.
true_branch
,
false_branch
)
return
WithScope
(
10
,
_on_exit
)
def
param
(
self
,
name
,
ty
=
None
):
if
not
ty
:
ty
=
scalar_type
(
'float32'
)
else
:
ty
=
_convert_type
(
ty
)
return
Var
(
name
,
ty
)
def
global_var
(
self
,
name
):
# type: (str) -> GlobalVar
"""Construct a global var with `name` as its name hint.
Parameters
----------
name: str
The name of the global variable.
Returns
-------
global_var: relay.GlobalVar
The global variable with `name`.
"""
return
self
.
env
.
global_var
(
name
)
def
decl
(
self
,
name
,
*
params
,
**
kwargs
):
"""Create a global function.
Parameters
----------
name: str or GlobalVar
The name of the function.
params: params
The parameters of the function.
Returns
-------
with_scope: Scope for the function.
"""
ret_type
=
kwargs
.
get
(
'ret_type'
,
None
)
self
.
enter_scope
()
def
_on_exit
():
bindings
,
_
,
_
,
ret_value
=
self
.
exit_scope
()
exp
=
_mk_let
(
bindings
,
ret_value
)
self
.
env
.
add
(
name
,
Function
(
params
,
exp
,
ret_type
))
return
WithScope
(
10
,
_on_exit
)
def
get
(
self
):
"""Get the full program.
Returns
----------
(prog, env) : (relay.Expr, relay.Environment)
A pair of the partial program, and the modified environment.
"""
bindings
=
self
.
bindings
.
pop
()
scope
=
self
.
scopes
.
pop
()
if
self
.
bindings
:
raise
Exception
(
"IRBuilder: binding error"
)
if
self
.
scopes
:
raise
Exception
(
"IRBuilder: scoping error"
)
if
bindings
and
scope
and
not
self
.
ret_values
:
raise
Exception
(
"IRBuilder: no return value set"
)
return
_mk_let
(
bindings
,
self
.
ret_values
[
-
1
]),
self
.
env
def
scalar_type
(
dtype
):
"""Construct a Relay scalar type.
Parameters
----------
dtype: dtype
The dtype of the scalar type.
Returns:
scalar_type: relay.Type
The scalar type.
"""
return
TensorType
(
tvm
.
convert
([]),
dtype
)
def
tensor_type
(
*
shape
,
**
kwargs
):
"""Construct a Relay Tensor type.
Parameters
----------
shape: list of tvm.Expr
The shape of the Tensor type.
dtype: dtype
The dtype of the Tensor type.
Returns
-------
tensor_type: relay.Type
The resulting tensor types.
"""
dtype
=
kwargs
.
get
(
'dtype'
,
'float32'
)
return
TensorType
(
tvm
.
convert
(
shape
),
dtype
)
def
func_type
(
args
,
ret_type
,
type_params
=
None
):
"""Construct a Relay function type.
Parameters
----------
args: list of relay.Type
The argument types.
ret_type: relay.Type
The return type.
type_params: list of relay.TypeParam
The type parameters.
Returns
-------
func_type: The function type.
"""
if
not
type_params
:
type_params
=
[]
args
=
[
_convert_type
(
arg
)
for
arg
in
args
]
ret_type
=
_convert_type
(
ret_type
)
return
FuncType
(
args
,
ret_type
,
type_params
,
[])
python/tvm/relay/ir_pass.py
View file @
8876eac8
...
@@ -2,37 +2,39 @@
...
@@ -2,37 +2,39 @@
# pylint: disable=unidiomatic-typecheck
# pylint: disable=unidiomatic-typecheck
"""The set of passes for Relay.
"""The set of passes for Relay.
Exposes an interface for configuring the passes and
scripting
Exposes an interface for configuring the passes and
them in Python.
scripting
them in Python.
"""
"""
from
.
import
_ir_pass
from
.
import
_ir_pass
from
.
import
_make
from
.
import
_make
# pylint: disable=invalid-name
# pylint: disable=invalid-name
def
infer_type
(
e
nv
,
expr
):
def
infer_type
(
e
xpr
,
env
=
None
):
"""Infer the type of expr under the context of env.
"""Infer the type of expr under the context of env.
Parameters
Parameters
----------
----------
env : relay.Environment
expr: tvm.relay.Expr
The input expression.
env: Optional[tvm.relay.Environment]
The global environment.
The global environment.
expr : relay.Expr
The input expression.
Returns
Returns
-------
-------
checked_expr : relay.Expr
checked_expr :
tvm.
relay.Expr
The checked expression.
The checked expression.
"""
"""
return
_ir_pass
.
infer_type
(
e
nv
,
expr
)
return
_ir_pass
.
infer_type
(
e
xpr
,
env
)
def
well_formed
(
e
):
def
well_formed
(
expr
):
"""Check that each Var is only bound once (well formed).
"""Check that each Var is only bound once (well formed).
Parameters
Parameters
----------
----------
e
:
relay.Expr
e
xpr: tvm.
relay.Expr
The input expression
The input expression
Returns
Returns
...
@@ -40,7 +42,8 @@ def well_formed(e):
...
@@ -40,7 +42,8 @@ def well_formed(e):
well_form : bool
well_form : bool
whether the input expression is well formed
whether the input expression is well formed
"""
"""
return
_ir_pass
.
well_formed
(
e
)
return
_ir_pass
.
well_formed
(
expr
)
def
check_kind
(
t
,
env
=
None
):
def
check_kind
(
t
,
env
=
None
):
"""Check that the type is well kinded.
"""Check that the type is well kinded.
...
@@ -48,10 +51,10 @@ def check_kind(t, env=None):
...
@@ -48,10 +51,10 @@ def check_kind(t, env=None):
Parameters
Parameters
----------
----------
t: relay.Type
t:
tvm.
relay.Type
The type to check
The type to check
env: relay.Environment, optional
env:
tvm.
relay.Environment, optional
The global environment
The global environment
Returns
Returns
...
@@ -71,61 +74,65 @@ def check_kind(t, env=None):
...
@@ -71,61 +74,65 @@ def check_kind(t, env=None):
else
:
else
:
return
_ir_pass
.
check_kind
(
t
)
return
_ir_pass
.
check_kind
(
t
)
def
free_vars
(
e
):
def
free_vars
(
e
):
"""Get free variables from expression e.
"""Get free variables from expression e.
Parameters
Parameters
----------
----------
e: relay.Expr
e:
tvm.
relay.Expr
The input expression
The input expression
Returns
Returns
-------
-------
free : List[relay.Var]
free : List[
tvm.
relay.Var]
t
he list of free variables
T
he list of free variables
"""
"""
return
_ir_pass
.
free_vars
(
e
)
return
_ir_pass
.
free_vars
(
e
)
def
free_type_vars
(
e
):
def
free_type_vars
(
expr
):
"""Get free type variables from expression/type e
"""Get free type variables from expression/type e
Parameters
Parameters
----------
----------
e
: relay.Expr/relay.Type
e
xpr: Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
The input expression/type
Returns
Returns
-------
-------
free : List[relay.TypeParam]
free : List[
tvm.
relay.TypeParam]
t
he list of free type variables
T
he list of free type variables
"""
"""
return
_ir_pass
.
free_type_vars
(
e
)
return
_ir_pass
.
free_type_vars
(
e
xpr
)
def
dead_code_elimination
(
e
):
def
dead_code_elimination
(
expr
):
""" Remove expressions which does not effect the program result (dead code).
""" Remove expressions which does not effect the program result (dead code).
Parameters
Parameters
----------
----------
e: relay.Expr
e:
tvm.
relay.Expr
The input Expression
The input Expression
Returns
Returns
-------
-------
result: relay.Expr
result:
tvm.
relay.Expr
An expression which is semantically equal to the input expression,
An expression which is semantically equal to the input expression,
but with dead code removed.
but with dead code removed.
"""
"""
return
_ir_pass
.
dead_code_elimination
(
e
)
return
_ir_pass
.
dead_code_elimination
(
expr
)
def
alpha_equal
(
lhs
,
rhs
):
def
alpha_equal
(
lhs
,
rhs
):
"""Compare two Relay expr for structural equivalence (alpha equivalence).
"""Compare two Relay expr for structural equivalence (alpha equivalence).
Parameters
Parameters
----------
----------
lhs: relay.Expr
lhs:
tvm.
relay.Expr
One of the input Expression.
One of the input Expression.
rhs: relay.Expr
rhs:
tvm.
relay.Expr
One of the input Expression.
One of the input Expression.
Returns
Returns
...
...
python/tvm/relay/scope_builder.py
0 → 100644
View file @
8876eac8
"""The scope builder interface """
from
__future__
import
absolute_import
from
.
import
expr
as
_expr
from
.._ffi
import
base
as
_base
class
WithScope
(
object
):
"""A wrapper for builder methods which introduce scoping.
Parameters
----------
enter_value: object
The value returned by enter.
"""
def
__init__
(
self
,
enter_value
,
exit_cb
):
self
.
_enter_value
=
enter_value
self
.
_exit_cb
=
exit_cb
def
__enter__
(
self
):
return
self
.
_enter_value
def
__exit__
(
self
,
ptype
,
value
,
trace
):
if
value
:
raise
value
else
:
self
.
_exit_cb
()
def
_make_lets
(
bindings
,
ret_value
):
"""Make a nested let expressions.
Parameters
----------
bindings: List[Tuple[tvm.relay.Var,tvm.relay.Expr]]
The sequence of let bindings
ret_value: tvm.relay.Expr
The final value of the expression.
Returns
-------
lets: tvm.relay.Expr
A nested let expression.
"""
if
ret_value
is
None
:
raise
RuntimeError
(
"ret is not called in this scope"
)
if
isinstance
(
ret_value
,
_expr
.
If
)
and
ret_value
.
false_branch
is
None
:
raise
RuntimeError
(
"Creating an If expression without else."
)
let_expr
=
ret_value
for
var
,
value
in
reversed
(
bindings
):
let_expr
=
_expr
.
Let
(
var
,
value
,
let_expr
)
return
let_expr
class
ScopeBuilder
(
object
):
"""Scope builder class.
Enables users to build up a nested
scope(let, if) expression easily.
Examples
--------
..code-block: python
sb = relay.ScopeBuilder()
cond = relay.var("cond", 'bool')
x = relay.var("x")
y = relay.var("y")
with sb.if_scope(cond):
one = relay.const(1, "float32")
t1 = sb.let(t1, relay.add(x, one))
sb.ret(t1)
with sb.else_scope():
sb.ret(y)
print(sb.get().astext())
"""
def
__init__
(
self
):
self
.
_bindings
=
[[]]
self
.
_ret_values
=
[
None
]
def
_enter_scope
(
self
):
self
.
_bindings
.
append
([])
self
.
_ret_values
.
append
(
None
)
def
_exit_scope
(
self
):
bindings
=
self
.
_bindings
.
pop
()
ret_value
=
self
.
_ret_values
.
pop
()
return
bindings
,
ret_value
def
let
(
self
,
var
,
value
):
"""Create a new let binding.
Parameters
----------
var: Union[Tuple[str, relay.Type], tvm.relay.Var]
The variable or name of variable.
value: tvm.relay.Expr
The value to be binded
"""
if
isinstance
(
var
,
(
tuple
,
list
)):
if
len
(
var
)
>
2
:
raise
ValueError
(
"Expect var to be Tuple[str, relay.Type]"
)
var
=
_expr
.
var
(
*
var
)
elif
isinstance
(
var
,
_base
.
string_types
):
var
=
_expr
.
var
(
var
)
self
.
_bindings
[
-
1
]
.
append
((
var
,
value
))
return
var
def
if_scope
(
self
,
cond
):
"""Create a new if scope.
Parameters
----------
cond: tvm.relay.Expr
The condition
Returns
-------
scope: WithScope
The if scope.
Note
----
The user must follows with an else scope.
"""
self
.
_enter_scope
()
def
_on_exit
():
bindings
,
ret_value
=
self
.
_exit_scope
()
if
self
.
_ret_values
[
-
1
]
is
not
None
:
raise
RuntimeError
(
"result already returned before if scope"
)
true_branch
=
_make_lets
(
bindings
,
ret_value
)
self
.
_ret_values
[
-
1
]
=
_expr
.
If
(
cond
,
true_branch
,
None
)
return
WithScope
(
None
,
_on_exit
)
def
else_scope
(
self
):
"""Create a new else scope.
Returns
-------
scope: WithScope
The if scope.
"""
self
.
_enter_scope
()
def
_on_exit
():
bindings
,
ret_value
=
self
.
_exit_scope
()
partial_if
=
self
.
_ret_values
[
-
1
]
no_else
=
(
not
isinstance
(
partial_if
,
_expr
.
If
)
or
partial_if
.
false_branch
is
not
None
)
if
no_else
:
raise
RuntimeError
(
"else scope must follows"
)
false_branch
=
_make_lets
(
bindings
,
ret_value
)
self
.
_ret_values
[
-
1
]
=
_expr
.
If
(
partial_if
.
cond
,
partial_if
.
true_branch
,
false_branch
)
return
WithScope
(
None
,
_on_exit
)
def
ret
(
self
,
value
):
"""Set the return value of this scope.
Parameters
----------
value: tvm.relay.Expr
The return value.
"""
if
self
.
_ret_values
[
-
1
]
is
not
None
:
raise
RuntimeError
(
"ret value is already set in this scope."
)
self
.
_ret_values
[
-
1
]
=
value
def
get
(
self
):
"""Get the generated result.
Returns
-------
value: tvm.relay.Expr
The final result of the expression.
"""
if
len
(
self
.
_bindings
)
!=
1
:
raise
RuntimeError
(
"can only call get at the outmost scope"
)
return
_make_lets
(
self
.
_bindings
[
-
1
],
self
.
_ret_values
[
-
1
])
python/tvm/relay/ty.py
View file @
8876eac8
...
@@ -56,7 +56,7 @@ class Kind(IntEnum):
...
@@ -56,7 +56,7 @@ class Kind(IntEnum):
Shape
=
3
Shape
=
3
@register_relay_node
@register_relay_node
class
Type
Param
(
Type
):
class
Type
Var
(
Type
):
"""A type parameter used for generic types in Relay,
"""A type parameter used for generic types in Relay,
see tvm/relay/type.h for more details.
see tvm/relay/type.h for more details.
...
@@ -66,7 +66,7 @@ class TypeParam(Type):
...
@@ -66,7 +66,7 @@ class TypeParam(Type):
"""
"""
def
__init__
(
self
,
var
,
kind
=
Kind
.
Type
):
def
__init__
(
self
,
var
,
kind
=
Kind
.
Type
):
"""Construct a Type
Param
.
"""Construct a Type
Var
.
Parameters
Parameters
----------
----------
...
@@ -78,10 +78,10 @@ class TypeParam(Type):
...
@@ -78,10 +78,10 @@ class TypeParam(Type):
Returns
Returns
-------
-------
type_param: Type
Param
type_param: Type
Var
The type parameter.
The type parameter.
"""
"""
self
.
__init_handle_by_constructor__
(
_make
.
Type
Param
,
var
,
kind
)
self
.
__init_handle_by_constructor__
(
_make
.
Type
Var
,
var
,
kind
)
@register_relay_node
@register_relay_node
...
@@ -122,26 +122,30 @@ class FuncType(Type):
...
@@ -122,26 +122,30 @@ class FuncType(Type):
We informally write them as:
We informally write them as:
`forall (type_params), (arg_types) -> ret_type where type_constraints`
`forall (type_params), (arg_types) -> ret_type where type_constraints`
Parameters
----------
arg_types: List[tvm.relay.Type]
The argument types
ret_type: tvm.relay.Type
The return type.
type_params: List[tvm.relay.TypeVar]
The type parameters
type_constraints: List[tvm.relay.TypeConstraint]
The type constraints.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
arg_types
,
arg_types
,
ret_type
,
ret_type
,
type_params
,
type_params
=
None
,
type_constraints
):
type_constraints
=
None
):
"""Construct a function type.
if
type_params
is
None
:
type_params
=
[]
Parameters
if
type_constraints
is
None
:
----------
type_constraints
=
[]
arg_types: list of Type
ret_type: Type
type_params: list of TypeParam
type_constraints: list of TypeConstraint
Returns
-------
func_type: FuncType
The function type.
"""
self
.
__init_handle_by_constructor__
(
self
.
__init_handle_by_constructor__
(
_make
.
FuncType
,
arg_types
,
ret_type
,
type_params
,
type_constraints
)
_make
.
FuncType
,
arg_types
,
ret_type
,
type_params
,
type_constraints
)
...
@@ -175,3 +179,21 @@ class TypeRelation(TypeConstraint):
...
@@ -175,3 +179,21 @@ class TypeRelation(TypeConstraint):
def
__init__
(
self
,
func
,
args
,
num_inputs
,
attrs
):
def
__init__
(
self
,
func
,
args
,
num_inputs
,
attrs
):
self
.
__init_handle_by_constructor__
(
_make
.
TypeRelation
,
self
.
__init_handle_by_constructor__
(
_make
.
TypeRelation
,
func
,
args
,
num_inputs
,
attrs
)
func
,
args
,
num_inputs
,
attrs
)
def
scalar_type
(
dtype
):
"""Creates a scalar type.
This function returns TensorType((), dtype)
Parameters
----------
dtype : str
The content data type.
Returns
-------
s_type: tvm.relay.TensorType
The result type.
"""
return
TensorType
((),
dtype
)
src/relay/ir/environment.cc
View file @
8876eac8
...
@@ -16,87 +16,71 @@ using namespace runtime;
...
@@ -16,87 +16,71 @@ using namespace runtime;
Environment
EnvironmentNode
::
make
(
tvm
::
Map
<
GlobalVar
,
Function
>
global_funcs
)
{
Environment
EnvironmentNode
::
make
(
tvm
::
Map
<
GlobalVar
,
Function
>
global_funcs
)
{
auto
n
=
make_node
<
EnvironmentNode
>
();
auto
n
=
make_node
<
EnvironmentNode
>
();
n
->
functions
=
std
::
move
(
global_funcs
);
n
->
functions
=
std
::
move
(
global_funcs
);
for
(
const
auto
&
kv
:
n
->
functions
)
{
// set gloval var map
CHECK
(
!
n
->
global_var_map_
.
count
(
kv
.
first
->
name_hint
))
<<
"Duplicate global function name "
<<
kv
.
first
->
name_hint
;
n
->
global_var_map_
.
Set
(
kv
.
first
->
name_hint
,
kv
.
first
);
}
return
Environment
(
n
);
return
Environment
(
n
);
}
}
GlobalVar
EnvironmentNode
::
GetGlobalVar
(
const
std
::
string
&
str
)
{
GlobalVar
EnvironmentNode
::
GetGlobalVar
(
const
std
::
string
&
name
)
{
auto
global_id
=
global_map_
.
find
(
str
);
auto
it
=
global_var_map_
.
find
(
name
);
if
(
global_id
!=
global_map_
.
end
())
{
CHECK
(
it
!=
global_var_map_
.
end
())
return
(
*
global_id
).
second
;
<<
"Cannot find global var "
<<
name
<<
" in the Environment"
;
}
else
{
return
(
*
it
).
second
;
auto
id
=
GlobalVarNode
::
make
(
str
);
this
->
global_map_
.
Set
(
str
,
id
);
return
id
;
}
}
}
/*!
void
EnvironmentNode
::
Add
(
const
GlobalVar
&
var
,
* \brief Add a new item to the global environment
const
Function
&
func
,
* \note if the update flag is not set adding a duplicate
* definition will trigger an exception, otherwise we will
* update the definition if and only if it is type compatible.
*/
void
EnvironmentNode
::
Add
(
const
GlobalVar
&
var
,
const
Function
&
func
,
bool
update
)
{
bool
update
)
{
// Type check the item before we add it to the environment.
// Type check the item before we add it to the environment.
auto
env
=
GetRef
<
Environment
>
(
this
);
auto
env
=
GetRef
<
Environment
>
(
this
);
Function
checked_func
=
InferType
(
func
,
env
,
var
);
Expr
checked_expr
=
InferType
(
env
,
var
,
func
);
auto
type
=
checked_func
->
checked_type
();
CHECK
(
type
.
as
<
IncompleteTypeNode
>
()
==
nullptr
);
if
(
const
FunctionNode
*
func_node
=
checked_expr
.
as
<
FunctionNode
>
())
{
if
(
functions
.
find
(
var
)
!=
functions
.
end
())
{
auto
checked_func
=
GetRef
<
Function
>
(
func_node
);
CHECK
(
update
)
auto
type
=
checked_func
->
checked_type
();
<<
"Already have definition for "
<<
var
->
name_hint
;
auto
old_type
=
functions
[
var
].
as
<
FunctionNode
>
()
->
checked_type
();
CHECK
(
type
.
as
<
IncompleteTypeNode
>
()
==
nullptr
);
CHECK
(
AlphaEqual
(
type
,
old_type
))
<<
"Environment#update changes type, not possible in this mode."
;
if
(
functions
.
find
(
var
)
!=
functions
.
end
())
{
if
(
!
update
)
{
throw
dmlc
::
Error
(
"already have definition for XXXX."
);
}
auto
old_type
=
functions
[
var
].
as
<
FunctionNode
>
()
->
checked_type
();
if
(
!
AlphaEqual
(
type
,
old_type
))
{
throw
dmlc
::
Error
(
"Environment#update changes type, not possible in this mode."
);
}
this
->
functions
.
Set
(
var
,
checked_func
);
}
else
{
this
->
functions
.
Set
(
var
,
checked_func
);
}
}
else
{
LOG
(
FATAL
)
<<
"internal error: unknown item type, unreachable code"
;
}
}
this
->
functions
.
Set
(
var
,
checked_func
);
// set gloval var map
CHECK
(
!
global_var_map_
.
count
(
var
->
name_hint
))
<<
"Duplicate global function name "
<<
var
->
name_hint
;
global_var_map_
.
Set
(
var
->
name_hint
,
var
);
}
}
void
EnvironmentNode
::
Update
(
const
GlobalVar
&
var
,
const
Function
&
func
)
{
void
EnvironmentNode
::
Update
(
const
GlobalVar
&
var
,
const
Function
&
func
)
{
this
->
Add
(
var
,
func
,
true
);
this
->
Add
(
var
,
func
,
true
);
}
}
void
EnvironmentNode
::
Remove
(
const
GlobalVar
&
var
)
{
void
EnvironmentNode
::
Remove
(
const
GlobalVar
&
var
)
{
auto
functions_node
=
this
->
functions
.
CopyOnWrite
();
auto
functions_node
=
this
->
functions
.
CopyOnWrite
();
functions_node
->
data
.
erase
(
var
.
node_
);
functions_node
->
data
.
erase
(
var
.
node_
);
auto
gvar_node
=
global_var_map_
.
CopyOnWrite
();
gvar_node
->
data
.
erase
(
var
->
name_hint
);
}
}
Function
EnvironmentNode
::
Lookup
(
const
GlobalVar
&
var
)
{
Function
EnvironmentNode
::
Lookup
(
const
GlobalVar
&
var
)
{
auto
func
=
functions
.
find
(
var
);
auto
it
=
functions
.
find
(
var
);
if
(
func
!=
functions
.
end
())
{
CHECK
(
it
!=
functions
.
end
())
return
(
*
func
).
second
;
<<
"There is no definition of "
<<
var
->
name_hint
;
}
else
{
return
(
*
it
).
second
;
throw
Error
(
std
::
string
(
"there is no definition of "
)
+
var
->
name_hint
);
}
}
}
Function
EnvironmentNode
::
Lookup
(
const
std
::
string
&
str
)
{
Function
EnvironmentNode
::
Lookup
(
const
std
::
string
&
name
)
{
GlobalVar
id
=
this
->
GetGlobalVar
(
str
);
GlobalVar
id
=
this
->
GetGlobalVar
(
name
);
return
this
->
Lookup
(
id
);
return
this
->
Lookup
(
id
);
}
}
void
EnvironmentNode
::
Merg
e
(
const
Environment
&
env
)
{
void
EnvironmentNode
::
Updat
e
(
const
Environment
&
env
)
{
for
(
auto
pair
:
env
->
functions
)
{
for
(
auto
pair
:
env
->
functions
)
{
this
->
functions
.
Set
(
pair
.
first
,
pair
.
second
);
this
->
Update
(
pair
.
first
,
pair
.
second
);
}
}
}
}
...
@@ -134,10 +118,10 @@ TVM_REGISTER_API("relay._env.Environment_Lookup_str")
...
@@ -134,10 +118,10 @@ TVM_REGISTER_API("relay._env.Environment_Lookup_str")
*
ret
=
env
->
Lookup
(
var
);
*
ret
=
env
->
Lookup
(
var
);
});
});
TVM_REGISTER_API
(
"relay._env.Environment_
Merg
e"
)
TVM_REGISTER_API
(
"relay._env.Environment_
Updat
e"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Environment
env
=
args
[
0
];
Environment
env
=
args
[
0
];
env
->
Merg
e
(
args
[
1
]);
env
->
Updat
e
(
args
[
1
]);
});
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
...
...
src/relay/ir/expr.cc
View file @
8876eac8
...
@@ -104,7 +104,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
...
@@ -104,7 +104,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
Function
FunctionNode
::
make
(
tvm
::
Array
<
Var
>
params
,
Function
FunctionNode
::
make
(
tvm
::
Array
<
Var
>
params
,
Expr
body
,
Expr
body
,
Type
ret_type
,
Type
ret_type
,
tvm
::
Array
<
Type
Param
>
type_params
)
{
tvm
::
Array
<
Type
Var
>
type_params
)
{
NodePtr
<
FunctionNode
>
n
=
make_node
<
FunctionNode
>
();
NodePtr
<
FunctionNode
>
n
=
make_node
<
FunctionNode
>
();
n
->
params
=
std
::
move
(
params
);
n
->
params
=
std
::
move
(
params
);
n
->
body
=
std
::
move
(
body
);
n
->
body
=
std
::
move
(
body
);
...
...
src/relay/ir/expr_functor.cc
View file @
8876eac8
...
@@ -66,11 +66,11 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) {
...
@@ -66,11 +66,11 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) {
}
}
Expr
ExprMutator
::
VisitExpr_
(
const
FunctionNode
*
op
)
{
Expr
ExprMutator
::
VisitExpr_
(
const
FunctionNode
*
op
)
{
tvm
::
Array
<
Type
Param
>
ty_params
;
tvm
::
Array
<
Type
Var
>
ty_params
;
bool
all_ty_params_changed
=
true
;
bool
all_ty_params_changed
=
true
;
for
(
auto
ty_param
:
op
->
type_params
)
{
for
(
auto
ty_param
:
op
->
type_params
)
{
Type
Param
new_ty_param
=
Downcast
<
TypeParam
>
(
VisitType
(
ty_param
));
Type
Var
new_ty_param
=
Downcast
<
TypeVar
>
(
VisitType
(
ty_param
));
ty_params
.
push_back
(
new_ty_param
);
ty_params
.
push_back
(
new_ty_param
);
all_ty_params_changed
&=
new_ty_param
.
same_as
(
ty_param
);
all_ty_params_changed
&=
new_ty_param
.
same_as
(
ty_param
);
}
}
...
...
src/relay/ir/text_printer.cc
View file @
8876eac8
...
@@ -217,6 +217,8 @@ class TextPrinter :
...
@@ -217,6 +217,8 @@ class TextPrinter :
return
ConstScalar
(
dtype
,
static_cast
<
const
float
*>
(
op
->
data
->
data
));
return
ConstScalar
(
dtype
,
static_cast
<
const
float
*>
(
op
->
data
->
data
));
}
else
if
(
dtype
==
Float
(
64
))
{
}
else
if
(
dtype
==
Float
(
64
))
{
return
ConstScalar
(
dtype
,
static_cast
<
const
double
*>
(
op
->
data
->
data
));
return
ConstScalar
(
dtype
,
static_cast
<
const
double
*>
(
op
->
data
->
data
));
}
else
if
(
dtype
==
Bool
())
{
return
ConstScalar
(
dtype
,
static_cast
<
const
uint8_t
*>
(
op
->
data
->
data
));
}
}
}
}
// default fall-back, record it as meta node.
// default fall-back, record it as meta node.
...
@@ -638,8 +640,14 @@ class TextPrinter :
...
@@ -638,8 +640,14 @@ class TextPrinter :
* \return The corresponding name.
* \return The corresponding name.
*/
*/
TextValue
AllocVarName
(
const
Var
&
var
)
{
TextValue
AllocVarName
(
const
Var
&
var
)
{
std
::
string
name
=
GetUniqueName
(
'%'
+
var
->
name_hint
);
std
::
string
name
=
var
->
name_hint
;
TextValue
val
(
name
);
// always make sure first name is alpha
if
(
name
.
length
()
!=
0
&&
!
std
::
isalpha
(
name
[
0
]))
{
name
=
"%v"
+
name
;
}
else
{
name
=
"%"
+
name
;
}
TextValue
val
(
GetUniqueName
(
name
));
CHECK
(
!
memo_
.
count
(
var
));
CHECK
(
!
memo_
.
count
(
var
));
memo_
[
var
]
=
val
;
memo_
[
var
]
=
val
;
return
val
;
return
val
;
...
...
src/relay/ir/type.cc
View file @
8876eac8
...
@@ -36,30 +36,30 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
...
@@ -36,30 +36,30 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p
->
stream
<<
"TensorType("
<<
node
->
shape
<<
", "
<<
node
->
dtype
<<
")"
;
p
->
stream
<<
"TensorType("
<<
node
->
shape
<<
", "
<<
node
->
dtype
<<
")"
;
});
});
Type
Param
TypeParamNode
::
make
(
std
::
string
name
,
TypeParam
Node
::
Kind
kind
)
{
Type
Var
TypeVarNode
::
make
(
std
::
string
name
,
TypeVar
Node
::
Kind
kind
)
{
NodePtr
<
Type
ParamNode
>
n
=
make_node
<
TypeParam
Node
>
();
NodePtr
<
Type
VarNode
>
n
=
make_node
<
TypeVar
Node
>
();
n
->
var
=
tvm
::
Var
(
name
);
n
->
var
=
tvm
::
Var
(
name
);
n
->
kind
=
std
::
move
(
kind
);
n
->
kind
=
std
::
move
(
kind
);
return
Type
Param
(
n
);
return
Type
Var
(
n
);
}
}
TVM_REGISTER_NODE_TYPE
(
Type
Param
Node
);
TVM_REGISTER_NODE_TYPE
(
Type
Var
Node
);
TVM_REGISTER_API
(
"relay._make.Type
Param
"
)
TVM_REGISTER_API
(
"relay._make.Type
Var
"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
int
kind
=
args
[
1
];
int
kind
=
args
[
1
];
*
ret
=
*
ret
=
Type
ParamNode
::
make
(
args
[
0
],
static_cast
<
TypeParam
Node
::
Kind
>
(
kind
));
Type
VarNode
::
make
(
args
[
0
],
static_cast
<
TypeVar
Node
::
Kind
>
(
kind
));
});
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
Type
ParamNode
>
([](
const
TypeParam
Node
*
node
,
.
set_dispatch
<
Type
VarNode
>
([](
const
TypeVar
Node
*
node
,
tvm
::
IRPrinter
*
p
)
{
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"Type
Param
Node("
<<
node
->
var
->
name_hint
<<
", "
p
->
stream
<<
"Type
Var
Node("
<<
node
->
var
->
name_hint
<<
", "
<<
node
->
kind
<<
")"
;
<<
node
->
kind
<<
")"
;
});
});
IncompleteType
IncompleteTypeNode
::
make
(
Type
Param
Node
::
Kind
kind
)
{
IncompleteType
IncompleteTypeNode
::
make
(
Type
Var
Node
::
Kind
kind
)
{
auto
n
=
make_node
<
IncompleteTypeNode
>
();
auto
n
=
make_node
<
IncompleteTypeNode
>
();
n
->
kind
=
std
::
move
(
kind
);
n
->
kind
=
std
::
move
(
kind
);
return
IncompleteType
(
n
);
return
IncompleteType
(
n
);
...
@@ -70,7 +70,7 @@ TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);
...
@@ -70,7 +70,7 @@ TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);
TVM_REGISTER_API
(
"relay._make.IncompleteType"
)
TVM_REGISTER_API
(
"relay._make.IncompleteType"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
int
kind
=
args
[
0
];
int
kind
=
args
[
0
];
*
ret
=
IncompleteTypeNode
::
make
(
static_cast
<
Type
Param
Node
::
Kind
>
(
kind
));
*
ret
=
IncompleteTypeNode
::
make
(
static_cast
<
Type
Var
Node
::
Kind
>
(
kind
));
});
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
...
@@ -82,7 +82,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
...
@@ -82,7 +82,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
FuncType
FuncTypeNode
::
make
(
tvm
::
Array
<
Type
>
arg_types
,
FuncType
FuncTypeNode
::
make
(
tvm
::
Array
<
Type
>
arg_types
,
Type
ret_type
,
Type
ret_type
,
tvm
::
Array
<
Type
Param
>
type_params
,
tvm
::
Array
<
Type
Var
>
type_params
,
tvm
::
Array
<
TypeConstraint
>
type_constraints
)
{
tvm
::
Array
<
TypeConstraint
>
type_constraints
)
{
NodePtr
<
FuncTypeNode
>
n
=
make_node
<
FuncTypeNode
>
();
NodePtr
<
FuncTypeNode
>
n
=
make_node
<
FuncTypeNode
>
();
n
->
arg_types
=
std
::
move
(
arg_types
);
n
->
arg_types
=
std
::
move
(
arg_types
);
...
...
src/relay/op/image/resize.cc
View file @
8876eac8
...
@@ -78,6 +78,7 @@ RELAY_REGISTER_OP("image.resize")
...
@@ -78,6 +78,7 @@ RELAY_REGISTER_OP("image.resize")
for layout NHWC
for layout NHWC
(batch_size, size[0], size[1], channels)
(batch_size, size[0], size[1], channels)
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_attrs_type_key
(
"relay.attrs.ResizeAttrs"
)
.
set_num_inputs
(
1
)
.
set_num_inputs
(
1
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
set_support_level
(
5
)
.
set_support_level
(
5
)
...
...
src/relay/op/nn/nn.cc
View file @
8876eac8
...
@@ -247,6 +247,8 @@ RELAY_REGISTER_UNARY_OP("relay.op.nn._make.", "relu")
...
@@ -247,6 +247,8 @@ RELAY_REGISTER_UNARY_OP("relay.op.nn._make.", "relu")
// Positional relay function to create LRN operator used by frontend FFI.
// Positional relay function to create LRN operator used by frontend FFI.
TVM_REGISTER_NODE_TYPE
(
LRNAttrs
);
Expr
MakeLRN
(
Expr
data
,
Expr
MakeLRN
(
Expr
data
,
IndexExpr
size
,
IndexExpr
size
,
IndexExpr
axis
,
IndexExpr
axis
,
...
@@ -290,6 +292,8 @@ centered at that value (zero padding is added where necessary).
...
@@ -290,6 +292,8 @@ centered at that value (zero padding is added where necessary).
// Positional relay function to create L2Normalize operator used by frontend FFI.
// Positional relay function to create L2Normalize operator used by frontend FFI.
TVM_REGISTER_NODE_TYPE
(
L2NormalizeAttrs
);
Expr
MakeL2Normalize
(
Expr
data
,
Expr
MakeL2Normalize
(
Expr
data
,
double
eps
,
double
eps
,
Array
<
IndexExpr
>
axis
)
{
Array
<
IndexExpr
>
axis
)
{
...
@@ -315,6 +319,7 @@ Normalizes along dimension axis using an L2 norm
...
@@ -315,6 +319,7 @@ Normalizes along dimension axis using an L2 norm
- **data**: The input tensor.
- **data**: The input tensor.
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_attrs_type_key
(
"relay.attrs.L2NormalizeAttrs"
)
.
set_num_inputs
(
1
)
.
set_num_inputs
(
1
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
set_support_level
(
2
)
.
set_support_level
(
2
)
...
...
src/relay/op/nn/pad.cc
View file @
8876eac8
...
@@ -77,6 +77,7 @@ RELAY_REGISTER_OP("nn.pad")
...
@@ -77,6 +77,7 @@ RELAY_REGISTER_OP("nn.pad")
.
describe
(
R"code(Pad for n-D tensor.
.
describe
(
R"code(Pad for n-D tensor.
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_attrs_type_key
(
"relay.attrs.PadAttrs"
)
.
set_num_inputs
(
1
)
.
set_num_inputs
(
1
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
set_support_level
(
2
)
.
set_support_level
(
2
)
...
...
src/relay/op/nn/pooling.cc
View file @
8876eac8
...
@@ -12,6 +12,7 @@ namespace tvm {
...
@@ -12,6 +12,7 @@ namespace tvm {
namespace
relay
{
namespace
relay
{
TVM_REGISTER_NODE_TYPE
(
MaxPool2DAttrs
);
TVM_REGISTER_NODE_TYPE
(
MaxPool2DAttrs
);
TVM_REGISTER_NODE_TYPE
(
AvgPool2DAttrs
);
template
<
typename
AttrTtype
>
template
<
typename
AttrTtype
>
bool
Pool2DRel
(
const
Array
<
Type
>&
types
,
bool
Pool2DRel
(
const
Array
<
Type
>&
types
,
...
@@ -115,6 +116,7 @@ RELAY_REGISTER_OP("nn.max_pool2d")
...
@@ -115,6 +116,7 @@ RELAY_REGISTER_OP("nn.max_pool2d")
equation.
equation.
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_attrs_type_key
(
"relay.attrs.MaxPool2DAttrs"
)
.
set_num_inputs
(
1
)
.
set_num_inputs
(
1
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
set_support_level
(
2
)
.
set_support_level
(
2
)
...
@@ -169,6 +171,7 @@ Average pooling operation for one dimensional data.
...
@@ -169,6 +171,7 @@ Average pooling operation for one dimensional data.
equation.
equation.
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_attrs_type_key
(
"relay.attrs.AvgPool2DAttrs"
)
.
set_num_inputs
(
1
)
.
set_num_inputs
(
1
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
set_support_level
(
2
)
.
set_support_level
(
2
)
...
@@ -232,6 +235,7 @@ RELAY_REGISTER_OP("nn.global_avg_pool2d")
...
@@ -232,6 +235,7 @@ RELAY_REGISTER_OP("nn.global_avg_pool2d")
(batch_size, channels, 1, 1) if `layout` is `NCHW`.
(batch_size, channels, 1, 1) if `layout` is `NCHW`.
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_attrs_type_key
(
"relay.attrs.GlobalPool2DAttrs"
)
.
set_num_inputs
(
1
)
.
set_num_inputs
(
1
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
set_support_level
(
2
)
.
set_support_level
(
2
)
...
@@ -261,6 +265,7 @@ RELAY_REGISTER_OP("nn.global_max_pool2d")
...
@@ -261,6 +265,7 @@ RELAY_REGISTER_OP("nn.global_max_pool2d")
(batch_size, channels, 1, 1) if `layout` is `NCHW`.
(batch_size, channels, 1, 1) if `layout` is `NCHW`.
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_attrs_type_key
(
"relay.attrs.GlobalPool2DAttrs"
)
.
set_num_inputs
(
1
)
.
set_num_inputs
(
1
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
set_support_level
(
2
)
.
set_support_level
(
2
)
...
...
src/relay/op/nn/upsampling.cc
View file @
8876eac8
...
@@ -78,6 +78,7 @@ RELAY_REGISTER_OP("nn.upsampling")
...
@@ -78,6 +78,7 @@ RELAY_REGISTER_OP("nn.upsampling")
(batch_size, in_height*scale, in_width*scale, channels)
(batch_size, in_height*scale, in_width*scale, channels)
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_attrs_type_key
(
"relay.attrs.UpSamplingAttrs"
)
.
set_num_inputs
(
1
)
.
set_num_inputs
(
1
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
set_support_level
(
2
)
.
set_support_level
(
2
)
...
...
src/relay/op/tensor/reduce.cc
View file @
8876eac8
...
@@ -199,7 +199,7 @@ RELAY_REGISTER_REDUCE_OP("argmax")
...
@@ -199,7 +199,7 @@ RELAY_REGISTER_REDUCE_OP("argmax")
values over a given axis.
values over a given axis.
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_
num_inputs
(
1
)
.
set_
attrs_type_key
(
"relay.attrs.ReduceAttrs"
)
.
set_support_level
(
4
)
.
set_support_level
(
4
)
.
add_type_rel
(
"ArgReduce"
,
ArgReduceRel
);
.
add_type_rel
(
"ArgReduce"
,
ArgReduceRel
);
...
@@ -209,7 +209,7 @@ RELAY_REGISTER_REDUCE_OP("argmin")
...
@@ -209,7 +209,7 @@ RELAY_REGISTER_REDUCE_OP("argmin")
values over a given axis.
values over a given axis.
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_
num_inputs
(
1
)
.
set_
attrs_type_key
(
"relay.attrs.ReduceAttrs"
)
.
set_support_level
(
4
)
.
set_support_level
(
4
)
.
add_type_rel
(
"ArgReduce"
,
ArgReduceRel
);
.
add_type_rel
(
"ArgReduce"
,
ArgReduceRel
);
...
...
src/relay/op/tensor/transform.cc
View file @
8876eac8
...
@@ -144,12 +144,14 @@ RELAY_REGISTER_OP("concatenate")
...
@@ -144,12 +144,14 @@ RELAY_REGISTER_OP("concatenate")
- **axis** : The axis along which the tensors are concatenated.
- **axis** : The axis along which the tensors are concatenated.
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_attrs_type_key
(
"relay.attrs.ConcatenateAttrs"
)
.
set_num_inputs
(
1
)
.
set_num_inputs
(
1
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input list of tensors."
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input list of tensors."
)
.
set_support_level
(
1
)
.
set_support_level
(
1
)
.
add_type_rel
(
"Concatenate"
,
ConcatenateRel
);
.
add_type_rel
(
"Concatenate"
,
ConcatenateRel
);
/* relay.transpose */
/* relay.transpose */
TVM_REGISTER_NODE_TYPE
(
TransposeAttrs
);
bool
TransposeRel
(
const
Array
<
Type
>&
types
,
bool
TransposeRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
int
num_inputs
,
...
@@ -224,12 +226,15 @@ RELAY_REGISTER_OP("transpose")
...
@@ -224,12 +226,15 @@ RELAY_REGISTER_OP("transpose")
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_num_inputs
(
1
)
.
set_num_inputs
(
1
)
.
set_attrs_type_key
(
"relay.attrs.TransposeAttrs"
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
set_support_level
(
3
)
.
set_support_level
(
3
)
.
add_type_rel
(
"Transpose"
,
TransposeRel
);
.
add_type_rel
(
"Transpose"
,
TransposeRel
);
/* relay.reshape */
/* relay.reshape */
TVM_REGISTER_NODE_TYPE
(
ReshapeAttrs
);
bool
ReshapeRel
(
const
Array
<
Type
>&
types
,
bool
ReshapeRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
int
num_inputs
,
const
Attrs
&
attrs
,
const
Attrs
&
attrs
,
...
@@ -310,6 +315,7 @@ Example::
...
@@ -310,6 +315,7 @@ Example::
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_num_inputs
(
1
)
.
set_num_inputs
(
1
)
.
set_attrs_type_key
(
"relay.attrs.ReshapeAttrs"
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
set_support_level
(
3
)
.
set_support_level
(
3
)
.
add_type_rel
(
"Reshape"
,
ReshapeRel
);
.
add_type_rel
(
"Reshape"
,
ReshapeRel
);
...
@@ -397,12 +403,14 @@ Examples::
...
@@ -397,12 +403,14 @@ Examples::
[ 4., 3.]]
[ 4., 3.]]
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_attrs_type_key
(
"relay.attrs.TakeAttrs"
)
.
set_num_inputs
(
2
)
.
set_num_inputs
(
2
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
add_argument
(
"indices"
,
"Tensor"
,
"The indices tensor."
)
.
add_argument
(
"indices"
,
"Tensor"
,
"The indices tensor."
)
.
set_support_level
(
2
)
.
set_support_level
(
2
)
.
add_type_rel
(
"Take"
,
TakeRel
);
.
add_type_rel
(
"Take"
,
TakeRel
);
// Init ops
TVM_REGISTER_NODE_TYPE
(
InitOpAttrs
);
TVM_REGISTER_NODE_TYPE
(
InitOpAttrs
);
bool
FullRel
(
const
Array
<
Type
>&
types
,
bool
FullRel
(
const
Array
<
Type
>&
types
,
...
@@ -448,6 +456,7 @@ RELAY_REGISTER_OP("full")
...
@@ -448,6 +456,7 @@ RELAY_REGISTER_OP("full")
.
describe
(
R"code(Fill array with scalar value.
.
describe
(
R"code(Fill array with scalar value.
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_attrs_type_key
(
"relay.attrs.InitOpAttrs"
)
.
set_num_inputs
(
1
)
.
set_num_inputs
(
1
)
.
add_argument
(
"fill_value"
,
"double"
,
"The value to fill."
)
.
add_argument
(
"fill_value"
,
"double"
,
"The value to fill."
)
.
set_support_level
(
3
)
.
set_support_level
(
3
)
...
@@ -634,6 +643,10 @@ Examples::
...
@@ -634,6 +643,10 @@ Examples::
.
set_support_level
(
4
)
.
set_support_level
(
4
)
.
add_type_rel
(
"Where"
,
WhereRel
);
.
add_type_rel
(
"Where"
,
WhereRel
);
// Squeeze
TVM_REGISTER_NODE_TYPE
(
SqueezeAttrs
);
Expr
MakeSqueeze
(
Expr
data
,
Expr
MakeSqueeze
(
Expr
data
,
Array
<
IndexExpr
>
axes
)
{
Array
<
IndexExpr
>
axes
)
{
auto
attrs
=
make_node
<
SqueezeAttrs
>
();
auto
attrs
=
make_node
<
SqueezeAttrs
>
();
...
...
src/relay/op/type_relations.cc
View file @
8876eac8
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/logging.h>
#include <tvm/relay/logging.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op.h>
#include <tvm/ir_pass.h>
#include <numeric>
#include <numeric>
#include "./type_relations.h"
#include "./type_relations.h"
...
@@ -21,14 +22,6 @@ TensorType ToTensorType(const Type& t) {
...
@@ -21,14 +22,6 @@ TensorType ToTensorType(const Type& t) {
}
}
}
}
// TODO(@jroesch) what size value do we extract, 64bit or 32bit?
int
ToInt
(
const
tvm
::
Expr
&
e
)
{
CHECK
(
e
.
defined
());
auto
imm
=
e
.
as
<
tvm
::
ir
::
IntImm
>
();
CHECK
(
imm
)
<<
"TYPE: "
<<
imm
<<
imm
->
type
<<
std
::
endl
;
return
imm
->
value
;
}
bool
IdentityRel
(
const
Array
<
Type
>&
types
,
bool
IdentityRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
int
num_inputs
,
const
Attrs
&
attrs
,
const
Attrs
&
attrs
,
...
@@ -39,72 +32,54 @@ bool IdentityRel(const Array<Type>& types,
...
@@ -39,72 +32,54 @@ bool IdentityRel(const Array<Type>& types,
return
true
;
return
true
;
}
}
bool
EqualCheck
(
const
IndexExpr
&
lhs
,
const
IndexExpr
&
rhs
)
{
IndexExpr
diff
=
lhs
-
rhs
;
if
(
const
int64_t
*
pdiff
=
as_const_int
(
diff
))
{
return
pdiff
[
0
]
==
0
;
}
// symbolic
diff
=
tvm
::
ir
::
CanonicalSimplify
(
diff
);
if
(
const
int64_t
*
pdiff
=
as_const_int
(
diff
))
{
return
pdiff
[
0
]
==
0
;
}
return
false
;
}
bool
EqualConstInt
(
const
IndexExpr
&
lhs
,
int64_t
value
)
{
if
(
const
int64_t
*
pvalue
=
as_const_int
(
lhs
))
{
return
pvalue
[
0
]
==
value
;
}
return
false
;
}
Type
ConcreteBroadcast
(
const
TensorType
&
t1
,
Type
ConcreteBroadcast
(
const
TensorType
&
t1
,
const
TensorType
&
t2
,
const
TensorType
&
t2
,
DataType
output_dtype
)
{
DataType
output_dtype
)
{
RELAY_LOG
(
INFO
)
<<
"ConcreteBroadcast: t1="
<<
t1
<<
" t2="
<<
t2
std
::
vector
<
IndexExpr
>
oshape
;
<<
std
::
endl
;
size_t
ndim1
=
t1
->
shape
.
size
();
auto
sh1
=
t1
->
shape
;
size_t
ndim2
=
t2
->
shape
.
size
();
auto
sh2
=
t2
->
shape
;
size_t
i
=
1
;
RELAY_LOG
(
INFO
)
<<
"ConcreteBroadcast: sh1="
<<
sh1
<<
" sh2="
<<
sh2
for
(;
i
<=
std
::
min
(
ndim1
,
ndim2
);
++
i
)
{
<<
std
::
endl
;
IndexExpr
s1
=
t1
->
shape
[
ndim1
-
i
];
if
(
sh1
.
size
()
==
0
&&
sh2
.
size
()
==
0
)
{
IndexExpr
s2
=
t2
->
shape
[
ndim2
-
i
];
return
TensorTypeNode
::
make
({},
output_dtype
);
if
(
EqualCheck
(
s1
,
s2
))
{
// We have non-zero shapes so broadcast rules apply.
oshape
.
push_back
(
s1
);
}
else
{
}
else
if
(
EqualConstInt
(
s1
,
1
))
{
auto
suffix_len
=
static_cast
<
int
>
(
std
::
min
(
sh1
.
size
(),
sh2
.
size
()));
oshape
.
push_back
(
s2
);
auto
full_len
=
static_cast
<
int
>
(
std
::
max
(
sh1
.
size
(),
sh2
.
size
()));
}
else
if
(
EqualConstInt
(
s2
,
1
))
{
oshape
.
push_back
(
s1
);
auto
rev_sh1
=
sh1
.
rbegin
();
auto
rev_sh2
=
sh2
.
rbegin
();
while
(
rev_sh1
!=
sh1
.
rend
()
&&
rev_sh2
!=
sh2
.
rend
())
{
auto
dim1
=
ToInt
(
*
rev_sh1
);
auto
dim2
=
ToInt
(
*
rev_sh2
);
if
((
dim1
!=
dim2
)
&&
((
dim1
!=
1
)
&&
(
dim2
!=
1
)))
{
CHECK
(
false
)
<<
"Dimension mistmatch "
<<
"dim1: "
<<
dim1
<<
" dim2: "
<<
dim2
<<
std
::
endl
;
}
rev_sh1
++
;
rev_sh2
++
;
}
Array
<
IndexExpr
>
larger
;
Array
<
IndexExpr
>
smaller
;
for
(
int
i
=
0
;
i
<
(
full_len
-
suffix_len
);
i
++
)
{
smaller
.
push_back
(
make_const
(
tvm
::
Int
(
64
),
1
));
}
if
(
sh1
.
size
()
<
sh2
.
size
())
{
for
(
auto
sh
:
sh1
)
{
smaller
.
push_back
(
sh
);
}
larger
=
sh2
;
}
else
if
(
sh1
.
size
()
>
sh2
.
size
())
{
for
(
auto
sh
:
sh1
)
{
larger
.
push_back
(
sh
);
}
smaller
=
sh2
;
}
else
{
}
else
{
larger
=
sh1
;
LOG
(
FATAL
)
<<
"Incompatible broadcast type "
<<
t1
<<
" and "
<<
t2
;
smaller
=
sh2
;
}
}
CHECK_EQ
(
larger
.
size
(),
smaller
.
size
());
Array
<
IndexExpr
>
out_shape
;
for
(
size_t
i
=
0
;
i
<
smaller
.
size
();
i
++
)
{
auto
left
=
smaller
[
i
].
as
<
tvm
::
ir
::
IntImm
>
();
auto
right
=
larger
[
i
].
as
<
tvm
::
ir
::
IntImm
>
();
CHECK
(
left
);
CHECK
(
right
);
int64_t
dim
=
std
::
max
(
left
->
value
,
right
->
value
);
out_shape
.
push_back
(
make_const
(
tvm
::
Int
(
64
),
dim
));
}
return
TensorTypeNode
::
make
(
out_shape
,
output_dtype
);
}
}
size_t
max_ndim
=
std
::
max
(
ndim1
,
ndim2
);
auto
&
rshape
=
(
ndim1
>
ndim2
)
?
t1
->
shape
:
t2
->
shape
;
for
(;
i
<=
max_ndim
;
++
i
)
{
oshape
.
push_back
(
rshape
[
max_ndim
-
i
]);
}
return
TensorTypeNode
::
make
(
Array
<
IndexExpr
>
(
oshape
.
rbegin
(),
oshape
.
rend
()),
output_dtype
);
}
}
bool
BroadcastRel
(
const
Array
<
Type
>&
types
,
bool
BroadcastRel
(
const
Array
<
Type
>&
types
,
...
@@ -141,71 +116,5 @@ bool BroadcastCompRel(const Array<Type>& types,
...
@@ -141,71 +116,5 @@ bool BroadcastCompRel(const Array<Type>& types,
return
false
;
return
false
;
}
}
/*! \brief Handle concrete concat case from known input to output. */
inline
Type
ConcreteConcatRel
(
const
Type
&
input_type
)
{
if
(
auto
tuple_node
=
input_type
.
as
<
TupleTypeNode
>
())
{
// NB: For now the axis argument is hardwired to be 0.
std
::
vector
<
int
>
dims
;
DataType
dtype
;
CHECK_LT
(
1
,
tuple_node
->
fields
.
size
());
bool
skip_first
=
true
;
// Collect the suffix dimensions since axis is zero.
// TODO(@jroesch): This is a demonstration of how
// to do varargs. It requires a little more work to
// fully type the behavior of concat.
auto
first
=
Downcast
<
TensorType
>
(
tuple_node
->
fields
[
0
]);
dtype
=
first
->
dtype
;
for
(
auto
dim_expr
:
first
->
shape
)
{
if
(
!
skip_first
)
{
dims
.
push_back
(
ToInt
(
dim_expr
));
}
else
{
skip_first
=
false
;
}
}
std
::
vector
<
int
>
axis_dims
;
for
(
auto
field_ty
:
tuple_node
->
fields
)
{
auto
ttype
=
Downcast
<
TensorType
>
(
field_ty
);
for
(
size_t
i
=
0
;
i
<
ttype
->
shape
.
size
();
i
++
)
{
if
(
i
!=
0
)
{
CHECK_EQ
(
ToInt
(
dims
[
i
-
1
]),
ToInt
(
ttype
->
shape
[
i
]));
}
else
{
axis_dims
.
push_back
(
ToInt
(
ttype
->
shape
[
i
]));
}
}
}
auto
out_axis_dim
=
std
::
accumulate
(
axis_dims
.
begin
(),
axis_dims
.
end
(),
0
);
Array
<
tvm
::
Expr
>
out_shape
=
{
make_const
(
Int
(
64
),
out_axis_dim
)
};
for
(
auto
dim
:
dims
)
{
out_shape
.
push_back
(
make_const
(
Int
(
64
),
dim
));
}
return
TensorTypeNode
::
make
(
out_shape
,
dtype
);
}
else
{
throw
TypeRelationError
(
"concat can only be used with a tuple as its argument"
);
}
}
bool
ConcatRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
const
Attrs
&
attrs
,
const
TypeReporter
&
reporter
)
{
CHECK_EQ
(
types
.
size
(),
2
);
if
(
types
[
0
].
as
<
TupleTypeNode
>
())
{
reporter
->
Assign
(
types
[
1
],
ConcreteConcatRel
(
types
[
0
]));
return
true
;
}
return
false
;
}
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
src/relay/op/type_relations.h
View file @
8876eac8
...
@@ -13,17 +13,6 @@
...
@@ -13,17 +13,6 @@
namespace
tvm
{
namespace
tvm
{
namespace
relay
{
namespace
relay
{
/*! \brief The error raised by a type relation.
*
* This error is how a type relation signals that it has failed.
*
*/
struct
TypeRelationError
:
Error
{
explicit
TypeRelationError
(
const
std
::
string
&
msg
)
:
Error
(
msg
)
{}
};
/*!
/*!
* \brief The identity type relation, all the types are equal.
* \brief The identity type relation, all the types are equal.
*
*
...
@@ -72,22 +61,6 @@ bool BroadcastCompRel(const Array<Type>& types,
...
@@ -72,22 +61,6 @@ bool BroadcastCompRel(const Array<Type>& types,
const
Attrs
&
attrs
,
const
Attrs
&
attrs
,
const
TypeReporter
&
reporter
);
const
TypeReporter
&
reporter
);
/*!
* \brief The concat type relation, implements the concatenating
* rule over the list of input types producing one concatenated
* type.
*
* \param types The input and output types to the relation.
* \param num_inputs The number of input arguments.
* \param attrs The attributes
* \param reporter The reporter.
* \return true whether relation has been resolved.
*/
bool
ConcatRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
const
Attrs
&
attrs
,
const
TypeReporter
&
reporter
);
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
...
...
src/relay/op/vision/multibox_op.cc
View file @
8876eac8
...
@@ -63,6 +63,7 @@ TVM_REGISTER_API("relay.op.vision._make.multibox_prior")
...
@@ -63,6 +63,7 @@ TVM_REGISTER_API("relay.op.vision._make.multibox_prior")
RELAY_REGISTER_OP
(
"vision.multibox_prior"
)
RELAY_REGISTER_OP
(
"vision.multibox_prior"
)
.
describe
(
R"doc("Generate prior(anchor) boxes from data, sizes and ratios."
.
describe
(
R"doc("Generate prior(anchor) boxes from data, sizes and ratios."
)doc"
TVM_ADD_FILELINE
)
)doc"
TVM_ADD_FILELINE
)
.
set_attrs_type_key
(
"relay.attrs.MultiBoxPriorAttrs"
)
.
set_num_inputs
(
1
)
.
set_num_inputs
(
1
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
set_support_level
(
4
)
.
set_support_level
(
4
)
...
...
src/relay/pass/alpha_eq.cc
View file @
8876eac8
...
@@ -34,7 +34,7 @@ bool SameNDArray(const NDArray& lhs, const NDArray& rhs) {
...
@@ -34,7 +34,7 @@ bool SameNDArray(const NDArray& lhs, const NDArray& rhs) {
}
}
struct
TypeAlphaEq
:
TypeVisitor
<
const
Type
&>
{
struct
TypeAlphaEq
:
TypeVisitor
<
const
Type
&>
{
tvm
::
Map
<
Type
Param
,
TypeParam
>
eq_map
;
tvm
::
Map
<
Type
Var
,
TypeVar
>
eq_map
;
bool
equal
;
bool
equal
;
TypeAlphaEq
()
:
eq_map
(),
equal
(
true
)
{}
TypeAlphaEq
()
:
eq_map
(),
equal
(
true
)
{}
...
@@ -76,10 +76,10 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
...
@@ -76,10 +76,10 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
}
}
}
}
void
VisitType_
(
const
Type
Param
Node
*
ti1
,
const
Type
&
t2
)
final
{
void
VisitType_
(
const
Type
Var
Node
*
ti1
,
const
Type
&
t2
)
final
{
if
(
const
Type
ParamNode
*
ti2
=
t2
.
as
<
TypeParam
Node
>
())
{
if
(
const
Type
VarNode
*
ti2
=
t2
.
as
<
TypeVar
Node
>
())
{
auto
tid1
=
GetRef
<
Type
Param
>
(
ti1
);
auto
tid1
=
GetRef
<
Type
Var
>
(
ti1
);
auto
tid2
=
GetRef
<
Type
Param
>
(
ti2
);
auto
tid2
=
GetRef
<
Type
Var
>
(
ti2
);
// We handle open terms with this rule assuming variables are identical.
// We handle open terms with this rule assuming variables are identical.
//
//
...
...
src/relay/pass/dead_code.cc
View file @
8876eac8
...
@@ -20,7 +20,9 @@ bool IsBoolLit(const Expr& e, bool b) {
...
@@ -20,7 +20,9 @@ bool IsBoolLit(const Expr& e, bool b) {
if
(
const
ConstantNode
*
c
=
e
.
as
<
ConstantNode
>
())
{
if
(
const
ConstantNode
*
c
=
e
.
as
<
ConstantNode
>
())
{
if
(
c
->
is_scalar
())
{
if
(
c
->
is_scalar
())
{
auto
dt
=
c
->
tensor_type
()
->
dtype
;
auto
dt
=
c
->
tensor_type
()
->
dtype
;
if
(
dt
==
UInt
(
8
))
{
if
(
dt
==
Bool
())
{
return
*
reinterpret_cast
<
const
uint8_t
*>
(
c
->
data
->
data
)
==
b
;
}
else
if
(
dt
==
UInt
(
8
))
{
return
*
reinterpret_cast
<
const
uint8_t
*>
(
c
->
data
->
data
)
==
b
;
return
*
reinterpret_cast
<
const
uint8_t
*>
(
c
->
data
->
data
)
==
b
;
}
else
if
(
dt
==
UInt
(
16
))
{
}
else
if
(
dt
==
UInt
(
16
))
{
return
*
reinterpret_cast
<
const
uint16_t
*>
(
c
->
data
->
data
)
==
b
;
return
*
reinterpret_cast
<
const
uint16_t
*>
(
c
->
data
->
data
)
==
b
;
...
...
src/relay/pass/kind_check.cc
View file @
8876eac8
...
@@ -20,7 +20,7 @@ namespace tvm {
...
@@ -20,7 +20,7 @@ namespace tvm {
namespace
relay
{
namespace
relay
{
using
namespace
tvm
::
runtime
;
using
namespace
tvm
::
runtime
;
using
Kind
=
Type
Param
Node
::
Kind
;
using
Kind
=
Type
Var
Node
::
Kind
;
struct
KindChecker
:
TypeVisitor
<>
{
struct
KindChecker
:
TypeVisitor
<>
{
bool
valid
;
bool
valid
;
...
@@ -33,7 +33,7 @@ struct KindChecker : TypeVisitor<> {
...
@@ -33,7 +33,7 @@ struct KindChecker : TypeVisitor<> {
return
tv
->
kind
==
k
;
return
tv
->
kind
==
k
;
}
}
if
(
const
Type
ParamNode
*
tp
=
t
.
as
<
TypeParam
Node
>
())
{
if
(
const
Type
VarNode
*
tp
=
t
.
as
<
TypeVar
Node
>
())
{
return
tp
->
kind
==
k
;
return
tp
->
kind
==
k
;
}
}
...
...
src/relay/pass/let_list.h
View file @
8876eac8
...
@@ -61,7 +61,7 @@ class LetList {
...
@@ -61,7 +61,7 @@ class LetList {
* \return a Var that hold the inserted expr.
* \return a Var that hold the inserted expr.
*/
*/
Var
Push
(
Expr
expr
)
{
Var
Push
(
Expr
expr
)
{
return
Push
(
IncompleteTypeNode
::
make
(
Type
Param
Node
::
kType
),
expr
);
return
Push
(
IncompleteTypeNode
::
make
(
Type
Var
Node
::
kType
),
expr
);
}
}
/*!
/*!
...
...
src/relay/pass/type_functor.h
View file @
8876eac8
...
@@ -61,7 +61,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
...
@@ -61,7 +61,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
// Functions that can be overriden by subclass
// Functions that can be overriden by subclass
virtual
R
VisitType_
(
const
TensorTypeNode
*
op
,
virtual
R
VisitType_
(
const
TensorTypeNode
*
op
,
Args
...
args
)
TYPE_FUNCTOR_DEFAULT
;
Args
...
args
)
TYPE_FUNCTOR_DEFAULT
;
virtual
R
VisitType_
(
const
Type
Param
Node
*
op
,
Args
...
args
)
TYPE_FUNCTOR_DEFAULT
;
virtual
R
VisitType_
(
const
Type
Var
Node
*
op
,
Args
...
args
)
TYPE_FUNCTOR_DEFAULT
;
virtual
R
VisitType_
(
const
TypeConstraintNode
*
op
,
Args
...
args
)
TYPE_FUNCTOR_DEFAULT
;
virtual
R
VisitType_
(
const
TypeConstraintNode
*
op
,
Args
...
args
)
TYPE_FUNCTOR_DEFAULT
;
virtual
R
VisitType_
(
const
FuncTypeNode
*
op
,
Args
...
args
)
TYPE_FUNCTOR_DEFAULT
;
virtual
R
VisitType_
(
const
FuncTypeNode
*
op
,
Args
...
args
)
TYPE_FUNCTOR_DEFAULT
;
virtual
R
VisitType_
(
const
TypeRelationNode
*
op
,
Args
...
args
)
TYPE_FUNCTOR_DEFAULT
;
virtual
R
VisitType_
(
const
TypeRelationNode
*
op
,
Args
...
args
)
TYPE_FUNCTOR_DEFAULT
;
...
@@ -79,7 +79,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
...
@@ -79,7 +79,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
FType
vtable
;
FType
vtable
;
// Set dispatch
// Set dispatch
RELAY_TYPE_FUNCTOR_DISPATCH
(
TensorTypeNode
);
RELAY_TYPE_FUNCTOR_DISPATCH
(
TensorTypeNode
);
RELAY_TYPE_FUNCTOR_DISPATCH
(
Type
Param
Node
);
RELAY_TYPE_FUNCTOR_DISPATCH
(
Type
Var
Node
);
RELAY_TYPE_FUNCTOR_DISPATCH
(
TypeConstraintNode
);
RELAY_TYPE_FUNCTOR_DISPATCH
(
TypeConstraintNode
);
RELAY_TYPE_FUNCTOR_DISPATCH
(
FuncTypeNode
);
RELAY_TYPE_FUNCTOR_DISPATCH
(
FuncTypeNode
);
RELAY_TYPE_FUNCTOR_DISPATCH
(
TypeRelationNode
);
RELAY_TYPE_FUNCTOR_DISPATCH
(
TypeRelationNode
);
...
...
src/relay/pass/type_infer.cc
View file @
8876eac8
...
@@ -28,6 +28,39 @@
...
@@ -28,6 +28,39 @@
namespace
tvm
{
namespace
tvm
{
namespace
relay
{
namespace
relay
{
// Necessary deferred relation for TupleGetItem
struct
TupleGetItemAttrs
:
public
tvm
::
AttrsNode
<
TupleGetItemAttrs
>
{
int
index
;
TVM_DECLARE_ATTRS
(
TupleGetItemAttrs
,
"relay.attrs.TupleGetItemAttrs"
)
{
TVM_ATTR_FIELD
(
index
);
}
};
bool
TupleGetItemRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
const
Attrs
&
attrs
,
const
TypeReporter
&
reporter
)
{
CHECK_EQ
(
types
.
size
(),
2
);
if
(
types
[
0
].
as
<
IncompleteTypeNode
>
())
return
false
;
const
auto
*
data
=
types
[
0
].
as
<
TupleTypeNode
>
();
CHECK
(
data
!=
nullptr
)
<<
"TupleGetItem expect input type to be TupleType "
<<
" get "
<<
types
[
0
]
<<
" instead"
;
const
auto
*
param
=
attrs
.
as
<
TupleGetItemAttrs
>
();
CHECK
(
param
!=
nullptr
);
CHECK_GE
(
param
->
index
,
0
);
CHECK_LT
(
param
->
index
,
data
->
fields
.
size
());
reporter
->
Assign
(
types
[
1
],
data
->
fields
[
param
->
index
]);
return
true
;
}
TVM_REGISTER_NODE_TYPE
(
TupleGetItemAttrs
);
TVM_REGISTER_API
(
"tvm.relay.type_relation.TupleGetItem"
)
.
set_body_typed
<
bool
(
const
Array
<
Type
>&
,
int
,
const
Attrs
&
,
const
TypeReporter
&
)
>
(
TupleGetItemRel
);
//
//
// The inference algorithm can roughly be devided into three stages:
// The inference algorithm can roughly be devided into three stages:
// - Populate the constraints by visiting the expression (TypeInferencer.GetType)
// - Populate the constraints by visiting the expression (TypeInferencer.GetType)
...
@@ -38,8 +71,7 @@ namespace relay {
...
@@ -38,8 +71,7 @@ namespace relay {
class
TypeInferencer
:
private
ExprFunctor
<
Type
(
const
Expr
&
)
>
{
class
TypeInferencer
:
private
ExprFunctor
<
Type
(
const
Expr
&
)
>
{
public
:
public
:
// constructors
// constructors
TypeInferencer
()
TypeInferencer
()
{
:
env_
(
EnvironmentNode
::
make
({}))
{
}
}
explicit
TypeInferencer
(
Environment
env
)
explicit
TypeInferencer
(
Environment
env
)
:
env_
(
env
)
{
:
env_
(
env
)
{
...
@@ -58,6 +90,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
...
@@ -58,6 +90,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
std
::
unordered_map
<
Expr
,
Type
,
NodeHash
,
NodeEqual
>
type_map_
;
std
::
unordered_map
<
Expr
,
Type
,
NodeHash
,
NodeEqual
>
type_map_
;
// The solver used by the inferencer.
// The solver used by the inferencer.
TypeSolver
solver_
;
TypeSolver
solver_
;
// relation function
TypeRelationFn
tuple_getitem_rel_
;
// Unify two types
// Unify two types
Type
Unify
(
const
Type
&
t1
,
const
Type
&
t2
,
const
Span
&
span
)
{
Type
Unify
(
const
Type
&
t1
,
const
Type
&
t2
,
const
Span
&
span
)
{
// TODO(tqchen, jroesch): propagate span to solver
// TODO(tqchen, jroesch): propagate span to solver
...
@@ -90,12 +124,14 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
...
@@ -90,12 +124,14 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
if
(
op
->
type_annotation
.
defined
())
{
if
(
op
->
type_annotation
.
defined
())
{
return
op
->
type_annotation
;
return
op
->
type_annotation
;
}
else
{
}
else
{
return
IncompleteTypeNode
::
make
(
Type
Param
Node
::
kType
);
return
IncompleteTypeNode
::
make
(
Type
Var
Node
::
kType
);
}
}
}
}
Type
VisitExpr_
(
const
GlobalVarNode
*
op
)
final
{
Type
VisitExpr_
(
const
GlobalVarNode
*
op
)
final
{
GlobalVar
var
=
GetRef
<
GlobalVar
>
(
op
);
GlobalVar
var
=
GetRef
<
GlobalVar
>
(
op
);
CHECK
(
env_
.
defined
())
<<
"Cannot do type inference without a global variable"
;
Expr
e
=
env_
->
Lookup
(
var
);
Expr
e
=
env_
->
Lookup
(
var
);
return
e
->
checked_type
();
return
e
->
checked_type
();
}
}
...
@@ -116,17 +152,17 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
...
@@ -116,17 +152,17 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
}
}
Type
VisitExpr_
(
const
TupleGetItemNode
*
op
)
final
{
Type
VisitExpr_
(
const
TupleGetItemNode
*
op
)
final
{
// TODO(M.K.)
if
(
!
tuple_getitem_rel_
.
defined
())
{
// handle case where field type is not known
tuple_getitem_rel_
=
TypeRelationFn
(
Type
tuple_type
=
GetType
(
op
->
tuple
);
EnvFunc
::
Get
(
"tvm.relay.type_relation.TupleGetItem"
).
node_
);
auto
tuple_ty_node
=
tuple_type
.
as
<
TupleTypeNode
>
();
if
(
!
tuple_ty_node
)
{
LOG
(
FATAL
)
<<
"only expressions with tuple types is accepted"
<<
GetRef
<
TupleGetItem
>
(
op
);
}
if
(
static_cast
<
int
>
(
tuple_ty_node
->
fields
.
size
())
<=
op
->
index
)
{
LOG
(
FATAL
)
<<
"tuple not big enough"
<<
GetRef
<
TupleGetItem
>
(
op
);
}
}
return
tuple_ty_node
->
fields
[
op
->
index
];
Type
tuple_type
=
GetType
(
op
->
tuple
);
Type
rtype
=
IncompleteTypeNode
::
make
(
TypeVarNode
::
Kind
::
kType
);
auto
attrs
=
make_node
<
TupleGetItemAttrs
>
();
attrs
->
index
=
op
->
index
;
solver_
.
AddConstraint
(
TypeRelationNode
::
make
(
tuple_getitem_rel_
,
{
tuple_type
,
rtype
},
1
,
Attrs
(
attrs
)));
return
rtype
;
}
}
Type
VisitExpr_
(
const
OpNode
*
op
)
final
{
Type
VisitExpr_
(
const
OpNode
*
op
)
final
{
...
@@ -169,7 +205,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
...
@@ -169,7 +205,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
for
(
size_t
i
=
0
;
i
<
op
->
type_params
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
op
->
type_params
.
size
();
++
i
)
{
if
(
!
op
->
type_params
[
i
].
same_as
(
rel
->
args
[
i
]))
return
Type
();
if
(
!
op
->
type_params
[
i
].
same_as
(
rel
->
args
[
i
]))
return
Type
();
}
}
Type
rtype
=
IncompleteTypeNode
::
make
(
Type
Param
Node
::
Kind
::
kType
);
Type
rtype
=
IncompleteTypeNode
::
make
(
Type
Var
Node
::
Kind
::
kType
);
arg_types
.
push_back
(
rtype
);
arg_types
.
push_back
(
rtype
);
// we can do simple replacement here
// we can do simple replacement here
solver_
.
AddConstraint
(
TypeRelationNode
::
make
(
solver_
.
AddConstraint
(
TypeRelationNode
::
make
(
...
@@ -179,7 +215,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
...
@@ -179,7 +215,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// instantiate the function type with fresh
// instantiate the function type with fresh
FuncType
Instantiate
(
const
FuncTypeNode
*
fn_ty
,
Array
<
Type
>*
ty_args
)
{
FuncType
Instantiate
(
const
FuncTypeNode
*
fn_ty
,
Array
<
Type
>*
ty_args
)
{
tvm
::
Map
<
Type
Param
,
Type
>
subst_map
;
tvm
::
Map
<
Type
Var
,
Type
>
subst_map
;
// Build a subsitituion map up from the function type and type arguments.
// Build a subsitituion map up from the function type and type arguments.
// Eventually allow the type vars to be passed in.
// Eventually allow the type vars to be passed in.
...
@@ -196,7 +232,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
...
@@ -196,7 +232,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// This is a temporary work around to check recursive functions whose
// This is a temporary work around to check recursive functions whose
// return type is not yet known.
// return type is not yet known.
if
(
!
ret_type
.
defined
())
{
if
(
!
ret_type
.
defined
())
{
ret_type
=
IncompleteTypeNode
::
make
(
Type
Param
Node
::
Kind
::
kType
);
ret_type
=
IncompleteTypeNode
::
make
(
Type
Var
Node
::
Kind
::
kType
);
}
}
Type
inst_ty
=
FuncTypeNode
::
make
(
fn_ty
->
arg_types
,
Type
inst_ty
=
FuncTypeNode
::
make
(
fn_ty
->
arg_types
,
ret_type
,
{},
ret_type
,
{},
...
@@ -305,7 +341,6 @@ class TypeInferencer::Resolver : public ExprMutator {
...
@@ -305,7 +341,6 @@ class TypeInferencer::Resolver : public ExprMutator {
return
AttachCheckedType
(
op
);
return
AttachCheckedType
(
op
);
}
}
Expr
VisitExpr_
(
const
FunctionNode
*
op
)
final
{
Expr
VisitExpr_
(
const
FunctionNode
*
op
)
final
{
return
AttachCheckedType
(
op
);
return
AttachCheckedType
(
op
);
}
}
...
@@ -363,20 +398,21 @@ Expr TypeInferencer::Infer(Expr expr) {
...
@@ -363,20 +398,21 @@ Expr TypeInferencer::Infer(Expr expr) {
return
Resolver
(
type_map_
,
&
solver_
).
VisitExpr
(
expr
);
return
Resolver
(
type_map_
,
&
solver_
).
VisitExpr
(
expr
);
}
}
Expr
InferType
(
const
Environment
&
env
,
const
Expr
&
expr
)
{
Expr
InferType
(
const
Expr
&
expr
,
const
Environment
&
env
)
{
return
TypeInferencer
(
env
).
Infer
(
expr
);
return
TypeInferencer
(
env
).
Infer
(
expr
);
}
}
Expr
InferType
(
const
Environment
&
env
,
Function
InferType
(
const
Function
&
func
,
const
GlobalVar
&
var
,
const
Environment
&
env
,
const
Function
&
func
)
{
const
GlobalVar
&
var
)
{
Function
func_copy
=
Function
(
make_node
<
FunctionNode
>
(
*
func
.
operator
->
()));
Function
func_copy
=
Function
(
make_node
<
FunctionNode
>
(
*
func
.
operator
->
()));
func_copy
->
checked_type_
=
func_copy
->
func_type_annotation
();
func_copy
->
checked_type_
=
func_copy
->
func_type_annotation
();
env
->
functions
.
Set
(
var
,
func_copy
);
env
->
functions
.
Set
(
var
,
func_copy
);
Expr
func_ret
=
TypeInferencer
(
env
).
Infer
(
func_copy
);
Expr
func_ret
=
TypeInferencer
(
env
).
Infer
(
func_copy
);
auto
map_node
=
env
->
functions
.
CopyOnWrite
();
auto
map_node
=
env
->
functions
.
CopyOnWrite
();
map_node
->
data
.
erase
(
var
.
node_
);
map_node
->
data
.
erase
(
var
.
node_
);
return
func_ret
;
return
Downcast
<
Function
>
(
func_ret
)
;
}
}
TVM_REGISTER_API
(
"relay._ir_pass.infer_type"
)
TVM_REGISTER_API
(
"relay._ir_pass.infer_type"
)
...
...
src/relay/pass/type_subst.cc
View file @
8876eac8
...
@@ -10,13 +10,13 @@ namespace tvm {
...
@@ -10,13 +10,13 @@ namespace tvm {
namespace
relay
{
namespace
relay
{
struct
TypeSubstV
:
TypeMutator
{
struct
TypeSubstV
:
TypeMutator
{
tvm
::
Map
<
Type
Param
,
Type
>
subst_map
;
tvm
::
Map
<
Type
Var
,
Type
>
subst_map
;
explicit
TypeSubstV
(
tvm
::
Map
<
Type
Param
,
Type
>
subst_map
)
explicit
TypeSubstV
(
tvm
::
Map
<
Type
Var
,
Type
>
subst_map
)
:
subst_map
(
subst_map
)
{}
:
subst_map
(
subst_map
)
{}
Type
VisitType_
(
const
Type
Param
Node
*
op
)
override
{
Type
VisitType_
(
const
Type
Var
Node
*
op
)
override
{
auto
id
=
GetRef
<
Type
Param
>
(
op
);
auto
id
=
GetRef
<
Type
Var
>
(
op
);
if
(
subst_map
.
find
(
id
)
!=
subst_map
.
end
())
{
if
(
subst_map
.
find
(
id
)
!=
subst_map
.
end
())
{
return
this
->
subst_map
[
id
];
return
this
->
subst_map
[
id
];
}
else
{
}
else
{
...
@@ -25,12 +25,12 @@ struct TypeSubstV : TypeMutator {
...
@@ -25,12 +25,12 @@ struct TypeSubstV : TypeMutator {
}
}
};
};
Type
TypeSubst
(
const
Type
&
type
,
const
Type
Param
&
target
,
const
Type
&
subst
)
{
Type
TypeSubst
(
const
Type
&
type
,
const
Type
Var
&
target
,
const
Type
&
subst
)
{
TypeSubstV
ty_sub
({
{
target
,
subst
}
});
TypeSubstV
ty_sub
({
{
target
,
subst
}
});
return
ty_sub
.
VisitType
(
type
);
return
ty_sub
.
VisitType
(
type
);
}
}
Type
TypeSubst
(
const
Type
&
type
,
tvm
::
Map
<
Type
Param
,
Type
>
subst_map
)
{
Type
TypeSubst
(
const
Type
&
type
,
tvm
::
Map
<
Type
Var
,
Type
>
subst_map
)
{
TypeSubstV
ty_sub
(
subst_map
);
TypeSubstV
ty_sub
(
subst_map
);
return
ty_sub
.
VisitType
(
type
);
return
ty_sub
.
VisitType
(
type
);
}
}
...
...
src/relay/pass/type_subst.h
View file @
8876eac8
...
@@ -11,8 +11,8 @@
...
@@ -11,8 +11,8 @@
namespace
tvm
{
namespace
tvm
{
namespace
relay
{
namespace
relay
{
Type
TypeSubst
(
const
Type
&
type
,
const
Type
Param
&
target
,
const
Type
&
subst
);
Type
TypeSubst
(
const
Type
&
type
,
const
Type
Var
&
target
,
const
Type
&
subst
);
Type
TypeSubst
(
const
Type
&
type
,
tvm
::
Map
<
Type
Param
,
Type
>
subst_map
);
Type
TypeSubst
(
const
Type
&
type
,
tvm
::
Map
<
Type
Var
,
Type
>
subst_map
);
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
...
...
src/relay/pass/type_visitor.h
View file @
8876eac8
...
@@ -19,7 +19,7 @@ namespace relay {
...
@@ -19,7 +19,7 @@ namespace relay {
*/
*/
template
<
typename
...
Args
>
template
<
typename
...
Args
>
struct
TypeVisitor
:
::
tvm
::
relay
::
TypeFunctor
<
void
(
const
Type
&
n
,
Args
...)
>
{
struct
TypeVisitor
:
::
tvm
::
relay
::
TypeFunctor
<
void
(
const
Type
&
n
,
Args
...)
>
{
void
VisitType_
(
const
Type
Param
Node
*
op
,
Args
...
args
)
override
{}
void
VisitType_
(
const
Type
Var
Node
*
op
,
Args
...
args
)
override
{}
void
VisitType_
(
const
FuncTypeNode
*
op
,
Args
...
args
)
override
{
void
VisitType_
(
const
FuncTypeNode
*
op
,
Args
...
args
)
override
{
for
(
auto
type_param
:
op
->
type_params
)
{
for
(
auto
type_param
:
op
->
type_params
)
{
...
@@ -60,16 +60,16 @@ struct TypeMutator : TypeFunctor<Type(const Type& n)> {
...
@@ -60,16 +60,16 @@ struct TypeMutator : TypeFunctor<Type(const Type& n)> {
return
TensorTypeNode
::
make
(
op
->
shape
,
op
->
dtype
);
return
TensorTypeNode
::
make
(
op
->
shape
,
op
->
dtype
);
}
}
Type
VisitType_
(
const
Type
Param
Node
*
op
)
override
{
Type
VisitType_
(
const
Type
Var
Node
*
op
)
override
{
return
GetRef
<
Type
Param
>
(
op
);
return
GetRef
<
Type
Var
>
(
op
);
}
}
Type
VisitType_
(
const
FuncTypeNode
*
op
)
override
{
Type
VisitType_
(
const
FuncTypeNode
*
op
)
override
{
Array
<
Type
Param
>
type_params
;
Array
<
Type
Var
>
type_params
;
for
(
auto
type_param
:
op
->
type_params
)
{
for
(
auto
type_param
:
op
->
type_params
)
{
auto
new_type_param
=
VisitType
(
type_param
);
auto
new_type_param
=
VisitType
(
type_param
);
if
(
const
Type
ParamNode
*
tin
=
new_type_param
.
as
<
TypeParam
Node
>
())
{
if
(
const
Type
VarNode
*
tin
=
new_type_param
.
as
<
TypeVar
Node
>
())
{
type_params
.
push_back
(
GetRef
<
Type
Param
>
(
tin
));
type_params
.
push_back
(
GetRef
<
Type
Var
>
(
tin
));
}
else
{
}
else
{
CHECK
(
false
)
<<
new_type_param
<<
std
::
endl
;
CHECK
(
false
)
<<
new_type_param
<<
std
::
endl
;
}
}
...
...
src/relay/pass/util.cc
View file @
8876eac8
...
@@ -14,14 +14,14 @@ namespace relay {
...
@@ -14,14 +14,14 @@ namespace relay {
class
FreeVar
;
class
FreeVar
;
class
FreeTypeVar
:
private
TypeVisitor
<>
{
class
FreeTypeVar
:
private
TypeVisitor
<>
{
std
::
unordered_set
<
Type
Param
,
NodeHash
,
NodeEqual
>
*
free_vars
;
std
::
unordered_set
<
Type
Var
,
NodeHash
,
NodeEqual
>
*
free_vars
;
std
::
unordered_set
<
Type
Param
,
NodeHash
,
NodeEqual
>
*
bound_vars
;
std
::
unordered_set
<
Type
Var
,
NodeHash
,
NodeEqual
>
*
bound_vars
;
FreeTypeVar
(
std
::
unordered_set
<
Type
Param
,
NodeHash
,
NodeEqual
>
*
free_vars
,
FreeTypeVar
(
std
::
unordered_set
<
Type
Var
,
NodeHash
,
NodeEqual
>
*
free_vars
,
std
::
unordered_set
<
Type
Param
,
NodeHash
,
NodeEqual
>
*
bound_vars
)
:
std
::
unordered_set
<
Type
Var
,
NodeHash
,
NodeEqual
>
*
bound_vars
)
:
free_vars
(
free_vars
),
bound_vars
(
bound_vars
)
{
}
free_vars
(
free_vars
),
bound_vars
(
bound_vars
)
{
}
void
VisitType_
(
const
Type
Param
Node
*
tp
)
final
{
void
VisitType_
(
const
Type
Var
Node
*
tp
)
final
{
auto
var
=
GetRef
<
Type
Param
>
(
tp
);
auto
var
=
GetRef
<
Type
Var
>
(
tp
);
if
(
bound_vars
->
count
(
var
)
==
0
)
{
if
(
bound_vars
->
count
(
var
)
==
0
)
{
free_vars
->
insert
(
var
);
free_vars
->
insert
(
var
);
}
}
...
@@ -75,8 +75,8 @@ class FreeVar : public ExprVisitor {
...
@@ -75,8 +75,8 @@ class FreeVar : public ExprVisitor {
public
:
public
:
std
::
unordered_set
<
Var
,
NodeHash
,
NodeEqual
>
free_vars
;
std
::
unordered_set
<
Var
,
NodeHash
,
NodeEqual
>
free_vars
;
std
::
unordered_set
<
Var
,
NodeHash
,
NodeEqual
>
bound_vars
;
std
::
unordered_set
<
Var
,
NodeHash
,
NodeEqual
>
bound_vars
;
std
::
unordered_set
<
Type
Param
,
NodeHash
,
NodeEqual
>
free_types
;
std
::
unordered_set
<
Type
Var
,
NodeHash
,
NodeEqual
>
free_types
;
std
::
unordered_set
<
Type
Param
,
NodeHash
,
NodeEqual
>
bound_types
;
std
::
unordered_set
<
Type
Var
,
NodeHash
,
NodeEqual
>
bound_types
;
void
VisitType
(
const
Type
&
t
)
final
{
void
VisitType
(
const
Type
&
t
)
final
{
FreeTypeVar
(
&
free_types
,
&
bound_types
)(
t
);
FreeTypeVar
(
&
free_types
,
&
bound_types
)(
t
);
...
@@ -89,16 +89,16 @@ tvm::Array<Var> FreeVariables(const Expr& e) {
...
@@ -89,16 +89,16 @@ tvm::Array<Var> FreeVariables(const Expr& e) {
return
tvm
::
Array
<
Var
>
(
fv
.
free_vars
.
begin
(),
fv
.
free_vars
.
end
());
return
tvm
::
Array
<
Var
>
(
fv
.
free_vars
.
begin
(),
fv
.
free_vars
.
end
());
}
}
tvm
::
Array
<
Type
Param
>
FreeTypeVariables
(
const
Expr
&
e
)
{
tvm
::
Array
<
Type
Var
>
FreeTypeVariables
(
const
Expr
&
e
)
{
FreeVar
fv
;
FreeVar
fv
;
fv
.
VisitExpr
(
e
);
fv
.
VisitExpr
(
e
);
return
tvm
::
Array
<
Type
Param
>
(
fv
.
free_types
.
begin
(),
fv
.
free_types
.
end
());
return
tvm
::
Array
<
Type
Var
>
(
fv
.
free_types
.
begin
(),
fv
.
free_types
.
end
());
}
}
tvm
::
Array
<
Type
Param
>
FreeTypeVariables
(
const
Type
&
t
)
{
tvm
::
Array
<
Type
Var
>
FreeTypeVariables
(
const
Type
&
t
)
{
FreeVar
fv
;
FreeVar
fv
;
fv
.
VisitType
(
t
);
fv
.
VisitType
(
t
);
return
tvm
::
Array
<
Type
Param
>
(
fv
.
free_types
.
begin
(),
fv
.
free_types
.
end
());
return
tvm
::
Array
<
Type
Var
>
(
fv
.
free_types
.
begin
(),
fv
.
free_types
.
end
());
}
}
TVM_REGISTER_API
(
"relay._ir_pass.free_vars"
)
TVM_REGISTER_API
(
"relay._ir_pass.free_vars"
)
...
...
tests/python/relay/test_ir_builder.py
deleted
100644 → 0
View file @
4300bbc2
import
numpy
as
np
from
tvm.relay.expr
import
Let
,
Constant
from
tvm.relay.ir_builder
import
IRBuilder
def
test_let
():
b
=
IRBuilder
()
x
=
b
.
let
(
'x'
,
1
)
b
.
ret
(
x
)
prog
,
_
=
b
.
get
()
assert
isinstance
(
prog
,
Let
)
var
=
prog
.
var
value
=
prog
.
value
assert
var
.
name_hint
==
'x'
assert
var
==
prog
.
body
assert
isinstance
(
value
,
Constant
)
assert
value
.
data
.
asnumpy
()
==
np
.
array
(
1
)
if
__name__
==
"__main__"
:
test_let
()
tests/python/relay/test_ir_nodes.py
View file @
8876eac8
...
@@ -34,7 +34,7 @@ def test_tensor_type():
...
@@ -34,7 +34,7 @@ def test_tensor_type():
def
test_type_param
():
def
test_type_param
():
tp
=
relay
.
Type
Param
(
'name'
,
relay
.
Kind
.
Type
)
tp
=
relay
.
Type
Var
(
'name'
,
relay
.
Kind
.
Type
)
assert
tp
.
kind
==
relay
.
Kind
.
Type
assert
tp
.
kind
==
relay
.
Kind
.
Type
# assert tp.span # TODO allow us to set span
# assert tp.span # TODO allow us to set span
str
(
tp
)
str
(
tp
)
...
@@ -56,7 +56,7 @@ def test_func_type():
...
@@ -56,7 +56,7 @@ def test_func_type():
def
test_tuple_type
():
def
test_tuple_type
():
tp
=
relay
.
Type
Param
(
'tp'
,
relay
.
Kind
.
Type
)
tp
=
relay
.
Type
Var
(
'tp'
,
relay
.
Kind
.
Type
)
tf
=
relay
.
FuncType
(
tvm
.
convert
([]),
None
,
tvm
.
convert
([]),
tvm
.
convert
([]))
tf
=
relay
.
FuncType
(
tvm
.
convert
([]),
None
,
tvm
.
convert
([]),
tvm
.
convert
([]))
tt
=
relay
.
TensorType
(
tvm
.
convert
([
1
,
2
,
3
]),
'float32'
)
tt
=
relay
.
TensorType
(
tvm
.
convert
([
1
,
2
,
3
]),
'float32'
)
fields
=
tvm
.
convert
([
tp
,
tf
,
tt
])
fields
=
tvm
.
convert
([
tp
,
tf
,
tt
])
...
@@ -66,7 +66,7 @@ def test_tuple_type():
...
@@ -66,7 +66,7 @@ def test_tuple_type():
def
test_type_relation
():
def
test_type_relation
():
tp
=
relay
.
Type
Param
(
'tp'
,
relay
.
Kind
.
Type
)
tp
=
relay
.
Type
Var
(
'tp'
,
relay
.
Kind
.
Type
)
tf
=
relay
.
FuncType
(
tvm
.
convert
([]),
None
,
tvm
.
convert
([]),
tvm
.
convert
([]))
tf
=
relay
.
FuncType
(
tvm
.
convert
([]),
None
,
tvm
.
convert
([]),
tvm
.
convert
([]))
tt
=
relay
.
TensorType
(
tvm
.
convert
([
1
,
2
,
3
]),
'float32'
)
tt
=
relay
.
TensorType
(
tvm
.
convert
([
1
,
2
,
3
]),
'float32'
)
args
=
tvm
.
convert
([
tf
,
tt
,
tp
])
args
=
tvm
.
convert
([
tf
,
tt
,
tp
])
...
@@ -173,7 +173,7 @@ def test_if():
...
@@ -173,7 +173,7 @@ def test_if():
def
test_tuple_get_item
():
def
test_tuple_get_item
():
tup
=
relay
.
Var
(
"tuple"
)
tup
=
relay
.
Var
(
"tuple"
)
get
=
relay
.
TupleGetItem
(
tup
,
1
)
get
=
relay
.
TupleGetItem
(
tup
,
1
)
assert
get
.
tuple
==
tup
assert
get
.
tuple
_value
==
tup
assert
get
.
index
==
1
assert
get
.
index
==
1
str
(
get
)
str
(
get
)
...
...
tests/python/relay/test_ir_text_printer.py
View file @
8876eac8
...
@@ -27,7 +27,7 @@ def test_env():
...
@@ -27,7 +27,7 @@ def test_env():
z
=
relay
.
add
(
z
,
z
)
z
=
relay
.
add
(
z
,
z
)
f
=
relay
.
Function
([
x
,
y
],
z
)
f
=
relay
.
Function
([
x
,
y
],
z
)
env
=
relay
.
Environment
()
env
=
relay
.
Environment
()
env
.
add
(
"myf"
,
f
)
env
[
"myf"
]
=
f
text
=
env
.
astext
()
text
=
env
.
astext
()
assert
"def @myf"
in
text
assert
"def @myf"
in
text
assert
"
%1
= add(
%0
,
%0
) # ty=float32"
in
text
assert
"
%1
= add(
%0
,
%0
) # ty=float32"
in
text
...
@@ -70,15 +70,18 @@ def test_let_if_scope():
...
@@ -70,15 +70,18 @@ def test_let_if_scope():
x
=
relay
.
var
(
"x"
,
"float32"
)
x
=
relay
.
var
(
"x"
,
"float32"
)
y
=
relay
.
var
(
"y"
,
"float32"
)
y
=
relay
.
var
(
"y"
,
"float32"
)
cond
=
relay
.
var
(
"cond"
,
"bool"
)
cond
=
relay
.
var
(
"cond"
,
"bool"
)
v1
=
relay
.
var
(
"v"
)
v2
=
relay
.
var
(
"v"
,
"float32"
)
sb
=
relay
.
ScopeBuilder
()
then_branch
=
relay
.
Let
(
with
sb
.
if_scope
(
cond
):
v1
,
relay
.
const
(
1
,
"float32"
),
v1
=
sb
.
let
(
"v"
,
relay
.
const
(
1
,
"float32"
))
relay
.
Let
(
v2
,
x
,
relay
.
subtract
(
v1
,
v2
)))
v2
=
sb
.
let
(
"v"
,
x
)
v3
=
relay
.
var
(
"v"
)
sb
.
ret
(
relay
.
subtract
(
v1
,
v2
))
let2
=
relay
.
Let
(
v3
,
y
,
v3
)
with
sb
.
else_scope
():
else_branch
=
relay
.
add
(
let2
,
let2
)
v3
=
relay
.
var
(
"v"
)
result
=
relay
.
If
(
cond
,
then_branch
,
else_branch
)
let2
=
relay
.
Let
(
v3
,
y
,
v3
)
sb
.
ret
(
relay
.
add
(
let2
,
let2
))
result
=
sb
.
get
()
f
=
relay
.
Function
([
x
,
y
,
cond
],
result
)
f
=
relay
.
Function
([
x
,
y
,
cond
],
result
)
text
=
f
.
astext
()
text
=
f
.
astext
()
assert
text
.
count
(
"{"
)
==
4
assert
text
.
count
(
"{"
)
==
4
...
@@ -86,10 +89,17 @@ def test_let_if_scope():
...
@@ -86,10 +89,17 @@ def test_let_if_scope():
show
(
f
.
astext
())
show
(
f
.
astext
())
def
test_variable_name
():
# avoid pure number even if the namehint is pure number
v1
=
relay
.
var
(
"1"
)
assert
"
%
v1"
in
v1
.
astext
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
do_print
[
0
]
=
True
do_print
[
0
]
=
True
test_let_if_scope
()
test_func
()
test_func
()
test_env
()
test_env
()
test_meta_data
()
test_meta_data
()
test_call_attrs
()
test_call_attrs
()
test_let_if_scope
()
test_variable_name
()
tests/python/relay/test_op_level1.py
View file @
8876eac8
import
tvm
import
tvm
import
numpy
as
np
import
numpy
as
np
from
tvm
import
relay
from
tvm
import
relay
from
tvm.relay.ir_pass
import
infer_type
from
tvm.relay.ir_builder
import
IRBuilder
,
func_type
from
tvm.relay.ir_builder
import
scalar_type
,
convert
,
tensor_type
from
tvm.relay.env
import
Environment
def
assert_has_type
(
expr
,
typ
,
env
=
Environment
({})):
checked_expr
=
infer_type
(
env
,
expr
)
checked_type
=
checked_expr
.
checked_type
if
checked_type
!=
typ
:
raise
RuntimeError
(
"Type mismatch
%
s vs
%
s"
%
(
checked_type
,
typ
))
def
test_
single
_op
():
def
test_
unary
_op
():
def
check_single_op
(
opfunc
):
def
check_single_op
(
opfunc
):
"Program: fn (x : float32) { let t1 = f(x); t1 }"
tp
=
relay
.
TensorType
((
10
,
4
),
"float32"
)
b
=
IRBuilder
()
x
=
relay
.
var
(
"x"
,
tp
)
with
b
.
function
((
'x'
,
'float32'
))
as
func
:
y
=
opfunc
(
x
)
x
,
=
func
.
param_ids
()
# test printer
t1
=
b
.
let
(
't1'
,
opfunc
(
x
))
assert
(
"
%0
= {}(
%
x)"
.
format
(
y
.
op
.
name
))
in
y
.
astext
()
b
.
ret
(
t1
)
# test type inference
assert_has_type
(
func
.
to_func
(),
func_type
([
'float32'
],
'float32'
))
assert
relay
.
ir_pass
.
infer_type
(
y
)
.
checked_type
==
tp
for
opfunc
in
[
tvm
.
relay
.
log
,
tvm
.
relay
.
exp
,
tvm
.
relay
.
sqrt
,
for
opfunc
in
[
tvm
.
relay
.
log
,
tvm
.
relay
.
sigmoid
,
tvm
.
relay
.
tanh
]:
tvm
.
relay
.
exp
,
tvm
.
relay
.
sqrt
,
tvm
.
relay
.
sigmoid
,
tvm
.
relay
.
tanh
,
relay
.
nn
.
relu
]:
check_single_op
(
opfunc
)
check_single_op
(
opfunc
)
def
test_binary_op
():
def
check_binary_op
(
opfunc
):
n
=
tvm
.
var
(
"n"
)
t1
=
relay
.
TensorType
((
5
,
n
,
5
))
t2
=
relay
.
TensorType
((
n
,
1
))
x
=
relay
.
var
(
"x"
,
t1
)
y
=
relay
.
var
(
"y"
,
t2
)
z
=
opfunc
(
x
,
y
)
# test printer
assert
(
"
%0
= {}(
%
x,
%
y)"
.
format
(
z
.
op
.
name
))
in
z
.
astext
()
assert
relay
.
ir_pass
.
infer_type
(
z
)
.
checked_type
==
t1
for
opfunc
in
[
relay
.
add
,
relay
.
subtract
,
relay
.
mod
,
relay
.
multiply
,
relay
.
divide
]:
check_binary_op
(
opfunc
)
def
test_expand_dims_infer_type
():
def
test_expand_dims_infer_type
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
t
,
d
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"t"
),
100
n
,
t
,
d
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"t"
),
100
# let's mimic a batch of sequences
x
=
relay
.
var
(
"x"
,
shape
=
(
n
,
t
,
d
))
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
t
,
d
),
"float32"
))
y
=
relay
.
expand_dims
(
x
,
axis
=
2
)
with
ib
.
function
(
x
)
as
func
:
assert
"axis=2"
in
y
.
astext
()
ib
.
ret
(
relay
.
expand_dims
(
x
,
axis
=
2
))
checked
=
relay
.
ir_pass
.
infer_type
(
y
)
ib
.
ret
(
func
)
assert
checked
.
checked_type
==
relay
.
TensorType
((
n
,
t
,
1
,
100
))
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
(
(
n
,
t
,
1
,
100
),
"float32"
)
def
test_softmax
():
def
test_softmax
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
d
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"d"
)
n
,
d
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"d"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
d
),
"float32"
))
x
=
relay
.
var
(
"x"
,
shape
=
(
n
,
d
))
with
ib
.
function
(
x
)
as
func
:
y
=
relay
.
nn
.
softmax
(
x
,
axis
=
1
)
ib
.
ret
(
relay
.
nn
.
softmax
(
x
,
axis
=
1
))
assert
"nn.softmax"
in
y
.
astext
()
ib
.
ret
(
func
)
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
assert
yy
.
checked_type
==
relay
.
TensorType
((
n
,
d
))
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
n
,
d
),
"float32"
)
def
test_log_softmax
():
def
test_log_softmax
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
d
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"d"
)
n
,
d
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"d"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
d
),
"float32"
))
x
=
relay
.
var
(
"x"
,
shape
=
(
n
,
d
))
with
ib
.
function
(
x
)
as
func
:
y
=
relay
.
nn
.
log_softmax
(
x
,
axis
=
0
)
ib
.
ret
(
relay
.
nn
.
log_softmax
(
x
,
axis
=
1
))
assert
"nn.log_softmax"
in
y
.
astext
()
ib
.
ret
(
func
)
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
assert
yy
.
checked_type
==
relay
.
TensorType
((
n
,
d
))
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
n
,
d
),
"float32"
)
def
test_unary_op
():
for
op
in
[
relay
.
exp
,
relay
.
log
,
relay
.
sqrt
,
relay
.
sigmoid
,
relay
.
nn
.
relu
]:
ib
=
relay
.
ir_builder
.
IRBuilder
()
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
10
,
4
),
"int32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
op
(
x
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
TensorType
((
10
,
4
),
"int32"
)
def
test_binary_op
():
def
check_binary_op
(
opfunc
):
"""
Program:
fn (x, y) {
return x <op> y;
}
"""
b
=
IRBuilder
()
x
=
b
.
param
(
'x'
,
tensor_type
(
5
,
5
,
5
))
y
=
b
.
param
(
'y'
,
tensor_type
(
5
,
5
,
5
))
with
b
.
function
(
x
,
y
)
as
func
:
b
.
ret
(
opfunc
(
x
,
y
))
b
.
ret
(
func
)
prog
,
env
=
b
.
get
()
ttype
=
tensor_type
(
5
,
5
,
5
)
expected_ty
=
func_type
([
ttype
,
ttype
],
ttype
)
assert_has_type
(
func
.
to_func
(),
expected_ty
)
for
opfunc
in
[
relay
.
add
,
relay
.
subtract
,
relay
.
mod
,
relay
.
multiply
,
relay
.
divide
]:
check_binary_op
(
opfunc
)
def
test_binary_broadcast_op
():
def
check_binary_broadcast_op
(
opfunc
):
"""
Program:
fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] {
return x <op> y;
}
"""
b
=
IRBuilder
()
x
=
b
.
param
(
'x'
,
tensor_type
(
10
,
4
))
y
=
b
.
param
(
'y'
,
tensor_type
(
5
,
10
,
1
))
with
b
.
function
(
x
,
y
)
as
func
:
b
.
ret
(
opfunc
(
x
,
y
))
b
.
ret
(
func
)
prog
,
env
=
b
.
get
()
expected_ty
=
func_type
([
tensor_type
(
10
,
4
),
tensor_type
(
5
,
10
,
1
)],
tensor_type
(
5
,
10
,
4
))
assert_has_type
(
func
.
to_func
(),
expected_ty
)
for
opfunc
in
[
relay
.
add
,
relay
.
subtract
,
relay
.
mod
,
relay
.
multiply
,
relay
.
divide
]:
check_binary_broadcast_op
(
opfunc
)
def
test_concatenate_infer_type
():
def
test_concatenate_infer_type
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
t
,
d
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"t"
),
100
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
t
,
d
),
"float32"
))
y
=
ib
.
param
(
"y"
,
relay
.
ty
.
TensorType
((
n
,
t
,
d
),
"float32"
))
with
ib
.
function
(
x
,
y
)
as
func
:
ib
.
ret
(
relay
.
concatenate
((
x
,
y
),
axis
=-
1
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
(
(
n
,
t
,
200
),
"float32"
)
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
t
,
d
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"t"
),
100
n
,
t
,
d
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"t"
),
100
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
t
,
d
),
"float32"
))
x
=
relay
.
var
(
"x"
,
shape
=
(
n
,
t
,
d
))
y
=
ib
.
param
(
"y"
,
relay
.
ty
.
TensorType
((
n
,
t
,
d
),
"float32"
))
y
=
relay
.
var
(
"y"
,
shape
=
(
n
,
t
,
d
))
with
ib
.
function
(
x
,
y
)
as
func
:
z
=
relay
.
concatenate
((
x
,
y
),
axis
=-
1
)
ib
.
ret
(
relay
.
concatenate
((
x
,
y
),
axis
=
2
))
assert
"axis="
in
z
.
astext
()
ib
.
ret
(
func
)
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
assert
zz
.
checked_type
==
relay
.
TensorType
((
n
,
t
,
200
))
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
(
(
n
,
t
,
200
),
"float32"
)
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
t
,
d
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"t"
),
100
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
t
,
d
),
"float32"
))
y
=
ib
.
param
(
"y"
,
relay
.
ty
.
TensorType
((
n
,
t
,
d
),
"float32"
))
with
ib
.
function
(
x
,
y
)
as
func
:
ib
.
ret
(
relay
.
concatenate
((
x
,
y
),
axis
=
1
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
(
(
n
,
t
+
t
,
100
),
"float32"
)
def
test_lrn
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
nn
.
lrn
(
x
,
size
=
10
,
axis
=
2
,
bias
=
0.5
,
alpha
=.
00001
,
beta
=
0.75
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
)
z
=
relay
.
concatenate
((
x
,
y
),
axis
=
2
)
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
assert
zz
.
checked_type
==
relay
.
TensorType
((
n
,
t
,
200
))
def
test_l2_normalize
():
z
=
relay
.
concatenate
((
x
,
y
),
axis
=
1
)
ib
=
relay
.
ir_builder
.
IRBuilder
()
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
assert
zz
.
checked_type
==
relay
.
TensorType
((
n
,
t
+
t
,
100
))
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
nn
.
l2_normalize
(
x
,
eps
=
0.001
,
axis
=
[
1
]))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
)
def
test_dropout
():
def
test_dropout
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
input_ty
=
relay
.
ty
.
TensorType
((
3
,
4
,
5
),
"int8"
)
x
=
ib
.
param
(
"x"
,
input_ty
)
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
nn
.
dropout
(
x
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TupleType
([
input_ty
,
input_ty
])
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
t
,
d
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"t"
),
tvm
.
var
(
"d"
)
n
,
t
,
d
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"t"
),
tvm
.
var
(
"d"
)
input_ty
=
relay
.
ty
.
TensorType
((
n
,
t
,
d
),
"float32"
)
input_ty
=
relay
.
TensorType
((
n
,
t
,
d
),
"float32"
)
x
=
ib
.
param
(
"x"
,
input_ty
)
x
=
relay
.
var
(
"x"
,
input_ty
)
with
ib
.
function
(
x
)
as
func
:
y
,
_
=
relay
.
nn
.
dropout
(
x
,
rate
=
0.75
)
ib
.
ret
(
relay
.
nn
.
dropout
(
x
,
rate
=
0.75
))
assert
"rate="
in
y
.
astext
()
ib
.
ret
(
func
)
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
assert
yy
.
checked_type
==
input_ty
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TupleType
([
input_ty
,
input_ty
])
def
test_batch_norm
():
def
test_batch_norm
():
# beta and gamma ignored
# beta and gamma ignored
ib
=
relay
.
ir_builder
.
IRBuilder
()
data
=
relay
.
var
(
"data"
,
relay
.
TensorType
((
3
,
2
,
1
)))
data
=
ib
.
param
(
"data"
,
relay
.
ty
.
TensorType
((
3
,
2
,
1
),
"float32"
))
beta
=
relay
.
var
(
"beta"
,
relay
.
TensorType
((
2
,)))
gamma
=
ib
.
param
(
"gamma"
,
relay
.
ty
.
TensorType
((
5
,),
"int8"
))
gamma
=
relay
.
var
(
"gamma"
,
relay
.
TensorType
((
2
,)))
beta
=
ib
.
param
(
"beta"
,
relay
.
ty
.
TensorType
((
12
,
16
),
"int64"
))
moving_mean
=
relay
.
var
(
"moving_mean"
,
relay
.
TensorType
((
2
,)))
moving_mean
=
ib
.
param
(
"moving_mean"
,
relay
.
ty
.
TensorType
((
2
,),
"float32"
))
moving_var
=
relay
.
var
(
"moving_var"
,
relay
.
TensorType
((
2
,)))
moving_var
=
ib
.
param
(
"moving_var"
,
relay
.
ty
.
TensorType
((
2
,),
"float32"
))
y
=
relay
.
nn
.
batch_norm
(
data
,
gamma
,
beta
,
moving_mean
,
moving_var
,
with
ib
.
function
(
data
,
gamma
,
beta
,
moving_mean
,
moving_var
)
as
func
:
center
=
False
,
scale
=
False
)
ib
.
ret
(
relay
.
nn
.
batch_norm
(
data
,
gamma
,
beta
,
moving_mean
,
moving_var
,
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
center
=
False
,
scale
=
False
))
assert
"center="
in
yy
.
astext
()
ib
.
ret
(
func
)
assert
yy
.
checked_type
==
relay
.
ty
.
TupleType
(
tvm
.
convert
([
relay
.
TensorType
((
3
,
2
,
1
),
"float32"
),
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
relay
.
TensorType
((
2
,),
"float32"
),
ftype
=
func
.
checked_type
relay
.
TensorType
((
2
,),
"float32"
)
assert
ftype
.
ret_type
==
relay
.
ty
.
TupleType
(
tvm
.
convert
([
relay
.
ty
.
TensorType
((
3
,
2
,
1
),
"float32"
),
relay
.
ty
.
TensorType
((
2
,),
"float32"
),
relay
.
ty
.
TensorType
((
2
,),
"float32"
)
]))
]))
# with beta and gamma, different axis
beta
=
relay
.
var
(
"beta"
,
relay
.
TensorType
((
3
,)))
ib
=
relay
.
ir_builder
.
IRBuilder
()
gamma
=
relay
.
var
(
"gamma"
,
relay
.
TensorType
((
3
,)))
data
=
ib
.
param
(
"data"
,
relay
.
ty
.
TensorType
((
3
,
2
,
1
),
"float32"
))
moving_mean
=
relay
.
var
(
"moving_mean"
,
relay
.
TensorType
((
3
,)))
gamma
=
ib
.
param
(
"gamma"
,
relay
.
ty
.
TensorType
((
3
,),
"float32"
))
moving_var
=
relay
.
var
(
"moving_var"
,
relay
.
TensorType
((
3
,)))
beta
=
ib
.
param
(
"beta"
,
relay
.
ty
.
TensorType
((
3
,),
"float32"
))
moving_mean
=
ib
.
param
(
"moving_mean"
,
relay
.
ty
.
TensorType
((
3
,),
"float32"
))
moving_var
=
ib
.
param
(
"moving_var"
,
relay
.
ty
.
TensorType
((
3
,),
"float32"
))
with
ib
.
function
(
data
,
gamma
,
beta
,
moving_mean
,
moving_var
)
as
func
:
ib
.
ret
(
relay
.
nn
.
batch_norm
(
data
,
gamma
,
beta
,
moving_mean
,
moving_var
,
axis
=
0
,
center
=
False
,
scale
=
False
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
y
=
relay
.
nn
.
batch_norm
(
data
,
gamma
,
beta
,
moving_mean
,
moving_var
,
ftype
=
func
.
checked_type
axis
=
0
,
center
=
False
,
scale
=
False
)
assert
ftype
.
ret_type
==
relay
.
ty
.
TupleType
(
tvm
.
convert
([
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
assert
yy
.
checked_type
==
relay
.
ty
.
TupleType
(
tvm
.
convert
([
relay
.
ty
.
TensorType
((
3
,
2
,
1
),
"float32"
),
relay
.
ty
.
TensorType
((
3
,
2
,
1
),
"float32"
),
relay
.
ty
.
TensorType
((
3
,),
"float32"
),
relay
.
ty
.
TensorType
((
3
,),
"float32"
),
relay
.
ty
.
TensorType
((
3
,),
"float32"
)
relay
.
ty
.
TensorType
((
3
,),
"float32"
)
]))
]))
# axis=-1
# axis=-1
ib
=
relay
.
ir_builder
.
IRBuilder
()
data
=
relay
.
var
(
"data"
,
relay
.
TensorType
((
1
,
2
,
3
)))
data
=
ib
.
param
(
"data"
,
relay
.
ty
.
TensorType
((
1
,
2
,
3
),
"float32"
))
beta
=
relay
.
var
(
"beta"
,
relay
.
TensorType
((
3
,)))
gamma
=
ib
.
param
(
"gamma"
,
relay
.
ty
.
TensorType
((
3
,),
"float32"
))
gamma
=
relay
.
var
(
"gamma"
,
relay
.
TensorType
((
3
,)))
beta
=
ib
.
param
(
"beta"
,
relay
.
ty
.
TensorType
((
3
,),
"float32"
))
moving_mean
=
relay
.
var
(
"moving_mean"
,
relay
.
TensorType
((
3
,)))
moving_mean
=
ib
.
param
(
"moving_mean"
,
relay
.
ty
.
TensorType
((
3
,),
"float32"
))
moving_var
=
relay
.
var
(
"moving_var"
,
relay
.
TensorType
((
3
,)))
moving_var
=
ib
.
param
(
"moving_var"
,
relay
.
ty
.
TensorType
((
3
,),
"float32"
))
y
=
relay
.
nn
.
batch_norm
(
data
,
gamma
,
beta
,
moving_mean
,
moving_var
,
with
ib
.
function
(
data
,
gamma
,
beta
,
moving_mean
,
moving_var
)
as
func
:
axis
=-
1
,
center
=
False
,
scale
=
False
)
ib
.
ret
(
relay
.
nn
.
batch_norm
(
data
,
gamma
,
beta
,
moving_mean
,
moving_var
,
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
axis
=-
1
,
center
=
False
,
scale
=
False
))
assert
yy
.
checked_type
==
relay
.
ty
.
TupleType
(
tvm
.
convert
([
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TupleType
(
tvm
.
convert
([
relay
.
ty
.
TensorType
((
1
,
2
,
3
),
"float32"
),
relay
.
ty
.
TensorType
((
1
,
2
,
3
),
"float32"
),
relay
.
ty
.
TensorType
((
3
,),
"float32"
),
relay
.
ty
.
TensorType
((
3
,),
"float32"
),
relay
.
ty
.
TensorType
((
3
,),
"float32"
)
relay
.
ty
.
TensorType
((
3
,),
"float32"
)
...
@@ -285,14 +146,10 @@ def test_batch_norm():
...
@@ -285,14 +146,10 @@ def test_batch_norm():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_unary_op
()
test_unary_op
()
test_
single
_op
()
test_
binary
_op
()
test_expand_dims_infer_type
()
test_expand_dims_infer_type
()
test_concatenate_infer_type
()
test_concatenate_infer_type
()
test_softmax
()
test_softmax
()
test_log_softmax
()
test_log_softmax
()
test_binary_op
()
test_binary_broadcast_op
()
test_lrn
()
test_l2_normalize
()
test_dropout
()
test_dropout
()
test_batch_norm
()
test_batch_norm
()
tests/python/relay/test_op_level2.py
View file @
8876eac8
...
@@ -3,162 +3,111 @@
...
@@ -3,162 +3,111 @@
import
tvm
import
tvm
from
tvm
import
relay
from
tvm
import
relay
def
test_conv2d_infer_type
():
def
test_conv2d_infer_type
():
# symbolic in batch dimension
# symbolic in batch dimension
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
10
,
224
,
224
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
10
,
224
,
224
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
x
=
relay
.
var
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
w
=
ib
.
param
(
"w"
,
relay
.
ty
.
IncompleteType
())
w
=
relay
.
var
(
"w"
)
y
=
relay
.
nn
.
conv2d
(
x
,
w
,
with
ib
.
function
(
x
,
w
)
as
func
:
kernel_size
=
(
3
,
3
),
ib
.
ret
(
relay
.
nn
.
conv2d
(
x
,
w
,
padding
=
(
1
,
1
),
kernel_size
=
(
3
,
3
),
channels
=
2
)
padding
=
(
1
,
1
),
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
channels
=
2
))
assert
yy
.
checked_type
==
relay
.
TensorType
(
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
(
(
n
,
2
,
224
,
224
),
"float32"
)
(
n
,
2
,
224
,
224
),
"float32"
)
assert
ftype
.
arg_types
[
1
]
==
relay
.
t
y
.
TensorType
(
assert
yy
.
args
[
1
]
.
checked_type
==
rela
y
.
TensorType
(
(
2
,
10
,
3
,
3
),
"float32"
)
(
2
,
10
,
3
,
3
),
"float32"
)
# infer by shape of w, mixed precision
# infer by shape of w, mixed precision
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
10
,
224
,
224
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
10
,
224
,
224
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"int8"
))
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"int8"
))
w
=
ib
.
param
(
"w"
,
relay
.
ty
.
TensorType
((
2
,
10
,
3
,
3
),
"int8"
))
w
=
relay
.
var
(
"w"
,
relay
.
TensorType
((
2
,
10
,
3
,
3
),
"int8"
))
with
ib
.
function
(
x
,
w
)
as
func
:
y
=
relay
.
nn
.
conv2d
(
x
,
w
,
out_dtype
=
"int32"
)
ib
.
ret
(
relay
.
nn
.
conv2d
(
x
,
w
,
out_dtype
=
"int32"
))
assert
"out_dtype=
\"
int32
\"
"
in
y
.
astext
()
ib
.
ret
(
func
)
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
assert
yy
.
checked_type
==
relay
.
TensorType
(
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
(
(
n
,
2
,
222
,
222
),
"int32"
)
(
n
,
2
,
222
,
222
),
"int32"
)
# Infer with a different layout
# Infer with a different layout
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
4
,
32
,
224
,
224
n
,
c
,
h
,
w
=
4
,
32
,
224
,
224
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"int8"
))
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"int8"
))
w
=
ib
.
param
(
"w"
,
relay
.
ty
.
IncompleteType
())
w
=
relay
.
var
(
"w"
)
with
ib
.
function
(
x
,
w
)
as
func
:
y
=
relay
.
nn
.
conv2d
(
x
,
w
,
ib
.
ret
(
relay
.
nn
.
conv2d
(
x
,
w
,
kernel_size
=
(
3
,
3
),
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
padding
=
(
1
,
1
),
channels
=
16
,
channels
=
16
,
data_layout
=
"NCHW4n4c"
,
data_layout
=
"NCHW4n4c"
,
weight_layout
=
"OIHW4o4i"
,
weight_layout
=
"OIHW4o4i"
,
out_dtype
=
"int32"
)
out_dtype
=
"int32"
))
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
ib
.
ret
(
func
)
assert
yy
.
checked_type
==
relay
.
TensorType
(
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
(
(
1
,
4
,
224
,
224
,
4
,
4
),
"int32"
)
(
1
,
4
,
224
,
224
,
4
,
4
),
"int32"
)
assert
ftype
.
arg_types
[
1
]
==
relay
.
t
y
.
TensorType
(
assert
yy
.
args
[
1
]
.
checked_type
==
rela
y
.
TensorType
(
(
4
,
8
,
3
,
3
,
4
,
4
),
"int8"
)
(
4
,
8
,
3
,
3
,
4
,
4
),
"int8"
)
def
test_conv2d_transpose_infer_type
():
def
test_conv2d_transpose_infer_type
():
# symbolic in batch dimension
# symbolic in batch dimension
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
10
,
10
,
12
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
10
,
10
,
12
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
w
=
ib
.
param
(
"w"
,
relay
.
ty
.
IncompleteType
())
w
=
relay
.
var
(
"w"
,
relay
.
IncompleteType
())
y
=
relay
.
nn
.
conv2d_transpose
(
x
,
w
,
with
ib
.
function
(
x
,
w
)
as
func
:
kernel_size
=
(
3
,
3
),
ib
.
ret
(
relay
.
nn
.
conv2d_transpose
(
x
,
w
,
padding
=
(
1
,
1
),
kernel_size
=
(
3
,
3
),
channels
=
15
)
padding
=
(
1
,
1
),
assert
"channels=15"
in
y
.
astext
()
channels
=
15
))
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
ib
.
ret
(
func
)
assert
yy
.
checked_type
==
relay
.
TensorType
(
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
(
(
n
,
15
,
10
,
12
),
"float32"
)
(
n
,
15
,
10
,
12
),
"float32"
)
assert
ftype
.
arg_types
[
1
]
==
relay
.
t
y
.
TensorType
(
assert
yy
.
args
[
1
]
.
checked_type
==
rela
y
.
TensorType
(
(
10
,
15
,
3
,
3
),
"float32"
)
(
10
,
15
,
3
,
3
),
"float32"
)
# infer by shape of w, mixed precision
# infer by shape of w, mixed precision
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
10
,
10
,
12
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
10
,
10
,
12
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
w
=
ib
.
param
(
"w"
,
relay
.
ty
.
TensorType
((
12
,
11
,
5
,
5
),
"float32"
))
w
=
relay
.
var
(
"w"
,
relay
.
TensorType
((
12
,
11
,
5
,
5
),
"float32"
))
with
ib
.
function
(
x
,
w
)
as
func
:
y
=
relay
.
nn
.
conv2d_transpose
(
x
,
w
,
ib
.
ret
(
relay
.
nn
.
conv2d_transpose
(
x
,
w
,
output_padding
=
(
1
,
1
),
output_padding
=
(
1
,
1
),
channels
=
11
,
channels
=
11
,
data_layout
=
"NHWC"
)
data_layout
=
"NHWC"
))
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
ib
.
ret
(
func
)
assert
yy
.
checked_type
==
relay
.
TensorType
(
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
(
(
n
,
15
,
15
,
11
),
"float32"
)
(
n
,
15
,
15
,
11
),
"float32"
)
def
test_upsampling_infer_type
():
def
test_upsampling_infer_type
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
y
=
relay
.
nn
.
upsampling
(
x
,
scale
=
2
,
layout
=
"NCHW"
,
method
=
"BILINEAR"
)
ib
.
ret
(
relay
.
nn
.
upsampling
(
x
,
scale
=
2
,
layout
=
"NCHW"
,
method
=
"BILINEAR"
))
"method=
\"
BINLINEAR
\"
"
in
y
.
astext
()
ib
.
ret
(
func
)
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
assert
yy
.
checked_type
==
relay
.
TensorType
((
n
,
c
,
h
*
2
,
w
*
2
),
"float32"
)
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
n
,
c
,
h
*
2
,
w
*
2
),
"float32"
)
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
)
n
,
c
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
100
,
200
),
"float32"
))
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
c
,
100
,
200
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
y
=
relay
.
nn
.
upsampling
(
x
,
scale
=
2
,
layout
=
"NCHW"
,
method
=
"BILINEAR"
)
ib
.
ret
(
relay
.
nn
.
upsampling
(
x
,
scale
=
2
,
layout
=
"NCHW"
,
method
=
"BILINEAR"
))
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
ib
.
ret
(
func
)
assert
yy
.
checked_type
==
relay
.
TensorType
((
n
,
c
,
200
,
400
),
"float32"
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
n
,
c
,
200
,
400
),
"float32"
)
def
_test_pool2d_infer_type
(
opfunc
):
def
_test_pool2d_infer_type
(
opfunc
):
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
10
,
224
,
224
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
opfunc
(
x
,
pool_size
=
(
1
,
1
)))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
n
,
10
,
224
,
224
),
"float32"
)
ph
,
pw
=
tvm
.
var
(
"ph"
),
tvm
.
var
(
"pw"
)
sh
,
sw
=
tvm
.
var
(
"sh"
),
tvm
.
var
(
"sw"
)
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
10
,
224
,
224
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
10
,
224
,
224
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
y
=
opfunc
(
x
,
pool_size
=
(
1
,
1
))
ib
.
ret
(
opfunc
(
x
,
pool_size
=
(
ph
,
pw
),
strides
=
(
sh
,
sw
)))
assert
"pool_size="
in
y
.
astext
()
ib
.
ret
(
func
)
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
assert
yy
.
checked_type
==
relay
.
TensorType
((
n
,
10
,
224
,
224
),
"float32"
)
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
(
(
n
,
10
,
(((
224
-
ph
)
/
sh
)
+
1
),
(((
224
-
pw
)
/
sw
)
+
1
)),
"float32"
)
def
_test_global_pool2d_infer_type
(
opfunc
):
def
_test_global_pool2d_infer_type
(
opfunc
):
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
224
,
224
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
224
,
224
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
h
,
w
,
c
),
"float32"
))
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
h
,
w
,
c
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
y
=
opfunc
(
x
,
layout
=
"NHWC"
)
ib
.
ret
(
opfunc
(
x
,
layout
=
"NHWC"
))
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
ib
.
ret
(
func
)
assert
yy
.
checked_type
==
relay
.
TensorType
((
n
,
1
,
1
,
c
),
"float32"
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
n
,
1
,
1
,
c
),
"float32"
)
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
y
=
opfunc
(
x
)
ib
.
ret
(
opfunc
(
x
))
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
ib
.
ret
(
func
)
assert
yy
.
checked_type
==
relay
.
TensorType
((
n
,
c
,
1
,
1
),
"float32"
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
n
,
c
,
1
,
1
),
"float32"
)
def
test_pool2d_infer_type
():
def
test_pool2d_infer_type
():
_test_pool2d_infer_type
(
relay
.
nn
.
max_pool2d
)
_test_pool2d_infer_type
(
relay
.
nn
.
max_pool2d
)
...
@@ -167,101 +116,83 @@ def test_pool2d_infer_type():
...
@@ -167,101 +116,83 @@ def test_pool2d_infer_type():
_test_global_pool2d_infer_type
(
relay
.
nn
.
global_avg_pool2d
)
_test_global_pool2d_infer_type
(
relay
.
nn
.
global_avg_pool2d
)
def
test_flatten_infer_type
():
def
test_flatten_infer_type
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
d1
,
d2
,
d3
,
d4
=
tvm
.
var
(
"d1"
),
tvm
.
var
(
"d2"
),
tvm
.
var
(
"d3"
),
tvm
.
var
(
"d4"
)
d1
,
d2
,
d3
,
d4
=
tvm
.
var
(
"d1"
),
tvm
.
var
(
"d2"
),
tvm
.
var
(
"d3"
),
tvm
.
var
(
"d4"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
d1
,
d2
,
d3
,
d4
),
"float32"
))
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
d1
,
d2
,
d3
,
d4
),
"float32"
))
y
=
relay
.
nn
.
batch_flatten
(
x
)
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
assert
yy
.
checked_type
==
relay
.
TensorType
((
d1
,
((
d2
*
d3
)
*
d4
)),
"float32"
)
with
ib
.
function
(
x
)
as
func
:
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
3
,
2
,
4
,
3
),
"float32"
))
ib
.
ret
(
relay
.
nn
.
batch_flatten
(
x
))
y
=
relay
.
nn
.
batch_flatten
(
x
)
ib
.
ret
(
func
)
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
assert
yy
.
checked_type
==
relay
.
TensorType
((
3
,
24
),
"float32"
)
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
d1
,
((
d2
*
d3
)
*
d4
)),
"float32"
)
ib
=
relay
.
ir_builder
.
IRBuilder
()
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
d1
,
2
,
d3
,
3
),
"float32"
))
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
3
,
2
,
4
,
3
),
"float32"
))
y
=
relay
.
nn
.
batch_flatten
(
x
)
with
ib
.
function
(
x
)
as
func
:
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
ib
.
ret
(
relay
.
nn
.
batch_flatten
(
x
))
assert
yy
.
checked_type
==
relay
.
TensorType
((
d1
,
((
2
*
d3
)
*
3
)),
"float32"
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
3
,
24
),
"float32"
)
ib
=
relay
.
ir_builder
.
IRBuilder
()
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
d1
,
2
,
d3
,
3
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
nn
.
batch_flatten
(
x
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
d1
,
((
2
*
d3
)
*
3
)),
"float32"
)
def
test_pad_infer_type
():
def
test_pad_infer_type
():
# entirely concrete case
# entirely concrete case
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
1
,
2
,
3
,
4
n
,
c
,
h
,
w
=
1
,
2
,
3
,
4
t
=
ib
.
param
(
"t"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
t
=
relay
.
var
(
"t"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
t
)
as
func
:
y
=
relay
.
nn
.
pad
(
t
,
((
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
)))
ib
.
ret
(
relay
.
nn
.
pad
(
t
,
((
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
))))
"pad_width="
in
y
.
astext
()
ib
.
ret
(
func
)
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
assert
yy
.
checked_type
==
relay
.
TensorType
((
3
,
6
,
9
,
12
),
"float32"
)
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
TensorType
((
3
,
6
,
9
,
12
),
"float32"
)
# some symbolic values
# some symbolic values
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
2
,
3
,
tvm
.
var
(
"w"
)
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
2
,
3
,
tvm
.
var
(
"w"
)
t
=
ib
.
param
(
"t"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
t
=
relay
.
var
(
"t"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
t
)
as
func
:
y
=
relay
.
nn
.
pad
(
t
,
((
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
)))
ib
.
ret
(
relay
.
nn
.
pad
(
t
,
((
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
))))
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
ib
.
ret
(
func
)
assert
yy
.
checked_type
==
relay
.
TensorType
((
n
+
2
,
6
,
9
,
w
+
8
),
"float32"
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
TensorType
((
n
+
2
,
6
,
9
,
w
+
8
),
"float32"
)
def
test_dense_infer_type
():
def
test_dense_infer_type
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
w
=
relay
.
var
(
"w"
,
relay
.
TensorType
((
w
,
2
),
"float32"
))
w
=
ib
.
param
(
"w"
,
relay
.
ty
.
TensorType
((
w
,
2
),
"float32"
))
y
=
relay
.
nn
.
dense
(
x
,
w
,
units
=
2
)
"units=2"
in
y
.
astext
()
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
assert
yy
.
checked_type
==
relay
.
TensorType
((
n
,
c
,
h
,
2
),
"float32"
)
with
ib
.
function
(
x
,
w
)
as
func
:
ib
.
ret
(
relay
.
nn
.
dense
(
x
,
w
,
units
=
2
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
2
),
"float32"
)
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
2
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
2
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
wh
,
ww
=
tvm
.
var
(
"wh"
),
tvm
.
var
(
"ww"
)
wh
,
ww
=
tvm
.
var
(
"wh"
),
tvm
.
var
(
"ww"
)
w
=
ib
.
param
(
"w"
,
relay
.
ty
.
TensorType
((
wh
,
ww
),
"float32"
))
w
=
relay
.
var
(
"w"
,
relay
.
TensorType
((
wh
,
ww
),
"float32"
))
y
=
relay
.
nn
.
dense
(
x
,
w
)
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
assert
yy
.
checked_type
==
relay
.
TensorType
((
n
,
c
,
h
,
ww
),
"float32"
)
with
ib
.
function
(
x
,
w
)
as
func
:
ib
.
ret
(
relay
.
nn
.
dense
(
x
,
w
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
ww
),
"float32"
)
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
2
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
2
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
w
=
relay
.
var
(
"w"
,
relay
.
IncompleteType
())
y
=
relay
.
nn
.
dense
(
x
,
w
,
units
=
2
)
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
assert
yy
.
checked_type
==
relay
.
TensorType
((
n
,
c
,
h
,
2
),
"float32"
)
w
=
ib
.
param
(
"w"
,
relay
.
ty
.
IncompleteType
())
def
test_lrn
():
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
relay
.
var
(
"x"
,
shape
=
(
n
,
c
,
h
,
w
))
y
=
relay
.
nn
.
lrn
(
x
,
size
=
10
,
axis
=
2
,
bias
=
0.5
,
alpha
=.
00001
,
beta
=
0.75
)
"alpha="
in
y
.
astext
()
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
assert
yy
.
checked_type
==
relay
.
TensorType
((
n
,
c
,
h
,
w
))
with
ib
.
function
(
x
,
w
)
as
func
:
def
test_l2_normalize
():
ib
.
ret
(
relay
.
nn
.
dense
(
x
,
w
,
units
=
2
))
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
ib
.
ret
(
func
)
x
=
relay
.
var
(
"x"
,
shape
=
(
n
,
c
,
h
,
w
))
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
y
=
relay
.
nn
.
l2_normalize
(
x
,
eps
=
0.001
,
axis
=
[
1
])
ftype
=
func
.
checked_type
"axis="
in
y
.
astext
()
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
2
),
"float32"
)
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
assert
yy
.
checked_type
==
relay
.
TensorType
((
n
,
c
,
h
,
w
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_lrn
()
test_l2_normalize
()
test_conv2d_infer_type
()
test_conv2d_infer_type
()
test_pool2d_infer_type
()
test_pool2d_infer_type
()
test_upsampling_infer_type
()
test_upsampling_infer_type
()
...
...
tests/python/relay/test_op_level3.py
View file @
8876eac8
...
@@ -3,154 +3,92 @@
...
@@ -3,154 +3,92 @@
import
tvm
import
tvm
import
numpy
as
np
import
numpy
as
np
from
tvm
import
relay
from
tvm
import
relay
from
tvm.relay.ir_pass
import
infer_type
from
tvm.relay.ir_builder
import
IRBuilder
,
func_type
from
tvm.relay.env
import
Environment
from
nose.tools
import
raises
from
nose.tools
import
raises
def
test_zeros_ones
():
def
test_zeros_ones
():
for
op
in
[
relay
.
zeros
,
relay
.
ones
]:
for
op
in
[
relay
.
zeros
,
relay
.
ones
]:
ib
=
relay
.
ir_builder
.
IRBuilder
()
y
=
op
(
shape
=
(
124
,
50
),
dtype
=
"float64"
)
with
ib
.
function
()
as
func
:
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
ib
.
ret
(
op
((
124
,
50
),
"float64"
))
assert
yy
.
checked_type
==
relay
.
TensorType
((
124
,
50
),
"float64"
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
TensorType
((
124
,
50
),
"float64"
)
def
test_unary_identity
():
def
test_unary_identity
():
for
op
in
[
relay
.
zeros_like
,
relay
.
ones_like
]:
for
op
in
[
relay
.
zeros_like
,
ib
=
relay
.
ir_builder
.
IRBuilder
()
relay
.
ones_like
,
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
8
,
9
,
4
),
"int32"
))
relay
.
ceil
,
with
ib
.
function
(
x
)
as
func
:
relay
.
floor
,
ib
.
ret
(
op
(
x
))
relay
.
trunc
,
ib
.
ret
(
func
)
relay
.
round
,
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
relay
.
abs
,
ftype
=
func
.
checked_type
relay
.
copy
,
assert
ftype
.
ret_type
==
relay
.
TensorType
((
8
,
9
,
4
),
"int32"
)
relay
.
negative
]:
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
8
,
9
,
4
),
"float32"
))
y
=
op
(
x
)
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
assert
yy
.
checked_type
==
relay
.
TensorType
((
8
,
9
,
4
),
"float32"
)
def
test_clip_type
():
def
test_clip_type
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
a
=
relay
.
var
(
"a"
,
relay
.
TensorType
((
10
,
4
),
"float32"
))
a
=
ib
.
param
(
"a"
,
relay
.
TensorType
((
10
,
4
),
"float32"
))
y
=
relay
.
clip
(
a
,
1.
,
4.
)
with
ib
.
function
(
a
)
as
func
:
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
ib
.
ret
(
relay
.
clip
(
a
,
1.
,
4.
))
assert
yy
.
checked_type
==
relay
.
TensorType
((
10
,
4
),
"float32"
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
TensorType
((
10
,
4
),
"float32"
)
def
test_copy_infer_type
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
t
,
d
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"t"
),
100
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
t
,
d
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
copy
(
x
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
(
(
n
,
t
,
100
),
"float32"
)
def
test_transpose_infer_type
():
def
test_transpose_infer_type
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
t
,
d
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"t"
),
100
n
,
t
,
d
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"t"
),
100
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
t
,
d
),
"float32"
))
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
t
,
d
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
y
=
relay
.
transpose
(
x
,
axes
=
(
1
,
0
,
2
))
ib
.
ret
(
relay
.
transpose
(
x
,
axes
=
(
1
,
0
,
2
)))
"axes="
in
y
.
astext
()
ib
.
ret
(
func
)
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
assert
yy
.
checked_type
==
relay
.
TensorType
(
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
(
(
t
,
n
,
100
),
"float32"
)
(
t
,
n
,
100
),
"float32"
)
def
test_squeeze_default_axes_infer_type
():
def
test_squeeze_infer_type
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
t
,
d
=
1
,
4
,
1
n
,
t
,
d
=
1
,
4
,
1
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
t
,
d
),
"float32"
))
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
t
,
d
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
y
=
relay
.
squeeze
(
x
,
axes
=
(
2
,))
ib
.
ret
(
relay
.
squeeze
(
x
))
assert
"axes="
in
y
.
astext
()
ib
.
ret
(
func
)
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
assert
yy
.
checked_type
==
relay
.
TensorType
(
ftype
=
func
.
checked_type
(
1
,
4
),
"float32"
)
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
(
(
4
,),
"float32"
)
def
test_squeeze_axes_infer_type
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
t
,
d
=
1
,
4
,
1
n
,
t
,
d
=
1
,
4
,
1
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
t
,
d
),
"float32"
))
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
t
,
d
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
y
=
relay
.
squeeze
(
x
)
ib
.
ret
(
relay
.
squeeze
(
x
,
axes
=
(
2
,)))
assert
"axes="
not
in
y
.
astext
()
ib
.
ret
(
func
)
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
assert
yy
.
checked_type
==
relay
.
TensorType
(
ftype
=
func
.
checked_type
(
4
,),
"float32"
)
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
(
(
1
,
4
),
"float32"
)
@raises
(
tvm
.
_ffi
.
base
.
TVMError
)
@raises
(
tvm
.
_ffi
.
base
.
TVMError
)
def
test_squeeze_bad_axes_infer_type
():
def
test_squeeze_bad_axes_infer_type
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
t
,
d
=
1
,
4
,
1
n
,
t
,
d
=
1
,
4
,
1
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
t
,
d
),
"float32"
))
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
t
,
d
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
y
=
relay
.
squeeze
(
x
,
axes
=
(
1
,))
ib
.
ret
(
relay
.
squeeze
(
x
,
axes
=
(
1
,)))
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
def
test_reshape_infer_type
():
def
test_reshape_infer_type
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
t
,
d1
,
d2
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"t"
),
100
,
20
n
,
t
,
d1
,
d2
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"t"
),
100
,
20
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
t
,
d1
,
d2
),
"float32"
))
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
t
,
d1
,
d2
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
y
=
relay
.
reshape
(
x
,
newshape
=
(
n
,
t
,
2000
))
ib
.
ret
(
relay
.
reshape
(
x
,
newshape
=
(
n
,
t
,
2000
)))
assert
"newshape="
in
y
.
astext
()
ib
.
ret
(
func
)
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
assert
yy
.
checked_type
==
relay
.
TensorType
(
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
(
(
n
,
t
,
2000
),
"float32"
)
(
n
,
t
,
2000
),
"float32"
)
def
assert_has_type
(
expr
,
typ
,
env
=
Environment
({})):
checked_expr
=
infer_type
(
env
,
expr
)
checked_type
=
checked_expr
.
checked_type
if
checked_type
!=
typ
:
raise
RuntimeError
(
"Type mismatch
%
s vs
%
s"
%
(
checked_type
,
typ
))
def
test_single_op
():
def
check_single_op
(
opfunc
):
"Program: fn (x : float32) { let t1 = f(x); t1 }"
b
=
IRBuilder
()
with
b
.
function
((
'x'
,
'float32'
))
as
func
:
x
,
=
func
.
param_ids
()
t1
=
b
.
let
(
't1'
,
opfunc
(
x
))
b
.
ret
(
t1
)
assert_has_type
(
func
.
to_func
(),
func_type
([
'float32'
],
'float32'
))
for
opfunc
in
[
tvm
.
relay
.
ceil
,
tvm
.
relay
.
floor
,
tvm
.
relay
.
trunc
,
tvm
.
relay
.
round
,
tvm
.
relay
.
abs
,
tvm
.
relay
.
negative
]:
check_single_op
(
opfunc
)
def
test_take_infer_type
():
def
test_take_infer_type
():
def
verify_take
(
dshape
,
indices_shape
,
oshape
,
axis
=
None
):
def
verify_take
(
dshape
,
indices_shape
,
oshape
,
axis
=
None
):
ib
=
relay
.
ir_builder
.
IRBuilder
()
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
(
dshape
,
"float32"
))
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
(
dshape
,
"float32"
))
indices
=
relay
.
var
(
"indices"
,
relay
.
TensorType
(
indices_shape
,
"int32"
))
indices
=
ib
.
param
(
"indices"
,
relay
.
ty
.
TensorType
(
indices_shape
,
"int32"
))
y
=
relay
.
take
(
x
,
indices
,
axis
=
axis
)
with
ib
.
function
(
x
,
indices
)
as
func
:
y
.
astext
()
ib
.
ret
(
relay
.
take
(
x
,
indices
,
axis
=
axis
))
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
ib
.
ret
(
func
)
assert
yy
.
checked_type
==
relay
.
TensorType
(
oshape
,
"float32"
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
(
oshape
,
"float32"
)
d1
,
d2
,
d3
=
tvm
.
var
(
"d1"
),
tvm
.
var
(
"d2"
),
tvm
.
var
(
"d3"
)
d1
,
d2
,
d3
=
tvm
.
var
(
"d1"
),
tvm
.
var
(
"d2"
),
tvm
.
var
(
"d3"
)
d4
,
d5
,
d6
=
tvm
.
var
(
"d4"
),
tvm
.
var
(
"d5"
),
tvm
.
var
(
"d6"
)
d4
,
d5
,
d6
=
tvm
.
var
(
"d4"
),
tvm
.
var
(
"d5"
),
tvm
.
var
(
"d6"
)
...
@@ -164,73 +102,52 @@ def test_take_infer_type():
...
@@ -164,73 +102,52 @@ def test_take_infer_type():
def
test_full
():
def
test_full
():
# default settings: match input dtype
# default settings: match input dtype
ib
=
relay
.
ir_builder
.
IRBuilder
()
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((),
"int8"
))
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((),
"int8"
))
y
=
relay
.
full
(
x
,
())
with
ib
.
function
(
x
)
as
func
:
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
ib
.
ret
(
relay
.
full
(
x
,
()))
assert
yy
.
checked_type
==
relay
.
TensorType
((),
"int8"
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
TensorType
((),
"int8"
)
# change the shape and dtype
# change the shape and dtype
ib
=
relay
.
ir_builder
.
IRBuilder
()
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((),
"float32"
))
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((),
"float32"
))
y
=
relay
.
full
(
x
,
(
1
,
2
),
"int8"
)
with
ib
.
function
(
x
)
as
func
:
"shape="
in
y
.
astext
()
ib
.
ret
(
relay
.
full
(
x
,
(
1
,
2
),
"int8"
))
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
ib
.
ret
(
func
)
assert
yy
.
checked_type
==
relay
.
TensorType
((
1
,
2
),
"int8"
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
TensorType
((
1
,
2
),
"int8"
)
def
test_full_like
():
def
test_full_like
():
# concrete shape
# concrete shape
ib
=
relay
.
ir_builder
.
IRBuilder
()
base
=
relay
.
var
(
"base"
,
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
))
base
=
ib
.
param
(
"base"
,
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
))
fill
=
relay
.
var
(
"fill"
,
relay
.
TensorType
((),
"float32"
))
fill
=
ib
.
param
(
"fill"
,
relay
.
TensorType
((),
"float32"
))
y
=
relay
.
full_like
(
base
,
fill
)
with
ib
.
function
(
base
,
fill
)
as
func
:
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
ib
.
ret
(
relay
.
full_like
(
base
,
fill
))
assert
yy
.
checked_type
==
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
)
# symbolic shape
# symbolic shape
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
2
,
3
,
tvm
.
var
(
"w"
)
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
2
,
3
,
tvm
.
var
(
"w"
)
base
=
ib
.
param
(
"base"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
base
=
relay
.
var
(
"base"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
fill
=
ib
.
param
(
"fill"
,
relay
.
TensorType
((),
"float32"
))
fill
=
relay
.
var
(
"fill"
,
relay
.
TensorType
((),
"float32"
))
with
ib
.
function
(
base
,
fill
)
as
func
:
y
=
relay
.
full_like
(
base
,
fill
)
ib
.
ret
(
relay
.
full_like
(
base
,
fill
))
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
ib
.
ret
(
func
)
assert
yy
.
checked_type
==
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
)
def
test_infer_type_leaky_relu
():
def
test_infer_type_leaky_relu
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
y
=
relay
.
nn
.
leaky_relu
(
x
,
alpha
=
0.1
)
with
ib
.
function
(
x
)
as
func
:
"alpha=0.1"
in
y
.
astext
()
ib
.
ret
(
relay
.
nn
.
leaky_relu
(
x
,
alpha
=
0.1
))
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
ib
.
ret
(
func
)
assert
yy
.
checked_type
==
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_single_op
()
test_zeros_ones
()
test_zeros_ones
()
test_unary_identity
()
test_unary_identity
()
test_clip_type
()
test_clip_type
()
test_copy_infer_type
()
test_transpose_infer_type
()
test_transpose_infer_type
()
test_reshape_infer_type
()
test_reshape_infer_type
()
test_take_infer_type
()
test_take_infer_type
()
test_full
()
test_full
()
test_full_like
()
test_full_like
()
test_infer_type_leaky_relu
()
test_infer_type_leaky_relu
()
test_squeeze_
axes_
infer_type
()
test_squeeze_infer_type
()
test_squeeze_
default
_axes_infer_type
()
test_squeeze_
bad
_axes_infer_type
()
tests/python/relay/test_op_level4.py
View file @
8876eac8
import
tvm
import
tvm
import
numpy
as
np
import
numpy
as
np
from
tvm
import
relay
from
tvm
import
relay
from
tvm.relay.ir_pass
import
infer_type
from
tvm.relay.ir_builder
import
IRBuilder
,
func_type
from
tvm.relay.ir_builder
import
scalar_type
,
convert
,
tensor_type
from
tvm.relay.env
import
Environment
def
assert_has_type
(
expr
,
typ
,
env
=
Environment
({})):
checked_expr
=
infer_type
(
env
,
expr
)
checked_type
=
checked_expr
.
checked_type
if
checked_type
!=
typ
:
raise
RuntimeError
(
"Type mismatch
%
s vs
%
s"
%
(
checked_type
,
typ
))
def
test_binary_op
():
def
test_binary_op
():
def
check_binary_op
(
opfunc
):
def
check_binary_op
(
opfunc
):
"""
n
=
tvm
.
var
(
"n"
)
Program:
t1
=
relay
.
TensorType
((
5
,
n
,
5
))
fn (x, y) {
t2
=
relay
.
TensorType
((
n
,
1
))
return x <op> y;
x
=
relay
.
var
(
"x"
,
t1
)
}
y
=
relay
.
var
(
"y"
,
t2
)
"""
z
=
opfunc
(
x
,
y
)
b
=
IRBuilder
()
# test printer
assert
(
"
%0
= {}(
%
x,
%
y)"
.
format
(
z
.
op
.
name
))
in
z
.
astext
()
x
=
b
.
param
(
'x'
,
tensor_type
(
5
,
5
,
5
))
assert
relay
.
ir_pass
.
infer_type
(
z
)
.
checked_type
==
t1
y
=
b
.
param
(
'y'
,
tensor_type
(
5
,
5
,
5
))
with
b
.
function
(
x
,
y
)
as
func
:
b
.
ret
(
opfunc
(
x
,
y
))
b
.
ret
(
func
)
prog
,
env
=
b
.
get
()
ttype
=
tensor_type
(
5
,
5
,
5
)
expected_ty
=
func_type
([
ttype
,
ttype
],
ttype
)
assert_has_type
(
func
.
to_func
(),
expected_ty
)
for
opfunc
in
[
relay
.
pow
]:
for
opfunc
in
[
relay
.
pow
]:
check_binary_op
(
opfunc
)
check_binary_op
(
opfunc
)
def
test_binary_broadcast_op
():
def
check_binary_broadcast_op
(
opfunc
):
"""
Program:
fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] {
return x <op> y;
}
"""
b
=
IRBuilder
()
x
=
b
.
param
(
'x'
,
tensor_type
(
10
,
4
))
y
=
b
.
param
(
'y'
,
tensor_type
(
5
,
10
,
1
))
with
b
.
function
(
x
,
y
)
as
func
:
b
.
ret
(
opfunc
(
x
,
y
))
b
.
ret
(
func
)
prog
,
env
=
b
.
get
()
expected_ty
=
func_type
([
tensor_type
(
10
,
4
),
tensor_type
(
5
,
10
,
1
)],
tensor_type
(
5
,
10
,
4
))
assert_has_type
(
func
.
to_func
(),
expected_ty
)
for
opfunc
in
[
relay
.
pow
]:
check_binary_broadcast_op
(
opfunc
)
def
test_cmp_type
():
def
test_cmp_type
():
for
op
in
(
relay
.
greater
,
for
op
in
(
relay
.
greater
,
relay
.
greater_equal
,
relay
.
greater_equal
,
...
@@ -68,138 +26,59 @@ def test_cmp_type():
...
@@ -68,138 +26,59 @@ def test_cmp_type():
relay
.
less_equal
,
relay
.
less_equal
,
relay
.
equal
,
relay
.
equal
,
relay
.
not_equal
):
relay
.
not_equal
):
ib
=
relay
.
ir_builder
.
IRBuilder
()
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
10
,
4
),
"float32"
))
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
10
,
4
),
"float32"
))
y
=
relay
.
var
(
"y"
,
relay
.
TensorType
((
5
,
10
,
1
),
"float32"
))
y
=
ib
.
param
(
"y"
,
relay
.
TensorType
((
5
,
10
,
1
),
"float32"
))
z
=
op
(
x
,
y
)
with
ib
.
function
(
x
,
y
)
as
func
:
z
.
astext
()
ib
.
ret
(
op
(
x
,
y
))
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
ib
.
ret
(
func
)
assert
zz
.
checked_type
==
relay
.
TensorType
((
5
,
10
,
4
),
"bool"
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
TensorType
((
5
,
10
,
4
),
"uint1"
)
def
test_binary_broadcast
():
def
test_binary_
int_
broadcast
():
for
op
in
[
relay
.
right_shift
,
for
op
in
[
relay
.
right_shift
,
relay
.
left_shift
,
relay
.
left_shift
,
relay
.
maximum
,
relay
.
maximum
,
relay
.
minimum
]:
relay
.
minimum
]:
ib
=
relay
.
ir_builder
.
IRBuilder
()
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
10
,
4
),
"int32"
))
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
10
,
4
),
"int32"
))
y
=
relay
.
var
(
"y"
,
relay
.
TensorType
((
5
,
10
,
1
),
"int32"
))
y
=
ib
.
param
(
"y"
,
relay
.
TensorType
((
5
,
10
,
1
),
"int32"
))
z
=
op
(
x
,
y
)
with
ib
.
function
(
x
,
y
)
as
func
:
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
ib
.
ret
(
op
(
x
,
y
))
assert
zz
.
checked_type
==
relay
.
TensorType
((
5
,
10
,
4
),
"int32"
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
def
test_arg_reduce
():
assert
ftype
.
ret_type
==
relay
.
TensorType
((
5
,
10
,
4
),
"int32"
)
for
op
in
[
relay
.
argmax
,
relay
.
argmin
]:
n
,
c
,
h
,
w
=
10
,
20
,
3
,
4
def
test_argmax
():
x
=
relay
.
var
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
ib
=
relay
.
ir_builder
.
IRBuilder
()
z
=
relay
.
argmax
(
x
,
axis
=
(
1
,))
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
"axis="
in
z
.
astext
()
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
with
ib
.
function
(
x
)
as
func
:
assert
zz
.
checked_type
==
relay
.
ty
.
TensorType
((
n
,
h
,
w
),
"int32"
)
ib
.
ret
(
relay
.
argmax
(
x
,
axis
=
(
1
,)))
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
ib
.
ret
(
func
)
x
=
relay
.
var
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
z
=
relay
.
argmax
(
x
,
axis
=
(
2
,),
keepdims
=
True
)
ftype
=
func
.
checked_type
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
n
,
h
,
w
),
"int32"
)
assert
zz
.
checked_type
==
relay
.
ty
.
TensorType
((
n
,
c
,
1
,
w
),
"int32"
)
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
relay
.
var
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
z
=
relay
.
argmax
(
x
,
axis
=
(
2
,),
keepdims
=
True
,
exclude
=
True
)
with
ib
.
function
(
x
)
as
func
:
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
ib
.
ret
(
relay
.
argmax
(
x
,
axis
=
(
2
,),
keepdims
=
True
))
assert
zz
.
checked_type
==
relay
.
ty
.
TensorType
((
1
,
1
,
h
,
1
),
"int32"
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
n
,
c
,
1
,
w
),
"int32"
)
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
argmax
(
x
,
axis
=
(
2
,),
keepdims
=
True
,
exclude
=
True
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
1
,
1
,
h
,
1
),
"int32"
)
def
test_argmin
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
argmax
(
x
,
axis
=
(
1
,)))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
n
,
h
,
w
),
"int32"
)
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
argmin
(
x
,
axis
=
(
2
,),
keepdims
=
True
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
n
,
c
,
1
,
w
),
"int32"
)
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
argmin
(
x
,
axis
=
(
2
,),
keepdims
=
True
,
exclude
=
True
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
1
,
1
,
h
,
1
),
"int32"
)
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
argmin
(
x
,
axis
=
(
2
,
1
),
keepdims
=
True
,
exclude
=
True
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
1
,
c
,
h
,
1
),
"int32"
)
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
argmin
(
x
,
axis
=
None
,
keepdims
=
True
,
exclude
=
True
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
1
,
1
,
1
,
1
),
"int32"
)
def
test_where
():
def
test_where
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
cond
=
relay
.
var
(
"cond"
,
relay
.
TensorType
((
3
,
4
),
"float32"
))
cond
=
ib
.
param
(
"cond"
,
relay
.
TensorType
((
3
,
4
),
"float32"
))
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
3
,
4
),
"float32"
))
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
3
,
4
),
"float32"
))
y
=
relay
.
var
(
"y"
,
relay
.
TensorType
((
3
,
4
),
"float32"
))
y
=
ib
.
param
(
"y"
,
relay
.
TensorType
((
3
,
4
),
"float32"
))
z
=
relay
.
where
(
cond
,
x
,
y
)
with
ib
.
function
(
cond
,
x
,
y
)
as
func
:
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
ib
.
ret
(
relay
.
where
(
cond
,
x
,
y
))
assert
zz
.
checked_type
==
relay
.
TensorType
((
3
,
4
),
"float32"
)
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
TensorType
((
3
,
4
),
"float32"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_binary_op
()
test_binary_op
()
test_binary_broadcast_op
()
test_cmp_type
()
test_cmp_type
()
test_binary_broadcast
()
test_binary_
int_
broadcast
()
test_where
()
test_where
()
test_multibox_prior
()
test_arg_reduce
()
test_argmax
()
test_argmin
()
tests/python/relay/test_op_level5.py
View file @
8876eac8
...
@@ -4,26 +4,18 @@ import tvm
...
@@ -4,26 +4,18 @@ import tvm
from
tvm
import
relay
from
tvm
import
relay
def
test_resize_infer_type
():
def
test_resize_infer_type
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
ib
.
param
(
"x"
,
relay
.
t
y
.
TensorType
((
n
,
c
,
h
,
w
),
"int8"
))
x
=
relay
.
var
(
"x"
,
rela
y
.
TensorType
((
n
,
c
,
h
,
w
),
"int8"
))
th
,
tw
=
tvm
.
var
(
"th"
),
tvm
.
var
(
"tw"
)
th
,
tw
=
tvm
.
var
(
"th"
),
tvm
.
var
(
"tw"
)
z
=
relay
.
image
.
resize
(
x
,
(
th
,
tw
))
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
assert
zz
.
checked_type
==
relay
.
TensorType
((
n
,
c
,
th
,
tw
),
"int8"
)
with
ib
.
function
(
x
)
as
func
:
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"int8"
))
ib
.
ret
(
relay
.
image
.
resize
(
x
,
(
th
,
tw
)))
z
=
relay
.
image
.
resize
(
x
,
(
100
,
200
),
"NCHW"
,
"BILINEAR"
,
False
)
ib
.
ret
(
func
)
assert
"size="
in
z
.
astext
()
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
ftype
=
func
.
checked_type
assert
zz
.
checked_type
==
relay
.
TensorType
((
n
,
c
,
100
,
200
),
"int8"
)
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
n
,
c
,
th
,
tw
),
"int8"
)
ib
=
relay
.
ir_builder
.
IRBuilder
()
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"int8"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
image
.
resize
(
x
,
(
100
,
200
),
"NCHW"
,
"BILINEAR"
,
False
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
n
,
c
,
100
,
200
),
"int8"
)
...
@@ -34,29 +26,21 @@ def test_multibox_prior():
...
@@ -34,29 +26,21 @@ def test_multibox_prior():
offsets
=
(
0.2
,
0.3
)
offsets
=
(
0.2
,
0.3
)
clip
=
True
clip
=
True
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
3
,
56
,
56
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
3
,
56
,
56
x
=
ib
.
param
(
"x"
,
relay
.
t
y
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
x
=
relay
.
var
(
"x"
,
rela
y
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
z
=
relay
.
vision
.
multibox_prior
(
x
,
sizes
,
ratios
,
ib
.
ret
(
relay
.
vision
.
multibox_prior
(
x
,
sizes
,
ratios
,
steps
,
offsets
,
clip
)
steps
,
offsets
,
clip
))
assert
"sizes="
in
z
.
astext
()
ib
.
ret
(
func
)
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
assert
zz
.
checked_type
==
relay
.
TensorType
(
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
(
(
1
,
h
*
w
*
(
len
(
sizes
)
+
len
(
ratios
)
-
1
),
4
),
"float32"
)
(
1
,
h
*
w
*
(
len
(
sizes
)
+
len
(
ratios
)
-
1
),
4
),
"float32"
)
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
24
,
32
,
32
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
24
,
32
,
32
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
z
=
relay
.
vision
.
multibox_prior
(
x
)
with
ib
.
function
(
x
)
as
func
:
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
ib
.
ret
(
relay
.
vision
.
multibox_prior
(
x
))
assert
zz
.
checked_type
==
relay
.
TensorType
(
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
(
(
1
,
h
*
w
,
4
),
"float32"
)
(
1
,
h
*
w
,
4
),
"float32"
)
...
...
tests/python/relay/test_pass_alpha_equal.py
View file @
8876eac8
...
@@ -2,7 +2,6 @@ import tvm
...
@@ -2,7 +2,6 @@ import tvm
import
numpy
as
np
import
numpy
as
np
from
tvm
import
relay
from
tvm
import
relay
from
tvm.relay.ir_pass
import
alpha_equal
from
tvm.relay.ir_pass
import
alpha_equal
from
tvm.relay.ir_builder
import
convert
def
test_tensor_type_alpha_equal
():
def
test_tensor_type_alpha_equal
():
t1
=
relay
.
TensorType
((
3
,
4
),
"float32"
)
t1
=
relay
.
TensorType
((
3
,
4
),
"float32"
)
...
@@ -29,9 +28,9 @@ def test_incomplete_type_alpha_equal():
...
@@ -29,9 +28,9 @@ def test_incomplete_type_alpha_equal():
def
test_type_param_alpha_equal
():
def
test_type_param_alpha_equal
():
t1
=
relay
.
Type
Param
(
"v1"
,
relay
.
Kind
.
Type
)
t1
=
relay
.
Type
Var
(
"v1"
,
relay
.
Kind
.
Type
)
t2
=
relay
.
Type
Param
(
"v2"
,
relay
.
Kind
.
Shape
)
t2
=
relay
.
Type
Var
(
"v2"
,
relay
.
Kind
.
Shape
)
t3
=
relay
.
Type
Param
(
"v3"
,
relay
.
Kind
.
Type
)
t3
=
relay
.
Type
Var
(
"v3"
,
relay
.
Kind
.
Type
)
# only pointer equality and eq_map allow equal params
# only pointer equality and eq_map allow equal params
assert
t1
==
t1
assert
t1
==
t1
...
@@ -54,10 +53,10 @@ def test_func_type_alpha_equal():
...
@@ -54,10 +53,10 @@ def test_func_type_alpha_equal():
t1
=
relay
.
TensorType
((
1
,
2
),
"float32"
)
t1
=
relay
.
TensorType
((
1
,
2
),
"float32"
)
t2
=
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
)
t2
=
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
)
tp1
=
relay
.
Type
Param
(
"v1"
,
relay
.
Kind
.
Type
)
tp1
=
relay
.
Type
Var
(
"v1"
,
relay
.
Kind
.
Type
)
tp2
=
relay
.
Type
Param
(
"v2"
,
relay
.
Kind
.
Type
)
tp2
=
relay
.
Type
Var
(
"v2"
,
relay
.
Kind
.
Type
)
tp3
=
relay
.
Type
Param
(
"v3"
,
relay
.
Kind
.
Shape
)
tp3
=
relay
.
Type
Var
(
"v3"
,
relay
.
Kind
.
Shape
)
tp4
=
relay
.
Type
Param
(
"v3"
,
relay
.
Kind
.
Shape
)
tp4
=
relay
.
Type
Var
(
"v3"
,
relay
.
Kind
.
Shape
)
broadcast
=
tvm
.
get_env_func
(
"tvm.relay.type_relation.Broadcast"
)
broadcast
=
tvm
.
get_env_func
(
"tvm.relay.type_relation.Broadcast"
)
identity
=
tvm
.
get_env_func
(
"tvm.relay.type_relation.Identity"
)
identity
=
tvm
.
get_env_func
(
"tvm.relay.type_relation.Identity"
)
...
@@ -113,8 +112,8 @@ def test_func_type_alpha_equal():
...
@@ -113,8 +112,8 @@ def test_func_type_alpha_equal():
def
test_tuple_type_alpha_equal
():
def
test_tuple_type_alpha_equal
():
t1
=
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
)
t1
=
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
)
t2
=
relay
.
TensorType
((
1
,
2
,
3
,
4
),
"float32"
)
t2
=
relay
.
TensorType
((
1
,
2
,
3
,
4
),
"float32"
)
tp1
=
relay
.
Type
Param
(
"v1"
,
relay
.
Kind
.
Type
)
tp1
=
relay
.
Type
Var
(
"v1"
,
relay
.
Kind
.
Type
)
tp2
=
relay
.
Type
Param
(
"v2"
,
relay
.
Kind
.
Type
)
tp2
=
relay
.
Type
Var
(
"v2"
,
relay
.
Kind
.
Type
)
tup1
=
relay
.
TupleType
(
tvm
.
convert
([
t1
,
t2
,
tp1
]))
tup1
=
relay
.
TupleType
(
tvm
.
convert
([
t1
,
t2
,
tp1
]))
tup2
=
relay
.
TupleType
(
tvm
.
convert
([
t1
,
t2
,
tp1
]))
tup2
=
relay
.
TupleType
(
tvm
.
convert
([
t1
,
t2
,
tp1
]))
...
@@ -164,11 +163,11 @@ def test_type_relation_alpha_equal():
...
@@ -164,11 +163,11 @@ def test_type_relation_alpha_equal():
def
test_constant_alpha_equal
():
def
test_constant_alpha_equal
():
x
=
conver
t
(
1
)
x
=
relay
.
cons
t
(
1
)
y
=
conver
t
(
2
)
y
=
relay
.
cons
t
(
2
)
assert
alpha_equal
(
x
,
x
)
assert
alpha_equal
(
x
,
x
)
assert
not
alpha_equal
(
x
,
y
)
assert
not
alpha_equal
(
x
,
y
)
assert
alpha_equal
(
x
,
conver
t
(
1
))
assert
alpha_equal
(
x
,
relay
.
cons
t
(
1
))
def
test_var_alpha_equal
():
def
test_var_alpha_equal
():
...
@@ -180,9 +179,9 @@ def test_var_alpha_equal():
...
@@ -180,9 +179,9 @@ def test_var_alpha_equal():
assert
not
alpha_equal
(
v1
,
v2
)
assert
not
alpha_equal
(
v1
,
v2
)
# let node allows for setting the eq_map
# let node allows for setting the eq_map
l1
=
relay
.
Let
(
v1
,
conver
t
(
1
),
v1
)
l1
=
relay
.
Let
(
v1
,
relay
.
cons
t
(
1
),
v1
)
l2
=
relay
.
Let
(
v2
,
conver
t
(
1
),
v2
)
l2
=
relay
.
Let
(
v2
,
relay
.
cons
t
(
1
),
v2
)
l3
=
relay
.
Let
(
v1
,
conver
t
(
1
),
v2
)
l3
=
relay
.
Let
(
v1
,
relay
.
cons
t
(
1
),
v2
)
assert
alpha_equal
(
l1
,
l2
)
assert
alpha_equal
(
l1
,
l2
)
assert
not
alpha_equal
(
l1
,
l3
)
assert
not
alpha_equal
(
l1
,
l3
)
...
@@ -223,34 +222,34 @@ def test_tuple_alpha_equal():
...
@@ -223,34 +222,34 @@ def test_tuple_alpha_equal():
# unit value is a valid tuple
# unit value is a valid tuple
assert
alpha_equal
(
relay
.
Tuple
([]),
relay
.
Tuple
([]))
assert
alpha_equal
(
relay
.
Tuple
([]),
relay
.
Tuple
([]))
tup
=
relay
.
Tuple
([
v1
,
convert
(
2
),
convert
(
3
),
relay
.
Tuple
([
conver
t
(
4
)])])
tup
=
relay
.
Tuple
([
v1
,
relay
.
const
(
2
),
relay
.
const
(
3
),
relay
.
Tuple
([
relay
.
cons
t
(
4
)])])
same
=
relay
.
Tuple
([
v1
,
convert
(
2
),
convert
(
3
),
relay
.
Tuple
([
conver
t
(
4
)])])
same
=
relay
.
Tuple
([
v1
,
relay
.
const
(
2
),
relay
.
const
(
3
),
relay
.
Tuple
([
relay
.
cons
t
(
4
)])])
assert
alpha_equal
(
tup
,
same
)
assert
alpha_equal
(
tup
,
same
)
# use the eq_map
# use the eq_map
let_tup
=
relay
.
Let
(
v1
,
tup
,
v1
)
let_tup
=
relay
.
Let
(
v1
,
tup
,
v1
)
let_mapped
=
relay
.
Let
(
v2
,
relay
.
Tuple
([
v2
,
convert
(
2
),
conver
t
(
3
),
let_mapped
=
relay
.
Let
(
v2
,
relay
.
Tuple
([
v2
,
relay
.
const
(
2
),
relay
.
cons
t
(
3
),
relay
.
Tuple
([
conver
t
(
4
)])]),
relay
.
Tuple
([
relay
.
cons
t
(
4
)])]),
v2
)
v2
)
assert
alpha_equal
(
let_tup
,
let_mapped
)
assert
alpha_equal
(
let_tup
,
let_mapped
)
more_fields
=
relay
.
Tuple
([
v1
,
convert
(
2
),
convert
(
3
),
relay
.
Tuple
([
conver
t
(
4
)]),
v2
])
more_fields
=
relay
.
Tuple
([
v1
,
relay
.
const
(
2
),
relay
.
const
(
3
),
relay
.
Tuple
([
relay
.
cons
t
(
4
)]),
v2
])
assert
not
alpha_equal
(
tup
,
more_fields
)
assert
not
alpha_equal
(
tup
,
more_fields
)
fewer_fields
=
relay
.
Tuple
([
v1
,
convert
(
2
),
conver
t
(
3
)])
fewer_fields
=
relay
.
Tuple
([
v1
,
relay
.
const
(
2
),
relay
.
cons
t
(
3
)])
assert
not
alpha_equal
(
tup
,
fewer_fields
)
assert
not
alpha_equal
(
tup
,
fewer_fields
)
different_end
=
relay
.
Tuple
([
v1
,
convert
(
2
),
conver
t
(
3
),
different_end
=
relay
.
Tuple
([
v1
,
relay
.
const
(
2
),
relay
.
cons
t
(
3
),
relay
.
Tuple
([
conver
t
(
5
)])])
relay
.
Tuple
([
relay
.
cons
t
(
5
)])])
assert
not
alpha_equal
(
tup
,
different_end
)
assert
not
alpha_equal
(
tup
,
different_end
)
different_start
=
relay
.
Tuple
([
v2
,
convert
(
2
),
conver
t
(
3
),
different_start
=
relay
.
Tuple
([
v2
,
relay
.
const
(
2
),
relay
.
cons
t
(
3
),
relay
.
Tuple
([
conver
t
(
4
)])])
relay
.
Tuple
([
relay
.
cons
t
(
4
)])])
assert
not
alpha_equal
(
tup
,
different_start
)
assert
not
alpha_equal
(
tup
,
different_start
)
longer_at_end
=
relay
.
Tuple
([
v1
,
convert
(
2
),
conver
t
(
3
),
longer_at_end
=
relay
.
Tuple
([
v1
,
relay
.
const
(
2
),
relay
.
cons
t
(
3
),
relay
.
Tuple
([
convert
(
4
),
conver
t
(
5
)])])
relay
.
Tuple
([
relay
.
const
(
4
),
relay
.
cons
t
(
5
)])])
assert
not
alpha_equal
(
tup
,
longer_at_end
)
assert
not
alpha_equal
(
tup
,
longer_at_end
)
...
@@ -273,10 +272,10 @@ def test_function_alpha_equal():
...
@@ -273,10 +272,10 @@ def test_function_alpha_equal():
v4
=
relay
.
Var
(
"v4"
,
tt2
)
v4
=
relay
.
Var
(
"v4"
,
tt2
)
vret
=
relay
.
Constant
(
tvm
.
nd
.
array
(
np
.
ones
(
1
)))
vret
=
relay
.
Constant
(
tvm
.
nd
.
array
(
np
.
ones
(
1
)))
tp1
=
relay
.
Type
Param
(
"tp1"
,
relay
.
Kind
.
Type
)
tp1
=
relay
.
Type
Var
(
"tp1"
,
relay
.
Kind
.
Type
)
tp2
=
relay
.
Type
Param
(
"tp2"
,
relay
.
Kind
.
Type
)
tp2
=
relay
.
Type
Var
(
"tp2"
,
relay
.
Kind
.
Type
)
tp3
=
relay
.
Type
Param
(
"tp3"
,
relay
.
Kind
.
Shape
)
tp3
=
relay
.
Type
Var
(
"tp3"
,
relay
.
Kind
.
Shape
)
tp4
=
relay
.
Type
Param
(
"tp4"
,
relay
.
Kind
.
Shape
)
tp4
=
relay
.
Type
Var
(
"tp4"
,
relay
.
Kind
.
Shape
)
basic_args
=
[
relay
.
Var
(
"v3"
,
tt1
),
relay
.
Var
(
"v4"
,
tt2
)]
basic_args
=
[
relay
.
Var
(
"v3"
,
tt1
),
relay
.
Var
(
"v4"
,
tt2
)]
basic_tps
=
[
tp1
,
tp2
]
basic_tps
=
[
tp1
,
tp2
]
...
@@ -346,11 +345,11 @@ def test_call_alpha_equal():
...
@@ -346,11 +345,11 @@ def test_call_alpha_equal():
tt1
=
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
)
tt1
=
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
)
tt2
=
relay
.
TensorType
((),
"int8"
)
tt2
=
relay
.
TensorType
((),
"int8"
)
basic_args
=
[
convert
(
1
),
conver
t
(
2
),
v2
,
relay
.
Tuple
([])]
basic_args
=
[
relay
.
const
(
1
),
relay
.
cons
t
(
2
),
v2
,
relay
.
Tuple
([])]
# manually writing out args to ensure that args does not rely on
# manually writing out args to ensure that args does not rely on
# pointer equality
# pointer equality
call
=
relay
.
Call
(
v1
,
[
convert
(
1
),
conver
t
(
2
),
v2
,
relay
.
Tuple
([])],
call
=
relay
.
Call
(
v1
,
[
relay
.
const
(
1
),
relay
.
cons
t
(
2
),
v2
,
relay
.
Tuple
([])],
attr1
,
[
tt1
])
attr1
,
[
tt1
])
same
=
relay
.
Call
(
v1
,
basic_args
,
attr1
,
[
tt1
])
same
=
relay
.
Call
(
v1
,
basic_args
,
attr1
,
[
tt1
])
assert
alpha_equal
(
call
,
same
)
assert
alpha_equal
(
call
,
same
)
...
@@ -358,19 +357,19 @@ def test_call_alpha_equal():
...
@@ -358,19 +357,19 @@ def test_call_alpha_equal():
different_fn
=
relay
.
Call
(
v2
,
basic_args
,
attr1
,
[
tt1
])
different_fn
=
relay
.
Call
(
v2
,
basic_args
,
attr1
,
[
tt1
])
assert
not
alpha_equal
(
call
,
different_fn
)
assert
not
alpha_equal
(
call
,
different_fn
)
fewer_args
=
relay
.
Call
(
v1
,
[
convert
(
1
),
conver
t
(
2
),
v2
],
attr1
,
[
tt1
])
fewer_args
=
relay
.
Call
(
v1
,
[
relay
.
const
(
1
),
relay
.
cons
t
(
2
),
v2
],
attr1
,
[
tt1
])
assert
not
alpha_equal
(
call
,
fewer_args
)
assert
not
alpha_equal
(
call
,
fewer_args
)
reordered_args
=
relay
.
Call
(
v1
,
[
convert
(
2
),
conver
t
(
1
),
reordered_args
=
relay
.
Call
(
v1
,
[
relay
.
const
(
2
),
relay
.
cons
t
(
1
),
relay
.
Tuple
([]),
v2
],
attr1
,
[
tt1
])
relay
.
Tuple
([]),
v2
],
attr1
,
[
tt1
])
assert
not
alpha_equal
(
call
,
reordered_args
)
assert
not
alpha_equal
(
call
,
reordered_args
)
different_args
=
relay
.
Call
(
v1
,
[
convert
(
1
),
convert
(
2
),
conver
t
(
3
)],
different_args
=
relay
.
Call
(
v1
,
[
relay
.
const
(
1
),
relay
.
const
(
2
),
relay
.
cons
t
(
3
)],
attr1
,
[
tt1
])
attr1
,
[
tt1
])
assert
not
alpha_equal
(
call
,
different_args
)
assert
not
alpha_equal
(
call
,
different_args
)
more_args
=
relay
.
Call
(
v1
,
[
convert
(
1
),
conver
t
(
2
),
v2
,
relay
.
Tuple
([]),
more_args
=
relay
.
Call
(
v1
,
[
relay
.
const
(
1
),
relay
.
cons
t
(
2
),
v2
,
relay
.
Tuple
([]),
convert
(
3
),
conver
t
(
4
)],
attr1
,
[
tt1
])
relay
.
const
(
3
),
relay
.
cons
t
(
4
)],
attr1
,
[
tt1
])
assert
not
alpha_equal
(
call
,
more_args
)
assert
not
alpha_equal
(
call
,
more_args
)
different_attrs
=
relay
.
Call
(
v1
,
basic_args
,
attr2
,
[
tt1
])
different_attrs
=
relay
.
Call
(
v1
,
basic_args
,
attr2
,
[
tt1
])
...
@@ -394,27 +393,27 @@ def test_let_alpha_equal():
...
@@ -394,27 +393,27 @@ def test_let_alpha_equal():
v2
=
relay
.
Var
(
"v2"
)
v2
=
relay
.
Var
(
"v2"
)
v3
=
relay
.
Var
(
"v3"
)
v3
=
relay
.
Var
(
"v3"
)
let
=
relay
.
Let
(
v1
,
conver
t
(
2
),
v1
)
let
=
relay
.
Let
(
v1
,
relay
.
cons
t
(
2
),
v1
)
mapped
=
relay
.
Let
(
v2
,
conver
t
(
2
),
v2
)
mapped
=
relay
.
Let
(
v2
,
relay
.
cons
t
(
2
),
v2
)
assert
alpha_equal
(
let
,
mapped
)
assert
alpha_equal
(
let
,
mapped
)
mismatched_var
=
relay
.
Let
(
v2
,
conver
t
(
2
),
v3
)
mismatched_var
=
relay
.
Let
(
v2
,
relay
.
cons
t
(
2
),
v3
)
assert
not
alpha_equal
(
let
,
mismatched_var
)
assert
not
alpha_equal
(
let
,
mismatched_var
)
different_value
=
relay
.
Let
(
v2
,
conver
t
(
3
),
v2
)
different_value
=
relay
.
Let
(
v2
,
relay
.
cons
t
(
3
),
v2
)
assert
not
alpha_equal
(
let
,
different_value
)
assert
not
alpha_equal
(
let
,
different_value
)
different_body
=
relay
.
Let
(
v2
,
convert
(
3
),
conver
t
(
12
))
different_body
=
relay
.
Let
(
v2
,
relay
.
const
(
3
),
relay
.
cons
t
(
12
))
assert
not
alpha_equal
(
let
,
different_body
)
assert
not
alpha_equal
(
let
,
different_body
)
# specified types must match
# specified types must match
let_with_type
=
relay
.
Let
(
v1_wtype
,
conver
t
(
2
),
v1_wtype
)
let_with_type
=
relay
.
Let
(
v1_wtype
,
relay
.
cons
t
(
2
),
v1_wtype
)
same_type
=
relay
.
Let
(
v1_wtype
,
conver
t
(
2
),
v1_wtype
)
same_type
=
relay
.
Let
(
v1_wtype
,
relay
.
cons
t
(
2
),
v1_wtype
)
assert
alpha_equal
(
let_with_type
,
same_type
)
assert
alpha_equal
(
let_with_type
,
same_type
)
assert
not
alpha_equal
(
let
,
let_with_type
)
assert
not
alpha_equal
(
let
,
let_with_type
)
v2
=
relay
.
Var
(
"v1"
,
tt2
)
v2
=
relay
.
Var
(
"v1"
,
tt2
)
different_type
=
relay
.
Let
(
v2
,
conver
t
(
2
),
v2
)
different_type
=
relay
.
Let
(
v2
,
relay
.
cons
t
(
2
),
v2
)
assert
not
alpha_equal
(
let_with_type
,
different_type
)
assert
not
alpha_equal
(
let_with_type
,
different_type
)
...
@@ -422,17 +421,17 @@ def test_if_alpha_equal():
...
@@ -422,17 +421,17 @@ def test_if_alpha_equal():
v1
=
relay
.
Var
(
"v1"
)
v1
=
relay
.
Var
(
"v1"
)
v2
=
relay
.
Var
(
"v2"
)
v2
=
relay
.
Var
(
"v2"
)
if_sample
=
relay
.
If
(
v1
,
convert
(
1
),
relay
.
Tuple
([
convert
(
2
),
conver
t
(
3
)]))
if_sample
=
relay
.
If
(
v1
,
relay
.
const
(
1
),
relay
.
Tuple
([
relay
.
const
(
2
),
relay
.
cons
t
(
3
)]))
same
=
relay
.
If
(
v1
,
convert
(
1
),
relay
.
Tuple
([
convert
(
2
),
conver
t
(
3
)]))
same
=
relay
.
If
(
v1
,
relay
.
const
(
1
),
relay
.
Tuple
([
relay
.
const
(
2
),
relay
.
cons
t
(
3
)]))
assert
alpha_equal
(
if_sample
,
same
)
assert
alpha_equal
(
if_sample
,
same
)
different_cond
=
relay
.
If
(
v2
,
convert
(
1
),
relay
.
Tuple
([
convert
(
2
),
conver
t
(
3
)]))
different_cond
=
relay
.
If
(
v2
,
relay
.
const
(
1
),
relay
.
Tuple
([
relay
.
const
(
2
),
relay
.
cons
t
(
3
)]))
assert
not
alpha_equal
(
if_sample
,
different_cond
)
assert
not
alpha_equal
(
if_sample
,
different_cond
)
different_true
=
relay
.
If
(
v1
,
convert
(
2
),
relay
.
Tuple
([
convert
(
2
),
conver
t
(
3
)]))
different_true
=
relay
.
If
(
v1
,
relay
.
const
(
2
),
relay
.
Tuple
([
relay
.
const
(
2
),
relay
.
cons
t
(
3
)]))
assert
not
alpha_equal
(
if_sample
,
different_true
)
assert
not
alpha_equal
(
if_sample
,
different_true
)
different_false
=
relay
.
If
(
v1
,
conver
t
(
1
),
relay
.
Tuple
([]))
different_false
=
relay
.
If
(
v1
,
relay
.
cons
t
(
1
),
relay
.
Tuple
([]))
assert
not
alpha_equal
(
if_sample
,
different_false
)
assert
not
alpha_equal
(
if_sample
,
different_false
)
...
...
tests/python/relay/test_pass_check_kind.py
View file @
8876eac8
...
@@ -4,7 +4,7 @@ from tvm.relay.ir_pass import check_kind
...
@@ -4,7 +4,7 @@ from tvm.relay.ir_pass import check_kind
def
test_tuple_kind
():
def
test_tuple_kind
():
# only contain type kinds
# only contain type kinds
tp
=
relay
.
Type
Param
(
'tp'
,
relay
.
Kind
.
Type
)
tp
=
relay
.
Type
Var
(
'tp'
,
relay
.
Kind
.
Type
)
tt
=
relay
.
TensorType
(
tvm
.
convert
([
1
,
2
,
3
]),
'float32'
)
tt
=
relay
.
TensorType
(
tvm
.
convert
([
1
,
2
,
3
]),
'float32'
)
tf
=
relay
.
FuncType
(
tvm
.
convert
([]),
tt
,
tvm
.
convert
([]),
tvm
.
convert
([]))
tf
=
relay
.
FuncType
(
tvm
.
convert
([]),
tt
,
tvm
.
convert
([]),
tvm
.
convert
([]))
fields
=
tvm
.
convert
([
tp
,
tf
,
tt
])
fields
=
tvm
.
convert
([
tp
,
tf
,
tt
])
...
@@ -15,8 +15,8 @@ def test_tuple_kind():
...
@@ -15,8 +15,8 @@ def test_tuple_kind():
def
test_func_kind
():
def
test_func_kind
():
# only contain type kinds
# only contain type kinds
tp1
=
relay
.
Type
Param
(
'tp1'
,
relay
.
Kind
.
Type
)
tp1
=
relay
.
Type
Var
(
'tp1'
,
relay
.
Kind
.
Type
)
tp2
=
relay
.
Type
Param
(
'tp2'
,
relay
.
Kind
.
Type
)
tp2
=
relay
.
Type
Var
(
'tp2'
,
relay
.
Kind
.
Type
)
shape
=
tvm
.
convert
([
1
,
2
,
3
])
shape
=
tvm
.
convert
([
1
,
2
,
3
])
dtype
=
'float32'
dtype
=
'float32'
...
@@ -35,7 +35,7 @@ def test_func_kind():
...
@@ -35,7 +35,7 @@ def test_func_kind():
def
test_relation_kind
():
def
test_relation_kind
():
# only have type kinds for arguments
# only have type kinds for arguments
tp
=
relay
.
Type
Param
(
'tp'
,
relay
.
Kind
.
Type
)
tp
=
relay
.
Type
Var
(
'tp'
,
relay
.
Kind
.
Type
)
tt
=
relay
.
TensorType
(
tvm
.
convert
([
1
,
2
,
3
]),
'float32'
)
tt
=
relay
.
TensorType
(
tvm
.
convert
([
1
,
2
,
3
]),
'float32'
)
tf
=
relay
.
FuncType
(
tvm
.
convert
([]),
tt
,
tvm
.
convert
([]),
tvm
.
convert
([]))
tf
=
relay
.
FuncType
(
tvm
.
convert
([]),
tt
,
tvm
.
convert
([]),
tvm
.
convert
([]))
args
=
tvm
.
convert
([
tf
,
tt
,
tp
])
args
=
tvm
.
convert
([
tf
,
tt
,
tp
])
...
@@ -45,9 +45,9 @@ def test_relation_kind():
...
@@ -45,9 +45,9 @@ def test_relation_kind():
def
test_invalid_tuple_kind
():
def
test_invalid_tuple_kind
():
tp1
=
relay
.
Type
Param
(
'tp1'
,
relay
.
Kind
.
Shape
)
tp1
=
relay
.
Type
Var
(
'tp1'
,
relay
.
Kind
.
Shape
)
tp2
=
relay
.
Type
Param
(
'tp2'
,
relay
.
Kind
.
BaseType
)
tp2
=
relay
.
Type
Var
(
'tp2'
,
relay
.
Kind
.
BaseType
)
tp3
=
relay
.
Type
Param
(
'tp3'
,
relay
.
Kind
.
ShapeVar
)
tp3
=
relay
.
Type
Var
(
'tp3'
,
relay
.
Kind
.
ShapeVar
)
fields
=
tvm
.
convert
([
tp1
,
tp2
,
tp3
])
fields
=
tvm
.
convert
([
tp1
,
tp2
,
tp3
])
tup_ty
=
relay
.
TupleType
(
fields
)
tup_ty
=
relay
.
TupleType
(
fields
)
...
@@ -55,9 +55,9 @@ def test_invalid_tuple_kind():
...
@@ -55,9 +55,9 @@ def test_invalid_tuple_kind():
def
test_invalid_func_kind
():
def
test_invalid_func_kind
():
tp1
=
relay
.
Type
Param
(
'tp1'
,
relay
.
Kind
.
Shape
)
tp1
=
relay
.
Type
Var
(
'tp1'
,
relay
.
Kind
.
Shape
)
tp2
=
relay
.
Type
Param
(
'tp2'
,
relay
.
Kind
.
BaseType
)
tp2
=
relay
.
Type
Var
(
'tp2'
,
relay
.
Kind
.
BaseType
)
tp3
=
relay
.
Type
Param
(
'tp3'
,
relay
.
Kind
.
ShapeVar
)
tp3
=
relay
.
Type
Var
(
'tp3'
,
relay
.
Kind
.
ShapeVar
)
type_params
=
tvm
.
convert
([
tp1
,
tp2
,
tp3
])
type_params
=
tvm
.
convert
([
tp1
,
tp2
,
tp3
])
type_constraints
=
tvm
.
convert
([])
type_constraints
=
tvm
.
convert
([])
...
@@ -69,9 +69,9 @@ def test_invalid_func_kind():
...
@@ -69,9 +69,9 @@ def test_invalid_func_kind():
def
test_invalid_relation_kind
():
def
test_invalid_relation_kind
():
tp1
=
relay
.
Type
Param
(
'tp1'
,
relay
.
Kind
.
Shape
)
tp1
=
relay
.
Type
Var
(
'tp1'
,
relay
.
Kind
.
Shape
)
tp2
=
relay
.
Type
Param
(
'tp2'
,
relay
.
Kind
.
BaseType
)
tp2
=
relay
.
Type
Var
(
'tp2'
,
relay
.
Kind
.
BaseType
)
tp3
=
relay
.
Type
Param
(
'tp3'
,
relay
.
Kind
.
ShapeVar
)
tp3
=
relay
.
Type
Var
(
'tp3'
,
relay
.
Kind
.
ShapeVar
)
args
=
tvm
.
convert
([
tp1
,
tp2
,
tp3
])
args
=
tvm
.
convert
([
tp1
,
tp2
,
tp3
])
tr
=
relay
.
TypeRelation
(
None
,
args
,
2
,
None
)
tr
=
relay
.
TypeRelation
(
None
,
args
,
2
,
None
)
...
@@ -79,19 +79,19 @@ def test_invalid_relation_kind():
...
@@ -79,19 +79,19 @@ def test_invalid_relation_kind():
def
test_func_with_invalid_ret_type
():
def
test_func_with_invalid_ret_type
():
tp1
=
relay
.
Type
Param
(
'tp1'
,
relay
.
Kind
.
Type
)
tp1
=
relay
.
Type
Var
(
'tp1'
,
relay
.
Kind
.
Type
)
tp2
=
relay
.
Type
Param
(
'tp2'
,
relay
.
Kind
.
Shape
)
tp2
=
relay
.
Type
Var
(
'tp2'
,
relay
.
Kind
.
Shape
)
tf
=
relay
.
FuncType
(
tvm
.
convert
([
tp1
]),
tp2
,
tvm
.
convert
([
tp1
,
tp2
]),
tvm
.
convert
([]))
tf
=
relay
.
FuncType
(
tvm
.
convert
([
tp1
]),
tp2
,
tvm
.
convert
([
tp1
,
tp2
]),
tvm
.
convert
([]))
def
test_func_with_invalid_arg_types
():
def
test_func_with_invalid_arg_types
():
tp1
=
relay
.
Type
Param
(
'tp1'
,
relay
.
Kind
.
Shape
)
tp1
=
relay
.
Type
Var
(
'tp1'
,
relay
.
Kind
.
Shape
)
tp2
=
relay
.
Type
Param
(
'tp2'
,
relay
.
Kind
.
Type
)
tp2
=
relay
.
Type
Var
(
'tp2'
,
relay
.
Kind
.
Type
)
tf
=
relay
.
FuncType
(
tvm
.
convert
([
tp1
]),
tp2
,
tvm
.
convert
([
tp1
,
tp2
]),
tvm
.
convert
([]))
tf
=
relay
.
FuncType
(
tvm
.
convert
([
tp1
]),
tp2
,
tvm
.
convert
([
tp1
,
tp2
]),
tvm
.
convert
([]))
def
test_func_with_invalid_tuple
():
def
test_func_with_invalid_tuple
():
tp1
=
relay
.
Type
Param
(
'tp1'
,
relay
.
Kind
.
Shape
)
tp1
=
relay
.
Type
Var
(
'tp1'
,
relay
.
Kind
.
Shape
)
ret_type
=
relay
.
TupleType
(
tvm
.
convert
([
tp1
,
tp1
,
tp1
]))
ret_type
=
relay
.
TupleType
(
tvm
.
convert
([
tp1
,
tp1
,
tp1
]))
...
@@ -100,9 +100,9 @@ def test_func_with_invalid_tuple():
...
@@ -100,9 +100,9 @@ def test_func_with_invalid_tuple():
def
test_func_with_invalid_relation
():
def
test_func_with_invalid_relation
():
tp1
=
relay
.
Type
Param
(
'tp1'
,
relay
.
Kind
.
Type
)
tp1
=
relay
.
Type
Var
(
'tp1'
,
relay
.
Kind
.
Type
)
tp2
=
relay
.
Type
Param
(
'tp2'
,
relay
.
Kind
.
Shape
)
tp2
=
relay
.
Type
Var
(
'tp2'
,
relay
.
Kind
.
Shape
)
tp3
=
relay
.
Type
Param
(
'tp3'
,
relay
.
Kind
.
ShapeVar
)
tp3
=
relay
.
Type
Var
(
'tp3'
,
relay
.
Kind
.
ShapeVar
)
tr
=
relay
.
TypeRelation
(
None
,
tvm
.
convert
([
tp2
,
tp3
]),
1
,
None
)
tr
=
relay
.
TypeRelation
(
None
,
tvm
.
convert
([
tp2
,
tp3
]),
1
,
None
)
...
@@ -113,7 +113,7 @@ def test_func_with_invalid_relation():
...
@@ -113,7 +113,7 @@ def test_func_with_invalid_relation():
def
test_tuple_with_invalid_func
():
def
test_tuple_with_invalid_func
():
tensor_type
=
relay
.
TensorType
(
tvm
.
convert
([
1
,
2
,
3
]),
'float32'
)
tensor_type
=
relay
.
TensorType
(
tvm
.
convert
([
1
,
2
,
3
]),
'float32'
)
tp1
=
relay
.
Type
Param
(
'tp1'
,
relay
.
Kind
.
Shape
)
tp1
=
relay
.
Type
Var
(
'tp1'
,
relay
.
Kind
.
Shape
)
tf
=
relay
.
FuncType
(
tvm
.
convert
([]),
tp1
,
tvm
.
convert
([
tp1
]),
tvm
.
convert
([]))
tf
=
relay
.
FuncType
(
tvm
.
convert
([]),
tp1
,
tvm
.
convert
([
tp1
]),
tvm
.
convert
([]))
tup_ty
=
relay
.
TupleType
(
tvm
.
convert
([
tensor_type
,
tf
]))
tup_ty
=
relay
.
TupleType
(
tvm
.
convert
([
tensor_type
,
tf
]))
...
...
tests/python/relay/test_pass_dead_code_elimination.py
View file @
8876eac8
import
tvm
import
tvm
from
tvm
import
relay
from
tvm
import
relay
from
tvm.relay.ir_pass
import
dead_code_elimination
,
alpha_equal
from
tvm.relay.ir_pass
import
dead_code_elimination
,
alpha_equal
from
tvm.relay.ir_builder
import
convert
,
IRBuilder
from
tvm.relay.op
import
log
,
add
,
equal
,
subtract
from
tvm.relay.op
import
log
,
add
,
equal
,
subtract
...
@@ -19,9 +18,9 @@ class env:
...
@@ -19,9 +18,9 @@ class env:
self
.
tt
=
relay
.
TensorType
(
self
.
shape
,
"float32"
)
self
.
tt
=
relay
.
TensorType
(
self
.
shape
,
"float32"
)
self
.
int32
=
relay
.
TensorType
([],
"int32"
)
self
.
int32
=
relay
.
TensorType
([],
"int32"
)
self
.
float32
=
relay
.
TensorType
([],
"float32"
)
self
.
float32
=
relay
.
TensorType
([],
"float32"
)
self
.
one
=
conver
t
(
1.0
)
self
.
one
=
relay
.
cons
t
(
1.0
)
self
.
two
=
conver
t
(
2.0
)
self
.
two
=
relay
.
cons
t
(
2.0
)
self
.
three
=
conver
t
(
3.0
)
self
.
three
=
relay
.
cons
t
(
3.0
)
e
=
env
()
e
=
env
()
...
@@ -58,9 +57,12 @@ def test_recursion():
...
@@ -58,9 +57,12 @@ def test_recursion():
f
=
relay
.
Var
(
"f"
)
f
=
relay
.
Var
(
"f"
)
n
=
relay
.
Var
(
"n"
,
e
.
int32
)
n
=
relay
.
Var
(
"n"
,
e
.
int32
)
data
=
relay
.
Var
(
"data"
,
e
.
float32
)
data
=
relay
.
Var
(
"data"
,
e
.
float32
)
funcbody
=
relay
.
If
(
equal
(
n
,
convert
(
0
)),
data
,
f
(
subtract
(
n
,
convert
(
1.0
)),
log
(
data
)))
funcbody
=
relay
.
If
(
equal
(
n
,
relay
.
const
(
0
)),
data
,
relay
.
Call
(
f
,
[
subtract
(
n
,
relay
.
const
(
1.0
)),
log
(
data
)]))
value
=
relay
.
Function
([
n
,
data
],
funcbody
,
e
.
float32
,
[])
value
=
relay
.
Function
([
n
,
data
],
funcbody
,
e
.
float32
,
[])
orig
=
relay
.
Let
(
f
,
funcbody
,
f
(
convert
(
2.0
),
convert
(
10000.0
)
))
orig
=
relay
.
Let
(
f
,
funcbody
,
relay
.
Call
(
f
,
[
relay
.
const
(
2.0
),
relay
.
const
(
10000.0
)]
))
assert
alpha_equal
(
dead_code_elimination
(
orig
),
orig
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
orig
)
assert
alpha_equal
(
dead_code_elimination
(
relay
.
Let
(
f
,
funcbody
,
e
.
three
)),
e
.
three
)
assert
alpha_equal
(
dead_code_elimination
(
relay
.
Let
(
f
,
funcbody
,
e
.
three
)),
e
.
three
)
...
@@ -70,8 +72,10 @@ def test_op_let():
...
@@ -70,8 +72,10 @@ def test_op_let():
def
test_if
():
def
test_if
():
orig
=
relay
.
If
(
convert
(
True
),
e
.
a
,
e
.
b
)
cond
=
relay
.
const
(
True
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
e
.
a
)
orig
=
relay
.
If
(
cond
,
e
.
a
,
e
.
b
)
y
=
dead_code_elimination
(
orig
)
assert
alpha_equal
(
y
,
e
.
a
)
def
test_tuple_get_item
():
def
test_tuple_get_item
():
...
@@ -82,10 +86,10 @@ def test_tuple_get_item():
...
@@ -82,10 +86,10 @@ def test_tuple_get_item():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_if
()
test_let
()
test_let
()
test_used_let
()
test_used_let
()
test_chain_unused_let
()
test_chain_unused_let
()
test_recursion
()
test_recursion
()
test_op_let
()
test_op_let
()
test_if
()
test_tuple_get_item
()
test_tuple_get_item
()
tests/python/relay/test_pass_free_vars.py
View file @
8876eac8
...
@@ -28,7 +28,7 @@ def test_tuple():
...
@@ -28,7 +28,7 @@ def test_tuple():
def
test_free_type_vars
():
def
test_free_type_vars
():
tp
=
relay
.
Type
Param
(
""
)
tp
=
relay
.
Type
Var
(
""
)
ty
=
relay
.
TupleType
([
tp
,
relay
.
TensorType
([],
"int32"
)])
ty
=
relay
.
TupleType
([
tp
,
relay
.
TensorType
([],
"int32"
)])
x
=
relay
.
Var
(
"x"
,
ty
)
x
=
relay
.
Var
(
"x"
,
ty
)
y
=
relay
.
Var
(
"y"
)
y
=
relay
.
Var
(
"y"
)
...
...
tests/python/relay/test_type_infer.py
View file @
8876eac8
...
@@ -4,34 +4,17 @@
...
@@ -4,34 +4,17 @@
import
tvm
import
tvm
import
numpy
as
np
import
numpy
as
np
from
tvm.relay.ir_pass
import
infer_type
from
tvm.relay.ir_pass
import
infer_type
from
tvm.relay.ir_builder
import
IRBuilder
,
func_type
from
tvm.relay.ir_builder
import
scalar_type
,
convert
,
tensor_type
from
tvm.relay.env
import
Environment
from
tvm.relay.op
import
log
,
add
,
equal
,
subtract
,
concatenate
from
tvm.relay.expr
import
Function
from
tvm
import
relay
from
tvm
import
relay
def
assert_has_type
(
expr
,
typ
,
env
=
Environment
({})):
checked_expr
=
infer_type
(
env
,
expr
)
checked_type
=
checked_expr
.
checked_type
if
checked_type
!=
typ
:
raise
RuntimeError
(
"Type mismatch
%
s vs
%
s"
%
(
checked_type
,
typ
))
def
assert_decl_has_type
(
env
,
name
,
typ
):
func
=
env
[
name
]
assert
func
.
checked_type
==
typ
def
test_monomorphic_let
():
def
test_monomorphic_let
():
"Program: let x = 1; return x"
"Program: let x = 1; return x"
b
=
IRBuilder
()
sb
=
relay
.
ScopeBuilder
()
x
=
b
.
let
(
'x'
,
1.0
,
value_type
=
scalar_type
(
'float64'
))
x
=
sb
.
let
(
'x'
,
relay
.
const
(
1.0
,
"float64"
))
b
.
ret
(
x
)
sb
.
ret
(
x
)
xchecked
=
relay
.
ir_pass
.
infer_type
(
sb
.
get
())
assert
xchecked
.
checked_type
==
relay
.
scalar_type
(
"float64"
)
prog
,
env
=
b
.
get
()
assert_has_type
(
prog
,
scalar_type
(
'float64'
))
def
test_dual_op
():
def
test_dual_op
():
"""Program:
"""Program:
...
@@ -41,31 +24,29 @@ def test_dual_op():
...
@@ -41,31 +24,29 @@ def test_dual_op():
return t1;
return t1;
}
}
"""
"""
b
=
IRBuilder
(
)
tp
=
relay
.
TensorType
((
10
,
10
),
"float32"
)
with
b
.
function
((
'x'
,
tensor_type
(
10
,
10
)))
as
func
:
x
=
relay
.
var
(
"x"
,
tp
)
x
,
=
func
.
param_ids
()
sb
=
relay
.
ScopeBuilder
()
t1
=
b
.
let
(
't1'
,
log
(
x
))
t1
=
sb
.
let
(
"t1"
,
relay
.
log
(
x
))
t2
=
b
.
let
(
't2'
,
add
(
t1
,
x
))
t2
=
sb
.
let
(
"t2"
,
relay
.
add
(
t1
,
x
))
b
.
ret
(
t2
)
s
b
.
ret
(
t2
)
f
=
relay
.
Function
([
x
],
sb
.
get
())
assert_has_type
(
func
.
to_func
(),
fchecked
=
relay
.
ir_pass
.
infer_type
(
f
)
func_type
([
tensor_type
(
10
,
10
)],
tensor_type
(
10
,
10
))
)
assert
fchecked
.
checked_type
==
relay
.
FuncType
([
tp
],
tp
)
def
test_decl
():
def
test_decl
():
"""Program:
"""Program:
def f(x : Tensor[f32, (10, 10)]) {
def f(x : Tensor[(10, 10), f32]) {
let lx = log(x);
return log(x);
return lx;
}
}
"""
"""
b
=
IRBuilder
()
sb
=
relay
.
ScopeBuilder
()
x
=
b
.
param
(
'x'
)
tp
=
relay
.
TensorType
((
10
,
10
))
with
b
.
decl
(
'f'
,
x
):
x
=
relay
.
var
(
"x"
,
tp
)
lx
=
b
.
let
(
'lx'
,
log
(
x
))
f
=
relay
.
Function
([
x
],
relay
.
log
(
x
))
b
.
ret
(
lx
)
fchecked
=
relay
.
ir_pass
.
infer_type
(
f
)
_
,
env
=
b
.
get
()
assert
fchecked
.
checked_type
==
relay
.
FuncType
([
tp
],
tp
)
assert_decl_has_type
(
env
,
'f'
,
func_type
([
'float32'
],
'float32'
))
def
test_recursion
():
def
test_recursion
():
...
@@ -78,54 +59,44 @@ def test_recursion():
...
@@ -78,54 +59,44 @@ def test_recursion():
return f(n - 1, log(data));
return f(n - 1, log(data));
}
}
}
}
f(2, 10000);
"""
"""
b
=
IRBuilder
()
sb
=
relay
.
ScopeBuilder
()
f
=
b
.
global_var
(
'f'
)
f
=
relay
.
GlobalVar
(
"f"
)
n
=
b
.
param
(
'n'
,
ty
=
'int32'
)
ti32
=
relay
.
scalar_type
(
"int32"
)
data
=
b
.
param
(
'data'
,
ty
=
'float32'
)
tf32
=
relay
.
scalar_type
(
"float32"
)
with
b
.
decl
(
f
,
n
,
data
):
n
=
relay
.
var
(
"n"
,
ti32
)
with
b
.
if_scope
(
equal
(
n
,
convert
(
0
))):
data
=
relay
.
var
(
"data"
,
tf32
)
b
.
ret
(
data
)
with
b
.
else_scope
():
with
sb
.
if_scope
(
relay
.
equal
(
n
,
relay
.
const
(
0
,
ti32
))):
b
.
ret
(
f
(
subtract
(
n
,
convert
(
1
)),
log
(
data
)))
sb
.
ret
(
data
)
b
.
ret
(
f
(
convert
(
2.0
),
convert
(
10000.0
)))
with
sb
.
else_scope
():
assert_decl_has_type
(
b
.
env
,
'f'
,
func_type
(
sb
.
ret
(
f
(
relay
.
subtract
(
n
,
relay
.
const
(
1
,
ti32
)),
relay
.
log
(
data
)))
[
'int32'
,
'float32'
],
'float32'
))
env
=
relay
.
Environment
()
# TODO(@jroesch): need evaluator or new runtime
env
[
f
]
=
relay
.
Function
([
n
,
data
],
sb
.
get
())
# to execute this.
assert
"
%3
= @f(
%1
,
%2
)"
in
env
.
astext
()
assert
env
[
f
]
.
checked_type
==
relay
.
FuncType
([
ti32
,
tf32
],
tf32
)
def
test_concat
():
"""
Program:
def try_concat2(x: Float(3, 2), y: Float(2, 2)) -> Float(5, 2) {
return concatenate((x, y), axis=0);
}
"""
ib
=
IRBuilder
()
try_concat2
=
ib
.
global_var
(
'try_concat2'
)
x
=
ib
.
param
(
'x'
,
ty
=
tensor_type
(
3
,
2
))
y
=
ib
.
param
(
'y'
,
ty
=
tensor_type
(
2
,
2
))
with
ib
.
decl
(
try_concat2
,
x
,
y
):
ib
.
ret
(
concatenate
((
x
,
y
),
axis
=
0
))
fn_ty
=
func_type
([
tensor_type
(
3
,
2
),
tensor_type
(
2
,
2
)],
tensor_type
(
5
,
2
))
assert_decl_has_type
(
ib
.
env
,
try_concat2
,
fn_ty
)
def
test_tuple
():
def
test_tuple
():
ib
=
IRBuilder
()
tp
=
relay
.
TensorType
((
10
,))
dup
=
ib
.
global_var
(
'dup'
)
x
=
relay
.
var
(
"x"
,
tp
)
x
=
ib
.
param
(
'x'
)
res
=
relay
.
Tuple
([
x
,
x
])
with
ib
.
decl
(
dup
,
x
):
assert
(
relay
.
ir_pass
.
infer_type
(
res
)
.
checked_type
==
ib
.
ret
(
relay
.
Tuple
([
x
,
x
]))
relay
.
TupleType
([
tp
,
tp
]))
# todo: why is this not generalized?
fn_ty
=
func_type
([
tensor_type
()],
relay
.
TupleType
([
tensor_type
(),
tensor_type
()]))
assert_decl_has_type
(
ib
.
env
,
dup
,
fn_ty
)
def
test_free_expr
():
x
=
relay
.
var
(
"x"
,
"float32"
)
y
=
relay
.
add
(
x
,
x
)
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
assert
yy
.
checked_type
==
relay
.
scalar_type
(
"float32"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_free_expr
()
test_dual_op
()
test_dual_op
()
test_recursion
()
test_recursion
()
test_monomorphic_let
()
test_monomorphic_let
()
test_decl
()
test_decl
()
test_recursion
()
test_recursion
()
test_concat
()
test_tuple
()
test_tuple
()
tests/python/relay/test_type_solver.py
View file @
8876eac8
import
tvm
import
tvm
from
tvm
import
relay
from
tvm
import
relay
from
tvm.relay.ir_builder
import
scalar_type
,
convert
,
tensor_type
def
make_rel
(
name
,
args
,
num_inputs
=
None
,
attrs
=
None
):
def
make_rel
(
name
,
args
,
num_inputs
=
None
,
attrs
=
None
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment