mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 19:41:00 +08:00
Some Optimization of online code (#784)
* Some Optimization of online code * more flexible updater and load_object & fix p*_uri * make recorder more friendly * remove unused import
This commit is contained in:
@@ -55,15 +55,6 @@ class ProviderBackendMixin:
|
||||
def backend_obj(self, **kwargs):
|
||||
backend = self.backend if self.backend else self.get_default_backend()
|
||||
backend = copy.deepcopy(backend)
|
||||
|
||||
# set default storage kwargs
|
||||
# NOTE: provider_uri priority:
|
||||
# 1. backend_config: backend_obj["kwargs"]["provider_uri"]
|
||||
# 2. qlib.init: provider_uri
|
||||
backend_kwargs = backend.setdefault("kwargs", {})
|
||||
provider_uri = backend_kwargs.get("provider_uri", None)
|
||||
provider_uri = C.dpm.provider_uri if provider_uri is None else C.dpm.format_provider_uri(provider_uri)
|
||||
backend_kwargs["provider_uri"] = provider_uri
|
||||
backend.setdefault("kwargs", {}).update(**kwargs)
|
||||
return init_instance_by_config(backend)
|
||||
|
||||
|
||||
@@ -24,9 +24,17 @@ class FileStorageMixin:
|
||||
|
||||
"""
|
||||
|
||||
# NOTE: provider_uri priority:
|
||||
# 1. self._provider_uri : if provider_uri is provided.
|
||||
# 2. provider_uri in qlib.config.C
|
||||
|
||||
@property
|
||||
def provider_uri(self):
|
||||
return C["provider_uri"] if getattr(self, "_provider_uri", None) is None else self._provider_uri
|
||||
|
||||
@property
|
||||
def dpm(self):
|
||||
return C.DataPathManager(self.provider_uri, None)
|
||||
return C.dpm if getattr(self, "_provider_uri", None) is None else C.DataPathManager(self._provider_uri, None)
|
||||
|
||||
@property
|
||||
def support_freq(self) -> List[str]:
|
||||
@@ -62,10 +70,10 @@ class FileStorageMixin:
|
||||
|
||||
|
||||
class FileCalendarStorage(FileStorageMixin, CalendarStorage):
|
||||
def __init__(self, freq: str, future: bool, provider_uri: dict, **kwargs):
|
||||
def __init__(self, freq: str, future: bool, provider_uri: dict = None, **kwargs):
|
||||
super(FileCalendarStorage, self).__init__(freq, future, **kwargs)
|
||||
self.future = future
|
||||
self.provider_uri = C.DataPathManager.format_provider_uri(provider_uri)
|
||||
self._provider_uri = None if provider_uri is None else C.DataPathManager.format_provider_uri(provider_uri)
|
||||
self.enable_read_cache = True # TODO: make it configurable
|
||||
|
||||
@property
|
||||
@@ -173,9 +181,9 @@ class FileInstrumentStorage(FileStorageMixin, InstrumentStorage):
|
||||
INSTRUMENT_END_FIELD = "end_datetime"
|
||||
SYMBOL_FIELD_NAME = "instrument"
|
||||
|
||||
def __init__(self, market: str, freq: str, provider_uri: dict, **kwargs):
|
||||
def __init__(self, market: str, freq: str, provider_uri: dict = None, **kwargs):
|
||||
super(FileInstrumentStorage, self).__init__(market, freq, **kwargs)
|
||||
self.provider_uri = C.DataPathManager.format_provider_uri(provider_uri)
|
||||
self._provider_uri = None if provider_uri is None else C.DataPathManager.format_provider_uri(provider_uri)
|
||||
self.file_name = f"{market.lower()}.txt"
|
||||
|
||||
def _read_instrument(self) -> Dict[InstKT, InstVT]:
|
||||
@@ -262,9 +270,9 @@ class FileInstrumentStorage(FileStorageMixin, InstrumentStorage):
|
||||
|
||||
|
||||
class FileFeatureStorage(FileStorageMixin, FeatureStorage):
|
||||
def __init__(self, instrument: str, field: str, freq: str, provider_uri: dict, **kwargs):
|
||||
def __init__(self, instrument: str, field: str, freq: str, provider_uri: dict = None, **kwargs):
|
||||
super(FileFeatureStorage, self).__init__(instrument, field, freq, **kwargs)
|
||||
self.provider_uri = C.DataPathManager.format_provider_uri(provider_uri)
|
||||
self._provider_uri = None if provider_uri is None else C.DataPathManager.format_provider_uri(provider_uri)
|
||||
self.file_name = f"{instrument.lower()}/{field.lower()}.{freq.lower()}.bin"
|
||||
|
||||
def clear(self):
|
||||
|
||||
@@ -122,6 +122,8 @@ class AverageEnsemble(Ensemble):
|
||||
# need to flatten the nested dict
|
||||
ensemble_dict = flatten_dict(ensemble_dict, sep=FLATTEN_TUPLE)
|
||||
values = list(ensemble_dict.values())
|
||||
# NOTE: this may change the style underlying data!!!!
|
||||
# from pd.DataFrame to pd.Series
|
||||
results = pd.concat(values, axis=1)
|
||||
results = results.groupby("datetime").apply(lambda df: (df - df.mean()) / df.std())
|
||||
results = results.mean(axis=1)
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
|
||||
"""
|
||||
Group can group a set of objects based on `group_func` and change them to a dict.
|
||||
After group, we provide a method to reduce them.
|
||||
After group, we provide a method to reduce them.
|
||||
|
||||
For example:
|
||||
For example:
|
||||
|
||||
group: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
|
||||
reduce: {(A,B): {C1: object, C2: object}} -> {(A,B): object}
|
||||
@@ -109,5 +109,5 @@ class RollingGroup(Group):
|
||||
grouped_dict.setdefault(key[:-1], {})[key[-1]] = values
|
||||
return grouped_dict
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(ens=RollingEnsemble())
|
||||
def __init__(self, ens=RollingEnsemble()):
|
||||
super().__init__(ens=ens)
|
||||
|
||||
@@ -26,6 +26,7 @@ from qlib.workflow.record_temp import SignalRecord
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
|
||||
|
||||
# from qlib.data.dataset.weight import Reweighter
|
||||
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ import collections
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union, Tuple, Any, Text, Optional
|
||||
from typing import Dict, Union, Tuple, Any, Text, Optional, Callable
|
||||
from types import ModuleType
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Union
|
||||
from typing import Dict, Union
|
||||
import mlflow, logging
|
||||
from mlflow.entities import ViewType
|
||||
from mlflow.exceptions import MlflowException
|
||||
@@ -214,7 +214,7 @@ class Experiment:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `_get_recorder` method")
|
||||
|
||||
def list_recorders(self, **flt_kwargs):
|
||||
def list_recorders(self, **flt_kwargs) -> Dict[str, Recorder]:
|
||||
"""
|
||||
List all the existing recorders of this experiment. Please first get the experiment instance before calling this method.
|
||||
If user want to use the method `R.list_recorders()`, please refer to the related API document in `QlibRecorder`.
|
||||
@@ -325,7 +325,9 @@ class MLflowExperiment(Experiment):
|
||||
|
||||
UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!
|
||||
|
||||
def list_recorders(self, max_results: int = UNLIMITED, status: Union[str, None] = None, filter_string: str = ""):
|
||||
def list_recorders(
|
||||
self, max_results: int = UNLIMITED, status: Union[str, None] = None, filter_string: str = ""
|
||||
) -> Dict[str, Recorder]:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
|
||||
@@ -5,11 +5,12 @@ Updater is a module to update artifacts such as predictions when the stock data
|
||||
"""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
from qlib import get_module_logger
|
||||
from qlib.data import D
|
||||
from qlib.data.dataset import Dataset, DatasetH
|
||||
from qlib.data.dataset import Dataset, DatasetH, TSDatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.model import Model
|
||||
from qlib.utils import get_date_by_shift
|
||||
@@ -25,7 +26,9 @@ class RMDLoader:
|
||||
def __init__(self, rec: Recorder):
|
||||
self.rec = rec
|
||||
|
||||
def get_dataset(self, start_time, end_time, segments=None) -> DatasetH:
|
||||
def get_dataset(
|
||||
self, start_time, end_time, segments=None, unprepared_dataset: Optional[DatasetH] = None
|
||||
) -> DatasetH:
|
||||
"""
|
||||
Load, config and setup dataset.
|
||||
|
||||
@@ -39,6 +42,8 @@ class RMDLoader:
|
||||
segments : dict
|
||||
the segments config for dataset
|
||||
Due to the time series dataset (TSDatasetH), the test segments maybe different from start_time and end_time
|
||||
unprepared_dataset: Optional[DatasetH]
|
||||
if user don't want to load dataset from recorder, please specify user's dataset
|
||||
|
||||
Returns:
|
||||
DatasetH: the instance of DatasetH
|
||||
@@ -46,7 +51,10 @@ class RMDLoader:
|
||||
"""
|
||||
if segments is None:
|
||||
segments = {"test": (start_time, end_time)}
|
||||
dataset: DatasetH = self.rec.load_object("dataset")
|
||||
if unprepared_dataset is None:
|
||||
dataset: DatasetH = self.rec.load_object("dataset")
|
||||
else:
|
||||
dataset = unprepared_dataset
|
||||
dataset.config(handler_kwargs={"start_time": start_time, "end_time": end_time}, segments=segments)
|
||||
dataset.setup_data(handler_kwargs={"init_type": DataHandlerLP.IT_LS})
|
||||
return dataset
|
||||
@@ -90,7 +98,16 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):
|
||||
SZ300676 -0.001321
|
||||
"""
|
||||
|
||||
def __init__(self, record: Recorder, to_date=None, from_date=None, hist_ref: int = 0, freq="day", fname="pred.pkl"):
|
||||
def __init__(
|
||||
self,
|
||||
record: Recorder,
|
||||
to_date=None,
|
||||
from_date=None,
|
||||
hist_ref: Optional[int] = None,
|
||||
freq="day",
|
||||
fname="pred.pkl",
|
||||
loader_cls: type = RMDLoader,
|
||||
):
|
||||
"""
|
||||
Init PredUpdater.
|
||||
|
||||
@@ -111,11 +128,15 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):
|
||||
hist_ref : int
|
||||
Sometimes, the dataset will have historical depends.
|
||||
Leave the problem to users to set the length of historical dependency
|
||||
If user doesn't specify this parameter, Updater will try to load dataset to automatically determine the hist_ref
|
||||
|
||||
.. note::
|
||||
|
||||
the start_time is not included in the hist_ref
|
||||
|
||||
loader_cls : type
|
||||
the class to load the model and dataset
|
||||
|
||||
"""
|
||||
# TODO: automate this hist_ref in the future.
|
||||
super().__init__(record=record)
|
||||
@@ -124,7 +145,7 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):
|
||||
self.hist_ref = hist_ref
|
||||
self.freq = freq
|
||||
self.fname = fname
|
||||
self.rmdl = RMDLoader(rec=record)
|
||||
self.rmdl = loader_cls(rec=record)
|
||||
|
||||
latest_date = D.calendar(freq=freq)[-1]
|
||||
if to_date == None:
|
||||
@@ -148,27 +169,50 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):
|
||||
else:
|
||||
self.last_end = get_date_by_shift(from_date, -1, align="right")
|
||||
|
||||
def prepare_data(self) -> DatasetH:
|
||||
def prepare_data(self, unprepared_dataset: Optional[DatasetH] = None) -> DatasetH:
|
||||
"""
|
||||
Load dataset
|
||||
- if unprepared_dataset is specified, then prepare the dataset directly
|
||||
- Otherwise,
|
||||
|
||||
Separating this function will make it easier to reuse the dataset
|
||||
|
||||
Returns:
|
||||
DatasetH: the instance of DatasetH
|
||||
"""
|
||||
start_time_buffer = get_date_by_shift(self.last_end, -self.hist_ref + 1, clip_shift=False, freq=self.freq)
|
||||
# automatically getting the historical dependency if not specified
|
||||
if self.hist_ref is None:
|
||||
dataset: DatasetH = self.record.load_object("dataset") if unprepared_dataset is None else unprepared_dataset
|
||||
# Special treatment of historical dependencies
|
||||
if isinstance(dataset, TSDatasetH):
|
||||
hist_ref = dataset.step_len
|
||||
else:
|
||||
hist_ref = 0
|
||||
else:
|
||||
hist_ref = self.hist_ref
|
||||
|
||||
start_time_buffer = get_date_by_shift(self.last_end, -hist_ref + 1, clip_shift=False, freq=self.freq)
|
||||
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
|
||||
seg = {"test": (start_time, self.to_date)}
|
||||
dataset = self.rmdl.get_dataset(start_time=start_time_buffer, end_time=self.to_date, segments=seg)
|
||||
return dataset
|
||||
return self.rmdl.get_dataset(
|
||||
start_time=start_time_buffer, end_time=self.to_date, segments=seg, unprepared_dataset=unprepared_dataset
|
||||
)
|
||||
|
||||
def update(self, dataset: DatasetH = None):
|
||||
def update(self, dataset: DatasetH = None, write: bool = True, ret_new: bool = False) -> Optional[object]:
|
||||
"""
|
||||
Update the data in a recorder.
|
||||
Parameters
|
||||
----------
|
||||
dataset : DatasetH
|
||||
DatasetH: the instance of DatasetH. None for prepare it again.
|
||||
write : bool
|
||||
will the the write action be executed
|
||||
ret_new : bool
|
||||
will the updated data be returned
|
||||
|
||||
Args:
|
||||
DatasetH: the instance of DatasetH. None for reprepare.
|
||||
Returns
|
||||
-------
|
||||
Optional[object]
|
||||
the updated dataset
|
||||
"""
|
||||
# FIXME: the problem below is not solved
|
||||
# The model dumped on GPU instances can not be loaded on CPU instance. Follow exception will raised
|
||||
@@ -186,7 +230,12 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):
|
||||
# For reusing the dataset
|
||||
dataset = self.prepare_data()
|
||||
|
||||
self.record.save_objects(**{self.fname: self.get_update_data(dataset)})
|
||||
updated_data = self.get_update_data(dataset)
|
||||
|
||||
if write:
|
||||
self.record.save_objects(**{self.fname: updated_data})
|
||||
if ret_new:
|
||||
return updated_data
|
||||
|
||||
@abstractmethod
|
||||
def get_update_data(self, dataset: Dataset) -> pd.DataFrame:
|
||||
|
||||
@@ -169,14 +169,8 @@ class OnlineToolR(OnlineTool):
|
||||
exp_name = self._get_exp_name(exp_name)
|
||||
online_models = self.online_models(exp_name=exp_name)
|
||||
for rec in online_models:
|
||||
hist_ref = 0
|
||||
task = rec.load_object("task")
|
||||
# Special treatment of historical dependencies
|
||||
cls, kwargs = get_callable_kwargs(task["dataset"], default_module="qlib.data.dataset")
|
||||
if issubclass(cls, TSDatasetH):
|
||||
hist_ref = kwargs.get("step_len", TSDatasetH.DEFAULT_STEP_LEN)
|
||||
try:
|
||||
updater = PredUpdater(rec, to_date=to_date, from_date=from_date, hist_ref=hist_ref)
|
||||
updater = PredUpdater(rec, to_date=to_date, from_date=from_date)
|
||||
except LoadObjectError as e:
|
||||
# skip the recorder without pred
|
||||
self.logger.warn(f"An exception `{str(e)}` happened when load `pred.pkl`, skip it.")
|
||||
|
||||
@@ -326,13 +326,15 @@ class MLflowRecorder(Recorder):
|
||||
self.client.log_artifact(self.id, temp_dir / name, artifact_path)
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
def load_object(self, name):
|
||||
def load_object(self, name, unpickler=pickle.Unpickler):
|
||||
"""
|
||||
Load object such as prediction file or model checkpoint in mlflow.
|
||||
|
||||
Args:
|
||||
name (str): the object name
|
||||
|
||||
unpickler: Supporting using custom unpickler
|
||||
|
||||
Raises:
|
||||
LoadObjectError: if raise some exceptions when load the object
|
||||
|
||||
@@ -344,7 +346,7 @@ class MLflowRecorder(Recorder):
|
||||
try:
|
||||
path = self.client.download_artifacts(self.id, name)
|
||||
with Path(path).open("rb") as f:
|
||||
data = pickle.load(f)
|
||||
data = unpickler(f).load()
|
||||
ar = self.client._tracking_client._get_artifact_repo(self.id)
|
||||
if isinstance(ar, AzureBlobArtifactRepository):
|
||||
# for saving disk space
|
||||
|
||||
@@ -5,12 +5,14 @@
|
||||
Collector module can collect objects from everywhere and process them such as merging, grouping, averaging and so on.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from qlib.log import TimeInspector
|
||||
from typing import Callable, Dict, List
|
||||
from typing import Callable, Dict, Iterable, List
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils.serial import Serializable
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.exp import Experiment
|
||||
from qlib.workflow.recorder import Recorder
|
||||
|
||||
|
||||
class Collector(Serializable):
|
||||
@@ -142,6 +144,7 @@ class RecorderCollector(Collector):
|
||||
artifacts_path={"pred": "pred.pkl"},
|
||||
artifacts_key=None,
|
||||
list_kwargs={},
|
||||
status: Iterable = {Recorder.STATUS_FI},
|
||||
):
|
||||
"""
|
||||
Init RecorderCollector.
|
||||
@@ -156,6 +159,7 @@ class RecorderCollector(Collector):
|
||||
artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}.
|
||||
artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts.
|
||||
list_kwargs (str): arguments for list_recorders function.
|
||||
status (Iterable): only collect recorders with specific status. None indicating collecting all the recorders
|
||||
"""
|
||||
super().__init__(process_list=process_list)
|
||||
if isinstance(experiment, str):
|
||||
@@ -171,6 +175,7 @@ class RecorderCollector(Collector):
|
||||
self.artifacts_key = artifacts_key
|
||||
self.rec_filter_func = rec_filter_func
|
||||
self.list_kwargs = list_kwargs
|
||||
self.status = status
|
||||
|
||||
def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) -> dict:
|
||||
"""
|
||||
@@ -202,9 +207,19 @@ class RecorderCollector(Collector):
|
||||
elif isinstance(self.experiment, Callable):
|
||||
recs = self.experiment()
|
||||
|
||||
recs = [rec for rec in recs if rec_filter_func is None or rec_filter_func(rec)]
|
||||
recs = [
|
||||
rec
|
||||
for rec in recs
|
||||
if (
|
||||
(self.status is None or rec.status in self.status) and (rec_filter_func is None or rec_filter_func(rec))
|
||||
)
|
||||
]
|
||||
|
||||
logger = get_module_logger("RecorderCollector")
|
||||
status_stat = defaultdict(int)
|
||||
for r in recs:
|
||||
status_stat[r.status] += 1
|
||||
logger.info(f"Nubmer of recorders after filter: {status_stat}")
|
||||
for rec in recs:
|
||||
rec_key = self.rec_key_func(rec)
|
||||
for key in artifacts_key:
|
||||
|
||||
Reference in New Issue
Block a user