mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
solve the conflict
This commit is contained in:
@@ -243,6 +243,7 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
|
||||
- Rank Label
|
||||

|
||||
-->
|
||||
- [Explanation](https://qlib.readthedocs.io/en/latest/component/report.html) of above results
|
||||
|
||||
## Building Customized Quant Research Workflow by Code
|
||||
The automatic workflow may not suit the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
|
||||
|
||||
@@ -298,9 +298,10 @@ Here are some important interfaces that ``DataHandlerLP`` provides:
|
||||
.. autoclass:: qlib.data.dataset.handler.DataHandlerLP
|
||||
:members: __init__, fetch, get_cols
|
||||
|
||||
If users want to load features and labels by config, users can inherit ``qlib.data.dataset.handler.ConfigDataHandler``, ``Qlib`` also provides some preprocess method in this subclass.
|
||||
|
||||
If users want to use qlib data, `QLibDataHandler` is recommended. Users can inherit their custom class from `QLibDataHandler`, which is also a subclass of `ConfigDataHandler`.
|
||||
If users want to load features and labels by config, users can define a new handler and call the static method `parse_config_to_fields` of ``qlib.contrib.data.handler.Alpha158``.
|
||||
|
||||
Also, users can pass ``qlib.contrib.data.processor.ConfigSectionProcessor`` that provides some preprocess methods for features defined by config into the new handler.
|
||||
|
||||
|
||||
Processor
|
||||
@@ -337,7 +338,6 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h
|
||||
|
||||
.. note:: Users need to initialize ``Qlib`` with `qlib.init` first, please refer to `initialization <../start/initialization.html>`_.
|
||||
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
import qlib
|
||||
@@ -364,6 +364,9 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h
|
||||
# fetch all the features
|
||||
print(h.fetch(col_set="feature"))
|
||||
|
||||
|
||||
.. note:: In the ``Alpha158``, ``Qlib`` uses the label `Ref($close, -2)/Ref($close, -1) - 1` that means the change from T+1 to T+2, rather than `Ref($close, -1)/$close - 1`, of which the reason is that when getting the T day close price of a china stock, the stock can be bought on T+1 day and sold on T+2 day.
|
||||
|
||||
API
|
||||
---------
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0493±0.01 | 0.3778±0.06| 0.0585±0.00 | 0.4606±0.04 | 0.0513±0.03 | 0.6727±0.38| -0.1085±0.02 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 |
|
||||
| TabNet (Sercan O. Arik, et al.)| Alpha360 | 0.0192±0.00 | 0.1401±0.00| 0.0291±0.00 | 0.2163±0.00 | -0.0258±0.00 | -0.2961±0.00| -0.1429±0.00 |
|
||||
|
||||
## Alpha158 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
@@ -32,6 +33,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| ALSTM (Yao Qin, et al.) | Alpha158 (with selected 20 features) | 0.0385±0.01 | 0.3022±0.06| 0.0478±0.00 | 0.3874±0.04 | 0.0486±0.03 | 0.7141±0.45| -0.1088±0.03 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4338±0.01 | 0.0523±0.00 | 0.4257±0.01 | 0.1253±0.01 | 1.4105±0.14 | -0.0902±0.01 |
|
||||
| TabNet (Sercan O. Arik, et al.)| Alpha158 | 0.0383±0.00 | 0.3414±0.00| 0.0388±0.00 | 0.3460±0.00 | 0.0226±0.00 | 0.2652±0.00| -0.1072±0.00 |
|
||||
|
||||
- The selected 20 features are based on the feature importance of a lightgbm-based model.
|
||||
- The base model of DoubleEnsemble is LGBM.
|
||||
|
||||
@@ -132,7 +132,7 @@ class GenericDataFormatter(abc.ABC):
|
||||
return -1, -1
|
||||
|
||||
def get_column_definition(self):
|
||||
""""Returns formatted column definition in order expected by the TFT."""
|
||||
"""Returns formatted column definition in order expected by the TFT."""
|
||||
|
||||
column_definition = self._column_definition
|
||||
|
||||
|
||||
@@ -25,4 +25,11 @@ The example is given in `workflow.py`, users can run the code as follows.
|
||||
Run the example by running the following command:
|
||||
```bash
|
||||
python workflow.py dump_and_load_dataset
|
||||
```
|
||||
```
|
||||
|
||||
## Benchmarks Performance
|
||||
### Signal Test
|
||||
Here are the results of signal test for benchmark models. We will keep updating benchmark models in future.
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Long precision| Short Precision | Long-Short Average Return | Long-Short Average Sharpe |
|
||||
|---|---|---|---|---|---|---|---|---|---|
|
||||
| LightGBM | Alpha158 | 0.3042±0.00 | 1.5372±0.00| 0.3117±0.00 | 1.6258±0.00 | 0.6720±0.00 | 0.6870±0.00 | 0.000769±0.00 | 1.0190±0.00 |
|
||||
|
||||
@@ -27,12 +27,11 @@ from qlib.tests.data import GetData
|
||||
from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut
|
||||
|
||||
|
||||
class HighfreqWorkflow(object):
|
||||
class HighfreqWorkflow:
|
||||
|
||||
SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None}
|
||||
|
||||
MARKET = "all"
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
start_time = "2020-09-15 00:00:00"
|
||||
end_time = "2021-01-18 16:00:00"
|
||||
@@ -146,35 +145,40 @@ class HighfreqWorkflow(object):
|
||||
|
||||
self._prepare_calender_cache()
|
||||
##=============reinit dataset=============
|
||||
dataset.init(
|
||||
dataset.config(
|
||||
handler_kwargs={
|
||||
"start_time": "2021-01-19 00:00:00",
|
||||
"end_time": "2021-01-25 16:00:00",
|
||||
},
|
||||
segments={
|
||||
"test": (
|
||||
"2021-01-19 00:00:00",
|
||||
"2021-01-25 16:00:00",
|
||||
),
|
||||
},
|
||||
)
|
||||
dataset.setup_data(
|
||||
handler_kwargs={
|
||||
"init_type": DataHandlerLP.IT_LS,
|
||||
"start_time": "2021-01-19 00:00:00",
|
||||
"end_time": "2021-01-25 16:00:00",
|
||||
},
|
||||
segment_kwargs={
|
||||
"test": (
|
||||
"2021-01-19 00:00:00",
|
||||
"2021-01-25 16:00:00",
|
||||
),
|
||||
},
|
||||
)
|
||||
dataset_backtest.init(
|
||||
dataset_backtest.config(
|
||||
handler_kwargs={
|
||||
"start_time": "2021-01-19 00:00:00",
|
||||
"end_time": "2021-01-25 16:00:00",
|
||||
},
|
||||
segment_kwargs={
|
||||
segments={
|
||||
"test": (
|
||||
"2021-01-19 00:00:00",
|
||||
"2021-01-25 16:00:00",
|
||||
),
|
||||
},
|
||||
)
|
||||
dataset_backtest.setup_data(handler_kwargs={})
|
||||
|
||||
##=============get data=============
|
||||
xtest = dataset.prepare(["test"])
|
||||
backtest_test = dataset_backtest.prepare(["test"])
|
||||
xtest = dataset.prepare("test")
|
||||
backtest_test = dataset_backtest.prepare("test")
|
||||
|
||||
print(xtest, backtest_test)
|
||||
return
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data_1min"
|
||||
region: cn
|
||||
market: &market 'csi300'
|
||||
start_time: &start_time "2020-09-15 00:00:00"
|
||||
end_time: &end_time "2021-01-18 16:00:00"
|
||||
train_end_time: &train_end_time "2020-11-15 16:00:00"
|
||||
valid_start_time: &valid_start_time "2020-11-16 00:00:00"
|
||||
valid_end_time: &valid_end_time "2020-11-30 16:00:00"
|
||||
test_start_time: &test_start_time "2020-12-01 00:00:00"
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: *start_time
|
||||
end_time: *end_time
|
||||
fit_start_time: *start_time
|
||||
fit_end_time: *train_end_time
|
||||
instruments: *market
|
||||
freq: '1min'
|
||||
infer_processors:
|
||||
- class: 'RobustZScoreNorm'
|
||||
kwargs:
|
||||
fields_group: 'feature'
|
||||
clip_outlier: false
|
||||
- class: "Fillna"
|
||||
kwargs:
|
||||
fields_group: 'feature'
|
||||
learn_processors:
|
||||
- class: 'DropnaLabel'
|
||||
- class: 'CSRankNorm'
|
||||
kwargs:
|
||||
fields_group: 'label'
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
task:
|
||||
model:
|
||||
class: "HFLGBModel"
|
||||
module_path: "qlib.contrib.model.highfreq_gdbt_model"
|
||||
kwargs:
|
||||
objective: 'binary'
|
||||
metric: ['binary_logloss','auc']
|
||||
verbosity: -1
|
||||
learning_rate: 0.01
|
||||
max_depth: 8
|
||||
num_leaves: 150
|
||||
lambda_l1: 1.5
|
||||
lambda_l2: 1
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: "DatasetH"
|
||||
module_path: "qlib.data.dataset"
|
||||
kwargs:
|
||||
handler:
|
||||
class: "Alpha158"
|
||||
module_path: "qlib.contrib.data.handler"
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [*start_time, *train_end_time]
|
||||
valid: [*train_end_time, *valid_end_time]
|
||||
test: [*test_start_time, *end_time]
|
||||
record:
|
||||
- class: "SignalRecord"
|
||||
module_path: "qlib.workflow.record_temp"
|
||||
kwargs: {}
|
||||
- class: "HFSignalRecord"
|
||||
module_path: "qlib.workflow.record_temp"
|
||||
kwargs: {}
|
||||
17
examples/rolling_process_data/README.md
Normal file
17
examples/rolling_process_data/README.md
Normal file
@@ -0,0 +1,17 @@
|
||||
# Rolling Process Data
|
||||
|
||||
This workflow is an example for `Rolling Process Data`.
|
||||
|
||||
## Background
|
||||
|
||||
When rolling train the models, data also needs to be generated in the different rolling windows. When the rolling window moves, the training data will change, and the processor's learnable state (such as standard deviation, mean, etc.) will also change.
|
||||
|
||||
In order to avoid regenerating data, this example uses the `DataHandler-based DataLoader` to load the raw features that are not related to the rolling window, and then used Processors to generate processed-features related to the rolling window.
|
||||
|
||||
|
||||
## Run the Code
|
||||
|
||||
Run the example by running the following command:
|
||||
```bash
|
||||
python workflow.py rolling_process
|
||||
```
|
||||
32
examples/rolling_process_data/rolling_handler.py
Normal file
32
examples/rolling_process_data/rolling_handler.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.dataset.loader import DataLoaderDH
|
||||
from qlib.contrib.data.handler import check_transform_proc
|
||||
|
||||
|
||||
class RollingDataHandler(DataHandlerLP):
|
||||
def __init__(
|
||||
self,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
infer_processors=[],
|
||||
learn_processors=[],
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
data_loader_kwargs={},
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
|
||||
data_loader = {
|
||||
"class": "DataLoaderDH",
|
||||
"kwargs": {**data_loader_kwargs},
|
||||
}
|
||||
|
||||
super().__init__(
|
||||
instruments=None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
)
|
||||
141
examples/rolling_process_data/workflow.py
Normal file
141
examples/rolling_process_data/workflow.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import qlib
|
||||
import fire
|
||||
import pickle
|
||||
import pandas as pd
|
||||
|
||||
from datetime import datetime
|
||||
from qlib.config import REG_CN
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
class RollingDataWorkflow:
|
||||
|
||||
MARKET = "csi300"
|
||||
start_time = "2010-01-01"
|
||||
end_time = "2019-12-31"
|
||||
rolling_cnt = 5
|
||||
|
||||
def _init_qlib(self):
|
||||
"""initialize qlib"""
|
||||
# use yahoo_cn_1min data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
def _dump_pre_handler(self, path):
|
||||
handler_config = {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": {
|
||||
"start_time": self.start_time,
|
||||
"end_time": self.end_time,
|
||||
"instruments": self.MARKET,
|
||||
"infer_processors": [],
|
||||
"learn_processors": [],
|
||||
},
|
||||
}
|
||||
pre_handler = init_instance_by_config(handler_config)
|
||||
pre_handler.config(dump_all=True)
|
||||
pre_handler.to_pickle(path)
|
||||
|
||||
def _load_pre_handler(self, path):
|
||||
with open(path, "rb") as file_dataset:
|
||||
pre_handler = pickle.load(file_dataset)
|
||||
return pre_handler
|
||||
|
||||
def rolling_process(self):
|
||||
self._init_qlib()
|
||||
self._dump_pre_handler("pre_handler.pkl")
|
||||
pre_handler = self._load_pre_handler("pre_handler.pkl")
|
||||
|
||||
train_start_time = (2010, 1, 1)
|
||||
train_end_time = (2012, 12, 31)
|
||||
valid_start_time = (2013, 1, 1)
|
||||
valid_end_time = (2013, 12, 31)
|
||||
test_start_time = (2014, 1, 1)
|
||||
test_end_time = (2014, 12, 31)
|
||||
|
||||
dataset_config = {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "RollingDataHandler",
|
||||
"module_path": "rolling_handler",
|
||||
"kwargs": {
|
||||
"start_time": datetime(*train_start_time),
|
||||
"end_time": datetime(*test_end_time),
|
||||
"fit_start_time": datetime(*train_start_time),
|
||||
"fit_end_time": datetime(*train_end_time),
|
||||
"infer_processors": [
|
||||
{"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature"}},
|
||||
],
|
||||
"learn_processors": [
|
||||
{"class": "DropnaLabel"},
|
||||
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
|
||||
],
|
||||
"data_loader_kwargs": {
|
||||
"handler_config": pre_handler,
|
||||
},
|
||||
},
|
||||
},
|
||||
"segments": {
|
||||
"train": (datetime(*train_start_time), datetime(*train_end_time)),
|
||||
"valid": (datetime(*valid_start_time), datetime(*valid_end_time)),
|
||||
"test": (datetime(*test_start_time), datetime(*test_end_time)),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
dataset = init_instance_by_config(dataset_config)
|
||||
|
||||
for rolling_offset in range(self.rolling_cnt):
|
||||
|
||||
print(f"===========rolling{rolling_offset} start===========")
|
||||
if rolling_offset:
|
||||
dataset.config(
|
||||
handler_kwargs={
|
||||
"start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
|
||||
"end_time": datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),
|
||||
"processor_kwargs": {
|
||||
"fit_start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
|
||||
"fit_end_time": datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),
|
||||
},
|
||||
},
|
||||
segments={
|
||||
"train": (
|
||||
datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
|
||||
datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),
|
||||
),
|
||||
"valid": (
|
||||
datetime(valid_start_time[0] + rolling_offset, *valid_start_time[1:]),
|
||||
datetime(valid_end_time[0] + rolling_offset, *valid_end_time[1:]),
|
||||
),
|
||||
"test": (
|
||||
datetime(test_start_time[0] + rolling_offset, *test_start_time[1:]),
|
||||
datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),
|
||||
),
|
||||
},
|
||||
)
|
||||
dataset.setup_data(
|
||||
handler_kwargs={
|
||||
"init_type": DataHandlerLP.IT_FIT_SEQ,
|
||||
}
|
||||
)
|
||||
|
||||
dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"])
|
||||
print(dtrain, dvalid, dtest)
|
||||
## print or dump data
|
||||
print(f"===========rolling{rolling_offset} end===========")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(RollingDataWorkflow)
|
||||
@@ -28,11 +28,17 @@
|
||||
"import sys, site\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"################################# NOTE #################################\n",
|
||||
"# Please be aware that if colab installs the latest numpy and pyqlib #\n",
|
||||
"# in this cell, users should RESTART the runtime in order to run the #\n",
|
||||
"# following cells successfully. #\n",
|
||||
"########################################################################\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" import qlib\n",
|
||||
"except ImportError:\n",
|
||||
" # install qlib\n",
|
||||
" ! pip install --upgrade numpy\n",
|
||||
" ! pip install pyqlib\n",
|
||||
" # reload\n",
|
||||
" site.main()\n",
|
||||
@@ -238,9 +244,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from qlib.contrib.report import analysis_model, analysis_position\n",
|
||||
@@ -359,7 +363,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.9"
|
||||
"version": "3.8.3"
|
||||
},
|
||||
"toc": {
|
||||
"base_numbering": 1,
|
||||
@@ -377,4 +381,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
}
|
||||
@@ -130,7 +130,7 @@ class Position:
|
||||
return self.position["cash"]
|
||||
|
||||
def get_stock_amount_dict(self):
|
||||
"""generate stock amount dict {stock_id : amount of stock} """
|
||||
"""generate stock amount dict {stock_id : amount of stock}"""
|
||||
d = {}
|
||||
stock_list = self.get_stock_list()
|
||||
for stock_code in stock_list:
|
||||
|
||||
@@ -8,6 +8,59 @@ import pandas as pd
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def calc_long_short_prec(
|
||||
pred: pd.Series, label: pd.Series, date_col="datetime", quantile: float = 0.2, dropna=False, is_alpha=False
|
||||
) -> Tuple[pd.Series, pd.Series]:
|
||||
"""
|
||||
calculate the precision for long and short operation
|
||||
|
||||
|
||||
:param pred/label: index is **pd.MultiIndex**, index name is **[datetime, instruments]**; columns names is **[score]**.
|
||||
|
||||
.. code-block:: python
|
||||
score
|
||||
datetime instrument
|
||||
2020-12-01 09:30:00 SH600068 0.553634
|
||||
SH600195 0.550017
|
||||
SH600276 0.540321
|
||||
SH600584 0.517297
|
||||
SH600715 0.544674
|
||||
label :
|
||||
label
|
||||
date_col :
|
||||
date_col
|
||||
|
||||
Returns
|
||||
-------
|
||||
(pd.Series, pd.Series)
|
||||
long precision and short precision in time level
|
||||
"""
|
||||
if is_alpha:
|
||||
label = label - label.mean(level=date_col)
|
||||
if int(1 / quantile) >= len(label.index.get_level_values(1).unique()):
|
||||
raise ValueError("Need more instruments to calculate precision")
|
||||
|
||||
df = pd.DataFrame({"pred": pred, "label": label})
|
||||
if dropna:
|
||||
df.dropna(inplace=True)
|
||||
|
||||
group = df.groupby(level=date_col)
|
||||
|
||||
N = lambda x: int(len(x) * quantile)
|
||||
# find the top/low quantile of prediction and treat them as long and short target
|
||||
long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label).reset_index(level=0, drop=True)
|
||||
short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label).reset_index(level=0, drop=True)
|
||||
|
||||
groupll = long.groupby(date_col)
|
||||
l_dom = groupll.apply(lambda x: x > 0)
|
||||
l_c = groupll.count()
|
||||
|
||||
groups = short.groupby(date_col)
|
||||
s_dom = groups.apply(lambda x: x < 0)
|
||||
s_c = groups.count()
|
||||
return (l_dom.groupby(date_col).sum() / l_c), (s_dom.groupby(date_col).sum() / s_c)
|
||||
|
||||
|
||||
def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]:
|
||||
"""calc_ic.
|
||||
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
try:
|
||||
from .catboost_model import CatBoostModel
|
||||
except ModuleNotFoundError:
|
||||
CatBoostModel = None
|
||||
print("Please install necessary libs for CatBoostModel.")
|
||||
try:
|
||||
from .double_ensemble import DEnsembleModel
|
||||
from .gbdt import LGBModel
|
||||
except ModuleNotFoundError:
|
||||
DEnsembleModel, LGBModel = None, None
|
||||
print("Please install necessary libs for DEnsembleModel and LGBModel, such as lightgbm.")
|
||||
try:
|
||||
from .xgboost import XGBModel
|
||||
except ModuleNotFoundError:
|
||||
XGBModel = None
|
||||
print("Please install necessary libs for XGBModel, such as xgboost.")
|
||||
try:
|
||||
from .linear import LinearModel
|
||||
except ModuleNotFoundError:
|
||||
LinearModel = None
|
||||
print("Please install necessary libs for LinearModel, such as scipy and sklearn.")
|
||||
# import pytorch models
|
||||
try:
|
||||
from .pytorch_alstm import ALSTM
|
||||
from .pytorch_gats import GATs
|
||||
from .pytorch_gru import GRU
|
||||
from .pytorch_lstm import LSTM
|
||||
from .pytorch_nn import DNNModelPytorch
|
||||
from .pytorch_tabnet import TabnetModel
|
||||
from .pytorch_sfm import SFM_Model
|
||||
|
||||
pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model)
|
||||
except ModuleNotFoundError:
|
||||
pytorch_classes = ()
|
||||
print("Please install necessary libs for PyTorch models.")
|
||||
|
||||
all_model_classes = (CatBoostModel, DEnsembleModel, LGBModel, XGBModel, LinearModel) + pytorch_classes
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
from catboost import Pool, CatBoost
|
||||
from catboost.utils import get_gpu_device_count
|
||||
|
||||
@@ -62,10 +63,10 @@ class CatBoostModel(Model):
|
||||
evals_result["train"] = list(evals_result["learn"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["validation"].values())[0]
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import lightgbm as lgb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from typing import Text, Union
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -40,6 +40,10 @@ class DEnsembleModel(Model):
|
||||
self.bins_sr = bins_sr
|
||||
self.bins_fs = bins_fs
|
||||
self.decay = decay
|
||||
if sample_ratios is None: # the default values for sample_ratios
|
||||
sample_ratios = [0.8, 0.7, 0.6, 0.5, 0.4]
|
||||
if sub_weights is None: # the default values for sub_weights
|
||||
sub_weights = [1.0, 0.2, 0.2, 0.2, 0.2, 0.2]
|
||||
if not len(sample_ratios) == bins_fs:
|
||||
raise ValueError("The length of sample_ratios should be equal to bins_fs.")
|
||||
self.sample_ratios = sample_ratios
|
||||
@@ -228,10 +232,10 @@ class DEnsembleModel(Model):
|
||||
raise ValueError("not implemented yet")
|
||||
return loss_curve
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if self.ensemble is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
pred = pd.Series(np.zeros(x_test.shape[0]), index=x_test.index)
|
||||
for i_sub, submodel in enumerate(self.ensemble):
|
||||
feat_sub = self.sub_features[i_sub]
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import lightgbm as lgb
|
||||
|
||||
from typing import Text, Union
|
||||
from ...model.base import ModelFT
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -61,10 +61,10 @@ class LGBModel(ModelFT):
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
|
||||
def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
|
||||
|
||||
157
qlib/contrib/model/highfreq_gdbt_model.py
Normal file
157
qlib/contrib/model/highfreq_gdbt_model.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import lightgbm as lgb
|
||||
|
||||
from qlib.model.base import ModelFT
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
import warnings
|
||||
|
||||
|
||||
class HFLGBModel(ModelFT):
|
||||
"""LightGBM Model for high frequency prediction"""
|
||||
|
||||
def __init__(self, loss="mse", **kwargs):
|
||||
if loss not in {"mse", "binary"}:
|
||||
raise NotImplementedError
|
||||
self.params = {"objective": loss, "verbosity": -1}
|
||||
self.params.update(kwargs)
|
||||
self.model = None
|
||||
|
||||
def _cal_signal_metrics(self, y_test, l_cut, r_cut):
|
||||
"""
|
||||
Calcaute the signal metrics by daily level
|
||||
"""
|
||||
up_pre, down_pre = [], []
|
||||
up_alpha_ll, down_alpha_ll = [], []
|
||||
for date in y_test.index.get_level_values(0).unique():
|
||||
df_res = y_test.loc[date].sort_values("pred")
|
||||
if int(l_cut * len(df_res)) < 10:
|
||||
warnings.warn("Warning: threhold is too low or instruments number is not enough")
|
||||
continue
|
||||
top = df_res.iloc[: int(l_cut * len(df_res))]
|
||||
bottom = df_res.iloc[int(r_cut * len(df_res)) :]
|
||||
|
||||
down_precision = len(top[top[top.columns[0]] < 0]) / (len(top))
|
||||
up_precision = len(bottom[bottom[top.columns[0]] > 0]) / (len(bottom))
|
||||
|
||||
down_alpha = top[top.columns[0]].mean()
|
||||
up_alpha = bottom[bottom.columns[0]].mean()
|
||||
|
||||
up_pre.append(up_precision)
|
||||
down_pre.append(down_precision)
|
||||
up_alpha_ll.append(up_alpha)
|
||||
down_alpha_ll.append(down_alpha)
|
||||
|
||||
return (
|
||||
np.array(up_pre).mean(),
|
||||
np.array(down_pre).mean(),
|
||||
np.array(up_alpha_ll).mean(),
|
||||
np.array(down_alpha_ll).mean(),
|
||||
)
|
||||
|
||||
def hf_signal_test(self, dataset: DatasetH, threhold=0.2):
|
||||
"""
|
||||
Test the sigal in high frequency test set
|
||||
"""
|
||||
if self.model == None:
|
||||
raise ValueError("Model hasn't been trained yet")
|
||||
df_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
df_test.dropna(inplace=True)
|
||||
x_test, y_test = df_test["feature"], df_test["label"]
|
||||
# Convert label into alpha
|
||||
y_test[y_test.columns[0]] = y_test[y_test.columns[0]] - y_test[y_test.columns[0]].mean(level=0)
|
||||
|
||||
res = pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
y_test["pred"] = res
|
||||
|
||||
up_p, down_p, up_a, down_a = self._cal_signal_metrics(y_test, threhold, 1 - threhold)
|
||||
print("===============================")
|
||||
print("High frequency signal test")
|
||||
print("===============================")
|
||||
print("Test set precision: ")
|
||||
print("Positive precision: {}, Negative precision: {}".format(up_p, down_p))
|
||||
print("Test Alpha Average in test set: ")
|
||||
print("Positive average alpha: {}, Negative average alpha: {}".format(up_a, down_a))
|
||||
|
||||
def _prepare_data(self, dataset: DatasetH):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_train["feature"], df_valid["label"]
|
||||
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
|
||||
l_name = df_train["label"].columns[0]
|
||||
# Convert label into alpha
|
||||
df_train["label"][l_name] = df_train["label"][l_name] - df_train["label"][l_name].mean(level=0)
|
||||
df_valid["label"][l_name] = df_valid["label"][l_name] - df_valid["label"][l_name].mean(level=0)
|
||||
mapping_fn = lambda x: 0 if x < 0 else 1
|
||||
df_train["label_c"] = df_train["label"][l_name].apply(mapping_fn)
|
||||
df_valid["label_c"] = df_valid["label"][l_name].apply(mapping_fn)
|
||||
x_train, y_train = df_train["feature"], df_train["label_c"].values
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label_c"].values
|
||||
else:
|
||||
raise ValueError("LightGBM doesn't support multi-label training")
|
||||
|
||||
dtrain = lgb.Dataset(x_train.values, label=y_train)
|
||||
dvalid = lgb.Dataset(x_valid.values, label=y_valid)
|
||||
return dtrain, dvalid
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
num_boost_round=1000,
|
||||
early_stopping_rounds=50,
|
||||
verbose_eval=20,
|
||||
evals_result=dict(),
|
||||
**kwargs
|
||||
):
|
||||
dtrain, dvalid = self._prepare_data(dataset)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=["train", "valid"],
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
**kwargs
|
||||
)
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
|
||||
def predict(self, dataset):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
|
||||
def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
|
||||
"""
|
||||
finetune model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset : DatasetH
|
||||
dataset for finetuning
|
||||
num_boost_round : int
|
||||
number of round to finetune model
|
||||
verbose_eval : int
|
||||
verbose level
|
||||
"""
|
||||
# Based on existing model and finetune by train more rounds
|
||||
dtrain, _ = self._prepare_data(dataset)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
init_model=self.model,
|
||||
valid_sets=[dtrain],
|
||||
valid_names=["train"],
|
||||
verbose_eval=verbose_eval,
|
||||
)
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from typing import Text, Union
|
||||
from scipy.optimize import nnls
|
||||
from sklearn.linear_model import LinearRegression, Ridge, Lasso
|
||||
|
||||
@@ -84,8 +84,8 @@ class LinearModel(Model):
|
||||
self.coef_ = coef
|
||||
self.intercept_ = 0.0
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if self.coef_ is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(x_test.values @ self.coef_ + self.intercept_, index=x_test.index)
|
||||
|
||||
@@ -8,13 +8,9 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
@@ -273,11 +269,11 @@ class ALSTM(Model):
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.ALSTM_model.eval()
|
||||
x_values = x_test.values
|
||||
|
||||
@@ -8,13 +8,9 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
@@ -264,11 +260,11 @@ class ALSTM(Model):
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
dl_test = dataset.prepare(segment, col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
dl_test.config(fillna_type="ffill+bfill")
|
||||
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
|
||||
self.ALSTM_model.eval()
|
||||
|
||||
@@ -8,13 +8,9 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -83,7 +79,6 @@ class GATs(Model):
|
||||
self.with_pretrain = with_pretrain
|
||||
self.model_path = model_path
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -310,11 +305,11 @@ class GATs(Model):
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature")
|
||||
index = x_test.index
|
||||
self.GAT_model.eval()
|
||||
x_values = x_test.values
|
||||
|
||||
@@ -9,12 +9,7 @@ import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -8,13 +8,9 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
@@ -273,11 +269,11 @@ class GRU(Model):
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.gru_model.eval()
|
||||
x_values = x_test.values
|
||||
|
||||
@@ -9,12 +9,7 @@ import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
@@ -126,8 +121,8 @@ class GRU(Model):
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.gru_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.gru_model)))
|
||||
self.logger.info("model:\n{:}".format(self.GRU_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.GRU_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.GRU_model.parameters(), lr=self.lr)
|
||||
|
||||
@@ -8,13 +8,9 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
@@ -268,11 +264,11 @@ class LSTM(Model):
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.lstm_model.eval()
|
||||
x_values = x_test.values
|
||||
@@ -280,17 +276,13 @@ class LSTM(Model):
|
||||
preds = []
|
||||
|
||||
for begin in range(sample_num)[:: self.batch_size]:
|
||||
|
||||
if sample_num - begin < self.batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.lstm_model(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
@@ -9,12 +9,7 @@ import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
|
||||
@@ -8,6 +8,7 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
|
||||
import torch
|
||||
@@ -18,7 +19,7 @@ from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, get_or_create_path, drop_nan_by_y_index
|
||||
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
from ...workflow import R
|
||||
|
||||
@@ -48,8 +49,8 @@ class DNNModelPytorch(Model):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
output_dim,
|
||||
input_dim=360,
|
||||
output_dim=1,
|
||||
layers=(256,),
|
||||
lr=0.001,
|
||||
max_steps=300,
|
||||
@@ -271,13 +272,12 @@ class DNNModelPytorch(Model):
|
||||
else:
|
||||
raise NotImplementedError("loss {} is not supported!".format(loss_type))
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test_pd = dataset.prepare("test", col_set="feature")
|
||||
x_test_pd = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
x_test = torch.from_numpy(x_test_pd.values).float().to(self.device)
|
||||
self.dnn_model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
preds = self.dnn_model(x_test).detach().cpu().numpy()
|
||||
return pd.Series(np.squeeze(preds), index=x_test_pd.index)
|
||||
|
||||
@@ -7,13 +7,9 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
@@ -442,11 +438,11 @@ class SFM(Model):
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.sfm_model.eval()
|
||||
x_values = x_test.values
|
||||
@@ -459,10 +455,7 @@ class SFM(Model):
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float()
|
||||
|
||||
if self.device != "cpu":
|
||||
x_batch = x_batch.to(self.device)
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.sfm_model(x_batch).detach().cpu().numpy()
|
||||
|
||||
@@ -6,13 +6,9 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
@@ -217,11 +213,11 @@ class TabnetModel(Model):
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.tabnet_model.eval()
|
||||
x_values = torch.from_numpy(x_test.values)
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import xgboost as xgb
|
||||
|
||||
from typing import Text, Union
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -57,8 +57,8 @@ class XGBModel(Model):
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(xgb.DMatrix(x_test.values)), index=x_test.index)
|
||||
|
||||
@@ -214,7 +214,7 @@ def cumulative_return_graph(
|
||||
features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close - 1'], pred_df_dates.min(), pred_df_dates.max())
|
||||
features_df.columns = ['label']
|
||||
|
||||
qcr.cumulative_return_graph(positions, report_normal_df, features_df)
|
||||
qcr.analysis_position.cumulative_return_graph(positions, report_normal_df, features_df)
|
||||
|
||||
|
||||
Graph desc:
|
||||
|
||||
@@ -94,7 +94,7 @@ def rank_label_graph(
|
||||
features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'], pred_df_dates.min(), pred_df_dates.max())
|
||||
features_df.columns = ['label']
|
||||
|
||||
qcr.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())
|
||||
qcr.analysis_position.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())
|
||||
|
||||
|
||||
:param position: position data; **qlib.contrib.backtest.backtest.backtest** result.
|
||||
|
||||
@@ -186,7 +186,7 @@ def report_graph(report_df: pd.DataFrame, show_notebook: bool = True) -> [list,
|
||||
|
||||
report_normal_df, _ = backtest(pred_df, strategy, **bparas)
|
||||
|
||||
qcr.report_graph(report_normal_df)
|
||||
qcr.analysis_position.report_graph(report_normal_df)
|
||||
|
||||
:param report_df: **df.index.name** must be **date**, **df.columns** must contain **return**, **turnover**, **cost**, **bench**.
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from ...utils import get_module_by_module_path
|
||||
|
||||
|
||||
class BaseGraph:
|
||||
""""""
|
||||
""" """
|
||||
|
||||
_name = None
|
||||
|
||||
|
||||
413
qlib/contrib/strategy/strategy.py
Normal file
413
qlib/contrib/strategy/strategy.py
Normal file
@@ -0,0 +1,413 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ..backtest.order import Order
|
||||
from .order_generator import OrderGenWInteract
|
||||
|
||||
|
||||
# TODO: The base strategies will be moved out of contrib to core code
|
||||
class BaseStrategy:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_risk_degree(self, date):
|
||||
"""get_risk_degree
|
||||
Return the proportion of your total value you will used in investment.
|
||||
Dynamically risk_degree will result in Market timing
|
||||
"""
|
||||
# It will use 95% amount of your total value by default
|
||||
return 0.95
|
||||
|
||||
def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date):
|
||||
"""
|
||||
DO NOT directly change the state of current
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
score_series : pd.Series
|
||||
stock_id , score.
|
||||
current : Position()
|
||||
current state of position.
|
||||
DO NOT directly change the state of current.
|
||||
trade_exchange : Exchange()
|
||||
trade exchange.
|
||||
pred_date : pd.Timestamp
|
||||
predict date.
|
||||
trade_date : pd.Timestamp
|
||||
trade date.
|
||||
"""
|
||||
pass
|
||||
|
||||
def update(self, score_series, pred_date, trade_date):
|
||||
"""User can use this method to update strategy state each trade date.
|
||||
Parameters
|
||||
-----------
|
||||
score_series : pd.Series
|
||||
stock_id , score.
|
||||
pred_date : pd.Timestamp
|
||||
oredict date.
|
||||
trade_date : pd.Timestamp
|
||||
trade date.
|
||||
"""
|
||||
pass
|
||||
|
||||
def init(self, **kwargs):
|
||||
"""Some strategy need to be initial after been implemented,
|
||||
User can use this method to init his strategy with parameters needed.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_init_args_from_model(self, model, init_date):
|
||||
"""
|
||||
This method only be used in 'online' module, it will generate the *args to initial the strategy.
|
||||
:param
|
||||
mode : model used in 'online' module.
|
||||
"""
|
||||
return {}
|
||||
|
||||
|
||||
class StrategyWrapper:
|
||||
"""
|
||||
StrategyWrapper is a wrapper of another strategy.
|
||||
By overriding some methods to make some changes on the basic strategy
|
||||
Cost control and risk control will base on this class.
|
||||
"""
|
||||
|
||||
def __init__(self, inner_strategy):
|
||||
"""__init__
|
||||
|
||||
:param inner_strategy: set the inner strategy.
|
||||
"""
|
||||
self.inner_strategy = inner_strategy
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""__getattr__
|
||||
|
||||
:param name: If no implementation in this method. Call the method in the innter_strategy by default.
|
||||
"""
|
||||
return getattr(self.inner_strategy, name)
|
||||
|
||||
|
||||
class AdjustTimer:
|
||||
"""AdjustTimer
|
||||
Responsible for timing of position adjusting
|
||||
|
||||
This is designed as multiple inheritance mechanism due to:
|
||||
- the is_adjust may need access to the internel state of a strategy.
|
||||
|
||||
- it can be reguard as a enhancement to the existing strategy.
|
||||
"""
|
||||
|
||||
# adjust position in each trade date
|
||||
def is_adjust(self, trade_date):
|
||||
"""is_adjust
|
||||
Return if the strategy can adjust positions on `trade_date`
|
||||
Will normally be used in strategy do trading with trade frequency
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
class ListAdjustTimer(AdjustTimer):
|
||||
def __init__(self, adjust_dates=None):
|
||||
"""__init__
|
||||
|
||||
:param adjust_dates: an iterable object, it will return a timelist for trading dates
|
||||
"""
|
||||
if adjust_dates is None:
|
||||
# None indicates that all dates is OK for adjusting
|
||||
self.adjust_dates = None
|
||||
else:
|
||||
self.adjust_dates = {pd.Timestamp(dt) for dt in adjust_dates}
|
||||
|
||||
def is_adjust(self, trade_date):
|
||||
if self.adjust_dates is None:
|
||||
return True
|
||||
return pd.Timestamp(trade_date) in self.adjust_dates
|
||||
|
||||
|
||||
class WeightStrategyBase(BaseStrategy, AdjustTimer):
|
||||
def __init__(self, order_generator_cls_or_obj=OrderGenWInteract, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if isinstance(order_generator_cls_or_obj, type):
|
||||
self.order_generator = order_generator_cls_or_obj()
|
||||
else:
|
||||
self.order_generator = order_generator_cls_or_obj
|
||||
|
||||
def generate_target_weight_position(self, score, current, trade_date):
|
||||
"""
|
||||
Generate target position from score for this date and the current position.The cash is not considered in the position
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
score : pd.Series
|
||||
pred score for this trade date, index is stock_id, contain 'score' column.
|
||||
current : Position()
|
||||
current position.
|
||||
trade_exchange : Exchange()
|
||||
trade_date : pd.Timestamp
|
||||
trade date.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date):
|
||||
"""
|
||||
Parameters
|
||||
-----------
|
||||
score_series : pd.Seires
|
||||
stock_id , score.
|
||||
current : Position()
|
||||
current of account.
|
||||
trade_exchange : Exchange()
|
||||
exchange.
|
||||
trade_date : pd.Timestamp
|
||||
date.
|
||||
"""
|
||||
# judge if to adjust
|
||||
if not self.is_adjust(trade_date):
|
||||
return []
|
||||
# generate_order_list
|
||||
# generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list
|
||||
current_temp = copy.deepcopy(current)
|
||||
target_weight_position = self.generate_target_weight_position(
|
||||
score=score_series, current=current_temp, trade_date=trade_date
|
||||
)
|
||||
|
||||
order_list = self.order_generator.generate_order_list_from_target_weight_position(
|
||||
current=current_temp,
|
||||
trade_exchange=trade_exchange,
|
||||
risk_degree=self.get_risk_degree(trade_date),
|
||||
target_weight_position=target_weight_position,
|
||||
pred_date=pred_date,
|
||||
trade_date=trade_date,
|
||||
)
|
||||
return order_list
|
||||
|
||||
|
||||
class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
|
||||
def __init__(
|
||||
self,
|
||||
topk,
|
||||
n_drop,
|
||||
method_sell="bottom",
|
||||
method_buy="top",
|
||||
risk_degree=0.95,
|
||||
thresh=1,
|
||||
hold_thresh=1,
|
||||
only_tradable=False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
-----------
|
||||
topk : int
|
||||
the number of stocks in the portfolio.
|
||||
n_drop : int
|
||||
number of stocks to be replaced in each trading date.
|
||||
method_sell : str
|
||||
dropout method_sell, random/bottom.
|
||||
method_buy : str
|
||||
dropout method_buy, random/top.
|
||||
risk_degree : float
|
||||
position percentage of total value.
|
||||
thresh : int
|
||||
minimun holding days since last buy singal of the stock.
|
||||
hold_thresh : int
|
||||
minimum holding days
|
||||
before sell stock , will check current.get_stock_count(order.stock_id) >= self.thresh.
|
||||
only_tradable : bool
|
||||
will the strategy only consider the tradable stock when buying and selling.
|
||||
if only_tradable:
|
||||
strategy will make buy sell decision without checking the tradable state of the stock.
|
||||
else:
|
||||
strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
|
||||
"""
|
||||
super(TopkDropoutStrategy, self).__init__()
|
||||
ListAdjustTimer.__init__(self, kwargs.get("adjust_dates", None))
|
||||
self.topk = topk
|
||||
self.n_drop = n_drop
|
||||
self.method_sell = method_sell
|
||||
self.method_buy = method_buy
|
||||
self.risk_degree = risk_degree
|
||||
self.thresh = thresh
|
||||
# self.stock_count['code'] will be the days the stock has been hold
|
||||
# since last buy signal. This is designed for thresh
|
||||
self.stock_count = {}
|
||||
|
||||
self.hold_thresh = hold_thresh
|
||||
self.only_tradable = only_tradable
|
||||
|
||||
def get_risk_degree(self, date):
|
||||
"""get_risk_degree
|
||||
Return the proportion of your total value you will used in investment.
|
||||
Dynamically risk_degree will result in Market timing.
|
||||
"""
|
||||
# It will use 95% amoutn of your total value by default
|
||||
return self.risk_degree
|
||||
|
||||
def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date):
|
||||
"""
|
||||
Generate order list according to score_series at trade_date, will not change current.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
score_series : pd.Series
|
||||
stock_id , score.
|
||||
current : Position()
|
||||
current of account.
|
||||
trade_exchange : Exchange()
|
||||
exchange.
|
||||
pred_date : pd.Timestamp
|
||||
predict date.
|
||||
trade_date : pd.Timestamp
|
||||
trade date.
|
||||
"""
|
||||
if not self.is_adjust(trade_date):
|
||||
return []
|
||||
|
||||
if self.only_tradable:
|
||||
# If The strategy only consider tradable stock when make decision
|
||||
# It needs following actions to filter stocks
|
||||
def get_first_n(l, n, reverse=False):
|
||||
cur_n = 0
|
||||
res = []
|
||||
for si in reversed(l) if reverse else l:
|
||||
if trade_exchange.is_stock_tradable(stock_id=si, trade_date=trade_date):
|
||||
res.append(si)
|
||||
cur_n += 1
|
||||
if cur_n >= n:
|
||||
break
|
||||
return res[::-1] if reverse else res
|
||||
|
||||
def get_last_n(l, n):
|
||||
return get_first_n(l, n, reverse=True)
|
||||
|
||||
def filter_stock(l):
|
||||
return [si for si in l if trade_exchange.is_stock_tradable(stock_id=si, trade_date=trade_date)]
|
||||
|
||||
else:
|
||||
# Otherwise, the stock will make decision with out the stock tradable info
|
||||
def get_first_n(l, n):
|
||||
return list(l)[:n]
|
||||
|
||||
def get_last_n(l, n):
|
||||
return list(l)[-n:]
|
||||
|
||||
def filter_stock(l):
|
||||
return l
|
||||
|
||||
current_temp = copy.deepcopy(current)
|
||||
# generate order list for this adjust date
|
||||
sell_order_list = []
|
||||
buy_order_list = []
|
||||
# load score
|
||||
cash = current_temp.get_cash()
|
||||
current_stock_list = current_temp.get_stock_list()
|
||||
# last position (sorted by score)
|
||||
last = score_series.reindex(current_stock_list).sort_values(ascending=False).index
|
||||
# The new stocks today want to buy **at most**
|
||||
if self.method_buy == "top":
|
||||
today = get_first_n(
|
||||
score_series[~score_series.index.isin(last)].sort_values(ascending=False).index,
|
||||
self.n_drop + self.topk - len(last),
|
||||
)
|
||||
elif self.method_buy == "random":
|
||||
topk_candi = get_first_n(score_series.sort_values(ascending=False).index, self.topk)
|
||||
candi = list(filter(lambda x: x not in last, topk_candi))
|
||||
n = self.n_drop + self.topk - len(last)
|
||||
try:
|
||||
today = np.random.choice(candi, n, replace=False)
|
||||
except ValueError:
|
||||
today = candi
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
# combine(new stocks + last stocks), we will drop stocks from this list
|
||||
# In case of dropping higher score stock and buying lower score stock.
|
||||
comb = score_series.reindex(last.union(pd.Index(today))).sort_values(ascending=False).index
|
||||
|
||||
# Get the stock list we really want to sell (After filtering the case that we sell high and buy low)
|
||||
if self.method_sell == "bottom":
|
||||
sell = last[last.isin(get_last_n(comb, self.n_drop))]
|
||||
elif self.method_sell == "random":
|
||||
candi = filter_stock(last)
|
||||
try:
|
||||
sell = pd.Index(np.random.choice(candi, self.n_drop, replace=False) if len(last) else [])
|
||||
except ValueError: # No enough candidates
|
||||
sell = candi
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
# Get the stock list we really want to buy
|
||||
buy = today[: len(sell) + self.topk - len(last)]
|
||||
|
||||
# buy singal: if a stock falls into topk, it appear in the buy_sinal
|
||||
buy_signal = score_series.sort_values(ascending=False).iloc[: self.topk].index
|
||||
|
||||
for code in current_stock_list:
|
||||
if not trade_exchange.is_stock_tradable(stock_id=code, trade_date=trade_date):
|
||||
continue
|
||||
if code in sell:
|
||||
# check hold limit
|
||||
if self.stock_count[code] < self.thresh or current_temp.get_stock_count(code) < self.hold_thresh:
|
||||
# can not sell this code
|
||||
# no buy signal, but the stock is kept
|
||||
self.stock_count[code] += 1
|
||||
continue
|
||||
# sell order
|
||||
sell_amount = current_temp.get_stock_amount(code=code)
|
||||
sell_order = Order(
|
||||
stock_id=code,
|
||||
amount=sell_amount,
|
||||
trade_date=trade_date,
|
||||
direction=Order.SELL, # 0 for sell, 1 for buy
|
||||
factor=trade_exchange.get_factor(code, trade_date),
|
||||
)
|
||||
# is order executable
|
||||
if trade_exchange.check_order(sell_order):
|
||||
sell_order_list.append(sell_order)
|
||||
trade_val, trade_cost, trade_price = trade_exchange.deal_order(sell_order, position=current_temp)
|
||||
# update cash
|
||||
cash += trade_val - trade_cost
|
||||
# sold
|
||||
del self.stock_count[code]
|
||||
else:
|
||||
# no buy signal, but the stock is kept
|
||||
self.stock_count[code] += 1
|
||||
elif code in buy_signal:
|
||||
# NOTE: This is different from the original version
|
||||
# get new buy signal
|
||||
# Only the stock fall in to topk will produce buy signal
|
||||
self.stock_count[code] = 1
|
||||
else:
|
||||
self.stock_count[code] += 1
|
||||
# buy new stock
|
||||
# note the current has been changed
|
||||
current_stock_list = current_temp.get_stock_list()
|
||||
value = cash * self.risk_degree / len(buy) if len(buy) > 0 else 0
|
||||
|
||||
# open_cost should be considered in the real trading environment, while the backtest in evaluate.py does not
|
||||
# consider it as the aim of demo is to accomplish same strategy as evaluate.py, so comment out this line
|
||||
# value = value / (1+trade_exchange.open_cost) # set open_cost limit
|
||||
for code in buy:
|
||||
# check is stock suspended
|
||||
if not trade_exchange.is_stock_tradable(stock_id=code, trade_date=trade_date):
|
||||
continue
|
||||
# buy order
|
||||
buy_price = trade_exchange.get_deal_price(stock_id=code, trade_date=trade_date)
|
||||
buy_amount = value / buy_price
|
||||
factor = trade_exchange.quote[(code, trade_date)]["$factor"]
|
||||
buy_amount = trade_exchange.round_amount_by_trade_unit(buy_amount, factor)
|
||||
buy_order = Order(
|
||||
stock_id=code,
|
||||
amount=buy_amount,
|
||||
trade_date=trade_date,
|
||||
direction=Order.BUY, # 1 for buy
|
||||
factor=factor,
|
||||
)
|
||||
buy_order_list.append(buy_order)
|
||||
self.stock_count[code] = 1
|
||||
return sell_order_list + buy_order_list
|
||||
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from .record_temp import MultiSegRecord
|
||||
from .record_temp import SignalMseRecord
|
||||
|
||||
@@ -1,16 +1,60 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import logging
|
||||
import pandas as pd
|
||||
from sklearn.metrics import mean_squared_error
|
||||
from pprint import pprint
|
||||
import numpy as np
|
||||
from sklearn.metrics import mean_squared_error
|
||||
from typing import Dict, Text, Any
|
||||
|
||||
from ...contrib.eva.alpha import calc_ic
|
||||
from ...workflow.record_temp import RecordTemp
|
||||
from ...workflow.record_temp import SignalRecord
|
||||
from ...data import dataset as qlib_dataset
|
||||
from ...log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
class MultiSegRecord(RecordTemp):
|
||||
"""
|
||||
This is the multiple segments signal record class that generates the signal prediction.
|
||||
This class inherits the ``RecordTemp`` class.
|
||||
"""
|
||||
|
||||
def __init__(self, model, dataset, recorder=None):
|
||||
super().__init__(recorder=recorder)
|
||||
if not isinstance(dataset, qlib_dataset.DatasetH):
|
||||
raise ValueError("The type of dataset is not DatasetH instead of {:}".format(type(dataset)))
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
|
||||
def generate(self, segments: Dict[Text, Any], save: bool = False):
|
||||
for key, segment in segments.items():
|
||||
predics = self.model.predict(self.dataset, segment)
|
||||
if isinstance(predics, pd.Series):
|
||||
predics = predics.to_frame("score")
|
||||
labels = self.dataset.prepare(
|
||||
segments=segment, col_set="label", data_key=qlib_dataset.handler.DataHandlerLP.DK_R
|
||||
)
|
||||
# Compute the IC and Rank IC
|
||||
ic, ric = calc_ic(predics.iloc[:, 0], labels.iloc[:, 0])
|
||||
results = {"all-IC": ic, "mean-IC": ic.mean(), "all-Rank-IC": ric, "mean-Rank-IC": ric.mean()}
|
||||
logger.info("--- Results for {:} ({:}) ---".format(key, segment))
|
||||
ic_x100, ric_x100 = ic * 100, ric * 100
|
||||
logger.info("IC: {:.4f}%".format(ic_x100.mean()))
|
||||
logger.info("ICIR: {:.4f}%".format(ic_x100.mean() / ic_x100.std()))
|
||||
logger.info("Rank IC: {:.4f}%".format(ric_x100.mean()))
|
||||
logger.info("Rank ICIR: {:.4f}%".format(ric_x100.mean() / ric_x100.std()))
|
||||
|
||||
if save:
|
||||
save_name = "results-{:}.pkl".format(key)
|
||||
self.recorder.save_objects(**{save_name: results})
|
||||
logger.info(
|
||||
"The record '{:}' has been saved as the artifact of the Experiment {:}".format(
|
||||
save_name, self.recorder.experiment_id
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class SignalMseRecord(SignalRecord):
|
||||
@@ -38,7 +82,7 @@ class SignalMseRecord(SignalRecord):
|
||||
objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)}
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.recorder.save_objects(**objects, artifact_path=self.get_path())
|
||||
pprint(metrics)
|
||||
logger.info("The evaluation results in SignalMseRecord is {:}".format(metrics))
|
||||
|
||||
def list(self):
|
||||
paths = [self.get_path("mse.pkl"), self.get_path("rmse.pkl")]
|
||||
|
||||
@@ -535,6 +535,9 @@ class LocalCalendarProvider(CalendarProvider):
|
||||
# if future calendar not exists, return current calendar
|
||||
if not os.path.exists(fname):
|
||||
get_module_logger("data").warning(f"{freq}_future.txt not exists, return current calendar!")
|
||||
get_module_logger("data").warning(
|
||||
"You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md"
|
||||
)
|
||||
fname = self._uri_cal.format(freq)
|
||||
else:
|
||||
fname = self._uri_cal.format(freq)
|
||||
@@ -1026,7 +1029,8 @@ class ClientProvider(BaseProvider):
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
if isinstance(Cal, ClientCalendarProvider):
|
||||
Cal.set_conn(self.client)
|
||||
Inst.set_conn(self.client)
|
||||
if isinstance(Inst, ClientInstrumentProvider):
|
||||
Inst.set_conn(self.client)
|
||||
if hasattr(DatasetD, "provider"):
|
||||
DatasetD.provider.set_conn(self.client)
|
||||
else:
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Union, List, Tuple, Dict, Text, Optional
|
||||
from ...utils import init_instance_by_config, np_ffill
|
||||
from ...log import get_module_logger
|
||||
from .handler import DataHandler, DataHandlerLP
|
||||
from copy import deepcopy
|
||||
from inspect import getfullargspec
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
@@ -16,22 +17,28 @@ class Dataset(Serializable):
|
||||
Preparing data for model training and inferencing.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
init is designed to finish following steps:
|
||||
|
||||
- init the sub instance and the state of the dataset(info to prepare the data)
|
||||
- The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing.
|
||||
|
||||
- setup data
|
||||
- The data related attributes' names should start with '_' so that it will not be saved on disk when serializing.
|
||||
|
||||
- initialize the state of the dataset(info to prepare the data)
|
||||
- The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing.
|
||||
|
||||
The data could specify the info to caculate the essential data for preparation
|
||||
"""
|
||||
self.setup_data(*args, **kwargs)
|
||||
self.setup_data(**kwargs)
|
||||
super().__init__()
|
||||
|
||||
def setup_data(self, *args, **kwargs):
|
||||
def config(self, **kwargs):
|
||||
"""
|
||||
config is designed to configure and parameters that cannot be learned from the data
|
||||
"""
|
||||
super().config(**kwargs)
|
||||
|
||||
def setup_data(self, **kwargs):
|
||||
"""
|
||||
Setup the data.
|
||||
|
||||
@@ -39,7 +46,7 @@ class Dataset(Serializable):
|
||||
|
||||
- User have a Dataset object with learned status on disk.
|
||||
|
||||
- User load the Dataset object from the disk(Note the init function is skiped).
|
||||
- User load the Dataset object from the disk.
|
||||
|
||||
- User call `setup_data` to load new data.
|
||||
|
||||
@@ -47,7 +54,7 @@ class Dataset(Serializable):
|
||||
"""
|
||||
pass
|
||||
|
||||
def prepare(self, *args, **kwargs) -> object:
|
||||
def prepare(self, **kwargs) -> object:
|
||||
"""
|
||||
The type of dataset depends on the model. (It could be pd.DataFrame, pytorch.DataLoader, etc.)
|
||||
The parameters should specify the scope for the prepared data
|
||||
@@ -76,44 +83,7 @@ class DatasetH(Dataset):
|
||||
- The processing is related to data split.
|
||||
"""
|
||||
|
||||
def init(self, handler_kwargs: dict = None, segment_kwargs: dict = None):
|
||||
"""
|
||||
Initialize the DatasetH
|
||||
|
||||
Parameters
|
||||
----------
|
||||
handler_kwargs : dict
|
||||
Config of DataHanlder, which could include the following arguments:
|
||||
|
||||
- arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'.
|
||||
|
||||
- arguments of DataHandler.init, such as 'enable_cache', etc.
|
||||
|
||||
segment_kwargs : dict
|
||||
Config of segments which is same as 'segments' in DatasetH.setup_data
|
||||
|
||||
"""
|
||||
if handler_kwargs:
|
||||
if not isinstance(handler_kwargs, dict):
|
||||
raise TypeError(f"param handler_kwargs must be type dict, not {type(handler_kwargs)}")
|
||||
kwargs_init = {}
|
||||
kwargs_conf_data = {}
|
||||
conf_data_arg = {"instruments", "start_time", "end_time"}
|
||||
for k, v in handler_kwargs.items():
|
||||
if k in conf_data_arg:
|
||||
kwargs_conf_data.update({k: v})
|
||||
else:
|
||||
kwargs_init.update({k: v})
|
||||
|
||||
self.handler.conf_data(**kwargs_conf_data)
|
||||
self.handler.init(**kwargs_init)
|
||||
|
||||
if segment_kwargs:
|
||||
if not isinstance(segment_kwargs, dict):
|
||||
raise TypeError(f"param handler_kwargs must be type dict, not {type(segment_kwargs)}")
|
||||
self.segments = segment_kwargs.copy()
|
||||
|
||||
def setup_data(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple]):
|
||||
def __init__(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple], **kwargs):
|
||||
"""
|
||||
Setup the underlying data.
|
||||
|
||||
@@ -144,6 +114,49 @@ class DatasetH(Dataset):
|
||||
"""
|
||||
self.handler = init_instance_by_config(handler, accept_types=DataHandler)
|
||||
self.segments = segments.copy()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def config(self, handler_kwargs: dict = None, **kwargs):
|
||||
"""
|
||||
Initialize the DatasetH
|
||||
|
||||
Parameters
|
||||
----------
|
||||
handler_kwargs : dict
|
||||
Config of DataHanlder, which could include the following arguments:
|
||||
|
||||
- arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'.
|
||||
|
||||
kwargs : dict
|
||||
Config of DatasetH, such as
|
||||
|
||||
- segments : dict
|
||||
Config of segments which is same as 'segments' in self.__init__
|
||||
|
||||
"""
|
||||
if handler_kwargs is not None:
|
||||
self.handler.config(**handler_kwargs)
|
||||
if "segments" in kwargs:
|
||||
self.segments = deepcopy(kwargs.pop("segments"))
|
||||
super().config(**kwargs)
|
||||
|
||||
def setup_data(self, handler_kwargs: dict = None, **kwargs):
|
||||
"""
|
||||
Setup the Data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
handler_kwargs : dict
|
||||
init arguments of DataHanlder, which could include the following arguments:
|
||||
|
||||
- init_type : Init Type of Handler
|
||||
|
||||
- enable_cache : wheter to enable cache
|
||||
|
||||
"""
|
||||
super().setup_data(**kwargs)
|
||||
if handler_kwargs is not None:
|
||||
self.handler.setup_data(**handler_kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(handler={handler}, segments={segments})".format(
|
||||
@@ -259,7 +272,7 @@ class TSDataSampler:
|
||||
self.fillna_type = fillna_type
|
||||
assert get_level_index(data, "datetime") == 0
|
||||
self.data = lazy_sort_index(data)
|
||||
self.data_arr = np.array(self.data) # Get index from numpy.array will much faster than DataFrame.values! But
|
||||
self.data_arr = np.array(self.data) # Get index from numpy.array will much faster than DataFrame.values!
|
||||
# NOTE: append last line with full NaN for better performance in `__getitem__`
|
||||
self.data_arr = np.append(self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan), axis=0)
|
||||
self.nan_idx = -1 # The last line is all NaN
|
||||
@@ -267,7 +280,6 @@ class TSDataSampler:
|
||||
# the data type will be changed
|
||||
# The index of usable data is between start_idx and end_idx
|
||||
self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
|
||||
# self.index_link = self.build_link(self.data)
|
||||
self.idx_df, self.idx_map = self.build_index(self.data)
|
||||
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
|
||||
|
||||
@@ -434,15 +446,19 @@ class TSDatasetH(DatasetH):
|
||||
- The dimension of a batch of data <batch_idx, feature, timestep>
|
||||
"""
|
||||
|
||||
def __init__(self, step_len=30, *args, **kwargs):
|
||||
def __init__(self, step_len=30, **kwargs):
|
||||
self.step_len = step_len
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def setup_data(self, *args, **kwargs):
|
||||
super().setup_data(*args, **kwargs)
|
||||
def config(self, **kwargs):
|
||||
if "step_len" in kwargs:
|
||||
self.step_len = kwargs.pop("step_len")
|
||||
super().config(**kwargs)
|
||||
|
||||
def setup_data(self, **kwargs):
|
||||
super().setup_data(**kwargs)
|
||||
cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique()
|
||||
cal = sorted(cal)
|
||||
# Get the datatime index for building timestamp
|
||||
self.cal = cal
|
||||
|
||||
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
|
||||
|
||||
@@ -6,6 +6,7 @@ import abc
|
||||
import bisect
|
||||
import logging
|
||||
import warnings
|
||||
from inspect import getfullargspec
|
||||
from typing import Union, Tuple, List, Iterator, Optional
|
||||
|
||||
import pandas as pd
|
||||
@@ -16,7 +17,7 @@ from ...data import D
|
||||
from ...config import C
|
||||
from ...utils import parse_config, transform_end_date, init_instance_by_config
|
||||
from ...utils.serial import Serializable
|
||||
from .utils import get_level_index, fetch_df_by_index
|
||||
from .utils import fetch_df_by_index
|
||||
from pathlib import Path
|
||||
from .loader import DataLoader
|
||||
|
||||
@@ -99,10 +100,10 @@ class DataHandler(Serializable):
|
||||
self.fetch_orig = fetch_orig
|
||||
if init_data:
|
||||
with TimeInspector.logt("Init data"):
|
||||
self.init()
|
||||
self.setup_data()
|
||||
super().__init__()
|
||||
|
||||
def conf_data(self, **kwargs):
|
||||
def config(self, **kwargs):
|
||||
"""
|
||||
configuration of data.
|
||||
# what data to be loaded from data source
|
||||
@@ -115,13 +116,16 @@ class DataHandler(Serializable):
|
||||
for k, v in kwargs.items():
|
||||
if k in attr_list:
|
||||
setattr(self, k, v)
|
||||
else:
|
||||
raise KeyError("Such config is not supported.")
|
||||
|
||||
def init(self, enable_cache: bool = False):
|
||||
for attr in attr_list:
|
||||
if attr in kwargs:
|
||||
kwargs.pop(attr)
|
||||
|
||||
super().config(**kwargs)
|
||||
|
||||
def setup_data(self, enable_cache: bool = False):
|
||||
"""
|
||||
initialize the data.
|
||||
In case of running intialization for multiple time, it will do nothing for the second time.
|
||||
Set Up the data in case of running intialization for multiple time
|
||||
|
||||
It is responsible for maintaining following variable
|
||||
1) self._data
|
||||
@@ -405,14 +409,28 @@ class DataHandlerLP(DataHandler):
|
||||
if self.drop_raw:
|
||||
del self._data
|
||||
|
||||
def config(self, processor_kwargs: dict = None, **kwargs):
|
||||
"""
|
||||
configuration of data.
|
||||
# what data to be loaded from data source
|
||||
|
||||
This method will be used when loading pickled handler from dataset.
|
||||
The data will be initialized with different time range.
|
||||
|
||||
"""
|
||||
super().config(**kwargs)
|
||||
if processor_kwargs is not None:
|
||||
for processor in self.get_all_processors():
|
||||
processor.config(**processor_kwargs)
|
||||
|
||||
# init type
|
||||
IT_FIT_SEQ = "fit_seq" # the input of `fit` will be the output of the previous processor
|
||||
IT_FIT_IND = "fit_ind" # the input of `fit` will be the original df
|
||||
IT_LS = "load_state" # The state of the object has been load by pickle
|
||||
|
||||
def init(self, init_type: str = IT_FIT_SEQ, enable_cache: bool = False):
|
||||
def setup_data(self, init_type: str = IT_FIT_SEQ, **kwargs):
|
||||
"""
|
||||
Initialize the data of Qlib
|
||||
Set up the data in case of running intialization for multiple time
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -427,7 +445,7 @@ class DataHandlerLP(DataHandler):
|
||||
when we call `init` next time
|
||||
"""
|
||||
# init raw data
|
||||
super().init(enable_cache=enable_cache)
|
||||
super().setup_data(**kwargs)
|
||||
|
||||
with TimeInspector.logt("fit & process data"):
|
||||
if init_type == DataHandlerLP.IT_FIT_IND:
|
||||
|
||||
@@ -217,3 +217,64 @@ class StaticDataLoader(DataLoader):
|
||||
join=self.join,
|
||||
)
|
||||
self._data.sort_index(inplace=True)
|
||||
|
||||
|
||||
class DataLoaderDH(DataLoader):
|
||||
"""DataLoaderDH
|
||||
DataLoader based on (D)ata (H)andler
|
||||
It is designed to load multiple data from data handler
|
||||
- If you just want to load data from single datahandler, you can write them in single data handler
|
||||
"""
|
||||
|
||||
def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
handler_config : dict
|
||||
handler_config will be used to describe the handlers
|
||||
|
||||
.. code-block::
|
||||
|
||||
<handler_config> := {
|
||||
"group_name1": <handler>
|
||||
"group_name2": <handler>
|
||||
}
|
||||
or
|
||||
<handler_config> := <handler>
|
||||
<handler> := DataHandler Instance | DataHandler Config
|
||||
|
||||
fetch_kwargs : dict
|
||||
fetch_kwargs will be used to describe the different arguments of fetch method, such as col_set, squeeze, data_key, etc.
|
||||
|
||||
is_group: bool
|
||||
is_group will be used to describe whether the key of handler_config is group
|
||||
|
||||
"""
|
||||
from qlib.data.dataset.handler import DataHandler
|
||||
|
||||
if is_group:
|
||||
self.handlers = {
|
||||
grp: init_instance_by_config(config, accept_types=DataHandler) for grp, config in handler_config.items()
|
||||
}
|
||||
else:
|
||||
self.handlers = init_instance_by_config(handler_config, accept_types=DataHandler)
|
||||
|
||||
self.is_group = is_group
|
||||
self.fetch_kwargs = {"col_set": DataHandler.CS_RAW}
|
||||
self.fetch_kwargs.update(fetch_kwargs)
|
||||
|
||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
if instruments is not None:
|
||||
LOG.warning(f"instruments[{instruments}] is ignored")
|
||||
|
||||
if self.is_group:
|
||||
df = pd.concat(
|
||||
{
|
||||
grp: dh.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs)
|
||||
for grp, dh in self.handlers.items()
|
||||
},
|
||||
axis=1,
|
||||
)
|
||||
else:
|
||||
df = self.handlers.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs)
|
||||
return df
|
||||
|
||||
15
qlib/data/dataset/processor.py
Executable file → Normal file
15
qlib/data/dataset/processor.py
Executable file → Normal file
@@ -72,6 +72,17 @@ class Processor(Serializable):
|
||||
"""
|
||||
return True
|
||||
|
||||
def config(self, **kwargs):
|
||||
attr_list = {"fit_start_time", "fit_end_time"}
|
||||
for k, v in kwargs.items():
|
||||
if k in attr_list and hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
for attr in attr_list:
|
||||
if attr in kwargs:
|
||||
kwargs.pop(attr)
|
||||
super().config(**kwargs)
|
||||
|
||||
|
||||
class DropnaProcessor(Processor):
|
||||
def __init__(self, fields_group=None):
|
||||
@@ -118,7 +129,7 @@ class FilterCol(Processor):
|
||||
|
||||
|
||||
class TanhProcess(Processor):
|
||||
""" Use tanh to process noise data"""
|
||||
"""Use tanh to process noise data"""
|
||||
|
||||
def __call__(self, df):
|
||||
def tanh_denoise(data):
|
||||
@@ -133,7 +144,7 @@ class TanhProcess(Processor):
|
||||
|
||||
|
||||
class ProcessInf(Processor):
|
||||
"""Process infinity """
|
||||
"""Process infinity"""
|
||||
|
||||
def __call__(self, df):
|
||||
def replace_inf(data):
|
||||
|
||||
38
qlib/log.py
38
qlib/log.py
@@ -12,7 +12,41 @@ from contextlib import contextmanager
|
||||
from .config import C
|
||||
|
||||
|
||||
def get_module_logger(module_name, level: Optional[int] = None):
|
||||
class MetaLogger(type):
|
||||
def __new__(cls, name, bases, dict):
|
||||
wrapper_dict = logging.Logger.__dict__.copy()
|
||||
for key in wrapper_dict:
|
||||
if key not in dict and key != "__reduce__":
|
||||
dict[key] = wrapper_dict[key]
|
||||
return type.__new__(cls, name, bases, dict)
|
||||
|
||||
|
||||
class QlibLogger(metaclass=MetaLogger):
|
||||
"""
|
||||
Customized logger for Qlib.
|
||||
"""
|
||||
|
||||
def __init__(self, module_name):
|
||||
self.module_name = module_name
|
||||
self.level = 0
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
logger = logging.getLogger(self.module_name)
|
||||
logger.setLevel(self.level)
|
||||
return logger
|
||||
|
||||
def setLevel(self, level):
|
||||
self.level = level
|
||||
|
||||
def __getattr__(self, name):
|
||||
# During unpickling, python will call __getattr__. Use this line to avoid maximum recursion error.
|
||||
if name in {"__setstate__"}:
|
||||
raise AttributeError
|
||||
return self.logger.__getattribute__(name)
|
||||
|
||||
|
||||
def get_module_logger(module_name, level: Optional[int] = None) -> logging.Logger:
|
||||
"""
|
||||
Get a logger for a specific module.
|
||||
|
||||
@@ -27,7 +61,7 @@ def get_module_logger(module_name, level: Optional[int] = None):
|
||||
|
||||
module_name = "qlib.{}".format(module_name)
|
||||
# Get logger.
|
||||
module_logger = logging.getLogger(module_name)
|
||||
module_logger = QlibLogger(module_name)
|
||||
module_logger.setLevel(level)
|
||||
return module_logger
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import abc
|
||||
from typing import Text, Union
|
||||
from ..utils.serial import Serializable
|
||||
from ..data.dataset import Dataset
|
||||
|
||||
@@ -10,11 +11,11 @@ class BaseModel(Serializable, metaclass=abc.ABCMeta):
|
||||
|
||||
@abc.abstractmethod
|
||||
def predict(self, *args, **kwargs) -> object:
|
||||
""" Make predictions after modeling things """
|
||||
"""Make predictions after modeling things"""
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs) -> object:
|
||||
""" leverage Python syntactic sugar to make the models' behaviors like functions """
|
||||
"""leverage Python syntactic sugar to make the models' behaviors like functions"""
|
||||
return self.predict(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -59,7 +60,7 @@ class Model(BaseModel):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def predict(self, dataset: Dataset) -> object:
|
||||
def predict(self, dataset: Dataset, segment: Union[Text, slice] = "test") -> object:
|
||||
"""give prediction given Dataset
|
||||
|
||||
Parameters
|
||||
@@ -67,6 +68,9 @@ class Model(BaseModel):
|
||||
dataset : Dataset
|
||||
dataset will generate the processed dataset from model training.
|
||||
|
||||
segment : Text or slice
|
||||
dataset will use this segment to prepare data. (default=test)
|
||||
|
||||
Returns
|
||||
-------
|
||||
Prediction results with certain type such as `pandas.Series`.
|
||||
|
||||
@@ -5,9 +5,9 @@ import abc
|
||||
|
||||
|
||||
class BaseOptimizer(abc.ABC):
|
||||
""" Construct portfolio with a optimization related method """
|
||||
"""Construct portfolio with a optimization related method"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __call__(self, *args, **kwargs) -> object:
|
||||
""" Generate a optimized portfolio allocation """
|
||||
"""Generate a optimized portfolio allocation"""
|
||||
pass
|
||||
|
||||
@@ -23,7 +23,10 @@ class QlibRecorder:
|
||||
@contextmanager
|
||||
def start(
|
||||
self,
|
||||
*,
|
||||
experiment_id: Optional[Text] = None,
|
||||
experiment_name: Optional[Text] = None,
|
||||
recorder_id: Optional[Text] = None,
|
||||
recorder_name: Optional[Text] = None,
|
||||
uri: Optional[Text] = None,
|
||||
resume: bool = False,
|
||||
@@ -45,8 +48,12 @@ class QlibRecorder:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_id : str
|
||||
id of the experiment one wants to start.
|
||||
experiment_name : str
|
||||
name of the experiment one wants to start.
|
||||
recorder_id : str
|
||||
id of the recorder under the experiment one wants to start.
|
||||
recorder_name : str
|
||||
name of the recorder under the experiment one wants to start.
|
||||
uri : str
|
||||
@@ -57,7 +64,14 @@ class QlibRecorder:
|
||||
resume : bool
|
||||
whether to resume the specific recorder with given name under the given experiment.
|
||||
"""
|
||||
run = self.start_exp(experiment_name, recorder_name, uri, resume)
|
||||
run = self.start_exp(
|
||||
experiment_id=experiment_id,
|
||||
experiment_name=experiment_name,
|
||||
recorder_id=recorder_id,
|
||||
recorder_name=recorder_name,
|
||||
uri=uri,
|
||||
resume=resume,
|
||||
)
|
||||
try:
|
||||
yield run
|
||||
except Exception as e:
|
||||
@@ -65,7 +79,9 @@ class QlibRecorder:
|
||||
raise e
|
||||
self.end_exp(Recorder.STATUS_FI)
|
||||
|
||||
def start_exp(self, experiment_name=None, recorder_name=None, uri=None, resume=False):
|
||||
def start_exp(
|
||||
self, *, experiment_id=None, experiment_name=None, recorder_id=None, recorder_name=None, uri=None, resume=False
|
||||
):
|
||||
"""
|
||||
Lower level method for starting an experiment. When use this method, one should end the experiment manually
|
||||
and the status of the recorder may not be handled properly. Here is the example code:
|
||||
@@ -79,8 +95,12 @@ class QlibRecorder:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_id : str
|
||||
id of the experiment one wants to start.
|
||||
experiment_name : str
|
||||
the name of the experiment to be started
|
||||
recorder_id : str
|
||||
id of the recorder under the experiment one wants to start.
|
||||
recorder_name : str
|
||||
name of the recorder under the experiment one wants to start.
|
||||
uri : str
|
||||
@@ -93,7 +113,14 @@ class QlibRecorder:
|
||||
-------
|
||||
An experiment instance being started.
|
||||
"""
|
||||
return self.exp_manager.start_exp(experiment_name, recorder_name, uri, resume)
|
||||
return self.exp_manager.start_exp(
|
||||
experiment_id=experiment_id,
|
||||
experiment_name=experiment_name,
|
||||
recorder_id=recorder_id,
|
||||
recorder_name=recorder_name,
|
||||
uri=uri,
|
||||
resume=resume,
|
||||
)
|
||||
|
||||
def end_exp(self, recorder_status=Recorder.STATUS_FI):
|
||||
"""
|
||||
@@ -202,13 +229,13 @@ class QlibRecorder:
|
||||
|
||||
- no id or name specified, return the active experiment.
|
||||
|
||||
- if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name, and the experiment is set to be active.
|
||||
- if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name.
|
||||
|
||||
- If `active experiment` not exists:
|
||||
|
||||
- no id or name specified, create a default experiment, and the experiment is set to be active.
|
||||
|
||||
- if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given name or the default experiment, and the experiment is set to be active.
|
||||
- if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given name or the default experiment.
|
||||
|
||||
- Else If '`create`' is False:
|
||||
|
||||
@@ -260,7 +287,7 @@ class QlibRecorder:
|
||||
-------
|
||||
An experiment instance with given id or name.
|
||||
"""
|
||||
return self.exp_manager.get_exp(experiment_id, experiment_name, create)
|
||||
return self.exp_manager.get_exp(experiment_id, experiment_name, create, start=False)
|
||||
|
||||
def delete_exp(self, experiment_id=None, experiment_name=None):
|
||||
"""
|
||||
@@ -358,7 +385,7 @@ class QlibRecorder:
|
||||
A recorder instance.
|
||||
"""
|
||||
return self.get_exp(experiment_name=experiment_name, create=False).get_recorder(
|
||||
recorder_id, recorder_name, create=False
|
||||
recorder_id, recorder_name, create=False, start=False
|
||||
)
|
||||
|
||||
def delete_recorder(self, recorder_id=None, recorder_name=None):
|
||||
@@ -416,6 +443,12 @@ class QlibRecorder:
|
||||
"""
|
||||
self.get_exp().get_recorder().save_objects(local_path, artifact_path, **kwargs)
|
||||
|
||||
def load_object(self, name: Text):
|
||||
"""
|
||||
Method for loading an object from artifacts in the experiment in the uri.
|
||||
"""
|
||||
return self.get_exp().get_recorder().load_object(name)
|
||||
|
||||
def log_params(self, **kwargs):
|
||||
"""
|
||||
Method for logging parameters during an experiment. In addition to using ``R``, one can also log to a specific recorder after getting it with `get_recorder` API.
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import mlflow
|
||||
import mlflow, logging
|
||||
from mlflow.entities import ViewType
|
||||
from mlflow.exceptions import MlflowException
|
||||
from pathlib import Path
|
||||
from .recorder import Recorder, MLflowRecorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
class Experiment:
|
||||
@@ -39,12 +39,14 @@ class Experiment:
|
||||
output["recorders"] = list(recorders.keys())
|
||||
return output
|
||||
|
||||
def start(self, recorder_name=None, resume=False):
|
||||
def start(self, *, recorder_id=None, recorder_name=None, resume=False):
|
||||
"""
|
||||
Start the experiment and set it to be active. This method will also start a new recorder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recorder_id : str
|
||||
the id of the recorder to be created.
|
||||
recorder_name : str
|
||||
the name of the recorder to be created.
|
||||
resume : bool
|
||||
@@ -107,24 +109,24 @@ class Experiment:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `delete_recorder` method.")
|
||||
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None, create: bool = True):
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None, create: bool = True, start: bool = False):
|
||||
"""
|
||||
Retrieve a Recorder for user. When user specify recorder id and name, the method will try to return the
|
||||
specific recorder. When user does not provide recorder id or name, the method will try to return the current
|
||||
active recorder. The `create` argument determines whether the method will automatically create a new recorder
|
||||
according to user's specification if the recorder hasn't been created before
|
||||
according to user's specification if the recorder hasn't been created before.
|
||||
|
||||
* If `create` is True:
|
||||
|
||||
* If `active recorder` exists:
|
||||
|
||||
* no id or name specified, return the active recorder.
|
||||
* if id or name is specified, return the specified recorder. If no such exp found, create a new recorder with given id or name, and the recorder shoud be active.
|
||||
* if id or name is specified, return the specified recorder. If no such exp found, create a new recorder with given id or name. If `start` is set to be True, the recorder is set to be active.
|
||||
|
||||
* If `active recorder` not exists:
|
||||
|
||||
* no id or name specified, create a new recorder.
|
||||
* if id or name is specified, return the specified experiment. If no such exp found, create a new recorder with given id or name, and the recorder shoud be active.
|
||||
* if id or name is specified, return the specified experiment. If no such exp found, create a new recorder with given id or name. If `start` is set to be True, the recorder is set to be active.
|
||||
|
||||
* Else If `create` is False:
|
||||
|
||||
@@ -146,6 +148,8 @@ class Experiment:
|
||||
the name of the recorder to be deleted.
|
||||
create : boolean
|
||||
create the recorder if it hasn't been created before.
|
||||
start : boolean
|
||||
start the new recorder if one is created.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -159,8 +163,11 @@ class Experiment:
|
||||
if create:
|
||||
recorder, is_new = self._get_or_create_rec(recorder_id=recorder_id, recorder_name=recorder_name)
|
||||
else:
|
||||
recorder, is_new = self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
|
||||
if is_new:
|
||||
recorder, is_new = (
|
||||
self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name),
|
||||
False,
|
||||
)
|
||||
if is_new and start:
|
||||
self.active_recorder = recorder
|
||||
# start the recorder
|
||||
self.active_recorder.start_run()
|
||||
@@ -174,7 +181,10 @@ class Experiment:
|
||||
try:
|
||||
if recorder_id is None and recorder_name is None:
|
||||
recorder_name = self._default_rec_name
|
||||
return self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
|
||||
return (
|
||||
self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name),
|
||||
False,
|
||||
)
|
||||
except ValueError:
|
||||
if recorder_name is None:
|
||||
recorder_name = self._default_rec_name
|
||||
@@ -230,14 +240,14 @@ class MLflowExperiment(Experiment):
|
||||
def __repr__(self):
|
||||
return "{name}(id={id}, info={info})".format(name=self.__class__.__name__, id=self.id, info=self.info)
|
||||
|
||||
def start(self, recorder_name=None, resume=False):
|
||||
def start(self, *, recorder_id=None, recorder_name=None, resume=False):
|
||||
logger.info(f"Experiment {self.id} starts running ...")
|
||||
# Get or create recorder
|
||||
if recorder_name is None:
|
||||
recorder_name = self._default_rec_name
|
||||
# resume the recorder
|
||||
if resume:
|
||||
recorder, _ = self._get_or_create_rec(recorder_name=recorder_name)
|
||||
recorder, _ = self._get_or_create_rec(recorder_id=recorder_id, recorder_name=recorder_name)
|
||||
# create a new recorder
|
||||
else:
|
||||
recorder = self.create_recorder(recorder_name)
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import mlflow
|
||||
from mlflow.exceptions import MlflowException
|
||||
from mlflow.entities import ViewType
|
||||
import os
|
||||
import os, logging
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Text
|
||||
@@ -14,7 +14,7 @@ from ..config import C
|
||||
from .recorder import Recorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
class ExpManager:
|
||||
@@ -33,7 +33,10 @@ class ExpManager:
|
||||
|
||||
def start_exp(
|
||||
self,
|
||||
*,
|
||||
experiment_id: Optional[Text] = None,
|
||||
experiment_name: Optional[Text] = None,
|
||||
recorder_id: Optional[Text] = None,
|
||||
recorder_name: Optional[Text] = None,
|
||||
uri: Optional[Text] = None,
|
||||
resume: bool = False,
|
||||
@@ -45,8 +48,12 @@ class ExpManager:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_id : str
|
||||
id of the active experiment.
|
||||
experiment_name : str
|
||||
name of the active experiment.
|
||||
recorder_id : str
|
||||
id of the recorder to be started.
|
||||
recorder_name : str
|
||||
name of the recorder to be started.
|
||||
uri : str
|
||||
@@ -102,10 +109,9 @@ class ExpManager:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `search_records` method.")
|
||||
|
||||
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True):
|
||||
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False):
|
||||
"""
|
||||
Retrieve an experiment. This method includes getting an active experiment, and get_or_create a specific experiment.
|
||||
The returned experiment will be active.
|
||||
|
||||
When user specify experiment id and name, the method will try to return the specific experiment.
|
||||
When user does not provide recorder id or name, the method will try to return the current active experiment.
|
||||
@@ -117,12 +123,12 @@ class ExpManager:
|
||||
* If `active experiment` exists:
|
||||
|
||||
* no id or name specified, return the active experiment.
|
||||
* if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name, and the experiment is set to be active.
|
||||
* if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name. If `start` is set to be True, the experiment is set to be active.
|
||||
|
||||
* If `active experiment` not exists:
|
||||
|
||||
* no id or name specified, create a default experiment.
|
||||
* if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name, and the experiment is set to be active.
|
||||
* if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name. If `start` is set to be True, the experiment is set to be active.
|
||||
|
||||
* Else If `create` is False:
|
||||
|
||||
@@ -144,6 +150,8 @@ class ExpManager:
|
||||
name of the experiment to return.
|
||||
create : boolean
|
||||
create the experiment it if hasn't been created before.
|
||||
start : boolean
|
||||
start the new experiment if one is created.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -159,8 +167,11 @@ class ExpManager:
|
||||
if create:
|
||||
exp, is_new = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name)
|
||||
else:
|
||||
exp, is_new = self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False
|
||||
if is_new:
|
||||
exp, is_new = (
|
||||
self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name),
|
||||
False,
|
||||
)
|
||||
if is_new and start:
|
||||
self.active_experiment = exp
|
||||
# start the recorder
|
||||
self.active_experiment.start()
|
||||
@@ -172,7 +183,10 @@ class ExpManager:
|
||||
automatically create a new experiment based on the given id and name.
|
||||
"""
|
||||
try:
|
||||
return self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False
|
||||
return (
|
||||
self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name),
|
||||
False,
|
||||
)
|
||||
except ValueError:
|
||||
if experiment_name is None:
|
||||
experiment_name = self._default_exp_name
|
||||
@@ -291,7 +305,10 @@ class MLflowExpManager(ExpManager):
|
||||
|
||||
def start_exp(
|
||||
self,
|
||||
*,
|
||||
experiment_id: Optional[Text] = None,
|
||||
experiment_name: Optional[Text] = None,
|
||||
recorder_id: Optional[Text] = None,
|
||||
recorder_name: Optional[Text] = None,
|
||||
uri: Optional[Text] = None,
|
||||
resume: bool = False,
|
||||
@@ -301,11 +318,11 @@ class MLflowExpManager(ExpManager):
|
||||
# Create experiment
|
||||
if experiment_name is None:
|
||||
experiment_name = self._default_exp_name
|
||||
experiment, _ = self._get_or_create_exp(experiment_name=experiment_name)
|
||||
experiment, _ = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name)
|
||||
# Set up active experiment
|
||||
self.active_experiment = experiment
|
||||
# Start the experiment
|
||||
self.active_experiment.start(recorder_name, resume)
|
||||
self.active_experiment.start(recorder_id=recorder_id, recorder_name=recorder_name, resume=resume)
|
||||
|
||||
return self.active_experiment
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import logging
|
||||
import warnings
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
@@ -18,7 +19,7 @@ from ..strategy.base import BaseStrategy
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return
|
||||
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
class RecordTemp:
|
||||
@@ -41,7 +42,13 @@ class RecordTemp:
|
||||
return "/".join(names)
|
||||
|
||||
def __init__(self, recorder):
|
||||
self.recorder = recorder
|
||||
self._recorder = recorder
|
||||
|
||||
@property
|
||||
def recorder(self):
|
||||
if self._recorder is None:
|
||||
raise ValueError("This RecordTemp did not set recorder yet.")
|
||||
return self._recorder
|
||||
|
||||
def generate(self, **kwargs):
|
||||
"""
|
||||
@@ -158,6 +165,60 @@ class SignalRecord(RecordTemp):
|
||||
return super().load(name)
|
||||
|
||||
|
||||
class HFSignalRecord(SignalRecord):
|
||||
"""
|
||||
This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class.
|
||||
"""
|
||||
|
||||
artifact_path = "hg_sig_analysis"
|
||||
|
||||
def __init__(self, recorder, **kwargs):
|
||||
super().__init__(recorder=recorder)
|
||||
|
||||
def generate(self):
|
||||
pred = self.load("pred.pkl")
|
||||
raw_label = self.load("label.pkl")
|
||||
long_pre, short_pre = calc_long_short_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha=True)
|
||||
ic, ric = calc_ic(pred.iloc[:, 0], raw_label.iloc[:, 0])
|
||||
metrics = {
|
||||
"IC": ic.mean(),
|
||||
"ICIR": ic.mean() / ic.std(),
|
||||
"Rank IC": ric.mean(),
|
||||
"Rank ICIR": ric.mean() / ric.std(),
|
||||
"Long precision": long_pre.mean(),
|
||||
"Short precision": short_pre.mean(),
|
||||
}
|
||||
objects = {"ic.pkl": ic, "ric.pkl": ric}
|
||||
objects.update({"long_pre.pkl": long_pre, "short_pre.pkl": short_pre})
|
||||
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], raw_label.iloc[:, 0])
|
||||
metrics.update(
|
||||
{
|
||||
"Long-Short Average Return": long_short_r.mean(),
|
||||
"Long-Short Average Sharpe": long_short_r.mean() / long_short_r.std(),
|
||||
}
|
||||
)
|
||||
objects.update(
|
||||
{
|
||||
"long_short_r.pkl": long_short_r,
|
||||
"long_avg_r.pkl": long_avg_r,
|
||||
}
|
||||
)
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.recorder.save_objects(**objects, artifact_path=self.get_path())
|
||||
pprint(metrics)
|
||||
|
||||
def list(self):
|
||||
paths = [
|
||||
self.get_path("ic.pkl"),
|
||||
self.get_path("ric.pkl"),
|
||||
self.get_path("long_pre.pkl"),
|
||||
self.get_path("short_pre.pkl"),
|
||||
self.get_path("long_short_r.pkl"),
|
||||
self.get_path("long_avg_r.pkl"),
|
||||
]
|
||||
return paths
|
||||
|
||||
|
||||
class SigAnaRecord(SignalRecord):
|
||||
"""
|
||||
This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class.
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import mlflow
|
||||
import mlflow, logging
|
||||
import shutil, os, pickle, tempfile, codecs, pickle
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from ..utils.objm import FileManager
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
class Recorder:
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys, traceback, signal, atexit
|
||||
import sys, traceback, signal, atexit, logging
|
||||
from . import R
|
||||
from .recorder import Recorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
# function to handle the experiment when unusual program ending occurs
|
||||
|
||||
@@ -66,7 +66,7 @@ class CheckBin:
|
||||
self.csv_files = sorted(csv_path.glob(f"*{file_suffix}") if csv_path.is_dir() else [csv_path])
|
||||
|
||||
if check_fields is None:
|
||||
check_fields = list(map(lambda x: x.split(".")[0], bin_path_list[0].glob(f"*.bin")))
|
||||
check_fields = list(map(lambda x: x.name.split(".")[0], bin_path_list[0].glob(f"*.bin")))
|
||||
else:
|
||||
check_fields = check_fields.split(",") if isinstance(check_fields, str) else check_fields
|
||||
self.check_fields = list(map(lambda x: x.strip(), check_fields))
|
||||
@@ -91,6 +91,7 @@ class CheckBin:
|
||||
origin_df[self.symbol_field_name] = symbol
|
||||
origin_df.set_index([self.symbol_field_name, self.date_field_name], inplace=True)
|
||||
origin_df.index.names = qlib_df.index.names
|
||||
origin_df = origin_df.reindex(qlib_df.index)
|
||||
try:
|
||||
compare = datacompy.Compare(
|
||||
origin_df,
|
||||
|
||||
@@ -46,7 +46,7 @@ class BaseCollector(abc.ABC):
|
||||
Parameters
|
||||
----------
|
||||
save_dir: str
|
||||
stock save dir
|
||||
instrument save dir
|
||||
max_workers: int
|
||||
workers, default 4
|
||||
max_collector_count: int
|
||||
@@ -77,11 +77,11 @@ class BaseCollector(abc.ABC):
|
||||
self.start_datetime = self.normalize_start_datetime(start)
|
||||
self.end_datetime = self.normalize_end_datetime(end)
|
||||
|
||||
self.stock_list = sorted(set(self.get_stock_list()))
|
||||
self.instrument_list = sorted(set(self.get_instrument_list()))
|
||||
|
||||
if limit_nums is not None:
|
||||
try:
|
||||
self.stock_list = self.stock_list[: int(limit_nums)]
|
||||
self.instrument_list = self.instrument_list[: int(limit_nums)]
|
||||
except Exception as e:
|
||||
logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored")
|
||||
|
||||
@@ -108,8 +108,8 @@ class BaseCollector(abc.ABC):
|
||||
raise NotImplementedError("rewrite min_numbers_trading")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_stock_list(self):
|
||||
raise NotImplementedError("rewrite get_stock_list")
|
||||
def get_instrument_list(self):
|
||||
raise NotImplementedError("rewrite get_instrument_list")
|
||||
|
||||
@abc.abstractmethod
|
||||
def normalize_symbol(self, symbol: str):
|
||||
@@ -158,27 +158,27 @@ class BaseCollector(abc.ABC):
|
||||
return _result
|
||||
|
||||
def save_instrument(self, symbol, df: pd.DataFrame):
|
||||
"""save stock data to file
|
||||
"""save instrument data to file
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
stock code
|
||||
instrument code
|
||||
df : pd.DataFrame
|
||||
df.columns must contain "symbol" and "datetime"
|
||||
"""
|
||||
if df.empty:
|
||||
if df is None or df.empty:
|
||||
logger.warning(f"{symbol} is empty")
|
||||
return
|
||||
|
||||
symbol = self.normalize_symbol(symbol)
|
||||
symbol = code_to_fname(symbol)
|
||||
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
|
||||
instrument_path = self.save_dir.joinpath(f"{symbol}.csv")
|
||||
df["symbol"] = symbol
|
||||
if stock_path.exists():
|
||||
_old_df = pd.read_csv(stock_path)
|
||||
if instrument_path.exists():
|
||||
_old_df = pd.read_csv(instrument_path)
|
||||
df = _old_df.append(df, sort=False)
|
||||
df.to_csv(stock_path, index=False)
|
||||
df.to_csv(instrument_path, index=False)
|
||||
|
||||
def cache_small_data(self, symbol, df):
|
||||
if len(df) <= self.min_numbers_trading:
|
||||
@@ -191,38 +191,38 @@ class BaseCollector(abc.ABC):
|
||||
self.mini_symbol_map.pop(symbol)
|
||||
return self.NORMAL_FLAG
|
||||
|
||||
def _collector(self, stock_list):
|
||||
def _collector(self, instrument_list):
|
||||
|
||||
error_symbol = []
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
with tqdm(total=len(stock_list)) as p_bar:
|
||||
for _symbol, _result in zip(stock_list, executor.map(self._simple_collector, stock_list)):
|
||||
with tqdm(total=len(instrument_list)) as p_bar:
|
||||
for _symbol, _result in zip(instrument_list, executor.map(self._simple_collector, instrument_list)):
|
||||
if _result != self.NORMAL_FLAG:
|
||||
error_symbol.append(_symbol)
|
||||
p_bar.update()
|
||||
print(error_symbol)
|
||||
logger.info(f"error symbol nums: {len(error_symbol)}")
|
||||
logger.info(f"current get symbol nums: {len(stock_list)}")
|
||||
logger.info(f"current get symbol nums: {len(instrument_list)}")
|
||||
error_symbol.extend(self.mini_symbol_map.keys())
|
||||
return sorted(set(error_symbol))
|
||||
|
||||
def collector_data(self):
|
||||
"""collector data"""
|
||||
logger.info("start collector data......")
|
||||
stock_list = self.stock_list
|
||||
instrument_list = self.instrument_list
|
||||
for i in range(self.max_collector_count):
|
||||
if not stock_list:
|
||||
if not instrument_list:
|
||||
break
|
||||
logger.info(f"getting data: {i+1}")
|
||||
stock_list = self._collector(stock_list)
|
||||
instrument_list = self._collector(instrument_list)
|
||||
logger.info(f"{i+1} finish.")
|
||||
for _symbol, _df_list in self.mini_symbol_map.items():
|
||||
self.save_instrument(
|
||||
_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])
|
||||
)
|
||||
if self.mini_symbol_map:
|
||||
logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self.mini_symbol_map.keys())}")
|
||||
logger.info(f"total {len(self.stock_list)}, error: {len(set(stock_list))}")
|
||||
logger.warning(f"less than {self.min_numbers_trading} instrument list: {list(self.mini_symbol_map.keys())}")
|
||||
logger.info(f"total {len(self.instrument_list)}, error: {len(set(instrument_list))}")
|
||||
|
||||
|
||||
class BaseNormalize(abc.ABC):
|
||||
@@ -386,9 +386,9 @@ class BaseRun(abc.ABC):
|
||||
Examples
|
||||
---------
|
||||
# get daily data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
$ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
# get 1m data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
|
||||
$ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
|
||||
"""
|
||||
|
||||
_class = getattr(self._cur_module, self.collector_class_name) # type: Type[BaseCollector]
|
||||
@@ -416,7 +416,7 @@ class BaseRun(abc.ABC):
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/instrument_data/source --normalize_dir ~/.qlib/instrument_data/normalize --region CN --interval 1d
|
||||
"""
|
||||
_class = getattr(self._cur_module, self.normalize_class_name)
|
||||
yc = Normalize(
|
||||
|
||||
24
scripts/data_collector/contrib/README.md
Normal file
24
scripts/data_collector/contrib/README.md
Normal file
@@ -0,0 +1,24 @@
|
||||
# Get future trading days
|
||||
|
||||
> `D.calendar(future=True)` will be used
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Collector Data
|
||||
|
||||
```bash
|
||||
# parse instruments, using in qlib/instruments.
|
||||
python future_trading_date_collector.py --qlib_dir ~/.qlib/qlib_data/cn_data --freq day
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
- qlib_dir: qlib data directory
|
||||
- freq: value from [`day`, `1min`], default `day`
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from loguru import logger
|
||||
|
||||
# get data from baostock
|
||||
import baostock as bs
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
|
||||
|
||||
from data_collector.utils import generate_minutes_calendar_from_daily
|
||||
|
||||
|
||||
def read_calendar_from_qlib(qlib_dir: Path) -> pd.DataFrame:
|
||||
calendar_path = qlib_dir.joinpath("calendars").joinpath("day.txt")
|
||||
if not calendar_path.exists():
|
||||
return pd.DataFrame()
|
||||
return pd.read_csv(calendar_path, header=None)
|
||||
|
||||
|
||||
def write_calendar_to_qlib(qlib_dir: Path, date_list: List[str], freq: str = "day"):
|
||||
calendar_path = str(qlib_dir.joinpath("calendars").joinpath(f"{freq}_future.txt"))
|
||||
|
||||
np.savetxt(calendar_path, date_list, fmt="%s", encoding="utf-8")
|
||||
logger.info(f"write future calendars success: {calendar_path}")
|
||||
|
||||
|
||||
def generate_qlib_calendar(date_list: List[str], freq: str) -> List[str]:
|
||||
print(freq)
|
||||
if freq == "day":
|
||||
return date_list
|
||||
elif freq == "1min":
|
||||
date_list = generate_minutes_calendar_from_daily(date_list, freq=freq).tolist()
|
||||
return list(map(lambda x: pd.Timestamp(x).strftime("%Y-%m-%d %H:%M:%S"), date_list))
|
||||
else:
|
||||
raise ValueError(f"Unsupported freq: {freq}")
|
||||
|
||||
|
||||
def future_calendar_collector(qlib_dir: [str, Path], freq: str = "day"):
|
||||
"""get future calendar
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_dir: str or Path
|
||||
qlib data directory
|
||||
freq: str
|
||||
value from ["day", "1min"], by default day
|
||||
"""
|
||||
qlib_dir = Path(qlib_dir).expanduser().resolve()
|
||||
if not qlib_dir.exists():
|
||||
raise FileNotFoundError(str(qlib_dir))
|
||||
|
||||
lg = bs.login()
|
||||
if lg.error_code != "0":
|
||||
logger.error(f"login error: {lg.error_msg}")
|
||||
return
|
||||
# read daily calendar
|
||||
daily_calendar = read_calendar_from_qlib(qlib_dir)
|
||||
end_year = pd.Timestamp.now().year
|
||||
if daily_calendar.empty:
|
||||
start_year = pd.Timestamp.now().year
|
||||
else:
|
||||
start_year = pd.Timestamp(daily_calendar.iloc[-1, 0]).year
|
||||
rs = bs.query_trade_dates(start_date=pd.Timestamp(f"{start_year}-01-01"), end_date=f"{end_year}-12-31")
|
||||
data_list = []
|
||||
while (rs.error_code == "0") & rs.next():
|
||||
_row_data = rs.get_row_data()
|
||||
if int(_row_data[1]) == 1:
|
||||
data_list.append(_row_data[0])
|
||||
data_list = sorted(data_list)
|
||||
date_list = generate_qlib_calendar(data_list, freq=freq)
|
||||
write_calendar_to_qlib(qlib_dir, date_list, freq=freq)
|
||||
bs.logout()
|
||||
logger.info(f"get trading dates success: {start_year}-01-01 to {end_year}-12-31")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(future_calendar_collector)
|
||||
5
scripts/data_collector/contrib/requirements.txt
Normal file
5
scripts/data_collector/contrib/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
baostock
|
||||
fire
|
||||
numpy
|
||||
pandas
|
||||
loguru
|
||||
51
scripts/data_collector/fund/README.md
Normal file
51
scripts/data_collector/fund/README.md
Normal file
@@ -0,0 +1,51 @@
|
||||
# Collect Fund Data
|
||||
|
||||
> *Please pay **ATTENTION** that the data is collected from [天天基金网](https://fund.eastmoney.com/) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Collector Data
|
||||
|
||||
|
||||
### CN Data
|
||||
|
||||
#### 1d from East Money
|
||||
|
||||
```bash
|
||||
|
||||
# download from eastmoney.com
|
||||
python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
|
||||
# normalize
|
||||
python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_1d --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ
|
||||
|
||||
# dump data
|
||||
cd qlib/scripts
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/fund_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/cn_fund_data --freq day --date_field_name FSRQ --include_fields DWJZ,LJJZ
|
||||
|
||||
```
|
||||
|
||||
### using data
|
||||
|
||||
```python
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/cn_fund_data")
|
||||
df = D.features(D.instruments(market="all"), ["$DWJZ", "$LJJZ"], freq="day")
|
||||
```
|
||||
|
||||
|
||||
### Help
|
||||
```bash
|
||||
pythono collector.py collector_data --help
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
- interval: 1d
|
||||
- region: CN
|
||||
312
scripts/data_collector/fund/collector.py
Normal file
312
scripts/data_collector/fund/collector.py
Normal file
@@ -0,0 +1,312 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
import sys
|
||||
import copy
|
||||
import time
|
||||
import datetime
|
||||
import importlib
|
||||
import json
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Type
|
||||
|
||||
import fire
|
||||
import requests
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from loguru import logger
|
||||
from dateutil.tz import tzlocal
|
||||
from qlib.config import REG_CN as REGION_CN
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
|
||||
from data_collector.utils import get_calendar_list, get_en_fund_symbols
|
||||
|
||||
INDEX_BENCH_URL = "http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}"
|
||||
|
||||
|
||||
class FundCollector(BaseCollector):
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: [str, Path],
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
max_workers=4,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
save_dir: str
|
||||
fund save dir
|
||||
max_workers: int
|
||||
workers, default 4
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1min
|
||||
start: str
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
"""
|
||||
super(FundCollector, self).__init__(
|
||||
save_dir=save_dir,
|
||||
start=start,
|
||||
end=end,
|
||||
interval=interval,
|
||||
max_workers=max_workers,
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
)
|
||||
|
||||
self.init_datetime()
|
||||
|
||||
def init_datetime(self):
|
||||
if self.interval == self.INTERVAL_1min:
|
||||
self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN)
|
||||
elif self.interval == self.INTERVAL_1d:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"interval error: {self.interval}")
|
||||
|
||||
self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)
|
||||
self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)
|
||||
|
||||
@staticmethod
|
||||
def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone):
|
||||
try:
|
||||
dt = pd.Timestamp(dt, tz=timezone).timestamp()
|
||||
dt = pd.Timestamp(dt, tz=tzlocal(), unit="s")
|
||||
except ValueError as e:
|
||||
pass
|
||||
return dt
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def _timezone(self):
|
||||
raise NotImplementedError("rewrite get_timezone")
|
||||
|
||||
@staticmethod
|
||||
def get_data_from_remote(symbol, interval, start, end):
|
||||
error_msg = f"{symbol}-{interval}-{start}-{end}"
|
||||
|
||||
try:
|
||||
# TODO: numberOfHistoricalDaysToCrawl should be bigger enouhg
|
||||
url = INDEX_BENCH_URL.format(
|
||||
index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end
|
||||
)
|
||||
resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"})
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
|
||||
data = json.loads(resp.text.split("(")[-1].split(")")[0])
|
||||
|
||||
# Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html
|
||||
SYType = data["Data"]["SYType"]
|
||||
if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"):
|
||||
raise Exception("The fund contains 每*份收益")
|
||||
|
||||
# TODO: should we sort the value by datetime?
|
||||
_resp = pd.DataFrame(data["Data"]["LSJZList"])
|
||||
|
||||
if isinstance(_resp, pd.DataFrame):
|
||||
return _resp.reset_index()
|
||||
except Exception as e:
|
||||
logger.warning(f"{error_msg}:{e}")
|
||||
|
||||
def get_data(
|
||||
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
||||
) -> [pd.DataFrame]:
|
||||
def _get_simple(start_, end_):
|
||||
self.sleep()
|
||||
_remote_interval = interval
|
||||
return self.get_data_from_remote(
|
||||
symbol,
|
||||
interval=_remote_interval,
|
||||
start=start_,
|
||||
end=end_,
|
||||
)
|
||||
|
||||
if interval == self.INTERVAL_1d:
|
||||
_result = _get_simple(start_datetime, end_datetime)
|
||||
else:
|
||||
raise ValueError(f"cannot support {interval}")
|
||||
return _result
|
||||
|
||||
|
||||
class FundollectorCN(FundCollector, ABC):
|
||||
def get_instrument_list(self):
|
||||
logger.info("get cn fund symbols......")
|
||||
symbols = get_en_fund_symbols()
|
||||
logger.info(f"get {len(symbols)} symbols.")
|
||||
return symbols
|
||||
|
||||
def normalize_symbol(self, symbol):
|
||||
return symbol
|
||||
|
||||
@property
|
||||
def _timezone(self):
|
||||
return "Asia/Shanghai"
|
||||
|
||||
|
||||
class FundCollectorCN1d(FundollectorCN):
|
||||
@property
|
||||
def min_numbers_trading(self):
|
||||
return 252 / 4
|
||||
|
||||
|
||||
class FundNormalize(BaseNormalize):
|
||||
DAILY_FORMAT = "%Y-%m-%d"
|
||||
|
||||
@staticmethod
|
||||
def normalize_fund(
|
||||
df: pd.DataFrame,
|
||||
calendar_list: list = None,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
if df.empty:
|
||||
return df
|
||||
df = df.copy()
|
||||
df.set_index(date_field_name, inplace=True)
|
||||
df.index = pd.to_datetime(df.index)
|
||||
df = df[~df.index.duplicated(keep="first")]
|
||||
if calendar_list is not None:
|
||||
df = df.reindex(
|
||||
pd.DataFrame(index=calendar_list)
|
||||
.loc[
|
||||
pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date()
|
||||
+ pd.Timedelta(hours=23, minutes=59)
|
||||
]
|
||||
.index
|
||||
)
|
||||
df.sort_index(inplace=True)
|
||||
|
||||
df.index.names = [date_field_name]
|
||||
return df.reset_index()
|
||||
|
||||
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
# normalize
|
||||
df = self.normalize_fund(df, self._calendar_list, self._date_field_name, self._symbol_field_name)
|
||||
return df
|
||||
|
||||
|
||||
class FundNormalize1d(FundNormalize):
|
||||
pass
|
||||
|
||||
|
||||
class FundNormalizeCN:
|
||||
def _get_calendar_list(self):
|
||||
return get_calendar_list("ALL")
|
||||
|
||||
|
||||
class FundNormalizeCN1d(FundNormalizeCN, FundNormalize1d):
|
||||
pass
|
||||
|
||||
|
||||
class Run(BaseRun):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str
|
||||
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
||||
normalize_dir: str
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
max_workers: int
|
||||
Concurrent number, default is 4
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
region: str
|
||||
region, value from ["CN"], default "CN"
|
||||
"""
|
||||
super().__init__(source_dir, normalize_dir, max_workers, interval)
|
||||
self.region = region
|
||||
|
||||
@property
|
||||
def collector_class_name(self):
|
||||
return f"FundCollector{self.region.upper()}{self.interval}"
|
||||
|
||||
@property
|
||||
def normalize_class_name(self):
|
||||
return f"FundNormalize{self.region.upper()}{self.interval}"
|
||||
|
||||
@property
|
||||
def default_base_dir(self) -> [Path, str]:
|
||||
return CUR_DIR
|
||||
|
||||
def download_data(
|
||||
self,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
check_data_length=False,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download data from Internet
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
start: str
|
||||
start datetime, default "2000-01-01"
|
||||
end: str
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
||||
check_data_length: bool # if this param useful?
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
|
||||
Examples
|
||||
---------
|
||||
# get daily data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
"""
|
||||
|
||||
super(Run, self).download_data(max_collector_count, delay, start, end, interval, check_data_length, limit_nums)
|
||||
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
||||
"""normalize data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
date_field_name: str
|
||||
date field name, default date
|
||||
symbol_field_name: str
|
||||
symbol field name, default symbol
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_1d --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ
|
||||
"""
|
||||
super(Run, self).normalize_data(date_field_name, symbol_field_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(Run)
|
||||
10
scripts/data_collector/fund/requirements.txt
Normal file
10
scripts/data_collector/fund/requirements.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
loguru
|
||||
fire
|
||||
requests
|
||||
numpy
|
||||
pandas
|
||||
tqdm
|
||||
lxml
|
||||
loguru
|
||||
yahooquery
|
||||
json
|
||||
@@ -114,6 +114,8 @@ class IndexBase:
|
||||
$ python collector.py save_new_companies --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data
|
||||
"""
|
||||
df = self.get_new_companies()
|
||||
if df is None or df.empty:
|
||||
raise ValueError(f"get new companies error: {self.index_name}")
|
||||
df = df.drop_duplicates([self.SYMBOL_FIELD_NAME])
|
||||
df.loc[:, self.INSTRUMENTS_COLUMNS].to_csv(
|
||||
self.instruments_dir.joinpath(f"{self.index_name.lower()}_only_new.txt"), sep="\t", index=False, header=None
|
||||
@@ -184,7 +186,10 @@ class IndexBase:
|
||||
logger.info(f"start parse {self.index_name.lower()} companies.....")
|
||||
instruments_columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD]
|
||||
changers_df = self.get_changes()
|
||||
new_df = self.get_new_companies().copy()
|
||||
new_df = self.get_new_companies()
|
||||
if new_df is None or new_df.empty:
|
||||
raise ValueError(f"get new companies error: {self.index_name}")
|
||||
new_df = new_df.copy()
|
||||
logger.info("parse history companies by changes......")
|
||||
for _row in tqdm(changers_df.sort_values(self.DATE_FIELD_NAME, ascending=False).itertuples(index=False)):
|
||||
if _row.type == self.ADD:
|
||||
|
||||
@@ -35,7 +35,7 @@ WIKI_INDEX_NAME_MAP = {
|
||||
class WIKIIndex(IndexBase):
|
||||
# NOTE: The US stock code contains "PRN", and the directory cannot be created on Windows system, use the "_" prefix
|
||||
# https://superuser.com/questions/613313/why-cant-we-make-con-prn-null-folder-in-windows
|
||||
INST_PREFIX = "_"
|
||||
INST_PREFIX = ""
|
||||
|
||||
def __init__(self, index_name: str, qlib_dir: [str, Path] = None, request_retry: int = 5, retry_sleep: int = 3):
|
||||
super(WIKIIndex, self).__init__(
|
||||
@@ -123,7 +123,7 @@ class NASDAQ100Index(WIKIIndex):
|
||||
MAX_WORKERS = 16
|
||||
|
||||
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
if not (set(df.columns) - {"Company", "Ticker"}):
|
||||
if len(df) >= 100 and "Ticker" in df.columns:
|
||||
return df.loc[:, ["Ticker"]].copy()
|
||||
|
||||
@property
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import os
|
||||
import time
|
||||
import bisect
|
||||
import pickle
|
||||
@@ -9,11 +10,16 @@ import random
|
||||
import requests
|
||||
import functools
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from lxml import etree
|
||||
from loguru import logger
|
||||
from yahooquery import Ticker
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
HS_SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}"
|
||||
|
||||
@@ -34,6 +40,7 @@ _BENCH_CALENDAR_LIST = None
|
||||
_ALL_CALENDAR_LIST = None
|
||||
_HS_SYMBOLS = None
|
||||
_US_SYMBOLS = None
|
||||
_EN_FUND_SYMBOLS = None
|
||||
_CALENDAR_MAP = {}
|
||||
|
||||
# NOTE: Until 2020-10-20 20:00:00
|
||||
@@ -93,6 +100,78 @@ def get_calendar_list(bench_code="CSI300") -> list:
|
||||
return calendar
|
||||
|
||||
|
||||
def return_date_list(date_field_name: str, file_path: Path):
|
||||
date_list = pd.read_csv(file_path, sep=",", index_col=0)[date_field_name].to_list()
|
||||
return sorted(map(lambda x: pd.Timestamp(x), date_list))
|
||||
|
||||
|
||||
def get_calendar_list_by_ratio(
|
||||
source_dir: [str, Path],
|
||||
date_field_name: str = "date",
|
||||
threshold: float = 0.5,
|
||||
minimum_count: int = 10,
|
||||
max_workers: int = 16,
|
||||
) -> list:
|
||||
"""get calendar list by selecting the date when few funds trade in this day
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str or Path
|
||||
The directory where the raw data collected from the Internet is saved
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
threshold: float
|
||||
threshold to exclude some days when few funds trade in this day, default 0.5
|
||||
minimum_count: int
|
||||
minimum count of funds should trade in one day
|
||||
max_workers: int
|
||||
Concurrent number, default is 16
|
||||
|
||||
Returns
|
||||
-------
|
||||
history calendar list
|
||||
"""
|
||||
logger.info(f"get calendar list from {source_dir} by threshold = {threshold}......")
|
||||
|
||||
source_dir = Path(source_dir).expanduser()
|
||||
file_list = list(source_dir.glob("*.csv"))
|
||||
|
||||
_number_all_funds = len(file_list)
|
||||
|
||||
logger.info(f"count how many funds trade in this day......")
|
||||
_dict_count_trade = dict() # dict{date:count}
|
||||
_fun = partial(return_date_list, date_field_name)
|
||||
all_oldest_list = []
|
||||
with tqdm(total=_number_all_funds) as p_bar:
|
||||
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
||||
for date_list in executor.map(_fun, file_list):
|
||||
if date_list:
|
||||
all_oldest_list.append(date_list[0])
|
||||
for date in date_list:
|
||||
if date not in _dict_count_trade.keys():
|
||||
_dict_count_trade[date] = 0
|
||||
|
||||
_dict_count_trade[date] += 1
|
||||
|
||||
p_bar.update()
|
||||
|
||||
logger.info(f"count how many funds have founded in this day......")
|
||||
_dict_count_founding = {date: _number_all_funds for date in _dict_count_trade.keys()} # dict{date:count}
|
||||
with tqdm(total=_number_all_funds) as p_bar:
|
||||
for oldest_date in all_oldest_list:
|
||||
for date in _dict_count_founding.keys():
|
||||
if date < oldest_date:
|
||||
_dict_count_founding[date] -= 1
|
||||
|
||||
calendar = [
|
||||
date
|
||||
for date in _dict_count_trade
|
||||
if _dict_count_trade[date] >= max(int(_dict_count_founding[date] * threshold), minimum_count)
|
||||
]
|
||||
|
||||
return calendar
|
||||
|
||||
|
||||
def get_hs_stock_symbols() -> list:
|
||||
"""get SH/SZ stock symbols
|
||||
|
||||
@@ -220,6 +299,42 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
return _US_SYMBOLS
|
||||
|
||||
|
||||
def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
"""get en fund symbols
|
||||
|
||||
Returns
|
||||
-------
|
||||
fund symbols in China
|
||||
"""
|
||||
global _EN_FUND_SYMBOLS
|
||||
|
||||
@deco_retry
|
||||
def _get_eastmoney():
|
||||
url = "http://fund.eastmoney.com/js/fundcode_search.js"
|
||||
resp = requests.get(url)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
try:
|
||||
_symbols = []
|
||||
for sub_data in re.findall(r"[\[](.*?)[\]]", resp.content.decode().split("= [")[-1].replace("];", "")):
|
||||
data = sub_data.replace('"', "").replace("'", "")
|
||||
# TODO: do we need other informations, like fund_name from ['000001', 'HXCZHH', '华夏成长混合', '混合型', 'HUAXIACHENGZHANGHUNHE']
|
||||
_symbols.append(data.split(",")[0])
|
||||
except Exception as e:
|
||||
logger.warning(f"request error: {e}")
|
||||
raise
|
||||
if len(_symbols) < 8000:
|
||||
raise ValueError("request error")
|
||||
return _symbols
|
||||
|
||||
if _EN_FUND_SYMBOLS is None:
|
||||
_all_symbols = _get_eastmoney()
|
||||
|
||||
_EN_FUND_SYMBOLS = sorted(set(_all_symbols))
|
||||
|
||||
return _EN_FUND_SYMBOLS
|
||||
|
||||
|
||||
def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str:
|
||||
"""symbol suffix to prefix
|
||||
|
||||
@@ -305,5 +420,40 @@ def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, sh
|
||||
return res
|
||||
|
||||
|
||||
def generate_minutes_calendar_from_daily(
|
||||
calendars: Iterable,
|
||||
freq: str = "1min",
|
||||
am_range: Tuple[str, str] = ("09:30:00", "11:29:00"),
|
||||
pm_range: Tuple[str, str] = ("13:00:00", "14:59:00"),
|
||||
) -> pd.Index:
|
||||
"""generate minutes calendar
|
||||
|
||||
Parameters
|
||||
----------
|
||||
calendars: Iterable
|
||||
daily calendar
|
||||
freq: str
|
||||
by default 1min
|
||||
am_range: Tuple[str, str]
|
||||
AM Time Range, by default China-Stock: ("09:30:00", "11:29:00")
|
||||
pm_range: Tuple[str, str]
|
||||
PM Time Range, by default China-Stock: ("13:00:00", "14:59:00")
|
||||
|
||||
"""
|
||||
daily_format: str = "%Y-%m-%d"
|
||||
res = []
|
||||
for _day in calendars:
|
||||
for _range in [am_range, pm_range]:
|
||||
res.append(
|
||||
pd.date_range(
|
||||
f"{pd.Timestamp(_day).strftime(daily_format)} {_range[0]}",
|
||||
f"{pd.Timestamp(_day).strftime(daily_format)} {_range[1]}",
|
||||
freq=freq,
|
||||
)
|
||||
)
|
||||
|
||||
return pd.Index(sorted(set(np.hstack(res))))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM
|
||||
|
||||
@@ -24,7 +24,12 @@ from qlib.config import REG_CN as REGION_CN
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
|
||||
from data_collector.utils import get_calendar_list, get_hs_stock_symbols, get_us_stock_symbols
|
||||
from data_collector.utils import (
|
||||
get_calendar_list,
|
||||
get_hs_stock_symbols,
|
||||
get_us_stock_symbols,
|
||||
generate_minutes_calendar_from_daily,
|
||||
)
|
||||
|
||||
INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}"
|
||||
|
||||
@@ -185,7 +190,7 @@ class YahooCollector(BaseCollector):
|
||||
|
||||
|
||||
class YahooCollectorCN(YahooCollector, ABC):
|
||||
def get_stock_list(self):
|
||||
def get_instrument_list(self):
|
||||
logger.info("get HS stock symbos......")
|
||||
symbols = get_hs_stock_symbols()
|
||||
logger.info(f"get {len(symbols)} symbols.")
|
||||
@@ -249,7 +254,7 @@ class YahooCollectorCN1min(YahooCollectorCN):
|
||||
|
||||
|
||||
class YahooCollectorUS(YahooCollector, ABC):
|
||||
def get_stock_list(self):
|
||||
def get_instrument_list(self):
|
||||
logger.info("get US stock symbols......")
|
||||
symbols = get_us_stock_symbols() + [
|
||||
"^GSPC",
|
||||
@@ -418,21 +423,9 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
return calendar_list_1d
|
||||
|
||||
def generate_1min_from_daily(self, calendars: Iterable) -> pd.Index:
|
||||
res = []
|
||||
daily_format = self.DAILY_FORMAT
|
||||
am_range = self.AM_RANGE
|
||||
pm_range = self.PM_RANGE
|
||||
for _day in calendars:
|
||||
for _range in [am_range, pm_range]:
|
||||
res.append(
|
||||
pd.date_range(
|
||||
f"{_day.strftime(daily_format)} {_range[0]}",
|
||||
f"{_day.strftime(daily_format)} {_range[1]}",
|
||||
freq="1min",
|
||||
)
|
||||
)
|
||||
|
||||
return pd.Index(sorted(set(np.hstack(res))))
|
||||
return generate_minutes_calendar_from_daily(
|
||||
calendars, freq="1min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE
|
||||
)
|
||||
|
||||
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
# TODO: using daily data factor
|
||||
|
||||
@@ -219,7 +219,7 @@ class DumpDataBase:
|
||||
# used when creating a bin file
|
||||
date_index = self.get_datetime_index(_df, calendar_list)
|
||||
for field in self.get_dump_fields(_df.columns):
|
||||
bin_path = features_dir.joinpath(f"{field}.{self.freq}{self.DUMP_FILE_SUFFIX}")
|
||||
bin_path = features_dir.joinpath(f"{field.lower()}.{self.freq}{self.DUMP_FILE_SUFFIX}")
|
||||
if field not in _df.columns:
|
||||
continue
|
||||
if bin_path.exists() and self._mode == self.UPDATE_MODE:
|
||||
|
||||
2
setup.py
2
setup.py
@@ -35,7 +35,7 @@ REQUIRED = [
|
||||
"scipy>=1.0.0",
|
||||
"requests>=2.18.0",
|
||||
"sacred>=0.7.4",
|
||||
"python-socketio==3.1.2",
|
||||
"python-socketio",
|
||||
"redis>=3.0.1",
|
||||
"python-redis-lock>=3.3.1",
|
||||
"schedule>=0.6.0",
|
||||
|
||||
@@ -6,24 +6,11 @@ import shutil
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
import qlib
|
||||
from qlib.config import REG_CN, C
|
||||
from qlib.utils import drop_nan_by_y_index
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.contrib.workflow.record_temp import SignalMseRecord
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
|
||||
from qlib.config import C
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
|
||||
@@ -166,8 +153,6 @@ def train_with_sigana():
|
||||
ric = sar.load(sar.get_path("ric.pkl"))
|
||||
pred_score = sar.load("pred.pkl")
|
||||
|
||||
smr = SignalMseRecord(recorder)
|
||||
smr.generate()
|
||||
uri_path = R.get_uri()
|
||||
return pred_score, {"ic": ic, "ric": ric}, uri_path
|
||||
|
||||
@@ -256,8 +241,10 @@ class TestAllFlow(TestAutoData):
|
||||
|
||||
def suite():
|
||||
_suite = unittest.TestSuite()
|
||||
_suite.addTest(TestAllFlow("test_0_train"))
|
||||
_suite.addTest(TestAllFlow("test_1_backtest"))
|
||||
_suite.addTest(TestAllFlow("test_0_train_with_sigana"))
|
||||
_suite.addTest(TestAllFlow("test_1_train"))
|
||||
_suite.addTest(TestAllFlow("test_2_backtest"))
|
||||
_suite.addTest(TestAllFlow("test_3_expmanager"))
|
||||
return _suite
|
||||
|
||||
|
||||
|
||||
27
tests/test_contrib_model.py
Normal file
27
tests/test_contrib_model.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import unittest
|
||||
|
||||
from qlib.contrib.model import all_model_classes
|
||||
|
||||
|
||||
class TestAllFlow(unittest.TestCase):
|
||||
def test_0_initialize(self):
|
||||
num = 0
|
||||
for model_class in all_model_classes:
|
||||
if model_class is not None:
|
||||
model = model_class()
|
||||
num += 1
|
||||
print("There are {:}/{:} valid models in total.".format(num, len(all_model_classes)))
|
||||
|
||||
|
||||
def suite():
|
||||
_suite = unittest.TestSuite()
|
||||
_suite.addTest(TestAllFlow("test_0_initialize"))
|
||||
return _suite
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
runner = unittest.TextTestRunner()
|
||||
runner.run(suite())
|
||||
111
tests/test_contrib_workflow.py
Normal file
111
tests/test_contrib_workflow.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
import shutil
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
from qlib.config import C
|
||||
from qlib.contrib.workflow import MultiSegRecord, SignalMseRecord
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def train_multiseg():
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
with R.start(experiment_name="workflow"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
model.fit(dataset)
|
||||
recorder = R.get_recorder()
|
||||
sr = MultiSegRecord(model, dataset, recorder)
|
||||
sr.generate(dict(valid="valid", test="test"), True)
|
||||
uri = R.get_uri()
|
||||
return uri
|
||||
|
||||
|
||||
def train_mse():
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
with R.start(experiment_name="workflow"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
model.fit(dataset)
|
||||
recorder = R.get_recorder()
|
||||
sr = SignalMseRecord(recorder, model=model, dataset=dataset)
|
||||
sr.generate()
|
||||
uri = R.get_uri()
|
||||
return uri
|
||||
|
||||
|
||||
class TestAllFlow(TestAutoData):
|
||||
def test_0_multiseg(self):
|
||||
uri_path = train_multiseg()
|
||||
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
|
||||
|
||||
def test_1_mse(self):
|
||||
uri_path = train_mse()
|
||||
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
|
||||
|
||||
|
||||
def suite():
|
||||
_suite = unittest.TestSuite()
|
||||
_suite.addTest(TestAllFlow("test_0_multiseg"))
|
||||
_suite.addTest(TestAllFlow("test_1_mse"))
|
||||
return _suite
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
runner = unittest.TextTestRunner()
|
||||
runner.run(suite())
|
||||
Reference in New Issue
Block a user