Unverified Commit 20c495e9 by Tianqi Chen Committed by GitHub

[NODEREF] Introduce named attribute system. (#1618)

parent b00aabc5
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/attrs.h
* \brief TVM attribute module
*
* This module enables declaration of named attributes
* which support default value setup and bound checking.
*
* \code
* struct MyAttrs : public tvm::AttrsNode<MyAttrs> {
* float learning_rate;
* int num_hidden;
* std::string name;
* // declare attribute fields in header file
* TVM_DECLARE_ATTRS(MyAttrs, "attrs.MyAttrs") {
* TVM_ATTR_FIELD(num_hidden).set_lower_bound(1);
* TVM_ATTR_FIELD(learning_rate).set_default(0.01f);
* TVM_ATTR_FIELD(name).set_default("hello");
* }
* };
* // register it in cc file
* TVM_REGISTER_NODE_TYPE(MyAttrs);
* \endcode
*
* \sa AttrsNode, TVM_DECLARE_ATTRS, TVM_ATTR_FIELD
*/
#ifndef TVM_ATTRS_H_
#define TVM_ATTRS_H_
#include <unordered_map>
#include <vector>
#include <type_traits>
#include <string>
#include "./ir.h"
#include "./base.h"
#include "./packed_func_ext.h"
namespace tvm {
/*!
* \brief Declare an attribute function.
* \param ClassName The name of the class.
* \param TypeKey The type key to be used by the TVM node system.
*/
#define TVM_DECLARE_ATTRS(ClassName, TypeKey) \
static constexpr const char* _type_key = TypeKey; \
TVM_DECLARE_NODE_TYPE_INFO(ClassName, ::tvm::BaseAttrsNode); \
template<typename FVisit> \
void __VisitAttrs__(FVisit& __fvisit__) // NOLINT(*)
/*!
* \brief Declare an attribute field.
* \param FieldName The field name.
*/
#define TVM_ATTR_FIELD(FieldName) \
__fvisit__(#FieldName, &FieldName)
/*! \brief Error thrown during attribute checking. */
struct AttrError : public dmlc::Error {
/*!
* \brief constructor
* \param msg error message
*/
explicit AttrError(const std::string &msg)
: dmlc::Error(msg) {}
};
/*!
* \brief Information about attribute fields in string representations.
*/
struct AttrFieldInfo {
/*! \brief name of the field */
std::string name;
/*! \brief type docstring information in str. */
std::string type_info;
/*! \brief detailed description of the type */
std::string description;
};
/*!
* \brief Base class of all attribute class
* \note Do not subclass AttrBaseNode directly,
* subclass AttrsNode instead.
* \sa AttrsNode
*/
class BaseAttrsNode : public Node {
public:
using TVMArgs = runtime::TVMArgs;
using TVMRetValue = runtime::TVMRetValue;
/*!
* \brief Initialize the attributes by sequence of arguments
* \param args The postional arguments in the form
* [key0, value0, key1, value1, ..., key_n, value_n]
*/
template<typename... Args>
inline void InitBySeq(Args&& ...args);
/*!
* \brief Print readible docstring to ostream, add newline.
* \param os the stream to print the docstring to.
*/
inline void PrintDocString(std::ostream &os) const; // NOLINT(*)
/*!
* \brief Get the field information about the
* \note This function throws when the required a field is not present.
*/
TVM_DLL virtual std::vector<AttrFieldInfo> ListFieldInfo() const = 0;
/*!
* \brief Initialize the attributes by arguments.
* \param kwargs The key value pairs for initialization.
* [key0, value0, key1, value1, ..., key_n, value_n]
* \param allow_unknown Whether allow additional unknown fields.
* \note This function throws when the required a field is not present.
*/
TVM_DLL virtual void InitByPackedArgs(const TVMArgs& kwargs, bool allow_unknown = false) = 0;
static constexpr const char* _type_key = "Attrs";
TVM_DECLARE_BASE_NODE_INFO(BaseAttrsNode, Node);
};
/*! \brief Base attribute container for all attributes */
class Attrs : public NodeRef {
public:
// normal constructor
Attrs() {}
// construct from shared ptr.
explicit Attrs(std::shared_ptr<Node> n) : NodeRef(n) {}
/*! \return The attribute node */
const BaseAttrsNode* operator->() const {
return ptr();
}
/*! \brief specify container node */
using ContainerType = BaseAttrsNode;
private:
/*! \return the internal attribute node */
const BaseAttrsNode* ptr() const {
return static_cast<const BaseAttrsNode*>(node_.get());
}
};
/*!
* \brief Specialized attribute type that is backed by a map.
* The DictAttrsNode implements the Attrs behavior,
* its fields are directly accessible via object.field_name
* like other normal nodes.
*/
class DictAttrsNode : public BaseAttrsNode {
public:
/*! \brief internal attrs map */
Map<std::string, NodeRef> dict;
/*!
* \brief Consruct a Attrs backed by DictAttrsNode.
* \param dict The attributes.
* \return The dict attributes.
*/
TVM_DLL static Attrs make(Map<std::string, NodeRef> dict);
// implementations
void VisitAttrs(AttrVisitor* v) final;
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
std::vector<AttrFieldInfo> ListFieldInfo() const final;
// type info
static constexpr const char* _type_key = "DictAttrs";
TVM_DECLARE_NODE_TYPE_INFO(DictAttrsNode, BaseAttrsNode);
};
// Namespace containing detail implementations
namespace detail {
using runtime::TVMArgValue;
// helper entry that does nothing in set_default/bound/describe calls.
struct AttrNopEntry {
using TSelf = AttrNopEntry;
TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) {
return *this;
}
template<typename T>
TSelf& set_default(DMLC_ATTRIBUTE_UNUSED T value) {
return *this;
}
template<typename T>
TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin) {
return *this;
}
template<typename T>
TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end) {
return *this;
}
};
// Wrapper for normal visitor.
class AttrNormalVisitor {
public:
explicit AttrNormalVisitor(AttrVisitor* visitor)
: visitor_(visitor) {
}
template<typename T>
AttrNopEntry operator()(const char* key, T* value) {
visitor_->Visit(key, value);
return AttrNopEntry();
}
private:
AttrVisitor* visitor_;
};
// helper entry that does initialization, set default.
template<typename T>
struct AttrInitEntry {
// The attributes
using TSelf = AttrInitEntry<T>;
// The type key
const char* type_key_;
// field name
const char* key_;
// internal value.
T* value_;
// whether the value is missing.
bool value_missing_{true};
// If the value is still missing in destruction time throw an error.
~AttrInitEntry() DMLC_THROW_EXCEPTION {
if (value_missing_) {
std::ostringstream os;
os << type_key_ << ": Cannot find required field \'" << key_
<< "\' during initialization";
throw AttrError(os.str());
}
}
// override fields.
// This function sets the lower bound of the attribute
TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) {
if (this->value_missing_) return *this;
const T& val = *value_;
if (begin > val) {
std::ostringstream os;
os << type_key_ << "." << key_ << ": "
<< "value " << val
<< " is smaller than the lower bound " << begin;
throw AttrError(os.str());
}
return *this;
}
// This function sets the upper bound of the attribute
TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) {
if (this->value_missing_) return *this;
const T& val = *value_;
if (val > end) {
std::ostringstream os;
os << type_key_ << "." << key_ << ": "
<< "value " << val
<< " is bigger than the upper bound " << end;
throw AttrError(os.str());
}
return *this;
}
// set default when
TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) {
if (!value_missing_) return *this;
*value_ = value;
value_missing_ = false;
return *this;
}
TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) {
return *this;
}
};
// Template function to allow smart conversion
// from Expr types into the constants.
template<typename T>
inline void SetValue(T* ptr, const TVMArgValue& val) {
*ptr = val.operator T();
}
template<typename T>
inline void SetIntValue(T* ptr, const TVMArgValue& val) {
if (val.type_code() == kDLInt) {
*ptr = static_cast<T>(val.value().v_int64);
} else {
Expr expr = val;
CHECK(expr.defined());
if (const ir::IntImm* op = expr.as<ir::IntImm>()) {
*ptr = static_cast<T>(op->value);
} else if (const ir::UIntImm* op = expr.as<ir::UIntImm>()) {
*ptr = static_cast<T>(op->value);
} else {
LOG(FATAL) << "Expect int value, but get " << expr->type_key();
}
}
}
template<>
inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
if (val.type_code() == kStr) {
*ptr = val.operator std::string();
} else {
Expr expr = val;
const ir::StringImm* op = expr.as<ir::StringImm>();
CHECK(op != nullptr);
*ptr = op->value;
}
}
template<>
inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
if (val.type_code() == kDLFloat || val.type_code() == kDLInt) {
*ptr = val.operator double();
} else {
Expr expr = val;
CHECK(expr.defined());
if (const ir::IntImm* op = expr.as<ir::IntImm>()) {
*ptr = static_cast<double>(op->value);
} else if (const ir::IntImm* op = expr.as<ir::IntImm>()) {
*ptr = static_cast<double>(op->value);
} else if (const ir::UIntImm* op = expr.as<ir::UIntImm>()) {
*ptr = static_cast<double>(op->value);
} else {
LOG(FATAL) << "Expect float value, but get " << expr->type_key();
}
}
}
template<>
inline void SetValue<int>(int* ptr, const TVMArgValue& val) {
SetIntValue(ptr, val);
}
template<>
inline void SetValue<int64_t>(int64_t* ptr, const TVMArgValue& val) {
SetIntValue(ptr, val);
}
template<>
inline void SetValue<uint64_t>(uint64_t* ptr, const TVMArgValue& val) {
SetIntValue(ptr, val);
}
template<>
inline void SetValue<bool>(bool* ptr, const TVMArgValue& val) {
SetIntValue(ptr, val);
}
// Visitor for value initialization
template<typename FFind>
class AttrInitVisitor {
public:
// Counter of number of matched attributes during visit.
// This is used to decide if there is additional unmatched attributes.
size_t hit_count_{0};
// constructor
AttrInitVisitor(const char* type_key, FFind ffind)
: type_key_(type_key), ffind_(ffind) {
}
template<typename T>
AttrInitEntry<T> operator()(const char* key, T* value) {
TVMArgValue val;
AttrInitEntry<T> opt;
opt.type_key_ = type_key_;
opt.key_ = key;
opt.value_ = value;
if (ffind_(key, &val)) {
SetValue(value, val);
opt.value_missing_ = false;
++hit_count_;
} else {
opt.value_missing_ = true;
}
return opt;
}
private:
// the type key
const char* type_key_;
FFind ffind_;
};
template<typename FFind>
inline AttrInitVisitor<FFind> CreateInitVisitor(
const char* type_key,
FFind ffind) {
return AttrInitVisitor<FFind>(type_key, ffind);
}
/*!
* \brief Helper struct to get the type name known to tvm.
* \tparam T the type we are interested in.
*/
template<typename T>
struct TypeName {
static constexpr const char* value = T::ContainerType::_type_key;
};
template<>
struct TypeName<int> {
static constexpr const char* value = "int";
};
template<>
struct TypeName<int64_t> {
static constexpr const char* value = "int64";
};
template<>
struct TypeName<uint64_t> {
static constexpr const char* value = "uint64_t";
};
template<>
struct TypeName<Type> {
static constexpr const char* value = "Type";
};
template<>
struct TypeName<std::string> {
static constexpr const char* value = "str";
};
template<>
struct TypeName<bool> {
static constexpr const char* value = "bool";
};
template<>
struct TypeName<void*> {
static constexpr const char* value = "handle";
};
template<>
struct TypeName<double> {
static constexpr const char* value = "double";
};
class AttrDocEntry {
public:
using TSelf = AttrDocEntry;
explicit AttrDocEntry(AttrFieldInfo* info)
: info_(info) {
}
TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) {
info_->description = str;
return *this;
}
template<typename T>
TSelf& set_default(DMLC_ATTRIBUTE_UNUSED T value) {
std::ostringstream os;
os << info_->type_info << ", default=" << value;
info_->type_info = os.str();
return *this;
}
template<typename T>
TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin) {
return *this;
}
template<typename T>
TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end) {
return *this;
}
private:
AttrFieldInfo* info_;
};
class AttrDocVisitor {
public:
template<typename T>
AttrDocEntry operator()(const char* key, T* v) {
AttrFieldInfo info;
info.name = key;
info.type_info = TypeName<T>::value;
fields_.emplace_back(std::move(info));
return AttrDocEntry(&(fields_.back()));
}
std::vector<AttrFieldInfo> fields_;
};
class AttrExistVisitor {
public:
std::string key_;
bool exist_{false};
template<typename T>
AttrNopEntry operator()(const char* key, T* v) {
if (exist_) return AttrNopEntry();
if (key == key_) exist_ = true;
return AttrNopEntry();
}
};
} // namespace detail
/*!
* \brief The base class of the all the
* Use "curiously recurring template pattern".
*
* \tparam DerivedType The final attribute type.
*/
template<typename DerivedType>
class AttrsNode : public BaseAttrsNode {
public:
void VisitAttrs(AttrVisitor* v) final {
detail::AttrNormalVisitor vis(v);
self()->__VisitAttrs__(vis);
}
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final {
CHECK_EQ(args.size() % 2, 0);
const int kLinearSearchBound = 16;
int hit_count = 0;
// applies two stratgies to lookup
if (args.size() < kLinearSearchBound) {
// linear search.
auto ffind = [&args](const char* key, runtime::TVMArgValue* val) {
for (int i = 0; i < args.size(); i += 2) {
CHECK_EQ(args.type_codes[i], kStr);
if (!std::strcmp(key, args.values[i].v_str)) {
*val = args[i + 1];
return true;
}
}
return false;
};
auto vis = detail::CreateInitVisitor(DerivedType::_type_key, ffind);
self()->__VisitAttrs__(vis);
hit_count = vis.hit_count_;
} else {
// construct a map then do lookup.
std::unordered_map<std::string, runtime::TVMArgValue> kwargs;
for (int i = 0; i < args.size(); i += 2) {
CHECK_EQ(args.type_codes[i], kStr);
kwargs[args[i].operator std::string()] = args[i + 1];
}
auto ffind = [&kwargs](const char *key, runtime::TVMArgValue* val) {
auto it = kwargs.find(key);
if (it != kwargs.end()) {
*val = it->second;
return true;
}
return false;
};
auto vis = detail::CreateInitVisitor(DerivedType::_type_key, ffind);
self()->__VisitAttrs__(vis);
hit_count = vis.hit_count_;
}
// error handling, slow path
if (hit_count * 2 != args.size() && !allow_unknown) {
for (int i = 0; i < args.size(); i += 2) {
detail::AttrExistVisitor visitor;
visitor.key_ = args[i].operator std::string();
self()->__VisitAttrs__(visitor);
if (!visitor.exist_) {
std::ostringstream os;
os << DerivedType::_type_key
<< ": does not have field \'" << visitor.key_
<< "\', Possible fields:\n";
os << "----------------\n";
this->PrintDocString(os);
throw AttrError(os.str());
}
}
}
}
std::vector<AttrFieldInfo> ListFieldInfo() const final {
detail::AttrDocVisitor visitor;
self()->__VisitAttrs__(visitor);
return visitor.fields_;
}
private:
DerivedType* self() const {
return const_cast<DerivedType*>(
static_cast<const DerivedType*>(this));
}
};
template<typename... Args>
inline void BaseAttrsNode::InitBySeq(Args&& ...args) {
runtime::PackedFunc pf([this](const TVMArgs& args, TVMRetValue *rv) {
this->InitByPackedArgs(args);
});
pf(std::forward<Args>(args)...);
}
inline void BaseAttrsNode::PrintDocString(std::ostream &os) const { // NOLINT(*)
std::vector<AttrFieldInfo> entry = this->ListFieldInfo();
for (AttrFieldInfo info : entry) {
os << info.name << " : " << info.type_info << '\n';
if (info.description.length() != 0) {
os << " " << info.description << '\n';
}
}
}
} // namespace tvm
#endif // TVM_ATTRS_H_
......@@ -223,6 +223,12 @@ class ExtTypeVTable {
class TVMPODValue_ {
public:
operator double() const {
// Allow automatic conversion from int to float
// This avoids errors when user pass in int from
// the frontend while the API expects a float.
if (type_code_ == kDLInt) {
return static_cast<double>(value_.v_int64);
}
TVM_CHECK_TYPE_CODE(type_code_, kDLFloat);
return value_.v_float64;
}
......@@ -310,6 +316,8 @@ class TVMPODValue_ {
*/
class TVMArgValue : public TVMPODValue_ {
public:
/*! \brief default constructor */
TVMArgValue() {}
/*!
* \brief constructor
* \param value of the function
......
......@@ -71,6 +71,17 @@ def node(type_key, **kwargs):
**kwargs : dict
The fields of the node.
Returns
-------
node : Node
The corresponding DSL Node
Note
----
If the created node is instance of AttrsNode, then
the creator function will also run bound checks and
default value setup as supported by Attrs.
Example
-------
The following code constructs a IntImm object
......
......@@ -33,18 +33,6 @@ TVM_REGISTER_API("_load_json")
*ret = LoadJSON<NodeRef>(args[0]);
});
TVM_REGISTER_API("_nop")
.set_body([](TVMArgs args, TVMRetValue *ret) {
});
// internal fucntion used for debug and testing purposes
TVM_REGISTER_API("_ndarray_use_count")
.set_body([](TVMArgs args, TVMRetValue *ret) {
runtime::NDArray nd = args[0];
// substract the current one
*ret = (nd.use_count() - 1);
});
TVM_REGISTER_API("_TVMSetStream")
.set_body([](TVMArgs args, TVMRetValue *ret) {
TVMSetStream(args[0], args[1], args[2]);
......
/*!
* Copyright (c) 2018 by Contributors
* Code mainly used for test purposes.
* \file api_test.cc
*/
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/attrs.h>
#include <tvm/api_registry.h>
namespace tvm {
// Attrs used to python API
struct TestAttrs : public AttrsNode<TestAttrs> {
int axis;
std::string name;
Array<Expr> padding;
TVM_DECLARE_ATTRS(TestAttrs, "attrs.TestAttrs") {
TVM_ATTR_FIELD(axis)
.set_default(10)
.set_lower_bound(1)
.set_upper_bound(10)
.describe("axis field");
TVM_ATTR_FIELD(name)
.describe("name");
TVM_ATTR_FIELD(padding)
.describe("padding of input")
.set_default(Array<Expr>({0, 0}));
}
};
TVM_REGISTER_NODE_TYPE(TestAttrs);
TVM_REGISTER_API("_nop")
.set_body([](TVMArgs args, TVMRetValue *ret) {
});
// internal fucntion used for debug and testing purposes
TVM_REGISTER_API("_ndarray_use_count")
.set_body([](TVMArgs args, TVMRetValue *ret) {
runtime::NDArray nd = args[0];
// substract the current one
*ret = (nd.use_count() - 1);
});
} // namespace tvm
......@@ -7,6 +7,7 @@
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/api_registry.h>
#include <tvm/attrs.h>
#include <vector>
#include <string>
#include <exception>
......@@ -124,22 +125,35 @@ class DSLAPIImpl : public DSLAPI {
(*static_cast<TVMAPINode*>(handle))->type_index());
}
void NodeGetAttr(NodeHandle handle,
const char* key,
TVMValue* ret_val,
int* ret_type_code,
int* ret_success) const final {
const char* key,
TVMValue* ret_val,
int* ret_type_code,
int* ret_success) const final {
TVMRetValue rv;
APIAttrGetter getter;
TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
getter.skey = key;
getter.ret = &rv;
TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
if (getter.skey == "type_key") {
ret_val->v_str = (*tnode)->type_key();
*ret_type_code = kStr;
*ret_success = 1;
} else {
return;
} else if (!(*tnode)->is_type<DictAttrsNode>()) {
(*tnode)->VisitAttrs(&getter);
*ret_success = getter.found_ref_object || rv.type_code() != kNull;
} else {
// specially handle dict attr
DictAttrsNode* dnode = static_cast<DictAttrsNode*>(tnode->get());
auto it = dnode->dict.find(key);
if (it != dnode->dict.end()) {
*ret_success = 1;
rv = (*it).second;
} else {
*ret_success = 0;
}
}
if (*ret_success) {
if (rv.type_code() == kStr ||
rv.type_code() == kTVMType) {
TVMAPIThreadLocalEntry *e = TVMAPIThreadLocalStore::Get();
......@@ -159,7 +173,16 @@ class DSLAPIImpl : public DSLAPI {
TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
APIAttrDir dir;
dir.names = &(ret->ret_vec_str);
(*tnode)->VisitAttrs(&dir);
if (!(*tnode)->is_type<DictAttrsNode>()) {
(*tnode)->VisitAttrs(&dir);
} else {
// specially handle dict attr
DictAttrsNode* dnode = static_cast<DictAttrsNode*>(tnode->get());
for (const auto& kv : dnode->dict) {
ret->ret_vec_str.push_back(kv.first);
}
}
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
......
/*!
* Copyright (c) 2018 by Contributors
* \file attrs.cc
*/
#include <tvm/attrs.h>
namespace tvm {
void DictAttrsNode::VisitAttrs(AttrVisitor* v) {
v->Visit("__dict__", &dict);
}
void DictAttrsNode::InitByPackedArgs(
const runtime::TVMArgs& args, bool allow_unknown) {
for (int i = 0; i < args.size(); i += 2) {
std::string key = args[i];
runtime::TVMArgValue val = args[i + 1];
if (val.type_code() == kNodeHandle) {
dict.Set(key, val.operator NodeRef());
} else if (val.type_code() == kStr) {
dict.Set(key, Expr(val.operator std::string()));
} else {
dict.Set(key, val.operator Expr());
}
}
}
std::vector<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const {
return {};
}
Attrs DictAttrsNode::make(Map<std::string, NodeRef> dict) {
std::shared_ptr<DictAttrsNode> n = std::make_shared<DictAttrsNode>();
n->dict = std::move(dict);
return Attrs(n);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<DictAttrsNode>([](const DictAttrsNode *op, IRPrinter *p) {
p->stream << op->dict;
});
TVM_REGISTER_NODE_TYPE(DictAttrsNode);
} // namespace tvm
......@@ -5,6 +5,7 @@
*/
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/attrs.h>
#include <tvm/container.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/ndarray.h>
......@@ -467,22 +468,15 @@ class NodeAttrSetter : public AttrVisitor {
}
};
// API function to make node.
// args format:
// type_key, key1, value1, ..., key_n, value_n
void MakeNode(runtime::TVMArgs args, runtime::TVMRetValue* rv) {
void InitNodeByPackedArgs(Node* n, const TVMArgs& args) {
NodeAttrSetter setter;
setter.type_key = args[0].operator std::string();
CHECK_EQ(args.size() % 2, 1);
for (int i = 1; i < args.size(); i += 2) {
setter.attrs.emplace(
args[i].operator std::string(),
runtime::TVMArgValue(args.values[i + 1], args.type_codes[i + 1]));
}
auto* f = dmlc::Registry<NodeFactoryReg>::Find(setter.type_key);
CHECK(f != nullptr)
<< "Node type \'" << setter.type_key << "\' is not registered in TVM";
std::shared_ptr<Node> n = f->body();
setter.type_key = n->type_key();
CHECK_EQ(args.size() % 2, 0);
for (int i = 0; i < args.size(); i += 2) {
setter.attrs.emplace(args[i].operator std::string(),
args[i + 1]);
}
n->VisitAttrs(&setter);
if (setter.attrs.size() != 0) {
std::ostringstream os;
......@@ -492,10 +486,26 @@ void MakeNode(runtime::TVMArgs args, runtime::TVMRetValue* rv) {
}
LOG(FATAL) << os.str();
}
}
// API function to make node.
// args format:
// key1, value1, ..., key_n, value_n
void MakeNode(const TVMArgs& args, TVMRetValue* rv) {
std::string type_key = args[0];
auto* f = dmlc::Registry<NodeFactoryReg>::Find(type_key);
CHECK(f != nullptr)
<< "Node type \'" << type_key << "\' is not registered in TVM";
TVMArgs kwargs(args.values + 1, args.type_codes + 1, args.size() - 1);
std::shared_ptr<Node> n = f->body();
if (n->derived_from<BaseAttrsNode>()) {
static_cast<BaseAttrsNode*>(n.get())->InitByPackedArgs(kwargs);
} else {
InitNodeByPackedArgs(n.get(), kwargs);
}
*rv = NodeRef(n);
}
TVM_REGISTER_GLOBAL("make._Node")
.set_body(MakeNode);
} // namespace tvm
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/attrs.h>
#include <tvm/ir.h>
namespace tvm {
namespace test {
// test example usage docs
struct TestAttrs : public AttrsNode<TestAttrs> {
int axis;
std::string name;
Expr expr;
double learning_rate;
TVM_DECLARE_ATTRS(TestAttrs, "attrs.cpptest.TestAttrs") {
TVM_ATTR_FIELD(axis)
.set_default(10)
.set_lower_bound(1)
.set_upper_bound(10)
.describe("axis field");
TVM_ATTR_FIELD(name)
.describe("name of the field");
TVM_ATTR_FIELD(expr)
.describe("expression field")
.set_default(make_const(Int(32), 1));
TVM_ATTR_FIELD(learning_rate)
.describe("learning_rate")
.set_default(0.1);
}
};
}
}
TEST(Attrs, Basic) {
using namespace tvm;
using namespace tvm::test;
std::shared_ptr<TestAttrs> n = std::make_shared<TestAttrs>();
try {
n->InitBySeq("axis", 10);
LOG(FATAL) << "bad";
} catch (const tvm::AttrError& e) {
}
try {
n->InitBySeq("axis", 12, "name", "111");
LOG(FATAL) << "bad";
} catch (const tvm::AttrError& e) {
}
try {
n->InitBySeq("axisx", 12, "name", "111");
LOG(FATAL) << "bad";
} catch (const tvm::AttrError& e) {
std::string what = e.what();
CHECK(what.find("expr : Expr, default=1") != std::string::npos);
CHECK(what.find("axisx") != std::string::npos);
}
n->InitBySeq("learning_rate", Expr(1), "expr", 128, "name", "xx");
CHECK_EQ(n->learning_rate, 1.0);
n->InitBySeq("name", "xxx", "expr", 128);
CHECK_EQ(n->name, "xxx");
CHECK_EQ(n->axis, 10);
CHECK_EQ(n->expr.as<tvm::ir::IntImm>()->value, 128);
// Check docstring
std::ostringstream os;
n->PrintDocString(os);
LOG(INFO) << "docstring\n"<< os.str();
CHECK(os.str().find("expr : Expr, default=1") != std::string::npos);
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
......@@ -36,6 +36,31 @@ def test_make_node():
assert AA.op == A.op
assert AA.value_index == A.value_index
def test_make_attrs():
try:
x = tvm.make.node("attrs.TestAttrs", unknown_key=1, name="xx")
assert False
except tvm.TVMError as e:
assert str(e).find("unknown_key") != -1
try:
x = tvm.make.node("attrs.TestAttrs", axis=100, name="xx")
assert False
except tvm.TVMError as e:
assert str(e).find("upper bound") != -1
x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3,4))
assert x.name == "xx"
assert x.padding[0].value == 3
assert x.padding[1].value == 4
assert x.axis == 10
dattr = tvm.make.node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
assert dattr.x.value == 1
def test_make_sum():
A = tvm.placeholder((2, 10), name='A')
k = tvm.reduce_axis((0,10), "k")
......@@ -46,6 +71,7 @@ def test_make_sum():
assert BB.op.body[0].combiner is not None
if __name__ == "__main__":
test_make_attrs()
test_make_node()
test_make_smap()
test_const_saveload_json()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment