test_lang_schedule.py 7.49 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
from nose.tools import raises
18
import tvm
19
import pickle as pkl
20 21

def test_schedule_create():
22 23 24
    m = tvm.var('m')
    n = tvm.var('n')
    l = tvm.var('l')
tqchen committed
25 26
    A = tvm.placeholder((m, l), name='A')
    B = tvm.placeholder((n, l), name='B')
tqchen committed
27
    AA = tvm.compute((m, l), lambda i, j: A[i, j])
28
    T = tvm.compute((m, n, l), lambda i, j, k: AA(i, k) * B(j, k))
29
    s = tvm.create_schedule(T.op)
30 31 32 33 34 35 36
    s[AA].set_scope("shared")
    xo, xi = s[T].split(T.op.axis[0], factor=10)
    xi1, xi2 = s[T].split(xi, factor=2)
    s[AA].compute_at(s[T], xi1)
    xo, xi = s[AA].split(AA.op.axis[0], factor=10)
    s[T].reorder(xi2, xi1)
    assert T.op.axis[1] in s[T].leaf_iter_vars
tqchen committed
37

38 39 40 41
    # save load json
    json_str = tvm.save_json(s)
    s_loaded = tvm.load_json(json_str)
    assert isinstance(s_loaded, tvm.schedule.Schedule)
42
    assert(str(s_loaded.outputs[0].body) == str(s.outputs[0].body))
43 44 45 46 47

    # pickle unpickle
    dump = pkl.dumps(s)
    s_loaded = pkl.loads(dump)
    assert isinstance(s_loaded, tvm.schedule.Schedule)
48
    assert(str(s_loaded.outputs[0].body) == str(s.outputs[0].body))
49

50

tqchen committed
51
def test_reorder():
52
    m = tvm.var('m')
tqchen committed
53 54
    A = tvm.placeholder((m,), name='A')
    T = tvm.compute(m, lambda i: A[i+1])
55

56
    s = tvm.create_schedule(T.op)
57 58
    xo, xi = s[T].split(T.op.axis[0], factor=10)
    xi1, xi2 = s[T].split(xi, factor=2)
tqchen committed
59
    order = (xi2, xi1, xo)
60 61 62
    assert tuple(s[T].leaf_iter_vars) != order
    s[T].reorder(*order)
    assert tuple(s[T].leaf_iter_vars) == order
63 64 65 66 67 68 69
    try:
        # pass duplicate IterVar
        # must raise an error
        s[T].reorder(xi2, xi1, xi2)
        assert False
    except tvm.TVMError:
        pass
70

71
def test_split():
72
    m = tvm.var('m')
73 74 75
    A = tvm.placeholder((m,), name='A')
    T = tvm.compute((m,), lambda i: A[i])

76
    s = tvm.create_schedule(T.op)
77 78
    xo, xi = s[T].split(T.op.axis[0], factor=10)
    assert tuple(s[T].leaf_iter_vars) == (xo, xi)
79 80


ZihengJiang committed
81
def test_tile():
82 83
    m = tvm.var('m')
    n = tvm.var('n')
ZihengJiang committed
84 85 86
    A = tvm.placeholder((m, n), name='A')
    T = tvm.compute((m, n), lambda i, j: A[i, j])

87
    s = tvm.create_schedule(T.op)
88 89
    xo, yo, xi, yi = s[T].tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5)
    assert tuple(s[T].leaf_iter_vars) == (xo, yo, xi, yi)
90

Ziheng Jiang committed
91 92

def test_fuse():
93 94
    m = tvm.var('m')
    n = tvm.var('n')
Ziheng Jiang committed
95 96 97
    A = tvm.placeholder((m, n), name='A')
    T = tvm.compute((m, n), lambda i, j: A[i, j])

98
    s = tvm.create_schedule(T.op)
Ziheng Jiang committed
99
    xo, yo, xi, yi = s[T].tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5)
100
    fused = s[T].fuse(xo, yo)
Ziheng Jiang committed
101 102 103
    assert any(isinstance(x, tvm.schedule.Fuse) for x in s[T].relations)
    assert tuple(s[T].leaf_iter_vars) == (fused, xi, yi)

104 105 106 107 108 109 110 111 112 113 114 115 116

def test_singleton():
    A = tvm.placeholder((), name='A')
    T = tvm.compute((), lambda : A() + 1)
    s = tvm.create_schedule(T.op)
    fused = s[T].fuse()
    assert any(isinstance(x, tvm.schedule.Singleton) for x in s[T].relations)
    assert tuple(s[T].leaf_iter_vars) == (fused,)
    dump = pkl.dumps(s)
    s_loaded = pkl.loads(dump)
    assert isinstance(s_loaded, tvm.schedule.Schedule)


117
def test_vectorize():
118 119
    m = tvm.var('m')
    n = tvm.var('n')
120 121 122
    A = tvm.placeholder((m, n), name='A')
    T = tvm.compute((m, n), lambda i, j: A[i, j])

123
    s = tvm.create_schedule(T.op)
124 125 126
    xo, yo, xi, yi = s[T].tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5)
    s[T].vectorize(yi)
    s[T].unroll(xi)
127 128
    UNROLL = tvm.schedule.IterVar.Unrolled
    VECTORIZE = tvm.schedule.IterVar.Vectorized
129 130 131
    assert s[T].iter_var_attrs[xi].iter_type == UNROLL
    assert s[T].iter_var_attrs[yi].iter_type == VECTORIZE

132 133 134 135 136 137 138
@raises(Exception)
def test_vectorize_commreduce():
    V = tvm.placeholder((128,), name='V')
    ax = tvm.reduce_axis((0, 128), name='ax')
    O = tvm.compute((1,), lambda _: tvm.sum(V[ax], axis=[ax]))
    s = tvm.create_schedule(O.op)
    s[O].vectorize(ax) # should throw here
139 140 141 142 143 144 145 146 147 148 149

def test_pragma():
    m = 100
    A = tvm.placeholder((m,), name='A')
    T = tvm.compute((m,), lambda i: A[i])

    s = tvm.create_schedule(T.op)
    xo, xi = s[T].split(T.op.axis[0], factor=10)
    s[T].pragma(xo, "pragma1")
    s[T].pragma(xi, "vectorize")
    VECTORIZE = tvm.schedule.IterVar.Vectorized
150
    assert s[T].iter_var_attrs[xo].pragma_keys[0].value == "pragma1"
151 152 153
    assert s[T].iter_var_attrs[xi].iter_type == VECTORIZE


154
def test_rfactor():
155
    n = tvm.var('n')
156 157 158 159 160
    k1 = tvm.reduce_axis((0, n), name="k1")
    k2 = tvm.reduce_axis((0, n), name="k2")
    A = tvm.placeholder((n, n, n), name='A')
    B = tvm.compute((n, ), lambda i: tvm.sum(A[i, k1, k2], axis=[k1, k2]))
    # normal schedule
161
    s = tvm.create_schedule(B.op)
162 163
    BF = s.rfactor(B, k1)
    assert(tuple(BF.shape) == (n, n))
164 165
    assert(set(BF.op.body[0].axis) == set([k2]))
    assert(s[B].op.body[0].axis[0].dom.extent == n)
166 167
    assert(len(s[B].all_iter_vars) == 2)
    # schedule with splot
168
    s = tvm.create_schedule(B.op)
169 170 171 172 173
    ko, ki = s[B].split(k1, factor=4)
    xo, xi = s[B].split(B.op.axis[0], factor=8)
    BF = s.rfactor(B, ki)
    assert(BF.shape[0].value == 4)
    assert(BF.shape[1] == n)
174 175 176
    assert(BF.op.body[0].axis[0] ==  k2)
    assert(BF.op.body[0].axis[1].var ==  ko.var)
    assert(s[B].op.body[0].axis[0].dom.extent.value == 4)
177 178 179 180 181 182 183 184 185 186
    # schedule with factor_axis
    s = tvm.create_schedule(B.op)
    ko, ki = s[B].split(k1, factor=4)
    xo, xi = s[B].split(B.op.axis[0], factor=8)
    BF = s.rfactor(B, ki, 1)
    assert(n == BF.shape[0])
    assert(BF.shape[1].value == 4)
    assert(BF.op.body[0].axis[0] ==  k2)
    assert(BF.op.body[0].axis[1].var ==  ko.var)
    assert(s[B].op.body[0].axis[0].dom.extent.value == 4)
187

188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
def test_tensor_intrin():
    n = 16
    x = tvm.placeholder((n,), name='x')
    y = tvm.placeholder((n,), name='y')
    z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z')
    def intrin_func(ins, outs):
        assert(isinstance(ins[0], tvm.schedule.Buffer))
        assert(ins[0].shape[0].value == n)
        return tvm.call_packed("vadd", ins[0].data, outs[0].data, ins[0].shape[0])
    intrin = tvm.decl_tensor_intrin(z.op, intrin_func)
    assert intrin.op == z.op
    assert intrin.reduce_init is None
    assert tuple(intrin.inputs) == tuple(z.op.input_tensors)
    assert(intrin.buffers[0].shape[0].value == n)
    m = 32
    x = tvm.placeholder((m,), name='x')
    y = tvm.placeholder((m,), name='y')
    z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z')
    s = tvm.create_schedule(z.op)
    xo, xi = s[z].split(z.op.axis[0], factor=n)
    s[z].tensorize(xi, intrin)
    assert(s[z].iter_var_attrs[xi].tensor_intrin == intrin)
    assert(s[z].iter_var_attrs[xi].iter_type == tvm.schedule.IterVar.Tensorized)

Ziheng Jiang committed
212

213
if __name__ == "__main__":
214
    test_singleton()
215
    test_pragma()
216
    test_tensor_intrin()
217
    test_rfactor()
218
    test_schedule_create()
tqchen committed
219
    test_reorder()
ZihengJiang committed
220
    test_tile()
221
    test_split()
Ziheng Jiang committed
222
    test_fuse()
223
    test_vectorize()
224
    test_vectorize_commreduce()