diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index ed7a0978b4e1..f359fd9e8d00 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -67,11 +67,18 @@ class BatchFeature(UserDict): tensor_type (`Union[None, str, TensorType]`, *optional*): You can give a tensor_type here to convert the lists of integers in PyTorch/Numpy Tensors at initialization. + skip_tensor_conversion (`list[str]` or `set[str]`, *optional*): + List or set of keys that should NOT be converted to tensors, even when `tensor_type` is specified. """ - def __init__(self, data: Optional[dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None): + def __init__( + self, + data: Optional[dict[str, Any]] = None, + tensor_type: Union[None, str, TensorType] = None, + skip_tensor_conversion: Optional[Union[list[str], set[str]]] = None, + ): super().__init__(data) - self.convert_to_tensors(tensor_type=tensor_type) + self.convert_to_tensors(tensor_type=tensor_type, skip_tensor_conversion=skip_tensor_conversion) def __getitem__(self, item: str) -> Any: """ @@ -110,6 +117,14 @@ def _get_is_as_tensor_fns(self, tensor_type: Optional[Union[str, TensorType]] = import torch def as_tensor(value): + if torch.is_tensor(value): + return value + + # stack list of tensors if tensor_type is PyTorch (# torch.tensor() does not support list of tensors) + if isinstance(value, (list, tuple)) and len(value) > 0 and torch.is_tensor(value[0]): + return torch.stack(value) + + # convert list of numpy arrays to numpy array (stack) if tensor_type is Numpy if isinstance(value, (list, tuple)) and len(value) > 0: if isinstance(value[0], np.ndarray): value = np.array(value) @@ -138,7 +153,11 @@ def as_tensor(value, dtype=None): is_tensor = is_numpy_array return is_tensor, as_tensor - def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None): + def convert_to_tensors( + self, + tensor_type: Optional[Union[str, TensorType]] = None, + skip_tensor_conversion: Optional[Union[list[str], set[str]]] = None, + ): """ Convert the inner content to tensors. @@ -146,6 +165,8 @@ def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = Non tensor_type (`str` or [`~utils.TensorType`], *optional*): The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If `None`, no modification is done. + skip_tensor_conversion (`list[str]` or `set[str]`, *optional*): + List or set of keys that should NOT be converted to tensors, even when `tensor_type` is specified. """ if tensor_type is None: return self @@ -154,18 +175,26 @@ def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = Non # Do the tensor conversion in batch for key, value in self.items(): + # Skip keys explicitly marked for no conversion + if skip_tensor_conversion and key in skip_tensor_conversion: + continue + try: if not is_tensor(value): tensor = as_tensor(value) - self[key] = tensor - except: # noqa E722 + except Exception as e: if key == "overflowing_values": - raise ValueError("Unable to create tensor returning overflowing values of different lengths. ") + raise ValueError( + f"Unable to create tensor for '{key}' with overflowing values of different lengths. " + f"Original error: {str(e)}" + ) from e raise ValueError( - "Unable to create tensor, you should probably activate padding " - "with 'padding=True' to have batched tensors with the same length." - ) + f"Unable to convert output '{key}' (type: {type(value).__name__}) to tensor: {str(e)}\n" + f"You can try:\n" + f" 1. Use padding=True to ensure all outputs have the same shape\n" + f" 2. Set return_tensors=None to return Python objects instead of tensors" + ) from e return self diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py index 762b7160f155..e1e6d935aa10 100644 --- a/src/transformers/image_processing_utils_fast.py +++ b/src/transformers/image_processing_utils_fast.py @@ -932,7 +932,6 @@ def _preprocess( if do_pad: processed_images = self.pad(processed_images, pad_size=pad_size, disable_grouping=disable_grouping) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) def to_dict(self): diff --git a/src/transformers/models/beit/image_processing_beit_fast.py b/src/transformers/models/beit/image_processing_beit_fast.py index 5d89120283a5..b739a1fab579 100644 --- a/src/transformers/models/beit/image_processing_beit_fast.py +++ b/src/transformers/models/beit/image_processing_beit_fast.py @@ -163,7 +163,6 @@ def _preprocess( processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py b/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py index 76a76b4b0a47..b4ca1caf0d3d 100644 --- a/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py +++ b/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py @@ -251,10 +251,8 @@ def _preprocess( processed_images, processed_masks = self.pad( processed_images, return_mask=True, disable_grouping=disable_grouping ) - processed_masks = torch.stack(processed_masks, dim=0) if return_tensors else processed_masks data["pixel_mask"] = processed_masks - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images data["pixel_values"] = processed_images return BatchFeature(data=data, tensor_type=return_tensors) diff --git a/src/transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py b/src/transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py index afdd683e2312..70aa8d71ef19 100644 --- a/src/transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +++ b/src/transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py @@ -262,7 +262,6 @@ def _preprocess( processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature( data={"pixel_values": processed_images, "num_patches": num_patches}, tensor_type=return_tensors diff --git a/src/transformers/models/convnext/image_processing_convnext_fast.py b/src/transformers/models/convnext/image_processing_convnext_fast.py index 035b92f8b7d2..fb122b41f2ca 100644 --- a/src/transformers/models/convnext/image_processing_convnext_fast.py +++ b/src/transformers/models/convnext/image_processing_convnext_fast.py @@ -162,7 +162,6 @@ def _preprocess( processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/src/transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py b/src/transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py index 6eaa15d827d9..aab51fdf9679 100644 --- a/src/transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +++ b/src/transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py @@ -171,7 +171,6 @@ def _preprocess( processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py b/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py index 0a1efea13cff..ef2ee384b736 100644 --- a/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +++ b/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py @@ -207,9 +207,6 @@ def _preprocess( ) high_res_processed_images_grouped[shape] = stacked_high_res_images high_res_processed_images = reorder_images(high_res_processed_images_grouped, grouped_high_res_images_index) - high_res_processed_images = ( - torch.stack(high_res_processed_images, dim=0) if return_tensors else high_res_processed_images - ) resized_images_grouped = {} for shape, stacked_high_res_padded_images in high_res_padded_images.items(): @@ -233,7 +230,6 @@ def _preprocess( ) processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_resized_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature( data={"pixel_values": processed_images, "high_res_pixel_values": high_res_processed_images}, diff --git a/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py b/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py index 4501ed7810d2..622458328977 100644 --- a/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +++ b/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py @@ -888,9 +888,6 @@ def _preprocess( ) high_res_processed_images_grouped[shape] = stacked_high_res_images high_res_processed_images = reorder_images(high_res_processed_images_grouped, grouped_high_res_images_index) - high_res_processed_images = ( - torch.stack(high_res_processed_images, dim=0) if return_tensors else high_res_processed_images - ) resized_images_grouped = {} for shape, stacked_high_res_padded_images in high_res_padded_images.items(): @@ -914,7 +911,6 @@ def _preprocess( ) processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_resized_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature( data={"pixel_values": processed_images, "high_res_pixel_values": high_res_processed_images}, diff --git a/src/transformers/models/depth_pro/image_processing_depth_pro_fast.py b/src/transformers/models/depth_pro/image_processing_depth_pro_fast.py index bc621e0ffc26..93cf889d43ee 100644 --- a/src/transformers/models/depth_pro/image_processing_depth_pro_fast.py +++ b/src/transformers/models/depth_pro/image_processing_depth_pro_fast.py @@ -94,7 +94,6 @@ def _preprocess( processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/src/transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py b/src/transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py index 7c080485ed00..3f3a2334ab5c 100644 --- a/src/transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +++ b/src/transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py @@ -88,7 +88,6 @@ def _preprocess( processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/src/transformers/models/donut/image_processing_donut_fast.py b/src/transformers/models/donut/image_processing_donut_fast.py index 9a150f4df75f..f27c9491cb59 100644 --- a/src/transformers/models/donut/image_processing_donut_fast.py +++ b/src/transformers/models/donut/image_processing_donut_fast.py @@ -231,7 +231,6 @@ def _preprocess( processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/src/transformers/models/dpt/image_processing_dpt_fast.py b/src/transformers/models/dpt/image_processing_dpt_fast.py index ba0a6d28c56c..06fd884afaf7 100644 --- a/src/transformers/models/dpt/image_processing_dpt_fast.py +++ b/src/transformers/models/dpt/image_processing_dpt_fast.py @@ -225,8 +225,7 @@ def _preprocess( processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images - return BatchFeature(data={"pixel_values": processed_images}) + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): """ diff --git a/src/transformers/models/dpt/modular_dpt.py b/src/transformers/models/dpt/modular_dpt.py index d99160653557..ea7a789a536d 100644 --- a/src/transformers/models/dpt/modular_dpt.py +++ b/src/transformers/models/dpt/modular_dpt.py @@ -228,8 +228,7 @@ def _preprocess( processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images - return BatchFeature(data={"pixel_values": processed_images}) + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) def post_process_depth_estimation( self, diff --git a/src/transformers/models/efficientloftr/image_processing_efficientloftr_fast.py b/src/transformers/models/efficientloftr/image_processing_efficientloftr_fast.py index 74a5d4577c91..1584e2a782ad 100644 --- a/src/transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +++ b/src/transformers/models/efficientloftr/image_processing_efficientloftr_fast.py @@ -153,9 +153,8 @@ def _preprocess( stacked_pairs = [torch.stack(pair, dim=0) for pair in image_pairs] # Return in same format as slow processor - image_pairs = torch.stack(stacked_pairs, dim=0) if return_tensors else stacked_pairs - return BatchFeature(data={"pixel_values": image_pairs}) + return BatchFeature(data={"pixel_values": stacked_pairs}, tensor_type=return_tensors) def post_process_keypoint_matching( self, diff --git a/src/transformers/models/efficientnet/image_processing_efficientnet_fast.py b/src/transformers/models/efficientnet/image_processing_efficientnet_fast.py index 5f3439aaa273..93e1237f061c 100644 --- a/src/transformers/models/efficientnet/image_processing_efficientnet_fast.py +++ b/src/transformers/models/efficientnet/image_processing_efficientnet_fast.py @@ -178,7 +178,6 @@ def _preprocess( processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/src/transformers/models/eomt/image_processing_eomt_fast.py b/src/transformers/models/eomt/image_processing_eomt_fast.py index b6bf7f3aa13e..ec85d98f7d39 100644 --- a/src/transformers/models/eomt/image_processing_eomt_fast.py +++ b/src/transformers/models/eomt/image_processing_eomt_fast.py @@ -162,8 +162,7 @@ def _preprocess_image_like_inputs( ) ignore_index = kwargs.pop("ignore_index", None) images_kwargs = kwargs.copy() - processed_images, patch_offsets = self._preprocess(images, **images_kwargs) - outputs = BatchFeature({"pixel_values": processed_images}) + outputs = self._preprocess(images, **images_kwargs) if segmentation_maps is not None: processed_segmentation_maps = self._prepare_image_like_inputs( @@ -183,9 +182,9 @@ def _preprocess_image_like_inputs( } ) - processed_segmentation_maps, _ = self._preprocess( + processed_segmentation_maps = self._preprocess( images=processed_segmentation_maps, **segmentation_maps_kwargs - ) + ).pixel_values processed_segmentation_maps = processed_segmentation_maps.squeeze(1).to(torch.int64) # Convert to list of binary masks and labels mask_labels, class_labels = [], [] @@ -208,8 +207,8 @@ def _preprocess_image_like_inputs( outputs["mask_labels"] = mask_labels outputs["class_labels"] = class_labels - if patch_offsets: - outputs["patch_offsets"] = [torch.tensor(offsets) for offsets in patch_offsets] + if outputs.patch_offsets: + outputs["patch_offsets"] = [torch.tensor(offsets) for offsets in outputs.patch_offsets] return outputs @@ -274,11 +273,13 @@ def _preprocess( stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std ) processed_images_grouped[shape] = stacked_images - images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(images, dim=0) if return_tensors else images - - return processed_images, patch_offsets + return BatchFeature( + data={"pixel_values": processed_images, "patch_offsets": patch_offsets}, + tensor_type=return_tensors, + skip_tensor_conversion=["patch_offsets"], + ) def merge_image_patches( self, diff --git a/src/transformers/models/flava/image_processing_flava_fast.py b/src/transformers/models/flava/image_processing_flava_fast.py index 0dfbd07f17a7..8959bf3c2bb8 100644 --- a/src/transformers/models/flava/image_processing_flava_fast.py +++ b/src/transformers/models/flava/image_processing_flava_fast.py @@ -306,7 +306,6 @@ def _preprocess_image( processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return processed_images @@ -397,7 +396,6 @@ def _preprocess( mask_group_max_aspect_ratio=mask_group_max_aspect_ratio, ) masks = [mask_generator() for _ in range(len(images))] - masks = torch.stack(masks, dim=0) if return_tensors else masks data["bool_masked_pos"] = masks return BatchFeature(data=data, tensor_type=return_tensors) diff --git a/src/transformers/models/fuyu/image_processing_fuyu.py b/src/transformers/models/fuyu/image_processing_fuyu.py index e86352af1bf5..c1a2d492c5df 100644 --- a/src/transformers/models/fuyu/image_processing_fuyu.py +++ b/src/transformers/models/fuyu/image_processing_fuyu.py @@ -94,7 +94,7 @@ class FuyuBatchFeature(BatchFeature): The outputs dictionary from the processors contains a mix of tensors and lists of tensors. """ - def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None): + def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None, **kwargs): """ Convert the inner content to tensors. diff --git a/src/transformers/models/gemma3/image_processing_gemma3_fast.py b/src/transformers/models/gemma3/image_processing_gemma3_fast.py index bfb58be2a8e1..a6ad4d9c67a2 100644 --- a/src/transformers/models/gemma3/image_processing_gemma3_fast.py +++ b/src/transformers/models/gemma3/image_processing_gemma3_fast.py @@ -231,7 +231,6 @@ def _preprocess( processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature( data={"pixel_values": processed_images, "num_crops": num_crops}, tensor_type=return_tensors ) diff --git a/src/transformers/models/glpn/image_processing_glpn_fast.py b/src/transformers/models/glpn/image_processing_glpn_fast.py index a906dc29c271..07f737a37f96 100644 --- a/src/transformers/models/glpn/image_processing_glpn_fast.py +++ b/src/transformers/models/glpn/image_processing_glpn_fast.py @@ -107,7 +107,6 @@ def _preprocess( processed_groups[shape] = stacked_images processed_images = reorder_images(processed_groups, grouped_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) def post_process_depth_estimation(self, outputs, target_sizes=None): diff --git a/src/transformers/models/got_ocr2/image_processing_got_ocr2_fast.py b/src/transformers/models/got_ocr2/image_processing_got_ocr2_fast.py index 210a18a406be..be183ae79415 100644 --- a/src/transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +++ b/src/transformers/models/got_ocr2/image_processing_got_ocr2_fast.py @@ -189,7 +189,6 @@ def _preprocess( processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature( data={"pixel_values": processed_images, "num_patches": num_patches}, tensor_type=return_tensors diff --git a/src/transformers/models/imagegpt/image_processing_imagegpt_fast.py b/src/transformers/models/imagegpt/image_processing_imagegpt_fast.py index 1be050b5ecf9..44aff7c91245 100644 --- a/src/transformers/models/imagegpt/image_processing_imagegpt_fast.py +++ b/src/transformers/models/imagegpt/image_processing_imagegpt_fast.py @@ -164,12 +164,8 @@ def _preprocess( input_ids = reorder_images(input_ids_grouped, grouped_images_index) - return BatchFeature( - data={"input_ids": torch.stack(input_ids, dim=0) if return_tensors else input_ids}, - tensor_type=return_tensors, - ) + return BatchFeature(data={"input_ids": input_ids}, tensor_type=return_tensors) - pixel_values = torch.stack(pixel_values, dim=0) if return_tensors else pixel_values return BatchFeature(data={"pixel_values": pixel_values}, tensor_type=return_tensors) def to_dict(self): diff --git a/src/transformers/models/instructblipvideo/video_processing_instructblipvideo.py b/src/transformers/models/instructblipvideo/video_processing_instructblipvideo.py index f2c49925ef19..ba5933769af9 100644 --- a/src/transformers/models/instructblipvideo/video_processing_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/video_processing_instructblipvideo.py @@ -84,7 +84,6 @@ def _preprocess( processed_videos_grouped[shape] = stacked_videos processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index) - processed_videos = torch.stack(processed_videos, dim=0) if return_tensors else processed_videos return BatchFeature(data={"pixel_values": processed_videos}, tensor_type=return_tensors) diff --git a/src/transformers/models/internvl/video_processing_internvl.py b/src/transformers/models/internvl/video_processing_internvl.py index a544bb08815a..de82e04602dc 100644 --- a/src/transformers/models/internvl/video_processing_internvl.py +++ b/src/transformers/models/internvl/video_processing_internvl.py @@ -140,7 +140,6 @@ def _preprocess( processed_videos_grouped[shape] = stacked_videos processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index) - processed_videos = torch.stack(processed_videos, dim=0) if return_tensors else processed_videos return BatchFeature(data={"pixel_values_videos": processed_videos}, tensor_type=return_tensors) diff --git a/src/transformers/models/janus/image_processing_janus_fast.py b/src/transformers/models/janus/image_processing_janus_fast.py index b8e032786bf4..0a176ca9818e 100644 --- a/src/transformers/models/janus/image_processing_janus_fast.py +++ b/src/transformers/models/janus/image_processing_janus_fast.py @@ -180,7 +180,6 @@ def _preprocess( processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/src/transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py b/src/transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py index d892436ea652..3e7b61c1c818 100644 --- a/src/transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +++ b/src/transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py @@ -264,8 +264,8 @@ def _preprocess( encoded_outputs = BatchFeature( data={ - "flattened_patches": torch.stack(flattened_patches, dim=0) if return_tensors else flattened_patches, - "attention_mask": torch.stack(attention_masks, dim=0) if return_tensors else attention_masks, + "flattened_patches": flattened_patches, + "attention_mask": attention_masks, "width": width, "height": height, "rows": rows, diff --git a/src/transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py b/src/transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py index 2d6e6bc21cb3..a74b3f02c118 100644 --- a/src/transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +++ b/src/transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py @@ -101,7 +101,6 @@ def _preprocess( processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images data = BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/src/transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py b/src/transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py index b614c5ec9449..0553b2b3a1a6 100644 --- a/src/transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +++ b/src/transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py @@ -115,7 +115,6 @@ def _preprocess( processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images data = BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/src/transformers/models/lightglue/image_processing_lightglue_fast.py b/src/transformers/models/lightglue/image_processing_lightglue_fast.py index e99237cc104d..2785ac652730 100644 --- a/src/transformers/models/lightglue/image_processing_lightglue_fast.py +++ b/src/transformers/models/lightglue/image_processing_lightglue_fast.py @@ -174,9 +174,8 @@ def _preprocess( stacked_pairs = [torch.stack(pair, dim=0) for pair in image_pairs] # Return in same format as slow processor - image_pairs = torch.stack(stacked_pairs, dim=0) if return_tensors else stacked_pairs - return BatchFeature(data={"pixel_values": image_pairs}) + return BatchFeature(data={"pixel_values": stacked_pairs}, tensor_type=return_tensors) def post_process_keypoint_matching( self, diff --git a/src/transformers/models/llama4/image_processing_llama4_fast.py b/src/transformers/models/llama4/image_processing_llama4_fast.py index ef44786f7c66..00c58cede8ff 100644 --- a/src/transformers/models/llama4/image_processing_llama4_fast.py +++ b/src/transformers/models/llama4/image_processing_llama4_fast.py @@ -419,10 +419,9 @@ def _preprocess( ) grouped_processed_images[shape] = torch.cat([processed_images, global_tiles.unsqueeze(1)], dim=1) processed_images = reorder_images(grouped_processed_images, grouped_images_index) - aspect_ratios_list = reorder_images(grouped_aspect_ratios, grouped_images_index) + aspect_ratios = reorder_images(grouped_aspect_ratios, grouped_images_index) processed_images = torch.cat(processed_images, dim=0) if return_tensors else processed_images - aspect_ratios = torch.stack(aspect_ratios_list, dim=0) if return_tensors else aspect_ratios_list return BatchFeature( data={"pixel_values": processed_images, "aspect_ratios": aspect_ratios}, tensor_type=return_tensors ) diff --git a/src/transformers/models/llava/image_processing_llava_fast.py b/src/transformers/models/llava/image_processing_llava_fast.py index 66ccb49c3671..e2f941d7ac49 100644 --- a/src/transformers/models/llava/image_processing_llava_fast.py +++ b/src/transformers/models/llava/image_processing_llava_fast.py @@ -149,7 +149,6 @@ def _preprocess( processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/src/transformers/models/llava_next/image_processing_llava_next_fast.py b/src/transformers/models/llava_next/image_processing_llava_next_fast.py index cc5dc756b237..936d88fba086 100644 --- a/src/transformers/models/llava_next/image_processing_llava_next_fast.py +++ b/src/transformers/models/llava_next/image_processing_llava_next_fast.py @@ -260,7 +260,6 @@ def _preprocess( if do_pad: processed_images = self._pad_for_batching(processed_images) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature( data={"pixel_values": processed_images, "image_sizes": image_sizes}, tensor_type=return_tensors ) diff --git a/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py b/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py index beb1c1b982e0..02f108105d00 100644 --- a/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +++ b/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py @@ -279,7 +279,6 @@ def _preprocess( if do_pad: processed_images = self._pad_for_batching(processed_images) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature( data={"pixel_values": processed_images, "image_sizes": image_sizes, "batch_num_images": batch_num_images}, tensor_type=return_tensors, diff --git a/src/transformers/models/llava_onevision/modular_llava_onevision.py b/src/transformers/models/llava_onevision/modular_llava_onevision.py index dd714def07c2..cb837861f9a1 100644 --- a/src/transformers/models/llava_onevision/modular_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modular_llava_onevision.py @@ -211,7 +211,6 @@ def _preprocess( if do_pad: processed_images = self._pad_for_batching(processed_images) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature( data={"pixel_values": processed_images, "image_sizes": image_sizes, "batch_num_images": batch_num_images}, tensor_type=return_tensors, diff --git a/src/transformers/models/mask2former/image_processing_mask2former_fast.py b/src/transformers/models/mask2former/image_processing_mask2former_fast.py index f8d176dcf042..dfd91ffca490 100644 --- a/src/transformers/models/mask2former/image_processing_mask2former_fast.py +++ b/src/transformers/models/mask2former/image_processing_mask2former_fast.py @@ -387,10 +387,7 @@ def _preprocess( processed_images = reorder_images(processed_images_grouped, grouped_images_index) processed_pixel_masks = reorder_images(processed_pixel_masks_grouped, grouped_images_index) encoded_inputs = BatchFeature( - data={ - "pixel_values": torch.stack(processed_images, dim=0) if return_tensors else processed_images, - "pixel_mask": torch.stack(processed_pixel_masks, dim=0) if return_tensors else processed_pixel_masks, - }, + data={"pixel_values": processed_images, "pixel_mask": processed_pixel_masks}, tensor_type=return_tensors, ) if segmentation_maps is not None: diff --git a/src/transformers/models/maskformer/image_processing_maskformer_fast.py b/src/transformers/models/maskformer/image_processing_maskformer_fast.py index 2be8ca8f16a9..bbe1dd39857f 100644 --- a/src/transformers/models/maskformer/image_processing_maskformer_fast.py +++ b/src/transformers/models/maskformer/image_processing_maskformer_fast.py @@ -391,10 +391,7 @@ def _preprocess( processed_images = reorder_images(processed_images_grouped, grouped_images_index) processed_pixel_masks = reorder_images(processed_pixel_masks_grouped, grouped_images_index) encoded_inputs = BatchFeature( - data={ - "pixel_values": torch.stack(processed_images, dim=0) if return_tensors else processed_images, - "pixel_mask": torch.stack(processed_pixel_masks, dim=0) if return_tensors else processed_pixel_masks, - }, + data={"pixel_values": processed_images, "pixel_mask": processed_pixel_masks}, tensor_type=return_tensors, ) if segmentation_maps is not None: diff --git a/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py b/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py index 2c8329a034c1..f29035e44422 100644 --- a/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +++ b/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py @@ -180,7 +180,6 @@ def _preprocess( processed_images = reorder_images(processed_images_grouped, grouped_images_index) # Stack all processed images if return_tensors is specified - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/src/transformers/models/mobilevit/image_processing_mobilevit_fast.py b/src/transformers/models/mobilevit/image_processing_mobilevit_fast.py index 81d745c2b54d..5edb9a6dd015 100644 --- a/src/transformers/models/mobilevit/image_processing_mobilevit_fast.py +++ b/src/transformers/models/mobilevit/image_processing_mobilevit_fast.py @@ -182,7 +182,6 @@ def _preprocess( processed_images = reorder_images(processed_images_grouped, grouped_images_index) # Stack all processed images if return_tensors is specified - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/src/transformers/models/nougat/image_processing_nougat_fast.py b/src/transformers/models/nougat/image_processing_nougat_fast.py index b059688d0046..e5b60b5ffe8e 100644 --- a/src/transformers/models/nougat/image_processing_nougat_fast.py +++ b/src/transformers/models/nougat/image_processing_nougat_fast.py @@ -290,7 +290,6 @@ def _preprocess( processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/src/transformers/models/ovis2/image_processing_ovis2_fast.py b/src/transformers/models/ovis2/image_processing_ovis2_fast.py index ea618e073526..a8ec773ddb59 100644 --- a/src/transformers/models/ovis2/image_processing_ovis2_fast.py +++ b/src/transformers/models/ovis2/image_processing_ovis2_fast.py @@ -213,7 +213,6 @@ def _preprocess( processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images, "grids": grids}, tensor_type=return_tensors) diff --git a/src/transformers/models/owlv2/image_processing_owlv2_fast.py b/src/transformers/models/owlv2/image_processing_owlv2_fast.py index d31173c997c4..2fda6f16cbf9 100644 --- a/src/transformers/models/owlv2/image_processing_owlv2_fast.py +++ b/src/transformers/models/owlv2/image_processing_owlv2_fast.py @@ -336,8 +336,6 @@ def _preprocess( processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images - return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/src/transformers/models/owlv2/modular_owlv2.py b/src/transformers/models/owlv2/modular_owlv2.py index 590fa5b4b31c..3d2012e71a6f 100644 --- a/src/transformers/models/owlv2/modular_owlv2.py +++ b/src/transformers/models/owlv2/modular_owlv2.py @@ -205,8 +205,6 @@ def _preprocess( processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images - return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/src/transformers/models/perceiver/image_processing_perceiver_fast.py b/src/transformers/models/perceiver/image_processing_perceiver_fast.py index 72cb17cd40cd..5f103bf03233 100644 --- a/src/transformers/models/perceiver/image_processing_perceiver_fast.py +++ b/src/transformers/models/perceiver/image_processing_perceiver_fast.py @@ -113,7 +113,6 @@ def _preprocess( processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/src/transformers/models/perception_lm/image_processing_perception_lm_fast.py b/src/transformers/models/perception_lm/image_processing_perception_lm_fast.py index 03ff515e63af..8c169a0b6804 100644 --- a/src/transformers/models/perception_lm/image_processing_perception_lm_fast.py +++ b/src/transformers/models/perception_lm/image_processing_perception_lm_fast.py @@ -307,7 +307,6 @@ def _preprocess( processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) processed_images = [p[None] if p.ndim == 3 else p for p in processed_images] # add tiles dimension if needed - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/src/transformers/models/poolformer/image_processing_poolformer_fast.py b/src/transformers/models/poolformer/image_processing_poolformer_fast.py index 594d076a924c..31c73fadb628 100644 --- a/src/transformers/models/poolformer/image_processing_poolformer_fast.py +++ b/src/transformers/models/poolformer/image_processing_poolformer_fast.py @@ -231,7 +231,6 @@ def _preprocess( processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/src/transformers/models/sam/image_processing_sam_fast.py b/src/transformers/models/sam/image_processing_sam_fast.py index fa824daee4be..efdee9e232e6 100644 --- a/src/transformers/models/sam/image_processing_sam_fast.py +++ b/src/transformers/models/sam/image_processing_sam_fast.py @@ -267,7 +267,6 @@ def _preprocess( if do_pad: processed_images = self.pad(processed_images, pad_size=pad_size, disable_grouping=disable_grouping) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature( data={"pixel_values": processed_images, "reshaped_input_sizes": reshaped_input_sizes}, tensor_type=return_tensors, diff --git a/src/transformers/models/segformer/image_processing_segformer_fast.py b/src/transformers/models/segformer/image_processing_segformer_fast.py index d3dc35e609de..ec9a070f23b9 100644 --- a/src/transformers/models/segformer/image_processing_segformer_fast.py +++ b/src/transformers/models/segformer/image_processing_segformer_fast.py @@ -168,7 +168,6 @@ def _preprocess( processed_images = reorder_images(processed_images_grouped, grouped_images_index) # Stack images into a single tensor if return_tensors is set - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/src/transformers/models/segformer/modular_segformer.py b/src/transformers/models/segformer/modular_segformer.py index 6bbbe9ecd4fd..1fcb7d840f71 100644 --- a/src/transformers/models/segformer/modular_segformer.py +++ b/src/transformers/models/segformer/modular_segformer.py @@ -140,7 +140,6 @@ def _preprocess( processed_images = reorder_images(processed_images_grouped, grouped_images_index) # Stack images into a single tensor if return_tensors is set - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/src/transformers/models/smolvlm/video_processing_smolvlm.py b/src/transformers/models/smolvlm/video_processing_smolvlm.py index 09751486f0ae..9e08c6a85c41 100644 --- a/src/transformers/models/smolvlm/video_processing_smolvlm.py +++ b/src/transformers/models/smolvlm/video_processing_smolvlm.py @@ -331,7 +331,6 @@ def _preprocess( processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index) pixel_attention_mask = reorder_videos(processed_padded_mask_grouped, grouped_videos_index) - processed_videos = torch.stack(processed_videos, dim=0) if return_tensors else processed_videos data = {"pixel_values": processed_videos} if do_pad: diff --git a/src/transformers/models/superglue/image_processing_superglue_fast.py b/src/transformers/models/superglue/image_processing_superglue_fast.py index ffbcc7ce9508..f751928471ca 100644 --- a/src/transformers/models/superglue/image_processing_superglue_fast.py +++ b/src/transformers/models/superglue/image_processing_superglue_fast.py @@ -161,9 +161,8 @@ def _preprocess( stacked_pairs = [torch.stack(pair, dim=0) for pair in image_pairs] # Return in same format as slow processor - image_pairs = torch.stack(stacked_pairs, dim=0) if return_tensors else stacked_pairs - return BatchFeature(data={"pixel_values": image_pairs}) + return BatchFeature(data={"pixel_values": stacked_pairs}, tensor_type=return_tensors) def post_process_keypoint_matching( self, diff --git a/src/transformers/models/superpoint/image_processing_superpoint_fast.py b/src/transformers/models/superpoint/image_processing_superpoint_fast.py index 3750441fc9f0..24638c1892ea 100644 --- a/src/transformers/models/superpoint/image_processing_superpoint_fast.py +++ b/src/transformers/models/superpoint/image_processing_superpoint_fast.py @@ -110,8 +110,7 @@ def _preprocess( stacked_images = self.rescale(stacked_images, rescale_factor) processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images - return BatchFeature(data={"pixel_values": processed_images}) + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) def post_process_keypoint_detection( self, outputs: "SuperPointKeypointDescriptionOutput", target_sizes: Union[TensorType, list[tuple]] diff --git a/src/transformers/models/swin2sr/image_processing_swin2sr_fast.py b/src/transformers/models/swin2sr/image_processing_swin2sr_fast.py index f85c124041bb..82fe6b71ee38 100644 --- a/src/transformers/models/swin2sr/image_processing_swin2sr_fast.py +++ b/src/transformers/models/swin2sr/image_processing_swin2sr_fast.py @@ -97,7 +97,6 @@ def _preprocess( stacked_images = self.pad(stacked_images, size_divisor=size_divisor) processed_image_grouped[shape] = stacked_images processed_images = reorder_images(processed_image_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/src/transformers/models/textnet/image_processing_textnet_fast.py b/src/transformers/models/textnet/image_processing_textnet_fast.py index eba6e14e64bc..574aafba4d3e 100644 --- a/src/transformers/models/textnet/image_processing_textnet_fast.py +++ b/src/transformers/models/textnet/image_processing_textnet_fast.py @@ -137,7 +137,6 @@ def _preprocess( processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/src/transformers/models/vitmatte/image_processing_vitmatte_fast.py b/src/transformers/models/vitmatte/image_processing_vitmatte_fast.py index 54f54d18cc89..b8719d9e358d 100644 --- a/src/transformers/models/vitmatte/image_processing_vitmatte_fast.py +++ b/src/transformers/models/vitmatte/image_processing_vitmatte_fast.py @@ -152,7 +152,6 @@ def _preprocess( processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/src/transformers/models/vitpose/image_processing_vitpose_fast.py b/src/transformers/models/vitpose/image_processing_vitpose_fast.py index ec5fadbfd6c1..c9dcb959431a 100644 --- a/src/transformers/models/vitpose/image_processing_vitpose_fast.py +++ b/src/transformers/models/vitpose/image_processing_vitpose_fast.py @@ -156,7 +156,6 @@ def _preprocess( processed_images = reorder_images(processed_images_grouped, grouped_images_index) # Stack into batch tensor - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/src/transformers/models/zoedepth/image_processing_zoedepth_fast.py b/src/transformers/models/zoedepth/image_processing_zoedepth_fast.py index 852ee161aff1..facc7d744d39 100644 --- a/src/transformers/models/zoedepth/image_processing_zoedepth_fast.py +++ b/src/transformers/models/zoedepth/image_processing_zoedepth_fast.py @@ -171,9 +171,7 @@ def _preprocess( if do_normalize: stacked_images = self.normalize(stacked_images, image_mean, image_std) resized_images_grouped[shape] = stacked_images - resized_images = reorder_images(resized_images_grouped, grouped_images_index) - - processed_images = torch.stack(resized_images, dim=0) if return_tensors else resized_images + processed_images = reorder_images(resized_images_grouped, grouped_images_index) return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) diff --git a/src/transformers/video_processing_utils.py b/src/transformers/video_processing_utils.py index d73bbce889f1..094e526d5d8b 100644 --- a/src/transformers/video_processing_utils.py +++ b/src/transformers/video_processing_utils.py @@ -442,7 +442,6 @@ def _preprocess( processed_videos_grouped[shape] = stacked_videos processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index) - processed_videos = torch.stack(processed_videos, dim=0) if return_tensors else processed_videos return BatchFeature(data={"pixel_values_videos": processed_videos}, tensor_type=return_tensors) diff --git a/tests/models/dac/test_feature_extraction_dac.py b/tests/models/dac/test_feature_extraction_dac.py index c995485d3311..d71cb0370895 100644 --- a/tests/models/dac/test_feature_extraction_dac.py +++ b/tests/models/dac/test_feature_extraction_dac.py @@ -207,7 +207,7 @@ def test_truncation_and_padding(self): # force no pad with self.assertRaisesRegex( ValueError, - "^Unable to create tensor, you should probably activate padding with 'padding=True' to have batched tensors with the same length.$", + r"Unable to convert output[\s\S]*padding=True", ): truncated_outputs = feature_extractor(input_audio, padding=False, return_tensors="pt").input_values diff --git a/tests/models/dia/test_feature_extraction_dia.py b/tests/models/dia/test_feature_extraction_dia.py index 9a6f797d5346..990e6ead9c59 100644 --- a/tests/models/dia/test_feature_extraction_dia.py +++ b/tests/models/dia/test_feature_extraction_dia.py @@ -223,7 +223,7 @@ def test_truncation_and_padding(self): # force no pad with self.assertRaisesRegex( ValueError, - "^Unable to create tensor, you should probably activate padding with 'padding=True' to have batched tensors with the same length.$", + r"Unable to convert output[\s\S]*padding=True", ): truncated_outputs = feature_extractor(input_audio, padding=False, return_tensors="pt").input_values diff --git a/tests/models/encodec/test_feature_extraction_encodec.py b/tests/models/encodec/test_feature_extraction_encodec.py index 2823b0099372..c3850dd71358 100644 --- a/tests/models/encodec/test_feature_extraction_encodec.py +++ b/tests/models/encodec/test_feature_extraction_encodec.py @@ -221,7 +221,7 @@ def test_truncation_and_padding(self): # force no pad with self.assertRaisesRegex( ValueError, - "^Unable to create tensor, you should probably activate padding with 'padding=True' to have batched tensors with the same length.$", + r"Unable to convert output[\s\S]*padding=True", ): truncated_outputs = feature_extractor(input_audio, padding=False, return_tensors="pt").input_values @@ -232,7 +232,7 @@ def test_truncation_and_padding(self): feature_extractor.chunk_length_s = None with self.assertRaisesRegex( ValueError, - "^Unable to create tensor, you should probably activate padding with 'padding=True' to have batched tensors with the same length.$", + r"Unable to convert output[\s\S]*padding=True", ): truncated_outputs = feature_extractor(input_audio, padding=False, return_tensors="pt").input_values @@ -244,7 +244,7 @@ def test_truncation_and_padding(self): feature_extractor.overlap = None with self.assertRaisesRegex( ValueError, - "^Unable to create tensor, you should probably activate padding with 'padding=True' to have batched tensors with the same length.$", + r"Unable to convert output[\s\S]*padding=True", ): truncated_outputs = feature_extractor(input_audio, padding=False, return_tensors="pt").input_values diff --git a/tests/models/owlv2/test_image_processing_owlv2.py b/tests/models/owlv2/test_image_processing_owlv2.py index 0299404a75b7..eef7d5771522 100644 --- a/tests/models/owlv2/test_image_processing_owlv2.py +++ b/tests/models/owlv2/test_image_processing_owlv2.py @@ -127,7 +127,7 @@ def test_image_processor_integration_test(self): pixel_values = processor(image, return_tensors="pt").pixel_values mean_value = round(pixel_values.mean().item(), 4) - self.assertEqual(mean_value, 0.2353) + self.assertEqual(mean_value, -0.2303) @slow def test_image_processor_integration_test_resize(self): diff --git a/tests/utils/test_feature_extraction_utils.py b/tests/utils/test_feature_extraction_utils.py index b0a6a193d10d..01e511bc289a 100644 --- a/tests/utils/test_feature_extraction_utils.py +++ b/tests/utils/test_feature_extraction_utils.py @@ -20,9 +20,12 @@ from pathlib import Path import httpx +import numpy as np from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor -from transformers.testing_utils import TOKEN, TemporaryHubRepo, get_tests_dir, is_staging_test +from transformers.feature_extraction_utils import BatchFeature +from transformers.testing_utils import TOKEN, TemporaryHubRepo, get_tests_dir, is_staging_test, require_torch +from transformers.utils import is_torch_available sys.path.append(str(Path(__file__).parent.parent.parent / "utils")) @@ -30,9 +33,143 @@ from test_module.custom_feature_extraction import CustomFeatureExtractor # noqa E402 +if is_torch_available(): + import torch + + SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = get_tests_dir("fixtures") +class BatchFeatureTester(unittest.TestCase): + """Tests for the BatchFeature class and tensor conversion.""" + + def test_batch_feature_basic_access_and_no_conversion(self): + """Test basic dict/attribute access and no conversion when tensor_type=None.""" + data = {"input_values": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1]} + batch = BatchFeature(data) + + # Dict-style and attribute-style access + self.assertEqual(batch["input_values"], [[1, 2, 3], [4, 5, 6]]) + self.assertEqual(batch.labels, [0, 1]) + + # No conversion without tensor_type + self.assertIsInstance(batch["input_values"], list) + + @require_torch + def test_batch_feature_numpy_conversion(self): + """Test conversion to numpy arrays from lists and existing numpy arrays.""" + # From lists + batch = BatchFeature({"input_values": [[1, 2, 3], [4, 5, 6]]}, tensor_type="np") + self.assertIsInstance(batch["input_values"], np.ndarray) + self.assertEqual(batch["input_values"].shape, (2, 3)) + + # From numpy arrays (should remain numpy) + numpy_data = np.array([[1, 2, 3], [4, 5, 6]]) + batch_arrays = BatchFeature({"input_values": numpy_data}, tensor_type="np") + np.testing.assert_array_equal(batch_arrays["input_values"], numpy_data) + + # From list of numpy arrays with same shape should stack + numpy_data = [np.array([[1, 2, 3], [4, 5, 6]]), np.array([[7, 8, 9], [10, 11, 12]])] + batch_stacked = BatchFeature({"input_values": numpy_data}, tensor_type="np") + np.testing.assert_array_equal( + batch_stacked["input_values"], np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]) + ) + + # from tensor + tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) + batch_tensor = BatchFeature({"input_values": tensor}, tensor_type="np") + np.testing.assert_array_equal(batch_tensor["input_values"], tensor.numpy()) + + # from list of tensors with same shape should stack + tensors = [torch.tensor([[1, 2, 3], [4, 5, 6]]), torch.tensor([[7, 8, 9], [10, 11, 12]])] + batch_stacked = BatchFeature({"input_values": tensors}, tensor_type="np") + self.assertIsInstance(batch_stacked["input_values"], np.ndarray) + np.testing.assert_array_equal( + batch_stacked["input_values"], np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]) + ) + + @require_torch + def test_batch_feature_pytorch_conversion(self): + """Test conversion to PyTorch tensors from various input types.""" + # From lists + batch = BatchFeature({"input_values": [[1, 2, 3], [4, 5, 6]]}, tensor_type="pt") + self.assertIsInstance(batch["input_values"], torch.Tensor) + self.assertEqual(batch["input_values"].shape, (2, 3)) + + # from tensor (should be returned as-is) + tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) + batch_tensor = BatchFeature({"input_values": tensor}, tensor_type="pt") + torch.testing.assert_close(batch_tensor["input_values"], tensor) + + # From numpy arrays + batch_numpy = BatchFeature({"input_values": np.array([[1, 2]])}, tensor_type="pt") + self.assertIsInstance(batch_numpy["input_values"], torch.Tensor) + + # List of same-shape tensors should stack + tensors = [torch.randn(3, 10, 10) for _ in range(3)] + batch_stacked = BatchFeature({"pixel_values": tensors}, tensor_type="pt") + self.assertEqual(batch_stacked["pixel_values"].shape, (3, 3, 10, 10)) + + # List of same-shape numpy arrays should stack + numpy_arrays = [np.random.randn(3, 10, 10) for _ in range(3)] + batch_stacked = BatchFeature({"pixel_values": numpy_arrays}, tensor_type="pt") + self.assertIsInstance(batch_stacked["pixel_values"], torch.Tensor) + self.assertEqual(batch_stacked["pixel_values"].shape, (3, 3, 10, 10)) + + @require_torch + def test_batch_feature_error_handling(self): + """Test clear error messages for common conversion failures.""" + # Ragged tensors (different shapes) + data_ragged = {"values": [torch.randn(3, 224, 224), torch.randn(3, 448, 448)]} + with self.assertRaises(ValueError) as context: + BatchFeature(data_ragged, tensor_type="pt") + error_msg = str(context.exception) + self.assertIn("stack expects each tensor to be equal size", error_msg.lower()) + self.assertIn("return_tensors=None", error_msg) + + # Ragged numpy arrays (different shapes) + data_ragged = {"values": [np.random.randn(3, 224, 224), np.random.randn(3, 448, 448)]} + with self.assertRaises(ValueError) as context: + BatchFeature(data_ragged, tensor_type="np") + error_msg = str(context.exception) + self.assertIn("inhomogeneous", error_msg.lower()) + self.assertIn("return_tensors=None", error_msg) + + # Unconvertible type (dict) + data_dict = {"values": [[1, 2]], "metadata": {"key": "val"}} + with self.assertRaises(ValueError) as context: + BatchFeature(data_dict, tensor_type="pt") + self.assertIn("metadata", str(context.exception)) + + @require_torch + def test_batch_feature_skip_tensor_conversion(self): + """Test skip_tensor_conversion parameter for metadata fields.""" + import torch + + data = {"pixel_values": [[1, 2, 3]], "num_crops": [1, 2], "sizes": [(224, 224)]} + batch = BatchFeature(data, tensor_type="pt", skip_tensor_conversion=["num_crops", "sizes"]) + + # pixel_values should be converted + self.assertIsInstance(batch["pixel_values"], torch.Tensor) + # num_crops and sizes should remain as lists + self.assertIsInstance(batch["num_crops"], list) + self.assertIsInstance(batch["sizes"], list) + + @require_torch + def test_batch_feature_convert_to_tensors_method(self): + """Test convert_to_tensors method can be called after initialization.""" + import torch + + data = {"input_values": [[1, 2, 3]], "metadata": [1, 2]} + batch = BatchFeature(data) # No conversion initially + self.assertIsInstance(batch["input_values"], list) + + # Convert with skip parameter + batch.convert_to_tensors(tensor_type="pt", skip_tensor_conversion=["metadata"]) + self.assertIsInstance(batch["input_values"], torch.Tensor) + self.assertIsInstance(batch["metadata"], list) + + class FeatureExtractorUtilTester(unittest.TestCase): def test_cached_files_are_used_when_internet_is_down(self): # A mock response for an HTTP head request to emulate server down