Commit 154104b3 by Tianqi Chen Committed by GitHub

[PASS] Remap thread axis. (#1122)

parent d0f40112
......@@ -32,6 +32,11 @@ class TargetNode : public Node {
int max_num_threads = 1;
/*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
int thread_warp_size = 1;
/*!
* \brief The thread index that is the lowest(correspond to warp)
* In cuda it is threadIdx.x, but can be different in some platform.
*/
int thread_warp_index = 0;
/*! \brief Keys for this target */
Array<Expr> keys_array;
/*! \brief Options for this target */
......@@ -48,6 +53,7 @@ class TargetNode : public Node {
v->Visit("device_type", &device_type);
v->Visit("max_num_threads", &max_num_threads);
v->Visit("thread_warp_size", &thread_warp_size);
v->Visit("thread_warp_index", &thread_warp_index);
v->Visit("keys_array", &keys_array);
v->Visit("options_array", &options_array);
v->Visit("libs_array", &libs_array);
......
......@@ -417,6 +417,20 @@ LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size);
LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size);
/*!
* \brief Remap the thread axis
*
* This can be used to get equivalent program which uses
* threadIdx.y in place of threadIdx.x by passing
* {"threadIdx.x": thread_axis("threadIdx.y")}
*
*
* \param f The device function to be lowered.
* \param axis_map The map from StringImm -> ItrVar
* \return Transformed function.
*/
LoweredFunc RemapThreadAxis(LoweredFunc f, Map<Expr, IterVar> axis_map);
/*!
* \brief Lower packed function call.
* \param f The function to be lowered.
* \return Transformed function.
......
......@@ -98,6 +98,7 @@ class DumpIR(object):
schedule.ScheduleOps = self._old_sgpass
DumpIR.scope_level -= 1
@register_node
class BuildConfig(NodeBase):
"""Configuration scope to set a build config option.
......@@ -469,6 +470,13 @@ def build(sch,
for i, func in enumerate(fdevice):
warp_size = target.thread_warp_size
fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size)
warp_index = target.thread_warp_index
if warp_index != 0:
assert warp_index == 2
# swap z and x
tmap = {api.convert("threadIdx.z"): api.thread_axis("threadIdx.x"),
api.convert("threadIdx.x"): api.thread_axis("threadIdx.z")}
fdevice[i] = ir_pass.RemapThreadAxis(func, tmap)
if "gpu" in target.keys and not fdevice:
warnings.warn(
......
......@@ -109,6 +109,7 @@ class Target(NodeBase):
def __exit__(self, ptype, value, trace):
_api_internal._ExitTargetScope()
@register_node
class GenericFunc(NodeBase):
"""GenericFunc node reference. This represents a generic function
......
......@@ -126,6 +126,7 @@ REGISTER_PASS2(LiftAttrScope);
REGISTER_PASS1(NarrowChannelAccess);
REGISTER_PASS2(LowerThreadAllreduce);
REGISTER_PASS2(LowerWarpMemory);
REGISTER_PASS2(RemapThreadAxis);
REGISTER_PASS2(LowerIntrin);
REGISTER_PASS1(LowerTVMBuiltin);
REGISTER_PASS1(CombineContextCall);
......
......@@ -78,6 +78,8 @@ Target CreateTarget(const std::string& target_name,
t->max_num_threads = 256;
if (t->device_name == "intel_gpu") {
t->thread_warp_size = 16;
// use threadIdx.z for index
t->thread_warp_index = 2;
}
} else if (target_name == "metal" || target_name == "vulkan") {
if (target_name == "metal") {
......
/*!
* Copyright (c) 2018 by Contributors
* \file remap_thread_axis.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
#include <unordered_map>
namespace tvm {
namespace ir {
// Mutator to change the read pattern
class ThreadAxisRewriter : private IRMutator {
public:
explicit ThreadAxisRewriter(
const std::unordered_map<std::string, IterVar>& tmap)
: tmap_(tmap) {
}
Stmt Rewrite(Stmt stmt) {
return Mutate(stmt);
}
private:
Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) final {
if (op->attr_key == attr::thread_extent) {
IterVar iv(op->node.node_);
CHECK_NE(iv->thread_tag.length(), 0U);
auto it = tmap_.find(iv->thread_tag);
if (it != tmap_.end()) {
const IterVar& new_iv = it->second;
const Variable* v = iv->var.get();
if (!vmap_.count(v)) {
vmap_[v] = new_iv->var;
} else {
CHECK(vmap_[v].same_as(new_iv->var));
}
Stmt body = this->Mutate(op->body);
return AttrStmt::make(
new_iv, op->attr_key, op->value, body);
}
}
return IRMutator::Mutate_(op, stmt);
}
Expr Mutate_(const Variable* op, const Expr& expr) final {
auto it = vmap_.find(op);
if (it != vmap_.end()) return it->second;
return IRMutator::Mutate_(op, expr);
}
// The thread map
const std::unordered_map<std::string, IterVar>& tmap_;
// variable map
std::unordered_map<const Variable*, Var> vmap_;
};
LoweredFunc
RemapThreadAxis(LoweredFunc f, Map<Expr, IterVar> thread_map) {
std::unordered_map<std::string, IterVar> tmap;
for (const auto& kv : thread_map) {
const StringImm* str = kv.first.as<StringImm>();
CHECK(str != nullptr);
tmap[str->value] = kv.second;
}
CHECK_EQ(f->func_type, kDeviceFunc);
auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
// replace the thread axis
for (size_t i = 0; i < n->thread_axis.size(); ++i) {
auto it = tmap.find(n->thread_axis[i]->thread_tag);
if (it != tmap.end()) {
n->thread_axis.Set(i, it->second);
}
}
n->body = ThreadAxisRewriter(tmap).Rewrite(n->body);
return LoweredFunc(n);
}
} // namespace ir
} // namespace tvm
......@@ -34,9 +34,10 @@ def test_exp():
np.testing.assert_allclose(
b.asnumpy(), np.exp(a.asnumpy()), rtol=1e-5)
check_device("opencl -device=intel_gpu")
check_device("cuda", "llvm")
check_device("vulkan")
check_device("opencl")
def test_log_pow_llvm():
......@@ -196,8 +197,8 @@ def try_warp_memory():
if __name__ == "__main__":
test_exp()
try_warp_memory()
test_add()
test_log_pow_llvm()
test_exp()
test_popcount()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment