Skip to content

Commit e5d1092

Browse files
committed
stack lists of tensors in BatchFeature, improve error messages, add tests
1 parent 5b4d72c commit e5d1092

File tree

4 files changed

+176
-12
lines changed

4 files changed

+176
-12
lines changed

src/transformers/feature_extraction_utils.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,18 @@ class BatchFeature(UserDict):
6767
tensor_type (`Union[None, str, TensorType]`, *optional*):
6868
You can give a tensor_type here to convert the lists of integers in PyTorch/Numpy Tensors at
6969
initialization.
70+
skip_tensor_conversion (`list[str]` or `set[str]`, *optional*):
71+
List or set of keys that should NOT be converted to tensors, even when `tensor_type` is specified.
7072
"""
7173

72-
def __init__(self, data: Optional[dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
74+
def __init__(
75+
self,
76+
data: Optional[dict[str, Any]] = None,
77+
tensor_type: Union[None, str, TensorType] = None,
78+
skip_tensor_conversion: Optional[Union[list[str], set[str]]] = None,
79+
):
7380
super().__init__(data)
74-
self.convert_to_tensors(tensor_type=tensor_type)
81+
self.convert_to_tensors(tensor_type=tensor_type, skip_tensor_conversion=skip_tensor_conversion)
7582

7683
def __getitem__(self, item: str) -> Any:
7784
"""
@@ -110,6 +117,14 @@ def _get_is_as_tensor_fns(self, tensor_type: Optional[Union[str, TensorType]] =
110117
import torch
111118

112119
def as_tensor(value):
120+
if torch.is_tensor(value):
121+
return value
122+
123+
# stack list of tensors if tensor_type is PyTorch (# torch.tensor() does not support list of tensors)
124+
if isinstance(value, (list, tuple)) and len(value) > 0 and torch.is_tensor(value[0]):
125+
return torch.stack(value)
126+
127+
# convert list of numpy arrays to numpy array (stack) if tensor_type is Numpy
113128
if isinstance(value, (list, tuple)) and len(value) > 0:
114129
if isinstance(value[0], np.ndarray):
115130
value = np.array(value)
@@ -138,14 +153,20 @@ def as_tensor(value, dtype=None):
138153
is_tensor = is_numpy_array
139154
return is_tensor, as_tensor
140155

141-
def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
156+
def convert_to_tensors(
157+
self,
158+
tensor_type: Optional[Union[str, TensorType]] = None,
159+
skip_tensor_conversion: Optional[Union[list[str], set[str]]] = None,
160+
):
142161
"""
143162
Convert the inner content to tensors.
144163
145164
Args:
146165
tensor_type (`str` or [`~utils.TensorType`], *optional*):
147166
The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
148167
`None`, no modification is done.
168+
skip_tensor_conversion (`list[str]` or `set[str]`, *optional*):
169+
List or set of keys that should NOT be converted to tensors, even when `tensor_type` is specified.
149170
"""
150171
if tensor_type is None:
151172
return self
@@ -154,18 +175,26 @@ def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = Non
154175

155176
# Do the tensor conversion in batch
156177
for key, value in self.items():
178+
# Skip keys explicitly marked for no conversion
179+
if skip_tensor_conversion and key in skip_tensor_conversion:
180+
continue
181+
157182
try:
158183
if not is_tensor(value):
159184
tensor = as_tensor(value)
160-
161185
self[key] = tensor
162-
except: # noqa E722
186+
except Exception as e:
163187
if key == "overflowing_values":
164-
raise ValueError("Unable to create tensor returning overflowing values of different lengths. ")
188+
raise ValueError(
189+
f"Unable to create tensor for '{key}' with overflowing values of different lengths. "
190+
f"Original error: {str(e)}"
191+
) from e
165192
raise ValueError(
166-
"Unable to create tensor, you should probably activate padding "
167-
"with 'padding=True' to have batched tensors with the same length."
168-
)
193+
f"Unable to convert output '{key}' (type: {type(value).__name__}) to tensor: {str(e)}\n"
194+
f"You can try:\n"
195+
f" 1. Use padding=True to ensure all outputs have the same shape\n"
196+
f" 2. Set return_tensors=None to return Python objects instead of tensors"
197+
) from e
169198

170199
return self
171200

src/transformers/image_processing_utils_fast.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -932,7 +932,6 @@ def _preprocess(
932932
if do_pad:
933933
processed_images = self.pad(processed_images, pad_size=pad_size, disable_grouping=disable_grouping)
934934

935-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
936935
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
937936

938937
def to_dict(self):

src/transformers/models/gemma3/image_processing_gemma3_fast.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,6 @@ def _preprocess(
231231
processed_images_grouped[shape] = stacked_images
232232

233233
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
234-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
235234
return BatchFeature(
236235
data={"pixel_values": processed_images, "num_crops": num_crops}, tensor_type=return_tensors
237236
)

tests/utils/test_feature_extraction_utils.py

Lines changed: 138 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,156 @@
2020
from pathlib import Path
2121

2222
import httpx
23+
import numpy as np
2324

2425
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
25-
from transformers.testing_utils import TOKEN, TemporaryHubRepo, get_tests_dir, is_staging_test
26+
from transformers.feature_extraction_utils import BatchFeature
27+
from transformers.testing_utils import TOKEN, TemporaryHubRepo, get_tests_dir, is_staging_test, require_torch
28+
from transformers.utils import is_torch_available
2629

2730

2831
sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))
2932

3033
from test_module.custom_feature_extraction import CustomFeatureExtractor # noqa E402
3134

3235

36+
if is_torch_available():
37+
import torch
38+
39+
3340
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = get_tests_dir("fixtures")
3441

3542

43+
class BatchFeatureTester(unittest.TestCase):
44+
"""Tests for the BatchFeature class and tensor conversion."""
45+
46+
def test_batch_feature_basic_access_and_no_conversion(self):
47+
"""Test basic dict/attribute access and no conversion when tensor_type=None."""
48+
data = {"input_values": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1]}
49+
batch = BatchFeature(data)
50+
51+
# Dict-style and attribute-style access
52+
self.assertEqual(batch["input_values"], [[1, 2, 3], [4, 5, 6]])
53+
self.assertEqual(batch.labels, [0, 1])
54+
55+
# No conversion without tensor_type
56+
self.assertIsInstance(batch["input_values"], list)
57+
58+
@require_torch
59+
def test_batch_feature_numpy_conversion(self):
60+
"""Test conversion to numpy arrays from lists and existing numpy arrays."""
61+
# From lists
62+
batch = BatchFeature({"input_values": [[1, 2, 3], [4, 5, 6]]}, tensor_type="np")
63+
self.assertIsInstance(batch["input_values"], np.ndarray)
64+
self.assertEqual(batch["input_values"].shape, (2, 3))
65+
66+
# From numpy arrays (should remain numpy)
67+
numpy_data = np.array([[1, 2, 3], [4, 5, 6]])
68+
batch_arrays = BatchFeature({"input_values": numpy_data}, tensor_type="np")
69+
np.testing.assert_array_equal(batch_arrays["input_values"], numpy_data)
70+
71+
# From list of numpy arrays with same shape should stack
72+
numpy_data = [np.array([[1, 2, 3], [4, 5, 6]]), np.array([[7, 8, 9], [10, 11, 12]])]
73+
batch_stacked = BatchFeature({"input_values": numpy_data}, tensor_type="np")
74+
np.testing.assert_array_equal(
75+
batch_stacked["input_values"], np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
76+
)
77+
78+
# from tensor
79+
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
80+
batch_tensor = BatchFeature({"input_values": tensor}, tensor_type="np")
81+
np.testing.assert_array_equal(batch_tensor["input_values"], tensor.numpy())
82+
83+
# from list of tensors with same shape should stack
84+
tensors = [torch.tensor([[1, 2, 3], [4, 5, 6]]), torch.tensor([[7, 8, 9], [10, 11, 12]])]
85+
batch_stacked = BatchFeature({"input_values": tensors}, tensor_type="np")
86+
self.assertIsInstance(batch_stacked["input_values"], np.ndarray)
87+
np.testing.assert_array_equal(
88+
batch_stacked["input_values"], np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
89+
)
90+
91+
@require_torch
92+
def test_batch_feature_pytorch_conversion(self):
93+
"""Test conversion to PyTorch tensors from various input types."""
94+
# From lists
95+
batch = BatchFeature({"input_values": [[1, 2, 3], [4, 5, 6]]}, tensor_type="pt")
96+
self.assertIsInstance(batch["input_values"], torch.Tensor)
97+
self.assertEqual(batch["input_values"].shape, (2, 3))
98+
99+
# from tensor (should be returned as-is)
100+
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
101+
batch_tensor = BatchFeature({"input_values": tensor}, tensor_type="pt")
102+
torch.testing.assert_close(batch_tensor["input_values"], tensor)
103+
104+
# From numpy arrays
105+
batch_numpy = BatchFeature({"input_values": np.array([[1, 2]])}, tensor_type="pt")
106+
self.assertIsInstance(batch_numpy["input_values"], torch.Tensor)
107+
108+
# List of same-shape tensors should stack
109+
tensors = [torch.randn(3, 10, 10) for _ in range(3)]
110+
batch_stacked = BatchFeature({"pixel_values": tensors}, tensor_type="pt")
111+
self.assertEqual(batch_stacked["pixel_values"].shape, (3, 3, 10, 10))
112+
113+
# List of same-shape numpy arrays should stack
114+
numpy_arrays = [np.random.randn(3, 10, 10) for _ in range(3)]
115+
batch_stacked = BatchFeature({"pixel_values": numpy_arrays}, tensor_type="pt")
116+
self.assertIsInstance(batch_stacked["pixel_values"], torch.Tensor)
117+
self.assertEqual(batch_stacked["pixel_values"].shape, (3, 3, 10, 10))
118+
119+
@require_torch
120+
def test_batch_feature_error_handling(self):
121+
"""Test clear error messages for common conversion failures."""
122+
# Ragged tensors (different shapes)
123+
data_ragged = {"values": [torch.randn(3, 224, 224), torch.randn(3, 448, 448)]}
124+
with self.assertRaises(ValueError) as context:
125+
BatchFeature(data_ragged, tensor_type="pt")
126+
error_msg = str(context.exception)
127+
self.assertIn("stack expects each tensor to be equal size", error_msg.lower())
128+
self.assertIn("return_tensors=None", error_msg)
129+
130+
# Ragged numpy arrays (different shapes)
131+
data_ragged = {"values": [np.random.randn(3, 224, 224), np.random.randn(3, 448, 448)]}
132+
with self.assertRaises(ValueError) as context:
133+
BatchFeature(data_ragged, tensor_type="np")
134+
error_msg = str(context.exception)
135+
self.assertIn("inhomogeneous", error_msg.lower())
136+
self.assertIn("return_tensors=None", error_msg)
137+
138+
# Unconvertible type (dict)
139+
data_dict = {"values": [[1, 2]], "metadata": {"key": "val"}}
140+
with self.assertRaises(ValueError) as context:
141+
BatchFeature(data_dict, tensor_type="pt")
142+
self.assertIn("metadata", str(context.exception))
143+
144+
@require_torch
145+
def test_batch_feature_skip_tensor_conversion(self):
146+
"""Test skip_tensor_conversion parameter for metadata fields."""
147+
import torch
148+
149+
data = {"pixel_values": [[1, 2, 3]], "num_crops": [1, 2], "sizes": [(224, 224)]}
150+
batch = BatchFeature(data, tensor_type="pt", skip_tensor_conversion=["num_crops", "sizes"])
151+
152+
# pixel_values should be converted
153+
self.assertIsInstance(batch["pixel_values"], torch.Tensor)
154+
# num_crops and sizes should remain as lists
155+
self.assertIsInstance(batch["num_crops"], list)
156+
self.assertIsInstance(batch["sizes"], list)
157+
158+
@require_torch
159+
def test_batch_feature_convert_to_tensors_method(self):
160+
"""Test convert_to_tensors method can be called after initialization."""
161+
import torch
162+
163+
data = {"input_values": [[1, 2, 3]], "metadata": [1, 2]}
164+
batch = BatchFeature(data) # No conversion initially
165+
self.assertIsInstance(batch["input_values"], list)
166+
167+
# Convert with skip parameter
168+
batch.convert_to_tensors(tensor_type="pt", skip_tensor_conversion=["metadata"])
169+
self.assertIsInstance(batch["input_values"], torch.Tensor)
170+
self.assertIsInstance(batch["metadata"], list)
171+
172+
36173
class FeatureExtractorUtilTester(unittest.TestCase):
37174
def test_cached_files_are_used_when_internet_is_down(self):
38175
# A mock response for an HTTP head request to emulate server down

0 commit comments

Comments
 (0)