Unverified Commit f1438813 by Tianqi Chen Committed by GitHub

[PYTHON] Enhance with_attr API, cleanup MakeAPILegacy in testcases (#5335)

parent d81b006b
......@@ -16,6 +16,8 @@
# under the License.
"""Function defintiions."""
from enum import IntEnum
import tvm.runtime
from .expr import RelayExpr
from . import _ffi_api
......@@ -34,3 +36,32 @@ class BaseFunc(RelayExpr):
"""Return the attrs member of the function.
"""
return _ffi_api.BaseFunc_Attrs(self)
def with_attr(self, attr_key_or_dict, attr_value=None):
"""Create a new copy of the function and update the attribute.
Parameters
----------
attr_key_or_dict : Union[str, dict]
The attribute key to use or a dict containing multiple key value pairs.
attr_value : Object
The new attribute value.
Returns
-------
func : Function
A new copy of the function
"""
# make sure we first copy so that we can safely do copy on write
# for multiple updates.
res = _ffi_api.BaseFuncCopy(self)
if isinstance(attr_key_or_dict, dict):
for key, val in attr_key_or_dict.items():
res = _ffi_api.BaseFuncWithAttr(
res._move(), key, tvm.runtime.convert(val))
return res
return _ffi_api.BaseFuncWithAttr(
res._move(), attr_key_or_dict, tvm.runtime.convert(attr_value))
......@@ -65,22 +65,3 @@ class Function(BaseFunc):
Arguments.
"""
return Call(self, args, None, None)
def with_attr(self, attr_key, attr_value):
"""Create a new copy of the function and update the attribute
Parameters
----------
attr_key : str
The attribute key to use.
attr_value : Object
The new attribute value.
Returns
-------
func : Function
A new copy of the function
"""
return _ffi_api.FunctionWithAttr(
self, attr_key, convert(attr_value))
......@@ -168,41 +168,4 @@ def check_numerical_grads(function, input_values, grad_values, function_value=No
x_name, grad.shape, dist, max_diff, avg_diff)
def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
"""Legacy adapter to build a Module from statement.
Used for migrating existing test cases only.
Parameters
----------
stmt: Stmt
The input statement.
name: str
The name of the funciton.
args: list of Buffer or Vars
The function arguments
num_unpacked_args: int
Number of unpacked arguments.
nolias: bool
Whether allow noalias.
Returns
-------
mod : IRModule
The created IRModule.
"""
assert num_unpacked_args == 0
f = tvm.tir.PrimFunc(args, stmt).with_attr(
"global_symbol", tvm.runtime.String(name))
f = f.with_attr("tir.is_entry_func", True)
if noalias:
f = f.with_attr("tir.noalias", True)
mod = tvm.IRModule({name: f})
return mod
tvm._ffi._init_api("testing", __name__)
......@@ -67,22 +67,3 @@ class PrimFunc(BaseFunc):
self.__init_handle_by_constructor__(
_ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs)
def with_attr(self, attr_key, attr_value):
"""Create a new copy of the function and update the attribute
Parameters
----------
attr_key : str
The attribute key to use.
attr_value : Object
The new attribute value.
Returns
-------
func : Function
A new copy of the function
"""
return _ffi_api.PrimFuncWithAttr(
self, attr_key, tvm.runtime.convert(attr_value))
......@@ -23,6 +23,14 @@
*/
#include <tvm/runtime/registry.h>
#include <tvm/ir/function.h>
// NOTE: reverse dependency on relay, tir/
// These dependencies do not happen at the interface-level,
// and are only used in minimum cases where they are clearly marked.
//
// Rationale: We calls into the type specific WithAttr function
#include <tvm/tir/function.h>
#include <tvm/relay/function.h>
namespace tvm {
......@@ -31,4 +39,22 @@ TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs")
return func->attrs;
});
TVM_REGISTER_GLOBAL("ir.BaseFuncCopy")
.set_body_typed([](BaseFunc func) {
return func;
});
TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr")
.set_body_typed([](BaseFunc func, std::string key, ObjectRef value) -> BaseFunc {
if (func->IsInstance<tir::PrimFuncNode>()) {
return WithAttr(Downcast<tir::PrimFunc>(std::move(func)), key, value);
} else if (func->IsInstance<relay::FunctionNode>()) {
return WithAttr(Downcast<relay::Function>(std::move(func)), key, value);
} else {
LOG(FATAL) << "Do not support function type " << func->GetTypeKey();
return func;
}
});
} // namespace tvm
......@@ -362,13 +362,19 @@ IRModule IRModule::FromExpr(
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) {
auto mod = IRModule(global_funcs, type_definitions);
BaseFunc func;
std::string gv_name = "main";
if (auto* func_node = expr.as<BaseFuncNode>()) {
func = GetRef<BaseFunc>(func_node);
if (auto opt = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
gv_name = opt.value();
}
} else {
func = relay::Function(relay::FreeVars(expr), expr, Type(),
relay::FreeTypeVars(expr, mod), {});
}
auto main_gv = GlobalVar("main");
auto main_gv = GlobalVar(gv_name);
mod->Add(main_gv, func);
return mod;
}
......
......@@ -74,11 +74,5 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< node->attrs << ")";
});
TVM_REGISTER_GLOBAL("relay.ir.FunctionWithAttr")
.set_body_typed(
[](Function func, std::string name, ObjectRef ref) {
return WithAttr(std::move(func), name, ref);
});
} // namespace relay
} // namespace tvm
......@@ -84,11 +84,5 @@ TVM_REGISTER_GLOBAL("tir.PrimFunc")
return PrimFunc(params, body, ret_type, buffer_map, attrs);
});
TVM_REGISTER_GLOBAL("tir.PrimFuncWithAttr")
.set_body_typed([](PrimFunc func, std::string name, ObjectRef ref) {
return WithAttr(std::move(func), name, ref);
});
} // namespace tir
} // namespace tvm
......@@ -39,7 +39,8 @@ def test_dltensor_compatible():
A[i + 1] = A[i] + 1
stmt = ib.get()
mod = tvm.testing.MakeAPILegacy(stmt, "arange", [Ab], 0, True)
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "arange"))
f = tvm.build(mod, target="stackvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
aview = MyTensorView(a)
......
......@@ -57,8 +57,11 @@ def test_dso_module_load():
tvm.tir.Store(Ab.data,
tvm.tir.Load(dtype, Ab.data, i) + 1,
i + 1))
m = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
m = tvm.driver.build(m, target="llvm")
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr(
"global_symbol", "main")
)
m = tvm.driver.build(mod, target="llvm")
for name in names:
m.save(name)
......
......@@ -36,8 +36,11 @@ def test_llvm_intrin():
"int32", "prefetch", args, tvm.tir.Call.Intrinsic, None, 0)))
body = ib.get()
func = tvm.testing.MakeAPILegacy(body, "prefetch", [A], 0, True)
fcode = tvm.build(func, None, "llvm")
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([A], body).with_attr(
"global_symbol", "prefetch")
)
fcode = tvm.build(mod, None, "llvm")
def test_llvm_overloaded_intrin():
......@@ -111,8 +114,9 @@ def test_llvm_lookup_intrin():
x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.v8i8", tvm.tir.const(1, 'uint32'), A[z])
ib.emit(x)
body = ib.get()
func = tvm.testing.MakeAPILegacy(body, "ctpop", [A], 0, True)
fcode = tvm.build(func, None, "llvm")
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "main"))
fcode = tvm.build(mod, None, "llvm")
def test_llvm_large_uintimm():
......
......@@ -20,17 +20,6 @@ import ctypes
import numpy as np
def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
"""Legacy adapter to create a API"""
f = tvm.tir.PrimFunc(args, stmt).with_attr(
"global_symbol", tvm.runtime.String(name))
f = f.with_attr("tir.is_entry_func", True)
if noalias:
f = f.with_attr("tir.noalias", True)
mod = tvm.IRModule.from_expr(f)
return tvm.tir.transform.MakePackedAPI()(mod)
def test_static_callback():
dtype = 'int64'
n = te.size_var('n')
......@@ -44,8 +33,11 @@ def test_static_callback():
with ib.for_range(0, n, "i", for_type="parallel") as i:
A[i] = A[i] + 1
stmt = ib.get()
fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
f = tvm.driver.build(fapi, target="llvm")
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "ramp")
)
f = tvm.driver.build(mod, target="llvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
f(a)
......@@ -67,8 +59,9 @@ def test_static_init():
return sh
stmt = ib.get()
fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
f = tvm.driver.build(fapi, target="llvm")
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "ramp"))
f = tvm.driver.build(mod, target="llvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
......
......@@ -26,18 +26,6 @@ def run_jit(fapi, check):
s = f.get_source()
check(f)
def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
"""Legacy adapter to create a API"""
f = tvm.tir.PrimFunc(args, stmt).with_attr(
"global_symbol", tvm.runtime.String(name))
f = f.with_attr("tir.is_entry_func", True)
if noalias:
f = f.with_attr("tir.noalias", True)
mod = tvm.IRModule.from_expr(f)
return tvm.tir.transform.MakePackedAPI()(mod)
def test_stack_vm_basic():
a = tvm.nd.array(np.zeros(10, dtype='float32'))
@tvm.register_func
......@@ -48,8 +36,11 @@ def test_stack_vm_basic():
n = te.size_var('n')
Ab = tvm.tir.decl_buffer((n, ), "float32")
stmt = tvm.tir.Evaluate(tvm.tir.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
fapi = tvm.testing.MakeAPILegacy(stmt, "print_shape", [Ab], 0, True)
run_jit(fapi, lambda f: f(a))
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "print_shape"))
run_jit(mod, lambda f: f(a))
@tvm.register_func
......@@ -69,12 +60,13 @@ def test_stack_vm_loop():
ib.emit(tvm.tir.call_packed("tvm_stack_vm_print", i))
stmt = ib.get()
fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "ramp"))
a = tvm.nd.array(np.zeros(10, dtype=dtype))
def check(f):
f(a)
np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
run_jit(fapi, check)
run_jit(mod, check)
def test_stack_vm_cond():
......@@ -91,14 +83,15 @@ def test_stack_vm_cond():
A[i + 1] = A[i] + 2
stmt = ib.get()
fapi = tvm.testing.MakeAPILegacy(stmt, "test", [Ab], 0, True)
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "test"))
def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
y = np.arange(a.shape[0]) * 2
y[5:] -= 1
np.testing.assert_equal(a.asnumpy(), y)
run_jit(fapi, check)
run_jit(mod, check)
def test_vm_parallel():
dtype = 'int64'
......@@ -110,12 +103,13 @@ def test_vm_parallel():
with ib.for_range(0, n, "i", for_type="parallel") as i:
A[i] = A[i] + 1
stmt = ib.get()
fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "test"))
def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
np.testing.assert_equal(a.asnumpy(), np.ones(a.shape[0]))
run_jit(fapi, check)
run_jit(mod, check)
if __name__ == "__main__":
......
......@@ -277,7 +277,7 @@ def test_prim_func():
assert func.buffer_map[func.params[2]].same_as(b)
assert len(func.buffer_map) == 1
f2 = func.with_attr("calling_conv", 1)
f2 = func.with_attr({"calling_conv": 1, "tir.noalias": True})
assert f2.attrs["calling_conv"].value == 1
assert func.attrs is None
......
......@@ -92,7 +92,9 @@ def test_flatten_double_buffer():
stmt = tvm.tir.ir_pass.Simplify(stmt)
assert isinstance(stmt.body.body, tvm.tir.Allocate)
assert stmt.body.body.extents[0].value == 2
mod = tvm.testing.MakeAPILegacy(stmt, "db", [A.asobject(), C.asobject()], 0, True)
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([A, C], stmt).with_attr("global_symbol", "db"))
f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
count = [0]
......
......@@ -43,7 +43,7 @@ def test_lower_warp_memory_local_scope():
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod)
fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"]
mod = tvm.IRModule.from_expr(fdevice)
fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["main"]
fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"]
assert(fdevice.body.body.value.value == "local")
assert(fdevice.body.body.body.extents[0].value == 2)
......
......@@ -35,11 +35,11 @@ def test_makeapi():
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}, 64)
num_unpacked_args = 2
f = tvm.tir.PrimFunc([n, Ab, Bb, Cb], stmt)
f = f.with_attr("global_symbol", "myadd")
f = f.with_attr("target", tvm.target.create("llvm"))
mod = tvm.IRModule.from_expr(f)
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([n, Ab, Bb, Cb], stmt).with_attr({
"global_symbol": "main",
"target": tvm.target.create("llvm")
}))
f = tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)["main"]
assert(len(f.params) == 7)
......
......@@ -39,12 +39,15 @@ def test_thread_storage_sync():
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
cuda_target = tvm.target.create("cuda")
mod = tvm.testing.MakeAPILegacy(stmt, "test", [Ab, A2b], 0, True)
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod)
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab, A2b], stmt).with_attr({
"global_symbol": "test", "target": cuda_target}))
fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"]
mod = tvm.IRModule.from_expr(fdevice)
cuda_target = tvm.target.create("cuda")
f = tvm.tir.transform.ThreadSync("shared")(mod)["main"]
f = tvm.tir.transform.ThreadSync("shared")(mod)["test_kernel0"]
body_list = tvm.tir.stmt_list(f.body.body.body.body)
assert(body_list[1].value.name == "tvm_storage_sync")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment