diff --git a/sahi/models/detectron2.py b/sahi/models/detectron2.py index 0752eaedf..5f67507ca 100644 --- a/sahi/models/detectron2.py +++ b/sahi/models/detectron2.py @@ -114,6 +114,7 @@ def _create_object_prediction_list_from_original_predictions( Size of the full image after shifting, should be in the form of List[[height, width],[height, width],...] """ + original_predictions = self._original_predictions # compatilibty for sahi v0.8.15 @@ -122,55 +123,56 @@ def _create_object_prediction_list_from_original_predictions( if full_shape_list is not None and isinstance(full_shape_list[0], int): full_shape_list = [full_shape_list] + # detectron2 DefaultPredictor supports single image + shift_amount = shift_amount_list[0] + full_shape = None if full_shape_list is None else full_shape_list[0] + # parse boxes, masks, scores, category_ids from predictions - boxes = original_predictions["instances"].pred_boxes.tensor.tolist() - scores = original_predictions["instances"].scores.tolist() - category_ids = original_predictions["instances"].pred_classes.tolist() + boxes = original_predictions["instances"].pred_boxes.tensor + scores = original_predictions["instances"].scores + category_ids = original_predictions["instances"].pred_classes # check if predictions contain mask try: - masks = original_predictions["instances"].pred_masks.tolist() + masks = original_predictions["instances"].pred_masks except AttributeError: masks = None - # create object_prediction_list - object_prediction_list_per_image = [] - object_prediction_list = [] - - # detectron2 DefaultPredictor supports single image - shift_amount = shift_amount_list[0] - full_shape = None if full_shape_list is None else full_shape_list[0] - - for ind in range(len(boxes)): - score = scores[ind] - if score < self.confidence_threshold: - continue - - category_id = category_ids[ind] - - if masks is None: - bbox = boxes[ind] - mask = None - else: - mask = np.array(masks[ind]) - - # check if mask is valid - # https://github.com/obss/sahi/issues/389 - if get_bbox_from_bool_mask(mask) is None: - continue - else: - bbox = None - - object_prediction = ObjectPrediction( - bbox=bbox, - bool_mask=mask, - category_id=category_id, - category_name=self.category_mapping[str(category_id)], - shift_amount=shift_amount, - score=score, - full_shape=full_shape, - ) - object_prediction_list.append(object_prediction) + # filter predictions with low confidence + high_confidence_mask = scores >= self.confidence_threshold + boxes = boxes[high_confidence_mask] + scores = scores[high_confidence_mask] + category_ids = category_ids[high_confidence_mask] + if masks is not None: + masks = masks[high_confidence_mask] + + if masks is not None: + object_prediction_list = [ + ObjectPrediction( + bbox=box.tolist() if mask is None else None, + bool_mask=mask.detach().cpu().numpy() if mask is not None else None, + category_id=category_id.item(), + category_name=self.category_mapping[str(category_id.item())], + shift_amount=shift_amount, + score=score.item(), + full_shape=full_shape, + ) + for box, score, category_id, mask in zip(boxes, scores, category_ids, masks) + if mask is None or get_bbox_from_bool_mask(mask.detach().cpu().numpy()) is not None + ] + else: + object_prediction_list = [ + ObjectPrediction( + bbox=box.tolist(), + bool_mask=None, + category_id=category_id.item(), + category_name=self.category_mapping[str(category_id.item())], + shift_amount=shift_amount, + score=score.item(), + full_shape=full_shape, + ) + for box, score, category_id in zip(boxes, scores, category_ids) + ] # detectron2 DefaultPredictor supports single image object_prediction_list_per_image = [object_prediction_list]