test_te_verify_compute.py 2.23 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
import tvm
18
from tvm import te
19 20

def test_verify_compute():
21 22 23 24 25 26
  n = te.size_var("n")
  m = te.size_var("m")
  A = te.placeholder((n, m), name='A')
  k = te.reduce_axis((0, m), "k")
  k_ = te.reduce_axis((0, m-1), "k_")
  f1 = lambda i: te.sum(A[i, k], axis=k)
27
  f2 = lambda i: A[i,0] + 1
28 29 30 31
  f3 = lambda i: te.sum(A[i, k], axis=k) + 1
  f4 = lambda i: A[i,0] * (te.sum(A[i, k], axis=k) + 1)
  f5 = lambda i: (te.sum(A[i, k], axis=k), A[i,0] + 1)
  f6 = lambda i: (te.sum(A[i, k], axis=k), te.sum(A[i, k_], axis=k_))
32 33 34 35

  #
  # Valid compute
  try:
36
    B = te.compute((n,), f1, name="B")
37 38 39 40 41 42
  except tvm._ffi.base.TVMError as ex:
    assert False

  #
  # Valid compute
  try:
43
    B = te.compute((n,), f2, name="B")
44 45 46 47 48 49
  except tvm._ffi.base.TVMError as ex:
    assert False

  #
  # Invalid compute with non top level reduction
  try:
50
    B = te.compute((n,), f3, name="B")
51 52 53 54 55 56 57
    assert False
  except tvm._ffi.base.TVMError as ex:
    pass

  #
  # Invalid compute with non top level reduction
  try:
58
    B = te.compute((n,), f4, name="B")
59 60 61 62 63 64 65
    assert False
  except tvm._ffi.base.TVMError as ex:
    pass

  #
  # Invalid compute with reduction and non-reduction batch ops
  try:
66
    B0, B1 = te.compute((n,), f5, name="B")
67 68 69 70 71 72 73
    assert False
  except tvm._ffi.base.TVMError as ex:
    pass

  #
  # Invalid compute with unequal batch reduction ops
  try:
74
    B0, B1 = te.compute((n,), f6, name="B")
75 76 77 78 79 80
    assert False
  except tvm._ffi.base.TVMError as ex:
    pass


if __name__ == "__main__":
81
  test_verify_compute()