Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from ..utils.py_utils import asdict, first_non_null_value, zip_dict
from .audio import Audio
from .image import Image, encode_pil_image
from .nifti import Nifti
from .nifti import Nifti, encode_nibabel_image
from .pdf import Pdf, encode_pdfplumber_pdf
from .translation import Translation, TranslationVariableLanguages
from .video import Video
Expand Down Expand Up @@ -307,6 +307,9 @@ def _cast_to_python_objects(obj: Any, only_1d_for_numpy: bool, optimize_list_cas
if config.PDFPLUMBER_AVAILABLE and "pdfplumber" in sys.modules:
import pdfplumber

if config.NIBABEL_AVAILABLE and "nibabel" in sys.modules:
import nibabel as nib

if config.TORCHCODEC_AVAILABLE and "torchcodec" in sys.modules:
from torchcodec.decoders import AudioDecoder, VideoDecoder

Expand Down Expand Up @@ -380,6 +383,8 @@ def _cast_to_python_objects(obj: Any, only_1d_for_numpy: bool, optimize_list_cas
return encode_pil_image(obj), True
elif config.PDFPLUMBER_AVAILABLE and "pdfplumber" in sys.modules and isinstance(obj, pdfplumber.pdf.PDF):
return encode_pdfplumber_pdf(obj), True
elif config.NIBABEL_AVAILABLE and "nibabel" in sys.modules and isinstance(obj, nib.analyze.AnalyzeImage):
return encode_nibabel_image(obj, force_bytes=True), True
elif isinstance(obj, pd.Series):
return (
_cast_to_python_objects(
Expand Down
5 changes: 3 additions & 2 deletions src/datasets/features/nifti.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.BinaryA
return array_cast(storage, self.pa_type)


def encode_nibabel_image(img: "nib.Nifti1Image") -> dict[str, Optional[Union[str, bytes]]]:
def encode_nibabel_image(img: "nib.Nifti1Image", force_bytes: bool = False) -> dict[str, Optional[Union[str, bytes]]]:
"""
Encode a nibabel image object into a dictionary.

Expand All @@ -309,11 +309,12 @@ def encode_nibabel_image(img: "nib.Nifti1Image") -> dict[str, Optional[Union[str

Args:
img: A nibabel image object (e.g., Nifti1Image).
force_bytes: If `True`, always serialize to bytes even if a file path exists. Needed to upload bytes properly.

Returns:
dict: A dictionary with "path" or "bytes" field.
"""
if hasattr(img, "file_map") and img.file_map is not None:
if hasattr(img, "file_map") and img.file_map is not None and not force_bytes:
filename = img.file_map["image"].filename
return {"path": filename, "bytes": None}

Expand Down
Loading