Commit 71334483 by Tianqi Chen Committed by GitHub

[VISITOR] New ExprFunctor, StmtFunctor Interface. Modular analysis (#58)

* [ARITH/VISITOR] Modular Analysis, ExprFunctor, StmtFunctor

* retrigger

* [IRFunctor] Migrated CodegenC

* [IRFUNCTOR] Migrate CodeGenLLVM

* [IRFunctor] Migrate canonical

* [IRFunctor] Migrate vectorize

* [IRFunctor] migrate CodeGenStackVM
parent e4387940
...@@ -59,7 +59,6 @@ after_failure: ...@@ -59,7 +59,6 @@ after_failure:
- tests/travis/travis_after_failure.sh - tests/travis/travis_after_failure.sh
notifications: notifications:
# Emails are sent to the committer's git-configured email address by default,
email: email:
on_success: change on_success: change
on_failure: always on_failure: always
...@@ -55,59 +55,23 @@ class IRMutator { ...@@ -55,59 +55,23 @@ class IRMutator {
static FMutateStmt& vtable_stmt(); // NOLINT(*) static FMutateStmt& vtable_stmt(); // NOLINT(*)
// Set of overloadable functions // Set of overloadable functions
// The underscore allows Mutate not to be shadowed by inheritance // The underscore allows Mutate not to be shadowed by inheritance
virtual Stmt Mutate_(const Variable* op, const Stmt& s);
virtual Stmt Mutate_(const LetStmt* op, const Stmt& s); virtual Stmt Mutate_(const LetStmt* op, const Stmt& s);
virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s); virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s);
virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s); virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s);
virtual Stmt Mutate_(const For* op, const Stmt& s); virtual Stmt Mutate_(const For* op, const Stmt& s);
virtual Stmt Mutate_(const Allocate* op, const Stmt& s); virtual Stmt Mutate_(const Allocate* op, const Stmt& s);
virtual Stmt Mutate_(const Load* op, const Stmt& s);
virtual Stmt Mutate_(const Store* op, const Stmt& s); virtual Stmt Mutate_(const Store* op, const Stmt& s);
virtual Stmt Mutate_(const Let* op, const Stmt& s);
virtual Stmt Mutate_(const Free* op, const Stmt& s); virtual Stmt Mutate_(const Free* op, const Stmt& s);
virtual Stmt Mutate_(const Call* op, const Stmt& s);
virtual Stmt Mutate_(const Add* op, const Stmt& e);
virtual Stmt Mutate_(const Sub* op, const Stmt& e);
virtual Stmt Mutate_(const Mul* op, const Stmt& e);
virtual Stmt Mutate_(const Div* op, const Stmt& e);
virtual Stmt Mutate_(const Mod* op, const Stmt& e);
virtual Stmt Mutate_(const Min* op, const Stmt& e);
virtual Stmt Mutate_(const Max* op, const Stmt& e);
virtual Stmt Mutate_(const EQ* op, const Stmt& e);
virtual Stmt Mutate_(const NE* op, const Stmt& e);
virtual Stmt Mutate_(const LT* op, const Stmt& e);
virtual Stmt Mutate_(const LE* op, const Stmt& e);
virtual Stmt Mutate_(const GT* op, const Stmt& e);
virtual Stmt Mutate_(const GE* op, const Stmt& e);
virtual Stmt Mutate_(const And* op, const Stmt& e);
virtual Stmt Mutate_(const Or* op, const Stmt& e);
virtual Stmt Mutate_(const Reduce* op, const Stmt& s);
virtual Stmt Mutate_(const Cast* op, const Stmt& s);
virtual Stmt Mutate_(const Not* op, const Stmt& s);
virtual Stmt Mutate_(const Select* op, const Stmt& s);
virtual Stmt Mutate_(const Ramp* op, const Stmt& s);
virtual Stmt Mutate_(const Broadcast* op, const Stmt& e);
virtual Stmt Mutate_(const AssertStmt* op, const Stmt& e); virtual Stmt Mutate_(const AssertStmt* op, const Stmt& e);
virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& e); virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& e);
virtual Stmt Mutate_(const Provide* op, const Stmt& e); virtual Stmt Mutate_(const Provide* op, const Stmt& e);
virtual Stmt Mutate_(const Realize* op, const Stmt& s); virtual Stmt Mutate_(const Realize* op, const Stmt& s);
virtual Stmt Mutate_(const Block* op, const Stmt& s); virtual Stmt Mutate_(const Block* op, const Stmt& s);
virtual Stmt Mutate_(const Evaluate* op, const Stmt& e); virtual Stmt Mutate_(const Evaluate* op, const Stmt& e);
virtual Stmt Mutate_(const IntImm* op, const Stmt& e);
virtual Stmt Mutate_(const UIntImm* op, const Stmt& e);
virtual Stmt Mutate_(const FloatImm* op, const Stmt& e);
virtual Stmt Mutate_(const StringImm* op, const Stmt& e);
virtual Expr Mutate_(const Variable* op, const Expr& e); virtual Expr Mutate_(const Variable* op, const Expr& e);
virtual Expr Mutate_(const LetStmt* op, const Expr& e);
virtual Expr Mutate_(const AttrStmt* op, const Expr& e);
virtual Expr Mutate_(const IfThenElse* op, const Expr& e);
virtual Expr Mutate_(const For* op, const Expr& e);
virtual Expr Mutate_(const Allocate* op, const Expr& e);
virtual Expr Mutate_(const Load* op, const Expr& e); virtual Expr Mutate_(const Load* op, const Expr& e);
virtual Expr Mutate_(const Store* op, const Expr& e);
virtual Expr Mutate_(const Let* op, const Expr& e); virtual Expr Mutate_(const Let* op, const Expr& e);
virtual Expr Mutate_(const Free* op, const Expr& e);
virtual Expr Mutate_(const Call* op, const Expr& e); virtual Expr Mutate_(const Call* op, const Expr& e);
virtual Expr Mutate_(const Add* op, const Expr& e); virtual Expr Mutate_(const Add* op, const Expr& e);
virtual Expr Mutate_(const Sub* op, const Expr& e); virtual Expr Mutate_(const Sub* op, const Expr& e);
...@@ -130,38 +94,12 @@ class IRMutator { ...@@ -130,38 +94,12 @@ class IRMutator {
virtual Expr Mutate_(const Select* op, const Expr& e); virtual Expr Mutate_(const Select* op, const Expr& e);
virtual Expr Mutate_(const Ramp* op, const Expr& e); virtual Expr Mutate_(const Ramp* op, const Expr& e);
virtual Expr Mutate_(const Broadcast* op, const Expr& e); virtual Expr Mutate_(const Broadcast* op, const Expr& e);
virtual Expr Mutate_(const AssertStmt* op, const Expr& e);
virtual Expr Mutate_(const ProducerConsumer* op, const Expr& e);
virtual Expr Mutate_(const Provide* op, const Expr& e);
virtual Expr Mutate_(const Realize* op, const Expr& e);
virtual Expr Mutate_(const Block* op, const Expr& e);
virtual Expr Mutate_(const Evaluate* op, const Expr& e);
virtual Expr Mutate_(const IntImm* op, const Expr& e); virtual Expr Mutate_(const IntImm* op, const Expr& e);
virtual Expr Mutate_(const UIntImm* op, const Expr& e); virtual Expr Mutate_(const UIntImm* op, const Expr& e);
virtual Expr Mutate_(const FloatImm* op, const Expr& e); virtual Expr Mutate_(const FloatImm* op, const Expr& e);
virtual Expr Mutate_(const StringImm* op, const Expr& e); virtual Expr Mutate_(const StringImm* op, const Expr& e);
}; };
/*!
* \brief Example on how to subclass and override behavior of IRMutator
*/
class IRMutatorExample : public IRMutator {
public:
Expr Mutate(Expr expr) final {
static const FMutateExpr& f = IRMutatorExample::vtable_expr();
return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::Mutate(expr));
}
Stmt Mutate(Stmt stmt) final {
static const FMutateStmt& f = IRMutatorExample::vtable_stmt();
return (f.can_dispatch(stmt) ?
f(stmt, stmt, this) : IRMutator::Mutate(stmt));
}
// to be implemented by child class
static FMutateExpr& vtable_expr(); // NOLINT(*)
static FMutateStmt& vtable_stmt(); // NOLINT(*)
};
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
#endif // TVM_IR_MUTATOR_H_ #endif // TVM_IR_MUTATOR_H_
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#ifndef TVM_IR_VISITOR_H_ #ifndef TVM_IR_VISITOR_H_
#define TVM_IR_VISITOR_H_ #define TVM_IR_VISITOR_H_
#include <tvm/ir_functor.h>
#include "./ir.h" #include "./ir.h"
namespace tvm { namespace tvm {
...@@ -17,7 +18,51 @@ namespace ir { ...@@ -17,7 +18,51 @@ namespace ir {
* This IRVisitor is implemented via IRFunctor * This IRVisitor is implemented via IRFunctor
* This enables extensions of possible new Node. * This enables extensions of possible new Node.
* *
* \sa IRFunctor, PostOrderVisit * \sa ExprFunctor, StmtFunctor, PostOrderVisit
*
* \note If you need to return values during Visit:
* - If it is mutaion of the IR, use IRMutator
* - If you want to return other things, consider use ExprFunctor/StmtFunctor
* - Watch out for possible bug pattern if you use IRVisitor to simulate returns.
*
* \code
*
* // This is an example code to show cases for traps in IRVisitor
* // The use case is to count number of Variables in the ir tree.
* class MyCounter : public IRVisitor {
* public:
* int Count(const NodeRef& n) {
* ret_ = 0;
* this->Visit(n);
* return ret_;
* }
* void Visit_(const Variable* op) final {
* ret_ = 1;
* }
* void Visit_(const Add* op) final {
* ret_ = count(op->a) + count(op->b);
* }
* private:
* int ret_;
* };
* MyCounter counter;
* Var x("x");
* // this returns 2
* CHECK_EQ(counter.Count(x + x), 2);
* // Think what is the result of the following count
* counter.count(Max::make(x, x));
* // The result is actually 1
* // This is because Visit is not overriden for Max
* // so it simply calls Visit for the left and right children
* // and because Count is not called, ret_ is not cleared.
* // There can also be cases where ret_ is forgetten to be set.
*
* // These traps may not happen if we program carefully
* // But it is recommended to use ExprFunctor, which allows direct
* // return the value, this helps us to avoid such problems.
* \encode
*
*/ */
class IRVisitor { class IRVisitor {
public: public:
......
...@@ -274,33 +274,51 @@ def sum(expr, axis): ...@@ -274,33 +274,51 @@ def sum(expr, axis):
return x return x
def min(expr, axis): def min(lhs, rhs=None, axis=None):
"""Create a min expression over axis """Create a min expression.
Parameters Parameters
---------- ----------
expr : Expr lhs : Expr
The source expression. The left hand expression.
axis : IterVar rhs : Expr, optional
The right hand expression.
axis : IterVar, optional
The reduction IterVar axis The reduction IterVar axis
""" """
if rhs and axis:
raise ValueError("Can only take one argument, rhs or axis")
if isinstance(rhs, (_collections.IterVar, list)):
axis, rhs = rhs, axis
if rhs:
return _make.Min(lhs, rhs)
axis = axis if isinstance(axis, list) else [axis] axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Min", expr, axis) x = _make.Reduce("Min", expr, axis)
return x return x
def max(expr, axis): def max(lhs, rhs=None, axis=None):
"""Create a min expression over axis """Create a max expression.
Parameters Parameters
---------- ----------
expr : Expr lhs : Expr
The source expression. The left hand expression.
axis : IterVar rhs : Expr, optional
The right hand expression.
axis : IterVar, optional
The reduction IterVar axis The reduction IterVar axis
""" """
if rhs and axis:
raise ValueError("Can only take one argument, rhs or axis")
if isinstance(rhs, (_collections.IterVar, list)):
axis, rhs = rhs, axis
if rhs:
return _make.Max(lhs, rhs)
axis = axis if isinstance(axis, list) else [axis] axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Max", expr, axis) x = _make.Reduce("Max", expr, axis)
return x return x
......
...@@ -5,7 +5,6 @@ from __future__ import absolute_import as _abs ...@@ -5,7 +5,6 @@ from __future__ import absolute_import as _abs
from ._ctypes._node import NodeBase, register_node from ._ctypes._node import NodeBase, register_node
from . import _api_internal from . import _api_internal
@register_node
class IntSet(NodeBase): class IntSet(NodeBase):
"""Represent a set of integer in one dimension.""" """Represent a set of integer in one dimension."""
def is_nothing(self): def is_nothing(self):
...@@ -33,3 +32,8 @@ class IntervalSet(IntSet): ...@@ -33,3 +32,8 @@ class IntervalSet(IntSet):
class StrideSet(IntSet): class StrideSet(IntSet):
"""Represent set of strided integers""" """Represent set of strided integers"""
pass pass
@register_node
class ModularSet(IntSet):
"""Represent range of (coeff * x + base) for x in Z """
pass
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/api_registry.h> #include <tvm/api_registry.h>
#include "../arithmetic/int_set.h" #include "../arithmetic/int_set.h"
#include "../arithmetic/modular.h"
namespace tvm { namespace tvm {
namespace arith { namespace arith {
...@@ -21,6 +22,11 @@ TVM_REGISTER_API(_arith_intset_interval) ...@@ -21,6 +22,11 @@ TVM_REGISTER_API(_arith_intset_interval)
*ret = IntSet::interval(args[0], args[1]); *ret = IntSet::interval(args[0], args[1]);
}); });
TVM_REGISTER_API(_arith_EvalModular)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = EvalModular(args[0], Map<Var, IntSet>());
});
TVM_REGISTER_API(_arith_DeduceBound) TVM_REGISTER_API(_arith_DeduceBound)
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DeduceBound(args[0], args[1], args[2]); *ret = DeduceBound(args[0], args[1], args[2]);
......
...@@ -162,10 +162,8 @@ class Canonical::Internal : public IRMutator { ...@@ -162,10 +162,8 @@ class Canonical::Internal : public IRMutator {
return stmt; return stmt;
} }
Expr MutateExpr_(Expr expr) { Expr MutateExpr_(Expr expr) {
static const FMutateExpr& f = Internal::vtable_expr();
stack_.push_back(StackEntry()); stack_.push_back(StackEntry());
expr = (f.can_dispatch(expr) ? expr = IRMutator::Mutate(expr);
f(expr, expr, this) : IRMutator::Mutate(expr));
// update result of parent automatically during pop // update result of parent automatically during pop
if (stack_.size() > 1) { if (stack_.size() > 1) {
StackEntry& back = stack_[stack_.size() - 1]; StackEntry& back = stack_[stack_.size() - 1];
...@@ -200,7 +198,7 @@ class Canonical::Internal : public IRMutator { ...@@ -200,7 +198,7 @@ class Canonical::Internal : public IRMutator {
return (t.lanes() == 1 && (t.is_int() || t.is_uint())); return (t.lanes() == 1 && (t.is_int() || t.is_uint()));
} }
// Add // Add
Expr Mutate_(const Add* op, const Expr& e) { Expr Mutate_(const Add* op, const Expr& e) final {
if (!EnableOpt(op->type)) { if (!EnableOpt(op->type)) {
return Binary(op, e, this); return Binary(op, e, this);
} }
...@@ -212,7 +210,7 @@ class Canonical::Internal : public IRMutator { ...@@ -212,7 +210,7 @@ class Canonical::Internal : public IRMutator {
return SumAdd(a, b, +1); return SumAdd(a, b, +1);
} }
// Sub // Sub
Expr Mutate_(const Sub* op, const Expr& e) { Expr Mutate_(const Sub* op, const Expr& e) final {
if (!EnableOpt(op->type)) { if (!EnableOpt(op->type)) {
return Binary(op, e, this); return Binary(op, e, this);
} }
...@@ -224,7 +222,7 @@ class Canonical::Internal : public IRMutator { ...@@ -224,7 +222,7 @@ class Canonical::Internal : public IRMutator {
return SumAdd(a, b, -1); return SumAdd(a, b, -1);
} }
// Mul // Mul
Expr Mutate_(const Mul* op, const Expr& e) { Expr Mutate_(const Mul* op, const Expr& e) final {
if (!EnableOpt(op->type)) { if (!EnableOpt(op->type)) {
return Binary(op, e, this); return Binary(op, e, this);
} }
...@@ -463,17 +461,6 @@ class Canonical::Internal : public IRMutator { ...@@ -463,17 +461,6 @@ class Canonical::Internal : public IRMutator {
using CInternal = Canonical::Internal; using CInternal = Canonical::Internal;
#define DISPATCH_EXPR(OP) \
set_dispatch<OP>([](const OP *op, const Expr& e, IRMutator* p) { \
return static_cast<CInternal*>(p)->Mutate_(op, e); })
TVM_STATIC_IR_FUNCTOR(CInternal, vtable_expr)
.DISPATCH_EXPR(Add)
.DISPATCH_EXPR(Sub)
.DISPATCH_EXPR(Mul)
.DISPATCH_EXPR(LT);
Canonical::Canonical() Canonical::Canonical()
: ptr_(std::make_shared<Internal>()) {} : ptr_(std::make_shared<Internal>()) {}
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include "./int_set.h" #include "./int_set.h"
#include "./modular.h"
namespace tvm { namespace tvm {
namespace arith { namespace arith {
...@@ -54,6 +55,23 @@ struct StrideSet : public IntSetNode { ...@@ -54,6 +55,23 @@ struct StrideSet : public IntSetNode {
TVM_DECLARE_NODE_TYPE_INFO(StrideSet, IntSetNode); TVM_DECLARE_NODE_TYPE_INFO(StrideSet, IntSetNode);
}; };
/*!
* \brief Set represented by range of ModularEntry.
* Used for front-end modular analysis.
*/
struct ModularSet : public IntSetNode {
/*! \brief Internal modular entry */
ModularEntry e;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("base", &(e.base));
v->Visit("coeff", &(e.coeff));
}
static constexpr const char* _type_key = "ModularSet";
TVM_DECLARE_NODE_TYPE_INFO(ModularSet, IntSetNode);
};
} // namespace arith } // namespace arith
} // namespace tvm } // namespace tvm
......
/*!
* Copyright (c) 2017 by Contributors
* \file modular.cc
* \brief Modular analysis
*/
#include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_visitor.h>
#include <limits>
#include "./modular.h"
#include "./int_set_internal.h"
namespace tvm {
namespace arith {
using namespace ir;
class ModularEvaluator
: public ExprFunctor<ModularEntry(const Expr&)> {
public:
explicit ModularEvaluator(
const std::unordered_map<
const Variable*, ModularEntry>& mod_map)
: mod_map_(mod_map) {
}
ModularEntry Eval(const Expr& e) {
return VisitExpr(e);
}
// default
ModularEntry VisitExprDefault_(const Node*) final {
return ModularEntry::everything();
}
// override combination rules.
ModularEntry VisitExpr_(const IntImm* op) final {
if (op->value < std::numeric_limits<int>::max()) {
ModularEntry ret;
ret.base = static_cast<int>(op->value);
ret.coeff = 0;
return ret;
} else {
return ModularEntry::everything();
}
}
ModularEntry VisitExpr_(const UIntImm* op) final {
if (op->value < static_cast<uint64_t>(
std::numeric_limits<int>::max())) {
ModularEntry ret;
ret.base = static_cast<int>(op->value);
ret.coeff = 0;
return ret;
} else {
return ModularEntry::everything();
}
}
ModularEntry VisitExpr_(const Variable* op) final {
auto it = mod_map_.find(op);
if (it != mod_map_.end()) {
return it->second;
} else {
return ModularEntry::everything();
}
}
ModularEntry VisitExpr_(const Add* op) final {
ModularEntry a = Eval(op->a);
ModularEntry b = Eval(op->b);
ModularEntry ret;
ret.coeff = ZeroAwareGCD(a.coeff, b.coeff);
ret.base = BaseSimplify(a.base + b.base, ret.coeff);
return ret;
}
ModularEntry VisitExpr_(const Sub* op) final {
ModularEntry a = Eval(op->a);
ModularEntry b = Eval(op->b);
ModularEntry ret;
ret.coeff = ZeroAwareGCD(a.coeff, b.coeff);
ret.base = BaseSimplify(a.base - b.base, ret.coeff);
return ret;
}
ModularEntry VisitExpr_(const Mul* op) final {
ModularEntry a = Eval(op->a);
ModularEntry b = Eval(op->b);
// Simplification rule, x, y, z are in Z
// (p x + n) (q y + m)
// -> pq xy + pm x + qn y + mn
// -> pq z + pm x + qn y + mn
int pq = a.coeff * b.coeff;
int pm = a.coeff * b.base;
int qn = a.base * b.coeff;
ModularEntry ret;
ret.coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn));
ret.base = BaseSimplify(a.base * b.base, ret.coeff);
return ret;
}
ModularEntry VisitExpr_(const Div* op) final {
// a c x / c -> a x
// We cannot do cases where offset is non-zero
// because of different integer rounding in pos/neg
ModularEntry a = Eval(op->a);
ModularEntry b = Eval(op->b);
if (b.coeff == 0 &&
a.base == 0) {
CHECK_NE(b.base, 0);
if (a.coeff % b.base == 0) {
ModularEntry ret;
ret.coeff = a.coeff / b.base;
ret.base = 0;
return ret;
}
}
return ModularEntry::everything();
}
private:
const std::unordered_map<
const Variable*, ModularEntry>& mod_map_;
// simplify the base by putting it in range.
static int BaseSimplify(int base, int coeff) {
if (coeff == 0) return base;
base = base % coeff;
if (base < 0) base += coeff;
return base;
}
static int ZeroAwareGCD(int a, int b) {
CHECK_GE(a, 0);
CHECK_GE(b, 0);
if (a < b) std::swap(a, b);
if (b == 0) return a;
// perform GCD (greatest common divisor)
// ax + by = gcd(a, b) z if a != 0, b != 0
while (a % b != 0) {
a = a % b;
std::swap(a, b);
}
return b;
}
};
ModularEntry EvalModular(
const Expr& e,
const std::unordered_map<const Variable*, ModularEntry>& mod_map) {
return ModularEvaluator(mod_map)(e);
}
IntSet EvalModular(const Expr& e,
const Map<Var, IntSet>& mod_map) {
std::unordered_map<const Variable*, ModularEntry> mmap;
for (auto& kv : mod_map) {
const ModularSet* m = kv.second.as<ModularSet>();
CHECK(m) << "Need to pass ModularSet for Modular Analysis";
mmap[kv.first.get()] = m->e;
}
std::shared_ptr<ModularSet> n = std::make_shared<ModularSet>();
n->e = ModularEvaluator(mmap)(e);
return IntSet(n);
}
} // namespace arith
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file modular.h
* \brief Modular integer set analysis
*/
#ifndef TVM_ARITHMETIC_MODULAR_H_
#define TVM_ARITHMETIC_MODULAR_H_
#include <tvm/expr.h>
#include "./int_set.h"
namespace tvm {
namespace arith {
/*!
* \brief Range of a linear integer function.
* Use to do specify the possible index values.
*
* set = { base + coeff * x | x \in Z }
*
* When coeff != 0, it can also be written as
* set = { n | n % coeff == base }
*
* This is useful to decide if the index is dividable by certain value.
* For example, if index = 0 + 4 x, then we know it can be divided by 4.
*/
struct ModularEntry {
/*! \brief The base */
int base;
/*! \brief linear co-efficient */
int coeff;
/*! \return entry represent everything */
static ModularEntry everything() {
// always safe to set 0 + x, so it can be everything.
ModularEntry e;
e.base = 0; e.coeff = 1;
return e;
}
};
/*!
* \brief Evaluate the expression with modular analysis
* \param e The expression to be evaluated.
* \param mod_map Map of modular statistics of known variables.
* \return The ModularEntry covering all possible value of e.
*/
ModularEntry EvalModular(
const Expr& e,
const std::unordered_map<const Variable*, ModularEntry>& mod_map);
/*!
* \brief Same as EvalModular, used by front-end.
* \param e The expression to be evaluated.
* \param mod_map Map of modular statistics of known variables.
* \return A ModularSet covering all possible value of e.
*/
IntSet EvalModular(const Expr& e,
const Map<Var, IntSet>& mod_map);
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_MODULAR_H_
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#define TVM_CODEGEN_CODEGEN_C_H_ #define TVM_CODEGEN_CODEGEN_C_H_
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <tvm/lowered_func.h> #include <tvm/lowered_func.h>
#include <string> #include <string>
...@@ -16,12 +17,15 @@ ...@@ -16,12 +17,15 @@
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
using namespace ir;
/*! /*!
* \brief A base class to generate C code. * \brief A base class to generate C code.
* *
* CodeGenC have two modes: generate SSA formed C code or normal form. * CodeGenC have two modes: generate SSA formed C code or normal form.
*/ */
class CodeGenC { class CodeGenC :
public ExprFunctor<void(const Expr&, std::ostream&)>,
public StmtFunctor<void(const Stmt&)> {
public: public:
/*! /*!
* \brief Initialize the code generator. * \brief Initialize the code generator.
...@@ -42,13 +46,15 @@ class CodeGenC { ...@@ -42,13 +46,15 @@ class CodeGenC {
* \brief Print the Stmt n to CodeGenC->stream * \brief Print the Stmt n to CodeGenC->stream
* \param n The statement to be printed. * \param n The statement to be printed.
*/ */
void PrintStmt(const Stmt& n); void PrintStmt(const Stmt& n) {
VisitStmt(n);
}
/*! /*!
* \brief Print the expression n(or its ssa id if in ssa mode) into os * \brief Print the expression n(or its ssa id if in ssa mode) into os
* \param n The expression to be printed. * \param n The expression to be printed.
* \param os The output stream * \param os The output stream
*/ */
void PrintExpr(const Expr& n, std::ostream& os); // NOLINT(*) void PrintExpr(const Expr& n, std::ostream& os);
/*! /*!
* \brief Same as PrintExpr, but simply returns result string * \brief Same as PrintExpr, but simply returns result string
* \param n The expression to be printed. * \param n The expression to be printed.
...@@ -84,6 +90,46 @@ class CodeGenC { ...@@ -84,6 +90,46 @@ class CodeGenC {
* \param f The function to be compiled. * \param f The function to be compiled.
*/ */
virtual void InitFuncState(LoweredFunc f); virtual void InitFuncState(LoweredFunc f);
// expression
void VisitExpr_(const Variable* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Load* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Let* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Call* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Add* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Sub* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Mul* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Div* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Mod* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Min* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Max* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const EQ* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const NE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LT* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GT* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const And* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Or* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Cast* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Not* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Select* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Ramp* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Broadcast* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const IntImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const UIntImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const FloatImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const StringImm* op, std::ostream& os) override; // NOLINT(*)
// statment
void VisitStmt_(const LetStmt* op) override;
void VisitStmt_(const Store* op) override;
void VisitStmt_(const For* op) override;
void VisitStmt_(const IfThenElse* op) override;
void VisitStmt_(const Allocate* op) override;
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
/*! /*!
* Print Type represetnation of type t. * Print Type represetnation of type t.
* \param t The type representation. * \param t The type representation.
...@@ -97,50 +143,37 @@ class CodeGenC { ...@@ -97,50 +143,37 @@ class CodeGenC {
*/ */
virtual void PrintThreadIndexExpr( virtual void PrintThreadIndexExpr(
std::string tag, std::ostream& os); // NOLINT(*) std::string tag, std::ostream& os); // NOLINT(*)
virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(* virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*)
virtual void PrintStorageSync(const std::string& scope); // NOLINT(*) virtual void PrintStorageSync(const std::string& scope); // NOLINT(*)
virtual void PrintStmt(const ir::LetStmt* op);
virtual void PrintStmt(const ir::Store* op);
virtual void PrintStmt(const ir::For* op);
virtual void PrintStmt(const ir::IfThenElse* op);
virtual void PrintStmt(const ir::Allocate* op);
virtual void PrintStmt(const ir::AttrStmt* op);
virtual void PrintStmt(const ir::AssertStmt* op);
virtual void PrintExpr(const ir::Load* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Call* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Let* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Ramp* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Broadcast* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Select* op, std::ostream& os); // NOLINT(*)
// Binary vector op. // Binary vector op.
virtual void PrintVecBinaryOp( virtual void PrintVecBinaryOp(
const std::string&op, Type op_type, const std::string&op, Type op_type,
Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*) Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*)
// print vector load
virtual void PrintVecLoad(const Variable* buffer, virtual void PrintVecLoad(const Variable* buffer,
Type t, Expr base, Type t, Expr base,
std::ostream& os); // NOLINT(*) std::ostream& os); // NOLINT(*)
// print vector store
virtual void PrintVecStore(const Variable* buffer, virtual void PrintVecStore(const Variable* buffer,
Type t, Expr base, Type t, Expr base,
const std::string& value); // NOLINT(*) const std::string& value); // NOLINT(*)
// print load of single element
virtual void PrintVecElemLoad( virtual void PrintVecElemLoad(
const std::string& vec, Type t, int i, std::ostream& os); // NOLINT(*) const std::string& vec, Type t, int i, std::ostream& os); // NOLINT(*)
// print store of single element.
virtual void PrintVecElemStore( virtual void PrintVecElemStore(
const std::string& vec, Type t, int i, const std::string& value); const std::string& vec, Type t, int i, const std::string& value);
/*! \brief function print into the ostream */
using FPrintExpr = IRFunctor<void(const NodeRef&, std::ostream& os, CodeGenC *)>; // NOLINT(*)
/*! \brief function to to print normal code */
using FPrintStmt = IRFunctor<void(const NodeRef&, CodeGenC *)>;
// vtable to print code
static FPrintStmt& vtable_print_stmt();
// vtable to print code
static FPrintExpr& vtable_print_expr();
/*! \brief The current indentation value */
int indent{0};
/*! \brief the stream to be printed */
std::ostringstream stream;
protected: protected:
/*! \brief the stream to be printed */
std::ostringstream stream;
/*! \brief entry in ssa assign map */
struct SSAEntry {
/*! \brief The value id */
std::string vid;
/*! \brief The scope id */
int scope_id;
};
// print reference to a buffer as type t in index. // print reference to a buffer as type t in index.
void PrintBufferRef(const Variable* buffer, void PrintBufferRef(const Variable* buffer,
Type t, Expr index, Type t, Expr index,
...@@ -158,13 +191,6 @@ class CodeGenC { ...@@ -158,13 +191,6 @@ class CodeGenC {
* \return The returned name. * \return The returned name.
*/ */
std::string GetUniqueName(std::string prefix); std::string GetUniqueName(std::string prefix);
/*! \brief entry in ssa assign map */
struct SSAEntry {
/*! \brief The value id */
std::string vid;
/*! \brief The scope id */
int scope_id;
};
/*! /*!
* \brief mark the beginning of a new scope * \brief mark the beginning of a new scope
* \return The scope id. * \return The scope id.
...@@ -209,6 +235,8 @@ class CodeGenC { ...@@ -209,6 +235,8 @@ class CodeGenC {
std::unordered_map<const Variable*, Type> handle_data_type_; std::unordered_map<const Variable*, Type> handle_data_type_;
/*! \brief array to check whether we are inside certain scope */ /*! \brief array to check whether we are inside certain scope */
std::vector<bool> scope_mark_; std::vector<bool> scope_mark_;
/*! \brief The current indentation value */
int indent{0};
}; };
} // namespace codegen } // namespace codegen
......
...@@ -19,7 +19,7 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) { ...@@ -19,7 +19,7 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) {
CodeGenC::AddFunction(f); CodeGenC::AddFunction(f);
} }
void CodeGenCUDA::PrintStmt(const ir::For* op) { void CodeGenCUDA::VisitStmt_(const ir::For* op) {
int ext; int ext;
CHECK(is_zero(op->min)); CHECK(is_zero(op->min));
if (arith::GetConstInt(op->extent, &ext) && if (arith::GetConstInt(op->extent, &ext) &&
...@@ -27,7 +27,7 @@ void CodeGenCUDA::PrintStmt(const ir::For* op) { ...@@ -27,7 +27,7 @@ void CodeGenCUDA::PrintStmt(const ir::For* op) {
PrintIndent(); PrintIndent();
stream << "#pragma unroll\n"; stream << "#pragma unroll\n";
} }
CodeGenC::PrintStmt(op); CodeGenC::VisitStmt_(op);
} }
void CodeGenCUDA::PrintType(Type t, std::ostream& os) const { // NOLINT(*) void CodeGenCUDA::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
......
...@@ -18,7 +18,7 @@ class CodeGenCUDA : public CodeGenC { ...@@ -18,7 +18,7 @@ class CodeGenCUDA : public CodeGenC {
public: public:
void AddFunction(LoweredFunc f); void AddFunction(LoweredFunc f);
// override behavior // override behavior
void PrintStmt(const ir::For* op) final; void VisitStmt_(const ir::For* op) final;
void PrintStorageSync(const std::string& sync) final; void PrintStorageSync(const std::string& sync) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintVecBinaryOp( void PrintVecBinaryOp(
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#ifdef TVM_LLVM_VERSION #ifdef TVM_LLVM_VERSION
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_functor_ext.h>
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <memory> #include <memory>
#include <vector> #include <vector>
...@@ -23,7 +23,9 @@ using namespace ir; ...@@ -23,7 +23,9 @@ using namespace ir;
/*! /*!
* \brief A base class to generate a LLVM. * \brief A base class to generate a LLVM.
*/ */
class CodeGenLLVM : public IRVisitor { class CodeGenLLVM :
public ExprFunctor<llvm::Value* (const Expr&)>,
public StmtFunctor<void(const Stmt&)> {
public: public:
/*! /*!
* \brief Initialize the code generator with given context * \brief Initialize the code generator with given context
...@@ -55,52 +57,52 @@ class CodeGenLLVM : public IRVisitor { ...@@ -55,52 +57,52 @@ class CodeGenLLVM : public IRVisitor {
* \return created value. * \return created value.
*/ */
llvm::Value* MakeValue(const Expr& e) { llvm::Value* MakeValue(const Expr& e) {
value_ = nullptr; return VisitExpr(e);
this->Visit(e);
CHECK(value_ != nullptr);
return value_;
} }
// Short hande code to get a constant int 32 // Short hande code to get a constant int 32
llvm::Constant* ConstInt32(unsigned value) const { llvm::Constant* ConstInt32(unsigned value) const {
return llvm::ConstantInt::get(t_int32_, value); return llvm::ConstantInt::get(t_int32_, value);
} }
// override codegen // override codegen
void Visit_(const Variable* op) final; llvm::Value* VisitExpr_(const Variable* op) override;
void Visit_(const Cast* op) final; llvm::Value* VisitExpr_(const Cast* op) override;
void Visit_(const IntImm* op) final; llvm::Value* VisitExpr_(const IntImm* op) override;
void Visit_(const UIntImm* op) final; llvm::Value* VisitExpr_(const UIntImm* op) override;
void Visit_(const FloatImm* op) final; llvm::Value* VisitExpr_(const FloatImm* op) override;
void Visit_(const StringImm* op) final; llvm::Value* VisitExpr_(const StringImm* op) override;
void Visit_(const Add* op) final; llvm::Value* VisitExpr_(const Add* op) override;
void Visit_(const Sub* op) final; llvm::Value* VisitExpr_(const Sub* op) override;
void Visit_(const Mul* op) final; llvm::Value* VisitExpr_(const Mul* op) override;
void Visit_(const Div* op) final; llvm::Value* VisitExpr_(const Div* op) override;
void Visit_(const Mod* op) final; llvm::Value* VisitExpr_(const Mod* op) override;
void Visit_(const Min* op) final; llvm::Value* VisitExpr_(const Min* op) override;
void Visit_(const Max* op) final; llvm::Value* VisitExpr_(const Max* op) override;
void Visit_(const LT* op) final; llvm::Value* VisitExpr_(const LT* op) override;
void Visit_(const LE* op) final; llvm::Value* VisitExpr_(const LE* op) override;
void Visit_(const GT* op) final; llvm::Value* VisitExpr_(const GT* op) override;
void Visit_(const GE* op) final; llvm::Value* VisitExpr_(const GE* op) override;
void Visit_(const EQ* op) final; llvm::Value* VisitExpr_(const EQ* op) override;
void Visit_(const NE* op) final; llvm::Value* VisitExpr_(const NE* op) override;
void Visit_(const And* op) final; llvm::Value* VisitExpr_(const And* op) override;
void Visit_(const Or* op) final; llvm::Value* VisitExpr_(const Or* op) override;
void Visit_(const Not* op) final; llvm::Value* VisitExpr_(const Not* op) override;
void Visit_(const Select* op) final; llvm::Value* VisitExpr_(const Select* op) override;
void Visit_(const Let* op) final; llvm::Value* VisitExpr_(const Let* op) override;
void Visit_(const Load* op) final; llvm::Value* VisitExpr_(const Load* op) override;
void Visit_(const Call* op) final; llvm::Value* VisitExpr_(const Call* op) override;
void Visit_(const Ramp* op) final; llvm::Value* VisitExpr_(const Ramp* op) override;
void Visit_(const Broadcast* op) final; llvm::Value* VisitExpr_(const Broadcast* op) override;
// stmt // stmt
void Visit_(const Store* op) final; void VisitStmt_(const Store* op) override;
void Visit_(const For* op) final; void VisitStmt_(const For* op) override;
void Visit_(const IfThenElse* op) final; void VisitStmt_(const IfThenElse* op) override;
void Visit_(const Allocate* op) final; void VisitStmt_(const Allocate* op) override;
void Visit_(const AttrStmt* op) override; void VisitStmt_(const AttrStmt* op) override;
void Visit_(const AssertStmt* op) final; void VisitStmt_(const AssertStmt* op) override;
void Visit_(const LetStmt* op) final; void VisitStmt_(const LetStmt* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
// create intrinstic given call // create intrinstic given call
virtual llvm::Value* CreateIntrinstic(const Call* op); virtual llvm::Value* CreateIntrinstic(const Call* op);
// create extern function call // create extern function call
...@@ -160,8 +162,6 @@ class CodeGenLLVM : public IRVisitor { ...@@ -160,8 +162,6 @@ class CodeGenLLVM : public IRVisitor {
llvm::Function* f_tvm_parallel_for_{nullptr}; llvm::Function* f_tvm_parallel_for_{nullptr};
// The acting body // The acting body
llvm::BasicBlock* block_{nullptr}; llvm::BasicBlock* block_{nullptr};
// Last value returned codegen call.
llvm::Value* value_{nullptr};
private: private:
// comparison op // comparison op
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#define TVM_CODEGEN_STACK_VM_CODEGEN_STACK_VM_H_ #define TVM_CODEGEN_STACK_VM_CODEGEN_STACK_VM_H_
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/lowered_func.h> #include <tvm/lowered_func.h>
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <string> #include <string>
...@@ -18,12 +19,15 @@ ...@@ -18,12 +19,15 @@
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
using namespace ir;
/*! /*!
* \brief A base class to generate a stack VM. * \brief A base class to generate a stack VM.
* This module is used to generate host wrapper * This module is used to generate host wrapper
* into device function when only device JIT is available. * into device function when only device JIT is available.
*/ */
class CodeGenStackVM { class CodeGenStackVM
: public ExprFunctor<void(const Expr&)>,
public StmtFunctor<void(const Stmt&)> {
public: public:
/*! /*!
* \brief Generate a stack VM representing * \brief Generate a stack VM representing
...@@ -36,7 +40,9 @@ class CodeGenStackVM { ...@@ -36,7 +40,9 @@ class CodeGenStackVM {
/*! \brief Push stmt to generate new code */ /*! \brief Push stmt to generate new code */
void Push(const Stmt& n); void Push(const Stmt& n);
/*! \brief Push expr to generate new code */ /*! \brief Push expr to generate new code */
void Push(const Expr& n); void Push(const Expr& n) {
VisitExpr(n);
}
/*! /*!
* \brief Push the opcode to the code. * \brief Push the opcode to the code.
* \param opcode The code to be pushed. * \param opcode The code to be pushed.
...@@ -84,16 +90,53 @@ class CodeGenStackVM { ...@@ -84,16 +90,53 @@ class CodeGenStackVM {
* \return the heap index of the var. * \return the heap index of the var.
*/ */
int GetVarID(const Variable* v) const; int GetVarID(const Variable* v) const;
// Push binary operator
void PushBinary(StackVM::OpCode op_int64,
const Expr& a,
const Expr& b);
// push cast;
void PushCast(Type dst, Type src);
// overloadable functions // overloadable functions
virtual void Push_(const ir::Load* op); // expression
virtual void Push_(const ir::Store* op); void VisitExpr_(const Variable* op) final;
virtual void Push_(const ir::Allocate* op); void VisitExpr_(const Load* op) final;
virtual void Push_(const ir::Call* op); void VisitExpr_(const Let* op) final;
virtual void HandleUnknownCall(const ir::Call* op); void VisitExpr_(const Call* op) final;
/*! \brief function to to print normal code */ void VisitExpr_(const Add* op) final;
using FType = IRFunctor<void(const NodeRef&, CodeGenStackVM *)>; void VisitExpr_(const Sub* op) final;
// vtable to print code void VisitExpr_(const Mul* op) final;
static FType& vtable(); // NOLINT(*) void VisitExpr_(const Div* op) final;
void VisitExpr_(const Mod* op) final;
void VisitExpr_(const Min* op) final;
void VisitExpr_(const Max* op) final;
void VisitExpr_(const EQ* op) final;
void VisitExpr_(const NE* op) final;
void VisitExpr_(const LT* op) final;
void VisitExpr_(const LE* op) final;
void VisitExpr_(const GT* op) final;
void VisitExpr_(const GE* op) final;
void VisitExpr_(const And* op) final;
void VisitExpr_(const Or* op) final;
void VisitExpr_(const Cast* op) final;
void VisitExpr_(const Not* op) final;
void VisitExpr_(const Select* op) final;
void VisitExpr_(const Ramp* op) final;
void VisitExpr_(const Broadcast* op) final;
void VisitExpr_(const IntImm* op) final;
void VisitExpr_(const UIntImm* op) final;
void VisitExpr_(const FloatImm* op) final;
void VisitExpr_(const StringImm* op) final;
// statment
void VisitStmt_(const LetStmt* op) final;
void VisitStmt_(const Store* op) final;
void VisitStmt_(const For* op) final;
void VisitStmt_(const IfThenElse* op) final;
void VisitStmt_(const Allocate* op) final;
void VisitStmt_(const AttrStmt* op) final;
void VisitStmt_(const AssertStmt* op) final;
void VisitStmt_(const Evaluate* op) final;
void VisitStmt_(const Block* op) final;
void VisitStmt_(const ProducerConsumer* op) final;
private: private:
bool debug_{false}; bool debug_{false};
......
...@@ -140,10 +140,6 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) { ...@@ -140,10 +140,6 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) {
} }
} }
Stmt IRMutator::Mutate_(const Load *op, const Stmt& s) {
return s;
}
Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) { Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) {
Expr value = this->Mutate(op->value); Expr value = this->Mutate(op->value);
Expr index = this->Mutate(op->index); Expr index = this->Mutate(op->index);
...@@ -234,84 +230,24 @@ Stmt IRMutator::Mutate_(const Evaluate *op, const Stmt& s) { ...@@ -234,84 +230,24 @@ Stmt IRMutator::Mutate_(const Evaluate *op, const Stmt& s) {
} }
} }
#define DEFINE_OP_RETURN_SELF_STMT_MUTATE_(OP) \ Stmt IRMutator::Mutate_(const Free *op, const Stmt& s) {
Stmt IRMutator::Mutate_(const OP *op, const Stmt& s) { \ return s;
return s; \ }
}
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Variable)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Let)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Free)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Call)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Add)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Sub)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Mul)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Div)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Mod)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Min)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Max)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(EQ)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(NE)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(LT)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(LE)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(GT)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(GE)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(And)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Or)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Reduce)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Cast)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Not)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Select)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Ramp)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Broadcast)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(IntImm)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(UIntImm)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(FloatImm)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(StringImm)
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.DISPATCH_TO_MUTATE_STMT(Variable)
.DISPATCH_TO_MUTATE_STMT(LetStmt) .DISPATCH_TO_MUTATE_STMT(LetStmt)
.DISPATCH_TO_MUTATE_STMT(AttrStmt) .DISPATCH_TO_MUTATE_STMT(AttrStmt)
.DISPATCH_TO_MUTATE_STMT(IfThenElse) .DISPATCH_TO_MUTATE_STMT(IfThenElse)
.DISPATCH_TO_MUTATE_STMT(For) .DISPATCH_TO_MUTATE_STMT(For)
.DISPATCH_TO_MUTATE_STMT(Allocate) .DISPATCH_TO_MUTATE_STMT(Allocate)
.DISPATCH_TO_MUTATE_STMT(Load)
.DISPATCH_TO_MUTATE_STMT(Store) .DISPATCH_TO_MUTATE_STMT(Store)
.DISPATCH_TO_MUTATE_STMT(Let)
.DISPATCH_TO_MUTATE_STMT(Free) .DISPATCH_TO_MUTATE_STMT(Free)
.DISPATCH_TO_MUTATE_STMT(Call)
.DISPATCH_TO_MUTATE_STMT(Add)
.DISPATCH_TO_MUTATE_STMT(Sub)
.DISPATCH_TO_MUTATE_STMT(Mul)
.DISPATCH_TO_MUTATE_STMT(Div)
.DISPATCH_TO_MUTATE_STMT(Mod)
.DISPATCH_TO_MUTATE_STMT(Min)
.DISPATCH_TO_MUTATE_STMT(Max)
.DISPATCH_TO_MUTATE_STMT(EQ)
.DISPATCH_TO_MUTATE_STMT(NE)
.DISPATCH_TO_MUTATE_STMT(LT)
.DISPATCH_TO_MUTATE_STMT(LE)
.DISPATCH_TO_MUTATE_STMT(GT)
.DISPATCH_TO_MUTATE_STMT(GE)
.DISPATCH_TO_MUTATE_STMT(And)
.DISPATCH_TO_MUTATE_STMT(Or)
.DISPATCH_TO_MUTATE_STMT(Reduce)
.DISPATCH_TO_MUTATE_STMT(Cast)
.DISPATCH_TO_MUTATE_STMT(Not)
.DISPATCH_TO_MUTATE_STMT(Select)
.DISPATCH_TO_MUTATE_STMT(Ramp)
.DISPATCH_TO_MUTATE_STMT(Broadcast)
.DISPATCH_TO_MUTATE_STMT(AssertStmt) .DISPATCH_TO_MUTATE_STMT(AssertStmt)
.DISPATCH_TO_MUTATE_STMT(ProducerConsumer) .DISPATCH_TO_MUTATE_STMT(ProducerConsumer)
.DISPATCH_TO_MUTATE_STMT(Provide) .DISPATCH_TO_MUTATE_STMT(Provide)
.DISPATCH_TO_MUTATE_STMT(Realize) .DISPATCH_TO_MUTATE_STMT(Realize)
.DISPATCH_TO_MUTATE_STMT(Block) .DISPATCH_TO_MUTATE_STMT(Block)
.DISPATCH_TO_MUTATE_STMT(Evaluate) .DISPATCH_TO_MUTATE_STMT(Evaluate);
.DISPATCH_TO_MUTATE_STMT(IntImm)
.DISPATCH_TO_MUTATE_STMT(UIntImm)
.DISPATCH_TO_MUTATE_STMT(FloatImm)
.DISPATCH_TO_MUTATE_STMT(StringImm);
// Mutate Expr // Mutate Expr
...@@ -450,19 +386,6 @@ Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) { ...@@ -450,19 +386,6 @@ Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) {
return e; \ return e; \
} }
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(LetStmt)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(AttrStmt)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(For)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IfThenElse)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Allocate)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Store)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Free)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(AssertStmt)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(ProducerConsumer)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Provide)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Realize)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Block)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Evaluate)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImm) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImm)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImm) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImm)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImm) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImm)
...@@ -470,15 +393,8 @@ DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImm) ...@@ -470,15 +393,8 @@ DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImm)
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.DISPATCH_TO_MUTATE_EXPR(Variable) .DISPATCH_TO_MUTATE_EXPR(Variable)
.DISPATCH_TO_MUTATE_EXPR(LetStmt)
.DISPATCH_TO_MUTATE_EXPR(AttrStmt)
.DISPATCH_TO_MUTATE_EXPR(IfThenElse)
.DISPATCH_TO_MUTATE_EXPR(For)
.DISPATCH_TO_MUTATE_EXPR(Allocate)
.DISPATCH_TO_MUTATE_EXPR(Load) .DISPATCH_TO_MUTATE_EXPR(Load)
.DISPATCH_TO_MUTATE_EXPR(Store)
.DISPATCH_TO_MUTATE_EXPR(Let) .DISPATCH_TO_MUTATE_EXPR(Let)
.DISPATCH_TO_MUTATE_EXPR(Free)
.DISPATCH_TO_MUTATE_EXPR(Call) .DISPATCH_TO_MUTATE_EXPR(Call)
.DISPATCH_TO_MUTATE_EXPR(Add) .DISPATCH_TO_MUTATE_EXPR(Add)
.DISPATCH_TO_MUTATE_EXPR(Sub) .DISPATCH_TO_MUTATE_EXPR(Sub)
...@@ -501,12 +417,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) ...@@ -501,12 +417,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.DISPATCH_TO_MUTATE_EXPR(Select) .DISPATCH_TO_MUTATE_EXPR(Select)
.DISPATCH_TO_MUTATE_EXPR(Ramp) .DISPATCH_TO_MUTATE_EXPR(Ramp)
.DISPATCH_TO_MUTATE_EXPR(Broadcast) .DISPATCH_TO_MUTATE_EXPR(Broadcast)
.DISPATCH_TO_MUTATE_EXPR(AssertStmt)
.DISPATCH_TO_MUTATE_EXPR(ProducerConsumer)
.DISPATCH_TO_MUTATE_EXPR(Provide)
.DISPATCH_TO_MUTATE_EXPR(Realize)
.DISPATCH_TO_MUTATE_EXPR(Block)
.DISPATCH_TO_MUTATE_EXPR(Evaluate)
.DISPATCH_TO_MUTATE_EXPR(IntImm) .DISPATCH_TO_MUTATE_EXPR(IntImm)
.DISPATCH_TO_MUTATE_EXPR(UIntImm) .DISPATCH_TO_MUTATE_EXPR(UIntImm)
.DISPATCH_TO_MUTATE_EXPR(FloatImm) .DISPATCH_TO_MUTATE_EXPR(FloatImm)
......
...@@ -69,11 +69,71 @@ class Vectorizer : public IRMutator { ...@@ -69,11 +69,71 @@ class Vectorizer : public IRMutator {
} }
// user mutate from parent. // user mutate from parent.
using IRMutator::Mutate; using IRMutator::Mutate;
// override mutate
Expr Mutate(Expr expr) final { Expr Mutate_(const Add* op, const Expr &e) final {
static const FMutateExpr& f = Vectorizer::vtable_expr(); return AddSubVec(op, e);
return (f.can_dispatch(expr) ? }
f(expr, expr, this) : IRMutator::Mutate(expr)); Expr Mutate_(const Sub* op, const Expr &e) final {
return AddSubVec(op, e);
}
Expr Mutate_(const Mul* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Div* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Mod* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Min* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Max* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const EQ* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const NE* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const LT* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const GT* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const GE* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const And* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Or* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Select *op, const Expr& e) final {
Expr cond = this->Mutate(op->condition);
Expr t = this->Mutate(op->true_value);
Expr f = this->Mutate(op->false_value);
if (cond.same_as(op->condition) &&
t.same_as(op->true_value) &&
f.same_as(op->false_value)) {
return e;
} else {
int lanes = std::max(std::max(
cond.type().lanes(),
t.type().lanes()), f.type().lanes());
return Select::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes));
}
}
Expr Mutate_(const Cast *op, const Expr& e) final {
Expr value = this->Mutate(op->value);
if (value.same_as(op->value)) {
return e;
} else {
return Cast::make(op->type.with_lanes(value.type().lanes()), value);
}
} }
// Variable // Variable
Expr Mutate_(const Variable* v, const Expr& e) final { Expr Mutate_(const Variable* v, const Expr& e) final {
...@@ -235,10 +295,6 @@ class Vectorizer : public IRMutator { ...@@ -235,10 +295,6 @@ class Vectorizer : public IRMutator {
stmt = Substitute(stmt, {{var_, idx}}); stmt = Substitute(stmt, {{var_, idx}});
return For::make(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt); return For::make(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt);
} }
// The overloads for vectorize.
static FMutateExpr& vtable_expr() { // NOLINT(*)
static FMutateExpr inst; return inst;
}
private: private:
// variable to be replaced // variable to be replaced
...@@ -273,13 +329,10 @@ class Vectorizer : public IRMutator { ...@@ -273,13 +329,10 @@ class Vectorizer : public IRMutator {
if (!changed) return arr; if (!changed) return arr;
return Array<Expr>(new_arr); return Array<Expr>(new_arr);
} }
}; template<typename T>
Expr BinaryVec(const T* op, const Expr& e) {
// binary vectorize Expr a = this->Mutate(op->a);
template<typename T> Expr b = this->Mutate(op->b);
inline Expr BinaryVec(const T* op, const Expr& e, IRMutator* m) {
Expr a = m->Mutate(op->a);
Expr b = m->Mutate(op->b);
if (a.same_as(op->a) && if (a.same_as(op->a) &&
b.same_as(op->b)) { b.same_as(op->b)) {
return e; return e;
...@@ -287,12 +340,11 @@ inline Expr BinaryVec(const T* op, const Expr& e, IRMutator* m) { ...@@ -287,12 +340,11 @@ inline Expr BinaryVec(const T* op, const Expr& e, IRMutator* m) {
int lanes = std::max(a.type().lanes(), b.type().lanes()); int lanes = std::max(a.type().lanes(), b.type().lanes());
return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
} }
} }
template<typename T>
template<typename T> Expr AddSubVec(const T* op, const Expr& e) {
inline Expr AddSubVec(const T* op, const Expr& e, IRMutator* m) { Expr a = this->Mutate(op->a);
Expr a = m->Mutate(op->a); Expr b = this->Mutate(op->b);
Expr b = m->Mutate(op->b);
if (a.same_as(op->a) && if (a.same_as(op->a) &&
b.same_as(op->b)) { b.same_as(op->b)) {
return e; return e;
...@@ -312,51 +364,8 @@ inline Expr AddSubVec(const T* op, const Expr& e, IRMutator* m) { ...@@ -312,51 +364,8 @@ inline Expr AddSubVec(const T* op, const Expr& e, IRMutator* m) {
} }
return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
} }
}
TVM_STATIC_IR_FUNCTOR(Vectorizer, vtable_expr)
.set_dispatch<Add>(AddSubVec<Add>)
.set_dispatch<Sub>(AddSubVec<Sub>)
.set_dispatch<Mul>(BinaryVec<Mul>)
.set_dispatch<Div>(BinaryVec<Div>)
.set_dispatch<Mod>(BinaryVec<Mod>)
.set_dispatch<Min>(BinaryVec<Min>)
.set_dispatch<Max>(BinaryVec<Max>)
.set_dispatch<EQ>(BinaryVec<EQ>)
.set_dispatch<NE>(BinaryVec<NE>)
.set_dispatch<LT>(BinaryVec<LT>)
.set_dispatch<LE>(BinaryVec<LE>)
.set_dispatch<GT>(BinaryVec<GT>)
.set_dispatch<GE>(BinaryVec<GE>)
.set_dispatch<And>(BinaryVec<And>)
.set_dispatch<Or>(BinaryVec<Or>);
TVM_STATIC_IR_FUNCTOR(Vectorizer, vtable_expr)
.set_dispatch<Select>([](const Select *op, const Expr& e, IRMutator* m) {
Expr cond = m->Mutate(op->condition);
Expr t = m->Mutate(op->true_value);
Expr f = m->Mutate(op->false_value);
if (cond.same_as(op->condition) &&
t.same_as(op->true_value) &&
f.same_as(op->false_value)) {
return e;
} else {
int lanes = std::max(std::max(
cond.type().lanes(),
t.type().lanes()), f.type().lanes());
return Select::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes));
} }
}) };
.set_dispatch<Cast>([](const Cast *op, const Expr& e, IRMutator* m) {
Expr value = m->Mutate(op->value);
if (value.same_as(op->value)) {
return e;
} else {
return Cast::make(op->type.with_lanes(value.type().lanes()), value);
}
});
class LoopVectorizer : public IRMutator { class LoopVectorizer : public IRMutator {
public: public:
......
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/tvm.h> #include <tvm/tvm.h>
#include <tvm/ir_functor.h> #include <tvm/ir_functor.h>
#include <tvm/ir_functor_ext.h>
TEST(IRF, Basic) { TEST(IRF, Basic) {
using namespace Halide::Internal;
using namespace tvm; using namespace tvm;
using namespace tvm::ir;
Var x("x"); Var x("x");
auto z = x + 1; auto z = x + 1;
...@@ -21,6 +22,65 @@ TEST(IRF, Basic) { ...@@ -21,6 +22,65 @@ TEST(IRF, Basic) {
CHECK_EQ(f(z, 2), 4); CHECK_EQ(f(z, 2), 4);
} }
TEST(IRF, ExprTransform) {
using namespace tvm;
using namespace tvm::ir;
Var x("x");
auto z = x + 1;
class MyExprFunctor
: public ir::ExprFunctor<int(const Expr&, int)> {
public:
int VisitExpr_(const Variable* op, int b) final {
return b;
}
int VisitExpr_(const IntImm* op, int b) final {
return op->value;
}
int VisitExpr_(const Add* op, int b) final {
return VisitExpr(op->a, b) + VisitExpr(op->b, b);
}
};
MyExprFunctor f;
CHECK_EQ(f(x, 2), 2);
CHECK_EQ(f(z, 2), 3);
try {
f(z - 1, 2);
LOG(FATAL) << "should fail";
} catch(dmlc::Error) {
}
}
TEST(IRF, ExprVisit) {
using namespace tvm;
using namespace tvm::ir;
Var x("x");
auto z = x + 1;
class MyVisitor
: public ir::ExprFunctor<void(const Expr&)>,
public ir::StmtFunctor<void(const Stmt&)> {
public:
int count = 0;
// implementation
void VisitExpr_(const Variable* op) final {
++count;
}
void VisitExpr_(const IntImm* op) final {
}
void VisitExpr_(const Add* op) final {
VisitExpr(op->a);
VisitExpr(op->b);
}
void VisitStmt_(const Evaluate* op) final {
VisitExpr(op->value);
}
};
MyVisitor v;
v(Evaluate::make(z));
CHECK_EQ(v.count, 1);
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe"; testing::FLAGS_gtest_death_test_style = "threadsafe";
......
...@@ -25,6 +25,11 @@ def test_deduce(): ...@@ -25,6 +25,11 @@ def test_deduce():
ans1 = (c-b)/4+(-2) ans1 = (c-b)/4+(-2)
assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1) assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1)
e2 = (tvm.max(5, a * 4) < 0)
res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s})
assert str(res2.max()) == "neg_inf"
assert str(res2.min()) == "pos_inf"
def test_check(): def test_check():
a = tvm.Var('a') a = tvm.Var('a')
b = tvm.Var('b') b = tvm.Var('b')
......
import tvm
def test_basic():
a = tvm.Var()
b = tvm.Var()
m = tvm.arith.EvalModular(a * 4 + b * 6 + 7)
assert m.coeff == 2
assert m.base == 1
m = tvm.arith.EvalModular((a * 4 + 1) * (b * 8 + 3))
assert m.coeff == 4
assert m.base == 3
m = tvm.arith.EvalModular((a * 4 + 1) / (b * 8 + 3))
assert m.coeff == 1
assert m.base == 0
m = tvm.arith.EvalModular((a * 4 + 1) * (b * 8 / 4))
assert m.coeff == 2
assert m.base == 0
m = tvm.arith.EvalModular((a * 12 + 1) - (b * 3 * 7 + 2))
assert m.coeff == 3
assert m.base == 2
m = tvm.arith.EvalModular(a * 12 + tvm.min(b * 3 * 7, 2))
assert m.coeff == 1
assert m.base == 0
if __name__ == "__main__":
test_basic()
...@@ -16,7 +16,7 @@ def test_llvm_add_pipeline(): ...@@ -16,7 +16,7 @@ def test_llvm_add_pipeline():
f = tvm.build(s, [A, B, C], "llvm") f = tvm.build(s, [A, B, C], "llvm")
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
# launch the kernel. # launch the kernel.
n = 10270 * 2460 n = 1027 * 1024
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx) a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment