mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 03:50:57 +08:00
Refine Qlib RL data format (#1480)
* wip * wip * wip * Fix naming errors * Backtest test passed * Why training stuck? * Minor * Refine train configs * Use dummy in training * Remove pickle_dataframe * CI * CI * Add more strict condition to filter orders * Pass test * Add TODO in example --------- Co-authored-by: Young <afe.young@gmail.com>
This commit is contained in:
@@ -1,26 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import pandas as pd
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
os.makedirs(os.path.join("data", "pickle_dataframe"), exist_ok=True)
|
||||
|
||||
|
||||
def _collect(df: pd.DataFrame, instrument: str, tag: str) -> None:
|
||||
cur = df[df["instrument"] == instrument].sort_values(by=["datetime"])
|
||||
cur = cur.set_index(["instrument", "datetime", "date"])
|
||||
pickle.dump(cur, open(os.path.join("data", "pickle_dataframe", tag, f"{instrument}.pkl"), "wb"))
|
||||
|
||||
|
||||
for tag in ("backtest", "feature"):
|
||||
df = pickle.load(open(os.path.join("data", "pickle", f"{tag}.pkl"), "rb"))
|
||||
df = pd.concat(list(df.values())).reset_index()
|
||||
df["date"] = df["datetime"].dt.date.astype("datetime64")
|
||||
instruments = sorted(set(df["instrument"]))
|
||||
|
||||
os.makedirs(os.path.join("data", "pickle_dataframe", tag), exist_ok=True)
|
||||
|
||||
Parallel(n_jobs=-1, verbose=10)(delayed(_collect)(df, instrument, tag) for instrument in instruments)
|
||||
@@ -4,17 +4,22 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
DATA_PATH = Path(os.path.join("data", "pickle_dataframe", "backtest"))
|
||||
DATA_PATH = Path(os.path.join("data", "pickle", "backtest"))
|
||||
OUTPUT_PATH = Path(os.path.join("data", "orders"))
|
||||
|
||||
|
||||
def generate_order(stock: str, start_idx: int, end_idx: int) -> None:
|
||||
df = pd.read_pickle(DATA_PATH / f"{stock}.pkl")
|
||||
def generate_order(stock: str, start_idx: int, end_idx: int) -> bool:
|
||||
dataset = pd.read_pickle(DATA_PATH / f"{stock}.pkl")
|
||||
df = dataset.handler.fetch(level=None).reset_index()
|
||||
if len(df) == 0 or df.isnull().values.any() or min(df["$volume0"]) < 1e-5:
|
||||
return False
|
||||
|
||||
df["date"] = df["datetime"].dt.date.astype("datetime64")
|
||||
df = df.set_index(["instrument", "datetime", "date"])
|
||||
df = df.groupby("date").take(range(start_idx, end_idx)).droplevel(level=0)
|
||||
div = df["$volume0"].rolling((end_idx - start_idx) * 60).mean().shift(1).groupby(level="date").transform("first")
|
||||
|
||||
order_all = pd.DataFrame(df.groupby(level=(2, 0)).mean().dropna())
|
||||
order_all["amount"] = np.random.lognormal(-3.28, 1.14) * order_all["$volume0"]
|
||||
@@ -32,11 +37,17 @@ def generate_order(stock: str, start_idx: int, end_idx: int) -> None:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
if len(order) > 0:
|
||||
order.to_pickle(path / f"{stock}.pkl.target")
|
||||
return True
|
||||
|
||||
|
||||
np.random.seed(1234)
|
||||
file_list = sorted(os.listdir(DATA_PATH))
|
||||
stocks = [f.replace(".pkl", "") for f in file_list]
|
||||
stocks = sorted(np.random.choice(stocks, size=100, replace=False))
|
||||
for stock in tqdm(stocks):
|
||||
generate_order(stock, 0, 240 // 5 - 1)
|
||||
np.random.shuffle(stocks)
|
||||
|
||||
cnt = 0
|
||||
for stock in stocks:
|
||||
if generate_order(stock, 0, 240 // 5 - 1):
|
||||
cnt += 1
|
||||
if cnt == 100:
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user