Source code for oncvpsp_tools.output

"""Classes for handling ONCVPSP output files."""

from collections import UserList
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Generic, Iterable, List, Optional, TypeVar, Union

import matplotlib.pyplot as plt
import numpy as np

from .input import ONCVPSPInput

T = TypeVar("T")


@dataclass
class ONCVPSPOutputData:
    """Generic class for storing data from an ONCVPSP output file."""

    x: np.ndarray
    y: np.ndarray
    xlabel: str = "radius ($a_0$)"
    label: str = ""
    info: Dict[str, Any] = field(default_factory=dict)

    def __eq__(self, other):
        """Check if two :class:`ONCVPSPOutputData` objects are equal."""
        if not isinstance(other, ONCVPSPOutputData):
            return False
        return (
            np.allclose(self.x, other.x)
            and np.allclose(self.y, other.y)
            and self.xlabel == other.xlabel
            and self.label == other.label
            and self.info == other.info
        )

    def plot(self, ax=None, **kwargs):
        """Plot the data."""
        if ax is None:
            _, ax = plt.subplots()
        if "label" not in kwargs:
            kwargs["label"] = ", ".join([f"{k}={v}" for k, v in self.info.items()])
        if (
            "ls" not in kwargs
            and "linestyle" not in kwargs
            and self.info.get("kind", None) == "pseudo"
        ):
            kwargs["ls"] = "--"
        ax.plot(self.x, self.y, **kwargs)
        ax.set_xlabel(self.xlabel)
        ax.set_xlim([self.x.min(), self.x.max()])
        if self.label:
            ax.set_title(self.label)
        return ax

    @classmethod
    def from_str(cls, string: str, identifier: str, xcol: int, ycol: int, **kwargs):
        """Create an :class:`ONCVPSPOutputData` object from a string."""
        relevant_lines = [
            line.strip().split()
            for line in string.split("\n")
            if line.strip().startswith(identifier)
        ]

        x = np.array([float(line[xcol]) for line in relevant_lines])
        y = np.array([float(line[ycol]) for line in relevant_lines])

        return cls(x, y, **kwargs)

    @classmethod
    def from_file(cls, filename: Union[Path, str], identifier: str, xcol: int, ycol: int, **kwargs):
        """Create an :class:`ONCVPSPOutputData` object from a file."""
        filename = Path(filename)

        with open(filename, "r") as f:
            lines = f.read()

        return cls.from_str(lines, identifier, xcol, ycol, **kwargs)


class ONCVPSPOutputDataList(UserList, Generic[T]):
    """Generic class for a list of ONCVPSPOutputData objects, with a few extra functionalities."""

    label: str

    def __init__(self, data, label: str = ""):
        """Create an :class:`ONCVPSPOutputDataList` object."""
        super().__init__(data)
        self.label = label

    def plot(self, ax=None, kwargs_list: Optional[List[Dict[str, Any]]] = None, **kwargs):
        """Plot all the data in the list."""
        if kwargs_list is None:
            kwargs_list = [{} for _ in self.data]
        for i, (data, specific_kwargs) in enumerate(zip(self.data, kwargs_list)):
            # Make the colors match for entries that only differ by info['kind']
            if ax and "color" not in specific_kwargs and "color" not in kwargs:
                # Get the previous colors used and the matching info dictionaries
                colors = [line.get_color() for line in ax.get_lines()[-i:]]
                infos = [{k: v for k, v in d.info.items() if k != "kind"} for d in self.data[:i]]

                # Use the same color if the dictionaries match (ignoring the 'kind' key)
                for info, color in zip(infos, colors):
                    if info == {k: v for k, v in data.info.items() if k != "kind"}:
                        specific_kwargs["color"] = color
                        break

            ax = data.plot(ax, **specific_kwargs, **kwargs)
        ax.legend()
        ax.set_title(self.label)

        # Set xlimits to the largest range of x values
        ax.set_xlim([min([d.x.min() for d in self.data]), max([d.x.max() for d in self.data])])

        return ax

    @classmethod
    def from_str(
        cls,
        label: str,
        string: str,
        identifiers,
        xcol: int,
        ycols: Iterable[int],
        kwargs_list: Optional[List[Dict[str, Any]]] = None,
    ):
        """Create an :class:`ONCVPSPOutputDataList` object from a string."""
        if kwargs_list is None:
            kwargs_list = [{} for _ in identifiers]
        oncvlist = cls(
            [
                ONCVPSPOutputData.from_str(string, identifier, xcol, ycol, **kwargs)
                for identifier, ycol, kwargs in zip(identifiers, ycols, kwargs_list)
            ]
        )
        oncvlist.label = label
        return oncvlist


[docs]@dataclass class ONCVPSPOutput: """Class for the contents of an ``oncvpsp.x`` output file. The :class:`ONCVPSPOutput` class is a dataclass that helps a user interact with output files from ``oncvpsp.x``. Typically, a user will not create a :class:`ONCVPSPOutput` object directly, but rather use the class method :meth:`from_file` as follows:: from oncvpsp_tools import ONCVPSPOutput output = ONCVPSPOutput.from_file("path/to/output") :class:`ONCVPSPOutput` objects -- being a :class:`dataclass` -- have the same attributes as the input parameters (listed below). Use these to interact with the contents of the output file. For example, to plot the semilocal ion pseudopotentials:: output.semilocal_ion_pseudopotentials.plot() :param content: the entire content of the output file :type content: str :param input: the input file used to generate the output :type input: :class:`ONCVPSPInput` :param semilocal_ion_pseudopotentials: the semilocal ion pseudopotentials :type semilocal_ion_pseudopotentials: :class:`ONCVPSPOutputDataList` :param local_pseudopotential: the local pseudopotential :type local_pseudopotential: :class:`ONCVPSPOutputData` :param charge_densities: the charge densities :type charge_densities: :class:`ONCVPSPOutputDataList` :param wavefunctions: the pseudoatomic wavefunctions :type wavefunctions: :class:`ONCVPSPOutputDataList` :param arctan_log_derivatives: the arctan log derivatives :type arctan_log_derivatives: :class:`ONCVPSPOutputDataList` :param projectors: the projectors :type projectors: :class:`ONCVPSPOutputDataList` :param energy_error: the energy error :type energy_error: :class:`ONCVPSPOutputDataList` """ content: str input: ONCVPSPInput semilocal_ion_pseudopotentials: ONCVPSPOutputDataList[ONCVPSPOutputData] local_pseudopotential: ONCVPSPOutputData charge_densities: ONCVPSPOutputDataList[ONCVPSPOutputData] wavefunctions: ONCVPSPOutputDataList[ONCVPSPOutputData] arctan_log_derivatives: ONCVPSPOutputDataList[ONCVPSPOutputData] projectors: ONCVPSPOutputDataList[ONCVPSPOutputData] energy_error: ONCVPSPOutputDataList[ONCVPSPOutputData]
[docs] @classmethod def from_str(cls, content: str): """Create an :class:`ONCVPSPOutput` object from a string.""" splitcontent = content.split("\n") # ONCVPSP input at_ref_str = "# ATOM AND REFERENCE CONFIGURATION" if at_ref_str in splitcontent: istart = splitcontent.index(at_ref_str) else: raise ValueError( "The atom and reference configuration information is missing from this output; this " "suggests that it came from a failed oncvpsp.x calculation" ) input = ONCVPSPInput.from_str("\n".join(splitcontent[istart:])) # Semilocal ion pseudopotentials slp_kwargs = [{"info": {"l": l}} for l in range(input.lmax + 1)] semilocal_ion_pseudopotentials = ONCVPSPOutputDataList.from_str( "semilocal ion pseudopotentials", content, ["!p" for _ in range(input.lmax + 1)], 1, range(3, input.lmax + 4), slp_kwargs, ) # Local pseudopotential local_pseudopotential = ONCVPSPOutputData.from_str( content, "!L", 1, 2, label="local pseudopotential" ) # Charge densities cd_kwargs = [{"info": {"rho": rho}} for rho in ["C", "M", "V"]] charge_densities = ONCVPSPOutputDataList.from_str( "charge densities", content, ["!r ", "!r ", "!r "], 1, [2, 3, 4], cd_kwargs ) # Pseudo and real wavefunctions il_pairs = sorted( list( set( [ line.strip().split()[1] for line in splitcontent if line.strip().startswith("&") ] ) ) ) kinds = ["full", "pseudo"] kwargs = [ {"info": {"kind": kind, "i": int(il[0]), "l": int(il[1])}} for il in il_pairs for kind in kinds ] identifiers = ["& " + il for il in il_pairs for _ in kinds] ycols = [kind_col for _ in range(len(il_pairs)) for kind_col in [3, 4]] wavefunctions = ONCVPSPOutputDataList.from_str( "wavefunctions", content, identifiers, 2, ycols, kwargs ) # Arctan log derivatives identifiers = [f"! {l}" for l in range(4) for kind in kinds] ycols = [kind_col for _ in range(4) for kind_col in [3, 4]] kwargs = [{"info": {"kind": kind, "l": l}} for l in range(4) for kind in kinds] arctan_log_derivatives = ONCVPSPOutputDataList.from_str( "arctan log derivatives", content, identifiers, 2, ycols, kwargs ) # Projectors ls = [proj.l for proj in input.vkb_projectors for _ in range(proj.nproj)] identifiers = [f"!J {l}" for l in ls] ycols = [x + 3 for proj in input.vkb_projectors for x in range(proj.nproj)] kwargs = [ {"info": {"i": i, "l": proj.l}} for proj in input.vkb_projectors for i in range(proj.nproj) ] projectors = ONCVPSPOutputDataList.from_str( "projectors", content, identifiers, 2, ycols, kwargs ) # Energy error per electron identifiers = [f"!C {l}" for l in range(input.lmax + 1)] eepe_kwargs = [ {"info": {"l": l}, "xlabel": "cutoff energy (Ha)"} for l in range(input.lmax + 1) ] eepe = ONCVPSPOutputDataList.from_str( "energy error per electron", content, identifiers, 2, [3 for _ in identifiers], eepe_kwargs, ) return cls( content, input, semilocal_ion_pseudopotentials, local_pseudopotential, charge_densities, wavefunctions, arctan_log_derivatives, projectors, eepe, )
[docs] @classmethod def from_file(cls, filename: str): """Create an :class:`ONCVPSPOutput` object from an ONCVPSP output file.""" with open(filename, "r") as f: content = f.read() return cls.from_str(content)
[docs] def to_str(self) -> str: """Return the contents of the ONCVPSP output file.""" return self.content
[docs] def to_file(self, filename: str) -> None: """Write the contents of the ONCVPSP output file to a file.""" with open(filename, "w") as f: f.write(self.to_str())
[docs] def to_upf(self) -> str: """Return the UPF part of the ONCVPSP output file.""" flines = self.content.split("\n") [istart] = [flines.index(x) for x in flines if "<UPF" in x] [iend] = [flines.index(x) for x in flines if "</UPF" in x] return "\n".join(flines[istart : iend + 1])