Unverified Commit 869b718a by Tianqi Chen Committed by GitHub

[TIR] Fix perf regression of tir refactor (#5258)

parent 7902f762
......@@ -198,7 +198,7 @@ def lower(sch,
f = tvm.tir.PrimFunc(arg_list, stmt).with_attr(
"global_symbol", tvm.runtime.String(name))
if cfg.restricted_func:
f = f.with_attr("tir.no_alias", True)
f = f.with_attr("tir.noalias", True)
mod = tvm.IRModule({name: f})
return tvm.tir.transform.MakePackedAPI()(mod)
......
......@@ -199,7 +199,7 @@ def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
"global_symbol", tvm.runtime.String(name))
f = f.with_attr("tir.is_entry_func", True)
if noalias:
f = f.with_attr("tir.no_alias", True)
f = f.with_attr("tir.noalias", True)
mod = tvm.IRModule({name: f})
return tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)
......
......@@ -214,7 +214,7 @@ IRModule lower(te::Schedule sch,
f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
if (config->restricted_func) {
f = WithAttr(std::move(f), "tir.no_alias", Integer(1));
f = WithAttr(std::move(f), "tir.noalias", Integer(1));
}
auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
return tir::transform::MakePackedAPI(0)(mod);
......
......@@ -26,7 +26,7 @@ def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
"global_symbol", tvm.runtime.String(name))
f = f.with_attr("tir.is_entry_func", True)
if noalias:
f = f.with_attr("tir.no_alias", True)
f = f.with_attr("tir.noalias", True)
mod = tvm.IRModule.from_expr(f)
return tvm.tir.transform.MakePackedAPI()(mod)
......
......@@ -33,7 +33,7 @@ def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
"global_symbol", tvm.runtime.String(name))
f = f.with_attr("tir.is_entry_func", True)
if noalias:
f = f.with_attr("tir.no_alias", True)
f = f.with_attr("tir.noalias", True)
mod = tvm.IRModule.from_expr(f)
return tvm.tir.transform.MakePackedAPI()(mod)
......
......@@ -36,7 +36,7 @@ def test_makeapi():
num_unpacked_args = 2
f = tvm.tir.PrimFunc([n, Ab, Bb, Cb], stmt).with_attr(
"tir.no_alias", True).with_attr("global_symbol", tvm.runtime.String("myadd"))
"tir.noalias", True).with_attr("global_symbol", tvm.runtime.String("myadd"))
mod = tvm.IRModule.from_expr(f)
f = tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)["main"]
assert(len(f.params) == 7)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment