Unverified Commit 255c187b by Tianqi Chen Committed by GitHub

[EXPR] Expression-template based pattern matching. (#2589)

parent f6be4d69
......@@ -7,6 +7,7 @@
#include <tvm/packed_func_ext.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include "../arithmetic/pattern_match.h"
namespace tvm {
namespace ir {
......@@ -35,27 +36,8 @@ class CopyIntrinInjector : public IRMutator {
}
private:
bool MatchCondition(Expr expr,
Expr* cond,
Expr* true_value,
Expr* false_value) {
if (const auto* op = expr.as<Select>()) {
*cond = op->condition;
*true_value = op->true_value;
*false_value = op->false_value;
return true;
} else if (const auto* op = expr.as<Call>()) {
if (op->name == intrinsic::tvm_if_then_else) {
*cond = op->args[0];
*true_value = op->args[1];
*false_value = op->args[2];
return true;
}
}
return false;
}
bool MatchCopyPattern(Stmt stmt, Stmt *out) {
using namespace arith;
Stmt body = stmt;
bool is_single_point_copy = false;
......@@ -68,11 +50,13 @@ class CopyIntrinInjector : public IRMutator {
}
const Store* store = body.as<Store>();
if (store == nullptr) return false;
Expr sel_cond, sel_true_value, sel_false_value;
bool has_cond = MatchCondition(store->value,
&sel_cond,
&sel_true_value,
&sel_false_value);
// Expr sel_cond, sel_true_value, sel_false_value;
// match select or if
PVar<Expr> sel_cond, sel_true_value, sel_false_value;
bool has_cond =
if_then_else(sel_cond, sel_true_value, sel_false_value).Match(store->value) ||
select(sel_cond, sel_true_value, sel_false_value).Match(store->value);
const Cast* cast = store->value.as<Cast>();
const Load* load = store->value.as<Load>();
if (0 == loops.size()) {
......@@ -81,7 +65,7 @@ class CopyIntrinInjector : public IRMutator {
}
// for now only support true condition matching
if (has_cond) {
load = sel_true_value.as<Load>();
load = sel_true_value.Eval().as<Load>();
}
// cast can be part of the pattern
if (cast != nullptr) {
......@@ -114,8 +98,8 @@ class CopyIntrinInjector : public IRMutator {
Expr src_elem_offset = load_strides[loop_var_size];
if (has_cond) {
Array<Expr> clip_bound =
arith::DetectClipBound(sel_cond, loop_vars);
pad_value = sel_false_value;
arith::DetectClipBound(sel_cond.Eval(), loop_vars);
pad_value = sel_false_value.Eval();
if (clip_bound.size() == 0) return false;
CHECK_EQ(src_shape.size(), loop_vars.size());
CHECK_EQ(clip_bound.size(), loop_vars.size() * 2);
......
#include <gtest/gtest.h>
#include "../src/arithmetic/pattern_match.h"
TEST(Pattern, Basic) {
using namespace tvm;
using namespace tvm::arith;
Var x("x"), y("y"), z("z");
arith::PVar<Expr> px, py, pz;
arith::PVar<Type> pt;
arith::PVar<int> planes;
// arithmetics
auto r = 1 + (y + 1);
CHECK(!(px + (px + px)).Match(r));
CHECK(!(px + (py + py)).Match(r));
CHECK((px + (py + pz)).Match(r));
auto pattern = px + (py + pz);
CHECK(pattern.Match(r));
{
CHECK((px + (py + px)).Match(r));
auto rr = (px + py).Eval();
CHECK(ir::Equal(rr, 1 + y));
CHECK(ir::Equal(px.Eval() + py.Eval(), 1 + y));
}
{
CHECK((px + max(py, px)).Match((x + 1) + max(y, (x + 1))));
CHECK(ir::Equal(px.Eval(), x + 1));
}
CHECK(!(px + min(py, px)).Match((x + 1) + max(y, (x + 1))));
CHECK((px + min(py, px)).Match(z + min(y, z)));
CHECK((px + py / (px * py)).Match(x + 2 / (x * 2)));
CHECK((px - py % (px * pz)).Match(x - 2 % (x * 2)));
CHECK((px - py % (px * PConst<Expr>(2))).Match(x - 2 % (x * 2)));
// logicals
CHECK((px == pz).Match(x == 1));
CHECK((px != pz).Match(x != 1));
CHECK((px > py).Match(x > y));
CHECK((px < py).Match(x < y));
CHECK((px <= py).Match(x <= y));
CHECK((px >= py).Match(x >= y));
CHECK((px >= py && px < pz).Match(x >= y && x < z));
CHECK((!(px > py || px != py)).Match(!(x > y || x != y)));
{
CHECK(select(px >= pz, py, py + pz).Match(
ir::Select::make((x + 1) >= 1, y, y + 1)));
CHECK(ir::Equal(px.Eval(), x + 1));
}
// bit intrinsics
{
CHECK((px >> pz).Match(x >> 1));
CHECK(is_const_int(pz.Eval(), 1));
}
CHECK(!(px >> pz).Match(x << 1));
CHECK((px << pz).Match(x << 1));
CHECK((px & pz).Match(x & 1));
CHECK((px | pz).Match(x | 1));
CHECK((px ^ pz).Match(x ^ 1));
CHECK((px - (~(py | (px * pz)))).Match(x - (~(2 | (x * 2)))));
// select
{
CHECK(select(px > pz, py, py + pz).Match(
ir::Select::make(x > 1, y, y + 1)));
CHECK(is_const_int(pz.Eval(), 1));
}
CHECK(!select(px > pz, py, py + pz).Match(
ir::Select::make(x > 2, y, y + 1)));
CHECK(!select(px > pz, py, py).Match(
ir::Select::make(x > 2, y, y + 1)));
{
CHECK(select(px, py, pz).Match(
ir::Select::make(x > 2, y, y + 1)));
CHECK(ir::Equal(pz.Eval(), y + 1));
}
// if_then_else
{
CHECK(if_then_else(px > pz, py, py + pz).Match(
if_then_else(x > 1, y, y + 1)));
CHECK(is_const_int(pz.Eval(), 1));
}
// cast pattern
{
CHECK(!cast(PConst<Type>(Int(32)), px).Match(ir::Cast::make(Float(64), x)));
CHECK(cast(pt, px).Match(ir::Cast::make(Float(64), x)));
CHECK(pt.Eval() == Float(64));
auto zz = cast(pt, px).Eval();
CHECK((cast(pt, px) - cast(pt, py)).Match(
ir::Cast::make(Float(64), x) - ir::Cast::make(Int(64), x)));
auto expr = ir::Cast::make(Int(32), ir::Cast::make(Float(64), x));
CHECK(!(cast(pt, cast(pt, px))).Match(expr));
}
// ramp pattern
{
CHECK(ramp(px, PConst<Expr>(1), planes).Match(
ir::Ramp::make(x, 1, 10)));
CHECK(planes.Eval() == 10);
CHECK(!ramp(px, PConst<Expr>(1), planes).Match(
ir::Ramp::make(x, 2, 10)));
}
// broadcast pattern
{
CHECK(broadcast(px, planes).Match(
ir::Broadcast::make(x, 10)));
CHECK(planes.Eval() == 10);
CHECK(broadcast(px * py , planes).Match(
ir::Broadcast::make(x * 10, 10)));
}
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment