Commit b0e41b9a by Tianqi Chen Committed by GitHub

[CODEGEN] Concise typecast for threadIdx (#208)

parent bf97724b
......@@ -245,9 +245,17 @@ void CodeGenC::PrintVecStore(const Variable* buffer,
stream << ref << " = " << value << ";\n";
}
std::string CodeGenC::CastFromTo(std::string value, Type from, Type target) {
if (from == target) return value;
std::ostringstream os;
os << "((";
this->PrintType(target, os);
os << ")" << value << ")";
return os.str();
}
void CodeGenC::BindThreadIndex(const IterVar& iv) {
CHECK(!var_idmap_.count(iv->var.get()));
var_idmap_[iv->var.get()] = iv->thread_tag;
LOG(FATAL) << "not implemented";
}
void CodeGenC::PrintStorageSync(const Call* op) { // NOLINT(*)
......
......@@ -150,6 +150,8 @@ class CodeGenC :
// print reference to a buffer as type t in index.
std::string GetBufferRef(
Type t, const Variable* buffer, Expr index);
// Get a cast type from to
std::string CastFromTo(std::string value, Type from, Type target);
/*!
* \brief If buffer is allocated as type t.
* \param buf_var The buffer variable.
......
......@@ -35,6 +35,12 @@ void CodeGenCUDA::VisitStmt_(const ir::For* op) {
CodeGenC::VisitStmt_(op);
}
void CodeGenCUDA::BindThreadIndex(const IterVar& iv) {
CHECK(!var_idmap_.count(iv->var.get()));
var_idmap_[iv->var.get()] =
CastFromTo(iv->thread_tag, UInt(32), iv->var.type());
}
void CodeGenCUDA::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
......
......@@ -30,7 +30,7 @@ class CodeGenCUDA final : public CodeGenC {
const std::string& vec, Type t, int i, std::ostream& os) final; // NOLINT(*)
void PrintVecElemStore(
const std::string& vec, Type t, int i, const std::string& value) final;
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
// overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
void VisitStmt_(const Evaluate *op) final;
......
......@@ -126,6 +126,12 @@ void CodeGenMetal::AddFunction(LoweredFunc f) {
this->stream << "}\n\n";
}
void CodeGenMetal::BindThreadIndex(const IterVar& iv) {
CHECK(!var_idmap_.count(iv->var.get()));
var_idmap_[iv->var.get()] =
CastFromTo(iv->thread_tag, UInt(16), iv->var.type());
}
void CodeGenMetal::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
......
......@@ -24,7 +24,7 @@ class CodeGenMetal final : public CodeGenC {
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const Call* op) final; // NOLINT(*)
void PrintType(Type t, std::ostream& os) const final; // NOLINT(*)
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
// overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
};
......
......@@ -35,7 +35,8 @@ void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) {
} else {
os << "get_group_id(" << ts.dim_index << ")";
}
var_idmap_[iv->var.get()] = os.str();
var_idmap_[iv->var.get()] =
CastFromTo(os.str(), UInt(64), iv->var.type());
}
void CodeGenOpenCL::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
......
......@@ -140,7 +140,7 @@ TVM_REGISTER_GLOBAL("module._RPCTimeEvaluator")
->GetTimeEvaluator(args[1], ctx, args[4]);
} else {
*rv = WrapTimeEvaluator(
m.GetFunction(args[1], false), ctx, args[3]);
m.GetFunction(args[1], false), ctx, args[4]);
}
});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment