test_runtime_extension.py 908 Bytes
Newer Older
1 2 3
import tvm
import numpy as np

4
@tvm.register_extension
5
class MyTensorView(object):
6
    _tvm_tcode = tvm.TypeCode.ARRAY_HANDLE
7 8 9 10
    def __init__(self, arr):
        self.arr = arr

    @property
11 12 13
    def _tvm_handle(self):
        return self.arr._tvm_handle

14 15 16 17 18 19 20 21 22 23
def test_dltensor_compatible():
    dtype = 'int64'
    n = tvm.var('n')
    Ab = tvm.decl_buffer((n,), dtype)
    i = tvm.var('i')
    ib = tvm.ir_builder.create()
    A = ib.buffer_ptr(Ab)
    with ib.for_range(0, n - 1, "i") as i:
        A[i + 1] = A[i] + 1
    stmt = ib.get()
24
    fapi = tvm.ir_pass.MakeAPI(stmt, "arange", [Ab], 0, True)
25
    fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
26 27 28 29 30 31 32 33
    f = tvm.codegen.build_module(fapi, "stackvm")
    a = tvm.nd.array(np.zeros(10, dtype=dtype))
    aview = MyTensorView(a)
    f(aview)
    np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))

if __name__ == "__main__":
    test_dltensor_compatible()