Unverified Commit 255c187b by Tianqi Chen Committed by GitHub

[EXPR] Expression-template based pattern matching. (#2589)

parent f6be4d69
/*!
* Copyright (c) 2019 by Contributors
* \file tvm/arithmetic/pattern_match.h
*
* \brief Internal tool for expression-template based pattern matching.
*
* It helps to simplify pattern matching and rewrites.
* All the patterns are generated via expression template during compile time,
* so the result code should be as efficient as manually written pattern match code.
*
* The code below shows how to use the pattern matcher.
*
* \code
*
* // max(x + z, y + z) => max(x, y) + z
* arith::PVar<Expr> x, y, z;
*
* // The following code tries to match the declared pattern.
* // Match will fill the result of match into PVar if successful.
* // Note that z occurs twice in the pattern,
* // an equality check is performed to ensure each occurance of z
* // is equivalent to each other.
* if (max(x + z, y + z).Match(expr)) {
* // Eval evaluates a pattern with the current matched value.
* // The filled value is valid until the next call to Match.
* return (max(x, y) + z).Eval();
* }
* \endcode
*
* \note The pattern matcher is not threadsafe,
* do not use the same PVar in multiple threads.
*
* Please be aware that the filled value in a PVar
* can be overriden in the next call to Match.
*/
#ifndef TVM_ARITHMETIC_PATTERN_MATCH_H_
#define TVM_ARITHMETIC_PATTERN_MATCH_H_
#include <tvm/ir_pass.h>
#include <tuple>
namespace tvm {
namespace arith {
/*!
* \brief Base class of all the patterns.
*
* There are two major member functions supported by each pattern.
* - Match: checks if value matches the pattern.
* - Eval: construct a new value based on matched values in PVar.
*
* We use curiously recurring template pattern to construct
* expression templates.
*
* \tparam Derived The type of the derived class.
*/
template<typename Derived>
class Pattern {
public:
/*!
* \brief Nested storage type in the expression.
*
* Depending on the Derived class,
* Nested can be Derived (nest by value) or
* const Derived& (nest by reference).
*
* The trick of Nested typedef originates from Eigen.
*
* \note We use nest by value for intermediate expressions,
* and nest by reference for PVars.
*/
using Nested = Derived;
/*!
* \brief Check if value matches the current pattern.
*
* This call also populates the PVars with matched value.
* The values in PVars are valid until the next call to Match.
*
* \return whether value matches the pattern.
*/
template<typename NodeType>
bool Match(const NodeType& value) const {
derived().InitMatch_();
return derived().Match_(value);
}
/*! \return Derived instance of current class. */
const Derived& derived() const {
return *static_cast<const Derived*>(this);
}
};
/*!
* \brief Default deep equality checker
* \tparam T the comparison point.
*/
template<typename T>
class PEqualChecker {
public:
bool operator()(const T& lhs, const T& rhs) const {
return lhs == rhs;
}
};
template<>
class PEqualChecker<Expr> {
public:
bool operator()(const Expr& lhs, const Expr& rhs) const {
if (lhs.same_as(rhs)) return true;
return ir::Equal(lhs, rhs);
}
};
/*!
* \brief Pattern variable container.
*
* PVar is used as a "hole" in the pattern that can be matched.
*
* \tparam T the type of the hole.
*
* \note PVar is not thread safe.
* Do not use the same PVar in multiple threads.
*/
template<typename T>
class PVar : public Pattern<PVar<T> > {
public:
// Store PVars by reference in the expression.
using Nested = const PVar&;
void InitMatch_() const {
filled_ = false;
}
bool Match_(const T& value) const {
if (!filled_) {
value_ = value;
filled_ = true;
return true;
} else {
return PEqualChecker<T>()(value_, value);
}
}
T Eval() const {
CHECK(filled_);
return value_;
}
private:
/*! \brief The matched value */
mutable T value_;
/*! \brief whether the variable has been filled */
mutable bool filled_{false};
};
/*!
* \brief Constant Pattern variable container.
*
* \tparam T the type of the hole.
*/
template<typename T>
class PConst : public Pattern<PConst<T> > {
public:
PConst(T value) // NOLINT(*)
: value_(value) {}
void InitMatch_() const {}
bool Match_(const T& value) const {
return PEqualChecker<T>()(value_, value);
}
T Eval() const {
return value_;
}
private:
const T value_;
};
/*!
* \brief Pattern binary expression.
* \tparam NodeType The AST node type.
* \tparam TA The pattern type of the first operand.
* \tparam TB The pattern type of the second operand.
*/
template<typename NodeType, typename TA, typename TB>
class PBinaryExpr :
public Pattern<PBinaryExpr<NodeType, TA, TB> > {
public:
PBinaryExpr(const TA& a, const TB& b) : a_(a), b_(b) {}
void InitMatch_() const {
a_.InitMatch_();
b_.InitMatch_();
}
bool Match_(const NodeRef& node) const {
if (const NodeType* ptr = node.as<NodeType>()) {
if (!a_.Match_(ptr->a)) return false;
if (!b_.Match_(ptr->b)) return false;
return true;
} else {
return false;
}
}
Expr Eval() const {
return NodeType::make(a_.Eval(), b_.Eval());
}
private:
typename TA::Nested a_;
typename TB::Nested b_;
};
#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \
template<typename TA, typename TB> \
inline PBinaryExpr<NodeName, TA, TB> \
FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
return PBinaryExpr<NodeName, TA, TB>(a.derived(), b.derived()); \
}
// arithmetic expressions
TVM_PATTERN_BINARY_OP(operator+, ir::Add);
TVM_PATTERN_BINARY_OP(operator-, ir::Sub);
TVM_PATTERN_BINARY_OP(operator*, ir::Mul);
TVM_PATTERN_BINARY_OP(operator/, ir::Div);
TVM_PATTERN_BINARY_OP(operator%, ir::Mod);
TVM_PATTERN_BINARY_OP(min, ir::Min);
TVM_PATTERN_BINARY_OP(max, ir::Max);
// logical expressions
TVM_PATTERN_BINARY_OP(operator>, ir::GT);
TVM_PATTERN_BINARY_OP(operator>=, ir::GE);
TVM_PATTERN_BINARY_OP(operator<, ir::LT);
TVM_PATTERN_BINARY_OP(operator<=, ir::LE);
TVM_PATTERN_BINARY_OP(operator==, ir::EQ);
TVM_PATTERN_BINARY_OP(operator!=, ir::NE);
TVM_PATTERN_BINARY_OP(operator&&, ir::And);
TVM_PATTERN_BINARY_OP(operator||, ir::Or);
/*!
* \brief Pattern not expression.
* \tparam TA The pattern type of the true operand.
*/
template<typename TA>
class PNotExpr : public Pattern<PNotExpr<TA> > {
public:
explicit PNotExpr(const TA& value)
: value_(value) {}
void InitMatch_() const {
value_.InitMatch_();
}
bool Match_(const NodeRef& node) const {
if (const ir::Not* ptr = node.as<ir::Not>()) {
if (!value_.Match_(ptr->a)) return false;
return true;
} else {
return false;
}
}
Expr Eval() const {
return ir::Not::make(value_.Eval());
}
private:
typename TA::Nested value_;
};
template<typename TA>
inline PNotExpr<TA> operator!(const Pattern<TA>& value) {
return PNotExpr<TA>(value.derived());
}
// select
/*!
* \brief Pattern select expression.
* \tparam TCond The pattern type of the condition.
* \tparam TA The pattern type of the true operand.
* \tparam TB The pattern type of the false operand.
*/
template<typename TCond, typename TA, typename TB>
class PSelectExpr :
public Pattern<PSelectExpr<TCond, TA, TB> > {
public:
PSelectExpr(const TCond& condition,
const TA& true_value,
const TB& false_value)
: condition_(condition),
true_value_(true_value),
false_value_(false_value) {}
void InitMatch_() const {
condition_.InitMatch_();
true_value_.InitMatch_();
false_value_.InitMatch_();
}
bool Match_(const NodeRef& node) const {
if (const ir::Select* ptr = node.as<ir::Select>()) {
if (!condition_.Match_(ptr->condition)) return false;
if (!true_value_.Match_(ptr->true_value)) return false;
if (!false_value_.Match_(ptr->false_value)) return false;
return true;
} else {
return false;
}
}
Expr Eval() const {
return ir::Select::make(
condition_.Eval(), true_value_.Eval(), false_value_.Eval());
}
private:
typename TCond::Nested condition_;
typename TA::Nested true_value_;
typename TB::Nested false_value_;
};
/*!
* \brief Construct a select pattern.
*
* \param condition The condition expression.
* \param true_value The value when condition is true.
* \param true_value The value when condition is false.
*
* \return The result pattern.
*
* \tparam TCond The pattern type of the condition.
* \tparam TA The pattern type of the true operand.
* \tparam TB The pattern type of the false operand.
*/
template<typename TCond, typename TA, typename TB>
inline PSelectExpr<TCond, TA, TB>
select(const Pattern<TCond>& condition,
const Pattern<TA>& true_value,
const Pattern<TB>& false_value) {
return PSelectExpr<TCond, TA, TB>(
condition.derived(), true_value.derived(), false_value.derived());
}
/*!
* \brief Pattern cast expression.
* \tparam DType The Pattern type of dtype.
* \tparam TA The pattern type of the first operand.
*/
template<typename DType, typename TA>
class PCastExpr :
public Pattern<PCastExpr<DType, TA> > {
public:
PCastExpr(const DType& dtype, const TA& value)
: dtype_(dtype), value_(value) {
}
void InitMatch_() const {
dtype_.InitMatch_();
value_.InitMatch_();
}
bool Match_(const NodeRef& node) const {
if (const ir::Cast* ptr = node.as<ir::Cast>()) {
if (!dtype_.Match_(ptr->type)) return false;
if (!value_.Match_(ptr->value)) return false;
return true;
} else {
return false;
}
}
Expr Eval() const {
return ir::Cast::make(dtype_.Eval(), value_.Eval());
}
private:
typename DType::Nested dtype_;
typename TA::Nested value_;
};
/*!
* \brief Construct a cast pattern.
*
* \param dtype The target data type, can be PVar<Type> or PConst<Type>.
* \param value The input type.
*
* \return The result pattern.
*
* \tparam DType The pattern type of type.
* \tparam TA The pattern type of value.
*/
template<typename DType, typename TA>
inline PCastExpr<DType, TA>
cast(const Pattern<DType>& dtype, const Pattern<TA>& value) {
return PCastExpr<DType, TA>(dtype.derived(), value.derived());
}
/*!
* \brief Pattern ramp expression.
* \tparam TBase The pattern type of the base.
* \tparam TStride The pattern type of the stride.
* \tparam TLanes The pattern type of the lanes.
*/
template<typename TBase, typename TStride, typename TLanes>
class PRampExpr :
public Pattern<PRampExpr<TBase, TStride, TLanes> > {
public:
PRampExpr(const TBase& base,
const TStride& stride,
const TLanes& lanes)
: base_(base), stride_(stride), lanes_(lanes) {
}
void InitMatch_() const {
base_.InitMatch_();
stride_.InitMatch_();
lanes_.InitMatch_();
}
bool Match_(const NodeRef& node) const {
if (const ir::Ramp* ptr = node.as<ir::Ramp>()) {
if (!base_.Match_(ptr->base)) return false;
if (!stride_.Match_(ptr->stride)) return false;
if (!lanes_.Match_(ptr->lanes)) return false;
return true;
} else {
return false;
}
}
Expr Eval() const {
return ir::Ramp::make(base_.Eval(), stride_.Eval(), lanes_.Eval());
}
private:
typename TBase::Nested base_;
typename TStride::Nested stride_;
typename TLanes::Nested lanes_;
};
/*!
* \brief Construct a ramp pattern.
*
* \param base The base pattern.
* \param stride The stride pattern.
* \param lanes The lanes pattern.
*
* \return The result pattern.
*
* \tparam TBase The pattern type of the base.
* \tparam TStride The pattern type of the stride.
* \tparam TLanes The pattern type of the lanes.
*/
template<typename TBase, typename TStride, typename TLanes>
inline PRampExpr<TBase, TStride, TLanes>
ramp(const Pattern<TBase>& base,
const Pattern<TStride>& stride,
const Pattern<TLanes>& lanes) {
return PRampExpr<TBase, TStride, TLanes>(
base.derived(), stride.derived(), lanes.derived());
}
/*!
* \brief Pattern broadcast expression.
* \tparam TA The pattern type of the value.
* \tparam TLanes The pattern type of the lanes.
*/
template<typename TA, typename TLanes>
class PBroadcastExpr :
public Pattern<PBroadcastExpr<TA, TLanes> > {
public:
PBroadcastExpr(const TA& value,
const TLanes& lanes)
: value_(value), lanes_(lanes) {
}
void InitMatch_() const {
value_.InitMatch_();
lanes_.InitMatch_();
}
bool Match_(const NodeRef& node) const {
if (const ir::Broadcast* ptr = node.as<ir::Broadcast>()) {
if (!value_.Match_(ptr->value)) return false;
if (!lanes_.Match_(ptr->lanes)) return false;
return true;
} else {
return false;
}
}
Expr Eval() const {
return ir::Broadcast::make(value_.Eval(), lanes_.Eval());
}
private:
typename TA::Nested value_;
typename TLanes::Nested lanes_;
};
/*!
* \brief Construct a broadcast pattern.
*
* \param value The value pattern.
* \param lanes The lanes pattern.
*
* \return The result pattern.
*
* \tparam TA The pattern type of the value.
* \tparam TLanes The pattern type of the lanes.
*/
template<typename TA, typename TLanes>
inline PBroadcastExpr<TA, TLanes>
broadcast(const Pattern<TA>& value, const Pattern<TLanes>& lanes) {
return PBroadcastExpr<TA, TLanes>(value.derived(), lanes.derived());
}
// internal namespace
namespace detail {
// implementation details for CallExpr
template<bool stop, std::size_t I, typename F>
struct tuple_for_each_dispatcher {
template<typename TTuple>
static void run(F& f, const TTuple& tuple) { // NOLINT(*)
f(I, std::get<I>(tuple));
tuple_for_each_dispatcher<
(I + 1) == std::tuple_size<TTuple>::value, (I + 1), F>
::run(f, tuple);
}
};
template<std::size_t I, typename F>
struct tuple_for_each_dispatcher<true, I, F> {
template<typename TTuple>
static void run(F& f, const TTuple& tuple) {} // NOLINT(*)
};
template<typename F, typename TTuple>
inline void tuple_for_each(F& f, const TTuple& tuple) { // NOLINT(*)
tuple_for_each_dispatcher<std::tuple_size<TTuple>::value == 0, 0, F>
::run(f, tuple);
}
struct PCallExprInitMatchFunctor {
template<typename T>
void operator()(size_t i, const T& pattern) const {
pattern.InitMatch_();
}
};
struct PCallExprMatchFunctor {
const ir::Call* call_;
bool matched_{true};
explicit PCallExprMatchFunctor(const ir::Call* call)
: call_(call) {}
template<typename T>
void operator()(size_t i, const T& pattern) {
matched_ = matched_ && pattern.Match_(call_->args[i]);
}
};
struct PCallExprEvalArgsFunctor {
Array<Expr> args_;
template<typename T>
void operator()(size_t i, const T& pattern) {
args_.push_back(pattern.Eval());
}
};
} // namespace detail
/*!
* \brief Pattern CallExpr expression.
* \tparam Op The operator functor class.
* \tparam TArgs The arguments.
* \note Op functor contains the name of the function and
* the implementation of Eval.
*/
template<typename Op, typename ...TArgs>
class PCallExpr :
public Pattern<PCallExpr<Op, TArgs...> > {
public:
explicit PCallExpr(const TArgs&... args)
: args_(args...) {
}
void InitMatch_() const {
detail::PCallExprInitMatchFunctor finit;
detail::tuple_for_each(finit, args_);
}
bool Match_(const NodeRef& node) const {
if (const ir::Call* ptr = node.as<ir::Call>()) {
if (ptr->args.size() != sizeof...(TArgs)) return false;
if (ptr->name != Op::kName) return false;
detail::PCallExprMatchFunctor fmatch(ptr);
detail::tuple_for_each(fmatch, args_);
return fmatch.matched_;
} else {
return false;
}
}
Expr Eval() const {
detail::PCallExprEvalArgsFunctor feval_args;
detail::tuple_for_each(feval_args, args_);
return Op::Eval(feval_args.args_);
}
private:
std::tuple<typename TArgs::Nested...> args_;
};
// arithemetic intrinsics
#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \
struct OpName { \
static Expr Eval(Array<Expr> args) { \
return ir::Call::make(args[0].type(), kName, args, \
ir::Call::PureIntrinsic); \
} \
static constexpr const char* kName = IntrinStr; \
}; \
template<typename TA, typename TB> \
inline PCallExpr<OpName, TA, TB> \
FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
return PCallExpr<OpName, TA, TB>(a.derived(), b.derived()); \
}
TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, "shift_left");
TVM_PATTERN_BINARY_INTRIN(operator>>, PRightShiftOp, "shift_right");
TVM_PATTERN_BINARY_INTRIN(operator&, PBitwiseAndOp, "bitwise_and");
TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, "bitwise_or");
TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor");
// unary intrinsics
#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \
struct OpName { \
static Expr Eval(Array<Expr> args) { \
return ir::Call::make(args[0].type(), kName, args, \
ir::Call::PureIntrinsic); \
} \
static constexpr const char* kName = IntrinStr; \
}; \
template<typename TA> \
inline PCallExpr<OpName, TA> \
FuncName(const Pattern<TA>& a) { \
return PCallExpr<OpName, TA>(a.derived()); \
}
TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not");
// if_then_else
struct PIfThenElseOp {
static Expr Eval(Array<Expr> args) {
return ir::Call::make(
args[1].type(), kName, args,
ir::Call::PureIntrinsic);
}
static constexpr const char* kName = "tvm_if_then_else";
};
/*!
* \brief Construct a if_then_else pattern.
*
* \param cond The condition expression.
* \param true_value The value when condition is true.
* \param true_value The value when condition is false.
*
* \return The result pattern.
*
* \tparam TCond The pattern type of the condition.
* \tparam TA The pattern type of the true operand.
* \tparam TB The pattern type of the false operand.
*/
template<typename TCond, typename TA, typename TB>
inline PCallExpr<PIfThenElseOp, TCond, TA, TB>
if_then_else(const Pattern<TCond>& cond,
const Pattern<TA>& true_value,
const Pattern<TB>& false_value) {
return PCallExpr<PIfThenElseOp, TCond, TA, TB>(
cond.derived(), true_value.derived(), false_value.derived());
}
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_PATTERN_MATCH_H_
......@@ -7,6 +7,7 @@
#include <tvm/packed_func_ext.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include "../arithmetic/pattern_match.h"
namespace tvm {
namespace ir {
......@@ -35,27 +36,8 @@ class CopyIntrinInjector : public IRMutator {
}
private:
bool MatchCondition(Expr expr,
Expr* cond,
Expr* true_value,
Expr* false_value) {
if (const auto* op = expr.as<Select>()) {
*cond = op->condition;
*true_value = op->true_value;
*false_value = op->false_value;
return true;
} else if (const auto* op = expr.as<Call>()) {
if (op->name == intrinsic::tvm_if_then_else) {
*cond = op->args[0];
*true_value = op->args[1];
*false_value = op->args[2];
return true;
}
}
return false;
}
bool MatchCopyPattern(Stmt stmt, Stmt *out) {
using namespace arith;
Stmt body = stmt;
bool is_single_point_copy = false;
......@@ -68,11 +50,13 @@ class CopyIntrinInjector : public IRMutator {
}
const Store* store = body.as<Store>();
if (store == nullptr) return false;
Expr sel_cond, sel_true_value, sel_false_value;
bool has_cond = MatchCondition(store->value,
&sel_cond,
&sel_true_value,
&sel_false_value);
// Expr sel_cond, sel_true_value, sel_false_value;
// match select or if
PVar<Expr> sel_cond, sel_true_value, sel_false_value;
bool has_cond =
if_then_else(sel_cond, sel_true_value, sel_false_value).Match(store->value) ||
select(sel_cond, sel_true_value, sel_false_value).Match(store->value);
const Cast* cast = store->value.as<Cast>();
const Load* load = store->value.as<Load>();
if (0 == loops.size()) {
......@@ -81,7 +65,7 @@ class CopyIntrinInjector : public IRMutator {
}
// for now only support true condition matching
if (has_cond) {
load = sel_true_value.as<Load>();
load = sel_true_value.Eval().as<Load>();
}
// cast can be part of the pattern
if (cast != nullptr) {
......@@ -114,8 +98,8 @@ class CopyIntrinInjector : public IRMutator {
Expr src_elem_offset = load_strides[loop_var_size];
if (has_cond) {
Array<Expr> clip_bound =
arith::DetectClipBound(sel_cond, loop_vars);
pad_value = sel_false_value;
arith::DetectClipBound(sel_cond.Eval(), loop_vars);
pad_value = sel_false_value.Eval();
if (clip_bound.size() == 0) return false;
CHECK_EQ(src_shape.size(), loop_vars.size());
CHECK_EQ(clip_bound.size(), loop_vars.size() * 2);
......
#include <gtest/gtest.h>
#include "../src/arithmetic/pattern_match.h"
TEST(Pattern, Basic) {
using namespace tvm;
using namespace tvm::arith;
Var x("x"), y("y"), z("z");
arith::PVar<Expr> px, py, pz;
arith::PVar<Type> pt;
arith::PVar<int> planes;
// arithmetics
auto r = 1 + (y + 1);
CHECK(!(px + (px + px)).Match(r));
CHECK(!(px + (py + py)).Match(r));
CHECK((px + (py + pz)).Match(r));
auto pattern = px + (py + pz);
CHECK(pattern.Match(r));
{
CHECK((px + (py + px)).Match(r));
auto rr = (px + py).Eval();
CHECK(ir::Equal(rr, 1 + y));
CHECK(ir::Equal(px.Eval() + py.Eval(), 1 + y));
}
{
CHECK((px + max(py, px)).Match((x + 1) + max(y, (x + 1))));
CHECK(ir::Equal(px.Eval(), x + 1));
}
CHECK(!(px + min(py, px)).Match((x + 1) + max(y, (x + 1))));
CHECK((px + min(py, px)).Match(z + min(y, z)));
CHECK((px + py / (px * py)).Match(x + 2 / (x * 2)));
CHECK((px - py % (px * pz)).Match(x - 2 % (x * 2)));
CHECK((px - py % (px * PConst<Expr>(2))).Match(x - 2 % (x * 2)));
// logicals
CHECK((px == pz).Match(x == 1));
CHECK((px != pz).Match(x != 1));
CHECK((px > py).Match(x > y));
CHECK((px < py).Match(x < y));
CHECK((px <= py).Match(x <= y));
CHECK((px >= py).Match(x >= y));
CHECK((px >= py && px < pz).Match(x >= y && x < z));
CHECK((!(px > py || px != py)).Match(!(x > y || x != y)));
{
CHECK(select(px >= pz, py, py + pz).Match(
ir::Select::make((x + 1) >= 1, y, y + 1)));
CHECK(ir::Equal(px.Eval(), x + 1));
}
// bit intrinsics
{
CHECK((px >> pz).Match(x >> 1));
CHECK(is_const_int(pz.Eval(), 1));
}
CHECK(!(px >> pz).Match(x << 1));
CHECK((px << pz).Match(x << 1));
CHECK((px & pz).Match(x & 1));
CHECK((px | pz).Match(x | 1));
CHECK((px ^ pz).Match(x ^ 1));
CHECK((px - (~(py | (px * pz)))).Match(x - (~(2 | (x * 2)))));
// select
{
CHECK(select(px > pz, py, py + pz).Match(
ir::Select::make(x > 1, y, y + 1)));
CHECK(is_const_int(pz.Eval(), 1));
}
CHECK(!select(px > pz, py, py + pz).Match(
ir::Select::make(x > 2, y, y + 1)));
CHECK(!select(px > pz, py, py).Match(
ir::Select::make(x > 2, y, y + 1)));
{
CHECK(select(px, py, pz).Match(
ir::Select::make(x > 2, y, y + 1)));
CHECK(ir::Equal(pz.Eval(), y + 1));
}
// if_then_else
{
CHECK(if_then_else(px > pz, py, py + pz).Match(
if_then_else(x > 1, y, y + 1)));
CHECK(is_const_int(pz.Eval(), 1));
}
// cast pattern
{
CHECK(!cast(PConst<Type>(Int(32)), px).Match(ir::Cast::make(Float(64), x)));
CHECK(cast(pt, px).Match(ir::Cast::make(Float(64), x)));
CHECK(pt.Eval() == Float(64));
auto zz = cast(pt, px).Eval();
CHECK((cast(pt, px) - cast(pt, py)).Match(
ir::Cast::make(Float(64), x) - ir::Cast::make(Int(64), x)));
auto expr = ir::Cast::make(Int(32), ir::Cast::make(Float(64), x));
CHECK(!(cast(pt, cast(pt, px))).Match(expr));
}
// ramp pattern
{
CHECK(ramp(px, PConst<Expr>(1), planes).Match(
ir::Ramp::make(x, 1, 10)));
CHECK(planes.Eval() == 10);
CHECK(!ramp(px, PConst<Expr>(1), planes).Match(
ir::Ramp::make(x, 2, 10)));
}
// broadcast pattern
{
CHECK(broadcast(px, planes).Match(
ir::Broadcast::make(x, 10)));
CHECK(planes.Eval() == 10);
CHECK(broadcast(px * py , planes).Match(
ir::Broadcast::make(x * 10, 10)));
}
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment