test_op_grad_level3.py 2.35 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
雾雨魔理沙 committed
18
import pytest
SWu committed
19

20 21
import tvm
from tvm import relay
SWu committed
22
from tvm.relay.testing import check_grad, ctx_list, run_infer_type
23
from tvm.relay.transform import gradient
SWu committed
24

25 26 27 28 29 30 31 32 33 34

def test_clip():
    ref = (lambda x: np.where(x > 10.0, np.zeros_like(x),
                     np.where(x < 1.0, np.zeros_like(x), np.ones_like(x))))
    x = relay.var("x", relay.TensorType((10, 4), "float32"))
    y = tvm.relay.clip(x, 1.0, 10.0)

    data = np.random.rand(10, 4).astype("float32") * 11.0
    ref_grad = ref(data)
    fwd_func = relay.Function([x], y)
35
    fwd_func = run_infer_type(fwd_func)
36 37 38 39 40 41 42 43
    bwd_func = run_infer_type(gradient(fwd_func))

    for target, ctx in ctx_list():
        intrp = relay.create_executor(ctx=ctx, target=target)
        op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
        np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)


SWu committed
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
def verify_transpose_grad(d_shape, axes=None):
    data = relay.var("data", relay.TensorType(d_shape, "float32"))
    fwd_func = relay.Function([data], relay.transpose(data, axes=axes))
    check_grad(fwd_func)


def test_transpose_grad():
    verify_transpose_grad((1, 2, 3, 4))
    verify_transpose_grad((1, 2, 3, 4), axes=(0, 2, 3, 1))


def test_negative_grad():
    data = relay.var("data", relay.TensorType((10, 4), "float32"))
    fwd_func = relay.Function([data], relay.negative(data))
    check_grad(fwd_func)


61 62 63 64 65
def test_cast_grad():
    data = relay.var("data", relay.TensorType((10, 4), "float32"))
    fwd_func = relay.Function([data], relay.cast(data, "float64"))
    check_grad(fwd_func)

66
if __name__ == "__main__":
雾雨魔理沙 committed
67
    pytest.main()