Commit e62d909a by Jared Roesch Committed by Tianqi Chen

Fix serialization issue (#2263)

parent d8bd4762
......@@ -9,6 +9,7 @@
#include <tvm/node/container.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <dmlc/json.h>
#include <dmlc/memory_io.h>
#include <string>
......@@ -25,34 +26,12 @@ namespace tvm {
}
inline std::string Type2String(const Type& t) {
if (t.code() ==Type::Handle) return "handle";
std::ostringstream os;
os << t;
return os.str();
return runtime::TVMType2String(Type2TVMType(t));
}
inline Type String2Type(std::string s) {
std::istringstream is(s);
halideir_type_code_t code = Type::Int;
if (s.substr(0, 3) == "int") {
code = Type::Int; s = s.substr(3);
} else if (s.substr(0, 4) == "uint") {
code = Type::UInt; s = s.substr(4);
} else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5);
} else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5);
} else if (s == "handle") {
return Handle();
} else {
LOG(FATAL) << "unknown type " << s;
}
int bits = 32, lanes = 1;
if (sscanf(s.c_str(), "%dx%d", &bits, &lanes) == 0) {
LOG(FATAL) << "unknown type " << s;
}
return Type(code, bits, lanes);
return TVMType2Type(runtime::String2TVMType(s));
}
......
......@@ -140,7 +140,7 @@ TVM_REGISTER_API("relay.op._Register")
NodePtr<Node> CreateOp(const std::string& name) {
auto op = Op::Get(name);
CHECK(!op.defined()) << "Cannot find op \'" << name << '\'';
CHECK(op.defined()) << "Cannot find op \'" << name << '\'';
return op.node_;
}
......
......@@ -2,6 +2,7 @@
import tvm
from tvm import relay
from tvm.expr import *
from tvm.relay import op
from tvm.relay.ir_pass import graph_equal
......@@ -209,6 +210,24 @@ def test_tuple_get_item():
check_json_roundtrip(get)
def test_op():
add = op.op.get("add")
check_json_roundtrip(add)
def test_conv2d_attrs():
data = relay.var('data', shape=(1, 3, 224, 224))
param = relay.var('param', shape=(64, 3, 7, 7))
out = op.nn.conv2d(
data,
param,
strides=(2, 2),
padding=(3, 3),
channels=64,
kernel_size=(7, 7))
check_json_roundtrip(out)
if __name__ == "__main__":
test_bad_constructor()
test_span()
......@@ -226,3 +245,5 @@ if __name__ == "__main__":
test_let()
test_if()
test_tuple_get_item()
test_op()
test_conv2d_attrs()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment