Commit 5fced923 by Tianqi Chen Committed by Haichen Shen

[LANG] Enable json load/save and pickle (#10)

parent 7250005d
...@@ -21,6 +21,41 @@ using ::tvm::Node; ...@@ -21,6 +21,41 @@ using ::tvm::Node;
using ::tvm::NodeRef; using ::tvm::NodeRef;
using ::tvm::AttrVisitor; using ::tvm::AttrVisitor;
/*!
* \brief save the node as well as all the node it depends on as json.
* This can be used to serialize any TVM object
*
* \return the string representation of the node.
*/
std::string SaveJSON(const NodeRef& node);
/*!
* \brief Internal implementation of LoadJSON
* Load tvm Node object from json and return a shared_ptr of Node.
* \param json_str The json string to load from.
*
* \return The shared_ptr of the Node.
*/
std::shared_ptr<Node> LoadJSON_(std::string json_str);
/*!
* \brief Load the node from json string.
* This can be used to deserialize any TVM object.
*
* \param json_str The json string to load from.
*
* \tparam NodeType the nodetype
*
* \code
* Expr e = LoadJSON<Expr>(json_str);
* \endcode
*/
template<typename NodeType,
typename = typename std::enable_if<std::is_base_of<NodeRef, NodeType>::value>::type >
inline NodeType LoadJSON(const std::string& json_str) {
return NodeType(LoadJSON_(json_str));
}
/*! \brief typedef the factory function of data iterator */ /*! \brief typedef the factory function of data iterator */
using NodeFactory = std::function<std::shared_ptr<Node> ()>; using NodeFactory = std::function<std::shared_ptr<Node> ()>;
/*! /*!
...@@ -32,7 +67,8 @@ struct NodeFactoryReg ...@@ -32,7 +67,8 @@ struct NodeFactoryReg
}; };
#define TVM_REGISTER_NODE_TYPE(TypeName) \ #define TVM_REGISTER_NODE_TYPE(TypeName) \
DMLC_REGISTRY_REGISTER(::tvm::NodeFactoryReg, NodeFactoryReg, TypeName) \ static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \
::dmlc::Registry<::tvm::NodeFactoryReg>::Get()->__REGISTER__(TypeName::_type_key) \
.set_body([]() { return std::make_shared<TypeName>(); }) .set_body([]() { return std::make_shared<TypeName>(); })
} // namespace tvm } // namespace tvm
......
...@@ -15,14 +15,15 @@ ...@@ -15,14 +15,15 @@
/*! \brief TVM_DLL prefix for windows */ /*! \brief TVM_DLL prefix for windows */
#ifdef _WIN32 #ifdef _WIN32
#ifdef TVM_EXPORTS #ifdef TVM_EXPORTS
#define TVM_DLL TVM_EXTERN_C __declspec(dllexport) #define TVM_DLL __declspec(dllexport)
#else #else
#define TVM_DLL TVM_EXTERN_C __declspec(dllimport) #define TVM_DLL __declspec(dllimport)
#endif #endif
#else #else
#define TVM_DLL TVM_EXTERN_C #define TVM_DLL
#endif #endif
TVM_EXTERN_C {
/*! \brief handle to functions */ /*! \brief handle to functions */
typedef void* FunctionHandle; typedef void* FunctionHandle;
/*! \brief handle to node */ /*! \brief handle to node */
...@@ -147,5 +148,5 @@ TVM_DLL int TVMNodeGetAttr(NodeHandle handle, ...@@ -147,5 +148,5 @@ TVM_DLL int TVMNodeGetAttr(NodeHandle handle,
TVM_DLL int TVMNodeListAttrNames(NodeHandle handle, TVM_DLL int TVMNodeListAttrNames(NodeHandle handle,
int *out_size, int *out_size,
const char*** out_array); const char*** out_array);
} // TVM_EXTERN_C
#endif // TVM_C_API_H_ #endif // TVM_C_API_H_
...@@ -89,7 +89,6 @@ class NodeBase(object): ...@@ -89,7 +89,6 @@ class NodeBase(object):
"'%s' object has no attribute '%s'" % (str(type(self)), name)) "'%s' object has no attribute '%s'" % (str(type(self)), name))
return value return value
def __hash__(self): def __hash__(self):
return _function_internal._raw_ptr(self) return _function_internal._raw_ptr(self)
...@@ -111,6 +110,29 @@ class NodeBase(object): ...@@ -111,6 +110,29 @@ class NodeBase(object):
names.append(py_str(plist[i])) names.append(py_str(plist[i]))
return names return names
def __reduce__(self):
return (type(self), (None,), self.__getstate__())
def __getstate__(self):
handle = self.handle
if handle is not None:
return {'handle': _function_internal._save_json(self)}
else:
return {'handle': None}
def __setstate__(self, state):
# pylint: disable=assigning-non-slot
handle = state['handle']
if handle is not None:
json_str = handle
_push_arg(json_str)
other = _function_internal._load_json(json_str)
self.handle = other.handle
other.handle = None
else:
self.handle = None
def const(value, dtype=None): def const(value, dtype=None):
"""construct a constant""" """construct a constant"""
if dtype is None: if dtype is None:
......
...@@ -19,6 +19,38 @@ def const(value, dtype=None): ...@@ -19,6 +19,38 @@ def const(value, dtype=None):
return _function_internal._const(value, dtype) return _function_internal._const(value, dtype)
def load_json(json_str):
"""Load tvm object from json_str.
Parameters
----------
json_str : str
The json string
Returns
-------
node : Node
The loaded tvm node.
"""
return _function_internal._load_json(json_str)
def save_json(node):
"""Load tvm object as json string.
Parameters
----------
node : Node
A TVM Node object to be saved.
Returns
-------
json_str : str
Saved json string.
"""
return _function_internal._save_json(node)
def Var(name="tindex", dtype=int32): def Var(name="tindex", dtype=int32):
"""Create a new variable with specified name and dtype """Create a new variable with specified name and dtype
......
/*!
* Copyright (c) 2016 by Contributors
* \file common.h
* \brief Common utilities
*/
#ifndef TVM_BASE_COMMON_H_
#define TVM_BASE_COMMON_H_
#include <tvm/base.h>
#include <string>
namespace tvm {
inline std::string Type2String(const Type& t) {
std::ostringstream os;
os << t;
return os.str();
}
inline Type String2Type(std::string s) {
std::istringstream is(s);
halide_type_code_t code = Type::Int;
if (s.substr(0, 3) == "int") {
code = Type::Int; s = s.substr(3);
} else if (s.substr(0, 4) == "uint") {
code = Type::UInt; s = s.substr(4);
} else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5);
} else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5);
} else {
LOG(FATAL) << "unknown type " << s;
}
int bits = 32, lanes = 1;
if (sscanf(s.c_str(), "%dx%d", &bits, &lanes) == 0) {
LOG(FATAL) << "unknown type " << s;
}
return Type(code, bits, lanes);
}
} // namespace tvm
#endif // TVM_BASE_COMMON_H_
/*!
* Copyright (c) 2016 by Contributors
* \file saveload_json.cc
* \brief Utilities to save/load TVM objects.
*/
#include <tvm/base.h>
#include <tvm/container.h>
#include <dmlc/json.h>
#include <string>
#include "./common.h"
namespace tvm {
// indexer to index all the ndoes
class NodeIndexer : public AttrVisitor {
public:
std::unordered_map<Node*, size_t> node_index{{nullptr, 0}};
std::vector<Node*> node_list{nullptr};
void Visit(const char* key, double* value) final {}
void Visit(const char* key, int64_t* value) final {}
void Visit(const char* key, uint64_t* value) final {}
void Visit(const char* key, int* value) final {}
void Visit(const char* key, bool* value) final {}
void Visit(const char* key, std::string* value) final {}
void Visit(const char* key, Type* value) final {}
void Visit(const char* key, NodeRef* value) final {
MakeIndex(value->node_.get());
}
// make index of all the children of node
void MakeIndex(Node* node) {
if (node == nullptr) return;
if (node_index.count(node)) return;
CHECK_EQ(node_index.size(), node_list.size());
node_index[node] = node_list.size();
node_list.push_back(node);
if (node->is_type<ArrayNode>()) {
ArrayNode* n = static_cast<ArrayNode*>(node);
for (const auto& sp : n->data) {
MakeIndex(sp.get());
}
} else if (node->is_type<MapNode>()) {
MapNode* n = static_cast<MapNode*>(node);
for (const auto& kv : n->data) {
MakeIndex(kv.first.get());
MakeIndex(kv.second.get());
}
} else {
node->VisitAttrs(this);
}
}
};
// use map so attributes are ordered.
using AttrMap = std::map<std::string, std::string>;
// A Node structure for JSON node.
struct JSONNode {
// The type key of the data
std::string type_key;
// the attributes
AttrMap attrs;
// container data
std::vector<size_t> data;
void Save(dmlc::JSONWriter *writer) const {
writer->BeginObject();
writer->WriteObjectKeyValue("type_key", type_key);
if (attrs.size() != 0) {
writer->WriteObjectKeyValue("attrs", attrs);
}
if (data.size() != 0) {
writer->WriteObjectKeyValue("data", data);
}
writer->EndObject();
}
void Load(dmlc::JSONReader *reader) {
attrs.clear();
data.clear();
type_key.clear();
dmlc::JSONObjectReadHelper helper;
helper.DeclareOptionalField("type_key", &type_key);
helper.DeclareOptionalField("attrs", &attrs);
helper.DeclareOptionalField("data", &data);
helper.ReadAllFields(reader);
}
};
class JSONAttrGetter : public AttrVisitor {
public:
const std::unordered_map<Node*, size_t>* node_index_;
JSONNode* node_;
void Visit(const char* key, double* value) final {
node_->attrs[key] = std::to_string(*value);
}
void Visit(const char* key, int64_t* value) final {
node_->attrs[key] = std::to_string(*value);
}
void Visit(const char* key, uint64_t* value) final {
node_->attrs[key] = std::to_string(*value);
}
void Visit(const char* key, int* value) final {
node_->attrs[key] = std::to_string(*value);
}
void Visit(const char* key, bool* value) final {
node_->attrs[key] = std::to_string(*value);
}
void Visit(const char* key, std::string* value) final {
node_->attrs[key] = *value;
}
void Visit(const char* key, Type* value) final {
node_->attrs[key] = Type2String(*value);
}
void Visit(const char* key, NodeRef* value) final {
node_->attrs[key] = std::to_string(
node_index_->at(value->node_.get()));
}
// Get the node
void Get(Node* node) {
if (node == nullptr) {
node_->type_key.clear();
return;
}
node_->type_key = node->type_key();
node_->attrs.clear();
node_->data.clear();
if (node->is_type<ArrayNode>()) {
ArrayNode* n = static_cast<ArrayNode*>(node);
for (size_t i = 0; i < n->data.size(); ++i) {
node_->data.push_back(
node_index_->at(n->data[i].get()));
}
} else if (node->is_type<MapNode>()) {
MapNode* n = static_cast<MapNode*>(node);
std::vector<std::pair<size_t, size_t> > elems;
for (const auto& kv : n->data) {
node_->data.push_back(
node_index_->at(kv.first.get()));
node_->data.push_back(
node_index_->at(kv.second.get()));
}
} else {
node->VisitAttrs(this);
}
}
};
class JSONAttrSetter : public AttrVisitor {
public:
const std::vector<std::shared_ptr<Node> >* node_list_;
JSONNode* node_;
std::string GetValue(const char* key) const {
auto it = node_->attrs.find(key);
if (it == node_->attrs.end()) {
LOG(FATAL) << "JSONReader: cannot find field " << key;
}
return it->second;
}
template<typename T>
void ParseValue(const char* key, T* value) const {
std::istringstream is(GetValue(key));
is >> *value;
if (is.fail()) {
LOG(FATAL) << "Wrong value format for field " << key;
}
}
void Visit(const char* key, double* value) final {
ParseValue(key, value);
}
void Visit(const char* key, int64_t* value) final {
ParseValue(key, value);
}
void Visit(const char* key, uint64_t* value) final {
ParseValue(key, value);
}
void Visit(const char* key, int* value) final {
ParseValue(key, value);
}
void Visit(const char* key, bool* value) final {
ParseValue(key, value);
}
void Visit(const char* key, std::string* value) final {
*value = GetValue(key);
}
void Visit(const char* key, Type* value) final {
std::string stype = GetValue(key);
*value = String2Type(stype);
}
void Visit(const char* key, NodeRef* value) final {
size_t index;
ParseValue(key, &index);
value->node_ = node_list_->at(index);
}
// Get the node
void Set(Node* node) {
if (node == nullptr) return;
if (node->is_type<ArrayNode>()) {
ArrayNode* n = static_cast<ArrayNode*>(node);
n->data.clear();
for (size_t index : node_->data) {
n->data.push_back(node_list_->at(index));
}
} else if (node->is_type<MapNode>()) {
MapNode* n = static_cast<MapNode*>(node);
CHECK_EQ(node_->data.size() % 2, 0U);
for (size_t i = 0; i < node_->data.size(); i += 2) {
n->data[node_list_->at(node_->data[i])]
= node_list_->at(node_->data[i + 1]);
}
} else {
node->VisitAttrs(this);
}
}
};
// json graph structure to store node
struct JSONGraph {
// the root of the graph
size_t root;
// the nodes of the graph
std::vector<JSONNode> nodes;
// global attributes
AttrMap attrs;
void Save(dmlc::JSONWriter *writer) const {
writer->BeginObject();
writer->WriteObjectKeyValue("root", root);
writer->WriteObjectKeyValue("nodes", nodes);
if (attrs.size() != 0) {
writer->WriteObjectKeyValue("attrs", attrs);
}
writer->EndObject();
}
void Load(dmlc::JSONReader *reader) {
attrs.clear();
dmlc::JSONObjectReadHelper helper;
helper.DeclareField("root", &root);
helper.DeclareField("nodes", &nodes);
helper.DeclareOptionalField("attrs", &attrs);
helper.ReadAllFields(reader);
}
static JSONGraph Create(const NodeRef& root) {
JSONGraph g;
NodeIndexer indexer;
indexer.MakeIndex(root.node_.get());
JSONAttrGetter getter;
getter.node_index_ = &indexer.node_index;
for (Node* n : indexer.node_list) {
JSONNode jnode;
getter.node_ = &jnode;
getter.Get(n);
g.nodes.emplace_back(std::move(jnode));
}
g.attrs["tvm_version"] = "0.1.0";
g.root = indexer.node_index.at(root.node_.get());
return g;
}
};
std::string SaveJSON(const NodeRef& n) {
auto jgraph = JSONGraph::Create(n);
std::ostringstream os;
dmlc::JSONWriter writer(&os);
jgraph.Save(&writer);
return os.str();
}
std::shared_ptr<Node> LoadJSON_(std::string json_str) {
std::istringstream is(json_str);
dmlc::JSONReader reader(&is);
JSONGraph jgraph;
// load in json graph.
jgraph.Load(&reader);
std::vector<std::shared_ptr<Node> > nodes;
// node 0 is always null
nodes.reserve(jgraph.nodes.size());
for (const JSONNode& jnode : jgraph.nodes) {
if (jnode.type_key.length() != 0) {
auto* f = dmlc::Registry<NodeFactoryReg>::Find(jnode.type_key);
CHECK(f != nullptr)
<< "Node type \'" << jnode.type_key << "\' is not registered in TVM";
nodes.emplace_back(f->body());
} else {
nodes.emplace_back(std::shared_ptr<Node>());
}
}
CHECK_EQ(nodes.size(), jgraph.nodes.size());
JSONAttrSetter setter;
setter.node_list_ = &nodes;
for (size_t i = 0; i < nodes.size(); ++i) {
setter.node_ = &jgraph.nodes[i];
setter.Set(nodes[i].get());
}
return nodes.at(jgraph.root);
}
} // namespace tvm
...@@ -34,4 +34,16 @@ TVM_REGISTER_API(_raw_ptr) ...@@ -34,4 +34,16 @@ TVM_REGISTER_API(_raw_ptr)
}) })
.add_argument("src", "NodeBase", "the node base"); .add_argument("src", "NodeBase", "the node base");
TVM_REGISTER_API(_save_json)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = SaveJSON(args.at(0));
})
.add_argument("src", "json_str", "the node ");
TVM_REGISTER_API(_load_json)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = NodeRef(LoadJSON_(args.at(0)));
})
.add_argument("src", "NodeBase", "the node");
} // namespace tvm } // namespace tvm
...@@ -13,36 +13,10 @@ ...@@ -13,36 +13,10 @@
#include <limits> #include <limits>
#include <string> #include <string>
#include <vector> #include <vector>
#include "../base/common.h"
namespace tvm { namespace tvm {
inline std::string Type2String(const Type& t) {
std::ostringstream os;
os << t;
return os.str();
}
inline Type String2Type(std::string s) {
std::istringstream is(s);
halide_type_code_t code = Type::Int;
if (s.substr(0, 3) == "int") {
code = Type::Int; s = s.substr(3);
} else if (s.substr(0, 4) == "uint") {
code = Type::UInt; s = s.substr(4);
} else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5);
} else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5);
} else {
LOG(FATAL) << "unknown type " << s;
}
int bits = 32, lanes = 1;
if (sscanf(s.c_str(), "%dx%d", &bits, &lanes) == 0) {
LOG(FATAL) << "unknown type " << s;
}
return Type(code, bits, lanes);
}
inline const char* TypeId2Str(ArgVariantID type_id) { inline const char* TypeId2Str(ArgVariantID type_id) {
switch (type_id) { switch (type_id) {
case kNull: return "Null"; case kNull: return "Null";
......
...@@ -13,8 +13,11 @@ DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg); ...@@ -13,8 +13,11 @@ DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
} // namespace dmlc } // namespace dmlc
namespace tvm { namespace tvm {
using Halide::IR::RangeNode;
Range::Range(Expr begin, Expr end) Range::Range(Expr begin, Expr end)
: Range(std::make_shared<Halide::IR::RangeNode>( : Range(std::make_shared<RangeNode>(
begin, begin,
is_zero(begin) ? end : (end - begin))) { is_zero(begin) ? end : (end - begin))) {
} }
...@@ -67,10 +70,14 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -67,10 +70,14 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Halide::IR::RangeNode>([](const Halide::IR::RangeNode *op, IRPrinter *p) { .set_dispatch<RangeNode>([](const Halide::IR::RangeNode *op, IRPrinter *p) {
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
}); });
TVM_REGISTER_NODE_TYPE(ArrayNode);
TVM_REGISTER_NODE_TYPE(MapNode);
TVM_REGISTER_NODE_TYPE(RangeNode);
TVM_REGISTER_NODE_TYPE(IterVarNode); TVM_REGISTER_NODE_TYPE(IterVarNode);
} // namespace tvm } // namespace tvm
...@@ -206,5 +206,6 @@ IterVarRelation FuseNode::make( ...@@ -206,5 +206,6 @@ IterVarRelation FuseNode::make(
TVM_REGISTER_NODE_TYPE(StageNode); TVM_REGISTER_NODE_TYPE(StageNode);
TVM_REGISTER_NODE_TYPE(SplitNode); TVM_REGISTER_NODE_TYPE(SplitNode);
TVM_REGISTER_NODE_TYPE(FuseNode); TVM_REGISTER_NODE_TYPE(FuseNode);
TVM_REGISTER_NODE_TYPE(ScheduleNode);
} // namespace tvm } // namespace tvm
...@@ -6,8 +6,11 @@ TEST(Expr, Basic) { ...@@ -6,8 +6,11 @@ TEST(Expr, Basic) {
using namespace tvm; using namespace tvm;
Var x("x"); Var x("x");
auto z = max(x + 1 + 2, 100); auto z = max(x + 1 + 2, 100);
NodeRef tmp = z;
Expr zz(tmp.node_);
std::ostringstream os; std::ostringstream os;
os << z; os << z;
CHECK(zz.same_as(z));
CHECK(os.str() == "max(((x + 1) + 2), 100)"); CHECK(os.str() == "max(((x + 1) + 2), 100)");
} }
......
...@@ -5,6 +5,16 @@ def test_const(): ...@@ -5,6 +5,16 @@ def test_const():
assert x.dtype == 'int32' assert x.dtype == 'int32'
assert isinstance(x, tvm.expr.IntImm) assert isinstance(x, tvm.expr.IntImm)
def test_const_saveload_json():
# save load json
x = tvm.const(1)
y = tvm.const(10)
z = x + y
z = z + z
json_str = tvm.save_json(z)
zz = tvm.load_json(json_str)
assert tvm.save_json(zz) == tvm.save_json(z)
def test_make(): def test_make():
x = tvm.const(1) x = tvm.const(1)
y = tvm.make.IntImm('int32', 1) y = tvm.make.IntImm('int32', 1)
...@@ -57,6 +67,7 @@ def test_stmt(): ...@@ -57,6 +67,7 @@ def test_stmt():
if __name__ == "__main__": if __name__ == "__main__":
test_attr() test_attr()
test_const() test_const()
test_const_saveload_json()
test_make() test_make()
test_ir() test_ir()
test_basic() test_basic()
......
...@@ -4,6 +4,12 @@ def test_array(): ...@@ -4,6 +4,12 @@ def test_array():
a = tvm.convert([1,2,3]) a = tvm.convert([1,2,3])
assert len(a) == 3 assert len(a) == 3
def test_array_save_load_json():
a = tvm.convert([1,2,3])
json_str = tvm.save_json(a)
a_loaded = tvm.load_json(json_str)
assert(a[1].value == 2)
def test_map(): def test_map():
a = tvm.Var('a') a = tvm.Var('a')
b = tvm.Var('b') b = tvm.Var('b')
...@@ -15,6 +21,20 @@ def test_map(): ...@@ -15,6 +21,20 @@ def test_map():
assert str(dd) == str(amap) assert str(dd) == str(amap)
assert a + 1 not in amap assert a + 1 not in amap
def test_map_save_load_json():
a = tvm.Var('a')
b = tvm.Var('b')
amap = tvm.convert({a: 2,
b: 3})
json_str = tvm.save_json(amap)
amap = tvm.load_json(json_str)
assert len(amap) == 2
dd = {kv[0].name : kv[1].value for kv in amap.items()}
assert(dd == {"a": 2, "b": 3})
if __name__ == "__main__": if __name__ == "__main__":
test_array() test_array()
test_map() test_map()
test_array_save_load_json()
test_map_save_load_json()
import tvm import tvm
import pickle as pkl
def test_schedule_create(): def test_schedule_create():
m = tvm.Var('m') m = tvm.Var('m')
...@@ -17,6 +18,18 @@ def test_schedule_create(): ...@@ -17,6 +18,18 @@ def test_schedule_create():
s[T].reorder(xi2, xi1) s[T].reorder(xi2, xi1)
assert T.op.axis[1] in s[T].leaf_iter_vars assert T.op.axis[1] in s[T].leaf_iter_vars
# save load json
json_str = tvm.save_json(s)
s_loaded = tvm.load_json(json_str)
assert isinstance(s_loaded, tvm.schedule.Schedule)
assert(str(s_loaded.roots[0].body) == str(s.roots[0].body))
# pickle unpickle
dump = pkl.dumps(s)
s_loaded = pkl.loads(dump)
assert isinstance(s_loaded, tvm.schedule.Schedule)
assert(str(s_loaded.roots[0].body) == str(s.roots[0].body))
def test_reorder(): def test_reorder():
m = tvm.Var('m') m = tvm.Var('m')
A = tvm.placeholder((m,), name='A') A = tvm.placeholder((m,), name='A')
......
...@@ -27,7 +27,11 @@ def test_tensor_reduce(): ...@@ -27,7 +27,11 @@ def test_tensor_reduce():
T = tvm.compute((m, n, l), lambda i, j, k: A[i, k] * B[j, k]) T = tvm.compute((m, n, l), lambda i, j, k: A[i, k] * B[j, k])
rv = tvm.IterVar((0, A.shape[1]), name="k") rv = tvm.IterVar((0, A.shape[1]), name="k")
C = tvm.compute((m, n), lambda i, j: tvm.sum(T(i, j, rv+1), rdom=rv)) C = tvm.compute((m, n), lambda i, j: tvm.sum(T(i, j, rv+1), rdom=rv))
print(C.op.body) # json load save
C_json = tvm.save_json(C)
C_loaded = tvm.load_json(C_json)
assert(isinstance(C_loaded, tvm.tensor.Tensor))
assert(str(C_loaded) == str(C))
if __name__ == "__main__": if __name__ == "__main__":
test_tensor() test_tensor()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment