Skip to content

Commit

Permalink
fix gat dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
you-n-g committed Jun 28, 2021
1 parent 8709dde commit 70cb932
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions qlib/contrib/model/pytorch_gats_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@

class DailyBatchSampler(Sampler):
def __init__(self, data_source):

self.data_source = data_source
self.data = self.data_source.data.loc[self.data_source.get_index()]
self.daily_count = self.data.groupby(level=0).size().values # calculate number of samples in each batch
# calculate number of samples in each batch
self.daily_count = pd.Series(index=self.data_source.get_index()).groupby("datetime").size().values
self.daily_index = np.roll(np.cumsum(self.daily_count), 1) # calculate begin index of each batch
self.daily_index[0] = 0

Expand Down

0 comments on commit 70cb932

Please sign in to comment.