api_arith.cc 4.39 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27
/*!
 *  Copyright (c) 2016 by Contributors
 *  Implementation of API functions related to arith
 * \file api_arith.cc
 */
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/api_registry.h>
28
#include <tvm/tensor.h>
29 30 31 32

namespace tvm {
namespace arith {

33
TVM_REGISTER_API("arith.intset_single_point")
34
.set_body_typed(IntSet::single_point);
35

36
TVM_REGISTER_API("arith.intset_vector")
37
.set_body_typed(IntSet::vector);
38

39
TVM_REGISTER_API("arith.intset_interval")
40
.set_body_typed(IntSet::interval);
41

42
TVM_REGISTER_API("arith.DetectLinearEquation")
43
.set_body_typed(DetectLinearEquation);
44

45
TVM_REGISTER_API("arith.DetectClipBound")
46
.set_body_typed(DetectClipBound);
47

48
TVM_REGISTER_API("arith.DeduceBound")
49 50 51 52 53 54 55
.set_body_typed<IntSet(Expr, Expr, Map<Var, IntSet>, Map<Var, IntSet>)>([](
  Expr v, Expr cond,
  const Map<Var, IntSet> hint_map,
  const Map<Var, IntSet> relax_map
) {
  return DeduceBound(v, cond, hint_map, relax_map);
});
56

57 58

TVM_REGISTER_API("arith.DomainTouched")
59
.set_body_typed(DomainTouched);
60 61


62
TVM_REGISTER_API("_IntervalSetGetMin")
63
.set_body_method(&IntSet::min);
64

65
TVM_REGISTER_API("_IntervalSetGetMax")
66
.set_body_method(&IntSet::max);
67

68
TVM_REGISTER_API("_IntSetIsNothing")
69
.set_body_method(&IntSet::is_nothing);
70

71
TVM_REGISTER_API("_IntSetIsEverything")
72
.set_body_method(&IntSet::is_everything);
73

74
TVM_REGISTER_API("arith._make_ConstIntBound")
75
.set_body_typed(ConstIntBoundNode::make);
76 77

TVM_REGISTER_API("arith._make_ModularSet")
78
.set_body_typed(ModularSetNode::make);
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97

TVM_REGISTER_API("arith._CreateAnalyzer")
.set_body([](TVMArgs args, TVMRetValue* ret) {
    using runtime::PackedFunc;
    using runtime::TypedPackedFunc;
    auto self = std::make_shared<Analyzer>();
    auto f = [self](std::string name) -> PackedFunc {
      if (name == "const_int_bound") {
        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
            *ret = self->const_int_bound(args[0]);
          });
      } else if (name == "modular_set") {
        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
            *ret = self->modular_set(args[0]);
        });
      } else if (name == "const_int_bound_update") {
        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
            self->const_int_bound.Update(args[0], args[1], args[2]);
        });
98 99 100 101
      } else if (name == "rewrite_simplify") {
        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
            *ret = self->rewrite_simplify(args[0]);
        });
102 103 104 105
      } else if (name == "canonical_simplify") {
        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
            *ret = self->canonical_simplify(args[0]);
        });
106 107 108 109 110 111 112 113 114 115 116
      } else if (name == "bind") {
        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
            auto& sptr = args[1].node_sptr();
            if (sptr->is_type<Range::ContainerType>()) {
              self->Bind(args[0], args[1].operator Range());
            } else {
              self->Bind(args[0], args[1].operator Expr());
            }
        });
      } else if (name == "enter_constraint_context") {
        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
117 118 119 120
            // can't use make_shared due to noexcept(false) decl in destructor,
            // see https://stackoverflow.com/a/43907314
            auto ctx =
                std::shared_ptr<ConstraintContext>(new ConstraintContext(self.get(), args[0]));
121 122 123 124 125 126 127 128 129 130 131
            auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable {
              ctx.reset();
            };
            *ret = PackedFunc(fexit);
        });
      }
      return PackedFunc();
    };
    *ret = TypedPackedFunc<PackedFunc(std::string)>(f);
});

132 133
}  // namespace arith
}  // namespace tvm