mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 02:50:58 +08:00
recorder refine; signalTemp; fixbug
This commit is contained in:
0
qlib/contrib/eva/__init__.py
Normal file
0
qlib/contrib/eva/__init__.py
Normal file
32
qlib/contrib/eva/alpha.py
Normal file
32
qlib/contrib/eva/alpha.py
Normal file
@@ -0,0 +1,32 @@
|
||||
'''
|
||||
Here is a batch of evaluation functions.
|
||||
|
||||
The interface should be redesigned carefully in the future.
|
||||
'''
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def calc_ic(pred: pd.Series, label: pd.Series, date_col='datetime', dropna=False) -> (pd.Series, pd.Series):
|
||||
"""calc_ic.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pred :
|
||||
pred
|
||||
label :
|
||||
label
|
||||
date_col :
|
||||
date_col
|
||||
|
||||
Returns
|
||||
-------
|
||||
(pd.Series, pd.Series)
|
||||
ic and rank ic
|
||||
"""
|
||||
df = pd.DataFrame({'pred': pred, 'label': label})
|
||||
ic = df.groupby(date_col).apply(lambda df: df['pred'].corr(df['label']))
|
||||
ric = df.groupby(date_col).apply(lambda df: df['pred'].corr(df['label'], method='spearman'))
|
||||
if dropna:
|
||||
return ic.dropna(), ric.dropna()
|
||||
else:
|
||||
return ic, ric
|
||||
@@ -64,7 +64,7 @@ class LGBModel(ModelFT):
|
||||
def predict(self, dataset):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(np.squeeze(x_test.values)), index=x_test.index)
|
||||
|
||||
def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from ...utils.serial import Serializable
|
||||
from typing import Union, List, Tuple
|
||||
from ...utils import init_instance_by_config
|
||||
from .handler import DataHandler
|
||||
from ...log import get_module_logger
|
||||
from .handler import DataHandler, DataHandlerLP
|
||||
from inspect import getfullargspec
|
||||
import pandas as pd
|
||||
|
||||
|
||||
@@ -98,9 +100,11 @@ class DatasetH(Dataset):
|
||||
self._handler = init_instance_by_config(handler, accept_types=DataHandler)
|
||||
self._segments = segments.copy()
|
||||
|
||||
def prepare(
|
||||
self, segments: Union[List[str], Tuple[str], str, slice], col_set=DataHandler.CS_ALL, **kwargs
|
||||
) -> Union[List[pd.DataFrame], pd.DataFrame]:
|
||||
def prepare(self,
|
||||
segments: Union[List[str], Tuple[str], str, slice],
|
||||
col_set=DataHandler.CS_ALL,
|
||||
data_key=DataHandlerLP.DK_I,
|
||||
**kwargs) -> Union[List[pd.DataFrame], pd.DataFrame]:
|
||||
"""
|
||||
prepare the data for learning and inference
|
||||
|
||||
@@ -111,22 +115,31 @@ class DatasetH(Dataset):
|
||||
Here are some examples
|
||||
1) 'train'
|
||||
2) ['train', 'valid']
|
||||
col_set : [TODO:type]
|
||||
[TODO:description]
|
||||
col_set : str
|
||||
The col_set will be passed to self._handler when fetching data
|
||||
data_key: str
|
||||
The data to fetch: DK_*
|
||||
Default is DK_I, which indicate fetching data for **inference**
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[List[pd.DataFrame], pd.DataFrame]:
|
||||
[TODO:description]
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError:
|
||||
[TODO:description]
|
||||
"""
|
||||
logger = get_module_logger("DatasetH")
|
||||
fetch_kwargs = {"col_set": col_set}
|
||||
fetch_kwargs.update(kwargs)
|
||||
if "data_key"in getfullargspec(self._handler.fetch).args:
|
||||
fetch_kwargs['data_key'] = data_key
|
||||
else:
|
||||
logger.info(f"data_key[{data_key}] is ignored.")
|
||||
|
||||
if isinstance(segments, (list, tuple)):
|
||||
return [self._handler.fetch(slice(*self._segments[seg]), col_set=col_set, **kwargs) for seg in segments]
|
||||
return [self._handler.fetch(slice(*self._segments[seg]), **fetch_kwargs) for seg in segments]
|
||||
elif isinstance(segments, str):
|
||||
return self._handler.fetch(slice(*self._segments[segments]), col_set=col_set, **kwargs)
|
||||
return self._handler.fetch(slice(*self._segments[segments]), **fetch_kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
@@ -9,9 +9,12 @@ from ..contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from ..data.dataset import DatasetH
|
||||
from ..data.dataset.handler import DataHandlerLP
|
||||
from ..utils import init_instance_by_config, get_module_by_module_path
|
||||
from ..log import get_module_logger
|
||||
from ..utils import flatten_dict
|
||||
from ..contrib.eva.alpha import calc_ic
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
|
||||
@@ -22,8 +25,8 @@ class RecordTemp:
|
||||
backtest in a certain format.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
def __init__(self, recorder):
|
||||
self.recorder = recorder
|
||||
|
||||
def generate(self, **kwargs):
|
||||
"""
|
||||
@@ -38,7 +41,7 @@ class RecordTemp:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `generate` method.")
|
||||
|
||||
def load(self, name, **kwargs):
|
||||
def load(self, name):
|
||||
"""
|
||||
Load the stored records.
|
||||
|
||||
@@ -46,13 +49,14 @@ class RecordTemp:
|
||||
----------
|
||||
name : str
|
||||
the name for the file to be load.
|
||||
kwargs
|
||||
|
||||
Return
|
||||
------
|
||||
The stored records.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `load` method.")
|
||||
# try to load the saved object
|
||||
obj = self.recorder.load_object(name)
|
||||
return obj
|
||||
|
||||
def list(self):
|
||||
"""
|
||||
@@ -62,34 +66,36 @@ class RecordTemp:
|
||||
------
|
||||
A list of all the stored records.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `list` method.")
|
||||
return []
|
||||
|
||||
def check(self, **kwargs):
|
||||
def check(self, parent=False):
|
||||
"""
|
||||
Check if the records is properly generated and saved.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
kwargs
|
||||
|
||||
Return
|
||||
Raise
|
||||
------
|
||||
Boolean: whether the records are stored properly.
|
||||
FileExistsError: whether the records are stored properly.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `check` method.")
|
||||
artifacts = set(self.recorder.list_artifacts())
|
||||
if parent:
|
||||
# Downcasting have to be done here instead of using `super`
|
||||
flist = self.__class__.__base__.list(self)
|
||||
else:
|
||||
flist = self.list()
|
||||
for item in flist:
|
||||
if item not in artifacts:
|
||||
raise FileExistsError(item)
|
||||
|
||||
|
||||
# TODO: this can only be run under R's running experiment.
|
||||
class SignalRecord(RecordTemp):
|
||||
"""
|
||||
This is the Signal Record class that generates the signal prediction.
|
||||
"""
|
||||
|
||||
def __init__(self, model, dataset, recorder, **kwargs):
|
||||
super(SignalRecord, self).__init__()
|
||||
def __init__(self, model=None, dataset=None, recorder=None, **kwargs):
|
||||
super().__init__(recorder=recorder)
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
self.recorder = recorder
|
||||
|
||||
def generate(self, **kwargs):
|
||||
# generate prediciton
|
||||
@@ -97,6 +103,7 @@ class SignalRecord(RecordTemp):
|
||||
if isinstance(pred, pd.Series):
|
||||
pred = pred.to_frame("score")
|
||||
self.recorder.save_objects(**{"pred.pkl": pred})
|
||||
|
||||
logger.info(
|
||||
f"Signal record 'pred.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
)
|
||||
@@ -104,35 +111,50 @@ class SignalRecord(RecordTemp):
|
||||
pprint(f"The following are prediction results of the {type(self.model).__name__} model.")
|
||||
pprint(pred.head(5))
|
||||
|
||||
def load(self, name="pred.pkl"):
|
||||
# try to load the saved object
|
||||
pred = self.recorder.load_object(name)
|
||||
return pred
|
||||
# save according label
|
||||
if isinstance(self.dataset, DatasetH):
|
||||
params = dict(self=self.dataset, segments="test", col_set="label", data_key=DataHandlerLP.DK_R)
|
||||
try:
|
||||
# Assume the backend handler is DataHandlerLP
|
||||
raw_label = DatasetH.prepare(**params)
|
||||
except TypeError:
|
||||
# The argument number is not right
|
||||
del params['data_key']
|
||||
# The backend handler should be DataHandler
|
||||
raw_label = DatasetH.prepare(**params)
|
||||
self.recorder.save_objects(**{"label.pkl": raw_label})
|
||||
|
||||
def list(self):
|
||||
return ["pred.pkl"]
|
||||
return ["pred.pkl", "label.pkl"]
|
||||
|
||||
def check(self, **kwargs):
|
||||
artifacts = self.recorder.list_artifacts()
|
||||
for artifact in artifacts:
|
||||
if "pred.pkl" in artifact.path:
|
||||
return True
|
||||
return False
|
||||
def load(self, name="pred.pkl"):
|
||||
return super().load(name)
|
||||
|
||||
|
||||
# TODO
|
||||
class SigAnaRecord(SignalRecord):
|
||||
def __init__(self, recorder, config, **kwargs):
|
||||
pass
|
||||
def __init__(self, recorder, **kwargs):
|
||||
super().__init__(recorder=recorder, **kwargs)
|
||||
# The name must be unique. Otherwise it will be overridden
|
||||
self.artifact_path_sig = "sig_analysis"
|
||||
|
||||
def generate(self):
|
||||
pass
|
||||
self.check(parent=True)
|
||||
|
||||
def load(self):
|
||||
pass
|
||||
pred = self.load("pred.pkl")
|
||||
label = self.load("label.pkl")
|
||||
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
|
||||
metrics = {
|
||||
"IC": ic.mean(),
|
||||
"ICIR": ic.mean() / ic.std(),
|
||||
"Rank IC": ric.mean(),
|
||||
"Rank ICIR": ric.mean() / ric.std()
|
||||
}
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.recorder.save_objects(**{"ic.pkl": ic, "ric.pkl": ric}, artifact_path=self.artifact_path_sig)
|
||||
pprint(metrics)
|
||||
|
||||
def check(self):
|
||||
pass
|
||||
def list(self):
|
||||
return ["{self.artifact_path_sig}/ic.pkl", "{self.artifact_path_sig}/ric.pkl"]
|
||||
|
||||
|
||||
class PortAnaRecord(SignalRecord):
|
||||
@@ -141,26 +163,28 @@ class PortAnaRecord(SignalRecord):
|
||||
"""
|
||||
|
||||
def __init__(self, recorder, config, **kwargs):
|
||||
self.recorder = recorder
|
||||
"""
|
||||
config["strategy"] : dict
|
||||
define the strategy class as well as the kwargs.
|
||||
config["backtest"] : dict
|
||||
define the backtest kwargs.
|
||||
"""
|
||||
super().__init__(recorder=recorder)
|
||||
|
||||
self.strategy_config = config["strategy"]
|
||||
self.backtest_config = config["backtest"]
|
||||
self.strategy = init_instance_by_config(self.strategy_config)
|
||||
self.artifact_path = "portfolio_analysis"
|
||||
self.artifact_path_port = "portfolio_analysis"
|
||||
|
||||
def generate(self, **kwargs):
|
||||
"""
|
||||
STRATEGY_CONFIG : dict
|
||||
define the strategy class as well as the kwargs.
|
||||
BACKTEST_CONFIG : dict
|
||||
define the backtest kwargs.
|
||||
"""
|
||||
# check previously stored prediction results
|
||||
assert super().check(), "Make sure the parent process is completed and store the data properly."
|
||||
self.check(parent=True) # "Make sure the parent process is completed and store the data properly."
|
||||
|
||||
# custom strategy and get backtest
|
||||
pred_score = super().load()
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
|
||||
self.recorder.save_objects(**{"report_normal.pkl": report_normal}, artifact_path=self.artifact_path)
|
||||
self.recorder.save_objects(**{"positions_normal.pkl": positions_normal}, artifact_path=self.artifact_path)
|
||||
self.recorder.save_objects(**{"report_normal.pkl": report_normal}, artifact_path=self.artifact_path_port)
|
||||
self.recorder.save_objects(**{"positions_normal.pkl": positions_normal}, artifact_path=self.artifact_path_port)
|
||||
|
||||
# analysis
|
||||
analysis = dict()
|
||||
@@ -173,7 +197,7 @@ class PortAnaRecord(SignalRecord):
|
||||
# log metrics
|
||||
self.recorder.log_metrics(**flatten_dict(analysis_df["risk"].unstack().T.to_dict()))
|
||||
# save results
|
||||
self.recorder.save_objects(**{"port_analysis.pkl": analysis_df}, artifact_path=self.artifact_path)
|
||||
self.recorder.save_objects(**{"port_analysis.pkl": analysis_df}, artifact_path=self.artifact_path_port)
|
||||
logger.info(
|
||||
f"Portfolio analysis record 'port_analysis.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
)
|
||||
@@ -183,24 +207,9 @@ class PortAnaRecord(SignalRecord):
|
||||
pprint("The following are analysis results of the excess return with cost.")
|
||||
pprint(analysis["excess_return_with_cost"])
|
||||
|
||||
def load(self, name):
|
||||
# try to load the saved object
|
||||
if self.artifact_path not in name:
|
||||
file_name = re.split(r" |/|\\", name)[-1]
|
||||
name = f"{self.artifact_path}/{file_name}"
|
||||
result = self.recorder.load_object(name)
|
||||
return result
|
||||
|
||||
def list(self):
|
||||
return [
|
||||
f"{self.artifact_path}/report_normal.pkl",
|
||||
f"{self.artifact_path}/positions_normal.pkl",
|
||||
f"{self.artifact_path}/port_analysis.pkl",
|
||||
f"{self.artifact_path_port}/report_normal.pkl",
|
||||
f"{self.artifact_path_port}/positions_normal.pkl",
|
||||
f"{self.artifact_path_port}/port_analysis.pkl",
|
||||
]
|
||||
|
||||
def check(self):
|
||||
artifacts = self.recorder.list_artifacts(self.artifact_path)
|
||||
for artifact in artifacts:
|
||||
if "port_analysis.pkl" in artifact.path:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -290,7 +290,7 @@ class MLflowRecorder(Recorder):
|
||||
def list_artifacts(self, artifact_path=None):
|
||||
assert self._uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
artifacts = self.client.list_artifacts(self.id, artifact_path)
|
||||
return artifacts
|
||||
return [art.path for art in artifacts]
|
||||
|
||||
def list_metrics(self):
|
||||
run = self.client.get_run(self.id)
|
||||
|
||||
Reference in New Issue
Block a user