/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

#include <gtest/gtest.h>
#include "../src/arithmetic/pattern_match.h"

TEST(Pattern, Basic) {
  using namespace tvm;
  using namespace tvm::arith;
  Var x("x"), y("y"), z("z");
  arith::PVar<Expr> px, py, pz;
  arith::PVar<Type> pt;
  arith::PVar<int> planes;

  // arithmetics
  auto r = 1 + (y + 1);
  CHECK(!(px + (px + px)).Match(r));
  CHECK(!(px + (py + py)).Match(r));
  CHECK((px + (py + pz)).Match(r));
  auto pattern = px + (py + pz);
  CHECK(pattern.Match(r));
  {
    CHECK((px + (py + px)).Match(r));
    auto rr = (px + py).Eval();
    CHECK(ir::Equal(rr, 1 + y));
    CHECK(ir::Equal(px.Eval() + py.Eval(), 1 + y));
  }
  {
    CHECK((px + max(py, px)).Match((x + 1) + max(y, (x + 1))));
    CHECK(ir::Equal(px.Eval(), x + 1));
  }
  CHECK(!(px + min(py, px)).Match((x + 1) + max(y, (x + 1))));
  CHECK((px + min(py, px)).Match(z + min(y, z)));
  CHECK((px + py / (px * py)).Match(x + 2 / (x * 2)));
  CHECK((px - py % (px * pz)).Match(x - 2 % (x * 2)));
  CHECK((px - py % (px * PConst<Expr>(2))).Match(x - 2 % (x * 2)));

  // logicals
  CHECK((px == pz).Match(x == 1));
  CHECK((px != pz).Match(x != 1));
  CHECK((px > py).Match(x > y));
  CHECK((px < py).Match(x < y));
  CHECK((px <= py).Match(x <= y));
  CHECK((px >= py).Match(x >= y));
  CHECK((px >= py && px < pz).Match(x >= y && x < z));
  CHECK((!(px > py || px != py)).Match(!(x > y || x != y)));
  {
    CHECK(select(px >= pz, py, py + pz).Match(
        ir::Select::make((x + 1) >= 1, y, y + 1)));
    CHECK(ir::Equal(px.Eval(), x + 1));
  }
  // bit intrinsics
  {
    CHECK((px >> pz).Match(x >> 1));
    CHECK(is_const_int(pz.Eval(), 1));
  }
  CHECK(!(px >> pz).Match(x << 1));
  CHECK((px << pz).Match(x << 1));
  CHECK((px & pz).Match(x & 1));
  CHECK((px | pz).Match(x | 1));
  CHECK((px ^ pz).Match(x ^ 1));
  CHECK((px - (~(py | (px * pz)))).Match(x - (~(2 | (x * 2)))));
  // select
  {
    CHECK(select(px > pz, py, py + pz).Match(
      ir::Select::make(x > 1, y, y + 1)));
    CHECK(is_const_int(pz.Eval(), 1));
  }
  CHECK(!select(px > pz, py, py + pz).Match(
      ir::Select::make(x > 2, y, y + 1)));
  CHECK(!select(px > pz, py, py).Match(
      ir::Select::make(x > 2, y, y + 1)));
  {
    CHECK(select(px, py, pz).Match(
        ir::Select::make(x > 2, y, y + 1)));
    CHECK(ir::Equal(pz.Eval(), y + 1));
  }
  // if_then_else
  {
    CHECK(if_then_else(px > pz, py, py + pz).Match(
        if_then_else(x > 1, y, y + 1)));
    CHECK(is_const_int(pz.Eval(), 1));
  }
  // cast pattern
  {
    CHECK(!cast(PConst<Type>(Int(32)), px).Match(ir::Cast::make(Float(64), x)));
    CHECK(cast(pt, px).Match(ir::Cast::make(Float(64), x)));
    CHECK(pt.Eval() == Float(64));
    auto zz = cast(pt, px).Eval();
    CHECK((cast(pt, px) - cast(pt, py)).Match(
        ir::Cast::make(Float(64), x) - ir::Cast::make(Int(64), x)));
    auto expr = ir::Cast::make(Int(32), ir::Cast::make(Float(64), x));
    CHECK(!(cast(pt, cast(pt, px))).Match(expr));
  }
  // ramp pattern
  {
    CHECK(ramp(px, PConst<Expr>(1), planes).Match(
        ir::Ramp::make(x, 1, 10)));
    CHECK(planes.Eval() == 10);
    CHECK(!ramp(px, PConst<Expr>(1), planes).Match(
        ir::Ramp::make(x, 2, 10)));
  }
  // broadcast pattern
  {
    CHECK(broadcast(px, planes).Match(
        ir::Broadcast::make(x, 10)));
    CHECK(planes.Eval() == 10);
    CHECK(broadcast(px * py , planes).Match(
        ir::Broadcast::make(x * 10, 10)));
  }
}

TEST(Pattern, Integer) {
  using namespace tvm;
  tvm::Var tx, ty;
  arith::PVar<Integer> c;
  arith::PVar<Var> v;
  {
    // We can match integer and Var, both of which are
    // special case container of Expr
    CHECK((v * c).Match(tx * 3));
    CHECK_EQ(c.Eval()->value, 3);
    CHECK((v * 3).Match(tx * 3));
  }
  // cannot match c to ty
  CHECK(!(v * c).Match(tx * ty));
  // cannot match tx + 1 to v
  CHECK(!(v * c).Match((tx + 1) * 3));
}

int main(int argc, char ** argv) {
  testing::InitGoogleTest(&argc, argv);
  testing::FLAGS_gtest_death_test_style = "threadsafe";
  return RUN_ALL_TESTS();
}