test_lang_verify_compute.py 1.45 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
import tvm

def test_verify_compute():
  n = tvm.var("n")
  m = tvm.var("m")
  A = tvm.placeholder((n, m), name='A')
  k = tvm.reduce_axis((0, m), "k")
  k_ = tvm.reduce_axis((0, m-1), "k_")
  f1 = lambda i: tvm.sum(A[i, k], axis=k)
  f2 = lambda i: A[i,0] + 1
  f3 = lambda i: tvm.sum(A[i, k], axis=k) + 1
  f4 = lambda i: A[i,0] * (tvm.sum(A[i, k], axis=k) + 1)
  f5 = lambda i: (tvm.sum(A[i, k], axis=k), A[i,0] + 1)
  f6 = lambda i: (tvm.sum(A[i, k], axis=k), tvm.sum(A[i, k_], axis=k_))

  #
  # Valid compute
  try:
    B = tvm.compute((n,), f1, name="B")
  except tvm._ffi.base.TVMError as ex:
    assert False

  #
  # Valid compute
  try:
    B = tvm.compute((n,), f2, name="B")
  except tvm._ffi.base.TVMError as ex:
    assert False

  #
  # Invalid compute with non top level reduction
  try:
    B = tvm.compute((n,), f3, name="B")
    assert False
  except tvm._ffi.base.TVMError as ex:
    pass

  #
  # Invalid compute with non top level reduction
  try:
    B = tvm.compute((n,), f4, name="B")
    assert False
  except tvm._ffi.base.TVMError as ex:
    pass

  #
  # Invalid compute with reduction and non-reduction batch ops
  try:
    B0, B1 = tvm.compute((n,), f5, name="B")
    assert False
  except tvm._ffi.base.TVMError as ex:
    pass

  #
  # Invalid compute with unequal batch reduction ops
  try:
    B0, B1 = tvm.compute((n,), f6, name="B")
    assert False
  except tvm._ffi.base.TVMError as ex:
    pass


if __name__ == "__main__":
  test_verify_compute()