Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization of _create_object_prediction_list_from_original_predictions for Detectron2 model. Significant speed improvement. #865

Merged
merged 4 commits into from
May 15, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 44 additions & 42 deletions sahi/models/detectron2.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def _create_object_prediction_list_from_original_predictions(
Size of the full image after shifting, should be in the form of
List[[height, width],[height, width],...]
"""

original_predictions = self._original_predictions

# compatilibty for sahi v0.8.15
Expand All @@ -122,55 +123,56 @@ def _create_object_prediction_list_from_original_predictions(
if full_shape_list is not None and isinstance(full_shape_list[0], int):
full_shape_list = [full_shape_list]

# detectron2 DefaultPredictor supports single image
shift_amount = shift_amount_list[0]
full_shape = None if full_shape_list is None else full_shape_list[0]

# parse boxes, masks, scores, category_ids from predictions
boxes = original_predictions["instances"].pred_boxes.tensor.tolist()
scores = original_predictions["instances"].scores.tolist()
category_ids = original_predictions["instances"].pred_classes.tolist()
boxes = original_predictions["instances"].pred_boxes.tensor
scores = original_predictions["instances"].scores
category_ids = original_predictions["instances"].pred_classes

# check if predictions contain mask
try:
masks = original_predictions["instances"].pred_masks.tolist()
masks = original_predictions["instances"].pred_masks
except AttributeError:
masks = None

# create object_prediction_list
object_prediction_list_per_image = []
object_prediction_list = []

# detectron2 DefaultPredictor supports single image
shift_amount = shift_amount_list[0]
full_shape = None if full_shape_list is None else full_shape_list[0]

for ind in range(len(boxes)):
score = scores[ind]
if score < self.confidence_threshold:
continue

category_id = category_ids[ind]

if masks is None:
bbox = boxes[ind]
mask = None
else:
mask = np.array(masks[ind])

# check if mask is valid
# https://github.com/obss/sahi/issues/389
if get_bbox_from_bool_mask(mask) is None:
continue
else:
bbox = None

object_prediction = ObjectPrediction(
bbox=bbox,
bool_mask=mask,
category_id=category_id,
category_name=self.category_mapping[str(category_id)],
shift_amount=shift_amount,
score=score,
full_shape=full_shape,
)
object_prediction_list.append(object_prediction)
# filter predictions with low confidence
high_confidence_mask = scores >= self.confidence_threshold
boxes = boxes[high_confidence_mask]
scores = scores[high_confidence_mask]
category_ids = category_ids[high_confidence_mask]
if masks is not None:
masks = masks[high_confidence_mask]

if masks is not None:
object_prediction_list = [
ObjectPrediction(
bbox=box.tolist() if mask is None else None,
bool_mask=mask.detach().cpu().numpy() if mask is not None else None,
category_id=category_id.item(),
category_name=self.category_mapping[str(category_id.item())],
shift_amount=shift_amount,
score=score.item(),
full_shape=full_shape,
)
for box, score, category_id, mask in zip(boxes, scores, category_ids, masks)
if mask is None or get_bbox_from_bool_mask(mask.detach().cpu().numpy()) is not None
]
else:
object_prediction_list = [
ObjectPrediction(
bbox=box.tolist(),
bool_mask=None,
category_id=category_id.item(),
category_name=self.category_mapping[str(category_id.item())],
shift_amount=shift_amount,
score=score.item(),
full_shape=full_shape,
)
for box, score, category_id in zip(boxes, scores, category_ids)
]

# detectron2 DefaultPredictor supports single image
object_prediction_list_per_image = [object_prediction_list]
Expand Down