Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
0702d2c0
Commit
0702d2c0
authored
Jun 17, 2018
by
Tianqi Chen
Committed by
GitHub
Jun 17, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[OP] Introduces auxiliary attrs into compute (#1293)
parent
146714ac
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
175 additions
and
57 deletions
+175
-57
include/tvm/operation.h
+31
-11
include/tvm/runtime/c_runtime_api.h
+4
-0
python/tvm/api.py
+28
-6
src/api/api_lang.cc
+6
-3
src/lang/reflection.cc
+26
-2
src/op/compute_op.cc
+15
-9
src/op/extern_op.cc
+8
-6
src/op/scan_op.cc
+13
-9
src/schedule/schedule_dataflow_rewrite.cc
+6
-3
tests/python/unittest/test_lang_reflection.py
+13
-0
tests/python/unittest/test_lang_tag.py
+17
-3
topi/include/topi/contrib/cublas.h
+1
-1
topi/include/topi/contrib/rocblas.h
+1
-1
topi/include/topi/detail/extern.h
+6
-3
No files found.
include/tvm/operation.h
View file @
0702d2c0
...
@@ -41,6 +41,8 @@ class OperationNode : public FunctionBaseNode {
...
@@ -41,6 +41,8 @@ class OperationNode : public FunctionBaseNode {
std
::
string
name
;
std
::
string
name
;
/*! \brief optional tag of the operation */
/*! \brief optional tag of the operation */
std
::
string
tag
;
std
::
string
tag
;
/*! \brief addtitional attributes of the operation*/
Map
<
std
::
string
,
NodeRef
>
attrs
;
/*! \return name of the operation */
/*! \return name of the operation */
const
std
::
string
&
func_name
()
const
final
{
const
std
::
string
&
func_name
()
const
final
{
return
name
;
return
name
;
...
@@ -167,6 +169,8 @@ class PlaceholderOpNode : public OperationNode {
...
@@ -167,6 +169,8 @@ class PlaceholderOpNode : public OperationNode {
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"tag"
,
&
tag
);
v
->
Visit
(
"attrs"
,
&
attrs
);
v
->
Visit
(
"shape"
,
&
shape
);
v
->
Visit
(
"shape"
,
&
shape
);
v
->
Visit
(
"dtype"
,
&
dtype
);
v
->
Visit
(
"dtype"
,
&
dtype
);
}
}
...
@@ -220,12 +224,14 @@ class TVM_DLL ComputeOpNode : public OperationNode {
...
@@ -220,12 +224,14 @@ class TVM_DLL ComputeOpNode : public OperationNode {
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"tag"
,
&
tag
);
v
->
Visit
(
"tag"
,
&
tag
);
v
->
Visit
(
"attrs"
,
&
attrs
);
v
->
Visit
(
"axis"
,
&
axis
);
v
->
Visit
(
"axis"
,
&
axis
);
v
->
Visit
(
"reduce_axis"
,
&
reduce_axis
);
v
->
Visit
(
"reduce_axis"
,
&
reduce_axis
);
v
->
Visit
(
"body"
,
&
body
);
v
->
Visit
(
"body"
,
&
body
);
}
}
static
Operation
make
(
std
::
string
name
,
static
Operation
make
(
std
::
string
name
,
std
::
string
tag
,
std
::
string
tag
,
Map
<
std
::
string
,
NodeRef
>
attrs
,
Array
<
IterVar
>
axis
,
Array
<
IterVar
>
axis
,
Array
<
Expr
>
body
);
Array
<
Expr
>
body
);
...
@@ -292,6 +298,7 @@ class ScanOpNode : public OperationNode {
...
@@ -292,6 +298,7 @@ class ScanOpNode : public OperationNode {
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"tag"
,
&
tag
);
v
->
Visit
(
"tag"
,
&
tag
);
v
->
Visit
(
"attrs"
,
&
attrs
);
v
->
Visit
(
"scan_axis"
,
&
scan_axis
);
v
->
Visit
(
"scan_axis"
,
&
scan_axis
);
v
->
Visit
(
"init"
,
&
init
);
v
->
Visit
(
"init"
,
&
init
);
v
->
Visit
(
"update"
,
&
update
);
v
->
Visit
(
"update"
,
&
update
);
...
@@ -301,6 +308,7 @@ class ScanOpNode : public OperationNode {
...
@@ -301,6 +308,7 @@ class ScanOpNode : public OperationNode {
}
}
static
Operation
make
(
std
::
string
name
,
static
Operation
make
(
std
::
string
name
,
std
::
string
tag
,
std
::
string
tag
,
Map
<
std
::
string
,
NodeRef
>
attrs
,
IterVar
axis
,
IterVar
axis
,
Array
<
Tensor
>
init
,
Array
<
Tensor
>
init
,
Array
<
Tensor
>
update
,
Array
<
Tensor
>
update
,
...
@@ -356,11 +364,13 @@ class ExternOpNode : public OperationNode {
...
@@ -356,11 +364,13 @@ class ExternOpNode : public OperationNode {
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"tag"
,
&
tag
);
v
->
Visit
(
"tag"
,
&
tag
);
v
->
Visit
(
"attrs"
,
&
attrs
);
v
->
Visit
(
"inputs"
,
&
inputs
);
v
->
Visit
(
"inputs"
,
&
inputs
);
v
->
Visit
(
"body"
,
&
body
);
v
->
Visit
(
"body"
,
&
body
);
}
}
EXPORT
static
Operation
make
(
std
::
string
name
,
EXPORT
static
Operation
make
(
std
::
string
name
,
std
::
string
tag
,
std
::
string
tag
,
Map
<
std
::
string
,
NodeRef
>
attrs
,
Array
<
Tensor
>
inputs
,
Array
<
Tensor
>
inputs
,
Array
<
Buffer
>
input_placeholders
,
Array
<
Buffer
>
input_placeholders
,
Array
<
Buffer
>
output_placeholders
,
Array
<
Buffer
>
output_placeholders
,
...
@@ -393,11 +403,13 @@ TVM_DLL Tensor placeholder(Array<Expr> shape,
...
@@ -393,11 +403,13 @@ TVM_DLL Tensor placeholder(Array<Expr> shape,
* \param fcompute The compute function to create the tensor.
* \param fcompute The compute function to create the tensor.
* \param name The optional name of the tensor.
* \param name The optional name of the tensor.
* \param tag The optional tag of the tensor.
* \param tag The optional tag of the tensor.
* \param attrs Optional additional attributes of the compute.
*/
*/
TVM_DLL
Tensor
compute
(
Array
<
Expr
>
shape
,
TVM_DLL
Tensor
compute
(
Array
<
Expr
>
shape
,
FCompute
fcompute
,
FCompute
fcompute
,
std
::
string
name
=
"tensor"
,
std
::
string
name
=
"tensor"
,
std
::
string
tag
=
""
);
std
::
string
tag
=
""
,
Map
<
std
::
string
,
NodeRef
>
attrs
=
{});
/*!
/*!
* \brief Construct a new tensor by computing over shape,
* \brief Construct a new tensor by computing over shape,
...
@@ -406,11 +418,13 @@ TVM_DLL Tensor compute(Array<Expr> shape,
...
@@ -406,11 +418,13 @@ TVM_DLL Tensor compute(Array<Expr> shape,
* \param fcompute The compute function to create the tensors.
* \param fcompute The compute function to create the tensors.
* \param name The optional name of the tensor.
* \param name The optional name of the tensor.
* \param tag The optional tag of the tensor.
* \param tag The optional tag of the tensor.
* \param attrs Optional additional attributes of the compute.
*/
*/
TVM_DLL
Array
<
Tensor
>
compute
(
Array
<
Expr
>
shape
,
TVM_DLL
Array
<
Tensor
>
compute
(
Array
<
Expr
>
shape
,
FBatchCompute
fcompute
,
FBatchCompute
fcompute
,
std
::
string
name
=
"tensor"
,
std
::
string
name
=
"tensor"
,
std
::
string
tag
=
""
);
std
::
string
tag
=
""
,
Map
<
std
::
string
,
NodeRef
>
attrs
=
{});
/*!
/*!
* \brief Construct new tensors by scan.
* \brief Construct new tensors by scan.
...
@@ -422,42 +436,48 @@ TVM_DLL Array<Tensor> compute(Array<Expr> shape,
...
@@ -422,42 +436,48 @@ TVM_DLL Array<Tensor> compute(Array<Expr> shape,
* but recommended to provide concrete information about scan body.
* but recommended to provide concrete information about scan body.
* \param name The optional name of the tensor.
* \param name The optional name of the tensor.
* \param tag The optional tag of the tensor.
* \param tag The optional tag of the tensor.
* \param attrs Optional additional attributes of the compute.
*/
*/
TVM_DLL
Array
<
Tensor
>
scan
(
Array
<
Tensor
>
init
,
TVM_DLL
Array
<
Tensor
>
scan
(
Array
<
Tensor
>
init
,
Array
<
Tensor
>
update
,
Array
<
Tensor
>
update
,
Array
<
Tensor
>
state_placeholder
,
Array
<
Tensor
>
state_placeholder
,
Array
<
Tensor
>
inputs
=
Array
<
Tensor
>
(),
Array
<
Tensor
>
inputs
=
Array
<
Tensor
>
(),
std
::
string
name
=
"scan"
,
std
::
string
name
=
"scan"
,
std
::
string
tag
=
""
);
std
::
string
tag
=
""
,
Map
<
std
::
string
,
NodeRef
>
attrs
=
{});
// same as compute, specialized for different fcompute function
// same as compute, specialized for different fcompute function
inline
Tensor
compute
(
Array
<
Expr
>
shape
,
inline
Tensor
compute
(
Array
<
Expr
>
shape
,
std
::
function
<
Expr
(
Var
)
>
f
,
std
::
function
<
Expr
(
Var
)
>
f
,
std
::
string
name
=
"tensor"
,
std
::
string
name
=
"tensor"
,
std
::
string
tag
=
""
)
{
std
::
string
tag
=
""
,
Map
<
std
::
string
,
NodeRef
>
attrs
=
{})
{
FCompute
fc
=
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
]);
};
FCompute
fc
=
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
]);
};
return
compute
(
shape
,
fc
,
name
,
tag
);
return
compute
(
shape
,
fc
,
name
,
tag
,
attrs
);
}
}
inline
Tensor
compute
(
Array
<
Expr
>
shape
,
inline
Tensor
compute
(
Array
<
Expr
>
shape
,
std
::
function
<
Expr
(
Var
,
Var
)
>
f
,
std
::
function
<
Expr
(
Var
,
Var
)
>
f
,
std
::
string
name
=
"tensor"
,
std
::
string
name
=
"tensor"
,
std
::
string
tag
=
""
)
{
std
::
string
tag
=
""
,
Map
<
std
::
string
,
NodeRef
>
attrs
=
{})
{
FCompute
fc
=
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
]);
};
FCompute
fc
=
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
]);
};
return
compute
(
shape
,
fc
,
name
,
tag
);
return
compute
(
shape
,
fc
,
name
,
tag
,
attrs
);
}
}
inline
Tensor
compute
(
Array
<
Expr
>
shape
,
inline
Tensor
compute
(
Array
<
Expr
>
shape
,
std
::
function
<
Expr
(
Var
,
Var
,
Var
)
>
f
,
std
::
function
<
Expr
(
Var
,
Var
,
Var
)
>
f
,
std
::
string
name
=
"tensor"
,
std
::
string
name
=
"tensor"
,
std
::
string
tag
=
""
)
{
std
::
string
tag
=
""
,
Map
<
std
::
string
,
NodeRef
>
attrs
=
{})
{
FCompute
fc
=
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
],
i
[
2
]);
};
FCompute
fc
=
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
],
i
[
2
]);
};
return
compute
(
shape
,
fc
,
name
,
tag
);
return
compute
(
shape
,
fc
,
name
,
tag
,
attrs
);
}
}
inline
Tensor
compute
(
Array
<
Expr
>
shape
,
inline
Tensor
compute
(
Array
<
Expr
>
shape
,
std
::
function
<
Expr
(
Var
,
Var
,
Var
,
Var
)
>
f
,
std
::
function
<
Expr
(
Var
,
Var
,
Var
,
Var
)
>
f
,
std
::
string
name
=
"tensor"
,
std
::
string
name
=
"tensor"
,
std
::
string
tag
=
""
)
{
std
::
string
tag
=
""
,
Map
<
std
::
string
,
NodeRef
>
attrs
=
{})
{
FCompute
fc
=
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
],
i
[
2
],
i
[
3
]);
};
FCompute
fc
=
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
],
i
[
2
],
i
[
3
]);
};
return
compute
(
shape
,
fc
,
name
,
tag
);
return
compute
(
shape
,
fc
,
name
,
tag
,
attrs
);
}
}
// inline function.
// inline function.
...
...
include/tvm/runtime/c_runtime_api.h
View file @
0702d2c0
...
@@ -42,6 +42,10 @@
...
@@ -42,6 +42,10 @@
#endif
#endif
#endif
#endif
// TVM version
#define TVM_VERSION "0.4.0"
// TVM Runtime is DLPack compatible.
// TVM Runtime is DLPack compatible.
#include <dlpack/dlpack.h>
#include <dlpack/dlpack.h>
...
...
python/tvm/api.py
View file @
0702d2c0
...
@@ -189,7 +189,7 @@ def placeholder(shape, dtype=None, name="placeholder"):
...
@@ -189,7 +189,7 @@ def placeholder(shape, dtype=None, name="placeholder"):
shape
,
dtype
,
name
)
shape
,
dtype
,
name
)
def
compute
(
shape
,
fcompute
,
name
=
"compute"
,
tag
=
""
):
def
compute
(
shape
,
fcompute
,
name
=
"compute"
,
tag
=
""
,
attrs
=
None
):
"""Construct a new tensor by computing over the shape domain.
"""Construct a new tensor by computing over the shape domain.
The compute rule is result[axis] = fcompute(axis)
The compute rule is result[axis] = fcompute(axis)
...
@@ -205,6 +205,12 @@ def compute(shape, fcompute, name="compute", tag=""):
...
@@ -205,6 +205,12 @@ def compute(shape, fcompute, name="compute", tag=""):
name: str, optional
name: str, optional
The name hint of the tensor
The name hint of the tensor
tag: str, optional
Additonal tag information about the compute.
attrs: dict, optional
The additional auxiliary attributes about the compute.
Returns
Returns
-------
-------
tensor: Tensor
tensor: Tensor
...
@@ -232,13 +238,13 @@ def compute(shape, fcompute, name="compute", tag=""):
...
@@ -232,13 +238,13 @@ def compute(shape, fcompute, name="compute", tag=""):
body
=
[
body
]
body
=
[
body
]
body
=
convert
(
body
)
body
=
convert
(
body
)
op_node
=
_api_internal
.
_ComputeOp
(
op_node
=
_api_internal
.
_ComputeOp
(
name
,
tag
,
dim_var
,
body
)
name
,
tag
,
attrs
,
dim_var
,
body
)
num
=
op_node
.
num_outputs
num
=
op_node
.
num_outputs
outputs
=
tuple
(
op_node
.
output
(
i
)
for
i
in
range
(
num
))
outputs
=
tuple
(
op_node
.
output
(
i
)
for
i
in
range
(
num
))
return
outputs
[
0
]
if
num
==
1
else
outputs
return
outputs
[
0
]
if
num
==
1
else
outputs
def
scan
(
init
,
update
,
state_placeholder
,
inputs
=
None
,
name
=
"scan"
,
tag
=
""
):
def
scan
(
init
,
update
,
state_placeholder
,
inputs
=
None
,
name
=
"scan"
,
tag
=
""
,
attrs
=
None
):
"""Construct new tensors by scanning over axis.
"""Construct new tensors by scanning over axis.
Parameters
Parameters
...
@@ -259,6 +265,12 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag=""):
...
@@ -259,6 +265,12 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag=""):
name: str, optional
name: str, optional
The name hint of the tensor
The name hint of the tensor
tag: str, optional
Additonal tag information about the compute.
attrs: dict, optional
The additional auxiliary attributes about the compute.
Returns
Returns
-------
-------
tensor: Tensor or list of Tensors
tensor: Tensor or list of Tensors
...
@@ -294,7 +306,8 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag=""):
...
@@ -294,7 +306,8 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag=""):
if
len
(
init
)
!=
len
(
update
)
or
len
(
init
)
!=
len
(
state_placeholder
):
if
len
(
init
)
!=
len
(
update
)
or
len
(
init
)
!=
len
(
state_placeholder
):
raise
ValueError
(
"init, update, state_placeholder must have same length"
)
raise
ValueError
(
"init, update, state_placeholder must have same length"
)
axis
=
_IterVar
((
init
[
0
]
.
shape
[
0
],
update
[
0
]
.
shape
[
0
]),
"
%
s.idx"
%
name
,
3
)
axis
=
_IterVar
((
init
[
0
]
.
shape
[
0
],
update
[
0
]
.
shape
[
0
]),
"
%
s.idx"
%
name
,
3
)
op
=
_api_internal
.
_ScanOp
(
name
,
tag
,
axis
,
init
,
update
,
op
=
_api_internal
.
_ScanOp
(
name
,
tag
,
attrs
,
axis
,
init
,
update
,
state_placeholder
,
inputs
)
state_placeholder
,
inputs
)
res
=
[
op
.
output
(
i
)
for
i
in
range
(
len
(
update
))]
res
=
[
op
.
output
(
i
)
for
i
in
range
(
len
(
update
))]
return
res
[
0
]
if
len
(
res
)
==
1
else
res
return
res
[
0
]
if
len
(
res
)
==
1
else
res
...
@@ -307,7 +320,8 @@ def extern(shape,
...
@@ -307,7 +320,8 @@ def extern(shape,
dtype
=
None
,
dtype
=
None
,
in_buffers
=
None
,
in_buffers
=
None
,
out_buffers
=
None
,
out_buffers
=
None
,
tag
=
""
):
tag
=
""
,
attrs
=
None
):
"""Compute several tensor via extern function.
"""Compute several tensor via extern function.
Parameters
Parameters
...
@@ -345,6 +359,13 @@ def extern(shape,
...
@@ -345,6 +359,13 @@ def extern(shape,
out_buffers: Buffer or list of Buffers, optional
out_buffers: Buffer or list of Buffers, optional
Output buffers.
Output buffers.
tag: str, optional
Additonal tag information about the compute.
attrs: dict, optional
The additional auxiliary attributes about the compute.
Returns
Returns
-------
-------
tensor: Tensor or list of Tensors
tensor: Tensor or list of Tensors
...
@@ -406,7 +427,8 @@ def extern(shape,
...
@@ -406,7 +427,8 @@ def extern(shape,
if
isinstance
(
body
,
_expr
.
Expr
):
if
isinstance
(
body
,
_expr
.
Expr
):
body
=
_make
.
Evaluate
(
body
)
body
=
_make
.
Evaluate
(
body
)
op
=
_api_internal
.
_ExternOp
(
name
,
tag
,
inputs
,
input_placeholders
,
op
=
_api_internal
.
_ExternOp
(
name
,
tag
,
attrs
,
inputs
,
input_placeholders
,
output_placeholders
,
body
)
output_placeholders
,
body
)
res
=
[
op
.
output
(
i
)
for
i
in
range
(
len
(
output_placeholders
))]
res
=
[
op
.
output
(
i
)
for
i
in
range
(
len
(
output_placeholders
))]
return
res
[
0
]
if
len
(
res
)
==
1
else
res
return
res
[
0
]
if
len
(
res
)
==
1
else
res
...
...
src/api/api_lang.cc
View file @
0702d2c0
...
@@ -262,7 +262,8 @@ TVM_REGISTER_API("_ComputeOp")
...
@@ -262,7 +262,8 @@ TVM_REGISTER_API("_ComputeOp")
*
ret
=
ComputeOpNode
::
make
(
args
[
0
],
*
ret
=
ComputeOpNode
::
make
(
args
[
0
],
args
[
1
],
args
[
1
],
args
[
2
],
args
[
2
],
args
[
3
]);
args
[
3
],
args
[
4
]);
});
});
TVM_REGISTER_API
(
"_ScanOp"
)
TVM_REGISTER_API
(
"_ScanOp"
)
...
@@ -273,7 +274,8 @@ TVM_REGISTER_API("_ScanOp")
...
@@ -273,7 +274,8 @@ TVM_REGISTER_API("_ScanOp")
args
[
3
],
args
[
3
],
args
[
4
],
args
[
4
],
args
[
5
],
args
[
5
],
args
[
6
]);
args
[
6
],
args
[
7
]);
});
});
TVM_REGISTER_API
(
"_ExternOp"
)
TVM_REGISTER_API
(
"_ExternOp"
)
...
@@ -283,7 +285,8 @@ TVM_REGISTER_API("_ExternOp")
...
@@ -283,7 +285,8 @@ TVM_REGISTER_API("_ExternOp")
args
[
2
],
args
[
2
],
args
[
3
],
args
[
3
],
args
[
4
],
args
[
4
],
args
[
5
]);
args
[
5
],
args
[
6
]);
});
});
TVM_REGISTER_API
(
"_OpGetOutput"
)
TVM_REGISTER_API
(
"_OpGetOutput"
)
...
...
src/lang/reflection.cc
View file @
0702d2c0
...
@@ -84,6 +84,11 @@ class NodeIndexer : public AttrVisitor {
...
@@ -84,6 +84,11 @@ class NodeIndexer : public AttrVisitor {
MakeIndex
(
kv
.
first
.
get
());
MakeIndex
(
kv
.
first
.
get
());
MakeIndex
(
kv
.
second
.
get
());
MakeIndex
(
kv
.
second
.
get
());
}
}
}
else
if
(
node
->
is_type
<
StrMapNode
>
())
{
StrMapNode
*
n
=
static_cast
<
StrMapNode
*>
(
node
);
for
(
const
auto
&
kv
:
n
->
data
)
{
MakeIndex
(
kv
.
second
.
get
());
}
}
else
{
}
else
{
node
->
VisitAttrs
(
this
);
node
->
VisitAttrs
(
this
);
}
}
...
@@ -99,6 +104,8 @@ struct JSONNode {
...
@@ -99,6 +104,8 @@ struct JSONNode {
std
::
string
type_key
;
std
::
string
type_key
;
// the attributes
// the attributes
AttrMap
attrs
;
AttrMap
attrs
;
// container keys
std
::
vector
<
std
::
string
>
keys
;
// container data
// container data
std
::
vector
<
size_t
>
data
;
std
::
vector
<
size_t
>
data
;
...
@@ -108,6 +115,9 @@ struct JSONNode {
...
@@ -108,6 +115,9 @@ struct JSONNode {
if
(
attrs
.
size
()
!=
0
)
{
if
(
attrs
.
size
()
!=
0
)
{
writer
->
WriteObjectKeyValue
(
"attrs"
,
attrs
);
writer
->
WriteObjectKeyValue
(
"attrs"
,
attrs
);
}
}
if
(
keys
.
size
()
!=
0
)
{
writer
->
WriteObjectKeyValue
(
"keys"
,
keys
);
}
if
(
data
.
size
()
!=
0
)
{
if
(
data
.
size
()
!=
0
)
{
writer
->
WriteObjectKeyValue
(
"data"
,
data
);
writer
->
WriteObjectKeyValue
(
"data"
,
data
);
}
}
...
@@ -121,6 +131,7 @@ struct JSONNode {
...
@@ -121,6 +131,7 @@ struct JSONNode {
dmlc
::
JSONObjectReadHelper
helper
;
dmlc
::
JSONObjectReadHelper
helper
;
helper
.
DeclareOptionalField
(
"type_key"
,
&
type_key
);
helper
.
DeclareOptionalField
(
"type_key"
,
&
type_key
);
helper
.
DeclareOptionalField
(
"attrs"
,
&
attrs
);
helper
.
DeclareOptionalField
(
"attrs"
,
&
attrs
);
helper
.
DeclareOptionalField
(
"keys"
,
&
keys
);
helper
.
DeclareOptionalField
(
"data"
,
&
data
);
helper
.
DeclareOptionalField
(
"data"
,
&
data
);
helper
.
ReadAllFields
(
reader
);
helper
.
ReadAllFields
(
reader
);
}
}
...
@@ -176,13 +187,19 @@ class JSONAttrGetter : public AttrVisitor {
...
@@ -176,13 +187,19 @@ class JSONAttrGetter : public AttrVisitor {
}
}
}
else
if
(
node
->
is_type
<
MapNode
>
())
{
}
else
if
(
node
->
is_type
<
MapNode
>
())
{
MapNode
*
n
=
static_cast
<
MapNode
*>
(
node
);
MapNode
*
n
=
static_cast
<
MapNode
*>
(
node
);
std
::
vector
<
std
::
pair
<
size_t
,
size_t
>
>
elems
;
for
(
const
auto
&
kv
:
n
->
data
)
{
for
(
const
auto
&
kv
:
n
->
data
)
{
node_
->
data
.
push_back
(
node_
->
data
.
push_back
(
node_index_
->
at
(
kv
.
first
.
get
()));
node_index_
->
at
(
kv
.
first
.
get
()));
node_
->
data
.
push_back
(
node_
->
data
.
push_back
(
node_index_
->
at
(
kv
.
second
.
get
()));
node_index_
->
at
(
kv
.
second
.
get
()));
}
}
}
else
if
(
node
->
is_type
<
StrMapNode
>
())
{
StrMapNode
*
n
=
static_cast
<
StrMapNode
*>
(
node
);
for
(
const
auto
&
kv
:
n
->
data
)
{
node_
->
keys
.
push_back
(
kv
.
first
);
node_
->
data
.
push_back
(
node_index_
->
at
(
kv
.
second
.
get
()));
}
}
else
{
}
else
{
node
->
VisitAttrs
(
this
);
node
->
VisitAttrs
(
this
);
}
}
...
@@ -256,6 +273,13 @@ class JSONAttrSetter : public AttrVisitor {
...
@@ -256,6 +273,13 @@ class JSONAttrSetter : public AttrVisitor {
n
->
data
[
node_list_
->
at
(
node_
->
data
[
i
])]
n
->
data
[
node_list_
->
at
(
node_
->
data
[
i
])]
=
node_list_
->
at
(
node_
->
data
[
i
+
1
]);
=
node_list_
->
at
(
node_
->
data
[
i
+
1
]);
}
}
}
else
if
(
node
->
is_type
<
StrMapNode
>
())
{
StrMapNode
*
n
=
static_cast
<
StrMapNode
*>
(
node
);
CHECK_EQ
(
node_
->
data
.
size
(),
node_
->
keys
.
size
());
for
(
size_t
i
=
0
;
i
<
node_
->
data
.
size
();
++
i
)
{
n
->
data
[
node_
->
keys
[
i
]]
=
node_list_
->
at
(
node_
->
data
[
i
]);
}
}
else
{
}
else
{
node
->
VisitAttrs
(
this
);
node
->
VisitAttrs
(
this
);
}
}
...
@@ -302,7 +326,7 @@ struct JSONGraph {
...
@@ -302,7 +326,7 @@ struct JSONGraph {
getter
.
Get
(
n
);
getter
.
Get
(
n
);
g
.
nodes
.
emplace_back
(
std
::
move
(
jnode
));
g
.
nodes
.
emplace_back
(
std
::
move
(
jnode
));
}
}
g
.
attrs
[
"tvm_version"
]
=
"0.1.0"
;
g
.
attrs
[
"tvm_version"
]
=
TVM_VERSION
;
g
.
root
=
indexer
.
node_index
.
at
(
root
.
node_
.
get
());
g
.
root
=
indexer
.
node_index
.
at
(
root
.
node_
.
get
());
return
g
;
return
g
;
}
}
...
...
src/op/compute_op.cc
View file @
0702d2c0
...
@@ -66,7 +66,8 @@ Array<Expr> ComputeOpNode::output_shape(size_t idx) const {
...
@@ -66,7 +66,8 @@ Array<Expr> ComputeOpNode::output_shape(size_t idx) const {
Tensor
compute
(
Array
<
Expr
>
shape
,
Tensor
compute
(
Array
<
Expr
>
shape
,
FCompute
fcompute
,
FCompute
fcompute
,
std
::
string
name
,
std
::
string
name
,
std
::
string
tag
)
{
std
::
string
tag
,
Map
<
std
::
string
,
NodeRef
>
attrs
)
{
auto
op_node
=
std
::
make_shared
<
ComputeOpNode
>
();
auto
op_node
=
std
::
make_shared
<
ComputeOpNode
>
();
// compute dimension.
// compute dimension.
size_t
ndim
=
shape
.
size
();
size_t
ndim
=
shape
.
size
();
...
@@ -80,13 +81,15 @@ Tensor compute(Array<Expr> shape,
...
@@ -80,13 +81,15 @@ Tensor compute(Array<Expr> shape,
args
.
push_back
(
axis
.
back
()
->
var
);
args
.
push_back
(
axis
.
back
()
->
var
);
}
}
return
ComputeOpNode
::
make
(
name
,
tag
,
axis
,
{
fcompute
(
args
)}).
output
(
0
);
return
ComputeOpNode
::
make
(
name
,
tag
,
attrs
,
axis
,
{
fcompute
(
args
)}).
output
(
0
);
}
}
Array
<
Tensor
>
compute
(
Array
<
Expr
>
shape
,
Array
<
Tensor
>
compute
(
Array
<
Expr
>
shape
,
FBatchCompute
fcompute
,
FBatchCompute
fcompute
,
std
::
string
name
,
std
::
string
name
,
std
::
string
tag
)
{
std
::
string
tag
,
Map
<
std
::
string
,
NodeRef
>
attrs
)
{
auto
op_node
=
std
::
make_shared
<
ComputeOpNode
>
();
auto
op_node
=
std
::
make_shared
<
ComputeOpNode
>
();
// compute dimension.
// compute dimension.
size_t
ndim
=
shape
.
size
();
size_t
ndim
=
shape
.
size
();
...
@@ -100,7 +103,7 @@ Array<Tensor> compute(Array<Expr> shape,
...
@@ -100,7 +103,7 @@ Array<Tensor> compute(Array<Expr> shape,
args
.
push_back
(
axis
.
back
()
->
var
);
args
.
push_back
(
axis
.
back
()
->
var
);
}
}
Operation
op
=
ComputeOpNode
::
make
(
name
,
tag
,
axis
,
fcompute
(
args
));
Operation
op
=
ComputeOpNode
::
make
(
name
,
tag
,
a
ttrs
,
a
xis
,
fcompute
(
args
));
Array
<
Tensor
>
outputs
;
Array
<
Tensor
>
outputs
;
for
(
int
idx
=
0
;
idx
<
op
->
num_outputs
();
++
idx
)
{
for
(
int
idx
=
0
;
idx
<
op
->
num_outputs
();
++
idx
)
{
outputs
.
push_back
(
op
.
output
(
idx
));
outputs
.
push_back
(
op
.
output
(
idx
));
...
@@ -110,13 +113,15 @@ Array<Tensor> compute(Array<Expr> shape,
...
@@ -110,13 +113,15 @@ Array<Tensor> compute(Array<Expr> shape,
Operation
ComputeOpNode
::
make
(
std
::
string
name
,
Operation
ComputeOpNode
::
make
(
std
::
string
name
,
std
::
string
tag
,
std
::
string
tag
,
Map
<
std
::
string
,
NodeRef
>
attrs
,
Array
<
IterVar
>
axis
,
Array
<
IterVar
>
axis
,
Array
<
Expr
>
body
)
{
Array
<
Expr
>
body
)
{
auto
n
=
std
::
make_shared
<
ComputeOpNode
>
();
auto
n
=
std
::
make_shared
<
ComputeOpNode
>
();
n
->
name
=
name
;
n
->
name
=
std
::
move
(
name
);
n
->
tag
=
tag
;
n
->
tag
=
std
::
move
(
tag
);
n
->
axis
=
axis
;
n
->
attrs
=
std
::
move
(
attrs
);
n
->
body
=
body
;
n
->
axis
=
std
::
move
(
axis
);
n
->
body
=
std
::
move
(
body
);
if
(
n
->
body
[
0
]
->
is_type
<
ir
::
Reduce
>
())
{
if
(
n
->
body
[
0
]
->
is_type
<
ir
::
Reduce
>
())
{
const
ir
::
Reduce
*
reduce
=
n
->
body
[
0
].
as
<
ir
::
Reduce
>
();
const
ir
::
Reduce
*
reduce
=
n
->
body
[
0
].
as
<
ir
::
Reduce
>
();
n
->
reduce_axis
=
reduce
->
axis
;
n
->
reduce_axis
=
reduce
->
axis
;
...
@@ -171,7 +176,8 @@ Operation ComputeOpNode::ReplaceInputs(
...
@@ -171,7 +176,8 @@ Operation ComputeOpNode::ReplaceInputs(
});
});
}
}
if
(
!
arr
.
same_as
(
this
->
body
))
{
if
(
!
arr
.
same_as
(
this
->
body
))
{
return
ComputeOpNode
::
make
(
name
,
tag
,
axis
,
arr
);
return
ComputeOpNode
::
make
(
this
->
name
,
this
->
tag
,
this
->
attrs
,
this
->
axis
,
arr
);
}
else
{
}
else
{
return
self
;
return
self
;
}
}
...
...
src/op/extern_op.cc
View file @
0702d2c0
...
@@ -38,23 +38,25 @@ Array<Expr> ExternOpNode::output_shape(size_t i) const {
...
@@ -38,23 +38,25 @@ Array<Expr> ExternOpNode::output_shape(size_t i) const {
Operation
ExternOpNode
::
make
(
std
::
string
name
,
Operation
ExternOpNode
::
make
(
std
::
string
name
,
std
::
string
tag
,
std
::
string
tag
,
Map
<
std
::
string
,
NodeRef
>
attrs
,
Array
<
Tensor
>
inputs
,
Array
<
Tensor
>
inputs
,
Array
<
Buffer
>
input_placeholders
,
Array
<
Buffer
>
input_placeholders
,
Array
<
Buffer
>
output_placeholders
,
Array
<
Buffer
>
output_placeholders
,
Stmt
body
)
{
Stmt
body
)
{
auto
n
=
std
::
make_shared
<
ExternOpNode
>
();
auto
n
=
std
::
make_shared
<
ExternOpNode
>
();
n
->
name
=
name
;
n
->
name
=
std
::
move
(
name
);
n
->
tag
=
tag
;
n
->
tag
=
std
::
move
(
tag
);
n
->
attrs
=
std
::
move
(
attrs
);
CHECK_EQ
(
inputs
.
size
(),
input_placeholders
.
size
());
CHECK_EQ
(
inputs
.
size
(),
input_placeholders
.
size
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
CHECK_EQ
(
inputs
[
i
]
->
dtype
,
input_placeholders
[
i
]
->
dtype
);
CHECK_EQ
(
inputs
[
i
]
->
dtype
,
input_placeholders
[
i
]
->
dtype
);
CHECK
(
inputs
[
i
]
->
shape
.
same_as
(
input_placeholders
[
i
]
->
shape
));
CHECK
(
inputs
[
i
]
->
shape
.
same_as
(
input_placeholders
[
i
]
->
shape
));
CHECK_EQ
(
input_placeholders
[
i
]
->
strides
.
size
(),
0U
);
CHECK_EQ
(
input_placeholders
[
i
]
->
strides
.
size
(),
0U
);
}
}
n
->
inputs
=
inputs
;
n
->
inputs
=
std
::
move
(
inputs
)
;
n
->
input_placeholders
=
input_placeholders
;
n
->
input_placeholders
=
std
::
move
(
input_placeholders
)
;
n
->
output_placeholders
=
output_placeholders
;
n
->
output_placeholders
=
std
::
move
(
output_placeholders
)
;
n
->
body
=
body
;
n
->
body
=
std
::
move
(
body
)
;
return
Operation
(
n
);
return
Operation
(
n
);
}
}
...
...
src/op/scan_op.cc
View file @
0702d2c0
...
@@ -45,6 +45,7 @@ Array<Expr> ScanOpNode::output_shape(size_t i) const {
...
@@ -45,6 +45,7 @@ Array<Expr> ScanOpNode::output_shape(size_t i) const {
Operation
ScanOpNode
::
make
(
std
::
string
name
,
Operation
ScanOpNode
::
make
(
std
::
string
name
,
std
::
string
tag
,
std
::
string
tag
,
Map
<
std
::
string
,
NodeRef
>
attrs
,
IterVar
axis
,
IterVar
axis
,
Array
<
Tensor
>
init
,
Array
<
Tensor
>
init
,
Array
<
Tensor
>
update
,
Array
<
Tensor
>
update
,
...
@@ -86,13 +87,14 @@ Operation ScanOpNode::make(std::string name,
...
@@ -86,13 +87,14 @@ Operation ScanOpNode::make(std::string name,
init
[
i
]
->
shape
[
k
],
state_placeholder
[
i
]
->
shape
[
k
]));
init
[
i
]
->
shape
[
k
],
state_placeholder
[
i
]
->
shape
[
k
]));
}
}
}
}
n
->
name
=
name
;
n
->
name
=
std
::
move
(
name
);
n
->
tag
=
tag
;
n
->
tag
=
std
::
move
(
tag
);
n
->
scan_axis
=
axis
;
n
->
attrs
=
std
::
move
(
attrs
);
n
->
init
=
init
;
n
->
scan_axis
=
std
::
move
(
axis
);
n
->
update
=
update
;
n
->
init
=
std
::
move
(
init
);
n
->
state_placeholder
=
state_placeholder
;
n
->
update
=
std
::
move
(
update
);
n
->
inputs
=
inputs
;
n
->
state_placeholder
=
std
::
move
(
state_placeholder
);
n
->
inputs
=
std
::
move
(
inputs
);
return
Operation
(
n
);
return
Operation
(
n
);
}
}
...
@@ -101,14 +103,16 @@ Array<Tensor> scan(Array<Tensor> init,
...
@@ -101,14 +103,16 @@ Array<Tensor> scan(Array<Tensor> init,
Array
<
Tensor
>
state_placeholder
,
Array
<
Tensor
>
state_placeholder
,
Array
<
Tensor
>
inputs
,
Array
<
Tensor
>
inputs
,
std
::
string
name
,
std
::
string
name
,
std
::
string
tag
)
{
std
::
string
tag
,
Map
<
std
::
string
,
NodeRef
>
attrs
)
{
IterVar
scan_axis
=
IterVar
scan_axis
=
IterVarNode
::
make
(
IterVarNode
::
make
(
Range
::
make_by_min_extent
(
Range
::
make_by_min_extent
(
init
[
0
]
->
shape
[
0
],
update
[
0
]
->
shape
[
0
]
-
init
[
0
]
->
shape
[
0
]),
init
[
0
]
->
shape
[
0
],
update
[
0
]
->
shape
[
0
]
-
init
[
0
]
->
shape
[
0
]),
Var
(
name
+
".idx"
),
kOrdered
);
Var
(
name
+
".idx"
),
kOrdered
);
Operation
op
=
ScanOpNode
::
make
(
Operation
op
=
ScanOpNode
::
make
(
name
,
tag
,
scan_axis
,
init
,
update
,
state_placeholder
,
inputs
);
name
,
tag
,
attrs
,
scan_axis
,
init
,
update
,
state_placeholder
,
inputs
);
Array
<
Tensor
>
res
;
Array
<
Tensor
>
res
;
for
(
int
i
=
0
;
i
<
op
->
num_outputs
();
++
i
)
{
for
(
int
i
=
0
;
i
<
op
->
num_outputs
();
++
i
)
{
res
.
push_back
(
op
.
output
(
i
));
res
.
push_back
(
op
.
output
(
i
));
...
...
src/schedule/schedule_dataflow_rewrite.cc
View file @
0702d2c0
...
@@ -232,7 +232,8 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
...
@@ -232,7 +232,8 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
}
}
}
}
Operation
cache_op
=
ComputeOpNode
::
make
(
Operation
cache_op
=
ComputeOpNode
::
make
(
compute
->
name
+
"."
+
scope
,
compute
->
tag
,
new_axis
,
body_list
);
compute
->
name
+
"."
+
scope
,
compute
->
tag
,
compute
->
attrs
,
new_axis
,
body_list
);
Array
<
Tensor
>
cache_tensor_list
;
Array
<
Tensor
>
cache_tensor_list
;
Array
<
Expr
>
cache_expr_list
;
Array
<
Expr
>
cache_expr_list
;
for
(
size_t
i
=
0
;
i
<
tensor_size
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
tensor_size
;
i
++
)
{
...
@@ -241,7 +242,8 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
...
@@ -241,7 +242,8 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
cache_expr_list
.
push_back
(
cache_tensor
(
args
));
cache_expr_list
.
push_back
(
cache_tensor
(
args
));
}
}
Operation
orig_new_op
=
ComputeOpNode
::
make
(
Operation
orig_new_op
=
ComputeOpNode
::
make
(
compute
->
name
,
compute
->
tag
,
compute
->
axis
,
cache_expr_list
);
compute
->
name
,
compute
->
tag
,
compute
->
attrs
,
compute
->
axis
,
cache_expr_list
);
// The replace of the dataflow
// The replace of the dataflow
std
::
unordered_map
<
Tensor
,
Tensor
>
vmap
;
std
::
unordered_map
<
Tensor
,
Tensor
>
vmap
;
std
::
unordered_map
<
Tensor
,
Tensor
>
rvmap
;
std
::
unordered_map
<
Tensor
,
Tensor
>
rvmap
;
...
@@ -430,7 +432,8 @@ void InjectInline(ScheduleNode* sch) {
...
@@ -430,7 +432,8 @@ void InjectInline(ScheduleNode* sch) {
Operation
op
=
s
->
op
;
Operation
op
=
s
->
op
;
if
(
changed
[
i
])
{
if
(
changed
[
i
])
{
op
=
ComputeOpNode
::
make
(
op
=
ComputeOpNode
::
make
(
compute
->
name
,
compute
->
tag
,
compute
->
axis
,
new_body
[
i
]);
compute
->
name
,
compute
->
tag
,
compute
->
attrs
,
compute
->
axis
,
new_body
[
i
]);
}
}
op
=
op
->
ReplaceInputs
(
op
,
repl
);
op
=
op
->
ReplaceInputs
(
op
,
repl
);
if
(
!
op
.
same_as
(
s
->
op
))
{
if
(
!
op
.
same_as
(
s
->
op
))
{
...
...
tests/python/unittest/test_lang_reflection.py
View file @
0702d2c0
...
@@ -11,6 +11,18 @@ def test_const_saveload_json():
...
@@ -11,6 +11,18 @@ def test_const_saveload_json():
assert
tvm
.
save_json
(
zz
)
==
tvm
.
save_json
(
z
)
assert
tvm
.
save_json
(
zz
)
==
tvm
.
save_json
(
z
)
def
test_make_smap
():
# save load json
x
=
tvm
.
const
(
1
)
y
=
tvm
.
const
(
10
)
z
=
x
+
y
smap
=
tvm
.
convert
({
"z"
:
z
,
"x"
:
x
})
json_str
=
tvm
.
save_json
(
tvm
.
convert
([
smap
]))
arr
=
tvm
.
load_json
(
json_str
)
assert
len
(
arr
)
==
1
assert
arr
[
0
][
"z"
]
.
a
==
arr
[
0
][
"x"
]
def
test_make_node
():
def
test_make_node
():
x
=
tvm
.
make
.
node
(
"IntImm"
,
dtype
=
"int32"
,
value
=
10
)
x
=
tvm
.
make
.
node
(
"IntImm"
,
dtype
=
"int32"
,
value
=
10
)
assert
isinstance
(
x
,
tvm
.
expr
.
IntImm
)
assert
isinstance
(
x
,
tvm
.
expr
.
IntImm
)
...
@@ -35,5 +47,6 @@ def test_make_sum():
...
@@ -35,5 +47,6 @@ def test_make_sum():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_make_node
()
test_make_node
()
test_make_smap
()
test_const_saveload_json
()
test_const_saveload_json
()
test_make_sum
()
test_make_sum
()
tests/python/unittest/test_lang_tag.py
View file @
0702d2c0
import
json
import
tvm
import
tvm
@tvm.tag_scope
(
tag
=
"conv"
)
@tvm.tag_scope
(
tag
=
"conv"
)
...
@@ -24,8 +25,19 @@ def test_with():
...
@@ -24,8 +25,19 @@ def test_with():
B
=
tvm
.
placeholder
((
m
,
l
),
name
=
'B'
)
B
=
tvm
.
placeholder
((
m
,
l
),
name
=
'B'
)
with
tvm
.
tag_scope
(
tag
=
"gemm"
):
with
tvm
.
tag_scope
(
tag
=
"gemm"
):
k
=
tvm
.
reduce_axis
((
0
,
l
),
name
=
'k'
)
k
=
tvm
.
reduce_axis
((
0
,
l
),
name
=
'k'
)
C
=
tvm
.
compute
((
n
,
m
),
lambda
i
,
j
:
tvm
.
sum
(
A
[
i
,
k
]
*
B
[
j
,
k
],
axis
=
k
))
C
=
tvm
.
compute
((
n
,
m
),
lambda
i
,
j
:
tvm
.
sum
(
A
[
i
,
k
]
*
B
[
j
,
k
],
axis
=
k
),
attrs
=
{
"hello"
:
1
,
"arr"
:
[
10
,
12
]})
assert
C
.
op
.
tag
==
'gemm'
assert
C
.
op
.
tag
==
'gemm'
assert
"hello"
in
C
.
op
.
attrs
assert
"xx"
not
in
C
.
op
.
attrs
assert
C
.
op
.
attrs
[
"hello"
]
.
value
==
1
CC
=
tvm
.
load_json
(
tvm
.
save_json
(
C
))
assert
CC
.
op
.
attrs
[
"hello"
]
.
value
==
1
assert
CC
.
op
.
attrs
[
"arr"
][
0
]
.
value
==
10
# str format happened to be json compatible
assert
json
.
loads
(
str
(
CC
.
op
.
attrs
))[
"arr"
][
1
]
==
12
def
test_decorator
():
def
test_decorator
():
n
=
tvm
.
var
(
'n'
)
n
=
tvm
.
var
(
'n'
)
...
@@ -39,6 +51,7 @@ def test_decorator():
...
@@ -39,6 +51,7 @@ def test_decorator():
B
=
tvm
.
placeholder
((
c
,
c
,
kh
,
kw
),
name
=
'B'
)
B
=
tvm
.
placeholder
((
c
,
c
,
kh
,
kw
),
name
=
'B'
)
C
=
compute_conv
(
A
,
B
)
C
=
compute_conv
(
A
,
B
)
assert
C
.
op
.
tag
==
'conv'
assert
C
.
op
.
tag
==
'conv'
assert
len
(
C
.
op
.
attrs
)
==
0
def
test_nested
():
def
test_nested
():
n
=
tvm
.
var
(
'n'
)
n
=
tvm
.
var
(
'n'
)
...
@@ -59,5 +72,6 @@ def test_nested():
...
@@ -59,5 +72,6 @@ def test_nested():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
import
nose
test_with
()
nose
.
runmodule
()
test_decorator
()
test_nested
()
topi/include/topi/contrib/cublas.h
View file @
0702d2c0
...
@@ -40,7 +40,7 @@ inline Tensor cublas_matmul(const Tensor& lhs,
...
@@ -40,7 +40,7 @@ inline Tensor cublas_matmul(const Tensor& lhs,
pack_buffer
(
outs
[
0
]),
pack_buffer
(
outs
[
0
]),
transa
,
transa
,
transb
});
transb
});
},
"C"
,
""
)[
0
];
},
"C"
,
""
,
{}
)[
0
];
}
}
}
// namespace contrib
}
// namespace contrib
...
...
topi/include/topi/contrib/rocblas.h
View file @
0702d2c0
...
@@ -39,7 +39,7 @@ inline Tensor rocblas_matmul(const Tensor& lhs,
...
@@ -39,7 +39,7 @@ inline Tensor rocblas_matmul(const Tensor& lhs,
pack_buffer
(
outs
[
0
]),
pack_buffer
(
outs
[
0
]),
transa
,
transa
,
transb
});
transb
});
},
"C"
,
""
)[
0
];
},
"C"
,
""
,
{}
)[
0
];
}
}
}
// namespace contrib
}
// namespace contrib
...
...
topi/include/topi/detail/extern.h
View file @
0702d2c0
...
@@ -6,10 +6,10 @@
...
@@ -6,10 +6,10 @@
#ifndef TOPI_DETAIL_EXTERN_H_
#ifndef TOPI_DETAIL_EXTERN_H_
#define TOPI_DETAIL_EXTERN_H_
#define TOPI_DETAIL_EXTERN_H_
#include <tvm/tvm.h>
#include <vector>
#include <vector>
#include <string>
#include <string>
#include "tvm/tvm.h"
namespace
topi
{
namespace
topi
{
namespace
detail
{
namespace
detail
{
...
@@ -51,6 +51,7 @@ using FExtern = std::function<Expr(Array<Buffer>, Array<Buffer>)>;
...
@@ -51,6 +51,7 @@ using FExtern = std::function<Expr(Array<Buffer>, Array<Buffer>)>;
* the external function given the input and output buffers.
* the external function given the input and output buffers.
* \param name The name of the operation
* \param name The name of the operation
* \param tag The tag to mark the operation
* \param tag The tag to mark the operation
* \param attrs The additional auxiliary attributes of the operation.
*
*
* \return An array of Tensors representing the outputs of the function invocation. There will
* \return An array of Tensors representing the outputs of the function invocation. There will
* be one output Tensor for each element of out_shapes, with dtype equal to the corresponding
* be one output Tensor for each element of out_shapes, with dtype equal to the corresponding
...
@@ -61,7 +62,8 @@ inline Array<Tensor> make_extern(const Array< Array<Expr> >& out_shapes,
...
@@ -61,7 +62,8 @@ inline Array<Tensor> make_extern(const Array< Array<Expr> >& out_shapes,
const
Array
<
Tensor
>&
inputs
,
const
Array
<
Tensor
>&
inputs
,
FExtern
fextern
,
FExtern
fextern
,
std
::
string
name
,
std
::
string
name
,
std
::
string
tag
)
{
std
::
string
tag
,
::
tvm
::
Map
<
std
::
string
,
NodeRef
>
attrs
)
{
CHECK_EQ
(
out_shapes
.
size
(),
out_types
.
size
())
CHECK_EQ
(
out_shapes
.
size
(),
out_types
.
size
())
<<
"make_extern: out_shapes and out_types must have equal size"
;
<<
"make_extern: out_shapes and out_types must have equal size"
;
...
@@ -78,7 +80,8 @@ inline Array<Tensor> make_extern(const Array< Array<Expr> >& out_shapes,
...
@@ -78,7 +80,8 @@ inline Array<Tensor> make_extern(const Array< Array<Expr> >& out_shapes,
auto
body_stmt
=
tvm
::
ir
::
Evaluate
::
make
(
body
);
auto
body_stmt
=
tvm
::
ir
::
Evaluate
::
make
(
body
);
auto
op
=
ExternOpNode
::
make
(
auto
op
=
ExternOpNode
::
make
(
name
,
tag
,
inputs
,
input_placeholders
,
output_placeholders
,
body_stmt
);
name
,
tag
,
attrs
,
inputs
,
input_placeholders
,
output_placeholders
,
body_stmt
);
Array
<
Tensor
>
outputs
;
Array
<
Tensor
>
outputs
;
for
(
size_t
i
=
0
;
i
<
output_placeholders
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
output_placeholders
.
size
();
++
i
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment