99import itertools
1010import os
1111import platform
12+ import shutil
1213import sys
1314import warnings
1415from concurrent .futures import FIRST_COMPLETED , ThreadPoolExecutor , wait
1819from importlib .metadata import PackageNotFoundError , version
1920from json import JSONDecodeError
2021from math import ceil
21- from typing import (
22- TYPE_CHECKING ,
23- ForwardRef ,
24- Optional ,
25- get_args ,
26- )
22+ from typing import TYPE_CHECKING , ForwardRef , Optional , get_args
2723from urllib .parse import quote , urljoin
2824
25+ import pyarrow as pa
26+ import pyarrow .dataset as ds
2927import requests
28+ from deltalake import DeltaTable , QueryBuilder , convert_to_deltalake
3029from emmet .core .utils import jsanitize
3130from pydantic import BaseModel , create_model
3231from requests .adapters import HTTPAdapter
3635from urllib3 .util .retry import Retry
3736
3837from mp_api .client .core .settings import MAPIClientSettings
39- from mp_api .client .core .utils import load_json , validate_ids
38+ from mp_api .client .core .utils import MPDataset , load_json , validate_ids
4039
4140try :
4241 import boto3
@@ -71,6 +70,7 @@ class BaseRester:
7170 document_model : type [BaseModel ] | None = None
7271 supports_versions : bool = False
7372 primary_key : str = "material_id"
73+ delta_backed : bool = False
7474
7575 def __init__ (
7676 self ,
@@ -85,6 +85,8 @@ def __init__(
8585 timeout : int = 20 ,
8686 headers : dict | None = None ,
8787 mute_progress_bars : bool = SETTINGS .MUTE_PROGRESS_BARS ,
88+ local_dataset_cache : str | os .PathLike = SETTINGS .LOCAL_DATASET_CACHE ,
89+ force_renew : bool = False ,
8890 ):
8991 """Initialize the REST API helper class.
9092
@@ -116,6 +118,9 @@ def __init__(
116118 timeout: Time in seconds to wait until a request timeout error is thrown
117119 headers: Custom headers for localhost connections.
118120 mute_progress_bars: Whether to disable progress bars.
121+ local_dataset_cache: Target directory for downloading full datasets. Defaults
122+ to 'materialsproject_datasets' in the user's home directory
123+ force_renew: Option to overwrite existing local dataset
119124 """
120125 # TODO: think about how to migrate from PMG_MAPI_KEY
121126 self .api_key = api_key or os .getenv ("MP_API_KEY" )
@@ -129,6 +134,8 @@ def __init__(
129134 self .timeout = timeout
130135 self .headers = headers or {}
131136 self .mute_progress_bars = mute_progress_bars
137+ self .local_dataset_cache = local_dataset_cache
138+ self .force_renew = force_renew
132139 self .db_version = BaseRester ._get_database_version (self .endpoint )
133140
134141 if self .suffix :
@@ -212,7 +219,7 @@ def _get_database_version(endpoint):
212219 remains unchanged and available for querying via its task_id.
213220
214221 The database version is set as a date in the format YYYY_MM_DD,
215- where "_DD" may be optional. An additional numerical or `postN` suffix
222+ predicate "_DD" may be optional. An additional numerical or `postN` suffix
216223 might be added if multiple releases happen on the same day.
217224
218225 Returns: database version as a string
@@ -356,10 +363,7 @@ def _patch_resource(
356363 raise MPRestError (str (ex ))
357364
358365 def _query_open_data (
359- self ,
360- bucket : str ,
361- key : str ,
362- decoder : Callable | None = None ,
366+ self , bucket : str , key : str , decoder : Callable | None = None
363367 ) -> tuple [list [dict ] | list [bytes ], int ]:
364368 """Query and deserialize Materials Project AWS open data s3 buckets.
365369
@@ -463,6 +467,12 @@ def _query_resource(
463467 url += "/"
464468
465469 if query_s3 :
470+ pbar_message = ( # type: ignore
471+ f"Retrieving { self .document_model .__name__ } documents" # type: ignore
472+ if self .document_model is not None
473+ else "Retrieving documents"
474+ )
475+
466476 db_version = self .db_version .replace ("." , "-" )
467477 if "/" not in self .suffix :
468478 suffix = self .suffix
@@ -474,9 +484,14 @@ def _query_resource(
474484 suffix = suffix .replace ("_" , "-" )
475485
476486 # Check if user has access to GNoMe
487+ # temp suppress tqdm
488+ re_enable = not self .mute_progress_bars
489+ self .mute_progress_bars = True
477490 has_gnome_access = bool (
478491 self ._submit_requests (
479- url = urljoin (self .endpoint , "materials/summary/" ),
492+ url = urljoin (
493+ "https://api.materialsproject.org/" , "materials/summary/"
494+ ),
480495 criteria = {
481496 "batch_id" : "gnome_r2scan_statics" ,
482497 "_fields" : "material_id" ,
@@ -489,21 +504,147 @@ def _query_resource(
489504 .get ("meta" , {})
490505 .get ("total_doc" , 0 )
491506 )
507+ self .mute_progress_bars = not re_enable
492508
493- # Paginate over all entries in the bucket.
494- # TODO: change when a subset of entries needed from DB
495509 if "tasks" in suffix :
496- bucket_suffix , prefix = "parsed" , "tasks_atomate2"
510+ bucket_suffix , prefix = ( "parsed" , "core/tasks/" )
497511 else :
498512 bucket_suffix = "build"
499513 prefix = f"collections/{ db_version } /{ suffix } "
500514
501- # only include prefixes accessible to user
502- # i.e. append `batch_id=others/core` to `prefix`
503- if not has_gnome_access :
504- prefix += "/batch_id=others"
505-
506515 bucket = f"materialsproject-{ bucket_suffix } "
516+
517+ if self .delta_backed :
518+ target_path = (
519+ self .local_dataset_cache + f"/{ bucket_suffix } /{ prefix } "
520+ )
521+ os .makedirs (target_path , exist_ok = True )
522+
523+ if DeltaTable .is_deltatable (target_path ):
524+ if self .force_renew :
525+ shutil .rmtree (target_path )
526+ warnings .warn (
527+ f"Regenerating { suffix } dataset at { target_path } ..." ,
528+ MPLocalDatasetWarning ,
529+ )
530+ os .makedirs (target_path , exist_ok = True )
531+ else :
532+ warnings .warn (
533+ f"Dataset for { suffix } already exists at { target_path } , delete or move existing dataset "
534+ "or re-run search query with MPRester(force_renew=True)" ,
535+ MPLocalDatasetWarning ,
536+ )
537+
538+ return {
539+ "data" : MPDataset (
540+ path = target_path ,
541+ document_model = self .document_model ,
542+ use_document_model = self .use_document_model ,
543+ )
544+ }
545+
546+ tbl = DeltaTable (
547+ f"s3a://{ bucket } /{ prefix } " ,
548+ storage_options = {
549+ "AWS_SKIP_SIGNATURE" : "true" ,
550+ "AWS_REGION" : "us-east-1" ,
551+ },
552+ )
553+
554+ controlled_batch_str = "," .join (
555+ [f"'{ tag } '" for tag in SETTINGS .ACCESS_CONTROLLED_BATCH_IDS ]
556+ )
557+
558+ predicate = (
559+ " WHERE batch_id NOT IN (" # don't delete leading space
560+ + controlled_batch_str
561+ + ")"
562+ if not has_gnome_access
563+ else ""
564+ )
565+
566+ builder = QueryBuilder ().register ("tbl" , tbl )
567+
568+ # Setup progress bar
569+ num_docs_needed = pa .table (
570+ builder .execute ("SELECT COUNT(*) FROM tbl" ).read_all ()
571+ )[0 ][0 ].as_py ()
572+
573+ # TODO: Update tasks (+ others?) resource to have emmet-api BatchIdQuery operator
574+ # -> need to modify BatchIdQuery operator to handle root level
575+ # batch_id, not only builder_meta.batch_id
576+ # if not has_gnome_access:
577+ # num_docs_needed = self.count(
578+ # {"batch_id_neq_any": SETTINGS.ACCESS_CONTROLLED_BATCH_IDS}
579+ # )
580+
581+ pbar = (
582+ tqdm (
583+ desc = pbar_message ,
584+ total = num_docs_needed ,
585+ )
586+ if not self .mute_progress_bars
587+ else None
588+ )
589+
590+ iterator = builder .execute ("SELECT * FROM tbl" + predicate )
591+
592+ file_options = ds .ParquetFileFormat ().make_write_options (
593+ compression = "zstd"
594+ )
595+
596+ def _flush (accumulator , group ):
597+ ds .write_dataset (
598+ accumulator ,
599+ base_dir = target_path ,
600+ format = "parquet" ,
601+ basename_template = f"group-{ group } -"
602+ + "part-{i}.zstd.parquet" ,
603+ existing_data_behavior = "overwrite_or_ignore" ,
604+ max_rows_per_group = 1024 ,
605+ file_options = file_options ,
606+ )
607+
608+ group = 1
609+ size = 0
610+ accumulator = []
611+ for page in iterator :
612+ # arro3 rb to pyarrow rb for compat w/ pyarrow ds writer
613+ accumulator .append (pa .record_batch (page ))
614+ page_size = page .num_rows
615+ size += page_size
616+
617+ if pbar is not None :
618+ pbar .update (page_size )
619+
620+ if size >= SETTINGS .DATASET_FLUSH_THRESHOLD :
621+ _flush (accumulator , group )
622+ group += 1
623+ size = 0
624+ accumulator = []
625+
626+ if accumulator :
627+ _flush (accumulator , group + 1 )
628+
629+ convert_to_deltalake (target_path )
630+
631+ warnings .warn (
632+ f"Dataset for { suffix } written to { target_path } . It is recommended to optimize "
633+ "the table according to your usage patterns prior to running intensive workloads, "
634+ "see: https://delta-io.github.io/delta-rs/delta-lake-best-practices/#optimizing-table-layout" ,
635+ MPLocalDatasetWarning ,
636+ )
637+
638+ return {
639+ "data" : MPDataset (
640+ path = target_path ,
641+ document_model = self .document_model ,
642+ use_document_model = self .use_document_model ,
643+ )
644+ }
645+
646+ # Paginate over all entries in the bucket.
647+ # TODO: change when a subset of entries needed from DB
507648 paginator = self .s3_client .get_paginator ("list_objects_v2" )
508649 pages = paginator .paginate (Bucket = bucket , Prefix = prefix )
509650
@@ -540,11 +681,6 @@ def _query_resource(
540681 }
541682
542683 # Setup progress bar
543- pbar_message = ( # type: ignore
544- f"Retrieving { self .document_model .__name__ } documents" # type: ignore
545- if self .document_model is not None
546- else "Retrieving documents"
547- )
548684 num_docs_needed = int (self .count ())
549685 pbar = (
550686 tqdm (
@@ -1372,3 +1508,7 @@ class MPRestError(Exception):
13721508
13731509class MPRestWarning (Warning ):
13741510 """Raised when a query is malformed but interpretable."""
1511+
1512+
1513+ class MPLocalDatasetWarning (Warning ):
1514+ """Raised when unrecoverable actions are performed on a local dataset."""
0 commit comments