Commit 3fb84e2b by Jared Roesch Committed by Thierry Moreau

[Relay][RFC] Implement type checking for Any (#3221)

* Implement type checking for Any

Remove code generation related changes

Remove compile changes

Remove more

Remove unification hack

Add some code back that was needed, and clean up test

Refactor test cases

WIP

Implement TypeHint AST

Add test case which should fail

Remove unification changes, and fix bug with let rec

Restore unification for shapes

Improve error reporting while debugging

All examples type check

All examples type check

WIP

First version that works with hints, needs clean up

Remove dead code

Tweaks

Remove type hint

Remove unecessary type hint stuff

Remove more type hints

Clean up

Expose Any expression node

Address CR

Fix

Fix solver

Kill unecessary code

Fix

PyLint

Fix

Relocate loops

Fix license and test

Lint again

Lint again

Fix loops

Fix docstring

Fix template error

Fix compiler issue

Fix compile err

Remove more runtime changes

Restore buffer

Fix segfault

Fix

Fix arange

* Address feedback

* Fix typo

* Fix arange

* Fix op level3

* Fix issue with Python wrapper
parent b10dda69
...@@ -138,6 +138,15 @@ struct Reduce : public ExprNode<Reduce> { ...@@ -138,6 +138,15 @@ struct Reduce : public ExprNode<Reduce> {
static constexpr const char* _type_key = "Reduce"; static constexpr const char* _type_key = "Reduce";
}; };
/*! \brief Any shape. */
struct Any : public ExprNode<Any> {
TVM_DLL static Expr make();
void VisitAttrs(AttrVisitor* v) final {}
static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "Any";
};
/*! /*!
* \brief Auxiliary data structure used in IR Pass to indicate a tensor. * \brief Auxiliary data structure used in IR Pass to indicate a tensor.
*/ */
......
...@@ -123,19 +123,19 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> { ...@@ -123,19 +123,19 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
/*! \brief Attributes used in arange operators */ /*! \brief Attributes used in arange operators */
struct ArangeAttrs : public tvm::AttrsNode<ArangeAttrs> { struct ArangeAttrs : public tvm::AttrsNode<ArangeAttrs> {
tvm::Expr start; Expr start;
tvm::Expr stop; Expr stop;
tvm::Expr step; Expr step;
DataType dtype; DataType dtype;
TVM_DECLARE_ATTRS(ArangeAttrs, "relay.attrs.ArangeAttrs") { TVM_DECLARE_ATTRS(ArangeAttrs, "relay.attrs.ArangeAttrs") {
TVM_ATTR_FIELD(start).set_default(make_const(Float(32), 0)) TVM_ATTR_FIELD(start)
.describe("Start of interval. The interval includes this value."); .describe("Start of interval. The interval includes this value.");
TVM_ATTR_FIELD(stop) TVM_ATTR_FIELD(stop)
.describe("Stop of interval. The interval does not include this value."); .describe("Stop of interval. The interval does not include this value.");
TVM_ATTR_FIELD(step).set_default(make_const(Float(32), 1)) TVM_ATTR_FIELD(step)
.describe("Spacing between values."); .describe("Spacing between values.");
TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>()) TVM_ATTR_FIELD(dtype)
.describe("Target data type."); .describe("Target data type.");
} }
}; // struct ArangeAttrs }; // struct ArangeAttrs
......
...@@ -64,9 +64,10 @@ struct RelayErrorStream { ...@@ -64,9 +64,10 @@ struct RelayErrorStream {
struct Error : public dmlc::Error { struct Error : public dmlc::Error {
Span sp; Span sp;
explicit Error(const std::string& msg) : dmlc::Error(msg), sp() {} explicit Error(const std::string& msg) : dmlc::Error(msg), sp(nullptr) {}
Error(const std::stringstream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*) Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp(nullptr) {} // NOLINT(*)
Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*) Error(const Error& err) : dmlc::Error(err.what()), sp(nullptr) {}
Error() : dmlc::Error(""), sp(nullptr) {}
}; };
/*! \brief An abstraction around how errors are stored and reported. /*! \brief An abstraction around how errors are stored and reported.
...@@ -118,7 +119,8 @@ class ErrorReporter { ...@@ -118,7 +119,8 @@ class ErrorReporter {
* \param err The error message to report. * \param err The error message to report.
*/ */
inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) { inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) {
this->ReportAt(global, node, Error(err)); std::string err_msg = err.str();
this->ReportAt(global, node, Error(err_msg));
} }
/*! \brief Report an error against a program, using the full program /*! \brief Report an error against a program, using the full program
......
...@@ -561,6 +561,9 @@ inline const TTypeNode* ExprNode::type_as() const { ...@@ -561,6 +561,9 @@ inline const TTypeNode* ExprNode::type_as() const {
return node; return node;
} }
/*! \brief Pretty print a Relay node, producing a fragment of the Relay text format. */
std::string PrettyPrint(const NodeRef& node);
/*! /*!
* \brief Render the node as a string in the Relay text format. * \brief Render the node as a string in the Relay text format.
* \param node The node to be rendered. * \param node The node to be rendered.
......
...@@ -158,6 +158,22 @@ using FForwardRewrite = runtime::TypedPackedFunc< ...@@ -158,6 +158,22 @@ using FForwardRewrite = runtime::TypedPackedFunc<
using FPrimalGradient = runtime::TypedPackedFunc<tvm::Array<Expr>(const Expr& orig_call, using FPrimalGradient = runtime::TypedPackedFunc<tvm::Array<Expr>(const Expr& orig_call,
const Expr& output_grad)>; const Expr& output_grad)>;
/*!
* \brief The codegeneration strategy for dynamic dimensions.
*/
enum AnyCodegenStrategy {
/*! \brief The default strategy of using completely variable dimensions. */
kVariableDimensions
};
/* \brief A runtime representation of shape. */
using Shape = Array<IndexExpr>;
using FShapeFunc = runtime::TypedPackedFunc<
Array<Tensor>(const Attrs& attrs,
const Array<Tensor>& inputs,
const Array<Shape>& out_shapes)>;
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_OP_ATTR_TYPES_H_ #endif // TVM_RELAY_OP_ATTR_TYPES_H_
...@@ -35,6 +35,8 @@ ...@@ -35,6 +35,8 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
using Any = tvm::ir::Any;
/*! \brief Base type of the Relay type hiearchy. */ /*! \brief Base type of the Relay type hiearchy. */
class TypeNode : public RelayNode { class TypeNode : public RelayNode {
public: public:
...@@ -384,6 +386,7 @@ class TypeReporterNode : public Node { ...@@ -384,6 +386,7 @@ class TypeReporterNode : public Node {
* But it is possible for the solver to resolve src by dst as well. * But it is possible for the solver to resolve src by dst as well.
*/ */
TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0; TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0;
/*! /*!
* \brief assert shape expression comparison. * \brief assert shape expression comparison.
* \note Use assert only if any of the condition input is symbolic. * \note Use assert only if any of the condition input is symbolic.
......
...@@ -190,6 +190,8 @@ class NDArray { ...@@ -190,6 +190,8 @@ class NDArray {
TVM_DLL static void CopyFromTo( TVM_DLL static void CopyFromTo(
DLTensor* from, DLTensor* to, TVMStreamHandle stream = nullptr); DLTensor* from, DLTensor* to, TVMStreamHandle stream = nullptr);
TVM_DLL std::vector<int64_t> Shape() const;
// internal namespace // internal namespace
struct Internal; struct Internal;
protected: protected:
......
...@@ -294,7 +294,7 @@ def get_last_ffi_error(): ...@@ -294,7 +294,7 @@ def get_last_ffi_error():
""" """
c_err_msg = py_str(_LIB.TVMGetLastError()) c_err_msg = py_str(_LIB.TVMGetLastError())
py_err_msg, err_type = c2pyerror(c_err_msg) py_err_msg, err_type = c2pyerror(c_err_msg)
if err_type.startswith("tvm.error."): if err_type is not None and err_type.startswith("tvm.error."):
err_type = err_type[10:] err_type = err_type[10:]
return ERROR_TYPE.get(err_type, TVMError)(py_err_msg) return ERROR_TYPE.get(err_type, TVMError)(py_err_msg)
......
...@@ -479,7 +479,8 @@ def extern(shape, ...@@ -479,7 +479,8 @@ def extern(shape,
raise ValueError("nested tag is not allowed for now") raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.get_current().tag tag = _tag.TagScope.get_current().tag
shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
shape = [shape] if isinstance(shape[0], (_expr.Expr, _Integral)) else shape if shape == () or isinstance(shape[0], (_expr.Expr, _Integral)):
shape = [shape]
if in_buffers is not None: if in_buffers is not None:
in_buffers = [in_buffers] if not isinstance(in_buffers, list) else in_buffers in_buffers = [in_buffers] if not isinstance(in_buffers, list) else in_buffers
if len(inputs) != len(in_buffers): if len(inputs) != len(in_buffers):
......
...@@ -63,6 +63,7 @@ TupleType = ty.TupleType ...@@ -63,6 +63,7 @@ TupleType = ty.TupleType
TensorType = ty.TensorType TensorType = ty.TensorType
Kind = ty.Kind Kind = ty.Kind
TypeVar = ty.TypeVar TypeVar = ty.TypeVar
ShapeVar = ty.ShapeVar
TypeConstraint = ty.TypeConstraint TypeConstraint = ty.TypeConstraint
FuncType = ty.FuncType FuncType = ty.FuncType
TypeRelation = ty.TypeRelation TypeRelation = ty.TypeRelation
...@@ -71,6 +72,7 @@ scalar_type = ty.scalar_type ...@@ -71,6 +72,7 @@ scalar_type = ty.scalar_type
RefType = ty.RefType RefType = ty.RefType
GlobalTypeVar = ty.GlobalTypeVar GlobalTypeVar = ty.GlobalTypeVar
TypeCall = ty.TypeCall TypeCall = ty.TypeCall
Any = ty.Any
# Expr # Expr
Expr = expr.Expr Expr = expr.Expr
......
...@@ -570,6 +570,7 @@ def const(value, dtype=None): ...@@ -570,6 +570,7 @@ def const(value, dtype=None):
""" """
if isinstance(value, (_base.numeric_types, (bool, list))): if isinstance(value, (_base.numeric_types, (bool, list))):
value = _np.array(value, dtype=dtype) value = _np.array(value, dtype=dtype)
if not dtype: if not dtype:
# when dtype is None: int maps to "int32", float maps to "float32" # when dtype is None: int maps to "int32", float maps to "float32"
map_dtype = { map_dtype = {
...@@ -578,6 +579,7 @@ def const(value, dtype=None): ...@@ -578,6 +579,7 @@ def const(value, dtype=None):
}.get(value.dtype, None) }.get(value.dtype, None)
if map_dtype: if map_dtype:
value = value.astype(map_dtype) value = value.astype(map_dtype)
if isinstance(value, (_np.ndarray, _np.generic)): if isinstance(value, (_np.ndarray, _np.generic)):
value = _nd.array(value) value = _nd.array(value)
......
...@@ -491,9 +491,9 @@ def _mx_arange(inputs, attrs): ...@@ -491,9 +491,9 @@ def _mx_arange(inputs, attrs):
raise tvm.error.OpAttributeUnimplemented( raise tvm.error.OpAttributeUnimplemented(
'Attribute "repeat" is not supported in operator arange.') 'Attribute "repeat" is not supported in operator arange.')
new_attrs = {} new_attrs = {}
new_attrs["start"] = attrs.get_float("start", 0) new_attrs["start"] = _expr.const(attrs.get_float("start", 0.0))
new_attrs["stop"] = attrs.get_float("stop") new_attrs["stop"] = _expr.const(attrs.get_float("stop"))
new_attrs["step"] = attrs.get_float("step", 1) new_attrs["step"] = _expr.const(attrs.get_float("step", 1.0))
new_attrs["dtype"] = attrs.get_str("dtype", "float32") new_attrs["dtype"] = attrs.get_str("dtype", "float32")
return _op.arange(**new_attrs) return _op.arange(**new_attrs)
......
...@@ -1059,9 +1059,9 @@ def _range(): ...@@ -1059,9 +1059,9 @@ def _range():
return AttrCvt( return AttrCvt(
op_name="arange", op_name="arange",
ignores=['Tidx'], ignores=['Tidx'],
extras={'start': start, extras={'start': _expr.const(start),
"stop": limit, "stop": _expr.const(limit),
'step': delta, 'step': _expr.const(delta),
'dtype': dtype})([], attr) 'dtype': dtype})([], attr)
return _impl return _impl
...@@ -1269,8 +1269,8 @@ def _batch_to_space_nd(): ...@@ -1269,8 +1269,8 @@ def _batch_to_space_nd():
crop = crops[axis - 1] crop = crops[axis - 1]
if crop != [0, 0]: if crop != [0, 0]:
indices = tvm.relay.arange( indices = tvm.relay.arange(
crop[0], _expr.const(crop[0]),
reshaped_permuted_shape[axis] - crop[1], _expr.const(reshaped_permuted_shape[axis] - crop[1]),
dtype='int32' dtype='int32'
) )
cropped = tvm.relay.take(cropped, indices=indices, axis=axis) cropped = tvm.relay.take(cropped, indices=indices, axis=axis)
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""
Utilities for building Relay loops.
"""
from .scope_builder import ScopeBuilder
from . import expr as _expr
def while_loop(cond, loop_vars, loop_bodies):
"""
Construct a while loop.
Parameters
----------
cond: Callable[Tuple[relay.Expr], relay.Expr]
The condition of the loop.
loop_vars: Tuple[relay.Expr]
The variables being looped over.
The initial values of the loop, will be used to
construct the loop variables.
loop_bodies: Callable[Tuple[relay.Expr], Tuple[relay.Expr]]
The body of the loop, should be a function which
given loop variables produces the output result
also as a tuple
Returns
-------
loop: relay.Expr
The loop expression.
"""
sb = ScopeBuilder()
loop = _expr.Var("while_loop")
fresh_vars = []
for i, loop_var in enumerate(loop_vars):
name = loop_var.name_hint if isinstance(loop_var, _expr.Var) else "arg{}".format(i)
new_var = _expr.var(name, type_annotation=sb.type_of(loop_var))
fresh_vars.append(new_var)
with sb.if_scope(cond(*fresh_vars)):
sb.ret(loop(*loop_bodies(*fresh_vars)))
with sb.else_scope():
sb.ret(_expr.Tuple(fresh_vars))
func = _expr.Function(fresh_vars, sb.get())
let = _expr.Let(loop, func, loop)
return let
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
"""Transform operators.""" """Transform operators."""
from . import _make from . import _make
from ..expr import TupleWrapper from ..expr import TupleWrapper, const
def cast(data, dtype): def cast(data, dtype):
...@@ -272,7 +272,7 @@ def full_like(data, fill_value): ...@@ -272,7 +272,7 @@ def full_like(data, fill_value):
return _make.full_like(data, fill_value) return _make.full_like(data, fill_value)
def arange(start, stop=None, step=1, dtype="float32"): def arange(start, stop=None, step=None, dtype="float32"):
"""Return evenly spaced values within a given interval. """Return evenly spaced values within a given interval.
.. note:: .. note::
...@@ -310,9 +310,13 @@ def arange(start, stop=None, step=1, dtype="float32"): ...@@ -310,9 +310,13 @@ def arange(start, stop=None, step=1, dtype="float32"):
relay.arange(1, 5) = [1, 2, 3, 4] relay.arange(1, 5) = [1, 2, 3, 4]
relay.arange(1, 5, 1.5) = [1, 2.5, 4] relay.arange(1, 5, 1.5) = [1, 2.5, 4]
""" """
if step is None:
step = const(1, dtype)
if stop is None: if stop is None:
stop = start stop = start
start = 0 start = const(0, dtype=dtype)
return _make.arange(start, stop, step, dtype) return _make.arange(start, stop, step, dtype)
......
...@@ -42,7 +42,6 @@ class WithScope(object): ...@@ -42,7 +42,6 @@ class WithScope(object):
else: else:
self._exit_cb() self._exit_cb()
def _make_lets(bindings, ret_value): def _make_lets(bindings, ret_value):
"""Make a nested let expressions. """Make a nested let expressions.
...@@ -176,6 +175,24 @@ class ScopeBuilder(object): ...@@ -176,6 +175,24 @@ class ScopeBuilder(object):
false_branch) false_branch)
return WithScope(None, _on_exit) return WithScope(None, _on_exit)
def type_of(self, expr):
"""
Compute the type of an expression.
Parameters
----------
expr: relay.Expr
The expression to compute the type of.
"""
if isinstance(expr, _expr.Var):
return expr.type_annotation
ity = _ty.IncompleteType()
var = _expr.var("unify", ity)
self.let(var, expr)
return ity
def ret(self, value): def ret(self, value):
"""Set the return value of this scope. """Set the return value of this scope.
......
...@@ -20,6 +20,7 @@ from enum import IntEnum ...@@ -20,6 +20,7 @@ from enum import IntEnum
from .base import RelayNode, register_relay_node from .base import RelayNode, register_relay_node
from . import _make from . import _make
Any = _make.Any
class Type(RelayNode): class Type(RelayNode):
"""The base type for all Relay types.""" """The base type for all Relay types."""
...@@ -137,6 +138,19 @@ class TypeVar(Type): ...@@ -137,6 +138,19 @@ class TypeVar(Type):
""" """
self.__init_handle_by_constructor__(_make.TypeVar, var, kind) self.__init_handle_by_constructor__(_make.TypeVar, var, kind)
def ShapeVar(name):
"""A helper which constructs a type var of which the shape kind.
Parameters
----------
name : str
Returns
-------
type_var : tvm.relay.TypeVar
The shape variable.
"""
return TypeVar(name, kind=Kind.ShapeVar)
@register_relay_node @register_relay_node
class GlobalTypeVar(Type): class GlobalTypeVar(Type):
......
...@@ -970,7 +970,9 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) { ...@@ -970,7 +970,9 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
op->call_type == Call::PureExtern) { op->call_type == Call::PureExtern) {
return CreateCallExtern(op); return CreateCallExtern(op);
} else { } else {
LOG(FATAL) << "Unknown call type "; LOG(FATAL) << "Unknown call type " <<
"name= " << op->name <<
" call_type= " << op->call_type;
return nullptr; return nullptr;
} }
} }
......
...@@ -246,6 +246,12 @@ inline Expr MergeMulMod(const Expr &base) { ...@@ -246,6 +246,12 @@ inline Expr MergeMulMod(const Expr &base) {
inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) { inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
Expr base = n->elem_offset; Expr base = n->elem_offset;
if (n->strides.size() == 0) { if (n->strides.size() == 0) {
// Scalar case
if (n->shape.size() == 0 && index.size() == 1) {
auto is_int = index[0].as<IntImm>();
CHECK(is_int && is_int->value == 0);
base = base + index[0];
} else {
CHECK_EQ(n->shape.size(), index.size()); CHECK_EQ(n->shape.size(), index.size());
if (index.size() > 0) { if (index.size() > 0) {
Expr offset = index[0]; Expr offset = index[0];
...@@ -254,6 +260,7 @@ inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) { ...@@ -254,6 +260,7 @@ inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
} }
base = base + offset; base = base + offset;
} }
}
} else { } else {
CHECK_EQ(n->strides.size(), index.size()); CHECK_EQ(n->strides.size(), index.size());
if (is_zero(base)) { if (is_zero(base)) {
......
...@@ -35,14 +35,25 @@ namespace Internal { ...@@ -35,14 +35,25 @@ namespace Internal {
using tvm::ir::CommReducerNode; using tvm::ir::CommReducerNode;
using tvm::ir::Reduce; using tvm::ir::Reduce;
using tvm::ir::Any;
using tvm::ir::AttrStmt; using tvm::ir::AttrStmt;
template<> template<>
void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const { void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const {
LOG(FATAL) << "Reduce do not work with old Visitor, use IRFunctor style visitor"; LOG(FATAL) << "Reduce does not work with old Visitor, use IRFunctor style visitor";
}
template<>
void ExprNode<Any>::accept(IRVisitor *v, const Expr&) const {
LOG(FATAL) << "Any does not work with old Visitor, use IRFunctor style visitor";
} }
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Any>([](const Any *op, IRPrinter *p) {
p->stream << "?";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) { .set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) {
p->stream << "reduce(combiner=" p->stream << "reduce(combiner="
<< op->combiner; << op->combiner;
...@@ -116,8 +127,14 @@ Expr Reduce::make(CommReducer combiner, Array<Expr> source, ...@@ -116,8 +127,14 @@ Expr Reduce::make(CommReducer combiner, Array<Expr> source,
return Expr(n); return Expr(n);
} }
Expr Any::make() {
auto n = make_node<Any>();
return Expr(n);
}
TVM_REGISTER_NODE_TYPE(CommReducerNode); TVM_REGISTER_NODE_TYPE(CommReducerNode);
TVM_REGISTER_NODE_TYPE(Reduce); TVM_REGISTER_NODE_TYPE(Reduce);
TVM_REGISTER_NODE_TYPE(Any);
TVM_REGISTER_NODE_TYPE(AttrStmt); TVM_REGISTER_NODE_TYPE(AttrStmt);
TVM_REGISTER_NODE_TYPE(FloatImm); TVM_REGISTER_NODE_TYPE(FloatImm);
......
...@@ -38,9 +38,12 @@ Expr Tensor::operator()(Array<Var> indices) const { ...@@ -38,9 +38,12 @@ Expr Tensor::operator()(Array<Var> indices) const {
Expr Tensor::operator()(Array<Expr> indices) const { Expr Tensor::operator()(Array<Expr> indices) const {
using HalideIR::Internal::Call; using HalideIR::Internal::Call;
if (ndim() != 0) {
CHECK_EQ(ndim(), indices.size()) CHECK_EQ(ndim(), indices.size())
<< "Tensor dimension mismatch in read" << "Tensor dimension mismatch in read"
<< "ndim = " << ndim() << ", indices.size=" << indices.size(); << "ndim = " << ndim() << ", indices.size=" << indices.size();
}
auto n = Call::make( auto n = Call::make(
(*this)->dtype, (*this)->op->name, indices, Call::Halide, (*this)->dtype, (*this)->op->name, indices, Call::Halide,
(*this)->op, (*this)->value_index); (*this)->op, (*this)->value_index);
......
...@@ -417,10 +417,10 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -417,10 +417,10 @@ class RelayBuildModule : public runtime::ModuleNode {
} }
/*! /*!
* \brief Build relay function to runtime module * \brief Compile a Relay function to runtime module.
* *
* \param func Relay Function * \param func The Relay function.
* \param params parameters * \param params The parameters.
*/ */
void BuildRelay( void BuildRelay(
Function func, Function func,
...@@ -444,9 +444,14 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -444,9 +444,14 @@ class RelayBuildModule : public runtime::ModuleNode {
ret_.graph_json = graph_codegen_->GetJSON(); ret_.graph_json = graph_codegen_->GetJSON();
ret_.params = graph_codegen_->GetParams(); ret_.params = graph_codegen_->GetParams();
ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host_, auto lowered_funcs = graph_codegen_->GetLoweredFunc();
if (lowered_funcs.size() != 0) {
ret_.mod = tvm::build(
lowered_funcs,
target_host_,
BuildConfig::Current()); BuildConfig::Current());
} }
}
protected: protected:
std::unique_ptr<GraphCodegen> graph_codegen_; std::unique_ptr<GraphCodegen> graph_codegen_;
......
...@@ -67,6 +67,7 @@ void ErrorReporter::RenderErrors(const Module& module, bool use_color) { ...@@ -67,6 +67,7 @@ void ErrorReporter::RenderErrors(const Module& module, bool use_color) {
std::stringstream err_msg; std::stringstream err_msg;
err_msg << rang::fg::red; err_msg << rang::fg::red;
err_msg << " ";
for (auto index : error_indicies) { for (auto index : error_indicies) {
err_msg << this->errors_[index].what() << "; "; err_msg << this->errors_[index].what() << "; ";
} }
...@@ -88,7 +89,7 @@ void ErrorReporter::RenderErrors(const Module& module, bool use_color) { ...@@ -88,7 +89,7 @@ void ErrorReporter::RenderErrors(const Module& module, bool use_color) {
// First we output a header for the errors. // First we output a header for the errors.
annotated_prog << annotated_prog <<
rang::style::bold << std::endl << rang::style::bold << std::endl <<
"Error(s) have occurred. We have annotated the program with them:" "Error(s) have occurred. The program has been annotated with them:"
<< std::endl << std::endl << rang::style::reset; << std::endl << std::endl << rang::style::reset;
// For each global function which contains errors, we will // For each global function which contains errors, we will
......
...@@ -287,6 +287,8 @@ RefCreate RefCreateNode::make(Expr value) { ...@@ -287,6 +287,8 @@ RefCreate RefCreateNode::make(Expr value) {
return RefCreate(n); return RefCreate(n);
} }
TVM_REGISTER_NODE_TYPE(RefCreateNode);
TVM_REGISTER_API("relay._make.RefCreate") TVM_REGISTER_API("relay._make.RefCreate")
.set_body_typed(RefCreateNode::make); .set_body_typed(RefCreateNode::make);
...@@ -301,6 +303,8 @@ RefRead RefReadNode::make(Expr ref) { ...@@ -301,6 +303,8 @@ RefRead RefReadNode::make(Expr ref) {
return RefRead(n); return RefRead(n);
} }
TVM_REGISTER_NODE_TYPE(RefReadNode);
TVM_REGISTER_API("relay._make.RefRead") TVM_REGISTER_API("relay._make.RefRead")
.set_body_typed(RefReadNode::make); .set_body_typed(RefReadNode::make);
...@@ -316,6 +320,8 @@ RefWrite RefWriteNode::make(Expr ref, Expr value) { ...@@ -316,6 +320,8 @@ RefWrite RefWriteNode::make(Expr ref, Expr value) {
return RefWrite(n); return RefWrite(n);
} }
TVM_REGISTER_NODE_TYPE(RefWriteNode);
TVM_REGISTER_API("relay._make.RefWrite") TVM_REGISTER_API("relay._make.RefWrite")
.set_body_typed(RefWriteNode::make); .set_body_typed(RefWriteNode::make);
......
...@@ -686,7 +686,9 @@ class PrettyPrinter : ...@@ -686,7 +686,9 @@ class PrettyPrinter :
Doc PrintAttr(const NodeRef& value, bool meta = false) { Doc PrintAttr(const NodeRef& value, bool meta = false) {
if (value.defined()) { if (value.defined()) {
Doc printed_attr; Doc printed_attr;
if (meta) { if (value.as<tvm::ir::Any>()) {
printed_attr << "?";
} else if (meta) {
printed_attr = meta_.GetMetaNode(value); printed_attr = meta_.GetMetaNode(value);
} else { } else {
printed_attr = VisitAttr(value); printed_attr = VisitAttr(value);
...@@ -846,6 +848,12 @@ std::string PrettyPrint_(const NodeRef& node, ...@@ -846,6 +848,12 @@ std::string PrettyPrint_(const NodeRef& node,
return doc.str(); return doc.str();
} }
std::string PrettyPrint(const NodeRef& node) {
Doc doc;
doc << PrettyPrinter(false, runtime::TypedPackedFunc<std::string(Expr)>()).PrintFinal(node);
return doc.str();
}
std::string AsText(const NodeRef& node, std::string AsText(const NodeRef& node,
bool show_meta_data, bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) { runtime::TypedPackedFunc<std::string(Expr)> annotate) {
......
...@@ -228,5 +228,9 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -228,5 +228,9 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "RefTypeNode(" << node->value << ")"; p->stream << "RefTypeNode(" << node->value << ")";
}); });
TVM_REGISTER_API("relay._make.Any")
.set_body_typed<IndexExpr()>([]() { return Any::make(); });
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
* \brief Transform operators. * \brief Transform operators.
*/ */
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/error.h>
#include <tvm/relay/attrs/transform.h> #include <tvm/relay/attrs/transform.h>
#include <tvm/expr_operator.h> #include <tvm/expr_operator.h>
#include <tvm/ir.h> #include <tvm/ir.h>
...@@ -184,40 +185,77 @@ bool ConcatenateRel(const Array<Type>& types, ...@@ -184,40 +185,77 @@ bool ConcatenateRel(const Array<Type>& types,
const TypeReporter& reporter) { const TypeReporter& reporter) {
// types: [data, result] // types: [data, result]
CHECK_EQ(types.size(), 2); CHECK_EQ(types.size(), 2);
/* If we receive a tuple we can continue, if we receive
* anything but an incomplete type we should signal an
* error.
*/
const auto* tensor_tuple = types[0].as<TupleTypeNode>(); const auto* tensor_tuple = types[0].as<TupleTypeNode>();
if (tensor_tuple == nullptr) { if (tensor_tuple == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>()) throw relay::Error(
<< "cast: expect input type to be TupleType but get " RELAY_ERROR(
<< types[0]; "concatenate requires a tuple of tensors as the first argument, found "
<< PrettyPrint(types[0])));
} else if (types[0].as<IncompleteTypeNode>() != nullptr) {
return false; return false;
} }
const auto* param = attrs.as<ConcatenateAttrs>(); const auto* param = attrs.as<ConcatenateAttrs>();
if (tensor_tuple->fields[0].as<IncompleteTypeNode>()) {
return false;
}
const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]); const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]);
// Sanity check: ndim and dtype. // Sanity check: ndim and dtype.
const int ndim = static_cast<int>(first->shape.size()); const int ndim = static_cast<int>(first->shape.size());
const DataType dtype = first->dtype; const DataType dtype = first->dtype;
for (const Type& ele : tensor_tuple->fields) { for (const Type& ele : tensor_tuple->fields) {
if (ele.as<IncompleteTypeNode>()) {
return false;
}
const auto& e = Downcast<TensorType>(ele); const auto& e = Downcast<TensorType>(ele);
int e_ndim = static_cast<int>(e->shape.size()); int e_ndim = static_cast<int>(e->shape.size());
const DataType& e_dtype = e->dtype; const DataType& e_dtype = e->dtype;
CHECK_EQ(e_ndim, ndim) << "relay.concatenate requires all tensors have the same ndim"; if (e_ndim != ndim) {
CHECK_EQ(e_dtype, dtype) << "relay.concatenate requires all tensors have the same dtype"; throw relay::Error("relay.concatenate requires all tensors have the same ndim");
}
if (e_dtype != dtype) {
throw relay::Error("relay.concatenate requires all tensors have the same dtype");
}
} }
// Sanity check: axis // Sanity check: axis
int axis = param->axis; int axis = param->axis;
CHECK(-ndim <= axis && axis < ndim) if (!(-ndim <= axis && axis < ndim)) {
<< "concatenate only accepts `axis` in [-ndim, ndim)" throw relay::Error(RELAY_ERROR(
<< ", but got axis = " << axis "concatenate only accepts `axis` in [-ndim, ndim)" <<
<< ", and ndim = " << ndim; ", but got axis = " << axis <<
", and ndim = " << ndim));
}
axis = axis < 0 ? ndim + axis : axis; axis = axis < 0 ? ndim + axis : axis;
// Calculate shape // Calculate shape
std::vector<IndexExpr>&& oshape = AsVector(first->shape); std::vector<IndexExpr>&& oshape = AsVector(first->shape);
IndexExpr &concat_dim = oshape[axis]; IndexExpr &concat_dim = oshape[axis];
bool has_any = false;
if (concat_dim.as<Any>()) {
has_any = true;
} else {
for (int i = 1; i < static_cast<int>(tensor_tuple->fields.size()); ++i) { for (int i = 1; i < static_cast<int>(tensor_tuple->fields.size()); ++i) {
const auto& e = Downcast<TensorType>(tensor_tuple->fields[i]); const auto& e = Downcast<TensorType>(tensor_tuple->fields[i]);
if (e->shape[axis].as<Any>()) {
has_any = true;
break;
}
concat_dim += e->shape[axis]; concat_dim += e->shape[axis];
} }
reporter->Assign(types[1], TensorTypeNode::make(oshape, dtype)); }
if (has_any) {
concat_dim = Any::make();
}
auto rtype = TensorTypeNode::make(oshape, dtype);
reporter->Assign(types[1], rtype);
return true; return true;
} }
...@@ -499,6 +537,8 @@ bool ReshapeRel(const Array<Type>& types, ...@@ -499,6 +537,8 @@ bool ReshapeRel(const Array<Type>& types,
newshape = param->newshape; newshape = param->newshape;
} }
Array<IndexExpr> oshape; Array<IndexExpr> oshape;
std::unordered_set<size_t> used_input_dims;
std::unordered_set<size_t> used_output_dims;
size_t src_idx = 0; size_t src_idx = 0;
int infer_idx = -1; int infer_idx = -1;
...@@ -511,6 +551,8 @@ bool ReshapeRel(const Array<Type>& types, ...@@ -511,6 +551,8 @@ bool ReshapeRel(const Array<Type>& types,
} else if (svalue == 0) { } else if (svalue == 0) {
// keep same // keep same
CHECK_LT(src_idx, data_shape.size()); CHECK_LT(src_idx, data_shape.size());
used_input_dims.insert(src_idx);
used_output_dims.insert(oshape.size());
oshape.push_back(data_shape[src_idx++]); oshape.push_back(data_shape[src_idx++]);
} else if (svalue == -1) { } else if (svalue == -1) {
// inference based on rest // inference based on rest
...@@ -522,31 +564,49 @@ bool ReshapeRel(const Array<Type>& types, ...@@ -522,31 +564,49 @@ bool ReshapeRel(const Array<Type>& types,
} else if (svalue == -2) { } else if (svalue == -2) {
// copy all remaining dims from source // copy all remaining dims from source
while (src_idx < data_shape.size()) { while (src_idx < data_shape.size()) {
used_input_dims.insert(src_idx);
used_output_dims.insert(oshape.size());
oshape.push_back(data_shape[src_idx++]); oshape.push_back(data_shape[src_idx++]);
} }
} else if (svalue == -3) { } else if (svalue == -3) {
// merge two dims from source // merge two dims from source
CHECK_LT(src_idx + 1, data_shape.size()); CHECK_LT(src_idx + 1, data_shape.size());
used_input_dims.insert(src_idx);
IndexExpr d1 = data_shape[src_idx++]; IndexExpr d1 = data_shape[src_idx++];
used_input_dims.insert(src_idx);
IndexExpr d2 = data_shape[src_idx++]; IndexExpr d2 = data_shape[src_idx++];
used_output_dims.insert(oshape.size());
oshape.push_back(d1 * d2); oshape.push_back(d1 * d2);
} else if (svalue == -4) { } else if (svalue == -4) {
// split the source dim s into two dims // split the source dim s into two dims
// read the left dim and then the right dim (either can be -1) // read the left dim and then the right dim (either can be -1)
CHECK_LT(i + 2, newshape.size()); CHECK_LT(i + 2, newshape.size());
CHECK_LT(src_idx, data_shape.size()); CHECK_LT(src_idx, data_shape.size());
used_input_dims.insert(src_idx);
IndexExpr d0 = data_shape[src_idx++]; IndexExpr d0 = data_shape[src_idx++];
Integer d1 = newshape[++i]; Integer d1 = newshape[++i];
Integer d2 = newshape[++i]; Integer d2 = newshape[++i];
if (d1->value == -1) { if (d1->value == -1) {
CHECK(d2->value != -1) CHECK(d2->value != -1)
<< "Split dims cannot both be -1."; << "Split dims cannot both be -1.";
used_output_dims.insert(oshape.size());
if (d0.as<Any>()) {
oshape.push_back(Any::make());
} else {
oshape.push_back(d0 / d2); oshape.push_back(d0 / d2);
}
used_output_dims.insert(oshape.size());
oshape.push_back(d2); oshape.push_back(d2);
} else { } else {
used_output_dims.insert(oshape.size());
oshape.push_back(d1); oshape.push_back(d1);
used_output_dims.insert(oshape.size());
if (d2->value == -1) { if (d2->value == -1) {
if (d0.as<Any>()) {
oshape.push_back(Any::make());
} else {
oshape.push_back(d0 / d1); oshape.push_back(d0 / d1);
}
} else { } else {
oshape.push_back(d2); oshape.push_back(d2);
} }
...@@ -555,9 +615,30 @@ bool ReshapeRel(const Array<Type>& types, ...@@ -555,9 +615,30 @@ bool ReshapeRel(const Array<Type>& types,
} }
if (infer_idx >= 0) { if (infer_idx >= 0) {
IndexExpr new_size = arith::ComputeReduce<tvm::ir::Mul>(oshape, 1); IndexExpr infer_dim = 1;
IndexExpr old_size = arith::ComputeReduce<tvm::ir::Mul>(data_shape, 1); for (size_t i = 0; i < data_shape.size(); ++i) {
oshape.Set(infer_idx, old_size / new_size); if (used_input_dims.count(i) != 0) {
continue;
}
if (data_shape[i].as<Any>()) {
infer_dim = Any::make();
break;
}
infer_dim *= data_shape[i];
}
if (!infer_dim.as<Any>()) {
for (size_t i = 0; i < oshape.size(); ++i) {
if (used_output_dims.count(i) != 0) {
continue;
}
if (oshape[i].as<Any>()) {
infer_dim = Any::make();
break;
}
infer_dim /= oshape[i];
}
}
oshape.Set(infer_idx, infer_dim);
} }
if (param->reverse) { if (param->reverse) {
...@@ -978,21 +1059,51 @@ and type as the input array. ...@@ -978,21 +1059,51 @@ and type as the input array.
// arange operator // arange operator
TVM_REGISTER_NODE_TYPE(ArangeAttrs); TVM_REGISTER_NODE_TYPE(ArangeAttrs);
double ToScalar(const runtime::NDArray& array) {
if (array->dtype.code == kDLInt || array->dtype.code == kDLUInt) {
return reinterpret_cast<int32_t*>(array->data)[0];
} else {
return reinterpret_cast<float*>(array->data)[0];
}
}
bool ArangeRel(const Array<Type>& types, bool ArangeRel(const Array<Type>& types,
int num_inputs, int num_inputs,
const Attrs& attrs, const Attrs& raw_attrs,
const TypeReporter& reporter) { const TypeReporter& reporter) {
CHECK_EQ(types.size(), 1); CHECK_EQ(types.size(), 4);
const ArangeAttrs* param = attrs.as<ArangeAttrs>(); const ArangeAttrs* attrs = raw_attrs.as<ArangeAttrs>();
IndexExpr num_elem = tvm::cast(tvm::Int(32), tvm::ceil( const ConstantNode *cstart, *cstop, *cstep;
tvm::cast(tvm::Float(32), param->stop - param->start) / param->step));
if (const tvm::ir::IntImm* val = num_elem.as<tvm::ir::IntImm>()) { reporter->Assign(types[0], types[1]);
CHECK_GT(val->value, 0) reporter->Assign(types[1], types[2]);
<< "Invalid arange attributes (start, stop, step): " << param->start reporter->Assign(types[2], TensorTypeNode::make({}, attrs->dtype));
<< ", " << param->stop << ", " << param->step;
} if ((cstart = attrs->start.as<ConstantNode>()) &&
reporter->Assign(types[0], TensorTypeNode::make({num_elem}, param->dtype)); (cstop = attrs->stop.as<ConstantNode>()) &&
(cstep = attrs->step.as<ConstantNode>())) {
double start = ToScalar(cstart->data);
double stop = ToScalar(cstop->data);
double step = ToScalar(cstep->data);
int32_t num_elem = static_cast<int32_t>(std::ceil((stop - start) / step));
CHECK_GT(num_elem, 0)
<< "Invalid arange attributes (start, stop, step): " << attrs->start
<< ", " << attrs->stop << ", " << attrs->step;
reporter->Assign(types[3], TensorTypeNode::make({num_elem}, attrs->dtype));
return true;
} else {
reporter->Assign(types[3], TensorTypeNode::make({Any::make()}, attrs->dtype));
return true; return true;
}
}
inline Tensor DynamicArange(const tvm::Tensor& start, const tvm::Tensor& stop,
const tvm::Tensor& step, tvm::Type dtype, std::string name = "tensor",
std::string tag = topi::kInjective) {
tvm::Expr num_elem = tvm::Var("num_elem");
return tvm::compute({num_elem}, [&](const Array<tvm::Var>& indices) {
return tvm::cast(dtype, start[0] + step[0] * indices[0]);
}, name, tag);
} }
Array<Tensor> ArangeCompute(const Attrs& attrs, Array<Tensor> ArangeCompute(const Attrs& attrs,
...@@ -1000,35 +1111,53 @@ Array<Tensor> ArangeCompute(const Attrs& attrs, ...@@ -1000,35 +1111,53 @@ Array<Tensor> ArangeCompute(const Attrs& attrs,
const Type& out_type, const Type& out_type,
const Target& target) { const Target& target) {
const ArangeAttrs* param = attrs.as<ArangeAttrs>(); const ArangeAttrs* param = attrs.as<ArangeAttrs>();
return { topi::arange(param->start, param->stop, param->step, param->dtype) }; Tensor start = inputs[0];
Tensor stop = inputs[1];
Tensor step = inputs[2];
Array<tvm::Expr> empty = {0};
return { DynamicArange(start, stop, step, param->dtype) };
} }
Expr MakeArange(tvm::Expr start, Expr MakeArange(Expr start,
tvm::Expr stop, Expr stop,
tvm::Expr step, Expr step,
DataType dtype) { DataType dtype) {
auto attrs = make_node<ArangeAttrs>(); auto attrs = make_node<ArangeAttrs>();
attrs->start = std::move(start); attrs->start = start;
attrs->stop = std::move(stop); attrs->stop = stop;
attrs->step = std::move(step); attrs->step = step;
attrs->dtype = std::move(dtype); attrs->dtype = dtype;
static const Op& op = Op::Get("arange"); static const Op& op = Op::Get("arange");
return CallNode::make(op, {}, Attrs(attrs), {}); return CallNode::make(op, {start, stop, step}, Attrs(attrs), {});
} }
TVM_REGISTER_API("relay.op._make.arange") TVM_REGISTER_API("relay.op._make.arange")
.set_body_typed(MakeArange); .set_body_typed(MakeArange);
// An issue with the existing design is that we require dependency
// to type the operator precisely.
//
// Supporting this in general is challenging so we duplicate the
// secondary arguments as args and attributes.
//
// In this way reify the arguments at both the value and type level.
//
// In the case our arguments are constant we can immediately recover
// the type of arange.
//
// In general I think we should avoid this pattern, and introduce
// a secondary shape analysis to recover more precise information.
RELAY_REGISTER_OP("arange") RELAY_REGISTER_OP("arange")
.describe(R"code(Returns evenly spaced values within a given interval. .describe(R"code(Returns evenly spaced values within a given interval.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ArangeAttrs") .set_attrs_type_key("relay.attrs.ArangeAttrs")
.set_num_inputs(0) .set_num_inputs(3)
.set_support_level(3) .set_support_level(3)
.add_type_rel("Arange", ArangeRel) .add_type_rel("Arange", ArangeRel)
.set_attr<FTVMCompute>("FTVMCompute", ArangeCompute) .set_attr<FTVMCompute>("FTVMCompute", ArangeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective); .set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<AnyCodegenStrategy>("AnyCodegenStrategy", kVariableDimensions);
// repeat operator // repeat operator
TVM_REGISTER_NODE_TYPE(RepeatAttrs); TVM_REGISTER_NODE_TYPE(RepeatAttrs);
......
...@@ -87,6 +87,11 @@ Type ConcreteBroadcast(const TensorType& t1, ...@@ -87,6 +87,11 @@ Type ConcreteBroadcast(const TensorType& t1,
oshape.push_back(s2); oshape.push_back(s2);
} else if (EqualConstInt(s2, 1)) { } else if (EqualConstInt(s2, 1)) {
oshape.push_back(s1); oshape.push_back(s1);
} else if (s1.as<Any>() && EqualConstInt(s2, 1)) {
// TODO(@jroesch): we need to come back to this
oshape.push_back(s2);
} else if (s2.as<Any>() && EqualConstInt(s1, 1)) {
oshape.push_back(s1);
} else { } else {
RELAY_ERROR( RELAY_ERROR(
"Incompatible broadcast type " "Incompatible broadcast type "
......
...@@ -313,17 +313,24 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -313,17 +313,24 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
Type VisitExpr_(const LetNode* let) final { Type VisitExpr_(const LetNode* let) final {
// if the definition is a function literal, permit recursion // if the definition is a function literal, permit recursion
bool is_functional_literal = let->value.as<FunctionNode>() != nullptr; bool is_functional_literal = let->value.as<FunctionNode>() != nullptr;
Type let_type = IncompleteTypeNode::make(Kind::kType);
if (is_functional_literal) { if (is_functional_literal) {
type_map_[let->var].checked_type = IncompleteTypeNode::make(Kind::kType); let_type = GetType(let->var);
type_map_[let->var].checked_type = let_type;
} }
Type vtype = GetType(let->value);
if (let->var->type_annotation.defined()) { if (let->var->type_annotation.defined()) {
vtype = Unify(vtype, let->var->type_annotation, GetRef<Let>(let)); let_type = Unify(let_type, let->var->type_annotation, GetRef<Let>(let));
} }
Type vtype = GetType(let->value);
let_type = Unify(let_type, vtype, GetRef<Let>(let));
CHECK(is_functional_literal || !type_map_.count(let->var)); CHECK(is_functional_literal || !type_map_.count(let->var));
// NOTE: no scoping is necessary because var are unique in program // NOTE: no scoping is necessary because var are unique in program
type_map_[let->var].checked_type = vtype; type_map_[let->var].checked_type = let_type;
return GetType(let->body); return GetType(let->body);
} }
...@@ -473,7 +480,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -473,7 +480,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
} }
for (size_t i = 0; i < fn_ty->arg_types.size(); i++) { for (size_t i = 0; i < fn_ty->arg_types.size(); i++) {
this->Unify(fn_ty->arg_types[i], arg_types[i], call->args[i]); this->Unify(fn_ty->arg_types[i], arg_types[i], GetRef<Call>(call));
} }
for (auto cs : fn_ty->type_constraints) { for (auto cs : fn_ty->type_constraints) {
...@@ -556,6 +563,14 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -556,6 +563,14 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
return FuncTypeNode::make(c->inputs, TypeCallNode::make(c->belong_to, types), return FuncTypeNode::make(c->inputs, TypeCallNode::make(c->belong_to, types),
td->type_vars, {}); td->type_vars, {});
} }
void Solve() {
solver_.Solve();
if (err_reporter.AnyErrors()) {
err_reporter.RenderErrors(mod_);
}
}
}; };
class TypeInferencer::Resolver : public ExprMutator, PatternMutator { class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
...@@ -673,7 +688,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { ...@@ -673,7 +688,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
update_missing_type_annotation_ && update_missing_type_annotation_ &&
!new_var->type_annotation.defined()); !new_var->type_annotation.defined());
bool need_update_fn = ( bool need_update_fn =(
std::is_base_of<FunctionNode, T>::value && std::is_base_of<FunctionNode, T>::value &&
update_missing_type_annotation_ && update_missing_type_annotation_ &&
!new_fn->ret_type.defined()); !new_fn->ret_type.defined());
...@@ -738,16 +753,13 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { ...@@ -738,16 +753,13 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
Expr TypeInferencer::Infer(Expr expr) { Expr TypeInferencer::Infer(Expr expr) {
// Step 0: Populate the constraints. // Step 1: Populate the constraints.
GetType(expr); GetType(expr);
// Step 1: Solve the constraints.
solver_.Solve();
if (err_reporter.AnyErrors()) { // Step 2: Solve the constraints.
err_reporter.RenderErrors(mod_); Solve();
}
// Step 2: Attach resolved types to checked_type field. // Step 3: Attach resolved types to checked_type field.
auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr); auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr);
CHECK(WellFormed(resolved_expr)); CHECK(WellFormed(resolved_expr));
return resolved_expr; return resolved_expr;
......
...@@ -24,6 +24,8 @@ ...@@ -24,6 +24,8 @@
*/ */
#include <string> #include <string>
#include <memory> #include <memory>
#include <tuple>
#include <utility>
#include "type_solver.h" #include "type_solver.h"
#include "../ir/type_functor.h" #include "../ir/type_functor.h"
...@@ -90,7 +92,7 @@ class TypeSolver::OccursChecker : public TypeVisitor { ...@@ -90,7 +92,7 @@ class TypeSolver::OccursChecker : public TypeVisitor {
class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> { class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
public: public:
explicit Unifier(TypeSolver* solver) : solver_(solver) {} explicit Unifier(TypeSolver* solver, const NodeRef& loc) : solver_(solver), loc(loc) {}
Type Unify(const Type& src, const Type& dst) { Type Unify(const Type& src, const Type& dst) {
// Known limitation // Known limitation
...@@ -102,35 +104,44 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> { ...@@ -102,35 +104,44 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
if (lhs->FindRoot() == rhs->FindRoot()) { if (lhs->FindRoot() == rhs->FindRoot()) {
return lhs->resolved_type; return lhs->resolved_type;
} }
if (lhs->resolved_type.as<IncompleteTypeNode>()) { if (lhs->resolved_type.as<IncompleteTypeNode>()) {
CHECK(!CheckOccurs(lhs, rhs->resolved_type)) CHECK(!OccursCheck(lhs, rhs->resolved_type))
<< "Incomplete type " << lhs->resolved_type << " occurs in " << "Incomplete type " << lhs->resolved_type << " occurs in "
<< rhs->resolved_type << ", cannot unify"; << rhs->resolved_type << ", cannot unify";
solver_->MergeFromTo(lhs, rhs); solver_->MergeFromTo(lhs, rhs);
return rhs->resolved_type; return rhs->resolved_type;
} else if (rhs->resolved_type.as<IncompleteTypeNode>()) { } else if (rhs->resolved_type.as<IncompleteTypeNode>()) {
CHECK(!CheckOccurs(rhs, lhs->resolved_type)) CHECK(!OccursCheck(rhs, lhs->resolved_type))
<< "Incomplete type " << rhs->resolved_type << " occurs in " << "Incomplete type " << rhs->resolved_type << " occurs in "
<< lhs->resolved_type << ", cannot unify"; << lhs->resolved_type << ", cannot unify";
solver_->MergeFromTo(rhs, lhs); solver_->MergeFromTo(rhs, lhs);
return lhs->resolved_type; return lhs->resolved_type;
} else { } else {
Type resolved = this->VisitType(lhs->resolved_type, rhs->resolved_type); Type resolved = this->VisitType(lhs->resolved_type, rhs->resolved_type);
CHECK(resolved.defined()) if (!resolved.defined()) {
<< "Unable to unify parent types: " solver_->ReportError(RELAY_ERROR("unable to unify: "
<< lhs->resolved_type << " and " << rhs->resolved_type; << "`" << PrettyPrint(lhs->resolved_type) << "` and `"
<< PrettyPrint(rhs->resolved_type) << "`"),
this->loc);
return lhs->resolved_type;
} else {
TypeNode* top = solver_->GetTypeNode(resolved); TypeNode* top = solver_->GetTypeNode(resolved);
solver_->MergeFromTo(lhs, top); solver_->MergeFromTo(lhs, top);
solver_->MergeFromTo(rhs, top); solver_->MergeFromTo(rhs, top);
return resolved; return resolved;
} }
} }
}
// Checks whether lhs (taken to be a type var) occurs in t, meaning // Checks whether lhs (taken to be a type var) occurs in t, meaning
// there is a recursive equality constraint, which should be rejected. // there is a recursive equality constraint, which should be rejected.
// N.b.: A tautology like ?a = ?a is okay and should be checked for // N.b.: A tautology like ?a = ?a is okay and should be checked for
// *before* calling this method // *before* calling this method
bool CheckOccurs(TypeNode* lhs, const Type& t) { //
// See: https://en.wikipedia.org/wiki/Occurs_check
bool OccursCheck(TypeNode* lhs, const Type& t) {
OccursChecker rc(solver_, lhs); OccursChecker rc(solver_, lhs);
return rc.Check(t); return rc.Check(t);
} }
...@@ -145,6 +156,118 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> { ...@@ -145,6 +156,118 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
return t1; return t1;
} }
IndexExpr GetShape(const IndexExpr& e) {
IndexExpr ex = e;
while (true) {
auto it = solver_->shape_uf_.find(ex);
if (it == solver_->shape_uf_.end()) {
return ex;
} else {
ex = (*it).second;
}
}
}
IndexExpr UnifyDim(const IndexExpr& lhs, const IndexExpr& rhs) {
auto ulhs = GetShape(lhs);
auto urhs = GetShape(rhs);
if (ulhs.same_as(urhs)) {
return ulhs;
}
if (ulhs.as<Any>() || urhs.as<Any>()) {
return Any::make();
}
auto left_index0 = ulhs.as<tvm::Variable>();
auto right_index0 = urhs.as<tvm::IntImm>();
if (left_index0 && right_index0) {
solver_->shape_uf_.Set(ulhs, urhs);
return urhs;
}
auto left_index1 = ulhs.as<tvm::IntImm>();
auto right_index1 = urhs.as<tvm::Variable>();
if (left_index1 && right_index1) {
solver_->shape_uf_.Set(urhs, ulhs);
return ulhs;
}
auto left_index2 = ulhs.as<tvm::IntImm>();
auto right_index2 = urhs.as<tvm::IntImm>();
if (left_index2 && right_index2 && left_index2->value == right_index2->value) {
return ulhs;
}
return tvm::Expr();
}
Type VisitType_(const TensorTypeNode* op, const Type& tn) final {
const auto* tt_node = tn.as<TensorTypeNode>();
if (!tt_node) {
return Type(nullptr);
}
auto tt1 = GetRef<TensorType>(op);
auto tt2 = GetRef<TensorType>(tt_node);
if (AlphaEqual(tt1, tt2)) {
return std::move(tt1);
}
if (tt1->dtype != tt2->dtype) {
return Type(nullptr);
}
tvm::Array<IndexExpr> shape;
if (tt1->shape.size() != tt2->shape.size()) {
this->solver_->ReportError(
RELAY_ERROR(
"tensor type `" << PrettyPrint(tt1) <<
"` has " << tt1->shape.size() <<
" dimensions, while `" <<
PrettyPrint(tt2) <<
"` has " << tt2->shape.size() <<
" dimensions"), this->loc);
return Type(nullptr);
}
std::vector<std::tuple<size_t, IndexExpr, IndexExpr>> mismatches;
CHECK_EQ(tt1->shape.size(), tt2->shape.size());
for (size_t i = 0; i < tt1->shape.size(); i++) {
auto dim = UnifyDim(tt1->shape[i], tt2->shape[i]);
if (!dim.defined()) {
// NB: We push an arbitrary dimension here so we can continue error propogation.
shape.push_back(tt1->shape[i]);
tvm::Expr shape1 = tt1->shape[i];
tvm::Expr shape2 = tt2->shape[i];
std::tuple<int, IndexExpr, IndexExpr> tuple = std::make_tuple(i, shape1, shape2);
mismatches.push_back(tuple);
} else {
shape.push_back(dim);
}
}
if (mismatches.size() != 0) {
RelayErrorStream err;
err << "in particular ";
for (auto mismatch : mismatches) {
err << "dimension "
<< std::get<0>(mismatch)
<< " conflicts "
<< std::get<1>(mismatch)
<< " does not match "
<< std::get<2>(mismatch);
}
Error error(err);
this->solver_->ReportError(error, this->loc);
return Type(nullptr);
}
return TensorTypeNode::make(shape, tt1->dtype);
}
Type VisitType_(const TupleTypeNode* op, const Type& tn) final { Type VisitType_(const TupleTypeNode* op, const Type& tn) final {
const auto* ttn = tn.as<TupleTypeNode>(); const auto* ttn = tn.as<TupleTypeNode>();
if (!ttn || op->fields.size() != ttn->fields.size()) { if (!ttn || op->fields.size() != ttn->fields.size()) {
...@@ -225,6 +348,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> { ...@@ -225,6 +348,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
private: private:
TypeSolver* solver_; TypeSolver* solver_;
NodeRef loc;
}; };
class TypeSolver::Resolver : public TypeMutator { class TypeSolver::Resolver : public TypeMutator {
...@@ -412,14 +536,14 @@ void TypeSolver::MergeFromTo(TypeNode* src, TypeNode* dst) { ...@@ -412,14 +536,14 @@ void TypeSolver::MergeFromTo(TypeNode* src, TypeNode* dst) {
} }
// Add equality constraint // Add equality constraint
Type TypeSolver::Unify(const Type& dst, const Type& src, const NodeRef&) { Type TypeSolver::Unify(const Type& dst, const Type& src, const NodeRef& loc) {
// NB(@jroesch): we should probably pass location into the unifier to do better Unifier unifier(this, loc);
// error reporting as well.
Unifier unifier(this);
return unifier.Unify(dst, src); return unifier.Unify(dst, src);
} }
void TypeSolver::ReportError(const Error& err, const NodeRef& location) { void TypeSolver::ReportError(const Error& err, const NodeRef& location) {
CHECK(location.defined());
CHECK(current_func.defined());
err_reporter_->ReportAt(current_func, location, err); err_reporter_->ReportAt(current_func, location, err);
} }
...@@ -460,7 +584,6 @@ Type TypeSolver::Resolve(const Type& type) { ...@@ -460,7 +584,6 @@ Type TypeSolver::Resolve(const Type& type) {
} }
bool TypeSolver::Solve() { bool TypeSolver::Solve() {
// Update until queue is empty.
while (!update_queue_.empty()) { while (!update_queue_.empty()) {
RelationNode* rnode = update_queue_.front(); RelationNode* rnode = update_queue_.front();
const auto& rel = rnode->rel; const auto& rel = rnode->rel;
...@@ -494,11 +617,10 @@ bool TypeSolver::Solve() { ...@@ -494,11 +617,10 @@ bool TypeSolver::Solve() {
rnode->resolved = false; rnode->resolved = false;
} catch (const dmlc::Error& err) { } catch (const dmlc::Error& err) {
rnode->resolved = false; rnode->resolved = false;
this->ReportError( this->ReportError(RELAY_ERROR("an internal invariant was violated while "
RELAY_ERROR( "typechecking your program "
"an internal invariant was violated while " \ << err.what()),
"typechecking your program " << rnode->location);
err.what()), rnode->location);
} }
// Mark inqueue as false after the function call // Mark inqueue as false after the function call
...@@ -516,17 +638,21 @@ TVM_REGISTER_API("relay._analysis._test_type_solver") ...@@ -516,17 +638,21 @@ TVM_REGISTER_API("relay._analysis._test_type_solver")
.set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) { .set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) {
using runtime::PackedFunc; using runtime::PackedFunc;
using runtime::TypedPackedFunc; using runtime::TypedPackedFunc;
ErrorReporter err_reporter; ErrorReporter *err_reporter = new ErrorReporter();
auto solver = std::make_shared<TypeSolver>(GlobalVarNode::make("test"), &err_reporter); auto solver = std::make_shared<TypeSolver>(GlobalVarNode::make("test"), err_reporter);
auto mod = [solver](std::string name) -> PackedFunc { auto mod = [solver, err_reporter](std::string name) -> PackedFunc {
if (name == "Solve") { if (name == "Solve") {
return TypedPackedFunc<bool()>([solver]() { return TypedPackedFunc<bool()>([solver]() {
return solver->Solve(); return solver->Solve();
}); });
} else if (name == "Unify") { } else if (name == "Unify") {
return TypedPackedFunc<Type(Type, Type)>([solver](Type lhs, Type rhs) { return TypedPackedFunc<Type(Type, Type)>([solver, err_reporter](Type lhs, Type rhs) {
return solver->Unify(lhs, rhs, lhs); auto res = solver->Unify(lhs, rhs, lhs);
if (err_reporter->AnyErrors()) {
err_reporter->RenderErrors(ModuleNode::make({}, {}), true);
}
return res;
}); });
} else if (name == "Resolve") { } else if (name == "Resolve") {
return TypedPackedFunc<Type(Type)>([solver](Type t) { return TypedPackedFunc<Type(Type)>([solver](Type t) {
......
...@@ -89,7 +89,6 @@ class TypeSolver { ...@@ -89,7 +89,6 @@ class TypeSolver {
* \param location The location at which the unification problem arose. * \param location The location at which the unification problem arose.
*/ */
Type Unify(const Type& lhs, const Type& rhs, const NodeRef& location); Type Unify(const Type& lhs, const Type& rhs, const NodeRef& location);
/*! /*!
* \brief Report an error at the provided location. * \brief Report an error at the provided location.
* \param err The error to report. * \param err The error to report.
...@@ -124,6 +123,7 @@ class TypeSolver { ...@@ -124,6 +123,7 @@ class TypeSolver {
TypeNode* parent{nullptr}; TypeNode* parent{nullptr};
/*! \brief set of relations that is related to this type node */ /*! \brief set of relations that is related to this type node */
std::unordered_set<RelationNode*> rel_set; std::unordered_set<RelationNode*> rel_set;
/*! /*!
* \brief Find the root type node, perform path compression * \brief Find the root type node, perform path compression
* \return The root type node. * \return The root type node.
...@@ -159,13 +159,15 @@ class TypeSolver { ...@@ -159,13 +159,15 @@ class TypeSolver {
NodeRef location; NodeRef location;
}; };
/*! \brief A simple union find between shapes. */
tvm::Map<IndexExpr, IndexExpr> shape_uf_;
/*! \brief List of all allocated type nodes */ /*! \brief List of all allocated type nodes */
std::vector<TypeNode*> type_nodes_; std::vector<TypeNode*> type_nodes_;
/*! \brief List of all allocated relation nodes */ /*! \brief List of all allocated relation nodes */
std::vector<RelationNode*> rel_nodes_; std::vector<RelationNode*> rel_nodes_;
/*! \brief Number of resolved relations */ /*! \brief Number of resolved relations */
size_t num_resolved_rels_{0}; size_t num_resolved_rels_{0};
/*! \brief map from type node to types. */ /*! \brief map from types to type nodes. */
std::unordered_map<Type, TypeNode*, NodeHash, NodeEqual> tmap_; std::unordered_map<Type, TypeNode*, NodeHash, NodeEqual> tmap_;
/*! \brief Internal queue to update the relation */ /*! \brief Internal queue to update the relation */
std::queue<RelationNode*> update_queue_; std::queue<RelationNode*> update_queue_;
...@@ -205,6 +207,7 @@ class TypeSolver { ...@@ -205,6 +207,7 @@ class TypeSolver {
rel->inqueue = true; rel->inqueue = true;
update_queue_.push(rel); update_queue_.push(rel);
} }
/*! /*!
* \brief Merge rhs type node to lhs * \brief Merge rhs type node to lhs
* \param src The source operand * \param src The source operand
......
...@@ -184,6 +184,10 @@ void NDArray::CopyFromTo(DLTensor* from, ...@@ -184,6 +184,10 @@ void NDArray::CopyFromTo(DLTensor* from,
from_size, from->ctx, to->ctx, from->dtype, stream); from_size, from->ctx, to->ctx, from->dtype, stream);
} }
std::vector<int64_t> NDArray::Shape() const {
return data_->shape_;
}
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import relay
from tvm.relay import Kind, transform
from tvm.relay.loops import while_loop
import numpy as np
def infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod)
entry = mod[mod.entry_func]
return entry if isinstance(expr, relay.Function) else entry.body
def int32(val):
return relay.const(val, 'int32')
def test_arange_with_dynamic_shape():
m, n, k = relay.ShapeVar('m'), relay.ShapeVar('n'), relay.ShapeVar('k')
x = relay.var('x', shape=(m.var, n.var, k.var), dtype='float32')
y0 = relay.shape_of(x)
y1 = relay.take(y0, relay.const(0, 'int32'))
y2 = relay.op.arange(y1)
ex = relay.create_executor()
f = relay.Function([x], y2, type_params=[m, n, k])
# TODO(@jroesch): Restore after code generation.
# data = np.random.rand(10, 5, 3).astype('float32')
# result = ex.evaluate(f)(data)
# np.testing.assert_allclose(result.asnumpy(), np.array(range(10)))
def test_dynamic_concat():
"""
fn @concat_loop(%i: int32, %st: (any, 1)) -> (any, 1) {
if (%i < 10) {
let %i = reshape(cast(i, "float32"), newshape=(1, ))
let %new_st = concatenate((st, i), axis=0)
concat_loop(%i + 1, )
} else {
st
}
}
"""
# Initial Values.
i = relay.var('i', shape=(), dtype='int32')
st = relay.var('st', shape=(relay.Any(), 1), dtype='int32')
def _cond(i, st):
return relay.op.min(relay.op.less(i, int32(10)))
def _body(i, st):
i_vec = relay.op.reshape(i, (1,1))
ret = relay.op.concatenate([st, i_vec], axis=0)
return i + int32(1), ret
loop = while_loop(_cond, [i, st], _body)
start = relay.var('start', shape=(), dtype='int32')
body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1)))
func = relay.Function([start], relay.TupleGetItem(body, 1))
func = infer_type(func)
# TODO(@jroesch, @haichen): We should restore this code when codegeneration
# is merged
# ret_shape = func.checked_type.ret_type.shape
# assert len(ret_shape) == 2, "expected 2-dim output"
# assert relay.ir_pass.alpha_eq(ret_shape[0], relay.Any())
# import pdb; pdb.set_trace()
# mod = relay.module.Module()
# print(relay.ir_pass.infer_type(func, mod=mod))
# ret = relay.Call(loop, [relay.const(0, 'int32'), init])
# mod[mod.entry_func] = relay.Function([], ret)
# print(relay.ir_pass.infer_type(mod[mod.entry_func], mod=mod))
# initial = np.array(0.0, dtype='float32').reshape((1,))
# iter_stop = np.array(10, dtype='int32')
# ex = relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm")
# result = ex.evaluate(mod.entry_func)()
# np.testing.assert_allclose(result.asnumpy(), np.array(range(10)))
def test_dynamic_concat_with_wrong_annotation():
"""
v0.0.1
fn (%start: int32) {
%7 = {
let %while_loop = fn (%i: int32, %st: Tensor[(1, 1), int32]) {
%0 = less(%i, 10)
%1 = min(%0)
if (%1) {
%2 = add(%i, 1)
%3 = reshape(%i, newshape=[1, 1])
%4 = (%st, %3)
/* The result of concat should be 1,1 but it is 2, 1. */
%5 = concatenate(%4)
%while_loop(%2, %5)
} else {
(%i, %st)
}
}
%6 = reshape(0, newshape=[1, 1])
%while_loop(%start, %6)
}
%7.1
}
"""
# Initial Values.
i = relay.var('i', shape=(), dtype='int32')
st = relay.var('st', shape=(1, 1), dtype='int32')
def _cond(i, st):
return relay.op.min(relay.op.less(i, int32(10)))
def _body(i, st):
i_vec = relay.op.reshape(i, (1,1))
ret = relay.op.concatenate([st, i_vec], axis=0)
return i + int32(1), ret
loop = while_loop(_cond, [i, st], _body)
start = relay.var('start', shape=(), dtype='int32')
body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1)))
func = relay.Function([start], relay.TupleGetItem(body, 1))
try:
func = infer_type(func)
assert False
except Exception as e:
assert "in particular dimension 0 conflicts 2 does not match 1" in str(e)
if __name__ == "__main__":
test_arange_with_dynamic_shape()
test_dynamic_concat()
test_dynamic_concat_with_wrong_annotation()
...@@ -493,17 +493,20 @@ def test_arange(): ...@@ -493,17 +493,20 @@ def test_arange():
def verify_arange(start, stop, step): def verify_arange(start, stop, step):
dtype = "float32" dtype = "float32"
if start is None and step is None: if start is None and step is None:
x = relay.arange(stop) x = relay.arange(relay.const(stop, dtype=dtype))
ref_res = np.arange(stop) ref_res = np.arange(stop).astype(dtype)
elif start is None: elif start is None:
x = relay.arange(stop, step=step) x = relay.arange(relay.const(stop, dtype=dtype), step=relay.const(step, dtype=dtype))
ref_res = np.arange(stop, step=step) ref_res = np.arange(stop, step=step).astype(dtype)
elif step is None: elif step is None:
x = relay.arange(start, stop) x = relay.arange(relay.const(start, dtype=dtype), relay.const(stop, dtype=dtype))
ref_res = np.arange(start, stop) ref_res = np.arange(start, stop).astype(dtype)
else: else:
x = relay.arange(start, stop, step) x = relay.arange(
ref_res = np.arange(start, stop, step) relay.const(start, dtype=dtype),
relay.const(stop, dtype=dtype),
relay.const(step, dtype=dtype))
ref_res = np.arange(start, stop, step).astype(dtype)
func = relay.Function([], x) func = relay.Function([], x)
for target, ctx in ctx_list(): for target, ctx in ctx_list():
...@@ -515,11 +518,13 @@ def test_arange(): ...@@ -515,11 +518,13 @@ def test_arange():
verify_arange(None, 20, 2) verify_arange(None, 20, 2)
verify_arange(1, 20, None) verify_arange(1, 20, None)
verify_arange(1, 20, 2) verify_arange(1, 20, 2)
verify_arange(1, 20, 1.5) # arange doesnt' support floating point right now, see type relation
# verify_arange(1, 20, 1.5)
verify_arange(1, 20.5, None) verify_arange(1, 20.5, None)
verify_arange(1, 20, 3) verify_arange(1, 20, 3)
verify_arange(20, 1, -1) verify_arange(20, 1, -1)
verify_arange(20, 1, -1.5) # arange doesnt' support floating point right now, see type relation
# verify_arange(20, 1, -1.5)
def test_tile(): def test_tile():
def verify_tile(dshape, reps): def verify_tile(dshape, reps):
...@@ -616,6 +621,7 @@ def test_gather_nd(): ...@@ -616,6 +621,7 @@ def test_gather_nd():
if __name__ == "__main__": if __name__ == "__main__":
test_arange()
test_cast() test_cast()
test_zeros_ones() test_zeros_ones()
test_unary_identity() test_unary_identity()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment