mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 18:11:18 +08:00
Merge remote-tracking branch 'microsoft/qlib/main' into online_srv
This commit is contained in:
@@ -243,6 +243,7 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
|
||||
- Rank Label
|
||||

|
||||
-->
|
||||
- [Explanation](https://qlib.readthedocs.io/en/latest/component/report.html) of above results
|
||||
|
||||
## Building Customized Quant Research Workflow by Code
|
||||
The automatic workflow may not suit the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
|
||||
|
||||
@@ -298,9 +298,10 @@ Here are some important interfaces that ``DataHandlerLP`` provides:
|
||||
.. autoclass:: qlib.data.dataset.handler.DataHandlerLP
|
||||
:members: __init__, fetch, get_cols
|
||||
|
||||
If users want to load features and labels by config, users can inherit ``qlib.data.dataset.handler.ConfigDataHandler``, ``Qlib`` also provides some preprocess method in this subclass.
|
||||
|
||||
If users want to use qlib data, `QLibDataHandler` is recommended. Users can inherit their custom class from `QLibDataHandler`, which is also a subclass of `ConfigDataHandler`.
|
||||
If users want to load features and labels by config, users can define a new handler and call the static method `parse_config_to_fields` of ``qlib.contrib.data.handler.Alpha158``.
|
||||
|
||||
Also, users can pass ``qlib.contrib.data.processor.ConfigSectionProcessor`` that provides some preprocess methods for features defined by config into the new handler.
|
||||
|
||||
|
||||
Processor
|
||||
@@ -337,7 +338,6 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h
|
||||
|
||||
.. note:: Users need to initialize ``Qlib`` with `qlib.init` first, please refer to `initialization <../start/initialization.html>`_.
|
||||
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
import qlib
|
||||
@@ -364,6 +364,9 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h
|
||||
# fetch all the features
|
||||
print(h.fetch(col_set="feature"))
|
||||
|
||||
|
||||
.. note:: In the ``Alpha158``, ``Qlib`` uses the label `Ref($close, -2)/Ref($close, -1) - 1` that means the change from T+1 to T+2, rather than `Ref($close, -1)/$close - 1`, of which the reason is that when getting the T day close price of a china stock, the stock can be bought on T+1 day and sold on T+2 day.
|
||||
|
||||
API
|
||||
---------
|
||||
|
||||
|
||||
@@ -27,12 +27,11 @@ from qlib.tests.data import GetData
|
||||
from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut
|
||||
|
||||
|
||||
class HighfreqWorkflow(object):
|
||||
class HighfreqWorkflow:
|
||||
|
||||
SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None}
|
||||
|
||||
MARKET = "all"
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
start_time = "2020-09-15 00:00:00"
|
||||
end_time = "2021-01-18 16:00:00"
|
||||
@@ -146,35 +145,40 @@ class HighfreqWorkflow(object):
|
||||
|
||||
self._prepare_calender_cache()
|
||||
##=============reinit dataset=============
|
||||
dataset.init(
|
||||
dataset.config(
|
||||
handler_kwargs={
|
||||
"start_time": "2021-01-19 00:00:00",
|
||||
"end_time": "2021-01-25 16:00:00",
|
||||
},
|
||||
segments={
|
||||
"test": (
|
||||
"2021-01-19 00:00:00",
|
||||
"2021-01-25 16:00:00",
|
||||
),
|
||||
},
|
||||
)
|
||||
dataset.setup_data(
|
||||
handler_kwargs={
|
||||
"init_type": DataHandlerLP.IT_LS,
|
||||
"start_time": "2021-01-19 00:00:00",
|
||||
"end_time": "2021-01-25 16:00:00",
|
||||
},
|
||||
segment_kwargs={
|
||||
"test": (
|
||||
"2021-01-19 00:00:00",
|
||||
"2021-01-25 16:00:00",
|
||||
),
|
||||
},
|
||||
)
|
||||
dataset_backtest.init(
|
||||
dataset_backtest.config(
|
||||
handler_kwargs={
|
||||
"start_time": "2021-01-19 00:00:00",
|
||||
"end_time": "2021-01-25 16:00:00",
|
||||
},
|
||||
segment_kwargs={
|
||||
segments={
|
||||
"test": (
|
||||
"2021-01-19 00:00:00",
|
||||
"2021-01-25 16:00:00",
|
||||
),
|
||||
},
|
||||
)
|
||||
dataset_backtest.setup_data(handler_kwargs={})
|
||||
|
||||
##=============get data=============
|
||||
xtest = dataset.prepare(["test"])
|
||||
backtest_test = dataset_backtest.prepare(["test"])
|
||||
xtest = dataset.prepare("test")
|
||||
backtest_test = dataset_backtest.prepare("test")
|
||||
|
||||
print(xtest, backtest_test)
|
||||
return
|
||||
|
||||
17
examples/rolling_process_data/README.md
Normal file
17
examples/rolling_process_data/README.md
Normal file
@@ -0,0 +1,17 @@
|
||||
# Rolling Process Data
|
||||
|
||||
This workflow is an example for `Rolling Process Data`.
|
||||
|
||||
## Background
|
||||
|
||||
When rolling train the models, data also needs to be generated in the different rolling windows. When the rolling window moves, the training data will change, and the processor's learnable state (such as standard deviation, mean, etc.) will also change.
|
||||
|
||||
In order to avoid regenerating data, this example uses the `DataHandler-based DataLoader` to load the raw features that are not related to the rolling window, and then used Processors to generate processed-features related to the rolling window.
|
||||
|
||||
|
||||
## Run the Code
|
||||
|
||||
Run the example by running the following command:
|
||||
```bash
|
||||
python workflow.py rolling_process
|
||||
```
|
||||
32
examples/rolling_process_data/rolling_handler.py
Normal file
32
examples/rolling_process_data/rolling_handler.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.dataset.loader import DataLoaderDH
|
||||
from qlib.contrib.data.handler import check_transform_proc
|
||||
|
||||
|
||||
class RollingDataHandler(DataHandlerLP):
|
||||
def __init__(
|
||||
self,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
infer_processors=[],
|
||||
learn_processors=[],
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
data_loader_kwargs={},
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
|
||||
data_loader = {
|
||||
"class": "DataLoaderDH",
|
||||
"kwargs": {**data_loader_kwargs},
|
||||
}
|
||||
|
||||
super().__init__(
|
||||
instruments=None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
)
|
||||
141
examples/rolling_process_data/workflow.py
Normal file
141
examples/rolling_process_data/workflow.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import qlib
|
||||
import fire
|
||||
import pickle
|
||||
import pandas as pd
|
||||
|
||||
from datetime import datetime
|
||||
from qlib.config import REG_CN
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
class RollingDataWorkflow:
|
||||
|
||||
MARKET = "csi300"
|
||||
start_time = "2010-01-01"
|
||||
end_time = "2019-12-31"
|
||||
rolling_cnt = 5
|
||||
|
||||
def _init_qlib(self):
|
||||
"""initialize qlib"""
|
||||
# use yahoo_cn_1min data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
def _dump_pre_handler(self, path):
|
||||
handler_config = {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": {
|
||||
"start_time": self.start_time,
|
||||
"end_time": self.end_time,
|
||||
"instruments": self.MARKET,
|
||||
"infer_processors": [],
|
||||
"learn_processors": [],
|
||||
},
|
||||
}
|
||||
pre_handler = init_instance_by_config(handler_config)
|
||||
pre_handler.config(dump_all=True)
|
||||
pre_handler.to_pickle(path)
|
||||
|
||||
def _load_pre_handler(self, path):
|
||||
with open(path, "rb") as file_dataset:
|
||||
pre_handler = pickle.load(file_dataset)
|
||||
return pre_handler
|
||||
|
||||
def rolling_process(self):
|
||||
self._init_qlib()
|
||||
self._dump_pre_handler("pre_handler.pkl")
|
||||
pre_handler = self._load_pre_handler("pre_handler.pkl")
|
||||
|
||||
train_start_time = (2010, 1, 1)
|
||||
train_end_time = (2012, 12, 31)
|
||||
valid_start_time = (2013, 1, 1)
|
||||
valid_end_time = (2013, 12, 31)
|
||||
test_start_time = (2014, 1, 1)
|
||||
test_end_time = (2014, 12, 31)
|
||||
|
||||
dataset_config = {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "RollingDataHandler",
|
||||
"module_path": "rolling_handler",
|
||||
"kwargs": {
|
||||
"start_time": datetime(*train_start_time),
|
||||
"end_time": datetime(*test_end_time),
|
||||
"fit_start_time": datetime(*train_start_time),
|
||||
"fit_end_time": datetime(*train_end_time),
|
||||
"infer_processors": [
|
||||
{"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature"}},
|
||||
],
|
||||
"learn_processors": [
|
||||
{"class": "DropnaLabel"},
|
||||
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
|
||||
],
|
||||
"data_loader_kwargs": {
|
||||
"handler_config": pre_handler,
|
||||
},
|
||||
},
|
||||
},
|
||||
"segments": {
|
||||
"train": (datetime(*train_start_time), datetime(*train_end_time)),
|
||||
"valid": (datetime(*valid_start_time), datetime(*valid_end_time)),
|
||||
"test": (datetime(*test_start_time), datetime(*test_end_time)),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
dataset = init_instance_by_config(dataset_config)
|
||||
|
||||
for rolling_offset in range(self.rolling_cnt):
|
||||
|
||||
print(f"===========rolling{rolling_offset} start===========")
|
||||
if rolling_offset:
|
||||
dataset.config(
|
||||
handler_kwargs={
|
||||
"start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
|
||||
"end_time": datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),
|
||||
"processor_kwargs": {
|
||||
"fit_start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
|
||||
"fit_end_time": datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),
|
||||
},
|
||||
},
|
||||
segments={
|
||||
"train": (
|
||||
datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
|
||||
datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),
|
||||
),
|
||||
"valid": (
|
||||
datetime(valid_start_time[0] + rolling_offset, *valid_start_time[1:]),
|
||||
datetime(valid_end_time[0] + rolling_offset, *valid_end_time[1:]),
|
||||
),
|
||||
"test": (
|
||||
datetime(test_start_time[0] + rolling_offset, *test_start_time[1:]),
|
||||
datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),
|
||||
),
|
||||
},
|
||||
)
|
||||
dataset.setup_data(
|
||||
handler_kwargs={
|
||||
"init_type": DataHandlerLP.IT_FIT_SEQ,
|
||||
}
|
||||
)
|
||||
|
||||
dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"])
|
||||
print(dtrain, dvalid, dtest)
|
||||
## print or dump data
|
||||
print(f"===========rolling{rolling_offset} end===========")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(RollingDataWorkflow)
|
||||
@@ -28,11 +28,17 @@
|
||||
"import sys, site\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"################################# NOTE #################################\n",
|
||||
"# Please be aware that if colab installs the latest numpy and pyqlib #\n",
|
||||
"# in this cell, users should RESTART the runtime in order to run the #\n",
|
||||
"# following cells successfully. #\n",
|
||||
"########################################################################\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" import qlib\n",
|
||||
"except ImportError:\n",
|
||||
" # install qlib\n",
|
||||
" ! pip install --upgrade numpy\n",
|
||||
" ! pip install pyqlib\n",
|
||||
" # reload\n",
|
||||
" site.main()\n",
|
||||
@@ -238,9 +244,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from qlib.contrib.report import analysis_model, analysis_position\n",
|
||||
@@ -359,7 +363,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.9"
|
||||
"version": "3.8.3"
|
||||
},
|
||||
"toc": {
|
||||
"base_numbering": 1,
|
||||
@@ -377,4 +381,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
try:
|
||||
from .catboost_model import CatBoostModel
|
||||
except ModuleNotFoundError:
|
||||
CatBoostModel = None
|
||||
print("Please install necessary libs for CatBoostModel.")
|
||||
try:
|
||||
from .double_ensemble import DEnsembleModel
|
||||
from .gbdt import LGBModel
|
||||
except ModuleNotFoundError:
|
||||
DEnsembleModel, LGBModel = None, None
|
||||
print("Please install necessary libs for DEnsembleModel and LGBModel, such as lightgbm.")
|
||||
try:
|
||||
from .xgboost import XGBModel
|
||||
except ModuleNotFoundError:
|
||||
XGBModel = None
|
||||
print("Please install necessary libs for XGBModel, such as xgboost.")
|
||||
try:
|
||||
from .linear import LinearModel
|
||||
except ModuleNotFoundError:
|
||||
LinearModel = None
|
||||
print("Please install necessary libs for LinearModel, such as scipy and sklearn.")
|
||||
# import pytorch models
|
||||
try:
|
||||
from .pytorch_alstm import ALSTM
|
||||
from .pytorch_gats import GATs
|
||||
from .pytorch_gru import GRU
|
||||
from .pytorch_lstm import LSTM
|
||||
from .pytorch_nn import DNNModelPytorch
|
||||
from .pytorch_tabnet import TabnetModel
|
||||
from .pytorch_sfm import SFM_Model
|
||||
|
||||
pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model)
|
||||
except ModuleNotFoundError:
|
||||
pytorch_classes = ()
|
||||
print("Please install necessary libs for PyTorch models.")
|
||||
|
||||
all_model_classes = (CatBoostModel, DEnsembleModel, LGBModel, XGBModel, LinearModel) + pytorch_classes
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
from catboost import Pool, CatBoost
|
||||
from catboost.utils import get_gpu_device_count
|
||||
|
||||
@@ -62,10 +63,10 @@ class CatBoostModel(Model):
|
||||
evals_result["train"] = list(evals_result["learn"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["validation"].values())[0]
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import lightgbm as lgb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from typing import Text, Union
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -40,6 +40,10 @@ class DEnsembleModel(Model):
|
||||
self.bins_sr = bins_sr
|
||||
self.bins_fs = bins_fs
|
||||
self.decay = decay
|
||||
if sample_ratios is None: # the default values for sample_ratios
|
||||
sample_ratios = [0.8, 0.7, 0.6, 0.5, 0.4]
|
||||
if sub_weights is None: # the default values for sub_weights
|
||||
sub_weights = [1.0, 0.2, 0.2, 0.2, 0.2, 0.2]
|
||||
if not len(sample_ratios) == bins_fs:
|
||||
raise ValueError("The length of sample_ratios should be equal to bins_fs.")
|
||||
self.sample_ratios = sample_ratios
|
||||
@@ -228,10 +232,10 @@ class DEnsembleModel(Model):
|
||||
raise ValueError("not implemented yet")
|
||||
return loss_curve
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if self.ensemble is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
pred = pd.Series(np.zeros(x_test.shape[0]), index=x_test.index)
|
||||
for i_sub, submodel in enumerate(self.ensemble):
|
||||
feat_sub = self.sub_features[i_sub]
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import lightgbm as lgb
|
||||
|
||||
from typing import Text, Union
|
||||
from ...model.base import ModelFT
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -61,10 +61,10 @@ class LGBModel(ModelFT):
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
|
||||
def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from typing import Text, Union
|
||||
from scipy.optimize import nnls
|
||||
from sklearn.linear_model import LinearRegression, Ridge, Lasso
|
||||
|
||||
@@ -84,8 +84,8 @@ class LinearModel(Model):
|
||||
self.coef_ = coef
|
||||
self.intercept_ = 0.0
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if self.coef_ is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(x_test.values @ self.coef_ + self.intercept_, index=x_test.index)
|
||||
|
||||
@@ -8,13 +8,9 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
@@ -273,11 +269,11 @@ class ALSTM(Model):
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.ALSTM_model.eval()
|
||||
x_values = x_test.values
|
||||
|
||||
@@ -8,13 +8,9 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
@@ -264,11 +260,11 @@ class ALSTM(Model):
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
dl_test = dataset.prepare(segment, col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
dl_test.config(fillna_type="ffill+bfill")
|
||||
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
|
||||
self.ALSTM_model.eval()
|
||||
|
||||
@@ -8,13 +8,9 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -83,7 +79,6 @@ class GATs(Model):
|
||||
self.with_pretrain = with_pretrain
|
||||
self.model_path = model_path
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -310,11 +305,11 @@ class GATs(Model):
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature")
|
||||
index = x_test.index
|
||||
self.GAT_model.eval()
|
||||
x_values = x_test.values
|
||||
|
||||
@@ -9,12 +9,7 @@ import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -8,13 +8,9 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
@@ -273,11 +269,11 @@ class GRU(Model):
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.gru_model.eval()
|
||||
x_values = x_test.values
|
||||
|
||||
@@ -9,12 +9,7 @@ import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
@@ -126,8 +121,8 @@ class GRU(Model):
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.gru_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.gru_model)))
|
||||
self.logger.info("model:\n{:}".format(self.GRU_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.GRU_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.GRU_model.parameters(), lr=self.lr)
|
||||
|
||||
@@ -8,13 +8,9 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
@@ -268,11 +264,11 @@ class LSTM(Model):
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.lstm_model.eval()
|
||||
x_values = x_test.values
|
||||
@@ -280,17 +276,13 @@ class LSTM(Model):
|
||||
preds = []
|
||||
|
||||
for begin in range(sample_num)[:: self.batch_size]:
|
||||
|
||||
if sample_num - begin < self.batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.lstm_model(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
@@ -9,12 +9,7 @@ import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
|
||||
@@ -8,6 +8,7 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
|
||||
import torch
|
||||
@@ -18,7 +19,7 @@ from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, get_or_create_path, drop_nan_by_y_index
|
||||
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
from ...workflow import R
|
||||
|
||||
@@ -48,8 +49,8 @@ class DNNModelPytorch(Model):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
output_dim,
|
||||
input_dim=360,
|
||||
output_dim=1,
|
||||
layers=(256,),
|
||||
lr=0.001,
|
||||
max_steps=300,
|
||||
@@ -271,13 +272,12 @@ class DNNModelPytorch(Model):
|
||||
else:
|
||||
raise NotImplementedError("loss {} is not supported!".format(loss_type))
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test_pd = dataset.prepare("test", col_set="feature")
|
||||
x_test_pd = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
x_test = torch.from_numpy(x_test_pd.values).float().to(self.device)
|
||||
self.dnn_model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
preds = self.dnn_model(x_test).detach().cpu().numpy()
|
||||
return pd.Series(np.squeeze(preds), index=x_test_pd.index)
|
||||
|
||||
@@ -7,13 +7,9 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
@@ -442,11 +438,11 @@ class SFM(Model):
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.sfm_model.eval()
|
||||
x_values = x_test.values
|
||||
@@ -459,10 +455,7 @@ class SFM(Model):
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float()
|
||||
|
||||
if self.device != "cpu":
|
||||
x_batch = x_batch.to(self.device)
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.sfm_model(x_batch).detach().cpu().numpy()
|
||||
|
||||
@@ -6,13 +6,9 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
@@ -217,11 +213,11 @@ class TabnetModel(Model):
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.tabnet_model.eval()
|
||||
x_values = torch.from_numpy(x_test.values)
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import xgboost as xgb
|
||||
|
||||
from typing import Text, Union
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -57,8 +57,8 @@ class XGBModel(Model):
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(xgb.DMatrix(x_test.values)), index=x_test.index)
|
||||
|
||||
@@ -251,7 +251,7 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
|
||||
|
||||
def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date):
|
||||
"""
|
||||
Gnererate order list according to score_series at trade_date, will not change current.
|
||||
Generate order list according to score_series at trade_date, will not change current.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from .record_temp import MultiSegRecord
|
||||
from .record_temp import SignalMseRecord
|
||||
|
||||
@@ -1,18 +1,59 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import pandas as pd
|
||||
from sklearn.metrics import mean_squared_error
|
||||
from pprint import pprint
|
||||
from typing import Dict, Text, Any
|
||||
import numpy as np
|
||||
|
||||
from ...contrib.eva.alpha import calc_ic
|
||||
from ...workflow.record_temp import RecordTemp
|
||||
from ...workflow.record_temp import SignalRecord
|
||||
from ...data import dataset as qlib_dataset
|
||||
from ...log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
|
||||
|
||||
class MultiSegRecord(RecordTemp):
|
||||
"""
|
||||
This is the multiple segments signal record class that generates the signal prediction.
|
||||
This class inherits the ``RecordTemp`` class.
|
||||
"""
|
||||
|
||||
def __init__(self, model, dataset, recorder=None):
|
||||
super().__init__(recorder=recorder)
|
||||
if not isinstance(dataset, qlib_dataset.DatasetH):
|
||||
raise ValueError("The type of dataset is not DatasetH instead of {:}".format(type(dataset)))
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
|
||||
def generate(self, segments: Dict[Text, Any], save: bool = False):
|
||||
for key, segment in segments.items():
|
||||
predics = self.model.predict(self.dataset, segment)
|
||||
if isinstance(predics, pd.Series):
|
||||
predics = predics.to_frame("score")
|
||||
labels = self.dataset.prepare(
|
||||
segments=segment, col_set="label", data_key=qlib_dataset.handler.DataHandlerLP.DK_R
|
||||
)
|
||||
# Compute the IC and Rank IC
|
||||
ic, ric = calc_ic(predics.iloc[:, 0], labels.iloc[:, 0])
|
||||
results = {"all-IC": ic, "mean-IC": ic.mean(), "all-Rank-IC": ric, "mean-Rank-IC": ric.mean()}
|
||||
logger.info("--- Results for {:} ({:}) ---".format(key, segment))
|
||||
ic_x100, ric_x100 = ic * 100, ric * 100
|
||||
logger.info("IC: {:.4f}%".format(ic_x100.mean()))
|
||||
logger.info("ICIR: {:.4f}%".format(ic_x100.mean() / ic_x100.std()))
|
||||
logger.info("Rank IC: {:.4f}%".format(ric_x100.mean()))
|
||||
logger.info("Rank ICIR: {:.4f}%".format(ric_x100.mean() / ric_x100.std()))
|
||||
|
||||
if save:
|
||||
save_name = "results-{:}.pkl".format(key)
|
||||
self.recorder.save_objects(**{save_name: results})
|
||||
logger.info(
|
||||
"The record '{save_name}' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
)
|
||||
|
||||
|
||||
class SignalMseRecord(SignalRecord):
|
||||
"""
|
||||
This is the Signal MSE Record class that computes the mean squared error (MSE).
|
||||
@@ -38,7 +79,7 @@ class SignalMseRecord(SignalRecord):
|
||||
objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)}
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.recorder.save_objects(**objects, artifact_path=self.get_path())
|
||||
pprint(metrics)
|
||||
logger.info("The evaluation results in SignalMseRecord is {:}".format(metrics))
|
||||
|
||||
def list(self):
|
||||
paths = [self.get_path("mse.pkl"), self.get_path("rmse.pkl")]
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Union, List, Tuple, Dict, Text, Optional
|
||||
from ...utils import init_instance_by_config, np_ffill
|
||||
from ...log import get_module_logger
|
||||
from .handler import DataHandler, DataHandlerLP
|
||||
from copy import deepcopy
|
||||
from inspect import getfullargspec
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
@@ -16,22 +17,28 @@ class Dataset(Serializable):
|
||||
Preparing data for model training and inferencing.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
init is designed to finish following steps:
|
||||
|
||||
- init the sub instance and the state of the dataset(info to prepare the data)
|
||||
- The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing.
|
||||
|
||||
- setup data
|
||||
- The data related attributes' names should start with '_' so that it will not be saved on disk when serializing.
|
||||
|
||||
- initialize the state of the dataset(info to prepare the data)
|
||||
- The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing.
|
||||
|
||||
The data could specify the info to caculate the essential data for preparation
|
||||
"""
|
||||
self.setup_data(*args, **kwargs)
|
||||
self.setup_data(**kwargs)
|
||||
super().__init__()
|
||||
|
||||
def setup_data(self, *args, **kwargs):
|
||||
def config(self, **kwargs):
|
||||
"""
|
||||
config is designed to configure and parameters that cannot be learned from the data
|
||||
"""
|
||||
super().config(**kwargs)
|
||||
|
||||
def setup_data(self, **kwargs):
|
||||
"""
|
||||
Setup the data.
|
||||
|
||||
@@ -39,7 +46,7 @@ class Dataset(Serializable):
|
||||
|
||||
- User have a Dataset object with learned status on disk.
|
||||
|
||||
- User load the Dataset object from the disk(Note the init function is skiped).
|
||||
- User load the Dataset object from the disk.
|
||||
|
||||
- User call `setup_data` to load new data.
|
||||
|
||||
@@ -47,7 +54,7 @@ class Dataset(Serializable):
|
||||
"""
|
||||
pass
|
||||
|
||||
def prepare(self, *args, **kwargs) -> object:
|
||||
def prepare(self, **kwargs) -> object:
|
||||
"""
|
||||
The type of dataset depends on the model. (It could be pd.DataFrame, pytorch.DataLoader, etc.)
|
||||
The parameters should specify the scope for the prepared data
|
||||
@@ -76,44 +83,7 @@ class DatasetH(Dataset):
|
||||
- The processing is related to data split.
|
||||
"""
|
||||
|
||||
def init(self, handler_kwargs: dict = None, segment_kwargs: dict = None):
|
||||
"""
|
||||
Initialize the DatasetH
|
||||
|
||||
Parameters
|
||||
----------
|
||||
handler_kwargs : dict
|
||||
Config of DataHanlder, which could include the following arguments:
|
||||
|
||||
- arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'.
|
||||
|
||||
- arguments of DataHandler.init, such as 'enable_cache', etc.
|
||||
|
||||
segment_kwargs : dict
|
||||
Config of segments which is same as 'segments' in DatasetH.setup_data
|
||||
|
||||
"""
|
||||
if handler_kwargs:
|
||||
if not isinstance(handler_kwargs, dict):
|
||||
raise TypeError(f"param handler_kwargs must be type dict, not {type(handler_kwargs)}")
|
||||
kwargs_init = {}
|
||||
kwargs_conf_data = {}
|
||||
conf_data_arg = {"instruments", "start_time", "end_time"}
|
||||
for k, v in handler_kwargs.items():
|
||||
if k in conf_data_arg:
|
||||
kwargs_conf_data.update({k: v})
|
||||
else:
|
||||
kwargs_init.update({k: v})
|
||||
|
||||
self.handler.conf_data(**kwargs_conf_data)
|
||||
self.handler.init(**kwargs_init)
|
||||
|
||||
if segment_kwargs:
|
||||
if not isinstance(segment_kwargs, dict):
|
||||
raise TypeError(f"param handler_kwargs must be type dict, not {type(segment_kwargs)}")
|
||||
self.segments = segment_kwargs.copy()
|
||||
|
||||
def setup_data(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple]):
|
||||
def __init__(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple], **kwargs):
|
||||
"""
|
||||
Setup the underlying data.
|
||||
|
||||
@@ -144,6 +114,49 @@ class DatasetH(Dataset):
|
||||
"""
|
||||
self.handler = init_instance_by_config(handler, accept_types=DataHandler)
|
||||
self.segments = segments.copy()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def config(self, handler_kwargs: dict = None, **kwargs):
|
||||
"""
|
||||
Initialize the DatasetH
|
||||
|
||||
Parameters
|
||||
----------
|
||||
handler_kwargs : dict
|
||||
Config of DataHanlder, which could include the following arguments:
|
||||
|
||||
- arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'.
|
||||
|
||||
kwargs : dict
|
||||
Config of DatasetH, such as
|
||||
|
||||
- segments : dict
|
||||
Config of segments which is same as 'segments' in self.__init__
|
||||
|
||||
"""
|
||||
if handler_kwargs is not None:
|
||||
self.handler.config(**handler_kwargs)
|
||||
if "segments" in kwargs:
|
||||
self.segments = deepcopy(kwargs.pop("segments"))
|
||||
super().config(**kwargs)
|
||||
|
||||
def setup_data(self, handler_kwargs: dict = None, **kwargs):
|
||||
"""
|
||||
Setup the Data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
handler_kwargs : dict
|
||||
init arguments of DataHanlder, which could include the following arguments:
|
||||
|
||||
- init_type : Init Type of Handler
|
||||
|
||||
- enable_cache : wheter to enable cache
|
||||
|
||||
"""
|
||||
super().setup_data(**kwargs)
|
||||
if handler_kwargs is not None:
|
||||
self.handler.setup_data(**handler_kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(handler={handler}, segments={segments})".format(
|
||||
@@ -433,15 +446,19 @@ class TSDatasetH(DatasetH):
|
||||
- The dimension of a batch of data <batch_idx, feature, timestep>
|
||||
"""
|
||||
|
||||
def __init__(self, step_len=30, *args, **kwargs):
|
||||
def __init__(self, step_len=30, **kwargs):
|
||||
self.step_len = step_len
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def setup_data(self, *args, **kwargs):
|
||||
super().setup_data(*args, **kwargs)
|
||||
def config(self, **kwargs):
|
||||
if "step_len" in kwargs:
|
||||
self.step_len = kwargs.pop("step_len")
|
||||
super().config(**kwargs)
|
||||
|
||||
def setup_data(self, **kwargs):
|
||||
super().setup_data(**kwargs)
|
||||
cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique()
|
||||
cal = sorted(cal)
|
||||
# Get the datatime index for building timestamp
|
||||
self.cal = cal
|
||||
|
||||
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
|
||||
|
||||
@@ -6,6 +6,7 @@ import abc
|
||||
import bisect
|
||||
import logging
|
||||
import warnings
|
||||
from inspect import getfullargspec
|
||||
from typing import Union, Tuple, List, Iterator, Optional
|
||||
|
||||
import pandas as pd
|
||||
@@ -16,7 +17,7 @@ from ...data import D
|
||||
from ...config import C
|
||||
from ...utils import parse_config, transform_end_date, init_instance_by_config
|
||||
from ...utils.serial import Serializable
|
||||
from .utils import get_level_index, fetch_df_by_index
|
||||
from .utils import fetch_df_by_index
|
||||
from pathlib import Path
|
||||
from .loader import DataLoader
|
||||
|
||||
@@ -102,10 +103,10 @@ class DataHandler(Serializable):
|
||||
self.fetch_orig = fetch_orig
|
||||
if init_data:
|
||||
with TimeInspector.logt("Init data"):
|
||||
self.init()
|
||||
self.setup_data()
|
||||
super().__init__()
|
||||
|
||||
def conf_data(self, **kwargs):
|
||||
def config(self, **kwargs):
|
||||
"""
|
||||
configuration of data.
|
||||
# what data to be loaded from data source
|
||||
@@ -118,13 +119,16 @@ class DataHandler(Serializable):
|
||||
for k, v in kwargs.items():
|
||||
if k in attr_list:
|
||||
setattr(self, k, v)
|
||||
else:
|
||||
raise KeyError("Such config is not supported.")
|
||||
|
||||
def init(self, enable_cache: bool = False):
|
||||
for attr in attr_list:
|
||||
if attr in kwargs:
|
||||
kwargs.pop(attr)
|
||||
|
||||
super().config(**kwargs)
|
||||
|
||||
def setup_data(self, enable_cache: bool = False):
|
||||
"""
|
||||
initialize the data.
|
||||
In case of running intialization for multiple time, it will do nothing for the second time.
|
||||
Set Up the data in case of running intialization for multiple time
|
||||
|
||||
It is responsible for maintaining following variable
|
||||
1) self._data
|
||||
@@ -412,14 +416,28 @@ class DataHandlerLP(DataHandler):
|
||||
if self.drop_raw:
|
||||
del self._data
|
||||
|
||||
def config(self, processor_kwargs: dict = None, **kwargs):
|
||||
"""
|
||||
configuration of data.
|
||||
# what data to be loaded from data source
|
||||
|
||||
This method will be used when loading pickled handler from dataset.
|
||||
The data will be initialized with different time range.
|
||||
|
||||
"""
|
||||
super().config(**kwargs)
|
||||
if processor_kwargs is not None:
|
||||
for processor in self.get_all_processors():
|
||||
processor.config(**processor_kwargs)
|
||||
|
||||
# init type
|
||||
IT_FIT_SEQ = "fit_seq" # the input of `fit` will be the output of the previous processor
|
||||
IT_FIT_IND = "fit_ind" # the input of `fit` will be the original df
|
||||
IT_LS = "load_state" # The state of the object has been load by pickle
|
||||
|
||||
def init(self, init_type: str = IT_FIT_SEQ, enable_cache: bool = False):
|
||||
def setup_data(self, init_type: str = IT_FIT_SEQ, **kwargs):
|
||||
"""
|
||||
Initialize the data of Qlib
|
||||
Set up the data in case of running intialization for multiple time
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -434,7 +452,7 @@ class DataHandlerLP(DataHandler):
|
||||
when we call `init` next time
|
||||
"""
|
||||
# init raw data
|
||||
super().init(enable_cache=enable_cache)
|
||||
super().setup_data(**kwargs)
|
||||
|
||||
with TimeInspector.logt("fit & process data"):
|
||||
if init_type == DataHandlerLP.IT_FIT_IND:
|
||||
|
||||
@@ -217,3 +217,64 @@ class StaticDataLoader(DataLoader):
|
||||
join=self.join,
|
||||
)
|
||||
self._data.sort_index(inplace=True)
|
||||
|
||||
|
||||
class DataLoaderDH(DataLoader):
|
||||
"""DataLoaderDH
|
||||
DataLoader based on (D)ata (H)andler
|
||||
It is designed to load multiple data from data handler
|
||||
- If you just want to load data from single datahandler, you can write them in single data handler
|
||||
"""
|
||||
|
||||
def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
handler_config : dict
|
||||
handler_config will be used to describe the handlers
|
||||
|
||||
.. code-block::
|
||||
|
||||
<handler_config> := {
|
||||
"group_name1": <handler>
|
||||
"group_name2": <handler>
|
||||
}
|
||||
or
|
||||
<handler_config> := <handler>
|
||||
<handler> := DataHandler Instance | DataHandler Config
|
||||
|
||||
fetch_kwargs : dict
|
||||
fetch_kwargs will be used to describe the different arguments of fetch method, such as col_set, squeeze, data_key, etc.
|
||||
|
||||
is_group: bool
|
||||
is_group will be used to describe whether the key of handler_config is group
|
||||
|
||||
"""
|
||||
from qlib.data.dataset.handler import DataHandler
|
||||
|
||||
if is_group:
|
||||
self.handlers = {
|
||||
grp: init_instance_by_config(config, accept_types=DataHandler) for grp, config in handler_config.items()
|
||||
}
|
||||
else:
|
||||
self.handlers = init_instance_by_config(handler_config, accept_types=DataHandler)
|
||||
|
||||
self.is_group = is_group
|
||||
self.fetch_kwargs = {"col_set": DataHandler.CS_RAW}
|
||||
self.fetch_kwargs.update(fetch_kwargs)
|
||||
|
||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
if instruments is not None:
|
||||
LOG.warning(f"instruments[{instruments}] is ignored")
|
||||
|
||||
if self.is_group:
|
||||
df = pd.concat(
|
||||
{
|
||||
grp: dh.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs)
|
||||
for grp, dh in self.handlers.items()
|
||||
},
|
||||
axis=1,
|
||||
)
|
||||
else:
|
||||
df = self.handlers.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs)
|
||||
return df
|
||||
|
||||
11
qlib/data/dataset/processor.py
Executable file → Normal file
11
qlib/data/dataset/processor.py
Executable file → Normal file
@@ -72,6 +72,17 @@ class Processor(Serializable):
|
||||
"""
|
||||
return True
|
||||
|
||||
def config(self, **kwargs):
|
||||
attr_list = {"fit_start_time", "fit_end_time"}
|
||||
for k, v in kwargs.items():
|
||||
if k in attr_list and hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
for attr in attr_list:
|
||||
if attr in kwargs:
|
||||
kwargs.pop(attr)
|
||||
super().config(**kwargs)
|
||||
|
||||
|
||||
class DropnaProcessor(Processor):
|
||||
def __init__(self, fields_group=None):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import abc
|
||||
from typing import Text, Union
|
||||
from ..utils.serial import Serializable
|
||||
from ..data.dataset import Dataset
|
||||
|
||||
@@ -59,7 +60,7 @@ class Model(BaseModel):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def predict(self, dataset: Dataset) -> object:
|
||||
def predict(self, dataset: Dataset, segment: Union[Text, slice] = "test") -> object:
|
||||
"""give prediction given Dataset
|
||||
|
||||
Parameters
|
||||
@@ -67,6 +68,9 @@ class Model(BaseModel):
|
||||
dataset : Dataset
|
||||
dataset will generate the processed dataset from model training.
|
||||
|
||||
segment : Text or slice
|
||||
dataset will use this segment to prepare data. (default=test)
|
||||
|
||||
Returns
|
||||
-------
|
||||
Prediction results with certain type such as `pandas.Series`.
|
||||
|
||||
@@ -416,6 +416,12 @@ class QlibRecorder:
|
||||
"""
|
||||
self.get_exp().get_recorder().save_objects(local_path, artifact_path, **kwargs)
|
||||
|
||||
def load_object(self, name: Text):
|
||||
"""
|
||||
Method for loading an object from artifacts in the experiment in the uri.
|
||||
"""
|
||||
return self.get_exp().get_recorder().load_object(name)
|
||||
|
||||
def log_params(self, **kwargs):
|
||||
"""
|
||||
Method for logging parameters during an experiment. In addition to using ``R``, one can also log to a specific recorder after getting it with `get_recorder` API.
|
||||
|
||||
@@ -159,7 +159,10 @@ class Experiment:
|
||||
if create:
|
||||
recorder, is_new = self._get_or_create_rec(recorder_id=recorder_id, recorder_name=recorder_name)
|
||||
else:
|
||||
recorder, is_new = self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
|
||||
recorder, is_new = (
|
||||
self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name),
|
||||
False,
|
||||
)
|
||||
if is_new:
|
||||
self.active_recorder = recorder
|
||||
# start the recorder
|
||||
@@ -174,7 +177,10 @@ class Experiment:
|
||||
try:
|
||||
if recorder_id is None and recorder_name is None:
|
||||
recorder_name = self._default_rec_name
|
||||
return self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
|
||||
return (
|
||||
self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name),
|
||||
False,
|
||||
)
|
||||
except ValueError:
|
||||
if recorder_name is None:
|
||||
recorder_name = self._default_rec_name
|
||||
|
||||
@@ -159,7 +159,10 @@ class ExpManager:
|
||||
if create:
|
||||
exp, is_new = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name)
|
||||
else:
|
||||
exp, is_new = self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False
|
||||
exp, is_new = (
|
||||
self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name),
|
||||
False,
|
||||
)
|
||||
if is_new:
|
||||
self.active_experiment = exp
|
||||
# start the recorder
|
||||
@@ -172,7 +175,10 @@ class ExpManager:
|
||||
automatically create a new experiment based on the given id and name.
|
||||
"""
|
||||
try:
|
||||
return self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False
|
||||
return (
|
||||
self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name),
|
||||
False,
|
||||
)
|
||||
except ValueError:
|
||||
if experiment_name is None:
|
||||
experiment_name = self._default_exp_name
|
||||
|
||||
@@ -39,7 +39,13 @@ class RecordTemp:
|
||||
return "/".join(names)
|
||||
|
||||
def __init__(self, recorder):
|
||||
self.recorder = recorder
|
||||
self._recorder = recorder
|
||||
|
||||
@property
|
||||
def recorder(self):
|
||||
if self._recorder is None:
|
||||
raise ValueError("This RecordTemp did not set recorder yet.")
|
||||
return self._recorder
|
||||
|
||||
def generate(self, **kwargs):
|
||||
"""
|
||||
@@ -248,11 +254,20 @@ class PortAnaRecord(SignalRecord):
|
||||
report_dict = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
|
||||
report_normal = report_dict.get("report_df")
|
||||
positions_normal = report_dict.get("positions")
|
||||
self.recorder.save_objects(**{"report_normal.pkl": report_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
self.recorder.save_objects(**{"positions_normal.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
self.recorder.save_objects(
|
||||
**{"report_normal.pkl": report_normal},
|
||||
artifact_path=PortAnaRecord.get_path(),
|
||||
)
|
||||
self.recorder.save_objects(
|
||||
**{"positions_normal.pkl": positions_normal},
|
||||
artifact_path=PortAnaRecord.get_path(),
|
||||
)
|
||||
order_normal = report_dict.get("order_list")
|
||||
if order_normal:
|
||||
self.recorder.save_objects(**{"order_normal.pkl": order_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
self.recorder.save_objects(
|
||||
**{"order_normal.pkl": order_normal},
|
||||
artifact_path=PortAnaRecord.get_path(),
|
||||
)
|
||||
|
||||
# analysis
|
||||
analysis = dict()
|
||||
|
||||
@@ -114,6 +114,8 @@ class IndexBase:
|
||||
$ python collector.py save_new_companies --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data
|
||||
"""
|
||||
df = self.get_new_companies()
|
||||
if df is None or df.empty:
|
||||
raise ValueError(f"get new companies error: {self.index_name}")
|
||||
df = df.drop_duplicates([self.SYMBOL_FIELD_NAME])
|
||||
df.loc[:, self.INSTRUMENTS_COLUMNS].to_csv(
|
||||
self.instruments_dir.joinpath(f"{self.index_name.lower()}_only_new.txt"), sep="\t", index=False, header=None
|
||||
@@ -184,7 +186,10 @@ class IndexBase:
|
||||
logger.info(f"start parse {self.index_name.lower()} companies.....")
|
||||
instruments_columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD]
|
||||
changers_df = self.get_changes()
|
||||
new_df = self.get_new_companies().copy()
|
||||
new_df = self.get_new_companies()
|
||||
if new_df is None or new_df.empty:
|
||||
raise ValueError(f"get new companies error: {self.index_name}")
|
||||
new_df = new_df.copy()
|
||||
logger.info("parse history companies by changes......")
|
||||
for _row in tqdm(changers_df.sort_values(self.DATE_FIELD_NAME, ascending=False).itertuples(index=False)):
|
||||
if _row.type == self.ADD:
|
||||
|
||||
@@ -35,7 +35,7 @@ WIKI_INDEX_NAME_MAP = {
|
||||
class WIKIIndex(IndexBase):
|
||||
# NOTE: The US stock code contains "PRN", and the directory cannot be created on Windows system, use the "_" prefix
|
||||
# https://superuser.com/questions/613313/why-cant-we-make-con-prn-null-folder-in-windows
|
||||
INST_PREFIX = "_"
|
||||
INST_PREFIX = ""
|
||||
|
||||
def __init__(self, index_name: str, qlib_dir: [str, Path] = None, request_retry: int = 5, retry_sleep: int = 3):
|
||||
super(WIKIIndex, self).__init__(
|
||||
@@ -123,7 +123,7 @@ class NASDAQ100Index(WIKIIndex):
|
||||
MAX_WORKERS = 16
|
||||
|
||||
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
if not (set(df.columns) - {"Company", "Ticker"}):
|
||||
if len(df) >= 100 and "Ticker" in df.columns:
|
||||
return df.loc[:, ["Ticker"]].copy()
|
||||
|
||||
@property
|
||||
|
||||
@@ -6,24 +6,11 @@ import shutil
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
import qlib
|
||||
from qlib.config import REG_CN, C
|
||||
from qlib.utils import drop_nan_by_y_index
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.contrib.workflow.record_temp import SignalMseRecord
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
|
||||
from qlib.config import C
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
|
||||
@@ -166,8 +153,6 @@ def train_with_sigana():
|
||||
ric = sar.load(sar.get_path("ric.pkl"))
|
||||
pred_score = sar.load("pred.pkl")
|
||||
|
||||
smr = SignalMseRecord(recorder)
|
||||
smr.generate()
|
||||
uri_path = R.get_uri()
|
||||
return pred_score, {"ic": ic, "ric": ric}, uri_path
|
||||
|
||||
@@ -256,8 +241,10 @@ class TestAllFlow(TestAutoData):
|
||||
|
||||
def suite():
|
||||
_suite = unittest.TestSuite()
|
||||
_suite.addTest(TestAllFlow("test_0_train"))
|
||||
_suite.addTest(TestAllFlow("test_1_backtest"))
|
||||
_suite.addTest(TestAllFlow("test_0_train_with_sigana"))
|
||||
_suite.addTest(TestAllFlow("test_1_train"))
|
||||
_suite.addTest(TestAllFlow("test_2_backtest"))
|
||||
_suite.addTest(TestAllFlow("test_3_expmanager"))
|
||||
return _suite
|
||||
|
||||
|
||||
|
||||
27
tests/test_contrib_model.py
Normal file
27
tests/test_contrib_model.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import unittest
|
||||
|
||||
from qlib.contrib.model import all_model_classes
|
||||
|
||||
|
||||
class TestAllFlow(unittest.TestCase):
|
||||
def test_0_initialize(self):
|
||||
num = 0
|
||||
for model_class in all_model_classes:
|
||||
if model_class is not None:
|
||||
model = model_class()
|
||||
num += 1
|
||||
print("There are {:}/{:} valid models in total.".format(num, len(all_model_classes)))
|
||||
|
||||
|
||||
def suite():
|
||||
_suite = unittest.TestSuite()
|
||||
_suite.addTest(TestAllFlow("test_0_initialize"))
|
||||
return _suite
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
runner = unittest.TextTestRunner()
|
||||
runner.run(suite())
|
||||
111
tests/test_contrib_workflow.py
Normal file
111
tests/test_contrib_workflow.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
import shutil
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
from qlib.config import C
|
||||
from qlib.contrib.workflow import MultiSegRecord, SignalMseRecord
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def train_multiseg():
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
with R.start(experiment_name="workflow"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
model.fit(dataset)
|
||||
recorder = R.get_recorder()
|
||||
sr = MultiSegRecord(model, dataset, recorder)
|
||||
sr.generate(dict(valid="valid", test="test"), True)
|
||||
uri = R.get_uri()
|
||||
return uri
|
||||
|
||||
|
||||
def train_mse():
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
with R.start(experiment_name="workflow"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
model.fit(dataset)
|
||||
recorder = R.get_recorder()
|
||||
sr = SignalMseRecord(recorder, model=model, dataset=dataset)
|
||||
sr.generate()
|
||||
uri = R.get_uri()
|
||||
return uri
|
||||
|
||||
|
||||
class TestAllFlow(TestAutoData):
|
||||
def test_0_multiseg(self):
|
||||
uri_path = train_multiseg()
|
||||
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
|
||||
|
||||
def test_1_mse(self):
|
||||
uri_path = train_mse()
|
||||
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
|
||||
|
||||
|
||||
def suite():
|
||||
_suite = unittest.TestSuite()
|
||||
_suite.addTest(TestAllFlow("test_0_multiseg"))
|
||||
_suite.addTest(TestAllFlow("test_1_mse"))
|
||||
return _suite
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
runner = unittest.TextTestRunner()
|
||||
runner.run(suite())
|
||||
Reference in New Issue
Block a user