Unverified Commit 32af4d28 by Tianqi Chen Committed by GitHub

[IR] eager constant folding in operator overloading (#1789)

parent 3455c8a5
......@@ -10,6 +10,7 @@
#include "base.h"
#include "expr.h"
#include "ir_operator.h"
#include "node/container.h"
namespace tvm {
......
......@@ -7,7 +7,6 @@
#define TVM_EXPR_H_
#include <ir/Expr.h>
#include <ir/IROperator.h>
#include <ir/IRPrinter.h>
#include <string>
#include <algorithm>
......@@ -34,15 +33,6 @@ using HalideIR::Internal::Stmt;
using HalideIR::Internal::IRPrinter;
using HalideIR::Internal::Variable;
using HalideIR::Internal::make_const;
using HalideIR::Internal::make_zero;
using HalideIR::Internal::make_one;
using HalideIR::Internal::as_const_int;
using HalideIR::Internal::as_const_uint;
using HalideIR::Internal::const_true;
using HalideIR::Internal::const_false;
using HalideIR::Internal::is_no_op;
inline Type TVMShapeIndexType() {
if (std::is_signed<tvm_index_t>::value) {
return Int(sizeof(tvm_index_t) * 8);
......
......@@ -495,8 +495,6 @@ using HalideIR::Internal::Block;
using HalideIR::Internal::IfThenElse;
using HalideIR::Internal::Evaluate;
using HalideIR::Internal::Shuffle;
// ir functions
using HalideIR::Internal::is_const_power_of_two_integer;
/*!
* \brief Create a type annotation expression
......
......@@ -13,6 +13,7 @@
#include "base.h"
#include "expr.h"
#include "ir_operator.h"
#include "arithmetic.h"
#include "node/container.h"
......
......@@ -354,7 +354,7 @@ Example::
if (!r_axes.ndim()) return Array<Tensor> { topi::identity(inputs[0]) };
auto axis = ShapeToArray(r_axes);
Expr count = make_one(inputs[0]->dtype);
Expr count = make_const(inputs[0]->dtype, 1);
for (auto& i : r_axes) {
count *= inputs[0]->shape[i];
}
......
......@@ -156,9 +156,9 @@ def any(*args):
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
ret = _expr.Or(args[0], args[1])
ret = _make._OpOr(args[0], args[1])
for i in range(2, len(args)):
ret = _expr.Or(ret, args[i])
ret = _make._OpOr(ret, args[i])
return ret
......@@ -180,9 +180,9 @@ def all(*args):
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
ret = _expr.And(args[0], args[1])
ret = _make._OpAnd(args[0], args[1])
for i in range(2, len(args)):
ret = _expr.And(ret, args[i])
ret = _make._OpAnd(ret, args[i])
return ret
......@@ -773,5 +773,5 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
_init_api("tvm.api")
#pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _expr.Min(x, y), max_value, name='min')
max = comm_reducer(lambda x, y: _expr.Max(x, y), min_value, name='max')
min = comm_reducer(lambda x, y: _make._OpMin(x, y), max_value, name='min')
max = comm_reducer(lambda x, y: _make._OpMax(x, y), min_value, name='max')
......@@ -60,7 +60,7 @@ class ExprOp(object):
return self.__rdiv__(other)
def __mod__(self, other):
return _make.Mod(self, other)
return _make._OpMod(self, other)
def __neg__(self):
neg_one = _api_internal._const(-1, self.dtype)
......@@ -85,10 +85,10 @@ class ExprOp(object):
return _make.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic, None, 0)
def __lt__(self, other):
return _make.LT(self, other)
return _make._OpLT(self, other)
def __le__(self, other):
return _make.LE(self, other)
return _make._OpLE(self, other)
def __eq__(self, other):
return EqualOp(self, other)
......@@ -97,10 +97,10 @@ class ExprOp(object):
return NotEqualOp(self, other)
def __gt__(self, other):
return _make.GT(self, other)
return _make._OpGT(self, other)
def __ge__(self, other):
return _make.GE(self, other)
return _make._OpGE(self, other)
def __nonzero__(self):
raise ValueError("Cannot use and / or / not operator to Expr, hint: " +
......@@ -122,7 +122,7 @@ class ExprOp(object):
ret : Expr
The equality expression.
"""
return _make.EQ(self, other)
return _make._OpEQ(self, other)
def astype(self, dtype):
"""Cast the expression to other type.
......@@ -169,7 +169,7 @@ class EqualOp(NodeGeneric, ExprOp):
def asnode(self):
"""Convert node."""
return _make.EQ(self.a, self.b)
return _make._OpEQ(self.a, self.b)
class NotEqualOp(NodeGeneric, ExprOp):
......@@ -201,7 +201,7 @@ class NotEqualOp(NodeGeneric, ExprOp):
def asnode(self):
"""Convert node."""
return _make.NE(self.a, self.b)
return _make._OpNE(self.a, self.b)
class Expr(ExprOp, NodeBase):
......
......@@ -24,7 +24,7 @@ def add(lhs, rhs):
op : tvm.Expr
The result Expr of add operaton.
"""
return _make.Add(lhs, rhs)
return _make._OpAdd(lhs, rhs)
def subtract(lhs, rhs):
......@@ -42,7 +42,7 @@ def subtract(lhs, rhs):
op : tvm.Expr
The result Expr of subtract operaton.
"""
return _make.Sub(lhs, rhs)
return _make._OpSub(lhs, rhs)
def multiply(lhs, rhs):
......@@ -60,7 +60,7 @@ def multiply(lhs, rhs):
op : tvm.Expr
The result Expr of multiply operaton.
"""
return _make.Mul(lhs, rhs)
return _make._OpMul(lhs, rhs)
def divide(lhs, rhs):
......@@ -78,7 +78,7 @@ def divide(lhs, rhs):
op : tvm.Expr
The result Expr of divide operaton.
"""
return _make.Div(lhs, rhs)
return _make._OpDiv(lhs, rhs)
def cast(src, dtype):
......
......@@ -5,7 +5,7 @@
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <ir/IROperator.h>
#include <tvm/ir_operator.h>
#include <tvm/api_registry.h>
#include <tvm/ir_operator.h>
......@@ -117,6 +117,50 @@ TVM_REGISTER_API("make.CommReducer")
*ret = Node::make(args[0], args[1], args[2], args[3], args[4]); \
}) \
REGISTER_MAKE5(Reduce);
REGISTER_MAKE4(AttrStmt);
REGISTER_MAKE2(IntImm);
REGISTER_MAKE2(UIntImm);
REGISTER_MAKE2(FloatImm);
REGISTER_MAKE1(StringImm);
REGISTER_MAKE2(Add);
REGISTER_MAKE2(Sub);
REGISTER_MAKE2(Mul);
REGISTER_MAKE2(Div);
REGISTER_MAKE2(Mod);
REGISTER_MAKE2(Min);
REGISTER_MAKE2(Max);
REGISTER_MAKE2(EQ);
REGISTER_MAKE2(NE);
REGISTER_MAKE2(LT);
REGISTER_MAKE2(LE);
REGISTER_MAKE2(GT);
REGISTER_MAKE2(GE);
REGISTER_MAKE2(And);
REGISTER_MAKE2(Or);
REGISTER_MAKE1(Not);
REGISTER_MAKE3(Select);
REGISTER_MAKE3(Ramp);
REGISTER_MAKE2(Cast);
REGISTER_MAKE2(Broadcast);
REGISTER_MAKE2(Shuffle);
REGISTER_MAKE3(Let);
REGISTER_MAKE3(LetStmt);
REGISTER_MAKE3(AssertStmt);
REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE5(Allocate);
REGISTER_MAKE4(Provide);
REGISTER_MAKE4(Prefetch);
REGISTER_MAKE1(Free);
REGISTER_MAKE2(Block);
REGISTER_MAKE3(IfThenElse);
REGISTER_MAKE1(Evaluate);
// operator overloading, smarter than make
#define REGISTER_MAKE_BINARY_OP(Node, Func) \
TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
......@@ -138,50 +182,27 @@ TVM_REGISTER_API("make.CommReducer")
} \
})
REGISTER_MAKE5(Reduce);
REGISTER_MAKE4(AttrStmt);
REGISTER_MAKE2(IntImm);
REGISTER_MAKE2(UIntImm);
REGISTER_MAKE2(FloatImm);
REGISTER_MAKE1(StringImm);
REGISTER_MAKE_BINARY_OP(Add, operator+);
REGISTER_MAKE_BINARY_OP(Sub, operator-);
REGISTER_MAKE_BINARY_OP(Mul, operator*);
REGISTER_MAKE_BINARY_OP(Div, operator/);
REGISTER_MAKE_BINARY_OP(Mod, operator%);
REGISTER_MAKE_BINARY_OP(Min, min);
REGISTER_MAKE_BINARY_OP(Max, max);
REGISTER_MAKE_BINARY_OP(EQ, operator==);
REGISTER_MAKE_BINARY_OP(NE, operator!=);
REGISTER_MAKE_BINARY_OP(LT, operator<); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(LE, operator<=); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(GT, operator>); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(GE, operator>=);
REGISTER_MAKE_BINARY_OP(And, operator&&);
REGISTER_MAKE_BINARY_OP(Or, operator||);
REGISTER_MAKE_BINARY_OP(_OpAdd, operator+);
REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
REGISTER_MAKE_BINARY_OP(_OpMul, operator*);
REGISTER_MAKE_BINARY_OP(_OpDiv, operator/);
REGISTER_MAKE_BINARY_OP(_OpMod, operator%);
REGISTER_MAKE_BINARY_OP(_OpMin, min);
REGISTER_MAKE_BINARY_OP(_OpMax, max);
REGISTER_MAKE_BINARY_OP(_OpEQ, operator==);
REGISTER_MAKE_BINARY_OP(_OpNE, operator!=);
REGISTER_MAKE_BINARY_OP(_OpLT, operator<); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(_OpLE, operator<=); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(_OpGT, operator>); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(_OpGE, operator>=);
REGISTER_MAKE_BINARY_OP(_OpAnd, operator&&);
REGISTER_MAKE_BINARY_OP(_OpOr, operator||);
REGISTER_MAKE_BIT_OP(bitwise_and, operator&);
REGISTER_MAKE_BIT_OP(bitwise_or, operator|);
REGISTER_MAKE_BIT_OP(bitwise_xor, operator^);
REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
REGISTER_MAKE_BIT_OP(right_shift, operator>>);
REGISTER_MAKE1(Not);
REGISTER_MAKE3(Select);
REGISTER_MAKE3(Ramp);
REGISTER_MAKE2(Cast);
REGISTER_MAKE2(Broadcast);
REGISTER_MAKE2(Shuffle);
REGISTER_MAKE3(Let);
REGISTER_MAKE3(LetStmt);
REGISTER_MAKE3(AssertStmt);
REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE5(Allocate);
REGISTER_MAKE4(Provide);
REGISTER_MAKE4(Prefetch);
REGISTER_MAKE1(Free);
REGISTER_MAKE2(Block);
REGISTER_MAKE3(IfThenElse);
REGISTER_MAKE1(Evaluate);
} // namespace ir
} // namespace tvm
......@@ -14,10 +14,6 @@
namespace tvm {
namespace arith {
using HalideIR::Internal::add_would_overflow;
using HalideIR::Internal::sub_would_overflow;
using HalideIR::Internal::mul_would_overflow;
/*!
* \brief Compute the expression with the given binary op.
* \param lhs The left operand
......@@ -42,23 +38,9 @@ template<typename Op>
inline Expr ComputeReduce(
const Array<Expr>& values, Expr empty_value);
template<typename T>
inline bool GetConst(Expr e, T* out);
template<>
inline bool GetConst<int64_t>(Expr e, int64_t *out) {
if (e.type().is_vector()) return false;
const int64_t *v = as_const_int(e);
if (v) {
*out = *v; return true;
} else {
return false;
}
}
template<>
inline bool GetConst<uint64_t>(Expr e, uint64_t *out) {
inline bool GetConst(Expr e, int64_t* out) {
if (e.type().is_vector()) return false;
const uint64_t *v = as_const_uint(e);
const int64_t* v = as_const_int(e);
if (v) {
*out = *v; return true;
} else {
......@@ -69,66 +51,37 @@ inline bool GetConst<uint64_t>(Expr e, uint64_t *out) {
// get a small constant int
inline bool GetConstInt(Expr e, int* out) {
int64_t v1 = 0;
uint64_t v2 = 0;
if (GetConst(e, &v1)) {
if (v1 > static_cast<int64_t>(
std::numeric_limits<int>::max())) return false;
*out = static_cast<int>(v1); return true;
}
if (GetConst(e, &v2)) {
if (v2 > static_cast<uint64_t>(
std::numeric_limits<int>::max())) return false;
*out = static_cast<int>(v2); return true;
}
return false;
}
#define TVM_CONST_PROPAGATION(OP_NAME, OP) \
int64_t ia = 0, ib = 0; \
if (GetConst(a, &ia) && GetConst(b, &ib)) { \
if (OP_NAME ## _would_overflow(a.type().bits(), ia, ib)) { \
LOG(FATAL) << "signed int overflow"; \
} \
return ir::IntImm::make(a.type(), ia OP ib); \
} \
uint64_t ua = 0, ub = 0; \
if (GetConst(a, &ua) && GetConst(b, &ub)) { \
return ir::UIntImm::make(a.type(), ua OP ub); \
} \
template<>
inline Expr ComputeExpr<ir::Add>(Expr a, Expr b) {
if (is_zero(a)) return b;
if (is_zero(b)) return a;
TVM_CONST_PROPAGATION(add, +);
return ir::Add::make(a, b);
return a + b;
}
template<>
inline Expr ComputeExpr<ir::Sub>(Expr a, Expr b) {
if (is_zero(b)) return a;
TVM_CONST_PROPAGATION(sub, -);
return ir::Sub::make(a, b);
return a - b;
}
template<>
inline Expr ComputeExpr<ir::Mul>(Expr a, Expr b) {
if (is_one(a)) return b;
if (is_one(b)) return a;
TVM_CONST_PROPAGATION(mul, *);
return ir::Mul::make(a, b);
return a * b;
}
template<>
inline Expr ComputeExpr<ir::Div>(Expr a, Expr b) {
if (is_one(b)) return a;
return ir::Div::make(a, b);
return a / b;
}
template<>
inline Expr ComputeExpr<ir::Mod>(Expr a, Expr b) {
if (is_zero(a)) return make_zero(a.type());
return ir::Mod::make(a, b);
return a % b;
}
template<>
......
......@@ -194,7 +194,7 @@ bool DetectClipBound(
if (!LinearEqDetector(var).Detect(canonical, &ret)) return false;
ret.coeff = Simplify(ret.coeff);
IntervalEntry& p = (*bmap)[var.get()];
if (is_one(ret.coeff)) {
if (is_const_int(ret.coeff, 1)) {
// var + shift >=0 -> var >= -shift
if (p.min_value.defined()) {
p.min_value = ir::Max::make(p.min_value, -ret.base);
......@@ -203,7 +203,7 @@ bool DetectClipBound(
}
return true;
}
if (is_const(ret.coeff, -1)) {
if (is_const_int(ret.coeff, -1)) {
// -var + shift >=0 -> var <= shift
if (p.max_value.defined()) {
p.max_value = ir::Min::make(p.max_value, ret.base);
......
......@@ -42,7 +42,7 @@ std::string CodeGenCUDA::Finish() {
}
void CodeGenCUDA::VisitStmt_(const ir::For* op) {
CHECK(is_zero(op->min));
CHECK(is_const_int(op->min, 0));
if (op->for_type == ir::ForType::Unrolled) {
PrintIndent();
stream << "#pragma unroll\n";
......
......@@ -195,7 +195,7 @@ class PipelineExtractor: public IRVisitor {
ChannelEntry& cb = cmap_.at(ch->handle_var.get());
trigger->signal_index = static_cast<int>(cb.node->ctrl_signals.size());
// Grab the advance constant size.
int trigger_size;
int trigger_size = 0;
if (attr->attr_key == attr::pipeline_stage_scope) {
cb.node->ctrl_signals.push_back(
ControlSignalNode::make(kComputeFinish, 0));
......
......@@ -5,6 +5,7 @@
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir_operator.h>
#include <ir/IRPrinter.h>
#include <memory>
......
......@@ -7,6 +7,7 @@
#define TVM_PASS_IR_UTIL_H_
#include <tvm/ir.h>
#include <tvm/ir_operator.h>
#include <tvm/runtime/device_api.h>
#include <vector>
......@@ -75,7 +76,7 @@ inline Expr TVMStructGet(
Array<Expr> args ={
handle,
make_const(Int(32), index),
make_const(Int(32), kind)};
make_const(Int(32), static_cast<int>(kind))};
return Call::make(dtype, intrinsic::tvm_struct_get, args, Call::PureIntrinsic);
}
......@@ -125,7 +126,7 @@ inline Stmt TVMStructSet(
Array<Expr> args ={
handle,
make_const(Int(32), index),
make_const(Int(32), kind),
make_const(Int(32), static_cast<int>(kind)),
value};
return Evaluate::make(
Call::make(Int(32), intrinsic::tvm_struct_set, args, Call::Intrinsic));
......
......@@ -102,9 +102,8 @@ class MarkChannelAccess : public IRMutator {
} else {
alloc_size = op->extents[0];
for (size_t i = 1; i < op->extents.size(); ++i) {
alloc_size *= op->extents[i];
alloc_size = alloc_size * op->extents[i];
}
alloc_size = ir::Simplify(alloc_size);
}
if (rw.write_count) {
......
......@@ -578,7 +578,7 @@ class StoragePlanRewriter : public IRMutator {
combo_size = combo_size / type_bits;
// round up for can not divided
if (!divided) {
combo_size += make_const(Int(32), 1);
combo_size = combo_size + make_const(Int(32), 1);
}
combo_size = ir::Simplify(combo_size);
e->new_alloc = Allocate::make(
......
......@@ -437,7 +437,6 @@ class LoopVectorizer : public IRMutator {
Stmt Mutate_(const For* op, const Stmt& s) final {
if (op->for_type == ForType::Vectorized) {
CHECK(is_zero(op->min));
CHECK(is_positive_const(op->extent));
int lanes = 0;
bool succ = arith::GetConstInt(op->extent, &lanes);
if (!succ || lanes < 1) {
......
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_operator.h>
namespace {
using namespace tvm::ir;
......
......@@ -63,7 +63,7 @@ def test_check():
assert res1.is_nothing()
# multiple compare operators
res2 = tvm.arith.DeduceBound(a, (a+b>3)>c , {b: b_s, c: c_s}, {})
res2 = tvm.arith.DeduceBound(a, (a+b>3).astype(c.dtype)>c , {b: b_s, c: c_s}, {})
assert res2.is_nothing()
# multiple target variable
......@@ -137,4 +137,3 @@ if __name__ == "__main__":
test_check()
test_deduce_basic()
test_deduce_complex()
......@@ -8,7 +8,7 @@ def test_const():
def test_make():
x = tvm.const(1)
y = tvm.make.IntImm('int32', 1)
y = tvm.var("x")
z = x + y
assert isinstance(tvm.max(x, y), tvm.expr.Max)
assert isinstance(tvm.min(x, y), tvm.expr.Min)
......
import tvm
def test_const_fold():
def check(f, *args):
x = f(*[tvm.const(x) for x in args])
y = f(*args)
if not isinstance(x, (tvm.expr.IntImm, tvm.expr.UIntImm)) or x.value != int(y):
raise ValueError("check error: %s vs %s " % (x, y))
check(lambda x, y: x + y, 3, 4)
check(lambda x, y: x * y, 3, 12)
check(lambda x, y: x * y - 10, 3, 12)
check(lambda x, y: x - y % 10, 3, 12)
check(lambda x, y: x // y + 10, 100, 12)
check(lambda x, y: x & y + 10, 112, 128)
check(lambda x, y: x > y, 112, 128)
check(lambda x, y: x < y, 112, 128)
check(lambda x, y: x <= y, 112, 128)
check(lambda x, y: x >= y, 112, 128)
check(lambda x, y: (x | y) ^ 10, 112, 128)
def test_const_fold2():
x = tvm.var("x")
assert (x + 0).same_as(x)
assert (0 + x).same_as(x)
assert (x - 0).same_as(x)
assert (x % 1).value == 0
assert (x * 1).same_as(x)
assert (1 * x).same_as(x)
assert isinstance((1 / x), tvm.expr.Div)
if __name__ == "__main__":
test_const_fold()
test_const_fold2()
......@@ -15,7 +15,7 @@ def test_make_smap():
# save load json
x = tvm.const(1)
y = tvm.const(10)
z = x + y
z = tvm.expr.Add(x, y)
smap = tvm.convert({"z": z, "x": x})
json_str = tvm.save_json(tvm.convert([smap]))
arr = tvm.load_json(json_str)
......
......@@ -53,7 +53,6 @@ def test_canonical():
assert (tvm.ir_pass.Equal(ret1, ret2))
if __name__ == "__main__":
test_modular()
test_bound()
test_basic()
test_simplify()
......
......@@ -163,7 +163,7 @@ inline Tensor full(const Array<Expr>& shape,
const Expr fill_value,
std::string name = "tensor",
std::string tag = kElementWise) {
Expr ev = lossless_cast(dtype, fill_value);
Expr ev = cast(dtype, fill_value);
if (!ev.defined()) {
LOG(ERROR) << "Can't cast fill_value to " << dtype;
}
......@@ -187,10 +187,7 @@ inline Tensor full_like(const Tensor& x,
const Expr fill_value,
std::string name = "tensor",
std::string tag = kElementWise) {
Expr ev = lossless_cast(x->dtype, fill_value);
if (!ev.defined()) {
LOG(ERROR) << "Can't cast fill_value to " << x->dtype;
}
Expr ev = cast(x->dtype, fill_value);
return compute(x->shape, [&](const Array<Var>& i) {
return ev;
}, name, tag);
......
......@@ -94,10 +94,10 @@ inline Tensor pool_impl(const Tensor& x,
out_shape.Set(height_axis, out_height);
out_shape.Set(width_axis, out_width);
const int64_t *padding_h0 = HalideIR::Internal::as_const_int(pad_top);
const int64_t *padding_w0 = HalideIR::Internal::as_const_int(pad_left);
const int64_t *padding_h1 = HalideIR::Internal::as_const_int(pad_bottom);
const int64_t *padding_w1 = HalideIR::Internal::as_const_int(pad_right);
const int64_t *padding_h0 = as_const_int(pad_top);
const int64_t *padding_w0 = as_const_int(pad_left);
const int64_t *padding_h1 = as_const_int(pad_bottom);
const int64_t *padding_w1 = as_const_int(pad_right);
const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) ||
((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1));
......
......@@ -164,10 +164,10 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
oy = py * vy * ah + ay
ow = tvm.exp(pw * vw) * aw / 2.0
oh = tvm.exp(ph * vh) * ah / 2.0
return tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox - ow)), ox - ow), \
tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy - oh)), oy - oh), \
tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox + ow)), ox + ow), \
tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy + oh)), oy + oh)
return tvm.select(clip, tvm.max(0, tvm.min(1, ox - ow)), ox - ow), \
tvm.select(clip, tvm.max(0, tvm.min(1, oy - oh)), oy - oh), \
tvm.select(clip, tvm.max(0, tvm.min(1, ox + ow)), ox + ow), \
tvm.select(clip, tvm.max(0, tvm.min(1, oy + oh)), oy + oh)
batch_size = cls_prob.shape[0]
num_classes = cls_prob.shape[1]
......@@ -191,7 +191,7 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
with ib.if_scope(j > 0):
temp = p_cls_prob[n * num_anchors * num_classes + j * num_anchors + i]
cls_id[0] = tvm.select(temp > score[0], j, cls_id[0])
score[0] = tvm.make.Max(temp, score[0])
score[0] = tvm.max(temp, score[0])
with ib.if_scope(tvm.all(cls_id[0] > 0, score[0] < threshold)):
cls_id[0] = 0
# [id, prob, xmin, ymin, xmax, ymax]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment