1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-29 00:51:19 +08:00

Compare commits

..

16 Commits

Author SHA1 Message Date
Huoran Li
1f5f3a6af0 Do not create venv each iteration & use separate data iterator for each parallel worker (#1522)
* Test passed

* CI

* Cache exchange

* Refine backtest scripts

* Minor

* Rename backtest file

* Add async mode for potential use

* Slient backtest. Add .
2023-06-12 12:05:51 +08:00
Huoran Li
2f8fc8d28a Black 2023-05-24 10:37:21 +08:00
Huoran Li
3e9ccd3ad2 Train on full simulation 2023-05-24 10:36:27 +08:00
you-n-g
94268619c4 Update README.md 2023-05-23 09:50:00 +08:00
Huoran Li
8d60a6a02b Resolve RL FIXMES (#1503)
* Solve several small FIXMEs left in RL

* Add TODO in example

* Minor bugfix

* black
2023-05-17 16:57:08 +08:00
Fivele-Li
7234308651 Add base config in yml (#1500)
* path on Windows contains double '/' which may cause open file failed.

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* add baseConfig in yml,user can add new keys or update/drop keys in baseConfig;

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* pip release version 23.1 on Apr.15 2023, CI failed to run, Please refer to #1495 ofr detailed logs. The pip version has been temporarily fixed to 23.0.1.

* 1.Search for baseConfig in multiple directories;
2.Add user instructions in qrun;

* fix format with black

* 1.modify baseConfig key to BASE_CONFIG_PATH;
2.only find config file in absolute path and relative path;

* load BASE_CONFIG_PATH on absolute path & relative path;

* fix Lint with black

---------

Co-authored-by: lijinhui <362237642@qq.com>
2023-05-12 17:35:37 +08:00
Chaoying
acf5df27ce Add support for redis password (#1508) 2023-05-08 16:17:15 +08:00
Chaoying
37a59f28d3 Fix deprecated syntax in numpy (#1507)
* Fix deprecated syntax in numpy

* Replace np.bool with bool
2023-05-08 16:17:02 +08:00
YQ Tsui
b084c352f5 provide dtype to empty series to surpress warning; fix type (#1449) 2023-05-05 17:47:44 +08:00
Maksim Zayakin
9e22e5168b Remove unused DNNModelPytorch params (#1470)
* Remove lr_decay and lr_decay_steps params

More flexible way to pass a scheduler (via callable function) is already
supported

* remove lr_decay and lr_decay_steps from mlp workflow configs
2023-04-28 17:48:40 +08:00
Fivele-Li
dceff7b471 Specify the tianshou version to match the dev environment to avoid the error in issue #1477. (#1502) 2023-04-28 13:50:25 +08:00
Huoran Li
7f1e8c5206 Refine Qlib RL data format (#1480)
* wip

* wip

* wip

* Fix naming errors

* Backtest test passed

* Why training stuck?

* Minor

* Refine train configs

* Use dummy in training

* Remove pickle_dataframe

* CI

* CI

* Add more strict condition to filter orders

* Pass test

* Add TODO in example

---------

Co-authored-by: Young <afe.young@gmail.com>
2023-04-26 21:14:30 +08:00
Fivele-Li
46264dfec9 normpath for Windows (#1495)
* path on Windows contains double '/' which may cause open file failed.

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* locate import numpy error

* pip release version 23.1 on Apr.15 2023, CI failed to run, Please refer to #1495 ofr detailed logs. The pip version has been temporarily fixed to 23.0.1.

---------

Co-authored-by: lijinhui <362237642@qq.com>
2023-04-26 16:26:12 +08:00
Fivele-Li
754799ab05 update ubuntu CI version; (#1488)
* update ubuntu CI version;
(End of standard support for 18.04 LTS - 31 May 2023)

* update ubuntu CI version;

---------

Co-authored-by: lijinhui <362237642@qq.com>
2023-04-10 17:06:48 +08:00
you-n-g
32c3070b73 Refine DDG-DA (#1472)
* Run ddg-da successfully

* Support include valid; More parameters

* Support L2 reg & visualization

* Blackformat

* Enable fill_method

* Support specify handler & optim dataset

* Fix Pylint
2023-04-07 15:00:21 +08:00
you-n-g
40de67265a Update Docs about some concepts in DataHandler (#1485) 2023-04-07 10:02:16 +08:00
61 changed files with 1599 additions and 895 deletions

View File

@@ -13,7 +13,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [windows-latest, ubuntu-18.04, ubuntu-20.04, macos-11, macos-latest]
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8]

View File

@@ -14,7 +14,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [windows-latest, ubuntu-18.04, ubuntu-20.04, macos-11, macos-latest]
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8]
@@ -28,8 +28,10 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Update pip to the latest version
# pip release version 23.1 on Apr.15 2023, CI failed to run, Please refer to #1495 ofr detailed logs.
# The pip version has been temporarily fixed to 23.0.1
run: |
python -m pip install --upgrade pip
python -m pip install pip==23.0.1
- name: Installing pytorch for macos
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
@@ -37,15 +39,13 @@ jobs:
python -m pip install torch torchvision torchaudio
- name: Installing pytorch for ubuntu
if: ${{ matrix.os == 'ubuntu-18.04' || matrix.os == 'ubuntu-20.04' }}
if: ${{ matrix.os == 'ubuntu-20.04' || matrix.os == 'ubuntu-22.04' }}
run: |
python -m pip install --upgrade pip
python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
- name: Installing pytorch for windows
if: ${{ matrix.os == 'windows-latest' }}
run: |
python -m pip install --upgrade pip
python -m pip install torch torchvision torchaudio
- name: Set up Python tools

View File

@@ -14,7 +14,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [windows-latest, ubuntu-18.04, ubuntu-20.04, macos-11, macos-latest]
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8]
@@ -28,9 +28,10 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Set up Python tools
# pip release version 23.1 on Apr.15 2023, CI failed to run, Please refer to #1495 ofr detailed logs.
# The pip version has been temporarily fixed to 23.0.1
run: |
python -m pip install --upgrade pip
# python -m pip is necessary to upgrade pip.
python -m pip install pip==23.0.1
pip install --upgrade cython numpy
pip install -e .[dev]

View File

@@ -42,13 +42,11 @@ Features released before 2021 are not listed here.
<img src="http://fintech.msra.cn/images_v070/logo/1.png" />
</p>
Qlib is an open-source, AI-oriented quantitative investment platform that aims to realize the potential, empower research, and create value using AI technologies in quantitative investment, from exploring ideas to implementing productions. Qlib supports diverse machine learning modeling paradigms, including supervised learning, market dynamics modeling, and reinforcement learning.
Qlib is an AI-oriented quantitative investment platform, which aims to realize the potential, empower the research, and create the value of AI technologies in quantitative investment.
An increasing number of SOTA Quant research works/papers in diverse paradigms are being released in Qlib to collaboratively solve key challenges in quantitative investment. For example, 1) using supervised learning to mine the market's complex non-linear patterns from rich and heterogeneous financial data, 2) modeling the dynamic nature of the financial market using adaptive concept drift technology, and 3) using reinforcement learning to model continuous investment decisions and assist investors in optimizing their trading strategies.
It contains the full ML pipeline of data processing, model training, back-testing; and covers the entire chain of quantitative investment: alpha seeking, risk modeling, portfolio optimization, and order execution.
With Qlib, users can easily try ideas to create better Quant investment strategies.
For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative Investment Platform"](https://arxiv.org/abs/2009.11189).

View File

@@ -64,8 +64,6 @@ task:
kwargs:
loss: mse
lr: 0.002
lr_decay: 0.96
lr_decay_steps: 100
optimizer: adam
max_steps: 8000
batch_size: 8192

View File

@@ -64,8 +64,6 @@ task:
kwargs:
loss: mse
lr: 0.002
lr_decay: 0.96
lr_decay_steps: 100
optimizer: adam
max_steps: 8000
batch_size: 8192

View File

@@ -52,8 +52,6 @@ task:
kwargs:
loss: mse
lr: 0.002
lr_decay: 0.96
lr_decay_steps: 100
optimizer: adam
max_steps: 8000
batch_size: 4096

View File

@@ -52,8 +52,6 @@ task:
kwargs:
loss: mse
lr: 0.002
lr_decay: 0.96
lr_decay_steps: 100
optimizer: adam
max_steps: 8000
batch_size: 4096

View File

@@ -0,0 +1,107 @@
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(color_codes=True)
plt.rcParams["font.sans-serif"] = "SimHei"
plt.rcParams["axes.unicode_minus"] = False
from tqdm.auto import tqdm
# tqdm.pandas() # for progress_apply
# %matplotlib inline
# %load_ext autoreload
# # Meta Input
# +
with open("./internal_data_s20.pkl", "rb") as f:
data = pickle.load(f)
data.data_ic_df.columns.names = ["start_date", "end_date"]
data_sim = data.data_ic_df.droplevel(axis=1, level="end_date")
data_sim.index.name = "test datetime"
# -
plt.figure(figsize=(40, 20))
sns.heatmap(data_sim)
plt.figure(figsize=(40, 20))
sns.heatmap(data_sim.rolling(20).mean())
# # Meta Model
from qlib import auto_init
auto_init()
from qlib.workflow import R
exp = R.get_exp(experiment_name="DDG-DA")
meta_rec = exp.list_recorders(rtype="list", max_results=1)[0]
meta_m = meta_rec.load_object("model")
pd.DataFrame(meta_m.tn.twm.linear.weight.detach().numpy()).T[0].plot()
pd.DataFrame(meta_m.tn.twm.linear.weight.detach().numpy()).T[0].rolling(5).mean().plot()
# # Meta Output
# +
with open("./tasks_s20.pkl", "rb") as f:
tasks = pickle.load(f)
task_df = {}
for t in tasks:
test_seg = t["dataset"]["kwargs"]["segments"]["test"]
if None not in test_seg:
# The last rolling is skipped.
task_df[test_seg] = t["reweighter"].time_weight
task_df = pd.concat(task_df)
task_df.index.names = ["OS_start", "OS_end", "IS_start", "IS_end"]
task_df = task_df.droplevel(["OS_end", "IS_end"])
task_df = task_df.unstack("OS_start")
# -
plt.figure(figsize=(40, 20))
sns.heatmap(task_df.T)
plt.figure(figsize=(40, 20))
sns.heatmap(task_df.rolling(10).mean().T)
# # Sub Models
#
# NOTE:
# - this section assumes that the model is Linear model!!
# - Other models does not support this analysis
exp = R.get_exp(experiment_name="rolling_ds")
def show_linear_weight(exp):
coef_df = {}
for r in exp.list_recorders("list"):
t = r.load_object("task")
if None in t["dataset"]["kwargs"]["segments"]["test"]:
continue
m = r.load_object("params.pkl")
coef_df[t["dataset"]["kwargs"]["segments"]["test"]] = pd.Series(m.coef_)
coef_df = pd.concat(coef_df)
coef_df.index.names = ["test_start", "test_end", "coef_idx"]
coef_df = coef_df.droplevel("test_end").unstack("coef_idx").T
plt.figure(figsize=(40, 20))
sns.heatmap(coef_df)
plt.show()
show_linear_weight(R.get_exp(experiment_name="rolling_ds"))
show_linear_weight(R.get_exp(experiment_name="rolling_models"))

View File

@@ -10,8 +10,10 @@ import pandas as pd
import fire
import sys
import pickle
from typing import Optional
from qlib import auto_init
from qlib.model.trainer import TrainerR
from qlib.typehint import Literal
from qlib.utils import init_instance_by_config
from qlib.workflow import R
from qlib.tests.data import GetData
@@ -30,7 +32,33 @@ class DDGDA:
- `rm -r mlruns`
"""
def __init__(self, sim_task_model="linear", forecast_model="linear"):
def __init__(
self,
sim_task_model: Literal["linear", "gbdt"] = "linear",
forecast_model: Literal["linear", "gbdt"] = "linear",
h_path: Optional[str] = None,
test_end: Optional[str] = None,
train_start: Optional[str] = None,
meta_1st_train_end: Optional[str] = None,
task_ext_conf: Optional[dict] = None,
alpha: float = 0.0,
proxy_hd: str = "handler_proxy.pkl",
):
"""
Parameters
----------
train_start: Optional[str]
the start datetime for data. It is used in training start time (for both tasks & meta learing)
test_end: Optional[str]
the end datetime for data. It is used in test end time
meta_1st_train_end: Optional[str]
the datetime of training end of the first meta_task
alpha: float
Setting the L2 regularization for ridge
The `alpha` is only passed to MetaModelDS (it is not passed to sim_task_model currently..)
"""
self.step = 20
# NOTE:
# the horizon must match the meaning in the base task template
@@ -38,10 +66,19 @@ class DDGDA:
self.meta_exp_name = "DDG-DA"
self.sim_task_model = sim_task_model # The model to capture the distribution of data.
self.forecast_model = forecast_model # downstream forecasting models' type
self.rb_kwargs = {
"h_path": h_path,
"test_end": test_end,
"train_start": train_start,
"task_ext_conf": task_ext_conf,
}
self.alpha = alpha
self.meta_1st_train_end = meta_1st_train_end
self.proxy_hd = proxy_hd
def get_feature_importance(self):
# this must be lightGBM, because it needs to get the feature importance
rb = RollingBenchmark(model_type="gbdt")
rb = RollingBenchmark(model_type="gbdt", **self.rb_kwargs)
task = rb.basic_task()
with R.start(experiment_name="feature_importance"):
@@ -69,7 +106,7 @@ class DDGDA:
fi = self.get_feature_importance()
col_selected = fi.nlargest(topk)
rb = RollingBenchmark(model_type=self.sim_task_model)
rb = RollingBenchmark(model_type=self.sim_task_model, **self.rb_kwargs)
task = rb.basic_task()
dataset = init_instance_by_config(task["dataset"])
prep_ds = dataset.prepare(slice(None), col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
@@ -96,7 +133,7 @@ class DDGDA:
"kwargs": {"config": DIRNAME / "fea_label_df.pkl"},
}
)
handler.to_pickle(DIRNAME / "handler_proxy.pkl", dump_all=True)
handler.to_pickle(DIRNAME / self.proxy_hd, dump_all=True)
@property
def _internal_data_path(self):
@@ -108,7 +145,7 @@ class DDGDA:
This function will dump the input data for meta model
"""
# According to the experiments, the choice of the model type is very important for achieving good results
rb = RollingBenchmark(model_type=self.sim_task_model)
rb = RollingBenchmark(model_type=self.sim_task_model, **self.rb_kwargs)
sim_task = rb.basic_task()
if self.sim_task_model == "gbdt":
@@ -122,24 +159,27 @@ class DDGDA:
with self._internal_data_path.open("wb") as f:
pickle.dump(internal_data, f)
def train_meta_model(self):
def train_meta_model(self, fill_method="max"):
"""
training a meta model based on a simplified linear proxy model;
"""
# 1) leverage the simplified proxy forecasting model to train meta model.
# - Only the dataset part is important, in current version of meta model will integrate the
rb = RollingBenchmark(model_type=self.sim_task_model)
rb = RollingBenchmark(model_type=self.sim_task_model, **self.rb_kwargs)
sim_task = rb.basic_task()
train_start = self.rb_kwargs.get("train_start", "2008-01-01")
train_end = "2010-12-31" if self.meta_1st_train_end is None else self.meta_1st_train_end
test_start = (pd.Timestamp(train_end) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
proxy_forecast_model_task = {
# "model": "qlib.contrib.model.linear.LinearModel",
"dataset": {
"class": "qlib.data.dataset.DatasetH",
"kwargs": {
"handler": f"file://{(DIRNAME / 'handler_proxy.pkl').absolute()}",
"handler": f"file://{(DIRNAME / self.proxy_hd).absolute()}",
"segments": {
"train": ("2008-01-01", "2010-12-31"),
"test": ("2011-01-01", sim_task["dataset"]["kwargs"]["segments"]["test"][1]),
"train": (train_start, train_end),
"test": (test_start, sim_task["dataset"]["kwargs"]["segments"]["test"][1]),
},
},
},
@@ -156,7 +196,7 @@ class DDGDA:
segments=0.62, # keep test period consistent with the dataset yaml
trunc_days=1 + self.horizon,
hist_step_n=30,
fill_method="max",
fill_method=fill_method,
rolling_ext_days=0,
)
# NOTE:
@@ -165,12 +205,15 @@ class DDGDA:
# So the misalignment will not affect the effectiveness of the method.
with self._internal_data_path.open("rb") as f:
internal_data = pickle.load(f)
md = MetaDatasetDS(exp_name=internal_data, **kwargs)
# 3) train and logging meta model
with R.start(experiment_name=self.meta_exp_name):
R.log_params(**kwargs)
mm = MetaModelDS(step=self.step, hist_step_n=kwargs["hist_step_n"], lr=0.001, max_epoch=100, seed=43)
mm = MetaModelDS(
step=self.step, hist_step_n=kwargs["hist_step_n"], lr=0.001, max_epoch=100, seed=43, alpha=self.alpha
)
mm.fit(md)
R.save_objects(model=mm)
@@ -203,7 +246,7 @@ class DDGDA:
hist_step_n = int(param["hist_step_n"])
fill_method = param.get("fill_method", "max")
rb = RollingBenchmark(model_type=self.forecast_model)
rb = RollingBenchmark(model_type=self.forecast_model, **self.rb_kwargs)
task_l = rb.create_rolling_tasks()
# 2.2) create meta dataset for final dataset
@@ -233,13 +276,13 @@ class DDGDA:
"""
with self._task_path.open("rb") as f:
tasks = pickle.load(f)
rb = RollingBenchmark(rolling_exp="rolling_ds", model_type=self.forecast_model)
rb = RollingBenchmark(rolling_exp="rolling_ds", model_type=self.forecast_model, **self.rb_kwargs)
rb.train_rolling_tasks(tasks)
rb.ens_rolling()
rb.update_rolling_rec()
def run_all(self):
# 1) file: handler_proxy.pkl
# 1) file: handler_proxy.pkl (self.proxy_hd)
self.dump_data_for_proxy_model()
# 2)
# file: internal_data_s20.pkl

View File

@@ -1,13 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Optional
from qlib.model.ens.ensemble import RollingEnsemble
from qlib.utils import init_instance_by_config
import fire
import yaml
import pandas as pd
from qlib import auto_init
from pathlib import Path
from tqdm.auto import tqdm
from qlib.model.trainer import TrainerR
from qlib.log import get_module_logger
from qlib.utils.data import update_config
from qlib.workflow import R
from qlib.tests.data import GetData
@@ -25,11 +29,40 @@ class RollingBenchmark:
"""
def __init__(self, rolling_exp="rolling_models", model_type="linear") -> None:
def __init__(
self,
rolling_exp: str = "rolling_models",
model_type: str = "linear",
h_path: Optional[str] = None,
train_start: Optional[str] = None,
test_end: Optional[str] = None,
task_ext_conf: Optional[dict] = None,
) -> None:
"""
Parameters
----------
rolling_exp : str
The name for the experiments for rolling
model_type : str
The model to be boosted.
h_path : Optional[str]
the dumped data handler;
test_end : Optional[str]
the test end for the data. It is typically used together with the handler
train_start : Optional[str]
the train start for the data. It is typically used together with the handler.
task_ext_conf : Optional[dict]
some option to update the
"""
self.step = 20
self.horizon = 20
self.rolling_exp = rolling_exp
self.model_type = model_type
self.h_path = h_path
self.train_start = train_start
self.test_end = test_end
self.logger = get_module_logger("RollingBenchmark")
self.task_ext_conf = task_ext_conf
def basic_task(self):
"""For fast training rolling"""
@@ -42,6 +75,10 @@ class RollingBenchmark:
h_path = DIRNAME / "linear_alpha158_handler_horizon{}.pkl".format(self.horizon)
else:
raise AssertionError("Model type is not supported!")
if self.h_path is not None:
h_path = Path(self.h_path)
with conf_path.open("r") as f:
conf = yaml.safe_load(f)
@@ -52,6 +89,9 @@ class RollingBenchmark:
task = conf["task"]
if self.task_ext_conf is not None:
task = update_config(task, self.task_ext_conf)
if not h_path.exists():
h_conf = task["dataset"]["kwargs"]["handler"]
h = init_instance_by_config(h_conf)
@@ -59,6 +99,15 @@ class RollingBenchmark:
task["dataset"]["kwargs"]["handler"] = f"file://{h_path}"
task["record"] = ["qlib.workflow.record_temp.SignalRecord"]
if self.train_start is not None:
seg = task["dataset"]["kwargs"]["segments"]["train"]
task["dataset"]["kwargs"]["segments"]["train"] = pd.Timestamp(self.train_start), seg[1]
if self.test_end is not None:
seg = task["dataset"]["kwargs"]["segments"]["test"]
task["dataset"]["kwargs"]["segments"]["test"] = seg[0], pd.Timestamp(self.test_end)
self.logger.info(task)
return task
def create_rolling_tasks(self):
@@ -93,7 +142,7 @@ class RollingBenchmark:
"""
Evaluate the combined rolling results
"""
for rid, rec in R.list_recorders(experiment_name=self.COMB_EXP).items():
for _, rec in R.list_recorders(experiment_name=self.COMB_EXP).items():
for rt_cls in SigAnaRecord, PortAnaRecord:
rt = rt_cls(recorder=rec, skip_existing=True)
rt.generate()

View File

@@ -14,9 +14,10 @@ python -m qlib.run.get_data qlib_data qlib_data --target_dir ./data/bin --region
To run codes in this example, we need data in pickle format. To achieve this, run following commands (might need a few minutes to finish):
[//]: # (TODO: Instead of dumping dataframe with different format &#40;like `_gen_dataset` and `_gen_day_dataset` in `qlib/contrib/data/highfreq_provider.py`&#41;, we encourage to implement different subclass of `Dataset` and `DataHandler`. This will keep the workflow cleaner and interfaces more consistent, and move all the complexity to the subclass.)
```
python scripts/gen_pickle_data.py -c scripts/pickle_data_config.yml
python scripts/collect_pickle_dataframe.py
python scripts/gen_training_orders.py
python scripts/merge_orders.py
```
@@ -27,8 +28,7 @@ When finished, the structure under `data/` should be:
data
├── bin
├── orders
── pickle
└── pickle_dataframe
── pickle
```
## Training

View File

@@ -1,17 +1,9 @@
order_file: ./data/orders/test_orders.pkl
start_time: "9:30"
end_time: "14:54"
data_granularity: "5min"
qlib:
provider_uri_5min: ./data/bin/
feature_root_dir: ./data/pickle/
feature_columns_today: [
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"
]
feature_columns_yesterday: [
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"
]
exchange:
limit_threshold: null
deal_price: ["$close", "$close"]
@@ -45,10 +37,12 @@ strategies:
data_ticks: 48
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
class: HandlerProcessedDataProvider
kwargs:
data_dir: ./data/pickle_dataframe/feature
module_path: qlib.rl.data.pickle_styled
data_dir: ./data/pickle/
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
module_path: qlib.rl.data.native
module_path: qlib.rl.order_execution.interpreter
module_path: qlib.rl.order_execution.strategy
30min:

View File

@@ -1,17 +1,9 @@
order_file: ./data/orders/test_orders.pkl
start_time: "9:30"
end_time: "14:54"
data_granularity: "5min"
qlib:
provider_uri_5min: ./data/bin/
feature_root_dir: ./data/pickle/
feature_columns_today: [
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"
]
feature_columns_yesterday: [
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"
]
exchange:
limit_threshold: null
deal_price: ["$close", "$close"]
@@ -45,10 +37,12 @@ strategies:
data_ticks: 48
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
class: HandlerProcessedDataProvider
kwargs:
data_dir: ./data/pickle_dataframe/feature
module_path: qlib.rl.data.pickle_styled
data_dir: ./data/pickle/
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
module_path: qlib.rl.data.native
module_path: qlib.rl.order_execution.interpreter
module_path: qlib.rl.order_execution.strategy
30min:

View File

@@ -1,17 +1,9 @@
order_file: ./data/orders/test_orders.pkl
start_time: "9:30"
end_time: "14:54"
data_granularity: "5min"
qlib:
provider_uri_5min: ./data/bin/
feature_root_dir: ./data/pickle/
feature_columns_today: [
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"
]
feature_columns_yesterday: [
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"
]
exchange:
limit_threshold: null
deal_price: ["$close", "$close"]

View File

@@ -3,8 +3,8 @@ simulator:
time_per_step: 30
vol_limit: null
env:
concurrency: 48
parallel_mode: shmem
concurrency: 32
parallel_mode: dummy
action_interpreter:
class: CategoricalActionInterpreter
kwargs:
@@ -18,10 +18,13 @@ state_interpreter:
data_ticks: 48 # 48 = 240 min / 5 min
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
module_path: qlib.rl.data.pickle_styled
class: HandlerProcessedDataProvider
kwargs:
data_dir: ./data/pickle_dataframe/feature
data_dir: ./data/pickle/
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
backtest: false
module_path: qlib.rl.data.native
module_path: qlib.rl.order_execution.interpreter
reward:
class: PAPenaltyReward
@@ -32,7 +35,9 @@ reward:
data:
source:
order_dir: ./data/orders
data_dir: ./data/pickle_dataframe/backtest
feature_root_dir: ./data/pickle/
feature_columns_today: ["$close0", "$volume0"]
feature_columns_yesterday: []
total_time: 240
default_start_time_index: 0
default_end_time_index: 235

View File

@@ -3,8 +3,8 @@ simulator:
time_per_step: 30
vol_limit: null
env:
concurrency: 48
parallel_mode: shmem
concurrency: 32
parallel_mode: dummy
action_interpreter:
class: CategoricalActionInterpreter
kwargs:
@@ -18,10 +18,13 @@ state_interpreter:
data_ticks: 48 # 48 = 240 min / 5 min
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
module_path: qlib.rl.data.pickle_styled
class: HandlerProcessedDataProvider
kwargs:
data_dir: ./data/pickle_dataframe/feature
data_dir: ./data/pickle/
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
backtest: false
module_path: qlib.rl.data.native
module_path: qlib.rl.order_execution.interpreter
reward:
class: PPOReward
@@ -33,7 +36,9 @@ reward:
data:
source:
order_dir: ./data/orders
data_dir: ./data/pickle_dataframe/backtest
feature_root_dir: ./data/pickle/
feature_columns_today: ["$close0", "$volume0"]
feature_columns_yesterday: []
total_time: 240
default_start_time_index: 0
default_end_time_index: 235

View File

@@ -1,26 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import pickle
import pandas as pd
from joblib import Parallel, delayed
os.makedirs(os.path.join("data", "pickle_dataframe"), exist_ok=True)
def _collect(df: pd.DataFrame, instrument: str, tag: str) -> None:
cur = df[df["instrument"] == instrument].sort_values(by=["datetime"])
cur = cur.set_index(["instrument", "datetime", "date"])
pickle.dump(cur, open(os.path.join("data", "pickle_dataframe", tag, f"{instrument}.pkl"), "wb"))
for tag in ("backtest", "feature"):
df = pickle.load(open(os.path.join("data", "pickle", f"{tag}.pkl"), "rb"))
df = pd.concat(list(df.values())).reset_index()
df["date"] = df["datetime"].dt.date.astype("datetime64")
instruments = sorted(set(df["instrument"]))
os.makedirs(os.path.join("data", "pickle_dataframe", tag), exist_ok=True)
Parallel(n_jobs=-1, verbose=10)(delayed(_collect)(df, instrument, tag) for instrument in instruments)

View File

@@ -4,17 +4,22 @@
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path
DATA_PATH = Path(os.path.join("data", "pickle_dataframe", "backtest"))
DATA_PATH = Path(os.path.join("data", "pickle", "backtest"))
OUTPUT_PATH = Path(os.path.join("data", "orders"))
def generate_order(stock: str, start_idx: int, end_idx: int) -> None:
df = pd.read_pickle(DATA_PATH / f"{stock}.pkl")
def generate_order(stock: str, start_idx: int, end_idx: int) -> bool:
dataset = pd.read_pickle(DATA_PATH / f"{stock}.pkl")
df = dataset.handler.fetch(level=None).reset_index()
if len(df) == 0 or df.isnull().values.any() or min(df["$volume0"]) < 1e-5:
return False
df["date"] = df["datetime"].dt.date.astype("datetime64")
df = df.set_index(["instrument", "datetime", "date"])
df = df.groupby("date").take(range(start_idx, end_idx)).droplevel(level=0)
div = df["$volume0"].rolling((end_idx - start_idx) * 60).mean().shift(1).groupby(level="date").transform("first")
order_all = pd.DataFrame(df.groupby(level=(2, 0)).mean().dropna())
order_all["amount"] = np.random.lognormal(-3.28, 1.14) * order_all["$volume0"]
@@ -32,11 +37,17 @@ def generate_order(stock: str, start_idx: int, end_idx: int) -> None:
os.makedirs(path, exist_ok=True)
if len(order) > 0:
order.to_pickle(path / f"{stock}.pkl.target")
return True
np.random.seed(1234)
file_list = sorted(os.listdir(DATA_PATH))
stocks = [f.replace(".pkl", "") for f in file_list]
stocks = sorted(np.random.choice(stocks, size=100, replace=False))
for stock in tqdm(stocks):
generate_order(stock, 0, 240 // 5 - 1)
np.random.shuffle(stocks)
cnt = 0
for stock in stocks:
if generate_order(stock, 0, 240 // 5 - 1):
cnt += 1
if cnt == 100:
break

View File

@@ -179,7 +179,7 @@ def get_strategy_executor(
executor: Union[str, dict, object, Path],
benchmark: Optional[str] = "SH000300",
account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {},
exchange_kwargs: Union[dict, Exchange] = {}, # TODO: rename parameter
pos_type: str = "Position",
) -> Tuple[BaseStrategy, BaseExecutor]:
@@ -197,12 +197,15 @@ def get_strategy_executor(
pos_type=pos_type,
)
exchange_kwargs = copy.copy(exchange_kwargs)
if "start_time" not in exchange_kwargs:
exchange_kwargs["start_time"] = start_time
if "end_time" not in exchange_kwargs:
exchange_kwargs["end_time"] = end_time
trade_exchange = get_exchange(**exchange_kwargs)
if isinstance(exchange_kwargs, Exchange):
trade_exchange = exchange_kwargs
else:
exchange_kwargs = copy.copy(exchange_kwargs)
if "start_time" not in exchange_kwargs:
exchange_kwargs["start_time"] = start_time
if "end_time" not in exchange_kwargs:
exchange_kwargs["end_time"] = end_time
trade_exchange = get_exchange(**exchange_kwargs)
common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=trade_exchange)
trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy)

View File

@@ -56,6 +56,7 @@ def collect_data_loop(
trade_strategy: BaseStrategy,
trade_executor: BaseExecutor,
return_value: dict | None = None,
show_progress: bool = True,
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], None]:
"""Generator for collecting the trade decision data for rl training
@@ -74,6 +75,8 @@ def collect_data_loop(
the outermost executor
return_value : dict
used for backtest_loop
show_progress: bool
whether to show execution progress
Yields
-------
@@ -83,7 +86,8 @@ def collect_data_loop(
trade_executor.reset(start_time=start_time, end_time=end_time)
trade_strategy.reset(level_infra=trade_executor.get_level_infra())
with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop") as bar:
disable = not show_progress
with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop", disable=disable) as bar:
_execute_result = None
while not trade_executor.finished():
_trade_decision: BaseTradeDecision = trade_strategy.generate_trade_decision(_execute_result)

View File

@@ -177,7 +177,7 @@ class Exchange:
necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"}
if self.limit_type == self.LT_TP_EXP:
assert isinstance(limit_threshold, tuple)
assert isinstance(limit_threshold, tuple) or (isinstance(limit_threshold, list) and len(limit_threshold) == 2)
for exp in limit_threshold:
necessary_fields.add(exp)
all_fields = list(necessary_fields | set(vol_lt_fields) | set(subscribe_fields))
@@ -263,6 +263,9 @@ class Exchange:
"""get limit type"""
if isinstance(limit_threshold, tuple):
return self.LT_TP_EXP
if isinstance(limit_threshold, list):
assert len(limit_threshold) == 2
return self.LT_TP_EXP
elif isinstance(limit_threshold, float):
return self.LT_FLT
elif limit_threshold is None:
@@ -325,7 +328,7 @@ class Exchange:
assert isinstance(volume_threshold, dict)
for key, vol_limit in volume_threshold.items():
assert isinstance(vol_limit, tuple)
assert isinstance(vol_limit, tuple) or (isinstance(vol_limit, list) and len(vol_limit) == 2)
fields.add(vol_limit[1])
if key in ("buy", "all"):
@@ -803,7 +806,7 @@ class Exchange:
vol_limit_num: List[float] = []
for limit in vol_limit:
assert isinstance(limit, tuple)
assert isinstance(limit, tuple) or (isinstance(limit, list) and len(limit) == 2)
if limit[0] == "current":
limit_value = self.quote.get_data(
order.stock_id,

View File

@@ -147,6 +147,7 @@ _default_config = {
"redis_host": "127.0.0.1",
"redis_port": 6379,
"redis_task_db": 1,
"redis_password": None,
# This value can be reset via qlib.init
"logging_level": logging.INFO,
# Global configuration of qlib log

View File

@@ -55,8 +55,10 @@ class InternalData:
# The handler is initialized for only once.
if not trainer.has_worker():
self.dh = init_task_handler(perf_task_tpl)
self.dh.config(dump_all=False) # in some cases, the data handler are saved to disk with `dump_all=True`
else:
self.dh = init_instance_by_config(perf_task_tpl["dataset"]["kwargs"]["handler"])
assert self.dh.dump_all is False # otherwise, it will save all the detailed data
seg = perf_task_tpl["dataset"]["kwargs"]["segments"]
@@ -77,7 +79,7 @@ class InternalData:
get_module_logger("Internal Data").info("the data has been initialized")
else:
# train new models
assert 0 == len(recorders), "An empty experiment is required for setup `InternalData``"
assert 0 == len(recorders), "An empty experiment is required for setup `InternalData`"
trainer.train(gen_task)
# 2) extract the similarity matrix
@@ -119,6 +121,7 @@ class MetaTaskDS(MetaTask):
def __init__(self, task: dict, meta_info: pd.DataFrame, mode: str = MetaTask.PROC_MODE_FULL, fill_method="max"):
"""
The description of the processed data
time_perf: A array with shape <hist_step_n * step, data pieces> -> data piece performance
@@ -132,6 +135,10 @@ class MetaTaskDS(MetaTask):
[0., 0., 0., ..., 0., 0., 1.],
[0., 0., 0., ..., 0., 0., 1.]])
Parameters
----------
meta_info: pd.DataFrame
please refer to the docs of _prepare_meta_ipt for detailed explanation.
"""
super().__init__(task, meta_info)
self.fill_method = fill_method
@@ -180,12 +187,41 @@ class MetaTaskDS(MetaTask):
self.processed_meta_input = data_to_tensor(self.processed_meta_input)
def _get_processed_meta_info(self):
meta_info_norm = self.meta_info.sub(self.meta_info.mean(axis=1), axis=0) # .fillna(0.)
if self.fill_method == "max":
meta_info_norm = meta_info_norm.T.fillna(
meta_info_norm.max(axis=1)
).T # fill it with row max to align with previous implementation
meta_info_norm = self.meta_info.sub(self.meta_info.mean(axis=1), axis=0)
if self.fill_method.startswith("max"):
suffix = self.fill_method.lstrip("max")
if suffix == "seg":
fill_value = {}
for col in meta_info_norm.columns:
fill_value[col] = meta_info_norm.loc[meta_info_norm[col].isna(), :].dropna(axis=1).mean().max()
fill_value = pd.Series(fill_value).sort_index()
# The NaN Values are filled segment-wise. Below is an exampleof fill_value
# 2009-01-05 2009-02-06 0.145809
# 2009-02-09 2009-03-06 0.148005
# 2009-03-09 2009-04-03 0.090385
# 2009-04-07 2009-05-05 0.114318
# 2009-05-06 2009-06-04 0.119328
# ...
meta_info_norm = meta_info_norm.fillna(fill_value)
else:
if len(suffix) > 0:
get_module_logger("MetaTaskDS").warning(
f"fill_method={self.fill_method}; the info after can't be correctly parsed. Please check your parameters."
)
fill_value = meta_info_norm.max(axis=1)
# fill it with row max to align with previous implementation
# This will magnify the data similarity when data is in daily freq
# the fill value corresponds to data like this
# It get a performance value for each day.
# The performance value are get from other models on this day
# 2009-01-16 0.276320
# 2009-01-19 0.280603
# ...
# 2011-06-27 0.203773
meta_info_norm = meta_info_norm.T.fillna(fill_value).T
elif self.fill_method == "zero":
# It will fillna(0.0) at the end.
pass
else:
raise NotImplementedError(f"This type of input is not supported")
@@ -286,7 +322,33 @@ class MetaDatasetDS(MetaTaskDataset):
logger.warning(f"ValueError: {e}")
assert len(self.meta_task_l) > 0, "No meta tasks found. Please check the data and setting"
def _prepare_meta_ipt(self, task):
def _prepare_meta_ipt(self, task) -> pd.DataFrame:
"""
Please refer to `self.internal_data.setup` for detailed information about `self.internal_data.data_ic_df`
Indices with format below can be successfully sliced by `ic_df.loc[:end, pd.IndexSlice[:, :end]]`
2021-06-21 2021-06-04 .. 2021-03-22 2021-03-08
2021-07-02 2021-06-18 .. 2021-04-02 None
Returns
-------
a pd.DataFrame with similar content below.
- each column corresponds to a trained model named by the training data range
- each row corresponds to a day of data tested by the models of the columns
- The rows cells that overlaps with the data used by columns are masked
2009-01-05 2009-02-09 ... 2011-04-27 2011-05-26
2009-02-06 2009-03-06 ... 2011-05-25 2011-06-23
datetime ...
2009-01-13 NaN 0.310639 ... -0.169057 0.137792
2009-01-14 NaN 0.261086 ... -0.143567 0.082581
... ... ... ... ... ...
2011-06-30 -0.054907 -0.020219 ... -0.023226 NaN
2011-07-01 -0.075762 -0.026626 ... -0.003167 NaN
"""
ic_df = self.internal_data.data_ic_df
segs = task["dataset"]["kwargs"]["segments"]
@@ -294,15 +356,19 @@ class MetaDatasetDS(MetaTaskDataset):
ic_df_avail = ic_df.loc[:end, pd.IndexSlice[:, :end]]
# meta data set focus on the **information** instead of preprocess
# 1) filter the future info
def mask_future(s):
"""mask future information"""
# from qlib.utils import get_date_by_shift
# 1) filter the overlap info
def mask_overlap(s):
"""
mask overlap information
data after self.name[end] with self.trunc_days that contains future info are also considered as overlap info
Approximately the diagnal + horizon length of data are masked.
"""
start, end = s.name
end = get_date_by_shift(trading_date=end, shift=self.trunc_days - 1, future=True)
return s.mask((s.index >= start) & (s.index <= end))
ic_df_avail = ic_df_avail.apply(mask_future) # apply to each col
ic_df_avail = ic_df_avail.apply(mask_overlap) # apply to each col
# 2) filter the info with too long periods
total_len = self.step * self.hist_step_n

View File

@@ -52,6 +52,7 @@ class MetaModelDS(MetaTaskModel):
lr=0.0001,
max_epoch=100,
seed=43,
alpha=0.0,
):
self.step = step
self.hist_step_n = hist_step_n
@@ -61,6 +62,7 @@ class MetaModelDS(MetaTaskModel):
self.lr = lr
self.max_epoch = max_epoch
self.fitted = False
self.alpha = alpha
torch.manual_seed(seed)
def run_epoch(self, phase, task_list, epoch, opt, loss_l, ignore_weight=False):
@@ -144,7 +146,11 @@ class MetaModelDS(MetaTaskModel):
) # debug: record when the test phase starts
self.tn = PredNet(
step=self.step, hist_step_n=self.hist_step_n, clip_weight=self.clip_weight, clip_method=self.clip_method
step=self.step,
hist_step_n=self.hist_step_n,
clip_weight=self.clip_weight,
clip_method=self.clip_method,
alpha=self.alpha,
)
opt = optim.Adam(self.tn.parameters(), lr=self.lr)

View File

@@ -41,11 +41,18 @@ class TimeWeightMeta(SingleMetaBase):
class PredNet(nn.Module):
def __init__(self, step, hist_step_n, clip_weight=None, clip_method="tanh"):
def __init__(self, step, hist_step_n, clip_weight=None, clip_method="tanh", alpha: float = 0.0):
"""
Parameters
----------
alpha : float
the regularization for sub model (useful when align meta model with linear submodel)
"""
super().__init__()
self.step = step
self.twm = TimeWeightMeta(hist_step_n=hist_step_n, clip_weight=clip_weight, clip_method=clip_method)
self.init_paramters(hist_step_n)
self.alpha = alpha
def get_sample_weights(self, X, time_perf, time_belong, ignore_weight=False):
weights = torch.from_numpy(np.ones(X.shape[0])).float().to(X.device)
@@ -59,7 +66,7 @@ class PredNet(nn.Module):
"""Please refer to the docs of MetaTaskDS for the description of the variables"""
weights = self.get_sample_weights(X, time_perf, time_belong, ignore_weight=ignore_weight)
X_w = X.T * weights.view(1, -1)
theta = torch.inverse(X_w @ X) @ X_w @ y
theta = torch.inverse(X_w @ X + self.alpha * torch.eye(X_w.shape[0])) @ X_w @ y
return X_test @ theta, weights
def init_paramters(self, hist_step_n):

View File

@@ -5,6 +5,9 @@ import numpy as np
import torch
from torch import nn
from qlib.constant import EPS
from qlib.log import get_module_logger
class ICLoss(nn.Module):
def forward(self, pred, y, idx, skip_size=50):
@@ -24,6 +27,7 @@ class ICLoss(nn.Module):
diff_point.append(i)
prev = date
diff_point.append(None)
# The lengths of diff_point will be one more larger then diff_point
ic_all = 0.0
skip_n = 0
@@ -34,13 +38,23 @@ class ICLoss(nn.Module):
skip_n += 1
continue
y_focus = y[start_i:end_i]
if pred_focus.std() < EPS or y_focus.std() < EPS:
# These cases often happend at the end of test data.
# Usually caused by fillna(0.)
skip_n += 1
continue
ic_day = torch.dot(
(pred_focus - pred_focus.mean()) / np.sqrt(pred_focus.shape[0]) / pred_focus.std(),
(y_focus - y_focus.mean()) / np.sqrt(y_focus.shape[0]) / y_focus.std(),
)
ic_all += ic_day
if len(diff_point) - 1 - skip_n <= 0:
raise ValueError("No enough data for calculating iC")
raise ValueError("No enough data for calculating IC")
if skip_n > 0:
get_module_logger("ICLoss").info(
f"{skip_n} days are skipped due to zero std or small scale of valid samples."
)
ic_mean = ic_all / (len(diff_point) - 1 - skip_n)
return -ic_mean # ic loss

View File

@@ -4,6 +4,7 @@
import numpy as np
import pandas as pd
from typing import Text, Union
from qlib.log import get_module_logger
from qlib.data.dataset.weight import Reweighter
from scipy.optimize import nnls
from sklearn.linear_model import LinearRegression, Ridge, Lasso
@@ -29,7 +30,7 @@ class LinearModel(Model):
RIDGE = "ridge"
LASSO = "lasso"
def __init__(self, estimator="ols", alpha=0.0, fit_intercept=False):
def __init__(self, estimator="ols", alpha=0.0, fit_intercept=False, include_valid: bool = False):
"""
Parameters
----------
@@ -39,6 +40,9 @@ class LinearModel(Model):
l1 or l2 regularization parameter
fit_intercept : bool
whether fit intercept
include_valid: bool
Should the validation data be included for training?
The validation data should be included
"""
assert estimator in [self.OLS, self.NNLS, self.RIDGE, self.LASSO], f"unsupported estimator `{estimator}`"
self.estimator = estimator
@@ -49,9 +53,16 @@ class LinearModel(Model):
self.fit_intercept = fit_intercept
self.coef_ = None
self.include_valid = include_valid
def fit(self, dataset: DatasetH, reweighter: Reweighter = None):
df_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
if self.include_valid:
try:
df_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
df_train = pd.concat([df_train, df_valid])
except KeyError:
get_module_logger("LinearModel").info("include_valid=True, but valid does not exist")
if df_train.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
if reweighter is not None:

View File

@@ -47,10 +47,6 @@ class DNNModelPytorch(Model):
layer sizes
lr : float
learning rate
lr_decay : float
learning rate decay
lr_decay_steps : int
learning rate decay steps
optimizer : str
optimizer name
GPU : int
@@ -64,8 +60,6 @@ class DNNModelPytorch(Model):
batch_size=2000,
early_stop_rounds=50,
eval_steps=20,
lr_decay=0.96,
lr_decay_steps=100,
optimizer="gd",
loss="mse",
GPU=0,
@@ -93,8 +87,6 @@ class DNNModelPytorch(Model):
self.batch_size = batch_size
self.early_stop_rounds = early_stop_rounds
self.eval_steps = eval_steps
self.lr_decay = lr_decay
self.lr_decay_steps = lr_decay_steps
self.optimizer = optimizer.lower()
self.loss_type = loss
if isinstance(GPU, str):
@@ -116,8 +108,6 @@ class DNNModelPytorch(Model):
f"\nbatch_size : {batch_size}"
f"\nearly_stop_rounds : {early_stop_rounds}"
f"\neval_steps : {eval_steps}"
f"\nlr_decay : {lr_decay}"
f"\nlr_decay_steps : {lr_decay_steps}"
f"\noptimizer : {optimizer}"
f"\nloss_type : {loss}"
f"\nseed : {seed}"

View File

@@ -635,7 +635,7 @@ class FileOrderStrategy(BaseStrategy):
self.order_df = file
else:
with get_io_object(file) as f:
self.order_df = pd.read_csv(f, dtype={"datetime": np.str})
self.order_df = pd.read_csv(f, dtype={"datetime": str})
self.order_df["datetime"] = self.order_df["datetime"].apply(pd.Timestamp)
self.order_df = self.order_df.set_index(["datetime", "instrument"])

View File

@@ -783,7 +783,7 @@ class LocalPITProvider(PITProvider):
index_path = C.dpm.get_data_uri() / "financial" / instrument.lower() / f"{field}.index"
data_path = C.dpm.get_data_uri() / "financial" / instrument.lower() / f"{field}.data"
if not (index_path.exists() and data_path.exists()):
raise FileNotFoundError("No file is found. Raise exception and ")
raise FileNotFoundError("No file is found.")
# NOTE: The most significant performance loss is here.
# Does the acceleration that makes the program complicated really matters?
# - It makes parameters of the interface complicate
@@ -797,14 +797,14 @@ class LocalPITProvider(PITProvider):
cur_time_int = int(cur_time.year) * 10000 + int(cur_time.month) * 100 + int(cur_time.day)
loc = np.searchsorted(data["date"], cur_time_int, side="right")
if loc <= 0:
return pd.Series()
return pd.Series(dtype=C.pit_record_type["value"])
last_period = data["period"][:loc].max() # return the latest quarter
first_period = data["period"][:loc].min()
period_list = get_period_list(first_period, last_period, quarterly)
if period is not None:
# NOTE: `period` has higher priority than `start_index` & `end_index`
if period not in period_list:
return pd.Series()
return pd.Series(dtype=C.pit_record_type["value"])
else:
period_list = [period]
else:
@@ -868,7 +868,7 @@ class LocalExpressionProvider(ExpressionProvider):
# Ensure that each column type is consistent
# FIXME:
# 1) The stock data is currently float. If there is other types of data, this part needs to be re-implemented.
# 2) The the precision should be configurable
# 2) The precision should be configurable
try:
series = series.astype(np.float32)
except ValueError:

View File

@@ -417,7 +417,7 @@ class TSDataSampler:
# NOTE: bool(np.nan) is True !!!!!!!!
# make sure reindex comes first. Otherwise extra NaN may appear.
flt_data = flt_data.swaplevel()
flt_data = flt_data.reindex(self.data_index).fillna(False).astype(np.bool)
flt_data = flt_data.reindex(self.data_index).fillna(False).astype(bool)
self.flt_data = flt_data.values
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
self.data_index = self.data_index[np.where(self.flt_data)[0]]

View File

@@ -720,3 +720,26 @@ class DataHandlerLP(DataHandler):
]:
setattr(new_hd, key, getattr(handler, key, None))
return new_hd
@classmethod
def from_df(cls, df: pd.DataFrame) -> "DataHandlerLP":
"""
Motivation:
- When user want to get a quick data handler.
The created data handler will have only one shared Dataframe without processors.
After creating the handler, user may often want to dump the handler for reuse
Here is a typical use case
.. code-block:: python
from qlib.data.dataset import DataHandlerLP
dh = DataHandlerLP.from_df(df)
dh.to_pickle(fname, dump_all=True)
TODO:
- The StaticDataLoader is quite slow. It don't have to copy the data again...
"""
loader = data_loader_module.StaticDataLoader(df)
return cls(data_loader=loader)

View File

@@ -2,9 +2,8 @@
# Licensed under the MIT License.
from __future__ import annotations
import pandas as pd
from typing import Union, List
from typing import Union, List, TYPE_CHECKING
from qlib.utils import init_instance_by_config
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from qlib.data.dataset import DataHandler
@@ -121,7 +120,7 @@ def convert_index_format(df: Union[pd.DataFrame, pd.Series], level: str = "datet
return df
def init_task_handler(task: dict) -> Union[DataHandler, None]:
def init_task_handler(task: dict) -> DataHandler:
"""
initialize the handler part of the task **inplace**
@@ -142,5 +141,6 @@ def init_task_handler(task: dict) -> Union[DataHandler, None]:
if h_conf is not None:
handler = init_instance_by_config(h_conf, accept_types=DataHandler)
task["dataset"]["kwargs"]["handler"] = handler
return handler
else:
raise ValueError("The task does not contains a handler part.")

View File

@@ -16,13 +16,12 @@ import torch
from joblib import Parallel, delayed
from qlib.backtest import INDICATOR_METRIC, collect_data_loop, get_strategy_executor
from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir, TradeRangeByTime
from qlib.backtest.decision import BaseTradeDecision, TradeRangeByTime
from qlib.backtest.executor import SimulatorExecutor
from qlib.backtest.high_performance_ds import BaseOrderIndicator
from qlib.rl.contrib.naive_config_parser import get_backtest_config_fromfile
from qlib.rl.contrib.naive_config_parser import BacktestConfigParser
from qlib.rl.contrib.utils import read_order_file
from qlib.rl.data.integration import init_qlib
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
from qlib.typehint import Literal
@@ -30,12 +29,13 @@ def _get_multi_level_executor_config(
strategy_config: dict,
cash_limit: float | None = None,
generate_report: bool = False,
data_granularity: str = "1min",
) -> dict:
executor_config = {
"class": "SimulatorExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"time_per_step": "5min", # FIXME: move this into config
"time_per_step": data_granularity,
"verbose": False,
"trade_type": SimulatorExecutor.TT_PARAL if cash_limit is not None else SimulatorExecutor.TT_SERIAL,
"generate_report": generate_report,
@@ -123,109 +123,13 @@ def _generate_report(
return report
def single_with_simulator(
backtest_config: dict,
orders: pd.DataFrame,
split: Literal["stock", "day"] = "stock",
cash_limit: float | None = None,
generate_report: bool = False,
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
"""Run backtest in a single thread with SingleAssetOrderExecution simulator. The orders will be executed day by day.
A new simulator will be created and used for every single-day order.
Parameters
----------
backtest_config:
Backtest config
orders:
Orders to be executed. Example format:
datetime instrument amount direction
0 2020-06-01 INST 600.0 0
1 2020-06-02 INST 700.0 1
...
split
Method to split orders. If it is "stock", split orders by stock. If it is "day", split orders by date.
cash_limit
Limitation of cash.
generate_report
Whether to generate reports.
Returns
-------
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
"""
if split == "stock":
stock_id = orders.iloc[0].instrument
init_qlib(backtest_config["qlib"], part=stock_id)
else:
day = orders.iloc[0].datetime
init_qlib(backtest_config["qlib"], part=day)
stocks = orders.instrument.unique().tolist()
reports = []
decisions = []
for _, row in orders.iterrows():
date = pd.Timestamp(row["datetime"])
start_time = pd.Timestamp(backtest_config["start_time"]).replace(year=date.year, month=date.month, day=date.day)
end_time = pd.Timestamp(backtest_config["end_time"]).replace(year=date.year, month=date.month, day=date.day)
order = Order(
stock_id=row["instrument"],
amount=row["amount"],
direction=OrderDir(row["direction"]),
start_time=start_time,
end_time=end_time,
)
executor_config = _get_multi_level_executor_config(
strategy_config=backtest_config["strategies"],
cash_limit=cash_limit,
generate_report=generate_report,
)
exchange_config = copy.deepcopy(backtest_config["exchange"])
exchange_config.update(
{
"codes": stocks,
"freq": "5min", # FIXME: move this into config
}
)
simulator = SingleAssetOrderExecution(
order=order,
executor_config=executor_config,
exchange_config=exchange_config,
qlib_config=None,
cash_limit=None,
)
reports.append(simulator.report_dict)
decisions += simulator.decisions
indicator_1day_objs = [report["indicator"]["1day"][1] for report in reports]
indicator_info = {k: v for obj in indicator_1day_objs for k, v in obj.order_indicator_his.items()}
records = _convert_indicator_to_dataframe(indicator_info)
assert records is None or not np.isnan(records["ffr"]).any()
if generate_report:
_report = _generate_report(decisions, [report["indicator"] for report in reports])
if split == "stock":
stock_id = orders.iloc[0].instrument
report = {stock_id: _report}
else:
day = orders.iloc[0].datetime
report = {day: _report}
return records, report
else:
return records
def single_with_collect_data_loop(
backtest_config: dict,
orders: pd.DataFrame,
time_range: Tuple[str, str],
exchange_config: dict,
strategy_config: dict,
split: Literal["stock", "day"] = "stock",
data_granularity: str = "1min",
cash_limit: float | None = None,
generate_report: bool = False,
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
@@ -253,48 +157,42 @@ def single_with_collect_data_loop(
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
"""
if split == "stock":
stock_id = orders.iloc[0].instrument
init_qlib(backtest_config["qlib"], part=stock_id)
else:
day = orders.iloc[0].datetime
init_qlib(backtest_config["qlib"], part=day)
trade_start_time = orders["datetime"].min()
trade_end_time = orders["datetime"].max()
stocks = orders.instrument.unique().tolist()
strategy_config = {
top_strategy_config = {
"class": "FileOrderStrategy",
"module_path": "qlib.contrib.strategy.rule_strategy",
"kwargs": {
"file": orders,
"trade_range": TradeRangeByTime(
pd.Timestamp(backtest_config["start_time"]).time(),
pd.Timestamp(backtest_config["end_time"]).time(),
pd.Timestamp(time_range[0]).time(),
pd.Timestamp(time_range[1]).time(),
),
},
}
executor_config = _get_multi_level_executor_config(
strategy_config=backtest_config["strategies"],
top_executor_config = _get_multi_level_executor_config(
strategy_config=strategy_config,
cash_limit=cash_limit,
generate_report=generate_report,
data_granularity=data_granularity,
)
exchange_config = copy.deepcopy(backtest_config["exchange"])
exchange_config.update(
{
exchange_config = {
**exchange_config,
**{
"codes": stocks,
"freq": "5min", # FIXME: move this into config
}
)
"freq": data_granularity,
},
}
strategy, executor = get_strategy_executor(
start_time=pd.Timestamp(trade_start_time),
end_time=pd.Timestamp(trade_end_time) + pd.DateOffset(1),
strategy=strategy_config,
executor=executor_config,
strategy=top_strategy_config,
executor=top_executor_config,
benchmark=None,
account=cash_limit if cash_limit is not None else int(1e12),
exchange_kwargs=exchange_config,
@@ -302,7 +200,7 @@ def single_with_collect_data_loop(
)
report_dict: dict = {}
decisions = list(collect_data_loop(trade_start_time, trade_end_time, strategy, executor, report_dict))
decisions = list(collect_data_loop(trade_start_time, trade_end_time, strategy, executor, report_dict, show_progress=False))
indicator_dict = cast(INDICATOR_METRIC, report_dict.get("indicator_dict"))
records = _convert_indicator_to_dataframe(indicator_dict["1day"][1].order_indicator_his)
@@ -322,46 +220,54 @@ def single_with_collect_data_loop(
def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFrame:
order_df = read_order_file(backtest_config["order_file"])
cash_limit = backtest_config["exchange"].pop("cash_limit")
generate_report = backtest_config.pop("generate_report")
stock_pool = order_df["instrument"].unique().tolist()
stock_pool.sort()
single = single_with_simulator if with_simulator else single_with_collect_data_loop
mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"}
init_qlib(backtest_config["simulator"]["qlib"])
torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199
res = Parallel(**mp_config)(
delayed(single)(
backtest_config=backtest_config,
orders=order_df[order_df["instrument"] == stock].copy(),
split="stock",
cash_limit=cash_limit,
generate_report=generate_report,
single = single_with_collect_data_loop
mp_config = {"n_jobs": backtest_config["runtime"]["concurrency"], "verbose": 10, "backend": "multiprocessing"}
for task_config in backtest_config["tasks"]:
order_df = read_order_file(task_config["order_file"])
exchange_config = task_config["exchange"]
cash_limit = exchange_config.pop("cash_limit")
generate_report = backtest_config["runtime"]["generate_report"]
stock_pool = order_df["instrument"].unique().tolist()
stock_pool.sort()
#
res = Parallel(**mp_config)(
delayed(single)(
orders=order_df[order_df["instrument"] == stock].copy(),
time_range=task_config["time_range"],
exchange_config=task_config["exchange"],
strategy_config=backtest_config["strategies"],
split="stock",
data_granularity=task_config["data_granularity"],
cash_limit=cash_limit,
generate_report=generate_report,
)
for stock in stock_pool
)
for stock in stock_pool
)
output_path = Path(backtest_config["output_dir"])
if generate_report:
with (output_path / "report.pkl").open("wb") as f:
report = {}
for r in res:
report.update(r[1])
pickle.dump(report, f)
res = pd.concat([r[0] for r in res], 0)
else:
res = pd.concat(res)
if not output_path.exists():
os.makedirs(output_path)
if "pa" in res.columns:
res["pa"] = res["pa"] * 10000.0 # align with training metrics
res.to_csv(output_path / "backtest_result.csv")
return res
#
output_path = Path(task_config["output_dir"])
os.makedirs(output_path, exist_ok=True)
if generate_report:
with (output_path / "report.pkl").open("wb") as f:
report = {}
for r in res:
report.update(r[1])
pickle.dump(report, f)
res = pd.concat([r[0] for r in res], 0)
else:
res = pd.concat(res)
if "pa" in res.columns:
res["pa"] = res["pa"] * 10000.0 # align with training metrics
res.to_csv(output_path / "backtest_result.csv")
# return res # TODO
if __name__ == "__main__":
@@ -369,6 +275,7 @@ if __name__ == "__main__":
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
@@ -381,9 +288,11 @@ if __name__ == "__main__":
)
args = parser.parse_args()
config = get_backtest_config_fromfile(args.config_path)
if args.n_jobs is not None:
config["concurrency"] = args.n_jobs
config_parser = BacktestConfigParser(args.config_path)
config = config_parser.parse()
if args.n_jobs is not None: # Overwrite concurrency
config["runtime"]["concurrency"] = args.n_jobs
backtest(
backtest_config=config,

View File

@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import copy
import os
import platform
import shutil
@@ -30,7 +31,7 @@ def check_file_exist(filename: str, msg_tmpl: str = 'file "{}" does not exist')
raise FileNotFoundError(msg_tmpl.format(filename))
def parse_backtest_config(path: str) -> dict:
def load_config(path: str) -> dict:
abs_path = os.path.abspath(path)
check_file_exist(abs_path)
@@ -65,42 +66,154 @@ def parse_backtest_config(path: str) -> dict:
base_file_name = [base_file_name]
for f in base_file_name:
base_config = parse_backtest_config(os.path.join(os.path.dirname(abs_path), f))
base_config = load_config(os.path.join(os.path.dirname(abs_path), f))
config = merge_a_into_b(a=config, b=base_config)
return config
def _convert_all_list_to_tuple(config: dict) -> dict:
for k, v in config.items():
if isinstance(v, list):
config[k] = tuple(v)
elif isinstance(v, dict):
config[k] = _convert_all_list_to_tuple(v)
return config
class BacktestConfigParser:
def __init__(self, path: str) -> None:
self.raw_config = load_config(path)
def parse(self) -> dict:
self._simulator_config = self._parse_simulator()
self._exchange_config = self._simulator_config.pop("exchange")
config = {
"strategies": self.raw_config["strategies"],
"runtime": self.raw_config["runtime"],
"tasks": self._parse_tasks(),
"simulator": self._simulator_config,
}
return config
def _parse_tasks(self) -> dict:
task_config = []
for task in self.raw_config["tasks"]:
if "output_dir" not in task:
task["output_dir"] = os.path.join("outputs_backtest", task["name"])
if "exchange" not in task:
task["exchange"] = copy.deepcopy(self._exchange_config)
else:
task["exchange"] = self._complete_exchange_config(task["exchange"])
task_config.append(task)
return task_config
def _complete_exchange_config(self, exchange_config: dict) -> dict:
exchange_config_default = {
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5.0,
"trade_unit": 100.0,
"cash_limit": None,
}
exchange_config = merge_a_into_b(a=exchange_config, b=exchange_config_default)
return exchange_config
def _parse_simulator(self) -> dict:
config = self.raw_config["simulator"]
return {
"qlib": config["qlib"],
"exchange": self._complete_exchange_config(config["exchange"]),
}
def get_backtest_config_fromfile(path: str) -> dict:
backtest_config = parse_backtest_config(path)
class TrainingConfigParser:
def __init__(self, path: str) -> None:
self.raw_config = load_config(path)
exchange_config_default = {
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5.0,
"trade_unit": 100.0,
"cash_limit": None,
}
backtest_config["exchange"] = merge_a_into_b(a=backtest_config["exchange"], b=exchange_config_default)
backtest_config["exchange"] = _convert_all_list_to_tuple(backtest_config["exchange"])
def parse(self) -> dict:
return {
"general": self._parse_general(),
"policy": self.raw_config["policy"],
"interpreter": self.raw_config["interpreter"],
"runtime": self._parse_runtime(),
"training": self._parse_training(),
"simulator": self._parse_simulator(),
}
backtest_config_default = {
"debug_single_stock": None,
"debug_single_day": None,
"concurrency": -1,
"multiplier": 1.0,
"output_dir": "outputs_backtest/",
"generate_report": False,
}
backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default)
def _parse_general(self) -> dict:
default = {
"freq": "1min",
"extra_module_paths": [],
}
return {**default, **self.raw_config["general"]}
return backtest_config
def _parse_runtime(self) -> dict:
default = {
"seed": None,
"use_cuda": False,
"concurrency": 1,
"parallel_mode": "dummy",
}
return {**default, **self.raw_config["runtime"]}
def _parse_training(self) -> dict:
default = {
"max_epoch": 100,
"repeat_per_collect": 2,
"earlystop_patience": float("inf"),
"episode_per_collect": 10000,
"batch_size": 256,
"val_every_n_epoch": None,
"checkpoint_path": "./outputs",
"checkpoint_every_n_iters": 10,
}
config = self.raw_config["training"]
assert "order_dir" in config
return {**default, **config}
def _parse_simulator(self) -> dict:
config = self.raw_config["simulator"]
sim_type = config["type"]
assert sim_type in ("simple", "full")
if sim_type == "simple":
return {
"type": sim_type,
"data": {
"feature_root_dir": config["data"]["feature_root_dir"],
"feature_columns_today": config["data"]["feature_columns_today"],
"default_start_time_index": config["data"].get("default_start_time_index", 0),
"default_end_time_index": config["data"].get("default_end_time_index", 240),
},
"time_per_step": config["time_per_step"],
"vol_limit": config["vol_limit"],
}
else:
exchange_config_default = {
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5.0,
"trade_unit": 100.0,
# "cash_limit": None,
}
exchange_config = {**exchange_config_default, **config["exchange"]}
exchange_config["freq"] = self.raw_config["general"].get("freq", "1min")
ret_config = {
"type": sim_type,
"data": {
"feature_root_dir": config["data"]["feature_root_dir"],
"default_start_time_index": config["data"].get("default_start_time_index", 0),
"default_end_time_index": config["data"].get("default_end_time_index", 240),
},
"qlib": {
"provider_uri_1min": config["qlib"]["provider_uri_1min"],
},
"exchange": exchange_config,
}
return ret_config
if __name__ == "__main__":
parser = TrainingConfigParser("/home/huoran/exp_configs/amc4th_training_refined.yml")
from pprint import pprint
pprint(parser.parse())

362
qlib/rl/contrib/train.py Normal file
View File

@@ -0,0 +1,362 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import argparse
import os
import random
import sys
import warnings
from pathlib import Path
from typing import Callable, cast, List, Optional, Sequence
import numpy as np
import pandas as pd
import torch
from qlib.backtest import Order
from qlib.backtest.decision import OrderDir
from qlib.constant import ONE_MIN
from qlib.rl import Simulator
from qlib.rl.contrib.naive_config_parser import TrainingConfigParser
from qlib.rl.data.integration import init_qlib
from qlib.rl.data.pickle_styled import load_pickle_intraday_processed_data
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
from qlib.rl.reward import Reward
from qlib.rl.trainer import Checkpoint, backtest, train
from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter
from qlib.rl.utils.log import CsvWriter
from qlib.utils import init_instance_by_config
from tianshou.policy import BasePolicy
from torch.utils.data import Dataset
def get_executor_config(freq: int) -> dict:
return {
"class": "NestedExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"inner_executor": {
"class": "NestedExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"inner_executor": {
"class": "SimulatorExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"generate_report": False,
"time_per_step": f"{freq}min",
"track_data": True,
"trade_type": "serial",
"verbose": False,
},
},
"inner_strategy": {
"class": "TWAPStrategy",
"kwargs": {},
"module_path": "qlib.contrib.strategy.rule_strategy",
},
"time_per_step": "30min",
"track_data": True,
},
},
"inner_strategy": {
"class": "ProxySAOEStrategy",
"module_path": "qlib.rl.order_execution.strategy",
"kwargs": {},
},
"time_per_step": "1day",
"track_data": True,
},
}
def seed_everything(seed: int) -> None:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def _read_orders(order_dir: Path) -> pd.DataFrame:
if os.path.isfile(order_dir):
return pd.read_pickle(order_dir)
else:
orders = []
for file in order_dir.iterdir():
order_data = pd.read_pickle(file)
orders.append(order_data)
return pd.concat(orders)
def _freq_str_to_int(freq: str) -> int:
if freq.endswith("min"):
return int(freq.replace("min", ""))
elif freq.endswith("hour"):
return int(freq.replace("hour", "") * 60)
else:
raise ValueError(f"Unrecognized freq string: {freq}")
class LazyLoadDataset(Dataset):
def __init__(
self,
data_dir: str,
order_df: pd.DataFrame,
default_start_time_index: int,
default_end_time_index: int,
) -> None:
self._default_start_time_index = default_start_time_index
self._default_end_time_index = default_end_time_index
self._order_df = order_df
self._ticks_index: Optional[pd.DatetimeIndex] = None
self._data_dir = Path(data_dir)
def __len__(self) -> int:
return len(self._order_df)
def __getitem__(self, index: int) -> Order:
row = self._order_df.iloc[index]
date = pd.Timestamp(str(row["date"]))
if self._ticks_index is None:
# TODO: We only load ticks index once based on the assumption that ticks index of different dates
# TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index
# TODO: of all dates.
data = load_pickle_intraday_processed_data(
data_dir=self._data_dir,
stock_id=row["instrument"],
date=date,
feature_columns_today=[],
feature_columns_yesterday=[],
backtest=True,
)
self._ticks_index = [t - date for t in data.today.index]
order = Order(
stock_id=row["instrument"],
amount=row["amount"],
direction=OrderDir(int(row["order_type"])),
start_time=date + self._ticks_index[self._default_start_time_index],
end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN,
)
return order
def _split_order_df_by_instrument(df: pd.DataFrame, k: int) -> List[pd.DataFrame]:
df = df.copy()
df["group"] = df["instrument"].apply(lambda s: hash(s) % k)
dfs = [df[df["group"] == i].drop(columns=["group"]) for i in range(k)]
return dfs
def _get_simulator_factory(
sim_type: str,
data_dir: Path,
freq_min: int,
simulator_config: dict,
) -> Callable[[Order], Simulator]:
if sim_type == "simple":
def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
simulator = SingleAssetOrderExecutionSimple(
order=order,
data_dir=data_dir,
feature_columns_today=simulator_config["data"]["feature_columns_today"],
data_granularity=freq_min,
ticks_per_step=simulator_config["time_per_step"],
vol_threshold=simulator_config["vol_limit"],
)
return simulator
return _simulator_factory_simple
elif sim_type == "full":
init_qlib(simulator_config["qlib"])
executor_config = get_executor_config(freq_min)
exchange_config = simulator_config["exchange"]
def _simulator_factory_full(order: Order) -> SingleAssetOrderExecution:
simulator = SingleAssetOrderExecution(
order=order,
executor_config=executor_config,
exchange_config=exchange_config, # `codes` will be set in SingleAssetOrderExecution.__init__()
qlib_config=None,
cash_limit=None,
)
return simulator
return _simulator_factory_full
else:
raise ValueError(f"Unknown simulator type: {sim_type}")
def train_and_test(
freq: str,
concurrency: int,
parallel_mode: str,
training_config: dict,
simulator_config: dict,
policy: BasePolicy,
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
reward: Reward,
run_training: bool,
run_backtest: bool,
) -> None:
freq_min: int = _freq_str_to_int(freq)
order_root_path = Path(training_config["order_dir"])
feature_root_dir = simulator_config["data"]["feature_root_dir"]
assert simulator_config["data"]["default_start_time_index"] % freq_min == 0
assert simulator_config["data"]["default_end_time_index"] % freq_min == 0
_simulator_factory = _get_simulator_factory(
sim_type=simulator_config["type"],
data_dir=feature_root_dir,
freq_min=freq_min,
simulator_config=simulator_config,
)
# Load orders
load_data_tags = []
orders_by_tag = {}
if run_training:
load_data_tags += ["train", "valid"]
if run_backtest:
load_data_tags += ["test"]
for tag in load_data_tags:
order_df = _read_orders(order_root_path / tag).reset_index()
dfs = _split_order_df_by_instrument(order_df, concurrency)
datasets = [
LazyLoadDataset(
data_dir=feature_root_dir,
order_df=df,
default_start_time_index=simulator_config["data"]["default_start_time_index"] // freq_min,
default_end_time_index=simulator_config["data"]["default_end_time_index"] // freq_min,
)
for df in dfs
]
orders_by_tag[tag] = datasets
if run_training:
callbacks: List[Callback] = [
MetricsWriter(dirpath=Path(training_config["checkpoint_path"])),
Checkpoint(
dirpath=Path(training_config["checkpoint_path"]) / "checkpoints",
every_n_iters=training_config["checkpoint_every_n_iters"],
save_latest="copy",
),
EarlyStopping(
patience=training_config["earlystop_patience"],
monitor="val/pa",
),
]
train(
simulator_fn=_simulator_factory,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
policy=policy,
reward=reward,
initial_states=cast(List[Sequence[Order]], orders_by_tag["train"]),
trainer_kwargs={
"max_iters": training_config["max_epoch"],
"finite_env_type": parallel_mode,
"concurrency": concurrency,
"val_every_n_iters": training_config["val_every_n_epoch"],
"callbacks": callbacks,
},
vessel_kwargs={
"episode_per_iter": training_config["episode_per_collect"],
"update_kwargs": {
"batch_size": training_config["batch_size"],
"repeat": training_config["repeat_per_collect"],
},
"val_initial_states": cast(List[Sequence[Order]], orders_by_tag["valid"]),
},
)
if run_backtest:
backtest(
simulator_fn=_simulator_factory,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
initial_states=cast(List[Sequence[Order]], orders_by_tag["test"]),
policy=policy,
logger=CsvWriter(Path(training_config["checkpoint_path"])),
reward=reward,
finite_env_type=parallel_mode, # type: ignore[arg-type]
concurrency=concurrency,
)
def main(config: dict, run_training: bool, run_backtest: bool) -> None:
if not run_training and not run_backtest:
warnings.warn("Skip the entire job since training and backtest are both skipped.")
return
seed = config["runtime"]["seed"]
if seed is not None:
seed_everything(seed)
for extra_module_path in config["general"]["extra_module_paths"]:
sys.path.append(extra_module_path)
state_interpreter: StateInterpreter = init_instance_by_config(config["interpreter"]["state"])
action_interpreter: ActionInterpreter = init_instance_by_config(config["interpreter"]["action"])
reward: Reward = init_instance_by_config(config["interpreter"]["reward"])
additional_policy_kwargs = {
"obs_space": state_interpreter.observation_space,
"action_space": action_interpreter.action_space,
}
# Create torch network
if "network" in config["policy"]:
network_config = config["policy"]["network"]
network_config["kwargs"] = {
**network_config.get("kwargs", {}),
**{"obs_space": state_interpreter.observation_space},
}
additional_policy_kwargs["network"] = init_instance_by_config(network_config)
# Create policy
policy_config = config["policy"]["policy"]
policy_config["kwargs"] = {**policy_config.get("kwargs", {}), **additional_policy_kwargs}
policy: BasePolicy = init_instance_by_config(policy_config)
use_cuda = config["runtime"]["use_cuda"]
if use_cuda:
policy.cuda()
train_and_test(
freq=config["general"]["freq"],
concurrency=config["runtime"]["concurrency"],
parallel_mode=config["runtime"]["parallel_mode"],
training_config=config["training"],
simulator_config=config["simulator"],
policy=policy,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
reward=reward,
run_training=run_training,
run_backtest=run_backtest,
)
if __name__ == "__main__":
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
parser.add_argument("--no_training", action="store_true", help="Skip training workflow.")
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.")
args = parser.parse_args()
config_parser = TrainingConfigParser(args.config_path)
config = config_parser.parse()
main(config, run_training=not args.no_training, run_backtest=args.run_backtest)

View File

@@ -1,261 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
import os
import random
import warnings
from pathlib import Path
from typing import cast, List, Optional
import numpy as np
import pandas as pd
import qlib
import torch
import yaml
from qlib.backtest import Order
from qlib.backtest.decision import OrderDir
from qlib.constant import ONE_MIN
from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
from qlib.rl.reward import Reward
from qlib.rl.trainer import Checkpoint, backtest, train
from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter
from qlib.rl.utils.log import CsvWriter
from qlib.utils import init_instance_by_config
from tianshou.policy import BasePolicy
from torch.utils.data import Dataset
def seed_everything(seed: int) -> None:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def _read_orders(order_dir: Path) -> pd.DataFrame:
if os.path.isfile(order_dir):
return pd.read_pickle(order_dir)
else:
orders = []
for file in order_dir.iterdir():
order_data = pd.read_pickle(file)
orders.append(order_data)
return pd.concat(orders)
class LazyLoadDataset(Dataset):
def __init__(
self,
order_file_path: Path,
data_dir: Path,
default_start_time_index: int,
default_end_time_index: int,
) -> None:
self._default_start_time_index = default_start_time_index
self._default_end_time_index = default_end_time_index
self._order_file_path = order_file_path
self._order_df = _read_orders(order_file_path).reset_index()
self._data_dir = data_dir
self._ticks_index: Optional[pd.DatetimeIndex] = None
def __len__(self) -> int:
return len(self._order_df)
def __getitem__(self, index: int) -> Order:
row = self._order_df.iloc[index]
date = pd.Timestamp(str(row["date"]))
if self._ticks_index is None:
# TODO: We only load ticks index once based on the assumption that ticks index of different dates
# TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index
# TODO: of all dates.
backtest_data = load_simple_intraday_backtest_data(
data_dir=self._data_dir,
stock_id=row["instrument"],
date=date,
)
self._ticks_index = [t - date for t in backtest_data.get_time_index()]
order = Order(
stock_id=row["instrument"],
amount=row["amount"],
direction=OrderDir(int(row["order_type"])),
start_time=date + self._ticks_index[self._default_start_time_index],
end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN,
)
return order
def train_and_test(
env_config: dict,
simulator_config: dict,
trainer_config: dict,
data_config: dict,
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
policy: BasePolicy,
reward: Reward,
run_training: bool,
run_backtest: bool,
) -> None:
qlib.init()
order_root_path = Path(data_config["source"]["order_dir"])
data_granularity = simulator_config.get("data_granularity", 1)
def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
return SingleAssetOrderExecutionSimple(
order=order,
data_dir=Path(data_config["source"]["data_dir"]),
ticks_per_step=simulator_config["time_per_step"],
data_granularity=data_granularity,
deal_price_type=data_config["source"].get("deal_price_column", "close"),
vol_threshold=simulator_config["vol_limit"],
)
assert data_config["source"]["default_start_time_index"] % data_granularity == 0
assert data_config["source"]["default_end_time_index"] % data_granularity == 0
if run_training:
train_dataset, valid_dataset = [
LazyLoadDataset(
order_file_path=order_root_path / tag,
data_dir=Path(data_config["source"]["data_dir"]),
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
)
for tag in ("train", "valid")
]
callbacks: List[Callback] = []
if "checkpoint_path" in trainer_config:
callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"])))
callbacks.append(
Checkpoint(
dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints",
every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1),
save_latest="copy",
),
)
if "earlystop_patience" in trainer_config:
callbacks.append(
EarlyStopping(
patience=trainer_config["earlystop_patience"],
monitor="val/pa",
)
)
train(
simulator_fn=_simulator_factory_simple,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
policy=policy,
reward=reward,
initial_states=cast(List[Order], train_dataset),
trainer_kwargs={
"max_iters": trainer_config["max_epoch"],
"finite_env_type": env_config["parallel_mode"],
"concurrency": env_config["concurrency"],
"val_every_n_iters": trainer_config.get("val_every_n_epoch", None),
"callbacks": callbacks,
},
vessel_kwargs={
"episode_per_iter": trainer_config["episode_per_collect"],
"update_kwargs": {
"batch_size": trainer_config["batch_size"],
"repeat": trainer_config["repeat_per_collect"],
},
"val_initial_states": valid_dataset,
},
)
if run_backtest:
test_dataset = LazyLoadDataset(
order_file_path=order_root_path / "test",
data_dir=Path(data_config["source"]["data_dir"]),
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
)
backtest(
simulator_fn=_simulator_factory_simple,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
initial_states=test_dataset,
policy=policy,
logger=CsvWriter(Path(trainer_config["checkpoint_path"])),
reward=reward,
finite_env_type=env_config["parallel_mode"],
concurrency=env_config["concurrency"],
)
def main(config: dict, run_training: bool, run_backtest: bool) -> None:
if not run_training and not run_backtest:
warnings.warn("Skip the entire job since training and backtest are both skipped.")
return
if "seed" in config["runtime"]:
seed_everything(config["runtime"]["seed"])
state_interpreter: StateInterpreter = init_instance_by_config(config["state_interpreter"])
action_interpreter: ActionInterpreter = init_instance_by_config(config["action_interpreter"])
reward: Reward = init_instance_by_config(config["reward"])
additional_policy_kwargs = {
"obs_space": state_interpreter.observation_space,
"action_space": action_interpreter.action_space,
}
# Create torch network
if "network" in config:
if "kwargs" not in config["network"]:
config["network"]["kwargs"] = {}
config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space})
additional_policy_kwargs["network"] = init_instance_by_config(config["network"])
# Create policy
if "kwargs" not in config["policy"]:
config["policy"]["kwargs"] = {}
config["policy"]["kwargs"].update(additional_policy_kwargs)
policy: BasePolicy = init_instance_by_config(config["policy"])
use_cuda = config["runtime"].get("use_cuda", False)
if use_cuda:
policy.cuda()
train_and_test(
env_config=config["env"],
simulator_config=config["simulator"],
data_config=config["data"],
trainer_config=config["trainer"],
action_interpreter=action_interpreter,
state_interpreter=state_interpreter,
policy=policy,
reward=reward,
run_training=run_training,
run_backtest=run_backtest,
)
if __name__ == "__main__":
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
parser.add_argument("--no_training", action="store_true", help="Skip training workflow.")
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.")
args = parser.parse_args()
with open(args.config_path, "r") as input_stream:
config = yaml.safe_load(input_stream)
main(config, run_training=not args.no_training, run_backtest=args.run_backtest)

View File

@@ -8,48 +8,14 @@ TODO: The implementation here is kind of adhoc. It is better to design a more un
from __future__ import annotations
import pickle
from pathlib import Path
from typing import List
import cachetools
import numpy as np
import pandas as pd
import qlib
from qlib.constant import REG_CN
from qlib.contrib.ops.high_freq import BFillNan, Cut, Date, DayCumsum, DayLast, FFillNan, IsInf, IsNull, Select
from qlib.data.dataset import DatasetH
dataset = None
class DataWrapper:
def __init__(
self,
feature_dataset: DatasetH,
backtest_dataset: DatasetH,
columns_today: List[str],
columns_yesterday: List[str],
_internal: bool = False,
):
assert _internal, "Init function of data wrapper is for internal use only."
self.feature_dataset = feature_dataset
self.backtest_dataset = backtest_dataset
self.columns_today = columns_today
self.columns_yesterday = columns_yesterday
@cachetools.cached( # type: ignore
cache=cachetools.LRUCache(100),
key=lambda _, stock_id, date, backtest: (stock_id, date.replace(hour=0, minute=0, second=0), backtest),
)
def get(self, stock_id: str, date: pd.Timestamp, backtest: bool = False) -> pd.DataFrame:
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
dataset = self.backtest_dataset if backtest else self.feature_dataset
return dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
def init_qlib(qlib_config: dict, part: str | None = None) -> None:
def init_qlib(qlib_config: dict) -> None:
"""Initialize necessary resource to launch the workflow, including data direction, feature columns, etc..
Parameters
@@ -72,12 +38,8 @@ def init_qlib(qlib_config: dict, part: str | None = None) -> None:
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1",
],
}
part
Identifying which part (stock / date) to load.
"""
global dataset # pylint: disable=W0603
def _convert_to_path(path: str | Path) -> Path:
return path if isinstance(path, Path) else Path(path)
@@ -118,47 +80,3 @@ def init_qlib(qlib_config: dict, part: str | None = None) -> None:
redis_port=-1,
clear_mem_cache=False, # init_qlib will be called for multiple times. Keep the cache for improving performance
)
if part == "skip":
return
# this won't work if it's put outside in case of multiprocessing
from qlib.data import D # noqa pylint: disable=C0415,W0611
if part is None:
feature_path = Path(qlib_config["feature_root_dir"]) / "feature.pkl"
backtest_path = Path(qlib_config["feature_root_dir"]) / "backtest.pkl"
else:
feature_path = Path(qlib_config["feature_root_dir"]) / "feature" / (part + ".pkl")
backtest_path = Path(qlib_config["feature_root_dir"]) / "backtest" / (part + ".pkl")
with feature_path.open("rb") as f:
feature_dataset = pickle.load(f)
with backtest_path.open("rb") as f:
backtest_dataset = pickle.load(f)
dataset = DataWrapper(
feature_dataset,
backtest_dataset,
qlib_config["feature_columns_today"],
qlib_config["feature_columns_yesterday"],
_internal=True,
)
def fetch_features(stock_id: str, date: pd.Timestamp, yesterday: bool = False, backtest: bool = False) -> pd.DataFrame:
assert dataset is not None, "You must call init_qlib() before doing this."
if backtest:
fields = ["$close", "$volume"]
else:
fields = dataset.columns_yesterday if yesterday else dataset.columns_today
data = dataset.get(stock_id, date, backtest)
if data is None or len(data) == 0:
# create a fake index, but RL doesn't care about index
data = pd.DataFrame(0.0, index=np.arange(240), columns=fields, dtype=np.float32) # FIXME: hardcode here
else:
data = data.rename(columns={c: c.rstrip("0") for c in data.columns})
data = data[fields]
return data

View File

@@ -2,17 +2,30 @@
# Licensed under the MIT License.
from __future__ import annotations
from typing import cast
from pathlib import Path
from typing import cast, List
import cachetools
import pandas as pd
import pickle
import os
from qlib.backtest import Exchange, Order
from qlib.backtest.decision import TradeRange, TradeRangeByTime
from qlib.rl.order_execution.utils import get_ticks_slice
from qlib.constant import EPS_T
from qlib.data.dataset import DatasetH
from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
from .integration import fetch_features
def get_ticks_slice(
ticks_index: pd.DatetimeIndex,
start: pd.Timestamp,
end: pd.Timestamp,
include_end: bool = False,
) -> pd.DatetimeIndex:
if not include_end:
end = end - EPS_T
return ticks_index[ticks_index.slice_indexer(start, end)]
class IntradayBacktestData(BaseIntradayBacktestData):
@@ -71,6 +84,31 @@ class IntradayBacktestData(BaseIntradayBacktestData):
return pd.DatetimeIndex([e[1] for e in list(self._exchange.quote_df.index)])
class DataframeIntradayBacktestData(BaseIntradayBacktestData):
"""Backtest data from dataframe"""
def __init__(self, df: pd.DataFrame, price_column: str = "$close0", volume_column: str = "$volume0") -> None:
self.df = df
self.price_column = price_column
self.volume_column = volume_column
def __repr__(self) -> str:
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
return f"{self.__class__.__name__}({self.df})"
def __len__(self) -> int:
return len(self.df)
def get_deal_price(self) -> pd.Series:
return self.df[self.price_column]
def get_volume(self) -> pd.Series:
return self.df[self.volume_column]
def get_time_index(self) -> pd.DatetimeIndex:
return cast(pd.DatetimeIndex, self.df.index)
@cachetools.cached( # type: ignore
cache=cachetools.LRUCache(100),
key=lambda order, _, __: order.key_by_day,
@@ -103,13 +141,27 @@ def load_backtest_data(
return backtest_data
class NTIntradayProcessedData(BaseIntradayProcessedData):
"""Subclass of IntradayProcessedData. Used to handle NT style data."""
@cachetools.cached( # type: ignore
cache=cachetools.LRUCache(1000),
key=lambda path: path,
)
def _load_handler_pickle(path: str) -> DatasetH:
with open(path, "rb") as fstream:
obj = pickle.load(fstream)
return obj
class HandlerIntradayProcessedData(BaseIntradayProcessedData):
"""Subclass of IntradayProcessedData. Used to handle handler (bin format) style data."""
def __init__(
self,
data_dir: Path,
stock_id: str,
date: pd.Timestamp,
feature_columns_today: List[str],
feature_columns_yesterday: List[str],
backtest: bool = False,
) -> None:
def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame:
df = df.reset_index()
@@ -117,22 +169,52 @@ class NTIntradayProcessedData(BaseIntradayProcessedData):
df = df.drop(columns=["instrument"])
return df.set_index(["datetime"])
self.today = _drop_stock_id(fetch_features(stock_id, date))
self.yesterday = _drop_stock_id(fetch_features(stock_id, date, yesterday=True))
path = os.path.join(data_dir, "backtest" if backtest else "feature", f"{stock_id}.pkl")
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
dataset = _load_handler_pickle(path)
data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
self.today = _drop_stock_id(data[feature_columns_today])
self.yesterday = _drop_stock_id(data[feature_columns_yesterday])
def __repr__(self) -> str:
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
return f"{self.__class__.__name__}({self.today}, {self.yesterday})"
@cachetools.cached( # type: ignore
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
)
def load_nt_intraday_processed_data(stock_id: str, date: pd.Timestamp) -> NTIntradayProcessedData:
return NTIntradayProcessedData(stock_id, date)
def load_handler_intraday_processed_data(
data_dir: Path,
stock_id: str,
date: pd.Timestamp,
feature_columns_today: List[str],
feature_columns_yesterday: List[str],
backtest: bool = False,
) -> HandlerIntradayProcessedData:
return HandlerIntradayProcessedData(
data_dir,
stock_id,
date,
feature_columns_today,
feature_columns_yesterday,
backtest,
)
class NTProcessedDataProvider(ProcessedDataProvider):
class HandlerProcessedDataProvider(ProcessedDataProvider):
def __init__(
self,
data_dir: str,
feature_columns_today: List[str],
feature_columns_yesterday: List[str],
backtest: bool = False,
) -> None:
super().__init__()
self.data_dir = Path(data_dir)
self.feature_columns_today = feature_columns_today
self.feature_columns_yesterday = feature_columns_yesterday
self.backtest = backtest
def get_data(
self,
stock_id: str,
@@ -140,4 +222,11 @@ class NTProcessedDataProvider(ProcessedDataProvider):
feature_dim: int,
time_index: pd.Index,
) -> BaseIntradayProcessedData:
return load_nt_intraday_processed_data(stock_id, date)
return load_handler_intraday_processed_data(
self.data_dir,
stock_id,
date,
self.feature_columns_today,
self.feature_columns_yesterday,
backtest=self.backtest,
)

View File

@@ -26,7 +26,6 @@ from typing import List, Sequence, cast
import cachetools
import numpy as np
import pandas as pd
from cachetools.keys import hashkey
from qlib.backtest.decision import Order, OrderDir
from qlib.rl.data.base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
@@ -158,44 +157,35 @@ class SimpleIntradayBacktestData(BaseIntradayBacktestData):
return cast(pd.DatetimeIndex, self.data.index)
class IntradayProcessedData(BaseIntradayProcessedData):
"""Subclass of IntradayProcessedData. Used to handle Dataset Handler style data."""
@cachetools.cached( # type: ignore
cache=cachetools.LRUCache(1000),
key=lambda path: path,
)
def _load_df_pickle(path: str) -> pd.DataFrame:
df = pd.read_pickle(path)
return df
class PickleIntradayProcessedData(BaseIntradayProcessedData):
"""Subclass of IntradayProcessedData. Used to handle pickle-styled data."""
def __init__(
self,
data_dir: Path | str,
stock_id: str,
date: pd.Timestamp,
feature_dim: int,
time_index: pd.Index,
feature_columns_today: List[str],
feature_columns_yesterday: List[str],
backtest: bool,
) -> None:
proc = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id)
if isinstance(data_dir, str):
data_dir = Path(data_dir)
path = data_dir / ("backtest" if backtest else "feature") / f"{stock_id}.pkl"
df = _load_df_pickle(str(path))
df = df.loc[pd.IndexSlice[stock_id, :, date]]
# We have to infer the names here because,
# unfortunately they are not included in the original data.
cnames = _infer_processed_data_column_names(feature_dim)
time_length: int = len(time_index)
try:
# new data format
proc = proc.loc[pd.IndexSlice[stock_id, :, date]]
assert len(proc) == time_length and len(proc.columns) == feature_dim * 2
proc_today = proc[cnames]
proc_yesterday = proc[[f"{c}_1" for c in cnames]].rename(columns=lambda c: c[:-2])
except (IndexError, KeyError):
# legacy data
proc = proc.loc[pd.IndexSlice[stock_id, date]]
assert time_length * feature_dim * 2 == len(proc)
proc_today = proc.to_numpy()[: time_length * feature_dim].reshape((time_length, feature_dim))
proc_yesterday = proc.to_numpy()[time_length * feature_dim :].reshape((time_length, feature_dim))
proc_today = pd.DataFrame(proc_today, index=time_index, columns=cnames)
proc_yesterday = pd.DataFrame(proc_yesterday, index=time_index, columns=cnames)
self.today: pd.DataFrame = proc_today
self.yesterday: pd.DataFrame = proc_yesterday
assert len(self.today.columns) == len(self.yesterday.columns) == feature_dim
assert len(self.today) == len(self.yesterday) == time_length
self.today = df[feature_columns_today]
self.yesterday = df[feature_columns_yesterday]
def __repr__(self) -> str:
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
@@ -213,25 +203,38 @@ def load_simple_intraday_backtest_data(
return SimpleIntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir)
@cachetools.cached( # type: ignore
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
key=lambda data_dir, stock_id, date, feature_dim, time_index: hashkey(data_dir, stock_id, date),
)
def load_pickled_intraday_processed_data(
def load_pickle_intraday_processed_data(
data_dir: Path,
stock_id: str,
date: pd.Timestamp,
feature_dim: int,
time_index: pd.Index,
feature_columns_today: List[str],
feature_columns_yesterday: List[str],
backtest: bool = False,
) -> BaseIntradayProcessedData:
return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index)
return PickleIntradayProcessedData(
data_dir,
stock_id,
date,
feature_columns_today,
feature_columns_yesterday,
backtest,
)
class PickleProcessedDataProvider(ProcessedDataProvider):
def __init__(self, data_dir: Path) -> None:
def __init__(
self,
data_dir: Path,
feature_columns_today: List[str],
feature_columns_yesterday: List[str],
backtest: bool = False,
) -> None:
super().__init__()
self._data_dir = data_dir
self._backtest = backtest
self._feature_columns_today = feature_columns_today
self._feature_columns_yesterday = feature_columns_yesterday
def get_data(
self,
@@ -240,12 +243,13 @@ class PickleProcessedDataProvider(ProcessedDataProvider):
feature_dim: int,
time_index: pd.Index,
) -> BaseIntradayProcessedData:
return load_pickled_intraday_processed_data(
return load_pickle_intraday_processed_data(
data_dir=self._data_dir,
stock_id=stock_id,
date=date,
feature_dim=feature_dim,
time_index=time_index,
feature_columns_today=self._feature_columns_today,
feature_columns_yesterday=self._feature_columns_yesterday,
backtest=self._backtest,
)

View File

@@ -4,10 +4,11 @@
from __future__ import annotations
from typing import Generator, List, Optional
import cachetools
import pandas as pd
from qlib.backtest import collect_data_loop, get_strategy_executor
from qlib.backtest import collect_data_loop, Exchange, get_exchange, get_strategy_executor
from qlib.backtest.decision import BaseTradeDecision, Order, TradeRangeByTime
from qlib.backtest.executor import NestedExecutor
from qlib.rl.data.integration import init_qlib
@@ -16,6 +17,18 @@ from .state import SAOEState
from .strategy import SAOEStateAdapter, SAOEStrategy
@cachetools.cached( # type: ignore
cache=cachetools.LRUCache(1000),
key=lambda order, _: order.stock_id,
)
def _create_exchange(order: Order, exchange_config: dict) -> Exchange:
exchange_kwargs = {
**exchange_config,
"codes": [order.stock_id],
}
return get_exchange(**exchange_kwargs)
class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
"""Single-asset order execution (SAOE) simulator which is implemented based on Qlib backtest tools.
@@ -67,7 +80,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
cash_limit: Optional[float] = None,
) -> None:
if qlib_config is not None:
init_qlib(qlib_config, part="skip")
init_qlib(qlib_config)
strategy, self._executor = get_strategy_executor(
start_time=order.date,
@@ -76,7 +89,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
executor=executor_config,
benchmark=order.stock_id,
account=cash_limit if cash_limit is not None else int(1e12),
exchange_kwargs=exchange_config,
exchange_kwargs=_create_exchange(order, exchange_config),
pos_type="Position" if cash_limit is not None else "InfPosition",
)
@@ -90,6 +103,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
trade_strategy=strategy,
trade_executor=self._executor,
return_value=self.report_dict,
show_progress=False,
)
assert isinstance(self._collect_data_loop, Generator)

View File

@@ -3,17 +3,20 @@
from __future__ import annotations
from pathlib import Path
from typing import Any, cast, Optional
from typing import Any, cast, List, Optional
import numpy as np
import pandas as pd
from pathlib import Path
from qlib.backtest.decision import Order, OrderDir
from qlib.constant import EPS, EPS_T, float_or_ndarray
from qlib.rl.data.pickle_styled import DealPriceType, load_simple_intraday_backtest_data
from qlib.rl.data.base import BaseIntradayBacktestData
from qlib.rl.data.native import DataframeIntradayBacktestData
from qlib.rl.data.pickle_styled import load_pickle_intraday_processed_data
from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data
from qlib.rl.simulator import Simulator
from qlib.rl.utils import LogLevel
from .state import SAOEMetrics, SAOEState
__all__ = ["SingleAssetOrderExecutionSimple"]
@@ -36,12 +39,14 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
----------
order
The seed to start an SAOE simulator is an order.
data_dir
Path to load backtest data.
feature_columns_today
Columns of today's feature.
data_granularity
Number of ticks between consecutive data entries.
ticks_per_step
How many ticks per step.
data_dir
Path to load backtest data
vol_threshold
Maximum execution volume (divided by market execution volume).
"""
@@ -73,9 +78,9 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
self,
order: Order,
data_dir: Path,
feature_columns_today: List[str] = [],
data_granularity: int = 1,
ticks_per_step: int = 30,
deal_price_type: DealPriceType = "close",
vol_threshold: Optional[float] = None,
) -> None:
super().__init__(initial=order)
@@ -83,18 +88,12 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
assert ticks_per_step % data_granularity == 0
self.order = order
self.ticks_per_step: int = ticks_per_step // data_granularity
self.deal_price_type = deal_price_type
self.vol_threshold = vol_threshold
self.data_dir = data_dir
self.backtest_data = load_simple_intraday_backtest_data(
self.data_dir,
order.stock_id,
pd.Timestamp(order.start_time.date()),
self.deal_price_type,
order.direction,
)
self.feature_columns_today = feature_columns_today
self.ticks_per_step: int = ticks_per_step // data_granularity
self.vol_threshold = vol_threshold
self.backtest_data = self.get_backtest_data()
self.ticks_index = self.backtest_data.get_time_index()
# Get time index available for trading
@@ -118,6 +117,29 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
self.market_vol: Optional[np.ndarray] = None
self.market_vol_limit: Optional[np.ndarray] = None
def get_backtest_data(self) -> BaseIntradayBacktestData:
try:
data = load_pickle_intraday_processed_data(
data_dir=self.data_dir,
stock_id=self.order.stock_id,
date=pd.Timestamp(self.order.start_time.date()),
feature_columns_today=self.feature_columns_today,
feature_columns_yesterday=[],
backtest=True,
)
return DataframeIntradayBacktestData(data.today)
except (AttributeError, FileNotFoundError):
# TODO: For compatibility with older versions of test scripts (tests/rl/test_saoe_simple.py)
# TODO: In the future, we should modify the data format used by the test script,
# TODO: and then delete this branch.
return load_simple_intraday_backtest_data(
self.data_dir / "backtest",
self.order.stock_id,
pd.Timestamp(self.order.start_time.date()),
"close",
self.order.direction,
)
def step(self, amount: float) -> None:
"""Execute one step or SAOE.

View File

@@ -451,6 +451,7 @@ class SAOEIntStrategy(SAOEStrategy):
state_interpreter: dict | StateInterpreter,
action_interpreter: dict | ActionInterpreter,
network: dict | torch.nn.Module | None = None,
immediate_addition: bool = False,
outer_trade_decision: BaseTradeDecision | None = None,
level_infra: LevelInfrastructure | None = None,
common_infra: CommonInfrastructure | None = None,
@@ -501,9 +502,12 @@ class SAOEIntStrategy(SAOEStrategy):
if self._policy is not None:
self._policy.eval()
self.immediate_addition = immediate_addition
def reset(self, outer_trade_decision: BaseTradeDecision | None = None, **kwargs: Any) -> None:
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
self.trade_amount_planned = collections.defaultdict(float)
def _generate_trade_details(self, act: np.ndarray, exec_vols: List[float]) -> pd.DataFrame:
assert hasattr(self.outer_trade_decision, "order_list")
@@ -539,9 +543,15 @@ class SAOEIntStrategy(SAOEStrategy):
oh = self.trade_exchange.get_order_helper()
order_list = []
for decision, exec_vol in zip(self.outer_trade_decision.get_decision(), exec_vols):
for decision, exec_vol, state in zip(self.outer_trade_decision.get_decision(), exec_vols, states):
order = cast(Order, decision)
if self.immediate_addition:
self.trade_amount_planned[order.stock_id] += exec_vol
amount_planned = self.trade_amount_planned[order.stock_id]
amount_finished = order.amount - state.position
exec_vol = min(state.position, amount_planned - amount_finished)
if exec_vol != 0:
order = cast(Order, decision)
order_list.append(oh.create(order.stock_id, exec_vol, order.direction))
return TradeDecisionWithDetails(

View File

@@ -10,18 +10,7 @@ import pandas as pd
from qlib.backtest.decision import OrderDir
from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor
from qlib.constant import EPS_T, float_or_ndarray
def get_ticks_slice(
ticks_index: pd.DatetimeIndex,
start: pd.Timestamp,
end: pd.Timestamp,
include_end: bool = False,
) -> pd.DatetimeIndex:
if not include_end:
end = end - EPS_T
return ticks_index[ticks_index.slice_indexer(start, end)]
from qlib.constant import float_or_ndarray
def dataframe_append(df: pd.DataFrame, other: Any) -> pd.DataFrame:

View File

@@ -20,7 +20,7 @@ def train(
simulator_fn: Callable[[InitialStateType], Simulator],
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
initial_states: Sequence[InitialStateType],
initial_states: List[Sequence[InitialStateType]],
policy: BasePolicy,
reward: Reward,
vessel_kwargs: Dict[str, Any],
@@ -39,7 +39,9 @@ def train(
action_interpreter
Interprets the policy actions.
initial_states
Initial states to iterate over. Every state will be run exactly once.
List of Initial state iterators to iterate over. There should be 1 or `concurrency` initial state iterators in
the list. If there is only 1 initial state iterator, this iterator will be shared by all workers and every
state will be run exactly once. Otherwise, every worker will have its own iterator.
policy
Policy to train against.
reward
@@ -67,7 +69,7 @@ def backtest(
simulator_fn: Callable[[InitialStateType], Simulator],
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
initial_states: Sequence[InitialStateType],
initial_states: List[Sequence[InitialStateType]],
policy: BasePolicy,
logger: LogWriter | List[LogWriter],
reward: Reward | None = None,
@@ -87,7 +89,9 @@ def backtest(
action_interpreter
Interprets the policy actions.
initial_states
Initial states to iterate over. Every state will be run exactly once.
List of Initial state iterators to iterate over. There should be 1 or `concurrency` initial state iterators in
the list. If there is only 1 initial state iterator, this iterator will be shared by all workers and every
state will be run exactly once. Otherwise, every worker will have its own iterator.
policy
Policy to test against.
logger

View File

@@ -5,8 +5,9 @@ from __future__ import annotations
import collections
import copy
from contextlib import AbstractContextManager, contextmanager
from contextlib import AbstractContextManager, ExitStack, contextmanager
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import Any, Dict, Iterable, List, OrderedDict, Sequence, TypeVar, cast
@@ -206,45 +207,50 @@ class Trainer:
self._call_callback_hooks("on_fit_start")
while not self.should_stop:
msg = f"\n{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\tTrain iteration {self.current_iter + 1}/{self.max_iters}"
_logger.info(msg)
with _wrap_context(vessel.train_seed_iterators()) as train_iterators, _wrap_context(
vessel.val_seed_iterators()
) as valid_iterators:
train_vector_env = self.venv_from_iterator(train_iterators)
valid_vector_env = self.venv_from_iterator(valid_iterators)
self.initialize_iter()
while not self.should_stop:
msg = f"\n{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\tTrain iteration {self.current_iter + 1}/{self.max_iters}"
print(msg)
_logger.info(msg)
self._call_callback_hooks("on_iter_start")
self.initialize_iter()
self.current_stage = "train"
self._call_callback_hooks("on_train_start")
self._call_callback_hooks("on_iter_start")
# TODO
# Add a feature that supports reloading the training environment every few iterations.
with _wrap_context(vessel.train_seed_iterator()) as iterator:
vector_env = self.venv_from_iterator(iterator)
self.vessel.train(vector_env)
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
self.current_stage = "train"
self._call_callback_hooks("on_train_start")
self._call_callback_hooks("on_train_end")
# TODO
# Add a feature that supports reloading the training environment every few iterations.
self.vessel.train(train_vector_env)
if self.val_every_n_iters is not None and (self.current_iter + 1) % self.val_every_n_iters == 0:
# Implementation of validation loop
self.current_stage = "val"
self._call_callback_hooks("on_validate_start")
with _wrap_context(vessel.val_seed_iterator()) as iterator:
vector_env = self.venv_from_iterator(iterator)
self.vessel.validate(vector_env)
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
self._call_callback_hooks("on_train_end")
self._call_callback_hooks("on_validate_end")
if self.val_every_n_iters is not None and (self.current_iter + 1) % self.val_every_n_iters == 0:
# Implementation of validation loop
self.current_stage = "val"
self._call_callback_hooks("on_validate_start")
# This iteration is considered complete.
# Bumping the current iteration counter.
self.current_iter += 1
self.vessel.validate(valid_vector_env)
if self.max_iters is not None and self.current_iter >= self.max_iters:
self.should_stop = True
self._call_callback_hooks("on_validate_end")
self._call_callback_hooks("on_iter_end")
# This iteration is considered complete.
# Bumping the current iteration counter.
self.current_iter += 1
if self.max_iters is not None and self.current_iter >= self.max_iters:
self.should_stop = True
self._call_callback_hooks("on_iter_end")
del train_vector_env # FIXME: Explicitly delete this object to avoid memory leak.
del valid_vector_env # FIXME: Explicitly delete this object to avoid memory leak.
self._call_callback_hooks("on_fit_end")
@@ -265,16 +271,16 @@ class Trainer:
self.current_stage = "test"
self._call_callback_hooks("on_test_start")
with _wrap_context(vessel.test_seed_iterator()) as iterator:
vector_env = self.venv_from_iterator(iterator)
with _wrap_context(vessel.test_seed_iterators()) as iterators:
vector_env = self.venv_from_iterator(iterators)
self.vessel.test(vector_env)
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
self._call_callback_hooks("on_test_end")
def venv_from_iterator(self, iterator: Iterable[InitialStateType]) -> FiniteVectorEnv:
def venv_from_iterator(self, iterators: List[Iterable[InitialStateType]]) -> FiniteVectorEnv:
"""Create a vectorized environment from iterator and the training vessel."""
def env_factory():
def env_factory(iterator):
# FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env),
# and could be thread unsafe.
# I'm not sure whether it's a design flaw.
@@ -300,7 +306,7 @@ class Trainer:
)
return vectorize_env(
env_factory,
[partial(env_factory, iterator=it) for it in iterators],
self.finite_env_type,
self.concurrency,
self.loggers,
@@ -334,8 +340,11 @@ class Trainer:
@contextmanager
def _wrap_context(obj):
"""Make any object a (possibly dummy) context manager."""
if isinstance(obj, AbstractContextManager):
if isinstance(obj, list) and isinstance(obj[0], AbstractContextManager):
with ExitStack() as stack:
yield [stack.enter_context(e) for e in obj]
stack.pop_all().close()
elif isinstance(obj, AbstractContextManager):
# obj has __enter__ and __exit__
with obj as ctx:
yield ctx

View File

@@ -4,7 +4,7 @@
from __future__ import annotations
import weakref
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generic, Iterable, Sequence, TypeVar, cast
from typing import List, TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generic, Iterable, Sequence, TypeVar, cast
import numpy as np
from tianshou.data import Collector, VectorReplayBuffer
@@ -49,19 +49,23 @@ class TrainingVesselBase(Generic[InitialStateType, StateType, ActType, ObsType,
def assign_trainer(self, trainer: Trainer) -> None:
self.trainer = weakref.proxy(trainer) # type: ignore
def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
"""Override this to create a seed iterator for training.
def train_seed_iterators(
self,
) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
"""Override this to create a seed iterators for training.
If the iterable is a context manager, the whole training will be invoked in the with-block,
and the iterator will be automatically closed after the training is done."""
raise SeedIteratorNotAvailable("Seed iterator for training is not available.")
raise SeedIteratorNotAvailable("Seed iterators for training is not available.")
def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
"""Override this to create a seed iterator for validation."""
raise SeedIteratorNotAvailable("Seed iterator for validation is not available.")
def val_seed_iterators(self) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
"""Override this to create a seed iterators for validation."""
raise SeedIteratorNotAvailable("Seed iterators for validation is not available.")
def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
"""Override this to create a seed iterator for testing."""
raise SeedIteratorNotAvailable("Seed iterator for testing is not available.")
def test_seed_iterators(
self,
) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
"""Override this to create a seed iterators for testing."""
raise SeedIteratorNotAvailable("Seed iterators for testing is not available.")
def train(self, vector_env: BaseVectorEnv) -> Dict[str, Any]:
"""Implement this to train one iteration. In RL, one iteration usually refers to one collect."""
@@ -120,9 +124,9 @@ class TrainingVessel(TrainingVesselBase):
action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType],
policy: BasePolicy,
reward: Reward,
train_initial_states: Sequence[InitialStateType] | None = None,
val_initial_states: Sequence[InitialStateType] | None = None,
test_initial_states: Sequence[InitialStateType] | None = None,
train_initial_states: List[Sequence[InitialStateType]] | None = None,
val_initial_states: List[Sequence[InitialStateType]] | None = None,
test_initial_states: List[Sequence[InitialStateType]] | None = None,
buffer_size: int = 20000,
episode_per_iter: int = 1000,
update_kwargs: Dict[str, Any] = cast(Dict[str, Any], None),
@@ -132,34 +136,49 @@ class TrainingVessel(TrainingVesselBase):
self.action_interpreter = action_interpreter
self.policy = policy
self.reward = reward
self.train_initial_states = train_initial_states
self.val_initial_states = val_initial_states
self.test_initial_states = test_initial_states
self.train_initial_states = None if train_initial_states is None else train_initial_states
self.val_initial_states = None if val_initial_states is None else val_initial_states
self.test_initial_states = None if test_initial_states is None else test_initial_states
self.buffer_size = buffer_size
self.episode_per_iter = episode_per_iter
self.update_kwargs = update_kwargs or {}
def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
def train_seed_iterators(
self,
) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
if self.train_initial_states is not None:
_logger.info("Training initial states collection size: %d", len(self.train_initial_states))
# Implement fast_dev_run here.
train_initial_states = self._random_subset("train", self.train_initial_states, self.trainer.fast_dev_run)
return DataQueue(train_initial_states, repeat=-1, shuffle=True)
return super().train_seed_iterator()
_logger.info(f"Training initial states collection sizes: {[len(e) for e in self.train_initial_states]}")
train_initial_states = [
self._random_subset("train", e, self.trainer.fast_dev_run) for e in self.train_initial_states
]
iterators = [DataQueue(e, repeat=-1, shuffle=True) for e in train_initial_states]
return cast(List[Iterable[InitialStateType]], iterators)
else:
return super().train_seed_iterators()
def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
def val_seed_iterators(self) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
if self.val_initial_states is not None:
_logger.info("Validation initial states collection size: %d", len(self.val_initial_states))
val_initial_states = self._random_subset("val", self.val_initial_states, self.trainer.fast_dev_run)
return DataQueue(val_initial_states, repeat=1)
return super().val_seed_iterator()
_logger.info(f"Validation initial states collection sizes: {[len(e) for e in self.val_initial_states]}")
val_initial_states = [
self._random_subset("val", e, self.trainer.fast_dev_run) for e in self.val_initial_states
]
iterators = [DataQueue(e, repeat=1) for e in val_initial_states]
return cast(List[Iterable[InitialStateType]], iterators)
else:
return super().val_seed_iterators()
def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
def test_seed_iterators(
self,
) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
if self.test_initial_states is not None:
_logger.info("Testing initial states collection size: %d", len(self.test_initial_states))
test_initial_states = self._random_subset("test", self.test_initial_states, self.trainer.fast_dev_run)
return DataQueue(test_initial_states, repeat=1)
return super().test_seed_iterator()
_logger.info(f"Testing initial states collection sizes: {[len(e) for e in self.test_initial_states]}")
test_initial_states = [
self._random_subset("test", e, self.trainer.fast_dev_run) for e in self.test_initial_states
]
iterators = [DataQueue(e, repeat=1) for e in test_initial_states]
return cast(List[Iterable[InitialStateType]], iterators)
else:
return super().test_seed_iterators()
def train(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]:
"""Create a collector and collects ``episode_per_iter`` episodes.

View File

@@ -258,6 +258,46 @@ class FiniteVectorEnv(BaseVectorEnv):
return np.stack(obs)
def step2(
self,
action: np.ndarray,
id: int | List[int] | np.ndarray | None = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
assert not self._zombie
wrapped_id = self._wrap_id(id)
id2idx = {i: k for k, i in enumerate(wrapped_id)}
request_id = list(filter(lambda i: i in self._alive_env_ids, wrapped_id))
result = {}
# ask super to step alive envs and remap to current index
if request_id:
valid_act = np.stack([action[id2idx[i]] for i in request_id])
tmp = super().step(valid_act, request_id)
for obs_next, rew, done, info in zip(*tmp):
obs_next = self._postproc_env_obs(obs_next)
result[info["env_id"]] = [obs_next, rew, done, info]
# logging
for i, r in result.items():
if i in self._alive_env_ids and r[0] is not None:
for logger in self._logger:
logger.on_env_step(i, *r)
for _, reward, __, info in result.values():
self._set_default_info(info)
self._set_default_rew(reward)
for r in result.values():
if r[0] is None:
r[0] = self._get_default_obs()
if r[1] is None:
r[1] = self._get_default_rew()
if r[3] is None:
r[3] = self._get_default_info()
ret = list(map(np.stack, zip(*result.values())))
return cast(Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], ret)
def step(
self,
action: np.ndarray,
@@ -311,7 +351,7 @@ class FiniteShmemVectorEnv(FiniteVectorEnv, ShmemVectorEnv):
def vectorize_env(
env_factory: Callable[..., gym.Env],
env_factories: List[Callable[..., gym.Env]],
env_type: FiniteEnvType,
concurrency: int,
logger: LogWriter | List[LogWriter],
@@ -334,9 +374,10 @@ def vectorize_env(
Parameters
----------
env_factory
Callable to instantiate one single ``gym.Env``.
All concurrent workers will have the same ``env_factory``.
env_factories
Callables to instantiate one single ``gym.Env``.
There should be 1 or `concurrency` env_factories. If there is 1 env_factory, all concurrent workers will have
the same env_factory. Otherwise, each worker will have its own env_factory.
env_type
dummy or subproc or shmem. Corresponding to
`parallelism in tianshou <https://tianshou.readthedocs.io/en/master/api/tianshou.env.html#vectorenv>`_.
@@ -358,6 +399,8 @@ def vectorize_env(
def env_factory(): ...
vectorize_env(env_factory, ...)
"""
assert len(env_factories) in (1, concurrency)
env_type_cls_mapping: Dict[str, Type[FiniteVectorEnv]] = {
"dummy": FiniteDummyVectorEnv,
"subproc": FiniteSubprocVectorEnv,
@@ -366,4 +409,7 @@ def vectorize_env(
finite_env_cls = env_type_cls_mapping[env_type]
return finite_env_cls(logger, [env_factory for _ in range(concurrency)])
if len(env_factories) == 1:
return finite_env_cls(logger, [env_factories[0] for _ in range(concurrency)])
else:
return finite_env_cls(logger, env_factories)

View File

@@ -0,0 +1,30 @@
import time
from contextlib import contextmanager
from typing import Callable, Generator
from line_profiler import LineProfiler
@contextmanager
def simple_perf(desc: str = "", out_path: str = None) -> Generator[None, None, None]:
s = time.perf_counter()
yield
e = time.perf_counter()
msg = f"{desc}: {(e - s) * 1000.0:.4f} ms"
if out_path is not None:
with open(out_path, "a") as fstream:
fstream.write(msg + "\n")
else:
print(msg)
def lprofile(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
lp = LineProfiler()
lpw = lp(func)
res = lpw(*args, **kwargs)
lp.print_stats()
return res
return wrapper

View File

@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# TODO: this utils covers too much utilities, please seperat it into sub modules
from __future__ import division
from __future__ import print_function
@@ -43,7 +44,7 @@ is_deprecated_lexsorted_pandas = version.parse(pd.__version__) > version.parse("
#################### Server ####################
def get_redis_connection():
"""get redis connection instance."""
return redis.StrictRedis(host=C.redis_host, port=C.redis_port, db=C.redis_task_db)
return redis.StrictRedis(host=C.redis_host, port=C.redis_port, db=C.redis_task_db, password=C.redis_password)
#################### Data ####################
@@ -427,7 +428,7 @@ def init_instance_by_config(
pr = urlparse(config)
if pr.scheme == "file":
pr_path = os.path.join(pr.netloc, pr.path) if bool(pr.path) else pr.netloc
with open(pr_path, "rb") as f:
with open(os.path.normpath(pr_path), "rb") as f:
return pickle.load(f)
else:
with config.open("rb") as f:

View File

@@ -1,6 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Union
"""
This module covers some utility functions that operate on data or basic object
"""
from copy import deepcopy
from typing import List, Union
import pandas as pd
import numpy as np
@@ -54,3 +58,48 @@ def deepcopy_basic_type(obj: object) -> object:
return {k: deepcopy_basic_type(v) for k, v in obj.items()}
else:
return obj
S_DROP = "__DROP__" # this is a symbol which indicates drop the value
def update_config(base_config: dict, ext_config: Union[dict, List[dict]]):
"""
supporting adding base config based on the ext_config
>>> bc = {"a": "xixi"}
>>> ec = {"b": "haha"}
>>> new_bc = update_config(bc, ec)
>>> print(new_bc)
{'a': 'xixi', 'b': 'haha'}
>>> print(bc) # base config should not be changed
{'a': 'xixi'}
>>> print(update_config(bc, {"b": S_DROP}))
{'a': 'xixi'}
>>> print(update_config(new_bc, {"b": S_DROP}))
{'a': 'xixi'}
"""
base_config = deepcopy(base_config) # in case of modifying base config
for ec in ext_config if isinstance(ext_config, (list, tuple)) else [ext_config]:
for key in ec:
if key not in base_config:
# if it is not in the default key, then replace it.
# ADD if not drop
if ec[key] != S_DROP:
base_config[key] = ec[key]
else:
if isinstance(base_config[key], dict) and isinstance(ec[key], dict):
# Recursive
# Both of them are dict, then update it nested
base_config[key] = update_config(base_config[key], ec[key])
elif ec[key] == S_DROP:
# DROP
del base_config[key]
else:
# REPLACE
# one of then are not dict. Then replace
base_config[key] = ec[key]
return base_config

View File

@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import sys
import os
from pathlib import Path
@@ -10,6 +10,12 @@ import fire
import ruamel.yaml as yaml
from qlib.config import C
from qlib.model.trainer import task_train
from qlib.utils.data import update_config
from qlib.log import get_module_logger
from qlib.utils import set_log_with_config
set_log_with_config(C.logging_config)
logger = get_module_logger("qrun", logging.INFO)
def get_path_list(path):
@@ -47,10 +53,47 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
This is a Qlib CLI entrance.
User can run the whole Quant research workflow defined by a configure file
- the code is located here ``qlib/workflow/cli.py`
User can specify a base_config file in your workflow.yml file by adding "BASE_CONFIG_PATH".
Qlib will load the configuration in BASE_CONFIG_PATH first, and the user only needs to update the custom fields
in their own workflow.yml file.
For examples:
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
BASE_CONFIG_PATH: "workflow_config_lightgbm_Alpha158_csi500.yaml"
market: csi300
"""
with open(config_path) as fp:
config = yaml.safe_load(fp)
base_config_path = config.get("BASE_CONFIG_PATH", None)
if base_config_path:
logger.info(f"Use BASE_CONFIG_PATH: {base_config_path}")
base_config_path = Path(base_config_path)
# it will find config file in absolute path and relative path
if base_config_path.exists():
path = base_config_path
else:
logger.info(
f"Can't find BASE_CONFIG_PATH base on: {Path.cwd()}, "
f"try using relative path to config path: {Path(config_path).absolute()}"
)
relative_path = Path(config_path).absolute().parent.joinpath(base_config_path)
if relative_path.exists():
path = relative_path
else:
raise FileNotFoundError(f"Can't find the BASE_CONFIG file: {base_config_path}")
with open(path) as fp:
base_config = yaml.safe_load(fp)
logger.info(f"Load BASE_CONFIG_PATH succeed: {path.resolve()}")
config = update_config(base_config, config)
# config the `sys` section
sys_config(config, config_path)

View File

@@ -170,7 +170,7 @@ setup(
"gym>=0.24", # If you do not put gym at the end, gym will degrade causing pytest results to fail.
],
"rl": [
"tianshou",
"tianshou<=0.4.10",
"torch",
],
},

View File

@@ -0,0 +1,5 @@
# Introduction
The middle layers of data, which mainly includes
- Handler
- processors
- Datasets

View File

@@ -0,0 +1,37 @@
import os
import pickle
import shutil
import unittest
from qlib.tests import TestAutoData
from qlib.data import D
from qlib.data.dataset.handler import DataHandlerLP
class HandlerTests(TestAutoData):
def to_str(self, obj):
return "".join(str(obj).split())
def test_handler_df(self):
df = D.features(["sh600519"], start_time="20190101", end_time="20190201", fields=["$close"])
dh = DataHandlerLP.from_df(df)
print(dh.fetch())
self.assertTrue(dh._data.equals(df))
self.assertTrue(dh._infer is dh._data)
self.assertTrue(dh._learn is dh._data)
self.assertTrue(dh.data_loader._data is dh._data)
fname = "_handler_test.pkl"
dh.to_pickle(fname, dump_all=True)
with open(fname, "rb") as f:
dh_d = pickle.load(f)
self.assertTrue(dh_d._data.equals(df))
self.assertTrue(dh_d._infer is dh_d._data)
self.assertTrue(dh_d._learn is dh_d._data)
# Data loader will no longer be useful
self.assertTrue("_data" not in dh_d.data_loader.__dict__.keys())
os.remove(fname)
if __name__ == "__main__":
unittest.main()

View File

@@ -76,7 +76,7 @@ class IndexDataTest(unittest.TestCase):
self.assertTrue(np.isnan(sd.loc["bar", "g"]))
# support slicing
print(sd.loc[~sd.loc[:, "g"].isna().data.astype(np.bool)])
print(sd.loc[~sd.loc[:, "g"].isna().data.astype(bool)])
print(self.assertTrue(idd.SingleData().index == idd.SingleData().index))

View File

@@ -31,7 +31,6 @@ FEATURE_DATA_DIR = DATA_DIR / "processed"
ORDER_DIR = DATA_DIR / "order" / "valid_bidir"
CN_DATA_DIR = DATA_ROOT_DIR / "cn"
CN_BACKTEST_DATA_DIR = CN_DATA_DIR / "backtest"
CN_FEATURE_DATA_DIR = CN_DATA_DIR / "processed"
CN_ORDER_DIR = CN_DATA_DIR / "order" / "test"
CN_POLICY_WEIGHTS_DIR = CN_DATA_DIR / "weights"
@@ -49,7 +48,7 @@ def test_pickle_data_inspect():
def test_simulator_first_step():
order = Order("AAL", 30.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
state = simulator.get_state()
assert state.cur_time == pd.Timestamp("2013-12-11 09:30:00")
assert state.position == 30.0
@@ -83,7 +82,7 @@ def test_simulator_first_step():
def test_simulator_stop_twap():
order = Order("AAL", 13.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
for _ in range(13):
simulator.step(1.0)
@@ -106,10 +105,10 @@ def test_simulator_stop_early():
order = Order("AAL", 1.0, 1, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
with pytest.raises(ValueError):
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
simulator.step(2.0)
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
simulator.step(1.0)
with pytest.raises(AssertionError):
@@ -119,7 +118,7 @@ def test_simulator_stop_early():
def test_simulator_start_middle():
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59"))
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
assert len(simulator.ticks_for_order) == 330
assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00")
simulator.step(2.0)
@@ -138,7 +137,7 @@ def test_simulator_start_middle():
def test_interpreter():
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59"))
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
assert len(simulator.ticks_for_order) == 330
assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00")
@@ -219,7 +218,7 @@ def test_network_sanity():
# we won't check the correctness of networks here
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 9:30:00"), pd.Timestamp("2013-12-11 15:59:59"))
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
assert len(simulator.ticks_for_order) == 390
class EmulateEnvWrapper(NamedTuple):
@@ -259,7 +258,7 @@ def test_twap_strategy(finite_env_type):
csv_writer = CsvWriter(Path(__file__).parent / ".output")
backtest(
partial(SingleAssetOrderExecutionSimple, data_dir=BACKTEST_DATA_DIR, ticks_per_step=30),
partial(SingleAssetOrderExecutionSimple, data_dir=DATA_DIR, ticks_per_step=30),
state_interp,
action_interp,
orders,
@@ -290,7 +289,7 @@ def test_cn_ppo_strategy():
csv_writer = CsvWriter(Path(__file__).parent / ".output")
backtest(
partial(SingleAssetOrderExecutionSimple, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
partial(SingleAssetOrderExecutionSimple, data_dir=CN_DATA_DIR, ticks_per_step=30),
state_interp,
action_interp,
orders,
@@ -319,7 +318,7 @@ def test_ppo_train():
policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)
train(
partial(SingleAssetOrderExecutionSimple, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
partial(SingleAssetOrderExecutionSimple, data_dir=CN_DATA_DIR, ticks_per_step=30),
state_interp,
action_interp,
orders,