mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Refine Qlib RL data format (#1480)
* wip * wip * wip * Fix naming errors * Backtest test passed * Why training stuck? * Minor * Refine train configs * Use dummy in training * Remove pickle_dataframe * CI * CI * Add more strict condition to filter orders * Pass test * Add TODO in example --------- Co-authored-by: Young <afe.young@gmail.com>
This commit is contained in:
@@ -14,9 +14,10 @@ python -m qlib.run.get_data qlib_data qlib_data --target_dir ./data/bin --region
|
||||
|
||||
To run codes in this example, we need data in pickle format. To achieve this, run following commands (might need a few minutes to finish):
|
||||
|
||||
[//]: # (TODO: Instead of dumping dataframe with different format (like `_gen_dataset` and `_gen_day_dataset` in `qlib/contrib/data/highfreq_provider.py`), we encourage to implement different subclass of `Dataset` and `DataHandler`. This will keep the workflow cleaner and interfaces more consistent, and move all the complexity to the subclass.)
|
||||
|
||||
```
|
||||
python scripts/gen_pickle_data.py -c scripts/pickle_data_config.yml
|
||||
python scripts/collect_pickle_dataframe.py
|
||||
python scripts/gen_training_orders.py
|
||||
python scripts/merge_orders.py
|
||||
```
|
||||
@@ -27,8 +28,7 @@ When finished, the structure under `data/` should be:
|
||||
data
|
||||
├── bin
|
||||
├── orders
|
||||
├── pickle
|
||||
└── pickle_dataframe
|
||||
└── pickle
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
@@ -3,15 +3,6 @@ start_time: "9:30"
|
||||
end_time: "14:54"
|
||||
qlib:
|
||||
provider_uri_5min: ./data/bin/
|
||||
feature_root_dir: ./data/pickle/
|
||||
feature_columns_today: [
|
||||
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
|
||||
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"
|
||||
]
|
||||
feature_columns_yesterday: [
|
||||
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
|
||||
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"
|
||||
]
|
||||
exchange:
|
||||
limit_threshold: null
|
||||
deal_price: ["$close", "$close"]
|
||||
@@ -45,10 +36,12 @@ strategies:
|
||||
data_ticks: 48
|
||||
max_step: 8
|
||||
processed_data_provider:
|
||||
class: PickleProcessedDataProvider
|
||||
class: HandlerProcessedDataProvider
|
||||
kwargs:
|
||||
data_dir: ./data/pickle_dataframe/feature
|
||||
module_path: qlib.rl.data.pickle_styled
|
||||
data_dir: ./data/pickle/
|
||||
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
|
||||
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
|
||||
module_path: qlib.rl.data.native
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
module_path: qlib.rl.order_execution.strategy
|
||||
30min:
|
||||
|
||||
@@ -3,15 +3,6 @@ start_time: "9:30"
|
||||
end_time: "14:54"
|
||||
qlib:
|
||||
provider_uri_5min: ./data/bin/
|
||||
feature_root_dir: ./data/pickle/
|
||||
feature_columns_today: [
|
||||
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
|
||||
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"
|
||||
]
|
||||
feature_columns_yesterday: [
|
||||
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
|
||||
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"
|
||||
]
|
||||
exchange:
|
||||
limit_threshold: null
|
||||
deal_price: ["$close", "$close"]
|
||||
@@ -45,10 +36,12 @@ strategies:
|
||||
data_ticks: 48
|
||||
max_step: 8
|
||||
processed_data_provider:
|
||||
class: PickleProcessedDataProvider
|
||||
class: HandlerProcessedDataProvider
|
||||
kwargs:
|
||||
data_dir: ./data/pickle_dataframe/feature
|
||||
module_path: qlib.rl.data.pickle_styled
|
||||
data_dir: ./data/pickle/
|
||||
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
|
||||
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
|
||||
module_path: qlib.rl.data.native
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
module_path: qlib.rl.order_execution.strategy
|
||||
30min:
|
||||
|
||||
@@ -3,15 +3,6 @@ start_time: "9:30"
|
||||
end_time: "14:54"
|
||||
qlib:
|
||||
provider_uri_5min: ./data/bin/
|
||||
feature_root_dir: ./data/pickle/
|
||||
feature_columns_today: [
|
||||
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
|
||||
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"
|
||||
]
|
||||
feature_columns_yesterday: [
|
||||
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
|
||||
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"
|
||||
]
|
||||
exchange:
|
||||
limit_threshold: null
|
||||
deal_price: ["$close", "$close"]
|
||||
|
||||
@@ -3,8 +3,8 @@ simulator:
|
||||
time_per_step: 30
|
||||
vol_limit: null
|
||||
env:
|
||||
concurrency: 48
|
||||
parallel_mode: shmem
|
||||
concurrency: 32
|
||||
parallel_mode: dummy
|
||||
action_interpreter:
|
||||
class: CategoricalActionInterpreter
|
||||
kwargs:
|
||||
@@ -18,10 +18,13 @@ state_interpreter:
|
||||
data_ticks: 48 # 48 = 240 min / 5 min
|
||||
max_step: 8
|
||||
processed_data_provider:
|
||||
class: PickleProcessedDataProvider
|
||||
module_path: qlib.rl.data.pickle_styled
|
||||
class: HandlerProcessedDataProvider
|
||||
kwargs:
|
||||
data_dir: ./data/pickle_dataframe/feature
|
||||
data_dir: ./data/pickle/
|
||||
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
|
||||
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
|
||||
backtest: false
|
||||
module_path: qlib.rl.data.native
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
reward:
|
||||
class: PAPenaltyReward
|
||||
@@ -32,7 +35,9 @@ reward:
|
||||
data:
|
||||
source:
|
||||
order_dir: ./data/orders
|
||||
data_dir: ./data/pickle_dataframe/backtest
|
||||
feature_root_dir: ./data/pickle/
|
||||
feature_columns_today: ["$close0", "$volume0"]
|
||||
feature_columns_yesterday: []
|
||||
total_time: 240
|
||||
default_start_time_index: 0
|
||||
default_end_time_index: 235
|
||||
|
||||
@@ -3,8 +3,8 @@ simulator:
|
||||
time_per_step: 30
|
||||
vol_limit: null
|
||||
env:
|
||||
concurrency: 48
|
||||
parallel_mode: shmem
|
||||
concurrency: 32
|
||||
parallel_mode: dummy
|
||||
action_interpreter:
|
||||
class: CategoricalActionInterpreter
|
||||
kwargs:
|
||||
@@ -18,10 +18,13 @@ state_interpreter:
|
||||
data_ticks: 48 # 48 = 240 min / 5 min
|
||||
max_step: 8
|
||||
processed_data_provider:
|
||||
class: PickleProcessedDataProvider
|
||||
module_path: qlib.rl.data.pickle_styled
|
||||
class: HandlerProcessedDataProvider
|
||||
kwargs:
|
||||
data_dir: ./data/pickle_dataframe/feature
|
||||
data_dir: ./data/pickle/
|
||||
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
|
||||
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
|
||||
backtest: false
|
||||
module_path: qlib.rl.data.native
|
||||
module_path: qlib.rl.order_execution.interpreter
|
||||
reward:
|
||||
class: PPOReward
|
||||
@@ -33,7 +36,9 @@ reward:
|
||||
data:
|
||||
source:
|
||||
order_dir: ./data/orders
|
||||
data_dir: ./data/pickle_dataframe/backtest
|
||||
feature_root_dir: ./data/pickle/
|
||||
feature_columns_today: ["$close0", "$volume0"]
|
||||
feature_columns_yesterday: []
|
||||
total_time: 240
|
||||
default_start_time_index: 0
|
||||
default_end_time_index: 235
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import pandas as pd
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
os.makedirs(os.path.join("data", "pickle_dataframe"), exist_ok=True)
|
||||
|
||||
|
||||
def _collect(df: pd.DataFrame, instrument: str, tag: str) -> None:
|
||||
cur = df[df["instrument"] == instrument].sort_values(by=["datetime"])
|
||||
cur = cur.set_index(["instrument", "datetime", "date"])
|
||||
pickle.dump(cur, open(os.path.join("data", "pickle_dataframe", tag, f"{instrument}.pkl"), "wb"))
|
||||
|
||||
|
||||
for tag in ("backtest", "feature"):
|
||||
df = pickle.load(open(os.path.join("data", "pickle", f"{tag}.pkl"), "rb"))
|
||||
df = pd.concat(list(df.values())).reset_index()
|
||||
df["date"] = df["datetime"].dt.date.astype("datetime64")
|
||||
instruments = sorted(set(df["instrument"]))
|
||||
|
||||
os.makedirs(os.path.join("data", "pickle_dataframe", tag), exist_ok=True)
|
||||
|
||||
Parallel(n_jobs=-1, verbose=10)(delayed(_collect)(df, instrument, tag) for instrument in instruments)
|
||||
@@ -4,17 +4,22 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
DATA_PATH = Path(os.path.join("data", "pickle_dataframe", "backtest"))
|
||||
DATA_PATH = Path(os.path.join("data", "pickle", "backtest"))
|
||||
OUTPUT_PATH = Path(os.path.join("data", "orders"))
|
||||
|
||||
|
||||
def generate_order(stock: str, start_idx: int, end_idx: int) -> None:
|
||||
df = pd.read_pickle(DATA_PATH / f"{stock}.pkl")
|
||||
def generate_order(stock: str, start_idx: int, end_idx: int) -> bool:
|
||||
dataset = pd.read_pickle(DATA_PATH / f"{stock}.pkl")
|
||||
df = dataset.handler.fetch(level=None).reset_index()
|
||||
if len(df) == 0 or df.isnull().values.any() or min(df["$volume0"]) < 1e-5:
|
||||
return False
|
||||
|
||||
df["date"] = df["datetime"].dt.date.astype("datetime64")
|
||||
df = df.set_index(["instrument", "datetime", "date"])
|
||||
df = df.groupby("date").take(range(start_idx, end_idx)).droplevel(level=0)
|
||||
div = df["$volume0"].rolling((end_idx - start_idx) * 60).mean().shift(1).groupby(level="date").transform("first")
|
||||
|
||||
order_all = pd.DataFrame(df.groupby(level=(2, 0)).mean().dropna())
|
||||
order_all["amount"] = np.random.lognormal(-3.28, 1.14) * order_all["$volume0"]
|
||||
@@ -32,11 +37,17 @@ def generate_order(stock: str, start_idx: int, end_idx: int) -> None:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
if len(order) > 0:
|
||||
order.to_pickle(path / f"{stock}.pkl.target")
|
||||
return True
|
||||
|
||||
|
||||
np.random.seed(1234)
|
||||
file_list = sorted(os.listdir(DATA_PATH))
|
||||
stocks = [f.replace(".pkl", "") for f in file_list]
|
||||
stocks = sorted(np.random.choice(stocks, size=100, replace=False))
|
||||
for stock in tqdm(stocks):
|
||||
generate_order(stock, 0, 240 // 5 - 1)
|
||||
np.random.shuffle(stocks)
|
||||
|
||||
cnt = 0
|
||||
for stock in stocks:
|
||||
if generate_order(stock, 0, 240 // 5 - 1):
|
||||
cnt += 1
|
||||
if cnt == 100:
|
||||
break
|
||||
|
||||
@@ -154,12 +154,7 @@ def single_with_simulator(
|
||||
-------
|
||||
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
|
||||
"""
|
||||
if split == "stock":
|
||||
stock_id = orders.iloc[0].instrument
|
||||
init_qlib(backtest_config["qlib"], part=stock_id)
|
||||
else:
|
||||
day = orders.iloc[0].datetime
|
||||
init_qlib(backtest_config["qlib"], part=day)
|
||||
init_qlib(backtest_config["qlib"])
|
||||
|
||||
stocks = orders.instrument.unique().tolist()
|
||||
|
||||
@@ -253,12 +248,7 @@ def single_with_collect_data_loop(
|
||||
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
|
||||
"""
|
||||
|
||||
if split == "stock":
|
||||
stock_id = orders.iloc[0].instrument
|
||||
init_qlib(backtest_config["qlib"], part=stock_id)
|
||||
else:
|
||||
day = orders.iloc[0].datetime
|
||||
init_qlib(backtest_config["qlib"], part=day)
|
||||
init_qlib(backtest_config["qlib"])
|
||||
|
||||
trade_start_time = orders["datetime"].min()
|
||||
trade_end_time = orders["datetime"].max()
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
@@ -9,13 +11,12 @@ from typing import cast, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import qlib
|
||||
import torch
|
||||
import yaml
|
||||
from qlib.backtest import Order
|
||||
from qlib.backtest.decision import OrderDir
|
||||
from qlib.constant import ONE_MIN
|
||||
from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data
|
||||
from qlib.rl.data.native import load_handler_intraday_processed_data
|
||||
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
|
||||
from qlib.rl.reward import Reward
|
||||
@@ -49,19 +50,17 @@ def _read_orders(order_dir: Path) -> pd.DataFrame:
|
||||
class LazyLoadDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
order_file_path: Path,
|
||||
data_dir: Path,
|
||||
default_start_time_index: int,
|
||||
default_end_time_index: int,
|
||||
) -> None:
|
||||
self._default_start_time_index = default_start_time_index
|
||||
self._default_end_time_index = default_end_time_index
|
||||
|
||||
self._order_file_path = order_file_path
|
||||
self._order_df = _read_orders(order_file_path).reset_index()
|
||||
|
||||
self._data_dir = data_dir
|
||||
self._ticks_index: Optional[pd.DatetimeIndex] = None
|
||||
self._data_dir = Path(data_dir)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._order_df)
|
||||
@@ -74,12 +73,17 @@ class LazyLoadDataset(Dataset):
|
||||
# TODO: We only load ticks index once based on the assumption that ticks index of different dates
|
||||
# TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index
|
||||
# TODO: of all dates.
|
||||
backtest_data = load_simple_intraday_backtest_data(
|
||||
|
||||
data = load_handler_intraday_processed_data(
|
||||
data_dir=self._data_dir,
|
||||
stock_id=row["instrument"],
|
||||
date=date,
|
||||
feature_columns_today=[],
|
||||
feature_columns_yesterday=[],
|
||||
backtest=True,
|
||||
index_only=True,
|
||||
)
|
||||
self._ticks_index = [t - date for t in backtest_data.get_time_index()]
|
||||
self._ticks_index = [t - date for t in data.today.index]
|
||||
|
||||
order = Order(
|
||||
stock_id=row["instrument"],
|
||||
@@ -104,8 +108,6 @@ def train_and_test(
|
||||
run_training: bool,
|
||||
run_backtest: bool,
|
||||
) -> None:
|
||||
qlib.init()
|
||||
|
||||
order_root_path = Path(data_config["source"]["order_dir"])
|
||||
|
||||
data_granularity = simulator_config.get("data_granularity", 1)
|
||||
@@ -113,10 +115,11 @@ def train_and_test(
|
||||
def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
|
||||
return SingleAssetOrderExecutionSimple(
|
||||
order=order,
|
||||
data_dir=Path(data_config["source"]["data_dir"]),
|
||||
ticks_per_step=simulator_config["time_per_step"],
|
||||
data_dir=data_config["source"]["feature_root_dir"],
|
||||
feature_columns_today=data_config["source"]["feature_columns_today"],
|
||||
feature_columns_yesterday=data_config["source"]["feature_columns_yesterday"],
|
||||
data_granularity=data_granularity,
|
||||
deal_price_type=data_config["source"].get("deal_price_column", "close"),
|
||||
ticks_per_step=simulator_config["time_per_step"],
|
||||
vol_threshold=simulator_config["vol_limit"],
|
||||
)
|
||||
|
||||
@@ -126,8 +129,8 @@ def train_and_test(
|
||||
if run_training:
|
||||
train_dataset, valid_dataset = [
|
||||
LazyLoadDataset(
|
||||
data_dir=data_config["source"]["feature_root_dir"],
|
||||
order_file_path=order_root_path / tag,
|
||||
data_dir=Path(data_config["source"]["data_dir"]),
|
||||
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
|
||||
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
|
||||
)
|
||||
@@ -178,8 +181,8 @@ def train_and_test(
|
||||
|
||||
if run_backtest:
|
||||
test_dataset = LazyLoadDataset(
|
||||
data_dir=data_config["source"]["feature_root_dir"],
|
||||
order_file_path=order_root_path / "test",
|
||||
data_dir=Path(data_config["source"]["data_dir"]),
|
||||
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
|
||||
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
|
||||
)
|
||||
|
||||
@@ -8,48 +8,14 @@ TODO: The implementation here is kind of adhoc. It is better to design a more un
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import cachetools
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import qlib
|
||||
from qlib.constant import REG_CN
|
||||
from qlib.contrib.ops.high_freq import BFillNan, Cut, Date, DayCumsum, DayLast, FFillNan, IsInf, IsNull, Select
|
||||
from qlib.data.dataset import DatasetH
|
||||
|
||||
dataset = None
|
||||
|
||||
|
||||
class DataWrapper:
|
||||
def __init__(
|
||||
self,
|
||||
feature_dataset: DatasetH,
|
||||
backtest_dataset: DatasetH,
|
||||
columns_today: List[str],
|
||||
columns_yesterday: List[str],
|
||||
_internal: bool = False,
|
||||
):
|
||||
assert _internal, "Init function of data wrapper is for internal use only."
|
||||
|
||||
self.feature_dataset = feature_dataset
|
||||
self.backtest_dataset = backtest_dataset
|
||||
self.columns_today = columns_today
|
||||
self.columns_yesterday = columns_yesterday
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(100),
|
||||
key=lambda _, stock_id, date, backtest: (stock_id, date.replace(hour=0, minute=0, second=0), backtest),
|
||||
)
|
||||
def get(self, stock_id: str, date: pd.Timestamp, backtest: bool = False) -> pd.DataFrame:
|
||||
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
|
||||
dataset = self.backtest_dataset if backtest else self.feature_dataset
|
||||
return dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
|
||||
|
||||
|
||||
def init_qlib(qlib_config: dict, part: str | None = None) -> None:
|
||||
def init_qlib(qlib_config: dict) -> None:
|
||||
"""Initialize necessary resource to launch the workflow, including data direction, feature columns, etc..
|
||||
|
||||
Parameters
|
||||
@@ -72,12 +38,8 @@ def init_qlib(qlib_config: dict, part: str | None = None) -> None:
|
||||
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1",
|
||||
],
|
||||
}
|
||||
part
|
||||
Identifying which part (stock / date) to load.
|
||||
"""
|
||||
|
||||
global dataset # pylint: disable=W0603
|
||||
|
||||
def _convert_to_path(path: str | Path) -> Path:
|
||||
return path if isinstance(path, Path) else Path(path)
|
||||
|
||||
@@ -118,47 +80,3 @@ def init_qlib(qlib_config: dict, part: str | None = None) -> None:
|
||||
redis_port=-1,
|
||||
clear_mem_cache=False, # init_qlib will be called for multiple times. Keep the cache for improving performance
|
||||
)
|
||||
|
||||
if part == "skip":
|
||||
return
|
||||
|
||||
# this won't work if it's put outside in case of multiprocessing
|
||||
from qlib.data import D # noqa pylint: disable=C0415,W0611
|
||||
|
||||
if part is None:
|
||||
feature_path = Path(qlib_config["feature_root_dir"]) / "feature.pkl"
|
||||
backtest_path = Path(qlib_config["feature_root_dir"]) / "backtest.pkl"
|
||||
else:
|
||||
feature_path = Path(qlib_config["feature_root_dir"]) / "feature" / (part + ".pkl")
|
||||
backtest_path = Path(qlib_config["feature_root_dir"]) / "backtest" / (part + ".pkl")
|
||||
|
||||
with feature_path.open("rb") as f:
|
||||
feature_dataset = pickle.load(f)
|
||||
with backtest_path.open("rb") as f:
|
||||
backtest_dataset = pickle.load(f)
|
||||
|
||||
dataset = DataWrapper(
|
||||
feature_dataset,
|
||||
backtest_dataset,
|
||||
qlib_config["feature_columns_today"],
|
||||
qlib_config["feature_columns_yesterday"],
|
||||
_internal=True,
|
||||
)
|
||||
|
||||
|
||||
def fetch_features(stock_id: str, date: pd.Timestamp, yesterday: bool = False, backtest: bool = False) -> pd.DataFrame:
|
||||
assert dataset is not None, "You must call init_qlib() before doing this."
|
||||
|
||||
if backtest:
|
||||
fields = ["$close", "$volume"]
|
||||
else:
|
||||
fields = dataset.columns_yesterday if yesterday else dataset.columns_today
|
||||
|
||||
data = dataset.get(stock_id, date, backtest)
|
||||
if data is None or len(data) == 0:
|
||||
# create a fake index, but RL doesn't care about index
|
||||
data = pd.DataFrame(0.0, index=np.arange(240), columns=fields, dtype=np.float32) # FIXME: hardcode here
|
||||
else:
|
||||
data = data.rename(columns={c: c.rstrip("0") for c in data.columns})
|
||||
data = data[fields]
|
||||
return data
|
||||
|
||||
@@ -2,17 +2,29 @@
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
from pathlib import Path
|
||||
from typing import cast, List
|
||||
|
||||
import cachetools
|
||||
import pandas as pd
|
||||
import pickle
|
||||
import os
|
||||
|
||||
from qlib.backtest import Exchange, Order
|
||||
from qlib.backtest.decision import TradeRange, TradeRangeByTime
|
||||
from qlib.rl.order_execution.utils import get_ticks_slice
|
||||
|
||||
from qlib.constant import EPS_T
|
||||
from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
|
||||
from .integration import fetch_features
|
||||
|
||||
|
||||
def get_ticks_slice(
|
||||
ticks_index: pd.DatetimeIndex,
|
||||
start: pd.Timestamp,
|
||||
end: pd.Timestamp,
|
||||
include_end: bool = False,
|
||||
) -> pd.DatetimeIndex:
|
||||
if not include_end:
|
||||
end = end - EPS_T
|
||||
return ticks_index[ticks_index.slice_indexer(start, end)]
|
||||
|
||||
|
||||
class IntradayBacktestData(BaseIntradayBacktestData):
|
||||
@@ -71,6 +83,31 @@ class IntradayBacktestData(BaseIntradayBacktestData):
|
||||
return pd.DatetimeIndex([e[1] for e in list(self._exchange.quote_df.index)])
|
||||
|
||||
|
||||
class DataframeIntradayBacktestData(BaseIntradayBacktestData):
|
||||
"""Backtest data from dataframe"""
|
||||
|
||||
def __init__(self, df: pd.DataFrame, price_column: str = "$close0", volume_column: str = "$volume0") -> None:
|
||||
self.df = df
|
||||
self.price_column = price_column
|
||||
self.volume_column = volume_column
|
||||
|
||||
def __repr__(self) -> str:
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
return f"{self.__class__.__name__}({self.df})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.df)
|
||||
|
||||
def get_deal_price(self) -> pd.Series:
|
||||
return self.df[self.price_column]
|
||||
|
||||
def get_volume(self) -> pd.Series:
|
||||
return self.df[self.volume_column]
|
||||
|
||||
def get_time_index(self) -> pd.DatetimeIndex:
|
||||
return cast(pd.DatetimeIndex, self.df.index)
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(100),
|
||||
key=lambda order, _, __: order.key_by_day,
|
||||
@@ -103,13 +140,18 @@ def load_backtest_data(
|
||||
return backtest_data
|
||||
|
||||
|
||||
class NTIntradayProcessedData(BaseIntradayProcessedData):
|
||||
"""Subclass of IntradayProcessedData. Used to handle NT style data."""
|
||||
class HandlerIntradayProcessedData(BaseIntradayProcessedData):
|
||||
"""Subclass of IntradayProcessedData. Used to handle handler (bin format) style data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool = False,
|
||||
index_only: bool = False,
|
||||
) -> None:
|
||||
def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame:
|
||||
df = df.reset_index()
|
||||
@@ -117,8 +159,18 @@ class NTIntradayProcessedData(BaseIntradayProcessedData):
|
||||
df = df.drop(columns=["instrument"])
|
||||
return df.set_index(["datetime"])
|
||||
|
||||
self.today = _drop_stock_id(fetch_features(stock_id, date))
|
||||
self.yesterday = _drop_stock_id(fetch_features(stock_id, date, yesterday=True))
|
||||
path = os.path.join(data_dir, "backtest" if backtest else "feature", f"{stock_id}.pkl")
|
||||
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
|
||||
with open(path, "rb") as fstream:
|
||||
dataset = pickle.load(fstream)
|
||||
data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
|
||||
|
||||
if index_only:
|
||||
self.today = _drop_stock_id(data[[]])
|
||||
self.yesterday = _drop_stock_id(data[[]])
|
||||
else:
|
||||
self.today = _drop_stock_id(data[feature_columns_today])
|
||||
self.yesterday = _drop_stock_id(data[feature_columns_yesterday])
|
||||
|
||||
def __repr__(self) -> str:
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
@@ -127,12 +179,42 @@ class NTIntradayProcessedData(BaseIntradayProcessedData):
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
|
||||
key=lambda data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only: (
|
||||
stock_id,
|
||||
date,
|
||||
backtest,
|
||||
index_only,
|
||||
),
|
||||
)
|
||||
def load_nt_intraday_processed_data(stock_id: str, date: pd.Timestamp) -> NTIntradayProcessedData:
|
||||
return NTIntradayProcessedData(stock_id, date)
|
||||
def load_handler_intraday_processed_data(
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool = False,
|
||||
index_only: bool = False,
|
||||
) -> HandlerIntradayProcessedData:
|
||||
return HandlerIntradayProcessedData(
|
||||
data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only
|
||||
)
|
||||
|
||||
|
||||
class NTProcessedDataProvider(ProcessedDataProvider):
|
||||
class HandlerProcessedDataProvider(ProcessedDataProvider):
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.data_dir = Path(data_dir)
|
||||
self.feature_columns_today = feature_columns_today
|
||||
self.feature_columns_yesterday = feature_columns_yesterday
|
||||
self.backtest = backtest
|
||||
|
||||
def get_data(
|
||||
self,
|
||||
stock_id: str,
|
||||
@@ -140,4 +222,12 @@ class NTProcessedDataProvider(ProcessedDataProvider):
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
) -> BaseIntradayProcessedData:
|
||||
return load_nt_intraday_processed_data(stock_id, date)
|
||||
return load_handler_intraday_processed_data(
|
||||
self.data_dir,
|
||||
stock_id,
|
||||
date,
|
||||
self.feature_columns_today,
|
||||
self.feature_columns_yesterday,
|
||||
backtest=self.backtest,
|
||||
index_only=False,
|
||||
)
|
||||
|
||||
@@ -158,8 +158,8 @@ class SimpleIntradayBacktestData(BaseIntradayBacktestData):
|
||||
return cast(pd.DatetimeIndex, self.data.index)
|
||||
|
||||
|
||||
class IntradayProcessedData(BaseIntradayProcessedData):
|
||||
"""Subclass of IntradayProcessedData. Used to handle Dataset Handler style data."""
|
||||
class PickleIntradayProcessedData(BaseIntradayProcessedData):
|
||||
"""Subclass of IntradayProcessedData. Used to handle pickle-styled data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -217,14 +217,14 @@ def load_simple_intraday_backtest_data(
|
||||
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
|
||||
key=lambda data_dir, stock_id, date, feature_dim, time_index: hashkey(data_dir, stock_id, date),
|
||||
)
|
||||
def load_pickled_intraday_processed_data(
|
||||
def load_pickle_intraday_processed_data(
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
) -> BaseIntradayProcessedData:
|
||||
return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index)
|
||||
return PickleIntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index)
|
||||
|
||||
|
||||
class PickleProcessedDataProvider(ProcessedDataProvider):
|
||||
@@ -240,7 +240,7 @@ class PickleProcessedDataProvider(ProcessedDataProvider):
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
) -> BaseIntradayProcessedData:
|
||||
return load_pickled_intraday_processed_data(
|
||||
return load_pickle_intraday_processed_data(
|
||||
data_dir=self._data_dir,
|
||||
stock_id=stock_id,
|
||||
date=date,
|
||||
|
||||
@@ -67,7 +67,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
cash_limit: Optional[float] = None,
|
||||
) -> None:
|
||||
if qlib_config is not None:
|
||||
init_qlib(qlib_config, part="skip")
|
||||
init_qlib(qlib_config)
|
||||
|
||||
strategy, self._executor = get_strategy_executor(
|
||||
start_time=order.date,
|
||||
|
||||
@@ -3,17 +3,19 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, cast, Optional
|
||||
from typing import Any, cast, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from pathlib import Path
|
||||
from qlib.backtest.decision import Order, OrderDir
|
||||
from qlib.constant import EPS, EPS_T, float_or_ndarray
|
||||
from qlib.rl.data.pickle_styled import DealPriceType, load_simple_intraday_backtest_data
|
||||
from qlib.rl.data.base import BaseIntradayBacktestData
|
||||
from qlib.rl.data.native import DataframeIntradayBacktestData, load_handler_intraday_processed_data
|
||||
from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data
|
||||
from qlib.rl.simulator import Simulator
|
||||
from qlib.rl.utils import LogLevel
|
||||
|
||||
from .state import SAOEMetrics, SAOEState
|
||||
|
||||
__all__ = ["SingleAssetOrderExecutionSimple"]
|
||||
@@ -36,12 +38,16 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
----------
|
||||
order
|
||||
The seed to start an SAOE simulator is an order.
|
||||
data_dir
|
||||
Path to load backtest data.
|
||||
feature_columns_today
|
||||
Columns of today's feature.
|
||||
feature_columns_yesterday
|
||||
Columns of yesterday's feature.
|
||||
data_granularity
|
||||
Number of ticks between consecutive data entries.
|
||||
ticks_per_step
|
||||
How many ticks per step.
|
||||
data_dir
|
||||
Path to load backtest data
|
||||
vol_threshold
|
||||
Maximum execution volume (divided by market execution volume).
|
||||
"""
|
||||
@@ -73,9 +79,10 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
self,
|
||||
order: Order,
|
||||
data_dir: Path,
|
||||
feature_columns_today: List[str] = [],
|
||||
feature_columns_yesterday: List[str] = [],
|
||||
data_granularity: int = 1,
|
||||
ticks_per_step: int = 30,
|
||||
deal_price_type: DealPriceType = "close",
|
||||
vol_threshold: Optional[float] = None,
|
||||
) -> None:
|
||||
super().__init__(initial=order)
|
||||
@@ -83,18 +90,13 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
assert ticks_per_step % data_granularity == 0
|
||||
|
||||
self.order = order
|
||||
self.ticks_per_step: int = ticks_per_step // data_granularity
|
||||
self.deal_price_type = deal_price_type
|
||||
self.vol_threshold = vol_threshold
|
||||
self.data_dir = data_dir
|
||||
self.backtest_data = load_simple_intraday_backtest_data(
|
||||
self.data_dir,
|
||||
order.stock_id,
|
||||
pd.Timestamp(order.start_time.date()),
|
||||
self.deal_price_type,
|
||||
order.direction,
|
||||
)
|
||||
self.feature_columns_today = feature_columns_today
|
||||
self.feature_columns_yesterday = feature_columns_yesterday
|
||||
self.ticks_per_step: int = ticks_per_step // data_granularity
|
||||
self.vol_threshold = vol_threshold
|
||||
|
||||
self.backtest_data = self.get_backtest_data()
|
||||
self.ticks_index = self.backtest_data.get_time_index()
|
||||
|
||||
# Get time index available for trading
|
||||
@@ -118,6 +120,30 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
self.market_vol: Optional[np.ndarray] = None
|
||||
self.market_vol_limit: Optional[np.ndarray] = None
|
||||
|
||||
def get_backtest_data(self) -> BaseIntradayBacktestData:
|
||||
try:
|
||||
data = load_handler_intraday_processed_data(
|
||||
data_dir=self.data_dir,
|
||||
stock_id=self.order.stock_id,
|
||||
date=pd.Timestamp(self.order.start_time.date()),
|
||||
feature_columns_today=self.feature_columns_today,
|
||||
feature_columns_yesterday=self.feature_columns_yesterday,
|
||||
backtest=True,
|
||||
index_only=False,
|
||||
)
|
||||
return DataframeIntradayBacktestData(data.today)
|
||||
except (AttributeError, FileNotFoundError):
|
||||
# TODO: For compatibility with older versions of test scripts (tests/rl/test_saoe_simple.py)
|
||||
# TODO: In the future, we should modify the data format used by the test script,
|
||||
# TODO: and then delete this branch.
|
||||
return load_simple_intraday_backtest_data(
|
||||
self.data_dir / "backtest",
|
||||
self.order.stock_id,
|
||||
pd.Timestamp(self.order.start_time.date()),
|
||||
"close",
|
||||
self.order.direction,
|
||||
)
|
||||
|
||||
def step(self, amount: float) -> None:
|
||||
"""Execute one step or SAOE.
|
||||
|
||||
|
||||
@@ -10,18 +10,7 @@ import pandas as pd
|
||||
|
||||
from qlib.backtest.decision import OrderDir
|
||||
from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor
|
||||
from qlib.constant import EPS_T, float_or_ndarray
|
||||
|
||||
|
||||
def get_ticks_slice(
|
||||
ticks_index: pd.DatetimeIndex,
|
||||
start: pd.Timestamp,
|
||||
end: pd.Timestamp,
|
||||
include_end: bool = False,
|
||||
) -> pd.DatetimeIndex:
|
||||
if not include_end:
|
||||
end = end - EPS_T
|
||||
return ticks_index[ticks_index.slice_indexer(start, end)]
|
||||
from qlib.constant import float_or_ndarray
|
||||
|
||||
|
||||
def dataframe_append(df: pd.DataFrame, other: Any) -> pd.DataFrame:
|
||||
|
||||
@@ -31,7 +31,6 @@ FEATURE_DATA_DIR = DATA_DIR / "processed"
|
||||
ORDER_DIR = DATA_DIR / "order" / "valid_bidir"
|
||||
|
||||
CN_DATA_DIR = DATA_ROOT_DIR / "cn"
|
||||
CN_BACKTEST_DATA_DIR = CN_DATA_DIR / "backtest"
|
||||
CN_FEATURE_DATA_DIR = CN_DATA_DIR / "processed"
|
||||
CN_ORDER_DIR = CN_DATA_DIR / "order" / "test"
|
||||
CN_POLICY_WEIGHTS_DIR = CN_DATA_DIR / "weights"
|
||||
@@ -49,7 +48,7 @@ def test_pickle_data_inspect():
|
||||
def test_simulator_first_step():
|
||||
order = Order("AAL", 30.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
|
||||
|
||||
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
|
||||
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
|
||||
state = simulator.get_state()
|
||||
assert state.cur_time == pd.Timestamp("2013-12-11 09:30:00")
|
||||
assert state.position == 30.0
|
||||
@@ -83,7 +82,7 @@ def test_simulator_first_step():
|
||||
def test_simulator_stop_twap():
|
||||
order = Order("AAL", 13.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
|
||||
|
||||
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
|
||||
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
|
||||
for _ in range(13):
|
||||
simulator.step(1.0)
|
||||
|
||||
@@ -106,10 +105,10 @@ def test_simulator_stop_early():
|
||||
order = Order("AAL", 1.0, 1, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
|
||||
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
|
||||
simulator.step(2.0)
|
||||
|
||||
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
|
||||
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
|
||||
simulator.step(1.0)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
@@ -119,7 +118,7 @@ def test_simulator_stop_early():
|
||||
def test_simulator_start_middle():
|
||||
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59"))
|
||||
|
||||
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
|
||||
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
|
||||
assert len(simulator.ticks_for_order) == 330
|
||||
assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00")
|
||||
simulator.step(2.0)
|
||||
@@ -138,7 +137,7 @@ def test_simulator_start_middle():
|
||||
def test_interpreter():
|
||||
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59"))
|
||||
|
||||
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
|
||||
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
|
||||
assert len(simulator.ticks_for_order) == 330
|
||||
assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00")
|
||||
|
||||
@@ -219,7 +218,7 @@ def test_network_sanity():
|
||||
# we won't check the correctness of networks here
|
||||
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 9:30:00"), pd.Timestamp("2013-12-11 15:59:59"))
|
||||
|
||||
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
|
||||
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
|
||||
assert len(simulator.ticks_for_order) == 390
|
||||
|
||||
class EmulateEnvWrapper(NamedTuple):
|
||||
@@ -259,7 +258,7 @@ def test_twap_strategy(finite_env_type):
|
||||
csv_writer = CsvWriter(Path(__file__).parent / ".output")
|
||||
|
||||
backtest(
|
||||
partial(SingleAssetOrderExecutionSimple, data_dir=BACKTEST_DATA_DIR, ticks_per_step=30),
|
||||
partial(SingleAssetOrderExecutionSimple, data_dir=DATA_DIR, ticks_per_step=30),
|
||||
state_interp,
|
||||
action_interp,
|
||||
orders,
|
||||
@@ -290,7 +289,7 @@ def test_cn_ppo_strategy():
|
||||
csv_writer = CsvWriter(Path(__file__).parent / ".output")
|
||||
|
||||
backtest(
|
||||
partial(SingleAssetOrderExecutionSimple, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
|
||||
partial(SingleAssetOrderExecutionSimple, data_dir=CN_DATA_DIR, ticks_per_step=30),
|
||||
state_interp,
|
||||
action_interp,
|
||||
orders,
|
||||
@@ -319,7 +318,7 @@ def test_ppo_train():
|
||||
policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)
|
||||
|
||||
train(
|
||||
partial(SingleAssetOrderExecutionSimple, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
|
||||
partial(SingleAssetOrderExecutionSimple, data_dir=CN_DATA_DIR, ticks_per_step=30),
|
||||
state_interp,
|
||||
action_interp,
|
||||
orders,
|
||||
|
||||
Reference in New Issue
Block a user