Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Models #483

Merged
merged 5 commits into from
Sep 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ All the models listed above are runnable with ``Qlib``. Users can find the confi
- Users can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.

- Users can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
- **NOTE**: Each baseline has different environment dependencies, please make sure that your python version aligns with the requirements(e.g. TFT only supports Python 3.6~3.7 due to the limitation of `tensorflow==1.15.0`)

## Run multiple models
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only support *Linux* for now. Other OS will be supported in the future. Besides, it doesn't support parallel running the same model for multiple times as well, and this will be fixed in the future development too.)
Expand Down
2 changes: 1 addition & 1 deletion examples/benchmarks/TFT/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Users can follow the ``workflow_by_code_tft.py`` to run the benchmark.

### Notes
1. Please be **aware** that this script can only support `Python 3.5 - 3.8`.
1. Please be **aware** that this script can only support `Python 3.6 - 3.7`.
2. If the CUDA version on your machine is not 10.0, please remember to run the following commands `conda install anaconda cudatoolkit=10.0` and `conda install cudnn` on your machine.
3. The model must run in GPU, or an error will be raised.
4. New datasets should be registered in ``data_formatters``, for detail please visit the source.
3 changes: 1 addition & 2 deletions examples/benchmarks/TFT/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
tensorflow-gpu==1.15.0
numpy == 1.19.4
pandas==1.1.0
pandas==1.1.0
25 changes: 24 additions & 1 deletion examples/benchmarks/TFT/tft.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from pathlib import Path
from typing import Union
import numpy as np
import pandas as pd
import tensorflow.compat.v1 as tf
Expand Down Expand Up @@ -243,7 +245,7 @@ def extract_numerical_data(data):
# extract_numerical_data(targets), extract_numerical_data(p90_forecast),
# 0.9)
tf.keras.backend.set_session(default_keras_session)
print("Training completed.".format(dte.datetime.now()))
print("Training completed at {}.".format(dte.datetime.now()))
# ===========================Training Process===========================

def predict(self, dataset):
Expand Down Expand Up @@ -289,3 +291,24 @@ def finetune(self, dataset: DatasetH):
dataset for finetuning
"""
pass

def to_pickle(self, path: Union[Path, str]):
"""
Tensorflow model can't be dumped directly.
So the data should be save seperatedly

**TODO**: Please implement the function to load the files

Parameters
----------
path : Union[Path, str]
the target path to be dumped
"""
# save tensorflow model
# path = Path(path)
# path.mkdir(parents=True)
# self.model.save(path)

# save qlib model wrapper
self.model = None
super(TFTModel, self).to_pickle(path / "qlib_model")
1 change: 0 additions & 1 deletion qlib/contrib/model/pytorch_gats_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

class DailyBatchSampler(Sampler):
def __init__(self, data_source):

self.data_source = data_source
# calculate number of samples in each batch
self.daily_count = pd.Series(index=self.data_source.get_index()).groupby("datetime").size().values
Expand Down
19 changes: 19 additions & 0 deletions qlib/utils/serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,22 @@ def get_backend(cls):
return dill
else:
raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.")

@staticmethod
def general_dump(obj, path: Union[Path, str]):
"""
A general dumping method for object

Parameters
----------
obj : object
the object to be dumped
path : Union[Path, str]
the target path the data will be dumped
"""
path = Path(path)
if isinstance(obj, Serializable):
obj.to_pickle(path)
else:
with path.open("wb") as f:
pickle.dump(obj, f)
5 changes: 3 additions & 2 deletions qlib/workflow/recorder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from qlib.utils.serial import Serializable
import mlflow, logging
import shutil, os, pickle, tempfile, codecs, pickle
from pathlib import Path
Expand Down Expand Up @@ -307,8 +308,8 @@ def save_objects(self, local_path=None, artifact_path=None, **kwargs):
else:
temp_dir = Path(tempfile.mkdtemp()).resolve()
for name, data in kwargs.items():
with (temp_dir / name).open("wb") as f:
pickle.dump(data, f)
path = temp_dir / name
Serializable.general_dump(data, path)
self.client.log_artifact(self.id, temp_dir / name, artifact_path)
shutil.rmtree(temp_dir)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"statsmodels",
"xlrd>=1.0.0",
"plotly==4.12.0",
"matplotlib==3.3",
"matplotlib>=3.3",
"tables>=3.6.1",
"pyyaml>=5.3.1",
"mlflow>=1.12.1",
Expand Down