|
20 | 20 | from pathlib import Path |
21 | 21 |
|
22 | 22 | import httpx |
| 23 | +import numpy as np |
23 | 24 |
|
24 | 25 | from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor |
25 | | -from transformers.testing_utils import TOKEN, TemporaryHubRepo, get_tests_dir, is_staging_test |
| 26 | +from transformers.feature_extraction_utils import BatchFeature |
| 27 | +from transformers.testing_utils import TOKEN, TemporaryHubRepo, get_tests_dir, is_staging_test, require_torch |
| 28 | +from transformers.utils import is_torch_available |
26 | 29 |
|
27 | 30 |
|
28 | 31 | sys.path.append(str(Path(__file__).parent.parent.parent / "utils")) |
29 | 32 |
|
30 | 33 | from test_module.custom_feature_extraction import CustomFeatureExtractor # noqa E402 |
31 | 34 |
|
32 | 35 |
|
| 36 | +if is_torch_available(): |
| 37 | + import torch |
| 38 | + |
| 39 | + |
33 | 40 | SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = get_tests_dir("fixtures") |
34 | 41 |
|
35 | 42 |
|
| 43 | +class BatchFeatureTester(unittest.TestCase): |
| 44 | + """Tests for the BatchFeature class and tensor conversion.""" |
| 45 | + |
| 46 | + def test_batch_feature_basic_access_and_no_conversion(self): |
| 47 | + """Test basic dict/attribute access and no conversion when tensor_type=None.""" |
| 48 | + data = {"input_values": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1]} |
| 49 | + batch = BatchFeature(data) |
| 50 | + |
| 51 | + # Dict-style and attribute-style access |
| 52 | + self.assertEqual(batch["input_values"], [[1, 2, 3], [4, 5, 6]]) |
| 53 | + self.assertEqual(batch.labels, [0, 1]) |
| 54 | + |
| 55 | + # No conversion without tensor_type |
| 56 | + self.assertIsInstance(batch["input_values"], list) |
| 57 | + |
| 58 | + @require_torch |
| 59 | + def test_batch_feature_numpy_conversion(self): |
| 60 | + """Test conversion to numpy arrays from lists and existing numpy arrays.""" |
| 61 | + # From lists |
| 62 | + batch = BatchFeature({"input_values": [[1, 2, 3], [4, 5, 6]]}, tensor_type="np") |
| 63 | + self.assertIsInstance(batch["input_values"], np.ndarray) |
| 64 | + self.assertEqual(batch["input_values"].shape, (2, 3)) |
| 65 | + |
| 66 | + # From numpy arrays (should remain numpy) |
| 67 | + numpy_data = np.array([[1, 2, 3], [4, 5, 6]]) |
| 68 | + batch_arrays = BatchFeature({"input_values": numpy_data}, tensor_type="np") |
| 69 | + np.testing.assert_array_equal(batch_arrays["input_values"], numpy_data) |
| 70 | + |
| 71 | + # From list of numpy arrays with same shape should stack |
| 72 | + numpy_data = [np.array([[1, 2, 3], [4, 5, 6]]), np.array([[7, 8, 9], [10, 11, 12]])] |
| 73 | + batch_stacked = BatchFeature({"input_values": numpy_data}, tensor_type="np") |
| 74 | + np.testing.assert_array_equal( |
| 75 | + batch_stacked["input_values"], np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]) |
| 76 | + ) |
| 77 | + |
| 78 | + # from tensor |
| 79 | + tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) |
| 80 | + batch_tensor = BatchFeature({"input_values": tensor}, tensor_type="np") |
| 81 | + np.testing.assert_array_equal(batch_tensor["input_values"], tensor.numpy()) |
| 82 | + |
| 83 | + # from list of tensors with same shape should stack |
| 84 | + tensors = [torch.tensor([[1, 2, 3], [4, 5, 6]]), torch.tensor([[7, 8, 9], [10, 11, 12]])] |
| 85 | + batch_stacked = BatchFeature({"input_values": tensors}, tensor_type="np") |
| 86 | + self.assertIsInstance(batch_stacked["input_values"], np.ndarray) |
| 87 | + np.testing.assert_array_equal( |
| 88 | + batch_stacked["input_values"], np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]) |
| 89 | + ) |
| 90 | + |
| 91 | + @require_torch |
| 92 | + def test_batch_feature_pytorch_conversion(self): |
| 93 | + """Test conversion to PyTorch tensors from various input types.""" |
| 94 | + # From lists |
| 95 | + batch = BatchFeature({"input_values": [[1, 2, 3], [4, 5, 6]]}, tensor_type="pt") |
| 96 | + self.assertIsInstance(batch["input_values"], torch.Tensor) |
| 97 | + self.assertEqual(batch["input_values"].shape, (2, 3)) |
| 98 | + |
| 99 | + # from tensor (should be returned as-is) |
| 100 | + tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) |
| 101 | + batch_tensor = BatchFeature({"input_values": tensor}, tensor_type="pt") |
| 102 | + torch.testing.assert_close(batch_tensor["input_values"], tensor) |
| 103 | + |
| 104 | + # From numpy arrays |
| 105 | + batch_numpy = BatchFeature({"input_values": np.array([[1, 2]])}, tensor_type="pt") |
| 106 | + self.assertIsInstance(batch_numpy["input_values"], torch.Tensor) |
| 107 | + |
| 108 | + # List of same-shape tensors should stack |
| 109 | + tensors = [torch.randn(3, 10, 10) for _ in range(3)] |
| 110 | + batch_stacked = BatchFeature({"pixel_values": tensors}, tensor_type="pt") |
| 111 | + self.assertEqual(batch_stacked["pixel_values"].shape, (3, 3, 10, 10)) |
| 112 | + |
| 113 | + # List of same-shape numpy arrays should stack |
| 114 | + numpy_arrays = [np.random.randn(3, 10, 10) for _ in range(3)] |
| 115 | + batch_stacked = BatchFeature({"pixel_values": numpy_arrays}, tensor_type="pt") |
| 116 | + self.assertIsInstance(batch_stacked["pixel_values"], torch.Tensor) |
| 117 | + self.assertEqual(batch_stacked["pixel_values"].shape, (3, 3, 10, 10)) |
| 118 | + |
| 119 | + @require_torch |
| 120 | + def test_batch_feature_error_handling(self): |
| 121 | + """Test clear error messages for common conversion failures.""" |
| 122 | + # Ragged tensors (different shapes) |
| 123 | + data_ragged = {"values": [torch.randn(3, 224, 224), torch.randn(3, 448, 448)]} |
| 124 | + with self.assertRaises(ValueError) as context: |
| 125 | + BatchFeature(data_ragged, tensor_type="pt") |
| 126 | + error_msg = str(context.exception) |
| 127 | + self.assertIn("stack expects each tensor to be equal size", error_msg.lower()) |
| 128 | + self.assertIn("return_tensors=None", error_msg) |
| 129 | + |
| 130 | + # Ragged numpy arrays (different shapes) |
| 131 | + data_ragged = {"values": [np.random.randn(3, 224, 224), np.random.randn(3, 448, 448)]} |
| 132 | + with self.assertRaises(ValueError) as context: |
| 133 | + BatchFeature(data_ragged, tensor_type="np") |
| 134 | + error_msg = str(context.exception) |
| 135 | + self.assertIn("inhomogeneous", error_msg.lower()) |
| 136 | + self.assertIn("return_tensors=None", error_msg) |
| 137 | + |
| 138 | + # Unconvertible type (dict) |
| 139 | + data_dict = {"values": [[1, 2]], "metadata": {"key": "val"}} |
| 140 | + with self.assertRaises(ValueError) as context: |
| 141 | + BatchFeature(data_dict, tensor_type="pt") |
| 142 | + self.assertIn("metadata", str(context.exception)) |
| 143 | + |
| 144 | + @require_torch |
| 145 | + def test_batch_feature_skip_tensor_conversion(self): |
| 146 | + """Test skip_tensor_conversion parameter for metadata fields.""" |
| 147 | + import torch |
| 148 | + |
| 149 | + data = {"pixel_values": [[1, 2, 3]], "num_crops": [1, 2], "sizes": [(224, 224)]} |
| 150 | + batch = BatchFeature(data, tensor_type="pt", skip_tensor_conversion=["num_crops", "sizes"]) |
| 151 | + |
| 152 | + # pixel_values should be converted |
| 153 | + self.assertIsInstance(batch["pixel_values"], torch.Tensor) |
| 154 | + # num_crops and sizes should remain as lists |
| 155 | + self.assertIsInstance(batch["num_crops"], list) |
| 156 | + self.assertIsInstance(batch["sizes"], list) |
| 157 | + |
| 158 | + @require_torch |
| 159 | + def test_batch_feature_convert_to_tensors_method(self): |
| 160 | + """Test convert_to_tensors method can be called after initialization.""" |
| 161 | + import torch |
| 162 | + |
| 163 | + data = {"input_values": [[1, 2, 3]], "metadata": [1, 2]} |
| 164 | + batch = BatchFeature(data) # No conversion initially |
| 165 | + self.assertIsInstance(batch["input_values"], list) |
| 166 | + |
| 167 | + # Convert with skip parameter |
| 168 | + batch.convert_to_tensors(tensor_type="pt", skip_tensor_conversion=["metadata"]) |
| 169 | + self.assertIsInstance(batch["input_values"], torch.Tensor) |
| 170 | + self.assertIsInstance(batch["metadata"], list) |
| 171 | + |
| 172 | + |
36 | 173 | class FeatureExtractorUtilTester(unittest.TestCase): |
37 | 174 | def test_cached_files_are_used_when_internet_is_down(self): |
38 | 175 | # A mock response for an HTTP head request to emulate server down |
|
0 commit comments