Commit 0702d2c0 by Tianqi Chen Committed by GitHub

[OP] Introduces auxiliary attrs into compute (#1293)

parent 146714ac
......@@ -41,6 +41,8 @@ class OperationNode : public FunctionBaseNode {
std::string name;
/*! \brief optional tag of the operation */
std::string tag;
/*! \brief addtitional attributes of the operation*/
Map<std::string, NodeRef> attrs;
/*! \return name of the operation */
const std::string& func_name() const final {
return name;
......@@ -167,6 +169,8 @@ class PlaceholderOpNode : public OperationNode {
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
}
......@@ -220,12 +224,14 @@ class TVM_DLL ComputeOpNode : public OperationNode {
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
v->Visit("axis", &axis);
v->Visit("reduce_axis", &reduce_axis);
v->Visit("body", &body);
}
static Operation make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
Array<IterVar> axis,
Array<Expr> body);
......@@ -292,6 +298,7 @@ class ScanOpNode : public OperationNode {
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
v->Visit("scan_axis", &scan_axis);
v->Visit("init", &init);
v->Visit("update", &update);
......@@ -301,6 +308,7 @@ class ScanOpNode : public OperationNode {
}
static Operation make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
IterVar axis,
Array<Tensor> init,
Array<Tensor> update,
......@@ -356,11 +364,13 @@ class ExternOpNode : public OperationNode {
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
v->Visit("inputs", &inputs);
v->Visit("body", &body);
}
EXPORT static Operation make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
Array<Tensor> inputs,
Array<Buffer> input_placeholders,
Array<Buffer> output_placeholders,
......@@ -393,11 +403,13 @@ TVM_DLL Tensor placeholder(Array<Expr> shape,
* \param fcompute The compute function to create the tensor.
* \param name The optional name of the tensor.
* \param tag The optional tag of the tensor.
* \param attrs Optional additional attributes of the compute.
*/
TVM_DLL Tensor compute(Array<Expr> shape,
FCompute fcompute,
std::string name = "tensor",
std::string tag = "");
std::string tag = "",
Map<std::string, NodeRef> attrs = {});
/*!
* \brief Construct a new tensor by computing over shape,
......@@ -406,11 +418,13 @@ TVM_DLL Tensor compute(Array<Expr> shape,
* \param fcompute The compute function to create the tensors.
* \param name The optional name of the tensor.
* \param tag The optional tag of the tensor.
* \param attrs Optional additional attributes of the compute.
*/
TVM_DLL Array<Tensor> compute(Array<Expr> shape,
FBatchCompute fcompute,
std::string name = "tensor",
std::string tag = "");
std::string tag = "",
Map<std::string, NodeRef> attrs = {});
/*!
* \brief Construct new tensors by scan.
......@@ -422,42 +436,48 @@ TVM_DLL Array<Tensor> compute(Array<Expr> shape,
* but recommended to provide concrete information about scan body.
* \param name The optional name of the tensor.
* \param tag The optional tag of the tensor.
* \param attrs Optional additional attributes of the compute.
*/
TVM_DLL Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
Array<Tensor> inputs = Array<Tensor>(),
std::string name = "scan",
std::string tag = "");
std::string tag = "",
Map<std::string, NodeRef> attrs = {});
// same as compute, specialized for different fcompute function
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var)> f,
std::string name = "tensor",
std::string tag = "") {
std::string tag = "",
Map<std::string, NodeRef> attrs = {}) {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); };
return compute(shape, fc, name, tag);
return compute(shape, fc, name, tag, attrs);
}
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var)> f,
std::string name = "tensor",
std::string tag = "") {
std::string tag = "",
Map<std::string, NodeRef> attrs = {}) {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); };
return compute(shape, fc, name, tag);
return compute(shape, fc, name, tag, attrs);
}
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var)> f,
std::string name = "tensor",
std::string tag = "") {
std::string tag = "",
Map<std::string, NodeRef> attrs = {}) {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
return compute(shape, fc, name, tag);
return compute(shape, fc, name, tag, attrs);
}
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var, Var)> f,
std::string name = "tensor",
std::string tag = "") {
std::string tag = "",
Map<std::string, NodeRef> attrs = {}) {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
return compute(shape, fc, name, tag);
return compute(shape, fc, name, tag, attrs);
}
// inline function.
......
......@@ -42,6 +42,10 @@
#endif
#endif
// TVM version
#define TVM_VERSION "0.4.0"
// TVM Runtime is DLPack compatible.
#include <dlpack/dlpack.h>
......
......@@ -189,7 +189,7 @@ def placeholder(shape, dtype=None, name="placeholder"):
shape, dtype, name)
def compute(shape, fcompute, name="compute", tag=""):
def compute(shape, fcompute, name="compute", tag="", attrs=None):
"""Construct a new tensor by computing over the shape domain.
The compute rule is result[axis] = fcompute(axis)
......@@ -205,6 +205,12 @@ def compute(shape, fcompute, name="compute", tag=""):
name: str, optional
The name hint of the tensor
tag: str, optional
Additonal tag information about the compute.
attrs: dict, optional
The additional auxiliary attributes about the compute.
Returns
-------
tensor: Tensor
......@@ -232,13 +238,13 @@ def compute(shape, fcompute, name="compute", tag=""):
body = [body]
body = convert(body)
op_node = _api_internal._ComputeOp(
name, tag, dim_var, body)
name, tag, attrs, dim_var, body)
num = op_node.num_outputs
outputs = tuple(op_node.output(i) for i in range(num))
return outputs[0] if num == 1 else outputs
def scan(init, update, state_placeholder, inputs=None, name="scan", tag=""):
def scan(init, update, state_placeholder, inputs=None, name="scan", tag="", attrs=None):
"""Construct new tensors by scanning over axis.
Parameters
......@@ -259,6 +265,12 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag=""):
name: str, optional
The name hint of the tensor
tag: str, optional
Additonal tag information about the compute.
attrs: dict, optional
The additional auxiliary attributes about the compute.
Returns
-------
tensor: Tensor or list of Tensors
......@@ -294,7 +306,8 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag=""):
if len(init) != len(update) or len(init) != len(state_placeholder):
raise ValueError("init, update, state_placeholder must have same length")
axis = _IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3)
op = _api_internal._ScanOp(name, tag, axis, init, update,
op = _api_internal._ScanOp(name, tag, attrs,
axis, init, update,
state_placeholder, inputs)
res = [op.output(i) for i in range(len(update))]
return res[0] if len(res) == 1 else res
......@@ -307,7 +320,8 @@ def extern(shape,
dtype=None,
in_buffers=None,
out_buffers=None,
tag=""):
tag="",
attrs=None):
"""Compute several tensor via extern function.
Parameters
......@@ -345,6 +359,13 @@ def extern(shape,
out_buffers: Buffer or list of Buffers, optional
Output buffers.
tag: str, optional
Additonal tag information about the compute.
attrs: dict, optional
The additional auxiliary attributes about the compute.
Returns
-------
tensor: Tensor or list of Tensors
......@@ -406,7 +427,8 @@ def extern(shape,
if isinstance(body, _expr.Expr):
body = _make.Evaluate(body)
op = _api_internal._ExternOp(name, tag, inputs, input_placeholders,
op = _api_internal._ExternOp(name, tag, attrs,
inputs, input_placeholders,
output_placeholders, body)
res = [op.output(i) for i in range(len(output_placeholders))]
return res[0] if len(res) == 1 else res
......
......@@ -262,7 +262,8 @@ TVM_REGISTER_API("_ComputeOp")
*ret = ComputeOpNode::make(args[0],
args[1],
args[2],
args[3]);
args[3],
args[4]);
});
TVM_REGISTER_API("_ScanOp")
......@@ -273,7 +274,8 @@ TVM_REGISTER_API("_ScanOp")
args[3],
args[4],
args[5],
args[6]);
args[6],
args[7]);
});
TVM_REGISTER_API("_ExternOp")
......@@ -283,7 +285,8 @@ TVM_REGISTER_API("_ExternOp")
args[2],
args[3],
args[4],
args[5]);
args[5],
args[6]);
});
TVM_REGISTER_API("_OpGetOutput")
......
......@@ -84,6 +84,11 @@ class NodeIndexer : public AttrVisitor {
MakeIndex(kv.first.get());
MakeIndex(kv.second.get());
}
} else if (node->is_type<StrMapNode>()) {
StrMapNode* n = static_cast<StrMapNode*>(node);
for (const auto& kv : n->data) {
MakeIndex(kv.second.get());
}
} else {
node->VisitAttrs(this);
}
......@@ -99,6 +104,8 @@ struct JSONNode {
std::string type_key;
// the attributes
AttrMap attrs;
// container keys
std::vector<std::string> keys;
// container data
std::vector<size_t> data;
......@@ -108,6 +115,9 @@ struct JSONNode {
if (attrs.size() != 0) {
writer->WriteObjectKeyValue("attrs", attrs);
}
if (keys.size() != 0) {
writer->WriteObjectKeyValue("keys", keys);
}
if (data.size() != 0) {
writer->WriteObjectKeyValue("data", data);
}
......@@ -121,6 +131,7 @@ struct JSONNode {
dmlc::JSONObjectReadHelper helper;
helper.DeclareOptionalField("type_key", &type_key);
helper.DeclareOptionalField("attrs", &attrs);
helper.DeclareOptionalField("keys", &keys);
helper.DeclareOptionalField("data", &data);
helper.ReadAllFields(reader);
}
......@@ -176,13 +187,19 @@ class JSONAttrGetter : public AttrVisitor {
}
} else if (node->is_type<MapNode>()) {
MapNode* n = static_cast<MapNode*>(node);
std::vector<std::pair<size_t, size_t> > elems;
for (const auto& kv : n->data) {
node_->data.push_back(
node_index_->at(kv.first.get()));
node_->data.push_back(
node_index_->at(kv.second.get()));
}
} else if (node->is_type<StrMapNode>()) {
StrMapNode* n = static_cast<StrMapNode*>(node);
for (const auto& kv : n->data) {
node_->keys.push_back(kv.first);
node_->data.push_back(
node_index_->at(kv.second.get()));
}
} else {
node->VisitAttrs(this);
}
......@@ -256,6 +273,13 @@ class JSONAttrSetter : public AttrVisitor {
n->data[node_list_->at(node_->data[i])]
= node_list_->at(node_->data[i + 1]);
}
} else if (node->is_type<StrMapNode>()) {
StrMapNode* n = static_cast<StrMapNode*>(node);
CHECK_EQ(node_->data.size(), node_->keys.size());
for (size_t i = 0; i < node_->data.size(); ++i) {
n->data[node_->keys[i]]
= node_list_->at(node_->data[i]);
}
} else {
node->VisitAttrs(this);
}
......@@ -302,7 +326,7 @@ struct JSONGraph {
getter.Get(n);
g.nodes.emplace_back(std::move(jnode));
}
g.attrs["tvm_version"] = "0.1.0";
g.attrs["tvm_version"] = TVM_VERSION;
g.root = indexer.node_index.at(root.node_.get());
return g;
}
......
......@@ -66,7 +66,8 @@ Array<Expr> ComputeOpNode::output_shape(size_t idx) const {
Tensor compute(Array<Expr> shape,
FCompute fcompute,
std::string name,
std::string tag) {
std::string tag,
Map<std::string, NodeRef> attrs) {
auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension.
size_t ndim = shape.size();
......@@ -80,13 +81,15 @@ Tensor compute(Array<Expr> shape,
args.push_back(axis.back()->var);
}
return ComputeOpNode::make(name, tag, axis, {fcompute(args)}).output(0);
return ComputeOpNode::make(
name, tag, attrs, axis, {fcompute(args)}).output(0);
}
Array<Tensor> compute(Array<Expr> shape,
FBatchCompute fcompute,
std::string name,
std::string tag) {
std::string tag,
Map<std::string, NodeRef> attrs) {
auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension.
size_t ndim = shape.size();
......@@ -100,7 +103,7 @@ Array<Tensor> compute(Array<Expr> shape,
args.push_back(axis.back()->var);
}
Operation op = ComputeOpNode::make(name, tag, axis, fcompute(args));
Operation op = ComputeOpNode::make(name, tag, attrs, axis, fcompute(args));
Array<Tensor> outputs;
for (int idx = 0; idx < op->num_outputs(); ++idx) {
outputs.push_back(op.output(idx));
......@@ -110,13 +113,15 @@ Array<Tensor> compute(Array<Expr> shape,
Operation ComputeOpNode::make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
Array<IterVar> axis,
Array<Expr> body) {
auto n = std::make_shared<ComputeOpNode>();
n->name = name;
n->tag = tag;
n->axis = axis;
n->body = body;
n->name = std::move(name);
n->tag = std::move(tag);
n->attrs = std::move(attrs);
n->axis = std::move(axis);
n->body = std::move(body);
if (n->body[0]->is_type<ir::Reduce>()) {
const ir::Reduce* reduce = n->body[0].as<ir::Reduce>();
n->reduce_axis = reduce->axis;
......@@ -171,7 +176,8 @@ Operation ComputeOpNode::ReplaceInputs(
});
}
if (!arr.same_as(this->body)) {
return ComputeOpNode::make(name, tag, axis, arr);
return ComputeOpNode::make(
this->name, this->tag, this->attrs, this->axis, arr);
} else {
return self;
}
......
......@@ -38,23 +38,25 @@ Array<Expr> ExternOpNode::output_shape(size_t i) const {
Operation ExternOpNode::make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
Array<Tensor> inputs,
Array<Buffer> input_placeholders,
Array<Buffer> output_placeholders,
Stmt body) {
auto n = std::make_shared<ExternOpNode>();
n->name = name;
n->tag = tag;
n->name = std::move(name);
n->tag = std::move(tag);
n->attrs = std::move(attrs);
CHECK_EQ(inputs.size(), input_placeholders.size());
for (size_t i = 0; i < inputs.size(); ++i) {
CHECK_EQ(inputs[i]->dtype, input_placeholders[i]->dtype);
CHECK(inputs[i]->shape.same_as(input_placeholders[i]->shape));
CHECK_EQ(input_placeholders[i]->strides.size(), 0U);
}
n->inputs = inputs;
n->input_placeholders = input_placeholders;
n->output_placeholders = output_placeholders;
n->body = body;
n->inputs = std::move(inputs);
n->input_placeholders = std::move(input_placeholders);
n->output_placeholders = std::move(output_placeholders);
n->body = std::move(body);
return Operation(n);
}
......
......@@ -45,6 +45,7 @@ Array<Expr> ScanOpNode::output_shape(size_t i) const {
Operation ScanOpNode::make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
IterVar axis,
Array<Tensor> init,
Array<Tensor> update,
......@@ -86,13 +87,14 @@ Operation ScanOpNode::make(std::string name,
init[i]->shape[k], state_placeholder[i]->shape[k]));
}
}
n->name = name;
n->tag = tag;
n->scan_axis = axis;
n->init = init;
n->update = update;
n->state_placeholder = state_placeholder;
n->inputs = inputs;
n->name = std::move(name);
n->tag = std::move(tag);
n->attrs = std::move(attrs);
n->scan_axis = std::move(axis);
n->init = std::move(init);
n->update = std::move(update);
n->state_placeholder = std::move(state_placeholder);
n->inputs = std::move(inputs);
return Operation(n);
}
......@@ -101,14 +103,16 @@ Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> state_placeholder,
Array<Tensor> inputs,
std::string name,
std::string tag) {
std::string tag,
Map<std::string, NodeRef> attrs) {
IterVar scan_axis =
IterVarNode::make(
Range::make_by_min_extent(
init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
Var(name + ".idx"), kOrdered);
Operation op = ScanOpNode::make(
name, tag, scan_axis, init, update, state_placeholder, inputs);
name, tag, attrs, scan_axis,
init, update, state_placeholder, inputs);
Array<Tensor> res;
for (int i = 0; i < op->num_outputs(); ++i) {
res.push_back(op.output(i));
......
......@@ -232,7 +232,8 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
}
}
Operation cache_op = ComputeOpNode::make(
compute->name + "." + scope, compute->tag, new_axis, body_list);
compute->name + "." + scope, compute->tag, compute->attrs,
new_axis, body_list);
Array<Tensor> cache_tensor_list;
Array<Expr> cache_expr_list;
for (size_t i = 0; i < tensor_size; i++) {
......@@ -241,7 +242,8 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
cache_expr_list.push_back(cache_tensor(args));
}
Operation orig_new_op = ComputeOpNode::make(
compute->name, compute->tag, compute->axis, cache_expr_list);
compute->name, compute->tag, compute->attrs,
compute->axis, cache_expr_list);
// The replace of the dataflow
std::unordered_map<Tensor, Tensor> vmap;
std::unordered_map<Tensor, Tensor> rvmap;
......@@ -430,7 +432,8 @@ void InjectInline(ScheduleNode* sch) {
Operation op = s->op;
if (changed[i]) {
op = ComputeOpNode::make(
compute->name, compute->tag, compute->axis, new_body[i]);
compute->name, compute->tag, compute->attrs,
compute->axis, new_body[i]);
}
op = op->ReplaceInputs(op, repl);
if (!op.same_as(s->op)) {
......
......@@ -11,6 +11,18 @@ def test_const_saveload_json():
assert tvm.save_json(zz) == tvm.save_json(z)
def test_make_smap():
# save load json
x = tvm.const(1)
y = tvm.const(10)
z = x + y
smap = tvm.convert({"z": z, "x": x})
json_str = tvm.save_json(tvm.convert([smap]))
arr = tvm.load_json(json_str)
assert len(arr) == 1
assert arr[0]["z"].a == arr[0]["x"]
def test_make_node():
x = tvm.make.node("IntImm", dtype="int32", value=10)
assert isinstance(x, tvm.expr.IntImm)
......@@ -35,5 +47,6 @@ def test_make_sum():
if __name__ == "__main__":
test_make_node()
test_make_smap()
test_const_saveload_json()
test_make_sum()
import json
import tvm
@tvm.tag_scope(tag="conv")
......@@ -24,8 +25,19 @@ def test_with():
B = tvm.placeholder((m, l), name='B')
with tvm.tag_scope(tag="gemm"):
k = tvm.reduce_axis((0, l), name='k')
C = tvm.compute((n, m), lambda i, j: tvm.sum(A[i, k] * B[j, k], axis=k))
C = tvm.compute((n, m), lambda i, j: tvm.sum(A[i, k] * B[j, k], axis=k),
attrs={"hello" : 1, "arr": [10, 12]})
assert C.op.tag == 'gemm'
assert "hello" in C.op.attrs
assert "xx" not in C.op.attrs
assert C.op.attrs["hello"].value == 1
CC = tvm.load_json(tvm.save_json(C))
assert CC.op.attrs["hello"].value == 1
assert CC.op.attrs["arr"][0].value == 10
# str format happened to be json compatible
assert json.loads(str(CC.op.attrs))["arr"][1] == 12
def test_decorator():
n = tvm.var('n')
......@@ -39,6 +51,7 @@ def test_decorator():
B = tvm.placeholder((c, c, kh, kw), name='B')
C = compute_conv(A, B)
assert C.op.tag == 'conv'
assert len(C.op.attrs) == 0
def test_nested():
n = tvm.var('n')
......@@ -59,5 +72,6 @@ def test_nested():
if __name__ == "__main__":
import nose
nose.runmodule()
test_with()
test_decorator()
test_nested()
......@@ -40,7 +40,7 @@ inline Tensor cublas_matmul(const Tensor& lhs,
pack_buffer(outs[0]),
transa,
transb });
}, "C", "")[0];
}, "C", "", {})[0];
}
} // namespace contrib
......
......@@ -39,7 +39,7 @@ inline Tensor rocblas_matmul(const Tensor& lhs,
pack_buffer(outs[0]),
transa,
transb });
}, "C", "")[0];
}, "C", "", {})[0];
}
} // namespace contrib
......
......@@ -6,10 +6,10 @@
#ifndef TOPI_DETAIL_EXTERN_H_
#define TOPI_DETAIL_EXTERN_H_
#include <tvm/tvm.h>
#include <vector>
#include <string>
#include "tvm/tvm.h"
namespace topi {
namespace detail {
......@@ -51,6 +51,7 @@ using FExtern = std::function<Expr(Array<Buffer>, Array<Buffer>)>;
* the external function given the input and output buffers.
* \param name The name of the operation
* \param tag The tag to mark the operation
* \param attrs The additional auxiliary attributes of the operation.
*
* \return An array of Tensors representing the outputs of the function invocation. There will
* be one output Tensor for each element of out_shapes, with dtype equal to the corresponding
......@@ -61,7 +62,8 @@ inline Array<Tensor> make_extern(const Array< Array<Expr> >& out_shapes,
const Array<Tensor>& inputs,
FExtern fextern,
std::string name,
std::string tag) {
std::string tag,
::tvm::Map<std::string, NodeRef> attrs) {
CHECK_EQ(out_shapes.size(), out_types.size())
<< "make_extern: out_shapes and out_types must have equal size";
......@@ -78,7 +80,8 @@ inline Array<Tensor> make_extern(const Array< Array<Expr> >& out_shapes,
auto body_stmt = tvm::ir::Evaluate::make(body);
auto op = ExternOpNode::make(
name, tag, inputs, input_placeholders, output_placeholders, body_stmt);
name, tag, attrs, inputs,
input_placeholders, output_placeholders, body_stmt);
Array<Tensor> outputs;
for (size_t i = 0; i < output_placeholders.size(); ++i) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment