Commit c9da7254 by ziheng Committed by GitHub

[TAG] Add tvm.tag module for tagging operator (#217)

* [TAG] Add op_tag module for tagging operator

* Fix accroading to comments

* Add example

* Add into doc

* Add --fix-missing for docker
parent 8a66ac23
......@@ -20,6 +20,7 @@ The user facing API for computation declaration.
tvm.sum
tvm.min
tvm.max
tvm.tag_scope
.. autofunction:: tvm.load_json
.. autofunction:: tvm.save_json
......@@ -37,3 +38,4 @@ The user facing API for computation declaration.
.. autofunction:: tvm.sum
.. autofunction:: tvm.min
.. autofunction:: tvm.max
.. autofunction:: tvm.tag_scope
......@@ -39,6 +39,8 @@ class OperationNode : public FunctionBaseNode {
public:
/*! \brief optional name of the operation */
std::string name;
/*! \brief optional tag of the operation */
std::string tag;
/*! \return name of the operation */
const std::string& func_name() const final {
return name;
......@@ -213,11 +215,13 @@ class ComputeOpNode : public OperationNode {
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("axis", &axis);
v->Visit("reduce_axis", &reduce_axis);
v->Visit("body", &body);
}
static Operation make(std::string name,
std::string tag,
Array<IterVar> axis,
Array<Expr> body);
......@@ -282,6 +286,7 @@ class ScanOpNode : public OperationNode {
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("scan_axis", &scan_axis);
v->Visit("init", &init);
v->Visit("update", &update);
......@@ -290,6 +295,7 @@ class ScanOpNode : public OperationNode {
v->Visit("spatial_axis_", &spatial_axis_);
}
static Operation make(std::string name,
std::string tag,
IterVar axis,
Array<Tensor> init,
Array<Tensor> update,
......@@ -343,10 +349,12 @@ class ExternOpNode : public OperationNode {
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("inputs", &inputs);
v->Visit("body", &body);
}
static Operation make(std::string name,
std::string tag,
Array<Tensor> inputs,
Array<Buffer> input_placeholders,
Array<Buffer> output_placeholders,
......@@ -378,8 +386,12 @@ Tensor placeholder(Array<Expr> shape,
* \param shape Shape of the tensor.
* \param fcompute The compute function to create the tensor.
* \param name The optional name of the tensor.
* \param tag The optional tag of the tensor.
*/
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
Tensor compute(Array<Expr> shape,
FCompute fcompute,
std::string name = "tensor",
std::string tag = "");
/*!
* \brief Construct a new tensor by computing over shape,
......@@ -387,8 +399,12 @@ Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor"
* \param shape Shape of the tensor.
* \param fcompute The compute function to create the tensors.
* \param name The optional name of the tensor.
* \param tag The optional tag of the tensor.
*/
Array<Tensor> compute(Array<Expr> shape, FBatchCompute fcompute, std::string name = "tensor");
Array<Tensor> compute(Array<Expr> shape,
FBatchCompute fcompute,
std::string name = "tensor",
std::string tag = "");
/*!
* \brief Construct new tensors by scan.
......@@ -399,37 +415,43 @@ Array<Tensor> compute(Array<Expr> shape, FBatchCompute fcompute, std::string nam
* \param inputs The inputs to the scan body, this is optional,
* but recommended to provide concrete information about scan body.
* \param name The optional name of the tensor.
* \param tag The optional tag of the tensor.
*/
Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
Array<Tensor> inputs = Array<Tensor>(),
std::string name = "scan");
std::string name = "scan",
std::string tag = "");
// same as compute, specialized for different fcompute function
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var)> f,
std::string name = "tensor") {
std::string name = "tensor",
std::string tag = "") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); };
return compute(shape, fc, name);
return compute(shape, fc, name, tag);
}
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var)> f,
std::string name = "tensor") {
std::string name = "tensor",
std::string tag = "") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); };
return compute(shape, fc, name);
return compute(shape, fc, name, tag);
}
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var)> f,
std::string name = "tensor") {
std::string name = "tensor",
std::string tag = "") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
return compute(shape, fc, name);
return compute(shape, fc, name, tag);
}
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var, Var)> f,
std::string name = "tensor") {
std::string name = "tensor",
std::string tag = "") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
return compute(shape, fc, name);
return compute(shape, fc, name, tag);
}
// inline function.
......
......@@ -27,3 +27,4 @@ from .node import register_node
from .ndarray import register_extension
from .schedule import create_schedule
from .build import build, lower, build_config
from .tag import tag_scope
......@@ -16,6 +16,7 @@ from . import expr as _expr
from . import tensor as _tensor
from . import schedule as _schedule
from . import collections as _collections
from . import tag as _tag
int32 = "int32"
float32 = "float32"
......@@ -186,7 +187,7 @@ def placeholder(shape, dtype=None, name="placeholder"):
shape, dtype, name)
def compute(shape, fcompute, name="compute"):
def compute(shape, fcompute, name="compute", tag=""):
"""Construct a new tensor by computing over the shape domain.
The compute rule is result[axis] = fcompute(axis)
......@@ -207,6 +208,10 @@ def compute(shape, fcompute, name="compute"):
tensor: Tensor
The created tensor
"""
if _tag.TagScope.current is not None:
if tag != "":
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.current.tag
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
ndim = len(shape)
code = fcompute.__code__
......@@ -225,13 +230,13 @@ def compute(shape, fcompute, name="compute"):
body = [body]
body = convert(body)
op_node = _api_internal._ComputeOp(
name, dim_var, body)
name, tag, dim_var, body)
num = op_node.num_outputs
outputs = tuple(op_node.output(i) for i in range(num))
return outputs[0] if num == 1 else outputs
def scan(init, update, state_placeholder, inputs=None, name="scan"):
def scan(init, update, state_placeholder, inputs=None, name="scan", tag=""):
"""Construct new tensors by scanning over axis.
Parameters
......@@ -270,6 +275,10 @@ def scan(init, update, state_placeholder, inputs=None, name="scan"):
s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
res = tvm.scan(s_init, s_update, s_state, X)
"""
if _tag.TagScope.current is not None:
if tag != "":
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.current.tag
if isinstance(init, _tensor.Tensor):
init = [init]
if isinstance(update, _tensor.Tensor):
......@@ -283,13 +292,13 @@ def scan(init, update, state_placeholder, inputs=None, name="scan"):
if len(init) != len(update) or len(init) != len(state_placeholder):
raise ValueError("init, update, state_placeholder must have same length")
axis = _IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3)
op = _api_internal._ScanOp(name, axis, init, update, state_placeholder, inputs)
op = _api_internal._ScanOp(name, tag, axis, init, update,
state_placeholder, inputs)
res = [op.output(i) for i in range(len(update))]
return res[0] if len(res) == 1 else res
def extern(shape, inputs, fcompute,
name="extern", dtype=None):
def extern(shape, inputs, fcompute, name="extern", dtype=None, tag=""):
"""Compute several tensor via extern function.
Parameters
......@@ -340,6 +349,10 @@ def extern(shape, inputs, fcompute,
"tvm.contrib.cblas.matmul",
ins[0], ins[1], outs[0], 0, 0), name="C")
"""
if _tag.TagScope.current is not None:
if tag != "":
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.current.tag
shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
shape = [shape] if isinstance(shape[0], (_expr.Expr, _Integral)) else shape
input_placeholders = []
......@@ -364,8 +377,8 @@ def extern(shape, inputs, fcompute,
if isinstance(body, _expr.Expr):
body = _make.Evaluate(body)
op = _api_internal._ExternOp(
name, inputs, input_placeholders, output_placeholders, body)
op = _api_internal._ExternOp(name, tag, inputs, input_placeholders,
output_placeholders, body)
res = [op.output(i) for i in range(len(output_placeholders))]
return res[0] if len(res) == 1 else res
......
"""Tag class for TVM operators."""
from functools import wraps
class TagScope(object):
"""Tag scope object to set tag for operators, working as context
manager and decorator both. See also tag_scope.
"""
current = None
def __init__(self, tag):
self._old_scope = None
self.tag = tag
def __enter__(self):
if TagScope.current is not None:
raise ValueError("nested op_tag is not allowed for now")
self._old_scope = TagScope.current
TagScope.current = self
return self
def __exit__(self, ptype, value, trace):
assert self._old_scope is None
TagScope.current = self._old_scope
def __call__(self, fdecl):
@wraps(fdecl)
def tagged_fdecl(*args, **kwargs):
with self:
return fdecl(*args, **kwargs)
return tagged_fdecl
def tag_scope(tag):
"""The operator tag scope.
Parameters
----------
tag: str
The tag name.
Returns
-------
tag_scope: TagScope
The tag scope object, which can be used as decorator or
context manger.
Example
-------
.. code-block:: python
n = tvm.var('n')
m = tvm.var('m')
l = tvm.var('m')
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((m, l), name='A')
k = tvm.reduce_axis((0, l), name='k')
with tvm.tag_scope(tag='matmul'):
C = tvm.compute((n, m), lambda i, j: tvm.sum(A[i, k] * B[j, k], axis=k))
# or use tag_scope as decorator
@tvm.tag_scope(tag="conv")
def compute_relu(data):
return tvm.compute(data.shape, lambda *i: tvm.select(data(*i) < 0, 0.0, data(*i)))
"""
return TagScope(tag)
......@@ -185,7 +185,8 @@ TVM_REGISTER_API("_ComputeOp")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ComputeOpNode::make(args[0],
args[1],
args[2]);
args[2],
args[3]);
});
TVM_REGISTER_API("_ScanOp")
......@@ -195,7 +196,8 @@ TVM_REGISTER_API("_ScanOp")
args[2],
args[3],
args[4],
args[5]);
args[5],
args[6]);
});
TVM_REGISTER_API("_ExternOp")
......@@ -204,7 +206,8 @@ TVM_REGISTER_API("_ExternOp")
args[1],
args[2],
args[3],
args[4]);
args[4],
args[5]);
});
TVM_REGISTER_API("_OpGetOutput")
......
......@@ -53,7 +53,10 @@ Array<Expr> ComputeOpNode::output_shape(size_t idx) const {
return Array<Expr>(shape);
}
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name) {
Tensor compute(Array<Expr> shape,
FCompute fcompute,
std::string name,
std::string tag) {
auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension.
size_t ndim = shape.size();
......@@ -67,10 +70,13 @@ Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name) {
args.push_back(axis.back()->var);
}
return ComputeOpNode::make(name, axis, {fcompute(args)}).output(0);
return ComputeOpNode::make(name, tag, axis, {fcompute(args)}).output(0);
}
Array<Tensor> compute(Array<Expr> shape, FBatchCompute fcompute, std::string name) {
Array<Tensor> compute(Array<Expr> shape,
FBatchCompute fcompute,
std::string name,
std::string tag) {
auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension.
size_t ndim = shape.size();
......@@ -84,7 +90,7 @@ Array<Tensor> compute(Array<Expr> shape, FBatchCompute fcompute, std::string nam
args.push_back(axis.back()->var);
}
Operation op = ComputeOpNode::make(name, axis, fcompute(args));
Operation op = ComputeOpNode::make(name, tag, axis, fcompute(args));
Array<Tensor> outputs;
for (int idx = 0; idx < op->num_outputs(); ++idx) {
outputs.push_back(op.output(idx));
......@@ -100,10 +106,12 @@ bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
}
Operation ComputeOpNode::make(std::string name,
std::string tag,
Array<IterVar> axis,
Array<Expr> body) {
auto n = std::make_shared<ComputeOpNode>();
n->name = name;
n->tag = tag;
n->axis = axis;
n->body = body;
if (n->body[0]->is_type<ir::Reduce>()) {
......@@ -147,7 +155,7 @@ Operation ComputeOpNode::ReplaceInputs(
return op::ReplaceTensor(e, rmap);
});
if (!arr.same_as(this->body)) {
return ComputeOpNode::make(name, axis, arr);
return ComputeOpNode::make(name, tag, axis, arr);
} else {
return self;
}
......
......@@ -37,12 +37,14 @@ Array<Expr> ExternOpNode::output_shape(size_t i) const {
Operation ExternOpNode::make(std::string name,
std::string tag,
Array<Tensor> inputs,
Array<Buffer> input_placeholders,
Array<Buffer> output_placeholders,
Stmt body) {
auto n = std::make_shared<ExternOpNode>();
n->name = name;
n->tag = tag;
CHECK_EQ(inputs.size(), input_placeholders.size());
for (size_t i = 0; i < inputs.size(); ++i) {
CHECK_EQ(inputs[i]->dtype, input_placeholders[i]->dtype);
......
......@@ -44,6 +44,7 @@ Array<Expr> ScanOpNode::output_shape(size_t i) const {
}
Operation ScanOpNode::make(std::string name,
std::string tag,
IterVar axis,
Array<Tensor> init,
Array<Tensor> update,
......@@ -86,6 +87,7 @@ Operation ScanOpNode::make(std::string name,
}
}
n->name = name;
n->tag = tag;
n->scan_axis = axis;
n->init = init;
n->update = update;
......@@ -98,14 +100,15 @@ Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
Array<Tensor> inputs,
std::string name) {
std::string name,
std::string tag) {
IterVar scan_axis =
IterVarNode::make(
Range::make_with_min_extent(
init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
Var(name + ".idx"), kOrdered);
Operation op = ScanOpNode::make(
name, scan_axis, init, update, state_placeholder, inputs);
name, tag, scan_axis, init, update, state_placeholder, inputs);
Array<Tensor> res;
for (int i = 0; i < op->num_outputs(); ++i) {
res.push_back(op.output(i));
......
......@@ -123,10 +123,10 @@ Tensor Schedule::cache_write(const Tensor& tensor,
VarReplacer repl(vsub);
Expr body = repl.Mutate(compute->body[tensor->value_index]);
Operation cache_op = ComputeOpNode::make(
compute->name + "." + scope, new_axis, {body});
compute->name + "." + scope, compute->tag, new_axis, {body});
Tensor cache_tensor = cache_op.output(0);
Operation orig_new_op = ComputeOpNode::make(
compute->name, compute->axis,
compute->name, compute->tag, compute->axis,
{cache_tensor(args)});
std::unordered_map<Tensor, Tensor> vmap;
......@@ -246,7 +246,7 @@ void InjectInline(ScheduleNode* sch) {
Operation op = s->op;
if (changed[i]) {
op = ComputeOpNode::make(
compute->name, compute->axis, new_body[i]);
compute->name, compute->tag, compute->axis, new_body[i]);
}
op = op->ReplaceInputs(op, repl);
if (!op.same_as(s->op)) {
......
# For CPU
FROM ubuntu:14.04
RUN apt-get update
RUN apt-get update --fix-missing
COPY install/ubuntu_install_core.sh /install/ubuntu_install_core.sh
RUN bash /install/ubuntu_install_core.sh
......
FROM ioft/i386-ubuntu:14.04
RUN apt-get update
RUN apt-get update --fix-missing
COPY install/ubuntu_install_core.sh /install/ubuntu_install_core.sh
RUN bash /install/ubuntu_install_core.sh
......
import tvm
@tvm.tag_scope(tag="conv")
def compute_conv(data, weight):
N, IC, H, W = data.shape
OC, IC, KH, KW = weight.shape
OH = H - KH + 1
OW = W - KW + 1
ic = tvm.reduce_axis((0, IC), name='ic')
dh = tvm.reduce_axis((0, KH), name='dh')
dw = tvm.reduce_axis((0, KW), name='dw')
return tvm.compute((N, OC, OH, OW), lambda i, oc, h, w: \
tvm.sum(data[i, ic, h+dh, w+dw] * weight[oc, ic, dh, dw],
axis=[ic, dh, dw]))
def test_with():
n = tvm.var('n')
m = tvm.var('m')
l = tvm.var('l')
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((m, l), name='B')
with tvm.tag_scope(tag="gemm"):
k = tvm.reduce_axis((0, l), name='k')
C = tvm.compute((n, m), lambda i, j: tvm.sum(A[i, k] * B[j, k], axis=k))
assert C.op.tag == 'gemm'
def test_decorator():
n = tvm.var('n')
c = tvm.var('c')
h = tvm.var('h')
w = tvm.var('w')
kh = tvm.var('kh')
kw = tvm.var('kw')
A = tvm.placeholder((n, c, h, w), name='A')
B = tvm.placeholder((c, c, kh, kw), name='B')
C = compute_conv(A, B)
assert C.op.tag == 'conv'
def test_nested():
n = tvm.var('n')
c = tvm.var('c')
h = tvm.var('h')
w = tvm.var('w')
kh = tvm.var('kh')
kw = tvm.var('kw')
A = tvm.placeholder((n, c, h, w), name='A')
B = tvm.placeholder((c, c, kh, kw), name='B')
try:
with tvm.tag_scope(tag='conv'):
C = compute_conv(A, B)
assert False
except ValueError:
pass
if __name__ == "__main__":
import nose
nose.runmodule()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment