Commit acd48e9a by Tianqi Chen Committed by GitHub

[RUNTIME] Enable ext_dev type for quick plugin of device (#542)

* [RUNTIME] Enable ext_dev type for quick plugin of device

* [TEST] Update testcase to cover all computation
parent 581509ab
......@@ -60,4 +60,9 @@ TVM_REGISTER_GLOBAL("tvm_ext.sym_add")
Var b = args[1];
*rv = a + b;
});
TVM_REGISTER_GLOBAL("device_api.ext_dev")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = (*tvm::runtime::Registry::Get("device_api.cpu"))();
});
} // namespace tvm_ext
import tvm_ext
import tvm
import numpy as np
def test_bind_add():
def add(a, b):
......@@ -7,6 +8,24 @@ def test_bind_add():
f = tvm_ext.bind_add(add, 1)
assert f(2) == 3
def test_ext_dev():
n = 10
A = tvm.placeholder((n,), name='A')
B = tvm.compute((n,), lambda *i: A(*i) + 1.0, name='B')
s = tvm.create_schedule(B.op)
def check_llvm():
if not tvm.module.enabled("llvm"):
return
f = tvm.build(s, [A, B], "ext_dev", "llvm")
ctx = tvm.ext_dev(0)
# launch the kernel.
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
f(a, b)
np.testing.assert_allclose(b.asnumpy(), a.asnumpy() + 1)
check_llvm()
def test_sym_add():
a = tvm.var('a')
b = tvm.var('b')
......@@ -26,6 +45,7 @@ def test_ext_vec():
tvm.convert(ivec_cb)(ivec)
if __name__ == "__main__":
test_ext_dev()
test_ext_vec()
test_bind_add()
test_sym_add()
......@@ -55,6 +55,9 @@ typedef int64_t tvm_index_t;
/*! \brief Extension device types in TVM */
typedef enum {
// Extension DRAM type, used for quickly test extension device
// The device api can differ depending on the xpu driver registered.
kExtDev = 12
// AddExtraTVMType which is not in DLPack here
} TVMDeviceExtType;
......
......@@ -17,7 +17,7 @@ from . import ir_builder
from . import target
from . import ndarray as nd
from .ndarray import context, cpu, gpu, opencl, cl, metal, mtl, vpi, rocm
from .ndarray import context, cpu, gpu, opencl, cl, metal, mtl, vpi, rocm, ext_dev
from ._ffi.runtime_ctypes import TypeCode
from ._ffi.function import Function
......
......@@ -96,7 +96,8 @@ class TVMContext(ctypes.Structure):
4 : 'opencl',
8 : 'metal',
9 : 'vpi',
10: 'rocm'
10: 'rocm',
12: 'ext_dev',
}
STR2MASK = {
'cpu': 1,
......@@ -106,7 +107,8 @@ class TVMContext(ctypes.Structure):
'opencl': 4,
'metal': 8,
'vpi': 9,
'rocm': 10
'rocm': 10,
'ext_dev': 12,
}
def __init__(self, device_type, device_id):
super(TVMContext, self).__init__()
......
......@@ -345,7 +345,7 @@ def build(sch,
else:
raise ValueError("unknown function type %d" % func.func_type)
if not target.startswith("llvm") and target != "stackvm" and not fdevice:
if not target.startswith("llvm") and target not in ("stackvm", "ext_dev") and not fdevice:
warnings.warn(
"Specified target %s, but cannot find device code, did you do bind?" % target)
......
......@@ -247,6 +247,10 @@ class RPCSession(object):
"""Construct remote Metal device."""
return self.context(8, dev_id)
def ext_dev(self, dev_id=0):
"""Construct remote extension device."""
return self.context(12, dev_id)
def upload(self, data, target=None):
"""Upload file to remote runtime temp folder
......
......@@ -120,6 +120,27 @@ def vpi(dev_id=0):
"""
return TVMContext(9, dev_id)
def ext_dev(dev_id=0):
"""Construct a extension device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
Note
----
This API is reserved for quick testing of new
device by plugin device API as ext_dev.
"""
return TVMContext(12, dev_id)
cl = opencl
mtl = metal
......
......@@ -31,6 +31,7 @@ inline std::string DeviceName(int type) {
case kMetal: return "metal";
case kVPI: return "vpi";
case kROCM: return "rocm";
case kExtDev: return "ext_dev";
default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
}
}
......
......@@ -134,7 +134,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
typedef dmlc::ThreadLocalStore<ROCMThreadEntry> ROCMThreadStore;
ROCMThreadEntry::ROCMThreadEntry()
: pool(kGPU, ROCMDeviceAPI::Global()) {
: pool(kROCM, ROCMDeviceAPI::Global()) {
}
ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() {
......
......@@ -7,7 +7,7 @@ def test_add_pipeline():
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
D = tvm.compute(A.shape, lambda *i: C(*i) + 1, name='C')
D = tvm.compute(A.shape, lambda *i: C(*i) + 1, name='D')
s = tvm.create_schedule(D.op)
# GPU schedule have to split by gridIdx and threadIdx
......@@ -26,11 +26,11 @@ def test_add_pipeline():
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
Cb = tvm.decl_buffer(C.shape, C.dtype, name='C')
Db = tvm.decl_buffer(D.shape, D.dtype, name='D')
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}, 64)
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, D:Db}, 64)
stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0, True)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Db], 0, True)
fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0])
......@@ -49,10 +49,10 @@ def test_add_pipeline():
n = 1027
a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=Cb.dtype), ctx)
f(a, b, c)
d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx)
f(a, b, d)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy())
d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)
def check_module_save(device, host="stackvm"):
if not tvm.module.enabled(host):
......@@ -73,10 +73,10 @@ def test_add_pipeline():
n = 1027
a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=Cb.dtype), ctx)
f(a, b, c)
d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx)
f(a, b, d)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy())
d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)
check_target("cuda", host="stackvm")
check_target("cuda", host="llvm")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment