diff --git a/qlib/contrib/model/pytorch_sandwich.py b/qlib/contrib/model/pytorch_sandwich.py index 4a61be5e1d..020c736fd3 100644 --- a/qlib/contrib/model/pytorch_sandwich.py +++ b/qlib/contrib/model/pytorch_sandwich.py @@ -300,10 +300,15 @@ def test_epoch(self, data_x, data_y): return np.mean(losses), np.mean(scores) def fit( - self, dataset: DatasetH, evals_result=dict(), save_path=None, + self, + dataset: DatasetH, + evals_result=dict(), + save_path=None, ): df_train, df_valid, df_test = dataset.prepare( - ["train", "valid", "test"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L, + ["train", "valid", "test"], + col_set=["feature", "label"], + data_key=DataHandlerLP.DK_L, ) if df_train.empty or df_valid.empty: raise ValueError("Empty data from dataset, please check your dataset config.")