Commit 7ab574b0 by Tianqi Chen

[COMPILER] Upgrade to meet latest TVM IR pragma convention (#32)

parent 012697a3
......@@ -8,6 +8,21 @@ from topi import util as util
from .environment import get_env
def _match_pragma(stmt, key):
"""Internal helper to match stmt to pragma stmt.
Parameters
----------
stmt : Stmt
The AttrStmt
key : str
The pragma key
"""
return ((stmt.attr_key == "pragma_" + key) or
(stmt.attr_key == "pragma_scope" and stmt.value.value == key))
def fold_uop_loop(stmt_in):
"""Detect and fold uop loop.
......@@ -255,7 +270,7 @@ def inject_skip_copy(stmt_in):
Transformed statement
"""
def _do_fold(stmt):
if (stmt.attr_key == "pragma_scope" and stmt.value.value == "skip_dma_copy"):
if _match_pragma(stmt, "skip_dma_copy"):
return tvm.make.Evaluate(0)
return None
return tvm.ir_pass.IRTransform(
......@@ -277,12 +292,12 @@ def inject_coproc_sync(stmt_in):
"""
success = [False]
def _do_fold(stmt):
if stmt.attr_key == "pragma_scope" and stmt.value.value == "coproc_sync":
if _match_pragma(stmt, "coproc_sync"):
success[0] = True
sync = tvm.make.Call(
"int32", "vta.coproc_sync", [], tvm.expr.Call.Intrinsic, None, 0)
return tvm.make.Block(stmt.body, tvm.make.Evaluate(sync))
elif stmt.attr_key == "pragma_scope" and stmt.value.value == "trim_loop":
elif _match_pragma(stmt, "trim_loop"):
op = stmt.body
assert isinstance(op, tvm.stmt.For)
return tvm.make.For(
......@@ -561,7 +576,7 @@ def annotate_alu_coproc_scope(stmt_in):
"""
env = get_env()
def _do_fold(stmt):
if (stmt.attr_key == "pragma_scope" and stmt.value.value == "alu"):
if _match_pragma(stmt, "alu"):
irb = tvm.ir_builder.create()
irb.scope_attr(env.dev.vta_axis, "coproc_scope",
env.dev.get_task_qid(env.dev.QID_COMPUTE))
......@@ -569,7 +584,7 @@ def annotate_alu_coproc_scope(stmt_in):
tvm.make.StringImm("VTAPushALUOp"))
irb.emit(stmt)
return irb.get()
elif (stmt.attr_key == "pragma_scope" and stmt.value.value == "skip_alu"):
elif _match_pragma(stmt, "skip_alu"):
return tvm.make.Evaluate(0)
return stmt
......@@ -631,7 +646,7 @@ def inject_alu_intrin(stmt_in):
return rev_src_coeff, rev_dst_coeff, rev_extents
if (stmt.attr_key == "pragma_scope" and stmt.value.value == "alu"):
if _match_pragma(stmt, "alu"):
# Get to the innermost loop body
loop_body = stmt.body
nest_size = 0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment