Skip to content

Commit

Permalink
update TSDataSampler refineing the memory layout of data array to spe…
Browse files Browse the repository at this point in the history
…ed up NN training (#1342)

* update TSDataSampler

* reformat code with black

* use pre-commit to reformat the code

* Add documents

* More docstring

* More Safety

Co-authored-by: Young <afe.young@gmail.com>
  • Loading branch information
peteryang1 and you-n-g committed Nov 11, 2022
1 parent 3b471a0 commit a82cc0b
Showing 1 changed file with 127 additions and 24 deletions.
151 changes: 127 additions & 24 deletions qlib/data/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@ class DatasetH(Dataset):
"""

def __init__(
self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple], fetch_kwargs: Dict = {}, **kwargs
self,
handler: Union[Dict, DataHandler],
segments: Dict[Text, Tuple],
fetch_kwargs: Dict = {},
**kwargs,
):
"""
Setup the underlying data.
Expand Down Expand Up @@ -284,18 +288,77 @@ class TSDataSampler:
- For performance issues, this Sampler will convert dataframe into arrays for better performance. This could result
in a different data type
Indices design:
TSDataSampler has a index mechanism to help users query time-series data efficiently.
The definition of related variables:
data_arr: np.ndarray
The original data. it will contains all the original data.
The querying are often for time-series of a specific stock.
By leveraging this data charactoristics to speed up querying, the multi-index of data_arr is rearranged in (instrument, datetime) order
data_index: pd.MultiIndex with index order <instrument, datetime>
it has the same shape with `idx_map`. Each elements of them are expected to be aligned.
idx_map: np.ndarray
It is the indexable data. It originates from data_arr, and then filtered by 1) `start` and `end` 2) `flt_data`
The extra data in data_arr is useful in following cases
1) creating meaningful time series data before `start` instead of padding them with zeros
2) some data are excluded by `flt_data` (e.g. no <X, y> sample pair for that index). but they are still used in time-series in X
Finnally, it will look like.
array([[ 0, 0],
[ 1, 0],
[ 2, 0],
...,
[241, 348],
[242, 348],
[243, 348]], dtype=int32)
It list all indexable data(some data only used in historical time series data may not be indexabla), the values are the corresponding row and col in idx_df
idx_df: pd.DataFrame
It aims to map the <datetime, instrument> key to the original position in data_arr
For example, it may look like (NOTE: the index for a instrument time-series is continoues in memory)
instrument SH600000 SH600008 SH600009 SH600010 SH600011 SH600015 ...
datetime
2017-01-03 0 242 473 717 NaN 974 ...
2017-01-04 1 243 474 718 NaN 975 ...
2017-01-05 2 244 475 719 NaN 976 ...
2017-01-06 3 245 476 720 NaN 977 ...
With these two indices(idx_map, idx_df) and original data(data_arr), we can make the following queries fast (implemented in __getitem__)
(1) Get the i-th indexable sample(time-series): (indexable sample index) -> [idx_map] -> (row col) -> [idx_df] -> (index in data_arr)
(2) Get the specific sample by <datetime, instrument>: (<datetime, instrument>, i.e. <row, col>) -> [idx_df] -> (index in data_arr)
(3) Get the index of a time-series data: (get the <row, col>, refer to (1), (2)) -> [idx_df] -> (all indices in data_arr for time-series)
"""

# Please refer to the docstring of TSDataSampler for the definition of following attributes
data_arr: np.ndarray
data_index: pd.MultiIndex
idx_map: np.ndarray
idx_df: pd.DataFrame

def __init__(
self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none", dtype=None, flt_data=None
self,
data: pd.DataFrame,
start,
end,
step_len: int,
fillna_type: str = "none",
dtype=None,
flt_data=None,
):
"""
Build a dataset which looks like torch.data.utils.Dataset.
Parameters
----------
data : pd.DataFrame
The raw tabular data
The raw tabular data whose index order is <"datetime", "instrument">
start :
The indexable start time
end :
Expand All @@ -311,7 +374,7 @@ def __init__(
ffill+bfill:
ffill with previous samples first and fill with later samples second
flt_data : pd.Series
a column of data(True or False) to filter data.
a column of data(True or False) to filter data. Its index order is <"datetime", "instrument">
None:
kepp all data
Expand All @@ -321,7 +384,10 @@ def __init__(
self.step_len = step_len
self.fillna_type = fillna_type
assert get_level_index(data, "datetime") == 0
self.data = lazy_sort_index(data)
self.data = data.swaplevel().sort_index().copy()
data.drop(
data.columns, axis=1, inplace=True
) # data is useless since it's passed to a transposed one, hard code to free the memory of this dataframe to avoid three big dataframe in the memory(including: data, self.data, self.data_arr)

kwargs = {"object": self.data}
if dtype is not None:
Expand All @@ -332,7 +398,9 @@ def __init__(
# - append last line with full NaN for better performance in `__getitem__`
# - Keep the same dtype will result in a better performance
self.data_arr = np.append(
self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype), axis=0
self.data_arr,
np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype),
axis=0,
)
self.nan_idx = -1 # The last line is all NaN

Expand All @@ -347,19 +415,36 @@ def __init__(
flt_data = flt_data.iloc[:, 0]
# NOTE: bool(np.nan) is True !!!!!!!!
# make sure reindex comes first. Otherwise extra NaN may appear.
flt_data = flt_data.swaplevel()
flt_data = flt_data.reindex(self.data_index).fillna(False).astype(np.bool)
self.flt_data = flt_data.values
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
self.data_index = self.data_index[np.where(self.flt_data)[0]]
self.idx_map = self.idx_map2arr(self.idx_map)

self.start_idx, self.end_idx = self.data_index.slice_locs(
start=time_to_slc_point(start), end=time_to_slc_point(end)
self.idx_map, self.data_index = self.slice_idx_map_and_data_index(
self.idx_map, self.idx_df, self.data_index, start, end
)
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance

self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
del self.data # save memory

@staticmethod
def slice_idx_map_and_data_index(
idx_map,
idx_df,
data_index,
start,
end,
):
assert (
len(idx_map) == data_index.shape[0]
) # make sure idx_map and data_index is same so index of idx_map can be used on data_index

start_row_idx, end_row_idx = idx_df.index.slice_locs(start=time_to_slc_point(start), end=time_to_slc_point(end))

time_flter_idx = (idx_map[:, 0] < end_row_idx) & (idx_map[:, 0] >= start_row_idx)
return idx_map[time_flter_idx], data_index[time_flter_idx]

@staticmethod
def idx_map2arr(idx_map):
# pytorch data sampler will have better memory control without large dict or list
Expand Down Expand Up @@ -394,7 +479,7 @@ def get_index(self):
Get the pandas index of the data, it will be useful in following scenarios
- Special sampler will be used (e.g. user want to sample day by day)
"""
return self.data_index[self.start_idx : self.end_idx]
return self.data_index.swaplevel() # to align the order of multiple index of original data received by __init__

def config(self, **kwargs):
# Config the attributes
Expand All @@ -409,25 +494,33 @@ def build_index(data: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
Parameters
----------
data : pd.DataFrame
The dataframe with <datetime, DataFrame>
A DataFrame with index in order <instrument, datetime>
RSQR5 RESI5 WVMA5 LABEL0
instrument datetime
SH600000 2017-01-03 0.016389 0.461632 -1.154788 -0.048056
2017-01-04 0.884545 -0.110597 -1.059332 -0.030139
2017-01-05 0.507540 -0.535493 -1.099665 -0.644983
2017-01-06 -1.267771 -0.669685 -1.636733 0.295366
2017-01-09 0.339346 0.074317 -0.984989 0.765540
Returns
-------
Tuple[pd.DataFrame, dict]:
1) the first element: reshape the original index into a <datetime(row), instrument(column)> 2D dataframe
instrument SH600000 SH600004 SH600006 SH600007 SH600008 SH600009 ...
instrument SH600000 SH600008 SH600009 SH600010 SH600011 SH600015 ...
datetime
2021-01-11 0 1 2 3 4 5 ...
2021-01-12 4146 4147 4148 4149 4150 4151 ...
2021-01-13 8293 8294 8295 8296 8297 8298 ...
2021-01-14 12441 12442 12443 12444 12445 12446 ...
2017-01-03 0 242 473 717 NaN 974 ...
2017-01-04 1 243 474 718 NaN 975 ...
2017-01-05 2 244 475 719 NaN 976 ...
2017-01-06 3 245 476 720 NaN 977 ...
2) the second element: {<original index>: <row, col>}
"""
# object incase of pandas converting int to float
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object)
idx_df = lazy_sort_index(idx_df.unstack())
# NOTE: the correctness of `__getitem__` depends on columns sorted here
idx_df = lazy_sort_index(idx_df, axis=1)
idx_df = lazy_sort_index(idx_df, axis=1).T

idx_map = {}
for i, (_, row) in enumerate(idx_df.iterrows()):
Expand Down Expand Up @@ -485,11 +578,11 @@ def _get_row_col(self, idx) -> Tuple[int]:
"""
# The the right row number `i` and col number `j` in idx_df
if isinstance(idx, (int, np.integer)):
real_idx = self.start_idx + idx
if self.start_idx <= real_idx < self.end_idx:
real_idx = idx
if 0 <= real_idx < len(self.idx_map):
i, j = self.idx_map[real_idx] # TODO: The performance of this line is not good
else:
raise KeyError(f"{real_idx} is out of [{self.start_idx}, {self.end_idx})")
raise KeyError(f"{real_idx} is out of [0, {len(self.idx_map)})")
elif isinstance(idx, tuple):
# <TSDataSampler object>["datetime", "instruments"]
date, inst = idx
Expand Down Expand Up @@ -532,15 +625,18 @@ def __getitem__(self, idx: Union[int, Tuple[object, str], List[int]]):
# precision problems. It will not cause any problems in my tests at least
indices = np.nan_to_num(indices.astype(np.float64), nan=self.nan_idx).astype(int)

data = self.data_arr[indices]
if (np.diff(indices) == 1).all(): # slicing instead of indexing for speeding up.
data = self.data_arr[indices[0] : indices[-1] + 1]
else:
data = self.data_arr[indices]
if isinstance(idx, mtit):
# if we get multiple indexes, addition dimension should be added.
# <sample_idx, step_idx, feature_idx>
data = data.reshape(-1, self.step_len, *data.shape[1:])
return data

def __len__(self):
return self.end_idx - self.start_idx
return len(self.idx_map)


class TSDatasetH(DatasetH):
Expand Down Expand Up @@ -611,7 +707,14 @@ def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
else:
flt_data = None

tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype, flt_data=flt_data)
tsds = TSDataSampler(
data=data,
start=start,
end=end,
step_len=self.step_len,
dtype=dtype,
flt_data=flt_data,
)
return tsds


Expand Down

0 comments on commit a82cc0b

Please sign in to comment.