1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

handler demo cache (#606)

* handler demo  cache

* Update data_cache_demo.py

* example to reusing processed data in memory

* Skip dumping task of task_train

* FIX Black

Co-authored-by: Wangwuyi123 <51237097+Wangwuyi123@users.noreply.github.com>
This commit is contained in:
you-n-g
2021-11-08 17:33:10 +08:00
committed by GitHub
parent fdbc666678
commit a2be6e28e9
4 changed files with 198 additions and 38 deletions

View File

@@ -0,0 +1,2 @@
# Introduction
The examples in this folder try to demonstrate some common usage of data-related modules of Qlib

View File

@@ -0,0 +1,53 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
The motivation of this demo
- To show the data modules of Qlib is Serializable, users can dump processed data to disk to avoid duplicated data preprocessing
"""
from copy import deepcopy
from pathlib import Path
import pickle
from pprint import pprint
import subprocess
import yaml
from qlib.log import TimeInspector
from qlib import init
from qlib.data.dataset.handler import DataHandlerLP
from qlib.utils import init_instance_by_config
# For general purpose, we use relative path
DIRNAME = Path(__file__).absolute().resolve().parent
if __name__ == "__main__":
init()
config_path = DIRNAME.parent / "benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml"
# 1) show original time
with TimeInspector.logt("The original time without handler cache:"):
subprocess.run(f"qrun {config_path}", shell=True)
# 2) dump handler
task_config = yaml.safe_load(config_path.open())
hd_conf = task_config["task"]["dataset"]["kwargs"]["handler"]
pprint(hd_conf)
hd: DataHandlerLP = init_instance_by_config(hd_conf)
hd_path = DIRNAME / "handler.pkl"
hd.to_pickle(hd_path, dump_all=True)
# 3) create new task with handler cache
new_task_config = deepcopy(task_config)
new_task_config["task"]["dataset"]["kwargs"]["handler"] = f"file://{hd_path}"
new_task_config
new_task_path = DIRNAME / "new_task.yaml"
print("The location of the new task", new_task_path)
# save new task
with new_task_path.open("w") as f:
yaml.safe_dump(new_task_config, f)
# 4) train model with new task
with TimeInspector.logt("The time for task with handler cache:"):
subprocess.run(f"qrun {new_task_path}", shell=True)

View File

@@ -0,0 +1,59 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
The motivation of this demo
- To show the data modules of Qlib is Serializable, users can dump processed data to disk to avoid duplicated data preprocessing
"""
from copy import deepcopy
from pathlib import Path
import pickle
from pprint import pprint
import subprocess
import yaml
from qlib import init
from qlib.data.dataset.handler import DataHandlerLP
from qlib.log import TimeInspector
from qlib.model.trainer import task_train
from qlib.utils import init_instance_by_config
# For general purpose, we use relative path
DIRNAME = Path(__file__).absolute().resolve().parent
if __name__ == "__main__":
init()
repeat = 2
exp_name = "data_mem_reuse_demo"
config_path = DIRNAME.parent / "benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml"
task_config = yaml.safe_load(config_path.open())
# 1) without using processed data in memory
with TimeInspector.logt("The original time without reusing processed data in memory:"):
for i in range(repeat):
task_train(task_config["task"], experiment_name=exp_name)
# 2) prepare processed data in memory.
hd_conf = task_config["task"]["dataset"]["kwargs"]["handler"]
pprint(hd_conf)
hd: DataHandlerLP = init_instance_by_config(hd_conf)
# 3) with reusing processed data in memory
new_task = deepcopy(task_config["task"])
new_task["dataset"]["kwargs"]["handler"] = hd
print(new_task)
with TimeInspector.logt("The time with reusing processed data in memory:"):
# this will save the time to reload and process data from disk(in `DataHandlerLP`)
# It still takes a lot of time in the backtest phase
for i in range(repeat):
task_train(new_task, experiment_name=exp_name)
# 4) User can change other parts exclude processed data in memory(handler)
new_task = deepcopy(task_config["task"])
new_task["dataset"]["kwargs"]["segments"]["train"] = ("20100101", "20131231")
with TimeInspector.logt("The time with reusing processed data in memory:"):
task_train(new_task, experiment_name=exp_name)

View File

@@ -16,6 +16,7 @@ import time
import re
from typing import Callable, List
from tqdm.auto import tqdm
from qlib.data.dataset import Dataset
from qlib.log import get_module_logger
from qlib.model.base import Model
@@ -25,6 +26,48 @@ from qlib.workflow.record_temp import SignalRecord
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.manage import TaskManager, run_task
# from qlib.data.dataset.weight import Reweighter
def _log_task_info(task_config: dict):
R.log_params(**flatten_dict(task_config))
R.save_objects(**{"task": task_config}) # keep the original format and datatype
R.set_tags(**{"hostname": socket.gethostname()})
def _exe_task(task_config: dict):
rec = R.get_recorder()
# model & dataset initiation
model: Model = init_instance_by_config(task_config["model"])
dataset: Dataset = init_instance_by_config(task_config["dataset"])
# FIXME: resume reweighter after merging data selection
# reweighter: Reweighter = task_config.get("reweighter", None)
# model training
# auto_filter_kwargs(model.fit)(dataset, reweighter=reweighter)
model.fit(dataset)
R.save_objects(**{"params.pkl": model})
# this dataset is saved for online inference. So the concrete data should not be dumped
dataset.config(dump_all=False, recursive=True)
R.save_objects(**{"dataset": dataset})
# fill placehorder
placehorder_value = {"<MODEL>": model, "<DATASET>": dataset}
task_config = fill_placeholder(task_config, placehorder_value)
# generate records: prediction, backtest, and analysis
records = task_config.get("record", [])
if isinstance(records, dict): # prevent only one dict
records = [records]
for record in records:
# Some recorder require the parameter `model` and `dataset`.
# try to automatically pass in them to the initialization function
# to make defining the tasking easier
r = init_instance_by_config(
record,
recorder=rec,
default_module="qlib.workflow.record_temp",
try_kwargs={"model": model, "dataset": dataset},
)
r.generate()
def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str = None) -> Recorder:
"""
@@ -39,11 +82,8 @@ def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str
Recorder: the model recorder
"""
with R.start(experiment_name=experiment_name, recorder_name=recorder_name):
R.log_params(**flatten_dict(task_config))
R.save_objects(**{"task": task_config}) # keep the original format and datatype
R.set_tags(**{"hostname": socket.gethostname()})
recorder: Recorder = R.get_recorder()
return recorder
_log_task_info(task_config)
return R.get_recorder()
def fill_placeholder(config: dict, config_extend: dict):
@@ -100,38 +140,11 @@ def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
"""
with R.start(experiment_name=experiment_name, recorder_id=rec.info["id"], resume=True):
task_config = R.load_object("task")
# model & dataset initiation
model: Model = init_instance_by_config(task_config["model"])
dataset: Dataset = init_instance_by_config(task_config["dataset"])
# model training
model.fit(dataset)
R.save_objects(**{"params.pkl": model})
# this dataset is saved for online inference. So the concrete data should not be dumped
dataset.config(dump_all=False, recursive=True)
R.save_objects(**{"dataset": dataset})
# fill placehorder
placehorder_value = {"<MODEL>": model, "<DATASET>": dataset}
task_config = fill_placeholder(task_config, placehorder_value)
# generate records: prediction, backtest, and analysis
records = task_config.get("record", [])
if isinstance(records, dict): # uniform the data format to list
records = [records]
for record in records:
# Some recorder require the parameter `model` and `dataset`.
# try to automatically pass in them to the initialization function
# to make defining the tasking easier
r = init_instance_by_config(
record,
recorder=rec,
default_module="qlib.workflow.record_temp",
try_kwargs={"model": model, "dataset": dataset},
)
r.generate()
_exe_task(task_config)
return rec
def task_train(task_config: dict, experiment_name: str) -> Recorder:
def task_train(task_config: dict, experiment_name: str, recorder_name: str = None) -> Recorder:
"""
Task based training, will be divided into two steps.
@@ -141,14 +154,17 @@ def task_train(task_config: dict, experiment_name: str) -> Recorder:
The config of a task.
experiment_name: str
The name of experiment
recorder_name: str
The name of recorder
Returns
----------
Recorder: The instance of the recorder
"""
recorder = begin_task_train(task_config, experiment_name)
recorder = end_task_train(recorder, experiment_name)
return recorder
with R.start(experiment_name=experiment_name, recorder_name=recorder_name):
_log_task_info(task_config)
_exe_task(task_config)
return R.get_recorder()
class Trainer:
@@ -204,6 +220,30 @@ class Trainer:
def __call__(self, *args, **kwargs) -> list:
return self.end_train(self.train(*args, **kwargs))
def has_worker(self) -> bool:
"""
Some trainer has backend worker to support parallel training
This method can tell if the worker is enabled.
Returns
-------
bool:
if the worker is enabled
"""
return False
def worker(self):
"""
start the worker
Raises
------
NotImplementedError:
If the worker is not supported
"""
raise NotImplementedError(f"Please implement the `worker` method")
class TrainerR(Trainer):
"""
@@ -252,7 +292,7 @@ class TrainerR(Trainer):
if experiment_name is None:
experiment_name = self.experiment_name
recs = []
for task in tasks:
for task in tqdm(tasks):
rec = train_func(task, experiment_name, **kwargs)
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
recs.append(rec)
@@ -457,6 +497,9 @@ class TrainerRM(Trainer):
task_pool = experiment_name
run_task(train_func, task_pool=task_pool, experiment_name=experiment_name)
def has_worker(self) -> bool:
return True
class DelayTrainerRM(TrainerRM):
"""
@@ -579,3 +622,6 @@ class DelayTrainerRM(TrainerRM):
experiment_name=experiment_name,
before_status=TaskManager.STATUS_PART_DONE,
)
def has_worker(self) -> bool:
return True