diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index b96a5eb2d..7a4ff0d8a 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -57,7 +57,7 @@ jobs: - name: Test with pytest env: MP_API_KEY: ${{ secrets[env.API_KEY_NAME] }} - #MP_API_ENDPOINT: https://api-preview.materialsproject.org/ + # MP_API_ENDPOINT: https://api-preview.materialsproject.org/ run: | pip install -e . pytest -n auto -x --cov=mp_api --cov-report=xml diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 793d0902f..81e80bd88 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -7,8 +7,10 @@ import inspect import itertools +import logging import os import platform +import shutil import sys import warnings from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait @@ -21,7 +23,10 @@ from typing import TYPE_CHECKING, ForwardRef, Optional, get_args from urllib.parse import quote, urljoin +import pyarrow as pa +import pyarrow.dataset as ds import requests +from deltalake import DeltaTable, QueryBuilder, convert_to_deltalake from emmet.core.utils import jsanitize from pydantic import BaseModel, create_model from requests.adapters import HTTPAdapter @@ -31,7 +36,7 @@ from urllib3.util.retry import Retry from mp_api.client.core.settings import MAPIClientSettings -from mp_api.client.core.utils import load_json, validate_ids +from mp_api.client.core.utils import MPDataset, load_json, validate_ids try: import boto3 @@ -58,6 +63,14 @@ SETTINGS = MAPIClientSettings() # type: ignore +hdlr = logging.StreamHandler() +fmt = logging.Formatter("%(name)s - %(levelname)s - %(message)s") +hdlr.setFormatter(fmt) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logger.addHandler(hdlr) + class _DictLikeAccess(BaseModel): """Define a pydantic mix-in which permits dict-like access to model fields.""" @@ -83,6 +96,7 @@ class BaseRester: document_model: type[BaseModel] | None = None supports_versions: bool = False primary_key: str = "material_id" + delta_backed: bool = False def __init__( self, @@ -97,6 +111,8 @@ def __init__( timeout: int = 20, headers: dict | None = None, mute_progress_bars: bool = SETTINGS.MUTE_PROGRESS_BARS, + local_dataset_cache: str | os.PathLike = SETTINGS.LOCAL_DATASET_CACHE, + force_renew: bool = False, ): """Initialize the REST API helper class. @@ -128,6 +144,9 @@ def __init__( timeout: Time in seconds to wait until a request timeout error is thrown headers: Custom headers for localhost connections. mute_progress_bars: Whether to disable progress bars. + local_dataset_cache: Target directory for downloading full datasets. Defaults + to 'mp_datasets' in the user's home directory + force_renew: Option to overwrite existing local dataset """ # TODO: think about how to migrate from PMG_MAPI_KEY self.api_key = api_key or os.getenv("MP_API_KEY") @@ -141,7 +160,12 @@ def __init__( self.timeout = timeout self.headers = headers or {} self.mute_progress_bars = mute_progress_bars - self.db_version = BaseRester._get_database_version(self.endpoint) + self.local_dataset_cache = local_dataset_cache + self.force_renew = force_renew + ( + self.db_version, + self.access_controlled_batch_ids, + ) = BaseRester._get_heartbeat_info(self.endpoint) if self.suffix: self.endpoint = urljoin(self.endpoint, self.suffix) @@ -216,8 +240,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): # pragma: no cover @staticmethod @cache - def _get_database_version(endpoint): - """The Materials Project database is periodically updated and has a + def _get_heartbeat_info(endpoint) -> tuple[str, str]: + """DB version: + The Materials Project database is periodically updated and has a database version associated with it. When the database is updated, consolidated data (information about "a material") may and does change, while calculation data about a specific calculation task @@ -227,9 +252,24 @@ def _get_database_version(endpoint): where "_DD" may be optional. An additional numerical or `postN` suffix might be added if multiple releases happen on the same day. - Returns: database version as a string + Access Controlled Datasets: + Certain contributions to the Materials Project have access + control restrictions that require explicit agreement to the + Terms of Use for the respective datasets prior to access being + granted. + + A full list of the Terms of Use for all contributions in the + Materials Project are available at: + + https://next-gen.materialsproject.org/about/terms + + Returns: + tuple with database version as a string and a comma separated + string with all calculation batch identifiers that have access + restrictions """ - return requests.get(url=endpoint + "heartbeat").json()["db_version"] + response = requests.get(url=endpoint + "heartbeat").json() + return response["db_version"], response["access_controlled_batch_ids"] def _post_resource( self, @@ -368,10 +408,7 @@ def _patch_resource( raise MPRestError(str(ex)) def _query_open_data( - self, - bucket: str, - key: str, - decoder: Callable | None = None, + self, bucket: str, key: str, decoder: Callable | None = None ) -> tuple[list[dict] | list[bytes], int]: """Query and deserialize Materials Project AWS open data s3 buckets. @@ -471,6 +508,12 @@ def _query_resource( url += "/" if query_s3: + pbar_message = ( # type: ignore + f"Retrieving {self.document_model.__name__} documents" # type: ignore + if self.document_model is not None + else "Retrieving documents" + ) + db_version = self.db_version.replace(".", "-") if "/" not in self.suffix: suffix = self.suffix @@ -481,15 +524,169 @@ def _query_resource( suffix = infix if suffix == "core" else suffix suffix = suffix.replace("_", "-") - # Paginate over all entries in the bucket. - # TODO: change when a subset of entries needed from DB + # Check if user has access to GNoMe + # temp suppress tqdm + re_enable = not self.mute_progress_bars + self.mute_progress_bars = True + has_gnome_access = bool( + self._submit_requests( + url=urljoin(self.base_endpoint, "materials/summary/"), + criteria={ + "batch_id": "gnome_r2scan_statics", + "_fields": "material_id", + }, + use_document_model=False, + num_chunks=1, + chunk_size=1, + timeout=timeout, + ) + .get("meta", {}) + .get("total_doc", 0) + ) + self.mute_progress_bars = not re_enable + if "tasks" in suffix: - bucket_suffix, prefix = "parsed", "tasks_atomate2" + bucket_suffix, prefix = ("parsed", "core/tasks/") else: bucket_suffix = "build" prefix = f"collections/{db_version}/{suffix}" bucket = f"materialsproject-{bucket_suffix}" + + if self.delta_backed: + target_path = str( + self.local_dataset_cache.joinpath(f"{bucket_suffix}/{prefix}") + ) + os.makedirs(target_path, exist_ok=True) + + if DeltaTable.is_deltatable(target_path): + if self.force_renew: + shutil.rmtree(target_path) + logger.warning( + f"Regenerating {suffix} dataset at {target_path}..." + ) + os.makedirs(target_path, exist_ok=True) + else: + logger.warning( + f"Dataset for {suffix} already exists at {target_path}, returning existing dataset." + ) + logger.info( + "Delete or move existing dataset or re-run search query with MPRester(force_renew=True) " + "to refresh local dataset.", + ) + + return { + "data": MPDataset( + path=target_path, + document_model=self.document_model, + use_document_model=self.use_document_model, + ) + } + + tbl = DeltaTable( + f"s3a://{bucket}/{prefix}", + storage_options={ + "AWS_SKIP_SIGNATURE": "true", + "AWS_REGION": "us-east-1", + }, + ) + + controlled_batch_str = ",".join( + [f"'{tag}'" for tag in self.access_controlled_batch_ids] + ) + + predicate = ( + " WHERE batch_id NOT IN (" # don't delete leading space + + controlled_batch_str + + ")" + if not has_gnome_access + else "" + ) + + builder = QueryBuilder().register("tbl", tbl) + + # Setup progress bar + num_docs_needed = pa.table( + builder.execute("SELECT COUNT(*) FROM tbl").read_all() + )[0][0].as_py() + + if not has_gnome_access: + num_docs_needed = self.count( + {"batch_id_neq_any": self.access_controlled_batch_ids} + ) + + pbar = ( + tqdm( + desc=pbar_message, + total=num_docs_needed, + ) + if not self.mute_progress_bars + else None + ) + + iterator = builder.execute("SELECT * FROM tbl" + predicate) + + file_options = ds.ParquetFileFormat().make_write_options( + compression="zstd" + ) + + def _flush(accumulator, group): + ds.write_dataset( + accumulator, + base_dir=target_path, + format="parquet", + basename_template=f"group-{group}-" + + "part-{i}.zstd.parquet", + existing_data_behavior="overwrite_or_ignore", + max_rows_per_group=1024, + file_options=file_options, + ) + + group = 1 + size = 0 + accumulator = [] + for page in iterator: + # arro3 rb to pyarrow rb for compat w/ pyarrow ds writer + rg = pa.record_batch(page) + accumulator.append(rg) + page_size = page.num_rows + size += rg.get_total_buffer_size() + + if pbar is not None: + pbar.update(page_size) + + if size >= SETTINGS.DATASET_FLUSH_THRESHOLD: + _flush(accumulator, group) + group += 1 + size = 0 + accumulator.clear() + + if accumulator: + _flush(accumulator, group + 1) + + if pbar is not None: + pbar.close() + + logger.info(f"Dataset for {suffix} written to {target_path}") + logger.info("Converting to DeltaTable...") + + convert_to_deltalake(target_path) + + logger.info( + "Consult the delta-rs and pyarrow documentation for advanced usage: " + "delta-io.github.io/delta-rs/, arrow.apache.org/docs/python/" + ) + + return { + "data": MPDataset( + path=target_path, + document_model=self.document_model, + use_document_model=self.use_document_model, + ) + } + + # Paginate over all entries in the bucket. + # TODO: change when a subset of entries needed from DB paginator = self.s3_client.get_paginator("list_objects_v2") pages = paginator.paginate(Bucket=bucket, Prefix=prefix) @@ -526,11 +723,6 @@ def _query_resource( } # Setup progress bar - pbar_message = ( # type: ignore - f"Retrieving {self.document_model.__name__} documents" # type: ignore - if self.document_model is not None - else "Retrieving documents" - ) num_docs_needed = int(self.count()) pbar = ( tqdm( diff --git a/mp_api/client/core/settings.py b/mp_api/client/core/settings.py index 200b67785..09926fe82 100644 --- a/mp_api/client/core/settings.py +++ b/mp_api/client/core/settings.py @@ -1,5 +1,6 @@ import os from multiprocessing import cpu_count +from pathlib import Path from typing import List from pydantic import Field @@ -50,6 +51,7 @@ class MAPIClientSettings(BaseSettings): "condition_mixing_media", "condition_heating_atmosphere", "operations", + "batch_id_neq_any", "_fields", ], description="List API query parameters that do not support parallel requests.", @@ -87,4 +89,14 @@ class MAPIClientSettings(BaseSettings): _MAX_LIST_LENGTH, description="Maximum length of query parameter list" ) + LOCAL_DATASET_CACHE: Path = Field( + Path("~/mp_datasets").expanduser(), + description="Target directory for downloading full datasets", + ) + + DATASET_FLUSH_THRESHOLD: int = Field( + int(2.75 * 1024**3), + description="Threshold bytes to accumulate in memory before flushing dataset to disk", + ) + model_config = SettingsConfigDict(env_prefix="MPRESTER_") diff --git a/mp_api/client/core/utils.py b/mp_api/client/core/utils.py index c2d03fec2..b549d5b2f 100644 --- a/mp_api/client/core/utils.py +++ b/mp_api/client/core/utils.py @@ -1,12 +1,17 @@ from __future__ import annotations import re +from functools import cached_property +from itertools import chain from typing import TYPE_CHECKING, Literal import orjson +import pyarrow.dataset as ds +from deltalake import DeltaTable from emmet.core import __version__ as _EMMET_CORE_VER from monty.json import MontyDecoder from packaging.version import parse as parse_version +from pydantic._internal._model_construction import ModelMetaclass from mp_api.client.core.settings import MAPIClientSettings @@ -124,3 +129,68 @@ def validate_monty(cls, v, _): monty_cls.validate_monty_v2 = classmethod(validate_monty) return monty_cls + + +class MPDataset: + def __init__(self, path, document_model, use_document_model): + """Convenience wrapper for pyarrow datasets stored on disk.""" + self._start = 0 + self._path = path + self._document_model = document_model + self._dataset = ds.dataset(path) + self._row_groups = list( + chain.from_iterable( + [ + fragment.split_by_row_group() + for fragment in self._dataset.get_fragments() + ] + ) + ) + self._use_document_model = use_document_model + + @property + def pyarrow_dataset(self) -> ds.Dataset: + return self._dataset + + @property + def pydantic_model(self) -> ModelMetaclass: + return self._document_model + + @property + def use_document_model(self) -> bool: + return self._use_document_model + + @use_document_model.setter + def use_document_model(self, value: bool): + self._use_document_model = value + + @cached_property + def delta_table(self) -> DeltaTable: + return DeltaTable(self._path) + + @cached_property + def num_chunks(self) -> int: + return len(self._row_groups) + + def __getitem__(self, idx): + if isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + _take = list(range(start, stop, step)) + ds_slice = self._dataset.take(_take).to_pylist(maps_as_pydicts="strict") + return ( + [self._document_model(**_row) for _row in ds_slice] + if self._use_document_model + else ds_slice + ) + + _row = self._dataset.take([idx]).to_pylist(maps_as_pydicts="strict")[0] + return self._document_model(**_row) if self._use_document_model else _row + + def __len__(self) -> int: + return self._dataset.count_rows() + + def __iter__(self): + current = self._start + while current < len(self): + yield self[current] + current += 1 diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index e6f65c964..8f2752cb1 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -134,6 +134,8 @@ def __init__( session: Session | None = None, headers: dict | None = None, mute_progress_bars: bool = _MAPI_SETTINGS.MUTE_PROGRESS_BARS, + local_dataset_cache: str | os.PathLike = _MAPI_SETTINGS.LOCAL_DATASET_CACHE, + force_renew: bool = False, ): """Initialize the MPRester. @@ -168,6 +170,9 @@ def __init__( session: Session object to use. By default (None), the client will create one. headers: Custom headers for localhost connections. mute_progress_bars: Whether to mute progress bars. + local_dataset_cache: Target directory for downloading full datasets. Defaults + to "mp_datasets" in the user's home directory + force_renew: Option to overwrite existing local dataset """ # SETTINGS tries to read API key from ~/.config/.pmgrc.yaml @@ -193,6 +198,8 @@ def __init__( self.use_document_model = use_document_model self.monty_decode = monty_decode self.mute_progress_bars = mute_progress_bars + self.local_dataset_cache = local_dataset_cache + self.force_renew = force_renew self._contribs = None self._deprecated_attributes = [ @@ -268,6 +275,8 @@ def __init__( use_document_model=self.use_document_model, headers=self.headers, mute_progress_bars=self.mute_progress_bars, + local_dataset_cache=self.local_dataset_cache, + force_renew=self.force_renew, ) for cls in self._all_resters if cls.suffix in core_suffix @@ -294,6 +303,8 @@ def __init__( use_document_model=self.use_document_model, headers=self.headers, mute_progress_bars=self.mute_progress_bars, + local_dataset_cache=self.local_dataset_cache, + force_renew=self.force_renew, ) # type: BaseRester setattr( self, @@ -324,6 +335,8 @@ def __core_custom_getattr(_self, _attr, _rester_map): use_document_model=self.use_document_model, headers=self.headers, mute_progress_bars=self.mute_progress_bars, + local_dataset_cache=self.local_dataset_cache, + force_renew=self.force_renew, ) # type: BaseRester setattr( diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index c78650780..a879a93c4 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -3,7 +3,11 @@ from datetime import datetime from typing import TYPE_CHECKING +import pyarrow as pa +from deltalake import DeltaTable, QueryBuilder +from emmet.core.mpid import AlphaID from emmet.core.tasks import CoreTaskDoc +from emmet.core.trajectory import RelaxTrajectory from mp_api.client.core import BaseRester, MPRestError from mp_api.client.core.utils import validate_ids @@ -16,6 +20,7 @@ class TaskRester(BaseRester): suffix: str = "materials/tasks" document_model: type[BaseModel] = CoreTaskDoc # type: ignore primary_key: str = "task_id" + delta_backed = True def get_trajectory(self, task_id): """Returns a Trajectory object containing the geometry of the @@ -26,16 +31,30 @@ def get_trajectory(self, task_id): task_id (str): Task ID """ - traj_data = self._query_resource_data( - {"task_ids": [task_id]}, suburl="trajectory/", use_document_model=False - )[0].get( - "trajectories", None - ) # type: ignore + as_alpha = str(AlphaID(task_id, padlen=8)).split("-")[-1] - if traj_data is None: + traj_tbl = DeltaTable( + "s3a://materialsproject-parsed/core/trajectories/", + storage_options={"AWS_SKIP_SIGNATURE": "true", "AWS_REGION": "us-east-1"}, + ) + + traj_data = pa.table( + QueryBuilder() + .register("traj", traj_tbl) + .execute( + f""" + SELECT * + FROM traj + WHERE identifier='{as_alpha}' + """ + ) + .read_all() + ).to_pylist(maps_as_pydicts="strict") + + if not traj_data: raise MPRestError(f"No trajectory data for {task_id} found") - return traj_data + return RelaxTrajectory(**traj_data[0]).to_pmg().as_dict() def search( self, diff --git a/pyproject.toml b/pyproject.toml index afefc1e06..9370f21ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,8 @@ dependencies = [ "smart_open", "boto3", "orjson >= 3.10,<4", + "pyarrow >= 20.0.0", + "deltalake >= 1.2.0", ] dynamic = ["version"] diff --git a/requirements/requirements-ubuntu-latest_py3.11.txt b/requirements/requirements-ubuntu-latest_py3.11.txt index d3ff66cb4..f817130ab 100644 --- a/requirements/requirements-ubuntu-latest_py3.11.txt +++ b/requirements/requirements-ubuntu-latest_py3.11.txt @@ -6,6 +6,8 @@ # annotated-types==0.7.0 # via pydantic +arro3-core==0.6.5 + # via deltalake bibtexparser==1.4.3 # via pymatgen blake3==1.0.8 @@ -24,7 +26,11 @@ contourpy==1.3.3 # via matplotlib cycler==0.12.1 # via matplotlib -emmet-core==0.86.0 +deltalake==1.2.1 + # via mp-api (pyproject.toml) +deprecated==1.3.1 + # via deltalake +emmet-core==0.86.0rc1 # via mp-api (pyproject.toml) fonttools==4.60.1 # via matplotlib @@ -81,6 +87,8 @@ pillow==12.0.0 # via matplotlib plotly==6.4.0 # via pymatgen +pyarrow==22.0.0 + # via mp-api (pyproject.toml) pybtex==0.25.1 # via emmet-core pydantic==2.12.4 @@ -145,6 +153,7 @@ tqdm==4.67.1 # via pymatgen typing-extensions==4.15.0 # via + # arro3-core # blake3 # emmet-core # mp-api (pyproject.toml) @@ -165,4 +174,6 @@ urllib3==2.5.0 # botocore # requests wrapt==2.0.0 - # via smart-open + # via + # deprecated + # smart-open diff --git a/requirements/requirements-ubuntu-latest_py3.11_extras.txt b/requirements/requirements-ubuntu-latest_py3.11_extras.txt index 59ac2f166..0a7c52cdb 100644 --- a/requirements/requirements-ubuntu-latest_py3.11_extras.txt +++ b/requirements/requirements-ubuntu-latest_py3.11_extras.txt @@ -8,6 +8,8 @@ alabaster==1.0.0 # via sphinx annotated-types==0.7.0 # via pydantic +arro3-core==0.6.5 + # via deltalake arrow==1.4.0 # via isoduration ase==3.26.0 @@ -54,6 +56,10 @@ cycler==0.12.1 # via matplotlib decorator==5.2.1 # via ipython +deltalake==1.2.1 + # via mp-api (pyproject.toml) +deprecated==1.3.1 + # via deltalake distlib==0.4.0 # via virtualenv dnspython==2.8.0 @@ -271,7 +277,9 @@ pubchempy==1.0.5 pure-eval==0.2.3 # via stack-data pyarrow==22.0.0 - # via emmet-core + # via + # emmet-core + # mp-api (pyproject.toml) pybtex==0.25.1 # via # emmet-core @@ -491,6 +499,7 @@ types-setuptools==80.9.0.20250822 # via mp-api (pyproject.toml) typing-extensions==4.15.0 # via + # arro3-core # blake3 # bravado # emmet-core @@ -534,4 +543,6 @@ wcwidth==0.2.14 webcolors==25.10.0 # via jsonschema wrapt==2.0.0 - # via smart-open + # via + # deprecated + # smart-open diff --git a/requirements/requirements-ubuntu-latest_py3.12.txt b/requirements/requirements-ubuntu-latest_py3.12.txt index 29ee10749..831da0d10 100644 --- a/requirements/requirements-ubuntu-latest_py3.12.txt +++ b/requirements/requirements-ubuntu-latest_py3.12.txt @@ -6,6 +6,8 @@ # annotated-types==0.7.0 # via pydantic +arro3-core==0.6.5 + # via deltalake bibtexparser==1.4.3 # via pymatgen blake3==1.0.8 @@ -24,7 +26,11 @@ contourpy==1.3.3 # via matplotlib cycler==0.12.1 # via matplotlib -emmet-core==0.86.0 +deltalake==1.2.1 + # via mp-api (pyproject.toml) +deprecated==1.3.1 + # via deltalake +emmet-core==0.86.0rc1 # via mp-api (pyproject.toml) fonttools==4.60.1 # via matplotlib @@ -81,6 +87,8 @@ pillow==12.0.0 # via matplotlib plotly==6.4.0 # via pymatgen +pyarrow==22.0.0 + # via mp-api (pyproject.toml) pybtex==0.25.1 # via emmet-core pydantic==2.12.4 @@ -164,4 +172,6 @@ urllib3==2.5.0 # botocore # requests wrapt==2.0.0 - # via smart-open + # via + # deprecated + # smart-open diff --git a/requirements/requirements-ubuntu-latest_py3.12_extras.txt b/requirements/requirements-ubuntu-latest_py3.12_extras.txt index 9b4c609a5..ebf193dce 100644 --- a/requirements/requirements-ubuntu-latest_py3.12_extras.txt +++ b/requirements/requirements-ubuntu-latest_py3.12_extras.txt @@ -8,6 +8,8 @@ alabaster==1.0.0 # via sphinx annotated-types==0.7.0 # via pydantic +arro3-core==0.6.5 + # via deltalake arrow==1.4.0 # via isoduration ase==3.26.0 @@ -54,6 +56,10 @@ cycler==0.12.1 # via matplotlib decorator==5.2.1 # via ipython +deltalake==1.2.1 + # via mp-api (pyproject.toml) +deprecated==1.3.1 + # via deltalake distlib==0.4.0 # via virtualenv dnspython==2.8.0 @@ -271,7 +277,9 @@ pubchempy==1.0.5 pure-eval==0.2.3 # via stack-data pyarrow==22.0.0 - # via emmet-core + # via + # emmet-core + # mp-api (pyproject.toml) pybtex==0.25.1 # via # emmet-core @@ -532,4 +540,6 @@ wcwidth==0.2.14 webcolors==25.10.0 # via jsonschema wrapt==2.0.0 - # via smart-open + # via + # deprecated + # smart-open diff --git a/tests/materials/test_tasks.py b/tests/materials/test_tasks.py index b35dfd938..1ddf12c58 100644 --- a/tests/materials/test_tasks.py +++ b/tests/materials/test_tasks.py @@ -1,8 +1,9 @@ import os -from core_function import client_search_testing -import pytest +import pytest +from core_function import client_search_testing from emmet.core.utils import utcnow + from mp_api.client.routes.materials.tasks import TaskRester @@ -53,7 +54,6 @@ def test_client(rester): def test_get_trajectories(rester): - trajectories = [traj for traj in rester.get_trajectory("mp-149")] + trajectory = rester.get_trajectory("mp-149") - for traj in trajectories: - assert ("@module", "pymatgen.core.trajectory") in traj.items() + assert trajectory["@module"] == "pymatgen.core.trajectory"