Commit b0e41b9a by Tianqi Chen Committed by GitHub

[CODEGEN] Concise typecast for threadIdx (#208)

parent bf97724b
...@@ -245,9 +245,17 @@ void CodeGenC::PrintVecStore(const Variable* buffer, ...@@ -245,9 +245,17 @@ void CodeGenC::PrintVecStore(const Variable* buffer,
stream << ref << " = " << value << ";\n"; stream << ref << " = " << value << ";\n";
} }
std::string CodeGenC::CastFromTo(std::string value, Type from, Type target) {
if (from == target) return value;
std::ostringstream os;
os << "((";
this->PrintType(target, os);
os << ")" << value << ")";
return os.str();
}
void CodeGenC::BindThreadIndex(const IterVar& iv) { void CodeGenC::BindThreadIndex(const IterVar& iv) {
CHECK(!var_idmap_.count(iv->var.get())); LOG(FATAL) << "not implemented";
var_idmap_[iv->var.get()] = iv->thread_tag;
} }
void CodeGenC::PrintStorageSync(const Call* op) { // NOLINT(*) void CodeGenC::PrintStorageSync(const Call* op) { // NOLINT(*)
......
...@@ -150,6 +150,8 @@ class CodeGenC : ...@@ -150,6 +150,8 @@ class CodeGenC :
// print reference to a buffer as type t in index. // print reference to a buffer as type t in index.
std::string GetBufferRef( std::string GetBufferRef(
Type t, const Variable* buffer, Expr index); Type t, const Variable* buffer, Expr index);
// Get a cast type from to
std::string CastFromTo(std::string value, Type from, Type target);
/*! /*!
* \brief If buffer is allocated as type t. * \brief If buffer is allocated as type t.
* \param buf_var The buffer variable. * \param buf_var The buffer variable.
......
...@@ -35,6 +35,12 @@ void CodeGenCUDA::VisitStmt_(const ir::For* op) { ...@@ -35,6 +35,12 @@ void CodeGenCUDA::VisitStmt_(const ir::For* op) {
CodeGenC::VisitStmt_(op); CodeGenC::VisitStmt_(op);
} }
void CodeGenCUDA::BindThreadIndex(const IterVar& iv) {
CHECK(!var_idmap_.count(iv->var.get()));
var_idmap_[iv->var.get()] =
CastFromTo(iv->thread_tag, UInt(32), iv->var.type());
}
void CodeGenCUDA::PrintType(Type t, std::ostream& os) const { // NOLINT(*) void CodeGenCUDA::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
int lanes = t.lanes(); int lanes = t.lanes();
if (t.is_handle()) { if (t.is_handle()) {
......
...@@ -30,7 +30,7 @@ class CodeGenCUDA final : public CodeGenC { ...@@ -30,7 +30,7 @@ class CodeGenCUDA final : public CodeGenC {
const std::string& vec, Type t, int i, std::ostream& os) final; // NOLINT(*) const std::string& vec, Type t, int i, std::ostream& os) final; // NOLINT(*)
void PrintVecElemStore( void PrintVecElemStore(
const std::string& vec, Type t, int i, const std::string& value) final; const std::string& vec, Type t, int i, const std::string& value) final;
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
// overload visitor // overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
void VisitStmt_(const Evaluate *op) final; void VisitStmt_(const Evaluate *op) final;
......
...@@ -126,6 +126,12 @@ void CodeGenMetal::AddFunction(LoweredFunc f) { ...@@ -126,6 +126,12 @@ void CodeGenMetal::AddFunction(LoweredFunc f) {
this->stream << "}\n\n"; this->stream << "}\n\n";
} }
void CodeGenMetal::BindThreadIndex(const IterVar& iv) {
CHECK(!var_idmap_.count(iv->var.get()));
var_idmap_[iv->var.get()] =
CastFromTo(iv->thread_tag, UInt(16), iv->var.type());
}
void CodeGenMetal::PrintType(Type t, std::ostream& os) const { // NOLINT(*) void CodeGenMetal::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
int lanes = t.lanes(); int lanes = t.lanes();
if (t.is_handle()) { if (t.is_handle()) {
......
...@@ -24,7 +24,7 @@ class CodeGenMetal final : public CodeGenC { ...@@ -24,7 +24,7 @@ class CodeGenMetal final : public CodeGenC {
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const Call* op) final; // NOLINT(*) void PrintStorageSync(const Call* op) final; // NOLINT(*)
void PrintType(Type t, std::ostream& os) const final; // NOLINT(*) void PrintType(Type t, std::ostream& os) const final; // NOLINT(*)
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
// overload visitor // overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
}; };
......
...@@ -35,7 +35,8 @@ void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) { ...@@ -35,7 +35,8 @@ void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) {
} else { } else {
os << "get_group_id(" << ts.dim_index << ")"; os << "get_group_id(" << ts.dim_index << ")";
} }
var_idmap_[iv->var.get()] = os.str(); var_idmap_[iv->var.get()] =
CastFromTo(os.str(), UInt(64), iv->var.type());
} }
void CodeGenOpenCL::PrintType(Type t, std::ostream& os) const { // NOLINT(*) void CodeGenOpenCL::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
......
...@@ -140,7 +140,7 @@ TVM_REGISTER_GLOBAL("module._RPCTimeEvaluator") ...@@ -140,7 +140,7 @@ TVM_REGISTER_GLOBAL("module._RPCTimeEvaluator")
->GetTimeEvaluator(args[1], ctx, args[4]); ->GetTimeEvaluator(args[1], ctx, args[4]);
} else { } else {
*rv = WrapTimeEvaluator( *rv = WrapTimeEvaluator(
m.GetFunction(args[1], false), ctx, args[3]); m.GetFunction(args[1], false), ctx, args[4]);
} }
}); });
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment