Commit 3c1020df by Tianqi Chen Committed by GitHub

[CODEGEN] Add CodeGenC (#22)

parent 5b408d1d
Subproject commit b6637f611f91dd075dc251438f72ad38901d17fb
Subproject commit adfa662402650e2f9b02ea600ffb70d6e7bb5adf
......@@ -8,6 +8,7 @@ from . import expr
from . import stmt
from . import make
from . import ir_pass
from . import codegen
from . import collections
from . import schedule
......
......@@ -281,6 +281,7 @@ def _init_function_module(root_namespace):
namespace_match = {
"_make_": sys.modules["%s.make" % root_namespace],
"_pass_": sys.modules["%s.ir_pass" % root_namespace],
"_codegen_": sys.modules["%s.codegen" % root_namespace],
"_schedule_": sys.modules["%s.schedule" % root_namespace]
}
......
"""Code generation related functions"""
......@@ -30,7 +30,7 @@ inline Type String2Type(std::string s) {
} else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5);
} else if (s == "handle") {
return Type(Type::Handle, 0, 0);
return Type(Type::Handle, 32, 1);
} else {
LOG(FATAL) << "unknown type " << s;
}
......
/*!
* Copyright (c) 2016 by Contributors
* Implementation of API functions related to IR build
* \file c_api_ir.cc
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
#include "./c_api_registry.h"
#include "../codegen/codegen_c.h"
namespace tvm {
namespace codegen {
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
TVM_REGISTER_API(_codegen_CompileToC)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = CodeGenC().Compile(
args.at(0), args.at(1), args.at(2), args.at(3));
});
} // namespace codegen
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file codegen_c.cc
*/
#include "./codegen_c.h"
namespace tvm {
namespace codegen {
using namespace ir;
std::string CodeGenC::Compile(
Stmt stmt, std::string fun_name,
Array<Var> args, bool output_ssa) {
print_ssa_form_ = output_ssa;
// skip the first underscore, so SSA variable starts from _1
if (print_ssa_form_) GetUniqueName("_");
this->indent += 2;
this->stream << "void " << fun_name << "(";
for (size_t i = 0; i < args.size(); ++i) {
Var v = args[i];
std::string vid = AllocVarID(v.get());
if (i != 0) stream << ", ";
PrintType(v.type(), stream);
stream << ' ' << vid;
}
stream << ") {\n";
this->PrintStmt(stmt);
this->indent -= 2;
this->PrintIndent();
this->stream << "}\n";
return stream.str();
}
void CodeGenC::PrintStmt(const Stmt& n) {
static const FPrintStmt& f = vtable_print_stmt();
f(n, this);
}
std::string CodeGenC::SSAGetID(std::string src, Type t) {
if (name_alloc_map_.count(src)) return src;
auto it = ssa_assign_map_.find(src);
if (it != ssa_assign_map_.end()) {
return it->second;
} else {
this->PrintIndent();
std::string id = GetUniqueName("_");
ssa_assign_map_[src] = id;
if (src.length() > 3 &&
src[0] == '(' && src[src.length() - 1] == ')') {
src = src.substr(1, src.length() - 2);
}
PrintType(t, stream);
stream << ' ' << id << " = " << src << ";\n";
return id;
}
}
void CodeGenC::PrintExpr(const Expr& n, std::ostream& os) { // NOLINT(*)
static const FPrintExpr& f = vtable_print_expr();
if (print_ssa_form_) {
std::ostringstream temp;
f(n, temp, this);
os << SSAGetID(temp.str(), n.type());
} else {
f(n, os, this);
}
}
std::string CodeGenC::GetUniqueName(std::string prefix) {
auto it = name_alloc_map_.find(prefix);
if (it != name_alloc_map_.end()) {
while (true) {
std::ostringstream os;
os << prefix << (++it->second);
std::string name = os.str();
if (name_alloc_map_.count(name) == 0) {
prefix = name;
break;
}
}
}
name_alloc_map_[prefix] = 0;
return prefix;
}
std::string CodeGenC::AllocVarID(const Variable* v) {
CHECK(!var_idmap_.count(v))
<< "Need input to be in SSA form dup " << v->name_hint;
std::string key = v->name_hint;
for (size_t i = 0; i < key.size(); ++i) {
if (key[i] == '.') key[i] = '_';
}
std::string vid = GetUniqueName(key);
var_idmap_[v] = vid;
return vid;
}
std::string CodeGenC::GetVarID(const Variable* v) const {
auto it = var_idmap_.find(v);
CHECK(it != var_idmap_.end())
<< "Find undefined Variable " << v->name_hint;
return it->second;
}
bool CodeGenC::BufferTypeMatch(const Variable* buf_var, Type t) const {
auto it = alloc_buf_type_.find(buf_var);
if (it == alloc_buf_type_.end()) return false;
return it->second == t;
}
void CodeGenC::PrintIndent() {
for (int i = 0; i < this->indent; ++i) {
this->stream << ' ';
}
}
void CodeGenC::MarkConst(std::string vid) {
if (print_ssa_form_) {
auto it = ssa_assign_map_.find(vid);
if (it == ssa_assign_map_.end()) {
ssa_assign_map_[vid] = vid;
} else {
CHECK_EQ(it->second, vid);
}
}
}
void CodeGenC::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
CHECK_EQ(t.lanes(), 1)
<< "do not yet support vector types";
if (t.is_handle()) {
os << "void*"; return;
}
if (t.is_float()) {
if (t.bits() == 32) {
os << "float"; return;
}
if (t.bits() == 64) {
os << "double"; return;
}
} else if (t.is_uint()) {
switch (t.bits()) {
case 8: case 16: case 32: case 64: {
os << "uint" << t.bits() << "_t"; return;
}
case 1: os << "int"; return;
}
} else if (t.is_int()) {
switch (t.bits()) {
case 8: case 16: case 32: case 64: {
os << "int" << t.bits() << "_t"; return;
}
}
}
LOG(FATAL) << "Cannot convert type " << t << " to C type";
}
CodeGenC::FPrintStmt& CodeGenC::vtable_print_stmt() { // NOLINT(*)
static FPrintStmt inst; return inst;
}
CodeGenC::FPrintExpr& CodeGenC::vtable_print_expr() { // NOLINT(*)
static FPrintExpr inst; return inst;
}
inline void PrintConst(const IntImm* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
if (op->type == Int(32)) {
std::ostringstream temp;
temp << op->value;
p->MarkConst(temp.str());
os << temp.str();
} else {
os << "(";
p->PrintType(op->type, os);
os << ")" << op->value;
}
}
inline void PrintConst(const UIntImm* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
if (op->type == UInt(32)) {
std::ostringstream temp;
temp << op->value << "U";
p->MarkConst(temp.str());
os << temp.str();
} else {
os << "(";
p->PrintType(op->type, os);
os << ")" << op->value;
}
}
inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
switch (op->type.bits()) {
case 64: case 32: {
std::ostringstream temp;
temp << op->value;
if (op->type.bits() == 32) temp << 'f';
p->MarkConst(temp.str());
os << temp.str();
break;
}
case 16: {
os << '(';
p->PrintType(op->type, os);
os << ')' << op->value << 'f';
break;
}
default: LOG(FATAL) << "Bad bit-width for float: " << op->type << "\n";
}
}
TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr)
.set_dispatch<IntImm>([](const IntImm *op, std::ostream& os, CodeGenC *p) { // NOLINT(*)
PrintConst(op, os, p);
})
.set_dispatch<UIntImm>([](const UIntImm *op, std::ostream& os, CodeGenC *p) { // NOLINT(*)
PrintConst(op, os, p);
})
.set_dispatch<FloatImm>([](const FloatImm *op, std::ostream& os, CodeGenC *p) { // NOLINT(*)
PrintConst(op, os, p);
});
template<typename T>
inline void PrintBinaryExpr(const T* op,
const char *opstr,
std::ostream& os, // NOLINT(*)
CodeGenC* p) {
os << '(';
p->PrintExpr(op->a, os);
os << opstr;
p->PrintExpr(op->b, os);
os << ')';
}
TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr)
.set_dispatch<Cast>([](const Cast *op, std::ostream& os, CodeGenC *p) { // NOLINT(*)
p->PrintType(op->type, os);
os << '(';
p->PrintExpr(op->value, os);
os << ')';
})
.set_dispatch<Variable>([](const Variable *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
os << p->GetVarID(op);
})
.set_dispatch<Add>([](const Add *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, " + ", os, p);
})
.set_dispatch<Sub>([](const Sub *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, " - ", os, p);
})
.set_dispatch<Mul>([](const Mul *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, " * ", os, p);
})
.set_dispatch<Div>([](const Div *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, " / ", os, p);
})
.set_dispatch<Mod>([](const Mod *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, " % ", os, p);
})
.set_dispatch<Min>([](const Min *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
os << "min(";
p->PrintExpr(op->a, os);
os << ", ";
p->PrintExpr(op->b, os);
os << ")";
})
.set_dispatch<Max>([](const Max *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
os << "max(";
p->PrintExpr(op->a, os);
os << ", ";
p->PrintExpr(op->b, os);
os << ")";
})
.set_dispatch<EQ>([](const EQ *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, " == ", os, p);
})
.set_dispatch<NE>([](const NE *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, " != ", os, p);
})
.set_dispatch<LT>([](const LT *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, " < ", os, p);
})
.set_dispatch<LE>([](const LE *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, " <= ", os, p);
})
.set_dispatch<GT>([](const GT *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, " > ", os, p);
})
.set_dispatch<GE>([](const GE *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, " >= ", os, p);
})
.set_dispatch<And>([](const And *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, " && ", os, p);
})
.set_dispatch<Or>([](const Or *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, " || ", os, p);
})
.set_dispatch<Not>([](const Not *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
os << '!';
p->PrintExpr(op->a, os);
})
.set_dispatch<Call>([](const Call *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
os << op->name << "(";
for (size_t i = 0; i < op->args.size(); i++) {
p->PrintExpr(op->args[i], os);
if (i < op->args.size() - 1) {
os << ", ";
}
}
os << ")";
});
TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt)
.set_dispatch<AssertStmt>([](const AssertStmt *op, CodeGenC* p) {
std::string cond = p->PrintExpr(op->condition);
p->PrintIndent();
p->stream << "assert(" << cond << ");\n";
})
.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, CodeGenC* p) {
p->PrintStmt(op->body);
})
.set_dispatch<For>([](const For *op, CodeGenC* p) {
std::string extent = p->PrintExpr(op->extent);
p->PrintIndent();
std::string vid = p->AllocVarID(op->loop_var.get());
CHECK(is_zero(op->min));
p->stream << "for (";
p->PrintType(op->loop_var.type(), p->stream);
p->stream << ' ' << vid << " = 0; "
<< vid << " < " << extent
<< "; ++" << vid << ") {\n";
p->indent += 2;
p->PrintStmt(op->body);
p->indent -= 2;
p->PrintIndent();
p->stream << "}\n";
})
.set_dispatch<Block>([](const Block *op, CodeGenC* p) {
p->PrintStmt(op->first);
if (op->rest.defined()) p->PrintStmt(op->rest);
})
.set_dispatch<Evaluate>([](const Evaluate *op, CodeGenC* p) {
if (is_const(op->value)) return;
std::string vid = p->PrintExpr(op->value);
p->PrintIndent();
p->stream << "(void)" << vid << ";\n";
})
.set_dispatch<IfThenElse>([](const IfThenElse *op, CodeGenC* p) {
std::string cond = p->PrintExpr(op->condition);
p->PrintIndent();
p->stream << "if (" << cond << ") {\n";
p->indent += 2;
p->PrintStmt(op->then_case);
p->indent -= 2;
if (op->else_case.defined()) {
p->PrintIndent();
p->stream << "} else {\n";
p->indent += 2;
p->PrintStmt(op->else_case);
p->indent -= 2;
}
p->PrintIndent();
p->stream << "}\n";
});
#define DISPATCH_EXPR(OP) \
set_dispatch<OP>([](const OP *op, std::ostream&os, CodeGenC* p) { \
p->PrintExpr(op, os); })
TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr)
.DISPATCH_EXPR(Load)
.DISPATCH_EXPR(Let)
.DISPATCH_EXPR(Ramp)
.DISPATCH_EXPR(Broadcast)
.DISPATCH_EXPR(Select);
void CodeGenC::PrintExpr(const Load* op, std::ostream& os) { // NOLINT(*)
std::string vid = GetVarID(op->buffer_var.get());
if (!BufferTypeMatch(op->buffer_var.get(), op->type)) {
os << "((const ";
PrintType(op->type, os);
os << "*)" << vid << ')';
} else {
os << vid;
}
os << '[';
PrintExpr(op->index, os);
os << ']';
}
void CodeGenC::PrintExpr(const Let* op, std::ostream& os) { // NOLINT(*)
CHECK(print_ssa_form_)
<< "LetExpr is only supported by print SSA form";
std::string value = PrintExpr(op->value);
CHECK(!var_idmap_.count(op->var.get()));
var_idmap_[op->var.get()] = value;
}
void CodeGenC::PrintExpr(const Ramp* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "not supported ";
}
void CodeGenC::PrintExpr(const Broadcast* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "not supported ";
}
void CodeGenC::PrintExpr(const Select* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "not supported ";
}
// Disoatch back to member functions
TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt)
.set_dispatch<LetStmt>([](const LetStmt *op, CodeGenC* p) { p->PrintStmt(op); })
.set_dispatch<Store>([](const Store *op, CodeGenC* p) { p->PrintStmt(op); })
.set_dispatch<Allocate>([](const Allocate *op, CodeGenC* p) { p->PrintStmt(op); })
.set_dispatch<AttrStmt>([](const AttrStmt *op, CodeGenC* p) { p->PrintStmt(op); });
void CodeGenC::PrintStmt(const LetStmt* op) {
std::string value = PrintExpr(op->value);
if (print_ssa_form_) {
CHECK(!var_idmap_.count(op->var.get()));
var_idmap_[op->var.get()] = value;
} else {
PrintIndent();
PrintType(op->var.type(), this->stream);
this->stream << ' '
<< AllocVarID(op->var.get())
<< " = " << value << ";\n";
}
PrintStmt(op->body);
}
void CodeGenC::PrintStmt(const Store* op) {
std::string index = this->PrintExpr(op->index);
std::string value = this->PrintExpr(op->value);
this->PrintIndent();
std::string vid = GetVarID(op->buffer_var.get());
if (!BufferTypeMatch(op->buffer_var.get(), op->value.type())) {
this->stream << "((";
PrintType(op->value.type(), this->stream);
this->stream << "*)" << vid << ')';
} else {
this->stream << vid;
}
this->stream << '[' << index
<< "] = " << value
<< ";\n";
}
void CodeGenC::PrintStmt(const Allocate* op) {
this->PrintIndent();
int32_t constant_size = op->constant_allocation_size();
std::string vid = AllocVarID(op->buffer_var.get());
CHECK(!op->new_expr.defined());
CHECK(!is_zero(op->condition));
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now";
PrintType(op->type, stream);
stream << ' '<< vid << '['
<< constant_size << "]\n;";
this->PrintStmt(op->body);
}
void CodeGenC::PrintStmt(const AttrStmt* op) {
if (op->type_key == "scope") {
IterVar iv(op->node.node_);
if (iv->thread_tag.length() != 0) {
this->PrintIndent();
PrintType(iv->var.type(), stream);
stream << ' '
<< AllocVarID(iv->var.get())
<< " = " << iv->thread_tag << ";\n";
}
}
this->PrintStmt(op->body);
}
} // namespace codegen
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file codegen_c.h
* \brief Common utilities to generated C style code.
*/
#ifndef TVM_CODEGEN_CODEGEN_C_H_
#define TVM_CODEGEN_CODEGEN_C_H_
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <string>
#include <unordered_map>
namespace tvm {
namespace codegen {
/*!
* \brief A base class to generate C code.
*
* CodeGenC have two modes: generate SSA formed C code or normal form.
*/
class CodeGenC {
public:
/*!
* \brief Generate the C code of statement
* \param body The body of the function.
* \param fun_name The name of the function.
* \param args The arguments to the function.
* \param output_ssa Whether output ssa form.
* \note Only call compile once,
* create a new codegen object each time.
*/
std::string Compile(Stmt body,
std::string fun_name,
Array<Var> args,
bool output_ssa);
/*!
* \brief Print the Stmt n to CodeGenC->stream
* \param n The statement to be printed.
*/
void PrintStmt(const Stmt& n);
/*!
* \brief Print the expression n(or its ssa id if in ssa mode) into os
* \param n The expression to be printed.
* \param os The output stream
*/
void PrintExpr(const Expr& n, std::ostream& os); // NOLINT(*)
/*!
* \brief Same as PrintExpr, but simply returns result string
* \param n The expression to be printed.
*/
inline std::string PrintExpr(const Expr& n) {
std::ostringstream os;
PrintExpr(n, os);
return os.str();
}
/*! \brief print the current indented value */
void PrintIndent();
/*!
* \brief Register constant value appeared in expresion tree
* This avoid generated a ssa id for each appearance of the value
* \param value The constant value.
*/
void MarkConst(std::string value);
/*!
* \brief Allocate a variable name for a newly defined var.
* \param v The variable.
* \return the variable name.
*/
std::string AllocVarID(const Variable* v);
/*!
* \brief Get a variable name.
* \param v The variable.
* \return the variable name.
*/
std::string GetVarID(const Variable* v) const;
/*!
* Print Type represetnation of type t.
* \param t The type representation.
* \return os The stream to print the ctype into
*/
virtual void PrintType(Type t, std::ostream& os) const; // NOLINT(*)
// The following parts are overloadable print operations.
virtual void PrintStmt(const ir::LetStmt* op);
virtual void PrintStmt(const ir::Store* op);
virtual void PrintStmt(const ir::Allocate* op);
virtual void PrintStmt(const ir::AttrStmt* op);
virtual void PrintExpr(const ir::Load* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Let* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Ramp* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Broadcast* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Select* op, std::ostream& os); // NOLINT(*)
/*! \brief function print into the ostream */
using FPrintExpr = IRFunctor<void(const NodeRef&, std::ostream& os, CodeGenC *)>; // NOLINT(*)
/*! \brief function to to print normal code */
using FPrintStmt = IRFunctor<void(const NodeRef&, CodeGenC *)>;
// vtable to print code
static FPrintStmt& vtable_print_stmt();
// vtable to print code
static FPrintExpr& vtable_print_expr();
/*! \brief The current indentation value */
int indent{0};
/*! \brief the stream to be printed */
std::ostringstream stream;
private:
/*!
* \brief Get the SSA ID corresponds to src
* If necessary, generate new assignment
* \param src The source expression
* \param t The type of the expression.
*/
std::string SSAGetID(std::string src, Type t);
/*!
* \brief If buffer is allocated as type t.
* \param buf_var The buffer variable.
* \param t The type to be checked.
*/
bool BufferTypeMatch(const Variable* buf_var, Type t) const;
/*!
* \brief get a unique name with the corresponding prefix
* \param prefix The prefix of the name
* \return The returned name.
*/
std::string GetUniqueName(std::string prefix);
/*! \brief whether to print in SSA form */
bool print_ssa_form_{true};
/*! \brief name of each variable */
std::unordered_map<const Variable*, std::string> var_idmap_;
/*! \brief the data type of allocated buffers */
std::unordered_map<const Variable*, Type> alloc_buf_type_;
/*! \brief name allocation map */
std::unordered_map<std::string, int> name_alloc_map_;
/*! \brief assignment map of ssa */
std::unordered_map<std::string, std::string> ssa_assign_map_;
};
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_CODEGEN_C_H_
......@@ -19,10 +19,18 @@ def mock_test_add():
bounds = tvm.schedule.InferBound(s)
stmt = tvm.ir_pass.ScheduleOps(s, bounds)
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
Bb = tvm.Buffer(B.shape, B.dtype, name='B')
Cb = tvm.Buffer(C.shape, C.dtype, name='C')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
print(stmt)
output_ssa = False
code = tvm.codegen.CompileToC(stmt, "myadd",
[Ab.ptr, Bb.ptr, Cb.ptr, n],
output_ssa)
print(code)
def codegen():
# generate host/device code
host_code, device_code = tvm.codegen.GenCUDA(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment