Commit 2083513f by Jared Roesch Committed by Tianqi Chen

Implement explicit IR representation of memory alloction (#3560)

parent 19164063
...@@ -272,6 +272,7 @@ add_library(tvm_runtime SHARED ${RUNTIME_SRCS}) ...@@ -272,6 +272,7 @@ add_library(tvm_runtime SHARED ${RUNTIME_SRCS})
if(USE_RELAY_DEBUG) if(USE_RELAY_DEBUG)
message(STATUS "Building Relay in debug mode...") message(STATUS "Building Relay in debug mode...")
set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "USE_RELAY_DEBUG") set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "USE_RELAY_DEBUG")
set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "DMLC_LOG_DEBUG")
else() else()
set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "NDEBUG") set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "NDEBUG")
endif(USE_RELAY_DEBUG) endif(USE_RELAY_DEBUG)
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/attrs/memory.h
* \brief Attributes for memory operators.
*/
#ifndef TVM_RELAY_ATTRS_MEMORY_H_
#define TVM_RELAY_ATTRS_MEMORY_H_
#include <tvm/attrs.h>
#include <tvm/relay/expr.h>
#include <string>
namespace tvm {
namespace relay {
/*!
* \brief Options for allocating tensors.
*/
struct AllocTensorAttrs : public tvm::AttrsNode<AllocTensorAttrs> {
Constant const_shape;
Array<IndexExpr> assert_shape;
DataType dtype;
TVM_DECLARE_ATTRS(AllocTensorAttrs, "relay.attrs.AllocTensorAttrs") {
TVM_ATTR_FIELD(dtype)
.describe(
"The dtype of the tensor to allocate.")
.set_default(Float(32, 1));
TVM_ATTR_FIELD(const_shape)
.describe(
"The shape of constant used to aid in type inference.");
TVM_ATTR_FIELD(assert_shape)
.describe(
"The shape to cast the return type of the allocation to, "\
"used to specify the shape obtained via further analysis.");
}
};
/*!
* \brief Options for the shape function operator.
*/
struct ShapeFuncAttrs : public tvm::AttrsNode<ShapeFuncAttrs> {
Array<Integer> is_input;
TVM_DECLARE_ATTRS(ShapeFuncAttrs, "relay.attrs.ShapeFuncAttrs") {
TVM_ATTR_FIELD(is_input)
.describe(
"A bool indicating whether the shape function should"\
"expect shape or input in each position.");
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_MEMORY_H_
...@@ -47,6 +47,12 @@ namespace relay { ...@@ -47,6 +47,12 @@ namespace relay {
(*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \ (*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \
} }
#define RELAY_DEBUG_INTERP(...) \
{ auto fdebug = runtime::Registry::Get("relay.debug_interp"); \
CHECK(fdebug) << "Could not find Relay Python debugger function."; \
(*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \
}
/*! /*!
* \brief We always used NodeRef for referencing nodes. * \brief We always used NodeRef for referencing nodes.
* *
......
...@@ -76,7 +76,8 @@ class ModuleNode : public RelayNode { ...@@ -76,7 +76,8 @@ class ModuleNode : public RelayNode {
} }
TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs, TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs,
tvm::Map<GlobalTypeVar, TypeData> global_type_defs); tvm::Map<GlobalTypeVar, TypeData> global_type_defs,
std::unordered_set<std::string> imports = {});
/*! /*!
* \brief Add a function to the global environment. * \brief Add a function to the global environment.
...@@ -235,6 +236,11 @@ class ModuleNode : public RelayNode { ...@@ -235,6 +236,11 @@ class ModuleNode : public RelayNode {
*/ */
TVM_DLL void ImportFromStd(const std::string& path); TVM_DLL void ImportFromStd(const std::string& path);
/*!
* \brief The set of imported files.
*/
TVM_DLL std::unordered_set<std::string> Imports() const;
/*! \brief Construct a module from a standalone expression. /*! \brief Construct a module from a standalone expression.
* *
* Allows one to optionally pass a global function map and * Allows one to optionally pass a global function map and
......
...@@ -283,6 +283,8 @@ class Object { ...@@ -283,6 +283,8 @@ class Object {
* \note The deleter will be called when ref_counter_ becomes zero. * \note The deleter will be called when ref_counter_ becomes zero.
*/ */
inline void DecRef(); inline void DecRef();
private:
/*! /*!
* \return The usage count of the cell. * \return The usage count of the cell.
* \note We use stl style naming to be consistent with known API in shared_ptr. * \note We use stl style naming to be consistent with known API in shared_ptr.
...@@ -675,6 +677,16 @@ struct ObjectEqual { ...@@ -675,6 +677,16 @@ struct ObjectEqual {
operator bool() const { return data_ != nullptr; } \ operator bool() const { return data_ != nullptr; } \
using ContainerType = ObjectName; using ContainerType = ObjectName;
#define TVM_DEFINE_OBJECT_REF_METHODS_MUT(TypeName, ParentType, ObjectName) \
TypeName() {} \
explicit TypeName( \
::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \
: ParentType(n) {} \
ObjectName* operator->() { \
return static_cast<ObjectName*>(data_.get()); \
} \
operator bool() const { return data_ != nullptr; } \
using ContainerType = ObjectName;
// Implementations details below // Implementations details below
// Object reference counting. // Object reference counting.
......
...@@ -138,6 +138,7 @@ enum class Opcode { ...@@ -138,6 +138,7 @@ enum class Opcode {
GetTag = 13U, GetTag = 13U,
LoadConsti = 14U, LoadConsti = 14U,
Fatal = 15U, Fatal = 15U,
AllocStorage = 16U,
}; };
/*! \brief A single virtual machine instruction. /*! \brief A single virtual machine instruction.
...@@ -158,6 +159,8 @@ struct Instruction { ...@@ -158,6 +159,8 @@ struct Instruction {
union { union {
struct /* AllocTensor Operands */ { struct /* AllocTensor Operands */ {
/*! \brief The storage to allocate from. */
RegName storage;
/*! \brief The number of dimensions. */ /*! \brief The number of dimensions. */
uint32_t ndim; uint32_t ndim;
/*! \brief The shape of tensor. */ /*! \brief The shape of tensor. */
...@@ -166,6 +169,8 @@ struct Instruction { ...@@ -166,6 +169,8 @@ struct Instruction {
DLDataType dtype; DLDataType dtype;
} alloc_tensor; } alloc_tensor;
struct /* AllocTensorReg Operands */ { struct /* AllocTensorReg Operands */ {
/*! \brief The storage to allocate from. */
RegName storage;
/*! \brief The register to read the shape out of. */ /*! \brief The register to read the shape out of. */
RegName shape_register; RegName shape_register;
/*! \brief The datatype of tensor to be allocated. */ /*! \brief The datatype of tensor to be allocated. */
...@@ -253,6 +258,14 @@ struct Instruction { ...@@ -253,6 +258,14 @@ struct Instruction {
/*! \brief The free variables as an array. */ /*! \brief The free variables as an array. */
RegName* free_vars; RegName* free_vars;
}; };
struct /* AllocStorage Operands */ {
/*! \brief The size of the allocation. */
RegName allocation_size;
/*! \brief The alignment of the allocation. */
RegName alignment;
/*! \brief The hint of the dtype. */
DLDataType dtype_hint;
} alloc_storage;
}; };
/*! \brief Construct a return instruction. /*! \brief Construct a return instruction.
...@@ -274,19 +287,23 @@ struct Instruction { ...@@ -274,19 +287,23 @@ struct Instruction {
static Instruction InvokePacked(Index packed_index, Index arity, Index output_size, static Instruction InvokePacked(Index packed_index, Index arity, Index output_size,
const std::vector<RegName>& args); const std::vector<RegName>& args);
/*! \brief Construct an allocate tensor instruction with constant shape. /*! \brief Construct an allocate tensor instruction with constant shape.
* \param storage The storage to allocate out of.
* \param shape The shape of the tensor. * \param shape The shape of the tensor.
* \param dtype The dtype of the tensor. * \param dtype The dtype of the tensor.
* \param dst The destination register. * \param dst The destination register.
* \return The allocate tensor instruction. * \return The allocate tensor instruction.
*/ */
static Instruction AllocTensor(std::vector<int64_t> shape, DLDataType dtype, RegName dst); static Instruction AllocTensor(RegName storage,
const std::vector<int64_t>& shape, DLDataType dtype, RegName dst);
/*! \brief Construct an allocate tensor instruction with register. /*! \brief Construct an allocate tensor instruction with register.
* \param storage The storage to allocate out of.
* \param shape_register The register containing the shape. * \param shape_register The register containing the shape.
* \param dtype The dtype of the tensor. * \param dtype The dtype of the tensor.
* \param dst The destination register. * \param dst The destination register.
* \return The allocate tensor instruction. * \return The allocate tensor instruction.
*/ */
static Instruction AllocTensorReg(RegName shape_register, DLDataType dtype, RegName dst); static Instruction AllocTensorReg(RegName storage,
RegName shape_register, DLDataType dtype, RegName dst);
/*! \brief Construct an allocate datatype instruction. /*! \brief Construct an allocate datatype instruction.
* \param tag The datatype tag. * \param tag The datatype tag.
* \param num_fields The number of fields for the datatype. * \param num_fields The number of fields for the datatype.
...@@ -364,6 +381,16 @@ struct Instruction { ...@@ -364,6 +381,16 @@ struct Instruction {
*/ */
static Instruction Move(RegName src, RegName dst); static Instruction Move(RegName src, RegName dst);
/*! \brief Allocate a storage block.
* \param size The size of the allocation.
* \param alignment The allocation's alignment.
* \param dtype_hint The data type hint for the allocator.
* \param dst The destination to place the storage.
* \return The alloc storage instruction.
*/
static Instruction AllocStorage(RegName size, RegName alignment,
DLDataType dtype_hint, RegName dst);
Instruction(); Instruction();
Instruction(const Instruction& instr); Instruction(const Instruction& instr);
Instruction& operator=(const Instruction& instr); Instruction& operator=(const Instruction& instr);
......
...@@ -59,6 +59,8 @@ from . import quantize ...@@ -59,6 +59,8 @@ from . import quantize
from . import qnn from . import qnn
from .scope_builder import ScopeBuilder from .scope_builder import ScopeBuilder
# Load Memory pass
from . import memory_alloc
# Required to traverse large programs # Required to traverse large programs
setrecursionlimit(10000) setrecursionlimit(10000)
......
...@@ -99,6 +99,10 @@ class CompileEngine(NodeBase): ...@@ -99,6 +99,10 @@ class CompileEngine(NodeBase):
msg += "--------------------------\n" msg += "--------------------------\n"
raise RuntimeError(msg) raise RuntimeError(msg)
def lower_shape_func(self, source_func, target=None):
key = _get_cache_key(source_func, target)
return _backend._CompileEngineLowerShapeFunc(self, key)
def jit(self, source_func, target=None): def jit(self, source_func, target=None):
"""JIT a source_func to a tvm.Function. """JIT a source_func to a tvm.Function.
......
...@@ -25,9 +25,14 @@ def _debugger_init(expr, stack): ...@@ -25,9 +25,14 @@ def _debugger_init(expr, stack):
import pdb import pdb
pdb.set_trace() pdb.set_trace()
# pylint: disable=unused-argument
@register_func("relay.debug") @register_func("relay.debug")
def _debug(*args): def _debug(*args):
import pdb
pdb.set_trace()
# pylint: disable=unused-argument
@register_func("relay.debug_interp")
def _debug_interp(*args):
_, _, _, ist = args _, _, _, ist = args
print("Relay Debugger") print("Relay Debugger")
print(" You can manipulate the expression under evaluation with the name `expr`.") print(" You can manipulate the expression under evaluation with the name `expr`.")
......
...@@ -317,6 +317,9 @@ class Function(Expr): ...@@ -317,6 +317,9 @@ class Function(Expr):
return _expr.FunctionSetParams(self, params) return _expr.FunctionSetParams(self, params)
def set_attribute(self, name, ref):
return _expr.FunctionSetAttr(self, name, ref)
@register_relay_node @register_relay_node
class Call(Expr): class Call(Expr):
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return,invalid-name,len-as-condition
"""
A pass for manifesting explicit memory allocations.
"""
import numpy as np
from .expr_functor import ExprMutator
from .scope_builder import ScopeBuilder
from . import transform
from . import op, ty, expr
from .. import TVMType, register_func
from .backend import compile_engine
def is_primitive(call):
return hasattr(call.op, 'attrs') and int(call.op.attrs.Primitive) == 1
# TODO(@jroesch): port to c++ and unify with existing code
class LinearizeRetType:
"""A linear view of a Relay type, handles a linear order
for nested tuples, and tensor types.
"""
def __init__(self, typ):
"""Initialize the linearizer."""
self.typ = typ
def unpack(self):
"""Return the linear representation of the type."""
def _unpack(typ, out):
# TODO(@jroesch): replace with new flattening pass
if isinstance(typ, ty.TensorType):
out.append(typ)
elif isinstance(typ, ty.TupleType):
for field_ty in typ.fields:
_unpack(field_ty, out)
else:
raise Exception(f"unsupported Relay type: {typ}")
output = []
_unpack(self.typ, output)
return output
def pack(self, seq):
"""Repack a linear type as a nested type."""
def _pack(value, typ, out):
if isinstance(typ, ty.TensorType):
out.append(value)
elif isinstance(typ, ty.TupleType):
tuple_out = []
for i, field_ty in enumerate(typ.fields):
_pack(value[i], field_ty, tuple_out)
out.append(expr.Tuple(tuple_out))
else:
raise Exception(f"unsupported Relay type: {typ}")
if len(seq) == 1:
return seq[0]
else:
out = []
_pack(seq, self.typ, out)
assert len(out) == 1, "must return fully packed type"
return out[0]
class ManifestAllocPass(ExprMutator):
"""A pass for explictly manifesting all memory allocations in Relay."""
def __init__(self, target_host):
self.invoke_tvm = op.memory.invoke_tvm_op
self.alloc_storage = op.memory.alloc_storage
self.alloc_tensor = op.memory.alloc_tensor
self.shape_func = op.memory.shape_func
self.scopes = [ScopeBuilder()]
self.target_host = target_host
self.compute_dtype = "int64"
super().__init__()
def current_scope(self):
return self.scopes[-1]
def shape_of(self, e):
return op.shape_of(e, self.compute_dtype)
def visit_tuple(self, tup):
scope = self.current_scope()
new_fields = []
for field in tup.fields:
field = self.visit(field)
if isinstance(field, expr.Constant):
field = scope.let('const', field)
new_fields.append(field)
return expr.Tuple(new_fields)
def compute_alignment(self, dtype):
dtype = TVMType(dtype)
align = (dtype.bits // 8) * dtype.lanes
# MAGIC CONSTANT FROM device_api.h
if align < 64:
align = 64
return expr.const(align, dtype="int64")
def compute_storage_in_relay(self, shape, dtype):
dtype = TVMType(dtype)
els = op.prod(shape)
num = expr.const(dtype.bits * dtype.lanes, self.compute_dtype)
num = num + expr.const(7, self.compute_dtype)
div = expr.const(8, self.compute_dtype)
return els * (num / div)
def compute_storage(self, tensor_type):
dtype = TVMType(tensor_type.dtype)
shape = [int(sh) for sh in tensor_type.shape]
size = 1
for sh in shape:
size *= sh
size *= (dtype.bits * dtype.lanes + 7) // 8
return expr.const(size, dtype=self.compute_dtype)
def make_static_allocation(self, scope, tensor_type, i):
"""Allocate a tensor with a statically known shape."""
shape = [int(sh) for sh in tensor_type.shape]
if len(shape) == 0:
shape = expr.const(np.array([]).astype(
self.compute_dtype), dtype=self.compute_dtype)
else:
shape = expr.const(np.array(shape), dtype=self.compute_dtype)
size = self.compute_storage(tensor_type)
alignment = self.compute_alignment(tensor_type.dtype)
dtype = tensor_type.dtype
sto = scope.let(f"storage_{i}", self.alloc_storage(
size, alignment, dtype))
# TODO(@jroesch): There is a bug with typing based on the constant shape.
tensor = self.alloc_tensor(sto, shape, dtype, tensor_type.shape)
return scope.let(f"tensor_{i}", tensor)
def visit_let(self, let):
scope = ScopeBuilder()
self.scopes.append(scope)
while isinstance(let, expr.Let):
new_val = self.visit(let.value)
scope.let(let.var, new_val)
let = let.body
new_body = self.visit(let)
scope.ret(new_body)
self.scopes.pop()
return scope.get()
def visit_call(self, call):
if is_primitive(call):
# Because we are in ANF we do not need to visit the arguments.
scope = self.current_scope()
new_args = [self.visit(arg) for arg in call.args]
ins = expr.Tuple(new_args)
ret_type = call.checked_type
is_dynamic = ret_type.is_dynamic()
# TODO(@jroesch): restore this code, more complex then it seems
# for arg in call.args:
# is_dynamic = is_dynamic or arg.checked_type.is_dynamic()
if is_dynamic:
assert isinstance(ret_type, ty.TensorType)
shape_func_ins = []
engine = compile_engine.get()
cfunc = engine.lower_shape_func(call.op, self.target_host)
input_states = cfunc.shape_func_param_states
is_inputs = []
for i, (arg, state) in enumerate(zip(new_args, input_states)):
state = int(state)
# Pass Shapes
if state == 2:
sh_of = self.visit(self.shape_of(arg))
shape_func_ins.append(
scope.let(f"in_shape_{i}", sh_of))
is_inputs.append(0)
# Pass Inputs
elif state == 1:
new_arg = self.visit(arg)
shape_func_ins.append(
scope.let(f"in_shape_{i}", new_arg))
is_inputs.append(1)
# TODO(@jroesch): handle 3rd case
else:
raise Exception("unsupported shape function input state")
out_shapes = []
for i, out in enumerate(cfunc.outputs):
tt = ty.TensorType(out.shape, out.dtype)
alloc = self.make_static_allocation(scope, tt, i)
alloc = scope.let(f"shape_func_out_{i}", alloc)
out_shapes.append(alloc)
shape_call = self.shape_func(
call.op,
expr.Tuple(shape_func_ins),
expr.Tuple(out_shapes), is_inputs)
scope.let("shape_func", shape_call)
out_types = []
out_types.append(call.checked_type)
storages = []
for out_shape, out_type in zip(out_shapes, out_types):
size = self.compute_storage_in_relay(
out_shape, out_type.dtype)
alignment = self.compute_alignment(out_type.dtype)
sto = scope.let(f"storage_{i}", self.alloc_storage(
size, alignment, out_type.dtype))
storages.append(sto)
outs = []
sh_ty_storage = zip(out_shapes, out_types, storages)
for i, (out_shape, out_type, storage) in enumerate(sh_ty_storage):
alloc = self.alloc_tensor(
storage,
out_shape,
out_type.dtype,
out_type.shape)
alloc = scope.let(f"out_{i}", alloc)
outs.append(alloc)
invoke = self.invoke_tvm(call.op, ins, expr.Tuple(outs))
scope.let("", invoke)
return outs[0]
else:
view = LinearizeRetType(ret_type)
out_tys = view.unpack()
outs = []
for i, out_ty in enumerate(out_tys):
out = self.make_static_allocation(scope, out_ty, i)
outs.append(out)
output = expr.Tuple(outs)
invoke = self.invoke_tvm(call.op, ins, output)
scope.let("", invoke)
return view.pack(output)
else:
return super().visit_call(call)
@transform.function_pass(opt_level=0)
class ManifestAlloc:
"""The explicit pass wrapper around ManifestAlloc."""
def __init__(self, target_host):
self.target_host = target_host
def transform_function(self, func, mod, _):
# TODO(@jroesch): Is there a way to do one shot initilization?
# can we have def pass_init?
mod.import_from_std("core.rly")
ea = ManifestAllocPass(self.target_host)
func = ea.visit(func)
return func
register_func("relay.transform.ManifestAlloc", ManifestAlloc)
...@@ -28,6 +28,7 @@ from .transform import * ...@@ -28,6 +28,7 @@ from .transform import *
from .algorithm import * from .algorithm import *
from . import nn from . import nn
from . import annotation from . import annotation
from . import memory
from . import image from . import image
from . import vision from . import vision
from . import contrib from . import contrib
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import
"""Operators for manipulating low level memory."""
from __future__ import absolute_import as _abs
from .memory import *
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
from ...._ffi.function import _init_api
_init_api("relay.op.memory._make", __name__)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Operators for manipulating low-level memory."""
from __future__ import absolute_import as _abs
from . import _make
def invoke_tvm_op(func, inputs, outputs):
"""Call a primitive function with the TVM operator calling convention.
Parameters
----------
inputs : tvm.relay.Expr
A tuple of the inputs to pass to the TVM function.
outputs : tvm.relay.Expr
A tuple of the outputs to pass to the TVM function.
Returns
-------
result : tvm.relay.Expr
The invoke_tvm_op call node.
"""
return _make.invoke_tvm_op(func, inputs, outputs)
def alloc_tensor(storage, shape, dtype='float32', assert_shape=None):
"""Allocate a tensor with the provided shape, and dtype.
Parameters
----------
storage : tvm.relay.Expr
The storage to allocate from.
shape : tvm.relay.Expr
The shape of the tensor to allocate.
dtype: str
The dtype of the tensor.
assert_shape: Control the static shape when computed by dynamic shape expression.
Returns
-------
result : tvm.relay.Expr
The alloc_tensor expression.
"""
return _make.alloc_tensor(storage, shape, dtype, assert_shape)
def alloc_storage(size, alignment, dtype_hint='float32'):
"""Allocate a piece of tensor storage.
Parameters
----------
size : tvm.relay.Expr
The size of the allocation.
alignment : tvm.relay.Expr
The alignment of the allocation.
dtype : str
The dtype_hint of the allocation.
Returns
-------
result : tvm.relay.Expr
The alloc_storage expression.
"""
return _make.alloc_storage(size, alignment, dtype_hint)
def shape_func(func, inputs, outputs, dependent=False):
"""Invoke the shape function of the passed function.
Parameters
----------
func : tvm.relay.Expr
The primitive function from which to compute the shape function.
inputs : tvm.relay.Tuple
The tupled inputs.
outputs : tvm.relay.Tuple
The tupled outputs.
Returns
-------
result : tvm.relay.Expr
The shape function expression.
"""
return _make.shape_func(func, inputs, outputs, dependent)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
v0.0.4
extern type Storage
...@@ -52,6 +52,9 @@ class Type(RelayNode): ...@@ -52,6 +52,9 @@ class Type(RelayNode):
""" """
return TypeCall(self, args) return TypeCall(self, args)
def is_dynamic(self):
return _make.IsDynamic(self)
@register_relay_node @register_relay_node
class TensorType(Type): class TensorType(Type):
"""A concrete TensorType in Relay. """A concrete TensorType in Relay.
...@@ -317,7 +320,6 @@ class RefType(Type): ...@@ -317,7 +320,6 @@ class RefType(Type):
def __init__(self, value): def __init__(self, value):
self.__init_handle_by_constructor__(_make.RefType, value) self.__init_handle_by_constructor__(_make.RefType, value)
def scalar_type(dtype): def scalar_type(dtype):
"""Creates a scalar type. """Creates a scalar type.
......
...@@ -72,6 +72,10 @@ bool IsDynamic(const Type& ty) { ...@@ -72,6 +72,10 @@ bool IsDynamic(const Type& ty) {
return v.is_dyn; return v.is_dyn;
} }
// TODO(@jroesch): MOVE ME
TVM_REGISTER_API("relay._make.IsDynamic")
.set_body_typed(IsDynamic);
Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) { Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
// for now, we always use int32 shape when possible // for now, we always use int32 shape when possible
// even if the result of shape inference becomes int64. // even if the result of shape inference becomes int64.
...@@ -775,6 +779,12 @@ TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower") ...@@ -775,6 +779,12 @@ TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower")
return self->Lower(key); return self->Lower(key);
}); });
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc")
.set_body_typed<CachedFunc(CompileEngine, CCacheKey)>(
[](CompileEngine self, CCacheKey key) {
return self->LowerShapeFunc(key);
});
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT") TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT")
.set_body_typed<PackedFunc(CompileEngine, CCacheKey)>( .set_body_typed<PackedFunc(CompileEngine, CCacheKey)>(
[](CompileEngine self, CCacheKey key) { [](CompileEngine self, CCacheKey key) {
......
...@@ -458,7 +458,7 @@ class Interpreter : ...@@ -458,7 +458,7 @@ class Interpreter :
if (dattrs->debug_func.defined()) { if (dattrs->debug_func.defined()) {
dattrs->debug_func(interp_state); dattrs->debug_func(interp_state);
} else { } else {
RELAY_DEBUG(interp_state); RELAY_DEBUG_INTERP(interp_state);
} }
return args[0]; return args[0];
...@@ -479,7 +479,8 @@ class Interpreter : ...@@ -479,7 +479,8 @@ class Interpreter :
if (const auto* tuple_type = func->body->checked_type().as<TupleTypeNode>()) { if (const auto* tuple_type = func->body->checked_type().as<TupleTypeNode>()) {
arg_len += tuple_type->fields.size(); arg_len += tuple_type->fields.size();
} else { } else {
CHECK(func->body->checked_type().as<TensorTypeNode>()); CHECK(func->body->checked_type().as<TensorTypeNode>())
<< func->body->checked_type();
arg_len += 1; arg_len += 1;
} }
std::vector<TVMValue> values(arg_len); std::vector<TVMValue> values(arg_len);
......
...@@ -48,6 +48,19 @@ namespace backend { ...@@ -48,6 +48,19 @@ namespace backend {
inline const PackedFunc* GetPackedFunc(const std::string& func_name) { inline const PackedFunc* GetPackedFunc(const std::string& func_name) {
return tvm::runtime::Registry::Get(func_name); return tvm::runtime::Registry::Get(func_name);
} }
/*!
* \brief Get a typed packed function.
*
* \param func_name
* \return const PackedFunc*
*/
template <typename R, typename... Args>
inline const runtime::TypedPackedFunc<R(Args...)> GetTypedPackedFunc(const std::string& func_name) {
auto *pf = GetPackedFunc(func_name);
CHECK(pf != nullptr) << "can not find packed function";
return runtime::TypedPackedFunc<R(Args...)>(*pf);
}
/*! /*!
* \brief Convert type to string * \brief Convert type to string
* *
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <tvm/logging.h> #include <tvm/logging.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h> #include <tvm/runtime/vm.h>
#include <tvm/relay/attrs/memory.h>
#include <topi/tags.h> #include <topi/tags.h>
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
...@@ -44,6 +45,7 @@ ...@@ -44,6 +45,7 @@
#include "../../../runtime/vm/naive_allocator.h" #include "../../../runtime/vm/naive_allocator.h"
#include "../../backend/compile_engine.h" #include "../../backend/compile_engine.h"
#include "../../pass/pass_util.h" #include "../../pass/pass_util.h"
#include "../../op/op_common.h"
#include "compiler.h" #include "compiler.h"
namespace tvm { namespace tvm {
...@@ -54,6 +56,12 @@ namespace transform { ...@@ -54,6 +56,12 @@ namespace transform {
Pass LambdaLift(); Pass LambdaLift();
Pass InlinePrimitives(); Pass InlinePrimitives();
Pass ManifestAlloc(Target target_host) {
auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc");
CHECK(f != nullptr) << "could not load memory allocation pass";
return (*f)(target_host);
}
} // namespace transform } // namespace transform
namespace vm { namespace vm {
...@@ -194,6 +202,39 @@ TreeNodePtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array<Clause> ...@@ -194,6 +202,39 @@ TreeNodePtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array<Clause>
return else_branch; return else_branch;
} }
std::vector<int64_t> ToAllocTensorShape64(NDArray shape) {
std::vector<int64_t> raw_shape;
DLTensor tensor = shape.ToDLPack()->dl_tensor;
CHECK_EQ(tensor.ndim, 1u);
CHECK_EQ(tensor.dtype.code, 0U) << "found " << tensor.dtype.code;
// TODO(@jroesch): we really need to standaridize the bit width of
// all of the shape manipulating code.
CHECK_EQ(tensor.dtype.bits, 64) << "found " << tensor.dtype.bits;
int64_t* int_ptr = reinterpret_cast<int64_t*>(tensor.data);
for (auto i = 0; i < tensor.shape[0]; i++) {
raw_shape.push_back(int_ptr[i]);
}
return raw_shape;
}
std::vector<int64_t> ToAllocTensorShape32(NDArray shape) {
std::vector<int64_t> raw_shape;
DLTensor tensor = shape.ToDLPack()->dl_tensor;
CHECK_EQ(tensor.ndim, 1u);
CHECK_EQ(tensor.dtype.code, 0U) << "found " << tensor.dtype.code;
// TODO(@jroesch): we really need to standaridize the bit width of
// all of the shape manipulating code.
CHECK_LE(tensor.dtype.bits, 32) << "found " << tensor.dtype.bits;
int32_t* int_ptr = reinterpret_cast<int32_t*>(tensor.data);
for (auto i = 0; i < tensor.shape[0]; i++) {
raw_shape.push_back(static_cast<int64_t>(int_ptr[i]));
}
return raw_shape;
}
class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
public: public:
VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host) VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host)
...@@ -248,13 +289,12 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -248,13 +289,12 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
case Opcode::LoadConsti: case Opcode::LoadConsti:
case Opcode::Invoke: case Opcode::Invoke:
case Opcode::AllocClosure: case Opcode::AllocClosure:
case Opcode::AllocStorage:
case Opcode::Move: case Opcode::Move:
case Opcode::InvokeClosure: case Opcode::InvokeClosure:
last_register_ = instr.dst; last_register_ = instr.dst;
break; break;
case Opcode::InvokePacked: case Opcode::InvokePacked:
last_register_ = instr.packed_args[instr.arity - 1];
break;
case Opcode::If: case Opcode::If:
case Opcode::Ret: case Opcode::Ret:
case Opcode::Goto: case Opcode::Goto:
...@@ -302,7 +342,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -302,7 +342,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
} }
void VisitExpr_(const LetNode* let_node) { void VisitExpr_(const LetNode* let_node) {
DLOG(INFO) << AsText(let_node->value); DLOG(INFO) << PrettyPrint(let_node->value);
this->VisitExpr(let_node->value); this->VisitExpr(let_node->value);
var_register_map_.insert({let_node->var, this->last_register_}); var_register_map_.insert({let_node->var, this->last_register_});
this->VisitExpr(let_node->body); this->VisitExpr(let_node->body);
...@@ -369,100 +409,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -369,100 +409,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
this->last_register_ = true_register; this->last_register_ = true_register;
} }
Index EmitGetShape(const TensorTypeNode* ttype, Index reg) { void EmitShapeFunc(Function func, Array<Expr> inputs, Array<Expr> outputs) {
bool const_shape = true;
std::vector<int64_t> shape;
for (auto dim : ttype->shape) {
if (auto kdim = dim.as<IntImm>()) {
shape.push_back(kdim->value);
} else {
const_shape = false;
}
}
if (const_shape) {
int64_t ndim = shape.size();
DLContext cpu_ctx;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
NDArray shape_tensor;
if (ndim == 0) {
shape_tensor = NDArray::Empty({}, Type2TVMType(Int(64)), cpu_ctx);
} else {
shape_tensor = NDArray::Empty({ndim}, Type2TVMType(Int(64)), cpu_ctx);
int64_t* dims = reinterpret_cast<int64_t*>(shape_tensor->data);
for (size_t i = 0; i < shape.size(); ++i) {
dims[i] = shape[i];
}
}
size_t konst_idx = context_->constants.size();
context_->constants.push_back(shape_tensor);
Emit(Instruction::LoadConst(konst_idx, NewRegister()));
return last_register_;
}
// For dynamic shape, we need insert shape_of op to get its shape at runtime
auto attrs = make_node<ShapeOfAttrs>();
attrs->dtype = Int(64);
static const Op& op = Op::Get("shape_of");
auto input = VarNode::make("input", GetRef<Type>(ttype));
auto expr = CallNode::make(op, {input}, Attrs(attrs), {});
auto func = FunctionNode::make({input}, expr, IncompleteTypeNode::make(Kind::kType), {});
auto mod = ModuleNode::make({}, {});
auto main_gv = GlobalVarNode::make("main");
mod->Add(main_gv, func);
func = mod->Lookup(main_gv);
// shape_of op has to be run on the host target
// TODO(@icemelon9): handle heterogeneous target, such as cuda
auto key = CCacheKeyNode::make(func, target_host_);
auto cfunc = engine_->Lower(key);
auto op_index = -1;
if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc);
context_->seen_funcs[cfunc->funcs[0]] = op_index;
} else {
op_index = context_->seen_funcs[cfunc->funcs[0]];
}
std::vector<Index> arg_regs{reg};
int64_t ndim = ttype->shape.size();
if (ndim == 0) {
Emit(Instruction::AllocTensor({}, Int(64), NewRegister()));
} else {
Emit(Instruction::AllocTensor({ndim}, Int(64), NewRegister()));
}
Index shape_reg = last_register_;
arg_regs.push_back(shape_reg);
Emit(Instruction::InvokePacked(op_index, 2, 1, arg_regs));
return shape_reg;
}
std::vector<Index> EmitShapeFunc(const Type& ret_type, const Function& func,
const std::vector<Index>& unpacked_arg_regs) {
// Find the mapping from params to registers
int idx = 0;
std::vector<std::vector<Index>> param_regs;
std::vector<std::vector<const TensorTypeNode*>> param_types;
for (auto param : func->params) {
auto ty = param->checked_type();
std::vector<Index> regs;
std::vector<const TensorTypeNode*> types;
if (auto ttype = ty.as<TensorTypeNode>()) {
regs.push_back(unpacked_arg_regs[idx++]);
types.push_back(ttype);
} else if (const auto tuple_ty = ret_type.as<TupleTypeNode>()) {
for (size_t j = 0; j < tuple_ty->fields.size(); ++j, ++idx) {
regs.push_back(unpacked_arg_regs[idx]);
auto ttype = tuple_ty->fields[j].as<TensorTypeNode>();
CHECK(ttype);
types.push_back(ttype);
}
} else {
LOG(FATAL) << "unsupported parameter type " << ty;
}
param_regs.push_back(regs);
param_types.push_back(types);
}
// Lower shape function // Lower shape function
auto key = CCacheKeyNode::make(func, target_host_); auto key = CCacheKeyNode::make(func, target_host_);
auto cfunc = engine_->LowerShapeFunc(key); auto cfunc = engine_->LowerShapeFunc(key);
...@@ -476,125 +423,60 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -476,125 +423,60 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
} }
// Prepare input and output registers // Prepare input and output registers
std::vector<Index> shape_func_args; std::vector<Index> argument_registers;
std::vector<Index> shape_regs; for (auto input : inputs) {
for (size_t i = 0; i < func->params.size(); ++i) { auto reg = var_register_map_.find(Downcast<Var>(input));
int state = cfunc->shape_func_param_states[i]->value; CHECK(reg != var_register_map_.end())
if (state & kNeedInputData) { << "internal error: all variables should be in the register mapping";
for (auto reg : param_regs[i]) { argument_registers.push_back(reg->second);
// TODO(@icemelon9): Need to copy data here for heterogeneous exec
shape_func_args.push_back(reg);
}
}
if (state & kNeedInputShape) {
for (size_t j = 0; j < param_regs[i].size(); ++j) {
shape_func_args.push_back(EmitGetShape(param_types[i][j], param_regs[i][j]));
}
} }
for (auto output : outputs) {
auto reg = var_register_map_.find(Downcast<Var>(output));
CHECK(reg != var_register_map_.end())
<< "internal error: all variables should be in the register mapping";
argument_registers.push_back(reg->second);
} }
for (auto t : cfunc->outputs) {
int64_t ndim = t->shape[0].as<IntImm>()->value; Emit(Instruction::InvokePacked(op_index,
Emit(Instruction::AllocTensor({ndim}, t->dtype, NewRegister())); argument_registers.size(),
shape_func_args.push_back(last_register_); outputs.size(),
shape_regs.push_back(last_register_); argument_registers));
} }
int arity = shape_func_args.size(); void EmitInvokeTVMOp(const Function& func,
int ret_count = shape_regs.size(); const Expr& inputs,
Emit(Instruction::InvokePacked(op_index, arity, ret_count, shape_func_args)); const Expr& outputs) {
std::vector<Index> argument_registers;
// Alloc return tensors given the shape regs CHECK(func->IsPrimitive())
std::vector<DataType> ret_dtypes; << "internal error: invoke_tvm_op requires the first argument to be a relay::Function";
if (const auto* tuple_type = ret_type.as<TupleTypeNode>()) {
for (auto field : tuple_type->fields) { auto input_tuple = inputs.as<TupleNode>();
const TensorTypeNode* tty = field.as<TensorTypeNode>(); CHECK(input_tuple)
CHECK(tty); << "internal error: invoke_tvm_op inputs must be a tuple,"
ret_dtypes.push_back(tty->dtype); << "please file a bug in the memory manifestation pass";
}
} else { auto output_tuple = outputs.as<TupleNode>();
auto tty = ret_type.as<TensorTypeNode>(); CHECK(output_tuple)
CHECK(tty); << "internal error: invoke_tvm_op outputs must be a tuple,"
ret_dtypes.push_back(tty->dtype); << "please file a bug in the memory manifestation pass";
}
std::vector<Index> ret_regs; for (auto input : input_tuple->fields) {
for (size_t i = 0; i < shape_regs.size(); ++i) { auto reg = var_register_map_.find(Downcast<Var>(input));
Emit(Instruction::AllocTensorReg(shape_regs[i], ret_dtypes[i], NewRegister())); CHECK(reg != var_register_map_.end())
ret_regs.push_back(last_register_); << "internal error: all variables should be in the register mapping";
} argument_registers.push_back(reg->second);
return ret_regs;
}
std::vector<Index> AllocReturnType(const Type& ret_type, const Function& func,
const std::vector<Index>& unpacked_arg_regs) {
auto op = func->body.as<CallNode>()->op;
// 1. If either func param types or ret type is dynamic, we need to insert
// shape func to perform type checking at runtime.
// 2. We skip the shape_of function since currently Relay doesn't support
// dynamic rank tensor.
if (op != Op::Get("shape_of") && IsDynamic(func->checked_type())) {
return EmitShapeFunc(ret_type, func, unpacked_arg_regs);
}
std::vector<Index> ret_regs;
auto alloc_tensor = [&](const TensorTypeNode* ttype) {
const TensorType& tensor_type = GetRef<TensorType>(ttype);
std::vector<int64_t> shape;
for (auto dim : tensor_type->shape) {
shape.push_back(Downcast<tvm::Integer>(dim)->value);
}
Emit(Instruction::AllocTensor(shape, Type2TVMType(tensor_type->dtype), NewRegister()));
ret_regs.push_back(last_register_);
};
if (const TensorTypeNode* ttype = ret_type.as<TensorTypeNode>()) {
alloc_tensor(ttype);
} else if (const TupleTypeNode* ttype = ret_type.as<TupleTypeNode>()) {
for (auto field : ttype->fields) {
alloc_tensor(field.as<TensorTypeNode>());
}
} else {
LOG(FATAL) << "Unsupported return value type";
}
return ret_regs;
}
void EmitInvokePrimitive(const Function& func,
const std::vector<Index>& arg_registers,
const Type& ret_type) {
std::vector<Index> unpacked_arg_regs;
std::vector<Instruction> allocs;
// Arity calculation must flatten tuples.
size_t arity = 0;
CHECK_EQ(func->params.size(), arg_registers.size());
for (size_t i = 0; i < func->params.size(); i++) {
auto ty = func->params[i]->checked_type();
if (ty.as<TensorTypeNode>()) {
unpacked_arg_regs.push_back(arg_registers[i]);
arity += 1;
} else if (auto tuple_ty = ty.as<TupleTypeNode>()) {
for (size_t f = 0; f < tuple_ty->fields.size(); f++) {
const auto& field = tuple_ty->fields[f];
CHECK(field.as<TensorTypeNode>())
<< "only supports non-nested tuples currently "
<< "found " << field;
auto dst = NewRegister();
Emit(Instruction::GetField(arg_registers[i], f, dst));
unpacked_arg_regs.push_back(dst);
}
arity += tuple_ty->fields.size();
} else {
LOG(FATAL) << "unsupported parameter type " << ty;
}
} }
auto ret_regs = AllocReturnType(ret_type, func, unpacked_arg_regs); for (auto output : output_tuple->fields) {
size_t return_count = ret_regs.size(); auto reg = var_register_map_.find(Downcast<Var>(output));
arity += return_count; CHECK(reg != var_register_map_.end())
for (auto reg : ret_regs) { << "internal error: all variables should be in the register mapping";
unpacked_arg_regs.push_back(reg); argument_registers.push_back(reg->second);
} }
// Next generate the invoke instruction. // Next generate the invoke instruction.
CHECK(func->IsPrimitive());
Target target; Target target;
if (targets_.size() == 1) { if (targets_.size() == 1) {
// homogeneous execution. // homogeneous execution.
...@@ -605,8 +487,10 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -605,8 +487,10 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
// heterogeneous execution. // heterogeneous execution.
LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation"; LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation";
} }
auto key = CCacheKeyNode::make(func, target); auto key = CCacheKeyNode::make(func, target);
auto cfunc = engine_->Lower(key); auto cfunc = engine_->Lower(key);
// TODO(jroesch): support lowered funcs for multiple targets // TODO(jroesch): support lowered funcs for multiple targets
CHECK_EQ(cfunc->funcs.size(), 1); CHECK_EQ(cfunc->funcs.size(), 1);
auto op_index = -1; auto op_index = -1;
...@@ -618,19 +502,99 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -618,19 +502,99 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
op_index = context_->seen_funcs[cfunc->funcs[0]]; op_index = context_->seen_funcs[cfunc->funcs[0]];
} }
Emit(Instruction::InvokePacked(op_index, arity, return_count, unpacked_arg_regs)); Emit(Instruction::InvokePacked(op_index,
argument_registers.size(),
output_tuple->fields.size(),
argument_registers));
}
if (return_count > 1) { void VisitExpr_(const CallNode* call_node) {
// return value is a tuple, we need to create a tuple Expr op = call_node->op;
std::vector<Index> fields_registers;
for (size_t i = arity - return_count; i < arity; ++i) { // First we handle the case in which we are using an opaque
fields_registers.push_back(unpacked_arg_regs[i]); // operator used to define a sub-dialect, such as memory
// allocation operations.
if (op.as<OpNode>()) {
OpMatch<void> matcher;
matcher.Match("memory.invoke_tvm_op",
[this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
CHECK_EQ(args.size(), 3);
EmitInvokeTVMOp(Downcast<Function>(args[0]), args[1], args[2]);
}).Match("memory.alloc_tensor",
[this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
CHECK_EQ(args.size(), 2);
// Get the attributes.
auto alloc_attrs = attrs.as<AllocTensorAttrs>();
CHECK(alloc_attrs != nullptr)
<< "must be the alloc tensor attrs";
auto dtype = alloc_attrs->dtype;
// The storage will be passed dynamically.
this->VisitExpr(args[0]);
auto storage_register = last_register_;
// If the shape is constant then we will emit a static tensor allocation instruction.
auto const_shape = args[1].as<ConstantNode>();
if (const_shape) {
NDArray shape = const_shape->data;
std::vector<int64_t> raw_shape;
DLTensor tensor = shape.ToDLPack()->dl_tensor;
// TODO(@jroesch): we need to get an RFC done to standarize this
if (tensor.dtype.bits == 64) {
raw_shape = ToAllocTensorShape64(shape);
} else if (tensor.dtype.bits == 32) {
raw_shape = ToAllocTensorShape32(shape);
} else {
LOG(FATAL) << "unsupported bitwidth: " << tensor.dtype.bits;
} }
Emit(Instruction::AllocADT(0, return_count, fields_registers, NewRegister()));
// Add context field.
Emit(Instruction::AllocTensor(storage_register, raw_shape, dtype, NewRegister()));
} else {
this->VisitExpr(args[1]);
auto shape_register = last_register_;
Emit(Instruction::AllocTensorReg(
storage_register,
shape_register,
dtype,
NewRegister()));
} }
}).Match("memory.alloc_storage",
[this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
CHECK_EQ(args.size(), 2);
// Compute the size of the allocation.
this->VisitExpr(args[0]);
auto size_register = last_register_;
this->VisitExpr(args[1]);
auto alignment_register = last_register_;
// Get the dtype hint from the attributes.
auto alloc_attrs = attrs.as<AllocTensorAttrs>();
CHECK(alloc_attrs != nullptr)
<< "must be the alloc tensor attrs";
auto dtype = alloc_attrs->dtype;
Emit(Instruction::AllocStorage(size_register, alignment_register, dtype, NewRegister()));
}).Match("memory.shape_func",
[this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
CHECK_EQ(args.size(), 3);
auto shape_func = Downcast<Function>(args[0]);
auto inputs = Downcast<Tuple>(args[1]);
auto outputs = Downcast<Tuple>(args[2]);
EmitShapeFunc(shape_func, inputs->fields, outputs->fields);
}).Match("memory.kill",
[](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
LOG(FATAL) << "memory.kill is not yet supported";
});
matcher(GetRef<Call>(call_node));
return;
} }
void VisitExpr_(const CallNode* call_node) { // In the case its not one of these specialized operators we will generate code
// for one of the "standard" cases.
std::vector<Index> args_registers; std::vector<Index> args_registers;
for (auto arg : call_node->args) { for (auto arg : call_node->args) {
...@@ -638,18 +602,16 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -638,18 +602,16 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
args_registers.push_back(last_register_); args_registers.push_back(last_register_);
} }
Expr op = call_node->op; if (auto global_node = op.as<GlobalVarNode>()) {
// In the case we are invoking a global we need to find its
if (auto func_node = op.as<FunctionNode>()) { // global ID, and then check whether it is closure invocation
CHECK(func_node->IsPrimitive()); // or whether it is a standard global, and emit the correct
EmitInvokePrimitive(GetRef<Function>(func_node), args_registers, call_node->checked_type()); // calling convention.
} else if (auto global_node = op.as<GlobalVarNode>()) {
auto global = GetRef<GlobalVar>(global_node); auto global = GetRef<GlobalVar>(global_node);
auto it = context_->global_map.find(global); auto it = context_->global_map.find(global);
CHECK(it != context_->global_map.end()); CHECK(it != context_->global_map.end());
DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint
<< " with func_index=" << it->second; << " with func_index=" << it->second;
auto func = context_->module->Lookup(global); auto func = context_->module->Lookup(global);
if (IsClosure(func)) { if (IsClosure(func)) {
auto arity = func->params.size(); auto arity = func->params.size();
...@@ -658,14 +620,21 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -658,14 +620,21 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
Emit(Instruction::Invoke(it->second, args_registers, NewRegister())); Emit(Instruction::Invoke(it->second, args_registers, NewRegister()));
} }
} else if (auto constructor_node = op.as<ConstructorNode>()) { } else if (auto constructor_node = op.as<ConstructorNode>()) {
// In the constructor case, we simply need to find its tag
// and emit a call to allocate the data structure.
auto constructor = GetRef<Constructor>(constructor_node); auto constructor = GetRef<Constructor>(constructor_node);
Emit(Instruction::AllocADT(constructor->tag, call_node->args.size(), args_registers, Emit(Instruction::AllocADT(constructor->tag, call_node->args.size(), args_registers,
NewRegister())); NewRegister()));
} else if (auto var_node = op.as<VarNode>()) { } else if (auto var_node = op.as<VarNode>()) {
// If we are calling a variable, it must be the case that it is a closure so we
// emit invoke closure here.
VisitExpr(GetRef<Var>(var_node)); VisitExpr(GetRef<Var>(var_node));
Emit(Instruction::InvokeClosure(last_register_, args_registers, NewRegister())); Emit(Instruction::InvokeClosure(last_register_, args_registers, NewRegister()));
} else { } else {
LOG(FATAL) << "unsupported case in vm compiler: " << op; // Finally if there are any other cases this is a bug.
LOG(FATAL) << "internal error: unreachable code,"
<< "should be transformed away by previous passes"
<< PrettyPrint(GetRef<Expr>(call_node));
} }
} }
...@@ -836,7 +805,6 @@ relay::Function VMCompiler::BindParamsByName( ...@@ -836,7 +805,6 @@ relay::Function VMCompiler::BindParamsByName(
return ret; return ret;
} }
void VMCompiler::Compile(Module mod, void VMCompiler::Compile(Module mod,
const TargetsMap& targets, const TargetsMap& targets,
const tvm::Target& target_host) { const tvm::Target& target_host) {
...@@ -852,8 +820,7 @@ void VMCompiler::Compile(Module mod, ...@@ -852,8 +820,7 @@ void VMCompiler::Compile(Module mod,
targets_ = targets; targets_ = targets;
target_host_ = target_host; target_host_ = target_host;
// Run some optimizations first, this code should // Run the optimizations necessary to target the VM.
// be moved to pass manager.
context_.module = OptimizeModule(mod, targets_); context_.module = OptimizeModule(mod, targets_);
// Populate the global map. // Populate the global map.
...@@ -885,7 +852,7 @@ void VMCompiler::Compile(Module mod, ...@@ -885,7 +852,7 @@ void VMCompiler::Compile(Module mod,
// populate constants // populate constants
for (auto data : context_.constants) { for (auto data : context_.constants) {
exec_->constants.push_back(runtime::vm::Tensor(data)); exec_->constants.push_back(vm::Tensor(data));
} }
LibraryCodegen(); LibraryCodegen();
...@@ -942,6 +909,15 @@ Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) ...@@ -942,6 +909,15 @@ Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets)
pass_seqs.push_back(transform::LambdaLift()); pass_seqs.push_back(transform::LambdaLift());
pass_seqs.push_back(transform::InlinePrimitives()); pass_seqs.push_back(transform::InlinePrimitives());
// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());
// Manifest the allocations needed for the shape functions.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
transform::Sequential seq(pass_seqs); transform::Sequential seq(pass_seqs);
transform::PassContext pass_ctx = PassContext::Current(); transform::PassContext pass_ctx = PassContext::Current();
// TODO(wweic): Support heterogenous execution // TODO(wweic): Support heterogenous execution
......
...@@ -355,5 +355,11 @@ TVM_REGISTER_API("relay._expr.TempExprRealize") ...@@ -355,5 +355,11 @@ TVM_REGISTER_API("relay._expr.TempExprRealize")
return temp->Realize(); return temp->Realize();
}); });
TVM_REGISTER_API("relay._expr.FunctionSetAttr")
.set_body_typed<Function(Function, std::string, NodeRef)>(
[](Function func, std::string name, NodeRef ref) {
return FunctionSetAttr(func, name, ref);
});
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -35,13 +35,16 @@ using tvm::IRPrinter; ...@@ -35,13 +35,16 @@ using tvm::IRPrinter;
using namespace runtime; using namespace runtime;
Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs, Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
tvm::Map<GlobalTypeVar, TypeData> global_type_defs) { tvm::Map<GlobalTypeVar, TypeData> global_type_defs,
std::unordered_set<std::string> imports
) {
auto n = make_node<ModuleNode>(); auto n = make_node<ModuleNode>();
n->functions = std::move(global_funcs); n->functions = std::move(global_funcs);
n->type_definitions = std::move(global_type_defs); n->type_definitions = std::move(global_type_defs);
n->global_type_var_map_ = {}; n->global_type_var_map_ = {};
n->global_var_map_ = {}; n->global_var_map_ = {};
n->constructor_tag_map_ = {}; n->constructor_tag_map_ = {};
n->import_set_ = imports;
for (const auto& kv : n->functions) { for (const auto& kv : n->functions) {
// set global var map // set global var map
...@@ -283,9 +286,9 @@ Module ModuleNode::FromExpr( ...@@ -283,9 +286,9 @@ Module ModuleNode::FromExpr(
} }
void ModuleNode::Import(const std::string& path) { void ModuleNode::Import(const std::string& path) {
LOG(INFO) << "Importing: " << path;
if (this->import_set_.count(path) == 0) { if (this->import_set_.count(path) == 0) {
this->import_set_.insert(path); this->import_set_.insert(path);
DLOG(INFO) << "Importing: " << path;
std::fstream src_file(path, std::fstream::in); std::fstream src_file(path, std::fstream::in);
std::string file_contents { std::string file_contents {
std::istreambuf_iterator<char>(src_file), std::istreambuf_iterator<char>(src_file),
...@@ -302,6 +305,10 @@ void ModuleNode::ImportFromStd(const std::string& path) { ...@@ -302,6 +305,10 @@ void ModuleNode::ImportFromStd(const std::string& path) {
return this->Import(std_path + "/" + path); return this->Import(std_path + "/" + path);
} }
std::unordered_set<std::string> ModuleNode::Imports() const {
return this->import_set_;
}
Module FromText(const std::string& source, const std::string& source_name) { Module FromText(const std::string& source, const std::string& source_name) {
auto* f = tvm::runtime::Registry::Get("relay.fromtext"); auto* f = tvm::runtime::Registry::Get("relay.fromtext");
CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path."; CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
...@@ -312,7 +319,10 @@ Module FromText(const std::string& source, const std::string& source_name) { ...@@ -312,7 +319,10 @@ Module FromText(const std::string& source, const std::string& source_name) {
TVM_REGISTER_NODE_TYPE(ModuleNode); TVM_REGISTER_NODE_TYPE(ModuleNode);
TVM_REGISTER_API("relay._make.Module") TVM_REGISTER_API("relay._make.Module")
.set_body_typed(ModuleNode::make); .set_body_typed<Module(tvm::Map<GlobalVar, Function>, tvm::Map<GlobalTypeVar, TypeData>)>(
[](tvm::Map<GlobalVar, Function> funcs, tvm::Map<GlobalTypeVar, TypeData> types) {
return ModuleNode::make(funcs, types, {});
});
TVM_REGISTER_API("relay._module.Module_Add") TVM_REGISTER_API("relay._module.Module_Add")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
......
...@@ -24,14 +24,15 @@ ...@@ -24,14 +24,15 @@
* \brief Registration of annotation operators. * \brief Registration of annotation operators.
*/ */
#include <tvm/expr.h>
#include <tvm/relay/attrs/annotation.h> #include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <topi/elemwise.h> #include <topi/elemwise.h>
#include "../type_relations.h"
#include "../../pass/alter_op_layout.h" #include "../../pass/alter_op_layout.h"
#include "../type_relations.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
* used as "barrier" to avoid fusing operators belonging to differen devices. * used as "barrier" to avoid fusing operators belonging to differen devices.
*/ */
#include <tvm/expr.h>
#include <tvm/relay/attrs/device_copy.h> #include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relay/op/memory/memory.cc
* \brief Operators for manifest shape-aware memory allocation in Relay.
*/
#include <topi/elemwise.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/memory.h>
#include "../op_common.h"
#include "../../pass/alter_op_layout.h"
#include "../type_relations.h"
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(AllocTensorAttrs);
TVM_REGISTER_NODE_TYPE(ShapeFuncAttrs);
// The passing value in attrs and args doesn't seem super great.
// We should consider a better solution, i.e the type relation
// being able to see the arguments as well?
TVM_REGISTER_API("relay.op.memory._make.alloc_storage")
.set_body_typed<Expr(Expr, Expr, DataType)>([](Expr size, Expr alignment, DataType dtype) {
auto attrs = make_node<AllocTensorAttrs>();
attrs->dtype = dtype;
static const Op& op = Op::Get("memory.alloc_storage");
return CallNode::make(op, {size, alignment}, Attrs(attrs), {});
});
bool AllocStorageRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3u);
auto size_type = types[0];
auto tensor_type = size_type.as<TensorTypeNode>();
CHECK(tensor_type != nullptr);
CHECK_EQ(tensor_type->dtype, Int(64));
CHECK_EQ(tensor_type->shape.size(), 0);
auto align_type = types[1];
auto align_ttype = align_type.as<TensorTypeNode>();
CHECK(align_ttype != nullptr);
CHECK_EQ(align_ttype->dtype, Int(64));
CHECK_EQ(align_ttype->shape.size(), 0);
auto mod = reporter->GetModule();
CHECK(mod.defined());
auto storage_name = mod->GetGlobalTypeVar("Storage");
auto storage = TypeCallNode::make(storage_name, {});
reporter->Assign(types[2], storage);
return true;
}
RELAY_REGISTER_OP("memory.alloc_storage")
.describe(R"code(Explicitly allocate storage to be used by tensors.)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("size", "Tensor", "The size of the storage to allocate.")
.add_argument("alignment", "Tensor", "The alignment of the storage.")
.add_type_rel("AllocStorage", AllocStorageRel)
.set_support_level(10)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
return {topi::identity(inputs[0])};
});
TVM_REGISTER_API("relay.op.memory._make.alloc_tensor")
.set_body_typed<Expr(Expr, Expr, DataType, Array<IndexExpr> assert_shape)>(
[](Expr storage, tvm::relay::Expr shape, DataType dtype, Array<IndexExpr> assert_shape) {
auto attrs = make_node<AllocTensorAttrs>();
attrs->dtype = dtype;
if (assert_shape.defined()) {
attrs->assert_shape = assert_shape;
} else {
attrs->const_shape = Downcast<Constant>(shape);
}
static const Op& op = Op::Get("memory.alloc_tensor");
return CallNode::make(op, {storage, shape}, Attrs(attrs), {});
});
std::vector<int64_t> FromConstShape(Constant konst) {
runtime::NDArray shape = konst->data;
std::vector<int64_t> raw_shape;
DLTensor tensor = shape.ToDLPack()->dl_tensor;
CHECK_EQ(tensor.ndim, 1u);
CHECK_EQ(tensor.dtype.code, 0U)
<< "found " << tensor.dtype.code;
CHECK(tensor.dtype.bits == 64 || tensor.dtype.bits == 32)
<< "found " << static_cast<int>(tensor.dtype.bits);
if (tensor.dtype.bits == 32) {
const int32_t* int_ptr = reinterpret_cast<int32_t*>(tensor.data);
for (auto i = 0; i < tensor.shape[0]; i++) {
raw_shape.push_back(int_ptr[i]);
}
} else if (tensor.dtype.bits == 64) {
const int64_t* int_ptr = reinterpret_cast<int64_t*>(tensor.data);
for (auto i = 0; i < tensor.shape[0]; i++) {
raw_shape.push_back(int_ptr[i]);
}
}
return raw_shape;
}
bool AllocTensorRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3u);
auto alloc_attrs = attrs.as<AllocTensorAttrs>();
CHECK(alloc_attrs != nullptr) << "must be alloc_tensor attributes";
// First argument should be storage.
auto mod = reporter->GetModule();
CHECK(mod.defined());
auto storage_name = mod->GetGlobalTypeVar("Storage");
auto storage = relay::TypeCallNode::make(storage_name, {});
reporter->Assign(types[0], storage);
// Second argument should be shape tensor.
auto tt = types[1].as<TensorTypeNode>();
CHECK(tt != nullptr) << "must be tensor type";
auto rank = tt->shape[0].as<tvm::IntImm>();
CHECK(rank != nullptr);
auto dims = rank->value;
// Constant node case.
Type alloc_type;
if (alloc_attrs->const_shape.defined()) {
auto con = alloc_attrs->const_shape;
auto sh = FromConstShape(con);
Array<IndexExpr> out_shape;
for (auto i = 0u; i < dims; i++) {
out_shape.push_back(tvm::Integer(sh[i]));
}
alloc_type = TensorTypeNode::make(out_shape, alloc_attrs->dtype);
} else {
CHECK(alloc_attrs->assert_shape.defined())
<< "the assert_shape must be set when const_shape is not";
alloc_type = TensorTypeNode::make(alloc_attrs->assert_shape, alloc_attrs->dtype);
return true;
}
reporter->Assign(types[2], alloc_type);
return true;
}
RELAY_REGISTER_OP("memory.alloc_tensor")
.describe(R"code(Explicitly allocate storage to be used by tensors.)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("storage", "Storage", "The storage to allocate from.")
.add_argument("shape", "Tensor", "The shape of the tensor to allocate.")
.add_type_rel("AllocTensor", AllocTensorRel)
.set_support_level(10)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
return {topi::identity(inputs[0])};
});
bool InvokeTVMOPRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4u);
auto func_type = types[0].as<FuncTypeNode>();
CHECK(func_type != nullptr) << "input must be operator with known type";
auto input_type = types[1].as<TupleTypeNode>();
auto output_type = types[2].as<TupleTypeNode>();
CHECK(input_type != nullptr)
<< "internal invariant violated: invoke_tvm_op inputs must be a tuple";
CHECK(output_type != nullptr)
<< "internal invariant violated: invoke_tvm_op outputs must be a tuple";
Type ex_output;
if (func_type->ret_type.as<TensorTypeNode>()) {
ex_output = TupleTypeNode::make({func_type->ret_type});
} else {
CHECK(func_type->ret_type.as<TupleTypeNode>()) << "should be tuple type";
ex_output = func_type->ret_type;
}
auto ex_input = TupleTypeNode::make(func_type->arg_types);
reporter->Assign(ex_input, GetRef<Type>(input_type));
reporter->Assign(ex_output, GetRef<Type>(output_type));
reporter->Assign(types[3], TupleTypeNode::make({}));
return true;
}
TVM_REGISTER_API("relay.op.memory._make.invoke_tvm_op")
.set_body_typed<Expr(Expr, Expr, Expr)>(
[](Expr func, Expr inputs, Expr outputs) {
return CallNode::make(Op::Get("memory.invoke_tvm_op"), {func, inputs, outputs}, Attrs());
});
RELAY_REGISTER_OP("memory.invoke_tvm_op")
.describe(R"code(Invoke an operation compiled by TVM.)code" TVM_ADD_FILELINE)
.set_num_inputs(3)
.add_argument("op", "Function", "The operation to call")
.add_argument("ins", "Tuple", "The input tensors.")
.add_argument("outs", "Tuple", "The output tensors.")
.add_type_rel("InvokeTVMOP", InvokeTVMOPRel)
.set_support_level(10)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
return {topi::identity(inputs[0])};
});
bool KillRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2u);
// TODO(@jroesch): should only support tensors.
reporter->Assign(types[1], TupleTypeNode::make({}));
return true;
}
RELAY_REGISTER_OP("memory.kill")
.describe(R"code(Mark a tensor for release to the allocator.)code" TVM_ADD_FILELINE)
.set_num_inputs(3)
.add_argument("to_free", "Tensor", "The tensor to free.")
.add_type_rel("Kill", KillRel)
.set_support_level(10)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
return {topi::identity(inputs[0])};
});
TVM_REGISTER_API("relay.op.memory._make.shape_func")
.set_body_typed<Expr(Expr, Expr, Expr, Array<tvm::Integer>)>(
[](Expr func, Expr inputs, Expr outputs, Array<tvm::Integer> is_input) {
static const Op& op = Op::Get("memory.shape_func");
auto attrs = make_node<ShapeFuncAttrs>();
attrs->is_input = is_input;
return CallNode::make(op, {func, inputs, outputs}, Attrs(attrs), {});
});
static void FlattenTypeAux(const Type& type, std::vector<TensorType>* out) {
if (auto tt = type.as<TensorTypeNode>()) {
out->push_back(GetRef<TensorType>(tt));
} else if (auto tuple_ty = type.as<TupleTypeNode>()) {
for (auto field : tuple_ty->fields) {
FlattenTypeAux(field, out);
}
} else {
LOG(FATAL) << "unsupported " << type;
}
}
std::vector<TensorType> FlattenType(const Type& type) {
std::vector<TensorType> out;
FlattenTypeAux(type, &out);
return out;
}
Expr PackByType(const Type& t, const Array<Expr>& exprs) {
LOG(FATAL) << "NYI";
return Expr();
}
bool ShapeFuncRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4u);
auto shape_func_attrs = attrs.as<ShapeFuncAttrs>();
CHECK(shape_func_attrs != nullptr) << "Internal compiler error";
auto func_type = types[0].as<FuncTypeNode>();
CHECK(func_type != nullptr);
auto tuple = TupleTypeNode::make(func_type->arg_types);
auto in_types = FlattenType(tuple);
auto out_types = FlattenType(func_type->ret_type);
Array<Type> shape_func_ins, shape_func_outs;
for (size_t i = 0; i < in_types.size(); i++) {
auto in_type = in_types[i];
if (shape_func_attrs->is_input[i]) {
shape_func_ins.push_back(in_type);
} else {
auto shape = RankShape(in_type->shape);
shape_func_ins.push_back(TensorTypeNode::make(shape, Int(64)));
}
}
for (auto out_type : out_types) {
auto rank_shape = RankShape(out_type->shape);
shape_func_outs.push_back(TensorTypeNode::make(rank_shape, Int(64)));
}
auto input_type = TupleTypeNode::make(shape_func_ins);
auto output_type = TupleTypeNode::make(shape_func_outs);
reporter->Assign(types[1], input_type);
reporter->Assign(types[2], output_type);
reporter->Assign(types[3], TupleTypeNode::make({}));
return true;
}
RELAY_REGISTER_OP("memory.shape_func")
.describe(R"code(Get the shape of a tensor.)code" TVM_ADD_FILELINE)
.set_num_inputs(3)
.add_argument("tensor", "Tensor", "The tensor to retrieve the shape for.")
.add_type_rel("ShapeFuncRel", ShapeFuncRel)
.set_support_level(10)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
return {topi::identity(inputs[0])};
});
} // namespace relay
} // namespace tvm
...@@ -30,6 +30,8 @@ ...@@ -30,6 +30,8 @@
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <vector> #include <vector>
#include <string>
#include <unordered_map>
#include "type_relations.h" #include "type_relations.h"
#include "../pass/alter_op_layout.h" #include "../pass/alter_op_layout.h"
...@@ -105,6 +107,50 @@ namespace relay { ...@@ -105,6 +107,50 @@ namespace relay {
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", \ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", \
BinaryBroadcastLayout) BinaryBroadcastLayout)
/*! \brief A helper class for matching and rewriting operators. */
template<typename R>
class OpMatch {
public:
using MatchFunc =
std::function<R(const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_args)>;
/*! \brief Match an operator with the given name.
* \param op_name The name of the operator to match.
* \param func The function to execute when it matches.
* \return A self-reference for builder style API.
*/
inline OpMatch& Match(const std::string& op_name, MatchFunc func) {
auto op = Op::Get(op_name);
match_map_.insert({op, func});
return *this;
}
/*! \brief Rewrite a call operation based on the operator and the registered
* match functions.
* \param call The call to rewrite.
* \return The result of rewriting.
*/
inline R operator()(const Call& call) {
auto it = match_map_.find(Downcast<Op>(call->op));
if (it != match_map_.end()) {
return it->second(call->args, call->attrs, call->type_args);
} else {
if (default_ != nullptr) {
return default_(call->args, call->attrs, call->type_args);
} else {
LOG(FATAL) << "unexpected operation " << call->op;
}
}
}
private:
/*! \brief The match function map. */
std::unordered_map<Op, MatchFunc, NodeHash, NodeEqual> match_map_;
/*! \brief An optional default case. */
MatchFunc default_;
};
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -286,8 +286,8 @@ bool ShapeOfRel(const Array<Type>& types, ...@@ -286,8 +286,8 @@ bool ShapeOfRel(const Array<Type>& types,
CHECK(tt != nullptr); CHECK(tt != nullptr);
const auto* param = attrs.as<ShapeOfAttrs>(); const auto* param = attrs.as<ShapeOfAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
auto vector_out = tvm::Integer(tt->shape.size()); auto rank_shape = RankShape(tt->shape);
reporter->Assign(types[1], TensorTypeNode::make({ vector_out }, param->dtype)); reporter->Assign(types[1], TensorTypeNode::make(rank_shape, param->dtype));
return true; return true;
} }
......
...@@ -144,5 +144,13 @@ bool BroadcastCompRel(const Array<Type>& types, ...@@ -144,5 +144,13 @@ bool BroadcastCompRel(const Array<Type>& types,
return false; return false;
} }
Array<IndexExpr> RankShape(const Array<IndexExpr>& shape) {
if (shape.size() == 0) {
return {};
} else {
return { tvm::Integer(shape.size()) };
}
}
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -80,6 +80,8 @@ bool BroadcastCompRel(const Array<Type>& types, ...@@ -80,6 +80,8 @@ bool BroadcastCompRel(const Array<Type>& types,
const Attrs& attrs, const Attrs& attrs,
const TypeReporter& reporter); const TypeReporter& reporter);
Array<IndexExpr> RankShape(const Array<IndexExpr>& shape);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
* 3. Collect the device allocation of each expression. * 3. Collect the device allocation of each expression.
*/ */
#include <tvm/expr.h>
#include <tvm/relay/attrs/device_copy.h> #include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/attrs/annotation.h> #include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <tvm/relay/interpreter.h> #include <tvm/relay/interpreter.h>
#include <tvm/relay/attrs/transform.h> #include <tvm/relay/attrs/transform.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include "./pattern_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -73,13 +74,12 @@ bool ConstantCheck(const Expr& e) { ...@@ -73,13 +74,12 @@ bool ConstantCheck(const Expr& e) {
TVM_REGISTER_API("relay._analysis.check_constant") TVM_REGISTER_API("relay._analysis.check_constant")
.set_body_typed(ConstantCheck); .set_body_typed(ConstantCheck);
// TODO(tvm-team) consider combine dead-code with constant folder. // TODO(tvm-team) consider combine dead-code with constant folder.
// or make a more powerful partial evaluator. // or make a more powerful partial evaluator.
class ConstantFolder : public ExprMutator { class ConstantFolder : public ExprMutator {
public: public:
explicit ConstantFolder(FInterpreter executor) explicit ConstantFolder(FInterpreter executor, Module module)
: executor_(executor) { : executor_(executor), module_(module) {
} }
Expr VisitExpr_(const LetNode* op) final { Expr VisitExpr_(const LetNode* op) final {
...@@ -123,6 +123,15 @@ class ConstantFolder : public ExprMutator { ...@@ -123,6 +123,15 @@ class ConstantFolder : public ExprMutator {
if (call->op.same_as(Op::Get("shape_of"))) { if (call->op.same_as(Op::Get("shape_of"))) {
return EvaluateShapeOf(res, origin_args, call->attrs); return EvaluateShapeOf(res, origin_args, call->attrs);
} }
// We should think about potentially constant evaluation over these ops too.
if (call->op.same_as(Op::Get("memory.invoke_tvm_op")) ||
call->op.same_as(Op::Get("memory.shape_func")) ||
call->op.same_as(Op::Get("memory.alloc_tensor")) ||
call->op.same_as(Op::Get("memory.alloc_storage"))) {
return GetRef<Call>(call);
}
bool all_const_args = true; bool all_const_args = true;
for (Expr arg : call->args) { for (Expr arg : call->args) {
if (!checker_.Check(arg)) { if (!checker_.Check(arg)) {
...@@ -151,10 +160,16 @@ class ConstantFolder : public ExprMutator { ...@@ -151,10 +160,16 @@ class ConstantFolder : public ExprMutator {
FInterpreter executor_; FInterpreter executor_;
// Internal constant checker // Internal constant checker
ConstantChecker checker_; ConstantChecker checker_;
// Module
Module module_;
// Convert value to expression. // Convert value to expression.
Expr ValueToExpr(Value value) { Expr ValueToExpr(Value value) {
if (const auto* val = value.as<TensorValueNode>()) { if (const auto* val = value.as<TensorValueNode>()) {
for (auto dim : val->data.Shape()) {
CHECK_GT(dim, 0)
<< "invalid dimension after constant eval";
}
return ConstantNode::make(val->data); return ConstantNode::make(val->data);
} else if (const auto* val = value.as<TupleValueNode>()) { } else if (const auto* val = value.as<TupleValueNode>()) {
Array<Expr> fields; Array<Expr> fields;
...@@ -171,18 +186,33 @@ class ConstantFolder : public ExprMutator { ...@@ -171,18 +186,33 @@ class ConstantFolder : public ExprMutator {
Expr ConstEvaluate(Expr expr) { Expr ConstEvaluate(Expr expr) {
std::vector<transform::Pass> passes = {transform::FuseOps(0), std::vector<transform::Pass> passes = {transform::FuseOps(0),
transform::InferType()}; transform::InferType()};
auto mod = ModuleNode::FromExpr(expr); Function func;
if (expr.as<FunctionNode>()) {
func = Downcast<Function>(expr);
} else {
// TODO(@jroesch): fix this
func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, module_), {});
}
auto mod = ModuleNode::make(
{},
module_->type_definitions,
module_->Imports());
auto global = GlobalVarNode::make("main");
mod->Add(global, func);
auto seq = transform::Sequential(passes); auto seq = transform::Sequential(passes);
mod = seq(mod); mod = seq(mod);
auto entry_func = mod->Lookup("main"); auto entry_func = mod->Lookup("main");
expr = expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func; expr = expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
return ValueToExpr(executor_(expr)); return ValueToExpr(executor_(expr));
} }
// Evaluate shape_of op
// Evaluate a call to the shape_of operator for tensors with constant
// shapes.
Expr EvaluateShapeOf(Expr expr, Array<Expr> args, Attrs attrs) { Expr EvaluateShapeOf(Expr expr, Array<Expr> args, Attrs attrs) {
Expr input = args[0]; Expr input = args[0];
const auto* param = attrs.as<ShapeOfAttrs>(); const auto* param = attrs.as<ShapeOfAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
tvm::Array<IndexExpr> ishape; tvm::Array<IndexExpr> ishape;
if (const ConstantNode* op = input.as<ConstantNode>()) { if (const ConstantNode* op = input.as<ConstantNode>()) {
ishape = op->tensor_type()->shape; ishape = op->tensor_type()->shape;
...@@ -191,13 +221,20 @@ class ConstantFolder : public ExprMutator { ...@@ -191,13 +221,20 @@ class ConstantFolder : public ExprMutator {
} else { } else {
return expr; return expr;
} }
// Get the constant shape // Get the constant shape
DLContext ctx; DLContext ctx;
ctx.device_type = kDLCPU; ctx.device_type = kDLCPU;
ctx.device_id = 0; ctx.device_id = 0;
auto val = runtime::NDArray::Empty( runtime::NDArray value;
{(int64_t)ishape.size()}, Type2TVMType(Int(32)), ctx); auto cdtype = Type2TVMType(Int(32));
int32_t* dims = static_cast<int32_t*>(val->data); if (ishape.size() == 0) {
value = runtime::NDArray::Empty({}, cdtype, ctx);
} else {
CHECK_NE(ishape.size(), 0);
std::vector<int64_t> cshape = { static_cast<int64_t>(ishape.size()) };
value = runtime::NDArray::Empty(cshape, cdtype, ctx);
int32_t* dims = static_cast<int32_t*>(value->data);
using ::tvm::ir::IntImm; using ::tvm::ir::IntImm;
for (size_t i = 0; i < ishape.size(); ++i) { for (size_t i = 0; i < ishape.size(); ++i) {
if (const IntImm* dim = ishape[i].as<IntImm>()) { if (const IntImm* dim = ishape[i].as<IntImm>()) {
...@@ -206,18 +243,26 @@ class ConstantFolder : public ExprMutator { ...@@ -206,18 +243,26 @@ class ConstantFolder : public ExprMutator {
return expr; return expr;
} }
} }
Expr shape = ValueToExpr(TensorValueNode::make(val)); }
Constant shape = Downcast<Constant>(ValueToExpr(TensorValueNode::make(value)));
if (shape->data.Shape().size() == 0 && GetScalarFromConstant<int32_t>(shape) == 0) {
auto ndarray = runtime::NDArray::Empty({}, cdtype, ctx);
shape = ConstantNode::make(ndarray);
}
// Cast the constant into correct dtype // Cast the constant into correct dtype
auto cast_attrs = make_node<CastAttrs>(); auto cast_attrs = make_node<CastAttrs>();
cast_attrs->dtype = param->dtype; cast_attrs->dtype = param->dtype;
static const Op& cast_op = Op::Get("cast"); static const Op& cast_op = Op::Get("cast");
Expr ret = CallNode::make(cast_op, {shape}, Attrs(cast_attrs), {}); Expr ret = CallNode::make(cast_op, { shape }, Attrs(cast_attrs), {});
return ConstEvaluate(ret); return ConstEvaluate(ret);
} }
}; };
Expr FoldConstant(const Expr& expr) { Expr FoldConstant(const Expr& expr, const Module& mod) {
DLContext ctx; DLContext ctx;
ctx.device_type = kDLCPU; ctx.device_type = kDLCPU;
ctx.device_id = 0; ctx.device_id = 0;
...@@ -227,7 +272,7 @@ Expr FoldConstant(const Expr& expr) { ...@@ -227,7 +272,7 @@ Expr FoldConstant(const Expr& expr) {
With<BuildConfig> fresh_build_ctx(BuildConfig::Create()); With<BuildConfig> fresh_build_ctx(BuildConfig::Create());
return ConstantFolder(CreateInterpreter( return ConstantFolder(CreateInterpreter(
Module(nullptr), ctx, target)).Mutate(expr); mod, ctx, target), mod).Mutate(expr);
} }
namespace transform { namespace transform {
...@@ -235,7 +280,7 @@ namespace transform { ...@@ -235,7 +280,7 @@ namespace transform {
Pass FoldConstant() { Pass FoldConstant() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, Module m, PassContext pc) {
return Downcast<Function>(FoldConstant(f)); return Downcast<Function>(FoldConstant(f, m));
}; };
return CreateFunctionPass(pass_func, 2, "FoldConstant", {}); return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
} }
......
...@@ -862,6 +862,13 @@ class FuseMutator : private ExprMutator { ...@@ -862,6 +862,13 @@ class FuseMutator : private ExprMutator {
Expr VisitExpr_(const CallNode* call) { Expr VisitExpr_(const CallNode* call) {
static const Op& stop_fusion = Op::Get("annotation.stop_fusion"); static const Op& stop_fusion = Op::Get("annotation.stop_fusion");
if (call->op.as<OpNode>()) { if (call->op.as<OpNode>()) {
static auto fnoncomputational =
Op::GetAttr<TNonComputational>("TNonComputational");
if (fnoncomputational.get(Downcast<Op>(call->op), false)) {
return ExprMutator::VisitExpr_(call);
}
// If it is a primitive op call // If it is a primitive op call
// then we must have a group assignment for it already. // then we must have a group assignment for it already.
CHECK(gmap_.count(call)); CHECK(gmap_.count(call));
......
...@@ -314,7 +314,7 @@ Module FunctionPassNode::operator()(const Module& mod, ...@@ -314,7 +314,7 @@ Module FunctionPassNode::operator()(const Module& mod,
<< pass_info->opt_level; << pass_info->opt_level;
// Execute the pass function and return a new module. // Execute the pass function and return a new module.
Module updated_mod = ModuleNode::make(mod->functions, mod->type_definitions); Module updated_mod = ModuleNode::make(mod->functions, mod->type_definitions, mod->Imports());
std::vector<std::pair<GlobalVar, Function> > updates; std::vector<std::pair<GlobalVar, Function> > updates;
for (const auto& it : updated_mod->functions) { for (const auto& it : updated_mod->functions) {
auto updated_func = SkipFunction(it.second) auto updated_func = SkipFunction(it.second)
......
...@@ -311,8 +311,15 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -311,8 +311,15 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
Match match = GetRef<Match>(op); Match match = GetRef<Match>(op);
Array<Pattern> unmatched_cases = UnmatchedCases(match, this->mod_); Array<Pattern> unmatched_cases = UnmatchedCases(match, this->mod_);
if (unmatched_cases.size() != 0) { if (unmatched_cases.size() != 0) {
LOG(FATAL) << "Match clause " << match << " does not handle the following cases: " RelayErrorStream ss;
<< unmatched_cases; ss << "match expression does not handle the following cases: ";
int i = 0;
for (auto cs : unmatched_cases) {
ss << "case " << i << ": \n" << PrettyPrint(cs);
}
this->ReportFatalError(
match,
ss);
} }
} }
......
...@@ -530,7 +530,9 @@ class TypeSolver::Merger : public TypeFunctor<void(const Type&)> { ...@@ -530,7 +530,9 @@ class TypeSolver::Merger : public TypeFunctor<void(const Type&)> {
}; };
// constructor // constructor
TypeSolver::TypeSolver(const GlobalVar& current_func, const Module& module, TypeSolver::TypeSolver(
const GlobalVar& current_func,
const Module& module,
ErrorReporter* err_reporter) ErrorReporter* err_reporter)
: reporter_(make_node<Reporter>(this)), : reporter_(make_node<Reporter>(this)),
current_func(current_func), current_func(current_func),
......
...@@ -287,9 +287,13 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { ...@@ -287,9 +287,13 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) {
} }
case Opcode::AllocTensor: { case Opcode::AllocTensor: {
// Number of fields = 5 + instr.alloc_tensor.ndim // Number of fields = 5 + instr.alloc_tensor.ndim
fields.push_back(instr.alloc_tensor.storage);
// Save `DLDataType` and the dst register. // Save `DLDataType` and the dst register.
const auto& dtype = instr.alloc_tensor.dtype; const auto& dtype = instr.alloc_tensor.dtype;
fields.assign({dtype.code, dtype.bits, dtype.lanes}); fields.push_back(dtype.code);
fields.push_back(dtype.bits);
fields.push_back(dtype.lanes);
// The number of dimensions is not needed for constructing an // The number of dimensions is not needed for constructing an
// `AllocTensor` instruction as it equals to the length of the `shape` // `AllocTensor` instruction as it equals to the length of the `shape`
...@@ -305,10 +309,22 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { ...@@ -305,10 +309,22 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) {
break; break;
} }
case Opcode::AllocTensorReg: { case Opcode::AllocTensorReg: {
// Number of fields = 5 // Number of fields = 6
fields.push_back(instr.alloc_tensor_reg.storage);
fields.push_back(instr.alloc_tensor_reg.shape_register); fields.push_back(instr.alloc_tensor_reg.shape_register);
// Save `DLDataType` and the dst register. // Save `DLDataType` and the dst register.
const auto& dtype = instr.alloc_tensor.dtype; const auto& dtype = instr.alloc_tensor_reg.dtype;
fields.push_back(dtype.code);
fields.push_back(dtype.bits);
fields.push_back(dtype.lanes);
fields.push_back(instr.dst);
break;
}
case Opcode::AllocStorage: {
fields.push_back(instr.alloc_storage.allocation_size);
fields.push_back(instr.alloc_storage.alignment);
// Save `DLDataType` and the dst register.
const auto& dtype = instr.alloc_storage.dtype_hint;
fields.push_back(dtype.code); fields.push_back(dtype.code);
fields.push_back(dtype.bits); fields.push_back(dtype.bits);
fields.push_back(dtype.lanes); fields.push_back(dtype.lanes);
...@@ -521,35 +537,39 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { ...@@ -521,35 +537,39 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) {
return Instruction::InvokePacked(packed_index, arity, output_size, args); return Instruction::InvokePacked(packed_index, arity, output_size, args);
} }
case Opcode::AllocTensor: { case Opcode::AllocTensor: {
// Number of fields = 5 + instr.alloc_tensor.ndim // Number of fields = 6 + instr.alloc_tensor.ndim
DCHECK_GE(instr.fields.size(), 5U); DCHECK_GE(instr.fields.size(), 6U);
DCHECK_EQ(instr.fields.size(), 5U + static_cast<size_t>(instr.fields[3])); DCHECK_EQ(instr.fields.size(), 6U + static_cast<size_t>(instr.fields[4]));
RegName storage_reg = instr.fields[0];
DLDataType dtype; DLDataType dtype;
dtype.code = instr.fields[0]; dtype.code = instr.fields[1];
dtype.bits = instr.fields[1]; dtype.bits = instr.fields[2];
dtype.lanes = instr.fields[2]; dtype.lanes = instr.fields[3];
Index ndim = instr.fields[3]; Index ndim = instr.fields[4];
RegName dst = instr.fields[4]; RegName dst = instr.fields[5];
std::vector<Index> shape = ExtractFields(instr.fields, 5, ndim); std::vector<Index> shape = ExtractFields(instr.fields, 6, ndim);
return Instruction::AllocTensor(shape, dtype, dst); return Instruction::AllocTensor(storage_reg, shape, dtype, dst);
} }
case Opcode::AllocTensorReg: { case Opcode::AllocTensorReg: {
// Number of fields = 5 // Number of fields = 5
DCHECK_EQ(instr.fields.size(), 5U); DCHECK_EQ(instr.fields.size(), 6U);
Index shape_register = instr.fields[0];
RegName storage_reg = instr.fields[0];
Index shape_register = instr.fields[1];
DLDataType dtype; DLDataType dtype;
dtype.code = instr.fields[1]; dtype.code = instr.fields[2];
dtype.bits = instr.fields[2]; dtype.bits = instr.fields[3];
dtype.lanes = instr.fields[3]; dtype.lanes = instr.fields[4];
RegName dst = instr.fields[4]; RegName dst = instr.fields[5];
return Instruction::AllocTensorReg(shape_register, dtype, dst); return Instruction::AllocTensorReg(storage_reg, shape_register, dtype, dst);
} }
case Opcode::AllocADT: { case Opcode::AllocADT: {
// Number of fields = 3 + instr.num_fields // Number of fields = 3 + instr.num_fields
...@@ -575,6 +595,24 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { ...@@ -575,6 +595,24 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) {
return Instruction::AllocClosure(clo_index, num_freevar, free_vars, dst); return Instruction::AllocClosure(clo_index, num_freevar, free_vars, dst);
} }
case Opcode::AllocStorage: {
DCHECK_GE(instr.fields.size(), 6U);
Index allocation_size = instr.fields[0];
Index alignment = instr.fields[1];
DLDataType dtype;
dtype.code = instr.fields[2];
dtype.bits = instr.fields[3];
dtype.lanes = instr.fields[4];
RegName dst = instr.fields[5];
return Instruction::AllocStorage(
allocation_size,
alignment,
dtype,
dst);
}
case Opcode::If: { case Opcode::If: {
// Number of fields = 4 // Number of fields = 4
DCHECK_EQ(instr.fields.size(), 4U); DCHECK_EQ(instr.fields.size(), 4U);
......
...@@ -32,6 +32,30 @@ namespace tvm { ...@@ -32,6 +32,30 @@ namespace tvm {
namespace runtime { namespace runtime {
namespace vm { namespace vm {
static void BufferDeleter(NDArray::Container* ptr) {
CHECK(ptr->manager_ctx != nullptr);
Buffer* buffer = reinterpret_cast<Buffer*>(ptr->manager_ctx);
MemoryManager::Global()->GetAllocator(buffer->ctx)->
Free(*(buffer));
delete buffer;
delete ptr;
}
void StorageObj::Deleter(NDArray::Container* ptr) {
// When invoking AllocNDArray we don't own the underlying allocation
// and should not delete the buffer, but instead let it be reclaimed
// by the storage object's destructor.
//
// We did bump the reference count by 1 to keep alive the StorageObj
// allocation in case this NDArray is the sole owner.
//
// We decrement the object allowing for the buffer to release our
// reference count from allocation.
StorageObj* storage = reinterpret_cast<StorageObj*>(ptr->manager_ctx);
storage->DecRef();
delete ptr;
}
inline void VerifyDataType(DLDataType dtype) { inline void VerifyDataType(DLDataType dtype) {
CHECK_GE(dtype.lanes, 1); CHECK_GE(dtype.lanes, 1);
if (dtype.code == kDLFloat) { if (dtype.code == kDLFloat) {
...@@ -50,6 +74,22 @@ inline size_t GetDataAlignment(const DLTensor& arr) { ...@@ -50,6 +74,22 @@ inline size_t GetDataAlignment(const DLTensor& arr) {
return align; return align;
} }
NDArray StorageObj::AllocNDArray(size_t offset, std::vector<int64_t> shape, DLDataType dtype) {
// TODO(@jroesch): generalize later to non-overlapping allocations.
CHECK_EQ(offset, 0u);
VerifyDataType(dtype);
NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, this->buffer.ctx);
container->deleter = StorageObj::Deleter;
size_t needed_size = GetDataSize(container->dl_tensor);
// TODO(@jroesch): generalize later to non-overlapping allocations.
CHECK(needed_size == this->buffer.size)
<< "size mistmatch required " << needed_size << " found " << this->buffer.size;
this->IncRef();
container->manager_ctx = reinterpret_cast<void*>(this);
container->dl_tensor.data = this->buffer.data;
return NDArray(container);
}
MemoryManager* MemoryManager::Global() { MemoryManager* MemoryManager::Global() {
static MemoryManager memory_manager; static MemoryManager memory_manager;
return &memory_manager; return &memory_manager;
...@@ -66,15 +106,6 @@ Allocator* MemoryManager::GetAllocator(TVMContext ctx) { ...@@ -66,15 +106,6 @@ Allocator* MemoryManager::GetAllocator(TVMContext ctx) {
return allocators_.at(ctx).get(); return allocators_.at(ctx).get();
} }
static void BufferDeleter(NDArray::Container* ptr) {
CHECK(ptr->manager_ctx != nullptr);
Buffer* buffer = reinterpret_cast<Buffer*>(ptr->manager_ctx);
MemoryManager::Global()->GetAllocator(buffer->ctx)->
Free(*(buffer));
delete buffer;
delete ptr;
}
NDArray Allocator::Empty(std::vector<int64_t> shape, DLDataType dtype, DLContext ctx) { NDArray Allocator::Empty(std::vector<int64_t> shape, DLDataType dtype, DLContext ctx) {
VerifyDataType(dtype); VerifyDataType(dtype);
NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, ctx); NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, ctx);
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/ndarray.h> #include <tvm/runtime/ndarray.h>
#include <tvm/runtime/object.h>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <mutex> #include <mutex>
...@@ -108,6 +109,38 @@ class MemoryManager { ...@@ -108,6 +109,38 @@ class MemoryManager {
std::unordered_map<TVMContext, std::unique_ptr<Allocator> > allocators_; std::unordered_map<TVMContext, std::unique_ptr<Allocator> > allocators_;
}; };
/*! \brief An object representing a storage allocation. */
class StorageObj : public Object {
public:
/*! \brief The index into the VM function table. */
Buffer buffer;
/*! \brief Allocate an NDArray from a given piece of storage. */
NDArray AllocNDArray(size_t offset,
std::vector<int64_t> shape,
DLDataType dtype);
/*! \brief The deleter for an NDArray when allocated from underlying storage. */
static void Deleter(NDArray::Container* ptr);
~StorageObj() {
auto alloc = MemoryManager::Global()->GetAllocator(buffer.ctx);
alloc->Free(buffer);
}
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "vm.Storage";
TVM_DECLARE_FINAL_OBJECT_INFO(StorageObj, Object);
};
/*! \brief reference to storage. */
class Storage : public ObjectRef {
public:
explicit Storage(Buffer buffer);
TVM_DEFINE_OBJECT_REF_METHODS_MUT(Storage, ObjectRef, StorageObj);
};
} // namespace vm } // namespace vm
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
......
...@@ -25,6 +25,8 @@ ...@@ -25,6 +25,8 @@
#include <dmlc/memory_io.h> #include <dmlc/memory_io.h>
#include <tvm/logging.h> #include <tvm/logging.h>
#include <tvm/runtime/vm.h> #include <tvm/runtime/vm.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>
#include <algorithm> #include <algorithm>
#include <chrono> #include <chrono>
...@@ -42,6 +44,17 @@ namespace tvm { ...@@ -42,6 +44,17 @@ namespace tvm {
namespace runtime { namespace runtime {
namespace vm { namespace vm {
inline Storage make_storage(size_t size, size_t alignment, TVMType dtype_hint, TVMContext ctx) {
// We could put cache in here, from ctx to storage allocator.
auto storage_obj = SimpleObjAllocator().make<StorageObj>();
auto alloc = MemoryManager::Global()->GetAllocator(ctx);
DCHECK(alloc != nullptr)
<< "allocator must not null";
storage_obj->buffer = alloc->Alloc(size, alignment, dtype_hint);
return Storage(storage_obj);
}
Instruction::Instruction() {} Instruction::Instruction() {}
template <typename T> template <typename T>
...@@ -65,12 +78,14 @@ Instruction::Instruction(const Instruction& instr) { ...@@ -65,12 +78,14 @@ Instruction::Instruction(const Instruction& instr) {
this->result = instr.result; this->result = instr.result;
return; return;
case Opcode::AllocTensor: case Opcode::AllocTensor:
this->alloc_tensor.storage = instr.alloc_tensor.storage;
this->alloc_tensor.ndim = instr.alloc_tensor.ndim; this->alloc_tensor.ndim = instr.alloc_tensor.ndim;
this->alloc_tensor.shape = Duplicate<int64_t>(instr.alloc_tensor.shape, this->alloc_tensor.shape = Duplicate<int64_t>(instr.alloc_tensor.shape,
instr.alloc_tensor.ndim); instr.alloc_tensor.ndim);
this->alloc_tensor.dtype = instr.alloc_tensor.dtype; this->alloc_tensor.dtype = instr.alloc_tensor.dtype;
return; return;
case Opcode::AllocTensorReg: case Opcode::AllocTensorReg:
this->alloc_tensor_reg.storage = instr.alloc_tensor_reg.storage;
this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register; this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register;
this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype; this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype;
return; return;
...@@ -119,6 +134,9 @@ Instruction::Instruction(const Instruction& instr) { ...@@ -119,6 +134,9 @@ Instruction::Instruction(const Instruction& instr) {
case Opcode::Goto: case Opcode::Goto:
this->pc_offset = instr.pc_offset; this->pc_offset = instr.pc_offset;
return; return;
case Opcode::AllocStorage:
this->alloc_storage = instr.alloc_storage;
return;
default: default:
std::ostringstream out; std::ostringstream out;
out << "Invalid instruction " << static_cast<int>(instr.op); out << "Invalid instruction " << static_cast<int>(instr.op);
...@@ -150,12 +168,14 @@ Instruction& Instruction::operator=(const Instruction& instr) { ...@@ -150,12 +168,14 @@ Instruction& Instruction::operator=(const Instruction& instr) {
this->result = instr.result; this->result = instr.result;
return *this; return *this;
case Opcode::AllocTensor: case Opcode::AllocTensor:
this->alloc_tensor.storage = this->alloc_tensor.storage;
this->alloc_tensor.ndim = instr.alloc_tensor.ndim; this->alloc_tensor.ndim = instr.alloc_tensor.ndim;
this->alloc_tensor.shape = Duplicate<int64_t>(instr.alloc_tensor.shape, this->alloc_tensor.shape = Duplicate<int64_t>(instr.alloc_tensor.shape,
instr.alloc_tensor.ndim); instr.alloc_tensor.ndim);
this->alloc_tensor.dtype = instr.alloc_tensor.dtype; this->alloc_tensor.dtype = instr.alloc_tensor.dtype;
return *this; return *this;
case Opcode::AllocTensorReg: case Opcode::AllocTensorReg:
this->alloc_tensor_reg.storage = instr.alloc_tensor_reg.storage;
this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register; this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register;
this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype; this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype;
return *this; return *this;
...@@ -206,6 +226,9 @@ Instruction& Instruction::operator=(const Instruction& instr) { ...@@ -206,6 +226,9 @@ Instruction& Instruction::operator=(const Instruction& instr) {
case Opcode::Goto: case Opcode::Goto:
this->pc_offset = instr.pc_offset; this->pc_offset = instr.pc_offset;
return *this; return *this;
case Opcode::AllocStorage:
this->alloc_storage = instr.alloc_storage;
return *this;
default: default:
std::ostringstream out; std::ostringstream out;
out << "Invalid instruction " << static_cast<int>(instr.op); out << "Invalid instruction " << static_cast<int>(instr.op);
...@@ -224,6 +247,7 @@ Instruction::~Instruction() { ...@@ -224,6 +247,7 @@ Instruction::~Instruction() {
case Opcode::GetTag: case Opcode::GetTag:
case Opcode::Goto: case Opcode::Goto:
case Opcode::LoadConsti: case Opcode::LoadConsti:
case Opcode::AllocStorage:
case Opcode::Fatal: case Opcode::Fatal:
return; return;
case Opcode::AllocTensor: case Opcode::AllocTensor:
...@@ -279,10 +303,14 @@ Instruction Instruction::InvokePacked(Index packed_index, ...@@ -279,10 +303,14 @@ Instruction Instruction::InvokePacked(Index packed_index,
return instr; return instr;
} }
Instruction Instruction::AllocTensor(std::vector<int64_t> shape, DLDataType dtype, Index dst) { Instruction Instruction::AllocTensor(
RegName storage,
const std::vector<int64_t>& shape,
DLDataType dtype, Index dst) {
Instruction instr; Instruction instr;
instr.op = Opcode::AllocTensor; instr.op = Opcode::AllocTensor;
instr.dst = dst; instr.dst = dst;
instr.alloc_tensor.storage = storage;
instr.alloc_tensor.ndim = shape.size(); instr.alloc_tensor.ndim = shape.size();
instr.alloc_tensor.shape = new int64_t[shape.size()]; instr.alloc_tensor.shape = new int64_t[shape.size()];
for (size_t i = 0; i < shape.size(); ++i) { for (size_t i = 0; i < shape.size(); ++i) {
...@@ -292,15 +320,32 @@ Instruction Instruction::AllocTensor(std::vector<int64_t> shape, DLDataType dtyp ...@@ -292,15 +320,32 @@ Instruction Instruction::AllocTensor(std::vector<int64_t> shape, DLDataType dtyp
return instr; return instr;
} }
Instruction Instruction::AllocTensorReg(RegName shape_register, DLDataType dtype, Index dst) { Instruction Instruction::AllocTensorReg(
RegName storage,
RegName shape_register,
DLDataType dtype, Index dst) {
Instruction instr; Instruction instr;
instr.op = Opcode::AllocTensorReg; instr.op = Opcode::AllocTensorReg;
instr.dst = dst; instr.dst = dst;
instr.alloc_tensor_reg.storage = storage;
instr.alloc_tensor_reg.shape_register = shape_register; instr.alloc_tensor_reg.shape_register = shape_register;
instr.alloc_tensor_reg.dtype = dtype; instr.alloc_tensor_reg.dtype = dtype;
return instr; return instr;
} }
Instruction Instruction::AllocStorage(RegName size,
Index alignment,
TVMType dtype_hint,
Index dst) {
Instruction instr;
instr.op = Opcode::AllocStorage;
instr.dst = dst;
instr.alloc_storage.allocation_size = size;
instr.alloc_storage.alignment = alignment;
instr.alloc_storage.dtype_hint = dtype_hint;
return instr;
}
Instruction Instruction::AllocADT(Index tag, Index num_fields, Instruction Instruction::AllocADT(Index tag, Index num_fields,
const std::vector<RegName>& datatype_fields, Index dst) { const std::vector<RegName>& datatype_fields, Index dst) {
Instruction instr; Instruction instr;
...@@ -472,7 +517,8 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { ...@@ -472,7 +517,8 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
break; break;
} }
case Opcode::AllocTensor: { case Opcode::AllocTensor: {
os << "alloc_tensor $" << instr.dst << " [" os << "alloc_tensor $" << instr.dst << " $"
<< instr.alloc_tensor.storage << " ["
<< StrJoin<int64_t>(instr.alloc_tensor.shape, 0, << StrJoin<int64_t>(instr.alloc_tensor.shape, 0,
instr.alloc_tensor.ndim) instr.alloc_tensor.ndim)
<< "] "; << "] ";
...@@ -481,6 +527,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { ...@@ -481,6 +527,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
} }
case Opcode::AllocTensorReg: { case Opcode::AllocTensorReg: {
os << "alloc_tensor_reg $" << instr.dst << " $" os << "alloc_tensor_reg $" << instr.dst << " $"
<< instr.alloc_tensor_reg.storage << " $"
<< instr.alloc_tensor_reg.shape_register << " "; << instr.alloc_tensor_reg.shape_register << " ";
DLDatatypePrint(os, instr.alloc_tensor_reg.dtype); DLDatatypePrint(os, instr.alloc_tensor_reg.dtype);
break; break;
...@@ -534,6 +581,14 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { ...@@ -534,6 +581,14 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
os << "goto " << instr.pc_offset; os << "goto " << instr.pc_offset;
break; break;
} }
case Opcode::AllocStorage: {
os << "alloc_storage " <<
instr.dst << " " <<
instr.alloc_storage.allocation_size << " " <<
instr.alloc_storage.alignment << " " <<
TVMType2String(instr.alloc_storage.dtype_hint);
break;
}
default: default:
LOG(FATAL) << "should never hit this case" << static_cast<int>(instr.op); LOG(FATAL) << "should never hit this case" << static_cast<int>(instr.op);
break; break;
...@@ -827,17 +882,21 @@ void VirtualMachine::RunLoop() { ...@@ -827,17 +882,21 @@ void VirtualMachine::RunLoop() {
goto main_loop; goto main_loop;
} }
case Opcode::InvokePacked: { case Opcode::InvokePacked: {
DLOG(INFO) << "InvokedPacked "
<< "arity=" << instr.arity;
const auto& func = packed_funcs[instr.packed_index]; const auto& func = packed_funcs[instr.packed_index];
const auto& arity = instr.arity; const auto& arity = instr.arity;
std::vector<ObjectRef> args; std::vector<ObjectRef> args;
for (Index i = 0; i < arity; ++i) { for (Index i = 0; i < arity; ++i) {
args.push_back(ReadRegister(instr.packed_args[i])); DLOG(INFO) <<
"arg" << i << " $" << instr.packed_args[i];
auto arg = ReadRegister(instr.packed_args[i]);
args.push_back(arg);
} }
// We no longer need to write the registers back, we write directly
// through the registers mutably.
InvokePacked(instr.packed_index, func, arity, instr.output_size, args); InvokePacked(instr.packed_index, func, arity, instr.output_size, args);
for (Index i = 0; i < instr.output_size; ++i) {
WriteRegister(instr.packed_args[instr.arity - instr.output_size + i],
args[instr.arity - instr.output_size + i]);
}
pc++; pc++;
goto main_loop; goto main_loop;
} }
...@@ -901,12 +960,15 @@ void VirtualMachine::RunLoop() { ...@@ -901,12 +960,15 @@ void VirtualMachine::RunLoop() {
} }
case Opcode::AllocTensor: { case Opcode::AllocTensor: {
auto shape = std::vector<int64_t>(instr.alloc_tensor.ndim); auto shape = std::vector<int64_t>(instr.alloc_tensor.ndim);
for (uint32_t i = 0; i < instr.alloc_tensor.ndim; ++i) { for (uint32_t i = 0; i < instr.alloc_tensor.ndim; ++i) {
shape[i] = instr.alloc_tensor.shape[i]; shape[i] = instr.alloc_tensor.shape[i];
} }
// TODO(wweic) ctx could be obtained from the ctxs list.
auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]); auto storage_obj = ReadRegister(instr.alloc_tensor.storage);
auto data = allocator->Empty(shape, instr.alloc_tensor.dtype, ctxs[0]); auto storage = Downcast<Storage>(storage_obj);
auto data = storage->AllocNDArray(0, shape, instr.alloc_tensor.dtype);
auto obj = Tensor(data); auto obj = Tensor(data);
WriteRegister(instr.dst, obj); WriteRegister(instr.dst, obj);
pc++; pc++;
...@@ -916,19 +978,22 @@ void VirtualMachine::RunLoop() { ...@@ -916,19 +978,22 @@ void VirtualMachine::RunLoop() {
DLContext cpu_ctx; DLContext cpu_ctx;
cpu_ctx.device_type = kDLCPU; cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0; cpu_ctx.device_id = 0;
auto shape_tensor_obj = ReadRegister(instr.alloc_tensor_reg.shape_register); auto shape_tensor_obj = ReadRegister(instr.alloc_tensor_reg.shape_register);
const auto* tensor = shape_tensor_obj.as<TensorObj>(); const auto* tensor = shape_tensor_obj.as<TensorObj>();
CHECK(tensor != nullptr); CHECK(tensor != nullptr);
NDArray shape_tensor = tensor->data.CopyTo(cpu_ctx); NDArray shape_tensor = tensor->data.CopyTo(cpu_ctx);
const DLTensor* dl_tensor = shape_tensor.operator->();
int64_t* dims = static_cast<int64_t*>(shape_tensor->data); CHECK_EQ(dl_tensor->dtype.code, 0u);
CHECK_LE(dl_tensor->dtype.bits, 64);
int64_t* dims = reinterpret_cast<int64_t*>(dl_tensor->data);
auto num_dims = shape_tensor->shape[0]; auto num_dims = shape_tensor->shape[0];
auto shape = std::vector<int64_t>(shape_tensor->shape[0]); auto shape = std::vector<int64_t>(num_dims);
shape.assign(dims, dims + num_dims); shape.assign(dims, dims + num_dims);
// TODO(wweic) ctx could be obtained from the ctxs list.
auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]); auto storage_obj = ReadRegister(instr.alloc_tensor_reg.storage);
auto data = allocator->Empty(shape, instr.alloc_tensor_reg.dtype, ctxs[0]); auto storage = Downcast<Storage>(storage_obj);
auto data = storage->AllocNDArray(0, shape, instr.alloc_tensor_reg.dtype);
auto obj = Tensor(data); auto obj = Tensor(data);
WriteRegister(instr.dst, obj); WriteRegister(instr.dst, obj);
pc++; pc++;
...@@ -953,6 +1018,20 @@ void VirtualMachine::RunLoop() { ...@@ -953,6 +1018,20 @@ void VirtualMachine::RunLoop() {
pc++; pc++;
goto main_loop; goto main_loop;
} }
case Opcode::AllocStorage: {
auto size = LoadScalarInt(instr.alloc_storage.allocation_size);
auto alignment = LoadScalarInt(instr.alloc_storage.alignment);
DLOG(INFO) <<
"AllocStorage: allocation_size=" << size <<
"alignment=" << alignment <<
"dtype_hint=" << TVMType2String(instr.alloc_storage.dtype_hint);
auto storage = make_storage(size, alignment, instr.alloc_storage.dtype_hint, ctxs[0]);
WriteRegister(instr.dst, storage);
pc++;
goto main_loop;
}
case Opcode::Ret: { case Opcode::Ret: {
// If we have hit the point from which we started // If we have hit the point from which we started
// running, we should return to the caller breaking // running, we should return to the caller breaking
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License
import tvm
import numpy as np
from tvm import relay
from tvm.relay import memory_alloc
def check_vm_alloc(func, check_fn):
mod = relay.Module()
mod['main'] = func
ex = relay.create_executor('vm', mod)
args = []
for param in func.params:
param = param.type_annotation
sh = [int(sh) for sh in param.shape]
data = np.random.rand(*sh).astype(param.dtype)
args.append(tvm.nd.array(data))
result = ex.evaluate(mod['main'])(*args)
py_res = check_fn(*[arg.asnumpy() for arg in args])
np.testing.assert_allclose(result.asnumpy(), py_res)
def storage_type(mod):
return relay.TypeCall(mod.get_global_type_var("Storage"), [])
def test_tyck_alloc_storage():
mod = relay.Module()
mod.import_from_std("core.rly")
def test_tyck_alloc_tensor():
mod = relay.Module()
mod.import_from_std("core.rly")
sto = relay.Var("x", storage_type(mod))
sh = relay.const(np.array([1, 2]), dtype="int64")
at = relay.op.memory.alloc_tensor(sto, sh)
mod['main'] = relay.Function([sto], at)
relay.transform.InferType()(mod)
def check_add(x):
return x + x
def test_add():
x = relay.var('x', shape=(2,))
z = x + x
func = relay.Function([x,], z)
check_vm_alloc(func, check_add)
def check_add_sub(x, y):
z = x + x
return z - y
def test_add_sub():
x = relay.var('x', shape=(10,))
y = relay.var('y', shape=(10,))
z = x + x
z = z - y
func = relay.Function([x, y], z)
check_vm_alloc(func, check_add_sub)
if __name__ == "__main__":
test_tyck_alloc_tensor()
test_add()
test_add_sub()
...@@ -107,9 +107,9 @@ def test_serializer(): ...@@ -107,9 +107,9 @@ def test_serializer():
assert any(item.startswith('fused_multiply') for item in prim_ops) assert any(item.startswith('fused_multiply') for item in prim_ops)
code = exe.bytecode code = exe.bytecode
assert "main 5 2 5" in code assert "main 8 2 8" in code
assert "f1 2 1 3" in code assert "f1 5 1 6" in code
assert "f2 2 1 3" in code assert "f2 5 1 6" in code
code, lib = exe.save() code, lib = exe.save()
assert isinstance(code, bytearray) assert isinstance(code, bytearray)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment