Skip to content

Commit 13531b7

Browse files
authored
Fix OBB prediction and update Ultralytics demo notebook (#1126)
1 parent fa705eb commit 13531b7

File tree

9 files changed

+315
-186
lines changed

9 files changed

+315
-186
lines changed

README.md

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
<br>
2222
<a href="https://ieeexplore.ieee.org/document/9897990"><img src="https://img.shields.io/badge/DOI-10.1109%2FICIP46576.2022.9897990-orange.svg" alt="ci"></a>
2323
<br>
24-
<a href="https://colab.research.google.com/github/obss/sahi/blob/main/demo/inference_for_yolov5.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
24+
<a href="https://colab.research.google.com/github/obss/sahi/blob/main/demo/inference_for_ultralytics.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
2525
<a href="https://huggingface.co/spaces/fcakyon/sahi-yolox"><img src="https://raw.githubusercontent.com/obss/sahi/main/resources/hf_spaces_badge.svg" alt="HuggingFace Spaces"></a>
2626

2727
</div>
@@ -43,7 +43,7 @@ Object detection and instance segmentation are by far the most important applica
4343

4444
## <div align="center">Quick Start Examples</div>
4545

46-
[📜 List of publications that cite SAHI (currently 200+)](https://scholar.google.com/scholar?hl=en&as_sdt=2005&sciodt=0,5&cites=14065474760484865747&scipsc=&q=&scisbd=1)
46+
[📜 List of publications that cite SAHI (currently 300+)](https://scholar.google.com/scholar?hl=en&as_sdt=2005&sciodt=0,5&cites=14065474760484865747&scipsc=&q=&scisbd=1)
4747

4848
[🏆 List of competition winners that used SAHI](https://github.com/obss/sahi/discussions/688)
4949

@@ -55,11 +55,15 @@ Object detection and instance segmentation are by far the most important applica
5555

5656
- [Pretrained weights and ICIP 2022 paper files](https://github.com/fcakyon/small-object-detection-benchmark)
5757

58-
- [Visualizing and Evaluating SAHI predictions with FiftyOne](https://voxel51.com/blog/how-to-detect-small-objects/) (2024) (NEW)
58+
- [2025 Video Tutorial](https://www.youtube.com/watch?v=ILqMBah5ZvI) (RECOMMENDED)
59+
60+
- [Visualizing and Evaluating SAHI predictions with FiftyOne](https://voxel51.com/blog/how-to-detect-small-objects/)
5961

6062
- ['Exploring SAHI' Research Article from 'learnopencv.com'](https://learnopencv.com/slicing-aided-hyper-inference/)
6163

62-
- ['VIDEO TUTORIAL: Slicing Aided Hyper Inference for Small Object Detection - SAHI'](https://www.youtube.com/watch?v=UuOjJKxn-M8&t=270s) (RECOMMENDED)
64+
- [Slicing Aided Hyper Inference Explained by Encord](https://encord.com/blog/slicing-aided-hyper-inference-explained/)
65+
66+
- ['VIDEO TUTORIAL: Slicing Aided Hyper Inference for Small Object Detection - SAHI'](https://www.youtube.com/watch?v=UuOjJKxn-M8&t=270s)
6367

6468
- [Video inference support is live](https://github.com/obss/sahi/discussions/626)
6569

@@ -77,11 +81,13 @@ Object detection and instance segmentation are by far the most important applica
7781

7882
- `YOLOX` + `SAHI` demo: <a href="https://huggingface.co/spaces/fcakyon/sahi-yolox"><img src="https://raw.githubusercontent.com/obss/sahi/main/resources/hf_spaces_badge.svg" alt="sahi-yolox"></a>
7983

80-
- `YOLO11` + `SAHI` walkthrough: <a href="https://colab.research.google.com/github/obss/sahi/blob/main/demo/inference_for_ultralytics.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="sahi-yolov8"></a> (NEW)
84+
- `YOLO12` + `SAHI` walkthrough: <a href="https://colab.research.google.com/github/obss/sahi/blob/main/demo/inference_for_ultralytics.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="sahi-yolo12"></a> (NEW)
8185

82-
- `RT-DETR` + `SAHI` walkthrough: <a href="https://colab.research.google.com/github/obss/sahi/blob/main/demo/inference_for_rtdetr.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="sahi-rtdetr"></a> (NEW)
86+
- `YOLO11-OBB` + `SAHI` walkthrough: <a href="https://colab.research.google.com/github/obss/sahi/blob/main/demo/inference_for_ultralytics.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="sahi-yolo11-obb"></a> (NEW)
8387

84-
- `YOLOv8` + `SAHI` walkthrough: <a href="https://colab.research.google.com/github/obss/sahi/blob/main/demo/inference_for_ultralytics.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="sahi-yolov8"></a>
88+
- `YOLO11` + `SAHI` walkthrough: <a href="https://colab.research.google.com/github/obss/sahi/blob/main/demo/inference_for_ultralytics.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="sahi-yolo11"></a>
89+
90+
- `RT-DETR` + `SAHI` walkthrough: <a href="https://colab.research.google.com/github/obss/sahi/blob/main/demo/inference_for_rtdetr.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="sahi-rtdetr"></a> (NEW)
8591

8692
- `DeepSparse` + `SAHI` walkthrough: <a href="https://colab.research.google.com/github/obss/sahi/blob/main/demo/inference_for_sparse_yolov5.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="sahi-deepsparse"></a>
8793

demo/inference_for_ultralytics.ipynb

Lines changed: 209 additions & 108 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "sahi"
3-
version = "0.11.21"
3+
version = "0.11.22"
44
readme = "README.md"
55
description = "A vision library for performing sliced inference on large images/small objects"
66
requires-python = ">=3.8"

sahi/auto_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def from_pretrained(
3939
4040
Args:
4141
model_type: str
42-
Name of the detection framework (example: "yolov5", "mmdet", "detectron2")
42+
Name of the detection framework (example: "ultralytics", "huggingface", "torchvision")
4343
model_path: str
4444
Path of the detection model (ex. 'model.pt')
4545
config_path: str
@@ -58,8 +58,10 @@ def from_pretrained(
5858
If True, automatically loads the model at initialization
5959
image_size: int
6060
Inference input size.
61+
6162
Returns:
6263
Returns an instance of a DetectionModel
64+
6365
Raises:
6466
ImportError: If given {model_type} framework is not installed
6567
"""

sahi/models/torchvision.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,13 @@ def _create_object_prediction_list_from_original_predictions(
178178

179179
for ind in range(len(boxes)):
180180
if masks is not None:
181-
mask = get_coco_segmentation_from_bool_mask(np.array(masks[ind]))
181+
segmentation = get_coco_segmentation_from_bool_mask(np.array(masks[ind]))
182182
else:
183-
mask = None
183+
segmentation = None
184184

185185
object_prediction = ObjectPrediction(
186186
bbox=boxes[ind],
187-
segmentation=mask,
187+
segmentation=segmentation,
188188
category_id=int(category_ids[ind]),
189189
category_name=self.category_mapping[str(int(category_ids[ind]))],
190190
shift_amount=shift_amount,

sahi/models/ultralytics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from sahi.models.base import DetectionModel
1212
from sahi.prediction import ObjectPrediction
1313
from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list
14-
from sahi.utils.cv import get_coco_segmentation_from_bool_mask, get_coco_segmentation_from_obb_points
14+
from sahi.utils.cv import get_coco_segmentation_from_bool_mask
1515
from sahi.utils.import_utils import check_requirements
1616

1717
logger = logging.getLogger(__name__)
@@ -207,7 +207,7 @@ def _create_object_prediction_list_from_original_predictions(
207207
segmentation = get_coco_segmentation_from_bool_mask(bool_mask)
208208
else: # is_obb
209209
obb_points = masks_or_points[pred_ind] # Get OBB points for this prediction
210-
segmentation = get_coco_segmentation_from_obb_points(obb_points)
210+
segmentation = [obb_points.reshape(-1).tolist()]
211211

212212
if len(segmentation) == 0:
213213
continue

sahi/predict.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ def get_prediction(
113113
time_end = time.time() - time_start
114114
durations_in_seconds["prediction"] = time_end
115115

116+
if full_shape is None:
117+
full_shape = [image_as_pil.height, image_as_pil.width]
118+
116119
# process prediction
117120
time_start = time.time()
118121
# works only with 1 batch
@@ -239,19 +242,21 @@ def get_sliced_prediction(
239242
overlap_width_ratio=overlap_width_ratio,
240243
auto_slice_resolution=auto_slice_resolution,
241244
)
245+
from sahi.models.ultralytics import UltralyticsDetectionModel
242246

243247
num_slices = len(slice_image_result)
244248
time_end = time.time() - time_start
245249
durations_in_seconds["slice"] = time_end
246250

251+
if isinstance(detection_model, UltralyticsDetectionModel) and detection_model.is_obb:
252+
# Only NMS is supported for OBB model outputs
253+
postprocess_type = "NMS"
254+
247255
# init match postprocess instance
248256
if postprocess_type not in POSTPROCESS_NAME_TO_CLASS.keys():
249257
raise ValueError(
250258
f"postprocess_type should be one of {list(POSTPROCESS_NAME_TO_CLASS.keys())} but given as {postprocess_type}"
251259
)
252-
elif postprocess_type == "UNIONMERGE":
253-
# deprecated in v0.9.3
254-
raise ValueError("'UNIONMERGE' postprocess_type is deprecated, use 'GREEDYNMM' instead.")
255260
postprocess_constructor = POSTPROCESS_NAME_TO_CLASS[postprocess_type]
256261
postprocess = postprocess_constructor(
257262
match_threshold=postprocess_match_threshold,

sahi/utils/cv.py

Lines changed: 76 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -540,68 +540,88 @@ def visualize_object_predictions(
540540
# set text_size for category names
541541
text_size = text_size or rect_th / 3
542542

543-
# add masks to image if present
543+
# add masks or obb polygons to image if present
544544
for object_prediction in object_prediction_list:
545545
# deepcopy object_prediction_list so that original is not altered
546546
object_prediction = object_prediction.deepcopy()
547-
# visualize masks if present
548-
if object_prediction.mask is not None:
549-
# deepcopy mask so that original is not altered
550-
mask = object_prediction.mask.bool_mask
551-
# set color
552-
if colors is not None:
553-
color = colors(object_prediction.category.id)
554-
# draw mask
555-
rgb_mask = apply_color_mask(mask, color or (0, 0, 0))
556-
image = cv2.addWeighted(image, 1, rgb_mask, 0.6, 0)
557-
558-
# add bboxes to image if present
559-
for object_prediction in object_prediction_list:
560-
# deepcopy object_prediction_list so that original is not altered
561-
object_prediction = object_prediction.deepcopy()
562-
563-
bbox = object_prediction.bbox.to_xyxy()
564-
category_name = object_prediction.category.name
565-
score = object_prediction.score.value
566-
547+
# arange label to be displayed
548+
label = f"{object_prediction.category.name}"
549+
if not hide_conf:
550+
label += f" {object_prediction.score.value:.2f}"
567551
# set color
568552
if colors is not None:
569553
color = colors(object_prediction.category.id)
570-
# set bbox points
571-
point1, point2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
572-
# visualize boxes
573-
cv2.rectangle(
574-
image,
575-
point1,
576-
point2,
577-
color=color or (0, 0, 0),
578-
thickness=rect_th,
579-
)
580-
581-
if not hide_labels:
582-
# arange bounding box text location
583-
label = f"{category_name}"
584-
585-
if not hide_conf:
586-
label += f" {score:.2f}"
587-
588-
box_width, box_height = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[
589-
0
590-
] # label width, height
591-
outside = point1[1] - box_height - 3 >= 0 # label fits outside box
592-
point2 = point1[0] + box_width, point1[1] - box_height - 3 if outside else point1[1] + box_height + 3
593-
# add bounding box text
594-
cv2.rectangle(image, point1, point2, color or (0, 0, 0), -1, cv2.LINE_AA) # filled
595-
cv2.putText(
554+
# visualize masks or obb polygons if present
555+
has_mask = object_prediction.mask is not None
556+
is_obb_pred = False
557+
if has_mask:
558+
segmentation = object_prediction.mask.segmentation
559+
if len(segmentation) == 1 and len(segmentation[0]) == 8:
560+
is_obb_pred = True
561+
562+
if is_obb_pred:
563+
points = np.array(segmentation).reshape((-1, 1, 2)).astype(np.int32)
564+
cv2.polylines(image, [points], isClosed=True, color=color or (0, 0, 0), thickness=rect_th)
565+
566+
if not hide_labels:
567+
lowest_point = points[points[:, :, 1].argmax()][0]
568+
box_width, box_height = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[0]
569+
outside = lowest_point[1] - box_height - 3 >= 0
570+
text_bg_point1 = (
571+
lowest_point[0],
572+
lowest_point[1] - box_height - 3 if outside else lowest_point[1] + 3,
573+
)
574+
text_bg_point2 = (lowest_point[0] + box_width, lowest_point[1])
575+
cv2.rectangle(
576+
image, text_bg_point1, text_bg_point2, color or (0, 0, 0), thickness=-1, lineType=cv2.LINE_AA
577+
)
578+
cv2.putText(
579+
image,
580+
label,
581+
(lowest_point[0], lowest_point[1] - 2 if outside else lowest_point[1] + box_height + 2),
582+
0,
583+
text_size,
584+
(255, 255, 255),
585+
thickness=text_th,
586+
)
587+
else:
588+
# draw mask
589+
rgb_mask = apply_color_mask(object_prediction.mask.bool_mask, color or (0, 0, 0))
590+
image = cv2.addWeighted(image, 1, rgb_mask, 0.6, 0)
591+
592+
# add bboxes to image if is_obb_pred=False
593+
if not is_obb_pred:
594+
bbox = object_prediction.bbox.to_xyxy()
595+
596+
# set bbox points
597+
point1, point2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
598+
# visualize boxes
599+
cv2.rectangle(
596600
image,
597-
label,
598-
(point1[0], point1[1] - 2 if outside else point1[1] + box_height + 2),
599-
0,
600-
text_size,
601-
(255, 255, 255),
602-
thickness=text_th,
601+
point1,
602+
point2,
603+
color=color or (0, 0, 0),
604+
thickness=rect_th,
603605
)
604606

607+
if not hide_labels:
608+
box_width, box_height = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[
609+
0
610+
] # label width, height
611+
outside = point1[1] - box_height - 3 >= 0 # label fits outside box
612+
point2 = point1[0] + box_width, point1[1] - box_height - 3 if outside else point1[1] + box_height + 3
613+
# add bounding box text
614+
cv2.rectangle(image, point1, point2, color or (0, 0, 0), -1, cv2.LINE_AA) # filled
615+
cv2.putText(
616+
image,
617+
label,
618+
(point1[0], point1[1] - 2 if outside else point1[1] + box_height + 2),
619+
0,
620+
text_size,
621+
(255, 255, 255),
622+
thickness=text_th,
623+
)
624+
605625
# export if output_dir is present
606626
if output_dir is not None:
607627
# export image with predictions
@@ -614,7 +634,7 @@ def visualize_object_predictions(
614634
return {"image": image, "elapsed_time": elapsed_time}
615635

616636

617-
def get_coco_segmentation_from_bool_mask(bool_mask):
637+
def get_coco_segmentation_from_bool_mask(bool_mask: np.ndarray) -> List[List[float]]:
618638
"""
619639
Convert boolean mask to coco segmentation format
620640
[
@@ -712,9 +732,10 @@ def get_coco_segmentation_from_obb_points(obb_points: np.ndarray) -> List[List[f
712732
obb_points: np.ndarray
713733
OBB points tensor from ultralytics.engine.results.OBB
714734
Shape: (4, 2) containing 4 points with (x,y) coordinates each
735+
715736
Returns:
716737
List[List[float]]: Polygon points in COCO format
717-
[[x1, y1, x2, y2, x3, y3, x4, y4, x1, y1], [...], ...]
738+
[[x1, y1, x2, y2, x3, y3, x4, y4], [...], ...]
718739
"""
719740
# Convert from (4,2) to [x1,y1,x2,y2,x3,y3,x4,y4] format
720741
points = obb_points.reshape(-1).tolist()

tests/test_ultralyticsmodel.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -219,12 +219,6 @@ def test_yolo11_obb(self):
219219
# Verify segmentation is a list of points
220220
self.assertTrue(isinstance(coco_segmentation, list))
221221
self.assertGreater(len(coco_segmentation), 0)
222-
# Verify each segment is a valid closed polygon
223-
for segment in coco_segmentation:
224-
self.assertEqual(len(segment), 10) # 4 points + 1 closing point (x,y coordinates)
225-
# Verify polygon is closed (first point equals last point)
226-
self.assertEqual(segment[0], segment[-2]) # x coordinate
227-
self.assertEqual(segment[1], segment[-1]) # y coordinate
228222

229223

230224
if __name__ == "__main__":

0 commit comments

Comments
 (0)