Commit 2406757d by abergeron Committed by Tianqi Chen

Add support for missing uint types. (#272)

parent 02141d4a
......@@ -57,20 +57,26 @@ enum TypeFlag {
kInt8 = 5,
kInt64 = 6,
kInt16 = 7,
kUint16 = 8,
kUint32 = 9,
kUint64 = 10,
};
struct CastParam : public dmlc::Parameter<CastParam> {
int dtype;
DMLC_DECLARE_PARAMETER(CastParam) {
DMLC_DECLARE_FIELD(dtype)
.add_enum("float16", kFloat16)
.add_enum("float32", kFloat32)
.add_enum("float64", kFloat64)
.add_enum("float16", kFloat16)
.add_enum("uint8", kUint8)
.add_enum("uint8", kUint8)
.add_enum("uint16", kUint16)
.add_enum("uint32", kUint32)
.add_enum("uint64", kUint64)
.add_enum("int8", kInt8)
.add_enum("int16", kInt16)
.add_enum("int32", kInt32)
.add_enum("int8", kInt8)
.add_enum("int64", kInt64)
.add_enum("int16", kInt16)
.describe("Output data type.");
}
};
......
......@@ -36,6 +36,9 @@ DTYPE_TO_TCODE = {
"int8": 5,
"int64": 6,
"int16": 7,
"uint16": 8,
"uint32": 9,
"uint64": 10,
}
TCODE_TO_DTYPE = {
......@@ -48,6 +51,9 @@ TCODE_TO_DTYPE = {
5: "int8",
6: "int64",
7: "int16",
8: "uint16",
9: "uint32",
10: "uint64",
}
def set_dtype_inputs(g, dtype):
......
......@@ -34,6 +34,9 @@ int GetTypeFlag(tvm::Type type) {
if (type == tvm::Int(8)) return 5;
if (type == tvm::Int(64)) return 6;
if (type == tvm::Int(16)) return 7;
if (type == tvm::UInt(16)) return 8;
if (type == tvm::UInt(32)) return 9;
if (type == tvm::UInt(64)) return 10;
LOG(FATAL) << "cannot convert " << type;
return 0;
}
......@@ -56,6 +59,12 @@ Type GetTVMType(int type_flag) {
return tvm::Int(64);
case 7:
return tvm::Int(16);
case 8:
return tvm::UInt(16);
case 9:
return tvm::UInt(32);
case 10:
return tvm::UInt(64);
default:
LOG(FATAL) << "unknown type_flag=" << type_flag;
return Float(32);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment