diff --git a/qlib/contrib/data/utils/sepdf.py b/qlib/contrib/data/utils/sepdf.py index 90537471e3..a6a56713e9 100644 --- a/qlib/contrib/data/utils/sepdf.py +++ b/qlib/contrib/data/utils/sepdf.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import pandas as pd -from typing import Dict, Iterable +from typing import Dict, Iterable, Union def align_index(df_dict, join): @@ -80,11 +80,30 @@ def _update_join(self): self.join = next(iter(self._df_dict.keys())) def __getitem__(self, item): + # TODO: behave more like pandas when multiindex return self._df_dict[item] - def __setitem__(self, item: str, df: pd.DataFrame): + def __setitem__(self, item: str, df: Union[pd.DataFrame, pd.Series]): # TODO: consider the join behavior - self._df_dict[item] = df + if not isinstance(item, tuple): + self._df_dict[item] = df + else: + # NOTE: corner case of MultiIndex + _df_dict_key, *col_name = item + col_name = tuple(col_name) + if _df_dict_key in self._df_dict: + if len(col_name) == 1: + col_name = col_name[0] + self._df_dict[_df_dict_key][col_name] = df + else: + if isinstance(df, pd.Series): + if len(col_name) == 1: + col_name = col_name[0] + self._df_dict[_df_dict_key] = df.to_frame(col_name) + else: + df_copy = df.copy() # avoid changing df + df_copy.columns = pd.MultiIndex.from_tuples([(*col_name, *idx) for idx in df.columns.to_list()]) + self._df_dict[_df_dict_key] = df_copy def __delitem__(self, item: str): del self._df_dict[item] diff --git a/setup.py b/setup.py index 527fe19626..8780e8be73 100644 --- a/setup.py +++ b/setup.py @@ -78,7 +78,7 @@ def get_version(rel_path: str) -> str: "dill", "dataclasses;python_version<'3.7'", "filelock", - "jinja2<3.1.0" # for passing the readthedocs workflow. + "jinja2<3.1.0", # for passing the readthedocs workflow. ] # Numpy include diff --git a/tests/misc/test_sepdf.py b/tests/misc/test_sepdf.py new file mode 100644 index 0000000000..6597ddaabe --- /dev/null +++ b/tests/misc/test_sepdf.py @@ -0,0 +1,54 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest +import numpy as np +import pandas as pd +from qlib.contrib.data.utils.sepdf import SepDataFrame + + +class SepDF(unittest.TestCase): + def to_str(self, obj): + return "".join(str(obj).split()) + + def test_index_data(self): + + np.random.seed(42) + + index = [ + np.array(["bar", "bar", "baz", "baz", "foo", "foo", "qux", "qux"]), + np.array(["one", "two", "one", "two", "one", "two", "one", "two"]), + ] + + cols = [ + np.repeat(np.array(["g1", "g2"]), 2), + np.arange(4), + ] + df = pd.DataFrame(np.random.randn(8, 4), index=index, columns=cols) + sdf = SepDataFrame(df_dict={"g2": df["g2"]}, join=None) + sdf[("g2", 4)] = 3 + sdf["g1"] = df["g1"] + exp = """ + {'g2': 2 3 4 + bar one 0.647689 1.523030 3 + two 1.579213 0.767435 3 + baz one -0.463418 -0.465730 3 + two -1.724918 -0.562288 3 + foo one -0.908024 -1.412304 3 + two 0.067528 -1.424748 3 + qux one -1.150994 0.375698 3 + two -0.601707 1.852278 3, 'g1': 0 1 + bar one 0.496714 -0.138264 + two -0.234153 -0.234137 + baz one -0.469474 0.542560 + two 0.241962 -1.913280 + foo one -1.012831 0.314247 + two 1.465649 -0.225776 + qux one -0.544383 0.110923 + two -0.600639 -0.291694} + """ + self.assertEqual(self.to_str(sdf._df_dict), self.to_str(exp)) + # self.assertEqual(self.to_str(data.tail()), self.to_str(res)) + + +if __name__ == "__main__": + unittest.main()