diff --git a/src/atomate2/common/jobs/transform.py b/src/atomate2/common/jobs/transform.py new file mode 100644 index 0000000000..a7618b24e4 --- /dev/null +++ b/src/atomate2/common/jobs/transform.py @@ -0,0 +1,234 @@ +"""Utility jobs to apply transformations as a job.""" + +from __future__ import annotations + +import os +import tarfile +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING + +from jobflow import Maker, job +from pymatgen.transformations.advanced_transformations import SQSTransformation + +from atomate2.common.schemas.transform import SQSTask, TransformTask + +if TYPE_CHECKING: + from collections.abc import Sequence + + from pymatgen.core import Structure + from pymatgen.transformations.transformation_abc import AbstractTransformation + + +@dataclass +class Transformer(Maker): + """Apply a pymatgen transformation, as a job. + + For many of the standard and advanced transformations, + this should "just work" by supplying the transformation. + """ + + transformation: AbstractTransformation + name: str = "pymatgen transformation maker" + + @job + def make( + self, structure: Structure, **kwargs + ) -> TransformTask | list[TransformTask]: + """Evaluate the transformation. + + Parameters + ---------- + structure : Structure to transform + **kwargs : to pass to the `apply_transformation` method + + Returns + ------- + list of TransformTask, if `self.transformation.is_one_to_many` + (many structures are produced from a single transformation) + + TransformTask, otherwise + """ + transformed_structure = self.transformation.apply_transformation( + structure, **kwargs + ) + if self.transformation.is_one_to_many: + return [ + TransformTask( + input_structure=structure, + final_structure=dct["structure"], + transformation=dct.get("transformation") or self.transformation, + ) + for dct in transformed_structure + ] + return TransformTask( + input_structure=structure, + final_structure=transformed_structure, + transformation=self.transformation, + ) + + +@dataclass +class SQS(Transformer): + """Generate special quasi-random structures (SQSs).""" + + name: str = "SQS" + + transformation: SQSTransformation = field( + default_factory=SQSTransformation( + scaling=1, + search_time=60, + directory=Path(".") / "sqs_runs", + remove_duplicate_structures=True, + best_only=True, + ) + ) + + @staticmethod + def check_structure(structure: Structure, scaling: Sequence[int]) -> Structure: + """Ensure that a disordered structure and scaling factor(s) are sensible.""" + struct = structure.copy() + struct.remove_oxidation_states() + if struct.is_ordered: + raise ValueError("Your structure is likely ordered!") + + if isinstance(scaling, int): + nsites = scaling * len(struct) + elif ( + hasattr(scaling, "__len__") + and all(isinstance(sf, int) for sf in scaling) + and len(scaling) == 3 + ): + nsites = len(struct * scaling) + else: + raise ValueError( + "`scaling` must be an int or sequence of three int, " + f"found {type(scaling)}." + ) + + num_sites: dict[str, int | float] = { + str(element): count * nsites + for element, count in struct.composition.items() + } + + if not all( + abs(num_sites[element] - round(num_sites[element])) < 1e-3 + for element in num_sites + ): + raise ValueError( + f"Incompatible supercell number of sites {nsites} " + f"for composition {struct.composition}" + ) + return struct + + @job + def make( # type: ignore[override] + self, + structure: Structure, + return_ranked_list: bool | int = False, + archive_instances: bool = False, + ) -> dict: + """Perform a parallelized SQS search. + + For Monte Carlo methods, mcsqs and icet-monte_carlo, this + executes parallel SQS searches from the same starting structure. + + For the icet-enumeration method, this divides the labor of + searching through a list of structures. + + Parameters + ---------- + structure : Structure + Disordered structure to order. + return_ranked_list: bool | int = False + Whether to return a list of SQS structures ranked by objective function + (bool), or how many to return (int). False returns only the best. + + Returns + ------- + dict + A dict of the best SQS structure, its objective (if saved), and + the ranked SQS structures (if saved). + """ + original_directory = os.getcwd() + + valid_struct = self.check_structure(structure, self.transformation.scaling) + if return_ranked_list and self.transformation.instances == 1: + raise ValueError( + "`return_ranked_list` should only be used for parallel MCSQS runs." + ) + + sqs_structs = self.transformation.apply_transformation( + valid_struct, return_ranked_list=return_ranked_list + ) + + if return_ranked_list: + best_sqs = sqs_structs[0]["structure"] + best_objective = sqs_structs[0]["objective_function"] + else: + best_sqs = sqs_structs + best_objective = None + + if ( + self.transformation.sqs_method == "mcsqs" + and (mcsqs_corr_file := Path("bestcorr.out")).exists() + ): + best_objective = float( + mcsqs_corr_file.read_text().split("Objective_function=")[-1].strip() + ) + + # MCSQS caller changes the directory + os.chdir(original_directory) + + if archive_instances and self.transformation.sqs_method == "mcsqs": + # MCSQS is the only SQS maker which requires a working directory + mcsqs_dir = Path(self.transformation.directory) + archive_name = str(self.transformation.directory) + if archive_name[-1] == os.path.sep: + archive_name = archive_name[:-1] + archive_name += ".tar.gz" + + # add files to tarball + with tarfile.open(archive_name, "w:gz") as tarball: + files: list[Path] = [] + for file in os.scandir(mcsqs_dir): + if (filename := mcsqs_dir / file).is_file(): + files.append(filename) + tarball.add(filename) + + # cleanup + _ = [file.unlink() for file in files] # type: ignore[func-returns-value] + + if len(list(os.scandir(mcsqs_dir))) == 0: + mcsqs_dir.unlink() + + # For MCSQS, check whether the `perfect_match` was found + # otherwise, SQSTask will throw a validation error + found_perfect_match = False + if ( + isinstance(best_objective, str) + and best_objective.lower() == "perfect_match" + ): + best_objective = None + found_perfect_match = True + + sqs_structures = None + sqs_scores = None + if isinstance(sqs_structs, list) and len(sqs_structs) > 1: + sqs_structures = [entry["structure"] for entry in sqs_structs[1:]] + sqs_scores = [entry["objective_function"] for entry in sqs_structs[1:]] + for i, score in enumerate(sqs_scores): + if isinstance(score, str) and score.lower() == "perfect_match": + sqs_scores[i] = None + found_perfect_match = True + + return SQSTask( + transformation=self.transformation, + input_structure=structure, + final_structure=best_sqs, + final_objective=best_objective, + sqs_structures=sqs_structures, + sqs_scores=sqs_scores, + sqs_method=self.transformation.sqs_method, + found_perfect_match=found_perfect_match, + ) diff --git a/src/atomate2/common/schemas/transform.py b/src/atomate2/common/schemas/transform.py new file mode 100644 index 0000000000..203a2f987d --- /dev/null +++ b/src/atomate2/common/schemas/transform.py @@ -0,0 +1,70 @@ +"""Define schemas for SQS runs.""" + +from pydantic import BaseModel, Field + +try: + from emmet.core.types.enums import ValueEnum +except ImportError: + from emmet.core.utils import ValueEnum + +from pymatgen.core import Structure +from pymatgen.transformations.transformation_abc import AbstractTransformation + + +class SQSMethod(ValueEnum): + """Define possible SQS methods used.""" + + MCSQS = "mcsqs" + ICET_ENUM = "icet-enumeration" + ICET_MCSQS = "icet-monte_carlo" + + +class TransformTask(BaseModel): + """Schematize a transformation run.""" + + transformation: AbstractTransformation = Field( + description="The transformation applied to a structure." + ) + + final_structure: Structure = Field( + description="The structure after the transformation." + ) + + input_structure: Structure = Field( + description="The structure before the transformation." + ) + + +class SQSTask(TransformTask): + """Structure the output of SQS runs.""" + + sqs_method: SQSMethod | None = Field(None, description="The SQS protocol used.") + final_objective: float | None = Field( + None, + description=( + "The minimum value of the SQS obejective function, " + "corresponding to the structure in `final_structure`." + "If None, but `found_perfect_match` is True, then the " + "ideal SQS structure was found." + ), + ) + sqs_structures: list[Structure] | None = Field( + None, description="A list of other good SQS candidates." + ) + sqs_scores: list[float | None] | None = Field( + None, + description=( + "The objective function values for the structures in `sqs_structures`." + "If any value is `None` and `found_perfect_match` is True, then the " + "ideal SQS structure was found." + ), + ) + found_perfect_match: bool = Field( + default=False, + description="Whether the lowest possible SQS objective was attained.", + ) + + @property + def all_structures(self) -> list[Structure]: + """Return all structures, not just the most optimal SQS structure.""" + return [self.final_structure, *(self.sqs_structures or [])] diff --git a/tests/common/jobs/test_transform.py b/tests/common/jobs/test_transform.py new file mode 100644 index 0000000000..dbad36d855 --- /dev/null +++ b/tests/common/jobs/test_transform.py @@ -0,0 +1,123 @@ +"""Test transformation jobs.""" + +try: + import icet +except ImportError: + icet = None + +import numpy as np +import pytest +from jobflow import Flow, run_locally +from pymatgen.core import Structure +from pymatgen.transformations.advanced_transformations import SQSTransformation +from pymatgen.transformations.standard_transformations import ( + OrderDisorderedStructureTransformation, + OxidationStateDecorationTransformation, +) + +from atomate2.common.jobs.transform import SQS, Transformer +from atomate2.common.schemas.transform import SQSTask, TransformTask + + +@pytest.fixture(scope="module") +def simple_alloy() -> Structure: + """Hexagonal close-packed 50-50 Mg-Al alloy.""" + return Structure( + 3.5 + * np.array( + [ + [0.5, -(3.0 ** (0.5)) / 2.0, 0.0], + [0.5, 3.0 ** (0.5) / 2.0, 0.0], + [0.0, 0.0, 8 ** (0.5) / 3.0], + ] + ), + [{"Mg": 0.5, "Al": 0.5}, {"Mg": 0.5, "Al": 0.5}], + [[0.0, 0.0, 0.0], [1.0 / 3.0, 2.0 / 3.0, 0.5]], + ) + + +def test_simple_and_advanced(): + # simple disordered zincblende structure + structure = Structure( + 3.8 * np.array([[0.0, 0.5, 0.5], [0.5, 0.0, 0.5], [0.5, 0.5, 0.0]]), + ["Zn", {"S": 0.75, "Se": 0.25}], + [[0.0, 0.0, 0.0], [0.25, 0.25, 0.25]], + ).to_conventional() + + oxi_dict = {"Zn": 2, "S": -2, "Se": -2} + oxi_job = Transformer( + name="oxistate", transformation=OxidationStateDecorationTransformation(oxi_dict) + ).make(structure) + + odst_job = Transformer( + name="odst", transformation=OrderDisorderedStructureTransformation() + ).make(oxi_job.output.final_structure, return_ranked_list=2) + + flow = Flow([oxi_job, odst_job]) + resp = run_locally(flow) + + oxi_state_output = resp[oxi_job.uuid][1].output + assert isinstance(oxi_state_output, TransformTask) + + # check correct assignment of oxidation states + assert all( + specie.oxi_state == oxi_dict.get(specie.element.value) + for site in oxi_state_output.final_structure + for specie in site.species + ) + + odst_output = resp[odst_job.uuid][1].output + # return_ranked_list = 2, so should get 2 output docs + assert len(odst_output) == 2 + assert all(isinstance(doc, TransformTask) for doc in odst_output) + assert all(doc.final_structure.is_ordered for doc in odst_output) + + +@pytest.mark.skipif( + icet is None, reason="`icet` must be installed to perform this test." +) +def test_sqs(tmp_dir, simple_alloy): + # Probably most common use case - just get one "best" SQS + sqs_trans = SQSTransformation( + scaling=4, + best_only=False, + sqs_method="icet-enumeration", + ) + job = SQS(transformation=sqs_trans).make(simple_alloy) + + output = run_locally(job)[job.uuid][1].output + assert isinstance(output, SQSTask) + assert output.final_structure.composition.as_dict() == {"Mg": 4, "Al": 4} + assert isinstance(output.final_structure, Structure) + assert output.final_structure.is_ordered + assert all( + getattr(output, attr) is None for attr in ("sqs_structures", "sqs_scores") + ) + assert isinstance(output.transformation, SQSTransformation) + + # Now simulate retrieving multiple SQSes + sqs_trans = SQSTransformation( + scaling=4, + best_only=False, + sqs_method="icet-monte_carlo", + instances=3, + icet_sqs_kwargs={"n_steps": 10}, # only 10-step search + remove_duplicate_structures=False, # needed just to simulate output + ) + + # return up to the two best structures + job = SQS(transformation=sqs_trans).make(simple_alloy, return_ranked_list=2) + output = run_locally(job)[job.uuid][1].output + + assert isinstance(output, SQSTask) + + # return_ranked_list - 1 structures and objective functions should be here + assert all( + len(getattr(output, attr)) == 1 for attr in ("sqs_structures", "sqs_scores") + ) + + assert all( + struct.composition.as_dict() == {"Mg": 4, "Al": 4} + and isinstance(struct, Structure) + for struct in output.sqs_structures + )