mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-29 00:51:19 +08:00
Compare commits
16 Commits
yx/docs_fo
...
6cma
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1f5f3a6af0 | ||
|
|
2f8fc8d28a | ||
|
|
3e9ccd3ad2 | ||
|
|
94268619c4 | ||
|
|
8d60a6a02b | ||
|
|
7234308651 | ||
|
|
acf5df27ce | ||
|
|
37a59f28d3 | ||
|
|
b084c352f5 | ||
|
|
9e22e5168b | ||
|
|
dceff7b471 | ||
|
|
7f1e8c5206 | ||
|
|
46264dfec9 | ||
|
|
754799ab05 | ||
|
|
32c3070b73 | ||
|
|
40de67265a |
2
.github/workflows/test_qlib_from_pip.yml
vendored
2
.github/workflows/test_qlib_from_pip.yml
vendored
@@ -13,7 +13,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-18.04, ubuntu-20.04, macos-11, macos-latest]
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
|
||||
10
.github/workflows/test_qlib_from_source.yml
vendored
10
.github/workflows/test_qlib_from_source.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-18.04, ubuntu-20.04, macos-11, macos-latest]
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
@@ -28,8 +28,10 @@ jobs:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Update pip to the latest version
|
||||
# pip release version 23.1 on Apr.15 2023, CI failed to run, Please refer to #1495 ofr detailed logs.
|
||||
# The pip version has been temporarily fixed to 23.0.1
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install pip==23.0.1
|
||||
|
||||
- name: Installing pytorch for macos
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
@@ -37,15 +39,13 @@ jobs:
|
||||
python -m pip install torch torchvision torchaudio
|
||||
|
||||
- name: Installing pytorch for ubuntu
|
||||
if: ${{ matrix.os == 'ubuntu-18.04' || matrix.os == 'ubuntu-20.04' }}
|
||||
if: ${{ matrix.os == 'ubuntu-20.04' || matrix.os == 'ubuntu-22.04' }}
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
- name: Installing pytorch for windows
|
||||
if: ${{ matrix.os == 'windows-latest' }}
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install torch torchvision torchaudio
|
||||
|
||||
- name: Set up Python tools
|
||||
|
||||
@@ -14,7 +14,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-18.04, ubuntu-20.04, macos-11, macos-latest]
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
@@ -28,9 +28,10 @@ jobs:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set up Python tools
|
||||
# pip release version 23.1 on Apr.15 2023, CI failed to run, Please refer to #1495 ofr detailed logs.
|
||||
# The pip version has been temporarily fixed to 23.0.1
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
# python -m pip is necessary to upgrade pip.
|
||||
python -m pip install pip==23.0.1
|
||||
pip install --upgrade cython numpy
|
||||
pip install -e .[dev]
|
||||
|
||||
|
||||
@@ -42,13 +42,11 @@ Features released before 2021 are not listed here.
|
||||
<img src="http://fintech.msra.cn/images_v070/logo/1.png" />
|
||||
</p>
|
||||
|
||||
Qlib is an open-source, AI-oriented quantitative investment platform that aims to realize the potential, empower research, and create value using AI technologies in quantitative investment, from exploring ideas to implementing productions. Qlib supports diverse machine learning modeling paradigms, including supervised learning, market dynamics modeling, and reinforcement learning.
|
||||
|
||||
Qlib is an AI-oriented quantitative investment platform, which aims to realize the potential, empower the research, and create the value of AI technologies in quantitative investment.
|
||||
An increasing number of SOTA Quant research works/papers in diverse paradigms are being released in Qlib to collaboratively solve key challenges in quantitative investment. For example, 1) using supervised learning to mine the market's complex non-linear patterns from rich and heterogeneous financial data, 2) modeling the dynamic nature of the financial market using adaptive concept drift technology, and 3) using reinforcement learning to model continuous investment decisions and assist investors in optimizing their trading strategies.
|
||||
|
||||
It contains the full ML pipeline of data processing, model training, back-testing; and covers the entire chain of quantitative investment: alpha seeking, risk modeling, portfolio optimization, and order execution.
|
||||
|
||||
With Qlib, users can easily try ideas to create better Quant investment strategies.
|
||||
|
||||
For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative Investment Platform"](https://arxiv.org/abs/2009.11189).
|
||||
|
||||
|
||||
|
||||
@@ -64,8 +64,6 @@ task:
|
||||
kwargs:
|
||||
loss: mse
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
optimizer: adam
|
||||
max_steps: 8000
|
||||
batch_size: 8192
|
||||
|
||||
@@ -64,8 +64,6 @@ task:
|
||||
kwargs:
|
||||
loss: mse
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
optimizer: adam
|
||||
max_steps: 8000
|
||||
batch_size: 8192
|
||||
|
||||
@@ -52,8 +52,6 @@ task:
|
||||
kwargs:
|
||||
loss: mse
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
optimizer: adam
|
||||
max_steps: 8000
|
||||
batch_size: 4096
|
||||
|
||||
@@ -52,8 +52,6 @@ task:
|
||||
kwargs:
|
||||
loss: mse
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
optimizer: adam
|
||||
max_steps: 8000
|
||||
batch_size: 4096
|
||||
|
||||
107
examples/benchmarks_dynamic/DDG-DA/vis_data.py
Normal file
107
examples/benchmarks_dynamic/DDG-DA/vis_data.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import pickle
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
sns.set(color_codes=True)
|
||||
plt.rcParams["font.sans-serif"] = "SimHei"
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
# tqdm.pandas() # for progress_apply
|
||||
# %matplotlib inline
|
||||
# %load_ext autoreload
|
||||
|
||||
|
||||
# # Meta Input
|
||||
|
||||
# +
|
||||
with open("./internal_data_s20.pkl", "rb") as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
data.data_ic_df.columns.names = ["start_date", "end_date"]
|
||||
|
||||
data_sim = data.data_ic_df.droplevel(axis=1, level="end_date")
|
||||
|
||||
data_sim.index.name = "test datetime"
|
||||
# -
|
||||
|
||||
plt.figure(figsize=(40, 20))
|
||||
sns.heatmap(data_sim)
|
||||
|
||||
plt.figure(figsize=(40, 20))
|
||||
sns.heatmap(data_sim.rolling(20).mean())
|
||||
|
||||
# # Meta Model
|
||||
|
||||
from qlib import auto_init
|
||||
|
||||
auto_init()
|
||||
from qlib.workflow import R
|
||||
|
||||
exp = R.get_exp(experiment_name="DDG-DA")
|
||||
meta_rec = exp.list_recorders(rtype="list", max_results=1)[0]
|
||||
meta_m = meta_rec.load_object("model")
|
||||
|
||||
pd.DataFrame(meta_m.tn.twm.linear.weight.detach().numpy()).T[0].plot()
|
||||
|
||||
pd.DataFrame(meta_m.tn.twm.linear.weight.detach().numpy()).T[0].rolling(5).mean().plot()
|
||||
|
||||
# # Meta Output
|
||||
|
||||
# +
|
||||
with open("./tasks_s20.pkl", "rb") as f:
|
||||
tasks = pickle.load(f)
|
||||
|
||||
task_df = {}
|
||||
for t in tasks:
|
||||
test_seg = t["dataset"]["kwargs"]["segments"]["test"]
|
||||
if None not in test_seg:
|
||||
# The last rolling is skipped.
|
||||
task_df[test_seg] = t["reweighter"].time_weight
|
||||
task_df = pd.concat(task_df)
|
||||
|
||||
task_df.index.names = ["OS_start", "OS_end", "IS_start", "IS_end"]
|
||||
task_df = task_df.droplevel(["OS_end", "IS_end"])
|
||||
task_df = task_df.unstack("OS_start")
|
||||
# -
|
||||
|
||||
plt.figure(figsize=(40, 20))
|
||||
sns.heatmap(task_df.T)
|
||||
|
||||
plt.figure(figsize=(40, 20))
|
||||
sns.heatmap(task_df.rolling(10).mean().T)
|
||||
|
||||
# # Sub Models
|
||||
#
|
||||
# NOTE:
|
||||
# - this section assumes that the model is Linear model!!
|
||||
# - Other models does not support this analysis
|
||||
|
||||
exp = R.get_exp(experiment_name="rolling_ds")
|
||||
|
||||
|
||||
def show_linear_weight(exp):
|
||||
coef_df = {}
|
||||
for r in exp.list_recorders("list"):
|
||||
t = r.load_object("task")
|
||||
if None in t["dataset"]["kwargs"]["segments"]["test"]:
|
||||
continue
|
||||
m = r.load_object("params.pkl")
|
||||
coef_df[t["dataset"]["kwargs"]["segments"]["test"]] = pd.Series(m.coef_)
|
||||
|
||||
coef_df = pd.concat(coef_df)
|
||||
|
||||
coef_df.index.names = ["test_start", "test_end", "coef_idx"]
|
||||
|
||||
coef_df = coef_df.droplevel("test_end").unstack("coef_idx").T
|
||||
|
||||
plt.figure(figsize=(40, 20))
|
||||
sns.heatmap(coef_df)
|
||||
plt.show()
|
||||
|
||||
|
||||
show_linear_weight(R.get_exp(experiment_name="rolling_ds"))
|
||||
|
||||
show_linear_weight(R.get_exp(experiment_name="rolling_models"))
|
||||
@@ -10,8 +10,10 @@ import pandas as pd
|
||||
import fire
|
||||
import sys
|
||||
import pickle
|
||||
from typing import Optional
|
||||
from qlib import auto_init
|
||||
from qlib.model.trainer import TrainerR
|
||||
from qlib.typehint import Literal
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
from qlib.tests.data import GetData
|
||||
@@ -30,7 +32,33 @@ class DDGDA:
|
||||
- `rm -r mlruns`
|
||||
"""
|
||||
|
||||
def __init__(self, sim_task_model="linear", forecast_model="linear"):
|
||||
def __init__(
|
||||
self,
|
||||
sim_task_model: Literal["linear", "gbdt"] = "linear",
|
||||
forecast_model: Literal["linear", "gbdt"] = "linear",
|
||||
h_path: Optional[str] = None,
|
||||
test_end: Optional[str] = None,
|
||||
train_start: Optional[str] = None,
|
||||
meta_1st_train_end: Optional[str] = None,
|
||||
task_ext_conf: Optional[dict] = None,
|
||||
alpha: float = 0.0,
|
||||
proxy_hd: str = "handler_proxy.pkl",
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
train_start: Optional[str]
|
||||
the start datetime for data. It is used in training start time (for both tasks & meta learing)
|
||||
test_end: Optional[str]
|
||||
the end datetime for data. It is used in test end time
|
||||
meta_1st_train_end: Optional[str]
|
||||
the datetime of training end of the first meta_task
|
||||
alpha: float
|
||||
Setting the L2 regularization for ridge
|
||||
The `alpha` is only passed to MetaModelDS (it is not passed to sim_task_model currently..)
|
||||
"""
|
||||
self.step = 20
|
||||
# NOTE:
|
||||
# the horizon must match the meaning in the base task template
|
||||
@@ -38,10 +66,19 @@ class DDGDA:
|
||||
self.meta_exp_name = "DDG-DA"
|
||||
self.sim_task_model = sim_task_model # The model to capture the distribution of data.
|
||||
self.forecast_model = forecast_model # downstream forecasting models' type
|
||||
self.rb_kwargs = {
|
||||
"h_path": h_path,
|
||||
"test_end": test_end,
|
||||
"train_start": train_start,
|
||||
"task_ext_conf": task_ext_conf,
|
||||
}
|
||||
self.alpha = alpha
|
||||
self.meta_1st_train_end = meta_1st_train_end
|
||||
self.proxy_hd = proxy_hd
|
||||
|
||||
def get_feature_importance(self):
|
||||
# this must be lightGBM, because it needs to get the feature importance
|
||||
rb = RollingBenchmark(model_type="gbdt")
|
||||
rb = RollingBenchmark(model_type="gbdt", **self.rb_kwargs)
|
||||
task = rb.basic_task()
|
||||
|
||||
with R.start(experiment_name="feature_importance"):
|
||||
@@ -69,7 +106,7 @@ class DDGDA:
|
||||
fi = self.get_feature_importance()
|
||||
col_selected = fi.nlargest(topk)
|
||||
|
||||
rb = RollingBenchmark(model_type=self.sim_task_model)
|
||||
rb = RollingBenchmark(model_type=self.sim_task_model, **self.rb_kwargs)
|
||||
task = rb.basic_task()
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
prep_ds = dataset.prepare(slice(None), col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
@@ -96,7 +133,7 @@ class DDGDA:
|
||||
"kwargs": {"config": DIRNAME / "fea_label_df.pkl"},
|
||||
}
|
||||
)
|
||||
handler.to_pickle(DIRNAME / "handler_proxy.pkl", dump_all=True)
|
||||
handler.to_pickle(DIRNAME / self.proxy_hd, dump_all=True)
|
||||
|
||||
@property
|
||||
def _internal_data_path(self):
|
||||
@@ -108,7 +145,7 @@ class DDGDA:
|
||||
This function will dump the input data for meta model
|
||||
"""
|
||||
# According to the experiments, the choice of the model type is very important for achieving good results
|
||||
rb = RollingBenchmark(model_type=self.sim_task_model)
|
||||
rb = RollingBenchmark(model_type=self.sim_task_model, **self.rb_kwargs)
|
||||
sim_task = rb.basic_task()
|
||||
|
||||
if self.sim_task_model == "gbdt":
|
||||
@@ -122,24 +159,27 @@ class DDGDA:
|
||||
with self._internal_data_path.open("wb") as f:
|
||||
pickle.dump(internal_data, f)
|
||||
|
||||
def train_meta_model(self):
|
||||
def train_meta_model(self, fill_method="max"):
|
||||
"""
|
||||
training a meta model based on a simplified linear proxy model;
|
||||
"""
|
||||
|
||||
# 1) leverage the simplified proxy forecasting model to train meta model.
|
||||
# - Only the dataset part is important, in current version of meta model will integrate the
|
||||
rb = RollingBenchmark(model_type=self.sim_task_model)
|
||||
rb = RollingBenchmark(model_type=self.sim_task_model, **self.rb_kwargs)
|
||||
sim_task = rb.basic_task()
|
||||
train_start = self.rb_kwargs.get("train_start", "2008-01-01")
|
||||
train_end = "2010-12-31" if self.meta_1st_train_end is None else self.meta_1st_train_end
|
||||
test_start = (pd.Timestamp(train_end) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
proxy_forecast_model_task = {
|
||||
# "model": "qlib.contrib.model.linear.LinearModel",
|
||||
"dataset": {
|
||||
"class": "qlib.data.dataset.DatasetH",
|
||||
"kwargs": {
|
||||
"handler": f"file://{(DIRNAME / 'handler_proxy.pkl').absolute()}",
|
||||
"handler": f"file://{(DIRNAME / self.proxy_hd).absolute()}",
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2010-12-31"),
|
||||
"test": ("2011-01-01", sim_task["dataset"]["kwargs"]["segments"]["test"][1]),
|
||||
"train": (train_start, train_end),
|
||||
"test": (test_start, sim_task["dataset"]["kwargs"]["segments"]["test"][1]),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -156,7 +196,7 @@ class DDGDA:
|
||||
segments=0.62, # keep test period consistent with the dataset yaml
|
||||
trunc_days=1 + self.horizon,
|
||||
hist_step_n=30,
|
||||
fill_method="max",
|
||||
fill_method=fill_method,
|
||||
rolling_ext_days=0,
|
||||
)
|
||||
# NOTE:
|
||||
@@ -165,12 +205,15 @@ class DDGDA:
|
||||
# So the misalignment will not affect the effectiveness of the method.
|
||||
with self._internal_data_path.open("rb") as f:
|
||||
internal_data = pickle.load(f)
|
||||
|
||||
md = MetaDatasetDS(exp_name=internal_data, **kwargs)
|
||||
|
||||
# 3) train and logging meta model
|
||||
with R.start(experiment_name=self.meta_exp_name):
|
||||
R.log_params(**kwargs)
|
||||
mm = MetaModelDS(step=self.step, hist_step_n=kwargs["hist_step_n"], lr=0.001, max_epoch=100, seed=43)
|
||||
mm = MetaModelDS(
|
||||
step=self.step, hist_step_n=kwargs["hist_step_n"], lr=0.001, max_epoch=100, seed=43, alpha=self.alpha
|
||||
)
|
||||
mm.fit(md)
|
||||
R.save_objects(model=mm)
|
||||
|
||||
@@ -203,7 +246,7 @@ class DDGDA:
|
||||
hist_step_n = int(param["hist_step_n"])
|
||||
fill_method = param.get("fill_method", "max")
|
||||
|
||||
rb = RollingBenchmark(model_type=self.forecast_model)
|
||||
rb = RollingBenchmark(model_type=self.forecast_model, **self.rb_kwargs)
|
||||
task_l = rb.create_rolling_tasks()
|
||||
|
||||
# 2.2) create meta dataset for final dataset
|
||||
@@ -233,13 +276,13 @@ class DDGDA:
|
||||
"""
|
||||
with self._task_path.open("rb") as f:
|
||||
tasks = pickle.load(f)
|
||||
rb = RollingBenchmark(rolling_exp="rolling_ds", model_type=self.forecast_model)
|
||||
rb = RollingBenchmark(rolling_exp="rolling_ds", model_type=self.forecast_model, **self.rb_kwargs)
|
||||
rb.train_rolling_tasks(tasks)
|
||||
rb.ens_rolling()
|
||||
rb.update_rolling_rec()
|
||||
|
||||
def run_all(self):
|
||||
# 1) file: handler_proxy.pkl
|
||||
# 1) file: handler_proxy.pkl (self.proxy_hd)
|
||||
self.dump_data_for_proxy_model()
|
||||
# 2)
|
||||
# file: internal_data_s20.pkl
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from typing import Optional
|
||||
from qlib.model.ens.ensemble import RollingEnsemble
|
||||
from qlib.utils import init_instance_by_config
|
||||
import fire
|
||||
import yaml
|
||||
import pandas as pd
|
||||
from qlib import auto_init
|
||||
from pathlib import Path
|
||||
from tqdm.auto import tqdm
|
||||
from qlib.model.trainer import TrainerR
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils.data import update_config
|
||||
from qlib.workflow import R
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
@@ -25,11 +29,40 @@ class RollingBenchmark:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, rolling_exp="rolling_models", model_type="linear") -> None:
|
||||
def __init__(
|
||||
self,
|
||||
rolling_exp: str = "rolling_models",
|
||||
model_type: str = "linear",
|
||||
h_path: Optional[str] = None,
|
||||
train_start: Optional[str] = None,
|
||||
test_end: Optional[str] = None,
|
||||
task_ext_conf: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
rolling_exp : str
|
||||
The name for the experiments for rolling
|
||||
model_type : str
|
||||
The model to be boosted.
|
||||
h_path : Optional[str]
|
||||
the dumped data handler;
|
||||
test_end : Optional[str]
|
||||
the test end for the data. It is typically used together with the handler
|
||||
train_start : Optional[str]
|
||||
the train start for the data. It is typically used together with the handler.
|
||||
task_ext_conf : Optional[dict]
|
||||
some option to update the
|
||||
"""
|
||||
self.step = 20
|
||||
self.horizon = 20
|
||||
self.rolling_exp = rolling_exp
|
||||
self.model_type = model_type
|
||||
self.h_path = h_path
|
||||
self.train_start = train_start
|
||||
self.test_end = test_end
|
||||
self.logger = get_module_logger("RollingBenchmark")
|
||||
self.task_ext_conf = task_ext_conf
|
||||
|
||||
def basic_task(self):
|
||||
"""For fast training rolling"""
|
||||
@@ -42,6 +75,10 @@ class RollingBenchmark:
|
||||
h_path = DIRNAME / "linear_alpha158_handler_horizon{}.pkl".format(self.horizon)
|
||||
else:
|
||||
raise AssertionError("Model type is not supported!")
|
||||
|
||||
if self.h_path is not None:
|
||||
h_path = Path(self.h_path)
|
||||
|
||||
with conf_path.open("r") as f:
|
||||
conf = yaml.safe_load(f)
|
||||
|
||||
@@ -52,6 +89,9 @@ class RollingBenchmark:
|
||||
|
||||
task = conf["task"]
|
||||
|
||||
if self.task_ext_conf is not None:
|
||||
task = update_config(task, self.task_ext_conf)
|
||||
|
||||
if not h_path.exists():
|
||||
h_conf = task["dataset"]["kwargs"]["handler"]
|
||||
h = init_instance_by_config(h_conf)
|
||||
@@ -59,6 +99,15 @@ class RollingBenchmark:
|
||||
|
||||
task["dataset"]["kwargs"]["handler"] = f"file://{h_path}"
|
||||
task["record"] = ["qlib.workflow.record_temp.SignalRecord"]
|
||||
|
||||
if self.train_start is not None:
|
||||
seg = task["dataset"]["kwargs"]["segments"]["train"]
|
||||
task["dataset"]["kwargs"]["segments"]["train"] = pd.Timestamp(self.train_start), seg[1]
|
||||
|
||||
if self.test_end is not None:
|
||||
seg = task["dataset"]["kwargs"]["segments"]["test"]
|
||||
task["dataset"]["kwargs"]["segments"]["test"] = seg[0], pd.Timestamp(self.test_end)
|
||||
self.logger.info(task)
|
||||
return task
|
||||
|
||||
def create_rolling_tasks(self):
|
||||
@@ -93,7 +142,7 @@ class RollingBenchmark:
|
||||
"""
|
||||
Evaluate the combined rolling results
|
||||
"""
|
||||
for rid, rec in R.list_recorders(experiment_name=self.COMB_EXP).items():
|
||||
for _, rec in R.list_recorders(experiment_name=self.COMB_EXP).items():
|
||||
for rt_cls in SigAnaRecord, PortAnaRecord:
|
||||
rt = rt_cls(recorder=rec, skip_existing=True)
|
||||
rt.generate()
|
||||
|
||||
@@ -14,9 +14,10 @@ python -m qlib.run.get_data qlib_data qlib_data --target_dir ./data/bin --region
|
||||
|
||||
To run codes in this example, we need data in pickle format. To achieve this, run following commands (might need a few minutes to finish):
|
||||
|
||||
[//]: # (TODO: Instead of dumping dataframe with different format (like `_gen_dataset` and `_gen_day_dataset` in `qlib/contrib/data/highfreq_provider.py`), we encourage to implement different subclass of `Dataset` and `DataHandler`. This will keep the workflow cleaner and interfaces more consistent, and move all the complexity to the subclass.)
|
||||
|
||||
```
|
||||
python scripts/gen_pickle_data.py -c scripts/pickle_data_config.yml
|
||||
python scripts/collect_pickle_dataframe.py
|
||||
python scripts/gen_training_orders.py
|
||||
python scripts/merge_orders.py
|
||||
```
|
||||
@@ -27,8 +28,7 @@ When finished, the structure under `data/` should be:
|
||||
data
|
||||
├── bin
|
||||
├── orders
|
||||
├── pickle
|
||||
└── pickle_dataframe
|
||||
└── pickle
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
@@ -1,17 +1,9 @@
|
||||
order_file: ./data/orders/test_orders.pkl
|
||||
start_time: "9:30"
|
||||
end_time: "14:54"
|
||||
data_granularity: "5min"
|
||||
qlib:
|
||||
provider_uri_5min: ./data/bin/
|
||||
feature_root_dir: ./data/pickle/
|
||||
feature_columns_today: [
|
||||
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
|
||||
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"
|
||||
]
|
||||
feature_columns_yesterday: [
|
||||
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
|
||||
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"
|
||||
]
|
||||
exchange:
|
||||
limit_threshold: null
|
||||
deal_price: ["$close", "$close"]
|
||||
@@ -45,10 +37,12 @@ strategies:
|
||||
data_ticks: 48
|
||||
max_step: 8
|
||||
processed_data_provider:
|
||||
class: PickleProcessedDataProvider
|
||||
class: HandlerProcessedDataProvider
|
||||
kwargs:
|
||||
data_dir: ./data/pickle_dataframe/feature
|
||||
module_path: qlib.rl.data.pickle_styled
|
||||
data_dir: ./data/pickle/
|
||||
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
|
||||
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
|
||||
module_path: qlib.rl.data.native
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
module_path: qlib.rl.order_execution.strategy
|
||||
30min:
|
||||
|
||||
@@ -1,17 +1,9 @@
|
||||
order_file: ./data/orders/test_orders.pkl
|
||||
start_time: "9:30"
|
||||
end_time: "14:54"
|
||||
data_granularity: "5min"
|
||||
qlib:
|
||||
provider_uri_5min: ./data/bin/
|
||||
feature_root_dir: ./data/pickle/
|
||||
feature_columns_today: [
|
||||
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
|
||||
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"
|
||||
]
|
||||
feature_columns_yesterday: [
|
||||
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
|
||||
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"
|
||||
]
|
||||
exchange:
|
||||
limit_threshold: null
|
||||
deal_price: ["$close", "$close"]
|
||||
@@ -45,10 +37,12 @@ strategies:
|
||||
data_ticks: 48
|
||||
max_step: 8
|
||||
processed_data_provider:
|
||||
class: PickleProcessedDataProvider
|
||||
class: HandlerProcessedDataProvider
|
||||
kwargs:
|
||||
data_dir: ./data/pickle_dataframe/feature
|
||||
module_path: qlib.rl.data.pickle_styled
|
||||
data_dir: ./data/pickle/
|
||||
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
|
||||
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
|
||||
module_path: qlib.rl.data.native
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
module_path: qlib.rl.order_execution.strategy
|
||||
30min:
|
||||
|
||||
@@ -1,17 +1,9 @@
|
||||
order_file: ./data/orders/test_orders.pkl
|
||||
start_time: "9:30"
|
||||
end_time: "14:54"
|
||||
data_granularity: "5min"
|
||||
qlib:
|
||||
provider_uri_5min: ./data/bin/
|
||||
feature_root_dir: ./data/pickle/
|
||||
feature_columns_today: [
|
||||
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
|
||||
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"
|
||||
]
|
||||
feature_columns_yesterday: [
|
||||
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
|
||||
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"
|
||||
]
|
||||
exchange:
|
||||
limit_threshold: null
|
||||
deal_price: ["$close", "$close"]
|
||||
|
||||
@@ -3,8 +3,8 @@ simulator:
|
||||
time_per_step: 30
|
||||
vol_limit: null
|
||||
env:
|
||||
concurrency: 48
|
||||
parallel_mode: shmem
|
||||
concurrency: 32
|
||||
parallel_mode: dummy
|
||||
action_interpreter:
|
||||
class: CategoricalActionInterpreter
|
||||
kwargs:
|
||||
@@ -18,10 +18,13 @@ state_interpreter:
|
||||
data_ticks: 48 # 48 = 240 min / 5 min
|
||||
max_step: 8
|
||||
processed_data_provider:
|
||||
class: PickleProcessedDataProvider
|
||||
module_path: qlib.rl.data.pickle_styled
|
||||
class: HandlerProcessedDataProvider
|
||||
kwargs:
|
||||
data_dir: ./data/pickle_dataframe/feature
|
||||
data_dir: ./data/pickle/
|
||||
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
|
||||
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
|
||||
backtest: false
|
||||
module_path: qlib.rl.data.native
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
reward:
|
||||
class: PAPenaltyReward
|
||||
@@ -32,7 +35,9 @@ reward:
|
||||
data:
|
||||
source:
|
||||
order_dir: ./data/orders
|
||||
data_dir: ./data/pickle_dataframe/backtest
|
||||
feature_root_dir: ./data/pickle/
|
||||
feature_columns_today: ["$close0", "$volume0"]
|
||||
feature_columns_yesterday: []
|
||||
total_time: 240
|
||||
default_start_time_index: 0
|
||||
default_end_time_index: 235
|
||||
|
||||
@@ -3,8 +3,8 @@ simulator:
|
||||
time_per_step: 30
|
||||
vol_limit: null
|
||||
env:
|
||||
concurrency: 48
|
||||
parallel_mode: shmem
|
||||
concurrency: 32
|
||||
parallel_mode: dummy
|
||||
action_interpreter:
|
||||
class: CategoricalActionInterpreter
|
||||
kwargs:
|
||||
@@ -18,10 +18,13 @@ state_interpreter:
|
||||
data_ticks: 48 # 48 = 240 min / 5 min
|
||||
max_step: 8
|
||||
processed_data_provider:
|
||||
class: PickleProcessedDataProvider
|
||||
module_path: qlib.rl.data.pickle_styled
|
||||
class: HandlerProcessedDataProvider
|
||||
kwargs:
|
||||
data_dir: ./data/pickle_dataframe/feature
|
||||
data_dir: ./data/pickle/
|
||||
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
|
||||
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
|
||||
backtest: false
|
||||
module_path: qlib.rl.data.native
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
reward:
|
||||
class: PPOReward
|
||||
@@ -33,7 +36,9 @@ reward:
|
||||
data:
|
||||
source:
|
||||
order_dir: ./data/orders
|
||||
data_dir: ./data/pickle_dataframe/backtest
|
||||
feature_root_dir: ./data/pickle/
|
||||
feature_columns_today: ["$close0", "$volume0"]
|
||||
feature_columns_yesterday: []
|
||||
total_time: 240
|
||||
default_start_time_index: 0
|
||||
default_end_time_index: 235
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import pandas as pd
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
os.makedirs(os.path.join("data", "pickle_dataframe"), exist_ok=True)
|
||||
|
||||
|
||||
def _collect(df: pd.DataFrame, instrument: str, tag: str) -> None:
|
||||
cur = df[df["instrument"] == instrument].sort_values(by=["datetime"])
|
||||
cur = cur.set_index(["instrument", "datetime", "date"])
|
||||
pickle.dump(cur, open(os.path.join("data", "pickle_dataframe", tag, f"{instrument}.pkl"), "wb"))
|
||||
|
||||
|
||||
for tag in ("backtest", "feature"):
|
||||
df = pickle.load(open(os.path.join("data", "pickle", f"{tag}.pkl"), "rb"))
|
||||
df = pd.concat(list(df.values())).reset_index()
|
||||
df["date"] = df["datetime"].dt.date.astype("datetime64")
|
||||
instruments = sorted(set(df["instrument"]))
|
||||
|
||||
os.makedirs(os.path.join("data", "pickle_dataframe", tag), exist_ok=True)
|
||||
|
||||
Parallel(n_jobs=-1, verbose=10)(delayed(_collect)(df, instrument, tag) for instrument in instruments)
|
||||
@@ -4,17 +4,22 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
DATA_PATH = Path(os.path.join("data", "pickle_dataframe", "backtest"))
|
||||
DATA_PATH = Path(os.path.join("data", "pickle", "backtest"))
|
||||
OUTPUT_PATH = Path(os.path.join("data", "orders"))
|
||||
|
||||
|
||||
def generate_order(stock: str, start_idx: int, end_idx: int) -> None:
|
||||
df = pd.read_pickle(DATA_PATH / f"{stock}.pkl")
|
||||
def generate_order(stock: str, start_idx: int, end_idx: int) -> bool:
|
||||
dataset = pd.read_pickle(DATA_PATH / f"{stock}.pkl")
|
||||
df = dataset.handler.fetch(level=None).reset_index()
|
||||
if len(df) == 0 or df.isnull().values.any() or min(df["$volume0"]) < 1e-5:
|
||||
return False
|
||||
|
||||
df["date"] = df["datetime"].dt.date.astype("datetime64")
|
||||
df = df.set_index(["instrument", "datetime", "date"])
|
||||
df = df.groupby("date").take(range(start_idx, end_idx)).droplevel(level=0)
|
||||
div = df["$volume0"].rolling((end_idx - start_idx) * 60).mean().shift(1).groupby(level="date").transform("first")
|
||||
|
||||
order_all = pd.DataFrame(df.groupby(level=(2, 0)).mean().dropna())
|
||||
order_all["amount"] = np.random.lognormal(-3.28, 1.14) * order_all["$volume0"]
|
||||
@@ -32,11 +37,17 @@ def generate_order(stock: str, start_idx: int, end_idx: int) -> None:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
if len(order) > 0:
|
||||
order.to_pickle(path / f"{stock}.pkl.target")
|
||||
return True
|
||||
|
||||
|
||||
np.random.seed(1234)
|
||||
file_list = sorted(os.listdir(DATA_PATH))
|
||||
stocks = [f.replace(".pkl", "") for f in file_list]
|
||||
stocks = sorted(np.random.choice(stocks, size=100, replace=False))
|
||||
for stock in tqdm(stocks):
|
||||
generate_order(stock, 0, 240 // 5 - 1)
|
||||
np.random.shuffle(stocks)
|
||||
|
||||
cnt = 0
|
||||
for stock in stocks:
|
||||
if generate_order(stock, 0, 240 // 5 - 1):
|
||||
cnt += 1
|
||||
if cnt == 100:
|
||||
break
|
||||
|
||||
@@ -179,7 +179,7 @@ def get_strategy_executor(
|
||||
executor: Union[str, dict, object, Path],
|
||||
benchmark: Optional[str] = "SH000300",
|
||||
account: Union[float, int, dict] = 1e9,
|
||||
exchange_kwargs: dict = {},
|
||||
exchange_kwargs: Union[dict, Exchange] = {}, # TODO: rename parameter
|
||||
pos_type: str = "Position",
|
||||
) -> Tuple[BaseStrategy, BaseExecutor]:
|
||||
|
||||
@@ -197,12 +197,15 @@ def get_strategy_executor(
|
||||
pos_type=pos_type,
|
||||
)
|
||||
|
||||
exchange_kwargs = copy.copy(exchange_kwargs)
|
||||
if "start_time" not in exchange_kwargs:
|
||||
exchange_kwargs["start_time"] = start_time
|
||||
if "end_time" not in exchange_kwargs:
|
||||
exchange_kwargs["end_time"] = end_time
|
||||
trade_exchange = get_exchange(**exchange_kwargs)
|
||||
if isinstance(exchange_kwargs, Exchange):
|
||||
trade_exchange = exchange_kwargs
|
||||
else:
|
||||
exchange_kwargs = copy.copy(exchange_kwargs)
|
||||
if "start_time" not in exchange_kwargs:
|
||||
exchange_kwargs["start_time"] = start_time
|
||||
if "end_time" not in exchange_kwargs:
|
||||
exchange_kwargs["end_time"] = end_time
|
||||
trade_exchange = get_exchange(**exchange_kwargs)
|
||||
|
||||
common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=trade_exchange)
|
||||
trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy)
|
||||
|
||||
@@ -56,6 +56,7 @@ def collect_data_loop(
|
||||
trade_strategy: BaseStrategy,
|
||||
trade_executor: BaseExecutor,
|
||||
return_value: dict | None = None,
|
||||
show_progress: bool = True,
|
||||
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], None]:
|
||||
"""Generator for collecting the trade decision data for rl training
|
||||
|
||||
@@ -74,6 +75,8 @@ def collect_data_loop(
|
||||
the outermost executor
|
||||
return_value : dict
|
||||
used for backtest_loop
|
||||
show_progress: bool
|
||||
whether to show execution progress
|
||||
|
||||
Yields
|
||||
-------
|
||||
@@ -83,7 +86,8 @@ def collect_data_loop(
|
||||
trade_executor.reset(start_time=start_time, end_time=end_time)
|
||||
trade_strategy.reset(level_infra=trade_executor.get_level_infra())
|
||||
|
||||
with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop") as bar:
|
||||
disable = not show_progress
|
||||
with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop", disable=disable) as bar:
|
||||
_execute_result = None
|
||||
while not trade_executor.finished():
|
||||
_trade_decision: BaseTradeDecision = trade_strategy.generate_trade_decision(_execute_result)
|
||||
|
||||
@@ -177,7 +177,7 @@ class Exchange:
|
||||
|
||||
necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"}
|
||||
if self.limit_type == self.LT_TP_EXP:
|
||||
assert isinstance(limit_threshold, tuple)
|
||||
assert isinstance(limit_threshold, tuple) or (isinstance(limit_threshold, list) and len(limit_threshold) == 2)
|
||||
for exp in limit_threshold:
|
||||
necessary_fields.add(exp)
|
||||
all_fields = list(necessary_fields | set(vol_lt_fields) | set(subscribe_fields))
|
||||
@@ -263,6 +263,9 @@ class Exchange:
|
||||
"""get limit type"""
|
||||
if isinstance(limit_threshold, tuple):
|
||||
return self.LT_TP_EXP
|
||||
if isinstance(limit_threshold, list):
|
||||
assert len(limit_threshold) == 2
|
||||
return self.LT_TP_EXP
|
||||
elif isinstance(limit_threshold, float):
|
||||
return self.LT_FLT
|
||||
elif limit_threshold is None:
|
||||
@@ -325,7 +328,7 @@ class Exchange:
|
||||
|
||||
assert isinstance(volume_threshold, dict)
|
||||
for key, vol_limit in volume_threshold.items():
|
||||
assert isinstance(vol_limit, tuple)
|
||||
assert isinstance(vol_limit, tuple) or (isinstance(vol_limit, list) and len(vol_limit) == 2)
|
||||
fields.add(vol_limit[1])
|
||||
|
||||
if key in ("buy", "all"):
|
||||
@@ -803,7 +806,7 @@ class Exchange:
|
||||
|
||||
vol_limit_num: List[float] = []
|
||||
for limit in vol_limit:
|
||||
assert isinstance(limit, tuple)
|
||||
assert isinstance(limit, tuple) or (isinstance(limit, list) and len(limit) == 2)
|
||||
if limit[0] == "current":
|
||||
limit_value = self.quote.get_data(
|
||||
order.stock_id,
|
||||
|
||||
@@ -147,6 +147,7 @@ _default_config = {
|
||||
"redis_host": "127.0.0.1",
|
||||
"redis_port": 6379,
|
||||
"redis_task_db": 1,
|
||||
"redis_password": None,
|
||||
# This value can be reset via qlib.init
|
||||
"logging_level": logging.INFO,
|
||||
# Global configuration of qlib log
|
||||
|
||||
@@ -55,8 +55,10 @@ class InternalData:
|
||||
# The handler is initialized for only once.
|
||||
if not trainer.has_worker():
|
||||
self.dh = init_task_handler(perf_task_tpl)
|
||||
self.dh.config(dump_all=False) # in some cases, the data handler are saved to disk with `dump_all=True`
|
||||
else:
|
||||
self.dh = init_instance_by_config(perf_task_tpl["dataset"]["kwargs"]["handler"])
|
||||
assert self.dh.dump_all is False # otherwise, it will save all the detailed data
|
||||
|
||||
seg = perf_task_tpl["dataset"]["kwargs"]["segments"]
|
||||
|
||||
@@ -77,7 +79,7 @@ class InternalData:
|
||||
get_module_logger("Internal Data").info("the data has been initialized")
|
||||
else:
|
||||
# train new models
|
||||
assert 0 == len(recorders), "An empty experiment is required for setup `InternalData``"
|
||||
assert 0 == len(recorders), "An empty experiment is required for setup `InternalData`"
|
||||
trainer.train(gen_task)
|
||||
|
||||
# 2) extract the similarity matrix
|
||||
@@ -119,6 +121,7 @@ class MetaTaskDS(MetaTask):
|
||||
|
||||
def __init__(self, task: dict, meta_info: pd.DataFrame, mode: str = MetaTask.PROC_MODE_FULL, fill_method="max"):
|
||||
"""
|
||||
|
||||
The description of the processed data
|
||||
|
||||
time_perf: A array with shape <hist_step_n * step, data pieces> -> data piece performance
|
||||
@@ -132,6 +135,10 @@ class MetaTaskDS(MetaTask):
|
||||
[0., 0., 0., ..., 0., 0., 1.],
|
||||
[0., 0., 0., ..., 0., 0., 1.]])
|
||||
|
||||
Parameters
|
||||
----------
|
||||
meta_info: pd.DataFrame
|
||||
please refer to the docs of _prepare_meta_ipt for detailed explanation.
|
||||
"""
|
||||
super().__init__(task, meta_info)
|
||||
self.fill_method = fill_method
|
||||
@@ -180,12 +187,41 @@ class MetaTaskDS(MetaTask):
|
||||
self.processed_meta_input = data_to_tensor(self.processed_meta_input)
|
||||
|
||||
def _get_processed_meta_info(self):
|
||||
meta_info_norm = self.meta_info.sub(self.meta_info.mean(axis=1), axis=0) # .fillna(0.)
|
||||
if self.fill_method == "max":
|
||||
meta_info_norm = meta_info_norm.T.fillna(
|
||||
meta_info_norm.max(axis=1)
|
||||
).T # fill it with row max to align with previous implementation
|
||||
meta_info_norm = self.meta_info.sub(self.meta_info.mean(axis=1), axis=0)
|
||||
if self.fill_method.startswith("max"):
|
||||
suffix = self.fill_method.lstrip("max")
|
||||
if suffix == "seg":
|
||||
fill_value = {}
|
||||
for col in meta_info_norm.columns:
|
||||
fill_value[col] = meta_info_norm.loc[meta_info_norm[col].isna(), :].dropna(axis=1).mean().max()
|
||||
fill_value = pd.Series(fill_value).sort_index()
|
||||
# The NaN Values are filled segment-wise. Below is an exampleof fill_value
|
||||
# 2009-01-05 2009-02-06 0.145809
|
||||
# 2009-02-09 2009-03-06 0.148005
|
||||
# 2009-03-09 2009-04-03 0.090385
|
||||
# 2009-04-07 2009-05-05 0.114318
|
||||
# 2009-05-06 2009-06-04 0.119328
|
||||
# ...
|
||||
meta_info_norm = meta_info_norm.fillna(fill_value)
|
||||
else:
|
||||
if len(suffix) > 0:
|
||||
get_module_logger("MetaTaskDS").warning(
|
||||
f"fill_method={self.fill_method}; the info after can't be correctly parsed. Please check your parameters."
|
||||
)
|
||||
fill_value = meta_info_norm.max(axis=1)
|
||||
# fill it with row max to align with previous implementation
|
||||
# This will magnify the data similarity when data is in daily freq
|
||||
|
||||
# the fill value corresponds to data like this
|
||||
# It get a performance value for each day.
|
||||
# The performance value are get from other models on this day
|
||||
# 2009-01-16 0.276320
|
||||
# 2009-01-19 0.280603
|
||||
# ...
|
||||
# 2011-06-27 0.203773
|
||||
meta_info_norm = meta_info_norm.T.fillna(fill_value).T
|
||||
elif self.fill_method == "zero":
|
||||
# It will fillna(0.0) at the end.
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
@@ -286,7 +322,33 @@ class MetaDatasetDS(MetaTaskDataset):
|
||||
logger.warning(f"ValueError: {e}")
|
||||
assert len(self.meta_task_l) > 0, "No meta tasks found. Please check the data and setting"
|
||||
|
||||
def _prepare_meta_ipt(self, task):
|
||||
def _prepare_meta_ipt(self, task) -> pd.DataFrame:
|
||||
"""
|
||||
Please refer to `self.internal_data.setup` for detailed information about `self.internal_data.data_ic_df`
|
||||
|
||||
Indices with format below can be successfully sliced by `ic_df.loc[:end, pd.IndexSlice[:, :end]]`
|
||||
|
||||
2021-06-21 2021-06-04 .. 2021-03-22 2021-03-08
|
||||
2021-07-02 2021-06-18 .. 2021-04-02 None
|
||||
|
||||
Returns
|
||||
-------
|
||||
a pd.DataFrame with similar content below.
|
||||
- each column corresponds to a trained model named by the training data range
|
||||
- each row corresponds to a day of data tested by the models of the columns
|
||||
- The rows cells that overlaps with the data used by columns are masked
|
||||
|
||||
|
||||
2009-01-05 2009-02-09 ... 2011-04-27 2011-05-26
|
||||
2009-02-06 2009-03-06 ... 2011-05-25 2011-06-23
|
||||
datetime ...
|
||||
2009-01-13 NaN 0.310639 ... -0.169057 0.137792
|
||||
2009-01-14 NaN 0.261086 ... -0.143567 0.082581
|
||||
... ... ... ... ... ...
|
||||
2011-06-30 -0.054907 -0.020219 ... -0.023226 NaN
|
||||
2011-07-01 -0.075762 -0.026626 ... -0.003167 NaN
|
||||
|
||||
"""
|
||||
ic_df = self.internal_data.data_ic_df
|
||||
|
||||
segs = task["dataset"]["kwargs"]["segments"]
|
||||
@@ -294,15 +356,19 @@ class MetaDatasetDS(MetaTaskDataset):
|
||||
ic_df_avail = ic_df.loc[:end, pd.IndexSlice[:, :end]]
|
||||
|
||||
# meta data set focus on the **information** instead of preprocess
|
||||
# 1) filter the future info
|
||||
def mask_future(s):
|
||||
"""mask future information"""
|
||||
# from qlib.utils import get_date_by_shift
|
||||
# 1) filter the overlap info
|
||||
def mask_overlap(s):
|
||||
"""
|
||||
mask overlap information
|
||||
data after self.name[end] with self.trunc_days that contains future info are also considered as overlap info
|
||||
|
||||
Approximately the diagnal + horizon length of data are masked.
|
||||
"""
|
||||
start, end = s.name
|
||||
end = get_date_by_shift(trading_date=end, shift=self.trunc_days - 1, future=True)
|
||||
return s.mask((s.index >= start) & (s.index <= end))
|
||||
|
||||
ic_df_avail = ic_df_avail.apply(mask_future) # apply to each col
|
||||
ic_df_avail = ic_df_avail.apply(mask_overlap) # apply to each col
|
||||
|
||||
# 2) filter the info with too long periods
|
||||
total_len = self.step * self.hist_step_n
|
||||
|
||||
@@ -52,6 +52,7 @@ class MetaModelDS(MetaTaskModel):
|
||||
lr=0.0001,
|
||||
max_epoch=100,
|
||||
seed=43,
|
||||
alpha=0.0,
|
||||
):
|
||||
self.step = step
|
||||
self.hist_step_n = hist_step_n
|
||||
@@ -61,6 +62,7 @@ class MetaModelDS(MetaTaskModel):
|
||||
self.lr = lr
|
||||
self.max_epoch = max_epoch
|
||||
self.fitted = False
|
||||
self.alpha = alpha
|
||||
torch.manual_seed(seed)
|
||||
|
||||
def run_epoch(self, phase, task_list, epoch, opt, loss_l, ignore_weight=False):
|
||||
@@ -144,7 +146,11 @@ class MetaModelDS(MetaTaskModel):
|
||||
) # debug: record when the test phase starts
|
||||
|
||||
self.tn = PredNet(
|
||||
step=self.step, hist_step_n=self.hist_step_n, clip_weight=self.clip_weight, clip_method=self.clip_method
|
||||
step=self.step,
|
||||
hist_step_n=self.hist_step_n,
|
||||
clip_weight=self.clip_weight,
|
||||
clip_method=self.clip_method,
|
||||
alpha=self.alpha,
|
||||
)
|
||||
|
||||
opt = optim.Adam(self.tn.parameters(), lr=self.lr)
|
||||
|
||||
@@ -41,11 +41,18 @@ class TimeWeightMeta(SingleMetaBase):
|
||||
|
||||
|
||||
class PredNet(nn.Module):
|
||||
def __init__(self, step, hist_step_n, clip_weight=None, clip_method="tanh"):
|
||||
def __init__(self, step, hist_step_n, clip_weight=None, clip_method="tanh", alpha: float = 0.0):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
alpha : float
|
||||
the regularization for sub model (useful when align meta model with linear submodel)
|
||||
"""
|
||||
super().__init__()
|
||||
self.step = step
|
||||
self.twm = TimeWeightMeta(hist_step_n=hist_step_n, clip_weight=clip_weight, clip_method=clip_method)
|
||||
self.init_paramters(hist_step_n)
|
||||
self.alpha = alpha
|
||||
|
||||
def get_sample_weights(self, X, time_perf, time_belong, ignore_weight=False):
|
||||
weights = torch.from_numpy(np.ones(X.shape[0])).float().to(X.device)
|
||||
@@ -59,7 +66,7 @@ class PredNet(nn.Module):
|
||||
"""Please refer to the docs of MetaTaskDS for the description of the variables"""
|
||||
weights = self.get_sample_weights(X, time_perf, time_belong, ignore_weight=ignore_weight)
|
||||
X_w = X.T * weights.view(1, -1)
|
||||
theta = torch.inverse(X_w @ X) @ X_w @ y
|
||||
theta = torch.inverse(X_w @ X + self.alpha * torch.eye(X_w.shape[0])) @ X_w @ y
|
||||
return X_test @ theta, weights
|
||||
|
||||
def init_paramters(self, hist_step_n):
|
||||
|
||||
@@ -5,6 +5,9 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from qlib.constant import EPS
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
|
||||
class ICLoss(nn.Module):
|
||||
def forward(self, pred, y, idx, skip_size=50):
|
||||
@@ -24,6 +27,7 @@ class ICLoss(nn.Module):
|
||||
diff_point.append(i)
|
||||
prev = date
|
||||
diff_point.append(None)
|
||||
# The lengths of diff_point will be one more larger then diff_point
|
||||
|
||||
ic_all = 0.0
|
||||
skip_n = 0
|
||||
@@ -34,13 +38,23 @@ class ICLoss(nn.Module):
|
||||
skip_n += 1
|
||||
continue
|
||||
y_focus = y[start_i:end_i]
|
||||
if pred_focus.std() < EPS or y_focus.std() < EPS:
|
||||
# These cases often happend at the end of test data.
|
||||
# Usually caused by fillna(0.)
|
||||
skip_n += 1
|
||||
continue
|
||||
|
||||
ic_day = torch.dot(
|
||||
(pred_focus - pred_focus.mean()) / np.sqrt(pred_focus.shape[0]) / pred_focus.std(),
|
||||
(y_focus - y_focus.mean()) / np.sqrt(y_focus.shape[0]) / y_focus.std(),
|
||||
)
|
||||
ic_all += ic_day
|
||||
if len(diff_point) - 1 - skip_n <= 0:
|
||||
raise ValueError("No enough data for calculating iC")
|
||||
raise ValueError("No enough data for calculating IC")
|
||||
if skip_n > 0:
|
||||
get_module_logger("ICLoss").info(
|
||||
f"{skip_n} days are skipped due to zero std or small scale of valid samples."
|
||||
)
|
||||
ic_mean = ic_all / (len(diff_point) - 1 - skip_n)
|
||||
return -ic_mean # ic loss
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.data.dataset.weight import Reweighter
|
||||
from scipy.optimize import nnls
|
||||
from sklearn.linear_model import LinearRegression, Ridge, Lasso
|
||||
@@ -29,7 +30,7 @@ class LinearModel(Model):
|
||||
RIDGE = "ridge"
|
||||
LASSO = "lasso"
|
||||
|
||||
def __init__(self, estimator="ols", alpha=0.0, fit_intercept=False):
|
||||
def __init__(self, estimator="ols", alpha=0.0, fit_intercept=False, include_valid: bool = False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -39,6 +40,9 @@ class LinearModel(Model):
|
||||
l1 or l2 regularization parameter
|
||||
fit_intercept : bool
|
||||
whether fit intercept
|
||||
include_valid: bool
|
||||
Should the validation data be included for training?
|
||||
The validation data should be included
|
||||
"""
|
||||
assert estimator in [self.OLS, self.NNLS, self.RIDGE, self.LASSO], f"unsupported estimator `{estimator}`"
|
||||
self.estimator = estimator
|
||||
@@ -49,9 +53,16 @@ class LinearModel(Model):
|
||||
self.fit_intercept = fit_intercept
|
||||
|
||||
self.coef_ = None
|
||||
self.include_valid = include_valid
|
||||
|
||||
def fit(self, dataset: DatasetH, reweighter: Reweighter = None):
|
||||
df_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
if self.include_valid:
|
||||
try:
|
||||
df_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
df_train = pd.concat([df_train, df_valid])
|
||||
except KeyError:
|
||||
get_module_logger("LinearModel").info("include_valid=True, but valid does not exist")
|
||||
if df_train.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
if reweighter is not None:
|
||||
|
||||
@@ -47,10 +47,6 @@ class DNNModelPytorch(Model):
|
||||
layer sizes
|
||||
lr : float
|
||||
learning rate
|
||||
lr_decay : float
|
||||
learning rate decay
|
||||
lr_decay_steps : int
|
||||
learning rate decay steps
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : int
|
||||
@@ -64,8 +60,6 @@ class DNNModelPytorch(Model):
|
||||
batch_size=2000,
|
||||
early_stop_rounds=50,
|
||||
eval_steps=20,
|
||||
lr_decay=0.96,
|
||||
lr_decay_steps=100,
|
||||
optimizer="gd",
|
||||
loss="mse",
|
||||
GPU=0,
|
||||
@@ -93,8 +87,6 @@ class DNNModelPytorch(Model):
|
||||
self.batch_size = batch_size
|
||||
self.early_stop_rounds = early_stop_rounds
|
||||
self.eval_steps = eval_steps
|
||||
self.lr_decay = lr_decay
|
||||
self.lr_decay_steps = lr_decay_steps
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss_type = loss
|
||||
if isinstance(GPU, str):
|
||||
@@ -116,8 +108,6 @@ class DNNModelPytorch(Model):
|
||||
f"\nbatch_size : {batch_size}"
|
||||
f"\nearly_stop_rounds : {early_stop_rounds}"
|
||||
f"\neval_steps : {eval_steps}"
|
||||
f"\nlr_decay : {lr_decay}"
|
||||
f"\nlr_decay_steps : {lr_decay_steps}"
|
||||
f"\noptimizer : {optimizer}"
|
||||
f"\nloss_type : {loss}"
|
||||
f"\nseed : {seed}"
|
||||
|
||||
@@ -635,7 +635,7 @@ class FileOrderStrategy(BaseStrategy):
|
||||
self.order_df = file
|
||||
else:
|
||||
with get_io_object(file) as f:
|
||||
self.order_df = pd.read_csv(f, dtype={"datetime": np.str})
|
||||
self.order_df = pd.read_csv(f, dtype={"datetime": str})
|
||||
|
||||
self.order_df["datetime"] = self.order_df["datetime"].apply(pd.Timestamp)
|
||||
self.order_df = self.order_df.set_index(["datetime", "instrument"])
|
||||
|
||||
@@ -783,7 +783,7 @@ class LocalPITProvider(PITProvider):
|
||||
index_path = C.dpm.get_data_uri() / "financial" / instrument.lower() / f"{field}.index"
|
||||
data_path = C.dpm.get_data_uri() / "financial" / instrument.lower() / f"{field}.data"
|
||||
if not (index_path.exists() and data_path.exists()):
|
||||
raise FileNotFoundError("No file is found. Raise exception and ")
|
||||
raise FileNotFoundError("No file is found.")
|
||||
# NOTE: The most significant performance loss is here.
|
||||
# Does the acceleration that makes the program complicated really matters?
|
||||
# - It makes parameters of the interface complicate
|
||||
@@ -797,14 +797,14 @@ class LocalPITProvider(PITProvider):
|
||||
cur_time_int = int(cur_time.year) * 10000 + int(cur_time.month) * 100 + int(cur_time.day)
|
||||
loc = np.searchsorted(data["date"], cur_time_int, side="right")
|
||||
if loc <= 0:
|
||||
return pd.Series()
|
||||
return pd.Series(dtype=C.pit_record_type["value"])
|
||||
last_period = data["period"][:loc].max() # return the latest quarter
|
||||
first_period = data["period"][:loc].min()
|
||||
period_list = get_period_list(first_period, last_period, quarterly)
|
||||
if period is not None:
|
||||
# NOTE: `period` has higher priority than `start_index` & `end_index`
|
||||
if period not in period_list:
|
||||
return pd.Series()
|
||||
return pd.Series(dtype=C.pit_record_type["value"])
|
||||
else:
|
||||
period_list = [period]
|
||||
else:
|
||||
@@ -868,7 +868,7 @@ class LocalExpressionProvider(ExpressionProvider):
|
||||
# Ensure that each column type is consistent
|
||||
# FIXME:
|
||||
# 1) The stock data is currently float. If there is other types of data, this part needs to be re-implemented.
|
||||
# 2) The the precision should be configurable
|
||||
# 2) The precision should be configurable
|
||||
try:
|
||||
series = series.astype(np.float32)
|
||||
except ValueError:
|
||||
|
||||
@@ -417,7 +417,7 @@ class TSDataSampler:
|
||||
# NOTE: bool(np.nan) is True !!!!!!!!
|
||||
# make sure reindex comes first. Otherwise extra NaN may appear.
|
||||
flt_data = flt_data.swaplevel()
|
||||
flt_data = flt_data.reindex(self.data_index).fillna(False).astype(np.bool)
|
||||
flt_data = flt_data.reindex(self.data_index).fillna(False).astype(bool)
|
||||
self.flt_data = flt_data.values
|
||||
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
|
||||
self.data_index = self.data_index[np.where(self.flt_data)[0]]
|
||||
|
||||
@@ -720,3 +720,26 @@ class DataHandlerLP(DataHandler):
|
||||
]:
|
||||
setattr(new_hd, key, getattr(handler, key, None))
|
||||
return new_hd
|
||||
|
||||
@classmethod
|
||||
def from_df(cls, df: pd.DataFrame) -> "DataHandlerLP":
|
||||
"""
|
||||
Motivation:
|
||||
- When user want to get a quick data handler.
|
||||
|
||||
The created data handler will have only one shared Dataframe without processors.
|
||||
After creating the handler, user may often want to dump the handler for reuse
|
||||
Here is a typical use case
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from qlib.data.dataset import DataHandlerLP
|
||||
dh = DataHandlerLP.from_df(df)
|
||||
dh.to_pickle(fname, dump_all=True)
|
||||
|
||||
TODO:
|
||||
- The StaticDataLoader is quite slow. It don't have to copy the data again...
|
||||
|
||||
"""
|
||||
loader = data_loader_module.StaticDataLoader(df)
|
||||
return cls(data_loader=loader)
|
||||
|
||||
@@ -2,9 +2,8 @@
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
import pandas as pd
|
||||
from typing import Union, List
|
||||
from typing import Union, List, TYPE_CHECKING
|
||||
from qlib.utils import init_instance_by_config
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.data.dataset import DataHandler
|
||||
@@ -121,7 +120,7 @@ def convert_index_format(df: Union[pd.DataFrame, pd.Series], level: str = "datet
|
||||
return df
|
||||
|
||||
|
||||
def init_task_handler(task: dict) -> Union[DataHandler, None]:
|
||||
def init_task_handler(task: dict) -> DataHandler:
|
||||
"""
|
||||
initialize the handler part of the task **inplace**
|
||||
|
||||
@@ -142,5 +141,6 @@ def init_task_handler(task: dict) -> Union[DataHandler, None]:
|
||||
if h_conf is not None:
|
||||
handler = init_instance_by_config(h_conf, accept_types=DataHandler)
|
||||
task["dataset"]["kwargs"]["handler"] = handler
|
||||
|
||||
return handler
|
||||
else:
|
||||
raise ValueError("The task does not contains a handler part.")
|
||||
|
||||
@@ -16,13 +16,12 @@ import torch
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
from qlib.backtest import INDICATOR_METRIC, collect_data_loop, get_strategy_executor
|
||||
from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir, TradeRangeByTime
|
||||
from qlib.backtest.decision import BaseTradeDecision, TradeRangeByTime
|
||||
from qlib.backtest.executor import SimulatorExecutor
|
||||
from qlib.backtest.high_performance_ds import BaseOrderIndicator
|
||||
from qlib.rl.contrib.naive_config_parser import get_backtest_config_fromfile
|
||||
from qlib.rl.contrib.naive_config_parser import BacktestConfigParser
|
||||
from qlib.rl.contrib.utils import read_order_file
|
||||
from qlib.rl.data.integration import init_qlib
|
||||
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
|
||||
from qlib.typehint import Literal
|
||||
|
||||
|
||||
@@ -30,12 +29,13 @@ def _get_multi_level_executor_config(
|
||||
strategy_config: dict,
|
||||
cash_limit: float | None = None,
|
||||
generate_report: bool = False,
|
||||
data_granularity: str = "1min",
|
||||
) -> dict:
|
||||
executor_config = {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": "5min", # FIXME: move this into config
|
||||
"time_per_step": data_granularity,
|
||||
"verbose": False,
|
||||
"trade_type": SimulatorExecutor.TT_PARAL if cash_limit is not None else SimulatorExecutor.TT_SERIAL,
|
||||
"generate_report": generate_report,
|
||||
@@ -123,109 +123,13 @@ def _generate_report(
|
||||
return report
|
||||
|
||||
|
||||
def single_with_simulator(
|
||||
backtest_config: dict,
|
||||
orders: pd.DataFrame,
|
||||
split: Literal["stock", "day"] = "stock",
|
||||
cash_limit: float | None = None,
|
||||
generate_report: bool = False,
|
||||
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
|
||||
"""Run backtest in a single thread with SingleAssetOrderExecution simulator. The orders will be executed day by day.
|
||||
A new simulator will be created and used for every single-day order.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backtest_config:
|
||||
Backtest config
|
||||
orders:
|
||||
Orders to be executed. Example format:
|
||||
datetime instrument amount direction
|
||||
0 2020-06-01 INST 600.0 0
|
||||
1 2020-06-02 INST 700.0 1
|
||||
...
|
||||
split
|
||||
Method to split orders. If it is "stock", split orders by stock. If it is "day", split orders by date.
|
||||
cash_limit
|
||||
Limitation of cash.
|
||||
generate_report
|
||||
Whether to generate reports.
|
||||
|
||||
Returns
|
||||
-------
|
||||
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
|
||||
"""
|
||||
if split == "stock":
|
||||
stock_id = orders.iloc[0].instrument
|
||||
init_qlib(backtest_config["qlib"], part=stock_id)
|
||||
else:
|
||||
day = orders.iloc[0].datetime
|
||||
init_qlib(backtest_config["qlib"], part=day)
|
||||
|
||||
stocks = orders.instrument.unique().tolist()
|
||||
|
||||
reports = []
|
||||
decisions = []
|
||||
for _, row in orders.iterrows():
|
||||
date = pd.Timestamp(row["datetime"])
|
||||
start_time = pd.Timestamp(backtest_config["start_time"]).replace(year=date.year, month=date.month, day=date.day)
|
||||
end_time = pd.Timestamp(backtest_config["end_time"]).replace(year=date.year, month=date.month, day=date.day)
|
||||
order = Order(
|
||||
stock_id=row["instrument"],
|
||||
amount=row["amount"],
|
||||
direction=OrderDir(row["direction"]),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
executor_config = _get_multi_level_executor_config(
|
||||
strategy_config=backtest_config["strategies"],
|
||||
cash_limit=cash_limit,
|
||||
generate_report=generate_report,
|
||||
)
|
||||
|
||||
exchange_config = copy.deepcopy(backtest_config["exchange"])
|
||||
exchange_config.update(
|
||||
{
|
||||
"codes": stocks,
|
||||
"freq": "5min", # FIXME: move this into config
|
||||
}
|
||||
)
|
||||
|
||||
simulator = SingleAssetOrderExecution(
|
||||
order=order,
|
||||
executor_config=executor_config,
|
||||
exchange_config=exchange_config,
|
||||
qlib_config=None,
|
||||
cash_limit=None,
|
||||
)
|
||||
|
||||
reports.append(simulator.report_dict)
|
||||
decisions += simulator.decisions
|
||||
|
||||
indicator_1day_objs = [report["indicator"]["1day"][1] for report in reports]
|
||||
indicator_info = {k: v for obj in indicator_1day_objs for k, v in obj.order_indicator_his.items()}
|
||||
records = _convert_indicator_to_dataframe(indicator_info)
|
||||
assert records is None or not np.isnan(records["ffr"]).any()
|
||||
|
||||
if generate_report:
|
||||
_report = _generate_report(decisions, [report["indicator"] for report in reports])
|
||||
|
||||
if split == "stock":
|
||||
stock_id = orders.iloc[0].instrument
|
||||
report = {stock_id: _report}
|
||||
else:
|
||||
day = orders.iloc[0].datetime
|
||||
report = {day: _report}
|
||||
|
||||
return records, report
|
||||
else:
|
||||
return records
|
||||
|
||||
|
||||
def single_with_collect_data_loop(
|
||||
backtest_config: dict,
|
||||
orders: pd.DataFrame,
|
||||
time_range: Tuple[str, str],
|
||||
exchange_config: dict,
|
||||
strategy_config: dict,
|
||||
split: Literal["stock", "day"] = "stock",
|
||||
data_granularity: str = "1min",
|
||||
cash_limit: float | None = None,
|
||||
generate_report: bool = False,
|
||||
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
|
||||
@@ -253,48 +157,42 @@ def single_with_collect_data_loop(
|
||||
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
|
||||
"""
|
||||
|
||||
if split == "stock":
|
||||
stock_id = orders.iloc[0].instrument
|
||||
init_qlib(backtest_config["qlib"], part=stock_id)
|
||||
else:
|
||||
day = orders.iloc[0].datetime
|
||||
init_qlib(backtest_config["qlib"], part=day)
|
||||
|
||||
trade_start_time = orders["datetime"].min()
|
||||
trade_end_time = orders["datetime"].max()
|
||||
stocks = orders.instrument.unique().tolist()
|
||||
|
||||
strategy_config = {
|
||||
top_strategy_config = {
|
||||
"class": "FileOrderStrategy",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
"kwargs": {
|
||||
"file": orders,
|
||||
"trade_range": TradeRangeByTime(
|
||||
pd.Timestamp(backtest_config["start_time"]).time(),
|
||||
pd.Timestamp(backtest_config["end_time"]).time(),
|
||||
pd.Timestamp(time_range[0]).time(),
|
||||
pd.Timestamp(time_range[1]).time(),
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
executor_config = _get_multi_level_executor_config(
|
||||
strategy_config=backtest_config["strategies"],
|
||||
top_executor_config = _get_multi_level_executor_config(
|
||||
strategy_config=strategy_config,
|
||||
cash_limit=cash_limit,
|
||||
generate_report=generate_report,
|
||||
data_granularity=data_granularity,
|
||||
)
|
||||
|
||||
exchange_config = copy.deepcopy(backtest_config["exchange"])
|
||||
exchange_config.update(
|
||||
{
|
||||
exchange_config = {
|
||||
**exchange_config,
|
||||
**{
|
||||
"codes": stocks,
|
||||
"freq": "5min", # FIXME: move this into config
|
||||
}
|
||||
)
|
||||
"freq": data_granularity,
|
||||
},
|
||||
}
|
||||
|
||||
strategy, executor = get_strategy_executor(
|
||||
start_time=pd.Timestamp(trade_start_time),
|
||||
end_time=pd.Timestamp(trade_end_time) + pd.DateOffset(1),
|
||||
strategy=strategy_config,
|
||||
executor=executor_config,
|
||||
strategy=top_strategy_config,
|
||||
executor=top_executor_config,
|
||||
benchmark=None,
|
||||
account=cash_limit if cash_limit is not None else int(1e12),
|
||||
exchange_kwargs=exchange_config,
|
||||
@@ -302,7 +200,7 @@ def single_with_collect_data_loop(
|
||||
)
|
||||
|
||||
report_dict: dict = {}
|
||||
decisions = list(collect_data_loop(trade_start_time, trade_end_time, strategy, executor, report_dict))
|
||||
decisions = list(collect_data_loop(trade_start_time, trade_end_time, strategy, executor, report_dict, show_progress=False))
|
||||
|
||||
indicator_dict = cast(INDICATOR_METRIC, report_dict.get("indicator_dict"))
|
||||
records = _convert_indicator_to_dataframe(indicator_dict["1day"][1].order_indicator_his)
|
||||
@@ -322,46 +220,54 @@ def single_with_collect_data_loop(
|
||||
|
||||
|
||||
def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFrame:
|
||||
order_df = read_order_file(backtest_config["order_file"])
|
||||
|
||||
cash_limit = backtest_config["exchange"].pop("cash_limit")
|
||||
generate_report = backtest_config.pop("generate_report")
|
||||
|
||||
stock_pool = order_df["instrument"].unique().tolist()
|
||||
stock_pool.sort()
|
||||
|
||||
single = single_with_simulator if with_simulator else single_with_collect_data_loop
|
||||
mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"}
|
||||
init_qlib(backtest_config["simulator"]["qlib"])
|
||||
torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199
|
||||
res = Parallel(**mp_config)(
|
||||
delayed(single)(
|
||||
backtest_config=backtest_config,
|
||||
orders=order_df[order_df["instrument"] == stock].copy(),
|
||||
split="stock",
|
||||
cash_limit=cash_limit,
|
||||
generate_report=generate_report,
|
||||
|
||||
single = single_with_collect_data_loop
|
||||
mp_config = {"n_jobs": backtest_config["runtime"]["concurrency"], "verbose": 10, "backend": "multiprocessing"}
|
||||
|
||||
for task_config in backtest_config["tasks"]:
|
||||
order_df = read_order_file(task_config["order_file"])
|
||||
exchange_config = task_config["exchange"]
|
||||
cash_limit = exchange_config.pop("cash_limit")
|
||||
generate_report = backtest_config["runtime"]["generate_report"]
|
||||
|
||||
stock_pool = order_df["instrument"].unique().tolist()
|
||||
stock_pool.sort()
|
||||
|
||||
#
|
||||
res = Parallel(**mp_config)(
|
||||
delayed(single)(
|
||||
orders=order_df[order_df["instrument"] == stock].copy(),
|
||||
time_range=task_config["time_range"],
|
||||
exchange_config=task_config["exchange"],
|
||||
strategy_config=backtest_config["strategies"],
|
||||
split="stock",
|
||||
data_granularity=task_config["data_granularity"],
|
||||
cash_limit=cash_limit,
|
||||
generate_report=generate_report,
|
||||
)
|
||||
for stock in stock_pool
|
||||
)
|
||||
for stock in stock_pool
|
||||
)
|
||||
|
||||
output_path = Path(backtest_config["output_dir"])
|
||||
if generate_report:
|
||||
with (output_path / "report.pkl").open("wb") as f:
|
||||
report = {}
|
||||
for r in res:
|
||||
report.update(r[1])
|
||||
pickle.dump(report, f)
|
||||
res = pd.concat([r[0] for r in res], 0)
|
||||
else:
|
||||
res = pd.concat(res)
|
||||
|
||||
if not output_path.exists():
|
||||
os.makedirs(output_path)
|
||||
|
||||
if "pa" in res.columns:
|
||||
res["pa"] = res["pa"] * 10000.0 # align with training metrics
|
||||
res.to_csv(output_path / "backtest_result.csv")
|
||||
return res
|
||||
|
||||
#
|
||||
output_path = Path(task_config["output_dir"])
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
if generate_report:
|
||||
with (output_path / "report.pkl").open("wb") as f:
|
||||
report = {}
|
||||
for r in res:
|
||||
report.update(r[1])
|
||||
pickle.dump(report, f)
|
||||
res = pd.concat([r[0] for r in res], 0)
|
||||
else:
|
||||
res = pd.concat(res)
|
||||
|
||||
if "pa" in res.columns:
|
||||
res["pa"] = res["pa"] * 10000.0 # align with training metrics
|
||||
res.to_csv(output_path / "backtest_result.csv")
|
||||
# return res # TODO
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -369,6 +275,7 @@ if __name__ == "__main__":
|
||||
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
|
||||
@@ -381,9 +288,11 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
config = get_backtest_config_fromfile(args.config_path)
|
||||
if args.n_jobs is not None:
|
||||
config["concurrency"] = args.n_jobs
|
||||
|
||||
config_parser = BacktestConfigParser(args.config_path)
|
||||
config = config_parser.parse()
|
||||
if args.n_jobs is not None: # Overwrite concurrency
|
||||
config["runtime"]["concurrency"] = args.n_jobs
|
||||
|
||||
backtest(
|
||||
backtest_config=config,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import copy
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
@@ -30,7 +31,7 @@ def check_file_exist(filename: str, msg_tmpl: str = 'file "{}" does not exist')
|
||||
raise FileNotFoundError(msg_tmpl.format(filename))
|
||||
|
||||
|
||||
def parse_backtest_config(path: str) -> dict:
|
||||
def load_config(path: str) -> dict:
|
||||
abs_path = os.path.abspath(path)
|
||||
check_file_exist(abs_path)
|
||||
|
||||
@@ -65,42 +66,154 @@ def parse_backtest_config(path: str) -> dict:
|
||||
base_file_name = [base_file_name]
|
||||
|
||||
for f in base_file_name:
|
||||
base_config = parse_backtest_config(os.path.join(os.path.dirname(abs_path), f))
|
||||
base_config = load_config(os.path.join(os.path.dirname(abs_path), f))
|
||||
config = merge_a_into_b(a=config, b=base_config)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _convert_all_list_to_tuple(config: dict) -> dict:
|
||||
for k, v in config.items():
|
||||
if isinstance(v, list):
|
||||
config[k] = tuple(v)
|
||||
elif isinstance(v, dict):
|
||||
config[k] = _convert_all_list_to_tuple(v)
|
||||
return config
|
||||
class BacktestConfigParser:
|
||||
def __init__(self, path: str) -> None:
|
||||
self.raw_config = load_config(path)
|
||||
|
||||
def parse(self) -> dict:
|
||||
self._simulator_config = self._parse_simulator()
|
||||
self._exchange_config = self._simulator_config.pop("exchange")
|
||||
config = {
|
||||
"strategies": self.raw_config["strategies"],
|
||||
"runtime": self.raw_config["runtime"],
|
||||
"tasks": self._parse_tasks(),
|
||||
"simulator": self._simulator_config,
|
||||
}
|
||||
return config
|
||||
|
||||
def _parse_tasks(self) -> dict:
|
||||
task_config = []
|
||||
for task in self.raw_config["tasks"]:
|
||||
if "output_dir" not in task:
|
||||
task["output_dir"] = os.path.join("outputs_backtest", task["name"])
|
||||
if "exchange" not in task:
|
||||
task["exchange"] = copy.deepcopy(self._exchange_config)
|
||||
else:
|
||||
task["exchange"] = self._complete_exchange_config(task["exchange"])
|
||||
task_config.append(task)
|
||||
|
||||
return task_config
|
||||
|
||||
def _complete_exchange_config(self, exchange_config: dict) -> dict:
|
||||
exchange_config_default = {
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5.0,
|
||||
"trade_unit": 100.0,
|
||||
"cash_limit": None,
|
||||
}
|
||||
exchange_config = merge_a_into_b(a=exchange_config, b=exchange_config_default)
|
||||
return exchange_config
|
||||
|
||||
def _parse_simulator(self) -> dict:
|
||||
config = self.raw_config["simulator"]
|
||||
|
||||
return {
|
||||
"qlib": config["qlib"],
|
||||
"exchange": self._complete_exchange_config(config["exchange"]),
|
||||
}
|
||||
|
||||
|
||||
def get_backtest_config_fromfile(path: str) -> dict:
|
||||
backtest_config = parse_backtest_config(path)
|
||||
class TrainingConfigParser:
|
||||
def __init__(self, path: str) -> None:
|
||||
self.raw_config = load_config(path)
|
||||
|
||||
exchange_config_default = {
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5.0,
|
||||
"trade_unit": 100.0,
|
||||
"cash_limit": None,
|
||||
}
|
||||
backtest_config["exchange"] = merge_a_into_b(a=backtest_config["exchange"], b=exchange_config_default)
|
||||
backtest_config["exchange"] = _convert_all_list_to_tuple(backtest_config["exchange"])
|
||||
def parse(self) -> dict:
|
||||
return {
|
||||
"general": self._parse_general(),
|
||||
"policy": self.raw_config["policy"],
|
||||
"interpreter": self.raw_config["interpreter"],
|
||||
"runtime": self._parse_runtime(),
|
||||
"training": self._parse_training(),
|
||||
"simulator": self._parse_simulator(),
|
||||
}
|
||||
|
||||
backtest_config_default = {
|
||||
"debug_single_stock": None,
|
||||
"debug_single_day": None,
|
||||
"concurrency": -1,
|
||||
"multiplier": 1.0,
|
||||
"output_dir": "outputs_backtest/",
|
||||
"generate_report": False,
|
||||
}
|
||||
backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default)
|
||||
def _parse_general(self) -> dict:
|
||||
default = {
|
||||
"freq": "1min",
|
||||
"extra_module_paths": [],
|
||||
}
|
||||
return {**default, **self.raw_config["general"]}
|
||||
|
||||
return backtest_config
|
||||
def _parse_runtime(self) -> dict:
|
||||
default = {
|
||||
"seed": None,
|
||||
"use_cuda": False,
|
||||
"concurrency": 1,
|
||||
"parallel_mode": "dummy",
|
||||
}
|
||||
return {**default, **self.raw_config["runtime"]}
|
||||
|
||||
def _parse_training(self) -> dict:
|
||||
default = {
|
||||
"max_epoch": 100,
|
||||
"repeat_per_collect": 2,
|
||||
"earlystop_patience": float("inf"),
|
||||
"episode_per_collect": 10000,
|
||||
"batch_size": 256,
|
||||
"val_every_n_epoch": None,
|
||||
"checkpoint_path": "./outputs",
|
||||
"checkpoint_every_n_iters": 10,
|
||||
}
|
||||
|
||||
config = self.raw_config["training"]
|
||||
assert "order_dir" in config
|
||||
|
||||
return {**default, **config}
|
||||
|
||||
def _parse_simulator(self) -> dict:
|
||||
config = self.raw_config["simulator"]
|
||||
sim_type = config["type"]
|
||||
assert sim_type in ("simple", "full")
|
||||
|
||||
if sim_type == "simple":
|
||||
return {
|
||||
"type": sim_type,
|
||||
"data": {
|
||||
"feature_root_dir": config["data"]["feature_root_dir"],
|
||||
"feature_columns_today": config["data"]["feature_columns_today"],
|
||||
"default_start_time_index": config["data"].get("default_start_time_index", 0),
|
||||
"default_end_time_index": config["data"].get("default_end_time_index", 240),
|
||||
},
|
||||
"time_per_step": config["time_per_step"],
|
||||
"vol_limit": config["vol_limit"],
|
||||
}
|
||||
else:
|
||||
exchange_config_default = {
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5.0,
|
||||
"trade_unit": 100.0,
|
||||
# "cash_limit": None,
|
||||
}
|
||||
exchange_config = {**exchange_config_default, **config["exchange"]}
|
||||
exchange_config["freq"] = self.raw_config["general"].get("freq", "1min")
|
||||
|
||||
ret_config = {
|
||||
"type": sim_type,
|
||||
"data": {
|
||||
"feature_root_dir": config["data"]["feature_root_dir"],
|
||||
"default_start_time_index": config["data"].get("default_start_time_index", 0),
|
||||
"default_end_time_index": config["data"].get("default_end_time_index", 240),
|
||||
},
|
||||
"qlib": {
|
||||
"provider_uri_1min": config["qlib"]["provider_uri_1min"],
|
||||
},
|
||||
"exchange": exchange_config,
|
||||
}
|
||||
|
||||
return ret_config
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrainingConfigParser("/home/huoran/exp_configs/amc4th_training_refined.yml")
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
pprint(parser.parse())
|
||||
|
||||
362
qlib/rl/contrib/train.py
Normal file
362
qlib/rl/contrib/train.py
Normal file
@@ -0,0 +1,362 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Callable, cast, List, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from qlib.backtest import Order
|
||||
from qlib.backtest.decision import OrderDir
|
||||
from qlib.constant import ONE_MIN
|
||||
from qlib.rl import Simulator
|
||||
from qlib.rl.contrib.naive_config_parser import TrainingConfigParser
|
||||
from qlib.rl.data.integration import init_qlib
|
||||
from qlib.rl.data.pickle_styled import load_pickle_intraday_processed_data
|
||||
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
|
||||
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.trainer import Checkpoint, backtest, train
|
||||
from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter
|
||||
from qlib.rl.utils.log import CsvWriter
|
||||
from qlib.utils import init_instance_by_config
|
||||
from tianshou.policy import BasePolicy
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def get_executor_config(freq: int) -> dict:
|
||||
return {
|
||||
"class": "NestedExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"inner_executor": {
|
||||
"class": "NestedExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"inner_executor": {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"generate_report": False,
|
||||
"time_per_step": f"{freq}min",
|
||||
"track_data": True,
|
||||
"trade_type": "serial",
|
||||
"verbose": False,
|
||||
},
|
||||
},
|
||||
"inner_strategy": {
|
||||
"class": "TWAPStrategy",
|
||||
"kwargs": {},
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
},
|
||||
"time_per_step": "30min",
|
||||
"track_data": True,
|
||||
},
|
||||
},
|
||||
"inner_strategy": {
|
||||
"class": "ProxySAOEStrategy",
|
||||
"module_path": "qlib.rl.order_execution.strategy",
|
||||
"kwargs": {},
|
||||
},
|
||||
"time_per_step": "1day",
|
||||
"track_data": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def seed_everything(seed: int) -> None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def _read_orders(order_dir: Path) -> pd.DataFrame:
|
||||
if os.path.isfile(order_dir):
|
||||
return pd.read_pickle(order_dir)
|
||||
else:
|
||||
orders = []
|
||||
for file in order_dir.iterdir():
|
||||
order_data = pd.read_pickle(file)
|
||||
orders.append(order_data)
|
||||
return pd.concat(orders)
|
||||
|
||||
|
||||
def _freq_str_to_int(freq: str) -> int:
|
||||
if freq.endswith("min"):
|
||||
return int(freq.replace("min", ""))
|
||||
elif freq.endswith("hour"):
|
||||
return int(freq.replace("hour", "") * 60)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized freq string: {freq}")
|
||||
|
||||
|
||||
class LazyLoadDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
order_df: pd.DataFrame,
|
||||
default_start_time_index: int,
|
||||
default_end_time_index: int,
|
||||
) -> None:
|
||||
self._default_start_time_index = default_start_time_index
|
||||
self._default_end_time_index = default_end_time_index
|
||||
|
||||
self._order_df = order_df
|
||||
self._ticks_index: Optional[pd.DatetimeIndex] = None
|
||||
self._data_dir = Path(data_dir)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._order_df)
|
||||
|
||||
def __getitem__(self, index: int) -> Order:
|
||||
row = self._order_df.iloc[index]
|
||||
date = pd.Timestamp(str(row["date"]))
|
||||
|
||||
if self._ticks_index is None:
|
||||
# TODO: We only load ticks index once based on the assumption that ticks index of different dates
|
||||
# TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index
|
||||
# TODO: of all dates.
|
||||
|
||||
data = load_pickle_intraday_processed_data(
|
||||
data_dir=self._data_dir,
|
||||
stock_id=row["instrument"],
|
||||
date=date,
|
||||
feature_columns_today=[],
|
||||
feature_columns_yesterday=[],
|
||||
backtest=True,
|
||||
)
|
||||
self._ticks_index = [t - date for t in data.today.index]
|
||||
|
||||
order = Order(
|
||||
stock_id=row["instrument"],
|
||||
amount=row["amount"],
|
||||
direction=OrderDir(int(row["order_type"])),
|
||||
start_time=date + self._ticks_index[self._default_start_time_index],
|
||||
end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN,
|
||||
)
|
||||
|
||||
return order
|
||||
|
||||
|
||||
def _split_order_df_by_instrument(df: pd.DataFrame, k: int) -> List[pd.DataFrame]:
|
||||
df = df.copy()
|
||||
df["group"] = df["instrument"].apply(lambda s: hash(s) % k)
|
||||
dfs = [df[df["group"] == i].drop(columns=["group"]) for i in range(k)]
|
||||
return dfs
|
||||
|
||||
|
||||
def _get_simulator_factory(
|
||||
sim_type: str,
|
||||
data_dir: Path,
|
||||
freq_min: int,
|
||||
simulator_config: dict,
|
||||
) -> Callable[[Order], Simulator]:
|
||||
if sim_type == "simple":
|
||||
|
||||
def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
|
||||
simulator = SingleAssetOrderExecutionSimple(
|
||||
order=order,
|
||||
data_dir=data_dir,
|
||||
feature_columns_today=simulator_config["data"]["feature_columns_today"],
|
||||
data_granularity=freq_min,
|
||||
ticks_per_step=simulator_config["time_per_step"],
|
||||
vol_threshold=simulator_config["vol_limit"],
|
||||
)
|
||||
return simulator
|
||||
|
||||
return _simulator_factory_simple
|
||||
elif sim_type == "full":
|
||||
init_qlib(simulator_config["qlib"])
|
||||
executor_config = get_executor_config(freq_min)
|
||||
exchange_config = simulator_config["exchange"]
|
||||
|
||||
def _simulator_factory_full(order: Order) -> SingleAssetOrderExecution:
|
||||
simulator = SingleAssetOrderExecution(
|
||||
order=order,
|
||||
executor_config=executor_config,
|
||||
exchange_config=exchange_config, # `codes` will be set in SingleAssetOrderExecution.__init__()
|
||||
qlib_config=None,
|
||||
cash_limit=None,
|
||||
)
|
||||
return simulator
|
||||
|
||||
return _simulator_factory_full
|
||||
else:
|
||||
raise ValueError(f"Unknown simulator type: {sim_type}")
|
||||
|
||||
|
||||
def train_and_test(
|
||||
freq: str,
|
||||
concurrency: int,
|
||||
parallel_mode: str,
|
||||
training_config: dict,
|
||||
simulator_config: dict,
|
||||
policy: BasePolicy,
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
reward: Reward,
|
||||
run_training: bool,
|
||||
run_backtest: bool,
|
||||
) -> None:
|
||||
freq_min: int = _freq_str_to_int(freq)
|
||||
order_root_path = Path(training_config["order_dir"])
|
||||
feature_root_dir = simulator_config["data"]["feature_root_dir"]
|
||||
assert simulator_config["data"]["default_start_time_index"] % freq_min == 0
|
||||
assert simulator_config["data"]["default_end_time_index"] % freq_min == 0
|
||||
|
||||
_simulator_factory = _get_simulator_factory(
|
||||
sim_type=simulator_config["type"],
|
||||
data_dir=feature_root_dir,
|
||||
freq_min=freq_min,
|
||||
simulator_config=simulator_config,
|
||||
)
|
||||
|
||||
# Load orders
|
||||
load_data_tags = []
|
||||
orders_by_tag = {}
|
||||
if run_training:
|
||||
load_data_tags += ["train", "valid"]
|
||||
if run_backtest:
|
||||
load_data_tags += ["test"]
|
||||
for tag in load_data_tags:
|
||||
order_df = _read_orders(order_root_path / tag).reset_index()
|
||||
dfs = _split_order_df_by_instrument(order_df, concurrency)
|
||||
datasets = [
|
||||
LazyLoadDataset(
|
||||
data_dir=feature_root_dir,
|
||||
order_df=df,
|
||||
default_start_time_index=simulator_config["data"]["default_start_time_index"] // freq_min,
|
||||
default_end_time_index=simulator_config["data"]["default_end_time_index"] // freq_min,
|
||||
)
|
||||
for df in dfs
|
||||
]
|
||||
orders_by_tag[tag] = datasets
|
||||
|
||||
if run_training:
|
||||
callbacks: List[Callback] = [
|
||||
MetricsWriter(dirpath=Path(training_config["checkpoint_path"])),
|
||||
Checkpoint(
|
||||
dirpath=Path(training_config["checkpoint_path"]) / "checkpoints",
|
||||
every_n_iters=training_config["checkpoint_every_n_iters"],
|
||||
save_latest="copy",
|
||||
),
|
||||
EarlyStopping(
|
||||
patience=training_config["earlystop_patience"],
|
||||
monitor="val/pa",
|
||||
),
|
||||
]
|
||||
|
||||
train(
|
||||
simulator_fn=_simulator_factory,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
policy=policy,
|
||||
reward=reward,
|
||||
initial_states=cast(List[Sequence[Order]], orders_by_tag["train"]),
|
||||
trainer_kwargs={
|
||||
"max_iters": training_config["max_epoch"],
|
||||
"finite_env_type": parallel_mode,
|
||||
"concurrency": concurrency,
|
||||
"val_every_n_iters": training_config["val_every_n_epoch"],
|
||||
"callbacks": callbacks,
|
||||
},
|
||||
vessel_kwargs={
|
||||
"episode_per_iter": training_config["episode_per_collect"],
|
||||
"update_kwargs": {
|
||||
"batch_size": training_config["batch_size"],
|
||||
"repeat": training_config["repeat_per_collect"],
|
||||
},
|
||||
"val_initial_states": cast(List[Sequence[Order]], orders_by_tag["valid"]),
|
||||
},
|
||||
)
|
||||
|
||||
if run_backtest:
|
||||
backtest(
|
||||
simulator_fn=_simulator_factory,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
initial_states=cast(List[Sequence[Order]], orders_by_tag["test"]),
|
||||
policy=policy,
|
||||
logger=CsvWriter(Path(training_config["checkpoint_path"])),
|
||||
reward=reward,
|
||||
finite_env_type=parallel_mode, # type: ignore[arg-type]
|
||||
concurrency=concurrency,
|
||||
)
|
||||
|
||||
|
||||
def main(config: dict, run_training: bool, run_backtest: bool) -> None:
|
||||
if not run_training and not run_backtest:
|
||||
warnings.warn("Skip the entire job since training and backtest are both skipped.")
|
||||
return
|
||||
|
||||
seed = config["runtime"]["seed"]
|
||||
if seed is not None:
|
||||
seed_everything(seed)
|
||||
|
||||
for extra_module_path in config["general"]["extra_module_paths"]:
|
||||
sys.path.append(extra_module_path)
|
||||
|
||||
state_interpreter: StateInterpreter = init_instance_by_config(config["interpreter"]["state"])
|
||||
action_interpreter: ActionInterpreter = init_instance_by_config(config["interpreter"]["action"])
|
||||
reward: Reward = init_instance_by_config(config["interpreter"]["reward"])
|
||||
|
||||
additional_policy_kwargs = {
|
||||
"obs_space": state_interpreter.observation_space,
|
||||
"action_space": action_interpreter.action_space,
|
||||
}
|
||||
# Create torch network
|
||||
if "network" in config["policy"]:
|
||||
network_config = config["policy"]["network"]
|
||||
network_config["kwargs"] = {
|
||||
**network_config.get("kwargs", {}),
|
||||
**{"obs_space": state_interpreter.observation_space},
|
||||
}
|
||||
additional_policy_kwargs["network"] = init_instance_by_config(network_config)
|
||||
|
||||
# Create policy
|
||||
policy_config = config["policy"]["policy"]
|
||||
policy_config["kwargs"] = {**policy_config.get("kwargs", {}), **additional_policy_kwargs}
|
||||
policy: BasePolicy = init_instance_by_config(policy_config)
|
||||
|
||||
use_cuda = config["runtime"]["use_cuda"]
|
||||
if use_cuda:
|
||||
policy.cuda()
|
||||
|
||||
train_and_test(
|
||||
freq=config["general"]["freq"],
|
||||
concurrency=config["runtime"]["concurrency"],
|
||||
parallel_mode=config["runtime"]["parallel_mode"],
|
||||
training_config=config["training"],
|
||||
simulator_config=config["simulator"],
|
||||
policy=policy,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
reward=reward,
|
||||
run_training=run_training,
|
||||
run_backtest=run_backtest,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
|
||||
parser.add_argument("--no_training", action="store_true", help="Skip training workflow.")
|
||||
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.")
|
||||
args = parser.parse_args()
|
||||
|
||||
config_parser = TrainingConfigParser(args.config_path)
|
||||
config = config_parser.parse()
|
||||
main(config, run_training=not args.no_training, run_backtest=args.run_backtest)
|
||||
@@ -1,261 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import cast, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import qlib
|
||||
import torch
|
||||
import yaml
|
||||
from qlib.backtest import Order
|
||||
from qlib.backtest.decision import OrderDir
|
||||
from qlib.constant import ONE_MIN
|
||||
from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data
|
||||
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.trainer import Checkpoint, backtest, train
|
||||
from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter
|
||||
from qlib.rl.utils.log import CsvWriter
|
||||
from qlib.utils import init_instance_by_config
|
||||
from tianshou.policy import BasePolicy
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def seed_everything(seed: int) -> None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def _read_orders(order_dir: Path) -> pd.DataFrame:
|
||||
if os.path.isfile(order_dir):
|
||||
return pd.read_pickle(order_dir)
|
||||
else:
|
||||
orders = []
|
||||
for file in order_dir.iterdir():
|
||||
order_data = pd.read_pickle(file)
|
||||
orders.append(order_data)
|
||||
return pd.concat(orders)
|
||||
|
||||
|
||||
class LazyLoadDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
order_file_path: Path,
|
||||
data_dir: Path,
|
||||
default_start_time_index: int,
|
||||
default_end_time_index: int,
|
||||
) -> None:
|
||||
self._default_start_time_index = default_start_time_index
|
||||
self._default_end_time_index = default_end_time_index
|
||||
|
||||
self._order_file_path = order_file_path
|
||||
self._order_df = _read_orders(order_file_path).reset_index()
|
||||
|
||||
self._data_dir = data_dir
|
||||
self._ticks_index: Optional[pd.DatetimeIndex] = None
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._order_df)
|
||||
|
||||
def __getitem__(self, index: int) -> Order:
|
||||
row = self._order_df.iloc[index]
|
||||
date = pd.Timestamp(str(row["date"]))
|
||||
|
||||
if self._ticks_index is None:
|
||||
# TODO: We only load ticks index once based on the assumption that ticks index of different dates
|
||||
# TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index
|
||||
# TODO: of all dates.
|
||||
backtest_data = load_simple_intraday_backtest_data(
|
||||
data_dir=self._data_dir,
|
||||
stock_id=row["instrument"],
|
||||
date=date,
|
||||
)
|
||||
self._ticks_index = [t - date for t in backtest_data.get_time_index()]
|
||||
|
||||
order = Order(
|
||||
stock_id=row["instrument"],
|
||||
amount=row["amount"],
|
||||
direction=OrderDir(int(row["order_type"])),
|
||||
start_time=date + self._ticks_index[self._default_start_time_index],
|
||||
end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN,
|
||||
)
|
||||
|
||||
return order
|
||||
|
||||
|
||||
def train_and_test(
|
||||
env_config: dict,
|
||||
simulator_config: dict,
|
||||
trainer_config: dict,
|
||||
data_config: dict,
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
policy: BasePolicy,
|
||||
reward: Reward,
|
||||
run_training: bool,
|
||||
run_backtest: bool,
|
||||
) -> None:
|
||||
qlib.init()
|
||||
|
||||
order_root_path = Path(data_config["source"]["order_dir"])
|
||||
|
||||
data_granularity = simulator_config.get("data_granularity", 1)
|
||||
|
||||
def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
|
||||
return SingleAssetOrderExecutionSimple(
|
||||
order=order,
|
||||
data_dir=Path(data_config["source"]["data_dir"]),
|
||||
ticks_per_step=simulator_config["time_per_step"],
|
||||
data_granularity=data_granularity,
|
||||
deal_price_type=data_config["source"].get("deal_price_column", "close"),
|
||||
vol_threshold=simulator_config["vol_limit"],
|
||||
)
|
||||
|
||||
assert data_config["source"]["default_start_time_index"] % data_granularity == 0
|
||||
assert data_config["source"]["default_end_time_index"] % data_granularity == 0
|
||||
|
||||
if run_training:
|
||||
train_dataset, valid_dataset = [
|
||||
LazyLoadDataset(
|
||||
order_file_path=order_root_path / tag,
|
||||
data_dir=Path(data_config["source"]["data_dir"]),
|
||||
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
|
||||
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
|
||||
)
|
||||
for tag in ("train", "valid")
|
||||
]
|
||||
|
||||
callbacks: List[Callback] = []
|
||||
if "checkpoint_path" in trainer_config:
|
||||
callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"])))
|
||||
callbacks.append(
|
||||
Checkpoint(
|
||||
dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints",
|
||||
every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1),
|
||||
save_latest="copy",
|
||||
),
|
||||
)
|
||||
if "earlystop_patience" in trainer_config:
|
||||
callbacks.append(
|
||||
EarlyStopping(
|
||||
patience=trainer_config["earlystop_patience"],
|
||||
monitor="val/pa",
|
||||
)
|
||||
)
|
||||
|
||||
train(
|
||||
simulator_fn=_simulator_factory_simple,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
policy=policy,
|
||||
reward=reward,
|
||||
initial_states=cast(List[Order], train_dataset),
|
||||
trainer_kwargs={
|
||||
"max_iters": trainer_config["max_epoch"],
|
||||
"finite_env_type": env_config["parallel_mode"],
|
||||
"concurrency": env_config["concurrency"],
|
||||
"val_every_n_iters": trainer_config.get("val_every_n_epoch", None),
|
||||
"callbacks": callbacks,
|
||||
},
|
||||
vessel_kwargs={
|
||||
"episode_per_iter": trainer_config["episode_per_collect"],
|
||||
"update_kwargs": {
|
||||
"batch_size": trainer_config["batch_size"],
|
||||
"repeat": trainer_config["repeat_per_collect"],
|
||||
},
|
||||
"val_initial_states": valid_dataset,
|
||||
},
|
||||
)
|
||||
|
||||
if run_backtest:
|
||||
test_dataset = LazyLoadDataset(
|
||||
order_file_path=order_root_path / "test",
|
||||
data_dir=Path(data_config["source"]["data_dir"]),
|
||||
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
|
||||
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
|
||||
)
|
||||
|
||||
backtest(
|
||||
simulator_fn=_simulator_factory_simple,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
initial_states=test_dataset,
|
||||
policy=policy,
|
||||
logger=CsvWriter(Path(trainer_config["checkpoint_path"])),
|
||||
reward=reward,
|
||||
finite_env_type=env_config["parallel_mode"],
|
||||
concurrency=env_config["concurrency"],
|
||||
)
|
||||
|
||||
|
||||
def main(config: dict, run_training: bool, run_backtest: bool) -> None:
|
||||
if not run_training and not run_backtest:
|
||||
warnings.warn("Skip the entire job since training and backtest are both skipped.")
|
||||
return
|
||||
|
||||
if "seed" in config["runtime"]:
|
||||
seed_everything(config["runtime"]["seed"])
|
||||
|
||||
state_interpreter: StateInterpreter = init_instance_by_config(config["state_interpreter"])
|
||||
action_interpreter: ActionInterpreter = init_instance_by_config(config["action_interpreter"])
|
||||
reward: Reward = init_instance_by_config(config["reward"])
|
||||
|
||||
additional_policy_kwargs = {
|
||||
"obs_space": state_interpreter.observation_space,
|
||||
"action_space": action_interpreter.action_space,
|
||||
}
|
||||
|
||||
# Create torch network
|
||||
if "network" in config:
|
||||
if "kwargs" not in config["network"]:
|
||||
config["network"]["kwargs"] = {}
|
||||
config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space})
|
||||
additional_policy_kwargs["network"] = init_instance_by_config(config["network"])
|
||||
|
||||
# Create policy
|
||||
if "kwargs" not in config["policy"]:
|
||||
config["policy"]["kwargs"] = {}
|
||||
config["policy"]["kwargs"].update(additional_policy_kwargs)
|
||||
policy: BasePolicy = init_instance_by_config(config["policy"])
|
||||
|
||||
use_cuda = config["runtime"].get("use_cuda", False)
|
||||
if use_cuda:
|
||||
policy.cuda()
|
||||
|
||||
train_and_test(
|
||||
env_config=config["env"],
|
||||
simulator_config=config["simulator"],
|
||||
data_config=config["data"],
|
||||
trainer_config=config["trainer"],
|
||||
action_interpreter=action_interpreter,
|
||||
state_interpreter=state_interpreter,
|
||||
policy=policy,
|
||||
reward=reward,
|
||||
run_training=run_training,
|
||||
run_backtest=run_backtest,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
|
||||
parser.add_argument("--no_training", action="store_true", help="Skip training workflow.")
|
||||
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.")
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config_path, "r") as input_stream:
|
||||
config = yaml.safe_load(input_stream)
|
||||
|
||||
main(config, run_training=not args.no_training, run_backtest=args.run_backtest)
|
||||
@@ -8,48 +8,14 @@ TODO: The implementation here is kind of adhoc. It is better to design a more un
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import cachetools
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import qlib
|
||||
from qlib.constant import REG_CN
|
||||
from qlib.contrib.ops.high_freq import BFillNan, Cut, Date, DayCumsum, DayLast, FFillNan, IsInf, IsNull, Select
|
||||
from qlib.data.dataset import DatasetH
|
||||
|
||||
dataset = None
|
||||
|
||||
|
||||
class DataWrapper:
|
||||
def __init__(
|
||||
self,
|
||||
feature_dataset: DatasetH,
|
||||
backtest_dataset: DatasetH,
|
||||
columns_today: List[str],
|
||||
columns_yesterday: List[str],
|
||||
_internal: bool = False,
|
||||
):
|
||||
assert _internal, "Init function of data wrapper is for internal use only."
|
||||
|
||||
self.feature_dataset = feature_dataset
|
||||
self.backtest_dataset = backtest_dataset
|
||||
self.columns_today = columns_today
|
||||
self.columns_yesterday = columns_yesterday
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(100),
|
||||
key=lambda _, stock_id, date, backtest: (stock_id, date.replace(hour=0, minute=0, second=0), backtest),
|
||||
)
|
||||
def get(self, stock_id: str, date: pd.Timestamp, backtest: bool = False) -> pd.DataFrame:
|
||||
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
|
||||
dataset = self.backtest_dataset if backtest else self.feature_dataset
|
||||
return dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
|
||||
|
||||
|
||||
def init_qlib(qlib_config: dict, part: str | None = None) -> None:
|
||||
def init_qlib(qlib_config: dict) -> None:
|
||||
"""Initialize necessary resource to launch the workflow, including data direction, feature columns, etc..
|
||||
|
||||
Parameters
|
||||
@@ -72,12 +38,8 @@ def init_qlib(qlib_config: dict, part: str | None = None) -> None:
|
||||
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1",
|
||||
],
|
||||
}
|
||||
part
|
||||
Identifying which part (stock / date) to load.
|
||||
"""
|
||||
|
||||
global dataset # pylint: disable=W0603
|
||||
|
||||
def _convert_to_path(path: str | Path) -> Path:
|
||||
return path if isinstance(path, Path) else Path(path)
|
||||
|
||||
@@ -118,47 +80,3 @@ def init_qlib(qlib_config: dict, part: str | None = None) -> None:
|
||||
redis_port=-1,
|
||||
clear_mem_cache=False, # init_qlib will be called for multiple times. Keep the cache for improving performance
|
||||
)
|
||||
|
||||
if part == "skip":
|
||||
return
|
||||
|
||||
# this won't work if it's put outside in case of multiprocessing
|
||||
from qlib.data import D # noqa pylint: disable=C0415,W0611
|
||||
|
||||
if part is None:
|
||||
feature_path = Path(qlib_config["feature_root_dir"]) / "feature.pkl"
|
||||
backtest_path = Path(qlib_config["feature_root_dir"]) / "backtest.pkl"
|
||||
else:
|
||||
feature_path = Path(qlib_config["feature_root_dir"]) / "feature" / (part + ".pkl")
|
||||
backtest_path = Path(qlib_config["feature_root_dir"]) / "backtest" / (part + ".pkl")
|
||||
|
||||
with feature_path.open("rb") as f:
|
||||
feature_dataset = pickle.load(f)
|
||||
with backtest_path.open("rb") as f:
|
||||
backtest_dataset = pickle.load(f)
|
||||
|
||||
dataset = DataWrapper(
|
||||
feature_dataset,
|
||||
backtest_dataset,
|
||||
qlib_config["feature_columns_today"],
|
||||
qlib_config["feature_columns_yesterday"],
|
||||
_internal=True,
|
||||
)
|
||||
|
||||
|
||||
def fetch_features(stock_id: str, date: pd.Timestamp, yesterday: bool = False, backtest: bool = False) -> pd.DataFrame:
|
||||
assert dataset is not None, "You must call init_qlib() before doing this."
|
||||
|
||||
if backtest:
|
||||
fields = ["$close", "$volume"]
|
||||
else:
|
||||
fields = dataset.columns_yesterday if yesterday else dataset.columns_today
|
||||
|
||||
data = dataset.get(stock_id, date, backtest)
|
||||
if data is None or len(data) == 0:
|
||||
# create a fake index, but RL doesn't care about index
|
||||
data = pd.DataFrame(0.0, index=np.arange(240), columns=fields, dtype=np.float32) # FIXME: hardcode here
|
||||
else:
|
||||
data = data.rename(columns={c: c.rstrip("0") for c in data.columns})
|
||||
data = data[fields]
|
||||
return data
|
||||
|
||||
@@ -2,17 +2,30 @@
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
from pathlib import Path
|
||||
from typing import cast, List
|
||||
|
||||
import cachetools
|
||||
import pandas as pd
|
||||
import pickle
|
||||
import os
|
||||
|
||||
from qlib.backtest import Exchange, Order
|
||||
from qlib.backtest.decision import TradeRange, TradeRangeByTime
|
||||
from qlib.rl.order_execution.utils import get_ticks_slice
|
||||
|
||||
from qlib.constant import EPS_T
|
||||
from qlib.data.dataset import DatasetH
|
||||
from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
|
||||
from .integration import fetch_features
|
||||
|
||||
|
||||
def get_ticks_slice(
|
||||
ticks_index: pd.DatetimeIndex,
|
||||
start: pd.Timestamp,
|
||||
end: pd.Timestamp,
|
||||
include_end: bool = False,
|
||||
) -> pd.DatetimeIndex:
|
||||
if not include_end:
|
||||
end = end - EPS_T
|
||||
return ticks_index[ticks_index.slice_indexer(start, end)]
|
||||
|
||||
|
||||
class IntradayBacktestData(BaseIntradayBacktestData):
|
||||
@@ -71,6 +84,31 @@ class IntradayBacktestData(BaseIntradayBacktestData):
|
||||
return pd.DatetimeIndex([e[1] for e in list(self._exchange.quote_df.index)])
|
||||
|
||||
|
||||
class DataframeIntradayBacktestData(BaseIntradayBacktestData):
|
||||
"""Backtest data from dataframe"""
|
||||
|
||||
def __init__(self, df: pd.DataFrame, price_column: str = "$close0", volume_column: str = "$volume0") -> None:
|
||||
self.df = df
|
||||
self.price_column = price_column
|
||||
self.volume_column = volume_column
|
||||
|
||||
def __repr__(self) -> str:
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
return f"{self.__class__.__name__}({self.df})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.df)
|
||||
|
||||
def get_deal_price(self) -> pd.Series:
|
||||
return self.df[self.price_column]
|
||||
|
||||
def get_volume(self) -> pd.Series:
|
||||
return self.df[self.volume_column]
|
||||
|
||||
def get_time_index(self) -> pd.DatetimeIndex:
|
||||
return cast(pd.DatetimeIndex, self.df.index)
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(100),
|
||||
key=lambda order, _, __: order.key_by_day,
|
||||
@@ -103,13 +141,27 @@ def load_backtest_data(
|
||||
return backtest_data
|
||||
|
||||
|
||||
class NTIntradayProcessedData(BaseIntradayProcessedData):
|
||||
"""Subclass of IntradayProcessedData. Used to handle NT style data."""
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(1000),
|
||||
key=lambda path: path,
|
||||
)
|
||||
def _load_handler_pickle(path: str) -> DatasetH:
|
||||
with open(path, "rb") as fstream:
|
||||
obj = pickle.load(fstream)
|
||||
return obj
|
||||
|
||||
|
||||
class HandlerIntradayProcessedData(BaseIntradayProcessedData):
|
||||
"""Subclass of IntradayProcessedData. Used to handle handler (bin format) style data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool = False,
|
||||
) -> None:
|
||||
def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame:
|
||||
df = df.reset_index()
|
||||
@@ -117,22 +169,52 @@ class NTIntradayProcessedData(BaseIntradayProcessedData):
|
||||
df = df.drop(columns=["instrument"])
|
||||
return df.set_index(["datetime"])
|
||||
|
||||
self.today = _drop_stock_id(fetch_features(stock_id, date))
|
||||
self.yesterday = _drop_stock_id(fetch_features(stock_id, date, yesterday=True))
|
||||
path = os.path.join(data_dir, "backtest" if backtest else "feature", f"{stock_id}.pkl")
|
||||
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
|
||||
dataset = _load_handler_pickle(path)
|
||||
data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
|
||||
|
||||
self.today = _drop_stock_id(data[feature_columns_today])
|
||||
self.yesterday = _drop_stock_id(data[feature_columns_yesterday])
|
||||
|
||||
def __repr__(self) -> str:
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
return f"{self.__class__.__name__}({self.today}, {self.yesterday})"
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
|
||||
)
|
||||
def load_nt_intraday_processed_data(stock_id: str, date: pd.Timestamp) -> NTIntradayProcessedData:
|
||||
return NTIntradayProcessedData(stock_id, date)
|
||||
def load_handler_intraday_processed_data(
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool = False,
|
||||
) -> HandlerIntradayProcessedData:
|
||||
return HandlerIntradayProcessedData(
|
||||
data_dir,
|
||||
stock_id,
|
||||
date,
|
||||
feature_columns_today,
|
||||
feature_columns_yesterday,
|
||||
backtest,
|
||||
)
|
||||
|
||||
|
||||
class NTProcessedDataProvider(ProcessedDataProvider):
|
||||
class HandlerProcessedDataProvider(ProcessedDataProvider):
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.data_dir = Path(data_dir)
|
||||
self.feature_columns_today = feature_columns_today
|
||||
self.feature_columns_yesterday = feature_columns_yesterday
|
||||
self.backtest = backtest
|
||||
|
||||
def get_data(
|
||||
self,
|
||||
stock_id: str,
|
||||
@@ -140,4 +222,11 @@ class NTProcessedDataProvider(ProcessedDataProvider):
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
) -> BaseIntradayProcessedData:
|
||||
return load_nt_intraday_processed_data(stock_id, date)
|
||||
return load_handler_intraday_processed_data(
|
||||
self.data_dir,
|
||||
stock_id,
|
||||
date,
|
||||
self.feature_columns_today,
|
||||
self.feature_columns_yesterday,
|
||||
backtest=self.backtest,
|
||||
)
|
||||
|
||||
@@ -26,7 +26,6 @@ from typing import List, Sequence, cast
|
||||
import cachetools
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from cachetools.keys import hashkey
|
||||
|
||||
from qlib.backtest.decision import Order, OrderDir
|
||||
from qlib.rl.data.base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
|
||||
@@ -158,44 +157,35 @@ class SimpleIntradayBacktestData(BaseIntradayBacktestData):
|
||||
return cast(pd.DatetimeIndex, self.data.index)
|
||||
|
||||
|
||||
class IntradayProcessedData(BaseIntradayProcessedData):
|
||||
"""Subclass of IntradayProcessedData. Used to handle Dataset Handler style data."""
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(1000),
|
||||
key=lambda path: path,
|
||||
)
|
||||
def _load_df_pickle(path: str) -> pd.DataFrame:
|
||||
df = pd.read_pickle(path)
|
||||
return df
|
||||
|
||||
|
||||
class PickleIntradayProcessedData(BaseIntradayProcessedData):
|
||||
"""Subclass of IntradayProcessedData. Used to handle pickle-styled data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: Path | str,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool,
|
||||
) -> None:
|
||||
proc = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id)
|
||||
if isinstance(data_dir, str):
|
||||
data_dir = Path(data_dir)
|
||||
path = data_dir / ("backtest" if backtest else "feature") / f"{stock_id}.pkl"
|
||||
df = _load_df_pickle(str(path))
|
||||
df = df.loc[pd.IndexSlice[stock_id, :, date]]
|
||||
|
||||
# We have to infer the names here because,
|
||||
# unfortunately they are not included in the original data.
|
||||
cnames = _infer_processed_data_column_names(feature_dim)
|
||||
|
||||
time_length: int = len(time_index)
|
||||
|
||||
try:
|
||||
# new data format
|
||||
proc = proc.loc[pd.IndexSlice[stock_id, :, date]]
|
||||
assert len(proc) == time_length and len(proc.columns) == feature_dim * 2
|
||||
proc_today = proc[cnames]
|
||||
proc_yesterday = proc[[f"{c}_1" for c in cnames]].rename(columns=lambda c: c[:-2])
|
||||
except (IndexError, KeyError):
|
||||
# legacy data
|
||||
proc = proc.loc[pd.IndexSlice[stock_id, date]]
|
||||
assert time_length * feature_dim * 2 == len(proc)
|
||||
proc_today = proc.to_numpy()[: time_length * feature_dim].reshape((time_length, feature_dim))
|
||||
proc_yesterday = proc.to_numpy()[time_length * feature_dim :].reshape((time_length, feature_dim))
|
||||
proc_today = pd.DataFrame(proc_today, index=time_index, columns=cnames)
|
||||
proc_yesterday = pd.DataFrame(proc_yesterday, index=time_index, columns=cnames)
|
||||
|
||||
self.today: pd.DataFrame = proc_today
|
||||
self.yesterday: pd.DataFrame = proc_yesterday
|
||||
assert len(self.today.columns) == len(self.yesterday.columns) == feature_dim
|
||||
assert len(self.today) == len(self.yesterday) == time_length
|
||||
self.today = df[feature_columns_today]
|
||||
self.yesterday = df[feature_columns_yesterday]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
@@ -213,25 +203,38 @@ def load_simple_intraday_backtest_data(
|
||||
return SimpleIntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir)
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
|
||||
key=lambda data_dir, stock_id, date, feature_dim, time_index: hashkey(data_dir, stock_id, date),
|
||||
)
|
||||
def load_pickled_intraday_processed_data(
|
||||
def load_pickle_intraday_processed_data(
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool = False,
|
||||
) -> BaseIntradayProcessedData:
|
||||
return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index)
|
||||
return PickleIntradayProcessedData(
|
||||
data_dir,
|
||||
stock_id,
|
||||
date,
|
||||
feature_columns_today,
|
||||
feature_columns_yesterday,
|
||||
backtest,
|
||||
)
|
||||
|
||||
|
||||
class PickleProcessedDataProvider(ProcessedDataProvider):
|
||||
def __init__(self, data_dir: Path) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: Path,
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._data_dir = data_dir
|
||||
self._backtest = backtest
|
||||
self._feature_columns_today = feature_columns_today
|
||||
self._feature_columns_yesterday = feature_columns_yesterday
|
||||
|
||||
def get_data(
|
||||
self,
|
||||
@@ -240,12 +243,13 @@ class PickleProcessedDataProvider(ProcessedDataProvider):
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
) -> BaseIntradayProcessedData:
|
||||
return load_pickled_intraday_processed_data(
|
||||
return load_pickle_intraday_processed_data(
|
||||
data_dir=self._data_dir,
|
||||
stock_id=stock_id,
|
||||
date=date,
|
||||
feature_dim=feature_dim,
|
||||
time_index=time_index,
|
||||
feature_columns_today=self._feature_columns_today,
|
||||
feature_columns_yesterday=self._feature_columns_yesterday,
|
||||
backtest=self._backtest,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -4,10 +4,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generator, List, Optional
|
||||
import cachetools
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest import collect_data_loop, get_strategy_executor
|
||||
from qlib.backtest import collect_data_loop, Exchange, get_exchange, get_strategy_executor
|
||||
from qlib.backtest.decision import BaseTradeDecision, Order, TradeRangeByTime
|
||||
from qlib.backtest.executor import NestedExecutor
|
||||
from qlib.rl.data.integration import init_qlib
|
||||
@@ -16,6 +17,18 @@ from .state import SAOEState
|
||||
from .strategy import SAOEStateAdapter, SAOEStrategy
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(1000),
|
||||
key=lambda order, _: order.stock_id,
|
||||
)
|
||||
def _create_exchange(order: Order, exchange_config: dict) -> Exchange:
|
||||
exchange_kwargs = {
|
||||
**exchange_config,
|
||||
"codes": [order.stock_id],
|
||||
}
|
||||
return get_exchange(**exchange_kwargs)
|
||||
|
||||
|
||||
class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
"""Single-asset order execution (SAOE) simulator which is implemented based on Qlib backtest tools.
|
||||
|
||||
@@ -67,7 +80,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
cash_limit: Optional[float] = None,
|
||||
) -> None:
|
||||
if qlib_config is not None:
|
||||
init_qlib(qlib_config, part="skip")
|
||||
init_qlib(qlib_config)
|
||||
|
||||
strategy, self._executor = get_strategy_executor(
|
||||
start_time=order.date,
|
||||
@@ -76,7 +89,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
executor=executor_config,
|
||||
benchmark=order.stock_id,
|
||||
account=cash_limit if cash_limit is not None else int(1e12),
|
||||
exchange_kwargs=exchange_config,
|
||||
exchange_kwargs=_create_exchange(order, exchange_config),
|
||||
pos_type="Position" if cash_limit is not None else "InfPosition",
|
||||
)
|
||||
|
||||
@@ -90,6 +103,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
trade_strategy=strategy,
|
||||
trade_executor=self._executor,
|
||||
return_value=self.report_dict,
|
||||
show_progress=False,
|
||||
)
|
||||
assert isinstance(self._collect_data_loop, Generator)
|
||||
|
||||
|
||||
@@ -3,17 +3,20 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, cast, Optional
|
||||
from typing import Any, cast, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from pathlib import Path
|
||||
from qlib.backtest.decision import Order, OrderDir
|
||||
from qlib.constant import EPS, EPS_T, float_or_ndarray
|
||||
from qlib.rl.data.pickle_styled import DealPriceType, load_simple_intraday_backtest_data
|
||||
from qlib.rl.data.base import BaseIntradayBacktestData
|
||||
from qlib.rl.data.native import DataframeIntradayBacktestData
|
||||
from qlib.rl.data.pickle_styled import load_pickle_intraday_processed_data
|
||||
from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data
|
||||
from qlib.rl.simulator import Simulator
|
||||
from qlib.rl.utils import LogLevel
|
||||
|
||||
from .state import SAOEMetrics, SAOEState
|
||||
|
||||
__all__ = ["SingleAssetOrderExecutionSimple"]
|
||||
@@ -36,12 +39,14 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
----------
|
||||
order
|
||||
The seed to start an SAOE simulator is an order.
|
||||
data_dir
|
||||
Path to load backtest data.
|
||||
feature_columns_today
|
||||
Columns of today's feature.
|
||||
data_granularity
|
||||
Number of ticks between consecutive data entries.
|
||||
ticks_per_step
|
||||
How many ticks per step.
|
||||
data_dir
|
||||
Path to load backtest data
|
||||
vol_threshold
|
||||
Maximum execution volume (divided by market execution volume).
|
||||
"""
|
||||
@@ -73,9 +78,9 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
self,
|
||||
order: Order,
|
||||
data_dir: Path,
|
||||
feature_columns_today: List[str] = [],
|
||||
data_granularity: int = 1,
|
||||
ticks_per_step: int = 30,
|
||||
deal_price_type: DealPriceType = "close",
|
||||
vol_threshold: Optional[float] = None,
|
||||
) -> None:
|
||||
super().__init__(initial=order)
|
||||
@@ -83,18 +88,12 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
assert ticks_per_step % data_granularity == 0
|
||||
|
||||
self.order = order
|
||||
self.ticks_per_step: int = ticks_per_step // data_granularity
|
||||
self.deal_price_type = deal_price_type
|
||||
self.vol_threshold = vol_threshold
|
||||
self.data_dir = data_dir
|
||||
self.backtest_data = load_simple_intraday_backtest_data(
|
||||
self.data_dir,
|
||||
order.stock_id,
|
||||
pd.Timestamp(order.start_time.date()),
|
||||
self.deal_price_type,
|
||||
order.direction,
|
||||
)
|
||||
self.feature_columns_today = feature_columns_today
|
||||
self.ticks_per_step: int = ticks_per_step // data_granularity
|
||||
self.vol_threshold = vol_threshold
|
||||
|
||||
self.backtest_data = self.get_backtest_data()
|
||||
self.ticks_index = self.backtest_data.get_time_index()
|
||||
|
||||
# Get time index available for trading
|
||||
@@ -118,6 +117,29 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
self.market_vol: Optional[np.ndarray] = None
|
||||
self.market_vol_limit: Optional[np.ndarray] = None
|
||||
|
||||
def get_backtest_data(self) -> BaseIntradayBacktestData:
|
||||
try:
|
||||
data = load_pickle_intraday_processed_data(
|
||||
data_dir=self.data_dir,
|
||||
stock_id=self.order.stock_id,
|
||||
date=pd.Timestamp(self.order.start_time.date()),
|
||||
feature_columns_today=self.feature_columns_today,
|
||||
feature_columns_yesterday=[],
|
||||
backtest=True,
|
||||
)
|
||||
return DataframeIntradayBacktestData(data.today)
|
||||
except (AttributeError, FileNotFoundError):
|
||||
# TODO: For compatibility with older versions of test scripts (tests/rl/test_saoe_simple.py)
|
||||
# TODO: In the future, we should modify the data format used by the test script,
|
||||
# TODO: and then delete this branch.
|
||||
return load_simple_intraday_backtest_data(
|
||||
self.data_dir / "backtest",
|
||||
self.order.stock_id,
|
||||
pd.Timestamp(self.order.start_time.date()),
|
||||
"close",
|
||||
self.order.direction,
|
||||
)
|
||||
|
||||
def step(self, amount: float) -> None:
|
||||
"""Execute one step or SAOE.
|
||||
|
||||
|
||||
@@ -451,6 +451,7 @@ class SAOEIntStrategy(SAOEStrategy):
|
||||
state_interpreter: dict | StateInterpreter,
|
||||
action_interpreter: dict | ActionInterpreter,
|
||||
network: dict | torch.nn.Module | None = None,
|
||||
immediate_addition: bool = False,
|
||||
outer_trade_decision: BaseTradeDecision | None = None,
|
||||
level_infra: LevelInfrastructure | None = None,
|
||||
common_infra: CommonInfrastructure | None = None,
|
||||
@@ -501,9 +502,12 @@ class SAOEIntStrategy(SAOEStrategy):
|
||||
|
||||
if self._policy is not None:
|
||||
self._policy.eval()
|
||||
|
||||
self.immediate_addition = immediate_addition
|
||||
|
||||
def reset(self, outer_trade_decision: BaseTradeDecision | None = None, **kwargs: Any) -> None:
|
||||
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
self.trade_amount_planned = collections.defaultdict(float)
|
||||
|
||||
def _generate_trade_details(self, act: np.ndarray, exec_vols: List[float]) -> pd.DataFrame:
|
||||
assert hasattr(self.outer_trade_decision, "order_list")
|
||||
@@ -539,9 +543,15 @@ class SAOEIntStrategy(SAOEStrategy):
|
||||
|
||||
oh = self.trade_exchange.get_order_helper()
|
||||
order_list = []
|
||||
for decision, exec_vol in zip(self.outer_trade_decision.get_decision(), exec_vols):
|
||||
for decision, exec_vol, state in zip(self.outer_trade_decision.get_decision(), exec_vols, states):
|
||||
order = cast(Order, decision)
|
||||
if self.immediate_addition:
|
||||
self.trade_amount_planned[order.stock_id] += exec_vol
|
||||
amount_planned = self.trade_amount_planned[order.stock_id]
|
||||
amount_finished = order.amount - state.position
|
||||
exec_vol = min(state.position, amount_planned - amount_finished)
|
||||
|
||||
if exec_vol != 0:
|
||||
order = cast(Order, decision)
|
||||
order_list.append(oh.create(order.stock_id, exec_vol, order.direction))
|
||||
|
||||
return TradeDecisionWithDetails(
|
||||
|
||||
@@ -10,18 +10,7 @@ import pandas as pd
|
||||
|
||||
from qlib.backtest.decision import OrderDir
|
||||
from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor
|
||||
from qlib.constant import EPS_T, float_or_ndarray
|
||||
|
||||
|
||||
def get_ticks_slice(
|
||||
ticks_index: pd.DatetimeIndex,
|
||||
start: pd.Timestamp,
|
||||
end: pd.Timestamp,
|
||||
include_end: bool = False,
|
||||
) -> pd.DatetimeIndex:
|
||||
if not include_end:
|
||||
end = end - EPS_T
|
||||
return ticks_index[ticks_index.slice_indexer(start, end)]
|
||||
from qlib.constant import float_or_ndarray
|
||||
|
||||
|
||||
def dataframe_append(df: pd.DataFrame, other: Any) -> pd.DataFrame:
|
||||
|
||||
@@ -20,7 +20,7 @@ def train(
|
||||
simulator_fn: Callable[[InitialStateType], Simulator],
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
initial_states: Sequence[InitialStateType],
|
||||
initial_states: List[Sequence[InitialStateType]],
|
||||
policy: BasePolicy,
|
||||
reward: Reward,
|
||||
vessel_kwargs: Dict[str, Any],
|
||||
@@ -39,7 +39,9 @@ def train(
|
||||
action_interpreter
|
||||
Interprets the policy actions.
|
||||
initial_states
|
||||
Initial states to iterate over. Every state will be run exactly once.
|
||||
List of Initial state iterators to iterate over. There should be 1 or `concurrency` initial state iterators in
|
||||
the list. If there is only 1 initial state iterator, this iterator will be shared by all workers and every
|
||||
state will be run exactly once. Otherwise, every worker will have its own iterator.
|
||||
policy
|
||||
Policy to train against.
|
||||
reward
|
||||
@@ -67,7 +69,7 @@ def backtest(
|
||||
simulator_fn: Callable[[InitialStateType], Simulator],
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
initial_states: Sequence[InitialStateType],
|
||||
initial_states: List[Sequence[InitialStateType]],
|
||||
policy: BasePolicy,
|
||||
logger: LogWriter | List[LogWriter],
|
||||
reward: Reward | None = None,
|
||||
@@ -87,7 +89,9 @@ def backtest(
|
||||
action_interpreter
|
||||
Interprets the policy actions.
|
||||
initial_states
|
||||
Initial states to iterate over. Every state will be run exactly once.
|
||||
List of Initial state iterators to iterate over. There should be 1 or `concurrency` initial state iterators in
|
||||
the list. If there is only 1 initial state iterator, this iterator will be shared by all workers and every
|
||||
state will be run exactly once. Otherwise, every worker will have its own iterator.
|
||||
policy
|
||||
Policy to test against.
|
||||
logger
|
||||
|
||||
@@ -5,8 +5,9 @@ from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import copy
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from contextlib import AbstractContextManager, ExitStack, contextmanager
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, OrderedDict, Sequence, TypeVar, cast
|
||||
|
||||
@@ -206,45 +207,50 @@ class Trainer:
|
||||
|
||||
self._call_callback_hooks("on_fit_start")
|
||||
|
||||
while not self.should_stop:
|
||||
msg = f"\n{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\tTrain iteration {self.current_iter + 1}/{self.max_iters}"
|
||||
_logger.info(msg)
|
||||
with _wrap_context(vessel.train_seed_iterators()) as train_iterators, _wrap_context(
|
||||
vessel.val_seed_iterators()
|
||||
) as valid_iterators:
|
||||
train_vector_env = self.venv_from_iterator(train_iterators)
|
||||
valid_vector_env = self.venv_from_iterator(valid_iterators)
|
||||
|
||||
self.initialize_iter()
|
||||
while not self.should_stop:
|
||||
msg = f"\n{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\tTrain iteration {self.current_iter + 1}/{self.max_iters}"
|
||||
print(msg)
|
||||
_logger.info(msg)
|
||||
|
||||
self._call_callback_hooks("on_iter_start")
|
||||
self.initialize_iter()
|
||||
|
||||
self.current_stage = "train"
|
||||
self._call_callback_hooks("on_train_start")
|
||||
self._call_callback_hooks("on_iter_start")
|
||||
|
||||
# TODO
|
||||
# Add a feature that supports reloading the training environment every few iterations.
|
||||
with _wrap_context(vessel.train_seed_iterator()) as iterator:
|
||||
vector_env = self.venv_from_iterator(iterator)
|
||||
self.vessel.train(vector_env)
|
||||
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
|
||||
self.current_stage = "train"
|
||||
self._call_callback_hooks("on_train_start")
|
||||
|
||||
self._call_callback_hooks("on_train_end")
|
||||
# TODO
|
||||
# Add a feature that supports reloading the training environment every few iterations.
|
||||
self.vessel.train(train_vector_env)
|
||||
|
||||
if self.val_every_n_iters is not None and (self.current_iter + 1) % self.val_every_n_iters == 0:
|
||||
# Implementation of validation loop
|
||||
self.current_stage = "val"
|
||||
self._call_callback_hooks("on_validate_start")
|
||||
with _wrap_context(vessel.val_seed_iterator()) as iterator:
|
||||
vector_env = self.venv_from_iterator(iterator)
|
||||
self.vessel.validate(vector_env)
|
||||
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
|
||||
self._call_callback_hooks("on_train_end")
|
||||
|
||||
self._call_callback_hooks("on_validate_end")
|
||||
if self.val_every_n_iters is not None and (self.current_iter + 1) % self.val_every_n_iters == 0:
|
||||
# Implementation of validation loop
|
||||
self.current_stage = "val"
|
||||
self._call_callback_hooks("on_validate_start")
|
||||
|
||||
# This iteration is considered complete.
|
||||
# Bumping the current iteration counter.
|
||||
self.current_iter += 1
|
||||
self.vessel.validate(valid_vector_env)
|
||||
|
||||
if self.max_iters is not None and self.current_iter >= self.max_iters:
|
||||
self.should_stop = True
|
||||
self._call_callback_hooks("on_validate_end")
|
||||
|
||||
self._call_callback_hooks("on_iter_end")
|
||||
# This iteration is considered complete.
|
||||
# Bumping the current iteration counter.
|
||||
self.current_iter += 1
|
||||
|
||||
if self.max_iters is not None and self.current_iter >= self.max_iters:
|
||||
self.should_stop = True
|
||||
|
||||
self._call_callback_hooks("on_iter_end")
|
||||
|
||||
del train_vector_env # FIXME: Explicitly delete this object to avoid memory leak.
|
||||
del valid_vector_env # FIXME: Explicitly delete this object to avoid memory leak.
|
||||
|
||||
self._call_callback_hooks("on_fit_end")
|
||||
|
||||
@@ -265,16 +271,16 @@ class Trainer:
|
||||
|
||||
self.current_stage = "test"
|
||||
self._call_callback_hooks("on_test_start")
|
||||
with _wrap_context(vessel.test_seed_iterator()) as iterator:
|
||||
vector_env = self.venv_from_iterator(iterator)
|
||||
with _wrap_context(vessel.test_seed_iterators()) as iterators:
|
||||
vector_env = self.venv_from_iterator(iterators)
|
||||
self.vessel.test(vector_env)
|
||||
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
|
||||
self._call_callback_hooks("on_test_end")
|
||||
|
||||
def venv_from_iterator(self, iterator: Iterable[InitialStateType]) -> FiniteVectorEnv:
|
||||
def venv_from_iterator(self, iterators: List[Iterable[InitialStateType]]) -> FiniteVectorEnv:
|
||||
"""Create a vectorized environment from iterator and the training vessel."""
|
||||
|
||||
def env_factory():
|
||||
def env_factory(iterator):
|
||||
# FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env),
|
||||
# and could be thread unsafe.
|
||||
# I'm not sure whether it's a design flaw.
|
||||
@@ -300,7 +306,7 @@ class Trainer:
|
||||
)
|
||||
|
||||
return vectorize_env(
|
||||
env_factory,
|
||||
[partial(env_factory, iterator=it) for it in iterators],
|
||||
self.finite_env_type,
|
||||
self.concurrency,
|
||||
self.loggers,
|
||||
@@ -334,8 +340,11 @@ class Trainer:
|
||||
@contextmanager
|
||||
def _wrap_context(obj):
|
||||
"""Make any object a (possibly dummy) context manager."""
|
||||
|
||||
if isinstance(obj, AbstractContextManager):
|
||||
if isinstance(obj, list) and isinstance(obj[0], AbstractContextManager):
|
||||
with ExitStack() as stack:
|
||||
yield [stack.enter_context(e) for e in obj]
|
||||
stack.pop_all().close()
|
||||
elif isinstance(obj, AbstractContextManager):
|
||||
# obj has __enter__ and __exit__
|
||||
with obj as ctx:
|
||||
yield ctx
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generic, Iterable, Sequence, TypeVar, cast
|
||||
from typing import List, TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generic, Iterable, Sequence, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
@@ -49,19 +49,23 @@ class TrainingVesselBase(Generic[InitialStateType, StateType, ActType, ObsType,
|
||||
def assign_trainer(self, trainer: Trainer) -> None:
|
||||
self.trainer = weakref.proxy(trainer) # type: ignore
|
||||
|
||||
def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
"""Override this to create a seed iterator for training.
|
||||
def train_seed_iterators(
|
||||
self,
|
||||
) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
|
||||
"""Override this to create a seed iterators for training.
|
||||
If the iterable is a context manager, the whole training will be invoked in the with-block,
|
||||
and the iterator will be automatically closed after the training is done."""
|
||||
raise SeedIteratorNotAvailable("Seed iterator for training is not available.")
|
||||
raise SeedIteratorNotAvailable("Seed iterators for training is not available.")
|
||||
|
||||
def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
"""Override this to create a seed iterator for validation."""
|
||||
raise SeedIteratorNotAvailable("Seed iterator for validation is not available.")
|
||||
def val_seed_iterators(self) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
|
||||
"""Override this to create a seed iterators for validation."""
|
||||
raise SeedIteratorNotAvailable("Seed iterators for validation is not available.")
|
||||
|
||||
def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
"""Override this to create a seed iterator for testing."""
|
||||
raise SeedIteratorNotAvailable("Seed iterator for testing is not available.")
|
||||
def test_seed_iterators(
|
||||
self,
|
||||
) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
|
||||
"""Override this to create a seed iterators for testing."""
|
||||
raise SeedIteratorNotAvailable("Seed iterators for testing is not available.")
|
||||
|
||||
def train(self, vector_env: BaseVectorEnv) -> Dict[str, Any]:
|
||||
"""Implement this to train one iteration. In RL, one iteration usually refers to one collect."""
|
||||
@@ -120,9 +124,9 @@ class TrainingVessel(TrainingVesselBase):
|
||||
action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType],
|
||||
policy: BasePolicy,
|
||||
reward: Reward,
|
||||
train_initial_states: Sequence[InitialStateType] | None = None,
|
||||
val_initial_states: Sequence[InitialStateType] | None = None,
|
||||
test_initial_states: Sequence[InitialStateType] | None = None,
|
||||
train_initial_states: List[Sequence[InitialStateType]] | None = None,
|
||||
val_initial_states: List[Sequence[InitialStateType]] | None = None,
|
||||
test_initial_states: List[Sequence[InitialStateType]] | None = None,
|
||||
buffer_size: int = 20000,
|
||||
episode_per_iter: int = 1000,
|
||||
update_kwargs: Dict[str, Any] = cast(Dict[str, Any], None),
|
||||
@@ -132,34 +136,49 @@ class TrainingVessel(TrainingVesselBase):
|
||||
self.action_interpreter = action_interpreter
|
||||
self.policy = policy
|
||||
self.reward = reward
|
||||
self.train_initial_states = train_initial_states
|
||||
self.val_initial_states = val_initial_states
|
||||
self.test_initial_states = test_initial_states
|
||||
self.train_initial_states = None if train_initial_states is None else train_initial_states
|
||||
self.val_initial_states = None if val_initial_states is None else val_initial_states
|
||||
self.test_initial_states = None if test_initial_states is None else test_initial_states
|
||||
self.buffer_size = buffer_size
|
||||
self.episode_per_iter = episode_per_iter
|
||||
self.update_kwargs = update_kwargs or {}
|
||||
|
||||
def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
def train_seed_iterators(
|
||||
self,
|
||||
) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
|
||||
if self.train_initial_states is not None:
|
||||
_logger.info("Training initial states collection size: %d", len(self.train_initial_states))
|
||||
# Implement fast_dev_run here.
|
||||
train_initial_states = self._random_subset("train", self.train_initial_states, self.trainer.fast_dev_run)
|
||||
return DataQueue(train_initial_states, repeat=-1, shuffle=True)
|
||||
return super().train_seed_iterator()
|
||||
_logger.info(f"Training initial states collection sizes: {[len(e) for e in self.train_initial_states]}")
|
||||
train_initial_states = [
|
||||
self._random_subset("train", e, self.trainer.fast_dev_run) for e in self.train_initial_states
|
||||
]
|
||||
iterators = [DataQueue(e, repeat=-1, shuffle=True) for e in train_initial_states]
|
||||
return cast(List[Iterable[InitialStateType]], iterators)
|
||||
else:
|
||||
return super().train_seed_iterators()
|
||||
|
||||
def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
def val_seed_iterators(self) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
|
||||
if self.val_initial_states is not None:
|
||||
_logger.info("Validation initial states collection size: %d", len(self.val_initial_states))
|
||||
val_initial_states = self._random_subset("val", self.val_initial_states, self.trainer.fast_dev_run)
|
||||
return DataQueue(val_initial_states, repeat=1)
|
||||
return super().val_seed_iterator()
|
||||
_logger.info(f"Validation initial states collection sizes: {[len(e) for e in self.val_initial_states]}")
|
||||
val_initial_states = [
|
||||
self._random_subset("val", e, self.trainer.fast_dev_run) for e in self.val_initial_states
|
||||
]
|
||||
iterators = [DataQueue(e, repeat=1) for e in val_initial_states]
|
||||
return cast(List[Iterable[InitialStateType]], iterators)
|
||||
else:
|
||||
return super().val_seed_iterators()
|
||||
|
||||
def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
def test_seed_iterators(
|
||||
self,
|
||||
) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
|
||||
if self.test_initial_states is not None:
|
||||
_logger.info("Testing initial states collection size: %d", len(self.test_initial_states))
|
||||
test_initial_states = self._random_subset("test", self.test_initial_states, self.trainer.fast_dev_run)
|
||||
return DataQueue(test_initial_states, repeat=1)
|
||||
return super().test_seed_iterator()
|
||||
_logger.info(f"Testing initial states collection sizes: {[len(e) for e in self.test_initial_states]}")
|
||||
test_initial_states = [
|
||||
self._random_subset("test", e, self.trainer.fast_dev_run) for e in self.test_initial_states
|
||||
]
|
||||
iterators = [DataQueue(e, repeat=1) for e in test_initial_states]
|
||||
return cast(List[Iterable[InitialStateType]], iterators)
|
||||
else:
|
||||
return super().test_seed_iterators()
|
||||
|
||||
def train(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]:
|
||||
"""Create a collector and collects ``episode_per_iter`` episodes.
|
||||
|
||||
@@ -258,6 +258,46 @@ class FiniteVectorEnv(BaseVectorEnv):
|
||||
|
||||
return np.stack(obs)
|
||||
|
||||
def step2(
|
||||
self,
|
||||
action: np.ndarray,
|
||||
id: int | List[int] | np.ndarray | None = None,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
assert not self._zombie
|
||||
wrapped_id = self._wrap_id(id)
|
||||
id2idx = {i: k for k, i in enumerate(wrapped_id)}
|
||||
request_id = list(filter(lambda i: i in self._alive_env_ids, wrapped_id))
|
||||
result = {}
|
||||
|
||||
# ask super to step alive envs and remap to current index
|
||||
if request_id:
|
||||
valid_act = np.stack([action[id2idx[i]] for i in request_id])
|
||||
tmp = super().step(valid_act, request_id)
|
||||
|
||||
for obs_next, rew, done, info in zip(*tmp):
|
||||
obs_next = self._postproc_env_obs(obs_next)
|
||||
result[info["env_id"]] = [obs_next, rew, done, info]
|
||||
|
||||
# logging
|
||||
for i, r in result.items():
|
||||
if i in self._alive_env_ids and r[0] is not None:
|
||||
for logger in self._logger:
|
||||
logger.on_env_step(i, *r)
|
||||
|
||||
for _, reward, __, info in result.values():
|
||||
self._set_default_info(info)
|
||||
self._set_default_rew(reward)
|
||||
for r in result.values():
|
||||
if r[0] is None:
|
||||
r[0] = self._get_default_obs()
|
||||
if r[1] is None:
|
||||
r[1] = self._get_default_rew()
|
||||
if r[3] is None:
|
||||
r[3] = self._get_default_info()
|
||||
|
||||
ret = list(map(np.stack, zip(*result.values())))
|
||||
return cast(Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], ret)
|
||||
|
||||
def step(
|
||||
self,
|
||||
action: np.ndarray,
|
||||
@@ -311,7 +351,7 @@ class FiniteShmemVectorEnv(FiniteVectorEnv, ShmemVectorEnv):
|
||||
|
||||
|
||||
def vectorize_env(
|
||||
env_factory: Callable[..., gym.Env],
|
||||
env_factories: List[Callable[..., gym.Env]],
|
||||
env_type: FiniteEnvType,
|
||||
concurrency: int,
|
||||
logger: LogWriter | List[LogWriter],
|
||||
@@ -334,9 +374,10 @@ def vectorize_env(
|
||||
|
||||
Parameters
|
||||
----------
|
||||
env_factory
|
||||
Callable to instantiate one single ``gym.Env``.
|
||||
All concurrent workers will have the same ``env_factory``.
|
||||
env_factories
|
||||
Callables to instantiate one single ``gym.Env``.
|
||||
There should be 1 or `concurrency` env_factories. If there is 1 env_factory, all concurrent workers will have
|
||||
the same env_factory. Otherwise, each worker will have its own env_factory.
|
||||
env_type
|
||||
dummy or subproc or shmem. Corresponding to
|
||||
`parallelism in tianshou <https://tianshou.readthedocs.io/en/master/api/tianshou.env.html#vectorenv>`_.
|
||||
@@ -358,6 +399,8 @@ def vectorize_env(
|
||||
def env_factory(): ...
|
||||
vectorize_env(env_factory, ...)
|
||||
"""
|
||||
assert len(env_factories) in (1, concurrency)
|
||||
|
||||
env_type_cls_mapping: Dict[str, Type[FiniteVectorEnv]] = {
|
||||
"dummy": FiniteDummyVectorEnv,
|
||||
"subproc": FiniteSubprocVectorEnv,
|
||||
@@ -366,4 +409,7 @@ def vectorize_env(
|
||||
|
||||
finite_env_cls = env_type_cls_mapping[env_type]
|
||||
|
||||
return finite_env_cls(logger, [env_factory for _ in range(concurrency)])
|
||||
if len(env_factories) == 1:
|
||||
return finite_env_cls(logger, [env_factories[0] for _ in range(concurrency)])
|
||||
else:
|
||||
return finite_env_cls(logger, env_factories)
|
||||
|
||||
30
qlib/rl/utils/profiling.py
Normal file
30
qlib/rl/utils/profiling.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, Generator
|
||||
|
||||
from line_profiler import LineProfiler
|
||||
|
||||
|
||||
@contextmanager
|
||||
def simple_perf(desc: str = "", out_path: str = None) -> Generator[None, None, None]:
|
||||
s = time.perf_counter()
|
||||
yield
|
||||
e = time.perf_counter()
|
||||
msg = f"{desc}: {(e - s) * 1000.0:.4f} ms"
|
||||
|
||||
if out_path is not None:
|
||||
with open(out_path, "a") as fstream:
|
||||
fstream.write(msg + "\n")
|
||||
else:
|
||||
print(msg)
|
||||
|
||||
|
||||
def lprofile(func: Callable) -> Callable:
|
||||
def wrapper(*args, **kwargs):
|
||||
lp = LineProfiler()
|
||||
lpw = lp(func)
|
||||
res = lpw(*args, **kwargs)
|
||||
lp.print_stats()
|
||||
return res
|
||||
|
||||
return wrapper
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# TODO: this utils covers too much utilities, please seperat it into sub modules
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
@@ -43,7 +44,7 @@ is_deprecated_lexsorted_pandas = version.parse(pd.__version__) > version.parse("
|
||||
#################### Server ####################
|
||||
def get_redis_connection():
|
||||
"""get redis connection instance."""
|
||||
return redis.StrictRedis(host=C.redis_host, port=C.redis_port, db=C.redis_task_db)
|
||||
return redis.StrictRedis(host=C.redis_host, port=C.redis_port, db=C.redis_task_db, password=C.redis_password)
|
||||
|
||||
|
||||
#################### Data ####################
|
||||
@@ -427,7 +428,7 @@ def init_instance_by_config(
|
||||
pr = urlparse(config)
|
||||
if pr.scheme == "file":
|
||||
pr_path = os.path.join(pr.netloc, pr.path) if bool(pr.path) else pr.netloc
|
||||
with open(pr_path, "rb") as f:
|
||||
with open(os.path.normpath(pr_path), "rb") as f:
|
||||
return pickle.load(f)
|
||||
else:
|
||||
with config.open("rb") as f:
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from typing import Union
|
||||
"""
|
||||
This module covers some utility functions that operate on data or basic object
|
||||
"""
|
||||
from copy import deepcopy
|
||||
from typing import List, Union
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
@@ -54,3 +58,48 @@ def deepcopy_basic_type(obj: object) -> object:
|
||||
return {k: deepcopy_basic_type(v) for k, v in obj.items()}
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
S_DROP = "__DROP__" # this is a symbol which indicates drop the value
|
||||
|
||||
|
||||
def update_config(base_config: dict, ext_config: Union[dict, List[dict]]):
|
||||
"""
|
||||
supporting adding base config based on the ext_config
|
||||
|
||||
>>> bc = {"a": "xixi"}
|
||||
>>> ec = {"b": "haha"}
|
||||
>>> new_bc = update_config(bc, ec)
|
||||
>>> print(new_bc)
|
||||
{'a': 'xixi', 'b': 'haha'}
|
||||
>>> print(bc) # base config should not be changed
|
||||
{'a': 'xixi'}
|
||||
>>> print(update_config(bc, {"b": S_DROP}))
|
||||
{'a': 'xixi'}
|
||||
>>> print(update_config(new_bc, {"b": S_DROP}))
|
||||
{'a': 'xixi'}
|
||||
"""
|
||||
|
||||
base_config = deepcopy(base_config) # in case of modifying base config
|
||||
|
||||
for ec in ext_config if isinstance(ext_config, (list, tuple)) else [ext_config]:
|
||||
for key in ec:
|
||||
if key not in base_config:
|
||||
# if it is not in the default key, then replace it.
|
||||
# ADD if not drop
|
||||
if ec[key] != S_DROP:
|
||||
base_config[key] = ec[key]
|
||||
|
||||
else:
|
||||
if isinstance(base_config[key], dict) and isinstance(ec[key], dict):
|
||||
# Recursive
|
||||
# Both of them are dict, then update it nested
|
||||
base_config[key] = update_config(base_config[key], ec[key])
|
||||
elif ec[key] == S_DROP:
|
||||
# DROP
|
||||
del base_config[key]
|
||||
else:
|
||||
# REPLACE
|
||||
# one of then are not dict. Then replace
|
||||
base_config[key] = ec[key]
|
||||
return base_config
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
@@ -10,6 +10,12 @@ import fire
|
||||
import ruamel.yaml as yaml
|
||||
from qlib.config import C
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.utils.data import update_config
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils import set_log_with_config
|
||||
|
||||
set_log_with_config(C.logging_config)
|
||||
logger = get_module_logger("qrun", logging.INFO)
|
||||
|
||||
|
||||
def get_path_list(path):
|
||||
@@ -47,10 +53,47 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
|
||||
This is a Qlib CLI entrance.
|
||||
User can run the whole Quant research workflow defined by a configure file
|
||||
- the code is located here ``qlib/workflow/cli.py`
|
||||
|
||||
User can specify a base_config file in your workflow.yml file by adding "BASE_CONFIG_PATH".
|
||||
Qlib will load the configuration in BASE_CONFIG_PATH first, and the user only needs to update the custom fields
|
||||
in their own workflow.yml file.
|
||||
|
||||
For examples:
|
||||
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
BASE_CONFIG_PATH: "workflow_config_lightgbm_Alpha158_csi500.yaml"
|
||||
market: csi300
|
||||
|
||||
"""
|
||||
with open(config_path) as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
|
||||
base_config_path = config.get("BASE_CONFIG_PATH", None)
|
||||
if base_config_path:
|
||||
logger.info(f"Use BASE_CONFIG_PATH: {base_config_path}")
|
||||
base_config_path = Path(base_config_path)
|
||||
|
||||
# it will find config file in absolute path and relative path
|
||||
if base_config_path.exists():
|
||||
path = base_config_path
|
||||
else:
|
||||
logger.info(
|
||||
f"Can't find BASE_CONFIG_PATH base on: {Path.cwd()}, "
|
||||
f"try using relative path to config path: {Path(config_path).absolute()}"
|
||||
)
|
||||
relative_path = Path(config_path).absolute().parent.joinpath(base_config_path)
|
||||
if relative_path.exists():
|
||||
path = relative_path
|
||||
else:
|
||||
raise FileNotFoundError(f"Can't find the BASE_CONFIG file: {base_config_path}")
|
||||
|
||||
with open(path) as fp:
|
||||
base_config = yaml.safe_load(fp)
|
||||
logger.info(f"Load BASE_CONFIG_PATH succeed: {path.resolve()}")
|
||||
config = update_config(base_config, config)
|
||||
|
||||
# config the `sys` section
|
||||
sys_config(config, config_path)
|
||||
|
||||
|
||||
2
setup.py
2
setup.py
@@ -170,7 +170,7 @@ setup(
|
||||
"gym>=0.24", # If you do not put gym at the end, gym will degrade causing pytest results to fail.
|
||||
],
|
||||
"rl": [
|
||||
"tianshou",
|
||||
"tianshou<=0.4.10",
|
||||
"torch",
|
||||
],
|
||||
},
|
||||
|
||||
5
tests/data_mid_layer_tests/README.md
Normal file
5
tests/data_mid_layer_tests/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# Introduction
|
||||
The middle layers of data, which mainly includes
|
||||
- Handler
|
||||
- processors
|
||||
- Datasets
|
||||
37
tests/data_mid_layer_tests/test_handler.py
Normal file
37
tests/data_mid_layer_tests/test_handler.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import unittest
|
||||
from qlib.tests import TestAutoData
|
||||
from qlib.data import D
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
class HandlerTests(TestAutoData):
|
||||
def to_str(self, obj):
|
||||
return "".join(str(obj).split())
|
||||
|
||||
def test_handler_df(self):
|
||||
df = D.features(["sh600519"], start_time="20190101", end_time="20190201", fields=["$close"])
|
||||
dh = DataHandlerLP.from_df(df)
|
||||
print(dh.fetch())
|
||||
self.assertTrue(dh._data.equals(df))
|
||||
self.assertTrue(dh._infer is dh._data)
|
||||
self.assertTrue(dh._learn is dh._data)
|
||||
self.assertTrue(dh.data_loader._data is dh._data)
|
||||
fname = "_handler_test.pkl"
|
||||
dh.to_pickle(fname, dump_all=True)
|
||||
|
||||
with open(fname, "rb") as f:
|
||||
dh_d = pickle.load(f)
|
||||
|
||||
self.assertTrue(dh_d._data.equals(df))
|
||||
self.assertTrue(dh_d._infer is dh_d._data)
|
||||
self.assertTrue(dh_d._learn is dh_d._data)
|
||||
# Data loader will no longer be useful
|
||||
self.assertTrue("_data" not in dh_d.data_loader.__dict__.keys())
|
||||
os.remove(fname)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -76,7 +76,7 @@ class IndexDataTest(unittest.TestCase):
|
||||
self.assertTrue(np.isnan(sd.loc["bar", "g"]))
|
||||
|
||||
# support slicing
|
||||
print(sd.loc[~sd.loc[:, "g"].isna().data.astype(np.bool)])
|
||||
print(sd.loc[~sd.loc[:, "g"].isna().data.astype(bool)])
|
||||
|
||||
print(self.assertTrue(idd.SingleData().index == idd.SingleData().index))
|
||||
|
||||
|
||||
@@ -31,7 +31,6 @@ FEATURE_DATA_DIR = DATA_DIR / "processed"
|
||||
ORDER_DIR = DATA_DIR / "order" / "valid_bidir"
|
||||
|
||||
CN_DATA_DIR = DATA_ROOT_DIR / "cn"
|
||||
CN_BACKTEST_DATA_DIR = CN_DATA_DIR / "backtest"
|
||||
CN_FEATURE_DATA_DIR = CN_DATA_DIR / "processed"
|
||||
CN_ORDER_DIR = CN_DATA_DIR / "order" / "test"
|
||||
CN_POLICY_WEIGHTS_DIR = CN_DATA_DIR / "weights"
|
||||
@@ -49,7 +48,7 @@ def test_pickle_data_inspect():
|
||||
def test_simulator_first_step():
|
||||
order = Order("AAL", 30.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
|
||||
|
||||
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
|
||||
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
|
||||
state = simulator.get_state()
|
||||
assert state.cur_time == pd.Timestamp("2013-12-11 09:30:00")
|
||||
assert state.position == 30.0
|
||||
@@ -83,7 +82,7 @@ def test_simulator_first_step():
|
||||
def test_simulator_stop_twap():
|
||||
order = Order("AAL", 13.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
|
||||
|
||||
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
|
||||
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
|
||||
for _ in range(13):
|
||||
simulator.step(1.0)
|
||||
|
||||
@@ -106,10 +105,10 @@ def test_simulator_stop_early():
|
||||
order = Order("AAL", 1.0, 1, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
|
||||
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
|
||||
simulator.step(2.0)
|
||||
|
||||
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
|
||||
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
|
||||
simulator.step(1.0)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
@@ -119,7 +118,7 @@ def test_simulator_stop_early():
|
||||
def test_simulator_start_middle():
|
||||
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59"))
|
||||
|
||||
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
|
||||
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
|
||||
assert len(simulator.ticks_for_order) == 330
|
||||
assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00")
|
||||
simulator.step(2.0)
|
||||
@@ -138,7 +137,7 @@ def test_simulator_start_middle():
|
||||
def test_interpreter():
|
||||
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59"))
|
||||
|
||||
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
|
||||
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
|
||||
assert len(simulator.ticks_for_order) == 330
|
||||
assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00")
|
||||
|
||||
@@ -219,7 +218,7 @@ def test_network_sanity():
|
||||
# we won't check the correctness of networks here
|
||||
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 9:30:00"), pd.Timestamp("2013-12-11 15:59:59"))
|
||||
|
||||
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
|
||||
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
|
||||
assert len(simulator.ticks_for_order) == 390
|
||||
|
||||
class EmulateEnvWrapper(NamedTuple):
|
||||
@@ -259,7 +258,7 @@ def test_twap_strategy(finite_env_type):
|
||||
csv_writer = CsvWriter(Path(__file__).parent / ".output")
|
||||
|
||||
backtest(
|
||||
partial(SingleAssetOrderExecutionSimple, data_dir=BACKTEST_DATA_DIR, ticks_per_step=30),
|
||||
partial(SingleAssetOrderExecutionSimple, data_dir=DATA_DIR, ticks_per_step=30),
|
||||
state_interp,
|
||||
action_interp,
|
||||
orders,
|
||||
@@ -290,7 +289,7 @@ def test_cn_ppo_strategy():
|
||||
csv_writer = CsvWriter(Path(__file__).parent / ".output")
|
||||
|
||||
backtest(
|
||||
partial(SingleAssetOrderExecutionSimple, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
|
||||
partial(SingleAssetOrderExecutionSimple, data_dir=CN_DATA_DIR, ticks_per_step=30),
|
||||
state_interp,
|
||||
action_interp,
|
||||
orders,
|
||||
@@ -319,7 +318,7 @@ def test_ppo_train():
|
||||
policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)
|
||||
|
||||
train(
|
||||
partial(SingleAssetOrderExecutionSimple, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
|
||||
partial(SingleAssetOrderExecutionSimple, data_dir=CN_DATA_DIR, ticks_per_step=30),
|
||||
state_interp,
|
||||
action_interp,
|
||||
orders,
|
||||
|
||||
Reference in New Issue
Block a user