split_host_device.cc 6.67 KB
Newer Older
1 2 3 4 5 6
/*!
 *  Copyright (c) 2017 by Contributors
 * \file split_host_device.cc
 * \brief Split device function from host.
 */
#include <tvm/ir.h>
7
#include <tvm/lowered_func.h>
Tianqi Chen committed
8
#include <tvm/channel.h>
9 10
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
11
#include <tvm/runtime/module.h>
12 13 14
#include <unordered_map>

namespace tvm {
15
namespace ir {
16 17 18 19 20

// use/def analysis, also delete unreferenced lets
class IRUseDefAnalysis : public IRMutator {
 public:
  Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
21
    if (op->attr_key == attr::thread_extent) {
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
      IterVar iv(op->node.node_);
      CHECK_NE(iv->thread_tag.length(), 0U);
      // thread_extent can appear multiple times
      // use the first appearance as def.
      if (!use_count_.count(iv->var.get())) {
        this->HandleDef(iv->var.get());
        thread_axis_.push_back(iv);
        thread_extent_.push_back(op->value);
      }

      Expr value = op->value;
      if (visit_thread_extent_) {
        value = this->Mutate(value);
      }
      Stmt body = this->Mutate(op->body);
      if (value.same_as(value) && body.same_as(body)) return s;
38 39 40
      return AttrStmt::make(op->node, op->attr_key, value, body);
    } else if (op->attr_key == attr::channel_write_scope ||
               op->attr_key == attr::channel_read_scope) {
Tianqi Chen committed
41 42 43 44 45
      Channel ch(op->node.node_);
      if (!use_count_.count(ch->handle_var.get())) {
        this->HandleDef(ch->handle_var.get());
      }
      return IRMutator::Mutate_(op, s);
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }

  Stmt Mutate_(const LetStmt *op, const Stmt& s) final {
    this->HandleDef(op->var.get());
    Stmt body = this->Mutate(op->body);
    // eliminate unreferenced let
    if (use_count_.at(op->var.get()) == 0 &&
        !HasSideEffect(op->value)) {
      return body;
    } else {
      Expr value = this->Mutate(op->value);
      if (body.same_as(op->body) &&
          value.same_as(op->value)) {
        return s;
      } else {
        return LetStmt::make(op->var, value, body);
      }
    }
  }

  Stmt Mutate_(const For *op, const Stmt& s) final {
    this->HandleDef(op->loop_var.get());
    return IRMutator::Mutate_(op, s);
  }

  Stmt Mutate_(const Allocate *op, const Stmt& s) final {
    this->HandleDef(op->buffer_var.get());
    return IRMutator::Mutate_(op, s);
  }

  Stmt Mutate_(const Store *op, const Stmt& s) final {
    this->HandleUse(op->buffer_var);
    return IRMutator::Mutate_(op, s);
  }

  Expr Mutate_(const Let *op, const Expr& e) final {
    this->HandleDef(op->var.get());
    Expr body = this->Mutate(op->body);
    // eliminate unreferenced let
    if (use_count_.at(op->var.get()) == 0 &&
        !HasSideEffect(op->value)) {
      return body;
    } else {
      Expr value = this->Mutate(op->value);
      if (body.same_as(op->body) &&
          value.same_as(op->value)) {
        return e;
      } else {
        return Let::make(op->var, value, body);
      }
    }
  }

  Expr Mutate_(const Variable *op, const Expr& e) final {
    this->HandleUse(e);
    return IRMutator::Mutate_(op, e);
  }

  Expr Mutate_(const Load *op, const Expr& e) final {
    this->HandleUse(op->buffer_var);
    return IRMutator::Mutate_(op, e);
  }

  void HandleDef(const Variable* v) {
113 114 115
    CHECK(!def_count_.count(v))
        << "variable " << v->name_hint
        << " has already been defined, the Stmt is not SSA";
116
    CHECK(!use_count_.count(v))
117 118
        << "variable " << v->name_hint
        << " has been used before definition!";
119
    use_count_[v] = 0;
120
    def_count_[v] = 1;
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
  }

  void HandleUse(const Expr& v) {
    CHECK(v.as<Variable>());
    Var var(v.node_);
    auto it = use_count_.find(var.get());
    if (it != use_count_.end()) {
      if (it->second >= 0) {
        ++it->second;
      }
    } else {
      undefined_.push_back(var);
      use_count_[var.get()] = -1;
    }
  }

  // The fields are publically readible to
  // be accessible to the users.
  bool visit_thread_extent_{true};
  Array<Var> undefined_;
  Array<IterVar> thread_axis_;
  Array<Expr> thread_extent_;
  std::unordered_map<const Variable*, int> use_count_;
144
  std::unordered_map<const Variable*, int> def_count_;
145 146 147 148 149
};

class HostDeviceSplitter : public IRMutator {
 public:
  Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
150 151
    if (op->attr_key == attr::thread_extent ||
        op->attr_key == attr::pipeline_exec_scope) {
152 153 154 155 156 157
      return SplitDeviceFunc(s);
    }
    return IRMutator::Mutate_(op, s);
  }

  Array<LoweredFunc> Split(LoweredFunc f) {
158
    CHECK_EQ(f->func_type, kMixedFunc);
159 160 161 162 163 164 165
    for (auto kv : f->handle_data_type) {
      handle_data_type_[kv.first.get()] = kv.second;
    }
    name_ = f->name;
    std::shared_ptr<LoweredFuncNode> n =
        std::make_shared<LoweredFuncNode>(*f.operator->());
    n->body = this->Mutate(f->body);
166
    n->func_type = kHostFunc;
167 168 169 170 171 172 173 174 175 176
    Array<LoweredFunc> ret{LoweredFunc(n)};
    for (LoweredFunc x : device_funcs_) {
      ret.push_back(x);
    }
    return ret;
  }

 private:
  Stmt SplitDeviceFunc(Stmt body) {
    std::ostringstream os;
177
    os << name_ << "__kernel" << device_funcs_.size();
178 179 180 181 182 183
    std::shared_ptr<LoweredFuncNode> n = std::make_shared<LoweredFuncNode>();
    // isolate the device function.
    IRUseDefAnalysis m;
    m.visit_thread_extent_ = false;
    n->body = m.Mutate(body);
    n->name = os.str();
184
    n->func_type = kDeviceFunc;
185
    n->thread_axis = m.thread_axis_;
186 187 188 189 190 191 192 193 194 195 196 197 198 199
    // Strictly order the arguments: Var pointers, positional arguments.
    for (Var v : m.undefined_) {
      if (v.type().is_handle()) {
        n->args.push_back(v);
        // mark handle data type.
        auto it = handle_data_type_.find(v.get());
        if (it != handle_data_type_.end()) {
          n->handle_data_type.Set(v, it->second);
        }
      }
    }
    for (Var v : m.undefined_) {
      if (!v.type().is_handle()) {
        n->args.push_back(v);
200 201 202 203
      }
    }
    LoweredFunc f_device(n);
    Array<Expr> call_args;
204
    call_args.push_back(StringImm::make(f_device->name));
205 206 207 208 209 210 211 212
    for (Var arg : n->args) {
      call_args.push_back(arg);
    }
    for (Expr ext : m.thread_extent_) {
      call_args.push_back(ext);
    }
    device_funcs_.emplace_back(f_device);
    return Evaluate::make(Call::make(
213
        Int(32), intrinsic::tvm_call_packed,
214
        call_args, Call::Intrinsic));
215 216 217 218 219 220 221 222 223 224
  }

  // function name
  std::string name_;
  // the device functions
  std::vector<LoweredFunc> device_funcs_;
  std::unordered_map<const Variable*, Expr> handle_data_type_;
};


225
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
226
  IRUseDefAnalysis m;
227
  for (Var arg : args) {
228 229
    m.use_count_[arg.get()] = 0;
  }
230
  m.Mutate(stmt);
231 232 233 234 235 236 237
  return m.undefined_;
}

Array<LoweredFunc> SplitHostDevice(LoweredFunc func) {
  return HostDeviceSplitter().Split(func);
}

238
}  // namespace ir
239
}  // namespace tvm