Unverified Commit e0316415 by Tianqi Chen Committed by GitHub

[TIR] Introduce tir::PrimFunc (#5070)

This PR introduces tir::PrimFunc which will be used as the TIR function
container in the unified IR.

Also streamlined the function attributes a bit further.
- All common attributes are under tvm::attr
- TIR specific attributes are under tvm::tir::attr and comes with a tir prefix
- Use stl_style for attributes for now
parent be4e9db4
......@@ -33,6 +33,36 @@
namespace tvm {
/*!
* \brief Possible Calling conventions.
*
* NOTE: The calling convention also implies
* the way we implement the function during lowering.
*/
enum class CallingConv : int {
/*!
* \brief Default calling convetion.
*
* - Uses the native calling convention of the target.
* - Implementation: specified by the native target.
*/
kDefault = 0,
/*!
* \brief Device kernel launch
*
* - Call by PackedFunc calling convention.
* - Implementation: defined by device runtime(e.g. runtime/cuda)
*/
kDeviceKernelLaunch = 2,
/*!
* \brief PackedFunc that exposes a CPackedFunc signature.
*
* - Calling by PackedFunc calling convention.
* - Implementation: Expose a function with the CPackedFunc signature.
*/
kCPackedFunc = 3,
};
/*!
* \brief Base node of all functions.
*
* We support several variants of functions throughout the stack.
......@@ -115,5 +145,74 @@ class BaseFunc : public RelayExpr {
TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode);
};
/*!
* \brief Create a new function that copies func, but overrides
* the attribute value key with the value.
*
* \param func The input function.
* \param attr_key The attribute key.
* \param attr_value The value attribute value.
*
* \tparam TFunc The corresponding function type.
*
* \returns The new function with updated attributes.
*
* \note This function performs copy on write optimization for func.
* If we move a uniquely referenced func into WithAttr,
* then no additional copy will be performed.
*
* This is also why we make it as a function instead of a member function
* and why we pass by value in the first argument.
*
* \code
*
* // Recommended way to trigger copy on write
* func = WithAttr(std::move(func), "key1", value1);
* func = WithAttr(std::move(func), "key2", value2);
*
* \endcode
*/
template<typename TFunc,
typename = typename std::enable_if<
std::is_base_of<BaseFunc, TFunc>::value>::type>
inline TFunc WithAttr(TFunc func,
const std::string& attr_key,
ObjectRef attr_value) {
using TNode = typename TFunc::ContainerType;
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
TNode* node = func.CopyOnWrite();
if (node->attrs.defined()) {
node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
} else {
Map<std::string, ObjectRef> dict = {{attr_key, attr_value}};
node->attrs = DictAttrs(dict);
}
return func;
}
/*!
* \brief Generic attribute names that can be attached to any function.
*
* \sa tvm::tir::attr, tvm::relay::attr
*/
namespace attr {
/*!
* \brief Indicates the special calling convention.
*
* Type: Integer
*
* \sa tvm::CallingConv
*/
constexpr const char* kCallingConv = "calling_conv";
/*!
* \brief Compilation target of the function.
*
* Type: Target
*
* \sa tvm::Target
*/
constexpr const char* kTarget = "target";
} // namespace attr
} // namespace tvm
#endif // TVM_IR_FUNCTION_H_
......@@ -277,6 +277,13 @@ class TupleType : public Type {
};
/*!
* \return a type that represents void.
*/
inline Type VoidType() {
return TupleType::Empty();
}
/*!
* \brief Potential Constraints in a function.
* \sa TypeConstraint
*/
......
......@@ -115,33 +115,6 @@ class Function : public BaseFunc {
};
/*!
* \brief Create a new function that copies func, but overrides
* the attribute value key with the value.
*
* \param func The input function.
* \param attr_key The attribute key.
* \param attr_value The value attribute value.
*
* \returns The new function with updated attributes.
*
* \note This function performs copy on write optimization for func.
* If we move a uniquely referenced func into WithAttr,
* then no additional copy will be performed.
*
* This is also why we make it as a function instead of a member function
* and why we pass by value in the first argument.
*
* \code
*
* // Recommended way to trigger copy on write
* func = WithAttr(std::move(func), "key1", value1);
* func = WithAttr(std::move(func), "key2", value2);
*
* \endcode
*/
TVM_DLL Function WithAttr(Function func, const std::string& attr_key, ObjectRef attr_value);
/*!
* \brief namespace of the attributes that can be attached to a relay::Function.
*/
namespace attr {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/tir/function.h
* \brief TIR Function.
*/
#ifndef TVM_TIR_FUNCTION_H_
#define TVM_TIR_FUNCTION_H_
#include <tvm/ir/function.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/stmt.h>
#include <string>
namespace tvm {
namespace tir {
/*!
* \brief Primitive functions that contains TIR statements.
*
* The PrimFunc provides low-level code representation does not
* automatically manage
*
* \sa PrimFunc
*/
class PrimFuncNode : public BaseFuncNode {
public:
/*! \brief Function parameters */
Array<tir::Var> params;
/*! \brief The body of the function */
tir::Stmt body;
/*! \brief The return type of the function. */
Type ret_type;
/*!
* \brief Maps some parameters to specific Buffer data structures.
*
* buffer_map provides a way to express data structure's field and shape
* constraints. The provided information is used in the program analysis
* and the code generation.
*
* - It defines the vars in the Buffer (m, n) in the cases below when
* they appears in the buffer_map for the first time.
* - When a var appears multiple times, they translate into runtime
* assertion to check the field constraint.
*
* \code
*
* # The corresponding fields of f are as follows
* #
* # - f.params = [a, b]
* # - f.buffer_map = {a: A, b: B}
* # - A = decl_buffer(shape=[m, n])
* # - B = decl_buffer(shape=[m, n])
*
* def f(a, b):
* m, n = var(), var()
* A = bind_buffer(a, shape=[m, n])
* B = bind_buffer(b, shape=[m, n])
* # body
*
* \endcode
*
* buffer_map is a sugar to express:
* - Parameter unpacking: e.g. I can load a.shape[0] to get value of m
* - Constraint checking: a.shape[0] must equal b.shape[0] because they
* both corresponds to m.
* While we could have express parameter unpacking and constraint using
* normal statements, making buffer_map as first class citizen of PrimFunc
* will make program analysis much easier.
*
* \note This field can be nullptr
*/
Map<tir::Var, Buffer> buffer_map;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("params", &params);
v->Visit("body", &body);
v->Visit("ret_type", &ret_type);
v->Visit("buffer_map", &buffer_map);
v->Visit("attrs", &attrs);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
/*!
* \brief Return the derived function annotation of this function.
*
* \return The function type annotation.
* \note The function type annotation of PrimExpr is
* directly derived from the Vars without the need of type inference.
*/
TVM_DLL FuncType func_type_annotation() const;
static constexpr const char* _type_key = "tir.PrimFunc";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncNode, BaseFuncNode);
};
/*!
* \brief Managed reference to PrimFuncNode.
* \sa PrimFuncNode
*/
class PrimFunc : public BaseFunc {
public:
/*!
* \brief Constructor
* \param params The parameters of the function.
* \param body The body of the function.
* \param ret_type The return type of the function.
* \param buffer_map The buffer map for parameter buffer unpacking.
* \param attrs Additional function attributes.
*/
TVM_DLL PrimFunc(Array<tir::Var> params,
Stmt body,
Type ret_type = VoidType(),
Map<tir::Var, Buffer> buffer_map = NullValue<Map<tir::Var, Buffer>>(),
DictAttrs attrs = NullValue<DictAttrs>());
TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode);
};
/*!
* \brief PrimFunc specific attribute names.
*
* \sa tvm::attr
*/
namespace attr {
/*!
* \brief List of thread IterVar that a DeviceLaunch function corresponds to.
*
* Type: Array<tir::IterVar>
*
* We call a device kernel launch function f using the following convention:
*
* Call(f,
* [arg1, arg2, ..., arg_n,
* work_size_1, work_size_2, ... work_size_m])
*
* Here n = len(arg), m = len(work_size) = len(device_thread_axis).
*
* The list of device_thread_axis indicates how can be bind the
* work_size arguments to the corresponding threads.
*
* \sa tvm::CallingConv::kDeviceKernelLaunch
*/
constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis";
/*!
* \brief Whether to set noalias rule on the function arguments.
*
* Type: Integer
*/
constexpr const char* kNoAlias = "tir.noalias";
} // namespace attr
} // namespace tir
} // namespace tvm
#endif // TVM_TIR_FUNCTION_H_
......@@ -28,6 +28,7 @@
#ifndef TVM_TIR_OP_H_
#define TVM_TIR_OP_H_
#include <tvm/ir/type.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
......@@ -37,6 +38,7 @@
namespace tvm {
// Most common operators can be overloaded by argument type(PrimExpr).
// So we put them under the root namespace.
// It is also necessary to overload operators for PrimExpr.
......@@ -45,6 +47,16 @@ namespace tvm {
// as they are more specific to the tir namespace.
/*!
* \brief Get the type of the expression under the unified type system.
*
* This function could return a more refined type than
* the runtime type provided by expr->dtype
*
* \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType.
*/
TVM_DLL Type GetType(const PrimExpr& expr);
/*!
* Query the maximum possible value of dtype.
* \param dtype The data type.
* \return the maximum possible value in this format.
......
......@@ -21,7 +21,8 @@ from .type import Type, TypeKind, TypeVar, GlobalTypeVar, TupleType
from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
from .tensor_type import TensorType
from .type_relation import TypeCall, TypeRelation
from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, BaseFunc, Range
from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range
from .function import BaseFunc
from .adt import Constructor, TypeData
from .module import IRModule
from .attrs import Attrs, DictAttrs, make_node
......
......@@ -51,15 +51,6 @@ class RelayExpr(BaseExpr):
return ret
class BaseFunc(RelayExpr):
"""Base class of all functions."""
@property
def attrs(self):
"""Return the attrs member of the function.
"""
return _ffi_api.BaseFunc_Attrs(self)
@tvm._ffi.register_object("relay.GlobalVar")
class GlobalVar(RelayExpr):
"""A global variable in the IR.
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Function defintiions."""
from .expr import RelayExpr
from . import _ffi_api
class BaseFunc(RelayExpr):
"""Base class of all functions."""
@property
def attrs(self):
"""Return the attrs member of the function.
"""
return _ffi_api.BaseFunc_Attrs(self)
......@@ -282,7 +282,8 @@ class Function(BaseFunc):
func : Function
A new copy of the function
"""
return _expr.FunctionWithAttr(self, attr_key, attr_value)
return _expr.FunctionWithAttr(
self, attr_key, convert(attr_value))
......
......@@ -31,6 +31,8 @@ from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For
from .stmt import Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_list
from .function import PrimFunc
from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, all, any, min_value, max_value, trace
from .op import exp, exp2, exp10, log, log2, log10
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Function data types."""
import tvm._ffi
import tvm.runtime
from tvm.ir import BaseFunc
from .buffer import Buffer
from .expr import Var
from . import _ffi_api
@tvm._ffi.register_object("tir.PrimFunc")
class PrimFunc(BaseFunc):
"""A function declaration expression.
Parameters
----------
params: List[Union[tvm.tir.Var, tvm.tir.Buffer]]
List of input parameters to the function.
body: tvm.tir.Stmt
The body of the function.
ret_type: tvm.ir.Type
The return type annotation of the function.
buffer_map : Map[tvm.tir.Var, tvm.tir.Buffer]
The buffer binding map.
attrs: Optional[tvm.Attrs]
Attributes of the function, can be None
"""
def __init__(self,
params,
body,
ret_type=None,
buffer_map=None,
attrs=None):
param_list = []
buffer_map = {} if buffer_map is None else buffer_map
for x in params:
if isinstance(x, Buffer):
var = Var(x.name, dtype="handle")
param_list.append(var)
buffer_map[var] = x
elif isinstance(x, Var):
param_list.append(x)
else:
raise TypeError("params can only contain Var or Buffer")
self.__init_handle_by_constructor__(
_ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs)
def with_attr(self, attr_key, attr_value):
"""Create a new copy of the function and update the attribute
Parameters
----------
attr_key : str
The attribute key to use.
attr_value : Object
The new attribute value.
Returns
-------
func : Function
A new copy of the function
"""
return _ffi_api.PrimFuncWithAttr(
self, attr_key, tvm.runtime.convert(attr_value))
......@@ -30,4 +30,5 @@ TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs")
.set_body_typed([](BaseFunc func) {
return func->attrs;
});
} // namespace tvm
......@@ -99,7 +99,11 @@ class RelayTextPrinter :
}
Doc PrintFinal(const ObjectRef& node) {
if (node.as<ExprNode>()) {
if (node->IsInstance<BaseFuncNode>() &&
!node->IsInstance<relay::FunctionNode>()) {
// Temporarily skip non-relay functions.
// TODO(tvm-team) enhance the code to work for all functions
} else if (node.as<ExprNode>()) {
Expr expr = Downcast<Expr>(node);
dg_ = DependencyGraph::Create(&arena_, expr);
}
......@@ -122,7 +126,10 @@ class RelayTextPrinter :
std::vector<Doc> PrintFuncAttrs(const Attrs& attrs);
Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false) {
if (node.as<ExprNode>()) {
bool is_non_relay_func =
node->IsInstance<BaseFuncNode>() &&
!node->IsInstance<relay::FunctionNode>();
if (node.as<ExprNode>() && !is_non_relay_func) {
return PrintExpr(Downcast<Expr>(node), meta, try_inline);
} else if (node.as<TypeNode>()) {
return PrintType(Downcast<Type>(node), meta);
......@@ -134,7 +141,7 @@ class RelayTextPrinter :
// default module.
std::ostringstream os;
os << node;
return Doc() << os.str();
return Doc::RawText(os.str());
}
}
......
......@@ -60,18 +60,6 @@ bool FunctionNode::UseDefaultCompiler() const {
return !val.defined() || val->value == "default";
}
Function WithAttr(Function func, const std::string& attr_key, ObjectRef attr_value) {
FunctionNode* node = func.CopyOnWrite();
if (node->attrs.defined()) {
node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
} else {
Map<std::string, ObjectRef> dict = {{attr_key, attr_value}};
node->attrs = DictAttrs(dict);
}
return func;
}
TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_GLOBAL("relay._make.Function")
......@@ -94,9 +82,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
TVM_REGISTER_GLOBAL("relay._expr.FunctionWithAttr")
.set_body_typed(
[](Function func, std::string name, ObjectRef ref) {
return WithAttr(std::move(func), name, ref);
});
[](Function func, std::string name, ObjectRef ref) {
return WithAttr(std::move(func), name, ref);
});
} // namespace relay
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/tir/ir/function.cc
* \brief The function data structure.
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/function.h>
namespace tvm {
namespace tir {
PrimFunc::PrimFunc(Array<tir::Var> params,
Stmt body,
Type ret_type,
Map<tir::Var, Buffer> buffer_map,
DictAttrs attrs) {
// Assume void-return type for now
// TODO(tvm-team) consider type deduction from body.
if (!ret_type.defined()) {
ret_type = VoidType();
}
auto n = make_object<PrimFuncNode>();
n->params = std::move(params);
n->body = std::move(body);
n->ret_type = std::move(ret_type);
n->buffer_map = std::move(buffer_map);
n->attrs = std::move(attrs);
data_ = std::move(n);
}
FuncType PrimFuncNode::func_type_annotation() const {
Array<Type> param_types;
for (auto param : this->params) {
param_types.push_back(GetType(param));
}
return FuncType(param_types, ret_type, {}, {});
}
TVM_REGISTER_NODE_TYPE(PrimFuncNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PrimFuncNode>([](const ObjectRef& ref, ReprPrinter* p) {
// TODO(tvm-team) redirect to Text printer once we have a good text format.
auto* node = static_cast<const PrimFuncNode*>(ref.get());
p->stream << "PrimFunc(" << node->params << ") ";
if (node->attrs.defined()) {
p->stream << "attrs=" << node->attrs;
}
p->stream << " {\n";
p->indent += 2;
p->Print(node->body);
p->indent -= 2;
p->stream << "}\n";
});
TVM_REGISTER_GLOBAL("tir.PrimFunc")
.set_body_typed([](Array<tir::Var> params,
Stmt body,
Type ret_type,
Map<tir::Var, Buffer> buffer_map,
DictAttrs attrs) {
return PrimFunc(params, body, ret_type, buffer_map, attrs);
});
TVM_REGISTER_GLOBAL("tir.PrimFuncWithAttr")
.set_body_typed([](PrimFunc func, std::string name, ObjectRef ref) {
return WithAttr(std::move(func), name, ref);
});
} // namespace tir
} // namespace tvm
......@@ -32,6 +32,18 @@ namespace tvm {
using namespace tir;
Type GetType(const PrimExpr& expr) {
runtime::DataType dtype = expr.dtype();
// These types already implies the specific type.
if (dtype.is_int() || dtype.is_uint() || dtype.is_float()) {
return PrimType(dtype);
}
// TODO(tqchen): add recursive type inference for Var and Call here
// once we introduced the corresponding fields to the IR.
return PrimType(dtype);
}
// simple cast that only checks if type matches and cast
inline PrimExpr SimpleCast(const DataType& t, PrimExpr value) {
if (value.dtype() == t) return value;
......
......@@ -19,6 +19,7 @@ from tvm import te
import numpy as np
def test_const():
x = tvm.tir.const(1, "int32")
print(x.dtype)
......@@ -46,8 +47,8 @@ def test_make():
x = tvm.tir.const(1, "int32")
y = te.var("x")
z = x + y
assert isinstance(tvm.te.max(x, y), tvm.tir.Max)
assert isinstance(tvm.te.min(x, y), tvm.tir.Min)
assert isinstance(tvm.tir.max(x, y), tvm.tir.Max)
assert isinstance(tvm.tir.min(x, y), tvm.tir.Min)
def test_ir():
......@@ -111,7 +112,6 @@ def test_stmt():
tvm.tir.For.Serial, 0,
x)
def test_dir():
x = te.var('x')
dir(x)
......@@ -247,8 +247,26 @@ def test_equality_string_imm():
x == y.value
x == y
def test_prim_func():
x = te.var('x')
y = te.var('y')
b = tvm.tir.decl_buffer((x,), "float32")
stmt = tvm.tir.LetStmt(
x, 10, tvm.tir.Evaluate(x + 1));
func = tvm.tir.PrimFunc(
[x, y, b], stmt)
assert func.buffer_map[func.params[2]].same_as(b)
assert len(func.buffer_map) == 1
f2 = func.with_attr("calling_conv", 1)
assert f2.attrs["calling_conv"].value == 1
assert func.attrs is None
if __name__ == "__main__":
test_prim_func()
test_cast()
test_attr()
test_const()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment