Unverified Commit b9038343 by Tianqi Chen Committed by GitHub

[RELAY][OP] Move computes to cxx, enable concat as injective (#2166)

parent f8f06595
......@@ -236,18 +236,14 @@ class GraphRuntimeCodegen(ExprFunctor):
self.lowered_funcs.add(loweredf)
inputs = []
tuple_arg_count = 0
# flatten tuple in the call.
for arg in call.args:
res = self.visit(arg)
if isinstance(arg.checked_type, TupleType):
tuple_arg_count += 1
inputs.append(self.visit(arg))
# We need to specially handle tuple inputs and
# tuple output cases.
# Tuple input function(e.g. concat)
if tuple_arg_count:
assert len(call.args) == 1
assert isinstance(inputs[0], tuple)
inputs = list(inputs[0])
assert isinstance(res, tuple)
inputs += res
else:
inputs.append(res)
inputs = [x.to_json() for x in inputs]
op_name = cached_func.func_name
......
......@@ -589,11 +589,11 @@ def from_mxnet(symbol,
shape, dtype = _update_shape_dtype(shape, dtype, params)
sym = _from_mxnet_impl(symbol, shape, dtype)
elif isinstance(symbol, mx.gluon.HybridBlock):
if args_params is not None or aux_params is not None:
if arg_params is not None or aux_params is not None:
raise ValueError("arg_params and aux_params ae not used when importing HybridBlock")
params = {}
for k, v in symbol.collect_params().items():
params[k] = tvm.nd.array(v.data().asnumpy())
params[k] = _nd.array(v.data().asnumpy())
data = mx.sym.Variable("data")
sym = symbol(data)
shape, dtype = _update_shape_dtype(shape, dtype, params)
......
......@@ -5,223 +5,37 @@ import topi
from .op import register_compute, register_schedule, register_pattern
from .op import schedule_injective, OpPattern
schedule_broadcast = schedule_injective
schedule_elemwise = schedule_injective
# log
@register_compute("log")
def log_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.log(inputs[0])]
register_schedule("log", schedule_broadcast)
# exp
@register_compute("exp")
def exp_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.exp(inputs[0])]
register_schedule("exp", schedule_broadcast)
# sqrt
@register_compute("sqrt")
def sqrt_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.sqrt(inputs[0])]
register_schedule("sqrt", schedule_broadcast)
# sigmoid
@register_compute("sigmoid")
def sigmoid_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.sigmoid(inputs[0])]
register_schedule("sigmoid", schedule_broadcast)
# floor
@register_compute("floor")
def floor_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.floor(inputs[0])]
register_schedule("floor", schedule_broadcast)
# ceil
@register_compute("ceil")
def ceil_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.ceil(inputs[0])]
register_schedule("ceil", schedule_broadcast)
# trunc
@register_compute("trunc")
def trunc_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.trunc(inputs[0])]
register_schedule("trunc", schedule_broadcast)
# round
@register_compute("round")
def round_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.round(inputs[0])]
register_schedule("round", schedule_broadcast)
# abs
@register_compute("abs")
def abs_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.abs(inputs[0])]
register_schedule("abs", schedule_broadcast)
# tanh
@register_compute("tanh")
def tanh_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.tanh(inputs[0])]
register_schedule("tanh", schedule_broadcast)
# negative
@register_compute("negative")
def negative_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.negative(inputs[0])]
register_schedule("negative", schedule_broadcast)
# add
@register_compute("add")
def add_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.add(inputs[0], inputs[1])]
register_schedule("add", schedule_injective)
# subtract
@register_compute("subtract")
def subtract_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.subtract(inputs[0], inputs[1])]
register_schedule("add", schedule_broadcast)
register_schedule("subtract", schedule_broadcast)
# multiply
@register_compute("multiply")
def multiply_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.multiply(inputs[0], inputs[1])]
register_schedule("multiply", schedule_broadcast)
# divide
@register_compute("divide")
def divide_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.divide(inputs[0], inputs[1])]
register_schedule("divide", schedule_broadcast)
# power
@register_compute("power")
def power_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.power(inputs[0], inputs[1])]
register_schedule("power", schedule_injective)
# mod
@register_compute("mod")
def mod_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.mod(inputs[0], inputs[1])]
register_schedule("mod", schedule_broadcast)
# equal
@register_compute("equal")
def equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.equal(inputs[0], inputs[1])]
register_schedule("equal", schedule_broadcast)
# not_equal
@register_compute("not_equal")
def not_equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.not_equal(inputs[0], inputs[1])]
register_schedule("not_equal", schedule_broadcast)
# less
@register_compute("less")
def less_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.less(inputs[0], inputs[1])]
register_schedule("less", schedule_broadcast)
# less equal
@register_compute("less_equal")
def less_equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.less_equal(inputs[0], inputs[1])]
register_schedule("less_equal", schedule_broadcast)
# greater
@register_compute("greater")
def greater_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.greater(inputs[0], inputs[1])]
register_schedule("greater", schedule_broadcast)
# greater equal
@register_compute("greater_equal")
def greater_equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.greater_equal(inputs[0], inputs[1])]
register_schedule("greater_equal", schedule_broadcast)
# maximum
@register_compute("maximum")
def maximum_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.maximum(inputs[0], inputs[1])]
register_schedule("maximum_compute", schedule_injective)
# minimum
@register_compute("minimum")
def minimum_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.minimum(inputs[0], inputs[1])]
register_schedule("minimum", schedule_injective)
# right shift
@register_compute("right_shift")
def right_shift_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.right_shift(inputs[0], inputs[1])]
register_schedule("right_shift", schedule_injective)
# left shift
@register_compute("left_shift")
def left_shift_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.left_shift(inputs[0], inputs[1])]
register_schedule("left_shift", schedule_injective)
# zeros
......@@ -273,5 +87,4 @@ def concatenate_compute(attrs, inputs, output_type, target):
return [topi.concatenate(inputs, axis=attrs.axis)]
register_schedule("concatenate", schedule_injective)
# TODO(tqchen): renable concat as injective
register_pattern("concatenate", OpPattern.OPAQUE)
register_pattern("concatenate", OpPattern.INJECTIVE)
......@@ -56,30 +56,26 @@ class ScheduleGetter :
Op::GetAttr<FTVMSchedule>("FTVMSchedule");
auto cache_node = make_node<CachedFuncNode>();
cache_node->target = target_;
if (prim_func->params.size() == 1 &&
prim_func->params[0]->checked_type().as<TupleTypeNode>()) {
// Handle tuple input type by flattening them.
// This is the current calling convention of tuple input.
for (Var param : prim_func->params) {
Array<tvm::Tensor> inputs;
for (Type field : prim_func->params[0]->type_as<TupleTypeNode>()->fields) {
const auto* ttype = field.as<TensorTypeNode>();
CHECK(ttype != nullptr);
if (const auto* ttype = param->checked_type().as<TensorTypeNode>()) {
tvm::Tensor tensor = tvm::placeholder(
GetShape(ttype->shape), ttype->dtype);
cache_node->inputs.push_back(tensor);
inputs.push_back(tensor);
} else {
// flatten tuple of tensor type.
const auto* tuple_type = param->type_as<TupleTypeNode>();
for (Type field : tuple_type->fields) {
const auto* ttype = field.as<TensorTypeNode>();
CHECK(ttype != nullptr);
tvm::Tensor tensor = tvm::placeholder(
GetShape(ttype->shape), ttype->dtype);
cache_node->inputs.push_back(tensor);
inputs.push_back(tensor);
}
}
memo_[prim_func->params[0]] = inputs;
} else {
for (Var param : prim_func->params) {
const auto* ttype = param->type_as<TensorTypeNode>();
tvm::Tensor tensor = tvm::placeholder(
GetShape(ttype->shape), ttype->dtype);
cache_node->inputs.push_back(tensor);
memo_[param] = Array<Tensor>({tensor});
}
memo_[param] = inputs;
}
readable_name_stream_ << "fused";
cache_node->outputs = this->VisitExpr(prim_func->body);
......@@ -161,8 +157,9 @@ class ScheduleGetter :
int op_pattern = fpattern[op];
if (op_pattern >= kCommReduce) {
CHECK(!master_op_.defined())
<< "Two complicated op in a primitive function";
CHECK(!master_op_.defined() || master_op_patetrn_ < kCommReduce)
<< "Two complicated op in a primitive function "
<< " master=" << master_op_ << " current=" << op;
}
if (op_pattern >= master_op_patetrn_) {
master_op_ = op;
......
......@@ -212,7 +212,7 @@ class Interpreter :
// Marshal the arguments.
// Handle tuple input/output by flattening them.
size_t arg_len = 0;
for (size_t i = 0; i < args.size(); i++) {
for (size_t i = 0; i < args.size(); ++i) {
if (args[i].as<TensorValueNode>()) {
++arg_len;
} else {
......@@ -242,22 +242,19 @@ class Interpreter :
<< context_ << ", but get " << arg_ctx;
};
if (func->params.size() == 1 &&
func->params[0]->checked_type().as<TupleTypeNode>()) {
// handle tuple input.
const TupleValueNode* tuple = args[0].as<TupleValueNode>();
CHECK(tuple);
for (size_t i = 0; i < tuple->fields.size(); ++i) {
fset_input(i, tuple->fields[i]);
}
} else {
CHECK_EQ(num_inputs, args.size());
// Decide the target context.
// Primitive functions always sit in the same context.
for (size_t i = 0; i < args.size(); i++) {
fset_input(i, args[i]);
int arg_counter = 0;
for (Value arg : args) {
if (arg.as<TensorValueNode>()) {
fset_input(arg_counter++, arg);
} else {
const TupleValueNode* tuple = arg.as<TupleValueNode>();
CHECK(tuple != nullptr);
for (size_t i = 0; i < tuple->fields.size(); ++i) {
fset_input(arg_counter++, tuple->fields[i]);
}
}
}
// TVM's calling convention is that the final argument is the output
// buffer. To preserve the illusion of being a functional language
// we need to allocate space for the output buffer based on the
......
......@@ -5,54 +5,75 @@
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <topi/broadcast.h>
#include "../type_relations.h"
#include "../op_common.h"
namespace tvm {
namespace relay {
#define RELAY_BINARY_COMPUTE(FTOPI) \
[] (const Attrs& attrs, \
const Array<Tensor>& inputs, \
const Type& out_type, \
const Target& target) -> Array<Tensor> { \
CHECK_EQ(inputs.size(), 2U); \
return {FTOPI(inputs[0], inputs[1])}; \
} \
// Addition
RELAY_REGISTER_BINARY_OP("relay.op._make.", "add")
.describe("Elementwise add with with broadcasting")
.set_support_level(1);
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::add));
// Subtraction
RELAY_REGISTER_BINARY_OP("relay.op._make.", "subtract")
.describe("Elementwise substract with broadcasting")
.set_support_level(1);
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::subtract));
// Right shift
RELAY_REGISTER_BINARY_OP("relay.op._make.", "right_shift")
.describe("Elementwise right shift with broadcasting")
.set_support_level(4);
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::right_shift));
RELAY_REGISTER_BINARY_OP("relay.op._make.", "left_shift")
.describe("Elementwise left shift with broadcasting")
.set_support_level(4);
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::left_shift));
RELAY_REGISTER_BINARY_OP("relay.op._make.", "maximum")
.describe("Elementwise maximum of two tensors with broadcasting")
.set_support_level(4);
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::maximum));
RELAY_REGISTER_BINARY_OP("relay.op._make.", "minimum")
.describe("Elementwise minimum of two tensors with broadcasting")
.set_support_level(4);
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::minimum));
RELAY_REGISTER_BINARY_OP("relay.op._make.", "divide")
.describe("Elementwise divide with broadcasting")
.set_support_level(1);
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::divide));
RELAY_REGISTER_BINARY_OP("relay.op._make.", "multiply")
.describe("Elementwise multiply with broadcasting")
.set_support_level(1);
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::multiply));
RELAY_REGISTER_BINARY_OP("relay.op._make.", "power")
.describe("Elementwise power with broadcasting")
.set_support_level(4);
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::power));
RELAY_REGISTER_BINARY_OP("relay.op._make.", "mod")
.describe("Elementwise mod with broadcasting")
.set_support_level(1);
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::mod));
// Comparisons
#define RELAY_REGISTER_CMP_OP(OpName) \
......@@ -70,22 +91,38 @@ RELAY_REGISTER_BINARY_OP("relay.op._make.", "mod")
RELAY_REGISTER_CMP_OP("equal")
.describe("Elementwise equal compare with broadcasting")
.set_support_level(4);
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::equal));
RELAY_REGISTER_CMP_OP("not_equal")
.describe("Elementwise not equal with broadcasting")
.set_support_level(4);
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::not_equal));
RELAY_REGISTER_CMP_OP("less")
.describe("Elementwise less than with broadcasting")
.set_support_level(4);
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::less));
RELAY_REGISTER_CMP_OP("less_equal")
.describe("Elementwise less than or equal compare with broadcasting")
.set_support_level(4);
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::less_equal));
RELAY_REGISTER_CMP_OP("greater")
.describe("Elementwise greater than compare with broadcasting")
.set_support_level(4);
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::greater));
RELAY_REGISTER_CMP_OP("greater_equal")
.describe("Elementwise greater than or equal compare with broadcasting")
.set_support_level(4);
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::greater_equal));
} // namespace relay
} // namespace tvm
......@@ -5,12 +5,21 @@
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <topi/elemwise.h>
#include "../type_relations.h"
#include "../op_common.h"
namespace tvm {
namespace relay {
#define RELAY_UNARY_COMPUTE(FTOPI) \
[] (const Attrs& attrs, \
const Array<Tensor>& inputs, \
const Type& out_type, \
const Target& target) -> Array<Tensor> { \
return {FTOPI(inputs[0])}; \
} \
RELAY_REGISTER_UNARY_OP("relay.op._make.", "log")
.describe(R"code(Returns the log input array, computed element-wise.
......@@ -20,7 +29,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "log")
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.add_type_rel("Identity", IdentityRel);
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "exp")
.describe(R"code(Returns the exp input array, computed element-wise.
......@@ -30,7 +41,8 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "exp")
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.add_type_rel("Identity", IdentityRel);
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "sqrt")
......@@ -41,7 +53,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "sqrt")
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.add_type_rel("Identity", IdentityRel);
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sqrt));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "zeros_like")
.describe(R"code(Returns an array of zeros, with same type and shape as the input.
......@@ -49,6 +63,7 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "zeros_like")
.set_support_level(1)
.add_type_rel("Identity", IdentityRel);
RELAY_REGISTER_UNARY_OP("relay.op._make.", "ones_like")
.describe(R"code(Returns an array of ones, with same type and shape as the input.
)code" TVM_ADD_FILELINE)
......@@ -63,13 +78,17 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "sigmoid")
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.add_type_rel("Identity", IdentityRel);
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sigmoid));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "copy")
.describe(R"code(Copy a tensor.
)code" TVM_ADD_FILELINE)
.set_support_level(3)
.add_type_rel("Identity", IdentityRel);
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::identity));
// Clip
struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
......@@ -107,7 +126,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "floor")
.describe(R"code(Returns the floor of input array, computed element-wise.
)code" TVM_ADD_FILELINE)
.set_support_level(3)
.add_type_rel("Identity", IdentityRel);
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::floor));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "ceil")
.describe(R"code(Returns the ceil of input array, computed element-wise.
......@@ -117,7 +138,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "ceil")
)code" TVM_ADD_FILELINE)
.set_support_level(3)
.add_type_rel("Identity", IdentityRel);
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::ceil));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "trunc")
.describe(R"code(Returns the trunc of input array, computed element-wise.
......@@ -127,7 +150,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "trunc")
)code" TVM_ADD_FILELINE)
.set_support_level(3)
.add_type_rel("Identity", IdentityRel);
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::trunc));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "round")
.describe(R"code(Returns the round of input array, computed element-wise.
......@@ -137,7 +162,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "round")
)code" TVM_ADD_FILELINE)
.set_support_level(3)
.add_type_rel("Identity", IdentityRel);
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::round));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "abs")
.describe(R"code(Returns the abs of input array, computed element-wise.
......@@ -147,7 +174,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "abs")
)code" TVM_ADD_FILELINE)
.set_support_level(3)
.add_type_rel("Identity", IdentityRel);
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::abs));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "tanh")
.describe(R"code(Returns the tanh of input array, computed element-wise.
......@@ -157,7 +186,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "tanh")
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.add_type_rel("Identity", IdentityRel);
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "negative")
.describe(R"code(Returns the numeric negative of input array, computed element-wise.
......@@ -167,7 +198,8 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "negative")
)code" TVM_ADD_FILELINE)
.set_support_level(3)
.add_type_rel("Identity", IdentityRel);
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::negative));
} // namespace relay
} // namespace tvm
......@@ -188,20 +188,22 @@ def test_concatenate():
x = relay.var("x", shape=(10, 5))
y = relay.var("y", shape=(10, 5))
t = relay.var("z", shape=())
z = relay.concatenate((x, y), axis=1)
z = relay.add(z, t)
# Check result.
func = relay.Function([x, y], z)
func = relay.Function([x, y, t], z)
x_data = np.random.rand(10, 5).astype('float32')
y_data = np.random.rand(10, 5).astype('float32')
ref_res = np.concatenate((x_data, y_data), axis=1)
t_data = np.random.uniform(size=()).astype('float32')
ref_res = np.concatenate((x_data, y_data), axis=1) + t_data
for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(x_data, y_data)
op_res1 = intrp1.evaluate(func)(x_data, y_data, t_data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=0.01)
op_res2 = intrp2.evaluate(func)(x_data, y_data)
op_res2 = intrp2.evaluate(func)(x_data, y_data, t_data)
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=0.01)
def test_dropout():
......@@ -306,11 +308,11 @@ def test_dense():
if __name__ == "__main__":
test_concatenate()
test_bias_add()
test_unary_op()
test_binary_op()
test_expand_dims_infer_type()
test_concatenate()
test_expand_dims()
test_softmax()
test_log_softmax()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment