1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 03:50:57 +08:00

Refine Qlib RL data format (#1480)

* wip

* wip

* wip

* Fix naming errors

* Backtest test passed

* Why training stuck?

* Minor

* Refine train configs

* Use dummy in training

* Remove pickle_dataframe

* CI

* CI

* Add more strict condition to filter orders

* Pass test

* Add TODO in example

---------

Co-authored-by: Young <afe.young@gmail.com>
This commit is contained in:
Huoran Li
2023-04-26 21:14:30 +08:00
committed by GitHub
parent 46264dfec9
commit 7f1e8c5206
17 changed files with 237 additions and 250 deletions

View File

@@ -1,26 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import pickle
import pandas as pd
from joblib import Parallel, delayed
os.makedirs(os.path.join("data", "pickle_dataframe"), exist_ok=True)
def _collect(df: pd.DataFrame, instrument: str, tag: str) -> None:
cur = df[df["instrument"] == instrument].sort_values(by=["datetime"])
cur = cur.set_index(["instrument", "datetime", "date"])
pickle.dump(cur, open(os.path.join("data", "pickle_dataframe", tag, f"{instrument}.pkl"), "wb"))
for tag in ("backtest", "feature"):
df = pickle.load(open(os.path.join("data", "pickle", f"{tag}.pkl"), "rb"))
df = pd.concat(list(df.values())).reset_index()
df["date"] = df["datetime"].dt.date.astype("datetime64")
instruments = sorted(set(df["instrument"]))
os.makedirs(os.path.join("data", "pickle_dataframe", tag), exist_ok=True)
Parallel(n_jobs=-1, verbose=10)(delayed(_collect)(df, instrument, tag) for instrument in instruments)

View File

@@ -4,17 +4,22 @@
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path
DATA_PATH = Path(os.path.join("data", "pickle_dataframe", "backtest"))
DATA_PATH = Path(os.path.join("data", "pickle", "backtest"))
OUTPUT_PATH = Path(os.path.join("data", "orders"))
def generate_order(stock: str, start_idx: int, end_idx: int) -> None:
df = pd.read_pickle(DATA_PATH / f"{stock}.pkl")
def generate_order(stock: str, start_idx: int, end_idx: int) -> bool:
dataset = pd.read_pickle(DATA_PATH / f"{stock}.pkl")
df = dataset.handler.fetch(level=None).reset_index()
if len(df) == 0 or df.isnull().values.any() or min(df["$volume0"]) < 1e-5:
return False
df["date"] = df["datetime"].dt.date.astype("datetime64")
df = df.set_index(["instrument", "datetime", "date"])
df = df.groupby("date").take(range(start_idx, end_idx)).droplevel(level=0)
div = df["$volume0"].rolling((end_idx - start_idx) * 60).mean().shift(1).groupby(level="date").transform("first")
order_all = pd.DataFrame(df.groupby(level=(2, 0)).mean().dropna())
order_all["amount"] = np.random.lognormal(-3.28, 1.14) * order_all["$volume0"]
@@ -32,11 +37,17 @@ def generate_order(stock: str, start_idx: int, end_idx: int) -> None:
os.makedirs(path, exist_ok=True)
if len(order) > 0:
order.to_pickle(path / f"{stock}.pkl.target")
return True
np.random.seed(1234)
file_list = sorted(os.listdir(DATA_PATH))
stocks = [f.replace(".pkl", "") for f in file_list]
stocks = sorted(np.random.choice(stocks, size=100, replace=False))
for stock in tqdm(stocks):
generate_order(stock, 0, 240 // 5 - 1)
np.random.shuffle(stocks)
cnt = 0
for stock in stocks:
if generate_order(stock, 0, 240 // 5 - 1):
cnt += 1
if cnt == 100:
break