rcnn.py 3.62 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
"""Faster R-CNN and Mask R-CNN operations."""
from . import _make


def roi_align(data, rois, pooled_size, spatial_scale, sample_ratio=-1, layout='NCHW'):
    """ROI align operator.

    Parameters
    ----------
    data : relay.Expr
        4-D tensor with shape [batch, channel, height, width]

    rois : relay.Expr
        2-D tensor with shape [num_roi, 5]. The last dimension should be in format of
        [batch_index, w_start, h_start, w_end, h_end]

    pooled_size : list/tuple of two ints
        output size

    spatial_scale : float
        Ratio of input feature map height (or w) to raw image height (or w). Equals the reciprocal
        of total stride in convolutional layers, which should be in range (0.0, 1.0]

    sample_ratio : int
        Optional sampling ratio of ROI align, using adaptive size by default.

    Returns
    -------
    output : relay.Expr
        4-D tensor with shape [num_roi, channel, pooled_size, pooled_size]
    """
    return _make.roi_align(data, rois, pooled_size, spatial_scale, sample_ratio, layout)
33 34


35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
def roi_pool(data, rois, pooled_size, spatial_scale, layout='NCHW'):
    """ROI pool operator.

    Parameters
    ----------
    data : relay.Expr
        4-D tensor with shape [batch, channel, height, width]

    rois : relay.Expr
        2-D tensor with shape [num_roi, 5]. The last dimension should be in format of
        [batch_index, w_start, h_start, w_end, h_end]

    pooled_size : list/tuple of two ints
        output size

    spatial_scale : float
        Ratio of input feature map height (or w) to raw image height (or w). Equals the reciprocal
        of total stride in convolutional layers, which should be in range (0.0, 1.0]

    Returns
    -------
    output : relay.Expr
        4-D tensor with shape [num_roi, channel, pooled_size, pooled_size]
    """
    return _make.roi_pool(data, rois, pooled_size, spatial_scale, layout)


62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
def proposal(cls_prob,
             bbox_pred,
             im_info,
             scales,
             ratios,
             feature_stride,
             threshold,
             rpn_pre_nms_top_n,
             rpn_post_nms_top_n,
             rpn_min_size,
             iou_loss):
    """Proposal operator.

    Parameters
    ----------
    cls_prob : relay.Expr
        4-D tensor with shape [batch, 2 * num_anchors, height, width].

    bbox_pred : relay.Expr
        4-D tensor with shape [batch, 4 * num_anchors, height, width].

    im_info : relay.Expr
        2-D tensor with shape [batch, 3]. The last dimension should be in format of
        [im_height, im_width, im_scale]

    scales : list/tuple of float
        Scales of anchor windoes.

    ratios : list/tuple of float
        Ratios of anchor windoes.

    feature_stride : int
        The size of the receptive field each unit in the convolution layer of the rpn, for example
        the product of all stride's prior to this layer.

    threshold : float
        Non-maximum suppression threshold.

    rpn_pre_nms_top_n : int
        Number of top scoring boxes to apply NMS. -1 to use all boxes.

    rpn_post_nms_top_n : int
        Number of top scoring boxes to keep after applying NMS to RPN proposals.

    rpn_min_size : int
        Minimum height or width in proposal.

    iou_loss : bool
        Usage of IoU loss.

    Returns
    -------
    output : relay.Expr
        2-D tensor with shape [batch * rpn_post_nms_top_n, 5]. The last dimension is in format of
        [batch_index, w_start, h_start, w_end, h_end].
    """
    return _make.proposal(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, threshold,
                          rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_min_size, iou_loss)