Commit 4e77eeb2 by Jared Roesch Committed by Tianqi Chen

[RELAY][RUNTIME] Add compute and schedule attributes for all ops in relay/op/tensor.py (#2050)

parent ead3ac6c
...@@ -735,12 +735,12 @@ template<typename DerivedType> ...@@ -735,12 +735,12 @@ template<typename DerivedType>
class AttrsNode : public BaseAttrsNode { class AttrsNode : public BaseAttrsNode {
public: public:
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
detail::AttrNormalVisitor vis(v); ::tvm::detail::AttrNormalVisitor vis(v);
self()->__VisitAttrs__(vis); self()->__VisitAttrs__(vis);
} }
void VisitNonDefaultAttrs(AttrVisitor* v) final { void VisitNonDefaultAttrs(AttrVisitor* v) final {
detail::AttrNonDefaultVisitor vis(v); ::tvm::detail::AttrNonDefaultVisitor vis(v);
self()->__VisitAttrs__(vis); self()->__VisitAttrs__(vis);
} }
...@@ -761,7 +761,7 @@ class AttrsNode : public BaseAttrsNode { ...@@ -761,7 +761,7 @@ class AttrsNode : public BaseAttrsNode {
} }
return false; return false;
}; };
auto vis = detail::CreateInitVisitor(DerivedType::_type_key, ffind); auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind);
self()->__VisitAttrs__(vis); self()->__VisitAttrs__(vis);
hit_count = vis.hit_count_; hit_count = vis.hit_count_;
} else { } else {
...@@ -779,14 +779,14 @@ class AttrsNode : public BaseAttrsNode { ...@@ -779,14 +779,14 @@ class AttrsNode : public BaseAttrsNode {
} }
return false; return false;
}; };
auto vis = detail::CreateInitVisitor(DerivedType::_type_key, ffind); auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind);
self()->__VisitAttrs__(vis); self()->__VisitAttrs__(vis);
hit_count = vis.hit_count_; hit_count = vis.hit_count_;
} }
// error handling, slow path // error handling, slow path
if (hit_count * 2 != args.size() && !allow_unknown) { if (hit_count * 2 != args.size() && !allow_unknown) {
for (int i = 0; i < args.size(); i += 2) { for (int i = 0; i < args.size(); i += 2) {
detail::AttrExistVisitor visitor; ::tvm::detail::AttrExistVisitor visitor;
visitor.key_ = args[i].operator std::string(); visitor.key_ = args[i].operator std::string();
self()->__VisitAttrs__(visitor); self()->__VisitAttrs__(visitor);
if (!visitor.exist_) { if (!visitor.exist_) {
...@@ -803,7 +803,7 @@ class AttrsNode : public BaseAttrsNode { ...@@ -803,7 +803,7 @@ class AttrsNode : public BaseAttrsNode {
} }
Array<AttrFieldInfo> ListFieldInfo() const final { Array<AttrFieldInfo> ListFieldInfo() const final {
detail::AttrDocVisitor visitor; ::tvm::detail::AttrDocVisitor visitor;
self()->__VisitAttrs__(visitor); self()->__VisitAttrs__(visitor);
return visitor.fields_; return visitor.fields_;
} }
...@@ -813,13 +813,13 @@ class AttrsNode : public BaseAttrsNode { ...@@ -813,13 +813,13 @@ class AttrsNode : public BaseAttrsNode {
if (pself == other) return true; if (pself == other) return true;
if (other == nullptr) return false; if (other == nullptr) return false;
if (pself->type_index() != other->type_index()) return false; if (pself->type_index() != other->type_index()) return false;
detail::AttrsEqualVisitor visitor(pself, other, equal); ::tvm::detail::AttrsEqualVisitor visitor(pself, other, equal);
self()->__VisitAttrs__(visitor); self()->__VisitAttrs__(visitor);
return visitor.result_; return visitor.result_;
} }
size_t ContentHash(AttrsHash hasher) const final { size_t ContentHash(AttrsHash hasher) const final {
detail::AttrsHashVisitor visitor(hasher); ::tvm::detail::AttrsHashVisitor visitor(hasher);
visitor.result_ = std::hash<std::string>()(this->type_key()); visitor.result_ = std::hash<std::string>()(this->type_key());
self()->__VisitAttrs__(visitor); self()->__VisitAttrs__(visitor);
return visitor.result_; return visitor.result_;
......
...@@ -417,7 +417,7 @@ inline TVMRetValue GenericFunc::operator()(Args&& ...args) const { ...@@ -417,7 +417,7 @@ inline TVMRetValue GenericFunc::operator()(Args&& ...args) const {
const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
TVMValue values[kArraySize]; TVMValue values[kArraySize];
int type_codes[kArraySize]; int type_codes[kArraySize];
detail::for_each(TVMArgsSetter(values, type_codes), runtime::detail::for_each(TVMArgsSetter(values, type_codes),
std::forward<Args>(args)...); std::forward<Args>(args)...);
TVMRetValue rv; TVMRetValue rv;
CallPacked(TVMArgs(values, type_codes, kNumArgs), &rv); CallPacked(TVMArgs(values, type_codes, kNumArgs), &rv);
......
...@@ -138,7 +138,8 @@ class Executor(object): ...@@ -138,7 +138,8 @@ class Executor(object):
""" """
if params: if params:
scope_builder = ScopeBuilder() scope_builder = ScopeBuilder()
for key, value in params: for key in params:
value = params[key]
scope_builder.let(key, value) scope_builder.let(key, value)
scope_builder.ret(expr) scope_builder.ret(expr)
expr = scope_builder.get() expr = scope_builder.get()
...@@ -146,7 +147,17 @@ class Executor(object): ...@@ -146,7 +147,17 @@ class Executor(object):
if isinstance(expr, Function): if isinstance(expr, Function):
assert not ir_pass.free_vars(expr) assert not ir_pass.free_vars(expr)
return self._make_executor(expr) executor = self._make_executor(expr)
# If we are evaluating a function or top-level defintion
# the user must call the function themselves.
#
# If we are evaluating an open term with parameters we will
# just return them the result.
if isinstance(expr, (Function, GlobalVar)):
return executor
else:
return executor()
class Interpreter(Executor): class Interpreter(Executor):
...@@ -168,10 +179,14 @@ class Interpreter(Executor): ...@@ -168,10 +179,14 @@ class Interpreter(Executor):
self.mod._add(expr, func, True) self.mod._add(expr, func, True)
opt_expr = Call(expr, relay_args) opt_expr = Call(expr, relay_args)
return _interpreter.evaluate(self.mod, opt_expr) return _interpreter.evaluate(self.mod, opt_expr)
else: elif isinstance(expr, Function):
call = Call(expr, relay_args) call = Call(expr, relay_args)
opt_expr = self.optimize(call) opt_expr = self.optimize(call)
return _interpreter.evaluate(self.mod, opt_expr) return _interpreter.evaluate(self.mod, opt_expr)
else:
assert not args
opt_expr = self.optimize(expr)
return _interpreter.evaluate(self.mod, opt_expr)
return _interp_wrapper return _interp_wrapper
......
#pylint: disable=wildcard-import, redefined-builtin #pylint: disable=wildcard-import, redefined-builtin
"""Relay core operators.""" """Relay core operators."""
# operator defs # operator defs
from .op import get, register, Op from .op import get, register, register_schedule, register_compute, Op
# Operators # Operators
from .reduce import * from .reduce import *
......
#pylint: disable=invalid-name, unused-argument #pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration""" """Backend compiler related feature registration"""
from __future__ import absolute_import
import tvm import tvm
import topi import topi
from . import register import topi.cuda
from . import register_schedule, register_compute
def schedule_injective(outputs, target):
"""Generic schedule for binary broadcast."""
with tvm.target.create(target):
return topi.generic.schedule_injective(outputs)
schedule_broadcast = schedule_injective
schedule_elemwise = schedule_injective
# log
def log_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.log(inputs[0])]
register_compute("log", log_compute)
register_schedule("log", schedule_broadcast)
# exp
def exp_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.exp(inputs[0])]
register_compute("exp", exp_compute)
register_schedule("exp", schedule_broadcast)
# sqrt
def sqrt_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.sqrt(inputs[0])]
register_compute("sqrt", sqrt_compute)
register_schedule("sqrt", schedule_broadcast)
# sigmoid
def sigmoid_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.sigmoid(inputs[0])]
register_compute("sigmoid", sigmoid_compute)
register_schedule("sigmoid", schedule_broadcast)
# floor
def floor_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.floor(inputs[0])]
register_compute("floor", floor_compute)
register_schedule("floor", schedule_broadcast)
# ceil
def ceil_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.ceil(inputs[0])]
register_compute("ceil", ceil_compute)
register_schedule("ceil", schedule_broadcast)
# trunc
def trunc_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.trunc(inputs[0])]
register_compute("trunc", trunc_compute)
register_schedule("trunc", schedule_broadcast)
# round
def round_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.round(inputs[0])]
register_compute("round", round_compute)
register_schedule("round", schedule_broadcast)
# abs
def abs_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.abs(inputs[0])]
register_compute("abs", abs_compute)
register_schedule("abs", schedule_broadcast)
# tanh
def tanh_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.tanh(inputs[0])]
register_compute("tanh", tanh_compute)
register_schedule("tanh", schedule_broadcast)
# negative
def negative_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.negative(inputs[0])]
register_compute("negative", negative_compute)
register_schedule("negative", schedule_broadcast)
# add
def add_compute(attrs, inputs, output_type, target): def add_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2 assert len(inputs) == 2
return [topi.add(inputs[0], inputs[1])] return [topi.add(inputs[0], inputs[1])]
def add_schedule(outputs, target): register_compute("add", add_compute)
assert len(outputs) == 1 register_schedule("add", schedule_injective)
return tvm.create_schedule(outputs[0].op)
register("add", "FTVMCompute", add_compute)
register("add", "FTVMSchedule", add_schedule)
# subtract
def subtract_compute(attrs, inputs, output_type, target): def subtract_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2 assert len(inputs) == 2
return [topi.subtract(inputs[0], inputs[1])] return [topi.subtract(inputs[0], inputs[1])]
def subtract_schedule(outputs, target): register_compute("subtract", subtract_compute)
assert len(outputs) == 1 register_schedule("subtract", schedule_broadcast)
return tvm.create_schedule(outputs[0].op)
register("subtract", "FTVMCompute", subtract_compute)
register("subtract", "FTVMSchedule", subtract_schedule)
# multiply
def multiply_compute(attrs, inputs, output_type, target): def multiply_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2 assert len(inputs) == 2
return [topi.multiply(inputs[0], inputs[1])] return [topi.multiply(inputs[0], inputs[1])]
def multiply_schedule(outputs, target): register_compute("multiply", multiply_compute)
assert len(outputs) == 1 register_schedule("multiply", schedule_broadcast)
return tvm.create_schedule(outputs[0].op)
# divide
def divide_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.divide(inputs[0], inputs[1])]
register_compute("divide", divide_compute)
register_schedule("divide", schedule_broadcast)
register("multiply", "FTVMCompute", multiply_compute) # pow
register("multiply", "FTVMSchedule", multiply_schedule) def pow_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.power(inputs[0], inputs[1])]
register_compute("pow", pow_compute)
register_schedule("pow", schedule_injective)
# mod
def mod_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.mod(inputs[0], inputs[1])]
register_compute("mod", mod_compute)
register_schedule("mod", schedule_broadcast)
# equal
def equal_compute(attrs, inputs, output_type, target): def equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2 assert len(inputs) == 2
return [topi.equal(inputs[0], inputs[1])] return [topi.equal(inputs[0], inputs[1])]
def equal_schedule(outputs, target): register_compute("equal", equal_compute)
assert len(outputs) == 1 register_schedule("equal", schedule_broadcast)
return tvm.create_schedule(outputs[0].op)
# not_equal
def not_equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.not_equal(inputs[0], inputs[1])]
register_compute("not_equal", not_equal_compute)
register_schedule("not_equal", schedule_broadcast)
# less
def less_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.less(inputs[0], inputs[1])]
register_compute("less", less_compute)
register_schedule("less", schedule_broadcast)
# less equal
def less_equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.less_equal(inputs[0], inputs[1])]
register_compute("less_equal", less_equal_compute)
register_schedule("less_equal", schedule_broadcast)
# greater
def greater_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.greater(inputs[0], inputs[1])]
register_compute("greater", greater_compute)
register_schedule("greater", schedule_broadcast)
# greater equal
def greater_equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.greater_equal(inputs[0], inputs[1])]
register_compute("greater_equal", greater_equal_compute)
register_schedule("greater_equal", schedule_broadcast)
# maximum
def maximum_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.maximum(inputs[0], inputs[1])]
register_compute("maximum_compute", maximum_compute)
register_schedule("maximum_compute", schedule_injective)
# minimum
def minimum_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.minimum(inputs[0], inputs[1])]
register_compute("minimum", minimum_compute)
register_schedule("minimum", schedule_injective)
# right shift
def right_shift_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.right_shift(inputs[0], inputs[1])]
register_compute("right_shift", right_shift_compute)
register_schedule("right_shift", schedule_injective)
# lift shift
def left_shift_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.left_shift(inputs[0], inputs[1])]
register_compute("left_shift", left_shift_compute)
register_schedule("left_shift", schedule_injective)
# zeros
def zeros_compute(attrs, inputs, output_type, target):
assert not inputs
return [topi.full(output_type.shape, output_type.dtype, 0.0)]
register_compute("zeros", zeros_compute)
register_schedule("zeros", schedule_injective)
# zeros_like
def zeros_like_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.full_like(inputs[0], 0.0)]
register_compute("zeros_like", zeros_like_compute)
register_schedule("zeros_like", schedule_injective)
# ones
def ones_compute(attrs, inputs, output_type, target):
assert not inputs
return [topi.full(output_type.shape, output_type.dtype, 1.0)]
register_compute("ones", ones_compute)
register_schedule("ones", schedule_injective)
# ones_like
def ones_like(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.full_like(inputs[0], 1.0)]
register_compute("ones_like", ones_like)
register_schedule("ones_like", schedule_injective)
# clip
def clip_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)]
register("equal", "FTVMCompute", equal_compute) register_compute("clip", clip_compute)
register("equal", "FTVMSchedule", equal_schedule) register_schedule("clip", schedule_injective)
...@@ -74,6 +74,11 @@ def register(op_name, attr_key, value=None, level=10): ...@@ -74,6 +74,11 @@ def register(op_name, attr_key, value=None, level=10):
return v return v
return _register(value) if value else _register return _register(value) if value else _register
def register_schedule(op_name, schedule):
register(op_name, "FTVMSchedule", schedule)
def register_compute(op_name, compute):
register(op_name, "FTVMCompute", compute)
_init_api("relay.op", __name__) _init_api("relay.op", __name__)
......
...@@ -213,9 +213,8 @@ def add(lhs, rhs): ...@@ -213,9 +213,8 @@ def add(lhs, rhs):
""" """
return _make.add(lhs, rhs) return _make.add(lhs, rhs)
def subtract(lhs, rhs):
def multiply(lhs, rhs): """Subtraction with numpy-style broadcasting.
"""Multiplication with numpy-style broadcasting.
Parameters Parameters
---------- ----------
...@@ -229,11 +228,10 @@ def multiply(lhs, rhs): ...@@ -229,11 +228,10 @@ def multiply(lhs, rhs):
result : relay.Expr result : relay.Expr
The computed result. The computed result.
""" """
return _make.multiply(lhs, rhs) return _make.subtract(lhs, rhs)
def divide(lhs, rhs): def multiply(lhs, rhs):
"""Division with numpy-style broadcasting. """Multiplication with numpy-style broadcasting.
Parameters Parameters
---------- ----------
...@@ -247,11 +245,11 @@ def divide(lhs, rhs): ...@@ -247,11 +245,11 @@ def divide(lhs, rhs):
result : relay.Expr result : relay.Expr
The computed result. The computed result.
""" """
return _make.divide(lhs, rhs) return _make.multiply(lhs, rhs)
def pow(lhs, rhs): def divide(lhs, rhs):
"""Power with numpy-style broadcasting. """Division with numpy-style broadcasting.
Parameters Parameters
---------- ----------
...@@ -265,11 +263,11 @@ def pow(lhs, rhs): ...@@ -265,11 +263,11 @@ def pow(lhs, rhs):
result : relay.Expr result : relay.Expr
The computed result. The computed result.
""" """
return _make.pow(lhs, rhs) return _make.divide(lhs, rhs)
def mod(lhs, rhs): def pow(lhs, rhs):
"""Mod with numpy-style broadcasting. """Power with numpy-style broadcasting.
Parameters Parameters
---------- ----------
...@@ -283,11 +281,11 @@ def mod(lhs, rhs): ...@@ -283,11 +281,11 @@ def mod(lhs, rhs):
result : relay.Expr result : relay.Expr
The computed result. The computed result.
""" """
return _make.mod(lhs, rhs) return _make.pow(lhs, rhs)
def subtract(lhs, rhs): def mod(lhs, rhs):
"""Subtraction with numpy-style broadcasting. """Mod with numpy-style broadcasting.
Parameters Parameters
---------- ----------
...@@ -301,7 +299,7 @@ def subtract(lhs, rhs): ...@@ -301,7 +299,7 @@ def subtract(lhs, rhs):
result : relay.Expr result : relay.Expr
The computed result. The computed result.
""" """
return _make.subtract(lhs, rhs) return _make.mod(lhs, rhs)
def equal(lhs, rhs): def equal(lhs, rhs):
...@@ -553,7 +551,6 @@ def ones_like(data): ...@@ -553,7 +551,6 @@ def ones_like(data):
""" """
return _make.ones_like(data) return _make.ones_like(data)
def clip(a, a_min, a_max): def clip(a, a_min, a_max):
"""Clip the elements in `a` between `a_min` and `a_max`. """Clip the elements in `a` between `a_min` and `a_max`.
`a_min` and `a_max` are cast to `a`'s dtype. `a_min` and `a_max` are cast to `a`'s dtype.
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
*/ */
#include <tvm/lowered_func.h> #include <tvm/lowered_func.h>
#include <tvm/operation.h> #include <tvm/operation.h>
#include <tvm/build_module.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/logging.h> #include <tvm/relay/logging.h>
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
...@@ -155,8 +156,8 @@ struct LiveFunctions : ExprVisitor { ...@@ -155,8 +156,8 @@ struct LiveFunctions : ExprVisitor {
}; };
using FCompute = TypedPackedFunc<Array<Tensor>( using FCompute = TypedPackedFunc<Array<Tensor>(
const Attrs&, const Array<Tensor>&, Type, std::string)>; const Attrs&, const Array<Tensor>&, Type, tvm::Target)>;
using FSchedule = TypedPackedFunc<Schedule(const Array<Tensor>&, std::string)>; using FSchedule = TypedPackedFunc<Schedule(const Array<Tensor>&, tvm::Target)>;
/*! \brief Return the set of operators in their TVM format. */ /*! \brief Return the set of operators in their TVM format. */
Array<LoweredOp> LowerOps(const Module& mod, const Expr& e, Array<LoweredOp> LowerOps(const Module& mod, const Expr& e,
...@@ -179,7 +180,7 @@ Array<LoweredOp> LowerOps(const Module& mod, const Expr& e, ...@@ -179,7 +180,7 @@ Array<LoweredOp> LowerOps(const Module& mod, const Expr& e,
auto func = mod->Lookup(func_name); auto func = mod->Lookup(func_name);
auto call = Downcast<Call>(func->body); auto call = Downcast<Call>(func->body);
auto op_node = call->op.as<OpNode>(); auto op_node = call->op.as<OpNode>();
CHECK(op_node) << "violated invariant that primtiive calls contain a single op call"; CHECK(op_node) << "violated invariant that primtive calls contain a single op call";
auto op = GetRef<Op>(op_node); auto op = GetRef<Op>(op_node);
RELAY_LOG(INFO) << "LowerOps: Lowering " << op->name; RELAY_LOG(INFO) << "LowerOps: Lowering " << op->name;
...@@ -197,10 +198,11 @@ Array<LoweredOp> LowerOps(const Module& mod, const Expr& e, ...@@ -197,10 +198,11 @@ Array<LoweredOp> LowerOps(const Module& mod, const Expr& e,
i++; i++;
} }
auto output_tt = op->op_type->ret_type; auto output_tt = call->checked_type();
auto target_node = Target::create(target);
Array<Tensor> outputs = Array<Tensor> outputs =
compute_reg[op](call->attrs, inputs, output_tt, target); compute_reg[op](call->attrs, inputs, output_tt, target_node);
auto schedule = schedule_reg[op](outputs, target); auto schedule = schedule_reg[op](outputs, target_node);
size_t hash = StructuralHash()(func); size_t hash = StructuralHash()(func);
LoweredFunc lf = LoweredFunc lf =
flower(op->name + std::to_string(hash), schedule, inputs, outputs); flower(op->name + std::to_string(hash), schedule, inputs, outputs);
......
import math
import tvm import tvm
import numpy as np import numpy as np
from tvm import relay from tvm import relay
from tvm.relay.interpreter import create_executor
def sigmoid(x):
one = np.ones_like(x)
return one / (one + np.exp(-x))
def relu(x):
x_copy = np.copy(x)
np.maximum(x_copy, 0, x_copy)
return x_copy
def test_unary_op(): def test_unary_op():
def check_single_op(opfunc): def check_single_op(opfunc, ref):
tp = relay.TensorType((10, 4), "float32") shape = (10, 4)
dtype = 'float32'
tp = relay.TensorType(shape, dtype)
x = relay.var("x", tp) x = relay.var("x", tp)
y = opfunc(x) y = opfunc(x)
# test printer # test printer
...@@ -13,20 +25,33 @@ def test_unary_op(): ...@@ -13,20 +25,33 @@ def test_unary_op():
# test type inference # test type inference
assert relay.ir_pass.infer_type(y).checked_type == tp assert relay.ir_pass.infer_type(y).checked_type == tp
for opfunc in [tvm.relay.log, if ref is not None:
tvm.relay.exp, data = np.random.rand(*shape).astype(dtype)
tvm.relay.sqrt, intrp = create_executor()
tvm.relay.sigmoid, op_res = intrp.evaluate(y, { x: relay.const(data) })
tvm.relay.tanh, ref_res = ref(data)
relay.nn.relu]: np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
check_single_op(opfunc)
for opfunc, ref in [(tvm.relay.log, np.log),
(tvm.relay.exp, np.exp),
(tvm.relay.sqrt, np.sqrt),
(tvm.relay.sigmoid, sigmoid),
(tvm.relay.tanh, np.tanh),
(relay.nn.relu, None)]: # Just add RELU here after registering.
check_single_op(opfunc, ref)
def test_binary_op(): def test_binary_op():
def check_binary_op(opfunc): def inst(vars, sh):
return [vars.get(s, s) for s in sh]
def check_binary_op(opfunc, ref):
# TODO(@jroesch): this piece of code improperly uses type variables.
n = tvm.var("n") n = tvm.var("n")
t1 = relay.TensorType((5, n, 5)) s1 = (5, n, 5)
t2 = relay.TensorType((n, 1)) s2 = (n, 1)
t1 = relay.TensorType(s1)
t2 = relay.TensorType(s2)
x = relay.var("x", t1) x = relay.var("x", t1)
y = relay.var("y", t2) y = relay.var("y", t2)
z = opfunc(x, y) z = opfunc(x, y)
...@@ -34,12 +59,25 @@ def test_binary_op(): ...@@ -34,12 +59,25 @@ def test_binary_op():
assert ("%0 = {}(%x, %y)".format(z.op.name)) in z.astext() assert ("%0 = {}(%x, %y)".format(z.op.name)) in z.astext()
assert relay.ir_pass.infer_type(z).checked_type == t1 assert relay.ir_pass.infer_type(z).checked_type == t1
for opfunc in [relay.add, if ref is not None:
relay.subtract, t1 = relay.TensorType((5, 10, 5))
relay.mod, t2 = relay.TensorType((5, 10, 5))
relay.multiply, x = relay.var("x", t1)
relay.divide]: y = relay.var("y", t2)
check_binary_op(opfunc) z = opfunc(x, y)
x_data = np.random.rand(5, 10, 5).astype(t1.dtype)
y_data = np.random.rand(5, 10, 5).astype(t2.dtype)
intrp = create_executor()
op_res = intrp.evaluate(z, { x: relay.const(x_data), y: relay.const(y_data) })
ref_res = ref(x_data, y_data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
for opfunc, ref in [(relay.add, np.add),
(relay.subtract, np.subtract),
(relay.mod, np.mod),
(relay.multiply, np.multiply),
(relay.divide, np.divide)]:
check_binary_op(opfunc, ref)
def test_bias_add(): def test_bias_add():
...@@ -96,6 +134,15 @@ def test_concatenate_infer_type(): ...@@ -96,6 +134,15 @@ def test_concatenate_infer_type():
zz = relay.ir_pass.infer_type(z) zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((n, t + t, 100)) assert zz.checked_type == relay.TensorType((n, t + t, 100))
# x = relay.var("x", shape=(10, 5))
# y = relay.var("y", shape=(10, 5))
# z = relay.concatenate((x, y), axis=1)
# intrp = create_executor()
# x_data = np.random.rand(10, 5).astype('float32')
# y_data = np.random.rand(10, 5).astype('float32')
# op_res = intrp.evaluate(z, { x: relay.const(x_data), y: relay.const(y_data) })
# ref_res = np.concatenate(x_data, y_data, axis=1)
# np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
def test_dropout(): def test_dropout():
n, t, d = tvm.var("n"), tvm.var("t"), tvm.var("d") n, t, d = tvm.var("n"), tvm.var("t"), tvm.var("d")
......
...@@ -3,29 +3,40 @@ ...@@ -3,29 +3,40 @@
import tvm import tvm
import numpy as np import numpy as np
from tvm import relay from tvm import relay
from tvm.relay import create_executor
from nose.tools import raises from nose.tools import raises
def test_zeros_ones(): def test_zeros_ones():
for op in [relay.zeros, relay.ones]: for op, ref in [(relay.zeros, np.zeros), (relay.ones, np.ones)]:
y = op(shape=(124, 50), dtype="float64") y = op(shape=(124, 50), dtype="float64")
yy = relay.ir_pass.infer_type(y) yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType((124, 50), "float64") assert yy.checked_type == relay.TensorType((124, 50), "float64")
intrp = create_executor()
intrp_res = intrp.evaluate(y).asnumpy()
np.testing.assert_allclose(intrp_res, ref((124, 50), 'float64'))
def test_unary_identity(): def test_unary_identity():
for op in [relay.zeros_like, for op, ref in [(relay.zeros_like, np.zeros_like),
relay.ones_like, (relay.ones_like, np.ones_like),
relay.ceil, (relay.ceil, np.ceil),
relay.floor, (relay.floor, np.floor),
relay.trunc, (relay.trunc, np.trunc),
relay.round, (relay.round, np.round),
relay.abs, (relay.abs, np.abs),
relay.copy, (relay.copy, None), # np.copy
relay.negative]: (relay.negative, np.negative)]:
x = relay.var("x", relay.TensorType((8, 9, 4), "float32")) shape = (8, 9, 4)
x = relay.var("x", relay.TensorType(shape, "float32"))
y = op(x) y = op(x)
yy = relay.ir_pass.infer_type(y) yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType((8, 9, 4), "float32") assert yy.checked_type == relay.TensorType(shape, "float32")
if ref is not None:
data = np.random.rand(*shape).astype('float32')
intrp = create_executor()
op_res = intrp.evaluate(y, { x: relay.const(data) })
ref_res = ref(data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
def test_cast(): def test_cast():
x = relay.var("x", relay.TensorType((8, 9, 4), "float32")) x = relay.var("x", relay.TensorType((8, 9, 4), "float32"))
...@@ -35,12 +46,20 @@ def test_cast(): ...@@ -35,12 +46,20 @@ def test_cast():
assert yy.checked_type == relay.TensorType((8, 9, 4), "int32") assert yy.checked_type == relay.TensorType((8, 9, 4), "int32")
def test_clip_type(): def test_clip():
a = relay.var("a", relay.TensorType((10, 4), "float32")) a = relay.var("a", relay.TensorType((10, 4), "float32"))
y = relay.clip(a, 1., 4.) y = relay.clip(a, 1., 4.)
yy = relay.ir_pass.infer_type(y) yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType((10, 4), "float32") assert yy.checked_type == relay.TensorType((10, 4), "float32")
data = np.random.rand(10, 4).astype('float32')
intrp = create_executor()
op_res = intrp.evaluate(y, { a: relay.const(data) })
ref_res = np.clip(data, 1., 4.)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
def test_transpose_infer_type(): def test_transpose_infer_type():
n, t, d = tvm.var("n"), tvm.var("t"), 100 n, t, d = tvm.var("n"), tvm.var("t"), 100
...@@ -226,7 +245,7 @@ if __name__ == "__main__": ...@@ -226,7 +245,7 @@ if __name__ == "__main__":
test_cast() test_cast()
test_zeros_ones() test_zeros_ones()
test_unary_identity() test_unary_identity()
test_clip_type() test_clip()
test_transpose_infer_type() test_transpose_infer_type()
test_reshape_infer_type() test_reshape_infer_type()
test_reshape_like() test_reshape_like()
......
import tvm import tvm
import numpy as np import numpy as np
from tvm import relay from tvm import relay
from tvm.relay import create_executor
def test_binary_op(): def test_binary_op():
def check_binary_op(opfunc): def check_binary_op(opfunc, ref):
n = tvm.var("n") n = tvm.var("n")
t1 = relay.TensorType((5, n, 5)) t1 = relay.TensorType((5, n, 5))
t2 = relay.TensorType((n, 1)) t2 = relay.TensorType((n, 1))
...@@ -15,17 +16,30 @@ def test_binary_op(): ...@@ -15,17 +16,30 @@ def test_binary_op():
assert ("%0 = {}(%x, %y)".format(z.op.name)) in z.astext() assert ("%0 = {}(%x, %y)".format(z.op.name)) in z.astext()
assert relay.ir_pass.infer_type(z).checked_type == t1 assert relay.ir_pass.infer_type(z).checked_type == t1
for opfunc in [relay.pow]: if ref is not None:
check_binary_op(opfunc) t1 = relay.TensorType((5, 10, 5))
t2 = relay.TensorType((5, 10, 5))
x = relay.var("x", t1)
y = relay.var("y", t2)
z = opfunc(x, y)
x_data = np.random.rand(5, 10, 5).astype(t1.dtype)
y_data = np.random.rand(5, 10, 5).astype(t2.dtype)
intrp = create_executor()
op_res = intrp.evaluate(z, { x: relay.const(x_data), y: relay.const(y_data) })
ref_res = ref(x_data, y_data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
for opfunc, ref in [(relay.pow, np.power)]:
check_binary_op(opfunc, ref)
def test_cmp_type(): def test_cmp_type():
for op in (relay.greater, for op, ref in ((relay.greater, np.greater),
relay.greater_equal, (relay.greater_equal, np.greater_equal),
relay.less, (relay.less, np.less),
relay.less_equal, (relay.less_equal, np.less_equal),
relay.equal, (relay.equal, np.equal),
relay.not_equal): (relay.not_equal, np.not_equal)):
x = relay.var("x", relay.TensorType((10, 4), "float32")) x = relay.var("x", relay.TensorType((10, 4), "float32"))
y = relay.var("y", relay.TensorType((5, 10, 1), "float32")) y = relay.var("y", relay.TensorType((5, 10, 1), "float32"))
z = op(x, y) z = op(x, y)
...@@ -33,18 +47,44 @@ def test_cmp_type(): ...@@ -33,18 +47,44 @@ def test_cmp_type():
zz = relay.ir_pass.infer_type(z) zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((5, 10, 4), "bool") assert zz.checked_type == relay.TensorType((5, 10, 4), "bool")
if ref is not None:
x_shape = (10, 4)
y_shape = (5, 10, 1)
t1 = relay.TensorType(x_shape)
t2 = relay.TensorType(y_shape)
x = relay.var("x", t1)
y = relay.var("y", t2)
z = op(x, y)
x_data = np.random.rand(*x_shape).astype(t1.dtype)
y_data = np.random.rand(*y_shape).astype(t2.dtype)
intrp = create_executor()
op_res = intrp.evaluate(z, { x: relay.const(x_data), y: relay.const(y_data) })
ref_res = ref(x_data, y_data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
def test_binary_int_broadcast(): def test_binary_int_broadcast():
for op in [relay.right_shift, for op, ref in [(relay.right_shift, np.right_shift),
relay.left_shift, (relay.left_shift, np.left_shift),
relay.maximum, (relay.maximum, np.maximum),
relay.minimum]: (relay.minimum, np.minimum)]:
x = relay.var("x", relay.TensorType((10, 4), "int32")) x = relay.var("x", relay.TensorType((10, 4), "int32"))
y = relay.var("y", relay.TensorType((5, 10, 1), "int32")) y = relay.var("y", relay.TensorType((5, 10, 1), "int32"))
z = op(x, y) z = op(x, y)
zz = relay.ir_pass.infer_type(z) zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((5, 10, 4), "int32") assert zz.checked_type == relay.TensorType((5, 10, 4), "int32")
if ref is not None:
x_shape = (10, 4)
y_shape = (5, 10, 1)
t1 = relay.TensorType(x_shape, 'int32')
t2 = relay.TensorType(y_shape, 'int32')
x_data = np.random.rand(*x_shape).astype(t1.dtype)
y_data = np.random.rand(*y_shape).astype(t2.dtype)
intrp = create_executor()
op_res = intrp.evaluate(z, { x: relay.const(x_data), y: relay.const(y_data) })
ref_res = ref(x_data, y_data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
def test_where(): def test_where():
cond = relay.var("cond", relay.TensorType((3, 4), "float32")) cond = relay.var("cond", relay.TensorType((3, 4), "float32"))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment