Commit e2ae388a by Eric Junyuan Xie Committed by Tianqi Chen

improve infer shape/type error message (#4)

* improve infer shape/type error message

* fix dense infer shape
parent 986caf71
......@@ -35,20 +35,23 @@ inline bool DenseInferShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]";
}
CHECK_EQ(out_shape->size(), 1U);
TShape dshape = (*in_shape)[DenseParam::kData];
TShape oshape = (*out_shape)[0];
// require data to be known
if (dshape.ndim() == 0) return false;
dim_t num_input;
num_input = dshape.ProdShape(1, dshape.ndim());
SHAPE_ASSIGN_CHECK(*in_shape, DenseParam::kWeight, TShape({param.units, num_input}));
if (param.use_bias) {
SHAPE_ASSIGN_CHECK(*in_shape, DenseParam::kBias, TShape({param.units}));
if ((*out_shape)[0].ndim() != 0) {
// reverse infer
TShape dshape = (*out_shape)[0];
dshape[dshape.ndim() - 1] = 0;
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, DenseParam::kData, dshape);
}
dim_t num_inputs = 0;
if ((*in_shape)[DenseParam::kData].ndim() != 0) {
TShape oshape = (*in_shape)[DenseParam::kData];
num_inputs = oshape[oshape.ndim() - 1];
oshape[oshape.ndim() - 1] = param.units;
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
}
SHAPE_ASSIGN_CHECK(*out_shape, 0, TShape({dshape[0], param.units}));
if (oshape.ndim() != 0) {
dshape[0] = oshape[0];
SHAPE_ASSIGN_CHECK(*in_shape, DenseParam::kData, dshape);
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, DenseParam::kWeight,
TShape({param.units, num_inputs}));
if (param.use_bias) {
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, DenseParam::kBias, TShape({param.units}));
}
return true;
}
......
......@@ -13,29 +13,6 @@
namespace nnvm {
namespace top {
/*! \brief exception throwed by InferShape error */
struct InferShapeError : public dmlc::Error {
/*! \brief analyze message */
std::string msg;
/*! \brief corresponding input index */
int index;
// constructor
InferShapeError(const std::string& msg_, int index)
: dmlc::Error(msg_), msg(msg_), index(index) {}
};
/*! \brief exception throwed by InferShape error */
struct InferTypeError : public dmlc::Error {
/*! \brief analyze message */
std::string msg;
/*! \brief corresponding input index */
int index;
// constructor
InferTypeError(const std::string& msg_, int index)
: dmlc::Error(msg_), msg(msg_), index(index) {}
};
/*!
* \brief Parse keyword arguments as PType arguments and save to parsed
* \tparam PType the arameter type.
......@@ -128,41 +105,88 @@ inline bool type_assign(int *y, const int& x) {
return true;
}
template<typename AttrType>
inline std::string attr_assign_error_msg(const NodeAttrs& attrs,
int index, bool is_input,
const AttrType& expected,
const AttrType& actual,
const char* attr_name) {
static const auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
static const auto& flist_outputs = Op::GetAttr<FListOutputNames>("FListOutputNames");
const auto& flist = is_input ? flist_inputs : flist_outputs;
std::string name;
if (flist.count(attrs.op)) {
name = flist[attrs.op](attrs)[index];
} else {
name = (is_input ? "data" : "output") + std::to_string(index);
}
std::ostringstream msg;
msg << "Operator " << attrs.op->name << "(";
for (const auto& kv : attrs.dict) msg << kv.first << "=" << kv.second << ", ";
msg << "name=" << attrs.name << ") expects " << name << "\'s " << attr_name
<< " to be " << expected << ", but got " << actual << ".";
return msg.str();
}
/*!
* \brief macro assign shape to out if out is unknown otherwise check consistency
* Use macro so we can see the error file more clearly
* \param shape_array the shape array to store the result
* \param inputs the shape array to store the result
* \param index the index of in the array
* \param shape the inferred shape
*/
#define SHAPE_ASSIGN_CHECK(shape_array, index, shape) \
{ \
if (!shape_assign(&(shape_array)[index], TShape(shape))) { \
std::ostringstream os; \
os << "Shape inconsistent, Provided=" << (shape_array)[index] << ',' \
<< " inferred shape=" << shape; \
throw InferShapeError(os.str(), index); \
} \
#define NNVM_ASSIGN_INPUT_SHAPE(attrs, inputs, index, shape) \
{ \
if (!shape_assign(&(inputs)[index], TShape(shape))) { \
LOG(FATAL) << attr_assign_error_msg(attrs, index, true, shape, \
(inputs)[index], "shape"); \
} \
}
/*!
* \brief macro assign shape to out if out is unknown otherwise check consistency
* Use macro so we can see the error file more clearly
* \param inputs the shape array to store the result
* \param index the index of in the array
* \param shape the inferred shape
*/
#define NNVM_ASSIGN_OUTPUT_SHAPE(attrs, outputs, index, shape) \
{ \
if (!shape_assign(&(outputs)[index], TShape(shape))) { \
LOG(FATAL) << attr_assign_error_msg(attrs, index, false, shape, \
(outputs)[index], "shape"); \
} \
}
/*!
* \brief macro assign type to out if out is unknown (-1) otherwise check consistency
* Use macro so we can see the error file more clearly
* \param type_array the type array to store the result
* \param inputs the type array to store the result
* \param index the index of in the array
* \param type the inferred type
*/
#define TYPE_ASSIGN_CHECK(type_array, index, type) \
{ \
if (!type_assign(&(type_array)[index], type)) { \
std::ostringstream os; \
os << "Type inconsistent, Provided=" \
<< type_string((type_array)[index]) << ',' \
<< " inferred type=" << type_string(type); \
throw InferTypeError(os.str(), index); \
} \
#define NNVM_ASSIGN_INPUT_TYPE(attrs, inputs, index, type) \
{ \
if (!type_assign(&(inputs)[index], type)) { \
LOG(FATAL) << attr_assign_error_msg(attrs, index, true, type, \
(inputs)[index], "type"); \
} \
}
/*!
* \brief macro assign type to out if out is unknown (-1) otherwise check consistency
* Use macro so we can see the error file more clearly
* \param inputs the type array to store the result
* \param index the index of in the array
* \param type the inferred type
*/
#define NNVM_ASSIGN_OUTPUT_TYPE(attrs, outputs, index, type) \
{ \
if (!type_assign(&(outputs)[index], type)) { \
LOG(FATAL) << attr_assign_error_msg(attrs, index, false, type, \
(outputs)[index], "type"); \
} \
}
// simply return the shape as same
inline bool SameShape(const NodeAttrs& attrs,
......
......@@ -64,7 +64,7 @@ inline bool FlattenInferShape(const nnvm::NodeAttrs& attrs,
for (uint32_t i = 1; i < dshape.ndim(); ++i) {
target_dim *= dshape[i];
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({dshape[0], target_dim}));
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, TShape({dshape[0], target_dim}));
return true;
}
......@@ -130,11 +130,11 @@ inline bool ConcatenateInferShape(const nnvm::NodeAttrs& attrs,
if (dshape.ndim() == 0) return false;
for (size_t i = 0; i < in_shape->size(); ++i) {
SHAPE_ASSIGN_CHECK(*in_shape, i, dshape);
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, i, dshape);
}
if (!has_zero) dshape[param.axis] = size;
SHAPE_ASSIGN_CHECK(*out_shape, 0, dshape);
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, dshape);
return dshape.Size() != 0;
}
......@@ -210,7 +210,7 @@ inline bool CastInferType(const nnvm::NodeAttrs& attrs,
std::vector<int> *out_attrs) {
const CastParam& param = nnvm::get<CastParam>(attrs.parsed);
CHECK_EQ(out_attrs->size(), 1U);
TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype);
NNVM_ASSIGN_OUTPUT_TYPE(attrs, *out_attrs, 0, param.dtype);
return true;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment