Unverified Commit 59448fed by Tianqi Chen Committed by GitHub

[ARITH] Refactor: Remove un-necessary usage of ComputeExpr (#3503)

parent 54f9d20a
......@@ -18,10 +18,8 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file compute_expr.h
* \brief Utility integer expression with quick eager simplification.
* This is weaker than Simplify but can be done Eagerly.
* \brief Utility to invoke certan compute operations.
*/
#ifndef TVM_ARITHMETIC_COMPUTE_EXPR_H_
#define TVM_ARITHMETIC_COMPUTE_EXPR_H_
......@@ -41,7 +39,7 @@ namespace arith {
* \return The result.
*/
template<typename OP>
inline Expr ComputeExpr(Expr lhs, Expr rhs) {
inline Expr Compute(Expr lhs, Expr rhs) {
return OP::make(lhs, rhs);
}
......@@ -79,37 +77,37 @@ inline bool GetConstInt(Expr e, int* out) {
}
template<>
inline Expr ComputeExpr<ir::Add>(Expr a, Expr b) {
inline Expr Compute<ir::Add>(Expr a, Expr b) {
return a + b;
}
template<>
inline Expr ComputeExpr<ir::Sub>(Expr a, Expr b) {
inline Expr Compute<ir::Sub>(Expr a, Expr b) {
return a - b;
}
template<>
inline Expr ComputeExpr<ir::Mul>(Expr a, Expr b) {
inline Expr Compute<ir::Mul>(Expr a, Expr b) {
return a * b;
}
template<>
inline Expr ComputeExpr<ir::Div>(Expr a, Expr b) {
inline Expr Compute<ir::Div>(Expr a, Expr b) {
return a / b;
}
template<>
inline Expr ComputeExpr<ir::Mod>(Expr a, Expr b) {
inline Expr Compute<ir::Mod>(Expr a, Expr b) {
return a % b;
}
template<>
inline Expr ComputeExpr<ir::Max>(Expr a, Expr b) {
inline Expr Compute<ir::Max>(Expr a, Expr b) {
return max(a, b);
}
template<>
inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) {
inline Expr Compute<ir::Min>(Expr a, Expr b) {
return min(a, b);
}
......@@ -121,7 +119,7 @@ inline Expr ComputeReduce(const Array<Expr>& values, Expr empty_value) {
}
Expr res = values[0];
for (size_t i = 1; i < values.size(); ++i) {
res = ComputeExpr<Op>(res, values[i]);
res = Compute<Op>(res, values[i]);
}
return res;
}
......
......@@ -27,7 +27,6 @@
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/arithmetic.h>
#include "compute_expr.h"
namespace tvm {
namespace arith {
......@@ -127,18 +126,18 @@ class LinearEqDetector
Expr AddCombine(Expr a, Expr b) {
if (!a.defined()) return b;
if (!b.defined()) return a;
return ComputeExpr<Add>(a, b);
return a + b;
}
Expr SubCombine(Expr a, Expr b) {
// Check b first in case they are both undefined
if (!b.defined()) return a;
if (!a.defined()) return -b;
return ComputeExpr<Sub>(a, b);
return a - b;
}
Expr MulCombine(Expr a, Expr b) {
if (!a.defined()) return a;
if (!b.defined()) return b;
return ComputeExpr<Mul>(a, b);
return a * b;
}
};
......
......@@ -27,7 +27,6 @@
#include <vector>
#include <string>
#include "codegen_cuda.h"
#include "../arithmetic/compute_expr.h"
namespace tvm {
namespace codegen {
......
......@@ -748,9 +748,7 @@ void CodeGenLLVM::Scalarize(const Expr& e,
std::function<void(int i, llvm::Value* v)> f) {
if (const Ramp* ramp = e.as<Ramp>()) {
for (int i = 0; i < ramp->type.lanes(); ++i) {
Expr offset = arith::ComputeExpr<Add>(
ramp->base,
arith::ComputeExpr<Mul>(ramp->stride, i));
Expr offset = ramp->base + (ramp->stride * i);
f(i, MakeValue(offset));
}
} else {
......
......@@ -25,8 +25,8 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <string>
#include "../../arithmetic/compute_expr.h"
#include "codegen_spirv.h"
#include "../../arithmetic/compute_expr.h"
namespace tvm {
namespace codegen {
......@@ -339,7 +339,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Ramp* op) {
spirv::Value v = base;
if (i != 0) {
spirv::Value offset = MakeValue(
arith::ComputeExpr<Mul>(make_const(op->stride.type(), i), op->stride));
make_const(op->stride.type(), i) * op->stride);
v = builder_->Add(v, offset);
}
values.push_back(v);
......@@ -419,9 +419,7 @@ void CodeGenSPIRV::Scalarize(const Expr& e,
std::function<void(int i, spirv::Value v)> f) {
if (const Ramp* ramp = e.as<Ramp>()) {
for (int i = 0; i < ramp->type.lanes(); ++i) {
Expr offset = arith::ComputeExpr<Add>(
ramp->base,
arith::ComputeExpr<Mul>(ramp->stride, i));
Expr offset = ramp->base + ramp->stride * i;
f(i, MakeValue(offset));
}
} else {
......
......@@ -378,8 +378,7 @@ Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr
extent = make_const(self->DefaultIndexType(), 1);
} else if (self->strides.size() == self->shape.size()) {
int highest_dim = 0;
extent = arith::ComputeExpr<ir::Mul>(
self->strides[highest_dim], self->shape[highest_dim]) - offset;
extent = self->strides[highest_dim] * self->shape[highest_dim] - offset;
} else {
extent = arith::ComputeReduce<ir::Mul>(self->shape, Expr()) - offset;
}
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file arg_binder.cc
* \brief Helper utility to match and bind arguments.
*/
......
......@@ -26,6 +26,7 @@
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/expr_operator.h>
#include "ir_util.h"
#include "../arithmetic/compute_expr.h"
......@@ -100,8 +101,8 @@ class DoubleBufferInjector : public IRMutator {
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
it->second.stride = arith::ComputeReduce<Mul>
(op->extents, Expr()) * op->type.lanes();
it->second.stride = arith::ComputeReduce<Mul>(
op->extents, Expr()) * op->type.lanes();
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Allocate>();
Array<Expr> new_extents{make_const(op->extents[0].type(), 2)};
......@@ -135,11 +136,11 @@ class DoubleBufferInjector : public IRMutator {
<< "It is better to split with multiple of 2";
CHECK(is_zero(old_loop->min));
Expr zero = old_loop->min;
Expr new_ext = arith::ComputeExpr<Sub>(
old_loop->extent, make_const(old_loop->loop_var.type(), 1));
Expr new_ext =
old_loop->extent - make_const(old_loop->loop_var.type(), 1);
Expr factor = make_const(new_ext.type(), split_loop_);
Expr outer_ext = arith::ComputeExpr<Div>(new_ext, factor);
Expr tail_base = arith::ComputeExpr<Mul>(outer_ext, factor);
Expr outer_ext = new_ext / factor;
Expr tail_base = outer_ext * factor;
Var outer_var(old_loop->loop_var->name_hint + ".outer", old_loop->loop_var.type());
std::unordered_map<const Variable*, Expr> vmap;
std::vector<Stmt> loop_seq;
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file inject_virtual_thread.cc
*/
#include <tvm/ir.h>
......@@ -37,6 +36,7 @@ class ExprTouched final : public IRVisitor {
explicit ExprTouched(const std::unordered_set<const Variable*> &touched,
bool check_write)
: touched_var_(touched), check_write_(check_write) {}
void Visit(const NodeRef& n) final {
// early stopping
if (expr_touched_ && !check_write_) return;
......@@ -241,8 +241,8 @@ class VTInjector : public IRMutator {
visit_touched_var_ = true;
Expr offset = Mutate(op->args[2]);
Expr extent = Mutate(op->args[3]);
Expr stride = arith::ComputeExpr<Div>(
it->second, make_const(offset.type(), dtype.lanes()));
Expr stride =
it->second / make_const(offset.type(), dtype.lanes());
offset = stride * var_ + offset;
return Call::make(
op->type, op->name,
......
......@@ -18,8 +18,6 @@
*/
/*!
* Copyright (c) 2018 by Contributors
*
* Lower warp memory to use local memory
* and shuffle intrinsics.
*
......
......@@ -33,7 +33,6 @@
#include "ir_util.h"
#include "arg_binder.h"
#include "../arithmetic/compute_expr.h"
namespace tvm {
namespace ir {
......
......@@ -211,7 +211,7 @@ class StorageFlattener : public IRMutator {
stride = ir::Simplify(stride);
}
rstrides.push_back(stride);
stride = arith::ComputeExpr<Mul>(stride, shape[dim]);
stride = stride * shape[dim];
}
strides = Array<Expr>(rstrides.rbegin(), rstrides.rend());
}
......@@ -237,7 +237,7 @@ class StorageFlattener : public IRMutator {
int first_dim = 0;
ret = Allocate::make(
e.buffer->data, storage_type,
{arith::ComputeExpr<Mul>(e.buffer->strides[first_dim], e.buffer->shape[first_dim])},
{e.buffer->strides[first_dim] * e.buffer->shape[first_dim]},
make_const(Bool(e.buffer->dtype.lanes()), true), body);
} else {
shape = e.buffer->shape;
......@@ -414,8 +414,7 @@ class StorageFlattener : public IRMutator {
if (be.bounds.size() != 0) {
CHECK_EQ(tuple->args.size(), be.bounds.size() * 2);
for (size_t i = 0; i < be.buffer->shape.size(); ++i) {
begins.push_back(
arith::ComputeExpr<Sub>(tuple->args[2 * i], be.bounds[i]->min));
begins.push_back(tuple->args[2 * i] - be.bounds[i]->min);
extents.push_back(tuple->args[2 * i + 1]);
}
} else {
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* Loop unrolling as in Halide pipeline.
* \file unroll_loop.cc
*/
......@@ -144,7 +143,6 @@ class LoopUnroller : public IRMutator {
}
Stmt Unroll(const For* op) {
using arith::ComputeExpr;
int value = GetExtent(op);
// For loop must have a constant integer extent
CHECK_NE(value, -1) << "loop doesn't have a constant integer extent";
......@@ -154,9 +152,7 @@ class LoopUnroller : public IRMutator {
Stmt unrolled;
for (int i = 0; i < value; ++i) {
Var lv(op->loop_var.node_);
vmap.Set(lv,
ComputeExpr<Add>(
op->min, make_const(op->loop_var.type(), i)));
vmap.Set(lv, op->min + make_const(op->loop_var.type(), i));
Stmt step = Substitute(body, vmap);
if (unrolled.defined()) {
unrolled = Block::make(unrolled, step);
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file vectorize_loop.cc
*/
// Loop vectorizer as in Halide pipeline.
......@@ -486,13 +485,13 @@ class Vectorizer : public IRMutator {
const Ramp* a_ramp = a.as<Ramp>();
if (a.type().lanes() == 1 && b_ramp) {
return Ramp::make(
arith::ComputeExpr<T>(a, b_ramp->base),
arith::ComputeExpr<T>(make_zero(b_ramp->stride.type()), b_ramp->stride),
arith::Compute<T>(a, b_ramp->base),
arith::Compute<T>(make_zero(b_ramp->stride.type()), b_ramp->stride),
b_ramp->lanes);
}
if (b.type().lanes() == 1 && a_ramp) {
return Ramp::make(
arith::ComputeExpr<T>(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
arith::Compute<T>(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
}
}
return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file message_passing.cc
* \brief The message passing domain.
*/
......@@ -32,12 +31,11 @@ namespace tvm {
namespace schedule {
using namespace ir;
using namespace arith;
void Update(std::unordered_map<IterVar, Range>* p_state,
const IterVar& iv,
Range r,
Analyzer* analyzer) {
arith::Analyzer* analyzer) {
auto it = p_state->find(iv);
if (it == p_state->end()) {
(*p_state)[iv] = r;
......@@ -145,8 +143,8 @@ void PassUpIndex(const Stage& stage,
Expr factor = dom_map.at(s->inner)->extent;
Expr outer_min = dom_map.at(s->outer)->min;
Expr inner_min = dom_map.at(s->inner)->min;
state[s->outer] = ComputeExpr<Div>(value, factor);
state[s->inner] = ComputeExpr<Mod>(value, factor);
state[s->outer] = value / factor;
state[s->inner] = value % factor;
// add min if they exist
if (!is_zero(outer_min)) {
state[s->outer] = state[s->outer] + outer_min;
......@@ -189,8 +187,8 @@ void PassDownIndex(const Stage& stage,
CHECK(is_zero(r->min));
Expr parent = state.at(s->parent);
Expr factor = r->extent;
state[s->outer] = ComputeExpr<Div>(parent, factor);
state[s->inner] = ComputeExpr<Mod>(parent, factor);
state[s->outer] = parent / factor;
state[s->inner] = parent % factor;
} else if (const FuseNode* s = rel.as<FuseNode>()) {
if (!state.count(s->inner) && !state.count(s->outer)) {
CHECK(allow_missing);
......@@ -240,7 +238,7 @@ void PassUpDomain(const SplitNode* s,
CHECK(outer.defined());
CHECK(inner.defined());
CHECK(factor.defined());
*parent = EvalSet(
*parent = arith::EvalSet(
s->outer->var * factor + s->inner->var + parent_min,
{{s->outer, outer}, {s->inner, inner}});
}
......@@ -290,7 +288,7 @@ void PassUpDomain(const RebaseNode* s,
return;
}
Expr parent_min = dom_map.at(s->parent)->min;
*parent = EvalSet(s->rebased->var + parent_min,
*parent = arith::EvalSet(s->rebased->var + parent_min,
{{s->rebased, rebased}});
}
......@@ -476,7 +474,7 @@ std::vector<Expr> MakeBoundCheck(
const std::unordered_map<IterVar, Expr>& value_map,
bool skip_ivar_domain,
const std::unordered_set<IterVar>& skip_iter) {
Analyzer analyzer;
arith::Analyzer analyzer;
std::unordered_map<IterVar, bool> bound_state;
for (IterVar iv : stage->leaf_iter_vars) {
......@@ -496,7 +494,7 @@ std::vector<Expr> MakeBoundCheck(
if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
if (bound_state.at(iv)) {
Range dom = dom_map.at(iv);
Expr value = ComputeExpr<Sub>(value_map.at(iv), dom->min);
Expr value = value_map.at(iv) - dom->min;
Expr vmax = EvalSet(value, iset_dmap).max();
if (vmax.type() != value.type() || !analyzer.CanProve(vmax < dom->extent)) {
preds.emplace_back(value < dom->extent);
......@@ -508,7 +506,7 @@ std::vector<Expr> MakeBoundCheck(
Range dom = dom_map.at(iv);
CHECK(iv->dom.defined());
if (!skip_ivar_domain && !iv->dom.same_as(dom)) {
Expr value = ComputeExpr<Sub>(value_map.at(iv), iv->dom->min);
Expr value = value_map.at(iv) - iv->dom->min;
IntSet s = EvalSet(value, iset_dmap);
Expr vmin = s.min();
Expr vmax = s.max();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment