/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2017 by Contributors * \file split_host_device.cc * \brief Split device function from host. */ #include <tvm/ir.h> #include <tvm/lowered_func.h> #include <tvm/channel.h> #include <tvm/ir_pass.h> #include <tvm/ir_mutator.h> #include <tvm/runtime/module.h> #include <unordered_map> namespace tvm { namespace ir { // use/def analysis, also delete unreferenced lets class IRUseDefAnalysis : public IRMutator { public: Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { if (op->attr_key == attr::thread_extent) { IterVar iv(op->node.node_); CHECK_NE(iv->thread_tag.length(), 0U); // thread_extent can appear multiple times // use the first appearance as def. if (!use_count_.count(iv->var.get())) { this->HandleDef(iv->var.get()); thread_axis_.push_back(iv); thread_extent_.push_back(op->value); } Expr value = op->value; if (visit_thread_extent_) { value = this->Mutate(value); } Stmt body = this->Mutate(op->body); if (value.same_as(op->value) && body.same_as(op->body)) return s; return AttrStmt::make(op->node, op->attr_key, value, body); } else if (op->attr_key == attr::channel_write_scope || op->attr_key == attr::channel_read_scope) { Channel ch(op->node.node_); if (!use_count_.count(ch->handle_var.get())) { this->HandleDef(ch->handle_var.get()); } return IRMutator::Mutate_(op, s); } else { return IRMutator::Mutate_(op, s); } } Stmt Mutate_(const LetStmt *op, const Stmt& s) final { this->HandleDef(op->var.get()); Stmt body = this->Mutate(op->body); // eliminate unreferenced let if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value)) { return body; } else { Expr value = this->Mutate(op->value); if (body.same_as(op->body) && value.same_as(op->value)) { return s; } else { return LetStmt::make(op->var, value, body); } } } Stmt Mutate_(const For *op, const Stmt& s) final { this->HandleDef(op->loop_var.get()); return IRMutator::Mutate_(op, s); } Stmt Mutate_(const Allocate *op, const Stmt& s) final { this->HandleDef(op->buffer_var.get()); return IRMutator::Mutate_(op, s); } Stmt Mutate_(const Store *op, const Stmt& s) final { this->HandleUse(op->buffer_var); return IRMutator::Mutate_(op, s); } Expr Mutate_(const Let *op, const Expr& e) final { this->HandleDef(op->var.get()); Expr body = this->Mutate(op->body); // eliminate unreferenced let if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value)) { return body; } else { Expr value = this->Mutate(op->value); if (body.same_as(op->body) && value.same_as(op->value)) { return e; } else { return Let::make(op->var, value, body); } } } Expr Mutate_(const Variable *op, const Expr& e) final { this->HandleUse(e); return IRMutator::Mutate_(op, e); } Expr Mutate_(const Load *op, const Expr& e) final { this->HandleUse(op->buffer_var); return IRMutator::Mutate_(op, e); } void HandleDef(const Variable* v) { CHECK(!def_count_.count(v)) << "variable " << v->name_hint << " has already been defined, the Stmt is not SSA"; CHECK(!use_count_.count(v)) << "variable " << v->name_hint << " has been used before definition!"; use_count_[v] = 0; def_count_[v] = 1; } void HandleUse(const Expr& v) { CHECK(v.as<Variable>()); Var var(v.node_); auto it = use_count_.find(var.get()); if (it != use_count_.end()) { if (it->second >= 0) { ++it->second; } } else { undefined_.push_back(var); use_count_[var.get()] = -1; } } // The fields are publically readible to // be accessible to the users. bool visit_thread_extent_{true}; Array<Var> undefined_; Array<IterVar> thread_axis_; Array<Expr> thread_extent_; std::unordered_map<const Variable*, int> use_count_; std::unordered_map<const Variable*, int> def_count_; }; class HostDeviceSplitter : public IRMutator { public: Stmt Mutate_(const Allocate* op, const Stmt& s) final { handle_data_type_[op->buffer_var.get()] = make_const(op->type, 0); return IRMutator::Mutate_(op, s); } Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope || op->attr_key == attr::device_scope) { return SplitDeviceFunc(s); } return IRMutator::Mutate_(op, s); } Array<LoweredFunc> Split(LoweredFunc f) { CHECK_EQ(f->func_type, kMixedFunc); for (auto kv : f->handle_data_type) { handle_data_type_[kv.first.get()] = kv.second; } name_ = f->name; NodePtr<LoweredFuncNode> n = make_node<LoweredFuncNode>(*f.operator->()); n->body = this->Mutate(f->body); n->func_type = kHostFunc; Array<LoweredFunc> ret{LoweredFunc(n)}; for (LoweredFunc x : device_funcs_) { ret.push_back(x); } return ret; } private: Stmt SplitDeviceFunc(Stmt body) { std::ostringstream os; os << name_ << "_kernel" << device_funcs_.size(); NodePtr<LoweredFuncNode> n = make_node<LoweredFuncNode>(); // isolate the device function. IRUseDefAnalysis m; m.visit_thread_extent_ = false; n->body = m.Mutate(body); n->name = os.str(); n->func_type = kDeviceFunc; n->thread_axis = m.thread_axis_; // Strictly order the arguments: Var pointers, positional arguments. for (Var v : m.undefined_) { if (v.type().is_handle()) { n->args.push_back(v); // mark handle data type. auto it = handle_data_type_.find(v.get()); if (it != handle_data_type_.end()) { n->handle_data_type.Set(v, it->second); } } } for (Var v : m.undefined_) { if (!v.type().is_handle()) { n->args.push_back(v); } } LoweredFunc f_device(n); Array<Expr> call_args; call_args.push_back(StringImm::make(f_device->name)); for (Var arg : n->args) { call_args.push_back(arg); } for (Expr ext : m.thread_extent_) { call_args.push_back(ext); } device_funcs_.emplace_back(f_device); return Evaluate::make(Call::make( Int(32), intrinsic::tvm_call_packed, call_args, Call::Intrinsic)); } // function name std::string name_; // the device functions std::vector<LoweredFunc> device_funcs_; std::unordered_map<const Variable*, Expr> handle_data_type_; }; Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) { IRUseDefAnalysis m; for (Var arg : args) { m.use_count_[arg.get()] = 0; } m.Mutate(stmt); return m.undefined_; } Array<LoweredFunc> SplitHostDevice(LoweredFunc func) { return HostDeviceSplitter().Split(func); } } // namespace ir } // namespace tvm