Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handler demo cache #606

Merged
merged 7 commits into from
Nov 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/data_demo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Introduction
The examples in this folder try to demonstrate some common usage of data-related modules of Qlib
53 changes: 53 additions & 0 deletions examples/data_demo/data_cache_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
The motivation of this demo
- To show the data modules of Qlib is Serializable, users can dump processed data to disk to avoid duplicated data preprocessing
"""

from copy import deepcopy
from pathlib import Path
import pickle
from pprint import pprint
import subprocess
import yaml
from qlib.log import TimeInspector

from qlib import init
from qlib.data.dataset.handler import DataHandlerLP
from qlib.utils import init_instance_by_config

# For general purpose, we use relative path
DIRNAME = Path(__file__).absolute().resolve().parent

if __name__ == "__main__":
init()

config_path = DIRNAME.parent / "benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml"

# 1) show original time
with TimeInspector.logt("The original time without handler cache:"):
subprocess.run(f"qrun {config_path}", shell=True)

# 2) dump handler
task_config = yaml.safe_load(config_path.open())
hd_conf = task_config["task"]["dataset"]["kwargs"]["handler"]
pprint(hd_conf)
hd: DataHandlerLP = init_instance_by_config(hd_conf)
hd_path = DIRNAME / "handler.pkl"
hd.to_pickle(hd_path, dump_all=True)

# 3) create new task with handler cache
new_task_config = deepcopy(task_config)
new_task_config["task"]["dataset"]["kwargs"]["handler"] = f"file://{hd_path}"
new_task_config
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line of code is incomplete: new_task_config

new_task_path = DIRNAME / "new_task.yaml"
print("The location of the new task", new_task_path)

# save new task
with new_task_path.open("w") as f:
yaml.safe_dump(new_task_config, f)

# 4) train model with new task
with TimeInspector.logt("The time for task with handler cache:"):
subprocess.run(f"qrun {new_task_path}", shell=True)
59 changes: 59 additions & 0 deletions examples/data_demo/data_mem_resuse_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
The motivation of this demo
- To show the data modules of Qlib is Serializable, users can dump processed data to disk to avoid duplicated data preprocessing
"""

from copy import deepcopy
from pathlib import Path
import pickle
from pprint import pprint
import subprocess

import yaml

from qlib import init
from qlib.data.dataset.handler import DataHandlerLP
from qlib.log import TimeInspector
from qlib.model.trainer import task_train
from qlib.utils import init_instance_by_config

# For general purpose, we use relative path
DIRNAME = Path(__file__).absolute().resolve().parent

if __name__ == "__main__":
init()

repeat = 2
exp_name = "data_mem_reuse_demo"

config_path = DIRNAME.parent / "benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml"
task_config = yaml.safe_load(config_path.open())

# 1) without using processed data in memory
with TimeInspector.logt("The original time without reusing processed data in memory:"):
for i in range(repeat):
task_train(task_config["task"], experiment_name=exp_name)

# 2) prepare processed data in memory.
hd_conf = task_config["task"]["dataset"]["kwargs"]["handler"]
pprint(hd_conf)
hd: DataHandlerLP = init_instance_by_config(hd_conf)

# 3) with reusing processed data in memory
new_task = deepcopy(task_config["task"])
new_task["dataset"]["kwargs"]["handler"] = hd
print(new_task)

with TimeInspector.logt("The time with reusing processed data in memory:"):
# this will save the time to reload and process data from disk(in `DataHandlerLP`)
# It still takes a lot of time in the backtest phase
for i in range(repeat):
task_train(new_task, experiment_name=exp_name)

# 4) User can change other parts exclude processed data in memory(handler)
new_task = deepcopy(task_config["task"])
new_task["dataset"]["kwargs"]["segments"]["train"] = ("20100101", "20131231")
with TimeInspector.logt("The time with reusing processed data in memory:"):
task_train(new_task, experiment_name=exp_name)
122 changes: 84 additions & 38 deletions qlib/model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import re
from typing import Callable, List

from tqdm.auto import tqdm
from qlib.data.dataset import Dataset
from qlib.log import get_module_logger
from qlib.model.base import Model
Expand All @@ -25,6 +26,48 @@
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.manage import TaskManager, run_task

# from qlib.data.dataset.weight import Reweighter


def _log_task_info(task_config: dict):
R.log_params(**flatten_dict(task_config))
R.save_objects(**{"task": task_config}) # keep the original format and datatype
R.set_tags(**{"hostname": socket.gethostname()})


