diff --git a/dpdata/system.py b/dpdata/system.py index 0748db28..f1273104 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -1,5 +1,6 @@ # %% import glob +import hashlib import os import warnings from copy import deepcopy @@ -19,7 +20,13 @@ from dpdata.driver import Driver, Minimizer from dpdata.format import Format from dpdata.plugin import Plugin -from dpdata.utils import add_atom_names, elements_index_map, remove_pbc, sort_atom_names +from dpdata.utils import ( + add_atom_names, + elements_index_map, + remove_pbc, + sort_atom_names, + utf8len, +) def load_format(fmt): @@ -562,6 +569,42 @@ def uniq_formula(self): ] ) + @property + def short_formula(self) -> str: + """Return the short formula of this system. Elements with zero number + will be removed. + """ + return "".join( + [ + f"{symbol}{numb}" + for symbol, numb in zip( + self.data["atom_names"], self.data["atom_numbs"] + ) + if numb + ] + ) + + @property + def formula_hash(self) -> str: + """Return the hash of the formula of this system.""" + return hashlib.sha256(self.formula.encode("utf-8")).hexdigest() + + @property + def short_name(self) -> str: + """Return the short name of this system (no more than 255 bytes), in + the following order: + - formula + - short_formula + - formula_hash. + """ + formula = self.formula + if utf8len(formula) <= 255: + return formula + short_formula = self.short_formula + if utf8len(short_formula) <= 255: + return short_formula + return self.formula_hash + def extend(self, systems): """Extend a system list to this system. @@ -1247,7 +1290,9 @@ def from_fmt_obj(self, fmtobj, directory, labeled=True, **kwargs): def to_fmt_obj(self, fmtobj, directory, *args, **kwargs): if not isinstance(fmtobj, dpdata.plugins.deepmd.DeePMDMixedFormat): for fn, ss in zip( - fmtobj.to_multi_systems(self.systems.keys(), directory, **kwargs), + fmtobj.to_multi_systems( + [ss.short_name for ss in self.systems.values()], directory, **kwargs + ), self.systems.values(), ): ss.to_fmt_obj(fmtobj, fn, *args, **kwargs) diff --git a/dpdata/utils.py b/dpdata/utils.py index da726179..cf4a109e 100644 --- a/dpdata/utils.py +++ b/dpdata/utils.py @@ -99,3 +99,8 @@ def uniq_atom_names(data): sum(ii == data["atom_types"]) for ii in range(len(data["atom_names"])) ] return data + + +def utf8len(s: str) -> int: + """Return the byte length of a string.""" + return len(s.encode("utf-8")) diff --git a/tests/test_multisystems.py b/tests/test_multisystems.py index 172c2ad4..2bda13a9 100644 --- a/tests/test_multisystems.py +++ b/tests/test_multisystems.py @@ -1,7 +1,9 @@ import os +import tempfile import unittest from itertools import permutations +import numpy as np from comp_sys import CompLabeledSys, IsNoPBC, MultiSystems from context import dpdata @@ -200,5 +202,37 @@ def setUp(self): self.atom_names = ["C", "H", "O"] +class TestLongFilename(unittest.TestCase): + def test_long_filename1(self): + system = dpdata.System( + data={ + "atom_names": [f"TYPE{ii}" for ii in range(200)], + "atom_numbs": [1] + [0 for _ in range(199)], + "atom_types": np.arange(1), + "coords": np.zeros((1, 1, 3)), + "orig": np.zeros(3), + "cells": np.zeros((1, 3, 3)), + } + ) + ms = dpdata.MultiSystems(system) + with tempfile.TemporaryDirectory() as tmpdir: + ms.to_deepmd_npy(tmpdir) + + def test_long_filename2(self): + system = dpdata.System( + data={ + "atom_names": [f"TYPE{ii}" for ii in range(200)], + "atom_numbs": [1 for _ in range(200)], + "atom_types": np.arange(200), + "coords": np.zeros((1, 200, 3)), + "orig": np.zeros(3), + "cells": np.zeros((1, 3, 3)), + } + ) + ms = dpdata.MultiSystems(system) + with tempfile.TemporaryDirectory() as tmpdir: + ms.to_deepmd_npy(tmpdir) + + if __name__ == "__main__": unittest.main()