Commit e2ae388a by Eric Junyuan Xie Committed by Tianqi Chen

improve infer shape/type error message (#4)

* improve infer shape/type error message

* fix dense infer shape
parent 986caf71
...@@ -35,20 +35,23 @@ inline bool DenseInferShape(const nnvm::NodeAttrs& attrs, ...@@ -35,20 +35,23 @@ inline bool DenseInferShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]"; CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]";
} }
CHECK_EQ(out_shape->size(), 1U); CHECK_EQ(out_shape->size(), 1U);
TShape dshape = (*in_shape)[DenseParam::kData]; if ((*out_shape)[0].ndim() != 0) {
TShape oshape = (*out_shape)[0]; // reverse infer
// require data to be known TShape dshape = (*out_shape)[0];
if (dshape.ndim() == 0) return false; dshape[dshape.ndim() - 1] = 0;
dim_t num_input; NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, DenseParam::kData, dshape);
num_input = dshape.ProdShape(1, dshape.ndim()); }
SHAPE_ASSIGN_CHECK(*in_shape, DenseParam::kWeight, TShape({param.units, num_input})); dim_t num_inputs = 0;
if (param.use_bias) { if ((*in_shape)[DenseParam::kData].ndim() != 0) {
SHAPE_ASSIGN_CHECK(*in_shape, DenseParam::kBias, TShape({param.units})); TShape oshape = (*in_shape)[DenseParam::kData];
num_inputs = oshape[oshape.ndim() - 1];
oshape[oshape.ndim() - 1] = param.units;
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
} }
SHAPE_ASSIGN_CHECK(*out_shape, 0, TShape({dshape[0], param.units})); NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, DenseParam::kWeight,
if (oshape.ndim() != 0) { TShape({param.units, num_inputs}));
dshape[0] = oshape[0]; if (param.use_bias) {
SHAPE_ASSIGN_CHECK(*in_shape, DenseParam::kData, dshape); NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, DenseParam::kBias, TShape({param.units}));
} }
return true; return true;
} }
......
...@@ -13,29 +13,6 @@ ...@@ -13,29 +13,6 @@
namespace nnvm { namespace nnvm {
namespace top { namespace top {
/*! \brief exception throwed by InferShape error */
struct InferShapeError : public dmlc::Error {
/*! \brief analyze message */
std::string msg;
/*! \brief corresponding input index */
int index;
// constructor
InferShapeError(const std::string& msg_, int index)
: dmlc::Error(msg_), msg(msg_), index(index) {}
};
/*! \brief exception throwed by InferShape error */
struct InferTypeError : public dmlc::Error {
/*! \brief analyze message */
std::string msg;
/*! \brief corresponding input index */
int index;
// constructor
InferTypeError(const std::string& msg_, int index)
: dmlc::Error(msg_), msg(msg_), index(index) {}
};
/*! /*!
* \brief Parse keyword arguments as PType arguments and save to parsed * \brief Parse keyword arguments as PType arguments and save to parsed
* \tparam PType the arameter type. * \tparam PType the arameter type.
...@@ -128,41 +105,88 @@ inline bool type_assign(int *y, const int& x) { ...@@ -128,41 +105,88 @@ inline bool type_assign(int *y, const int& x) {
return true; return true;
} }
template<typename AttrType>
inline std::string attr_assign_error_msg(const NodeAttrs& attrs,
int index, bool is_input,
const AttrType& expected,
const AttrType& actual,
const char* attr_name) {
static const auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
static const auto& flist_outputs = Op::GetAttr<FListOutputNames>("FListOutputNames");
const auto& flist = is_input ? flist_inputs : flist_outputs;
std::string name;
if (flist.count(attrs.op)) {
name = flist[attrs.op](attrs)[index];
} else {
name = (is_input ? "data" : "output") + std::to_string(index);
}
std::ostringstream msg;
msg << "Operator " << attrs.op->name << "(";
for (const auto& kv : attrs.dict) msg << kv.first << "=" << kv.second << ", ";
msg << "name=" << attrs.name << ") expects " << name << "\'s " << attr_name
<< " to be " << expected << ", but got " << actual << ".";
return msg.str();
}
/*! /*!
* \brief macro assign shape to out if out is unknown otherwise check consistency * \brief macro assign shape to out if out is unknown otherwise check consistency
* Use macro so we can see the error file more clearly * Use macro so we can see the error file more clearly
* \param shape_array the shape array to store the result * \param inputs the shape array to store the result
* \param index the index of in the array * \param index the index of in the array
* \param shape the inferred shape * \param shape the inferred shape
*/ */
#define SHAPE_ASSIGN_CHECK(shape_array, index, shape) \ #define NNVM_ASSIGN_INPUT_SHAPE(attrs, inputs, index, shape) \
{ \ { \
if (!shape_assign(&(shape_array)[index], TShape(shape))) { \ if (!shape_assign(&(inputs)[index], TShape(shape))) { \
std::ostringstream os; \ LOG(FATAL) << attr_assign_error_msg(attrs, index, true, shape, \
os << "Shape inconsistent, Provided=" << (shape_array)[index] << ',' \ (inputs)[index], "shape"); \
<< " inferred shape=" << shape; \ } \
throw InferShapeError(os.str(), index); \ }
/*!
* \brief macro assign shape to out if out is unknown otherwise check consistency
* Use macro so we can see the error file more clearly
* \param inputs the shape array to store the result
* \param index the index of in the array
* \param shape the inferred shape
*/
#define NNVM_ASSIGN_OUTPUT_SHAPE(attrs, outputs, index, shape) \
{ \
if (!shape_assign(&(outputs)[index], TShape(shape))) { \
LOG(FATAL) << attr_assign_error_msg(attrs, index, false, shape, \
(outputs)[index], "shape"); \
} \ } \
} }
/*! /*!
* \brief macro assign type to out if out is unknown (-1) otherwise check consistency * \brief macro assign type to out if out is unknown (-1) otherwise check consistency
* Use macro so we can see the error file more clearly * Use macro so we can see the error file more clearly
* \param type_array the type array to store the result * \param inputs the type array to store the result
* \param index the index of in the array * \param index the index of in the array
* \param type the inferred type * \param type the inferred type
*/ */
#define TYPE_ASSIGN_CHECK(type_array, index, type) \ #define NNVM_ASSIGN_INPUT_TYPE(attrs, inputs, index, type) \
{ \ { \
if (!type_assign(&(type_array)[index], type)) { \ if (!type_assign(&(inputs)[index], type)) { \
std::ostringstream os; \ LOG(FATAL) << attr_assign_error_msg(attrs, index, true, type, \
os << "Type inconsistent, Provided=" \ (inputs)[index], "type"); \
<< type_string((type_array)[index]) << ',' \
<< " inferred type=" << type_string(type); \
throw InferTypeError(os.str(), index); \
} \ } \
} }
/*!
* \brief macro assign type to out if out is unknown (-1) otherwise check consistency
* Use macro so we can see the error file more clearly
* \param inputs the type array to store the result
* \param index the index of in the array
* \param type the inferred type
*/
#define NNVM_ASSIGN_OUTPUT_TYPE(attrs, outputs, index, type) \
{ \
if (!type_assign(&(outputs)[index], type)) { \
LOG(FATAL) << attr_assign_error_msg(attrs, index, false, type, \
(outputs)[index], "type"); \
} \
}
// simply return the shape as same // simply return the shape as same
inline bool SameShape(const NodeAttrs& attrs, inline bool SameShape(const NodeAttrs& attrs,
......
...@@ -64,7 +64,7 @@ inline bool FlattenInferShape(const nnvm::NodeAttrs& attrs, ...@@ -64,7 +64,7 @@ inline bool FlattenInferShape(const nnvm::NodeAttrs& attrs,
for (uint32_t i = 1; i < dshape.ndim(); ++i) { for (uint32_t i = 1; i < dshape.ndim(); ++i) {
target_dim *= dshape[i]; target_dim *= dshape[i];
} }
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({dshape[0], target_dim})); NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, TShape({dshape[0], target_dim}));
return true; return true;
} }
...@@ -130,11 +130,11 @@ inline bool ConcatenateInferShape(const nnvm::NodeAttrs& attrs, ...@@ -130,11 +130,11 @@ inline bool ConcatenateInferShape(const nnvm::NodeAttrs& attrs,
if (dshape.ndim() == 0) return false; if (dshape.ndim() == 0) return false;
for (size_t i = 0; i < in_shape->size(); ++i) { for (size_t i = 0; i < in_shape->size(); ++i) {
SHAPE_ASSIGN_CHECK(*in_shape, i, dshape); NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, i, dshape);
} }
if (!has_zero) dshape[param.axis] = size; if (!has_zero) dshape[param.axis] = size;
SHAPE_ASSIGN_CHECK(*out_shape, 0, dshape); NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, dshape);
return dshape.Size() != 0; return dshape.Size() != 0;
} }
...@@ -210,7 +210,7 @@ inline bool CastInferType(const nnvm::NodeAttrs& attrs, ...@@ -210,7 +210,7 @@ inline bool CastInferType(const nnvm::NodeAttrs& attrs,
std::vector<int> *out_attrs) { std::vector<int> *out_attrs) {
const CastParam& param = nnvm::get<CastParam>(attrs.parsed); const CastParam& param = nnvm::get<CastParam>(attrs.parsed);
CHECK_EQ(out_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U);
TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype); NNVM_ASSIGN_OUTPUT_TYPE(attrs, *out_attrs, 0, param.dtype);
return true; return true;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment