Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9d2048e
exclude gnome for full downloads if needed
tschaume Mar 5, 2025
505ddfe
query s3 for trajectories
tsmathis Oct 23, 2025
aee0f8c
add deltalake query support
tsmathis Oct 23, 2025
d5a25b1
linting + mistaken sed replace on 'where'
tsmathis Oct 23, 2025
2de051d
return trajectory as pmg dict
tsmathis Oct 23, 2025
7d0b8b7
update trajectory test
tsmathis Oct 23, 2025
7195adf
correct docstrs
tsmathis Oct 23, 2025
33b787f
Merge branch 'main' into deltalake
tschaume Oct 24, 2025
2664fcd
get access controlled batch ids from heartbeat
tsmathis Nov 3, 2025
b498a76
refactor
tsmathis Nov 4, 2025
7da6984
Merge branch 'main' into deltalake
tschaume Nov 4, 2025
948c108
auto dependency upgrades
invalid-email-address Nov 5, 2025
b0aed4f
Update testing.yml
tschaume Nov 5, 2025
a35bcb7
rm overlooked access of removed settings param
tsmathis Nov 5, 2025
9460601
refactor: consolidate requests to heartbeat for meta info
tsmathis Nov 5, 2025
05f1d0e
lint
tsmathis Nov 5, 2025
e685445
fix incomplete docstr
tsmathis Nov 5, 2025
bb0b238
typo
tsmathis Nov 5, 2025
dc0c949
Merge branch 'main' into deltalake
tsmathis Nov 10, 2025
fb84d73
revert testing endpoint
tsmathis Nov 10, 2025
5bdacf5
no parallel on batch_id_neq_any
tsmathis Nov 10, 2025
7ee5515
more resilient dataset path expansion
tsmathis Nov 12, 2025
ae7674d
missed field annotation update
tsmathis Nov 12, 2025
5538c74
coerce Path to str for deltalake lib
tsmathis Nov 12, 2025
f39c0d3
flush based on bytes
tsmathis Nov 14, 2025
a965255
iterate over individual rows for local dataset
tsmathis Nov 14, 2025
03b38e7
missed bounds check for updated iteration behavior
tsmathis Nov 14, 2025
3a44b4f
opt for module level logging over warnings lib
tsmathis Nov 14, 2025
b2a832f
lint
tsmathis Nov 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 181 additions & 19 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import itertools
import os
import platform
import shutil
import sys
import warnings
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
Expand All @@ -18,15 +19,13 @@
from importlib.metadata import PackageNotFoundError, version
from json import JSONDecodeError
from math import ceil
from typing import (
TYPE_CHECKING,
ForwardRef,
Optional,
get_args,
)
from typing import TYPE_CHECKING, ForwardRef, Optional, get_args
from urllib.parse import quote, urljoin

import pyarrow as pa
import pyarrow.dataset as ds
import requests
from deltalake import DeltaTable, QueryBuilder, convert_to_deltalake
from emmet.core.utils import jsanitize
from pydantic import BaseModel, create_model
from requests.adapters import HTTPAdapter
Expand All @@ -36,7 +35,7 @@
from urllib3.util.retry import Retry

from mp_api.client.core.settings import MAPIClientSettings
from mp_api.client.core.utils import load_json, validate_ids
from mp_api.client.core.utils import MPDataset, load_json, validate_ids

try:
import boto3
Expand Down Expand Up @@ -71,6 +70,7 @@ class BaseRester:
document_model: type[BaseModel] | None = None
supports_versions: bool = False
primary_key: str = "material_id"
delta_backed: bool = False

def __init__(
self,
Expand All @@ -85,6 +85,8 @@ def __init__(
timeout: int = 20,
headers: dict | None = None,
mute_progress_bars: bool = SETTINGS.MUTE_PROGRESS_BARS,
local_dataset_cache: str | os.PathLike = SETTINGS.LOCAL_DATASET_CACHE,
force_renew: bool = False,
):
"""Initialize the REST API helper class.

Expand Down Expand Up @@ -116,6 +118,9 @@ def __init__(
timeout: Time in seconds to wait until a request timeout error is thrown
headers: Custom headers for localhost connections.
mute_progress_bars: Whether to disable progress bars.
local_dataset_cache: Target directory for downloading full datasets. Defaults
to 'materialsproject_datasets' in the user's home directory
force_renew: Option to overwrite existing local dataset
"""
# TODO: think about how to migrate from PMG_MAPI_KEY
self.api_key = api_key or os.getenv("MP_API_KEY")
Expand All @@ -129,6 +134,8 @@ def __init__(
self.timeout = timeout
self.headers = headers or {}
self.mute_progress_bars = mute_progress_bars
self.local_dataset_cache = local_dataset_cache
self.force_renew = force_renew
self.db_version = BaseRester._get_database_version(self.endpoint)

if self.suffix:
Expand Down Expand Up @@ -356,10 +363,7 @@ def _patch_resource(
raise MPRestError(str(ex))

def _query_open_data(
self,
bucket: str,
key: str,
decoder: Callable | None = None,
self, bucket: str, key: str, decoder: Callable | None = None
) -> tuple[list[dict] | list[bytes], int]:
"""Query and deserialize Materials Project AWS open data s3 buckets.

Expand Down Expand Up @@ -463,6 +467,12 @@ def _query_resource(
url += "/"

if query_s3:
pbar_message = ( # type: ignore
f"Retrieving {self.document_model.__name__} documents" # type: ignore
if self.document_model is not None
else "Retrieving documents"
)

db_version = self.db_version.replace(".", "-")
if "/" not in self.suffix:
suffix = self.suffix
Expand All @@ -473,15 +483,168 @@ def _query_resource(
suffix = infix if suffix == "core" else suffix
suffix = suffix.replace("_", "-")

# Paginate over all entries in the bucket.
# TODO: change when a subset of entries needed from DB
# Check if user has access to GNoMe
# temp suppress tqdm
re_enable = not self.mute_progress_bars
self.mute_progress_bars = True
has_gnome_access = bool(
self._submit_requests(
url=urljoin(
"https://api.materialsproject.org/", "materials/summary/"
),
criteria={
"batch_id": "gnome_r2scan_statics",
"_fields": "material_id",
},
use_document_model=False,
num_chunks=1,
chunk_size=1,
timeout=timeout,
)
.get("meta", {})
.get("total_doc", 0)
)
self.mute_progress_bars = not re_enable

if "tasks" in suffix:
bucket_suffix, prefix = "parsed", "tasks_atomate2"
bucket_suffix, prefix = ("parsed", "core/tasks/")
else:
bucket_suffix = "build"
prefix = f"collections/{db_version}/{suffix}"

bucket = f"materialsproject-{bucket_suffix}"

if self.delta_backed:
target_path = (
self.local_dataset_cache + f"/{bucket_suffix}/{prefix}"
)
os.makedirs(target_path, exist_ok=True)

if DeltaTable.is_deltatable(target_path):
if self.force_renew:
shutil.rmtree(target_path)
warnings.warn(
f"Regenerating {suffix} dataset at {target_path}...",
MPLocalDatasetWarning,
)
os.makedirs(target_path, exist_ok=True)
else:
warnings.warn(
f"Dataset for {suffix} already exists at {target_path}, delete or move existing dataset "
"or re-run search query with MPRester(force_renew=True)",
MPLocalDatasetWarning,
)

return {
"data": MPDataset(
path=target_path,
document_model=self.document_model,
use_document_model=self.use_document_model,
)
}

tbl = DeltaTable(
f"s3a://{bucket}/{prefix}",
storage_options={
"AWS_SKIP_SIGNATURE": "true",
"AWS_REGION": "us-east-1",
},
)

controlled_batch_str = ",".join(
[f"'{tag}'" for tag in SETTINGS.ACCESS_CONTROLLED_BATCH_IDS]
)

predicate = (
" WHERE batch_id NOT IN (" # don't delete leading space
+ controlled_batch_str
+ ")"
if not has_gnome_access
else ""
)

builder = QueryBuilder().register("tbl", tbl)

# Setup progress bar
num_docs_needed = pa.table(
builder.execute("SELECT COUNT(*) FROM tbl").read_all()
)[0][0].as_py()

# TODO: Update tasks (+ others?) resource to have emmet-api BatchIdQuery operator
# -> need to modify BatchIdQuery operator to handle root level
# batch_id, not only builder_meta.batch_id
# if not has_gnome_access:
# num_docs_needed = self.count(
# {"batch_id_neq_any": SETTINGS.ACCESS_CONTROLLED_BATCH_IDS}
# )

pbar = (
tqdm(
desc=pbar_message,
total=num_docs_needed,
)
if not self.mute_progress_bars
else None
)

iterator = builder.execute("SELECT * FROM tbl" + predicate)

file_options = ds.ParquetFileFormat().make_write_options(
compression="zstd"
)

def _flush(accumulator, group):
ds.write_dataset(
accumulator,
base_dir=target_path,
format="parquet",
basename_template=f"group-{group}-"
+ "part-{i}.zstd.parquet",
existing_data_behavior="overwrite_or_ignore",
max_rows_per_group=1024,
file_options=file_options,
)

group = 1
size = 0
accumulator = []
for page in iterator:
# arro3 rb to pyarrow rb for compat w/ pyarrow ds writer
accumulator.append(pa.record_batch(page))
page_size = page.num_rows
size += page_size

if pbar is not None:
pbar.update(page_size)

if size >= SETTINGS.DATASET_FLUSH_THRESHOLD:
_flush(accumulator, group)
group += 1
size = 0
accumulator = []

if accumulator:
_flush(accumulator, group + 1)

convert_to_deltalake(target_path)

warnings.warn(
f"Dataset for {suffix} written to {target_path}. It is recommended to optimize "
"the table according to your usage patterns prior to running intensive workloads, "
"see: https://delta-io.github.io/delta-rs/delta-lake-best-practices/#optimizing-table-layout",
MPLocalDatasetWarning,
)

return {
"data": MPDataset(
path=target_path,
document_model=self.document_model,
use_document_model=self.use_document_model,
)
}

# Paginate over all entries in the bucket.
# TODO: change when a subset of entries needed from DB
paginator = self.s3_client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)

Expand Down Expand Up @@ -518,11 +681,6 @@ def _query_resource(
}

# Setup progress bar
pbar_message = ( # type: ignore
f"Retrieving {self.document_model.__name__} documents" # type: ignore
if self.document_model is not None
else "Retrieving documents"
)
num_docs_needed = int(self.count())
pbar = (
tqdm(
Expand Down Expand Up @@ -1350,3 +1508,7 @@ class MPRestError(Exception):

class MPRestWarning(Warning):
"""Raised when a query is malformed but interpretable."""


class MPLocalDatasetWarning(Warning):
"""Raised when unrecoverable actions are performed on a local dataset."""
14 changes: 14 additions & 0 deletions mp_api/client/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,18 @@ class MAPIClientSettings(BaseSettings):
_MAX_LIST_LENGTH, description="Maximum length of query parameter list"
)

LOCAL_DATASET_CACHE: str = Field(
os.path.expanduser("~") + "/mp_datasets",
Copy link
Collaborator

@esoteric-ephemera esoteric-ephemera Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may want to change to just os.path.expanduser("~/mp_datasets") so that os can resolve non-unix-like separators. Or just use pathlib.Path("~/mp_datasets").expanduser()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point
7ee5515

description="Target directory for downloading full datasets",
)

DATASET_FLUSH_THRESHOLD: int = Field(
100000,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make this a byte threshold in memory with pyarrow.Table.get_total_buffer_size? Would be an overestimate but that's probably safe for this case

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if that would work exactly since the in memory accumulator is a pylist of pa.RecordBatchs.

I'll look around for something that's more predictable for the flush threshold than just number of rows since row sizes can vary drastically across different data products.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some looking, RecordBatch also has get_total_buffer_size()

What do you think a good threshold would be in this case? For the first 100k rows for the tasks table I got 2770781904 bytes (2.7 GB)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The corresponding on disk size (compressed w/ zstd) for that first 100k rows is 422M

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2.5-2.75 GB spill is probably good

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

description="Threshold number of rows to accumulate in memory before flushing dataset to disk",
)

ACCESS_CONTROLLED_BATCH_IDS: list[str] = Field(
["gnome_r2scan_statics"], description="Batch ids with access restrictions"
)

model_config = SettingsConfigDict(env_prefix="MPRESTER_")
64 changes: 64 additions & 0 deletions mp_api/client/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from __future__ import annotations

import re
from functools import cached_property
from itertools import chain
from typing import TYPE_CHECKING, Literal

import orjson
import pyarrow.dataset as ds
from deltalake import DeltaTable
from emmet.core import __version__ as _EMMET_CORE_VER
from monty.json import MontyDecoder
from packaging.version import parse as parse_version
from pydantic._internal._model_construction import ModelMetaclass

from mp_api.client.core.settings import MAPIClientSettings

Expand Down Expand Up @@ -124,3 +129,62 @@ def validate_monty(cls, v, _):
monty_cls.validate_monty_v2 = classmethod(validate_monty)

return monty_cls


class MPDataset:
def __init__(self, path, document_model, use_document_model):
"""Convenience wrapper for pyarrow datasets stored on disk."""
self._start = 0
self._path = path
self._document_model = document_model
self._dataset = ds.dataset(path)
self._row_groups = list(
chain.from_iterable(
[
fragment.split_by_row_group()
for fragment in self._dataset.get_fragments()
]
)
)
self._use_document_model = use_document_model

@property
def pyarrow_dataset(self) -> ds.Dataset:
return self._dataset

@property
def pydantic_model(self) -> ModelMetaclass:
return self._document_model

@property
def use_document_model(self) -> bool:
return self._use_document_model

@use_document_model.setter
def use_document_model(self, value: bool):
self._use_document_model = value

@cached_property
def delta_table(self) -> DeltaTable:
return DeltaTable(self._path)

@cached_property
def num_chunks(self) -> int:
return len(self._row_groups)

def __getitem__(self, idx):
return list(
map(
lambda x: self._document_model(**x) if self._use_document_model else x,
self._row_groups[idx].to_table().to_pylist(maps_as_pydicts="strict"),
)
)

def __len__(self) -> int:
return self.num_chunks

def __iter__(self):
current = self._start
while current < self.num_chunks:
yield self[current]
current += 1
Loading
Loading