1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Split classes in riskmodel.py & optimizer.py into seperate files.

This commit is contained in:
Charles Young
2021-03-04 22:08:11 +08:00
parent 527718a440
commit 2bff6eb781
5 changed files with 150 additions and 62 deletions

View File

View File

View File

View File

View File

@@ -2,32 +2,39 @@
# Licensed under the MIT License.
import sys
import math
import shutil
import unittest
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path
import qlib
from qlib.config import REG_CN, C
from qlib.utils import drop_nan_by_y_index
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.data.handler import Alpha158
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
from qlib.config import C
from qlib.utils import init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord
from qlib.tests.data import GetData
from qlib.config import REG_CN
from qlib.workflow.record_temp import SignalRecord, SigAnaRecord
from qlib.tests import TestAutoData
from qlib.portfolio.optimizer import EnhancedIndexingOptimizer
from qlib.model.riskmodel import StructuredCovEstimator
from qlib.data.dataset.loader import QlibDataLoader
from qlib.data.dataset.handler import DataHandler
from qlib.data import D
from qlib.utils import exists_qlib_data, init_instance_by_config
market = "all"
trade_gap = 21
label_config = "Ref($close, -{}) / Ref($close, -1) - 1".format(trade_gap) # reconstruct portfolio once a month
market = "csi300"
benchmark = "SH000300"
provider_uri = "~/.qlib_ei/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(Path.cwd().parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)
###################################
# train model
@@ -36,8 +43,9 @@ data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"fit_end_time": "2014-11-30",
"instruments": market,
"label": [label_config]
}
task = {
@@ -53,7 +61,7 @@ task = {
"lambda_l2": 580.9768,
"max_depth": 8,
"num_leaves": 210,
"num_threads": 20,
"num_threads": 32,
},
},
"dataset": {
@@ -66,37 +74,104 @@ task = {
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
"train": ("2008-01-01", "2014-11-30"),
"valid": ("2015-01-01", "2016-11-30"),
"test": ("2017-01-01", "2018-01-01"),
},
},
},
}
port_analysis_config = {
"strategy": {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.strategy",
"kwargs": {
"topk": 50,
"n_drop": 5,
},
},
"backtest": {
"verbose": False,
"limit_threshold": 0.095,
"account": 100000000,
"benchmark": benchmark,
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5,
},
}
class CSI300:
"""Simulate CSI300 as the Benchmark for Enhanced Indexing to Track"""
def __init__(self):
# provider_uri = '/nfs_data/qlib_data/ycz_daily/qlib'
# qlib.init(provider_uri=provider_uri, region=REG_CN, dataset_cache=None, expression_cache=None)
self.csi_weight = D.features(D.instruments('csi300'), ['$csi300_weight'])
def __call__(self, pd_index, trade_date):
weights = np.zeros(len(pd_index))
for idx, instrument in enumerate(pd_index):
if (instrument, trade_date) in self.csi_weight.index:
weight = self.csi_weight.loc[(instrument, trade_date)].values[0]
if not math.isnan(weight):
weights[idx] = weight
assert weights.sum() > 0, ' Fetch CSI Weights Error!'
weights = weights / weights.sum()
return weights
class EnhancedIndexingStrategy:
"""Enhanced Indexing Strategy"""
def __init__(self):
self.benchmark = CSI300()
provider_uri = "~/.qlib_ei/qlib_data/cn_data"
qlib.init(provider_uri=provider_uri, region=REG_CN)
self.data_handler = DataHandler(market, "2015-01-01", "2019-01-01", QlibDataLoader(["$close"]))
self.label_handler = DataHandler(market, "2015-01-01", "2019-01-01", QlibDataLoader([label_config]))
self.cov_estimator = StructuredCovEstimator()
self.optimizer = EnhancedIndexingOptimizer(lamb=0.1, delta=0.4, bench_dev=0.03, max_iters=50000)
def update(self, score_series, current, pred_date):
"""
Parameters
-----------
score_series : pd.Series
stock_id , score.
current : Position()
current of account.
trade_exchange : Exchange()
exchange.
trade_date : pd.Timestamp
date.
"""
print(score_series)
score_series = score_series.dropna()
# portfolio init weight
init_weight = current.reindex(score_series.index, fill_value=0).values.squeeze()
init_weight_sum = init_weight.sum()
if init_weight_sum > 0:
init_weight /= init_weight_sum
# covariance estimation
selector = (self.data_handler.get_range_selector(pred_date, 252), score_series.index)
price = self.data_handler.fetch(selector, level=None, squeeze=True)
F, cov_b, var_u = self.cov_estimator.predict(price, return_decomposed_components=True)
# optimize target portfolio
w_bench = self.benchmark(score_series.index, pred_date)
passed_init_weight = init_weight if init_weight_sum > 0 else None
# print(F)
# print(cov_b)
# print(var_u)
# print(passed_init_weight)
# print(w_bench)
target_weight = self.optimizer(score_series.values, F, cov_b, var_u, passed_init_weight, w_bench)
# print(target_weight)
target = pd.DataFrame(data=target_weight, index=score_series.index)
active_weights = target_weight - w_bench
selector = (self.label_handler.get_range_selector(pred_date, 1), score_series.index)
label = self.label_handler.fetch(selector, level=None, squeeze=True)
alpha = 0
for instrument, weight in zip(score_series.index, active_weights):
delta = label.loc[(pred_date, instrument)]
alpha += weight * (0 if math.isnan(delta) else delta)
print(alpha)
return alpha, target
# train
def train():
"""train model
@@ -108,7 +183,7 @@ def train():
model performance
"""
# model initiaiton
# model initiation
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])
@@ -133,29 +208,42 @@ def train():
return pred_score, {"ic": ic, "ric": ric}, rid
def backtest_analysis(pred, rid):
"""backtest and analysis
def backtest_analysis(scores):
"""backtest enhanced indexing
Parameters
----------
pred : pandas.DataFrame
predict scores
rid : str
the id of the recorder to be used in this function
scores: pandas.DataFrame
predict scores
Returns
-------
analysis : pandas.DataFrame
the analysis result
sharpe_ratio: floating-point
sharpe ratio of the enhanced indexing portfolio
"""
recorder = R.get_recorder(experiment_name="workflow", recorder_id=rid)
# backtest
par = PortAnaRecord(recorder, port_analysis_config)
par.generate()
analysis_df = par.load(par.get_path("port_analysis.pkl"))
print(analysis_df)
return analysis_df
# backtest and analysis
with R.start(experiment_name="backtest_analysis"):
strategy = EnhancedIndexingStrategy()
dates = scores.index.get_level_values(0).unique()
alphas = []
current = pd.DataFrame()
gap_between_next_trade = 0
for date in tqdm(dates):
if gap_between_next_trade == 0:
score_series = scores.loc[date]
alpha, current = strategy.update(score_series, current, date)
alphas.append(alpha)
gap_between_next_trade = trade_gap
else:
gap_between_next_trade -= 1
alphas = np.array(alphas)
sharpe_ratio = alphas.mean() / np.std(alphas)
print('Sharpe:', sharpe_ratio)
return sharpe_ratio
class TestAllFlow(TestAutoData):
@@ -174,10 +262,10 @@ class TestAllFlow(TestAutoData):
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
def test_1_backtest(self):
analyze_df = backtest_analysis(TestAllFlow.PRED_SCORE, TestAllFlow.RID)
sharpe_ratio = backtest_analysis(TestAllFlow.PRED_SCORE)
self.assertGreaterEqual(
analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0],
0.10,
sharpe_ratio,
0.90,
"backtest failed",
)