Commit ddbbacc3 by ziheng Committed by Tianqi Chen

[CMPL] Add Support for Other Data Types (#252)

* [CMPL] Add Support for Other Data Types

* [CMPL] Add test

* [CMPL] Fix
parent 97aaadeb
......@@ -56,6 +56,7 @@ enum TypeFlag {
kInt32 = 4,
kInt8 = 5,
kInt64 = 6,
kInt16 = 7,
};
struct CastParam : public dmlc::Parameter<CastParam> {
......@@ -67,6 +68,9 @@ struct CastParam : public dmlc::Parameter<CastParam> {
.add_enum("float16", kFloat16)
.add_enum("uint8", kUint8)
.add_enum("int32", kInt32)
.add_enum("int8", kInt8)
.add_enum("int64", kInt64)
.add_enum("int16", kInt16)
.describe("Output data type.");
}
};
......
......@@ -28,12 +28,26 @@ def set_shape_inputs(g, shape):
DTYPE_TO_TCODE = {
"default": -1,
"float32": 0
"float32": 0,
"float64": 1,
"float16": 2,
"uint8": 3,
"int32": 4,
"int8": 5,
"int64": 6,
"int16": 7,
}
TCODE_TO_DTYPE = {
-1: None,
0: "float32"
0: "float32",
1: "float64",
2: "float16",
3: "uint8",
4: "int32",
5: "int8",
6: "int64",
7: "int16",
}
def set_dtype_inputs(g, dtype):
......
......@@ -27,14 +27,39 @@ using namespace tvm;
*/
int GetTypeFlag(tvm::Type type) {
if (type == tvm::Float(32)) return 0;
if (type == tvm::Float(64)) return 1;
if (type == tvm::Float(16)) return 2;
if (type == tvm::UInt(8)) return 3;
if (type == tvm::Int(32)) return 4;
if (type == tvm::Int(8)) return 5;
if (type == tvm::Int(64)) return 6;
if (type == tvm::Int(16)) return 7;
LOG(FATAL) << "cannot convert " << type;
return 0;
}
// convert from type flag to tvm type.
Type GetTVMType(int type_flag) {
if (type_flag == 0) return tvm::Float(32);
switch (type_flag) {
case 0:
return tvm::Float(32);
case 1:
return tvm::Float(64);
case 2:
return tvm::Float(16);
case 3:
return tvm::UInt(8);
case 4:
return tvm::Int(32);
case 5:
return tvm::Int(8);
case 6:
return tvm::Int(64);
case 7:
return tvm::Int(16);
default:
LOG(FATAL) << "unknown type_flag=" << type_flag;
return Float(32);
}
}
// internal compile engine
......
......@@ -90,6 +90,22 @@ GraphFunc GraphLower(Graph graph,
const Op* schedule_op_key,
const NodeAttrs& schedule_op_attr);
/*!
* \brief Get type flag from TVM Type
*
* \param type the tvm type
* \return corresponding DLDataType
*/
int GetTypeFlag(tvm::Type type);
/*!
* \brief Get TVM Type from type flag
*
* \param type_flag the type flag
* \return corresponding TVM type
*/
tvm::Type GetTVMType(int type_flag);
} // namespace compiler
} // namespace nnvm
......
......@@ -36,9 +36,7 @@ enum class FuseRule {
* \return corresponding DLDataType
*/
DLDataType GetDLType(int type_flag) {
if (type_flag == 0) return tvm::Type2TVMType(tvm::Float(32));
LOG(FATAL) << "unknown type_flag=" << type_flag;
return Type2TVMType(Float(32));
return Type2TVMType(GetTVMType(type_flag));
}
// Partition the graph into segments
......
......@@ -77,7 +77,26 @@ def test_precompute_prune():
res.asnumpy(), nx.asnumpy() + 1 + ny.asnumpy() + na.asnumpy())
def test_dtypes():
x = sym.Variable("x")
y = sym.relu(x)
dshape = (1, 3, 32, 32)
oshape = dshape
for dtype in ['float32', 'float64', 'int32', 'int16', 'int8', 'int64']:
graph, lib, _ = nnvm.compiler.build(y, 'llvm', {"x": dshape}, dtype=dtype)
m = graph_runtime.create(graph, lib, tvm.cpu())
if 'float' in dtype:
data = np.random.uniform(size=dshape).astype(dtype)
elif 'int' in dtype:
data = np.random.randint(-127, 127, dshape).astype(dtype)
m.run(x=data)
data = (data > 0) * data
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
np.testing.assert_allclose(out.asnumpy(), data, atol=1e-5, rtol=1e-5)
if __name__ == "__main__":
test_precompute_prune()
test_compile()
test_run()
test_dtypes()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment