Commit e62d909a by Jared Roesch Committed by Tianqi Chen

Fix serialization issue (#2263)

parent d8bd4762
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <tvm/node/container.h> #include <tvm/node/container.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <tvm/runtime/ndarray.h> #include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <dmlc/json.h> #include <dmlc/json.h>
#include <dmlc/memory_io.h> #include <dmlc/memory_io.h>
#include <string> #include <string>
...@@ -25,34 +26,12 @@ namespace tvm { ...@@ -25,34 +26,12 @@ namespace tvm {
} }
inline std::string Type2String(const Type& t) { inline std::string Type2String(const Type& t) {
if (t.code() ==Type::Handle) return "handle"; return runtime::TVMType2String(Type2TVMType(t));
std::ostringstream os;
os << t;
return os.str();
} }
inline Type String2Type(std::string s) { inline Type String2Type(std::string s) {
std::istringstream is(s); return TVMType2Type(runtime::String2TVMType(s));
halideir_type_code_t code = Type::Int;
if (s.substr(0, 3) == "int") {
code = Type::Int; s = s.substr(3);
} else if (s.substr(0, 4) == "uint") {
code = Type::UInt; s = s.substr(4);
} else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5);
} else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5);
} else if (s == "handle") {
return Handle();
} else {
LOG(FATAL) << "unknown type " << s;
}
int bits = 32, lanes = 1;
if (sscanf(s.c_str(), "%dx%d", &bits, &lanes) == 0) {
LOG(FATAL) << "unknown type " << s;
}
return Type(code, bits, lanes);
} }
......
...@@ -140,7 +140,7 @@ TVM_REGISTER_API("relay.op._Register") ...@@ -140,7 +140,7 @@ TVM_REGISTER_API("relay.op._Register")
NodePtr<Node> CreateOp(const std::string& name) { NodePtr<Node> CreateOp(const std::string& name) {
auto op = Op::Get(name); auto op = Op::Get(name);
CHECK(!op.defined()) << "Cannot find op \'" << name << '\''; CHECK(op.defined()) << "Cannot find op \'" << name << '\'';
return op.node_; return op.node_;
} }
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.expr import * from tvm.expr import *
from tvm.relay import op
from tvm.relay.ir_pass import graph_equal from tvm.relay.ir_pass import graph_equal
...@@ -209,6 +210,24 @@ def test_tuple_get_item(): ...@@ -209,6 +210,24 @@ def test_tuple_get_item():
check_json_roundtrip(get) check_json_roundtrip(get)
def test_op():
add = op.op.get("add")
check_json_roundtrip(add)
def test_conv2d_attrs():
data = relay.var('data', shape=(1, 3, 224, 224))
param = relay.var('param', shape=(64, 3, 7, 7))
out = op.nn.conv2d(
data,
param,
strides=(2, 2),
padding=(3, 3),
channels=64,
kernel_size=(7, 7))
check_json_roundtrip(out)
if __name__ == "__main__": if __name__ == "__main__":
test_bad_constructor() test_bad_constructor()
test_span() test_span()
...@@ -226,3 +245,5 @@ if __name__ == "__main__": ...@@ -226,3 +245,5 @@ if __name__ == "__main__":
test_let() test_let()
test_if() test_if()
test_tuple_get_item() test_tuple_get_item()
test_op()
test_conv2d_attrs()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment