Commit 3fb84e2b by Jared Roesch Committed by Thierry Moreau

[Relay][RFC] Implement type checking for Any (#3221)

* Implement type checking for Any

Remove code generation related changes

Remove compile changes

Remove more

Remove unification hack

Add some code back that was needed, and clean up test

Refactor test cases

WIP

Implement TypeHint AST

Add test case which should fail

Remove unification changes, and fix bug with let rec

Restore unification for shapes

Improve error reporting while debugging

All examples type check

All examples type check

WIP

First version that works with hints, needs clean up

Remove dead code

Tweaks

Remove type hint

Remove unecessary type hint stuff

Remove more type hints

Clean up

Expose Any expression node

Address CR

Fix

Fix solver

Kill unecessary code

Fix

PyLint

Fix

Relocate loops

Fix license and test

Lint again

Lint again

Fix loops

Fix docstring

Fix template error

Fix compiler issue

Fix compile err

Remove more runtime changes

Restore buffer

Fix segfault

Fix

Fix arange

* Address feedback

* Fix typo

* Fix arange

* Fix op level3

* Fix issue with Python wrapper
parent b10dda69
......@@ -138,6 +138,15 @@ struct Reduce : public ExprNode<Reduce> {
static constexpr const char* _type_key = "Reduce";
};
/*! \brief Any shape. */
struct Any : public ExprNode<Any> {
TVM_DLL static Expr make();
void VisitAttrs(AttrVisitor* v) final {}
static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "Any";
};
/*!
* \brief Auxiliary data structure used in IR Pass to indicate a tensor.
*/
......
......@@ -123,19 +123,19 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
/*! \brief Attributes used in arange operators */
struct ArangeAttrs : public tvm::AttrsNode<ArangeAttrs> {
tvm::Expr start;
tvm::Expr stop;
tvm::Expr step;
Expr start;
Expr stop;
Expr step;
DataType dtype;
TVM_DECLARE_ATTRS(ArangeAttrs, "relay.attrs.ArangeAttrs") {
TVM_ATTR_FIELD(start).set_default(make_const(Float(32), 0))
TVM_ATTR_FIELD(start)
.describe("Start of interval. The interval includes this value.");
TVM_ATTR_FIELD(stop)
.describe("Stop of interval. The interval does not include this value.");
TVM_ATTR_FIELD(step).set_default(make_const(Float(32), 1))
TVM_ATTR_FIELD(step)
.describe("Spacing between values.");
TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>())
TVM_ATTR_FIELD(dtype)
.describe("Target data type.");
}
}; // struct ArangeAttrs
......
......@@ -64,9 +64,10 @@ struct RelayErrorStream {
struct Error : public dmlc::Error {
Span sp;
explicit Error(const std::string& msg) : dmlc::Error(msg), sp() {}
Error(const std::stringstream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*)
Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*)
explicit Error(const std::string& msg) : dmlc::Error(msg), sp(nullptr) {}
Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp(nullptr) {} // NOLINT(*)
Error(const Error& err) : dmlc::Error(err.what()), sp(nullptr) {}
Error() : dmlc::Error(""), sp(nullptr) {}
};
/*! \brief An abstraction around how errors are stored and reported.
......@@ -118,7 +119,8 @@ class ErrorReporter {
* \param err The error message to report.
*/
inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) {
this->ReportAt(global, node, Error(err));
std::string err_msg = err.str();
this->ReportAt(global, node, Error(err_msg));
}
/*! \brief Report an error against a program, using the full program
......
......@@ -561,6 +561,9 @@ inline const TTypeNode* ExprNode::type_as() const {
return node;
}
/*! \brief Pretty print a Relay node, producing a fragment of the Relay text format. */
std::string PrettyPrint(const NodeRef& node);
/*!
* \brief Render the node as a string in the Relay text format.
* \param node The node to be rendered.
......
......@@ -158,6 +158,22 @@ using FForwardRewrite = runtime::TypedPackedFunc<
using FPrimalGradient = runtime::TypedPackedFunc<tvm::Array<Expr>(const Expr& orig_call,
const Expr& output_grad)>;
/*!
* \brief The codegeneration strategy for dynamic dimensions.
*/
enum AnyCodegenStrategy {
/*! \brief The default strategy of using completely variable dimensions. */
kVariableDimensions
};
/* \brief A runtime representation of shape. */
using Shape = Array<IndexExpr>;
using FShapeFunc = runtime::TypedPackedFunc<
Array<Tensor>(const Attrs& attrs,
const Array<Tensor>& inputs,
const Array<Shape>& out_shapes)>;
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_ATTR_TYPES_H_
......@@ -35,6 +35,8 @@
namespace tvm {
namespace relay {
using Any = tvm::ir::Any;
/*! \brief Base type of the Relay type hiearchy. */
class TypeNode : public RelayNode {
public:
......@@ -384,6 +386,7 @@ class TypeReporterNode : public Node {
* But it is possible for the solver to resolve src by dst as well.
*/
TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0;
/*!
* \brief assert shape expression comparison.
* \note Use assert only if any of the condition input is symbolic.
......
......@@ -190,6 +190,8 @@ class NDArray {
TVM_DLL static void CopyFromTo(
DLTensor* from, DLTensor* to, TVMStreamHandle stream = nullptr);
TVM_DLL std::vector<int64_t> Shape() const;
// internal namespace
struct Internal;
protected:
......
......@@ -294,7 +294,7 @@ def get_last_ffi_error():
"""
c_err_msg = py_str(_LIB.TVMGetLastError())
py_err_msg, err_type = c2pyerror(c_err_msg)
if err_type.startswith("tvm.error."):
if err_type is not None and err_type.startswith("tvm.error."):
err_type = err_type[10:]
return ERROR_TYPE.get(err_type, TVMError)(py_err_msg)
......
......@@ -479,7 +479,8 @@ def extern(shape,
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.get_current().tag
shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
shape = [shape] if isinstance(shape[0], (_expr.Expr, _Integral)) else shape
if shape == () or isinstance(shape[0], (_expr.Expr, _Integral)):
shape = [shape]
if in_buffers is not None:
in_buffers = [in_buffers] if not isinstance(in_buffers, list) else in_buffers
if len(inputs) != len(in_buffers):
......
......@@ -63,6 +63,7 @@ TupleType = ty.TupleType
TensorType = ty.TensorType
Kind = ty.Kind
TypeVar = ty.TypeVar
ShapeVar = ty.ShapeVar
TypeConstraint = ty.TypeConstraint
FuncType = ty.FuncType
TypeRelation = ty.TypeRelation
......@@ -71,6 +72,7 @@ scalar_type = ty.scalar_type
RefType = ty.RefType
GlobalTypeVar = ty.GlobalTypeVar
TypeCall = ty.TypeCall
Any = ty.Any
# Expr
Expr = expr.Expr
......
......@@ -570,6 +570,7 @@ def const(value, dtype=None):
"""
if isinstance(value, (_base.numeric_types, (bool, list))):
value = _np.array(value, dtype=dtype)
if not dtype:
# when dtype is None: int maps to "int32", float maps to "float32"
map_dtype = {
......@@ -578,6 +579,7 @@ def const(value, dtype=None):
}.get(value.dtype, None)
if map_dtype:
value = value.astype(map_dtype)
if isinstance(value, (_np.ndarray, _np.generic)):
value = _nd.array(value)
......
......@@ -491,9 +491,9 @@ def _mx_arange(inputs, attrs):
raise tvm.error.OpAttributeUnimplemented(
'Attribute "repeat" is not supported in operator arange.')
new_attrs = {}
new_attrs["start"] = attrs.get_float("start", 0)
new_attrs["stop"] = attrs.get_float("stop")
new_attrs["step"] = attrs.get_float("step", 1)
new_attrs["start"] = _expr.const(attrs.get_float("start", 0.0))
new_attrs["stop"] = _expr.const(attrs.get_float("stop"))
new_attrs["step"] = _expr.const(attrs.get_float("step", 1.0))
new_attrs["dtype"] = attrs.get_str("dtype", "float32")
return _op.arange(**new_attrs)
......
......@@ -1059,9 +1059,9 @@ def _range():
return AttrCvt(
op_name="arange",
ignores=['Tidx'],
extras={'start': start,
"stop": limit,
'step': delta,
extras={'start': _expr.const(start),
"stop": _expr.const(limit),
'step': _expr.const(delta),
'dtype': dtype})([], attr)
return _impl
......@@ -1269,8 +1269,8 @@ def _batch_to_space_nd():
crop = crops[axis - 1]
if crop != [0, 0]:
indices = tvm.relay.arange(
crop[0],
reshaped_permuted_shape[axis] - crop[1],
_expr.const(crop[0]),
_expr.const(reshaped_permuted_shape[axis] - crop[1]),
dtype='int32'
)
cropped = tvm.relay.take(cropped, indices=indices, axis=axis)
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""
Utilities for building Relay loops.
"""
from .scope_builder import ScopeBuilder
from . import expr as _expr
def while_loop(cond, loop_vars, loop_bodies):
"""
Construct a while loop.
Parameters
----------
cond: Callable[Tuple[relay.Expr], relay.Expr]
The condition of the loop.
loop_vars: Tuple[relay.Expr]
The variables being looped over.
The initial values of the loop, will be used to
construct the loop variables.
loop_bodies: Callable[Tuple[relay.Expr], Tuple[relay.Expr]]
The body of the loop, should be a function which
given loop variables produces the output result
also as a tuple
Returns
-------
loop: relay.Expr
The loop expression.
"""
sb = ScopeBuilder()
loop = _expr.Var("while_loop")
fresh_vars = []
for i, loop_var in enumerate(loop_vars):
name = loop_var.name_hint if isinstance(loop_var, _expr.Var) else "arg{}".format(i)
new_var = _expr.var(name, type_annotation=sb.type_of(loop_var))
fresh_vars.append(new_var)
with sb.if_scope(cond(*fresh_vars)):
sb.ret(loop(*loop_bodies(*fresh_vars)))
with sb.else_scope():
sb.ret(_expr.Tuple(fresh_vars))
func = _expr.Function(fresh_vars, sb.get())
let = _expr.Let(loop, func, loop)
return let
......@@ -17,7 +17,7 @@
"""Transform operators."""
from . import _make
from ..expr import TupleWrapper
from ..expr import TupleWrapper, const
def cast(data, dtype):
......@@ -272,7 +272,7 @@ def full_like(data, fill_value):
return _make.full_like(data, fill_value)
def arange(start, stop=None, step=1, dtype="float32"):
def arange(start, stop=None, step=None, dtype="float32"):
"""Return evenly spaced values within a given interval.
.. note::
......@@ -310,9 +310,13 @@ def arange(start, stop=None, step=1, dtype="float32"):
relay.arange(1, 5) = [1, 2, 3, 4]
relay.arange(1, 5, 1.5) = [1, 2.5, 4]
"""
if step is None:
step = const(1, dtype)
if stop is None:
stop = start
start = 0
start = const(0, dtype=dtype)
return _make.arange(start, stop, step, dtype)
......
......@@ -42,7 +42,6 @@ class WithScope(object):
else:
self._exit_cb()
def _make_lets(bindings, ret_value):
"""Make a nested let expressions.
......@@ -176,6 +175,24 @@ class ScopeBuilder(object):
false_branch)
return WithScope(None, _on_exit)
def type_of(self, expr):
"""
Compute the type of an expression.
Parameters
----------
expr: relay.Expr
The expression to compute the type of.
"""
if isinstance(expr, _expr.Var):
return expr.type_annotation
ity = _ty.IncompleteType()
var = _expr.var("unify", ity)
self.let(var, expr)
return ity
def ret(self, value):
"""Set the return value of this scope.
......
......@@ -20,6 +20,7 @@ from enum import IntEnum
from .base import RelayNode, register_relay_node
from . import _make
Any = _make.Any
class Type(RelayNode):
"""The base type for all Relay types."""
......@@ -137,6 +138,19 @@ class TypeVar(Type):
"""
self.__init_handle_by_constructor__(_make.TypeVar, var, kind)
def ShapeVar(name):
"""A helper which constructs a type var of which the shape kind.
Parameters
----------
name : str
Returns
-------
type_var : tvm.relay.TypeVar
The shape variable.
"""
return TypeVar(name, kind=Kind.ShapeVar)
@register_relay_node
class GlobalTypeVar(Type):
......
......@@ -970,7 +970,9 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
op->call_type == Call::PureExtern) {
return CreateCallExtern(op);
} else {
LOG(FATAL) << "Unknown call type ";
LOG(FATAL) << "Unknown call type " <<
"name= " << op->name <<
" call_type= " << op->call_type;
return nullptr;
}
}
......
......@@ -246,13 +246,20 @@ inline Expr MergeMulMod(const Expr &base) {
inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
Expr base = n->elem_offset;
if (n->strides.size() == 0) {
CHECK_EQ(n->shape.size(), index.size());
if (index.size() > 0) {
Expr offset = index[0];
for (size_t i = 1; i < index.size(); ++i) {
offset = MergeMulMod(offset * n->shape[i] + index[i]);
// Scalar case
if (n->shape.size() == 0 && index.size() == 1) {
auto is_int = index[0].as<IntImm>();
CHECK(is_int && is_int->value == 0);
base = base + index[0];
} else {
CHECK_EQ(n->shape.size(), index.size());
if (index.size() > 0) {
Expr offset = index[0];
for (size_t i = 1; i < index.size(); ++i) {
offset = MergeMulMod(offset * n->shape[i] + index[i]);
}
base = base + offset;
}
base = base + offset;
}
} else {
CHECK_EQ(n->strides.size(), index.size());
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -35,14 +35,25 @@ namespace Internal {
using tvm::ir::CommReducerNode;
using tvm::ir::Reduce;
using tvm::ir::Any;
using tvm::ir::AttrStmt;
template<>
void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const {
LOG(FATAL) << "Reduce do not work with old Visitor, use IRFunctor style visitor";
LOG(FATAL) << "Reduce does not work with old Visitor, use IRFunctor style visitor";
}
template<>
void ExprNode<Any>::accept(IRVisitor *v, const Expr&) const {
LOG(FATAL) << "Any does not work with old Visitor, use IRFunctor style visitor";
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Any>([](const Any *op, IRPrinter *p) {
p->stream << "?";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) {
p->stream << "reduce(combiner="
<< op->combiner;
......@@ -116,8 +127,14 @@ Expr Reduce::make(CommReducer combiner, Array<Expr> source,
return Expr(n);
}
Expr Any::make() {
auto n = make_node<Any>();
return Expr(n);
}
TVM_REGISTER_NODE_TYPE(CommReducerNode);
TVM_REGISTER_NODE_TYPE(Reduce);
TVM_REGISTER_NODE_TYPE(Any);
TVM_REGISTER_NODE_TYPE(AttrStmt);
TVM_REGISTER_NODE_TYPE(FloatImm);
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -38,9 +38,12 @@ Expr Tensor::operator()(Array<Var> indices) const {
Expr Tensor::operator()(Array<Expr> indices) const {
using HalideIR::Internal::Call;
CHECK_EQ(ndim(), indices.size())
<< "Tensor dimension mismatch in read"
<< "ndim = " << ndim() << ", indices.size=" << indices.size();
if (ndim() != 0) {
CHECK_EQ(ndim(), indices.size())
<< "Tensor dimension mismatch in read"
<< "ndim = " << ndim() << ", indices.size=" << indices.size();
}
auto n = Call::make(
(*this)->dtype, (*this)->op->name, indices, Call::Halide,
(*this)->op, (*this)->value_index);
......
......@@ -417,10 +417,10 @@ class RelayBuildModule : public runtime::ModuleNode {
}
/*!
* \brief Build relay function to runtime module
* \brief Compile a Relay function to runtime module.
*
* \param func Relay Function
* \param params parameters
* \param func The Relay function.
* \param params The parameters.
*/
void BuildRelay(
Function func,
......@@ -444,8 +444,13 @@ class RelayBuildModule : public runtime::ModuleNode {
ret_.graph_json = graph_codegen_->GetJSON();
ret_.params = graph_codegen_->GetParams();
ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host_,
BuildConfig::Current());
auto lowered_funcs = graph_codegen_->GetLoweredFunc();
if (lowered_funcs.size() != 0) {
ret_.mod = tvm::build(
lowered_funcs,
target_host_,
BuildConfig::Current());
}
}
protected:
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......
......@@ -67,6 +67,7 @@ void ErrorReporter::RenderErrors(const Module& module, bool use_color) {
std::stringstream err_msg;
err_msg << rang::fg::red;
err_msg << " ";
for (auto index : error_indicies) {
err_msg << this->errors_[index].what() << "; ";
}
......@@ -88,7 +89,7 @@ void ErrorReporter::RenderErrors(const Module& module, bool use_color) {
// First we output a header for the errors.
annotated_prog <<
rang::style::bold << std::endl <<
"Error(s) have occurred. We have annotated the program with them:"
"Error(s) have occurred. The program has been annotated with them:"
<< std::endl << std::endl << rang::style::reset;
// For each global function which contains errors, we will
......
......@@ -287,6 +287,8 @@ RefCreate RefCreateNode::make(Expr value) {
return RefCreate(n);
}
TVM_REGISTER_NODE_TYPE(RefCreateNode);
TVM_REGISTER_API("relay._make.RefCreate")
.set_body_typed(RefCreateNode::make);
......@@ -301,6 +303,8 @@ RefRead RefReadNode::make(Expr ref) {
return RefRead(n);
}
TVM_REGISTER_NODE_TYPE(RefReadNode);
TVM_REGISTER_API("relay._make.RefRead")
.set_body_typed(RefReadNode::make);
......@@ -316,6 +320,8 @@ RefWrite RefWriteNode::make(Expr ref, Expr value) {
return RefWrite(n);
}
TVM_REGISTER_NODE_TYPE(RefWriteNode);
TVM_REGISTER_API("relay._make.RefWrite")
.set_body_typed(RefWriteNode::make);
......
......@@ -686,7 +686,9 @@ class PrettyPrinter :
Doc PrintAttr(const NodeRef& value, bool meta = false) {
if (value.defined()) {
Doc printed_attr;
if (meta) {
if (value.as<tvm::ir::Any>()) {
printed_attr << "?";
} else if (meta) {
printed_attr = meta_.GetMetaNode(value);
} else {
printed_attr = VisitAttr(value);
......@@ -846,6 +848,12 @@ std::string PrettyPrint_(const NodeRef& node,
return doc.str();
}
std::string PrettyPrint(const NodeRef& node) {
Doc doc;
doc << PrettyPrinter(false, runtime::TypedPackedFunc<std::string(Expr)>()).PrintFinal(node);
return doc.str();
}
std::string AsText(const NodeRef& node,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) {
......
......@@ -228,5 +228,9 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "RefTypeNode(" << node->value << ")";
});
TVM_REGISTER_API("relay._make.Any")
.set_body_typed<IndexExpr()>([]() { return Any::make(); });
} // namespace relay
} // namespace tvm
......@@ -87,6 +87,11 @@ Type ConcreteBroadcast(const TensorType& t1,
oshape.push_back(s2);
} else if (EqualConstInt(s2, 1)) {
oshape.push_back(s1);
} else if (s1.as<Any>() && EqualConstInt(s2, 1)) {
// TODO(@jroesch): we need to come back to this
oshape.push_back(s2);
} else if (s2.as<Any>() && EqualConstInt(s1, 1)) {
oshape.push_back(s1);
} else {
RELAY_ERROR(
"Incompatible broadcast type "
......
......@@ -313,17 +313,24 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
Type VisitExpr_(const LetNode* let) final {
// if the definition is a function literal, permit recursion
bool is_functional_literal = let->value.as<FunctionNode>() != nullptr;
Type let_type = IncompleteTypeNode::make(Kind::kType);
if (is_functional_literal) {
type_map_[let->var].checked_type = IncompleteTypeNode::make(Kind::kType);
let_type = GetType(let->var);
type_map_[let->var].checked_type = let_type;
}
Type vtype = GetType(let->value);
if (let->var->type_annotation.defined()) {
vtype = Unify(vtype, let->var->type_annotation, GetRef<Let>(let));
let_type = Unify(let_type, let->var->type_annotation, GetRef<Let>(let));
}
Type vtype = GetType(let->value);
let_type = Unify(let_type, vtype, GetRef<Let>(let));
CHECK(is_functional_literal || !type_map_.count(let->var));
// NOTE: no scoping is necessary because var are unique in program
type_map_[let->var].checked_type = vtype;
type_map_[let->var].checked_type = let_type;
return GetType(let->body);
}
......@@ -473,7 +480,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
}
for (size_t i = 0; i < fn_ty->arg_types.size(); i++) {
this->Unify(fn_ty->arg_types[i], arg_types[i], call->args[i]);
this->Unify(fn_ty->arg_types[i], arg_types[i], GetRef<Call>(call));
}
for (auto cs : fn_ty->type_constraints) {
......@@ -556,6 +563,14 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
return FuncTypeNode::make(c->inputs, TypeCallNode::make(c->belong_to, types),
td->type_vars, {});
}
void Solve() {
solver_.Solve();
if (err_reporter.AnyErrors()) {
err_reporter.RenderErrors(mod_);
}
}
};
class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
......@@ -673,7 +688,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
update_missing_type_annotation_ &&
!new_var->type_annotation.defined());
bool need_update_fn = (
bool need_update_fn =(
std::is_base_of<FunctionNode, T>::value &&
update_missing_type_annotation_ &&
!new_fn->ret_type.defined());
......@@ -738,16 +753,13 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
Expr TypeInferencer::Infer(Expr expr) {
// Step 0: Populate the constraints.
// Step 1: Populate the constraints.
GetType(expr);
// Step 1: Solve the constraints.
solver_.Solve();
if (err_reporter.AnyErrors()) {
err_reporter.RenderErrors(mod_);
}
// Step 2: Solve the constraints.
Solve();
// Step 2: Attach resolved types to checked_type field.
// Step 3: Attach resolved types to checked_type field.
auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr);
CHECK(WellFormed(resolved_expr));
return resolved_expr;
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -89,7 +89,6 @@ class TypeSolver {
* \param location The location at which the unification problem arose.
*/
Type Unify(const Type& lhs, const Type& rhs, const NodeRef& location);
/*!
* \brief Report an error at the provided location.
* \param err The error to report.
......@@ -124,6 +123,7 @@ class TypeSolver {
TypeNode* parent{nullptr};
/*! \brief set of relations that is related to this type node */
std::unordered_set<RelationNode*> rel_set;
/*!
* \brief Find the root type node, perform path compression
* \return The root type node.
......@@ -159,13 +159,15 @@ class TypeSolver {
NodeRef location;
};
/*! \brief A simple union find between shapes. */
tvm::Map<IndexExpr, IndexExpr> shape_uf_;
/*! \brief List of all allocated type nodes */
std::vector<TypeNode*> type_nodes_;
/*! \brief List of all allocated relation nodes */
std::vector<RelationNode*> rel_nodes_;
/*! \brief Number of resolved relations */
size_t num_resolved_rels_{0};
/*! \brief map from type node to types. */
/*! \brief map from types to type nodes. */
std::unordered_map<Type, TypeNode*, NodeHash, NodeEqual> tmap_;
/*! \brief Internal queue to update the relation */
std::queue<RelationNode*> update_queue_;
......@@ -205,6 +207,7 @@ class TypeSolver {
rel->inqueue = true;
update_queue_.push(rel);
}
/*!
* \brief Merge rhs type node to lhs
* \param src The source operand
......
......@@ -184,6 +184,10 @@ void NDArray::CopyFromTo(DLTensor* from,
from_size, from->ctx, to->ctx, from->dtype, stream);
}
std::vector<int64_t> NDArray::Shape() const {
return data_->shape_;
}
} // namespace runtime
} // namespace tvm
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import relay
from tvm.relay import Kind, transform
from tvm.relay.loops import while_loop
import numpy as np
def infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod)
entry = mod[mod.entry_func]
return entry if isinstance(expr, relay.Function) else entry.body
def int32(val):
return relay.const(val, 'int32')
def test_arange_with_dynamic_shape():
m, n, k = relay.ShapeVar('m'), relay.ShapeVar('n'), relay.ShapeVar('k')
x = relay.var('x', shape=(m.var, n.var, k.var), dtype='float32')
y0 = relay.shape_of(x)
y1 = relay.take(y0, relay.const(0, 'int32'))
y2 = relay.op.arange(y1)
ex = relay.create_executor()
f = relay.Function([x], y2, type_params=[m, n, k])
# TODO(@jroesch): Restore after code generation.
# data = np.random.rand(10, 5, 3).astype('float32')
# result = ex.evaluate(f)(data)
# np.testing.assert_allclose(result.asnumpy(), np.array(range(10)))
def test_dynamic_concat():
"""
fn @concat_loop(%i: int32, %st: (any, 1)) -> (any, 1) {
if (%i < 10) {
let %i = reshape(cast(i, "float32"), newshape=(1, ))
let %new_st = concatenate((st, i), axis=0)
concat_loop(%i + 1, )
} else {
st
}
}
"""
# Initial Values.
i = relay.var('i', shape=(), dtype='int32')
st = relay.var('st', shape=(relay.Any(), 1), dtype='int32')
def _cond(i, st):
return relay.op.min(relay.op.less(i, int32(10)))
def _body(i, st):
i_vec = relay.op.reshape(i, (1,1))
ret = relay.op.concatenate([st, i_vec], axis=0)
return i + int32(1), ret
loop = while_loop(_cond, [i, st], _body)
start = relay.var('start', shape=(), dtype='int32')
body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1)))
func = relay.Function([start], relay.TupleGetItem(body, 1))
func = infer_type(func)
# TODO(@jroesch, @haichen): We should restore this code when codegeneration
# is merged
# ret_shape = func.checked_type.ret_type.shape
# assert len(ret_shape) == 2, "expected 2-dim output"
# assert relay.ir_pass.alpha_eq(ret_shape[0], relay.Any())
# import pdb; pdb.set_trace()
# mod = relay.module.Module()
# print(relay.ir_pass.infer_type(func, mod=mod))
# ret = relay.Call(loop, [relay.const(0, 'int32'), init])
# mod[mod.entry_func] = relay.Function([], ret)
# print(relay.ir_pass.infer_type(mod[mod.entry_func], mod=mod))
# initial = np.array(0.0, dtype='float32').reshape((1,))
# iter_stop = np.array(10, dtype='int32')
# ex = relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm")
# result = ex.evaluate(mod.entry_func)()
# np.testing.assert_allclose(result.asnumpy(), np.array(range(10)))
def test_dynamic_concat_with_wrong_annotation():
"""
v0.0.1
fn (%start: int32) {
%7 = {
let %while_loop = fn (%i: int32, %st: Tensor[(1, 1), int32]) {
%0 = less(%i, 10)
%1 = min(%0)
if (%1) {
%2 = add(%i, 1)
%3 = reshape(%i, newshape=[1, 1])
%4 = (%st, %3)
/* The result of concat should be 1,1 but it is 2, 1. */
%5 = concatenate(%4)
%while_loop(%2, %5)
} else {
(%i, %st)
}
}
%6 = reshape(0, newshape=[1, 1])
%while_loop(%start, %6)
}
%7.1
}
"""
# Initial Values.
i = relay.var('i', shape=(), dtype='int32')
st = relay.var('st', shape=(1, 1), dtype='int32')
def _cond(i, st):
return relay.op.min(relay.op.less(i, int32(10)))
def _body(i, st):
i_vec = relay.op.reshape(i, (1,1))
ret = relay.op.concatenate([st, i_vec], axis=0)
return i + int32(1), ret
loop = while_loop(_cond, [i, st], _body)
start = relay.var('start', shape=(), dtype='int32')
body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1)))
func = relay.Function([start], relay.TupleGetItem(body, 1))
try:
func = infer_type(func)
assert False
except Exception as e:
assert "in particular dimension 0 conflicts 2 does not match 1" in str(e)
if __name__ == "__main__":
test_arange_with_dynamic_shape()
test_dynamic_concat()
test_dynamic_concat_with_wrong_annotation()
......@@ -493,17 +493,20 @@ def test_arange():
def verify_arange(start, stop, step):
dtype = "float32"
if start is None and step is None:
x = relay.arange(stop)
ref_res = np.arange(stop)
x = relay.arange(relay.const(stop, dtype=dtype))
ref_res = np.arange(stop).astype(dtype)
elif start is None:
x = relay.arange(stop, step=step)
ref_res = np.arange(stop, step=step)
x = relay.arange(relay.const(stop, dtype=dtype), step=relay.const(step, dtype=dtype))
ref_res = np.arange(stop, step=step).astype(dtype)
elif step is None:
x = relay.arange(start, stop)
ref_res = np.arange(start, stop)
x = relay.arange(relay.const(start, dtype=dtype), relay.const(stop, dtype=dtype))
ref_res = np.arange(start, stop).astype(dtype)
else:
x = relay.arange(start, stop, step)
ref_res = np.arange(start, stop, step)
x = relay.arange(
relay.const(start, dtype=dtype),
relay.const(stop, dtype=dtype),
relay.const(step, dtype=dtype))
ref_res = np.arange(start, stop, step).astype(dtype)
func = relay.Function([], x)
for target, ctx in ctx_list():
......@@ -515,11 +518,13 @@ def test_arange():
verify_arange(None, 20, 2)
verify_arange(1, 20, None)
verify_arange(1, 20, 2)
verify_arange(1, 20, 1.5)
# arange doesnt' support floating point right now, see type relation
# verify_arange(1, 20, 1.5)
verify_arange(1, 20.5, None)
verify_arange(1, 20, 3)
verify_arange(20, 1, -1)
verify_arange(20, 1, -1.5)
# arange doesnt' support floating point right now, see type relation
# verify_arange(20, 1, -1.5)
def test_tile():
def verify_tile(dshape, reps):
......@@ -616,6 +621,7 @@ def test_gather_nd():
if __name__ == "__main__":
test_arange()
test_cast()
test_zeros_ones()
test_unary_identity()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment