Commit 8541e255 by YixinBao Committed by Tianqi Chen

add bfloat16 typeflag support (#4525)

parent 62aac9f1
...@@ -100,10 +100,14 @@ enum TypeFlag { ...@@ -100,10 +100,14 @@ enum TypeFlag {
kInt32 = 4, kInt32 = 4,
kInt8 = 5, kInt8 = 5,
kInt64 = 6, kInt64 = 6,
kInt16 = 7, // kBool = 7,
kUint16 = 8, // 7 is reserved for kBool, in order to keep consistency with MXNet TypeFlag defined in
kUint32 = 9, // https://github.com/apache/incubator-mxnet/blob/master/3rdparty/mshadow/mshadow/base.h#L314
kUint64 = 10, kInt16 = 8,
kUint16 = 9,
kUint32 = 10,
kUint64 = 11,
kBfloat16 = 12,
}; };
enum IndicatorRuleFlag { enum IndicatorRuleFlag {
...@@ -125,7 +129,8 @@ enum IndicatorRuleFlag { ...@@ -125,7 +129,8 @@ enum IndicatorRuleFlag {
.add_enum("int8", kInt8) \ .add_enum("int8", kInt8) \
.add_enum("int16", kInt16) \ .add_enum("int16", kInt16) \
.add_enum("int32", kInt32) \ .add_enum("int32", kInt32) \
.add_enum("int64", kInt64) .add_enum("int64", kInt64) \
.add_enum("bfloat16", kBfloat16)
struct CastParam : public dmlc::Parameter<CastParam> { struct CastParam : public dmlc::Parameter<CastParam> {
int dtype; int dtype;
......
...@@ -40,6 +40,7 @@ static int GetDTypeSize(int type_flag) { ...@@ -40,6 +40,7 @@ static int GetDTypeSize(int type_flag) {
case kInt8: case kInt8:
return 1; return 1;
case kFloat16: case kFloat16:
case kBfloat16:
case kInt16: case kInt16:
case kUint16: case kUint16:
return 2; return 2;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment