inject_copy_intrin.cc 6.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28
/*!
 *  Copyright (c) 2017 by Contributors
 * \brief Replace certain copy with copy intrinsics.
 * \file copy_intrin_rewrite.cc
 */
#include <tvm/ir.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
29
#include "../arithmetic/pattern_match.h"
30 31 32 33 34 35 36 37 38 39

namespace tvm {
namespace ir {

using runtime::PackedFunc;

class CopyIntrinInjector : public IRMutator {
 public:
  CopyIntrinInjector(const std::string& pragma_key,
                     const PackedFunc& flower_copy_fromto)
40
      : pragma_key_(attr::pragma_scope_prefix+  pragma_key),
41 42 43 44 45 46 47
        flower_copy_fromto_(flower_copy_fromto) {
  }

  Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
    if (op->attr_key == attr::storage_scope) {
      const Variable* buf = op->node.as<Variable>();
      storage_scope_[buf] = op->value.as<StringImm>()->value;
48 49 50 51 52
    } else if (op->attr_key == pragma_key_) {
      Stmt ret;
      CHECK(MatchCopyPattern(op->body, &ret))
          << "Cannot match copy pattern of " << op->body;
      return ret;
53 54 55 56 57 58
    }
    return IRMutator::Mutate_(op, s);
  }

 private:
  bool MatchCopyPattern(Stmt stmt, Stmt *out) {
59
    using namespace arith;
60 61 62 63 64 65 66 67 68 69 70
    Stmt body = stmt;

    // strip the loops
    std::vector<const For*> loops;
    while (const For* op = body.as<For>()) {
      if (!is_zero(op->min)) return false;
      loops.push_back(op);
      body = op->body;
    }
    const Store* store = body.as<Store>();
    if (store == nullptr) return false;
71 72 73 74 75 76 77
    // Expr sel_cond, sel_true_value, sel_false_value;
    // match select or if
    PVar<Expr> sel_cond, sel_true_value, sel_false_value;
    bool has_cond =
        if_then_else(sel_cond, sel_true_value, sel_false_value).Match(store->value) ||
        select(sel_cond, sel_true_value, sel_false_value).Match(store->value);

78
    const Cast* cast = store->value.as<Cast>();
79
    const Load* load = store->value.as<Load>();
80
    if (0 == loops.size()) {
81
      CHECK(!has_cond);
82
    }
83
    // for now only support true condition matching
84
    if (has_cond) {
85
      load = sel_true_value.Eval().as<Load>();
86
    }
87 88 89 90
    // cast can be part of the pattern
    if (cast != nullptr) {
      load = cast->value.as<Load>();
    }
91 92 93 94 95 96 97 98 99 100 101 102
    if (load == nullptr) return false;
    if (load->type.lanes() != 1) return false;
    Array<Var> loop_vars;
    for (const For* op : loops) {
      loop_vars.push_back(Var(op->loop_var.node_));
    }
    Array<Expr> store_strides =
        arith::DetectLinearEquation(store->index, loop_vars);
    Array<Expr> load_strides =
        arith::DetectLinearEquation(load->index, loop_vars);
    if (load_strides.size()  == 0 || store_strides.size() == 0) return false;
    Array<Expr> dst_shape;
103 104
    const size_t loop_var_size = loop_vars.size();
    if (loop_var_size == 0) {
105 106 107 108 109
      dst_shape.push_back(make_const(Int(32), 1));
    } else {
      for (const For* op : loops) {
        dst_shape.push_back(op->extent);
      }
110 111 112 113
    }
    Array<Expr> src_shape = dst_shape;
    Array<Expr> pad_before, pad_after;
    Expr pad_value;
114
    Expr src_elem_offset = load_strides[loop_var_size];
115
    if (has_cond) {
116
      Array<Expr> clip_bound =
117 118
          arith::DetectClipBound(sel_cond.Eval(), loop_vars);
      pad_value = sel_false_value.Eval();
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
      if (clip_bound.size() == 0) return false;
      CHECK_EQ(src_shape.size(), loop_vars.size());
      CHECK_EQ(clip_bound.size(), loop_vars.size() * 2);
      for (size_t i = 0; i < src_shape.size(); ++i) {
        Expr min_value = clip_bound[2 * i];
        Expr max_value = clip_bound[2 * i + 1];
        Type t = loop_vars[i].type();
        Expr svalue = src_shape[i];
        if (min_value.defined()) {
          Expr pbefore = Simplify(Max::make(min_value, make_zero(t)));
          src_elem_offset = src_elem_offset + pbefore * load_strides[i];
          svalue = svalue - pbefore;
          pad_before.push_back(pbefore);
        } else {
          pad_before.push_back(make_zero(t));
        }
        if (max_value.defined()) {
          Expr pafter = Simplify(Max::make(loops[i]->extent - max_value - make_const(t, 1),
                                           make_zero(t)));
          svalue = svalue - pafter;
          pad_after.push_back(pafter);
        } else {
          pad_after.push_back(make_zero(t));
        }
        src_shape.Set(i, Simplify(svalue));
      }
      src_elem_offset = Simplify(src_elem_offset);
    }
    CHECK_EQ(load_strides.size(), store_strides.size());
148 149 150
    CHECK_EQ(load_strides.size(), loop_var_size + 1);
    Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_var_size);
    Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size);
151 152 153 154
    if (loop_var_size == 0) {
        src_strides.push_back(make_const(Int(32), 1));
        dst_strides.push_back(make_const(Int(32), 1));
    }
155 156
    Buffer dst = BufferNode::make(
        Var(store->buffer_var.node_),
157
        store->value.type(),
158 159
        dst_shape,
        dst_strides,
160
        store_strides[loop_var_size],
161 162
        store->buffer_var->name_hint,
        GetStorageScope(store->buffer_var.get()),
163
        0, 0, kDefault);
164 165 166 167 168 169 170 171
    Buffer src = BufferNode::make(
        Var(load->buffer_var.node_),
        load->type,
        src_shape,
        src_strides,
        src_elem_offset,
        load->buffer_var->name_hint,
        GetStorageScope(load->buffer_var.get()),
172
        0, 0, kDefault);
173 174 175 176 177 178 179 180 181 182 183 184 185 186
    *out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value);
    CHECK(out->defined()) << "flower function did not return correct stmt";
    return true;
  }
  // Get storage scope
  std::string GetStorageScope(const Variable* var) const {
    auto it = storage_scope_.find(var);
    if (it != storage_scope_.end()) {
      return it->second;
    } else {
      return "";
    }
  }
  // pragma key
187
  std::string pragma_key_;
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
  // function to lower copy intrinsics.
  const PackedFunc& flower_copy_fromto_;
  // Storage scope
  std::unordered_map<const Variable*, std::string> storage_scope_;
};

Stmt InjectCopyIntrin(Stmt stmt,
                      const std::string& pragma_key,
                      const PackedFunc& flower_copy_fromto) {
  return CopyIntrinInjector(pragma_key, flower_copy_fromto)
      .Mutate(stmt);
}

}  // namespace ir
}  // namespace tvm