5757logger = logging .getLogger (__name__ )
5858
5959
60+ def filter_predictions (object_prediction_list , exclude_classes_by_name , exclude_classes_by_id ):
61+ return [
62+ obj_pred
63+ for obj_pred in object_prediction_list
64+ if obj_pred .category .name not in (exclude_classes_by_name or [])
65+ and obj_pred .category .id not in (exclude_classes_by_id or [])
66+ ]
67+
68+
6069def get_prediction (
6170 image ,
6271 detection_model ,
6372 shift_amount : list = [0 , 0 ],
6473 full_shape = None ,
6574 postprocess : Optional [PostprocessPredictions ] = None ,
6675 verbose : int = 0 ,
76+ exclude_classes_by_name : Optional [List [str ]] = None ,
77+ exclude_classes_by_id : Optional [List [int ]] = None ,
6778) -> PredictionResult :
6879 """
6980 Function for performing prediction for given image using given detection_model.
@@ -81,7 +92,12 @@ def get_prediction(
8192 verbose: int
8293 0: no print (default)
8394 1: print prediction duration
84-
95+ exclude_classes_by_name: Optional[List[str]]
96+ None: if no classes are excluded
97+ List[str]: set of classes to exclude using its/their class label name/s
98+ exclude_classes_by_id: Optional[List[int]]
99+ None: if no classes are excluded
100+ List[int]: set of classes to exclude using one or more IDs
85101 Returns:
86102 A dict with fields:
87103 object_prediction_list: a list of ObjectPrediction
@@ -105,6 +121,7 @@ def get_prediction(
105121 full_shape = full_shape ,
106122 )
107123 object_prediction_list : List [ObjectPrediction ] = detection_model .object_prediction_list
124+ object_prediction_list = filter_predictions (object_prediction_list , exclude_classes_by_name , exclude_classes_by_id )
108125
109126 # postprocess matching predictions
110127 if postprocess is not None :
@@ -142,6 +159,8 @@ def get_sliced_prediction(
142159 auto_slice_resolution : bool = True ,
143160 slice_export_prefix : Optional [str ] = None ,
144161 slice_dir : Optional [str ] = None ,
162+ exclude_classes_by_name : Optional [List [str ]] = None ,
163+ exclude_classes_by_id : Optional [List [int ]] = None ,
145164) -> PredictionResult :
146165 """
147166 Function for slice image + get predicion for each slice + combine predictions in full image.
@@ -191,7 +210,12 @@ def get_sliced_prediction(
191210 Prefix for the exported slices. Defaults to None.
192211 slice_dir: str
193212 Directory to save the slices. Defaults to None.
194-
213+ exclude_classes_by_name: Optional[List[str]]
214+ None: if no classes are excluded
215+ List[str]: set of classes to exclude using its/their class label name/s
216+ exclude_classes_by_id: Optional[List[int]]
217+ None: if no classes are excluded
218+ List[int]: set of classes to exclude using one or more IDs
195219 Returns:
196220 A Dict with fields:
197221 object_prediction_list: a list of sahi.prediction.ObjectPrediction
@@ -257,6 +281,8 @@ def get_sliced_prediction(
257281 slice_image_result .original_image_height ,
258282 slice_image_result .original_image_width ,
259283 ],
284+ exclude_classes_by_name = exclude_classes_by_name ,
285+ exclude_classes_by_id = exclude_classes_by_id ,
260286 )
261287 # convert sliced predictions to full predictions
262288 for object_prediction in prediction_result .object_prediction_list :
@@ -278,6 +304,8 @@ def get_sliced_prediction(
278304 slice_image_result .original_image_width ,
279305 ],
280306 postprocess = None ,
307+ exclude_classes_by_name = exclude_classes_by_name ,
308+ exclude_classes_by_id = exclude_classes_by_id ,
281309 )
282310 object_prediction_list .extend (prediction_result .object_prediction_list )
283311
@@ -380,6 +408,8 @@ def predict(
380408 verbose : int = 1 ,
381409 return_dict : bool = False ,
382410 force_postprocess_type : bool = False ,
411+ exclude_classes_by_name : Optional [List [str ]] = None ,
412+ exclude_classes_by_id : Optional [List [int ]] = None ,
383413 ** kwargs ,
384414):
385415 """
@@ -466,6 +496,12 @@ def predict(
466496 If True, returns a dict with 'export_dir' field.
467497 force_postprocess_type: bool
468498 If True, auto postprocess check will e disabled
499+ exclude_classes_by_name: Optional[List[str]]
500+ None: if no classes are excluded
501+ List[str]: set of classes to exclude using its/their class label name/s
502+ exclude_classes_by_id: Optional[List[int]]
503+ None: if no classes are excluded
504+ List[int]: set of classes to exclude using one or more IDs
469505 """
470506 # assert prediction type
471507 if no_standard_prediction and no_sliced_prediction :
@@ -574,6 +610,8 @@ def predict(
574610 postprocess_match_threshold = postprocess_match_threshold ,
575611 postprocess_class_agnostic = postprocess_class_agnostic ,
576612 verbose = 1 if verbose else 0 ,
613+ exclude_classes_by_name = exclude_classes_by_name ,
614+ exclude_classes_by_id = exclude_classes_by_id ,
577615 )
578616 object_prediction_list = prediction_result .object_prediction_list
579617 if prediction_result .durations_in_seconds :
@@ -587,6 +625,8 @@ def predict(
587625 full_shape = None ,
588626 postprocess = None ,
589627 verbose = 0 ,
628+ exclude_classes_by_name = exclude_classes_by_name ,
629+ exclude_classes_by_id = exclude_classes_by_id ,
590630 )
591631 object_prediction_list = prediction_result .object_prediction_list
592632
@@ -753,6 +793,8 @@ def predict_fiftyone(
753793 postprocess_match_threshold : float = 0.5 ,
754794 postprocess_class_agnostic : bool = False ,
755795 verbose : int = 1 ,
796+ exclude_classes_by_name : Optional [List [str ]] = None ,
797+ exclude_classes_by_id : Optional [List [int ]] = None ,
756798):
757799 """
758800 Performs prediction for all present images in given folder.
@@ -811,6 +853,12 @@ def predict_fiftyone(
811853 verbose: int
812854 0: no print
813855 1: print slice/prediction durations, number of slices, model loading/file exporting durations
856+ exclude_classes_by_name: Optional[List[str]]
857+ None: if no classes are excluded
858+ List[str]: set of classes to exclude using its/their class label name/s
859+ exclude_classes_by_id: Optional[List[int]]
860+ None: if no classes are excluded
861+ List[int]: set of classes to exclude using one or more IDs
814862 """
815863 check_requirements (["fiftyone" ])
816864
@@ -863,6 +911,8 @@ def predict_fiftyone(
863911 postprocess_match_metric = postprocess_match_metric ,
864912 postprocess_class_agnostic = postprocess_class_agnostic ,
865913 verbose = verbose ,
914+ exclude_classes_by_name = exclude_classes_by_name ,
915+ exclude_classes_by_id = exclude_classes_by_id ,
866916 )
867917 durations_in_seconds ["slice" ] += prediction_result .durations_in_seconds ["slice" ]
868918 else :
@@ -874,6 +924,8 @@ def predict_fiftyone(
874924 full_shape = None ,
875925 postprocess = None ,
876926 verbose = 0 ,
927+ exclude_classes_by_name = exclude_classes_by_name ,
928+ exclude_classes_by_id = exclude_classes_by_id ,
877929 )
878930 durations_in_seconds ["prediction" ] += prediction_result .durations_in_seconds ["prediction" ]
879931
0 commit comments