Unverified Commit 9afde69b by Tianqi Chen Committed by GitHub

[RELAY][OP] conv2d, ShapeExpr->IndexExpr (#1798)

parent 147f3ad5
......@@ -56,6 +56,22 @@ namespace tvm {
__fvisit__(#FieldName, &FieldName)
/*!
* \brief Create a NodeRef type that represents null.
* \tparam TNodeRef the type to be created.
* \return A instance that will represent None.
*/
template<typename TNodeRef>
inline TNodeRef NullValue() {
return TNodeRef(NodePtr<Node>(nullptr));
}
template<>
inline Type NullValue<Type>() {
return Type(Type::Handle, 0, 0);
}
/*! \brief Error thrown during attribute checking. */
struct AttrError : public dmlc::Error {
/*!
......
......@@ -114,7 +114,7 @@ inline TNodeRef TVMArgValue::AsNodeRef() const {
static_assert(
std::is_base_of<NodeRef, TNodeRef>::value,
"Conversion only works for NodeRef");
if (type_code_ == kNull) return TNodeRef();
if (type_code_ == kNull) return TNodeRef(NodePtr<Node>(nullptr));
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
NodePtr<Node>& sptr = *ptr<NodePtr<Node> >();
CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get()))
......
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/attrs/nn.h
* \brief Auxiliary attributes for nn operators.
*/
#ifndef TVM_RELAY_ATTRS_NN_H_
#define TVM_RELAY_ATTRS_NN_H_
#include <tvm/attrs.h>
#include <string>
namespace tvm {
namespace relay {
/*! \brief Attributes used in convolution operators */
struct ConvAttrs : public tvm::AttrsNode<ConvAttrs> {
Array<IndexExpr> strides;
Array<IndexExpr> padding;
Array<IndexExpr> dilation;
int groups;
IndexExpr channels;
Array<IndexExpr> kernel_size;
std::string data_layout;
std::string weight_layout;
std::string out_layout;
DataType out_dtype;
TVM_DECLARE_ATTRS(ConvAttrs, "relay.attrs.ConvAttrs") {
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(groups).set_default(1)
.describe("Controls the connections between inputs and outputs."
"At groups=1, all inputs are convolved to all outputs."
"At groups=2, the operation becomes equivalent to having two convolution"
"layers side by side, each seeing half the input channels, and producing"
"half the output channels, and both subsequently concatenated.");
TVM_ATTR_FIELD(channels)
.describe("The number of output channels in the convolution."
" If it is not set, inferred by shape of the weight.")
.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.")
.set_default(NullValue<Array<IndexExpr> >());
TVM_ATTR_FIELD(data_layout).set_default("NCHW")
.describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(weight_layout).set_default("OIHW")
.describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively.");
TVM_ATTR_FIELD(out_layout).set_default("__undef__")
.describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Default to be same as input layout.");
// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(Int(0))
.describe("Output data type, set to explicit type under mixed precision setting");
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_NN_H_
......@@ -37,7 +37,7 @@ using DataType = ::tvm::Type;
/*!
* \brief Symbolic expression for tensor shape.
*/
using ShapeExpr = ::tvm::Expr;
using IndexExpr = ::tvm::Expr;
/*!
* \brief Hash function for nodes.
......
......@@ -286,7 +286,9 @@ class CallNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static Call make(Expr op, Array<Expr> args, Attrs attrs = Attrs(),
TVM_DLL static Call make(Expr op,
Array<Expr> args,
Attrs attrs = Attrs(),
Array<Type> ty_args = Array<Type>());
static constexpr const char* _type_key = "relay.Call";
......
......@@ -70,9 +70,9 @@ class TensorTypeNode : public BaseTensorTypeNode {
public:
/*!
* \brief The shape of the tensor,
* represented by ShapeExpr(tvm::Expr).
* represented by IndexExpr(tvm::Expr).
*/
Array<ShapeExpr> shape;
Array<IndexExpr> shape;
/*! \brief The content data type */
DataType dtype;
......@@ -82,7 +82,7 @@ class TensorTypeNode : public BaseTensorTypeNode {
v->Visit("span", &span);
}
TVM_DLL static TensorType make(Array<ShapeExpr> shape, DataType dtype);
TVM_DLL static TensorType make(Array<IndexExpr> shape, DataType dtype);
/*! \brief Construct an scalar containing elements of dtype. */
TVM_DLL static TensorType Scalar(DataType dtype);
......@@ -273,8 +273,10 @@ class TypeReporterNode : public Node {
* \brief assert shape expression equals each other.
* \param lhs The left operand.
* \param rhs The right operand.
* \return false if assertation can be proven to have failed
* true if solver can still proceed.
*/
TVM_DLL virtual void AssertEQ(const ShapeExpr& lhs, const ShapeExpr& rhs) = 0;
TVM_DLL virtual bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) = 0;
// solver is not serializable.
void VisitAttrs(tvm::AttrVisitor* v) final {}
......
......@@ -521,6 +521,12 @@ class TVMArgValue : public TVMPODValue_ {
if (type_code_ == kStr) {
return String2TVMType(operator std::string());
}
// None type
if (type_code_ == kNull) {
TVMType t;
t.code = kHandle; t.bits = 0; t.lanes = 0;
return t;
}
TVM_CHECK_TYPE_CODE(type_code_, kTVMType);
return value_.v_type;
}
......@@ -878,6 +884,7 @@ inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*)
#endif
inline std::string TVMType2String(TVMType t) {
if (t.bits == 0) return "";
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
std::ostringstream os;
os << t;
......@@ -896,6 +903,11 @@ inline std::string TVMType2String(TVMType t) {
inline TVMType String2TVMType(std::string s) {
TVMType t;
// handle None type
if (s.length() == 0) {
t.bits = 0; t.lanes = 0; t.code = kHandle;
return t;
}
t.bits = 32; t.lanes = 1;
const char* scan;
if (s.substr(0, 3) == "int") {
......
......@@ -9,6 +9,7 @@ from . import ir_builder
# Operators
from .op import Op
from .op.tensor import *
from .op import nn
# Span
Span = base.Span
......
......@@ -11,17 +11,19 @@ class Environment(NodeBase):
options and more.
"""
def __init__(self, funcs):
def __init__(self, funcs=None):
"""Construct an environment.
Parameters
------
funcs: list of relay.Function
funcs : optional, dict
Map of global var to Function
Returns
------
env: A new environment containing :py:class:`~relay.env.Environment`.
"""
funcs = funcs if funcs else {}
self.__init_handle_by_constructor__(_make.Environment, funcs)
def add(self, var, func):
......
......@@ -6,10 +6,26 @@ Exposes an interface for configuring the passes and scripting
them in Python.
"""
from . import _ir_pass
# Expose checking expression, should rename to infer_type.
# pylint: disable=invalid-name
check_expr = _ir_pass.check_expr
def infer_type(env, expr):
"""Infer the type of expr under the context of env
Parameters
----------
env : relay.Environment
The global environmemt.
expr : relay.Expr
The input expression.
Returns
-------
checked_expr : relay.Expr
The checked expression.
"""
return _ir_pass.infer_type(env, expr)
well_formed = _ir_pass.well_formed
......
......@@ -5,6 +5,8 @@ from .op import get, register, Op
# Operators
from .tensor import *
from . import nn
# operator registry
from . import _tensor
......
"""Neural network operations."""
from __future__ import absolute_import as _abs
from . import _make
def conv2d(data,
weight,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=1,
channels=None,
kernel_size=None,
data_layout="NCHW",
weight_layout="OIHW",
out_layout="",
out_dtype=""):
"""Two dimensional convolution operator.
Parameters
----------
data : relay.Expr
The input data to the operator.
weight : relay.Expr
The weight expressions.
strides : tuple of int, optional
The strides of convoltution.
padding : tuple of int, optional
The padding of convolution on both sides of inputs.
dilation : tuple of int, optional
Specifies the dilation rate to be used for dilated convolution.
groups : int, optional
Number of groups for grouped convolution.
data_layout : str, optional
Layout of the input.
weight_layout : str, optional
Layout of the weight.
out_layout : str, optional
Layout of the output.
out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.
"""
return _make.conv2d(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
weight_layout, out_layout, out_dtype)
......@@ -117,6 +117,9 @@ Operation ComputeOpNode::make(std::string name,
Map<std::string, NodeRef> attrs,
Array<IterVar> axis,
Array<Expr> body) {
if (!attrs.defined()) {
attrs = Map<std::string, NodeRef>();
}
auto n = make_node<ComputeOpNode>();
n->name = std::move(name);
n->tag = std::move(tag);
......
......@@ -43,6 +43,9 @@ Operation ExternOpNode::make(std::string name,
Array<Buffer> input_placeholders,
Array<Buffer> output_placeholders,
Stmt body) {
if (!attrs.defined()) {
attrs = Map<std::string, NodeRef>();
}
auto n = make_node<ExternOpNode>();
n->name = std::move(name);
n->tag = std::move(tag);
......
......@@ -51,6 +51,9 @@ Operation ScanOpNode::make(std::string name,
Array<Tensor> update,
Array<Tensor> state_placeholder,
Array<Tensor> inputs) {
if (!attrs.defined()) {
attrs = Map<std::string, NodeRef>();
}
auto n = make_node<ScanOpNode>();
CHECK_EQ(init.size(), update.size());
CHECK_EQ(init.size(), state_placeholder.size());
......
......@@ -418,6 +418,19 @@ bool Equal(const Stmt& lhs, const Stmt& rhs) {
}
bool Equal(const Expr& lhs, const Expr& rhs) {
// quick pass for constant expressions.
if (const int64_t *a = as_const_int(lhs)) {
if (const int64_t *b = as_const_int(rhs)) {
return a[0] == b[0];
}
}
if (!lhs.defined()) {
if (rhs.defined()) return false;
if (!rhs.defined()) return true;
} else {
if (!rhs.defined()) return false;
}
// deep comparison.
return IRDeepCompare().Equal(lhs, rhs);
}
......
......@@ -11,7 +11,7 @@ namespace relay {
using tvm::IRPrinter;
using namespace tvm::runtime;
TensorType TensorTypeNode::make(Array<ShapeExpr> shape, DataType dtype) {
TensorType TensorTypeNode::make(Array<IndexExpr> shape, DataType dtype) {
NodePtr<TensorTypeNode> n = make_node<TensorTypeNode>();
n->shape = std::move(shape);
n->dtype = std::move(dtype);
......@@ -24,7 +24,7 @@ TensorType TensorTypeNode::Scalar(DataType dtype) {
TVM_REGISTER_API("relay._make.TensorType")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Array<ShapeExpr> shape = args[0];
Array<IndexExpr> shape = args[0];
*ret = TensorTypeNode::make(shape, args[1]);
});
......
/*!
* Copyright (c) 2018 by Contributors
* \file convolution.cc
* \brief Convolution operators
*/
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <vector>
#include "layout.h"
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(ConvAttrs);
bool Conv2DRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr) return false;
static const Layout kNCHW("NCHW");
static const Layout kOIHW("OIHW");
const ConvAttrs* param = attrs.as<ConvAttrs>();
CHECK(param != nullptr);
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->weight_layout);
CHECK(in_layout.convertible(kNCHW))
<< "Conv only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
CHECK(kernel_layout.convertible(kOIHW))
<< "Conv only support kernel layouts that are convertible from OIHW."
<< " But got "<< kernel_layout;
Layout out_layout(param->out_layout);
if (!out_layout.defined()) out_layout = in_layout;
CHECK(out_layout.convertible(kNCHW))
<< "Conv only support output layouts that are convertible from NCHW."
<< " But got " << out_layout;
IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
// infer weight if the kernel_size and channels are defined
if (param->kernel_size.defined() && param->channels.defined()) {
CHECK_EQ(param->kernel_size.size(), 2);
CHECK_EQ(param->dilation.size(), 2);
std::vector<IndexExpr> wshape(
{param->channels / param->groups,
data->shape[1] / param->groups,
param->kernel_size[0],
param->kernel_size[1]});
wshape = ConvertLayout(wshape, kOIHW, kernel_layout);
wshape[kernel_layout.indexof('O')] *= param->groups;
channels = param->channels;
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
// assign result to reporter
reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
auto wshape = ConvertLayout(weight->shape, kernel_layout, kOIHW);
if (param->kernel_size.defined()) {
CHECK_EQ(param->kernel_size.size(), 2);
// check the size
CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
reporter->AssertEQ(param->kernel_size[1], wshape[3]))
<< "Conv2D: shape of weight is inconsistent with kernel_size, "
<< " kernel_size=" << param->kernel_size
<< " wshape=" << Array<IndexExpr>(wshape);
}
if (param->channels.defined()) {
CHECK(reporter->AssertEQ(param->channels, wshape[0]))
<< "Conv2D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels
<< " wshape=" << Array<IndexExpr>(wshape);
}
CHECK(reporter->AssertEQ(data->shape[1] / param->groups, wshape[1]));
channels = wshape[0];
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
}
// dilation
std::vector<IndexExpr> oshape({data->shape[0], channels, 0, 0});
oshape[2] = (data->shape[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1;
oshape[3] = (data->shape[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1;
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
oshape = ConvertLayout(oshape, kNCHW, out_layout);
// assign output type
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
return true;
}
// Positional relay function to create conv2d operator
// used by frontend FFI.
Expr MakeConv2D(Expr data,
Expr weight,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int groups,
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string weight_layout,
std::string out_layout,
DataType out_dtype) {
auto attrs = make_node<ConvAttrs>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->channels = channels;
attrs->kernel_size = kernel_size;
attrs->data_layout = std::move(data_layout);
attrs->weight_layout = std::move(weight_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("conv2d");
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.conv2d")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 12>(MakeConv2D, args, rv);
});
RELAY_REGISTER_OP("conv2d")
.describe(R"code(2D convolution layer (e.g. spatial convolution over images).
This layer creates a convolution kernel that is convolved
with the layer input to produce a tensor of outputs.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, in_channels, height, width) if `layout` is `NCHW`.
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
(batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(2)
.add_type_rel("Conv2D", Conv2DRel);
} // namespace relay
} // namespace tvm
......@@ -69,8 +69,8 @@ Type ConcreteBroadcast(const TensorType& t1,
rev_sh2++;
}
Array<ShapeExpr> larger;
Array<ShapeExpr> smaller;
Array<IndexExpr> larger;
Array<IndexExpr> smaller;
for (int i = 0; i < (full_len - suffix_len); i++) {
smaller.push_back(make_const(tvm::Int(64), 1));
......@@ -93,7 +93,7 @@ Type ConcreteBroadcast(const TensorType& t1,
CHECK_EQ(larger.size(), smaller.size());
Array<ShapeExpr> out_shape;
Array<IndexExpr> out_shape;
for (size_t i = 0; i < smaller.size(); i++) {
auto left = smaller[i].as<tvm::ir::IntImm>();
auto right = larger[i].as<tvm::ir::IntImm>();
......
/*!
* Copyright (c) 2018 by Contributors
* \file src/tvm/relay/pass/alpha_eq.cc
* \brief Compute the set of variables not bound in the expression.
* \brief The structral equivalence comparison.
*/
#include <tvm/ir_pass.h>
#include <tvm/relay/expr_functor.h>
#include "./type_visitor.h"
#include "tvm/relay/pass.h"
......@@ -19,9 +20,23 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
TypeAlphaEq() : eq_map(), equal(true) {}
void DataTypeEqual(const DataType& dt1, const DataType& dt2) {
equal = equal && dt1 == dt2;
if (dt1 != dt2) {
equal = false;
}
}
void ShapeEqual(const Array<IndexExpr>& s1, const Array<IndexExpr>& s2) {
if (s1.size() != s2.size()) {
equal = false;
return;
}
for (size_t i = 0; i < s1.size(); ++i) {
if (!tvm::ir::Equal(s1[i], s2[i])) {
equal = false;
return;
}
}
}
void ShapeEqual(Array<ShapeExpr> s1, Array<ShapeExpr> s2) {}
void VisitType_(const TensorTypeNode *tt1, const Type& t2) final {
if (const TensorTypeNode *tt2 = t2.as<TensorTypeNode>()) {
......
......@@ -354,8 +354,8 @@ Expr TypeInferencer::Infer(Expr expr) {
return Resolver(type_map_, &solver_).VisitExpr(expr);
}
Expr InferType(const Environment& env, const Expr& e) {
return TypeInferencer(env).Infer(e);
Expr InferType(const Environment& env, const Expr& expr) {
return TypeInferencer(env).Infer(expr);
}
Expr InferType(const Environment& env,
......@@ -370,11 +370,9 @@ Expr InferType(const Environment& env,
return func_ret;
}
TVM_REGISTER_API("relay._ir_pass.check_expr")
TVM_REGISTER_API("relay._ir_pass.infer_type")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Environment env = args[0];
Expr e = args[1];
*ret = InferType(env, e);
*ret = InferType(args[0], args[1]);
});
} // namespace relay
......
......@@ -18,8 +18,13 @@ class TypeSolver::Reporter : public TypeReporterNode {
solver_->Unify(dst, src);
}
void AssertEQ(const ShapeExpr& lhs, const ShapeExpr& rhs) final {
// TODO(tqchen)
bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) final {
// early warning constant case.
IndexExpr diff = lhs - rhs;
if (const int64_t* pdiff = as_const_int(diff)) {
return pdiff[0] == 0;
}
return true;
}
private:
......
......@@ -8,7 +8,6 @@ ib = IRBuilder()
def show(e):
r = debug_print(ib.env, e)
assert r is not None
# print(r) # uncomment this line to debug
def test_constant():
......
import tvm
from tvm import relay
def test_conv2d_infer_type():
# symbolic in batch dimension
ib = relay.ir_builder.IRBuilder()
n, c, h, w = tvm.var("n"), 10, 224, 224
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
w = ib.param("w", relay.ty.IncompleteType())
with ib.function(x, w) as func:
ib.ret(relay.nn.conv2d(x.var, w.var,
kernel_size=(3, 3),
padding=(1, 1),
channels=2))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type()
assert ftype.ret_type == relay.ty.TensorType(
(n, 2, 224, 224), "float32")
assert ftype.arg_types[1] == relay.ty.TensorType(
(2, 10, 3, 3), "float32")
# infer by shape of w, mixed precision
ib = relay.ir_builder.IRBuilder()
n, c, h, w = tvm.var("n"), 10, 224, 224
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8"))
w = ib.param("w", relay.ty.TensorType((2, 10, 3, 3), "int8"))
with ib.function(x, w) as func:
ib.ret(relay.nn.conv2d(x.var, w.var, out_dtype="int32"))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type()
assert ftype.ret_type == relay.ty.TensorType(
(n, 2, 222, 222), "int32")
# Infer with a different layout
ib = relay.ir_builder.IRBuilder()
n, c, h, w = 4, 32, 224, 224
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8"))
w = ib.param("w", relay.ty.IncompleteType())
with ib.function(x, w) as func:
ib.ret(relay.nn.conv2d(x.var, w.var,
kernel_size=(3, 3),
padding=(1, 1),
channels=16,
data_layout="NCHW4n4c",
weight_layout="OIHW4o4i",
out_dtype="int32"))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type()
assert ftype.ret_type == relay.ty.TensorType(
(1, 4, 224, 224, 4, 4), "int32")
assert ftype.arg_types[1] == relay.ty.TensorType(
(4, 8, 3, 3, 4, 4), "int8")
if __name__ == "__main__":
test_conv2d_infer_type()
import tvm
from tvm import relay
def test_type_alpha_eq():
t1 = relay.ty.TensorType((3, 4), "float32")
t2 = relay.ty.TensorType((3, 4), "float32")
t3 = relay.ty.TensorType((3, 4, 5), "float32")
assert t1 == t2
assert t1 != t3
t1 = relay.ty.TensorType((), "float32")
t2 = relay.ty.TensorType((), "float32")
assert t1 == t2
if __name__ == "__main__":
test_type_alpha_eq()
......@@ -3,7 +3,7 @@
"""
import tvm
import numpy as np
from tvm.relay.ir_pass import check_expr
from tvm.relay.ir_pass import infer_type
from tvm.relay.ir_builder import IRBuilder, func_type
from tvm.relay.ir_builder import scalar_type, convert, tensor_type
from tvm.relay.env import Environment
......@@ -11,8 +11,11 @@ from tvm.relay.op import log, add, equal, subtract, concat
from tvm.relay.expr import Function
def assert_has_type(expr, typ, env=Environment({})):
checked_expr = check_expr(env, expr)
assert checked_expr.checked_type() == typ
checked_expr = infer_type(env, expr)
checked_type = checked_expr.checked_type()
if checked_type != typ:
raise RuntimeError("Type mismatch %s vs %s" % (
checked_type, typ))
def assert_decl_has_type(env, name, typ):
......@@ -47,6 +50,7 @@ def test_add_op():
}
"""
b = IRBuilder()
x = b.param('x', tensor_type(5, 5, 5))
y = b.param('y', tensor_type(5, 5, 5))
with b.function(x, y) as func:
......@@ -71,8 +75,9 @@ def test_add_broadcast_op():
b.ret(add(x.var, y.var))
b.ret(func)
prog, env = b.get()
ttype = tensor_type(5, 5, 5)
expected_ty = func_type([ttype, ttype], ttype)
expected_ty = func_type([tensor_type(10, 4), tensor_type(5, 10, 1)],
tensor_type(5, 10, 4))
assert_has_type(func.to_func(), expected_ty)
def test_dual_op():
......@@ -89,7 +94,9 @@ def test_dual_op():
t1 = b.let('t1', log(x))
t2 = b.let('t2', add(t1, x))
b.ret(t2)
assert_has_type(func.to_func(), func_type(['float32'], 'float32'))
assert_has_type(func.to_func(),
func_type([tensor_type(10, 10)], tensor_type(10, 10)))
def test_decl():
......@@ -152,12 +159,12 @@ def test_concat():
assert_decl_has_type(ib.env, try_concat2, fn_ty)
if __name__ == "__main__":
test_recursion()
test_dual_op()
test_recursion()
test_monomorphic_let()
test_single_op()
test_add_op()
test_add_broadcast_op()
test_dual_op()
test_decl()
test_concat()
......@@ -59,6 +59,7 @@ inline int64_t GetConstInt(Expr expr) {
*/
inline std::vector<int> GetConstIntValues(Array<Expr> exprs, const std::string& var_name) {
std::vector<int> result;
if (!exprs.defined()) return result;
for (auto expr : exprs) {
CHECK(IsConstInt(expr)) << "All elements of " << var_name << " must be constant integers";
result.push_back(GetConstInt(expr));
......@@ -77,6 +78,7 @@ inline std::vector<int> GetConstIntValues(Array<Expr> exprs, const std::string&
*/
inline std::vector<int64_t> GetConstInt64Values(Array<Expr> exprs, const std::string& var_name) {
std::vector<int64_t> result;
if (!exprs.defined()) return result;
for (auto expr : exprs) {
CHECK(IsConstInt(expr)) << "All elements of " << var_name << " must be constant integers";
result.push_back(GetConstInt(expr));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment