Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
b9038343
Unverified
Commit
b9038343
authored
Nov 25, 2018
by
Tianqi Chen
Committed by
GitHub
Nov 25, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY][OP] Move computes to cxx, enable concat as injective (#2166)
parent
f8f06595
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
139 additions
and
265 deletions
+139
-265
python/tvm/relay/backend/graph_runtime_codegen.py
+6
-10
python/tvm/relay/frontend/mxnet.py
+2
-2
python/tvm/relay/op/_tensor.py
+3
-190
src/relay/backend/compile_engine.cc
+13
-16
src/relay/backend/interpreter.cc
+10
-13
src/relay/op/tensor/binary.cc
+53
-16
src/relay/op/tensor/unary.cc
+44
-12
tests/python/relay/test_op_level1.py
+8
-6
No files found.
python/tvm/relay/backend/graph_runtime_codegen.py
View file @
b9038343
...
@@ -236,18 +236,14 @@ class GraphRuntimeCodegen(ExprFunctor):
...
@@ -236,18 +236,14 @@ class GraphRuntimeCodegen(ExprFunctor):
self
.
lowered_funcs
.
add
(
loweredf
)
self
.
lowered_funcs
.
add
(
loweredf
)
inputs
=
[]
inputs
=
[]
tuple_arg_count
=
0
# flatten tuple in the call.
for
arg
in
call
.
args
:
for
arg
in
call
.
args
:
res
=
self
.
visit
(
arg
)
if
isinstance
(
arg
.
checked_type
,
TupleType
):
if
isinstance
(
arg
.
checked_type
,
TupleType
):
tuple_arg_count
+=
1
assert
isinstance
(
res
,
tuple
)
inputs
.
append
(
self
.
visit
(
arg
))
inputs
+=
res
# We need to specially handle tuple inputs and
else
:
# tuple output cases.
inputs
.
append
(
res
)
# Tuple input function(e.g. concat)
if
tuple_arg_count
:
assert
len
(
call
.
args
)
==
1
assert
isinstance
(
inputs
[
0
],
tuple
)
inputs
=
list
(
inputs
[
0
])
inputs
=
[
x
.
to_json
()
for
x
in
inputs
]
inputs
=
[
x
.
to_json
()
for
x
in
inputs
]
op_name
=
cached_func
.
func_name
op_name
=
cached_func
.
func_name
...
...
python/tvm/relay/frontend/mxnet.py
View file @
b9038343
...
@@ -589,11 +589,11 @@ def from_mxnet(symbol,
...
@@ -589,11 +589,11 @@ def from_mxnet(symbol,
shape
,
dtype
=
_update_shape_dtype
(
shape
,
dtype
,
params
)
shape
,
dtype
=
_update_shape_dtype
(
shape
,
dtype
,
params
)
sym
=
_from_mxnet_impl
(
symbol
,
shape
,
dtype
)
sym
=
_from_mxnet_impl
(
symbol
,
shape
,
dtype
)
elif
isinstance
(
symbol
,
mx
.
gluon
.
HybridBlock
):
elif
isinstance
(
symbol
,
mx
.
gluon
.
HybridBlock
):
if
arg
s
_params
is
not
None
or
aux_params
is
not
None
:
if
arg_params
is
not
None
or
aux_params
is
not
None
:
raise
ValueError
(
"arg_params and aux_params ae not used when importing HybridBlock"
)
raise
ValueError
(
"arg_params and aux_params ae not used when importing HybridBlock"
)
params
=
{}
params
=
{}
for
k
,
v
in
symbol
.
collect_params
()
.
items
():
for
k
,
v
in
symbol
.
collect_params
()
.
items
():
params
[
k
]
=
tvm
.
nd
.
array
(
v
.
data
()
.
asnumpy
())
params
[
k
]
=
_
nd
.
array
(
v
.
data
()
.
asnumpy
())
data
=
mx
.
sym
.
Variable
(
"data"
)
data
=
mx
.
sym
.
Variable
(
"data"
)
sym
=
symbol
(
data
)
sym
=
symbol
(
data
)
shape
,
dtype
=
_update_shape_dtype
(
shape
,
dtype
,
params
)
shape
,
dtype
=
_update_shape_dtype
(
shape
,
dtype
,
params
)
...
...
python/tvm/relay/op/_tensor.py
View file @
b9038343
...
@@ -5,223 +5,37 @@ import topi
...
@@ -5,223 +5,37 @@ import topi
from
.op
import
register_compute
,
register_schedule
,
register_pattern
from
.op
import
register_compute
,
register_schedule
,
register_pattern
from
.op
import
schedule_injective
,
OpPattern
from
.op
import
schedule_injective
,
OpPattern
schedule_broadcast
=
schedule_injective
schedule_broadcast
=
schedule_injective
schedule_elemwise
=
schedule_injective
schedule_elemwise
=
schedule_injective
# log
@register_compute
(
"log"
)
def
log_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
1
return
[
topi
.
log
(
inputs
[
0
])]
register_schedule
(
"log"
,
schedule_broadcast
)
register_schedule
(
"log"
,
schedule_broadcast
)
# exp
@register_compute
(
"exp"
)
def
exp_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
1
return
[
topi
.
exp
(
inputs
[
0
])]
register_schedule
(
"exp"
,
schedule_broadcast
)
register_schedule
(
"exp"
,
schedule_broadcast
)
# sqrt
@register_compute
(
"sqrt"
)
def
sqrt_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
1
return
[
topi
.
sqrt
(
inputs
[
0
])]
register_schedule
(
"sqrt"
,
schedule_broadcast
)
register_schedule
(
"sqrt"
,
schedule_broadcast
)
# sigmoid
@register_compute
(
"sigmoid"
)
def
sigmoid_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
1
return
[
topi
.
sigmoid
(
inputs
[
0
])]
register_schedule
(
"sigmoid"
,
schedule_broadcast
)
register_schedule
(
"sigmoid"
,
schedule_broadcast
)
# floor
@register_compute
(
"floor"
)
def
floor_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
1
return
[
topi
.
floor
(
inputs
[
0
])]
register_schedule
(
"floor"
,
schedule_broadcast
)
register_schedule
(
"floor"
,
schedule_broadcast
)
# ceil
@register_compute
(
"ceil"
)
def
ceil_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
1
return
[
topi
.
ceil
(
inputs
[
0
])]
register_schedule
(
"ceil"
,
schedule_broadcast
)
register_schedule
(
"ceil"
,
schedule_broadcast
)
# trunc
@register_compute
(
"trunc"
)
def
trunc_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
1
return
[
topi
.
trunc
(
inputs
[
0
])]
register_schedule
(
"trunc"
,
schedule_broadcast
)
register_schedule
(
"trunc"
,
schedule_broadcast
)
# round
@register_compute
(
"round"
)
def
round_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
1
return
[
topi
.
round
(
inputs
[
0
])]
register_schedule
(
"round"
,
schedule_broadcast
)
register_schedule
(
"round"
,
schedule_broadcast
)
# abs
@register_compute
(
"abs"
)
def
abs_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
1
return
[
topi
.
abs
(
inputs
[
0
])]
register_schedule
(
"abs"
,
schedule_broadcast
)
register_schedule
(
"abs"
,
schedule_broadcast
)
# tanh
@register_compute
(
"tanh"
)
def
tanh_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
1
return
[
topi
.
tanh
(
inputs
[
0
])]
register_schedule
(
"tanh"
,
schedule_broadcast
)
register_schedule
(
"tanh"
,
schedule_broadcast
)
# negative
@register_compute
(
"negative"
)
def
negative_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
1
return
[
topi
.
negative
(
inputs
[
0
])]
register_schedule
(
"negative"
,
schedule_broadcast
)
register_schedule
(
"negative"
,
schedule_broadcast
)
# add
register_schedule
(
"add"
,
schedule_broadcast
)
@register_compute
(
"add"
)
def
add_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
add
(
inputs
[
0
],
inputs
[
1
])]
register_schedule
(
"add"
,
schedule_injective
)
# subtract
@register_compute
(
"subtract"
)
def
subtract_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
subtract
(
inputs
[
0
],
inputs
[
1
])]
register_schedule
(
"subtract"
,
schedule_broadcast
)
register_schedule
(
"subtract"
,
schedule_broadcast
)
# multiply
@register_compute
(
"multiply"
)
def
multiply_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
multiply
(
inputs
[
0
],
inputs
[
1
])]
register_schedule
(
"multiply"
,
schedule_broadcast
)
register_schedule
(
"multiply"
,
schedule_broadcast
)
# divide
@register_compute
(
"divide"
)
def
divide_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
divide
(
inputs
[
0
],
inputs
[
1
])]
register_schedule
(
"divide"
,
schedule_broadcast
)
register_schedule
(
"divide"
,
schedule_broadcast
)
# power
@register_compute
(
"power"
)
def
power_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
power
(
inputs
[
0
],
inputs
[
1
])]
register_schedule
(
"power"
,
schedule_injective
)
register_schedule
(
"power"
,
schedule_injective
)
# mod
@register_compute
(
"mod"
)
def
mod_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
mod
(
inputs
[
0
],
inputs
[
1
])]
register_schedule
(
"mod"
,
schedule_broadcast
)
register_schedule
(
"mod"
,
schedule_broadcast
)
# equal
@register_compute
(
"equal"
)
def
equal_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
equal
(
inputs
[
0
],
inputs
[
1
])]
register_schedule
(
"equal"
,
schedule_broadcast
)
register_schedule
(
"equal"
,
schedule_broadcast
)
# not_equal
@register_compute
(
"not_equal"
)
def
not_equal_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
not_equal
(
inputs
[
0
],
inputs
[
1
])]
register_schedule
(
"not_equal"
,
schedule_broadcast
)
register_schedule
(
"not_equal"
,
schedule_broadcast
)
# less
@register_compute
(
"less"
)
def
less_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
less
(
inputs
[
0
],
inputs
[
1
])]
register_schedule
(
"less"
,
schedule_broadcast
)
register_schedule
(
"less"
,
schedule_broadcast
)
# less equal
@register_compute
(
"less_equal"
)
def
less_equal_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
less_equal
(
inputs
[
0
],
inputs
[
1
])]
register_schedule
(
"less_equal"
,
schedule_broadcast
)
register_schedule
(
"less_equal"
,
schedule_broadcast
)
# greater
@register_compute
(
"greater"
)
def
greater_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
greater
(
inputs
[
0
],
inputs
[
1
])]
register_schedule
(
"greater"
,
schedule_broadcast
)
register_schedule
(
"greater"
,
schedule_broadcast
)
# greater equal
@register_compute
(
"greater_equal"
)
def
greater_equal_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
greater_equal
(
inputs
[
0
],
inputs
[
1
])]
register_schedule
(
"greater_equal"
,
schedule_broadcast
)
register_schedule
(
"greater_equal"
,
schedule_broadcast
)
# maximum
@register_compute
(
"maximum"
)
def
maximum_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
maximum
(
inputs
[
0
],
inputs
[
1
])]
register_schedule
(
"maximum_compute"
,
schedule_injective
)
register_schedule
(
"maximum_compute"
,
schedule_injective
)
# minimum
@register_compute
(
"minimum"
)
def
minimum_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
minimum
(
inputs
[
0
],
inputs
[
1
])]
register_schedule
(
"minimum"
,
schedule_injective
)
register_schedule
(
"minimum"
,
schedule_injective
)
# right shift
@register_compute
(
"right_shift"
)
def
right_shift_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
right_shift
(
inputs
[
0
],
inputs
[
1
])]
register_schedule
(
"right_shift"
,
schedule_injective
)
register_schedule
(
"right_shift"
,
schedule_injective
)
# left shift
@register_compute
(
"left_shift"
)
def
left_shift_compute
(
attrs
,
inputs
,
output_type
,
target
):
assert
len
(
inputs
)
==
2
return
[
topi
.
left_shift
(
inputs
[
0
],
inputs
[
1
])]
register_schedule
(
"left_shift"
,
schedule_injective
)
register_schedule
(
"left_shift"
,
schedule_injective
)
# zeros
# zeros
...
@@ -273,5 +87,4 @@ def concatenate_compute(attrs, inputs, output_type, target):
...
@@ -273,5 +87,4 @@ def concatenate_compute(attrs, inputs, output_type, target):
return
[
topi
.
concatenate
(
inputs
,
axis
=
attrs
.
axis
)]
return
[
topi
.
concatenate
(
inputs
,
axis
=
attrs
.
axis
)]
register_schedule
(
"concatenate"
,
schedule_injective
)
register_schedule
(
"concatenate"
,
schedule_injective
)
# TODO(tqchen): renable concat as injective
register_pattern
(
"concatenate"
,
OpPattern
.
INJECTIVE
)
register_pattern
(
"concatenate"
,
OpPattern
.
OPAQUE
)
src/relay/backend/compile_engine.cc
View file @
b9038343
...
@@ -56,30 +56,26 @@ class ScheduleGetter :
...
@@ -56,30 +56,26 @@ class ScheduleGetter :
Op
::
GetAttr
<
FTVMSchedule
>
(
"FTVMSchedule"
);
Op
::
GetAttr
<
FTVMSchedule
>
(
"FTVMSchedule"
);
auto
cache_node
=
make_node
<
CachedFuncNode
>
();
auto
cache_node
=
make_node
<
CachedFuncNode
>
();
cache_node
->
target
=
target_
;
cache_node
->
target
=
target_
;
for
(
Var
param
:
prim_func
->
params
)
{
if
(
prim_func
->
params
.
size
()
==
1
&&
prim_func
->
params
[
0
]
->
checked_type
().
as
<
TupleTypeNode
>
())
{
// Handle tuple input type by flattening them.
// This is the current calling convention of tuple input.
Array
<
tvm
::
Tensor
>
inputs
;
Array
<
tvm
::
Tensor
>
inputs
;
for
(
Type
field
:
prim_func
->
params
[
0
]
->
type_as
<
TupleTypeNode
>
()
->
fields
)
{
if
(
const
auto
*
ttype
=
param
->
checked_type
().
as
<
TensorTypeNode
>
())
{
const
auto
*
ttype
=
field
.
as
<
TensorTypeNode
>
();
CHECK
(
ttype
!=
nullptr
);
tvm
::
Tensor
tensor
=
tvm
::
placeholder
(
tvm
::
Tensor
tensor
=
tvm
::
placeholder
(
GetShape
(
ttype
->
shape
),
ttype
->
dtype
);
GetShape
(
ttype
->
shape
),
ttype
->
dtype
);
cache_node
->
inputs
.
push_back
(
tensor
);
cache_node
->
inputs
.
push_back
(
tensor
);
inputs
.
push_back
(
tensor
);
inputs
.
push_back
(
tensor
);
}
memo_
[
prim_func
->
params
[
0
]]
=
inputs
;
}
else
{
}
else
{
for
(
Var
param
:
prim_func
->
params
)
{
// flatten tuple of tensor type.
const
auto
*
ttype
=
param
->
type_as
<
TensorTypeNode
>
();
const
auto
*
tuple_type
=
param
->
type_as
<
TupleTypeNode
>
();
for
(
Type
field
:
tuple_type
->
fields
)
{
const
auto
*
ttype
=
field
.
as
<
TensorTypeNode
>
();
CHECK
(
ttype
!=
nullptr
);
tvm
::
Tensor
tensor
=
tvm
::
placeholder
(
tvm
::
Tensor
tensor
=
tvm
::
placeholder
(
GetShape
(
ttype
->
shape
),
ttype
->
dtype
);
GetShape
(
ttype
->
shape
),
ttype
->
dtype
);
cache_node
->
inputs
.
push_back
(
tensor
);
cache_node
->
inputs
.
push_back
(
tensor
);
memo_
[
param
]
=
Array
<
Tensor
>
({
tensor
});
inputs
.
push_back
(
tensor
);
}
}
}
memo_
[
param
]
=
inputs
;
}
}
readable_name_stream_
<<
"fused"
;
readable_name_stream_
<<
"fused"
;
cache_node
->
outputs
=
this
->
VisitExpr
(
prim_func
->
body
);
cache_node
->
outputs
=
this
->
VisitExpr
(
prim_func
->
body
);
...
@@ -161,8 +157,9 @@ class ScheduleGetter :
...
@@ -161,8 +157,9 @@ class ScheduleGetter :
int
op_pattern
=
fpattern
[
op
];
int
op_pattern
=
fpattern
[
op
];
if
(
op_pattern
>=
kCommReduce
)
{
if
(
op_pattern
>=
kCommReduce
)
{
CHECK
(
!
master_op_
.
defined
())
CHECK
(
!
master_op_
.
defined
()
||
master_op_patetrn_
<
kCommReduce
)
<<
"Two complicated op in a primitive function"
;
<<
"Two complicated op in a primitive function "
<<
" master="
<<
master_op_
<<
" current="
<<
op
;
}
}
if
(
op_pattern
>=
master_op_patetrn_
)
{
if
(
op_pattern
>=
master_op_patetrn_
)
{
master_op_
=
op
;
master_op_
=
op
;
...
...
src/relay/backend/interpreter.cc
View file @
b9038343
...
@@ -212,7 +212,7 @@ class Interpreter :
...
@@ -212,7 +212,7 @@ class Interpreter :
// Marshal the arguments.
// Marshal the arguments.
// Handle tuple input/output by flattening them.
// Handle tuple input/output by flattening them.
size_t
arg_len
=
0
;
size_t
arg_len
=
0
;
for
(
size_t
i
=
0
;
i
<
args
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
args
.
size
();
++
i
)
{
if
(
args
[
i
].
as
<
TensorValueNode
>
())
{
if
(
args
[
i
].
as
<
TensorValueNode
>
())
{
++
arg_len
;
++
arg_len
;
}
else
{
}
else
{
...
@@ -242,22 +242,19 @@ class Interpreter :
...
@@ -242,22 +242,19 @@ class Interpreter :
<<
context_
<<
", but get "
<<
arg_ctx
;
<<
context_
<<
", but get "
<<
arg_ctx
;
};
};
if
(
func
->
params
.
size
()
==
1
&&
int
arg_counter
=
0
;
func
->
params
[
0
]
->
checked_type
().
as
<
TupleTypeNode
>
())
{
for
(
Value
arg
:
args
)
{
// handle tuple input.
if
(
arg
.
as
<
TensorValueNode
>
())
{
const
TupleValueNode
*
tuple
=
args
[
0
].
as
<
TupleValueNode
>
();
fset_input
(
arg_counter
++
,
arg
);
CHECK
(
tuple
);
}
else
{
const
TupleValueNode
*
tuple
=
arg
.
as
<
TupleValueNode
>
();
CHECK
(
tuple
!=
nullptr
);
for
(
size_t
i
=
0
;
i
<
tuple
->
fields
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
tuple
->
fields
.
size
();
++
i
)
{
fset_input
(
i
,
tuple
->
fields
[
i
]);
fset_input
(
arg_counter
++
,
tuple
->
fields
[
i
]);
}
}
}
else
{
CHECK_EQ
(
num_inputs
,
args
.
size
());
// Decide the target context.
// Primitive functions always sit in the same context.
for
(
size_t
i
=
0
;
i
<
args
.
size
();
i
++
)
{
fset_input
(
i
,
args
[
i
]);
}
}
}
}
// TVM's calling convention is that the final argument is the output
// TVM's calling convention is that the final argument is the output
// buffer. To preserve the illusion of being a functional language
// buffer. To preserve the illusion of being a functional language
// we need to allocate space for the output buffer based on the
// we need to allocate space for the output buffer based on the
...
...
src/relay/op/tensor/binary.cc
View file @
b9038343
...
@@ -5,54 +5,75 @@
...
@@ -5,54 +5,75 @@
*/
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op.h>
#include <topi/broadcast.h>
#include "../type_relations.h"
#include "../type_relations.h"
#include "../op_common.h"
#include "../op_common.h"
namespace
tvm
{
namespace
tvm
{
namespace
relay
{
namespace
relay
{
#define RELAY_BINARY_COMPUTE(FTOPI) \
[] (const Attrs& attrs, \
const Array<Tensor>& inputs, \
const Type& out_type, \
const Target& target) -> Array<Tensor> { \
CHECK_EQ(inputs.size(), 2U); \
return {FTOPI(inputs[0], inputs[1])}; \
} \
// Addition
// Addition
RELAY_REGISTER_BINARY_OP
(
"relay.op._make."
,
"add"
)
RELAY_REGISTER_BINARY_OP
(
"relay.op._make."
,
"add"
)
.
describe
(
"Elementwise add with with broadcasting"
)
.
describe
(
"Elementwise add with with broadcasting"
)
.
set_support_level
(
1
);
.
set_support_level
(
1
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_BINARY_COMPUTE
(
topi
::
add
));
// Subtraction
// Subtraction
RELAY_REGISTER_BINARY_OP
(
"relay.op._make."
,
"subtract"
)
RELAY_REGISTER_BINARY_OP
(
"relay.op._make."
,
"subtract"
)
.
describe
(
"Elementwise substract with broadcasting"
)
.
describe
(
"Elementwise substract with broadcasting"
)
.
set_support_level
(
1
);
.
set_support_level
(
1
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_BINARY_COMPUTE
(
topi
::
subtract
));
// Right shift
// Right shift
RELAY_REGISTER_BINARY_OP
(
"relay.op._make."
,
"right_shift"
)
RELAY_REGISTER_BINARY_OP
(
"relay.op._make."
,
"right_shift"
)
.
describe
(
"Elementwise right shift with broadcasting"
)
.
describe
(
"Elementwise right shift with broadcasting"
)
.
set_support_level
(
4
);
.
set_support_level
(
4
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_BINARY_COMPUTE
(
topi
::
right_shift
));
RELAY_REGISTER_BINARY_OP
(
"relay.op._make."
,
"left_shift"
)
RELAY_REGISTER_BINARY_OP
(
"relay.op._make."
,
"left_shift"
)
.
describe
(
"Elementwise left shift with broadcasting"
)
.
describe
(
"Elementwise left shift with broadcasting"
)
.
set_support_level
(
4
);
.
set_support_level
(
4
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_BINARY_COMPUTE
(
topi
::
left_shift
));
RELAY_REGISTER_BINARY_OP
(
"relay.op._make."
,
"maximum"
)
RELAY_REGISTER_BINARY_OP
(
"relay.op._make."
,
"maximum"
)
.
describe
(
"Elementwise maximum of two tensors with broadcasting"
)
.
describe
(
"Elementwise maximum of two tensors with broadcasting"
)
.
set_support_level
(
4
);
.
set_support_level
(
4
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_BINARY_COMPUTE
(
topi
::
maximum
));
RELAY_REGISTER_BINARY_OP
(
"relay.op._make."
,
"minimum"
)
RELAY_REGISTER_BINARY_OP
(
"relay.op._make."
,
"minimum"
)
.
describe
(
"Elementwise minimum of two tensors with broadcasting"
)
.
describe
(
"Elementwise minimum of two tensors with broadcasting"
)
.
set_support_level
(
4
);
.
set_support_level
(
4
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_BINARY_COMPUTE
(
topi
::
minimum
));
RELAY_REGISTER_BINARY_OP
(
"relay.op._make."
,
"divide"
)
RELAY_REGISTER_BINARY_OP
(
"relay.op._make."
,
"divide"
)
.
describe
(
"Elementwise divide with broadcasting"
)
.
describe
(
"Elementwise divide with broadcasting"
)
.
set_support_level
(
1
);
.
set_support_level
(
1
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_BINARY_COMPUTE
(
topi
::
divide
));
RELAY_REGISTER_BINARY_OP
(
"relay.op._make."
,
"multiply"
)
RELAY_REGISTER_BINARY_OP
(
"relay.op._make."
,
"multiply"
)
.
describe
(
"Elementwise multiply with broadcasting"
)
.
describe
(
"Elementwise multiply with broadcasting"
)
.
set_support_level
(
1
);
.
set_support_level
(
1
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_BINARY_COMPUTE
(
topi
::
multiply
));
RELAY_REGISTER_BINARY_OP
(
"relay.op._make."
,
"power"
)
RELAY_REGISTER_BINARY_OP
(
"relay.op._make."
,
"power"
)
.
describe
(
"Elementwise power with broadcasting"
)
.
describe
(
"Elementwise power with broadcasting"
)
.
set_support_level
(
4
);
.
set_support_level
(
4
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_BINARY_COMPUTE
(
topi
::
power
));
RELAY_REGISTER_BINARY_OP
(
"relay.op._make."
,
"mod"
)
RELAY_REGISTER_BINARY_OP
(
"relay.op._make."
,
"mod"
)
.
describe
(
"Elementwise mod with broadcasting"
)
.
describe
(
"Elementwise mod with broadcasting"
)
.
set_support_level
(
1
);
.
set_support_level
(
1
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_BINARY_COMPUTE
(
topi
::
mod
));
// Comparisons
// Comparisons
#define RELAY_REGISTER_CMP_OP(OpName) \
#define RELAY_REGISTER_CMP_OP(OpName) \
...
@@ -70,22 +91,38 @@ RELAY_REGISTER_BINARY_OP("relay.op._make.", "mod")
...
@@ -70,22 +91,38 @@ RELAY_REGISTER_BINARY_OP("relay.op._make.", "mod")
RELAY_REGISTER_CMP_OP
(
"equal"
)
RELAY_REGISTER_CMP_OP
(
"equal"
)
.
describe
(
"Elementwise equal compare with broadcasting"
)
.
describe
(
"Elementwise equal compare with broadcasting"
)
.
set_support_level
(
4
);
.
set_support_level
(
4
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_BINARY_COMPUTE
(
topi
::
equal
));
RELAY_REGISTER_CMP_OP
(
"not_equal"
)
RELAY_REGISTER_CMP_OP
(
"not_equal"
)
.
describe
(
"Elementwise not equal with broadcasting"
)
.
describe
(
"Elementwise not equal with broadcasting"
)
.
set_support_level
(
4
);
.
set_support_level
(
4
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_BINARY_COMPUTE
(
topi
::
not_equal
));
RELAY_REGISTER_CMP_OP
(
"less"
)
RELAY_REGISTER_CMP_OP
(
"less"
)
.
describe
(
"Elementwise less than with broadcasting"
)
.
describe
(
"Elementwise less than with broadcasting"
)
.
set_support_level
(
4
);
.
set_support_level
(
4
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_BINARY_COMPUTE
(
topi
::
less
));
RELAY_REGISTER_CMP_OP
(
"less_equal"
)
RELAY_REGISTER_CMP_OP
(
"less_equal"
)
.
describe
(
"Elementwise less than or equal compare with broadcasting"
)
.
describe
(
"Elementwise less than or equal compare with broadcasting"
)
.
set_support_level
(
4
);
.
set_support_level
(
4
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_BINARY_COMPUTE
(
topi
::
less_equal
));
RELAY_REGISTER_CMP_OP
(
"greater"
)
RELAY_REGISTER_CMP_OP
(
"greater"
)
.
describe
(
"Elementwise greater than compare with broadcasting"
)
.
describe
(
"Elementwise greater than compare with broadcasting"
)
.
set_support_level
(
4
);
.
set_support_level
(
4
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_BINARY_COMPUTE
(
topi
::
greater
));
RELAY_REGISTER_CMP_OP
(
"greater_equal"
)
RELAY_REGISTER_CMP_OP
(
"greater_equal"
)
.
describe
(
"Elementwise greater than or equal compare with broadcasting"
)
.
describe
(
"Elementwise greater than or equal compare with broadcasting"
)
.
set_support_level
(
4
);
.
set_support_level
(
4
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_BINARY_COMPUTE
(
topi
::
greater_equal
));
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
src/relay/op/tensor/unary.cc
View file @
b9038343
...
@@ -5,12 +5,21 @@
...
@@ -5,12 +5,21 @@
*/
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op.h>
#include <topi/elemwise.h>
#include "../type_relations.h"
#include "../type_relations.h"
#include "../op_common.h"
#include "../op_common.h"
namespace
tvm
{
namespace
tvm
{
namespace
relay
{
namespace
relay
{
#define RELAY_UNARY_COMPUTE(FTOPI) \
[] (const Attrs& attrs, \
const Array<Tensor>& inputs, \
const Type& out_type, \
const Target& target) -> Array<Tensor> { \
return {FTOPI(inputs[0])}; \
} \
RELAY_REGISTER_UNARY_OP
(
"relay.op._make."
,
"log"
)
RELAY_REGISTER_UNARY_OP
(
"relay.op._make."
,
"log"
)
.
describe
(
R"code(Returns the log input array, computed element-wise.
.
describe
(
R"code(Returns the log input array, computed element-wise.
...
@@ -20,7 +29,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "log")
...
@@ -20,7 +29,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "log")
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_support_level
(
1
)
.
set_support_level
(
1
)
.
add_type_rel
(
"Identity"
,
IdentityRel
);
.
add_type_rel
(
"Identity"
,
IdentityRel
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_UNARY_COMPUTE
(
topi
::
log
));
RELAY_REGISTER_UNARY_OP
(
"relay.op._make."
,
"exp"
)
RELAY_REGISTER_UNARY_OP
(
"relay.op._make."
,
"exp"
)
.
describe
(
R"code(Returns the exp input array, computed element-wise.
.
describe
(
R"code(Returns the exp input array, computed element-wise.
...
@@ -30,7 +41,8 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "exp")
...
@@ -30,7 +41,8 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "exp")
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_support_level
(
1
)
.
set_support_level
(
1
)
.
add_type_rel
(
"Identity"
,
IdentityRel
);
.
add_type_rel
(
"Identity"
,
IdentityRel
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_UNARY_COMPUTE
(
topi
::
exp
));
RELAY_REGISTER_UNARY_OP
(
"relay.op._make."
,
"sqrt"
)
RELAY_REGISTER_UNARY_OP
(
"relay.op._make."
,
"sqrt"
)
...
@@ -41,7 +53,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "sqrt")
...
@@ -41,7 +53,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "sqrt")
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_support_level
(
1
)
.
set_support_level
(
1
)
.
add_type_rel
(
"Identity"
,
IdentityRel
);
.
add_type_rel
(
"Identity"
,
IdentityRel
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_UNARY_COMPUTE
(
topi
::
sqrt
));
RELAY_REGISTER_UNARY_OP
(
"relay.op._make."
,
"zeros_like"
)
RELAY_REGISTER_UNARY_OP
(
"relay.op._make."
,
"zeros_like"
)
.
describe
(
R"code(Returns an array of zeros, with same type and shape as the input.
.
describe
(
R"code(Returns an array of zeros, with same type and shape as the input.
...
@@ -49,6 +63,7 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "zeros_like")
...
@@ -49,6 +63,7 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "zeros_like")
.
set_support_level
(
1
)
.
set_support_level
(
1
)
.
add_type_rel
(
"Identity"
,
IdentityRel
);
.
add_type_rel
(
"Identity"
,
IdentityRel
);
RELAY_REGISTER_UNARY_OP
(
"relay.op._make."
,
"ones_like"
)
RELAY_REGISTER_UNARY_OP
(
"relay.op._make."
,
"ones_like"
)
.
describe
(
R"code(Returns an array of ones, with same type and shape as the input.
.
describe
(
R"code(Returns an array of ones, with same type and shape as the input.
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
...
@@ -63,13 +78,17 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "sigmoid")
...
@@ -63,13 +78,17 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "sigmoid")
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_support_level
(
1
)
.
set_support_level
(
1
)
.
add_type_rel
(
"Identity"
,
IdentityRel
);
.
add_type_rel
(
"Identity"
,
IdentityRel
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_UNARY_COMPUTE
(
topi
::
sigmoid
));
RELAY_REGISTER_UNARY_OP
(
"relay.op._make."
,
"copy"
)
RELAY_REGISTER_UNARY_OP
(
"relay.op._make."
,
"copy"
)
.
describe
(
R"code(Copy a tensor.
.
describe
(
R"code(Copy a tensor.
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_support_level
(
3
)
.
set_support_level
(
3
)
.
add_type_rel
(
"Identity"
,
IdentityRel
);
.
add_type_rel
(
"Identity"
,
IdentityRel
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_UNARY_COMPUTE
(
topi
::
identity
));
// Clip
// Clip
struct
ClipAttrs
:
public
tvm
::
AttrsNode
<
ClipAttrs
>
{
struct
ClipAttrs
:
public
tvm
::
AttrsNode
<
ClipAttrs
>
{
...
@@ -107,7 +126,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "floor")
...
@@ -107,7 +126,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "floor")
.
describe
(
R"code(Returns the floor of input array, computed element-wise.
.
describe
(
R"code(Returns the floor of input array, computed element-wise.
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_support_level
(
3
)
.
set_support_level
(
3
)
.
add_type_rel
(
"Identity"
,
IdentityRel
);
.
add_type_rel
(
"Identity"
,
IdentityRel
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_UNARY_COMPUTE
(
topi
::
floor
));
RELAY_REGISTER_UNARY_OP
(
"relay.op._make."
,
"ceil"
)
RELAY_REGISTER_UNARY_OP
(
"relay.op._make."
,
"ceil"
)
.
describe
(
R"code(Returns the ceil of input array, computed element-wise.
.
describe
(
R"code(Returns the ceil of input array, computed element-wise.
...
@@ -117,7 +138,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "ceil")
...
@@ -117,7 +138,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "ceil")
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_support_level
(
3
)
.
set_support_level
(
3
)
.
add_type_rel
(
"Identity"
,
IdentityRel
);
.
add_type_rel
(
"Identity"
,
IdentityRel
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_UNARY_COMPUTE
(
topi
::
ceil
));
RELAY_REGISTER_UNARY_OP
(
"relay.op._make."
,
"trunc"
)
RELAY_REGISTER_UNARY_OP
(
"relay.op._make."
,
"trunc"
)
.
describe
(
R"code(Returns the trunc of input array, computed element-wise.
.
describe
(
R"code(Returns the trunc of input array, computed element-wise.
...
@@ -127,7 +150,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "trunc")
...
@@ -127,7 +150,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "trunc")
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_support_level
(
3
)
.
set_support_level
(
3
)
.
add_type_rel
(
"Identity"
,
IdentityRel
);
.
add_type_rel
(
"Identity"
,
IdentityRel
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_UNARY_COMPUTE
(
topi
::
trunc
));
RELAY_REGISTER_UNARY_OP
(
"relay.op._make."
,
"round"
)
RELAY_REGISTER_UNARY_OP
(
"relay.op._make."
,
"round"
)
.
describe
(
R"code(Returns the round of input array, computed element-wise.
.
describe
(
R"code(Returns the round of input array, computed element-wise.
...
@@ -137,7 +162,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "round")
...
@@ -137,7 +162,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "round")
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_support_level
(
3
)
.
set_support_level
(
3
)
.
add_type_rel
(
"Identity"
,
IdentityRel
);
.
add_type_rel
(
"Identity"
,
IdentityRel
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_UNARY_COMPUTE
(
topi
::
round
));
RELAY_REGISTER_UNARY_OP
(
"relay.op._make."
,
"abs"
)
RELAY_REGISTER_UNARY_OP
(
"relay.op._make."
,
"abs"
)
.
describe
(
R"code(Returns the abs of input array, computed element-wise.
.
describe
(
R"code(Returns the abs of input array, computed element-wise.
...
@@ -147,7 +174,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "abs")
...
@@ -147,7 +174,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "abs")
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_support_level
(
3
)
.
set_support_level
(
3
)
.
add_type_rel
(
"Identity"
,
IdentityRel
);
.
add_type_rel
(
"Identity"
,
IdentityRel
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_UNARY_COMPUTE
(
topi
::
abs
));
RELAY_REGISTER_UNARY_OP
(
"relay.op._make."
,
"tanh"
)
RELAY_REGISTER_UNARY_OP
(
"relay.op._make."
,
"tanh"
)
.
describe
(
R"code(Returns the tanh of input array, computed element-wise.
.
describe
(
R"code(Returns the tanh of input array, computed element-wise.
...
@@ -157,7 +186,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "tanh")
...
@@ -157,7 +186,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "tanh")
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_support_level
(
1
)
.
set_support_level
(
1
)
.
add_type_rel
(
"Identity"
,
IdentityRel
);
.
add_type_rel
(
"Identity"
,
IdentityRel
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_UNARY_COMPUTE
(
topi
::
tanh
));
RELAY_REGISTER_UNARY_OP
(
"relay.op._make."
,
"negative"
)
RELAY_REGISTER_UNARY_OP
(
"relay.op._make."
,
"negative"
)
.
describe
(
R"code(Returns the numeric negative of input array, computed element-wise.
.
describe
(
R"code(Returns the numeric negative of input array, computed element-wise.
...
@@ -167,7 +198,8 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "negative")
...
@@ -167,7 +198,8 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "negative")
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_support_level
(
3
)
.
set_support_level
(
3
)
.
add_type_rel
(
"Identity"
,
IdentityRel
);
.
add_type_rel
(
"Identity"
,
IdentityRel
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_UNARY_COMPUTE
(
topi
::
negative
));
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
tests/python/relay/test_op_level1.py
View file @
b9038343
...
@@ -188,20 +188,22 @@ def test_concatenate():
...
@@ -188,20 +188,22 @@ def test_concatenate():
x
=
relay
.
var
(
"x"
,
shape
=
(
10
,
5
))
x
=
relay
.
var
(
"x"
,
shape
=
(
10
,
5
))
y
=
relay
.
var
(
"y"
,
shape
=
(
10
,
5
))
y
=
relay
.
var
(
"y"
,
shape
=
(
10
,
5
))
t
=
relay
.
var
(
"z"
,
shape
=
())
z
=
relay
.
concatenate
((
x
,
y
),
axis
=
1
)
z
=
relay
.
concatenate
((
x
,
y
),
axis
=
1
)
z
=
relay
.
add
(
z
,
t
)
# Check result.
# Check result.
func
=
relay
.
Function
([
x
,
y
],
z
)
func
=
relay
.
Function
([
x
,
y
,
t
],
z
)
x_data
=
np
.
random
.
rand
(
10
,
5
)
.
astype
(
'float32'
)
x_data
=
np
.
random
.
rand
(
10
,
5
)
.
astype
(
'float32'
)
y_data
=
np
.
random
.
rand
(
10
,
5
)
.
astype
(
'float32'
)
y_data
=
np
.
random
.
rand
(
10
,
5
)
.
astype
(
'float32'
)
ref_res
=
np
.
concatenate
((
x_data
,
y_data
),
axis
=
1
)
t_data
=
np
.
random
.
uniform
(
size
=
())
.
astype
(
'float32'
)
ref_res
=
np
.
concatenate
((
x_data
,
y_data
),
axis
=
1
)
+
t_data
for
target
,
ctx
in
ctx_list
():
for
target
,
ctx
in
ctx_list
():
intrp1
=
relay
.
create_executor
(
"graph"
,
ctx
=
ctx
,
target
=
target
)
intrp1
=
relay
.
create_executor
(
"graph"
,
ctx
=
ctx
,
target
=
target
)
intrp2
=
relay
.
create_executor
(
"debug"
,
ctx
=
ctx
,
target
=
target
)
intrp2
=
relay
.
create_executor
(
"debug"
,
ctx
=
ctx
,
target
=
target
)
op_res1
=
intrp1
.
evaluate
(
func
)(
x_data
,
y_data
)
op_res1
=
intrp1
.
evaluate
(
func
)(
x_data
,
y_data
,
t_data
)
tvm
.
testing
.
assert_allclose
(
op_res1
.
asnumpy
(),
ref_res
,
rtol
=
0.01
)
tvm
.
testing
.
assert_allclose
(
op_res1
.
asnumpy
(),
ref_res
,
rtol
=
0.01
)
op_res2
=
intrp2
.
evaluate
(
func
)(
x_data
,
y_data
)
op_res2
=
intrp2
.
evaluate
(
func
)(
x_data
,
y_data
,
t_data
)
tvm
.
testing
.
assert_allclose
(
op_res2
.
asnumpy
(),
ref_res
,
rtol
=
0.01
)
tvm
.
testing
.
assert_allclose
(
op_res2
.
asnumpy
(),
ref_res
,
rtol
=
0.01
)
def
test_dropout
():
def
test_dropout
():
...
@@ -306,11 +308,11 @@ def test_dense():
...
@@ -306,11 +308,11 @@ def test_dense():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_concatenate
()
test_bias_add
()
test_bias_add
()
test_unary_op
()
test_unary_op
()
test_binary_op
()
test_binary_op
()
test_expand_dims_infer_type
()
test_expand_dims_infer_type
()
test_concatenate
()
test_expand_dims
()
test_expand_dims
()
test_softmax
()
test_softmax
()
test_log_softmax
()
test_log_softmax
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment