Commit c9da7254 by ziheng Committed by GitHub

[TAG] Add tvm.tag module for tagging operator (#217)

* [TAG] Add op_tag module for tagging operator

* Fix accroading to comments

* Add example

* Add into doc

* Add --fix-missing for docker
parent 8a66ac23
...@@ -20,6 +20,7 @@ The user facing API for computation declaration. ...@@ -20,6 +20,7 @@ The user facing API for computation declaration.
tvm.sum tvm.sum
tvm.min tvm.min
tvm.max tvm.max
tvm.tag_scope
.. autofunction:: tvm.load_json .. autofunction:: tvm.load_json
.. autofunction:: tvm.save_json .. autofunction:: tvm.save_json
...@@ -37,3 +38,4 @@ The user facing API for computation declaration. ...@@ -37,3 +38,4 @@ The user facing API for computation declaration.
.. autofunction:: tvm.sum .. autofunction:: tvm.sum
.. autofunction:: tvm.min .. autofunction:: tvm.min
.. autofunction:: tvm.max .. autofunction:: tvm.max
.. autofunction:: tvm.tag_scope
...@@ -39,6 +39,8 @@ class OperationNode : public FunctionBaseNode { ...@@ -39,6 +39,8 @@ class OperationNode : public FunctionBaseNode {
public: public:
/*! \brief optional name of the operation */ /*! \brief optional name of the operation */
std::string name; std::string name;
/*! \brief optional tag of the operation */
std::string tag;
/*! \return name of the operation */ /*! \return name of the operation */
const std::string& func_name() const final { const std::string& func_name() const final {
return name; return name;
...@@ -213,11 +215,13 @@ class ComputeOpNode : public OperationNode { ...@@ -213,11 +215,13 @@ class ComputeOpNode : public OperationNode {
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("axis", &axis); v->Visit("axis", &axis);
v->Visit("reduce_axis", &reduce_axis); v->Visit("reduce_axis", &reduce_axis);
v->Visit("body", &body); v->Visit("body", &body);
} }
static Operation make(std::string name, static Operation make(std::string name,
std::string tag,
Array<IterVar> axis, Array<IterVar> axis,
Array<Expr> body); Array<Expr> body);
...@@ -282,6 +286,7 @@ class ScanOpNode : public OperationNode { ...@@ -282,6 +286,7 @@ class ScanOpNode : public OperationNode {
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("scan_axis", &scan_axis); v->Visit("scan_axis", &scan_axis);
v->Visit("init", &init); v->Visit("init", &init);
v->Visit("update", &update); v->Visit("update", &update);
...@@ -290,6 +295,7 @@ class ScanOpNode : public OperationNode { ...@@ -290,6 +295,7 @@ class ScanOpNode : public OperationNode {
v->Visit("spatial_axis_", &spatial_axis_); v->Visit("spatial_axis_", &spatial_axis_);
} }
static Operation make(std::string name, static Operation make(std::string name,
std::string tag,
IterVar axis, IterVar axis,
Array<Tensor> init, Array<Tensor> init,
Array<Tensor> update, Array<Tensor> update,
...@@ -343,10 +349,12 @@ class ExternOpNode : public OperationNode { ...@@ -343,10 +349,12 @@ class ExternOpNode : public OperationNode {
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("inputs", &inputs); v->Visit("inputs", &inputs);
v->Visit("body", &body); v->Visit("body", &body);
} }
static Operation make(std::string name, static Operation make(std::string name,
std::string tag,
Array<Tensor> inputs, Array<Tensor> inputs,
Array<Buffer> input_placeholders, Array<Buffer> input_placeholders,
Array<Buffer> output_placeholders, Array<Buffer> output_placeholders,
...@@ -378,8 +386,12 @@ Tensor placeholder(Array<Expr> shape, ...@@ -378,8 +386,12 @@ Tensor placeholder(Array<Expr> shape,
* \param shape Shape of the tensor. * \param shape Shape of the tensor.
* \param fcompute The compute function to create the tensor. * \param fcompute The compute function to create the tensor.
* \param name The optional name of the tensor. * \param name The optional name of the tensor.
* \param tag The optional tag of the tensor.
*/ */
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor"); Tensor compute(Array<Expr> shape,
FCompute fcompute,
std::string name = "tensor",
std::string tag = "");
/*! /*!
* \brief Construct a new tensor by computing over shape, * \brief Construct a new tensor by computing over shape,
...@@ -387,8 +399,12 @@ Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor" ...@@ -387,8 +399,12 @@ Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor"
* \param shape Shape of the tensor. * \param shape Shape of the tensor.
* \param fcompute The compute function to create the tensors. * \param fcompute The compute function to create the tensors.
* \param name The optional name of the tensor. * \param name The optional name of the tensor.
* \param tag The optional tag of the tensor.
*/ */
Array<Tensor> compute(Array<Expr> shape, FBatchCompute fcompute, std::string name = "tensor"); Array<Tensor> compute(Array<Expr> shape,
FBatchCompute fcompute,
std::string name = "tensor",
std::string tag = "");
/*! /*!
* \brief Construct new tensors by scan. * \brief Construct new tensors by scan.
...@@ -399,37 +415,43 @@ Array<Tensor> compute(Array<Expr> shape, FBatchCompute fcompute, std::string nam ...@@ -399,37 +415,43 @@ Array<Tensor> compute(Array<Expr> shape, FBatchCompute fcompute, std::string nam
* \param inputs The inputs to the scan body, this is optional, * \param inputs The inputs to the scan body, this is optional,
* but recommended to provide concrete information about scan body. * but recommended to provide concrete information about scan body.
* \param name The optional name of the tensor. * \param name The optional name of the tensor.
* \param tag The optional tag of the tensor.
*/ */
Array<Tensor> scan(Array<Tensor> init, Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> update, Array<Tensor> update,
Array<Tensor> state_placeholder, Array<Tensor> state_placeholder,
Array<Tensor> inputs = Array<Tensor>(), Array<Tensor> inputs = Array<Tensor>(),
std::string name = "scan"); std::string name = "scan",
std::string tag = "");
// same as compute, specialized for different fcompute function // same as compute, specialized for different fcompute function
inline Tensor compute(Array<Expr> shape, inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var)> f, std::function<Expr(Var)> f,
std::string name = "tensor") { std::string name = "tensor",
std::string tag = "") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); }; FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); };
return compute(shape, fc, name); return compute(shape, fc, name, tag);
} }
inline Tensor compute(Array<Expr> shape, inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var)> f, std::function<Expr(Var, Var)> f,
std::string name = "tensor") { std::string name = "tensor",
std::string tag = "") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); }; FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); };
return compute(shape, fc, name); return compute(shape, fc, name, tag);
} }
inline Tensor compute(Array<Expr> shape, inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var)> f, std::function<Expr(Var, Var, Var)> f,
std::string name = "tensor") { std::string name = "tensor",
std::string tag = "") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); }; FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
return compute(shape, fc, name); return compute(shape, fc, name, tag);
} }
inline Tensor compute(Array<Expr> shape, inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var, Var)> f, std::function<Expr(Var, Var, Var, Var)> f,
std::string name = "tensor") { std::string name = "tensor",
std::string tag = "") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); }; FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
return compute(shape, fc, name); return compute(shape, fc, name, tag);
} }
// inline function. // inline function.
......
...@@ -27,3 +27,4 @@ from .node import register_node ...@@ -27,3 +27,4 @@ from .node import register_node
from .ndarray import register_extension from .ndarray import register_extension
from .schedule import create_schedule from .schedule import create_schedule
from .build import build, lower, build_config from .build import build, lower, build_config
from .tag import tag_scope
...@@ -16,6 +16,7 @@ from . import expr as _expr ...@@ -16,6 +16,7 @@ from . import expr as _expr
from . import tensor as _tensor from . import tensor as _tensor
from . import schedule as _schedule from . import schedule as _schedule
from . import collections as _collections from . import collections as _collections
from . import tag as _tag
int32 = "int32" int32 = "int32"
float32 = "float32" float32 = "float32"
...@@ -186,7 +187,7 @@ def placeholder(shape, dtype=None, name="placeholder"): ...@@ -186,7 +187,7 @@ def placeholder(shape, dtype=None, name="placeholder"):
shape, dtype, name) shape, dtype, name)
def compute(shape, fcompute, name="compute"): def compute(shape, fcompute, name="compute", tag=""):
"""Construct a new tensor by computing over the shape domain. """Construct a new tensor by computing over the shape domain.
The compute rule is result[axis] = fcompute(axis) The compute rule is result[axis] = fcompute(axis)
...@@ -207,6 +208,10 @@ def compute(shape, fcompute, name="compute"): ...@@ -207,6 +208,10 @@ def compute(shape, fcompute, name="compute"):
tensor: Tensor tensor: Tensor
The created tensor The created tensor
""" """
if _tag.TagScope.current is not None:
if tag != "":
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.current.tag
shape = (shape,) if isinstance(shape, _expr.Expr) else shape shape = (shape,) if isinstance(shape, _expr.Expr) else shape
ndim = len(shape) ndim = len(shape)
code = fcompute.__code__ code = fcompute.__code__
...@@ -225,13 +230,13 @@ def compute(shape, fcompute, name="compute"): ...@@ -225,13 +230,13 @@ def compute(shape, fcompute, name="compute"):
body = [body] body = [body]
body = convert(body) body = convert(body)
op_node = _api_internal._ComputeOp( op_node = _api_internal._ComputeOp(
name, dim_var, body) name, tag, dim_var, body)
num = op_node.num_outputs num = op_node.num_outputs
outputs = tuple(op_node.output(i) for i in range(num)) outputs = tuple(op_node.output(i) for i in range(num))
return outputs[0] if num == 1 else outputs return outputs[0] if num == 1 else outputs
def scan(init, update, state_placeholder, inputs=None, name="scan"): def scan(init, update, state_placeholder, inputs=None, name="scan", tag=""):
"""Construct new tensors by scanning over axis. """Construct new tensors by scanning over axis.
Parameters Parameters
...@@ -270,6 +275,10 @@ def scan(init, update, state_placeholder, inputs=None, name="scan"): ...@@ -270,6 +275,10 @@ def scan(init, update, state_placeholder, inputs=None, name="scan"):
s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i]) s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
res = tvm.scan(s_init, s_update, s_state, X) res = tvm.scan(s_init, s_update, s_state, X)
""" """
if _tag.TagScope.current is not None:
if tag != "":
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.current.tag
if isinstance(init, _tensor.Tensor): if isinstance(init, _tensor.Tensor):
init = [init] init = [init]
if isinstance(update, _tensor.Tensor): if isinstance(update, _tensor.Tensor):
...@@ -283,13 +292,13 @@ def scan(init, update, state_placeholder, inputs=None, name="scan"): ...@@ -283,13 +292,13 @@ def scan(init, update, state_placeholder, inputs=None, name="scan"):
if len(init) != len(update) or len(init) != len(state_placeholder): if len(init) != len(update) or len(init) != len(state_placeholder):
raise ValueError("init, update, state_placeholder must have same length") raise ValueError("init, update, state_placeholder must have same length")
axis = _IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3) axis = _IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3)
op = _api_internal._ScanOp(name, axis, init, update, state_placeholder, inputs) op = _api_internal._ScanOp(name, tag, axis, init, update,
state_placeholder, inputs)
res = [op.output(i) for i in range(len(update))] res = [op.output(i) for i in range(len(update))]
return res[0] if len(res) == 1 else res return res[0] if len(res) == 1 else res
def extern(shape, inputs, fcompute, def extern(shape, inputs, fcompute, name="extern", dtype=None, tag=""):
name="extern", dtype=None):
"""Compute several tensor via extern function. """Compute several tensor via extern function.
Parameters Parameters
...@@ -340,6 +349,10 @@ def extern(shape, inputs, fcompute, ...@@ -340,6 +349,10 @@ def extern(shape, inputs, fcompute,
"tvm.contrib.cblas.matmul", "tvm.contrib.cblas.matmul",
ins[0], ins[1], outs[0], 0, 0), name="C") ins[0], ins[1], outs[0], 0, 0), name="C")
""" """
if _tag.TagScope.current is not None:
if tag != "":
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.current.tag
shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
shape = [shape] if isinstance(shape[0], (_expr.Expr, _Integral)) else shape shape = [shape] if isinstance(shape[0], (_expr.Expr, _Integral)) else shape
input_placeholders = [] input_placeholders = []
...@@ -364,8 +377,8 @@ def extern(shape, inputs, fcompute, ...@@ -364,8 +377,8 @@ def extern(shape, inputs, fcompute,
if isinstance(body, _expr.Expr): if isinstance(body, _expr.Expr):
body = _make.Evaluate(body) body = _make.Evaluate(body)
op = _api_internal._ExternOp( op = _api_internal._ExternOp(name, tag, inputs, input_placeholders,
name, inputs, input_placeholders, output_placeholders, body) output_placeholders, body)
res = [op.output(i) for i in range(len(output_placeholders))] res = [op.output(i) for i in range(len(output_placeholders))]
return res[0] if len(res) == 1 else res return res[0] if len(res) == 1 else res
......
"""Tag class for TVM operators."""
from functools import wraps
class TagScope(object):
"""Tag scope object to set tag for operators, working as context
manager and decorator both. See also tag_scope.
"""
current = None
def __init__(self, tag):
self._old_scope = None
self.tag = tag
def __enter__(self):
if TagScope.current is not None:
raise ValueError("nested op_tag is not allowed for now")
self._old_scope = TagScope.current
TagScope.current = self
return self
def __exit__(self, ptype, value, trace):
assert self._old_scope is None
TagScope.current = self._old_scope
def __call__(self, fdecl):
@wraps(fdecl)
def tagged_fdecl(*args, **kwargs):
with self:
return fdecl(*args, **kwargs)
return tagged_fdecl
def tag_scope(tag):
"""The operator tag scope.
Parameters
----------
tag: str
The tag name.
Returns
-------
tag_scope: TagScope
The tag scope object, which can be used as decorator or
context manger.
Example
-------
.. code-block:: python
n = tvm.var('n')
m = tvm.var('m')
l = tvm.var('m')
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((m, l), name='A')
k = tvm.reduce_axis((0, l), name='k')
with tvm.tag_scope(tag='matmul'):
C = tvm.compute((n, m), lambda i, j: tvm.sum(A[i, k] * B[j, k], axis=k))
# or use tag_scope as decorator
@tvm.tag_scope(tag="conv")
def compute_relu(data):
return tvm.compute(data.shape, lambda *i: tvm.select(data(*i) < 0, 0.0, data(*i)))
"""
return TagScope(tag)
...@@ -185,7 +185,8 @@ TVM_REGISTER_API("_ComputeOp") ...@@ -185,7 +185,8 @@ TVM_REGISTER_API("_ComputeOp")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ComputeOpNode::make(args[0], *ret = ComputeOpNode::make(args[0],
args[1], args[1],
args[2]); args[2],
args[3]);
}); });
TVM_REGISTER_API("_ScanOp") TVM_REGISTER_API("_ScanOp")
...@@ -195,7 +196,8 @@ TVM_REGISTER_API("_ScanOp") ...@@ -195,7 +196,8 @@ TVM_REGISTER_API("_ScanOp")
args[2], args[2],
args[3], args[3],
args[4], args[4],
args[5]); args[5],
args[6]);
}); });
TVM_REGISTER_API("_ExternOp") TVM_REGISTER_API("_ExternOp")
...@@ -204,7 +206,8 @@ TVM_REGISTER_API("_ExternOp") ...@@ -204,7 +206,8 @@ TVM_REGISTER_API("_ExternOp")
args[1], args[1],
args[2], args[2],
args[3], args[3],
args[4]); args[4],
args[5]);
}); });
TVM_REGISTER_API("_OpGetOutput") TVM_REGISTER_API("_OpGetOutput")
......
...@@ -53,7 +53,10 @@ Array<Expr> ComputeOpNode::output_shape(size_t idx) const { ...@@ -53,7 +53,10 @@ Array<Expr> ComputeOpNode::output_shape(size_t idx) const {
return Array<Expr>(shape); return Array<Expr>(shape);
} }
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name) { Tensor compute(Array<Expr> shape,
FCompute fcompute,
std::string name,
std::string tag) {
auto op_node = std::make_shared<ComputeOpNode>(); auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension. // compute dimension.
size_t ndim = shape.size(); size_t ndim = shape.size();
...@@ -67,10 +70,13 @@ Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name) { ...@@ -67,10 +70,13 @@ Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name) {
args.push_back(axis.back()->var); args.push_back(axis.back()->var);
} }
return ComputeOpNode::make(name, axis, {fcompute(args)}).output(0); return ComputeOpNode::make(name, tag, axis, {fcompute(args)}).output(0);
} }
Array<Tensor> compute(Array<Expr> shape, FBatchCompute fcompute, std::string name) { Array<Tensor> compute(Array<Expr> shape,
FBatchCompute fcompute,
std::string name,
std::string tag) {
auto op_node = std::make_shared<ComputeOpNode>(); auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension. // compute dimension.
size_t ndim = shape.size(); size_t ndim = shape.size();
...@@ -84,7 +90,7 @@ Array<Tensor> compute(Array<Expr> shape, FBatchCompute fcompute, std::string nam ...@@ -84,7 +90,7 @@ Array<Tensor> compute(Array<Expr> shape, FBatchCompute fcompute, std::string nam
args.push_back(axis.back()->var); args.push_back(axis.back()->var);
} }
Operation op = ComputeOpNode::make(name, axis, fcompute(args)); Operation op = ComputeOpNode::make(name, tag, axis, fcompute(args));
Array<Tensor> outputs; Array<Tensor> outputs;
for (int idx = 0; idx < op->num_outputs(); ++idx) { for (int idx = 0; idx < op->num_outputs(); ++idx) {
outputs.push_back(op.output(idx)); outputs.push_back(op.output(idx));
...@@ -100,10 +106,12 @@ bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) { ...@@ -100,10 +106,12 @@ bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
} }
Operation ComputeOpNode::make(std::string name, Operation ComputeOpNode::make(std::string name,
std::string tag,
Array<IterVar> axis, Array<IterVar> axis,
Array<Expr> body) { Array<Expr> body) {
auto n = std::make_shared<ComputeOpNode>(); auto n = std::make_shared<ComputeOpNode>();
n->name = name; n->name = name;
n->tag = tag;
n->axis = axis; n->axis = axis;
n->body = body; n->body = body;
if (n->body[0]->is_type<ir::Reduce>()) { if (n->body[0]->is_type<ir::Reduce>()) {
...@@ -147,7 +155,7 @@ Operation ComputeOpNode::ReplaceInputs( ...@@ -147,7 +155,7 @@ Operation ComputeOpNode::ReplaceInputs(
return op::ReplaceTensor(e, rmap); return op::ReplaceTensor(e, rmap);
}); });
if (!arr.same_as(this->body)) { if (!arr.same_as(this->body)) {
return ComputeOpNode::make(name, axis, arr); return ComputeOpNode::make(name, tag, axis, arr);
} else { } else {
return self; return self;
} }
......
...@@ -37,12 +37,14 @@ Array<Expr> ExternOpNode::output_shape(size_t i) const { ...@@ -37,12 +37,14 @@ Array<Expr> ExternOpNode::output_shape(size_t i) const {
Operation ExternOpNode::make(std::string name, Operation ExternOpNode::make(std::string name,
std::string tag,
Array<Tensor> inputs, Array<Tensor> inputs,
Array<Buffer> input_placeholders, Array<Buffer> input_placeholders,
Array<Buffer> output_placeholders, Array<Buffer> output_placeholders,
Stmt body) { Stmt body) {
auto n = std::make_shared<ExternOpNode>(); auto n = std::make_shared<ExternOpNode>();
n->name = name; n->name = name;
n->tag = tag;
CHECK_EQ(inputs.size(), input_placeholders.size()); CHECK_EQ(inputs.size(), input_placeholders.size());
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
CHECK_EQ(inputs[i]->dtype, input_placeholders[i]->dtype); CHECK_EQ(inputs[i]->dtype, input_placeholders[i]->dtype);
......
...@@ -44,6 +44,7 @@ Array<Expr> ScanOpNode::output_shape(size_t i) const { ...@@ -44,6 +44,7 @@ Array<Expr> ScanOpNode::output_shape(size_t i) const {
} }
Operation ScanOpNode::make(std::string name, Operation ScanOpNode::make(std::string name,
std::string tag,
IterVar axis, IterVar axis,
Array<Tensor> init, Array<Tensor> init,
Array<Tensor> update, Array<Tensor> update,
...@@ -86,6 +87,7 @@ Operation ScanOpNode::make(std::string name, ...@@ -86,6 +87,7 @@ Operation ScanOpNode::make(std::string name,
} }
} }
n->name = name; n->name = name;
n->tag = tag;
n->scan_axis = axis; n->scan_axis = axis;
n->init = init; n->init = init;
n->update = update; n->update = update;
...@@ -98,14 +100,15 @@ Array<Tensor> scan(Array<Tensor> init, ...@@ -98,14 +100,15 @@ Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> update, Array<Tensor> update,
Array<Tensor> state_placeholder, Array<Tensor> state_placeholder,
Array<Tensor> inputs, Array<Tensor> inputs,
std::string name) { std::string name,
std::string tag) {
IterVar scan_axis = IterVar scan_axis =
IterVarNode::make( IterVarNode::make(
Range::make_with_min_extent( Range::make_with_min_extent(
init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]), init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
Var(name + ".idx"), kOrdered); Var(name + ".idx"), kOrdered);
Operation op = ScanOpNode::make( Operation op = ScanOpNode::make(
name, scan_axis, init, update, state_placeholder, inputs); name, tag, scan_axis, init, update, state_placeholder, inputs);
Array<Tensor> res; Array<Tensor> res;
for (int i = 0; i < op->num_outputs(); ++i) { for (int i = 0; i < op->num_outputs(); ++i) {
res.push_back(op.output(i)); res.push_back(op.output(i));
......
...@@ -123,10 +123,10 @@ Tensor Schedule::cache_write(const Tensor& tensor, ...@@ -123,10 +123,10 @@ Tensor Schedule::cache_write(const Tensor& tensor,
VarReplacer repl(vsub); VarReplacer repl(vsub);
Expr body = repl.Mutate(compute->body[tensor->value_index]); Expr body = repl.Mutate(compute->body[tensor->value_index]);
Operation cache_op = ComputeOpNode::make( Operation cache_op = ComputeOpNode::make(
compute->name + "." + scope, new_axis, {body}); compute->name + "." + scope, compute->tag, new_axis, {body});
Tensor cache_tensor = cache_op.output(0); Tensor cache_tensor = cache_op.output(0);
Operation orig_new_op = ComputeOpNode::make( Operation orig_new_op = ComputeOpNode::make(
compute->name, compute->axis, compute->name, compute->tag, compute->axis,
{cache_tensor(args)}); {cache_tensor(args)});
std::unordered_map<Tensor, Tensor> vmap; std::unordered_map<Tensor, Tensor> vmap;
...@@ -246,7 +246,7 @@ void InjectInline(ScheduleNode* sch) { ...@@ -246,7 +246,7 @@ void InjectInline(ScheduleNode* sch) {
Operation op = s->op; Operation op = s->op;
if (changed[i]) { if (changed[i]) {
op = ComputeOpNode::make( op = ComputeOpNode::make(
compute->name, compute->axis, new_body[i]); compute->name, compute->tag, compute->axis, new_body[i]);
} }
op = op->ReplaceInputs(op, repl); op = op->ReplaceInputs(op, repl);
if (!op.same_as(s->op)) { if (!op.same_as(s->op)) {
......
# For CPU # For CPU
FROM ubuntu:14.04 FROM ubuntu:14.04
RUN apt-get update RUN apt-get update --fix-missing
COPY install/ubuntu_install_core.sh /install/ubuntu_install_core.sh COPY install/ubuntu_install_core.sh /install/ubuntu_install_core.sh
RUN bash /install/ubuntu_install_core.sh RUN bash /install/ubuntu_install_core.sh
......
FROM ioft/i386-ubuntu:14.04 FROM ioft/i386-ubuntu:14.04
RUN apt-get update RUN apt-get update --fix-missing
COPY install/ubuntu_install_core.sh /install/ubuntu_install_core.sh COPY install/ubuntu_install_core.sh /install/ubuntu_install_core.sh
RUN bash /install/ubuntu_install_core.sh RUN bash /install/ubuntu_install_core.sh
......
import tvm
@tvm.tag_scope(tag="conv")
def compute_conv(data, weight):
N, IC, H, W = data.shape
OC, IC, KH, KW = weight.shape
OH = H - KH + 1
OW = W - KW + 1
ic = tvm.reduce_axis((0, IC), name='ic')
dh = tvm.reduce_axis((0, KH), name='dh')
dw = tvm.reduce_axis((0, KW), name='dw')
return tvm.compute((N, OC, OH, OW), lambda i, oc, h, w: \
tvm.sum(data[i, ic, h+dh, w+dw] * weight[oc, ic, dh, dw],
axis=[ic, dh, dw]))
def test_with():
n = tvm.var('n')
m = tvm.var('m')
l = tvm.var('l')
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((m, l), name='B')
with tvm.tag_scope(tag="gemm"):
k = tvm.reduce_axis((0, l), name='k')
C = tvm.compute((n, m), lambda i, j: tvm.sum(A[i, k] * B[j, k], axis=k))
assert C.op.tag == 'gemm'
def test_decorator():
n = tvm.var('n')
c = tvm.var('c')
h = tvm.var('h')
w = tvm.var('w')
kh = tvm.var('kh')
kw = tvm.var('kw')
A = tvm.placeholder((n, c, h, w), name='A')
B = tvm.placeholder((c, c, kh, kw), name='B')
C = compute_conv(A, B)
assert C.op.tag == 'conv'
def test_nested():
n = tvm.var('n')
c = tvm.var('c')
h = tvm.var('h')
w = tvm.var('w')
kh = tvm.var('kh')
kw = tvm.var('kw')
A = tvm.placeholder((n, c, h, w), name='A')
B = tvm.placeholder((c, c, kh, kw), name='B')
try:
with tvm.tag_scope(tag='conv'):
C = compute_conv(A, B)
assert False
except ValueError:
pass
if __name__ == "__main__":
import nose
nose.runmodule()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment