Commit 2083513f by Jared Roesch Committed by Tianqi Chen

Implement explicit IR representation of memory alloction (#3560)

parent 19164063
...@@ -272,6 +272,7 @@ add_library(tvm_runtime SHARED ${RUNTIME_SRCS}) ...@@ -272,6 +272,7 @@ add_library(tvm_runtime SHARED ${RUNTIME_SRCS})
if(USE_RELAY_DEBUG) if(USE_RELAY_DEBUG)
message(STATUS "Building Relay in debug mode...") message(STATUS "Building Relay in debug mode...")
set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "USE_RELAY_DEBUG") set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "USE_RELAY_DEBUG")
set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "DMLC_LOG_DEBUG")
else() else()
set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "NDEBUG") set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "NDEBUG")
endif(USE_RELAY_DEBUG) endif(USE_RELAY_DEBUG)
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/attrs/memory.h
* \brief Attributes for memory operators.
*/
#ifndef TVM_RELAY_ATTRS_MEMORY_H_
#define TVM_RELAY_ATTRS_MEMORY_H_
#include <tvm/attrs.h>
#include <tvm/relay/expr.h>
#include <string>
namespace tvm {
namespace relay {
/*!
* \brief Options for allocating tensors.
*/
struct AllocTensorAttrs : public tvm::AttrsNode<AllocTensorAttrs> {
Constant const_shape;
Array<IndexExpr> assert_shape;
DataType dtype;
TVM_DECLARE_ATTRS(AllocTensorAttrs, "relay.attrs.AllocTensorAttrs") {
TVM_ATTR_FIELD(dtype)
.describe(
"The dtype of the tensor to allocate.")
.set_default(Float(32, 1));
TVM_ATTR_FIELD(const_shape)
.describe(
"The shape of constant used to aid in type inference.");
TVM_ATTR_FIELD(assert_shape)
.describe(
"The shape to cast the return type of the allocation to, "\
"used to specify the shape obtained via further analysis.");
}
};
/*!
* \brief Options for the shape function operator.
*/
struct ShapeFuncAttrs : public tvm::AttrsNode<ShapeFuncAttrs> {
Array<Integer> is_input;
TVM_DECLARE_ATTRS(ShapeFuncAttrs, "relay.attrs.ShapeFuncAttrs") {
TVM_ATTR_FIELD(is_input)
.describe(
"A bool indicating whether the shape function should"\
"expect shape or input in each position.");
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_MEMORY_H_
...@@ -47,6 +47,12 @@ namespace relay { ...@@ -47,6 +47,12 @@ namespace relay {
(*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \ (*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \
} }
#define RELAY_DEBUG_INTERP(...) \
{ auto fdebug = runtime::Registry::Get("relay.debug_interp"); \
CHECK(fdebug) << "Could not find Relay Python debugger function."; \
(*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \
}
/*! /*!
* \brief We always used NodeRef for referencing nodes. * \brief We always used NodeRef for referencing nodes.
* *
......
...@@ -76,7 +76,8 @@ class ModuleNode : public RelayNode { ...@@ -76,7 +76,8 @@ class ModuleNode : public RelayNode {
} }
TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs, TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs,
tvm::Map<GlobalTypeVar, TypeData> global_type_defs); tvm::Map<GlobalTypeVar, TypeData> global_type_defs,
std::unordered_set<std::string> imports = {});
/*! /*!
* \brief Add a function to the global environment. * \brief Add a function to the global environment.
...@@ -235,6 +236,11 @@ class ModuleNode : public RelayNode { ...@@ -235,6 +236,11 @@ class ModuleNode : public RelayNode {
*/ */
TVM_DLL void ImportFromStd(const std::string& path); TVM_DLL void ImportFromStd(const std::string& path);
/*!
* \brief The set of imported files.
*/
TVM_DLL std::unordered_set<std::string> Imports() const;
/*! \brief Construct a module from a standalone expression. /*! \brief Construct a module from a standalone expression.
* *
* Allows one to optionally pass a global function map and * Allows one to optionally pass a global function map and
......
...@@ -283,6 +283,8 @@ class Object { ...@@ -283,6 +283,8 @@ class Object {
* \note The deleter will be called when ref_counter_ becomes zero. * \note The deleter will be called when ref_counter_ becomes zero.
*/ */
inline void DecRef(); inline void DecRef();
private:
/*! /*!
* \return The usage count of the cell. * \return The usage count of the cell.
* \note We use stl style naming to be consistent with known API in shared_ptr. * \note We use stl style naming to be consistent with known API in shared_ptr.
...@@ -675,6 +677,16 @@ struct ObjectEqual { ...@@ -675,6 +677,16 @@ struct ObjectEqual {
operator bool() const { return data_ != nullptr; } \ operator bool() const { return data_ != nullptr; } \
using ContainerType = ObjectName; using ContainerType = ObjectName;
#define TVM_DEFINE_OBJECT_REF_METHODS_MUT(TypeName, ParentType, ObjectName) \
TypeName() {} \
explicit TypeName( \
::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \
: ParentType(n) {} \
ObjectName* operator->() { \
return static_cast<ObjectName*>(data_.get()); \
} \
operator bool() const { return data_ != nullptr; } \
using ContainerType = ObjectName;
// Implementations details below // Implementations details below
// Object reference counting. // Object reference counting.
......
...@@ -138,6 +138,7 @@ enum class Opcode { ...@@ -138,6 +138,7 @@ enum class Opcode {
GetTag = 13U, GetTag = 13U,
LoadConsti = 14U, LoadConsti = 14U,
Fatal = 15U, Fatal = 15U,
AllocStorage = 16U,
}; };
/*! \brief A single virtual machine instruction. /*! \brief A single virtual machine instruction.
...@@ -158,6 +159,8 @@ struct Instruction { ...@@ -158,6 +159,8 @@ struct Instruction {
union { union {
struct /* AllocTensor Operands */ { struct /* AllocTensor Operands */ {
/*! \brief The storage to allocate from. */
RegName storage;
/*! \brief The number of dimensions. */ /*! \brief The number of dimensions. */
uint32_t ndim; uint32_t ndim;
/*! \brief The shape of tensor. */ /*! \brief The shape of tensor. */
...@@ -166,6 +169,8 @@ struct Instruction { ...@@ -166,6 +169,8 @@ struct Instruction {
DLDataType dtype; DLDataType dtype;
} alloc_tensor; } alloc_tensor;
struct /* AllocTensorReg Operands */ { struct /* AllocTensorReg Operands */ {
/*! \brief The storage to allocate from. */
RegName storage;
/*! \brief The register to read the shape out of. */ /*! \brief The register to read the shape out of. */
RegName shape_register; RegName shape_register;
/*! \brief The datatype of tensor to be allocated. */ /*! \brief The datatype of tensor to be allocated. */
...@@ -253,6 +258,14 @@ struct Instruction { ...@@ -253,6 +258,14 @@ struct Instruction {
/*! \brief The free variables as an array. */ /*! \brief The free variables as an array. */
RegName* free_vars; RegName* free_vars;
}; };
struct /* AllocStorage Operands */ {
/*! \brief The size of the allocation. */
RegName allocation_size;
/*! \brief The alignment of the allocation. */
RegName alignment;
/*! \brief The hint of the dtype. */
DLDataType dtype_hint;
} alloc_storage;
}; };
/*! \brief Construct a return instruction. /*! \brief Construct a return instruction.
...@@ -274,19 +287,23 @@ struct Instruction { ...@@ -274,19 +287,23 @@ struct Instruction {
static Instruction InvokePacked(Index packed_index, Index arity, Index output_size, static Instruction InvokePacked(Index packed_index, Index arity, Index output_size,
const std::vector<RegName>& args); const std::vector<RegName>& args);
/*! \brief Construct an allocate tensor instruction with constant shape. /*! \brief Construct an allocate tensor instruction with constant shape.
* \param storage The storage to allocate out of.
* \param shape The shape of the tensor. * \param shape The shape of the tensor.
* \param dtype The dtype of the tensor. * \param dtype The dtype of the tensor.
* \param dst The destination register. * \param dst The destination register.
* \return The allocate tensor instruction. * \return The allocate tensor instruction.
*/ */
static Instruction AllocTensor(std::vector<int64_t> shape, DLDataType dtype, RegName dst); static Instruction AllocTensor(RegName storage,
const std::vector<int64_t>& shape, DLDataType dtype, RegName dst);
/*! \brief Construct an allocate tensor instruction with register. /*! \brief Construct an allocate tensor instruction with register.
* \param storage The storage to allocate out of.
* \param shape_register The register containing the shape. * \param shape_register The register containing the shape.
* \param dtype The dtype of the tensor. * \param dtype The dtype of the tensor.
* \param dst The destination register. * \param dst The destination register.
* \return The allocate tensor instruction. * \return The allocate tensor instruction.
*/ */
static Instruction AllocTensorReg(RegName shape_register, DLDataType dtype, RegName dst); static Instruction AllocTensorReg(RegName storage,
RegName shape_register, DLDataType dtype, RegName dst);
/*! \brief Construct an allocate datatype instruction. /*! \brief Construct an allocate datatype instruction.
* \param tag The datatype tag. * \param tag The datatype tag.
* \param num_fields The number of fields for the datatype. * \param num_fields The number of fields for the datatype.
...@@ -295,7 +312,7 @@ struct Instruction { ...@@ -295,7 +312,7 @@ struct Instruction {
* \return The allocate instruction tensor. * \return The allocate instruction tensor.
*/ */
static Instruction AllocADT(Index tag, Index num_fields, const std::vector<RegName>& fields, static Instruction AllocADT(Index tag, Index num_fields, const std::vector<RegName>& fields,
RegName dst); RegName dst);
/*! \brief Construct an allocate closure instruction. /*! \brief Construct an allocate closure instruction.
* \param func_index The index of the function table. * \param func_index The index of the function table.
* \param num_freevar The number of free variables. * \param num_freevar The number of free variables.
...@@ -364,6 +381,16 @@ struct Instruction { ...@@ -364,6 +381,16 @@ struct Instruction {
*/ */
static Instruction Move(RegName src, RegName dst); static Instruction Move(RegName src, RegName dst);
/*! \brief Allocate a storage block.
* \param size The size of the allocation.
* \param alignment The allocation's alignment.
* \param dtype_hint The data type hint for the allocator.
* \param dst The destination to place the storage.
* \return The alloc storage instruction.
*/
static Instruction AllocStorage(RegName size, RegName alignment,
DLDataType dtype_hint, RegName dst);
Instruction(); Instruction();
Instruction(const Instruction& instr); Instruction(const Instruction& instr);
Instruction& operator=(const Instruction& instr); Instruction& operator=(const Instruction& instr);
......
...@@ -59,6 +59,8 @@ from . import quantize ...@@ -59,6 +59,8 @@ from . import quantize
from . import qnn from . import qnn
from .scope_builder import ScopeBuilder from .scope_builder import ScopeBuilder
# Load Memory pass
from . import memory_alloc
# Required to traverse large programs # Required to traverse large programs
setrecursionlimit(10000) setrecursionlimit(10000)
......
...@@ -99,6 +99,10 @@ class CompileEngine(NodeBase): ...@@ -99,6 +99,10 @@ class CompileEngine(NodeBase):
msg += "--------------------------\n" msg += "--------------------------\n"
raise RuntimeError(msg) raise RuntimeError(msg)
def lower_shape_func(self, source_func, target=None):
key = _get_cache_key(source_func, target)
return _backend._CompileEngineLowerShapeFunc(self, key)
def jit(self, source_func, target=None): def jit(self, source_func, target=None):
"""JIT a source_func to a tvm.Function. """JIT a source_func to a tvm.Function.
......
...@@ -25,9 +25,14 @@ def _debugger_init(expr, stack): ...@@ -25,9 +25,14 @@ def _debugger_init(expr, stack):
import pdb import pdb
pdb.set_trace() pdb.set_trace()
# pylint: disable=unused-argument
@register_func("relay.debug") @register_func("relay.debug")
def _debug(*args): def _debug(*args):
import pdb
pdb.set_trace()
# pylint: disable=unused-argument
@register_func("relay.debug_interp")
def _debug_interp(*args):
_, _, _, ist = args _, _, _, ist = args
print("Relay Debugger") print("Relay Debugger")
print(" You can manipulate the expression under evaluation with the name `expr`.") print(" You can manipulate the expression under evaluation with the name `expr`.")
......
...@@ -317,6 +317,9 @@ class Function(Expr): ...@@ -317,6 +317,9 @@ class Function(Expr):
return _expr.FunctionSetParams(self, params) return _expr.FunctionSetParams(self, params)
def set_attribute(self, name, ref):
return _expr.FunctionSetAttr(self, name, ref)
@register_relay_node @register_relay_node
class Call(Expr): class Call(Expr):
......
...@@ -28,6 +28,7 @@ from .transform import * ...@@ -28,6 +28,7 @@ from .transform import *
from .algorithm import * from .algorithm import *
from . import nn from . import nn
from . import annotation from . import annotation
from . import memory
from . import image from . import image
from . import vision from . import vision
from . import contrib from . import contrib
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import
"""Operators for manipulating low level memory."""
from __future__ import absolute_import as _abs
from .memory import *
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
from ...._ffi.function import _init_api
_init_api("relay.op.memory._make", __name__)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Operators for manipulating low-level memory."""
from __future__ import absolute_import as _abs
from . import _make
def invoke_tvm_op(func, inputs, outputs):
"""Call a primitive function with the TVM operator calling convention.
Parameters
----------
inputs : tvm.relay.Expr
A tuple of the inputs to pass to the TVM function.
outputs : tvm.relay.Expr
A tuple of the outputs to pass to the TVM function.
Returns
-------
result : tvm.relay.Expr
The invoke_tvm_op call node.
"""
return _make.invoke_tvm_op(func, inputs, outputs)
def alloc_tensor(storage, shape, dtype='float32', assert_shape=None):
"""Allocate a tensor with the provided shape, and dtype.
Parameters
----------
storage : tvm.relay.Expr
The storage to allocate from.
shape : tvm.relay.Expr
The shape of the tensor to allocate.
dtype: str
The dtype of the tensor.
assert_shape: Control the static shape when computed by dynamic shape expression.
Returns
-------
result : tvm.relay.Expr
The alloc_tensor expression.
"""
return _make.alloc_tensor(storage, shape, dtype, assert_shape)
def alloc_storage(size, alignment, dtype_hint='float32'):
"""Allocate a piece of tensor storage.
Parameters
----------
size : tvm.relay.Expr
The size of the allocation.
alignment : tvm.relay.Expr
The alignment of the allocation.
dtype : str
The dtype_hint of the allocation.
Returns
-------
result : tvm.relay.Expr
The alloc_storage expression.
"""
return _make.alloc_storage(size, alignment, dtype_hint)
def shape_func(func, inputs, outputs, dependent=False):
"""Invoke the shape function of the passed function.
Parameters
----------
func : tvm.relay.Expr
The primitive function from which to compute the shape function.
inputs : tvm.relay.Tuple
The tupled inputs.
outputs : tvm.relay.Tuple
The tupled outputs.
Returns
-------
result : tvm.relay.Expr
The shape function expression.
"""
return _make.shape_func(func, inputs, outputs, dependent)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
v0.0.4
extern type Storage
...@@ -52,6 +52,9 @@ class Type(RelayNode): ...@@ -52,6 +52,9 @@ class Type(RelayNode):
""" """
return TypeCall(self, args) return TypeCall(self, args)
def is_dynamic(self):
return _make.IsDynamic(self)
@register_relay_node @register_relay_node
class TensorType(Type): class TensorType(Type):
"""A concrete TensorType in Relay. """A concrete TensorType in Relay.
...@@ -317,7 +320,6 @@ class RefType(Type): ...@@ -317,7 +320,6 @@ class RefType(Type):
def __init__(self, value): def __init__(self, value):
self.__init_handle_by_constructor__(_make.RefType, value) self.__init_handle_by_constructor__(_make.RefType, value)
def scalar_type(dtype): def scalar_type(dtype):
"""Creates a scalar type. """Creates a scalar type.
......
...@@ -72,6 +72,10 @@ bool IsDynamic(const Type& ty) { ...@@ -72,6 +72,10 @@ bool IsDynamic(const Type& ty) {
return v.is_dyn; return v.is_dyn;
} }
// TODO(@jroesch): MOVE ME
TVM_REGISTER_API("relay._make.IsDynamic")
.set_body_typed(IsDynamic);
Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) { Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
// for now, we always use int32 shape when possible // for now, we always use int32 shape when possible
// even if the result of shape inference becomes int64. // even if the result of shape inference becomes int64.
...@@ -775,6 +779,12 @@ TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower") ...@@ -775,6 +779,12 @@ TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower")
return self->Lower(key); return self->Lower(key);
}); });
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc")
.set_body_typed<CachedFunc(CompileEngine, CCacheKey)>(
[](CompileEngine self, CCacheKey key) {
return self->LowerShapeFunc(key);
});
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT") TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT")
.set_body_typed<PackedFunc(CompileEngine, CCacheKey)>( .set_body_typed<PackedFunc(CompileEngine, CCacheKey)>(
[](CompileEngine self, CCacheKey key) { [](CompileEngine self, CCacheKey key) {
......
...@@ -458,7 +458,7 @@ class Interpreter : ...@@ -458,7 +458,7 @@ class Interpreter :
if (dattrs->debug_func.defined()) { if (dattrs->debug_func.defined()) {
dattrs->debug_func(interp_state); dattrs->debug_func(interp_state);
} else { } else {
RELAY_DEBUG(interp_state); RELAY_DEBUG_INTERP(interp_state);
} }
return args[0]; return args[0];
...@@ -479,7 +479,8 @@ class Interpreter : ...@@ -479,7 +479,8 @@ class Interpreter :
if (const auto* tuple_type = func->body->checked_type().as<TupleTypeNode>()) { if (const auto* tuple_type = func->body->checked_type().as<TupleTypeNode>()) {
arg_len += tuple_type->fields.size(); arg_len += tuple_type->fields.size();
} else { } else {
CHECK(func->body->checked_type().as<TensorTypeNode>()); CHECK(func->body->checked_type().as<TensorTypeNode>())
<< func->body->checked_type();
arg_len += 1; arg_len += 1;
} }
std::vector<TVMValue> values(arg_len); std::vector<TVMValue> values(arg_len);
......
...@@ -48,6 +48,19 @@ namespace backend { ...@@ -48,6 +48,19 @@ namespace backend {
inline const PackedFunc* GetPackedFunc(const std::string& func_name) { inline const PackedFunc* GetPackedFunc(const std::string& func_name) {
return tvm::runtime::Registry::Get(func_name); return tvm::runtime::Registry::Get(func_name);
} }
/*!
* \brief Get a typed packed function.
*
* \param func_name
* \return const PackedFunc*
*/
template <typename R, typename... Args>
inline const runtime::TypedPackedFunc<R(Args...)> GetTypedPackedFunc(const std::string& func_name) {
auto *pf = GetPackedFunc(func_name);
CHECK(pf != nullptr) << "can not find packed function";
return runtime::TypedPackedFunc<R(Args...)>(*pf);
}
/*! /*!
* \brief Convert type to string * \brief Convert type to string
* *
......
...@@ -355,5 +355,11 @@ TVM_REGISTER_API("relay._expr.TempExprRealize") ...@@ -355,5 +355,11 @@ TVM_REGISTER_API("relay._expr.TempExprRealize")
return temp->Realize(); return temp->Realize();
}); });
TVM_REGISTER_API("relay._expr.FunctionSetAttr")
.set_body_typed<Function(Function, std::string, NodeRef)>(
[](Function func, std::string name, NodeRef ref) {
return FunctionSetAttr(func, name, ref);
});
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -35,13 +35,16 @@ using tvm::IRPrinter; ...@@ -35,13 +35,16 @@ using tvm::IRPrinter;
using namespace runtime; using namespace runtime;
Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs, Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
tvm::Map<GlobalTypeVar, TypeData> global_type_defs) { tvm::Map<GlobalTypeVar, TypeData> global_type_defs,
std::unordered_set<std::string> imports
) {
auto n = make_node<ModuleNode>(); auto n = make_node<ModuleNode>();
n->functions = std::move(global_funcs); n->functions = std::move(global_funcs);
n->type_definitions = std::move(global_type_defs); n->type_definitions = std::move(global_type_defs);
n->global_type_var_map_ = {}; n->global_type_var_map_ = {};
n->global_var_map_ = {}; n->global_var_map_ = {};
n->constructor_tag_map_ = {}; n->constructor_tag_map_ = {};
n->import_set_ = imports;
for (const auto& kv : n->functions) { for (const auto& kv : n->functions) {
// set global var map // set global var map
...@@ -283,9 +286,9 @@ Module ModuleNode::FromExpr( ...@@ -283,9 +286,9 @@ Module ModuleNode::FromExpr(
} }
void ModuleNode::Import(const std::string& path) { void ModuleNode::Import(const std::string& path) {
LOG(INFO) << "Importing: " << path;
if (this->import_set_.count(path) == 0) { if (this->import_set_.count(path) == 0) {
this->import_set_.insert(path); this->import_set_.insert(path);
DLOG(INFO) << "Importing: " << path;
std::fstream src_file(path, std::fstream::in); std::fstream src_file(path, std::fstream::in);
std::string file_contents { std::string file_contents {
std::istreambuf_iterator<char>(src_file), std::istreambuf_iterator<char>(src_file),
...@@ -302,6 +305,10 @@ void ModuleNode::ImportFromStd(const std::string& path) { ...@@ -302,6 +305,10 @@ void ModuleNode::ImportFromStd(const std::string& path) {
return this->Import(std_path + "/" + path); return this->Import(std_path + "/" + path);
} }
std::unordered_set<std::string> ModuleNode::Imports() const {
return this->import_set_;
}
Module FromText(const std::string& source, const std::string& source_name) { Module FromText(const std::string& source, const std::string& source_name) {
auto* f = tvm::runtime::Registry::Get("relay.fromtext"); auto* f = tvm::runtime::Registry::Get("relay.fromtext");
CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path."; CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
...@@ -312,7 +319,10 @@ Module FromText(const std::string& source, const std::string& source_name) { ...@@ -312,7 +319,10 @@ Module FromText(const std::string& source, const std::string& source_name) {
TVM_REGISTER_NODE_TYPE(ModuleNode); TVM_REGISTER_NODE_TYPE(ModuleNode);
TVM_REGISTER_API("relay._make.Module") TVM_REGISTER_API("relay._make.Module")
.set_body_typed(ModuleNode::make); .set_body_typed<Module(tvm::Map<GlobalVar, Function>, tvm::Map<GlobalTypeVar, TypeData>)>(
[](tvm::Map<GlobalVar, Function> funcs, tvm::Map<GlobalTypeVar, TypeData> types) {
return ModuleNode::make(funcs, types, {});
});
TVM_REGISTER_API("relay._module.Module_Add") TVM_REGISTER_API("relay._module.Module_Add")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
......
...@@ -24,14 +24,15 @@ ...@@ -24,14 +24,15 @@
* \brief Registration of annotation operators. * \brief Registration of annotation operators.
*/ */
#include <tvm/expr.h>
#include <tvm/relay/attrs/annotation.h> #include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <topi/elemwise.h> #include <topi/elemwise.h>
#include "../type_relations.h"
#include "../../pass/alter_op_layout.h" #include "../../pass/alter_op_layout.h"
#include "../type_relations.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
* used as "barrier" to avoid fusing operators belonging to differen devices. * used as "barrier" to avoid fusing operators belonging to differen devices.
*/ */
#include <tvm/expr.h>
#include <tvm/relay/attrs/device_copy.h> #include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -30,6 +30,8 @@ ...@@ -30,6 +30,8 @@
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <vector> #include <vector>
#include <string>
#include <unordered_map>
#include "type_relations.h" #include "type_relations.h"
#include "../pass/alter_op_layout.h" #include "../pass/alter_op_layout.h"
...@@ -105,6 +107,50 @@ namespace relay { ...@@ -105,6 +107,50 @@ namespace relay {
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", \ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", \
BinaryBroadcastLayout) BinaryBroadcastLayout)
/*! \brief A helper class for matching and rewriting operators. */
template<typename R>
class OpMatch {
public:
using MatchFunc =
std::function<R(const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_args)>;
/*! \brief Match an operator with the given name.
* \param op_name The name of the operator to match.
* \param func The function to execute when it matches.
* \return A self-reference for builder style API.
*/
inline OpMatch& Match(const std::string& op_name, MatchFunc func) {
auto op = Op::Get(op_name);
match_map_.insert({op, func});
return *this;
}
/*! \brief Rewrite a call operation based on the operator and the registered
* match functions.
* \param call The call to rewrite.
* \return The result of rewriting.
*/
inline R operator()(const Call& call) {
auto it = match_map_.find(Downcast<Op>(call->op));
if (it != match_map_.end()) {
return it->second(call->args, call->attrs, call->type_args);
} else {
if (default_ != nullptr) {
return default_(call->args, call->attrs, call->type_args);
} else {
LOG(FATAL) << "unexpected operation " << call->op;
}
}
}
private:
/*! \brief The match function map. */
std::unordered_map<Op, MatchFunc, NodeHash, NodeEqual> match_map_;
/*! \brief An optional default case. */
MatchFunc default_;
};
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -286,8 +286,8 @@ bool ShapeOfRel(const Array<Type>& types, ...@@ -286,8 +286,8 @@ bool ShapeOfRel(const Array<Type>& types,
CHECK(tt != nullptr); CHECK(tt != nullptr);
const auto* param = attrs.as<ShapeOfAttrs>(); const auto* param = attrs.as<ShapeOfAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
auto vector_out = tvm::Integer(tt->shape.size()); auto rank_shape = RankShape(tt->shape);
reporter->Assign(types[1], TensorTypeNode::make({ vector_out }, param->dtype)); reporter->Assign(types[1], TensorTypeNode::make(rank_shape, param->dtype));
return true; return true;
} }
......
...@@ -144,5 +144,13 @@ bool BroadcastCompRel(const Array<Type>& types, ...@@ -144,5 +144,13 @@ bool BroadcastCompRel(const Array<Type>& types,
return false; return false;
} }
Array<IndexExpr> RankShape(const Array<IndexExpr>& shape) {
if (shape.size() == 0) {
return {};
} else {
return { tvm::Integer(shape.size()) };
}
}
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -80,6 +80,8 @@ bool BroadcastCompRel(const Array<Type>& types, ...@@ -80,6 +80,8 @@ bool BroadcastCompRel(const Array<Type>& types,
const Attrs& attrs, const Attrs& attrs,
const TypeReporter& reporter); const TypeReporter& reporter);
Array<IndexExpr> RankShape(const Array<IndexExpr>& shape);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
* 3. Collect the device allocation of each expression. * 3. Collect the device allocation of each expression.
*/ */
#include <tvm/expr.h>
#include <tvm/relay/attrs/device_copy.h> #include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/attrs/annotation.h> #include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <tvm/relay/interpreter.h> #include <tvm/relay/interpreter.h>
#include <tvm/relay/attrs/transform.h> #include <tvm/relay/attrs/transform.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include "./pattern_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -73,13 +74,12 @@ bool ConstantCheck(const Expr& e) { ...@@ -73,13 +74,12 @@ bool ConstantCheck(const Expr& e) {
TVM_REGISTER_API("relay._analysis.check_constant") TVM_REGISTER_API("relay._analysis.check_constant")
.set_body_typed(ConstantCheck); .set_body_typed(ConstantCheck);
// TODO(tvm-team) consider combine dead-code with constant folder. // TODO(tvm-team) consider combine dead-code with constant folder.
// or make a more powerful partial evaluator. // or make a more powerful partial evaluator.
class ConstantFolder : public ExprMutator { class ConstantFolder : public ExprMutator {
public: public:
explicit ConstantFolder(FInterpreter executor) explicit ConstantFolder(FInterpreter executor, Module module)
: executor_(executor) { : executor_(executor), module_(module) {
} }
Expr VisitExpr_(const LetNode* op) final { Expr VisitExpr_(const LetNode* op) final {
...@@ -123,6 +123,15 @@ class ConstantFolder : public ExprMutator { ...@@ -123,6 +123,15 @@ class ConstantFolder : public ExprMutator {
if (call->op.same_as(Op::Get("shape_of"))) { if (call->op.same_as(Op::Get("shape_of"))) {
return EvaluateShapeOf(res, origin_args, call->attrs); return EvaluateShapeOf(res, origin_args, call->attrs);
} }
// We should think about potentially constant evaluation over these ops too.
if (call->op.same_as(Op::Get("memory.invoke_tvm_op")) ||
call->op.same_as(Op::Get("memory.shape_func")) ||
call->op.same_as(Op::Get("memory.alloc_tensor")) ||
call->op.same_as(Op::Get("memory.alloc_storage"))) {
return GetRef<Call>(call);
}
bool all_const_args = true; bool all_const_args = true;
for (Expr arg : call->args) { for (Expr arg : call->args) {
if (!checker_.Check(arg)) { if (!checker_.Check(arg)) {
...@@ -151,10 +160,16 @@ class ConstantFolder : public ExprMutator { ...@@ -151,10 +160,16 @@ class ConstantFolder : public ExprMutator {
FInterpreter executor_; FInterpreter executor_;
// Internal constant checker // Internal constant checker
ConstantChecker checker_; ConstantChecker checker_;
// Module
Module module_;
// Convert value to expression. // Convert value to expression.
Expr ValueToExpr(Value value) { Expr ValueToExpr(Value value) {
if (const auto* val = value.as<TensorValueNode>()) { if (const auto* val = value.as<TensorValueNode>()) {
for (auto dim : val->data.Shape()) {
CHECK_GT(dim, 0)
<< "invalid dimension after constant eval";
}
return ConstantNode::make(val->data); return ConstantNode::make(val->data);
} else if (const auto* val = value.as<TupleValueNode>()) { } else if (const auto* val = value.as<TupleValueNode>()) {
Array<Expr> fields; Array<Expr> fields;
...@@ -171,18 +186,33 @@ class ConstantFolder : public ExprMutator { ...@@ -171,18 +186,33 @@ class ConstantFolder : public ExprMutator {
Expr ConstEvaluate(Expr expr) { Expr ConstEvaluate(Expr expr) {
std::vector<transform::Pass> passes = {transform::FuseOps(0), std::vector<transform::Pass> passes = {transform::FuseOps(0),
transform::InferType()}; transform::InferType()};
auto mod = ModuleNode::FromExpr(expr); Function func;
if (expr.as<FunctionNode>()) {
func = Downcast<Function>(expr);
} else {
// TODO(@jroesch): fix this
func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, module_), {});
}
auto mod = ModuleNode::make(
{},
module_->type_definitions,
module_->Imports());
auto global = GlobalVarNode::make("main");
mod->Add(global, func);
auto seq = transform::Sequential(passes); auto seq = transform::Sequential(passes);
mod = seq(mod); mod = seq(mod);
auto entry_func = mod->Lookup("main"); auto entry_func = mod->Lookup("main");
expr = expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func; expr = expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
return ValueToExpr(executor_(expr)); return ValueToExpr(executor_(expr));
} }
// Evaluate shape_of op
// Evaluate a call to the shape_of operator for tensors with constant
// shapes.
Expr EvaluateShapeOf(Expr expr, Array<Expr> args, Attrs attrs) { Expr EvaluateShapeOf(Expr expr, Array<Expr> args, Attrs attrs) {
Expr input = args[0]; Expr input = args[0];
const auto* param = attrs.as<ShapeOfAttrs>(); const auto* param = attrs.as<ShapeOfAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
tvm::Array<IndexExpr> ishape; tvm::Array<IndexExpr> ishape;
if (const ConstantNode* op = input.as<ConstantNode>()) { if (const ConstantNode* op = input.as<ConstantNode>()) {
ishape = op->tensor_type()->shape; ishape = op->tensor_type()->shape;
...@@ -191,33 +221,48 @@ class ConstantFolder : public ExprMutator { ...@@ -191,33 +221,48 @@ class ConstantFolder : public ExprMutator {
} else { } else {
return expr; return expr;
} }
// Get the constant shape // Get the constant shape
DLContext ctx; DLContext ctx;
ctx.device_type = kDLCPU; ctx.device_type = kDLCPU;
ctx.device_id = 0; ctx.device_id = 0;
auto val = runtime::NDArray::Empty( runtime::NDArray value;
{(int64_t)ishape.size()}, Type2TVMType(Int(32)), ctx); auto cdtype = Type2TVMType(Int(32));
int32_t* dims = static_cast<int32_t*>(val->data); if (ishape.size() == 0) {
using ::tvm::ir::IntImm; value = runtime::NDArray::Empty({}, cdtype, ctx);
for (size_t i = 0; i < ishape.size(); ++i) { } else {
if (const IntImm* dim = ishape[i].as<IntImm>()) { CHECK_NE(ishape.size(), 0);
dims[i] = dim->value; std::vector<int64_t> cshape = { static_cast<int64_t>(ishape.size()) };
} else { value = runtime::NDArray::Empty(cshape, cdtype, ctx);
return expr; int32_t* dims = static_cast<int32_t*>(value->data);
using ::tvm::ir::IntImm;
for (size_t i = 0; i < ishape.size(); ++i) {
if (const IntImm* dim = ishape[i].as<IntImm>()) {
dims[i] = dim->value;
} else {
return expr;
}
} }
} }
Expr shape = ValueToExpr(TensorValueNode::make(val));
Constant shape = Downcast<Constant>(ValueToExpr(TensorValueNode::make(value)));
if (shape->data.Shape().size() == 0 && GetScalarFromConstant<int32_t>(shape) == 0) {
auto ndarray = runtime::NDArray::Empty({}, cdtype, ctx);
shape = ConstantNode::make(ndarray);
}
// Cast the constant into correct dtype // Cast the constant into correct dtype
auto cast_attrs = make_node<CastAttrs>(); auto cast_attrs = make_node<CastAttrs>();
cast_attrs->dtype = param->dtype; cast_attrs->dtype = param->dtype;
static const Op& cast_op = Op::Get("cast"); static const Op& cast_op = Op::Get("cast");
Expr ret = CallNode::make(cast_op, {shape}, Attrs(cast_attrs), {}); Expr ret = CallNode::make(cast_op, { shape }, Attrs(cast_attrs), {});
return ConstEvaluate(ret); return ConstEvaluate(ret);
} }
}; };
Expr FoldConstant(const Expr& expr) { Expr FoldConstant(const Expr& expr, const Module& mod) {
DLContext ctx; DLContext ctx;
ctx.device_type = kDLCPU; ctx.device_type = kDLCPU;
ctx.device_id = 0; ctx.device_id = 0;
...@@ -227,7 +272,7 @@ Expr FoldConstant(const Expr& expr) { ...@@ -227,7 +272,7 @@ Expr FoldConstant(const Expr& expr) {
With<BuildConfig> fresh_build_ctx(BuildConfig::Create()); With<BuildConfig> fresh_build_ctx(BuildConfig::Create());
return ConstantFolder(CreateInterpreter( return ConstantFolder(CreateInterpreter(
Module(nullptr), ctx, target)).Mutate(expr); mod, ctx, target), mod).Mutate(expr);
} }
namespace transform { namespace transform {
...@@ -235,7 +280,7 @@ namespace transform { ...@@ -235,7 +280,7 @@ namespace transform {
Pass FoldConstant() { Pass FoldConstant() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, Module m, PassContext pc) {
return Downcast<Function>(FoldConstant(f)); return Downcast<Function>(FoldConstant(f, m));
}; };
return CreateFunctionPass(pass_func, 2, "FoldConstant", {}); return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
} }
......
...@@ -862,6 +862,13 @@ class FuseMutator : private ExprMutator { ...@@ -862,6 +862,13 @@ class FuseMutator : private ExprMutator {
Expr VisitExpr_(const CallNode* call) { Expr VisitExpr_(const CallNode* call) {
static const Op& stop_fusion = Op::Get("annotation.stop_fusion"); static const Op& stop_fusion = Op::Get("annotation.stop_fusion");
if (call->op.as<OpNode>()) { if (call->op.as<OpNode>()) {
static auto fnoncomputational =
Op::GetAttr<TNonComputational>("TNonComputational");
if (fnoncomputational.get(Downcast<Op>(call->op), false)) {
return ExprMutator::VisitExpr_(call);
}
// If it is a primitive op call // If it is a primitive op call
// then we must have a group assignment for it already. // then we must have a group assignment for it already.
CHECK(gmap_.count(call)); CHECK(gmap_.count(call));
......
...@@ -314,7 +314,7 @@ Module FunctionPassNode::operator()(const Module& mod, ...@@ -314,7 +314,7 @@ Module FunctionPassNode::operator()(const Module& mod,
<< pass_info->opt_level; << pass_info->opt_level;
// Execute the pass function and return a new module. // Execute the pass function and return a new module.
Module updated_mod = ModuleNode::make(mod->functions, mod->type_definitions); Module updated_mod = ModuleNode::make(mod->functions, mod->type_definitions, mod->Imports());
std::vector<std::pair<GlobalVar, Function> > updates; std::vector<std::pair<GlobalVar, Function> > updates;
for (const auto& it : updated_mod->functions) { for (const auto& it : updated_mod->functions) {
auto updated_func = SkipFunction(it.second) auto updated_func = SkipFunction(it.second)
......
...@@ -311,8 +311,15 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -311,8 +311,15 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
Match match = GetRef<Match>(op); Match match = GetRef<Match>(op);
Array<Pattern> unmatched_cases = UnmatchedCases(match, this->mod_); Array<Pattern> unmatched_cases = UnmatchedCases(match, this->mod_);
if (unmatched_cases.size() != 0) { if (unmatched_cases.size() != 0) {
LOG(FATAL) << "Match clause " << match << " does not handle the following cases: " RelayErrorStream ss;
<< unmatched_cases; ss << "match expression does not handle the following cases: ";
int i = 0;
for (auto cs : unmatched_cases) {
ss << "case " << i << ": \n" << PrettyPrint(cs);
}
this->ReportFatalError(
match,
ss);
} }
} }
......
...@@ -530,8 +530,10 @@ class TypeSolver::Merger : public TypeFunctor<void(const Type&)> { ...@@ -530,8 +530,10 @@ class TypeSolver::Merger : public TypeFunctor<void(const Type&)> {
}; };
// constructor // constructor
TypeSolver::TypeSolver(const GlobalVar& current_func, const Module& module, TypeSolver::TypeSolver(
ErrorReporter* err_reporter) const GlobalVar& current_func,
const Module& module,
ErrorReporter* err_reporter)
: reporter_(make_node<Reporter>(this)), : reporter_(make_node<Reporter>(this)),
current_func(current_func), current_func(current_func),
err_reporter_(err_reporter), err_reporter_(err_reporter),
......
...@@ -287,9 +287,13 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { ...@@ -287,9 +287,13 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) {
} }
case Opcode::AllocTensor: { case Opcode::AllocTensor: {
// Number of fields = 5 + instr.alloc_tensor.ndim // Number of fields = 5 + instr.alloc_tensor.ndim
fields.push_back(instr.alloc_tensor.storage);
// Save `DLDataType` and the dst register. // Save `DLDataType` and the dst register.
const auto& dtype = instr.alloc_tensor.dtype; const auto& dtype = instr.alloc_tensor.dtype;
fields.assign({dtype.code, dtype.bits, dtype.lanes}); fields.push_back(dtype.code);
fields.push_back(dtype.bits);
fields.push_back(dtype.lanes);
// The number of dimensions is not needed for constructing an // The number of dimensions is not needed for constructing an
// `AllocTensor` instruction as it equals to the length of the `shape` // `AllocTensor` instruction as it equals to the length of the `shape`
...@@ -305,10 +309,22 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { ...@@ -305,10 +309,22 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) {
break; break;
} }
case Opcode::AllocTensorReg: { case Opcode::AllocTensorReg: {
// Number of fields = 5 // Number of fields = 6
fields.push_back(instr.alloc_tensor_reg.storage);
fields.push_back(instr.alloc_tensor_reg.shape_register); fields.push_back(instr.alloc_tensor_reg.shape_register);
// Save `DLDataType` and the dst register. // Save `DLDataType` and the dst register.
const auto& dtype = instr.alloc_tensor.dtype; const auto& dtype = instr.alloc_tensor_reg.dtype;
fields.push_back(dtype.code);
fields.push_back(dtype.bits);
fields.push_back(dtype.lanes);
fields.push_back(instr.dst);
break;
}
case Opcode::AllocStorage: {
fields.push_back(instr.alloc_storage.allocation_size);
fields.push_back(instr.alloc_storage.alignment);
// Save `DLDataType` and the dst register.
const auto& dtype = instr.alloc_storage.dtype_hint;
fields.push_back(dtype.code); fields.push_back(dtype.code);
fields.push_back(dtype.bits); fields.push_back(dtype.bits);
fields.push_back(dtype.lanes); fields.push_back(dtype.lanes);
...@@ -521,35 +537,39 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { ...@@ -521,35 +537,39 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) {
return Instruction::InvokePacked(packed_index, arity, output_size, args); return Instruction::InvokePacked(packed_index, arity, output_size, args);
} }
case Opcode::AllocTensor: { case Opcode::AllocTensor: {
// Number of fields = 5 + instr.alloc_tensor.ndim // Number of fields = 6 + instr.alloc_tensor.ndim
DCHECK_GE(instr.fields.size(), 5U); DCHECK_GE(instr.fields.size(), 6U);
DCHECK_EQ(instr.fields.size(), 5U + static_cast<size_t>(instr.fields[3])); DCHECK_EQ(instr.fields.size(), 6U + static_cast<size_t>(instr.fields[4]));
RegName storage_reg = instr.fields[0];
DLDataType dtype; DLDataType dtype;
dtype.code = instr.fields[0]; dtype.code = instr.fields[1];
dtype.bits = instr.fields[1]; dtype.bits = instr.fields[2];
dtype.lanes = instr.fields[2]; dtype.lanes = instr.fields[3];
Index ndim = instr.fields[3]; Index ndim = instr.fields[4];
RegName dst = instr.fields[4]; RegName dst = instr.fields[5];
std::vector<Index> shape = ExtractFields(instr.fields, 5, ndim); std::vector<Index> shape = ExtractFields(instr.fields, 6, ndim);
return Instruction::AllocTensor(shape, dtype, dst); return Instruction::AllocTensor(storage_reg, shape, dtype, dst);
} }
case Opcode::AllocTensorReg: { case Opcode::AllocTensorReg: {
// Number of fields = 5 // Number of fields = 5
DCHECK_EQ(instr.fields.size(), 5U); DCHECK_EQ(instr.fields.size(), 6U);
Index shape_register = instr.fields[0];
RegName storage_reg = instr.fields[0];
Index shape_register = instr.fields[1];
DLDataType dtype; DLDataType dtype;
dtype.code = instr.fields[1]; dtype.code = instr.fields[2];
dtype.bits = instr.fields[2]; dtype.bits = instr.fields[3];
dtype.lanes = instr.fields[3]; dtype.lanes = instr.fields[4];
RegName dst = instr.fields[4]; RegName dst = instr.fields[5];
return Instruction::AllocTensorReg(shape_register, dtype, dst); return Instruction::AllocTensorReg(storage_reg, shape_register, dtype, dst);
} }
case Opcode::AllocADT: { case Opcode::AllocADT: {
// Number of fields = 3 + instr.num_fields // Number of fields = 3 + instr.num_fields
...@@ -575,6 +595,24 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { ...@@ -575,6 +595,24 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) {
return Instruction::AllocClosure(clo_index, num_freevar, free_vars, dst); return Instruction::AllocClosure(clo_index, num_freevar, free_vars, dst);
} }
case Opcode::AllocStorage: {
DCHECK_GE(instr.fields.size(), 6U);
Index allocation_size = instr.fields[0];
Index alignment = instr.fields[1];
DLDataType dtype;
dtype.code = instr.fields[2];
dtype.bits = instr.fields[3];
dtype.lanes = instr.fields[4];
RegName dst = instr.fields[5];
return Instruction::AllocStorage(
allocation_size,
alignment,
dtype,
dst);
}
case Opcode::If: { case Opcode::If: {
// Number of fields = 4 // Number of fields = 4
DCHECK_EQ(instr.fields.size(), 4U); DCHECK_EQ(instr.fields.size(), 4U);
......
...@@ -32,6 +32,30 @@ namespace tvm { ...@@ -32,6 +32,30 @@ namespace tvm {
namespace runtime { namespace runtime {
namespace vm { namespace vm {
static void BufferDeleter(NDArray::Container* ptr) {
CHECK(ptr->manager_ctx != nullptr);
Buffer* buffer = reinterpret_cast<Buffer*>(ptr->manager_ctx);
MemoryManager::Global()->GetAllocator(buffer->ctx)->
Free(*(buffer));
delete buffer;
delete ptr;
}
void StorageObj::Deleter(NDArray::Container* ptr) {
// When invoking AllocNDArray we don't own the underlying allocation
// and should not delete the buffer, but instead let it be reclaimed
// by the storage object's destructor.
//
// We did bump the reference count by 1 to keep alive the StorageObj
// allocation in case this NDArray is the sole owner.
//
// We decrement the object allowing for the buffer to release our
// reference count from allocation.
StorageObj* storage = reinterpret_cast<StorageObj*>(ptr->manager_ctx);
storage->DecRef();
delete ptr;
}
inline void VerifyDataType(DLDataType dtype) { inline void VerifyDataType(DLDataType dtype) {
CHECK_GE(dtype.lanes, 1); CHECK_GE(dtype.lanes, 1);
if (dtype.code == kDLFloat) { if (dtype.code == kDLFloat) {
...@@ -50,6 +74,22 @@ inline size_t GetDataAlignment(const DLTensor& arr) { ...@@ -50,6 +74,22 @@ inline size_t GetDataAlignment(const DLTensor& arr) {
return align; return align;
} }
NDArray StorageObj::AllocNDArray(size_t offset, std::vector<int64_t> shape, DLDataType dtype) {
// TODO(@jroesch): generalize later to non-overlapping allocations.
CHECK_EQ(offset, 0u);
VerifyDataType(dtype);
NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, this->buffer.ctx);
container->deleter = StorageObj::Deleter;
size_t needed_size = GetDataSize(container->dl_tensor);
// TODO(@jroesch): generalize later to non-overlapping allocations.
CHECK(needed_size == this->buffer.size)
<< "size mistmatch required " << needed_size << " found " << this->buffer.size;
this->IncRef();
container->manager_ctx = reinterpret_cast<void*>(this);
container->dl_tensor.data = this->buffer.data;
return NDArray(container);
}
MemoryManager* MemoryManager::Global() { MemoryManager* MemoryManager::Global() {
static MemoryManager memory_manager; static MemoryManager memory_manager;
return &memory_manager; return &memory_manager;
...@@ -66,15 +106,6 @@ Allocator* MemoryManager::GetAllocator(TVMContext ctx) { ...@@ -66,15 +106,6 @@ Allocator* MemoryManager::GetAllocator(TVMContext ctx) {
return allocators_.at(ctx).get(); return allocators_.at(ctx).get();
} }
static void BufferDeleter(NDArray::Container* ptr) {
CHECK(ptr->manager_ctx != nullptr);
Buffer* buffer = reinterpret_cast<Buffer*>(ptr->manager_ctx);
MemoryManager::Global()->GetAllocator(buffer->ctx)->
Free(*(buffer));
delete buffer;
delete ptr;
}
NDArray Allocator::Empty(std::vector<int64_t> shape, DLDataType dtype, DLContext ctx) { NDArray Allocator::Empty(std::vector<int64_t> shape, DLDataType dtype, DLContext ctx) {
VerifyDataType(dtype); VerifyDataType(dtype);
NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, ctx); NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, ctx);
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/ndarray.h> #include <tvm/runtime/ndarray.h>
#include <tvm/runtime/object.h>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <mutex> #include <mutex>
...@@ -108,6 +109,38 @@ class MemoryManager { ...@@ -108,6 +109,38 @@ class MemoryManager {
std::unordered_map<TVMContext, std::unique_ptr<Allocator> > allocators_; std::unordered_map<TVMContext, std::unique_ptr<Allocator> > allocators_;
}; };
/*! \brief An object representing a storage allocation. */
class StorageObj : public Object {
public:
/*! \brief The index into the VM function table. */
Buffer buffer;
/*! \brief Allocate an NDArray from a given piece of storage. */
NDArray AllocNDArray(size_t offset,
std::vector<int64_t> shape,
DLDataType dtype);
/*! \brief The deleter for an NDArray when allocated from underlying storage. */
static void Deleter(NDArray::Container* ptr);
~StorageObj() {
auto alloc = MemoryManager::Global()->GetAllocator(buffer.ctx);
alloc->Free(buffer);
}
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "vm.Storage";
TVM_DECLARE_FINAL_OBJECT_INFO(StorageObj, Object);
};
/*! \brief reference to storage. */
class Storage : public ObjectRef {
public:
explicit Storage(Buffer buffer);
TVM_DEFINE_OBJECT_REF_METHODS_MUT(Storage, ObjectRef, StorageObj);
};
} // namespace vm } // namespace vm
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License
import tvm
import numpy as np
from tvm import relay
from tvm.relay import memory_alloc
def check_vm_alloc(func, check_fn):
mod = relay.Module()
mod['main'] = func
ex = relay.create_executor('vm', mod)
args = []
for param in func.params:
param = param.type_annotation
sh = [int(sh) for sh in param.shape]
data = np.random.rand(*sh).astype(param.dtype)
args.append(tvm.nd.array(data))
result = ex.evaluate(mod['main'])(*args)
py_res = check_fn(*[arg.asnumpy() for arg in args])
np.testing.assert_allclose(result.asnumpy(), py_res)
def storage_type(mod):
return relay.TypeCall(mod.get_global_type_var("Storage"), [])
def test_tyck_alloc_storage():
mod = relay.Module()
mod.import_from_std("core.rly")
def test_tyck_alloc_tensor():
mod = relay.Module()
mod.import_from_std("core.rly")
sto = relay.Var("x", storage_type(mod))
sh = relay.const(np.array([1, 2]), dtype="int64")
at = relay.op.memory.alloc_tensor(sto, sh)
mod['main'] = relay.Function([sto], at)
relay.transform.InferType()(mod)
def check_add(x):
return x + x
def test_add():
x = relay.var('x', shape=(2,))
z = x + x
func = relay.Function([x,], z)
check_vm_alloc(func, check_add)
def check_add_sub(x, y):
z = x + x
return z - y
def test_add_sub():
x = relay.var('x', shape=(10,))
y = relay.var('y', shape=(10,))
z = x + x
z = z - y
func = relay.Function([x, y], z)
check_vm_alloc(func, check_add_sub)
if __name__ == "__main__":
test_tyck_alloc_tensor()
test_add()
test_add_sub()
...@@ -107,9 +107,9 @@ def test_serializer(): ...@@ -107,9 +107,9 @@ def test_serializer():
assert any(item.startswith('fused_multiply') for item in prim_ops) assert any(item.startswith('fused_multiply') for item in prim_ops)
code = exe.bytecode code = exe.bytecode
assert "main 5 2 5" in code assert "main 8 2 8" in code
assert "f1 2 1 3" in code assert "f1 5 1 6" in code
assert "f2 2 1 3" in code assert "f2 5 1 6" in code
code, lib = exe.save() code, lib = exe.save()
assert isinstance(code, bytearray) assert isinstance(code, bytearray)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment