api_test.cc 3.28 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
 /*!
 *  Copyright (c) 2018 by Contributors
 *  Code mainly used for test purposes.
 * \file api_test.cc
 */
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/attrs.h>
#include <tvm/api_registry.h>

namespace tvm {
// Attrs used to python API
struct TestAttrs : public AttrsNode<TestAttrs> {
  int axis;
  std::string name;
  Array<Expr> padding;
36
  TypedEnvFunc<int(int)> func;
37 38 39 40 41 42 43 44 45 46 47 48

  TVM_DECLARE_ATTRS(TestAttrs, "attrs.TestAttrs") {
    TVM_ATTR_FIELD(axis)
        .set_default(10)
        .set_lower_bound(1)
        .set_upper_bound(10)
        .describe("axis field");
    TVM_ATTR_FIELD(name)
        .describe("name");
    TVM_ATTR_FIELD(padding)
        .describe("padding of input")
        .set_default(Array<Expr>({0, 0}));
49 50 51
    TVM_ATTR_FIELD(func)
        .describe("some random env function")
        .set_default(TypedEnvFunc<int(int)>(nullptr));
52 53 54 55 56 57 58 59 60
  }
};

TVM_REGISTER_NODE_TYPE(TestAttrs);

TVM_REGISTER_API("_nop")
.set_body([](TVMArgs args,  TVMRetValue *ret) {
  });

61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
TVM_REGISTER_API("_test_wrap_callback")
.set_body([](TVMArgs args,  TVMRetValue *ret) {
    PackedFunc pf = args[0];
    *ret = runtime::TypedPackedFunc<void()>([pf](){
        pf();
      });
  });

TVM_REGISTER_API("_test_raise_error_callback")
.set_body([](TVMArgs args,  TVMRetValue *ret) {
    std::string msg = args[0];
    *ret = runtime::TypedPackedFunc<void()>([msg](){
        LOG(FATAL) << msg;
      });
  });

TVM_REGISTER_API("_test_check_eq_callback")
.set_body([](TVMArgs args,  TVMRetValue *ret) {
    std::string msg = args[0];
    *ret = runtime::TypedPackedFunc<void(int x, int y)>([msg](int x, int y){
        CHECK_EQ(x, y) << msg;
      });
  });

85 86 87 88 89 90 91 92 93 94
TVM_REGISTER_API("_context_test")
.set_body([](TVMArgs args,  TVMRetValue *ret) {
    DLContext ctx = args[0];
    int dtype = args[1];
    int did = args[2];
    CHECK_EQ(static_cast<int>(ctx.device_type), dtype);
    CHECK_EQ(static_cast<int>(ctx.device_id), did);
    *ret = ctx;
  });

95 96 97 98 99 100 101 102 103 104 105 106 107 108

// in src/api_test.cc
void ErrorTest(int x, int y) {
  // raise ValueError
  CHECK_EQ(x, y) << "ValueError: expect x and y to be equal.";
  if (x == 1) {
    // raise InternalError.
    LOG(FATAL) << "InternalError: cannot reach here";
  }
}

TVM_REGISTER_API("_ErrorTest")
.set_body_typed<void(int, int)>(ErrorTest);

109
// internal function used for debug and testing purposes
110 111 112 113 114 115 116 117
TVM_REGISTER_API("_ndarray_use_count")
.set_body([](TVMArgs args,  TVMRetValue *ret) {
    runtime::NDArray nd = args[0];
    // substract the current one
    *ret = (nd.use_count() - 1);
  });

}  // namespace tvm