Commit 92f82c8e by kun-zh Committed by Tianqi Chen

[PASS] add a pass for the specific hardware accelarator when it is not binded (#1999)

parent b9e8826f
...@@ -238,6 +238,11 @@ constexpr const char* pipeline_exec_scope = "pipeline_exec_scope"; ...@@ -238,6 +238,11 @@ constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
constexpr const char* opengl_stage_scope = "opengl_stage_scope"; constexpr const char* opengl_stage_scope = "opengl_stage_scope";
/*! /*!
* \brief Mark that it is in the device scope.
*/
constexpr const char* device_scope = "device_scope";
/*!
* \brief Check if attr_key is a pragma key extension * \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared * \param attr_key The attr key to be compared
* \return true if it is a pragma key * \return true if it is a pragma key
......
...@@ -327,6 +327,15 @@ Stmt RewriteUnsafeSelect(Stmt stmt); ...@@ -327,6 +327,15 @@ Stmt RewriteUnsafeSelect(Stmt stmt);
Stmt LowerStorageAccessInfo(Stmt stmt); Stmt LowerStorageAccessInfo(Stmt stmt);
/*! /*!
* \brief Decorate the stmt with a device scope, this is helpful for
* hardware accelerator without thread blocks.
*
* \param stmt The stmt to be trasnformed
* \return Transformed stmt.
*/
Stmt DecorateDeviceScope(Stmt stmt);
/*!
* \brief Make an user callable API LoweredFunc. * \brief Make an user callable API LoweredFunc.
* *
* The main task of this function is to create code to : * The main task of this function is to create code to :
......
...@@ -154,5 +154,6 @@ REGISTER_PASS1(LowerTVMBuiltin); ...@@ -154,5 +154,6 @@ REGISTER_PASS1(LowerTVMBuiltin);
REGISTER_PASS1(CombineContextCall); REGISTER_PASS1(CombineContextCall);
REGISTER_PASS2(VerifyMemory); REGISTER_PASS2(VerifyMemory);
REGISTER_PASS2(VerifyGPUCode); REGISTER_PASS2(VerifyGPUCode);
REGISTER_PASS1(DecorateDeviceScope);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file detect_device.cc
*/
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include "../pass/ir_util.h"
namespace tvm {
namespace ir {
Stmt DecorateDeviceScope(Stmt stmt) {
Stmt body = AttrStmt::make(make_zero(Int(32)),
ir::attr::device_scope,
0,
stmt);
return body;
}
} // namespace ir
} // namespace tvm
...@@ -153,7 +153,8 @@ class HostDeviceSplitter : public IRMutator { ...@@ -153,7 +153,8 @@ class HostDeviceSplitter : public IRMutator {
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
if (op->attr_key == attr::thread_extent || if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::pipeline_exec_scope) { op->attr_key == attr::pipeline_exec_scope ||
op->attr_key == attr::device_scope) {
return SplitDeviceFunc(s); return SplitDeviceFunc(s);
} }
return IRMutator::Mutate_(op, s); return IRMutator::Mutate_(op, s);
......
import tvm
def test_decorate_device():
m = tvm.var('m')
l = tvm.var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = tvm.create_schedule(A2.op)
xo, xi = s[A2].split(A2.op.axis[0], factor=8)
s[A1].compute_at(s[A2], xo)
s[A1].set_scope("shared")
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt1 = tvm.ir_pass.Simplify(stmt)
stmt2 = tvm.ir_pass.DecorateDeviceScope(stmt1)
assert isinstance(stmt2, tvm.stmt.AttrStmt)
assert stmt2.attr_key == "device_scope"
assert stmt1 == stmt2.body
if __name__ == "__main__":
test_decorate_device()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment