Unverified Commit 3a672e3e by Yizhi Liu Committed by GitHub

[Arith] add SizeVar representing non-neg valued variable in a tensor shape (#4684)

* [arith] add ShapeVar representing non-neg valued variable in a tensor shape

* bounder remover; deal with div in int_set differently

* fix bounder_remover

* migrate unittest to use shape_var

* use tvm.shape_var in integration & relay tests

* add test case; fix Var register

* fix lint

* fix lint again

* add default ShapeVar visitor in Relay

* fix override

* fix ShapeVar visit bug

* revert IntervalSet for shape_var

* remove bound_remover

* remove is_var; use constructor for shapevar/var instead

* ShapeVar -> SizeVar; add constructor comments

* shape_var -> size_var in doc

* tindex -> size
parent d756d3ca
...@@ -24,6 +24,7 @@ The user facing API for computation declaration. ...@@ -24,6 +24,7 @@ The user facing API for computation declaration.
tvm.load_json tvm.load_json
tvm.save_json tvm.save_json
tvm.var tvm.var
tvm.size_var
tvm.const tvm.const
tvm.convert tvm.convert
tvm.placeholder tvm.placeholder
...@@ -49,6 +50,7 @@ The user facing API for computation declaration. ...@@ -49,6 +50,7 @@ The user facing API for computation declaration.
.. autofunction:: tvm.load_json .. autofunction:: tvm.load_json
.. autofunction:: tvm.save_json .. autofunction:: tvm.save_json
.. autofunction:: tvm.var .. autofunction:: tvm.var
.. autofunction:: tvm.size_var
.. autofunction:: tvm.const .. autofunction:: tvm.const
.. autofunction:: tvm.convert .. autofunction:: tvm.convert
.. autofunction:: tvm.placeholder .. autofunction:: tvm.placeholder
......
...@@ -65,27 +65,33 @@ class Var; ...@@ -65,27 +65,33 @@ class Var;
*/ */
class VarNode : public PrimExprNode { class VarNode : public PrimExprNode {
public: public:
/*! \brief constructor */
VarNode() {}
VarNode(DataType dtype, std::string name_hint);
/*! /*!
* \brief The hint to the variable name. * \brief The hint to the variable name.
* \note Each variable is uniquely identified by its address. * \note Each variable is uniquely identified by its address.
*/ */
std::string name_hint; std::string name_hint;
static Var make(DataType dtype, std::string name_hint);
void VisitAttrs(AttrVisitor* v) { void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype); v->Visit("dtype", &dtype);
v->Visit("name", &name_hint); v->Visit("name", &name_hint);
} }
static constexpr const char* _type_key = "Variable"; static constexpr const char* _type_key = "Variable";
TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, PrimExprNode); TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode);
}; };
/*! \brief a named variable in TVM */ /*! \brief a named variable in TVM */
class Var : public PrimExpr { class Var : public PrimExpr {
public: public:
explicit Var(ObjectPtr<Object> n) : PrimExpr(n) {} explicit Var(ObjectPtr<Object> n) : PrimExpr(n) {}
/*! \brief constructor
* \param name_hint variable name
* \param t data type
*/
TVM_DLL explicit Var(std::string name_hint = "v", TVM_DLL explicit Var(std::string name_hint = "v",
DataType t = DataType::Int(32)); DataType t = DataType::Int(32));
/*! /*!
...@@ -114,6 +120,53 @@ class Var : public PrimExpr { ...@@ -114,6 +120,53 @@ class Var : public PrimExpr {
using ContainerType = VarNode; using ContainerType = VarNode;
}; };
class SizeVar;
/*!
* \brief A variable node represent a tensor index size,
* whose value must be non-negative.
*/
class SizeVarNode : public VarNode {
public:
/*! \brief constructor */
SizeVarNode() {}
/*! \brief constructor
* \param dtype data type
* \param name_hint variable name
*/
SizeVarNode(DataType dtype, std::string name_hint);
static constexpr const char* _type_key = "SizeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(SizeVarNode, VarNode);
};
/*! \brief a named variable represents a tensor index size */
class SizeVar : public Var {
public:
explicit SizeVar(ObjectPtr<Object> n) : Var(n) {}
/*! \brief constructor
* \param name_hint variable name
* \param t data type
*/
TVM_DLL explicit SizeVar(std::string name_hint = "s",
DataType t = DataType::Int(32));
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
*/
const SizeVarNode* operator->() const {
return get();
}
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
*/
const SizeVarNode* get() const {
return static_cast<const SizeVarNode*>(data_.get());
}
/*! \brief type indicate the container type */
using ContainerType = SizeVarNode;
};
/*! /*!
* \brief Container of constant int that adds more constructors. * \brief Container of constant int that adds more constructors.
* *
......
...@@ -38,6 +38,7 @@ namespace ir { ...@@ -38,6 +38,7 @@ namespace ir {
using IntImmNode = tvm::IntImmNode; using IntImmNode = tvm::IntImmNode;
using FloatImmNode = tvm::FloatImmNode; using FloatImmNode = tvm::FloatImmNode;
using VarNode = tvm::VarNode; using VarNode = tvm::VarNode;
using SizeVarNode = tvm::SizeVarNode;
/*! \brief String constants, only used in asserts. */ /*! \brief String constants, only used in asserts. */
class StringImmNode : public PrimExprNode { class StringImmNode : public PrimExprNode {
...@@ -679,7 +680,7 @@ class AnyNode : public PrimExprNode { ...@@ -679,7 +680,7 @@ class AnyNode : public PrimExprNode {
void VisitAttrs(AttrVisitor* v) {} void VisitAttrs(AttrVisitor* v) {}
/*! \brief Convert to var. */ /*! \brief Convert to var. */
Var ToVar() const { Var ToVar() const {
return VarNode::make(DataType::Int(32), "any_dim"); return Var("any_dim", DataType::Int(32));
} }
TVM_DLL static PrimExpr make(); TVM_DLL static PrimExpr make();
......
...@@ -133,6 +133,9 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> { ...@@ -133,6 +133,9 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
} }
// Functions that can be overriden by subclass // Functions that can be overriden by subclass
virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const SizeVarNode* op, Args... args) {
return VisitExpr_(static_cast<const VarNode*>(op), std::forward<Args>(args)...);
}
virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
...@@ -174,6 +177,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> { ...@@ -174,6 +177,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
FType vtable; FType vtable;
// Set dispatch // Set dispatch
IR_EXPR_FUNCTOR_DISPATCH(VarNode); IR_EXPR_FUNCTOR_DISPATCH(VarNode);
IR_EXPR_FUNCTOR_DISPATCH(SizeVarNode);
IR_EXPR_FUNCTOR_DISPATCH(LoadNode); IR_EXPR_FUNCTOR_DISPATCH(LoadNode);
IR_EXPR_FUNCTOR_DISPATCH(LetNode); IR_EXPR_FUNCTOR_DISPATCH(LetNode);
IR_EXPR_FUNCTOR_DISPATCH(CallNode); IR_EXPR_FUNCTOR_DISPATCH(CallNode);
...@@ -297,6 +301,7 @@ class TVM_DLL ExprVisitor : ...@@ -297,6 +301,7 @@ class TVM_DLL ExprVisitor :
using ExprFunctor::VisitExpr; using ExprFunctor::VisitExpr;
// list of functions to override. // list of functions to override.
void VisitExpr_(const VarNode* op) override; void VisitExpr_(const VarNode* op) override;
void VisitExpr_(const SizeVarNode* op) override;
void VisitExpr_(const LoadNode* op) override; void VisitExpr_(const LoadNode* op) override;
void VisitExpr_(const LetNode* op) override; void VisitExpr_(const LetNode* op) override;
void VisitExpr_(const CallNode* op) override; void VisitExpr_(const CallNode* op) override;
...@@ -341,6 +346,7 @@ class TVM_DLL ExprMutator : ...@@ -341,6 +346,7 @@ class TVM_DLL ExprMutator :
using ExprFunctor::VisitExpr; using ExprFunctor::VisitExpr;
// list of functions to override. // list of functions to override.
PrimExpr VisitExpr_(const VarNode* op) override; PrimExpr VisitExpr_(const VarNode* op) override;
PrimExpr VisitExpr_(const SizeVarNode* op) override;
PrimExpr VisitExpr_(const LoadNode* op) override; PrimExpr VisitExpr_(const LoadNode* op) override;
PrimExpr VisitExpr_(const LetNode* op) override; PrimExpr VisitExpr_(const LetNode* op) override;
PrimExpr VisitExpr_(const CallNode* op) override; PrimExpr VisitExpr_(const CallNode* op) override;
......
...@@ -192,6 +192,25 @@ def var(name="tindex", dtype=int32): ...@@ -192,6 +192,25 @@ def var(name="tindex", dtype=int32):
return _api_internal._Var(name, dtype) return _api_internal._Var(name, dtype)
def size_var(name="size", dtype=int32):
"""Create a new variable represents a tensor shape size, which is non-negative.
Parameters
----------
name : str
The name
dtype : str
The data type
Returns
-------
var : SizeVar
The result symbolic shape variable.
"""
return _api_internal._SizeVar(name, dtype)
def any(*args): def any(*args):
"""Create a new experssion of the union of all conditions in the arguments """Create a new experssion of the union of all conditions in the arguments
......
...@@ -279,6 +279,25 @@ class Var(PrimExpr): ...@@ -279,6 +279,25 @@ class Var(PrimExpr):
@register_object @register_object
class SizeVar(Var):
"""Symbolic variable to represent a tensor index size
which is greater or equal to zero
Parameters
----------
name : str
The name
dtype : int
The data type
"""
# pylint: disable=super-init-not-called
def __init__(self, name, dtype):
self.__init_handle_by_constructor__(
_api_internal._SizeVar, name, dtype)
@register_object
class Reduce(PrimExpr): class Reduce(PrimExpr):
"""Reduce node. """Reduce node.
......
...@@ -63,7 +63,7 @@ class PyVariableUsage(ast.NodeVisitor): ...@@ -63,7 +63,7 @@ class PyVariableUsage(ast.NodeVisitor):
_internal_assert(func_id in list(HYBRID_GLOBALS.keys()) + \ _internal_assert(func_id in list(HYBRID_GLOBALS.keys()) + \
['range', 'max', 'min', 'len'] + \ ['range', 'max', 'min', 'len'] + \
list(self.symbols.keys()), \ list(self.symbols.keys()), \
"Function call id not in intrinsics' list") "Function call id " + func_id + " not in intrinsics' list")
for elem in node.args: for elem in node.args:
self.visit(elem) self.visit(elem)
......
...@@ -33,7 +33,12 @@ namespace ir { ...@@ -33,7 +33,12 @@ namespace ir {
TVM_REGISTER_GLOBAL("_Var") TVM_REGISTER_GLOBAL("_Var")
.set_body_typed([](std::string s, DataType t) { .set_body_typed([](std::string s, DataType t) {
return VarNode::make(t, s); return Var(s, t);
});
TVM_REGISTER_GLOBAL("_SizeVar")
.set_body_typed([](std::string s, DataType t) {
return SizeVar(s, t);
}); });
TVM_REGISTER_GLOBAL("make.abs") TVM_REGISTER_GLOBAL("make.abs")
......
...@@ -86,7 +86,7 @@ class BoundDeducer: public ExprVisitor { ...@@ -86,7 +86,7 @@ class BoundDeducer: public ExprVisitor {
void VisitExpr(const PrimExpr& e) final { void VisitExpr(const PrimExpr& e) final {
if (!success_) return; if (!success_) return;
if (e.get() == path_[iter_++]) { if (iter_ < path_.size() && e.get() == path_[iter_++]) {
ExprVisitor::VisitExpr(e); ExprVisitor::VisitExpr(e);
} else { } else {
success_ = false; success_ = false;
...@@ -297,6 +297,7 @@ void BoundDeducer::Transform() { ...@@ -297,6 +297,7 @@ void BoundDeducer::Transform() {
void BoundDeducer::Deduce() { void BoundDeducer::Deduce() {
Init(); Init();
if (!success_) return; if (!success_) return;
Relax(); Relax();
if (!success_) return; if (!success_) return;
// get the path // get the path
......
...@@ -284,6 +284,16 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -284,6 +284,16 @@ class ConstIntBoundAnalyzer::Impl :
} }
} }
Entry VisitExpr_(const SizeVarNode* op) final {
SizeVar v = GetRef<SizeVar>(op);
auto it = var_map_.find(v);
if (it != var_map_.end()) {
return it->second;
} else {
return MakeBound(0, kPosInf);
}
}
Entry VisitRightShift(const CallNode* op) { Entry VisitRightShift(const CallNode* op) {
Entry a = VisitExpr(op->args[0]); Entry a = VisitExpr(op->args[0]);
Entry b = VisitExpr(op->args[1]); Entry b = VisitExpr(op->args[1]);
......
...@@ -401,6 +401,7 @@ class IntervalSetEvaluator : ...@@ -401,6 +401,7 @@ class IntervalSetEvaluator :
} }
} }
IntervalSet VisitExpr_(const AddNode* op) final { IntervalSet VisitExpr_(const AddNode* op) final {
return VisitBinaryExpr_(op); return VisitBinaryExpr_(op);
} }
......
...@@ -81,6 +81,9 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> { ...@@ -81,6 +81,9 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> {
virtual R VisitAttr_(const ir::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
// deep comparison of symbolic integer expressions. // deep comparison of symbolic integer expressions.
virtual R VisitAttr_(const VarNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const VarNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const SizeVarNode* op, Args... args) {
return VisitAttr_(static_cast<const VarNode*>(op), std::forward<Args>(args)...);
}
virtual R VisitAttr_(const ir::AddNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::AddNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::SubNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::SubNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::MulNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::MulNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
...@@ -115,6 +118,7 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> { ...@@ -115,6 +118,7 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> {
ATTR_FUNCTOR_DISPATCH(FloatImmNode); ATTR_FUNCTOR_DISPATCH(FloatImmNode);
ATTR_FUNCTOR_DISPATCH(StringImmNode); ATTR_FUNCTOR_DISPATCH(StringImmNode);
ATTR_FUNCTOR_DISPATCH(VarNode); ATTR_FUNCTOR_DISPATCH(VarNode);
ATTR_FUNCTOR_DISPATCH(SizeVarNode);
ATTR_FUNCTOR_DISPATCH(AddNode); ATTR_FUNCTOR_DISPATCH(AddNode);
ATTR_FUNCTOR_DISPATCH(SubNode); ATTR_FUNCTOR_DISPATCH(SubNode);
ATTR_FUNCTOR_DISPATCH(MulNode); ATTR_FUNCTOR_DISPATCH(MulNode);
......
...@@ -39,15 +39,19 @@ PrimExpr::PrimExpr(std::string str) ...@@ -39,15 +39,19 @@ PrimExpr::PrimExpr(std::string str)
: PrimExpr(ir::StringImmNode::make(str)) {} : PrimExpr(ir::StringImmNode::make(str)) {}
Var::Var(std::string name_hint, DataType t) Var::Var(std::string name_hint, DataType t)
: Var(VarNode::make(t, name_hint)) {} : Var(make_object<VarNode>(t, name_hint)) {}
Var VarNode::make(DataType t, std::string name_hint) { VarNode::VarNode(DataType t, std::string name_hint) {
ObjectPtr<VarNode> node = make_object<VarNode>(); this->dtype = t;
node->dtype = t; this->name_hint = std::move(name_hint);
node->name_hint = std::move(name_hint);
return Var(node);
} }
SizeVar::SizeVar(std::string name_hint, DataType t)
: SizeVar(make_object<SizeVarNode>(t, name_hint)) {}
SizeVarNode::SizeVarNode(DataType t, std::string name_hint)
: VarNode(t, std::move(name_hint)) {}
Range::Range(PrimExpr begin, PrimExpr end) Range::Range(PrimExpr begin, PrimExpr end)
: Range(make_object<RangeNode>( : Range(make_object<RangeNode>(
begin, begin,
......
...@@ -592,6 +592,10 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) ...@@ -592,6 +592,10 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
// stream << op->name << "." << op->type; // stream << op->name << "." << op->type;
p->stream << op->name_hint; p->stream << op->name_hint;
}) })
.set_dispatch<SizeVarNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const SizeVarNode*>(node.get());
p->stream << "{" << op->name_hint << "|" << op->name_hint << ">=0}";
})
.set_dispatch<AddNode>([](const ObjectRef& node, NodePrinter* p) { .set_dispatch<AddNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const AddNode*>(node.get()); auto* op = static_cast<const AddNode*>(node.get());
p->stream << '('; p->stream << '(';
...@@ -1143,6 +1147,7 @@ TVM_REGISTER_NODE_TYPE(IntImmNode); ...@@ -1143,6 +1147,7 @@ TVM_REGISTER_NODE_TYPE(IntImmNode);
TVM_REGISTER_NODE_TYPE(StringImmNode); TVM_REGISTER_NODE_TYPE(StringImmNode);
TVM_REGISTER_NODE_TYPE(CastNode); TVM_REGISTER_NODE_TYPE(CastNode);
TVM_REGISTER_NODE_TYPE(VarNode); TVM_REGISTER_NODE_TYPE(VarNode);
TVM_REGISTER_NODE_TYPE(SizeVarNode);
TVM_REGISTER_NODE_TYPE(AddNode); TVM_REGISTER_NODE_TYPE(AddNode);
TVM_REGISTER_NODE_TYPE(SubNode); TVM_REGISTER_NODE_TYPE(SubNode);
TVM_REGISTER_NODE_TYPE(MulNode); TVM_REGISTER_NODE_TYPE(MulNode);
......
...@@ -221,6 +221,10 @@ void StmtVisitor::VisitStmt_(const EvaluateNode* op) { ...@@ -221,6 +221,10 @@ void StmtVisitor::VisitStmt_(const EvaluateNode* op) {
void ExprVisitor::VisitExpr_(const VarNode* op) {} void ExprVisitor::VisitExpr_(const VarNode* op) {}
void ExprVisitor::VisitExpr_(const SizeVarNode* op) {
this->VisitExpr_(static_cast<const VarNode*>(op));
}
void ExprVisitor::VisitExpr_(const LoadNode* op) { void ExprVisitor::VisitExpr_(const LoadNode* op) {
this->VisitExpr(op->index); this->VisitExpr(op->index);
this->VisitExpr(op->predicate); this->VisitExpr(op->predicate);
...@@ -596,6 +600,10 @@ PrimExpr ExprMutator::VisitExpr_(const VarNode* op) { ...@@ -596,6 +600,10 @@ PrimExpr ExprMutator::VisitExpr_(const VarNode* op) {
return GetRef<PrimExpr>(op); return GetRef<PrimExpr>(op);
} }
PrimExpr ExprMutator::VisitExpr_(const SizeVarNode* op) {
return this->VisitExpr_(static_cast<const VarNode*>(op));
}
PrimExpr ExprMutator::VisitExpr_(const LoadNode* op) { PrimExpr ExprMutator::VisitExpr_(const LoadNode* op) {
PrimExpr index = this->VisitExpr(op->index); PrimExpr index = this->VisitExpr(op->index);
PrimExpr predicate = this->VisitExpr(op->predicate); PrimExpr predicate = this->VisitExpr(op->predicate);
......
...@@ -87,7 +87,7 @@ class IRConvertSSA final : public StmtExprMutator { ...@@ -87,7 +87,7 @@ class IRConvertSSA final : public StmtExprMutator {
const Var& v = op->var; const Var& v = op->var;
if (defined_.count(v.get())) { if (defined_.count(v.get())) {
PrimExpr value = this->VisitExpr(op->value); PrimExpr value = this->VisitExpr(op->value);
Var new_var = VarNode::make(v.dtype(), v->name_hint); Var new_var(v->name_hint, v.dtype());
scope_[v.get()].push_back(new_var); scope_[v.get()].push_back(new_var);
PrimExpr body = this->VisitExpr(op->body); PrimExpr body = this->VisitExpr(op->body);
scope_[v.get()].pop_back(); scope_[v.get()].pop_back();
...@@ -123,7 +123,7 @@ class IRConvertSSA final : public StmtExprMutator { ...@@ -123,7 +123,7 @@ class IRConvertSSA final : public StmtExprMutator {
const Var& v = op->var; const Var& v = op->var;
if (defined_.count(v.get())) { if (defined_.count(v.get())) {
PrimExpr value = this->VisitExpr(op->value); PrimExpr value = this->VisitExpr(op->value);
Var new_var = VarNode::make(v.dtype(), v->name_hint); Var new_var(v->name_hint, v.dtype());
scope_[v.get()].push_back(new_var); scope_[v.get()].push_back(new_var);
Stmt body = this->VisitStmt(op->body); Stmt body = this->VisitStmt(op->body);
scope_[v.get()].pop_back(); scope_[v.get()].pop_back();
...@@ -136,7 +136,7 @@ class IRConvertSSA final : public StmtExprMutator { ...@@ -136,7 +136,7 @@ class IRConvertSSA final : public StmtExprMutator {
Stmt VisitStmt_(const ForNode* op) final { Stmt VisitStmt_(const ForNode* op) final {
const Var& v = op->loop_var; const Var& v = op->loop_var;
if (defined_.count(v.get())) { if (defined_.count(v.get())) {
Var new_var = VarNode::make(v.dtype(), v->name_hint); Var new_var(v->name_hint, v.dtype());
scope_[v.get()].push_back(new_var); scope_[v.get()].push_back(new_var);
Stmt stmt = StmtExprMutator::VisitStmt_(op); Stmt stmt = StmtExprMutator::VisitStmt_(op);
scope_[v.get()].pop_back(); scope_[v.get()].pop_back();
...@@ -151,7 +151,7 @@ class IRConvertSSA final : public StmtExprMutator { ...@@ -151,7 +151,7 @@ class IRConvertSSA final : public StmtExprMutator {
Stmt VisitStmt_(const AllocateNode* op) final { Stmt VisitStmt_(const AllocateNode* op) final {
const Var& v = op->buffer_var; const Var& v = op->buffer_var;
if (defined_.count(v.get())) { if (defined_.count(v.get())) {
Var new_var = VarNode::make(v.dtype(), v->name_hint); Var new_var(v->name_hint, v.dtype());
scope_[v.get()].push_back(new_var); scope_[v.get()].push_back(new_var);
Stmt stmt = StmtExprMutator::VisitStmt_(op); Stmt stmt = StmtExprMutator::VisitStmt_(op);
scope_[v.get()].pop_back(); scope_[v.get()].pop_back();
......
...@@ -1108,7 +1108,7 @@ class TensorCoreIRMutator : public StmtExprMutator { ...@@ -1108,7 +1108,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
auto it2 = matrix_abc_.find(simplify_name(call->name)); auto it2 = matrix_abc_.find(simplify_name(call->name));
CHECK(it2 != matrix_abc_.end()) CHECK(it2 != matrix_abc_.end())
<< "Cannot find matrix info for " << call->name; << "Cannot find matrix info for " << call->name;
buffer_node->data = VarNode::make(DataType::Handle(), call->name); buffer_node->data = Var(call->name, DataType::Handle());
buffer_node->name = call->name; buffer_node->name = call->name;
buffer_node->scope = "wmma." + it2->second; buffer_node->scope = "wmma." + it2->second;
buffer_node->dtype = datatype; buffer_node->dtype = datatype;
......
...@@ -25,8 +25,8 @@ def test_static_tensor(): ...@@ -25,8 +25,8 @@ def test_static_tensor():
stype = 'csr' stype = 'csr'
target = 'llvm' target = 'llvm'
ctx = tvm.context(target, 0) ctx = tvm.context(target, 0)
m = tvm.var('m') m = tvm.size_var('m')
n = tvm.var('n') n = tvm.size_var('n')
A = tvmsp.placeholder(shape=(m, n), name='A', dtype=dtype) A = tvmsp.placeholder(shape=(m, n), name='A', dtype=dtype)
assert(A.stype == 'csr') assert(A.stype == 'csr')
n = 3 n = 3
...@@ -50,7 +50,7 @@ def test_dynamic_tensor(): ...@@ -50,7 +50,7 @@ def test_dynamic_tensor():
stype = 'csr' stype = 'csr'
target = 'llvm' target = 'llvm'
ctx = tvm.context(target, 0) ctx = tvm.context(target, 0)
nr, nc, n = tvm.var('nr'), tvm.var('nc'), tvm.var('n') nr, nc, n = tvm.size_var('nr'), tvm.size_var('nc'), tvm.size_var('n')
A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, name='A', dtype=dtype) A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, name='A', dtype=dtype)
assert(A.stype == 'csr') assert(A.stype == 'csr')
C = tvm.compute(A.data.shape, lambda i: A.data[i] * 2., tag='cs_scatter') C = tvm.compute(A.data.shape, lambda i: A.data[i] * 2., tag='cs_scatter')
...@@ -76,7 +76,7 @@ def test_sparse_array_tuple(): ...@@ -76,7 +76,7 @@ def test_sparse_array_tuple():
stype = 'csr' stype = 'csr'
target = 'llvm' target = 'llvm'
ctx = tvm.context(target, 0) ctx = tvm.context(target, 0)
nr, nc, n = tvm.var('nr'), tvm.var('nc'), tvm.var('n') nr, nc, n = tvm.size_var('nr'), tvm.size_var('nc'), tvm.size_var('n')
A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, name='A', dtype=dtype) A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, name='A', dtype=dtype)
assert(A.stype == 'csr') assert(A.stype == 'csr')
C = tvm.compute(A.data.shape, lambda i: A.data[i] * 2., tag='cs_scatter') C = tvm.compute(A.data.shape, lambda i: A.data[i] * 2., tag='cs_scatter')
......
...@@ -57,7 +57,7 @@ def test_exp(): ...@@ -57,7 +57,7 @@ def test_exp():
def test_fmod(): def test_fmod():
# graph # graph
def run(dtype): def run(dtype):
n = tvm.var('n') n = tvm.size_var('n')
A = tvm.placeholder((n,), name='A', dtype=dtype) A = tvm.placeholder((n,), name='A', dtype=dtype)
B = tvm.placeholder((n,), name='B', dtype=dtype) B = tvm.placeholder((n,), name='B', dtype=dtype)
C = tvm.compute(A.shape, lambda *i: tvm.fmod(A(*i), B(*i)), name='C') C = tvm.compute(A.shape, lambda *i: tvm.fmod(A(*i), B(*i)), name='C')
...@@ -140,7 +140,7 @@ def test_multiple_cache_write(): ...@@ -140,7 +140,7 @@ def test_multiple_cache_write():
def test_log_pow_llvm(): def test_log_pow_llvm():
# graph # graph
n = tvm.var('n') n = tvm.size_var('n')
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: tvm.power(tvm.log(A(*i)), 2.0), name='B') B = tvm.compute(A.shape, lambda *i: tvm.power(tvm.log(A(*i)), 2.0), name='B')
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
...@@ -207,7 +207,7 @@ def test_popcount(): ...@@ -207,7 +207,7 @@ def test_popcount():
def test_add(): def test_add():
def run(dtype): def run(dtype):
# graph # graph
n = tvm.var('n') n = tvm.size_var('n')
A = tvm.placeholder((n,), name='A', dtype=dtype) A = tvm.placeholder((n,), name='A', dtype=dtype)
B = tvm.placeholder((n,), name='B', dtype=dtype) B = tvm.placeholder((n,), name='B', dtype=dtype)
bias = tvm.var("bias", dtype=dtype) bias = tvm.var("bias", dtype=dtype)
......
...@@ -22,7 +22,6 @@ import time ...@@ -22,7 +22,6 @@ import time
def test_gemm(): def test_gemm():
# graph # graph
nn = 1024 nn = 1024
n = tvm.var('n')
n = tvm.convert(nn) n = tvm.convert(nn)
m = n m = n
l = n l = n
......
...@@ -21,8 +21,8 @@ import numpy as np ...@@ -21,8 +21,8 @@ import numpy as np
def test_reduce_prims(): def test_reduce_prims():
def test_prim(reducer, np_reducer): def test_prim(reducer, np_reducer):
# graph # graph
n = tvm.var('n') n = tvm.size_var('n')
m = tvm.var('m') m = tvm.size_var('m')
A = tvm.placeholder((n, m), name='A') A = tvm.placeholder((n, m), name='A')
R = tvm.compute((n, ), lambda i: tvm.expr.Select((i > 1), 1, 0), name='R') R = tvm.compute((n, ), lambda i: tvm.expr.Select((i > 1), 1, 0), name='R')
k = tvm.reduce_axis((0, m)) k = tvm.reduce_axis((0, m))
...@@ -242,8 +242,8 @@ def test_argmax(): ...@@ -242,8 +242,8 @@ def test_argmax():
argmax = tvm.comm_reducer(fcombine, argmax = tvm.comm_reducer(fcombine,
fidentity, fidentity,
name='argmax') name='argmax')
m = tvm.var('m') m = tvm.size_var('m')
n = tvm.var('n') n = tvm.size_var('n')
idx = tvm.placeholder((m, n), name='idx', dtype='int32') idx = tvm.placeholder((m, n), name='idx', dtype='int32')
val = tvm.placeholder((m, n), name='val', dtype='float32') val = tvm.placeholder((m, n), name='val', dtype='float32')
k = tvm.reduce_axis((0, n), 'k') k = tvm.reduce_axis((0, n), 'k')
......
...@@ -18,8 +18,8 @@ import tvm ...@@ -18,8 +18,8 @@ import tvm
import numpy as np import numpy as np
def test_scan(): def test_scan():
m = tvm.var("m") m = tvm.size_var("m")
n = tvm.var("n") n = tvm.size_var("n")
X = tvm.placeholder((m, n), name="X") X = tvm.placeholder((m, n), name="X")
s_state = tvm.placeholder((m, n)) s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i]) s_init = tvm.compute((1, n), lambda _, i: X[0, i])
......
...@@ -70,7 +70,7 @@ def test_env(): ...@@ -70,7 +70,7 @@ def test_env():
def test_meta_data(): def test_meta_data():
n, c, h, w = tvm.var("n"), 10, 224, 224 n, c, h, w = tvm.size_var("n"), 10, 224, 224
x = relay.var("x", shape=(n, c, h, w)) x = relay.var("x", shape=(n, c, h, w))
w = relay.var("w") w = relay.var("w")
z = relay.nn.conv2d(x, w, z = relay.nn.conv2d(x, w,
...@@ -82,8 +82,8 @@ def test_meta_data(): ...@@ -82,8 +82,8 @@ def test_meta_data():
text_no_meta = str(f) text_no_meta = str(f)
assert "channels=2" in text assert "channels=2" in text
assert "channels=2" in text_no_meta assert "channels=2" in text_no_meta
assert "meta[Variable][0]" in text assert "meta[SizeVar][0]" in text
assert "meta[Variable][0]" in text_no_meta assert "meta[SizeVar][0]" in text_no_meta
assert "type_key" in text assert "type_key" in text
assert "type_key" not in text_no_meta assert "type_key" not in text_no_meta
......
...@@ -177,7 +177,7 @@ def test_bias_add(): ...@@ -177,7 +177,7 @@ def test_bias_add():
def test_expand_dims_infer_type(): def test_expand_dims_infer_type():
for dtype in ['float16', 'float32']: for dtype in ['float16', 'float32']:
n, t, d = tvm.var("n"), tvm.var("t"), 100 n, t, d = tvm.size_var("n"), tvm.size_var("t"), 100
x = relay.var("x", shape=(n, t, d), dtype=dtype) x = relay.var("x", shape=(n, t, d), dtype=dtype)
y = relay.expand_dims(x, axis=2) y = relay.expand_dims(x, axis=2)
assert "axis=2" in y.astext() assert "axis=2" in y.astext()
...@@ -227,7 +227,7 @@ def test_log_softmax(): ...@@ -227,7 +227,7 @@ def test_log_softmax():
def test_concatenate(): def test_concatenate():
for dtype in ['float16', 'float32']: for dtype in ['float16', 'float32']:
n, t, d = tvm.var("n"), tvm.var("t"), 100 n, t, d = tvm.size_var("n"), tvm.size_var("t"), 100
x = relay.var("x", shape=(n, t, d)) x = relay.var("x", shape=(n, t, d))
y = relay.var("y", shape=(n, t, d)) y = relay.var("y", shape=(n, t, d))
z = relay.concatenate((x, y), axis=-1) z = relay.concatenate((x, y), axis=-1)
...@@ -280,7 +280,7 @@ def test_concatenate(): ...@@ -280,7 +280,7 @@ def test_concatenate():
def test_dropout(): def test_dropout():
for dtype in ['float16', 'float32']: for dtype in ['float16', 'float32']:
n, t, d = tvm.var("n"), tvm.var("t"), tvm.var("d") n, t, d = tvm.size_var("n"), tvm.size_var("t"), tvm.size_var("d")
input_ty = relay.TensorType((n, t, d), dtype) input_ty = relay.TensorType((n, t, d), dtype)
x = relay.var("x", input_ty) x = relay.var("x", input_ty)
y = relay.nn.dropout(x, rate=0.75) y = relay.nn.dropout(x, rate=0.75)
...@@ -342,7 +342,7 @@ def test_dense(): ...@@ -342,7 +342,7 @@ def test_dense():
# Dense accuracy for float16 is poor # Dense accuracy for float16 is poor
if dtype == 'float16': if dtype == 'float16':
return return
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") n, c , h, w = tvm.size_var("n"), tvm.size_var("c"), tvm.size_var("h"), tvm.size_var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), dtype)) x = relay.var("x", relay.TensorType((n, c, h, w), dtype))
w = relay.var("w", relay.TensorType((2, w), dtype)) w = relay.var("w", relay.TensorType((2, w), dtype))
y = relay.nn.dense(x, w, units=2) y = relay.nn.dense(x, w, units=2)
...@@ -350,15 +350,15 @@ def test_dense(): ...@@ -350,15 +350,15 @@ def test_dense():
yy = run_infer_type(y) yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, h, 2), dtype) assert yy.checked_type == relay.TensorType((n, c, h, 2), dtype)
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2 n, c , h, w = tvm.size_var("n"), tvm.size_var("c"), tvm.size_var("h"), 2
x = relay.var("x", relay.TensorType((n, c, h, w), dtype)) x = relay.var("x", relay.TensorType((n, c, h, w), dtype))
wh, ww = tvm.var("wh"), tvm.var("ww") wh, ww = tvm.size_var("wh"), tvm.size_var("ww")
w = relay.var("w", relay.TensorType((ww, wh), dtype)) w = relay.var("w", relay.TensorType((ww, wh), dtype))
y = relay.nn.dense(x, w) y = relay.nn.dense(x, w)
yy = run_infer_type(y) yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, h, ww), dtype) assert yy.checked_type == relay.TensorType((n, c, h, ww), dtype)
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2 n, c , h, w = tvm.size_var("n"), tvm.size_var("c"), tvm.size_var("h"), 2
x = relay.var("x", relay.TensorType((n, c, h, w), dtype)) x = relay.var("x", relay.TensorType((n, c, h, w), dtype))
w = relay.var("w", relay.IncompleteType()) w = relay.var("w", relay.IncompleteType())
y = relay.nn.dense(x, w, units=2) y = relay.nn.dense(x, w, units=2)
...@@ -388,7 +388,7 @@ def test_dense_dtype(): ...@@ -388,7 +388,7 @@ def test_dense_dtype():
data_dtype = 'uint8' data_dtype = 'uint8'
weight_dtype = 'int8' weight_dtype = 'int8'
out_dtype = 'uint8' out_dtype = 'uint8'
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") n, c , h, w = tvm.size_var("n"), tvm.size_var("c"), tvm.size_var("h"), tvm.size_var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), data_dtype)) x = relay.var("x", relay.TensorType((n, c, h, w), data_dtype))
w = relay.var("w", relay.TensorType((2, w), weight_dtype)) w = relay.var("w", relay.TensorType((2, w), weight_dtype))
y = relay.nn.dense(x, w, units=2, out_dtype=out_dtype) y = relay.nn.dense(x, w, units=2, out_dtype=out_dtype)
...@@ -400,7 +400,7 @@ def test_dense_dtype(): ...@@ -400,7 +400,7 @@ def test_dense_dtype():
def test_bitserial_dense(): def test_bitserial_dense():
m, k = tvm.var("m"), tvm.var("k") m, k = tvm.size_var("m"), tvm.size_var("k")
x = relay.var("x", relay.TensorType((m, k), "int16")) x = relay.var("x", relay.TensorType((m, k), "int16"))
w = relay.var("w", relay.TensorType((k, 32), "int16")) w = relay.var("w", relay.TensorType((k, 32), "int16"))
y = relay.nn.bitserial_dense(x, w, units=32) y = relay.nn.bitserial_dense(x, w, units=32)
......
...@@ -309,7 +309,7 @@ def verify_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"): ...@@ -309,7 +309,7 @@ def verify_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"):
tvm.testing.assert_allclose(z.asnumpy(), z_np, rtol=1e-5) tvm.testing.assert_allclose(z.asnumpy(), z_np, rtol=1e-5)
def test_batch_matmul(): def test_batch_matmul():
b, m, n, k = tvm.var("b"), tvm.var("m"), tvm.var("n"), tvm.var("k") b, m, n, k = tvm.size_var("b"), tvm.size_var("m"), tvm.size_var("n"), tvm.size_var("k")
x = relay.var("x", relay.TensorType((b, m, k), "float32")) x = relay.var("x", relay.TensorType((b, m, k), "float32"))
y = relay.var("y", relay.TensorType((b, n, k), "float32")) y = relay.var("y", relay.TensorType((b, n, k), "float32"))
z = relay.nn.batch_matmul(x, y) z = relay.nn.batch_matmul(x, y)
......
...@@ -128,7 +128,7 @@ def test_conv1d_run(): ...@@ -128,7 +128,7 @@ def test_conv1d_run():
def test_conv2d_infer_type(): def test_conv2d_infer_type():
# symbolic in batch dimension # symbolic in batch dimension
n, c, h, w = tvm.var("n"), 10, 224, 224 n, c, h, w = tvm.size_var("n"), 10, 224, 224
x = relay.var("x", relay.ty.TensorType((n, c, h, w), "float32")) x = relay.var("x", relay.ty.TensorType((n, c, h, w), "float32"))
w = relay.var("w") w = relay.var("w")
y = relay.nn.conv2d(x, w, y = relay.nn.conv2d(x, w,
...@@ -142,7 +142,7 @@ def test_conv2d_infer_type(): ...@@ -142,7 +142,7 @@ def test_conv2d_infer_type():
(2, 10, 3, 3), "float32") (2, 10, 3, 3), "float32")
# infer by shape of w, mixed precision # infer by shape of w, mixed precision
n, c, h, w = tvm.var("n"), 10, 224, 224 n, c, h, w = tvm.size_var("n"), 10, 224, 224
x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) x = relay.var("x", relay.TensorType((n, c, h, w), "int8"))
w = relay.var("w", relay.TensorType((2, 10, 3, 3), "int8")) w = relay.var("w", relay.TensorType((2, 10, 3, 3), "int8"))
y = relay.nn.conv2d(x, w, out_dtype="int32") y = relay.nn.conv2d(x, w, out_dtype="int32")
...@@ -152,7 +152,7 @@ def test_conv2d_infer_type(): ...@@ -152,7 +152,7 @@ def test_conv2d_infer_type():
(n, 2, 222, 222), "int32") (n, 2, 222, 222), "int32")
# infer shape in case of different dtypes for input and weight. # infer shape in case of different dtypes for input and weight.
n, c, h, w = tvm.var("n"), 10, 224, 224 n, c, h, w = tvm.size_var("n"), 10, 224, 224
x = relay.var("x", relay.TensorType((n, c, h, w), "uint8")) x = relay.var("x", relay.TensorType((n, c, h, w), "uint8"))
w = relay.var("w", relay.TensorType((2, 10, 3, 3), "int8")) w = relay.var("w", relay.TensorType((2, 10, 3, 3), "int8"))
y = relay.nn.conv2d(x, w, out_dtype="int32") y = relay.nn.conv2d(x, w, out_dtype="int32")
...@@ -391,7 +391,7 @@ def test_conv2d_winograd(): ...@@ -391,7 +391,7 @@ def test_conv2d_winograd():
def test_conv3d_infer_type(): def test_conv3d_infer_type():
# symbolic in batch dimension # symbolic in batch dimension
n, c, d, h, w = tvm.var("n"), 10, 224, 224, 224 n, c, d, h, w = tvm.size_var("n"), 10, 224, 224, 224
x = relay.var("x", relay.ty.TensorType((n, c, d, h, w), "float32")) x = relay.var("x", relay.ty.TensorType((n, c, d, h, w), "float32"))
w = relay.var("w") w = relay.var("w")
y = relay.nn.conv3d(x, w, y = relay.nn.conv3d(x, w,
...@@ -405,7 +405,7 @@ def test_conv3d_infer_type(): ...@@ -405,7 +405,7 @@ def test_conv3d_infer_type():
(2, 10, 3, 3, 3), "float32") (2, 10, 3, 3, 3), "float32")
# infer by shape of w, mixed precision # infer by shape of w, mixed precision
n, c, d, h, w = tvm.var("n"), 10, 224, 224, 224 n, c, d, h, w = tvm.size_var("n"), 10, 224, 224, 224
x = relay.var("x", relay.TensorType((n, c, d, h, w), "int8")) x = relay.var("x", relay.TensorType((n, c, d, h, w), "int8"))
w = relay.var("w", relay.TensorType((2, 10, 3, 3, 3), "int8")) w = relay.var("w", relay.TensorType((2, 10, 3, 3, 3), "int8"))
y = relay.nn.conv3d(x, w, out_dtype="int32") y = relay.nn.conv3d(x, w, out_dtype="int32")
...@@ -415,7 +415,7 @@ def test_conv3d_infer_type(): ...@@ -415,7 +415,7 @@ def test_conv3d_infer_type():
(n, 2, 222, 222, 222), "int32") (n, 2, 222, 222, 222), "int32")
# infer shape in case of different dtypes for input and weight. # infer shape in case of different dtypes for input and weight.
n, c, d, h, w = tvm.var("n"), 10, 224, 224, 224 n, c, d, h, w = tvm.size_var("n"), 10, 224, 224, 224
x = relay.var("x", relay.TensorType((n, c, d, h, w), "uint8")) x = relay.var("x", relay.TensorType((n, c, d, h, w), "uint8"))
w = relay.var("w", relay.TensorType((2, 10, 3, 3, 3), "int8")) w = relay.var("w", relay.TensorType((2, 10, 3, 3, 3), "int8"))
y = relay.nn.conv3d(x, w, out_dtype="int32") y = relay.nn.conv3d(x, w, out_dtype="int32")
...@@ -530,7 +530,7 @@ def test_conv3d_ndhwc_run(): ...@@ -530,7 +530,7 @@ def test_conv3d_ndhwc_run():
def test_conv2d_transpose_infer_type(): def test_conv2d_transpose_infer_type():
# symbolic in batch dimension # symbolic in batch dimension
n, c, h, w = tvm.var("n"), 10, 10, 12 n, c, h, w = tvm.size_var("n"), 10, 10, 12
x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
w = relay.var("w", relay.IncompleteType()) w = relay.var("w", relay.IncompleteType())
y = relay.nn.conv2d_transpose(x, w, y = relay.nn.conv2d_transpose(x, w,
...@@ -545,7 +545,7 @@ def test_conv2d_transpose_infer_type(): ...@@ -545,7 +545,7 @@ def test_conv2d_transpose_infer_type():
(10, 15, 3, 3), "float32") (10, 15, 3, 3), "float32")
# infer by shape of w, mixed precision # infer by shape of w, mixed precision
n, h, w, c = tvm.var("n"), 10, 10, 12 n, h, w, c = tvm.size_var("n"), 10, 10, 12
x = relay.var("x", relay.TensorType((n, h, w, c), "float32")) x = relay.var("x", relay.TensorType((n, h, w, c), "float32"))
w = relay.var("w", relay.TensorType((12, 11, 5, 5), "float32")) w = relay.var("w", relay.TensorType((12, 11, 5, 5), "float32"))
y = relay.nn.conv2d_transpose(x, w, y = relay.nn.conv2d_transpose(x, w,
...@@ -630,7 +630,7 @@ def test_conv1d_transpose_ncw_run(): ...@@ -630,7 +630,7 @@ def test_conv1d_transpose_ncw_run():
def test_upsampling_infer_type(): def test_upsampling_infer_type():
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") n, c , h, w = tvm.size_var("n"), tvm.size_var("c"), tvm.size_var("h"), tvm.size_var("w")
scale = tvm.const(2.0, "float64") scale = tvm.const(2.0, "float64")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
y = relay.nn.upsampling(x, scale_h=2, scale_w=2, layout="NCHW", method="bilinear") y = relay.nn.upsampling(x, scale_h=2, scale_w=2, layout="NCHW", method="bilinear")
...@@ -639,14 +639,15 @@ def test_upsampling_infer_type(): ...@@ -639,14 +639,15 @@ def test_upsampling_infer_type():
assert yy.checked_type == relay.TensorType((n, c, tvm.expr.Cast("int32", tvm.round(h*scale)), assert yy.checked_type == relay.TensorType((n, c, tvm.expr.Cast("int32", tvm.round(h*scale)),
tvm.expr.Cast("int32", tvm.round(w*scale))), tvm.expr.Cast("int32", tvm.round(w*scale))),
"float32") "float32")
n, c = tvm.var("n"), tvm.var("c") n, c = tvm.size_var("n"), tvm.size_var("c")
x = relay.var("x", relay.TensorType((n, c, 100, 200), "float32")) x = relay.var("x", relay.TensorType((n, c, 100, 200), "float32"))
y = relay.nn.upsampling(x, scale_h=2, scale_w=2, layout="NCHW", method="bilinear") y = relay.nn.upsampling(x, scale_h=2, scale_w=2, layout="NCHW", method="bilinear")
yy = run_infer_type(y) yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, 200, 400), "float32") assert yy.checked_type == relay.TensorType((n, c, 200, 400), "float32")
def test_upsampling3d_infer_type(): def test_upsampling3d_infer_type():
n, c, d, h, w = tvm.var("n"), tvm.var("c"), tvm.var("d"), tvm.var("h"), tvm.var("w") n, c, d, h, w = tvm.size_var("n"), tvm.size_var("c"),\
tvm.size_var("d"), tvm.size_var("h"), tvm.size_var("w")
scale = tvm.const(2.0, "float64") scale = tvm.const(2.0, "float64")
x = relay.var("x", relay.TensorType((n, c, d, h, w), "float32")) x = relay.var("x", relay.TensorType((n, c, d, h, w), "float32"))
y = relay.nn.upsampling3d(x, scale_d=2, scale_h=2, scale_w=2, layout="NCDHW", method="trilinear") y = relay.nn.upsampling3d(x, scale_d=2, scale_h=2, scale_w=2, layout="NCDHW", method="trilinear")
...@@ -656,14 +657,14 @@ def test_upsampling3d_infer_type(): ...@@ -656,14 +657,14 @@ def test_upsampling3d_infer_type():
tvm.expr.Cast("int32", tvm.round(h*scale)), tvm.expr.Cast("int32", tvm.round(h*scale)),
tvm.expr.Cast("int32", tvm.round(w*scale))), tvm.expr.Cast("int32", tvm.round(w*scale))),
"float32") "float32")
n, c = tvm.var("n"), tvm.var("c") n, c = tvm.size_var("n"), tvm.size_var("c")
x = relay.var("x", relay.TensorType((n, c, 100, 100, 200), "float32")) x = relay.var("x", relay.TensorType((n, c, 100, 100, 200), "float32"))
y = relay.nn.upsampling3d(x, scale_d=2, scale_h=2, scale_w=2, layout="NCDHW", method="trilinear") y = relay.nn.upsampling3d(x, scale_d=2, scale_h=2, scale_w=2, layout="NCDHW", method="trilinear")
yy = run_infer_type(y) yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, 200, 200, 400), "float32") assert yy.checked_type == relay.TensorType((n, c, 200, 200, 400), "float32")
def _test_pool2d(opfunc, reffunc): def _test_pool2d(opfunc, reffunc):
n, c, h, w = tvm.var("n"), 10, 224, 224 n, c, h, w = tvm.size_var("n"), 10, 224, 224
x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
y = opfunc(x, pool_size=(1, 1)) y = opfunc(x, pool_size=(1, 1))
assert "pool_size=" in y.astext() assert "pool_size=" in y.astext()
...@@ -683,7 +684,7 @@ def _test_pool2d(opfunc, reffunc): ...@@ -683,7 +684,7 @@ def _test_pool2d(opfunc, reffunc):
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
def _test_pool2d_int(opfunc, reffunc, dtype): def _test_pool2d_int(opfunc, reffunc, dtype):
n, c, h, w = tvm.var("n"), 10, 224, 224 n, c, h, w = tvm.size_var("n"), 10, 224, 224
x = relay.var("x", relay.TensorType((n, c, h, w), dtype)) x = relay.var("x", relay.TensorType((n, c, h, w), dtype))
y = opfunc(x, pool_size=(1, 1)) y = opfunc(x, pool_size=(1, 1))
assert "pool_size=" in y.astext() assert "pool_size=" in y.astext()
...@@ -703,13 +704,13 @@ def _test_pool2d_int(opfunc, reffunc, dtype): ...@@ -703,13 +704,13 @@ def _test_pool2d_int(opfunc, reffunc, dtype):
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
def _test_global_pool2d(opfunc, reffunc): def _test_global_pool2d(opfunc, reffunc):
n, c, h, w = tvm.var("n"), tvm.var("c"), 224, 224 n, c, h, w = tvm.size_var("n"), tvm.size_var("c"), 224, 224
x = relay.var("x", relay.TensorType((n, h, w, c), "float32")) x = relay.var("x", relay.TensorType((n, h, w, c), "float32"))
y = opfunc(x, layout="NHWC") y = opfunc(x, layout="NHWC")
yy = run_infer_type(y) yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, 1, 1, c), "float32") assert yy.checked_type == relay.TensorType((n, 1, 1, c), "float32")
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") n, c, h, w = tvm.size_var("n"), tvm.size_var("c"), tvm.size_var("h"), tvm.size_var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
y = opfunc(x) y = opfunc(x)
yy = run_infer_type(y) yy = run_infer_type(y)
...@@ -768,7 +769,7 @@ def test_pool1d(): ...@@ -768,7 +769,7 @@ def test_pool1d():
def test_pool3d(): def test_pool3d():
def _test_pool3d(opfunc): def _test_pool3d(opfunc):
n, c, d, h, w = tvm.var("n"), 10, 5, 224, 224 n, c, d, h, w = tvm.size_var("n"), 10, 5, 224, 224
x = relay.var("x", relay.TensorType((n, c, d, h, w), "float32")) x = relay.var("x", relay.TensorType((n, c, d, h, w), "float32"))
y = opfunc(x, pool_size=(1, 1, 1)) y = opfunc(x, pool_size=(1, 1, 1))
assert "pool_size=" in y.astext() assert "pool_size=" in y.astext()
...@@ -828,7 +829,7 @@ def test_avg_pool2d_no_count_pad(): ...@@ -828,7 +829,7 @@ def test_avg_pool2d_no_count_pad():
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
def test_flatten_infer_type(): def test_flatten_infer_type():
d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4") d1, d2, d3, d4 = tvm.size_var("d1"), tvm.size_var("d2"), tvm.size_var("d3"), tvm.size_var("d4")
x = relay.var("x", relay.TensorType((d1, d2, d3, d4), "float32")) x = relay.var("x", relay.TensorType((d1, d2, d3, d4), "float32"))
y = relay.nn.batch_flatten(x) y = relay.nn.batch_flatten(x)
yy = run_infer_type(y) yy = run_infer_type(y)
...@@ -873,7 +874,7 @@ def test_pad_infer_type(): ...@@ -873,7 +874,7 @@ def test_pad_infer_type():
assert yy.checked_type == relay.TensorType((3, 6, 9, 12), "float32") assert yy.checked_type == relay.TensorType((3, 6, 9, 12), "float32")
# some symbolic values # some symbolic values
n, c, h, w = tvm.var("n"), 2, 3, tvm.var("w") n, c, h, w = tvm.size_var("n"), 2, 3, tvm.size_var("w")
t = relay.var("t", relay.TensorType((n, c, h, w), "float32")) t = relay.var("t", relay.TensorType((n, c, h, w), "float32"))
y = relay.nn.pad(t, ((1, 1), (2, 2), (3, 3), (4, 4))) y = relay.nn.pad(t, ((1, 1), (2, 2), (3, 3), (4, 4)))
yy = run_infer_type(y) yy = run_infer_type(y)
...@@ -896,7 +897,7 @@ def test_pad_run(): ...@@ -896,7 +897,7 @@ def test_pad_run():
_test_run('int32') _test_run('int32')
def test_lrn(): def test_lrn():
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") n, c , h, w = tvm.size_var("n"), tvm.size_var("c"), tvm.size_var("h"), tvm.size_var("w")
x = relay.var("x", shape=(n, c , h, w)) x = relay.var("x", shape=(n, c , h, w))
y = relay.nn.lrn(x, size=10, axis=2, bias=0.5, alpha=.00001, beta=0.75) y = relay.nn.lrn(x, size=10, axis=2, bias=0.5, alpha=.00001, beta=0.75)
"alpha=" in y.astext() "alpha=" in y.astext()
...@@ -927,7 +928,7 @@ def test_lrn(): ...@@ -927,7 +928,7 @@ def test_lrn():
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
def test_l2_normalize(): def test_l2_normalize():
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") n, c , h, w = tvm.size_var("n"), tvm.size_var("c"), tvm.size_var("h"), tvm.size_var("w")
x = relay.var("x", shape=(n, c , h, w)) x = relay.var("x", shape=(n, c , h, w))
y = relay.nn.l2_normalize(x, eps=0.001, axis=[1]) y = relay.nn.l2_normalize(x, eps=0.001, axis=[1])
"axis=" in y.astext() "axis=" in y.astext()
...@@ -977,7 +978,7 @@ def test_batch_flatten(): ...@@ -977,7 +978,7 @@ def test_batch_flatten():
def _test_upsampling(layout, method, align_corners=False): def _test_upsampling(layout, method, align_corners=False):
n, c, h, w = tvm.var("n"), 16, 32, 32 n, c, h, w = tvm.size_var("n"), 16, 32, 32
scale_h = 2.0 scale_h = 2.0
scale_w = 2.0 scale_w = 2.0
dtype = "float32" dtype = "float32"
...@@ -1016,7 +1017,7 @@ def test_upsampling(): ...@@ -1016,7 +1017,7 @@ def test_upsampling():
_test_upsampling("NHWC", "bilinear", True) _test_upsampling("NHWC", "bilinear", True)
def _test_upsampling3d(layout, method, coordinate_transformation_mode="half_pixel"): def _test_upsampling3d(layout, method, coordinate_transformation_mode="half_pixel"):
n, c, d, h, w = tvm.var("n"), 8, 16, 16, 16 n, c, d, h, w = tvm.size_var("n"), 8, 16, 16, 16
scale_d = 2.0 scale_d = 2.0
scale_h = 2.0 scale_h = 2.0
scale_w = 2.0 scale_w = 2.0
...@@ -1183,7 +1184,7 @@ def test_conv2d_int8_intrinsics(): ...@@ -1183,7 +1184,7 @@ def test_conv2d_int8_intrinsics():
def test_bitserial_conv2d_infer_type(): def test_bitserial_conv2d_infer_type():
# Basic shape test with ambiguous batch. # Basic shape test with ambiguous batch.
n, c, h, w = tvm.var("n"), 32, 224, 224 n, c, h, w = tvm.size_var("n"), 32, 224, 224
x = relay.var("x", relay.ty.TensorType((n, c, h, w), "int16")) x = relay.var("x", relay.ty.TensorType((n, c, h, w), "int16"))
w = relay.var("w", relay.ty.TensorType((32, 32, 3, 3), "int16")) w = relay.var("w", relay.ty.TensorType((32, 32, 3, 3), "int16"))
y = relay.nn.bitserial_conv2d( y = relay.nn.bitserial_conv2d(
......
...@@ -171,7 +171,7 @@ def test_squeeze(): ...@@ -171,7 +171,7 @@ def test_squeeze():
def test_transpose_infer_type(): def test_transpose_infer_type():
n, t, d = tvm.var("n"), tvm.var("t"), 100 n, t, d = tvm.size_var("n"), tvm.size_var("t"), 100
x = relay.var("x", relay.TensorType((n, t, d), "float32")) x = relay.var("x", relay.TensorType((n, t, d), "float32"))
y = relay.transpose(x, axes=(1, 0, 2)) y = relay.transpose(x, axes=(1, 0, 2))
assert "axes=" in y.astext() assert "axes=" in y.astext()
...@@ -279,7 +279,7 @@ def test_reshape_like_infer_type(): ...@@ -279,7 +279,7 @@ def test_reshape_like_infer_type():
assert zz.checked_type == relay.TensorType((1, 6), "float32") assert zz.checked_type == relay.TensorType((1, 6), "float32")
# symbolic shape # symbolic shape
n, c, h, w = tvm.var("n"), 2, 3, tvm.var("w") n, c, h, w = tvm.size_var("n"), 2, 3, tvm.size_var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
y = relay.var("y", relay.TensorType((1, 8, 8), "float32")) y = relay.var("y", relay.TensorType((1, 8, 8), "float32"))
z = relay.reshape_like(x, y) z = relay.reshape_like(x, y)
...@@ -452,7 +452,7 @@ def test_full_like_infer_type(): ...@@ -452,7 +452,7 @@ def test_full_like_infer_type():
assert yy.checked_type == relay.TensorType((1, 2, 3), "float32") assert yy.checked_type == relay.TensorType((1, 2, 3), "float32")
# symbolic shape # symbolic shape
n, c, h, w = tvm.var("n"), 2, 3, tvm.var("w") n, c, h, w = tvm.size_var("n"), 2, 3, tvm.size_var("w")
base = relay.var("base", relay.TensorType((n, c, h, w), "float32")) base = relay.var("base", relay.TensorType((n, c, h, w), "float32"))
fill = relay.var("fill", relay.TensorType((), "float32")) fill = relay.var("fill", relay.TensorType((), "float32"))
y = relay.full_like(base, fill) y = relay.full_like(base, fill)
...@@ -480,7 +480,7 @@ def test_full_like(): ...@@ -480,7 +480,7 @@ def test_full_like():
def test_infer_type_leaky_relu(): def test_infer_type_leaky_relu():
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") n, c , h, w = tvm.size_var("n"), tvm.size_var("c"), tvm.size_var("h"), tvm.size_var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
y = relay.nn.leaky_relu(x, alpha=0.1) y = relay.nn.leaky_relu(x, alpha=0.1)
"alpha=0.1" in y.astext() "alpha=0.1" in y.astext()
...@@ -544,7 +544,7 @@ def verify_infer_type_prelu(data, alpha, axis, output, dtype="float32"): ...@@ -544,7 +544,7 @@ def verify_infer_type_prelu(data, alpha, axis, output, dtype="float32"):
def test_infer_type_prelu(): def test_infer_type_prelu():
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") n, c , h, w = tvm.size_var("n"), tvm.size_var("c"), tvm.size_var("h"), tvm.size_var("w")
verify_infer_type_prelu((n, c, h, w), (c,), 1, (n, c, h, w)) verify_infer_type_prelu((n, c, h, w), (c,), 1, (n, c, h, w))
verify_infer_type_prelu((n, h, w, c), (c,), 3, (n, h, w, c)) verify_infer_type_prelu((n, h, w, c), (c,), 3, (n, h, w, c))
verify_infer_type_prelu((n, c, h, w), None, 1, (n, c, h, w)) verify_infer_type_prelu((n, c, h, w), None, 1, (n, c, h, w))
......
...@@ -29,7 +29,7 @@ def run_infer_type(expr): ...@@ -29,7 +29,7 @@ def run_infer_type(expr):
def test_binary_op(): def test_binary_op():
def check_binary_op(opfunc, ref): def check_binary_op(opfunc, ref):
n = tvm.var("n") n = tvm.size_var("n")
t1 = relay.TensorType((5, n, 5)) t1 = relay.TensorType((5, n, 5))
t2 = relay.TensorType((n, 1)) t2 = relay.TensorType((n, 1))
x = relay.var("x", t1) x = relay.var("x", t1)
......
...@@ -31,7 +31,7 @@ def run_infer_type(expr): ...@@ -31,7 +31,7 @@ def run_infer_type(expr):
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
def test_resize_infer_type(): def test_resize_infer_type():
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") n, c, h, w = tvm.size_var("n"), tvm.size_var("c"), tvm.size_var("h"), tvm.size_var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) x = relay.var("x", relay.TensorType((n, c, h, w), "int8"))
th, tw = tvm.var("th"), tvm.var("tw") th, tw = tvm.var("th"), tvm.var("tw")
z = relay.image.resize(x, (th, tw)) z = relay.image.resize(x, (th, tw))
...@@ -187,7 +187,7 @@ def test_multibox_prior(): ...@@ -187,7 +187,7 @@ def test_multibox_prior():
x = relay.var("x", relay.TensorType(dshape, "float32")) x = relay.var("x", relay.TensorType(dshape, "float32"))
verify_multibox_prior(x, dshape, ref_res, sizes, ratios, steps, offsets, verify_multibox_prior(x, dshape, ref_res, sizes, ratios, steps, offsets,
check_size=True) check_size=True)
y = relay.var("y", relay.TensorType((tvm.var("n"), 3, 56, 56), "float32")) y = relay.var("y", relay.TensorType((tvm.size_var("n"), 3, 56, 56), "float32"))
verify_multibox_prior(x, dshape, ref_res, sizes, ratios, steps, offsets, verify_multibox_prior(x, dshape, ref_res, sizes, ratios, steps, offsets,
check_size=True, check_type_only=True) check_size=True, check_type_only=True)
...@@ -195,7 +195,7 @@ def test_multibox_prior(): ...@@ -195,7 +195,7 @@ def test_multibox_prior():
ref_res = get_ref_result(dshape, clip=False) ref_res = get_ref_result(dshape, clip=False)
x = relay.var("x", relay.TensorType(dshape, "float32")) x = relay.var("x", relay.TensorType(dshape, "float32"))
verify_multibox_prior(x, dshape, ref_res, clip=False) verify_multibox_prior(x, dshape, ref_res, clip=False)
y = relay.var("y", relay.TensorType((tvm.var("n"), 24, 32, 32), "float32")) y = relay.var("y", relay.TensorType((tvm.size_var("n"), 24, 32, 32), "float32"))
verify_multibox_prior(x, dshape, ref_res, clip=False, check_type_only=True) verify_multibox_prior(x, dshape, ref_res, clip=False, check_type_only=True)
...@@ -287,7 +287,7 @@ def test_non_max_suppression(): ...@@ -287,7 +287,7 @@ def test_non_max_suppression():
np_indices_result = np.array([[3, 0, -1, -1, -1]]) np_indices_result = np.array([[3, 0, -1, -1, -1]])
num_anchors = 5 num_anchors = 5
dshape = (tvm.var("n"), num_anchors, 6) dshape = (tvm.size_var("n"), num_anchors, 6)
verify_nms(np_data, np_valid_count, dshape, np_result, np_indices_result, verify_nms(np_data, np_valid_count, dshape, np_result, np_indices_result,
force_suppress=True, top_k=2, check_type_only=True) force_suppress=True, top_k=2, check_type_only=True)
dshape = (1, num_anchors, 6) dshape = (1, num_anchors, 6)
...@@ -298,7 +298,7 @@ def test_non_max_suppression(): ...@@ -298,7 +298,7 @@ def test_non_max_suppression():
[1, 0.7, 30, 60, 50, 80], [-1, -1, -1, -1, -1, -1], [1, 0.7, 30, 60, 50, 80], [-1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]]]) [-1, -1, -1, -1, -1, -1]]])
np_indices_result = np.array([[3, 0, 1, -1, -1]]) np_indices_result = np.array([[3, 0, 1, -1, -1]])
dshape = (tvm.var("n"), num_anchors, 6) dshape = (tvm.size_var("n"), num_anchors, 6)
verify_nms(np_data, np_valid_count, dshape, np_result, verify_nms(np_data, np_valid_count, dshape, np_result,
np_indices_result, check_type_only=True) np_indices_result, check_type_only=True)
dshape = (1, num_anchors, 6) dshape = (1, num_anchors, 6)
...@@ -361,7 +361,7 @@ def test_multibox_transform_loc(): ...@@ -361,7 +361,7 @@ def test_multibox_transform_loc():
def test_threshold(): def test_threshold():
num_anchors = 5 num_anchors = 5
num_classes = 5 num_classes = 5
n = tvm.var("n") n = tvm.size_var("n")
cls_prob = relay.var( cls_prob = relay.var(
"cls_prob", "cls_prob",
relay.ty.TensorType((n, num_anchors, num_classes), "float32")) relay.ty.TensorType((n, num_anchors, num_classes), "float32"))
...@@ -527,7 +527,7 @@ def test_yolo_reorg_infer_shape(): ...@@ -527,7 +527,7 @@ def test_yolo_reorg_infer_shape():
assert "stride=" in z.astext() assert "stride=" in z.astext()
assert zz.checked_type == relay.ty.TensorType(out_shape, "float32") assert zz.checked_type == relay.ty.TensorType(out_shape, "float32")
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") n, c, h, w = tvm.size_var("n"), tvm.size_var("c"), tvm.size_var("h"), tvm.size_var("w")
idxd = tvm.indexdiv idxd = tvm.indexdiv
verify_yolo_reorg((n, c, 20, 20), 10, (n, c*10*10, 2, 2)) verify_yolo_reorg((n, c, 20, 20), 10, (n, c*10*10, 2, 2))
verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, idxd(h, 2), idxd(w, 2))) verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, idxd(h, 2), idxd(w, 2)))
......
...@@ -275,6 +275,14 @@ def test_mix_index_bound(): ...@@ -275,6 +275,14 @@ def test_mix_index_bound():
assert bd.max_value == (23 // 7) * 7 + 6 assert bd.max_value == (23 // 7) * 7 + 6
def test_size_var_bound():
analyzer = tvm.arith.Analyzer()
x = tvm.size_var("x")
bd = analyzer.const_int_bound(x)
assert bd.min_value == 0
assert bd.max_value == bd.POS_INF
if __name__ == "__main__": if __name__ == "__main__":
test_dtype_bound() test_dtype_bound()
test_cast_bound() test_cast_bound()
...@@ -288,3 +296,4 @@ if __name__ == "__main__": ...@@ -288,3 +296,4 @@ if __name__ == "__main__":
test_select_bound() test_select_bound()
test_shift_and_bound() test_shift_and_bound()
test_mix_index_bound() test_mix_index_bound()
test_size_var_bound()
...@@ -60,6 +60,7 @@ def test_add_sub(): ...@@ -60,6 +60,7 @@ def test_add_sub():
def test_mul_div(): def test_mul_div():
ck = IntSetChecker() ck = IntSetChecker()
x, y = tvm.var("x"), tvm.var("y") x, y = tvm.var("x"), tvm.var("y")
tdiv = tvm.truncdiv tdiv = tvm.truncdiv
ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True) ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
ck.verify(x * y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 * y)) ck.verify(x * y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 * y))
......
...@@ -20,7 +20,7 @@ def test_stmt_simplify(): ...@@ -20,7 +20,7 @@ def test_stmt_simplify():
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
A = ib.pointer("float32", name="A") A = ib.pointer("float32", name="A")
C = ib.pointer("float32", name="C") C = ib.pointer("float32", name="C")
n = tvm.var("n") n = tvm.size_var("n")
with ib.for_range(0, n, name="i") as i: with ib.for_range(0, n, name="i") as i:
with ib.if_scope(i < 12): with ib.if_scope(i < 12):
A[i] = C[i] A[i] = C[i]
...@@ -34,7 +34,7 @@ def test_thread_extent_simplify(): ...@@ -34,7 +34,7 @@ def test_thread_extent_simplify():
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
A = ib.pointer("float32", name="A") A = ib.pointer("float32", name="A")
C = ib.pointer("float32", name="C") C = ib.pointer("float32", name="C")
n = tvm.var("n") n = tvm.size_var("n")
tx = tvm.thread_axis("threadIdx.x") tx = tvm.thread_axis("threadIdx.x")
ty = tvm.thread_axis("threadIdx.y") ty = tvm.thread_axis("threadIdx.y")
ib.scope_attr(tx, "thread_extent", n) ib.scope_attr(tx, "thread_extent", n)
...@@ -48,7 +48,7 @@ def test_thread_extent_simplify(): ...@@ -48,7 +48,7 @@ def test_thread_extent_simplify():
def test_basic_likely_elimination(): def test_basic_likely_elimination():
n = tvm.var('n') n = tvm.size_var('n')
X = tvm.placeholder(shape=(n,), name="x") X = tvm.placeholder(shape=(n,), name="x")
W = tvm.placeholder(shape=(n + 1,), dtype="int32", name="w") W = tvm.placeholder(shape=(n + 1,), dtype="int32", name="w")
...@@ -87,7 +87,8 @@ def test_complex_likely_elimination(): ...@@ -87,7 +87,8 @@ def test_complex_likely_elimination():
return tvm.compute(oshape, sls) return tvm.compute(oshape, sls)
m, n, d, i, l = tvm.var('m'), tvm.var('n'), tvm.var('d'), tvm.var('i'), tvm.var('l') m, n, d, i, l = tvm.size_var('m'), tvm.size_var('n'), tvm.size_var('d'),\
tvm.size_var('i'), tvm.size_var('l')
data_ph = tvm.placeholder((m, d * 32), name="data") data_ph = tvm.placeholder((m, d * 32), name="data")
indices_ph = tvm.placeholder((i,), name="indices", dtype="int32") indices_ph = tvm.placeholder((i,), name="indices", dtype="int32")
lengths_ph = tvm.placeholder((n,), name="lengths", dtype="int32") lengths_ph = tvm.placeholder((n,), name="lengths", dtype="int32")
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
import tvm import tvm
def test_lower_rfactor(): def test_lower_rfactor():
n = tvm.var("n") n = tvm.size_var("n")
m = tvm.var("m") m = tvm.size_var("m")
A = tvm.placeholder((n, m), name='A') A = tvm.placeholder((n, m), name='A')
k = tvm.reduce_axis((0, m), "k") k = tvm.reduce_axis((0, m), "k")
B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name="B") B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name="B")
...@@ -33,7 +33,7 @@ def test_lower_rfactor(): ...@@ -33,7 +33,7 @@ def test_lower_rfactor():
fapi = tvm.lower(s, [A, B]) fapi = tvm.lower(s, [A, B])
def test_dependent_output_shape(): def test_dependent_output_shape():
n, m, x = tvm.var('n'), tvm.var('m'), tvm.var('x') n, m, x = tvm.size_var('n'), tvm.size_var('m'), tvm.size_var('x')
A = tvm.placeholder((n, m)) A = tvm.placeholder((n, m))
B = tvm.compute((m, n//x), lambda i, j: A[i,j] , name='B') B = tvm.compute((m, n//x), lambda i, j: A[i,j] , name='B')
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
......
...@@ -47,7 +47,7 @@ def test_vmlal_s16(): ...@@ -47,7 +47,7 @@ def test_vmlal_s16():
target = 'llvm -target=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon' target = 'llvm -target=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon'
def check_correct_assembly(N): def check_correct_assembly(N):
K = tvm.var("K") K = tvm.size_var("K")
A = tvm.placeholder((K, N), dtype="int8", name='A') A = tvm.placeholder((K, N), dtype="int8", name='A')
B = tvm.placeholder((K, N), dtype="int8", name='B') B = tvm.placeholder((K, N), dtype="int8", name='B')
k = tvm.reduce_axis((0, K)) k = tvm.reduce_axis((0, K))
...@@ -67,7 +67,7 @@ def test_vmlal_s16(): ...@@ -67,7 +67,7 @@ def test_vmlal_s16():
check_correct_assembly(64) check_correct_assembly(64)
def check_broadcast_correct_assembly(N): def check_broadcast_correct_assembly(N):
K = tvm.var("K") K = tvm.size_var("K")
A = tvm.placeholder((K, N), dtype="int8", name='A') A = tvm.placeholder((K, N), dtype="int8", name='A')
B = tvm.placeholder((K,), dtype="int8", name='B') B = tvm.placeholder((K,), dtype="int8", name='B')
k = tvm.reduce_axis((0, K)) k = tvm.reduce_axis((0, K))
......
...@@ -67,7 +67,7 @@ def test_add_pipeline(): ...@@ -67,7 +67,7 @@ def test_add_pipeline():
# Specifically allow offset to test codepath when offset is available # Specifically allow offset to test codepath when offset is available
Ab = tvm.decl_buffer( Ab = tvm.decl_buffer(
A.shape, A.dtype, A.shape, A.dtype,
elem_offset=tvm.var('Aoffset'), elem_offset=tvm.size_var('Aoffset'),
offset_factor=8, offset_factor=8,
name='A') name='A')
binds = {A : Ab} binds = {A : Ab}
......
...@@ -45,7 +45,7 @@ def test_large_uint_imm(): ...@@ -45,7 +45,7 @@ def test_large_uint_imm():
def test_add_pipeline(): def test_add_pipeline():
n = tvm.var('n') n = tvm.size_var('n')
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((), name='B') B = tvm.placeholder((), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(), name='C') C = tvm.compute(A.shape, lambda *i: A(*i) + B(), name='C')
......
...@@ -79,7 +79,7 @@ def test_llvm_import(): ...@@ -79,7 +79,7 @@ def test_llvm_import():
def test_llvm_lookup_intrin(): def test_llvm_lookup_intrin():
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
m = tvm.var("m") m = tvm.size_var("m")
A = ib.pointer("uint8x8", name="A") A = ib.pointer("uint8x8", name="A")
x = tvm.call_llvm_intrin("uint8x8", "llvm.ctpop.i8", tvm.const(1, 'uint32'), A) x = tvm.call_llvm_intrin("uint8x8", "llvm.ctpop.i8", tvm.const(1, 'uint32'), A)
ib.emit(x) ib.emit(x)
...@@ -131,7 +131,7 @@ def test_llvm_add_pipeline(): ...@@ -131,7 +131,7 @@ def test_llvm_add_pipeline():
# Specifically allow offset to test codepath when offset is available # Specifically allow offset to test codepath when offset is available
Ab = tvm.decl_buffer( Ab = tvm.decl_buffer(
A.shape, A.dtype, A.shape, A.dtype,
elem_offset=tvm.var('Aoffset'), elem_offset=tvm.size_var('Aoffset'),
offset_factor=8, offset_factor=8,
name='A') name='A')
binds = {A : Ab} binds = {A : Ab}
......
...@@ -26,8 +26,8 @@ by = tvm.thread_axis("blockIdx.y") ...@@ -26,8 +26,8 @@ by = tvm.thread_axis("blockIdx.y")
@unittest.skipIf(not tvm.rocm(0).exist or not tvm.module.enabled("rocm"), "skip because rocm is not enabled..") @unittest.skipIf(not tvm.rocm(0).exist or not tvm.module.enabled("rocm"), "skip because rocm is not enabled..")
def test_rocm_cross_thread_reduction(): def test_rocm_cross_thread_reduction():
# based on the reduction tutorial # based on the reduction tutorial
n = tvm.var("n") n = tvm.size_var("n")
m = tvm.var("m") m = tvm.size_var("m")
A = tvm.placeholder((n, m), name='A') A = tvm.placeholder((n, m), name='A')
k = tvm.reduce_axis((0, m), "k") k = tvm.reduce_axis((0, m), "k")
B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name="B") B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name="B")
......
...@@ -20,9 +20,9 @@ import numpy as np ...@@ -20,9 +20,9 @@ import numpy as np
def test_static_callback(): def test_static_callback():
dtype = 'int64' dtype = 'int64'
n = tvm.var('n') n = tvm.size_var('n')
Ab = tvm.decl_buffer((n, ), dtype) Ab = tvm.decl_buffer((n, ), dtype)
i = tvm.var('i') i = tvm.size_var('i')
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
A = ib.buffer_ptr(Ab) A = ib.buffer_ptr(Ab)
cp = tvm.thread_axis((0, 1), "cop") cp = tvm.thread_axis((0, 1), "cop")
...@@ -41,9 +41,9 @@ def test_static_callback(): ...@@ -41,9 +41,9 @@ def test_static_callback():
def test_static_init(): def test_static_init():
dtype = 'int64' dtype = 'int64'
n = tvm.var('n') n = tvm.size_var('n')
Ab = tvm.decl_buffer((n, ), dtype) Ab = tvm.decl_buffer((n, ), dtype)
i = tvm.var('i') i = tvm.size_var('i')
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
handle = tvm.call_intrin("handle", "tvm_static_handle") handle = tvm.call_intrin("handle", "tvm_static_handle")
ib.emit( ib.emit(
......
...@@ -32,7 +32,7 @@ def test_stack_vm_basic(): ...@@ -32,7 +32,7 @@ def test_stack_vm_basic():
print(shape0) print(shape0)
assert shape0 == a.shape[0] assert shape0 == a.shape[0]
n = tvm.var('n') n = tvm.size_var('n')
Ab = tvm.decl_buffer((n, ), tvm.float32) Ab = tvm.decl_buffer((n, ), tvm.float32)
stmt = tvm.make.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0])) stmt = tvm.make.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True) fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True)
...@@ -47,9 +47,9 @@ def tvm_stack_vm_print(*x): ...@@ -47,9 +47,9 @@ def tvm_stack_vm_print(*x):
def test_stack_vm_loop(): def test_stack_vm_loop():
dtype = 'int64' dtype = 'int64'
n = tvm.var('n') n = tvm.size_var('n')
Ab = tvm.decl_buffer((n, ), dtype) Ab = tvm.decl_buffer((n, ), dtype)
i = tvm.var('i') i = tvm.size_var('i')
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
A = ib.buffer_ptr(Ab) A = ib.buffer_ptr(Ab)
...@@ -69,7 +69,7 @@ def test_stack_vm_loop(): ...@@ -69,7 +69,7 @@ def test_stack_vm_loop():
def test_stack_vm_cond(): def test_stack_vm_cond():
dtype = 'int64' dtype = 'int64'
n = tvm.var('n') n = tvm.size_var('n')
Ab = tvm.decl_buffer((n, ), dtype) Ab = tvm.decl_buffer((n, ), dtype)
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
...@@ -93,9 +93,9 @@ def test_stack_vm_cond(): ...@@ -93,9 +93,9 @@ def test_stack_vm_cond():
def test_vm_parallel(): def test_vm_parallel():
dtype = 'int64' dtype = 'int64'
n = tvm.var('n') n = tvm.size_var('n')
Ab = tvm.decl_buffer((n, ), dtype) Ab = tvm.decl_buffer((n, ), dtype)
i = tvm.var('i') i = tvm.size_var('i')
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
A = ib.buffer_ptr(Ab) A = ib.buffer_ptr(Ab)
with ib.for_range(0, n, "i", for_type="parallel") as i: with ib.for_range(0, n, "i", for_type="parallel") as i:
......
...@@ -98,8 +98,8 @@ def outer_product(n, m, a, b): ...@@ -98,8 +98,8 @@ def outer_product(n, m, a, b):
#Test global function #Test global function
#Test bridge between frontend and backend #Test bridge between frontend and backend
def test_outer_product(): def test_outer_product():
n = tvm.var('n') n = tvm.size_var('n')
m = tvm.var('m') m = tvm.size_var('m')
a = tvm.placeholder((n, ), name='a') a = tvm.placeholder((n, ), name='a')
b = tvm.placeholder((m, ), name='b') b = tvm.placeholder((m, ), name='b')
...@@ -167,7 +167,7 @@ def test_fanout(): ...@@ -167,7 +167,7 @@ def test_fanout():
b[i] = sigma b[i] = sigma
return b return b
n = tvm.var('n') n = tvm.size_var('n')
a = tvm.placeholder((n, ), 'float32', name='a') a = tvm.placeholder((n, ), 'float32', name='a')
try: try:
b = fanout(n, a) b = fanout(n, a)
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
def test_for(): def test_for():
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
n = tvm.var("n") n = tvm.size_var("n")
A = ib.allocate("float32", n, name="A", scope="global") A = ib.allocate("float32", n, name="A", scope="global")
with ib.for_range(0, n, name="i") as i: with ib.for_range(0, n, name="i") as i:
A[i] = A[i] + 1 A[i] = A[i] + 1
...@@ -39,7 +39,7 @@ def test_for(): ...@@ -39,7 +39,7 @@ def test_for():
def test_if(): def test_if():
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
n = tvm.var("n") n = tvm.size_var("n")
A = ib.pointer("float32", name="A") A = ib.pointer("float32", name="A")
tmod = tvm.truncmod tmod = tvm.truncmod
with ib.for_range(0, n, name="i") as i: with ib.for_range(0, n, name="i") as i:
...@@ -60,7 +60,7 @@ def test_if(): ...@@ -60,7 +60,7 @@ def test_if():
def test_prefetch(): def test_prefetch():
A = tvm.placeholder((10, 20), name="A") A = tvm.placeholder((10, 20), name="A")
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
n = tvm.var("n") n = tvm.size_var("n")
with ib.for_range(0, n, name="i") as i: with ib.for_range(0, n, name="i") as i:
ib.emit( ib.emit(
...@@ -105,7 +105,7 @@ def test_cpu(): ...@@ -105,7 +105,7 @@ def test_cpu():
check_target("llvm") check_target("llvm")
def test_gpu(): def test_gpu():
n = tvm.var('n') n = tvm.size_var('n')
dtype = "float32" dtype = "float32"
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B') B = tvm.placeholder((n,), name='B')
......
...@@ -19,9 +19,9 @@ from tvm.schedule import Buffer ...@@ -19,9 +19,9 @@ from tvm.schedule import Buffer
import numpy as np import numpy as np
def test_buffer(): def test_buffer():
m = tvm.var('m') m = tvm.size_var('m')
n = tvm.var('n') n = tvm.size_var('n')
l = tvm.var('l') l = tvm.size_var('l')
Ab = tvm.decl_buffer((m, n), tvm.float32) Ab = tvm.decl_buffer((m, n), tvm.float32)
Bb = tvm.decl_buffer((n, l), tvm.float32) Bb = tvm.decl_buffer((n, l), tvm.float32)
...@@ -31,8 +31,8 @@ def test_buffer(): ...@@ -31,8 +31,8 @@ def test_buffer():
def test_buffer_access_ptr(): def test_buffer_access_ptr():
m = tvm.var('m') m = tvm.size_var('m')
n = tvm.var('n') n = tvm.size_var('n')
Ab = tvm.decl_buffer((m, n), tvm.float32, strides=[n + 1 , 1]) Ab = tvm.decl_buffer((m, n), tvm.float32, strides=[n + 1 , 1])
aptr = Ab.access_ptr("rw") aptr = Ab.access_ptr("rw")
assert tvm.ir_pass.Equal(aptr.args[3], Ab.strides[0] * m) assert tvm.ir_pass.Equal(aptr.args[3], Ab.strides[0] * m)
...@@ -43,14 +43,14 @@ def test_buffer_access_ptr(): ...@@ -43,14 +43,14 @@ def test_buffer_access_ptr():
def test_buffer_access_ptr_offset(): def test_buffer_access_ptr_offset():
m = tvm.var('m') m = tvm.size_var('m')
n = tvm.var('n') n = tvm.size_var('n')
Ab = tvm.decl_buffer((m, n), tvm.float32) Ab = tvm.decl_buffer((m, n), tvm.float32)
aptr = Ab.access_ptr("rw", offset=100) aptr = Ab.access_ptr("rw", offset=100)
offset = tvm.ir_pass.Simplify(aptr.args[2]) offset = tvm.ir_pass.Simplify(aptr.args[2])
assert tvm.ir_pass.Equal(offset, 100) assert tvm.ir_pass.Equal(offset, 100)
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
v = tvm.var('int32') v = tvm.size_var('int32')
aptr = Ab.access_ptr("rw", offset=100 + 100 + v) aptr = Ab.access_ptr("rw", offset=100 + 100 + v)
offset = tvm.ir_pass.Simplify(aptr.args[2]) offset = tvm.ir_pass.Simplify(aptr.args[2])
assert tvm.ir_pass.Equal(offset, 200 + v) assert tvm.ir_pass.Equal(offset, 200 + v)
...@@ -62,8 +62,8 @@ def test_buffer_access_ptr_offset(): ...@@ -62,8 +62,8 @@ def test_buffer_access_ptr_offset():
def test_buffer_access_ptr_extent(): def test_buffer_access_ptr_extent():
m = tvm.var('m') m = tvm.size_var('m')
n = tvm.var('n') n = tvm.size_var('n')
Ab = tvm.decl_buffer((m, n), tvm.float32) Ab = tvm.decl_buffer((m, n), tvm.float32)
aptr = Ab.access_ptr("rw") aptr = Ab.access_ptr("rw")
assert tvm.ir_pass.Equal(aptr.args[3], m * n) assert tvm.ir_pass.Equal(aptr.args[3], m * n)
...@@ -75,8 +75,8 @@ def test_buffer_access_ptr_extent(): ...@@ -75,8 +75,8 @@ def test_buffer_access_ptr_extent():
def test_buffer_vload(): def test_buffer_vload():
m = tvm.var('m') m = tvm.size_var('m')
n = tvm.var('n') n = tvm.size_var('n')
Ab = tvm.decl_buffer((m, n), tvm.float32, elem_offset=100) Ab = tvm.decl_buffer((m, n), tvm.float32, elem_offset=100)
load = Ab.vload([2, 3]) load = Ab.vload([2, 3])
offset = tvm.ir_pass.Simplify(load.index) offset = tvm.ir_pass.Simplify(load.index)
...@@ -84,11 +84,11 @@ def test_buffer_vload(): ...@@ -84,11 +84,11 @@ def test_buffer_vload():
def test_buffer_index_merge_mult_mod(): def test_buffer_index_merge_mult_mod():
m = tvm.var('m') m = tvm.size_var('m')
n = tvm.var('n') n = tvm.size_var('n')
s = tvm.var('s') s = tvm.size_var('s')
k0 = tvm.var('k0') k0 = tvm.size_var('k0')
k1 = tvm.var('k1') k1 = tvm.size_var('k1')
A = tvm.decl_buffer((m, n), tvm.float32) A = tvm.decl_buffer((m, n), tvm.float32)
A_stride = tvm.decl_buffer((m, n), tvm.float32, strides=(s, 1)) A_stride = tvm.decl_buffer((m, n), tvm.float32, strides=(s, 1))
def assert_simplified_equal(index_simplified, index_direct): def assert_simplified_equal(index_simplified, index_direct):
...@@ -123,9 +123,9 @@ def test_buffer_index_merge_mult_mod(): ...@@ -123,9 +123,9 @@ def test_buffer_index_merge_mult_mod():
def test_buffer_broadcast(): def test_buffer_broadcast():
m0, m1, m2 = tvm.var("m0"), tvm.var("m1"), tvm.var("m2") m0, m1, m2 = tvm.size_var("m0"), tvm.size_var("m1"), tvm.size_var("m2")
n0, n1, n2 = tvm.var("n0"), tvm.var("n1"), tvm.var("n2") n0, n1, n2 = tvm.size_var("n0"), tvm.size_var("n1"), tvm.size_var("n2")
o0, o1, o2 = tvm.var("o0"), tvm.var("o1"), tvm.var("o2") o0, o1, o2 = tvm.size_var("o0"), tvm.size_var("o1"), tvm.size_var("o2")
A = tvm.placeholder((m0, m1, m2), name='A') A = tvm.placeholder((m0, m1, m2), name='A')
B = tvm.placeholder((n0, n1, n2), name='B') B = tvm.placeholder((n0, n1, n2), name='B')
...@@ -151,9 +151,9 @@ def test_buffer_broadcast(): ...@@ -151,9 +151,9 @@ def test_buffer_broadcast():
def test_buffer_broadcast_expr(): def test_buffer_broadcast_expr():
n0, m0, x = tvm.var('n0'), tvm.var('m0'), tvm.var('x') n0, m0, x = tvm.size_var('n0'), tvm.size_var('m0'), tvm.size_var('x')
n1, m1 = tvm.var('n1'), tvm.var('m1') n1, m1 = tvm.size_var('n1'), tvm.size_var('m1')
o0, o1 = tvm.var('o0'), tvm.var('o1') o0, o1 = tvm.size_var('o0'), tvm.size_var('o1')
A = tvm.placeholder((m0, n0), name='A') A = tvm.placeholder((m0, n0), name='A')
B = tvm.placeholder((m1, n1), name='B') B = tvm.placeholder((m1, n1), name='B')
......
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
import tvm import tvm
def test_scan_group(): def test_scan_group():
m = tvm.var("m") m = tvm.size_var("m")
n = tvm.var("n") n = tvm.size_var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x") x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
s_state = tvm.placeholder((m, n)) s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: x[0, i]) s_init = tvm.compute((1, n), lambda _, i: x[0, i])
...@@ -50,8 +50,8 @@ def test_scan_group(): ...@@ -50,8 +50,8 @@ def test_scan_group():
pass pass
def test_compute_group(): def test_compute_group():
m = tvm.var("m") m = tvm.size_var("m")
n = tvm.var("n") n = tvm.size_var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x") x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1") x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1")
x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2") x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2")
...@@ -64,8 +64,8 @@ def test_compute_group(): ...@@ -64,8 +64,8 @@ def test_compute_group():
assert g.num_child_stages == 2 assert g.num_child_stages == 2
def test_nest_group(): def test_nest_group():
m = tvm.var("m") m = tvm.size_var("m")
n = tvm.var("n") n = tvm.size_var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x") x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1") x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1")
x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2") x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2")
......
...@@ -19,9 +19,9 @@ import tvm ...@@ -19,9 +19,9 @@ import tvm
import pickle as pkl import pickle as pkl
def test_schedule_create(): def test_schedule_create():
m = tvm.var('m') m = tvm.size_var('m')
n = tvm.var('n') n = tvm.size_var('n')
l = tvm.var('l') l = tvm.size_var('l')
A = tvm.placeholder((m, l), name='A') A = tvm.placeholder((m, l), name='A')
B = tvm.placeholder((n, l), name='B') B = tvm.placeholder((n, l), name='B')
AA = tvm.compute((m, l), lambda i, j: A[i, j]) AA = tvm.compute((m, l), lambda i, j: A[i, j])
...@@ -49,7 +49,7 @@ def test_schedule_create(): ...@@ -49,7 +49,7 @@ def test_schedule_create():
def test_reorder(): def test_reorder():
m = tvm.var('m') m = tvm.size_var('m')
A = tvm.placeholder((m,), name='A') A = tvm.placeholder((m,), name='A')
T = tvm.compute(m, lambda i: A[i+1]) T = tvm.compute(m, lambda i: A[i+1])
...@@ -69,7 +69,7 @@ def test_reorder(): ...@@ -69,7 +69,7 @@ def test_reorder():
pass pass
def test_split(): def test_split():
m = tvm.var('m') m = tvm.size_var('m')
A = tvm.placeholder((m,), name='A') A = tvm.placeholder((m,), name='A')
T = tvm.compute((m,), lambda i: A[i]) T = tvm.compute((m,), lambda i: A[i])
...@@ -79,8 +79,8 @@ def test_split(): ...@@ -79,8 +79,8 @@ def test_split():
def test_tile(): def test_tile():
m = tvm.var('m') m = tvm.size_var('m')
n = tvm.var('n') n = tvm.size_var('n')
A = tvm.placeholder((m, n), name='A') A = tvm.placeholder((m, n), name='A')
T = tvm.compute((m, n), lambda i, j: A[i, j]) T = tvm.compute((m, n), lambda i, j: A[i, j])
...@@ -90,8 +90,8 @@ def test_tile(): ...@@ -90,8 +90,8 @@ def test_tile():
def test_fuse(): def test_fuse():
m = tvm.var('m') m = tvm.size_var('m')
n = tvm.var('n') n = tvm.size_var('n')
A = tvm.placeholder((m, n), name='A') A = tvm.placeholder((m, n), name='A')
T = tvm.compute((m, n), lambda i, j: A[i, j]) T = tvm.compute((m, n), lambda i, j: A[i, j])
...@@ -119,8 +119,8 @@ def test_singleton(): ...@@ -119,8 +119,8 @@ def test_singleton():
print("test singleton fin") print("test singleton fin")
def test_vectorize(): def test_vectorize():
m = tvm.var('m') m = tvm.size_var('m')
n = tvm.var('n') n = tvm.size_var('n')
A = tvm.placeholder((m, n), name='A') A = tvm.placeholder((m, n), name='A')
T = tvm.compute((m, n), lambda i, j: A[i, j]) T = tvm.compute((m, n), lambda i, j: A[i, j])
...@@ -156,7 +156,7 @@ def test_pragma(): ...@@ -156,7 +156,7 @@ def test_pragma():
def test_rfactor(): def test_rfactor():
n = tvm.var('n') n = tvm.size_var('n')
k1 = tvm.reduce_axis((0, n), name="k1") k1 = tvm.reduce_axis((0, n), name="k1")
k2 = tvm.reduce_axis((0, n), name="k2") k2 = tvm.reduce_axis((0, n), name="k2")
A = tvm.placeholder((n, n, n), name='A') A = tvm.placeholder((n, n, n), name='A')
...@@ -214,10 +214,10 @@ def test_tensor_intrin(): ...@@ -214,10 +214,10 @@ def test_tensor_intrin():
assert(s[z].iter_var_attrs[xi].iter_type == tvm.schedule.IterVar.Tensorized) assert(s[z].iter_var_attrs[xi].iter_type == tvm.schedule.IterVar.Tensorized)
def test_tensor_intrin_scalar_params(): def test_tensor_intrin_scalar_params():
n = tvm.var("n") n = tvm.size_var("n")
x = tvm.placeholder((n,), name='x') x = tvm.placeholder((n,), name='x')
v = tvm.var("v") v = tvm.size_var("v")
w = tvm.var("w") w = tvm.size_var("w")
z = tvm.compute((n,), lambda i: x[i]*v + w, name='z') z = tvm.compute((n,), lambda i: x[i]*v + w, name='z')
def intrin_func(ins, outs, sp): def intrin_func(ins, outs, sp):
......
...@@ -33,9 +33,9 @@ def compute_conv(data, weight): ...@@ -33,9 +33,9 @@ def compute_conv(data, weight):
axis=[ic, dh, dw])) axis=[ic, dh, dw]))
def test_with(): def test_with():
n = tvm.var('n') n = tvm.size_var('n')
m = tvm.var('m') m = tvm.size_var('m')
l = tvm.var('l') l = tvm.size_var('l')
A = tvm.placeholder((n, l), name='A') A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((m, l), name='B') B = tvm.placeholder((m, l), name='B')
...@@ -56,12 +56,12 @@ def test_with(): ...@@ -56,12 +56,12 @@ def test_with():
def test_decorator(): def test_decorator():
n = tvm.var('n') n = tvm.size_var('n')
c = tvm.var('c') c = tvm.size_var('c')
h = tvm.var('h') h = tvm.size_var('h')
w = tvm.var('w') w = tvm.size_var('w')
kh = tvm.var('kh') kh = tvm.size_var('kh')
kw = tvm.var('kw') kw = tvm.size_var('kw')
A = tvm.placeholder((n, c, h, w), name='A') A = tvm.placeholder((n, c, h, w), name='A')
B = tvm.placeholder((c, c, kh, kw), name='B') B = tvm.placeholder((c, c, kh, kw), name='B')
...@@ -70,12 +70,12 @@ def test_decorator(): ...@@ -70,12 +70,12 @@ def test_decorator():
assert len(C.op.attrs) == 0 assert len(C.op.attrs) == 0
def test_nested(): def test_nested():
n = tvm.var('n') n = tvm.size_var('n')
c = tvm.var('c') c = tvm.size_var('c')
h = tvm.var('h') h = tvm.size_var('h')
w = tvm.var('w') w = tvm.size_var('w')
kh = tvm.var('kh') kh = tvm.size_var('kh')
kw = tvm.var('kw') kw = tvm.size_var('kw')
A = tvm.placeholder((n, c, h, w), name='A') A = tvm.placeholder((n, c, h, w), name='A')
B = tvm.placeholder((c, c, kh, kw), name='B') B = tvm.placeholder((c, c, kh, kw), name='B')
......
...@@ -18,9 +18,9 @@ import tvm ...@@ -18,9 +18,9 @@ import tvm
from topi.nn.pooling import pool from topi.nn.pooling import pool
def test_tensor(): def test_tensor():
m = tvm.var('m') m = tvm.size_var('m')
n = tvm.var('n') n = tvm.size_var('n')
l = tvm.var('l') l = tvm.size_var('l')
A = tvm.placeholder((m, l), name='A') A = tvm.placeholder((m, l), name='A')
B = tvm.placeholder((n, l), name='B') B = tvm.placeholder((n, l), name='B')
T = tvm.compute((m, n, l), lambda i, j, k: A[i, k] * B[j, k]) T = tvm.compute((m, n, l), lambda i, j, k: A[i, k] * B[j, k])
...@@ -37,7 +37,7 @@ def test_tensor(): ...@@ -37,7 +37,7 @@ def test_tensor():
def test_rank_zero(): def test_rank_zero():
m = tvm.var('m') m = tvm.size_var('m')
A = tvm.placeholder((m,), name='A') A = tvm.placeholder((m,), name='A')
scale = tvm.placeholder((), name='s') scale = tvm.placeholder((), name='s')
k = tvm.reduce_axis((0, m), name="k") k = tvm.reduce_axis((0, m), name="k")
...@@ -48,7 +48,7 @@ def test_rank_zero(): ...@@ -48,7 +48,7 @@ def test_rank_zero():
def test_conv1d(): def test_conv1d():
n = tvm.var('n') n = tvm.size_var('n')
A = tvm.placeholder((n+2), name='A') A = tvm.placeholder((n+2), name='A')
def computeB(ii): def computeB(ii):
i = ii + 1 i = ii + 1
...@@ -57,14 +57,14 @@ def test_conv1d(): ...@@ -57,14 +57,14 @@ def test_conv1d():
def test_tensor_slice(): def test_tensor_slice():
n = tvm.var('n') n = tvm.size_var('n')
A = tvm.compute((n, n), lambda i, j: 1) A = tvm.compute((n, n), lambda i, j: 1)
B = tvm.compute((n,), lambda i: A[0][i] + A[0][i]) B = tvm.compute((n,), lambda i: A[0][i] + A[0][i])
def test_tensor_reduce_multi_axis(): def test_tensor_reduce_multi_axis():
m = tvm.var('m') m = tvm.size_var('m')
n = tvm.var('n') n = tvm.size_var('n')
A = tvm.placeholder((m, n), name='A') A = tvm.placeholder((m, n), name='A')
k1 = tvm.reduce_axis((0, n), "k") k1 = tvm.reduce_axis((0, n), "k")
k2 = tvm.reduce_axis((0, m), "k") k2 = tvm.reduce_axis((0, m), "k")
...@@ -73,23 +73,23 @@ def test_tensor_reduce_multi_axis(): ...@@ -73,23 +73,23 @@ def test_tensor_reduce_multi_axis():
def test_tensor_comm_reducer(): def test_tensor_comm_reducer():
m = tvm.var('m') m = tvm.size_var('m')
n = tvm.var('n') n = tvm.size_var('n')
A = tvm.placeholder((m, n), name='A') A = tvm.placeholder((m, n), name='A')
k = tvm.reduce_axis((0, n), "k") k = tvm.reduce_axis((0, n), "k")
mysum = tvm.comm_reducer(lambda x, y: x+y, lambda t: tvm.const(0, dtype=t)) mysum = tvm.comm_reducer(lambda x, y: x+y, lambda t: tvm.const(0, dtype=t))
C = tvm.compute((m,), lambda i: mysum(A[i, k], axis=k)) C = tvm.compute((m,), lambda i: mysum(A[i, k], axis=k))
def test_tensor_comm_reducer_overload(): def test_tensor_comm_reducer_overload():
m = tvm.var('m') m = tvm.size_var('m')
n = tvm.var('n') n = tvm.size_var('n')
mysum = tvm.comm_reducer(lambda x, y: x+y, lambda t: tvm.const(0, dtype=t)) mysum = tvm.comm_reducer(lambda x, y: x+y, lambda t: tvm.const(0, dtype=t))
sum_res = mysum(m, n) sum_res = mysum(m, n)
def test_tensor_reduce(): def test_tensor_reduce():
m = tvm.var('m') m = tvm.size_var('m')
n = tvm.var('n') n = tvm.size_var('n')
l = tvm.var('l') l = tvm.size_var('l')
A = tvm.placeholder((m, l), name='A') A = tvm.placeholder((m, l), name='A')
B = tvm.placeholder((n, l), name='B') B = tvm.placeholder((n, l), name='B')
T = tvm.compute((m, n, l), lambda i, j, k: A[i, k] * B[j, k]) T = tvm.compute((m, n, l), lambda i, j, k: A[i, k] * B[j, k])
...@@ -175,8 +175,8 @@ def test_tensor_compute2(): ...@@ -175,8 +175,8 @@ def test_tensor_compute2():
assert isinstance(stmt.body.body.body[1].body, tvm.stmt.Evaluate) assert isinstance(stmt.body.body.body[1].body, tvm.stmt.Evaluate)
def test_tensor_scan(): def test_tensor_scan():
m = tvm.var("m") m = tvm.size_var("m")
n = tvm.var("n") n = tvm.size_var("n")
x = tvm.placeholder((m, n)) x = tvm.placeholder((m, n))
s = tvm.placeholder((m, n)) s = tvm.placeholder((m, n))
res = tvm.scan(tvm.compute((1, n), lambda _, i: x[0, i]), res = tvm.scan(tvm.compute((1, n), lambda _, i: x[0, i]),
...@@ -185,8 +185,8 @@ def test_tensor_scan(): ...@@ -185,8 +185,8 @@ def test_tensor_scan():
assert tuple(res.shape) == (m, n) assert tuple(res.shape) == (m, n)
def test_scan_multi_out(): def test_scan_multi_out():
m = tvm.var("m") m = tvm.size_var("m")
n = tvm.var("n") n = tvm.size_var("n")
x1 = tvm.placeholder((m, n)) x1 = tvm.placeholder((m, n))
s1 = tvm.placeholder((m, n)) s1 = tvm.placeholder((m, n))
x2 = tvm.placeholder((m, n)) x2 = tvm.placeholder((m, n))
...@@ -206,7 +206,7 @@ def test_scan_multi_out(): ...@@ -206,7 +206,7 @@ def test_scan_multi_out():
assert isinstance(zz, tvm.tensor.ScanOp) assert isinstance(zz, tvm.tensor.ScanOp)
def test_extern(): def test_extern():
m = tvm.var('m') m = tvm.size_var('m')
A = tvm.placeholder((m,), name='A') A = tvm.placeholder((m,), name='A')
def extern_func(ins, outs): def extern_func(ins, outs):
...@@ -217,7 +217,7 @@ def test_extern(): ...@@ -217,7 +217,7 @@ def test_extern():
def test_extern_multi_out(): def test_extern_multi_out():
m = tvm.var('m') m = tvm.size_var('m')
A = tvm.placeholder((m,), name='A') A = tvm.placeholder((m,), name='A')
B = tvm.compute((m,), lambda i: A[i] * 10) B = tvm.compute((m,), lambda i: A[i] * 10)
...@@ -230,8 +230,8 @@ def test_extern_multi_out(): ...@@ -230,8 +230,8 @@ def test_extern_multi_out():
assert(res[1].value_index == 1) assert(res[1].value_index == 1)
def test_tuple_inputs(): def test_tuple_inputs():
m = tvm.var('m') m = tvm.size_var('m')
n = tvm.var('n') n = tvm.size_var('n')
A0 = tvm.placeholder((m, n), name='A0') A0 = tvm.placeholder((m, n), name='A0')
A1 = tvm.placeholder((m, n), name='A1') A1 = tvm.placeholder((m, n), name='A1')
T0, T1 = tvm.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name='T') T0, T1 = tvm.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name='T')
...@@ -244,8 +244,8 @@ def test_tuple_inputs(): ...@@ -244,8 +244,8 @@ def test_tuple_inputs():
assert(T1.value_index == 1) assert(T1.value_index == 1)
def test_tuple_with_different_deps(): def test_tuple_with_different_deps():
m = tvm.var('m') m = tvm.size_var('m')
n = tvm.var('n') n = tvm.size_var('n')
A0 = tvm.placeholder((m, n), name='A1') A0 = tvm.placeholder((m, n), name='A1')
A1 = tvm.placeholder((m, n), name='A2') A1 = tvm.placeholder((m, n), name='A2')
B0, B1 = tvm.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name='B') B0, B1 = tvm.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name='B')
......
...@@ -87,7 +87,7 @@ def test_combination(): ...@@ -87,7 +87,7 @@ def test_combination():
def verify_tensor_scalar_bop(shape, typ="add"): def verify_tensor_scalar_bop(shape, typ="add"):
"""Verify non-constant Tensor and scalar binary operations.""" """Verify non-constant Tensor and scalar binary operations."""
sh = [tvm.var('n%d' % i) for i in range(0, len(shape))] sh = [tvm.size_var('n%d' % i) for i in range(0, len(shape))]
k = tvm.var('k') k = tvm.var('k')
A = tvm.placeholder(sh, name='A') A = tvm.placeholder(sh, name='A')
if typ == "add": if typ == "add":
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
import tvm import tvm
def test_verify_compute(): def test_verify_compute():
n = tvm.var("n") n = tvm.size_var("n")
m = tvm.var("m") m = tvm.size_var("m")
A = tvm.placeholder((n, m), name='A') A = tvm.placeholder((n, m), name='A')
k = tvm.reduce_axis((0, m), "k") k = tvm.reduce_axis((0, m), "k")
k_ = tvm.reduce_axis((0, m-1), "k_") k_ = tvm.reduce_axis((0, m-1), "k_")
......
...@@ -46,7 +46,7 @@ def test_dso_module_load(): ...@@ -46,7 +46,7 @@ def test_dso_module_load():
temp = util.tempdir() temp = util.tempdir()
def save_object(names): def save_object(names):
n = tvm.var('n') n = tvm.size_var('n')
Ab = tvm.decl_buffer((n, ), dtype) Ab = tvm.decl_buffer((n, ), dtype)
i = tvm.var('i') i = tvm.var('i')
# for i in 0 to n-1: # for i in 0 to n-1:
......
...@@ -46,7 +46,7 @@ def lower(sch, args): ...@@ -46,7 +46,7 @@ def lower(sch, args):
@pytest.mark.xfail @pytest.mark.xfail
def test_out_of_bounds_llvm(index_a, index_b): def test_out_of_bounds_llvm(index_a, index_b):
n = tvm.var("n") n = tvm.size_var("n")
A = tvm.placeholder ((n,), name='A') A = tvm.placeholder ((n,), name='A')
B = tvm.placeholder ((n,), name='B') B = tvm.placeholder ((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i + index_a] + B[i + index_b], name='C') C = tvm.compute(A.shape, lambda i: A[i + index_a] + B[i + index_b], name='C')
...@@ -63,7 +63,7 @@ def test_out_of_bounds_llvm(index_a, index_b): ...@@ -63,7 +63,7 @@ def test_out_of_bounds_llvm(index_a, index_b):
fadd (a, b, c) fadd (a, b, c)
def test_in_bounds_llvm(): def test_in_bounds_llvm():
n = tvm.var("n") n = tvm.size_var("n")
A = tvm.placeholder ((n,), name='A') A = tvm.placeholder ((n,), name='A')
B = tvm.placeholder ((n,), name='B') B = tvm.placeholder ((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C') C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C')
...@@ -128,7 +128,7 @@ def test_in_bounds_vectorize_llvm(): ...@@ -128,7 +128,7 @@ def test_in_bounds_vectorize_llvm():
tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 1) tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 1)
def test_in_bounds_loop_partition_basic_llvm(): def test_in_bounds_loop_partition_basic_llvm():
n = tvm.var('n') n = tvm.size_var('n')
A = tvm.placeholder((n, ), name='A') A = tvm.placeholder((n, ), name='A')
B = tvm.placeholder((n, ), name='B') B = tvm.placeholder((n, ), name='B')
...@@ -147,7 +147,7 @@ def test_in_bounds_loop_partition_basic_llvm(): ...@@ -147,7 +147,7 @@ def test_in_bounds_loop_partition_basic_llvm():
@pytest.mark.xfail @pytest.mark.xfail
def test_out_of_bounds_loop_partition_basic_llvm(index_a, index_b): def test_out_of_bounds_loop_partition_basic_llvm(index_a, index_b):
n = tvm.var('n') n = tvm.size_var('n')
A = tvm.placeholder((n, ), name='A') A = tvm.placeholder((n, ), name='A')
B = tvm.placeholder((n, ), name='B') B = tvm.placeholder((n, ), name='B')
...@@ -331,9 +331,9 @@ def test_out_of_bounds_conv_llvm(data_offsets, kernel_offsets, loop_tiling=False ...@@ -331,9 +331,9 @@ def test_out_of_bounds_conv_llvm(data_offsets, kernel_offsets, loop_tiling=False
f(data_input, kernel_input, conv_out) f(data_input, kernel_input, conv_out)
def test_in_bounds_tensors_with_same_shapes1D_llvm(): def test_in_bounds_tensors_with_same_shapes1D_llvm():
n = tvm.var('n') n = tvm.size_var('n')
k = tvm.var('k') k = tvm.size_var('k')
m = tvm.var('m') m = tvm.size_var('m')
A = tvm.placeholder((n, ), name='A') A = tvm.placeholder((n, ), name='A')
B = tvm.placeholder((k, ), name='B') B = tvm.placeholder((k, ), name='B')
...@@ -351,9 +351,9 @@ def test_in_bounds_tensors_with_same_shapes1D_llvm(): ...@@ -351,9 +351,9 @@ def test_in_bounds_tensors_with_same_shapes1D_llvm():
@pytest.mark.xfail @pytest.mark.xfail
def test_out_of_bounds_tensors_with_diff_shapes1D_llvm(a_shape, b_shape, c_shape): def test_out_of_bounds_tensors_with_diff_shapes1D_llvm(a_shape, b_shape, c_shape):
n = tvm.var('n') n = tvm.size_var('n')
k = tvm.var('k') k = tvm.size_var('k')
m = tvm.var('m') m = tvm.size_var('m')
A = tvm.placeholder((n, ), name='A') A = tvm.placeholder((n, ), name='A')
B = tvm.placeholder((k, ), name='B') B = tvm.placeholder((k, ), name='B')
...@@ -370,9 +370,9 @@ def test_out_of_bounds_tensors_with_diff_shapes1D_llvm(a_shape, b_shape, c_shape ...@@ -370,9 +370,9 @@ def test_out_of_bounds_tensors_with_diff_shapes1D_llvm(a_shape, b_shape, c_shape
f(a, b, t) f(a, b, t)
def test_in_bounds_tensors_with_same_shapes2D_llvm(): def test_in_bounds_tensors_with_same_shapes2D_llvm():
n = tvm.var('n') n = tvm.size_var('n')
k = tvm.var('k') k = tvm.size_var('k')
m = tvm.var('m') m = tvm.size_var('m')
A = tvm.placeholder((n, n), name='A') A = tvm.placeholder((n, n), name='A')
B = tvm.placeholder((k, k), name='B') B = tvm.placeholder((k, k), name='B')
...@@ -390,9 +390,9 @@ def test_in_bounds_tensors_with_same_shapes2D_llvm(): ...@@ -390,9 +390,9 @@ def test_in_bounds_tensors_with_same_shapes2D_llvm():
@pytest.mark.xfail @pytest.mark.xfail
def test_out_of_bounds_tensors_with_diff_shapes2D_llvm(a_shape, b_shape, c_shape): def test_out_of_bounds_tensors_with_diff_shapes2D_llvm(a_shape, b_shape, c_shape):
n = tvm.var('n') n = tvm.size_var('n')
k = tvm.var('k') k = tvm.size_var('k')
m = tvm.var('m') m = tvm.size_var('m')
A = tvm.placeholder((n, n), name='A') A = tvm.placeholder((n, n), name='A')
B = tvm.placeholder((k, k), name='B') B = tvm.placeholder((k, k), name='B')
...@@ -409,9 +409,9 @@ def test_out_of_bounds_tensors_with_diff_shapes2D_llvm(a_shape, b_shape, c_shape ...@@ -409,9 +409,9 @@ def test_out_of_bounds_tensors_with_diff_shapes2D_llvm(a_shape, b_shape, c_shape
f(a, b, t) f(a, b, t)
def test_in_bounds_tensors_with_same_shapes3D_llvm(): def test_in_bounds_tensors_with_same_shapes3D_llvm():
n = tvm.var('n') n = tvm.size_var('n')
k = tvm.var('k') k = tvm.size_var('k')
m = tvm.var('m') m = tvm.size_var('m')
A = tvm.placeholder((n, n, n), name='A') A = tvm.placeholder((n, n, n), name='A')
B = tvm.placeholder((k, k, k), name='B') B = tvm.placeholder((k, k, k), name='B')
...@@ -429,9 +429,9 @@ def test_in_bounds_tensors_with_same_shapes3D_llvm(): ...@@ -429,9 +429,9 @@ def test_in_bounds_tensors_with_same_shapes3D_llvm():
@pytest.mark.xfail @pytest.mark.xfail
def test_out_of_bounds_tensors_with_diff_shapes3D_llvm(a_shape, b_shape, c_shape): def test_out_of_bounds_tensors_with_diff_shapes3D_llvm(a_shape, b_shape, c_shape):
n = tvm.var('n') n = tvm.size_var('n')
k = tvm.var('k') k = tvm.size_var('k')
m = tvm.var('m') m = tvm.size_var('m')
A = tvm.placeholder((n, n, n), name='A') A = tvm.placeholder((n, n, n), name='A')
B = tvm.placeholder((k, k, k), name='B') B = tvm.placeholder((k, k, k), name='B')
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
import tvm import tvm
def test_decorate_device(): def test_decorate_device():
m = tvm.var('m') m = tvm.size_var('m')
l = tvm.var('l') l = tvm.size_var('l')
A = tvm.placeholder((m, l), name='A') A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import tvm import tvm
def test_inline(): def test_inline():
m = tvm.var('m') m = tvm.size_var('m')
A = tvm.placeholder((m,), name='A') A = tvm.placeholder((m,), name='A')
T = tvm.compute((m,), lambda i,: A[i] + 10, name='T') T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
stmt = tvm.make.Evaluate(T[10] + 11 * T[100]) stmt = tvm.make.Evaluate(T[10] + 11 * T[100])
...@@ -36,7 +36,7 @@ def test_inline(): ...@@ -36,7 +36,7 @@ def test_inline():
pass pass
def test_inline2(): def test_inline2():
m = tvm.var('m') m = tvm.size_var('m')
A = tvm.placeholder((m,), name='A') A = tvm.placeholder((m,), name='A')
T = tvm.compute((m,), lambda i,: A[i] + 10, name='T') T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
stmt = tvm.make.Evaluate(tvm.exp(T[10]) + 11 * T[100]) stmt = tvm.make.Evaluate(tvm.exp(T[10]) + 11 * T[100])
......
...@@ -52,7 +52,7 @@ def lower(sch, args): ...@@ -52,7 +52,7 @@ def lower(sch, args):
return stmt return stmt
def test_basic(): def test_basic():
n = tvm.var('n') n = tvm.size_var('n')
A = tvm.placeholder((n, ), name='A') A = tvm.placeholder((n, ), name='A')
B = tvm.placeholder((n, ), name='B') B = tvm.placeholder((n, ), name='B')
...@@ -65,6 +65,7 @@ def test_basic(): ...@@ -65,6 +65,7 @@ def test_basic():
stmt = tvm.ir_pass.LoopPartition(stmt, False) stmt = tvm.ir_pass.LoopPartition(stmt, False)
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.body.body[0])) assert('if' not in str(stmt.body.body.body[0]))
assert('if' in str(stmt.body.body.body[1]))
def test_const_loop(): def test_const_loop():
n = 21 n = 21
...@@ -83,8 +84,8 @@ def test_const_loop(): ...@@ -83,8 +84,8 @@ def test_const_loop():
def test_multi_loop(): def test_multi_loop():
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
m = tvm.var('m') m = tvm.size_var('m')
n = tvm.var('n') n = tvm.size_var('n')
with ib.for_range(0, 4, "i") as i: with ib.for_range(0, 4, "i") as i:
with ib.for_range(0, n, "j") as j: with ib.for_range(0, n, "j") as j:
with ib.for_range(0, m, "k") as k: with ib.for_range(0, m, "k") as k:
...@@ -99,8 +100,8 @@ def test_multi_loop(): ...@@ -99,8 +100,8 @@ def test_multi_loop():
def test_multi_if(): def test_multi_if():
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
m = tvm.var('m') m = tvm.size_var('m')
n = tvm.var('n') n = tvm.size_var('n')
with ib.for_range(0, 4, 'i') as i: with ib.for_range(0, 4, 'i') as i:
with ib.for_range(0, n, 'j') as j: with ib.for_range(0, n, 'j') as j:
with ib.for_range(0, m, 'k') as k: with ib.for_range(0, m, 'k') as k:
...@@ -118,8 +119,8 @@ def test_multi_if(): ...@@ -118,8 +119,8 @@ def test_multi_if():
assert('if' not in str(stmt.body[0])) assert('if' not in str(stmt.body[0]))
def test_thread_axis(): def test_thread_axis():
m = tvm.var('m') m = tvm.size_var('m')
l = tvm.var('l') l = tvm.size_var('l')
A = tvm.placeholder((m, l), name='A') A = tvm.placeholder((m, l), name='A')
B = tvm.compute((m, l), lambda i, j: A[i, j] + 3, name='B') B = tvm.compute((m, l), lambda i, j: A[i, j] + 3, name='B')
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
...@@ -137,11 +138,11 @@ def test_thread_axis(): ...@@ -137,11 +138,11 @@ def test_thread_axis():
assert('if' not in str(stmt.body.body.body[0])) assert('if' not in str(stmt.body.body.body[0]))
def test_vectorize(): def test_vectorize():
n = tvm.var('n') n = tvm.size_var('n')
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B') B = tvm.placeholder((n,), name='B')
bias = tvm.var("bias", dtype="float32") bias = tvm.size_var("bias", dtype="float32")
scale = tvm.var("scale", dtype="float32") scale = tvm.size_var("scale", dtype="float32")
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i) * scale + bias, name='C') C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i) * scale + bias, name='C')
# schedule # schedule
s = tvm.create_schedule(C.op) s = tvm.create_schedule(C.op)
...@@ -160,8 +161,8 @@ def test_vectorize(): ...@@ -160,8 +161,8 @@ def test_vectorize():
def test_condition(): def test_condition():
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
m = tvm.var('m') m = tvm.size_var('m')
n = tvm.var('n') n = tvm.size_var('n')
with ib.for_range(0, tvm.truncdiv(n+3,4), 'i') as i: with ib.for_range(0, tvm.truncdiv(n+3,4), 'i') as i:
with ib.for_range(0, 4, 'j') as j: with ib.for_range(0, 4, 'j') as j:
ib.emit(tvm.make.Evaluate( ib.emit(tvm.make.Evaluate(
...@@ -173,8 +174,8 @@ def test_condition(): ...@@ -173,8 +174,8 @@ def test_condition():
def test_condition_EQ(): def test_condition_EQ():
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
m = tvm.var('m') m = tvm.size_var('m')
n = tvm.var('n') n = tvm.size_var('n')
with ib.for_range(0, 10, 'i') as i: with ib.for_range(0, 10, 'i') as i:
ib.emit(tvm.make.Evaluate( ib.emit(tvm.make.Evaluate(
tvm.make.Select(ib.likely(tvm.expr.EQ(i, 5)), m, n))) tvm.make.Select(ib.likely(tvm.expr.EQ(i, 5)), m, n)))
...@@ -185,7 +186,7 @@ def test_condition_EQ(): ...@@ -185,7 +186,7 @@ def test_condition_EQ():
def test_thread_axis2(): def test_thread_axis2():
n = tvm.convert(4096) n = tvm.convert(4096)
m = tvm.var('m') m = tvm.size_var('m')
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B') B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C') C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C')
...@@ -201,8 +202,8 @@ def test_thread_axis2(): ...@@ -201,8 +202,8 @@ def test_thread_axis2():
assert('threadIdx' not in str(for_body.extent)) assert('threadIdx' not in str(for_body.extent))
def test_everything_during_deduction(): def test_everything_during_deduction():
m = tvm.var('m') m = tvm.size_var('m')
n = tvm.var('n') n = tvm.size_var('n')
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
with ib.for_range(0, n, 'i') as i: with ib.for_range(0, n, 'i') as i:
with ib.for_range(0, 32, 'j') as j: with ib.for_range(0, 32, 'j') as j:
...@@ -252,7 +253,7 @@ def test_multi_likely(): ...@@ -252,7 +253,7 @@ def test_multi_likely():
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse)))) assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
def test_oneD_pool(): def test_oneD_pool():
m = tvm.var('m') m = tvm.size_var('m')
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
#data = tvm.placeholder((16,), name = 'data') #data = tvm.placeholder((16,), name = 'data')
data = ib.pointer("float32", name="A") data = ib.pointer("float32", name="A")
......
...@@ -19,7 +19,7 @@ import numpy ...@@ -19,7 +19,7 @@ import numpy
def test_makeapi(): def test_makeapi():
"""Not yet working, mock design""" """Not yet working, mock design"""
n = tvm.var('n') n = tvm.size_var('n')
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B') B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
......
...@@ -19,7 +19,7 @@ import tvm ...@@ -19,7 +19,7 @@ import tvm
@pytest.mark.xfail @pytest.mark.xfail
def test_loop_dependent_allocate(): def test_loop_dependent_allocate():
N = tvm.var("N") N = tvm.size_var("N")
A = tvm.placeholder((2*N,), "float32", "A") A = tvm.placeholder((2*N,), "float32", "A")
C = tvm.compute((N, ), lambda i: A[2*i] + A[i+1], name='C') C = tvm.compute((N, ), lambda i: A[2*i] + A[i+1], name='C')
s = tvm.create_schedule(C.op) s = tvm.create_schedule(C.op)
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
import tvm import tvm
def test_flatten2(): def test_flatten2():
m = tvm.var('m') m = tvm.size_var('m')
l = tvm.var('l') l = tvm.size_var('l')
A = tvm.placeholder((m, l), name='A') A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
...@@ -38,8 +38,8 @@ def test_flatten2(): ...@@ -38,8 +38,8 @@ def test_flatten2():
def test_flatten_prefetch(): def test_flatten_prefetch():
A = tvm.placeholder((25, 100, 4), name = 'A') A = tvm.placeholder((25, 100, 4), name = 'A')
_A= tvm.decl_buffer(A.shape, A.dtype, name = 'A'); _A= tvm.decl_buffer(A.shape, A.dtype, name = 'A');
i = tvm.var('i') i = tvm.size_var('i')
j = tvm.var('j') j = tvm.size_var('j')
region = [tvm.make.range_by_min_extent(i[0], i[1]) for i in [(i, 2), (j, 8), (0, 4)]] region = [tvm.make.range_by_min_extent(i[0], i[1]) for i in [(i, 2), (j, 8), (0, 4)]]
stmt = tvm.make.Prefetch(A.op, 0, A.dtype, region) stmt = tvm.make.Prefetch(A.op, 0, A.dtype, region)
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: _A}, 64) stmt = tvm.ir_pass.StorageFlatten(stmt, {A: _A}, 64)
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
import tvm import tvm
def test_storage_sync(): def test_storage_sync():
m = tvm.var('m') m = tvm.size_var('m')
l = tvm.var('l') l = tvm.size_var('l')
A = tvm.placeholder((m, l), name='A') A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
...@@ -54,7 +54,7 @@ def test_coproc_sync(): ...@@ -54,7 +54,7 @@ def test_coproc_sync():
max_num_bits=128, max_num_bits=128,
head_address=tvm.call_extern("handle", "global_cache")) head_address=tvm.call_extern("handle", "global_cache"))
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
n = tvm.var("n") n = tvm.size_var("n")
cp = tvm.thread_axis((0, 1), "cop") cp = tvm.thread_axis((0, 1), "cop")
A = ib.allocate("float32", 128, name="A", scope="global.cache") A = ib.allocate("float32", 128, name="A", scope="global.cache")
with ib.for_range(0, n, name="i") as i: with ib.for_range(0, n, name="i") as i:
...@@ -76,7 +76,7 @@ def test_coproc_sync(): ...@@ -76,7 +76,7 @@ def test_coproc_sync():
def test_coproc_sync2(): def test_coproc_sync2():
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
n = tvm.var("n") n = tvm.size_var("n")
cp = tvm.thread_axis((0, 1), "cop") cp = tvm.thread_axis((0, 1), "cop")
ty = tvm.thread_axis("cthread") ty = tvm.thread_axis("cthread")
A = ib.allocate("float32", 128, name="A") A = ib.allocate("float32", 128, name="A")
...@@ -102,7 +102,7 @@ def test_coproc_sync3(): ...@@ -102,7 +102,7 @@ def test_coproc_sync3():
return True return True
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
n = tvm.var("n") n = tvm.size_var("n")
cp = tvm.thread_axis((0, 1), "cop") cp = tvm.thread_axis((0, 1), "cop")
A = ib.allocate("float32", 128, name="A", scope="global.cache") A = ib.allocate("float32", 128, name="A", scope="global.cache")
with ib.for_range(0, n, name="i") as i: with ib.for_range(0, n, name="i") as i:
......
...@@ -21,7 +21,7 @@ import os ...@@ -21,7 +21,7 @@ import os
def test_unroll_loop(): def test_unroll_loop():
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
dtype = 'int64' dtype = 'int64'
n = tvm.var('n') n = tvm.size_var('n')
Ab = tvm.decl_buffer((n, ), dtype) Ab = tvm.decl_buffer((n, ), dtype)
Aptr = ib.buffer_ptr(Ab) Aptr = ib.buffer_ptr(Ab)
# for i in 0 to n-1: # for i in 0 to n-1:
...@@ -54,7 +54,7 @@ def test_unroll_loop(): ...@@ -54,7 +54,7 @@ def test_unroll_loop():
def test_unroll_fake_loop(): def test_unroll_fake_loop():
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
dtype = 'int32' dtype = 'int32'
n = tvm.var('n') n = tvm.size_var('n')
Ab = tvm.decl_buffer((n, ), dtype) Ab = tvm.decl_buffer((n, ), dtype)
Aptr = ib.buffer_ptr(Ab) Aptr = ib.buffer_ptr(Ab)
# for i in 0 to n-1: # for i in 0 to n-1:
...@@ -68,7 +68,7 @@ def test_unroll_fake_loop(): ...@@ -68,7 +68,7 @@ def test_unroll_fake_loop():
assert isinstance(ret[0], tvm.stmt.Store) assert isinstance(ret[0], tvm.stmt.Store)
def test_unroll_single_count_loops(): def test_unroll_single_count_loops():
n = tvm.var('n') n = tvm.size_var('n')
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.compute((n,), lambda *i: A(*i), name='B') B = tvm.compute((n,), lambda *i: A(*i), name='B')
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
......
...@@ -142,18 +142,18 @@ def conv2d_infer_layout(workload, cfg): ...@@ -142,18 +142,18 @@ def conv2d_infer_layout(workload, cfg):
def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'): def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'):
""" Get the workload structure. """ """ Get the workload structure. """
if data_layout == 'NCHW': if data_layout == 'NCHW':
_, CI, IH, IW = [x.value for x in data.shape] _, CI, IH, IW = get_const_tuple(data.shape)
elif data_layout == 'NHWC': elif data_layout == 'NHWC':
_, IH, IW, CI = [x.value for x in data.shape] _, IH, IW, CI = get_const_tuple(data.shape)
elif data_layout == 'HWCN': elif data_layout == 'HWCN':
IH, IW, CI, _ = [x.value for x in data.shape] IH, IW, CI, _ = get_const_tuple(data.shape)
else: else:
raise ValueError("not support this layout {} yet".format(data_layout)) raise ValueError("not support this layout {} yet".format(data_layout))
if data_layout == 'NCHW': if data_layout == 'NCHW':
CO, CIG, KH, KW = [x.value for x in kernel.shape] CO, CIG, KH, KW = get_const_tuple(kernel.shape)
else: else:
KH, KW, CIG, CO = [x.value for x in kernel.shape] KH, KW, CIG, CO = get_const_tuple(kernel.shape)
HPAD, WPAD, _, _ = get_pad_tuple(padding, (get_const_int(KH), get_const_int(KW))) HPAD, WPAD, _, _ = get_pad_tuple(padding, (get_const_int(KH), get_const_int(KW)))
GRPS = CI // CIG GRPS = CI // CIG
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment