Commit 71334483 by Tianqi Chen Committed by GitHub

[VISITOR] New ExprFunctor, StmtFunctor Interface. Modular analysis (#58)

* [ARITH/VISITOR] Modular Analysis, ExprFunctor, StmtFunctor

* retrigger

* [IRFunctor] Migrated CodegenC

* [IRFUNCTOR] Migrate CodeGenLLVM

* [IRFunctor] Migrate canonical

* [IRFunctor] Migrate vectorize

* [IRFunctor] migrate CodeGenStackVM
parent e4387940
...@@ -59,7 +59,6 @@ after_failure: ...@@ -59,7 +59,6 @@ after_failure:
- tests/travis/travis_after_failure.sh - tests/travis/travis_after_failure.sh
notifications: notifications:
# Emails are sent to the committer's git-configured email address by default,
email: email:
on_success: change on_success: change
on_failure: always on_failure: always
/*!
* Copyright (c) 2017 by Contributors
* \file ir_functor_ext.h
* \brief More powerful Visitor that allows define function signatures.
*/
#ifndef TVM_IR_FUNCTOR_EXT_H_
#define TVM_IR_FUNCTOR_EXT_H_
#include <tvm/ir_functor.h>
#include "./ir.h"
namespace tvm {
namespace ir {
/*!
* \brief A dynamical functor that dispatches on in the first Expr argument.
* You can use this as a more powerful Visitor, since it allows you to
* define function signatures of Visit Function.
*
* \code
* // A functor that set variable to b. and calculate results.
* class MyExprFunctor
* : public ir::ExprFunctor<int(const Expr&, int)> {
* public:
* int VisitExpr_(const Variable* op, int b) final {
* return b;
* }
* int VisitExpr_(const IntImm* op, int b) final {
* return op->value;
* }
* int VisitExpr_(const Add* op, int b) final {
* return Visit(op->a, b) + Visit(op->b, b);
* }
* };
* MyExprFunctor f;
* Var x("x");
* CHECK_EQ(f(x + 1, 2), 3);
* \endcode
*
* \note Why do we need this more powerful Functor:
*
* We often need to implement a transformer tasks.
* Say we want to take Expr and transform it to some analysis result,
* This easily be done incorrectly using plain Visitor. See IRVisitor's
* document for possible error cases.
*
* \tparam FType function signiture
* This type if only defined for FType with function signiture R(const Expr&, Args...)
*/
template<typename FType>
class ExprFunctor;
/*!
* \brief Same as ExprFunctor except it is applied on statements
* \tparam FType The function signature.
*/
template<typename FType>
class StmtFunctor;
// functions to be overriden.
#define EXPR_FUNCTOR_DEFAULT { \
return VisitExprDefault_(op, std::forward<Args>(args)...); \
}
#define STMT_FUNCTOR_DEFAULT { \
return VisitStmtDefault_(op, std::forward<Args>(args)...); \
}
#define IR_EXPR_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
return self->VisitExpr_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \
}); \
#define IR_STMT_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
return self->VisitStmt_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \
}); \
template<typename R, typename ...Args>
class ExprFunctor<R(const Expr& n, Args...)> {
private:
using TSelf = ExprFunctor<R(const Expr& n, Args...)>;
using FType = IRFunctor<R(const NodeRef& n, TSelf* self, Args...)>;
public:
/*! \brief the result type of this functor */
using result_type = R;
/*! \brief virtual destructor */
virtual ~ExprFunctor() {}
/*!
* \brief Same as call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
R operator()(const Expr& n, Args... args) {
return VisitExpr(n, std::forward<Args>(args)...);
}
/*!
* \brief The functor call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
virtual R VisitExpr(const Expr& n, Args... args) {
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
// Functions that can be overriden by subclass
virtual R VisitExpr_(const Variable* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Load* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Let* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Call* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Add* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Sub* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Mul* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Div* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Mod* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Min* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Max* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const EQ* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const NE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LT* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const GT* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const GE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const And* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Or* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Reduce* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Cast* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Not* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Select* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Ramp* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Broadcast* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const IntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const UIntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const StringImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args ...) {
LOG(FATAL) << "Do not have a default for " << op->type_key();
return R();
}
private:
// initialize the vtable.
static FType InitVTable() {
FType vtable;
// Set dispatch
IR_EXPR_FUNCTOR_DISPATCH(Variable);
IR_EXPR_FUNCTOR_DISPATCH(Load);
IR_EXPR_FUNCTOR_DISPATCH(Let);
IR_EXPR_FUNCTOR_DISPATCH(Call);
IR_EXPR_FUNCTOR_DISPATCH(Add);
IR_EXPR_FUNCTOR_DISPATCH(Sub);
IR_EXPR_FUNCTOR_DISPATCH(Mul);
IR_EXPR_FUNCTOR_DISPATCH(Div);
IR_EXPR_FUNCTOR_DISPATCH(Mod);
IR_EXPR_FUNCTOR_DISPATCH(Min);
IR_EXPR_FUNCTOR_DISPATCH(Max);
IR_EXPR_FUNCTOR_DISPATCH(EQ);
IR_EXPR_FUNCTOR_DISPATCH(NE);
IR_EXPR_FUNCTOR_DISPATCH(LT);
IR_EXPR_FUNCTOR_DISPATCH(LE);
IR_EXPR_FUNCTOR_DISPATCH(GT);
IR_EXPR_FUNCTOR_DISPATCH(GE);
IR_EXPR_FUNCTOR_DISPATCH(And);
IR_EXPR_FUNCTOR_DISPATCH(Or);
IR_EXPR_FUNCTOR_DISPATCH(Reduce);
IR_EXPR_FUNCTOR_DISPATCH(Cast);
IR_EXPR_FUNCTOR_DISPATCH(Not);
IR_EXPR_FUNCTOR_DISPATCH(Select);
IR_EXPR_FUNCTOR_DISPATCH(Ramp);
IR_EXPR_FUNCTOR_DISPATCH(Broadcast);
IR_EXPR_FUNCTOR_DISPATCH(IntImm);
IR_EXPR_FUNCTOR_DISPATCH(UIntImm);
IR_EXPR_FUNCTOR_DISPATCH(FloatImm);
IR_EXPR_FUNCTOR_DISPATCH(StringImm);
return vtable;
}
};
template<typename R, typename ...Args>
class StmtFunctor<R(const Stmt& n, Args... args)> {
private:
using TSelf = StmtFunctor<R(const Stmt& n, Args... args)>;
using FType = IRFunctor<R(const NodeRef& n, TSelf* self, Args... args)>;
public:
/*! \brief the result type of this functor */
using result_type = R;
/*! \brief virtual destructor */
virtual ~StmtFunctor() {}
/*!
* \brief Same as call.
* \param n The stmt node.
* \param args Additional arguments.
* \return The result of the call
*/
R operator()(const Stmt& n, Args... args) {
return VisitStmt(n, std::forward<Args>(args)...);
}
/*!
* \brief The functor call.
* \param n The stmt node.
* \param args Additional arguments.
* \return The result of the call
*/
virtual R VisitStmt(const Stmt& n, Args... args) {
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
// Functions that can be overriden by subclass
virtual R VisitStmt_(const LetStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AttrStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const IfThenElse* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const For* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Allocate* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Store* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Free* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AssertStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ProducerConsumer* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Provide* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Realize* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmtDefault_(const Node* op, Args ...) {
LOG(FATAL) << "Do not have a default for " << op->type_key();
return R();
}
private:
// initialize the vtable.
static FType InitVTable() {
FType vtable;
IR_STMT_FUNCTOR_DISPATCH(LetStmt);
IR_STMT_FUNCTOR_DISPATCH(AttrStmt);
IR_STMT_FUNCTOR_DISPATCH(IfThenElse);
IR_STMT_FUNCTOR_DISPATCH(For);
IR_STMT_FUNCTOR_DISPATCH(Allocate);
IR_STMT_FUNCTOR_DISPATCH(Store);
IR_STMT_FUNCTOR_DISPATCH(Free);
IR_STMT_FUNCTOR_DISPATCH(AssertStmt);
IR_STMT_FUNCTOR_DISPATCH(ProducerConsumer);
IR_STMT_FUNCTOR_DISPATCH(Provide);
IR_STMT_FUNCTOR_DISPATCH(Realize);
IR_STMT_FUNCTOR_DISPATCH(Block);
IR_STMT_FUNCTOR_DISPATCH(Evaluate);
return vtable;
}
};
#undef IR_STMT_FUNCTOR_DISPATCH
#undef IR_EXPR_FUNCTOR_DISPATCH
#undef EXPR_FUNCTOR_DEFAULT
#undef STMT_FUNCTOR_DEFAULT
} // namespace ir
} // namespace tvm
#endif // TVM_IR_FUNCTOR_EXT_H_
...@@ -55,59 +55,23 @@ class IRMutator { ...@@ -55,59 +55,23 @@ class IRMutator {
static FMutateStmt& vtable_stmt(); // NOLINT(*) static FMutateStmt& vtable_stmt(); // NOLINT(*)
// Set of overloadable functions // Set of overloadable functions
// The underscore allows Mutate not to be shadowed by inheritance // The underscore allows Mutate not to be shadowed by inheritance
virtual Stmt Mutate_(const Variable* op, const Stmt& s);
virtual Stmt Mutate_(const LetStmt* op, const Stmt& s); virtual Stmt Mutate_(const LetStmt* op, const Stmt& s);
virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s); virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s);
virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s); virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s);
virtual Stmt Mutate_(const For* op, const Stmt& s); virtual Stmt Mutate_(const For* op, const Stmt& s);
virtual Stmt Mutate_(const Allocate* op, const Stmt& s); virtual Stmt Mutate_(const Allocate* op, const Stmt& s);
virtual Stmt Mutate_(const Load* op, const Stmt& s);
virtual Stmt Mutate_(const Store* op, const Stmt& s); virtual Stmt Mutate_(const Store* op, const Stmt& s);
virtual Stmt Mutate_(const Let* op, const Stmt& s);
virtual Stmt Mutate_(const Free* op, const Stmt& s); virtual Stmt Mutate_(const Free* op, const Stmt& s);
virtual Stmt Mutate_(const Call* op, const Stmt& s);
virtual Stmt Mutate_(const Add* op, const Stmt& e);
virtual Stmt Mutate_(const Sub* op, const Stmt& e);
virtual Stmt Mutate_(const Mul* op, const Stmt& e);
virtual Stmt Mutate_(const Div* op, const Stmt& e);
virtual Stmt Mutate_(const Mod* op, const Stmt& e);
virtual Stmt Mutate_(const Min* op, const Stmt& e);
virtual Stmt Mutate_(const Max* op, const Stmt& e);
virtual Stmt Mutate_(const EQ* op, const Stmt& e);
virtual Stmt Mutate_(const NE* op, const Stmt& e);
virtual Stmt Mutate_(const LT* op, const Stmt& e);
virtual Stmt Mutate_(const LE* op, const Stmt& e);
virtual Stmt Mutate_(const GT* op, const Stmt& e);
virtual Stmt Mutate_(const GE* op, const Stmt& e);
virtual Stmt Mutate_(const And* op, const Stmt& e);
virtual Stmt Mutate_(const Or* op, const Stmt& e);
virtual Stmt Mutate_(const Reduce* op, const Stmt& s);
virtual Stmt Mutate_(const Cast* op, const Stmt& s);
virtual Stmt Mutate_(const Not* op, const Stmt& s);
virtual Stmt Mutate_(const Select* op, const Stmt& s);
virtual Stmt Mutate_(const Ramp* op, const Stmt& s);
virtual Stmt Mutate_(const Broadcast* op, const Stmt& e);
virtual Stmt Mutate_(const AssertStmt* op, const Stmt& e); virtual Stmt Mutate_(const AssertStmt* op, const Stmt& e);
virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& e); virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& e);
virtual Stmt Mutate_(const Provide* op, const Stmt& e); virtual Stmt Mutate_(const Provide* op, const Stmt& e);
virtual Stmt Mutate_(const Realize* op, const Stmt& s); virtual Stmt Mutate_(const Realize* op, const Stmt& s);
virtual Stmt Mutate_(const Block* op, const Stmt& s); virtual Stmt Mutate_(const Block* op, const Stmt& s);
virtual Stmt Mutate_(const Evaluate* op, const Stmt& e); virtual Stmt Mutate_(const Evaluate* op, const Stmt& e);
virtual Stmt Mutate_(const IntImm* op, const Stmt& e);
virtual Stmt Mutate_(const UIntImm* op, const Stmt& e);
virtual Stmt Mutate_(const FloatImm* op, const Stmt& e);
virtual Stmt Mutate_(const StringImm* op, const Stmt& e);
virtual Expr Mutate_(const Variable* op, const Expr& e); virtual Expr Mutate_(const Variable* op, const Expr& e);
virtual Expr Mutate_(const LetStmt* op, const Expr& e);
virtual Expr Mutate_(const AttrStmt* op, const Expr& e);
virtual Expr Mutate_(const IfThenElse* op, const Expr& e);
virtual Expr Mutate_(const For* op, const Expr& e);
virtual Expr Mutate_(const Allocate* op, const Expr& e);
virtual Expr Mutate_(const Load* op, const Expr& e); virtual Expr Mutate_(const Load* op, const Expr& e);
virtual Expr Mutate_(const Store* op, const Expr& e);
virtual Expr Mutate_(const Let* op, const Expr& e); virtual Expr Mutate_(const Let* op, const Expr& e);
virtual Expr Mutate_(const Free* op, const Expr& e);
virtual Expr Mutate_(const Call* op, const Expr& e); virtual Expr Mutate_(const Call* op, const Expr& e);
virtual Expr Mutate_(const Add* op, const Expr& e); virtual Expr Mutate_(const Add* op, const Expr& e);
virtual Expr Mutate_(const Sub* op, const Expr& e); virtual Expr Mutate_(const Sub* op, const Expr& e);
...@@ -130,38 +94,12 @@ class IRMutator { ...@@ -130,38 +94,12 @@ class IRMutator {
virtual Expr Mutate_(const Select* op, const Expr& e); virtual Expr Mutate_(const Select* op, const Expr& e);
virtual Expr Mutate_(const Ramp* op, const Expr& e); virtual Expr Mutate_(const Ramp* op, const Expr& e);
virtual Expr Mutate_(const Broadcast* op, const Expr& e); virtual Expr Mutate_(const Broadcast* op, const Expr& e);
virtual Expr Mutate_(const AssertStmt* op, const Expr& e);
virtual Expr Mutate_(const ProducerConsumer* op, const Expr& e);
virtual Expr Mutate_(const Provide* op, const Expr& e);
virtual Expr Mutate_(const Realize* op, const Expr& e);
virtual Expr Mutate_(const Block* op, const Expr& e);
virtual Expr Mutate_(const Evaluate* op, const Expr& e);
virtual Expr Mutate_(const IntImm* op, const Expr& e); virtual Expr Mutate_(const IntImm* op, const Expr& e);
virtual Expr Mutate_(const UIntImm* op, const Expr& e); virtual Expr Mutate_(const UIntImm* op, const Expr& e);
virtual Expr Mutate_(const FloatImm* op, const Expr& e); virtual Expr Mutate_(const FloatImm* op, const Expr& e);
virtual Expr Mutate_(const StringImm* op, const Expr& e); virtual Expr Mutate_(const StringImm* op, const Expr& e);
}; };
/*!
* \brief Example on how to subclass and override behavior of IRMutator
*/
class IRMutatorExample : public IRMutator {
public:
Expr Mutate(Expr expr) final {
static const FMutateExpr& f = IRMutatorExample::vtable_expr();
return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::Mutate(expr));
}
Stmt Mutate(Stmt stmt) final {
static const FMutateStmt& f = IRMutatorExample::vtable_stmt();
return (f.can_dispatch(stmt) ?
f(stmt, stmt, this) : IRMutator::Mutate(stmt));
}
// to be implemented by child class
static FMutateExpr& vtable_expr(); // NOLINT(*)
static FMutateStmt& vtable_stmt(); // NOLINT(*)
};
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
#endif // TVM_IR_MUTATOR_H_ #endif // TVM_IR_MUTATOR_H_
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#ifndef TVM_IR_VISITOR_H_ #ifndef TVM_IR_VISITOR_H_
#define TVM_IR_VISITOR_H_ #define TVM_IR_VISITOR_H_
#include <tvm/ir_functor.h>
#include "./ir.h" #include "./ir.h"
namespace tvm { namespace tvm {
...@@ -17,7 +18,51 @@ namespace ir { ...@@ -17,7 +18,51 @@ namespace ir {
* This IRVisitor is implemented via IRFunctor * This IRVisitor is implemented via IRFunctor
* This enables extensions of possible new Node. * This enables extensions of possible new Node.
* *
* \sa IRFunctor, PostOrderVisit * \sa ExprFunctor, StmtFunctor, PostOrderVisit
*
* \note If you need to return values during Visit:
* - If it is mutaion of the IR, use IRMutator
* - If you want to return other things, consider use ExprFunctor/StmtFunctor
* - Watch out for possible bug pattern if you use IRVisitor to simulate returns.
*
* \code
*
* // This is an example code to show cases for traps in IRVisitor
* // The use case is to count number of Variables in the ir tree.
* class MyCounter : public IRVisitor {
* public:
* int Count(const NodeRef& n) {
* ret_ = 0;
* this->Visit(n);
* return ret_;
* }
* void Visit_(const Variable* op) final {
* ret_ = 1;
* }
* void Visit_(const Add* op) final {
* ret_ = count(op->a) + count(op->b);
* }
* private:
* int ret_;
* };
* MyCounter counter;
* Var x("x");
* // this returns 2
* CHECK_EQ(counter.Count(x + x), 2);
* // Think what is the result of the following count
* counter.count(Max::make(x, x));
* // The result is actually 1
* // This is because Visit is not overriden for Max
* // so it simply calls Visit for the left and right children
* // and because Count is not called, ret_ is not cleared.
* // There can also be cases where ret_ is forgetten to be set.
*
* // These traps may not happen if we program carefully
* // But it is recommended to use ExprFunctor, which allows direct
* // return the value, this helps us to avoid such problems.
* \encode
*
*/ */
class IRVisitor { class IRVisitor {
public: public:
......
...@@ -274,33 +274,51 @@ def sum(expr, axis): ...@@ -274,33 +274,51 @@ def sum(expr, axis):
return x return x
def min(expr, axis): def min(lhs, rhs=None, axis=None):
"""Create a min expression over axis """Create a min expression.
Parameters Parameters
---------- ----------
expr : Expr lhs : Expr
The source expression. The left hand expression.
axis : IterVar rhs : Expr, optional
The right hand expression.
axis : IterVar, optional
The reduction IterVar axis The reduction IterVar axis
""" """
if rhs and axis:
raise ValueError("Can only take one argument, rhs or axis")
if isinstance(rhs, (_collections.IterVar, list)):
axis, rhs = rhs, axis
if rhs:
return _make.Min(lhs, rhs)
axis = axis if isinstance(axis, list) else [axis] axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Min", expr, axis) x = _make.Reduce("Min", expr, axis)
return x return x
def max(expr, axis): def max(lhs, rhs=None, axis=None):
"""Create a min expression over axis """Create a max expression.
Parameters Parameters
---------- ----------
expr : Expr lhs : Expr
The source expression. The left hand expression.
axis : IterVar rhs : Expr, optional
The right hand expression.
axis : IterVar, optional
The reduction IterVar axis The reduction IterVar axis
""" """
if rhs and axis:
raise ValueError("Can only take one argument, rhs or axis")
if isinstance(rhs, (_collections.IterVar, list)):
axis, rhs = rhs, axis
if rhs:
return _make.Max(lhs, rhs)
axis = axis if isinstance(axis, list) else [axis] axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Max", expr, axis) x = _make.Reduce("Max", expr, axis)
return x return x
......
...@@ -5,7 +5,6 @@ from __future__ import absolute_import as _abs ...@@ -5,7 +5,6 @@ from __future__ import absolute_import as _abs
from ._ctypes._node import NodeBase, register_node from ._ctypes._node import NodeBase, register_node
from . import _api_internal from . import _api_internal
@register_node
class IntSet(NodeBase): class IntSet(NodeBase):
"""Represent a set of integer in one dimension.""" """Represent a set of integer in one dimension."""
def is_nothing(self): def is_nothing(self):
...@@ -33,3 +32,8 @@ class IntervalSet(IntSet): ...@@ -33,3 +32,8 @@ class IntervalSet(IntSet):
class StrideSet(IntSet): class StrideSet(IntSet):
"""Represent set of strided integers""" """Represent set of strided integers"""
pass pass
@register_node
class ModularSet(IntSet):
"""Represent range of (coeff * x + base) for x in Z """
pass
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/api_registry.h> #include <tvm/api_registry.h>
#include "../arithmetic/int_set.h" #include "../arithmetic/int_set.h"
#include "../arithmetic/modular.h"
namespace tvm { namespace tvm {
namespace arith { namespace arith {
...@@ -21,6 +22,11 @@ TVM_REGISTER_API(_arith_intset_interval) ...@@ -21,6 +22,11 @@ TVM_REGISTER_API(_arith_intset_interval)
*ret = IntSet::interval(args[0], args[1]); *ret = IntSet::interval(args[0], args[1]);
}); });
TVM_REGISTER_API(_arith_EvalModular)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = EvalModular(args[0], Map<Var, IntSet>());
});
TVM_REGISTER_API(_arith_DeduceBound) TVM_REGISTER_API(_arith_DeduceBound)
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DeduceBound(args[0], args[1], args[2]); *ret = DeduceBound(args[0], args[1], args[2]);
......
...@@ -162,10 +162,8 @@ class Canonical::Internal : public IRMutator { ...@@ -162,10 +162,8 @@ class Canonical::Internal : public IRMutator {
return stmt; return stmt;
} }
Expr MutateExpr_(Expr expr) { Expr MutateExpr_(Expr expr) {
static const FMutateExpr& f = Internal::vtable_expr();
stack_.push_back(StackEntry()); stack_.push_back(StackEntry());
expr = (f.can_dispatch(expr) ? expr = IRMutator::Mutate(expr);
f(expr, expr, this) : IRMutator::Mutate(expr));
// update result of parent automatically during pop // update result of parent automatically during pop
if (stack_.size() > 1) { if (stack_.size() > 1) {
StackEntry& back = stack_[stack_.size() - 1]; StackEntry& back = stack_[stack_.size() - 1];
...@@ -200,7 +198,7 @@ class Canonical::Internal : public IRMutator { ...@@ -200,7 +198,7 @@ class Canonical::Internal : public IRMutator {
return (t.lanes() == 1 && (t.is_int() || t.is_uint())); return (t.lanes() == 1 && (t.is_int() || t.is_uint()));
} }
// Add // Add
Expr Mutate_(const Add* op, const Expr& e) { Expr Mutate_(const Add* op, const Expr& e) final {
if (!EnableOpt(op->type)) { if (!EnableOpt(op->type)) {
return Binary(op, e, this); return Binary(op, e, this);
} }
...@@ -212,7 +210,7 @@ class Canonical::Internal : public IRMutator { ...@@ -212,7 +210,7 @@ class Canonical::Internal : public IRMutator {
return SumAdd(a, b, +1); return SumAdd(a, b, +1);
} }
// Sub // Sub
Expr Mutate_(const Sub* op, const Expr& e) { Expr Mutate_(const Sub* op, const Expr& e) final {
if (!EnableOpt(op->type)) { if (!EnableOpt(op->type)) {
return Binary(op, e, this); return Binary(op, e, this);
} }
...@@ -224,7 +222,7 @@ class Canonical::Internal : public IRMutator { ...@@ -224,7 +222,7 @@ class Canonical::Internal : public IRMutator {
return SumAdd(a, b, -1); return SumAdd(a, b, -1);
} }
// Mul // Mul
Expr Mutate_(const Mul* op, const Expr& e) { Expr Mutate_(const Mul* op, const Expr& e) final {
if (!EnableOpt(op->type)) { if (!EnableOpt(op->type)) {
return Binary(op, e, this); return Binary(op, e, this);
} }
...@@ -463,17 +461,6 @@ class Canonical::Internal : public IRMutator { ...@@ -463,17 +461,6 @@ class Canonical::Internal : public IRMutator {
using CInternal = Canonical::Internal; using CInternal = Canonical::Internal;
#define DISPATCH_EXPR(OP) \
set_dispatch<OP>([](const OP *op, const Expr& e, IRMutator* p) { \
return static_cast<CInternal*>(p)->Mutate_(op, e); })
TVM_STATIC_IR_FUNCTOR(CInternal, vtable_expr)
.DISPATCH_EXPR(Add)
.DISPATCH_EXPR(Sub)
.DISPATCH_EXPR(Mul)
.DISPATCH_EXPR(LT);
Canonical::Canonical() Canonical::Canonical()
: ptr_(std::make_shared<Internal>()) {} : ptr_(std::make_shared<Internal>()) {}
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include "./int_set.h" #include "./int_set.h"
#include "./modular.h"
namespace tvm { namespace tvm {
namespace arith { namespace arith {
...@@ -54,6 +55,23 @@ struct StrideSet : public IntSetNode { ...@@ -54,6 +55,23 @@ struct StrideSet : public IntSetNode {
TVM_DECLARE_NODE_TYPE_INFO(StrideSet, IntSetNode); TVM_DECLARE_NODE_TYPE_INFO(StrideSet, IntSetNode);
}; };
/*!
* \brief Set represented by range of ModularEntry.
* Used for front-end modular analysis.
*/
struct ModularSet : public IntSetNode {
/*! \brief Internal modular entry */
ModularEntry e;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("base", &(e.base));
v->Visit("coeff", &(e.coeff));
}
static constexpr const char* _type_key = "ModularSet";
TVM_DECLARE_NODE_TYPE_INFO(ModularSet, IntSetNode);
};
} // namespace arith } // namespace arith
} // namespace tvm } // namespace tvm
......
/*!
* Copyright (c) 2017 by Contributors
* \file modular.cc
* \brief Modular analysis
*/
#include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_visitor.h>
#include <limits>
#include "./modular.h"
#include "./int_set_internal.h"
namespace tvm {
namespace arith {
using namespace ir;
class ModularEvaluator
: public ExprFunctor<ModularEntry(const Expr&)> {
public:
explicit ModularEvaluator(
const std::unordered_map<
const Variable*, ModularEntry>& mod_map)
: mod_map_(mod_map) {
}
ModularEntry Eval(const Expr& e) {
return VisitExpr(e);
}
// default
ModularEntry VisitExprDefault_(const Node*) final {
return ModularEntry::everything();
}
// override combination rules.
ModularEntry VisitExpr_(const IntImm* op) final {
if (op->value < std::numeric_limits<int>::max()) {
ModularEntry ret;
ret.base = static_cast<int>(op->value);
ret.coeff = 0;
return ret;
} else {
return ModularEntry::everything();
}
}
ModularEntry VisitExpr_(const UIntImm* op) final {
if (op->value < static_cast<uint64_t>(
std::numeric_limits<int>::max())) {
ModularEntry ret;
ret.base = static_cast<int>(op->value);
ret.coeff = 0;
return ret;
} else {
return ModularEntry::everything();
}
}
ModularEntry VisitExpr_(const Variable* op) final {
auto it = mod_map_.find(op);
if (it != mod_map_.end()) {
return it->second;
} else {
return ModularEntry::everything();
}
}
ModularEntry VisitExpr_(const Add* op) final {
ModularEntry a = Eval(op->a);
ModularEntry b = Eval(op->b);
ModularEntry ret;
ret.coeff = ZeroAwareGCD(a.coeff, b.coeff);
ret.base = BaseSimplify(a.base + b.base, ret.coeff);
return ret;
}
ModularEntry VisitExpr_(const Sub* op) final {
ModularEntry a = Eval(op->a);
ModularEntry b = Eval(op->b);
ModularEntry ret;
ret.coeff = ZeroAwareGCD(a.coeff, b.coeff);
ret.base = BaseSimplify(a.base - b.base, ret.coeff);
return ret;
}
ModularEntry VisitExpr_(const Mul* op) final {
ModularEntry a = Eval(op->a);
ModularEntry b = Eval(op->b);
// Simplification rule, x, y, z are in Z
// (p x + n) (q y + m)
// -> pq xy + pm x + qn y + mn
// -> pq z + pm x + qn y + mn
int pq = a.coeff * b.coeff;
int pm = a.coeff * b.base;
int qn = a.base * b.coeff;
ModularEntry ret;
ret.coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn));
ret.base = BaseSimplify(a.base * b.base, ret.coeff);
return ret;
}
ModularEntry VisitExpr_(const Div* op) final {
// a c x / c -> a x
// We cannot do cases where offset is non-zero
// because of different integer rounding in pos/neg
ModularEntry a = Eval(op->a);
ModularEntry b = Eval(op->b);
if (b.coeff == 0 &&
a.base == 0) {
CHECK_NE(b.base, 0);
if (a.coeff % b.base == 0) {
ModularEntry ret;
ret.coeff = a.coeff / b.base;
ret.base = 0;
return ret;
}
}
return ModularEntry::everything();
}
private:
const std::unordered_map<
const Variable*, ModularEntry>& mod_map_;
// simplify the base by putting it in range.
static int BaseSimplify(int base, int coeff) {
if (coeff == 0) return base;
base = base % coeff;
if (base < 0) base += coeff;
return base;
}
static int ZeroAwareGCD(int a, int b) {
CHECK_GE(a, 0);
CHECK_GE(b, 0);
if (a < b) std::swap(a, b);
if (b == 0) return a;
// perform GCD (greatest common divisor)
// ax + by = gcd(a, b) z if a != 0, b != 0
while (a % b != 0) {
a = a % b;
std::swap(a, b);
}
return b;
}
};
ModularEntry EvalModular(
const Expr& e,
const std::unordered_map<const Variable*, ModularEntry>& mod_map) {
return ModularEvaluator(mod_map)(e);
}
IntSet EvalModular(const Expr& e,
const Map<Var, IntSet>& mod_map) {
std::unordered_map<const Variable*, ModularEntry> mmap;
for (auto& kv : mod_map) {
const ModularSet* m = kv.second.as<ModularSet>();
CHECK(m) << "Need to pass ModularSet for Modular Analysis";
mmap[kv.first.get()] = m->e;
}
std::shared_ptr<ModularSet> n = std::make_shared<ModularSet>();
n->e = ModularEvaluator(mmap)(e);
return IntSet(n);
}
} // namespace arith
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file modular.h
* \brief Modular integer set analysis
*/
#ifndef TVM_ARITHMETIC_MODULAR_H_
#define TVM_ARITHMETIC_MODULAR_H_
#include <tvm/expr.h>
#include "./int_set.h"
namespace tvm {
namespace arith {
/*!
* \brief Range of a linear integer function.
* Use to do specify the possible index values.
*
* set = { base + coeff * x | x \in Z }
*
* When coeff != 0, it can also be written as
* set = { n | n % coeff == base }
*
* This is useful to decide if the index is dividable by certain value.
* For example, if index = 0 + 4 x, then we know it can be divided by 4.
*/
struct ModularEntry {
/*! \brief The base */
int base;
/*! \brief linear co-efficient */
int coeff;
/*! \return entry represent everything */
static ModularEntry everything() {
// always safe to set 0 + x, so it can be everything.
ModularEntry e;
e.base = 0; e.coeff = 1;
return e;
}
};
/*!
* \brief Evaluate the expression with modular analysis
* \param e The expression to be evaluated.
* \param mod_map Map of modular statistics of known variables.
* \return The ModularEntry covering all possible value of e.
*/
ModularEntry EvalModular(
const Expr& e,
const std::unordered_map<const Variable*, ModularEntry>& mod_map);
/*!
* \brief Same as EvalModular, used by front-end.
* \param e The expression to be evaluated.
* \param mod_map Map of modular statistics of known variables.
* \return A ModularSet covering all possible value of e.
*/
IntSet EvalModular(const Expr& e,
const Map<Var, IntSet>& mod_map);
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_MODULAR_H_
...@@ -67,10 +67,6 @@ std::string CodeGenC::Finish() { ...@@ -67,10 +67,6 @@ std::string CodeGenC::Finish() {
return stream.str(); return stream.str();
} }
void CodeGenC::PrintStmt(const Stmt& n) {
static const FPrintStmt& f = vtable_print_stmt();
f(n, this);
}
std::string CodeGenC::SSAGetID(std::string src, Type t) { std::string CodeGenC::SSAGetID(std::string src, Type t) {
if (name_alloc_map_.count(src)) return src; if (name_alloc_map_.count(src)) return src;
...@@ -96,13 +92,12 @@ std::string CodeGenC::SSAGetID(std::string src, Type t) { ...@@ -96,13 +92,12 @@ std::string CodeGenC::SSAGetID(std::string src, Type t) {
} }
void CodeGenC::PrintExpr(const Expr& n, std::ostream& os) { // NOLINT(*) void CodeGenC::PrintExpr(const Expr& n, std::ostream& os) { // NOLINT(*)
static const FPrintExpr& f = vtable_print_expr();
if (print_ssa_form_) { if (print_ssa_form_) {
std::ostringstream temp; std::ostringstream temp;
f(n, temp, this); VisitExpr(n, temp);
os << SSAGetID(temp.str(), n.type()); os << SSAGetID(temp.str(), n.type());
} else { } else {
f(n, os, this); VisitExpr(n, os);
} }
} }
...@@ -178,6 +173,102 @@ void CodeGenC::MarkConst(std::string vid) { ...@@ -178,6 +173,102 @@ void CodeGenC::MarkConst(std::string vid) {
} }
} }
int CodeGenC::BeginScope() {
int sid = static_cast<int>(scope_mark_.size());
scope_mark_.push_back(true);
indent += 2;
return sid;
}
void CodeGenC::EndScope(int scope_id) {
scope_mark_[scope_id] = false;
indent -= 2;
}
// Print a reference expression to a buffer.
void CodeGenC::PrintBufferRef(
const Variable* buffer,
Type t, Expr index,
std::ostream& os) { // NOLINT(*)
std::string vid = GetVarID(buffer);
if (t.lanes() == 1) {
if (!HandleTypeMatch(buffer, t)) {
os << "((";
PrintType(t, os);
os << "*)" << vid << ')';
} else {
os << vid;
}
os << '[';
PrintExpr(index, os);
os << ']';
} else {
// Buffer declared as vector type.
// optimize for case where it is in register,
if (HandleTypeMatch(buffer, t)) {
// optimize for constant access
int offset;
if (arith::GetConstInt(index, &offset)) {
CHECK_EQ(offset % t.lanes(), 0)
<< "Find unaligned vector load to a vector type";
os << vid << '[' << (offset / t.lanes()) << ']';
return;
}
}
os << "((";
PrintType(t, os);
os << "*)(";
if (!HandleTypeMatch(buffer, t.element_of())) {
os << '(';
PrintType(t.element_of(), os);
os << "*)";
}
os << vid << " + ";
PrintExpr(index, os);
os << "))[0]";
}
}
void CodeGenC::PrintVecElemLoad(const std::string& vec,
Type t, int i,
std::ostream& os) { // NOLINT(*)
os << vec << ".s" << std::hex << i;
}
void CodeGenC::PrintVecElemStore(const std::string& vec,
Type t, int i,
const std::string& value) {
this->PrintIndent();
stream << vec << ".s" << std::hex << i
<< " = " << value << ";\n";
}
void CodeGenC::PrintVecLoad(const Variable* buffer,
Type t, Expr base,
std::ostream& os) {
PrintBufferRef(buffer, t, base, os);
}
void CodeGenC::PrintVecStore(const Variable* buffer,
Type t, Expr base,
const std::string& value) {
this->PrintIndent();
PrintBufferRef(buffer, t, base, stream);
stream << " = " << value << ";\n";
}
void CodeGenC::PrintThreadIndexExpr(
std::string thread_tag, std::ostream& os) { // NOLINT(*)
os << thread_tag;
}
void CodeGenC::PrintStorageSync(const std::string& sync) { // NOLINT(*)
}
void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
CHECK_EQ(scope, "global");
}
void CodeGenC::PrintType(Type t, std::ostream& os) const { // NOLINT(*) void CodeGenC::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
CHECK_EQ(t.lanes(), 1) CHECK_EQ(t.lanes(), 1)
<< "do not yet support vector types"; << "do not yet support vector types";
...@@ -208,13 +299,6 @@ void CodeGenC::PrintType(Type t, std::ostream& os) const { // NOLINT(*) ...@@ -208,13 +299,6 @@ void CodeGenC::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
LOG(FATAL) << "Cannot convert type " << t << " to C type"; LOG(FATAL) << "Cannot convert type " << t << " to C type";
} }
CodeGenC::FPrintStmt& CodeGenC::vtable_print_stmt() { // NOLINT(*)
static FPrintStmt inst; return inst;
}
CodeGenC::FPrintExpr& CodeGenC::vtable_print_expr() { // NOLINT(*)
static FPrintExpr inst; return inst;
}
inline void PrintConst(const IntImm* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) inline void PrintConst(const IntImm* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
if (op->type == Int(32)) { if (op->type == Int(32)) {
...@@ -262,19 +346,18 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N ...@@ -262,19 +346,18 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N
} }
} }
TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr) void CodeGenC::VisitExpr_(const IntImm *op, std::ostream& os) { // NOLINT(*)
.set_dispatch<IntImm>([](const IntImm *op, std::ostream& os, CodeGenC *p) { // NOLINT(*) PrintConst(op, os, this);
PrintConst(op, os, p); }
}) void CodeGenC::VisitExpr_(const UIntImm *op, std::ostream& os) { // NOLINT(*)
.set_dispatch<UIntImm>([](const UIntImm *op, std::ostream& os, CodeGenC *p) { // NOLINT(*) PrintConst(op, os, this);
PrintConst(op, os, p); }
}) void CodeGenC::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*)
.set_dispatch<FloatImm>([](const FloatImm *op, std::ostream& os, CodeGenC *p) { // NOLINT(*) PrintConst(op, os, this);
PrintConst(op, os, p); }
}) void CodeGenC::VisitExpr_(const StringImm *op, std::ostream& os) { // NOLINT(*)
.set_dispatch<StringImm>([](const StringImm *op, std::ostream& os, CodeGenC *p) { // NOLINT(*) os << "\"" << op->value << "\"";
os << "\"" << op->value << "\""; }
});
template<typename T> template<typename T>
inline void PrintBinaryExpr(const T* op, inline void PrintBinaryExpr(const T* op,
...@@ -315,137 +398,99 @@ inline void PrintBinaryIntrinsitc(const Call* op, ...@@ -315,137 +398,99 @@ inline void PrintBinaryIntrinsitc(const Call* op,
p->PrintVecBinaryOp(opstr, op->type, op->args[0], op->args[1], os); p->PrintVecBinaryOp(opstr, op->type, op->args[0], op->args[1], os);
} }
} }
void CodeGenC::VisitExpr_(const Cast *op, std::ostream& os) { // NOLINT(*)
this->PrintType(op->type, os);
os << '(';
this->PrintExpr(op->value, os);
os << ')';
}
void CodeGenC::VisitExpr_(const Variable *op, std::ostream& os) { // NOLINT(*)
os << GetVarID(op);
}
void CodeGenC::VisitExpr_(const Add *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "+", os, this);
}
void CodeGenC::VisitExpr_(const Sub *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "-", os, this);
}
void CodeGenC::VisitExpr_(const Mul *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "*", os, this);
}
void CodeGenC::VisitExpr_(const Div *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "/", os, this);
}
void CodeGenC::VisitExpr_(const Mod *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "%", os, this);
}
void CodeGenC::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "min", os, this);
}
void CodeGenC::VisitExpr_(const Max *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "max", os, this);
}
void CodeGenC::VisitExpr_(const EQ *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "==", os, this);
}
void CodeGenC::VisitExpr_(const NE *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "!=", os, this);
}
void CodeGenC::VisitExpr_(const LT *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "<", os, this);
}
void CodeGenC::VisitExpr_(const LE *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "<=", os, this);
}
void CodeGenC::VisitExpr_(const GT *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, ">", os, this);
}
void CodeGenC::VisitExpr_(const GE *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, ">=", os, this);
}
void CodeGenC::VisitExpr_(const And *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "&&", os, this);
}
void CodeGenC::VisitExpr_(const Or *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "||", os, this);
}
void CodeGenC::VisitExpr_(const Not *op, std::ostream& os) { // NOLINT(*)
os << '!';
PrintExpr(op->a, os);
}
TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr) void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
.set_dispatch<Cast>([](const Cast *op, std::ostream& os, CodeGenC *p) { // NOLINT(*)
p->PrintType(op->type, os);
os << '(';
p->PrintExpr(op->value, os);
os << ')';
})
.set_dispatch<Variable>([](const Variable *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
os << p->GetVarID(op);
})
.set_dispatch<Add>([](const Add *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, "+", os, p);
})
.set_dispatch<Sub>([](const Sub *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, "-", os, p);
})
.set_dispatch<Mul>([](const Mul *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, "*", os, p);
})
.set_dispatch<Div>([](const Div *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, "/", os, p);
})
.set_dispatch<Mod>([](const Mod *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, "%", os, p);
})
.set_dispatch<Min>([](const Min *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, "min", os, p);
})
.set_dispatch<Max>([](const Max *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, "max", os, p);
})
.set_dispatch<EQ>([](const EQ *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, "==", os, p);
})
.set_dispatch<NE>([](const NE *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, "!=", os, p);
})
.set_dispatch<LT>([](const LT *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, "<", os, p);
})
.set_dispatch<LE>([](const LE *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, "<=", os, p);
})
.set_dispatch<GT>([](const GT *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, ">", os, p);
})
.set_dispatch<GE>([](const GE *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, ">=", os, p);
})
.set_dispatch<And>([](const And *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, "&&", os, p);
})
.set_dispatch<Or>([](const Or *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, "||", os, p);
})
.set_dispatch<Not>([](const Not *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
os << '!';
p->PrintExpr(op->a, os);
});
TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt)
.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, CodeGenC* p) {
p->PrintStmt(op->body);
})
.set_dispatch<Block>([](const Block *op, CodeGenC* p) {
p->PrintStmt(op->first);
if (op->rest.defined()) p->PrintStmt(op->rest);
})
.set_dispatch<Evaluate>([](const Evaluate *op, CodeGenC* p) {
if (is_const(op->value)) return;
const Call* call = op->value.as<Call>();
if (call && call->is_intrinsic(intrinsic::tvm_storage_sync)) {
p->PrintStorageSync(call->args[0].as<StringImm>()->value);
} else {
std::string vid = p->PrintExpr(op->value);
p->PrintIndent();
p->stream << "(void)" << vid << ";\n";
}
});
#define DISPATCH_EXPR(OP) \
set_dispatch<OP>([](const OP *op, std::ostream&os, CodeGenC* p) { \
p->PrintExpr(op, os); })
TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr)
.DISPATCH_EXPR(Load)
.DISPATCH_EXPR(Call)
.DISPATCH_EXPR(Let)
.DISPATCH_EXPR(Ramp)
.DISPATCH_EXPR(Broadcast)
.DISPATCH_EXPR(Select);
void CodeGenC::PrintExpr(const Call *op, std::ostream& os) { // NOLINT(*)
CodeGenC* p = this;
if (op->is_intrinsic(Call::bitwise_and)) { if (op->is_intrinsic(Call::bitwise_and)) {
PrintBinaryIntrinsitc(op, " & ", os, p); PrintBinaryIntrinsitc(op, " & ", os, this);
} else if (op->is_intrinsic(Call::bitwise_xor)) { } else if (op->is_intrinsic(Call::bitwise_xor)) {
PrintBinaryIntrinsitc(op, " ^ ", os, p); PrintBinaryIntrinsitc(op, " ^ ", os, this);
} else if (op->is_intrinsic(Call::bitwise_or)) { } else if (op->is_intrinsic(Call::bitwise_or)) {
PrintBinaryIntrinsitc(op, " | ", os, p); PrintBinaryIntrinsitc(op, " | ", os, this);
} else if (op->is_intrinsic(Call::bitwise_not)) { } else if (op->is_intrinsic(Call::bitwise_not)) {
CHECK_EQ(op->args.size(), 1U); CHECK_EQ(op->args.size(), 1U);
os << "(~"; os << "(~";
p->PrintExpr(op->args[0], os); this->PrintExpr(op->args[0], os);
os << ')'; os << ')';
} else if (op->is_intrinsic(Call::shift_left)) { } else if (op->is_intrinsic(Call::shift_left)) {
PrintBinaryIntrinsitc(op, " << ", os, p); PrintBinaryIntrinsitc(op, " << ", os, this);
} else if (op->is_intrinsic(Call::shift_right)) { } else if (op->is_intrinsic(Call::shift_right)) {
PrintBinaryIntrinsitc(op, " >> ", os, p); PrintBinaryIntrinsitc(op, " >> ", os, this);
} else if (op->is_intrinsic(Call::address_of)) { } else if (op->is_intrinsic(Call::address_of)) {
const Load *l = op->args[0].as<Load>(); const Load *l = op->args[0].as<Load>();
CHECK(op->args.size() == 1 && l); CHECK(op->args.size() == 1 && l);
os << "(("; os << "((";
p->PrintType(l->type.element_of(), os); this->PrintType(l->type.element_of(), os);
os << " *)" << p->GetVarID(l->buffer_var.get()) os << " *)" << this->GetVarID(l->buffer_var.get())
<< " + "; << " + ";
p->PrintExpr(l->index, os); this->PrintExpr(l->index, os);
os << ')'; os << ')';
} else if (op->is_intrinsic(intrinsic::tvm_api_load_arg)) { } else if (op->is_intrinsic(intrinsic::tvm_api_load_arg)) {
CHECK_EQ(op->args.size(), 3U); CHECK_EQ(op->args.size(), 3U);
if (!op->type.is_handle()) { if (!op->type.is_handle()) {
os << '('; os << '(';
p->PrintType(op->type, os); this->PrintType(op->type, os);
os << ')'; os << ')';
} }
os << "(((TVMArg*)"; os << "(((TVMArg*)";
p->PrintExpr(op->args[0], os); this->PrintExpr(op->args[0], os);
os << ")[" << op->args[2] << "]."; os << ")[" << op->args[2] << "].";
if (op->type.is_handle()) { if (op->type.is_handle()) {
os << "v_handle"; os << "v_handle";
...@@ -460,7 +505,7 @@ void CodeGenC::PrintExpr(const Call *op, std::ostream& os) { // NOLINT(*) ...@@ -460,7 +505,7 @@ void CodeGenC::PrintExpr(const Call *op, std::ostream& os) { // NOLINT(*)
} else if (op->is_intrinsic(intrinsic::tvm_array_get_field)) { } else if (op->is_intrinsic(intrinsic::tvm_array_get_field)) {
CHECK_EQ(op->args.size(), 2U); CHECK_EQ(op->args.size(), 2U);
os << "(((TVMArray*)"; os << "(((TVMArray*)";
p->PrintExpr(op->args[0], os); this->PrintExpr(op->args[0], os);
os << ")->"; os << ")->";
switch (op->args[1].as<IntImm>()->value) { switch (op->args[1].as<IntImm>()->value) {
case intrinsic::kData: os << "data"; break; case intrinsic::kData: os << "data"; break;
...@@ -476,12 +521,12 @@ void CodeGenC::PrintExpr(const Call *op, std::ostream& os) { // NOLINT(*) ...@@ -476,12 +521,12 @@ void CodeGenC::PrintExpr(const Call *op, std::ostream& os) { // NOLINT(*)
} else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
CHECK_EQ(op->args.size(), 1U); CHECK_EQ(op->args.size(), 1U);
os << "("; os << "(";
p->PrintExpr(op->args[0], os); this->PrintExpr(op->args[0], os);
os << " == NULL)"; os << " == NULL)";
} else { } else {
os << op->name << "("; os << op->name << "(";
for (size_t i = 0; i < op->args.size(); i++) { for (size_t i = 0; i < op->args.size(); i++) {
p->PrintExpr(op->args[i], os); this->PrintExpr(op->args[i], os);
if (i < op->args.size() - 1) { if (i < op->args.size() - 1) {
os << ", "; os << ", ";
} }
...@@ -517,51 +562,7 @@ inline bool TryGetRamp1Base(Expr index, int lanes, Expr *base) { ...@@ -517,51 +562,7 @@ inline bool TryGetRamp1Base(Expr index, int lanes, Expr *base) {
return true; return true;
} }
// Print a reference expression to a buffer. void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
void CodeGenC::PrintBufferRef(
const Variable* buffer,
Type t, Expr index,
std::ostream& os) { // NOLINT(*)
std::string vid = GetVarID(buffer);
if (t.lanes() == 1) {
if (!HandleTypeMatch(buffer, t)) {
os << "((";
PrintType(t, os);
os << "*)" << vid << ')';
} else {
os << vid;
}
os << '[';
PrintExpr(index, os);
os << ']';
} else {
// Buffer declared as vector type.
// optimize for case where it is in register,
if (HandleTypeMatch(buffer, t)) {
// optimize for constant access
int offset;
if (arith::GetConstInt(index, &offset)) {
CHECK_EQ(offset % t.lanes(), 0)
<< "Find unaligned vector load to a vector type";
os << vid << '[' << (offset / t.lanes()) << ']';
return;
}
}
os << "((";
PrintType(t, os);
os << "*)(";
if (!HandleTypeMatch(buffer, t.element_of())) {
os << '(';
PrintType(t.element_of(), os);
os << "*)";
}
os << vid << " + ";
PrintExpr(index, os);
os << "))[0]";
}
}
void CodeGenC::PrintExpr(const Load* op, std::ostream& os) { // NOLINT(*)
int lanes = op->type.lanes(); int lanes = op->type.lanes();
if (op->type.lanes() == 1) { if (op->type.lanes() == 1) {
this->PrintBufferRef(op->buffer_var.get(), op->type, op->index, os); this->PrintBufferRef(op->buffer_var.get(), op->type, op->index, os);
...@@ -600,7 +601,7 @@ void CodeGenC::PrintExpr(const Load* op, std::ostream& os) { // NOLINT(*) ...@@ -600,7 +601,7 @@ void CodeGenC::PrintExpr(const Load* op, std::ostream& os) { // NOLINT(*)
} }
} }
void CodeGenC::PrintStmt(const Store* op) { void CodeGenC::VisitStmt_(const Store* op) {
Type t = op->value.type(); Type t = op->value.type();
if (t.lanes() == 1) { if (t.lanes() == 1) {
this->PrintIndent(); this->PrintIndent();
...@@ -637,35 +638,7 @@ void CodeGenC::PrintStmt(const Store* op) { ...@@ -637,35 +638,7 @@ void CodeGenC::PrintStmt(const Store* op) {
} }
} }
void CodeGenC::PrintVecElemLoad(const std::string& vec, void CodeGenC::VisitExpr_(const Let* op, std::ostream& os) { // NOLINT(*)
Type t, int i,
std::ostream& os) { // NOLINT(*)
os << vec << ".s" << std::hex << i;
}
void CodeGenC::PrintVecElemStore(const std::string& vec,
Type t, int i,
const std::string& value) {
this->PrintIndent();
stream << vec << ".s" << std::hex << i
<< " = " << value << ";\n";
}
void CodeGenC::PrintVecLoad(const Variable* buffer,
Type t, Expr base,
std::ostream& os) {
PrintBufferRef(buffer, t, base, os);
}
void CodeGenC::PrintVecStore(const Variable* buffer,
Type t, Expr base,
const std::string& value) {
this->PrintIndent();
PrintBufferRef(buffer, t, base, stream);
stream << " = " << value << ";\n";
}
void CodeGenC::PrintExpr(const Let* op, std::ostream& os) { // NOLINT(*)
CHECK(print_ssa_form_) CHECK(print_ssa_form_)
<< "LetExpr is only supported by print SSA form"; << "LetExpr is only supported by print SSA form";
std::string value = PrintExpr(op->value); std::string value = PrintExpr(op->value);
...@@ -673,41 +646,19 @@ void CodeGenC::PrintExpr(const Let* op, std::ostream& os) { // NOLINT(*) ...@@ -673,41 +646,19 @@ void CodeGenC::PrintExpr(const Let* op, std::ostream& os) { // NOLINT(*)
var_idmap_[op->var.get()] = value; var_idmap_[op->var.get()] = value;
} }
void CodeGenC::PrintExpr(const Ramp* op, std::ostream& os) { // NOLINT(*) void CodeGenC::VisitExpr_(const Ramp* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Ramp: not supported "; LOG(FATAL) << "Ramp: not supported ";
} }
void CodeGenC::PrintExpr(const Broadcast* op, std::ostream& os) { // NOLINT(*) void CodeGenC::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Broadcast: not supported "; LOG(FATAL) << "Broadcast: not supported ";
} }
void CodeGenC::PrintExpr(const Select* op, std::ostream& os) { // NOLINT(*) void CodeGenC::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Select: not supported "; LOG(FATAL) << "Select: not supported ";
} }
// Disoatch back to member functions void CodeGenC::VisitStmt_(const LetStmt* op) {
TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt)
.set_dispatch<LetStmt>([](const LetStmt *op, CodeGenC* p) { p->PrintStmt(op); })
.set_dispatch<Store>([](const Store *op, CodeGenC* p) { p->PrintStmt(op); })
.set_dispatch<Allocate>([](const Allocate *op, CodeGenC* p) { p->PrintStmt(op); })
.set_dispatch<AttrStmt>([](const AttrStmt *op, CodeGenC* p) { p->PrintStmt(op); })
.set_dispatch<AssertStmt>([](const AssertStmt *op, CodeGenC* p) { p->PrintStmt(op); })
.set_dispatch<For>([](const For *op, CodeGenC* p) { p->PrintStmt(op); })
.set_dispatch<IfThenElse>([](const IfThenElse *op, CodeGenC* p) { p->PrintStmt(op); });
void CodeGenC::PrintThreadIndexExpr(
std::string thread_tag, std::ostream& os) { // NOLINT(*)
os << thread_tag;
}
void CodeGenC::PrintStorageSync(const std::string& sync) { // NOLINT(*)
}
void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
CHECK_EQ(scope, "global");
}
void CodeGenC::PrintStmt(const LetStmt* op) {
std::string value = PrintExpr(op->value); std::string value = PrintExpr(op->value);
if (print_ssa_form_) { if (print_ssa_form_) {
CHECK(!var_idmap_.count(op->var.get())); CHECK(!var_idmap_.count(op->var.get()));
...@@ -732,7 +683,7 @@ void CodeGenC::PrintStmt(const LetStmt* op) { ...@@ -732,7 +683,7 @@ void CodeGenC::PrintStmt(const LetStmt* op) {
PrintStmt(op->body); PrintStmt(op->body);
} }
void CodeGenC::PrintStmt(const Allocate* op) { void CodeGenC::VisitStmt_(const Allocate* op) {
CHECK(!is_zero(op->condition)); CHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get()); std::string vid = AllocVarID(op->buffer_var.get());
if (op->new_expr.defined()) { if (op->new_expr.defined()) {
...@@ -758,7 +709,7 @@ void CodeGenC::PrintStmt(const Allocate* op) { ...@@ -758,7 +709,7 @@ void CodeGenC::PrintStmt(const Allocate* op) {
this->PrintStmt(op->body); this->PrintStmt(op->body);
} }
void CodeGenC::PrintStmt(const AttrStmt* op) { void CodeGenC::VisitStmt_(const AttrStmt* op) {
if (op->type_key == ir::attr::thread_extent) { if (op->type_key == ir::attr::thread_extent) {
IterVar iv(op->node.node_); IterVar iv(op->node.node_);
if (iv->thread_tag.length() != 0) { if (iv->thread_tag.length() != 0) {
...@@ -780,7 +731,7 @@ void CodeGenC::PrintStmt(const AttrStmt* op) { ...@@ -780,7 +731,7 @@ void CodeGenC::PrintStmt(const AttrStmt* op) {
this->PrintStmt(op->body); this->PrintStmt(op->body);
} }
void CodeGenC::PrintStmt(const AssertStmt* op) { void CodeGenC::VisitStmt_(const AssertStmt* op) {
std::string cond = PrintExpr(op->condition); std::string cond = PrintExpr(op->condition);
PrintIndent(); PrintIndent();
if (op->message.as<StringImm>()) { if (op->message.as<StringImm>()) {
...@@ -792,19 +743,7 @@ void CodeGenC::PrintStmt(const AssertStmt* op) { ...@@ -792,19 +743,7 @@ void CodeGenC::PrintStmt(const AssertStmt* op) {
} }
} }
int CodeGenC::BeginScope() { void CodeGenC::VisitStmt_(const For* op) {
int sid = static_cast<int>(scope_mark_.size());
scope_mark_.push_back(true);
indent += 2;
return sid;
}
void CodeGenC::EndScope(int scope_id) {
scope_mark_[scope_id] = false;
indent -= 2;
}
void CodeGenC::PrintStmt(const For* op) {
std::string extent = PrintExpr(op->extent); std::string extent = PrintExpr(op->extent);
PrintIndent(); PrintIndent();
std::string vid = AllocVarID(op->loop_var.get()); std::string vid = AllocVarID(op->loop_var.get());
...@@ -821,7 +760,7 @@ void CodeGenC::PrintStmt(const For* op) { ...@@ -821,7 +760,7 @@ void CodeGenC::PrintStmt(const For* op) {
stream << "}\n"; stream << "}\n";
} }
void CodeGenC::PrintStmt(const IfThenElse* op) { void CodeGenC::VisitStmt_(const IfThenElse* op) {
std::string cond = PrintExpr(op->condition); std::string cond = PrintExpr(op->condition);
PrintIndent(); PrintIndent();
stream << "if (" << cond << ") {\n"; stream << "if (" << cond << ") {\n";
...@@ -840,6 +779,27 @@ void CodeGenC::PrintStmt(const IfThenElse* op) { ...@@ -840,6 +779,27 @@ void CodeGenC::PrintStmt(const IfThenElse* op) {
stream << "}\n"; stream << "}\n";
} }
void CodeGenC::VisitStmt_(const Block *op) {
PrintStmt(op->first);
if (op->rest.defined()) PrintStmt(op->rest);
}
void CodeGenC::VisitStmt_(const Evaluate *op) {
if (is_const(op->value)) return;
const Call* call = op->value.as<Call>();
if (call && call->is_intrinsic(intrinsic::tvm_storage_sync)) {
this->PrintStorageSync(call->args[0].as<StringImm>()->value);
} else {
std::string vid = this->PrintExpr(op->value);
this->PrintIndent();
this->stream << "(void)" << vid << ";\n";
}
}
void CodeGenC::VisitStmt_(const ProducerConsumer *op) {
PrintStmt(op->body);
}
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#define TVM_CODEGEN_CODEGEN_C_H_ #define TVM_CODEGEN_CODEGEN_C_H_
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <tvm/lowered_func.h> #include <tvm/lowered_func.h>
#include <string> #include <string>
...@@ -16,12 +17,15 @@ ...@@ -16,12 +17,15 @@
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
using namespace ir;
/*! /*!
* \brief A base class to generate C code. * \brief A base class to generate C code.
* *
* CodeGenC have two modes: generate SSA formed C code or normal form. * CodeGenC have two modes: generate SSA formed C code or normal form.
*/ */
class CodeGenC { class CodeGenC :
public ExprFunctor<void(const Expr&, std::ostream&)>,
public StmtFunctor<void(const Stmt&)> {
public: public:
/*! /*!
* \brief Initialize the code generator. * \brief Initialize the code generator.
...@@ -42,13 +46,15 @@ class CodeGenC { ...@@ -42,13 +46,15 @@ class CodeGenC {
* \brief Print the Stmt n to CodeGenC->stream * \brief Print the Stmt n to CodeGenC->stream
* \param n The statement to be printed. * \param n The statement to be printed.
*/ */
void PrintStmt(const Stmt& n); void PrintStmt(const Stmt& n) {
VisitStmt(n);
}
/*! /*!
* \brief Print the expression n(or its ssa id if in ssa mode) into os * \brief Print the expression n(or its ssa id if in ssa mode) into os
* \param n The expression to be printed. * \param n The expression to be printed.
* \param os The output stream * \param os The output stream
*/ */
void PrintExpr(const Expr& n, std::ostream& os); // NOLINT(*) void PrintExpr(const Expr& n, std::ostream& os);
/*! /*!
* \brief Same as PrintExpr, but simply returns result string * \brief Same as PrintExpr, but simply returns result string
* \param n The expression to be printed. * \param n The expression to be printed.
...@@ -84,6 +90,46 @@ class CodeGenC { ...@@ -84,6 +90,46 @@ class CodeGenC {
* \param f The function to be compiled. * \param f The function to be compiled.
*/ */
virtual void InitFuncState(LoweredFunc f); virtual void InitFuncState(LoweredFunc f);
// expression
void VisitExpr_(const Variable* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Load* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Let* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Call* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Add* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Sub* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Mul* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Div* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Mod* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Min* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Max* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const EQ* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const NE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LT* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GT* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const And* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Or* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Cast* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Not* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Select* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Ramp* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Broadcast* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const IntImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const UIntImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const FloatImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const StringImm* op, std::ostream& os) override; // NOLINT(*)
// statment
void VisitStmt_(const LetStmt* op) override;
void VisitStmt_(const Store* op) override;
void VisitStmt_(const For* op) override;
void VisitStmt_(const IfThenElse* op) override;
void VisitStmt_(const Allocate* op) override;
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
/*! /*!
* Print Type represetnation of type t. * Print Type represetnation of type t.
* \param t The type representation. * \param t The type representation.
...@@ -97,50 +143,37 @@ class CodeGenC { ...@@ -97,50 +143,37 @@ class CodeGenC {
*/ */
virtual void PrintThreadIndexExpr( virtual void PrintThreadIndexExpr(
std::string tag, std::ostream& os); // NOLINT(*) std::string tag, std::ostream& os); // NOLINT(*)
virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(* virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*)
virtual void PrintStorageSync(const std::string& scope); // NOLINT(*) virtual void PrintStorageSync(const std::string& scope); // NOLINT(*)
virtual void PrintStmt(const ir::LetStmt* op);
virtual void PrintStmt(const ir::Store* op);
virtual void PrintStmt(const ir::For* op);
virtual void PrintStmt(const ir::IfThenElse* op);
virtual void PrintStmt(const ir::Allocate* op);
virtual void PrintStmt(const ir::AttrStmt* op);
virtual void PrintStmt(const ir::AssertStmt* op);
virtual void PrintExpr(const ir::Load* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Call* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Let* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Ramp* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Broadcast* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Select* op, std::ostream& os); // NOLINT(*)
// Binary vector op. // Binary vector op.
virtual void PrintVecBinaryOp( virtual void PrintVecBinaryOp(
const std::string&op, Type op_type, const std::string&op, Type op_type,
Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*) Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*)
// print vector load
virtual void PrintVecLoad(const Variable* buffer, virtual void PrintVecLoad(const Variable* buffer,
Type t, Expr base, Type t, Expr base,
std::ostream& os); // NOLINT(*) std::ostream& os); // NOLINT(*)
// print vector store
virtual void PrintVecStore(const Variable* buffer, virtual void PrintVecStore(const Variable* buffer,
Type t, Expr base, Type t, Expr base,
const std::string& value); // NOLINT(*) const std::string& value); // NOLINT(*)
// print load of single element
virtual void PrintVecElemLoad( virtual void PrintVecElemLoad(
const std::string& vec, Type t, int i, std::ostream& os); // NOLINT(*) const std::string& vec, Type t, int i, std::ostream& os); // NOLINT(*)
// print store of single element.
virtual void PrintVecElemStore( virtual void PrintVecElemStore(
const std::string& vec, Type t, int i, const std::string& value); const std::string& vec, Type t, int i, const std::string& value);
/*! \brief function print into the ostream */
using FPrintExpr = IRFunctor<void(const NodeRef&, std::ostream& os, CodeGenC *)>; // NOLINT(*)
/*! \brief function to to print normal code */
using FPrintStmt = IRFunctor<void(const NodeRef&, CodeGenC *)>;
// vtable to print code
static FPrintStmt& vtable_print_stmt();
// vtable to print code
static FPrintExpr& vtable_print_expr();
/*! \brief The current indentation value */
int indent{0};
/*! \brief the stream to be printed */
std::ostringstream stream;
protected: protected:
/*! \brief the stream to be printed */
std::ostringstream stream;
/*! \brief entry in ssa assign map */
struct SSAEntry {
/*! \brief The value id */
std::string vid;
/*! \brief The scope id */
int scope_id;
};
// print reference to a buffer as type t in index. // print reference to a buffer as type t in index.
void PrintBufferRef(const Variable* buffer, void PrintBufferRef(const Variable* buffer,
Type t, Expr index, Type t, Expr index,
...@@ -158,13 +191,6 @@ class CodeGenC { ...@@ -158,13 +191,6 @@ class CodeGenC {
* \return The returned name. * \return The returned name.
*/ */
std::string GetUniqueName(std::string prefix); std::string GetUniqueName(std::string prefix);
/*! \brief entry in ssa assign map */
struct SSAEntry {
/*! \brief The value id */
std::string vid;
/*! \brief The scope id */
int scope_id;
};
/*! /*!
* \brief mark the beginning of a new scope * \brief mark the beginning of a new scope
* \return The scope id. * \return The scope id.
...@@ -209,6 +235,8 @@ class CodeGenC { ...@@ -209,6 +235,8 @@ class CodeGenC {
std::unordered_map<const Variable*, Type> handle_data_type_; std::unordered_map<const Variable*, Type> handle_data_type_;
/*! \brief array to check whether we are inside certain scope */ /*! \brief array to check whether we are inside certain scope */
std::vector<bool> scope_mark_; std::vector<bool> scope_mark_;
/*! \brief The current indentation value */
int indent{0};
}; };
} // namespace codegen } // namespace codegen
......
...@@ -19,7 +19,7 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) { ...@@ -19,7 +19,7 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) {
CodeGenC::AddFunction(f); CodeGenC::AddFunction(f);
} }
void CodeGenCUDA::PrintStmt(const ir::For* op) { void CodeGenCUDA::VisitStmt_(const ir::For* op) {
int ext; int ext;
CHECK(is_zero(op->min)); CHECK(is_zero(op->min));
if (arith::GetConstInt(op->extent, &ext) && if (arith::GetConstInt(op->extent, &ext) &&
...@@ -27,7 +27,7 @@ void CodeGenCUDA::PrintStmt(const ir::For* op) { ...@@ -27,7 +27,7 @@ void CodeGenCUDA::PrintStmt(const ir::For* op) {
PrintIndent(); PrintIndent();
stream << "#pragma unroll\n"; stream << "#pragma unroll\n";
} }
CodeGenC::PrintStmt(op); CodeGenC::VisitStmt_(op);
} }
void CodeGenCUDA::PrintType(Type t, std::ostream& os) const { // NOLINT(*) void CodeGenCUDA::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
......
...@@ -18,7 +18,7 @@ class CodeGenCUDA : public CodeGenC { ...@@ -18,7 +18,7 @@ class CodeGenCUDA : public CodeGenC {
public: public:
void AddFunction(LoweredFunc f); void AddFunction(LoweredFunc f);
// override behavior // override behavior
void PrintStmt(const ir::For* op) final; void VisitStmt_(const ir::For* op) final;
void PrintStorageSync(const std::string& sync) final; void PrintStorageSync(const std::string& sync) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintVecBinaryOp( void PrintVecBinaryOp(
......
...@@ -130,7 +130,7 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) { ...@@ -130,7 +130,7 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
llvm::BasicBlock* block = llvm::BasicBlock::Create(*ctx_, "entry", function_); llvm::BasicBlock* block = llvm::BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(block); builder_->SetInsertPoint(block);
this->Visit(f->body); this->VisitStmt(f->body);
builder_->CreateRet(ConstInt32(0)); builder_->CreateRet(ConstInt32(0));
} }
...@@ -222,240 +222,369 @@ llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const { ...@@ -222,240 +222,369 @@ llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const {
return ret; return ret;
} }
void CodeGenLLVM::Visit_(const Variable* op) { llvm::BasicBlock* CodeGenLLVM::CheckCallSuccess(llvm::Value* retcode) {
value_ = GetVarValue(op); // create emit codes that checks and load the function.
using llvm::BasicBlock;
BasicBlock* fail_block = BasicBlock::Create(
*ctx_, "call_fail", function_);
BasicBlock* end_block = BasicBlock::Create(
*ctx_, "call_end", function_);
llvm::Value* succ = builder_->CreateICmpEQ(
retcode, llvm::ConstantInt::get(t_int_, 0));
builder_->CreateCondBr(succ, end_block, fail_block, md_very_likely_branch_);
builder_->SetInsertPoint(fail_block);
// return the code.
builder_->CreateRet(retcode);
// otherwise set it to be new end.
builder_->SetInsertPoint(end_block);
return end_block;
} }
void CodeGenLLVM::Visit_(const Cast* op) { void CodeGenLLVM::AddAliasInfo(
value_ = CreateCast(op->value.type(), op->type, MakeValue(op->value)); llvm::Instruction* inst, const Variable* buffer, Expr index) {
} int base = 0, width = 0;
// create meta-data for alias analysis
// Use a group of binary tree ranges.
const Ramp* ramp = index.as<Ramp>();
if (ramp) {
int base, stride;
if (arith::GetConstInt(ramp->base, &base) &&
arith::GetConstInt(ramp->stride, &stride)) {
int xwith = ramp->lanes * stride;
width = 1;
while (width < xwith) {
width *= 2;
}
while (base % width) {
base -= base % width;
width *= 2;
}
}
} else {
if (arith::GetConstInt(index, &base)) width = 1;
}
void CodeGenLLVM::Visit_(const IntImm* op) { llvm::MDNode* meta = md_tbaa_root_;
value_ = llvm::ConstantInt::getSigned(LLVMType(op->type), op->value); std::ostringstream buffer_addr;
buffer_addr << buffer;
meta = md_builder_->createTBAAScalarTypeNode(buffer_addr.str(), meta);
// create a tree-shape access structure.
if (width != 0) {
for (int w = 1024; w >= width; w /= 2) {
int b = (base / w) * w;
std::stringstream os;
os << buffer << ".w" << w << ".b" << b;
meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta);
}
}
inst->setMetadata(
"tbaa",
md_builder_->createTBAAStructTagNode(meta, meta, 0));
} }
void CodeGenLLVM::Visit_(const UIntImm* op) { llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
value_ = llvm::ConstantInt::get(LLVMType(op->type), op->value); llvm::Constant* init = llvm::UndefValue::get(
llvm::VectorType::get(value->getType(), lanes));
llvm::Constant* zero = ConstInt32(0);
value = builder_->CreateInsertElement(init, value, zero);
llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero);
return builder_->CreateShuffleVector(value, init, mask);
} }
void CodeGenLLVM::Visit_(const FloatImm* op) { llvm::Value* CodeGenLLVM::CreateBufferPtr(
value_ = llvm::ConstantFP::get(LLVMType(op->type), op->value); Type t, llvm::Value* buffer, llvm::Value* index) {
} llvm::Type* elem_type = buffer->getType();
unsigned address_space = elem_type->getPointerAddressSpace();
llvm::Type* load_type = LLVMType(t)->getPointerTo(address_space);
void CodeGenLLVM::Visit_(const StringImm* op) { if (load_type != elem_type) {
value_ = GetConstString(op->value); buffer = builder_->CreatePointerCast(buffer, load_type);
}
llvm::Constant* cindex = llvm::dyn_cast<llvm::Constant>(index);
if (cindex && cindex->isZeroValue()) {
return buffer;
}
return builder_->CreateInBoundsGEP(buffer, index);
} }
#define DEFINE_CODEGEN_BINARY_OP(OP) \ llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) {
llvm::Value* CodeGenLLVM::Create ## OP( \ llvm::Type * target = LLVMType(to);
Type t, llvm::Value* a, llvm::Value *b) { \ if (value->getType() == target) return value;
if (t.is_float()) { \ if (from.is_handle() && from.is_handle()) {
return builder_->CreateF ## OP (a, b); \ return builder_->CreateBitCast(value, target);
} else if (t.is_int() && t.bits() >= 32) { \ } else if (!from.is_float() && !to.is_float()) {
return builder_->CreateNSW ## OP (a, b); \ return builder_->CreateIntCast(value, target, from.is_int());
} else { \ } else if (from.is_float() && to.is_int()) {
return builder_->Create ## OP (a, b); \ return builder_->CreateFPToSI(value, target);
} \ } else if (from.is_float() && to.is_uint()) {
} \ if (to.bits() < 8) {
value = builder_->CreateFPToUI(value, LLVMType(to.with_bits(8)));
DEFINE_CODEGEN_BINARY_OP(Add); return builder_->CreateIntCast(value, target, false);
DEFINE_CODEGEN_BINARY_OP(Sub); } else {
DEFINE_CODEGEN_BINARY_OP(Mul); return builder_->CreateFPToUI(value, target);
}
void CodeGenLLVM::Visit_(const Add* op) { } else if (from.is_int() && to.is_float()) {
value_ = CreateAdd(op->type, MakeValue(op->a), MakeValue(op->b)); return builder_->CreateSIToFP(value, target);
} else if (from.is_uint() && to.is_float()) {
return builder_->CreateUIToFP(value, target);
} else {
CHECK(from.is_float() && to.is_float());
return builder_->CreateFPCast(value, target);
}
} }
void CodeGenLLVM::Visit_(const Sub* op) { llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) {
value_ = CreateSub(op->type, MakeValue(op->a), MakeValue(op->b)); using llvm::BasicBlock;
// We will store the packed function handle in global space.
// Initialize it during the first call.
llvm::DataLayout layout(module_.get());
uint64_t align = layout.getTypeAllocSize(t_tvm_func_handle_);
auto it = func_handle_map_.find(fname);
llvm::GlobalVariable* hptr;
if (it == func_handle_map_.end()) {
// create global location for the handle
// create the function handle
hptr = new llvm::GlobalVariable(
*module_, t_tvm_func_handle_, false,
llvm::GlobalValue::PrivateLinkage, 0, ".tvm_func");
hptr->setAlignment(align);
hptr->setInitializer(llvm::Constant::getNullValue(t_tvm_func_handle_));
func_handle_map_[fname] = hptr;
} else {
hptr = it->second;
}
// create emit codes that checks and load the function.
BasicBlock* pre_block = builder_->GetInsertBlock();
BasicBlock* init_block = BasicBlock::Create(
*ctx_, "handle_init", function_);
BasicBlock* end_block = BasicBlock::Create(
*ctx_, "handle_init_end", function_);
llvm::Value* handle = builder_->CreateAlignedLoad(hptr, align);
llvm::Value* handle_not_null = builder_->CreateICmpNE(
handle, llvm::Constant::getNullValue(t_tvm_func_handle_));
builder_->CreateCondBr(
handle_not_null, end_block, init_block, md_very_likely_branch_);
// Initialize the handle if needed.
builder_->SetInsertPoint(init_block);
llvm::Value* out = builder_->CreateAlloca(t_tvm_func_handle_);
llvm::Value* ctx = builder_->CreateLoad(gv_mod_ctx_);
llvm::Value* retcode = builder_->CreateCall(
f_tvm_get_func_from_env_, {ctx, GetConstString(fname), out});
init_block = CheckCallSuccess(retcode);
llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align);
builder_->CreateBr(end_block);
// end block
builder_->SetInsertPoint(end_block);
llvm::PHINode* phi = builder_->CreatePHI(t_tvm_func_handle_, 2);
phi->addIncoming(handle, pre_block);
phi->addIncoming(loaded_handle, init_block);
return phi;
} }
void CodeGenLLVM::Visit_(const Mul* op) { llvm::Value* CodeGenLLVM::CreateCallPacked(const Call* op) {
value_ = CreateMul(op->type, MakeValue(op->a), MakeValue(op->b)); CHECK_GE(op->args.size(), 1U);
std::string func_name = op->args[0].as<StringImm>()->value;
llvm::Value* handle = GetPackedFuncHandle(func_name);
// call the function
unsigned nargs = static_cast<unsigned>(op->args.size() - 1);
llvm::Value* targs = builder_->CreateAlloca(
t_tvm_value_, ConstInt32(nargs));
llvm::Value* tcodes = builder_->CreateAlloca(
t_int_, ConstInt32(nargs));
for (unsigned i = 0; i < nargs; ++i) {
Expr expr = op->args[i + 1];
Type t = expr.type();
CHECK_EQ(t.lanes(), 1);
// Always pass via 64 bit value.
// For handle type, Handle(64) maps to 32 bit void* in 32bit platform.
Type api_type = t.with_bits(64);
llvm::Value* value = CreateCast(t, api_type, MakeValue(expr));
llvm::Value* store_ptr = builder_->CreatePointerCast(
builder_->CreateInBoundsGEP(targs, ConstInt32(i)),
LLVMType(api_type)->getPointerTo());
builder_->CreateAlignedStore(value, store_ptr, 8);
builder_->CreateAlignedStore(
ConstInt32(t.code()),
builder_->CreateInBoundsGEP(tcodes, ConstInt32(i)), 4);
}
llvm::Value* ret_value = builder_->CreateAlloca(t_tvm_value_);
llvm::Value* ret_tcode = builder_->CreateAlloca(t_int_);
CheckCallSuccess(
builder_->CreateCall(
f_tvm_func_call_,
{handle, targs, tcodes, ConstInt32(nargs), ret_value, ret_tcode}));
Type r_type = op->type;
Type r_api_type = op->type.with_bits(64);
llvm::Value* rvalue =
builder_->CreateAlignedLoad(
builder_->CreatePointerCast(
ret_value, LLVMType(r_api_type)->getPointerTo()), 8);
rvalue = CreateCast(r_api_type, r_type, rvalue);
return rvalue;
} }
void CodeGenLLVM::Visit_(const Div* op) { llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
llvm::Value* a = MakeValue(op->a); std::vector<llvm::Value*> arg_values(op->args.size());
int shift; for (size_t i = 0; i < op->args.size(); ++i) {
if (op->type.is_float()) { arg_values[i] = MakeValue(op->args[i]);
value_ = builder_->CreateFDiv(a, MakeValue(op->b)); }
} else if ((op->type.is_int() || op->type.is_uint()) && if (op->type.is_scalar()) {
is_const_power_of_two_integer(op->b, &shift)) { llvm::Function* f = module_->getFunction(op->name);
value_ = builder_->CreateAShr(a, shift); if (f) {
return builder_->CreateCall(f, arg_values);
} else {
LOG(FATAL) << "cannot find function " << op->name;
}
} else { } else {
llvm::Value* b = MakeValue(op->b); llvm::Function* f = module_->getFunction(op->name);
if (op->type.is_int()) { if (f) {
value_ = builder_->CreateSDiv(a, b); return CreateScalarizedCall(op, f, arg_values);
} else { } else {
CHECK(op->type.is_uint()); LOG(FATAL) << "cannot find function " << op->name;
value_ = builder_->CreateUDiv(a, b);
} }
} }
return nullptr;
} }
void CodeGenLLVM::Visit_(const Mod* op) { llvm::Value* CodeGenLLVM::CreateScalarizedCall(
CHECK(!op->type.is_float()) const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args) {
<< "Cannot do mod for float"; llvm::Value* value = llvm::UndefValue::get(LLVMType(op->type));
if (op->type.is_int()) { for (int i = 0; i < op->type.lanes(); ++i) {
value_ = builder_->CreateSRem(MakeValue(op->a), MakeValue(op->b)); std::vector<llvm::Value*> sargs(args.size());
} else { for (size_t j = 0; j < args.size(); ++j) {
CHECK(op->type.is_uint()); if (args[j]->getType()->isVectorTy()) {
value_ = builder_->CreateURem(MakeValue(op->a), MakeValue(op->b)); sargs[j] = builder_->CreateExtractElement(args[j], ConstInt32(i));
} else {
sargs[j] = args[j];
}
}
llvm::CallInst* call = builder_->CreateCall(f, sargs);
if (op->is_pure()) {
call->setDoesNotAccessMemory();
}
call->setDoesNotThrow();
if (!call->getType()->isVoidTy()) {
value = builder_->CreateInsertElement(value, call, ConstInt32(i));
}
} }
return value;
} }
void CodeGenLLVM::Visit_(const Min* op) { llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const {
llvm::Value* a = MakeValue(op->a); auto it = var_map_.find(v);
llvm::Value* b = MakeValue(op->b); CHECK(it != var_map_.end())
llvm::Value* cond = CreateLT(op->a.type(), a, b); << "Cannot find " << v->name_hint << " in the var map";
value_ = builder_->CreateSelect(cond, a, b); return it->second;
}
void CodeGenLLVM::Visit_(const Max* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
llvm::Value* cond = CreateGT(op->a.type(), a, b);
value_ = builder_->CreateSelect(cond, a, b);
}
#define DEFINE_CODEGEN_CMP_OP(OP) \
llvm::Value* CodeGenLLVM::Create ## OP( \
Type t, llvm::Value* a, llvm::Value* b) { \
if (t.is_float()) { \
return builder_->CreateFCmpO ## OP (a, b); \
} else if (t.is_int()) { \
return builder_->CreateICmpS ## OP (a, b); \
} else { \
return builder_->CreateICmpU ## OP (a, b); \
} \
} \
DEFINE_CODEGEN_CMP_OP(LT);
DEFINE_CODEGEN_CMP_OP(LE);
DEFINE_CODEGEN_CMP_OP(GT);
DEFINE_CODEGEN_CMP_OP(GE);
void CodeGenLLVM::Visit_(const LT* op) {
value_ = CreateLT(op->a.type(), MakeValue(op->a), MakeValue(op->b));
}
void CodeGenLLVM::Visit_(const LE* op) {
value_ = CreateLE(op->a.type(), MakeValue(op->a), MakeValue(op->b));
}
void CodeGenLLVM::Visit_(const GT* op) {
value_ = CreateGT(op->a.type(), MakeValue(op->a), MakeValue(op->b));
}
void CodeGenLLVM::Visit_(const GE* op) {
value_ = CreateGE(op->a.type(), MakeValue(op->a), MakeValue(op->b));
}
void CodeGenLLVM::Visit_(const EQ* op) {
if (op->a.type().is_float()) {
value_ = builder_->CreateFCmpOEQ(MakeValue(op->a), MakeValue(op->b));
} else {
value_ = builder_->CreateICmpEQ(MakeValue(op->a), MakeValue(op->b));
}
} }
void CodeGenLLVM::Visit_(const NE* op) { llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) {
if (op->a.type().is_float()) { auto it = str_map_.find(str);
value_ = builder_->CreateFCmpONE(MakeValue(op->a), MakeValue(op->b)); if (it == str_map_.end()) {
llvm::Type* type = llvm::ArrayType::get(t_char_, str.length() + 1);
llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, true, llvm::GlobalValue::PrivateLinkage, 0, ".str");
global->setAlignment(1);
global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, str));
// useful constant value
llvm::Constant* zero = ConstInt32(0);
llvm::Constant* indices[] = {zero, zero};
llvm::Constant* sptr = llvm::ConstantExpr::getGetElementPtr(
type, global, indices);
str_map_[str] = sptr;
return sptr;
} else { } else {
value_ = builder_->CreateICmpNE(MakeValue(op->a), MakeValue(op->b)); return it->second;
} }
} }
void CodeGenLLVM::Visit_(const And* op) { void CodeGenLLVM::CreateParallelFor(const For* op) {
value_ = builder_->CreateAnd(MakeValue(op->a), MakeValue(op->b)); using llvm::BasicBlock;
} llvm::Value* min = MakeValue(op->min);
llvm::Value* extent = MakeValue(op->extent);
void CodeGenLLVM::Visit_(const Or* op) { min = builder_->CreateIntCast(min, t_int64_, op->min.type().is_int());
value_ = builder_->CreateOr(MakeValue(op->a), MakeValue(op->b)); extent = builder_->CreateIntCast(extent, t_int64_, op->min.type().is_int());
} // fields to be packed into closure.
Var loop_var(op->loop_var.node_);
void CodeGenLLVM::Visit_(const Not* op) { Array<Var> vfields = ir::UndefinedVars(op->body, {loop_var});
value_ = builder_->CreateNot(MakeValue(op->a)); std::vector<llvm::Type*> fields;
} for (Var v : vfields) {
auto it = var_map_.find(v.get());
void CodeGenLLVM::Visit_(const Select* op) { CHECK(it != var_map_.end());
value_ = builder_->CreateSelect( fields.push_back(it->second->getType());
MakeValue(op->condition),
MakeValue(op->true_value),
MakeValue(op->false_value));
}
void CodeGenLLVM::Visit_(const Let* op) {
llvm::Value* v = MakeValue(op->value);
CHECK(!var_map_.count(op->var.get()));
var_map_[op->var.get()] = v;
value_ = MakeValue(op->body);
}
void CodeGenLLVM::Visit_(const Broadcast* op) {
value_ = CreateBroadcast(MakeValue(op->value), op->lanes);
}
void CodeGenLLVM::Visit_(const Ramp* op) {
Type t = op->type;
llvm::Value* base = MakeValue(op->base);
llvm::Value* stride = MakeValue(op->stride);
llvm::Value* value = llvm::UndefValue::get(LLVMType(t));
for (int i = 0; i < t.lanes(); ++i) {
if (i != 0) {
base = CreateAdd(t, base, stride);
}
value = builder_->CreateInsertElement(
value, base, llvm::ConstantInt::get(t_int32_, i));
} }
value_ = value; // closure data
} llvm::StructType* tcdata = llvm::StructType::create(fields);
llvm::Function* f = llvm::Function::Create(
void CodeGenLLVM::Visit_(const Load* op) { t_f_tvm_par_for_lambda_,
Type t = op->type; llvm::Function::PrivateLinkage,
CHECK(!t.is_vector()); "__tvm_par_for_lambda", module_.get());
// allocate and setup the closure, call the closure.
llvm::Value* cdata = builder_->CreateAlloca(tcdata, ConstInt32(1));
llvm::Value* zero = ConstInt32(0);
if (t.is_scalar()) { for (size_t i = 0; i < vfields.size(); ++i) {
llvm::LoadInst* inst = builder_->CreateAlignedLoad( builder_->CreateStore(
CreateBufferPtr( var_map_.at(vfields[i].get()),
t, builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)}));
GetVarValue(op->buffer_var.get()),
MakeValue(op->index)),
data_layout_->getTypeAllocSize(LLVMType(t)));
AddAliasInfo(inst, op->buffer_var.get(), op->index);
value_ = inst;
} else {
LOG(FATAL) << "not yet supported";
} }
} BasicBlock* par_for_end = CheckCallSuccess(
builder_->CreateCall(
void CodeGenLLVM::Visit_(const Store* op) { f_tvm_parallel_for_,
llvm::Value* value = MakeValue(op->value); {min, extent, f, builder_->CreatePointerCast(cdata, t_void_p_)}));
Type t = op->value.type(); // Setup the closure function.
CHECK(!t.is_vector()); BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
if (t.is_scalar()) { builder_->SetInsertPoint(lambda_entry);
llvm::StoreInst* inst = builder_->CreateAlignedStore( auto it = f->arg_begin();
value, llvm::Value* begin = &(*it++);
CreateBufferPtr( llvm::Value* end = &(*it++);
t, cdata = &(*it++);
GetVarValue(op->buffer_var.get()), begin = CreateCast(Int(64), op->loop_var.type(), begin);
MakeValue(op->index)), end = CreateCast(Int(64), op->loop_var.type(), end);
data_layout_->getTypeAllocSize(value->getType())); cdata = builder_->CreatePointerCast(cdata, tcdata->getPointerTo());
AddAliasInfo(inst, op->buffer_var.get(), op->index); // setup new variable map, swap it with current var context.
} else { std::unordered_map<const Variable*, llvm::Value*> new_vmap;
LOG(FATAL) << "not yet supported"; for (size_t i = 0; i < vfields.size(); ++i) {
new_vmap[vfields[i].get()] =
builder_->CreateLoad(builder_->CreateInBoundsGEP(
cdata, {zero, ConstInt32(i)}));
} }
std::swap(function_, f);
std::swap(new_vmap, var_map_);
CreateSerialFor(begin, end, op->loop_var, op->body);
builder_->CreateRet(ConstInt32(0));
// swap the var map back, now we are back on track.
std::swap(new_vmap, var_map_);
std::swap(function_, f);
builder_->SetInsertPoint(par_for_end);
} }
void CodeGenLLVM::Visit_(const Call* op) { void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end,
if (op->is_intrinsic(intrinsic::tvm_call_packed)) { const VarExpr& loop_var, const Stmt& body) {
value_ = CreateCallPacked(op); using llvm::BasicBlock;
} else if (op->call_type == Call::Intrinsic || Type t = loop_var.type();
op->call_type == Call::PureIntrinsic) { BasicBlock* for_head = BasicBlock::Create(
value_ = CreateIntrinstic(op); *ctx_, "for_head", function_);
} else { BasicBlock* for_body = BasicBlock::Create(
CHECK(op->call_type == Call::Extern || *ctx_, "for_body", function_);
op->call_type == Call::PureExtern); BasicBlock* for_end = BasicBlock::Create(
value_ = CreateCallExtern(op); *ctx_, "for_end", function_);
} BasicBlock* pre_block = builder_->GetInsertBlock();
builder_->CreateBr(for_head);
builder_->SetInsertPoint(for_head);
llvm::PHINode* index = builder_->CreatePHI(begin->getType(), 2);
index->addIncoming(begin, pre_block);
llvm::Value* cond = CreateLT(t, index, end);
builder_->CreateCondBr(cond, for_body, for_end, md_very_likely_branch_);
// body of for
builder_->SetInsertPoint(for_body);
var_map_[loop_var.get()] = index;
this->VisitStmt(body);
llvm::Value* next_index = CreateAdd(t, index, ConstInt32(1));
index->addIncoming(next_index, builder_->GetInsertBlock());
builder_->CreateBr(for_head);
// end of for
builder_->SetInsertPoint(for_end);
} }
llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) { llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
...@@ -555,70 +684,292 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) { ...@@ -555,70 +684,292 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
return nullptr; return nullptr;
} }
llvm::BasicBlock* CodeGenLLVM::CheckCallSuccess(llvm::Value* retcode) { // visitor overrides
// create emit codes that checks and load the function. llvm::Value* CodeGenLLVM::VisitExpr_(const Variable* op) {
using llvm::BasicBlock; return GetVarValue(op);
BasicBlock* fail_block = BasicBlock::Create(
*ctx_, "call_fail", function_);
BasicBlock* end_block = BasicBlock::Create(
*ctx_, "call_end", function_);
llvm::Value* succ = builder_->CreateICmpEQ(
retcode, llvm::ConstantInt::get(t_int_, 0));
builder_->CreateCondBr(succ, end_block, fail_block, md_very_likely_branch_);
builder_->SetInsertPoint(fail_block);
// return the code.
builder_->CreateRet(retcode);
// otherwise set it to be new end.
builder_->SetInsertPoint(end_block);
return end_block;
} }
void CodeGenLLVM::Visit_(const For* op) {
CHECK(is_zero(op->min)); llvm::Value* CodeGenLLVM::VisitExpr_(const Cast* op) {
if (op->for_type == ForType::Serial) { return CreateCast(op->value.type(), op->type, MakeValue(op->value));
CreateSerialFor(ConstInt32(0), MakeValue(op->extent),
op->loop_var, op->body);
} else if (op->for_type == ForType::Parallel) {
CreateParallelFor(op);
} else {
LOG(FATAL) << "cannot handle for type " << op->for_type;
}
} }
void CodeGenLLVM::Visit_(const IfThenElse* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const IntImm* op) {
using llvm::BasicBlock; return llvm::ConstantInt::getSigned(LLVMType(op->type), op->value);
BasicBlock* then_block = BasicBlock::Create(
*ctx_, "if_then", function_);
BasicBlock* else_block = BasicBlock::Create(
*ctx_, "if_else", function_);
BasicBlock* end_block = BasicBlock::Create(
*ctx_, "if_end", function_);
if (!op->else_case.defined()) {
else_block = end_block;
}
// condition.
llvm::Value* cond = MakeValue(op->condition);
bool likely = true;
if (likely) {
builder_->CreateCondBr(cond, then_block, else_block, md_very_likely_branch_);
} else {
builder_->CreateCondBr(cond, then_block, else_block);
}
// then case.
builder_->SetInsertPoint(then_block);
this->Visit(op->then_case);
builder_->CreateBr(end_block);
// else case.
if (op->else_case.defined()) {
builder_->SetInsertPoint(else_block);
this->Visit(op->else_case);
builder_->CreateBr(end_block);
}
builder_->SetInsertPoint(end_block);
} }
void CodeGenLLVM::Visit_(const Allocate* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const UIntImm* op) {
CHECK(!is_zero(op->condition)); return llvm::ConstantInt::get(LLVMType(op->type), op->value);
llvm::Value* buf = nullptr; }
llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImm* op) {
return llvm::ConstantFP::get(LLVMType(op->type), op->value);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const StringImm* op) {
return GetConstString(op->value);
}
#define DEFINE_CODEGEN_BINARY_OP(OP) \
llvm::Value* CodeGenLLVM::Create ## OP( \
Type t, llvm::Value* a, llvm::Value *b) { \
if (t.is_float()) { \
return builder_->CreateF ## OP (a, b); \
} else if (t.is_int() && t.bits() >= 32) { \
return builder_->CreateNSW ## OP (a, b); \
} else { \
return builder_->Create ## OP (a, b); \
} \
} \
DEFINE_CODEGEN_BINARY_OP(Add);
DEFINE_CODEGEN_BINARY_OP(Sub);
DEFINE_CODEGEN_BINARY_OP(Mul);
llvm::Value* CodeGenLLVM::VisitExpr_(const Add* op) {
return CreateAdd(op->type, MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Sub* op) {
return CreateSub(op->type, MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Mul* op) {
return CreateMul(op->type, MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Div* op) {
llvm::Value* a = MakeValue(op->a);
int shift;
if (op->type.is_float()) {
return builder_->CreateFDiv(a, MakeValue(op->b));
} else if ((op->type.is_int() || op->type.is_uint()) &&
is_const_power_of_two_integer(op->b, &shift)) {
return builder_->CreateAShr(a, shift);
} else {
llvm::Value* b = MakeValue(op->b);
if (op->type.is_int()) {
return builder_->CreateSDiv(a, b);
} else {
CHECK(op->type.is_uint());
return builder_->CreateUDiv(a, b);
}
}
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Mod* op) {
CHECK(!op->type.is_float())
<< "Cannot do mod for float";
if (op->type.is_int()) {
return builder_->CreateSRem(MakeValue(op->a), MakeValue(op->b));
} else {
CHECK(op->type.is_uint());
return builder_->CreateURem(MakeValue(op->a), MakeValue(op->b));
}
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Min* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
llvm::Value* cond = CreateLT(op->a.type(), a, b);
return builder_->CreateSelect(cond, a, b);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Max* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
llvm::Value* cond = CreateGT(op->a.type(), a, b);
return builder_->CreateSelect(cond, a, b);
}
#define DEFINE_CODEGEN_CMP_OP(OP) \
llvm::Value* CodeGenLLVM::Create ## OP( \
Type t, llvm::Value* a, llvm::Value* b) { \
if (t.is_float()) { \
return builder_->CreateFCmpO ## OP (a, b); \
} else if (t.is_int()) { \
return builder_->CreateICmpS ## OP (a, b); \
} else { \
return builder_->CreateICmpU ## OP (a, b); \
} \
} \
DEFINE_CODEGEN_CMP_OP(LT);
DEFINE_CODEGEN_CMP_OP(LE);
DEFINE_CODEGEN_CMP_OP(GT);
DEFINE_CODEGEN_CMP_OP(GE);
llvm::Value* CodeGenLLVM::VisitExpr_(const LT* op) {
return CreateLT(op->a.type(), MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const LE* op) {
return CreateLE(op->a.type(), MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const GT* op) {
return CreateGT(op->a.type(), MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const GE* op) {
return CreateGE(op->a.type(), MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const EQ* op) {
if (op->a.type().is_float()) {
return builder_->CreateFCmpOEQ(MakeValue(op->a), MakeValue(op->b));
} else {
return builder_->CreateICmpEQ(MakeValue(op->a), MakeValue(op->b));
}
}
llvm::Value* CodeGenLLVM::VisitExpr_(const NE* op) {
if (op->a.type().is_float()) {
return builder_->CreateFCmpONE(MakeValue(op->a), MakeValue(op->b));
} else {
return builder_->CreateICmpNE(MakeValue(op->a), MakeValue(op->b));
}
}
llvm::Value* CodeGenLLVM::VisitExpr_(const And* op) {
return builder_->CreateAnd(MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Or* op) {
return builder_->CreateOr(MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Not* op) {
return builder_->CreateNot(MakeValue(op->a));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Select* op) {
return builder_->CreateSelect(
MakeValue(op->condition),
MakeValue(op->true_value),
MakeValue(op->false_value));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Let* op) {
llvm::Value* v = MakeValue(op->value);
CHECK(!var_map_.count(op->var.get()));
var_map_[op->var.get()] = v;
return MakeValue(op->body);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Broadcast* op) {
return CreateBroadcast(MakeValue(op->value), op->lanes);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Ramp* op) {
Type t = op->type;
llvm::Value* base = MakeValue(op->base);
llvm::Value* stride = MakeValue(op->stride);
llvm::Value* value = llvm::UndefValue::get(LLVMType(t));
for (int i = 0; i < t.lanes(); ++i) {
if (i != 0) {
base = CreateAdd(t, base, stride);
}
value = builder_->CreateInsertElement(
value, base, llvm::ConstantInt::get(t_int32_, i));
}
return value;
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
Type t = op->type;
CHECK(!t.is_vector());
if (t.is_scalar()) {
llvm::LoadInst* inst = builder_->CreateAlignedLoad(
CreateBufferPtr(
t,
GetVarValue(op->buffer_var.get()),
MakeValue(op->index)),
data_layout_->getTypeAllocSize(LLVMType(t)));
AddAliasInfo(inst, op->buffer_var.get(), op->index);
return inst;
} else {
LOG(FATAL) << "not yet supported";
return nullptr;
}
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
return CreateCallPacked(op);
} else if (op->call_type == Call::Intrinsic ||
op->call_type == Call::PureIntrinsic) {
return CreateIntrinstic(op);
} else {
CHECK(op->call_type == Call::Extern ||
op->call_type == Call::PureExtern);
return CreateCallExtern(op);
}
}
// stmts
void CodeGenLLVM::VisitStmt_(const Store* op) {
llvm::Value* value = MakeValue(op->value);
Type t = op->value.type();
CHECK(!t.is_vector());
if (t.is_scalar()) {
llvm::StoreInst* inst = builder_->CreateAlignedStore(
value,
CreateBufferPtr(
t,
GetVarValue(op->buffer_var.get()),
MakeValue(op->index)),
data_layout_->getTypeAllocSize(value->getType()));
AddAliasInfo(inst, op->buffer_var.get(), op->index);
} else {
LOG(FATAL) << "not yet supported";
}
}
void CodeGenLLVM::VisitStmt_(const For* op) {
CHECK(is_zero(op->min));
if (op->for_type == ForType::Serial) {
CreateSerialFor(ConstInt32(0), MakeValue(op->extent),
op->loop_var, op->body);
} else if (op->for_type == ForType::Parallel) {
CreateParallelFor(op);
} else {
LOG(FATAL) << "cannot handle for type " << op->for_type;
}
}
void CodeGenLLVM::VisitStmt_(const IfThenElse* op) {
using llvm::BasicBlock;
BasicBlock* then_block = BasicBlock::Create(
*ctx_, "if_then", function_);
BasicBlock* else_block = BasicBlock::Create(
*ctx_, "if_else", function_);
BasicBlock* end_block = BasicBlock::Create(
*ctx_, "if_end", function_);
if (!op->else_case.defined()) {
else_block = end_block;
}
// condition.
llvm::Value* cond = MakeValue(op->condition);
bool likely = true;
if (likely) {
builder_->CreateCondBr(cond, then_block, else_block, md_very_likely_branch_);
} else {
builder_->CreateCondBr(cond, then_block, else_block);
}
// then case.
builder_->SetInsertPoint(then_block);
this->VisitStmt(op->then_case);
builder_->CreateBr(end_block);
// else case.
if (op->else_case.defined()) {
builder_->SetInsertPoint(else_block);
this->VisitStmt(op->else_case);
builder_->CreateBr(end_block);
}
builder_->SetInsertPoint(end_block);
}
void CodeGenLLVM::VisitStmt_(const Allocate* op) {
CHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr;
if (op->new_expr.defined()) { if (op->new_expr.defined()) {
CHECK_EQ(op->free_function, "nop"); CHECK_EQ(op->free_function, "nop");
buf = MakeValue(op->new_expr); buf = MakeValue(op->new_expr);
...@@ -634,11 +985,11 @@ void CodeGenLLVM::Visit_(const Allocate* op) { ...@@ -634,11 +985,11 @@ void CodeGenLLVM::Visit_(const Allocate* op) {
var_map_[op->buffer_var.get()] = buf; var_map_[op->buffer_var.get()] = buf;
} }
void CodeGenLLVM::Visit_(const AttrStmt* op) { void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
this->Visit(op->body); this->VisitStmt(op->body);
} }
void CodeGenLLVM::Visit_(const AssertStmt* op) { void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
using llvm::BasicBlock; using llvm::BasicBlock;
llvm::Value* cond = MakeValue(op->condition); llvm::Value* cond = MakeValue(op->condition);
std::ostringstream os; std::ostringstream os;
...@@ -660,359 +1011,23 @@ void CodeGenLLVM::Visit_(const AssertStmt* op) { ...@@ -660,359 +1011,23 @@ void CodeGenLLVM::Visit_(const AssertStmt* op) {
builder_->SetInsertPoint(end_block); builder_->SetInsertPoint(end_block);
} }
void CodeGenLLVM::Visit_(const LetStmt* op) { void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
llvm::Value* v = MakeValue(op->value); llvm::Value* v = MakeValue(op->value);
CHECK(!var_map_.count(op->var.get())); CHECK(!var_map_.count(op->var.get()));
var_map_[op->var.get()] = v; var_map_[op->var.get()] = v;
this->Visit(op->body); this->VisitStmt(op->body);
} }
void CodeGenLLVM::VisitStmt_(const Block* op) {
void CodeGenLLVM::AddAliasInfo( VisitStmt(op->first);
llvm::Instruction* inst, const Variable* buffer, Expr index) { if (op->rest.defined()) VisitStmt(op->rest);
int base = 0, width = 0; }
// create meta-data for alias analysis void CodeGenLLVM::VisitStmt_(const Evaluate *op) {
// Use a group of binary tree ranges. MakeValue(op->value);
const Ramp* ramp = index.as<Ramp>(); }
if (ramp) { void CodeGenLLVM::VisitStmt_(const ProducerConsumer* op) {
int base, stride; VisitStmt(op->body);
if (arith::GetConstInt(ramp->base, &base) &&
arith::GetConstInt(ramp->stride, &stride)) {
int xwith = ramp->lanes * stride;
width = 1;
while (width < xwith) {
width *= 2;
}
while (base % width) {
base -= base % width;
width *= 2;
}
}
} else {
if (arith::GetConstInt(index, &base)) width = 1;
}
llvm::MDNode* meta = md_tbaa_root_;
std::ostringstream buffer_addr;
buffer_addr << buffer;
meta = md_builder_->createTBAAScalarTypeNode(buffer_addr.str(), meta);
// create a tree-shape access structure.
if (width != 0) {
for (int w = 1024; w >= width; w /= 2) {
int b = (base / w) * w;
std::stringstream os;
os << buffer << ".w" << w << ".b" << b;
meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta);
}
}
inst->setMetadata(
"tbaa",
md_builder_->createTBAAStructTagNode(meta, meta, 0));
} }
llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
llvm::Constant* init = llvm::UndefValue::get(
llvm::VectorType::get(value->getType(), lanes));
llvm::Constant* zero = ConstInt32(0);
value = builder_->CreateInsertElement(init, value, zero);
llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero);
return builder_->CreateShuffleVector(value, init, mask);
}
llvm::Value* CodeGenLLVM::CreateBufferPtr(
Type t, llvm::Value* buffer, llvm::Value* index) {
llvm::Type* elem_type = buffer->getType();
unsigned address_space = elem_type->getPointerAddressSpace();
llvm::Type* load_type = LLVMType(t)->getPointerTo(address_space);
if (load_type != elem_type) {
buffer = builder_->CreatePointerCast(buffer, load_type);
}
llvm::Constant* cindex = llvm::dyn_cast<llvm::Constant>(index);
if (cindex && cindex->isZeroValue()) {
return buffer;
}
return builder_->CreateInBoundsGEP(buffer, index);
}
llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) {
llvm::Type * target = LLVMType(to);
if (value->getType() == target) return value;
if (from.is_handle() && from.is_handle()) {
return builder_->CreateBitCast(value, target);
} else if (!from.is_float() && !to.is_float()) {
return builder_->CreateIntCast(value, target, from.is_int());
} else if (from.is_float() && to.is_int()) {
return builder_->CreateFPToSI(value, target);
} else if (from.is_float() && to.is_uint()) {
if (to.bits() < 8) {
value = builder_->CreateFPToUI(value, LLVMType(to.with_bits(8)));
return builder_->CreateIntCast(value, target, false);
} else {
return builder_->CreateFPToUI(value, target);
}
} else if (from.is_int() && to.is_float()) {
return builder_->CreateSIToFP(value, target);
} else if (from.is_uint() && to.is_float()) {
return builder_->CreateUIToFP(value, target);
} else {
CHECK(from.is_float() && to.is_float());
return builder_->CreateFPCast(value, target);
}
}
llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) {
using llvm::BasicBlock;
// We will store the packed function handle in global space.
// Initialize it during the first call.
llvm::DataLayout layout(module_.get());
uint64_t align = layout.getTypeAllocSize(t_tvm_func_handle_);
auto it = func_handle_map_.find(fname);
llvm::GlobalVariable* hptr;
if (it == func_handle_map_.end()) {
// create global location for the handle
// create the function handle
hptr = new llvm::GlobalVariable(
*module_, t_tvm_func_handle_, false,
llvm::GlobalValue::PrivateLinkage, 0, ".tvm_func");
hptr->setAlignment(align);
hptr->setInitializer(llvm::Constant::getNullValue(t_tvm_func_handle_));
func_handle_map_[fname] = hptr;
} else {
hptr = it->second;
}
// create emit codes that checks and load the function.
BasicBlock* pre_block = builder_->GetInsertBlock();
BasicBlock* init_block = BasicBlock::Create(
*ctx_, "handle_init", function_);
BasicBlock* end_block = BasicBlock::Create(
*ctx_, "handle_init_end", function_);
llvm::Value* handle = builder_->CreateAlignedLoad(hptr, align);
llvm::Value* handle_not_null = builder_->CreateICmpNE(
handle, llvm::Constant::getNullValue(t_tvm_func_handle_));
builder_->CreateCondBr(
handle_not_null, end_block, init_block, md_very_likely_branch_);
// Initialize the handle if needed.
builder_->SetInsertPoint(init_block);
llvm::Value* out = builder_->CreateAlloca(t_tvm_func_handle_);
llvm::Value* ctx = builder_->CreateLoad(gv_mod_ctx_);
llvm::Value* retcode = builder_->CreateCall(
f_tvm_get_func_from_env_, {ctx, GetConstString(fname), out});
init_block = CheckCallSuccess(retcode);
llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align);
builder_->CreateBr(end_block);
// end block
builder_->SetInsertPoint(end_block);
llvm::PHINode* phi = builder_->CreatePHI(t_tvm_func_handle_, 2);
phi->addIncoming(handle, pre_block);
phi->addIncoming(loaded_handle, init_block);
return phi;
}
llvm::Value* CodeGenLLVM::CreateCallPacked(const Call* op) {
CHECK_GE(op->args.size(), 1U);
std::string func_name = op->args[0].as<StringImm>()->value;
llvm::Value* handle = GetPackedFuncHandle(func_name);
// call the function
unsigned nargs = static_cast<unsigned>(op->args.size() - 1);
llvm::Value* targs = builder_->CreateAlloca(
t_tvm_value_, ConstInt32(nargs));
llvm::Value* tcodes = builder_->CreateAlloca(
t_int_, ConstInt32(nargs));
for (unsigned i = 0; i < nargs; ++i) {
Expr expr = op->args[i + 1];
Type t = expr.type();
CHECK_EQ(t.lanes(), 1);
// Always pass via 64 bit value.
// For handle type, Handle(64) maps to 32 bit void* in 32bit platform.
Type api_type = t.with_bits(64);
llvm::Value* value = CreateCast(t, api_type, MakeValue(expr));
llvm::Value* store_ptr = builder_->CreatePointerCast(
builder_->CreateInBoundsGEP(targs, ConstInt32(i)),
LLVMType(api_type)->getPointerTo());
builder_->CreateAlignedStore(value, store_ptr, 8);
builder_->CreateAlignedStore(
ConstInt32(t.code()),
builder_->CreateInBoundsGEP(tcodes, ConstInt32(i)), 4);
}
llvm::Value* ret_value = builder_->CreateAlloca(t_tvm_value_);
llvm::Value* ret_tcode = builder_->CreateAlloca(t_int_);
CheckCallSuccess(
builder_->CreateCall(
f_tvm_func_call_,
{handle, targs, tcodes, ConstInt32(nargs), ret_value, ret_tcode}));
Type r_type = op->type;
Type r_api_type = op->type.with_bits(64);
llvm::Value* rvalue =
builder_->CreateAlignedLoad(
builder_->CreatePointerCast(
ret_value, LLVMType(r_api_type)->getPointerTo()), 8);
rvalue = CreateCast(r_api_type, r_type, rvalue);
return rvalue;
}
llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
std::vector<llvm::Value*> arg_values(op->args.size());
for (size_t i = 0; i < op->args.size(); ++i) {
arg_values[i] = MakeValue(op->args[i]);
}
if (op->type.is_scalar()) {
llvm::Function* f = module_->getFunction(op->name);
if (f) {
return builder_->CreateCall(f, arg_values);
} else {
LOG(FATAL) << "cannot find function " << op->name;
}
} else {
llvm::Function* f = module_->getFunction(op->name);
if (f) {
return CreateScalarizedCall(op, f, arg_values);
} else {
LOG(FATAL) << "cannot find function " << op->name;
}
}
return nullptr;
}
llvm::Value* CodeGenLLVM::CreateScalarizedCall(
const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args) {
llvm::Value* value = llvm::UndefValue::get(LLVMType(op->type));
for (int i = 0; i < op->type.lanes(); ++i) {
std::vector<llvm::Value*> sargs(args.size());
for (size_t j = 0; j < args.size(); ++j) {
if (args[j]->getType()->isVectorTy()) {
sargs[j] = builder_->CreateExtractElement(args[j], ConstInt32(i));
} else {
sargs[j] = args[j];
}
}
llvm::CallInst* call = builder_->CreateCall(f, sargs);
if (op->is_pure()) {
call->setDoesNotAccessMemory();
}
call->setDoesNotThrow();
if (!call->getType()->isVoidTy()) {
value = builder_->CreateInsertElement(value, call, ConstInt32(i));
}
}
return value;
}
llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const {
auto it = var_map_.find(v);
CHECK(it != var_map_.end())
<< "Cannot find " << v->name_hint << " in the var map";
return it->second;
}
llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) {
auto it = str_map_.find(str);
if (it == str_map_.end()) {
llvm::Type* type = llvm::ArrayType::get(t_char_, str.length() + 1);
llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, true, llvm::GlobalValue::PrivateLinkage, 0, ".str");
global->setAlignment(1);
global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, str));
// useful constant value
llvm::Constant* zero = ConstInt32(0);
llvm::Constant* indices[] = {zero, zero};
llvm::Constant* sptr = llvm::ConstantExpr::getGetElementPtr(
type, global, indices);
str_map_[str] = sptr;
return sptr;
} else {
return it->second;
}
}
void CodeGenLLVM::CreateParallelFor(const For* op) {
using llvm::BasicBlock;
llvm::Value* min = MakeValue(op->min);
llvm::Value* extent = MakeValue(op->extent);
min = builder_->CreateIntCast(min, t_int64_, op->min.type().is_int());
extent = builder_->CreateIntCast(extent, t_int64_, op->min.type().is_int());
// fields to be packed into closure.
Var loop_var(op->loop_var.node_);
Array<Var> vfields = ir::UndefinedVars(op->body, {loop_var});
std::vector<llvm::Type*> fields;
for (Var v : vfields) {
auto it = var_map_.find(v.get());
CHECK(it != var_map_.end());
fields.push_back(it->second->getType());
}
// closure data
llvm::StructType* tcdata = llvm::StructType::create(fields);
llvm::Function* f = llvm::Function::Create(
t_f_tvm_par_for_lambda_,
llvm::Function::PrivateLinkage,
"__tvm_par_for_lambda", module_.get());
// allocate and setup the closure, call the closure.
llvm::Value* cdata = builder_->CreateAlloca(tcdata, ConstInt32(1));
llvm::Value* zero = ConstInt32(0);
for (size_t i = 0; i < vfields.size(); ++i) {
builder_->CreateStore(
var_map_.at(vfields[i].get()),
builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)}));
}
BasicBlock* par_for_end = CheckCallSuccess(
builder_->CreateCall(
f_tvm_parallel_for_,
{min, extent, f, builder_->CreatePointerCast(cdata, t_void_p_)}));
// Setup the closure function.
BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
builder_->SetInsertPoint(lambda_entry);
auto it = f->arg_begin();
llvm::Value* begin = &(*it++);
llvm::Value* end = &(*it++);
cdata = &(*it++);
begin = CreateCast(Int(64), op->loop_var.type(), begin);
end = CreateCast(Int(64), op->loop_var.type(), end);
cdata = builder_->CreatePointerCast(cdata, tcdata->getPointerTo());
// setup new variable map, swap it with current var context.
std::unordered_map<const Variable*, llvm::Value*> new_vmap;
for (size_t i = 0; i < vfields.size(); ++i) {
new_vmap[vfields[i].get()] =
builder_->CreateLoad(builder_->CreateInBoundsGEP(
cdata, {zero, ConstInt32(i)}));
}
std::swap(function_, f);
std::swap(new_vmap, var_map_);
CreateSerialFor(begin, end, op->loop_var, op->body);
builder_->CreateRet(ConstInt32(0));
// swap the var map back, now we are back on track.
std::swap(new_vmap, var_map_);
std::swap(function_, f);
builder_->SetInsertPoint(par_for_end);
}
void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end,
const VarExpr& loop_var, const Stmt& body) {
using llvm::BasicBlock;
Type t = loop_var.type();
BasicBlock* for_head = BasicBlock::Create(
*ctx_, "for_head", function_);
BasicBlock* for_body = BasicBlock::Create(
*ctx_, "for_body", function_);
BasicBlock* for_end = BasicBlock::Create(
*ctx_, "for_end", function_);
BasicBlock* pre_block = builder_->GetInsertBlock();
builder_->CreateBr(for_head);
builder_->SetInsertPoint(for_head);
llvm::PHINode* index = builder_->CreatePHI(begin->getType(), 2);
index->addIncoming(begin, pre_block);
llvm::Value* cond = CreateLT(t, index, end);
builder_->CreateCondBr(cond, for_body, for_end, md_very_likely_branch_);
// body of for
builder_->SetInsertPoint(for_body);
var_map_[loop_var.get()] = index;
this->Visit(body);
llvm::Value* next_index = CreateAdd(t, index, ConstInt32(1));
index->addIncoming(next_index, builder_->GetInsertBlock());
builder_->CreateBr(for_head);
// end of for
builder_->SetInsertPoint(for_end);
}
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
#endif // TVM_LLVM_VERSION #endif // TVM_LLVM_VERSION
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#ifdef TVM_LLVM_VERSION #ifdef TVM_LLVM_VERSION
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_functor_ext.h>
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <memory> #include <memory>
#include <vector> #include <vector>
...@@ -23,7 +23,9 @@ using namespace ir; ...@@ -23,7 +23,9 @@ using namespace ir;
/*! /*!
* \brief A base class to generate a LLVM. * \brief A base class to generate a LLVM.
*/ */
class CodeGenLLVM : public IRVisitor { class CodeGenLLVM :
public ExprFunctor<llvm::Value* (const Expr&)>,
public StmtFunctor<void(const Stmt&)> {
public: public:
/*! /*!
* \brief Initialize the code generator with given context * \brief Initialize the code generator with given context
...@@ -55,52 +57,52 @@ class CodeGenLLVM : public IRVisitor { ...@@ -55,52 +57,52 @@ class CodeGenLLVM : public IRVisitor {
* \return created value. * \return created value.
*/ */
llvm::Value* MakeValue(const Expr& e) { llvm::Value* MakeValue(const Expr& e) {
value_ = nullptr; return VisitExpr(e);
this->Visit(e);
CHECK(value_ != nullptr);
return value_;
} }
// Short hande code to get a constant int 32 // Short hande code to get a constant int 32
llvm::Constant* ConstInt32(unsigned value) const { llvm::Constant* ConstInt32(unsigned value) const {
return llvm::ConstantInt::get(t_int32_, value); return llvm::ConstantInt::get(t_int32_, value);
} }
// override codegen // override codegen
void Visit_(const Variable* op) final; llvm::Value* VisitExpr_(const Variable* op) override;
void Visit_(const Cast* op) final; llvm::Value* VisitExpr_(const Cast* op) override;
void Visit_(const IntImm* op) final; llvm::Value* VisitExpr_(const IntImm* op) override;
void Visit_(const UIntImm* op) final; llvm::Value* VisitExpr_(const UIntImm* op) override;
void Visit_(const FloatImm* op) final; llvm::Value* VisitExpr_(const FloatImm* op) override;
void Visit_(const StringImm* op) final; llvm::Value* VisitExpr_(const StringImm* op) override;
void Visit_(const Add* op) final; llvm::Value* VisitExpr_(const Add* op) override;
void Visit_(const Sub* op) final; llvm::Value* VisitExpr_(const Sub* op) override;
void Visit_(const Mul* op) final; llvm::Value* VisitExpr_(const Mul* op) override;
void Visit_(const Div* op) final; llvm::Value* VisitExpr_(const Div* op) override;
void Visit_(const Mod* op) final; llvm::Value* VisitExpr_(const Mod* op) override;
void Visit_(const Min* op) final; llvm::Value* VisitExpr_(const Min* op) override;
void Visit_(const Max* op) final; llvm::Value* VisitExpr_(const Max* op) override;
void Visit_(const LT* op) final; llvm::Value* VisitExpr_(const LT* op) override;
void Visit_(const LE* op) final; llvm::Value* VisitExpr_(const LE* op) override;
void Visit_(const GT* op) final; llvm::Value* VisitExpr_(const GT* op) override;
void Visit_(const GE* op) final; llvm::Value* VisitExpr_(const GE* op) override;
void Visit_(const EQ* op) final; llvm::Value* VisitExpr_(const EQ* op) override;
void Visit_(const NE* op) final; llvm::Value* VisitExpr_(const NE* op) override;
void Visit_(const And* op) final; llvm::Value* VisitExpr_(const And* op) override;
void Visit_(const Or* op) final; llvm::Value* VisitExpr_(const Or* op) override;
void Visit_(const Not* op) final; llvm::Value* VisitExpr_(const Not* op) override;
void Visit_(const Select* op) final; llvm::Value* VisitExpr_(const Select* op) override;
void Visit_(const Let* op) final; llvm::Value* VisitExpr_(const Let* op) override;
void Visit_(const Load* op) final; llvm::Value* VisitExpr_(const Load* op) override;
void Visit_(const Call* op) final; llvm::Value* VisitExpr_(const Call* op) override;
void Visit_(const Ramp* op) final; llvm::Value* VisitExpr_(const Ramp* op) override;
void Visit_(const Broadcast* op) final; llvm::Value* VisitExpr_(const Broadcast* op) override;
// stmt // stmt
void Visit_(const Store* op) final; void VisitStmt_(const Store* op) override;
void Visit_(const For* op) final; void VisitStmt_(const For* op) override;
void Visit_(const IfThenElse* op) final; void VisitStmt_(const IfThenElse* op) override;
void Visit_(const Allocate* op) final; void VisitStmt_(const Allocate* op) override;
void Visit_(const AttrStmt* op) override; void VisitStmt_(const AttrStmt* op) override;
void Visit_(const AssertStmt* op) final; void VisitStmt_(const AssertStmt* op) override;
void Visit_(const LetStmt* op) final; void VisitStmt_(const LetStmt* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
// create intrinstic given call // create intrinstic given call
virtual llvm::Value* CreateIntrinstic(const Call* op); virtual llvm::Value* CreateIntrinstic(const Call* op);
// create extern function call // create extern function call
...@@ -160,8 +162,6 @@ class CodeGenLLVM : public IRVisitor { ...@@ -160,8 +162,6 @@ class CodeGenLLVM : public IRVisitor {
llvm::Function* f_tvm_parallel_for_{nullptr}; llvm::Function* f_tvm_parallel_for_{nullptr};
// The acting body // The acting body
llvm::BasicBlock* block_{nullptr}; llvm::BasicBlock* block_{nullptr};
// Last value returned codegen call.
llvm::Value* value_{nullptr};
private: private:
// comparison op // comparison op
......
...@@ -12,10 +12,6 @@ namespace codegen { ...@@ -12,10 +12,6 @@ namespace codegen {
using namespace ir; using namespace ir;
CodeGenStackVM::FType& CodeGenStackVM::vtable() { // NOLINT(*)
static FType inst; return inst;
}
StackVM CodeGenStackVM::Compile(LoweredFunc f) { StackVM CodeGenStackVM::Compile(LoweredFunc f) {
for (size_t i = 0; i < f->args.size(); ++i) { for (size_t i = 0; i < f->args.size(); ++i) {
Var v = f->args[i]; Var v = f->args[i];
...@@ -27,18 +23,12 @@ StackVM CodeGenStackVM::Compile(LoweredFunc f) { ...@@ -27,18 +23,12 @@ StackVM CodeGenStackVM::Compile(LoweredFunc f) {
} }
void CodeGenStackVM::Push(const Stmt& n) { void CodeGenStackVM::Push(const Stmt& n) {
static const FType& f = vtable(); VisitStmt(n);
f(n, this);
if (debug_) { if (debug_) {
this->PushOp(StackVM::ASSERT_SP, 0); this->PushOp(StackVM::ASSERT_SP, 0);
} }
} }
void CodeGenStackVM::Push(const Expr& n) {
static const FType& f = vtable();
f(n, this);
}
void CodeGenStackVM::PushOp(StackVM::OpCode opcode) { void CodeGenStackVM::PushOp(StackVM::OpCode opcode) {
StackVM::Code code; StackVM::Code code;
code.op_code = opcode; code.op_code = opcode;
...@@ -106,7 +96,7 @@ int CodeGenStackVM::GetVarID(const Variable* v) const { ...@@ -106,7 +96,7 @@ int CodeGenStackVM::GetVarID(const Variable* v) const {
return it->second; return it->second;
} }
void CodeGenStackVM::Push_(const ir::Load* op) { void CodeGenStackVM::VisitExpr_(const Load* op) {
this->PushOp(StackVM::LOAD_HEAP, GetVarID(op->buffer_var.get())); this->PushOp(StackVM::LOAD_HEAP, GetVarID(op->buffer_var.get()));
if (op->type == UInt(32) && op->index.as<IntImm>()) { if (op->type == UInt(32) && op->index.as<IntImm>()) {
this->PushOp(StackVM::ARRAY_LOAD_UINT32, op->index.as<IntImm>()->value); this->PushOp(StackVM::ARRAY_LOAD_UINT32, op->index.as<IntImm>()->value);
...@@ -118,7 +108,8 @@ void CodeGenStackVM::Push_(const ir::Load* op) { ...@@ -118,7 +108,8 @@ void CodeGenStackVM::Push_(const ir::Load* op) {
this->PushOp(StackVM::GetLoad(Type2TVMType(op->type))); this->PushOp(StackVM::GetLoad(Type2TVMType(op->type)));
} }
} }
void CodeGenStackVM::Push_(const ir::Store* op) {
void CodeGenStackVM::VisitStmt_(const Store* op) {
this->PushOp(StackVM::LOAD_HEAP, GetVarID(op->buffer_var.get())); this->PushOp(StackVM::LOAD_HEAP, GetVarID(op->buffer_var.get()));
this->Push(op->index); this->Push(op->index);
this->PushOp(StackVM::PUSH_I64, op->value.type().element_of().bytes()); this->PushOp(StackVM::PUSH_I64, op->value.type().element_of().bytes());
...@@ -128,7 +119,7 @@ void CodeGenStackVM::Push_(const ir::Store* op) { ...@@ -128,7 +119,7 @@ void CodeGenStackVM::Push_(const ir::Store* op) {
this->PushOp(StackVM::GetStore(Type2TVMType(op->value.type()))); this->PushOp(StackVM::GetStore(Type2TVMType(op->value.type())));
} }
void CodeGenStackVM::Push_(const ir::Allocate* op) { void CodeGenStackVM::VisitStmt_(const Allocate* op) {
CHECK(!is_zero(op->condition)); CHECK(!is_zero(op->condition));
int vid = AllocVarID(op->buffer_var.get()); int vid = AllocVarID(op->buffer_var.get());
if (op->new_expr.defined()) { if (op->new_expr.defined()) {
...@@ -141,7 +132,7 @@ void CodeGenStackVM::Push_(const ir::Allocate* op) { ...@@ -141,7 +132,7 @@ void CodeGenStackVM::Push_(const ir::Allocate* op) {
} }
} }
void CodeGenStackVM::Push_(const ir::Call* op) { void CodeGenStackVM::VisitExpr_(const Call* op) {
if (op->is_intrinsic(Call::address_of)) { if (op->is_intrinsic(Call::address_of)) {
const Load *l = op->args[0].as<Load>(); const Load *l = op->args[0].as<Load>();
CHECK(op->args.size() == 1 && l); CHECK(op->args.size() == 1 && l);
...@@ -211,37 +202,30 @@ void CodeGenStackVM::Push_(const ir::Call* op) { ...@@ -211,37 +202,30 @@ void CodeGenStackVM::Push_(const ir::Call* op) {
this->PushOp(StackVM::PUSH_I64, 0); this->PushOp(StackVM::PUSH_I64, 0);
this->PushOp(StackVM::EQ_I64); this->PushOp(StackVM::EQ_I64);
} else { } else {
this->HandleUnknownCall(op); LOG(FATAL) << "unknown function call " << op->name;
} }
} }
void CodeGenStackVM::HandleUnknownCall(const ir::Call* op) { void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64,
LOG(FATAL) << "donot know how to handle call " << op->name; const Expr& a,
} const Expr& b) {
this->Push(a);
inline void PushBinary(StackVM::OpCode op_int64, this->Push(b);
const Expr& a,
const Expr& b,
CodeGenStackVM* p) {
p->Push(a);
p->Push(b);
Type t = a.type(); Type t = a.type();
if (t.is_int()) { if (t.is_int()) {
p->PushOp(op_int64); this->PushOp(op_int64);
} else if (t.is_uint()) { } else if (t.is_uint()) {
if (t.bits() <= 32) { if (t.bits() <= 32) {
p->PushOp(op_int64); this->PushOp(op_int64);
} else { } else {
LOG(FATAL) << "Cannot handle uint64_t in StackVM"; LOG(FATAL) << "Cannot handle uint64_t in StackVM";
} }
} else { } else {
p->PushOp(StackVM::CodeI64ToF64(op_int64)); this->PushOp(StackVM::CodeI64ToF64(op_int64));
} }
} }
inline void PushCast(Type dst, void CodeGenStackVM::PushCast(Type dst, Type src) {
Type src,
CodeGenStackVM* p) {
if (dst.is_int()) { if (dst.is_int()) {
if (src.is_int()) return; if (src.is_int()) return;
if (src.is_uint() && src.bits() <= 32) return; if (src.is_uint() && src.bits() <= 32) return;
...@@ -254,211 +238,226 @@ inline void PushCast(Type dst, ...@@ -254,211 +238,226 @@ inline void PushCast(Type dst,
LOG(FATAL) << "Cannot handle cast " << src << " to " << dst; LOG(FATAL) << "Cannot handle cast " << src << " to " << dst;
} }
TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable) void CodeGenStackVM::VisitExpr_(const StringImm *op) {
.set_dispatch<StringImm>([](const StringImm *op, CodeGenStackVM *p) { int sid = this->GetStrID(op->value);
int sid = p->GetStrID(op->value); this->PushOp(StackVM::PUSH_I64, sid);
p->PushOp(StackVM::PUSH_I64, sid); }
})
.set_dispatch<IntImm>([](const IntImm *op, CodeGenStackVM *p) { void CodeGenStackVM::VisitExpr_(const IntImm *op) {
CHECK(op->value >= std::numeric_limits<int>::min() && CHECK(op->value >= std::numeric_limits<int>::min() &&
op->value <= std::numeric_limits<int>::max()) op->value <= std::numeric_limits<int>::max())
<< "Int constant exceed bound"; << "Int constant exceed bound";
p->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value)); this->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value));
}) }
.set_dispatch<UIntImm>([](const UIntImm *op, CodeGenStackVM *p) {
CHECK(op->value <= std::numeric_limits<int>::max()) void CodeGenStackVM::VisitExpr_(const UIntImm *op) {
<< "Int constant exceed bound"; CHECK(op->value <= std::numeric_limits<int>::max())
p->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value)); << "Int constant exceed bound";
}) this->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value));
.set_dispatch<FloatImm>([](const FloatImm *op, CodeGenStackVM *p) { }
LOG(FATAL) << "Float Imm is not supported";
}); void CodeGenStackVM::VisitExpr_(const FloatImm *op) {
LOG(FATAL) << "Float Imm is not supported";
TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable) }
.set_dispatch<Variable>([](const Variable *op, CodeGenStackVM* p) {
int vid = p->GetVarID(op); void CodeGenStackVM::VisitExpr_(const Variable *op) {
p->PushOp(StackVM::LOAD_HEAP, vid); int vid = this->GetVarID(op);
}) this->PushOp(StackVM::LOAD_HEAP, vid);
.set_dispatch<Cast>([](const Cast *op, CodeGenStackVM* p) { }
p->Push(op->value);
PushCast(op->type, op->value.type(), p); void CodeGenStackVM::VisitExpr_(const Cast *op) {
}) this->Push(op->value);
.set_dispatch<Add>([](const Add *op, CodeGenStackVM* p) { PushCast(op->type, op->value.type());
PushBinary(StackVM::ADD_I64, op->a, op->b, p); }
})
.set_dispatch<Sub>([](const Sub *op, CodeGenStackVM* p) { void CodeGenStackVM::VisitExpr_(const Add *op) {
PushBinary(StackVM::SUB_I64, op->a, op->b, p); PushBinary(StackVM::ADD_I64, op->a, op->b);
}) }
.set_dispatch<Mul>([](const Mul *op, CodeGenStackVM* p) {
PushBinary(StackVM::MUL_I64, op->a, op->b, p); void CodeGenStackVM::VisitExpr_(const Sub *op) {
}) PushBinary(StackVM::SUB_I64, op->a, op->b);
.set_dispatch<Div>([](const Div *op, CodeGenStackVM* p) { }
PushBinary(StackVM::DIV_I64, op->a, op->b, p);
}) void CodeGenStackVM::VisitExpr_(const Mul *op) {
.set_dispatch<Mod>([](const Mod *op, CodeGenStackVM* p) { PushBinary(StackVM::MUL_I64, op->a, op->b);
PushBinary(StackVM::MOD_I64, op->a, op->b, p); }
})
.set_dispatch<Min>([](const Min *op, CodeGenStackVM* p) { void CodeGenStackVM::VisitExpr_(const Div *op) {
p->Push(op->a); PushBinary(StackVM::DIV_I64, op->a, op->b);
p->Push(op->b); }
p->PushOp(StackVM::PUSH_VALUE, -1);
p->PushOp(StackVM::PUSH_VALUE, -1); void CodeGenStackVM::VisitExpr_(const Mod *op) {
p->PushOp(StackVM::LT_I64); PushBinary(StackVM::MOD_I64, op->a, op->b);
p->PushOp(StackVM::SELECT); }
})
.set_dispatch<Max>([](const Max *op, CodeGenStackVM* p) { void CodeGenStackVM::VisitExpr_(const Min *op) {
p->Push(op->a); this->Push(op->a);
p->Push(op->b); this->Push(op->b);
p->PushOp(StackVM::PUSH_VALUE, 0); this->PushOp(StackVM::PUSH_VALUE, -1);
p->PushOp(StackVM::PUSH_VALUE, -2); this->PushOp(StackVM::PUSH_VALUE, -1);
p->PushOp(StackVM::LT_I64); this->PushOp(StackVM::LT_I64);
p->PushOp(StackVM::SELECT); this->PushOp(StackVM::SELECT);
}) }
.set_dispatch<EQ>([](const EQ *op, CodeGenStackVM* p) {
PushBinary(StackVM::EQ_I64, op->a, op->b, p); void CodeGenStackVM::VisitExpr_(const Max *op) {
}) this->Push(op->a);
.set_dispatch<LE>([](const LE *op, CodeGenStackVM* p) { this->Push(op->b);
PushBinary(StackVM::LE_I64, op->a, op->b, p); this->PushOp(StackVM::PUSH_VALUE, 0);
}) this->PushOp(StackVM::PUSH_VALUE, -2);
.set_dispatch<NE>([](const NE *op, CodeGenStackVM* p) { this->PushOp(StackVM::LT_I64);
PushBinary(StackVM::EQ_I64, op->a, op->b, p); this->PushOp(StackVM::SELECT);
p->PushOp(StackVM::NOT); }
})
.set_dispatch<LT>([](const LT *op, CodeGenStackVM* p) { void CodeGenStackVM::VisitExpr_(const EQ *op) {
PushBinary(StackVM::LT_I64, op->a, op->b, p); PushBinary(StackVM::EQ_I64, op->a, op->b);
}) }
.set_dispatch<GE>([](const GE *op, CodeGenStackVM* p) {
PushBinary(StackVM::LT_I64, op->a, op->b, p); void CodeGenStackVM::VisitExpr_(const LE *op) {
p->PushOp(StackVM::NOT); PushBinary(StackVM::LE_I64, op->a, op->b);
}) }
.set_dispatch<GT>([](const GT *op, CodeGenStackVM* p) {
PushBinary(StackVM::LE_I64, op->a, op->b, p); void CodeGenStackVM::VisitExpr_(const NE *op) {
p->PushOp(StackVM::NOT); PushBinary(StackVM::EQ_I64, op->a, op->b);
}) this->PushOp(StackVM::NOT);
.set_dispatch<And>([](const And *op, CodeGenStackVM* p) { }
p->Push(op->a);
int64_t pc_jump = p->GetPC(); void CodeGenStackVM::VisitExpr_(const LT *op) {
int64_t opr_index = p->PushOp(StackVM::RJUMP_IF_FALSE, 0); PushBinary(StackVM::LT_I64, op->a, op->b);
p->PushOp(StackVM::POP); }
p->Push(op->b);
int64_t diff = p->GetPC() - pc_jump; void CodeGenStackVM::VisitExpr_(const GE *op) {
p->SetOperand(opr_index, diff); PushBinary(StackVM::LT_I64, op->a, op->b);
}) this->PushOp(StackVM::NOT);
.set_dispatch<Or>([](const Or *op, CodeGenStackVM* p) { }
p->Push(op->a);
int64_t pc_jump = p->GetPC(); void CodeGenStackVM::VisitExpr_(const GT *op) {
int64_t opr_index = p->PushOp(StackVM::RJUMP_IF_TRUE, 0); PushBinary(StackVM::LE_I64, op->a, op->b);
p->Push(op->b); this->PushOp(StackVM::NOT);
int64_t diff = p->GetPC() - pc_jump; }
p->SetOperand(opr_index, diff);
}) void CodeGenStackVM::VisitExpr_(const And *op) {
.set_dispatch<Not>([](const Not* op, CodeGenStackVM* p) { this->Push(op->a);
p->PushOp(StackVM::NOT); int64_t pc_jump = this->GetPC();
}); int64_t opr_index = this->PushOp(StackVM::RJUMP_IF_FALSE, 0);
this->PushOp(StackVM::POP);
this->Push(op->b);
TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable) int64_t diff = this->GetPC() - pc_jump;
.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, CodeGenStackVM* p) { this->SetOperand(opr_index, diff);
p->Push(op->body); }
})
.set_dispatch<For>([](const For *op, CodeGenStackVM* p) { void CodeGenStackVM::VisitExpr_(const Or *op) {
CHECK(is_zero(op->min)); this->Push(op->a);
int vid = p->AllocVarID(op->loop_var.get()); int64_t pc_jump = this->GetPC();
p->PushOp(StackVM::PUSH_I64, 0); int64_t opr_index = this->PushOp(StackVM::RJUMP_IF_TRUE, 0);
int64_t loop_head = p->GetPC(); this->Push(op->b);
p->PushOp(StackVM::STORE_HEAP, vid); int64_t diff = this->GetPC() - pc_jump;
p->PushOp(StackVM::LOAD_HEAP, vid); this->SetOperand(opr_index, diff);
p->Push(op->extent); }
p->PushOp(StackVM::LT_I64);
int64_t label_fjump = p->GetPC(); void CodeGenStackVM::VisitExpr_(const Not* op) {
int64_t foward_jump = p->PushOp(StackVM::RJUMP_IF_FALSE, 0); this->PushOp(StackVM::NOT);
p->PushOp(StackVM::POP); }
p->Push(op->body);
p->PushOp(StackVM::LOAD_HEAP, vid); void CodeGenStackVM::VisitStmt_(const ProducerConsumer *op) {
p->PushOp(StackVM::PUSH_I64, 1); this->Push(op->body);
p->PushOp(StackVM::ADD_I64); }
int64_t label_bjump = p->GetPC();
int64_t backward_jump = p->PushOp(StackVM::RJUMP, 0); void CodeGenStackVM::VisitStmt_(const For *op) {
int64_t loop_end = p->GetPC(); CHECK(is_zero(op->min));
p->PushOp(StackVM::POP); int vid = this->AllocVarID(op->loop_var.get());
p->SetOperand(foward_jump, loop_end - label_fjump); this->PushOp(StackVM::PUSH_I64, 0);
p->SetOperand(backward_jump, loop_head - label_bjump); int64_t loop_head = this->GetPC();
}) this->PushOp(StackVM::STORE_HEAP, vid);
.set_dispatch<Block>([](const Block *op, CodeGenStackVM* p) { this->PushOp(StackVM::LOAD_HEAP, vid);
p->Push(op->first); this->Push(op->extent);
if (op->rest.defined()) p->Push(op->rest); this->PushOp(StackVM::LT_I64);
}) int64_t label_fjump = this->GetPC();
.set_dispatch<Evaluate>([](const Evaluate *op, CodeGenStackVM* p) { int64_t foward_jump = this->PushOp(StackVM::RJUMP_IF_FALSE, 0);
if (is_const(op->value)) return; this->PushOp(StackVM::POP);
p->Push(op->value); this->Push(op->body);
p->PushOp(StackVM::POP); this->PushOp(StackVM::LOAD_HEAP, vid);
}) this->PushOp(StackVM::PUSH_I64, 1);
.set_dispatch<IfThenElse>([](const IfThenElse *op, CodeGenStackVM* p) { this->PushOp(StackVM::ADD_I64);
p->Push(op->condition); int64_t label_bjump = this->GetPC();
int64_t label_ejump = p->GetPC(); int64_t backward_jump = this->PushOp(StackVM::RJUMP, 0);
int64_t else_jump = p->PushOp(StackVM::RJUMP_IF_FALSE, 0); int64_t loop_end = this->GetPC();
p->PushOp(StackVM::POP); this->PushOp(StackVM::POP);
p->Push(op->then_case); this->SetOperand(foward_jump, loop_end - label_fjump);
if (op->else_case.defined()) { this->SetOperand(backward_jump, loop_head - label_bjump);
int64_t label_then_jump = p->GetPC(); }
int64_t then_jump = p->PushOp(StackVM::RJUMP, 0);
int64_t else_begin = p->GetPC(); void CodeGenStackVM::VisitStmt_(const Block *op) {
p->SetOperand(else_jump, else_begin - label_ejump); this->Push(op->first);
p->PushOp(StackVM::POP); if (op->rest.defined()) this->Push(op->rest);
p->Push(op->else_case); }
int64_t if_end = p->GetPC();
p->SetOperand(then_jump, if_end - label_then_jump); void CodeGenStackVM::VisitStmt_(const Evaluate *op) {
} else { if (is_const(op->value)) return;
int64_t if_end = p->GetPC(); this->Push(op->value);
p->SetOperand(else_jump, if_end - label_ejump); this->PushOp(StackVM::POP);
p->PushOp(StackVM::POP); }
}
}) void CodeGenStackVM::VisitStmt_(const IfThenElse *op) {
.set_dispatch<LetStmt>([](const LetStmt *op, CodeGenStackVM* p) { this->Push(op->condition);
p->Push(op->value); int64_t label_ejump = this->GetPC();
int64_t vid = p->AllocVarID(op->var.get()); int64_t else_jump = this->PushOp(StackVM::RJUMP_IF_FALSE, 0);
p->PushOp(StackVM::STORE_HEAP, static_cast<int>(vid)); this->PushOp(StackVM::POP);
p->Push(op->body); this->Push(op->then_case);
}) if (op->else_case.defined()) {
.set_dispatch<Ramp>([](const Ramp *op, CodeGenStackVM* p) { int64_t label_then_jump = this->GetPC();
LOG(FATAL) << "Ramp is not supported"; int64_t then_jump = this->PushOp(StackVM::RJUMP, 0);
}) int64_t else_begin = this->GetPC();
.set_dispatch<Broadcast>([](const Broadcast *op, CodeGenStackVM* p) { this->SetOperand(else_jump, else_begin - label_ejump);
LOG(FATAL) << "Broadcast is not supported"; this->PushOp(StackVM::POP);
}) this->Push(op->else_case);
.set_dispatch<Select>([](const Select *op, CodeGenStackVM* p) { int64_t if_end = this->GetPC();
p->Push(op->true_value); this->SetOperand(then_jump, if_end - label_then_jump);
p->Push(op->false_value); } else {
p->Push(op->condition); int64_t if_end = this->GetPC();
p->PushOp(StackVM::SELECT); this->SetOperand(else_jump, if_end - label_ejump);
}) this->PushOp(StackVM::POP);
.set_dispatch<AssertStmt>([](const AssertStmt *op, CodeGenStackVM* p) { }
if (op->message.as<StringImm>()) { }
int sid = p->GetStrID(op->message.as<StringImm>()->value);
p->Push(op->condition); void CodeGenStackVM::VisitStmt_(const LetStmt *op) {
p->PushOp(StackVM::ASSERT, sid); this->Push(op->value);
} int64_t vid = this->AllocVarID(op->var.get());
}) this->PushOp(StackVM::STORE_HEAP, static_cast<int>(vid));
.set_dispatch<AttrStmt>([](const AttrStmt *op, CodeGenStackVM* p) { this->Push(op->body);
p->Push(op->body); }
})
.set_dispatch<Let>([](const Let *op, CodeGenStackVM* p) { void CodeGenStackVM::VisitExpr_(const Ramp *op) {
p->Push(op->value); LOG(FATAL) << "Ramp is not supported";
int64_t vid = p->AllocVarID(op->var.get()); }
p->PushOp(StackVM::STORE_HEAP, static_cast<int>(vid));
p->Push(op->body); void CodeGenStackVM::VisitExpr_(const Broadcast *op) {
}) LOG(FATAL) << "Broadcast is not supported";
.set_dispatch<Load>([](const Load *op, CodeGenStackVM* p) { }
p->Push_(op);
}) void CodeGenStackVM::VisitExpr_(const Select *op) {
.set_dispatch<Store>([](const Store *op, CodeGenStackVM* p) { this->Push(op->true_value);
p->Push_(op); this->Push(op->false_value);
}) this->Push(op->condition);
.set_dispatch<Allocate>([](const Allocate *op, CodeGenStackVM* p) { this->PushOp(StackVM::SELECT);
p->Push_(op); }
})
.set_dispatch<Call>([](const Call *op, CodeGenStackVM* p) { void CodeGenStackVM::VisitStmt_(const AssertStmt *op) {
p->Push_(op); if (op->message.as<StringImm>()) {
}); int sid = this->GetStrID(op->message.as<StringImm>()->value);
this->Push(op->condition);
this->PushOp(StackVM::ASSERT, sid);
}
}
void CodeGenStackVM::VisitStmt_(const AttrStmt *op) {
this->Push(op->body);
}
void CodeGenStackVM::VisitExpr_(const Let *op) {
this->Push(op->value);
int64_t vid = this->AllocVarID(op->var.get());
this->PushOp(StackVM::STORE_HEAP, static_cast<int>(vid));
this->Push(op->body);
}
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#define TVM_CODEGEN_STACK_VM_CODEGEN_STACK_VM_H_ #define TVM_CODEGEN_STACK_VM_CODEGEN_STACK_VM_H_
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/lowered_func.h> #include <tvm/lowered_func.h>
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <string> #include <string>
...@@ -18,12 +19,15 @@ ...@@ -18,12 +19,15 @@
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
using namespace ir;
/*! /*!
* \brief A base class to generate a stack VM. * \brief A base class to generate a stack VM.
* This module is used to generate host wrapper * This module is used to generate host wrapper
* into device function when only device JIT is available. * into device function when only device JIT is available.
*/ */
class CodeGenStackVM { class CodeGenStackVM
: public ExprFunctor<void(const Expr&)>,
public StmtFunctor<void(const Stmt&)> {
public: public:
/*! /*!
* \brief Generate a stack VM representing * \brief Generate a stack VM representing
...@@ -35,8 +39,10 @@ class CodeGenStackVM { ...@@ -35,8 +39,10 @@ class CodeGenStackVM {
StackVM Compile(LoweredFunc f); StackVM Compile(LoweredFunc f);
/*! \brief Push stmt to generate new code */ /*! \brief Push stmt to generate new code */
void Push(const Stmt& n); void Push(const Stmt& n);
/*! \brief Push expr to generate new code */ /*! \brief Push expr to generate new code */
void Push(const Expr& n); void Push(const Expr& n) {
VisitExpr(n);
}
/*! /*!
* \brief Push the opcode to the code. * \brief Push the opcode to the code.
* \param opcode The code to be pushed. * \param opcode The code to be pushed.
...@@ -84,16 +90,53 @@ class CodeGenStackVM { ...@@ -84,16 +90,53 @@ class CodeGenStackVM {
* \return the heap index of the var. * \return the heap index of the var.
*/ */
int GetVarID(const Variable* v) const; int GetVarID(const Variable* v) const;
// Push binary operator
void PushBinary(StackVM::OpCode op_int64,
const Expr& a,
const Expr& b);
// push cast;
void PushCast(Type dst, Type src);
// overloadable functions // overloadable functions
virtual void Push_(const ir::Load* op); // expression
virtual void Push_(const ir::Store* op); void VisitExpr_(const Variable* op) final;
virtual void Push_(const ir::Allocate* op); void VisitExpr_(const Load* op) final;
virtual void Push_(const ir::Call* op); void VisitExpr_(const Let* op) final;
virtual void HandleUnknownCall(const ir::Call* op); void VisitExpr_(const Call* op) final;
/*! \brief function to to print normal code */ void VisitExpr_(const Add* op) final;
using FType = IRFunctor<void(const NodeRef&, CodeGenStackVM *)>; void VisitExpr_(const Sub* op) final;
// vtable to print code void VisitExpr_(const Mul* op) final;
static FType& vtable(); // NOLINT(*) void VisitExpr_(const Div* op) final;
void VisitExpr_(const Mod* op) final;
void VisitExpr_(const Min* op) final;
void VisitExpr_(const Max* op) final;
void VisitExpr_(const EQ* op) final;
void VisitExpr_(const NE* op) final;
void VisitExpr_(const LT* op) final;
void VisitExpr_(const LE* op) final;
void VisitExpr_(const GT* op) final;
void VisitExpr_(const GE* op) final;
void VisitExpr_(const And* op) final;
void VisitExpr_(const Or* op) final;
void VisitExpr_(const Cast* op) final;
void VisitExpr_(const Not* op) final;
void VisitExpr_(const Select* op) final;
void VisitExpr_(const Ramp* op) final;
void VisitExpr_(const Broadcast* op) final;
void VisitExpr_(const IntImm* op) final;
void VisitExpr_(const UIntImm* op) final;
void VisitExpr_(const FloatImm* op) final;
void VisitExpr_(const StringImm* op) final;
// statment
void VisitStmt_(const LetStmt* op) final;
void VisitStmt_(const Store* op) final;
void VisitStmt_(const For* op) final;
void VisitStmt_(const IfThenElse* op) final;
void VisitStmt_(const Allocate* op) final;
void VisitStmt_(const AttrStmt* op) final;
void VisitStmt_(const AssertStmt* op) final;
void VisitStmt_(const Evaluate* op) final;
void VisitStmt_(const Block* op) final;
void VisitStmt_(const ProducerConsumer* op) final;
private: private:
bool debug_{false}; bool debug_{false};
......
...@@ -140,10 +140,6 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) { ...@@ -140,10 +140,6 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) {
} }
} }
Stmt IRMutator::Mutate_(const Load *op, const Stmt& s) {
return s;
}
Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) { Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) {
Expr value = this->Mutate(op->value); Expr value = this->Mutate(op->value);
Expr index = this->Mutate(op->index); Expr index = this->Mutate(op->index);
...@@ -234,84 +230,24 @@ Stmt IRMutator::Mutate_(const Evaluate *op, const Stmt& s) { ...@@ -234,84 +230,24 @@ Stmt IRMutator::Mutate_(const Evaluate *op, const Stmt& s) {
} }
} }
#define DEFINE_OP_RETURN_SELF_STMT_MUTATE_(OP) \ Stmt IRMutator::Mutate_(const Free *op, const Stmt& s) {
Stmt IRMutator::Mutate_(const OP *op, const Stmt& s) { \ return s;
return s; \ }
}
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Variable)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Let)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Free)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Call)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Add)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Sub)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Mul)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Div)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Mod)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Min)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Max)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(EQ)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(NE)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(LT)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(LE)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(GT)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(GE)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(And)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Or)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Reduce)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Cast)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Not)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Select)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Ramp)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Broadcast)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(IntImm)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(UIntImm)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(FloatImm)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(StringImm)
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.DISPATCH_TO_MUTATE_STMT(Variable)
.DISPATCH_TO_MUTATE_STMT(LetStmt) .DISPATCH_TO_MUTATE_STMT(LetStmt)
.DISPATCH_TO_MUTATE_STMT(AttrStmt) .DISPATCH_TO_MUTATE_STMT(AttrStmt)
.DISPATCH_TO_MUTATE_STMT(IfThenElse) .DISPATCH_TO_MUTATE_STMT(IfThenElse)
.DISPATCH_TO_MUTATE_STMT(For) .DISPATCH_TO_MUTATE_STMT(For)
.DISPATCH_TO_MUTATE_STMT(Allocate) .DISPATCH_TO_MUTATE_STMT(Allocate)
.DISPATCH_TO_MUTATE_STMT(Load)
.DISPATCH_TO_MUTATE_STMT(Store) .DISPATCH_TO_MUTATE_STMT(Store)
.DISPATCH_TO_MUTATE_STMT(Let)
.DISPATCH_TO_MUTATE_STMT(Free) .DISPATCH_TO_MUTATE_STMT(Free)
.DISPATCH_TO_MUTATE_STMT(Call)
.DISPATCH_TO_MUTATE_STMT(Add)
.DISPATCH_TO_MUTATE_STMT(Sub)
.DISPATCH_TO_MUTATE_STMT(Mul)
.DISPATCH_TO_MUTATE_STMT(Div)
.DISPATCH_TO_MUTATE_STMT(Mod)
.DISPATCH_TO_MUTATE_STMT(Min)
.DISPATCH_TO_MUTATE_STMT(Max)
.DISPATCH_TO_MUTATE_STMT(EQ)
.DISPATCH_TO_MUTATE_STMT(NE)
.DISPATCH_TO_MUTATE_STMT(LT)
.DISPATCH_TO_MUTATE_STMT(LE)
.DISPATCH_TO_MUTATE_STMT(GT)
.DISPATCH_TO_MUTATE_STMT(GE)
.DISPATCH_TO_MUTATE_STMT(And)
.DISPATCH_TO_MUTATE_STMT(Or)
.DISPATCH_TO_MUTATE_STMT(Reduce)
.DISPATCH_TO_MUTATE_STMT(Cast)
.DISPATCH_TO_MUTATE_STMT(Not)
.DISPATCH_TO_MUTATE_STMT(Select)
.DISPATCH_TO_MUTATE_STMT(Ramp)
.DISPATCH_TO_MUTATE_STMT(Broadcast)
.DISPATCH_TO_MUTATE_STMT(AssertStmt) .DISPATCH_TO_MUTATE_STMT(AssertStmt)
.DISPATCH_TO_MUTATE_STMT(ProducerConsumer) .DISPATCH_TO_MUTATE_STMT(ProducerConsumer)
.DISPATCH_TO_MUTATE_STMT(Provide) .DISPATCH_TO_MUTATE_STMT(Provide)
.DISPATCH_TO_MUTATE_STMT(Realize) .DISPATCH_TO_MUTATE_STMT(Realize)
.DISPATCH_TO_MUTATE_STMT(Block) .DISPATCH_TO_MUTATE_STMT(Block)
.DISPATCH_TO_MUTATE_STMT(Evaluate) .DISPATCH_TO_MUTATE_STMT(Evaluate);
.DISPATCH_TO_MUTATE_STMT(IntImm)
.DISPATCH_TO_MUTATE_STMT(UIntImm)
.DISPATCH_TO_MUTATE_STMT(FloatImm)
.DISPATCH_TO_MUTATE_STMT(StringImm);
// Mutate Expr // Mutate Expr
...@@ -450,19 +386,6 @@ Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) { ...@@ -450,19 +386,6 @@ Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) {
return e; \ return e; \
} }
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(LetStmt)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(AttrStmt)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(For)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IfThenElse)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Allocate)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Store)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Free)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(AssertStmt)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(ProducerConsumer)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Provide)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Realize)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Block)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Evaluate)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImm) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImm)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImm) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImm)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImm) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImm)
...@@ -470,15 +393,8 @@ DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImm) ...@@ -470,15 +393,8 @@ DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImm)
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.DISPATCH_TO_MUTATE_EXPR(Variable) .DISPATCH_TO_MUTATE_EXPR(Variable)
.DISPATCH_TO_MUTATE_EXPR(LetStmt)
.DISPATCH_TO_MUTATE_EXPR(AttrStmt)
.DISPATCH_TO_MUTATE_EXPR(IfThenElse)
.DISPATCH_TO_MUTATE_EXPR(For)
.DISPATCH_TO_MUTATE_EXPR(Allocate)
.DISPATCH_TO_MUTATE_EXPR(Load) .DISPATCH_TO_MUTATE_EXPR(Load)
.DISPATCH_TO_MUTATE_EXPR(Store)
.DISPATCH_TO_MUTATE_EXPR(Let) .DISPATCH_TO_MUTATE_EXPR(Let)
.DISPATCH_TO_MUTATE_EXPR(Free)
.DISPATCH_TO_MUTATE_EXPR(Call) .DISPATCH_TO_MUTATE_EXPR(Call)
.DISPATCH_TO_MUTATE_EXPR(Add) .DISPATCH_TO_MUTATE_EXPR(Add)
.DISPATCH_TO_MUTATE_EXPR(Sub) .DISPATCH_TO_MUTATE_EXPR(Sub)
...@@ -501,12 +417,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) ...@@ -501,12 +417,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.DISPATCH_TO_MUTATE_EXPR(Select) .DISPATCH_TO_MUTATE_EXPR(Select)
.DISPATCH_TO_MUTATE_EXPR(Ramp) .DISPATCH_TO_MUTATE_EXPR(Ramp)
.DISPATCH_TO_MUTATE_EXPR(Broadcast) .DISPATCH_TO_MUTATE_EXPR(Broadcast)
.DISPATCH_TO_MUTATE_EXPR(AssertStmt)
.DISPATCH_TO_MUTATE_EXPR(ProducerConsumer)
.DISPATCH_TO_MUTATE_EXPR(Provide)
.DISPATCH_TO_MUTATE_EXPR(Realize)
.DISPATCH_TO_MUTATE_EXPR(Block)
.DISPATCH_TO_MUTATE_EXPR(Evaluate)
.DISPATCH_TO_MUTATE_EXPR(IntImm) .DISPATCH_TO_MUTATE_EXPR(IntImm)
.DISPATCH_TO_MUTATE_EXPR(UIntImm) .DISPATCH_TO_MUTATE_EXPR(UIntImm)
.DISPATCH_TO_MUTATE_EXPR(FloatImm) .DISPATCH_TO_MUTATE_EXPR(FloatImm)
......
...@@ -69,11 +69,71 @@ class Vectorizer : public IRMutator { ...@@ -69,11 +69,71 @@ class Vectorizer : public IRMutator {
} }
// user mutate from parent. // user mutate from parent.
using IRMutator::Mutate; using IRMutator::Mutate;
// override mutate
Expr Mutate(Expr expr) final { Expr Mutate_(const Add* op, const Expr &e) final {
static const FMutateExpr& f = Vectorizer::vtable_expr(); return AddSubVec(op, e);
return (f.can_dispatch(expr) ? }
f(expr, expr, this) : IRMutator::Mutate(expr)); Expr Mutate_(const Sub* op, const Expr &e) final {
return AddSubVec(op, e);
}
Expr Mutate_(const Mul* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Div* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Mod* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Min* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Max* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const EQ* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const NE* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const LT* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const GT* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const GE* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const And* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Or* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Select *op, const Expr& e) final {
Expr cond = this->Mutate(op->condition);
Expr t = this->Mutate(op->true_value);
Expr f = this->Mutate(op->false_value);
if (cond.same_as(op->condition) &&
t.same_as(op->true_value) &&
f.same_as(op->false_value)) {
return e;
} else {
int lanes = std::max(std::max(
cond.type().lanes(),
t.type().lanes()), f.type().lanes());
return Select::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes));
}
}
Expr Mutate_(const Cast *op, const Expr& e) final {
Expr value = this->Mutate(op->value);
if (value.same_as(op->value)) {
return e;
} else {
return Cast::make(op->type.with_lanes(value.type().lanes()), value);
}
} }
// Variable // Variable
Expr Mutate_(const Variable* v, const Expr& e) final { Expr Mutate_(const Variable* v, const Expr& e) final {
...@@ -235,10 +295,6 @@ class Vectorizer : public IRMutator { ...@@ -235,10 +295,6 @@ class Vectorizer : public IRMutator {
stmt = Substitute(stmt, {{var_, idx}}); stmt = Substitute(stmt, {{var_, idx}});
return For::make(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt); return For::make(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt);
} }
// The overloads for vectorize.
static FMutateExpr& vtable_expr() { // NOLINT(*)
static FMutateExpr inst; return inst;
}
private: private:
// variable to be replaced // variable to be replaced
...@@ -273,90 +329,43 @@ class Vectorizer : public IRMutator { ...@@ -273,90 +329,43 @@ class Vectorizer : public IRMutator {
if (!changed) return arr; if (!changed) return arr;
return Array<Expr>(new_arr); return Array<Expr>(new_arr);
} }
}; template<typename T>
Expr BinaryVec(const T* op, const Expr& e) {
// binary vectorize Expr a = this->Mutate(op->a);
template<typename T> Expr b = this->Mutate(op->b);
inline Expr BinaryVec(const T* op, const Expr& e, IRMutator* m) { if (a.same_as(op->a) &&
Expr a = m->Mutate(op->a); b.same_as(op->b)) {
Expr b = m->Mutate(op->b);
if (a.same_as(op->a) &&
b.same_as(op->b)) {
return e;
} else {
int lanes = std::max(a.type().lanes(), b.type().lanes());
return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
}
}
template<typename T>
inline Expr AddSubVec(const T* op, const Expr& e, IRMutator* m) {
Expr a = m->Mutate(op->a);
Expr b = m->Mutate(op->b);
if (a.same_as(op->a) &&
b.same_as(op->b)) {
return e;
} else {
int lanes = std::max(a.type().lanes(), b.type().lanes());
if (lanes != 1) {
const Ramp* b_ramp = b.as<Ramp>();
const Ramp* a_ramp = a.as<Ramp>();
if (a.type().lanes() == 1 && b_ramp) {
return Ramp::make(
arith::ComputeExpr<T>(a, b_ramp->base), b_ramp->stride, b_ramp->lanes);
}
if (b.type().lanes() == 1 && a_ramp) {
return Ramp::make(
arith::ComputeExpr<T>(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
}
}
return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
}
}
TVM_STATIC_IR_FUNCTOR(Vectorizer, vtable_expr)
.set_dispatch<Add>(AddSubVec<Add>)
.set_dispatch<Sub>(AddSubVec<Sub>)
.set_dispatch<Mul>(BinaryVec<Mul>)
.set_dispatch<Div>(BinaryVec<Div>)
.set_dispatch<Mod>(BinaryVec<Mod>)
.set_dispatch<Min>(BinaryVec<Min>)
.set_dispatch<Max>(BinaryVec<Max>)
.set_dispatch<EQ>(BinaryVec<EQ>)
.set_dispatch<NE>(BinaryVec<NE>)
.set_dispatch<LT>(BinaryVec<LT>)
.set_dispatch<LE>(BinaryVec<LE>)
.set_dispatch<GT>(BinaryVec<GT>)
.set_dispatch<GE>(BinaryVec<GE>)
.set_dispatch<And>(BinaryVec<And>)
.set_dispatch<Or>(BinaryVec<Or>);
TVM_STATIC_IR_FUNCTOR(Vectorizer, vtable_expr)
.set_dispatch<Select>([](const Select *op, const Expr& e, IRMutator* m) {
Expr cond = m->Mutate(op->condition);
Expr t = m->Mutate(op->true_value);
Expr f = m->Mutate(op->false_value);
if (cond.same_as(op->condition) &&
t.same_as(op->true_value) &&
f.same_as(op->false_value)) {
return e; return e;
} else { } else {
int lanes = std::max(std::max( int lanes = std::max(a.type().lanes(), b.type().lanes());
cond.type().lanes(), return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
t.type().lanes()), f.type().lanes());
return Select::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes));
} }
}) }
.set_dispatch<Cast>([](const Cast *op, const Expr& e, IRMutator* m) { template<typename T>
Expr value = m->Mutate(op->value); Expr AddSubVec(const T* op, const Expr& e) {
if (value.same_as(op->value)) { Expr a = this->Mutate(op->a);
Expr b = this->Mutate(op->b);
if (a.same_as(op->a) &&
b.same_as(op->b)) {
return e; return e;
} else { } else {
return Cast::make(op->type.with_lanes(value.type().lanes()), value); int lanes = std::max(a.type().lanes(), b.type().lanes());
if (lanes != 1) {
const Ramp* b_ramp = b.as<Ramp>();
const Ramp* a_ramp = a.as<Ramp>();
if (a.type().lanes() == 1 && b_ramp) {
return Ramp::make(
arith::ComputeExpr<T>(a, b_ramp->base), b_ramp->stride, b_ramp->lanes);
}
if (b.type().lanes() == 1 && a_ramp) {
return Ramp::make(
arith::ComputeExpr<T>(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
}
}
return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
} }
}); }
};
class LoopVectorizer : public IRMutator { class LoopVectorizer : public IRMutator {
public: public:
......
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/tvm.h> #include <tvm/tvm.h>
#include <tvm/ir_functor.h> #include <tvm/ir_functor.h>
#include <tvm/ir_functor_ext.h>
TEST(IRF, Basic) { TEST(IRF, Basic) {
using namespace Halide::Internal;
using namespace tvm; using namespace tvm;
using namespace tvm::ir;
Var x("x"); Var x("x");
auto z = x + 1; auto z = x + 1;
...@@ -21,6 +22,65 @@ TEST(IRF, Basic) { ...@@ -21,6 +22,65 @@ TEST(IRF, Basic) {
CHECK_EQ(f(z, 2), 4); CHECK_EQ(f(z, 2), 4);
} }
TEST(IRF, ExprTransform) {
using namespace tvm;
using namespace tvm::ir;
Var x("x");
auto z = x + 1;
class MyExprFunctor
: public ir::ExprFunctor<int(const Expr&, int)> {
public:
int VisitExpr_(const Variable* op, int b) final {
return b;
}
int VisitExpr_(const IntImm* op, int b) final {
return op->value;
}
int VisitExpr_(const Add* op, int b) final {
return VisitExpr(op->a, b) + VisitExpr(op->b, b);
}
};
MyExprFunctor f;
CHECK_EQ(f(x, 2), 2);
CHECK_EQ(f(z, 2), 3);
try {
f(z - 1, 2);
LOG(FATAL) << "should fail";
} catch(dmlc::Error) {
}
}
TEST(IRF, ExprVisit) {
using namespace tvm;
using namespace tvm::ir;
Var x("x");
auto z = x + 1;
class MyVisitor
: public ir::ExprFunctor<void(const Expr&)>,
public ir::StmtFunctor<void(const Stmt&)> {
public:
int count = 0;
// implementation
void VisitExpr_(const Variable* op) final {
++count;
}
void VisitExpr_(const IntImm* op) final {
}
void VisitExpr_(const Add* op) final {
VisitExpr(op->a);
VisitExpr(op->b);
}
void VisitStmt_(const Evaluate* op) final {
VisitExpr(op->value);
}
};
MyVisitor v;
v(Evaluate::make(z));
CHECK_EQ(v.count, 1);
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe"; testing::FLAGS_gtest_death_test_style = "threadsafe";
......
...@@ -25,6 +25,11 @@ def test_deduce(): ...@@ -25,6 +25,11 @@ def test_deduce():
ans1 = (c-b)/4+(-2) ans1 = (c-b)/4+(-2)
assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1) assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1)
e2 = (tvm.max(5, a * 4) < 0)
res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s})
assert str(res2.max()) == "neg_inf"
assert str(res2.min()) == "pos_inf"
def test_check(): def test_check():
a = tvm.Var('a') a = tvm.Var('a')
b = tvm.Var('b') b = tvm.Var('b')
......
import tvm
def test_basic():
a = tvm.Var()
b = tvm.Var()
m = tvm.arith.EvalModular(a * 4 + b * 6 + 7)
assert m.coeff == 2
assert m.base == 1
m = tvm.arith.EvalModular((a * 4 + 1) * (b * 8 + 3))
assert m.coeff == 4
assert m.base == 3
m = tvm.arith.EvalModular((a * 4 + 1) / (b * 8 + 3))
assert m.coeff == 1
assert m.base == 0
m = tvm.arith.EvalModular((a * 4 + 1) * (b * 8 / 4))
assert m.coeff == 2
assert m.base == 0
m = tvm.arith.EvalModular((a * 12 + 1) - (b * 3 * 7 + 2))
assert m.coeff == 3
assert m.base == 2
m = tvm.arith.EvalModular(a * 12 + tvm.min(b * 3 * 7, 2))
assert m.coeff == 1
assert m.base == 0
if __name__ == "__main__":
test_basic()
...@@ -16,7 +16,7 @@ def test_llvm_add_pipeline(): ...@@ -16,7 +16,7 @@ def test_llvm_add_pipeline():
f = tvm.build(s, [A, B, C], "llvm") f = tvm.build(s, [A, B, C], "llvm")
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
# launch the kernel. # launch the kernel.
n = 10270 * 2460 n = 1027 * 1024
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx) a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment