Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9d2048e
exclude gnome for full downloads if needed
tschaume Mar 5, 2025
505ddfe
query s3 for trajectories
tsmathis Oct 23, 2025
aee0f8c
add deltalake query support
tsmathis Oct 23, 2025
d5a25b1
linting + mistaken sed replace on 'where'
tsmathis Oct 23, 2025
2de051d
return trajectory as pmg dict
tsmathis Oct 23, 2025
7d0b8b7
update trajectory test
tsmathis Oct 23, 2025
7195adf
correct docstrs
tsmathis Oct 23, 2025
33b787f
Merge branch 'main' into deltalake
tschaume Oct 24, 2025
2664fcd
get access controlled batch ids from heartbeat
tsmathis Nov 3, 2025
b498a76
refactor
tsmathis Nov 4, 2025
7da6984
Merge branch 'main' into deltalake
tschaume Nov 4, 2025
948c108
auto dependency upgrades
invalid-email-address Nov 5, 2025
b0aed4f
Update testing.yml
tschaume Nov 5, 2025
a35bcb7
rm overlooked access of removed settings param
tsmathis Nov 5, 2025
9460601
refactor: consolidate requests to heartbeat for meta info
tsmathis Nov 5, 2025
05f1d0e
lint
tsmathis Nov 5, 2025
e685445
fix incomplete docstr
tsmathis Nov 5, 2025
bb0b238
typo
tsmathis Nov 5, 2025
dc0c949
Merge branch 'main' into deltalake
tsmathis Nov 10, 2025
fb84d73
revert testing endpoint
tsmathis Nov 10, 2025
5bdacf5
no parallel on batch_id_neq_any
tsmathis Nov 10, 2025
7ee5515
more resilient dataset path expansion
tsmathis Nov 12, 2025
ae7674d
missed field annotation update
tsmathis Nov 12, 2025
5538c74
coerce Path to str for deltalake lib
tsmathis Nov 12, 2025
f39c0d3
flush based on bytes
tsmathis Nov 14, 2025
a965255
iterate over individual rows for local dataset
tsmathis Nov 14, 2025
03b38e7
missed bounds check for updated iteration behavior
tsmathis Nov 14, 2025
3a44b4f
opt for module level logging over warnings lib
tsmathis Nov 14, 2025
b2a832f
lint
tsmathis Nov 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ jobs:
- name: Test with pytest
env:
MP_API_KEY: ${{ secrets[env.API_KEY_NAME] }}
#MP_API_ENDPOINT: https://api-preview.materialsproject.org/
# MP_API_ENDPOINT: https://api-preview.materialsproject.org/
run: |
pip install -e .
pytest -n auto -x --cov=mp_api --cov-report=xml
Expand Down
228 changes: 210 additions & 18 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

import inspect
import itertools
import logging
import os
import platform
import shutil
import sys
import warnings
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
Expand All @@ -21,7 +23,10 @@
from typing import TYPE_CHECKING, ForwardRef, Optional, get_args
from urllib.parse import quote, urljoin

import pyarrow as pa
import pyarrow.dataset as ds
import requests
from deltalake import DeltaTable, QueryBuilder, convert_to_deltalake
from emmet.core.utils import jsanitize
from pydantic import BaseModel, create_model
from requests.adapters import HTTPAdapter
Expand All @@ -31,7 +36,7 @@
from urllib3.util.retry import Retry

from mp_api.client.core.settings import MAPIClientSettings
from mp_api.client.core.utils import load_json, validate_ids
from mp_api.client.core.utils import MPDataset, load_json, validate_ids

try:
import boto3
Expand All @@ -58,6 +63,14 @@

SETTINGS = MAPIClientSettings() # type: ignore

hdlr = logging.StreamHandler()
fmt = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
hdlr.setFormatter(fmt)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(hdlr)


class _DictLikeAccess(BaseModel):
"""Define a pydantic mix-in which permits dict-like access to model fields."""
Expand All @@ -83,6 +96,7 @@ class BaseRester:
document_model: type[BaseModel] | None = None
supports_versions: bool = False
primary_key: str = "material_id"
delta_backed: bool = False

def __init__(
self,
Expand All @@ -97,6 +111,8 @@ def __init__(
timeout: int = 20,
headers: dict | None = None,
mute_progress_bars: bool = SETTINGS.MUTE_PROGRESS_BARS,
local_dataset_cache: str | os.PathLike = SETTINGS.LOCAL_DATASET_CACHE,
force_renew: bool = False,
):
"""Initialize the REST API helper class.

Expand Down Expand Up @@ -128,6 +144,9 @@ def __init__(
timeout: Time in seconds to wait until a request timeout error is thrown
headers: Custom headers for localhost connections.
mute_progress_bars: Whether to disable progress bars.
local_dataset_cache: Target directory for downloading full datasets. Defaults
to 'mp_datasets' in the user's home directory
force_renew: Option to overwrite existing local dataset
"""
# TODO: think about how to migrate from PMG_MAPI_KEY
self.api_key = api_key or os.getenv("MP_API_KEY")
Expand All @@ -141,7 +160,12 @@ def __init__(
self.timeout = timeout
self.headers = headers or {}
self.mute_progress_bars = mute_progress_bars
self.db_version = BaseRester._get_database_version(self.endpoint)
self.local_dataset_cache = local_dataset_cache
self.force_renew = force_renew
(
self.db_version,
self.access_controlled_batch_ids,
) = BaseRester._get_heartbeat_info(self.endpoint)

if self.suffix:
self.endpoint = urljoin(self.endpoint, self.suffix)
Expand Down Expand Up @@ -216,8 +240,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): # pragma: no cover

@staticmethod
@cache
def _get_database_version(endpoint):
"""The Materials Project database is periodically updated and has a
def _get_heartbeat_info(endpoint) -> tuple[str, str]:
"""DB version:
The Materials Project database is periodically updated and has a
database version associated with it. When the database is updated,
consolidated data (information about "a material") may and does
change, while calculation data about a specific calculation task
Expand All @@ -227,9 +252,24 @@ def _get_database_version(endpoint):
where "_DD" may be optional. An additional numerical or `postN` suffix
might be added if multiple releases happen on the same day.

Returns: database version as a string
Access Controlled Datasets:
Certain contributions to the Materials Project have access
control restrictions that require explicit agreement to the
Terms of Use for the respective datasets prior to access being
granted.

A full list of the Terms of Use for all contributions in the
Materials Project are available at:

https://next-gen.materialsproject.org/about/terms

Returns:
tuple with database version as a string and a comma separated
string with all calculation batch identifiers that have access
restrictions
"""
return requests.get(url=endpoint + "heartbeat").json()["db_version"]
response = requests.get(url=endpoint + "heartbeat").json()
return response["db_version"], response["access_controlled_batch_ids"]

def _post_resource(
self,
Expand Down Expand Up @@ -368,10 +408,7 @@ def _patch_resource(
raise MPRestError(str(ex))

def _query_open_data(
self,
bucket: str,
key: str,
decoder: Callable | None = None,
self, bucket: str, key: str, decoder: Callable | None = None
) -> tuple[list[dict] | list[bytes], int]:
"""Query and deserialize Materials Project AWS open data s3 buckets.

Expand Down Expand Up @@ -471,6 +508,12 @@ def _query_resource(
url += "/"

if query_s3:
pbar_message = ( # type: ignore
f"Retrieving {self.document_model.__name__} documents" # type: ignore
if self.document_model is not None
else "Retrieving documents"
)

db_version = self.db_version.replace(".", "-")
if "/" not in self.suffix:
suffix = self.suffix
Expand All @@ -481,15 +524,169 @@ def _query_resource(
suffix = infix if suffix == "core" else suffix
suffix = suffix.replace("_", "-")

# Paginate over all entries in the bucket.
# TODO: change when a subset of entries needed from DB
# Check if user has access to GNoMe
# temp suppress tqdm
re_enable = not self.mute_progress_bars
self.mute_progress_bars = True
has_gnome_access = bool(
self._submit_requests(
url=urljoin(self.base_endpoint, "materials/summary/"),
criteria={
"batch_id": "gnome_r2scan_statics",
"_fields": "material_id",
},
use_document_model=False,
num_chunks=1,
chunk_size=1,
timeout=timeout,
)
.get("meta", {})
.get("total_doc", 0)
)
self.mute_progress_bars = not re_enable

if "tasks" in suffix:
bucket_suffix, prefix = "parsed", "tasks_atomate2"
bucket_suffix, prefix = ("parsed", "core/tasks/")
else:
bucket_suffix = "build"
prefix = f"collections/{db_version}/{suffix}"

bucket = f"materialsproject-{bucket_suffix}"

if self.delta_backed:
target_path = str(
self.local_dataset_cache.joinpath(f"{bucket_suffix}/{prefix}")
)
os.makedirs(target_path, exist_ok=True)

if DeltaTable.is_deltatable(target_path):
if self.force_renew:
shutil.rmtree(target_path)
logger.warning(
f"Regenerating {suffix} dataset at {target_path}..."
)
os.makedirs(target_path, exist_ok=True)
else:
logger.warning(
f"Dataset for {suffix} already exists at {target_path}, returning existing dataset."
)
logger.info(
"Delete or move existing dataset or re-run search query with MPRester(force_renew=True) "
"to refresh local dataset.",
)

return {
"data": MPDataset(
path=target_path,
document_model=self.document_model,
use_document_model=self.use_document_model,
)
}

tbl = DeltaTable(
f"s3a://{bucket}/{prefix}",
storage_options={
"AWS_SKIP_SIGNATURE": "true",
"AWS_REGION": "us-east-1",
},
)

controlled_batch_str = ",".join(
[f"'{tag}'" for tag in self.access_controlled_batch_ids]
)

predicate = (
" WHERE batch_id NOT IN (" # don't delete leading space
+ controlled_batch_str
+ ")"
if not has_gnome_access
else ""
)

builder = QueryBuilder().register("tbl", tbl)

# Setup progress bar
num_docs_needed = pa.table(
builder.execute("SELECT COUNT(*) FROM tbl").read_all()
)[0][0].as_py()

if not has_gnome_access:
num_docs_needed = self.count(
{"batch_id_neq_any": self.access_controlled_batch_ids}
)

pbar = (
tqdm(
desc=pbar_message,
total=num_docs_needed,
)
if not self.mute_progress_bars
else None
)

iterator = builder.execute("SELECT * FROM tbl" + predicate)

file_options = ds.ParquetFileFormat().make_write_options(
compression="zstd"
)

def _flush(accumulator, group):
ds.write_dataset(
accumulator,
base_dir=target_path,
format="parquet",
basename_template=f"group-{group}-"
+ "part-{i}.zstd.parquet",
existing_data_behavior="overwrite_or_ignore",
max_rows_per_group=1024,
file_options=file_options,
)

group = 1
size = 0
accumulator = []
for page in iterator:
# arro3 rb to pyarrow rb for compat w/ pyarrow ds writer
rg = pa.record_batch(page)
accumulator.append(rg)
page_size = page.num_rows
size += rg.get_total_buffer_size()

if pbar is not None:
pbar.update(page_size)

if size >= SETTINGS.DATASET_FLUSH_THRESHOLD:
_flush(accumulator, group)
group += 1
size = 0
accumulator.clear()

if accumulator:
_flush(accumulator, group + 1)

if pbar is not None:
pbar.close()

logger.info(f"Dataset for {suffix} written to {target_path}")
logger.info("Converting to DeltaTable...")

convert_to_deltalake(target_path)

logger.info(
"Consult the delta-rs and pyarrow documentation for advanced usage: "
"delta-io.github.io/delta-rs/, arrow.apache.org/docs/python/"
)

return {
"data": MPDataset(
path=target_path,
document_model=self.document_model,
use_document_model=self.use_document_model,
)
}

# Paginate over all entries in the bucket.
# TODO: change when a subset of entries needed from DB
paginator = self.s3_client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)

Expand Down Expand Up @@ -526,11 +723,6 @@ def _query_resource(
}

# Setup progress bar
pbar_message = ( # type: ignore
f"Retrieving {self.document_model.__name__} documents" # type: ignore
if self.document_model is not None
else "Retrieving documents"
)
num_docs_needed = int(self.count())
pbar = (
tqdm(
Expand Down
12 changes: 12 additions & 0 deletions mp_api/client/core/settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from multiprocessing import cpu_count
from pathlib import Path
from typing import List

from pydantic import Field
Expand Down Expand Up @@ -50,6 +51,7 @@ class MAPIClientSettings(BaseSettings):
"condition_mixing_media",
"condition_heating_atmosphere",
"operations",
"batch_id_neq_any",
"_fields",
],
description="List API query parameters that do not support parallel requests.",
Expand Down Expand Up @@ -87,4 +89,14 @@ class MAPIClientSettings(BaseSettings):
_MAX_LIST_LENGTH, description="Maximum length of query parameter list"
)

LOCAL_DATASET_CACHE: Path = Field(
Path("~/mp_datasets").expanduser(),
description="Target directory for downloading full datasets",
)

DATASET_FLUSH_THRESHOLD: int = Field(
int(2.75 * 1024**3),
description="Threshold bytes to accumulate in memory before flushing dataset to disk",
)

model_config = SettingsConfigDict(env_prefix="MPRESTER_")
Loading