Commit 71334483 by Tianqi Chen Committed by GitHub

[VISITOR] New ExprFunctor, StmtFunctor Interface. Modular analysis (#58)

* [ARITH/VISITOR] Modular Analysis, ExprFunctor, StmtFunctor

* retrigger

* [IRFunctor] Migrated CodegenC

* [IRFUNCTOR] Migrate CodeGenLLVM

* [IRFunctor] Migrate canonical

* [IRFunctor] Migrate vectorize

* [IRFunctor] migrate CodeGenStackVM
parent e4387940
......@@ -59,7 +59,6 @@ after_failure:
- tests/travis/travis_after_failure.sh
notifications:
# Emails are sent to the committer's git-configured email address by default,
email:
on_success: change
on_failure: always
/*!
* Copyright (c) 2017 by Contributors
* \file ir_functor_ext.h
* \brief More powerful Visitor that allows define function signatures.
*/
#ifndef TVM_IR_FUNCTOR_EXT_H_
#define TVM_IR_FUNCTOR_EXT_H_
#include <tvm/ir_functor.h>
#include "./ir.h"
namespace tvm {
namespace ir {
/*!
* \brief A dynamical functor that dispatches on in the first Expr argument.
* You can use this as a more powerful Visitor, since it allows you to
* define function signatures of Visit Function.
*
* \code
* // A functor that set variable to b. and calculate results.
* class MyExprFunctor
* : public ir::ExprFunctor<int(const Expr&, int)> {
* public:
* int VisitExpr_(const Variable* op, int b) final {
* return b;
* }
* int VisitExpr_(const IntImm* op, int b) final {
* return op->value;
* }
* int VisitExpr_(const Add* op, int b) final {
* return Visit(op->a, b) + Visit(op->b, b);
* }
* };
* MyExprFunctor f;
* Var x("x");
* CHECK_EQ(f(x + 1, 2), 3);
* \endcode
*
* \note Why do we need this more powerful Functor:
*
* We often need to implement a transformer tasks.
* Say we want to take Expr and transform it to some analysis result,
* This easily be done incorrectly using plain Visitor. See IRVisitor's
* document for possible error cases.
*
* \tparam FType function signiture
* This type if only defined for FType with function signiture R(const Expr&, Args...)
*/
template<typename FType>
class ExprFunctor;
/*!
* \brief Same as ExprFunctor except it is applied on statements
* \tparam FType The function signature.
*/
template<typename FType>
class StmtFunctor;
// functions to be overriden.
#define EXPR_FUNCTOR_DEFAULT { \
return VisitExprDefault_(op, std::forward<Args>(args)...); \
}
#define STMT_FUNCTOR_DEFAULT { \
return VisitStmtDefault_(op, std::forward<Args>(args)...); \
}
#define IR_EXPR_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
return self->VisitExpr_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \
}); \
#define IR_STMT_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
return self->VisitStmt_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \
}); \
template<typename R, typename ...Args>
class ExprFunctor<R(const Expr& n, Args...)> {
private:
using TSelf = ExprFunctor<R(const Expr& n, Args...)>;
using FType = IRFunctor<R(const NodeRef& n, TSelf* self, Args...)>;
public:
/*! \brief the result type of this functor */
using result_type = R;
/*! \brief virtual destructor */
virtual ~ExprFunctor() {}
/*!
* \brief Same as call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
R operator()(const Expr& n, Args... args) {
return VisitExpr(n, std::forward<Args>(args)...);
}
/*!
* \brief The functor call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
virtual R VisitExpr(const Expr& n, Args... args) {
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
// Functions that can be overriden by subclass
virtual R VisitExpr_(const Variable* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Load* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Let* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Call* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Add* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Sub* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Mul* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Div* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Mod* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Min* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Max* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const EQ* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const NE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LT* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const GT* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const GE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const And* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Or* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Reduce* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Cast* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Not* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Select* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Ramp* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Broadcast* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const IntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const UIntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const StringImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args ...) {
LOG(FATAL) << "Do not have a default for " << op->type_key();
return R();
}
private:
// initialize the vtable.
static FType InitVTable() {
FType vtable;
// Set dispatch
IR_EXPR_FUNCTOR_DISPATCH(Variable);
IR_EXPR_FUNCTOR_DISPATCH(Load);
IR_EXPR_FUNCTOR_DISPATCH(Let);
IR_EXPR_FUNCTOR_DISPATCH(Call);
IR_EXPR_FUNCTOR_DISPATCH(Add);
IR_EXPR_FUNCTOR_DISPATCH(Sub);
IR_EXPR_FUNCTOR_DISPATCH(Mul);
IR_EXPR_FUNCTOR_DISPATCH(Div);
IR_EXPR_FUNCTOR_DISPATCH(Mod);
IR_EXPR_FUNCTOR_DISPATCH(Min);
IR_EXPR_FUNCTOR_DISPATCH(Max);
IR_EXPR_FUNCTOR_DISPATCH(EQ);
IR_EXPR_FUNCTOR_DISPATCH(NE);
IR_EXPR_FUNCTOR_DISPATCH(LT);
IR_EXPR_FUNCTOR_DISPATCH(LE);
IR_EXPR_FUNCTOR_DISPATCH(GT);
IR_EXPR_FUNCTOR_DISPATCH(GE);
IR_EXPR_FUNCTOR_DISPATCH(And);
IR_EXPR_FUNCTOR_DISPATCH(Or);
IR_EXPR_FUNCTOR_DISPATCH(Reduce);
IR_EXPR_FUNCTOR_DISPATCH(Cast);
IR_EXPR_FUNCTOR_DISPATCH(Not);
IR_EXPR_FUNCTOR_DISPATCH(Select);
IR_EXPR_FUNCTOR_DISPATCH(Ramp);
IR_EXPR_FUNCTOR_DISPATCH(Broadcast);
IR_EXPR_FUNCTOR_DISPATCH(IntImm);
IR_EXPR_FUNCTOR_DISPATCH(UIntImm);
IR_EXPR_FUNCTOR_DISPATCH(FloatImm);
IR_EXPR_FUNCTOR_DISPATCH(StringImm);
return vtable;
}
};
template<typename R, typename ...Args>
class StmtFunctor<R(const Stmt& n, Args... args)> {
private:
using TSelf = StmtFunctor<R(const Stmt& n, Args... args)>;
using FType = IRFunctor<R(const NodeRef& n, TSelf* self, Args... args)>;
public:
/*! \brief the result type of this functor */
using result_type = R;
/*! \brief virtual destructor */
virtual ~StmtFunctor() {}
/*!
* \brief Same as call.
* \param n The stmt node.
* \param args Additional arguments.
* \return The result of the call
*/
R operator()(const Stmt& n, Args... args) {
return VisitStmt(n, std::forward<Args>(args)...);
}
/*!
* \brief The functor call.
* \param n The stmt node.
* \param args Additional arguments.
* \return The result of the call
*/
virtual R VisitStmt(const Stmt& n, Args... args) {
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
// Functions that can be overriden by subclass
virtual R VisitStmt_(const LetStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AttrStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const IfThenElse* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const For* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Allocate* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Store* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Free* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AssertStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ProducerConsumer* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Provide* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Realize* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmtDefault_(const Node* op, Args ...) {
LOG(FATAL) << "Do not have a default for " << op->type_key();
return R();
}
private:
// initialize the vtable.
static FType InitVTable() {
FType vtable;
IR_STMT_FUNCTOR_DISPATCH(LetStmt);
IR_STMT_FUNCTOR_DISPATCH(AttrStmt);
IR_STMT_FUNCTOR_DISPATCH(IfThenElse);
IR_STMT_FUNCTOR_DISPATCH(For);
IR_STMT_FUNCTOR_DISPATCH(Allocate);
IR_STMT_FUNCTOR_DISPATCH(Store);
IR_STMT_FUNCTOR_DISPATCH(Free);
IR_STMT_FUNCTOR_DISPATCH(AssertStmt);
IR_STMT_FUNCTOR_DISPATCH(ProducerConsumer);
IR_STMT_FUNCTOR_DISPATCH(Provide);
IR_STMT_FUNCTOR_DISPATCH(Realize);
IR_STMT_FUNCTOR_DISPATCH(Block);
IR_STMT_FUNCTOR_DISPATCH(Evaluate);
return vtable;
}
};
#undef IR_STMT_FUNCTOR_DISPATCH
#undef IR_EXPR_FUNCTOR_DISPATCH
#undef EXPR_FUNCTOR_DEFAULT
#undef STMT_FUNCTOR_DEFAULT
} // namespace ir
} // namespace tvm
#endif // TVM_IR_FUNCTOR_EXT_H_
......@@ -55,59 +55,23 @@ class IRMutator {
static FMutateStmt& vtable_stmt(); // NOLINT(*)
// Set of overloadable functions
// The underscore allows Mutate not to be shadowed by inheritance
virtual Stmt Mutate_(const Variable* op, const Stmt& s);
virtual Stmt Mutate_(const LetStmt* op, const Stmt& s);
virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s);
virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s);
virtual Stmt Mutate_(const For* op, const Stmt& s);
virtual Stmt Mutate_(const Allocate* op, const Stmt& s);
virtual Stmt Mutate_(const Load* op, const Stmt& s);
virtual Stmt Mutate_(const Store* op, const Stmt& s);
virtual Stmt Mutate_(const Let* op, const Stmt& s);
virtual Stmt Mutate_(const Free* op, const Stmt& s);
virtual Stmt Mutate_(const Call* op, const Stmt& s);
virtual Stmt Mutate_(const Add* op, const Stmt& e);
virtual Stmt Mutate_(const Sub* op, const Stmt& e);
virtual Stmt Mutate_(const Mul* op, const Stmt& e);
virtual Stmt Mutate_(const Div* op, const Stmt& e);
virtual Stmt Mutate_(const Mod* op, const Stmt& e);
virtual Stmt Mutate_(const Min* op, const Stmt& e);
virtual Stmt Mutate_(const Max* op, const Stmt& e);
virtual Stmt Mutate_(const EQ* op, const Stmt& e);
virtual Stmt Mutate_(const NE* op, const Stmt& e);
virtual Stmt Mutate_(const LT* op, const Stmt& e);
virtual Stmt Mutate_(const LE* op, const Stmt& e);
virtual Stmt Mutate_(const GT* op, const Stmt& e);
virtual Stmt Mutate_(const GE* op, const Stmt& e);
virtual Stmt Mutate_(const And* op, const Stmt& e);
virtual Stmt Mutate_(const Or* op, const Stmt& e);
virtual Stmt Mutate_(const Reduce* op, const Stmt& s);
virtual Stmt Mutate_(const Cast* op, const Stmt& s);
virtual Stmt Mutate_(const Not* op, const Stmt& s);
virtual Stmt Mutate_(const Select* op, const Stmt& s);
virtual Stmt Mutate_(const Ramp* op, const Stmt& s);
virtual Stmt Mutate_(const Broadcast* op, const Stmt& e);
virtual Stmt Mutate_(const AssertStmt* op, const Stmt& e);
virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& e);
virtual Stmt Mutate_(const Provide* op, const Stmt& e);
virtual Stmt Mutate_(const Realize* op, const Stmt& s);
virtual Stmt Mutate_(const Block* op, const Stmt& s);
virtual Stmt Mutate_(const Evaluate* op, const Stmt& e);
virtual Stmt Mutate_(const IntImm* op, const Stmt& e);
virtual Stmt Mutate_(const UIntImm* op, const Stmt& e);
virtual Stmt Mutate_(const FloatImm* op, const Stmt& e);
virtual Stmt Mutate_(const StringImm* op, const Stmt& e);
virtual Expr Mutate_(const Variable* op, const Expr& e);
virtual Expr Mutate_(const LetStmt* op, const Expr& e);
virtual Expr Mutate_(const AttrStmt* op, const Expr& e);
virtual Expr Mutate_(const IfThenElse* op, const Expr& e);
virtual Expr Mutate_(const For* op, const Expr& e);
virtual Expr Mutate_(const Allocate* op, const Expr& e);
virtual Expr Mutate_(const Load* op, const Expr& e);
virtual Expr Mutate_(const Store* op, const Expr& e);
virtual Expr Mutate_(const Let* op, const Expr& e);
virtual Expr Mutate_(const Free* op, const Expr& e);
virtual Expr Mutate_(const Call* op, const Expr& e);
virtual Expr Mutate_(const Add* op, const Expr& e);
virtual Expr Mutate_(const Sub* op, const Expr& e);
......@@ -130,38 +94,12 @@ class IRMutator {
virtual Expr Mutate_(const Select* op, const Expr& e);
virtual Expr Mutate_(const Ramp* op, const Expr& e);
virtual Expr Mutate_(const Broadcast* op, const Expr& e);
virtual Expr Mutate_(const AssertStmt* op, const Expr& e);
virtual Expr Mutate_(const ProducerConsumer* op, const Expr& e);
virtual Expr Mutate_(const Provide* op, const Expr& e);
virtual Expr Mutate_(const Realize* op, const Expr& e);
virtual Expr Mutate_(const Block* op, const Expr& e);
virtual Expr Mutate_(const Evaluate* op, const Expr& e);
virtual Expr Mutate_(const IntImm* op, const Expr& e);
virtual Expr Mutate_(const UIntImm* op, const Expr& e);
virtual Expr Mutate_(const FloatImm* op, const Expr& e);
virtual Expr Mutate_(const StringImm* op, const Expr& e);
};
/*!
* \brief Example on how to subclass and override behavior of IRMutator
*/
class IRMutatorExample : public IRMutator {
public:
Expr Mutate(Expr expr) final {
static const FMutateExpr& f = IRMutatorExample::vtable_expr();
return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::Mutate(expr));
}
Stmt Mutate(Stmt stmt) final {
static const FMutateStmt& f = IRMutatorExample::vtable_stmt();
return (f.can_dispatch(stmt) ?
f(stmt, stmt, this) : IRMutator::Mutate(stmt));
}
// to be implemented by child class
static FMutateExpr& vtable_expr(); // NOLINT(*)
static FMutateStmt& vtable_stmt(); // NOLINT(*)
};
} // namespace ir
} // namespace tvm
#endif // TVM_IR_MUTATOR_H_
......@@ -6,6 +6,7 @@
#ifndef TVM_IR_VISITOR_H_
#define TVM_IR_VISITOR_H_
#include <tvm/ir_functor.h>
#include "./ir.h"
namespace tvm {
......@@ -17,7 +18,51 @@ namespace ir {
* This IRVisitor is implemented via IRFunctor
* This enables extensions of possible new Node.
*
* \sa IRFunctor, PostOrderVisit
* \sa ExprFunctor, StmtFunctor, PostOrderVisit
*
* \note If you need to return values during Visit:
* - If it is mutaion of the IR, use IRMutator
* - If you want to return other things, consider use ExprFunctor/StmtFunctor
* - Watch out for possible bug pattern if you use IRVisitor to simulate returns.
*
* \code
*
* // This is an example code to show cases for traps in IRVisitor
* // The use case is to count number of Variables in the ir tree.
* class MyCounter : public IRVisitor {
* public:
* int Count(const NodeRef& n) {
* ret_ = 0;
* this->Visit(n);
* return ret_;
* }
* void Visit_(const Variable* op) final {
* ret_ = 1;
* }
* void Visit_(const Add* op) final {
* ret_ = count(op->a) + count(op->b);
* }
* private:
* int ret_;
* };
* MyCounter counter;
* Var x("x");
* // this returns 2
* CHECK_EQ(counter.Count(x + x), 2);
* // Think what is the result of the following count
* counter.count(Max::make(x, x));
* // The result is actually 1
* // This is because Visit is not overriden for Max
* // so it simply calls Visit for the left and right children
* // and because Count is not called, ret_ is not cleared.
* // There can also be cases where ret_ is forgetten to be set.
*
* // These traps may not happen if we program carefully
* // But it is recommended to use ExprFunctor, which allows direct
* // return the value, this helps us to avoid such problems.
* \encode
*
*/
class IRVisitor {
public:
......
......@@ -274,33 +274,51 @@ def sum(expr, axis):
return x
def min(expr, axis):
"""Create a min expression over axis
def min(lhs, rhs=None, axis=None):
"""Create a min expression.
Parameters
----------
expr : Expr
The source expression.
lhs : Expr
The left hand expression.
axis : IterVar
rhs : Expr, optional
The right hand expression.
axis : IterVar, optional
The reduction IterVar axis
"""
if rhs and axis:
raise ValueError("Can only take one argument, rhs or axis")
if isinstance(rhs, (_collections.IterVar, list)):
axis, rhs = rhs, axis
if rhs:
return _make.Min(lhs, rhs)
axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Min", expr, axis)
return x
def max(expr, axis):
"""Create a min expression over axis
def max(lhs, rhs=None, axis=None):
"""Create a max expression.
Parameters
----------
expr : Expr
The source expression.
lhs : Expr
The left hand expression.
axis : IterVar
rhs : Expr, optional
The right hand expression.
axis : IterVar, optional
The reduction IterVar axis
"""
if rhs and axis:
raise ValueError("Can only take one argument, rhs or axis")
if isinstance(rhs, (_collections.IterVar, list)):
axis, rhs = rhs, axis
if rhs:
return _make.Max(lhs, rhs)
axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Max", expr, axis)
return x
......
......@@ -5,7 +5,6 @@ from __future__ import absolute_import as _abs
from ._ctypes._node import NodeBase, register_node
from . import _api_internal
@register_node
class IntSet(NodeBase):
"""Represent a set of integer in one dimension."""
def is_nothing(self):
......@@ -33,3 +32,8 @@ class IntervalSet(IntSet):
class StrideSet(IntSet):
"""Represent set of strided integers"""
pass
@register_node
class ModularSet(IntSet):
"""Represent range of (coeff * x + base) for x in Z """
pass
......@@ -7,6 +7,7 @@
#include <tvm/ir.h>
#include <tvm/api_registry.h>
#include "../arithmetic/int_set.h"
#include "../arithmetic/modular.h"
namespace tvm {
namespace arith {
......@@ -21,6 +22,11 @@ TVM_REGISTER_API(_arith_intset_interval)
*ret = IntSet::interval(args[0], args[1]);
});
TVM_REGISTER_API(_arith_EvalModular)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = EvalModular(args[0], Map<Var, IntSet>());
});
TVM_REGISTER_API(_arith_DeduceBound)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DeduceBound(args[0], args[1], args[2]);
......
......@@ -162,10 +162,8 @@ class Canonical::Internal : public IRMutator {
return stmt;
}
Expr MutateExpr_(Expr expr) {
static const FMutateExpr& f = Internal::vtable_expr();
stack_.push_back(StackEntry());
expr = (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::Mutate(expr));
expr = IRMutator::Mutate(expr);
// update result of parent automatically during pop
if (stack_.size() > 1) {
StackEntry& back = stack_[stack_.size() - 1];
......@@ -200,7 +198,7 @@ class Canonical::Internal : public IRMutator {
return (t.lanes() == 1 && (t.is_int() || t.is_uint()));
}
// Add
Expr Mutate_(const Add* op, const Expr& e) {
Expr Mutate_(const Add* op, const Expr& e) final {
if (!EnableOpt(op->type)) {
return Binary(op, e, this);
}
......@@ -212,7 +210,7 @@ class Canonical::Internal : public IRMutator {
return SumAdd(a, b, +1);
}
// Sub
Expr Mutate_(const Sub* op, const Expr& e) {
Expr Mutate_(const Sub* op, const Expr& e) final {
if (!EnableOpt(op->type)) {
return Binary(op, e, this);
}
......@@ -224,7 +222,7 @@ class Canonical::Internal : public IRMutator {
return SumAdd(a, b, -1);
}
// Mul
Expr Mutate_(const Mul* op, const Expr& e) {
Expr Mutate_(const Mul* op, const Expr& e) final {
if (!EnableOpt(op->type)) {
return Binary(op, e, this);
}
......@@ -463,17 +461,6 @@ class Canonical::Internal : public IRMutator {
using CInternal = Canonical::Internal;
#define DISPATCH_EXPR(OP) \
set_dispatch<OP>([](const OP *op, const Expr& e, IRMutator* p) { \
return static_cast<CInternal*>(p)->Mutate_(op, e); })
TVM_STATIC_IR_FUNCTOR(CInternal, vtable_expr)
.DISPATCH_EXPR(Add)
.DISPATCH_EXPR(Sub)
.DISPATCH_EXPR(Mul)
.DISPATCH_EXPR(LT);
Canonical::Canonical()
: ptr_(std::make_shared<Internal>()) {}
......
......@@ -9,6 +9,7 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include "./int_set.h"
#include "./modular.h"
namespace tvm {
namespace arith {
......@@ -54,6 +55,23 @@ struct StrideSet : public IntSetNode {
TVM_DECLARE_NODE_TYPE_INFO(StrideSet, IntSetNode);
};
/*!
* \brief Set represented by range of ModularEntry.
* Used for front-end modular analysis.
*/
struct ModularSet : public IntSetNode {
/*! \brief Internal modular entry */
ModularEntry e;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("base", &(e.base));
v->Visit("coeff", &(e.coeff));
}
static constexpr const char* _type_key = "ModularSet";
TVM_DECLARE_NODE_TYPE_INFO(ModularSet, IntSetNode);
};
} // namespace arith
} // namespace tvm
......
/*!
* Copyright (c) 2017 by Contributors
* \file modular.cc
* \brief Modular analysis
*/
#include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_visitor.h>
#include <limits>
#include "./modular.h"
#include "./int_set_internal.h"
namespace tvm {
namespace arith {
using namespace ir;
class ModularEvaluator
: public ExprFunctor<ModularEntry(const Expr&)> {
public:
explicit ModularEvaluator(
const std::unordered_map<
const Variable*, ModularEntry>& mod_map)
: mod_map_(mod_map) {
}
ModularEntry Eval(const Expr& e) {
return VisitExpr(e);
}
// default
ModularEntry VisitExprDefault_(const Node*) final {
return ModularEntry::everything();
}
// override combination rules.
ModularEntry VisitExpr_(const IntImm* op) final {
if (op->value < std::numeric_limits<int>::max()) {
ModularEntry ret;
ret.base = static_cast<int>(op->value);
ret.coeff = 0;
return ret;
} else {
return ModularEntry::everything();
}
}
ModularEntry VisitExpr_(const UIntImm* op) final {
if (op->value < static_cast<uint64_t>(
std::numeric_limits<int>::max())) {
ModularEntry ret;
ret.base = static_cast<int>(op->value);
ret.coeff = 0;
return ret;
} else {
return ModularEntry::everything();
}
}
ModularEntry VisitExpr_(const Variable* op) final {
auto it = mod_map_.find(op);
if (it != mod_map_.end()) {
return it->second;
} else {
return ModularEntry::everything();
}
}
ModularEntry VisitExpr_(const Add* op) final {
ModularEntry a = Eval(op->a);
ModularEntry b = Eval(op->b);
ModularEntry ret;
ret.coeff = ZeroAwareGCD(a.coeff, b.coeff);
ret.base = BaseSimplify(a.base + b.base, ret.coeff);
return ret;
}
ModularEntry VisitExpr_(const Sub* op) final {
ModularEntry a = Eval(op->a);
ModularEntry b = Eval(op->b);
ModularEntry ret;
ret.coeff = ZeroAwareGCD(a.coeff, b.coeff);
ret.base = BaseSimplify(a.base - b.base, ret.coeff);
return ret;
}
ModularEntry VisitExpr_(const Mul* op) final {
ModularEntry a = Eval(op->a);
ModularEntry b = Eval(op->b);
// Simplification rule, x, y, z are in Z
// (p x + n) (q y + m)
// -> pq xy + pm x + qn y + mn
// -> pq z + pm x + qn y + mn
int pq = a.coeff * b.coeff;
int pm = a.coeff * b.base;
int qn = a.base * b.coeff;
ModularEntry ret;
ret.coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn));
ret.base = BaseSimplify(a.base * b.base, ret.coeff);
return ret;
}
ModularEntry VisitExpr_(const Div* op) final {
// a c x / c -> a x
// We cannot do cases where offset is non-zero
// because of different integer rounding in pos/neg
ModularEntry a = Eval(op->a);
ModularEntry b = Eval(op->b);
if (b.coeff == 0 &&
a.base == 0) {
CHECK_NE(b.base, 0);
if (a.coeff % b.base == 0) {
ModularEntry ret;
ret.coeff = a.coeff / b.base;
ret.base = 0;
return ret;
}
}
return ModularEntry::everything();
}
private:
const std::unordered_map<
const Variable*, ModularEntry>& mod_map_;
// simplify the base by putting it in range.
static int BaseSimplify(int base, int coeff) {
if (coeff == 0) return base;
base = base % coeff;
if (base < 0) base += coeff;
return base;
}
static int ZeroAwareGCD(int a, int b) {
CHECK_GE(a, 0);
CHECK_GE(b, 0);
if (a < b) std::swap(a, b);
if (b == 0) return a;
// perform GCD (greatest common divisor)
// ax + by = gcd(a, b) z if a != 0, b != 0
while (a % b != 0) {
a = a % b;
std::swap(a, b);
}
return b;
}
};
ModularEntry EvalModular(
const Expr& e,
const std::unordered_map<const Variable*, ModularEntry>& mod_map) {
return ModularEvaluator(mod_map)(e);
}
IntSet EvalModular(const Expr& e,
const Map<Var, IntSet>& mod_map) {
std::unordered_map<const Variable*, ModularEntry> mmap;
for (auto& kv : mod_map) {
const ModularSet* m = kv.second.as<ModularSet>();
CHECK(m) << "Need to pass ModularSet for Modular Analysis";
mmap[kv.first.get()] = m->e;
}
std::shared_ptr<ModularSet> n = std::make_shared<ModularSet>();
n->e = ModularEvaluator(mmap)(e);
return IntSet(n);
}
} // namespace arith
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file modular.h
* \brief Modular integer set analysis
*/
#ifndef TVM_ARITHMETIC_MODULAR_H_
#define TVM_ARITHMETIC_MODULAR_H_
#include <tvm/expr.h>
#include "./int_set.h"
namespace tvm {
namespace arith {
/*!
* \brief Range of a linear integer function.
* Use to do specify the possible index values.
*
* set = { base + coeff * x | x \in Z }
*
* When coeff != 0, it can also be written as
* set = { n | n % coeff == base }
*
* This is useful to decide if the index is dividable by certain value.
* For example, if index = 0 + 4 x, then we know it can be divided by 4.
*/
struct ModularEntry {
/*! \brief The base */
int base;
/*! \brief linear co-efficient */
int coeff;
/*! \return entry represent everything */
static ModularEntry everything() {
// always safe to set 0 + x, so it can be everything.
ModularEntry e;
e.base = 0; e.coeff = 1;
return e;
}
};
/*!
* \brief Evaluate the expression with modular analysis
* \param e The expression to be evaluated.
* \param mod_map Map of modular statistics of known variables.
* \return The ModularEntry covering all possible value of e.
*/
ModularEntry EvalModular(
const Expr& e,
const std::unordered_map<const Variable*, ModularEntry>& mod_map);
/*!
* \brief Same as EvalModular, used by front-end.
* \param e The expression to be evaluated.
* \param mod_map Map of modular statistics of known variables.
* \return A ModularSet covering all possible value of e.
*/
IntSet EvalModular(const Expr& e,
const Map<Var, IntSet>& mod_map);
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_MODULAR_H_
......@@ -67,10 +67,6 @@ std::string CodeGenC::Finish() {
return stream.str();
}
void CodeGenC::PrintStmt(const Stmt& n) {
static const FPrintStmt& f = vtable_print_stmt();
f(n, this);
}
std::string CodeGenC::SSAGetID(std::string src, Type t) {
if (name_alloc_map_.count(src)) return src;
......@@ -96,13 +92,12 @@ std::string CodeGenC::SSAGetID(std::string src, Type t) {
}
void CodeGenC::PrintExpr(const Expr& n, std::ostream& os) { // NOLINT(*)
static const FPrintExpr& f = vtable_print_expr();
if (print_ssa_form_) {
std::ostringstream temp;
f(n, temp, this);
VisitExpr(n, temp);
os << SSAGetID(temp.str(), n.type());
} else {
f(n, os, this);
VisitExpr(n, os);
}
}
......@@ -178,6 +173,102 @@ void CodeGenC::MarkConst(std::string vid) {
}
}
int CodeGenC::BeginScope() {
int sid = static_cast<int>(scope_mark_.size());
scope_mark_.push_back(true);
indent += 2;
return sid;
}
void CodeGenC::EndScope(int scope_id) {
scope_mark_[scope_id] = false;
indent -= 2;
}
// Print a reference expression to a buffer.
void CodeGenC::PrintBufferRef(
const Variable* buffer,
Type t, Expr index,
std::ostream& os) { // NOLINT(*)
std::string vid = GetVarID(buffer);
if (t.lanes() == 1) {
if (!HandleTypeMatch(buffer, t)) {
os << "((";
PrintType(t, os);
os << "*)" << vid << ')';
} else {
os << vid;
}
os << '[';
PrintExpr(index, os);
os << ']';
} else {
// Buffer declared as vector type.
// optimize for case where it is in register,
if (HandleTypeMatch(buffer, t)) {
// optimize for constant access
int offset;
if (arith::GetConstInt(index, &offset)) {
CHECK_EQ(offset % t.lanes(), 0)
<< "Find unaligned vector load to a vector type";
os << vid << '[' << (offset / t.lanes()) << ']';
return;
}
}
os << "((";
PrintType(t, os);
os << "*)(";
if (!HandleTypeMatch(buffer, t.element_of())) {
os << '(';
PrintType(t.element_of(), os);
os << "*)";
}
os << vid << " + ";
PrintExpr(index, os);
os << "))[0]";
}
}
void CodeGenC::PrintVecElemLoad(const std::string& vec,
Type t, int i,
std::ostream& os) { // NOLINT(*)
os << vec << ".s" << std::hex << i;
}
void CodeGenC::PrintVecElemStore(const std::string& vec,
Type t, int i,
const std::string& value) {
this->PrintIndent();
stream << vec << ".s" << std::hex << i
<< " = " << value << ";\n";
}
void CodeGenC::PrintVecLoad(const Variable* buffer,
Type t, Expr base,
std::ostream& os) {
PrintBufferRef(buffer, t, base, os);
}
void CodeGenC::PrintVecStore(const Variable* buffer,
Type t, Expr base,
const std::string& value) {
this->PrintIndent();
PrintBufferRef(buffer, t, base, stream);
stream << " = " << value << ";\n";
}
void CodeGenC::PrintThreadIndexExpr(
std::string thread_tag, std::ostream& os) { // NOLINT(*)
os << thread_tag;
}
void CodeGenC::PrintStorageSync(const std::string& sync) { // NOLINT(*)
}
void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
CHECK_EQ(scope, "global");
}
void CodeGenC::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
CHECK_EQ(t.lanes(), 1)
<< "do not yet support vector types";
......@@ -208,13 +299,6 @@ void CodeGenC::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
LOG(FATAL) << "Cannot convert type " << t << " to C type";
}
CodeGenC::FPrintStmt& CodeGenC::vtable_print_stmt() { // NOLINT(*)
static FPrintStmt inst; return inst;
}
CodeGenC::FPrintExpr& CodeGenC::vtable_print_expr() { // NOLINT(*)
static FPrintExpr inst; return inst;
}
inline void PrintConst(const IntImm* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
if (op->type == Int(32)) {
......@@ -262,19 +346,18 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N
}
}
TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr)
.set_dispatch<IntImm>([](const IntImm *op, std::ostream& os, CodeGenC *p) { // NOLINT(*)
PrintConst(op, os, p);
})
.set_dispatch<UIntImm>([](const UIntImm *op, std::ostream& os, CodeGenC *p) { // NOLINT(*)
PrintConst(op, os, p);
})
.set_dispatch<FloatImm>([](const FloatImm *op, std::ostream& os, CodeGenC *p) { // NOLINT(*)
PrintConst(op, os, p);
})
.set_dispatch<StringImm>([](const StringImm *op, std::ostream& os, CodeGenC *p) { // NOLINT(*)
void CodeGenC::VisitExpr_(const IntImm *op, std::ostream& os) { // NOLINT(*)
PrintConst(op, os, this);
}
void CodeGenC::VisitExpr_(const UIntImm *op, std::ostream& os) { // NOLINT(*)
PrintConst(op, os, this);
}
void CodeGenC::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*)
PrintConst(op, os, this);
}
void CodeGenC::VisitExpr_(const StringImm *op, std::ostream& os) { // NOLINT(*)
os << "\"" << op->value << "\"";
});
}
template<typename T>
inline void PrintBinaryExpr(const T* op,
......@@ -315,137 +398,99 @@ inline void PrintBinaryIntrinsitc(const Call* op,
p->PrintVecBinaryOp(opstr, op->type, op->args[0], op->args[1], os);
}
}
TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr)
.set_dispatch<Cast>([](const Cast *op, std::ostream& os, CodeGenC *p) { // NOLINT(*)
p->PrintType(op->type, os);
void CodeGenC::VisitExpr_(const Cast *op, std::ostream& os) { // NOLINT(*)
this->PrintType(op->type, os);
os << '(';
p->PrintExpr(op->value, os);
this->PrintExpr(op->value, os);
os << ')';
})
.set_dispatch<Variable>([](const Variable *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
os << p->GetVarID(op);
})
.set_dispatch<Add>([](const Add *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, "+", os, p);
})
.set_dispatch<Sub>([](const Sub *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, "-", os, p);
})
.set_dispatch<Mul>([](const Mul *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, "*", os, p);
})
.set_dispatch<Div>([](const Div *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, "/", os, p);
})
.set_dispatch<Mod>([](const Mod *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, "%", os, p);
})
.set_dispatch<Min>([](const Min *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, "min", os, p);
})
.set_dispatch<Max>([](const Max *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, "max", os, p);
})
.set_dispatch<EQ>([](const EQ *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, "==", os, p);
})
.set_dispatch<NE>([](const NE *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, "!=", os, p);
})
.set_dispatch<LT>([](const LT *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, "<", os, p);
})
.set_dispatch<LE>([](const LE *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, "<=", os, p);
})
.set_dispatch<GT>([](const GT *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, ">", os, p);
})
.set_dispatch<GE>([](const GE *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, ">=", os, p);
})
.set_dispatch<And>([](const And *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, "&&", os, p);
})
.set_dispatch<Or>([](const Or *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, "||", os, p);
})
.set_dispatch<Not>([](const Not *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
}
void CodeGenC::VisitExpr_(const Variable *op, std::ostream& os) { // NOLINT(*)
os << GetVarID(op);
}
void CodeGenC::VisitExpr_(const Add *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "+", os, this);
}
void CodeGenC::VisitExpr_(const Sub *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "-", os, this);
}
void CodeGenC::VisitExpr_(const Mul *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "*", os, this);
}
void CodeGenC::VisitExpr_(const Div *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "/", os, this);
}
void CodeGenC::VisitExpr_(const Mod *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "%", os, this);
}
void CodeGenC::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "min", os, this);
}
void CodeGenC::VisitExpr_(const Max *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "max", os, this);
}
void CodeGenC::VisitExpr_(const EQ *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "==", os, this);
}
void CodeGenC::VisitExpr_(const NE *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "!=", os, this);
}
void CodeGenC::VisitExpr_(const LT *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "<", os, this);
}
void CodeGenC::VisitExpr_(const LE *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "<=", os, this);
}
void CodeGenC::VisitExpr_(const GT *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, ">", os, this);
}
void CodeGenC::VisitExpr_(const GE *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, ">=", os, this);
}
void CodeGenC::VisitExpr_(const And *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "&&", os, this);
}
void CodeGenC::VisitExpr_(const Or *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "||", os, this);
}
void CodeGenC::VisitExpr_(const Not *op, std::ostream& os) { // NOLINT(*)
os << '!';
p->PrintExpr(op->a, os);
});
TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt)
.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, CodeGenC* p) {
p->PrintStmt(op->body);
})
.set_dispatch<Block>([](const Block *op, CodeGenC* p) {
p->PrintStmt(op->first);
if (op->rest.defined()) p->PrintStmt(op->rest);
})
.set_dispatch<Evaluate>([](const Evaluate *op, CodeGenC* p) {
if (is_const(op->value)) return;
const Call* call = op->value.as<Call>();
if (call && call->is_intrinsic(intrinsic::tvm_storage_sync)) {
p->PrintStorageSync(call->args[0].as<StringImm>()->value);
} else {
std::string vid = p->PrintExpr(op->value);
p->PrintIndent();
p->stream << "(void)" << vid << ";\n";
}
});
#define DISPATCH_EXPR(OP) \
set_dispatch<OP>([](const OP *op, std::ostream&os, CodeGenC* p) { \
p->PrintExpr(op, os); })
TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr)
.DISPATCH_EXPR(Load)
.DISPATCH_EXPR(Call)
.DISPATCH_EXPR(Let)
.DISPATCH_EXPR(Ramp)
.DISPATCH_EXPR(Broadcast)
.DISPATCH_EXPR(Select);
PrintExpr(op->a, os);
}
void CodeGenC::PrintExpr(const Call *op, std::ostream& os) { // NOLINT(*)
CodeGenC* p = this;
void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
if (op->is_intrinsic(Call::bitwise_and)) {
PrintBinaryIntrinsitc(op, " & ", os, p);
PrintBinaryIntrinsitc(op, " & ", os, this);
} else if (op->is_intrinsic(Call::bitwise_xor)) {
PrintBinaryIntrinsitc(op, " ^ ", os, p);
PrintBinaryIntrinsitc(op, " ^ ", os, this);
} else if (op->is_intrinsic(Call::bitwise_or)) {
PrintBinaryIntrinsitc(op, " | ", os, p);
PrintBinaryIntrinsitc(op, " | ", os, this);
} else if (op->is_intrinsic(Call::bitwise_not)) {
CHECK_EQ(op->args.size(), 1U);
os << "(~";
p->PrintExpr(op->args[0], os);
this->PrintExpr(op->args[0], os);
os << ')';
} else if (op->is_intrinsic(Call::shift_left)) {
PrintBinaryIntrinsitc(op, " << ", os, p);
PrintBinaryIntrinsitc(op, " << ", os, this);
} else if (op->is_intrinsic(Call::shift_right)) {
PrintBinaryIntrinsitc(op, " >> ", os, p);
PrintBinaryIntrinsitc(op, " >> ", os, this);
} else if (op->is_intrinsic(Call::address_of)) {
const Load *l = op->args[0].as<Load>();
CHECK(op->args.size() == 1 && l);
os << "((";
p->PrintType(l->type.element_of(), os);
os << " *)" << p->GetVarID(l->buffer_var.get())
this->PrintType(l->type.element_of(), os);
os << " *)" << this->GetVarID(l->buffer_var.get())
<< " + ";
p->PrintExpr(l->index, os);
this->PrintExpr(l->index, os);
os << ')';
} else if (op->is_intrinsic(intrinsic::tvm_api_load_arg)) {
CHECK_EQ(op->args.size(), 3U);
if (!op->type.is_handle()) {
os << '(';
p->PrintType(op->type, os);
this->PrintType(op->type, os);
os << ')';
}
os << "(((TVMArg*)";
p->PrintExpr(op->args[0], os);
this->PrintExpr(op->args[0], os);
os << ")[" << op->args[2] << "].";
if (op->type.is_handle()) {
os << "v_handle";
......@@ -460,7 +505,7 @@ void CodeGenC::PrintExpr(const Call *op, std::ostream& os) { // NOLINT(*)
} else if (op->is_intrinsic(intrinsic::tvm_array_get_field)) {
CHECK_EQ(op->args.size(), 2U);
os << "(((TVMArray*)";
p->PrintExpr(op->args[0], os);
this->PrintExpr(op->args[0], os);
os << ")->";
switch (op->args[1].as<IntImm>()->value) {
case intrinsic::kData: os << "data"; break;
......@@ -476,12 +521,12 @@ void CodeGenC::PrintExpr(const Call *op, std::ostream& os) { // NOLINT(*)
} else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
CHECK_EQ(op->args.size(), 1U);
os << "(";
p->PrintExpr(op->args[0], os);
this->PrintExpr(op->args[0], os);
os << " == NULL)";
} else {
os << op->name << "(";
for (size_t i = 0; i < op->args.size(); i++) {
p->PrintExpr(op->args[i], os);
this->PrintExpr(op->args[i], os);
if (i < op->args.size() - 1) {
os << ", ";
}
......@@ -517,51 +562,7 @@ inline bool TryGetRamp1Base(Expr index, int lanes, Expr *base) {
return true;
}
// Print a reference expression to a buffer.
void CodeGenC::PrintBufferRef(
const Variable* buffer,
Type t, Expr index,
std::ostream& os) { // NOLINT(*)
std::string vid = GetVarID(buffer);
if (t.lanes() == 1) {
if (!HandleTypeMatch(buffer, t)) {
os << "((";
PrintType(t, os);
os << "*)" << vid << ')';
} else {
os << vid;
}
os << '[';
PrintExpr(index, os);
os << ']';
} else {
// Buffer declared as vector type.
// optimize for case where it is in register,
if (HandleTypeMatch(buffer, t)) {
// optimize for constant access
int offset;
if (arith::GetConstInt(index, &offset)) {
CHECK_EQ(offset % t.lanes(), 0)
<< "Find unaligned vector load to a vector type";
os << vid << '[' << (offset / t.lanes()) << ']';
return;
}
}
os << "((";
PrintType(t, os);
os << "*)(";
if (!HandleTypeMatch(buffer, t.element_of())) {
os << '(';
PrintType(t.element_of(), os);
os << "*)";
}
os << vid << " + ";
PrintExpr(index, os);
os << "))[0]";
}
}
void CodeGenC::PrintExpr(const Load* op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
int lanes = op->type.lanes();
if (op->type.lanes() == 1) {
this->PrintBufferRef(op->buffer_var.get(), op->type, op->index, os);
......@@ -600,7 +601,7 @@ void CodeGenC::PrintExpr(const Load* op, std::ostream& os) { // NOLINT(*)
}
}
void CodeGenC::PrintStmt(const Store* op) {
void CodeGenC::VisitStmt_(const Store* op) {
Type t = op->value.type();
if (t.lanes() == 1) {
this->PrintIndent();
......@@ -637,35 +638,7 @@ void CodeGenC::PrintStmt(const Store* op) {
}
}
void CodeGenC::PrintVecElemLoad(const std::string& vec,
Type t, int i,
std::ostream& os) { // NOLINT(*)
os << vec << ".s" << std::hex << i;
}
void CodeGenC::PrintVecElemStore(const std::string& vec,
Type t, int i,
const std::string& value) {
this->PrintIndent();
stream << vec << ".s" << std::hex << i
<< " = " << value << ";\n";
}
void CodeGenC::PrintVecLoad(const Variable* buffer,
Type t, Expr base,
std::ostream& os) {
PrintBufferRef(buffer, t, base, os);
}
void CodeGenC::PrintVecStore(const Variable* buffer,
Type t, Expr base,
const std::string& value) {
this->PrintIndent();
PrintBufferRef(buffer, t, base, stream);
stream << " = " << value << ";\n";
}
void CodeGenC::PrintExpr(const Let* op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const Let* op, std::ostream& os) { // NOLINT(*)
CHECK(print_ssa_form_)
<< "LetExpr is only supported by print SSA form";
std::string value = PrintExpr(op->value);
......@@ -673,41 +646,19 @@ void CodeGenC::PrintExpr(const Let* op, std::ostream& os) { // NOLINT(*)
var_idmap_[op->var.get()] = value;
}
void CodeGenC::PrintExpr(const Ramp* op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const Ramp* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Ramp: not supported ";
}
void CodeGenC::PrintExpr(const Broadcast* op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Broadcast: not supported ";
}
void CodeGenC::PrintExpr(const Select* op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Select: not supported ";
}
// Disoatch back to member functions
TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt)
.set_dispatch<LetStmt>([](const LetStmt *op, CodeGenC* p) { p->PrintStmt(op); })
.set_dispatch<Store>([](const Store *op, CodeGenC* p) { p->PrintStmt(op); })
.set_dispatch<Allocate>([](const Allocate *op, CodeGenC* p) { p->PrintStmt(op); })
.set_dispatch<AttrStmt>([](const AttrStmt *op, CodeGenC* p) { p->PrintStmt(op); })
.set_dispatch<AssertStmt>([](const AssertStmt *op, CodeGenC* p) { p->PrintStmt(op); })
.set_dispatch<For>([](const For *op, CodeGenC* p) { p->PrintStmt(op); })
.set_dispatch<IfThenElse>([](const IfThenElse *op, CodeGenC* p) { p->PrintStmt(op); });
void CodeGenC::PrintThreadIndexExpr(
std::string thread_tag, std::ostream& os) { // NOLINT(*)
os << thread_tag;
}
void CodeGenC::PrintStorageSync(const std::string& sync) { // NOLINT(*)
}
void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
CHECK_EQ(scope, "global");
}
void CodeGenC::PrintStmt(const LetStmt* op) {
void CodeGenC::VisitStmt_(const LetStmt* op) {
std::string value = PrintExpr(op->value);
if (print_ssa_form_) {
CHECK(!var_idmap_.count(op->var.get()));
......@@ -732,7 +683,7 @@ void CodeGenC::PrintStmt(const LetStmt* op) {
PrintStmt(op->body);
}
void CodeGenC::PrintStmt(const Allocate* op) {
void CodeGenC::VisitStmt_(const Allocate* op) {
CHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get());
if (op->new_expr.defined()) {
......@@ -758,7 +709,7 @@ void CodeGenC::PrintStmt(const Allocate* op) {
this->PrintStmt(op->body);
}
void CodeGenC::PrintStmt(const AttrStmt* op) {
void CodeGenC::VisitStmt_(const AttrStmt* op) {
if (op->type_key == ir::attr::thread_extent) {
IterVar iv(op->node.node_);
if (iv->thread_tag.length() != 0) {
......@@ -780,7 +731,7 @@ void CodeGenC::PrintStmt(const AttrStmt* op) {
this->PrintStmt(op->body);
}
void CodeGenC::PrintStmt(const AssertStmt* op) {
void CodeGenC::VisitStmt_(const AssertStmt* op) {
std::string cond = PrintExpr(op->condition);
PrintIndent();
if (op->message.as<StringImm>()) {
......@@ -792,19 +743,7 @@ void CodeGenC::PrintStmt(const AssertStmt* op) {
}
}
int CodeGenC::BeginScope() {
int sid = static_cast<int>(scope_mark_.size());
scope_mark_.push_back(true);
indent += 2;
return sid;
}
void CodeGenC::EndScope(int scope_id) {
scope_mark_[scope_id] = false;
indent -= 2;
}
void CodeGenC::PrintStmt(const For* op) {
void CodeGenC::VisitStmt_(const For* op) {
std::string extent = PrintExpr(op->extent);
PrintIndent();
std::string vid = AllocVarID(op->loop_var.get());
......@@ -821,7 +760,7 @@ void CodeGenC::PrintStmt(const For* op) {
stream << "}\n";
}
void CodeGenC::PrintStmt(const IfThenElse* op) {
void CodeGenC::VisitStmt_(const IfThenElse* op) {
std::string cond = PrintExpr(op->condition);
PrintIndent();
stream << "if (" << cond << ") {\n";
......@@ -840,6 +779,27 @@ void CodeGenC::PrintStmt(const IfThenElse* op) {
stream << "}\n";
}
void CodeGenC::VisitStmt_(const Block *op) {
PrintStmt(op->first);
if (op->rest.defined()) PrintStmt(op->rest);
}
void CodeGenC::VisitStmt_(const Evaluate *op) {
if (is_const(op->value)) return;
const Call* call = op->value.as<Call>();
if (call && call->is_intrinsic(intrinsic::tvm_storage_sync)) {
this->PrintStorageSync(call->args[0].as<StringImm>()->value);
} else {
std::string vid = this->PrintExpr(op->value);
this->PrintIndent();
this->stream << "(void)" << vid << ";\n";
}
}
void CodeGenC::VisitStmt_(const ProducerConsumer *op) {
PrintStmt(op->body);
}
} // namespace codegen
} // namespace tvm
......@@ -7,6 +7,7 @@
#define TVM_CODEGEN_CODEGEN_C_H_
#include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/codegen.h>
#include <tvm/lowered_func.h>
#include <string>
......@@ -16,12 +17,15 @@
namespace tvm {
namespace codegen {
using namespace ir;
/*!
* \brief A base class to generate C code.
*
* CodeGenC have two modes: generate SSA formed C code or normal form.
*/
class CodeGenC {
class CodeGenC :
public ExprFunctor<void(const Expr&, std::ostream&)>,
public StmtFunctor<void(const Stmt&)> {
public:
/*!
* \brief Initialize the code generator.
......@@ -42,13 +46,15 @@ class CodeGenC {
* \brief Print the Stmt n to CodeGenC->stream
* \param n The statement to be printed.
*/
void PrintStmt(const Stmt& n);
void PrintStmt(const Stmt& n) {
VisitStmt(n);
}
/*!
* \brief Print the expression n(or its ssa id if in ssa mode) into os
* \param n The expression to be printed.
* \param os The output stream
*/
void PrintExpr(const Expr& n, std::ostream& os); // NOLINT(*)
void PrintExpr(const Expr& n, std::ostream& os);
/*!
* \brief Same as PrintExpr, but simply returns result string
* \param n The expression to be printed.
......@@ -84,6 +90,46 @@ class CodeGenC {
* \param f The function to be compiled.
*/
virtual void InitFuncState(LoweredFunc f);
// expression
void VisitExpr_(const Variable* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Load* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Let* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Call* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Add* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Sub* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Mul* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Div* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Mod* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Min* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Max* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const EQ* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const NE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LT* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GT* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const And* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Or* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Cast* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Not* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Select* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Ramp* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Broadcast* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const IntImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const UIntImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const FloatImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const StringImm* op, std::ostream& os) override; // NOLINT(*)
// statment
void VisitStmt_(const LetStmt* op) override;
void VisitStmt_(const Store* op) override;
void VisitStmt_(const For* op) override;
void VisitStmt_(const IfThenElse* op) override;
void VisitStmt_(const Allocate* op) override;
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
/*!
* Print Type represetnation of type t.
* \param t The type representation.
......@@ -97,50 +143,37 @@ class CodeGenC {
*/
virtual void PrintThreadIndexExpr(
std::string tag, std::ostream& os); // NOLINT(*)
virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*
virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*)
virtual void PrintStorageSync(const std::string& scope); // NOLINT(*)
virtual void PrintStmt(const ir::LetStmt* op);
virtual void PrintStmt(const ir::Store* op);
virtual void PrintStmt(const ir::For* op);
virtual void PrintStmt(const ir::IfThenElse* op);
virtual void PrintStmt(const ir::Allocate* op);
virtual void PrintStmt(const ir::AttrStmt* op);
virtual void PrintStmt(const ir::AssertStmt* op);
virtual void PrintExpr(const ir::Load* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Call* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Let* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Ramp* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Broadcast* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Select* op, std::ostream& os); // NOLINT(*)
// Binary vector op.
virtual void PrintVecBinaryOp(
const std::string&op, Type op_type,
Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*)
// print vector load
virtual void PrintVecLoad(const Variable* buffer,
Type t, Expr base,
std::ostream& os); // NOLINT(*)
// print vector store
virtual void PrintVecStore(const Variable* buffer,
Type t, Expr base,
const std::string& value); // NOLINT(*)
// print load of single element
virtual void PrintVecElemLoad(
const std::string& vec, Type t, int i, std::ostream& os); // NOLINT(*)
// print store of single element.
virtual void PrintVecElemStore(
const std::string& vec, Type t, int i, const std::string& value);
/*! \brief function print into the ostream */
using FPrintExpr = IRFunctor<void(const NodeRef&, std::ostream& os, CodeGenC *)>; // NOLINT(*)
/*! \brief function to to print normal code */
using FPrintStmt = IRFunctor<void(const NodeRef&, CodeGenC *)>;
// vtable to print code
static FPrintStmt& vtable_print_stmt();
// vtable to print code
static FPrintExpr& vtable_print_expr();
/*! \brief The current indentation value */
int indent{0};
/*! \brief the stream to be printed */
std::ostringstream stream;
protected:
/*! \brief the stream to be printed */
std::ostringstream stream;
/*! \brief entry in ssa assign map */
struct SSAEntry {
/*! \brief The value id */
std::string vid;
/*! \brief The scope id */
int scope_id;
};
// print reference to a buffer as type t in index.
void PrintBufferRef(const Variable* buffer,
Type t, Expr index,
......@@ -158,13 +191,6 @@ class CodeGenC {
* \return The returned name.
*/
std::string GetUniqueName(std::string prefix);
/*! \brief entry in ssa assign map */
struct SSAEntry {
/*! \brief The value id */
std::string vid;
/*! \brief The scope id */
int scope_id;
};
/*!
* \brief mark the beginning of a new scope
* \return The scope id.
......@@ -209,6 +235,8 @@ class CodeGenC {
std::unordered_map<const Variable*, Type> handle_data_type_;
/*! \brief array to check whether we are inside certain scope */
std::vector<bool> scope_mark_;
/*! \brief The current indentation value */
int indent{0};
};
} // namespace codegen
......
......@@ -19,7 +19,7 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) {
CodeGenC::AddFunction(f);
}
void CodeGenCUDA::PrintStmt(const ir::For* op) {
void CodeGenCUDA::VisitStmt_(const ir::For* op) {
int ext;
CHECK(is_zero(op->min));
if (arith::GetConstInt(op->extent, &ext) &&
......@@ -27,7 +27,7 @@ void CodeGenCUDA::PrintStmt(const ir::For* op) {
PrintIndent();
stream << "#pragma unroll\n";
}
CodeGenC::PrintStmt(op);
CodeGenC::VisitStmt_(op);
}
void CodeGenCUDA::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
......
......@@ -18,7 +18,7 @@ class CodeGenCUDA : public CodeGenC {
public:
void AddFunction(LoweredFunc f);
// override behavior
void PrintStmt(const ir::For* op) final;
void VisitStmt_(const ir::For* op) final;
void PrintStorageSync(const std::string& sync) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintVecBinaryOp(
......
......@@ -130,7 +130,7 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
llvm::BasicBlock* block = llvm::BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(block);
this->Visit(f->body);
this->VisitStmt(f->body);
builder_->CreateRet(ConstInt32(0));
}
......@@ -222,247 +222,376 @@ llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const {
return ret;
}
void CodeGenLLVM::Visit_(const Variable* op) {
value_ = GetVarValue(op);
}
void CodeGenLLVM::Visit_(const Cast* op) {
value_ = CreateCast(op->value.type(), op->type, MakeValue(op->value));
}
void CodeGenLLVM::Visit_(const IntImm* op) {
value_ = llvm::ConstantInt::getSigned(LLVMType(op->type), op->value);
}
void CodeGenLLVM::Visit_(const UIntImm* op) {
value_ = llvm::ConstantInt::get(LLVMType(op->type), op->value);
}
void CodeGenLLVM::Visit_(const FloatImm* op) {
value_ = llvm::ConstantFP::get(LLVMType(op->type), op->value);
}
void CodeGenLLVM::Visit_(const StringImm* op) {
value_ = GetConstString(op->value);
}
#define DEFINE_CODEGEN_BINARY_OP(OP) \
llvm::Value* CodeGenLLVM::Create ## OP( \
Type t, llvm::Value* a, llvm::Value *b) { \
if (t.is_float()) { \
return builder_->CreateF ## OP (a, b); \
} else if (t.is_int() && t.bits() >= 32) { \
return builder_->CreateNSW ## OP (a, b); \
} else { \
return builder_->Create ## OP (a, b); \
} \
} \
DEFINE_CODEGEN_BINARY_OP(Add);
DEFINE_CODEGEN_BINARY_OP(Sub);
DEFINE_CODEGEN_BINARY_OP(Mul);
void CodeGenLLVM::Visit_(const Add* op) {
value_ = CreateAdd(op->type, MakeValue(op->a), MakeValue(op->b));
}
void CodeGenLLVM::Visit_(const Sub* op) {
value_ = CreateSub(op->type, MakeValue(op->a), MakeValue(op->b));
}
void CodeGenLLVM::Visit_(const Mul* op) {
value_ = CreateMul(op->type, MakeValue(op->a), MakeValue(op->b));
llvm::BasicBlock* CodeGenLLVM::CheckCallSuccess(llvm::Value* retcode) {
// create emit codes that checks and load the function.
using llvm::BasicBlock;
BasicBlock* fail_block = BasicBlock::Create(
*ctx_, "call_fail", function_);
BasicBlock* end_block = BasicBlock::Create(
*ctx_, "call_end", function_);
llvm::Value* succ = builder_->CreateICmpEQ(
retcode, llvm::ConstantInt::get(t_int_, 0));
builder_->CreateCondBr(succ, end_block, fail_block, md_very_likely_branch_);
builder_->SetInsertPoint(fail_block);
// return the code.
builder_->CreateRet(retcode);
// otherwise set it to be new end.
builder_->SetInsertPoint(end_block);
return end_block;
}
void CodeGenLLVM::Visit_(const Div* op) {
llvm::Value* a = MakeValue(op->a);
int shift;
if (op->type.is_float()) {
value_ = builder_->CreateFDiv(a, MakeValue(op->b));
} else if ((op->type.is_int() || op->type.is_uint()) &&
is_const_power_of_two_integer(op->b, &shift)) {
value_ = builder_->CreateAShr(a, shift);
} else {
llvm::Value* b = MakeValue(op->b);
if (op->type.is_int()) {
value_ = builder_->CreateSDiv(a, b);
} else {
CHECK(op->type.is_uint());
value_ = builder_->CreateUDiv(a, b);
void CodeGenLLVM::AddAliasInfo(
llvm::Instruction* inst, const Variable* buffer, Expr index) {
int base = 0, width = 0;
// create meta-data for alias analysis
// Use a group of binary tree ranges.
const Ramp* ramp = index.as<Ramp>();
if (ramp) {
int base, stride;
if (arith::GetConstInt(ramp->base, &base) &&
arith::GetConstInt(ramp->stride, &stride)) {
int xwith = ramp->lanes * stride;
width = 1;
while (width < xwith) {
width *= 2;
}
while (base % width) {
base -= base % width;
width *= 2;
}
}
void CodeGenLLVM::Visit_(const Mod* op) {
CHECK(!op->type.is_float())
<< "Cannot do mod for float";
if (op->type.is_int()) {
value_ = builder_->CreateSRem(MakeValue(op->a), MakeValue(op->b));
} else {
CHECK(op->type.is_uint());
value_ = builder_->CreateURem(MakeValue(op->a), MakeValue(op->b));
}
}
void CodeGenLLVM::Visit_(const Min* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
llvm::Value* cond = CreateLT(op->a.type(), a, b);
value_ = builder_->CreateSelect(cond, a, b);
}
void CodeGenLLVM::Visit_(const Max* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
llvm::Value* cond = CreateGT(op->a.type(), a, b);
value_ = builder_->CreateSelect(cond, a, b);
}
#define DEFINE_CODEGEN_CMP_OP(OP) \
llvm::Value* CodeGenLLVM::Create ## OP( \
Type t, llvm::Value* a, llvm::Value* b) { \
if (t.is_float()) { \
return builder_->CreateFCmpO ## OP (a, b); \
} else if (t.is_int()) { \
return builder_->CreateICmpS ## OP (a, b); \
} else { \
return builder_->CreateICmpU ## OP (a, b); \
} \
} \
DEFINE_CODEGEN_CMP_OP(LT);
DEFINE_CODEGEN_CMP_OP(LE);
DEFINE_CODEGEN_CMP_OP(GT);
DEFINE_CODEGEN_CMP_OP(GE);
void CodeGenLLVM::Visit_(const LT* op) {
value_ = CreateLT(op->a.type(), MakeValue(op->a), MakeValue(op->b));
}
void CodeGenLLVM::Visit_(const LE* op) {
value_ = CreateLE(op->a.type(), MakeValue(op->a), MakeValue(op->b));
}
void CodeGenLLVM::Visit_(const GT* op) {
value_ = CreateGT(op->a.type(), MakeValue(op->a), MakeValue(op->b));
}
void CodeGenLLVM::Visit_(const GE* op) {
value_ = CreateGE(op->a.type(), MakeValue(op->a), MakeValue(op->b));
}
void CodeGenLLVM::Visit_(const EQ* op) {
if (op->a.type().is_float()) {
value_ = builder_->CreateFCmpOEQ(MakeValue(op->a), MakeValue(op->b));
} else {
value_ = builder_->CreateICmpEQ(MakeValue(op->a), MakeValue(op->b));
if (arith::GetConstInt(index, &base)) width = 1;
}
}
void CodeGenLLVM::Visit_(const NE* op) {
if (op->a.type().is_float()) {
value_ = builder_->CreateFCmpONE(MakeValue(op->a), MakeValue(op->b));
} else {
value_ = builder_->CreateICmpNE(MakeValue(op->a), MakeValue(op->b));
llvm::MDNode* meta = md_tbaa_root_;
std::ostringstream buffer_addr;
buffer_addr << buffer;
meta = md_builder_->createTBAAScalarTypeNode(buffer_addr.str(), meta);
// create a tree-shape access structure.
if (width != 0) {
for (int w = 1024; w >= width; w /= 2) {
int b = (base / w) * w;
std::stringstream os;
os << buffer << ".w" << w << ".b" << b;
meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta);
}
}
inst->setMetadata(
"tbaa",
md_builder_->createTBAAStructTagNode(meta, meta, 0));
}
void CodeGenLLVM::Visit_(const And* op) {
value_ = builder_->CreateAnd(MakeValue(op->a), MakeValue(op->b));
}
void CodeGenLLVM::Visit_(const Or* op) {
value_ = builder_->CreateOr(MakeValue(op->a), MakeValue(op->b));
}
void CodeGenLLVM::Visit_(const Not* op) {
value_ = builder_->CreateNot(MakeValue(op->a));
}
void CodeGenLLVM::Visit_(const Select* op) {
value_ = builder_->CreateSelect(
MakeValue(op->condition),
MakeValue(op->true_value),
MakeValue(op->false_value));
}
void CodeGenLLVM::Visit_(const Let* op) {
llvm::Value* v = MakeValue(op->value);
CHECK(!var_map_.count(op->var.get()));
var_map_[op->var.get()] = v;
value_ = MakeValue(op->body);
llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
llvm::Constant* init = llvm::UndefValue::get(
llvm::VectorType::get(value->getType(), lanes));
llvm::Constant* zero = ConstInt32(0);
value = builder_->CreateInsertElement(init, value, zero);
llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero);
return builder_->CreateShuffleVector(value, init, mask);
}
void CodeGenLLVM::Visit_(const Broadcast* op) {
value_ = CreateBroadcast(MakeValue(op->value), op->lanes);
}
llvm::Value* CodeGenLLVM::CreateBufferPtr(
Type t, llvm::Value* buffer, llvm::Value* index) {
llvm::Type* elem_type = buffer->getType();
unsigned address_space = elem_type->getPointerAddressSpace();
llvm::Type* load_type = LLVMType(t)->getPointerTo(address_space);
void CodeGenLLVM::Visit_(const Ramp* op) {
Type t = op->type;
llvm::Value* base = MakeValue(op->base);
llvm::Value* stride = MakeValue(op->stride);
llvm::Value* value = llvm::UndefValue::get(LLVMType(t));
for (int i = 0; i < t.lanes(); ++i) {
if (i != 0) {
base = CreateAdd(t, base, stride);
if (load_type != elem_type) {
buffer = builder_->CreatePointerCast(buffer, load_type);
}
value = builder_->CreateInsertElement(
value, base, llvm::ConstantInt::get(t_int32_, i));
llvm::Constant* cindex = llvm::dyn_cast<llvm::Constant>(index);
if (cindex && cindex->isZeroValue()) {
return buffer;
}
value_ = value;
return builder_->CreateInBoundsGEP(buffer, index);
}
void CodeGenLLVM::Visit_(const Load* op) {
Type t = op->type;
CHECK(!t.is_vector());
if (t.is_scalar()) {
llvm::LoadInst* inst = builder_->CreateAlignedLoad(
CreateBufferPtr(
t,
GetVarValue(op->buffer_var.get()),
MakeValue(op->index)),
data_layout_->getTypeAllocSize(LLVMType(t)));
AddAliasInfo(inst, op->buffer_var.get(), op->index);
value_ = inst;
llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) {
llvm::Type * target = LLVMType(to);
if (value->getType() == target) return value;
if (from.is_handle() && from.is_handle()) {
return builder_->CreateBitCast(value, target);
} else if (!from.is_float() && !to.is_float()) {
return builder_->CreateIntCast(value, target, from.is_int());
} else if (from.is_float() && to.is_int()) {
return builder_->CreateFPToSI(value, target);
} else if (from.is_float() && to.is_uint()) {
if (to.bits() < 8) {
value = builder_->CreateFPToUI(value, LLVMType(to.with_bits(8)));
return builder_->CreateIntCast(value, target, false);
} else {
LOG(FATAL) << "not yet supported";
return builder_->CreateFPToUI(value, target);
}
}
void CodeGenLLVM::Visit_(const Store* op) {
llvm::Value* value = MakeValue(op->value);
Type t = op->value.type();
CHECK(!t.is_vector());
if (t.is_scalar()) {
llvm::StoreInst* inst = builder_->CreateAlignedStore(
value,
CreateBufferPtr(
t,
GetVarValue(op->buffer_var.get()),
MakeValue(op->index)),
data_layout_->getTypeAllocSize(value->getType()));
AddAliasInfo(inst, op->buffer_var.get(), op->index);
} else if (from.is_int() && to.is_float()) {
return builder_->CreateSIToFP(value, target);
} else if (from.is_uint() && to.is_float()) {
return builder_->CreateUIToFP(value, target);
} else {
LOG(FATAL) << "not yet supported";
CHECK(from.is_float() && to.is_float());
return builder_->CreateFPCast(value, target);
}
}
void CodeGenLLVM::Visit_(const Call* op) {
if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
value_ = CreateCallPacked(op);
} else if (op->call_type == Call::Intrinsic ||
op->call_type == Call::PureIntrinsic) {
value_ = CreateIntrinstic(op);
llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) {
using llvm::BasicBlock;
// We will store the packed function handle in global space.
// Initialize it during the first call.
llvm::DataLayout layout(module_.get());
uint64_t align = layout.getTypeAllocSize(t_tvm_func_handle_);
auto it = func_handle_map_.find(fname);
llvm::GlobalVariable* hptr;
if (it == func_handle_map_.end()) {
// create global location for the handle
// create the function handle
hptr = new llvm::GlobalVariable(
*module_, t_tvm_func_handle_, false,
llvm::GlobalValue::PrivateLinkage, 0, ".tvm_func");
hptr->setAlignment(align);
hptr->setInitializer(llvm::Constant::getNullValue(t_tvm_func_handle_));
func_handle_map_[fname] = hptr;
} else {
CHECK(op->call_type == Call::Extern ||
op->call_type == Call::PureExtern);
value_ = CreateCallExtern(op);
hptr = it->second;
}
}
llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
if (op->is_intrinsic(Call::bitwise_and)) {
CHECK_EQ(op->args.size(), 2U);
return builder_->CreateAnd(
MakeValue(op->args[0]), MakeValue(op->args[1]));
// create emit codes that checks and load the function.
BasicBlock* pre_block = builder_->GetInsertBlock();
BasicBlock* init_block = BasicBlock::Create(
*ctx_, "handle_init", function_);
BasicBlock* end_block = BasicBlock::Create(
*ctx_, "handle_init_end", function_);
llvm::Value* handle = builder_->CreateAlignedLoad(hptr, align);
llvm::Value* handle_not_null = builder_->CreateICmpNE(
handle, llvm::Constant::getNullValue(t_tvm_func_handle_));
builder_->CreateCondBr(
handle_not_null, end_block, init_block, md_very_likely_branch_);
// Initialize the handle if needed.
builder_->SetInsertPoint(init_block);
llvm::Value* out = builder_->CreateAlloca(t_tvm_func_handle_);
llvm::Value* ctx = builder_->CreateLoad(gv_mod_ctx_);
llvm::Value* retcode = builder_->CreateCall(
f_tvm_get_func_from_env_, {ctx, GetConstString(fname), out});
init_block = CheckCallSuccess(retcode);
llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align);
builder_->CreateBr(end_block);
// end block
builder_->SetInsertPoint(end_block);
llvm::PHINode* phi = builder_->CreatePHI(t_tvm_func_handle_, 2);
phi->addIncoming(handle, pre_block);
phi->addIncoming(loaded_handle, init_block);
return phi;
}
llvm::Value* CodeGenLLVM::CreateCallPacked(const Call* op) {
CHECK_GE(op->args.size(), 1U);
std::string func_name = op->args[0].as<StringImm>()->value;
llvm::Value* handle = GetPackedFuncHandle(func_name);
// call the function
unsigned nargs = static_cast<unsigned>(op->args.size() - 1);
llvm::Value* targs = builder_->CreateAlloca(
t_tvm_value_, ConstInt32(nargs));
llvm::Value* tcodes = builder_->CreateAlloca(
t_int_, ConstInt32(nargs));
for (unsigned i = 0; i < nargs; ++i) {
Expr expr = op->args[i + 1];
Type t = expr.type();
CHECK_EQ(t.lanes(), 1);
// Always pass via 64 bit value.
// For handle type, Handle(64) maps to 32 bit void* in 32bit platform.
Type api_type = t.with_bits(64);
llvm::Value* value = CreateCast(t, api_type, MakeValue(expr));
llvm::Value* store_ptr = builder_->CreatePointerCast(
builder_->CreateInBoundsGEP(targs, ConstInt32(i)),
LLVMType(api_type)->getPointerTo());
builder_->CreateAlignedStore(value, store_ptr, 8);
builder_->CreateAlignedStore(
ConstInt32(t.code()),
builder_->CreateInBoundsGEP(tcodes, ConstInt32(i)), 4);
}
llvm::Value* ret_value = builder_->CreateAlloca(t_tvm_value_);
llvm::Value* ret_tcode = builder_->CreateAlloca(t_int_);
CheckCallSuccess(
builder_->CreateCall(
f_tvm_func_call_,
{handle, targs, tcodes, ConstInt32(nargs), ret_value, ret_tcode}));
Type r_type = op->type;
Type r_api_type = op->type.with_bits(64);
llvm::Value* rvalue =
builder_->CreateAlignedLoad(
builder_->CreatePointerCast(
ret_value, LLVMType(r_api_type)->getPointerTo()), 8);
rvalue = CreateCast(r_api_type, r_type, rvalue);
return rvalue;
}
llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
std::vector<llvm::Value*> arg_values(op->args.size());
for (size_t i = 0; i < op->args.size(); ++i) {
arg_values[i] = MakeValue(op->args[i]);
}
if (op->type.is_scalar()) {
llvm::Function* f = module_->getFunction(op->name);
if (f) {
return builder_->CreateCall(f, arg_values);
} else {
LOG(FATAL) << "cannot find function " << op->name;
}
} else {
llvm::Function* f = module_->getFunction(op->name);
if (f) {
return CreateScalarizedCall(op, f, arg_values);
} else {
LOG(FATAL) << "cannot find function " << op->name;
}
}
return nullptr;
}
llvm::Value* CodeGenLLVM::CreateScalarizedCall(
const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args) {
llvm::Value* value = llvm::UndefValue::get(LLVMType(op->type));
for (int i = 0; i < op->type.lanes(); ++i) {
std::vector<llvm::Value*> sargs(args.size());
for (size_t j = 0; j < args.size(); ++j) {
if (args[j]->getType()->isVectorTy()) {
sargs[j] = builder_->CreateExtractElement(args[j], ConstInt32(i));
} else {
sargs[j] = args[j];
}
}
llvm::CallInst* call = builder_->CreateCall(f, sargs);
if (op->is_pure()) {
call->setDoesNotAccessMemory();
}
call->setDoesNotThrow();
if (!call->getType()->isVoidTy()) {
value = builder_->CreateInsertElement(value, call, ConstInt32(i));
}
}
return value;
}
llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const {
auto it = var_map_.find(v);
CHECK(it != var_map_.end())
<< "Cannot find " << v->name_hint << " in the var map";
return it->second;
}
llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) {
auto it = str_map_.find(str);
if (it == str_map_.end()) {
llvm::Type* type = llvm::ArrayType::get(t_char_, str.length() + 1);
llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, true, llvm::GlobalValue::PrivateLinkage, 0, ".str");
global->setAlignment(1);
global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, str));
// useful constant value
llvm::Constant* zero = ConstInt32(0);
llvm::Constant* indices[] = {zero, zero};
llvm::Constant* sptr = llvm::ConstantExpr::getGetElementPtr(
type, global, indices);
str_map_[str] = sptr;
return sptr;
} else {
return it->second;
}
}
void CodeGenLLVM::CreateParallelFor(const For* op) {
using llvm::BasicBlock;
llvm::Value* min = MakeValue(op->min);
llvm::Value* extent = MakeValue(op->extent);
min = builder_->CreateIntCast(min, t_int64_, op->min.type().is_int());
extent = builder_->CreateIntCast(extent, t_int64_, op->min.type().is_int());
// fields to be packed into closure.
Var loop_var(op->loop_var.node_);
Array<Var> vfields = ir::UndefinedVars(op->body, {loop_var});
std::vector<llvm::Type*> fields;
for (Var v : vfields) {
auto it = var_map_.find(v.get());
CHECK(it != var_map_.end());
fields.push_back(it->second->getType());
}
// closure data
llvm::StructType* tcdata = llvm::StructType::create(fields);
llvm::Function* f = llvm::Function::Create(
t_f_tvm_par_for_lambda_,
llvm::Function::PrivateLinkage,
"__tvm_par_for_lambda", module_.get());
// allocate and setup the closure, call the closure.
llvm::Value* cdata = builder_->CreateAlloca(tcdata, ConstInt32(1));
llvm::Value* zero = ConstInt32(0);
for (size_t i = 0; i < vfields.size(); ++i) {
builder_->CreateStore(
var_map_.at(vfields[i].get()),
builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)}));
}
BasicBlock* par_for_end = CheckCallSuccess(
builder_->CreateCall(
f_tvm_parallel_for_,
{min, extent, f, builder_->CreatePointerCast(cdata, t_void_p_)}));
// Setup the closure function.
BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
builder_->SetInsertPoint(lambda_entry);
auto it = f->arg_begin();
llvm::Value* begin = &(*it++);
llvm::Value* end = &(*it++);
cdata = &(*it++);
begin = CreateCast(Int(64), op->loop_var.type(), begin);
end = CreateCast(Int(64), op->loop_var.type(), end);
cdata = builder_->CreatePointerCast(cdata, tcdata->getPointerTo());
// setup new variable map, swap it with current var context.
std::unordered_map<const Variable*, llvm::Value*> new_vmap;
for (size_t i = 0; i < vfields.size(); ++i) {
new_vmap[vfields[i].get()] =
builder_->CreateLoad(builder_->CreateInBoundsGEP(
cdata, {zero, ConstInt32(i)}));
}
std::swap(function_, f);
std::swap(new_vmap, var_map_);
CreateSerialFor(begin, end, op->loop_var, op->body);
builder_->CreateRet(ConstInt32(0));
// swap the var map back, now we are back on track.
std::swap(new_vmap, var_map_);
std::swap(function_, f);
builder_->SetInsertPoint(par_for_end);
}
void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end,
const VarExpr& loop_var, const Stmt& body) {
using llvm::BasicBlock;
Type t = loop_var.type();
BasicBlock* for_head = BasicBlock::Create(
*ctx_, "for_head", function_);
BasicBlock* for_body = BasicBlock::Create(
*ctx_, "for_body", function_);
BasicBlock* for_end = BasicBlock::Create(
*ctx_, "for_end", function_);
BasicBlock* pre_block = builder_->GetInsertBlock();
builder_->CreateBr(for_head);
builder_->SetInsertPoint(for_head);
llvm::PHINode* index = builder_->CreatePHI(begin->getType(), 2);
index->addIncoming(begin, pre_block);
llvm::Value* cond = CreateLT(t, index, end);
builder_->CreateCondBr(cond, for_body, for_end, md_very_likely_branch_);
// body of for
builder_->SetInsertPoint(for_body);
var_map_[loop_var.get()] = index;
this->VisitStmt(body);
llvm::Value* next_index = CreateAdd(t, index, ConstInt32(1));
index->addIncoming(next_index, builder_->GetInsertBlock());
builder_->CreateBr(for_head);
// end of for
builder_->SetInsertPoint(for_end);
}
llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
if (op->is_intrinsic(Call::bitwise_and)) {
CHECK_EQ(op->args.size(), 2U);
return builder_->CreateAnd(
MakeValue(op->args[0]), MakeValue(op->args[1]));
} else if (op->is_intrinsic(Call::bitwise_xor)) {
CHECK_EQ(op->args.size(), 2U);
return builder_->CreateXor(
......@@ -555,27 +684,249 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
return nullptr;
}
llvm::BasicBlock* CodeGenLLVM::CheckCallSuccess(llvm::Value* retcode) {
// create emit codes that checks and load the function.
using llvm::BasicBlock;
BasicBlock* fail_block = BasicBlock::Create(
*ctx_, "call_fail", function_);
BasicBlock* end_block = BasicBlock::Create(
*ctx_, "call_end", function_);
llvm::Value* succ = builder_->CreateICmpEQ(
retcode, llvm::ConstantInt::get(t_int_, 0));
builder_->CreateCondBr(succ, end_block, fail_block, md_very_likely_branch_);
builder_->SetInsertPoint(fail_block);
// return the code.
builder_->CreateRet(retcode);
// otherwise set it to be new end.
builder_->SetInsertPoint(end_block);
return end_block;
// visitor overrides
llvm::Value* CodeGenLLVM::VisitExpr_(const Variable* op) {
return GetVarValue(op);
}
void CodeGenLLVM::Visit_(const For* op) {
CHECK(is_zero(op->min));
if (op->for_type == ForType::Serial) {
CreateSerialFor(ConstInt32(0), MakeValue(op->extent),
llvm::Value* CodeGenLLVM::VisitExpr_(const Cast* op) {
return CreateCast(op->value.type(), op->type, MakeValue(op->value));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const IntImm* op) {
return llvm::ConstantInt::getSigned(LLVMType(op->type), op->value);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const UIntImm* op) {
return llvm::ConstantInt::get(LLVMType(op->type), op->value);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImm* op) {
return llvm::ConstantFP::get(LLVMType(op->type), op->value);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const StringImm* op) {
return GetConstString(op->value);
}
#define DEFINE_CODEGEN_BINARY_OP(OP) \
llvm::Value* CodeGenLLVM::Create ## OP( \
Type t, llvm::Value* a, llvm::Value *b) { \
if (t.is_float()) { \
return builder_->CreateF ## OP (a, b); \
} else if (t.is_int() && t.bits() >= 32) { \
return builder_->CreateNSW ## OP (a, b); \
} else { \
return builder_->Create ## OP (a, b); \
} \
} \
DEFINE_CODEGEN_BINARY_OP(Add);
DEFINE_CODEGEN_BINARY_OP(Sub);
DEFINE_CODEGEN_BINARY_OP(Mul);
llvm::Value* CodeGenLLVM::VisitExpr_(const Add* op) {
return CreateAdd(op->type, MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Sub* op) {
return CreateSub(op->type, MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Mul* op) {
return CreateMul(op->type, MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Div* op) {
llvm::Value* a = MakeValue(op->a);
int shift;
if (op->type.is_float()) {
return builder_->CreateFDiv(a, MakeValue(op->b));
} else if ((op->type.is_int() || op->type.is_uint()) &&
is_const_power_of_two_integer(op->b, &shift)) {
return builder_->CreateAShr(a, shift);
} else {
llvm::Value* b = MakeValue(op->b);
if (op->type.is_int()) {
return builder_->CreateSDiv(a, b);
} else {
CHECK(op->type.is_uint());
return builder_->CreateUDiv(a, b);
}
}
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Mod* op) {
CHECK(!op->type.is_float())
<< "Cannot do mod for float";
if (op->type.is_int()) {
return builder_->CreateSRem(MakeValue(op->a), MakeValue(op->b));
} else {
CHECK(op->type.is_uint());
return builder_->CreateURem(MakeValue(op->a), MakeValue(op->b));
}
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Min* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
llvm::Value* cond = CreateLT(op->a.type(), a, b);
return builder_->CreateSelect(cond, a, b);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Max* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
llvm::Value* cond = CreateGT(op->a.type(), a, b);
return builder_->CreateSelect(cond, a, b);
}
#define DEFINE_CODEGEN_CMP_OP(OP) \
llvm::Value* CodeGenLLVM::Create ## OP( \
Type t, llvm::Value* a, llvm::Value* b) { \
if (t.is_float()) { \
return builder_->CreateFCmpO ## OP (a, b); \
} else if (t.is_int()) { \
return builder_->CreateICmpS ## OP (a, b); \
} else { \
return builder_->CreateICmpU ## OP (a, b); \
} \
} \
DEFINE_CODEGEN_CMP_OP(LT);
DEFINE_CODEGEN_CMP_OP(LE);
DEFINE_CODEGEN_CMP_OP(GT);
DEFINE_CODEGEN_CMP_OP(GE);
llvm::Value* CodeGenLLVM::VisitExpr_(const LT* op) {
return CreateLT(op->a.type(), MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const LE* op) {
return CreateLE(op->a.type(), MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const GT* op) {
return CreateGT(op->a.type(), MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const GE* op) {
return CreateGE(op->a.type(), MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const EQ* op) {
if (op->a.type().is_float()) {
return builder_->CreateFCmpOEQ(MakeValue(op->a), MakeValue(op->b));
} else {
return builder_->CreateICmpEQ(MakeValue(op->a), MakeValue(op->b));
}
}
llvm::Value* CodeGenLLVM::VisitExpr_(const NE* op) {
if (op->a.type().is_float()) {
return builder_->CreateFCmpONE(MakeValue(op->a), MakeValue(op->b));
} else {
return builder_->CreateICmpNE(MakeValue(op->a), MakeValue(op->b));
}
}
llvm::Value* CodeGenLLVM::VisitExpr_(const And* op) {
return builder_->CreateAnd(MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Or* op) {
return builder_->CreateOr(MakeValue(op->a), MakeValue(op->b));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Not* op) {
return builder_->CreateNot(MakeValue(op->a));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Select* op) {
return builder_->CreateSelect(
MakeValue(op->condition),
MakeValue(op->true_value),
MakeValue(op->false_value));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Let* op) {
llvm::Value* v = MakeValue(op->value);
CHECK(!var_map_.count(op->var.get()));
var_map_[op->var.get()] = v;
return MakeValue(op->body);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Broadcast* op) {
return CreateBroadcast(MakeValue(op->value), op->lanes);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Ramp* op) {
Type t = op->type;
llvm::Value* base = MakeValue(op->base);
llvm::Value* stride = MakeValue(op->stride);
llvm::Value* value = llvm::UndefValue::get(LLVMType(t));
for (int i = 0; i < t.lanes(); ++i) {
if (i != 0) {
base = CreateAdd(t, base, stride);
}
value = builder_->CreateInsertElement(
value, base, llvm::ConstantInt::get(t_int32_, i));
}
return value;
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
Type t = op->type;
CHECK(!t.is_vector());
if (t.is_scalar()) {
llvm::LoadInst* inst = builder_->CreateAlignedLoad(
CreateBufferPtr(
t,
GetVarValue(op->buffer_var.get()),
MakeValue(op->index)),
data_layout_->getTypeAllocSize(LLVMType(t)));
AddAliasInfo(inst, op->buffer_var.get(), op->index);
return inst;
} else {
LOG(FATAL) << "not yet supported";
return nullptr;
}
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
return CreateCallPacked(op);
} else if (op->call_type == Call::Intrinsic ||
op->call_type == Call::PureIntrinsic) {
return CreateIntrinstic(op);
} else {
CHECK(op->call_type == Call::Extern ||
op->call_type == Call::PureExtern);
return CreateCallExtern(op);
}
}
// stmts
void CodeGenLLVM::VisitStmt_(const Store* op) {
llvm::Value* value = MakeValue(op->value);
Type t = op->value.type();
CHECK(!t.is_vector());
if (t.is_scalar()) {
llvm::StoreInst* inst = builder_->CreateAlignedStore(
value,
CreateBufferPtr(
t,
GetVarValue(op->buffer_var.get()),
MakeValue(op->index)),
data_layout_->getTypeAllocSize(value->getType()));
AddAliasInfo(inst, op->buffer_var.get(), op->index);
} else {
LOG(FATAL) << "not yet supported";
}
}
void CodeGenLLVM::VisitStmt_(const For* op) {
CHECK(is_zero(op->min));
if (op->for_type == ForType::Serial) {
CreateSerialFor(ConstInt32(0), MakeValue(op->extent),
op->loop_var, op->body);
} else if (op->for_type == ForType::Parallel) {
CreateParallelFor(op);
......@@ -584,7 +935,7 @@ void CodeGenLLVM::Visit_(const For* op) {
}
}
void CodeGenLLVM::Visit_(const IfThenElse* op) {
void CodeGenLLVM::VisitStmt_(const IfThenElse* op) {
using llvm::BasicBlock;
BasicBlock* then_block = BasicBlock::Create(
*ctx_, "if_then", function_);
......@@ -605,18 +956,18 @@ void CodeGenLLVM::Visit_(const IfThenElse* op) {
}
// then case.
builder_->SetInsertPoint(then_block);
this->Visit(op->then_case);
this->VisitStmt(op->then_case);
builder_->CreateBr(end_block);
// else case.
if (op->else_case.defined()) {
builder_->SetInsertPoint(else_block);
this->Visit(op->else_case);
this->VisitStmt(op->else_case);
builder_->CreateBr(end_block);
}
builder_->SetInsertPoint(end_block);
}
void CodeGenLLVM::Visit_(const Allocate* op) {
void CodeGenLLVM::VisitStmt_(const Allocate* op) {
CHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr;
if (op->new_expr.defined()) {
......@@ -634,11 +985,11 @@ void CodeGenLLVM::Visit_(const Allocate* op) {
var_map_[op->buffer_var.get()] = buf;
}
void CodeGenLLVM::Visit_(const AttrStmt* op) {
this->Visit(op->body);
void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
this->VisitStmt(op->body);
}
void CodeGenLLVM::Visit_(const AssertStmt* op) {
void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
using llvm::BasicBlock;
llvm::Value* cond = MakeValue(op->condition);
std::ostringstream os;
......@@ -660,359 +1011,23 @@ void CodeGenLLVM::Visit_(const AssertStmt* op) {
builder_->SetInsertPoint(end_block);
}
void CodeGenLLVM::Visit_(const LetStmt* op) {
void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
llvm::Value* v = MakeValue(op->value);
CHECK(!var_map_.count(op->var.get()));
var_map_[op->var.get()] = v;
this->Visit(op->body);
this->VisitStmt(op->body);
}
void CodeGenLLVM::AddAliasInfo(
llvm::Instruction* inst, const Variable* buffer, Expr index) {
int base = 0, width = 0;
// create meta-data for alias analysis
// Use a group of binary tree ranges.
const Ramp* ramp = index.as<Ramp>();
if (ramp) {
int base, stride;
if (arith::GetConstInt(ramp->base, &base) &&
arith::GetConstInt(ramp->stride, &stride)) {
int xwith = ramp->lanes * stride;
width = 1;
while (width < xwith) {
width *= 2;
}
while (base % width) {
base -= base % width;
width *= 2;
}
}
} else {
if (arith::GetConstInt(index, &base)) width = 1;
}
llvm::MDNode* meta = md_tbaa_root_;
std::ostringstream buffer_addr;
buffer_addr << buffer;
meta = md_builder_->createTBAAScalarTypeNode(buffer_addr.str(), meta);
// create a tree-shape access structure.
if (width != 0) {
for (int w = 1024; w >= width; w /= 2) {
int b = (base / w) * w;
std::stringstream os;
os << buffer << ".w" << w << ".b" << b;
meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta);
}
}
inst->setMetadata(
"tbaa",
md_builder_->createTBAAStructTagNode(meta, meta, 0));
void CodeGenLLVM::VisitStmt_(const Block* op) {
VisitStmt(op->first);
if (op->rest.defined()) VisitStmt(op->rest);
}
llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
llvm::Constant* init = llvm::UndefValue::get(
llvm::VectorType::get(value->getType(), lanes));
llvm::Constant* zero = ConstInt32(0);
value = builder_->CreateInsertElement(init, value, zero);
llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero);
return builder_->CreateShuffleVector(value, init, mask);
void CodeGenLLVM::VisitStmt_(const Evaluate *op) {
MakeValue(op->value);
}
llvm::Value* CodeGenLLVM::CreateBufferPtr(
Type t, llvm::Value* buffer, llvm::Value* index) {
llvm::Type* elem_type = buffer->getType();
unsigned address_space = elem_type->getPointerAddressSpace();
llvm::Type* load_type = LLVMType(t)->getPointerTo(address_space);
if (load_type != elem_type) {
buffer = builder_->CreatePointerCast(buffer, load_type);
}
llvm::Constant* cindex = llvm::dyn_cast<llvm::Constant>(index);
if (cindex && cindex->isZeroValue()) {
return buffer;
}
return builder_->CreateInBoundsGEP(buffer, index);
void CodeGenLLVM::VisitStmt_(const ProducerConsumer* op) {
VisitStmt(op->body);
}
llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) {
llvm::Type * target = LLVMType(to);
if (value->getType() == target) return value;
if (from.is_handle() && from.is_handle()) {
return builder_->CreateBitCast(value, target);
} else if (!from.is_float() && !to.is_float()) {
return builder_->CreateIntCast(value, target, from.is_int());
} else if (from.is_float() && to.is_int()) {
return builder_->CreateFPToSI(value, target);
} else if (from.is_float() && to.is_uint()) {
if (to.bits() < 8) {
value = builder_->CreateFPToUI(value, LLVMType(to.with_bits(8)));
return builder_->CreateIntCast(value, target, false);
} else {
return builder_->CreateFPToUI(value, target);
}
} else if (from.is_int() && to.is_float()) {
return builder_->CreateSIToFP(value, target);
} else if (from.is_uint() && to.is_float()) {
return builder_->CreateUIToFP(value, target);
} else {
CHECK(from.is_float() && to.is_float());
return builder_->CreateFPCast(value, target);
}
}
llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) {
using llvm::BasicBlock;
// We will store the packed function handle in global space.
// Initialize it during the first call.
llvm::DataLayout layout(module_.get());
uint64_t align = layout.getTypeAllocSize(t_tvm_func_handle_);
auto it = func_handle_map_.find(fname);
llvm::GlobalVariable* hptr;
if (it == func_handle_map_.end()) {
// create global location for the handle
// create the function handle
hptr = new llvm::GlobalVariable(
*module_, t_tvm_func_handle_, false,
llvm::GlobalValue::PrivateLinkage, 0, ".tvm_func");
hptr->setAlignment(align);
hptr->setInitializer(llvm::Constant::getNullValue(t_tvm_func_handle_));
func_handle_map_[fname] = hptr;
} else {
hptr = it->second;
}
// create emit codes that checks and load the function.
BasicBlock* pre_block = builder_->GetInsertBlock();
BasicBlock* init_block = BasicBlock::Create(
*ctx_, "handle_init", function_);
BasicBlock* end_block = BasicBlock::Create(
*ctx_, "handle_init_end", function_);
llvm::Value* handle = builder_->CreateAlignedLoad(hptr, align);
llvm::Value* handle_not_null = builder_->CreateICmpNE(
handle, llvm::Constant::getNullValue(t_tvm_func_handle_));
builder_->CreateCondBr(
handle_not_null, end_block, init_block, md_very_likely_branch_);
// Initialize the handle if needed.
builder_->SetInsertPoint(init_block);
llvm::Value* out = builder_->CreateAlloca(t_tvm_func_handle_);
llvm::Value* ctx = builder_->CreateLoad(gv_mod_ctx_);
llvm::Value* retcode = builder_->CreateCall(
f_tvm_get_func_from_env_, {ctx, GetConstString(fname), out});
init_block = CheckCallSuccess(retcode);
llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align);
builder_->CreateBr(end_block);
// end block
builder_->SetInsertPoint(end_block);
llvm::PHINode* phi = builder_->CreatePHI(t_tvm_func_handle_, 2);
phi->addIncoming(handle, pre_block);
phi->addIncoming(loaded_handle, init_block);
return phi;
}
llvm::Value* CodeGenLLVM::CreateCallPacked(const Call* op) {
CHECK_GE(op->args.size(), 1U);
std::string func_name = op->args[0].as<StringImm>()->value;
llvm::Value* handle = GetPackedFuncHandle(func_name);
// call the function
unsigned nargs = static_cast<unsigned>(op->args.size() - 1);
llvm::Value* targs = builder_->CreateAlloca(
t_tvm_value_, ConstInt32(nargs));
llvm::Value* tcodes = builder_->CreateAlloca(
t_int_, ConstInt32(nargs));
for (unsigned i = 0; i < nargs; ++i) {
Expr expr = op->args[i + 1];
Type t = expr.type();
CHECK_EQ(t.lanes(), 1);
// Always pass via 64 bit value.
// For handle type, Handle(64) maps to 32 bit void* in 32bit platform.
Type api_type = t.with_bits(64);
llvm::Value* value = CreateCast(t, api_type, MakeValue(expr));
llvm::Value* store_ptr = builder_->CreatePointerCast(
builder_->CreateInBoundsGEP(targs, ConstInt32(i)),
LLVMType(api_type)->getPointerTo());
builder_->CreateAlignedStore(value, store_ptr, 8);
builder_->CreateAlignedStore(
ConstInt32(t.code()),
builder_->CreateInBoundsGEP(tcodes, ConstInt32(i)), 4);
}
llvm::Value* ret_value = builder_->CreateAlloca(t_tvm_value_);
llvm::Value* ret_tcode = builder_->CreateAlloca(t_int_);
CheckCallSuccess(
builder_->CreateCall(
f_tvm_func_call_,
{handle, targs, tcodes, ConstInt32(nargs), ret_value, ret_tcode}));
Type r_type = op->type;
Type r_api_type = op->type.with_bits(64);
llvm::Value* rvalue =
builder_->CreateAlignedLoad(
builder_->CreatePointerCast(
ret_value, LLVMType(r_api_type)->getPointerTo()), 8);
rvalue = CreateCast(r_api_type, r_type, rvalue);
return rvalue;
}
llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
std::vector<llvm::Value*> arg_values(op->args.size());
for (size_t i = 0; i < op->args.size(); ++i) {
arg_values[i] = MakeValue(op->args[i]);
}
if (op->type.is_scalar()) {
llvm::Function* f = module_->getFunction(op->name);
if (f) {
return builder_->CreateCall(f, arg_values);
} else {
LOG(FATAL) << "cannot find function " << op->name;
}
} else {
llvm::Function* f = module_->getFunction(op->name);
if (f) {
return CreateScalarizedCall(op, f, arg_values);
} else {
LOG(FATAL) << "cannot find function " << op->name;
}
}
return nullptr;
}
llvm::Value* CodeGenLLVM::CreateScalarizedCall(
const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args) {
llvm::Value* value = llvm::UndefValue::get(LLVMType(op->type));
for (int i = 0; i < op->type.lanes(); ++i) {
std::vector<llvm::Value*> sargs(args.size());
for (size_t j = 0; j < args.size(); ++j) {
if (args[j]->getType()->isVectorTy()) {
sargs[j] = builder_->CreateExtractElement(args[j], ConstInt32(i));
} else {
sargs[j] = args[j];
}
}
llvm::CallInst* call = builder_->CreateCall(f, sargs);
if (op->is_pure()) {
call->setDoesNotAccessMemory();
}
call->setDoesNotThrow();
if (!call->getType()->isVoidTy()) {
value = builder_->CreateInsertElement(value, call, ConstInt32(i));
}
}
return value;
}
llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const {
auto it = var_map_.find(v);
CHECK(it != var_map_.end())
<< "Cannot find " << v->name_hint << " in the var map";
return it->second;
}
llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) {
auto it = str_map_.find(str);
if (it == str_map_.end()) {
llvm::Type* type = llvm::ArrayType::get(t_char_, str.length() + 1);
llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, true, llvm::GlobalValue::PrivateLinkage, 0, ".str");
global->setAlignment(1);
global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, str));
// useful constant value
llvm::Constant* zero = ConstInt32(0);
llvm::Constant* indices[] = {zero, zero};
llvm::Constant* sptr = llvm::ConstantExpr::getGetElementPtr(
type, global, indices);
str_map_[str] = sptr;
return sptr;
} else {
return it->second;
}
}
void CodeGenLLVM::CreateParallelFor(const For* op) {
using llvm::BasicBlock;
llvm::Value* min = MakeValue(op->min);
llvm::Value* extent = MakeValue(op->extent);
min = builder_->CreateIntCast(min, t_int64_, op->min.type().is_int());
extent = builder_->CreateIntCast(extent, t_int64_, op->min.type().is_int());
// fields to be packed into closure.
Var loop_var(op->loop_var.node_);
Array<Var> vfields = ir::UndefinedVars(op->body, {loop_var});
std::vector<llvm::Type*> fields;
for (Var v : vfields) {
auto it = var_map_.find(v.get());
CHECK(it != var_map_.end());
fields.push_back(it->second->getType());
}
// closure data
llvm::StructType* tcdata = llvm::StructType::create(fields);
llvm::Function* f = llvm::Function::Create(
t_f_tvm_par_for_lambda_,
llvm::Function::PrivateLinkage,
"__tvm_par_for_lambda", module_.get());
// allocate and setup the closure, call the closure.
llvm::Value* cdata = builder_->CreateAlloca(tcdata, ConstInt32(1));
llvm::Value* zero = ConstInt32(0);
for (size_t i = 0; i < vfields.size(); ++i) {
builder_->CreateStore(
var_map_.at(vfields[i].get()),
builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)}));
}
BasicBlock* par_for_end = CheckCallSuccess(
builder_->CreateCall(
f_tvm_parallel_for_,
{min, extent, f, builder_->CreatePointerCast(cdata, t_void_p_)}));
// Setup the closure function.
BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
builder_->SetInsertPoint(lambda_entry);
auto it = f->arg_begin();
llvm::Value* begin = &(*it++);
llvm::Value* end = &(*it++);
cdata = &(*it++);
begin = CreateCast(Int(64), op->loop_var.type(), begin);
end = CreateCast(Int(64), op->loop_var.type(), end);
cdata = builder_->CreatePointerCast(cdata, tcdata->getPointerTo());
// setup new variable map, swap it with current var context.
std::unordered_map<const Variable*, llvm::Value*> new_vmap;
for (size_t i = 0; i < vfields.size(); ++i) {
new_vmap[vfields[i].get()] =
builder_->CreateLoad(builder_->CreateInBoundsGEP(
cdata, {zero, ConstInt32(i)}));
}
std::swap(function_, f);
std::swap(new_vmap, var_map_);
CreateSerialFor(begin, end, op->loop_var, op->body);
builder_->CreateRet(ConstInt32(0));
// swap the var map back, now we are back on track.
std::swap(new_vmap, var_map_);
std::swap(function_, f);
builder_->SetInsertPoint(par_for_end);
}
void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end,
const VarExpr& loop_var, const Stmt& body) {
using llvm::BasicBlock;
Type t = loop_var.type();
BasicBlock* for_head = BasicBlock::Create(
*ctx_, "for_head", function_);
BasicBlock* for_body = BasicBlock::Create(
*ctx_, "for_body", function_);
BasicBlock* for_end = BasicBlock::Create(
*ctx_, "for_end", function_);
BasicBlock* pre_block = builder_->GetInsertBlock();
builder_->CreateBr(for_head);
builder_->SetInsertPoint(for_head);
llvm::PHINode* index = builder_->CreatePHI(begin->getType(), 2);
index->addIncoming(begin, pre_block);
llvm::Value* cond = CreateLT(t, index, end);
builder_->CreateCondBr(cond, for_body, for_end, md_very_likely_branch_);
// body of for
builder_->SetInsertPoint(for_body);
var_map_[loop_var.get()] = index;
this->Visit(body);
llvm::Value* next_index = CreateAdd(t, index, ConstInt32(1));
index->addIncoming(next_index, builder_->GetInsertBlock());
builder_->CreateBr(for_head);
// end of for
builder_->SetInsertPoint(for_end);
}
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
......@@ -8,7 +8,7 @@
#ifdef TVM_LLVM_VERSION
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/codegen.h>
#include <memory>
#include <vector>
......@@ -23,7 +23,9 @@ using namespace ir;
/*!
* \brief A base class to generate a LLVM.
*/
class CodeGenLLVM : public IRVisitor {
class CodeGenLLVM :
public ExprFunctor<llvm::Value* (const Expr&)>,
public StmtFunctor<void(const Stmt&)> {
public:
/*!
* \brief Initialize the code generator with given context
......@@ -55,52 +57,52 @@ class CodeGenLLVM : public IRVisitor {
* \return created value.
*/
llvm::Value* MakeValue(const Expr& e) {
value_ = nullptr;
this->Visit(e);
CHECK(value_ != nullptr);
return value_;
return VisitExpr(e);
}
// Short hande code to get a constant int 32
llvm::Constant* ConstInt32(unsigned value) const {
return llvm::ConstantInt::get(t_int32_, value);
}
// override codegen
void Visit_(const Variable* op) final;
void Visit_(const Cast* op) final;
void Visit_(const IntImm* op) final;
void Visit_(const UIntImm* op) final;
void Visit_(const FloatImm* op) final;
void Visit_(const StringImm* op) final;
void Visit_(const Add* op) final;
void Visit_(const Sub* op) final;
void Visit_(const Mul* op) final;
void Visit_(const Div* op) final;
void Visit_(const Mod* op) final;
void Visit_(const Min* op) final;
void Visit_(const Max* op) final;
void Visit_(const LT* op) final;
void Visit_(const LE* op) final;
void Visit_(const GT* op) final;
void Visit_(const GE* op) final;
void Visit_(const EQ* op) final;
void Visit_(const NE* op) final;
void Visit_(const And* op) final;
void Visit_(const Or* op) final;
void Visit_(const Not* op) final;
void Visit_(const Select* op) final;
void Visit_(const Let* op) final;
void Visit_(const Load* op) final;
void Visit_(const Call* op) final;
void Visit_(const Ramp* op) final;
void Visit_(const Broadcast* op) final;
llvm::Value* VisitExpr_(const Variable* op) override;
llvm::Value* VisitExpr_(const Cast* op) override;
llvm::Value* VisitExpr_(const IntImm* op) override;
llvm::Value* VisitExpr_(const UIntImm* op) override;
llvm::Value* VisitExpr_(const FloatImm* op) override;
llvm::Value* VisitExpr_(const StringImm* op) override;
llvm::Value* VisitExpr_(const Add* op) override;
llvm::Value* VisitExpr_(const Sub* op) override;
llvm::Value* VisitExpr_(const Mul* op) override;
llvm::Value* VisitExpr_(const Div* op) override;
llvm::Value* VisitExpr_(const Mod* op) override;
llvm::Value* VisitExpr_(const Min* op) override;
llvm::Value* VisitExpr_(const Max* op) override;
llvm::Value* VisitExpr_(const LT* op) override;
llvm::Value* VisitExpr_(const LE* op) override;
llvm::Value* VisitExpr_(const GT* op) override;
llvm::Value* VisitExpr_(const GE* op) override;
llvm::Value* VisitExpr_(const EQ* op) override;
llvm::Value* VisitExpr_(const NE* op) override;
llvm::Value* VisitExpr_(const And* op) override;
llvm::Value* VisitExpr_(const Or* op) override;
llvm::Value* VisitExpr_(const Not* op) override;
llvm::Value* VisitExpr_(const Select* op) override;
llvm::Value* VisitExpr_(const Let* op) override;
llvm::Value* VisitExpr_(const Load* op) override;
llvm::Value* VisitExpr_(const Call* op) override;
llvm::Value* VisitExpr_(const Ramp* op) override;
llvm::Value* VisitExpr_(const Broadcast* op) override;
// stmt
void Visit_(const Store* op) final;
void Visit_(const For* op) final;
void Visit_(const IfThenElse* op) final;
void Visit_(const Allocate* op) final;
void Visit_(const AttrStmt* op) override;
void Visit_(const AssertStmt* op) final;
void Visit_(const LetStmt* op) final;
void VisitStmt_(const Store* op) override;
void VisitStmt_(const For* op) override;
void VisitStmt_(const IfThenElse* op) override;
void VisitStmt_(const Allocate* op) override;
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const LetStmt* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
// create intrinstic given call
virtual llvm::Value* CreateIntrinstic(const Call* op);
// create extern function call
......@@ -160,8 +162,6 @@ class CodeGenLLVM : public IRVisitor {
llvm::Function* f_tvm_parallel_for_{nullptr};
// The acting body
llvm::BasicBlock* block_{nullptr};
// Last value returned codegen call.
llvm::Value* value_{nullptr};
private:
// comparison op
......
......@@ -12,10 +12,6 @@ namespace codegen {
using namespace ir;
CodeGenStackVM::FType& CodeGenStackVM::vtable() { // NOLINT(*)
static FType inst; return inst;
}
StackVM CodeGenStackVM::Compile(LoweredFunc f) {
for (size_t i = 0; i < f->args.size(); ++i) {
Var v = f->args[i];
......@@ -27,18 +23,12 @@ StackVM CodeGenStackVM::Compile(LoweredFunc f) {
}
void CodeGenStackVM::Push(const Stmt& n) {
static const FType& f = vtable();
f(n, this);
VisitStmt(n);
if (debug_) {
this->PushOp(StackVM::ASSERT_SP, 0);
}
}
void CodeGenStackVM::Push(const Expr& n) {
static const FType& f = vtable();
f(n, this);
}
void CodeGenStackVM::PushOp(StackVM::OpCode opcode) {
StackVM::Code code;
code.op_code = opcode;
......@@ -106,7 +96,7 @@ int CodeGenStackVM::GetVarID(const Variable* v) const {
return it->second;
}
void CodeGenStackVM::Push_(const ir::Load* op) {
void CodeGenStackVM::VisitExpr_(const Load* op) {
this->PushOp(StackVM::LOAD_HEAP, GetVarID(op->buffer_var.get()));
if (op->type == UInt(32) && op->index.as<IntImm>()) {
this->PushOp(StackVM::ARRAY_LOAD_UINT32, op->index.as<IntImm>()->value);
......@@ -118,7 +108,8 @@ void CodeGenStackVM::Push_(const ir::Load* op) {
this->PushOp(StackVM::GetLoad(Type2TVMType(op->type)));
}
}
void CodeGenStackVM::Push_(const ir::Store* op) {
void CodeGenStackVM::VisitStmt_(const Store* op) {
this->PushOp(StackVM::LOAD_HEAP, GetVarID(op->buffer_var.get()));
this->Push(op->index);
this->PushOp(StackVM::PUSH_I64, op->value.type().element_of().bytes());
......@@ -128,7 +119,7 @@ void CodeGenStackVM::Push_(const ir::Store* op) {
this->PushOp(StackVM::GetStore(Type2TVMType(op->value.type())));
}
void CodeGenStackVM::Push_(const ir::Allocate* op) {
void CodeGenStackVM::VisitStmt_(const Allocate* op) {
CHECK(!is_zero(op->condition));
int vid = AllocVarID(op->buffer_var.get());
if (op->new_expr.defined()) {
......@@ -141,7 +132,7 @@ void CodeGenStackVM::Push_(const ir::Allocate* op) {
}
}
void CodeGenStackVM::Push_(const ir::Call* op) {
void CodeGenStackVM::VisitExpr_(const Call* op) {
if (op->is_intrinsic(Call::address_of)) {
const Load *l = op->args[0].as<Load>();
CHECK(op->args.size() == 1 && l);
......@@ -211,37 +202,30 @@ void CodeGenStackVM::Push_(const ir::Call* op) {
this->PushOp(StackVM::PUSH_I64, 0);
this->PushOp(StackVM::EQ_I64);
} else {
this->HandleUnknownCall(op);
LOG(FATAL) << "unknown function call " << op->name;
}
}
void CodeGenStackVM::HandleUnknownCall(const ir::Call* op) {
LOG(FATAL) << "donot know how to handle call " << op->name;
}
inline void PushBinary(StackVM::OpCode op_int64,
void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64,
const Expr& a,
const Expr& b,
CodeGenStackVM* p) {
p->Push(a);
p->Push(b);
const Expr& b) {
this->Push(a);
this->Push(b);
Type t = a.type();
if (t.is_int()) {
p->PushOp(op_int64);
this->PushOp(op_int64);
} else if (t.is_uint()) {
if (t.bits() <= 32) {
p->PushOp(op_int64);
this->PushOp(op_int64);
} else {
LOG(FATAL) << "Cannot handle uint64_t in StackVM";
}
} else {
p->PushOp(StackVM::CodeI64ToF64(op_int64));
this->PushOp(StackVM::CodeI64ToF64(op_int64));
}
}
inline void PushCast(Type dst,
Type src,
CodeGenStackVM* p) {
void CodeGenStackVM::PushCast(Type dst, Type src) {
if (dst.is_int()) {
if (src.is_int()) return;
if (src.is_uint() && src.bits() <= 32) return;
......@@ -254,211 +238,226 @@ inline void PushCast(Type dst,
LOG(FATAL) << "Cannot handle cast " << src << " to " << dst;
}
TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable)
.set_dispatch<StringImm>([](const StringImm *op, CodeGenStackVM *p) {
int sid = p->GetStrID(op->value);
p->PushOp(StackVM::PUSH_I64, sid);
})
.set_dispatch<IntImm>([](const IntImm *op, CodeGenStackVM *p) {
void CodeGenStackVM::VisitExpr_(const StringImm *op) {
int sid = this->GetStrID(op->value);
this->PushOp(StackVM::PUSH_I64, sid);
}
void CodeGenStackVM::VisitExpr_(const IntImm *op) {
CHECK(op->value >= std::numeric_limits<int>::min() &&
op->value <= std::numeric_limits<int>::max())
<< "Int constant exceed bound";
p->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value));
})
.set_dispatch<UIntImm>([](const UIntImm *op, CodeGenStackVM *p) {
this->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value));
}
void CodeGenStackVM::VisitExpr_(const UIntImm *op) {
CHECK(op->value <= std::numeric_limits<int>::max())
<< "Int constant exceed bound";
p->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value));
})
.set_dispatch<FloatImm>([](const FloatImm *op, CodeGenStackVM *p) {
this->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value));
}
void CodeGenStackVM::VisitExpr_(const FloatImm *op) {
LOG(FATAL) << "Float Imm is not supported";
});
TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable)
.set_dispatch<Variable>([](const Variable *op, CodeGenStackVM* p) {
int vid = p->GetVarID(op);
p->PushOp(StackVM::LOAD_HEAP, vid);
})
.set_dispatch<Cast>([](const Cast *op, CodeGenStackVM* p) {
p->Push(op->value);
PushCast(op->type, op->value.type(), p);
})
.set_dispatch<Add>([](const Add *op, CodeGenStackVM* p) {
PushBinary(StackVM::ADD_I64, op->a, op->b, p);
})
.set_dispatch<Sub>([](const Sub *op, CodeGenStackVM* p) {
PushBinary(StackVM::SUB_I64, op->a, op->b, p);
})
.set_dispatch<Mul>([](const Mul *op, CodeGenStackVM* p) {
PushBinary(StackVM::MUL_I64, op->a, op->b, p);
})
.set_dispatch<Div>([](const Div *op, CodeGenStackVM* p) {
PushBinary(StackVM::DIV_I64, op->a, op->b, p);
})
.set_dispatch<Mod>([](const Mod *op, CodeGenStackVM* p) {
PushBinary(StackVM::MOD_I64, op->a, op->b, p);
})
.set_dispatch<Min>([](const Min *op, CodeGenStackVM* p) {
p->Push(op->a);
p->Push(op->b);
p->PushOp(StackVM::PUSH_VALUE, -1);
p->PushOp(StackVM::PUSH_VALUE, -1);
p->PushOp(StackVM::LT_I64);
p->PushOp(StackVM::SELECT);
})
.set_dispatch<Max>([](const Max *op, CodeGenStackVM* p) {
p->Push(op->a);
p->Push(op->b);
p->PushOp(StackVM::PUSH_VALUE, 0);
p->PushOp(StackVM::PUSH_VALUE, -2);
p->PushOp(StackVM::LT_I64);
p->PushOp(StackVM::SELECT);
})
.set_dispatch<EQ>([](const EQ *op, CodeGenStackVM* p) {
PushBinary(StackVM::EQ_I64, op->a, op->b, p);
})
.set_dispatch<LE>([](const LE *op, CodeGenStackVM* p) {
PushBinary(StackVM::LE_I64, op->a, op->b, p);
})
.set_dispatch<NE>([](const NE *op, CodeGenStackVM* p) {
PushBinary(StackVM::EQ_I64, op->a, op->b, p);
p->PushOp(StackVM::NOT);
})
.set_dispatch<LT>([](const LT *op, CodeGenStackVM* p) {
PushBinary(StackVM::LT_I64, op->a, op->b, p);
})
.set_dispatch<GE>([](const GE *op, CodeGenStackVM* p) {
PushBinary(StackVM::LT_I64, op->a, op->b, p);
p->PushOp(StackVM::NOT);
})
.set_dispatch<GT>([](const GT *op, CodeGenStackVM* p) {
PushBinary(StackVM::LE_I64, op->a, op->b, p);
p->PushOp(StackVM::NOT);
})
.set_dispatch<And>([](const And *op, CodeGenStackVM* p) {
p->Push(op->a);
int64_t pc_jump = p->GetPC();
int64_t opr_index = p->PushOp(StackVM::RJUMP_IF_FALSE, 0);
p->PushOp(StackVM::POP);
p->Push(op->b);
int64_t diff = p->GetPC() - pc_jump;
p->SetOperand(opr_index, diff);
})
.set_dispatch<Or>([](const Or *op, CodeGenStackVM* p) {
p->Push(op->a);
int64_t pc_jump = p->GetPC();
int64_t opr_index = p->PushOp(StackVM::RJUMP_IF_TRUE, 0);
p->Push(op->b);
int64_t diff = p->GetPC() - pc_jump;
p->SetOperand(opr_index, diff);
})
.set_dispatch<Not>([](const Not* op, CodeGenStackVM* p) {
p->PushOp(StackVM::NOT);
});
TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable)
.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, CodeGenStackVM* p) {
p->Push(op->body);
})
.set_dispatch<For>([](const For *op, CodeGenStackVM* p) {
}
void CodeGenStackVM::VisitExpr_(const Variable *op) {
int vid = this->GetVarID(op);
this->PushOp(StackVM::LOAD_HEAP, vid);
}
void CodeGenStackVM::VisitExpr_(const Cast *op) {
this->Push(op->value);
PushCast(op->type, op->value.type());
}
void CodeGenStackVM::VisitExpr_(const Add *op) {
PushBinary(StackVM::ADD_I64, op->a, op->b);
}
void CodeGenStackVM::VisitExpr_(const Sub *op) {
PushBinary(StackVM::SUB_I64, op->a, op->b);
}
void CodeGenStackVM::VisitExpr_(const Mul *op) {
PushBinary(StackVM::MUL_I64, op->a, op->b);
}
void CodeGenStackVM::VisitExpr_(const Div *op) {
PushBinary(StackVM::DIV_I64, op->a, op->b);
}
void CodeGenStackVM::VisitExpr_(const Mod *op) {
PushBinary(StackVM::MOD_I64, op->a, op->b);
}
void CodeGenStackVM::VisitExpr_(const Min *op) {
this->Push(op->a);
this->Push(op->b);
this->PushOp(StackVM::PUSH_VALUE, -1);
this->PushOp(StackVM::PUSH_VALUE, -1);
this->PushOp(StackVM::LT_I64);
this->PushOp(StackVM::SELECT);
}
void CodeGenStackVM::VisitExpr_(const Max *op) {
this->Push(op->a);
this->Push(op->b);
this->PushOp(StackVM::PUSH_VALUE, 0);
this->PushOp(StackVM::PUSH_VALUE, -2);
this->PushOp(StackVM::LT_I64);
this->PushOp(StackVM::SELECT);
}
void CodeGenStackVM::VisitExpr_(const EQ *op) {
PushBinary(StackVM::EQ_I64, op->a, op->b);
}
void CodeGenStackVM::VisitExpr_(const LE *op) {
PushBinary(StackVM::LE_I64, op->a, op->b);
}
void CodeGenStackVM::VisitExpr_(const NE *op) {
PushBinary(StackVM::EQ_I64, op->a, op->b);
this->PushOp(StackVM::NOT);
}
void CodeGenStackVM::VisitExpr_(const LT *op) {
PushBinary(StackVM::LT_I64, op->a, op->b);
}
void CodeGenStackVM::VisitExpr_(const GE *op) {
PushBinary(StackVM::LT_I64, op->a, op->b);
this->PushOp(StackVM::NOT);
}
void CodeGenStackVM::VisitExpr_(const GT *op) {
PushBinary(StackVM::LE_I64, op->a, op->b);
this->PushOp(StackVM::NOT);
}
void CodeGenStackVM::VisitExpr_(const And *op) {
this->Push(op->a);
int64_t pc_jump = this->GetPC();
int64_t opr_index = this->PushOp(StackVM::RJUMP_IF_FALSE, 0);
this->PushOp(StackVM::POP);
this->Push(op->b);
int64_t diff = this->GetPC() - pc_jump;
this->SetOperand(opr_index, diff);
}
void CodeGenStackVM::VisitExpr_(const Or *op) {
this->Push(op->a);
int64_t pc_jump = this->GetPC();
int64_t opr_index = this->PushOp(StackVM::RJUMP_IF_TRUE, 0);
this->Push(op->b);
int64_t diff = this->GetPC() - pc_jump;
this->SetOperand(opr_index, diff);
}
void CodeGenStackVM::VisitExpr_(const Not* op) {
this->PushOp(StackVM::NOT);
}
void CodeGenStackVM::VisitStmt_(const ProducerConsumer *op) {
this->Push(op->body);
}
void CodeGenStackVM::VisitStmt_(const For *op) {
CHECK(is_zero(op->min));
int vid = p->AllocVarID(op->loop_var.get());
p->PushOp(StackVM::PUSH_I64, 0);
int64_t loop_head = p->GetPC();
p->PushOp(StackVM::STORE_HEAP, vid);
p->PushOp(StackVM::LOAD_HEAP, vid);
p->Push(op->extent);
p->PushOp(StackVM::LT_I64);
int64_t label_fjump = p->GetPC();
int64_t foward_jump = p->PushOp(StackVM::RJUMP_IF_FALSE, 0);
p->PushOp(StackVM::POP);
p->Push(op->body);
p->PushOp(StackVM::LOAD_HEAP, vid);
p->PushOp(StackVM::PUSH_I64, 1);
p->PushOp(StackVM::ADD_I64);
int64_t label_bjump = p->GetPC();
int64_t backward_jump = p->PushOp(StackVM::RJUMP, 0);
int64_t loop_end = p->GetPC();
p->PushOp(StackVM::POP);
p->SetOperand(foward_jump, loop_end - label_fjump);
p->SetOperand(backward_jump, loop_head - label_bjump);
})
.set_dispatch<Block>([](const Block *op, CodeGenStackVM* p) {
p->Push(op->first);
if (op->rest.defined()) p->Push(op->rest);
})
.set_dispatch<Evaluate>([](const Evaluate *op, CodeGenStackVM* p) {
int vid = this->AllocVarID(op->loop_var.get());
this->PushOp(StackVM::PUSH_I64, 0);
int64_t loop_head = this->GetPC();
this->PushOp(StackVM::STORE_HEAP, vid);
this->PushOp(StackVM::LOAD_HEAP, vid);
this->Push(op->extent);
this->PushOp(StackVM::LT_I64);
int64_t label_fjump = this->GetPC();
int64_t foward_jump = this->PushOp(StackVM::RJUMP_IF_FALSE, 0);
this->PushOp(StackVM::POP);
this->Push(op->body);
this->PushOp(StackVM::LOAD_HEAP, vid);
this->PushOp(StackVM::PUSH_I64, 1);
this->PushOp(StackVM::ADD_I64);
int64_t label_bjump = this->GetPC();
int64_t backward_jump = this->PushOp(StackVM::RJUMP, 0);
int64_t loop_end = this->GetPC();
this->PushOp(StackVM::POP);
this->SetOperand(foward_jump, loop_end - label_fjump);
this->SetOperand(backward_jump, loop_head - label_bjump);
}
void CodeGenStackVM::VisitStmt_(const Block *op) {
this->Push(op->first);
if (op->rest.defined()) this->Push(op->rest);
}
void CodeGenStackVM::VisitStmt_(const Evaluate *op) {
if (is_const(op->value)) return;
p->Push(op->value);
p->PushOp(StackVM::POP);
})
.set_dispatch<IfThenElse>([](const IfThenElse *op, CodeGenStackVM* p) {
p->Push(op->condition);
int64_t label_ejump = p->GetPC();
int64_t else_jump = p->PushOp(StackVM::RJUMP_IF_FALSE, 0);
p->PushOp(StackVM::POP);
p->Push(op->then_case);
this->Push(op->value);
this->PushOp(StackVM::POP);
}
void CodeGenStackVM::VisitStmt_(const IfThenElse *op) {
this->Push(op->condition);
int64_t label_ejump = this->GetPC();
int64_t else_jump = this->PushOp(StackVM::RJUMP_IF_FALSE, 0);
this->PushOp(StackVM::POP);
this->Push(op->then_case);
if (op->else_case.defined()) {
int64_t label_then_jump = p->GetPC();
int64_t then_jump = p->PushOp(StackVM::RJUMP, 0);
int64_t else_begin = p->GetPC();
p->SetOperand(else_jump, else_begin - label_ejump);
p->PushOp(StackVM::POP);
p->Push(op->else_case);
int64_t if_end = p->GetPC();
p->SetOperand(then_jump, if_end - label_then_jump);
int64_t label_then_jump = this->GetPC();
int64_t then_jump = this->PushOp(StackVM::RJUMP, 0);
int64_t else_begin = this->GetPC();
this->SetOperand(else_jump, else_begin - label_ejump);
this->PushOp(StackVM::POP);
this->Push(op->else_case);
int64_t if_end = this->GetPC();
this->SetOperand(then_jump, if_end - label_then_jump);
} else {
int64_t if_end = p->GetPC();
p->SetOperand(else_jump, if_end - label_ejump);
p->PushOp(StackVM::POP);
int64_t if_end = this->GetPC();
this->SetOperand(else_jump, if_end - label_ejump);
this->PushOp(StackVM::POP);
}
})
.set_dispatch<LetStmt>([](const LetStmt *op, CodeGenStackVM* p) {
p->Push(op->value);
int64_t vid = p->AllocVarID(op->var.get());
p->PushOp(StackVM::STORE_HEAP, static_cast<int>(vid));
p->Push(op->body);
})
.set_dispatch<Ramp>([](const Ramp *op, CodeGenStackVM* p) {
}
void CodeGenStackVM::VisitStmt_(const LetStmt *op) {
this->Push(op->value);
int64_t vid = this->AllocVarID(op->var.get());
this->PushOp(StackVM::STORE_HEAP, static_cast<int>(vid));
this->Push(op->body);
}
void CodeGenStackVM::VisitExpr_(const Ramp *op) {
LOG(FATAL) << "Ramp is not supported";
})
.set_dispatch<Broadcast>([](const Broadcast *op, CodeGenStackVM* p) {
}
void CodeGenStackVM::VisitExpr_(const Broadcast *op) {
LOG(FATAL) << "Broadcast is not supported";
})
.set_dispatch<Select>([](const Select *op, CodeGenStackVM* p) {
p->Push(op->true_value);
p->Push(op->false_value);
p->Push(op->condition);
p->PushOp(StackVM::SELECT);
})
.set_dispatch<AssertStmt>([](const AssertStmt *op, CodeGenStackVM* p) {
}
void CodeGenStackVM::VisitExpr_(const Select *op) {
this->Push(op->true_value);
this->Push(op->false_value);
this->Push(op->condition);
this->PushOp(StackVM::SELECT);
}
void CodeGenStackVM::VisitStmt_(const AssertStmt *op) {
if (op->message.as<StringImm>()) {
int sid = p->GetStrID(op->message.as<StringImm>()->value);
p->Push(op->condition);
p->PushOp(StackVM::ASSERT, sid);
int sid = this->GetStrID(op->message.as<StringImm>()->value);
this->Push(op->condition);
this->PushOp(StackVM::ASSERT, sid);
}
})
.set_dispatch<AttrStmt>([](const AttrStmt *op, CodeGenStackVM* p) {
p->Push(op->body);
})
.set_dispatch<Let>([](const Let *op, CodeGenStackVM* p) {
p->Push(op->value);
int64_t vid = p->AllocVarID(op->var.get());
p->PushOp(StackVM::STORE_HEAP, static_cast<int>(vid));
p->Push(op->body);
})
.set_dispatch<Load>([](const Load *op, CodeGenStackVM* p) {
p->Push_(op);
})
.set_dispatch<Store>([](const Store *op, CodeGenStackVM* p) {
p->Push_(op);
})
.set_dispatch<Allocate>([](const Allocate *op, CodeGenStackVM* p) {
p->Push_(op);
})
.set_dispatch<Call>([](const Call *op, CodeGenStackVM* p) {
p->Push_(op);
});
}
void CodeGenStackVM::VisitStmt_(const AttrStmt *op) {
this->Push(op->body);
}
void CodeGenStackVM::VisitExpr_(const Let *op) {
this->Push(op->value);
int64_t vid = this->AllocVarID(op->var.get());
this->PushOp(StackVM::STORE_HEAP, static_cast<int>(vid));
this->Push(op->body);
}
} // namespace codegen
} // namespace tvm
......@@ -7,6 +7,7 @@
#define TVM_CODEGEN_STACK_VM_CODEGEN_STACK_VM_H_
#include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/lowered_func.h>
#include <tvm/codegen.h>
#include <string>
......@@ -18,12 +19,15 @@
namespace tvm {
namespace codegen {
using namespace ir;
/*!
* \brief A base class to generate a stack VM.
* This module is used to generate host wrapper
* into device function when only device JIT is available.
*/
class CodeGenStackVM {
class CodeGenStackVM
: public ExprFunctor<void(const Expr&)>,
public StmtFunctor<void(const Stmt&)> {
public:
/*!
* \brief Generate a stack VM representing
......@@ -36,7 +40,9 @@ class CodeGenStackVM {
/*! \brief Push stmt to generate new code */
void Push(const Stmt& n);
/*! \brief Push expr to generate new code */
void Push(const Expr& n);
void Push(const Expr& n) {
VisitExpr(n);
}
/*!
* \brief Push the opcode to the code.
* \param opcode The code to be pushed.
......@@ -84,16 +90,53 @@ class CodeGenStackVM {
* \return the heap index of the var.
*/
int GetVarID(const Variable* v) const;
// Push binary operator
void PushBinary(StackVM::OpCode op_int64,
const Expr& a,
const Expr& b);
// push cast;
void PushCast(Type dst, Type src);
// overloadable functions
virtual void Push_(const ir::Load* op);
virtual void Push_(const ir::Store* op);
virtual void Push_(const ir::Allocate* op);
virtual void Push_(const ir::Call* op);
virtual void HandleUnknownCall(const ir::Call* op);
/*! \brief function to to print normal code */
using FType = IRFunctor<void(const NodeRef&, CodeGenStackVM *)>;
// vtable to print code
static FType& vtable(); // NOLINT(*)
// expression
void VisitExpr_(const Variable* op) final;
void VisitExpr_(const Load* op) final;
void VisitExpr_(const Let* op) final;
void VisitExpr_(const Call* op) final;
void VisitExpr_(const Add* op) final;
void VisitExpr_(const Sub* op) final;
void VisitExpr_(const Mul* op) final;
void VisitExpr_(const Div* op) final;
void VisitExpr_(const Mod* op) final;
void VisitExpr_(const Min* op) final;
void VisitExpr_(const Max* op) final;
void VisitExpr_(const EQ* op) final;
void VisitExpr_(const NE* op) final;
void VisitExpr_(const LT* op) final;
void VisitExpr_(const LE* op) final;
void VisitExpr_(const GT* op) final;
void VisitExpr_(const GE* op) final;
void VisitExpr_(const And* op) final;
void VisitExpr_(const Or* op) final;
void VisitExpr_(const Cast* op) final;
void VisitExpr_(const Not* op) final;
void VisitExpr_(const Select* op) final;
void VisitExpr_(const Ramp* op) final;
void VisitExpr_(const Broadcast* op) final;
void VisitExpr_(const IntImm* op) final;
void VisitExpr_(const UIntImm* op) final;
void VisitExpr_(const FloatImm* op) final;
void VisitExpr_(const StringImm* op) final;
// statment
void VisitStmt_(const LetStmt* op) final;
void VisitStmt_(const Store* op) final;
void VisitStmt_(const For* op) final;
void VisitStmt_(const IfThenElse* op) final;
void VisitStmt_(const Allocate* op) final;
void VisitStmt_(const AttrStmt* op) final;
void VisitStmt_(const AssertStmt* op) final;
void VisitStmt_(const Evaluate* op) final;
void VisitStmt_(const Block* op) final;
void VisitStmt_(const ProducerConsumer* op) final;
private:
bool debug_{false};
......
......@@ -140,10 +140,6 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) {
}
}
Stmt IRMutator::Mutate_(const Load *op, const Stmt& s) {
return s;
}
Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) {
Expr value = this->Mutate(op->value);
Expr index = this->Mutate(op->index);
......@@ -234,84 +230,24 @@ Stmt IRMutator::Mutate_(const Evaluate *op, const Stmt& s) {
}
}
#define DEFINE_OP_RETURN_SELF_STMT_MUTATE_(OP) \
Stmt IRMutator::Mutate_(const OP *op, const Stmt& s) { \
return s; \
}
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Variable)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Let)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Free)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Call)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Add)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Sub)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Mul)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Div)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Mod)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Min)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Max)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(EQ)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(NE)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(LT)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(LE)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(GT)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(GE)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(And)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Or)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Reduce)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Cast)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Not)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Select)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Ramp)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Broadcast)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(IntImm)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(UIntImm)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(FloatImm)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(StringImm)
Stmt IRMutator::Mutate_(const Free *op, const Stmt& s) {
return s;
}
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.DISPATCH_TO_MUTATE_STMT(Variable)
.DISPATCH_TO_MUTATE_STMT(LetStmt)
.DISPATCH_TO_MUTATE_STMT(AttrStmt)
.DISPATCH_TO_MUTATE_STMT(IfThenElse)
.DISPATCH_TO_MUTATE_STMT(For)
.DISPATCH_TO_MUTATE_STMT(Allocate)
.DISPATCH_TO_MUTATE_STMT(Load)
.DISPATCH_TO_MUTATE_STMT(Store)
.DISPATCH_TO_MUTATE_STMT(Let)
.DISPATCH_TO_MUTATE_STMT(Free)
.DISPATCH_TO_MUTATE_STMT(Call)
.DISPATCH_TO_MUTATE_STMT(Add)
.DISPATCH_TO_MUTATE_STMT(Sub)
.DISPATCH_TO_MUTATE_STMT(Mul)
.DISPATCH_TO_MUTATE_STMT(Div)
.DISPATCH_TO_MUTATE_STMT(Mod)
.DISPATCH_TO_MUTATE_STMT(Min)
.DISPATCH_TO_MUTATE_STMT(Max)
.DISPATCH_TO_MUTATE_STMT(EQ)
.DISPATCH_TO_MUTATE_STMT(NE)
.DISPATCH_TO_MUTATE_STMT(LT)
.DISPATCH_TO_MUTATE_STMT(LE)
.DISPATCH_TO_MUTATE_STMT(GT)
.DISPATCH_TO_MUTATE_STMT(GE)
.DISPATCH_TO_MUTATE_STMT(And)
.DISPATCH_TO_MUTATE_STMT(Or)
.DISPATCH_TO_MUTATE_STMT(Reduce)
.DISPATCH_TO_MUTATE_STMT(Cast)
.DISPATCH_TO_MUTATE_STMT(Not)
.DISPATCH_TO_MUTATE_STMT(Select)
.DISPATCH_TO_MUTATE_STMT(Ramp)
.DISPATCH_TO_MUTATE_STMT(Broadcast)
.DISPATCH_TO_MUTATE_STMT(AssertStmt)
.DISPATCH_TO_MUTATE_STMT(ProducerConsumer)
.DISPATCH_TO_MUTATE_STMT(Provide)
.DISPATCH_TO_MUTATE_STMT(Realize)
.DISPATCH_TO_MUTATE_STMT(Block)
.DISPATCH_TO_MUTATE_STMT(Evaluate)
.DISPATCH_TO_MUTATE_STMT(IntImm)
.DISPATCH_TO_MUTATE_STMT(UIntImm)
.DISPATCH_TO_MUTATE_STMT(FloatImm)
.DISPATCH_TO_MUTATE_STMT(StringImm);
.DISPATCH_TO_MUTATE_STMT(Evaluate);
// Mutate Expr
......@@ -450,19 +386,6 @@ Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) {
return e; \
}
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(LetStmt)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(AttrStmt)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(For)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IfThenElse)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Allocate)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Store)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Free)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(AssertStmt)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(ProducerConsumer)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Provide)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Realize)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Block)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Evaluate)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImm)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImm)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImm)
......@@ -470,15 +393,8 @@ DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImm)
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.DISPATCH_TO_MUTATE_EXPR(Variable)
.DISPATCH_TO_MUTATE_EXPR(LetStmt)
.DISPATCH_TO_MUTATE_EXPR(AttrStmt)
.DISPATCH_TO_MUTATE_EXPR(IfThenElse)
.DISPATCH_TO_MUTATE_EXPR(For)
.DISPATCH_TO_MUTATE_EXPR(Allocate)
.DISPATCH_TO_MUTATE_EXPR(Load)
.DISPATCH_TO_MUTATE_EXPR(Store)
.DISPATCH_TO_MUTATE_EXPR(Let)
.DISPATCH_TO_MUTATE_EXPR(Free)
.DISPATCH_TO_MUTATE_EXPR(Call)
.DISPATCH_TO_MUTATE_EXPR(Add)
.DISPATCH_TO_MUTATE_EXPR(Sub)
......@@ -501,12 +417,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.DISPATCH_TO_MUTATE_EXPR(Select)
.DISPATCH_TO_MUTATE_EXPR(Ramp)
.DISPATCH_TO_MUTATE_EXPR(Broadcast)
.DISPATCH_TO_MUTATE_EXPR(AssertStmt)
.DISPATCH_TO_MUTATE_EXPR(ProducerConsumer)
.DISPATCH_TO_MUTATE_EXPR(Provide)
.DISPATCH_TO_MUTATE_EXPR(Realize)
.DISPATCH_TO_MUTATE_EXPR(Block)
.DISPATCH_TO_MUTATE_EXPR(Evaluate)
.DISPATCH_TO_MUTATE_EXPR(IntImm)
.DISPATCH_TO_MUTATE_EXPR(UIntImm)
.DISPATCH_TO_MUTATE_EXPR(FloatImm)
......
......@@ -69,11 +69,71 @@ class Vectorizer : public IRMutator {
}
// user mutate from parent.
using IRMutator::Mutate;
// override mutate
Expr Mutate(Expr expr) final {
static const FMutateExpr& f = Vectorizer::vtable_expr();
return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::Mutate(expr));
Expr Mutate_(const Add* op, const Expr &e) final {
return AddSubVec(op, e);
}
Expr Mutate_(const Sub* op, const Expr &e) final {
return AddSubVec(op, e);
}
Expr Mutate_(const Mul* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Div* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Mod* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Min* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Max* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const EQ* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const NE* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const LT* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const GT* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const GE* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const And* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Or* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Select *op, const Expr& e) final {
Expr cond = this->Mutate(op->condition);
Expr t = this->Mutate(op->true_value);
Expr f = this->Mutate(op->false_value);
if (cond.same_as(op->condition) &&
t.same_as(op->true_value) &&
f.same_as(op->false_value)) {
return e;
} else {
int lanes = std::max(std::max(
cond.type().lanes(),
t.type().lanes()), f.type().lanes());
return Select::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes));
}
}
Expr Mutate_(const Cast *op, const Expr& e) final {
Expr value = this->Mutate(op->value);
if (value.same_as(op->value)) {
return e;
} else {
return Cast::make(op->type.with_lanes(value.type().lanes()), value);
}
}
// Variable
Expr Mutate_(const Variable* v, const Expr& e) final {
......@@ -235,10 +295,6 @@ class Vectorizer : public IRMutator {
stmt = Substitute(stmt, {{var_, idx}});
return For::make(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt);
}
// The overloads for vectorize.
static FMutateExpr& vtable_expr() { // NOLINT(*)
static FMutateExpr inst; return inst;
}
private:
// variable to be replaced
......@@ -273,13 +329,10 @@ class Vectorizer : public IRMutator {
if (!changed) return arr;
return Array<Expr>(new_arr);
}
};
// binary vectorize
template<typename T>
inline Expr BinaryVec(const T* op, const Expr& e, IRMutator* m) {
Expr a = m->Mutate(op->a);
Expr b = m->Mutate(op->b);
template<typename T>
Expr BinaryVec(const T* op, const Expr& e) {
Expr a = this->Mutate(op->a);
Expr b = this->Mutate(op->b);
if (a.same_as(op->a) &&
b.same_as(op->b)) {
return e;
......@@ -287,12 +340,11 @@ inline Expr BinaryVec(const T* op, const Expr& e, IRMutator* m) {
int lanes = std::max(a.type().lanes(), b.type().lanes());
return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
}
}
template<typename T>
inline Expr AddSubVec(const T* op, const Expr& e, IRMutator* m) {
Expr a = m->Mutate(op->a);
Expr b = m->Mutate(op->b);
}
template<typename T>
Expr AddSubVec(const T* op, const Expr& e) {
Expr a = this->Mutate(op->a);
Expr b = this->Mutate(op->b);
if (a.same_as(op->a) &&
b.same_as(op->b)) {
return e;
......@@ -312,51 +364,8 @@ inline Expr AddSubVec(const T* op, const Expr& e, IRMutator* m) {
}
return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
}
}
TVM_STATIC_IR_FUNCTOR(Vectorizer, vtable_expr)
.set_dispatch<Add>(AddSubVec<Add>)
.set_dispatch<Sub>(AddSubVec<Sub>)
.set_dispatch<Mul>(BinaryVec<Mul>)
.set_dispatch<Div>(BinaryVec<Div>)
.set_dispatch<Mod>(BinaryVec<Mod>)
.set_dispatch<Min>(BinaryVec<Min>)
.set_dispatch<Max>(BinaryVec<Max>)
.set_dispatch<EQ>(BinaryVec<EQ>)
.set_dispatch<NE>(BinaryVec<NE>)
.set_dispatch<LT>(BinaryVec<LT>)
.set_dispatch<LE>(BinaryVec<LE>)
.set_dispatch<GT>(BinaryVec<GT>)
.set_dispatch<GE>(BinaryVec<GE>)
.set_dispatch<And>(BinaryVec<And>)
.set_dispatch<Or>(BinaryVec<Or>);
TVM_STATIC_IR_FUNCTOR(Vectorizer, vtable_expr)
.set_dispatch<Select>([](const Select *op, const Expr& e, IRMutator* m) {
Expr cond = m->Mutate(op->condition);
Expr t = m->Mutate(op->true_value);
Expr f = m->Mutate(op->false_value);
if (cond.same_as(op->condition) &&
t.same_as(op->true_value) &&
f.same_as(op->false_value)) {
return e;
} else {
int lanes = std::max(std::max(
cond.type().lanes(),
t.type().lanes()), f.type().lanes());
return Select::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes));
}
})
.set_dispatch<Cast>([](const Cast *op, const Expr& e, IRMutator* m) {
Expr value = m->Mutate(op->value);
if (value.same_as(op->value)) {
return e;
} else {
return Cast::make(op->type.with_lanes(value.type().lanes()), value);
}
});
};
class LoopVectorizer : public IRMutator {
public:
......
......@@ -2,10 +2,11 @@
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/ir_functor.h>
#include <tvm/ir_functor_ext.h>
TEST(IRF, Basic) {
using namespace Halide::Internal;
using namespace tvm;
using namespace tvm::ir;
Var x("x");
auto z = x + 1;
......@@ -21,6 +22,65 @@ TEST(IRF, Basic) {
CHECK_EQ(f(z, 2), 4);
}
TEST(IRF, ExprTransform) {
using namespace tvm;
using namespace tvm::ir;
Var x("x");
auto z = x + 1;
class MyExprFunctor
: public ir::ExprFunctor<int(const Expr&, int)> {
public:
int VisitExpr_(const Variable* op, int b) final {
return b;
}
int VisitExpr_(const IntImm* op, int b) final {
return op->value;
}
int VisitExpr_(const Add* op, int b) final {
return VisitExpr(op->a, b) + VisitExpr(op->b, b);
}
};
MyExprFunctor f;
CHECK_EQ(f(x, 2), 2);
CHECK_EQ(f(z, 2), 3);
try {
f(z - 1, 2);
LOG(FATAL) << "should fail";
} catch(dmlc::Error) {
}
}
TEST(IRF, ExprVisit) {
using namespace tvm;
using namespace tvm::ir;
Var x("x");
auto z = x + 1;
class MyVisitor
: public ir::ExprFunctor<void(const Expr&)>,
public ir::StmtFunctor<void(const Stmt&)> {
public:
int count = 0;
// implementation
void VisitExpr_(const Variable* op) final {
++count;
}
void VisitExpr_(const IntImm* op) final {
}
void VisitExpr_(const Add* op) final {
VisitExpr(op->a);
VisitExpr(op->b);
}
void VisitStmt_(const Evaluate* op) final {
VisitExpr(op->value);
}
};
MyVisitor v;
v(Evaluate::make(z));
CHECK_EQ(v.count, 1);
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
......
......@@ -25,6 +25,11 @@ def test_deduce():
ans1 = (c-b)/4+(-2)
assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1)
e2 = (tvm.max(5, a * 4) < 0)
res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s})
assert str(res2.max()) == "neg_inf"
assert str(res2.min()) == "pos_inf"
def test_check():
a = tvm.Var('a')
b = tvm.Var('b')
......
import tvm
def test_basic():
a = tvm.Var()
b = tvm.Var()
m = tvm.arith.EvalModular(a * 4 + b * 6 + 7)
assert m.coeff == 2
assert m.base == 1
m = tvm.arith.EvalModular((a * 4 + 1) * (b * 8 + 3))
assert m.coeff == 4
assert m.base == 3
m = tvm.arith.EvalModular((a * 4 + 1) / (b * 8 + 3))
assert m.coeff == 1
assert m.base == 0
m = tvm.arith.EvalModular((a * 4 + 1) * (b * 8 / 4))
assert m.coeff == 2
assert m.base == 0
m = tvm.arith.EvalModular((a * 12 + 1) - (b * 3 * 7 + 2))
assert m.coeff == 3
assert m.base == 2
m = tvm.arith.EvalModular(a * 12 + tvm.min(b * 3 * 7, 2))
assert m.coeff == 1
assert m.base == 0
if __name__ == "__main__":
test_basic()
......@@ -16,7 +16,7 @@ def test_llvm_add_pipeline():
f = tvm.build(s, [A, B, C], "llvm")
ctx = tvm.cpu(0)
# launch the kernel.
n = 10270 * 2460
n = 1027 * 1024
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment