Unverified Commit 20c495e9 by Tianqi Chen Committed by GitHub

[NODEREF] Introduce named attribute system. (#1618)

parent b00aabc5
......@@ -223,6 +223,12 @@ class ExtTypeVTable {
class TVMPODValue_ {
public:
operator double() const {
// Allow automatic conversion from int to float
// This avoids errors when user pass in int from
// the frontend while the API expects a float.
if (type_code_ == kDLInt) {
return static_cast<double>(value_.v_int64);
}
TVM_CHECK_TYPE_CODE(type_code_, kDLFloat);
return value_.v_float64;
}
......@@ -310,6 +316,8 @@ class TVMPODValue_ {
*/
class TVMArgValue : public TVMPODValue_ {
public:
/*! \brief default constructor */
TVMArgValue() {}
/*!
* \brief constructor
* \param value of the function
......
......@@ -71,6 +71,17 @@ def node(type_key, **kwargs):
**kwargs : dict
The fields of the node.
Returns
-------
node : Node
The corresponding DSL Node
Note
----
If the created node is instance of AttrsNode, then
the creator function will also run bound checks and
default value setup as supported by Attrs.
Example
-------
The following code constructs a IntImm object
......
......@@ -33,18 +33,6 @@ TVM_REGISTER_API("_load_json")
*ret = LoadJSON<NodeRef>(args[0]);
});
TVM_REGISTER_API("_nop")
.set_body([](TVMArgs args, TVMRetValue *ret) {
});
// internal fucntion used for debug and testing purposes
TVM_REGISTER_API("_ndarray_use_count")
.set_body([](TVMArgs args, TVMRetValue *ret) {
runtime::NDArray nd = args[0];
// substract the current one
*ret = (nd.use_count() - 1);
});
TVM_REGISTER_API("_TVMSetStream")
.set_body([](TVMArgs args, TVMRetValue *ret) {
TVMSetStream(args[0], args[1], args[2]);
......
/*!
* Copyright (c) 2018 by Contributors
* Code mainly used for test purposes.
* \file api_test.cc
*/
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/attrs.h>
#include <tvm/api_registry.h>
namespace tvm {
// Attrs used to python API
struct TestAttrs : public AttrsNode<TestAttrs> {
int axis;
std::string name;
Array<Expr> padding;
TVM_DECLARE_ATTRS(TestAttrs, "attrs.TestAttrs") {
TVM_ATTR_FIELD(axis)
.set_default(10)
.set_lower_bound(1)
.set_upper_bound(10)
.describe("axis field");
TVM_ATTR_FIELD(name)
.describe("name");
TVM_ATTR_FIELD(padding)
.describe("padding of input")
.set_default(Array<Expr>({0, 0}));
}
};
TVM_REGISTER_NODE_TYPE(TestAttrs);
TVM_REGISTER_API("_nop")
.set_body([](TVMArgs args, TVMRetValue *ret) {
});
// internal fucntion used for debug and testing purposes
TVM_REGISTER_API("_ndarray_use_count")
.set_body([](TVMArgs args, TVMRetValue *ret) {
runtime::NDArray nd = args[0];
// substract the current one
*ret = (nd.use_count() - 1);
});
} // namespace tvm
......@@ -7,6 +7,7 @@
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/api_registry.h>
#include <tvm/attrs.h>
#include <vector>
#include <string>
#include <exception>
......@@ -124,22 +125,35 @@ class DSLAPIImpl : public DSLAPI {
(*static_cast<TVMAPINode*>(handle))->type_index());
}
void NodeGetAttr(NodeHandle handle,
const char* key,
TVMValue* ret_val,
int* ret_type_code,
int* ret_success) const final {
const char* key,
TVMValue* ret_val,
int* ret_type_code,
int* ret_success) const final {
TVMRetValue rv;
APIAttrGetter getter;
TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
getter.skey = key;
getter.ret = &rv;
TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
if (getter.skey == "type_key") {
ret_val->v_str = (*tnode)->type_key();
*ret_type_code = kStr;
*ret_success = 1;
} else {
return;
} else if (!(*tnode)->is_type<DictAttrsNode>()) {
(*tnode)->VisitAttrs(&getter);
*ret_success = getter.found_ref_object || rv.type_code() != kNull;
} else {
// specially handle dict attr
DictAttrsNode* dnode = static_cast<DictAttrsNode*>(tnode->get());
auto it = dnode->dict.find(key);
if (it != dnode->dict.end()) {
*ret_success = 1;
rv = (*it).second;
} else {
*ret_success = 0;
}
}
if (*ret_success) {
if (rv.type_code() == kStr ||
rv.type_code() == kTVMType) {
TVMAPIThreadLocalEntry *e = TVMAPIThreadLocalStore::Get();
......@@ -159,7 +173,16 @@ class DSLAPIImpl : public DSLAPI {
TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
APIAttrDir dir;
dir.names = &(ret->ret_vec_str);
(*tnode)->VisitAttrs(&dir);
if (!(*tnode)->is_type<DictAttrsNode>()) {
(*tnode)->VisitAttrs(&dir);
} else {
// specially handle dict attr
DictAttrsNode* dnode = static_cast<DictAttrsNode*>(tnode->get());
for (const auto& kv : dnode->dict) {
ret->ret_vec_str.push_back(kv.first);
}
}
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
......
/*!
* Copyright (c) 2018 by Contributors
* \file attrs.cc
*/
#include <tvm/attrs.h>
namespace tvm {
void DictAttrsNode::VisitAttrs(AttrVisitor* v) {
v->Visit("__dict__", &dict);
}
void DictAttrsNode::InitByPackedArgs(
const runtime::TVMArgs& args, bool allow_unknown) {
for (int i = 0; i < args.size(); i += 2) {
std::string key = args[i];
runtime::TVMArgValue val = args[i + 1];
if (val.type_code() == kNodeHandle) {
dict.Set(key, val.operator NodeRef());
} else if (val.type_code() == kStr) {
dict.Set(key, Expr(val.operator std::string()));
} else {
dict.Set(key, val.operator Expr());
}
}
}
std::vector<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const {
return {};
}
Attrs DictAttrsNode::make(Map<std::string, NodeRef> dict) {
std::shared_ptr<DictAttrsNode> n = std::make_shared<DictAttrsNode>();
n->dict = std::move(dict);
return Attrs(n);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<DictAttrsNode>([](const DictAttrsNode *op, IRPrinter *p) {
p->stream << op->dict;
});
TVM_REGISTER_NODE_TYPE(DictAttrsNode);
} // namespace tvm
......@@ -5,6 +5,7 @@
*/
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/attrs.h>
#include <tvm/container.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/ndarray.h>
......@@ -467,22 +468,15 @@ class NodeAttrSetter : public AttrVisitor {
}
};
// API function to make node.
// args format:
// type_key, key1, value1, ..., key_n, value_n
void MakeNode(runtime::TVMArgs args, runtime::TVMRetValue* rv) {
void InitNodeByPackedArgs(Node* n, const TVMArgs& args) {
NodeAttrSetter setter;
setter.type_key = args[0].operator std::string();
CHECK_EQ(args.size() % 2, 1);
for (int i = 1; i < args.size(); i += 2) {
setter.attrs.emplace(
args[i].operator std::string(),
runtime::TVMArgValue(args.values[i + 1], args.type_codes[i + 1]));
}
auto* f = dmlc::Registry<NodeFactoryReg>::Find(setter.type_key);
CHECK(f != nullptr)
<< "Node type \'" << setter.type_key << "\' is not registered in TVM";
std::shared_ptr<Node> n = f->body();
setter.type_key = n->type_key();
CHECK_EQ(args.size() % 2, 0);
for (int i = 0; i < args.size(); i += 2) {
setter.attrs.emplace(args[i].operator std::string(),
args[i + 1]);
}
n->VisitAttrs(&setter);
if (setter.attrs.size() != 0) {
std::ostringstream os;
......@@ -492,10 +486,26 @@ void MakeNode(runtime::TVMArgs args, runtime::TVMRetValue* rv) {
}
LOG(FATAL) << os.str();
}
}
// API function to make node.
// args format:
// key1, value1, ..., key_n, value_n
void MakeNode(const TVMArgs& args, TVMRetValue* rv) {
std::string type_key = args[0];
auto* f = dmlc::Registry<NodeFactoryReg>::Find(type_key);
CHECK(f != nullptr)
<< "Node type \'" << type_key << "\' is not registered in TVM";
TVMArgs kwargs(args.values + 1, args.type_codes + 1, args.size() - 1);
std::shared_ptr<Node> n = f->body();
if (n->derived_from<BaseAttrsNode>()) {
static_cast<BaseAttrsNode*>(n.get())->InitByPackedArgs(kwargs);
} else {
InitNodeByPackedArgs(n.get(), kwargs);
}
*rv = NodeRef(n);
}
TVM_REGISTER_GLOBAL("make._Node")
.set_body(MakeNode);
} // namespace tvm
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/attrs.h>
#include <tvm/ir.h>
namespace tvm {
namespace test {
// test example usage docs
struct TestAttrs : public AttrsNode<TestAttrs> {
int axis;
std::string name;
Expr expr;
double learning_rate;
TVM_DECLARE_ATTRS(TestAttrs, "attrs.cpptest.TestAttrs") {
TVM_ATTR_FIELD(axis)
.set_default(10)
.set_lower_bound(1)
.set_upper_bound(10)
.describe("axis field");
TVM_ATTR_FIELD(name)
.describe("name of the field");
TVM_ATTR_FIELD(expr)
.describe("expression field")
.set_default(make_const(Int(32), 1));
TVM_ATTR_FIELD(learning_rate)
.describe("learning_rate")
.set_default(0.1);
}
};
}
}
TEST(Attrs, Basic) {
using namespace tvm;
using namespace tvm::test;
std::shared_ptr<TestAttrs> n = std::make_shared<TestAttrs>();
try {
n->InitBySeq("axis", 10);
LOG(FATAL) << "bad";
} catch (const tvm::AttrError& e) {
}
try {
n->InitBySeq("axis", 12, "name", "111");
LOG(FATAL) << "bad";
} catch (const tvm::AttrError& e) {
}
try {
n->InitBySeq("axisx", 12, "name", "111");
LOG(FATAL) << "bad";
} catch (const tvm::AttrError& e) {
std::string what = e.what();
CHECK(what.find("expr : Expr, default=1") != std::string::npos);
CHECK(what.find("axisx") != std::string::npos);
}
n->InitBySeq("learning_rate", Expr(1), "expr", 128, "name", "xx");
CHECK_EQ(n->learning_rate, 1.0);
n->InitBySeq("name", "xxx", "expr", 128);
CHECK_EQ(n->name, "xxx");
CHECK_EQ(n->axis, 10);
CHECK_EQ(n->expr.as<tvm::ir::IntImm>()->value, 128);
// Check docstring
std::ostringstream os;
n->PrintDocString(os);
LOG(INFO) << "docstring\n"<< os.str();
CHECK(os.str().find("expr : Expr, default=1") != std::string::npos);
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
......@@ -36,6 +36,31 @@ def test_make_node():
assert AA.op == A.op
assert AA.value_index == A.value_index
def test_make_attrs():
try:
x = tvm.make.node("attrs.TestAttrs", unknown_key=1, name="xx")
assert False
except tvm.TVMError as e:
assert str(e).find("unknown_key") != -1
try:
x = tvm.make.node("attrs.TestAttrs", axis=100, name="xx")
assert False
except tvm.TVMError as e:
assert str(e).find("upper bound") != -1
x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3,4))
assert x.name == "xx"
assert x.padding[0].value == 3
assert x.padding[1].value == 4
assert x.axis == 10
dattr = tvm.make.node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
assert dattr.x.value == 1
def test_make_sum():
A = tvm.placeholder((2, 10), name='A')
k = tvm.reduce_axis((0,10), "k")
......@@ -46,6 +71,7 @@ def test_make_sum():
assert BB.op.body[0].combiner is not None
if __name__ == "__main__":
test_make_attrs()
test_make_node()
test_make_smap()
test_const_saveload_json()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment