Commit 3c1020df by Tianqi Chen Committed by GitHub

[CODEGEN] Add CodeGenC (#22)

parent 5b408d1d
Subproject commit b6637f611f91dd075dc251438f72ad38901d17fb Subproject commit adfa662402650e2f9b02ea600ffb70d6e7bb5adf
...@@ -8,6 +8,7 @@ from . import expr ...@@ -8,6 +8,7 @@ from . import expr
from . import stmt from . import stmt
from . import make from . import make
from . import ir_pass from . import ir_pass
from . import codegen
from . import collections from . import collections
from . import schedule from . import schedule
......
...@@ -281,6 +281,7 @@ def _init_function_module(root_namespace): ...@@ -281,6 +281,7 @@ def _init_function_module(root_namespace):
namespace_match = { namespace_match = {
"_make_": sys.modules["%s.make" % root_namespace], "_make_": sys.modules["%s.make" % root_namespace],
"_pass_": sys.modules["%s.ir_pass" % root_namespace], "_pass_": sys.modules["%s.ir_pass" % root_namespace],
"_codegen_": sys.modules["%s.codegen" % root_namespace],
"_schedule_": sys.modules["%s.schedule" % root_namespace] "_schedule_": sys.modules["%s.schedule" % root_namespace]
} }
......
"""Code generation related functions"""
...@@ -30,7 +30,7 @@ inline Type String2Type(std::string s) { ...@@ -30,7 +30,7 @@ inline Type String2Type(std::string s) {
} else if (s.substr(0, 5) == "float") { } else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5); code = Type::Float; s = s.substr(5);
} else if (s == "handle") { } else if (s == "handle") {
return Type(Type::Handle, 0, 0); return Type(Type::Handle, 32, 1);
} else { } else {
LOG(FATAL) << "unknown type " << s; LOG(FATAL) << "unknown type " << s;
} }
......
/*!
* Copyright (c) 2016 by Contributors
* Implementation of API functions related to IR build
* \file c_api_ir.cc
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
#include "./c_api_registry.h"
#include "../codegen/codegen_c.h"
namespace tvm {
namespace codegen {
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
TVM_REGISTER_API(_codegen_CompileToC)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = CodeGenC().Compile(
args.at(0), args.at(1), args.at(2), args.at(3));
});
} // namespace codegen
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file codegen_c.h
* \brief Common utilities to generated C style code.
*/
#ifndef TVM_CODEGEN_CODEGEN_C_H_
#define TVM_CODEGEN_CODEGEN_C_H_
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <string>
#include <unordered_map>
namespace tvm {
namespace codegen {
/*!
* \brief A base class to generate C code.
*
* CodeGenC have two modes: generate SSA formed C code or normal form.
*/
class CodeGenC {
public:
/*!
* \brief Generate the C code of statement
* \param body The body of the function.
* \param fun_name The name of the function.
* \param args The arguments to the function.
* \param output_ssa Whether output ssa form.
* \note Only call compile once,
* create a new codegen object each time.
*/
std::string Compile(Stmt body,
std::string fun_name,
Array<Var> args,
bool output_ssa);
/*!
* \brief Print the Stmt n to CodeGenC->stream
* \param n The statement to be printed.
*/
void PrintStmt(const Stmt& n);
/*!
* \brief Print the expression n(or its ssa id if in ssa mode) into os
* \param n The expression to be printed.
* \param os The output stream
*/
void PrintExpr(const Expr& n, std::ostream& os); // NOLINT(*)
/*!
* \brief Same as PrintExpr, but simply returns result string
* \param n The expression to be printed.
*/
inline std::string PrintExpr(const Expr& n) {
std::ostringstream os;
PrintExpr(n, os);
return os.str();
}
/*! \brief print the current indented value */
void PrintIndent();
/*!
* \brief Register constant value appeared in expresion tree
* This avoid generated a ssa id for each appearance of the value
* \param value The constant value.
*/
void MarkConst(std::string value);
/*!
* \brief Allocate a variable name for a newly defined var.
* \param v The variable.
* \return the variable name.
*/
std::string AllocVarID(const Variable* v);
/*!
* \brief Get a variable name.
* \param v The variable.
* \return the variable name.
*/
std::string GetVarID(const Variable* v) const;
/*!
* Print Type represetnation of type t.
* \param t The type representation.
* \return os The stream to print the ctype into
*/
virtual void PrintType(Type t, std::ostream& os) const; // NOLINT(*)
// The following parts are overloadable print operations.
virtual void PrintStmt(const ir::LetStmt* op);
virtual void PrintStmt(const ir::Store* op);
virtual void PrintStmt(const ir::Allocate* op);
virtual void PrintStmt(const ir::AttrStmt* op);
virtual void PrintExpr(const ir::Load* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Let* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Ramp* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Broadcast* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Select* op, std::ostream& os); // NOLINT(*)
/*! \brief function print into the ostream */
using FPrintExpr = IRFunctor<void(const NodeRef&, std::ostream& os, CodeGenC *)>; // NOLINT(*)
/*! \brief function to to print normal code */
using FPrintStmt = IRFunctor<void(const NodeRef&, CodeGenC *)>;
// vtable to print code
static FPrintStmt& vtable_print_stmt();
// vtable to print code
static FPrintExpr& vtable_print_expr();
/*! \brief The current indentation value */
int indent{0};
/*! \brief the stream to be printed */
std::ostringstream stream;
private:
/*!
* \brief Get the SSA ID corresponds to src
* If necessary, generate new assignment
* \param src The source expression
* \param t The type of the expression.
*/
std::string SSAGetID(std::string src, Type t);
/*!
* \brief If buffer is allocated as type t.
* \param buf_var The buffer variable.
* \param t The type to be checked.
*/
bool BufferTypeMatch(const Variable* buf_var, Type t) const;
/*!
* \brief get a unique name with the corresponding prefix
* \param prefix The prefix of the name
* \return The returned name.
*/
std::string GetUniqueName(std::string prefix);
/*! \brief whether to print in SSA form */
bool print_ssa_form_{true};
/*! \brief name of each variable */
std::unordered_map<const Variable*, std::string> var_idmap_;
/*! \brief the data type of allocated buffers */
std::unordered_map<const Variable*, Type> alloc_buf_type_;
/*! \brief name allocation map */
std::unordered_map<std::string, int> name_alloc_map_;
/*! \brief assignment map of ssa */
std::unordered_map<std::string, std::string> ssa_assign_map_;
};
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_CODEGEN_C_H_
...@@ -19,10 +19,18 @@ def mock_test_add(): ...@@ -19,10 +19,18 @@ def mock_test_add():
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.ir_pass.ScheduleOps(s, bounds) stmt = tvm.ir_pass.ScheduleOps(s, bounds)
Ab = tvm.Buffer(A.shape, A.dtype, name='A') Ab = tvm.Buffer(A.shape, A.dtype, name='A')
Bb = tvm.Buffer(B.shape, B.dtype, name='B') Bb = tvm.Buffer(B.shape, B.dtype, name='B')
Cb = tvm.Buffer(C.shape, C.dtype, name='C') Cb = tvm.Buffer(C.shape, C.dtype, name='C')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
print(stmt)
output_ssa = False
code = tvm.codegen.CompileToC(stmt, "myadd",
[Ab.ptr, Bb.ptr, Cb.ptr, n],
output_ssa)
print(code)
def codegen(): def codegen():
# generate host/device code # generate host/device code
host_code, device_code = tvm.codegen.GenCUDA( host_code, device_code = tvm.codegen.GenCUDA(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment