mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Merge pull request #290 from you-n-g/online_srv
init version of online serving and rolling
This commit is contained in:
BIN
docs/_static/img/online_serving.png
vendored
Normal file
BIN
docs/_static/img/online_serving.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 440 KiB |
@@ -14,6 +14,9 @@ Serializable Class
|
||||
|
||||
``Qlib`` provides a base class ``qlib.utils.serial.Serializable``, whose state can be dumped into or loaded from disk in `pickle` format.
|
||||
When users dump the state of a ``Serializable`` instance, the attributes of the instance whose name **does not** start with `_` will be saved on the disk.
|
||||
However, users can use ``config`` method or override ``default_dump_all`` attribute to prevent this feature.
|
||||
|
||||
Users can also override ``pickle_backend`` attribute to choose a pickle backend. The supported value is "pickle" (default and common) and "dill" (dump more things such as function, more information in `here <https://pypi.org/project/dill/>`_).
|
||||
|
||||
Example
|
||||
==========================
|
||||
|
||||
89
docs/advanced/task_management.rst
Normal file
89
docs/advanced/task_management.rst
Normal file
@@ -0,0 +1,89 @@
|
||||
.. _task_management:
|
||||
|
||||
=================================
|
||||
Task Management
|
||||
=================================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
|
||||
Introduction
|
||||
=============
|
||||
|
||||
The `Workflow <../component/introduction.html>`_ part introduces how to run research workflow in a loosely-coupled way. But it can only execute one ``task`` when you use ``qrun``.
|
||||
To automatically generate and execute different tasks, ``Task Management`` provides a whole process including `Task Generating`_, `Task Storing`_, `Task Training`_ and `Task Collecting`_.
|
||||
With this module, users can run their ``task`` automatically at different periods, in different losses, or even by different models.
|
||||
|
||||
This whole process can be used in `Online Serving <../component/online.html>`_.
|
||||
|
||||
An example of the entire process is shown `here <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.
|
||||
|
||||
Task Generating
|
||||
===============
|
||||
A ``task`` consists of `Model`, `Dataset`, `Record`, or anything added by users.
|
||||
The specific task template can be viewed in
|
||||
`Task Section <../component/workflow.html#task-section>`_.
|
||||
Even though the task template is fixed, users can customize their ``TaskGen`` to generate different ``task`` by task template.
|
||||
|
||||
Here is the base class of ``TaskGen``:
|
||||
|
||||
.. autoclass:: qlib.workflow.task.gen.TaskGen
|
||||
:members:
|
||||
|
||||
``Qlib`` provides a class `RollingGen <https://github.com/microsoft/qlib/tree/main/qlib/workflow/task/gen.py>`_ to generate a list of ``task`` of the dataset in different date segments.
|
||||
This class allows users to verify the effect of data from different periods on the model in one experiment. More information is `here <../reference/api.html#TaskGen>`_.
|
||||
|
||||
Task Storing
|
||||
===============
|
||||
To achieve higher efficiency and the possibility of cluster operation, ``Task Manager`` will store all tasks in `MongoDB <https://www.mongodb.com/>`_.
|
||||
``TaskManager`` can fetch undone tasks automatically and manage the lifecycle of a set of tasks with error handling.
|
||||
Users **MUST** finish the configuration of `MongoDB <https://www.mongodb.com/>`_ when using this module.
|
||||
|
||||
Users need to provide the MongoDB URL and database name for using ``TaskManager`` in `initialization <../start/initialization.html#Parameters>`_ or make a statement like this.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from qlib.config import C
|
||||
C["mongo"] = {
|
||||
"task_url" : "mongodb://localhost:27017/", # your MongoDB url
|
||||
"task_db_name" : "rolling_db" # database name
|
||||
}
|
||||
|
||||
.. autoclass:: qlib.workflow.task.manage.TaskManager
|
||||
:members:
|
||||
|
||||
More information of ``Task Manager`` can be found in `here <../reference/api.html#TaskManager>`_.
|
||||
|
||||
Task Training
|
||||
===============
|
||||
After generating and storing those ``task``, it's time to run the ``task`` which is in the *WAITING* status.
|
||||
``Qlib`` provides a method called ``run_task`` to run those ``task`` in task pool, however, users can also customize how tasks are executed.
|
||||
An easy way to get the ``task_func`` is using ``qlib.model.trainer.task_train`` directly.
|
||||
It will run the whole workflow defined by ``task``, which includes *Model*, *Dataset*, *Record*.
|
||||
|
||||
.. autofunction:: qlib.workflow.task.manage.run_task
|
||||
|
||||
Meanwhile, ``Qlib`` provides a module called ``Trainer``.
|
||||
|
||||
.. autoclass:: qlib.model.trainer.Trainer
|
||||
:members:
|
||||
|
||||
``Trainer`` will train a list of tasks and return a list of model recorders.
|
||||
``Qlib`` offer two kinds of Trainer, TrainerR is the simplest way and TrainerRM is based on TaskManager to help manager tasks lifecycle automatically.
|
||||
If you do not want to use ``Task Manager`` to manage tasks, then use TrainerR to train a list of tasks generated by ``TaskGen`` is enough.
|
||||
`Here <../reference/api.html#Trainer>`_ are the details about different ``Trainer``.
|
||||
|
||||
Task Collecting
|
||||
===============
|
||||
To collect the results of ``task`` after training, ``Qlib`` provides `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_ to collect the results in a readable, expandable and loosely-coupled way.
|
||||
|
||||
`Collector <../reference/api.html#Collector>`_ can collect objects from everywhere and process them such as merging, grouping, averaging and so on. It has 2 step action including ``collect`` (collect anything in a dict) and ``process_collect`` (process collected dict).
|
||||
|
||||
`Group <../reference/api.html#Group>`_ also has 2 steps including ``group`` (can group a set of object based on `group_func` and change them to a dict) and ``reduce`` (can make a dict become an ensemble based on some rule).
|
||||
For example: {(A,B,C1): object, (A,B,C2): object} ---``group``---> {(A,B): {C1: object, C2: object}} ---``reduce``---> {(A,B): object}
|
||||
|
||||
`Ensemble <../reference/api.html#Ensemble>`_ can merge the objects in an ensemble.
|
||||
For example: {C1: object, C2: object} ---``Ensemble``---> object
|
||||
|
||||
So the hierarchy is ``Collector``'s second step corresponds to ``Group``. And ``Group``'s second step correspond to ``Ensemble``.
|
||||
|
||||
For more information, please see `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_, or the `example <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.
|
||||
46
docs/component/online.rst
Normal file
46
docs/component/online.rst
Normal file
@@ -0,0 +1,46 @@
|
||||
.. _online:
|
||||
|
||||
=================================
|
||||
Online Serving
|
||||
=================================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
|
||||
Introduction
|
||||
=============
|
||||
|
||||
.. image:: ../_static/img/online_serving.png
|
||||
:align: center
|
||||
|
||||
|
||||
In addition to backtesting, one way to test a model is effective is to make predictions in real market conditions or even do real trading based on those predictions.
|
||||
``Online Serving`` is a set of modules for online models using the latest data,
|
||||
which including `Online Manager <#Online Manager>`_, `Online Strategy <#Online Strategy>`_, `Online Tool <#Online Tool>`_, `Updater <#Updater>`_.
|
||||
|
||||
`Here <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are several examples for reference, which demonstrate different features of ``Online Serving``.
|
||||
If you have many models or `task` needs to be managed, please consider `Task Management <../advanced/task_management.html>`_.
|
||||
The `examples <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are based on some components in `Task Management <../advanced/task_management.html>`_ such as ``TrainerRM`` or ``Collector``.
|
||||
|
||||
Online Manager
|
||||
=============
|
||||
|
||||
.. automodule:: qlib.workflow.online.manager
|
||||
:members:
|
||||
|
||||
Online Strategy
|
||||
=============
|
||||
|
||||
.. automodule:: qlib.workflow.online.strategy
|
||||
:members:
|
||||
|
||||
Online Tool
|
||||
=============
|
||||
|
||||
.. automodule:: qlib.workflow.online.utils
|
||||
:members:
|
||||
|
||||
Updater
|
||||
=============
|
||||
|
||||
.. automodule:: qlib.workflow.online.update
|
||||
:members:
|
||||
@@ -42,6 +42,7 @@ Document Structure
|
||||
Intraday Trading: Model&Strategy Testing <component/backtest.rst>
|
||||
Qlib Recorder: Experiment Management <component/recorder.rst>
|
||||
Analysis: Evaluation & Results Analysis <component/report.rst>
|
||||
Online Serving: Online Management & Strategy & Tool <component/online.rst>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
@@ -50,6 +51,7 @@ Document Structure
|
||||
Building Formulaic Alphas <advanced/alpha.rst>
|
||||
Online & Offline mode <advanced/server.rst>
|
||||
Serialization <advanced/serial.rst>
|
||||
Task Management <advanced/task_management.rst>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
|
||||
@@ -154,6 +154,70 @@ Record Template
|
||||
.. automodule:: qlib.workflow.record_temp
|
||||
:members:
|
||||
|
||||
Task Management
|
||||
====================
|
||||
|
||||
|
||||
TaskGen
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.task.gen
|
||||
:members:
|
||||
|
||||
TaskManager
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.task.manage
|
||||
:members:
|
||||
|
||||
Trainer
|
||||
--------------------
|
||||
.. automodule:: qlib.model.trainer
|
||||
:members:
|
||||
|
||||
Collector
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.task.collect
|
||||
:members:
|
||||
|
||||
Group
|
||||
--------------------
|
||||
.. automodule:: qlib.model.ens.group
|
||||
:members:
|
||||
|
||||
Ensemble
|
||||
--------------------
|
||||
.. automodule:: qlib.model.ens.ensemble
|
||||
:members:
|
||||
|
||||
Utils
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.task.utils
|
||||
:members:
|
||||
|
||||
|
||||
Online Serving
|
||||
====================
|
||||
|
||||
|
||||
Online Manager
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.online.manager
|
||||
:members:
|
||||
|
||||
Online Strategy
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.online.strategy
|
||||
:members:
|
||||
|
||||
Online Tool
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.online.utils
|
||||
:members:
|
||||
|
||||
RecordUpdater
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.online.update
|
||||
:members:
|
||||
|
||||
|
||||
Utils
|
||||
====================
|
||||
@@ -162,4 +226,7 @@ Serializable
|
||||
--------------------
|
||||
|
||||
.. automodule:: qlib.utils.serial.Serializable
|
||||
:members:
|
||||
:members:
|
||||
|
||||
|
||||
|
||||
@@ -75,3 +75,14 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
|
||||
"default_exp_name": "Experiment",
|
||||
}
|
||||
})
|
||||
- `mongo`
|
||||
Type: dict, optional parameter, the setting of `MongoDB <https://www.mongodb.com/>`_ which will be used in some features such as `Task Management <../advanced/task_management.html>`_, with high performance and clustered processing.
|
||||
Users need finished `installation <https://www.mongodb.com/try/download/community>`_ firstly, and run it in a fixed URL.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
# For example, you can initialize qlib below
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo={
|
||||
"task_url": "mongodb://localhost:27017/", # your mongo url
|
||||
"task_db_name": "rolling_db", # the database name of Task Management
|
||||
})
|
||||
|
||||
159
examples/model_rolling/task_manager_rolling.py
Normal file
159
examples/model_rolling/task_manager_rolling.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This example shows how a TrainerRM works based on TaskManager with rolling tasks.
|
||||
After training, how to collect the rolling results will be shown in task_collecting.
|
||||
"""
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.model.trainer import TrainerRM
|
||||
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": "csi100",
|
||||
}
|
||||
|
||||
dataset_config = {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
record_config = [
|
||||
{
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
{
|
||||
"class": "SigAnaRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
]
|
||||
|
||||
# use lgb
|
||||
task_lgb_config = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
# use xgboost
|
||||
task_xgboost_config = {
|
||||
"model": {
|
||||
"class": "XGBModel",
|
||||
"module_path": "qlib.contrib.model.xgboost",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
|
||||
class RollingTaskExample:
|
||||
def __init__(
|
||||
self,
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region=REG_CN,
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
experiment_name="rolling_exp",
|
||||
task_pool="rolling_task",
|
||||
task_config=[task_xgboost_config, task_lgb_config],
|
||||
rolling_step=550,
|
||||
rolling_type=RollingGen.ROLL_SD,
|
||||
):
|
||||
# TaskManager config
|
||||
mongo_conf = {
|
||||
"task_url": task_url,
|
||||
"task_db_name": task_db_name,
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
self.experiment_name = experiment_name
|
||||
self.task_pool = task_pool
|
||||
self.task_config = task_config
|
||||
self.rolling_gen = RollingGen(step=rolling_step, rtype=rolling_type)
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
print("========== reset ==========")
|
||||
TaskManager(task_pool=self.task_pool).remove()
|
||||
exp = R.get_exp(experiment_name=self.experiment_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
def task_generating(self):
|
||||
print("========== task_generating ==========")
|
||||
tasks = task_generator(
|
||||
tasks=self.task_config,
|
||||
generators=self.rolling_gen, # generate different date segments
|
||||
)
|
||||
pprint(tasks)
|
||||
return tasks
|
||||
|
||||
def task_training(self, tasks):
|
||||
print("========== task_training ==========")
|
||||
trainer = TrainerRM(self.experiment_name, self.task_pool)
|
||||
trainer.train(tasks)
|
||||
|
||||
def task_collecting(self):
|
||||
print("========== task_collecting ==========")
|
||||
|
||||
def rec_key(recorder):
|
||||
task_config = recorder.load_object("task")
|
||||
model_key = task_config["model"]["class"]
|
||||
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
|
||||
return model_key, rolling_key
|
||||
|
||||
def my_filter(recorder):
|
||||
# only choose the results of "LGBModel"
|
||||
model_key, rolling_key = rec_key(recorder)
|
||||
if model_key == "LGBModel":
|
||||
return True
|
||||
return False
|
||||
|
||||
collector = RecorderCollector(
|
||||
experiment=self.experiment_name,
|
||||
process_list=RollingGroup(),
|
||||
rec_key_func=rec_key,
|
||||
rec_filter_func=my_filter,
|
||||
)
|
||||
print(collector())
|
||||
|
||||
def main(self):
|
||||
self.reset()
|
||||
tasks = self.task_generating()
|
||||
self.task_training(tasks)
|
||||
self.task_collecting()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to see the whole process with your own parameters, use the command below
|
||||
# python task_manager_rolling.py main --experiment_name="your_exp_name"
|
||||
fire.Fire(RollingTaskExample)
|
||||
146
examples/online_srv/online_management_simulate.py
Normal file
146
examples/online_srv/online_management_simulate.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This example is about how can simulate the OnlineManager based on rolling tasks.
|
||||
"""
|
||||
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2018-01-01",
|
||||
"end_time": "2018-10-31",
|
||||
"fit_start_time": "2018-01-01",
|
||||
"fit_end_time": "2018-03-31",
|
||||
"instruments": "csi100",
|
||||
}
|
||||
|
||||
dataset_config = {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2018-01-01", "2018-03-31"),
|
||||
"valid": ("2018-04-01", "2018-05-31"),
|
||||
"test": ("2018-06-01", "2018-09-10"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
record_config = [
|
||||
{
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
{
|
||||
"class": "SigAnaRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
]
|
||||
|
||||
# use lgb model
|
||||
task_lgb_config = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
# use xgboost model
|
||||
task_xgboost_config = {
|
||||
"model": {
|
||||
"class": "XGBModel",
|
||||
"module_path": "qlib.contrib.model.xgboost",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
|
||||
class OnlineSimulationExample:
|
||||
def __init__(
|
||||
self,
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
exp_name="rolling_exp",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
task_pool="rolling_task",
|
||||
rolling_step=80,
|
||||
start_time="2018-09-10",
|
||||
end_time="2018-10-31",
|
||||
tasks=[task_xgboost_config, task_lgb_config],
|
||||
):
|
||||
"""
|
||||
Init OnlineManagerExample.
|
||||
|
||||
Args:
|
||||
provider_uri (str, optional): the provider uri. Defaults to "~/.qlib/qlib_data/cn_data".
|
||||
region (str, optional): the stock region. Defaults to "cn".
|
||||
exp_name (str, optional): the experiment name. Defaults to "rolling_exp".
|
||||
task_url (str, optional): your MongoDB url. Defaults to "mongodb://10.0.0.4:27017/".
|
||||
task_db_name (str, optional): database name. Defaults to "rolling_db".
|
||||
task_pool (str, optional): the task pool name (a task pool is a collection in MongoDB). Defaults to "rolling_task".
|
||||
rolling_step (int, optional): the step for rolling. Defaults to 80.
|
||||
start_time (str, optional): the start time of simulating. Defaults to "2018-09-10".
|
||||
end_time (str, optional): the end time of simulating. Defaults to "2018-10-31".
|
||||
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
|
||||
"""
|
||||
self.exp_name = exp_name
|
||||
self.task_pool = task_pool
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
mongo_conf = {
|
||||
"task_url": task_url,
|
||||
"task_db_name": task_db_name,
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
self.rolling_gen = RollingGen(
|
||||
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
|
||||
) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
|
||||
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
|
||||
self.rolling_online_manager = OnlineManager(
|
||||
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
|
||||
trainer=self.trainer,
|
||||
begin_time=self.start_time,
|
||||
)
|
||||
self.tasks = tasks
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
TaskManager(self.task_pool).remove()
|
||||
exp = R.get_exp(experiment_name=self.exp_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
# Run this to run all workflow automatically
|
||||
def main(self):
|
||||
print("========== reset ==========")
|
||||
self.reset()
|
||||
print("========== simulate ==========")
|
||||
self.rolling_online_manager.simulate(end_time=self.end_time)
|
||||
print("========== collect results ==========")
|
||||
print(self.rolling_online_manager.get_collector()())
|
||||
print("========== signals ==========")
|
||||
print(self.rolling_online_manager.get_signals())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to run all workflow automatically with your own parameters, use the command below
|
||||
# python online_management_simulate.py main --experiment_name="your_exp_name" --rolling_step=60
|
||||
fire.Fire(OnlineSimulationExample)
|
||||
181
examples/online_srv/rolling_online_management.py
Normal file
181
examples/online_srv/rolling_online_management.py
Normal file
@@ -0,0 +1,181 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This example shows how OnlineManager works with rolling tasks.
|
||||
There are four parts including first train, routine 1, add strategy and routine 2.
|
||||
Firstly, the OnlineManager will finish the first training and set trained models to `online` models.
|
||||
Next, the OnlineManager will finish a routine process, including update online prediction -> prepare tasks -> prepare new models -> prepare signals
|
||||
Then, we will add some new strategies to the OnlineManager. This will finish first training of new strategies.
|
||||
Finally, the OnlineManager will finish second routine and update all strategies.
|
||||
"""
|
||||
|
||||
import os
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2013-01-01",
|
||||
"end_time": "2020-09-25",
|
||||
"fit_start_time": "2013-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": "csi100",
|
||||
}
|
||||
|
||||
dataset_config = {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2013-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2015-12-31"),
|
||||
"test": ("2016-01-01", "2020-07-10"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
record_config = [
|
||||
{
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
{
|
||||
"class": "SigAnaRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
]
|
||||
|
||||
# use lgb model
|
||||
task_lgb_config = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
# use xgboost model
|
||||
task_xgboost_config = {
|
||||
"model": {
|
||||
"class": "XGBModel",
|
||||
"module_path": "qlib.contrib.model.xgboost",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
|
||||
class RollingOnlineExample:
|
||||
def __init__(
|
||||
self,
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
rolling_step=550,
|
||||
tasks=[task_xgboost_config],
|
||||
add_tasks=[task_lgb_config],
|
||||
):
|
||||
mongo_conf = {
|
||||
"task_url": task_url, # your MongoDB url
|
||||
"task_db_name": task_db_name, # database name
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
self.tasks = tasks
|
||||
self.add_tasks = add_tasks
|
||||
self.rolling_step = rolling_step
|
||||
strategies = []
|
||||
for task in tasks:
|
||||
name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
|
||||
strategies.append(
|
||||
RollingStrategy(
|
||||
name_id,
|
||||
task,
|
||||
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
|
||||
)
|
||||
)
|
||||
|
||||
self.rolling_online_manager = OnlineManager(strategies)
|
||||
|
||||
_ROLLING_MANAGER_PATH = (
|
||||
".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine.
|
||||
)
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
for task in self.tasks + self.add_tasks:
|
||||
name_id = task["model"]["class"]
|
||||
exp = R.get_exp(experiment_name=name_id)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
if os.path.exists(self._ROLLING_MANAGER_PATH):
|
||||
os.remove(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
def first_run(self):
|
||||
print("========== reset ==========")
|
||||
self.reset()
|
||||
print("========== first_run ==========")
|
||||
self.rolling_online_manager.first_train()
|
||||
print("========== collect results ==========")
|
||||
print(self.rolling_online_manager.get_collector()())
|
||||
print("========== dump ==========")
|
||||
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
def routine(self):
|
||||
print("========== load ==========")
|
||||
self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH)
|
||||
print("========== routine ==========")
|
||||
self.rolling_online_manager.routine()
|
||||
print("========== collect results ==========")
|
||||
print(self.rolling_online_manager.get_collector()())
|
||||
print("========== signals ==========")
|
||||
print(self.rolling_online_manager.get_signals())
|
||||
print("========== dump ==========")
|
||||
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
def add_strategy(self):
|
||||
print("========== load ==========")
|
||||
self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH)
|
||||
print("========== add strategy ==========")
|
||||
strategies = []
|
||||
for task in self.add_tasks:
|
||||
name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
|
||||
strategies.append(
|
||||
RollingStrategy(
|
||||
name_id,
|
||||
task,
|
||||
RollingGen(step=self.rolling_step, rtype=RollingGen.ROLL_SD),
|
||||
)
|
||||
)
|
||||
self.rolling_online_manager.add_strategy(strategies=strategies)
|
||||
print("========== dump ==========")
|
||||
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
def main(self):
|
||||
self.first_run()
|
||||
self.routine()
|
||||
self.add_strategy()
|
||||
self.routine()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
####### to train the first version's models, use the command below
|
||||
# python rolling_online_management.py first_run
|
||||
|
||||
####### to update the models and predictions after the trading time, use the command below
|
||||
# python rolling_online_management.py routine
|
||||
|
||||
####### to define your own parameters, use `--`
|
||||
# python rolling_online_management.py first_run --exp_name='your_exp_name' --rolling_step=40
|
||||
fire.Fire(RollingOnlineExample)
|
||||
91
examples/online_srv/update_online_pred.py
Normal file
91
examples/online_srv/update_online_pred.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This example shows how OnlineTool works when we need update prediction.
|
||||
There are two parts including first_train and update_online_pred.
|
||||
Firstly, we will finish the training and set the trained models to the `online` models.
|
||||
Next, we will finish updating online predictions.
|
||||
"""
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow.online.utils import OnlineToolR
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": "csi100",
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
"record": {
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class UpdatePredExample:
|
||||
def __init__(
|
||||
self, provider_uri="~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv", task_config=task
|
||||
):
|
||||
qlib.init(provider_uri=provider_uri, region=region)
|
||||
self.experiment_name = experiment_name
|
||||
self.online_tool = OnlineToolR(self.experiment_name)
|
||||
self.task_config = task_config
|
||||
|
||||
def first_train(self):
|
||||
rec = task_train(self.task_config, experiment_name=self.experiment_name)
|
||||
self.online_tool.reset_online_tag(rec) # set to online model
|
||||
|
||||
def update_online_pred(self):
|
||||
self.online_tool.update_online_pred()
|
||||
|
||||
def main(self):
|
||||
self.first_train()
|
||||
self.update_online_pred()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to train a model and set it to online model, use the command below
|
||||
# python update_online_pred.py first_train
|
||||
## to update online predictions once a day, use the command below
|
||||
# python update_online_pred.py update_online_pred
|
||||
## to see the whole process with your own parameters, use the command below
|
||||
# python update_online_pred.py main --experiment_name="your_exp_name"
|
||||
fire.Fire(UpdatePredExample)
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
|
||||
__version__ = "0.6.3.99"
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
|
||||
|
||||
import os
|
||||
@@ -10,12 +11,13 @@ import yaml
|
||||
import logging
|
||||
import platform
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from .log import get_module_logger
|
||||
|
||||
|
||||
# init qlib
|
||||
def init(default_conf="client", **kwargs):
|
||||
from .config import C
|
||||
from .log import get_module_logger
|
||||
from .data.cache import H
|
||||
|
||||
H.clear()
|
||||
@@ -48,7 +50,6 @@ def init(default_conf="client", **kwargs):
|
||||
|
||||
|
||||
def _mount_nfs_uri(C):
|
||||
from .log import get_module_logger
|
||||
|
||||
LOG = get_module_logger("mount nfs", level=logging.INFO)
|
||||
|
||||
@@ -151,3 +152,74 @@ def init_from_yaml_conf(conf_path, **kwargs):
|
||||
config.update(kwargs)
|
||||
default_conf = config.pop("default_conf", "client")
|
||||
init(default_conf, **config)
|
||||
|
||||
|
||||
def get_project_path(config_name="config.yaml", cur_path=None) -> Path:
|
||||
"""
|
||||
If users are building a project follow the following pattern.
|
||||
- Qlib is a sub folder in project path
|
||||
- There is a file named `config.yaml` in qlib.
|
||||
|
||||
For example:
|
||||
If your project file system stucuture follows such a pattern
|
||||
|
||||
<project_path>/
|
||||
- config.yaml
|
||||
- ...some folders...
|
||||
- qlib/
|
||||
|
||||
This folder will return <project_path>
|
||||
|
||||
NOTE: link is not supported here.
|
||||
|
||||
|
||||
This method is often used when
|
||||
- user want to use a relative config path instead of hard-coding qlib config path in code
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError:
|
||||
If project path is not found
|
||||
"""
|
||||
if cur_path is None:
|
||||
cur_path = Path(__file__).absolute().resolve()
|
||||
while True:
|
||||
if (cur_path / config_name).exists():
|
||||
return cur_path
|
||||
if cur_path == cur_path.parent:
|
||||
raise FileNotFoundError("We can't find the project path")
|
||||
cur_path = cur_path.parent
|
||||
|
||||
|
||||
def auto_init(**kwargs):
|
||||
"""
|
||||
This function will init qlib automatically with following priority
|
||||
- Find the project configuration and init qlib
|
||||
- The parsing process will be affected by the `conf_type` of the configuration file
|
||||
- Init qlib with default config
|
||||
"""
|
||||
|
||||
try:
|
||||
pp = get_project_path(cur_path=kwargs.pop("cur_path", None))
|
||||
except FileNotFoundError:
|
||||
init(**kwargs)
|
||||
else:
|
||||
|
||||
conf_pp = pp / "config.yaml"
|
||||
with conf_pp.open() as f:
|
||||
conf = yaml.safe_load(f)
|
||||
|
||||
conf_type = conf.get("conf_type", "origin")
|
||||
if conf_type == "origin":
|
||||
# The type of config is just like original qlib config
|
||||
init_from_yaml_conf(conf_pp, **kwargs)
|
||||
elif conf_type == "ref":
|
||||
# This config type will be more convenient in following scenario
|
||||
# - There is a shared configure file and you don't want to edit it inplace.
|
||||
# - The shared configure may be updated later and you don't want to copy it.
|
||||
# - You have some customized config.
|
||||
qlib_conf_path = conf["qlib_cfg"]
|
||||
qlib_conf_update = conf.get("qlib_cfg_update")
|
||||
init_from_yaml_conf(qlib_conf_path, **qlib_conf_update, **kwargs)
|
||||
logger = get_module_logger("Initialization")
|
||||
logger.info(f"Auto load project config: {conf_pp}")
|
||||
|
||||
@@ -33,6 +33,9 @@ class Config:
|
||||
|
||||
raise AttributeError(f"No such {attr} in self._config")
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self.__dict__["_config"].get(key, default)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.__dict__["_config"][key] = value
|
||||
|
||||
@@ -131,7 +134,7 @@ _default_config = {
|
||||
},
|
||||
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"]}},
|
||||
},
|
||||
# Defatult config for experiment manager
|
||||
# Default config for experiment manager
|
||||
"exp_manager": {
|
||||
"class": "MLflowExpManager",
|
||||
"module_path": "qlib.workflow.expm",
|
||||
@@ -140,6 +143,11 @@ _default_config = {
|
||||
"default_exp_name": "Experiment",
|
||||
},
|
||||
},
|
||||
# Default config for MongoDB
|
||||
"mongo": {
|
||||
"task_url": "mongodb://localhost:27017/",
|
||||
"task_db_name": "default_task_db",
|
||||
},
|
||||
}
|
||||
|
||||
MODE_CONF = {
|
||||
@@ -310,8 +318,22 @@ class QlibConfig(Config):
|
||||
# clean up experiment when python program ends
|
||||
experiment_exit_handler()
|
||||
|
||||
# Supporting user reset qlib version (useful when user want to connect to qlib server with old version)
|
||||
self.reset_qlib_version()
|
||||
|
||||
self._registered = True
|
||||
|
||||
def reset_qlib_version(self):
|
||||
import qlib
|
||||
|
||||
reset_version = self.get("qlib_reset_version", None)
|
||||
if reset_version is not None:
|
||||
qlib.__version__ = reset_version
|
||||
else:
|
||||
qlib.__version__ = getattr(qlib, "__version__bak")
|
||||
# Due to a bug? that converting __version__ to _QlibConfig__version__bak
|
||||
# Using __version__bak instead of __version__
|
||||
|
||||
@property
|
||||
def registered(self):
|
||||
return self._registered
|
||||
|
||||
@@ -26,6 +26,7 @@ def check_transform_proc(proc_l, fit_start_time, fit_end_time):
|
||||
"fit_end_time": fit_end_time,
|
||||
}
|
||||
)
|
||||
# FIXME: the `module_path` parameter is missed.
|
||||
new_l.append({"class": klass.__name__, "kwargs": pkwargs})
|
||||
else:
|
||||
new_l.append(p)
|
||||
|
||||
@@ -27,7 +27,7 @@ class Dataset(Serializable):
|
||||
- setup data
|
||||
- The data related attributes' names should start with '_' so that it will not be saved on disk when serializing.
|
||||
|
||||
The data could specify the info to caculate the essential data for preparation
|
||||
The data could specify the info to calculate the essential data for preparation
|
||||
"""
|
||||
self.setup_data(**kwargs)
|
||||
super().__init__()
|
||||
@@ -92,7 +92,7 @@ class DatasetH(Dataset):
|
||||
handler : Union[dict, DataHandler]
|
||||
handler could be:
|
||||
|
||||
- insntance of `DataHandler`
|
||||
- instance of `DataHandler`
|
||||
|
||||
- config of `DataHandler`. Please refer to `DataHandler`
|
||||
|
||||
@@ -112,8 +112,9 @@ class DatasetH(Dataset):
|
||||
'outsample': ("2017-01-01", "2020-08-01",),
|
||||
}
|
||||
"""
|
||||
self.handler = init_instance_by_config(handler, accept_types=DataHandler)
|
||||
self.handler: DataHandler = init_instance_by_config(handler, accept_types=DataHandler)
|
||||
self.segments = segments.copy()
|
||||
self.fetch_kwargs = {}
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def config(self, handler_kwargs: dict = None, **kwargs):
|
||||
@@ -123,7 +124,7 @@ class DatasetH(Dataset):
|
||||
Parameters
|
||||
----------
|
||||
handler_kwargs : dict
|
||||
Config of DataHanlder, which could include the following arguments:
|
||||
Config of DataHandler, which could include the following arguments:
|
||||
|
||||
- arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'.
|
||||
|
||||
@@ -147,11 +148,11 @@ class DatasetH(Dataset):
|
||||
Parameters
|
||||
----------
|
||||
handler_kwargs : dict
|
||||
init arguments of DataHanlder, which could include the following arguments:
|
||||
init arguments of DataHandler, which could include the following arguments:
|
||||
|
||||
- init_type : Init Type of Handler
|
||||
|
||||
- enable_cache : wheter to enable cache
|
||||
- enable_cache : whether to enable cache
|
||||
|
||||
"""
|
||||
super().setup_data(**kwargs)
|
||||
@@ -171,7 +172,10 @@ class DatasetH(Dataset):
|
||||
----------
|
||||
slc : slice
|
||||
"""
|
||||
return self.handler.fetch(slc, **kwargs)
|
||||
if hasattr(self, "fetch_kwargs"):
|
||||
return self.handler.fetch(slc, **kwargs, **self.fetch_kwargs)
|
||||
else:
|
||||
return self.handler.fetch(slc, **kwargs)
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
@@ -199,6 +203,12 @@ class DatasetH(Dataset):
|
||||
The data to fetch: DK_*
|
||||
Default is DK_I, which indicate fetching data for **inference**.
|
||||
|
||||
kwargs :
|
||||
The parameters that kwargs may contain:
|
||||
flt_col : str
|
||||
It only exists in TSDatasetH, can be used to add a column of data(True or False) to filter data.
|
||||
This parameter is only supported when it is an instance of TSDatasetH.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[List[pd.DataFrame], pd.DataFrame]:
|
||||
@@ -231,7 +241,7 @@ class TSDataSampler:
|
||||
(T)ime-(S)eries DataSampler
|
||||
This is the result of TSDatasetH
|
||||
|
||||
It works like `torch.data.utils.Dataset`, it provides a very convient interface for constructing time-series
|
||||
It works like `torch.data.utils.Dataset`, it provides a very convenient interface for constructing time-series
|
||||
dataset based on tabular data.
|
||||
|
||||
If user have further requirements for processing data, user could process them based on `TSDataSampler` or create
|
||||
@@ -243,7 +253,9 @@ class TSDataSampler:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none"):
|
||||
def __init__(
|
||||
self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none", dtype=None, flt_data=None
|
||||
):
|
||||
"""
|
||||
Build a dataset which looks like torch.data.utils.Dataset.
|
||||
|
||||
@@ -265,6 +277,11 @@ class TSDataSampler:
|
||||
ffill with previous sample
|
||||
ffill+bfill:
|
||||
ffill with previous samples first and fill with later samples second
|
||||
flt_data : pd.Series
|
||||
a column of data(True or False) to filter data.
|
||||
None:
|
||||
kepp all data
|
||||
|
||||
"""
|
||||
self.start = start
|
||||
self.end = end
|
||||
@@ -272,23 +289,51 @@ class TSDataSampler:
|
||||
self.fillna_type = fillna_type
|
||||
assert get_level_index(data, "datetime") == 0
|
||||
self.data = lazy_sort_index(data)
|
||||
self.data_arr = np.array(self.data) # Get index from numpy.array will much faster than DataFrame.values!
|
||||
# NOTE: append last line with full NaN for better performance in `__getitem__`
|
||||
self.data_arr = np.append(self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan), axis=0)
|
||||
|
||||
kwargs = {"object": self.data}
|
||||
if dtype is not None:
|
||||
kwargs["dtype"] = dtype
|
||||
|
||||
self.data_arr = np.array(**kwargs) # Get index from numpy.array will much faster than DataFrame.values!
|
||||
# NOTE:
|
||||
# - append last line with full NaN for better performance in `__getitem__`
|
||||
# - Keep the same dtype will result in a better performance
|
||||
self.data_arr = np.append(
|
||||
self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype), axis=0
|
||||
)
|
||||
self.nan_idx = -1 # The last line is all NaN
|
||||
|
||||
# the data type will be changed
|
||||
# The index of usable data is between start_idx and end_idx
|
||||
self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
|
||||
self.idx_df, self.idx_map = self.build_index(self.data)
|
||||
self.data_index = deepcopy(self.data.index)
|
||||
|
||||
if flt_data is not None:
|
||||
self.flt_data = np.array(flt_data.reindex(self.data_index)).reshape(-1)
|
||||
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
|
||||
self.data_index = self.data_index[np.where(self.flt_data == True)[0]]
|
||||
|
||||
self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
|
||||
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
|
||||
|
||||
del self.data # save memory
|
||||
|
||||
@staticmethod
|
||||
def flt_idx_map(flt_data, idx_map):
|
||||
idx = 0
|
||||
new_idx_map = {}
|
||||
for i, exist in enumerate(flt_data):
|
||||
if exist:
|
||||
new_idx_map[idx] = idx_map[i]
|
||||
idx += 1
|
||||
return new_idx_map
|
||||
|
||||
def get_index(self):
|
||||
"""
|
||||
Get the pandas index of the data, it will be useful in following scenarios
|
||||
- Special sampler will be used (e.g. user want to sample day by day)
|
||||
"""
|
||||
return self.data.index[self.start_idx : self.end_idx]
|
||||
return self.data_index[self.start_idx : self.end_idx]
|
||||
|
||||
def config(self, **kwargs):
|
||||
# Config the attributes
|
||||
@@ -432,7 +477,7 @@ class TSDatasetH(DatasetH):
|
||||
(T)ime-(S)eries Dataset (H)andler
|
||||
|
||||
|
||||
Covnert the tabular data to Time-Series data
|
||||
Convert the tabular data to Time-Series data
|
||||
|
||||
Requirements analysis
|
||||
|
||||
@@ -461,7 +506,7 @@ class TSDatasetH(DatasetH):
|
||||
cal = sorted(cal)
|
||||
self.cal = cal
|
||||
|
||||
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
|
||||
def _prepare_raw_seg(self, slc: slice, **kwargs) -> pd.DataFrame:
|
||||
# Dataset decide how to slice data(Get more data for timeseries).
|
||||
start, end = slc.start, slc.stop
|
||||
start_idx = bisect.bisect_left(self.cal, pd.Timestamp(start))
|
||||
@@ -470,6 +515,25 @@ class TSDatasetH(DatasetH):
|
||||
|
||||
# TSDatasetH will retrieve more data for complete
|
||||
data = super()._prepare_seg(slice(pad_start, end), **kwargs)
|
||||
return data
|
||||
|
||||
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len)
|
||||
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
|
||||
"""
|
||||
split the _prepare_raw_seg is to leave a hook for data preprocessing before creating processing data
|
||||
"""
|
||||
dtype = kwargs.pop("dtype", None)
|
||||
start, end = slc.start, slc.stop
|
||||
flt_col = kwargs.pop("flt_col", None)
|
||||
# TSDatasetH will retrieve more data for complete
|
||||
data = self._prepare_raw_seg(slc, **kwargs)
|
||||
|
||||
flt_kwargs = deepcopy(kwargs)
|
||||
if flt_col is not None:
|
||||
flt_kwargs["col_set"] = flt_col
|
||||
flt_data = self._prepare_raw_seg(slc, **flt_kwargs)
|
||||
assert len(flt_data.columns) == 1
|
||||
else:
|
||||
flt_data = None
|
||||
|
||||
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype, flt_data=flt_data)
|
||||
return tsds
|
||||
|
||||
@@ -7,7 +7,7 @@ import bisect
|
||||
import logging
|
||||
import warnings
|
||||
from inspect import getfullargspec
|
||||
from typing import Union, Tuple, List, Iterator, Optional
|
||||
from typing import Callable, Union, Tuple, List, Iterator, Optional
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
@@ -36,7 +36,7 @@ class DataHandler(Serializable):
|
||||
The data handler try to maintain a handler with 2 level.
|
||||
`datetime` & `instruments`.
|
||||
|
||||
Any order of the index level can be suported (The order will be implied in the data).
|
||||
Any order of the index level can be supported (The order will be implied in the data).
|
||||
The order <`datetime`, `instruments`> will be used when the dataframe index name is missed.
|
||||
|
||||
Example of the data:
|
||||
@@ -51,6 +51,9 @@ class DataHandler(Serializable):
|
||||
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
|
||||
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
|
||||
|
||||
|
||||
Tips for improving the performance of datahandler
|
||||
- Fetching data with `col_set=CS_RAW` will return the raw data and may avoid pandas from copying the data when calling `loc`
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -74,7 +77,7 @@ class DataHandler(Serializable):
|
||||
data_loader : Union[dict, str, DataLoader]
|
||||
data loader to load the data.
|
||||
init_data :
|
||||
intialize the original data in the constructor.
|
||||
initialize the original data in the constructor.
|
||||
fetch_orig : bool
|
||||
Return the original data instead of copy if possible.
|
||||
"""
|
||||
@@ -125,7 +128,7 @@ class DataHandler(Serializable):
|
||||
|
||||
def setup_data(self, enable_cache: bool = False):
|
||||
"""
|
||||
Set Up the data in case of running intialization for multiple time
|
||||
Set Up the data in case of running initialization for multiple time
|
||||
|
||||
It is responsible for maintaining following variable
|
||||
1) self._data
|
||||
@@ -163,6 +166,7 @@ class DataHandler(Serializable):
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = CS_ALL,
|
||||
squeeze: bool = False,
|
||||
proc_func: Callable = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
fetch data from underlying data source
|
||||
@@ -185,6 +189,14 @@ class DataHandler(Serializable):
|
||||
- if isinstance(col_set, List[str]):
|
||||
|
||||
select several sets of meaningful columns, the returned data has multiple levels
|
||||
proc_func: Callable
|
||||
- Give a hook for processing data before fetching
|
||||
- An example to explain the necessity of the hook:
|
||||
- A Dataset learned some processors to process data which is related to data segmentation
|
||||
- It will apply them every time when preparing data.
|
||||
- The learned processor require the dataframe remains the same format when fitting and applying
|
||||
- However the data format will change according to the parameters.
|
||||
- So the processors should be applied to the underlayer data.
|
||||
|
||||
squeeze : bool
|
||||
whether squeeze columns and index
|
||||
@@ -193,8 +205,15 @@ class DataHandler(Serializable):
|
||||
-------
|
||||
pd.DataFrame.
|
||||
"""
|
||||
if proc_func is None:
|
||||
df = self._data
|
||||
else:
|
||||
# FIXME: fetching by time first will be more friendly to `proc_func`
|
||||
# Copy in case of `proc_func` changing the data inplace....
|
||||
df = proc_func(fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig).copy())
|
||||
|
||||
# Fetch column first will be more friendly to SepDataFrame
|
||||
df = self._fetch_df_by_col(self._data, col_set)
|
||||
df = self._fetch_df_by_col(df, col_set)
|
||||
df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
|
||||
if squeeze:
|
||||
# squeeze columns
|
||||
@@ -261,6 +280,10 @@ class DataHandler(Serializable):
|
||||
class DataHandlerLP(DataHandler):
|
||||
"""
|
||||
DataHandler with **(L)earnable (P)rocessor**
|
||||
|
||||
Tips to improving the performance of data handler
|
||||
- To reduce the memory cost
|
||||
- `drop_raw=True`: this will modify the data inplace on raw data;
|
||||
"""
|
||||
|
||||
# data key
|
||||
@@ -430,7 +453,7 @@ class DataHandlerLP(DataHandler):
|
||||
|
||||
def setup_data(self, init_type: str = IT_FIT_SEQ, **kwargs):
|
||||
"""
|
||||
Set up the data in case of running intialization for multiple time
|
||||
Set up the data in case of running initialization for multiple time
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -474,6 +497,7 @@ class DataHandlerLP(DataHandler):
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set=DataHandler.CS_ALL,
|
||||
data_key: str = DK_I,
|
||||
proc_func: Callable = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
fetch data from underlying data source
|
||||
@@ -488,12 +512,18 @@ class DataHandlerLP(DataHandler):
|
||||
select a set of meaningful columns.(e.g. features, columns).
|
||||
data_key : str
|
||||
the data to fetch: DK_*.
|
||||
proc_func: Callable
|
||||
please refer to the doc of DataHandler.fetch
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
"""
|
||||
df = self._get_df_by_key(data_key)
|
||||
if proc_func is not None:
|
||||
# FIXME: fetch by time first will be more friendly to proc_func
|
||||
# Copy incase of `proc_func` changing the data inplace....
|
||||
df = proc_func(fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig).copy())
|
||||
# Fetch column first will be more friendly to SepDataFrame
|
||||
df = self._fetch_df_by_col(df, col_set)
|
||||
return fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
|
||||
|
||||
@@ -13,6 +13,7 @@ from qlib.data import D
|
||||
from qlib.data import filter as filter_module
|
||||
from qlib.data.filter import BaseDFilter
|
||||
from qlib.utils import load_dataset, init_instance_by_config
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
|
||||
class DataLoader(abc.ABC):
|
||||
@@ -224,6 +225,10 @@ class DataLoaderDH(DataLoader):
|
||||
DataLoader based on (D)ata (H)andler
|
||||
It is designed to load multiple data from data handler
|
||||
- If you just want to load data from single datahandler, you can write them in single data handler
|
||||
|
||||
TODO: What make this module not that easy to use.
|
||||
- For online scenario
|
||||
- The underlayer data handler should be configured. But data loader doesn't provide such interface & hook.
|
||||
"""
|
||||
|
||||
def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False):
|
||||
@@ -265,7 +270,7 @@ class DataLoaderDH(DataLoader):
|
||||
|
||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
if instruments is not None:
|
||||
LOG.warning(f"instruments[{instruments}] is ignored")
|
||||
get_module_logger(self.__class__.__name__).warning(f"instruments[{instruments}] is ignored")
|
||||
|
||||
if self.is_group:
|
||||
df = pd.concat(
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
from typing import Union, Text
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
@@ -14,7 +15,7 @@ from ...utils.paral import datetime_groupby_apply
|
||||
EPS = 1e-12
|
||||
|
||||
|
||||
def get_group_columns(df: pd.DataFrame, group: str):
|
||||
def get_group_columns(df: pd.DataFrame, group: Union[Text, None]):
|
||||
"""
|
||||
get a group of columns from multi-index columns DataFrame
|
||||
|
||||
|
||||
0
qlib/model/ens/__init__.py
Normal file
0
qlib/model/ens/__init__.py
Normal file
115
qlib/model/ens/ensemble.py
Normal file
115
qlib/model/ens/ensemble.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
Ensemble module can merge the objects in an Ensemble. For example, if there are many submodels predictions, we may need to merge them into an ensemble prediction.
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
import pandas as pd
|
||||
from qlib.utils import FLATTEN_TUPLE, flatten_dict
|
||||
|
||||
|
||||
class Ensemble:
|
||||
"""Merge the ensemble_dict into an ensemble object.
|
||||
|
||||
For example: {Rollinga_b: object, Rollingb_c: object} -> object
|
||||
|
||||
When calling this class:
|
||||
|
||||
Args:
|
||||
ensemble_dict (dict): the ensemble dict like {name: things} waiting for merging
|
||||
|
||||
Returns:
|
||||
object: the ensemble object
|
||||
"""
|
||||
|
||||
def __call__(self, ensemble_dict: dict, *args, **kwargs):
|
||||
raise NotImplementedError(f"Please implement the `__call__` method.")
|
||||
|
||||
|
||||
class SingleKeyEnsemble(Ensemble):
|
||||
|
||||
"""
|
||||
Extract the object if there is only one key and value in the dict. Make the result more readable.
|
||||
{Only key: Only value} -> Only value
|
||||
|
||||
If there is more than 1 key or less than 1 key, then do nothing.
|
||||
Even you can run this recursively to make dict more readable.
|
||||
|
||||
NOTE: Default runs recursively.
|
||||
|
||||
When calling this class:
|
||||
|
||||
Args:
|
||||
ensemble_dict (dict): the dict. The key of the dict will be ignored.
|
||||
|
||||
Returns:
|
||||
dict: the readable dict.
|
||||
"""
|
||||
|
||||
def __call__(self, ensemble_dict: Union[dict, object], recursion: bool = True) -> object:
|
||||
if not isinstance(ensemble_dict, dict):
|
||||
return ensemble_dict
|
||||
if recursion:
|
||||
tmp_dict = {}
|
||||
for k, v in ensemble_dict.items():
|
||||
tmp_dict[k] = self(v, recursion)
|
||||
ensemble_dict = tmp_dict
|
||||
keys = list(ensemble_dict.keys())
|
||||
if len(keys) == 1:
|
||||
ensemble_dict = ensemble_dict[keys[0]]
|
||||
return ensemble_dict
|
||||
|
||||
|
||||
class RollingEnsemble(Ensemble):
|
||||
|
||||
"""Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble.
|
||||
|
||||
NOTE: The values of dict must be pd.DataFrame, and have the index "datetime".
|
||||
|
||||
When calling this class:
|
||||
|
||||
Args:
|
||||
ensemble_dict (dict): a dict like {"A": pd.DataFrame, "B": pd.DataFrame}.
|
||||
The key of the dict will be ignored.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: the complete result of rolling.
|
||||
"""
|
||||
|
||||
def __call__(self, ensemble_dict: dict) -> pd.DataFrame:
|
||||
artifact_list = list(ensemble_dict.values())
|
||||
artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min())
|
||||
artifact = pd.concat(artifact_list)
|
||||
# If there are duplicated predition, use the latest perdiction
|
||||
artifact = artifact[~artifact.index.duplicated(keep="last")]
|
||||
artifact = artifact.sort_index()
|
||||
return artifact
|
||||
|
||||
|
||||
class AverageEnsemble(Ensemble):
|
||||
"""
|
||||
Average and standardize a dict of same shape dataframe like `prediction` or `IC` into an ensemble.
|
||||
|
||||
NOTE: The values of dict must be pd.DataFrame, and have the index "datetime". If it is a nested dict, then flat it.
|
||||
|
||||
When calling this class:
|
||||
|
||||
Args:
|
||||
ensemble_dict (dict): a dict like {"A": pd.DataFrame, "B": pd.DataFrame}.
|
||||
The key of the dict will be ignored.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: the complete result of averaging and standardizing.
|
||||
"""
|
||||
|
||||
def __call__(self, ensemble_dict: dict) -> pd.DataFrame:
|
||||
# need to flatten the nested dict
|
||||
ensemble_dict = flatten_dict(ensemble_dict, sep=FLATTEN_TUPLE)
|
||||
values = list(ensemble_dict.values())
|
||||
results = pd.concat(values, axis=1)
|
||||
results = results.groupby("datetime").apply(lambda df: (df - df.mean()) / df.std())
|
||||
results = results.mean(axis=1)
|
||||
results = results.sort_index()
|
||||
return results
|
||||
113
qlib/model/ens/group.py
Normal file
113
qlib/model/ens/group.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
Group can group a set of objects based on `group_func` and change them to a dict.
|
||||
After group, we provide a method to reduce them.
|
||||
|
||||
For example:
|
||||
|
||||
group: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
|
||||
reduce: {(A,B): {C1: object, C2: object}} -> {(A,B): object}
|
||||
|
||||
"""
|
||||
|
||||
from qlib.model.ens.ensemble import Ensemble, RollingEnsemble
|
||||
from typing import Callable, Union
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
|
||||
class Group:
|
||||
"""Group the objects based on dict"""
|
||||
|
||||
def __init__(self, group_func=None, ens: Ensemble = None):
|
||||
"""
|
||||
Init Group.
|
||||
|
||||
Args:
|
||||
group_func (Callable, optional): Given a dict and return the group key and one of the group elements.
|
||||
|
||||
For example: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
|
||||
|
||||
Defaults to None.
|
||||
|
||||
ens (Ensemble, optional): If not None, do ensemble for grouped value after grouping.
|
||||
"""
|
||||
self._group_func = group_func
|
||||
self._ens_func = ens
|
||||
|
||||
def group(self, *args, **kwargs) -> dict:
|
||||
"""
|
||||
Group a set of objects and change them to a dict.
|
||||
|
||||
For example: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
|
||||
|
||||
Returns:
|
||||
dict: grouped dict
|
||||
"""
|
||||
if isinstance(getattr(self, "_group_func", None), Callable):
|
||||
return self._group_func(*args, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Please specify valid `group_func`.")
|
||||
|
||||
def reduce(self, *args, **kwargs) -> dict:
|
||||
"""
|
||||
Reduce grouped dict.
|
||||
|
||||
For example: {(A,B): {C1: object, C2: object}} -> {(A,B): object}
|
||||
|
||||
Returns:
|
||||
dict: reduced dict
|
||||
"""
|
||||
if isinstance(getattr(self, "_ens_func", None), Callable):
|
||||
return self._ens_func(*args, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Please specify valid `_ens_func`.")
|
||||
|
||||
def __call__(self, ungrouped_dict: dict, n_jobs: int = 1, verbose: int = 0, *args, **kwargs) -> dict:
|
||||
"""
|
||||
Group the ungrouped_dict into different groups.
|
||||
|
||||
Args:
|
||||
ungrouped_dict (dict): the ungrouped dict waiting for grouping like {name: things}
|
||||
|
||||
Returns:
|
||||
dict: grouped_dict like {G1: object, G2: object}
|
||||
n_jobs: how many progress you need.
|
||||
verbose: the print mode for Parallel.
|
||||
"""
|
||||
|
||||
# NOTE: The multiprocessing will raise error if you use `Serializable`
|
||||
# Because the `Serializable` will affect the behaviors of pickle
|
||||
grouped_dict = self.group(ungrouped_dict, *args, **kwargs)
|
||||
|
||||
key_l = []
|
||||
job_l = []
|
||||
for key, value in grouped_dict.items():
|
||||
key_l.append(key)
|
||||
job_l.append(delayed(Group.reduce)(self, value))
|
||||
return dict(zip(key_l, Parallel(n_jobs=n_jobs, verbose=verbose)(job_l)))
|
||||
|
||||
|
||||
class RollingGroup(Group):
|
||||
"""Group the rolling dict"""
|
||||
|
||||
def group(self, rolling_dict: dict) -> dict:
|
||||
"""Given an rolling dict likes {(A,B,R): things}, return the grouped dict likes {(A,B): {R:things}}
|
||||
|
||||
NOTE: There is an assumption which is the rolling key is at the end of the key tuple, because the rolling results always need to be ensemble firstly.
|
||||
|
||||
Args:
|
||||
rolling_dict (dict): an rolling dict. If the key is not a tuple, then do nothing.
|
||||
|
||||
Returns:
|
||||
dict: grouped dict
|
||||
"""
|
||||
grouped_dict = {}
|
||||
for key, values in rolling_dict.items():
|
||||
if isinstance(key, tuple):
|
||||
grouped_dict.setdefault(key[:-1], {})[key[-1]] = values
|
||||
return grouped_dict
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(ens=RollingEnsemble())
|
||||
@@ -1,27 +0,0 @@
|
||||
import abc
|
||||
import typing
|
||||
|
||||
|
||||
class TaskGen(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def __call__(self, *args, **kwargs) -> typing.List[dict]:
|
||||
"""
|
||||
generate
|
||||
|
||||
Parameters
|
||||
----------
|
||||
args, kwargs:
|
||||
The info for generating tasks
|
||||
Example 1):
|
||||
input: a specific task template
|
||||
output: rolling version of the tasks
|
||||
Example 2):
|
||||
input: a specific task template
|
||||
output: a set of tasks with different losses
|
||||
|
||||
Returns
|
||||
-------
|
||||
typing.List[dict]:
|
||||
A list of tasks
|
||||
"""
|
||||
pass
|
||||
@@ -1,42 +1,446 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
"""
|
||||
The Trainer will train a list of tasks and return a list of model recorders.
|
||||
There are two steps in each Trainer including ``train``(make model recorder) and ``end_train``(modify model recorder).
|
||||
|
||||
This is a concept called ``DelayTrainer``, which can be used in online simulating for parallel training.
|
||||
In ``DelayTrainer``, the first step is only to save some necessary info to model recorders, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting.
|
||||
|
||||
``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
|
||||
"""
|
||||
|
||||
import socket
|
||||
from typing import Callable, List
|
||||
|
||||
from qlib.data.dataset import Dataset
|
||||
from qlib.model.base import Model
|
||||
from qlib.utils import flatten_dict, get_cls_kwargs, init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
|
||||
|
||||
def task_train(task_config: dict, experiment_name):
|
||||
def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str = None) -> Recorder:
|
||||
"""
|
||||
task based training
|
||||
Begin task training to start a recorder and save the task config.
|
||||
|
||||
Args:
|
||||
task_config (dict): the config of a task
|
||||
experiment_name (str): the name of experiment
|
||||
recorder_name (str): the given name will be the recorder name. None for using rid.
|
||||
|
||||
Returns:
|
||||
Recorder: the model recorder
|
||||
"""
|
||||
with R.start(experiment_name=experiment_name, recorder_name=recorder_name):
|
||||
R.log_params(**flatten_dict(task_config))
|
||||
R.save_objects(**{"task": task_config}) # keep the original format and datatype
|
||||
R.set_tags(**{"hostname": socket.gethostname()})
|
||||
recorder: Recorder = R.get_recorder()
|
||||
return recorder
|
||||
|
||||
|
||||
def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
|
||||
"""
|
||||
Finish task training with real model fitting and saving.
|
||||
|
||||
Args:
|
||||
rec (Recorder): the recorder will be resumed
|
||||
experiment_name (str): the name of experiment
|
||||
|
||||
Returns:
|
||||
Recorder: the model recorder
|
||||
"""
|
||||
with R.start(experiment_name=experiment_name, recorder_id=rec.info["id"], resume=True):
|
||||
task_config = R.load_object("task")
|
||||
# model & dataset initiation
|
||||
model: Model = init_instance_by_config(task_config["model"])
|
||||
dataset: Dataset = init_instance_by_config(task_config["dataset"])
|
||||
# model training
|
||||
model.fit(dataset)
|
||||
R.save_objects(**{"params.pkl": model})
|
||||
# this dataset is saved for online inference. So the concrete data should not be dumped
|
||||
dataset.config(dump_all=False, recursive=True)
|
||||
R.save_objects(**{"dataset": dataset})
|
||||
# generate records: prediction, backtest, and analysis
|
||||
records = task_config.get("record", [])
|
||||
if isinstance(records, dict): # prevent only one dict
|
||||
records = [records]
|
||||
for record in records:
|
||||
cls, kwargs = get_cls_kwargs(record, default_module="qlib.workflow.record_temp")
|
||||
if cls is SignalRecord:
|
||||
rconf = {"model": model, "dataset": dataset, "recorder": rec}
|
||||
else:
|
||||
rconf = {"recorder": rec}
|
||||
r = cls(**kwargs, **rconf)
|
||||
r.generate()
|
||||
|
||||
return rec
|
||||
|
||||
|
||||
def task_train(task_config: dict, experiment_name: str) -> Recorder:
|
||||
"""
|
||||
Task based training, will be divided into two steps.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task_config : dict
|
||||
A dict describes a task setting.
|
||||
The config of a task.
|
||||
experiment_name: str
|
||||
The name of experiment
|
||||
|
||||
Returns
|
||||
----------
|
||||
Recorder: The instance of the recorder
|
||||
"""
|
||||
recorder = begin_task_train(task_config, experiment_name)
|
||||
recorder = end_task_train(recorder, experiment_name)
|
||||
return recorder
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""
|
||||
The trainer can train a list of models.
|
||||
There are Trainer and DelayTrainer, which can be distinguished by when it will finish real training.
|
||||
"""
|
||||
|
||||
# model initiaiton
|
||||
model = init_instance_by_config(task_config["model"])
|
||||
dataset = init_instance_by_config(task_config["dataset"])
|
||||
def __init__(self):
|
||||
self.delay = False
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name=experiment_name):
|
||||
# train model
|
||||
R.log_params(**flatten_dict(task_config))
|
||||
model.fit(dataset)
|
||||
recorder = R.get_recorder()
|
||||
R.save_objects(**{"params.pkl": model})
|
||||
def train(self, tasks: list, *args, **kwargs) -> list:
|
||||
"""
|
||||
Given a list of task definitions, begin training, and return the models.
|
||||
|
||||
# generate records: prediction, backtest, and analysis
|
||||
for record in task_config["record"]:
|
||||
if record["class"] == SignalRecord.__name__:
|
||||
srconf = {"model": model, "dataset": dataset, "recorder": recorder}
|
||||
record["kwargs"].update(srconf)
|
||||
sr = init_instance_by_config(record)
|
||||
sr.generate()
|
||||
else:
|
||||
rconf = {"recorder": recorder}
|
||||
record["kwargs"].update(rconf)
|
||||
ar = init_instance_by_config(record)
|
||||
ar.generate()
|
||||
For Trainer, it finishes real training in this method.
|
||||
For DelayTrainer, it only does some preparation in this method.
|
||||
|
||||
Args:
|
||||
tasks: a list of tasks
|
||||
|
||||
Returns:
|
||||
list: a list of models
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `train` method.")
|
||||
|
||||
def end_train(self, models: list, *args, **kwargs) -> list:
|
||||
"""
|
||||
Given a list of models, finished something at the end of training if you need.
|
||||
The models may be Recorder, txt file, database, and so on.
|
||||
|
||||
For Trainer, it does some finishing touches in this method.
|
||||
For DelayTrainer, it finishes real training in this method.
|
||||
|
||||
Args:
|
||||
models: a list of models
|
||||
|
||||
Returns:
|
||||
list: a list of models
|
||||
"""
|
||||
# do nothing if you finished all work in `train` method
|
||||
return models
|
||||
|
||||
def is_delay(self) -> bool:
|
||||
"""
|
||||
If Trainer will delay finishing `end_train`.
|
||||
|
||||
Returns:
|
||||
bool: if DelayTrainer
|
||||
"""
|
||||
return self.delay
|
||||
|
||||
|
||||
class TrainerR(Trainer):
|
||||
"""
|
||||
Trainer based on (R)ecorder.
|
||||
It will train a list of tasks and return a list of model recorders in a linear way.
|
||||
|
||||
Assumption: models were defined by `task` and the results will be saved to `Recorder`.
|
||||
"""
|
||||
|
||||
# Those tag will help you distinguish whether the Recorder has finished traning
|
||||
STATUS_KEY = "train_status"
|
||||
STATUS_BEGIN = "begin_task_train"
|
||||
STATUS_END = "end_task_train"
|
||||
|
||||
def __init__(self, experiment_name: str = None, train_func: Callable = task_train):
|
||||
"""
|
||||
Init TrainerR.
|
||||
|
||||
Args:
|
||||
experiment_name (str, optional): the default name of experiment.
|
||||
train_func (Callable, optional): default training method. Defaults to `task_train`.
|
||||
"""
|
||||
super().__init__()
|
||||
self.experiment_name = experiment_name
|
||||
self.train_func = train_func
|
||||
|
||||
def train(self, tasks: list, train_func: Callable = None, experiment_name: str = None, **kwargs) -> List[Recorder]:
|
||||
"""
|
||||
Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
|
||||
|
||||
Args:
|
||||
tasks (list): a list of definitions based on `task` dict
|
||||
train_func (Callable): the training method which needs at least `tasks` and `experiment_name`. None for the default training method.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
kwargs: the params for train_func.
|
||||
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if len(tasks) == 0:
|
||||
return []
|
||||
if train_func is None:
|
||||
train_func = self.train_func
|
||||
if experiment_name is None:
|
||||
experiment_name = self.experiment_name
|
||||
recs = []
|
||||
for task in tasks:
|
||||
rec = train_func(task, experiment_name, **kwargs)
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
|
||||
recs.append(rec)
|
||||
return recs
|
||||
|
||||
def end_train(self, recs: list, **kwargs) -> List[Recorder]:
|
||||
"""
|
||||
Set STATUS_END tag to the recorders.
|
||||
|
||||
Args:
|
||||
recs (list): a list of trained recorders.
|
||||
|
||||
Returns:
|
||||
List[Recorder]: the same list as the param.
|
||||
"""
|
||||
for rec in recs:
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
|
||||
|
||||
class DelayTrainerR(TrainerR):
|
||||
"""
|
||||
A delayed implementation based on TrainerR, which means `train` method may only do some preparation and `end_train` method can do the real model fitting.
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name: str = None, train_func=begin_task_train, end_train_func=end_task_train):
|
||||
"""
|
||||
Init TrainerRM.
|
||||
|
||||
Args:
|
||||
experiment_name (str): the default name of experiment.
|
||||
train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
|
||||
end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
|
||||
"""
|
||||
super().__init__(experiment_name, train_func)
|
||||
self.end_train_func = end_train_func
|
||||
self.delay = True
|
||||
|
||||
def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
|
||||
"""
|
||||
Given a list of Recorder and return a list of trained Recorder.
|
||||
This class will finish real data loading and model fitting.
|
||||
|
||||
Args:
|
||||
recs (list): a list of Recorder, the tasks have been saved to them
|
||||
end_train_func (Callable, optional): the end_train method which needs at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
kwargs: the params for end_train_func.
|
||||
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if end_train_func is None:
|
||||
end_train_func = self.end_train_func
|
||||
if experiment_name is None:
|
||||
experiment_name = self.experiment_name
|
||||
for rec in recs:
|
||||
if rec.list_tags()[self.STATUS_KEY] == self.STATUS_END:
|
||||
continue
|
||||
end_train_func(rec, experiment_name, **kwargs)
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
|
||||
|
||||
class TrainerRM(Trainer):
|
||||
"""
|
||||
Trainer based on (R)ecorder and Task(M)anager.
|
||||
It can train a list of tasks and return a list of model recorders in a multiprocessing way.
|
||||
|
||||
Assumption: `task` will be saved to TaskManager and `task` will be fetched and trained from TaskManager
|
||||
"""
|
||||
|
||||
# Those tag will help you distinguish whether the Recorder has finished traning
|
||||
STATUS_KEY = "train_status"
|
||||
STATUS_BEGIN = "begin_task_train"
|
||||
STATUS_END = "end_task_train"
|
||||
|
||||
def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train):
|
||||
"""
|
||||
Init TrainerR.
|
||||
|
||||
Args:
|
||||
experiment_name (str): the default name of experiment.
|
||||
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
|
||||
train_func (Callable, optional): default training method. Defaults to `task_train`.
|
||||
"""
|
||||
super().__init__()
|
||||
self.experiment_name = experiment_name
|
||||
self.task_pool = task_pool
|
||||
self.train_func = train_func
|
||||
|
||||
def train(
|
||||
self,
|
||||
tasks: list,
|
||||
train_func: Callable = None,
|
||||
experiment_name: str = None,
|
||||
before_status: str = TaskManager.STATUS_WAITING,
|
||||
after_status: str = TaskManager.STATUS_DONE,
|
||||
**kwargs,
|
||||
) -> List[Recorder]:
|
||||
"""
|
||||
Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
|
||||
|
||||
This method defaults to a single process, but TaskManager offered a great way to parallel training.
|
||||
Users can customize their train_func to realize multiple processes or even multiple machines.
|
||||
|
||||
Args:
|
||||
tasks (list): a list of definitions based on `task` dict
|
||||
train_func (Callable): the training method which needs at least `task`s and `experiment_name`. None for the default training method.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
before_status (str): the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE.
|
||||
after_status (str): the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE.
|
||||
kwargs: the params for train_func.
|
||||
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if len(tasks) == 0:
|
||||
return []
|
||||
if train_func is None:
|
||||
train_func = self.train_func
|
||||
if experiment_name is None:
|
||||
experiment_name = self.experiment_name
|
||||
task_pool = self.task_pool
|
||||
if task_pool is None:
|
||||
task_pool = experiment_name
|
||||
tm = TaskManager(task_pool=task_pool)
|
||||
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
|
||||
run_task(
|
||||
train_func,
|
||||
task_pool,
|
||||
experiment_name=experiment_name,
|
||||
before_status=before_status,
|
||||
after_status=after_status,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
recs = []
|
||||
for _id in _id_list:
|
||||
rec = tm.re_query(_id)["res"]
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
|
||||
recs.append(rec)
|
||||
return recs
|
||||
|
||||
def end_train(self, recs: list, **kwargs) -> List[Recorder]:
|
||||
"""
|
||||
Set STATUS_END tag to the recorders.
|
||||
|
||||
Args:
|
||||
recs (list): a list of trained recorders.
|
||||
|
||||
Returns:
|
||||
List[Recorder]: the same list as the param.
|
||||
"""
|
||||
for rec in recs:
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
|
||||
|
||||
class DelayTrainerRM(TrainerRM):
|
||||
"""
|
||||
A delayed implementation based on TrainerRM, which means `train` method may only do some preparation and `end_train` method can do the real model fitting.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experiment_name: str = None,
|
||||
task_pool: str = None,
|
||||
train_func=begin_task_train,
|
||||
end_train_func=end_task_train,
|
||||
):
|
||||
"""
|
||||
Init DelayTrainerRM.
|
||||
|
||||
Args:
|
||||
experiment_name (str): the default name of experiment.
|
||||
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
|
||||
train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
|
||||
end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
|
||||
"""
|
||||
super().__init__(experiment_name, task_pool, train_func)
|
||||
self.end_train_func = end_train_func
|
||||
self.delay = True
|
||||
|
||||
def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
|
||||
"""
|
||||
Same as `train` of TrainerRM, after_status will be STATUS_PART_DONE.
|
||||
|
||||
Args:
|
||||
tasks (list): a list of definition based on `task` dict
|
||||
train_func (Callable): the train method which need at least `task`s and `experiment_name`. Defaults to None for using self.train_func.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if len(tasks) == 0:
|
||||
return []
|
||||
return super().train(
|
||||
tasks,
|
||||
train_func=train_func,
|
||||
experiment_name=experiment_name,
|
||||
after_status=TaskManager.STATUS_PART_DONE,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
|
||||
"""
|
||||
Given a list of Recorder and return a list of trained Recorder.
|
||||
This class will finish real data loading and model fitting.
|
||||
|
||||
NOTE: This method will train all STATUS_PART_DONE tasks in the task pool, not only the ``recs``.
|
||||
|
||||
Args:
|
||||
recs (list): a list of Recorder, the tasks have been saved to them.
|
||||
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
kwargs: the params for end_train_func.
|
||||
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
|
||||
if end_train_func is None:
|
||||
end_train_func = self.end_train_func
|
||||
if experiment_name is None:
|
||||
experiment_name = self.experiment_name
|
||||
task_pool = self.task_pool
|
||||
if task_pool is None:
|
||||
task_pool = experiment_name
|
||||
tasks = []
|
||||
for rec in recs:
|
||||
tasks.append(rec.load_object("task"))
|
||||
|
||||
run_task(
|
||||
end_train_func,
|
||||
task_pool,
|
||||
query={"filter": {"$in": tasks}}, # only train these tasks
|
||||
experiment_name=experiment_name,
|
||||
before_status=TaskManager.STATUS_PART_DONE,
|
||||
**kwargs,
|
||||
)
|
||||
for rec in recs:
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
|
||||
@@ -6,6 +6,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import copy
|
||||
import json
|
||||
@@ -24,7 +25,9 @@ import collections
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Union, Tuple, Text, Optional
|
||||
from typing import Union, Tuple, Any, Text, Optional
|
||||
from types import ModuleType
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ..config import C
|
||||
from ..log import get_module_logger, set_log_with_config
|
||||
@@ -165,24 +168,25 @@ def parse_field(field):
|
||||
return re.sub(r"\$(\w+)", r'Feature("\1")', re.sub(r"(\w+\s*)\(", r"Operators.\1(", field))
|
||||
|
||||
|
||||
def get_module_by_module_path(module_path):
|
||||
def get_module_by_module_path(module_path: Union[str, ModuleType]):
|
||||
"""Load module path
|
||||
|
||||
:param module_path:
|
||||
:return:
|
||||
"""
|
||||
|
||||
if module_path.endswith(".py"):
|
||||
module_spec = importlib.util.spec_from_file_location("", module_path)
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
module_spec.loader.exec_module(module)
|
||||
if isinstance(module_path, ModuleType):
|
||||
module = module_path
|
||||
else:
|
||||
module = importlib.import_module(module_path)
|
||||
|
||||
if module_path.endswith(".py"):
|
||||
module_spec = importlib.util.spec_from_file_location("", module_path)
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
module_spec.loader.exec_module(module)
|
||||
else:
|
||||
module = importlib.import_module(module_path)
|
||||
return module
|
||||
|
||||
|
||||
def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
|
||||
def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType] = None) -> (type, dict):
|
||||
"""
|
||||
extract class and kwargs from config info
|
||||
|
||||
@@ -191,8 +195,10 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
|
||||
config : [dict, str]
|
||||
similar to config
|
||||
|
||||
module : Python module
|
||||
default_module : Python module or str
|
||||
It should be a python module to load the class type
|
||||
This function will load class from the config['module_path'] first.
|
||||
If config['module_path'] doesn't exists, it will load the class from default_module.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -200,10 +206,14 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
|
||||
the class object and it's arguments.
|
||||
"""
|
||||
if isinstance(config, dict):
|
||||
module = get_module_by_module_path(config.get("module_path", default_module))
|
||||
|
||||
# raise AttributeError
|
||||
klass = getattr(module, config["class"])
|
||||
kwargs = config.get("kwargs", {})
|
||||
elif isinstance(config, str):
|
||||
module = get_module_by_module_path(default_module)
|
||||
|
||||
klass = getattr(module, config)
|
||||
kwargs = {}
|
||||
else:
|
||||
@@ -212,8 +222,8 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
|
||||
|
||||
|
||||
def init_instance_by_config(
|
||||
config: Union[str, dict, object], module=None, accept_types: Union[type, Tuple[type]] = (), **kwargs
|
||||
) -> object:
|
||||
config: Union[str, dict, object], default_module=None, accept_types: Union[type, Tuple[type]] = (), **kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
get initialized instance with config
|
||||
|
||||
@@ -227,13 +237,19 @@ def init_instance_by_config(
|
||||
'model_path': path, # It is optional if module is given
|
||||
}
|
||||
str example.
|
||||
"ClassName": getattr(module, config)() will be used.
|
||||
1) specify a pickle object
|
||||
- path like 'file:///<path to pickle file>/obj.pkl'
|
||||
2) specify a class name
|
||||
- "ClassName": getattr(module, config)() will be used.
|
||||
object example:
|
||||
instance of accept_types
|
||||
module : Python module
|
||||
default_module : Python module
|
||||
Optional. It should be a python module.
|
||||
NOTE: the "module_path" will be override by `module` arguments
|
||||
|
||||
This function will load class from the config['module_path'] first.
|
||||
If config['module_path'] doesn't exists, it will load the class from default_module.
|
||||
|
||||
accept_types: Union[type, Tuple[type]]
|
||||
Optional. If the config is a instance of specific type, return the config directly.
|
||||
This will be passed into the second parameter of isinstance.
|
||||
@@ -246,10 +262,14 @@ def init_instance_by_config(
|
||||
if isinstance(config, accept_types):
|
||||
return config
|
||||
|
||||
if module is None:
|
||||
module = get_module_by_module_path(config["module_path"])
|
||||
if isinstance(config, str):
|
||||
# path like 'file:///<path to pickle file>/obj.pkl'
|
||||
pr = urlparse(config)
|
||||
if pr.scheme == "file":
|
||||
with open(os.path.join(pr.netloc, pr.path), "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
klass, cls_kwargs = get_cls_kwargs(config, module)
|
||||
klass, cls_kwargs = get_cls_kwargs(config, default_module=default_module)
|
||||
return klass(**cls_kwargs, **kwargs)
|
||||
|
||||
|
||||
@@ -502,7 +522,7 @@ def get_date_range(trading_date, left_shift=0, right_shift=0, future=False):
|
||||
return calendar
|
||||
|
||||
|
||||
def get_date_by_shift(trading_date, shift, future=False, clip_shift=True):
|
||||
def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day"):
|
||||
"""get trading date with shift bias wil cur_date
|
||||
e.g. : shift == 1, return next trading date
|
||||
shift == -1, return previous trading date
|
||||
@@ -515,7 +535,7 @@ def get_date_by_shift(trading_date, shift, future=False, clip_shift=True):
|
||||
"""
|
||||
from qlib.data import D
|
||||
|
||||
cal = D.calendar(future=future)
|
||||
cal = D.calendar(future=future, freq=freq)
|
||||
if pd.to_datetime(trading_date) not in list(cal):
|
||||
raise ValueError("{} is not trading day!".format(str(trading_date)))
|
||||
_index = bisect.bisect_left(cal, trading_date)
|
||||
@@ -696,23 +716,33 @@ def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame:
|
||||
return df.sort_index(axis=axis)
|
||||
|
||||
|
||||
def flatten_dict(d, parent_key="", sep="."):
|
||||
"""flatten_dict.
|
||||
FLATTEN_TUPLE = "_FLATTEN_TUPLE"
|
||||
|
||||
|
||||
def flatten_dict(d, parent_key="", sep=".") -> dict:
|
||||
"""
|
||||
Flatten a nested dict.
|
||||
|
||||
>>> flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]})
|
||||
>>> {'a': 1, 'c.a': 2, 'c.b.x': 5, 'd': [1, 2, 3], 'c.b.y': 10}
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d :
|
||||
d
|
||||
parent_key :
|
||||
parent_key
|
||||
sep :
|
||||
sep
|
||||
>>> flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]}, sep=FLATTEN_TUPLE)
|
||||
>>> {'a': 1, ('c','a'): 2, ('c','b','x'): 5, 'd': [1, 2, 3], ('c','b','y'): 10}
|
||||
|
||||
Args:
|
||||
d (dict): the dict waiting for flatting
|
||||
parent_key (str, optional): the parent key, will be a prefix in new key. Defaults to "".
|
||||
sep (str, optional): the separator for string connecting. FLATTEN_TUPLE for tuple connecting.
|
||||
|
||||
Returns:
|
||||
dict: flatten dict
|
||||
"""
|
||||
items = []
|
||||
for k, v in d.items():
|
||||
new_key = parent_key + sep + k if parent_key else k
|
||||
if sep == FLATTEN_TUPLE:
|
||||
new_key = (parent_key, k) if parent_key else k
|
||||
else:
|
||||
new_key = parent_key + sep + k if parent_key else k
|
||||
if isinstance(v, collections.abc.MutableMapping):
|
||||
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
||||
else:
|
||||
|
||||
@@ -3,16 +3,24 @@
|
||||
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import typing
|
||||
import dill
|
||||
from typing import Union
|
||||
|
||||
|
||||
class Serializable:
|
||||
"""
|
||||
Serializable behaves like pickle.
|
||||
But it only saves the state whose name **does not** start with `_`
|
||||
Serializable will change the behaviors of pickle.
|
||||
- It only saves the state whose name **does not** start with `_`
|
||||
It provides a syntactic sugar for distinguish the attributes which user doesn't want.
|
||||
- For examples, a learnable Datahandler just wants to save the parameters without data when dumping to disk
|
||||
"""
|
||||
|
||||
pickle_backend = "pickle" # another optional value is "dill" which can pickle more things of python.
|
||||
default_dump_all = False # if dump all things
|
||||
|
||||
def __init__(self):
|
||||
self._dump_all = False
|
||||
self._dump_all = self.default_dump_all
|
||||
self._exclude = []
|
||||
|
||||
def __getstate__(self) -> dict:
|
||||
@@ -33,18 +41,86 @@ class Serializable:
|
||||
@property
|
||||
def exclude(self):
|
||||
"""
|
||||
What attribute will be dumped
|
||||
What attribute will not be dumped
|
||||
"""
|
||||
return getattr(self, "_exclude", [])
|
||||
|
||||
def config(self, dump_all: bool = None, exclude: list = None):
|
||||
if dump_all is not None:
|
||||
self._dump_all = dump_all
|
||||
FLAG_KEY = "_qlib_serial_flag"
|
||||
|
||||
if exclude is not None:
|
||||
self._exclude = exclude
|
||||
def config(self, dump_all: bool = None, exclude: list = None, recursive=False):
|
||||
"""
|
||||
configure the serializable object
|
||||
|
||||
def to_pickle(self, path: [Path, str], dump_all: bool = None, exclude: list = None):
|
||||
Parameters
|
||||
----------
|
||||
dump_all : bool
|
||||
will the object dump all object
|
||||
exclude : list
|
||||
What attribute will not be dumped
|
||||
recursive : bool
|
||||
will the configuration be recursive
|
||||
"""
|
||||
|
||||
params = {"dump_all": dump_all, "exclude": exclude}
|
||||
|
||||
for k, v in params.items():
|
||||
if v is not None:
|
||||
attr_name = f"_{k}"
|
||||
setattr(self, attr_name, v)
|
||||
|
||||
if recursive:
|
||||
for obj in self.__dict__.values():
|
||||
# set flag to prevent endless loop
|
||||
self.__dict__[self.FLAG_KEY] = True
|
||||
if isinstance(obj, Serializable) and self.FLAG_KEY not in obj.__dict__:
|
||||
obj.config(**params, recursive=True)
|
||||
del self.__dict__[self.FLAG_KEY]
|
||||
|
||||
def to_pickle(self, path: Union[Path, str], dump_all: bool = None, exclude: list = None):
|
||||
"""
|
||||
Dump self to a pickle file.
|
||||
|
||||
Args:
|
||||
path (Union[Path, str]): the path to dump
|
||||
dump_all (bool, optional): if need to dump all things. Defaults to None.
|
||||
exclude (list, optional): will exclude the attributes in this list when dumping. Defaults to None.
|
||||
"""
|
||||
self.config(dump_all=dump_all, exclude=exclude)
|
||||
with Path(path).open("wb") as f:
|
||||
pickle.dump(self, f)
|
||||
self.get_backend().dump(self, f)
|
||||
|
||||
@classmethod
|
||||
def load(cls, filepath):
|
||||
"""
|
||||
Load the collector from a filepath.
|
||||
|
||||
Args:
|
||||
filepath (str): the path of file
|
||||
|
||||
Raises:
|
||||
TypeError: the pickled file must be `Collector`
|
||||
|
||||
Returns:
|
||||
Collector: the instance of Collector
|
||||
"""
|
||||
with open(filepath, "rb") as f:
|
||||
object = cls.get_backend().load(f)
|
||||
if isinstance(object, cls):
|
||||
return object
|
||||
else:
|
||||
raise TypeError(f"The instance of {type(object)} is not a valid `{type(cls)}`!")
|
||||
|
||||
@classmethod
|
||||
def get_backend(cls):
|
||||
"""
|
||||
Return the real backend of a Serializable class. The pickle_backend value can be "pickle" or "dill".
|
||||
|
||||
Returns:
|
||||
module: pickle or dill module based on pickle_backend
|
||||
"""
|
||||
if cls.pickle_backend == "pickle":
|
||||
return pickle
|
||||
elif cls.pickle_backend == "dill":
|
||||
return dill
|
||||
else:
|
||||
raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.")
|
||||
|
||||
@@ -331,7 +331,7 @@ class QlibRecorder:
|
||||
"""
|
||||
self.exp_manager.set_uri(uri)
|
||||
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None):
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None) -> Recorder:
|
||||
"""
|
||||
Method for retrieving a recorder.
|
||||
|
||||
|
||||
0
qlib/workflow/online/__init__.py
Normal file
0
qlib/workflow/online/__init__.py
Normal file
304
qlib/workflow/online/manager.py
Normal file
304
qlib/workflow/online/manager.py
Normal file
@@ -0,0 +1,304 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
OnlineManager can manage a set of `Online Strategy <#Online Strategy>`_ and run them dynamically.
|
||||
|
||||
With the change of time, the decisive models will be also changed. In this module, we call those contributing models `online` models.
|
||||
In every routine(such as every day or every minute), the `online` models may be changed and the prediction of them needs to be updated.
|
||||
So this module provides a series of methods to control this process.
|
||||
|
||||
This module also provides a method to simulate `Online Strategy <#Online Strategy>`_ in history.
|
||||
Which means you can verify your strategy or find a better one.
|
||||
|
||||
There are 4 total situations for using different trainers in different situations:
|
||||
|
||||
|
||||
|
||||
========================= ===================================================================================
|
||||
Situations Description
|
||||
========================= ===================================================================================
|
||||
Online + Trainer When you REAL want to do a routine, the Trainer will help you train the models.
|
||||
|
||||
Online + DelayTrainer In normal online routine, whether Trainer or DelayTrainer will REAL train models
|
||||
in this routine. So it is not necessary to use DelayTrainer when do a REAL routine.
|
||||
|
||||
Simulation + Trainer When your models have some temporal dependence on the previous models, then you
|
||||
need to consider using Trainer. This means it will REAL train your models in
|
||||
every routine and prepare signals for every routine.
|
||||
|
||||
Simulation + DelayTrainer When your models don't have any temporal dependence, you can use DelayTrainer
|
||||
for the ability to multitasking. It means all tasks in all routines
|
||||
can be REAL trained at the end of simulating. The signals will be prepared well at
|
||||
different time segments (based on whether or not any new model is online).
|
||||
========================= ===================================================================================
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Callable, Dict, List, Union
|
||||
|
||||
import pandas as pd
|
||||
from qlib import get_module_logger
|
||||
from qlib.data.data import D
|
||||
from qlib.log import set_global_logger_level
|
||||
from qlib.model.ens.ensemble import AverageEnsemble
|
||||
from qlib.model.trainer import DelayTrainerR, Trainer, TrainerR
|
||||
from qlib.utils import flatten_dict
|
||||
from qlib.utils.serial import Serializable
|
||||
from qlib.workflow.online.strategy import OnlineStrategy
|
||||
from qlib.workflow.task.collect import MergeCollector
|
||||
|
||||
|
||||
class OnlineManager(Serializable):
|
||||
"""
|
||||
OnlineManager can manage online models with `Online Strategy <#Online Strategy>`_.
|
||||
It also provides a history recording of which models are online at what time.
|
||||
"""
|
||||
|
||||
STATUS_SIMULATING = "simulating" # when calling `simulate`
|
||||
STATUS_NORMAL = "normal" # the normal status
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategies: Union[OnlineStrategy, List[OnlineStrategy]],
|
||||
trainer: Trainer = None,
|
||||
begin_time: Union[str, pd.Timestamp] = None,
|
||||
freq="day",
|
||||
):
|
||||
"""
|
||||
Init OnlineManager.
|
||||
One OnlineManager must have at least one OnlineStrategy.
|
||||
|
||||
Args:
|
||||
strategies (Union[OnlineStrategy, List[OnlineStrategy]]): an instance of OnlineStrategy or a list of OnlineStrategy
|
||||
begin_time (Union[str,pd.Timestamp], optional): the OnlineManager will begin at this time. Defaults to None for using the latest date.
|
||||
trainer (Trainer): the trainer to train task. None for using TrainerR.
|
||||
freq (str, optional): data frequency. Defaults to "day".
|
||||
"""
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
if not isinstance(strategies, list):
|
||||
strategies = [strategies]
|
||||
self.strategies = strategies
|
||||
self.freq = freq
|
||||
if begin_time is None:
|
||||
begin_time = D.calendar(freq=self.freq).max()
|
||||
self.begin_time = pd.Timestamp(begin_time)
|
||||
self.cur_time = self.begin_time
|
||||
# OnlineManager will recorder the history of online models, which is a dict like {pd.Timestamp, {strategy, [online_models]}}.
|
||||
self.history = {}
|
||||
if trainer is None:
|
||||
trainer = TrainerR()
|
||||
self.trainer = trainer
|
||||
self.signals = None
|
||||
self.status = self.STATUS_NORMAL
|
||||
|
||||
def first_train(self, strategies: List[OnlineStrategy] = None, model_kwargs: dict = {}):
|
||||
"""
|
||||
Get tasks from every strategy's first_tasks method and train them.
|
||||
If using DelayTrainer, it can finish training all together after every strategy's first_tasks.
|
||||
|
||||
Args:
|
||||
strategies (List[OnlineStrategy]): the strategies list (need this param when adding strategies). None for use default strategies.
|
||||
model_kwargs (dict): the params for `prepare_online_models`
|
||||
"""
|
||||
if strategies is None:
|
||||
strategies = self.strategies
|
||||
for strategy in strategies:
|
||||
|
||||
self.logger.info(f"Strategy `{strategy.name_id}` begins first training...")
|
||||
tasks = strategy.first_tasks()
|
||||
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
self.logger.info(f"Finished training {len(models)} models.")
|
||||
|
||||
online_models = strategy.prepare_online_models(models, **model_kwargs)
|
||||
self.history.setdefault(self.cur_time, {})[strategy] = online_models
|
||||
|
||||
def routine(
|
||||
self,
|
||||
cur_time: Union[str, pd.Timestamp] = None,
|
||||
task_kwargs: dict = {},
|
||||
model_kwargs: dict = {},
|
||||
signal_kwargs: dict = {},
|
||||
):
|
||||
"""
|
||||
Typical update process for every strategy and record the online history.
|
||||
|
||||
The typical update process after a routine, such as day by day or month by month.
|
||||
The process is: Update predictions -> Prepare tasks -> Prepare online models -> Prepare signals.
|
||||
|
||||
If using DelayTrainer, it can finish training all together after every strategy's prepare_tasks.
|
||||
|
||||
Args:
|
||||
cur_time (Union[str,pd.Timestamp], optional): run routine method in this time. Defaults to None.
|
||||
task_kwargs (dict): the params for `prepare_tasks`
|
||||
model_kwargs (dict): the params for `prepare_online_models`
|
||||
signal_kwargs (dict): the params for `prepare_signals`
|
||||
"""
|
||||
if cur_time is None:
|
||||
cur_time = D.calendar(freq=self.freq).max()
|
||||
self.cur_time = pd.Timestamp(cur_time) # None for latest date
|
||||
|
||||
for strategy in self.strategies:
|
||||
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
|
||||
if self.status == self.STATUS_NORMAL:
|
||||
strategy.tool.update_online_pred()
|
||||
|
||||
tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
|
||||
models = self.trainer.train(tasks)
|
||||
if self.status == self.STATUS_NORMAL or not self.trainer.is_delay():
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
self.logger.info(f"Finished training {len(models)} models.")
|
||||
online_models = strategy.prepare_online_models(models, **model_kwargs)
|
||||
self.history.setdefault(self.cur_time, {})[strategy] = online_models
|
||||
|
||||
if not self.trainer.is_delay():
|
||||
self.prepare_signals(**signal_kwargs)
|
||||
|
||||
def get_collector(self) -> MergeCollector:
|
||||
"""
|
||||
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy.
|
||||
This collector can be a basis as the signals preparation.
|
||||
|
||||
Returns:
|
||||
MergeCollector: the collector to merge other collectors.
|
||||
"""
|
||||
collector_dict = {}
|
||||
for strategy in self.strategies:
|
||||
collector_dict[strategy.name_id] = strategy.get_collector()
|
||||
return MergeCollector(collector_dict, process_list=[])
|
||||
|
||||
def add_strategy(self, strategies: Union[OnlineStrategy, List[OnlineStrategy]]):
|
||||
"""
|
||||
Add some new strategies to OnlineManager.
|
||||
|
||||
Args:
|
||||
strategy (Union[OnlineStrategy, List[OnlineStrategy]]): a list of OnlineStrategy
|
||||
"""
|
||||
if not isinstance(strategies, list):
|
||||
strategies = [strategies]
|
||||
self.first_train(strategies)
|
||||
self.strategies.extend(strategies)
|
||||
|
||||
def prepare_signals(self, prepare_func: Callable = AverageEnsemble(), over_write=False):
|
||||
"""
|
||||
After preparing the data of the last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for the next routine.
|
||||
|
||||
NOTE: Given a set prediction, all signals before these prediction end times will be prepared well.
|
||||
|
||||
Even if the latest signal already exists, the latest calculation result will be overwritten.
|
||||
|
||||
.. note::
|
||||
|
||||
Given a prediction of a certain time, all signals before this time will be prepared well.
|
||||
|
||||
Args:
|
||||
prepare_func (Callable, optional): Get signals from a dict after collecting. Defaults to AverageEnsemble(), the results collected by MergeCollector must be {xxx:pred}.
|
||||
over_write (bool, optional): If True, the new signals will overwrite. If False, the new signals will append to the end of signals. Defaults to False.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: the signals.
|
||||
"""
|
||||
signals = prepare_func(self.get_collector()())
|
||||
old_signals = self.signals
|
||||
if old_signals is not None and not over_write:
|
||||
old_max = old_signals.index.get_level_values("datetime").max()
|
||||
new_signals = signals.loc[old_max:]
|
||||
signals = pd.concat([old_signals, new_signals], axis=0)
|
||||
else:
|
||||
new_signals = signals
|
||||
self.logger.info(f"Finished preparing new {len(new_signals)} signals.")
|
||||
self.signals = signals
|
||||
return new_signals
|
||||
|
||||
def get_signals(self) -> Union[pd.Series, pd.DataFrame]:
|
||||
"""
|
||||
Get prepared online signals.
|
||||
|
||||
Returns:
|
||||
Union[pd.Series, pd.DataFrame]: pd.Series for only one signals every datetime.
|
||||
pd.DataFrame for multiple signals, for example, buy and sell operations use different trading signals.
|
||||
"""
|
||||
return self.signals
|
||||
|
||||
SIM_LOG_LEVEL = logging.INFO + 1 # when simulating, reduce information
|
||||
SIM_LOG_NAME = "SIMULATE_INFO"
|
||||
|
||||
def simulate(
|
||||
self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}
|
||||
) -> Union[pd.Series, pd.DataFrame]:
|
||||
"""
|
||||
Starting from the current time, this method will simulate every routine in OnlineManager until the end time.
|
||||
|
||||
Considering the parallel training, the models and signals can be prepared after all routine simulating.
|
||||
|
||||
The delay training way can be ``DelayTrainer`` and the delay preparing signals way can be ``delay_prepare``.
|
||||
|
||||
Args:
|
||||
end_time: the time the simulation will end
|
||||
frequency: the calendar frequency
|
||||
task_kwargs (dict): the params for `prepare_tasks`
|
||||
model_kwargs (dict): the params for `prepare_online_models`
|
||||
signal_kwargs (dict): the params for `prepare_signals`
|
||||
|
||||
Returns:
|
||||
Union[pd.Series, pd.DataFrame]: pd.Series for only one signals every datetime.
|
||||
pd.DataFrame for multiple signals, for example, buy and sell operations use different trading signals.
|
||||
"""
|
||||
self.status = self.STATUS_SIMULATING
|
||||
cal = D.calendar(start_time=self.cur_time, end_time=end_time, freq=frequency)
|
||||
self.first_train()
|
||||
|
||||
simulate_level = self.SIM_LOG_LEVEL
|
||||
set_global_logger_level(simulate_level)
|
||||
logging.addLevelName(simulate_level, self.SIM_LOG_NAME)
|
||||
|
||||
for cur_time in cal:
|
||||
self.logger.log(level=simulate_level, msg=f"Simulating at {str(cur_time)}......")
|
||||
self.routine(
|
||||
cur_time,
|
||||
task_kwargs=task_kwargs,
|
||||
model_kwargs=model_kwargs,
|
||||
signal_kwargs=signal_kwargs,
|
||||
)
|
||||
# delay prepare the models and signals
|
||||
if self.trainer.is_delay():
|
||||
self.delay_prepare(model_kwargs=model_kwargs, signal_kwargs=signal_kwargs)
|
||||
|
||||
# FIXME: get logging level firstly and restore it here
|
||||
set_global_logger_level(logging.DEBUG)
|
||||
self.logger.info(f"Finished preparing signals")
|
||||
self.status = self.STATUS_NORMAL
|
||||
return self.get_signals()
|
||||
|
||||
def delay_prepare(self, model_kwargs={}, signal_kwargs={}):
|
||||
"""
|
||||
Prepare all models and signals if something is waiting for preparation.
|
||||
|
||||
Args:
|
||||
model_kwargs: the params for `end_train`
|
||||
signal_kwargs: the params for `prepare_signals`
|
||||
"""
|
||||
last_models = {}
|
||||
signals_time = D.calendar()[0]
|
||||
need_prepare = False
|
||||
for cur_time, strategy_models in self.history.items():
|
||||
self.cur_time = cur_time
|
||||
|
||||
for strategy, models in strategy_models.items():
|
||||
# only new online models need to prepare
|
||||
if last_models.setdefault(strategy, set()) != set(models):
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id, **model_kwargs)
|
||||
strategy.tool.reset_online_tag(models)
|
||||
need_prepare = True
|
||||
last_models[strategy] = set(models)
|
||||
|
||||
if need_prepare:
|
||||
# NOTE: Assumption: the predictions of online models need less than next cur_time, or this method will work in a wrong way.
|
||||
self.prepare_signals(**signal_kwargs)
|
||||
if signals_time > cur_time:
|
||||
self.logger.warn(
|
||||
f"The signals have already parpred to {signals_time} by last preparation, but current time is only {cur_time}. This may be because the online models predict more than they should, which can cause signals to be contaminated by the offline models."
|
||||
)
|
||||
need_prepare = False
|
||||
signals_time = self.signals.index.get_level_values("datetime").max()
|
||||
211
qlib/workflow/online/strategy.py
Normal file
211
qlib/workflow/online/strategy.py
Normal file
@@ -0,0 +1,211 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
OnlineStrategy module is an element of online serving.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import List, Tuple, Union
|
||||
from qlib.data.data import D
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.workflow.online.utils import OnlineTool, OnlineToolR
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.collect import Collector, RecorderCollector
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.utils import TimeAdjuster
|
||||
|
||||
|
||||
class OnlineStrategy:
|
||||
"""
|
||||
OnlineStrategy is working with `Online Manager <#Online Manager>`_, responding to how the tasks are generated, the models are updated and signals are prepared.
|
||||
"""
|
||||
|
||||
def __init__(self, name_id: str):
|
||||
"""
|
||||
Init OnlineStrategy.
|
||||
This module **MUST** use `Trainer <../reference/api.html#Trainer>`_ to finishing model training.
|
||||
|
||||
Args:
|
||||
name_id (str): a unique name or id.
|
||||
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
|
||||
"""
|
||||
self.name_id = name_id
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
self.tool = OnlineTool()
|
||||
|
||||
def prepare_tasks(self, cur_time, **kwargs) -> List[dict]:
|
||||
"""
|
||||
After the end of a routine, check whether we need to prepare and train some new tasks based on cur_time (None for latest)..
|
||||
Return the new tasks waiting for training.
|
||||
|
||||
You can find the last online models by OnlineTool.online_models.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
|
||||
|
||||
def prepare_online_models(self, trained_models, cur_time=None) -> List[object]:
|
||||
"""
|
||||
Select some models from trained models and set them to online models.
|
||||
This is a typical implementation to online all trained models, you can override it to implement the complex method.
|
||||
You can find the last online models by OnlineTool.online_models if you still need them.
|
||||
|
||||
NOTE: Reset all online models to trained models. If there are no trained models, then do nothing.
|
||||
|
||||
Args:
|
||||
models (list): a list of models.
|
||||
cur_time (pd.Dataframe): current time from OnlineManger. None for the latest.
|
||||
|
||||
Returns:
|
||||
List[object]: a list of online models.
|
||||
"""
|
||||
if not trained_models:
|
||||
return self.tool.online_models()
|
||||
self.tool.reset_online_tag(trained_models)
|
||||
return trained_models
|
||||
|
||||
def first_tasks(self) -> List[dict]:
|
||||
"""
|
||||
Generate a series of tasks firstly and return them.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `first_tasks` method.")
|
||||
|
||||
def get_collector(self) -> Collector:
|
||||
"""
|
||||
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect different results of this strategy.
|
||||
|
||||
For example:
|
||||
1) collect predictions in Recorder
|
||||
2) collect signals in a txt file
|
||||
|
||||
Returns:
|
||||
Collector
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_collector` method.")
|
||||
|
||||
|
||||
class RollingStrategy(OnlineStrategy):
|
||||
|
||||
"""
|
||||
This example strategy always uses the latest rolling model sas online models.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name_id: str,
|
||||
task_template: Union[dict, List[dict]],
|
||||
rolling_gen: RollingGen,
|
||||
):
|
||||
"""
|
||||
Init RollingStrategy.
|
||||
|
||||
Assumption: the str of name_id, the experiment name, and the trainer's experiment name are the same.
|
||||
|
||||
Args:
|
||||
name_id (str): a unique name or id. Will be also the name of the Experiment.
|
||||
task_template (Union[dict, List[dict]]): a list of task_template or a single template, which will be used to generate many tasks using rolling_gen.
|
||||
rolling_gen (RollingGen): an instance of RollingGen
|
||||
"""
|
||||
super().__init__(name_id=name_id)
|
||||
self.exp_name = self.name_id
|
||||
if not isinstance(task_template, list):
|
||||
task_template = [task_template]
|
||||
self.task_template = task_template
|
||||
self.rg = rolling_gen
|
||||
self.tool = OnlineToolR(self.exp_name)
|
||||
self.ta = TimeAdjuster()
|
||||
|
||||
def get_collector(self, process_list=[RollingGroup()], rec_key_func=None, rec_filter_func=None, artifacts_key=None):
|
||||
"""
|
||||
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results. The returned collector must distinguish results in different models.
|
||||
|
||||
Assumption: the models can be distinguished based on the model name and rolling test segments.
|
||||
If you do not want this assumption, please implement your method or use another rec_key_func.
|
||||
|
||||
Args:
|
||||
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
|
||||
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
|
||||
artifacts_key (List[str], optional): the artifacts key you want to get. If None, get all artifacts.
|
||||
"""
|
||||
|
||||
def rec_key(recorder):
|
||||
task_config = recorder.load_object("task")
|
||||
model_key = task_config["model"]["class"]
|
||||
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
|
||||
return model_key, rolling_key
|
||||
|
||||
if rec_key_func is None:
|
||||
rec_key_func = rec_key
|
||||
|
||||
artifacts_collector = RecorderCollector(
|
||||
experiment=self.exp_name,
|
||||
process_list=process_list,
|
||||
rec_key_func=rec_key_func,
|
||||
rec_filter_func=rec_filter_func,
|
||||
artifacts_key=artifacts_key,
|
||||
)
|
||||
|
||||
return artifacts_collector
|
||||
|
||||
def first_tasks(self) -> List[dict]:
|
||||
"""
|
||||
Use rolling_gen to generate different tasks based on task_template.
|
||||
|
||||
Returns:
|
||||
List[dict]: a list of tasks
|
||||
"""
|
||||
return task_generator(
|
||||
tasks=self.task_template,
|
||||
generators=self.rg, # generate different date segment
|
||||
)
|
||||
|
||||
def prepare_tasks(self, cur_time) -> List[dict]:
|
||||
"""
|
||||
Prepare new tasks based on cur_time (None for the latest).
|
||||
|
||||
You can find the last online models by OnlineToolR.online_models.
|
||||
|
||||
Returns:
|
||||
List[dict]: a list of new tasks.
|
||||
"""
|
||||
latest_records, max_test = self._list_latest(self.tool.online_models())
|
||||
if max_test is None:
|
||||
self.logger.warn(f"No latest online recorders, no new tasks.")
|
||||
return []
|
||||
calendar_latest = D.calendar(end_time=cur_time)[-1] if cur_time is None else cur_time
|
||||
self.logger.info(
|
||||
f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
|
||||
)
|
||||
if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step:
|
||||
old_tasks = []
|
||||
tasks_tmp = []
|
||||
for rec in latest_records:
|
||||
task = rec.load_object("task")
|
||||
old_tasks.append(deepcopy(task))
|
||||
test_begin = task["dataset"]["kwargs"]["segments"]["test"][0]
|
||||
# modify the test segment to generate new tasks
|
||||
task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
|
||||
tasks_tmp.append(task)
|
||||
new_tasks_tmp = task_generator(tasks_tmp, self.rg)
|
||||
new_tasks = [task for task in new_tasks_tmp if task not in old_tasks]
|
||||
return new_tasks
|
||||
return []
|
||||
|
||||
def _list_latest(self, rec_list: List[Recorder]):
|
||||
"""
|
||||
List latest recorder form rec_list
|
||||
|
||||
Args:
|
||||
rec_list (List[Recorder]): a list of Recorder
|
||||
|
||||
Returns:
|
||||
List[Recorder], pd.Timestamp: the latest recorders and their test end time
|
||||
"""
|
||||
if len(rec_list) == 0:
|
||||
return rec_list, None
|
||||
max_test = max(rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] for rec in rec_list)
|
||||
latest_rec = []
|
||||
for rec in rec_list:
|
||||
if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test:
|
||||
latest_rec.append(rec)
|
||||
return latest_rec, max_test
|
||||
160
qlib/workflow/online/update.py
Normal file
160
qlib/workflow/online/update.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
Updater is a module to update artifacts such as predictions when the stock data is updating.
|
||||
"""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
import pandas as pd
|
||||
from qlib import get_module_logger
|
||||
from qlib.data import D
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.model import Model
|
||||
from qlib.utils import get_date_by_shift
|
||||
from qlib.workflow.recorder import Recorder
|
||||
|
||||
|
||||
class RMDLoader:
|
||||
"""
|
||||
Recorder Model Dataset Loader
|
||||
"""
|
||||
|
||||
def __init__(self, rec: Recorder):
|
||||
self.rec = rec
|
||||
|
||||
def get_dataset(self, start_time, end_time, segments=None) -> DatasetH:
|
||||
"""
|
||||
Load, config and setup dataset.
|
||||
|
||||
This dataset is for inference.
|
||||
|
||||
Args:
|
||||
start_time :
|
||||
the start_time of underlying data
|
||||
end_time :
|
||||
the end_time of underlying data
|
||||
segments : dict
|
||||
the segments config for dataset
|
||||
Due to the time series dataset (TSDatasetH), the test segments maybe different from start_time and end_time
|
||||
|
||||
Returns:
|
||||
DatasetH: the instance of DatasetH
|
||||
|
||||
"""
|
||||
if segments is None:
|
||||
segments = {"test": (start_time, end_time)}
|
||||
dataset: DatasetH = self.rec.load_object("dataset")
|
||||
dataset.config(handler_kwargs={"start_time": start_time, "end_time": end_time}, segments=segments)
|
||||
dataset.setup_data(handler_kwargs={"init_type": DataHandlerLP.IT_LS})
|
||||
return dataset
|
||||
|
||||
def get_model(self) -> Model:
|
||||
return self.rec.load_object("params.pkl")
|
||||
|
||||
|
||||
class RecordUpdater(metaclass=ABCMeta):
|
||||
"""
|
||||
Update a specific recorders
|
||||
"""
|
||||
|
||||
def __init__(self, record: Recorder, *args, **kwargs):
|
||||
self.record = record
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
@abstractmethod
|
||||
def update(self, *args, **kwargs):
|
||||
"""
|
||||
Update info for specific recorder
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class PredUpdater(RecordUpdater):
|
||||
"""
|
||||
Update the prediction in the Recorder
|
||||
"""
|
||||
|
||||
def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day"):
|
||||
"""
|
||||
Init PredUpdater.
|
||||
|
||||
Args:
|
||||
record : Recorder
|
||||
to_date :
|
||||
update to prediction to the `to_date`
|
||||
hist_ref : int
|
||||
Sometimes, the dataset will have historical depends.
|
||||
Leave the problem to users to set the length of historical dependency
|
||||
|
||||
.. note::
|
||||
|
||||
the start_time is not included in the hist_ref
|
||||
|
||||
"""
|
||||
# TODO: automate this hist_ref in the future.
|
||||
super().__init__(record=record)
|
||||
|
||||
self.to_date = to_date
|
||||
self.hist_ref = hist_ref
|
||||
self.freq = freq
|
||||
self.rmdl = RMDLoader(rec=record)
|
||||
|
||||
if to_date == None:
|
||||
to_date = D.calendar(freq=freq)[-1]
|
||||
self.to_date = pd.Timestamp(to_date)
|
||||
self.old_pred = record.load_object("pred.pkl")
|
||||
self.last_end = self.old_pred.index.get_level_values("datetime").max()
|
||||
|
||||
def prepare_data(self) -> DatasetH:
|
||||
"""
|
||||
Load dataset
|
||||
|
||||
Separating this function will make it easier to reuse the dataset
|
||||
|
||||
Returns:
|
||||
DatasetH: the instance of DatasetH
|
||||
"""
|
||||
start_time_buffer = get_date_by_shift(self.last_end, -self.hist_ref + 1, clip_shift=False, freq=self.freq)
|
||||
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
|
||||
seg = {"test": (start_time, self.to_date)}
|
||||
dataset = self.rmdl.get_dataset(start_time=start_time_buffer, end_time=self.to_date, segments=seg)
|
||||
return dataset
|
||||
|
||||
def update(self, dataset: DatasetH = None):
|
||||
"""
|
||||
Update the prediction in a recorder.
|
||||
|
||||
Args:
|
||||
DatasetH: the instance of DatasetH. None for reprepare.
|
||||
"""
|
||||
# FIXME: the problem below is not solved
|
||||
# The model dumped on GPU instances can not be loaded on CPU instance. Follow exception will raised
|
||||
# RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
|
||||
# https://github.com/pytorch/pytorch/issues/16797
|
||||
|
||||
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
|
||||
if start_time >= self.to_date:
|
||||
self.logger.info(
|
||||
f"The prediction in {self.record.info['id']} are latest ({start_time}). No need to update to {self.to_date}."
|
||||
)
|
||||
return
|
||||
|
||||
# load dataset
|
||||
if dataset is None:
|
||||
# For reusing the dataset
|
||||
dataset = self.prepare_data()
|
||||
|
||||
# Load model
|
||||
model = self.rmdl.get_model()
|
||||
|
||||
new_pred: pd.Series = model.predict(dataset)
|
||||
|
||||
cb_pred = pd.concat([self.old_pred, new_pred.to_frame("score")], axis=0)
|
||||
cb_pred = cb_pred.sort_index()
|
||||
|
||||
self.record.save_objects(**{"pred.pkl": cb_pred})
|
||||
|
||||
self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {self.record.info['id']}.")
|
||||
168
qlib/workflow/online/utils.py
Normal file
168
qlib/workflow/online/utils.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
OnlineTool is a module to set and unset a series of `online` models.
|
||||
The `online` models are some decisive models in some time points, which can be changed with the change of time.
|
||||
This allows us to use efficient submodels as the market-style changing.
|
||||
"""
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.workflow.online.update import PredUpdater
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
|
||||
|
||||
class OnlineTool:
|
||||
"""
|
||||
OnlineTool will manage `online` models in an experiment that includes the model recorders.
|
||||
"""
|
||||
|
||||
ONLINE_KEY = "online_status" # the online status key in recorder
|
||||
ONLINE_TAG = "online" # the 'online' model
|
||||
OFFLINE_TAG = "offline" # the 'offline' model, not for online serving
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Init OnlineTool.
|
||||
"""
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
def set_online_tag(self, tag, recorder: Union[list, object]):
|
||||
"""
|
||||
Set `tag` to the model to sign whether online.
|
||||
|
||||
Args:
|
||||
tag (str): the tags in `ONLINE_TAG`, `OFFLINE_TAG`
|
||||
recorder (Union[list,object]): the model's recorder
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `set_online_tag` method.")
|
||||
|
||||
def get_online_tag(self, recorder: object) -> str:
|
||||
"""
|
||||
Given a model recorder and return its online tag.
|
||||
|
||||
Args:
|
||||
recorder (Object): the model's recorder
|
||||
|
||||
Returns:
|
||||
str: the online tag
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_online_tag` method.")
|
||||
|
||||
def reset_online_tag(self, recorder: Union[list, object]):
|
||||
"""
|
||||
Offline all models and set the recorders to 'online'.
|
||||
|
||||
Args:
|
||||
recorder (Union[list,object]):
|
||||
the recorder you want to reset to 'online'.
|
||||
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `reset_online_tag` method.")
|
||||
|
||||
def online_models(self) -> list:
|
||||
"""
|
||||
Get current `online` models
|
||||
|
||||
Returns:
|
||||
list: a list of `online` models.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `online_models` method.")
|
||||
|
||||
def update_online_pred(self, to_date=None):
|
||||
"""
|
||||
Update the predictions of `online` models to to_date.
|
||||
|
||||
Args:
|
||||
to_date (pd.Timestamp): the pred before this date will be updated. None for updating to the latest.
|
||||
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `update_online_pred` method.")
|
||||
|
||||
|
||||
class OnlineToolR(OnlineTool):
|
||||
"""
|
||||
The implementation of OnlineTool based on (R)ecorder.
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name: str):
|
||||
"""
|
||||
Init OnlineToolR.
|
||||
|
||||
Args:
|
||||
experiment_name (str): the experiment name.
|
||||
"""
|
||||
super().__init__()
|
||||
self.exp_name = experiment_name
|
||||
|
||||
def set_online_tag(self, tag, recorder: Union[Recorder, List]):
|
||||
"""
|
||||
Set `tag` to the model's recorder to sign whether online.
|
||||
|
||||
Args:
|
||||
tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG`
|
||||
recorder (Union[Recorder, List]): a list of Recorder or an instance of Recorder
|
||||
"""
|
||||
if isinstance(recorder, Recorder):
|
||||
recorder = [recorder]
|
||||
for rec in recorder:
|
||||
rec.set_tags(**{self.ONLINE_KEY: tag})
|
||||
self.logger.info(f"Set {len(recorder)} models to '{tag}'.")
|
||||
|
||||
def get_online_tag(self, recorder: Recorder) -> str:
|
||||
"""
|
||||
Given a model recorder and return its online tag.
|
||||
|
||||
Args:
|
||||
recorder (Recorder): an instance of recorder
|
||||
|
||||
Returns:
|
||||
str: the online tag
|
||||
"""
|
||||
tags = recorder.list_tags()
|
||||
return tags.get(self.ONLINE_KEY, self.OFFLINE_TAG)
|
||||
|
||||
def reset_online_tag(self, recorder: Union[Recorder, List]):
|
||||
"""
|
||||
Offline all models and set the recorders to 'online'.
|
||||
|
||||
Args:
|
||||
recorder (Union[Recorder, List]):
|
||||
the recorder you want to reset to 'online'.
|
||||
|
||||
"""
|
||||
if isinstance(recorder, Recorder):
|
||||
recorder = [recorder]
|
||||
recs = list_recorders(self.exp_name)
|
||||
self.set_online_tag(self.OFFLINE_TAG, list(recs.values()))
|
||||
self.set_online_tag(self.ONLINE_TAG, recorder)
|
||||
|
||||
def online_models(self) -> list:
|
||||
"""
|
||||
Get current `online` models
|
||||
|
||||
Returns:
|
||||
list: a list of `online` models.
|
||||
"""
|
||||
return list(list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())
|
||||
|
||||
def update_online_pred(self, to_date=None):
|
||||
"""
|
||||
Update the predictions of online models to to_date.
|
||||
|
||||
Args:
|
||||
to_date (pd.Timestamp): the pred before this date will be updated. None for updating to latest time in Calendar.
|
||||
"""
|
||||
online_models = self.online_models()
|
||||
for rec in online_models:
|
||||
hist_ref = 0
|
||||
task = rec.load_object("task")
|
||||
# Special treatment of historical dependencies
|
||||
if task["dataset"]["class"] == "TSDatasetH":
|
||||
hist_ref = task["dataset"]["kwargs"]["step_len"]
|
||||
PredUpdater(rec, to_date=to_date, hist_ref=hist_ref).update()
|
||||
|
||||
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")
|
||||
@@ -151,6 +151,10 @@ class SignalRecord(RecordTemp):
|
||||
del params["data_key"]
|
||||
# The backend handler should be DataHandler
|
||||
raw_label = self.dataset.prepare(**params)
|
||||
except AttributeError:
|
||||
# The data handler is initialize with `drop_raw=True`...
|
||||
# So raw_label is not available
|
||||
raw_label = None
|
||||
|
||||
self.recorder.save_objects(**{"label.pkl": raw_label})
|
||||
self.dataset.__class__ = orig_cls
|
||||
@@ -236,6 +240,9 @@ class SigAnaRecord(SignalRecord):
|
||||
|
||||
pred = self.load("pred.pkl")
|
||||
label = self.load("label.pkl")
|
||||
if label is None or not isinstance(label, pd.DataFrame) or label.empty:
|
||||
logger.warn(f"Empty label.")
|
||||
return
|
||||
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
|
||||
metrics = {
|
||||
"IC": ic.mean(),
|
||||
|
||||
@@ -39,6 +39,9 @@ class Recorder:
|
||||
def __str__(self):
|
||||
return str(self.info)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.info["id"])
|
||||
|
||||
@property
|
||||
def info(self):
|
||||
output = dict()
|
||||
@@ -232,6 +235,14 @@ class MLflowRecorder(Recorder):
|
||||
client=self.client,
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.info["id"])
|
||||
|
||||
def __eq__(self, o: object) -> bool:
|
||||
if isinstance(o, MLflowRecorder):
|
||||
return self.info["id"] == o.info["id"]
|
||||
return False
|
||||
|
||||
@property
|
||||
def uri(self):
|
||||
return self._uri
|
||||
|
||||
13
qlib/workflow/task/__init__.py
Normal file
13
qlib/workflow/task/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
Task related workflow is implemented in this folder
|
||||
|
||||
A typical task workflow
|
||||
|
||||
| Step | Description |
|
||||
|-----------------------+------------------------------------------------|
|
||||
| TaskGen | Generating tasks. |
|
||||
| TaskManager(optional) | Manage generated tasks |
|
||||
| run task | retrive tasks from TaskManager and run tasks. |
|
||||
"""
|
||||
219
qlib/workflow/task/collect.py
Normal file
219
qlib/workflow/task/collect.py
Normal file
@@ -0,0 +1,219 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
Collector module can collect objects from everywhere and process them such as merging, grouping, averaging and so on.
|
||||
"""
|
||||
|
||||
from typing import Callable, Dict, List
|
||||
from qlib.utils.serial import Serializable
|
||||
from qlib.workflow import R
|
||||
|
||||
|
||||
class Collector(Serializable):
|
||||
"""The collector to collect different results"""
|
||||
|
||||
pickle_backend = "dill" # use dill to dump user method
|
||||
|
||||
def __init__(self, process_list=[]):
|
||||
"""
|
||||
Init Collector.
|
||||
|
||||
Args:
|
||||
process_list (list or Callable): the list of processors or the instance of a processor to process dict.
|
||||
"""
|
||||
if not isinstance(process_list, list):
|
||||
process_list = [process_list]
|
||||
self.process_list = process_list
|
||||
|
||||
def collect(self) -> dict:
|
||||
"""
|
||||
Collect the results and return a dict like {key: things}
|
||||
|
||||
Returns:
|
||||
dict: the dict after collecting.
|
||||
|
||||
For example:
|
||||
|
||||
{"prediction": pd.Series}
|
||||
|
||||
{"IC": {"Xgboost": pd.Series, "LSTM": pd.Series}}
|
||||
|
||||
......
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `collect` method.")
|
||||
|
||||
@staticmethod
|
||||
def process_collect(collected_dict, process_list=[], *args, **kwargs) -> dict:
|
||||
"""
|
||||
Do a series of processing to the dict returned by collect and return a dict like {key: things}
|
||||
For example, you can group and ensemble.
|
||||
|
||||
Args:
|
||||
collected_dict (dict): the dict return by `collect`
|
||||
process_list (list or Callable): the list of processors or the instance of a processor to process dict.
|
||||
The processor order is the same as the list order.
|
||||
For example: [Group1(..., Ensemble1()), Group2(..., Ensemble2())]
|
||||
|
||||
Returns:
|
||||
dict: the dict after processing.
|
||||
"""
|
||||
if not isinstance(process_list, list):
|
||||
process_list = [process_list]
|
||||
result = {}
|
||||
for artifact in collected_dict:
|
||||
value = collected_dict[artifact]
|
||||
for process in process_list:
|
||||
if not callable(process):
|
||||
raise NotImplementedError(f"{type(process)} is not supported in `process_collect`.")
|
||||
value = process(value, *args, **kwargs)
|
||||
result[artifact] = value
|
||||
return result
|
||||
|
||||
def __call__(self, *args, **kwargs) -> dict:
|
||||
"""
|
||||
Do the workflow including ``collect`` and ``process_collect``
|
||||
|
||||
Returns:
|
||||
dict: the dict after collecting and processing.
|
||||
"""
|
||||
collected = self.collect()
|
||||
return self.process_collect(collected, self.process_list, *args, **kwargs)
|
||||
|
||||
|
||||
class MergeCollector(Collector):
|
||||
"""
|
||||
A collector to collect the results of other Collectors
|
||||
|
||||
For example:
|
||||
|
||||
We have 2 collector, which named A and B.
|
||||
A can collect {"prediction": pd.Series} and B can collect {"IC": {"Xgboost": pd.Series, "LSTM": pd.Series}}.
|
||||
Then after this class's collect, we can collect {"A_prediction": pd.Series, "B_IC": {"Xgboost": pd.Series, "LSTM": pd.Series}}
|
||||
|
||||
......
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, collector_dict: Dict[str, Collector], process_list: List[Callable] = [], merge_func=None):
|
||||
"""
|
||||
Init MergeCollector.
|
||||
|
||||
Args:
|
||||
collector_dict (Dict[str,Collector]): the dict like {collector_key, Collector}
|
||||
process_list (List[Callable]): the list of processors or the instance of processor to process dict.
|
||||
merge_func (Callable): a method to generate outermost key. The given params are ``collector_key`` from collector_dict and ``key`` from every collector after collecting.
|
||||
None for using tuple to connect them, such as "ABC"+("a","b") -> ("ABC", ("a","b")).
|
||||
"""
|
||||
super().__init__(process_list=process_list)
|
||||
self.collector_dict = collector_dict
|
||||
self.merge_func = merge_func
|
||||
|
||||
def collect(self) -> dict:
|
||||
"""
|
||||
Collect all results of collector_dict and change the outermost key to a recombination key.
|
||||
|
||||
Returns:
|
||||
dict: the dict after collecting.
|
||||
"""
|
||||
collect_dict = {}
|
||||
for collector_key, collector in self.collector_dict.items():
|
||||
tmp_dict = collector()
|
||||
for key, value in tmp_dict.items():
|
||||
if self.merge_func is not None:
|
||||
collect_dict[self.merge_func(collector_key, key)] = value
|
||||
else:
|
||||
collect_dict[(collector_key, key)] = value
|
||||
return collect_dict
|
||||
|
||||
|
||||
class RecorderCollector(Collector):
|
||||
ART_KEY_RAW = "__raw"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experiment,
|
||||
process_list=[],
|
||||
rec_key_func=None,
|
||||
rec_filter_func=None,
|
||||
artifacts_path={"pred": "pred.pkl"},
|
||||
artifacts_key=None,
|
||||
):
|
||||
"""
|
||||
Init RecorderCollector.
|
||||
|
||||
Args:
|
||||
experiment (Experiment or str): an instance of an Experiment or the name of an Experiment
|
||||
process_list (list or Callable): the list of processors or the instance of a processor to process dict.
|
||||
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
|
||||
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
|
||||
artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}.
|
||||
artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts.
|
||||
"""
|
||||
super().__init__(process_list=process_list)
|
||||
if isinstance(experiment, str):
|
||||
experiment = R.get_exp(experiment_name=experiment)
|
||||
self.experiment = experiment
|
||||
self.artifacts_path = artifacts_path
|
||||
if rec_key_func is None:
|
||||
rec_key_func = lambda rec: rec.info["id"]
|
||||
if artifacts_key is None:
|
||||
artifacts_key = list(self.artifacts_path.keys())
|
||||
self.rec_key_func = rec_key_func
|
||||
self.artifacts_key = artifacts_key
|
||||
self.rec_filter_func = rec_filter_func
|
||||
|
||||
def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) -> dict:
|
||||
"""
|
||||
Collect different artifacts based on recorder after filtering.
|
||||
|
||||
Args:
|
||||
artifacts_key (str or List, optional): the artifacts key you want to get. If None, use the default.
|
||||
rec_filter_func (Callable, optional): filter the recorder by return True or False. If None, use the default.
|
||||
only_exist (bool, optional): if only collect the artifacts when a recorder really has.
|
||||
If True, the recorder with exception when loading will not be collected. But if False, it will raise the exception.
|
||||
|
||||
Returns:
|
||||
dict: the dict after collected like {artifact: {rec_key: object}}
|
||||
"""
|
||||
if artifacts_key is None:
|
||||
artifacts_key = self.artifacts_key
|
||||
if rec_filter_func is None:
|
||||
rec_filter_func = self.rec_filter_func
|
||||
|
||||
if isinstance(artifacts_key, str):
|
||||
artifacts_key = [artifacts_key]
|
||||
|
||||
collect_dict = {}
|
||||
# filter records
|
||||
recs = self.experiment.list_recorders()
|
||||
recs_flt = {}
|
||||
for rid, rec in recs.items():
|
||||
if rec_filter_func is None or rec_filter_func(rec):
|
||||
recs_flt[rid] = rec
|
||||
|
||||
for _, rec in recs_flt.items():
|
||||
rec_key = self.rec_key_func(rec)
|
||||
for key in artifacts_key:
|
||||
if self.ART_KEY_RAW == key:
|
||||
artifact = rec
|
||||
else:
|
||||
try:
|
||||
artifact = rec.load_object(self.artifacts_path[key])
|
||||
except Exception as e:
|
||||
if only_exist:
|
||||
# only collect existing artifact
|
||||
continue
|
||||
raise e
|
||||
collect_dict.setdefault(key, {})[rec_key] = artifact
|
||||
|
||||
return collect_dict
|
||||
|
||||
def get_exp_name(self) -> str:
|
||||
"""
|
||||
Get experiment name
|
||||
|
||||
Returns:
|
||||
str: experiment name
|
||||
"""
|
||||
return self.experiment.name
|
||||
231
qlib/workflow/task/gen.py
Normal file
231
qlib/workflow/task/gen.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
TaskGenerator module can generate many tasks based on TaskGen and some task templates.
|
||||
"""
|
||||
import abc
|
||||
import copy
|
||||
from typing import List, Union, Callable
|
||||
from .utils import TimeAdjuster
|
||||
|
||||
|
||||
def task_generator(tasks, generators) -> list:
|
||||
"""
|
||||
Use a list of TaskGen and a list of task templates to generate different tasks.
|
||||
|
||||
For examples:
|
||||
|
||||
There are 3 task templates a,b,c and 2 TaskGen A,B. A will generates 2 tasks from a template and B will generates 3 tasks from a template.
|
||||
task_generator([a, b, c], [A, B]) will finally generate 3*2*3 = 18 tasks.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tasks : List[dict] or dict
|
||||
a list of task templates or a single task
|
||||
generators : List[TaskGen] or TaskGen
|
||||
a list of TaskGen or a single TaskGen
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
a list of tasks
|
||||
"""
|
||||
|
||||
if isinstance(tasks, dict):
|
||||
tasks = [tasks]
|
||||
if isinstance(generators, TaskGen):
|
||||
generators = [generators]
|
||||
|
||||
# generate gen_task_list
|
||||
for gen in generators:
|
||||
new_task_list = []
|
||||
for task in tasks:
|
||||
new_task_list.extend(gen.generate(task))
|
||||
tasks = new_task_list
|
||||
|
||||
return tasks
|
||||
|
||||
|
||||
class TaskGen(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
The base class for generating different tasks
|
||||
|
||||
Example 1:
|
||||
|
||||
input: a specific task template and rolling steps
|
||||
|
||||
output: rolling version of the tasks
|
||||
|
||||
Example 2:
|
||||
|
||||
input: a specific task template and losses list
|
||||
|
||||
output: a set of tasks with different losses
|
||||
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def generate(self, task: dict) -> List[dict]:
|
||||
"""
|
||||
Generate different tasks based on a task template
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task: dict
|
||||
a task template
|
||||
|
||||
Returns
|
||||
-------
|
||||
typing.List[dict]:
|
||||
A list of tasks
|
||||
"""
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
This is just a syntactic sugar for generate
|
||||
"""
|
||||
return self.generate(*args, **kwargs)
|
||||
|
||||
|
||||
def handler_mod(task: dict, rolling_gen):
|
||||
"""
|
||||
Help to modify the handler end time when using RollingGen
|
||||
|
||||
Args:
|
||||
task (dict): a task template
|
||||
rg (RollingGen): an instance of RollingGen
|
||||
"""
|
||||
try:
|
||||
interval = rolling_gen.ta.cal_interval(
|
||||
task["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"],
|
||||
task["dataset"]["kwargs"]["segments"][rolling_gen.test_key][1],
|
||||
)
|
||||
# if end_time < the end of test_segments, then change end_time to allow load more data
|
||||
if interval < 0:
|
||||
task["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"] = copy.deepcopy(
|
||||
task["dataset"]["kwargs"]["segments"][rolling_gen.test_key][1]
|
||||
)
|
||||
except KeyError:
|
||||
# Maybe dataset do not have handler, then do nothing.
|
||||
pass
|
||||
|
||||
|
||||
class RollingGen(TaskGen):
|
||||
ROLL_EX = TimeAdjuster.SHIFT_EX # fixed start date, expanding end date
|
||||
ROLL_SD = TimeAdjuster.SHIFT_SD # fixed segments size, slide it from start date
|
||||
|
||||
def __init__(self, step: int = 40, rtype: str = ROLL_EX, ds_extra_mod_func: Union[None, Callable] = handler_mod):
|
||||
"""
|
||||
Generate tasks for rolling
|
||||
|
||||
Parameters
|
||||
----------
|
||||
step : int
|
||||
step to rolling
|
||||
rtype : str
|
||||
rolling type (expanding, sliding)
|
||||
ds_extra_mod_func: Callable
|
||||
A method like: handler_mod(task: dict, rg: RollingGen)
|
||||
Do some extra action after generating a task. For example, use ``handler_mod`` to modify the end time of the handler of a dataset.
|
||||
"""
|
||||
self.step = step
|
||||
self.rtype = rtype
|
||||
self.ds_extra_mod_func = ds_extra_mod_func
|
||||
self.ta = TimeAdjuster(future=True)
|
||||
|
||||
self.test_key = "test"
|
||||
self.train_key = "train"
|
||||
|
||||
def generate(self, task: dict) -> List[dict]:
|
||||
"""
|
||||
Converting the task into a rolling task.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task: dict
|
||||
A dict describing a task. For example.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
DEFAULT_TASK = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": "csi100",
|
||||
},
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-20"), # Please avoid leaking the future test data into validation
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
"record": [
|
||||
{
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
Returns
|
||||
----------
|
||||
List[dict]: a list of tasks
|
||||
"""
|
||||
res = []
|
||||
|
||||
prev_seg = None
|
||||
test_end = None
|
||||
while True:
|
||||
t = copy.deepcopy(task)
|
||||
|
||||
# calculate segments
|
||||
if prev_seg is None:
|
||||
# First rolling
|
||||
# 1) prepare the end point
|
||||
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
|
||||
test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1]
|
||||
# 2) and init test segments
|
||||
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
|
||||
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
|
||||
else:
|
||||
segments = {}
|
||||
try:
|
||||
for k, seg in prev_seg.items():
|
||||
# decide how to shift
|
||||
# expanding only for train data, the segments size of test data and valid data won't change
|
||||
if k == self.train_key and self.rtype == self.ROLL_EX:
|
||||
rtype = self.ta.SHIFT_EX
|
||||
else:
|
||||
rtype = self.ta.SHIFT_SD
|
||||
# shift the segments data
|
||||
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
|
||||
if segments[self.test_key][0] > test_end:
|
||||
break
|
||||
except KeyError:
|
||||
# We reach the end of tasks
|
||||
# No more rolling
|
||||
break
|
||||
|
||||
# update segments of this task
|
||||
t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments)
|
||||
prev_seg = segments
|
||||
if self.ds_extra_mod_func is not None:
|
||||
self.ds_extra_mod_func(t, self)
|
||||
res.append(t)
|
||||
return res
|
||||
493
qlib/workflow/task/manage.py
Normal file
493
qlib/workflow/task/manage.py
Normal file
@@ -0,0 +1,493 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
TaskManager can fetch unused tasks automatically and manage the lifecycle of a set of tasks with error handling.
|
||||
These features can run tasks concurrently and ensure every task will be used only once.
|
||||
Task Manager will store all tasks in `MongoDB <https://www.mongodb.com/>`_.
|
||||
Users **MUST** finished the configuration of `MongoDB <https://www.mongodb.com/>`_ when using this module.
|
||||
|
||||
A task in TaskManager consists of 3 parts
|
||||
- tasks description: the desc will define the task
|
||||
- tasks status: the status of the task
|
||||
- tasks result: A user can get the task with the task description and task result.
|
||||
"""
|
||||
import concurrent
|
||||
import pickle
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, List
|
||||
|
||||
import fire
|
||||
import pymongo
|
||||
from bson.binary import Binary
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.errors import InvalidDocument
|
||||
from qlib import auto_init, get_module_logger
|
||||
from tqdm.cli import tqdm
|
||||
|
||||
from .utils import get_mongodb
|
||||
|
||||
|
||||
class TaskManager:
|
||||
"""
|
||||
TaskManager
|
||||
|
||||
Here is what will a task looks like when it created by TaskManager
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
'def': pickle serialized task definition. using pickle will make it easier
|
||||
'filter': json-like data. This is for filtering the tasks.
|
||||
'status': 'waiting' | 'running' | 'done'
|
||||
'res': pickle serialized task result,
|
||||
}
|
||||
|
||||
The tasks manager assumes that you will only update the tasks you fetched.
|
||||
The mongo fetch one and update will make it date updating secure.
|
||||
|
||||
.. note::
|
||||
|
||||
Assumption: the data in MongoDB was encoded and the data out of MongoDB was decoded
|
||||
|
||||
Here are four status which are:
|
||||
|
||||
STATUS_WAITING: waiting for training
|
||||
|
||||
STATUS_RUNNING: training
|
||||
|
||||
STATUS_PART_DONE: finished some step and waiting for next step
|
||||
|
||||
STATUS_DONE: all work done
|
||||
"""
|
||||
|
||||
STATUS_WAITING = "waiting"
|
||||
STATUS_RUNNING = "running"
|
||||
STATUS_DONE = "done"
|
||||
STATUS_PART_DONE = "part_done"
|
||||
|
||||
ENCODE_FIELDS_PREFIX = ["def", "res"]
|
||||
|
||||
def __init__(self, task_pool: str = None):
|
||||
"""
|
||||
Init Task Manager, remember to make the statement of MongoDB url and database name firstly.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task_pool: str
|
||||
the name of Collection in MongoDB
|
||||
"""
|
||||
self.mdb = get_mongodb()
|
||||
if task_pool is not None:
|
||||
self.task_pool = getattr(self.mdb, task_pool)
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
def list(self) -> list:
|
||||
"""
|
||||
List the all collection(task_pool) of the db
|
||||
|
||||
Returns:
|
||||
list
|
||||
"""
|
||||
return self.mdb.list_collection_names()
|
||||
|
||||
def _encode_task(self, task):
|
||||
for prefix in self.ENCODE_FIELDS_PREFIX:
|
||||
for k in list(task.keys()):
|
||||
if k.startswith(prefix):
|
||||
task[k] = Binary(pickle.dumps(task[k]))
|
||||
return task
|
||||
|
||||
def _decode_task(self, task):
|
||||
for prefix in self.ENCODE_FIELDS_PREFIX:
|
||||
for k in list(task.keys()):
|
||||
if k.startswith(prefix):
|
||||
task[k] = pickle.loads(task[k])
|
||||
return task
|
||||
|
||||
def _dict_to_str(self, flt):
|
||||
return {k: str(v) for k, v in flt.items()}
|
||||
|
||||
def replace_task(self, task, new_task):
|
||||
"""
|
||||
Use a new task to replace a old one
|
||||
|
||||
Args:
|
||||
task: old task
|
||||
new_task: new task
|
||||
"""
|
||||
new_task = self._encode_task(new_task)
|
||||
query = {"_id": ObjectId(task["_id"])}
|
||||
try:
|
||||
self.task_pool.replace_one(query, new_task)
|
||||
except InvalidDocument:
|
||||
task["filter"] = self._dict_to_str(task["filter"])
|
||||
self.task_pool.replace_one(query, new_task)
|
||||
|
||||
def insert_task(self, task):
|
||||
"""
|
||||
Insert a task.
|
||||
|
||||
Args:
|
||||
task: the task waiting for insert
|
||||
|
||||
Returns:
|
||||
pymongo.results.InsertOneResult
|
||||
"""
|
||||
try:
|
||||
insert_result = self.task_pool.insert_one(task)
|
||||
except InvalidDocument:
|
||||
task["filter"] = self._dict_to_str(task["filter"])
|
||||
insert_result = self.task_pool.insert_one(task)
|
||||
return insert_result
|
||||
|
||||
def insert_task_def(self, task_def):
|
||||
"""
|
||||
Insert a task to task_pool
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task_def: dict
|
||||
the task definition
|
||||
|
||||
Returns
|
||||
-------
|
||||
pymongo.results.InsertOneResult
|
||||
"""
|
||||
task = self._encode_task(
|
||||
{
|
||||
"def": task_def,
|
||||
"filter": task_def, # FIXME: catch the raised error
|
||||
"status": self.STATUS_WAITING,
|
||||
}
|
||||
)
|
||||
insert_result = self.insert_task(task)
|
||||
return insert_result
|
||||
|
||||
def create_task(self, task_def_l, dry_run=False, print_nt=False) -> List[str]:
|
||||
"""
|
||||
If the tasks in task_def_l are new, then insert new tasks into the task_pool, and record inserted_id.
|
||||
If a task is not new, then just query its _id.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task_def_l: list
|
||||
a list of task
|
||||
dry_run: bool
|
||||
if insert those new tasks to task pool
|
||||
print_nt: bool
|
||||
if print new task
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[str]
|
||||
a list of the _id of task_def_l
|
||||
"""
|
||||
new_tasks = []
|
||||
_id_list = []
|
||||
for t in task_def_l:
|
||||
try:
|
||||
r = self.task_pool.find_one({"filter": t})
|
||||
except InvalidDocument:
|
||||
r = self.task_pool.find_one({"filter": self._dict_to_str(t)})
|
||||
if r is None:
|
||||
new_tasks.append(t)
|
||||
if not dry_run:
|
||||
insert_result = self.insert_task_def(t)
|
||||
_id_list.append(insert_result.inserted_id)
|
||||
else:
|
||||
_id_list.append(None)
|
||||
else:
|
||||
_id_list.append(self._decode_task(r)["_id"])
|
||||
|
||||
self.logger.info(f"Total Tasks: {len(task_def_l)}, New Tasks: {len(new_tasks)}")
|
||||
|
||||
if print_nt: # print new task
|
||||
for t in new_tasks:
|
||||
print(t)
|
||||
|
||||
if dry_run:
|
||||
return []
|
||||
|
||||
return _id_list
|
||||
|
||||
def fetch_task(self, query={}, status=STATUS_WAITING) -> dict:
|
||||
"""
|
||||
Use query to fetch tasks.
|
||||
|
||||
Args:
|
||||
query (dict, optional): query dict. Defaults to {}.
|
||||
status (str, optional): [description]. Defaults to STATUS_WAITING.
|
||||
|
||||
Returns:
|
||||
dict: a task(document in collection) after decoding
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query.update({"status": status})
|
||||
task = self.task_pool.find_one_and_update(
|
||||
query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)]
|
||||
)
|
||||
# null will be at the top after sorting when using ASCENDING, so the larger the number higher, the higher the priority
|
||||
if task is None:
|
||||
return None
|
||||
task["status"] = self.STATUS_RUNNING
|
||||
return self._decode_task(task)
|
||||
|
||||
@contextmanager
|
||||
def safe_fetch_task(self, query={}, status=STATUS_WAITING):
|
||||
"""
|
||||
Fetch task from task_pool using query with contextmanager
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: dict
|
||||
the dict of query
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict: a task(document in collection) after decoding
|
||||
"""
|
||||
task = self.fetch_task(query=query, status=status)
|
||||
try:
|
||||
yield task
|
||||
except Exception:
|
||||
if task is not None:
|
||||
self.logger.info("Returning task before raising error")
|
||||
self.return_task(task)
|
||||
self.logger.info("Task returned")
|
||||
raise
|
||||
|
||||
def task_fetcher_iter(self, query={}):
|
||||
while True:
|
||||
with self.safe_fetch_task(query=query) as task:
|
||||
if task is None:
|
||||
break
|
||||
yield task
|
||||
|
||||
def query(self, query={}, decode=True):
|
||||
"""
|
||||
Query task in collection.
|
||||
This function may raise exception `pymongo.errors.CursorNotFound: cursor id not found` if it takes too long to iterate the generator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: dict
|
||||
the dict of query
|
||||
decode: bool
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict: a task(document in collection) after decoding
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
for t in self.task_pool.find(query):
|
||||
yield self._decode_task(t)
|
||||
|
||||
def re_query(self, _id):
|
||||
"""
|
||||
Use _id to query task.
|
||||
|
||||
Args:
|
||||
_id (str): _id of a document
|
||||
|
||||
Returns:
|
||||
dict: a task(document in collection) after decoding
|
||||
"""
|
||||
t = self.task_pool.find_one({"_id": ObjectId(_id)})
|
||||
return self._decode_task(t)
|
||||
|
||||
def commit_task_res(self, task, res, status=STATUS_DONE):
|
||||
"""
|
||||
Commit the result to task['res'].
|
||||
|
||||
Args:
|
||||
task ([type]): [description]
|
||||
res (object): the result you want to save
|
||||
status (str, optional): STATUS_WAITING, STATUS_RUNNING, STATUS_DONE, STATUS_PART_DONE. Defaults to STATUS_DONE.
|
||||
"""
|
||||
# A workaround to use the class attribute.
|
||||
if status is None:
|
||||
status = TaskManager.STATUS_DONE
|
||||
self.task_pool.update_one({"_id": task["_id"]}, {"$set": {"status": status, "res": Binary(pickle.dumps(res))}})
|
||||
|
||||
def return_task(self, task, status=STATUS_WAITING):
|
||||
"""
|
||||
Return a task to status. Alway using in error handling.
|
||||
|
||||
Args:
|
||||
task ([type]): [description]
|
||||
status (str, optional): STATUS_WAITING, STATUS_RUNNING, STATUS_DONE, STATUS_PART_DONE. Defaults to STATUS_WAITING.
|
||||
"""
|
||||
if status is None:
|
||||
status = TaskManager.STATUS_WAITING
|
||||
update_dict = {"$set": {"status": status}}
|
||||
self.task_pool.update_one({"_id": task["_id"]}, update_dict)
|
||||
|
||||
def remove(self, query={}):
|
||||
"""
|
||||
Remove the task using query
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: dict
|
||||
the dict of query
|
||||
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
self.task_pool.delete_many(query)
|
||||
|
||||
def task_stat(self, query={}) -> dict:
|
||||
"""
|
||||
Count the tasks in every status.
|
||||
|
||||
Args:
|
||||
query (dict, optional): the query dict. Defaults to {}.
|
||||
|
||||
Returns:
|
||||
dict
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
tasks = self.query(query=query, decode=False)
|
||||
status_stat = {}
|
||||
for t in tasks:
|
||||
status_stat[t["status"]] = status_stat.get(t["status"], 0) + 1
|
||||
return status_stat
|
||||
|
||||
def reset_waiting(self, query={}):
|
||||
"""
|
||||
Reset all running task into waiting status. Can be used when some running task exit unexpected.
|
||||
|
||||
Args:
|
||||
query (dict, optional): the query dict. Defaults to {}.
|
||||
"""
|
||||
query = query.copy()
|
||||
# default query
|
||||
if "status" not in query:
|
||||
query["status"] = self.STATUS_RUNNING
|
||||
return self.reset_status(query=query, status=self.STATUS_WAITING)
|
||||
|
||||
def reset_status(self, query, status):
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
print(self.task_pool.update_many(query, {"$set": {"status": status}}))
|
||||
|
||||
def prioritize(self, task, priority: int):
|
||||
"""
|
||||
Set priority for task
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task : dict
|
||||
The task query from the database
|
||||
priority : int
|
||||
the target priority
|
||||
"""
|
||||
update_dict = {"$set": {"priority": priority}}
|
||||
self.task_pool.update_one({"_id": task["_id"]}, update_dict)
|
||||
|
||||
def _get_undone_n(self, task_stat):
|
||||
return task_stat.get(self.STATUS_WAITING, 0) + task_stat.get(self.STATUS_RUNNING, 0)
|
||||
|
||||
def _get_total(self, task_stat):
|
||||
return sum(task_stat.values())
|
||||
|
||||
def wait(self, query={}):
|
||||
task_stat = self.task_stat(query)
|
||||
total = self._get_total(task_stat)
|
||||
last_undone_n = self._get_undone_n(task_stat)
|
||||
with tqdm(total=total, initial=total - last_undone_n) as pbar:
|
||||
while True:
|
||||
time.sleep(10)
|
||||
undone_n = self._get_undone_n(self.task_stat(query))
|
||||
pbar.update(last_undone_n - undone_n)
|
||||
last_undone_n = undone_n
|
||||
if undone_n == 0:
|
||||
break
|
||||
|
||||
def __str__(self):
|
||||
return f"TaskManager({self.task_pool})"
|
||||
|
||||
|
||||
def run_task(
|
||||
task_func: Callable,
|
||||
task_pool: str,
|
||||
query: dict = {},
|
||||
force_release: bool = False,
|
||||
before_status: str = TaskManager.STATUS_WAITING,
|
||||
after_status: str = TaskManager.STATUS_DONE,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
While the task pool is not empty (has WAITING tasks), use task_func to fetch and run tasks in task_pool
|
||||
|
||||
After running this method, here are 4 situations (before_status -> after_status):
|
||||
|
||||
STATUS_WAITING -> STATUS_DONE: use task["def"] as `task_func` param
|
||||
|
||||
STATUS_WAITING -> STATUS_PART_DONE: use task["def"] as `task_func` param
|
||||
|
||||
STATUS_PART_DONE -> STATUS_PART_DONE: use task["res"] as `task_func` param
|
||||
|
||||
STATUS_PART_DONE -> STATUS_DONE: use task["res"] as `task_func` param
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task_func : Callable
|
||||
def (task_def, **kwargs) -> <res which will be committed>
|
||||
the function to run the task
|
||||
task_pool : str
|
||||
the name of the task pool (Collection in MongoDB)
|
||||
query: dict
|
||||
will use this dict to query task_pool when fetching task
|
||||
force_release : bool
|
||||
will the program force to release the resource
|
||||
before_status : str:
|
||||
the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE.
|
||||
after_status : str:
|
||||
the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE.
|
||||
kwargs
|
||||
the params for `task_func`
|
||||
"""
|
||||
tm = TaskManager(task_pool)
|
||||
|
||||
ever_run = False
|
||||
|
||||
while True:
|
||||
with tm.safe_fetch_task(status=before_status, query=query) as task:
|
||||
if task is None:
|
||||
break
|
||||
get_module_logger("run_task").info(task["def"])
|
||||
# when fetching `WAITING` task, use task["def"] to train
|
||||
if before_status == TaskManager.STATUS_WAITING:
|
||||
param = task["def"]
|
||||
# when fetching `PART_DONE` task, use task["res"] to train because the middle result has been saved to task["res"]
|
||||
elif before_status == TaskManager.STATUS_PART_DONE:
|
||||
param = task["res"]
|
||||
else:
|
||||
raise ValueError("The fetched task must be `STATUS_WAITING` or `STATUS_PART_DONE`!")
|
||||
if force_release:
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
|
||||
res = executor.submit(task_func, param, **kwargs).result()
|
||||
else:
|
||||
res = task_func(param, **kwargs)
|
||||
tm.commit_task_res(task, res, status=after_status)
|
||||
ever_run = True
|
||||
|
||||
return ever_run
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# This is for using it in cmd
|
||||
# E.g. : `python -m qlib.workflow.task.manage list`
|
||||
auto_init()
|
||||
fire.Fire(TaskManager)
|
||||
258
qlib/workflow/task/utils.py
Normal file
258
qlib/workflow/task/utils.py
Normal file
@@ -0,0 +1,258 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
Some tools for task management.
|
||||
"""
|
||||
|
||||
import bisect
|
||||
import pandas as pd
|
||||
from qlib.data import D
|
||||
from qlib.workflow import R
|
||||
from qlib.config import C
|
||||
from qlib.log import get_module_logger
|
||||
from pymongo import MongoClient
|
||||
from pymongo.database import Database
|
||||
from typing import Union
|
||||
|
||||
|
||||
def get_mongodb() -> Database:
|
||||
|
||||
"""
|
||||
Get database in MongoDB, which means you need to declare the address and the name of a database at first.
|
||||
|
||||
For example:
|
||||
|
||||
Using qlib.init():
|
||||
|
||||
mongo_conf = {
|
||||
"task_url": task_url, # your MongoDB url
|
||||
"task_db_name": task_db_name, # database name
|
||||
}
|
||||
qlib.init(..., mongo=mongo_conf)
|
||||
|
||||
After qlib.init():
|
||||
|
||||
C["mongo"] = {
|
||||
"task_url" : "mongodb://localhost:27017/",
|
||||
"task_db_name" : "rolling_db"
|
||||
}
|
||||
|
||||
Returns:
|
||||
Database: the Database instance
|
||||
"""
|
||||
try:
|
||||
cfg = C["mongo"]
|
||||
except KeyError:
|
||||
get_module_logger("task").error("Please configure `C['mongo']` before using TaskManager")
|
||||
raise
|
||||
|
||||
client = MongoClient(cfg["task_url"])
|
||||
return client.get_database(name=cfg["task_db_name"])
|
||||
|
||||
|
||||
def list_recorders(experiment, rec_filter_func=None):
|
||||
"""
|
||||
List all recorders which can pass the filter in an experiment.
|
||||
|
||||
Args:
|
||||
experiment (str or Experiment): the name of an Experiment or an instance
|
||||
rec_filter_func (Callable, optional): return True to retain the given recorder. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict: a dict {rid: recorder} after filtering.
|
||||
"""
|
||||
if isinstance(experiment, str):
|
||||
experiment = R.get_exp(experiment_name=experiment)
|
||||
recs = experiment.list_recorders()
|
||||
recs_flt = {}
|
||||
for rid, rec in recs.items():
|
||||
if rec_filter_func is None or rec_filter_func(rec):
|
||||
recs_flt[rid] = rec
|
||||
|
||||
return recs_flt
|
||||
|
||||
|
||||
class TimeAdjuster:
|
||||
"""
|
||||
Find appropriate date and adjust date.
|
||||
"""
|
||||
|
||||
def __init__(self, future=True, end_time=None):
|
||||
self._future = future
|
||||
self.cals = D.calendar(future=future, end_time=end_time)
|
||||
|
||||
def set_end_time(self, end_time=None):
|
||||
"""
|
||||
Set end time. None for use calendar's end time.
|
||||
|
||||
Args:
|
||||
end_time
|
||||
"""
|
||||
self.cals = D.calendar(future=self._future, end_time=end_time)
|
||||
|
||||
def get(self, idx: int):
|
||||
"""
|
||||
Get datetime by index.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
idx : int
|
||||
index of the calendar
|
||||
"""
|
||||
if idx >= len(self.cals):
|
||||
return None
|
||||
return self.cals[idx]
|
||||
|
||||
def max(self) -> pd.Timestamp:
|
||||
"""
|
||||
Return the max calendar datetime
|
||||
"""
|
||||
return max(self.cals)
|
||||
|
||||
def align_idx(self, time_point, tp_type="start") -> int:
|
||||
"""
|
||||
Align the index of time_point in the calendar.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
time_point
|
||||
tp_type : str
|
||||
|
||||
Returns
|
||||
-------
|
||||
index : int
|
||||
"""
|
||||
time_point = pd.Timestamp(time_point)
|
||||
if tp_type == "start":
|
||||
idx = bisect.bisect_left(self.cals, time_point)
|
||||
elif tp_type == "end":
|
||||
idx = bisect.bisect_right(self.cals, time_point) - 1
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
return idx
|
||||
|
||||
def cal_interval(self, time_point_A, time_point_B) -> int:
|
||||
"""
|
||||
Calculate the trading day interval (time_point_A - time_point_B)
|
||||
|
||||
Args:
|
||||
time_point_A : time_point_A
|
||||
time_point_B : time_point_B (is the past of time_point_A)
|
||||
|
||||
Returns:
|
||||
int: the interval between A and B
|
||||
"""
|
||||
return self.align_idx(time_point_A) - self.align_idx(time_point_B)
|
||||
|
||||
def align_time(self, time_point, tp_type="start") -> pd.Timestamp:
|
||||
"""
|
||||
Align time_point to trade date of calendar
|
||||
|
||||
Args:
|
||||
time_point
|
||||
Time point
|
||||
tp_type : str
|
||||
time point type (`"start"`, `"end"`)
|
||||
|
||||
Returns:
|
||||
pd.Timestamp
|
||||
"""
|
||||
return self.cals[self.align_idx(time_point, tp_type=tp_type)]
|
||||
|
||||
def align_seg(self, segment: Union[dict, tuple]) -> Union[dict, tuple]:
|
||||
"""
|
||||
Align the given date to the trade date
|
||||
|
||||
for example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
input: {'train': ('2008-01-01', '2014-12-31'), 'valid': ('2015-01-01', '2016-12-31'), 'test': ('2017-01-01', '2020-08-01')}
|
||||
|
||||
output: {'train': (Timestamp('2008-01-02 00:00:00'), Timestamp('2014-12-31 00:00:00')),
|
||||
'valid': (Timestamp('2015-01-05 00:00:00'), Timestamp('2016-12-30 00:00:00')),
|
||||
'test': (Timestamp('2017-01-03 00:00:00'), Timestamp('2020-07-31 00:00:00'))}
|
||||
|
||||
Parameters
|
||||
----------
|
||||
segment
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[dict, tuple]: the start and end trade date (pd.Timestamp) between the given start and end date.
|
||||
"""
|
||||
if isinstance(segment, dict):
|
||||
return {k: self.align_seg(seg) for k, seg in segment.items()}
|
||||
elif isinstance(segment, tuple) or isinstance(segment, list):
|
||||
return self.align_time(segment[0], tp_type="start"), self.align_time(segment[1], tp_type="end")
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
def truncate(self, segment: tuple, test_start, days: int) -> tuple:
|
||||
"""
|
||||
Truncate the segment based on the test_start date
|
||||
|
||||
Parameters
|
||||
----------
|
||||
segment : tuple
|
||||
time segment
|
||||
test_start
|
||||
days : int
|
||||
The trading days to be truncated
|
||||
the data in this segment may need 'days' data
|
||||
|
||||
Returns
|
||||
---------
|
||||
tuple: new segment
|
||||
"""
|
||||
test_idx = self.align_idx(test_start)
|
||||
if isinstance(segment, tuple):
|
||||
new_seg = []
|
||||
for time_point in segment:
|
||||
tp_idx = min(self.align_idx(time_point), test_idx - days)
|
||||
assert tp_idx > 0
|
||||
new_seg.append(self.get(tp_idx))
|
||||
return tuple(new_seg)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
SHIFT_SD = "sliding"
|
||||
SHIFT_EX = "expanding"
|
||||
|
||||
def shift(self, seg: tuple, step: int, rtype=SHIFT_SD) -> tuple:
|
||||
"""
|
||||
Shift the datatime of segment
|
||||
|
||||
Parameters
|
||||
----------
|
||||
seg :
|
||||
datetime segment
|
||||
step : int
|
||||
rolling step
|
||||
rtype : str
|
||||
rolling type ("sliding" or "expanding")
|
||||
|
||||
Returns
|
||||
--------
|
||||
tuple: new segment
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError:
|
||||
shift will raise error if the index(both start and end) is out of self.cal
|
||||
"""
|
||||
if isinstance(seg, tuple):
|
||||
start_idx, end_idx = self.align_idx(seg[0], tp_type="start"), self.align_idx(seg[1], tp_type="end")
|
||||
if rtype == self.SHIFT_SD:
|
||||
start_idx += step
|
||||
end_idx += step
|
||||
elif rtype == self.SHIFT_EX:
|
||||
end_idx += step
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
if start_idx > len(self.cals):
|
||||
raise KeyError("The segment is out of valid calendar")
|
||||
return self.get(start_idx), self.get(end_idx)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
Reference in New Issue
Block a user