Skip to content

Commit 0ec4d87

Browse files
authored
Add inspect_ai eval logs support (#7899)
add inspectai eval format
1 parent 1a330f3 commit 0ec4d87

File tree

6 files changed

+72
-2
lines changed

6 files changed

+72
-2
lines changed

src/datasets/arrow_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -689,10 +689,10 @@ def __len__(self) -> int:
689689
return len(self.source)
690690

691691
def __repr__(self):
692-
return "Column(" + repr(list(self[:5])) + ")"
692+
return "Column(" + repr(list(self[:5]))[:-1] + (", ...])" if len(self) > 5 else "])")
693693

694694
def __str__(self):
695-
return "Column(" + str(list(self[:5])) + ")"
695+
return "Column(" + str(list(self[:5]))[:-1] + (", ...])" if len(self) > 5 else "])")
696696

697697
def __eq__(self, value):
698698
if isinstance(value, Column):

src/datasets/data_files.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,11 @@ class EmptyDatasetError(FileNotFoundError):
9494
Split.TRAIN: ["**"],
9595
}
9696

97+
DEFAULT_PATTERNS_LOGS = {"logs": ["**/*.eval"]}
98+
9799
ALL_SPLIT_PATTERNS = [SPLIT_PATTERN_SHARDED]
98100
ALL_DEFAULT_PATTERNS = [
101+
DEFAULT_PATTERNS_LOGS,
99102
DEFAULT_PATTERNS_SPLIT_IN_DIR_NAME,
100103
DEFAULT_PATTERNS_SPLIT_IN_FILENAME,
101104
DEFAULT_PATTERNS_ALL,

src/datasets/packaged_modules/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .audiofolder import audiofolder
99
from .cache import cache
1010
from .csv import csv
11+
from .eval import eval
1112
from .hdf5 import hdf5
1213
from .imagefolder import imagefolder
1314
from .json import json
@@ -51,6 +52,7 @@ def _hash_python_lines(lines: list[str]) -> str:
5152
"webdataset": (webdataset.__name__, _hash_python_lines(inspect.getsource(webdataset).splitlines())),
5253
"xml": (xml.__name__, _hash_python_lines(inspect.getsource(xml).splitlines())),
5354
"hdf5": (hdf5.__name__, _hash_python_lines(inspect.getsource(hdf5).splitlines())),
55+
"eval": (eval.__name__, _hash_python_lines(inspect.getsource(eval).splitlines())),
5456
}
5557

5658
# get importable module names and hash for caching
@@ -82,6 +84,7 @@ def _hash_python_lines(lines: list[str]) -> str:
8284
".xml": ("xml", {}),
8385
".hdf5": ("hdf5", {}),
8486
".h5": ("hdf5", {}),
87+
".eval": ("eval", {}),
8588
}
8689
_EXTENSION_TO_MODULE.update({ext: ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS})
8790
_EXTENSION_TO_MODULE.update({ext.upper(): ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS})

src/datasets/packaged_modules/eval/__init__.py

Whitespace-only changes.
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import json
2+
import os
3+
from itertools import islice
4+
5+
import pyarrow as pa
6+
7+
import datasets
8+
from datasets.builder import Key
9+
10+
11+
logger = datasets.utils.logging.get_logger(__name__)
12+
13+
14+
class Eval(datasets.GeneratorBasedBuilder):
15+
NUM_EXAMPLES_FOR_FEATURES_INFERENCE = 5
16+
17+
def _info(self):
18+
return datasets.DatasetInfo()
19+
20+
def _split_generators(self, dl_manager):
21+
"""We handle string, list and dicts in datafiles"""
22+
if not self.config.data_files:
23+
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
24+
dl_manager.download_config.extract_on_the_fly = True
25+
data_files = dl_manager.download_and_extract(self.config.data_files)
26+
splits = []
27+
for split_name, logs in data_files.items():
28+
if isinstance(logs, str):
29+
logs = [logs]
30+
logs_files = [dl_manager.iter_files(log) for log in logs]
31+
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"logs_files": logs_files}))
32+
if not self.info.features:
33+
first_examples = list(
34+
islice(self._iter_samples_from_log_files(logs_files[0]), self.NUM_EXAMPLES_FOR_FEATURES_INFERENCE)
35+
)
36+
pa_tables = [pa.Table.from_pylist([example]) for example in first_examples]
37+
inferred_arrow_schema = pa.concat_tables(pa_tables, promote_options="default").schema
38+
self.info.features = datasets.Features.from_arrow_schema(inferred_arrow_schema)
39+
40+
return splits
41+
42+
def _sort_samples_key(self, sample_path: str):
43+
# looks like "{sample_idx}_epoch_{epoch_idx}""
44+
(sample_idx_str, epoch_idx_str) = os.path.splitext(os.path.basename(sample_path))[0].split("_epoch_")
45+
return (int(epoch_idx_str), int(sample_idx_str))
46+
47+
def _iter_samples_from_log_files(self, log_files: list[str]):
48+
sample_files = [log_file for log_file in log_files if os.path.basename(os.path.dirname(log_file)) == "samples"]
49+
sample_files.sort(key=self._sort_samples_key)
50+
for sample_file in sample_files:
51+
with open(sample_file) as f:
52+
sample = json.load(f)
53+
for field in sample:
54+
if isinstance(sample[field], dict):
55+
sample[field] = json.dumps(sample[field])
56+
if isinstance(sample[field], list):
57+
sample[field] = [json.dumps(x) for x in sample[field]]
58+
yield sample
59+
60+
def _generate_examples(self, logs_files):
61+
for file_idx, log_files in enumerate(logs_files):
62+
for sample_idx, sample in enumerate(self._iter_samples_from_log_files(log_files)):
63+
yield Key(file_idx, sample_idx), sample

src/datasets/utils/file_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,7 @@ def readline(f: io.RawIOBase):
479479
},
480480
# archive compression
481481
"zip": "zip",
482+
"eval": "zip",
482483
}
483484
SINGLE_FILE_COMPRESSION_EXTENSION_TO_PROTOCOL = {
484485
extension.lstrip("."): fs_class.protocol

0 commit comments

Comments
 (0)