11from __future__ import annotations
22
3+ import numpy as np
34import torch
4- from shapely .geometry import box
5- from shapely .strtree import STRtree
5+ from shapely import STRtree , box
66
77from sahi .logger import logger
88from sahi .postprocess .utils import ObjectPredictionList , has_match , merge_object_prediction_pair
@@ -53,31 +53,30 @@ def nms(
5353 Returns:
5454 A list of filtered indexes, Shape: [ ,]
5555 """
56+ if len (predictions ) == 0 :
57+ return []
5658
57- # Extract coordinates and scores as tensors
58- x1 = predictions [:, 0 ]
59- y1 = predictions [:, 1 ]
60- x2 = predictions [:, 2 ]
61- y2 = predictions [:, 3 ]
62- scores = predictions [:, 4 ]
59+ # Ensure predictions are on CPU and convert to numpy
60+ if predictions .device .type != "cpu" :
61+ predictions = predictions .cpu ()
6362
64- # Calculate areas as tensor (vectorized operation)
63+ predictions_np = predictions .numpy ()
64+
65+ # Extract coordinates and scores
66+ x1 = predictions_np [:, 0 ]
67+ y1 = predictions_np [:, 1 ]
68+ x2 = predictions_np [:, 2 ]
69+ y2 = predictions_np [:, 3 ]
70+ scores = predictions_np [:, 4 ]
71+
72+ # Calculate areas
6573 areas = (x2 - x1 ) * (y2 - y1 )
6674
67- # Create Shapely boxes only once
68- boxes = []
69- for i in range (len (predictions )):
70- boxes .append (
71- box (
72- x1 [i ].item (), # Convert only individual values
73- y1 [i ].item (),
74- x2 [i ].item (),
75- y2 [i ].item (),
76- )
77- )
75+ # Create Shapely boxes (vectorized)
76+ boxes = box (x1 , y1 , x2 , y2 )
7877
79- # Sort indices by score (descending) using torch
80- sorted_idxs = torch .argsort (scores , descending = True ). tolist ()
78+ # Sort indices by score (descending)
79+ sorted_idxs = np .argsort (scores )[:: - 1 ]
8180
8281 # Build STRtree
8382 tree = STRtree (boxes )
@@ -91,7 +90,7 @@ def nms(
9190
9291 keep .append (current_idx )
9392 current_box = boxes [current_idx ]
94- current_area = areas [current_idx ]. item () # Convert only when needed
93+ current_area = areas [current_idx ]
9594
9695 # Query potential intersections using STRtree
9796 candidate_idxs = tree .query (current_box )
@@ -108,16 +107,16 @@ def nms(
108107 if scores [candidate_idx ] == scores [current_idx ]:
109108 # Use box coordinates for stable ordering
110109 current_coords = (
111- x1 [current_idx ]. item () ,
112- y1 [current_idx ]. item () ,
113- x2 [current_idx ]. item () ,
114- y2 [current_idx ]. item () ,
110+ x1 [current_idx ],
111+ y1 [current_idx ],
112+ x2 [current_idx ],
113+ y2 [current_idx ],
115114 )
116115 candidate_coords = (
117- x1 [candidate_idx ]. item () ,
118- y1 [candidate_idx ]. item () ,
119- x2 [candidate_idx ]. item () ,
120- y2 [candidate_idx ]. item () ,
116+ x1 [candidate_idx ],
117+ y1 [candidate_idx ],
118+ x2 [candidate_idx ],
119+ y2 [candidate_idx ],
121120 )
122121
123122 # Compare coordinates lexicographically
@@ -130,10 +129,10 @@ def nms(
130129
131130 # Calculate metric
132131 if match_metric == "IOU" :
133- union = current_area + areas [candidate_idx ]. item () - intersection
132+ union = current_area + areas [candidate_idx ] - intersection
134133 metric = intersection / union if union > 0 else 0
135134 elif match_metric == "IOS" :
136- smaller = min (current_area , areas [candidate_idx ]. item () )
135+ smaller = min (current_area , areas [candidate_idx ])
137136 metric = intersection / smaller if smaller > 0 else 0
138137 else :
139138 raise ValueError ("Invalid match_metric" )
0 commit comments