diff --git a/docs/advanced/task_management.rst b/docs/advanced/task_management.rst index 230a4e9d1..a68c12627 100644 --- a/docs/advanced/task_management.rst +++ b/docs/advanced/task_management.rst @@ -1,4 +1,4 @@ -.. _task_managment: +.. _task_management: ================================= Task Management @@ -10,15 +10,17 @@ Introduction ============= The `Workflow <../component/introduction.html>`_ part introduces how to run research workflow in a loosely-coupled way. But it can only execute one ``task`` when you use ``qrun``. -To automatically generate and execute different tasks, ``Task Management`` provides a whole process including `Task Generating`_, `Task Storing`_, `Task Running`_ and `Task Collecting`_. +To automatically generate and execute different tasks, ``Task Management`` provides a whole process including `Task Generating`_, `Task Storing`_, `Task Training`_ and `Task Collecting`_. With this module, users can run their ``task`` automatically at different periods, in different losses, or even by different models. -An example of the entire process is shown `here `_. +This whole process can be used in `Online Serving <../component/online.html>`_. + +An example of the entire process is shown `here `_. Task Generating =============== A ``task`` consists of `Model`, `Dataset`, `Record` or anything added by users. -The specific task template(/definition/config) can be viewed in +The specific task template can be viewed in `Task Section <../component/workflow.html#task-section>`_. Even though the task template is fixed, users can customize their ``TaskGen`` to generate different ``task`` by task template. @@ -27,15 +29,16 @@ Here is the base class of ``TaskGen``: .. autoclass:: qlib.workflow.task.gen.TaskGen :members: -``Qlib`` provider a class `RollingGen `_ to generate a list of ``task`` of the dataset in different date segments. -This class allows users to verify the effect of data from different periods on the model in one experiment. +``Qlib`` provides a class `RollingGen `_ to generate a list of ``task`` of the dataset in different date segments. +This class allows users to verify the effect of data from different periods on the model in one experiment. More information in `here <../reference/api.html#TaskGen>`_. Task Storing =============== To achieve higher efficiency and the possibility of cluster operation, ``Task Manager`` will store all tasks in `MongoDB `_. +``TaskManager`` can fetch undone tasks automatically and manage the lifecycle of a set of tasks with error handling. Users **MUST** finished the configuration of `MongoDB `_ when using this module. -Users need to provide the URL and database name of ``task`` storing like this. +Users need to provide the MongoDB URL and database name for using ``TaskManager`` in `initialization <../start/initialization.html#Parameters>`_ or make statement like this. .. code-block:: python @@ -45,13 +48,12 @@ Users need to provide the URL and database name of ``task`` storing like this. "task_db_name" : "rolling_db" # database name } -The CRUD methods of ``task`` can be found in TaskManager. -More methods can be seen in the `Github `_. - .. autoclass:: qlib.workflow.task.manage.TaskManager :members: -Task Running +More information of ``Task Manager`` can be found in `here <../reference/api.html#TaskManager>`_. + +Task Training =============== After generating and storing those ``task``, it's time to run the ``task`` which are in the *WAITING* status. ``Qlib`` provides a method called ``run_task`` to run those ``task`` in task pool, however, users can also customize how tasks are executed. @@ -60,14 +62,24 @@ It will run the whole workflow defined by ``task``, which includes *Model*, *Dat .. autofunction:: qlib.workflow.task.manage.run_task +Meanwhile, ``Qlib`` provides a module called ``Trainer``. +``Trainer`` will train a list of tasks and return a list of model recorder. +``Qlib`` offer two kind of Trainer, TrainerR is the simplest way and TrainerRM is based on TaskManager to help manager tasks lifecycle automatically. +If you do not want to use ``Task Manager`` to manage tasks, then use TrainerR to train a list of tasks generated by ``TaskGen`` is enough. +More information is in `here <../reference/api.html#Trainer>`_. + Task Collecting =============== -To see the results of ``task`` after running or to update something, ``Qlib`` provides a ``TaskCollector`` to collect the tasks by filter condition (optional). -Here are some methods in this class. +To collect the results of ``task`` after training, ``Qlib`` provides `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_ to collect the results in a readable, expandable and loosely-coupled way. -.. autoclass:: qlib.workflow.task.collect.TaskCollector - :members: +`Collector <../reference/api.html#Collector>`_ can collect object from everywhere and process them such as merging, grouping, averaging and so on. It has 2 step action including ``collect`` (collect anything in a dict) and ``process_collect`` (process collected dict). -``Qlib`` provides a concrete `example `_, including a whole process of `Task Generating`_ (using `RollingGen `_), `Task Storing`_, `Task Running`_ and `Task Collecting`_. -Besides, the `example `_ uses a ``ModelUpdater`` inherited from ``TaskCollector``, which can update the inferences and retrain the model if it is out of date. -Actually, the model updating can be viewed as a subset of ``Online Serving``. \ No newline at end of file +`Group <../reference/api.html#Group>`_ also has 2 steps including ``group`` (can group a set of object based on `group_func` and change them to a dict) and ``reduce`` (can make a dict become an ensemble based on some rule). +For example: {(A,B,C1): object, (A,B,C2): object} ---``group``---> {(A,B): {C1: object, C2: object}} ---``reduce``---> {(A,B): object} + +`Ensemble <../reference/api.html#Ensemble>`_ can merge the objects in an ensemble. +For example: {C1: object, C2: object} ---``Ensemble``---> object + +So the hierarchy is ``Collector``'s second step correspond to ``Group``. And ``Group``'s second step correspond to ``Ensemble``. + +For more information, please see `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_, or the `example `_ \ No newline at end of file diff --git a/docs/component/data.rst b/docs/component/data.rst index 26f44a076..3cee803e6 100644 --- a/docs/component/data.rst +++ b/docs/component/data.rst @@ -182,6 +182,11 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t qlib.init(provider_uri='~/.qlib/qlib_data/us_data', region=REG_US) +.. note:: + + PRs for new data source are highly welcome! Users could commit the code to crawl data as a PR like `the examples here `_. And then we will use the code to create data cache on our server which other users could use directly. + + Data API ======================== diff --git a/docs/component/online.rst b/docs/component/online.rst new file mode 100644 index 000000000..e25173153 --- /dev/null +++ b/docs/component/online.rst @@ -0,0 +1,41 @@ +.. _online: + +================================= +Online Serving +================================= +.. currentmodule:: qlib + + +Introduction +============= +In addition to backtesting, one way to test a model is effective is to make predictions in real market conditions or even do real trading based on those predictions. +``Online Serving`` is a set of module for online models using latest data, +which including `Online Manager <#Online Manager>`_, `Online Strategy <#Online Strategy>`_, `Online Tool <#Online Tool>`_, `Updater <#Updater>`_. + +`Here `_ are several examples for reference, which demonstrate different features of ``Online Serving``. +If you have many models or `task` need to be managed, please consider `Task Management <../advanced/task_management.html>`_. +The `examples `_ maybe based on `Task Management <../advanced/task_management.html>`_ such as ``TrainerRM`` or ``Collector``. + +Online Manager +============= + +.. automodule:: qlib.workflow.online.manager + :members: + +Online Strategy +============= + +.. automodule:: qlib.workflow.online.strategy + :members: + +Online Tool +============= + +.. automodule:: qlib.workflow.online.utils + :members: + +Updater +============= + +.. automodule:: qlib.workflow.online.update + :members: \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 274dc8045..803aa97d2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -42,6 +42,7 @@ Document Structure Intraday Trading: Model&Strategy Testing Qlib Recorder: Experiment Management Analysis: Evaluation & Results Analysis + Online Serving: Online Management & Strategy & Tool .. toctree:: :maxdepth: 3 diff --git a/docs/reference/api.rst b/docs/reference/api.rst index 691dff703..edba6228a 100644 --- a/docs/reference/api.rst +++ b/docs/reference/api.rst @@ -154,36 +154,71 @@ Record Template .. automodule:: qlib.workflow.record_temp :members: - Task Management ==================== -RollingGen +TaskGen -------------------- -.. autoclass:: qlib.workflow.task.gen.RollingGen +.. automodule:: qlib.workflow.task.gen :members: TaskManager -------------------- -.. autoclass:: qlib.workflow.task.manage.TaskManager +.. automodule:: qlib.workflow.task.manage :members: -TaskCollector +Trainer -------------------- -.. autoclass:: qlib.workflow.task.collect.TaskCollector +.. automodule:: qlib.model.trainer :members: -ModelUpdater +Collector -------------------- -.. autoclass:: qlib.workflow.task.update.ModelUpdater +.. automodule:: qlib.workflow.task.collect :members: -TimeAdjuster +Group -------------------- -.. autoclass:: qlib.workflow.task.utils.TimeAdjuster +.. automodule:: qlib.model.ens.group :members: +Ensemble +-------------------- +.. automodule:: qlib.model.ens.ensemble + :members: + +Utils +-------------------- +.. automodule:: qlib.workflow.task.utils + :members: + + +Online Serving +==================== + + +Online Manager +-------------------- +.. automodule:: qlib.workflow.online.manager + :members: + +Online Strategy +-------------------- +.. automodule:: qlib.workflow.online.strategy + :members: + +Online Tool +-------------------- +.. automodule:: qlib.workflow.online.utils + :members: + +RecordUpdater +-------------------- +.. automodule:: qlib.workflow.online.update + :members: + + Utils ==================== diff --git a/examples/benchmarks/README.md b/examples/benchmarks/README.md index f1e7437fa..c3d965d85 100644 --- a/examples/benchmarks/README.md +++ b/examples/benchmarks/README.md @@ -17,6 +17,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of | ALSTM (Yao Qin, et al.) | Alpha360 | 0.0493±0.01 | 0.3778±0.06| 0.0585±0.00 | 0.4606±0.04 | 0.0513±0.03 | 0.6727±0.38| -0.1085±0.02 | | GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 | | DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 | +| TabNet (Sercan O. Arik, et al.)| Alpha360 | 0.0192±0.00 | 0.1401±0.00| 0.0291±0.00 | 0.2163±0.00 | -0.0258±0.00 | -0.2961±0.00| -0.1429±0.00 | ## Alpha158 dataset | Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown | @@ -32,6 +33,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of | ALSTM (Yao Qin, et al.) | Alpha158 (with selected 20 features) | 0.0385±0.01 | 0.3022±0.06| 0.0478±0.00 | 0.3874±0.04 | 0.0486±0.03 | 0.7141±0.45| -0.1088±0.03 | | GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 | | DoubleEnsemble (Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4338±0.01 | 0.0523±0.00 | 0.4257±0.01 | 0.1253±0.01 | 1.4105±0.14 | -0.0902±0.01 | +| TabNet (Sercan O. Arik, et al.)| Alpha158 | 0.0383±0.00 | 0.3414±0.00| 0.0388±0.00 | 0.3460±0.00 | 0.0226±0.00 | 0.2652±0.00| -0.1072±0.00 | - The selected 20 features are based on the feature importance of a lightgbm-based model. - The base model of DoubleEnsemble is LGBM. diff --git a/examples/benchmarks/TFT/data_formatters/base.py b/examples/benchmarks/TFT/data_formatters/base.py index c68a192ba..aa1c0dc82 100644 --- a/examples/benchmarks/TFT/data_formatters/base.py +++ b/examples/benchmarks/TFT/data_formatters/base.py @@ -132,7 +132,7 @@ class GenericDataFormatter(abc.ABC): return -1, -1 def get_column_definition(self): - """"Returns formatted column definition in order expected by the TFT.""" + """Returns formatted column definition in order expected by the TFT.""" column_definition = self._column_definition diff --git a/examples/highfreq/README.md b/examples/highfreq/README.md index 30c2e19db..c07d8a2a0 100644 --- a/examples/highfreq/README.md +++ b/examples/highfreq/README.md @@ -25,4 +25,11 @@ The example is given in `workflow.py`, users can run the code as follows. Run the example by running the following command: ```bash python workflow.py dump_and_load_dataset -``` \ No newline at end of file +``` + +## Benchmarks Performance +### Signal Test +Here are the results of signal test for benchmark models. We will keep updating benchmark models in future. +| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Long precision| Short Precision | Long-Short Average Return | Long-Short Average Sharpe | +|---|---|---|---|---|---|---|---|---|---| +| LightGBM | Alpha158 | 0.3042±0.00 | 1.5372±0.00| 0.3117±0.00 | 1.6258±0.00 | 0.6720±0.00 | 0.6870±0.00 | 0.000769±0.00 | 1.0190±0.00 | diff --git a/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml b/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml new file mode 100644 index 000000000..45c59c670 --- /dev/null +++ b/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml @@ -0,0 +1,65 @@ +qlib_init: + provider_uri: "~/.qlib/qlib_data/cn_data_1min" + region: cn +market: &market 'csi300' +start_time: &start_time "2020-09-15 00:00:00" +end_time: &end_time "2021-01-18 16:00:00" +train_end_time: &train_end_time "2020-11-15 16:00:00" +valid_start_time: &valid_start_time "2020-11-16 00:00:00" +valid_end_time: &valid_end_time "2020-11-30 16:00:00" +test_start_time: &test_start_time "2020-12-01 00:00:00" +data_handler_config: &data_handler_config + start_time: *start_time + end_time: *end_time + fit_start_time: *start_time + fit_end_time: *train_end_time + instruments: *market + freq: '1min' + infer_processors: + - class: 'RobustZScoreNorm' + kwargs: + fields_group: 'feature' + clip_outlier: false + - class: "Fillna" + kwargs: + fields_group: 'feature' + learn_processors: + - class: 'DropnaLabel' + - class: 'CSRankNorm' + kwargs: + fields_group: 'label' + label: ["Ref($close, -2) / Ref($close, -1) - 1"] + +task: + model: + class: "HFLGBModel" + module_path: "qlib.contrib.model.highfreq_gdbt_model" + kwargs: + objective: 'binary' + metric: ['binary_logloss','auc'] + verbosity: -1 + learning_rate: 0.01 + max_depth: 8 + num_leaves: 150 + lambda_l1: 1.5 + lambda_l2: 1 + num_threads: 20 + dataset: + class: "DatasetH" + module_path: "qlib.data.dataset" + kwargs: + handler: + class: "Alpha158" + module_path: "qlib.contrib.data.handler" + kwargs: *data_handler_config + segments: + train: [*start_time, *train_end_time] + valid: [*train_end_time, *valid_end_time] + test: [*test_start_time, *end_time] + record: + - class: "SignalRecord" + module_path: "qlib.workflow.record_temp" + kwargs: {} + - class: "HFSignalRecord" + module_path: "qlib.workflow.record_temp" + kwargs: {} \ No newline at end of file diff --git a/examples/model_rolling/task_manager_rolling.py b/examples/model_rolling/task_manager_rolling.py index ab3a4eee5..175319885 100644 --- a/examples/model_rolling/task_manager_rolling.py +++ b/examples/model_rolling/task_manager_rolling.py @@ -1,24 +1,23 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +This example shows how a TrainerRM work based on TaskManager with rolling tasks. +After training, how to collect the rolling results will be showed in task_collecting. +""" + from pprint import pprint -import time import fire import qlib from qlib.config import REG_CN -from qlib.model.trainer import TrainerR, task_train from qlib.workflow import R from qlib.workflow.task.gen import RollingGen, task_generator -from qlib.workflow.task.manage import TaskManager, run_task +from qlib.workflow.task.manage import TaskManager from qlib.workflow.task.collect import RecorderCollector -from qlib.model.ens.ensemble import RollingEnsemble, ens_workflow -import pandas as pd -from qlib.workflow.task.utils import list_recorders from qlib.model.ens.group import RollingGroup from qlib.model.trainer import TrainerRM -""" -This example shows how a Trainer work based on TaskManager with rolling tasks. -After training, how to collect the rolling results will be showed in task_collecting. -""" data_handler_config = { "start_time": "2008-01-01", @@ -139,11 +138,13 @@ class RollingTaskExample: return True return False - artifact = ens_workflow( - RecorderCollector(experiment=self.experiment_name, rec_key_func=rec_key, rec_filter_func=my_filter), - RollingGroup(), + collector = RecorderCollector( + experiment=self.experiment_name, + process_list=RollingGroup(), + rec_key_func=rec_key, + rec_filter_func=my_filter, ) - print(artifact) + print(collector()) def main(self): self.reset() diff --git a/examples/online_srv/online_management_simulate.py b/examples/online_srv/online_management_simulate.py index 1b1fed660..7be46d999 100644 --- a/examples/online_srv/online_management_simulate.py +++ b/examples/online_srv/online_management_simulate.py @@ -1,21 +1,18 @@ -import fire -import qlib -from qlib.model.ens.ensemble import ens_workflow -from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerRM -from qlib.workflow import R -from qlib.workflow.online.manager import RollingOnlineManager -from qlib.workflow.online.simulator import OnlineSimulator -from qlib.workflow.task.collect import RecorderCollector -from qlib.workflow.task.gen import RollingGen, task_generator -from qlib.workflow.task.manage import TaskManager -from qlib.workflow.task.utils import list_recorders +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """ -This examples is about the OnlineManager and OnlineSimulator based on rolling tasks. -The OnlineManager will focus on the updating of your online models. -The OnlineSimulator will focus on the simulating real updating routine of your online models. +This examples is about how can simulate the OnlineManager based on rolling tasks. """ +import fire +import qlib +from qlib.model.trainer import DelayTrainerRM +from qlib.workflow.online.manager import OnlineManager +from qlib.workflow.online.strategy import RollingAverageStrategy +from qlib.workflow.task.gen import RollingGen +from qlib.workflow.task.manage import TaskManager + data_handler_config = { "start_time": "2018-01-01", @@ -86,10 +83,10 @@ class OnlineSimulationExample: rolling_step=80, start_time="2018-09-10", end_time="2018-10-31", - tasks=[task_xgboost_config], # , task_lgb_config] + tasks=[task_xgboost_config, task_lgb_config], ): """ - init OnlineManagerExample. + Init OnlineManagerExample. Args: provider_uri (str, optional): the provider uri. Defaults to "~/.qlib/qlib_data/cn_data". @@ -105,6 +102,8 @@ class OnlineSimulationExample: """ self.exp_name = exp_name self.task_pool = task_pool + self.start_time = start_time + self.end_time = end_time mongo_conf = { "task_url": task_url, "task_db_name": task_db_name, @@ -115,62 +114,30 @@ class OnlineSimulationExample: ) # The rolling tasks generator, modify_end_time is false because we just need simulate to 2018-10-31. self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) self.task_manager = TaskManager(self.task_pool) # A good way to manage all your tasks - self.rolling_online_manager = RollingOnlineManager( - experiment_name=exp_name, - rolling_gen=self.rolling_gen, - trainer=self.trainer, + self.rolling_online_manager = OnlineManager( + RollingAverageStrategy( + exp_name, task_template=tasks, rolling_gen=self.rolling_gen, trainer=self.trainer, need_log=False + ), + begin_time=self.start_time, need_log=False, - ) # The OnlineManager based on Rolling - self.onlinesimulator = OnlineSimulator( - start_time=start_time, - end_time=end_time, - online_manager=self.rolling_online_manager, ) self.tasks = tasks - # Reset all things to the first status, be careful to save important data - def reset(self): - print("========== reset ==========") - self.task_manager.remove() - - exp = R.get_exp(experiment_name=self.exp_name) - for rid in exp.list_recorders(): - exp.delete_recorder(rid) - - for rid in list_recorders( - RollingOnlineManager.SIGNAL_EXP, lambda x: True if x.info["name"] == self.exp_name else False - ): - exp.delete_recorder(rid) - - # Run this firstly to see the workflow in OnlineManager - def first_train(self): - print("========== first train ==========") - self.reset() - self.rolling_online_manager.first_train(self.tasks) - - # Run this secondly to see the simulating in OnlineSimulator - def simulate(self): - print("========== simulate ==========") - self.onlinesimulator.simulate() - print(self.rolling_online_manager.collect_artifact()) - - print("========== online models ==========") - recs_dict = self.onlinesimulator.online_models() - for time, recs in recs_dict.items(): - print(f"{str(time[0])} to {str(time[1])}:") - for rec in recs: - print(rec.info["id"]) - - print("========== online signals ==========") - print(self.rolling_online_manager.get_signals()) - - # Run this to run all workflow automaticly + # Run this to run all workflow automatically def main(self): - self.first_train() - self.simulate() + print("========== reset ==========") + self.rolling_online_manager.reset() + print("========== simulate ==========") + self.rolling_online_manager.simulate(end_time=self.end_time) + print("========== collect results ==========") + print(self.rolling_online_manager.get_collector()()) + print("========== signals ==========") + print(self.rolling_online_manager.get_signals()) + print("========== online history ==========") + print(self.rolling_online_manager.get_online_history(self.exp_name)) if __name__ == "__main__": - ## to run all workflow automaticly with your own parameters, use the command below + ## to run all workflow automatically with your own parameters, use the command below # python online_management_simulate.py main --experiment_name="your_exp_name" --rolling_step=60 fire.Fire(OnlineSimulationExample) diff --git a/examples/online_srv/rolling_online_management.py b/examples/online_srv/rolling_online_management.py index d118afe75..25b6fc4da 100644 --- a/examples/online_srv/rolling_online_management.py +++ b/examples/online_srv/rolling_online_management.py @@ -1,22 +1,26 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +This example show how OnlineManager works with rolling tasks. +There are two parts including first train and routine. +Firstly, the OnlineManager will finish the first training and set trained models to `online` models. +Next, the OnlineManager will finish a routine process, including update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models +""" + import os from pathlib import Path import pickle import fire import qlib from qlib.workflow import R +from qlib.workflow.online.strategy import RollingAverageStrategy from qlib.workflow.task.gen import RollingGen from qlib.workflow.task.manage import TaskManager -from qlib.workflow.online.manager import RollingOnlineManager +from qlib.workflow.online.manager import OnlineManager from qlib.workflow.task.utils import list_recorders from qlib.model.trainer import TrainerRM -""" -This example show how RollingOnlineManager works with rolling tasks. -There are two parts including first train and routine. -Firstly, the RollingOnlineManager will finish the first training and set trained models to `online` models. -Next, the RollingOnlineManager will finish a routine process, including update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models -""" - data_handler_config = { "start_time": "2013-01-01", "end_time": "2020-09-25", @@ -77,58 +81,75 @@ task_xgboost_config = { class RollingOnlineExample: def __init__( self, - exp_name="rolling_exp", - task_pool="rolling_task", provider_uri="~/.qlib/qlib_data/cn_data", region="cn", task_url="mongodb://10.0.0.4:27017/", task_db_name="rolling_db", rolling_step=550, + tasks=[task_xgboost_config], # , task_lgb_config], ): - self.exp_name = exp_name - self.task_pool = task_pool mongo_conf = { "task_url": task_url, # your MongoDB url "task_db_name": task_db_name, # database name } qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf) - self.rolling_online_manager = RollingOnlineManager( - experiment_name=exp_name, - rolling_gen=RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD), - trainer=TrainerRM(self.exp_name, self.task_pool), - ) + self.tasks = tasks + self.rolling_step = rolling_step + strategy = [] + for task in tasks: + name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy + strategy.append( + RollingAverageStrategy( + name_id, + task, + RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD), + TrainerRM(experiment_name=name_id, task_pool=name_id), + ) + ) - _ROLLING_MANAGER_PATH = ".rolling_manager" # the RollingOnlineManager will dump to this file, for it will be loaded when calling routine. + self.rolling_online_manager = OnlineManager(strategy) + self.collector = self.rolling_online_manager.get_collector() + + _ROLLING_MANAGER_PATH = ( + ".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine. + ) # Reset all things to the first status, be careful to save important data def reset(self): print("========== reset ==========") - TaskManager(self.task_pool).remove() - exp = R.get_exp(experiment_name=self.exp_name) - for rid in exp.list_recorders(): - exp.delete_recorder(rid) + for task in self.tasks: + name_id = task["model"]["class"] + "_" + str(self.rolling_step) + TaskManager(name_id).remove() + exp = R.get_exp(experiment_name=name_id) + for rid in exp.list_recorders(): + exp.delete_recorder(rid) - if os.path.exists(self._ROLLING_MANAGER_PATH): - os.remove(self._ROLLING_MANAGER_PATH) + if os.path.exists(self._ROLLING_MANAGER_PATH): + os.remove(self._ROLLING_MANAGER_PATH) - for rid in list_recorders( - RollingOnlineManager.SIGNAL_EXP, lambda x: True if x.info["name"] == self.exp_name else False - ): - exp.delete_recorder(rid) + for rid in list_recorders("OnlineManagerSignals", lambda x: True if x.info["name"] == name_id else False): + exp.delete_recorder(rid) def first_run(self): + print("========== reset ==========") + self.rolling_online_manager.reset() print("========== first_run ==========") - self.reset() - self.rolling_online_manager.first_train([task_xgboost_config, task_lgb_config]) + self.rolling_online_manager.first_train() + print("========== dump ==========") self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH) - print(self.rolling_online_manager.collect_artifact()) + print("========== collect results ==========") + print(self.collector()) def routine(self): - print("========== routine ==========") + print("========== load ==========") with Path(self._ROLLING_MANAGER_PATH).open("rb") as f: self.rolling_online_manager = pickle.load(f) + print("========== routine ==========") self.rolling_online_manager.routine() - print(self.rolling_online_manager.collect_artifact()) + print("========== collect results ==========") + print(self.collector()) + print("========== signals ==========") + print(self.rolling_online_manager.get_signals()) def main(self): self.first_run() @@ -137,11 +158,11 @@ class RollingOnlineExample: if __name__ == "__main__": ####### to train the first version's models, use the command below - # python task_manager_rolling_with_updating.py first_run + # python rolling_online_management.py first_run ####### to update the models and predictions after the trading time, use the command below - # python task_manager_rolling_with_updating.py after_day + # python rolling_online_management.py after_day ####### to define your own parameters, use `--` - # python task_manager_rolling_with_updating.py first_run --exp_name='your_exp_name' --rolling_step=40 + # python rolling_online_management.py first_run --exp_name='your_exp_name' --rolling_step=40 fire.Fire(RollingOnlineExample) diff --git a/examples/online_srv/update_online_pred.py b/examples/online_srv/update_online_pred.py index ed2ad6997..6e2725c7a 100644 --- a/examples/online_srv/update_online_pred.py +++ b/examples/online_srv/update_online_pred.py @@ -1,16 +1,17 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +This example show how OnlineTool works when we need update prediction. +There are two parts including first_train and update_online_pred. +Firstly, we will finish the training and set the trained model to `online` model. +Next, we will finish updating online prediction. +""" import fire import qlib from qlib.config import REG_CN from qlib.model.trainer import task_train -from qlib.workflow.online.manager import OnlineManagerR -from qlib.workflow.task.utils import list_recorders - -""" -This example show how OnlineManager works when we need update prediction. -There are two parts including first_train and update_online_pred. -Firstly, the RollingOnlineManager will finish the first training and set the trained model to `online` model. -Next, the RollingOnlineManager will finish updating online prediction -""" +from qlib.workflow.online.utils import OnlineToolR data_handler_config = { "start_time": "2008-01-01", @@ -65,15 +66,15 @@ class UpdatePredExample: ): qlib.init(provider_uri=provider_uri, region=region) self.experiment_name = experiment_name - self.online_manager = OnlineManagerR(self.experiment_name) + self.online_tool = OnlineToolR(self.experiment_name) self.task_config = task_config def first_train(self): rec = task_train(self.task_config, experiment_name=self.experiment_name) - self.online_manager.reset_online_tag(rec) # set to online model + self.online_tool.reset_online_tag(rec) # set to online model def update_online_pred(self): - self.online_manager.update_online_pred() + self.online_tool.update_online_pred() def main(self): self.first_train() diff --git a/qlib/contrib/backtest/backtest.py b/qlib/contrib/backtest/backtest.py index b87d6afe3..fc30065fd 100644 --- a/qlib/contrib/backtest/backtest.py +++ b/qlib/contrib/backtest/backtest.py @@ -15,7 +15,8 @@ LOG = get_module_logger("backtest") def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, benchmark, return_order): - """Parameters + """ + Parameters ---------- pred : pandas.DataFrame predict should has index and one `score` column @@ -124,7 +125,9 @@ def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, def update_account(trade_account, trade_info, trade_exchange, trade_date): - """Update the account and strategy + """ + Update the account and strategy + Parameters ---------- trade_account : Account() diff --git a/qlib/contrib/backtest/position.py b/qlib/contrib/backtest/position.py index 6c269d505..97abc2a56 100644 --- a/qlib/contrib/backtest/position.py +++ b/qlib/contrib/backtest/position.py @@ -128,7 +128,7 @@ class Position: return self.position["cash"] def get_stock_amount_dict(self): - """generate stock amount dict {stock_id : amount of stock} """ + """generate stock amount dict {stock_id : amount of stock}""" d = {} stock_list = self.get_stock_list() for stock_code in stock_list: diff --git a/qlib/contrib/eva/alpha.py b/qlib/contrib/eva/alpha.py index c68571853..fadef9d16 100644 --- a/qlib/contrib/eva/alpha.py +++ b/qlib/contrib/eva/alpha.py @@ -8,6 +8,59 @@ import pandas as pd from typing import Tuple +def calc_long_short_prec( + pred: pd.Series, label: pd.Series, date_col="datetime", quantile: float = 0.2, dropna=False, is_alpha=False +) -> Tuple[pd.Series, pd.Series]: + """ + calculate the precision for long and short operation + + + :param pred/label: index is **pd.MultiIndex**, index name is **[datetime, instruments]**; columns names is **[score]**. + + .. code-block:: python + score + datetime instrument + 2020-12-01 09:30:00 SH600068 0.553634 + SH600195 0.550017 + SH600276 0.540321 + SH600584 0.517297 + SH600715 0.544674 + label : + label + date_col : + date_col + + Returns + ------- + (pd.Series, pd.Series) + long precision and short precision in time level + """ + if is_alpha: + label = label - label.mean(level=date_col) + if int(1 / quantile) >= len(label.index.get_level_values(1).unique()): + raise ValueError("Need more instruments to calculate precision") + + df = pd.DataFrame({"pred": pred, "label": label}) + if dropna: + df.dropna(inplace=True) + + group = df.groupby(level=date_col) + + N = lambda x: int(len(x) * quantile) + # find the top/low quantile of prediction and treat them as long and short target + long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label).reset_index(level=0, drop=True) + short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label).reset_index(level=0, drop=True) + + groupll = long.groupby(date_col) + l_dom = groupll.apply(lambda x: x > 0) + l_c = groupll.count() + + groups = short.groupby(date_col) + s_dom = groups.apply(lambda x: x < 0) + s_c = groups.count() + return (l_dom.groupby(date_col).sum() / l_c), (s_dom.groupby(date_col).sum() / s_c) + + def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]: """calc_ic. diff --git a/qlib/contrib/model/highfreq_gdbt_model.py b/qlib/contrib/model/highfreq_gdbt_model.py new file mode 100644 index 000000000..5a2eeb50a --- /dev/null +++ b/qlib/contrib/model/highfreq_gdbt_model.py @@ -0,0 +1,157 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import numpy as np +import pandas as pd +import lightgbm as lgb + +from qlib.model.base import ModelFT +from qlib.data.dataset import DatasetH +from qlib.data.dataset.handler import DataHandlerLP +import warnings + + +class HFLGBModel(ModelFT): + """LightGBM Model for high frequency prediction""" + + def __init__(self, loss="mse", **kwargs): + if loss not in {"mse", "binary"}: + raise NotImplementedError + self.params = {"objective": loss, "verbosity": -1} + self.params.update(kwargs) + self.model = None + + def _cal_signal_metrics(self, y_test, l_cut, r_cut): + """ + Calcaute the signal metrics by daily level + """ + up_pre, down_pre = [], [] + up_alpha_ll, down_alpha_ll = [], [] + for date in y_test.index.get_level_values(0).unique(): + df_res = y_test.loc[date].sort_values("pred") + if int(l_cut * len(df_res)) < 10: + warnings.warn("Warning: threhold is too low or instruments number is not enough") + continue + top = df_res.iloc[: int(l_cut * len(df_res))] + bottom = df_res.iloc[int(r_cut * len(df_res)) :] + + down_precision = len(top[top[top.columns[0]] < 0]) / (len(top)) + up_precision = len(bottom[bottom[top.columns[0]] > 0]) / (len(bottom)) + + down_alpha = top[top.columns[0]].mean() + up_alpha = bottom[bottom.columns[0]].mean() + + up_pre.append(up_precision) + down_pre.append(down_precision) + up_alpha_ll.append(up_alpha) + down_alpha_ll.append(down_alpha) + + return ( + np.array(up_pre).mean(), + np.array(down_pre).mean(), + np.array(up_alpha_ll).mean(), + np.array(down_alpha_ll).mean(), + ) + + def hf_signal_test(self, dataset: DatasetH, threhold=0.2): + """ + Test the sigal in high frequency test set + """ + if self.model == None: + raise ValueError("Model hasn't been trained yet") + df_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) + df_test.dropna(inplace=True) + x_test, y_test = df_test["feature"], df_test["label"] + # Convert label into alpha + y_test[y_test.columns[0]] = y_test[y_test.columns[0]] - y_test[y_test.columns[0]].mean(level=0) + + res = pd.Series(self.model.predict(x_test.values), index=x_test.index) + y_test["pred"] = res + + up_p, down_p, up_a, down_a = self._cal_signal_metrics(y_test, threhold, 1 - threhold) + print("===============================") + print("High frequency signal test") + print("===============================") + print("Test set precision: ") + print("Positive precision: {}, Negative precision: {}".format(up_p, down_p)) + print("Test Alpha Average in test set: ") + print("Positive average alpha: {}, Negative average alpha: {}".format(up_a, down_a)) + + def _prepare_data(self, dataset: DatasetH): + df_train, df_valid = dataset.prepare( + ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) + + x_train, y_train = df_train["feature"], df_train["label"] + x_valid, y_valid = df_train["feature"], df_valid["label"] + if y_train.values.ndim == 2 and y_train.values.shape[1] == 1: + l_name = df_train["label"].columns[0] + # Convert label into alpha + df_train["label"][l_name] = df_train["label"][l_name] - df_train["label"][l_name].mean(level=0) + df_valid["label"][l_name] = df_valid["label"][l_name] - df_valid["label"][l_name].mean(level=0) + mapping_fn = lambda x: 0 if x < 0 else 1 + df_train["label_c"] = df_train["label"][l_name].apply(mapping_fn) + df_valid["label_c"] = df_valid["label"][l_name].apply(mapping_fn) + x_train, y_train = df_train["feature"], df_train["label_c"].values + x_valid, y_valid = df_valid["feature"], df_valid["label_c"].values + else: + raise ValueError("LightGBM doesn't support multi-label training") + + dtrain = lgb.Dataset(x_train.values, label=y_train) + dvalid = lgb.Dataset(x_valid.values, label=y_valid) + return dtrain, dvalid + + def fit( + self, + dataset: DatasetH, + num_boost_round=1000, + early_stopping_rounds=50, + verbose_eval=20, + evals_result=dict(), + **kwargs + ): + dtrain, dvalid = self._prepare_data(dataset) + self.model = lgb.train( + self.params, + dtrain, + num_boost_round=num_boost_round, + valid_sets=[dtrain, dvalid], + valid_names=["train", "valid"], + early_stopping_rounds=early_stopping_rounds, + verbose_eval=verbose_eval, + evals_result=evals_result, + **kwargs + ) + evals_result["train"] = list(evals_result["train"].values())[0] + evals_result["valid"] = list(evals_result["valid"].values())[0] + + def predict(self, dataset): + if self.model is None: + raise ValueError("model is not fitted yet!") + x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I) + return pd.Series(self.model.predict(x_test.values), index=x_test.index) + + def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20): + """ + finetune model + + Parameters + ---------- + dataset : DatasetH + dataset for finetuning + num_boost_round : int + number of round to finetune model + verbose_eval : int + verbose level + """ + # Based on existing model and finetune by train more rounds + dtrain, _ = self._prepare_data(dataset) + self.model = lgb.train( + self.params, + dtrain, + num_boost_round=num_boost_round, + init_model=self.model, + valid_sets=[dtrain], + valid_names=["train"], + verbose_eval=verbose_eval, + ) diff --git a/qlib/contrib/report/analysis_position/cumulative_return.py b/qlib/contrib/report/analysis_position/cumulative_return.py index abb68ea60..00985a17c 100644 --- a/qlib/contrib/report/analysis_position/cumulative_return.py +++ b/qlib/contrib/report/analysis_position/cumulative_return.py @@ -214,7 +214,7 @@ def cumulative_return_graph( features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close - 1'], pred_df_dates.min(), pred_df_dates.max()) features_df.columns = ['label'] - qcr.cumulative_return_graph(positions, report_normal_df, features_df) + qcr.analysis_position.cumulative_return_graph(positions, report_normal_df, features_df) Graph desc: diff --git a/qlib/contrib/report/analysis_position/rank_label.py b/qlib/contrib/report/analysis_position/rank_label.py index 72a358adc..77743b10c 100644 --- a/qlib/contrib/report/analysis_position/rank_label.py +++ b/qlib/contrib/report/analysis_position/rank_label.py @@ -94,7 +94,7 @@ def rank_label_graph( features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'], pred_df_dates.min(), pred_df_dates.max()) features_df.columns = ['label'] - qcr.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max()) + qcr.analysis_position.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max()) :param position: position data; **qlib.contrib.backtest.backtest.backtest** result. diff --git a/qlib/contrib/report/analysis_position/report.py b/qlib/contrib/report/analysis_position/report.py index f82e654c4..6b83f0734 100644 --- a/qlib/contrib/report/analysis_position/report.py +++ b/qlib/contrib/report/analysis_position/report.py @@ -186,7 +186,7 @@ def report_graph(report_df: pd.DataFrame, show_notebook: bool = True) -> [list, report_normal_df, _ = backtest(pred_df, strategy, **bparas) - qcr.report_graph(report_normal_df) + qcr.analysis_position.report_graph(report_normal_df) :param report_df: **df.index.name** must be **date**, **df.columns** must contain **return**, **turnover**, **cost**, **bench**. diff --git a/qlib/contrib/report/graph.py b/qlib/contrib/report/graph.py index 677e767ee..2d4f546e8 100644 --- a/qlib/contrib/report/graph.py +++ b/qlib/contrib/report/graph.py @@ -18,7 +18,7 @@ from ...utils import get_module_by_module_path class BaseGraph: - """""" + """ """ _name = None diff --git a/qlib/contrib/workflow/record_temp.py b/qlib/contrib/workflow/record_temp.py index 12792fbcb..bedf89105 100644 --- a/qlib/contrib/workflow/record_temp.py +++ b/qlib/contrib/workflow/record_temp.py @@ -1,10 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import logging import pandas as pd +import numpy as np from sklearn.metrics import mean_squared_error from typing import Dict, Text, Any -import numpy as np from ...contrib.eva.alpha import calc_ic from ...workflow.record_temp import RecordTemp @@ -12,7 +13,7 @@ from ...workflow.record_temp import SignalRecord from ...data import dataset as qlib_dataset from ...log import get_module_logger -logger = get_module_logger("workflow", "INFO") +logger = get_module_logger("workflow", logging.INFO) class MultiSegRecord(RecordTemp): diff --git a/qlib/data/data.py b/qlib/data/data.py index 000bd1196..c2638e234 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -522,6 +522,9 @@ class LocalCalendarProvider(CalendarProvider): # if future calendar not exists, return current calendar if not os.path.exists(fname): get_module_logger("data").warning(f"{freq}_future.txt not exists, return current calendar!") + get_module_logger("data").warning( + "You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md" + ) fname = self._uri_cal.format(freq) else: fname = self._uri_cal.format(freq) @@ -1016,7 +1019,8 @@ class ClientProvider(BaseProvider): self.logger = get_module_logger(self.__class__.__name__) if isinstance(Cal, ClientCalendarProvider): Cal.set_conn(self.client) - Inst.set_conn(self.client) + if isinstance(Inst, ClientInstrumentProvider): + Inst.set_conn(self.client) if hasattr(DatasetD, "provider"): DatasetD.provider.set_conn(self.client) else: diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 2173d87ae..75abc0cbb 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -27,7 +27,7 @@ class Dataset(Serializable): - setup data - The data related attributes' names should start with '_' so that it will not be saved on disk when serializing. - The data could specify the info to caculate the essential data for preparation + The data could specify the info to calculate the essential data for preparation """ self.setup_data(**kwargs) super().__init__() @@ -92,7 +92,7 @@ class DatasetH(Dataset): handler : Union[dict, DataHandler] handler could be: - - insntance of `DataHandler` + - instance of `DataHandler` - config of `DataHandler`. Please refer to `DataHandler` @@ -124,7 +124,7 @@ class DatasetH(Dataset): Parameters ---------- handler_kwargs : dict - Config of DataHanlder, which could include the following arguments: + Config of DataHandler, which could include the following arguments: - arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'. @@ -148,11 +148,11 @@ class DatasetH(Dataset): Parameters ---------- handler_kwargs : dict - init arguments of DataHanlder, which could include the following arguments: + init arguments of DataHandler, which could include the following arguments: - init_type : Init Type of Handler - - enable_cache : wheter to enable cache + - enable_cache : whether to enable cache """ super().setup_data(**kwargs) @@ -238,7 +238,7 @@ class TSDataSampler: (T)ime-(S)eries DataSampler This is the result of TSDatasetH - It works like `torch.data.utils.Dataset`, it provides a very convient interface for constructing time-series + It works like `torch.data.utils.Dataset`, it provides a very convenient interface for constructing time-series dataset based on tabular data. If user have further requirements for processing data, user could process them based on `TSDataSampler` or create @@ -310,7 +310,7 @@ class TSDataSampler: self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end)) self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance - + del self.data # save memory @staticmethod @@ -472,7 +472,7 @@ class TSDatasetH(DatasetH): (T)ime-(S)eries Dataset (H)andler - Covnert the tabular data to Time-Series data + Convert the tabular data to Time-Series data Requirements analysis diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index f1fa39c3b..63b49d78b 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -36,7 +36,7 @@ class DataHandler(Serializable): The data handler try to maintain a handler with 2 level. `datetime` & `instruments`. - Any order of the index level can be suported (The order will be implied in the data). + Any order of the index level can be supported (The order will be implied in the data). The order <`datetime`, `instruments`> will be used when the dataframe index name is missed. Example of the data: @@ -77,7 +77,7 @@ class DataHandler(Serializable): data_loader : Tuple[dict, str, DataLoader] data loader to load the data. init_data : - intialize the original data in the constructor. + initialize the original data in the constructor. fetch_orig : bool Return the original data instead of copy if possible. """ @@ -128,7 +128,7 @@ class DataHandler(Serializable): def setup_data(self, enable_cache: bool = False): """ - Set Up the data in case of running intialization for multiple time + Set Up the data in case of running initialization for multiple time It is responsible for maintaining following variable 1) self._data @@ -453,7 +453,7 @@ class DataHandlerLP(DataHandler): def setup_data(self, init_type: str = IT_FIT_SEQ, **kwargs): """ - Set up the data in case of running intialization for multiple time + Set up the data in case of running initialization for multiple time Parameters ---------- diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index 1d98d3bc9..fce22ddfc 100644 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -130,7 +130,7 @@ class FilterCol(Processor): class TanhProcess(Processor): - """ Use tanh to process noise data""" + """Use tanh to process noise data""" def __call__(self, df): def tanh_denoise(data): @@ -145,7 +145,7 @@ class TanhProcess(Processor): class ProcessInf(Processor): - """Process infinity """ + """Process infinity""" def __call__(self, df): def replace_inf(data): diff --git a/qlib/log.py b/qlib/log.py index 126acb9d2..e714bc15a 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -12,7 +12,41 @@ from contextlib import contextmanager from .config import C -def get_module_logger(module_name, level: Optional[int] = None): +class MetaLogger(type): + def __new__(cls, name, bases, dict): + wrapper_dict = logging.Logger.__dict__.copy() + for key in wrapper_dict: + if key not in dict and key != "__reduce__": + dict[key] = wrapper_dict[key] + return type.__new__(cls, name, bases, dict) + + +class QlibLogger(metaclass=MetaLogger): + """ + Customized logger for Qlib. + """ + + def __init__(self, module_name): + self.module_name = module_name + self.level = 0 + + @property + def logger(self): + logger = logging.getLogger(self.module_name) + logger.setLevel(self.level) + return logger + + def setLevel(self, level): + self.level = level + + def __getattr__(self, name): + # During unpickling, python will call __getattr__. Use this line to avoid maximum recursion error. + if name in {"__setstate__"}: + raise AttributeError + return self.logger.__getattribute__(name) + + +def get_module_logger(module_name, level: Optional[int] = None) -> logging.Logger: """ Get a logger for a specific module. @@ -27,7 +61,7 @@ def get_module_logger(module_name, level: Optional[int] = None): module_name = "qlib.{}".format(module_name) # Get logger. - module_logger = logging.getLogger(module_name) + module_logger = QlibLogger(module_name) module_logger.setLevel(level) return module_logger diff --git a/qlib/model/base.py b/qlib/model/base.py index 1ac8f2fc9..12caf5f73 100644 --- a/qlib/model/base.py +++ b/qlib/model/base.py @@ -11,11 +11,11 @@ class BaseModel(Serializable, metaclass=abc.ABCMeta): @abc.abstractmethod def predict(self, *args, **kwargs) -> object: - """ Make predictions after modeling things """ + """Make predictions after modeling things""" pass def __call__(self, *args, **kwargs) -> object: - """ leverage Python syntactic sugar to make the models' behaviors like functions """ + """leverage Python syntactic sugar to make the models' behaviors like functions""" return self.predict(*args, **kwargs) diff --git a/qlib/model/ens/ensemble.py b/qlib/model/ens/ensemble.py index 63f6438c2..1fb14a37b 100644 --- a/qlib/model/ens/ensemble.py +++ b/qlib/model/ens/ensemble.py @@ -1,36 +1,12 @@ -from abc import abstractmethod -from typing import Callable, Union +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +""" +Ensemble can merge the objects in an Ensemble. For example, if there are many submodels predictions, we may need to merge them in an ensemble predictions. +""" + +from typing import Union import pandas as pd -from qlib.workflow.task.collect import Collector -from qlib.utils.serial import Serializable - - -def ens_workflow(collector: Collector, process_list, *args, **kwargs): - """the ensemble workflow based on collector and different dict processors. - - Args: - collector (Collector): the collector to collect the result into {result_key: things} - process_list (list or Callable): the list of processors or the instance of processor to process dict. - The processor order is same as the list order. - For example: [Group1(..., Ensemble1()), Group2(..., Ensemble2())] - Returns: - dict: the ensemble dict - """ - collect_dict = collector.collect() - if not isinstance(process_list, list): - process_list = [process_list] - - ensemble = {} - for artifact in collect_dict: - value = collect_dict[artifact] - for process in process_list: - if not callable(process): - raise NotImplementedError(f"{type(process)} is not supported in `ens_workflow`.") - value = process(value, *args, **kwargs) - ensemble[artifact] = value - - return ensemble class Ensemble: @@ -49,21 +25,45 @@ class Ensemble: raise NotImplementedError(f"Please implement the `__call__` method.") +class SingleKeyEnsemble(Ensemble): + + """ + Extract the object if there is only one key and value in dict. Make result more readable. + {Only key: Only value} -> Only value + If there are more than 1 key or less than 1 key, then do nothing. + Even you can run this recursively to make dict more readable. + NOTE: Default run recursively. + """ + + def __call__(self, ensemble_dict: Union[dict, object], recursion: bool = True) -> object: + if not isinstance(ensemble_dict, dict): + return ensemble_dict + if recursion: + tmp_dict = {} + for k, v in ensemble_dict.items(): + tmp_dict[k] = self(v, recursion) + ensemble_dict = tmp_dict + keys = list(ensemble_dict.keys()) + if len(keys) == 1: + ensemble_dict = ensemble_dict[keys[0]] + return ensemble_dict + + class RollingEnsemble(Ensemble): """Merge the rolling objects in an Ensemble""" - def __call__(self, ensemble_dict: dict): + def __call__(self, ensemble_dict: dict) -> pd.DataFrame: """Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble. - NOTE: The values of dict must be pd.Dataframe, and have the index "datetime" + NOTE: The values of dict must be pd.DataFrame, and have the index "datetime" Args: - ensemble_dict (dict): a dict like {"A": pd.Dataframe, "B": pd.Dataframe}. + ensemble_dict (dict): a dict like {"A": pd.DataFrame, "B": pd.DataFrame}. The key of the dict will be ignored. Returns: - pd.Dataframe: the complete result of rolling. + pd.DataFrame: the complete result of rolling. """ artifact_list = list(ensemble_dict.values()) artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min()) @@ -72,3 +72,24 @@ class RollingEnsemble(Ensemble): artifact = artifact[~artifact.index.duplicated(keep="last")] artifact = artifact.sort_index() return artifact + + +class AverageEnsemble(Ensemble): + def __call__(self, ensemble_dict: dict): + """ + Average a dict of same shape dataframe like `prediction` or `IC` into an ensemble. + + NOTE: The values of dict must be pd.DataFrame, and have the index "datetime" + + Args: + ensemble_dict (dict): a dict like {"A": pd.DataFrame, "B": pd.DataFrame}. + The key of the dict will be ignored. + + Returns: + pd.DataFrame: the complete result of averaging. + """ + values = list(ensemble_dict.values()) + results = pd.concat(values, axis=1) + results = results.mean(axis=1).to_frame("score") + results = results.sort_index() + return results diff --git a/qlib/model/ens/group.py b/qlib/model/ens/group.py index c80959b0d..d8f174105 100644 --- a/qlib/model/ens/group.py +++ b/qlib/model/ens/group.py @@ -1,3 +1,17 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Group can group a set of object based on `group_func` and change them to a dict. +After group, we provide a method to reduce them. + +For example: + +group: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}} +reduce: {(A,B): {C1: object, C2: object}} -> {(A,B): object} + +""" + from qlib.model.ens.ensemble import Ensemble, RollingEnsemble from typing import Callable, Union from joblib import Parallel, delayed @@ -21,20 +35,20 @@ class Group: self._group_func = group_func self._ens_func = ens - def group(self, *args, **kwargs): + def group(self, *args, **kwargs) -> dict: # TODO: such design is weird when `_group_func` is the only configurable part in the class if isinstance(getattr(self, "_group_func", None), Callable): return self._group_func(*args, **kwargs) else: raise NotImplementedError(f"Please specify valid `group_func`.") - def reduce(self, *args, **kwargs): + def reduce(self, *args, **kwargs) -> dict: if isinstance(getattr(self, "_ens_func", None), Callable): return self._ens_func(*args, **kwargs) else: raise NotImplementedError(f"Please specify valid `_ens_func`.") - def __call__(self, ungrouped_dict: dict, n_jobs=1, verbose=0, *args, **kwargs): + def __call__(self, ungrouped_dict: dict, n_jobs=1, verbose=0, *args, **kwargs) -> dict: """Group the ungrouped_dict into different groups. Args: @@ -59,7 +73,7 @@ class Group: class RollingGroup(Group): """group the rolling dict""" - def group(self, rolling_dict: dict): + def group(self, rolling_dict: dict) -> dict: """Given an rolling dict likes {(A,B,R): things}, return the grouped dict likes {(A,B): {R:things}} NOTE: There is a assumption which is the rolling key is at the end of key tuple, because the rolling results always need to be ensemble firstly. diff --git a/qlib/model/task.py b/qlib/model/task.py deleted file mode 100644 index f29f513a4..000000000 --- a/qlib/model/task.py +++ /dev/null @@ -1,27 +0,0 @@ -import abc -import typing - - -class TaskGen(metaclass=abc.ABCMeta): - @abc.abstractmethod - def __call__(self, *args, **kwargs) -> typing.List[dict]: - """ - generate - - Parameters - ---------- - args, kwargs: - The info for generating tasks - Example 1): - input: a specific task template - output: rolling version of the tasks - Example 2): - input: a specific task template - output: a set of tasks with different losses - - Returns - ------- - typing.List[dict]: - A list of tasks - """ - pass diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index af65c5886..7680674a6 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -1,58 +1,69 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import copy +""" +The Trainer will train a list of tasks and return a list of model recorder. +There are two steps in each Trainer including ``train``(make model recorder) and ``end_train``(modify model recorder). + +This is concept called ``DelayTrainer``, which can be used in online simulating to parallel training. +In ``DelayTrainer``, the first step is only to save some necessary info to model recorder, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting. + +``Qlib`` offer two kind of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically. +""" + +import socket import time -from xxlimited import Str -from qlib.utils import init_instance_by_config, flatten_dict, get_cls_kwargs -from qlib.workflow import R -from qlib.workflow.recorder import Recorder -from qlib.workflow.record_temp import SignalRecord -from qlib.workflow.task.manage import TaskManager, run_task +from typing import Callable, List + from qlib.data.dataset import Dataset from qlib.model.base import Model -import socket +from qlib.utils import flatten_dict, get_cls_kwargs, init_instance_by_config +from qlib.workflow import R +from qlib.workflow.record_temp import SignalRecord +from qlib.workflow.recorder import Recorder +from qlib.workflow.task.manage import TaskManager, run_task -def begin_task_train(task_config: dict, experiment_name: str, *args, **kwargs) -> Recorder: +def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str = None) -> Recorder: """ - Begin a task training with starting a recorder and saving the task config. + Begin a task training to start a recorder and save the task config. Args: - task_config (dict) - experiment_name (str) + task_config (dict): the config of a task + experiment_name (str): the name of experiment + recorder_name (str): the given name will be the recorder name. None for using rid. Returns: - Recorder + Recorder: the model recorder """ - with R.start(experiment_name=experiment_name, recorder_name=str(time.time())): + with R.start(experiment_name=experiment_name, recorder_name=recorder_name): R.log_params(**flatten_dict(task_config)) R.save_objects(**{"task": task_config}) # keep the original format and datatype - R.set_tags(**{"hostname": socket.gethostname(), "train_status": "begin_task_train"}) + R.set_tags(**{"hostname": socket.gethostname()}) recorder: Recorder = R.get_recorder() return recorder -def end_task_train(rec: Recorder, experiment_name: str, *args, **kwargs): +def end_task_train(rec: Recorder, experiment_name: str) -> Recorder: """ - Finished task training with real model fitting and saving. + Finish task training with real model fitting and saving. Args: - rec (Recorder): This recorder will be resumed - experiment_name (str) + rec (Recorder): the recorder will be resumed + experiment_name (str): the name of experiment Returns: - Recorder + Recorder: the model recorder """ - with R.start(experiment_name=experiment_name, recorder_name=rec.info["name"], resume=True): + with R.start(experiment_name=experiment_name, recorder_id=rec.info["id"], resume=True): task_config = R.load_object("task") - # model & dataset initiaiton + # model & dataset initiation model: Model = init_instance_by_config(task_config["model"]) dataset: Dataset = init_instance_by_config(task_config["dataset"]) # model training model.fit(dataset) R.save_objects(**{"params.pkl": model}) - # This dataset is saved for online inference. So the concrete data should not be dumped + # this dataset is saved for online inference. So the concrete data should not be dumped dataset.config(dump_all=False, recursive=True) R.save_objects(**{"dataset": dataset}) # generate records: prediction, backtest, and analysis @@ -67,18 +78,18 @@ def end_task_train(rec: Recorder, experiment_name: str, *args, **kwargs): rconf = {"recorder": rec} r = cls(**kwargs, **rconf) r.generate() - R.set_tags(**{"train_status": "end_task_train"}) + return rec def task_train(task_config: dict, experiment_name: str) -> Recorder: """ - task based training + Task based training, will be divided into two steps. Parameters ---------- task_config : dict - A dict describes a task setting. + The config of a task. experiment_name: str The name of experiment @@ -96,39 +107,79 @@ class Trainer: The trainer which can train a list of model """ - def train(self, tasks: list, *args, **kwargs): - """Given a list of model definition, begin a training and return the models. + def __init__(self): + self.delay = False + + def train(self, tasks: list, *args, **kwargs) -> list: + """ + Given a list of model definition, begin a training and return the models. + + Args: + tasks: a list of tasks Returns: list: a list of models """ raise NotImplementedError(f"Please implement the `train` method.") - def end_train(self, models, *args, **kwargs): - """Given a list of models, finished something in the end of training if you need. + def end_train(self, models: list, *args, **kwargs) -> list: + """ + Given a list of models, finished something in the end of training if you need. + The models maybe Recorder, txt file, database and so on. + + Args: + models: a list of models Returns: list: a list of models """ + # do nothing if you finished all work in `train` method + return models + + def is_delay(self) -> bool: + """ + If Trainer will delay finishing `end_train`. + + Returns: + bool: if DelayTrainer + """ + return self.delay + + def reset(self): + """ + Reset the Trainer status. + """ pass class TrainerR(Trainer): - """Trainer based on (R)ecorder. + """ + Trainer based on (R)ecorder. + It will train a list of tasks and return a list of model recorder in a linear way. Assumption: models were defined by `task` and the results will saved to `Recorder` """ - def __init__(self, experiment_name, train_func=task_train): + def __init__(self, experiment_name: str, train_func: Callable = task_train): + """ + Init TrainerR. + + Args: + experiment_name (str): the name of experiment. + train_func (Callable, optional): default training method. Defaults to `task_train`. + """ + super().__init__() self.experiment_name = experiment_name self.train_func = train_func - def train(self, tasks: list, train_func=None, *args, **kwargs): - """Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed. + def train(self, tasks: list, train_func: Callable = None, **kwargs) -> List[Recorder]: + """ + Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed. Args: tasks (list): a list of definition based on `task` dict - train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default. + train_func (Callable): the train method which need at least `task`s and `experiment_name`. None for default training method. + kwargs: the params for train_func. Returns: list: a list of Recorders @@ -137,17 +188,74 @@ class TrainerR(Trainer): train_func = self.train_func recs = [] for task in tasks: - recs.append(train_func(task, self.experiment_name, *args, **kwargs)) + rec = train_func(task, self.experiment_name, **kwargs) + rec.set_tags(**{"train_status": "begin_task_train"}) + recs.append(rec) + return recs + + def end_train(self, recs: list, **kwargs) -> list: + for rec in recs: + rec.set_tags(**{"train_status": "end_task_train"}) + return recs + + +class DelayTrainerR(TrainerR): + """ + A delayed implementation based on TrainerR, which means `train` method may only do some preparation and `end_train` method can do the real model fitting. + """ + + def __init__(self, experiment_name, train_func=begin_task_train, end_train_func=end_task_train): + """ + Init TrainerRM. + + Args: + experiment_name (str): the name of experiment. + train_func (Callable, optional): default train method. Defaults to `begin_task_train`. + end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`. + """ + super().__init__(experiment_name, train_func) + self.end_train_func = end_train_func + self.delay = True + + def end_train(self, recs, end_train_func=None, **kwargs) -> List[Recorder]: + """ + Given a list of Recorder and return a list of trained Recorder. + This class will finish real data loading and model fitting. + + Args: + recs (list): a list of Recorder, the tasks have been saved to them + end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func. + kwargs: the params for end_train_func. + + Returns: + list: a list of Recorders + """ + if end_train_func is None: + end_train_func = self.end_train_func + for rec in recs: + end_train_func(rec, **kwargs) + rec.set_tags(**{"train_status": "end_task_train"}) return recs class TrainerRM(Trainer): - """Trainer based on (R)ecorder and Task(M)anager + """ + Trainer based on (R)ecorder and Task(M)anager. + It can train a list of tasks and return a list of model recorder in a multiprocessing way. Assumption: `task` will be saved to TaskManager and `task` will be fetched and trained from TaskManager """ def __init__(self, experiment_name: str, task_pool: str, train_func=task_train): + """ + Init TrainerR. + + Args: + experiment_name (str): the name of experiment. + task_pool (str): task pool name in TaskManager. + train_func (Callable, optional): default training method. Defaults to `task_train`. + """ + super().__init__() self.experiment_name = experiment_name self.task_pool = task_pool self.train_func = train_func @@ -155,20 +263,23 @@ class TrainerRM(Trainer): def train( self, tasks: list, - train_func=None, - before_status=TaskManager.STATUS_WAITING, - after_status=TaskManager.STATUS_DONE, - *args, + train_func: Callable = None, + before_status: str = TaskManager.STATUS_WAITING, + after_status: str = TaskManager.STATUS_DONE, **kwargs, - ): - """Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed. + ) -> List[Recorder]: + """ + Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed. This method defaults to a single process, but TaskManager offered a great way to parallel training. Users can customize their train_func to realize multiple processes or even multiple machines. Args: tasks (list): a list of definition based on `task` dict - train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default. + train_func (Callable): the train method which need at least `task`s and `experiment_name`. None for default training method. + before_status (str): the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE. + after_status (str): the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE. + kwargs: the params for train_func. Returns: list: a list of Recorders @@ -183,63 +294,29 @@ class TrainerRM(Trainer): experiment_name=self.experiment_name, before_status=before_status, after_status=after_status, - *args, **kwargs, ) recs = [] for _id in _id_list: - recs.append(tm.re_query(_id)["res"]) + rec = tm.re_query(_id)["res"] + rec.set_tags(**{"train_status": "begin_task_train"}) + recs.append(rec) return recs - -class DelayTrainerR(TrainerR): - """ - A delayed implementation based on TrainerR, which means `train` method may only do some preparation and `end_train` method can do the real model fitting. - - """ - - def __init__(self, experiment_name, train_func=begin_task_train, end_train_func=end_task_train): - super().__init__(experiment_name, train_func) - self.end_train_func = end_train_func - self.recs = [] - - def train(self, tasks: list, train_func, *args, **kwargs): - """ - Same as `train` of TrainerR, the results will be recorded in self.recs - - Args: - tasks (list): a list of definition based on `task` dict - train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default. - - Returns: - list: a list of Recorders - """ - self.recs = super().train(tasks, train_func=train_func, *args, **kwargs) - return self.recs - - def end_train(self, recs=None, end_train_func=None): - """ - Given a list of Recorder and return a list of trained Recorder. - This class will finished real data loading and model fitting. - - Args: - recs (list, optional): a list of Recorder, the tasks have been saved to them. Defaults to None for using self.recs. - end_train_func (Callable, optional): the end_train method which need at least `rec` and `experiment_name`. Defaults to None for using self.end_train_func. - - Returns: - list: a list of Recorders - """ - if recs is None: - recs = copy.deepcopy(self.recs) - # the models will be only trained once - self.recs = [] - if end_train_func is None: - end_train_func = self.end_train_func + def end_train(self, recs: list, **kwargs) -> list: for rec in recs: - end_train_func(rec) + rec.set_tags(**{"train_status": "end_task_train"}) return recs + def reset(self): + """ + .. note:: + this method will delete all task in this task_pool! + """ + tm = TaskManager(task_pool=self.task_pool) + tm.remove() + class DelayTrainerRM(TrainerRM): """ @@ -250,28 +327,28 @@ class DelayTrainerRM(TrainerRM): def __init__(self, experiment_name, task_pool: str, train_func=begin_task_train, end_train_func=end_task_train): super().__init__(experiment_name, task_pool, train_func) self.end_train_func = end_train_func + self.delay = True - def train(self, tasks: list, train_func=None, *args, **kwargs): + def train(self, tasks: list, train_func=None, **kwargs): """ - Same as `train` of TrainerRM, the results will be recorded in self.recs - + Same as `train` of TrainerRM, after_status will be STATUS_PART_DONE. Args: tasks (list): a list of definition based on `task` dict - train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default. - + train_func (Callable): the train method which need at least `task`s and `experiment_name`. Defaults to None for using self.train_func. Returns: list: a list of Recorders """ - return super().train(tasks, train_func=train_func, after_status=TaskManager.STATUS_PART_DONE, *args, **kwargs) + return super().train(tasks, train_func=train_func, after_status=TaskManager.STATUS_PART_DONE, **kwargs) - def end_train(self, recs, end_train_func=None): + def end_train(self, recs, end_train_func=None, **kwargs): """ Given a list of Recorder and return a list of trained Recorder. - This class will finished real data loading and model fitting. + This class will finish real data loading and model fitting. Args: - recs (list, optional): a list of Recorder, the tasks have been saved to them. Defaults to None for using self.recs.. - end_train_func (Callable, optional): the end_train method which need at least `rec` and `experiment_name`. Defaults to None for using self.end_train_func. + recs (list): a list of Recorder, the tasks have been saved to them. + end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func. + kwargs: the params for end_train_func. Returns: list: a list of Recorders @@ -284,5 +361,8 @@ class DelayTrainerRM(TrainerRM): self.task_pool, experiment_name=self.experiment_name, before_status=TaskManager.STATUS_PART_DONE, + **kwargs, ) + for rec in recs: + rec.set_tags(**{"train_status": "end_task_train"}) return recs diff --git a/qlib/portfolio/optimizer/base.py b/qlib/portfolio/optimizer/base.py index 502443869..e3f692014 100644 --- a/qlib/portfolio/optimizer/base.py +++ b/qlib/portfolio/optimizer/base.py @@ -5,9 +5,9 @@ import abc class BaseOptimizer(abc.ABC): - """ Construct portfolio with a optimization related method """ + """Construct portfolio with a optimization related method""" @abc.abstractmethod def __call__(self, *args, **kwargs) -> object: - """ Generate a optimized portfolio allocation """ + """Generate a optimized portfolio allocation""" pass diff --git a/qlib/utils/serial.py b/qlib/utils/serial.py index 1b775d99a..52d326c2a 100644 --- a/qlib/utils/serial.py +++ b/qlib/utils/serial.py @@ -3,11 +3,12 @@ from pathlib import Path import pickle +from typing import Union class Serializable: """ - Serializable will change the behaviours of pickle. + Serializable will change the behaviors of pickle. - It only saves the state whose name **does not** start with `_` It provides a syntactic sugar for distinguish the attributes which user doesn't want. - For examples, a learnable Datahandler just wants to save the parameters without data when dumping to disk @@ -70,7 +71,7 @@ class Serializable: obj.config(**params, recursive=True) del self.__dict__[self.FLAG_KEY] - def to_pickle(self, path: [Path, str], dump_all: bool = None, exclude: list = None): + def to_pickle(self, path: Union[Path, str], dump_all: bool = None, exclude: list = None): self.config(dump_all=dump_all, exclude=exclude) with Path(path).open("wb") as f: pickle.dump(self, f) diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index 46f9c563f..2b2535edc 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -23,7 +23,10 @@ class QlibRecorder: @contextmanager def start( self, + *, + experiment_id: Optional[Text] = None, experiment_name: Optional[Text] = None, + recorder_id: Optional[Text] = None, recorder_name: Optional[Text] = None, uri: Optional[Text] = None, resume: bool = False, @@ -45,8 +48,12 @@ class QlibRecorder: Parameters ---------- + experiment_id : str + id of the experiment one wants to start. experiment_name : str name of the experiment one wants to start. + recorder_id : str + id of the recorder under the experiment one wants to start. recorder_name : str name of the recorder under the experiment one wants to start. uri : str @@ -57,7 +64,14 @@ class QlibRecorder: resume : bool whether to resume the specific recorder with given name under the given experiment. """ - run = self.start_exp(experiment_name, recorder_name, uri, resume) + run = self.start_exp( + experiment_id=experiment_id, + experiment_name=experiment_name, + recorder_id=recorder_id, + recorder_name=recorder_name, + uri=uri, + resume=resume, + ) try: yield run except Exception as e: @@ -65,7 +79,9 @@ class QlibRecorder: raise e self.end_exp(Recorder.STATUS_FI) - def start_exp(self, experiment_name=None, recorder_name=None, uri=None, resume=False): + def start_exp( + self, *, experiment_id=None, experiment_name=None, recorder_id=None, recorder_name=None, uri=None, resume=False + ): """ Lower level method for starting an experiment. When use this method, one should end the experiment manually and the status of the recorder may not be handled properly. Here is the example code: @@ -79,8 +95,12 @@ class QlibRecorder: Parameters ---------- + experiment_id : str + id of the experiment one wants to start. experiment_name : str the name of the experiment to be started + recorder_id : str + id of the recorder under the experiment one wants to start. recorder_name : str name of the recorder under the experiment one wants to start. uri : str @@ -93,7 +113,14 @@ class QlibRecorder: ------- An experiment instance being started. """ - return self.exp_manager.start_exp(experiment_name, recorder_name, uri, resume) + return self.exp_manager.start_exp( + experiment_id=experiment_id, + experiment_name=experiment_name, + recorder_id=recorder_id, + recorder_name=recorder_name, + uri=uri, + resume=resume, + ) def end_exp(self, recorder_status=Recorder.STATUS_FI): """ diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index dd73f7f52..467c7c3f4 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -1,14 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import mlflow +import mlflow, logging from mlflow.entities import ViewType from mlflow.exceptions import MlflowException from pathlib import Path from .recorder import Recorder, MLflowRecorder from ..log import get_module_logger -logger = get_module_logger("workflow", "INFO") +logger = get_module_logger("workflow", logging.INFO) class Experiment: @@ -39,12 +39,14 @@ class Experiment: output["recorders"] = list(recorders.keys()) return output - def start(self, recorder_name=None, resume=False): + def start(self, *, recorder_id=None, recorder_name=None, resume=False): """ Start the experiment and set it to be active. This method will also start a new recorder. Parameters ---------- + recorder_id : str + the id of the recorder to be created. recorder_name : str the name of the recorder to be created. resume : bool @@ -238,14 +240,14 @@ class MLflowExperiment(Experiment): def __repr__(self): return "{name}(id={id}, info={info})".format(name=self.__class__.__name__, id=self.id, info=self.info) - def start(self, recorder_name=None, resume=False): + def start(self, *, recorder_id=None, recorder_name=None, resume=False): logger.info(f"Experiment {self.id} starts running ...") # Get or create recorder if recorder_name is None: recorder_name = self._default_rec_name # resume the recorder if resume: - recorder, _ = self._get_or_create_rec(recorder_name=recorder_name) + recorder, _ = self._get_or_create_rec(recorder_id=recorder_id, recorder_name=recorder_name) # create a new recorder else: recorder = self.create_recorder(recorder_name) diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 5275e57d7..04cc3bcb7 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -4,7 +4,7 @@ import mlflow from mlflow.exceptions import MlflowException from mlflow.entities import ViewType -import os +import os, logging from pathlib import Path from contextlib import contextmanager from typing import Optional, Text @@ -14,7 +14,7 @@ from ..config import C from .recorder import Recorder from ..log import get_module_logger -logger = get_module_logger("workflow", "INFO") +logger = get_module_logger("workflow", logging.INFO) class ExpManager: @@ -33,7 +33,10 @@ class ExpManager: def start_exp( self, + *, + experiment_id: Optional[Text] = None, experiment_name: Optional[Text] = None, + recorder_id: Optional[Text] = None, recorder_name: Optional[Text] = None, uri: Optional[Text] = None, resume: bool = False, @@ -45,8 +48,12 @@ class ExpManager: Parameters ---------- + experiment_id : str + id of the active experiment. experiment_name : str name of the active experiment. + recorder_id : str + id of the recorder to be started. recorder_name : str name of the recorder to be started. uri : str @@ -298,7 +305,10 @@ class MLflowExpManager(ExpManager): def start_exp( self, + *, + experiment_id: Optional[Text] = None, experiment_name: Optional[Text] = None, + recorder_id: Optional[Text] = None, recorder_name: Optional[Text] = None, uri: Optional[Text] = None, resume: bool = False, @@ -308,11 +318,11 @@ class MLflowExpManager(ExpManager): # Create experiment if experiment_name is None: experiment_name = self._default_exp_name - experiment, _ = self._get_or_create_exp(experiment_name=experiment_name) + experiment, _ = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name) # Set up active experiment self.active_experiment = experiment # Start the experiment - self.active_experiment.start(recorder_name, resume) + self.active_experiment.start(recorder_id=recorder_id, recorder_name=recorder_name, resume=resume) return self.active_experiment diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py index e107271d0..6c62fbce9 100644 --- a/qlib/workflow/online/manager.py +++ b/qlib/workflow/online/manager.py @@ -1,477 +1,177 @@ -from copy import deepcopy -from operator import index -import pandas as pd -from qlib.model.ens.ensemble import ens_workflow -from qlib.model.ens.group import RollingGroup -from qlib.utils.serial import Serializable -from typing import Dict, List, Union -from qlib import get_module_logger -from qlib.data.data import D -from qlib.model.trainer import Trainer, TrainerR, task_train -from qlib.workflow import R -from qlib.workflow.online.update import PredUpdater -from qlib.workflow.recorder import Recorder -from qlib.workflow.task.collect import Collector, RecorderCollector -from qlib.workflow.task.gen import RollingGen, task_generator -from qlib.workflow.task.utils import TimeAdjuster, list_recorders +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """ -This class is a component of online serving, it can manage a series of models dynamically. -With the change of time, the decisive models will be also changed. In this module, we called those contributing models as `online` models. +OnlineManager can manage a set of `Online Strategy <#Online Strategy>`_ and run them dynamically. + +With the change of time, the decisive models will be also changed. In this module, we call those contributing models as `online` models. In every routine(such as everyday or every minutes), the `online` models maybe changed and the prediction of them need to be updated. So this module provide a series methods to control this process. + +This module also provide a method to simulate `Online Strategy <#Online Strategy>`_ in the history. +Which means you can verify your strategy or find a better one. """ +from typing import Dict, List, Union + +import pandas as pd +from qlib import get_module_logger +from qlib.data.data import D +from qlib.model.ens.ensemble import AverageEnsemble, SingleKeyEnsemble +from qlib.utils.serial import Serializable +from qlib.workflow.online.strategy import OnlineStrategy +from qlib.workflow.task.collect import HyperCollector + class OnlineManager(Serializable): - - ONLINE_KEY = "online_status" # the online status key in recorder - ONLINE_TAG = "online" # the 'online' model - # NOTE: The meaning of this tag is that we can not assume the training models can be trained before we need its predition. Whenever finished training, it can be guaranteed that there are some online models. - NEXT_ONLINE_TAG = "next_online" # the 'next online' model, which can be 'online' model when call reset_online_model - OFFLINE_TAG = "offline" # the 'offline' model, not for online serving - - SIGNAL_EXP = "OnlineManagerSignals" # a specific experiment to save signals of different experiment. - - def __init__(self, trainer: Trainer = None, need_log=True): - """ - init OnlineManager. - - Args: - trainer (Trainer, optional): a instance of Trainer. Defaults to None. - need_log (bool, optional): print log or not. Defaults to True. - """ - self.trainer = trainer - self.logger = get_module_logger(self.__class__.__name__) - self.need_log = need_log - self.cur_time = None - - def prepare_signals(self): - """ - After perparing the data of last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for next routine. - Must use `pass` even though there is nothing to do. - """ - raise NotImplementedError(f"Please implement the `prepare_signals` method.") - - def get_signals(self): - """ - After preparing signals, here is the method to get them. - """ - raise NotImplementedError(f"Please implement the `get_signals` method.") - - def prepare_tasks(self, *args, **kwargs): - """ - After the end of a routine, check whether we need to prepare and train some new tasks. - return the new tasks waiting for training. - """ - raise NotImplementedError(f"Please implement the `prepare_tasks` method.") - - def prepare_new_models(self, tasks, tag=NEXT_ONLINE_TAG, check_func=None, *args, **kwargs): - """ - Use trainer to train a list of tasks and set the trained model to `tag`. - - Args: - tasks (list): a list of tasks. - tag (str): - `ONLINE_TAG` for first train or additional train - `NEXT_ONLINE_TAG` for reset online model when calling `reset_online_tag` - `OFFLINE_TAG` for train but offline those models - check_func: the method to judge if a model can be online. - The parameter is the model record and return True for online. - None for online every models. - *args, **kwargs: will be passed to end_train which means will be passed to customized train method. - - """ - if check_func is None: - check_func = lambda x: True - if len(tasks) > 0: - if self.trainer is not None: - new_models = self.trainer.train(tasks, *args, **kwargs) - if check_func(new_models): - self.set_online_tag(tag, new_models) - if self.need_log: - self.logger.info(f"Finished preparing {len(new_models)} new models and set them to {tag}.") - else: - self.logger.warn("No trainer to train new tasks.") - - def update_online_pred(self): - """ - After the end of a routine, update the predictions of online models to latest. - """ - raise NotImplementedError(f"Please implement the `update_online_pred` method.") - - def set_online_tag(self, tag, recorder): - """ - Set `tag` to the model to sign whether online. - - Args: - tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG` - """ - raise NotImplementedError(f"Please implement the `set_online_tag` method.") - - def get_online_tag(self): - """ - Given a model and return its online tag. - """ - raise NotImplementedError(f"Please implement the `get_online_tag` method.") - - def reset_online_tag(self, recorders=None): - """offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing. - - Args: - recorders (List, optional): - the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model. - - Returns: - list: new online recorder. [] if there is no update. - """ - raise NotImplementedError(f"Please implement the `reset_online_tag` method.") - - def online_models(self): - """ - Return online models. - """ - raise NotImplementedError(f"Please implement the `online_models` method.") - - def first_train(self): - """ - Train a series of models firstly and set some of them into online models. - """ - raise NotImplementedError(f"Please implement the `first_train` method.") - - def get_collector(self): - """ - Return the collector. - - Returns: - Collector - """ - raise NotImplementedError(f"Please implement the `get_collector` method.") - - def delay_prepare(self, rec_dict, *args, **kwargs): - """ - Prepare all models and signals if there are something waiting for prepare. - NOTE: Assumption: the predictions of online models are between `time_segment`, or this method will work in a wrong way. - - Args: - rec_dict (str): an online models dict likes {(begin_time, end_time):[online models]}. - *args, **kwargs: will be passed to end_train which means will be passed to customized train method. - """ - for time_segment, recs_list in rec_dict.items(): - self.trainer.end_train(recs_list, *args, **kwargs) - self.reset_online_tag(recs_list) - self.prepare_signals() - signal_max = self.get_signals().index.get_level_values("datetime").max() - if time_segment[1] is not None and signal_max > time_segment[1]: - raise ValueError( - f"The max time of signals prepared by online models is {signal_max}, but those models only online in {time_segment}" - ) - - def routine(self, cur_time=None, delay_prepare=False, *args, **kwargs): - """ - The typical update process after a routine, such as day by day or month by month. - update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models - - NOTE: Assumption: if using simulator (delay_prepare is True), the prediction will be prepared well after every training, so there is no need to update predictions. - - Args: - cur_time ([type], optional): [description]. Defaults to None. - delay_prepare (bool, optional): [description]. Defaults to False. - *args, **kwargs: will be passed to `prepare_tasks` and `prepare_new_models`. It can be some hyper parameter or training config. - - Returns: - [type]: [description] - """ - self.cur_time = cur_time # None for latest date - if not delay_prepare: - self.update_online_pred() - self.prepare_signals() - tasks = self.prepare_tasks(*args, **kwargs) - self.prepare_new_models(tasks, *args, **kwargs) - - return self.reset_online_tag() - - -class OnlineManagerR(OnlineManager): """ - The implementation of OnlineManager based on (R)ecorder. - + OnlineManager can manage online models with `Online Strategy <#Online Strategy>`_. + It also provide a history recording which models are onlined at what time. """ - def __init__(self, experiment_name: str, trainer: Trainer = None, need_log=True): - """ - init OnlineManagerR. - - Args: - experiment_name (str): the experiment name. - trainer (Trainer, optional): a instance of Trainer. Defaults to None. - need_log (bool, optional): print log or not. Defaults to True. - """ - if trainer is None: - trainer = TrainerR(experiment_name) - super().__init__(trainer=trainer, need_log=need_log) - self.exp_name = experiment_name - self.signal_rec = None - - def set_online_tag(self, tag, recorder: Union[Recorder, List]): - """ - Set `tag` to the model to sign whether online. - - Args: - tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG` - recorder (Union[Recorder, List]) - """ - if isinstance(recorder, Recorder): - recorder = [recorder] - for rec in recorder: - rec.set_tags(**{self.ONLINE_KEY: tag}) - if self.need_log: - self.logger.info(f"Set {len(recorder)} models to '{tag}'.") - - def get_online_tag(self, recorder: Recorder): - """ - Given a model and return its online tag. - - Args: - recorder (Recorder): a instance of recorder - - Returns: - str: the tag - """ - tags = recorder.list_tags() - return tags.get(OnlineManager.ONLINE_KEY, OnlineManager.OFFLINE_TAG) - - def reset_online_tag(self, recorder: Union[Recorder, List] = None): - """offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing. - - Args: - recorders (Union[Recorder, List], optional): - the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model. - - Returns: - list: new online recorder. [] if there is no update. - """ - if recorder is None: - recorder = list( - list_recorders( - self.exp_name, lambda rec: self.get_online_tag(rec) == OnlineManager.NEXT_ONLINE_TAG - ).values() - ) - if isinstance(recorder, Recorder): - recorder = [recorder] - if len(recorder) == 0: - if self.need_log: - self.logger.info("No 'next online' model, just use current 'online' models.") - return [] - recs = list_recorders(self.exp_name) - self.set_online_tag(OnlineManager.OFFLINE_TAG, list(recs.values())) - self.set_online_tag(OnlineManager.ONLINE_TAG, recorder) - return recorder - - def get_signals(self): - """ - get signals from the recorder(named self.exp_name) of the experiment(named self.SIGNAL_EXP) - - Returns: - signals - """ - if self.signal_rec is None: - with R.start(experiment_name=self.SIGNAL_EXP, recorder_name=self.exp_name, resume=True): - self.signal_rec = R.get_recorder() - signals = None - try: - signals = self.signal_rec.load_object("signals") - except OSError: - self.logger.warn("Can not find `signals`, have you called `prepare_signals` before?") - return signals - - def online_models(self): - """ - Return online models. - - Returns: - list: the list of online models - """ - return list( - list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG).values() - ) - - def update_online_pred(self): - """ - Update all online model predictions to the latest day in Calendar - """ - online_models = self.online_models() - for rec in online_models: - PredUpdater(rec, to_date=self.cur_time, need_log=self.need_log).update() - - if self.need_log: - self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.") - - def prepare_signals(self, over_write=False): - """ - Average the predictions of online models and offer a trading signals every routine. - The signals will be saved to `signal` file of a recorder named self.exp_name of a experiment using the name of `SIGNAL_EXP` - Even if the latest signal already exists, the latest calculation result will be overwritten. - NOTE: Given a prediction of a certain time, all signals before this time will be prepared well. - Args: - over_write (bool, optional): If True, the new signals will overwrite the file. If False, the new signals will append to the end of signals. Defaults to False. - """ - if self.signal_rec is None: - with R.start(experiment_name=self.SIGNAL_EXP, recorder_name=self.exp_name, resume=True): - self.signal_rec = R.get_recorder() - - pred = [] - try: - old_signals = self.signal_rec.load_object("signals") - except OSError: - old_signals = None - - for rec in self.online_models(): - pred.append(rec.load_object("pred.pkl")) - - signals = pd.concat(pred, axis=1).mean(axis=1).to_frame("score") - signals = signals.sort_index() - if old_signals is not None and not over_write: - old_max = old_signals.index.get_level_values("datetime").max() - new_signals = signals.loc[old_max:] - signals = pd.concat([old_signals, new_signals], axis=0) - else: - new_signals = signals - if self.need_log: - self.logger.info(f"Finished preparing new {len(new_signals)} signals to {self.SIGNAL_EXP}/{self.exp_name}.") - self.signal_rec.save_objects(**{"signals": signals}) - - -class RollingOnlineManager(OnlineManagerR): - """An implementation of OnlineManager based on Rolling.""" - def __init__( self, - experiment_name: str, - rolling_gen: RollingGen, - trainer: Trainer = None, + strategy: Union[OnlineStrategy, List[OnlineStrategy]], + begin_time: Union[str, pd.Timestamp] = None, + freq="day", need_log=True, ): """ - init RollingOnlineManager. + Init OnlineManager. + One OnlineManager must have at least one OnlineStrategy. Args: - experiment_name (str): the experiment name. - rolling_gen (RollingGen): a instance of RollingGen - trainer (Trainer, optional): a instance of Trainer. Defaults to None. - collector (Collector, optional): a instance of Collector. Defaults to None. + strategy (Union[OnlineStrategy, List[OnlineStrategy]]): an instance of OnlineStrategy or a list of OnlineStrategy + begin_time (Union[str,pd.Timestamp], optional): the OnlineManager will begin at this time. Defaults to None for using latest date. + freq (str, optional): data frequency. Defaults to "day". need_log (bool, optional): print log or not. Defaults to True. """ - if trainer is None: - trainer = TrainerR(experiment_name) - super().__init__(experiment_name=experiment_name, trainer=trainer, need_log=need_log) - self.ta = TimeAdjuster() - self.rg = rolling_gen self.logger = get_module_logger(self.__class__.__name__) + self.need_log = need_log + if not isinstance(strategy, list): + strategy = [strategy] + self.strategy = strategy + self.freq = freq + if begin_time is None: + begin_time = D.calendar(freq=self.freq).max() + self.begin_time = pd.Timestamp(begin_time) + self.cur_time = self.begin_time + self.history = {} - def get_collector(self, rec_key_func=None, rec_filter_func=None): + def first_train(self): """ - Get the instance of collector to collect results. The returned collector must can distinguish results in different models. - Assumption: the models can be distinguished based on model name and rolling test segments. - If you do not want this assumption, please implement your own method or use another rec_key_func. + Run every strategy first_train method and record the online history. + """ + for strategy in self.strategy: + self.logger.info(f"Strategy `{strategy.name_id}` begins first training...") + online_models = strategy.first_train() + self.history.setdefault(strategy.name_id, {})[self.cur_time] = online_models + + def routine(self, cur_time: Union[str, pd.Timestamp] = None, task_kwargs: dict = {}, model_kwargs: dict = {}): + """ + Run typical update process for every strategy and record the online history. + + The typical update process after a routine, such as day by day or month by month. + The process is: Prepare signals -> Prepare tasks -> Prepare online models. Args: - rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id. - rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None. + cur_time (Union[str,pd.Timestamp], optional): run routine method in this time. Defaults to None. + task_kwargs (dict): the params for `prepare_tasks` + model_kwargs (dict): the params for `prepare_online_models` """ + if cur_time is None: + cur_time = D.calendar(freq=self.freq).max() + self.cur_time = pd.Timestamp(cur_time) # None for latest date + for strategy in self.strategy: + if self.need_log: + self.logger.info(f"Strategy `{strategy.name_id}` begins routine...") + if not strategy.trainer.is_delay(): + strategy.prepare_signals() + tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs) + online_models = strategy.prepare_online_models(tasks, **model_kwargs) + if len(online_models) > 0: + self.history.setdefault(strategy.name_id, {})[self.cur_time] = online_models - def rec_key(recorder): - task_config = recorder.load_object("task") - model_key = task_config["model"]["class"] - rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"] - return model_key, rolling_key - - if rec_key_func is None: - rec_key_func = rec_key - - return RecorderCollector(experiment=self.exp_name, rec_key_func=rec_key_func, rec_filter_func=rec_filter_func) - - def collect_artifact(self, rec_key_func=None, rec_filter_func=None): + def get_collector(self) -> HyperCollector: """ - collecting artifact based on the collector and RollingGroup. - - Args: - rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id. - rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None. + Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy. Returns: - dict: the artifact dict after rolling ensemble + HyperCollector: the collector to collect other collectors (using SingleKeyEnsemble() to make results more readable). """ - artifact = ens_workflow( - self.get_collector(rec_key_func=rec_key_func, rec_filter_func=rec_filter_func), RollingGroup() - ) - return artifact + collector_dict = {} + for strategy in self.strategy: + collector_dict[strategy.name_id] = strategy.get_collector() + return HyperCollector(collector_dict, process_list=SingleKeyEnsemble()) - def first_train(self, task_configs: list): + def get_online_history(self, strategy_name_id: str) -> list: """ - Use rolling_gen to generate different tasks based on task_configs and trained them. + Get the online history based on strategy_name_id. Args: - task_configs (list or dict): a list of task configs or a task config + strategy_name_id (str): the name_id of strategy Returns: - Collector: a instance of a Collector. + list: a list like [(begin_time, [online_models])] """ - tasks = task_generator( - tasks=task_configs, - generators=self.rg, # generate different date segment - ) - self.prepare_new_models(tasks, tag=self.ONLINE_TAG) + history_dict = self.history[strategy_name_id] + history = [] + for time in sorted(history_dict): + models = history_dict[time] + history.append((time, models)) + return history + + def delay_prepare(self, delay_kwargs={}): + """ + Prepare all models and signals if there are something waiting for prepare. + + Args: + delay_kwargs: the params for `delay_prepare` + """ + for strategy in self.strategy: + strategy.delay_prepare(self.get_online_history(strategy.name_id), **delay_kwargs) + + def get_signals(self) -> pd.DataFrame: + """ + Average all strategy signals as the online signals. + + Assumption: the signals from every strategy is pd.DataFrame. Override this function to change. + + Returns: + pd.DataFrame: signals + """ + signals_dict = {} + for strategy in self.strategy: + signals_dict[strategy.name_id] = strategy.get_signals() + return AverageEnsemble()(signals_dict) + + def simulate(self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, delay_kwargs={}) -> HyperCollector: + """ + Starting from current time, this method will simulate every routine in OnlineManager until end time. + + Considering the parallel training, the models and signals can be perpared after all routine simulating. + + The delay training way can be ``DelayTrainer`` and the delay preparing signals way can be ``delay_prepare``. + + Returns: + HyperCollector: the OnlineManager's collector + """ + cal = D.calendar(start_time=self.cur_time, end_time=end_time, freq=frequency) + self.first_train() + for cur_time in cal: + self.logger.info(f"Simulating at {str(cur_time)}......") + self.routine(cur_time, task_kwargs=task_kwargs, model_kwargs=model_kwargs) + self.delay_prepare(delay_kwargs=delay_kwargs) + self.logger.info(f"Finished preparing signals") return self.get_collector() - def prepare_tasks(self): + def reset(self): """ - Prepare new tasks based on new date. + This method will reset all strategy! - Returns: - list: a list of new tasks. + **Be careful to use it.** """ - latest_records, max_test = self.list_latest_recorders( - lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG - ) - if max_test is None: - self.logger.warn(f"No latest online recorders, no new tasks.") - return [] - calendar_latest = D.calendar(end_time=self.cur_time)[-1] if self.cur_time is None else self.cur_time - if self.need_log: - self.logger.info( - f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}" - ) - if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step: - old_tasks = [] - tasks_tmp = [] - for rid, rec in latest_records.items(): - task = rec.load_object("task") - old_tasks.append(deepcopy(task)) - test_begin = task["dataset"]["kwargs"]["segments"]["test"][0] - # modify the test segment to generate new tasks - task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest) - tasks_tmp.append(task) - new_tasks_tmp = task_generator(tasks_tmp, self.rg) - new_tasks = [task for task in new_tasks_tmp if task not in old_tasks] - return new_tasks - return [] - - def list_latest_recorders(self, rec_filter_func=None): - """find latest recorders based on test segments. - - Args: - rec_filter_func (Callable, optional): recorder filter. Defaults to None. - - Returns: - dict, tuple: the latest recorders and the latest date of them - """ - recs_flt = list_recorders(self.exp_name, rec_filter_func) - if len(recs_flt) == 0: - return recs_flt, None - max_test = max(rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] for rec in recs_flt.values()) - latest_rec = {} - for rid, rec in recs_flt.items(): - if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test: - latest_rec[rid] = rec - return latest_rec, max_test + self.cur_time = self.begin_time + self.history = {} + for strategy in self.strategy: + strategy.reset() diff --git a/qlib/workflow/online/simulator.py b/qlib/workflow/online/simulator.py deleted file mode 100644 index d45b7d99d..000000000 --- a/qlib/workflow/online/simulator.py +++ /dev/null @@ -1,72 +0,0 @@ -from qlib.data import D -from qlib import get_module_logger -from qlib.workflow.online.manager import OnlineManager - - -class OnlineSimulator: - """ - To simulate online serving in the past, like a "online serving backtest". - """ - - def __init__( - self, - start_time, - end_time, - online_manager: OnlineManager, - frequency="day", - ): - """ - init OnlineSimulator. - - Args: - start_time (str or pd.Timestamp): the start time of simulating. - end_time (str or pd.Timestamp): the end time of simulating. If None, then end_time is latest. - onlinemanager (OnlineManager): the instance of OnlineManager - frequency (str, optional): the data frequency. Defaults to "day". - """ - self.logger = get_module_logger(self.__class__.__name__) - self.cal = D.calendar(start_time=start_time, end_time=end_time, freq=frequency) - self.start_time = self.cal[0] - self.end_time = self.cal[-1] - self.olm = online_manager - if len(self.cal) == 0: - self.logger.warn(f"There is no need to simulate bacause start_time is larger than end_time.") - - def simulate(self, *args, **kwargs): - """ - Starting from start time, this method will simulate every routine in OnlineManager. - NOTE: Considering the parallel training, the models and signals can be perpared after all routine simulating. - - Returns: - Collector: the OnlineManager's collector - """ - self.rec_dict = {} - tmp_begin = self.start_time - tmp_end = None - prev_recorders = self.olm.online_models() - for cur_time in self.cal: - self.logger.info(f"Simulating at {str(cur_time)}......") - recorders = self.olm.routine(cur_time, True, *args, **kwargs) - if len(recorders) == 0: - tmp_end = cur_time - else: - self.rec_dict[(tmp_begin, tmp_end)] = prev_recorders - tmp_begin = cur_time - prev_recorders = recorders - self.rec_dict[(tmp_begin, self.end_time)] = prev_recorders - # finished perparing models (and pred) and signals - self.olm.delay_prepare(self.rec_dict) - self.logger.info(f"Finished preparing signals") - return self.olm.get_collector() - - def online_models(self): - """ - Return a online models dict likes {(begin_time, end_time):[online models]}. - - Returns: - dict - """ - if hasattr(self, "rec_dict"): - return self.rec_dict - self.logger.warn(f"Please call `simulate` firstly when calling `online_models`") - return {} diff --git a/qlib/workflow/online/strategy.py b/qlib/workflow/online/strategy.py new file mode 100644 index 000000000..0cae11b7f --- /dev/null +++ b/qlib/workflow/online/strategy.py @@ -0,0 +1,339 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +OnlineStrategy is a set of strategy for online serving. +""" + +from copy import deepcopy +from typing import List, Tuple, Union + +import pandas as pd +from qlib.data.data import D +from qlib.log import get_module_logger +from qlib.model.ens.ensemble import AverageEnsemble, SingleKeyEnsemble +from qlib.model.ens.group import RollingGroup +from qlib.model.trainer import Trainer, TrainerR +from qlib.workflow import R +from qlib.workflow.online.utils import OnlineTool, OnlineToolR +from qlib.workflow.recorder import Recorder +from qlib.workflow.task.collect import Collector, HyperCollector, RecorderCollector +from qlib.workflow.task.gen import RollingGen, task_generator +from qlib.workflow.task.utils import TimeAdjuster, list_recorders + + +class OnlineStrategy: + """ + OnlineStrategy is working with `Online Manager <#Online Manager>`_, responsing how the tasks are generated, the models are updated and signals are perpared. + """ + + def __init__(self, name_id: str, trainer: Trainer = None, need_log=True): + """ + Init OnlineStrategy. + This module **MUST** use `Trainer <../reference/api.html#Trainer>`_ to finishing model training. + + Args: + name_id (str): a unique name or id + trainer (Trainer, optional): a instance of Trainer. Defaults to None. + need_log (bool, optional): print log or not. Defaults to True. + """ + self.name_id = name_id + self.trainer = trainer + self.logger = get_module_logger(self.__class__.__name__) + self.need_log = need_log + self.tool = OnlineTool() + + def prepare_signals(self, delay: bool = False): + """ + After perparing the data of last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for next routine. + + NOTE: Given a set prediction, all signals before these prediction end time will be prepared well. + + Args: + delay: bool + If this method was called by `delay_prepare` + """ + raise NotImplementedError(f"Please implement the `prepare_signals` method.") + + def prepare_tasks(self, *args, **kwargs): + """ + After the end of a routine, check whether we need to prepare and train some new tasks. + Return the new tasks waiting for training. + + You can find last online models by OnlineTool.online_models. + """ + raise NotImplementedError(f"Please implement the `prepare_tasks` method.") + + def prepare_online_models(self, tasks, check_func=None, **kwargs): + """ + Use trainer to train a list of tasks and set the trained model to `online`. + + NOTE: This method will first offline all models and online the online models prepared by this method. So you can find last online models by OnlineTool.online_models if you still need them. + + Args: + tasks (list): a list of tasks. + check_func: the method to judge if a model can be online. + The parameter is the model record and return True for online. + None for online every models. + **kwargs: will be passed to end_train which means will be passed to customized train method. + + """ + if check_func is None: + check_func = lambda x: True + online_models = [] + if len(tasks) > 0: + new_models = self.trainer.train(tasks, **kwargs) + for model in new_models: + if check_func(model): + online_models.append(model) + self.tool.reset_online_tag(online_models) + return online_models + + def first_train(self): + """ + Train a series of models firstly and set some of them as online models. + """ + raise NotImplementedError(f"Please implement the `first_train` method.") + + def get_collector(self) -> Collector: + """ + Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results of online serving. + + + For example: + 1) collect predictions in Recorder + 2) collect signals in .txt file + + Returns: + Collector + """ + raise NotImplementedError(f"Please implement the `get_collector` method.") + + def delay_prepare(self, history: list, **kwargs): + """ + Prepare all models and signals if there are something waiting for prepare. + + Assumption: the predictions of online models need less than next begin_time, or this method will work in a wrong way. + + Args: + history (list): an online models list likes [begin_time:[online models]]. + **kwargs: will be passed to end_train which means will be passed to customized train method. + """ + for begin_time, recs_list in history: + self.trainer.end_train(recs_list, **kwargs) + self.tool.reset_online_tag(recs_list) + self.prepare_signals(delay=True) + + def get_signals(self): + """ + Get prepared signals. + """ + raise NotImplementedError(f"Please implement the `get_signals` method.") + + def reset(self): + """ + Delete all things and set them to default status. This method is convenient to explore the strategy for online simulation. + """ + pass + + +class RollingAverageStrategy(OnlineStrategy): + + """ + This example strategy always use latest rolling model as online model and prepare trading signals using the average prediction of online models + """ + + def __init__( + self, + name_id: str, + task_template: Union[dict, List[dict]], + rolling_gen: RollingGen, + trainer: Trainer = None, + need_log=True, + signal_exp_name="OnlineManagerSignals", + ): + """ + Init RollingAverageStrategy. + + Assumption: the str of name_id, the experiment name and the trainer's experiment name are same one. + + Args: + name_id (str): a unique name or id. Will be also the name of Experiment. + task_template (Union[dict,List[dict]]): a list of task_template or a single template, which will be used to generate many tasks using rolling_gen. + rolling_gen (RollingGen): an instance of RollingGen + trainer (Trainer, optional): a instance of Trainer. Defaults to None. + need_log (bool, optional): print log or not. Defaults to True. + signal_exp_path (str): a specific experiment to save signals of different experiment. + """ + super().__init__(name_id=name_id, trainer=trainer, need_log=need_log) + self.exp_name = self.name_id + if not isinstance(task_template, list): + task_template = [task_template] + self.task_template = task_template + self.signal_exp_name = signal_exp_name + self.rg = rolling_gen + self.tool = OnlineToolR(self.exp_name) + self.ta = TimeAdjuster() + with R.start(experiment_name=self.signal_exp_name, recorder_name=self.exp_name, resume=True): + self.signal_rec = R.get_recorder() # the recorder to record signals + self.signal_rec.save_objects(**{"signals": None}) + + def get_collector(self, process_list=[RollingGroup()], rec_key_func=None, rec_filter_func=None, artifacts_key=None): + """ + Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results. The returned collector must can distinguish results in different models. + Assumption: the models can be distinguished based on model name and rolling test segments. + If you do not want this assumption, please implement your own method or use another rec_key_func. + + Args: + rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id. + rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None. + artifacts_key (List[str], optional): the artifacts key you want to get. If None, get all artifacts. + """ + + def rec_key(recorder): + task_config = recorder.load_object("task") + model_key = task_config["model"]["class"] + rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"] + return model_key, rolling_key + + if rec_key_func is None: + rec_key_func = rec_key + + artifacts_collector = RecorderCollector( + experiment=self.exp_name, + process_list=process_list, + rec_key_func=rec_key_func, + rec_filter_func=rec_filter_func, + artifacts_key=artifacts_key, + ) + + return artifacts_collector + + def first_train(self) -> List[Recorder]: + """ + Use rolling_gen to generate different tasks based on task_template and trained them. + + Returns: + List[Recorder]: a list of Recorder. + """ + tasks = task_generator( + tasks=self.task_template, + generators=self.rg, # generate different date segment + ) + return self.prepare_online_models(tasks) + + def prepare_tasks(self, cur_time) -> List[dict]: + """ + Prepare new tasks based on cur_time (None for latest). + + You can find last online models by OnlineToolR.online_models. + + Returns: + List[dict]: a list of new tasks. + """ + latest_records, max_test = self._list_latest(self.tool.online_models()) + if max_test is None: + self.logger.warn(f"No latest online recorders, no new tasks.") + return [] + calendar_latest = D.calendar(end_time=cur_time)[-1] if cur_time is None else cur_time + if self.need_log: + self.logger.info( + f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}" + ) + if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step: + old_tasks = [] + tasks_tmp = [] + for rec in latest_records: + task = rec.load_object("task") + old_tasks.append(deepcopy(task)) + test_begin = task["dataset"]["kwargs"]["segments"]["test"][0] + # modify the test segment to generate new tasks + task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest) + tasks_tmp.append(task) + new_tasks_tmp = task_generator(tasks_tmp, self.rg) + new_tasks = [task for task in new_tasks_tmp if task not in old_tasks] + return new_tasks + return [] + + def prepare_signals(self, delay=False, over_write=False) -> pd.DataFrame: + """ + Average the predictions of online models and offer a trading signals every routine. + The signals will be saved to `signal` file of a recorder named self.exp_name of a experiment using the name of `SIGNAL_EXP` + Even if the latest signal already exists, the latest calculation result will be overwritten. + + .. note:: + + Given a prediction of a certain time, all signals before this time will be prepared well. + + Args: + over_write (bool, optional): If True, the new signals will overwrite the file. If False, the new signals will append to the end of signals. Defaults to False. + Returns: + pd.DataFrame: the signals. + """ + if not delay: + self.tool.update_online_pred() + + # Get a collector to average online models predictions + online_collector = self.get_collector( + process_list=[AverageEnsemble()], + rec_filter_func=lambda x: True if self.tool.get_online_tag(x) == self.tool.ONLINE_TAG else False, + artifacts_key="pred", + ) + online_results = online_collector() + signals = online_results["pred"] + + old_signals = self.get_signals() + if old_signals is not None and not over_write: + old_max = old_signals.index.get_level_values("datetime").max() + new_signals = signals.loc[old_max:] + signals = pd.concat([old_signals, new_signals], axis=0) + else: + new_signals = signals + if self.need_log: + self.logger.info( + f"Finished preparing new {len(new_signals)} signals to {self.signal_exp_name}/{self.exp_name}." + ) + self.signal_rec.save_objects(**{"signals": signals}) + return signals + + def get_signals(self) -> object: + """ + Get signals from the recorder(named self.exp_name) of the experiment(named self.SIGNAL_EXP) + + Returns: + object: signals + """ + signals = self.signal_rec.load_object("signals") + return signals + + def _list_latest(self, rec_list: List[Recorder]): + """ + List latest recorder form rec_list + + Args: + rec_list (List[Recorder]): a list of Recorder + + Returns: + List[Recorder], pd.Timestamp: the latest recorders and its test end time + """ + if len(rec_list) == 0: + return rec_list, None + max_test = max(rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] for rec in rec_list) + latest_rec = [] + for rec in rec_list: + if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test: + latest_rec.append(rec) + return latest_rec, max_test + + def reset(self): + """ + NOTE: This method will delete all recorder in Experiment and reset the Trainer! + """ + self.trainer.reset() + # delete models + exp = R.get_exp(experiment_name=self.exp_name) + for rid in exp.list_recorders(): + exp.delete_recorder(rid) + # delete signals + for rid in list_recorders(self.signal_exp_name, lambda x: True if x.info["name"] == self.exp_name else False): + exp.delete_recorder(rid) diff --git a/qlib/workflow/online/update.py b/qlib/workflow/online/update.py index 5b58360d8..ab910ba8d 100644 --- a/qlib/workflow/online/update.py +++ b/qlib/workflow/online/update.py @@ -1,18 +1,20 @@ -from typing import Union, List -from qlib.data.dataset import DatasetH -from qlib.workflow import R -from qlib.data import D +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Updater is a module to update artifacts such as predictions, when the stock data is updating. +""" + +from abc import ABCMeta, abstractmethod + import pandas as pd from qlib import get_module_logger -from qlib.workflow import R -from qlib.model import Model -from qlib.model.trainer import task_train -from qlib.workflow.recorder import Recorder -from qlib.workflow.task.utils import list_recorders -from qlib.data.dataset.handler import DataHandlerLP +from qlib.data import D from qlib.data.dataset import DatasetH -from abc import ABCMeta, abstractmethod +from qlib.data.dataset.handler import DataHandlerLP +from qlib.model import Model from qlib.utils import get_date_by_shift +from qlib.workflow.recorder import Recorder class RMDLoader: @@ -25,19 +27,22 @@ class RMDLoader: def get_dataset(self, start_time, end_time, segments=None) -> DatasetH: """ - load, config and setup dataset. + Load, config and setup dataset. - This dataset is for inference + This dataset is for inference. + + Args: + start_time : + the start_time of underlying data + end_time : + the end_time of underlying data + segments : dict + the segments config for dataset + Due to the time series dataset (TSDatasetH), the test segments maybe different from start_time and end_time + + Returns: + DatasetH: the instance of DatasetH - Parameters - ---------- - start_time : - the start_time of underlying data - end_time : - the end_time of underlying data - segments : dict - the segments config for dataset - Due to the time series dataset (TSDatasetH), the test segments maybe different from start_time and end_time """ if segments is None: segments = {"test": (start_time, end_time)} @@ -52,7 +57,7 @@ class RMDLoader: class RecordUpdater(metaclass=ABCMeta): """ - Updata a specific recorders + Update a specific recorders """ def __init__(self, record: Recorder, need_log=True, *args, **kwargs): @@ -75,17 +80,22 @@ class PredUpdater(RecordUpdater): def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day", need_log=True): """ - Parameters - ---------- - record : Recorder - to_date : - update to prediction to the `to_date` - hist_ref : int - Sometimes, the dataset will have historical depends. - Leave the problem to user to set the length of historical dependancy - NOTE: the start_time is not included in the hist_ref - # TODO: automate this step in the future. + Init PredUpdater. + + Args: + record : Recorder + to_date : + update to prediction to the `to_date` + hist_ref : int + Sometimes, the dataset will have historical depends. + Leave the problem to user to set the length of historical dependency + + .. note:: + + the start_time is not included in the hist_ref + """ + # TODO: automate this hist_ref in the future. super().__init__(record=record, need_log=need_log) self.to_date = to_date @@ -101,9 +111,12 @@ class PredUpdater(RecordUpdater): def prepare_data(self) -> DatasetH: """ - # Load dataset + Load dataset Seperating this function will make it easier to reuse the dataset + + Returns: + DatasetH: the instance of DatasetH """ start_time_buffer = get_date_by_shift(self.last_end, -self.hist_ref + 1, clip_shift=False, freq=self.freq) start_time = get_date_by_shift(self.last_end, 1, freq=self.freq) @@ -113,9 +126,12 @@ class PredUpdater(RecordUpdater): def update(self, dataset: DatasetH = None): """ - update the precition in a recorder + Update the precition in a recorder + + Args: + DatasetH: the instance of DatasetH. None for reprepare. """ - # FIXME: the problme below is not solved + # FIXME: the problem below is not solved # The model dumped on GPU instances can not be loaded on CPU instance. Follow exception will raised # RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU. # https://github.com/pytorch/pytorch/issues/16797 diff --git a/qlib/workflow/online/utils.py b/qlib/workflow/online/utils.py new file mode 100644 index 000000000..296ca3ea6 --- /dev/null +++ b/qlib/workflow/online/utils.py @@ -0,0 +1,170 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +OnlineTool is a module to set and unset a series of `online` models. +The `online` models are some decisive models in some time point, which can be changed with the change of time. +This allows us to use efficient submodels as the market style changing. +""" + +from typing import List, Union + +from qlib.log import get_module_logger +from qlib.workflow.online.update import PredUpdater +from qlib.workflow.recorder import Recorder +from qlib.workflow.task.utils import list_recorders + + +class OnlineTool: + """ + OnlineTool. + """ + + ONLINE_KEY = "online_status" # the online status key in recorder + ONLINE_TAG = "online" # the 'online' model + OFFLINE_TAG = "offline" # the 'offline' model, not for online serving + + def __init__(self, need_log=True): + """ + Init OnlineTool. + + Args: + need_log (bool, optional): print log or not. Defaults to True. + """ + self.logger = get_module_logger(self.__class__.__name__) + self.need_log = need_log + + def set_online_tag(self, tag, recorder: Union[list, object]): + """ + Set `tag` to the model to sign whether online. + + Args: + tag (str): the tags in `ONLINE_TAG`, `OFFLINE_TAG` + recorder (Union[list,object]): the model's recorder + """ + raise NotImplementedError(f"Please implement the `set_online_tag` method.") + + def get_online_tag(self, recorder: object) -> str: + """ + Given a model recorder and return its online tag. + + Args: + recorder (Object): the model's recorder + + Returns: + str: the online tag + """ + raise NotImplementedError(f"Please implement the `get_online_tag` method.") + + def reset_online_tag(self, recorder: Union[list, object]): + """ + Offline all models and set the recorders to 'online'. + + Args: + recorder (Union[list,object]): + the recorder you want to reset to 'online'. + + """ + raise NotImplementedError(f"Please implement the `reset_online_tag` method.") + + def online_models(self) -> list: + """ + Get current `online` models + + Returns: + list: a list of `online` models. + """ + raise NotImplementedError(f"Please implement the `online_models` method.") + + def update_online_pred(self, to_date=None): + """ + Update the predictions of `online` models to a date. + + Args: + to_date (pd.Timestamp): the pred before this date will be updated. None for update to latest. + + """ + raise NotImplementedError(f"Please implement the `update_online_pred` method.") + + +class OnlineToolR(OnlineTool): + """ + The implementation of OnlineTool based on (R)ecorder. + """ + + def __init__(self, experiment_name: str, need_log=True): + """ + Init OnlineToolR. + + Args: + experiment_name (str): the experiment name. + need_log (bool, optional): print log or not. Defaults to True. + """ + super().__init__(need_log=need_log) + self.exp_name = experiment_name + + def set_online_tag(self, tag, recorder: Union[Recorder, List]): + """ + Set `tag` to the model's recorder to sign whether online. + + Args: + tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG` + recorder (Union[Recorder, List]): a list of Recorder or an instance of Recorder + """ + if isinstance(recorder, Recorder): + recorder = [recorder] + for rec in recorder: + rec.set_tags(**{self.ONLINE_KEY: tag}) + if self.need_log: + self.logger.info(f"Set {len(recorder)} models to '{tag}'.") + + def get_online_tag(self, recorder: Recorder) -> str: + """ + Given a model recorder and return its online tag. + + Args: + recorder (Recorder): an instance of recorder + + Returns: + str: the online tag + """ + tags = recorder.list_tags() + return tags.get(self.ONLINE_KEY, self.OFFLINE_TAG) + + def reset_online_tag(self, recorder: Union[Recorder, List]): + """ + Offline all models and set the recorders to 'online'. + + Args: + recorder (Union[Recorder, List]): + the recorder you want to reset to 'online'. + + """ + if isinstance(recorder, Recorder): + recorder = [recorder] + recs = list_recorders(self.exp_name) + self.set_online_tag(self.OFFLINE_TAG, list(recs.values())) + self.set_online_tag(self.ONLINE_TAG, recorder) + + def online_models(self) -> list: + """ + Get current `online` models + + Returns: + list: a list of `online` models. + """ + return list(list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values()) + + def update_online_pred(self, to_date=None): + """ + Update the predictions of online models to a date. + + Args: + to_date (pd.Timestamp): the pred before this date will be updated. None for update to latest time in Calendar. + """ + online_models = self.online_models() + for rec in online_models: + PredUpdater(rec, to_date=to_date, need_log=self.need_log).update() + + if self.need_log: + self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.") diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 324b790ac..fc71b3f9a 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import re +import re, logging import pandas as pd from pathlib import Path from pprint import pprint @@ -13,10 +13,10 @@ from ..data.dataset.handler import DataHandlerLP from ..utils import init_instance_by_config, get_module_by_module_path from ..log import get_module_logger from ..utils import flatten_dict -from ..contrib.eva.alpha import calc_ic, calc_long_short_return +from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec from ..contrib.strategy.strategy import BaseStrategy -logger = get_module_logger("workflow", "INFO") +logger = get_module_logger("workflow", logging.INFO) class RecordTemp: @@ -166,6 +166,60 @@ class SignalRecord(RecordTemp): return super().load(name) +class HFSignalRecord(SignalRecord): + """ + This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class. + """ + + artifact_path = "hg_sig_analysis" + + def __init__(self, recorder, **kwargs): + super().__init__(recorder=recorder) + + def generate(self): + pred = self.load("pred.pkl") + raw_label = self.load("label.pkl") + long_pre, short_pre = calc_long_short_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha=True) + ic, ric = calc_ic(pred.iloc[:, 0], raw_label.iloc[:, 0]) + metrics = { + "IC": ic.mean(), + "ICIR": ic.mean() / ic.std(), + "Rank IC": ric.mean(), + "Rank ICIR": ric.mean() / ric.std(), + "Long precision": long_pre.mean(), + "Short precision": short_pre.mean(), + } + objects = {"ic.pkl": ic, "ric.pkl": ric} + objects.update({"long_pre.pkl": long_pre, "short_pre.pkl": short_pre}) + long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], raw_label.iloc[:, 0]) + metrics.update( + { + "Long-Short Average Return": long_short_r.mean(), + "Long-Short Average Sharpe": long_short_r.mean() / long_short_r.std(), + } + ) + objects.update( + { + "long_short_r.pkl": long_short_r, + "long_avg_r.pkl": long_avg_r, + } + ) + self.recorder.log_metrics(**metrics) + self.recorder.save_objects(**objects, artifact_path=self.get_path()) + pprint(metrics) + + def list(self): + paths = [ + self.get_path("ic.pkl"), + self.get_path("ric.pkl"), + self.get_path("long_pre.pkl"), + self.get_path("short_pre.pkl"), + self.get_path("long_short_r.pkl"), + self.get_path("long_avg_r.pkl"), + ] + return paths + + class SigAnaRecord(SignalRecord): """ This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class. diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 5915e58da..b9b2fd1b3 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -1,14 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import mlflow +import mlflow, logging import shutil, os, pickle, tempfile, codecs, pickle from pathlib import Path from datetime import datetime from ..utils.objm import FileManager from ..log import get_module_logger -logger = get_module_logger("workflow", "INFO") +logger = get_module_logger("workflow", logging.INFO) class Recorder: diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index b4c81122d..28320e2ce 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -1,8 +1,12 @@ -from abc import abstractmethod -from typing import Callable, Union +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Collector can collect object from everywhere and process them such as merging, grouping, averaging and so on. +""" + +from qlib.model.ens.ensemble import SingleKeyEnsemble from qlib.workflow import R -from qlib.workflow.task.utils import list_recorders -from qlib.utils.serial import Serializable import dill as pickle @@ -18,7 +22,7 @@ class Collector: process_list = [process_list] self.process_list = process_list - def collect(self): + def collect(self) -> dict: """Collect the results and return a dict like {key: things} Returns: @@ -35,7 +39,7 @@ class Collector: raise NotImplementedError(f"Please implement the `collect` method.") @staticmethod - def process_collect(collected_dict, process_list=[], *args, **kwargs): + def process_collect(collected_dict, process_list=[], *args, **kwargs) -> dict: """do a series of processing to the dict returned by collect and return a dict like {key: things} For example: you can group and ensemble. @@ -60,7 +64,7 @@ class Collector: result[artifact] = value return result - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> dict: """ do the workflow including collect and process_collect @@ -78,7 +82,7 @@ class Collector: filepath (str): the path of file Returns: - bool: if successed + bool: if succeeded """ try: with open(filepath, "wb") as f: @@ -109,6 +113,29 @@ class Collector: raise TypeError(f"The instance of {type(collector)} is not a valid `Collector`!") +class HyperCollector(Collector): + """ + A collector to collect the results of other Collectors + """ + + def __init__(self, collector_dict, process_list=[]): + """ + Args: + collector_dict (dict): the dict like {collector_key, Collector} + process_list (list or Callable): the list of processors or the instance of processor to process dict. + NOTE: process_list = [SingleKeyEnsemble()] can ignore key and use value directly if there is only one {k,v} in a dict. + This can make result more readable. If you want to maintain as it should be, just give a empty process list. + """ + super().__init__(process_list=process_list) + self.collector_dict = collector_dict + + def collect(self) -> dict: + collect_dict = {} + for key, collector in self.collector_dict.items(): + collect_dict[key] = collector() + return collect_dict + + class RecorderCollector(Collector): ART_KEY_RAW = "__raw" @@ -131,10 +158,10 @@ class RecorderCollector(Collector): artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}. artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts. """ + super().__init__(process_list=process_list) if isinstance(experiment, str): experiment = R.get_exp(experiment_name=experiment) self.experiment = experiment - self.process_list = process_list self.artifacts_path = artifacts_path if rec_key_func is None: rec_key_func = lambda rec: rec.info["id"] @@ -144,7 +171,7 @@ class RecorderCollector(Collector): self.artifacts_key = artifacts_key self._rec_filter_func = rec_filter_func - def collect(self, artifacts_key=None, rec_filter_func=None): + def collect(self, artifacts_key=None, rec_filter_func=None) -> dict: """Collect different artifacts based on recorder after filtering. Args: @@ -180,3 +207,12 @@ class RecorderCollector(Collector): collect_dict.setdefault(key, {})[rec_key] = artifact return collect_dict + + def get_exp_name(self) -> str: + """ + Get experiment name + + Returns: + str: experiment name + """ + return self.experiment.name diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py index 158bc9916..c4c6bab7f 100644 --- a/qlib/workflow/task/gen.py +++ b/qlib/workflow/task/gen.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. """ -this is a task generator +Task generator can generate many tasks based on TaskGen and some task templates. """ import abc import copy @@ -113,7 +113,7 @@ class RollingGen(TaskGen): self.test_key = "test" self.train_key = "train" - def generate(self, task: dict): + def generate(self, task: dict) -> typing.List[dict]: """ Converting the task into a rolling task. @@ -158,6 +158,10 @@ class RollingGen(TaskGen): }, ] } + + Returns + ---------- + typing.List[dict]: a list of tasks """ res = [] @@ -196,16 +200,18 @@ class RollingGen(TaskGen): # update segments of this task t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments) - # if end_time < the end of test_segments, then change end_time to allow load more data - if ( - self.modify_end_time - and self.ta.cal_interval( + + try: + interval = self.ta.cal_interval( t["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"], t["dataset"]["kwargs"]["segments"][self.test_key][1], ) - < 0 - ): - t["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"] = copy.deepcopy(segments[self.test_key][1]) + # if end_time < the end of test_segments, then change end_time to allow load more data + if self.modify_end_time and interval < 0: + t["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"] = copy.deepcopy(segments[self.test_key][1]) + except KeyError: + # Maybe the user dataset has no handler or end_time + pass prev_seg = segments res.append(t) return res diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index 9d50d8563..c71be7d39 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -1,31 +1,39 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. + """ -A task consists of 3 parts +TaskManager can fetch unused tasks automatically and manager the lifecycle of a set of tasks with error handling. +These features can run tasks concurrently and ensure every task will be used only once. +Task Manager will store all tasks in `MongoDB `_. +Users **MUST** finished the configuration of `MongoDB `_ when using this module. + +A task in TaskManager consists of 3 parts - tasks description: the desc will define the task - tasks status: the status of the task - tasks result information : A user can get the task with the task description and task result. - """ -from bson.binary import Binary -import pickle -from pymongo.errors import InvalidDocument -from bson.objectid import ObjectId -from contextlib import contextmanager -import qlib -from tqdm.cli import tqdm -import time import concurrent -import pymongo -from qlib.config import C -from .utils import get_mongodb -from qlib import get_module_logger, auto_init +import pickle +import time +from contextlib import contextmanager +from typing import Callable, List + import fire +import pymongo +from bson.binary import Binary +from bson.objectid import ObjectId +from pymongo.errors import InvalidDocument +from qlib import auto_init, get_module_logger +from tqdm.cli import tqdm + +from .utils import get_mongodb class TaskManager: - """TaskManager - here is what will a task looks like when it created by TaskManager + """ + TaskManager + + Here is what will a task looks like when it created by TaskManager .. code-block:: python @@ -42,6 +50,16 @@ class TaskManager: .. note:: Assumption: the data in MongoDB was encoded and the data out of MongoDB was decoded + + Here are four status which are: + + STATUS_WAITING: waiting for train + + STATUS_RUNNING: training + + STATUS_PART_DONE: finished some step and waiting for next step + + STATUS_DONE: all work done """ STATUS_WAITING = "waiting" @@ -53,7 +71,7 @@ class TaskManager: def __init__(self, task_pool: str = None): """ - init Task Manager, remember to make the statement of MongoDB url and database name firstly. + Init Task Manager, remember to make the statement of MongoDB url and database name firstly. Parameters ---------- @@ -65,7 +83,7 @@ class TaskManager: self.task_pool = getattr(self.mdb, task_pool) self.logger = get_module_logger(self.__class__.__name__) - def list(self): + def list(self) -> list: """ list the all collection(task_pool) of the db @@ -92,7 +110,9 @@ class TaskManager: return {k: str(v) for k, v in flt.items()} def replace_task(self, task, new_task): - # assume that the data out of interface was decoded and the data in interface was encoded + """ + Use a new task to replace a old one + """ new_task = self._encode_task(new_task) query = {"_id": ObjectId(task["_id"])} try: @@ -121,7 +141,7 @@ class TaskManager: Returns ------- - + pymongo.results.InsertOneResult """ task = self._encode_task( { @@ -133,9 +153,9 @@ class TaskManager: insert_result = self.insert_task(task) return insert_result - def create_task(self, task_def_l, dry_run=False, print_nt=False): + def create_task(self, task_def_l, dry_run=False, print_nt=False) -> List[str]: """ - if the tasks in task_def_l is new, then insert new tasks into the task_pool + If the tasks in task_def_l is new, then insert new tasks into the task_pool Parameters ---------- @@ -145,6 +165,7 @@ class TaskManager: if insert those new tasks to task pool print_nt: bool if print new task + Returns ------- list @@ -165,7 +186,7 @@ class TaskManager: print(t) if dry_run: - return + return [] _id_list = [] for t in new_tasks: @@ -174,7 +195,17 @@ class TaskManager: return _id_list - def fetch_task(self, query={}, status=STATUS_WAITING): + def fetch_task(self, query={}, status=STATUS_WAITING) -> dict: + """ + Use query to fetch tasks + + Args: + query (dict, optional): query dict. Defaults to {}. + status (str, optional): [description]. Defaults to STATUS_WAITING. + + Returns: + dict: a task(document in collection) after decoding + """ query = query.copy() if "_id" in query: query["_id"] = ObjectId(query["_id"]) @@ -191,7 +222,7 @@ class TaskManager: @contextmanager def safe_fetch_task(self, query={}, status=STATUS_WAITING): """ - fetch task from task_pool using query with contextmanager + Fetch task from task_pool using query with contextmanager Parameters ---------- @@ -200,7 +231,7 @@ class TaskManager: Returns ------- - + dict: a task(document in collection) after decoding """ task = self.fetch_task(query=query, status=status) try: @@ -231,7 +262,7 @@ class TaskManager: Returns ------- - + dict: a task(document in collection) after decoding """ query = query.copy() if "_id" in query: @@ -240,16 +271,40 @@ class TaskManager: yield self._decode_task(t) def re_query(self, _id): + """ + Use _id to query task. + + Args: + _id (str): _id of a document + + Returns: + dict: a task(document in collection) after decoding + """ t = self.task_pool.find_one({"_id": ObjectId(_id)}) return self._decode_task(t) - def commit_task_res(self, task, res, status=None): + def commit_task_res(self, task, res, status=STATUS_DONE): + """ + Commit the result to task['res']. + + Args: + task ([type]): [description] + res (object): the result you want to save + status (str, optional): STATUS_WAITING, STATUS_RUNNING, STATUS_DONE, STATUS_PART_DONE. Defaults to STATUS_DONE. + """ # A workaround to use the class attribute. if status is None: status = TaskManager.STATUS_DONE self.task_pool.update_one({"_id": task["_id"]}, {"$set": {"status": status, "res": Binary(pickle.dumps(res))}}) - def return_task(self, task, status=None): + def return_task(self, task, status=STATUS_WAITING): + """ + Return a task to status. Alway using in error handling. + + Args: + task ([type]): [description] + status (str, optional): STATUS_WAITING, STATUS_RUNNING, STATUS_DONE, STATUS_PART_DONE. Defaults to STATUS_WAITING. + """ if status is None: status = TaskManager.STATUS_WAITING update_dict = {"$set": {"status": status}} @@ -257,7 +312,7 @@ class TaskManager: def remove(self, query={}): """ - remove the task using query + Remove the task using query Parameters ---------- @@ -295,7 +350,7 @@ class TaskManager: def prioritize(self, task, priority: int): """ - set priority for task + Set priority for task Parameters ---------- @@ -331,29 +386,41 @@ class TaskManager: def run_task( - task_func, - task_pool, - force_release=False, - before_status=TaskManager.STATUS_WAITING, - after_status=TaskManager.STATUS_DONE, - *args, + task_func: Callable, + task_pool: str, + force_release: bool = False, + before_status: str = TaskManager.STATUS_WAITING, + after_status: str = TaskManager.STATUS_DONE, **kwargs, ): """ While task pool is not empty (has WAITING tasks), use task_func to fetch and run tasks in task_pool + After running this method, here are 4 situations (before_status -> after_status): + + STATUS_WAITING -> STATUS_DONE: use task["def"] as `task_func` param + + STATUS_WAITING -> STATUS_PART_DONE: use task["def"] as `task_func` param + + STATUS_PART_DONE -> STATUS_PART_DONE: use task["res"] as `task_func` param + + STATUS_PART_DONE -> STATUS_DONE: use task["res"] as `task_func` param + Parameters ---------- - task_func : def (task_def, *args, **kwargs) -> - the function to run the task + task_func : Callable + def (task_def, **kwargs) -> + the function to run the task task_pool : str the name of the task pool (Collection in MongoDB) - force_release : + force_release : bool will the program force to release the resource - args : - args - kwargs : - kwargs + before_status : str: + the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE. + after_status : str: + the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE. + kwargs + the params for `task_func` """ tm = TaskManager(task_pool) @@ -364,19 +431,19 @@ def run_task( if task is None: break get_module_logger("run_task").info(task["def"]) - # when fetching `WAITING` task, use task_def to train + # when fetching `WAITING` task, use task["def"] to train if before_status == TaskManager.STATUS_WAITING: param = task["def"] - # when fetching `PART_DONE` task, use task_res to train for the result has been saved + # when fetching `PART_DONE` task, use task["res"] to train because the middle result has been saved to task["res"] elif before_status == TaskManager.STATUS_PART_DONE: param = task["res"] else: raise ValueError("The fetched task must be `STATUS_WAITING` or `STATUS_PART_DONE`!") if force_release: with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: - res = executor.submit(task_func, param, *args, **kwargs).result() + res = executor.submit(task_func, param, **kwargs).result() else: - res = task_func(param, *args, **kwargs) + res = task_func(param, **kwargs) tm.commit_task_res(task, res, status=after_status) ever_run = True diff --git a/qlib/workflow/task/utils.py b/qlib/workflow/task/utils.py index ce8e0dfa3..ed5e1a235 100644 --- a/qlib/workflow/task/utils.py +++ b/qlib/workflow/task/utils.py @@ -1,5 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. + +""" +Some tools for task management. +""" + import bisect import pandas as pd from qlib.data import D @@ -7,13 +12,14 @@ from qlib.workflow import R from qlib.config import C from qlib.log import get_module_logger from pymongo import MongoClient +from pymongo.database import Database from typing import Union -def get_mongodb(): - """ +def get_mongodb() -> Database: - get database in MongoDB, which means you need to declare the address and the name of database. + """ + Get database in MongoDB, which means you need to declare the address and the name of database. for example: Using qlib.init(): @@ -31,6 +37,8 @@ def get_mongodb(): "task_db_name" : "rolling_db" } + Returns: + Database: the Database instance """ try: cfg = C["mongo"] @@ -43,7 +51,8 @@ def get_mongodb(): def list_recorders(experiment, rec_filter_func=None): - """list all recorders which can pass the filter in a experiment. + """ + List all recorders which can pass the filter in a experiment. Args: experiment (str or Experiment): the name of a Experiment or a instance @@ -65,7 +74,7 @@ def list_recorders(experiment, rec_filter_func=None): class TimeAdjuster: """ - find appropriate date and adjust date. + Find appropriate date and adjust date. """ def __init__(self, future=True, end_time=None): @@ -88,15 +97,15 @@ class TimeAdjuster: return None return self.cals[idx] - def max(self): + def max(self) -> pd.Timestamp: """ Return the max calendar datetime """ return max(self.cals) - def align_idx(self, time_point, tp_type="start"): + def align_idx(self, time_point, tp_type="start") -> int: """ - align the index of time_point in the calendar + Align the index of time_point in the calendar Parameters ---------- @@ -116,9 +125,9 @@ class TimeAdjuster: raise NotImplementedError(f"This type of input is not supported") return idx - def cal_interval(self, time_point_A, time_point_B): + def cal_interval(self, time_point_A, time_point_B) -> int: """ - calculate the trading day interval + Calculate the trading day interval (time_point_A - time_point_B) Args: time_point_A : time_point_A @@ -129,20 +138,22 @@ class TimeAdjuster: """ return self.align_idx(time_point_A) - self.align_idx(time_point_B) - def align_time(self, time_point, tp_type="start"): + def align_time(self, time_point, tp_type="start") -> pd.Timestamp: """ Align time_point to trade date of calendar - Parameters - ---------- - time_point - Time point - tp_type : str - time point type (`"start"`, `"end"`) + Args: + time_point + Time point + tp_type : str + time point type (`"start"`, `"end"`) + + Returns: + pd.Timestamp """ return self.cals[self.align_idx(time_point, tp_type=tp_type)] - def align_seg(self, segment: Union[dict, tuple]): + def align_seg(self, segment: Union[dict, tuple]) -> Union[dict, tuple]: """ align the given date to trade date @@ -162,7 +173,7 @@ class TimeAdjuster: Returns ------- - the start and end trade date (pd.Timestamp) between the given start and end date. + Union[dict, tuple]: the start and end trade date (pd.Timestamp) between the given start and end date. """ if isinstance(segment, dict): return {k: self.align_seg(seg) for k, seg in segment.items()} @@ -171,7 +182,7 @@ class TimeAdjuster: else: raise NotImplementedError(f"This type of input is not supported") - def truncate(self, segment: tuple, test_start, days: int): + def truncate(self, segment: tuple, test_start, days: int) -> tuple: """ truncate the segment based on the test_start date @@ -183,6 +194,10 @@ class TimeAdjuster: days : int The trading days to be truncated the data in this segment may need 'days' data + + Returns + --------- + tuple: new segment """ test_idx = self.align_idx(test_start) if isinstance(segment, tuple): @@ -198,7 +213,7 @@ class TimeAdjuster: SHIFT_SD = "sliding" SHIFT_EX = "expanding" - def shift(self, seg: tuple, step: int, rtype=SHIFT_SD): + def shift(self, seg: tuple, step: int, rtype=SHIFT_SD) -> tuple: """ shift the datatime of segment @@ -211,6 +226,10 @@ class TimeAdjuster: rtype : str rolling type ("sliding" or "expanding") + Returns + -------- + tuple: new segment + Raises ------ KeyError: diff --git a/qlib/workflow/utils.py b/qlib/workflow/utils.py index 33d251dd8..596ff0927 100644 --- a/qlib/workflow/utils.py +++ b/qlib/workflow/utils.py @@ -1,12 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import sys, traceback, signal, atexit +import sys, traceback, signal, atexit, logging from . import R from .recorder import Recorder from ..log import get_module_logger -logger = get_module_logger("workflow", "INFO") +logger = get_module_logger("workflow", logging.INFO) # function to handle the experiment when unusual program ending occurs diff --git a/scripts/data_collector/contrib/README.md b/scripts/data_collector/contrib/README.md new file mode 100644 index 000000000..011ff56e6 --- /dev/null +++ b/scripts/data_collector/contrib/README.md @@ -0,0 +1,24 @@ +# Get future trading days + +> `D.calendar(future=True)` will be used + +## Requirements + +```bash +pip install -r requirements.txt +``` + +## Collector Data + +```bash +# parse instruments, using in qlib/instruments. +python future_trading_date_collector.py --qlib_dir ~/.qlib/qlib_data/cn_data --freq day +``` + +## Parameters + +- qlib_dir: qlib data directory +- freq: value from [`day`, `1min`], default `day` + + + diff --git a/scripts/data_collector/contrib/future_trading_date_collector.py b/scripts/data_collector/contrib/future_trading_date_collector.py new file mode 100644 index 000000000..4da62d465 --- /dev/null +++ b/scripts/data_collector/contrib/future_trading_date_collector.py @@ -0,0 +1,87 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import sys +from typing import List +from pathlib import Path + +import fire +import numpy as np +import pandas as pd +from loguru import logger + +# get data from baostock +import baostock as bs + +CUR_DIR = Path(__file__).resolve().parent +sys.path.append(str(CUR_DIR.parent.parent)) + + +from data_collector.utils import generate_minutes_calendar_from_daily + + +def read_calendar_from_qlib(qlib_dir: Path) -> pd.DataFrame: + calendar_path = qlib_dir.joinpath("calendars").joinpath("day.txt") + if not calendar_path.exists(): + return pd.DataFrame() + return pd.read_csv(calendar_path, header=None) + + +def write_calendar_to_qlib(qlib_dir: Path, date_list: List[str], freq: str = "day"): + calendar_path = str(qlib_dir.joinpath("calendars").joinpath(f"{freq}_future.txt")) + + np.savetxt(calendar_path, date_list, fmt="%s", encoding="utf-8") + logger.info(f"write future calendars success: {calendar_path}") + + +def generate_qlib_calendar(date_list: List[str], freq: str) -> List[str]: + print(freq) + if freq == "day": + return date_list + elif freq == "1min": + date_list = generate_minutes_calendar_from_daily(date_list, freq=freq).tolist() + return list(map(lambda x: pd.Timestamp(x).strftime("%Y-%m-%d %H:%M:%S"), date_list)) + else: + raise ValueError(f"Unsupported freq: {freq}") + + +def future_calendar_collector(qlib_dir: [str, Path], freq: str = "day"): + """get future calendar + + Parameters + ---------- + qlib_dir: str or Path + qlib data directory + freq: str + value from ["day", "1min"], by default day + """ + qlib_dir = Path(qlib_dir).expanduser().resolve() + if not qlib_dir.exists(): + raise FileNotFoundError(str(qlib_dir)) + + lg = bs.login() + if lg.error_code != "0": + logger.error(f"login error: {lg.error_msg}") + return + # read daily calendar + daily_calendar = read_calendar_from_qlib(qlib_dir) + end_year = pd.Timestamp.now().year + if daily_calendar.empty: + start_year = pd.Timestamp.now().year + else: + start_year = pd.Timestamp(daily_calendar.iloc[-1, 0]).year + rs = bs.query_trade_dates(start_date=pd.Timestamp(f"{start_year}-01-01"), end_date=f"{end_year}-12-31") + data_list = [] + while (rs.error_code == "0") & rs.next(): + _row_data = rs.get_row_data() + if int(_row_data[1]) == 1: + data_list.append(_row_data[0]) + data_list = sorted(data_list) + date_list = generate_qlib_calendar(data_list, freq=freq) + write_calendar_to_qlib(qlib_dir, date_list, freq=freq) + bs.logout() + logger.info(f"get trading dates success: {start_year}-01-01 to {end_year}-12-31") + + +if __name__ == "__main__": + fire.Fire(future_calendar_collector) diff --git a/scripts/data_collector/contrib/requirements.txt b/scripts/data_collector/contrib/requirements.txt new file mode 100644 index 000000000..92dcb2374 --- /dev/null +++ b/scripts/data_collector/contrib/requirements.txt @@ -0,0 +1,5 @@ +baostock +fire +numpy +pandas +loguru diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index e8c9b9dc4..3f4539612 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -10,7 +10,9 @@ import random import requests import functools from pathlib import Path +from typing import Iterable, Tuple +import numpy as np import pandas as pd from lxml import etree from loguru import logger @@ -418,5 +420,40 @@ def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, sh return res +def generate_minutes_calendar_from_daily( + calendars: Iterable, + freq: str = "1min", + am_range: Tuple[str, str] = ("09:30:00", "11:29:00"), + pm_range: Tuple[str, str] = ("13:00:00", "14:59:00"), +) -> pd.Index: + """generate minutes calendar + + Parameters + ---------- + calendars: Iterable + daily calendar + freq: str + by default 1min + am_range: Tuple[str, str] + AM Time Range, by default China-Stock: ("09:30:00", "11:29:00") + pm_range: Tuple[str, str] + PM Time Range, by default China-Stock: ("13:00:00", "14:59:00") + + """ + daily_format: str = "%Y-%m-%d" + res = [] + for _day in calendars: + for _range in [am_range, pm_range]: + res.append( + pd.date_range( + f"{pd.Timestamp(_day).strftime(daily_format)} {_range[0]}", + f"{pd.Timestamp(_day).strftime(daily_format)} {_range[1]}", + freq=freq, + ) + ) + + return pd.Index(sorted(set(np.hstack(res)))) + + if __name__ == "__main__": assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index f0e110694..a6e06613e 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -24,7 +24,12 @@ from qlib.config import REG_CN as REGION_CN CUR_DIR = Path(__file__).resolve().parent sys.path.append(str(CUR_DIR.parent.parent)) from data_collector.base import BaseCollector, BaseNormalize, BaseRun -from data_collector.utils import get_calendar_list, get_hs_stock_symbols, get_us_stock_symbols +from data_collector.utils import ( + get_calendar_list, + get_hs_stock_symbols, + get_us_stock_symbols, + generate_minutes_calendar_from_daily, +) INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}" @@ -418,21 +423,9 @@ class YahooNormalize1min(YahooNormalize, ABC): return calendar_list_1d def generate_1min_from_daily(self, calendars: Iterable) -> pd.Index: - res = [] - daily_format = self.DAILY_FORMAT - am_range = self.AM_RANGE - pm_range = self.PM_RANGE - for _day in calendars: - for _range in [am_range, pm_range]: - res.append( - pd.date_range( - f"{_day.strftime(daily_format)} {_range[0]}", - f"{_day.strftime(daily_format)} {_range[1]}", - freq="1min", - ) - ) - - return pd.Index(sorted(set(np.hstack(res)))) + return generate_minutes_calendar_from_daily( + calendars, freq="1min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE + ) def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame: # TODO: using daily data factor diff --git a/setup.py b/setup.py index c90d7d1c3..92c9ccc0c 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ REQUIRED = [ "scipy>=1.0.0", "requests>=2.18.0", "sacred>=0.7.4", - "python-socketio==3.1.2", + "python-socketio", "redis>=3.0.1", "python-redis-lock>=3.3.1", "schedule>=0.6.0",