Skip to content
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9d2048e
exclude gnome for full downloads if needed
tschaume Mar 5, 2025
505ddfe
query s3 for trajectories
tsmathis Oct 23, 2025
aee0f8c
add deltalake query support
tsmathis Oct 23, 2025
d5a25b1
linting + mistaken sed replace on 'where'
tsmathis Oct 23, 2025
2de051d
return trajectory as pmg dict
tsmathis Oct 23, 2025
7d0b8b7
update trajectory test
tsmathis Oct 23, 2025
7195adf
correct docstrs
tsmathis Oct 23, 2025
33b787f
Merge branch 'main' into deltalake
tschaume Oct 24, 2025
2664fcd
get access controlled batch ids from heartbeat
tsmathis Nov 3, 2025
b498a76
refactor
tsmathis Nov 4, 2025
7da6984
Merge branch 'main' into deltalake
tschaume Nov 4, 2025
948c108
auto dependency upgrades
invalid-email-address Nov 5, 2025
b0aed4f
Update testing.yml
tschaume Nov 5, 2025
a35bcb7
rm overlooked access of removed settings param
tsmathis Nov 5, 2025
9460601
refactor: consolidate requests to heartbeat for meta info
tsmathis Nov 5, 2025
05f1d0e
lint
tsmathis Nov 5, 2025
e685445
fix incomplete docstr
tsmathis Nov 5, 2025
bb0b238
typo
tsmathis Nov 5, 2025
dc0c949
Merge branch 'main' into deltalake
tsmathis Nov 10, 2025
fb84d73
revert testing endpoint
tsmathis Nov 10, 2025
5bdacf5
no parallel on batch_id_neq_any
tsmathis Nov 10, 2025
7ee5515
more resilient dataset path expansion
tsmathis Nov 12, 2025
ae7674d
missed field annotation update
tsmathis Nov 12, 2025
5538c74
coerce Path to str for deltalake lib
tsmathis Nov 12, 2025
f39c0d3
flush based on bytes
tsmathis Nov 14, 2025
a965255
iterate over individual rows for local dataset
tsmathis Nov 14, 2025
03b38e7
missed bounds check for updated iteration behavior
tsmathis Nov 14, 2025
3a44b4f
opt for module level logging over warnings lib
tsmathis Nov 14, 2025
b2a832f
lint
tsmathis Nov 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ jobs:
- name: Test with pytest
env:
MP_API_KEY: ${{ secrets[env.API_KEY_NAME] }}
#MP_API_ENDPOINT: https://api-preview.materialsproject.org/
# MP_API_ENDPOINT: https://api-preview.materialsproject.org/
run: |
pip install -e .
pytest -n auto -x --cov=mp_api --cov-report=xml
Expand Down
217 changes: 199 additions & 18 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import itertools
import os
import platform
import shutil
import sys
import warnings
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
Expand All @@ -21,7 +22,10 @@
from typing import TYPE_CHECKING, ForwardRef, Optional, get_args
from urllib.parse import quote, urljoin

import pyarrow as pa
import pyarrow.dataset as ds
import requests
from deltalake import DeltaTable, QueryBuilder, convert_to_deltalake
from emmet.core.utils import jsanitize
from pydantic import BaseModel, create_model
from requests.adapters import HTTPAdapter
Expand All @@ -31,7 +35,7 @@
from urllib3.util.retry import Retry

from mp_api.client.core.settings import MAPIClientSettings
from mp_api.client.core.utils import load_json, validate_ids
from mp_api.client.core.utils import MPDataset, load_json, validate_ids

try:
import boto3
Expand Down Expand Up @@ -83,6 +87,7 @@ class BaseRester:
document_model: type[BaseModel] | None = None
supports_versions: bool = False
primary_key: str = "material_id"
delta_backed: bool = False

def __init__(
self,
Expand All @@ -97,6 +102,8 @@ def __init__(
timeout: int = 20,
headers: dict | None = None,
mute_progress_bars: bool = SETTINGS.MUTE_PROGRESS_BARS,
local_dataset_cache: str | os.PathLike = SETTINGS.LOCAL_DATASET_CACHE,
force_renew: bool = False,
):
"""Initialize the REST API helper class.

Expand Down Expand Up @@ -128,6 +135,9 @@ def __init__(
timeout: Time in seconds to wait until a request timeout error is thrown
headers: Custom headers for localhost connections.
mute_progress_bars: Whether to disable progress bars.
local_dataset_cache: Target directory for downloading full datasets. Defaults
to 'mp_datasets' in the user's home directory
force_renew: Option to overwrite existing local dataset
"""
# TODO: think about how to migrate from PMG_MAPI_KEY
self.api_key = api_key or os.getenv("MP_API_KEY")
Expand All @@ -141,7 +151,12 @@ def __init__(
self.timeout = timeout
self.headers = headers or {}
self.mute_progress_bars = mute_progress_bars
self.db_version = BaseRester._get_database_version(self.endpoint)
self.local_dataset_cache = local_dataset_cache
self.force_renew = force_renew
(
self.db_version,
self.access_controlled_batch_ids,
) = BaseRester._get_heartbeat_info(self.endpoint)

if self.suffix:
self.endpoint = urljoin(self.endpoint, self.suffix)
Expand Down Expand Up @@ -216,8 +231,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): # pragma: no cover

@staticmethod
@cache
def _get_database_version(endpoint):
"""The Materials Project database is periodically updated and has a
def _get_heartbeat_info(endpoint) -> tuple[str, str]:
"""DB version:
The Materials Project database is periodically updated and has a
database version associated with it. When the database is updated,
consolidated data (information about "a material") may and does
change, while calculation data about a specific calculation task
Expand All @@ -227,9 +243,24 @@ def _get_database_version(endpoint):
where "_DD" may be optional. An additional numerical or `postN` suffix
might be added if multiple releases happen on the same day.

Returns: database version as a string
Access Controlled Datasets:
Certain contributions to the Materials Project have access
control restrictions that require explicit agreement to the
Terms of Use for the respective datasets prior to access being
granted.

A full list of the Terms of Use for all contributions in the
Materials Project are available at:

https://next-gen.materialsproject.org/about/terms

Returns:
tuple with database version as a string and a comma separated
string with all calculation batch identifiers that have access
restrictions
"""
return requests.get(url=endpoint + "heartbeat").json()["db_version"]
response = requests.get(url=endpoint + "heartbeat").json()
return response["db_version"], response["access_controlled_batch_ids"]

def _post_resource(
self,
Expand Down Expand Up @@ -368,10 +399,7 @@ def _patch_resource(
raise MPRestError(str(ex))

def _query_open_data(
self,
bucket: str,
key: str,
decoder: Callable | None = None,
self, bucket: str, key: str, decoder: Callable | None = None
) -> tuple[list[dict] | list[bytes], int]:
"""Query and deserialize Materials Project AWS open data s3 buckets.

Expand Down Expand Up @@ -471,6 +499,12 @@ def _query_resource(
url += "/"

if query_s3:
pbar_message = ( # type: ignore
f"Retrieving {self.document_model.__name__} documents" # type: ignore
if self.document_model is not None
else "Retrieving documents"
)

db_version = self.db_version.replace(".", "-")
if "/" not in self.suffix:
suffix = self.suffix
Expand All @@ -481,15 +515,163 @@ def _query_resource(
suffix = infix if suffix == "core" else suffix
suffix = suffix.replace("_", "-")

# Paginate over all entries in the bucket.
# TODO: change when a subset of entries needed from DB
# Check if user has access to GNoMe
# temp suppress tqdm
re_enable = not self.mute_progress_bars
self.mute_progress_bars = True
has_gnome_access = bool(
self._submit_requests(
url=urljoin(self.base_endpoint, "materials/summary/"),
criteria={
"batch_id": "gnome_r2scan_statics",
"_fields": "material_id",
},
use_document_model=False,
num_chunks=1,
chunk_size=1,
timeout=timeout,
)
.get("meta", {})
.get("total_doc", 0)
)
self.mute_progress_bars = not re_enable

if "tasks" in suffix:
bucket_suffix, prefix = "parsed", "tasks_atomate2"
bucket_suffix, prefix = ("parsed", "core/tasks/")
else:
bucket_suffix = "build"
prefix = f"collections/{db_version}/{suffix}"

bucket = f"materialsproject-{bucket_suffix}"

if self.delta_backed:
target_path = (
self.local_dataset_cache + f"/{bucket_suffix}/{prefix}"
)
os.makedirs(target_path, exist_ok=True)

if DeltaTable.is_deltatable(target_path):
if self.force_renew:
shutil.rmtree(target_path)
warnings.warn(
f"Regenerating {suffix} dataset at {target_path}...",
MPLocalDatasetWarning,
)
os.makedirs(target_path, exist_ok=True)
else:
warnings.warn(
f"Dataset for {suffix} already exists at {target_path}, delete or move existing dataset "
"or re-run search query with MPRester(force_renew=True)",
MPLocalDatasetWarning,
)

return {
"data": MPDataset(
path=target_path,
document_model=self.document_model,
use_document_model=self.use_document_model,
)
}

tbl = DeltaTable(
f"s3a://{bucket}/{prefix}",
storage_options={
"AWS_SKIP_SIGNATURE": "true",
"AWS_REGION": "us-east-1",
},
)

controlled_batch_str = ",".join(
[f"'{tag}'" for tag in self.access_controlled_batch_ids]
)

predicate = (
" WHERE batch_id NOT IN (" # don't delete leading space
+ controlled_batch_str
+ ")"
if not has_gnome_access
else ""
)

builder = QueryBuilder().register("tbl", tbl)

# Setup progress bar
num_docs_needed = pa.table(
builder.execute("SELECT COUNT(*) FROM tbl").read_all()
)[0][0].as_py()

if not has_gnome_access:
num_docs_needed = self.count(
{"batch_id_neq_any": self.access_controlled_batch_ids}
)

pbar = (
tqdm(
desc=pbar_message,
total=num_docs_needed,
)
if not self.mute_progress_bars
else None
)

iterator = builder.execute("SELECT * FROM tbl" + predicate)

file_options = ds.ParquetFileFormat().make_write_options(
compression="zstd"
)

def _flush(accumulator, group):
ds.write_dataset(
accumulator,
base_dir=target_path,
format="parquet",
basename_template=f"group-{group}-"
+ "part-{i}.zstd.parquet",
existing_data_behavior="overwrite_or_ignore",
max_rows_per_group=1024,
file_options=file_options,
)

group = 1
size = 0
accumulator = []
for page in iterator:
# arro3 rb to pyarrow rb for compat w/ pyarrow ds writer
accumulator.append(pa.record_batch(page))
page_size = page.num_rows
size += page_size

if pbar is not None:
pbar.update(page_size)

if size >= SETTINGS.DATASET_FLUSH_THRESHOLD:
_flush(accumulator, group)
group += 1
size = 0
accumulator.clear()

if accumulator:
_flush(accumulator, group + 1)

convert_to_deltalake(target_path)

warnings.warn(
f"Dataset for {suffix} written to {target_path}. It is recommended to optimize "
"the table according to your usage patterns prior to running intensive workloads, "
"see: https://delta-io.github.io/delta-rs/delta-lake-best-practices/#optimizing-table-layout",
MPLocalDatasetWarning,
)

return {
"data": MPDataset(
path=target_path,
document_model=self.document_model,
use_document_model=self.use_document_model,
)
}

# Paginate over all entries in the bucket.
# TODO: change when a subset of entries needed from DB
paginator = self.s3_client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)

Expand Down Expand Up @@ -526,11 +708,6 @@ def _query_resource(
}

# Setup progress bar
pbar_message = ( # type: ignore
f"Retrieving {self.document_model.__name__} documents" # type: ignore
if self.document_model is not None
else "Retrieving documents"
)
num_docs_needed = int(self.count())
pbar = (
tqdm(
Expand Down Expand Up @@ -1359,3 +1536,7 @@ class MPRestError(Exception):

class MPRestWarning(Warning):
"""Raised when a query is malformed but interpretable."""


class MPLocalDatasetWarning(Warning):
"""Raised when unrecoverable actions are performed on a local dataset."""
11 changes: 11 additions & 0 deletions mp_api/client/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class MAPIClientSettings(BaseSettings):
"condition_mixing_media",
"condition_heating_atmosphere",
"operations",
"batch_id_neq_any",
"_fields",
],
description="List API query parameters that do not support parallel requests.",
Expand Down Expand Up @@ -87,4 +88,14 @@ class MAPIClientSettings(BaseSettings):
_MAX_LIST_LENGTH, description="Maximum length of query parameter list"
)

LOCAL_DATASET_CACHE: str = Field(
os.path.expanduser("~") + "/mp_datasets",
Copy link
Collaborator

@esoteric-ephemera esoteric-ephemera Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may want to change to just os.path.expanduser("~/mp_datasets") so that os can resolve non-unix-like separators. Or just use pathlib.Path("~/mp_datasets").expanduser()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point
7ee5515

description="Target directory for downloading full datasets",
)

DATASET_FLUSH_THRESHOLD: int = Field(
100000,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make this a byte threshold in memory with pyarrow.Table.get_total_buffer_size? Would be an overestimate but that's probably safe for this case

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if that would work exactly since the in memory accumulator is a pylist of pa.RecordBatchs.

I'll look around for something that's more predictable for the flush threshold than just number of rows since row sizes can vary drastically across different data products.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some looking, RecordBatch also has get_total_buffer_size()

What do you think a good threshold would be in this case? For the first 100k rows for the tasks table I got 2770781904 bytes (2.7 GB)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The corresponding on disk size (compressed w/ zstd) for that first 100k rows is 422M

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2.5-2.75 GB spill is probably good

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

description="Threshold number of rows to accumulate in memory before flushing dataset to disk",
)

model_config = SettingsConfigDict(env_prefix="MPRESTER_")
Loading