mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 19:10:58 +08:00
Merge github.com:microsoft/qlib into qlib_register_ops
This commit is contained in:
@@ -226,11 +226,12 @@ Here is a list of models built on `Qlib`.
|
||||
- [GBDT based on XGBoost (Tianqi Chen, et al.)](qlib/contrib/model/xgboost.py)
|
||||
- [MLP based on pytorch](qlib/contrib/model/pytorch_nn.py)
|
||||
- [GRU based on pytorch (Kyunghyun Cho, et al.)](qlib/contrib/model/pytorch_gru.py)
|
||||
- [LSTM based on pytorcn (Sepp Hochreiter, et al.)](qlib/contrib/model/pytorch_lstm.py)
|
||||
- [ALSTM based on pytorcn (Yao Qin, et al.)](qlib/contrib/model/pytorch_alstm.py)
|
||||
- [LSTM based on pytorch (Sepp Hochreiter, et al.)](qlib/contrib/model/pytorch_lstm.py)
|
||||
- [ALSTM based on pytorch (Yao Qin, et al.)](qlib/contrib/model/pytorch_alstm.py)
|
||||
- [GATs based on pytorch (Petar Velickovic, et al.)](qlib/contrib/model/pytorch_gats.py)
|
||||
- [SFM based on pytorch (Liheng Zhang, et al.)](qlib/contrib/model/pytorch_sfm.py)
|
||||
- [TFT based on tensorflow (Bryan Lim, et al.)](examples/benchmarks/TFT/tft.py)
|
||||
- [TabNet based on pytorch (Sercan O. Arik, et al.)](qlib/contrib/model/pytorch_tabnet.py)
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| XGBoost (Tianqi Chen, et al.) | Alpha158 | 0.0481±0.00 | 0.3659±0.00| 0.0495±0.00 | 0.4033±0.00 | 0.1111±0.00 | 1.2915±0.00| -0.0893±0.00 |
|
||||
| LightGBM (Guolin Ke, et al.) | Alpha158 | 0.0475±0.00 | 0.3979±0.00| 0.0485±0.00 | 0.4123±0.00 | 0.1143±0.00 | 1.2744±0.00| -0.0800±0.00 |
|
||||
| MLP | Alpha158 | 0.0358±0.00 | 0.2738±0.03| 0.0425±0.00 | 0.3221±0.01 | 0.0836±0.02 | 1.0323±0.25| -0.1127±0.02 |
|
||||
| TabNet with pretrain (Sercan O. Arikm et al) | Alpha158 | 0.0344±0.00|0.205±0.11|0.0398±0.00 |0.3479±0.01|0.0827±0.02|1.1141±0.32 |-0.0925±0.02 |
|
||||
| TFT (Bryan Lim, et al.) | Alpha158 (with selected 20 features) | 0.0343±0.00 | 0.2071±0.02| 0.0107±0.00 | 0.0660±0.02 | 0.0623±0.02 | 0.5818±0.20| -0.1762±0.01 |
|
||||
| GRU (Kyunghyun Cho, et al.) | Alpha158 (with selected 20 features) | 0.0311±0.00 | 0.2418±0.04| 0.0425±0.00 | 0.3434±0.02 | 0.0330±0.02 | 0.4805±0.30| -0.1021±0.02 |
|
||||
| LSTM (Sepp Hochreiter, et al.) | Alpha158 (with selected 20 features) | 0.0312±0.00 | 0.2394±0.04| 0.0418±0.00 | 0.3324±0.03 | 0.0298±0.02 | 0.4198±0.33| -0.1348±0.03 |
|
||||
|
||||
BIN
examples/benchmarks/TabNet/pretrain/best.model
Normal file
BIN
examples/benchmarks/TabNet/pretrain/best.model
Normal file
Binary file not shown.
4
examples/benchmarks/TabNet/requirements.txt
Normal file
4
examples/benchmarks/TabNet/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
@@ -0,0 +1,74 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TabnetModel
|
||||
module_path: qlib.contrib.model.pytorch_tabnet
|
||||
kwargs:
|
||||
pretrain: True
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
pretrain: [2008-01-01, 2014-12-31]
|
||||
pretrain_validation: [2015-01-01, 2020-08-01]
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
__version__ = "0.6.1.dev"
|
||||
__version__ = "0.6.1.99"
|
||||
|
||||
|
||||
import os
|
||||
@@ -39,11 +39,10 @@ def init(default_conf="client", **kwargs):
|
||||
else:
|
||||
raise NotImplementedError(f"This type of URI is not supported")
|
||||
|
||||
if "flask_server" in C:
|
||||
logger.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
|
||||
|
||||
C.register()
|
||||
|
||||
if "flask_server" in C:
|
||||
logger.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
|
||||
logger.info("qlib successfully initialized based on %s settings." % default_conf)
|
||||
logger.info(f"data_path={C.get_data_path()}")
|
||||
|
||||
|
||||
@@ -259,7 +259,7 @@ class DNNModelPytorch(Model):
|
||||
loss = torch.mul(sqr_loss, w).mean()
|
||||
return loss
|
||||
elif loss_type == "binary":
|
||||
loss = nn.BCELoss()
|
||||
loss = nn.BCELoss(weight=w)
|
||||
return loss(pred, target)
|
||||
else:
|
||||
raise NotImplementedError("loss {} is not supported!".format(loss_type))
|
||||
|
||||
642
qlib/contrib/model/pytorch_tabnet.py
Normal file
642
qlib/contrib/model/pytorch_tabnet.py
Normal file
@@ -0,0 +1,642 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Function
|
||||
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
class TabnetModel(Model):
|
||||
def __init__(
|
||||
self,
|
||||
d_feat=158,
|
||||
out_dim=64,
|
||||
final_out_dim=1,
|
||||
batch_size=4096,
|
||||
n_d=64,
|
||||
n_a=64,
|
||||
n_shared=2,
|
||||
n_ind=2,
|
||||
n_steps=5,
|
||||
n_epochs=100,
|
||||
pretrain_n_epochs=50,
|
||||
relax=1.3,
|
||||
vbs=2048,
|
||||
seed=993,
|
||||
optimizer="adam",
|
||||
loss="mse",
|
||||
metric="",
|
||||
early_stop=20,
|
||||
GPU="1",
|
||||
pretrain_loss="custom",
|
||||
ps=0.3,
|
||||
lr=0.01,
|
||||
pretrain=True,
|
||||
pretrain_file="./pretrain/best.model",
|
||||
):
|
||||
"""
|
||||
TabNet model for Qlib
|
||||
|
||||
Args:
|
||||
ps: probability to generate the bernoulli mask
|
||||
"""
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.out_dim = out_dim
|
||||
self.final_out_dim = final_out_dim
|
||||
self.lr = lr
|
||||
self.batch_size = batch_size
|
||||
self.optimizer = optimizer.lower()
|
||||
self.pretrain_loss = pretrain_loss
|
||||
self.seed = seed
|
||||
self.ps = ps
|
||||
self.n_epochs = n_epochs
|
||||
self.logger = get_module_logger("TabNet")
|
||||
self.pretrain_n_epochs = pretrain_n_epochs
|
||||
self.device = "cuda:%s" % (GPU) if torch.cuda.is_available() else "cpu"
|
||||
self.loss = loss
|
||||
self.metric = metric
|
||||
self.early_stop = early_stop
|
||||
self.pretrain = pretrain
|
||||
self.pretrain_file = pretrain_file
|
||||
self.logger.info(
|
||||
"TabNet:"
|
||||
"\nbatch_size : {}"
|
||||
"\nvirtual bs : {}"
|
||||
"\nGPU : {}"
|
||||
"\npretrain: {}".format(self.batch_size, vbs, GPU, pretrain)
|
||||
)
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.tabnet_model = TabNet(
|
||||
inp_dim=self.d_feat, out_dim=self.out_dim, vbs=vbs, relax=relax, device=self.device
|
||||
).to(self.device)
|
||||
self.tabnet_decoder = TabNet_Decoder(self.out_dim, self.d_feat, n_shared, n_ind, vbs, n_steps, self.device).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.pretrain_optimizer = optim.Adam(
|
||||
list(self.tabnet_model.parameters()) + list(self.tabnet_decoder.parameters()), lr=self.lr
|
||||
)
|
||||
self.train_optimizer = optim.Adam(self.tabnet_model.parameters(), lr=self.lr)
|
||||
|
||||
elif optimizer.lower() == "gd":
|
||||
self.pretrain_optimizer = optim.SGD(
|
||||
list(self.tabnet_model.parameters()) + list(self.tabnet_decoder.parameters()), lr=self.lr
|
||||
)
|
||||
self.train_optimizer = optim.SGD(self.tabnet_model.parameters(), lr=self.lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
def pretrain_fn(self, dataset=DatasetH, pretrain_file="./pretrain/best.model"):
|
||||
# make a directory if pretrian director does not exist
|
||||
if pretrain_file.startswith("./pretrain") and not os.path.exists("pretrain"):
|
||||
self.logger.info("make folder to store model...")
|
||||
os.makedirs("pretrain")
|
||||
|
||||
[df_train, df_valid] = dataset.prepare(
|
||||
["pretrain", "pretrain_validation"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
|
||||
df_train.fillna(df_train.mean(), inplace=True)
|
||||
df_valid.fillna(df_valid.mean(), inplace=True)
|
||||
|
||||
x_train = df_train["feature"]
|
||||
x_valid = df_valid["feature"]
|
||||
|
||||
# Early stop setup
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_loss = np.inf
|
||||
|
||||
for epoch_idx in range(self.pretrain_n_epochs):
|
||||
self.logger.info("epoch: %s" % (epoch_idx))
|
||||
self.logger.info("pre-training...")
|
||||
self.pretrain_epoch(x_train)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss = self.pretrain_test_epoch(x_train)
|
||||
valid_loss = self.pretrain_test_epoch(x_valid)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_loss, valid_loss))
|
||||
|
||||
if valid_loss < best_loss:
|
||||
self.logger.info("Save Model...")
|
||||
torch.save(self.tabnet_model.state_dict(), pretrain_file)
|
||||
best_loss = valid_loss
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
if self.pretrain:
|
||||
# there is a pretrained model, load the model
|
||||
self.logger.info("Pretrain...")
|
||||
self.pretrain_fn(dataset, self.pretrain_file)
|
||||
self.logger.info("Load Pretrain model")
|
||||
self.tabnet_model.load_state_dict(torch.load(self.pretrain_file))
|
||||
|
||||
# adding one more linear layer to fit the final output dimension
|
||||
self.tabnet_model = FinetuneModel(self.out_dim, self.final_out_dim, self.tabnet_model).to(self.device)
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
df_train.fillna(df_train.mean(), inplace=True)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
|
||||
for epoch_idx in range(self.n_epochs):
|
||||
self.logger.info("epoch: %s" % (epoch_idx))
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(x_train, y_train)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(x_train, y_train)
|
||||
valid_loss, val_score = self.test_epoch(x_valid, y_valid)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score < best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = epoch_idx
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.tabnet_model.eval()
|
||||
x_values = torch.from_numpy(x_test.values)
|
||||
x_values[torch.isnan(x_values)] = 0
|
||||
sample_num = x_values.shape[0]
|
||||
preds = []
|
||||
|
||||
for begin in range(sample_num)[:: self.batch_size]:
|
||||
if sample_num - begin < self.batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = x_values[begin:end].float().to(self.device)
|
||||
priors = torch.ones(end - begin, self.d_feat).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.tabnet_model(x_batch, priors).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
def test_epoch(self, data_x, data_y):
|
||||
# prepare training data
|
||||
x_values = torch.from_numpy(data_x.values)
|
||||
y_values = torch.from_numpy(np.squeeze(data_y.values))
|
||||
x_values[torch.isnan(x_values)] = 0
|
||||
y_values[torch.isnan(y_values)] = 0
|
||||
self.tabnet_model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
indices = np.arange(len(x_values))
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
feature = x_values[indices[i : i + self.batch_size]].float().to(self.device)
|
||||
label = y_values[indices[i : i + self.batch_size]].float().to(self.device)
|
||||
priors = torch.ones(self.batch_size, self.d_feat).to(self.device)
|
||||
pred = self.tabnet_model(feature, priors)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def train_epoch(self, x_train, y_train):
|
||||
x_train_values = torch.from_numpy(x_train.values)
|
||||
y_train_values = torch.from_numpy(np.squeeze(y_train.values))
|
||||
x_train_values[torch.isnan(x_train_values)] = 0
|
||||
y_train_values[torch.isnan(y_train_values)] = 0
|
||||
self.tabnet_model.train()
|
||||
|
||||
indices = np.arange(len(x_train_values))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = x_train_values[indices[i : i + self.batch_size]].float().to(self.device)
|
||||
label = y_train_values[indices[i : i + self.batch_size]].float().to(self.device)
|
||||
priors = torch.ones(self.batch_size, self.d_feat).to(self.device)
|
||||
pred = self.tabnet_model(feature, priors)
|
||||
loss = self.loss_fn(pred, label)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.tabnet_model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def pretrain_epoch(self, x_train):
|
||||
train_set = torch.from_numpy(x_train.values)
|
||||
train_set[torch.isnan(train_set)] = 0
|
||||
indices = np.arange(len(train_set))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
self.tabnet_model.train()
|
||||
self.tabnet_decoder.train()
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
S_mask = torch.bernoulli(torch.empty(self.batch_size, self.d_feat).fill_(self.ps))
|
||||
x_train_values = train_set[indices[i : i + self.batch_size]] * (1 - S_mask)
|
||||
y_train_values = train_set[indices[i : i + self.batch_size]] * (S_mask)
|
||||
|
||||
S_mask = S_mask.to(self.device)
|
||||
feature = x_train_values.float().to(self.device)
|
||||
label = y_train_values.float().to(self.device)
|
||||
priors = 1 - S_mask
|
||||
(vec, sparse_loss) = self.tabnet_model(feature, priors)
|
||||
f = self.tabnet_decoder(vec)
|
||||
loss = self.pretrain_loss_fn(label, f, S_mask)
|
||||
|
||||
self.pretrain_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
self.pretrain_optimizer.step()
|
||||
|
||||
def pretrain_test_epoch(self, x_train):
|
||||
train_set = torch.from_numpy(x_train.values)
|
||||
train_set[torch.isnan(train_set)] = 0
|
||||
indices = np.arange(len(train_set))
|
||||
|
||||
self.tabnet_model.eval()
|
||||
self.tabnet_decoder.eval()
|
||||
|
||||
losses = []
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
S_mask = torch.bernoulli(torch.empty(self.batch_size, self.d_feat).fill_(self.ps))
|
||||
x_train_values = train_set[indices[i : i + self.batch_size]] * (1 - S_mask)
|
||||
y_train_values = train_set[indices[i : i + self.batch_size]] * (S_mask)
|
||||
|
||||
feature = x_train_values.float().to(self.device)
|
||||
label = y_train_values.float().to(self.device)
|
||||
S_mask = S_mask.to(self.device)
|
||||
priors = 1 - S_mask
|
||||
(vec, sparse_loss) = self.tabnet_model(feature, priors)
|
||||
f = self.tabnet_decoder(vec)
|
||||
|
||||
loss = self.pretrain_loss_fn(label, f, S_mask)
|
||||
losses.append(loss.item())
|
||||
|
||||
return np.mean(losses)
|
||||
|
||||
def pretrain_loss_fn(self, f_hat, f, S):
|
||||
"""
|
||||
Pretrain loss function defined in the original paper, read "Tabular self-supervised learning" in https://arxiv.org/pdf/1908.07442.pdf
|
||||
"""
|
||||
down_mean = torch.mean(f, dim=0)
|
||||
down = torch.sqrt(torch.sum(torch.square(f - down_mean), dim=0))
|
||||
up = (f_hat - f) * S
|
||||
return torch.sum(torch.square(up / down))
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask])
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
mask = torch.isfinite(label)
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
|
||||
class FinetuneModel(nn.Module):
|
||||
"""
|
||||
FinuetuneModel for adding a layer by the end
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim, output_dim, trained_model):
|
||||
super().__init__()
|
||||
self.model = trained_model
|
||||
self.fc = nn.Linear(input_dim, output_dim)
|
||||
|
||||
def forward(self, x, priors):
|
||||
return self.fc(self.model(x, priors)[0]).squeeze() # take the vec out
|
||||
|
||||
|
||||
class DecoderStep(nn.Module):
|
||||
def __init__(self, inp_dim, out_dim, shared, n_ind, vbs, device):
|
||||
super().__init__()
|
||||
self.fea_tran = FeatureTransformer(inp_dim, out_dim, shared, n_ind, vbs, device)
|
||||
self.fc = nn.Linear(out_dim, out_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fea_tran(x)
|
||||
return self.fc(x)
|
||||
|
||||
|
||||
class TabNet_Decoder(nn.Module):
|
||||
def __init__(self, inp_dim, out_dim, n_shared, n_ind, vbs, n_steps, device):
|
||||
"""
|
||||
TabNet decoder that is used in pre-training
|
||||
"""
|
||||
self.out_dim = out_dim
|
||||
|
||||
super().__init__()
|
||||
if n_shared > 0:
|
||||
self.shared = nn.ModuleList()
|
||||
self.shared.append(nn.Linear(inp_dim, 2 * out_dim))
|
||||
for x in range(n_shared - 1):
|
||||
self.shared.append(nn.Linear(out_dim, 2 * out_dim)) # preset the linear function we will use
|
||||
else:
|
||||
self.shared = None
|
||||
self.n_steps = n_steps
|
||||
self.steps = nn.ModuleList()
|
||||
for x in range(n_steps):
|
||||
self.steps.append(DecoderStep(inp_dim, out_dim, self.shared, n_ind, vbs, device))
|
||||
|
||||
def forward(self, x):
|
||||
out = torch.zeros(x.size(0), self.out_dim).to(x.device)
|
||||
for step in self.steps:
|
||||
out += step(x)
|
||||
return out
|
||||
|
||||
|
||||
class TabNet(nn.Module):
|
||||
def __init__(
|
||||
self, inp_dim=6, out_dim=6, n_d=64, n_a=64, n_shared=2, n_ind=2, n_steps=5, relax=1.2, vbs=1024, device="cpu"
|
||||
):
|
||||
"""
|
||||
TabNet AKA the original encoder
|
||||
|
||||
Args:
|
||||
n_d: dimension of the features used to calculate the final results
|
||||
n_a: dimension of the features input to the attention transformer of the next step
|
||||
n_shared: numbr of shared steps in feature transfomer(optional)
|
||||
n_ind: number of independent steps in feature transformer
|
||||
n_steps: number of steps of pass through tabbet
|
||||
relax coefficient:
|
||||
virtual batch size:
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# set the number of shared step in feature transformer
|
||||
if n_shared > 0:
|
||||
self.shared = nn.ModuleList()
|
||||
self.shared.append(nn.Linear(inp_dim, 2 * (n_d + n_a)))
|
||||
for x in range(n_shared - 1):
|
||||
self.shared.append(nn.Linear(n_d + n_a, 2 * (n_d + n_a))) # preset the linear function we will use
|
||||
else:
|
||||
self.shared = None
|
||||
|
||||
self.first_step = FeatureTransformer(inp_dim, n_d + n_a, self.shared, n_ind, vbs, device)
|
||||
self.steps = nn.ModuleList()
|
||||
for x in range(n_steps - 1):
|
||||
self.steps.append(DecisionStep(inp_dim, n_d, n_a, self.shared, n_ind, relax, vbs, device))
|
||||
self.fc = nn.Linear(n_d, out_dim)
|
||||
self.bn = nn.BatchNorm1d(inp_dim, momentum=0.01)
|
||||
self.n_d = n_d
|
||||
|
||||
def forward(self, x, priors):
|
||||
assert not torch.isnan(x).any()
|
||||
x = self.bn(x)
|
||||
x_a = self.first_step(x)[:, self.n_d :]
|
||||
sparse_loss = torch.zeros(1).to(x.device)
|
||||
out = torch.zeros(x.size(0), self.n_d).to(x.device)
|
||||
for step in self.steps:
|
||||
x_te, l = step(x, x_a, priors)
|
||||
out += F.relu(x_te[:, : self.n_d]) # split the feautre from feat_transformer
|
||||
x_a = x_te[:, self.n_d :]
|
||||
sparse_loss += l
|
||||
return self.fc(out), sparse_loss
|
||||
|
||||
|
||||
class GBN(nn.Module):
|
||||
"""
|
||||
Ghost Batch Normalization
|
||||
an efficient way of doing batch normalization
|
||||
|
||||
Args:
|
||||
vbs: virtual batch size
|
||||
"""
|
||||
|
||||
def __init__(self, inp, vbs=1024, momentum=0.01):
|
||||
super().__init__()
|
||||
self.bn = nn.BatchNorm1d(inp, momentum=momentum)
|
||||
self.vbs = vbs
|
||||
|
||||
def forward(self, x):
|
||||
chunk = torch.chunk(x, x.size(0) // self.vbs, 0)
|
||||
res = [self.bn(y) for y in chunk]
|
||||
return torch.cat(res, 0)
|
||||
|
||||
|
||||
class GLU(nn.Module):
|
||||
"""
|
||||
GLU block that extracts only the most essential information
|
||||
|
||||
Args:
|
||||
vbs: virtual batch size
|
||||
"""
|
||||
|
||||
def __init__(self, inp_dim, out_dim, fc=None, vbs=1024):
|
||||
super().__init__()
|
||||
if fc:
|
||||
self.fc = fc
|
||||
else:
|
||||
self.fc = nn.Linear(inp_dim, out_dim * 2)
|
||||
self.bn = GBN(out_dim * 2, vbs=vbs)
|
||||
self.od = out_dim
|
||||
|
||||
def forward(self, x):
|
||||
x = self.bn(self.fc(x))
|
||||
return torch.mul(x[:, : self.od], torch.sigmoid(x[:, self.od :]))
|
||||
|
||||
|
||||
class AttentionTransformer(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
relax: relax coefficient. The greater it is, we can
|
||||
use the same features more. When it is set to 1
|
||||
we can use every feature only once
|
||||
"""
|
||||
|
||||
def __init__(self, d_a, inp_dim, relax, vbs=1024):
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(d_a, inp_dim)
|
||||
self.bn = GBN(inp_dim, vbs=vbs)
|
||||
self.r = relax
|
||||
|
||||
# a:feature from previous decision step
|
||||
def forward(self, a, priors):
|
||||
a = self.bn(self.fc(a))
|
||||
mask = SparsemaxFunction.apply(a * priors)
|
||||
priors = priors * (self.r - mask) # updating the prior
|
||||
return mask
|
||||
|
||||
|
||||
class FeatureTransformer(nn.Module):
|
||||
def __init__(self, inp_dim, out_dim, shared, n_ind, vbs, device):
|
||||
super().__init__()
|
||||
first = True
|
||||
self.shared = nn.ModuleList()
|
||||
if shared:
|
||||
self.shared.append(GLU(inp_dim, out_dim, shared[0], vbs=vbs))
|
||||
first = False
|
||||
for fc in shared[1:]:
|
||||
self.shared.append(GLU(out_dim, out_dim, fc, vbs=vbs))
|
||||
else:
|
||||
self.shared = None
|
||||
self.independ = nn.ModuleList()
|
||||
if first:
|
||||
self.independ.append(GLU(inp, out_dim, vbs=vbs))
|
||||
for x in range(first, n_ind):
|
||||
self.independ.append(GLU(out_dim, out_dim, vbs=vbs))
|
||||
self.scale = torch.sqrt(torch.tensor([0.5], device=device))
|
||||
|
||||
def forward(self, x):
|
||||
if self.shared:
|
||||
x = self.shared[0](x)
|
||||
for glu in self.shared[1:]:
|
||||
x = torch.add(x, glu(x))
|
||||
x = x * self.scale
|
||||
for glu in self.independ:
|
||||
x = torch.add(x, glu(x))
|
||||
x = x * self.scale
|
||||
return x
|
||||
|
||||
|
||||
class DecisionStep(nn.Module):
|
||||
"""
|
||||
One step for the TabNet
|
||||
"""
|
||||
|
||||
def __init__(self, inp_dim, n_d, n_a, shared, n_ind, relax, vbs, device):
|
||||
super().__init__()
|
||||
self.atten_tran = AttentionTransformer(n_a, inp_dim, relax, vbs)
|
||||
self.fea_tran = FeatureTransformer(inp_dim, n_d + n_a, shared, n_ind, vbs, device)
|
||||
|
||||
def forward(self, x, a, priors):
|
||||
mask = self.atten_tran(a, priors)
|
||||
sparse_loss = ((-1) * mask * torch.log(mask + 1e-10)).mean()
|
||||
x = self.fea_tran(x * mask)
|
||||
return x, sparse_loss
|
||||
|
||||
|
||||
def make_ix_like(input, dim=0):
|
||||
d = input.size(dim)
|
||||
rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype)
|
||||
view = [1] * input.dim()
|
||||
view[0] = -1
|
||||
return rho.view(view).transpose(0, dim)
|
||||
|
||||
|
||||
class SparsemaxFunction(Function):
|
||||
"""
|
||||
SparseMax function for replacing reLU
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, dim=-1):
|
||||
ctx.dim = dim
|
||||
max_val, _ = input.max(dim=dim, keepdim=True)
|
||||
input -= max_val # same numerical stability trick as for softmax
|
||||
tau, supp_size = SparsemaxFunction.threshold_and_support(input, dim=dim)
|
||||
output = torch.clamp(input - tau, min=0)
|
||||
ctx.save_for_backward(supp_size, output)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
supp_size, output = ctx.saved_tensors
|
||||
dim = ctx.dim
|
||||
grad_input = grad_output.clone()
|
||||
grad_input[output == 0] = 0
|
||||
|
||||
v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze()
|
||||
v_hat = v_hat.unsqueeze(dim)
|
||||
grad_input = torch.where(output != 0, grad_input - v_hat, grad_input)
|
||||
return grad_input, None
|
||||
|
||||
@staticmethod
|
||||
def threshold_and_support(input, dim=-1):
|
||||
input_srt, _ = torch.sort(input, descending=True, dim=dim)
|
||||
input_cumsum = input_srt.cumsum(dim) - 1
|
||||
rhos = make_ix_like(input, dim)
|
||||
support = rhos * input_srt > input_cumsum
|
||||
|
||||
support_size = support.sum(dim=dim).unsqueeze(dim)
|
||||
tau = input_cumsum.gather(dim, support_size - 1)
|
||||
tau /= support_size.to(input.dtype)
|
||||
return tau, support_size
|
||||
@@ -15,14 +15,13 @@ import importlib
|
||||
import traceback
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from multiprocessing import Pool
|
||||
|
||||
from .cache import H
|
||||
from ..config import C
|
||||
from .ops import Operators
|
||||
from ..log import get_module_logger
|
||||
from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields
|
||||
from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields, code_to_fname
|
||||
from .base import Feature
|
||||
from .cache import DiskDatasetCache, DiskExpressionCache
|
||||
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
|
||||
@@ -225,24 +224,6 @@ class InstrumentProvider(abc.ABC):
|
||||
return cls.LIST
|
||||
raise ValueError(f"Unknown instrument type {inst}")
|
||||
|
||||
def convert_instruments(self, instrument):
|
||||
_instruments_map = getattr(self, "_instruments_map", None)
|
||||
if _instruments_map is None:
|
||||
_df_list = []
|
||||
# FIXME: each process will read these files
|
||||
for _path in Path(C.get_data_path()).joinpath("instruments").glob("*.txt"):
|
||||
_df = pd.read_csv(
|
||||
_path,
|
||||
sep="\t",
|
||||
names=["inst", "start_datetime", "end_datetime", "save_inst"],
|
||||
)
|
||||
_df_list.append(_df.iloc[:, [0, -1]])
|
||||
df = pd.concat(_df_list, sort=False).sort_values("save_inst")
|
||||
df = df.drop_duplicates(subset=["save_inst"], keep="first").fillna(axis=1, method="ffill")
|
||||
_instruments_map = df.set_index("inst").iloc[:, 0].to_dict()
|
||||
setattr(self, "_instruments_map", _instruments_map)
|
||||
return _instruments_map.get(instrument, instrument)
|
||||
|
||||
|
||||
class FeatureProvider(abc.ABC):
|
||||
"""Feature provider class
|
||||
@@ -600,14 +581,16 @@ class LocalInstrumentProvider(InstrumentProvider):
|
||||
fname = self._uri_inst.format(market)
|
||||
if not os.path.exists(fname):
|
||||
raise ValueError("instruments not exists for market " + market)
|
||||
|
||||
_instruments = dict()
|
||||
df = pd.read_csv(
|
||||
fname,
|
||||
sep="\t",
|
||||
names=["inst", "start_datetime", "end_datetime", "save_inst"],
|
||||
usecols=[0, 1, 2],
|
||||
names=["inst", "start_datetime", "end_datetime"],
|
||||
dtype={"inst": str},
|
||||
parse_dates=["start_datetime", "end_datetime"],
|
||||
)
|
||||
df["start_datetime"] = pd.to_datetime(df["start_datetime"])
|
||||
df["end_datetime"] = pd.to_datetime(df["end_datetime"])
|
||||
for row in df.itertuples(index=False):
|
||||
_instruments.setdefault(row[0], []).append((row[1], row[2]))
|
||||
return _instruments
|
||||
@@ -664,7 +647,7 @@ class LocalFeatureProvider(FeatureProvider):
|
||||
def feature(self, instrument, field, start_index, end_index, freq):
|
||||
# validate
|
||||
field = str(field).lower()[1:]
|
||||
instrument = Inst.convert_instruments(instrument)
|
||||
instrument = code_to_fname(instrument)
|
||||
uri_data = self._uri_data.format(instrument.lower(), field, freq)
|
||||
if not os.path.exists(uri_data):
|
||||
get_module_logger("data").warning("WARN: data not found for %s.%s" % (instrument, field))
|
||||
|
||||
@@ -7,6 +7,9 @@ from ..config import REG_CN
|
||||
|
||||
|
||||
class TestAutoData(unittest.TestCase):
|
||||
|
||||
_setup_kwargs = {}
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
# use default data
|
||||
@@ -15,6 +18,10 @@ class TestAutoData(unittest.TestCase):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
|
||||
GetData().qlib_data(
|
||||
name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri
|
||||
name="qlib_data_simple",
|
||||
region="cn",
|
||||
interval="1d",
|
||||
target_dir=provider_uri,
|
||||
delete_old=False,
|
||||
)
|
||||
init(provider_uri=provider_uri, region=REG_CN)
|
||||
init(provider_uri=provider_uri, region=REG_CN, **cls._setup_kwargs)
|
||||
|
||||
@@ -1,14 +1,21 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import qlib
|
||||
import shutil
|
||||
import zipfile
|
||||
import requests
|
||||
import datetime
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class GetData:
|
||||
DATASET_VERSION = "v1"
|
||||
REMOTE_URL = "http://fintech.msra.cn/stock_data/downloads"
|
||||
QLIB_DATA_NAME = "{dataset_name}_{region}_{interval}_{qlib_version}.zip"
|
||||
|
||||
def __init__(self, delete_zip_file=False):
|
||||
"""
|
||||
@@ -20,13 +27,24 @@ class GetData:
|
||||
"""
|
||||
self.delete_zip_file = delete_zip_file
|
||||
|
||||
def _download_data(self, file_name: str, target_dir: [Path, str]):
|
||||
def normalize_dataset_version(self, dataset_version: str = None):
|
||||
if dataset_version is None:
|
||||
dataset_version = self.DATASET_VERSION
|
||||
return dataset_version
|
||||
|
||||
def merge_remote_url(self, file_name: str, dataset_version: str = None):
|
||||
return f"{self.REMOTE_URL}/{self.normalize_dataset_version(dataset_version)}/{file_name}"
|
||||
|
||||
def _download_data(
|
||||
self, file_name: str, target_dir: [Path, str], delete_old: bool = True, dataset_version: str = None
|
||||
):
|
||||
target_dir = Path(target_dir).expanduser()
|
||||
target_dir.mkdir(exist_ok=True, parents=True)
|
||||
# saved file name
|
||||
_target_file_name = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + "_" + file_name
|
||||
target_path = target_dir.joinpath(_target_file_name)
|
||||
|
||||
url = f"{self.REMOTE_URL}/{file_name}"
|
||||
target_path = target_dir.joinpath(file_name)
|
||||
|
||||
url = self.merge_remote_url(file_name, dataset_version)
|
||||
resp = requests.get(url, stream=True)
|
||||
if resp.status_code != 200:
|
||||
raise requests.exceptions.HTTPError()
|
||||
@@ -42,19 +60,59 @@ class GetData:
|
||||
fp.write(chunk)
|
||||
p_bar.update(chunk_size)
|
||||
|
||||
self._unzip(target_path, target_dir)
|
||||
self._unzip(target_path, target_dir, delete_old)
|
||||
if self.delete_zip_file:
|
||||
target_path.unlink()
|
||||
|
||||
def check_dataset(self, file_name: str, dataset_version: str = None):
|
||||
url = self.merge_remote_url(file_name, dataset_version)
|
||||
resp = requests.get(url, stream=True)
|
||||
status = True
|
||||
if resp.status_code == 404:
|
||||
status = False
|
||||
return status
|
||||
|
||||
@staticmethod
|
||||
def _unzip(file_path: Path, target_dir: Path):
|
||||
def _unzip(file_path: Path, target_dir: Path, delete_old: bool = True):
|
||||
if delete_old:
|
||||
logger.warning(
|
||||
f"will delete the old qlib data directory(features, instruments, calendars, features_cache, dataset_cache): {target_dir}"
|
||||
)
|
||||
GetData._delete_qlib_data(target_dir)
|
||||
logger.info(f"{file_path} unzipping......")
|
||||
with zipfile.ZipFile(str(file_path.resolve()), "r") as zp:
|
||||
for _file in tqdm(zp.namelist()):
|
||||
zp.extract(_file, str(target_dir.resolve()))
|
||||
|
||||
@staticmethod
|
||||
def _delete_qlib_data(file_dir: Path):
|
||||
logger.info(f"delete {file_dir}")
|
||||
rm_dirs = []
|
||||
for _name in ["features", "calendars", "instruments", "features_cache", "dataset_cache"]:
|
||||
_p = file_dir.joinpath(_name)
|
||||
if _p.exists():
|
||||
rm_dirs.append(str(_p.resolve()))
|
||||
if rm_dirs:
|
||||
flag = input(
|
||||
f"Will be deleted: "
|
||||
f"\n\t{rm_dirs}"
|
||||
f"\nIf you do not need to delete {file_dir}, please change the <--target_dir>"
|
||||
f"\nAre you sure you want to delete, yes(Y/y), no (N/n):"
|
||||
)
|
||||
if str(flag) not in ["Y", "y"]:
|
||||
exit()
|
||||
for _p in rm_dirs:
|
||||
logger.warning(f"delete: {_p}")
|
||||
shutil.rmtree(_p)
|
||||
|
||||
def qlib_data(
|
||||
self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version="latest", interval="1d", region="cn"
|
||||
self,
|
||||
name="qlib_data",
|
||||
target_dir="~/.qlib/qlib_data/cn_data",
|
||||
version=None,
|
||||
interval="1d",
|
||||
region="cn",
|
||||
delete_old=True,
|
||||
):
|
||||
"""download cn qlib data from remote
|
||||
|
||||
@@ -65,20 +123,31 @@ class GetData:
|
||||
name: str
|
||||
dataset name, value from [qlib_data, qlib_data_simple], by default qlib_data
|
||||
version: str
|
||||
data version, value from [v0, v1, ..., latest], by default latest
|
||||
data version, value from [v1, ...], by default None(use script to specify version)
|
||||
interval: str
|
||||
data freq, value from [1d], by default 1d
|
||||
region: str
|
||||
data region, value from [cn, us], by default cn
|
||||
delete_old: bool
|
||||
delete an existing directory, by default True
|
||||
|
||||
Examples
|
||||
---------
|
||||
python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data --version latest --interval 1d --region cn
|
||||
python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
-------
|
||||
|
||||
"""
|
||||
file_name = f"{name}_{region.lower()}_{interval.lower()}_{version}.zip"
|
||||
self._download_data(file_name.lower(), target_dir)
|
||||
qlib_version = ".".join(re.findall(r"(\d+)\.+", qlib.__version__))
|
||||
|
||||
def _get_file_name(v):
|
||||
return self.QLIB_DATA_NAME.format(
|
||||
dataset_name=name, region=region.lower(), interval=interval.lower(), qlib_version=v
|
||||
)
|
||||
|
||||
file_name = _get_file_name(qlib_version)
|
||||
if not self.check_dataset(file_name, version):
|
||||
file_name = _get_file_name("latest")
|
||||
self._download_data(file_name.lower(), target_dir, delete_old, dataset_version=version)
|
||||
|
||||
def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"):
|
||||
"""download cn csv data from remote
|
||||
|
||||
@@ -15,7 +15,6 @@ import bisect
|
||||
import shutil
|
||||
import difflib
|
||||
import hashlib
|
||||
import logging
|
||||
import datetime
|
||||
import requests
|
||||
import tempfile
|
||||
@@ -27,10 +26,9 @@ import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Union, Tuple
|
||||
|
||||
from ..config import C, REG_CN
|
||||
from ..config import C
|
||||
from ..log import get_module_logger, set_log_with_config
|
||||
|
||||
|
||||
log = get_module_logger("utils")
|
||||
|
||||
|
||||
@@ -645,15 +643,28 @@ def exists_qlib_data(qlib_dir):
|
||||
# check instruments
|
||||
code_names = set(map(lambda x: x.name.lower(), features_dir.iterdir()))
|
||||
_instrument = instruments_dir.joinpath("all.txt")
|
||||
df = pd.read_csv(_instrument, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"])
|
||||
df = df.iloc[:, [0, -1]].fillna(axis=1, method="ffill")
|
||||
miss_code = set(df.iloc[:, -1].apply(str.lower)) - set(code_names)
|
||||
miss_code = set(pd.read_csv(_instrument, sep="\t", header=None).loc[:, 0].apply(str.lower)) - set(code_names)
|
||||
if miss_code and any(map(lambda x: "sht" not in x, miss_code)):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def check_qlib_data(qlib_config):
|
||||
inst_dir = Path(qlib_config["provider_uri"]).joinpath("instruments")
|
||||
for _p in inst_dir.glob("*.txt"):
|
||||
try:
|
||||
assert len(pd.read_csv(_p, sep="\t", nrows=0, header=None).columns) == 3, (
|
||||
f"\nThe {str(_p.resolve())} of qlib data is not equal to 3 columns:"
|
||||
f"\n\tIf you are using the data provided by qlib: "
|
||||
f"https://qlib.readthedocs.io/en/latest/component/data.html#qlib-format-dataset"
|
||||
f"\n\tIf you are using your own data, please dump the data again: "
|
||||
f"https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format"
|
||||
)
|
||||
except AssertionError:
|
||||
raise
|
||||
|
||||
|
||||
def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame:
|
||||
"""
|
||||
make the df index sorted
|
||||
@@ -744,3 +755,36 @@ def load_dataset(path_or_obj):
|
||||
elif extension == ".csv":
|
||||
return pd.read_csv(path_or_obj, parse_dates=True, index_col=[0, 1])
|
||||
raise ValueError(f"unsupported file type `{extension}`")
|
||||
|
||||
|
||||
def code_to_fname(code: str):
|
||||
"""stock code to file name
|
||||
|
||||
Parameters
|
||||
----------
|
||||
code: str
|
||||
"""
|
||||
# NOTE: In windows, the following name is I/O device, and the file with the corresponding name cannot be created
|
||||
# reference: https://superuser.com/questions/86999/why-cant-i-name-a-folder-or-file-con-in-windows
|
||||
replace_names = ["CON", "PRN", "AUX", "NUL"]
|
||||
replace_names += [f"COM{i}" for i in range(10)]
|
||||
replace_names += [f"LPT{i}" for i in range(10)]
|
||||
|
||||
prefix = "_qlib_"
|
||||
if str(code).upper() in replace_names:
|
||||
code = prefix + str(code)
|
||||
|
||||
return code
|
||||
|
||||
|
||||
def fname_to_code(fname: str):
|
||||
"""file name to stock code
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname: str
|
||||
"""
|
||||
prefix = "_qlib_"
|
||||
if fname.startswith(prefix):
|
||||
fname = fname.lstrip(prefix)
|
||||
return fname
|
||||
|
||||
@@ -27,11 +27,6 @@ class Serializable:
|
||||
def dump_all(self):
|
||||
"""
|
||||
will the object dump all object
|
||||
|
||||
Parameters
|
||||
----------
|
||||
self : [TODO:type]
|
||||
[TODO:description]
|
||||
"""
|
||||
return getattr(self, "_dump_all", False)
|
||||
|
||||
@@ -39,11 +34,6 @@ class Serializable:
|
||||
def exclude(self):
|
||||
"""
|
||||
What attribute will be dumped
|
||||
|
||||
Parameters
|
||||
----------
|
||||
self : [TODO:type]
|
||||
[TODO:description]
|
||||
"""
|
||||
return getattr(self, "_exclude", [])
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ pip install -r requirements.txt
|
||||
|
||||
### Download data and Normalize data
|
||||
```bash
|
||||
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d --normalize_dir ~/.qlib/stock_data/normalize
|
||||
```
|
||||
|
||||
### Download Data
|
||||
|
||||
@@ -18,6 +18,7 @@ from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from yahooquery import Ticker
|
||||
from dateutil.tz import tzlocal
|
||||
from qlib.utils import code_to_fname
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
@@ -40,7 +41,7 @@ class YahooCollector:
|
||||
end=None,
|
||||
interval="1d",
|
||||
max_workers=4,
|
||||
max_collector_count=5,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
limit_nums: int = None,
|
||||
@@ -55,7 +56,7 @@ class YahooCollector:
|
||||
max_workers: int
|
||||
workers, default 4
|
||||
max_collector_count: int
|
||||
default 5
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
@@ -147,11 +148,10 @@ class YahooCollector:
|
||||
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
|
||||
df["symbol"] = symbol
|
||||
if stock_path.exists():
|
||||
with stock_path.open("a") as fp:
|
||||
df.to_csv(fp, index=False, header=False)
|
||||
_temp_df = pd.read_csv(stock_path, nrows=0)
|
||||
df.loc[:, _temp_df.columns].to_csv(stock_path, index=False, header=False, mode="a")
|
||||
else:
|
||||
with stock_path.open("w") as fp:
|
||||
df.to_csv(fp, index=False)
|
||||
df.to_csv(stock_path, index=False, mode="w")
|
||||
|
||||
def _save_small_data(self, symbol, df):
|
||||
if len(df) <= self.min_numbers_trading:
|
||||
@@ -350,7 +350,7 @@ class YahooCollectorUS(YahooCollector):
|
||||
pass
|
||||
|
||||
def normalize_symbol(self, symbol):
|
||||
return symbol.upper()
|
||||
return code_to_fname(symbol).upper()
|
||||
|
||||
@property
|
||||
def _timezone(self):
|
||||
|
||||
@@ -14,6 +14,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from qlib.utils import fname_to_code, code_to_fname
|
||||
|
||||
|
||||
class DumpDataBase:
|
||||
@@ -27,7 +28,6 @@ class DumpDataBase:
|
||||
HIGH_FREQ_FORMAT = "%Y-%m-%d %H:%M:%S"
|
||||
INSTRUMENTS_SEP = "\t"
|
||||
INSTRUMENTS_FILE_NAME = "all.txt"
|
||||
SAVE_INST_FIELD = "save_inst"
|
||||
|
||||
UPDATE_MODE = "update"
|
||||
ALL_MODE = "all"
|
||||
@@ -45,7 +45,6 @@ class DumpDataBase:
|
||||
exclude_fields: str = "",
|
||||
include_fields: str = "",
|
||||
limit_nums: int = None,
|
||||
inst_prefix: str = "",
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -73,9 +72,6 @@ class DumpDataBase:
|
||||
fields not dumped
|
||||
limit_nums: int
|
||||
Use when debugging, default None
|
||||
inst_prefix: str
|
||||
add a column to the instruments file and record the saved instrument name,
|
||||
the US stock code contains "PRN", and the directory cannot be created on Windows system, use the "_" prefix.
|
||||
"""
|
||||
csv_path = Path(csv_path).expanduser()
|
||||
if isinstance(exclude_fields, str):
|
||||
@@ -84,7 +80,6 @@ class DumpDataBase:
|
||||
include_fields = include_fields.split(",")
|
||||
self._exclude_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, exclude_fields)))
|
||||
self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields)))
|
||||
self._inst_prefix = inst_prefix.strip()
|
||||
self.file_suffix = file_suffix
|
||||
self.symbol_field_name = symbol_field_name
|
||||
self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path])
|
||||
@@ -145,7 +140,7 @@ class DumpDataBase:
|
||||
return df
|
||||
|
||||
def get_symbol_from_file(self, file_path: Path) -> str:
|
||||
return file_path.name[: -len(self.file_suffix)].strip().lower()
|
||||
return fname_to_code(file_path.name[: -len(self.file_suffix)].strip().lower())
|
||||
|
||||
def get_dump_fields(self, df_columns: Iterable[str]) -> Iterable[str]:
|
||||
return (
|
||||
@@ -173,7 +168,6 @@ class DumpDataBase:
|
||||
self.symbol_field_name,
|
||||
self.INSTRUMENTS_START_FIELD,
|
||||
self.INSTRUMENTS_END_FIELD,
|
||||
self.SAVE_INST_FIELD,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -190,13 +184,11 @@ class DumpDataBase:
|
||||
instruments_path = str(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME).resolve())
|
||||
if isinstance(instruments_data, pd.DataFrame):
|
||||
_df_fields = [self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD]
|
||||
if self._inst_prefix:
|
||||
_df_fields.append(self.SAVE_INST_FIELD)
|
||||
instruments_data[self.SAVE_INST_FIELD] = instruments_data[self.symbol_field_name].apply(
|
||||
lambda x: f"{self._inst_prefix}{x}"
|
||||
)
|
||||
instruments_data = instruments_data.loc[:, _df_fields]
|
||||
instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP)
|
||||
instruments_data[self.symbol_field_name] = instruments_data[self.symbol_field_name].apply(
|
||||
lambda x: fname_to_code(x.lower()).upper()
|
||||
)
|
||||
instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP, index=False)
|
||||
else:
|
||||
np.savetxt(instruments_path, instruments_data, fmt="%s", encoding="utf-8")
|
||||
|
||||
@@ -223,26 +215,26 @@ class DumpDataBase:
|
||||
logger.warning(f"{features_dir.name} data is None or empty")
|
||||
return
|
||||
# align index
|
||||
_df = self.data_merge_calendar(df, self._calendars_list)
|
||||
_df = self.data_merge_calendar(df, calendar_list)
|
||||
# used when creating a bin file
|
||||
date_index = self.get_datetime_index(_df, calendar_list)
|
||||
for field in self.get_dump_fields(_df.columns):
|
||||
bin_path = features_dir.joinpath(f"{field}.{self.freq}{self.DUMP_FILE_SUFFIX}")
|
||||
if field not in _df.columns:
|
||||
continue
|
||||
if self._mode == self.UPDATE_MODE:
|
||||
if bin_path.exists() and self._mode == self.UPDATE_MODE:
|
||||
# update
|
||||
with bin_path.open("ab") as fp:
|
||||
np.array(_df[field]).astype("<f").tofile(fp)
|
||||
elif self._mode == self.ALL_MODE:
|
||||
np.hstack([date_index, _df[field]]).astype("<f").tofile(str(bin_path.resolve()))
|
||||
else:
|
||||
raise ValueError(f"{self._mode} cannot support!")
|
||||
# append; self._mode == self.ALL_MODE or not bin_path.exists()
|
||||
np.hstack([date_index, _df[field]]).astype("<f").tofile(str(bin_path.resolve()))
|
||||
|
||||
def _dump_bin(self, file_or_data: [Path, pd.DataFrame], calendar_list: List[pd.Timestamp]):
|
||||
if isinstance(file_or_data, pd.DataFrame):
|
||||
if file_or_data.empty:
|
||||
return
|
||||
code = file_or_data.iloc[0][self.symbol_field_name].lower()
|
||||
code = fname_to_code(file_or_data.iloc[0][self.symbol_field_name].lower())
|
||||
df = file_or_data
|
||||
elif isinstance(file_or_data, Path):
|
||||
code = self.get_symbol_from_file(file_or_data)
|
||||
@@ -253,8 +245,7 @@ class DumpDataBase:
|
||||
logger.warning(f"{code} data is None or empty")
|
||||
return
|
||||
# features save dir
|
||||
code = self._inst_prefix + code if self._inst_prefix else code
|
||||
features_dir = self._features_dir.joinpath(code)
|
||||
features_dir = self._features_dir.joinpath(code_to_fname(code).lower())
|
||||
features_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._data_to_bin(df, calendar_list, features_dir)
|
||||
|
||||
@@ -283,8 +274,6 @@ class DumpDataAll(DumpDataBase):
|
||||
_end_time = self._format_datetime(_end_time)
|
||||
symbol = self.get_symbol_from_file(file_path)
|
||||
_inst_fields = [symbol.upper(), _begin_time, _end_time]
|
||||
if self._inst_prefix:
|
||||
_inst_fields.append(self._inst_prefix + symbol.upper())
|
||||
date_range_list.append(f"{self.INSTRUMENTS_SEP.join(_inst_fields)}")
|
||||
p_bar.update()
|
||||
self._kwargs["all_datetime_set"] = all_datetime
|
||||
@@ -323,12 +312,18 @@ class DumpDataFix(DumpDataAll):
|
||||
def _dump_instruments(self):
|
||||
logger.info("start dump instruments......")
|
||||
_fun = partial(self._get_date, is_begin_end=True)
|
||||
new_stock_files = sorted(filter(lambda x: x.name not in self._old_instruments, self.csv_files))
|
||||
new_stock_files = sorted(
|
||||
filter(
|
||||
lambda x: fname_to_code(x.name[: -len(self.file_suffix)].strip().lower()).upper()
|
||||
not in self._old_instruments,
|
||||
self.csv_files,
|
||||
)
|
||||
)
|
||||
with tqdm(total=len(new_stock_files)) as p_bar:
|
||||
with ProcessPoolExecutor(max_workers=self.works) as execute:
|
||||
for file_path, (_begin_time, _end_time) in zip(new_stock_files, execute.map(_fun, new_stock_files)):
|
||||
if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp):
|
||||
symbol = self.get_symbol_from_file(file_path).upper()
|
||||
symbol = fname_to_code(self.get_symbol_from_file(file_path).lower()).upper()
|
||||
_dt_map = self._old_instruments.setdefault(symbol, dict())
|
||||
_dt_map[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_begin_time)
|
||||
_dt_map[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end_time)
|
||||
@@ -406,10 +401,10 @@ class DumpDataUpdate(DumpDataBase):
|
||||
)
|
||||
self._mode = self.UPDATE_MODE
|
||||
self._old_calendar_list = self._read_calendars(self._calendars_dir.joinpath(f"{self.freq}.txt"))
|
||||
self._update_instruments = self._read_instruments(
|
||||
self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME)
|
||||
).to_dict(
|
||||
orient="index"
|
||||
self._update_instruments = (
|
||||
self._read_instruments(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME))
|
||||
.set_index([self.symbol_field_name])
|
||||
.to_dict(orient="index")
|
||||
) # type: dict
|
||||
|
||||
# load all csv files
|
||||
@@ -425,10 +420,7 @@ class DumpDataUpdate(DumpDataBase):
|
||||
all_df = []
|
||||
|
||||
def _read_csv(file_path: Path):
|
||||
if self._include_fields:
|
||||
_df = pd.read_csv(file_path, usecols=self._include_fields)
|
||||
else:
|
||||
_df = pd.read_csv(file_path)
|
||||
_df = pd.read_csv(file_path, parse_dates=[self.date_field_name])
|
||||
if self.symbol_field_name not in _df.columns:
|
||||
_df[self.symbol_field_name] = self.get_symbol_from_file(file_path)
|
||||
return _df
|
||||
@@ -436,7 +428,7 @@ class DumpDataUpdate(DumpDataBase):
|
||||
with tqdm(total=len(self.csv_files)) as p_bar:
|
||||
with ThreadPoolExecutor(max_workers=self.works) as executor:
|
||||
for df in executor.map(_read_csv, self.csv_files):
|
||||
if df:
|
||||
if not df.empty:
|
||||
all_df.append(df)
|
||||
p_bar.update()
|
||||
|
||||
@@ -455,25 +447,27 @@ class DumpDataUpdate(DumpDataBase):
|
||||
with ProcessPoolExecutor(max_workers=self.works) as executor:
|
||||
futures = {}
|
||||
for _code, _df in self._all_data.groupby(self.symbol_field_name):
|
||||
_code = str(_code).upper()
|
||||
_code = fname_to_code(str(_code).lower()).upper()
|
||||
_start, _end = self._get_date(_df, is_begin_end=True)
|
||||
if not (isinstance(_start, pd.Timestamp) and isinstance(_end, pd.Timestamp)):
|
||||
continue
|
||||
if _code in self._update_instruments:
|
||||
self._update_instruments[_code]["end_time"] = _end
|
||||
self._update_instruments[_code][self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end)
|
||||
futures[executor.submit(self._dump_bin, _df, self._update_calendars)] = _code
|
||||
else:
|
||||
# new stock
|
||||
_dt_range = self._update_instruments.setdefault(_code, dict())
|
||||
_dt_range["start_time"] = _start
|
||||
_dt_range["end_time"] = _end
|
||||
_dt_range[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_start)
|
||||
_dt_range[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end)
|
||||
futures[executor.submit(self._dump_bin, _df, self._new_calendar_list)] = _code
|
||||
|
||||
for _future in tqdm(as_completed(futures)):
|
||||
try:
|
||||
_future.result()
|
||||
except Exception:
|
||||
error_code[futures[_future]] = traceback.format_exc()
|
||||
with tqdm(total=len(futures)) as p_bar:
|
||||
for _future in as_completed(futures):
|
||||
try:
|
||||
_future.result()
|
||||
except Exception:
|
||||
error_code[futures[_future]] = traceback.format_exc()
|
||||
p_bar.update()
|
||||
logger.info(f"dump bin errors: {error_code}")
|
||||
|
||||
logger.info("end of features dump.\n")
|
||||
@@ -481,7 +475,9 @@ class DumpDataUpdate(DumpDataBase):
|
||||
def dump(self):
|
||||
self.save_calendars(self._new_calendar_list)
|
||||
self._dump_features()
|
||||
self.save_instruments(pd.DataFrame.from_dict(self._update_instruments, orient="index"))
|
||||
df = pd.DataFrame.from_dict(self._update_instruments, orient="index")
|
||||
df.index.names = [self.symbol_field_name]
|
||||
self.save_instruments(df.reset_index())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import fire
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
2
setup.py
2
setup.py
@@ -11,7 +11,7 @@ NAME = "pyqlib"
|
||||
DESCRIPTION = "A Quantitative-research Platform"
|
||||
REQUIRES_PYTHON = ">=3.5.0"
|
||||
|
||||
VERSION = "0.6.1.dev"
|
||||
VERSION = "0.6.1.99"
|
||||
|
||||
# Detect Cython
|
||||
try:
|
||||
|
||||
@@ -37,7 +37,7 @@ class TestGetData(unittest.TestCase):
|
||||
|
||||
def test_0_qlib_data(self):
|
||||
|
||||
GetData().qlib_data(name="qlib_data_simple", target_dir=QLIB_DIR, region="cn", interval="1d", version="latest")
|
||||
GetData().qlib_data(name="qlib_data_simple", target_dir=QLIB_DIR, region="cn", interval="1d", delete_old=False)
|
||||
df = D.features(D.instruments("csi300"), self.FIELDS)
|
||||
self.assertListEqual(list(df.columns), self.FIELDS, "get qlib data failed")
|
||||
self.assertFalse(df.dropna().empty, "get qlib data failed")
|
||||
|
||||
@@ -62,13 +62,8 @@ class Distance(PairOperator):
|
||||
class TestRegiterCustomOps(TestAutoData):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data_simple_1" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
|
||||
GetData().qlib_data(name="qlib_data_simple", region="cn", interval="1d", target_dir=provider_uri)
|
||||
qlib.init(provider_uri=provider_uri, custom_ops=[Diff, Distance], region=REG_CN)
|
||||
cls._setup_kwargs.update({"custom_ops": [Diff, Distance]})
|
||||
super().setUpClass()
|
||||
|
||||
def test_regiter_custom_ops(self):
|
||||
instruments = ["SH600000"]
|
||||
|
||||
Reference in New Issue
Block a user