/*! * Copyright (c) 2017 by Contributors * Combine calls into context related function into one. * * \file combine_context_call.cc */ #include <tvm/ir.h> #include <tvm/ir_mutator.h> #include <tvm/ir_pass.h> #include <map> namespace tvm { namespace ir { // Calculate the statistics of packed function. // These information are needed during codegen. class ContextCallCombiner final : public IRMutator { public: struct CompareExpr { bool operator()(const Expr& lhs, const Expr& rhs) const { return Compare(lhs, rhs) < 0; } }; Expr Mutate_(const Call* op, const Expr& e) final { if (op->is_intrinsic(intrinsic::tvm_thread_context)) { CHECK_EQ(op->args.size(), 1U); Expr ctx = op->args[0]; auto it = ctx_map_.find(ctx); if (it != ctx_map_.end()) { return it->second; } else { CHECK(ctx.type().is_handle()); std::string name; if (const Call* call = ctx.as<Call>()) { name = call->name + "_cache"; } else { name = "ctx_cache_"; } Var ctx_var(name, ctx.type()); ctx_map_[ctx] = ctx_var; return ctx_var; } } else { return IRMutator::Mutate_(op, e); } } Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { if (op->attr_key == attr::thread_extent || op->attr_key == attr::coproc_uop_scope) { // Map of comparison expression to variable std::map<Expr, Var, CompareExpr> temp; std::swap(temp, ctx_map_); Stmt stmt = IRMutator::Mutate_(op, s); std::swap(temp, ctx_map_); return BuildContext(temp, stmt); } else { return IRMutator::Mutate_(op, s); } } Stmt Mutate_(const For* op, const Stmt& s) final { if (op->for_type == ForType::Parallel) { // Map of comparison expression to variable std::map<Expr, Var, CompareExpr> temp; std::swap(temp, ctx_map_); Stmt stmt = IRMutator::Mutate_(op, s); std::swap(temp, ctx_map_); return BuildContext(temp, stmt); } else { return IRMutator::Mutate_(op, s); } } Stmt Combine(Stmt stmt) { return BuildContext(ctx_map_, this->Mutate(stmt)); } private: static Stmt BuildContext(const std::map<Expr, Var, CompareExpr>& cmap, Stmt body) { for (const auto& kv : cmap) { body = LetStmt::make(kv.second, kv.first, body); } return body; } // Map of comparison expression to variable std::map<Expr, Var, CompareExpr> ctx_map_; }; LoweredFunc CombineContextCall(LoweredFunc f) { auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); n->body = ContextCallCombiner().Combine(n->body); return LoweredFunc(n); } } // namespace ir } // namespace tvm