mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Update CI & add black formatter
This commit is contained in:
10
.github/workflows/test.yml
vendored
10
.github/workflows/test.yml
vendored
@@ -38,14 +38,12 @@ jobs:
|
||||
- name: Install test dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install flake8 pytest
|
||||
pip install black pytest
|
||||
|
||||
- name: Lint with flake8
|
||||
- name: Lint with Black
|
||||
run: |
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
|
||||
cd ..
|
||||
python -m black qlib -l 120
|
||||
|
||||
- name: Unit tests with Pytest
|
||||
run: |
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
[](https://pypi.org/project/pyqlib/#files)
|
||||
[](https://pypi.org/project/pyqlib/#history)
|
||||
[](https://pypi.org/project/pyqlib/)
|
||||
[](https://github.com/microsoft/qlib/actions)
|
||||
[](https://qlib.readthedocs.io/en/latest/?badge=latest)
|
||||
[](LICENSE)
|
||||
[](https://gitter.im/Microsoft/qlib?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
||||
|
||||
@@ -53,7 +53,6 @@ source_suffix = ".rst"
|
||||
master_doc = "index"
|
||||
|
||||
|
||||
|
||||
# General information about the project.
|
||||
project = u"QLib"
|
||||
copyright = u"Microsoft"
|
||||
@@ -104,8 +103,7 @@ todo_include_todos = True
|
||||
#
|
||||
html_theme = "sphinx_rtd_theme"
|
||||
|
||||
html_logo = '_static/img/logo/1.png'
|
||||
|
||||
html_logo = "_static/img/logo/1.png"
|
||||
|
||||
|
||||
# Theme options are theme-specific and customize the look and feel of a theme
|
||||
@@ -161,15 +159,12 @@ latex_elements = {
|
||||
# The paper size ('letterpaper' or 'a4paper').
|
||||
#
|
||||
# 'papersize': 'letterpaper',
|
||||
|
||||
# The font size ('10pt', '11pt' or '12pt').
|
||||
#
|
||||
# 'pointsize': '10pt',
|
||||
|
||||
# Additional stuff for the LaTeX preamble.
|
||||
#
|
||||
# 'preamble': '',
|
||||
|
||||
# Latex figure (float) alignment
|
||||
#
|
||||
# 'figure_align': 'htbp',
|
||||
|
||||
@@ -54,9 +54,9 @@ if __name__ == "__main__":
|
||||
|
||||
# use default DataHandler
|
||||
# custom DataHandler, refer to: TODO: DataHandler API url
|
||||
x_train, y_train, x_validate, y_validate, x_test, y_test = Alpha158(
|
||||
**DATA_HANDLER_CONFIG
|
||||
).get_split_data(**TRAINER_CONFIG)
|
||||
x_train, y_train, x_validate, y_validate, x_test, y_test = Alpha158(**DATA_HANDLER_CONFIG).get_split_data(
|
||||
**TRAINER_CONFIG
|
||||
)
|
||||
|
||||
MODEL_CONFIG = {
|
||||
"loss": "mse",
|
||||
@@ -114,6 +114,8 @@ if __name__ == "__main__":
|
||||
###################################
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"] - report_normal["cost"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
print(analysis_df)
|
||||
|
||||
@@ -44,7 +44,7 @@ def init(default_conf="client", **kwargs):
|
||||
if k not in C:
|
||||
LOG.warning("Unrecognized config %s" % k)
|
||||
|
||||
C.set_region(kwargs.get('region', C['region'] if 'region' in C else REG_CN ))
|
||||
C.set_region(kwargs.get("region", C["region"] if "region" in C else REG_CN))
|
||||
C.resolve_path()
|
||||
|
||||
if not (C["expression_cache"] is None and C["dataset_cache"] is None):
|
||||
@@ -83,6 +83,7 @@ def init(default_conf="client", **kwargs):
|
||||
|
||||
def _mount_nfs_uri(C):
|
||||
from .log import get_module_logger
|
||||
|
||||
LOG = get_module_logger("mount nfs", level=logging.INFO)
|
||||
|
||||
# FIXME: the C["provider_uri"] is modified in this function
|
||||
@@ -161,9 +162,7 @@ def _mount_nfs_uri(C):
|
||||
command_res = os.popen("dpkg -l | grep nfs-common")
|
||||
command_res = command_res.readlines()
|
||||
if not command_res:
|
||||
raise OSError(
|
||||
"nfs-common is not found, please install it by execute: sudo apt install nfs-common"
|
||||
)
|
||||
raise OSError("nfs-common is not found, please install it by execute: sudo apt install nfs-common")
|
||||
# manually mount
|
||||
command_status = os.system(mount_command)
|
||||
if command_status == 256:
|
||||
|
||||
@@ -17,7 +17,6 @@ import re
|
||||
|
||||
|
||||
class Config:
|
||||
|
||||
def __init__(self, default_conf):
|
||||
self.__dict__["_default_config"] = default_conf # avoiding conflictions with __getattr__
|
||||
self.reset()
|
||||
@@ -128,7 +127,7 @@ _default_config = {
|
||||
}
|
||||
|
||||
MODE_CONF = {
|
||||
'server': {
|
||||
"server": {
|
||||
# data provider config
|
||||
"calendar_provider": "LocalCalendarProvider",
|
||||
"instrument_provider": "LocalInstrumentProvider",
|
||||
@@ -147,8 +146,7 @@ MODE_CONF = {
|
||||
"expression_cache": "DiskExpressionCache",
|
||||
"dataset_cache": "DiskDatasetCache",
|
||||
},
|
||||
|
||||
'client': {
|
||||
"client": {
|
||||
# data provider config
|
||||
"calendar_provider": "LocalCalendarProvider",
|
||||
"instrument_provider": "LocalInstrumentProvider",
|
||||
@@ -172,7 +170,7 @@ MODE_CONF = {
|
||||
"timeout": 100,
|
||||
"logging_level": "INFO",
|
||||
"region": REG_CN,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -192,8 +190,8 @@ _default_region_config = {
|
||||
|
||||
class QlibConfig(Config):
|
||||
# URI_TYPE
|
||||
LOCAL_URI = 'local'
|
||||
NFS_URI = 'nfs'
|
||||
LOCAL_URI = "local"
|
||||
NFS_URI = "nfs"
|
||||
|
||||
def set_mode(self, mode):
|
||||
# raise KeyError
|
||||
@@ -222,9 +220,9 @@ class QlibConfig(Config):
|
||||
|
||||
def get_data_path(self):
|
||||
if self.get_uri_type() == QlibConfig.LOCAL_URI:
|
||||
return self['provider_uri']
|
||||
return self["provider_uri"]
|
||||
elif self.get_uri_type() == QlibConfig.NFS_URI:
|
||||
return self['mount_path']
|
||||
return self["mount_path"]
|
||||
else:
|
||||
raise NotImplementedError(f"This type of uri is not supported")
|
||||
|
||||
|
||||
@@ -186,7 +186,9 @@ class Estimator(object):
|
||||
# analysis["pred_short"] = risk_analysis(long_short_reports["short"])
|
||||
# analysis["pred_long_short"] = risk_analysis(long_short_reports["long_short"])
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"] - report_normal["cost"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
TimeInspector.log_cost_time(
|
||||
"Finished generating analysis," " average turnover is: {0:.4f}.".format(report_normal["turnover"].mean())
|
||||
|
||||
@@ -558,16 +558,16 @@ class QLibDataHandlerV1(ConfigQLibDataHandler):
|
||||
|
||||
class Alpha158(QLibDataHandlerV1):
|
||||
config_template = {
|
||||
'kbar': {},
|
||||
'price': {
|
||||
'windows': [0],
|
||||
'feature': ['OPEN', 'HIGH', 'LOW', 'CLOSE'],
|
||||
"kbar": {},
|
||||
"price": {
|
||||
"windows": [0],
|
||||
"feature": ["OPEN", "HIGH", "LOW", "CLOSE"],
|
||||
},
|
||||
'rolling': {}
|
||||
"rolling": {},
|
||||
}
|
||||
|
||||
def _init_kwargs(self, **kwargs):
|
||||
kwargs['labels'] = ["Ref($close, -2)/Ref($close, -1) - 1"]
|
||||
kwargs["labels"] = ["Ref($close, -2)/Ref($close, -1) - 1"]
|
||||
super(Alpha158, self)._init_kwargs(**kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -34,8 +34,13 @@ def risk_analysis(r, N=252):
|
||||
annualized_return = mean * N
|
||||
information_ratio = mean / std * np.sqrt(N)
|
||||
max_drawdown = (r.cumsum() - r.cumsum().cummax()).min()
|
||||
data = {"mean": mean, "std": std, "annualized_return": annualized_return,
|
||||
"information_ratio": information_ratio, "max_drawdown": max_drawdown}
|
||||
data = {
|
||||
"mean": mean,
|
||||
"std": std,
|
||||
"annualized_return": annualized_return,
|
||||
"information_ratio": information_ratio,
|
||||
"max_drawdown": max_drawdown,
|
||||
}
|
||||
res = pd.Series(data, index=data.keys()).to_frame("risk")
|
||||
return res
|
||||
|
||||
@@ -230,7 +235,7 @@ def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, **k
|
||||
limit move 0.1 (10%) for example, long and short with same limit
|
||||
extract_codes: bool
|
||||
will we pass the codes extracted from the pred to the exchange.
|
||||
|
||||
|
||||
.. note:: This will be faster with offline qlib.
|
||||
"""
|
||||
# check strategy:
|
||||
|
||||
@@ -167,7 +167,7 @@ class DNNModelPytorch(Model):
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
#return
|
||||
# return
|
||||
# prepare training data
|
||||
x_train_values = torch.from_numpy(x_train.values).float()
|
||||
y_train_values = torch.from_numpy(y_train.values).float()
|
||||
@@ -210,7 +210,7 @@ class DNNModelPytorch(Model):
|
||||
|
||||
# validation
|
||||
train_loss += loss.val
|
||||
#print(loss.val)
|
||||
# print(loss.val)
|
||||
if step and step % self.eval_steps == 0:
|
||||
stop_steps += 1
|
||||
train_loss /= self.eval_steps
|
||||
@@ -263,7 +263,7 @@ class DNNModelPytorch(Model):
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = torch.from_numpy(x_test.values).float().cuda()
|
||||
self.dnn_model.eval()
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
preds = self.dnn_model(x_test).detach().cpu().numpy()
|
||||
return preds
|
||||
|
||||
@@ -14,9 +14,7 @@ from scipy import stats
|
||||
from ..graph import ScatterGraph, SubplotsGraph, BarGraph, HeatmapGraph
|
||||
|
||||
|
||||
def _group_return(
|
||||
pred_label: pd.DataFrame = None, reverse: bool = False, N: int = 5, **kwargs
|
||||
) -> tuple:
|
||||
def _group_return(pred_label: pd.DataFrame = None, reverse: bool = False, N: int = 5, **kwargs) -> tuple:
|
||||
"""
|
||||
|
||||
:param pred_label:
|
||||
@@ -48,9 +46,7 @@ def _group_return(
|
||||
t_df["long-short"] = t_df["Group1"] - t_df["Group%d" % N]
|
||||
|
||||
# Long-Average
|
||||
t_df["long-average"] = (
|
||||
t_df["Group1"] - pred_label.groupby(level="datetime")["label"].mean()
|
||||
)
|
||||
t_df["long-average"] = t_df["Group1"] - pred_label.groupby(level="datetime")["label"].mean()
|
||||
|
||||
t_df = t_df.dropna(how="all") # for days which does not contain label
|
||||
# FIXME: support HIGH-FREQ
|
||||
@@ -58,9 +54,7 @@ def _group_return(
|
||||
# Cumulative Return By Group
|
||||
group_scatter_figure = ScatterGraph(
|
||||
t_df.cumsum(),
|
||||
layout=dict(
|
||||
title="Cumulative Return", xaxis=dict(type="category", tickangle=45)
|
||||
),
|
||||
layout=dict(title="Cumulative Return", xaxis=dict(type="category", tickangle=45)),
|
||||
).figure
|
||||
|
||||
t_df = t_df.loc[:, ["long-short", "long-average"]]
|
||||
@@ -103,13 +97,9 @@ def _pred_ic(pred_label: pd.DataFrame = None, rank: bool = False, **kwargs) -> t
|
||||
lambda x: x["label"].rank(pct=True).corr(x["score"].rank(pct=True))
|
||||
)
|
||||
else:
|
||||
ic = pred_label.groupby(level="datetime").apply(
|
||||
lambda x: x["label"].corr(x["score"])
|
||||
)
|
||||
ic = pred_label.groupby(level="datetime").apply(lambda x: x["label"].corr(x["score"]))
|
||||
|
||||
_index = (
|
||||
ic.index.get_level_values(0).astype("str").str.replace("-", "").str.slice(0, 6)
|
||||
)
|
||||
_index = ic.index.get_level_values(0).astype("str").str.replace("-", "").str.slice(0, 6)
|
||||
_monthly_ic = ic.groupby(_index).mean()
|
||||
_monthly_ic.index = pd.MultiIndex.from_arrays(
|
||||
[_monthly_ic.index.str.slice(0, 4), _monthly_ic.index.str.slice(4, 6)],
|
||||
@@ -186,17 +176,13 @@ def _pred_ic(pred_label: pd.DataFrame = None, rank: bool = False, **kwargs) -> t
|
||||
def _pred_autocorr(pred_label: pd.DataFrame, lag=1, **kwargs) -> tuple:
|
||||
pred = pred_label.copy()
|
||||
pred["score_last"] = pred.groupby(level="instrument")["score"].shift(lag)
|
||||
ac = pred.groupby(level="datetime").apply(
|
||||
lambda x: x["score"].rank(pct=True).corr(x["score_last"].rank(pct=True))
|
||||
)
|
||||
ac = pred.groupby(level="datetime").apply(lambda x: x["score"].rank(pct=True).corr(x["score_last"].rank(pct=True)))
|
||||
# FIXME: support HIGH-FREQ
|
||||
_df = ac.to_frame("value")
|
||||
_df.index = _df.index.strftime("%Y-%m-%d")
|
||||
ac_figure = ScatterGraph(
|
||||
_df,
|
||||
layout=dict(
|
||||
title="Auto Correlation", xaxis=dict(type="category", tickangle=45)
|
||||
),
|
||||
layout=dict(title="Auto Correlation", xaxis=dict(type="category", tickangle=45)),
|
||||
).figure
|
||||
return (ac_figure,)
|
||||
|
||||
@@ -206,9 +192,7 @@ def _pred_turnover(pred_label: pd.DataFrame, N=5, lag=1, **kwargs) -> tuple:
|
||||
pred["score_last"] = pred.groupby(level="instrument")["score"].shift(lag)
|
||||
top = pred.groupby(level="datetime").apply(
|
||||
lambda x: 1
|
||||
- x.nlargest(len(x) // N, columns="score")
|
||||
.index.isin(x.nlargest(len(x) // N, columns="score_last").index)
|
||||
.sum()
|
||||
- x.nlargest(len(x) // N, columns="score").index.isin(x.nlargest(len(x) // N, columns="score_last").index).sum()
|
||||
/ (len(x) // N)
|
||||
)
|
||||
bottom = pred.groupby(level="datetime").apply(
|
||||
@@ -218,14 +202,17 @@ def _pred_turnover(pred_label: pd.DataFrame, N=5, lag=1, **kwargs) -> tuple:
|
||||
.sum()
|
||||
/ (len(x) // N)
|
||||
)
|
||||
r_df = pd.DataFrame({"Top": top, "Bottom": bottom,})
|
||||
r_df = pd.DataFrame(
|
||||
{
|
||||
"Top": top,
|
||||
"Bottom": bottom,
|
||||
}
|
||||
)
|
||||
# FIXME: support HIGH-FREQ
|
||||
r_df.index = r_df.index.strftime("%Y-%m-%d")
|
||||
turnover_figure = ScatterGraph(
|
||||
r_df,
|
||||
layout=dict(
|
||||
title="Top-Bottom Turnover", xaxis=dict(type="category", tickangle=45)
|
||||
),
|
||||
layout=dict(title="Top-Bottom Turnover", xaxis=dict(type="category", tickangle=45)),
|
||||
).figure
|
||||
return (turnover_figure,)
|
||||
|
||||
@@ -270,12 +257,12 @@ def model_performance_graph(
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
instrument datetime score label
|
||||
SH600004 2017-12-11 -0.013502 -0.013502
|
||||
2017-12-12 -0.072367 -0.072367
|
||||
2017-12-13 -0.068605 -0.068605
|
||||
2017-12-14 0.012440 0.012440
|
||||
2017-12-15 -0.102778 -0.102778
|
||||
instrument datetime score label
|
||||
SH600004 2017-12-11 -0.013502 -0.013502
|
||||
2017-12-12 -0.072367 -0.072367
|
||||
2017-12-13 -0.068605 -0.068605
|
||||
2017-12-14 0.012440 0.012440
|
||||
2017-12-15 -0.102778 -0.102778
|
||||
|
||||
|
||||
:param lag: `pred.groupby(level='instrument')['score'].shift(lag)`. It will be only used in the auto-correlation computing.
|
||||
|
||||
@@ -36,9 +36,7 @@ def _get_cum_return_data_with_position(
|
||||
end_date=end_date,
|
||||
).copy()
|
||||
|
||||
_cumulative_return_df["label"] = (
|
||||
_cumulative_return_df["label"] - _cumulative_return_df["bench"]
|
||||
)
|
||||
_cumulative_return_df["label"] = _cumulative_return_df["label"] - _cumulative_return_df["bench"]
|
||||
_cumulative_return_df = _cumulative_return_df.dropna()
|
||||
df_gp = _cumulative_return_df.groupby(level="datetime")
|
||||
result_list = []
|
||||
@@ -105,26 +103,20 @@ def _get_figure_with_position(
|
||||
:return:
|
||||
"""
|
||||
|
||||
cum_return_df = _get_cum_return_data_with_position(
|
||||
position, report_normal, label_data, start_date, end_date
|
||||
)
|
||||
cum_return_df = _get_cum_return_data_with_position(position, report_normal, label_data, start_date, end_date)
|
||||
cum_return_df = cum_return_df.set_index("date")
|
||||
# FIXME: support HIGH-FREQ
|
||||
cum_return_df.index = cum_return_df.index.strftime('%Y-%m-%d')
|
||||
cum_return_df.index = cum_return_df.index.strftime("%Y-%m-%d")
|
||||
|
||||
# Create figures
|
||||
for _t_name in ["buy", "sell", "buy_minus_sell", "hold"]:
|
||||
sub_graph_data = [
|
||||
(
|
||||
"cum_{}".format(_t_name),
|
||||
dict(
|
||||
row=1, col=1, graph_kwargs={"mode": "lines+markers", "xaxis": "x3"}
|
||||
),
|
||||
dict(row=1, col=1, graph_kwargs={"mode": "lines+markers", "xaxis": "x3"}),
|
||||
),
|
||||
(
|
||||
"{}_weight".format(
|
||||
_t_name.replace("minus", "plus") if "minus" in _t_name else _t_name
|
||||
),
|
||||
"{}_weight".format(_t_name.replace("minus", "plus") if "minus" in _t_name else _t_name),
|
||||
dict(row=2, col=1),
|
||||
),
|
||||
(
|
||||
@@ -240,13 +232,13 @@ def cumulative_return_graph(
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
return cost bench turnover
|
||||
return cost bench turnover
|
||||
date
|
||||
2017-01-04 0.003421 0.000864 0.011693 0.576325
|
||||
2017-01-05 0.000508 0.000447 0.000721 0.227882
|
||||
2017-01-06 -0.003321 0.000212 -0.004322 0.102765
|
||||
2017-01-09 0.006753 0.000212 0.006874 0.105864
|
||||
2017-01-10 -0.000416 0.000440 -0.003350 0.208396
|
||||
2017-01-04 0.003421 0.000864 0.011693 0.576325
|
||||
2017-01-05 0.000508 0.000447 0.000721 0.227882
|
||||
2017-01-06 -0.003321 0.000212 -0.004322 0.102765
|
||||
2017-01-09 0.006753 0.000212 0.006874 0.105864
|
||||
2017-01-10 -0.000416 0.000440 -0.003350 0.208396
|
||||
|
||||
|
||||
:param label_data: `D.features` result; index is `pd.MultiIndex`, index name is [`instrument`, `datetime`]; columns names is [`label`].
|
||||
@@ -256,12 +248,12 @@ def cumulative_return_graph(
|
||||
.. code-block:: python
|
||||
|
||||
label
|
||||
instrument datetime
|
||||
SH600004 2017-12-11 -0.013502
|
||||
2017-12-12 -0.072367
|
||||
2017-12-13 -0.068605
|
||||
2017-12-14 0.012440
|
||||
2017-12-15 -0.102778
|
||||
instrument datetime
|
||||
SH600004 2017-12-11 -0.013502
|
||||
2017-12-12 -0.072367
|
||||
2017-12-13 -0.068605
|
||||
2017-12-14 0.012440
|
||||
2017-12-15 -0.102778
|
||||
|
||||
|
||||
:param show_notebook: True or False. If True, show graph in notebook, else return figures
|
||||
@@ -272,9 +264,7 @@ def cumulative_return_graph(
|
||||
position = copy.deepcopy(position)
|
||||
report_normal = report_normal.copy()
|
||||
label_data.columns = ["label"]
|
||||
_figures = _get_figure_with_position(
|
||||
position, report_normal, label_data, start_date, end_date
|
||||
)
|
||||
_figures = _get_figure_with_position(position, report_normal, label_data, start_date, end_date)
|
||||
if show_notebook:
|
||||
BaseGraph.show_graph_in_notebook(_figures)
|
||||
else:
|
||||
|
||||
@@ -20,13 +20,13 @@ def parse_position(position: dict = None) -> pd.DataFrame:
|
||||
print(position_df.head())
|
||||
# status: 0-hold, -1-sell, 1-buy
|
||||
|
||||
amount cash count price status weight
|
||||
instrument datetime
|
||||
SZ000547 2017-01-04 44.154290 211405.285654 1 205.189575 1 0.031255
|
||||
SZ300202 2017-01-04 60.638845 211405.285654 1 154.356506 1 0.032290
|
||||
SH600158 2017-01-04 46.531681 211405.285654 1 153.895142 1 0.024704
|
||||
SH600545 2017-01-04 197.173093 211405.285654 1 48.607037 1 0.033063
|
||||
SZ000930 2017-01-04 103.938300 211405.285654 1 80.759453 1 0.028958
|
||||
amount cash count price status weight
|
||||
instrument datetime
|
||||
SZ000547 2017-01-04 44.154290 211405.285654 1 205.189575 1 0.031255
|
||||
SZ300202 2017-01-04 60.638845 211405.285654 1 154.356506 1 0.032290
|
||||
SH600158 2017-01-04 46.531681 211405.285654 1 153.895142 1 0.024704
|
||||
SH600545 2017-01-04 197.173093 211405.285654 1 48.607037 1 0.033063
|
||||
SZ000930 2017-01-04 103.938300 211405.285654 1 80.759453 1 0.028958
|
||||
|
||||
|
||||
"""
|
||||
@@ -63,15 +63,12 @@ def parse_position(position: dict = None) -> pd.DataFrame:
|
||||
# Trading day sell
|
||||
if not result_df.empty:
|
||||
_trading_day_sell_df = result_df.loc[
|
||||
(result_df["date"] == previous_data["date"])
|
||||
& (result_df.index.isin(_cur_day_sell))
|
||||
(result_df["date"] == previous_data["date"]) & (result_df.index.isin(_cur_day_sell))
|
||||
].copy()
|
||||
if not _trading_day_sell_df.empty:
|
||||
_trading_day_sell_df["status"] = -1
|
||||
_trading_day_sell_df["date"] = _trading_date
|
||||
_trading_day_df = _trading_day_df.append(
|
||||
_trading_day_sell_df, sort=False
|
||||
)
|
||||
_trading_day_df = _trading_day_df.append(_trading_day_sell_df, sort=False)
|
||||
|
||||
result_df = result_df.append(_trading_day_df, sort=True)
|
||||
|
||||
@@ -85,9 +82,7 @@ def parse_position(position: dict = None) -> pd.DataFrame:
|
||||
return result_df.set_index(["instrument", "datetime"])
|
||||
|
||||
|
||||
def _add_label_to_position(
|
||||
position_df: pd.DataFrame, label_data: pd.DataFrame
|
||||
) -> pd.DataFrame:
|
||||
def _add_label_to_position(position_df: pd.DataFrame, label_data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Concat position with custom label
|
||||
|
||||
:param position_df: position DataFrame
|
||||
@@ -98,16 +93,12 @@ def _add_label_to_position(
|
||||
_start_time = position_df.index.get_level_values(level="datetime").min()
|
||||
_end_time = position_df.index.get_level_values(level="datetime").max()
|
||||
label_data = label_data.loc(axis=0)[:, pd.to_datetime(_start_time) :]
|
||||
_result_df = pd.concat([position_df, label_data], axis=1, sort=True).reindex(
|
||||
label_data.index
|
||||
)
|
||||
_result_df = pd.concat([position_df, label_data], axis=1, sort=True).reindex(label_data.index)
|
||||
_result_df = _result_df.loc[_result_df.index.get_level_values(1) <= _end_time]
|
||||
return _result_df
|
||||
|
||||
|
||||
def _add_bench_to_position(
|
||||
position_df: pd.DataFrame = None, bench: pd.Series = None
|
||||
) -> pd.DataFrame:
|
||||
def _add_bench_to_position(position_df: pd.DataFrame = None, bench: pd.Series = None) -> pd.DataFrame:
|
||||
"""Concat position with bench
|
||||
|
||||
:param position_df: position DataFrame
|
||||
@@ -135,9 +126,7 @@ def _calculate_label_rank(df: pd.DataFrame) -> pd.DataFrame:
|
||||
|
||||
# Sell: -1, Hold: 0, Buy: 1
|
||||
for i in [-1, 0, 1]:
|
||||
g_df.loc[g_df["status"] == i, "rank_label_mean"] = g_df[
|
||||
g_df["status"] == i
|
||||
]["rank_ratio"].mean()
|
||||
g_df.loc[g_df["status"] == i, "rank_label_mean"] = g_df[g_df["status"] == i]["rank_ratio"].mean()
|
||||
|
||||
g_df["excess_return"] = g_df[_label_name] - g_df[_label_name].mean()
|
||||
return g_df
|
||||
@@ -181,7 +170,5 @@ def get_position_data(
|
||||
_date_list = _position_df.index.get_level_values(level="datetime")
|
||||
start_date = _date_list.min() if start_date is None else start_date
|
||||
end_date = _date_list.max() if end_date is None else end_date
|
||||
_position_df = _position_df.loc[
|
||||
(start_date <= _date_list) & (_date_list <= end_date)
|
||||
]
|
||||
_position_df = _position_df.loc[(start_date <= _date_list) & (_date_list <= end_date)]
|
||||
return _position_df
|
||||
|
||||
@@ -46,7 +46,7 @@ def _get_figure_with_position(
|
||||
|
||||
_res_df = pd.DataFrame.from_dict(res_dict, orient="index")
|
||||
# FIXME: support HIGH-FREQ
|
||||
_res_df.index = _res_df.index.strftime('%Y-%m-%d')
|
||||
_res_df.index = _res_df.index.strftime("%Y-%m-%d")
|
||||
for _col in _res_df.columns:
|
||||
yield ScatterGraph(
|
||||
_res_df.loc[:, [_col]],
|
||||
@@ -105,12 +105,12 @@ def rank_label_graph(
|
||||
.. code-block:: python
|
||||
|
||||
label
|
||||
instrument datetime
|
||||
SH600004 2017-12-11 -0.013502
|
||||
2017-12-12 -0.072367
|
||||
2017-12-13 -0.068605
|
||||
2017-12-14 0.012440
|
||||
2017-12-15 -0.102778
|
||||
instrument datetime
|
||||
SH600004 2017-12-11 -0.013502
|
||||
2017-12-12 -0.072367
|
||||
2017-12-13 -0.068605
|
||||
2017-12-14 0.012440
|
||||
2017-12-15 -0.102778
|
||||
|
||||
|
||||
:param start_date: start date
|
||||
|
||||
@@ -48,20 +48,12 @@ def _calculate_report_data(df: pd.DataFrame) -> pd.DataFrame:
|
||||
report_df["cum_return_w_cost"] = (df["return"] - df["cost"]).cumsum()
|
||||
# report_df['cum_return'] - report_df['cum_return'].cummax()
|
||||
report_df["return_wo_mdd"] = _calculate_mdd(report_df["cum_return_wo_cost"])
|
||||
report_df["return_w_cost_mdd"] = _calculate_mdd(
|
||||
(df["return"] - df["cost"]).cumsum()
|
||||
)
|
||||
report_df["return_w_cost_mdd"] = _calculate_mdd((df["return"] - df["cost"]).cumsum())
|
||||
|
||||
report_df["cum_ex_return_wo_cost"] = (df["return"] - df["bench"]).cumsum()
|
||||
report_df["cum_ex_return_w_cost"] = (
|
||||
df["return"] - df["bench"] - df["cost"]
|
||||
).cumsum()
|
||||
report_df["cum_ex_return_wo_cost_mdd"] = _calculate_mdd(
|
||||
(df["return"] - df["bench"]).cumsum()
|
||||
)
|
||||
report_df["cum_ex_return_w_cost_mdd"] = _calculate_mdd(
|
||||
(df["return"] - df["cost"] - df["bench"]).cumsum()
|
||||
)
|
||||
report_df["cum_ex_return_w_cost"] = (df["return"] - df["bench"] - df["cost"]).cumsum()
|
||||
report_df["cum_ex_return_wo_cost_mdd"] = _calculate_mdd((df["return"] - df["bench"]).cumsum())
|
||||
report_df["cum_ex_return_w_cost_mdd"] = _calculate_mdd((df["return"] - df["cost"] - df["bench"]).cumsum())
|
||||
# return_wo_mdd , return_w_cost_mdd, cum_ex_return_wo_cost_mdd, cum_ex_return_w
|
||||
|
||||
report_df["turnover"] = df["turnover"]
|
||||
@@ -113,13 +105,7 @@ def _report_figure(df: pd.DataFrame) -> [list, tuple]:
|
||||
)
|
||||
for i in range(2, 8):
|
||||
# yaxis
|
||||
_subplot_layout.update(
|
||||
{
|
||||
"yaxis{}".format(i): dict(
|
||||
zeroline=True, showline=True, showticklabels=True
|
||||
)
|
||||
}
|
||||
)
|
||||
_subplot_layout.update({"yaxis{}".format(i): dict(zeroline=True, showline=True, showticklabels=True)})
|
||||
_layout_style = dict(
|
||||
height=1200,
|
||||
title=" ",
|
||||
@@ -134,7 +120,9 @@ def _report_figure(df: pd.DataFrame) -> [list, tuple]:
|
||||
"y1": 1,
|
||||
"fillcolor": "#d3d3d3",
|
||||
"opacity": 0.3,
|
||||
"line": {"width": 0,},
|
||||
"line": {
|
||||
"width": 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "rect",
|
||||
@@ -146,7 +134,9 @@ def _report_figure(df: pd.DataFrame) -> [list, tuple]:
|
||||
"y1": 0.55,
|
||||
"fillcolor": "#d3d3d3",
|
||||
"opacity": 0.3,
|
||||
"line": {"width": 0,},
|
||||
"line": {
|
||||
"width": 0,
|
||||
},
|
||||
},
|
||||
],
|
||||
)
|
||||
@@ -200,13 +190,13 @@ def report_graph(report_df: pd.DataFrame, show_notebook: bool = True) -> [list,
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
return cost bench turnover
|
||||
return cost bench turnover
|
||||
date
|
||||
2017-01-04 0.003421 0.000864 0.011693 0.576325
|
||||
2017-01-05 0.000508 0.000447 0.000721 0.227882
|
||||
2017-01-06 -0.003321 0.000212 -0.004322 0.102765
|
||||
2017-01-09 0.006753 0.000212 0.006874 0.105864
|
||||
2017-01-10 -0.000416 0.000440 -0.003350 0.208396
|
||||
2017-01-04 0.003421 0.000864 0.011693 0.576325
|
||||
2017-01-05 0.000508 0.000447 0.000721 0.227882
|
||||
2017-01-06 -0.003321 0.000212 -0.004322 0.102765
|
||||
2017-01-09 0.006753 0.000212 0.006874 0.105864
|
||||
2017-01-10 -0.000416 0.000440 -0.003350 0.208396
|
||||
|
||||
|
||||
:param show_notebook: whether to display graphics in notebook, the default is **True**
|
||||
|
||||
@@ -32,13 +32,9 @@ def _get_risk_analysis_data_with_report(
|
||||
# analysis["pred_long_short"] = risk_analysis(report_long_short_df["long_short"])
|
||||
|
||||
if not report_normal_df.empty:
|
||||
analysis["excess_return_without_cost"] = risk_analysis(
|
||||
report_normal_df["return"] - report_normal_df["bench"]
|
||||
)
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal_df["return"] - report_normal_df["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal_df["return"]
|
||||
- report_normal_df["bench"]
|
||||
- report_normal_df["cost"]
|
||||
report_normal_df["return"] - report_normal_df["bench"] - report_normal_df["cost"]
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
analysis_df["date"] = date
|
||||
@@ -67,9 +63,7 @@ def _get_monthly_risk_analysis_with_report(report_normal_df: pd.DataFrame) -> pd
|
||||
"""
|
||||
|
||||
# Group by month
|
||||
report_normal_gp = report_normal_df.groupby(
|
||||
[report_normal_df.index.year, report_normal_df.index.month]
|
||||
)
|
||||
report_normal_gp = report_normal_df.groupby([report_normal_df.index.year, report_normal_df.index.month])
|
||||
# report_long_short_gp = report_long_short_df.groupby(
|
||||
# [report_long_short_df.index.year, report_long_short_df.index.month]
|
||||
# )
|
||||
@@ -96,9 +90,7 @@ def _get_monthly_risk_analysis_with_report(report_normal_df: pd.DataFrame) -> pd
|
||||
return _monthly_df
|
||||
|
||||
|
||||
def _get_monthly_analysis_with_feature(
|
||||
monthly_df: pd.DataFrame, feature: str = "annualized_return"
|
||||
) -> pd.DataFrame:
|
||||
def _get_monthly_analysis_with_feature(monthly_df: pd.DataFrame, feature: str = "annualized_return") -> pd.DataFrame:
|
||||
"""
|
||||
|
||||
:param monthly_df:
|
||||
@@ -108,9 +100,7 @@ def _get_monthly_analysis_with_feature(
|
||||
_monthly_df_gp = monthly_df.reset_index().groupby(["level_1"])
|
||||
|
||||
_name_df = _monthly_df_gp.get_group(feature).set_index(["level_0", "level_1"])
|
||||
_temp_df = _name_df.pivot_table(
|
||||
index="date", values=["risk"], columns=_name_df.index
|
||||
)
|
||||
_temp_df = _name_df.pivot_table(index="date", values=["risk"], columns=_name_df.index)
|
||||
_temp_df.columns = map(lambda x: "_".join(x[-1]), _temp_df.columns)
|
||||
_temp_df.index = _temp_df.index.strftime("%Y-%m")
|
||||
|
||||
@@ -126,9 +116,7 @@ def _get_risk_analysis_figure(analysis_df: pd.DataFrame) -> Iterable[py.Figure]:
|
||||
if analysis_df is None:
|
||||
return []
|
||||
|
||||
_figure = SubplotsGraph(
|
||||
_get_all_risk_analysis(analysis_df), kind_map=dict(kind="BarGraph", kwargs={})
|
||||
).figure
|
||||
_figure = SubplotsGraph(_get_all_risk_analysis(analysis_df), kind_map=dict(kind="BarGraph", kwargs={})).figure
|
||||
return (_figure,)
|
||||
|
||||
|
||||
@@ -141,7 +129,7 @@ def _get_monthly_risk_analysis_figure(report_normal_df: pd.DataFrame) -> Iterabl
|
||||
"""
|
||||
|
||||
# if report_normal_df is None and report_long_short_df is None:
|
||||
# return []
|
||||
# return []
|
||||
if report_normal_df is None:
|
||||
return []
|
||||
|
||||
@@ -231,13 +219,13 @@ def risk_analysis_graph(
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
return cost bench turnover
|
||||
return cost bench turnover
|
||||
date
|
||||
2017-01-04 0.003421 0.000864 0.011693 0.576325
|
||||
2017-01-05 0.000508 0.000447 0.000721 0.227882
|
||||
2017-01-06 -0.003321 0.000212 -0.004322 0.102765
|
||||
2017-01-09 0.006753 0.000212 0.006874 0.105864
|
||||
2017-01-10 -0.000416 0.000440 -0.003350 0.208396
|
||||
2017-01-04 0.003421 0.000864 0.011693 0.576325
|
||||
2017-01-05 0.000508 0.000447 0.000721 0.227882
|
||||
2017-01-06 -0.003321 0.000212 -0.004322 0.102765
|
||||
2017-01-09 0.006753 0.000212 0.006874 0.105864
|
||||
2017-01-10 -0.000416 0.000440 -0.003350 0.208396
|
||||
|
||||
|
||||
:param report_long_short_df: **df.index.name** must be **date**, df.columns contain **long**, **short**, **long_short**
|
||||
@@ -245,13 +233,13 @@ def risk_analysis_graph(
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
long short long_short
|
||||
long short long_short
|
||||
date
|
||||
2017-01-04 -0.001360 0.001394 0.000034
|
||||
2017-01-05 0.002456 0.000058 0.002514
|
||||
2017-01-06 0.000120 0.002739 0.002859
|
||||
2017-01-09 0.001436 0.001838 0.003273
|
||||
2017-01-10 0.000824 -0.001944 -0.001120
|
||||
2017-01-04 -0.001360 0.001394 0.000034
|
||||
2017-01-05 0.002456 0.000058 0.002514
|
||||
2017-01-06 0.000120 0.002739 0.002859
|
||||
2017-01-09 0.001436 0.001838 0.003273
|
||||
2017-01-10 0.000824 -0.001944 -0.001120
|
||||
|
||||
|
||||
:param show_notebook: Whether to display graphics in a notebook, default **True**
|
||||
@@ -263,7 +251,7 @@ def risk_analysis_graph(
|
||||
_get_monthly_risk_analysis_figure(
|
||||
report_normal_df,
|
||||
# report_long_short_df,
|
||||
)
|
||||
)
|
||||
)
|
||||
if show_notebook:
|
||||
ScatterGraph.show_graph_in_notebook(_figure_list)
|
||||
|
||||
@@ -14,18 +14,12 @@ def _get_score_ic(pred_label: pd.DataFrame):
|
||||
"""
|
||||
concat_data = pred_label.copy()
|
||||
concat_data.dropna(axis=0, how="any", inplace=True)
|
||||
_ic = concat_data.groupby(level="datetime").apply(
|
||||
lambda x: x["label"].corr(x["score"])
|
||||
)
|
||||
_rank_ic = concat_data.groupby(level="datetime").apply(
|
||||
lambda x: x["label"].corr(x["score"], method="spearman")
|
||||
)
|
||||
_ic = concat_data.groupby(level="datetime").apply(lambda x: x["label"].corr(x["score"]))
|
||||
_rank_ic = concat_data.groupby(level="datetime").apply(lambda x: x["label"].corr(x["score"], method="spearman"))
|
||||
return pd.DataFrame({"ic": _ic, "rank_ic": _rank_ic})
|
||||
|
||||
|
||||
def score_ic_graph(
|
||||
pred_label: pd.DataFrame, show_notebook: bool = True
|
||||
) -> [list, tuple]:
|
||||
def score_ic_graph(pred_label: pd.DataFrame, show_notebook: bool = True) -> [list, tuple]:
|
||||
"""score IC
|
||||
|
||||
Example:
|
||||
@@ -47,12 +41,12 @@ def score_ic_graph(
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
instrument datetime score label
|
||||
SH600004 2017-12-11 -0.013502 -0.013502
|
||||
2017-12-12 -0.072367 -0.072367
|
||||
2017-12-13 -0.068605 -0.068605
|
||||
2017-12-14 0.012440 0.012440
|
||||
2017-12-15 -0.102778 -0.102778
|
||||
instrument datetime score label
|
||||
SH600004 2017-12-11 -0.013502 -0.013502
|
||||
2017-12-12 -0.072367 -0.072367
|
||||
2017-12-13 -0.068605 -0.068605
|
||||
2017-12-14 0.012440 0.012440
|
||||
2017-12-15 -0.102778 -0.102778
|
||||
|
||||
|
||||
:param show_notebook: whether to display graphics in notebook, the default is **True**
|
||||
|
||||
@@ -142,7 +142,7 @@ class SeriesDFilter(BaseDFilter):
|
||||
the series of bool value indicating whether the date satisfies the filter condition and exists in target timestamp
|
||||
"""
|
||||
fstart, fend = list(filter_series.keys())[0], list(filter_series.keys())[-1]
|
||||
filter_series = filter_series.astype('bool') # Make sure the filter_series is boolean
|
||||
filter_series = filter_series.astype("bool") # Make sure the filter_series is boolean
|
||||
timestamp_series[fstart:fend] = timestamp_series[fstart:fend] & filter_series
|
||||
return timestamp_series
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ try:
|
||||
from ._libs.expanding import expanding_slope, expanding_rsquare, expanding_resi
|
||||
except ImportError as err:
|
||||
print(err)
|
||||
print('Do not import qlib package in the repository directory')
|
||||
print("Do not import qlib package in the repository directory")
|
||||
exit(-1)
|
||||
|
||||
__all__ = (
|
||||
@@ -1342,7 +1342,9 @@ class PairRolling(ExpressionOps):
|
||||
if self.N == 0:
|
||||
return np.inf
|
||||
return (
|
||||
max(self.feature_left.get_longest_back_rolling(), self.feature_right.get_longest_back_rolling()) + self.N - 1
|
||||
max(self.feature_left.get_longest_back_rolling(), self.feature_right.get_longest_back_rolling())
|
||||
+ self.N
|
||||
- 1
|
||||
)
|
||||
|
||||
def get_extended_window_size(self):
|
||||
@@ -1411,4 +1413,3 @@ class Cov(PairRolling):
|
||||
|
||||
def __init__(self, feature_left, feature_right, N):
|
||||
super(Cov, self).__init__(feature_left, feature_right, N, "cov")
|
||||
|
||||
|
||||
@@ -154,7 +154,7 @@ def get_module_by_module_path(module_path):
|
||||
:return:
|
||||
"""
|
||||
|
||||
if module_path.endswith(".py"):
|
||||
if module_path.endswith(".py"):
|
||||
module_spec = importlib.util.spec_from_file_location("", module_path)
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
module_spec.loader.exec_module(module)
|
||||
|
||||
@@ -1,24 +1,28 @@
|
||||
import sys, platform
|
||||
import qlib
|
||||
|
||||
|
||||
def linux_distribution():
|
||||
try:
|
||||
return platform.linux_distribution()
|
||||
except:
|
||||
return "N/A"
|
||||
|
||||
print('Qlib version: {} \n'.format(qlib.__version__))
|
||||
print("""Python version: {} \n
|
||||
|
||||
print("Qlib version: {} \n".format(qlib.__version__))
|
||||
print(
|
||||
"""Python version: {} \n
|
||||
linux_distribution: {}
|
||||
system: {}
|
||||
machine: {}
|
||||
platform: {}
|
||||
version: {}
|
||||
""".format(
|
||||
sys.version.split('\n'),
|
||||
linux_distribution(),
|
||||
platform.system(),
|
||||
platform.machine(),
|
||||
platform.platform(),
|
||||
platform.version(),
|
||||
))
|
||||
sys.version.split("\n"),
|
||||
linux_distribution(),
|
||||
platform.system(),
|
||||
platform.machine(),
|
||||
platform.platform(),
|
||||
platform.version(),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -116,9 +116,7 @@ class YahooCollector:
|
||||
return error_symbol
|
||||
|
||||
def collector_data(self):
|
||||
"""collector data
|
||||
|
||||
"""
|
||||
"""collector data"""
|
||||
logger.info("start collector yahoo data......")
|
||||
stock_list = self.stock_list
|
||||
for i in range(self._max_collector_count):
|
||||
@@ -131,7 +129,7 @@ class YahooCollector:
|
||||
self.save_stock(_symbol, max(_df_list, key=len))
|
||||
|
||||
logger.warning(f"less than {MIN_NUMBERS_TRADING} stock list: {list(self._mini_symbol_map.keys())}")
|
||||
|
||||
|
||||
self.download_csi300_data()
|
||||
|
||||
def download_csi300_data(self):
|
||||
@@ -280,8 +278,7 @@ class Run:
|
||||
YahooCollector(self.source_dir).download_csi300_data()
|
||||
|
||||
def download_bench_data(self):
|
||||
"""download bench stock data(SH000300)
|
||||
"""
|
||||
"""download bench stock data(SH000300)"""
|
||||
|
||||
def collector_data(self):
|
||||
"""download -> normalize -> dump data
|
||||
|
||||
@@ -34,7 +34,9 @@ class GetData:
|
||||
raise requests.exceptions.HTTPError()
|
||||
|
||||
chuck_size = 1024
|
||||
logger.warning(f"The data for the example is collected from Yahoo Finance. Please be aware that the quality of the data might not be perfect. (You can refer to the original data source: https://finance.yahoo.com/lookup.)")
|
||||
logger.warning(
|
||||
f"The data for the example is collected from Yahoo Finance. Please be aware that the quality of the data might not be perfect. (You can refer to the original data source: https://finance.yahoo.com/lookup.)"
|
||||
)
|
||||
logger.info(f"{file_name} downloading......")
|
||||
with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar:
|
||||
with target_path.open("wb") as fp:
|
||||
|
||||
8
setup.py
8
setup.py
@@ -61,7 +61,7 @@ NUMPY_INCLUDE = numpy.get_include()
|
||||
|
||||
here = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
with open(os.path.join(here, 'README.md'), encoding='utf-8') as f:
|
||||
with open(os.path.join(here, "README.md"), encoding="utf-8") as f:
|
||||
long_description = f.read()
|
||||
|
||||
|
||||
@@ -85,11 +85,11 @@ extensions = [
|
||||
setup(
|
||||
name=NAME,
|
||||
version=VERSION,
|
||||
license = "MIT Licence",
|
||||
url = "https://github.com/microsoft/qlib",
|
||||
license="MIT Licence",
|
||||
url="https://github.com/microsoft/qlib",
|
||||
description=DESCRIPTION,
|
||||
long_description=long_description,
|
||||
long_description_content_type='text/markdown',
|
||||
long_description_content_type="text/markdown",
|
||||
python_requires=REQUIRES_PYTHON,
|
||||
packages=find_packages(exclude=("tests",)),
|
||||
# if your package is a single module, use this instead of 'packages':
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import qlib
|
||||
@@ -10,7 +9,6 @@ from qlib.utils import exists_qlib_data
|
||||
|
||||
|
||||
class TestDataset(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
# use default data
|
||||
@@ -24,9 +22,9 @@ class TestDataset(unittest.TestCase):
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
def testCSI300(self):
|
||||
close_p = D.features(D.instruments('csi300'), ['$close'])
|
||||
size = close_p.groupby('datetime').size()
|
||||
cnt = close_p.groupby('datetime').count()['$close']
|
||||
close_p = D.features(D.instruments("csi300"), ["$close"])
|
||||
size = close_p.groupby("datetime").size()
|
||||
cnt = close_p.groupby("datetime").count()["$close"]
|
||||
size_desc = size.describe(percentiles=np.arange(0.1, 1.0, 0.1))
|
||||
cnt_desc = cnt.describe(percentiles=np.arange(0.1, 1.0, 0.1))
|
||||
|
||||
@@ -35,22 +33,21 @@ class TestDataset(unittest.TestCase):
|
||||
|
||||
self.assertLessEqual(size_desc.loc["max"], 305, "Excessive number of CSI300 constituent stocks")
|
||||
self.assertGreaterEqual(size_desc.loc["80%"], 290, "Insufficient number of CSI300 constituent stocks")
|
||||
|
||||
|
||||
self.assertLessEqual(cnt_desc.loc["max"], 305, "Excessive number of CSI300 constituent stocks")
|
||||
# FIXME: Due to the low quality of data. Hard to make sure there are enough data
|
||||
# self.assertEqual(cnt_desc.loc["80%"], 300, "Insufficient number of CSI300 constituent stocks")
|
||||
|
||||
def testClose(self):
|
||||
close_p = D.features(D.instruments('csi300'), ['Ref($close, 1)/$close - 1'])
|
||||
close_p = D.features(D.instruments("csi300"), ["Ref($close, 1)/$close - 1"])
|
||||
close_desc = close_p.describe(percentiles=np.arange(0.1, 1.0, 0.1))
|
||||
print(close_desc)
|
||||
self.assertLessEqual(abs(close_desc.loc["90%"][0]), 0.1, "Close value is abnormal")
|
||||
self.assertLessEqual(abs(close_desc.loc["10%"][0]), 0.1, "Close value is abnormal")
|
||||
# FIXME: The yahoo data is not perfect. We have to
|
||||
# FIXME: The yahoo data is not perfect. We have to
|
||||
# self.assertLessEqual(abs(close_desc.loc["max"][0]), 0.2, "Close value is abnormal")
|
||||
# self.assertGreaterEqual(close_desc.loc["min"][0], -0.2, "Close value is abnormal")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
|
||||
@@ -79,9 +79,9 @@ def train():
|
||||
model performance
|
||||
"""
|
||||
# get data
|
||||
x_train, y_train, x_validate, y_validate, x_test, y_test = Alpha158(
|
||||
**DATA_HANDLER_CONFIG
|
||||
).get_split_data(**TRAINER_CONFIG)
|
||||
x_train, y_train, x_validate, y_validate, x_test, y_test = Alpha158(**DATA_HANDLER_CONFIG).get_split_data(
|
||||
**TRAINER_CONFIG
|
||||
)
|
||||
|
||||
# train
|
||||
model = LGBModel(**MODEL_CONFIG)
|
||||
@@ -127,7 +127,9 @@ def backtest(pred):
|
||||
def analyze(report_normal):
|
||||
_analysis = dict()
|
||||
_analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
_analysis["excess_return_with_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"] - report_normal["cost"])
|
||||
_analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
analysis_df = pd.concat(_analysis) # type: pd.DataFrame
|
||||
print(analysis_df)
|
||||
return analysis_df
|
||||
@@ -155,12 +157,12 @@ class TestAllFlow(unittest.TestCase):
|
||||
self.assertGreaterEqual(model_pearsonr["model_pearsonr"], 0, "train failed")
|
||||
|
||||
def test_1_backtest(self):
|
||||
TestAllFlow.REPORT_NORMAL, TestAllFlow.POSITIONS = backtest(
|
||||
TestAllFlow.PRED_SCORE
|
||||
)
|
||||
TestAllFlow.REPORT_NORMAL, TestAllFlow.POSITIONS = backtest(TestAllFlow.PRED_SCORE)
|
||||
analyze_df = analyze(TestAllFlow.REPORT_NORMAL)
|
||||
self.assertGreaterEqual(
|
||||
analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0], 0.10, "backtest failed",
|
||||
analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0],
|
||||
0.10,
|
||||
"backtest failed",
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user