Commit 0702d2c0 by Tianqi Chen Committed by GitHub

[OP] Introduces auxiliary attrs into compute (#1293)

parent 146714ac
...@@ -41,6 +41,8 @@ class OperationNode : public FunctionBaseNode { ...@@ -41,6 +41,8 @@ class OperationNode : public FunctionBaseNode {
std::string name; std::string name;
/*! \brief optional tag of the operation */ /*! \brief optional tag of the operation */
std::string tag; std::string tag;
/*! \brief addtitional attributes of the operation*/
Map<std::string, NodeRef> attrs;
/*! \return name of the operation */ /*! \return name of the operation */
const std::string& func_name() const final { const std::string& func_name() const final {
return name; return name;
...@@ -167,6 +169,8 @@ class PlaceholderOpNode : public OperationNode { ...@@ -167,6 +169,8 @@ class PlaceholderOpNode : public OperationNode {
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
v->Visit("shape", &shape); v->Visit("shape", &shape);
v->Visit("dtype", &dtype); v->Visit("dtype", &dtype);
} }
...@@ -220,12 +224,14 @@ class TVM_DLL ComputeOpNode : public OperationNode { ...@@ -220,12 +224,14 @@ class TVM_DLL ComputeOpNode : public OperationNode {
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("tag", &tag); v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
v->Visit("axis", &axis); v->Visit("axis", &axis);
v->Visit("reduce_axis", &reduce_axis); v->Visit("reduce_axis", &reduce_axis);
v->Visit("body", &body); v->Visit("body", &body);
} }
static Operation make(std::string name, static Operation make(std::string name,
std::string tag, std::string tag,
Map<std::string, NodeRef> attrs,
Array<IterVar> axis, Array<IterVar> axis,
Array<Expr> body); Array<Expr> body);
...@@ -292,6 +298,7 @@ class ScanOpNode : public OperationNode { ...@@ -292,6 +298,7 @@ class ScanOpNode : public OperationNode {
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("tag", &tag); v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
v->Visit("scan_axis", &scan_axis); v->Visit("scan_axis", &scan_axis);
v->Visit("init", &init); v->Visit("init", &init);
v->Visit("update", &update); v->Visit("update", &update);
...@@ -301,6 +308,7 @@ class ScanOpNode : public OperationNode { ...@@ -301,6 +308,7 @@ class ScanOpNode : public OperationNode {
} }
static Operation make(std::string name, static Operation make(std::string name,
std::string tag, std::string tag,
Map<std::string, NodeRef> attrs,
IterVar axis, IterVar axis,
Array<Tensor> init, Array<Tensor> init,
Array<Tensor> update, Array<Tensor> update,
...@@ -356,11 +364,13 @@ class ExternOpNode : public OperationNode { ...@@ -356,11 +364,13 @@ class ExternOpNode : public OperationNode {
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("tag", &tag); v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
v->Visit("inputs", &inputs); v->Visit("inputs", &inputs);
v->Visit("body", &body); v->Visit("body", &body);
} }
EXPORT static Operation make(std::string name, EXPORT static Operation make(std::string name,
std::string tag, std::string tag,
Map<std::string, NodeRef> attrs,
Array<Tensor> inputs, Array<Tensor> inputs,
Array<Buffer> input_placeholders, Array<Buffer> input_placeholders,
Array<Buffer> output_placeholders, Array<Buffer> output_placeholders,
...@@ -393,11 +403,13 @@ TVM_DLL Tensor placeholder(Array<Expr> shape, ...@@ -393,11 +403,13 @@ TVM_DLL Tensor placeholder(Array<Expr> shape,
* \param fcompute The compute function to create the tensor. * \param fcompute The compute function to create the tensor.
* \param name The optional name of the tensor. * \param name The optional name of the tensor.
* \param tag The optional tag of the tensor. * \param tag The optional tag of the tensor.
* \param attrs Optional additional attributes of the compute.
*/ */
TVM_DLL Tensor compute(Array<Expr> shape, TVM_DLL Tensor compute(Array<Expr> shape,
FCompute fcompute, FCompute fcompute,
std::string name = "tensor", std::string name = "tensor",
std::string tag = ""); std::string tag = "",
Map<std::string, NodeRef> attrs = {});
/*! /*!
* \brief Construct a new tensor by computing over shape, * \brief Construct a new tensor by computing over shape,
...@@ -406,11 +418,13 @@ TVM_DLL Tensor compute(Array<Expr> shape, ...@@ -406,11 +418,13 @@ TVM_DLL Tensor compute(Array<Expr> shape,
* \param fcompute The compute function to create the tensors. * \param fcompute The compute function to create the tensors.
* \param name The optional name of the tensor. * \param name The optional name of the tensor.
* \param tag The optional tag of the tensor. * \param tag The optional tag of the tensor.
* \param attrs Optional additional attributes of the compute.
*/ */
TVM_DLL Array<Tensor> compute(Array<Expr> shape, TVM_DLL Array<Tensor> compute(Array<Expr> shape,
FBatchCompute fcompute, FBatchCompute fcompute,
std::string name = "tensor", std::string name = "tensor",
std::string tag = ""); std::string tag = "",
Map<std::string, NodeRef> attrs = {});
/*! /*!
* \brief Construct new tensors by scan. * \brief Construct new tensors by scan.
...@@ -422,42 +436,48 @@ TVM_DLL Array<Tensor> compute(Array<Expr> shape, ...@@ -422,42 +436,48 @@ TVM_DLL Array<Tensor> compute(Array<Expr> shape,
* but recommended to provide concrete information about scan body. * but recommended to provide concrete information about scan body.
* \param name The optional name of the tensor. * \param name The optional name of the tensor.
* \param tag The optional tag of the tensor. * \param tag The optional tag of the tensor.
* \param attrs Optional additional attributes of the compute.
*/ */
TVM_DLL Array<Tensor> scan(Array<Tensor> init, TVM_DLL Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> update, Array<Tensor> update,
Array<Tensor> state_placeholder, Array<Tensor> state_placeholder,
Array<Tensor> inputs = Array<Tensor>(), Array<Tensor> inputs = Array<Tensor>(),
std::string name = "scan", std::string name = "scan",
std::string tag = ""); std::string tag = "",
Map<std::string, NodeRef> attrs = {});
// same as compute, specialized for different fcompute function // same as compute, specialized for different fcompute function
inline Tensor compute(Array<Expr> shape, inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var)> f, std::function<Expr(Var)> f,
std::string name = "tensor", std::string name = "tensor",
std::string tag = "") { std::string tag = "",
Map<std::string, NodeRef> attrs = {}) {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); }; FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); };
return compute(shape, fc, name, tag); return compute(shape, fc, name, tag, attrs);
} }
inline Tensor compute(Array<Expr> shape, inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var)> f, std::function<Expr(Var, Var)> f,
std::string name = "tensor", std::string name = "tensor",
std::string tag = "") { std::string tag = "",
Map<std::string, NodeRef> attrs = {}) {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); }; FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); };
return compute(shape, fc, name, tag); return compute(shape, fc, name, tag, attrs);
} }
inline Tensor compute(Array<Expr> shape, inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var)> f, std::function<Expr(Var, Var, Var)> f,
std::string name = "tensor", std::string name = "tensor",
std::string tag = "") { std::string tag = "",
Map<std::string, NodeRef> attrs = {}) {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); }; FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
return compute(shape, fc, name, tag); return compute(shape, fc, name, tag, attrs);
} }
inline Tensor compute(Array<Expr> shape, inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var, Var)> f, std::function<Expr(Var, Var, Var, Var)> f,
std::string name = "tensor", std::string name = "tensor",
std::string tag = "") { std::string tag = "",
Map<std::string, NodeRef> attrs = {}) {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); }; FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
return compute(shape, fc, name, tag); return compute(shape, fc, name, tag, attrs);
} }
// inline function. // inline function.
......
...@@ -42,6 +42,10 @@ ...@@ -42,6 +42,10 @@
#endif #endif
#endif #endif
// TVM version
#define TVM_VERSION "0.4.0"
// TVM Runtime is DLPack compatible. // TVM Runtime is DLPack compatible.
#include <dlpack/dlpack.h> #include <dlpack/dlpack.h>
......
...@@ -189,7 +189,7 @@ def placeholder(shape, dtype=None, name="placeholder"): ...@@ -189,7 +189,7 @@ def placeholder(shape, dtype=None, name="placeholder"):
shape, dtype, name) shape, dtype, name)
def compute(shape, fcompute, name="compute", tag=""): def compute(shape, fcompute, name="compute", tag="", attrs=None):
"""Construct a new tensor by computing over the shape domain. """Construct a new tensor by computing over the shape domain.
The compute rule is result[axis] = fcompute(axis) The compute rule is result[axis] = fcompute(axis)
...@@ -205,6 +205,12 @@ def compute(shape, fcompute, name="compute", tag=""): ...@@ -205,6 +205,12 @@ def compute(shape, fcompute, name="compute", tag=""):
name: str, optional name: str, optional
The name hint of the tensor The name hint of the tensor
tag: str, optional
Additonal tag information about the compute.
attrs: dict, optional
The additional auxiliary attributes about the compute.
Returns Returns
------- -------
tensor: Tensor tensor: Tensor
...@@ -232,13 +238,13 @@ def compute(shape, fcompute, name="compute", tag=""): ...@@ -232,13 +238,13 @@ def compute(shape, fcompute, name="compute", tag=""):
body = [body] body = [body]
body = convert(body) body = convert(body)
op_node = _api_internal._ComputeOp( op_node = _api_internal._ComputeOp(
name, tag, dim_var, body) name, tag, attrs, dim_var, body)
num = op_node.num_outputs num = op_node.num_outputs
outputs = tuple(op_node.output(i) for i in range(num)) outputs = tuple(op_node.output(i) for i in range(num))
return outputs[0] if num == 1 else outputs return outputs[0] if num == 1 else outputs
def scan(init, update, state_placeholder, inputs=None, name="scan", tag=""): def scan(init, update, state_placeholder, inputs=None, name="scan", tag="", attrs=None):
"""Construct new tensors by scanning over axis. """Construct new tensors by scanning over axis.
Parameters Parameters
...@@ -259,6 +265,12 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag=""): ...@@ -259,6 +265,12 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag=""):
name: str, optional name: str, optional
The name hint of the tensor The name hint of the tensor
tag: str, optional
Additonal tag information about the compute.
attrs: dict, optional
The additional auxiliary attributes about the compute.
Returns Returns
------- -------
tensor: Tensor or list of Tensors tensor: Tensor or list of Tensors
...@@ -294,7 +306,8 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag=""): ...@@ -294,7 +306,8 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag=""):
if len(init) != len(update) or len(init) != len(state_placeholder): if len(init) != len(update) or len(init) != len(state_placeholder):
raise ValueError("init, update, state_placeholder must have same length") raise ValueError("init, update, state_placeholder must have same length")
axis = _IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3) axis = _IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3)
op = _api_internal._ScanOp(name, tag, axis, init, update, op = _api_internal._ScanOp(name, tag, attrs,
axis, init, update,
state_placeholder, inputs) state_placeholder, inputs)
res = [op.output(i) for i in range(len(update))] res = [op.output(i) for i in range(len(update))]
return res[0] if len(res) == 1 else res return res[0] if len(res) == 1 else res
...@@ -307,7 +320,8 @@ def extern(shape, ...@@ -307,7 +320,8 @@ def extern(shape,
dtype=None, dtype=None,
in_buffers=None, in_buffers=None,
out_buffers=None, out_buffers=None,
tag=""): tag="",
attrs=None):
"""Compute several tensor via extern function. """Compute several tensor via extern function.
Parameters Parameters
...@@ -345,6 +359,13 @@ def extern(shape, ...@@ -345,6 +359,13 @@ def extern(shape,
out_buffers: Buffer or list of Buffers, optional out_buffers: Buffer or list of Buffers, optional
Output buffers. Output buffers.
tag: str, optional
Additonal tag information about the compute.
attrs: dict, optional
The additional auxiliary attributes about the compute.
Returns Returns
------- -------
tensor: Tensor or list of Tensors tensor: Tensor or list of Tensors
...@@ -406,7 +427,8 @@ def extern(shape, ...@@ -406,7 +427,8 @@ def extern(shape,
if isinstance(body, _expr.Expr): if isinstance(body, _expr.Expr):
body = _make.Evaluate(body) body = _make.Evaluate(body)
op = _api_internal._ExternOp(name, tag, inputs, input_placeholders, op = _api_internal._ExternOp(name, tag, attrs,
inputs, input_placeholders,
output_placeholders, body) output_placeholders, body)
res = [op.output(i) for i in range(len(output_placeholders))] res = [op.output(i) for i in range(len(output_placeholders))]
return res[0] if len(res) == 1 else res return res[0] if len(res) == 1 else res
......
...@@ -262,7 +262,8 @@ TVM_REGISTER_API("_ComputeOp") ...@@ -262,7 +262,8 @@ TVM_REGISTER_API("_ComputeOp")
*ret = ComputeOpNode::make(args[0], *ret = ComputeOpNode::make(args[0],
args[1], args[1],
args[2], args[2],
args[3]); args[3],
args[4]);
}); });
TVM_REGISTER_API("_ScanOp") TVM_REGISTER_API("_ScanOp")
...@@ -273,7 +274,8 @@ TVM_REGISTER_API("_ScanOp") ...@@ -273,7 +274,8 @@ TVM_REGISTER_API("_ScanOp")
args[3], args[3],
args[4], args[4],
args[5], args[5],
args[6]); args[6],
args[7]);
}); });
TVM_REGISTER_API("_ExternOp") TVM_REGISTER_API("_ExternOp")
...@@ -283,7 +285,8 @@ TVM_REGISTER_API("_ExternOp") ...@@ -283,7 +285,8 @@ TVM_REGISTER_API("_ExternOp")
args[2], args[2],
args[3], args[3],
args[4], args[4],
args[5]); args[5],
args[6]);
}); });
TVM_REGISTER_API("_OpGetOutput") TVM_REGISTER_API("_OpGetOutput")
......
...@@ -84,6 +84,11 @@ class NodeIndexer : public AttrVisitor { ...@@ -84,6 +84,11 @@ class NodeIndexer : public AttrVisitor {
MakeIndex(kv.first.get()); MakeIndex(kv.first.get());
MakeIndex(kv.second.get()); MakeIndex(kv.second.get());
} }
} else if (node->is_type<StrMapNode>()) {
StrMapNode* n = static_cast<StrMapNode*>(node);
for (const auto& kv : n->data) {
MakeIndex(kv.second.get());
}
} else { } else {
node->VisitAttrs(this); node->VisitAttrs(this);
} }
...@@ -99,6 +104,8 @@ struct JSONNode { ...@@ -99,6 +104,8 @@ struct JSONNode {
std::string type_key; std::string type_key;
// the attributes // the attributes
AttrMap attrs; AttrMap attrs;
// container keys
std::vector<std::string> keys;
// container data // container data
std::vector<size_t> data; std::vector<size_t> data;
...@@ -108,6 +115,9 @@ struct JSONNode { ...@@ -108,6 +115,9 @@ struct JSONNode {
if (attrs.size() != 0) { if (attrs.size() != 0) {
writer->WriteObjectKeyValue("attrs", attrs); writer->WriteObjectKeyValue("attrs", attrs);
} }
if (keys.size() != 0) {
writer->WriteObjectKeyValue("keys", keys);
}
if (data.size() != 0) { if (data.size() != 0) {
writer->WriteObjectKeyValue("data", data); writer->WriteObjectKeyValue("data", data);
} }
...@@ -121,6 +131,7 @@ struct JSONNode { ...@@ -121,6 +131,7 @@ struct JSONNode {
dmlc::JSONObjectReadHelper helper; dmlc::JSONObjectReadHelper helper;
helper.DeclareOptionalField("type_key", &type_key); helper.DeclareOptionalField("type_key", &type_key);
helper.DeclareOptionalField("attrs", &attrs); helper.DeclareOptionalField("attrs", &attrs);
helper.DeclareOptionalField("keys", &keys);
helper.DeclareOptionalField("data", &data); helper.DeclareOptionalField("data", &data);
helper.ReadAllFields(reader); helper.ReadAllFields(reader);
} }
...@@ -176,13 +187,19 @@ class JSONAttrGetter : public AttrVisitor { ...@@ -176,13 +187,19 @@ class JSONAttrGetter : public AttrVisitor {
} }
} else if (node->is_type<MapNode>()) { } else if (node->is_type<MapNode>()) {
MapNode* n = static_cast<MapNode*>(node); MapNode* n = static_cast<MapNode*>(node);
std::vector<std::pair<size_t, size_t> > elems;
for (const auto& kv : n->data) { for (const auto& kv : n->data) {
node_->data.push_back( node_->data.push_back(
node_index_->at(kv.first.get())); node_index_->at(kv.first.get()));
node_->data.push_back( node_->data.push_back(
node_index_->at(kv.second.get())); node_index_->at(kv.second.get()));
} }
} else if (node->is_type<StrMapNode>()) {
StrMapNode* n = static_cast<StrMapNode*>(node);
for (const auto& kv : n->data) {
node_->keys.push_back(kv.first);
node_->data.push_back(
node_index_->at(kv.second.get()));
}
} else { } else {
node->VisitAttrs(this); node->VisitAttrs(this);
} }
...@@ -256,6 +273,13 @@ class JSONAttrSetter : public AttrVisitor { ...@@ -256,6 +273,13 @@ class JSONAttrSetter : public AttrVisitor {
n->data[node_list_->at(node_->data[i])] n->data[node_list_->at(node_->data[i])]
= node_list_->at(node_->data[i + 1]); = node_list_->at(node_->data[i + 1]);
} }
} else if (node->is_type<StrMapNode>()) {
StrMapNode* n = static_cast<StrMapNode*>(node);
CHECK_EQ(node_->data.size(), node_->keys.size());
for (size_t i = 0; i < node_->data.size(); ++i) {
n->data[node_->keys[i]]
= node_list_->at(node_->data[i]);
}
} else { } else {
node->VisitAttrs(this); node->VisitAttrs(this);
} }
...@@ -302,7 +326,7 @@ struct JSONGraph { ...@@ -302,7 +326,7 @@ struct JSONGraph {
getter.Get(n); getter.Get(n);
g.nodes.emplace_back(std::move(jnode)); g.nodes.emplace_back(std::move(jnode));
} }
g.attrs["tvm_version"] = "0.1.0"; g.attrs["tvm_version"] = TVM_VERSION;
g.root = indexer.node_index.at(root.node_.get()); g.root = indexer.node_index.at(root.node_.get());
return g; return g;
} }
......
...@@ -66,7 +66,8 @@ Array<Expr> ComputeOpNode::output_shape(size_t idx) const { ...@@ -66,7 +66,8 @@ Array<Expr> ComputeOpNode::output_shape(size_t idx) const {
Tensor compute(Array<Expr> shape, Tensor compute(Array<Expr> shape,
FCompute fcompute, FCompute fcompute,
std::string name, std::string name,
std::string tag) { std::string tag,
Map<std::string, NodeRef> attrs) {
auto op_node = std::make_shared<ComputeOpNode>(); auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension. // compute dimension.
size_t ndim = shape.size(); size_t ndim = shape.size();
...@@ -80,13 +81,15 @@ Tensor compute(Array<Expr> shape, ...@@ -80,13 +81,15 @@ Tensor compute(Array<Expr> shape,
args.push_back(axis.back()->var); args.push_back(axis.back()->var);
} }
return ComputeOpNode::make(name, tag, axis, {fcompute(args)}).output(0); return ComputeOpNode::make(
name, tag, attrs, axis, {fcompute(args)}).output(0);
} }
Array<Tensor> compute(Array<Expr> shape, Array<Tensor> compute(Array<Expr> shape,
FBatchCompute fcompute, FBatchCompute fcompute,
std::string name, std::string name,
std::string tag) { std::string tag,
Map<std::string, NodeRef> attrs) {
auto op_node = std::make_shared<ComputeOpNode>(); auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension. // compute dimension.
size_t ndim = shape.size(); size_t ndim = shape.size();
...@@ -100,7 +103,7 @@ Array<Tensor> compute(Array<Expr> shape, ...@@ -100,7 +103,7 @@ Array<Tensor> compute(Array<Expr> shape,
args.push_back(axis.back()->var); args.push_back(axis.back()->var);
} }
Operation op = ComputeOpNode::make(name, tag, axis, fcompute(args)); Operation op = ComputeOpNode::make(name, tag, attrs, axis, fcompute(args));
Array<Tensor> outputs; Array<Tensor> outputs;
for (int idx = 0; idx < op->num_outputs(); ++idx) { for (int idx = 0; idx < op->num_outputs(); ++idx) {
outputs.push_back(op.output(idx)); outputs.push_back(op.output(idx));
...@@ -110,13 +113,15 @@ Array<Tensor> compute(Array<Expr> shape, ...@@ -110,13 +113,15 @@ Array<Tensor> compute(Array<Expr> shape,
Operation ComputeOpNode::make(std::string name, Operation ComputeOpNode::make(std::string name,
std::string tag, std::string tag,
Map<std::string, NodeRef> attrs,
Array<IterVar> axis, Array<IterVar> axis,
Array<Expr> body) { Array<Expr> body) {
auto n = std::make_shared<ComputeOpNode>(); auto n = std::make_shared<ComputeOpNode>();
n->name = name; n->name = std::move(name);
n->tag = tag; n->tag = std::move(tag);
n->axis = axis; n->attrs = std::move(attrs);
n->body = body; n->axis = std::move(axis);
n->body = std::move(body);
if (n->body[0]->is_type<ir::Reduce>()) { if (n->body[0]->is_type<ir::Reduce>()) {
const ir::Reduce* reduce = n->body[0].as<ir::Reduce>(); const ir::Reduce* reduce = n->body[0].as<ir::Reduce>();
n->reduce_axis = reduce->axis; n->reduce_axis = reduce->axis;
...@@ -171,7 +176,8 @@ Operation ComputeOpNode::ReplaceInputs( ...@@ -171,7 +176,8 @@ Operation ComputeOpNode::ReplaceInputs(
}); });
} }
if (!arr.same_as(this->body)) { if (!arr.same_as(this->body)) {
return ComputeOpNode::make(name, tag, axis, arr); return ComputeOpNode::make(
this->name, this->tag, this->attrs, this->axis, arr);
} else { } else {
return self; return self;
} }
......
...@@ -38,23 +38,25 @@ Array<Expr> ExternOpNode::output_shape(size_t i) const { ...@@ -38,23 +38,25 @@ Array<Expr> ExternOpNode::output_shape(size_t i) const {
Operation ExternOpNode::make(std::string name, Operation ExternOpNode::make(std::string name,
std::string tag, std::string tag,
Map<std::string, NodeRef> attrs,
Array<Tensor> inputs, Array<Tensor> inputs,
Array<Buffer> input_placeholders, Array<Buffer> input_placeholders,
Array<Buffer> output_placeholders, Array<Buffer> output_placeholders,
Stmt body) { Stmt body) {
auto n = std::make_shared<ExternOpNode>(); auto n = std::make_shared<ExternOpNode>();
n->name = name; n->name = std::move(name);
n->tag = tag; n->tag = std::move(tag);
n->attrs = std::move(attrs);
CHECK_EQ(inputs.size(), input_placeholders.size()); CHECK_EQ(inputs.size(), input_placeholders.size());
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
CHECK_EQ(inputs[i]->dtype, input_placeholders[i]->dtype); CHECK_EQ(inputs[i]->dtype, input_placeholders[i]->dtype);
CHECK(inputs[i]->shape.same_as(input_placeholders[i]->shape)); CHECK(inputs[i]->shape.same_as(input_placeholders[i]->shape));
CHECK_EQ(input_placeholders[i]->strides.size(), 0U); CHECK_EQ(input_placeholders[i]->strides.size(), 0U);
} }
n->inputs = inputs; n->inputs = std::move(inputs);
n->input_placeholders = input_placeholders; n->input_placeholders = std::move(input_placeholders);
n->output_placeholders = output_placeholders; n->output_placeholders = std::move(output_placeholders);
n->body = body; n->body = std::move(body);
return Operation(n); return Operation(n);
} }
......
...@@ -45,6 +45,7 @@ Array<Expr> ScanOpNode::output_shape(size_t i) const { ...@@ -45,6 +45,7 @@ Array<Expr> ScanOpNode::output_shape(size_t i) const {
Operation ScanOpNode::make(std::string name, Operation ScanOpNode::make(std::string name,
std::string tag, std::string tag,
Map<std::string, NodeRef> attrs,
IterVar axis, IterVar axis,
Array<Tensor> init, Array<Tensor> init,
Array<Tensor> update, Array<Tensor> update,
...@@ -86,13 +87,14 @@ Operation ScanOpNode::make(std::string name, ...@@ -86,13 +87,14 @@ Operation ScanOpNode::make(std::string name,
init[i]->shape[k], state_placeholder[i]->shape[k])); init[i]->shape[k], state_placeholder[i]->shape[k]));
} }
} }
n->name = name; n->name = std::move(name);
n->tag = tag; n->tag = std::move(tag);
n->scan_axis = axis; n->attrs = std::move(attrs);
n->init = init; n->scan_axis = std::move(axis);
n->update = update; n->init = std::move(init);
n->state_placeholder = state_placeholder; n->update = std::move(update);
n->inputs = inputs; n->state_placeholder = std::move(state_placeholder);
n->inputs = std::move(inputs);
return Operation(n); return Operation(n);
} }
...@@ -101,14 +103,16 @@ Array<Tensor> scan(Array<Tensor> init, ...@@ -101,14 +103,16 @@ Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> state_placeholder, Array<Tensor> state_placeholder,
Array<Tensor> inputs, Array<Tensor> inputs,
std::string name, std::string name,
std::string tag) { std::string tag,
Map<std::string, NodeRef> attrs) {
IterVar scan_axis = IterVar scan_axis =
IterVarNode::make( IterVarNode::make(
Range::make_by_min_extent( Range::make_by_min_extent(
init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]), init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
Var(name + ".idx"), kOrdered); Var(name + ".idx"), kOrdered);
Operation op = ScanOpNode::make( Operation op = ScanOpNode::make(
name, tag, scan_axis, init, update, state_placeholder, inputs); name, tag, attrs, scan_axis,
init, update, state_placeholder, inputs);
Array<Tensor> res; Array<Tensor> res;
for (int i = 0; i < op->num_outputs(); ++i) { for (int i = 0; i < op->num_outputs(); ++i) {
res.push_back(op.output(i)); res.push_back(op.output(i));
......
...@@ -232,7 +232,8 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch, ...@@ -232,7 +232,8 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
} }
} }
Operation cache_op = ComputeOpNode::make( Operation cache_op = ComputeOpNode::make(
compute->name + "." + scope, compute->tag, new_axis, body_list); compute->name + "." + scope, compute->tag, compute->attrs,
new_axis, body_list);
Array<Tensor> cache_tensor_list; Array<Tensor> cache_tensor_list;
Array<Expr> cache_expr_list; Array<Expr> cache_expr_list;
for (size_t i = 0; i < tensor_size; i++) { for (size_t i = 0; i < tensor_size; i++) {
...@@ -241,7 +242,8 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch, ...@@ -241,7 +242,8 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
cache_expr_list.push_back(cache_tensor(args)); cache_expr_list.push_back(cache_tensor(args));
} }
Operation orig_new_op = ComputeOpNode::make( Operation orig_new_op = ComputeOpNode::make(
compute->name, compute->tag, compute->axis, cache_expr_list); compute->name, compute->tag, compute->attrs,
compute->axis, cache_expr_list);
// The replace of the dataflow // The replace of the dataflow
std::unordered_map<Tensor, Tensor> vmap; std::unordered_map<Tensor, Tensor> vmap;
std::unordered_map<Tensor, Tensor> rvmap; std::unordered_map<Tensor, Tensor> rvmap;
...@@ -430,7 +432,8 @@ void InjectInline(ScheduleNode* sch) { ...@@ -430,7 +432,8 @@ void InjectInline(ScheduleNode* sch) {
Operation op = s->op; Operation op = s->op;
if (changed[i]) { if (changed[i]) {
op = ComputeOpNode::make( op = ComputeOpNode::make(
compute->name, compute->tag, compute->axis, new_body[i]); compute->name, compute->tag, compute->attrs,
compute->axis, new_body[i]);
} }
op = op->ReplaceInputs(op, repl); op = op->ReplaceInputs(op, repl);
if (!op.same_as(s->op)) { if (!op.same_as(s->op)) {
......
...@@ -11,6 +11,18 @@ def test_const_saveload_json(): ...@@ -11,6 +11,18 @@ def test_const_saveload_json():
assert tvm.save_json(zz) == tvm.save_json(z) assert tvm.save_json(zz) == tvm.save_json(z)
def test_make_smap():
# save load json
x = tvm.const(1)
y = tvm.const(10)
z = x + y
smap = tvm.convert({"z": z, "x": x})
json_str = tvm.save_json(tvm.convert([smap]))
arr = tvm.load_json(json_str)
assert len(arr) == 1
assert arr[0]["z"].a == arr[0]["x"]
def test_make_node(): def test_make_node():
x = tvm.make.node("IntImm", dtype="int32", value=10) x = tvm.make.node("IntImm", dtype="int32", value=10)
assert isinstance(x, tvm.expr.IntImm) assert isinstance(x, tvm.expr.IntImm)
...@@ -35,5 +47,6 @@ def test_make_sum(): ...@@ -35,5 +47,6 @@ def test_make_sum():
if __name__ == "__main__": if __name__ == "__main__":
test_make_node() test_make_node()
test_make_smap()
test_const_saveload_json() test_const_saveload_json()
test_make_sum() test_make_sum()
import json
import tvm import tvm
@tvm.tag_scope(tag="conv") @tvm.tag_scope(tag="conv")
...@@ -24,8 +25,19 @@ def test_with(): ...@@ -24,8 +25,19 @@ def test_with():
B = tvm.placeholder((m, l), name='B') B = tvm.placeholder((m, l), name='B')
with tvm.tag_scope(tag="gemm"): with tvm.tag_scope(tag="gemm"):
k = tvm.reduce_axis((0, l), name='k') k = tvm.reduce_axis((0, l), name='k')
C = tvm.compute((n, m), lambda i, j: tvm.sum(A[i, k] * B[j, k], axis=k)) C = tvm.compute((n, m), lambda i, j: tvm.sum(A[i, k] * B[j, k], axis=k),
attrs={"hello" : 1, "arr": [10, 12]})
assert C.op.tag == 'gemm' assert C.op.tag == 'gemm'
assert "hello" in C.op.attrs
assert "xx" not in C.op.attrs
assert C.op.attrs["hello"].value == 1
CC = tvm.load_json(tvm.save_json(C))
assert CC.op.attrs["hello"].value == 1
assert CC.op.attrs["arr"][0].value == 10
# str format happened to be json compatible
assert json.loads(str(CC.op.attrs))["arr"][1] == 12
def test_decorator(): def test_decorator():
n = tvm.var('n') n = tvm.var('n')
...@@ -39,6 +51,7 @@ def test_decorator(): ...@@ -39,6 +51,7 @@ def test_decorator():
B = tvm.placeholder((c, c, kh, kw), name='B') B = tvm.placeholder((c, c, kh, kw), name='B')
C = compute_conv(A, B) C = compute_conv(A, B)
assert C.op.tag == 'conv' assert C.op.tag == 'conv'
assert len(C.op.attrs) == 0
def test_nested(): def test_nested():
n = tvm.var('n') n = tvm.var('n')
...@@ -59,5 +72,6 @@ def test_nested(): ...@@ -59,5 +72,6 @@ def test_nested():
if __name__ == "__main__": if __name__ == "__main__":
import nose test_with()
nose.runmodule() test_decorator()
test_nested()
...@@ -40,7 +40,7 @@ inline Tensor cublas_matmul(const Tensor& lhs, ...@@ -40,7 +40,7 @@ inline Tensor cublas_matmul(const Tensor& lhs,
pack_buffer(outs[0]), pack_buffer(outs[0]),
transa, transa,
transb }); transb });
}, "C", "")[0]; }, "C", "", {})[0];
} }
} // namespace contrib } // namespace contrib
......
...@@ -39,7 +39,7 @@ inline Tensor rocblas_matmul(const Tensor& lhs, ...@@ -39,7 +39,7 @@ inline Tensor rocblas_matmul(const Tensor& lhs,
pack_buffer(outs[0]), pack_buffer(outs[0]),
transa, transa,
transb }); transb });
}, "C", "")[0]; }, "C", "", {})[0];
} }
} // namespace contrib } // namespace contrib
......
...@@ -6,10 +6,10 @@ ...@@ -6,10 +6,10 @@
#ifndef TOPI_DETAIL_EXTERN_H_ #ifndef TOPI_DETAIL_EXTERN_H_
#define TOPI_DETAIL_EXTERN_H_ #define TOPI_DETAIL_EXTERN_H_
#include <tvm/tvm.h>
#include <vector> #include <vector>
#include <string> #include <string>
#include "tvm/tvm.h"
namespace topi { namespace topi {
namespace detail { namespace detail {
...@@ -51,6 +51,7 @@ using FExtern = std::function<Expr(Array<Buffer>, Array<Buffer>)>; ...@@ -51,6 +51,7 @@ using FExtern = std::function<Expr(Array<Buffer>, Array<Buffer>)>;
* the external function given the input and output buffers. * the external function given the input and output buffers.
* \param name The name of the operation * \param name The name of the operation
* \param tag The tag to mark the operation * \param tag The tag to mark the operation
* \param attrs The additional auxiliary attributes of the operation.
* *
* \return An array of Tensors representing the outputs of the function invocation. There will * \return An array of Tensors representing the outputs of the function invocation. There will
* be one output Tensor for each element of out_shapes, with dtype equal to the corresponding * be one output Tensor for each element of out_shapes, with dtype equal to the corresponding
...@@ -61,7 +62,8 @@ inline Array<Tensor> make_extern(const Array< Array<Expr> >& out_shapes, ...@@ -61,7 +62,8 @@ inline Array<Tensor> make_extern(const Array< Array<Expr> >& out_shapes,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
FExtern fextern, FExtern fextern,
std::string name, std::string name,
std::string tag) { std::string tag,
::tvm::Map<std::string, NodeRef> attrs) {
CHECK_EQ(out_shapes.size(), out_types.size()) CHECK_EQ(out_shapes.size(), out_types.size())
<< "make_extern: out_shapes and out_types must have equal size"; << "make_extern: out_shapes and out_types must have equal size";
...@@ -78,7 +80,8 @@ inline Array<Tensor> make_extern(const Array< Array<Expr> >& out_shapes, ...@@ -78,7 +80,8 @@ inline Array<Tensor> make_extern(const Array< Array<Expr> >& out_shapes,
auto body_stmt = tvm::ir::Evaluate::make(body); auto body_stmt = tvm::ir::Evaluate::make(body);
auto op = ExternOpNode::make( auto op = ExternOpNode::make(
name, tag, inputs, input_placeholders, output_placeholders, body_stmt); name, tag, attrs, inputs,
input_placeholders, output_placeholders, body_stmt);
Array<Tensor> outputs; Array<Tensor> outputs;
for (size_t i = 0; i < output_placeholders.size(); ++i) { for (size_t i = 0; i < output_placeholders.size(); ++i) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment