Unverified Commit fc0149d5 by Tianqi Chen Committed by GitHub

[ATTR] Introduce Integer container (#1994)

parent 5742995d
......@@ -35,6 +35,7 @@
#include <string>
#include "ir.h"
#include "base.h"
#include "expr.h"
#include "packed_func_ext.h"
namespace tvm {
......@@ -73,7 +74,6 @@ inline Type NullValue<Type>() {
return Type(Type::Handle, 0, 0);
}
/*! \brief Error thrown during attribute checking. */
struct AttrError : public dmlc::Error {
/*!
......
......@@ -29,6 +29,7 @@ using HalideIR::VarExpr;
using HalideIR::IR::RangeNode;
using HalideIR::IR::FunctionRef;
using HalideIR::IR::FunctionBaseNode;
using HalideIR::Internal::IntImm;
using HalideIR::Internal::Stmt;
using HalideIR::Internal::IRPrinter;
using HalideIR::Internal::Variable;
......@@ -83,6 +84,51 @@ class Var : public HalideIR::VarExpr {
};
/*!
* \brief Container of constant ineteger (IntImm).
*
* This is used to store and automate type check
* attributes that must be constant integer.
*/
class Integer : public Expr {
public:
Integer() : Expr() {}
/*!
* \brief constructor from node.
*/
explicit Integer(NodePtr<Node> node) : Expr(node) {}
/*!
* \brief Construct integer from int value.
*/
Integer(int value) : Expr(value) {} // NOLINT(*)
/*!
* \brief Assign an expression to integer.
* \param other another expression.
*/
Integer& operator=(const Integer& other) {
node_ = other.node_;
return *this;
}
/*!
* \brief Get pointer to the internal value.
* \return the content of the integer.
*/
const IntImm* operator->() const {
return static_cast<const IntImm*>(node_.get());
}
/*!
* \brief convert to int64_t
*/
operator int64_t() const {
CHECK(node_ != nullptr)
<< " Trying get reference a null Integer";
return (*this)->value;
}
/*! \brief type indicate the container type */
using ContainerType = IntImm;
};
/*! \brief container class of iteration variable. */
class IterVarNode;
......
......@@ -10,6 +10,7 @@
#include <sstream>
#include <string>
#include <memory>
#include <limits>
#include <type_traits>
#include "base.h"
......@@ -126,6 +127,8 @@ inline TNodeRef TVMArgValue::AsNodeRef() const {
inline TVMArgValue::operator HalideIR::Expr() const {
if (type_code_ == kNull) return Expr();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
return Expr(static_cast<int>(value_.v_int64));
}
if (type_code_ == kDLFloat) {
......@@ -145,6 +148,20 @@ inline TVMArgValue::operator HalideIR::Expr() const {
return Expr(sptr);
}
inline TVMArgValue::operator tvm::Integer() const {
if (type_code_ == kNull) return Integer();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
return Integer(static_cast<int>(value_.v_int64));
}
NodePtr<Node>& sptr = *ptr<NodePtr<Node> >();
CHECK(NodeTypeChecker<Integer>::Check(sptr.get()))
<< "Expected type " << NodeTypeName<Expr>()
<< " but get " << sptr->type_key();
return Integer(sptr);
}
inline NodePtr<Node>& TVMArgValue::node_sptr() {
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
return *ptr<NodePtr<Node> >();
......
......@@ -317,7 +317,7 @@ struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> {
/*! \brief Attributes for LRN operator */
struct LRNAttrs : public tvm::AttrsNode<LRNAttrs> {
IndexExpr size;
IndexExpr axis;
int axis;
double bias;
double alpha;
double beta;
......@@ -340,7 +340,7 @@ struct LRNAttrs : public tvm::AttrsNode<LRNAttrs> {
/*! \brief Attributes for L2Normalize operator */
struct L2NormalizeAttrs : public tvm::AttrsNode<L2NormalizeAttrs> {
double eps;
Array<IndexExpr> axis;
Array<Integer> axis;
TVM_DECLARE_ATTRS(L2NormalizeAttrs, "relay.attrs.L2NormalizeAttrs") {
TVM_ATTR_FIELD(eps)
......
......@@ -53,7 +53,7 @@ struct ConcatenateAttrs : public tvm::AttrsNode<ConcatenateAttrs> {
/*! \brief Attributes used in transpose operators */
struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> {
Array<IndexExpr> axes;
Array<Integer> axes;
TVM_DECLARE_ATTRS(TransposeAttrs, "relay.attrs.TransposeAttrs") {
TVM_ATTR_FIELD(axes)
.describe("The target axes order, reverse order if not specified.");
......@@ -70,10 +70,10 @@ struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
}; // struct ReshapeAttrs
struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
IndexExpr axis;
Integer axis;
TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") {
TVM_ATTR_FIELD(axis).set_default(NullValue<IndexExpr>())
TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
.describe("The axis over which to select values.");
}
};
......
......@@ -32,6 +32,9 @@ struct Expr;
#endif
namespace tvm {
// forward declarations
class Integer;
namespace runtime {
// forward declarations
class TVMArgs;
......@@ -559,6 +562,7 @@ class TVMArgValue : public TVMPODValue_ {
inline bool IsNodeType() const;
inline operator HalideIR::Type() const;
inline operator HalideIR::Expr() const;
inline operator tvm::Integer() const;
// get internal node ptr, if it is node
inline NodePtr<Node>& node_sptr();
};
......
......@@ -317,7 +317,7 @@ TVM_REGISTER_NODE_TYPE(LRNAttrs);
Expr MakeLRN(Expr data,
IndexExpr size,
IndexExpr axis,
int axis,
double alpha,
double beta,
double bias) {
......@@ -337,7 +337,7 @@ TVM_REGISTER_API("relay.op.nn._make.lrn")
});
RELAY_REGISTER_OP("nn.lrn")
.describe(R"code(LRN layer.
.describe(R"code(LRN layer.
Normalize the input in a local region across or within feature maps.
Each input value is divided by (1 + (\alpha/n) \sum_i x_i^2)^\beta,
......@@ -362,7 +362,7 @@ TVM_REGISTER_NODE_TYPE(L2NormalizeAttrs);
Expr MakeL2Normalize(Expr data,
double eps,
Array<IndexExpr> axis) {
Array<Integer> axis) {
auto attrs = make_node<L2NormalizeAttrs>();
attrs->eps = eps;
attrs->axis = std::move(axis);
......
......@@ -218,24 +218,23 @@ bool TransposeRel(const Array<Type>& types,
}
const auto* param = attrs.as<TransposeAttrs>();
const int ndim = data->shape.size();
const Array<IndexExpr>& axes = param->axes;
const Array<Integer>& axes = param->axes;
// check dimension match
CHECK(axes.empty() || static_cast<int>(axes.size()) == ndim)
CHECK(!axes.defined() || static_cast<int>(axes.size()) == ndim)
<< "Dimension mismatch: axes has " << axes.size() << " elements"
<< ", but data.ndim = " << ndim;
// construct int_axes
std::vector<int> int_axes;
int_axes.reserve(ndim);
if (axes.empty()) {
// used not defined to check if it is None.
if (!axes.defined()) {
for (int i = ndim - 1; i >= 0; --i) {
int_axes.push_back(i);
}
} else {
std::vector<int> axis_used(ndim, 0);
for (const IndexExpr& e : axes) {
const int64_t *axis_ptr = as_const_int(e);
CHECK(axis_ptr != nullptr);
int axis = *axis_ptr;
for (const Integer& e : axes) {
int64_t axis = e;
// sanity check for axis and ndim
CHECK(-ndim <= axis && axis < ndim)
<< "transpose only allows each `axis` in `axes` in range [-data.ndim, data.ndim)"
......@@ -245,7 +244,7 @@ bool TransposeRel(const Array<Type>& types,
// sanity check for duplication
CHECK(!axis_used[axis]) << "Duplicate axes in transpose: " << axis;
axis_used[axis] = 1;
int_axes.push_back(axis);
int_axes.push_back(static_cast<int>(axis));
}
}
std::vector<IndexExpr> oshape;
......@@ -258,7 +257,7 @@ bool TransposeRel(const Array<Type>& types,
}
Expr MakeTranspose(Expr data,
Array<IndexExpr> axes) {
Array<Integer> axes) {
auto attrs = make_node<TransposeAttrs>();
attrs->axes = std::move(axes);
static const Op& op = Op::Get("transpose");
......@@ -401,7 +400,7 @@ bool TakeRel(const Array<Type>& types,
std::vector<IndexExpr> oshape;
const auto ndim_data = static_cast<int>(data->shape.size());
const auto ndim_indices = static_cast<int>(indices->shape.size());
auto axis = (*as_const_int(param->axis));
int axis = static_cast<int>(param->axis->value);
if (axis < 0) axis += ndim_data;
CHECK_LE(axis, ndim_data)
<< "axis should be with in data shape"
......@@ -424,9 +423,9 @@ bool TakeRel(const Array<Type>& types,
Expr MakeTake(Expr data,
Expr indices,
IndexExpr axis) {
Integer axis) {
auto attrs = make_node<TakeAttrs>();
attrs->axis = axis;
attrs->axis = std::move(axis);
static const Op& op = Op::Get("take");
return CallNode::make(op, {data, indices}, Attrs(attrs), {});
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment