diff --git a/qlib/contrib/model/pytorch_alstm.py b/qlib/contrib/model/pytorch_alstm.py index e8b316795..0cb09a05a 100644 --- a/qlib/contrib/model/pytorch_alstm.py +++ b/qlib/contrib/model/pytorch_alstm.py @@ -14,7 +14,7 @@ import logging from ...utils import ( unpack_archive_with_buffer, save_multiple_parts_file, - create_save_path, + get_or_create_path, drop_nan_by_y_index, ) from ...log import get_module_logger, TimeInspector @@ -230,8 +230,7 @@ class ALSTM(Model): x_train, y_train = df_train["feature"], df_train["label"] x_valid, y_valid = df_valid["feature"], df_valid["label"] - if save_path == None: - save_path = create_save_path(save_path) + save_path = get_or_create_path(save_path) stop_steps = 0 train_loss = 0 best_score = -np.inf diff --git a/qlib/contrib/model/pytorch_alstm_ts.py b/qlib/contrib/model/pytorch_alstm_ts.py index d17cfb934..fe562fd1c 100644 --- a/qlib/contrib/model/pytorch_alstm_ts.py +++ b/qlib/contrib/model/pytorch_alstm_ts.py @@ -14,7 +14,7 @@ import logging from ...utils import ( unpack_archive_with_buffer, save_multiple_parts_file, - create_save_path, + get_or_create_path, drop_nan_by_y_index, ) from ...log import get_module_logger, TimeInspector @@ -220,8 +220,7 @@ class ALSTM(Model): dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True ) - if save_path == None: - save_path = create_save_path(save_path) + save_path = get_or_create_path(save_path) stop_steps = 0 train_loss = 0 diff --git a/qlib/contrib/model/pytorch_gats.py b/qlib/contrib/model/pytorch_gats.py index 7d3b00232..b5330146f 100644 --- a/qlib/contrib/model/pytorch_gats.py +++ b/qlib/contrib/model/pytorch_gats.py @@ -14,7 +14,7 @@ import logging from ...utils import ( unpack_archive_with_buffer, save_multiple_parts_file, - create_save_path, + get_or_create_path, drop_nan_by_y_index, ) from ...log import get_module_logger, TimeInspector @@ -248,8 +248,7 @@ class GATs(Model): x_train, y_train = df_train["feature"], df_train["label"] x_valid, y_valid = df_valid["feature"], df_valid["label"] - if save_path == None: - save_path = create_save_path(save_path) + save_path = get_or_create_path(save_path) stop_steps = 0 best_score = -np.inf best_epoch = 0 diff --git a/qlib/contrib/model/pytorch_gats_ts.py b/qlib/contrib/model/pytorch_gats_ts.py index 1d9b525f1..369d1ca7f 100644 --- a/qlib/contrib/model/pytorch_gats_ts.py +++ b/qlib/contrib/model/pytorch_gats_ts.py @@ -14,7 +14,7 @@ import logging from ...utils import ( unpack_archive_with_buffer, save_multiple_parts_file, - create_save_path, + get_or_create_path, drop_nan_by_y_index, ) from ...log import get_module_logger, TimeInspector @@ -264,8 +264,7 @@ class GATs(Model): train_loader = DataLoader(dl_train, sampler=sampler_train, num_workers=self.n_jobs, drop_last=True) valid_loader = DataLoader(dl_valid, sampler=sampler_valid, num_workers=self.n_jobs, drop_last=True) - if save_path == None: - save_path = create_save_path(save_path) + save_path = get_or_create_path(save_path) stop_steps = 0 train_loss = 0 diff --git a/qlib/contrib/model/pytorch_gru.py b/qlib/contrib/model/pytorch_gru.py index d4dc88452..697b71cc9 100755 --- a/qlib/contrib/model/pytorch_gru.py +++ b/qlib/contrib/model/pytorch_gru.py @@ -14,7 +14,7 @@ import logging from ...utils import ( unpack_archive_with_buffer, save_multiple_parts_file, - create_save_path, + get_or_create_path, drop_nan_by_y_index, ) from ...log import get_module_logger, TimeInspector @@ -230,8 +230,7 @@ class GRU(Model): x_train, y_train = df_train["feature"], df_train["label"] x_valid, y_valid = df_valid["feature"], df_valid["label"] - if save_path == None: - save_path = create_save_path(save_path) + save_path = get_or_create_path(save_path) stop_steps = 0 train_loss = 0 best_score = -np.inf diff --git a/qlib/contrib/model/pytorch_gru_ts.py b/qlib/contrib/model/pytorch_gru_ts.py index b0a72f72d..483f419ce 100755 --- a/qlib/contrib/model/pytorch_gru_ts.py +++ b/qlib/contrib/model/pytorch_gru_ts.py @@ -14,7 +14,7 @@ import logging from ...utils import ( unpack_archive_with_buffer, save_multiple_parts_file, - create_save_path, + get_or_create_path, drop_nan_by_y_index, ) from ...log import get_module_logger, TimeInspector @@ -220,8 +220,7 @@ class GRU(Model): dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True ) - if save_path == None: - save_path = create_save_path(save_path) + save_path = get_or_create_path(save_path) stop_steps = 0 train_loss = 0 diff --git a/qlib/contrib/model/pytorch_lstm.py b/qlib/contrib/model/pytorch_lstm.py index be8d20a15..648a909c7 100755 --- a/qlib/contrib/model/pytorch_lstm.py +++ b/qlib/contrib/model/pytorch_lstm.py @@ -14,7 +14,7 @@ import logging from ...utils import ( unpack_archive_with_buffer, save_multiple_parts_file, - create_save_path, + get_or_create_path, drop_nan_by_y_index, ) from ...log import get_module_logger, TimeInspector @@ -226,8 +226,7 @@ class LSTM(Model): x_train, y_train = df_train["feature"], df_train["label"] x_valid, y_valid = df_valid["feature"], df_valid["label"] - if save_path == None: - save_path = create_save_path(save_path) + save_path = get_or_create_path(save_path) stop_steps = 0 train_loss = 0 best_score = -np.inf diff --git a/qlib/contrib/model/pytorch_lstm_ts.py b/qlib/contrib/model/pytorch_lstm_ts.py index e2eec9c28..95476fedf 100755 --- a/qlib/contrib/model/pytorch_lstm_ts.py +++ b/qlib/contrib/model/pytorch_lstm_ts.py @@ -14,7 +14,7 @@ import logging from ...utils import ( unpack_archive_with_buffer, save_multiple_parts_file, - create_save_path, + get_or_create_path, drop_nan_by_y_index, ) from ...log import get_module_logger, TimeInspector @@ -216,8 +216,7 @@ class LSTM(Model): dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True ) - if save_path == None: - save_path = create_save_path(save_path) + save_path = get_or_create_path(save_path) stop_steps = 0 train_loss = 0 diff --git a/qlib/contrib/model/pytorch_nn.py b/qlib/contrib/model/pytorch_nn.py index a51481c85..37d8dec3e 100644 --- a/qlib/contrib/model/pytorch_nn.py +++ b/qlib/contrib/model/pytorch_nn.py @@ -19,7 +19,7 @@ from .pytorch_utils import count_parameters from ...model.base import Model from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP -from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index +from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, get_or_create_path, drop_nan_by_y_index from ...log import get_module_logger, TimeInspector from ...workflow import R @@ -176,7 +176,7 @@ class DNNModelPytorch(Model): w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index) w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index) - save_path = create_save_path(save_path) + save_path = get_or_create_path(save_path) stop_steps = 0 train_loss = 0 best_loss = np.inf diff --git a/qlib/contrib/model/pytorch_sfm.py b/qlib/contrib/model/pytorch_sfm.py index 40b991c9c..cc600a955 100644 --- a/qlib/contrib/model/pytorch_sfm.py +++ b/qlib/contrib/model/pytorch_sfm.py @@ -13,7 +13,7 @@ import logging from ...utils import ( unpack_archive_with_buffer, save_multiple_parts_file, - create_save_path, + get_or_create_path, drop_nan_by_y_index, ) from ...log import get_module_logger, TimeInspector @@ -380,6 +380,7 @@ class SFM(Model): x_train, y_train = df_train["feature"], df_train["label"] x_valid, y_valid = df_valid["feature"], df_valid["label"] + save_path = get_or_create_path(save_path) stop_steps = 0 train_loss = 0 best_score = -np.inf @@ -412,7 +413,10 @@ class SFM(Model): if stop_steps >= self.early_stop: self.logger.info("early stop") break + self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch)) + self.sfm_model.load_state_dict(best_param) + torch.save(best_param, save_path) if self.device != "cpu": torch.cuda.empty_cache() diff --git a/qlib/contrib/model/pytorch_tabnet.py b/qlib/contrib/model/pytorch_tabnet.py index 020bbaff2..93b2a36da 100644 --- a/qlib/contrib/model/pytorch_tabnet.py +++ b/qlib/contrib/model/pytorch_tabnet.py @@ -12,7 +12,7 @@ import logging from ...utils import ( unpack_archive_with_buffer, save_multiple_parts_file, - create_save_path, + get_or_create_path, drop_nan_by_y_index, ) from ...log import get_module_logger, TimeInspector @@ -117,10 +117,7 @@ class TabnetModel(Model): raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) def pretrain_fn(self, dataset=DatasetH, pretrain_file="./pretrain/best.model"): - # make a directory if pretrian director does not exist - if pretrain_file.startswith("./pretrain") and not os.path.exists("pretrain"): - self.logger.info("make folder to store model...") - os.makedirs("pretrain") + get_or_create_path(pretrain_file) [df_train, df_valid] = dataset.prepare( ["pretrain", "pretrain_validation"], @@ -181,6 +178,7 @@ class TabnetModel(Model): df_train.fillna(df_train.mean(), inplace=True) x_train, y_train = df_train["feature"], df_train["label"] x_valid, y_valid = df_valid["feature"], df_valid["label"] + save_path = get_or_create_path(save_path) stop_steps = 0 train_loss = 0 @@ -207,12 +205,16 @@ class TabnetModel(Model): best_score = val_score stop_steps = 0 best_epoch = epoch_idx + best_param = copy.deepcopy(self.tabnet_model.state_dict()) else: stop_steps += 1 if stop_steps >= self.early_stop: self.logger.info("early stop") break + self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch)) + self.tabnet_model.load_state_dict(best_param) + torch.save(best_param, save_path) def predict(self, dataset): if not self.fitted: diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index f550a0419..68d7d8f3f 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -24,7 +24,7 @@ import collections import numpy as np import pandas as pd from pathlib import Path -from typing import Union, Tuple +from typing import Union, Tuple, Text, Optional from ..config import C from ..log import get_module_logger, set_log_with_config @@ -276,23 +276,31 @@ def compare_dict_value(src_data: dict, dst_data: dict): return changes -def create_save_path(save_path=None): - """Create save path +def get_or_create_path(path: Optional[Text] = None, return_dir: bool = False): + """Create or get a file or directory given the path and return_dir. Parameters ---------- - save_path: str + path: a string indicates the path or None indicates creating a temporary path. + return_dir: if True, create and return a directory; otherwise c&r a file. """ - if save_path: - if not os.path.exists(save_path): - os.makedirs(save_path) + if path: + if return_dir and not os.path.exists(path): + os.makedirs(path) + elif not return_dir: # return a file, thus we need to create its parent directory + xpath = os.path.abspath(os.path.join(path, "..")) + if not os.path.exists(xpath): + os.makedirs(xpath) else: temp_dir = os.path.expanduser("~/tmp") if not os.path.exists(temp_dir): os.makedirs(temp_dir) - _, save_path = tempfile.mkstemp(dir=temp_dir) - return save_path + if return_dir: + _, path = tempfile.mkdtemp(dir=temp_dir) + else: + _, path = tempfile.mkstemp(dir=temp_dir) + return path @contextlib.contextmanager