1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 19:41:00 +08:00

Some Optimization of online code (#784)

* Some Optimization of online code

* more flexible updater and load_object & fix p*_uri

* make recorder more friendly

* remove unused import
This commit is contained in:
you-n-g
2022-01-03 15:52:03 +08:00
committed by GitHub
parent e76b409d9a
commit 03cce8c908
11 changed files with 113 additions and 49 deletions

View File

@@ -55,15 +55,6 @@ class ProviderBackendMixin:
def backend_obj(self, **kwargs):
backend = self.backend if self.backend else self.get_default_backend()
backend = copy.deepcopy(backend)
# set default storage kwargs
# NOTE: provider_uri priority
# 1. backend_config: backend_obj["kwargs"]["provider_uri"]
# 2. qlib.init: provider_uri
backend_kwargs = backend.setdefault("kwargs", {})
provider_uri = backend_kwargs.get("provider_uri", None)
provider_uri = C.dpm.provider_uri if provider_uri is None else C.dpm.format_provider_uri(provider_uri)
backend_kwargs["provider_uri"] = provider_uri
backend.setdefault("kwargs", {}).update(**kwargs)
return init_instance_by_config(backend)

View File

@@ -24,9 +24,17 @@ class FileStorageMixin:
"""
# NOTE: provider_uri priority
# 1. self._provider_uri : if provider_uri is provided.
# 2. provider_uri in qlib.config.C
@property
def provider_uri(self):
return C["provider_uri"] if getattr(self, "_provider_uri", None) is None else self._provider_uri
@property
def dpm(self):
return C.DataPathManager(self.provider_uri, None)
return C.dpm if getattr(self, "_provider_uri", None) is None else C.DataPathManager(self._provider_uri, None)
@property
def support_freq(self) -> List[str]:
@@ -62,10 +70,10 @@ class FileStorageMixin:
class FileCalendarStorage(FileStorageMixin, CalendarStorage):
def __init__(self, freq: str, future: bool, provider_uri: dict, **kwargs):
def __init__(self, freq: str, future: bool, provider_uri: dict = None, **kwargs):
super(FileCalendarStorage, self).__init__(freq, future, **kwargs)
self.future = future
self.provider_uri = C.DataPathManager.format_provider_uri(provider_uri)
self._provider_uri = None if provider_uri is None else C.DataPathManager.format_provider_uri(provider_uri)
self.enable_read_cache = True # TODO: make it configurable
@property
@@ -173,9 +181,9 @@ class FileInstrumentStorage(FileStorageMixin, InstrumentStorage):
INSTRUMENT_END_FIELD = "end_datetime"
SYMBOL_FIELD_NAME = "instrument"
def __init__(self, market: str, freq: str, provider_uri: dict, **kwargs):
def __init__(self, market: str, freq: str, provider_uri: dict = None, **kwargs):
super(FileInstrumentStorage, self).__init__(market, freq, **kwargs)
self.provider_uri = C.DataPathManager.format_provider_uri(provider_uri)
self._provider_uri = None if provider_uri is None else C.DataPathManager.format_provider_uri(provider_uri)
self.file_name = f"{market.lower()}.txt"
def _read_instrument(self) -> Dict[InstKT, InstVT]:
@@ -262,9 +270,9 @@ class FileInstrumentStorage(FileStorageMixin, InstrumentStorage):
class FileFeatureStorage(FileStorageMixin, FeatureStorage):
def __init__(self, instrument: str, field: str, freq: str, provider_uri: dict, **kwargs):
def __init__(self, instrument: str, field: str, freq: str, provider_uri: dict = None, **kwargs):
super(FileFeatureStorage, self).__init__(instrument, field, freq, **kwargs)
self.provider_uri = C.DataPathManager.format_provider_uri(provider_uri)
self._provider_uri = None if provider_uri is None else C.DataPathManager.format_provider_uri(provider_uri)
self.file_name = f"{instrument.lower()}/{field.lower()}.{freq.lower()}.bin"
def clear(self):

View File

@@ -122,6 +122,8 @@ class AverageEnsemble(Ensemble):
# need to flatten the nested dict
ensemble_dict = flatten_dict(ensemble_dict, sep=FLATTEN_TUPLE)
values = list(ensemble_dict.values())
# NOTE: this may change the style underlying data!!!!
# from pd.DataFrame to pd.Series
results = pd.concat(values, axis=1)
results = results.groupby("datetime").apply(lambda df: (df - df.mean()) / df.std())
results = results.mean(axis=1)

View File

@@ -3,9 +3,9 @@
"""
Group can group a set of objects based on `group_func` and change them to a dict.
After group, we provide a method to reduce them.
After group, we provide a method to reduce them.
For example:
For example:
group: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
reduce: {(A,B): {C1: object, C2: object}} -> {(A,B): object}
@@ -109,5 +109,5 @@ class RollingGroup(Group):
grouped_dict.setdefault(key[:-1], {})[key[-1]] = values
return grouped_dict
def __init__(self):
super().__init__(ens=RollingEnsemble())
def __init__(self, ens=RollingEnsemble()):
super().__init__(ens=ens)

View File

@@ -26,6 +26,7 @@ from qlib.workflow.record_temp import SignalRecord
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.manage import TaskManager, run_task
# from qlib.data.dataset.weight import Reweighter

View File

@@ -27,7 +27,7 @@ import collections
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, Union, Tuple, Any, Text, Optional
from typing import Dict, Union, Tuple, Any, Text, Optional, Callable
from types import ModuleType
from urllib.parse import urlparse

View File

@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Union
from typing import Dict, Union
import mlflow, logging
from mlflow.entities import ViewType
from mlflow.exceptions import MlflowException
@@ -214,7 +214,7 @@ class Experiment:
"""
raise NotImplementedError(f"Please implement the `_get_recorder` method")
def list_recorders(self, **flt_kwargs):
def list_recorders(self, **flt_kwargs) -> Dict[str, Recorder]:
"""
List all the existing recorders of this experiment. Please first get the experiment instance before calling this method.
If user want to use the method `R.list_recorders()`, please refer to the related API document in `QlibRecorder`.
@@ -325,7 +325,9 @@ class MLflowExperiment(Experiment):
UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!
def list_recorders(self, max_results: int = UNLIMITED, status: Union[str, None] = None, filter_string: str = ""):
def list_recorders(
self, max_results: int = UNLIMITED, status: Union[str, None] = None, filter_string: str = ""
) -> Dict[str, Recorder]:
"""
Parameters
----------

View File

@@ -5,11 +5,12 @@ Updater is a module to update artifacts such as predictions when the stock data
"""
from abc import ABCMeta, abstractmethod
from typing import Optional
import pandas as pd
from qlib import get_module_logger
from qlib.data import D
from qlib.data.dataset import Dataset, DatasetH
from qlib.data.dataset import Dataset, DatasetH, TSDatasetH
from qlib.data.dataset.handler import DataHandlerLP
from qlib.model import Model
from qlib.utils import get_date_by_shift
@@ -25,7 +26,9 @@ class RMDLoader:
def __init__(self, rec: Recorder):
self.rec = rec
def get_dataset(self, start_time, end_time, segments=None) -> DatasetH:
def get_dataset(
self, start_time, end_time, segments=None, unprepared_dataset: Optional[DatasetH] = None
) -> DatasetH:
"""
Load, config and setup dataset.
@@ -39,6 +42,8 @@ class RMDLoader:
segments : dict
the segments config for dataset
Due to the time series dataset (TSDatasetH), the test segments maybe different from start_time and end_time
unprepared_dataset: Optional[DatasetH]
if user don't want to load dataset from recorder, please specify user's dataset
Returns:
DatasetH: the instance of DatasetH
@@ -46,7 +51,10 @@ class RMDLoader:
"""
if segments is None:
segments = {"test": (start_time, end_time)}
dataset: DatasetH = self.rec.load_object("dataset")
if unprepared_dataset is None:
dataset: DatasetH = self.rec.load_object("dataset")
else:
dataset = unprepared_dataset
dataset.config(handler_kwargs={"start_time": start_time, "end_time": end_time}, segments=segments)
dataset.setup_data(handler_kwargs={"init_type": DataHandlerLP.IT_LS})
return dataset
@@ -90,7 +98,16 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):
SZ300676 -0.001321
"""
def __init__(self, record: Recorder, to_date=None, from_date=None, hist_ref: int = 0, freq="day", fname="pred.pkl"):
def __init__(
self,
record: Recorder,
to_date=None,
from_date=None,
hist_ref: Optional[int] = None,
freq="day",
fname="pred.pkl",
loader_cls: type = RMDLoader,
):
"""
Init PredUpdater.
@@ -111,11 +128,15 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):
hist_ref : int
Sometimes, the dataset will have historical depends.
Leave the problem to users to set the length of historical dependency
If user doesn't specify this parameter, Updater will try to load dataset to automatically determine the hist_ref
.. note::
the start_time is not included in the hist_ref
loader_cls : type
the class to load the model and dataset
"""
# TODO: automate this hist_ref in the future.
super().__init__(record=record)
@@ -124,7 +145,7 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):
self.hist_ref = hist_ref
self.freq = freq
self.fname = fname
self.rmdl = RMDLoader(rec=record)
self.rmdl = loader_cls(rec=record)
latest_date = D.calendar(freq=freq)[-1]
if to_date == None:
@@ -148,27 +169,50 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):
else:
self.last_end = get_date_by_shift(from_date, -1, align="right")
def prepare_data(self) -> DatasetH:
def prepare_data(self, unprepared_dataset: Optional[DatasetH] = None) -> DatasetH:
"""
Load dataset
- if unprepared_dataset is specified, then prepare the dataset directly
- Otherwise,
Separating this function will make it easier to reuse the dataset
Returns:
DatasetH: the instance of DatasetH
"""
start_time_buffer = get_date_by_shift(self.last_end, -self.hist_ref + 1, clip_shift=False, freq=self.freq)
# automatically getting the historical dependency if not specified
if self.hist_ref is None:
dataset: DatasetH = self.record.load_object("dataset") if unprepared_dataset is None else unprepared_dataset
# Special treatment of historical dependencies
if isinstance(dataset, TSDatasetH):
hist_ref = dataset.step_len
else:
hist_ref = 0
else:
hist_ref = self.hist_ref
start_time_buffer = get_date_by_shift(self.last_end, -hist_ref + 1, clip_shift=False, freq=self.freq)
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
seg = {"test": (start_time, self.to_date)}
dataset = self.rmdl.get_dataset(start_time=start_time_buffer, end_time=self.to_date, segments=seg)
return dataset
return self.rmdl.get_dataset(
start_time=start_time_buffer, end_time=self.to_date, segments=seg, unprepared_dataset=unprepared_dataset
)
def update(self, dataset: DatasetH = None):
def update(self, dataset: DatasetH = None, write: bool = True, ret_new: bool = False) -> Optional[object]:
"""
Update the data in a recorder.
Parameters
----------
dataset : DatasetH
DatasetH: the instance of DatasetH. None for prepare it again.
write : bool
will the the write action be executed
ret_new : bool
will the updated data be returned
Args:
DatasetH: the instance of DatasetH. None for reprepare.
Returns
-------
Optional[object]
the updated dataset
"""
# FIXME: the problem below is not solved
# The model dumped on GPU instances can not be loaded on CPU instance. Follow exception will raised
@@ -186,7 +230,12 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):
# For reusing the dataset
dataset = self.prepare_data()
self.record.save_objects(**{self.fname: self.get_update_data(dataset)})
updated_data = self.get_update_data(dataset)
if write:
self.record.save_objects(**{self.fname: updated_data})
if ret_new:
return updated_data
@abstractmethod
def get_update_data(self, dataset: Dataset) -> pd.DataFrame:

View File

@@ -169,14 +169,8 @@ class OnlineToolR(OnlineTool):
exp_name = self._get_exp_name(exp_name)
online_models = self.online_models(exp_name=exp_name)
for rec in online_models:
hist_ref = 0
task = rec.load_object("task")
# Special treatment of historical dependencies
cls, kwargs = get_callable_kwargs(task["dataset"], default_module="qlib.data.dataset")
if issubclass(cls, TSDatasetH):
hist_ref = kwargs.get("step_len", TSDatasetH.DEFAULT_STEP_LEN)
try:
updater = PredUpdater(rec, to_date=to_date, from_date=from_date, hist_ref=hist_ref)
updater = PredUpdater(rec, to_date=to_date, from_date=from_date)
except LoadObjectError as e:
# skip the recorder without pred
self.logger.warn(f"An exception `{str(e)}` happened when load `pred.pkl`, skip it.")

View File

@@ -326,13 +326,15 @@ class MLflowRecorder(Recorder):
self.client.log_artifact(self.id, temp_dir / name, artifact_path)
shutil.rmtree(temp_dir)
def load_object(self, name):
def load_object(self, name, unpickler=pickle.Unpickler):
"""
Load object such as prediction file or model checkpoint in mlflow.
Args:
name (str): the object name
unpickler: Supporting using custom unpickler
Raises:
LoadObjectError: if raise some exceptions when load the object
@@ -344,7 +346,7 @@ class MLflowRecorder(Recorder):
try:
path = self.client.download_artifacts(self.id, name)
with Path(path).open("rb") as f:
data = pickle.load(f)
data = unpickler(f).load()
ar = self.client._tracking_client._get_artifact_repo(self.id)
if isinstance(ar, AzureBlobArtifactRepository):
# for saving disk space

View File

@@ -5,12 +5,14 @@
Collector module can collect objects from everywhere and process them such as merging, grouping, averaging and so on.
"""
from collections import defaultdict
from qlib.log import TimeInspector
from typing import Callable, Dict, List
from typing import Callable, Dict, Iterable, List
from qlib.log import get_module_logger
from qlib.utils.serial import Serializable
from qlib.workflow import R
from qlib.workflow.exp import Experiment
from qlib.workflow.recorder import Recorder
class Collector(Serializable):
@@ -142,6 +144,7 @@ class RecorderCollector(Collector):
artifacts_path={"pred": "pred.pkl"},
artifacts_key=None,
list_kwargs={},
status: Iterable = {Recorder.STATUS_FI},
):
"""
Init RecorderCollector.
@@ -156,6 +159,7 @@ class RecorderCollector(Collector):
artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}.
artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts.
list_kwargs (str): arguments for list_recorders function.
status (Iterable): only collect recorders with specific status. None indicating collecting all the recorders
"""
super().__init__(process_list=process_list)
if isinstance(experiment, str):
@@ -171,6 +175,7 @@ class RecorderCollector(Collector):
self.artifacts_key = artifacts_key
self.rec_filter_func = rec_filter_func
self.list_kwargs = list_kwargs
self.status = status
def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) -> dict:
"""
@@ -202,9 +207,19 @@ class RecorderCollector(Collector):
elif isinstance(self.experiment, Callable):
recs = self.experiment()
recs = [rec for rec in recs if rec_filter_func is None or rec_filter_func(rec)]
recs = [
rec
for rec in recs
if (
(self.status is None or rec.status in self.status) and (rec_filter_func is None or rec_filter_func(rec))
)
]
logger = get_module_logger("RecorderCollector")
status_stat = defaultdict(int)
for r in recs:
status_stat[r.status] += 1
logger.info(f"Nubmer of recorders after filter: {status_stat}")
for rec in recs:
rec_key = self.rec_key_func(rec)
for key in artifacts_key: