1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Refine Qlib RL data format (#1480)

* wip

* wip

* wip

* Fix naming errors

* Backtest test passed

* Why training stuck?

* Minor

* Refine train configs

* Use dummy in training

* Remove pickle_dataframe

* CI

* CI

* Add more strict condition to filter orders

* Pass test

* Add TODO in example

---------

Co-authored-by: Young <afe.young@gmail.com>
This commit is contained in:
Huoran Li
2023-04-26 21:14:30 +08:00
committed by GitHub
parent 46264dfec9
commit 7f1e8c5206
17 changed files with 237 additions and 250 deletions

View File

@@ -14,9 +14,10 @@ python -m qlib.run.get_data qlib_data qlib_data --target_dir ./data/bin --region
To run codes in this example, we need data in pickle format. To achieve this, run following commands (might need a few minutes to finish):
[//]: # (TODO: Instead of dumping dataframe with different format &#40;like `_gen_dataset` and `_gen_day_dataset` in `qlib/contrib/data/highfreq_provider.py`&#41;, we encourage to implement different subclass of `Dataset` and `DataHandler`. This will keep the workflow cleaner and interfaces more consistent, and move all the complexity to the subclass.)
```
python scripts/gen_pickle_data.py -c scripts/pickle_data_config.yml
python scripts/collect_pickle_dataframe.py
python scripts/gen_training_orders.py
python scripts/merge_orders.py
```
@@ -27,8 +28,7 @@ When finished, the structure under `data/` should be:
data
├── bin
├── orders
── pickle
└── pickle_dataframe
── pickle
```
## Training

View File

@@ -3,15 +3,6 @@ start_time: "9:30"
end_time: "14:54"
qlib:
provider_uri_5min: ./data/bin/
feature_root_dir: ./data/pickle/
feature_columns_today: [
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"
]
feature_columns_yesterday: [
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"
]
exchange:
limit_threshold: null
deal_price: ["$close", "$close"]
@@ -45,10 +36,12 @@ strategies:
data_ticks: 48
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
class: HandlerProcessedDataProvider
kwargs:
data_dir: ./data/pickle_dataframe/feature
module_path: qlib.rl.data.pickle_styled
data_dir: ./data/pickle/
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
module_path: qlib.rl.data.native
module_path: qlib.rl.order_execution.interpreter
module_path: qlib.rl.order_execution.strategy
30min:

View File

@@ -3,15 +3,6 @@ start_time: "9:30"
end_time: "14:54"
qlib:
provider_uri_5min: ./data/bin/
feature_root_dir: ./data/pickle/
feature_columns_today: [
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"
]
feature_columns_yesterday: [
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"
]
exchange:
limit_threshold: null
deal_price: ["$close", "$close"]
@@ -45,10 +36,12 @@ strategies:
data_ticks: 48
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
class: HandlerProcessedDataProvider
kwargs:
data_dir: ./data/pickle_dataframe/feature
module_path: qlib.rl.data.pickle_styled
data_dir: ./data/pickle/
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
module_path: qlib.rl.data.native
module_path: qlib.rl.order_execution.interpreter
module_path: qlib.rl.order_execution.strategy
30min:

View File

@@ -3,15 +3,6 @@ start_time: "9:30"
end_time: "14:54"
qlib:
provider_uri_5min: ./data/bin/
feature_root_dir: ./data/pickle/
feature_columns_today: [
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"
]
feature_columns_yesterday: [
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"
]
exchange:
limit_threshold: null
deal_price: ["$close", "$close"]

View File

@@ -3,8 +3,8 @@ simulator:
time_per_step: 30
vol_limit: null
env:
concurrency: 48
parallel_mode: shmem
concurrency: 32
parallel_mode: dummy
action_interpreter:
class: CategoricalActionInterpreter
kwargs:
@@ -18,10 +18,13 @@ state_interpreter:
data_ticks: 48 # 48 = 240 min / 5 min
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
module_path: qlib.rl.data.pickle_styled
class: HandlerProcessedDataProvider
kwargs:
data_dir: ./data/pickle_dataframe/feature
data_dir: ./data/pickle/
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
backtest: false
module_path: qlib.rl.data.native
module_path: qlib.rl.order_execution.interpreter
reward:
class: PAPenaltyReward
@@ -32,7 +35,9 @@ reward:
data:
source:
order_dir: ./data/orders
data_dir: ./data/pickle_dataframe/backtest
feature_root_dir: ./data/pickle/
feature_columns_today: ["$close0", "$volume0"]
feature_columns_yesterday: []
total_time: 240
default_start_time_index: 0
default_end_time_index: 235

View File

@@ -3,8 +3,8 @@ simulator:
time_per_step: 30
vol_limit: null
env:
concurrency: 48
parallel_mode: shmem
concurrency: 32
parallel_mode: dummy
action_interpreter:
class: CategoricalActionInterpreter
kwargs:
@@ -18,10 +18,13 @@ state_interpreter:
data_ticks: 48 # 48 = 240 min / 5 min
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
module_path: qlib.rl.data.pickle_styled
class: HandlerProcessedDataProvider
kwargs:
data_dir: ./data/pickle_dataframe/feature
data_dir: ./data/pickle/
feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"]
feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"]
backtest: false
module_path: qlib.rl.data.native
module_path: qlib.rl.order_execution.interpreter
reward:
class: PPOReward
@@ -33,7 +36,9 @@ reward:
data:
source:
order_dir: ./data/orders
data_dir: ./data/pickle_dataframe/backtest
feature_root_dir: ./data/pickle/
feature_columns_today: ["$close0", "$volume0"]
feature_columns_yesterday: []
total_time: 240
default_start_time_index: 0
default_end_time_index: 235

View File

@@ -1,26 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import pickle
import pandas as pd
from joblib import Parallel, delayed
os.makedirs(os.path.join("data", "pickle_dataframe"), exist_ok=True)
def _collect(df: pd.DataFrame, instrument: str, tag: str) -> None:
cur = df[df["instrument"] == instrument].sort_values(by=["datetime"])
cur = cur.set_index(["instrument", "datetime", "date"])
pickle.dump(cur, open(os.path.join("data", "pickle_dataframe", tag, f"{instrument}.pkl"), "wb"))
for tag in ("backtest", "feature"):
df = pickle.load(open(os.path.join("data", "pickle", f"{tag}.pkl"), "rb"))
df = pd.concat(list(df.values())).reset_index()
df["date"] = df["datetime"].dt.date.astype("datetime64")
instruments = sorted(set(df["instrument"]))
os.makedirs(os.path.join("data", "pickle_dataframe", tag), exist_ok=True)
Parallel(n_jobs=-1, verbose=10)(delayed(_collect)(df, instrument, tag) for instrument in instruments)

View File

@@ -4,17 +4,22 @@
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path
DATA_PATH = Path(os.path.join("data", "pickle_dataframe", "backtest"))
DATA_PATH = Path(os.path.join("data", "pickle", "backtest"))
OUTPUT_PATH = Path(os.path.join("data", "orders"))
def generate_order(stock: str, start_idx: int, end_idx: int) -> None:
df = pd.read_pickle(DATA_PATH / f"{stock}.pkl")
def generate_order(stock: str, start_idx: int, end_idx: int) -> bool:
dataset = pd.read_pickle(DATA_PATH / f"{stock}.pkl")
df = dataset.handler.fetch(level=None).reset_index()
if len(df) == 0 or df.isnull().values.any() or min(df["$volume0"]) < 1e-5:
return False
df["date"] = df["datetime"].dt.date.astype("datetime64")
df = df.set_index(["instrument", "datetime", "date"])
df = df.groupby("date").take(range(start_idx, end_idx)).droplevel(level=0)
div = df["$volume0"].rolling((end_idx - start_idx) * 60).mean().shift(1).groupby(level="date").transform("first")
order_all = pd.DataFrame(df.groupby(level=(2, 0)).mean().dropna())
order_all["amount"] = np.random.lognormal(-3.28, 1.14) * order_all["$volume0"]
@@ -32,11 +37,17 @@ def generate_order(stock: str, start_idx: int, end_idx: int) -> None:
os.makedirs(path, exist_ok=True)
if len(order) > 0:
order.to_pickle(path / f"{stock}.pkl.target")
return True
np.random.seed(1234)
file_list = sorted(os.listdir(DATA_PATH))
stocks = [f.replace(".pkl", "") for f in file_list]
stocks = sorted(np.random.choice(stocks, size=100, replace=False))
for stock in tqdm(stocks):
generate_order(stock, 0, 240 // 5 - 1)
np.random.shuffle(stocks)
cnt = 0
for stock in stocks:
if generate_order(stock, 0, 240 // 5 - 1):
cnt += 1
if cnt == 100:
break

View File

@@ -154,12 +154,7 @@ def single_with_simulator(
-------
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
"""
if split == "stock":
stock_id = orders.iloc[0].instrument
init_qlib(backtest_config["qlib"], part=stock_id)
else:
day = orders.iloc[0].datetime
init_qlib(backtest_config["qlib"], part=day)
init_qlib(backtest_config["qlib"])
stocks = orders.instrument.unique().tolist()
@@ -253,12 +248,7 @@ def single_with_collect_data_loop(
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
"""
if split == "stock":
stock_id = orders.iloc[0].instrument
init_qlib(backtest_config["qlib"], part=stock_id)
else:
day = orders.iloc[0].datetime
init_qlib(backtest_config["qlib"], part=day)
init_qlib(backtest_config["qlib"])
trade_start_time = orders["datetime"].min()
trade_end_time = orders["datetime"].max()

View File

@@ -1,5 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import argparse
import os
import random
@@ -9,13 +11,12 @@ from typing import cast, List, Optional
import numpy as np
import pandas as pd
import qlib
import torch
import yaml
from qlib.backtest import Order
from qlib.backtest.decision import OrderDir
from qlib.constant import ONE_MIN
from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data
from qlib.rl.data.native import load_handler_intraday_processed_data
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
from qlib.rl.reward import Reward
@@ -49,19 +50,17 @@ def _read_orders(order_dir: Path) -> pd.DataFrame:
class LazyLoadDataset(Dataset):
def __init__(
self,
data_dir: str,
order_file_path: Path,
data_dir: Path,
default_start_time_index: int,
default_end_time_index: int,
) -> None:
self._default_start_time_index = default_start_time_index
self._default_end_time_index = default_end_time_index
self._order_file_path = order_file_path
self._order_df = _read_orders(order_file_path).reset_index()
self._data_dir = data_dir
self._ticks_index: Optional[pd.DatetimeIndex] = None
self._data_dir = Path(data_dir)
def __len__(self) -> int:
return len(self._order_df)
@@ -74,12 +73,17 @@ class LazyLoadDataset(Dataset):
# TODO: We only load ticks index once based on the assumption that ticks index of different dates
# TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index
# TODO: of all dates.
backtest_data = load_simple_intraday_backtest_data(
data = load_handler_intraday_processed_data(
data_dir=self._data_dir,
stock_id=row["instrument"],
date=date,
feature_columns_today=[],
feature_columns_yesterday=[],
backtest=True,
index_only=True,
)
self._ticks_index = [t - date for t in backtest_data.get_time_index()]
self._ticks_index = [t - date for t in data.today.index]
order = Order(
stock_id=row["instrument"],
@@ -104,8 +108,6 @@ def train_and_test(
run_training: bool,
run_backtest: bool,
) -> None:
qlib.init()
order_root_path = Path(data_config["source"]["order_dir"])
data_granularity = simulator_config.get("data_granularity", 1)
@@ -113,10 +115,11 @@ def train_and_test(
def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
return SingleAssetOrderExecutionSimple(
order=order,
data_dir=Path(data_config["source"]["data_dir"]),
ticks_per_step=simulator_config["time_per_step"],
data_dir=data_config["source"]["feature_root_dir"],
feature_columns_today=data_config["source"]["feature_columns_today"],
feature_columns_yesterday=data_config["source"]["feature_columns_yesterday"],
data_granularity=data_granularity,
deal_price_type=data_config["source"].get("deal_price_column", "close"),
ticks_per_step=simulator_config["time_per_step"],
vol_threshold=simulator_config["vol_limit"],
)
@@ -126,8 +129,8 @@ def train_and_test(
if run_training:
train_dataset, valid_dataset = [
LazyLoadDataset(
data_dir=data_config["source"]["feature_root_dir"],
order_file_path=order_root_path / tag,
data_dir=Path(data_config["source"]["data_dir"]),
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
)
@@ -178,8 +181,8 @@ def train_and_test(
if run_backtest:
test_dataset = LazyLoadDataset(
data_dir=data_config["source"]["feature_root_dir"],
order_file_path=order_root_path / "test",
data_dir=Path(data_config["source"]["data_dir"]),
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
)

View File

@@ -8,48 +8,14 @@ TODO: The implementation here is kind of adhoc. It is better to design a more un
from __future__ import annotations
import pickle
from pathlib import Path
from typing import List
import cachetools
import numpy as np
import pandas as pd
import qlib
from qlib.constant import REG_CN
from qlib.contrib.ops.high_freq import BFillNan, Cut, Date, DayCumsum, DayLast, FFillNan, IsInf, IsNull, Select
from qlib.data.dataset import DatasetH
dataset = None
class DataWrapper:
def __init__(
self,
feature_dataset: DatasetH,
backtest_dataset: DatasetH,
columns_today: List[str],
columns_yesterday: List[str],
_internal: bool = False,
):
assert _internal, "Init function of data wrapper is for internal use only."
self.feature_dataset = feature_dataset
self.backtest_dataset = backtest_dataset
self.columns_today = columns_today
self.columns_yesterday = columns_yesterday
@cachetools.cached( # type: ignore
cache=cachetools.LRUCache(100),
key=lambda _, stock_id, date, backtest: (stock_id, date.replace(hour=0, minute=0, second=0), backtest),
)
def get(self, stock_id: str, date: pd.Timestamp, backtest: bool = False) -> pd.DataFrame:
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
dataset = self.backtest_dataset if backtest else self.feature_dataset
return dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
def init_qlib(qlib_config: dict, part: str | None = None) -> None:
def init_qlib(qlib_config: dict) -> None:
"""Initialize necessary resource to launch the workflow, including data direction, feature columns, etc..
Parameters
@@ -72,12 +38,8 @@ def init_qlib(qlib_config: dict, part: str | None = None) -> None:
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1",
],
}
part
Identifying which part (stock / date) to load.
"""
global dataset # pylint: disable=W0603
def _convert_to_path(path: str | Path) -> Path:
return path if isinstance(path, Path) else Path(path)
@@ -118,47 +80,3 @@ def init_qlib(qlib_config: dict, part: str | None = None) -> None:
redis_port=-1,
clear_mem_cache=False, # init_qlib will be called for multiple times. Keep the cache for improving performance
)
if part == "skip":
return
# this won't work if it's put outside in case of multiprocessing
from qlib.data import D # noqa pylint: disable=C0415,W0611
if part is None:
feature_path = Path(qlib_config["feature_root_dir"]) / "feature.pkl"
backtest_path = Path(qlib_config["feature_root_dir"]) / "backtest.pkl"
else:
feature_path = Path(qlib_config["feature_root_dir"]) / "feature" / (part + ".pkl")
backtest_path = Path(qlib_config["feature_root_dir"]) / "backtest" / (part + ".pkl")
with feature_path.open("rb") as f:
feature_dataset = pickle.load(f)
with backtest_path.open("rb") as f:
backtest_dataset = pickle.load(f)
dataset = DataWrapper(
feature_dataset,
backtest_dataset,
qlib_config["feature_columns_today"],
qlib_config["feature_columns_yesterday"],
_internal=True,
)
def fetch_features(stock_id: str, date: pd.Timestamp, yesterday: bool = False, backtest: bool = False) -> pd.DataFrame:
assert dataset is not None, "You must call init_qlib() before doing this."
if backtest:
fields = ["$close", "$volume"]
else:
fields = dataset.columns_yesterday if yesterday else dataset.columns_today
data = dataset.get(stock_id, date, backtest)
if data is None or len(data) == 0:
# create a fake index, but RL doesn't care about index
data = pd.DataFrame(0.0, index=np.arange(240), columns=fields, dtype=np.float32) # FIXME: hardcode here
else:
data = data.rename(columns={c: c.rstrip("0") for c in data.columns})
data = data[fields]
return data

View File

@@ -2,17 +2,29 @@
# Licensed under the MIT License.
from __future__ import annotations
from typing import cast
from pathlib import Path
from typing import cast, List
import cachetools
import pandas as pd
import pickle
import os
from qlib.backtest import Exchange, Order
from qlib.backtest.decision import TradeRange, TradeRangeByTime
from qlib.rl.order_execution.utils import get_ticks_slice
from qlib.constant import EPS_T
from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
from .integration import fetch_features
def get_ticks_slice(
ticks_index: pd.DatetimeIndex,
start: pd.Timestamp,
end: pd.Timestamp,
include_end: bool = False,
) -> pd.DatetimeIndex:
if not include_end:
end = end - EPS_T
return ticks_index[ticks_index.slice_indexer(start, end)]
class IntradayBacktestData(BaseIntradayBacktestData):
@@ -71,6 +83,31 @@ class IntradayBacktestData(BaseIntradayBacktestData):
return pd.DatetimeIndex([e[1] for e in list(self._exchange.quote_df.index)])
class DataframeIntradayBacktestData(BaseIntradayBacktestData):
"""Backtest data from dataframe"""
def __init__(self, df: pd.DataFrame, price_column: str = "$close0", volume_column: str = "$volume0") -> None:
self.df = df
self.price_column = price_column
self.volume_column = volume_column
def __repr__(self) -> str:
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
return f"{self.__class__.__name__}({self.df})"
def __len__(self) -> int:
return len(self.df)
def get_deal_price(self) -> pd.Series:
return self.df[self.price_column]
def get_volume(self) -> pd.Series:
return self.df[self.volume_column]
def get_time_index(self) -> pd.DatetimeIndex:
return cast(pd.DatetimeIndex, self.df.index)
@cachetools.cached( # type: ignore
cache=cachetools.LRUCache(100),
key=lambda order, _, __: order.key_by_day,
@@ -103,13 +140,18 @@ def load_backtest_data(
return backtest_data
class NTIntradayProcessedData(BaseIntradayProcessedData):
"""Subclass of IntradayProcessedData. Used to handle NT style data."""
class HandlerIntradayProcessedData(BaseIntradayProcessedData):
"""Subclass of IntradayProcessedData. Used to handle handler (bin format) style data."""
def __init__(
self,
data_dir: Path,
stock_id: str,
date: pd.Timestamp,
feature_columns_today: List[str],
feature_columns_yesterday: List[str],
backtest: bool = False,
index_only: bool = False,
) -> None:
def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame:
df = df.reset_index()
@@ -117,8 +159,18 @@ class NTIntradayProcessedData(BaseIntradayProcessedData):
df = df.drop(columns=["instrument"])
return df.set_index(["datetime"])
self.today = _drop_stock_id(fetch_features(stock_id, date))
self.yesterday = _drop_stock_id(fetch_features(stock_id, date, yesterday=True))
path = os.path.join(data_dir, "backtest" if backtest else "feature", f"{stock_id}.pkl")
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
with open(path, "rb") as fstream:
dataset = pickle.load(fstream)
data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
if index_only:
self.today = _drop_stock_id(data[[]])
self.yesterday = _drop_stock_id(data[[]])
else:
self.today = _drop_stock_id(data[feature_columns_today])
self.yesterday = _drop_stock_id(data[feature_columns_yesterday])
def __repr__(self) -> str:
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
@@ -127,12 +179,42 @@ class NTIntradayProcessedData(BaseIntradayProcessedData):
@cachetools.cached( # type: ignore
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
key=lambda data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only: (
stock_id,
date,
backtest,
index_only,
),
)
def load_nt_intraday_processed_data(stock_id: str, date: pd.Timestamp) -> NTIntradayProcessedData:
return NTIntradayProcessedData(stock_id, date)
def load_handler_intraday_processed_data(
data_dir: Path,
stock_id: str,
date: pd.Timestamp,
feature_columns_today: List[str],
feature_columns_yesterday: List[str],
backtest: bool = False,
index_only: bool = False,
) -> HandlerIntradayProcessedData:
return HandlerIntradayProcessedData(
data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only
)
class NTProcessedDataProvider(ProcessedDataProvider):
class HandlerProcessedDataProvider(ProcessedDataProvider):
def __init__(
self,
data_dir: str,
feature_columns_today: List[str],
feature_columns_yesterday: List[str],
backtest: bool = False,
) -> None:
super().__init__()
self.data_dir = Path(data_dir)
self.feature_columns_today = feature_columns_today
self.feature_columns_yesterday = feature_columns_yesterday
self.backtest = backtest
def get_data(
self,
stock_id: str,
@@ -140,4 +222,12 @@ class NTProcessedDataProvider(ProcessedDataProvider):
feature_dim: int,
time_index: pd.Index,
) -> BaseIntradayProcessedData:
return load_nt_intraday_processed_data(stock_id, date)
return load_handler_intraday_processed_data(
self.data_dir,
stock_id,
date,
self.feature_columns_today,
self.feature_columns_yesterday,
backtest=self.backtest,
index_only=False,
)

View File

@@ -158,8 +158,8 @@ class SimpleIntradayBacktestData(BaseIntradayBacktestData):
return cast(pd.DatetimeIndex, self.data.index)
class IntradayProcessedData(BaseIntradayProcessedData):
"""Subclass of IntradayProcessedData. Used to handle Dataset Handler style data."""
class PickleIntradayProcessedData(BaseIntradayProcessedData):
"""Subclass of IntradayProcessedData. Used to handle pickle-styled data."""
def __init__(
self,
@@ -217,14 +217,14 @@ def load_simple_intraday_backtest_data(
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
key=lambda data_dir, stock_id, date, feature_dim, time_index: hashkey(data_dir, stock_id, date),
)
def load_pickled_intraday_processed_data(
def load_pickle_intraday_processed_data(
data_dir: Path,
stock_id: str,
date: pd.Timestamp,
feature_dim: int,
time_index: pd.Index,
) -> BaseIntradayProcessedData:
return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index)
return PickleIntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index)
class PickleProcessedDataProvider(ProcessedDataProvider):
@@ -240,7 +240,7 @@ class PickleProcessedDataProvider(ProcessedDataProvider):
feature_dim: int,
time_index: pd.Index,
) -> BaseIntradayProcessedData:
return load_pickled_intraday_processed_data(
return load_pickle_intraday_processed_data(
data_dir=self._data_dir,
stock_id=stock_id,
date=date,

View File

@@ -67,7 +67,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
cash_limit: Optional[float] = None,
) -> None:
if qlib_config is not None:
init_qlib(qlib_config, part="skip")
init_qlib(qlib_config)
strategy, self._executor = get_strategy_executor(
start_time=order.date,

View File

@@ -3,17 +3,19 @@
from __future__ import annotations
from pathlib import Path
from typing import Any, cast, Optional
from typing import Any, cast, List, Optional
import numpy as np
import pandas as pd
from pathlib import Path
from qlib.backtest.decision import Order, OrderDir
from qlib.constant import EPS, EPS_T, float_or_ndarray
from qlib.rl.data.pickle_styled import DealPriceType, load_simple_intraday_backtest_data
from qlib.rl.data.base import BaseIntradayBacktestData
from qlib.rl.data.native import DataframeIntradayBacktestData, load_handler_intraday_processed_data
from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data
from qlib.rl.simulator import Simulator
from qlib.rl.utils import LogLevel
from .state import SAOEMetrics, SAOEState
__all__ = ["SingleAssetOrderExecutionSimple"]
@@ -36,12 +38,16 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
----------
order
The seed to start an SAOE simulator is an order.
data_dir
Path to load backtest data.
feature_columns_today
Columns of today's feature.
feature_columns_yesterday
Columns of yesterday's feature.
data_granularity
Number of ticks between consecutive data entries.
ticks_per_step
How many ticks per step.
data_dir
Path to load backtest data
vol_threshold
Maximum execution volume (divided by market execution volume).
"""
@@ -73,9 +79,10 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
self,
order: Order,
data_dir: Path,
feature_columns_today: List[str] = [],
feature_columns_yesterday: List[str] = [],
data_granularity: int = 1,
ticks_per_step: int = 30,
deal_price_type: DealPriceType = "close",
vol_threshold: Optional[float] = None,
) -> None:
super().__init__(initial=order)
@@ -83,18 +90,13 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
assert ticks_per_step % data_granularity == 0
self.order = order
self.ticks_per_step: int = ticks_per_step // data_granularity
self.deal_price_type = deal_price_type
self.vol_threshold = vol_threshold
self.data_dir = data_dir
self.backtest_data = load_simple_intraday_backtest_data(
self.data_dir,
order.stock_id,
pd.Timestamp(order.start_time.date()),
self.deal_price_type,
order.direction,
)
self.feature_columns_today = feature_columns_today
self.feature_columns_yesterday = feature_columns_yesterday
self.ticks_per_step: int = ticks_per_step // data_granularity
self.vol_threshold = vol_threshold
self.backtest_data = self.get_backtest_data()
self.ticks_index = self.backtest_data.get_time_index()
# Get time index available for trading
@@ -118,6 +120,30 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
self.market_vol: Optional[np.ndarray] = None
self.market_vol_limit: Optional[np.ndarray] = None
def get_backtest_data(self) -> BaseIntradayBacktestData:
try:
data = load_handler_intraday_processed_data(
data_dir=self.data_dir,
stock_id=self.order.stock_id,
date=pd.Timestamp(self.order.start_time.date()),
feature_columns_today=self.feature_columns_today,
feature_columns_yesterday=self.feature_columns_yesterday,
backtest=True,
index_only=False,
)
return DataframeIntradayBacktestData(data.today)
except (AttributeError, FileNotFoundError):
# TODO: For compatibility with older versions of test scripts (tests/rl/test_saoe_simple.py)
# TODO: In the future, we should modify the data format used by the test script,
# TODO: and then delete this branch.
return load_simple_intraday_backtest_data(
self.data_dir / "backtest",
self.order.stock_id,
pd.Timestamp(self.order.start_time.date()),
"close",
self.order.direction,
)
def step(self, amount: float) -> None:
"""Execute one step or SAOE.

View File

@@ -10,18 +10,7 @@ import pandas as pd
from qlib.backtest.decision import OrderDir
from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor
from qlib.constant import EPS_T, float_or_ndarray
def get_ticks_slice(
ticks_index: pd.DatetimeIndex,
start: pd.Timestamp,
end: pd.Timestamp,
include_end: bool = False,
) -> pd.DatetimeIndex:
if not include_end:
end = end - EPS_T
return ticks_index[ticks_index.slice_indexer(start, end)]
from qlib.constant import float_or_ndarray
def dataframe_append(df: pd.DataFrame, other: Any) -> pd.DataFrame:

View File

@@ -31,7 +31,6 @@ FEATURE_DATA_DIR = DATA_DIR / "processed"
ORDER_DIR = DATA_DIR / "order" / "valid_bidir"
CN_DATA_DIR = DATA_ROOT_DIR / "cn"
CN_BACKTEST_DATA_DIR = CN_DATA_DIR / "backtest"
CN_FEATURE_DATA_DIR = CN_DATA_DIR / "processed"
CN_ORDER_DIR = CN_DATA_DIR / "order" / "test"
CN_POLICY_WEIGHTS_DIR = CN_DATA_DIR / "weights"
@@ -49,7 +48,7 @@ def test_pickle_data_inspect():
def test_simulator_first_step():
order = Order("AAL", 30.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
state = simulator.get_state()
assert state.cur_time == pd.Timestamp("2013-12-11 09:30:00")
assert state.position == 30.0
@@ -83,7 +82,7 @@ def test_simulator_first_step():
def test_simulator_stop_twap():
order = Order("AAL", 13.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
for _ in range(13):
simulator.step(1.0)
@@ -106,10 +105,10 @@ def test_simulator_stop_early():
order = Order("AAL", 1.0, 1, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59"))
with pytest.raises(ValueError):
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
simulator.step(2.0)
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
simulator.step(1.0)
with pytest.raises(AssertionError):
@@ -119,7 +118,7 @@ def test_simulator_stop_early():
def test_simulator_start_middle():
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59"))
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
assert len(simulator.ticks_for_order) == 330
assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00")
simulator.step(2.0)
@@ -138,7 +137,7 @@ def test_simulator_start_middle():
def test_interpreter():
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59"))
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
assert len(simulator.ticks_for_order) == 330
assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00")
@@ -219,7 +218,7 @@ def test_network_sanity():
# we won't check the correctness of networks here
order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 9:30:00"), pd.Timestamp("2013-12-11 15:59:59"))
simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR)
simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR)
assert len(simulator.ticks_for_order) == 390
class EmulateEnvWrapper(NamedTuple):
@@ -259,7 +258,7 @@ def test_twap_strategy(finite_env_type):
csv_writer = CsvWriter(Path(__file__).parent / ".output")
backtest(
partial(SingleAssetOrderExecutionSimple, data_dir=BACKTEST_DATA_DIR, ticks_per_step=30),
partial(SingleAssetOrderExecutionSimple, data_dir=DATA_DIR, ticks_per_step=30),
state_interp,
action_interp,
orders,
@@ -290,7 +289,7 @@ def test_cn_ppo_strategy():
csv_writer = CsvWriter(Path(__file__).parent / ".output")
backtest(
partial(SingleAssetOrderExecutionSimple, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
partial(SingleAssetOrderExecutionSimple, data_dir=CN_DATA_DIR, ticks_per_step=30),
state_interp,
action_interp,
orders,
@@ -319,7 +318,7 @@ def test_ppo_train():
policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)
train(
partial(SingleAssetOrderExecutionSimple, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30),
partial(SingleAssetOrderExecutionSimple, data_dir=CN_DATA_DIR, ticks_per_step=30),
state_interp,
action_interp,
orders,