test_runtime_packed_func.py 10.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
import tvm
18
from tvm import te
19
import tvm.testing
20 21 22 23 24 25 26 27
import numpy as np

def test_get_global():
    targs = (10, 10.0, "hello")
    # register into global function table
    @tvm.register_func
    def my_packed_func(*args):
        assert(tuple(args) == targs)
28
        return 10
29 30
    # get it out from global function table
    f = tvm.get_global_func("my_packed_func")
31
    assert isinstance(f, tvm.runtime.PackedFunc)
32 33 34
    y = f(*targs)
    assert y == 10

35
def test_get_callback_with_node():
36
    x = tvm.runtime.convert(10)
37 38 39 40
    def test(y):
        assert y.handle != x.handle
        return y

41
    f2 = tvm.runtime.convert(test)
42 43 44 45 46 47 48 49
    # register into global function table
    @tvm.register_func
    def my_callback_with_node(y, f):
        assert y == x
        return f(y)

    # get it out from global function table
    f = tvm.get_global_func("my_callback_with_node")
50
    assert isinstance(f, tvm.runtime.PackedFunc)
51 52 53
    y = f(x, f2)
    assert(y.value == 10)

54 55 56 57

def test_return_func():
    def addy(y):
        def add(x):
58
            return tvm.runtime.convert(x + y)
59
        return add
60
    myf = tvm.runtime.convert(addy)
61 62
    f = myf(10)
    assert f(11).value == 21
63 64 65 66 67 68 69 70


def test_convert():
    # convert a function to tvm function
    targs = (10, 10.0, "hello", 10)
    def myfunc(*args):
        assert(tuple(args) == targs)

71
    f = tvm.runtime.convert(myfunc)
72
    assert isinstance(f, tvm.runtime.PackedFunc)
73

74 75 76 77 78 79
def test_byte_array():
    s = "hello"
    a = bytearray(s, encoding="ascii")

    def myfunc(ss):
        assert ss == a
80
    f = tvm.runtime.convert(myfunc)
81
    f(a)
82

83

84 85 86
def test_empty_array():
    def myfunc(ss):
        assert tuple(ss) == ()
87 88
    x = tvm.runtime.convert(())
    tvm.runtime.convert(myfunc)(x)
89 90


91 92 93 94 95 96 97
def test_ctx():
    def test_ctx_func(ctx):
        assert tvm.gpu(7) == ctx
        return tvm.cpu(0)
    x = test_ctx_func(tvm.gpu(7))
    assert x == tvm.cpu(0)
    x = tvm.opencl(10)
98
    x = tvm.testing.context_test(x, x.device_type, x.device_id)
99 100
    assert x == tvm.opencl(10)

101 102
def test_trace_default_action():
    n = 2
103 104 105
    x = te.placeholder((n,n,n), name="X", dtype="float32")
    y = te.compute(x.shape, lambda i, j, k: tvm.tir.trace([i, j, k, x[i][j][k]]))
    s = te.create_schedule(y.op)
106 107 108 109 110 111
    f = tvm.build(s, [x, y], target="llvm")
    xnd = tvm.nd.array(np.ones((n,n,n), dtype=x.dtype))
    ynd = tvm.nd.array(np.zeros((n,n,n), dtype=y.dtype))
    f(xnd, ynd)

def test_trace_expr_assign():
112
    @tvm.register_func("tvm.tir.trace_callback2")
113 114 115 116 117
    def trace_buffer(x):
        return

    def check_assign(dtype):
        n = 4
118 119 120 121
        x = te.placeholder((n,n,n), name="X", dtype=dtype)
        y = te.compute(x.shape, lambda i, j, k: tvm.tir.trace([x[i][j][k]], "tvm.tir.trace_callback2"))
        z = te.compute(x.shape, lambda i, j, k: tvm.tir.trace([y[i][j][k]], "tvm.tir.trace_callback2"))
        s = te.create_schedule(z.op)
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
        f = tvm.build(s, [x, y, z], "llvm")

        xnd = tvm.nd.array(np.ones((n,n,n), dtype=x.dtype))
        ynd = tvm.nd.array(np.zeros((n,n,n), dtype=y.dtype))
        znd = tvm.nd.array(np.zeros((n,n,n), dtype=z.dtype))
        f(xnd, ynd, znd)

        assert(np.array_equal(xnd.asnumpy(), np.ones((n,n,n))))
        assert(np.array_equal(ynd.asnumpy(), np.ones((n,n,n))))
        assert(np.array_equal(znd.asnumpy(), np.ones((n,n,n))))

    for t in ["float64", "float32", "int64", "int32"]:
        check_assign(t)

def test_trace_expr_sum_generated():
137
    @tvm.register_func("tvm.tir.trace_callback3")
138 139 140 141 142
    def trace_buffer(x):
        return

    def check_expr_sum(dtype):
        n = 4
143 144 145 146 147
        a = te.placeholder((n,n,n), name="a", dtype=dtype)
        b = te.placeholder((n,n,n), name="b", dtype=dtype)
        c = te.compute(a.shape, lambda i, j, k: tvm.tir.trace([a[i][j][k]],"tvm.tir.trace_callback3")
                                         + tvm.tir.trace([b[i][j][k]],"tvm.tir.trace_callback3"))
        s = te.create_schedule(c.op)
148 149 150 151 152 153 154 155 156 157 158
        f = tvm.build(s, [a, b, c])
        xnd = tvm.nd.array(np.array(np.ones((n,n,n), dtype=a.dtype)))
        ynd = tvm.nd.array(np.array(np.ones((n,n,n), dtype=b.dtype)))
        znd = tvm.nd.array(np.zeros((n,n,n), dtype=c.dtype))
        f(xnd, ynd, znd)
        assert(np.array_equal(znd.asnumpy(), xnd.asnumpy() + ynd.asnumpy()))

    for t in ["float64", "float32", "int64", "int32"]:
        check_expr_sum(t)

def test_trace_expr_sum_args():
159
    @tvm.register_func("tvm.tir.trace_silent")
160 161 162 163 164
    def silent(*args):
      return

    def check_expr_sum(dtype):
        n = 4
165 166 167 168 169 170 171 172 173 174
        a = te.placeholder((n,n,n), name="a", dtype=dtype)
        b = te.placeholder((n,n,n), name="b", dtype=dtype)
        e = te.placeholder((n,n,n), name="e", dtype=dtype)
        d = te.placeholder((n,n,n), name="d", dtype=dtype)

        c = te.compute(a.shape, lambda i, j, k: tvm.tir.trace([i, j, k, a[i][j][k]], "tvm.tir.trace_silent")
                                               + tvm.tir.trace([i, j, k, b[i][j][k]], "tvm.tir.trace_silent")
                                               + tvm.tir.trace([i, j, k, d[i][j][k]], "tvm.tir.trace_silent")
                                               + tvm.tir.trace([i, j, k, e[i][j][k]], "tvm.tir.trace_silent"))
        s = te.create_schedule(c.op)
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
        f = tvm.build(s, [a, b, d, e, c])
        a_nd = tvm.nd.array(np.array(np.ones((n,n,n), dtype=a.dtype)))
        b_nd = tvm.nd.array(np.array(np.ones((n,n,n), dtype=b.dtype)))
        d_nd = tvm.nd.array(np.array(np.ones((n,n,n), dtype=d.dtype)))
        e_nd = tvm.nd.array(np.array(np.ones((n,n,n), dtype=e.dtype)))
        c_nd = tvm.nd.array(np.zeros((n,n,n), dtype=c.dtype))
        f(a_nd, b_nd, d_nd, e_nd, c_nd)
        assert(np.array_equal(c_nd.asnumpy(), a_nd.asnumpy()
                                            + b_nd.asnumpy()
                                            + d_nd.asnumpy()
                                            + e_nd.asnumpy()))

    for t in ["float64", "float32", "int64", "int32"]:
        check_expr_sum(t)

def test_trace_expr_sum_custom():
191
    @tvm.register_func("tvm.tir.trace_callback4")
192 193 194 195 196
    def trace_buffer(x):
        return

    def check_expr_sum_custom(dtype):
        n = 4
197 198 199 200 201
        a = te.placeholder((n,n), name="a", dtype=dtype)
        b = te.placeholder((n,n), name="b", dtype=dtype)
        c = te.compute(a.shape, lambda i,j: tvm.tir.trace([a[i][j]], "tvm.tir.trace_callback4")
                                         + tvm.tir.trace([b[i][j]], "tvm.tir.trace_callback4"))
        s = te.create_schedule(c.op)
202 203 204 205 206 207 208 209 210 211 212 213 214
        f = tvm.build(s, [a, b, c])
        npa = np.array([[1,0,0,0], [0,1,0,0],[0,0,1,0],[0,0,0,1]], dtype=a.dtype)
        npb = np.array([[1,0,0,0], [0,1,0,0],[0,0,1,0],[0,0,0,1]], dtype=a.dtype)
        xnd = tvm.nd.array(npa)
        ynd = tvm.nd.array(npb)
        znd = tvm.nd.array(np.zeros((n,n), dtype=c.dtype))
        f(xnd, ynd, znd)
        assert(np.array_equal(znd.asnumpy(), npa + npb))

    for t in ["float64", "float32", "int64", "int32"]:
        check_expr_sum_custom(t)

def test_trace_can_change_traced_value_int():
215
    @tvm.register_func("tvm.tir.trace_change_int_first")
216 217 218
    def trace_buffer(x):
        return 13

219
    @tvm.register_func("tvm.tir.trace_change_int_second")
220 221 222 223 224
    def trace_buffer(x):
        return 14

    def check_assign(dtype):
        n = 4
225 226 227 228
        x = te.placeholder((n,), name="X", dtype=dtype)
        y = te.compute(x.shape, lambda i: tvm.tir.trace([x[i]], "tvm.tir.trace_change_int_first"))
        z = te.compute(x.shape, lambda i: tvm.tir.trace([y[i]], "tvm.tir.trace_change_int_second"))
        s = te.create_schedule(z.op)
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
        f = tvm.build(s, [x, y, z], "llvm")

        xnd = tvm.nd.array(np.ones((n,), dtype=x.dtype))
        ynd = tvm.nd.array(np.zeros((n,), dtype=y.dtype))
        znd = tvm.nd.array(np.zeros((n,), dtype=z.dtype))
        f(xnd, ynd, znd)
        check_array_first = np.array([13, 13, 13, 13])
        check_array_second = np.array([14, 14, 14, 14])
        assert(np.array_equal(ynd.asnumpy(), check_array_first))
        assert(np.array_equal(znd.asnumpy(), check_array_second))

    for t in ["int64", "int32"]:
        check_assign(t)

def test_trace_can_change_traced_value_float():
244
    @tvm.register_func("tvm.tir.trace_change_float_first")
245 246 247
    def trace_buffer(x):
        return 13.0

248
    @tvm.register_func("tvm.tir.trace_change_float_second")
249 250 251 252 253
    def trace_buffer(x):
        return 14.0

    def check_assign(dtype):
        n = 4
254 255 256 257
        x = te.placeholder((n,), name="X", dtype=dtype)
        y = te.compute(x.shape, lambda i: tvm.tir.trace([x[i]], "tvm.tir.trace_change_float_first"))
        z = te.compute(x.shape, lambda i: tvm.tir.trace([y[i]], "tvm.tir.trace_change_float_second"))
        s = te.create_schedule(z.op)
258 259 260 261 262 263 264 265 266 267 268 269 270 271
        f = tvm.build(s, [x, y, z], "llvm")

        xnd = tvm.nd.array(np.ones((n,), dtype=x.dtype))
        ynd = tvm.nd.array(np.zeros((n,), dtype=y.dtype))
        znd = tvm.nd.array(np.zeros((n,), dtype=z.dtype))
        f(xnd, ynd, znd)
        check_array_first = np.array([13.0, 13.0, 13.0, 13.0])
        check_array_second = np.array([14.0, 14.0, 14.0, 14.0])
        assert(np.array_equal(ynd.asnumpy(), check_array_first))
        assert(np.array_equal(znd.asnumpy(), check_array_second))

    for t in ["float64", "float32"]:
        check_assign(t)

272
if __name__ == "__main__":
273
    test_empty_array()
274
    test_get_global()
275
    test_get_callback_with_node()
276
    test_convert()
277
    test_return_func()
278
    test_byte_array()
279
    test_ctx()
280 281 282 283 284 285 286
    test_trace_expr_assign()
    test_trace_expr_sum_generated()
    test_trace_expr_sum_custom()
    test_trace_expr_sum_args()
    test_trace_default_action()
    test_trace_can_change_traced_value_int()
    test_trace_can_change_traced_value_float()