test_runtime_extension.py 941 Bytes
Newer Older
1 2 3
import tvm
import numpy as np

4
@tvm.register_extension
5 6 7 8 9
class MyTensorView(object):
    def __init__(self, arr):
        self.arr = arr

    @property
10 11 12 13 14 15
    def _tvm_handle(self):
        return self.arr._tvm_handle

    @property
    def _tvm_tcode(self):
        return tvm.TypeCode.ARRAY_HANDLE
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36

def test_dltensor_compatible():
    dtype = 'int64'
    n = tvm.var('n')
    Ab = tvm.decl_buffer((n,), dtype)
    i = tvm.var('i')
    ib = tvm.ir_builder.create()
    A = ib.buffer_ptr(Ab)
    with ib.for_range(0, n - 1, "i") as i:
        A[i + 1] = A[i] + 1
    stmt = ib.get()
    fapi = tvm.ir_pass.MakeAPI(stmt, "arange", [Ab], 0)
    fapi = tvm.ir_pass.LowerPackedCall(fapi)
    f = tvm.codegen.build_module(fapi, "stackvm")
    a = tvm.nd.array(np.zeros(10, dtype=dtype))
    aview = MyTensorView(a)
    f(aview)
    np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))

if __name__ == "__main__":
    test_dltensor_compatible()