Unverified Commit 18295b27 by Tianqi Chen Committed by GitHub

[REFACTOR] Polish ffi convention. (#4912)

* [REFACTOR] Polish ffi convention.

- Remove the src/api, keep registration local to the c++ function.
- Remove the api_internal as it is no longer needed.

* Update the codebase walk through
parent fccf2268
......@@ -133,7 +133,7 @@ file(GLOB_RECURSE COMPILER_SRCS
src/tir/*.cc
src/driver/*.cc
src/printer/*.cc
src/api/*.cc
src/support/*.cc
)
file(GLOB CODEGEN_SRCS
......
......@@ -55,7 +55,7 @@ We use a simple example that uses the low level TVM API directly. The example is
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")
Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/tensor.h`` and ``src/lang/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``Object``.
Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/te/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/te/tensor.h`` and ``src/te/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``Object``.
::
......@@ -68,24 +68,12 @@ Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``pytho
The object protocol is the basis of exposing C++ types to frontend languages, including Python. The way TVM implements Python wrapping is not straightforward. It is briefly covered in `this document <https://docs.tvm.ai/dev/runtime.html#tvm-node-and-compiler-stack>`_, and details are in ``python/tvm/_ffi/`` if you are interested.
``Tensor`` is created by functions in ``python/tvm/api.py``, which in turn calls into C++ functions exposed in ``src/api/api_lang.cc``. All C++ functions that are callable from Python are exposed in the ``src/api`` subdirectory. For example, the ``tvm.compute()`` function above calls into ``_ComputeOp`` API exposed in ``src/api/api_lang.cc``:
::
TVM_REGISTER_GLOBAL("_ComputeOp")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ComputeOpNode::make(args[0],
args[1],
args[2],
args[3],
args[4]);
});
We use the ``TVM_REGISTER_*`` macro to expose C++ functions to frontend languages, in the form of a `PackedFunc <https://docs.tvm.ai/dev/runtime.html#packedfunc>`_. A ``PackedFunc`` is another mechanism by which TVM implements interoperability between C++ and Python. In particular, this is what makes calling Python functions from the C++ codebase very easy.
You can also checkout `FFI Navigator <https://github.com/tqchen/ffi-navigator>`_ which allows you to navigate between python and c++ FFI calls.
A ``Tensor`` object has an ``Operation`` object associated with it, defined in ``python/tvm/tensor.py``, ``include/tvm/operation.h``, and ``src/tvm/op`` subdirectory. A ``Tensor`` is an output of its ``Operation`` object. Each ``Operation`` object has in turn ``input_tensors()`` method, which returns a list of input ``Tensor`` to it. This way we can keep track of dependencies between ``Operation``.
A ``Tensor`` object has an ``Operation`` object associated with it, defined in ``python/tvm/te/tensor.py``, ``include/tvm/te/operation.h``, and ``src/tvm/te/operation`` subdirectory. A ``Tensor`` is an output of its ``Operation`` object. Each ``Operation`` object has in turn ``input_tensors()`` method, which returns a list of input ``Tensor`` to it. This way we can keep track of dependencies between ``Operation``.
We pass the operation corresponding to the output tensor ``C`` to ``tvm.create_schedule()`` function in ``python/tvm/schedule.py``.
We pass the operation corresponding to the output tensor ``C`` to ``tvm.create_schedule()`` function in ``python/tvm/te/schedule.py``.
::
......@@ -103,7 +91,7 @@ This function is mapped to the C++ function in ``include/tvm/schedule.h``.
``Stage`` corresponds to one ``Operation``. In the vector add example above, there are two placeholder ops and one compute op, so the schedule ``s`` contains three stages. Each ``Stage`` holds information about a loop nest structure, types of each loop (``Parallel``, ``Vectorized``, ``Unrolled``), and where to execute its computation in the loop nest of the next ``Stage``, if any.
``Schedule`` and ``Stage`` are defined in ``tvm/python/schedule.py``, ``include/tvm/schedule.h``, and ``src/schedule/schedule_ops.cc``.
``Schedule`` and ``Stage`` are defined in ``tvm/python/te/schedule.py``, ``include/tvm/te/schedule.h``, and ``src/te/schedule/schedule_ops.cc``.
To keep it simple, we call ``tvm.build(...)`` on the default schedule created by ``create_schedule()`` function above.
......@@ -112,7 +100,7 @@ To keep it simple, we call ``tvm.build(...)`` on the default schedule created by
target = "cuda"
fadd = tvm.build(s, [A, B, C], target)
``tvm.build()``, defined in ``python/tvm/build_module.py``, takes a schedule, input and output ``Tensor``, and a target, and returns a ``tvm.Module`` object, defined in ``python/tvm/module.py``. A ``Module`` object contains a compiled function which can be invoked with function call syntax.
``tvm.build()``, defined in ``python/tvm/driver/build_module.py``, takes a schedule, input and output ``Tensor``, and a target, and returns a :py:class:`tvm.runtime.Module` object. A :py:class:`tvm.runtime.Module` object contains a compiled function which can be invoked with function call syntax.
The process of ``tvm.build()`` can be divided into two steps:
......@@ -133,14 +121,14 @@ Lowering is done by ``tvm.lower()`` function, defined in ``python/tvm/build_modu
stmt = schedule.ScheduleOps(sch, bounds)
...
Bound inference is the process where all loop bounds and sizes of intermediate buffers are inferred. If you target the CUDA backend and you use shared memory, its required minimum size is automatically determined here. Bound inference is implemented in ``src/schedule/bound.cc``, ``src/schedule/graph.cc`` and ``src/schedule/message_passing.cc``. For more information on how bound inference works, see `InferBound Pass`_.
Bound inference is the process where all loop bounds and sizes of intermediate buffers are inferred. If you target the CUDA backend and you use shared memory, its required minimum size is automatically determined here. Bound inference is implemented in ``src/te/schedule/bound.cc``, ``src/te/schedule/graph.cc`` and ``src/te/schedule/message_passing.cc``. For more information on how bound inference works, see `InferBound Pass`_.
.. _InferBound Pass: http://docs.tvm.ai/dev/inferbound.html
``stmt``, which is the output of ``ScheduleOps()``, represents an initial loop nest structure. If you have applied ``reorder`` or ``split`` primitives to your schedule, then the initial loop nest already reflects those changes. ``ScheduleOps()`` is defined in ``src/schedule/schedule_ops.cc``.
``stmt``, which is the output of ``ScheduleOps()``, represents an initial loop nest structure. If you have applied ``reorder`` or ``split`` primitives to your schedule, then the initial loop nest already reflects those changes. ``ScheduleOps()`` is defined in ``src/te/schedule/schedule_ops.cc``.
Next, we apply a number of lowering passes to ``stmt``. These passes are implemented in ``src/pass`` subdirectory. For example, if you have applied ``vectorize`` or ``unroll`` primitives to your schedule, they are applied in loop vectorization and unrolling passes below.
Next, we apply a number of lowering passes to ``stmt``. These passes are implemented in ``src/tir/pass`` subdirectory. For example, if you have applied ``vectorize`` or ``unroll`` primitives to your schedule, they are applied in loop vectorization and unrolling passes below.
::
......@@ -157,7 +145,7 @@ Next, we apply a number of lowering passes to ``stmt``. These passes are impleme
After lowering is done, ``build()`` function generates target machine code from the lowered function. This code can contain SSE or AVX instructions if you target x86, or PTX instructions for CUDA target. In addition to target specific machine code, TVM also generates host side code that is responsible for memory management, kernel launch etc.
Code generation is done by ``build_module()`` function, defined in ``python/tvm/codegen.py``. On the C++ side, code generation is implemented in ``src/codegen`` subdirectory. ``build_module()`` Python function will reach ``Build()`` function below in ``src/codegen/codegen.cc``:
Code generation is done by ``build_module()`` function, defined in ``python/tvm/target/codegen.py``. On the C++ side, code generation is implemented in ``src/target/codegen`` subdirectory. ``build_module()`` Python function will reach ``Build()`` function below in ``src/target/codegen/codegen.cc``:
::
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Namespace of internal API
The functions in this namespace are automatically exported from C++ side via PackedFunc
that is registered by "TVM_REGISTER_*" macro. This way makes calling Python functions from C++
side very easily.
Each string starts with "_" in the "TVM_REGISTER_*" macro is an internal API. You can find
all the functions in "api_lang.cc", "api_base.cc", "api_arith.cc" and "api_ir.cc" under "src/api".
"""
......@@ -19,7 +19,6 @@
"""FFI registry to register function and objects."""
import sys
import ctypes
from .. import _api_internal
from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE, _RUNTIME_ONLY
......@@ -288,17 +287,11 @@ def _init_api_prefix(module_name, prefix):
module = sys.modules[module_name]
for name in list_global_func_names():
if prefix == "api":
fname = name
if name.startswith("_"):
target_module = sys.modules["tvm._api_internal"]
else:
target_module = module
else:
if not name.startswith(prefix):
continue
fname = name[len(prefix)+1:]
target_module = module
if not name.startswith(prefix):
continue
fname = name[len(prefix)+1:]
target_module = module
if fname.find(".") != -1:
continue
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Implementation of API functions related to arith
* \file api_arith.cc
*/
#include <tvm/arith/bound.h>
#include <tvm/arith/int_set.h>
#include <tvm/arith/pattern.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/tensor.h>
namespace tvm {
namespace arith {
TVM_REGISTER_GLOBAL("arith.intset_single_point")
.set_body_typed(IntSet::single_point);
TVM_REGISTER_GLOBAL("arith.intset_vector")
.set_body_typed(IntSet::vector);
TVM_REGISTER_GLOBAL("arith.intset_interval")
.set_body_typed(IntSet::interval);
TVM_REGISTER_GLOBAL("arith.DetectLinearEquation")
.set_body_typed(DetectLinearEquation);
TVM_REGISTER_GLOBAL("arith.DetectClipBound")
.set_body_typed(DetectClipBound);
TVM_REGISTER_GLOBAL("arith.DeduceBound")
.set_body_typed([](
PrimExpr v, PrimExpr cond,
const Map<Var, IntSet> hint_map,
const Map<Var, IntSet> relax_map
) {
return DeduceBound(v, cond, hint_map, relax_map);
});
TVM_REGISTER_GLOBAL("arith.DomainTouched")
.set_body_typed(DomainTouched);
TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin")
.set_body_method(&IntSet::min);
TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax")
.set_body_method(&IntSet::max);
TVM_REGISTER_GLOBAL("arith.IntSetIsNothing")
.set_body_method(&IntSet::is_nothing);
TVM_REGISTER_GLOBAL("arith.IntSetIsEverything")
.set_body_method(&IntSet::is_everything);
ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) {
return ConstIntBound(min_value, max_value);
}
TVM_REGISTER_GLOBAL("arith.ConstIntBound")
.set_body_typed(MakeConstIntBound);
ModularSet MakeModularSet(int64_t coeff, int64_t base) {
return ModularSet(coeff, base);
}
TVM_REGISTER_GLOBAL("arith.ModularSet")
.set_body_typed(MakeModularSet);
TVM_REGISTER_GLOBAL("arith.CreateAnalyzer")
.set_body([](TVMArgs args, TVMRetValue* ret) {
using runtime::PackedFunc;
using runtime::TypedPackedFunc;
auto self = std::make_shared<Analyzer>();
auto f = [self](std::string name) -> PackedFunc {
if (name == "const_int_bound") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->const_int_bound(args[0]);
});
} else if (name == "modular_set") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->modular_set(args[0]);
});
} else if (name == "const_int_bound_update") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
self->const_int_bound.Update(args[0], args[1], args[2]);
});
} else if (name == "Simplify") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->Simplify(args[0]);
});
} else if (name == "rewrite_simplify") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->rewrite_simplify(args[0]);
});
} else if (name == "canonical_simplify") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->canonical_simplify(args[0]);
});
} else if (name == "int_set") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->int_set(args[0], args[1]);
});
} else if (name == "bind") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
if (args[1].IsObjectRef<Range>()) {
self->Bind(args[0], args[1].operator Range());
} else {
self->Bind(args[0], args[1].operator PrimExpr());
}
});
} else if (name == "enter_constraint_context") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
// can't use make_shared due to noexcept(false) decl in destructor,
// see https://stackoverflow.com/a/43907314
auto ctx = std::shared_ptr<With<ConstraintContext> >(
new With<ConstraintContext>(self.get(), args[0]));
auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable {
ctx.reset();
};
*ret = PackedFunc(fexit);
});
}
return PackedFunc();
};
*ret = TypedPackedFunc<PackedFunc(std::string)>(f);
});
} // namespace arith
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Implementation of API functions related to IR build
* \file api_ir.cc
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/expr.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/op.h>
namespace tvm {
namespace tir {
TVM_REGISTER_GLOBAL("tir.Var")
.set_body_typed([](std::string s, DataType t) {
return Var(s, t);
});
TVM_REGISTER_GLOBAL("tir.SizeVar")
.set_body_typed([](std::string s, DataType t) {
return SizeVar(s, t);
});
TVM_REGISTER_GLOBAL("tir.abs")
.set_body_typed(tvm::abs);
TVM_REGISTER_GLOBAL("tir.isnan")
.set_body_typed(tvm::isnan);
TVM_REGISTER_GLOBAL("tir.floor")
.set_body_typed(tvm::floor);
TVM_REGISTER_GLOBAL("tir.ceil")
.set_body_typed(tvm::ceil);
TVM_REGISTER_GLOBAL("tir.round")
.set_body_typed(tvm::round);
TVM_REGISTER_GLOBAL("tir.nearbyint")
.set_body_typed(tvm::nearbyint);
TVM_REGISTER_GLOBAL("tir.trunc")
.set_body_typed(tvm::trunc);
TVM_REGISTER_GLOBAL("tir._cast")
.set_body_typed(tvm::cast);
TVM_REGISTER_GLOBAL("ir.range_by_min_extent")
.set_body_typed(Range::make_by_min_extent);
TVM_REGISTER_GLOBAL("tir.SeqStmt")
.set_body_typed([](Array<Stmt> seq) {
return SeqStmt(std::move(seq));
});
TVM_REGISTER_GLOBAL("tir.For")
.set_body_typed([](
Var loop_var, PrimExpr min, PrimExpr extent,
int for_type, int device_api, Stmt body) {
return ForNode::make(loop_var,
min,
extent,
static_cast<ForType>(for_type),
static_cast<DeviceAPI>(device_api),
body);
});
TVM_REGISTER_GLOBAL("tir.Load")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DataType t = args[0];
if (args.size() == 3) {
*ret = LoadNode::make(t, args[1], args[2], const_true(t.lanes()));
} else {
*ret = LoadNode::make(t, args[1], args[2], args[3]);
}
});
TVM_REGISTER_GLOBAL("tir.Store")
.set_body([](TVMArgs args, TVMRetValue *ret) {
PrimExpr value = args[1];
if (args.size() == 3) {
*ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes()));
} else {
*ret = StoreNode::make(args[0], value, args[2], args[3]);
}
});
TVM_REGISTER_GLOBAL("tir.Realize")
.set_body_typed(RealizeNode::make);
TVM_REGISTER_GLOBAL("tir.Call")
.set_body_typed([](
DataType type, std::string name,
Array<PrimExpr> args, int call_type,
FunctionRef func, int value_index
) {
return CallNode::make(type,
name,
args,
static_cast<CallNode::CallType>(call_type),
func,
value_index);
});
TVM_REGISTER_GLOBAL("tir.CommReducer")
.set_body_typed(CommReducerNode::make);
// make from two arguments
#define REGISTER_MAKE(NodeName) \
TVM_REGISTER_GLOBAL("tir."#NodeName) \
.set_body_typed(NodeName ## Node::make); \
REGISTER_MAKE(Reduce);
REGISTER_MAKE(AttrStmt);
REGISTER_MAKE(StringImm);
REGISTER_MAKE(Add);
REGISTER_MAKE(Sub);
REGISTER_MAKE(Mul);
REGISTER_MAKE(Div);
REGISTER_MAKE(Mod);
REGISTER_MAKE(FloorDiv);
REGISTER_MAKE(FloorMod);
REGISTER_MAKE(Min);
REGISTER_MAKE(Max);
REGISTER_MAKE(EQ);
REGISTER_MAKE(NE);
REGISTER_MAKE(LT);
REGISTER_MAKE(LE);
REGISTER_MAKE(GT);
REGISTER_MAKE(GE);
REGISTER_MAKE(And);
REGISTER_MAKE(Or);
REGISTER_MAKE(Not);
REGISTER_MAKE(Select);
REGISTER_MAKE(Ramp);
REGISTER_MAKE(Cast);
REGISTER_MAKE(Broadcast);
REGISTER_MAKE(Shuffle);
REGISTER_MAKE(Let);
REGISTER_MAKE(LetStmt);
REGISTER_MAKE(AssertStmt);
REGISTER_MAKE(ProducerConsumer);
REGISTER_MAKE(Provide);
REGISTER_MAKE(Prefetch);
REGISTER_MAKE(Free);
REGISTER_MAKE(IfThenElse);
REGISTER_MAKE(Evaluate);
// overloaded, needs special handling
// has default args
TVM_REGISTER_GLOBAL("tir.Allocate")
.set_body_typed([](
Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition, Stmt body
){
return AllocateNode::make(buffer_var, type, extents, condition, body);
});
// operator overloading, smarter than make
#define REGISTER_MAKE_BINARY_OP(Node, Func) \
TVM_REGISTER_GLOBAL("tir."#Node) \
.set_body_typed([](PrimExpr a, PrimExpr b) { \
return (Func(a, b)); \
})
#define REGISTER_MAKE_BIT_OP(Node, Func) \
TVM_REGISTER_GLOBAL("tir."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
bool lhs_is_int = args[0].type_code() == kDLInt; \
bool rhs_is_int = args[1].type_code() == kDLInt; \
if (lhs_is_int) { \
*ret = (Func(args[0].operator int(), args[1].operator PrimExpr())); \
} else if (rhs_is_int) { \
*ret = (Func(args[0].operator PrimExpr(), args[1].operator int())); \
} else { \
*ret = (Func(args[0].operator PrimExpr(), args[1].operator PrimExpr())); \
} \
})
REGISTER_MAKE_BINARY_OP(_OpAdd, operator+);
REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
REGISTER_MAKE_BINARY_OP(_OpMul, operator*);
REGISTER_MAKE_BINARY_OP(_OpDiv, div);
REGISTER_MAKE_BINARY_OP(_OpMod, truncmod);
REGISTER_MAKE_BINARY_OP(_OpIndexDiv, indexdiv);
REGISTER_MAKE_BINARY_OP(_OpIndexMod, indexmod);
REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv);
REGISTER_MAKE_BINARY_OP(_OpTruncMod, truncmod);
REGISTER_MAKE_BINARY_OP(_OpPow, pow);
REGISTER_MAKE_BINARY_OP(_OpMin, min);
REGISTER_MAKE_BINARY_OP(_OpMax, max);
REGISTER_MAKE_BINARY_OP(_OpEQ, operator==);
REGISTER_MAKE_BINARY_OP(_OpNE, operator!=);
REGISTER_MAKE_BINARY_OP(_OpLT, operator<); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(_OpLE, operator<=); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(_OpGT, operator>); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(_OpGE, operator>=);
REGISTER_MAKE_BINARY_OP(_OpAnd, operator&&);
REGISTER_MAKE_BINARY_OP(_OpOr, operator||);
REGISTER_MAKE_BIT_OP(bitwise_and, operator&);
REGISTER_MAKE_BIT_OP(bitwise_or, operator|);
REGISTER_MAKE_BIT_OP(bitwise_xor, operator^);
REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
REGISTER_MAKE_BIT_OP(right_shift, operator>>);
TVM_REGISTER_GLOBAL("tir._OpIfThenElse")
.set_body_typed([] (PrimExpr cond, PrimExpr true_value, PrimExpr false_value) {
return if_then_else(cond, true_value, false_value);
});
} // namespace tir
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Implementation of API functions related to Higher DSL build.
* \file api_lang.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/te/tensor.h>
#include <tvm/te/operation.h>
#include <tvm/tir/buffer.h>
#include <tvm/te/schedule.h>
#include <tvm/runtime/registry.h>
#include <tvm/driver/driver_api.h>
#include <tvm/tir/data_layout.h>
namespace tvm {
TVM_REGISTER_GLOBAL("tir.min_value")
.set_body_typed(min_value);
TVM_REGISTER_GLOBAL("tir.max_value")
.set_body_typed(max_value);
TVM_REGISTER_GLOBAL("ir.Range")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Range(args[0], args[1]);
});
namespace tir {
TVM_REGISTER_GLOBAL("tir.IterVar")
.set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) {
return IterVarNode::make(
dom, var,
static_cast<IterVarType>(iter_type),
thread_tag);
});
}
namespace te {
TVM_REGISTER_GLOBAL("te.Tensor")
.set_body_typed(TensorNode::make);
TVM_REGISTER_GLOBAL("te.TensorIntrin")
.set_body_typed(TensorIntrinNode::make);
TVM_REGISTER_GLOBAL("te.TensorIntrinCall")
.set_body_typed(TensorIntrinCallNode::make);
TVM_REGISTER_GLOBAL("te.TensorEqual")
.set_body_method(&Tensor::operator==);
TVM_REGISTER_GLOBAL("te.TensorHash")
.set_body_typed([](Tensor tensor) -> int64_t {
return static_cast<int64_t>(std::hash<Tensor>()(tensor));
});
TVM_REGISTER_GLOBAL("te.Placeholder")
.set_body_typed([](Array<PrimExpr> shape, DataType dtype, std::string name) {
return placeholder(shape, dtype, name);
});
TVM_REGISTER_GLOBAL("te.ComputeOp")
.set_body_typed(ComputeOpNode::make);
TVM_REGISTER_GLOBAL("te.ScanOp")
.set_body_typed(ScanOpNode::make);
TVM_REGISTER_GLOBAL("te.TensorComputeOp")
.set_body_typed(TensorComputeOpNode::make);
TVM_REGISTER_GLOBAL("te.ExternOp")
.set_body_typed(ExternOpNode::make);
TVM_REGISTER_GLOBAL("te.HybridOp")
.set_body_typed(HybridOpNode::make);
TVM_REGISTER_GLOBAL("te.OpGetOutput")
.set_body_typed([](Operation op, int64_t output) {
return op.output(static_cast<size_t>(output));
});
TVM_REGISTER_GLOBAL("te.OpNumOutputs")
.set_body_method<Operation>(&OperationNode::num_outputs);
TVM_REGISTER_GLOBAL("te.OpInputTensors")
.set_body_method<Operation>(&OperationNode::InputTensors);
TVM_REGISTER_GLOBAL("te.CreateSchedule")
.set_body_typed(create_schedule);
TVM_REGISTER_GLOBAL("te.StageSetScope")
.set_body_method(&Stage::set_scope);
TVM_REGISTER_GLOBAL("te.StageBind")
.set_body_method(&Stage::bind);
TVM_REGISTER_GLOBAL("te.StageSplitByFactor")
.set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) {
IterVar outer, inner;
stage.split(parent, factor, &outer, &inner);
return Array<IterVar>({outer, inner});
});
TVM_REGISTER_GLOBAL("te.StageSplitByNParts")
.set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) {
IterVar outer, inner;
stage.split_by_nparts(parent, nparts, &outer, &inner);
return Array<IterVar>({outer, inner});
});
TVM_REGISTER_GLOBAL("te.StageFuse")
.set_body_typed([](Stage stage, Array<IterVar> axes) {
IterVar fused;
stage.fuse(axes, &fused);
return fused;
});
TVM_REGISTER_GLOBAL("te.StageComputeAt")
.set_body_method(&Stage::compute_at);
TVM_REGISTER_GLOBAL("te.StageComputeInline")
.set_body_method(&Stage::compute_inline);
TVM_REGISTER_GLOBAL("te.StageComputeRoot")
.set_body_method(&Stage::compute_root);
TVM_REGISTER_GLOBAL("te.StageReorder")
.set_body_method(&Stage::reorder);
TVM_REGISTER_GLOBAL("te.StageTile")
.set_body_typed([](
Stage stage,
IterVar x_parent, IterVar y_parent,
PrimExpr x_factor, PrimExpr y_factor
) {
IterVar x_outer, y_outer, x_inner, y_inner;
stage.tile(x_parent, y_parent,
x_factor, y_factor,
&x_outer, &y_outer,
&x_inner, &y_inner);
return Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
});
TVM_REGISTER_GLOBAL("te.StageEnvThreads")
.set_body_method(&Stage::env_threads);
TVM_REGISTER_GLOBAL("te.StageSetStorePredicate")
.set_body_method(&Stage::set_store_predicate);
TVM_REGISTER_GLOBAL("te.StageUnroll")
.set_body_method(&Stage::unroll);
TVM_REGISTER_GLOBAL("te.StageVectorize")
.set_body_method(&Stage::vectorize);
TVM_REGISTER_GLOBAL("te.StageTensorize")
.set_body_method(&Stage::tensorize);
TVM_REGISTER_GLOBAL("te.StageParallel")
.set_body_method(&Stage::parallel);
TVM_REGISTER_GLOBAL("te.StagePragma")
.set_body_method(&Stage::pragma);
TVM_REGISTER_GLOBAL("te.StagePrefetch")
.set_body_method(&Stage::prefetch);
TVM_REGISTER_GLOBAL("te.StageStorageAlign")
.set_body_method(&Stage::storage_align);
TVM_REGISTER_GLOBAL("te.StageDoubleBuffer")
.set_body_method(&Stage::double_buffer);
TVM_REGISTER_GLOBAL("te.StageOpenGL")
.set_body_method(&Stage::opengl);
TVM_REGISTER_GLOBAL("te.ScheduleNormalize")
.set_body_method(&Schedule::normalize);
TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup")
.set_body_method(&Schedule::create_group);
TVM_REGISTER_GLOBAL("te.ScheduleCacheRead")
.set_body_method(&Schedule::cache_read);
TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite")
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args[1].IsObjectRef<Tensor>()) {
*ret = args[0].operator Schedule()
.cache_write(args[1].operator Tensor(), args[2]);
} else {
*ret = args[0].operator Schedule()
.cache_write(args[1].operator Array<Tensor>(), args[2]);
}
});
TVM_REGISTER_GLOBAL("te.ScheduleRFactor")
.set_body_method(&Schedule::rfactor);
} // namespace te
TVM_REGISTER_GLOBAL("te.CommReducerCombine")
.set_body_method<tir::CommReducer>(&tir::CommReducerNode::operator());
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Implementation of API functions related to schedule pass.
* \file api_schedule.cc
*/
#include <tvm/tir/expr.h>
#include <tvm/te/tensor.h>
#include <tvm/te/schedule.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/runtime/registry.h>
#include "../te/schedule/graph.h"
namespace tvm {
namespace te {
TVM_REGISTER_GLOBAL("schedule.AutoInlineElemWise")
.set_body_typed(AutoInlineElemWise);
TVM_REGISTER_GLOBAL("schedule.AutoInlineInjective")
.set_body_typed(AutoInlineInjective);
TVM_REGISTER_GLOBAL("schedule.ScheduleOps")
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 2)
*ret = ScheduleOps(args[0], args[1], false);
else
*ret = ScheduleOps(args[0], args[1], args[2]);
});
#define REGISTER_SCHEDULE_PASS(PassName) \
TVM_REGISTER_GLOBAL("schedule."#PassName) \
.set_body_typed(PassName); \
REGISTER_SCHEDULE_PASS(InferBound);
REGISTER_SCHEDULE_PASS(CreateReadGraph);
REGISTER_SCHEDULE_PASS(PostDFSOrder);
REGISTER_SCHEDULE_PASS(CreateAttachPath);
REGISTER_SCHEDULE_PASS(ScanGetBody);
REGISTER_SCHEDULE_PASS(ScanFixPointAnalysis);
} // namespace te
} // namespace tvm
......@@ -20,6 +20,7 @@
/*!
* \file tvm/arith/analyzer.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/op.h>
......@@ -109,5 +110,64 @@ PrimExpr Analyzer::Simplify(const PrimExpr& expr) {
return res;
}
TVM_REGISTER_GLOBAL("arith.CreateAnalyzer")
.set_body([](TVMArgs args, TVMRetValue* ret) {
using runtime::PackedFunc;
using runtime::TypedPackedFunc;
auto self = std::make_shared<Analyzer>();
auto f = [self](std::string name) -> PackedFunc {
if (name == "const_int_bound") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->const_int_bound(args[0]);
});
} else if (name == "modular_set") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->modular_set(args[0]);
});
} else if (name == "const_int_bound_update") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
self->const_int_bound.Update(args[0], args[1], args[2]);
});
} else if (name == "Simplify") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->Simplify(args[0]);
});
} else if (name == "rewrite_simplify") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->rewrite_simplify(args[0]);
});
} else if (name == "canonical_simplify") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->canonical_simplify(args[0]);
});
} else if (name == "int_set") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->int_set(args[0], args[1]);
});
} else if (name == "bind") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
if (args[1].IsObjectRef<Range>()) {
self->Bind(args[0], args[1].operator Range());
} else {
self->Bind(args[0], args[1].operator PrimExpr());
}
});
} else if (name == "enter_constraint_context") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
// can't use make_shared due to noexcept(false) decl in destructor,
// see https://stackoverflow.com/a/43907314
auto ctx = std::shared_ptr<With<ConstraintContext> >(
new With<ConstraintContext>(self.get(), args[0]));
auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable {
ctx.reset();
};
*ret = PackedFunc(fexit);
});
}
return PackedFunc();
};
*ret = TypedPackedFunc<PackedFunc(std::string)>(f);
});
} // namespace arith
} // namespace tvm
......@@ -21,11 +21,11 @@
* \file bound_deducer.cc
* \brief Utility to deduce bound of expression
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <unordered_set>
#include <unordered_map>
......@@ -362,5 +362,16 @@ IntSet DeduceBound(PrimExpr v, PrimExpr e,
return DeduceBound(v, e, hmap, rmap);
}
TVM_REGISTER_GLOBAL("arith.DeduceBound")
.set_body_typed([](
PrimExpr v, PrimExpr cond,
const Map<Var, IntSet> hint_map,
const Map<Var, IntSet> relax_map
) {
return DeduceBound(v, cond, hint_map, relax_map);
});
} // namespace arith
} // namespace tvm
......@@ -20,6 +20,7 @@
/*!
* \file tvm/arith/const_int_bound.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr_functor.h>
#include <algorithm>
......@@ -41,6 +42,13 @@ ConstIntBound::ConstIntBound(
data_ = std::move(node);
}
ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) {
return ConstIntBound(min_value, max_value);
}
TVM_REGISTER_GLOBAL("arith.ConstIntBound")
.set_body_typed(MakeConstIntBound);
inline void PrintBoundValue(std::ostream& os, int64_t val) {
if (val == ConstIntBound::kPosInf) {
os << "pos_inf";
......
......@@ -21,6 +21,7 @@
* \file detect_linear_equation.cc
* \brief Utility to detect patterns in the expression.
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/expr_functor.h>
......@@ -268,6 +269,12 @@ Array<PrimExpr> DetectClipBound(const PrimExpr& e, const Array<Var>& vars) {
return ret;
}
TVM_REGISTER_GLOBAL("arith.DetectLinearEquation")
.set_body_typed(DetectLinearEquation);
TVM_REGISTER_GLOBAL("arith.DetectClipBound")
.set_body_typed([](const PrimExpr& e, const Array<Var>& vars) {
return DetectClipBound(e, vars);
});
} // namespace arith
} // namespace tvm
......@@ -119,5 +119,8 @@ Domain DomainTouched(Stmt stmt,
return FuncTouchedDomain(tensor, consider_calls, consider_provides).Find(stmt);
}
TVM_REGISTER_GLOBAL("arith.DomainTouched")
.set_body_typed(DomainTouched);
} // namespace arith
} // namespace tvm
......@@ -820,5 +820,28 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< "[" << op->min_value << ", "
<< op->max_value << ']';
});
TVM_REGISTER_GLOBAL("arith.intset_single_point")
.set_body_typed(IntSet::single_point);
TVM_REGISTER_GLOBAL("arith.intset_vector")
.set_body_typed(IntSet::vector);
TVM_REGISTER_GLOBAL("arith.intset_interval")
.set_body_typed(IntSet::interval);
TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin")
.set_body_method(&IntSet::min);
TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax")
.set_body_method(&IntSet::max);
TVM_REGISTER_GLOBAL("arith.IntSetIsNothing")
.set_body_method(&IntSet::is_nothing);
TVM_REGISTER_GLOBAL("arith.IntSetIsEverything")
.set_body_method(&IntSet::is_everything);
} // namespace arith
} // namespace tvm
......@@ -21,6 +21,7 @@
* \file modular_set.cc
* \brief Modular set analysis
*/
#include <tvm/runtime/registry.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/op.h>
#include <tvm/tir/expr_functor.h>
......@@ -52,6 +53,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< op->base << ')';
});
ModularSet MakeModularSet(int64_t coeff, int64_t base) {
return ModularSet(coeff, base);
}
TVM_REGISTER_GLOBAL("arith.ModularSet")
.set_body_typed(MakeModularSet);
// internal entry for const int bound
struct ModularSetAnalyzer::Entry {
......
......@@ -134,6 +134,14 @@ Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) {
return Range(make_object<RangeNode>(min, extent));
}
TVM_REGISTER_GLOBAL("ir.range_by_min_extent")
.set_body_typed(Range::make_by_min_extent);
TVM_REGISTER_GLOBAL("ir.Range")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Range(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RangeNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const RangeNode*>(node.get());
......
......@@ -18,13 +18,13 @@
*/
/*!
* Code mainly used for test purposes.
* \file api_test.cc
* FFI registration code used for frontend testing purposes.
* \file ffi_testing.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/te/tensor.h>
#include <tvm/ir/attrs.h>
#include <tvm/runtime/registry.h>
#include <tvm/ir/env_func.h>
namespace tvm {
......
......@@ -21,6 +21,7 @@
* \brief Compute Op.
* \file compute_op.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
......@@ -156,6 +157,10 @@ Operation ComputeOpNode::make(std::string name,
return Operation(n);
}
TVM_REGISTER_GLOBAL("te.ComputeOp")
.set_body_typed(ComputeOpNode::make);
// The schedule related logics
Array<Tensor> ComputeOpNode::InputTensors() const {
Array<Tensor> ret;
......
......@@ -21,6 +21,7 @@
* \brief External computation rule.
* \file extern_op.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
......@@ -86,6 +87,10 @@ Operation ExternOpNode::make(std::string name,
return Operation(n);
}
TVM_REGISTER_GLOBAL("te.ExternOp")
.set_body_typed(ExternOpNode::make);
Array<Tensor> ExternOpNode::InputTensors() const {
return inputs;
}
......
......@@ -21,6 +21,7 @@
* \brief Hybrid computation rule.
* \file hybrid_op.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
......@@ -83,6 +84,10 @@ Operation HybridOpNode::make(std::string name,
return res;
}
TVM_REGISTER_GLOBAL("te.HybridOp")
.set_body_typed(HybridOpNode::make);
Array<Tensor> HybridOpNode::InputTensors() const {
// Because input tensors could be potentially inlined into hybrid scripts,
// we need to check if all input tensors are used in the body.
......
......@@ -21,6 +21,7 @@
* \brief Placeholder op.
* \file placeholder_op.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
namespace tvm {
......@@ -67,6 +68,11 @@ Tensor placeholder(Array<PrimExpr> shape, DataType dtype, std::string name) {
return PlaceholderOpNode::make(name, shape, dtype).output(0);
}
TVM_REGISTER_GLOBAL("te.Placeholder")
.set_body_typed([](Array<PrimExpr> shape, DataType dtype, std::string name) {
return placeholder(shape, dtype, name);
});
Array<Tensor> PlaceholderOpNode::InputTensors() const {
return {};
}
......
......@@ -21,6 +21,7 @@
* \brief Scan Operator.
* \file scan_op.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
......@@ -120,6 +121,10 @@ Operation ScanOpNode::make(std::string name,
return Operation(n);
}
TVM_REGISTER_GLOBAL("te.ScanOp")
.set_body_typed(ScanOpNode::make);
Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
......
......@@ -21,6 +21,7 @@
* \brief Tensor Compute Op.
* \file tensor_compute_op.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
......@@ -72,6 +73,10 @@ Operation TensorComputeOpNode::make(std::string name,
return Operation(n);
}
TVM_REGISTER_GLOBAL("te.TensorComputeOp")
.set_body_typed(TensorComputeOpNode::make);
Array<Tensor> TensorComputeOpNode::InputTensors() const {
return inputs;
}
......
......@@ -20,6 +20,7 @@
/*!
* \file auto_inline_elem_wise.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/te/operation.h>
#include <tvm/tir/expr_functor.h>
......@@ -111,5 +112,12 @@ void AutoInlineInjective(Schedule sch) {
}
}
TVM_REGISTER_GLOBAL("schedule.AutoInlineElemWise")
.set_body_typed(AutoInlineElemWise);
TVM_REGISTER_GLOBAL("schedule.AutoInlineInjective")
.set_body_typed(AutoInlineInjective);
} // namespace te
} // namespace tvm
......@@ -21,6 +21,7 @@
* \file bound.cc
* \brief The bound inference logic.
*/
#include <tvm/runtime/registry.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/te/operation.h>
#include <tvm/tir/ir_pass.h>
......@@ -259,5 +260,8 @@ Map<IterVar, Range> InferBound(const Schedule& sch) {
return Map<IterVar, Range>(ret.begin(), ret.end());
}
TVM_REGISTER_GLOBAL("schedule.InferBound")
.set_body_typed(InferBound);
} // namespace te
} // namespace tvm
......@@ -21,6 +21,7 @@
* \file graph.cc
* \brief Utilities to get information about schedule graph.
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/te/operation.h>
......@@ -429,5 +430,24 @@ Map<IterVar, PrimExpr> ScanFixPointAnalysis(const Operation& scan_op) {
return ret;
}
TVM_REGISTER_GLOBAL("schedule.CreateReadGraph")
.set_body_typed(CreateReadGraph);
TVM_REGISTER_GLOBAL("schedule.PostDFSOrder")
.set_body_typed([](const Array<Operation>& roots,
const ReadGraph& g) {
return PostDFSOrder(roots, g);
});
TVM_REGISTER_GLOBAL("schedule.CreateAttachPath")
.set_body_typed(CreateAttachPath);
TVM_REGISTER_GLOBAL("schedule.ScanGetBody")
.set_body_typed(ScanGetBody);
TVM_REGISTER_GLOBAL("schedule.ScanFixPointAnalysis")
.set_body_typed(ScanFixPointAnalysis);
} // namespace te
} // namespace tvm
......@@ -20,6 +20,7 @@
/*!
* \file schedule_lang.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/te/schedule.h>
#include <tvm/te/operation.h>
#include <unordered_set>
......@@ -848,5 +849,118 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
auto* op = static_cast<const ScheduleNode*>(node.get());
p->stream << "schedule(" << op << ")";
});
TVM_REGISTER_GLOBAL("te.CreateSchedule")
.set_body_typed(create_schedule);
TVM_REGISTER_GLOBAL("te.StageSetScope")
.set_body_method(&Stage::set_scope);
TVM_REGISTER_GLOBAL("te.StageBind")
.set_body_method(&Stage::bind);
TVM_REGISTER_GLOBAL("te.StageSplitByFactor")
.set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) {
IterVar outer, inner;
stage.split(parent, factor, &outer, &inner);
return Array<IterVar>({outer, inner});
});
TVM_REGISTER_GLOBAL("te.StageSplitByNParts")
.set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) {
IterVar outer, inner;
stage.split_by_nparts(parent, nparts, &outer, &inner);
return Array<IterVar>({outer, inner});
});
TVM_REGISTER_GLOBAL("te.StageFuse")
.set_body_typed([](Stage stage, Array<IterVar> axes) {
IterVar fused;
stage.fuse(axes, &fused);
return fused;
});
TVM_REGISTER_GLOBAL("te.StageComputeAt")
.set_body_method(&Stage::compute_at);
TVM_REGISTER_GLOBAL("te.StageComputeInline")
.set_body_method(&Stage::compute_inline);
TVM_REGISTER_GLOBAL("te.StageComputeRoot")
.set_body_method(&Stage::compute_root);
TVM_REGISTER_GLOBAL("te.StageReorder")
.set_body_method(&Stage::reorder);
TVM_REGISTER_GLOBAL("te.StageTile")
.set_body_typed([](
Stage stage,
IterVar x_parent, IterVar y_parent,
PrimExpr x_factor, PrimExpr y_factor
) {
IterVar x_outer, y_outer, x_inner, y_inner;
stage.tile(x_parent, y_parent,
x_factor, y_factor,
&x_outer, &y_outer,
&x_inner, &y_inner);
return Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
});
TVM_REGISTER_GLOBAL("te.StageEnvThreads")
.set_body_method(&Stage::env_threads);
TVM_REGISTER_GLOBAL("te.StageSetStorePredicate")
.set_body_method(&Stage::set_store_predicate);
TVM_REGISTER_GLOBAL("te.StageUnroll")
.set_body_method(&Stage::unroll);
TVM_REGISTER_GLOBAL("te.StageVectorize")
.set_body_method(&Stage::vectorize);
TVM_REGISTER_GLOBAL("te.StageTensorize")
.set_body_method(&Stage::tensorize);
TVM_REGISTER_GLOBAL("te.StageParallel")
.set_body_method(&Stage::parallel);
TVM_REGISTER_GLOBAL("te.StagePragma")
.set_body_method(&Stage::pragma);
TVM_REGISTER_GLOBAL("te.StagePrefetch")
.set_body_method(&Stage::prefetch);
TVM_REGISTER_GLOBAL("te.StageStorageAlign")
.set_body_method(&Stage::storage_align);
TVM_REGISTER_GLOBAL("te.StageDoubleBuffer")
.set_body_method(&Stage::double_buffer);
TVM_REGISTER_GLOBAL("te.StageOpenGL")
.set_body_method(&Stage::opengl);
TVM_REGISTER_GLOBAL("te.ScheduleNormalize")
.set_body_method(&Schedule::normalize);
TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup")
.set_body_method(&Schedule::create_group);
TVM_REGISTER_GLOBAL("te.ScheduleCacheRead")
.set_body_method(&Schedule::cache_read);
TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite")
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args[1].IsObjectRef<Tensor>()) {
*ret = args[0].operator Schedule()
.cache_write(args[1].operator Tensor(), args[2]);
} else {
*ret = args[0].operator Schedule()
.cache_write(args[1].operator Array<Tensor>(), args[2]);
}
});
TVM_REGISTER_GLOBAL("te.ScheduleRFactor")
.set_body_method(&Schedule::rfactor);
} // namespace te
} // namespace tvm
......@@ -20,6 +20,7 @@
/*!
* \file schedule_ops.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/stmt_functor.h>
......@@ -423,5 +424,13 @@ Stmt ScheduleOps(
return post_proc(std::move(body));
}
TVM_REGISTER_GLOBAL("schedule.ScheduleOps")
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 2)
*ret = ScheduleOps(args[0], args[1], false);
else
*ret = ScheduleOps(args[0], args[1], args[2]);
});
} // namespace te
} // namespace tvm
......@@ -20,6 +20,7 @@
/*!
* \file tensor.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/te/tensor.h>
#include <tvm/te/operation.h>
#include <tvm/te/tensor_intrin.h>
......@@ -147,5 +148,33 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode);
TVM_REGISTER_GLOBAL("te.Tensor")
.set_body_typed(TensorNode::make);
TVM_REGISTER_GLOBAL("te.TensorIntrin")
.set_body_typed(TensorIntrinNode::make);
TVM_REGISTER_GLOBAL("te.TensorIntrinCall")
.set_body_typed(TensorIntrinCallNode::make);
TVM_REGISTER_GLOBAL("te.TensorEqual")
.set_body_method(&Tensor::operator==);
TVM_REGISTER_GLOBAL("te.TensorHash")
.set_body_typed([](Tensor tensor) -> int64_t {
return static_cast<int64_t>(std::hash<Tensor>()(tensor));
});
TVM_REGISTER_GLOBAL("te.OpGetOutput")
.set_body_typed([](Operation op, int64_t output) {
return op.output(static_cast<size_t>(output));
});
TVM_REGISTER_GLOBAL("te.OpNumOutputs")
.set_body_method<Operation>(&OperationNode::num_outputs);
TVM_REGISTER_GLOBAL("te.OpInputTensors")
.set_body_method<Operation>(&OperationNode::InputTensors);
} // namespace te
} // namespace tvm
......@@ -20,6 +20,7 @@
/*!
* \file expr.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/op.h>
......@@ -45,6 +46,17 @@ SizeVar::SizeVar(std::string name_hint, DataType t)
SizeVarNode::SizeVarNode(DataType t, std::string name_hint)
: VarNode(t, std::move(name_hint)) {}
TVM_REGISTER_GLOBAL("tir.Var")
.set_body_typed([](std::string s, DataType t) {
return Var(s, t);
});
TVM_REGISTER_GLOBAL("tir.SizeVar")
.set_body_typed([](std::string s, DataType t) {
return SizeVar(s, t);
});
IterVar IterVarNode::make(Range dom,
Var var,
IterVarType t,
......@@ -57,6 +69,14 @@ IterVar IterVarNode::make(Range dom,
return IterVar(n);
}
TVM_REGISTER_GLOBAL("tir.IterVar")
.set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) {
return IterVarNode::make(
dom, var,
static_cast<IterVarType>(iter_type),
thread_tag);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IterVarNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IterVarNode*>(node.get());
......@@ -83,6 +103,9 @@ PrimExpr StringImmNode::make(std::string value) {
return PrimExpr(node);
}
TVM_REGISTER_GLOBAL("tir.StringImm")
.set_body_typed(StringImmNode::make);
PrimExpr CastNode::make(DataType t, PrimExpr value) {
CHECK(value.defined());
CHECK_EQ(t.lanes(), value.dtype().lanes());
......@@ -311,6 +334,13 @@ Array<PrimExpr> CommReducerNode::operator()(Array<PrimExpr> a, Array<PrimExpr> b
});
}
TVM_REGISTER_GLOBAL("tir.CommReducer")
.set_body_typed(CommReducerNode::make);
TVM_REGISTER_GLOBAL("tir.CommReducerCombine")
.set_body_method<tir::CommReducer>(&tir::CommReducerNode::operator());
PrimExpr ReduceNode::make(CommReducer combiner, Array<PrimExpr> source,
Array<IterVar> axis, PrimExpr condition, int value_index) {
for (size_t i = 0; i < axis.size(); ++i) {
......@@ -334,6 +364,11 @@ PrimExpr ReduceNode::make(CommReducer combiner, Array<PrimExpr> source,
return PrimExpr(n);
}
TVM_REGISTER_GLOBAL("tir.Reduce")
.set_body_typed(ReduceNode::make);
PrimExpr AnyNode::make() {
auto n = make_object<AnyNode>();
return PrimExpr(n);
......@@ -659,5 +694,104 @@ TVM_REGISTER_NODE_TYPE(CommReducerNode);
TVM_REGISTER_NODE_TYPE(ReduceNode);
TVM_REGISTER_NODE_TYPE(AnyNode);
TVM_REGISTER_GLOBAL("tir.Add")
.set_body_typed(AddNode::make);
TVM_REGISTER_GLOBAL("tir.Sub")
.set_body_typed(SubNode::make);
TVM_REGISTER_GLOBAL("tir.Mul")
.set_body_typed(MulNode::make);
TVM_REGISTER_GLOBAL("tir.Div")
.set_body_typed(DivNode::make);
TVM_REGISTER_GLOBAL("tir.Mod")
.set_body_typed(ModNode::make);
TVM_REGISTER_GLOBAL("tir.FloorDiv")
.set_body_typed(FloorDivNode::make);
TVM_REGISTER_GLOBAL("tir.FloorMod")
.set_body_typed(FloorModNode::make);
TVM_REGISTER_GLOBAL("tir.Min")
.set_body_typed(MinNode::make);
TVM_REGISTER_GLOBAL("tir.Max")
.set_body_typed(MaxNode::make);
TVM_REGISTER_GLOBAL("tir.EQ")
.set_body_typed(EQNode::make);
TVM_REGISTER_GLOBAL("tir.NE")
.set_body_typed(NENode::make);
TVM_REGISTER_GLOBAL("tir.LT")
.set_body_typed(LTNode::make);
TVM_REGISTER_GLOBAL("tir.LE")
.set_body_typed(LENode::make);
TVM_REGISTER_GLOBAL("tir.GT")
.set_body_typed(GTNode::make);
TVM_REGISTER_GLOBAL("tir.GE")
.set_body_typed(GENode::make);
TVM_REGISTER_GLOBAL("tir.And")
.set_body_typed(AndNode::make);
TVM_REGISTER_GLOBAL("tir.Or")
.set_body_typed(OrNode::make);
TVM_REGISTER_GLOBAL("tir.Not")
.set_body_typed(NotNode::make);
TVM_REGISTER_GLOBAL("tir.Select")
.set_body_typed(SelectNode::make);
TVM_REGISTER_GLOBAL("tir.Ramp")
.set_body_typed(RampNode::make);
TVM_REGISTER_GLOBAL("tir.Cast")
.set_body_typed(CastNode::make);
TVM_REGISTER_GLOBAL("tir.Broadcast")
.set_body_typed(BroadcastNode::make);
TVM_REGISTER_GLOBAL("tir.Shuffle")
.set_body_typed(ShuffleNode::make);
TVM_REGISTER_GLOBAL("tir.Let")
.set_body_typed(LetNode::make);
TVM_REGISTER_GLOBAL("tir.Load")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DataType t = args[0];
if (args.size() == 3) {
*ret = LoadNode::make(t, args[1], args[2], const_true(t.lanes()));
} else {
*ret = LoadNode::make(t, args[1], args[2], args[3]);
}
});
TVM_REGISTER_GLOBAL("tir.Call")
.set_body_typed([](
DataType type, std::string name,
Array<PrimExpr> args, int call_type,
FunctionRef func, int value_index
) {
return CallNode::make(type,
name,
args,
static_cast<CallNode::CallType>(call_type),
func,
value_index);
});
} // namespace tir
} // namespace tvm
......@@ -662,4 +662,90 @@ TVM_REGISTER_GLOBAL("node.LargeUIntImm")
TVM_REGISTER_GLOBAL("node.String")
.set_body_typed(tir::StringImmNode::make);
TVM_REGISTER_GLOBAL("tir.min_value")
.set_body_typed(min_value);
TVM_REGISTER_GLOBAL("tir.max_value")
.set_body_typed(max_value);
TVM_REGISTER_GLOBAL("tir.abs")
.set_body_typed(tvm::abs);
TVM_REGISTER_GLOBAL("tir.isnan")
.set_body_typed(tvm::isnan);
TVM_REGISTER_GLOBAL("tir.floor")
.set_body_typed(tvm::floor);
TVM_REGISTER_GLOBAL("tir.ceil")
.set_body_typed(tvm::ceil);
TVM_REGISTER_GLOBAL("tir.round")
.set_body_typed(tvm::round);
TVM_REGISTER_GLOBAL("tir.nearbyint")
.set_body_typed(tvm::nearbyint);
TVM_REGISTER_GLOBAL("tir.trunc")
.set_body_typed(tvm::trunc);
TVM_REGISTER_GLOBAL("tir._cast")
.set_body_typed(tvm::cast);
// operator overloading, smarter than make
#define REGISTER_MAKE_BINARY_OP(Node, Func) \
TVM_REGISTER_GLOBAL("tir."#Node) \
.set_body_typed([](PrimExpr a, PrimExpr b) { \
return (Func(a, b)); \
})
#define REGISTER_MAKE_BIT_OP(Node, Func) \
TVM_REGISTER_GLOBAL("tir."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
bool lhs_is_int = args[0].type_code() == kDLInt; \
bool rhs_is_int = args[1].type_code() == kDLInt; \
if (lhs_is_int) { \
*ret = (Func(args[0].operator int(), args[1].operator PrimExpr())); \
} else if (rhs_is_int) { \
*ret = (Func(args[0].operator PrimExpr(), args[1].operator int())); \
} else { \
*ret = (Func(args[0].operator PrimExpr(), args[1].operator PrimExpr())); \
} \
})
REGISTER_MAKE_BINARY_OP(_OpAdd, operator+);
REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
REGISTER_MAKE_BINARY_OP(_OpMul, operator*);
REGISTER_MAKE_BINARY_OP(_OpDiv, div);
REGISTER_MAKE_BINARY_OP(_OpMod, truncmod);
REGISTER_MAKE_BINARY_OP(_OpIndexDiv, indexdiv);
REGISTER_MAKE_BINARY_OP(_OpIndexMod, indexmod);
REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv);
REGISTER_MAKE_BINARY_OP(_OpTruncMod, truncmod);
REGISTER_MAKE_BINARY_OP(_OpPow, pow);
REGISTER_MAKE_BINARY_OP(_OpMin, min);
REGISTER_MAKE_BINARY_OP(_OpMax, max);
REGISTER_MAKE_BINARY_OP(_OpEQ, operator==);
REGISTER_MAKE_BINARY_OP(_OpNE, operator!=);
REGISTER_MAKE_BINARY_OP(_OpLT, operator<); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(_OpLE, operator<=); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(_OpGT, operator>); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(_OpGE, operator>=);
REGISTER_MAKE_BINARY_OP(_OpAnd, operator&&);
REGISTER_MAKE_BINARY_OP(_OpOr, operator||);
REGISTER_MAKE_BIT_OP(bitwise_and, operator&);
REGISTER_MAKE_BIT_OP(bitwise_or, operator|);
REGISTER_MAKE_BIT_OP(bitwise_xor, operator^);
REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
REGISTER_MAKE_BIT_OP(right_shift, operator>>);
TVM_REGISTER_GLOBAL("tir._OpIfThenElse")
.set_body_typed([] (PrimExpr cond, PrimExpr true_value, PrimExpr false_value) {
return if_then_else(cond, true_value, false_value);
});
} // namespace tvm
......@@ -20,7 +20,7 @@
/*!
* \file tvm/tir/stmt.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/ir_pass.h>
#include "../pass/ir_util.h"
......@@ -40,6 +40,9 @@ Stmt LetStmtNode::make(Var var, PrimExpr value, Stmt body) {
return Stmt(node);
}
TVM_REGISTER_GLOBAL("tir.LetStmt")
.set_body_typed(LetStmtNode::make);
Stmt AttrStmtNode::make(ObjectRef node,
std::string attr_key,
PrimExpr value,
......@@ -52,6 +55,10 @@ Stmt AttrStmtNode::make(ObjectRef node,
return Stmt(n);
}
TVM_REGISTER_GLOBAL("tir.AttrStmt")
.set_body_typed(AttrStmtNode::make);
Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) {
CHECK(condition.defined());
CHECK(message.dtype() == DataType::Int(32) ||
......@@ -66,6 +73,10 @@ Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) {
return Stmt(node);
}
TVM_REGISTER_GLOBAL("tir.AssertStmt")
.set_body_typed(AssertStmtNode::make);
Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body) {
CHECK(body.defined());
......@@ -76,6 +87,10 @@ Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body) {
return Stmt(node);
}
TVM_REGISTER_GLOBAL("tir.ProducerConsumer")
.set_body_typed(ProducerConsumerNode::make);
Stmt ForNode::make(Var loop_var,
PrimExpr min,
PrimExpr extent,
......@@ -99,6 +114,19 @@ Stmt ForNode::make(Var loop_var,
return Stmt(node);
}
TVM_REGISTER_GLOBAL("tir.For")
.set_body_typed([](
Var loop_var, PrimExpr min, PrimExpr extent,
int for_type, int device_api, Stmt body) {
return ForNode::make(loop_var,
min,
extent,
static_cast<ForType>(for_type),
static_cast<DeviceAPI>(device_api),
body);
});
Stmt StoreNode::make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate) {
CHECK(value.defined());
CHECK(index.defined());
......@@ -114,6 +142,18 @@ Stmt StoreNode::make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr pr
return Stmt(node);
}
TVM_REGISTER_GLOBAL("tir.Store")
.set_body([](TVMArgs args, TVMRetValue *ret) {
PrimExpr value = args[1];
if (args.size() == 3) {
*ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes()));
} else {
*ret = StoreNode::make(args[0], value, args[2], args[3]);
}
});
Stmt ProvideNode::make(FunctionRef func, int value_index, PrimExpr value, Array<PrimExpr> args) {
CHECK(value_index >=0 && value_index < func->num_outputs())
<< "value index output function return value bound";
......@@ -131,6 +171,10 @@ Stmt ProvideNode::make(FunctionRef func, int value_index, PrimExpr value, Array<
return Stmt(node);
}
TVM_REGISTER_GLOBAL("tir.Provide")
.set_body_typed(ProvideNode::make);
Stmt AllocateNode::make(Var buffer_var,
DataType dtype,
Array<PrimExpr> extents,
......@@ -157,6 +201,15 @@ Stmt AllocateNode::make(Var buffer_var,
return Stmt(node);
}
// overloaded, needs special handling
// has default args
TVM_REGISTER_GLOBAL("tir.Allocate")
.set_body_typed([](
Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition, Stmt body
){
return AllocateNode::make(buffer_var, type, extents, condition, body);
});
int32_t AllocateNode::constant_allocation_size(const Array<PrimExpr>& extents) {
int64_t result = 1;
for (size_t i = 0; i < extents.size(); ++i) {
......@@ -178,12 +231,16 @@ Stmt FreeNode::make(Var buffer_var) {
return Stmt(node);
}
TVM_REGISTER_GLOBAL("tir.Free")
.set_body_typed(FreeNode::make);
Stmt RealizeNode::make(FunctionRef func,
int value_index,
DataType dtype,
Region bounds,
PrimExpr condition,
Stmt body) {
int value_index,
DataType dtype,
Region bounds,
PrimExpr condition,
Stmt body) {
for (size_t i = 0; i < bounds.size(); ++i) {
CHECK(bounds[i]->min.defined());
CHECK(bounds[i]->extent.defined());
......@@ -204,6 +261,11 @@ Stmt RealizeNode::make(FunctionRef func,
return Stmt(node);
}
TVM_REGISTER_GLOBAL("tir.Realize")
.set_body_typed(RealizeNode::make);
Stmt PrefetchNode::make(FunctionRef func, int value_index, DataType dtype, Region bounds) {
for (size_t i = 0; i < bounds.size(); ++i) {
CHECK(bounds[i]->min.defined());
......@@ -220,12 +282,21 @@ Stmt PrefetchNode::make(FunctionRef func, int value_index, DataType dtype, Regio
return Stmt(node);
}
TVM_REGISTER_GLOBAL("tir.Prefetch")
.set_body_typed(PrefetchNode::make);
SeqStmt::SeqStmt(Array<Stmt> seq) {
auto node = make_object<SeqStmtNode>();
node->seq = std::move(seq);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.SeqStmt")
.set_body_typed([](Array<Stmt> seq) {
return SeqStmt(std::move(seq));
});
Stmt IfThenElseNode::make(PrimExpr condition, Stmt then_case, Stmt else_case) {
CHECK(condition.defined());
CHECK(then_case.defined());
......@@ -238,6 +309,10 @@ Stmt IfThenElseNode::make(PrimExpr condition, Stmt then_case, Stmt else_case) {
return Stmt(node);
}
TVM_REGISTER_GLOBAL("tir.IfThenElse")
.set_body_typed(IfThenElseNode::make);
Stmt EvaluateNode::make(PrimExpr value) {
CHECK(value.defined());
......@@ -246,6 +321,9 @@ Stmt EvaluateNode::make(PrimExpr value) {
return Stmt(node);
}
TVM_REGISTER_GLOBAL("tir.Evaluate")
.set_body_typed(EvaluateNode::make);
// Printers
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......
......@@ -19,7 +19,7 @@
/*!
* Exposure of pass functions.
* \file api_pass.cc
* \file ffi_api.cc
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
......@@ -136,8 +136,8 @@ TVM_REGISTER_GLOBAL("ir_pass.LowerStorageAccess")
// make from two arguments
#define REGISTER_PASS(PassName) \
TVM_REGISTER_GLOBAL("ir_pass."#PassName) \
.set_body_typed(PassName); \
TVM_REGISTER_GLOBAL("ir_pass."#PassName) \
.set_body_typed(PassName); \
REGISTER_PASS(ConvertSSA);
......
......@@ -27,7 +27,7 @@ def test_op_translation():
except tvm.error.OpNotImplemented as e:
msg = str(e)
assert isinstance(e, NotImplementedError)
assert msg.find("api_test.cc") != -1
assert msg.find("ffi_testing.cc") != -1
fchk_eq = tvm.testing.test_check_eq_callback(
"InternalError: myop")
......@@ -36,14 +36,14 @@ def test_op_translation():
assert False
except tvm.error.InternalError as e:
msg = str(e)
assert msg.find("api_test.cc") != -1
assert msg.find("ffi_testing.cc") != -1
try:
tvm.testing.ErrorTest(0, 1)
assert False
except ValueError as e:
msg = str(e)
assert msg.find("api_test.cc") != -1
assert msg.find("ffi_testing.cc") != -1
def test_deep_callback():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment