Commit 2406757d by abergeron Committed by Tianqi Chen

Add support for missing uint types. (#272)

parent 02141d4a
...@@ -57,20 +57,26 @@ enum TypeFlag { ...@@ -57,20 +57,26 @@ enum TypeFlag {
kInt8 = 5, kInt8 = 5,
kInt64 = 6, kInt64 = 6,
kInt16 = 7, kInt16 = 7,
kUint16 = 8,
kUint32 = 9,
kUint64 = 10,
}; };
struct CastParam : public dmlc::Parameter<CastParam> { struct CastParam : public dmlc::Parameter<CastParam> {
int dtype; int dtype;
DMLC_DECLARE_PARAMETER(CastParam) { DMLC_DECLARE_PARAMETER(CastParam) {
DMLC_DECLARE_FIELD(dtype) DMLC_DECLARE_FIELD(dtype)
.add_enum("float16", kFloat16)
.add_enum("float32", kFloat32) .add_enum("float32", kFloat32)
.add_enum("float64", kFloat64) .add_enum("float64", kFloat64)
.add_enum("float16", kFloat16) .add_enum("uint8", kUint8)
.add_enum("uint8", kUint8) .add_enum("uint16", kUint16)
.add_enum("uint32", kUint32)
.add_enum("uint64", kUint64)
.add_enum("int8", kInt8)
.add_enum("int16", kInt16)
.add_enum("int32", kInt32) .add_enum("int32", kInt32)
.add_enum("int8", kInt8)
.add_enum("int64", kInt64) .add_enum("int64", kInt64)
.add_enum("int16", kInt16)
.describe("Output data type."); .describe("Output data type.");
} }
}; };
......
...@@ -36,6 +36,9 @@ DTYPE_TO_TCODE = { ...@@ -36,6 +36,9 @@ DTYPE_TO_TCODE = {
"int8": 5, "int8": 5,
"int64": 6, "int64": 6,
"int16": 7, "int16": 7,
"uint16": 8,
"uint32": 9,
"uint64": 10,
} }
TCODE_TO_DTYPE = { TCODE_TO_DTYPE = {
...@@ -48,6 +51,9 @@ TCODE_TO_DTYPE = { ...@@ -48,6 +51,9 @@ TCODE_TO_DTYPE = {
5: "int8", 5: "int8",
6: "int64", 6: "int64",
7: "int16", 7: "int16",
8: "uint16",
9: "uint32",
10: "uint64",
} }
def set_dtype_inputs(g, dtype): def set_dtype_inputs(g, dtype):
......
...@@ -34,6 +34,9 @@ int GetTypeFlag(tvm::Type type) { ...@@ -34,6 +34,9 @@ int GetTypeFlag(tvm::Type type) {
if (type == tvm::Int(8)) return 5; if (type == tvm::Int(8)) return 5;
if (type == tvm::Int(64)) return 6; if (type == tvm::Int(64)) return 6;
if (type == tvm::Int(16)) return 7; if (type == tvm::Int(16)) return 7;
if (type == tvm::UInt(16)) return 8;
if (type == tvm::UInt(32)) return 9;
if (type == tvm::UInt(64)) return 10;
LOG(FATAL) << "cannot convert " << type; LOG(FATAL) << "cannot convert " << type;
return 0; return 0;
} }
...@@ -56,6 +59,12 @@ Type GetTVMType(int type_flag) { ...@@ -56,6 +59,12 @@ Type GetTVMType(int type_flag) {
return tvm::Int(64); return tvm::Int(64);
case 7: case 7:
return tvm::Int(16); return tvm::Int(16);
case 8:
return tvm::UInt(16);
case 9:
return tvm::UInt(32);
case 10:
return tvm::UInt(64);
default: default:
LOG(FATAL) << "unknown type_flag=" << type_flag; LOG(FATAL) << "unknown type_flag=" << type_flag;
return Float(32); return Float(32);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment