mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 18:40:58 +08:00
Merge branch 'main' into main
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -34,3 +34,7 @@ tags
|
||||
|
||||
.pytest_cache/
|
||||
.vscode/
|
||||
|
||||
*.swp
|
||||
|
||||
./pretrain
|
||||
|
||||
@@ -218,6 +218,25 @@ Filter
|
||||
- `cross-sectional features filter` \: rule_expression = '$rank($close)<10'
|
||||
- `time-sequence features filter`: rule_expression = '$Ref($close, 3)>100'
|
||||
|
||||
Here is a simple example showing how to use filter in a basic ``Qlib`` workflow configuration file:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
filter: &filter
|
||||
filter_type: ExpressionDFilter
|
||||
rule_expression: "Ref($close, -2) / Ref($close, -1) > 1"
|
||||
filter_start_time: 2010-01-01
|
||||
filter_end_time: 2010-01-07
|
||||
keep: False
|
||||
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2010-01-01
|
||||
end_time: 2021-01-22
|
||||
fit_start_time: 2010-01-01
|
||||
fit_end_time: 2015-12-31
|
||||
instruments: *market
|
||||
filter_pipe: [*filter]
|
||||
|
||||
To know more about ``Filter``, please refer to `Filter API <../reference/api.html#module-qlib.data.filter>`_.
|
||||
|
||||
Reference
|
||||
|
||||
@@ -17,6 +17,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0493±0.01 | 0.3778±0.06| 0.0585±0.00 | 0.4606±0.04 | 0.0513±0.03 | 0.6727±0.38| -0.1085±0.02 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 |
|
||||
|
||||
## Alpha158 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|---|---|---|---|---|---|---|---|---|
|
||||
@@ -25,7 +26,6 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| XGBoost (Tianqi Chen, et al.) | Alpha158 | 0.0481±0.00 | 0.3659±0.00| 0.0495±0.00 | 0.4033±0.00 | 0.1111±0.00 | 1.2915±0.00| -0.0893±0.00 |
|
||||
| LightGBM (Guolin Ke, et al.) | Alpha158 | 0.0475±0.00 | 0.3979±0.00| 0.0485±0.00 | 0.4123±0.00 | 0.1143±0.00 | 1.2744±0.00| -0.0800±0.00 |
|
||||
| MLP | Alpha158 | 0.0358±0.00 | 0.2738±0.03| 0.0425±0.00 | 0.3221±0.01 | 0.0836±0.02 | 1.0323±0.25| -0.1127±0.02 |
|
||||
| TabNet with pretrain (Sercan O. Arikm et al) | Alpha158 | 0.0344±0.00|0.205±0.11|0.0398±0.00 |0.3479±0.01|0.0827±0.02|1.1141±0.32 |-0.0925±0.02 |
|
||||
| TFT (Bryan Lim, et al.) | Alpha158 (with selected 20 features) | 0.0343±0.00 | 0.2071±0.02| 0.0107±0.00 | 0.0660±0.02 | 0.0623±0.02 | 0.5818±0.20| -0.1762±0.01 |
|
||||
| GRU (Kyunghyun Cho, et al.) | Alpha158 (with selected 20 features) | 0.0311±0.00 | 0.2418±0.04| 0.0425±0.00 | 0.3434±0.02 | 0.0330±0.02 | 0.4805±0.30| -0.1021±0.02 |
|
||||
| LSTM (Sepp Hochreiter, et al.) | Alpha158 (with selected 20 features) | 0.0312±0.00 | 0.2394±0.04| 0.0418±0.00 | 0.3324±0.03 | 0.0298±0.02 | 0.4198±0.33| -0.1348±0.03 |
|
||||
|
||||
Binary file not shown.
@@ -55,7 +55,7 @@ task:
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
pretrain: [2008-01-01, 2014-12-31]
|
||||
pretrain_validation: [2015-01-01, 2020-08-01]
|
||||
pretrain_validation: [2015-01-01, 2016-12-31]
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
|
||||
@@ -105,7 +105,7 @@ _default_config = {
|
||||
"redis_port": 6379,
|
||||
"redis_task_db": 1,
|
||||
# This value can be reset via qlib.init
|
||||
"logging_level": "INFO",
|
||||
"logging_level": logging.INFO,
|
||||
# Global configuration of qlib log
|
||||
# logging_level can control the logging level more finely
|
||||
"logging_config": {
|
||||
@@ -124,12 +124,12 @@ _default_config = {
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": "DEBUG",
|
||||
"level": logging.DEBUG,
|
||||
"formatter": "logger_format",
|
||||
"filters": ["field_not_found"],
|
||||
}
|
||||
},
|
||||
"loggers": {"qlib": {"level": "DEBUG", "handlers": ["console"]}},
|
||||
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"]}},
|
||||
},
|
||||
# Defatult config for experiment manager
|
||||
"exp_manager": {
|
||||
@@ -185,7 +185,7 @@ MODE_CONF = {
|
||||
# The nfs should be auto-mounted by qlib on other
|
||||
# serversS(such as PAI) [auto_mount:True]
|
||||
"timeout": 100,
|
||||
"logging_level": "INFO",
|
||||
"logging_level": logging.INFO,
|
||||
"region": REG_CN,
|
||||
## Custom Operator
|
||||
"custom_ops": [],
|
||||
|
||||
@@ -184,7 +184,7 @@ class DEnsembleModel(Model):
|
||||
/ M
|
||||
)
|
||||
loss_feat = self.get_loss(y_train.values.squeeze(), pred.values)
|
||||
g.loc[i_f, "g_value"] = np.mean(loss_feat - loss_values) / np.std(loss_feat - loss_values)
|
||||
g.loc[i_f, "g_value"] = np.mean(loss_feat - loss_values) / (np.std(loss_feat - loss_values) + 1e-7)
|
||||
x_train_tmp.loc[:, feat] = x_train.loc[:, feat].copy()
|
||||
|
||||
# one column in train features is all-nan # if g['g_value'].isna().any()
|
||||
|
||||
@@ -14,7 +14,7 @@ import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
@@ -23,6 +23,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -39,8 +40,8 @@ class ALSTM(Model):
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -76,8 +77,7 @@ class ALSTM(Model):
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -93,7 +93,7 @@ class ALSTM(Model):
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
@@ -107,7 +107,7 @@ class ALSTM(Model):
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.device,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
@@ -123,6 +123,9 @@ class ALSTM(Model):
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.ALSTM_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.ALSTM_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.ALSTM_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -133,6 +136,10 @@ class ALSTM(Model):
|
||||
self.fitted = False
|
||||
self.ALSTM_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -201,12 +208,13 @@ class ALSTM(Model):
|
||||
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.ALSTM_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
with torch.no_grad():
|
||||
pred = self.ALSTM_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
@@ -214,7 +222,6 @@ class ALSTM(Model):
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
@@ -227,8 +234,7 @@ class ALSTM(Model):
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
@@ -290,10 +296,7 @@ class ALSTM(Model):
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.ALSTM_model(x_batch).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.ALSTM_model(x_batch).detach().numpy()
|
||||
pred = self.ALSTM_model(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
@@ -24,6 +24,7 @@ import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -40,8 +41,8 @@ class ALSTM(Model):
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -78,9 +79,8 @@ class ALSTM(Model):
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.n_jobs = n_jobs
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -96,7 +96,7 @@ class ALSTM(Model):
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nn_jobs : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
@@ -111,7 +111,7 @@ class ALSTM(Model):
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.device,
|
||||
n_jobs,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
@@ -127,7 +127,10 @@ class ALSTM(Model):
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
).to(self.device)
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.ALSTM_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.ALSTM_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.ALSTM_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -138,6 +141,10 @@ class ALSTM(Model):
|
||||
self.fitted = False
|
||||
self.ALSTM_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -188,12 +195,13 @@ class ALSTM(Model):
|
||||
# feature[torch.isnan(feature)] = 0
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
|
||||
pred = self.ALSTM_model(feature.float())
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
with torch.no_grad():
|
||||
pred = self.ALSTM_model(feature.float())
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
@@ -201,7 +209,6 @@ class ALSTM(Model):
|
||||
self,
|
||||
dataset,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
@@ -210,11 +217,14 @@ class ALSTM(Model):
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
train_loader = DataLoader(dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs)
|
||||
valid_loader = DataLoader(dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs)
|
||||
train_loader = DataLoader(
|
||||
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
valid_loader = DataLoader(
|
||||
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
@@ -271,10 +281,7 @@ class ALSTM(Model):
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.ALSTM_model(feature.float()).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.ALSTM_model(feature.float()).detach().numpy()
|
||||
pred = self.ALSTM_model(feature.float()).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
@@ -22,6 +22,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -42,8 +43,8 @@ class GATs(Model):
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -83,7 +84,7 @@ class GATs(Model):
|
||||
self.base_model = base_model
|
||||
self.with_pretrain = with_pretrain
|
||||
self.model_path = model_path
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
|
||||
@@ -102,7 +103,7 @@ class GATs(Model):
|
||||
"\nbase_model : {}"
|
||||
"\nwith_pretrain : {}"
|
||||
"\nmodel_path : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
@@ -118,7 +119,7 @@ class GATs(Model):
|
||||
base_model,
|
||||
with_pretrain,
|
||||
model_path,
|
||||
GPU,
|
||||
self.device,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
@@ -135,6 +136,9 @@ class GATs(Model):
|
||||
dropout=self.dropout,
|
||||
base_model=self.base_model,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.GAT_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.GAT_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.GAT_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -145,6 +149,10 @@ class GATs(Model):
|
||||
self.fitted = False
|
||||
self.GAT_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -232,7 +240,6 @@ class GATs(Model):
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
@@ -245,8 +252,7 @@ class GATs(Model):
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
@@ -324,10 +330,7 @@ class GATs(Model):
|
||||
x_batch = torch.from_numpy(x_values[batch]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.GAT_model(x_batch).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.GAT_model(x_batch).detach().numpy()
|
||||
pred = self.GAT_model(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
@@ -24,6 +24,7 @@ import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -62,8 +63,8 @@ class GATs(Model):
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -104,9 +105,8 @@ class GATs(Model):
|
||||
self.base_model = base_model
|
||||
self.with_pretrain = with_pretrain
|
||||
self.model_path = model_path
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.n_jobs = n_jobs
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -157,6 +157,9 @@ class GATs(Model):
|
||||
dropout=self.dropout,
|
||||
base_model=self.base_model,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.GAT_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.GAT_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.GAT_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -167,6 +170,10 @@ class GATs(Model):
|
||||
self.fitted = False
|
||||
self.GAT_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -245,7 +252,6 @@ class GATs(Model):
|
||||
self,
|
||||
dataset,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
@@ -258,11 +264,10 @@ class GATs(Model):
|
||||
sampler_train = DailyBatchSampler(dl_train)
|
||||
sampler_valid = DailyBatchSampler(dl_valid)
|
||||
|
||||
train_loader = DataLoader(dl_train, sampler=sampler_train, num_workers=self.n_jobs)
|
||||
valid_loader = DataLoader(dl_valid, sampler=sampler_valid, num_workers=self.n_jobs)
|
||||
train_loader = DataLoader(dl_train, sampler=sampler_train, num_workers=self.n_jobs, drop_last=True)
|
||||
valid_loader = DataLoader(dl_valid, sampler=sampler_valid, num_workers=self.n_jobs, drop_last=True)
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
@@ -345,10 +350,7 @@ class GATs(Model):
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.GAT_model(feature.float()).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.GAT_model(feature.float()).detach().numpy()
|
||||
pred = self.GAT_model(feature.float()).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
@@ -23,6 +23,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -76,8 +77,7 @@ class GRU(Model):
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -123,6 +123,9 @@ class GRU(Model):
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.gru_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.gru_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.gru_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -133,6 +136,10 @@ class GRU(Model):
|
||||
self.fitted = False
|
||||
self.gru_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -201,12 +208,13 @@ class GRU(Model):
|
||||
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.gru_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
with torch.no_grad():
|
||||
pred = self.gru_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
@@ -214,7 +222,6 @@ class GRU(Model):
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
@@ -227,8 +234,7 @@ class GRU(Model):
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
@@ -290,10 +296,7 @@ class GRU(Model):
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.gru_model(x_batch).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.gru_model(x_batch).detach().numpy()
|
||||
pred = self.gru_model(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
@@ -24,6 +24,7 @@ import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -78,9 +79,8 @@ class GRU(Model):
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.n_jobs = n_jobs
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -96,7 +96,7 @@ class GRU(Model):
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nn_jobs : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
@@ -111,7 +111,7 @@ class GRU(Model):
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.device,
|
||||
n_jobs,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
@@ -127,7 +127,10 @@ class GRU(Model):
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
).to(self.device)
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.gru_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.gru_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.GRU_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -138,6 +141,10 @@ class GRU(Model):
|
||||
self.fitted = False
|
||||
self.GRU_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -188,12 +195,13 @@ class GRU(Model):
|
||||
# feature[torch.isnan(feature)] = 0
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
|
||||
pred = self.GRU_model(feature.float())
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
with torch.no_grad():
|
||||
pred = self.GRU_model(feature.float())
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
@@ -201,7 +209,6 @@ class GRU(Model):
|
||||
self,
|
||||
dataset,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
@@ -210,11 +217,14 @@ class GRU(Model):
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
train_loader = DataLoader(dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs)
|
||||
valid_loader = DataLoader(dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs)
|
||||
train_loader = DataLoader(
|
||||
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
valid_loader = DataLoader(
|
||||
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
@@ -271,10 +281,7 @@ class GRU(Model):
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.GRU_model(feature.float()).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.GRU_model(feature.float()).detach().numpy()
|
||||
pred = self.GRU_model(feature.float()).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
@@ -76,8 +76,7 @@ class LSTM(Model):
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -133,6 +132,10 @@ class LSTM(Model):
|
||||
self.fitted = False
|
||||
self.lstm_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -214,7 +217,6 @@ class LSTM(Model):
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
@@ -227,8 +229,7 @@ class LSTM(Model):
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
@@ -290,10 +291,7 @@ class LSTM(Model):
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.lstm_model(x_batch).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.lstm_model(x_batch).detach().numpy()
|
||||
pred = self.lstm_model(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
@@ -78,9 +78,8 @@ class LSTM(Model):
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.n_jobs = n_jobs
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -96,7 +95,7 @@ class LSTM(Model):
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nn_jobs : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
@@ -111,7 +110,7 @@ class LSTM(Model):
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.device,
|
||||
n_jobs,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
@@ -138,6 +137,10 @@ class LSTM(Model):
|
||||
self.fitted = False
|
||||
self.LSTM_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -201,7 +204,6 @@ class LSTM(Model):
|
||||
self,
|
||||
dataset,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
@@ -210,11 +212,14 @@ class LSTM(Model):
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
train_loader = DataLoader(dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs)
|
||||
valid_loader = DataLoader(dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs)
|
||||
train_loader = DataLoader(
|
||||
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
valid_loader = DataLoader(
|
||||
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
@@ -271,10 +276,7 @@ class LSTM(Model):
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.LSTM_model(feature.float()).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.LSTM_model(feature.float()).detach().numpy()
|
||||
pred = self.LSTM_model(feature.float()).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -15,10 +15,11 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index
|
||||
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, get_or_create_path, drop_nan_by_y_index
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...workflow import R
|
||||
|
||||
@@ -42,8 +43,8 @@ class DNNModelPytorch(Model):
|
||||
learning rate decay steps
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -80,8 +81,7 @@ class DNNModelPytorch(Model):
|
||||
self.lr_decay_steps = lr_decay_steps
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss_type = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_GPU = torch.cuda.is_available()
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
@@ -99,7 +99,7 @@ class DNNModelPytorch(Model):
|
||||
"\nloss_type : {}"
|
||||
"\neval_steps : {}"
|
||||
"\nseed : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nweight_decay : {}".format(
|
||||
layers,
|
||||
@@ -114,8 +114,8 @@ class DNNModelPytorch(Model):
|
||||
loss,
|
||||
eval_steps,
|
||||
seed,
|
||||
GPU,
|
||||
self.use_GPU,
|
||||
self.device,
|
||||
self.use_gpu,
|
||||
weight_decay,
|
||||
)
|
||||
)
|
||||
@@ -129,6 +129,9 @@ class DNNModelPytorch(Model):
|
||||
self._scorer = mean_squared_error if loss == "mse" else roc_auc_score
|
||||
|
||||
self.dnn_model = Net(input_dim, output_dim, layers, loss=self.loss_type)
|
||||
self.logger.info("model:\n{:}".format(self.dnn_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.dnn_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -153,6 +156,10 @@ class DNNModelPytorch(Model):
|
||||
self.fitted = False
|
||||
self.dnn_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
@@ -172,7 +179,7 @@ class DNNModelPytorch(Model):
|
||||
w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index)
|
||||
w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index)
|
||||
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_loss = np.inf
|
||||
@@ -215,7 +222,8 @@ class DNNModelPytorch(Model):
|
||||
|
||||
# validation
|
||||
train_loss += loss.val
|
||||
if step and step % self.eval_steps == 0:
|
||||
# for evert `eval_steps` steps or at the last steps, we will evaluate the model.
|
||||
if step % self.eval_steps == 0 or step + 1 == self.max_steps:
|
||||
stop_steps += 1
|
||||
train_loss /= self.eval_steps
|
||||
|
||||
@@ -248,9 +256,9 @@ class DNNModelPytorch(Model):
|
||||
# update learning rate
|
||||
self.scheduler.step(cur_loss_val)
|
||||
|
||||
# restore the optimal parameters after training ??
|
||||
# restore the optimal parameters after training
|
||||
self.dnn_model.load_state_dict(torch.load(save_path))
|
||||
if self.use_GPU:
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_loss(self, pred, w, target, loss_type):
|
||||
@@ -272,10 +280,7 @@ class DNNModelPytorch(Model):
|
||||
self.dnn_model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_GPU:
|
||||
preds = self.dnn_model(x_test).detach().cpu().numpy()
|
||||
else:
|
||||
preds = self.dnn_model(x_test).detach().numpy()
|
||||
preds = self.dnn_model(x_test).detach().cpu().numpy()
|
||||
return pd.Series(np.squeeze(preds), index=x_test_pd.index)
|
||||
|
||||
def save(self, filename, **kwargs):
|
||||
|
||||
@@ -13,7 +13,7 @@ import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
@@ -23,6 +23,7 @@ import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
import torch.optim as optim
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -196,8 +197,8 @@ class SFM(Model):
|
||||
learning rate
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -216,7 +217,7 @@ class SFM(Model):
|
||||
eval_steps=5,
|
||||
loss="mse",
|
||||
optimizer="gd",
|
||||
GPU="0",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
@@ -239,8 +240,7 @@ class SFM(Model):
|
||||
self.eval_steps = eval_steps
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -259,7 +259,7 @@ class SFM(Model):
|
||||
"\neval_steps : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
@@ -276,7 +276,7 @@ class SFM(Model):
|
||||
eval_steps,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.device,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
@@ -295,6 +295,9 @@ class SFM(Model):
|
||||
dropout_U=self.dropout_U,
|
||||
device=self.device,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.sfm_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.sfm_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.sfm_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -305,6 +308,10 @@ class SFM(Model):
|
||||
self.fitted = False
|
||||
self.sfm_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def test_epoch(self, data_x, data_y):
|
||||
|
||||
# prepare training data
|
||||
@@ -365,7 +372,6 @@ class SFM(Model):
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
@@ -377,6 +383,7 @@ class SFM(Model):
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
@@ -409,7 +416,10 @@ class SFM(Model):
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.sfm_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
if self.device != "cpu":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
@@ -23,6 +23,7 @@ import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Function
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -49,12 +50,12 @@ class TabnetModel(Model):
|
||||
loss="mse",
|
||||
metric="",
|
||||
early_stop=20,
|
||||
GPU="1",
|
||||
GPU=0,
|
||||
pretrain_loss="custom",
|
||||
ps=0.3,
|
||||
lr=0.01,
|
||||
pretrain=True,
|
||||
pretrain_file="./pretrain/best.model",
|
||||
pretrain_file=None,
|
||||
):
|
||||
"""
|
||||
TabNet model for Qlib
|
||||
@@ -75,18 +76,18 @@ class TabnetModel(Model):
|
||||
self.n_epochs = n_epochs
|
||||
self.logger = get_module_logger("TabNet")
|
||||
self.pretrain_n_epochs = pretrain_n_epochs
|
||||
self.device = "cuda:%s" % (GPU) if torch.cuda.is_available() else "cpu"
|
||||
self.device = "cuda:%s" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu"
|
||||
self.loss = loss
|
||||
self.metric = metric
|
||||
self.early_stop = early_stop
|
||||
self.pretrain = pretrain
|
||||
self.pretrain_file = pretrain_file
|
||||
self.pretrain_file = get_or_create_path(pretrain_file)
|
||||
self.logger.info(
|
||||
"TabNet:"
|
||||
"\nbatch_size : {}"
|
||||
"\nvirtual bs : {}"
|
||||
"\nGPU : {}"
|
||||
"\npretrain: {}".format(self.batch_size, vbs, GPU, pretrain)
|
||||
"\ndevice : {}"
|
||||
"\npretrain: {}".format(self.batch_size, vbs, self.device, self.pretrain)
|
||||
)
|
||||
self.fitted = False
|
||||
np.random.seed(self.seed)
|
||||
@@ -98,6 +99,8 @@ class TabnetModel(Model):
|
||||
self.tabnet_decoder = TabNet_Decoder(self.out_dim, self.d_feat, n_shared, n_ind, vbs, n_steps, self.device).to(
|
||||
self.device
|
||||
)
|
||||
self.logger.info("model:\n{:}\n{:}".format(self.tabnet_model, self.tabnet_decoder))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters([self.tabnet_model, self.tabnet_decoder])))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.pretrain_optimizer = optim.Adam(
|
||||
@@ -113,11 +116,12 @@ class TabnetModel(Model):
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def pretrain_fn(self, dataset=DatasetH, pretrain_file="./pretrain/best.model"):
|
||||
# make a directory if pretrian director does not exist
|
||||
if pretrain_file.startswith("./pretrain") and not os.path.exists("pretrain"):
|
||||
self.logger.info("make folder to store model...")
|
||||
os.makedirs("pretrain")
|
||||
get_or_create_path(pretrain_file)
|
||||
|
||||
[df_train, df_valid] = dataset.prepare(
|
||||
["pretrain", "pretrain_validation"],
|
||||
@@ -159,7 +163,6 @@ class TabnetModel(Model):
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
if self.pretrain:
|
||||
@@ -179,10 +182,11 @@ class TabnetModel(Model):
|
||||
df_train.fillna(df_train.mean(), inplace=True)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = np.inf
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
@@ -201,16 +205,23 @@ class TabnetModel(Model):
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score < best_score:
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = epoch_idx
|
||||
best_param = copy.deepcopy(self.tabnet_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.tabnet_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self.fitted:
|
||||
@@ -260,12 +271,13 @@ class TabnetModel(Model):
|
||||
feature = x_values[indices[i : i + self.batch_size]].float().to(self.device)
|
||||
label = y_values[indices[i : i + self.batch_size]].float().to(self.device)
|
||||
priors = torch.ones(self.batch_size, self.d_feat).to(self.device)
|
||||
pred = self.tabnet_model(feature, priors)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
with torch.no_grad():
|
||||
pred = self.tabnet_model(feature, priors)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
@@ -348,10 +360,11 @@ class TabnetModel(Model):
|
||||
label = y_train_values.float().to(self.device)
|
||||
S_mask = S_mask.to(self.device)
|
||||
priors = 1 - S_mask
|
||||
(vec, sparse_loss) = self.tabnet_model(feature, priors)
|
||||
f = self.tabnet_decoder(vec)
|
||||
with torch.no_grad():
|
||||
(vec, sparse_loss) = self.tabnet_model(feature, priors)
|
||||
f = self.tabnet_decoder(vec)
|
||||
|
||||
loss = self.pretrain_loss_fn(label, f, S_mask)
|
||||
loss = self.pretrain_loss_fn(label, f, S_mask)
|
||||
losses.append(loss.item())
|
||||
|
||||
return np.mean(losses)
|
||||
|
||||
37
qlib/contrib/model/pytorch_utils.py
Normal file
37
qlib/contrib/model/pytorch_utils.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def count_parameters(models_or_parameters, unit="m"):
|
||||
"""
|
||||
This function is to obtain the storage size unit of a (or multiple) models.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
models_or_parameters : PyTorch model(s) or a list of parameters.
|
||||
unit : the storage size unit.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The number of parameters of the given model(s) or parameters.
|
||||
"""
|
||||
if isinstance(models_or_parameters, nn.Module):
|
||||
counts = sum(v.numel() for v in models_or_parameters.parameters())
|
||||
elif isinstance(models_or_parameters, nn.Parameter):
|
||||
counts = models_or_parameters.numel()
|
||||
elif isinstance(models_or_parameters, (list, tuple)):
|
||||
return sum(count_parameters(x, unit) for x in models_or_parameters)
|
||||
else:
|
||||
counts = sum(v.numel() for v in models_or_parameters)
|
||||
unit = unit.lower()
|
||||
if unit == "kb" or unit == "k":
|
||||
counts /= 2 ** 10
|
||||
elif unit == "mb" or unit == "m":
|
||||
counts /= 2 ** 20
|
||||
elif unit == "gb" or unit == "g":
|
||||
counts /= 2 ** 30
|
||||
elif unit is not None:
|
||||
raise ValueError("Unknow unit: {:}".format(unit))
|
||||
return counts
|
||||
@@ -1,5 +1,5 @@
|
||||
from ...utils.serial import Serializable
|
||||
from typing import Union, List, Tuple
|
||||
from typing import Union, List, Tuple, Dict, Text, Optional
|
||||
from ...utils import init_instance_by_config, np_ffill
|
||||
from ...log import get_module_logger
|
||||
from .handler import DataHandler, DataHandlerLP
|
||||
@@ -113,7 +113,7 @@ class DatasetH(Dataset):
|
||||
raise TypeError(f"param handler_kwargs must be type dict, not {type(segment_kwargs)}")
|
||||
self.segments = segment_kwargs.copy()
|
||||
|
||||
def setup_data(self, handler: Union[dict, DataHandler], segments: dict):
|
||||
def setup_data(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple]):
|
||||
"""
|
||||
Setup the underlying data.
|
||||
|
||||
@@ -145,6 +145,11 @@ class DatasetH(Dataset):
|
||||
self.handler = init_instance_by_config(handler, accept_types=DataHandler)
|
||||
self.segments = segments.copy()
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(handler={handler}, segments={segments})".format(
|
||||
name=self.__class__.__name__, handler=self.handler, segments=self.segments
|
||||
)
|
||||
|
||||
def _prepare_seg(self, slc: slice, **kwargs):
|
||||
"""
|
||||
Give a slice, retrieve the according data
|
||||
@@ -157,7 +162,7 @@ class DatasetH(Dataset):
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
segments: Union[List[str], Tuple[str], str, slice],
|
||||
segments: Union[List[Text], Tuple[Text], Text, slice],
|
||||
col_set=DataHandler.CS_ALL,
|
||||
data_key=DataHandlerLP.DK_I,
|
||||
**kwargs,
|
||||
@@ -167,7 +172,7 @@ class DatasetH(Dataset):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
segments : Union[List[str], Tuple[str], str, slice]
|
||||
segments : Union[List[Text], Tuple[Text], Text, slice]
|
||||
Describe the scope of the data to be prepared
|
||||
Here are some examples:
|
||||
|
||||
@@ -397,7 +402,7 @@ class TSDataSampler:
|
||||
# 1) for better performance, use the last nan line for padding the lost date
|
||||
# 2) In case of precision problems. We use np.float64. # TODO: I'm not sure if whether np.float64 will result in
|
||||
# precision problems. It will not cause any problems in my tests at least
|
||||
indices = np.nan_to_num(indices.astype(np.float64), nan=self.nan_idx).astype(np.int)
|
||||
indices = np.nan_to_num(indices.astype(np.float64), nan=self.nan_idx).astype(int)
|
||||
|
||||
data = self.data_arr[indices]
|
||||
if isinstance(idx, mtit):
|
||||
|
||||
@@ -35,7 +35,7 @@ class DataHandler(Serializable):
|
||||
The data handler try to maintain a handler with 2 level.
|
||||
`datetime` & `instruments`.
|
||||
|
||||
Any order of the index level can be suported(The order will implied in the data).
|
||||
Any order of the index level can be suported (The order will be implied in the data).
|
||||
The order <`datetime`, `instruments`> will be used when the dataframe index name is missed.
|
||||
|
||||
Example of the data:
|
||||
@@ -47,8 +47,8 @@ class DataHandler(Serializable):
|
||||
$close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0
|
||||
datetime instrument
|
||||
2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032
|
||||
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
|
||||
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
|
||||
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
|
||||
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@@ -74,7 +74,6 @@ class NpElemOperator(ElemOperator):
|
||||
"""
|
||||
|
||||
def __init__(self, feature, func):
|
||||
self.feature = feature
|
||||
self.func = func
|
||||
super(NpElemOperator, self).__init__(feature)
|
||||
|
||||
@@ -289,8 +288,6 @@ class NpPairOperator(PairOperator):
|
||||
"""
|
||||
|
||||
def __init__(self, feature_left, feature_right, func):
|
||||
self.feature_left = feature_left
|
||||
self.feature_right = feature_right
|
||||
self.func = func
|
||||
super(NpPairOperator, self).__init__(feature_left, feature_right)
|
||||
|
||||
@@ -1489,7 +1486,7 @@ OpsList = [
|
||||
]
|
||||
|
||||
|
||||
class OpsWrapper(object):
|
||||
class OpsWrapper:
|
||||
"""Ops Wrapper"""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
10
qlib/log.py
10
qlib/log.py
@@ -3,8 +3,7 @@
|
||||
|
||||
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
from typing import Optional, Text, Dict, Any
|
||||
import re
|
||||
from logging import config as logging_config
|
||||
from time import time
|
||||
@@ -13,16 +12,13 @@ from contextlib import contextmanager
|
||||
from .config import C
|
||||
|
||||
|
||||
def get_module_logger(module_name, level=None):
|
||||
def get_module_logger(module_name, level: Optional[int] = None):
|
||||
"""
|
||||
Get a logger for a specific module.
|
||||
|
||||
:param module_name: str
|
||||
Logic module name.
|
||||
:param level: int
|
||||
:param sh_level: int
|
||||
Stream handler log level.
|
||||
:param log_format: str
|
||||
:return: Logger
|
||||
Logger object.
|
||||
"""
|
||||
@@ -103,7 +99,7 @@ class TimeInspector:
|
||||
cls.log_cost_time(info=f"{name} Done")
|
||||
|
||||
|
||||
def set_log_with_config(log_config: dict):
|
||||
def set_log_with_config(log_config: Dict[Text, Any]):
|
||||
"""set log with config
|
||||
|
||||
:param log_config:
|
||||
|
||||
@@ -24,7 +24,7 @@ import collections
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Union, Tuple
|
||||
from typing import Union, Tuple, Text, Optional
|
||||
|
||||
from ..config import C
|
||||
from ..log import get_module_logger, set_log_with_config
|
||||
@@ -64,7 +64,7 @@ def np_ffill(arr: np.array):
|
||||
arr : np.array
|
||||
Input numpy 1D array
|
||||
"""
|
||||
mask = np.isnan(arr.astype(np.float)) # np.isnan only works on np.float
|
||||
mask = np.isnan(arr.astype(float)) # np.isnan only works on np.float
|
||||
# get fill index
|
||||
idx = np.where(~mask, np.arange(mask.shape[0]), 0)
|
||||
np.maximum.accumulate(idx, out=idx)
|
||||
@@ -276,23 +276,31 @@ def compare_dict_value(src_data: dict, dst_data: dict):
|
||||
return changes
|
||||
|
||||
|
||||
def create_save_path(save_path=None):
|
||||
"""Create save path
|
||||
def get_or_create_path(path: Optional[Text] = None, return_dir: bool = False):
|
||||
"""Create or get a file or directory given the path and return_dir.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
save_path: str
|
||||
path: a string indicates the path or None indicates creating a temporary path.
|
||||
return_dir: if True, create and return a directory; otherwise c&r a file.
|
||||
|
||||
"""
|
||||
if save_path:
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
if path:
|
||||
if return_dir and not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
elif not return_dir: # return a file, thus we need to create its parent directory
|
||||
xpath = os.path.abspath(os.path.join(path, ".."))
|
||||
if not os.path.exists(xpath):
|
||||
os.makedirs(xpath)
|
||||
else:
|
||||
temp_dir = os.path.expanduser("~/tmp")
|
||||
if not os.path.exists(temp_dir):
|
||||
os.makedirs(temp_dir)
|
||||
_, save_path = tempfile.mkstemp(dir=temp_dir)
|
||||
return save_path
|
||||
if return_dir:
|
||||
_, path = tempfile.mkdtemp(dir=temp_dir)
|
||||
else:
|
||||
_, path = tempfile.mkstemp(dir=temp_dir)
|
||||
return path
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Text, Optional
|
||||
from .expm import MLflowExpManager
|
||||
from .exp import Experiment
|
||||
from .recorder import Recorder
|
||||
@@ -16,8 +17,13 @@ class QlibRecorder:
|
||||
def __init__(self, exp_manager):
|
||||
self.exp_manager = exp_manager
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(manager={manager})".format(name=self.__class__.__name__, manager=self.exp_manager)
|
||||
|
||||
@contextmanager
|
||||
def start(self, experiment_name=None, recorder_name=None):
|
||||
def start(
|
||||
self, experiment_name: Optional[Text] = None, recorder_name: Optional[Text] = None, uri: Optional[Text] = None
|
||||
):
|
||||
"""
|
||||
Method to start an experiment. This method can only be called within a Python's `with` statement. Here is the example code:
|
||||
|
||||
@@ -34,8 +40,13 @@ class QlibRecorder:
|
||||
name of the experiment one wants to start.
|
||||
recorder_name : str
|
||||
name of the recorder under the experiment one wants to start.
|
||||
uri : str
|
||||
The tracking uri of the experiment, where all the artifacts/metrics etc. will be stored.
|
||||
The default uri is set in the qlib.config. Note that this uri argument will not change the one defined in the config file.
|
||||
Therefore, the next time when users call this function in the same experiment,
|
||||
they have to also specify this argument with the same value. Otherwise, inconsistent uri may occur.
|
||||
"""
|
||||
run = self.start_exp(experiment_name, recorder_name)
|
||||
run = self.start_exp(experiment_name, recorder_name, uri)
|
||||
try:
|
||||
yield run
|
||||
except Exception as e:
|
||||
@@ -272,7 +283,13 @@ class QlibRecorder:
|
||||
-------
|
||||
The uri of current experiment manager.
|
||||
"""
|
||||
return self.exp_manager.get_uri()
|
||||
return self.exp_manager.uri
|
||||
|
||||
def set_uri(self, uri: Optional[Text]):
|
||||
"""
|
||||
Method to reset the current uri of current experiment manager.
|
||||
"""
|
||||
self.exp_manager.set_uri(uri)
|
||||
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None):
|
||||
"""
|
||||
|
||||
@@ -23,7 +23,7 @@ class Experiment:
|
||||
self.active_recorder = None # only one recorder can running each time
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.info)
|
||||
return "{name}(id={id}, info={info})".format(name=self.__class__.__name__, id=self.id, info=self.info)
|
||||
|
||||
def __str__(self):
|
||||
return str(self.info)
|
||||
@@ -173,11 +173,12 @@ class MLflowExperiment(Experiment):
|
||||
self._uri = uri
|
||||
self._default_name = None
|
||||
self._default_rec_name = "mlflow_recorder"
|
||||
self.client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
self._client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(id={id}, info={info})".format(name=self.__class__.__name__, id=self.id, info=self.info)
|
||||
|
||||
def start(self, recorder_name=None):
|
||||
# set the active experiment
|
||||
mlflow.set_experiment(self.name)
|
||||
logger.info(f"Experiment {self.id} starts running ...")
|
||||
# set up recorder
|
||||
recorder = self.create_recorder(recorder_name)
|
||||
@@ -210,7 +211,6 @@ class MLflowExperiment(Experiment):
|
||||
else:
|
||||
recorder, is_new = self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
|
||||
if is_new:
|
||||
mlflow.set_experiment(self.name)
|
||||
self.active_recorder = recorder
|
||||
# start the recorder
|
||||
self.active_recorder.start_run()
|
||||
@@ -239,7 +239,7 @@ class MLflowExperiment(Experiment):
|
||||
), "Please input at least one of recorder id or name before retrieving recorder."
|
||||
if recorder_id is not None:
|
||||
try:
|
||||
run = self.client.get_run(recorder_id)
|
||||
run = self._client.get_run(recorder_id)
|
||||
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=run)
|
||||
return recorder
|
||||
except MlflowException:
|
||||
@@ -260,7 +260,7 @@ class MLflowExperiment(Experiment):
|
||||
max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results")
|
||||
order_by = kwargs.get("order_by")
|
||||
|
||||
return self.client.search_runs([self.id], filter_string, run_view_type, max_results, order_by)
|
||||
return self._client.search_runs([self.id], filter_string, run_view_type, max_results, order_by)
|
||||
|
||||
def delete_recorder(self, recorder_id=None, recorder_name=None):
|
||||
assert (
|
||||
@@ -268,10 +268,10 @@ class MLflowExperiment(Experiment):
|
||||
), "Please input a valid recorder id or name before deleting."
|
||||
try:
|
||||
if recorder_id is not None:
|
||||
self.client.delete_run(recorder_id)
|
||||
self._client.delete_run(recorder_id)
|
||||
else:
|
||||
recorder = self._get_recorder(recorder_name=recorder_name)
|
||||
self.client.delete_run(recorder.id)
|
||||
self._client.delete_run(recorder.id)
|
||||
except MlflowException as e:
|
||||
raise Exception(
|
||||
f"Error: {e}. Something went wrong when deleting recorder. Please check if the name/id of the recorder is correct."
|
||||
@@ -280,7 +280,7 @@ class MLflowExperiment(Experiment):
|
||||
UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!
|
||||
|
||||
def list_recorders(self, max_results=UNLIMITED):
|
||||
runs = self.client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1]
|
||||
runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1]
|
||||
recorders = dict()
|
||||
for i in range(len(runs)):
|
||||
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])
|
||||
|
||||
@@ -7,8 +7,11 @@ from mlflow.entities import ViewType
|
||||
import os
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Text
|
||||
|
||||
from .exp import MLflowExperiment, Experiment
|
||||
from .recorder import Recorder, MLflowRecorder
|
||||
from ..config import C
|
||||
from .recorder import Recorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
@@ -20,12 +23,21 @@ class ExpManager:
|
||||
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
|
||||
"""
|
||||
|
||||
def __init__(self, uri, default_exp_name):
|
||||
self.uri = uri
|
||||
def __init__(self, uri: Text, default_exp_name: Optional[Text]):
|
||||
self._current_uri = uri
|
||||
self.default_exp_name = default_exp_name
|
||||
self.active_experiment = None # only one experiment can active each time
|
||||
|
||||
def start_exp(self, experiment_name=None, recorder_name=None, uri=None, **kwargs):
|
||||
def __repr__(self):
|
||||
return "{name}(current_uri={curi})".format(name=self.__class__.__name__, curi=self._current_uri)
|
||||
|
||||
def start_exp(
|
||||
self,
|
||||
experiment_name: Optional[Text] = None,
|
||||
recorder_name: Optional[Text] = None,
|
||||
uri: Optional[Text] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Start an experiment. This method includes first get_or_create an experiment, and then
|
||||
set it to be active.
|
||||
@@ -45,7 +57,7 @@ class ExpManager:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `start_exp` method.")
|
||||
|
||||
def end_exp(self, recorder_status: str = Recorder.STATUS_S, **kwargs):
|
||||
def end_exp(self, recorder_status: Text = Recorder.STATUS_S, **kwargs):
|
||||
"""
|
||||
End an active experiment.
|
||||
|
||||
@@ -58,7 +70,7 @@ class ExpManager:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `end_exp` method.")
|
||||
|
||||
def create_exp(self, experiment_name=None):
|
||||
def create_exp(self, experiment_name: Optional[Text] = None):
|
||||
"""
|
||||
Create an experiment.
|
||||
|
||||
@@ -203,7 +215,17 @@ class ExpManager:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `delete_exp` method.")
|
||||
|
||||
def get_uri(self):
|
||||
@property
|
||||
def default_uri(self):
|
||||
"""
|
||||
Get the default tracking URI from qlib.config.C
|
||||
"""
|
||||
if "kwargs" not in C.exp_manager or "uri" not in C.exp_manager["kwargs"]:
|
||||
raise ValueError("The default URI is not set in qlib.config.C")
|
||||
return C.exp_manager["kwargs"]["uri"]
|
||||
|
||||
@property
|
||||
def uri(self):
|
||||
"""
|
||||
Get the default tracking URI or current URI.
|
||||
|
||||
@@ -211,7 +233,31 @@ class ExpManager:
|
||||
-------
|
||||
The tracking URI string.
|
||||
"""
|
||||
return self.uri
|
||||
return self._current_uri or self.default_uri
|
||||
|
||||
def set_uri(self, uri: Optional[Text] = None):
|
||||
"""
|
||||
Set the current tracking URI and the corresponding variables.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
uri : str
|
||||
|
||||
"""
|
||||
if uri is None:
|
||||
logger.info("No tracking URI is provided. Use the default tracking URI.")
|
||||
self._current_uri = self.default_uri
|
||||
else:
|
||||
# Temporarily re-set the current uri as the uri argument.
|
||||
self._current_uri = uri
|
||||
# Customized features for subclasses.
|
||||
self._set_uri()
|
||||
|
||||
def _set_uri(self):
|
||||
"""
|
||||
Customized features for subclasses' set_uri function.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `_set_uri` method.")
|
||||
|
||||
def list_experiments(self):
|
||||
"""
|
||||
@@ -232,31 +278,32 @@ class MLflowExpManager(ExpManager):
|
||||
@property
|
||||
def client(self):
|
||||
# Delay the creation of mlflow client in case of creating `mlruns` folder when importing qlib
|
||||
if not hasattr(self, "_client"):
|
||||
if self._client is None:
|
||||
self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri)
|
||||
return self._client
|
||||
|
||||
def start_exp(self, experiment_name=None, recorder_name=None, uri=None):
|
||||
# set the tracking uri
|
||||
if uri is None:
|
||||
logger.info("No tracking URI is provided. Use the default tracking URI.")
|
||||
else:
|
||||
self.uri = uri
|
||||
# create experiment
|
||||
def start_exp(
|
||||
self, experiment_name: Optional[Text] = None, recorder_name: Optional[Text] = None, uri: Optional[Text] = None
|
||||
):
|
||||
# Set the tracking uri
|
||||
self.set_uri(uri)
|
||||
# Create experiment
|
||||
experiment, _ = self._get_or_create_exp(experiment_name=experiment_name)
|
||||
# set up active experiment
|
||||
# Set up active experiment
|
||||
self.active_experiment = experiment
|
||||
# start the experiment
|
||||
# Start the experiment
|
||||
self.active_experiment.start(recorder_name)
|
||||
|
||||
return self.active_experiment
|
||||
|
||||
def end_exp(self, recorder_status: str = Recorder.STATUS_S):
|
||||
def end_exp(self, recorder_status: Text = Recorder.STATUS_S):
|
||||
if self.active_experiment is not None:
|
||||
self.active_experiment.end(recorder_status)
|
||||
self.active_experiment = None
|
||||
# When an experiment end, we will release the current uri.
|
||||
self._current_uri = None
|
||||
|
||||
def create_exp(self, experiment_name=None):
|
||||
def create_exp(self, experiment_name: Optional[Text] = None):
|
||||
assert experiment_name is not None
|
||||
# init experiment
|
||||
experiment_id = self.client.create_experiment(experiment_name)
|
||||
|
||||
@@ -34,7 +34,7 @@ class Recorder:
|
||||
self.status = Recorder.STATUS_S
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.info)
|
||||
return "{name}(info={info})".format(name=self.__class__.__name__, info=self.info)
|
||||
|
||||
def __str__(self):
|
||||
return str(self.info)
|
||||
@@ -201,7 +201,7 @@ class MLflowRecorder(Recorder):
|
||||
def __init__(self, experiment_id, uri, name=None, mlflow_run=None):
|
||||
super(MLflowRecorder, self).__init__(experiment_id, name)
|
||||
self._uri = uri
|
||||
self.artifact_uri = None
|
||||
self._artifact_uri = None
|
||||
self.client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
# construct from mlflow run
|
||||
if mlflow_run is not None:
|
||||
@@ -220,14 +220,51 @@ class MLflowRecorder(Recorder):
|
||||
else None
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
name = self.__class__.__name__
|
||||
space_length = len(name) + 1
|
||||
return "{name}(info={info},\n{space}uri={uri},\n{space}artifact_uri={artifact_uri},\n{space}client={client})".format(
|
||||
name=name,
|
||||
space=" " * space_length,
|
||||
info=self.info,
|
||||
uri=self.uri,
|
||||
artifact_uri=self.artifact_uri,
|
||||
client=self.client,
|
||||
)
|
||||
|
||||
@property
|
||||
def uri(self):
|
||||
return self._uri
|
||||
|
||||
@property
|
||||
def artifact_uri(self):
|
||||
return self._artifact_uri
|
||||
|
||||
def get_local_dir(self):
|
||||
"""
|
||||
This function will return the directory path of this recorder.
|
||||
"""
|
||||
if self.artifact_uri is not None:
|
||||
local_dir_path = Path(self.artifact_uri.lstrip("file:")) / ".."
|
||||
local_dir_path = str(local_dir_path.resolve())
|
||||
if os.path.isdir(local_dir_path):
|
||||
return local_dir_path
|
||||
else:
|
||||
raise RuntimeError("This recorder is not saved in the local file system.")
|
||||
|
||||
else:
|
||||
raise Exception(
|
||||
"Please make sure the recorder has been created and started properly before getting artifact uri."
|
||||
)
|
||||
|
||||
def start_run(self):
|
||||
# set the tracking uri
|
||||
mlflow.set_tracking_uri(self._uri)
|
||||
mlflow.set_tracking_uri(self.uri)
|
||||
# start the run
|
||||
run = mlflow.start_run(self.id, self.experiment_id, self.name)
|
||||
# save the run id and artifact_uri
|
||||
self.id = run.info.run_id
|
||||
self.artifact_uri = run.info.artifact_uri
|
||||
self._artifact_uri = run.info.artifact_uri
|
||||
self.start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
self.status = Recorder.STATUS_R
|
||||
logger.info(f"Recorder {self.id} starts running under Experiment {self.experiment_id} ...")
|
||||
@@ -247,7 +284,7 @@ class MLflowRecorder(Recorder):
|
||||
self.status = status
|
||||
|
||||
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
|
||||
assert self._uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
if local_path is not None:
|
||||
self.client.log_artifacts(self.id, local_path, artifact_path)
|
||||
else:
|
||||
@@ -259,7 +296,7 @@ class MLflowRecorder(Recorder):
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
def load_object(self, name):
|
||||
assert self._uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
path = self.client.download_artifacts(self.id, name)
|
||||
with Path(path).open("rb") as f:
|
||||
return pickle.load(f)
|
||||
@@ -289,7 +326,7 @@ class MLflowRecorder(Recorder):
|
||||
)
|
||||
|
||||
def list_artifacts(self, artifact_path=None):
|
||||
assert self._uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
artifacts = self.client.list_artifacts(self.id, artifact_path)
|
||||
return [art.path for art in artifacts]
|
||||
|
||||
|
||||
430
scripts/data_collector/base.py
Normal file
430
scripts/data_collector/base.py
Normal file
@@ -0,0 +1,430 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import abc
|
||||
import time
|
||||
import datetime
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from typing import Type
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from qlib.utils import code_to_fname
|
||||
|
||||
|
||||
class BaseCollector(abc.ABC):
|
||||
|
||||
CACHE_FLAG = "CACHED"
|
||||
NORMAL_FLAG = "NORMAL"
|
||||
|
||||
DEFAULT_START_DATETIME_1D = pd.Timestamp("2000-01-01")
|
||||
DEFAULT_START_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6))
|
||||
DEFAULT_END_DATETIME_1D = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
|
||||
DEFAULT_END_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
|
||||
|
||||
INTERVAL_1min = "1min"
|
||||
INTERVAL_1d = "1d"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: [str, Path],
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
max_workers=4,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
save_dir: str
|
||||
stock save dir
|
||||
max_workers: int
|
||||
workers, default 4
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
start: str
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
"""
|
||||
self.save_dir = Path(save_dir).expanduser().resolve()
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.delay = delay
|
||||
self.max_workers = max_workers
|
||||
self.max_collector_count = max_collector_count
|
||||
self.mini_symbol_map = {}
|
||||
self.interval = interval
|
||||
self.check_small_data = check_data_length
|
||||
|
||||
self.start_datetime = self.normalize_start_datetime(start)
|
||||
self.end_datetime = self.normalize_end_datetime(end)
|
||||
|
||||
self.stock_list = sorted(set(self.get_stock_list()))
|
||||
|
||||
if limit_nums is not None:
|
||||
try:
|
||||
self.stock_list = self.stock_list[: int(limit_nums)]
|
||||
except Exception as e:
|
||||
logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored")
|
||||
|
||||
def normalize_start_datetime(self, start_datetime: [str, pd.Timestamp] = None):
|
||||
return (
|
||||
pd.Timestamp(str(start_datetime))
|
||||
if start_datetime
|
||||
else getattr(self, f"DEFAULT_START_DATETIME_{self.interval.upper()}")
|
||||
)
|
||||
|
||||
def normalize_end_datetime(self, end_datetime: [str, pd.Timestamp] = None):
|
||||
return (
|
||||
pd.Timestamp(str(end_datetime))
|
||||
if end_datetime
|
||||
else getattr(self, f"DEFAULT_END_DATETIME_{self.interval.upper()}")
|
||||
)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def min_numbers_trading(self):
|
||||
# daily, one year: 252 / 4
|
||||
# us 1min, a week: 6.5 * 60 * 5
|
||||
# cn 1min, a week: 4 * 60 * 5
|
||||
raise NotImplementedError("rewrite min_numbers_trading")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_stock_list(self):
|
||||
raise NotImplementedError("rewrite get_stock_list")
|
||||
|
||||
@abc.abstractmethod
|
||||
def normalize_symbol(self, symbol: str):
|
||||
"""normalize symbol"""
|
||||
raise NotImplementedError("rewrite normalize_symbol")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_data(
|
||||
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
||||
) -> pd.DataFrame:
|
||||
"""get data with symbol
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
interval: str
|
||||
value from [1min, 1d]
|
||||
start_datetime: pd.Timestamp
|
||||
end_datetime: pd.Timestamp
|
||||
|
||||
Returns
|
||||
---------
|
||||
pd.DataFrame, "symbol" in pd.columns
|
||||
|
||||
"""
|
||||
raise NotImplementedError("rewrite get_timezone")
|
||||
|
||||
def sleep(self):
|
||||
time.sleep(self.delay)
|
||||
|
||||
def _simple_collector(self, symbol: str):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
|
||||
"""
|
||||
self.sleep()
|
||||
df = self.get_data(symbol, self.interval, self.start_datetime, self.end_datetime)
|
||||
_result = self.NORMAL_FLAG
|
||||
if self.check_small_data:
|
||||
_result = self.cache_small_data(symbol, df)
|
||||
if _result == self.NORMAL_FLAG:
|
||||
self.save_instrument(symbol, df)
|
||||
return _result
|
||||
|
||||
def save_instrument(self, symbol, df: pd.DataFrame):
|
||||
"""save stock data to file
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
stock code
|
||||
df : pd.DataFrame
|
||||
df.columns must contain "symbol" and "datetime"
|
||||
"""
|
||||
if df.empty:
|
||||
logger.warning(f"{symbol} is empty")
|
||||
return
|
||||
|
||||
symbol = self.normalize_symbol(symbol)
|
||||
symbol = code_to_fname(symbol)
|
||||
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
|
||||
df["symbol"] = symbol
|
||||
if stock_path.exists():
|
||||
_old_df = pd.read_csv(stock_path)
|
||||
df = _old_df.append(df, sort=False)
|
||||
df.to_csv(stock_path, index=False)
|
||||
|
||||
def cache_small_data(self, symbol, df):
|
||||
if len(df) <= self.min_numbers_trading:
|
||||
logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!")
|
||||
_temp = self.mini_symbol_map.setdefault(symbol, [])
|
||||
_temp.append(df.copy())
|
||||
return self.CACHE_FLAG
|
||||
else:
|
||||
if symbol in self.mini_symbol_map:
|
||||
self.mini_symbol_map.pop(symbol)
|
||||
return self.NORMAL_FLAG
|
||||
|
||||
def _collector(self, stock_list):
|
||||
|
||||
error_symbol = []
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
with tqdm(total=len(stock_list)) as p_bar:
|
||||
for _symbol, _result in zip(stock_list, executor.map(self._simple_collector, stock_list)):
|
||||
if _result != self.NORMAL_FLAG:
|
||||
error_symbol.append(_symbol)
|
||||
p_bar.update()
|
||||
print(error_symbol)
|
||||
logger.info(f"error symbol nums: {len(error_symbol)}")
|
||||
logger.info(f"current get symbol nums: {len(stock_list)}")
|
||||
error_symbol.extend(self.mini_symbol_map.keys())
|
||||
return sorted(set(error_symbol))
|
||||
|
||||
def collector_data(self):
|
||||
"""collector data"""
|
||||
logger.info("start collector data......")
|
||||
stock_list = self.stock_list
|
||||
for i in range(self.max_collector_count):
|
||||
if not stock_list:
|
||||
break
|
||||
logger.info(f"getting data: {i+1}")
|
||||
stock_list = self._collector(stock_list)
|
||||
logger.info(f"{i+1} finish.")
|
||||
for _symbol, _df_list in self.mini_symbol_map.items():
|
||||
self.save_instrument(
|
||||
_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])
|
||||
)
|
||||
if self.mini_symbol_map:
|
||||
logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self.mini_symbol_map.keys())}")
|
||||
logger.info(f"total {len(self.stock_list)}, error: {len(set(stock_list))}")
|
||||
|
||||
|
||||
class BaseNormalize(abc.ABC):
|
||||
def __init__(
|
||||
self,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
self._date_field_name = date_field_name
|
||||
self._symbol_field_name = symbol_field_name
|
||||
|
||||
self._calendar_list = self._get_calendar_list()
|
||||
|
||||
@abc.abstractmethod
|
||||
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
# normalize
|
||||
raise NotImplementedError("")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_calendar_list(self):
|
||||
"""Get benchmark calendar"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
|
||||
class Normalize:
|
||||
def __init__(
|
||||
self,
|
||||
source_dir: [str, Path],
|
||||
target_dir: [str, Path],
|
||||
normalize_class: Type[BaseNormalize],
|
||||
max_workers: int = 16,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str or Path
|
||||
The directory where the raw data collected from the Internet is saved
|
||||
target_dir: str or Path
|
||||
Directory for normalize data
|
||||
normalize_class: Type[YahooNormalize]
|
||||
normalize class
|
||||
max_workers: int
|
||||
Concurrent number, default is 16
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
if not (source_dir and target_dir):
|
||||
raise ValueError("source_dir and target_dir cannot be None")
|
||||
self._source_dir = Path(source_dir).expanduser()
|
||||
self._target_dir = Path(target_dir).expanduser()
|
||||
self._target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._max_workers = max_workers
|
||||
|
||||
self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name)
|
||||
|
||||
def _executor(self, file_path: Path):
|
||||
file_path = Path(file_path)
|
||||
df = pd.read_csv(file_path)
|
||||
df = self._normalize_obj.normalize(df)
|
||||
if not df.empty:
|
||||
df.to_csv(self._target_dir.joinpath(file_path.name), index=False)
|
||||
|
||||
def normalize(self):
|
||||
logger.info("normalize data......")
|
||||
|
||||
with ProcessPoolExecutor(max_workers=self._max_workers) as worker:
|
||||
file_list = list(self._source_dir.glob("*.csv"))
|
||||
with tqdm(total=len(file_list)) as p_bar:
|
||||
for _ in worker.map(self._executor, file_list):
|
||||
p_bar.update()
|
||||
|
||||
|
||||
class BaseRun(abc.ABC):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d"):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str
|
||||
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
||||
normalize_dir: str
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
max_workers: int
|
||||
Concurrent number, default is 4
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
"""
|
||||
if source_dir is None:
|
||||
source_dir = Path(self.default_base_dir).joinpath("_source")
|
||||
self.source_dir = Path(source_dir).expanduser().resolve()
|
||||
self.source_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if normalize_dir is None:
|
||||
normalize_dir = Path(self.default_base_dir).joinpath("normalize")
|
||||
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
|
||||
self.normalize_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._cur_module = importlib.import_module("collector")
|
||||
self.max_workers = max_workers
|
||||
self.interval = interval
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def collector_class_name(self):
|
||||
raise NotImplementedError("rewrite normalize_symbol")
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def normalize_class_name(self):
|
||||
raise NotImplementedError("rewrite normalize_symbol")
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def default_base_dir(self) -> [Path, str]:
|
||||
raise NotImplementedError("rewrite normalize_symbol")
|
||||
|
||||
def download_data(
|
||||
self,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
check_data_length=False,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download data from Internet
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
start: str
|
||||
start datetime, default "2000-01-01"
|
||||
end: str
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
|
||||
Examples
|
||||
---------
|
||||
# get daily data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
# get 1m data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
|
||||
"""
|
||||
|
||||
_class = getattr(self._cur_module, self.collector_class_name) # type: Type[BaseCollector]
|
||||
_class(
|
||||
self.source_dir,
|
||||
max_workers=self.max_workers,
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
start=start,
|
||||
end=end,
|
||||
interval=interval,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
).collector_data()
|
||||
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
||||
"""normalize data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
date_field_name: str
|
||||
date field name, default date
|
||||
symbol_field_name: str
|
||||
symbol field name, default symbol
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d
|
||||
"""
|
||||
_class = getattr(self._cur_module, self.normalize_class_name)
|
||||
yc = Normalize(
|
||||
source_dir=self.source_dir,
|
||||
target_dir=self.normalize_dir,
|
||||
normalize_class=_class,
|
||||
max_workers=self.max_workers,
|
||||
date_field_name=date_field_name,
|
||||
symbol_field_name=symbol_field_name,
|
||||
)
|
||||
yc.normalize()
|
||||
@@ -10,158 +10,26 @@ import importlib
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Type
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
|
||||
import fire
|
||||
import requests
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from yahooquery import Ticker
|
||||
from dateutil.tz import tzlocal
|
||||
from qlib.utils import code_to_fname, fname_to_code
|
||||
from qlib.config import REG_CN as REGION_CN
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
|
||||
from data_collector.utils import get_calendar_list, get_hs_stock_symbols, get_us_stock_symbols
|
||||
|
||||
INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}"
|
||||
REGION_CN = "CN"
|
||||
REGION_US = "US"
|
||||
|
||||
|
||||
class YahooData:
|
||||
START_DATETIME = pd.Timestamp("2000-01-01")
|
||||
HIGH_FREQ_START_DATETIME = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6))
|
||||
END_DATETIME = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
|
||||
INTERVAL_1min = "1min"
|
||||
INTERVAL_1d = "1d"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timezone: str = None,
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
delay=0,
|
||||
show_1min_logging: bool = False,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
timezone: str
|
||||
The timezone where the data is located
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1min
|
||||
start: str
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
show_1min_logging: bool
|
||||
show 1min logging, by default False; if True, there may be many warning logs
|
||||
"""
|
||||
self._timezone = tzlocal() if timezone is None else timezone
|
||||
self._delay = delay
|
||||
self._interval = interval
|
||||
self._show_1min_logging = show_1min_logging
|
||||
self.start_datetime = pd.Timestamp(str(start)) if start else self.START_DATETIME
|
||||
self.end_datetime = min(pd.Timestamp(str(end)) if end else self.END_DATETIME, self.END_DATETIME)
|
||||
if self._interval == self.INTERVAL_1min:
|
||||
self.start_datetime = max(self.start_datetime, self.HIGH_FREQ_START_DATETIME)
|
||||
elif self._interval == self.INTERVAL_1d:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"interval error: {self._interval}")
|
||||
|
||||
# using for 1min
|
||||
self._next_datetime = self.convert_datetime(self.start_datetime.date() + pd.Timedelta(days=1), self._timezone)
|
||||
self._latest_datetime = self.convert_datetime(self.end_datetime.date(), self._timezone)
|
||||
|
||||
self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)
|
||||
self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)
|
||||
|
||||
@staticmethod
|
||||
def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone):
|
||||
try:
|
||||
dt = pd.Timestamp(dt, tz=timezone).timestamp()
|
||||
dt = pd.Timestamp(dt, tz=tzlocal(), unit="s")
|
||||
except ValueError as e:
|
||||
pass
|
||||
return dt
|
||||
|
||||
def _sleep(self):
|
||||
time.sleep(self._delay)
|
||||
|
||||
@staticmethod
|
||||
def get_data_from_remote(symbol, interval, start, end, show_1min_logging: bool = False):
|
||||
error_msg = f"{symbol}-{interval}-{start}-{end}"
|
||||
|
||||
def _show_logging_func():
|
||||
if interval == YahooData.INTERVAL_1min and show_1min_logging:
|
||||
logger.warning(f"{error_msg}:{_resp}")
|
||||
|
||||
interval = "1m" if interval in ["1m", "1min"] else interval
|
||||
try:
|
||||
_resp = Ticker(symbol, asynchronous=False).history(interval=interval, start=start, end=end)
|
||||
if isinstance(_resp, pd.DataFrame):
|
||||
return _resp.reset_index()
|
||||
elif isinstance(_resp, dict):
|
||||
_temp_data = _resp.get(symbol, {})
|
||||
if isinstance(_temp_data, str) or (
|
||||
isinstance(_resp, dict) and _temp_data.get("indicators", {}).get("quote", None) is None
|
||||
):
|
||||
_show_logging_func()
|
||||
else:
|
||||
_show_logging_func()
|
||||
except Exception as e:
|
||||
logger.warning(f"{error_msg}:{e}")
|
||||
|
||||
def get_data(self, symbol: str) -> [pd.DataFrame]:
|
||||
def _get_simple(start_, end_):
|
||||
self._sleep()
|
||||
_remote_interval = "1m" if self._interval == self.INTERVAL_1min else self._interval
|
||||
return self.get_data_from_remote(
|
||||
symbol,
|
||||
interval=_remote_interval,
|
||||
start=start_,
|
||||
end=end_,
|
||||
show_1min_logging=self._show_1min_logging,
|
||||
)
|
||||
|
||||
_result = None
|
||||
if self._interval == self.INTERVAL_1d:
|
||||
_result = _get_simple(self.start_datetime, self.end_datetime)
|
||||
elif self._interval == self.INTERVAL_1min:
|
||||
if self._next_datetime >= self._latest_datetime:
|
||||
_result = _get_simple(self.start_datetime, self.end_datetime)
|
||||
else:
|
||||
_res = []
|
||||
|
||||
def _get_multi(start_, end_):
|
||||
_resp = _get_simple(start_, end_)
|
||||
if _resp is not None and not _resp.empty:
|
||||
_res.append(_resp)
|
||||
|
||||
for _s, _e in (
|
||||
(self.start_datetime, self._next_datetime),
|
||||
(self._latest_datetime, self.end_datetime),
|
||||
):
|
||||
_get_multi(_s, _e)
|
||||
for _start in pd.date_range(self._next_datetime, self._latest_datetime, closed="left"):
|
||||
_end = _start + pd.Timedelta(days=1)
|
||||
_get_multi(_start, _end)
|
||||
if _res:
|
||||
_result = pd.concat(_res, sort=False).sort_values(["symbol", "date"])
|
||||
else:
|
||||
raise ValueError(f"cannot support {self._interval}")
|
||||
return _result
|
||||
|
||||
|
||||
class YahooCollector:
|
||||
class YahooCollector(BaseCollector):
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: [str, Path],
|
||||
@@ -173,7 +41,6 @@ class YahooCollector:
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
limit_nums: int = None,
|
||||
show_1min_logging: bool = False,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -197,131 +64,118 @@ class YahooCollector:
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
show_1min_logging: bool
|
||||
show 1m logging, by default False; if True, there may be many warning logs
|
||||
"""
|
||||
self.save_dir = Path(save_dir).expanduser().resolve()
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._delay = delay
|
||||
self.max_workers = max_workers
|
||||
self._max_collector_count = max_collector_count
|
||||
self._mini_symbol_map = {}
|
||||
self._interval = interval
|
||||
self._check_small_data = check_data_length
|
||||
|
||||
self.stock_list = sorted(set(self.get_stock_list()))
|
||||
if limit_nums is not None:
|
||||
try:
|
||||
self.stock_list = self.stock_list[: int(limit_nums)]
|
||||
except Exception as e:
|
||||
logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored")
|
||||
|
||||
self.yahoo_data = YahooData(
|
||||
timezone=self._timezone,
|
||||
super(YahooCollector, self).__init__(
|
||||
save_dir=save_dir,
|
||||
start=start,
|
||||
end=end,
|
||||
interval=interval,
|
||||
max_workers=max_workers,
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
show_1min_logging=show_1min_logging,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def min_numbers_trading(self):
|
||||
# daily, one year: 252 / 4
|
||||
# us 1min, a week: 6.5 * 60 * 5
|
||||
# cn 1min, a week: 4 * 60 * 5
|
||||
raise NotImplementedError("rewrite min_numbers_trading")
|
||||
self.init_datetime()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_stock_list(self):
|
||||
raise NotImplementedError("rewrite get_stock_list")
|
||||
def init_datetime(self):
|
||||
if self.interval == self.INTERVAL_1min:
|
||||
self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN)
|
||||
elif self.interval == self.INTERVAL_1d:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"interval error: {self.interval}")
|
||||
|
||||
# using for 1min
|
||||
self._next_datetime = self.convert_datetime(self.start_datetime.date() + pd.Timedelta(days=1), self._timezone)
|
||||
self._latest_datetime = self.convert_datetime(self.end_datetime.date(), self._timezone)
|
||||
|
||||
self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)
|
||||
self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)
|
||||
|
||||
@staticmethod
|
||||
def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone):
|
||||
try:
|
||||
dt = pd.Timestamp(dt, tz=timezone).timestamp()
|
||||
dt = pd.Timestamp(dt, tz=tzlocal(), unit="s")
|
||||
except ValueError as e:
|
||||
pass
|
||||
return dt
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def _timezone(self):
|
||||
raise NotImplementedError("rewrite get_timezone")
|
||||
|
||||
def save_stock(self, symbol, df: pd.DataFrame):
|
||||
"""save stock data to file
|
||||
@staticmethod
|
||||
def get_data_from_remote(symbol, interval, start, end, show_1min_logging: bool = False):
|
||||
error_msg = f"{symbol}-{interval}-{start}-{end}"
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
stock code
|
||||
df : pd.DataFrame
|
||||
df.columns must contain "symbol" and "datetime"
|
||||
"""
|
||||
if df.empty:
|
||||
logger.warning(f"{symbol} is empty")
|
||||
return
|
||||
def _show_logging_func():
|
||||
if interval == YahooCollector.INTERVAL_1min and show_1min_logging:
|
||||
logger.warning(f"{error_msg}:{_resp}")
|
||||
|
||||
symbol = self.normalize_symbol(symbol)
|
||||
symbol = code_to_fname(symbol)
|
||||
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
|
||||
df["symbol"] = symbol
|
||||
if stock_path.exists():
|
||||
_old_df = pd.read_csv(stock_path)
|
||||
df = _old_df.append(df, sort=False)
|
||||
df.to_csv(stock_path, index=False)
|
||||
interval = "1m" if interval in ["1m", "1min"] else interval
|
||||
try:
|
||||
_resp = Ticker(symbol, asynchronous=False).history(interval=interval, start=start, end=end)
|
||||
if isinstance(_resp, pd.DataFrame):
|
||||
return _resp.reset_index()
|
||||
elif isinstance(_resp, dict):
|
||||
_temp_data = _resp.get(symbol, {})
|
||||
if isinstance(_temp_data, str) or (
|
||||
isinstance(_resp, dict) and _temp_data.get("indicators", {}).get("quote", None) is None
|
||||
):
|
||||
_show_logging_func()
|
||||
else:
|
||||
_show_logging_func()
|
||||
except Exception as e:
|
||||
logger.warning(f"{error_msg}:{e}")
|
||||
|
||||
def _save_small_data(self, symbol, df):
|
||||
if len(df) <= self.min_numbers_trading:
|
||||
logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!")
|
||||
_temp = self._mini_symbol_map.setdefault(symbol, [])
|
||||
_temp.append(df.copy())
|
||||
return None
|
||||
else:
|
||||
if symbol in self._mini_symbol_map:
|
||||
self._mini_symbol_map.pop(symbol)
|
||||
return symbol
|
||||
def get_data(
|
||||
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
||||
) -> pd.DataFrame:
|
||||
def _get_simple(start_, end_):
|
||||
self.sleep()
|
||||
_remote_interval = "1m" if interval == self.INTERVAL_1min else interval
|
||||
return self.get_data_from_remote(
|
||||
symbol,
|
||||
interval=_remote_interval,
|
||||
start=start_,
|
||||
end=end_,
|
||||
)
|
||||
|
||||
def _get_data(self, symbol):
|
||||
_result = None
|
||||
df = self.yahoo_data.get_data(symbol)
|
||||
if isinstance(df, pd.DataFrame):
|
||||
if not df.empty:
|
||||
if self._check_small_data:
|
||||
if self._save_small_data(symbol, df) is not None:
|
||||
_result = symbol
|
||||
self.save_stock(symbol, df)
|
||||
else:
|
||||
_result = symbol
|
||||
self.save_stock(symbol, df)
|
||||
return _result
|
||||
if interval == self.INTERVAL_1d:
|
||||
_result = _get_simple(start_datetime, end_datetime)
|
||||
elif interval == self.INTERVAL_1min:
|
||||
if self._next_datetime >= self._latest_datetime:
|
||||
_result = _get_simple(start_datetime, end_datetime)
|
||||
else:
|
||||
_res = []
|
||||
|
||||
def _collector(self, stock_list):
|
||||
def _get_multi(start_, end_):
|
||||
_resp = _get_simple(start_, end_)
|
||||
if _resp is not None and not _resp.empty:
|
||||
_res.append(_resp)
|
||||
|
||||
error_symbol = []
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
with tqdm(total=len(stock_list)) as p_bar:
|
||||
for _symbol, _result in zip(stock_list, executor.map(self._get_data, stock_list)):
|
||||
if _result is None:
|
||||
error_symbol.append(_symbol)
|
||||
p_bar.update()
|
||||
print(error_symbol)
|
||||
logger.info(f"error symbol nums: {len(error_symbol)}")
|
||||
logger.info(f"current get symbol nums: {len(stock_list)}")
|
||||
error_symbol.extend(self._mini_symbol_map.keys())
|
||||
return sorted(set(error_symbol))
|
||||
for _s, _e in (
|
||||
(self.start_datetime, self._next_datetime),
|
||||
(self._latest_datetime, self.end_datetime),
|
||||
):
|
||||
_get_multi(_s, _e)
|
||||
for _start in pd.date_range(self._next_datetime, self._latest_datetime, closed="left"):
|
||||
_end = _start + pd.Timedelta(days=1)
|
||||
_get_multi(_start, _end)
|
||||
if _res:
|
||||
_result = pd.concat(_res, sort=False).sort_values(["symbol", "date"])
|
||||
else:
|
||||
raise ValueError(f"cannot support {self.interval}")
|
||||
return pd.DataFrame() if _result is None else _result
|
||||
|
||||
def collector_data(self):
|
||||
"""collector data"""
|
||||
logger.info("start collector yahoo data......")
|
||||
stock_list = self.stock_list
|
||||
for i in range(self._max_collector_count):
|
||||
if not stock_list:
|
||||
break
|
||||
logger.info(f"getting data: {i+1}")
|
||||
stock_list = self._collector(stock_list)
|
||||
logger.info(f"{i+1} finish.")
|
||||
for _symbol, _df_list in self._mini_symbol_map.items():
|
||||
self.save_stock(_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"]))
|
||||
if self._mini_symbol_map:
|
||||
logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self._mini_symbol_map.keys())}")
|
||||
logger.info(f"total {len(self.stock_list)}, error: {len(set(stock_list))}")
|
||||
|
||||
super(YahooCollector, self).collector_data()
|
||||
self.download_index_data()
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -329,11 +183,6 @@ class YahooCollector:
|
||||
"""download index data"""
|
||||
raise NotImplementedError("rewrite download_index_data")
|
||||
|
||||
@abc.abstractmethod
|
||||
def normalize_symbol(self, symbol: str):
|
||||
"""normalize symbol"""
|
||||
raise NotImplementedError("rewrite normalize_symbol")
|
||||
|
||||
|
||||
class YahooCollectorCN(YahooCollector, ABC):
|
||||
def get_stock_list(self):
|
||||
@@ -360,8 +209,8 @@ class YahooCollectorCN1d(YahooCollectorCN):
|
||||
def download_index_data(self):
|
||||
# TODO: from MSN
|
||||
_format = "%Y%m%d"
|
||||
_begin = self.yahoo_data.start_datetime.strftime(_format)
|
||||
_end = (self.yahoo_data.end_datetime + pd.Timedelta(days=-1)).strftime(_format)
|
||||
_begin = self.start_datetime.strftime(_format)
|
||||
_end = (self.end_datetime + pd.Timedelta(days=-1)).strftime(_format)
|
||||
for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items():
|
||||
logger.info(f"get bench data: {_index_name}({_index_code})......")
|
||||
try:
|
||||
@@ -396,7 +245,7 @@ class YahooCollectorCN1min(YahooCollectorCN):
|
||||
|
||||
def download_index_data(self):
|
||||
# TODO: 1m
|
||||
logger.warning(f"{self.__class__.__name__} {self._interval} does not support: download_index_data")
|
||||
logger.warning(f"{self.__class__.__name__} {self.interval} does not support: download_index_data")
|
||||
|
||||
|
||||
class YahooCollectorUS(YahooCollector, ABC):
|
||||
@@ -433,29 +282,10 @@ class YahooCollectorUS1min(YahooCollectorUS):
|
||||
return 60 * 6.5 * 5
|
||||
|
||||
|
||||
class YahooNormalize:
|
||||
class YahooNormalize(BaseNormalize):
|
||||
COLUMNS = ["open", "close", "high", "low", "volume"]
|
||||
DAILY_FORMAT = "%Y-%m-%d"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
self._date_field_name = date_field_name
|
||||
self._symbol_field_name = symbol_field_name
|
||||
|
||||
self._calendar_list = self._get_calendar_list()
|
||||
|
||||
@staticmethod
|
||||
def normalize_yahoo(
|
||||
df: pd.DataFrame,
|
||||
@@ -498,11 +328,6 @@ class YahooNormalize:
|
||||
df = self.adjusted_price(df)
|
||||
return df
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_calendar_list(self):
|
||||
"""Get benchmark calendar"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
@abc.abstractmethod
|
||||
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""adjusted price"""
|
||||
@@ -618,7 +443,9 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
# get 1d data from yahoo
|
||||
_start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT)
|
||||
_end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT)
|
||||
data_1d = YahooData.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=_start, end=_end)
|
||||
data_1d = YahooCollector.get_data_from_remote(
|
||||
self.symbol_to_yahoo(symbol), interval="1d", start=_start, end=_end
|
||||
)
|
||||
if data_1d is None or data_1d.empty:
|
||||
df["factor"] = 1
|
||||
# TODO: np.nan or 1 or 0
|
||||
@@ -723,62 +550,8 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min):
|
||||
return get_calendar_list("ALL")
|
||||
|
||||
|
||||
class Normalize:
|
||||
def __init__(
|
||||
self,
|
||||
source_dir: [str, Path],
|
||||
target_dir: [str, Path],
|
||||
normalize_class: Type[YahooNormalize],
|
||||
max_workers: int = 16,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str or Path
|
||||
The directory where the raw data collected from the Internet is saved
|
||||
target_dir: str or Path
|
||||
Directory for normalize data
|
||||
normalize_class: Type[YahooNormalize]
|
||||
normalize class
|
||||
max_workers: int
|
||||
Concurrent number, default is 16
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
if not (source_dir and target_dir):
|
||||
raise ValueError("source_dir and target_dir cannot be None")
|
||||
self._source_dir = Path(source_dir).expanduser()
|
||||
self._target_dir = Path(target_dir).expanduser()
|
||||
self._target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._max_workers = max_workers
|
||||
|
||||
self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name)
|
||||
|
||||
def _executor(self, file_path: Path):
|
||||
file_path = Path(file_path)
|
||||
df = pd.read_csv(file_path)
|
||||
df = self._normalize_obj.normalize(df)
|
||||
if not df.empty:
|
||||
df.to_csv(self._target_dir.joinpath(file_path.name), index=False)
|
||||
|
||||
def normalize(self):
|
||||
logger.info("normalize data......")
|
||||
|
||||
with ProcessPoolExecutor(max_workers=self._max_workers) as worker:
|
||||
file_list = list(self._source_dir.glob("*.csv"))
|
||||
with tqdm(total=len(file_list)) as p_bar:
|
||||
for _ in worker.map(self._executor, file_list):
|
||||
p_bar.update()
|
||||
|
||||
|
||||
class Run:
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN):
|
||||
class Run(BaseRun):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
@@ -789,23 +562,26 @@ class Run:
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
max_workers: int
|
||||
Concurrent number, default is 4
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
region: str
|
||||
region, value from ["CN", "US"], default "CN"
|
||||
"""
|
||||
if source_dir is None:
|
||||
source_dir = CUR_DIR.joinpath("source")
|
||||
self.source_dir = Path(source_dir).expanduser().resolve()
|
||||
self.source_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if normalize_dir is None:
|
||||
normalize_dir = CUR_DIR.joinpath("normalize")
|
||||
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
|
||||
self.normalize_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._cur_module = importlib.import_module("collector")
|
||||
self.max_workers = max_workers
|
||||
super().__init__(source_dir, normalize_dir, max_workers, interval)
|
||||
self.region = region
|
||||
|
||||
@property
|
||||
def collector_class_name(self):
|
||||
return f"YahooCollector{self.region.upper()}{self.interval}"
|
||||
|
||||
@property
|
||||
def normalize_class_name(self):
|
||||
return f"YahooNormalize{self.region.upper()}{self.interval}"
|
||||
|
||||
@property
|
||||
def default_base_dir(self) -> [Path, str]:
|
||||
return CUR_DIR
|
||||
|
||||
def download_data(
|
||||
self,
|
||||
max_collector_count=2,
|
||||
@@ -815,7 +591,6 @@ class Run:
|
||||
interval="1d",
|
||||
check_data_length=False,
|
||||
limit_nums=None,
|
||||
show_1min_logging=False,
|
||||
):
|
||||
"""download data from Internet
|
||||
|
||||
@@ -835,8 +610,6 @@ class Run:
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
show_1min_logging: bool
|
||||
show 1m logging, by default False; if True, there may be many warning logs
|
||||
|
||||
Examples
|
||||
---------
|
||||
@@ -846,29 +619,13 @@ class Run:
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
|
||||
"""
|
||||
|
||||
_class = getattr(
|
||||
self._cur_module, f"YahooCollector{self.region.upper()}{interval}"
|
||||
) # type: Type[YahooCollector]
|
||||
_class(
|
||||
self.source_dir,
|
||||
max_workers=self.max_workers,
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
start=start,
|
||||
end=end,
|
||||
interval=interval,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
show_1min_logging=show_1min_logging,
|
||||
).collector_data()
|
||||
super(Run, self).download_data(max_collector_count, delay, start, end, interval, check_data_length, limit_nums)
|
||||
|
||||
def normalize_data(self, interval: str = "1d", date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
||||
"""normalize data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
date_field_name: str
|
||||
date field name, default date
|
||||
symbol_field_name: str
|
||||
@@ -878,16 +635,7 @@ class Run:
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d
|
||||
"""
|
||||
_class = getattr(self._cur_module, f"YahooNormalize{self.region.upper()}{interval}")
|
||||
yc = Normalize(
|
||||
source_dir=self.source_dir,
|
||||
target_dir=self.normalize_dir,
|
||||
normalize_class=_class,
|
||||
max_workers=self.max_workers,
|
||||
date_field_name=date_field_name,
|
||||
symbol_field_name=symbol_field_name,
|
||||
)
|
||||
yc.normalize()
|
||||
super(Run, self).normalize_data(date_field_name, symbol_field_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -96,7 +96,6 @@ port_analysis_config = {
|
||||
}
|
||||
|
||||
|
||||
# train
|
||||
def train():
|
||||
"""train model
|
||||
|
||||
@@ -111,6 +110,9 @@ def train():
|
||||
# model initiaiton
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
# To test __repr__
|
||||
print(dataset)
|
||||
print(R)
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name="workflow"):
|
||||
@@ -119,6 +121,10 @@ def train():
|
||||
|
||||
# prediction
|
||||
recorder = R.get_recorder()
|
||||
# To test __repr__
|
||||
print(recorder)
|
||||
# To test get_local_dir
|
||||
print(recorder.get_local_dir())
|
||||
rid = recorder.id
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
@@ -133,6 +139,27 @@ def train():
|
||||
return pred_score, {"ic": ic, "ric": ric}, rid
|
||||
|
||||
|
||||
def fake_experiment():
|
||||
"""A fake experiment workflow to test uri
|
||||
|
||||
Returns
|
||||
-------
|
||||
pass_or_not_for_default_uri: bool
|
||||
pass_or_not_for_current_uri: bool
|
||||
temporary_exp_dir: str
|
||||
"""
|
||||
|
||||
# start exp
|
||||
default_uri = R.get_uri()
|
||||
current_uri = "file:./temp-test-exp-mag"
|
||||
with R.start(experiment_name="fake_workflow_for_expm", uri=current_uri):
|
||||
R.log_params(**flatten_dict(task))
|
||||
|
||||
current_uri_to_check = R.get_uri()
|
||||
default_uri_to_check = R.get_uri()
|
||||
return default_uri == default_uri_to_check, current_uri == current_uri_to_check, current_uri
|
||||
|
||||
|
||||
def backtest_analysis(pred, rid):
|
||||
"""backtest and analysis
|
||||
|
||||
@@ -181,6 +208,12 @@ class TestAllFlow(TestAutoData):
|
||||
"backtest failed",
|
||||
)
|
||||
|
||||
def test_2_expmanager(self):
|
||||
pass_default, pass_current, uri_path = fake_experiment()
|
||||
self.assertTrue(pass_default, msg="default uri is incorrect")
|
||||
self.assertTrue(pass_current, msg="current uri is incorrect")
|
||||
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
|
||||
|
||||
|
||||
def suite():
|
||||
_suite = unittest.TestSuite()
|
||||
|
||||
Reference in New Issue
Block a user