Unverified Commit 59448fed by Tianqi Chen Committed by GitHub

[ARITH] Refactor: Remove un-necessary usage of ComputeExpr (#3503)

parent 54f9d20a
...@@ -18,10 +18,8 @@ ...@@ -18,10 +18,8 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* \file compute_expr.h * \file compute_expr.h
* \brief Utility integer expression with quick eager simplification. * \brief Utility to invoke certan compute operations.
* This is weaker than Simplify but can be done Eagerly.
*/ */
#ifndef TVM_ARITHMETIC_COMPUTE_EXPR_H_ #ifndef TVM_ARITHMETIC_COMPUTE_EXPR_H_
#define TVM_ARITHMETIC_COMPUTE_EXPR_H_ #define TVM_ARITHMETIC_COMPUTE_EXPR_H_
...@@ -41,7 +39,7 @@ namespace arith { ...@@ -41,7 +39,7 @@ namespace arith {
* \return The result. * \return The result.
*/ */
template<typename OP> template<typename OP>
inline Expr ComputeExpr(Expr lhs, Expr rhs) { inline Expr Compute(Expr lhs, Expr rhs) {
return OP::make(lhs, rhs); return OP::make(lhs, rhs);
} }
...@@ -79,37 +77,37 @@ inline bool GetConstInt(Expr e, int* out) { ...@@ -79,37 +77,37 @@ inline bool GetConstInt(Expr e, int* out) {
} }
template<> template<>
inline Expr ComputeExpr<ir::Add>(Expr a, Expr b) { inline Expr Compute<ir::Add>(Expr a, Expr b) {
return a + b; return a + b;
} }
template<> template<>
inline Expr ComputeExpr<ir::Sub>(Expr a, Expr b) { inline Expr Compute<ir::Sub>(Expr a, Expr b) {
return a - b; return a - b;
} }
template<> template<>
inline Expr ComputeExpr<ir::Mul>(Expr a, Expr b) { inline Expr Compute<ir::Mul>(Expr a, Expr b) {
return a * b; return a * b;
} }
template<> template<>
inline Expr ComputeExpr<ir::Div>(Expr a, Expr b) { inline Expr Compute<ir::Div>(Expr a, Expr b) {
return a / b; return a / b;
} }
template<> template<>
inline Expr ComputeExpr<ir::Mod>(Expr a, Expr b) { inline Expr Compute<ir::Mod>(Expr a, Expr b) {
return a % b; return a % b;
} }
template<> template<>
inline Expr ComputeExpr<ir::Max>(Expr a, Expr b) { inline Expr Compute<ir::Max>(Expr a, Expr b) {
return max(a, b); return max(a, b);
} }
template<> template<>
inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) { inline Expr Compute<ir::Min>(Expr a, Expr b) {
return min(a, b); return min(a, b);
} }
...@@ -121,7 +119,7 @@ inline Expr ComputeReduce(const Array<Expr>& values, Expr empty_value) { ...@@ -121,7 +119,7 @@ inline Expr ComputeReduce(const Array<Expr>& values, Expr empty_value) {
} }
Expr res = values[0]; Expr res = values[0];
for (size_t i = 1; i < values.size(); ++i) { for (size_t i = 1; i < values.size(); ++i) {
res = ComputeExpr<Op>(res, values[i]); res = Compute<Op>(res, values[i]);
} }
return res; return res;
} }
......
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
#include <tvm/ir_visitor.h> #include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h> #include <tvm/ir_functor_ext.h>
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
#include "compute_expr.h"
namespace tvm { namespace tvm {
namespace arith { namespace arith {
...@@ -127,18 +126,18 @@ class LinearEqDetector ...@@ -127,18 +126,18 @@ class LinearEqDetector
Expr AddCombine(Expr a, Expr b) { Expr AddCombine(Expr a, Expr b) {
if (!a.defined()) return b; if (!a.defined()) return b;
if (!b.defined()) return a; if (!b.defined()) return a;
return ComputeExpr<Add>(a, b); return a + b;
} }
Expr SubCombine(Expr a, Expr b) { Expr SubCombine(Expr a, Expr b) {
// Check b first in case they are both undefined // Check b first in case they are both undefined
if (!b.defined()) return a; if (!b.defined()) return a;
if (!a.defined()) return -b; if (!a.defined()) return -b;
return ComputeExpr<Sub>(a, b); return a - b;
} }
Expr MulCombine(Expr a, Expr b) { Expr MulCombine(Expr a, Expr b) {
if (!a.defined()) return a; if (!a.defined()) return a;
if (!b.defined()) return b; if (!b.defined()) return b;
return ComputeExpr<Mul>(a, b); return a * b;
} }
}; };
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
#include <vector> #include <vector>
#include <string> #include <string>
#include "codegen_cuda.h" #include "codegen_cuda.h"
#include "../arithmetic/compute_expr.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
......
...@@ -748,9 +748,7 @@ void CodeGenLLVM::Scalarize(const Expr& e, ...@@ -748,9 +748,7 @@ void CodeGenLLVM::Scalarize(const Expr& e,
std::function<void(int i, llvm::Value* v)> f) { std::function<void(int i, llvm::Value* v)> f) {
if (const Ramp* ramp = e.as<Ramp>()) { if (const Ramp* ramp = e.as<Ramp>()) {
for (int i = 0; i < ramp->type.lanes(); ++i) { for (int i = 0; i < ramp->type.lanes(); ++i) {
Expr offset = arith::ComputeExpr<Add>( Expr offset = ramp->base + (ramp->stride * i);
ramp->base,
arith::ComputeExpr<Mul>(ramp->stride, i));
f(i, MakeValue(offset)); f(i, MakeValue(offset));
} }
} else { } else {
......
...@@ -25,8 +25,8 @@ ...@@ -25,8 +25,8 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <string> #include <string>
#include "../../arithmetic/compute_expr.h"
#include "codegen_spirv.h" #include "codegen_spirv.h"
#include "../../arithmetic/compute_expr.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
...@@ -339,7 +339,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Ramp* op) { ...@@ -339,7 +339,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Ramp* op) {
spirv::Value v = base; spirv::Value v = base;
if (i != 0) { if (i != 0) {
spirv::Value offset = MakeValue( spirv::Value offset = MakeValue(
arith::ComputeExpr<Mul>(make_const(op->stride.type(), i), op->stride)); make_const(op->stride.type(), i) * op->stride);
v = builder_->Add(v, offset); v = builder_->Add(v, offset);
} }
values.push_back(v); values.push_back(v);
...@@ -419,9 +419,7 @@ void CodeGenSPIRV::Scalarize(const Expr& e, ...@@ -419,9 +419,7 @@ void CodeGenSPIRV::Scalarize(const Expr& e,
std::function<void(int i, spirv::Value v)> f) { std::function<void(int i, spirv::Value v)> f) {
if (const Ramp* ramp = e.as<Ramp>()) { if (const Ramp* ramp = e.as<Ramp>()) {
for (int i = 0; i < ramp->type.lanes(); ++i) { for (int i = 0; i < ramp->type.lanes(); ++i) {
Expr offset = arith::ComputeExpr<Add>( Expr offset = ramp->base + ramp->stride * i;
ramp->base,
arith::ComputeExpr<Mul>(ramp->stride, i));
f(i, MakeValue(offset)); f(i, MakeValue(offset));
} }
} else { } else {
......
...@@ -378,8 +378,7 @@ Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr ...@@ -378,8 +378,7 @@ Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr
extent = make_const(self->DefaultIndexType(), 1); extent = make_const(self->DefaultIndexType(), 1);
} else if (self->strides.size() == self->shape.size()) { } else if (self->strides.size() == self->shape.size()) {
int highest_dim = 0; int highest_dim = 0;
extent = arith::ComputeExpr<ir::Mul>( extent = self->strides[highest_dim] * self->shape[highest_dim] - offset;
self->strides[highest_dim], self->shape[highest_dim]) - offset;
} else { } else {
extent = arith::ComputeReduce<ir::Mul>(self->shape, Expr()) - offset; extent = arith::ComputeReduce<ir::Mul>(self->shape, Expr()) - offset;
} }
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* \file arg_binder.cc * \file arg_binder.cc
* \brief Helper utility to match and bind arguments. * \brief Helper utility to match and bind arguments.
*/ */
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <tvm/expr_operator.h>
#include "ir_util.h" #include "ir_util.h"
#include "../arithmetic/compute_expr.h" #include "../arithmetic/compute_expr.h"
...@@ -100,8 +101,8 @@ class DoubleBufferInjector : public IRMutator { ...@@ -100,8 +101,8 @@ class DoubleBufferInjector : public IRMutator {
Stmt Mutate_(const Allocate* op, const Stmt& s) final { Stmt Mutate_(const Allocate* op, const Stmt& s) final {
auto it = dbuffer_info_.find(op->buffer_var.get()); auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) { if (it != dbuffer_info_.end()) {
it->second.stride = arith::ComputeReduce<Mul> it->second.stride = arith::ComputeReduce<Mul>(
(op->extents, Expr()) * op->type.lanes(); op->extents, Expr()) * op->type.lanes();
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Allocate>(); op = stmt.as<Allocate>();
Array<Expr> new_extents{make_const(op->extents[0].type(), 2)}; Array<Expr> new_extents{make_const(op->extents[0].type(), 2)};
...@@ -135,11 +136,11 @@ class DoubleBufferInjector : public IRMutator { ...@@ -135,11 +136,11 @@ class DoubleBufferInjector : public IRMutator {
<< "It is better to split with multiple of 2"; << "It is better to split with multiple of 2";
CHECK(is_zero(old_loop->min)); CHECK(is_zero(old_loop->min));
Expr zero = old_loop->min; Expr zero = old_loop->min;
Expr new_ext = arith::ComputeExpr<Sub>( Expr new_ext =
old_loop->extent, make_const(old_loop->loop_var.type(), 1)); old_loop->extent - make_const(old_loop->loop_var.type(), 1);
Expr factor = make_const(new_ext.type(), split_loop_); Expr factor = make_const(new_ext.type(), split_loop_);
Expr outer_ext = arith::ComputeExpr<Div>(new_ext, factor); Expr outer_ext = new_ext / factor;
Expr tail_base = arith::ComputeExpr<Mul>(outer_ext, factor); Expr tail_base = outer_ext * factor;
Var outer_var(old_loop->loop_var->name_hint + ".outer", old_loop->loop_var.type()); Var outer_var(old_loop->loop_var->name_hint + ".outer", old_loop->loop_var.type());
std::unordered_map<const Variable*, Expr> vmap; std::unordered_map<const Variable*, Expr> vmap;
std::vector<Stmt> loop_seq; std::vector<Stmt> loop_seq;
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* \file inject_virtual_thread.cc * \file inject_virtual_thread.cc
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
...@@ -37,6 +36,7 @@ class ExprTouched final : public IRVisitor { ...@@ -37,6 +36,7 @@ class ExprTouched final : public IRVisitor {
explicit ExprTouched(const std::unordered_set<const Variable*> &touched, explicit ExprTouched(const std::unordered_set<const Variable*> &touched,
bool check_write) bool check_write)
: touched_var_(touched), check_write_(check_write) {} : touched_var_(touched), check_write_(check_write) {}
void Visit(const NodeRef& n) final { void Visit(const NodeRef& n) final {
// early stopping // early stopping
if (expr_touched_ && !check_write_) return; if (expr_touched_ && !check_write_) return;
...@@ -241,8 +241,8 @@ class VTInjector : public IRMutator { ...@@ -241,8 +241,8 @@ class VTInjector : public IRMutator {
visit_touched_var_ = true; visit_touched_var_ = true;
Expr offset = Mutate(op->args[2]); Expr offset = Mutate(op->args[2]);
Expr extent = Mutate(op->args[3]); Expr extent = Mutate(op->args[3]);
Expr stride = arith::ComputeExpr<Div>( Expr stride =
it->second, make_const(offset.type(), dtype.lanes())); it->second / make_const(offset.type(), dtype.lanes());
offset = stride * var_ + offset; offset = stride * var_ + offset;
return Call::make( return Call::make(
op->type, op->name, op->type, op->name,
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......
...@@ -18,8 +18,6 @@ ...@@ -18,8 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2018 by Contributors
*
* Lower warp memory to use local memory * Lower warp memory to use local memory
* and shuffle intrinsics. * and shuffle intrinsics.
* *
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -33,7 +33,6 @@ ...@@ -33,7 +33,6 @@
#include "ir_util.h" #include "ir_util.h"
#include "arg_binder.h" #include "arg_binder.h"
#include "../arithmetic/compute_expr.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -211,7 +211,7 @@ class StorageFlattener : public IRMutator { ...@@ -211,7 +211,7 @@ class StorageFlattener : public IRMutator {
stride = ir::Simplify(stride); stride = ir::Simplify(stride);
} }
rstrides.push_back(stride); rstrides.push_back(stride);
stride = arith::ComputeExpr<Mul>(stride, shape[dim]); stride = stride * shape[dim];
} }
strides = Array<Expr>(rstrides.rbegin(), rstrides.rend()); strides = Array<Expr>(rstrides.rbegin(), rstrides.rend());
} }
...@@ -237,7 +237,7 @@ class StorageFlattener : public IRMutator { ...@@ -237,7 +237,7 @@ class StorageFlattener : public IRMutator {
int first_dim = 0; int first_dim = 0;
ret = Allocate::make( ret = Allocate::make(
e.buffer->data, storage_type, e.buffer->data, storage_type,
{arith::ComputeExpr<Mul>(e.buffer->strides[first_dim], e.buffer->shape[first_dim])}, {e.buffer->strides[first_dim] * e.buffer->shape[first_dim]},
make_const(Bool(e.buffer->dtype.lanes()), true), body); make_const(Bool(e.buffer->dtype.lanes()), true), body);
} else { } else {
shape = e.buffer->shape; shape = e.buffer->shape;
...@@ -414,8 +414,7 @@ class StorageFlattener : public IRMutator { ...@@ -414,8 +414,7 @@ class StorageFlattener : public IRMutator {
if (be.bounds.size() != 0) { if (be.bounds.size() != 0) {
CHECK_EQ(tuple->args.size(), be.bounds.size() * 2); CHECK_EQ(tuple->args.size(), be.bounds.size() * 2);
for (size_t i = 0; i < be.buffer->shape.size(); ++i) { for (size_t i = 0; i < be.buffer->shape.size(); ++i) {
begins.push_back( begins.push_back(tuple->args[2 * i] - be.bounds[i]->min);
arith::ComputeExpr<Sub>(tuple->args[2 * i], be.bounds[i]->min));
extents.push_back(tuple->args[2 * i + 1]); extents.push_back(tuple->args[2 * i + 1]);
} }
} else { } else {
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* Loop unrolling as in Halide pipeline. * Loop unrolling as in Halide pipeline.
* \file unroll_loop.cc * \file unroll_loop.cc
*/ */
...@@ -144,7 +143,6 @@ class LoopUnroller : public IRMutator { ...@@ -144,7 +143,6 @@ class LoopUnroller : public IRMutator {
} }
Stmt Unroll(const For* op) { Stmt Unroll(const For* op) {
using arith::ComputeExpr;
int value = GetExtent(op); int value = GetExtent(op);
// For loop must have a constant integer extent // For loop must have a constant integer extent
CHECK_NE(value, -1) << "loop doesn't have a constant integer extent"; CHECK_NE(value, -1) << "loop doesn't have a constant integer extent";
...@@ -154,9 +152,7 @@ class LoopUnroller : public IRMutator { ...@@ -154,9 +152,7 @@ class LoopUnroller : public IRMutator {
Stmt unrolled; Stmt unrolled;
for (int i = 0; i < value; ++i) { for (int i = 0; i < value; ++i) {
Var lv(op->loop_var.node_); Var lv(op->loop_var.node_);
vmap.Set(lv, vmap.Set(lv, op->min + make_const(op->loop_var.type(), i));
ComputeExpr<Add>(
op->min, make_const(op->loop_var.type(), i)));
Stmt step = Substitute(body, vmap); Stmt step = Substitute(body, vmap);
if (unrolled.defined()) { if (unrolled.defined()) {
unrolled = Block::make(unrolled, step); unrolled = Block::make(unrolled, step);
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* \file vectorize_loop.cc * \file vectorize_loop.cc
*/ */
// Loop vectorizer as in Halide pipeline. // Loop vectorizer as in Halide pipeline.
...@@ -486,13 +485,13 @@ class Vectorizer : public IRMutator { ...@@ -486,13 +485,13 @@ class Vectorizer : public IRMutator {
const Ramp* a_ramp = a.as<Ramp>(); const Ramp* a_ramp = a.as<Ramp>();
if (a.type().lanes() == 1 && b_ramp) { if (a.type().lanes() == 1 && b_ramp) {
return Ramp::make( return Ramp::make(
arith::ComputeExpr<T>(a, b_ramp->base), arith::Compute<T>(a, b_ramp->base),
arith::ComputeExpr<T>(make_zero(b_ramp->stride.type()), b_ramp->stride), arith::Compute<T>(make_zero(b_ramp->stride.type()), b_ramp->stride),
b_ramp->lanes); b_ramp->lanes);
} }
if (b.type().lanes() == 1 && a_ramp) { if (b.type().lanes() == 1 && a_ramp) {
return Ramp::make( return Ramp::make(
arith::ComputeExpr<T>(a_ramp->base, b), a_ramp->stride, a_ramp->lanes); arith::Compute<T>(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
} }
} }
return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* \file message_passing.cc * \file message_passing.cc
* \brief The message passing domain. * \brief The message passing domain.
*/ */
...@@ -32,12 +31,11 @@ namespace tvm { ...@@ -32,12 +31,11 @@ namespace tvm {
namespace schedule { namespace schedule {
using namespace ir; using namespace ir;
using namespace arith;
void Update(std::unordered_map<IterVar, Range>* p_state, void Update(std::unordered_map<IterVar, Range>* p_state,
const IterVar& iv, const IterVar& iv,
Range r, Range r,
Analyzer* analyzer) { arith::Analyzer* analyzer) {
auto it = p_state->find(iv); auto it = p_state->find(iv);
if (it == p_state->end()) { if (it == p_state->end()) {
(*p_state)[iv] = r; (*p_state)[iv] = r;
...@@ -145,8 +143,8 @@ void PassUpIndex(const Stage& stage, ...@@ -145,8 +143,8 @@ void PassUpIndex(const Stage& stage,
Expr factor = dom_map.at(s->inner)->extent; Expr factor = dom_map.at(s->inner)->extent;
Expr outer_min = dom_map.at(s->outer)->min; Expr outer_min = dom_map.at(s->outer)->min;
Expr inner_min = dom_map.at(s->inner)->min; Expr inner_min = dom_map.at(s->inner)->min;
state[s->outer] = ComputeExpr<Div>(value, factor); state[s->outer] = value / factor;
state[s->inner] = ComputeExpr<Mod>(value, factor); state[s->inner] = value % factor;
// add min if they exist // add min if they exist
if (!is_zero(outer_min)) { if (!is_zero(outer_min)) {
state[s->outer] = state[s->outer] + outer_min; state[s->outer] = state[s->outer] + outer_min;
...@@ -189,8 +187,8 @@ void PassDownIndex(const Stage& stage, ...@@ -189,8 +187,8 @@ void PassDownIndex(const Stage& stage,
CHECK(is_zero(r->min)); CHECK(is_zero(r->min));
Expr parent = state.at(s->parent); Expr parent = state.at(s->parent);
Expr factor = r->extent; Expr factor = r->extent;
state[s->outer] = ComputeExpr<Div>(parent, factor); state[s->outer] = parent / factor;
state[s->inner] = ComputeExpr<Mod>(parent, factor); state[s->inner] = parent % factor;
} else if (const FuseNode* s = rel.as<FuseNode>()) { } else if (const FuseNode* s = rel.as<FuseNode>()) {
if (!state.count(s->inner) && !state.count(s->outer)) { if (!state.count(s->inner) && !state.count(s->outer)) {
CHECK(allow_missing); CHECK(allow_missing);
...@@ -240,7 +238,7 @@ void PassUpDomain(const SplitNode* s, ...@@ -240,7 +238,7 @@ void PassUpDomain(const SplitNode* s,
CHECK(outer.defined()); CHECK(outer.defined());
CHECK(inner.defined()); CHECK(inner.defined());
CHECK(factor.defined()); CHECK(factor.defined());
*parent = EvalSet( *parent = arith::EvalSet(
s->outer->var * factor + s->inner->var + parent_min, s->outer->var * factor + s->inner->var + parent_min,
{{s->outer, outer}, {s->inner, inner}}); {{s->outer, outer}, {s->inner, inner}});
} }
...@@ -290,8 +288,8 @@ void PassUpDomain(const RebaseNode* s, ...@@ -290,8 +288,8 @@ void PassUpDomain(const RebaseNode* s,
return; return;
} }
Expr parent_min = dom_map.at(s->parent)->min; Expr parent_min = dom_map.at(s->parent)->min;
*parent = EvalSet(s->rebased->var + parent_min, *parent = arith::EvalSet(s->rebased->var + parent_min,
{{s->rebased, rebased}}); {{s->rebased, rebased}});
} }
void PassUpDomain(const Stage& stage, void PassUpDomain(const Stage& stage,
...@@ -476,7 +474,7 @@ std::vector<Expr> MakeBoundCheck( ...@@ -476,7 +474,7 @@ std::vector<Expr> MakeBoundCheck(
const std::unordered_map<IterVar, Expr>& value_map, const std::unordered_map<IterVar, Expr>& value_map,
bool skip_ivar_domain, bool skip_ivar_domain,
const std::unordered_set<IterVar>& skip_iter) { const std::unordered_set<IterVar>& skip_iter) {
Analyzer analyzer; arith::Analyzer analyzer;
std::unordered_map<IterVar, bool> bound_state; std::unordered_map<IterVar, bool> bound_state;
for (IterVar iv : stage->leaf_iter_vars) { for (IterVar iv : stage->leaf_iter_vars) {
...@@ -496,7 +494,7 @@ std::vector<Expr> MakeBoundCheck( ...@@ -496,7 +494,7 @@ std::vector<Expr> MakeBoundCheck(
if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue; if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
if (bound_state.at(iv)) { if (bound_state.at(iv)) {
Range dom = dom_map.at(iv); Range dom = dom_map.at(iv);
Expr value = ComputeExpr<Sub>(value_map.at(iv), dom->min); Expr value = value_map.at(iv) - dom->min;
Expr vmax = EvalSet(value, iset_dmap).max(); Expr vmax = EvalSet(value, iset_dmap).max();
if (vmax.type() != value.type() || !analyzer.CanProve(vmax < dom->extent)) { if (vmax.type() != value.type() || !analyzer.CanProve(vmax < dom->extent)) {
preds.emplace_back(value < dom->extent); preds.emplace_back(value < dom->extent);
...@@ -508,7 +506,7 @@ std::vector<Expr> MakeBoundCheck( ...@@ -508,7 +506,7 @@ std::vector<Expr> MakeBoundCheck(
Range dom = dom_map.at(iv); Range dom = dom_map.at(iv);
CHECK(iv->dom.defined()); CHECK(iv->dom.defined());
if (!skip_ivar_domain && !iv->dom.same_as(dom)) { if (!skip_ivar_domain && !iv->dom.same_as(dom)) {
Expr value = ComputeExpr<Sub>(value_map.at(iv), iv->dom->min); Expr value = value_map.at(iv) - iv->dom->min;
IntSet s = EvalSet(value, iset_dmap); IntSet s = EvalSet(value, iset_dmap);
Expr vmin = s.min(); Expr vmin = s.min();
Expr vmax = s.max(); Expr vmax = s.max();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment