/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file codegen_hybrid.cc */ #include <tvm/runtime/registry.h> #include <iomanip> #include <cctype> #include "codegen_hybrid.h" namespace tvm { namespace contrib { using runtime::TVMArgs; using runtime::TVMRetValue; using namespace tir; std::string dot_to_underscore(std::string s) { for (auto &ch : s) if (ch == '.') ch = '_'; return s; } std::string CodeGenHybrid::GetUniqueName(std::string prefix) { prefix = dot_to_underscore(prefix); auto it = ids_allocated_.find(prefix); if (it != ids_allocated_.end()) { while (true) { std::ostringstream os; os << prefix << (++it->second); std::string name = os.str(); if (ids_allocated_.count(name) == 0) { prefix = name; break; } } } ids_allocated_[prefix] = 0; return prefix; } std::string CodeGenHybrid::Finish() { return stream.str(); } void CodeGenHybrid::PrintType(DataType t, std::ostream &os) { if (t.is_float()) { os << "float"; CHECK(t.bits() == 16 || t.bits() == 32 || t.bits() == 64); } else if (t.is_int()) { os << "int"; CHECK(t.bits() == 8 || t.bits() == 16 || t.bits() == 32 || t.bits() == 64); } else { CHECK(t.is_uint()) << "Unsupported type " << t; os << "uint"; CHECK(t.bits() == 8 || t.bits() == 16 || t.bits() == 32 || t.bits() == 64); } os << t.bits(); } void CodeGenHybrid::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*) os << op->value; } void CodeGenHybrid::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) PrintType(op->dtype, os); os << "(" << std::setprecision(20) << op->value << ")"; } void CodeGenHybrid::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*) os << "'" << op->value << "'"; } template<typename T> inline void PrintBinaryExpr(const T* op, const char* opstr, std::ostream& os, // NOLINT(*) CodeGenHybrid* p) { CHECK(op->dtype.lanes() == 1) << "vec bin op not implemented"; if (isalpha(opstr[0])) { os << opstr << '('; p->PrintExpr(op->a, os); os << ", "; p->PrintExpr(op->b, os); os << ')'; } else { os << '('; p->PrintExpr(op->a, os); if (!strcmp(opstr, "&&")) opstr = "and"; if (!strcmp(opstr, "||")) opstr = "or"; os << ' ' << opstr << ' '; p->PrintExpr(op->b, os); os << ')'; } } inline void PrintBinaryIntrinsitc(const CallNode* op, const char* opstr, std::ostream& os, // NOLINT(*) CodeGenHybrid* p) { CHECK(op->dtype.lanes() == 1) << "vec bin intrin not implemented"; CHECK_EQ(op->args.size(), 2U); os << '('; p->PrintExpr(op->args[0], os); os << opstr; p->PrintExpr(op->args[1], os); os << ')'; } void CodeGenHybrid::VisitExpr_(const CastNode* op, std::ostream& os) { // NOLINT(*) if (op->dtype == op->value.dtype()) { PrintExpr(op->value, stream); } else { PrintType(op->dtype, os); os << "("; PrintExpr(op->value, os); os << ")"; } } void CodeGenHybrid::VisitExpr_(const VarNode* op, std::ostream& os) { // NOLINT(*) os << GetVarID(op); } void CodeGenHybrid::VisitExpr_(const AddNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "+", os, this); } void CodeGenHybrid::VisitExpr_(const SubNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "-", os, this); } void CodeGenHybrid::VisitExpr_(const MulNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "*", os, this); } void CodeGenHybrid::VisitExpr_(const DivNode* op, std::ostream& os) { // NOLINT(*) if (op->dtype.is_int()) PrintBinaryExpr(op, "//", os, this); else PrintBinaryExpr(op, "/", os, this); } void CodeGenHybrid::VisitExpr_(const FloorDivNode* op, std::ostream& os) { // NOLINT(*) if (op->dtype.is_int()) PrintBinaryExpr(op, "//", os, this); else PrintBinaryExpr(op, "/", os, this); } void CodeGenHybrid::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "%", os, this); } void CodeGenHybrid::VisitExpr_(const FloorModNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "%", os, this); } void CodeGenHybrid::VisitExpr_(const MinNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "min", os, this); } void CodeGenHybrid::VisitExpr_(const MaxNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "max", os, this); } void CodeGenHybrid::VisitExpr_(const EQNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "==", os, this); } void CodeGenHybrid::VisitExpr_(const NENode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "!=", os, this); } void CodeGenHybrid::VisitExpr_(const LTNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "<", os, this); } void CodeGenHybrid::VisitExpr_(const LENode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "<=", os, this); } void CodeGenHybrid::VisitExpr_(const GTNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, ">", os, this); } void CodeGenHybrid::VisitExpr_(const GENode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, ">=", os, this); } void CodeGenHybrid::VisitExpr_(const AndNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "&&", os, this); } void CodeGenHybrid::VisitExpr_(const OrNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "||", os, this); } void CodeGenHybrid::VisitExpr_(const NotNode* op, std::ostream& os) { // NOLINT(*) os << "not "; PrintExpr(op->a, os); } void CodeGenHybrid::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) if (op->call_type == CallNode::Halide) { os << GetTensorID(op->func, op->value_index); os << "["; for (size_t i = 0; i < op->args.size(); ++i) { if (i) os << ", "; std::stringstream idx; PrintExpr(op->args[i], idx); os << idx.str(); } os << "]"; } else if (op->is_intrinsic(CallNode::bitwise_and)) { PrintBinaryIntrinsitc(op, "&", os, this); } else if (op->is_intrinsic(CallNode::bitwise_xor)) { PrintBinaryIntrinsitc(op, "^", os, this); } else if (op->is_intrinsic(CallNode::bitwise_or)) { PrintBinaryIntrinsitc(op, "|", os, this); } else if (op->is_intrinsic(CallNode::shift_left)) { PrintBinaryIntrinsitc(op, "<<", os, this); } else if (op->is_intrinsic(CallNode::shift_right)) { PrintBinaryIntrinsitc(op, ">>", os, this); } else if (op->is_intrinsic(CallNode::bitwise_not)) { CHECK_EQ(op->args.size(), 1U); os << "(~"; PrintExpr(op->args[0], os); os << ')'; } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { PrintExpr(op->args[1], os); os << " if "; PrintExpr(op->args[0], os); os << " else "; PrintExpr(op->args[2], os); } else { os << op->name << "("; for (size_t i = 0; i < op->args.size(); i++) { PrintExpr(op->args[i], os); if (i < op->args.size() - 1) { os << ", "; } } os << ")"; } } void CodeGenHybrid::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Phase 0 has no Load(s)!"; } void CodeGenHybrid::VisitStmt_(const StoreNode* op) { LOG(FATAL) << "Phase 0 has no Store(s)!"; } void CodeGenHybrid::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Phase 0 has no Let(s)!"; } void CodeGenHybrid::VisitStmt_(const AllocateNode* op) { LOG(FATAL) << "Phase 0 has no Allocate(s)!"; } void CodeGenHybrid::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Ramp to be supported yet"; } void CodeGenHybrid::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Broadcast: not supported "; } void CodeGenHybrid::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(*) PrintExpr(op->true_value, os); os << " if "; PrintExpr(op->condition, os); os << " else "; PrintExpr(op->false_value, os); os << "\n"; } void CodeGenHybrid::VisitStmt_(const LetStmtNode* op) { std::string value = PrintExpr(op->value); stream << GetVarID(op->var.get()) << " = " << value << ";\n"; PrintStmt(op->body); } void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == tir::attr::thread_extent) { auto iter_var = op->node.as<IterVarNode>(); CHECK(iter_var); binds_[iter_var->var.get()] = dot_to_underscore(iter_var->var->name_hint); PrintIndent(); stream << "for " << binds_[iter_var->var.get()] << " in bind('" << iter_var->var->name_hint << "', "; PrintExpr(op->value, stream); stream << "):\n"; indent_ += tab_; PrintStmt(op->body); indent_ -= tab_; } else if (op->attr_key == tir::attr::realize_scope) { auto v = Downcast<FunctionRef>(op->node); alloc_storage_scope_[v] = op->value.as<StringImmNode>()->value; PrintStmt(op->body); } else { // For now we ignore the unsupported AttrStmt PrintStmt(op->body); } } void CodeGenHybrid::VisitStmt_(const RealizeNode* op) { CHECK(alloc_storage_scope_.count(op->func)); if (!alloc_storage_scope_[op->func].empty()) { PrintIndent(); stream << GetTensorID(op->func, op->value_index) << " = allocate(("; for (size_t i = 0; i < op->bounds.size(); ++i) { if (i) stream << ", "; stream << PrintExpr(op->bounds[i]->extent); } if (op->bounds.size() == 1) stream << ", "; stream << "), '"; PrintType(op->dtype, stream); stream << "', '"; stream << alloc_storage_scope_[op->func] << "')\n"; } PrintStmt(op->body); } void CodeGenHybrid::VisitStmt_(const AssertStmtNode* op) { PrintIndent(); stream << "assert "; PrintExpr(op->condition, stream); stream << ", "; PrintExpr(op->message, stream); stream << "\n"; PrintStmt(op->body); } void CodeGenHybrid::VisitStmt_(const ProvideNode* op) { PrintIndent(); stream << GetTensorID(op->func, op->value_index); stream << "["; for (size_t i = 0; i < op->args.size(); ++i) { if (i) stream << ", "; PrintExpr(op->args[i], stream); } stream << "] = "; PrintExpr(op->value, stream); stream << "\n"; } void CodeGenHybrid::VisitStmt_(const ForNode* op) { std::string extent = PrintExpr(op->extent); PrintIndent(); std::string vid = GetVarID(op->loop_var.get()); stream << "for " << vid << " in " << "range(" << extent << "):\n"; indent_ += tab_; PrintStmt(op->body); indent_ -= tab_; } bool is_noop(const Stmt &stmt) { if (!stmt.defined()) return true; if (auto eval = stmt.as<EvaluateNode>()) return is_const(eval->value); return false; } void CodeGenHybrid::VisitStmt_(const IfThenElseNode* op) { std::string cond = PrintExpr(op->condition); PrintIndent(); stream << "if " << cond << ":\n"; indent_ += tab_; PrintStmt(op->then_case); indent_ -= tab_; if (!is_noop(op->else_case)) { PrintIndent(); stream << "else:\n"; indent_ += tab_; PrintStmt(op->else_case); indent_ -= tab_; } } void CodeGenHybrid::VisitStmt_(const SeqStmtNode* op) { for (Stmt stmt : op->seq) { PrintStmt(stmt); } } void CodeGenHybrid::VisitStmt_(const EvaluateNode* op) { if (is_const(op->value)) return; std::string str = PrintExpr(op->value); if (!str.empty()) stream << str << "\n"; } void CodeGenHybrid::PrintIndent() { stream << std::string(indent_, ' '); } std::string CodeGenHybrid::GetVarID(const VarNode *v) { if (binds_.count(v)) return binds_[v]; auto key = std::make_pair(static_cast<const Object*>(v), 0); if (id_map_.count(key)) { return id_map_[key]; } return id_map_[key] = GetUniqueName(v->name_hint); } std::string CodeGenHybrid::GetTensorID(const FunctionRef &func, int value_index) { auto key = std::make_pair(func.get(), value_index); if (id_map_.count(key)) { return id_map_[key]; } std::string name_hint = func->func_name(); if (func->num_outputs() > 1) { name_hint += "_v" + std::to_string(value_index); } return id_map_[key] = GetUniqueName(name_hint); } void CodeGenHybrid::ReserveKeywords() { GetUniqueName("def"); GetUniqueName("for"); GetUniqueName("in"); GetUniqueName("range"); GetUniqueName("True"); GetUniqueName("False"); GetUniqueName("unroll"); GetUniqueName("const_range"); GetUniqueName("parallel"); GetUniqueName("vectorize"); GetUniqueName("bind"); GetUniqueName("threadIdx.x"); GetUniqueName("threadIdx.y"); GetUniqueName("threadIdx.z"); GetUniqueName("blockIdx.x"); GetUniqueName("blockIdx.y"); GetUniqueName("blockIdx.z"); GetUniqueName("vthread"); GetUniqueName("allocate"); GetUniqueName("output_tensor"); GetUniqueName("sqrt"); GetUniqueName("log"); GetUniqueName("tanh"); GetUniqueName("power"); GetUniqueName("exp"); GetUniqueName("sigmoid"); GetUniqueName("popcount"); GetUniqueName("likely"); GetUniqueName("int8"); GetUniqueName("int16"); GetUniqueName("int32"); GetUniqueName("int64"); GetUniqueName("uint8"); GetUniqueName("uint16"); GetUniqueName("uint32"); GetUniqueName("uint64"); GetUniqueName("float16"); GetUniqueName("float32"); GetUniqueName("float64"); GetUniqueName("ceil_div"); GetUniqueName("max_num_threads"); } void CodeGenHybrid::DumpStmt(const Stmt &stmt, const Array<ObjectRef> &inputs, const Array<Tensor> &outputs, const std::string &name) { ReserveKeywords(); GetUniqueName(name); stream << "def " << name << "("; for (size_t i = 0; i < inputs.size(); ++i) { if (i) stream << ", "; if (auto tensor = inputs[i].as<TensorNode>()) { stream << GetTensorID(tensor->op, tensor->value_index); } else { auto var = inputs[i].as<VarNode>(); CHECK(var) << "Input should either be a tensor or a variable!"; stream << GetVarID(var); } } stream << "):\n"; indent_ += tab_; for (size_t i = 0; i < outputs.size(); ++i) { PrintIndent(); stream << GetTensorID(outputs[i]->op, outputs[i]->value_index) << " = output_tensor(("; for (size_t j = 0; j < outputs[i]->shape.size(); ++j) { if (j) stream << ", "; PrintExpr(outputs[i]->shape[j], stream); } if (outputs[i]->shape.size() == 1) stream << ", "; stream << "), '" << outputs[i]->dtype << "')\n"; } PrintStmt(stmt); PrintIndent(); stream << "return "; for (size_t i = 0; i < outputs.size(); ++i) { if (i) stream << ", "; stream << GetTensorID(outputs[i]->op, outputs[i]->value_index); } stream << "\n"; } TVM_REGISTER_GLOBAL("hybrid._Dump") .set_body([](TVMArgs args, TVMRetValue* rv) { CodeGenHybrid codegen; if (args.size() == 4) codegen.DumpStmt(args[0], args[1], args[2], args[3]); else codegen.DumpStmt(args[0], args[1], args[2]); *rv = codegen.Finish(); }); } // namespace contrib } // namespace tvm