def _exe_task(task_config: dict):
rec = R.get_recorder()
# model & dataset initiation
model: Model = init_instance_by_config(task_config["model"])
dataset: Dataset = init_instance_by_config(task_config["dataset"])
# FIXME: resume reweighter after merging data selection
# reweighter: Reweighter = task_config.get("reweighter", None)
# model training
# auto_filter_kwargs(model.fit)(dataset, reweighter=reweighter)
model.fit(dataset)
R.save_objects(**{"params.pkl": model})
# this dataset is saved for online inference. So the concrete data should not be dumped
dataset.config(dump_all=False, recursive=True)
R.save_objects(**{"dataset": dataset})
# fill placehorder
placehorder_value = {"<MODEL>": model, "<DATASET>": dataset}
task_config = fill_placeholder(task_config, placehorder_value)
# generate records: prediction, backtest, and analysis
records = task_config.get("record", [])
if isinstance(records, dict): # prevent only one dict
records = [records]
for record in records:
# Some recorder require the parameter `model` and `dataset`.
# try to automatically pass in them to the initialization function
# to make defining the tasking easier
r = init_instance_by_config(
record,
recorder=rec,
default_module="qlib.workflow.record_temp",
try_kwargs={"model": model, "dataset": dataset},
)
r.generate()


def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str = None) -> Recorder:
"""
Expand All @@ -39,11 +82,8 @@ def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str
Recorder: the model recorder
"""
with R.start(experiment_name=experiment_name, recorder_name=recorder_name):
R.log_params(**flatten_dict(task_config))
R.save_objects(**{"task": task_config}) # keep the original format and datatype
R.set_tags(**{"hostname": socket.gethostname()})
recorder: Recorder = R.get_recorder()
return recorder
_log_task_info(task_config)
return R.get_recorder()


def fill_placeholder(config: dict, config_extend: dict):
Expand Down Expand Up @@ -100,38 +140,11 @@ def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
"""
with R.start(experiment_name=experiment_name, recorder_id=rec.info["id"], resume=True):
task_config = R.load_object("task")
# model & dataset initiation
model: Model = init_instance_by_config(task_config["model"])
dataset: Dataset = init_instance_by_config(task_config["dataset"])
# model training
model.fit(dataset)
R.save_objects(**{"params.pkl": model})
# this dataset is saved for online inference. So the concrete data should not be dumped
dataset.config(dump_all=False, recursive=True)
R.save_objects(**{"dataset": dataset})
# fill placehorder
placehorder_value = {"<MODEL>": model, "<DATASET>": dataset}
task_config = fill_placeholder(task_config, placehorder_value)
# generate records: prediction, backtest, and analysis
records = task_config.get("record", [])
if isinstance(records, dict): # uniform the data format to list
records = [records]

for record in records:
# Some recorder require the parameter `model` and `dataset`.
# try to automatically pass in them to the initialization function
# to make defining the tasking easier
r = init_instance_by_config(
record,
recorder=rec,
default_module="qlib.workflow.record_temp",
try_kwargs={"model": model, "dataset": dataset},
)
r.generate()
_exe_task(task_config)
return rec


def task_train(task_config: dict, experiment_name: str) -> Recorder:
def task_train(task_config: dict, experiment_name: str, recorder_name: str = None) -> Recorder:
"""
Task based training, will be divided into two steps.

Expand All @@ -141,14 +154,17 @@ def task_train(task_config: dict, experiment_name: str) -> Recorder:
The config of a task.
experiment_name: str
The name of experiment
recorder_name: str
The name of recorder

Returns
----------
Recorder: The instance of the recorder
"""
recorder = begin_task_train(task_config, experiment_name)
recorder = end_task_train(recorder, experiment_name)
return recorder
with R.start(experiment_name=experiment_name, recorder_name=recorder_name):
_log_task_info(task_config)
_exe_task(task_config)
return R.get_recorder()


class Trainer:
Expand Down Expand Up @@ -204,6 +220,30 @@ def is_delay(self) -> bool:
def __call__(self, *args, **kwargs) -> list:
return self.end_train(self.train(*args, **kwargs))

def has_worker(self) -> bool:
"""
Some trainer has backend worker to support parallel training
This method can tell if the worker is enabled.

Returns
-------
bool:
if the worker is enabled

"""
return False

def worker(self):
"""
start the worker

Raises
------
NotImplementedError:
If the worker is not supported
"""
raise NotImplementedError(f"Please implement the `worker` method")


class TrainerR(Trainer):
"""
Expand Down Expand Up @@ -252,7 +292,7 @@ def train(self, tasks: list, train_func: Callable = None, experiment_name: str =
if experiment_name is None:
experiment_name = self.experiment_name
recs = []
for task in tasks:
for task in tqdm(tasks):
rec = train_func(task, experiment_name, **kwargs)
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
recs.append(rec)
Expand Down Expand Up @@ -457,6 +497,9 @@ def worker(
task_pool = experiment_name
run_task(train_func, task_pool=task_pool, experiment_name=experiment_name)

def has_worker(self) -> bool:
return True


class DelayTrainerRM(TrainerRM):
"""
Expand Down Expand Up @@ -579,3 +622,6 @@ def worker(self, end_train_func=None, experiment_name: str = None):
experiment_name=experiment_name,
before_status=TaskManager.STATUS_PART_DONE,
)

def has_worker(self) -> bool:
return True