Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
4e77eeb2
Commit
4e77eeb2
authored
Nov 04, 2018
by
Jared Roesch
Committed by
Tianqi Chen
Nov 04, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY][RUNTIME] Add compute and schedule attributes for all ops in relay/op/tensor.py (#2050)
parent
ead3ac6c
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
453 additions
and
105 deletions
+453
-105
include/tvm/attrs.h
+8
-8
include/tvm/build_module.h
+1
-1
python/tvm/relay/interpreter.py
+18
-3
python/tvm/relay/op/__init__.py
+1
-1
python/tvm/relay/op/_tensor.py
+246
-23
python/tvm/relay/op/op.py
+5
-0
python/tvm/relay/op/tensor.py
+15
-18
src/relay/pass/lower_ops.cc
+8
-6
tests/python/relay/test_op_level1.py
+65
-18
tests/python/relay/test_op_level3.py
+33
-14
tests/python/relay/test_op_level4.py
+53
-13
No files found.
include/tvm/attrs.h
View file @
4e77eeb2
...
@@ -735,12 +735,12 @@ template<typename DerivedType>
...
@@ -735,12 +735,12 @@ template<typename DerivedType>
class
AttrsNode
:
public
BaseAttrsNode
{
class
AttrsNode
:
public
BaseAttrsNode
{
public
:
public
:
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
detail
::
AttrNormalVisitor
vis
(
v
);
::
tvm
::
detail
::
AttrNormalVisitor
vis
(
v
);
self
()
->
__VisitAttrs__
(
vis
);
self
()
->
__VisitAttrs__
(
vis
);
}
}
void
VisitNonDefaultAttrs
(
AttrVisitor
*
v
)
final
{
void
VisitNonDefaultAttrs
(
AttrVisitor
*
v
)
final
{
detail
::
AttrNonDefaultVisitor
vis
(
v
);
::
tvm
::
detail
::
AttrNonDefaultVisitor
vis
(
v
);
self
()
->
__VisitAttrs__
(
vis
);
self
()
->
__VisitAttrs__
(
vis
);
}
}
...
@@ -761,7 +761,7 @@ class AttrsNode : public BaseAttrsNode {
...
@@ -761,7 +761,7 @@ class AttrsNode : public BaseAttrsNode {
}
}
return
false
;
return
false
;
};
};
auto
vis
=
detail
::
CreateInitVisitor
(
DerivedType
::
_type_key
,
ffind
);
auto
vis
=
::
tvm
::
detail
::
CreateInitVisitor
(
DerivedType
::
_type_key
,
ffind
);
self
()
->
__VisitAttrs__
(
vis
);
self
()
->
__VisitAttrs__
(
vis
);
hit_count
=
vis
.
hit_count_
;
hit_count
=
vis
.
hit_count_
;
}
else
{
}
else
{
...
@@ -779,14 +779,14 @@ class AttrsNode : public BaseAttrsNode {
...
@@ -779,14 +779,14 @@ class AttrsNode : public BaseAttrsNode {
}
}
return
false
;
return
false
;
};
};
auto
vis
=
detail
::
CreateInitVisitor
(
DerivedType
::
_type_key
,
ffind
);
auto
vis
=
::
tvm
::
detail
::
CreateInitVisitor
(
DerivedType
::
_type_key
,
ffind
);
self
()
->
__VisitAttrs__
(
vis
);
self
()
->
__VisitAttrs__
(
vis
);
hit_count
=
vis
.
hit_count_
;
hit_count
=
vis
.
hit_count_
;
}
}
// error handling, slow path
// error handling, slow path
if
(
hit_count
*
2
!=
args
.
size
()
&&
!
allow_unknown
)
{
if
(
hit_count
*
2
!=
args
.
size
()
&&
!
allow_unknown
)
{
for
(
int
i
=
0
;
i
<
args
.
size
();
i
+=
2
)
{
for
(
int
i
=
0
;
i
<
args
.
size
();
i
+=
2
)
{
detail
::
AttrExistVisitor
visitor
;
::
tvm
::
detail
::
AttrExistVisitor
visitor
;
visitor
.
key_
=
args
[
i
].
operator
std
::
string
();
visitor
.
key_
=
args
[
i
].
operator
std
::
string
();
self
()
->
__VisitAttrs__
(
visitor
);
self
()
->
__VisitAttrs__
(
visitor
);
if
(
!
visitor
.
exist_
)
{
if
(
!
visitor
.
exist_
)
{
...
@@ -803,7 +803,7 @@ class AttrsNode : public BaseAttrsNode {
...
@@ -803,7 +803,7 @@ class AttrsNode : public BaseAttrsNode {
}
}
Array
<
AttrFieldInfo
>
ListFieldInfo
()
const
final
{
Array
<
AttrFieldInfo
>
ListFieldInfo
()
const
final
{
detail
::
AttrDocVisitor
visitor
;
::
tvm
::
detail
::
AttrDocVisitor
visitor
;
self
()
->
__VisitAttrs__
(
visitor
);
self
()
->
__VisitAttrs__
(
visitor
);
return
visitor
.
fields_
;
return
visitor
.
fields_
;
}
}
...
@@ -813,13 +813,13 @@ class AttrsNode : public BaseAttrsNode {
...
@@ -813,13 +813,13 @@ class AttrsNode : public BaseAttrsNode {
if
(
pself
==
other
)
return
true
;
if
(
pself
==
other
)
return
true
;
if
(
other
==
nullptr
)
return
false
;
if
(
other
==
nullptr
)
return
false
;
if
(
pself
->
type_index
()
!=
other
->
type_index
())
return
false
;
if
(
pself
->
type_index
()
!=
other
->
type_index
())
return
false
;
detail
::
AttrsEqualVisitor
visitor
(
pself
,
other
,
equal
);
::
tvm
::
detail
::
AttrsEqualVisitor
visitor
(
pself
,
other
,
equal
);
self
()
->
__VisitAttrs__
(
visitor
);
self
()
->
__VisitAttrs__
(
visitor
);
return
visitor
.
result_
;
return
visitor
.
result_
;
}
}
size_t
ContentHash
(
AttrsHash
hasher
)
const
final
{
size_t
ContentHash
(
AttrsHash
hasher
)
const
final
{
detail
::
AttrsHashVisitor
visitor
(
hasher
);
::
tvm
::
detail
::
AttrsHashVisitor
visitor
(
hasher
);
visitor
.
result_
=
std
::
hash
<
std
::
string
>
()(
this
->
type_key
());
visitor
.
result_
=
std
::
hash
<
std
::
string
>
()(
this
->
type_key
());
self
()
->
__VisitAttrs__
(
visitor
);
self
()
->
__VisitAttrs__
(
visitor
);
return
visitor
.
result_
;
return
visitor
.
result_
;
...
...
include/tvm/build_module.h
View file @
4e77eeb2
...
@@ -417,7 +417,7 @@ inline TVMRetValue GenericFunc::operator()(Args&& ...args) const {
...
@@ -417,7 +417,7 @@ inline TVMRetValue GenericFunc::operator()(Args&& ...args) const {
const
int
kArraySize
=
kNumArgs
>
0
?
kNumArgs
:
1
;
const
int
kArraySize
=
kNumArgs
>
0
?
kNumArgs
:
1
;
TVMValue
values
[
kArraySize
];
TVMValue
values
[
kArraySize
];
int
type_codes
[
kArraySize
];
int
type_codes
[
kArraySize
];
detail
::
for_each
(
TVMArgsSetter
(
values
,
type_codes
),
runtime
::
detail
::
for_each
(
TVMArgsSetter
(
values
,
type_codes
),
std
::
forward
<
Args
>
(
args
)...);
std
::
forward
<
Args
>
(
args
)...);
TVMRetValue
rv
;
TVMRetValue
rv
;
CallPacked
(
TVMArgs
(
values
,
type_codes
,
kNumArgs
),
&
rv
);
CallPacked
(
TVMArgs
(
values
,
type_codes
,
kNumArgs
),
&
rv
);
...
...
python/tvm/relay/interpreter.py
View file @
4e77eeb2
...
@@ -138,7 +138,8 @@ class Executor(object):
...
@@ -138,7 +138,8 @@ class Executor(object):
"""
"""
if
params
:
if
params
:
scope_builder
=
ScopeBuilder
()
scope_builder
=
ScopeBuilder
()
for
key
,
value
in
params
:
for
key
in
params
:
value
=
params
[
key
]
scope_builder
.
let
(
key
,
value
)
scope_builder
.
let
(
key
,
value
)
scope_builder
.
ret
(
expr
)
scope_builder
.
ret
(
expr
)
expr
=
scope_builder
.
get
()
expr
=
scope_builder
.
get
()
...
@@ -146,7 +147,17 @@ class Executor(object):
...
@@ -146,7 +147,17 @@ class Executor(object):
if
isinstance
(
expr
,
Function
):
if
isinstance
(
expr
,
Function
):
assert
not
ir_pass
.
free_vars
(
expr
)
assert
not
ir_pass
.
free_vars
(
expr
)
return
self
.
_make_executor
(
expr
)
executor
=
self
.
_make_executor
(
expr
)
# If we are evaluating a function or top-level defintion
# the user must call the function themselves.
#
# If we are evaluating an open term with parameters we will
# just return them the result.
if
isinstance
(
expr
,
(
Function
,
GlobalVar
)):
return
executor
else
:
return
executor
()
class
Interpreter
(
Executor
):
class
Interpreter
(
Executor
):
...
@@ -168,10 +179,14 @@ class Interpreter(Executor):
...
@@ -168,10 +179,14 @@ class Interpreter(Executor):
self
.
mod
.
_add
(
expr
,
func
,
True
)
self
.
mod
.
_add
(
expr
,
func
,
True
)
opt_expr
=
Call
(
expr
,
relay_args
)
opt_expr
=
Call
(
expr
,
relay_args
)
return
_interpreter
.
evaluate
(
self
.
mod
,
opt_expr
)
return
_interpreter
.
evaluate
(
self
.
mod
,
opt_expr
)
el
se
:
el
if
isinstance
(
expr
,
Function
)
:
call
=
Call
(
expr
,
relay_args
)
call
=
Call
(
expr
,
relay_args
)
opt_expr
=
self
.
optimize
(
call
)
opt_expr
=
self
.
optimize
(
call
)
return
_interpreter
.
evaluate
(
self
.
mod
,
opt_expr
)
return
_interpreter
.
evaluate
(
self
.
mod
,
opt_expr
)
else
:
assert
not
args
opt_expr
=
self
.
optimize
(
expr
)
return
_interpreter
.
evaluate
(
self
.
mod
,
opt_expr
)
return
_interp_wrapper
return
_interp_wrapper
...
...
python/tvm/relay/op/__init__.py
View file @
4e77eeb2
#pylint: disable=wildcard-import, redefined-builtin
#pylint: disable=wildcard-import, redefined-builtin
"""Relay core operators."""
"""Relay core operators."""
# operator defs
# operator defs
from
.op
import
get
,
register
,
Op
from
.op
import
get
,
register
,
register_schedule
,
register_compute
,
Op
# Operators
# Operators
from
.reduce
import
*
from
.reduce
import
*
...
...
python/tvm/relay/op/_tensor.py
View file @
4e77eeb2
#pylint: disable=invalid-name, unused-argument
#pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
"""Backend compiler related feature registration"""
from
__future__
import
absolute_import
import
tvm
import
tvm
import
topi
import
topi
from
.
import
register
import
topi.cuda
from
.
import
register_schedule
,
register_compute
def
schedule_injective
(
outputs
,
target
):
"""Generic schedule for binary broadcast."""
with
tvm
.
target
.
create
(
target
):
return
topi
.
generic
.
schedule_injective
(
outputs
)
schedule_broadcast
=
schedule_injective
schedule_elemwise
=
schedule_injective
# log
def
log_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
1
return
[
topi
.
log
(
inputs
[
0
])]
register_compute
(
"log"
,
log_compute
)
register_schedule
(
"log"
,
schedule_broadcast
)
# exp
def
exp_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
1
return
[
topi
.
exp
(
inputs
[
0
])]
register_compute
(
"exp"
,
exp_compute
)
register_schedule
(
"exp"
,
schedule_broadcast
)
# sqrt
def
sqrt_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
1
return
[
topi
.
sqrt
(
inputs
[
0
])]
register_compute
(
"sqrt"
,
sqrt_compute
)
register_schedule
(
"sqrt"
,
schedule_broadcast
)
# sigmoid
def
sigmoid_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
1
return
[
topi
.
sigmoid
(
inputs
[
0
])]
register_compute
(
"sigmoid"
,
sigmoid_compute
)
register_schedule
(
"sigmoid"
,
schedule_broadcast
)
# floor
def
floor_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
1
return
[
topi
.
floor
(
inputs
[
0
])]
register_compute
(
"floor"
,
floor_compute
)
register_schedule
(
"floor"
,
schedule_broadcast
)
# ceil
def
ceil_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
1
return
[
topi
.
ceil
(
inputs
[
0
])]
register_compute
(
"ceil"
,
ceil_compute
)
register_schedule
(
"ceil"
,
schedule_broadcast
)
# trunc
def
trunc_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
1
return
[
topi
.
trunc
(
inputs
[
0
])]
register_compute
(
"trunc"
,
trunc_compute
)
register_schedule
(
"trunc"
,
schedule_broadcast
)
# round
def
round_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
1
return
[
topi
.
round
(
inputs
[
0
])]
register_compute
(
"round"
,
round_compute
)
register_schedule
(
"round"
,
schedule_broadcast
)
# abs
def
abs_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
1
return
[
topi
.
abs
(
inputs
[
0
])]
register_compute
(
"abs"
,
abs_compute
)
register_schedule
(
"abs"
,
schedule_broadcast
)
# tanh
def
tanh_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
1
return
[
topi
.
tanh
(
inputs
[
0
])]
register_compute
(
"tanh"
,
tanh_compute
)
register_schedule
(
"tanh"
,
schedule_broadcast
)
# negative
def
negative_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
1
return
[
topi
.
negative
(
inputs
[
0
])]
register_compute
(
"negative"
,
negative_compute
)
register_schedule
(
"negative"
,
schedule_broadcast
)
# add
def
add_compute
(
attrs
,
inputs
,
output_type
,
target
):
def
add_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
assert
len
(
inputs
)
==
2
return
[
topi
.
add
(
inputs
[
0
],
inputs
[
1
])]
return
[
topi
.
add
(
inputs
[
0
],
inputs
[
1
])]
def
add_schedule
(
outputs
,
target
):
register_compute
(
"add"
,
add_compute
)
assert
len
(
outputs
)
==
1
register_schedule
(
"add"
,
schedule_injective
)
return
tvm
.
create_schedule
(
outputs
[
0
]
.
op
)
register
(
"add"
,
"FTVMCompute"
,
add_compute
)
register
(
"add"
,
"FTVMSchedule"
,
add_schedule
)
# subtract
def
subtract_compute
(
attrs
,
inputs
,
output_type
,
target
):
def
subtract_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
assert
len
(
inputs
)
==
2
return
[
topi
.
subtract
(
inputs
[
0
],
inputs
[
1
])]
return
[
topi
.
subtract
(
inputs
[
0
],
inputs
[
1
])]
def
subtract_schedule
(
outputs
,
target
):
register_compute
(
"subtract"
,
subtract_compute
)
assert
len
(
outputs
)
==
1
register_schedule
(
"subtract"
,
schedule_broadcast
)
return
tvm
.
create_schedule
(
outputs
[
0
]
.
op
)
register
(
"subtract"
,
"FTVMCompute"
,
subtract_compute
)
register
(
"subtract"
,
"FTVMSchedule"
,
subtract_schedule
)
# multiply
def
multiply_compute
(
attrs
,
inputs
,
output_type
,
target
):
def
multiply_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
assert
len
(
inputs
)
==
2
return
[
topi
.
multiply
(
inputs
[
0
],
inputs
[
1
])]
return
[
topi
.
multiply
(
inputs
[
0
],
inputs
[
1
])]
def
multiply_schedule
(
outputs
,
target
):
register_compute
(
"multiply"
,
multiply_compute
)
assert
len
(
outputs
)
==
1
register_schedule
(
"multiply"
,
schedule_broadcast
)
return
tvm
.
create_schedule
(
outputs
[
0
]
.
op
)
# divide
def
divide_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
divide
(
inputs
[
0
],
inputs
[
1
])]
register_compute
(
"divide"
,
divide_compute
)
register_schedule
(
"divide"
,
schedule_broadcast
)
register
(
"multiply"
,
"FTVMCompute"
,
multiply_compute
)
# pow
register
(
"multiply"
,
"FTVMSchedule"
,
multiply_schedule
)
def
pow_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
power
(
inputs
[
0
],
inputs
[
1
])]
register_compute
(
"pow"
,
pow_compute
)
register_schedule
(
"pow"
,
schedule_injective
)
# mod
def
mod_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
mod
(
inputs
[
0
],
inputs
[
1
])]
register_compute
(
"mod"
,
mod_compute
)
register_schedule
(
"mod"
,
schedule_broadcast
)
# equal
def
equal_compute
(
attrs
,
inputs
,
output_type
,
target
):
def
equal_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
assert
len
(
inputs
)
==
2
return
[
topi
.
equal
(
inputs
[
0
],
inputs
[
1
])]
return
[
topi
.
equal
(
inputs
[
0
],
inputs
[
1
])]
def
equal_schedule
(
outputs
,
target
):
register_compute
(
"equal"
,
equal_compute
)
assert
len
(
outputs
)
==
1
register_schedule
(
"equal"
,
schedule_broadcast
)
return
tvm
.
create_schedule
(
outputs
[
0
]
.
op
)
# not_equal
def
not_equal_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
not_equal
(
inputs
[
0
],
inputs
[
1
])]
register_compute
(
"not_equal"
,
not_equal_compute
)
register_schedule
(
"not_equal"
,
schedule_broadcast
)
# less
def
less_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
less
(
inputs
[
0
],
inputs
[
1
])]
register_compute
(
"less"
,
less_compute
)
register_schedule
(
"less"
,
schedule_broadcast
)
# less equal
def
less_equal_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
less_equal
(
inputs
[
0
],
inputs
[
1
])]
register_compute
(
"less_equal"
,
less_equal_compute
)
register_schedule
(
"less_equal"
,
schedule_broadcast
)
# greater
def
greater_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
greater
(
inputs
[
0
],
inputs
[
1
])]
register_compute
(
"greater"
,
greater_compute
)
register_schedule
(
"greater"
,
schedule_broadcast
)
# greater equal
def
greater_equal_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
greater_equal
(
inputs
[
0
],
inputs
[
1
])]
register_compute
(
"greater_equal"
,
greater_equal_compute
)
register_schedule
(
"greater_equal"
,
schedule_broadcast
)
# maximum
def
maximum_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
maximum
(
inputs
[
0
],
inputs
[
1
])]
register_compute
(
"maximum_compute"
,
maximum_compute
)
register_schedule
(
"maximum_compute"
,
schedule_injective
)
# minimum
def
minimum_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
minimum
(
inputs
[
0
],
inputs
[
1
])]
register_compute
(
"minimum"
,
minimum_compute
)
register_schedule
(
"minimum"
,
schedule_injective
)
# right shift
def
right_shift_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
right_shift
(
inputs
[
0
],
inputs
[
1
])]
register_compute
(
"right_shift"
,
right_shift_compute
)
register_schedule
(
"right_shift"
,
schedule_injective
)
# lift shift
def
left_shift_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
left_shift
(
inputs
[
0
],
inputs
[
1
])]
register_compute
(
"left_shift"
,
left_shift_compute
)
register_schedule
(
"left_shift"
,
schedule_injective
)
# zeros
def
zeros_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
not
inputs
return
[
topi
.
full
(
output_type
.
shape
,
output_type
.
dtype
,
0.0
)]
register_compute
(
"zeros"
,
zeros_compute
)
register_schedule
(
"zeros"
,
schedule_injective
)
# zeros_like
def
zeros_like_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
1
return
[
topi
.
full_like
(
inputs
[
0
],
0.0
)]
register_compute
(
"zeros_like"
,
zeros_like_compute
)
register_schedule
(
"zeros_like"
,
schedule_injective
)
# ones
def
ones_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
not
inputs
return
[
topi
.
full
(
output_type
.
shape
,
output_type
.
dtype
,
1.0
)]
register_compute
(
"ones"
,
ones_compute
)
register_schedule
(
"ones"
,
schedule_injective
)
# ones_like
def
ones_like
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
1
return
[
topi
.
full_like
(
inputs
[
0
],
1.0
)]
register_compute
(
"ones_like"
,
ones_like
)
register_schedule
(
"ones_like"
,
schedule_injective
)
# clip
def
clip_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
1
return
[
topi
.
clip
(
inputs
[
0
],
attrs
.
a_min
,
attrs
.
a_max
)]
register
(
"equal"
,
"FTVMCompute"
,
equal
_compute
)
register
_compute
(
"clip"
,
clip
_compute
)
register
(
"equal"
,
"FTVMSchedule"
,
equal_schedul
e
)
register
_schedule
(
"clip"
,
schedule_injectiv
e
)
python/tvm/relay/op/op.py
View file @
4e77eeb2
...
@@ -74,6 +74,11 @@ def register(op_name, attr_key, value=None, level=10):
...
@@ -74,6 +74,11 @@ def register(op_name, attr_key, value=None, level=10):
return
v
return
v
return
_register
(
value
)
if
value
else
_register
return
_register
(
value
)
if
value
else
_register
def
register_schedule
(
op_name
,
schedule
):
register
(
op_name
,
"FTVMSchedule"
,
schedule
)
def
register_compute
(
op_name
,
compute
):
register
(
op_name
,
"FTVMCompute"
,
compute
)
_init_api
(
"relay.op"
,
__name__
)
_init_api
(
"relay.op"
,
__name__
)
...
...
python/tvm/relay/op/tensor.py
View file @
4e77eeb2
...
@@ -213,9 +213,8 @@ def add(lhs, rhs):
...
@@ -213,9 +213,8 @@ def add(lhs, rhs):
"""
"""
return
_make
.
add
(
lhs
,
rhs
)
return
_make
.
add
(
lhs
,
rhs
)
def
subtract
(
lhs
,
rhs
):
def
multiply
(
lhs
,
rhs
):
"""Subtraction with numpy-style broadcasting.
"""Multiplication with numpy-style broadcasting.
Parameters
Parameters
----------
----------
...
@@ -229,11 +228,10 @@ def multiply(lhs, rhs):
...
@@ -229,11 +228,10 @@ def multiply(lhs, rhs):
result : relay.Expr
result : relay.Expr
The computed result.
The computed result.
"""
"""
return
_make
.
multiply
(
lhs
,
rhs
)
return
_make
.
subtract
(
lhs
,
rhs
)
def
divide
(
lhs
,
rhs
):
def
multiply
(
lhs
,
rhs
):
"""
Divis
ion with numpy-style broadcasting.
"""
Multiplicat
ion with numpy-style broadcasting.
Parameters
Parameters
----------
----------
...
@@ -247,11 +245,11 @@ def divide(lhs, rhs):
...
@@ -247,11 +245,11 @@ def divide(lhs, rhs):
result : relay.Expr
result : relay.Expr
The computed result.
The computed result.
"""
"""
return
_make
.
divide
(
lhs
,
rhs
)
return
_make
.
multiply
(
lhs
,
rhs
)
def
pow
(
lhs
,
rhs
):
def
divide
(
lhs
,
rhs
):
"""
Power
with numpy-style broadcasting.
"""
Division
with numpy-style broadcasting.
Parameters
Parameters
----------
----------
...
@@ -265,11 +263,11 @@ def pow(lhs, rhs):
...
@@ -265,11 +263,11 @@ def pow(lhs, rhs):
result : relay.Expr
result : relay.Expr
The computed result.
The computed result.
"""
"""
return
_make
.
pow
(
lhs
,
rhs
)
return
_make
.
divide
(
lhs
,
rhs
)
def
mod
(
lhs
,
rhs
):
def
pow
(
lhs
,
rhs
):
"""
Mod
with numpy-style broadcasting.
"""
Power
with numpy-style broadcasting.
Parameters
Parameters
----------
----------
...
@@ -283,11 +281,11 @@ def mod(lhs, rhs):
...
@@ -283,11 +281,11 @@ def mod(lhs, rhs):
result : relay.Expr
result : relay.Expr
The computed result.
The computed result.
"""
"""
return
_make
.
mod
(
lhs
,
rhs
)
return
_make
.
pow
(
lhs
,
rhs
)
def
subtract
(
lhs
,
rhs
):
def
mod
(
lhs
,
rhs
):
"""
Subtraction
with numpy-style broadcasting.
"""
Mod
with numpy-style broadcasting.
Parameters
Parameters
----------
----------
...
@@ -301,7 +299,7 @@ def subtract(lhs, rhs):
...
@@ -301,7 +299,7 @@ def subtract(lhs, rhs):
result : relay.Expr
result : relay.Expr
The computed result.
The computed result.
"""
"""
return
_make
.
subtract
(
lhs
,
rhs
)
return
_make
.
mod
(
lhs
,
rhs
)
def
equal
(
lhs
,
rhs
):
def
equal
(
lhs
,
rhs
):
...
@@ -553,7 +551,6 @@ def ones_like(data):
...
@@ -553,7 +551,6 @@ def ones_like(data):
"""
"""
return
_make
.
ones_like
(
data
)
return
_make
.
ones_like
(
data
)
def
clip
(
a
,
a_min
,
a_max
):
def
clip
(
a
,
a_min
,
a_max
):
"""Clip the elements in `a` between `a_min` and `a_max`.
"""Clip the elements in `a` between `a_min` and `a_max`.
`a_min` and `a_max` are cast to `a`'s dtype.
`a_min` and `a_max` are cast to `a`'s dtype.
...
...
src/relay/pass/lower_ops.cc
View file @
4e77eeb2
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
*/
*/
#include <tvm/lowered_func.h>
#include <tvm/lowered_func.h>
#include <tvm/operation.h>
#include <tvm/operation.h>
#include <tvm/build_module.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/logging.h>
#include <tvm/relay/logging.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/pass.h>
...
@@ -155,8 +156,8 @@ struct LiveFunctions : ExprVisitor {
...
@@ -155,8 +156,8 @@ struct LiveFunctions : ExprVisitor {
};
};
using
FCompute
=
TypedPackedFunc
<
Array
<
Tensor
>
(
using
FCompute
=
TypedPackedFunc
<
Array
<
Tensor
>
(
const
Attrs
&
,
const
Array
<
Tensor
>&
,
Type
,
std
::
string
)
>
;
const
Attrs
&
,
const
Array
<
Tensor
>&
,
Type
,
tvm
::
Target
)
>
;
using
FSchedule
=
TypedPackedFunc
<
Schedule
(
const
Array
<
Tensor
>&
,
std
::
string
)
>
;
using
FSchedule
=
TypedPackedFunc
<
Schedule
(
const
Array
<
Tensor
>&
,
tvm
::
Target
)
>
;
/*! \brief Return the set of operators in their TVM format. */
/*! \brief Return the set of operators in their TVM format. */
Array
<
LoweredOp
>
LowerOps
(
const
Module
&
mod
,
const
Expr
&
e
,
Array
<
LoweredOp
>
LowerOps
(
const
Module
&
mod
,
const
Expr
&
e
,
...
@@ -179,7 +180,7 @@ Array<LoweredOp> LowerOps(const Module& mod, const Expr& e,
...
@@ -179,7 +180,7 @@ Array<LoweredOp> LowerOps(const Module& mod, const Expr& e,
auto
func
=
mod
->
Lookup
(
func_name
);
auto
func
=
mod
->
Lookup
(
func_name
);
auto
call
=
Downcast
<
Call
>
(
func
->
body
);
auto
call
=
Downcast
<
Call
>
(
func
->
body
);
auto
op_node
=
call
->
op
.
as
<
OpNode
>
();
auto
op_node
=
call
->
op
.
as
<
OpNode
>
();
CHECK
(
op_node
)
<<
"violated invariant that primti
i
ve calls contain a single op call"
;
CHECK
(
op_node
)
<<
"violated invariant that primtive calls contain a single op call"
;
auto
op
=
GetRef
<
Op
>
(
op_node
);
auto
op
=
GetRef
<
Op
>
(
op_node
);
RELAY_LOG
(
INFO
)
<<
"LowerOps: Lowering "
<<
op
->
name
;
RELAY_LOG
(
INFO
)
<<
"LowerOps: Lowering "
<<
op
->
name
;
...
@@ -197,10 +198,11 @@ Array<LoweredOp> LowerOps(const Module& mod, const Expr& e,
...
@@ -197,10 +198,11 @@ Array<LoweredOp> LowerOps(const Module& mod, const Expr& e,
i
++
;
i
++
;
}
}
auto
output_tt
=
op
->
op_type
->
ret_type
;
auto
output_tt
=
call
->
checked_type
();
auto
target_node
=
Target
::
create
(
target
);
Array
<
Tensor
>
outputs
=
Array
<
Tensor
>
outputs
=
compute_reg
[
op
](
call
->
attrs
,
inputs
,
output_tt
,
target
);
compute_reg
[
op
](
call
->
attrs
,
inputs
,
output_tt
,
target
_node
);
auto
schedule
=
schedule_reg
[
op
](
outputs
,
target
);
auto
schedule
=
schedule_reg
[
op
](
outputs
,
target
_node
);
size_t
hash
=
StructuralHash
()(
func
);
size_t
hash
=
StructuralHash
()(
func
);
LoweredFunc
lf
=
LoweredFunc
lf
=
flower
(
op
->
name
+
std
::
to_string
(
hash
),
schedule
,
inputs
,
outputs
);
flower
(
op
->
name
+
std
::
to_string
(
hash
),
schedule
,
inputs
,
outputs
);
...
...
tests/python/relay/test_op_level1.py
View file @
4e77eeb2
import
math
import
tvm
import
tvm
import
numpy
as
np
import
numpy
as
np
from
tvm
import
relay
from
tvm
import
relay
from
tvm.relay.interpreter
import
create_executor
def
sigmoid
(
x
):
one
=
np
.
ones_like
(
x
)
return
one
/
(
one
+
np
.
exp
(
-
x
))
def
relu
(
x
):
x_copy
=
np
.
copy
(
x
)
np
.
maximum
(
x_copy
,
0
,
x_copy
)
return
x_copy
def
test_unary_op
():
def
test_unary_op
():
def
check_single_op
(
opfunc
):
def
check_single_op
(
opfunc
,
ref
):
tp
=
relay
.
TensorType
((
10
,
4
),
"float32"
)
shape
=
(
10
,
4
)
dtype
=
'float32'
tp
=
relay
.
TensorType
(
shape
,
dtype
)
x
=
relay
.
var
(
"x"
,
tp
)
x
=
relay
.
var
(
"x"
,
tp
)
y
=
opfunc
(
x
)
y
=
opfunc
(
x
)
# test printer
# test printer
...
@@ -13,20 +25,33 @@ def test_unary_op():
...
@@ -13,20 +25,33 @@ def test_unary_op():
# test type inference
# test type inference
assert
relay
.
ir_pass
.
infer_type
(
y
)
.
checked_type
==
tp
assert
relay
.
ir_pass
.
infer_type
(
y
)
.
checked_type
==
tp
for
opfunc
in
[
tvm
.
relay
.
log
,
if
ref
is
not
None
:
tvm
.
relay
.
exp
,
data
=
np
.
random
.
rand
(
*
shape
)
.
astype
(
dtype
)
tvm
.
relay
.
sqrt
,
intrp
=
create_executor
()
tvm
.
relay
.
sigmoid
,
op_res
=
intrp
.
evaluate
(
y
,
{
x
:
relay
.
const
(
data
)
})
tvm
.
relay
.
tanh
,
ref_res
=
ref
(
data
)
relay
.
nn
.
relu
]:
np
.
testing
.
assert_allclose
(
op_res
.
asnumpy
(),
ref_res
,
rtol
=
0.01
)
check_single_op
(
opfunc
)
for
opfunc
,
ref
in
[(
tvm
.
relay
.
log
,
np
.
log
),
(
tvm
.
relay
.
exp
,
np
.
exp
),
(
tvm
.
relay
.
sqrt
,
np
.
sqrt
),
(
tvm
.
relay
.
sigmoid
,
sigmoid
),
(
tvm
.
relay
.
tanh
,
np
.
tanh
),
(
relay
.
nn
.
relu
,
None
)]:
# Just add RELU here after registering.
check_single_op
(
opfunc
,
ref
)
def
test_binary_op
():
def
test_binary_op
():
def
check_binary_op
(
opfunc
):
def
inst
(
vars
,
sh
):
return
[
vars
.
get
(
s
,
s
)
for
s
in
sh
]
def
check_binary_op
(
opfunc
,
ref
):
# TODO(@jroesch): this piece of code improperly uses type variables.
n
=
tvm
.
var
(
"n"
)
n
=
tvm
.
var
(
"n"
)
t1
=
relay
.
TensorType
((
5
,
n
,
5
))
s1
=
(
5
,
n
,
5
)
t2
=
relay
.
TensorType
((
n
,
1
))
s2
=
(
n
,
1
)
t1
=
relay
.
TensorType
(
s1
)
t2
=
relay
.
TensorType
(
s2
)
x
=
relay
.
var
(
"x"
,
t1
)
x
=
relay
.
var
(
"x"
,
t1
)
y
=
relay
.
var
(
"y"
,
t2
)
y
=
relay
.
var
(
"y"
,
t2
)
z
=
opfunc
(
x
,
y
)
z
=
opfunc
(
x
,
y
)
...
@@ -34,12 +59,25 @@ def test_binary_op():
...
@@ -34,12 +59,25 @@ def test_binary_op():
assert
(
"
%0
= {}(
%
x,
%
y)"
.
format
(
z
.
op
.
name
))
in
z
.
astext
()
assert
(
"
%0
= {}(
%
x,
%
y)"
.
format
(
z
.
op
.
name
))
in
z
.
astext
()
assert
relay
.
ir_pass
.
infer_type
(
z
)
.
checked_type
==
t1
assert
relay
.
ir_pass
.
infer_type
(
z
)
.
checked_type
==
t1
for
opfunc
in
[
relay
.
add
,
if
ref
is
not
None
:
relay
.
subtract
,
t1
=
relay
.
TensorType
((
5
,
10
,
5
))
relay
.
mod
,
t2
=
relay
.
TensorType
((
5
,
10
,
5
))
relay
.
multiply
,
x
=
relay
.
var
(
"x"
,
t1
)
relay
.
divide
]:
y
=
relay
.
var
(
"y"
,
t2
)
check_binary_op
(
opfunc
)
z
=
opfunc
(
x
,
y
)
x_data
=
np
.
random
.
rand
(
5
,
10
,
5
)
.
astype
(
t1
.
dtype
)
y_data
=
np
.
random
.
rand
(
5
,
10
,
5
)
.
astype
(
t2
.
dtype
)
intrp
=
create_executor
()
op_res
=
intrp
.
evaluate
(
z
,
{
x
:
relay
.
const
(
x_data
),
y
:
relay
.
const
(
y_data
)
})
ref_res
=
ref
(
x_data
,
y_data
)
np
.
testing
.
assert_allclose
(
op_res
.
asnumpy
(),
ref_res
,
rtol
=
0.01
)
for
opfunc
,
ref
in
[(
relay
.
add
,
np
.
add
),
(
relay
.
subtract
,
np
.
subtract
),
(
relay
.
mod
,
np
.
mod
),
(
relay
.
multiply
,
np
.
multiply
),
(
relay
.
divide
,
np
.
divide
)]:
check_binary_op
(
opfunc
,
ref
)
def
test_bias_add
():
def
test_bias_add
():
...
@@ -96,6 +134,15 @@ def test_concatenate_infer_type():
...
@@ -96,6 +134,15 @@ def test_concatenate_infer_type():
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
assert
zz
.
checked_type
==
relay
.
TensorType
((
n
,
t
+
t
,
100
))
assert
zz
.
checked_type
==
relay
.
TensorType
((
n
,
t
+
t
,
100
))
# x = relay.var("x", shape=(10, 5))
# y = relay.var("y", shape=(10, 5))
# z = relay.concatenate((x, y), axis=1)
# intrp = create_executor()
# x_data = np.random.rand(10, 5).astype('float32')
# y_data = np.random.rand(10, 5).astype('float32')
# op_res = intrp.evaluate(z, { x: relay.const(x_data), y: relay.const(y_data) })
# ref_res = np.concatenate(x_data, y_data, axis=1)
# np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
def
test_dropout
():
def
test_dropout
():
n
,
t
,
d
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"t"
),
tvm
.
var
(
"d"
)
n
,
t
,
d
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"t"
),
tvm
.
var
(
"d"
)
...
...
tests/python/relay/test_op_level3.py
View file @
4e77eeb2
...
@@ -3,29 +3,40 @@
...
@@ -3,29 +3,40 @@
import
tvm
import
tvm
import
numpy
as
np
import
numpy
as
np
from
tvm
import
relay
from
tvm
import
relay
from
tvm.relay
import
create_executor
from
nose.tools
import
raises
from
nose.tools
import
raises
def
test_zeros_ones
():
def
test_zeros_ones
():
for
op
in
[
relay
.
zeros
,
relay
.
ones
]:
for
op
,
ref
in
[(
relay
.
zeros
,
np
.
zeros
),
(
relay
.
ones
,
np
.
ones
)
]:
y
=
op
(
shape
=
(
124
,
50
),
dtype
=
"float64"
)
y
=
op
(
shape
=
(
124
,
50
),
dtype
=
"float64"
)
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
assert
yy
.
checked_type
==
relay
.
TensorType
((
124
,
50
),
"float64"
)
assert
yy
.
checked_type
==
relay
.
TensorType
((
124
,
50
),
"float64"
)
intrp
=
create_executor
()
intrp_res
=
intrp
.
evaluate
(
y
)
.
asnumpy
()
np
.
testing
.
assert_allclose
(
intrp_res
,
ref
((
124
,
50
),
'float64'
))
def
test_unary_identity
():
def
test_unary_identity
():
for
op
in
[
relay
.
zeros_like
,
for
op
,
ref
in
[(
relay
.
zeros_like
,
np
.
zeros_like
),
relay
.
ones_like
,
(
relay
.
ones_like
,
np
.
ones_like
),
relay
.
ceil
,
(
relay
.
ceil
,
np
.
ceil
),
relay
.
floor
,
(
relay
.
floor
,
np
.
floor
),
relay
.
trunc
,
(
relay
.
trunc
,
np
.
trunc
),
relay
.
round
,
(
relay
.
round
,
np
.
round
),
relay
.
abs
,
(
relay
.
abs
,
np
.
abs
),
relay
.
copy
,
(
relay
.
copy
,
None
),
# np.copy
relay
.
negative
]:
(
relay
.
negative
,
np
.
negative
)]:
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
8
,
9
,
4
),
"float32"
))
shape
=
(
8
,
9
,
4
)
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
(
shape
,
"float32"
))
y
=
op
(
x
)
y
=
op
(
x
)
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
assert
yy
.
checked_type
==
relay
.
TensorType
(
(
8
,
9
,
4
)
,
"float32"
)
assert
yy
.
checked_type
==
relay
.
TensorType
(
shape
,
"float32"
)
if
ref
is
not
None
:
data
=
np
.
random
.
rand
(
*
shape
)
.
astype
(
'float32'
)
intrp
=
create_executor
()
op_res
=
intrp
.
evaluate
(
y
,
{
x
:
relay
.
const
(
data
)
})
ref_res
=
ref
(
data
)
np
.
testing
.
assert_allclose
(
op_res
.
asnumpy
(),
ref_res
,
rtol
=
0.01
)
def
test_cast
():
def
test_cast
():
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
8
,
9
,
4
),
"float32"
))
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
8
,
9
,
4
),
"float32"
))
...
@@ -35,12 +46,20 @@ def test_cast():
...
@@ -35,12 +46,20 @@ def test_cast():
assert
yy
.
checked_type
==
relay
.
TensorType
((
8
,
9
,
4
),
"int32"
)
assert
yy
.
checked_type
==
relay
.
TensorType
((
8
,
9
,
4
),
"int32"
)
def
test_clip
_type
():
def
test_clip
():
a
=
relay
.
var
(
"a"
,
relay
.
TensorType
((
10
,
4
),
"float32"
))
a
=
relay
.
var
(
"a"
,
relay
.
TensorType
((
10
,
4
),
"float32"
))
y
=
relay
.
clip
(
a
,
1.
,
4.
)
y
=
relay
.
clip
(
a
,
1.
,
4.
)
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
assert
yy
.
checked_type
==
relay
.
TensorType
((
10
,
4
),
"float32"
)
assert
yy
.
checked_type
==
relay
.
TensorType
((
10
,
4
),
"float32"
)
data
=
np
.
random
.
rand
(
10
,
4
)
.
astype
(
'float32'
)
intrp
=
create_executor
()
op_res
=
intrp
.
evaluate
(
y
,
{
a
:
relay
.
const
(
data
)
})
ref_res
=
np
.
clip
(
data
,
1.
,
4.
)
np
.
testing
.
assert_allclose
(
op_res
.
asnumpy
(),
ref_res
,
rtol
=
0.01
)
def
test_transpose_infer_type
():
def
test_transpose_infer_type
():
n
,
t
,
d
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"t"
),
100
n
,
t
,
d
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"t"
),
100
...
@@ -226,7 +245,7 @@ if __name__ == "__main__":
...
@@ -226,7 +245,7 @@ if __name__ == "__main__":
test_cast
()
test_cast
()
test_zeros_ones
()
test_zeros_ones
()
test_unary_identity
()
test_unary_identity
()
test_clip
_type
()
test_clip
()
test_transpose_infer_type
()
test_transpose_infer_type
()
test_reshape_infer_type
()
test_reshape_infer_type
()
test_reshape_like
()
test_reshape_like
()
...
...
tests/python/relay/test_op_level4.py
View file @
4e77eeb2
import
tvm
import
tvm
import
numpy
as
np
import
numpy
as
np
from
tvm
import
relay
from
tvm
import
relay
from
tvm.relay
import
create_executor
def
test_binary_op
():
def
test_binary_op
():
def
check_binary_op
(
opfunc
):
def
check_binary_op
(
opfunc
,
ref
):
n
=
tvm
.
var
(
"n"
)
n
=
tvm
.
var
(
"n"
)
t1
=
relay
.
TensorType
((
5
,
n
,
5
))
t1
=
relay
.
TensorType
((
5
,
n
,
5
))
t2
=
relay
.
TensorType
((
n
,
1
))
t2
=
relay
.
TensorType
((
n
,
1
))
...
@@ -15,17 +16,30 @@ def test_binary_op():
...
@@ -15,17 +16,30 @@ def test_binary_op():
assert
(
"
%0
= {}(
%
x,
%
y)"
.
format
(
z
.
op
.
name
))
in
z
.
astext
()
assert
(
"
%0
= {}(
%
x,
%
y)"
.
format
(
z
.
op
.
name
))
in
z
.
astext
()
assert
relay
.
ir_pass
.
infer_type
(
z
)
.
checked_type
==
t1
assert
relay
.
ir_pass
.
infer_type
(
z
)
.
checked_type
==
t1
for
opfunc
in
[
relay
.
pow
]:
if
ref
is
not
None
:
check_binary_op
(
opfunc
)
t1
=
relay
.
TensorType
((
5
,
10
,
5
))
t2
=
relay
.
TensorType
((
5
,
10
,
5
))
x
=
relay
.
var
(
"x"
,
t1
)
y
=
relay
.
var
(
"y"
,
t2
)
z
=
opfunc
(
x
,
y
)
x_data
=
np
.
random
.
rand
(
5
,
10
,
5
)
.
astype
(
t1
.
dtype
)
y_data
=
np
.
random
.
rand
(
5
,
10
,
5
)
.
astype
(
t2
.
dtype
)
intrp
=
create_executor
()
op_res
=
intrp
.
evaluate
(
z
,
{
x
:
relay
.
const
(
x_data
),
y
:
relay
.
const
(
y_data
)
})
ref_res
=
ref
(
x_data
,
y_data
)
np
.
testing
.
assert_allclose
(
op_res
.
asnumpy
(),
ref_res
,
rtol
=
0.01
)
for
opfunc
,
ref
in
[(
relay
.
pow
,
np
.
power
)]:
check_binary_op
(
opfunc
,
ref
)
def
test_cmp_type
():
def
test_cmp_type
():
for
op
in
(
relay
.
greater
,
for
op
,
ref
in
((
relay
.
greater
,
np
.
greater
)
,
relay
.
greater_equal
,
(
relay
.
greater_equal
,
np
.
greater_equal
)
,
relay
.
less
,
(
relay
.
less
,
np
.
less
)
,
relay
.
less_equal
,
(
relay
.
less_equal
,
np
.
less_equal
)
,
relay
.
equal
,
(
relay
.
equal
,
np
.
equal
)
,
relay
.
not_equal
):
(
relay
.
not_equal
,
np
.
not_equal
)
):
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
10
,
4
),
"float32"
))
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
10
,
4
),
"float32"
))
y
=
relay
.
var
(
"y"
,
relay
.
TensorType
((
5
,
10
,
1
),
"float32"
))
y
=
relay
.
var
(
"y"
,
relay
.
TensorType
((
5
,
10
,
1
),
"float32"
))
z
=
op
(
x
,
y
)
z
=
op
(
x
,
y
)
...
@@ -33,18 +47,44 @@ def test_cmp_type():
...
@@ -33,18 +47,44 @@ def test_cmp_type():
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
assert
zz
.
checked_type
==
relay
.
TensorType
((
5
,
10
,
4
),
"bool"
)
assert
zz
.
checked_type
==
relay
.
TensorType
((
5
,
10
,
4
),
"bool"
)
if
ref
is
not
None
:
x_shape
=
(
10
,
4
)
y_shape
=
(
5
,
10
,
1
)
t1
=
relay
.
TensorType
(
x_shape
)
t2
=
relay
.
TensorType
(
y_shape
)
x
=
relay
.
var
(
"x"
,
t1
)
y
=
relay
.
var
(
"y"
,
t2
)
z
=
op
(
x
,
y
)
x_data
=
np
.
random
.
rand
(
*
x_shape
)
.
astype
(
t1
.
dtype
)
y_data
=
np
.
random
.
rand
(
*
y_shape
)
.
astype
(
t2
.
dtype
)
intrp
=
create_executor
()
op_res
=
intrp
.
evaluate
(
z
,
{
x
:
relay
.
const
(
x_data
),
y
:
relay
.
const
(
y_data
)
})
ref_res
=
ref
(
x_data
,
y_data
)
np
.
testing
.
assert_allclose
(
op_res
.
asnumpy
(),
ref_res
,
rtol
=
0.01
)
def
test_binary_int_broadcast
():
def
test_binary_int_broadcast
():
for
op
in
[
relay
.
right_shift
,
for
op
,
ref
in
[(
relay
.
right_shift
,
np
.
right_shift
)
,
relay
.
left_shift
,
(
relay
.
left_shift
,
np
.
left_shift
)
,
relay
.
maximum
,
(
relay
.
maximum
,
np
.
maximum
)
,
relay
.
minimum
]:
(
relay
.
minimum
,
np
.
minimum
)
]:
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
10
,
4
),
"int32"
))
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
10
,
4
),
"int32"
))
y
=
relay
.
var
(
"y"
,
relay
.
TensorType
((
5
,
10
,
1
),
"int32"
))
y
=
relay
.
var
(
"y"
,
relay
.
TensorType
((
5
,
10
,
1
),
"int32"
))
z
=
op
(
x
,
y
)
z
=
op
(
x
,
y
)
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
assert
zz
.
checked_type
==
relay
.
TensorType
((
5
,
10
,
4
),
"int32"
)
assert
zz
.
checked_type
==
relay
.
TensorType
((
5
,
10
,
4
),
"int32"
)
if
ref
is
not
None
:
x_shape
=
(
10
,
4
)
y_shape
=
(
5
,
10
,
1
)
t1
=
relay
.
TensorType
(
x_shape
,
'int32'
)
t2
=
relay
.
TensorType
(
y_shape
,
'int32'
)
x_data
=
np
.
random
.
rand
(
*
x_shape
)
.
astype
(
t1
.
dtype
)
y_data
=
np
.
random
.
rand
(
*
y_shape
)
.
astype
(
t2
.
dtype
)
intrp
=
create_executor
()
op_res
=
intrp
.
evaluate
(
z
,
{
x
:
relay
.
const
(
x_data
),
y
:
relay
.
const
(
y_data
)
})
ref_res
=
ref
(
x_data
,
y_data
)
np
.
testing
.
assert_allclose
(
op_res
.
asnumpy
(),
ref_res
,
rtol
=
0.01
)
def
test_where
():
def
test_where
():
cond
=
relay
.
var
(
"cond"
,
relay
.
TensorType
((
3
,
4
),
"float32"
))
cond
=
relay
.
var
(
"cond"
,
relay
.
TensorType
((
3
,
4
),
"float32"
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment