Commit 71334483 by Tianqi Chen Committed by GitHub

[VISITOR] New ExprFunctor, StmtFunctor Interface. Modular analysis (#58)

* [ARITH/VISITOR] Modular Analysis, ExprFunctor, StmtFunctor

* retrigger

* [IRFunctor] Migrated CodegenC

* [IRFUNCTOR] Migrate CodeGenLLVM

* [IRFunctor] Migrate canonical

* [IRFunctor] Migrate vectorize

* [IRFunctor] migrate CodeGenStackVM
parent e4387940
......@@ -59,7 +59,6 @@ after_failure:
- tests/travis/travis_after_failure.sh
notifications:
# Emails are sent to the committer's git-configured email address by default,
email:
on_success: change
on_failure: always
......@@ -55,59 +55,23 @@ class IRMutator {
static FMutateStmt& vtable_stmt(); // NOLINT(*)
// Set of overloadable functions
// The underscore allows Mutate not to be shadowed by inheritance
virtual Stmt Mutate_(const Variable* op, const Stmt& s);
virtual Stmt Mutate_(const LetStmt* op, const Stmt& s);
virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s);
virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s);
virtual Stmt Mutate_(const For* op, const Stmt& s);
virtual Stmt Mutate_(const Allocate* op, const Stmt& s);
virtual Stmt Mutate_(const Load* op, const Stmt& s);
virtual Stmt Mutate_(const Store* op, const Stmt& s);
virtual Stmt Mutate_(const Let* op, const Stmt& s);
virtual Stmt Mutate_(const Free* op, const Stmt& s);
virtual Stmt Mutate_(const Call* op, const Stmt& s);
virtual Stmt Mutate_(const Add* op, const Stmt& e);
virtual Stmt Mutate_(const Sub* op, const Stmt& e);
virtual Stmt Mutate_(const Mul* op, const Stmt& e);
virtual Stmt Mutate_(const Div* op, const Stmt& e);
virtual Stmt Mutate_(const Mod* op, const Stmt& e);
virtual Stmt Mutate_(const Min* op, const Stmt& e);
virtual Stmt Mutate_(const Max* op, const Stmt& e);
virtual Stmt Mutate_(const EQ* op, const Stmt& e);
virtual Stmt Mutate_(const NE* op, const Stmt& e);
virtual Stmt Mutate_(const LT* op, const Stmt& e);
virtual Stmt Mutate_(const LE* op, const Stmt& e);
virtual Stmt Mutate_(const GT* op, const Stmt& e);
virtual Stmt Mutate_(const GE* op, const Stmt& e);
virtual Stmt Mutate_(const And* op, const Stmt& e);
virtual Stmt Mutate_(const Or* op, const Stmt& e);
virtual Stmt Mutate_(const Reduce* op, const Stmt& s);
virtual Stmt Mutate_(const Cast* op, const Stmt& s);
virtual Stmt Mutate_(const Not* op, const Stmt& s);
virtual Stmt Mutate_(const Select* op, const Stmt& s);
virtual Stmt Mutate_(const Ramp* op, const Stmt& s);
virtual Stmt Mutate_(const Broadcast* op, const Stmt& e);
virtual Stmt Mutate_(const AssertStmt* op, const Stmt& e);
virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& e);
virtual Stmt Mutate_(const Provide* op, const Stmt& e);
virtual Stmt Mutate_(const Realize* op, const Stmt& s);
virtual Stmt Mutate_(const Block* op, const Stmt& s);
virtual Stmt Mutate_(const Evaluate* op, const Stmt& e);
virtual Stmt Mutate_(const IntImm* op, const Stmt& e);
virtual Stmt Mutate_(const UIntImm* op, const Stmt& e);
virtual Stmt Mutate_(const FloatImm* op, const Stmt& e);
virtual Stmt Mutate_(const StringImm* op, const Stmt& e);
virtual Expr Mutate_(const Variable* op, const Expr& e);
virtual Expr Mutate_(const LetStmt* op, const Expr& e);
virtual Expr Mutate_(const AttrStmt* op, const Expr& e);
virtual Expr Mutate_(const IfThenElse* op, const Expr& e);
virtual Expr Mutate_(const For* op, const Expr& e);
virtual Expr Mutate_(const Allocate* op, const Expr& e);
virtual Expr Mutate_(const Load* op, const Expr& e);
virtual Expr Mutate_(const Store* op, const Expr& e);
virtual Expr Mutate_(const Let* op, const Expr& e);
virtual Expr Mutate_(const Free* op, const Expr& e);
virtual Expr Mutate_(const Call* op, const Expr& e);
virtual Expr Mutate_(const Add* op, const Expr& e);
virtual Expr Mutate_(const Sub* op, const Expr& e);
......@@ -130,38 +94,12 @@ class IRMutator {
virtual Expr Mutate_(const Select* op, const Expr& e);
virtual Expr Mutate_(const Ramp* op, const Expr& e);
virtual Expr Mutate_(const Broadcast* op, const Expr& e);
virtual Expr Mutate_(const AssertStmt* op, const Expr& e);
virtual Expr Mutate_(const ProducerConsumer* op, const Expr& e);
virtual Expr Mutate_(const Provide* op, const Expr& e);
virtual Expr Mutate_(const Realize* op, const Expr& e);
virtual Expr Mutate_(const Block* op, const Expr& e);
virtual Expr Mutate_(const Evaluate* op, const Expr& e);
virtual Expr Mutate_(const IntImm* op, const Expr& e);
virtual Expr Mutate_(const UIntImm* op, const Expr& e);
virtual Expr Mutate_(const FloatImm* op, const Expr& e);
virtual Expr Mutate_(const StringImm* op, const Expr& e);
};
/*!
* \brief Example on how to subclass and override behavior of IRMutator
*/
class IRMutatorExample : public IRMutator {
public:
Expr Mutate(Expr expr) final {
static const FMutateExpr& f = IRMutatorExample::vtable_expr();
return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::Mutate(expr));
}
Stmt Mutate(Stmt stmt) final {
static const FMutateStmt& f = IRMutatorExample::vtable_stmt();
return (f.can_dispatch(stmt) ?
f(stmt, stmt, this) : IRMutator::Mutate(stmt));
}
// to be implemented by child class
static FMutateExpr& vtable_expr(); // NOLINT(*)
static FMutateStmt& vtable_stmt(); // NOLINT(*)
};
} // namespace ir
} // namespace tvm
#endif // TVM_IR_MUTATOR_H_
......@@ -6,6 +6,7 @@
#ifndef TVM_IR_VISITOR_H_
#define TVM_IR_VISITOR_H_
#include <tvm/ir_functor.h>
#include "./ir.h"
namespace tvm {
......@@ -17,7 +18,51 @@ namespace ir {
* This IRVisitor is implemented via IRFunctor
* This enables extensions of possible new Node.
*
* \sa IRFunctor, PostOrderVisit
* \sa ExprFunctor, StmtFunctor, PostOrderVisit
*
* \note If you need to return values during Visit:
* - If it is mutaion of the IR, use IRMutator
* - If you want to return other things, consider use ExprFunctor/StmtFunctor
* - Watch out for possible bug pattern if you use IRVisitor to simulate returns.
*
* \code
*
* // This is an example code to show cases for traps in IRVisitor
* // The use case is to count number of Variables in the ir tree.
* class MyCounter : public IRVisitor {
* public:
* int Count(const NodeRef& n) {
* ret_ = 0;
* this->Visit(n);
* return ret_;
* }
* void Visit_(const Variable* op) final {
* ret_ = 1;
* }
* void Visit_(const Add* op) final {
* ret_ = count(op->a) + count(op->b);
* }
* private:
* int ret_;
* };
* MyCounter counter;
* Var x("x");
* // this returns 2
* CHECK_EQ(counter.Count(x + x), 2);
* // Think what is the result of the following count
* counter.count(Max::make(x, x));
* // The result is actually 1
* // This is because Visit is not overriden for Max
* // so it simply calls Visit for the left and right children
* // and because Count is not called, ret_ is not cleared.
* // There can also be cases where ret_ is forgetten to be set.
*
* // These traps may not happen if we program carefully
* // But it is recommended to use ExprFunctor, which allows direct
* // return the value, this helps us to avoid such problems.
* \encode
*
*/
class IRVisitor {
public:
......
......@@ -274,33 +274,51 @@ def sum(expr, axis):
return x
def min(expr, axis):
"""Create a min expression over axis
def min(lhs, rhs=None, axis=None):
"""Create a min expression.
Parameters
----------
expr : Expr
The source expression.
lhs : Expr
The left hand expression.
axis : IterVar
rhs : Expr, optional
The right hand expression.
axis : IterVar, optional
The reduction IterVar axis
"""
if rhs and axis:
raise ValueError("Can only take one argument, rhs or axis")
if isinstance(rhs, (_collections.IterVar, list)):
axis, rhs = rhs, axis
if rhs:
return _make.Min(lhs, rhs)
axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Min", expr, axis)
return x
def max(expr, axis):
"""Create a min expression over axis
def max(lhs, rhs=None, axis=None):
"""Create a max expression.
Parameters
----------
expr : Expr
The source expression.
lhs : Expr
The left hand expression.
axis : IterVar
rhs : Expr, optional
The right hand expression.
axis : IterVar, optional
The reduction IterVar axis
"""
if rhs and axis:
raise ValueError("Can only take one argument, rhs or axis")
if isinstance(rhs, (_collections.IterVar, list)):
axis, rhs = rhs, axis
if rhs:
return _make.Max(lhs, rhs)
axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Max", expr, axis)
return x
......
......@@ -5,7 +5,6 @@ from __future__ import absolute_import as _abs
from ._ctypes._node import NodeBase, register_node
from . import _api_internal
@register_node
class IntSet(NodeBase):
"""Represent a set of integer in one dimension."""
def is_nothing(self):
......@@ -33,3 +32,8 @@ class IntervalSet(IntSet):
class StrideSet(IntSet):
"""Represent set of strided integers"""
pass
@register_node
class ModularSet(IntSet):
"""Represent range of (coeff * x + base) for x in Z """
pass
......@@ -7,6 +7,7 @@
#include <tvm/ir.h>
#include <tvm/api_registry.h>
#include "../arithmetic/int_set.h"
#include "../arithmetic/modular.h"
namespace tvm {
namespace arith {
......@@ -21,6 +22,11 @@ TVM_REGISTER_API(_arith_intset_interval)
*ret = IntSet::interval(args[0], args[1]);
});
TVM_REGISTER_API(_arith_EvalModular)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = EvalModular(args[0], Map<Var, IntSet>());
});
TVM_REGISTER_API(_arith_DeduceBound)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DeduceBound(args[0], args[1], args[2]);
......
......@@ -162,10 +162,8 @@ class Canonical::Internal : public IRMutator {
return stmt;
}
Expr MutateExpr_(Expr expr) {
static const FMutateExpr& f = Internal::vtable_expr();
stack_.push_back(StackEntry());
expr = (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::Mutate(expr));
expr = IRMutator::Mutate(expr);
// update result of parent automatically during pop
if (stack_.size() > 1) {
StackEntry& back = stack_[stack_.size() - 1];
......@@ -200,7 +198,7 @@ class Canonical::Internal : public IRMutator {
return (t.lanes() == 1 && (t.is_int() || t.is_uint()));
}
// Add
Expr Mutate_(const Add* op, const Expr& e) {
Expr Mutate_(const Add* op, const Expr& e) final {
if (!EnableOpt(op->type)) {
return Binary(op, e, this);
}
......@@ -212,7 +210,7 @@ class Canonical::Internal : public IRMutator {
return SumAdd(a, b, +1);
}
// Sub
Expr Mutate_(const Sub* op, const Expr& e) {
Expr Mutate_(const Sub* op, const Expr& e) final {
if (!EnableOpt(op->type)) {
return Binary(op, e, this);
}
......@@ -224,7 +222,7 @@ class Canonical::Internal : public IRMutator {
return SumAdd(a, b, -1);
}
// Mul
Expr Mutate_(const Mul* op, const Expr& e) {
Expr Mutate_(const Mul* op, const Expr& e) final {
if (!EnableOpt(op->type)) {
return Binary(op, e, this);
}
......@@ -463,17 +461,6 @@ class Canonical::Internal : public IRMutator {
using CInternal = Canonical::Internal;
#define DISPATCH_EXPR(OP) \
set_dispatch<OP>([](const OP *op, const Expr& e, IRMutator* p) { \
return static_cast<CInternal*>(p)->Mutate_(op, e); })
TVM_STATIC_IR_FUNCTOR(CInternal, vtable_expr)
.DISPATCH_EXPR(Add)
.DISPATCH_EXPR(Sub)
.DISPATCH_EXPR(Mul)
.DISPATCH_EXPR(LT);
Canonical::Canonical()
: ptr_(std::make_shared<Internal>()) {}
......
......@@ -9,6 +9,7 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include "./int_set.h"
#include "./modular.h"
namespace tvm {
namespace arith {
......@@ -54,6 +55,23 @@ struct StrideSet : public IntSetNode {
TVM_DECLARE_NODE_TYPE_INFO(StrideSet, IntSetNode);
};
/*!
* \brief Set represented by range of ModularEntry.
* Used for front-end modular analysis.
*/
struct ModularSet : public IntSetNode {
/*! \brief Internal modular entry */
ModularEntry e;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("base", &(e.base));
v->Visit("coeff", &(e.coeff));
}
static constexpr const char* _type_key = "ModularSet";
TVM_DECLARE_NODE_TYPE_INFO(ModularSet, IntSetNode);
};
} // namespace arith
} // namespace tvm
......
/*!
* Copyright (c) 2017 by Contributors
* \file modular.cc
* \brief Modular analysis
*/
#include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_visitor.h>
#include <limits>
#include "./modular.h"
#include "./int_set_internal.h"
namespace tvm {
namespace arith {
using namespace ir;
class ModularEvaluator
: public ExprFunctor<ModularEntry(const Expr&)> {
public:
explicit ModularEvaluator(
const std::unordered_map<
const Variable*, ModularEntry>& mod_map)
: mod_map_(mod_map) {
}
ModularEntry Eval(const Expr& e) {
return VisitExpr(e);
}
// default
ModularEntry VisitExprDefault_(const Node*) final {
return ModularEntry::everything();
}
// override combination rules.
ModularEntry VisitExpr_(const IntImm* op) final {
if (op->value < std::numeric_limits<int>::max()) {
ModularEntry ret;
ret.base = static_cast<int>(op->value);
ret.coeff = 0;
return ret;
} else {
return ModularEntry::everything();
}
}
ModularEntry VisitExpr_(const UIntImm* op) final {
if (op->value < static_cast<uint64_t>(
std::numeric_limits<int>::max())) {
ModularEntry ret;
ret.base = static_cast<int>(op->value);
ret.coeff = 0;
return ret;
} else {
return ModularEntry::everything();
}
}
ModularEntry VisitExpr_(const Variable* op) final {
auto it = mod_map_.find(op);
if (it != mod_map_.end()) {
return it->second;
} else {
return ModularEntry::everything();
}
}
ModularEntry VisitExpr_(const Add* op) final {
ModularEntry a = Eval(op->a);
ModularEntry b = Eval(op->b);
ModularEntry ret;
ret.coeff = ZeroAwareGCD(a.coeff, b.coeff);
ret.base = BaseSimplify(a.base + b.base, ret.coeff);
return ret;
}
ModularEntry VisitExpr_(const Sub* op) final {
ModularEntry a = Eval(op->a);
ModularEntry b = Eval(op->b);
ModularEntry ret;
ret.coeff = ZeroAwareGCD(a.coeff, b.coeff);
ret.base = BaseSimplify(a.base - b.base, ret.coeff);
return ret;
}
ModularEntry VisitExpr_(const Mul* op) final {
ModularEntry a = Eval(op->a);
ModularEntry b = Eval(op->b);
// Simplification rule, x, y, z are in Z
// (p x + n) (q y + m)
// -> pq xy + pm x + qn y + mn
// -> pq z + pm x + qn y + mn
int pq = a.coeff * b.coeff;
int pm = a.coeff * b.base;
int qn = a.base * b.coeff;
ModularEntry ret;
ret.coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn));
ret.base = BaseSimplify(a.base * b.base, ret.coeff);
return ret;
}
ModularEntry VisitExpr_(const Div* op) final {
// a c x / c -> a x
// We cannot do cases where offset is non-zero
// because of different integer rounding in pos/neg
ModularEntry a = Eval(op->a);
ModularEntry b = Eval(op->b);
if (b.coeff == 0 &&
a.base == 0) {
CHECK_NE(b.base, 0);
if (a.coeff % b.base == 0) {
ModularEntry ret;
ret.coeff = a.coeff / b.base;
ret.base = 0;
return ret;
}
}
return ModularEntry::everything();
}
private:
const std::unordered_map<
const Variable*, ModularEntry>& mod_map_;
// simplify the base by putting it in range.
static int BaseSimplify(int base, int coeff) {
if (coeff == 0) return base;
base = base % coeff;
if (base < 0) base += coeff;
return base;
}
static int ZeroAwareGCD(int a, int b) {
CHECK_GE(a, 0);
CHECK_GE(b, 0);
if (a < b) std::swap(a, b);
if (b == 0) return a;
// perform GCD (greatest common divisor)
// ax + by = gcd(a, b) z if a != 0, b != 0
while (a % b != 0) {
a = a % b;
std::swap(a, b);
}
return b;
}
};
ModularEntry EvalModular(
const Expr& e,
const std::unordered_map<const Variable*, ModularEntry>& mod_map) {
return ModularEvaluator(mod_map)(e);
}
IntSet EvalModular(const Expr& e,
const Map<Var, IntSet>& mod_map) {
std::unordered_map<const Variable*, ModularEntry> mmap;
for (auto& kv : mod_map) {
const ModularSet* m = kv.second.as<ModularSet>();
CHECK(m) << "Need to pass ModularSet for Modular Analysis";
mmap[kv.first.get()] = m->e;
}
std::shared_ptr<ModularSet> n = std::make_shared<ModularSet>();
n->e = ModularEvaluator(mmap)(e);
return IntSet(n);
}
} // namespace arith
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file modular.h
* \brief Modular integer set analysis
*/
#ifndef TVM_ARITHMETIC_MODULAR_H_
#define TVM_ARITHMETIC_MODULAR_H_
#include <tvm/expr.h>
#include "./int_set.h"
namespace tvm {
namespace arith {
/*!
* \brief Range of a linear integer function.
* Use to do specify the possible index values.
*
* set = { base + coeff * x | x \in Z }
*
* When coeff != 0, it can also be written as
* set = { n | n % coeff == base }
*
* This is useful to decide if the index is dividable by certain value.
* For example, if index = 0 + 4 x, then we know it can be divided by 4.
*/
struct ModularEntry {
/*! \brief The base */
int base;
/*! \brief linear co-efficient */
int coeff;
/*! \return entry represent everything */
static ModularEntry everything() {
// always safe to set 0 + x, so it can be everything.
ModularEntry e;
e.base = 0; e.coeff = 1;
return e;
}
};
/*!
* \brief Evaluate the expression with modular analysis
* \param e The expression to be evaluated.
* \param mod_map Map of modular statistics of known variables.
* \return The ModularEntry covering all possible value of e.
*/
ModularEntry EvalModular(
const Expr& e,
const std::unordered_map<const Variable*, ModularEntry>& mod_map);
/*!
* \brief Same as EvalModular, used by front-end.
* \param e The expression to be evaluated.
* \param mod_map Map of modular statistics of known variables.
* \return A ModularSet covering all possible value of e.
*/
IntSet EvalModular(const Expr& e,
const Map<Var, IntSet>& mod_map);
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_MODULAR_H_
......@@ -7,6 +7,7 @@
#define TVM_CODEGEN_CODEGEN_C_H_
#include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/codegen.h>
#include <tvm/lowered_func.h>
#include <string>
......@@ -16,12 +17,15 @@
namespace tvm {
namespace codegen {
using namespace ir;
/*!
* \brief A base class to generate C code.
*
* CodeGenC have two modes: generate SSA formed C code or normal form.
*/
class CodeGenC {
class CodeGenC :
public ExprFunctor<void(const Expr&, std::ostream&)>,
public StmtFunctor<void(const Stmt&)> {
public:
/*!
* \brief Initialize the code generator.
......@@ -42,13 +46,15 @@ class CodeGenC {
* \brief Print the Stmt n to CodeGenC->stream
* \param n The statement to be printed.
*/
void PrintStmt(const Stmt& n);
void PrintStmt(const Stmt& n) {
VisitStmt(n);
}
/*!
* \brief Print the expression n(or its ssa id if in ssa mode) into os
* \param n The expression to be printed.
* \param os The output stream
*/
void PrintExpr(const Expr& n, std::ostream& os); // NOLINT(*)
void PrintExpr(const Expr& n, std::ostream& os);
/*!
* \brief Same as PrintExpr, but simply returns result string
* \param n The expression to be printed.
......@@ -84,6 +90,46 @@ class CodeGenC {
* \param f The function to be compiled.
*/
virtual void InitFuncState(LoweredFunc f);
// expression
void VisitExpr_(const Variable* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Load* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Let* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Call* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Add* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Sub* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Mul* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Div* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Mod* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Min* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Max* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const EQ* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const NE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LT* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GT* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const And* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Or* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Cast* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Not* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Select* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Ramp* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Broadcast* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const IntImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const UIntImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const FloatImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const StringImm* op, std::ostream& os) override; // NOLINT(*)
// statment
void VisitStmt_(const LetStmt* op) override;
void VisitStmt_(const Store* op) override;
void VisitStmt_(const For* op) override;
void VisitStmt_(const IfThenElse* op) override;
void VisitStmt_(const Allocate* op) override;
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
/*!
* Print Type represetnation of type t.
* \param t The type representation.
......@@ -97,50 +143,37 @@ class CodeGenC {
*/
virtual void PrintThreadIndexExpr(
std::string tag, std::ostream& os); // NOLINT(*)
virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*
virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*)
virtual void PrintStorageSync(const std::string& scope); // NOLINT(*)
virtual void PrintStmt(const ir::LetStmt* op);
virtual void PrintStmt(const ir::Store* op);
virtual void PrintStmt(const ir::For* op);
virtual void PrintStmt(const ir::IfThenElse* op);
virtual void PrintStmt(const ir::Allocate* op);
virtual void PrintStmt(const ir::AttrStmt* op);
virtual void PrintStmt(const ir::AssertStmt* op);
virtual void PrintExpr(const ir::Load* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Call* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Let* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Ramp* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Broadcast* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Select* op, std::ostream& os); // NOLINT(*)
// Binary vector op.
virtual void PrintVecBinaryOp(
const std::string&op, Type op_type,
Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*)
// print vector load
virtual void PrintVecLoad(const Variable* buffer,
Type t, Expr base,
std::ostream& os); // NOLINT(*)
// print vector store
virtual void PrintVecStore(const Variable* buffer,
Type t, Expr base,
const std::string& value); // NOLINT(*)
// print load of single element
virtual void PrintVecElemLoad(
const std::string& vec, Type t, int i, std::ostream& os); // NOLINT(*)
// print store of single element.
virtual void PrintVecElemStore(
const std::string& vec, Type t, int i, const std::string& value);
/*! \brief function print into the ostream */
using FPrintExpr = IRFunctor<void(const NodeRef&, std::ostream& os, CodeGenC *)>; // NOLINT(*)
/*! \brief function to to print normal code */
using FPrintStmt = IRFunctor<void(const NodeRef&, CodeGenC *)>;
// vtable to print code
static FPrintStmt& vtable_print_stmt();
// vtable to print code
static FPrintExpr& vtable_print_expr();
/*! \brief The current indentation value */
int indent{0};
/*! \brief the stream to be printed */
std::ostringstream stream;
protected:
/*! \brief the stream to be printed */
std::ostringstream stream;
/*! \brief entry in ssa assign map */
struct SSAEntry {
/*! \brief The value id */
std::string vid;
/*! \brief The scope id */
int scope_id;
};
// print reference to a buffer as type t in index.
void PrintBufferRef(const Variable* buffer,
Type t, Expr index,
......@@ -158,13 +191,6 @@ class CodeGenC {
* \return The returned name.
*/
std::string GetUniqueName(std::string prefix);
/*! \brief entry in ssa assign map */
struct SSAEntry {
/*! \brief The value id */
std::string vid;
/*! \brief The scope id */
int scope_id;
};
/*!
* \brief mark the beginning of a new scope
* \return The scope id.
......@@ -209,6 +235,8 @@ class CodeGenC {
std::unordered_map<const Variable*, Type> handle_data_type_;
/*! \brief array to check whether we are inside certain scope */
std::vector<bool> scope_mark_;
/*! \brief The current indentation value */
int indent{0};
};
} // namespace codegen
......
......@@ -19,7 +19,7 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) {
CodeGenC::AddFunction(f);
}
void CodeGenCUDA::PrintStmt(const ir::For* op) {
void CodeGenCUDA::VisitStmt_(const ir::For* op) {
int ext;
CHECK(is_zero(op->min));
if (arith::GetConstInt(op->extent, &ext) &&
......@@ -27,7 +27,7 @@ void CodeGenCUDA::PrintStmt(const ir::For* op) {
PrintIndent();
stream << "#pragma unroll\n";
}
CodeGenC::PrintStmt(op);
CodeGenC::VisitStmt_(op);
}
void CodeGenCUDA::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
......
......@@ -18,7 +18,7 @@ class CodeGenCUDA : public CodeGenC {
public:
void AddFunction(LoweredFunc f);
// override behavior
void PrintStmt(const ir::For* op) final;
void VisitStmt_(const ir::For* op) final;
void PrintStorageSync(const std::string& sync) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintVecBinaryOp(
......
......@@ -8,7 +8,7 @@
#ifdef TVM_LLVM_VERSION
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/codegen.h>
#include <memory>
#include <vector>
......@@ -23,7 +23,9 @@ using namespace ir;
/*!
* \brief A base class to generate a LLVM.
*/
class CodeGenLLVM : public IRVisitor {
class CodeGenLLVM :
public ExprFunctor<llvm::Value* (const Expr&)>,
public StmtFunctor<void(const Stmt&)> {
public:
/*!
* \brief Initialize the code generator with given context
......@@ -55,52 +57,52 @@ class CodeGenLLVM : public IRVisitor {
* \return created value.
*/
llvm::Value* MakeValue(const Expr& e) {
value_ = nullptr;
this->Visit(e);
CHECK(value_ != nullptr);
return value_;
return VisitExpr(e);
}
// Short hande code to get a constant int 32
llvm::Constant* ConstInt32(unsigned value) const {
return llvm::ConstantInt::get(t_int32_, value);
}
// override codegen
void Visit_(const Variable* op) final;
void Visit_(const Cast* op) final;
void Visit_(const IntImm* op) final;
void Visit_(const UIntImm* op) final;
void Visit_(const FloatImm* op) final;
void Visit_(const StringImm* op) final;
void Visit_(const Add* op) final;
void Visit_(const Sub* op) final;
void Visit_(const Mul* op) final;
void Visit_(const Div* op) final;
void Visit_(const Mod* op) final;
void Visit_(const Min* op) final;
void Visit_(const Max* op) final;
void Visit_(const LT* op) final;
void Visit_(const LE* op) final;
void Visit_(const GT* op) final;
void Visit_(const GE* op) final;
void Visit_(const EQ* op) final;
void Visit_(const NE* op) final;
void Visit_(const And* op) final;
void Visit_(const Or* op) final;
void Visit_(const Not* op) final;
void Visit_(const Select* op) final;
void Visit_(const Let* op) final;
void Visit_(const Load* op) final;
void Visit_(const Call* op) final;
void Visit_(const Ramp* op) final;
void Visit_(const Broadcast* op) final;
llvm::Value* VisitExpr_(const Variable* op) override;
llvm::Value* VisitExpr_(const Cast* op) override;
llvm::Value* VisitExpr_(const IntImm* op) override;
llvm::Value* VisitExpr_(const UIntImm* op) override;
llvm::Value* VisitExpr_(const FloatImm* op) override;
llvm::Value* VisitExpr_(const StringImm* op) override;
llvm::Value* VisitExpr_(const Add* op) override;
llvm::Value* VisitExpr_(const Sub* op) override;
llvm::Value* VisitExpr_(const Mul* op) override;
llvm::Value* VisitExpr_(const Div* op) override;
llvm::Value* VisitExpr_(const Mod* op) override;
llvm::Value* VisitExpr_(const Min* op) override;
llvm::Value* VisitExpr_(const Max* op) override;
llvm::Value* VisitExpr_(const LT* op) override;
llvm::Value* VisitExpr_(const LE* op) override;
llvm::Value* VisitExpr_(const GT* op) override;
llvm::Value* VisitExpr_(const GE* op) override;
llvm::Value* VisitExpr_(const EQ* op) override;
llvm::Value* VisitExpr_(const NE* op) override;
llvm::Value* VisitExpr_(const And* op) override;
llvm::Value* VisitExpr_(const Or* op) override;
llvm::Value* VisitExpr_(const Not* op) override;
llvm::Value* VisitExpr_(const Select* op) override;
llvm::Value* VisitExpr_(const Let* op) override;
llvm::Value* VisitExpr_(const Load* op) override;
llvm::Value* VisitExpr_(const Call* op) override;
llvm::Value* VisitExpr_(const Ramp* op) override;
llvm::Value* VisitExpr_(const Broadcast* op) override;
// stmt
void Visit_(const Store* op) final;
void Visit_(const For* op) final;
void Visit_(const IfThenElse* op) final;
void Visit_(const Allocate* op) final;
void Visit_(const AttrStmt* op) override;
void Visit_(const AssertStmt* op) final;
void Visit_(const LetStmt* op) final;
void VisitStmt_(const Store* op) override;
void VisitStmt_(const For* op) override;
void VisitStmt_(const IfThenElse* op) override;
void VisitStmt_(const Allocate* op) override;
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const LetStmt* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
// create intrinstic given call
virtual llvm::Value* CreateIntrinstic(const Call* op);
// create extern function call
......@@ -160,8 +162,6 @@ class CodeGenLLVM : public IRVisitor {
llvm::Function* f_tvm_parallel_for_{nullptr};
// The acting body
llvm::BasicBlock* block_{nullptr};
// Last value returned codegen call.
llvm::Value* value_{nullptr};
private:
// comparison op
......
......@@ -7,6 +7,7 @@
#define TVM_CODEGEN_STACK_VM_CODEGEN_STACK_VM_H_
#include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/lowered_func.h>
#include <tvm/codegen.h>
#include <string>
......@@ -18,12 +19,15 @@
namespace tvm {
namespace codegen {
using namespace ir;
/*!
* \brief A base class to generate a stack VM.
* This module is used to generate host wrapper
* into device function when only device JIT is available.
*/
class CodeGenStackVM {
class CodeGenStackVM
: public ExprFunctor<void(const Expr&)>,
public StmtFunctor<void(const Stmt&)> {
public:
/*!
* \brief Generate a stack VM representing
......@@ -36,7 +40,9 @@ class CodeGenStackVM {
/*! \brief Push stmt to generate new code */
void Push(const Stmt& n);
/*! \brief Push expr to generate new code */
void Push(const Expr& n);
void Push(const Expr& n) {
VisitExpr(n);
}
/*!
* \brief Push the opcode to the code.
* \param opcode The code to be pushed.
......@@ -84,16 +90,53 @@ class CodeGenStackVM {
* \return the heap index of the var.
*/
int GetVarID(const Variable* v) const;
// Push binary operator
void PushBinary(StackVM::OpCode op_int64,
const Expr& a,
const Expr& b);
// push cast;
void PushCast(Type dst, Type src);
// overloadable functions
virtual void Push_(const ir::Load* op);
virtual void Push_(const ir::Store* op);
virtual void Push_(const ir::Allocate* op);
virtual void Push_(const ir::Call* op);
virtual void HandleUnknownCall(const ir::Call* op);
/*! \brief function to to print normal code */
using FType = IRFunctor<void(const NodeRef&, CodeGenStackVM *)>;
// vtable to print code
static FType& vtable(); // NOLINT(*)
// expression
void VisitExpr_(const Variable* op) final;
void VisitExpr_(const Load* op) final;
void VisitExpr_(const Let* op) final;
void VisitExpr_(const Call* op) final;
void VisitExpr_(const Add* op) final;
void VisitExpr_(const Sub* op) final;
void VisitExpr_(const Mul* op) final;
void VisitExpr_(const Div* op) final;
void VisitExpr_(const Mod* op) final;
void VisitExpr_(const Min* op) final;
void VisitExpr_(const Max* op) final;
void VisitExpr_(const EQ* op) final;
void VisitExpr_(const NE* op) final;
void VisitExpr_(const LT* op) final;
void VisitExpr_(const LE* op) final;
void VisitExpr_(const GT* op) final;
void VisitExpr_(const GE* op) final;
void VisitExpr_(const And* op) final;
void VisitExpr_(const Or* op) final;
void VisitExpr_(const Cast* op) final;
void VisitExpr_(const Not* op) final;
void VisitExpr_(const Select* op) final;
void VisitExpr_(const Ramp* op) final;
void VisitExpr_(const Broadcast* op) final;
void VisitExpr_(const IntImm* op) final;
void VisitExpr_(const UIntImm* op) final;
void VisitExpr_(const FloatImm* op) final;
void VisitExpr_(const StringImm* op) final;
// statment
void VisitStmt_(const LetStmt* op) final;
void VisitStmt_(const Store* op) final;
void VisitStmt_(const For* op) final;
void VisitStmt_(const IfThenElse* op) final;
void VisitStmt_(const Allocate* op) final;
void VisitStmt_(const AttrStmt* op) final;
void VisitStmt_(const AssertStmt* op) final;
void VisitStmt_(const Evaluate* op) final;
void VisitStmt_(const Block* op) final;
void VisitStmt_(const ProducerConsumer* op) final;
private:
bool debug_{false};
......
......@@ -140,10 +140,6 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) {
}
}
Stmt IRMutator::Mutate_(const Load *op, const Stmt& s) {
return s;
}
Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) {
Expr value = this->Mutate(op->value);
Expr index = this->Mutate(op->index);
......@@ -234,84 +230,24 @@ Stmt IRMutator::Mutate_(const Evaluate *op, const Stmt& s) {
}
}
#define DEFINE_OP_RETURN_SELF_STMT_MUTATE_(OP) \
Stmt IRMutator::Mutate_(const OP *op, const Stmt& s) { \
return s; \
}
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Variable)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Let)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Free)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Call)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Add)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Sub)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Mul)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Div)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Mod)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Min)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Max)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(EQ)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(NE)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(LT)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(LE)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(GT)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(GE)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(And)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Or)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Reduce)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Cast)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Not)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Select)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Ramp)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Broadcast)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(IntImm)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(UIntImm)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(FloatImm)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(StringImm)
Stmt IRMutator::Mutate_(const Free *op, const Stmt& s) {
return s;
}
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.DISPATCH_TO_MUTATE_STMT(Variable)
.DISPATCH_TO_MUTATE_STMT(LetStmt)
.DISPATCH_TO_MUTATE_STMT(AttrStmt)
.DISPATCH_TO_MUTATE_STMT(IfThenElse)
.DISPATCH_TO_MUTATE_STMT(For)
.DISPATCH_TO_MUTATE_STMT(Allocate)
.DISPATCH_TO_MUTATE_STMT(Load)
.DISPATCH_TO_MUTATE_STMT(Store)
.DISPATCH_TO_MUTATE_STMT(Let)
.DISPATCH_TO_MUTATE_STMT(Free)
.DISPATCH_TO_MUTATE_STMT(Call)
.DISPATCH_TO_MUTATE_STMT(Add)
.DISPATCH_TO_MUTATE_STMT(Sub)
.DISPATCH_TO_MUTATE_STMT(Mul)
.DISPATCH_TO_MUTATE_STMT(Div)
.DISPATCH_TO_MUTATE_STMT(Mod)
.DISPATCH_TO_MUTATE_STMT(Min)
.DISPATCH_TO_MUTATE_STMT(Max)
.DISPATCH_TO_MUTATE_STMT(EQ)
.DISPATCH_TO_MUTATE_STMT(NE)
.DISPATCH_TO_MUTATE_STMT(LT)
.DISPATCH_TO_MUTATE_STMT(LE)
.DISPATCH_TO_MUTATE_STMT(GT)
.DISPATCH_TO_MUTATE_STMT(GE)
.DISPATCH_TO_MUTATE_STMT(And)
.DISPATCH_TO_MUTATE_STMT(Or)
.DISPATCH_TO_MUTATE_STMT(Reduce)
.DISPATCH_TO_MUTATE_STMT(Cast)
.DISPATCH_TO_MUTATE_STMT(Not)
.DISPATCH_TO_MUTATE_STMT(Select)
.DISPATCH_TO_MUTATE_STMT(Ramp)
.DISPATCH_TO_MUTATE_STMT(Broadcast)
.DISPATCH_TO_MUTATE_STMT(AssertStmt)
.DISPATCH_TO_MUTATE_STMT(ProducerConsumer)
.DISPATCH_TO_MUTATE_STMT(Provide)
.DISPATCH_TO_MUTATE_STMT(Realize)
.DISPATCH_TO_MUTATE_STMT(Block)
.DISPATCH_TO_MUTATE_STMT(Evaluate)
.DISPATCH_TO_MUTATE_STMT(IntImm)
.DISPATCH_TO_MUTATE_STMT(UIntImm)
.DISPATCH_TO_MUTATE_STMT(FloatImm)
.DISPATCH_TO_MUTATE_STMT(StringImm);
.DISPATCH_TO_MUTATE_STMT(Evaluate);
// Mutate Expr
......@@ -450,19 +386,6 @@ Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) {
return e; \
}
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(LetStmt)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(AttrStmt)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(For)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IfThenElse)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Allocate)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Store)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Free)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(AssertStmt)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(ProducerConsumer)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Provide)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Realize)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Block)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Evaluate)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImm)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImm)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImm)
......@@ -470,15 +393,8 @@ DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImm)
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.DISPATCH_TO_MUTATE_EXPR(Variable)
.DISPATCH_TO_MUTATE_EXPR(LetStmt)
.DISPATCH_TO_MUTATE_EXPR(AttrStmt)
.DISPATCH_TO_MUTATE_EXPR(IfThenElse)
.DISPATCH_TO_MUTATE_EXPR(For)
.DISPATCH_TO_MUTATE_EXPR(Allocate)
.DISPATCH_TO_MUTATE_EXPR(Load)
.DISPATCH_TO_MUTATE_EXPR(Store)
.DISPATCH_TO_MUTATE_EXPR(Let)
.DISPATCH_TO_MUTATE_EXPR(Free)
.DISPATCH_TO_MUTATE_EXPR(Call)
.DISPATCH_TO_MUTATE_EXPR(Add)
.DISPATCH_TO_MUTATE_EXPR(Sub)
......@@ -501,12 +417,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.DISPATCH_TO_MUTATE_EXPR(Select)
.DISPATCH_TO_MUTATE_EXPR(Ramp)
.DISPATCH_TO_MUTATE_EXPR(Broadcast)
.DISPATCH_TO_MUTATE_EXPR(AssertStmt)
.DISPATCH_TO_MUTATE_EXPR(ProducerConsumer)
.DISPATCH_TO_MUTATE_EXPR(Provide)
.DISPATCH_TO_MUTATE_EXPR(Realize)
.DISPATCH_TO_MUTATE_EXPR(Block)
.DISPATCH_TO_MUTATE_EXPR(Evaluate)
.DISPATCH_TO_MUTATE_EXPR(IntImm)
.DISPATCH_TO_MUTATE_EXPR(UIntImm)
.DISPATCH_TO_MUTATE_EXPR(FloatImm)
......
......@@ -69,11 +69,71 @@ class Vectorizer : public IRMutator {
}
// user mutate from parent.
using IRMutator::Mutate;
// override mutate
Expr Mutate(Expr expr) final {
static const FMutateExpr& f = Vectorizer::vtable_expr();
return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::Mutate(expr));
Expr Mutate_(const Add* op, const Expr &e) final {
return AddSubVec(op, e);
}
Expr Mutate_(const Sub* op, const Expr &e) final {
return AddSubVec(op, e);
}
Expr Mutate_(const Mul* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Div* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Mod* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Min* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Max* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const EQ* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const NE* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const LT* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const GT* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const GE* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const And* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Or* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Select *op, const Expr& e) final {
Expr cond = this->Mutate(op->condition);
Expr t = this->Mutate(op->true_value);
Expr f = this->Mutate(op->false_value);
if (cond.same_as(op->condition) &&
t.same_as(op->true_value) &&
f.same_as(op->false_value)) {
return e;
} else {
int lanes = std::max(std::max(
cond.type().lanes(),
t.type().lanes()), f.type().lanes());
return Select::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes));
}
}
Expr Mutate_(const Cast *op, const Expr& e) final {
Expr value = this->Mutate(op->value);
if (value.same_as(op->value)) {
return e;
} else {
return Cast::make(op->type.with_lanes(value.type().lanes()), value);
}
}
// Variable
Expr Mutate_(const Variable* v, const Expr& e) final {
......@@ -235,10 +295,6 @@ class Vectorizer : public IRMutator {
stmt = Substitute(stmt, {{var_, idx}});
return For::make(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt);
}
// The overloads for vectorize.
static FMutateExpr& vtable_expr() { // NOLINT(*)
static FMutateExpr inst; return inst;
}
private:
// variable to be replaced
......@@ -273,13 +329,10 @@ class Vectorizer : public IRMutator {
if (!changed) return arr;
return Array<Expr>(new_arr);
}
};
// binary vectorize
template<typename T>
inline Expr BinaryVec(const T* op, const Expr& e, IRMutator* m) {
Expr a = m->Mutate(op->a);
Expr b = m->Mutate(op->b);
template<typename T>
Expr BinaryVec(const T* op, const Expr& e) {
Expr a = this->Mutate(op->a);
Expr b = this->Mutate(op->b);
if (a.same_as(op->a) &&
b.same_as(op->b)) {
return e;
......@@ -287,12 +340,11 @@ inline Expr BinaryVec(const T* op, const Expr& e, IRMutator* m) {
int lanes = std::max(a.type().lanes(), b.type().lanes());
return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
}
}
template<typename T>
inline Expr AddSubVec(const T* op, const Expr& e, IRMutator* m) {
Expr a = m->Mutate(op->a);
Expr b = m->Mutate(op->b);
}
template<typename T>
Expr AddSubVec(const T* op, const Expr& e) {
Expr a = this->Mutate(op->a);
Expr b = this->Mutate(op->b);
if (a.same_as(op->a) &&
b.same_as(op->b)) {
return e;
......@@ -312,51 +364,8 @@ inline Expr AddSubVec(const T* op, const Expr& e, IRMutator* m) {
}
return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
}
}
TVM_STATIC_IR_FUNCTOR(Vectorizer, vtable_expr)
.set_dispatch<Add>(AddSubVec<Add>)
.set_dispatch<Sub>(AddSubVec<Sub>)
.set_dispatch<Mul>(BinaryVec<Mul>)
.set_dispatch<Div>(BinaryVec<Div>)
.set_dispatch<Mod>(BinaryVec<Mod>)
.set_dispatch<Min>(BinaryVec<Min>)
.set_dispatch<Max>(BinaryVec<Max>)
.set_dispatch<EQ>(BinaryVec<EQ>)
.set_dispatch<NE>(BinaryVec<NE>)
.set_dispatch<LT>(BinaryVec<LT>)
.set_dispatch<LE>(BinaryVec<LE>)
.set_dispatch<GT>(BinaryVec<GT>)
.set_dispatch<GE>(BinaryVec<GE>)
.set_dispatch<And>(BinaryVec<And>)
.set_dispatch<Or>(BinaryVec<Or>);
TVM_STATIC_IR_FUNCTOR(Vectorizer, vtable_expr)
.set_dispatch<Select>([](const Select *op, const Expr& e, IRMutator* m) {
Expr cond = m->Mutate(op->condition);
Expr t = m->Mutate(op->true_value);
Expr f = m->Mutate(op->false_value);
if (cond.same_as(op->condition) &&
t.same_as(op->true_value) &&
f.same_as(op->false_value)) {
return e;
} else {
int lanes = std::max(std::max(
cond.type().lanes(),
t.type().lanes()), f.type().lanes());
return Select::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes));
}
})
.set_dispatch<Cast>([](const Cast *op, const Expr& e, IRMutator* m) {
Expr value = m->Mutate(op->value);
if (value.same_as(op->value)) {
return e;
} else {
return Cast::make(op->type.with_lanes(value.type().lanes()), value);
}
});
};
class LoopVectorizer : public IRMutator {
public:
......
......@@ -2,10 +2,11 @@
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/ir_functor.h>
#include <tvm/ir_functor_ext.h>
TEST(IRF, Basic) {
using namespace Halide::Internal;
using namespace tvm;
using namespace tvm::ir;
Var x("x");
auto z = x + 1;
......@@ -21,6 +22,65 @@ TEST(IRF, Basic) {
CHECK_EQ(f(z, 2), 4);
}
TEST(IRF, ExprTransform) {
using namespace tvm;
using namespace tvm::ir;
Var x("x");
auto z = x + 1;
class MyExprFunctor
: public ir::ExprFunctor<int(const Expr&, int)> {
public:
int VisitExpr_(const Variable* op, int b) final {
return b;
}
int VisitExpr_(const IntImm* op, int b) final {
return op->value;
}
int VisitExpr_(const Add* op, int b) final {
return VisitExpr(op->a, b) + VisitExpr(op->b, b);
}
};
MyExprFunctor f;
CHECK_EQ(f(x, 2), 2);
CHECK_EQ(f(z, 2), 3);
try {
f(z - 1, 2);
LOG(FATAL) << "should fail";
} catch(dmlc::Error) {
}
}
TEST(IRF, ExprVisit) {
using namespace tvm;
using namespace tvm::ir;
Var x("x");
auto z = x + 1;
class MyVisitor
: public ir::ExprFunctor<void(const Expr&)>,
public ir::StmtFunctor<void(const Stmt&)> {
public:
int count = 0;
// implementation
void VisitExpr_(const Variable* op) final {
++count;
}
void VisitExpr_(const IntImm* op) final {
}
void VisitExpr_(const Add* op) final {
VisitExpr(op->a);
VisitExpr(op->b);
}
void VisitStmt_(const Evaluate* op) final {
VisitExpr(op->value);
}
};
MyVisitor v;
v(Evaluate::make(z));
CHECK_EQ(v.count, 1);
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
......
......@@ -25,6 +25,11 @@ def test_deduce():
ans1 = (c-b)/4+(-2)
assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1)
e2 = (tvm.max(5, a * 4) < 0)
res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s})
assert str(res2.max()) == "neg_inf"
assert str(res2.min()) == "pos_inf"
def test_check():
a = tvm.Var('a')
b = tvm.Var('b')
......
import tvm
def test_basic():
a = tvm.Var()
b = tvm.Var()
m = tvm.arith.EvalModular(a * 4 + b * 6 + 7)
assert m.coeff == 2
assert m.base == 1
m = tvm.arith.EvalModular((a * 4 + 1) * (b * 8 + 3))
assert m.coeff == 4
assert m.base == 3
m = tvm.arith.EvalModular((a * 4 + 1) / (b * 8 + 3))
assert m.coeff == 1
assert m.base == 0
m = tvm.arith.EvalModular((a * 4 + 1) * (b * 8 / 4))
assert m.coeff == 2
assert m.base == 0
m = tvm.arith.EvalModular((a * 12 + 1) - (b * 3 * 7 + 2))
assert m.coeff == 3
assert m.base == 2
m = tvm.arith.EvalModular(a * 12 + tvm.min(b * 3 * 7, 2))
assert m.coeff == 1
assert m.base == 0
if __name__ == "__main__":
test_basic()
......@@ -16,7 +16,7 @@ def test_llvm_add_pipeline():
f = tvm.build(s, [A, B, C], "llvm")
ctx = tvm.cpu(0)
# launch the kernel.
n = 10270 * 2460
n = 1027 * 1024
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment