Unverified Commit f1438813 by Tianqi Chen Committed by GitHub

[PYTHON] Enhance with_attr API, cleanup MakeAPILegacy in testcases (#5335)

parent d81b006b
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
# under the License. # under the License.
"""Function defintiions.""" """Function defintiions."""
from enum import IntEnum from enum import IntEnum
import tvm.runtime
from .expr import RelayExpr from .expr import RelayExpr
from . import _ffi_api from . import _ffi_api
...@@ -34,3 +36,32 @@ class BaseFunc(RelayExpr): ...@@ -34,3 +36,32 @@ class BaseFunc(RelayExpr):
"""Return the attrs member of the function. """Return the attrs member of the function.
""" """
return _ffi_api.BaseFunc_Attrs(self) return _ffi_api.BaseFunc_Attrs(self)
def with_attr(self, attr_key_or_dict, attr_value=None):
"""Create a new copy of the function and update the attribute.
Parameters
----------
attr_key_or_dict : Union[str, dict]
The attribute key to use or a dict containing multiple key value pairs.
attr_value : Object
The new attribute value.
Returns
-------
func : Function
A new copy of the function
"""
# make sure we first copy so that we can safely do copy on write
# for multiple updates.
res = _ffi_api.BaseFuncCopy(self)
if isinstance(attr_key_or_dict, dict):
for key, val in attr_key_or_dict.items():
res = _ffi_api.BaseFuncWithAttr(
res._move(), key, tvm.runtime.convert(val))
return res
return _ffi_api.BaseFuncWithAttr(
res._move(), attr_key_or_dict, tvm.runtime.convert(attr_value))
...@@ -65,22 +65,3 @@ class Function(BaseFunc): ...@@ -65,22 +65,3 @@ class Function(BaseFunc):
Arguments. Arguments.
""" """
return Call(self, args, None, None) return Call(self, args, None, None)
def with_attr(self, attr_key, attr_value):
"""Create a new copy of the function and update the attribute
Parameters
----------
attr_key : str
The attribute key to use.
attr_value : Object
The new attribute value.
Returns
-------
func : Function
A new copy of the function
"""
return _ffi_api.FunctionWithAttr(
self, attr_key, convert(attr_value))
...@@ -168,41 +168,4 @@ def check_numerical_grads(function, input_values, grad_values, function_value=No ...@@ -168,41 +168,4 @@ def check_numerical_grads(function, input_values, grad_values, function_value=No
x_name, grad.shape, dist, max_diff, avg_diff) x_name, grad.shape, dist, max_diff, avg_diff)
def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
"""Legacy adapter to build a Module from statement.
Used for migrating existing test cases only.
Parameters
----------
stmt: Stmt
The input statement.
name: str
The name of the funciton.
args: list of Buffer or Vars
The function arguments
num_unpacked_args: int
Number of unpacked arguments.
nolias: bool
Whether allow noalias.
Returns
-------
mod : IRModule
The created IRModule.
"""
assert num_unpacked_args == 0
f = tvm.tir.PrimFunc(args, stmt).with_attr(
"global_symbol", tvm.runtime.String(name))
f = f.with_attr("tir.is_entry_func", True)
if noalias:
f = f.with_attr("tir.noalias", True)
mod = tvm.IRModule({name: f})
return mod
tvm._ffi._init_api("testing", __name__) tvm._ffi._init_api("testing", __name__)
...@@ -67,22 +67,3 @@ class PrimFunc(BaseFunc): ...@@ -67,22 +67,3 @@ class PrimFunc(BaseFunc):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs) _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs)
def with_attr(self, attr_key, attr_value):
"""Create a new copy of the function and update the attribute
Parameters
----------
attr_key : str
The attribute key to use.
attr_value : Object
The new attribute value.
Returns
-------
func : Function
A new copy of the function
"""
return _ffi_api.PrimFuncWithAttr(
self, attr_key, tvm.runtime.convert(attr_value))
...@@ -23,6 +23,14 @@ ...@@ -23,6 +23,14 @@
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/ir/function.h> #include <tvm/ir/function.h>
// NOTE: reverse dependency on relay, tir/
// These dependencies do not happen at the interface-level,
// and are only used in minimum cases where they are clearly marked.
//
// Rationale: We calls into the type specific WithAttr function
#include <tvm/tir/function.h>
#include <tvm/relay/function.h>
namespace tvm { namespace tvm {
...@@ -31,4 +39,22 @@ TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs") ...@@ -31,4 +39,22 @@ TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs")
return func->attrs; return func->attrs;
}); });
TVM_REGISTER_GLOBAL("ir.BaseFuncCopy")
.set_body_typed([](BaseFunc func) {
return func;
});
TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr")
.set_body_typed([](BaseFunc func, std::string key, ObjectRef value) -> BaseFunc {
if (func->IsInstance<tir::PrimFuncNode>()) {
return WithAttr(Downcast<tir::PrimFunc>(std::move(func)), key, value);
} else if (func->IsInstance<relay::FunctionNode>()) {
return WithAttr(Downcast<relay::Function>(std::move(func)), key, value);
} else {
LOG(FATAL) << "Do not support function type " << func->GetTypeKey();
return func;
}
});
} // namespace tvm } // namespace tvm
...@@ -362,13 +362,19 @@ IRModule IRModule::FromExpr( ...@@ -362,13 +362,19 @@ IRModule IRModule::FromExpr(
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) { const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) {
auto mod = IRModule(global_funcs, type_definitions); auto mod = IRModule(global_funcs, type_definitions);
BaseFunc func; BaseFunc func;
std::string gv_name = "main";
if (auto* func_node = expr.as<BaseFuncNode>()) { if (auto* func_node = expr.as<BaseFuncNode>()) {
func = GetRef<BaseFunc>(func_node); func = GetRef<BaseFunc>(func_node);
if (auto opt = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
gv_name = opt.value();
}
} else { } else {
func = relay::Function(relay::FreeVars(expr), expr, Type(), func = relay::Function(relay::FreeVars(expr), expr, Type(),
relay::FreeTypeVars(expr, mod), {}); relay::FreeTypeVars(expr, mod), {});
} }
auto main_gv = GlobalVar("main"); auto main_gv = GlobalVar(gv_name);
mod->Add(main_gv, func); mod->Add(main_gv, func);
return mod; return mod;
} }
......
...@@ -74,11 +74,5 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ...@@ -74,11 +74,5 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< node->attrs << ")"; << node->attrs << ")";
}); });
TVM_REGISTER_GLOBAL("relay.ir.FunctionWithAttr")
.set_body_typed(
[](Function func, std::string name, ObjectRef ref) {
return WithAttr(std::move(func), name, ref);
});
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -84,11 +84,5 @@ TVM_REGISTER_GLOBAL("tir.PrimFunc") ...@@ -84,11 +84,5 @@ TVM_REGISTER_GLOBAL("tir.PrimFunc")
return PrimFunc(params, body, ret_type, buffer_map, attrs); return PrimFunc(params, body, ret_type, buffer_map, attrs);
}); });
TVM_REGISTER_GLOBAL("tir.PrimFuncWithAttr")
.set_body_typed([](PrimFunc func, std::string name, ObjectRef ref) {
return WithAttr(std::move(func), name, ref);
});
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
...@@ -39,7 +39,8 @@ def test_dltensor_compatible(): ...@@ -39,7 +39,8 @@ def test_dltensor_compatible():
A[i + 1] = A[i] + 1 A[i + 1] = A[i] + 1
stmt = ib.get() stmt = ib.get()
mod = tvm.testing.MakeAPILegacy(stmt, "arange", [Ab], 0, True) mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "arange"))
f = tvm.build(mod, target="stackvm") f = tvm.build(mod, target="stackvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
aview = MyTensorView(a) aview = MyTensorView(a)
......
...@@ -57,8 +57,11 @@ def test_dso_module_load(): ...@@ -57,8 +57,11 @@ def test_dso_module_load():
tvm.tir.Store(Ab.data, tvm.tir.Store(Ab.data,
tvm.tir.Load(dtype, Ab.data, i) + 1, tvm.tir.Load(dtype, Ab.data, i) + 1,
i + 1)) i + 1))
m = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True) mod = tvm.IRModule.from_expr(
m = tvm.driver.build(m, target="llvm") tvm.tir.PrimFunc([Ab], stmt).with_attr(
"global_symbol", "main")
)
m = tvm.driver.build(mod, target="llvm")
for name in names: for name in names:
m.save(name) m.save(name)
......
...@@ -36,8 +36,11 @@ def test_llvm_intrin(): ...@@ -36,8 +36,11 @@ def test_llvm_intrin():
"int32", "prefetch", args, tvm.tir.Call.Intrinsic, None, 0))) "int32", "prefetch", args, tvm.tir.Call.Intrinsic, None, 0)))
body = ib.get() body = ib.get()
func = tvm.testing.MakeAPILegacy(body, "prefetch", [A], 0, True) mod = tvm.IRModule.from_expr(
fcode = tvm.build(func, None, "llvm") tvm.tir.PrimFunc([A], body).with_attr(
"global_symbol", "prefetch")
)
fcode = tvm.build(mod, None, "llvm")
def test_llvm_overloaded_intrin(): def test_llvm_overloaded_intrin():
...@@ -111,8 +114,9 @@ def test_llvm_lookup_intrin(): ...@@ -111,8 +114,9 @@ def test_llvm_lookup_intrin():
x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.v8i8", tvm.tir.const(1, 'uint32'), A[z]) x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.v8i8", tvm.tir.const(1, 'uint32'), A[z])
ib.emit(x) ib.emit(x)
body = ib.get() body = ib.get()
func = tvm.testing.MakeAPILegacy(body, "ctpop", [A], 0, True) mod = tvm.IRModule.from_expr(
fcode = tvm.build(func, None, "llvm") tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "main"))
fcode = tvm.build(mod, None, "llvm")
def test_llvm_large_uintimm(): def test_llvm_large_uintimm():
......
...@@ -20,17 +20,6 @@ import ctypes ...@@ -20,17 +20,6 @@ import ctypes
import numpy as np import numpy as np
def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
"""Legacy adapter to create a API"""
f = tvm.tir.PrimFunc(args, stmt).with_attr(
"global_symbol", tvm.runtime.String(name))
f = f.with_attr("tir.is_entry_func", True)
if noalias:
f = f.with_attr("tir.noalias", True)
mod = tvm.IRModule.from_expr(f)
return tvm.tir.transform.MakePackedAPI()(mod)
def test_static_callback(): def test_static_callback():
dtype = 'int64' dtype = 'int64'
n = te.size_var('n') n = te.size_var('n')
...@@ -44,8 +33,11 @@ def test_static_callback(): ...@@ -44,8 +33,11 @@ def test_static_callback():
with ib.for_range(0, n, "i", for_type="parallel") as i: with ib.for_range(0, n, "i", for_type="parallel") as i:
A[i] = A[i] + 1 A[i] = A[i] + 1
stmt = ib.get() stmt = ib.get()
fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
f = tvm.driver.build(fapi, target="llvm") mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "ramp")
)
f = tvm.driver.build(mod, target="llvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a) f(a)
f(a) f(a)
...@@ -67,8 +59,9 @@ def test_static_init(): ...@@ -67,8 +59,9 @@ def test_static_init():
return sh return sh
stmt = ib.get() stmt = ib.get()
fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True) mod = tvm.IRModule.from_expr(
f = tvm.driver.build(fapi, target="llvm") tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "ramp"))
f = tvm.driver.build(mod, target="llvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a) f(a)
......
...@@ -26,18 +26,6 @@ def run_jit(fapi, check): ...@@ -26,18 +26,6 @@ def run_jit(fapi, check):
s = f.get_source() s = f.get_source()
check(f) check(f)
def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
"""Legacy adapter to create a API"""
f = tvm.tir.PrimFunc(args, stmt).with_attr(
"global_symbol", tvm.runtime.String(name))
f = f.with_attr("tir.is_entry_func", True)
if noalias:
f = f.with_attr("tir.noalias", True)
mod = tvm.IRModule.from_expr(f)
return tvm.tir.transform.MakePackedAPI()(mod)
def test_stack_vm_basic(): def test_stack_vm_basic():
a = tvm.nd.array(np.zeros(10, dtype='float32')) a = tvm.nd.array(np.zeros(10, dtype='float32'))
@tvm.register_func @tvm.register_func
...@@ -48,8 +36,11 @@ def test_stack_vm_basic(): ...@@ -48,8 +36,11 @@ def test_stack_vm_basic():
n = te.size_var('n') n = te.size_var('n')
Ab = tvm.tir.decl_buffer((n, ), "float32") Ab = tvm.tir.decl_buffer((n, ), "float32")
stmt = tvm.tir.Evaluate(tvm.tir.call_packed("tvm_call_back_get_shape", Ab.shape[0])) stmt = tvm.tir.Evaluate(tvm.tir.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
fapi = tvm.testing.MakeAPILegacy(stmt, "print_shape", [Ab], 0, True)
run_jit(fapi, lambda f: f(a)) mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "print_shape"))
run_jit(mod, lambda f: f(a))
@tvm.register_func @tvm.register_func
...@@ -69,12 +60,13 @@ def test_stack_vm_loop(): ...@@ -69,12 +60,13 @@ def test_stack_vm_loop():
ib.emit(tvm.tir.call_packed("tvm_stack_vm_print", i)) ib.emit(tvm.tir.call_packed("tvm_stack_vm_print", i))
stmt = ib.get() stmt = ib.get()
fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True) mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "ramp"))
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
def check(f): def check(f):
f(a) f(a)
np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0])) np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
run_jit(fapi, check) run_jit(mod, check)
def test_stack_vm_cond(): def test_stack_vm_cond():
...@@ -91,14 +83,15 @@ def test_stack_vm_cond(): ...@@ -91,14 +83,15 @@ def test_stack_vm_cond():
A[i + 1] = A[i] + 2 A[i + 1] = A[i] + 2
stmt = ib.get() stmt = ib.get()
fapi = tvm.testing.MakeAPILegacy(stmt, "test", [Ab], 0, True) mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "test"))
def check(f): def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a) f(a)
y = np.arange(a.shape[0]) * 2 y = np.arange(a.shape[0]) * 2
y[5:] -= 1 y[5:] -= 1
np.testing.assert_equal(a.asnumpy(), y) np.testing.assert_equal(a.asnumpy(), y)
run_jit(fapi, check) run_jit(mod, check)
def test_vm_parallel(): def test_vm_parallel():
dtype = 'int64' dtype = 'int64'
...@@ -110,12 +103,13 @@ def test_vm_parallel(): ...@@ -110,12 +103,13 @@ def test_vm_parallel():
with ib.for_range(0, n, "i", for_type="parallel") as i: with ib.for_range(0, n, "i", for_type="parallel") as i:
A[i] = A[i] + 1 A[i] = A[i] + 1
stmt = ib.get() stmt = ib.get()
fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True) mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "test"))
def check(f): def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a) f(a)
np.testing.assert_equal(a.asnumpy(), np.ones(a.shape[0])) np.testing.assert_equal(a.asnumpy(), np.ones(a.shape[0]))
run_jit(fapi, check) run_jit(mod, check)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -277,7 +277,7 @@ def test_prim_func(): ...@@ -277,7 +277,7 @@ def test_prim_func():
assert func.buffer_map[func.params[2]].same_as(b) assert func.buffer_map[func.params[2]].same_as(b)
assert len(func.buffer_map) == 1 assert len(func.buffer_map) == 1
f2 = func.with_attr("calling_conv", 1) f2 = func.with_attr({"calling_conv": 1, "tir.noalias": True})
assert f2.attrs["calling_conv"].value == 1 assert f2.attrs["calling_conv"].value == 1
assert func.attrs is None assert func.attrs is None
......
...@@ -92,7 +92,9 @@ def test_flatten_double_buffer(): ...@@ -92,7 +92,9 @@ def test_flatten_double_buffer():
stmt = tvm.tir.ir_pass.Simplify(stmt) stmt = tvm.tir.ir_pass.Simplify(stmt)
assert isinstance(stmt.body.body, tvm.tir.Allocate) assert isinstance(stmt.body.body, tvm.tir.Allocate)
assert stmt.body.body.extents[0].value == 2 assert stmt.body.body.extents[0].value == 2
mod = tvm.testing.MakeAPILegacy(stmt, "db", [A.asobject(), C.asobject()], 0, True)
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([A, C], stmt).with_attr("global_symbol", "db"))
f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
count = [0] count = [0]
......
...@@ -43,7 +43,7 @@ def test_lower_warp_memory_local_scope(): ...@@ -43,7 +43,7 @@ def test_lower_warp_memory_local_scope():
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod)
fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"] fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"]
mod = tvm.IRModule.from_expr(fdevice) mod = tvm.IRModule.from_expr(fdevice)
fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["main"] fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"]
assert(fdevice.body.body.value.value == "local") assert(fdevice.body.body.value.value == "local")
assert(fdevice.body.body.body.extents[0].value == 2) assert(fdevice.body.body.body.extents[0].value == 2)
......
...@@ -35,11 +35,11 @@ def test_makeapi(): ...@@ -35,11 +35,11 @@ def test_makeapi():
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}, 64) stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}, 64)
num_unpacked_args = 2 num_unpacked_args = 2
f = tvm.tir.PrimFunc([n, Ab, Bb, Cb], stmt) mod = tvm.IRModule.from_expr(
f = f.with_attr("global_symbol", "myadd") tvm.tir.PrimFunc([n, Ab, Bb, Cb], stmt).with_attr({
f = f.with_attr("target", tvm.target.create("llvm")) "global_symbol": "main",
"target": tvm.target.create("llvm")
mod = tvm.IRModule.from_expr(f) }))
f = tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)["main"] f = tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)["main"]
assert(len(f.params) == 7) assert(len(f.params) == 7)
......
...@@ -39,12 +39,15 @@ def test_thread_storage_sync(): ...@@ -39,12 +39,15 @@ def test_thread_storage_sync():
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64) stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
cuda_target = tvm.target.create("cuda") cuda_target = tvm.target.create("cuda")
mod = tvm.testing.MakeAPILegacy(stmt, "test", [Ab, A2b], 0, True)
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod) mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab, A2b], stmt).with_attr({
"global_symbol": "test", "target": cuda_target}))
fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"] fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"]
mod = tvm.IRModule.from_expr(fdevice) mod = tvm.IRModule.from_expr(fdevice)
cuda_target = tvm.target.create("cuda") cuda_target = tvm.target.create("cuda")
f = tvm.tir.transform.ThreadSync("shared")(mod)["main"] f = tvm.tir.transform.ThreadSync("shared")(mod)["test_kernel0"]
body_list = tvm.tir.stmt_list(f.body.body.body.body) body_list = tvm.tir.stmt_list(f.body.body.body.body)
assert(body_list[1].value.name == "tvm_storage_sync") assert(body_list[1].value.name == "tvm_storage_sync")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment