Commit 1f7712ae by Tianqi Chen Committed by GitHub

[LANG] Add reflection routine to construct node (#265)

parent 68c4400e
/*!
* Copyright (c) 2017 by Contributors
* \file target_info.h
* \brief Various information about target.
*/
#ifndef TVM_TARGET_INFO_H_
#define TVM_TARGET_INFO_H_
#include "./base.h"
#include "./expr.h"
namespace tvm {
/*!
* \brief Memory information of special memory region.
* Use MemoryInfo as its container type
*/
struct MemoryInfoNode : public Node {
/*! \brief The addressable unit */
int unit_bits;
/*! \brief Maximum number of bits supported in the memory */
int max_num_bits;
/*! \brief maximum number of bits to be used in simd op */
int max_simd_bits;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("unit_bits", &unit_bits);
v->Visit("max_num_bits", &max_num_bits);
v->Visit("max_simd_bits", &max_simd_bits);
}
static constexpr const char* _type_key = "MemoryInfo";
TVM_DECLARE_NODE_TYPE_INFO(MemoryInfoNode, Node);
};
/*! \brief Defines memory info */
TVM_DEFINE_NODE_REF(MemoryInfo, MemoryInfoNode);
} // namespace tvm
#endif // TVM_TARGET_INFO_H_
......@@ -30,6 +30,33 @@ def range_by_min_extent(min_value, extent):
return _range_by_min_extent(min_value, extent)
def node(type_key, **kwargs):
"""Make a new DSL node by its type key and fields
Parameters
----------
type_key : str
The type key of the node.
**kwargs : dict
The fields of the node.
Example
-------
The following code constructs a IntImm object
.. code-block:: python
x = tvm.make.node("IntImm", dtype="int32", value=10)
assert isinstance(x, tvm.expr.IntImm)
assert x.value == 10
"""
args = [type_key]
for k, v in kwargs.items():
args += [k, v]
return _Node(*args)
def stmt_seq(*args):
"""Make sequence of statements
......
......@@ -30,7 +30,7 @@ TVM_REGISTER_API("_save_json")
TVM_REGISTER_API("_load_json")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = NodeRef(LoadJSON_(args[0]));
*ret = LoadJSON<NodeRef>(args[0]);
});
TVM_REGISTER_API("_nop")
......
......@@ -8,10 +8,6 @@
#include <ir/IRPrinter.h>
#include <memory>
namespace dmlc {
DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
} // namespace dmlc
namespace tvm {
using Halide::IR::RangeNode;
......
/*!
* Copyright (c) 2016 by Contributors
* \file saveload_json.cc
* \brief Utilities to save/load TVM objects.
* \file reflection.cc
* \brief Utilities to save/load/construct TVM objects
*/
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/container.h>
#include <tvm/packed_func_ext.h>
#include <dmlc/json.h>
#include <string>
namespace dmlc {
DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
} // namespace dmlc
namespace tvm {
inline std::string Type2String(const Type& t) {
......@@ -334,4 +339,75 @@ std::shared_ptr<Node> LoadJSON_(std::string json_str) {
return nodes.at(jgraph.root);
}
class NodeAttrSetter : public AttrVisitor {
public:
std::string type_key;
std::unordered_map<std::string, runtime::TVMArgValue> attrs;
template<typename T>
void SetValue(const char* key, T* value) {
auto it = attrs.find(key);
if (it == attrs.end()) {
LOG(FATAL) << type_key << ": require field " << key;
}
*value = it->second.operator T();
attrs.erase(it);
}
void Visit(const char* key, double* value) final {
SetValue(key, value);
}
void Visit(const char* key, int64_t* value) final {
SetValue(key, value);
}
void Visit(const char* key, uint64_t* value) final {
SetValue(key, value);
}
void Visit(const char* key, int* value) final {
SetValue(key, value);
}
void Visit(const char* key, bool* value) final {
SetValue(key, value);
}
void Visit(const char* key, std::string* value) final {
SetValue(key, value);
}
void Visit(const char* key, Type* value) final {
SetValue(key, value);
}
void Visit(const char* key, NodeRef* value) final {
SetValue(key, value);
}
};
// API function to make node.
// args format:
// type_key, key1, value1, ..., key_n, value_n
void MakeNode(runtime::TVMArgs args, runtime::TVMRetValue* rv) {
NodeAttrSetter setter;
setter.type_key = args[0].operator std::string();
CHECK_EQ(args.size() % 2, 1);
for (int i = 1; i < args.size(); i += 2) {
setter.attrs.emplace(
args[i].operator std::string(),
runtime::TVMArgValue(args.values[i + 1], args.type_codes[i + 1]));
}
auto* f = dmlc::Registry<NodeFactoryReg>::Find(setter.type_key);
CHECK(f != nullptr)
<< "Node type \'" << setter.type_key << "\' is not registered in TVM";
std::shared_ptr<Node> n = f->body();
n->VisitAttrs(&setter);
if (setter.attrs.size() != 0) {
std::ostringstream os;
os << setter.type_key << " does not contain field ";
for (const auto &kv : setter.attrs) {
os << " " << kv.first;
}
LOG(FATAL) << os.str();
}
*rv = NodeRef(n);
}
TVM_REGISTER_GLOBAL("make._Node")
.set_body(MakeNode);
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file target_info.cc
*/
#include <tvm/target_info.h>
namespace tvm {
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<MemoryInfoNode>([](const MemoryInfoNode *op, IRPrinter *p) {
p->stream << "mem-info("
<< "unit_bits=" << op->unit_bits << ", "
<< "max_num_bits=" << op->max_num_bits << ", "
<< "max_simd_bits=" << op->max_simd_bits << ")";
});
TVM_REGISTER_NODE_TYPE(MemoryInfoNode);
} // namespace tvm
......@@ -6,21 +6,10 @@ def test_const():
assert x.dtype == tvm.int32
assert isinstance(x, tvm.expr.IntImm)
def test_const_saveload_json():
# save load json
x = tvm.const(1)
y = tvm.const(10)
z = x + y
z = z + z
json_str = tvm.save_json(z)
zz = tvm.load_json(json_str)
assert tvm.save_json(zz) == tvm.save_json(z)
def test_make():
x = tvm.const(1)
y = tvm.make.IntImm('int32', 1)
z = x + y
print(z)
def test_ir():
x = tvm.const(1)
......@@ -125,11 +114,10 @@ def test_all():
'(((%s < %s) && (%s > (%s + 1))) && (%s < (%s*2)))' % (
x.name, y.name, y.name, z.name, x.name, z.name)
if __name__ == "__main__":
test_attr()
test_const()
test_const_saveload_json()
test_make()
test_ir()
test_basic()
......
import tvm
def test_const_saveload_json():
# save load json
x = tvm.const(1)
y = tvm.const(10)
z = x + y
z = z + z
json_str = tvm.save_json(z)
zz = tvm.load_json(json_str)
assert tvm.save_json(zz) == tvm.save_json(z)
def test_make_node():
x = tvm.make.node("IntImm", dtype="int32", value=10)
assert isinstance(x, tvm.expr.IntImm)
assert x.value == 10
A = tvm.placeholder((10, ), name='A')
AA = tvm.make.node("Tensor",
shape=A.shape,
dtype=A.dtype,
op=A.op,
value_index=A.value_index)
assert AA.op == A.op
assert AA.value_index == A.value_index
if __name__ == "__main__":
test_make_node()
test_const_saveload_json()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment