test_pass_partial_eval.py 10.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

雾雨魔理沙 committed
18 19 20
import numpy as np
import tvm
from tvm import relay
21
from tvm.relay.analysis import alpha_equal, assert_alpha_equal
雾雨魔理沙 committed
22
from tvm.relay.prelude import Prelude
23
from tvm.relay import op, create_executor, transform
雾雨魔理沙 committed
24 25
from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate
from tvm.relay import TensorType, Tuple, If, Module, Clause, PatternConstructor, PatternVar, Match
26
from tvm.relay import GlobalVar, Call
Zhi committed
27
from tvm.relay.transform import gradient
28
from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type
雾雨魔理沙 committed
29 30 31 32 33 34 35 36 37

def check_eval(expr, expected_result, mod=None, rtol=1e-07):
    ctx = tvm.context("llvm", 0)
    intrp = create_executor(mod=mod, ctx=ctx, target="llvm")

    result = intrp.evaluate(expr)
    np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)


Zhi committed
38 39 40 41 42 43
def run_opt_pass(expr, passes):
    passes = passes if isinstance(passes, list) else [passes]
    mod = relay.Module.from_expr(expr)
    seq = transform.Sequential(passes)
    with transform.PassContext(opt_level=3):
       mod = seq(mod)
44
    entry = mod["main"]
Zhi committed
45 46 47
    return entry if isinstance(expr, relay.Function) else entry.body


48
def tipe(expr):
Zhi committed
49 50
    return run_opt_pass(expr, [transform.PartialEvaluate(),
                               transform.InferType()])
51 52 53 54 55 56


def dcpe(expr, mod=None, grad=False):
    passes = [transform.PartialEvaluate(),
              transform.DeadCodeElimination(inline_once=True)]
    if grad:
57
        expr = gradient(run_infer_type(expr))
58 59
    if mod:
        assert isinstance(expr, Function)
60
        mod["main"] = expr
61 62
        seq = transform.Sequential(passes)
        mod = seq(mod)
63
        return mod["main"]
Zhi committed
64
    return run_opt_pass(expr, passes)
雾雨魔理沙 committed
65 66 67


def test_tuple():
雾雨魔理沙 committed
68 69 70 71
    t = TypeVar("t")
    x = Var("x", t)
    body = TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1)
    f = Function([x], body, None, [t])
72
    expected = relay.Function([x], x, None, [t])
Zhi committed
73
    expected = run_opt_pass(expected, transform.InferType())
74 75
    assert alpha_equal(dcpe(f), expected)

雾雨魔理沙 committed
76 77

def test_const_inline():
78 79
    t = relay.TensorType([], "float32")
    d = Var("d", t)
雾雨魔理沙 committed
80 81 82
    double = Function([d], d + d)
    orig = double(const(4.0))
    assert alpha_equal(dcpe(orig), const(8.0))
雾雨魔理沙 committed
83 84 85


def test_ref():
86 87 88
    t = relay.TensorType([], "float32")
    d = relay.Var("d", t)
    r = relay.Var("r", relay.RefType(t))
雾雨魔理沙 committed
89 90
    x = relay.Var("x")
    body = relay.RefRead(r)
雾雨魔理沙 committed
91 92 93
    body = Let(x, RefWrite(r, RefRead(r) * RefRead(r)), body)
    body = Let(r, RefCreate(d), body)
    square = Function([d], body)
Zhi committed
94
    expected = run_opt_pass(Function([d], d * d), transform.InferType())
95
    assert alpha_equal(dcpe(square), expected)
雾雨魔理沙 committed
96 97 98 99 100 101 102 103


def test_empty_ad():
    shape = (10, 10)
    dtype = "float32"
    t = TensorType(shape, dtype)
    d = Var("d", t)
    f = Function([d], d)
104
    g = dcpe(f, grad=True)
雾雨魔理沙 committed
105
    expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])]))
Zhi committed
106
    expected = run_opt_pass(expected, transform.InferType())
雾雨魔理沙 committed
107
    assert alpha_equal(g, expected)
雾雨魔理沙 committed
108

109

雾雨魔理沙 committed
110 111 112
def test_ad():
    shape = (10, 10)
    dtype = "float32"
雾雨魔理沙 committed
113 114 115
    t = TensorType(shape, dtype)
    d = Var("d", t)
    f = Function([d], d * d)
116
    g = dcpe(f, grad=True)
雾雨魔理沙 committed
117
    m = d * d
雾雨魔理沙 committed
118 119 120 121 122 123 124
    x = relay.Var("x")
    o = op.ones_like(x)
    x1 = relay.Var("x1")
    grad = op.zeros_like(d) + op.collapse_sum_like(x1 * d, d) + op.collapse_sum_like(x1 * d, d)
    body = Tuple([x, Tuple([grad])])
    body = relay.Let(x1, o, body)
    expected = Function([d], relay.Let(x, m, body))
Zhi committed
125
    expected = run_opt_pass(expected, transform.InferType())
126
    assert_alpha_equal(g, expected)
雾雨魔理沙 committed
127 128 129 130 131


def test_if_ref():
    shape = ()
    dtype = "bool"
雾雨魔理沙 committed
132 133 134 135 136 137 138 139 140
    t = TensorType(shape, dtype)
    d = Var("d", t)
    r = Var("r")
    update = Function([], RefWrite(r, RefRead(r) + RefRead(r)))
    u = Var("u")
    body = If(d, u(), u())
    eff = Var("eff")
    body = Let(eff, body, RefRead(r))
    f = Function([d], Let(r, RefCreate(const(1)), Let(u, update, body)))
141
    pe_f = tipe(f)
雾雨魔理沙 committed
142
    ex = create_executor()
雾雨魔理沙 committed
143 144
    f_res = ex.evaluate(f)(const(True))
    pe_f_res = ex.evaluate(pe_f)(const(True))
雾雨魔理沙 committed
145 146 147 148 149 150 151
    np.testing.assert_allclose(f_res.asnumpy(), 2 * np.ones_like(f_res.asnumpy()))
    np.testing.assert_allclose(pe_f_res.asnumpy(), 2 * np.ones_like(pe_f_res.asnumpy()))


def test_function_invalidate():
    shape = ()
    dtype = "bool"
雾雨魔理沙 committed
152 153 154 155 156 157 158 159 160 161 162 163 164
    t = TensorType(shape, dtype)
    d = Var("d", t)
    r = Var("r")
    fetch = Function([], RefRead(r))
    fet = Var("fetch")
    fet_obscured = Var("fetch_obscured")
    u = Var("u")
    body = If(d, fet_obscured(), fet_obscured())
    body = Let(u, RefWrite(r, const(1)), body)
    body = Let(fet_obscured, If(d, fet, fet), body)
    body = Let(fet, fetch, body)
    body = Let(r, RefCreate(const(0)), body)
    f = Function([d], body)
165
    pe_f = tipe(f)
雾雨魔理沙 committed
166
    ex = create_executor()
雾雨魔理沙 committed
167 168
    f_res = ex.evaluate(f)(const(True))
    pe_f_res = ex.evaluate(pe_f)(const(True))
雾雨魔理沙 committed
169 170 171 172 173
    np.testing.assert_allclose(f_res.asnumpy(), np.ones_like(f_res.asnumpy()))
    np.testing.assert_allclose(pe_f_res.asnumpy(), np.ones_like(pe_f_res.asnumpy()))


def test_head_cons():
雾雨魔理沙 committed
174
    mod = Module()
雾雨魔理沙 committed
175
    p = Prelude(mod)
176
    hd = p.hd
雾雨魔理沙 committed
177 178
    t = TypeVar("t")
    x = Var("x", t)
179
    body = hd(p.cons(x, p.nil()))
雾雨魔理沙 committed
180
    f = Function([x], body, None, [t])
181
    res = dcpe(f, mod)
雾雨魔理沙 committed
182 183 184 185 186 187
    assert alpha_equal(res, Function([x], x, t, [t]))


def test_map():
    mod = Module()
    p = Prelude(mod)
188 189 190 191
    f = GlobalVar("f")
    t = TypeVar("t")
    a = Var("a", t)
    mod[f] = Function([a], a, t, [t])
雾雨魔理沙 committed
192
    orig = p.map(f, p.cons(const(1), p.cons(const(2), p.cons(const(3), p.nil()))))
193 194
    expected = p.cons((const(1)), p.cons((const(2)), p.cons((const(3)), p.nil())))
    expected = Function([], expected)
195 196
    mod["main"] = expected
    expected = mod["main"]
197 198 199
    orig = Function([], orig)
    res = dcpe(orig, mod=mod)
    assert alpha_equal(res.body, expected.body)
雾雨魔理沙 committed
200 201 202 203 204 205 206 207


def test_loop():
    mod = Module()
    t = TypeVar("t")
    x = Var("x", t)
    loop = GlobalVar("loop")
    mod[loop] = Function([x], loop(x), t, [t])
208
    expected = Call(loop, [const(1)])
209 210
    mod["main"] = Function([], expected)
    expected = mod["main"].body
211 212 213
    call = Function([], loop(const(1)))
    res = dcpe(call, mod=mod)
    assert alpha_equal(res.body, expected)
雾雨魔理沙 committed
214 215 216 217 218 219 220 221 222 223 224 225


def test_swap_loop():
    mod = Module()
    p = Prelude(mod)
    add_nat_definitions(p)
    nat = p.nat()
    x = Var("x", nat)
    y = Var("y", nat)
    loop = GlobalVar("loop")
    mod[loop] = Function([x, y], loop(y, x), nat)
    prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2))
226 227 228
    res = Function([], prog)
    res = dcpe(res, mod=mod)
    assert alpha_equal(prog, res.body)
雾雨魔理沙 committed
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247


def test_abs_diff():
    # TODO(@M.K.): refactor using tuple pattern (not yet implemented)
    mod = Module()
    p = Prelude(mod)
    add_nat_definitions(p)
    nat = p.nat()
    x = Var("x", nat)
    y = Var("y", nat)
    xp = Var("x'", nat)
    yp = Var("y'", nat)
    diff = GlobalVar("diff")
    y_z_case = Clause(PatternConstructor(p.z, []), x)
    y_s_case = Clause(PatternConstructor(p.s, [PatternVar(yp)]), diff(yp, xp))
    x_z_case = Clause(PatternConstructor(p.z, []), y)
    x_s_case = Clause(PatternConstructor(p.s, [PatternVar(xp)]), Match(y, [y_z_case, y_s_case]))
    mod[diff] = Function([x, y], Match(x, [x_z_case, x_s_case]))
    orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3))
248
    orig = Function([], orig)
雾雨魔理沙 committed
249
    res = dcpe(orig, mod=mod)
250
    assert alpha_equal(res.body, make_nat_expr(p, 4))
雾雨魔理沙 committed
251 252 253 254 255 256 257 258 259 260 261 262 263 264


def test_match_nat_id():
    mod = Module()
    p = Prelude(mod)
    add_nat_definitions(p)
    nat = p.nat()
    x = Var("x", nat)
    y = Var("y", nat)
    nat_id = GlobalVar("nat_id")
    z_case = Clause(PatternConstructor(p.z, []), p.z())
    s_case = Clause(PatternConstructor(p.s, [PatternVar(y)]), p.s(y))
    mod[nat_id] = Function([x], Match(x, [z_case, s_case]))
    orig = nat_id(make_nat_expr(p, 3))
265
    orig = Function([], orig)
雾雨魔理沙 committed
266
    res = dcpe(orig, mod=mod)
267
    assert alpha_equal(res.body, make_nat_expr(p, 3))
雾雨魔理沙 committed
268 269 270 271 272 273 274 275 276 277 278 279


def test_nat_id():
    mod = Module()
    p = Prelude(mod)
    add_nat_definitions(p)
    nat = p.nat()
    x = Var("x", nat)
    y = Var("y", nat)
    nat_id = GlobalVar("nat_id")
    mod[nat_id] = Function([x], x)
    orig = nat_id(make_nat_expr(p, 3))
280
    orig = Function([], orig)
雾雨魔理沙 committed
281
    res = dcpe(orig, mod=mod)
282
    assert alpha_equal(res.body, make_nat_expr(p, 3))
雾雨魔理沙 committed
283 284 285 286 287 288 289 290 291 292 293


def test_global_match_nat_id():
    mod = Module()
    p = Prelude(mod)
    add_nat_definitions(p)
    nat = p.nat()
    x = Var("x", nat)
    z_case = Clause(PatternConstructor(p.z, []), p.z())
    s_case = Clause(PatternConstructor(p.s, [PatternVar(x)]), p.s(x))
    orig = Match(make_nat_expr(p, 3), [z_case, s_case])
294
    orig = Function([], orig)
雾雨魔理沙 committed
295
    res = dcpe(orig, mod=mod)
296
    assert alpha_equal(res.body, make_nat_expr(p, 3))
雾雨魔理沙 committed
297 298 299 300 301 302 303


def test_double():
    mod = Module()
    p = Prelude(mod)
    add_nat_definitions(p)
    orig = p.double(make_nat_expr(p, 3))
304
    orig = Function([], orig)
雾雨魔理沙 committed
305
    res = dcpe(orig, mod=mod)
306
    assert alpha_equal(res.body, make_nat_expr(p, 6))
雾雨魔理沙 committed
307 308


309 310 311 312 313
def test_concat():
    t = relay.TensorType([10], "float32")
    x = Var("x", t)
    y = Var("x", t)
    orig = run_infer_type(Function([x, y], op.concatenate([x, y], axis=0)))
314 315 316
    assert_alpha_equal(dcpe(orig), orig)


雾雨魔理沙 committed
317
def test_triangle_number():
318 319 320 321 322 323
    t = relay.TensorType([], "int32")
    x = Var("x", t)
    f_var = Var("f")
    f = Function([x], If(op.equal(x, const(0)), const(0), x + f_var(x - const(1))))
    orig = run_infer_type(Let(f_var, f, f_var(const(10))))
    assert_alpha_equal(dcpe(orig), const(55))
324 325


326 327 328 329 330 331 332 333
def test_nat_update():
    m = Module()
    p = Prelude(m)
    add_nat_definitions(p)
    m = transform.ToANormalForm()(m)
    transform.PartialEvaluate()(m)


334 335 336 337 338 339 340 341
def test_tuple_match():
    a = relay.Var("a")
    b = relay.Var("b")
    clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
    x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
    assert_alpha_equal(dcpe(x), const(2))


雾雨魔理沙 committed
342
if __name__ == '__main__':
343
    test_nat_update()
344
    test_ref()
雾雨魔理沙 committed
345
    test_tuple()
346
    test_empty_ad()
雾雨魔理沙 committed
347 348 349 350 351
    test_const_inline()
    test_ad()
    test_if_ref()
    test_function_invalidate()
    test_head_cons()
雾雨魔理沙 committed
352 353 354 355 356 357 358 359
    test_map()
    test_loop()
    test_swap_loop()
    test_abs_diff()
    test_double()
    test_nat_id()
    test_global_match_nat_id()
    test_match_nat_id()
360
    test_concat()
雾雨魔理沙 committed
361
    test_triangle_number()
362
    test_tuple_match()