Commit 154104b3 by Tianqi Chen Committed by GitHub

[PASS] Remap thread axis. (#1122)

parent d0f40112
...@@ -32,6 +32,11 @@ class TargetNode : public Node { ...@@ -32,6 +32,11 @@ class TargetNode : public Node {
int max_num_threads = 1; int max_num_threads = 1;
/*! \brief The warp size that should be used by the LowerThreadAllreduce pass */ /*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
int thread_warp_size = 1; int thread_warp_size = 1;
/*!
* \brief The thread index that is the lowest(correspond to warp)
* In cuda it is threadIdx.x, but can be different in some platform.
*/
int thread_warp_index = 0;
/*! \brief Keys for this target */ /*! \brief Keys for this target */
Array<Expr> keys_array; Array<Expr> keys_array;
/*! \brief Options for this target */ /*! \brief Options for this target */
...@@ -48,6 +53,7 @@ class TargetNode : public Node { ...@@ -48,6 +53,7 @@ class TargetNode : public Node {
v->Visit("device_type", &device_type); v->Visit("device_type", &device_type);
v->Visit("max_num_threads", &max_num_threads); v->Visit("max_num_threads", &max_num_threads);
v->Visit("thread_warp_size", &thread_warp_size); v->Visit("thread_warp_size", &thread_warp_size);
v->Visit("thread_warp_index", &thread_warp_index);
v->Visit("keys_array", &keys_array); v->Visit("keys_array", &keys_array);
v->Visit("options_array", &options_array); v->Visit("options_array", &options_array);
v->Visit("libs_array", &libs_array); v->Visit("libs_array", &libs_array);
......
...@@ -417,6 +417,20 @@ LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size); ...@@ -417,6 +417,20 @@ LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size);
LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size); LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size);
/*! /*!
* \brief Remap the thread axis
*
* This can be used to get equivalent program which uses
* threadIdx.y in place of threadIdx.x by passing
* {"threadIdx.x": thread_axis("threadIdx.y")}
*
*
* \param f The device function to be lowered.
* \param axis_map The map from StringImm -> ItrVar
* \return Transformed function.
*/
LoweredFunc RemapThreadAxis(LoweredFunc f, Map<Expr, IterVar> axis_map);
/*!
* \brief Lower packed function call. * \brief Lower packed function call.
* \param f The function to be lowered. * \param f The function to be lowered.
* \return Transformed function. * \return Transformed function.
......
...@@ -98,6 +98,7 @@ class DumpIR(object): ...@@ -98,6 +98,7 @@ class DumpIR(object):
schedule.ScheduleOps = self._old_sgpass schedule.ScheduleOps = self._old_sgpass
DumpIR.scope_level -= 1 DumpIR.scope_level -= 1
@register_node @register_node
class BuildConfig(NodeBase): class BuildConfig(NodeBase):
"""Configuration scope to set a build config option. """Configuration scope to set a build config option.
...@@ -469,6 +470,13 @@ def build(sch, ...@@ -469,6 +470,13 @@ def build(sch,
for i, func in enumerate(fdevice): for i, func in enumerate(fdevice):
warp_size = target.thread_warp_size warp_size = target.thread_warp_size
fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size) fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size)
warp_index = target.thread_warp_index
if warp_index != 0:
assert warp_index == 2
# swap z and x
tmap = {api.convert("threadIdx.z"): api.thread_axis("threadIdx.x"),
api.convert("threadIdx.x"): api.thread_axis("threadIdx.z")}
fdevice[i] = ir_pass.RemapThreadAxis(func, tmap)
if "gpu" in target.keys and not fdevice: if "gpu" in target.keys and not fdevice:
warnings.warn( warnings.warn(
......
...@@ -109,6 +109,7 @@ class Target(NodeBase): ...@@ -109,6 +109,7 @@ class Target(NodeBase):
def __exit__(self, ptype, value, trace): def __exit__(self, ptype, value, trace):
_api_internal._ExitTargetScope() _api_internal._ExitTargetScope()
@register_node @register_node
class GenericFunc(NodeBase): class GenericFunc(NodeBase):
"""GenericFunc node reference. This represents a generic function """GenericFunc node reference. This represents a generic function
......
...@@ -126,6 +126,7 @@ REGISTER_PASS2(LiftAttrScope); ...@@ -126,6 +126,7 @@ REGISTER_PASS2(LiftAttrScope);
REGISTER_PASS1(NarrowChannelAccess); REGISTER_PASS1(NarrowChannelAccess);
REGISTER_PASS2(LowerThreadAllreduce); REGISTER_PASS2(LowerThreadAllreduce);
REGISTER_PASS2(LowerWarpMemory); REGISTER_PASS2(LowerWarpMemory);
REGISTER_PASS2(RemapThreadAxis);
REGISTER_PASS2(LowerIntrin); REGISTER_PASS2(LowerIntrin);
REGISTER_PASS1(LowerTVMBuiltin); REGISTER_PASS1(LowerTVMBuiltin);
REGISTER_PASS1(CombineContextCall); REGISTER_PASS1(CombineContextCall);
......
...@@ -78,6 +78,8 @@ Target CreateTarget(const std::string& target_name, ...@@ -78,6 +78,8 @@ Target CreateTarget(const std::string& target_name,
t->max_num_threads = 256; t->max_num_threads = 256;
if (t->device_name == "intel_gpu") { if (t->device_name == "intel_gpu") {
t->thread_warp_size = 16; t->thread_warp_size = 16;
// use threadIdx.z for index
t->thread_warp_index = 2;
} }
} else if (target_name == "metal" || target_name == "vulkan") { } else if (target_name == "metal" || target_name == "vulkan") {
if (target_name == "metal") { if (target_name == "metal") {
......
/*!
* Copyright (c) 2018 by Contributors
* \file remap_thread_axis.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
#include <unordered_map>
namespace tvm {
namespace ir {
// Mutator to change the read pattern
class ThreadAxisRewriter : private IRMutator {
public:
explicit ThreadAxisRewriter(
const std::unordered_map<std::string, IterVar>& tmap)
: tmap_(tmap) {
}
Stmt Rewrite(Stmt stmt) {
return Mutate(stmt);
}
private:
Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) final {
if (op->attr_key == attr::thread_extent) {
IterVar iv(op->node.node_);
CHECK_NE(iv->thread_tag.length(), 0U);
auto it = tmap_.find(iv->thread_tag);
if (it != tmap_.end()) {
const IterVar& new_iv = it->second;
const Variable* v = iv->var.get();
if (!vmap_.count(v)) {
vmap_[v] = new_iv->var;
} else {
CHECK(vmap_[v].same_as(new_iv->var));
}
Stmt body = this->Mutate(op->body);
return AttrStmt::make(
new_iv, op->attr_key, op->value, body);
}
}
return IRMutator::Mutate_(op, stmt);
}
Expr Mutate_(const Variable* op, const Expr& expr) final {
auto it = vmap_.find(op);
if (it != vmap_.end()) return it->second;
return IRMutator::Mutate_(op, expr);
}
// The thread map
const std::unordered_map<std::string, IterVar>& tmap_;
// variable map
std::unordered_map<const Variable*, Var> vmap_;
};
LoweredFunc
RemapThreadAxis(LoweredFunc f, Map<Expr, IterVar> thread_map) {
std::unordered_map<std::string, IterVar> tmap;
for (const auto& kv : thread_map) {
const StringImm* str = kv.first.as<StringImm>();
CHECK(str != nullptr);
tmap[str->value] = kv.second;
}
CHECK_EQ(f->func_type, kDeviceFunc);
auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
// replace the thread axis
for (size_t i = 0; i < n->thread_axis.size(); ++i) {
auto it = tmap.find(n->thread_axis[i]->thread_tag);
if (it != tmap.end()) {
n->thread_axis.Set(i, it->second);
}
}
n->body = ThreadAxisRewriter(tmap).Rewrite(n->body);
return LoweredFunc(n);
}
} // namespace ir
} // namespace tvm
...@@ -34,9 +34,10 @@ def test_exp(): ...@@ -34,9 +34,10 @@ def test_exp():
np.testing.assert_allclose( np.testing.assert_allclose(
b.asnumpy(), np.exp(a.asnumpy()), rtol=1e-5) b.asnumpy(), np.exp(a.asnumpy()), rtol=1e-5)
check_device("opencl -device=intel_gpu")
check_device("cuda", "llvm") check_device("cuda", "llvm")
check_device("vulkan") check_device("vulkan")
check_device("opencl")
def test_log_pow_llvm(): def test_log_pow_llvm():
...@@ -196,8 +197,8 @@ def try_warp_memory(): ...@@ -196,8 +197,8 @@ def try_warp_memory():
if __name__ == "__main__": if __name__ == "__main__":
test_exp()
try_warp_memory() try_warp_memory()
test_add() test_add()
test_log_pow_llvm() test_log_pow_llvm()
test_exp()
test_popcount() test_popcount()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment