codegen_common.h 1.67 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
/*!
 *  Copyright (c) 2018 by Contributors
 * \file codegen_common.h
 * \brief Common utility for codegen.
 */
#ifndef TVM_CODEGEN_CODEGEN_COMMON_H_
#define TVM_CODEGEN_CODEGEN_COMMON_H_

#include <tvm/arithmetic.h>
#include "../arithmetic/compute_expr.h"

namespace tvm {
namespace codegen {

/*!
 * \brief Visit AssertStmt recursively, update align_map from condition.
 * \param op The AssertStmt
 * \param align_map The alignmap
 * \param fvisit The recursive visitor
 * \tparam FVisit the recursive visitor
 */
template<typename FVisit>
inline void VisitAssert(
    const ir::AssertStmt* op,
    std::unordered_map<const Variable*, arith::ModularEntry>* align_map,
    FVisit fvisit) {
  using namespace ir;
  auto& align_map_ = *align_map;
  // Detect useful invariant pattern and use them to visit child.
  // Pattern: Var % const  == 0
  // TODO(tqchen) merge these pattern to a generic scope info visitor.
  if (const EQ* eq = op->condition.as<EQ>()) {
    const Mod* mod = eq->a.as<Mod>();
    int64_t factor = 0, offset = 0;
    if (mod && arith::GetConst(eq->b, &offset)) {
      const Variable *var = mod->a.as<Variable>();
      if (var && arith::GetConst(mod->b, &factor)) {
        arith::ModularEntry old = align_map_[var];
        if (factor > old.coeff) {
          arith::ModularEntry e;
          e.coeff = static_cast<int>(factor);
          e.base = static_cast<int>(offset);
          // new alignment info,
          align_map_[var] = e;
          fvisit(op->body);
          // restore old info
          align_map_[var] = old;
          return;
        }
      }
    }
  }
  fvisit(op->body);
}

}  // namespace codegen
}  // namespace tvm

#endif  // TVM_CODEGEN_CODEGEN_COMMON_H_