Skip to content
This repository has been archived by the owner on Jan 9, 2023. It is now read-only.

Commit

Permalink
Merge pull request #83 from chrisburr/fix-bugs
Browse files Browse the repository at this point in the history
Remove depricated use of pandas and ensure column order is correct
  • Loading branch information
chrisburr authored Jul 23, 2019
2 parents 57991a4 + ea2ec6b commit 76e820d
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
14 changes: 10 additions & 4 deletions root_pandas/readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""
A module that extends pandas to support the ROOT data format.
"""
from collections import Counter

import numpy as np
from numpy.lib.recfunctions import append_fields
Expand Down Expand Up @@ -95,11 +96,11 @@ def get_nonscalar_columns(array):
def get_matching_variables(branches, patterns, fail=True):
# Convert branches to a set to make x "in branches" O(1) on average
branches = set(branches)
patterns = set(patterns)
# Find any trivial matches
selected = list(branches.intersection(patterns))
selected = sorted(branches.intersection(patterns),
key=lambda s: patterns.index(s))
# Any matches that weren't trivial need to be looped over...
for pattern in patterns.difference(selected):
for pattern in set(patterns).difference(selected):
found = False
# Avoid using fnmatch if the pattern if possible
if re.findall(r'(\*)|(\?)|(\[.*\])|(\[\!.*\])', pattern):
Expand Down Expand Up @@ -317,7 +318,7 @@ def convert_to_dataframe(array, start_index=None):
# Filter to remove __index__ columns
columns = [c for c in array.dtype.names if c in df.columns]
assert len(columns) == len(df.columns), (columns, df.columns)
df = df.reindex_axis(columns, axis=1, copy=False)
df = df.reindex(columns, axis=1, copy=False)

# Convert categorical columns back to categories
for c in df.columns:
Expand Down Expand Up @@ -366,6 +367,11 @@ def to_root(df, path, key='my_ttree', mode='w', store_index=True, *args, **kwarg
else:
raise ValueError('Unknown mode: {}. Must be "a" or "w".'.format(mode))

column_name_counts = Counter(df.columns)
if max(column_name_counts.values()) > 1:
raise ValueError('DataFrame contains duplicated column names: ' +
' '.join({k for k, v in column_name_counts.items() if v > 1}))

from root_numpy import array2tree
# We don't want to modify the user's DataFrame here, so we make a shallow copy
df_ = df.copy(deep=False)
Expand Down
2 changes: 1 addition & 1 deletion root_pandas/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
'version_info',
]

__version__ = '0.6.1'
__version__ = '0.7.0'
version = __version__
version_info = tuple(__version__.split('.'))
17 changes: 17 additions & 0 deletions tests/test_issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,20 @@ def test_issue_63():
assert all(len(df) == 1 for df in result)
os.remove('tmp_1.root')
os.remove('tmp_2.root')


def test_issue_80():
df = pd.DataFrame({'a': [1, 2], 'b': [4, 5]})
df.columns = ['a', 'a']
try:
root_pandas.to_root(df, '/tmp/example.root')
except ValueError as e:
assert 'DataFrame contains duplicated column names' in e.args[0]
else:
raise Exception('ValueError is expected')


def test_issue_82():
variables = ['MET_px', 'MET_py', 'EventWeight']
df = root_pandas.read_root('http://scikit-hep.org/uproot/examples/HZZ.root', 'events', columns=variables)
assert list(df.columns) == variables

0 comments on commit 76e820d

Please sign in to comment.