Skip to content

Commit

Permalink
Make sepdf more like DataFrame (#1080)
Browse files Browse the repository at this point in the history
  • Loading branch information
you-n-g committed Apr 28, 2022
1 parent 701b18a commit 9d0a8f6
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 4 deletions.
25 changes: 22 additions & 3 deletions qlib/contrib/data/utils/sepdf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
from typing import Dict, Iterable
from typing import Dict, Iterable, Union


def align_index(df_dict, join):
Expand Down Expand Up @@ -80,11 +80,30 @@ def _update_join(self):
self.join = next(iter(self._df_dict.keys()))

def __getitem__(self, item):
# TODO: behave more like pandas when multiindex
return self._df_dict[item]

def __setitem__(self, item: str, df: pd.DataFrame):
def __setitem__(self, item: str, df: Union[pd.DataFrame, pd.Series]):
# TODO: consider the join behavior
self._df_dict[item] = df
if not isinstance(item, tuple):
self._df_dict[item] = df
else:
# NOTE: corner case of MultiIndex
_df_dict_key, *col_name = item
col_name = tuple(col_name)
if _df_dict_key in self._df_dict:
if len(col_name) == 1:
col_name = col_name[0]
self._df_dict[_df_dict_key][col_name] = df
else:
if isinstance(df, pd.Series):
if len(col_name) == 1:
col_name = col_name[0]
self._df_dict[_df_dict_key] = df.to_frame(col_name)
else:
df_copy = df.copy() # avoid changing df
df_copy.columns = pd.MultiIndex.from_tuples([(*col_name, *idx) for idx in df.columns.to_list()])
self._df_dict[_df_dict_key] = df_copy

def __delitem__(self, item: str):
del self._df_dict[item]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def get_version(rel_path: str) -> str:
"dill",
"dataclasses;python_version<'3.7'",
"filelock",
"jinja2<3.1.0" # for passing the readthedocs workflow.
"jinja2<3.1.0", # for passing the readthedocs workflow.
]

# Numpy include
Expand Down
54 changes: 54 additions & 0 deletions tests/misc/test_sepdf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import unittest
import numpy as np
import pandas as pd
from qlib.contrib.data.utils.sepdf import SepDataFrame


class SepDF(unittest.TestCase):
def to_str(self, obj):
return "".join(str(obj).split())

def test_index_data(self):

np.random.seed(42)

index = [
np.array(["bar", "bar", "baz", "baz", "foo", "foo", "qux", "qux"]),
np.array(["one", "two", "one", "two", "one", "two", "one", "two"]),
]

cols = [
np.repeat(np.array(["g1", "g2"]), 2),
np.arange(4),
]
df = pd.DataFrame(np.random.randn(8, 4), index=index, columns=cols)
sdf = SepDataFrame(df_dict={"g2": df["g2"]}, join=None)
sdf[("g2", 4)] = 3
sdf["g1"] = df["g1"]
exp = """
{'g2': 2 3 4
bar one 0.647689 1.523030 3
two 1.579213 0.767435 3
baz one -0.463418 -0.465730 3
two -1.724918 -0.562288 3
foo one -0.908024 -1.412304 3
two 0.067528 -1.424748 3
qux one -1.150994 0.375698 3
two -0.601707 1.852278 3, 'g1': 0 1
bar one 0.496714 -0.138264
two -0.234153 -0.234137
baz one -0.469474 0.542560
two 0.241962 -1.913280
foo one -1.012831 0.314247
two 1.465649 -0.225776
qux one -0.544383 0.110923
two -0.600639 -0.291694}
"""
self.assertEqual(self.to_str(sdf._df_dict), self.to_str(exp))
# self.assertEqual(self.to_str(data.tail()), self.to_str(res))


if __name__ == "__main__":
unittest.main()

0 comments on commit 9d0a8f6

Please sign in to comment.