Commit 3fb84e2b by Jared Roesch Committed by Thierry Moreau

[Relay][RFC] Implement type checking for Any (#3221)

* Implement type checking for Any

Remove code generation related changes

Remove compile changes

Remove more

Remove unification hack

Add some code back that was needed, and clean up test

Refactor test cases

WIP

Implement TypeHint AST

Add test case which should fail

Remove unification changes, and fix bug with let rec

Restore unification for shapes

Improve error reporting while debugging

All examples type check

All examples type check

WIP

First version that works with hints, needs clean up

Remove dead code

Tweaks

Remove type hint

Remove unecessary type hint stuff

Remove more type hints

Clean up

Expose Any expression node

Address CR

Fix

Fix solver

Kill unecessary code

Fix

PyLint

Fix

Relocate loops

Fix license and test

Lint again

Lint again

Fix loops

Fix docstring

Fix template error

Fix compiler issue

Fix compile err

Remove more runtime changes

Restore buffer

Fix segfault

Fix

Fix arange

* Address feedback

* Fix typo

* Fix arange

* Fix op level3

* Fix issue with Python wrapper
parent b10dda69
...@@ -138,6 +138,15 @@ struct Reduce : public ExprNode<Reduce> { ...@@ -138,6 +138,15 @@ struct Reduce : public ExprNode<Reduce> {
static constexpr const char* _type_key = "Reduce"; static constexpr const char* _type_key = "Reduce";
}; };
/*! \brief Any shape. */
struct Any : public ExprNode<Any> {
TVM_DLL static Expr make();
void VisitAttrs(AttrVisitor* v) final {}
static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "Any";
};
/*! /*!
* \brief Auxiliary data structure used in IR Pass to indicate a tensor. * \brief Auxiliary data structure used in IR Pass to indicate a tensor.
*/ */
......
...@@ -123,19 +123,19 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> { ...@@ -123,19 +123,19 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
/*! \brief Attributes used in arange operators */ /*! \brief Attributes used in arange operators */
struct ArangeAttrs : public tvm::AttrsNode<ArangeAttrs> { struct ArangeAttrs : public tvm::AttrsNode<ArangeAttrs> {
tvm::Expr start; Expr start;
tvm::Expr stop; Expr stop;
tvm::Expr step; Expr step;
DataType dtype; DataType dtype;
TVM_DECLARE_ATTRS(ArangeAttrs, "relay.attrs.ArangeAttrs") { TVM_DECLARE_ATTRS(ArangeAttrs, "relay.attrs.ArangeAttrs") {
TVM_ATTR_FIELD(start).set_default(make_const(Float(32), 0)) TVM_ATTR_FIELD(start)
.describe("Start of interval. The interval includes this value."); .describe("Start of interval. The interval includes this value.");
TVM_ATTR_FIELD(stop) TVM_ATTR_FIELD(stop)
.describe("Stop of interval. The interval does not include this value."); .describe("Stop of interval. The interval does not include this value.");
TVM_ATTR_FIELD(step).set_default(make_const(Float(32), 1)) TVM_ATTR_FIELD(step)
.describe("Spacing between values."); .describe("Spacing between values.");
TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>()) TVM_ATTR_FIELD(dtype)
.describe("Target data type."); .describe("Target data type.");
} }
}; // struct ArangeAttrs }; // struct ArangeAttrs
......
...@@ -64,9 +64,10 @@ struct RelayErrorStream { ...@@ -64,9 +64,10 @@ struct RelayErrorStream {
struct Error : public dmlc::Error { struct Error : public dmlc::Error {
Span sp; Span sp;
explicit Error(const std::string& msg) : dmlc::Error(msg), sp() {} explicit Error(const std::string& msg) : dmlc::Error(msg), sp(nullptr) {}
Error(const std::stringstream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*) Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp(nullptr) {} // NOLINT(*)
Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*) Error(const Error& err) : dmlc::Error(err.what()), sp(nullptr) {}
Error() : dmlc::Error(""), sp(nullptr) {}
}; };
/*! \brief An abstraction around how errors are stored and reported. /*! \brief An abstraction around how errors are stored and reported.
...@@ -118,7 +119,8 @@ class ErrorReporter { ...@@ -118,7 +119,8 @@ class ErrorReporter {
* \param err The error message to report. * \param err The error message to report.
*/ */
inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) { inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) {
this->ReportAt(global, node, Error(err)); std::string err_msg = err.str();
this->ReportAt(global, node, Error(err_msg));
} }
/*! \brief Report an error against a program, using the full program /*! \brief Report an error against a program, using the full program
......
...@@ -561,6 +561,9 @@ inline const TTypeNode* ExprNode::type_as() const { ...@@ -561,6 +561,9 @@ inline const TTypeNode* ExprNode::type_as() const {
return node; return node;
} }
/*! \brief Pretty print a Relay node, producing a fragment of the Relay text format. */
std::string PrettyPrint(const NodeRef& node);
/*! /*!
* \brief Render the node as a string in the Relay text format. * \brief Render the node as a string in the Relay text format.
* \param node The node to be rendered. * \param node The node to be rendered.
......
...@@ -158,6 +158,22 @@ using FForwardRewrite = runtime::TypedPackedFunc< ...@@ -158,6 +158,22 @@ using FForwardRewrite = runtime::TypedPackedFunc<
using FPrimalGradient = runtime::TypedPackedFunc<tvm::Array<Expr>(const Expr& orig_call, using FPrimalGradient = runtime::TypedPackedFunc<tvm::Array<Expr>(const Expr& orig_call,
const Expr& output_grad)>; const Expr& output_grad)>;
/*!
* \brief The codegeneration strategy for dynamic dimensions.
*/
enum AnyCodegenStrategy {
/*! \brief The default strategy of using completely variable dimensions. */
kVariableDimensions
};
/* \brief A runtime representation of shape. */
using Shape = Array<IndexExpr>;
using FShapeFunc = runtime::TypedPackedFunc<
Array<Tensor>(const Attrs& attrs,
const Array<Tensor>& inputs,
const Array<Shape>& out_shapes)>;
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_OP_ATTR_TYPES_H_ #endif // TVM_RELAY_OP_ATTR_TYPES_H_
...@@ -35,6 +35,8 @@ ...@@ -35,6 +35,8 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
using Any = tvm::ir::Any;
/*! \brief Base type of the Relay type hiearchy. */ /*! \brief Base type of the Relay type hiearchy. */
class TypeNode : public RelayNode { class TypeNode : public RelayNode {
public: public:
...@@ -384,6 +386,7 @@ class TypeReporterNode : public Node { ...@@ -384,6 +386,7 @@ class TypeReporterNode : public Node {
* But it is possible for the solver to resolve src by dst as well. * But it is possible for the solver to resolve src by dst as well.
*/ */
TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0; TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0;
/*! /*!
* \brief assert shape expression comparison. * \brief assert shape expression comparison.
* \note Use assert only if any of the condition input is symbolic. * \note Use assert only if any of the condition input is symbolic.
......
...@@ -190,6 +190,8 @@ class NDArray { ...@@ -190,6 +190,8 @@ class NDArray {
TVM_DLL static void CopyFromTo( TVM_DLL static void CopyFromTo(
DLTensor* from, DLTensor* to, TVMStreamHandle stream = nullptr); DLTensor* from, DLTensor* to, TVMStreamHandle stream = nullptr);
TVM_DLL std::vector<int64_t> Shape() const;
// internal namespace // internal namespace
struct Internal; struct Internal;
protected: protected:
......
...@@ -294,7 +294,7 @@ def get_last_ffi_error(): ...@@ -294,7 +294,7 @@ def get_last_ffi_error():
""" """
c_err_msg = py_str(_LIB.TVMGetLastError()) c_err_msg = py_str(_LIB.TVMGetLastError())
py_err_msg, err_type = c2pyerror(c_err_msg) py_err_msg, err_type = c2pyerror(c_err_msg)
if err_type.startswith("tvm.error."): if err_type is not None and err_type.startswith("tvm.error."):
err_type = err_type[10:] err_type = err_type[10:]
return ERROR_TYPE.get(err_type, TVMError)(py_err_msg) return ERROR_TYPE.get(err_type, TVMError)(py_err_msg)
......
...@@ -479,7 +479,8 @@ def extern(shape, ...@@ -479,7 +479,8 @@ def extern(shape,
raise ValueError("nested tag is not allowed for now") raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.get_current().tag tag = _tag.TagScope.get_current().tag
shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
shape = [shape] if isinstance(shape[0], (_expr.Expr, _Integral)) else shape if shape == () or isinstance(shape[0], (_expr.Expr, _Integral)):
shape = [shape]
if in_buffers is not None: if in_buffers is not None:
in_buffers = [in_buffers] if not isinstance(in_buffers, list) else in_buffers in_buffers = [in_buffers] if not isinstance(in_buffers, list) else in_buffers
if len(inputs) != len(in_buffers): if len(inputs) != len(in_buffers):
......
...@@ -63,6 +63,7 @@ TupleType = ty.TupleType ...@@ -63,6 +63,7 @@ TupleType = ty.TupleType
TensorType = ty.TensorType TensorType = ty.TensorType
Kind = ty.Kind Kind = ty.Kind
TypeVar = ty.TypeVar TypeVar = ty.TypeVar
ShapeVar = ty.ShapeVar
TypeConstraint = ty.TypeConstraint TypeConstraint = ty.TypeConstraint
FuncType = ty.FuncType FuncType = ty.FuncType
TypeRelation = ty.TypeRelation TypeRelation = ty.TypeRelation
...@@ -71,6 +72,7 @@ scalar_type = ty.scalar_type ...@@ -71,6 +72,7 @@ scalar_type = ty.scalar_type
RefType = ty.RefType RefType = ty.RefType
GlobalTypeVar = ty.GlobalTypeVar GlobalTypeVar = ty.GlobalTypeVar
TypeCall = ty.TypeCall TypeCall = ty.TypeCall
Any = ty.Any
# Expr # Expr
Expr = expr.Expr Expr = expr.Expr
......
...@@ -570,6 +570,7 @@ def const(value, dtype=None): ...@@ -570,6 +570,7 @@ def const(value, dtype=None):
""" """
if isinstance(value, (_base.numeric_types, (bool, list))): if isinstance(value, (_base.numeric_types, (bool, list))):
value = _np.array(value, dtype=dtype) value = _np.array(value, dtype=dtype)
if not dtype: if not dtype:
# when dtype is None: int maps to "int32", float maps to "float32" # when dtype is None: int maps to "int32", float maps to "float32"
map_dtype = { map_dtype = {
...@@ -578,6 +579,7 @@ def const(value, dtype=None): ...@@ -578,6 +579,7 @@ def const(value, dtype=None):
}.get(value.dtype, None) }.get(value.dtype, None)
if map_dtype: if map_dtype:
value = value.astype(map_dtype) value = value.astype(map_dtype)
if isinstance(value, (_np.ndarray, _np.generic)): if isinstance(value, (_np.ndarray, _np.generic)):
value = _nd.array(value) value = _nd.array(value)
......
...@@ -491,9 +491,9 @@ def _mx_arange(inputs, attrs): ...@@ -491,9 +491,9 @@ def _mx_arange(inputs, attrs):
raise tvm.error.OpAttributeUnimplemented( raise tvm.error.OpAttributeUnimplemented(
'Attribute "repeat" is not supported in operator arange.') 'Attribute "repeat" is not supported in operator arange.')
new_attrs = {} new_attrs = {}
new_attrs["start"] = attrs.get_float("start", 0) new_attrs["start"] = _expr.const(attrs.get_float("start", 0.0))
new_attrs["stop"] = attrs.get_float("stop") new_attrs["stop"] = _expr.const(attrs.get_float("stop"))
new_attrs["step"] = attrs.get_float("step", 1) new_attrs["step"] = _expr.const(attrs.get_float("step", 1.0))
new_attrs["dtype"] = attrs.get_str("dtype", "float32") new_attrs["dtype"] = attrs.get_str("dtype", "float32")
return _op.arange(**new_attrs) return _op.arange(**new_attrs)
......
...@@ -1059,9 +1059,9 @@ def _range(): ...@@ -1059,9 +1059,9 @@ def _range():
return AttrCvt( return AttrCvt(
op_name="arange", op_name="arange",
ignores=['Tidx'], ignores=['Tidx'],
extras={'start': start, extras={'start': _expr.const(start),
"stop": limit, "stop": _expr.const(limit),
'step': delta, 'step': _expr.const(delta),
'dtype': dtype})([], attr) 'dtype': dtype})([], attr)
return _impl return _impl
...@@ -1269,8 +1269,8 @@ def _batch_to_space_nd(): ...@@ -1269,8 +1269,8 @@ def _batch_to_space_nd():
crop = crops[axis - 1] crop = crops[axis - 1]
if crop != [0, 0]: if crop != [0, 0]:
indices = tvm.relay.arange( indices = tvm.relay.arange(
crop[0], _expr.const(crop[0]),
reshaped_permuted_shape[axis] - crop[1], _expr.const(reshaped_permuted_shape[axis] - crop[1]),
dtype='int32' dtype='int32'
) )
cropped = tvm.relay.take(cropped, indices=indices, axis=axis) cropped = tvm.relay.take(cropped, indices=indices, axis=axis)
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""
Utilities for building Relay loops.
"""
from .scope_builder import ScopeBuilder
from . import expr as _expr
def while_loop(cond, loop_vars, loop_bodies):
"""
Construct a while loop.
Parameters
----------
cond: Callable[Tuple[relay.Expr], relay.Expr]
The condition of the loop.
loop_vars: Tuple[relay.Expr]
The variables being looped over.
The initial values of the loop, will be used to
construct the loop variables.
loop_bodies: Callable[Tuple[relay.Expr], Tuple[relay.Expr]]
The body of the loop, should be a function which
given loop variables produces the output result
also as a tuple
Returns
-------
loop: relay.Expr
The loop expression.
"""
sb = ScopeBuilder()
loop = _expr.Var("while_loop")
fresh_vars = []
for i, loop_var in enumerate(loop_vars):
name = loop_var.name_hint if isinstance(loop_var, _expr.Var) else "arg{}".format(i)
new_var = _expr.var(name, type_annotation=sb.type_of(loop_var))
fresh_vars.append(new_var)
with sb.if_scope(cond(*fresh_vars)):
sb.ret(loop(*loop_bodies(*fresh_vars)))
with sb.else_scope():
sb.ret(_expr.Tuple(fresh_vars))
func = _expr.Function(fresh_vars, sb.get())
let = _expr.Let(loop, func, loop)
return let
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
"""Transform operators.""" """Transform operators."""
from . import _make from . import _make
from ..expr import TupleWrapper from ..expr import TupleWrapper, const
def cast(data, dtype): def cast(data, dtype):
...@@ -272,7 +272,7 @@ def full_like(data, fill_value): ...@@ -272,7 +272,7 @@ def full_like(data, fill_value):
return _make.full_like(data, fill_value) return _make.full_like(data, fill_value)
def arange(start, stop=None, step=1, dtype="float32"): def arange(start, stop=None, step=None, dtype="float32"):
"""Return evenly spaced values within a given interval. """Return evenly spaced values within a given interval.
.. note:: .. note::
...@@ -310,9 +310,13 @@ def arange(start, stop=None, step=1, dtype="float32"): ...@@ -310,9 +310,13 @@ def arange(start, stop=None, step=1, dtype="float32"):
relay.arange(1, 5) = [1, 2, 3, 4] relay.arange(1, 5) = [1, 2, 3, 4]
relay.arange(1, 5, 1.5) = [1, 2.5, 4] relay.arange(1, 5, 1.5) = [1, 2.5, 4]
""" """
if step is None:
step = const(1, dtype)
if stop is None: if stop is None:
stop = start stop = start
start = 0 start = const(0, dtype=dtype)
return _make.arange(start, stop, step, dtype) return _make.arange(start, stop, step, dtype)
......
...@@ -42,7 +42,6 @@ class WithScope(object): ...@@ -42,7 +42,6 @@ class WithScope(object):
else: else:
self._exit_cb() self._exit_cb()
def _make_lets(bindings, ret_value): def _make_lets(bindings, ret_value):
"""Make a nested let expressions. """Make a nested let expressions.
...@@ -176,6 +175,24 @@ class ScopeBuilder(object): ...@@ -176,6 +175,24 @@ class ScopeBuilder(object):
false_branch) false_branch)
return WithScope(None, _on_exit) return WithScope(None, _on_exit)
def type_of(self, expr):
"""
Compute the type of an expression.
Parameters
----------
expr: relay.Expr
The expression to compute the type of.
"""
if isinstance(expr, _expr.Var):
return expr.type_annotation
ity = _ty.IncompleteType()
var = _expr.var("unify", ity)
self.let(var, expr)
return ity
def ret(self, value): def ret(self, value):
"""Set the return value of this scope. """Set the return value of this scope.
......
...@@ -20,6 +20,7 @@ from enum import IntEnum ...@@ -20,6 +20,7 @@ from enum import IntEnum
from .base import RelayNode, register_relay_node from .base import RelayNode, register_relay_node
from . import _make from . import _make
Any = _make.Any
class Type(RelayNode): class Type(RelayNode):
"""The base type for all Relay types.""" """The base type for all Relay types."""
...@@ -137,6 +138,19 @@ class TypeVar(Type): ...@@ -137,6 +138,19 @@ class TypeVar(Type):
""" """
self.__init_handle_by_constructor__(_make.TypeVar, var, kind) self.__init_handle_by_constructor__(_make.TypeVar, var, kind)
def ShapeVar(name):
"""A helper which constructs a type var of which the shape kind.
Parameters
----------
name : str
Returns
-------
type_var : tvm.relay.TypeVar
The shape variable.
"""
return TypeVar(name, kind=Kind.ShapeVar)
@register_relay_node @register_relay_node
class GlobalTypeVar(Type): class GlobalTypeVar(Type):
......
...@@ -970,7 +970,9 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) { ...@@ -970,7 +970,9 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
op->call_type == Call::PureExtern) { op->call_type == Call::PureExtern) {
return CreateCallExtern(op); return CreateCallExtern(op);
} else { } else {
LOG(FATAL) << "Unknown call type "; LOG(FATAL) << "Unknown call type " <<
"name= " << op->name <<
" call_type= " << op->call_type;
return nullptr; return nullptr;
} }
} }
......
...@@ -246,13 +246,20 @@ inline Expr MergeMulMod(const Expr &base) { ...@@ -246,13 +246,20 @@ inline Expr MergeMulMod(const Expr &base) {
inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) { inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
Expr base = n->elem_offset; Expr base = n->elem_offset;
if (n->strides.size() == 0) { if (n->strides.size() == 0) {
CHECK_EQ(n->shape.size(), index.size()); // Scalar case
if (index.size() > 0) { if (n->shape.size() == 0 && index.size() == 1) {
Expr offset = index[0]; auto is_int = index[0].as<IntImm>();
for (size_t i = 1; i < index.size(); ++i) { CHECK(is_int && is_int->value == 0);
offset = MergeMulMod(offset * n->shape[i] + index[i]); base = base + index[0];
} else {
CHECK_EQ(n->shape.size(), index.size());
if (index.size() > 0) {
Expr offset = index[0];
for (size_t i = 1; i < index.size(); ++i) {
offset = MergeMulMod(offset * n->shape[i] + index[i]);
}
base = base + offset;
} }
base = base + offset;
} }
} else { } else {
CHECK_EQ(n->strides.size(), index.size()); CHECK_EQ(n->strides.size(), index.size());
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -35,14 +35,25 @@ namespace Internal { ...@@ -35,14 +35,25 @@ namespace Internal {
using tvm::ir::CommReducerNode; using tvm::ir::CommReducerNode;
using tvm::ir::Reduce; using tvm::ir::Reduce;
using tvm::ir::Any;
using tvm::ir::AttrStmt; using tvm::ir::AttrStmt;
template<> template<>
void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const { void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const {
LOG(FATAL) << "Reduce do not work with old Visitor, use IRFunctor style visitor"; LOG(FATAL) << "Reduce does not work with old Visitor, use IRFunctor style visitor";
}
template<>
void ExprNode<Any>::accept(IRVisitor *v, const Expr&) const {
LOG(FATAL) << "Any does not work with old Visitor, use IRFunctor style visitor";
} }
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Any>([](const Any *op, IRPrinter *p) {
p->stream << "?";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) { .set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) {
p->stream << "reduce(combiner=" p->stream << "reduce(combiner="
<< op->combiner; << op->combiner;
...@@ -116,8 +127,14 @@ Expr Reduce::make(CommReducer combiner, Array<Expr> source, ...@@ -116,8 +127,14 @@ Expr Reduce::make(CommReducer combiner, Array<Expr> source,
return Expr(n); return Expr(n);
} }
Expr Any::make() {
auto n = make_node<Any>();
return Expr(n);
}
TVM_REGISTER_NODE_TYPE(CommReducerNode); TVM_REGISTER_NODE_TYPE(CommReducerNode);
TVM_REGISTER_NODE_TYPE(Reduce); TVM_REGISTER_NODE_TYPE(Reduce);
TVM_REGISTER_NODE_TYPE(Any);
TVM_REGISTER_NODE_TYPE(AttrStmt); TVM_REGISTER_NODE_TYPE(AttrStmt);
TVM_REGISTER_NODE_TYPE(FloatImm); TVM_REGISTER_NODE_TYPE(FloatImm);
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -38,9 +38,12 @@ Expr Tensor::operator()(Array<Var> indices) const { ...@@ -38,9 +38,12 @@ Expr Tensor::operator()(Array<Var> indices) const {
Expr Tensor::operator()(Array<Expr> indices) const { Expr Tensor::operator()(Array<Expr> indices) const {
using HalideIR::Internal::Call; using HalideIR::Internal::Call;
CHECK_EQ(ndim(), indices.size()) if (ndim() != 0) {
<< "Tensor dimension mismatch in read" CHECK_EQ(ndim(), indices.size())
<< "ndim = " << ndim() << ", indices.size=" << indices.size(); << "Tensor dimension mismatch in read"
<< "ndim = " << ndim() << ", indices.size=" << indices.size();
}
auto n = Call::make( auto n = Call::make(
(*this)->dtype, (*this)->op->name, indices, Call::Halide, (*this)->dtype, (*this)->op->name, indices, Call::Halide,
(*this)->op, (*this)->value_index); (*this)->op, (*this)->value_index);
......
...@@ -417,10 +417,10 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -417,10 +417,10 @@ class RelayBuildModule : public runtime::ModuleNode {
} }
/*! /*!
* \brief Build relay function to runtime module * \brief Compile a Relay function to runtime module.
* *
* \param func Relay Function * \param func The Relay function.
* \param params parameters * \param params The parameters.
*/ */
void BuildRelay( void BuildRelay(
Function func, Function func,
...@@ -444,8 +444,13 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -444,8 +444,13 @@ class RelayBuildModule : public runtime::ModuleNode {
ret_.graph_json = graph_codegen_->GetJSON(); ret_.graph_json = graph_codegen_->GetJSON();
ret_.params = graph_codegen_->GetParams(); ret_.params = graph_codegen_->GetParams();
ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host_, auto lowered_funcs = graph_codegen_->GetLoweredFunc();
BuildConfig::Current()); if (lowered_funcs.size() != 0) {
ret_.mod = tvm::build(
lowered_funcs,
target_host_,
BuildConfig::Current());
}
} }
protected: protected:
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......
...@@ -67,6 +67,7 @@ void ErrorReporter::RenderErrors(const Module& module, bool use_color) { ...@@ -67,6 +67,7 @@ void ErrorReporter::RenderErrors(const Module& module, bool use_color) {
std::stringstream err_msg; std::stringstream err_msg;
err_msg << rang::fg::red; err_msg << rang::fg::red;
err_msg << " ";
for (auto index : error_indicies) { for (auto index : error_indicies) {
err_msg << this->errors_[index].what() << "; "; err_msg << this->errors_[index].what() << "; ";
} }
...@@ -88,7 +89,7 @@ void ErrorReporter::RenderErrors(const Module& module, bool use_color) { ...@@ -88,7 +89,7 @@ void ErrorReporter::RenderErrors(const Module& module, bool use_color) {
// First we output a header for the errors. // First we output a header for the errors.
annotated_prog << annotated_prog <<
rang::style::bold << std::endl << rang::style::bold << std::endl <<
"Error(s) have occurred. We have annotated the program with them:" "Error(s) have occurred. The program has been annotated with them:"
<< std::endl << std::endl << rang::style::reset; << std::endl << std::endl << rang::style::reset;
// For each global function which contains errors, we will // For each global function which contains errors, we will
......
...@@ -287,6 +287,8 @@ RefCreate RefCreateNode::make(Expr value) { ...@@ -287,6 +287,8 @@ RefCreate RefCreateNode::make(Expr value) {
return RefCreate(n); return RefCreate(n);
} }
TVM_REGISTER_NODE_TYPE(RefCreateNode);
TVM_REGISTER_API("relay._make.RefCreate") TVM_REGISTER_API("relay._make.RefCreate")
.set_body_typed(RefCreateNode::make); .set_body_typed(RefCreateNode::make);
...@@ -301,6 +303,8 @@ RefRead RefReadNode::make(Expr ref) { ...@@ -301,6 +303,8 @@ RefRead RefReadNode::make(Expr ref) {
return RefRead(n); return RefRead(n);
} }
TVM_REGISTER_NODE_TYPE(RefReadNode);
TVM_REGISTER_API("relay._make.RefRead") TVM_REGISTER_API("relay._make.RefRead")
.set_body_typed(RefReadNode::make); .set_body_typed(RefReadNode::make);
...@@ -316,6 +320,8 @@ RefWrite RefWriteNode::make(Expr ref, Expr value) { ...@@ -316,6 +320,8 @@ RefWrite RefWriteNode::make(Expr ref, Expr value) {
return RefWrite(n); return RefWrite(n);
} }
TVM_REGISTER_NODE_TYPE(RefWriteNode);
TVM_REGISTER_API("relay._make.RefWrite") TVM_REGISTER_API("relay._make.RefWrite")
.set_body_typed(RefWriteNode::make); .set_body_typed(RefWriteNode::make);
......
...@@ -686,7 +686,9 @@ class PrettyPrinter : ...@@ -686,7 +686,9 @@ class PrettyPrinter :
Doc PrintAttr(const NodeRef& value, bool meta = false) { Doc PrintAttr(const NodeRef& value, bool meta = false) {
if (value.defined()) { if (value.defined()) {
Doc printed_attr; Doc printed_attr;
if (meta) { if (value.as<tvm::ir::Any>()) {
printed_attr << "?";
} else if (meta) {
printed_attr = meta_.GetMetaNode(value); printed_attr = meta_.GetMetaNode(value);
} else { } else {
printed_attr = VisitAttr(value); printed_attr = VisitAttr(value);
...@@ -846,6 +848,12 @@ std::string PrettyPrint_(const NodeRef& node, ...@@ -846,6 +848,12 @@ std::string PrettyPrint_(const NodeRef& node,
return doc.str(); return doc.str();
} }
std::string PrettyPrint(const NodeRef& node) {
Doc doc;
doc << PrettyPrinter(false, runtime::TypedPackedFunc<std::string(Expr)>()).PrintFinal(node);
return doc.str();
}
std::string AsText(const NodeRef& node, std::string AsText(const NodeRef& node,
bool show_meta_data, bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) { runtime::TypedPackedFunc<std::string(Expr)> annotate) {
......
...@@ -228,5 +228,9 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -228,5 +228,9 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "RefTypeNode(" << node->value << ")"; p->stream << "RefTypeNode(" << node->value << ")";
}); });
TVM_REGISTER_API("relay._make.Any")
.set_body_typed<IndexExpr()>([]() { return Any::make(); });
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -87,6 +87,11 @@ Type ConcreteBroadcast(const TensorType& t1, ...@@ -87,6 +87,11 @@ Type ConcreteBroadcast(const TensorType& t1,
oshape.push_back(s2); oshape.push_back(s2);
} else if (EqualConstInt(s2, 1)) { } else if (EqualConstInt(s2, 1)) {
oshape.push_back(s1); oshape.push_back(s1);
} else if (s1.as<Any>() && EqualConstInt(s2, 1)) {
// TODO(@jroesch): we need to come back to this
oshape.push_back(s2);
} else if (s2.as<Any>() && EqualConstInt(s1, 1)) {
oshape.push_back(s1);
} else { } else {
RELAY_ERROR( RELAY_ERROR(
"Incompatible broadcast type " "Incompatible broadcast type "
......
...@@ -313,17 +313,24 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -313,17 +313,24 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
Type VisitExpr_(const LetNode* let) final { Type VisitExpr_(const LetNode* let) final {
// if the definition is a function literal, permit recursion // if the definition is a function literal, permit recursion
bool is_functional_literal = let->value.as<FunctionNode>() != nullptr; bool is_functional_literal = let->value.as<FunctionNode>() != nullptr;
Type let_type = IncompleteTypeNode::make(Kind::kType);
if (is_functional_literal) { if (is_functional_literal) {
type_map_[let->var].checked_type = IncompleteTypeNode::make(Kind::kType); let_type = GetType(let->var);
type_map_[let->var].checked_type = let_type;
} }
Type vtype = GetType(let->value);
if (let->var->type_annotation.defined()) { if (let->var->type_annotation.defined()) {
vtype = Unify(vtype, let->var->type_annotation, GetRef<Let>(let)); let_type = Unify(let_type, let->var->type_annotation, GetRef<Let>(let));
} }
Type vtype = GetType(let->value);
let_type = Unify(let_type, vtype, GetRef<Let>(let));
CHECK(is_functional_literal || !type_map_.count(let->var)); CHECK(is_functional_literal || !type_map_.count(let->var));
// NOTE: no scoping is necessary because var are unique in program // NOTE: no scoping is necessary because var are unique in program
type_map_[let->var].checked_type = vtype; type_map_[let->var].checked_type = let_type;
return GetType(let->body); return GetType(let->body);
} }
...@@ -473,7 +480,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -473,7 +480,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
} }
for (size_t i = 0; i < fn_ty->arg_types.size(); i++) { for (size_t i = 0; i < fn_ty->arg_types.size(); i++) {
this->Unify(fn_ty->arg_types[i], arg_types[i], call->args[i]); this->Unify(fn_ty->arg_types[i], arg_types[i], GetRef<Call>(call));
} }
for (auto cs : fn_ty->type_constraints) { for (auto cs : fn_ty->type_constraints) {
...@@ -556,6 +563,14 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -556,6 +563,14 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
return FuncTypeNode::make(c->inputs, TypeCallNode::make(c->belong_to, types), return FuncTypeNode::make(c->inputs, TypeCallNode::make(c->belong_to, types),
td->type_vars, {}); td->type_vars, {});
} }
void Solve() {
solver_.Solve();
if (err_reporter.AnyErrors()) {
err_reporter.RenderErrors(mod_);
}
}
}; };
class TypeInferencer::Resolver : public ExprMutator, PatternMutator { class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
...@@ -673,7 +688,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { ...@@ -673,7 +688,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
update_missing_type_annotation_ && update_missing_type_annotation_ &&
!new_var->type_annotation.defined()); !new_var->type_annotation.defined());
bool need_update_fn = ( bool need_update_fn =(
std::is_base_of<FunctionNode, T>::value && std::is_base_of<FunctionNode, T>::value &&
update_missing_type_annotation_ && update_missing_type_annotation_ &&
!new_fn->ret_type.defined()); !new_fn->ret_type.defined());
...@@ -738,16 +753,13 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { ...@@ -738,16 +753,13 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
Expr TypeInferencer::Infer(Expr expr) { Expr TypeInferencer::Infer(Expr expr) {
// Step 0: Populate the constraints. // Step 1: Populate the constraints.
GetType(expr); GetType(expr);
// Step 1: Solve the constraints.
solver_.Solve();
if (err_reporter.AnyErrors()) { // Step 2: Solve the constraints.
err_reporter.RenderErrors(mod_); Solve();
}
// Step 2: Attach resolved types to checked_type field. // Step 3: Attach resolved types to checked_type field.
auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr); auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr);
CHECK(WellFormed(resolved_expr)); CHECK(WellFormed(resolved_expr));
return resolved_expr; return resolved_expr;
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -89,7 +89,6 @@ class TypeSolver { ...@@ -89,7 +89,6 @@ class TypeSolver {
* \param location The location at which the unification problem arose. * \param location The location at which the unification problem arose.
*/ */
Type Unify(const Type& lhs, const Type& rhs, const NodeRef& location); Type Unify(const Type& lhs, const Type& rhs, const NodeRef& location);
/*! /*!
* \brief Report an error at the provided location. * \brief Report an error at the provided location.
* \param err The error to report. * \param err The error to report.
...@@ -124,6 +123,7 @@ class TypeSolver { ...@@ -124,6 +123,7 @@ class TypeSolver {
TypeNode* parent{nullptr}; TypeNode* parent{nullptr};
/*! \brief set of relations that is related to this type node */ /*! \brief set of relations that is related to this type node */
std::unordered_set<RelationNode*> rel_set; std::unordered_set<RelationNode*> rel_set;
/*! /*!
* \brief Find the root type node, perform path compression * \brief Find the root type node, perform path compression
* \return The root type node. * \return The root type node.
...@@ -159,13 +159,15 @@ class TypeSolver { ...@@ -159,13 +159,15 @@ class TypeSolver {
NodeRef location; NodeRef location;
}; };
/*! \brief A simple union find between shapes. */
tvm::Map<IndexExpr, IndexExpr> shape_uf_;
/*! \brief List of all allocated type nodes */ /*! \brief List of all allocated type nodes */
std::vector<TypeNode*> type_nodes_; std::vector<TypeNode*> type_nodes_;
/*! \brief List of all allocated relation nodes */ /*! \brief List of all allocated relation nodes */
std::vector<RelationNode*> rel_nodes_; std::vector<RelationNode*> rel_nodes_;
/*! \brief Number of resolved relations */ /*! \brief Number of resolved relations */
size_t num_resolved_rels_{0}; size_t num_resolved_rels_{0};
/*! \brief map from type node to types. */ /*! \brief map from types to type nodes. */
std::unordered_map<Type, TypeNode*, NodeHash, NodeEqual> tmap_; std::unordered_map<Type, TypeNode*, NodeHash, NodeEqual> tmap_;
/*! \brief Internal queue to update the relation */ /*! \brief Internal queue to update the relation */
std::queue<RelationNode*> update_queue_; std::queue<RelationNode*> update_queue_;
...@@ -205,6 +207,7 @@ class TypeSolver { ...@@ -205,6 +207,7 @@ class TypeSolver {
rel->inqueue = true; rel->inqueue = true;
update_queue_.push(rel); update_queue_.push(rel);
} }
/*! /*!
* \brief Merge rhs type node to lhs * \brief Merge rhs type node to lhs
* \param src The source operand * \param src The source operand
......
...@@ -184,6 +184,10 @@ void NDArray::CopyFromTo(DLTensor* from, ...@@ -184,6 +184,10 @@ void NDArray::CopyFromTo(DLTensor* from,
from_size, from->ctx, to->ctx, from->dtype, stream); from_size, from->ctx, to->ctx, from->dtype, stream);
} }
std::vector<int64_t> NDArray::Shape() const {
return data_->shape_;
}
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import relay
from tvm.relay import Kind, transform
from tvm.relay.loops import while_loop
import numpy as np
def infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod)
entry = mod[mod.entry_func]
return entry if isinstance(expr, relay.Function) else entry.body
def int32(val):
return relay.const(val, 'int32')
def test_arange_with_dynamic_shape():
m, n, k = relay.ShapeVar('m'), relay.ShapeVar('n'), relay.ShapeVar('k')
x = relay.var('x', shape=(m.var, n.var, k.var), dtype='float32')
y0 = relay.shape_of(x)
y1 = relay.take(y0, relay.const(0, 'int32'))
y2 = relay.op.arange(y1)
ex = relay.create_executor()
f = relay.Function([x], y2, type_params=[m, n, k])
# TODO(@jroesch): Restore after code generation.
# data = np.random.rand(10, 5, 3).astype('float32')
# result = ex.evaluate(f)(data)
# np.testing.assert_allclose(result.asnumpy(), np.array(range(10)))
def test_dynamic_concat():
"""
fn @concat_loop(%i: int32, %st: (any, 1)) -> (any, 1) {
if (%i < 10) {
let %i = reshape(cast(i, "float32"), newshape=(1, ))
let %new_st = concatenate((st, i), axis=0)
concat_loop(%i + 1, )
} else {
st
}
}
"""
# Initial Values.
i = relay.var('i', shape=(), dtype='int32')
st = relay.var('st', shape=(relay.Any(), 1), dtype='int32')
def _cond(i, st):
return relay.op.min(relay.op.less(i, int32(10)))
def _body(i, st):
i_vec = relay.op.reshape(i, (1,1))
ret = relay.op.concatenate([st, i_vec], axis=0)
return i + int32(1), ret
loop = while_loop(_cond, [i, st], _body)
start = relay.var('start', shape=(), dtype='int32')
body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1)))
func = relay.Function([start], relay.TupleGetItem(body, 1))
func = infer_type(func)
# TODO(@jroesch, @haichen): We should restore this code when codegeneration
# is merged
# ret_shape = func.checked_type.ret_type.shape
# assert len(ret_shape) == 2, "expected 2-dim output"
# assert relay.ir_pass.alpha_eq(ret_shape[0], relay.Any())
# import pdb; pdb.set_trace()
# mod = relay.module.Module()
# print(relay.ir_pass.infer_type(func, mod=mod))
# ret = relay.Call(loop, [relay.const(0, 'int32'), init])
# mod[mod.entry_func] = relay.Function([], ret)
# print(relay.ir_pass.infer_type(mod[mod.entry_func], mod=mod))
# initial = np.array(0.0, dtype='float32').reshape((1,))
# iter_stop = np.array(10, dtype='int32')
# ex = relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm")
# result = ex.evaluate(mod.entry_func)()
# np.testing.assert_allclose(result.asnumpy(), np.array(range(10)))
def test_dynamic_concat_with_wrong_annotation():
"""
v0.0.1
fn (%start: int32) {
%7 = {
let %while_loop = fn (%i: int32, %st: Tensor[(1, 1), int32]) {
%0 = less(%i, 10)
%1 = min(%0)
if (%1) {
%2 = add(%i, 1)
%3 = reshape(%i, newshape=[1, 1])
%4 = (%st, %3)
/* The result of concat should be 1,1 but it is 2, 1. */
%5 = concatenate(%4)
%while_loop(%2, %5)
} else {
(%i, %st)
}
}
%6 = reshape(0, newshape=[1, 1])
%while_loop(%start, %6)
}
%7.1
}
"""
# Initial Values.
i = relay.var('i', shape=(), dtype='int32')
st = relay.var('st', shape=(1, 1), dtype='int32')
def _cond(i, st):
return relay.op.min(relay.op.less(i, int32(10)))
def _body(i, st):
i_vec = relay.op.reshape(i, (1,1))
ret = relay.op.concatenate([st, i_vec], axis=0)
return i + int32(1), ret
loop = while_loop(_cond, [i, st], _body)
start = relay.var('start', shape=(), dtype='int32')
body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1)))
func = relay.Function([start], relay.TupleGetItem(body, 1))
try:
func = infer_type(func)
assert False
except Exception as e:
assert "in particular dimension 0 conflicts 2 does not match 1" in str(e)
if __name__ == "__main__":
test_arange_with_dynamic_shape()
test_dynamic_concat()
test_dynamic_concat_with_wrong_annotation()
...@@ -493,17 +493,20 @@ def test_arange(): ...@@ -493,17 +493,20 @@ def test_arange():
def verify_arange(start, stop, step): def verify_arange(start, stop, step):
dtype = "float32" dtype = "float32"
if start is None and step is None: if start is None and step is None:
x = relay.arange(stop) x = relay.arange(relay.const(stop, dtype=dtype))
ref_res = np.arange(stop) ref_res = np.arange(stop).astype(dtype)
elif start is None: elif start is None:
x = relay.arange(stop, step=step) x = relay.arange(relay.const(stop, dtype=dtype), step=relay.const(step, dtype=dtype))
ref_res = np.arange(stop, step=step) ref_res = np.arange(stop, step=step).astype(dtype)
elif step is None: elif step is None:
x = relay.arange(start, stop) x = relay.arange(relay.const(start, dtype=dtype), relay.const(stop, dtype=dtype))
ref_res = np.arange(start, stop) ref_res = np.arange(start, stop).astype(dtype)
else: else:
x = relay.arange(start, stop, step) x = relay.arange(
ref_res = np.arange(start, stop, step) relay.const(start, dtype=dtype),
relay.const(stop, dtype=dtype),
relay.const(step, dtype=dtype))
ref_res = np.arange(start, stop, step).astype(dtype)
func = relay.Function([], x) func = relay.Function([], x)
for target, ctx in ctx_list(): for target, ctx in ctx_list():
...@@ -515,11 +518,13 @@ def test_arange(): ...@@ -515,11 +518,13 @@ def test_arange():
verify_arange(None, 20, 2) verify_arange(None, 20, 2)
verify_arange(1, 20, None) verify_arange(1, 20, None)
verify_arange(1, 20, 2) verify_arange(1, 20, 2)
verify_arange(1, 20, 1.5) # arange doesnt' support floating point right now, see type relation
# verify_arange(1, 20, 1.5)
verify_arange(1, 20.5, None) verify_arange(1, 20.5, None)
verify_arange(1, 20, 3) verify_arange(1, 20, 3)
verify_arange(20, 1, -1) verify_arange(20, 1, -1)
verify_arange(20, 1, -1.5) # arange doesnt' support floating point right now, see type relation
# verify_arange(20, 1, -1.5)
def test_tile(): def test_tile():
def verify_tile(dshape, reps): def verify_tile(dshape, reps):
...@@ -616,6 +621,7 @@ def test_gather_nd(): ...@@ -616,6 +621,7 @@ def test_gather_nd():
if __name__ == "__main__": if __name__ == "__main__":
test_arange()
test_cast() test_cast()
test_zeros_ones() test_zeros_ones()
test_unary_identity() test_unary_identity()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment