|
| 1 | +import json |
| 2 | +import os |
| 3 | +from itertools import islice |
| 4 | + |
| 5 | +import pyarrow as pa |
| 6 | + |
| 7 | +import datasets |
| 8 | +from datasets.builder import Key |
| 9 | + |
| 10 | + |
| 11 | +logger = datasets.utils.logging.get_logger(__name__) |
| 12 | + |
| 13 | + |
| 14 | +class Eval(datasets.GeneratorBasedBuilder): |
| 15 | + NUM_EXAMPLES_FOR_FEATURES_INFERENCE = 5 |
| 16 | + |
| 17 | + def _info(self): |
| 18 | + return datasets.DatasetInfo() |
| 19 | + |
| 20 | + def _split_generators(self, dl_manager): |
| 21 | + """We handle string, list and dicts in datafiles""" |
| 22 | + if not self.config.data_files: |
| 23 | + raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") |
| 24 | + dl_manager.download_config.extract_on_the_fly = True |
| 25 | + data_files = dl_manager.download_and_extract(self.config.data_files) |
| 26 | + splits = [] |
| 27 | + for split_name, logs in data_files.items(): |
| 28 | + if isinstance(logs, str): |
| 29 | + logs = [logs] |
| 30 | + logs_files = [dl_manager.iter_files(log) for log in logs] |
| 31 | + splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"logs_files": logs_files})) |
| 32 | + if not self.info.features: |
| 33 | + first_examples = list( |
| 34 | + islice(self._iter_samples_from_log_files(logs_files[0]), self.NUM_EXAMPLES_FOR_FEATURES_INFERENCE) |
| 35 | + ) |
| 36 | + pa_tables = [pa.Table.from_pylist([example]) for example in first_examples] |
| 37 | + inferred_arrow_schema = pa.concat_tables(pa_tables, promote_options="default").schema |
| 38 | + self.info.features = datasets.Features.from_arrow_schema(inferred_arrow_schema) |
| 39 | + |
| 40 | + return splits |
| 41 | + |
| 42 | + def _sort_samples_key(self, sample_path: str): |
| 43 | + # looks like "{sample_idx}_epoch_{epoch_idx}"" |
| 44 | + (sample_idx_str, epoch_idx_str) = os.path.splitext(os.path.basename(sample_path))[0].split("_epoch_") |
| 45 | + return (int(epoch_idx_str), int(sample_idx_str)) |
| 46 | + |
| 47 | + def _iter_samples_from_log_files(self, log_files: list[str]): |
| 48 | + sample_files = [log_file for log_file in log_files if os.path.basename(os.path.dirname(log_file)) == "samples"] |
| 49 | + sample_files.sort(key=self._sort_samples_key) |
| 50 | + for sample_file in sample_files: |
| 51 | + with open(sample_file) as f: |
| 52 | + sample = json.load(f) |
| 53 | + for field in sample: |
| 54 | + if isinstance(sample[field], dict): |
| 55 | + sample[field] = json.dumps(sample[field]) |
| 56 | + if isinstance(sample[field], list): |
| 57 | + sample[field] = [json.dumps(x) for x in sample[field]] |
| 58 | + yield sample |
| 59 | + |
| 60 | + def _generate_examples(self, logs_files): |
| 61 | + for file_idx, log_files in enumerate(logs_files): |
| 62 | + for sample_idx, sample in enumerate(self._iter_samples_from_log_files(log_files)): |
| 63 | + yield Key(file_idx, sample_idx), sample |
0 commit comments