Commit acd48e9a by Tianqi Chen Committed by GitHub

[RUNTIME] Enable ext_dev type for quick plugin of device (#542)

* [RUNTIME] Enable ext_dev type for quick plugin of device

* [TEST] Update testcase to cover all computation
parent 581509ab
...@@ -60,4 +60,9 @@ TVM_REGISTER_GLOBAL("tvm_ext.sym_add") ...@@ -60,4 +60,9 @@ TVM_REGISTER_GLOBAL("tvm_ext.sym_add")
Var b = args[1]; Var b = args[1];
*rv = a + b; *rv = a + b;
}); });
TVM_REGISTER_GLOBAL("device_api.ext_dev")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = (*tvm::runtime::Registry::Get("device_api.cpu"))();
});
} // namespace tvm_ext } // namespace tvm_ext
import tvm_ext import tvm_ext
import tvm import tvm
import numpy as np
def test_bind_add(): def test_bind_add():
def add(a, b): def add(a, b):
...@@ -7,6 +8,24 @@ def test_bind_add(): ...@@ -7,6 +8,24 @@ def test_bind_add():
f = tvm_ext.bind_add(add, 1) f = tvm_ext.bind_add(add, 1)
assert f(2) == 3 assert f(2) == 3
def test_ext_dev():
n = 10
A = tvm.placeholder((n,), name='A')
B = tvm.compute((n,), lambda *i: A(*i) + 1.0, name='B')
s = tvm.create_schedule(B.op)
def check_llvm():
if not tvm.module.enabled("llvm"):
return
f = tvm.build(s, [A, B], "ext_dev", "llvm")
ctx = tvm.ext_dev(0)
# launch the kernel.
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
f(a, b)
np.testing.assert_allclose(b.asnumpy(), a.asnumpy() + 1)
check_llvm()
def test_sym_add(): def test_sym_add():
a = tvm.var('a') a = tvm.var('a')
b = tvm.var('b') b = tvm.var('b')
...@@ -26,6 +45,7 @@ def test_ext_vec(): ...@@ -26,6 +45,7 @@ def test_ext_vec():
tvm.convert(ivec_cb)(ivec) tvm.convert(ivec_cb)(ivec)
if __name__ == "__main__": if __name__ == "__main__":
test_ext_dev()
test_ext_vec() test_ext_vec()
test_bind_add() test_bind_add()
test_sym_add() test_sym_add()
...@@ -55,6 +55,9 @@ typedef int64_t tvm_index_t; ...@@ -55,6 +55,9 @@ typedef int64_t tvm_index_t;
/*! \brief Extension device types in TVM */ /*! \brief Extension device types in TVM */
typedef enum { typedef enum {
// Extension DRAM type, used for quickly test extension device
// The device api can differ depending on the xpu driver registered.
kExtDev = 12
// AddExtraTVMType which is not in DLPack here // AddExtraTVMType which is not in DLPack here
} TVMDeviceExtType; } TVMDeviceExtType;
......
...@@ -17,7 +17,7 @@ from . import ir_builder ...@@ -17,7 +17,7 @@ from . import ir_builder
from . import target from . import target
from . import ndarray as nd from . import ndarray as nd
from .ndarray import context, cpu, gpu, opencl, cl, metal, mtl, vpi, rocm from .ndarray import context, cpu, gpu, opencl, cl, metal, mtl, vpi, rocm, ext_dev
from ._ffi.runtime_ctypes import TypeCode from ._ffi.runtime_ctypes import TypeCode
from ._ffi.function import Function from ._ffi.function import Function
......
...@@ -96,7 +96,8 @@ class TVMContext(ctypes.Structure): ...@@ -96,7 +96,8 @@ class TVMContext(ctypes.Structure):
4 : 'opencl', 4 : 'opencl',
8 : 'metal', 8 : 'metal',
9 : 'vpi', 9 : 'vpi',
10: 'rocm' 10: 'rocm',
12: 'ext_dev',
} }
STR2MASK = { STR2MASK = {
'cpu': 1, 'cpu': 1,
...@@ -106,7 +107,8 @@ class TVMContext(ctypes.Structure): ...@@ -106,7 +107,8 @@ class TVMContext(ctypes.Structure):
'opencl': 4, 'opencl': 4,
'metal': 8, 'metal': 8,
'vpi': 9, 'vpi': 9,
'rocm': 10 'rocm': 10,
'ext_dev': 12,
} }
def __init__(self, device_type, device_id): def __init__(self, device_type, device_id):
super(TVMContext, self).__init__() super(TVMContext, self).__init__()
......
...@@ -345,7 +345,7 @@ def build(sch, ...@@ -345,7 +345,7 @@ def build(sch,
else: else:
raise ValueError("unknown function type %d" % func.func_type) raise ValueError("unknown function type %d" % func.func_type)
if not target.startswith("llvm") and target != "stackvm" and not fdevice: if not target.startswith("llvm") and target not in ("stackvm", "ext_dev") and not fdevice:
warnings.warn( warnings.warn(
"Specified target %s, but cannot find device code, did you do bind?" % target) "Specified target %s, but cannot find device code, did you do bind?" % target)
......
...@@ -247,6 +247,10 @@ class RPCSession(object): ...@@ -247,6 +247,10 @@ class RPCSession(object):
"""Construct remote Metal device.""" """Construct remote Metal device."""
return self.context(8, dev_id) return self.context(8, dev_id)
def ext_dev(self, dev_id=0):
"""Construct remote extension device."""
return self.context(12, dev_id)
def upload(self, data, target=None): def upload(self, data, target=None):
"""Upload file to remote runtime temp folder """Upload file to remote runtime temp folder
......
...@@ -120,6 +120,27 @@ def vpi(dev_id=0): ...@@ -120,6 +120,27 @@ def vpi(dev_id=0):
""" """
return TVMContext(9, dev_id) return TVMContext(9, dev_id)
def ext_dev(dev_id=0):
"""Construct a extension device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
Note
----
This API is reserved for quick testing of new
device by plugin device API as ext_dev.
"""
return TVMContext(12, dev_id)
cl = opencl cl = opencl
mtl = metal mtl = metal
......
...@@ -31,6 +31,7 @@ inline std::string DeviceName(int type) { ...@@ -31,6 +31,7 @@ inline std::string DeviceName(int type) {
case kMetal: return "metal"; case kMetal: return "metal";
case kVPI: return "vpi"; case kVPI: return "vpi";
case kROCM: return "rocm"; case kROCM: return "rocm";
case kExtDev: return "ext_dev";
default: LOG(FATAL) << "unknown type =" << type; return "Unknown"; default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
} }
} }
......
...@@ -134,7 +134,7 @@ class ROCMDeviceAPI final : public DeviceAPI { ...@@ -134,7 +134,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
typedef dmlc::ThreadLocalStore<ROCMThreadEntry> ROCMThreadStore; typedef dmlc::ThreadLocalStore<ROCMThreadEntry> ROCMThreadStore;
ROCMThreadEntry::ROCMThreadEntry() ROCMThreadEntry::ROCMThreadEntry()
: pool(kGPU, ROCMDeviceAPI::Global()) { : pool(kROCM, ROCMDeviceAPI::Global()) {
} }
ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() {
......
...@@ -7,7 +7,7 @@ def test_add_pipeline(): ...@@ -7,7 +7,7 @@ def test_add_pipeline():
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B') B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
D = tvm.compute(A.shape, lambda *i: C(*i) + 1, name='C') D = tvm.compute(A.shape, lambda *i: C(*i) + 1, name='D')
s = tvm.create_schedule(D.op) s = tvm.create_schedule(D.op)
# GPU schedule have to split by gridIdx and threadIdx # GPU schedule have to split by gridIdx and threadIdx
...@@ -26,11 +26,11 @@ def test_add_pipeline(): ...@@ -26,11 +26,11 @@ def test_add_pipeline():
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
Cb = tvm.decl_buffer(C.shape, C.dtype, name='C') Db = tvm.decl_buffer(D.shape, D.dtype, name='D')
stmt = tvm.ir_pass.LoopPartition(stmt) stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}, 64) stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, D:Db}, 64)
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0, True) fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Db], 0, True)
fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)] fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0]) fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0])
...@@ -49,10 +49,10 @@ def test_add_pipeline(): ...@@ -49,10 +49,10 @@ def test_add_pipeline():
n = 1027 n = 1027
a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx) a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=Cb.dtype), ctx) d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx)
f(a, b, c) f(a, b, d)
np.testing.assert_allclose( np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy()) d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)
def check_module_save(device, host="stackvm"): def check_module_save(device, host="stackvm"):
if not tvm.module.enabled(host): if not tvm.module.enabled(host):
...@@ -73,10 +73,10 @@ def test_add_pipeline(): ...@@ -73,10 +73,10 @@ def test_add_pipeline():
n = 1027 n = 1027
a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx) a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=Cb.dtype), ctx) d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx)
f(a, b, c) f(a, b, d)
np.testing.assert_allclose( np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy()) d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)
check_target("cuda", host="stackvm") check_target("cuda", host="stackvm")
check_target("cuda", host="llvm") check_target("cuda", host="llvm")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment