mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 02:50:58 +08:00
Merge pull request #466 from you-n-g/online_hotfix
Online bug fix, enhancement & docs for dataset, workflow, trainer ...
This commit is contained in:
@@ -20,11 +20,17 @@ def init(default_conf="client", **kwargs):
|
||||
from .config import C
|
||||
from .data.cache import H
|
||||
|
||||
H.clear()
|
||||
|
||||
# FIXME: this logger ignored the level in config
|
||||
logger = get_module_logger("Initialization", level=logging.INFO)
|
||||
|
||||
skip_if_reg = kwargs.pop("skip_if_reg", False)
|
||||
if skip_if_reg and C.registered:
|
||||
# if we reinitialize Qlib during running an experiment `R.start`.
|
||||
# it will result in loss of the recorder
|
||||
logger.warning("Skip initialization because `skip_if_reg is True`")
|
||||
return
|
||||
|
||||
H.clear()
|
||||
C.set(default_conf, **kwargs)
|
||||
|
||||
# check path if server/local
|
||||
@@ -197,14 +203,15 @@ def auto_init(**kwargs):
|
||||
- Find the project configuration and init qlib
|
||||
- The parsing process will be affected by the `conf_type` of the configuration file
|
||||
- Init qlib with default config
|
||||
- Skip initialization if already initialized
|
||||
"""
|
||||
kwargs["skip_if_reg"] = kwargs.get("skip_if_reg", True)
|
||||
|
||||
try:
|
||||
pp = get_project_path(cur_path=kwargs.pop("cur_path", None))
|
||||
except FileNotFoundError:
|
||||
init(**kwargs)
|
||||
else:
|
||||
|
||||
conf_pp = pp / "config.yaml"
|
||||
with conf_pp.open() as f:
|
||||
conf = yaml.safe_load(f)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from ...utils.serial import Serializable
|
||||
from typing import Union, List, Tuple, Dict, Text, Optional
|
||||
from ...utils import init_instance_by_config, np_ffill
|
||||
from ...utils import init_instance_by_config, np_ffill, time_to_slc_point
|
||||
from ...log import get_module_logger
|
||||
from .handler import DataHandler, DataHandlerLP
|
||||
from copy import deepcopy
|
||||
@@ -243,6 +243,8 @@ class TSDataSampler:
|
||||
|
||||
It works like `torch.data.utils.Dataset`, it provides a very convenient interface for constructing time-series
|
||||
dataset based on tabular data.
|
||||
- On time step dimension, the smaller index indicates the historical data and the larger index indicates the future
|
||||
data.
|
||||
|
||||
If user have further requirements for processing data, user could process them based on `TSDataSampler` or create
|
||||
more powerful subclasses.
|
||||
@@ -309,11 +311,19 @@ class TSDataSampler:
|
||||
self.data_index = deepcopy(self.data.index)
|
||||
|
||||
if flt_data is not None:
|
||||
self.flt_data = np.array(flt_data.reindex(self.data_index)).reshape(-1)
|
||||
if isinstance(flt_data, pd.DataFrame):
|
||||
assert len(flt_data.columns) == 1
|
||||
flt_data = flt_data.iloc[:, 0]
|
||||
# NOTE: bool(np.nan) is True !!!!!!!!
|
||||
# make sure reindex comes first. Otherwise extra NaN may appear.
|
||||
flt_data = flt_data.reindex(self.data_index).fillna(False).astype(np.bool)
|
||||
self.flt_data = flt_data.values
|
||||
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
|
||||
self.data_index = self.data_index[np.where(self.flt_data == True)[0]]
|
||||
|
||||
self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
|
||||
self.start_idx, self.end_idx = self.data_index.slice_locs(
|
||||
start=time_to_slc_point(start), end=time_to_slc_point(end)
|
||||
)
|
||||
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
|
||||
|
||||
del self.data # save memory
|
||||
@@ -341,7 +351,7 @@ class TSDataSampler:
|
||||
setattr(self, k, v)
|
||||
|
||||
@staticmethod
|
||||
def build_index(data: pd.DataFrame) -> dict:
|
||||
def build_index(data: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
|
||||
"""
|
||||
The relation of the data
|
||||
|
||||
@@ -352,9 +362,15 @@ class TSDataSampler:
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict:
|
||||
{<index>: <prev_index or None>}
|
||||
# get the previous index of a line given index
|
||||
Tuple[pd.DataFrame, dict]:
|
||||
1) the first element: reshape the original index into a <datetime(row), instrument(column)> 2D dataframe
|
||||
instrument SH600000 SH600004 SH600006 SH600007 SH600008 SH600009 ...
|
||||
datetime
|
||||
2021-01-11 0 1 2 3 4 5 ...
|
||||
2021-01-12 4146 4147 4148 4149 4150 4151 ...
|
||||
2021-01-13 8293 8294 8295 8296 8297 8298 ...
|
||||
2021-01-14 12441 12442 12443 12444 12445 12446 ...
|
||||
2) the second element: {<original index>: <row, col>}
|
||||
"""
|
||||
# object incase of pandas converting int to flaot
|
||||
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object)
|
||||
|
||||
@@ -28,16 +28,18 @@ class QlibLogger(metaclass=MetaLogger):
|
||||
|
||||
def __init__(self, module_name):
|
||||
self.module_name = module_name
|
||||
self.level = 0
|
||||
# this feature name conflicts with the attribute with Logger
|
||||
# rename it to avoid some corner cases that result in comparing `str` and `int`
|
||||
self.__level = 0
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
logger = logging.getLogger(self.module_name)
|
||||
logger.setLevel(self.level)
|
||||
logger.setLevel(self.__level)
|
||||
return logger
|
||||
|
||||
def setLevel(self, level):
|
||||
self.level = level
|
||||
self.__level = level
|
||||
|
||||
def __getattr__(self, name):
|
||||
# During unpickling, python will call __getattr__. Use this line to avoid maximum recursion error.
|
||||
|
||||
@@ -97,7 +97,7 @@ class ModelFT(Model):
|
||||
|
||||
# Finetune model based on previous trained model
|
||||
with R.start(experiment_name="finetune model"):
|
||||
recorder = R.get_recorder(rid, experiment_name="init models")
|
||||
recorder = R.get_recorder(recorder_id=rid, experiment_name="init models")
|
||||
model = recorder.load_object("init_model")
|
||||
model.finetune(dataset, num_boost_round=10)
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ There are two steps in each Trainer including ``train``(make model recorder) and
|
||||
This is a concept called ``DelayTrainer``, which can be used in online simulating for parallel training.
|
||||
In ``DelayTrainer``, the first step is only to save some necessary info to model recorders, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting.
|
||||
|
||||
``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
|
||||
``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
|
||||
"""
|
||||
|
||||
import socket
|
||||
@@ -153,6 +153,9 @@ class Trainer:
|
||||
"""
|
||||
return self.delay
|
||||
|
||||
def __call__(self, *args, **kwargs) -> list:
|
||||
return self.end_train(self.train(*args, **kwargs))
|
||||
|
||||
|
||||
class TrainerR(Trainer):
|
||||
"""
|
||||
@@ -286,7 +289,9 @@ class TrainerRM(Trainer):
|
||||
# This tag is the _id in TaskManager to distinguish tasks.
|
||||
TM_ID = "_id in TaskManager"
|
||||
|
||||
def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train):
|
||||
def __init__(
|
||||
self, experiment_name: str = None, task_pool: str = None, train_func=task_train, skip_run_task: bool = False
|
||||
):
|
||||
"""
|
||||
Init TrainerR.
|
||||
|
||||
@@ -294,11 +299,16 @@ class TrainerRM(Trainer):
|
||||
experiment_name (str): the default name of experiment.
|
||||
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
|
||||
train_func (Callable, optional): default training method. Defaults to `task_train`.
|
||||
skip_run_task (bool):
|
||||
If skip_run_task == True:
|
||||
Only run_task in the worker. Otherwise skip run_task.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.experiment_name = experiment_name
|
||||
self.task_pool = task_pool
|
||||
self.train_func = train_func
|
||||
self.skip_run_task = skip_run_task
|
||||
|
||||
def train(
|
||||
self,
|
||||
@@ -340,15 +350,16 @@ class TrainerRM(Trainer):
|
||||
tm = TaskManager(task_pool=task_pool)
|
||||
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
|
||||
query = {"_id": {"$in": _id_list}}
|
||||
run_task(
|
||||
train_func,
|
||||
task_pool,
|
||||
query=query, # only train these tasks
|
||||
experiment_name=experiment_name,
|
||||
before_status=before_status,
|
||||
after_status=after_status,
|
||||
**kwargs,
|
||||
)
|
||||
if not self.skip_run_task:
|
||||
run_task(
|
||||
train_func,
|
||||
task_pool,
|
||||
query=query, # only train these tasks
|
||||
experiment_name=experiment_name,
|
||||
before_status=before_status,
|
||||
after_status=after_status,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not self.is_delay():
|
||||
tm.wait(query=query)
|
||||
@@ -411,6 +422,7 @@ class DelayTrainerRM(TrainerRM):
|
||||
task_pool: str = None,
|
||||
train_func=begin_task_train,
|
||||
end_train_func=end_task_train,
|
||||
skip_run_task: bool = False,
|
||||
):
|
||||
"""
|
||||
Init DelayTrainerRM.
|
||||
@@ -420,10 +432,15 @@ class DelayTrainerRM(TrainerRM):
|
||||
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
|
||||
train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
|
||||
end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
|
||||
skip_run_task (bool):
|
||||
If skip_run_task == True:
|
||||
Only run_task in the worker. Otherwise skip run_task.
|
||||
E.g. Starting trainer on a CPU VM and then waiting tasks to be finished on GPU VMs.
|
||||
"""
|
||||
super().__init__(experiment_name, task_pool, train_func)
|
||||
self.end_train_func = end_train_func
|
||||
self.delay = True
|
||||
self.skip_run_task = skip_run_task
|
||||
|
||||
def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
|
||||
"""
|
||||
@@ -477,14 +494,15 @@ class DelayTrainerRM(TrainerRM):
|
||||
_id_list.append(rec.list_tags()[self.TM_ID])
|
||||
|
||||
query = {"_id": {"$in": _id_list}}
|
||||
run_task(
|
||||
end_train_func,
|
||||
task_pool,
|
||||
query=query, # only train these tasks
|
||||
experiment_name=experiment_name,
|
||||
before_status=TaskManager.STATUS_PART_DONE,
|
||||
**kwargs,
|
||||
)
|
||||
if not self.skip_run_task:
|
||||
run_task(
|
||||
end_train_func,
|
||||
task_pool,
|
||||
query=query, # only train these tasks
|
||||
experiment_name=experiment_name,
|
||||
before_status=TaskManager.STATUS_PART_DONE,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
TaskManager(task_pool=task_pool).wait(query=query)
|
||||
|
||||
|
||||
@@ -642,6 +642,28 @@ def split_pred(pred, number=None, split_date=None):
|
||||
return pred_left, pred_right
|
||||
|
||||
|
||||
def time_to_slc_point(t: Union[None, str, pd.Timestamp]) -> Union[None, pd.Timestamp]:
|
||||
"""
|
||||
Time slicing in Qlib or Pandas is a frequently-used action.
|
||||
However, user often input all kinds of data format to represent time.
|
||||
This function will help user to convert these inputs into a uniform format which is friendly to time slicing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
t : Union[None, str, pd.Timestamp]
|
||||
original time
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[None, pd.Timestamp]:
|
||||
"""
|
||||
if t is None:
|
||||
# None represents unbounded in Qlib or Pandas(e.g. df.loc[slice(None, "20210303")]).
|
||||
return t
|
||||
else:
|
||||
return pd.Timestamp(t)
|
||||
|
||||
|
||||
def can_use_cache():
|
||||
res = True
|
||||
r = get_redis_connection()
|
||||
|
||||
@@ -216,9 +216,9 @@ class QlibRecorder:
|
||||
-------
|
||||
A dictionary (id -> recorder) of recorder information that being stored.
|
||||
"""
|
||||
return self.get_exp(experiment_id, experiment_name).list_recorders()
|
||||
return self.get_exp(experiment_id=experiment_id, experiment_name=experiment_name).list_recorders()
|
||||
|
||||
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment:
|
||||
def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment:
|
||||
"""
|
||||
Method for retrieving an experiment with given id or name. Once the `create` argument is set to
|
||||
True, if no valid experiment is found, this method will create one for you. Otherwise, it will
|
||||
@@ -263,7 +263,7 @@ class QlibRecorder:
|
||||
|
||||
# Case 2
|
||||
with R.start('test'):
|
||||
exp = R.get_exp('test1')
|
||||
exp = R.get_exp(experiment_name='test1')
|
||||
|
||||
# Case 3
|
||||
exp = R.get_exp() -> a default experiment.
|
||||
@@ -288,7 +288,9 @@ class QlibRecorder:
|
||||
-------
|
||||
An experiment instance with given id or name.
|
||||
"""
|
||||
return self.exp_manager.get_exp(experiment_id, experiment_name, create, start=False)
|
||||
return self.exp_manager.get_exp(
|
||||
experiment_id=experiment_id, experiment_name=experiment_name, create=create, start=False
|
||||
)
|
||||
|
||||
def delete_exp(self, experiment_id=None, experiment_name=None):
|
||||
"""
|
||||
@@ -332,7 +334,9 @@ class QlibRecorder:
|
||||
"""
|
||||
self.exp_manager.set_uri(uri)
|
||||
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None) -> Recorder:
|
||||
def get_recorder(
|
||||
self, *, recorder_id=None, recorder_name=None, experiment_id=None, experiment_name=None
|
||||
) -> Recorder:
|
||||
"""
|
||||
Method for retrieving a recorder.
|
||||
|
||||
@@ -385,7 +389,7 @@ class QlibRecorder:
|
||||
-------
|
||||
A recorder instance.
|
||||
"""
|
||||
return self.get_exp(experiment_name=experiment_name, create=False).get_recorder(
|
||||
return self.get_exp(experiment_name=experiment_name, experiment_id=experiment_id, create=False).get_recorder(
|
||||
recorder_id, recorder_name, create=False, start=False
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Union
|
||||
import mlflow, logging
|
||||
from mlflow.entities import ViewType
|
||||
from mlflow.exceptions import MlflowException
|
||||
@@ -213,11 +214,15 @@ class Experiment:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `_get_recorder` method")
|
||||
|
||||
def list_recorders(self):
|
||||
def list_recorders(self, **flt_kwargs):
|
||||
"""
|
||||
List all the existing recorders of this experiment. Please first get the experiment instance before calling this method.
|
||||
If user want to use the method `R.list_recorders()`, please refer to the related API document in `QlibRecorder`.
|
||||
|
||||
flt_kwargs : dict
|
||||
filter recorders by conditions
|
||||
e.g. list_recorders(status=Recorder.STATUS_FI)
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary (id -> recorder) of recorder information that being stored.
|
||||
@@ -320,11 +325,21 @@ class MLflowExperiment(Experiment):
|
||||
|
||||
UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!
|
||||
|
||||
def list_recorders(self, max_results=UNLIMITED):
|
||||
def list_recorders(self, max_results: int = UNLIMITED, status: Union[str, None] = None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
max_results : int
|
||||
the number limitation of the results
|
||||
status : str
|
||||
the criteria based on status to filter results.
|
||||
`None` indicates no filtering.
|
||||
"""
|
||||
runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)
|
||||
recorders = dict()
|
||||
for i in range(len(runs)):
|
||||
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])
|
||||
recorders[runs[i].info.run_id] = recorder
|
||||
if status is None or recorder.status == status:
|
||||
recorders[runs[i].info.run_id] = recorder
|
||||
|
||||
return recorders
|
||||
|
||||
@@ -109,7 +109,7 @@ class ExpManager:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `search_records` method.")
|
||||
|
||||
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False):
|
||||
def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False):
|
||||
"""
|
||||
Retrieve an experiment. This method includes getting an active experiment, and get_or_create a specific experiment.
|
||||
|
||||
@@ -190,7 +190,7 @@ class ExpManager:
|
||||
except ValueError:
|
||||
if experiment_name is None:
|
||||
experiment_name = self._default_exp_name
|
||||
logger.info(f"No valid experiment found. Create a new experiment with name {experiment_name}.")
|
||||
logger.warning(f"No valid experiment found. Create a new experiment with name {experiment_name}.")
|
||||
return self.create_exp(experiment_name), True
|
||||
|
||||
def _get_exp(self, experiment_id=None, experiment_name=None) -> Experiment:
|
||||
@@ -352,6 +352,8 @@ class MLflowExpManager(ExpManager):
|
||||
), "Please input at least one of experiment/recorder id or name before retrieving experiment/recorder."
|
||||
if experiment_id is not None:
|
||||
try:
|
||||
# NOTE: the mlflow's experiment_id must be str type...
|
||||
# https://www.mlflow.org/docs/latest/python_api/mlflow.tracking.html#mlflow.tracking.MlflowClient.get_experiment
|
||||
exp = self.client.get_experiment(experiment_id)
|
||||
if exp.lifecycle_stage.upper() == "DELETED":
|
||||
raise MlflowException("No valid experiment has been found.")
|
||||
|
||||
@@ -6,7 +6,7 @@ OnlineManager can manage a set of `Online Strategy <#Online Strategy>`_ and run
|
||||
|
||||
With the change of time, the decisive models will be also changed. In this module, we call those contributing models `online` models.
|
||||
In every routine(such as every day or every minute), the `online` models may be changed and the prediction of them needs to be updated.
|
||||
So this module provides a series of methods to control this process.
|
||||
So this module provides a series of methods to control this process.
|
||||
|
||||
This module also provides a method to simulate `Online Strategy <#Online Strategy>`_ in history.
|
||||
Which means you can verify your strategy or find a better one.
|
||||
@@ -31,7 +31,7 @@ Simulation + Trainer When your models have some temporal dependence on the
|
||||
|
||||
Simulation + DelayTrainer When your models don't have any temporal dependence, you can use DelayTrainer
|
||||
for the ability to multitasking. It means all tasks in all routines
|
||||
can be REAL trained at the end of simulating. The signals will be prepared well at
|
||||
can be REAL trained at the end of simulating. The signals will be prepared well at
|
||||
different time segments (based on whether or not any new model is online).
|
||||
========================= ===================================================================================
|
||||
"""
|
||||
@@ -113,6 +113,8 @@ class OnlineManager(Serializable):
|
||||
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
|
||||
models_list.append(models)
|
||||
self.logger.info(f"Finished training {len(models)} models.")
|
||||
# FIXME: Traing multiple online models at `first_train` will result in getting too much online models at the
|
||||
# start.
|
||||
online_models = strategy.prepare_online_models(models, **model_kwargs)
|
||||
self.history.setdefault(self.cur_time, {})[strategy] = online_models
|
||||
|
||||
@@ -148,8 +150,6 @@ class OnlineManager(Serializable):
|
||||
models_list = []
|
||||
for strategy in self.strategies:
|
||||
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
|
||||
if self.status == self.STATUS_NORMAL:
|
||||
strategy.tool.update_online_pred()
|
||||
|
||||
tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
|
||||
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
|
||||
@@ -158,6 +158,11 @@ class OnlineManager(Serializable):
|
||||
online_models = strategy.prepare_online_models(models, **model_kwargs)
|
||||
self.history.setdefault(self.cur_time, {})[strategy] = online_models
|
||||
|
||||
# The online model may changes in the above processes
|
||||
# So updating the predictions of online models should be the last step
|
||||
if self.status == self.STATUS_NORMAL:
|
||||
strategy.tool.update_online_pred()
|
||||
|
||||
if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
|
||||
for strategy, models in zip(self.strategies, models_list):
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
@@ -236,7 +241,7 @@ class OnlineManager(Serializable):
|
||||
SIM_LOG_NAME = "SIMULATE_INFO"
|
||||
|
||||
def simulate(
|
||||
self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}
|
||||
self, end_time=None, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}
|
||||
) -> Union[pd.Series, pd.DataFrame]:
|
||||
"""
|
||||
Starting from the current time, this method will simulate every routine in OnlineManager until the end time.
|
||||
|
||||
@@ -52,6 +52,12 @@ class OnlineStrategy:
|
||||
|
||||
NOTE: Reset all online models to trained models. If there are no trained models, then do nothing.
|
||||
|
||||
**NOTE**:
|
||||
Current implementation is very naive. Here is a more complex situation which is more closer to the
|
||||
practical scenarios.
|
||||
1. Train new models at the day before `test_start` (at time stamp `T`)
|
||||
2. Switch models at the `test_start` (at time timestamp `T + 1` typically)
|
||||
|
||||
Args:
|
||||
models (list): a list of models.
|
||||
cur_time (pd.Dataframe): current time from OnlineManger. None for the latest.
|
||||
|
||||
@@ -136,7 +136,7 @@ class PredUpdater(RecordUpdater):
|
||||
# https://github.com/pytorch/pytorch/issues/16797
|
||||
|
||||
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
|
||||
if start_time >= self.to_date:
|
||||
if start_time > self.to_date:
|
||||
self.logger.info(
|
||||
f"The prediction in {self.record.info['id']} are latest ({start_time}). No need to update to {self.to_date}."
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ Collector module can collect objects from everywhere and process them such as me
|
||||
"""
|
||||
|
||||
from typing import Callable, Dict, List
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils.serial import Serializable
|
||||
from qlib.workflow import R
|
||||
|
||||
@@ -192,6 +193,7 @@ class RecorderCollector(Collector):
|
||||
if rec_filter_func is None or rec_filter_func(rec):
|
||||
recs_flt[rid] = rec
|
||||
|
||||
logger = get_module_logger("RecorderCollector")
|
||||
for _, rec in recs_flt.items():
|
||||
rec_key = self.rec_key_func(rec)
|
||||
for key in artifacts_key:
|
||||
@@ -205,7 +207,13 @@ class RecorderCollector(Collector):
|
||||
# only collect existing artifact
|
||||
continue
|
||||
raise e
|
||||
collect_dict.setdefault(key, {})[rec_key] = artifact
|
||||
# give user some warning if the values are overridden
|
||||
cdd = collect_dict.setdefault(key, {})
|
||||
if rec_key in cdd:
|
||||
logger.warning(
|
||||
f"key '{rec_key}' is duplicated. Previous value will be overrides. Please check you `rec_key_func`"
|
||||
)
|
||||
cdd[rec_key] = artifact
|
||||
|
||||
return collect_dict
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ TaskGenerator module can generate many tasks based on TaskGen and some task temp
|
||||
import abc
|
||||
import copy
|
||||
from typing import List, Union, Callable
|
||||
|
||||
from qlib.utils import transform_end_date
|
||||
from .utils import TimeAdjuster
|
||||
|
||||
|
||||
@@ -199,7 +201,7 @@ class RollingGen(TaskGen):
|
||||
# First rolling
|
||||
# 1) prepare the end point
|
||||
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
|
||||
test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1]
|
||||
test_end = transform_end_date(segments[self.test_key][1])
|
||||
# 2) and init test segments
|
||||
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
|
||||
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
|
||||
|
||||
@@ -272,10 +272,10 @@ class TaskManager:
|
||||
task = self.fetch_task(query=query, status=status)
|
||||
try:
|
||||
yield task
|
||||
except Exception:
|
||||
except (Exception, KeyboardInterrupt): # KeyboardInterrupt is not a subclass of Exception
|
||||
if task is not None:
|
||||
self.logger.info("Returning task before raising error")
|
||||
self.return_task(task)
|
||||
self.return_task(task, status=status) # return task as the original status
|
||||
self.logger.info("Task returned")
|
||||
raise
|
||||
|
||||
@@ -411,7 +411,11 @@ class TaskManager:
|
||||
self.task_pool.update_one({"_id": task["_id"]}, update_dict)
|
||||
|
||||
def _get_undone_n(self, task_stat):
|
||||
return task_stat.get(self.STATUS_WAITING, 0) + task_stat.get(self.STATUS_RUNNING, 0)
|
||||
return (
|
||||
task_stat.get(self.STATUS_WAITING, 0)
|
||||
+ task_stat.get(self.STATUS_RUNNING, 0)
|
||||
+ task_stat.get(self.STATUS_PART_DONE, 0)
|
||||
)
|
||||
|
||||
def _get_total(self, task_stat):
|
||||
return sum(task_stat.values())
|
||||
@@ -429,7 +433,7 @@ class TaskManager:
|
||||
last_undone_n = self._get_undone_n(task_stat)
|
||||
if last_undone_n == 0:
|
||||
return
|
||||
self.logger.warn(f"Waiting for {last_undone_n} undone tasks. Please make sure they are running.")
|
||||
self.logger.warning(f"Waiting for {last_undone_n} undone tasks. Please make sure they are running.")
|
||||
with tqdm(total=total, initial=total - last_undone_n) as pbar:
|
||||
while True:
|
||||
time.sleep(10)
|
||||
|
||||
Reference in New Issue
Block a user