Skip to content

Commit

Permalink
Merge pull request #32 from QEF/develop
Browse files Browse the repository at this point in the history
Update to version 1.3
  • Loading branch information
brunato authored Dec 24, 2021
2 parents 3f57d3e + eed50e8 commit 3bc68aa
Show file tree
Hide file tree
Showing 111 changed files with 861 additions and 573 deletions.
3 changes: 2 additions & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ source = qeschema/
[report]
exclude_lines =
# Exclude not implemented features
raise NotImplementedError
pragma: no cover
raise NotImplementedError()

# Exclude lines where yaml library is not installed
except ImportError\:
Expand Down
5 changes: 4 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ['3.6', '3.7', '3.8']
python-version: ['3.7', '3.8', '3.9', '3.10']
steps:
- uses: actions/checkout@v2
- name: Set up Python
Expand All @@ -28,6 +28,9 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r requirements-dev.txt
- name: Lint with flake8
run: |
flake8 qeschema --max-line-length=100 --statistics
- name: Run tests
run: |
python -m unittest
3 changes: 2 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ opEn-Source Package for Research in Electronic Structure, Simulation and Optimiz
Requirements
------------

* Python_ 3.6+
* Python_ 3.7+
* xmlschema_ (Python library for processing XML Schema based documents)

.. _Python: http://www.python.org/
Expand Down Expand Up @@ -56,6 +56,7 @@ Authors
Davide Brunato
Pietro Delugas
Giovanni Borghi
Alexandr Fonari


License
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
author = 'Davide Brunato, Pietro Delugas'

# The full version, including alpha/beta/rc tags
release = '1.2.1'
release = '1.3.0'


# -- General configuration ---------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ by qeschema.

>>> import qeschema
>>> doc = qeschema.PwDocument(schema='qes.xsd')
>>> doc.read('tests/examples/pw/Al001_relax_bfgs.xml')
>>> doc.read('tests/resources/pw/Al001_relax_bfgs.xml')
>>> pw_data = doc.to_dict()
>>> control_variables = pw_data['qes:espresso']['input']['control_variables']

Expand All @@ -28,6 +28,6 @@ the desired dictionary:

>>> import qeschema
>>> doc = qeschema.PwDocument()
>>> doc.read('tests/examples/pw/CO_bgfs_relax.xml')
>>> doc.read('tests/resources/pw/CO_bgfs_relax.xml')
>>> path = './/input/atomic_species'
>>> bsdict = doc.to_dict(path=path)
4 changes: 2 additions & 2 deletions qeschema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from .exceptions import QESchemaError, XmlDocumentError
from .utils import set_logger

__version__ = '1.2.1'
__version__ = '1.3.0'

__all__ = [
'XmlDocument', 'QeDocument', 'PwDocument', 'PhononDocument', 'NebDocument',
'TdDocument', 'TdSpectrumDocument', 'RawInputConverter', 'PwInputConverter',
'PhononInputConverter', 'TdInputConverter', 'TdSpectrumInputConverter',
'NebInputConverter', 'QESchemaError', 'XmlDocumentError', 'set_logger'
'NebInputConverter', 'QESchemaError', 'XmlDocumentError', 'set_logger', 'hdf5'
]
48 changes: 24 additions & 24 deletions qeschema/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,12 @@ def _build_maps(dict_element, path):
path_key = '/'.join((path, key))
if isinstance(item, dict):
_build_maps(item, path_key)
continue

if isinstance(item, str):
elif isinstance(item, str):
invariant_map[path_key] = item
logger.debug("Added one-to-one association: '{0}'='{1}'".format(path_key, item))
continue
logger.debug("Added one-to-one association: %r=%r", path_key, item)

if isinstance(item, tuple) or isinstance(item, list):
elif isinstance(item, (tuple, list)):
try:
variant_map[path_key] = _check_variant(*item)
logger.debug("Added single-variant mapping: %r=%r",
Expand All @@ -90,10 +88,11 @@ def _build_maps(dict_element, path):
for variant in item:
if isinstance(variant, str):
invariant_map[path_key] = variant
elif isinstance(item, tuple) or isinstance(item, list):
elif isinstance(variant, (tuple, list)):
variants.append(_check_variant(*variant))
else:
raise TypeError("Expect a tuple, list or string! {0}".format(variant))
raise TypeError(f"Expect a tuple, list or string! {variant!r}")

variant_map[path_key] = tuple(variants)
logger.debug("Added multi-variant mapping: %r=%r",
path_key, variant_map[path_key])
Expand All @@ -107,7 +106,7 @@ def _build_maps(dict_element, path):
# Check inconsistencies between maps
for items in variant_map.values():
for value in items:
logger.debug("Check value: {0}".format(value))
logger.debug("Check value: %r", value)
if isinstance(value, str) and value in invariant_map.inverse():
raise ValueError("A variant is also in invariant map! "
"'%s': '%s'" % (invariant_map.getkey(value), value))
Expand Down Expand Up @@ -163,7 +162,7 @@ def set_path(self, path, tag, node_dict):
if len(node_dict) != 1:
raise ValueError("The node_dict argument must contains exactly "
"one element! {0}".format(node_dict))
logger.debug("Set input with path '{0}' and node dict '{1}'".format(path, node_dict))
logger.debug("Set input with path %r and node dict %r", path, node_dict)
_path, _, keyword = path.rpartition('/')
value = node_dict[tag]
if isinstance(value, dict) and keyword != tag:
Expand All @@ -178,7 +177,7 @@ def set_path(self, path, tag, node_dict):
)

if value is None:
logger.debug("Skip element '%s': None value!" % path)
logger.debug("Skip element %r: None value!", path)
return

# Set the target parameter if the path is in invariant_map dictionary
Expand All @@ -200,16 +199,16 @@ def set_parameter(self, path, value):
raise ValueError("Wrong value {!r} for invariant parameter {!r}".format(target, path))

self._input[namelist][name] = to_fortran(value)
logger.debug("Set {0}[{1}]={2}".format(namelist, name, self._input[namelist][name]))
logger.debug("Set %s[%s]=%s", namelist, name, self._input[namelist][name])

def add_kwarg(self, path, tag, node_dict):
if isinstance(self.variant_map[path][0], str):
target_items = list([self.variant_map[path]])[:2]
else:
target_items = self.variant_map[path]
for target, _get_qe_input, _ in target_items:
logger.debug("Add argument to '{0}'".format(target))
logger.debug("Argument's conversion function: {0}".format(_get_qe_input))
logger.debug("Add argument to %r", target)
logger.debug("Argument's conversion function: %r", _get_qe_input)
group, name = self.target_pattern.match(target).groups()
if name is not None:
try:
Expand Down Expand Up @@ -251,36 +250,37 @@ def get_qe_input(self):

lines.append('&%s' % namelist)
for name, value in sorted(_input[namelist].items(), key=lambda x: x[0].lower()):
logger.debug("Add input for parameter %s[%r] with value %r", namelist, name, value)
logger.debug("Add input for parameter %s[%r] with value %r",
namelist, name, value)

if isinstance(value, dict):
# Variant conversion: apply to_fortran_input function with saved arguments
try:
to_fortran_input = value['_get_qe_input']
except KeyError:
logger.debug(
'No conversion function for parameter %s[%r], skip ... ', namelist, name
)
logger.debug('No conversion function for parameter %s[%r], skip ... ',
namelist, name)
continue

if callable(to_fortran_input):
lines.extend(to_fortran_input(name, **value))
else:
logger.error(
'Parameter %s[%r] conversion function is not callable!', namelist, name
)
logger.error('Parameter %s[%r] conversion function is not callable!',
namelist, name)
else:
# Simple invariant conversion
lines.append(' {0}={1}'.format(name, value))
lines.append('/')

for card in self.input_cards:
logger.debug("Add card: %s" % card)
logger.debug("Add card %r", card)
card_args = _input[card]
logger.debug("Card arguments: {0}".format(card_args))
logger.debug("Card arguments: %r", card_args)

if card not in OPTIONAL_CARDS and \
('_get_qe_input' not in card_args or
not callable(card_args['_get_qe_input'])):
logger.error("Missing conversion function for card '%s'" % card)
logger.error("Missing conversion function for card %r", card)

_get_qe_input = card_args.get('_get_qe_input', None)

Expand Down Expand Up @@ -490,7 +490,7 @@ class PwInputConverter(RawInputConverter):
'fix_area': ("CELL[cell_dofree]", options.get_cell_dofree, None),
'fix_xy': ("CELL[cell_dofree]", options.get_cell_dofree, None),
'isotropic': ("CELL[cell_dofree]", options.get_cell_dofree, None),
'cell_do_free': ("CELL[cell_dofree]",options.get_cell_dofree, None),
'cell_do_free': ("CELL[cell_dofree]", options.get_cell_dofree, None),
},
'symmetry_flags': {
'nosym': "SYSTEM[nosym]",
Expand Down
13 changes: 11 additions & 2 deletions qeschema/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def get_fortran_input(self, use_defaults=True):
rel_path = path.replace(input_path, '.')
xsd_element = schema_root.find(path)
if xsd_element is None:
logger.error("%r doesn't match any element!" % path)
logger.error("%r doesn't match any element!", path)
continue
else:
value = xsd_element.decode(elem, use_defaults=use_defaults)
Expand Down Expand Up @@ -594,6 +594,7 @@ def get_forces(self):
if elem is not None:
forces = self.schema.find(path).decode(elem)
path = './/output//atomic_positions'
breakpoint()
atomic_positions = self.schema.find(path).decode(self.find(path))
atoms = atomic_positions.get('atom', [])
if not isinstance(atoms, list):
Expand Down Expand Up @@ -622,7 +623,15 @@ def get_ks_eigenvalues(self):
:return: nested list of KS eigenvalues for each k_point in Hartree Units
"""
path = './/output//ks_energies/eigenvalues'
return [self.schema.find(path).decode(e)['$'] for e in self.findall(path)]
eigenvalues = []
for e in self.findall(path):
obj = self.schema.find(path).decode(e)
if isinstance(obj, dict):
eigenvalues.append(obj['$']) # pragma: no cover
else:
eigenvalues.append(obj)

return eigenvalues

@requires_xml_data
def get_total_energy(self):
Expand Down
89 changes: 89 additions & 0 deletions qeschema/hdf5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#
# Copyright (c), 2021, Quantum Espresso Foundation and SISSA (Scuola
# Internazionale Superiore di Studi Avanzati). All rights reserved.
# This file is distributed under the terms of the MIT License. See the
# file 'LICENSE' in the root directory of the present distribution, or
# http://opensource.org/licenses/MIT.
#
import numpy as np
import h5py

__all__ = ['read_charge_file', 'get_wf_attributes', 'get_wavefunctions',
'get_wfc_miller_indices']


def read_charge_file(filename):
"""
Reads a PW charge file in HDF5 format.
:param filename: the name of the HDF5 file to read.
:return: a dictionary describing the content of file \
keys=[nr, ngm_g, gamma_only, rhog_, MillerIndexes]
"""
with h5py.File(filename, "r") as h5f:
MI = h5f.get('MillerIndices')[:]
nr1 = 2 * max(abs(MI[:, 0])) + 1
nr2 = 2 * max(abs(MI[:, 1])) + 1
nr3 = 2 * max(abs(MI[:, 2])) + 1
nr = np.array([nr1, nr2, nr3])
res = dict(h5f.attrs.items())
res.update({'MillInd': MI, 'nr_min': nr})
rhog = h5f['rhotot_g'][:].reshape(res['ngm_g'], 2).dot([1.e0, 1.e0j])
res['rhotot_g'] = rhog
if 'rhodiff_g' in h5f.keys():
rhog = h5f['rhodiff_g'][:].reshape(res['ngm_g'], 2).dot([1.e0, 1.e0j])
res['rhodiff_g'] = rhog
return res


# TODO update to the new format
def get_wf_attributes(filename):
"""
Read attributes from a wfc HDF5 file.
:param filename: the path to the wfc file
:return: a dictionary with all attributes included reciprocal vectors
"""
with h5py.File(filename, "r") as f:
res = dict(f.attrs)
mi_attrs = f.get('MillerIndices').attrs
bg = np.array(mi_attrs.get(x) for x in ['bg1', 'bg2', 'bg3'])
res.update({'bg': bg})
return res


def get_wavefunctions(filename, start_band=None, stop_band=None):
"""
Returns a numpy array with the wave functions for bands from start_band to
stop_band. If not specified starts from 1st band and ends with last one.
Band numbering is Python style starts from 0.abs
:param filename: path to the wfc file
:param start_band: first band to read, default first band in the file
:param stop_band: last band to read, default last band in the file
:return: a numpy array with shape [nbnd,npw]
"""
with h5py.File(filename, "r") as f:
igwx = f.attrs.get('igwx')
if start_band is None:
start_band = 0
if stop_band is None:
stop_band = f.attrs.get('nbnd')
if stop_band == start_band:
stop_band = start_band + 1
res = f.get('evc')[start_band:stop_band, :]

res = np.asarray(x.reshape([igwx, 2]).dot([1.e0, 1.e0j]) for x in res[:])
return res


def get_wfc_miller_indices(filename):
"""
Reads miller indices from the wfc file
:param filename: path to the wfc HDF5 file
:return: a np.array of integers with shape [igwx,3]
"""
with h5py.File(filename, "r") as f:
res = f.get("MillerIndices")[:, :]
return res
18 changes: 0 additions & 18 deletions qeschema/hdf5/__init__.py

This file was deleted.

Loading

0 comments on commit 3bc68aa

Please sign in to comment.