Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
3b3b8cbe
Commit
3b3b8cbe
authored
Dec 07, 2018
by
Jian Weng
Committed by
Tianqi Chen
Dec 07, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
allows constant param in op construct (#2257)
parent
d50f7b66
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
81 additions
and
38 deletions
+81
-38
python/tvm/hybrid/parser.py
+27
-18
python/tvm/hybrid/util.py
+14
-16
python/tvm/hybrid/var_decl.py
+10
-1
tests/python/unittest/test_hybrid_script.py
+30
-3
No files found.
python/tvm/hybrid/parser.py
View file @
3b3b8cbe
...
@@ -144,14 +144,14 @@ class HybridParser(ast.NodeVisitor):
...
@@ -144,14 +144,14 @@ class HybridParser(ast.NodeVisitor):
def
visit_Name
(
self
,
node
):
def
visit_Name
(
self
,
node
):
_id
=
node
.
id
_id
=
node
.
id
if
_id
in
self
.
_args
.
keys
()
and
isinstance
(
self
.
_args
[
_id
],
_expr
.
Var
):
if
_id
in
self
.
_args
.
keys
()
and
isinstance
(
self
.
_args
[
_id
],
(
_expr
.
Var
,
_expr
.
ConstExpr
)
):
return
self
.
_args
[
_id
]
return
self
.
_args
[
_id
]
elif
_id
in
self
.
loops_above
.
keys
():
elif
_id
in
self
.
loops_above
.
keys
():
return
self
.
loops_above
[
_id
]
return
self
.
loops_above
[
_id
]
_internal_assert
(
_id
not
in
self
.
_args
.
keys
(),
\
_internal_assert
(
_id
not
in
self
.
_args
.
keys
(),
\
"This id
%
s should be handled in visit_Subscript!"
%
_id
)
"This id
%
s should be handled in visit_Subscript!"
%
_id
)
_internal_assert
(
_id
in
self
.
usage
.
keys
(),
\
_internal_assert
(
_id
in
self
.
usage
.
keys
(),
\
"This id
%
s is expected to be a defined variable!"
%
_id
)
"This id
%
s is expected to be a defined variable!"
%
_id
)
# Buffer
# Buffer
if
_id
in
self
.
alloc_buffers
.
keys
():
if
_id
in
self
.
alloc_buffers
.
keys
():
_buf
,
_
=
self
.
alloc_buffers
[
_id
]
_buf
,
_
=
self
.
alloc_buffers
[
_id
]
...
@@ -166,6 +166,15 @@ class HybridParser(ast.NodeVisitor):
...
@@ -166,6 +166,15 @@ class HybridParser(ast.NodeVisitor):
return
_api
.
const
(
node
.
n
)
return
_api
.
const
(
node
.
n
)
def
visit_AugAssign
(
self
,
node
):
lhs
=
self
.
visit
(
node
.
target
)
rhs
=
self
.
visit
(
node
.
value
)
rhs
=
HybridParser
.
_binop_maker
[
type
(
node
.
op
)](
lhs
,
rhs
)
_internal_assert
(
isinstance
(
lhs
,
_expr
.
Call
),
\
"The LHS of an AugAssign is supposed to be a call!"
)
return
_make
.
Provide
(
lhs
.
func
,
0
,
rhs
,
lhs
.
args
)
def
visit_Assign
(
self
,
node
):
def
visit_Assign
(
self
,
node
):
_internal_assert
(
len
(
node
.
targets
)
==
1
,
"So far only one-valued assignment is supported!"
)
_internal_assert
(
len
(
node
.
targets
)
==
1
,
"So far only one-valued assignment is supported!"
)
lhs
=
node
.
targets
[
0
]
lhs
=
node
.
targets
[
0
]
...
@@ -177,7 +186,7 @@ class HybridParser(ast.NodeVisitor):
...
@@ -177,7 +186,7 @@ class HybridParser(ast.NodeVisitor):
lhs_
=
lhs
lhs_
=
lhs
lhs
=
lhs
.
id
lhs
=
lhs
.
id
_internal_assert
(
lhs
not
in
self
.
loops_above
.
keys
(),
\
_internal_assert
(
lhs
not
in
self
.
loops_above
.
keys
(),
\
"Loop variable cannot be overwritten!"
)
"Loop variable cannot be overwritten!"
)
decl
,
_
,
rw
=
self
.
usage
[
lhs
]
decl
,
_
,
rw
=
self
.
usage
[
lhs
]
if
decl
==
lhs_
:
if
decl
==
lhs_
:
_internal_assert
(
lhs
not
in
self
.
var_consts
.
keys
(),
\
_internal_assert
(
lhs
not
in
self
.
var_consts
.
keys
(),
\
...
@@ -227,16 +236,16 @@ class HybridParser(ast.NodeVisitor):
...
@@ -227,16 +236,16 @@ class HybridParser(ast.NodeVisitor):
return
_make
.
Call
(
_buf
.
dtype
,
array
,
args
,
_expr
.
Call
.
Halide
,
_buf
.
op
,
0
)
return
_make
.
Call
(
_buf
.
dtype
,
array
,
args
,
_expr
.
Call
.
Halide
,
_buf
.
op
,
0
)
_internal_assert
(
isinstance
(
node
.
value
,
ast
.
Attribute
),
\
_internal_assert
(
isinstance
(
node
.
value
,
ast
.
Attribute
),
\
"Only variable and attribute's subscript supported so far"
)
"Only variable and attribute's subscript supported so far"
)
_internal_assert
(
isinstance
(
node
.
value
.
value
,
ast
.
Name
),
\
_internal_assert
(
isinstance
(
node
.
value
.
value
,
ast
.
Name
),
\
"The root of array access is expect to be a id!"
)
"The root of array access is expect to be a id!"
)
_internal_assert
(
node
.
value
.
attr
==
"shape"
,
\
_internal_assert
(
node
.
value
.
attr
==
"shape"
,
\
"Attribute access so far only 'shape' is supported!"
)
"Attribute access so far only 'shape' is supported!"
)
_internal_assert
(
len
(
args
)
==
1
,
"For 'shape' access the argument should be only one!"
)
_internal_assert
(
len
(
args
)
==
1
,
"For 'shape' access the argument should be only one!"
)
args
=
args
[
0
]
args
=
args
[
0
]
#TODO: maybe support non-constant value later?
#TODO: maybe support non-constant value later?
_internal_assert
(
isinstance
(
args
,
(
_expr
.
IntImm
,
_expr
.
UIntImm
)),
\
_internal_assert
(
isinstance
(
args
,
(
_expr
.
IntImm
,
_expr
.
UIntImm
)),
\
"So far only constant shape access supported!"
)
"So far only constant shape access supported!"
)
buf
=
self
.
_get_buffer_from_id
(
node
.
value
.
value
.
id
)
buf
=
self
.
_get_buffer_from_id
(
node
.
value
.
value
.
id
)
return
buf
.
shape
[
args
.
value
]
return
buf
.
shape
[
args
.
value
]
...
@@ -294,7 +303,7 @@ class HybridParser(ast.NodeVisitor):
...
@@ -294,7 +303,7 @@ class HybridParser(ast.NodeVisitor):
def
visit_Call
(
self
,
node
):
def
visit_Call
(
self
,
node
):
# Yet, no function pointer supported
# Yet, no function pointer supported
_internal_assert
(
isinstance
(
node
.
func
,
ast
.
Name
),
\
_internal_assert
(
isinstance
(
node
.
func
,
ast
.
Name
),
\
"Only id-function function call is supported so far!"
)
"Only id-function function call is supported so far!"
)
func_id
=
node
.
func
.
id
func_id
=
node
.
func
.
id
n
=
len
(
node
.
args
)
n
=
len
(
node
.
args
)
if
func_id
in
LOOP_INTRIN
.
keys
()
and
func_id
!=
'bind'
:
if
func_id
in
LOOP_INTRIN
.
keys
()
and
func_id
!=
'bind'
:
...
@@ -311,7 +320,7 @@ class HybridParser(ast.NodeVisitor):
...
@@ -311,7 +320,7 @@ class HybridParser(ast.NodeVisitor):
elif
func_id
==
'bind'
:
elif
func_id
==
'bind'
:
_internal_assert
(
n
==
2
,
"A loop bind should only have 2 arguments!"
)
_internal_assert
(
n
==
2
,
"A loop bind should only have 2 arguments!"
)
_internal_assert
(
isinstance
(
node
.
args
[
0
],
ast
.
Str
),
\
_internal_assert
(
isinstance
(
node
.
args
[
0
],
ast
.
Str
),
\
"A loop bind's first argument should be a string!"
)
"A loop bind's first argument should be a string!"
)
_vn
=
node
.
args
[
0
]
.
s
_vn
=
node
.
args
[
0
]
.
s
iter_var
=
thread_axis
(
node
.
args
[
0
]
.
s
)
iter_var
=
thread_axis
(
node
.
args
[
0
]
.
s
)
low
,
ext
=
_api
.
const
(
0
,
dtype
=
'int32'
),
self
.
visit
(
node
.
args
[
1
])
low
,
ext
=
_api
.
const
(
0
,
dtype
=
'int32'
),
self
.
visit
(
node
.
args
[
1
])
...
@@ -321,11 +330,11 @@ class HybridParser(ast.NodeVisitor):
...
@@ -321,11 +330,11 @@ class HybridParser(ast.NodeVisitor):
return
getattr
(
intrin
,
func_id
)(
*
[
self
.
visit
(
arg
)
for
arg
in
node
.
args
])
return
getattr
(
intrin
,
func_id
)(
*
[
self
.
visit
(
arg
)
for
arg
in
node
.
args
])
elif
func_id
in
[
'allocate'
,
'output_tensor'
]:
elif
func_id
in
[
'allocate'
,
'output_tensor'
]:
_internal_assert
(
isinstance
(
node
.
args
[
0
],
ast
.
Tuple
),
\
_internal_assert
(
isinstance
(
node
.
args
[
0
],
ast
.
Tuple
),
\
"allocate's first argument should be a tuple of shape!"
)
"allocate's first argument should be a tuple of shape!"
)
shape
=
tuple
(
self
.
visit
(
i
)
for
i
in
node
.
args
[
0
]
.
elts
)
shape
=
tuple
(
self
.
visit
(
i
)
for
i
in
node
.
args
[
0
]
.
elts
)
if
func_id
==
'output_tensor'
:
if
func_id
==
'output_tensor'
:
_internal_assert
(
not
self
.
loops_above
,
\
_internal_assert
(
not
self
.
loops_above
,
\
"Are you sure to allocate a output buffer multiple times?"
)
"Are you sure to allocate a output buffer multiple times?"
)
for
i
in
shape
:
for
i
in
shape
:
_internal_assert
(
isinstance
(
i
,
_expr
.
Expr
),
"The shape should be an expression"
)
_internal_assert
(
isinstance
(
i
,
_expr
.
Expr
),
"The shape should be an expression"
)
if
n
>
1
:
if
n
>
1
:
...
@@ -333,18 +342,18 @@ class HybridParser(ast.NodeVisitor):
...
@@ -333,18 +342,18 @@ class HybridParser(ast.NodeVisitor):
dtype
=
node
.
args
[
1
]
.
s
dtype
=
node
.
args
[
1
]
.
s
else
:
else
:
_internal_assert
(
isinstance
(
node
.
args
[
1
],
ast
.
Attribute
),
\
_internal_assert
(
isinstance
(
node
.
args
[
1
],
ast
.
Attribute
),
\
"Unable to evaluate to get data type"
)
"Unable to evaluate to get data type"
)
to_eval
=
node
.
args
[
1
]
to_eval
=
node
.
args
[
1
]
_internal_assert
(
isinstance
(
to_eval
.
value
,
ast
.
Name
),
\
_internal_assert
(
isinstance
(
to_eval
.
value
,
ast
.
Name
),
\
"Unable to evaluate the attribute to get data type"
)
"Unable to evaluate the attribute to get data type"
)
_internal_assert
(
to_eval
.
attr
==
'dtype'
,
\
_internal_assert
(
to_eval
.
attr
==
'dtype'
,
\
"Only dtype attribute is supported so far"
)
"Only dtype attribute is supported so far"
)
dtype
=
self
.
_get_buffer_from_id
(
to_eval
.
value
.
id
)
.
dtype
dtype
=
self
.
_get_buffer_from_id
(
to_eval
.
value
.
id
)
.
dtype
else
:
else
:
dtype
=
'float32'
dtype
=
'float32'
if
n
>
2
:
if
n
>
2
:
_internal_assert
(
isinstance
(
node
.
args
[
2
],
ast
.
Str
),
\
_internal_assert
(
isinstance
(
node
.
args
[
2
],
ast
.
Str
),
\
"The data scope should be an string"
)
"The data scope should be an string"
)
_internal_assert
(
func_id
!=
'output_tensor'
,
"Output tensor cannot specify scope"
)
_internal_assert
(
func_id
!=
'output_tensor'
,
"Output tensor cannot specify scope"
)
scope
=
node
.
args
[
2
]
.
s
scope
=
node
.
args
[
2
]
.
s
else
:
else
:
...
@@ -361,7 +370,7 @@ class HybridParser(ast.NodeVisitor):
...
@@ -361,7 +370,7 @@ class HybridParser(ast.NodeVisitor):
def
visit_For
(
self
,
node
):
def
visit_For
(
self
,
node
):
iter_var
,
low
,
ext
,
for_type
=
self
.
visit
(
node
.
iter
)
iter_var
,
low
,
ext
,
for_type
=
self
.
visit
(
node
.
iter
)
_internal_assert
(
isinstance
(
node
.
target
,
ast
.
Name
),
\
_internal_assert
(
isinstance
(
node
.
target
,
ast
.
Name
),
\
"The loop iterator should be a variable!"
)
"The loop iterator should be a variable!"
)
_name
=
node
.
target
.
id
_name
=
node
.
target
.
id
if
iter_var
is
None
:
if
iter_var
is
None
:
_internal_assert
(
for_type
is
not
None
,
"The loop bind function parse error!"
)
_internal_assert
(
for_type
is
not
None
,
"The loop bind function parse error!"
)
...
@@ -389,7 +398,7 @@ class HybridParser(ast.NodeVisitor):
...
@@ -389,7 +398,7 @@ class HybridParser(ast.NodeVisitor):
ids
.
append
(
node
.
value
.
id
)
ids
.
append
(
node
.
value
.
id
)
else
:
else
:
_internal_assert
(
isinstance
(
node
.
value
,
ast
.
Tuple
),
\
_internal_assert
(
isinstance
(
node
.
value
,
ast
.
Tuple
),
\
"You should return either a single tensor or a tuple"
)
"You should return either a single tensor or a tuple"
)
for
i
in
node
.
value
.
elts
:
for
i
in
node
.
value
.
elts
:
_internal_assert
(
isinstance
(
i
,
ast
.
Name
),
"What do you return?"
)
_internal_assert
(
isinstance
(
i
,
ast
.
Name
),
"What do you return?"
)
ids
.
append
(
i
.
id
)
ids
.
append
(
i
.
id
)
...
...
python/tvm/hybrid/util.py
View file @
3b3b8cbe
...
@@ -15,9 +15,14 @@ from ..tensor import Tensor
...
@@ -15,9 +15,14 @@ from ..tensor import Tensor
#pylint: disable=invalid-name
#pylint: disable=invalid-name
np_arg_types
=
tuple
(
list
(
numeric_types
)
+
[
numpy
.
ndarray
])
np_arg_types
=
tuple
(
list
(
numeric_types
)
+
[
numpy
.
ndarray
])
tvm_arg_types
=
(
Tensor
,
_expr
.
Var
)
tvm_arg_types
=
(
Tensor
,
_expr
.
Var
,
_expr
.
ConstExpr
)
halide_imm_types
=
(
_expr
.
IntImm
,
_expr
.
FloatImm
,
_expr
.
UIntImm
)
halide_imm_types
=
(
_expr
.
IntImm
,
_expr
.
FloatImm
,
_expr
.
UIntImm
)
def
_internal_assert
(
cond
,
err
):
"""Simplify the code segment like if not XXX then raise an error"""
if
not
cond
:
raise
ValueError
(
err
)
# Useful constants. In avoid of runtime dependences, we use function calls to return them.
# Useful constants. In avoid of runtime dependences, we use function calls to return them.
def
make_nop
():
def
make_nop
():
...
@@ -50,14 +55,16 @@ def _is_tvm_arg_types(args):
...
@@ -50,14 +55,16 @@ def _is_tvm_arg_types(args):
If neither is true, raise a value error."""
If neither is true, raise a value error."""
if
isinstance
(
args
[
0
],
tvm_arg_types
):
if
isinstance
(
args
[
0
],
tvm_arg_types
):
for
elem
in
args
[
1
:]:
for
elem
in
args
[
1
:]:
if
not
isinstance
(
elem
,
tvm_arg_types
):
_internal_assert
(
isinstance
(
elem
,
tvm_arg_types
),
raise
ValueError
(
"Expect a Var or Tensor instance but
%
get!"
%
str
(
type
(
elem
)))
"Expecting a Var, Tensor or ConstExpr instance but
%
s get!"
\
%
str
(
type
(
elem
)))
return
True
return
True
if
not
isinstance
(
args
[
0
],
np_arg_types
):
raise
ValueError
(
"Expect a numpy type but
%
get!"
%
str
(
type
(
args
[
0
])))
_internal_assert
(
isinstance
(
args
[
0
],
np_arg_types
),
\
"Expect a numpy type but
%
s get!"
%
str
(
type
(
args
[
0
])))
for
elem
in
args
[
1
:]:
for
elem
in
args
[
1
:]:
if
not
isinstance
(
elem
,
np_arg_types
):
_internal_assert
(
isinstance
(
elem
,
np_arg_types
),
\
raise
ValueError
(
"Expect a numpy type but
%
get!"
%
str
(
type
(
elem
)))
"Expect a numpy type but
%
s
get!"
%
str
(
type
(
elem
)))
return
False
return
False
...
@@ -79,12 +86,3 @@ def _restore_runtime(func, intersect):
...
@@ -79,12 +86,3 @@ def _restore_runtime(func, intersect):
_globals
.
pop
(
elem
)
_globals
.
pop
(
elem
)
for
k
,
v
in
intersect
:
for
k
,
v
in
intersect
:
_globals
[
k
]
=
v
_globals
[
k
]
=
v
def
_internal_assert
(
cond
,
err
):
"""Simplify the code segment like if not XXX then raise an error"""
if
not
cond
:
raise
ValueError
(
err
)
# Almost the same functionality as the one above, but in this case,
# the error is caused by users inproper usage.
_user_assert
=
_internal_assert
python/tvm/hybrid/var_decl.py
View file @
3b3b8cbe
...
@@ -15,6 +15,7 @@ class PyVariableUsage(ast.NodeVisitor):
...
@@ -15,6 +15,7 @@ class PyVariableUsage(ast.NodeVisitor):
self
.
scope_level
=
[]
self
.
scope_level
=
[]
self
.
_args
=
{}
self
.
_args
=
{}
self
.
args
=
args
self
.
args
=
args
self
.
aug_assign_
=
False
def
visit_FunctionDef
(
self
,
node
):
def
visit_FunctionDef
(
self
,
node
):
...
@@ -48,6 +49,12 @@ class PyVariableUsage(ast.NodeVisitor):
...
@@ -48,6 +49,12 @@ class PyVariableUsage(ast.NodeVisitor):
self
.
visit
(
elem
)
self
.
visit
(
elem
)
def
visit_AugAssign
(
self
,
node
):
self
.
aug_assign_
=
True
self
.
generic_visit
(
node
)
self
.
aug_assign_
=
False
def
visit_Name
(
self
,
node
):
def
visit_Name
(
self
,
node
):
# If it is from the argument list or loop variable, we do not worry about it!
# If it is from the argument list or loop variable, we do not worry about it!
if
node
.
id
in
self
.
_args
.
keys
():
if
node
.
id
in
self
.
_args
.
keys
():
...
@@ -61,7 +68,9 @@ class PyVariableUsage(ast.NodeVisitor):
...
@@ -61,7 +68,9 @@ class PyVariableUsage(ast.NodeVisitor):
if
node
.
id
not
in
self
.
status
.
keys
():
if
node
.
id
not
in
self
.
status
.
keys
():
_internal_assert
(
isinstance
(
node
.
ctx
,
ast
.
Store
),
\
_internal_assert
(
isinstance
(
node
.
ctx
,
ast
.
Store
),
\
'Undeclared variable
%
s'
%
node
.
id
)
'Undeclared variable
%
s'
%
node
.
id
)
if
self
.
aug_assign_
:
raise
ValueError
(
'"First store" cannot be an AugAssign'
)
self
.
status
[
node
.
id
]
=
(
node
,
self
.
scope_level
[
-
1
],
set
())
self
.
status
[
node
.
id
]
=
(
node
,
self
.
scope_level
[
-
1
],
set
())
else
:
else
:
decl
,
loop
,
usage
=
self
.
status
[
node
.
id
]
decl
,
loop
,
usage
=
self
.
status
[
node
.
id
]
...
...
tests/python/unittest/test_hybrid_script.py
View file @
3b3b8cbe
...
@@ -115,7 +115,7 @@ def test_fanout():
...
@@ -115,7 +115,7 @@ def test_fanout():
for
i
in
range
(
a
.
shape
[
0
]
-
3
):
for
i
in
range
(
a
.
shape
[
0
]
-
3
):
sigma
=
0.0
sigma
=
0.0
for
j
in
range
(
3
):
for
j
in
range
(
3
):
sigma
=
sigma
+
a
[
i
+
j
]
sigma
+=
a
[
i
+
j
]
sigma
=
sigma
/
three
sigma
=
sigma
/
three
b
[
i
]
=
sigma
b
[
i
]
=
sigma
return
b
return
b
...
@@ -246,7 +246,7 @@ def test_bind():
...
@@ -246,7 +246,7 @@ def test_bind():
def
vec_add
(
a
,
b
):
def
vec_add
(
a
,
b
):
c
=
output_tensor
((
1000
,
),
dtype
=
'float32'
)
c
=
output_tensor
((
1000
,
),
dtype
=
'float32'
)
for
tx
in
bind
(
'threadIdx.x'
,
1000
):
for
tx
in
bind
(
'threadIdx.x'
,
1000
):
c
[
tx
]
=
b
[
tx
]
+
c
[
tx
]
c
[
tx
]
=
a
[
tx
]
+
b
[
tx
]
return
c
return
c
a
=
tvm
.
placeholder
((
1000
,
),
dtype
=
'float32'
,
name
=
'a'
)
a
=
tvm
.
placeholder
((
1000
,
),
dtype
=
'float32'
,
name
=
'a'
)
...
@@ -308,7 +308,7 @@ def test_non_zero():
...
@@ -308,7 +308,7 @@ def test_non_zero():
s
=
0.0
s
=
0.0
for
di
in
range
(
3
):
for
di
in
range
(
3
):
for
dj
in
range
(
3
):
for
dj
in
range
(
3
):
s
=
s
+
a
[
i
-
di
,
j
-
dj
]
s
+=
a
[
i
-
di
,
j
-
dj
]
b
[
i
-
2
,
j
-
2
]
=
s
/
9.0
b
[
i
-
2
,
j
-
2
]
=
s
/
9.0
return
b
return
b
...
@@ -419,6 +419,32 @@ def test_downstream():
...
@@ -419,6 +419,32 @@ def test_downstream():
module
(
tvm_a
,
tvm_c
)
module
(
tvm_a
,
tvm_c
)
tvm
.
testing
.
assert_allclose
(
tvm_c
.
asnumpy
(),
ref
,
1e-5
,
1e-5
)
tvm
.
testing
.
assert_allclose
(
tvm_c
.
asnumpy
(),
ref
,
1e-5
,
1e-5
)
def
test_const_param
():
@tvm.hybrid.script
def
add_something
(
a
,
b
):
c
=
output_tensor
((
11
,
),
'int32'
)
for
i
in
range
(
11
):
c
[
i
]
=
a
[
i
]
+
b
return
c
a
=
tvm
.
placeholder
((
11
,
),
dtype
=
'int32'
,
name
=
'a'
)
b
=
tvm
.
const
(
11
,
'int32'
)
c
=
add_something
(
a
,
b
)
sch
=
tvm
.
create_schedule
(
c
.
op
)
module
=
tvm
.
build
(
sch
,
[
a
,
c
],
'llvm'
)
assert
(
module
)
np_a
=
numpy
.
arange
(
11
)
.
astype
(
'int32'
)
np_b
=
11
np_c
=
numpy
.
zeros
((
11
,
))
.
astype
(
'int32'
)
nd_a
=
tvm
.
ndarray
.
array
(
np_a
)
nd_c
=
tvm
.
ndarray
.
array
(
numpy
.
zeros
((
11
,
))
.
astype
(
'int32'
))
module
(
nd_a
,
nd_c
)
ref
=
add_something
(
np_a
,
11
)
tvm
.
testing
.
assert_allclose
(
nd_c
.
asnumpy
(),
ref
,
1e-5
,
1e-5
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_outer_product
()
test_outer_product
()
...
@@ -432,5 +458,6 @@ if __name__ == "__main__":
...
@@ -432,5 +458,6 @@ if __name__ == "__main__":
#test_inplace()
#test_inplace()
test_upstream
()
test_upstream
()
test_downstream
()
test_downstream
()
test_const_param
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment