diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5d773a9d0..be111a5e8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -33,7 +33,37 @@ jobs: - name: Install Qlib with pip run: | pip install numpy==1.19.5 ruamel.yaml - pip install pyqlib --ignore-installed + pip install pyqlib --ignore-installed + + # Check Qlib with pylint + # TODO: These problems we will solve in the future. Important among them are: W0221, W0223, W0237, E1102 + # C0103: invalid-name + # C0209: consider-using-f-string + # R0402: consider-using-from-import + # R1705: no-else-return + # R1710: inconsistent-return-statements + # R1725: super-with-arguments + # R1735: use-dict-literal + # W0102: dangerous-default-value + # W0212: protected-access + # W0221: arguments-differ + # W0223: abstract-method + # W0231: super-init-not-called + # W0237: arguments-renamed + # W0612: unused-variable + # W0621: redefined-outer-name + # W0622: redefined-builtin + # FIXME: specify exception type + # W0703: broad-except + # W1309: f-string-without-interpolation + # E1102: not-callable + # E1136: unsubscriptable-object + # References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962 + - name: Check Qlib with pylint + run: | + pip install --upgrade pip + pip install pylint + pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0201,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500" - name: Test data downloads run: | diff --git a/qlib/__init__.py b/qlib/__init__.py index 8b3894253..0229f2cb9 100644 --- a/qlib/__init__.py +++ b/qlib/__init__.py @@ -30,8 +30,8 @@ def init(default_conf="client", **kwargs): When using the recorder, skip_if_reg can set to True to avoid loss of recorder. """ - from .config import C - from .data.cache import H + from .config import C # pylint: disable=C0415 + from .data.cache import H # pylint: disable=C0415 # FIXME: this logger ignored the level in config logger = get_module_logger("Initialization", level=logging.INFO) @@ -85,7 +85,7 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False): mount_command = "sudo mount.nfs %s %s" % (provider_uri, mount_path) # If the provider uri looks like this 172.23.233.89//data/csdesign' # It will be a nfs path. The client provider will be used - if not auto_mount: + if not auto_mount: # pylint: disable=R1702 if not Path(mount_path).exists(): raise FileNotFoundError( f"Invalid mount path: {mount_path}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`" @@ -139,8 +139,10 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False): if not _is_mount: try: Path(mount_path).mkdir(parents=True, exist_ok=True) - except Exception: - raise OSError(f"Failed to create directory {mount_path}, please create {mount_path} manually!") + except Exception as e: + raise OSError( + f"Failed to create directory {mount_path}, please create {mount_path} manually!" + ) from e # check nfs-common command_res = os.popen("dpkg -l | grep nfs-common") diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index b1d92c5a5..36f0961be 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -171,8 +171,8 @@ def get_strategy_executor( # NOTE: # - for avoiding recursive import # - typing annotations is not reliable - from ..strategy.base import BaseStrategy - from .executor import BaseExecutor + from ..strategy.base import BaseStrategy # pylint: disable=C0415 + from .executor import BaseExecutor # pylint: disable=C0415 trade_account = create_account_instance( start_time=start_time, end_time=end_time, benchmark=benchmark, account=account, pos_type=pos_type diff --git a/qlib/backtest/account.py b/qlib/backtest/account.py index f2e32c602..4c9330e4c 100644 --- a/qlib/backtest/account.py +++ b/qlib/backtest/account.py @@ -2,11 +2,11 @@ # Licensed under the MIT License. from __future__ import annotations import copy -from typing import Dict, List, Tuple, TYPE_CHECKING +from typing import Dict, List, Tuple from qlib.utils import init_instance_by_config import pandas as pd -from .position import BasePosition, InfPosition, Position +from .position import BasePosition from .report import PortfolioMetrics, Indicator from .decision import BaseTradeDecision, Order from .exchange import Exchange diff --git a/qlib/backtest/decision.py b/qlib/backtest/decision.py index 3b15b06a4..e8f787a9f 100644 --- a/qlib/backtest/decision.py +++ b/qlib/backtest/decision.py @@ -7,19 +7,18 @@ from qlib.data.data import Cal from qlib.utils.time import concat_date_time, epsilon_change from qlib.log import get_module_logger +from typing import ClassVar, Optional, Union, List, Tuple + # try to fix circular imports when enabling type hints -from typing import Callable, TYPE_CHECKING +from typing import TYPE_CHECKING if TYPE_CHECKING: from qlib.strategy.base import BaseStrategy from qlib.backtest.exchange import Exchange from qlib.backtest.utils import TradeCalendarManager -import warnings import numpy as np import pandas as pd -import numpy as np -from dataclasses import dataclass, field -from typing import ClassVar, Optional, Union, List, Set, Tuple +from dataclasses import dataclass class OrderDir(IntEnum): @@ -418,7 +417,7 @@ class BaseTradeDecision: return kwargs["default_value"] else: # Default to get full index - raise NotImplementedError(f"The decision didn't provide an index range") + raise NotImplementedError(f"The decision didn't provide an index range") from NotImplementedError # clip index if getattr(self, "total_step", None) is not None: diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index 902394a9c..4c020f8d8 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -3,13 +3,13 @@ from __future__ import annotations from collections import defaultdict from typing import TYPE_CHECKING +from typing import List, Tuple, Union if TYPE_CHECKING: from .account import Account from qlib.backtest.position import BasePosition, Position import random -from typing import List, Tuple, Union import numpy as np import pandas as pd @@ -18,7 +18,7 @@ from ..config import C from ..constant import REG_CN from ..log import get_module_logger from .decision import Order, OrderDir, OrderHelper -from .high_performance_ds import BaseQuote, PandasQuote, NumpyQuote +from .high_performance_ds import BaseQuote, NumpyQuote class Exchange: diff --git a/qlib/backtest/executor.py b/qlib/backtest/executor.py index 37098379b..821d36cc0 100644 --- a/qlib/backtest/executor.py +++ b/qlib/backtest/executor.py @@ -1,22 +1,18 @@ -from abc import abstractclassmethod, abstractmethod +from abc import abstractmethod import copy from qlib.backtest.position import BasePosition from qlib.log import get_module_logger from types import GeneratorType from qlib.backtest.account import Account -import warnings import pandas as pd from typing import List, Tuple, Union from collections import defaultdict -from qlib.backtest.report import Indicator - -from .decision import EmptyTradeDecision, Order, BaseTradeDecision +from .decision import Order, BaseTradeDecision from .exchange import Exchange from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure, get_start_end_idx from ..utils import init_instance_by_config -from ..utils.time import Freq from ..strategy.base import BaseStrategy @@ -193,7 +189,8 @@ class BaseExecutor: pass return return_value.get("execute_result") - @abstractclassmethod + @classmethod + @abstractmethod def _collect_data(cls, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]: """ Please refer to the doc of collect_data @@ -453,7 +450,6 @@ class NestedExecutor(BaseExecutor): inner_exe_res : the execution result of inner task """ - pass def get_all_executors(self): """get all executors, including self and inner_executor.get_all_executors()""" diff --git a/qlib/backtest/position.py b/qlib/backtest/position.py index ddb5c24e8..925cb711d 100644 --- a/qlib/backtest/position.py +++ b/qlib/backtest/position.py @@ -2,8 +2,6 @@ # Licensed under the MIT License. -import copy -import pathlib from typing import Dict, List, Union import pandas as pd @@ -538,7 +536,7 @@ class InfPosition(BasePosition): def get_stock_amount_dict(self) -> Dict: raise NotImplementedError(f"InfPosition doesn't support get_stock_amount_dict") - def get_stock_weight_dict(self, only_stock: bool) -> Dict: + def get_stock_weight_dict(self, only_stock: bool = False) -> Dict: raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict") def add_count_all(self, bar): diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index f846aea9d..023114623 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -10,11 +10,8 @@ import numpy as np import pandas as pd from qlib.backtest.exchange import Exchange -from .decision import IdxTradeRange from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir -from qlib.backtest.utils import TradeCalendarManager -from .high_performance_ds import BaseOrderIndicator, PandasOrderIndicator, NumpyOrderIndicator, SingleMetric -from ..data import D +from .high_performance_ds import BaseOrderIndicator, NumpyOrderIndicator, SingleMetric from ..tests.config import CSI300_BENCH from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data import qlib.utils.index_data as idd diff --git a/qlib/config.py b/qlib/config.py index fa2e539dc..d831a3ad2 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -388,13 +388,11 @@ class QlibConfig(Config): default_conf : str the default config template chosen by user: "server", "client" """ - from .utils import set_log_with_config, get_module_logger, can_use_cache + from .utils import set_log_with_config, get_module_logger, can_use_cache # pylint: disable=C0415 self.reset() - _logging_config = self.logging_config - if "logging_config" in kwargs: - _logging_config = kwargs["logging_config"] + _logging_config = kwargs.get("logging_config", self.logging_config) # set global config if _logging_config: @@ -433,11 +431,11 @@ class QlibConfig(Config): ) def register(self): - from .utils import init_instance_by_config - from .data.ops import register_all_ops - from .data.data import register_all_wrappers - from .workflow import R, QlibRecorder - from .workflow.utils import experiment_exit_handler + from .utils import init_instance_by_config # pylint: disable=C0415 + from .data.ops import register_all_ops # pylint: disable=C0415 + from .data.data import register_all_wrappers # pylint: disable=C0415 + from .workflow import R, QlibRecorder # pylint: disable=C0415 + from .workflow.utils import experiment_exit_handler # pylint: disable=C0415 register_all_ops(self) register_all_wrappers(self) @@ -454,7 +452,7 @@ class QlibConfig(Config): self._registered = True def reset_qlib_version(self): - import qlib + import qlib # pylint: disable=C0415 reset_version = self.get("qlib_reset_version", None) if reset_version is not None: diff --git a/qlib/contrib/data/dataset.py b/qlib/contrib/data/dataset.py index e8a048762..60362130b 100644 --- a/qlib/contrib/data/dataset.py +++ b/qlib/contrib/data/dataset.py @@ -7,8 +7,7 @@ import warnings import numpy as np import pandas as pd -from qlib.utils import init_instance_by_config -from qlib.data.dataset import DatasetH, DataHandler +from qlib.data.dataset import DatasetH device = "cuda" if torch.cuda.is_available() else "cpu" @@ -16,7 +15,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu" def _to_tensor(x): if not isinstance(x, torch.Tensor): - return torch.tensor(x, dtype=torch.float, device=device) + return torch.tensor(x, dtype=torch.float, device=device) # pylint: disable=E1101 return x diff --git a/qlib/contrib/data/handler.py b/qlib/contrib/data/handler.py index 3c6a93f22..81b7c7392 100644 --- a/qlib/contrib/data/handler.py +++ b/qlib/contrib/data/handler.py @@ -5,9 +5,7 @@ from ...data.dataset.handler import DataHandlerLP from ...data.dataset.processor import Processor from ...utils import get_callable_kwargs from ...data.dataset import processor as processor_module -from ...log import TimeInspector from inspect import getfullargspec -import copy def check_transform_proc(proc_l, fit_start_time, fit_end_time): diff --git a/qlib/contrib/data/processor.py b/qlib/contrib/data/processor.py index 35b242510..e8ea38870 100644 --- a/qlib/contrib/data/processor.py +++ b/qlib/contrib/data/processor.py @@ -1,9 +1,6 @@ import numpy as np -import pandas as pd -import copy from ...log import TimeInspector -from ...utils.serial import Serializable from ...data.dataset.processor import Processor, get_group_columns diff --git a/qlib/contrib/evaluate_portfolio.py b/qlib/contrib/evaluate_portfolio.py index 920d2182c..0c598e2fa 100644 --- a/qlib/contrib/evaluate_portfolio.py +++ b/qlib/contrib/evaluate_portfolio.py @@ -5,12 +5,10 @@ from __future__ import division from __future__ import print_function -import copy import numpy as np import pandas as pd from scipy.stats import spearmanr, pearsonr - from ..data import D from collections import OrderedDict @@ -243,4 +241,4 @@ def get_rank_ic(a, b): def get_normal_ic(a, b): - return pearsonr(a, b).correlation + return pearsonr(a, b)[0] diff --git a/qlib/contrib/meta/data_selection/dataset.py b/qlib/contrib/meta/data_selection/dataset.py index f907af535..235b4f49a 100644 --- a/qlib/contrib/meta/data_selection/dataset.py +++ b/qlib/contrib/meta/data_selection/dataset.py @@ -1,24 +1,23 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from copy import deepcopy -from qlib.data.dataset.utils import init_task_handler -from qlib.utils.data import deepcopy_basic_type -from qlib.contrib.torch import data_to_tensor -from qlib.workflow.task.utils import TimeAdjuster -from qlib.model.meta.task import MetaTask -from typing import Dict, List, Union, Text, Tuple -from qlib.data.dataset.handler import DataHandler -from qlib.log import get_module_logger -from qlib.utils import auto_filter_kwargs, get_date_by_shift, init_instance_by_config -from qlib.workflow import R -from qlib.workflow.task.gen import RollingGen, task_generator -from joblib import Parallel, delayed -from qlib.model.meta.dataset import MetaTaskDataset -from qlib.model.trainer import task_train, TrainerR -from qlib.data.dataset import DatasetH -from tqdm.auto import tqdm import pandas as pd import numpy as np +from copy import deepcopy +from joblib import Parallel, delayed # pylint: disable=E0401 +from typing import Dict, List, Union, Text, Tuple +from qlib.data.dataset.utils import init_task_handler +from qlib.data.dataset import DatasetH +from qlib.contrib.torch import data_to_tensor +from qlib.model.meta.task import MetaTask +from qlib.model.meta.dataset import MetaTaskDataset +from qlib.model.trainer import TrainerR +from qlib.log import get_module_logger +from qlib.utils import auto_filter_kwargs, get_date_by_shift, init_instance_by_config +from qlib.utils.data import deepcopy_basic_type +from qlib.workflow import R +from qlib.workflow.task.gen import RollingGen, task_generator +from qlib.workflow.task.utils import TimeAdjuster +from tqdm.auto import tqdm class InternalData: diff --git a/qlib/contrib/meta/data_selection/model.py b/qlib/contrib/meta/data_selection/model.py index c2106348a..76f16a2ff 100644 --- a/qlib/contrib/meta/data_selection/model.py +++ b/qlib/contrib/meta/data_selection/model.py @@ -1,28 +1,26 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from qlib.log import get_module_logger import pandas as pd import numpy as np -from qlib.model.meta.task import MetaTask import torch from torch import nn from torch import optim from tqdm.auto import tqdm -import collections import copy -from typing import Union, List, Tuple, Dict +from typing import Union, List from ....data.dataset.weight import Reweighter from ....model.meta.dataset import MetaTaskDataset -from ....model.meta.model import MetaModel, MetaTaskModel +from ....model.meta.model import MetaTaskModel from ....workflow import R - from .utils import ICLoss from .dataset import MetaDatasetDS -from qlib.contrib.meta.data_selection.net import PredNet -from qlib.data.dataset.weight import Reweighter + from qlib.log import get_module_logger +from qlib.data.dataset.weight import Reweighter +from qlib.model.meta.task import MetaTask +from qlib.contrib.meta.data_selection.net import PredNet logger = get_module_logger("data selection") diff --git a/qlib/contrib/meta/data_selection/net.py b/qlib/contrib/meta/data_selection/net.py index c8b15d750..0aa8845cf 100644 --- a/qlib/contrib/meta/data_selection/net.py +++ b/qlib/contrib/meta/data_selection/net.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import pandas as pd import numpy as np import torch from torch import nn diff --git a/qlib/contrib/meta/data_selection/utils.py b/qlib/contrib/meta/data_selection/utils.py index 8d7dcf2e4..316f4e5cd 100644 --- a/qlib/contrib/meta/data_selection/utils.py +++ b/qlib/contrib/meta/data_selection/utils.py @@ -1,11 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import pandas as pd import numpy as np import torch from torch import nn -from qlib.contrib.torch import data_to_tensor class ICLoss(nn.Module): diff --git a/qlib/contrib/model/gbdt.py b/qlib/contrib/model/gbdt.py index d2a093d2a..9572287df 100644 --- a/qlib/contrib/model/gbdt.py +++ b/qlib/contrib/model/gbdt.py @@ -101,7 +101,7 @@ class LGBModel(ModelFT, LightGBMFInt): verbose level """ # Based on existing model and finetune by train more rounds - dtrain, _ = self._prepare_data(dataset, reweighter) + dtrain, _ = self._prepare_data(dataset, reweighter) # pylint: disable=W0632 if dtrain.empty: raise ValueError("Empty data from dataset, please check your dataset config.") self.model = lgb.train( diff --git a/qlib/contrib/model/highfreq_gdbt_model.py b/qlib/contrib/model/highfreq_gdbt_model.py index a9f1bd03e..3e6c33471 100644 --- a/qlib/contrib/model/highfreq_gdbt_model.py +++ b/qlib/contrib/model/highfreq_gdbt_model.py @@ -58,7 +58,7 @@ class HFLGBModel(ModelFT, LightGBMFInt): """ Test the signal in high frequency test set """ - if self.model == None: + if self.model is None: raise ValueError("Model hasn't been trained yet") df_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) df_test.dropna(inplace=True) diff --git a/qlib/contrib/model/pytorch_adarnn.py b/qlib/contrib/model/pytorch_adarnn.py index aaa949b14..b5700def6 100644 --- a/qlib/contrib/model/pytorch_adarnn.py +++ b/qlib/contrib/model/pytorch_adarnn.py @@ -1,12 +1,10 @@ # Copyright (c) Microsoft Corporation. import os -from pdb import set_trace from torch.utils.data import Dataset, DataLoader import copy from typing import Text, Union -import math import numpy as np import pandas as pd import torch @@ -182,11 +180,11 @@ class ADARNN(Model): continue total_loss = torch.zeros(1).cuda() - for i in range(len(index)): - feature_s = list_feat[index[i][0]] - feature_t = list_feat[index[i][1]] - label_reg_s = list_label[index[i][0]] - label_reg_t = list_label[index[i][1]] + for i, n in enumerate(index): + feature_s = list_feat[n[0]] + feature_t = list_feat[n[1]] + label_reg_s = list_label[n[0]] + label_reg_t = list_label[n[1]] feature_all = torch.cat((feature_s, feature_t), 0) if epoch < self.pre_epoch: @@ -410,7 +408,7 @@ class AdaRNN(nn.Module): in_size = hidden self.features = nn.Sequential(*features) - if use_bottleneck == True: # finance + if use_bottleneck is True: # finance self.bottleneck = nn.Sequential( nn.Linear(n_hiddens[-1], bottleneck_width), nn.Linear(bottleneck_width, bottleneck_width), @@ -449,7 +447,7 @@ class AdaRNN(nn.Module): def forward_pre_train(self, x, len_win=0): out = self.gru_features(x) fea = out[0] # [2N,L,H] - if self.use_bottleneck == True: + if self.use_bottleneck is True: fea_bottleneck = self.bottleneck(fea[:, -1, :]) fc_out = self.fc(fea_bottleneck).squeeze() else: @@ -458,8 +456,8 @@ class AdaRNN(nn.Module): out_list_all, out_weight_list = out[1], out[2] out_list_s, out_list_t = self.get_features(out_list_all) loss_transfer = torch.zeros((1,)).cuda() - for i in range(len(out_list_s)): - criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=out_list_s[i].shape[2]) + for i, n in enumerate(out_list_s): + criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=n.shape[2]) h_start = 0 for j in range(h_start, self.len_seq, 1): i_start = j - len_win if j - len_win >= 0 else 0 @@ -471,7 +469,7 @@ class AdaRNN(nn.Module): else 1 / (self.len_seq - h_start) * (2 * len_win + 1) ) loss_transfer = loss_transfer + weight * criterion_transder.compute( - out_list_s[i][:, j, :], out_list_t[i][:, k, :] + n[:, j, :], out_list_t[i][:, k, :] ) return fc_out, loss_transfer, out_weight_list @@ -484,7 +482,7 @@ class AdaRNN(nn.Module): out, _ = self.features[i](x_input.float()) x_input = out out_lis.append(out) - if self.model_type == "AdaRNN" and predict == False: + if self.model_type == "AdaRNN" and predict is False: out_gate = self.process_gate_weight(x_input, i) out_weight_list.append(out_gate) return out, out_lis, out_weight_list @@ -524,10 +522,10 @@ class AdaRNN(nn.Module): else: weight = weight_mat dist_mat = torch.zeros(self.num_layers, self.len_seq).cuda() - for i in range(len(out_list_s)): - criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=out_list_s[i].shape[2]) + for i, n in enumerate(out_list_s): + criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=n.shape[2]) for j in range(self.len_seq): - loss_trans = criterion_transder.compute(out_list_s[i][:, j, :], out_list_t[i][:, j, :]) + loss_trans = criterion_transder.compute(n[:, j, :], out_list_t[i][:, j, :]) loss_transfer = loss_transfer + weight[i, j] * loss_trans dist_mat[i, j] = loss_trans return fc_out, loss_transfer, dist_mat, weight @@ -546,7 +544,7 @@ class AdaRNN(nn.Module): def predict(self, x): out = self.gru_features(x, predict=True) fea = out[0] - if self.use_bottleneck == True: + if self.use_bottleneck is True: fea_bottleneck = self.bottleneck(fea[:, -1, :]) fc_out = self.fc(fea_bottleneck).squeeze() else: @@ -572,12 +570,12 @@ class TransferLoss: Returns: [tensor] -- transfer loss """ - if self.loss_type == "mmd_lin" or self.loss_type == "mmd": + if self.loss_type in ("mmd_lin", "mmd"): mmdloss = MMD_loss(kernel_type="linear") loss = mmdloss(X, Y) elif self.loss_type == "coral": loss = CORAL(X, Y) - elif self.loss_type == "cosine" or self.loss_type == "cos": + elif self.loss_type in ("cosine", "cos"): loss = 1 - cosine(X, Y) elif self.loss_type == "kl": loss = kl_div(X, Y) diff --git a/qlib/contrib/model/pytorch_add.py b/qlib/contrib/model/pytorch_add.py index 234d66299..b214daed3 100644 --- a/qlib/contrib/model/pytorch_add.py +++ b/qlib/contrib/model/pytorch_add.py @@ -20,7 +20,6 @@ from qlib.contrib.model.pytorch_lstm import LSTMModel from qlib.contrib.model.pytorch_utils import count_parameters from qlib.data.dataset import DatasetH from qlib.data.dataset.handler import DataHandlerLP -from qlib.data.dataset.processor import CSRankNorm from qlib.log import get_module_logger from qlib.model.base import Model from qlib.utils import get_or_create_path diff --git a/qlib/contrib/model/pytorch_alstm.py b/qlib/contrib/model/pytorch_alstm.py index f3f2f090d..13e3bf879 100644 --- a/qlib/contrib/model/pytorch_alstm.py +++ b/qlib/contrib/model/pytorch_alstm.py @@ -5,7 +5,6 @@ from __future__ import division from __future__ import print_function -import os import numpy as np import pandas as pd from typing import Text, Union @@ -150,7 +149,7 @@ class ALSTM(Model): mask = torch.isfinite(label) - if self.metric == "" or self.metric == "loss": + if self.metric in ("", "loss"): return -self.loss_fn(pred[mask], label[mask]) raise ValueError("unknown metric `%s`" % self.metric) @@ -312,8 +311,8 @@ class ALSTMModel(nn.Module): def _build_model(self): try: klass = getattr(nn, self.rnn_type.upper()) - except: - raise ValueError("unknown rnn_type `%s`" % self.rnn_type) + except Exception as e: + raise ValueError("unknown rnn_type `%s`" % self.rnn_type) from e self.net = nn.Sequential() self.net.add_module("fc_in", nn.Linear(in_features=self.input_size, out_features=self.hid_size)) self.net.add_module("act", nn.Tanh()) diff --git a/qlib/contrib/model/pytorch_alstm_ts.py b/qlib/contrib/model/pytorch_alstm_ts.py index df724e8b9..60645e2a3 100644 --- a/qlib/contrib/model/pytorch_alstm_ts.py +++ b/qlib/contrib/model/pytorch_alstm_ts.py @@ -5,7 +5,6 @@ from __future__ import division from __future__ import print_function -import os import numpy as np import pandas as pd from typing import Text, Union @@ -20,7 +19,7 @@ from torch.utils.data import DataLoader from .pytorch_utils import count_parameters from ...model.base import Model -from ...data.dataset import DatasetH, TSDatasetH +from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP from ...model.utils import ConcatDataset from ...data.dataset.weight import Reweighter @@ -160,7 +159,7 @@ class ALSTM(Model): mask = torch.isfinite(label) - if self.metric == "" or self.metric == "loss": + if self.metric in ("", "loss"): return -self.loss_fn(pred[mask], label[mask]) raise ValueError("unknown metric `%s`" % self.metric) @@ -320,8 +319,8 @@ class ALSTMModel(nn.Module): def _build_model(self): try: klass = getattr(nn, self.rnn_type.upper()) - except: - raise ValueError("unknown rnn_type `%s`" % self.rnn_type) + except Exception as e: + raise ValueError("unknown rnn_type `%s`" % self.rnn_type) from e self.net = nn.Sequential() self.net.add_module("fc_in", nn.Linear(in_features=self.input_size, out_features=self.hid_size)) self.net.add_module("act", nn.Tanh()) diff --git a/qlib/contrib/model/pytorch_gats.py b/qlib/contrib/model/pytorch_gats.py index 7c2c99432..b5685fdab 100644 --- a/qlib/contrib/model/pytorch_gats.py +++ b/qlib/contrib/model/pytorch_gats.py @@ -5,7 +5,6 @@ from __future__ import division from __future__ import print_function -import os import numpy as np import pandas as pd from typing import Text, Union @@ -158,7 +157,7 @@ class GATs(Model): mask = torch.isfinite(label) - if self.metric == "" or self.metric == "loss": + if self.metric in ("", "loss"): return -self.loss_fn(pred[mask], label[mask]) raise ValueError("unknown metric `%s`" % self.metric) @@ -263,7 +262,9 @@ class GATs(Model): pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device)) model_dict = self.GAT_model.state_dict() - pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict} + pretrained_dict = { + k: v for k, v in pretrained_model.state_dict().items() if k in model_dict + } # pylint: disable=E1135 model_dict.update(pretrained_dict) self.GAT_model.load_state_dict(model_dict) self.logger.info("Loading pretrained model Done...") diff --git a/qlib/contrib/model/pytorch_gats_ts.py b/qlib/contrib/model/pytorch_gats_ts.py index 53a7817e2..901d4c2bd 100644 --- a/qlib/contrib/model/pytorch_gats_ts.py +++ b/qlib/contrib/model/pytorch_gats_ts.py @@ -5,7 +5,6 @@ from __future__ import division from __future__ import print_function -import os import numpy as np import pandas as pd import copy @@ -19,7 +18,6 @@ from torch.utils.data import Sampler from .pytorch_utils import count_parameters from ...model.base import Model -from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP from ...contrib.model.pytorch_lstm import LSTMModel from ...contrib.model.pytorch_gru import GRUModel @@ -178,7 +176,7 @@ class GATs(Model): mask = torch.isfinite(label) - if self.metric == "" or self.metric == "loss": + if self.metric in ("", "loss"): return -self.loss_fn(pred[mask], label[mask]) raise ValueError("unknown metric `%s`" % self.metric) @@ -279,7 +277,9 @@ class GATs(Model): pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device)) model_dict = self.GAT_model.state_dict() - pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict} + pretrained_dict = { + k: v for k, v in pretrained_model.state_dict().items() if k in model_dict + } # pylint: disable=E1135 model_dict.update(pretrained_dict) self.GAT_model.load_state_dict(model_dict) self.logger.info("Loading pretrained model Done...") diff --git a/qlib/contrib/model/pytorch_gru.py b/qlib/contrib/model/pytorch_gru.py index 740bdd977..2275b86e1 100755 --- a/qlib/contrib/model/pytorch_gru.py +++ b/qlib/contrib/model/pytorch_gru.py @@ -5,7 +5,6 @@ from __future__ import division from __future__ import print_function -import os import numpy as np import pandas as pd from typing import Text, Union @@ -150,7 +149,7 @@ class GRU(Model): mask = torch.isfinite(label) - if self.metric == "" or self.metric == "loss": + if self.metric in ("", "loss"): return -self.loss_fn(pred[mask], label[mask]) raise ValueError("unknown metric `%s`" % self.metric) diff --git a/qlib/contrib/model/pytorch_gru_ts.py b/qlib/contrib/model/pytorch_gru_ts.py index 6e7c1594a..390a66924 100755 --- a/qlib/contrib/model/pytorch_gru_ts.py +++ b/qlib/contrib/model/pytorch_gru_ts.py @@ -5,7 +5,6 @@ from __future__ import division from __future__ import print_function -import os import numpy as np import pandas as pd import copy @@ -19,7 +18,6 @@ from torch.utils.data import DataLoader from .pytorch_utils import count_parameters from ...model.base import Model -from ...data.dataset import DatasetH, TSDatasetH from ...data.dataset.handler import DataHandlerLP from ...model.utils import ConcatDataset from ...data.dataset.weight import Reweighter @@ -159,7 +157,7 @@ class GRU(Model): mask = torch.isfinite(label) - if self.metric == "" or self.metric == "loss": + if self.metric in ("", "loss"): return -self.loss_fn(pred[mask], label[mask]) raise ValueError("unknown metric `%s`" % self.metric) diff --git a/qlib/contrib/model/pytorch_localformer.py b/qlib/contrib/model/pytorch_localformer.py index 7548c936f..6e7d91180 100644 --- a/qlib/contrib/model/pytorch_localformer.py +++ b/qlib/contrib/model/pytorch_localformer.py @@ -5,7 +5,6 @@ from __future__ import division from __future__ import print_function -import os import numpy as np import pandas as pd from typing import Text, Union @@ -17,11 +16,9 @@ from ...log import get_module_logger import torch import torch.nn as nn import torch.optim as optim -from torch.utils.data import DataLoader -from .pytorch_utils import count_parameters from ...model.base import Model -from ...data.dataset import DatasetH, TSDatasetH +from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP from torch.nn.modules.container import ModuleList @@ -102,7 +99,7 @@ class LocalformerModel(Model): mask = torch.isfinite(label) - if self.metric == "" or self.metric == "loss": + if self.metric in ("", "loss"): return -self.loss_fn(pred[mask], label[mask]) raise ValueError("unknown metric `%s`" % self.metric) diff --git a/qlib/contrib/model/pytorch_localformer_ts.py b/qlib/contrib/model/pytorch_localformer_ts.py index 9645e28f3..18ef7f112 100644 --- a/qlib/contrib/model/pytorch_localformer_ts.py +++ b/qlib/contrib/model/pytorch_localformer_ts.py @@ -5,7 +5,6 @@ from __future__ import division from __future__ import print_function -import os import numpy as np import pandas as pd import copy @@ -18,9 +17,8 @@ import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader -from .pytorch_utils import count_parameters from ...model.base import Model -from ...data.dataset import DatasetH, TSDatasetH +from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP from torch.nn.modules.container import ModuleList @@ -101,7 +99,7 @@ class LocalformerModel(Model): mask = torch.isfinite(label) - if self.metric == "" or self.metric == "loss": + if self.metric in ("", "loss"): return -self.loss_fn(pred[mask], label[mask]) raise ValueError("unknown metric `%s`" % self.metric) diff --git a/qlib/contrib/model/pytorch_lstm.py b/qlib/contrib/model/pytorch_lstm.py index 4920613af..494fd4a0e 100755 --- a/qlib/contrib/model/pytorch_lstm.py +++ b/qlib/contrib/model/pytorch_lstm.py @@ -5,7 +5,6 @@ from __future__ import division from __future__ import print_function -import os import numpy as np import pandas as pd from typing import Text, Union @@ -146,7 +145,7 @@ class LSTM(Model): mask = torch.isfinite(label) - if self.metric == "" or self.metric == "loss": + if self.metric in ("", "loss"): return -self.loss_fn(pred[mask], label[mask]) raise ValueError("unknown metric `%s`" % self.metric) diff --git a/qlib/contrib/model/pytorch_lstm_ts.py b/qlib/contrib/model/pytorch_lstm_ts.py index 6e80127f4..d7705981a 100755 --- a/qlib/contrib/model/pytorch_lstm_ts.py +++ b/qlib/contrib/model/pytorch_lstm_ts.py @@ -5,7 +5,6 @@ from __future__ import division from __future__ import print_function -import os import numpy as np import pandas as pd import copy @@ -18,7 +17,6 @@ import torch.optim as optim from torch.utils.data import DataLoader from ...model.base import Model -from ...data.dataset import DatasetH, TSDatasetH from ...data.dataset.handler import DataHandlerLP from ...model.utils import ConcatDataset from ...data.dataset.weight import Reweighter @@ -155,7 +153,7 @@ class LSTM(Model): mask = torch.isfinite(label) - if self.metric == "" or self.metric == "loss": + if self.metric in ("", "loss"): return -self.loss_fn(pred[mask], label[mask]) raise ValueError("unknown metric `%s`" % self.metric) diff --git a/qlib/contrib/model/pytorch_nn.py b/qlib/contrib/model/pytorch_nn.py index 83d217458..173d494bf 100644 --- a/qlib/contrib/model/pytorch_nn.py +++ b/qlib/contrib/model/pytorch_nn.py @@ -328,6 +328,7 @@ class Net(nn.Module): dnn_layers = [] drop_input = nn.Dropout(0.05) dnn_layers.append(drop_input) + hidden_units = None for i, (_input_dim, hidden_units) in enumerate(zip(layers[:-1], layers[1:])): fc = nn.Linear(_input_dim, hidden_units) activation = nn.LeakyReLU(negative_slope=0.1, inplace=False) @@ -338,7 +339,7 @@ class Net(nn.Module): dnn_layers.append(drop_input) fc = nn.Linear(hidden_units, output_dim) dnn_layers.append(fc) - # optimizer + # optimizer # pylint: disable=W0631 self.dnn_layers = nn.ModuleList(dnn_layers) self._weight_init() diff --git a/qlib/contrib/model/pytorch_sfm.py b/qlib/contrib/model/pytorch_sfm.py index 5d076b9fd..cebaeef96 100644 --- a/qlib/contrib/model/pytorch_sfm.py +++ b/qlib/contrib/model/pytorch_sfm.py @@ -4,7 +4,6 @@ from __future__ import division from __future__ import print_function -import os import numpy as np import pandas as pd from typing import Text, Union @@ -435,7 +434,7 @@ class SFM(Model): mask = torch.isfinite(label) - if self.metric == "" or self.metric == "loss": + if self.metric in ("", "loss"): return -self.loss_fn(pred[mask], label[mask]) raise ValueError("unknown metric `%s`" % self.metric) diff --git a/qlib/contrib/model/pytorch_tabnet.py b/qlib/contrib/model/pytorch_tabnet.py index d9290977b..d50067639 100644 --- a/qlib/contrib/model/pytorch_tabnet.py +++ b/qlib/contrib/model/pytorch_tabnet.py @@ -3,7 +3,6 @@ from __future__ import division from __future__ import print_function -import os import numpy as np import pandas as pd from typing import Text, Union @@ -378,7 +377,7 @@ class TabnetModel(Model): def metric_fn(self, pred, label): mask = torch.isfinite(label) - if self.metric == "" or self.metric == "loss": + if self.metric in ("", "loss"): return -self.loss_fn(pred[mask], label[mask]) raise ValueError("unknown metric `%s`" % self.metric) diff --git a/qlib/contrib/model/pytorch_tcn.py b/qlib/contrib/model/pytorch_tcn.py index c649dfa0b..8c40683fe 100755 --- a/qlib/contrib/model/pytorch_tcn.py +++ b/qlib/contrib/model/pytorch_tcn.py @@ -15,7 +15,6 @@ from ...log import get_module_logger import torch import torch.nn as nn import torch.optim as optim -from torch.nn.utils import weight_norm from .pytorch_utils import count_parameters from ...model.base import Model @@ -158,7 +157,7 @@ class TCN(Model): mask = torch.isfinite(label) - if self.metric == "" or self.metric == "loss": + if self.metric in ("", "loss"): return -self.loss_fn(pred[mask], label[mask]) raise ValueError("unknown metric `%s`" % self.metric) diff --git a/qlib/contrib/model/pytorch_tcn_ts.py b/qlib/contrib/model/pytorch_tcn_ts.py index 3e0a15e04..13c125d27 100755 --- a/qlib/contrib/model/pytorch_tcn_ts.py +++ b/qlib/contrib/model/pytorch_tcn_ts.py @@ -158,7 +158,7 @@ class TCN(Model): mask = torch.isfinite(label) - if self.metric == "" or self.metric == "loss": + if self.metric in ("", "loss"): return -self.loss_fn(pred[mask], label[mask]) raise ValueError("unknown metric `%s`" % self.metric) diff --git a/qlib/contrib/model/pytorch_tcts.py b/qlib/contrib/model/pytorch_tcts.py index 8cb56930d..8dadefb68 100644 --- a/qlib/contrib/model/pytorch_tcts.py +++ b/qlib/contrib/model/pytorch_tcts.py @@ -5,20 +5,12 @@ from __future__ import division from __future__ import print_function -import os import numpy as np import pandas as pd import copy import random -from sklearn.metrics import roc_auc_score, mean_squared_error -import logging -from ...utils import ( - unpack_archive_with_buffer, - save_multiple_parts_file, - get_or_create_path, - drop_nan_by_y_index, -) -from ...log import get_module_logger, TimeInspector +from ...utils import get_or_create_path +from ...log import get_module_logger import torch import torch.nn as nn @@ -263,7 +255,7 @@ class TCTS(Model): x_valid, y_valid = df_valid["feature"], df_valid["label"] x_test, y_test = df_test["feature"], df_test["label"] - if save_path == None: + if save_path is None: save_path = get_or_create_path(save_path) best_loss = np.inf while best_loss > self.lowest_valid_performance: diff --git a/qlib/contrib/model/pytorch_tra.py b/qlib/contrib/model/pytorch_tra.py index 42df4e5c1..81c9ba145 100644 --- a/qlib/contrib/model/pytorch_tra.py +++ b/qlib/contrib/model/pytorch_tra.py @@ -6,10 +6,8 @@ import os import copy import math import json -import collections import numpy as np import pandas as pd -import seaborn as sns import matplotlib.pyplot as plt import torch @@ -24,7 +22,6 @@ except ImportError: from tqdm import tqdm -from qlib.utils import get_or_create_path from qlib.constant import EPS from qlib.log import get_module_logger from qlib.model.base import Model diff --git a/qlib/contrib/model/pytorch_transformer.py b/qlib/contrib/model/pytorch_transformer.py index da36cd5f6..66e5b2c4e 100644 --- a/qlib/contrib/model/pytorch_transformer.py +++ b/qlib/contrib/model/pytorch_transformer.py @@ -5,7 +5,6 @@ from __future__ import division from __future__ import print_function -import os import numpy as np import pandas as pd from typing import Text, Union @@ -17,11 +16,9 @@ from ...log import get_module_logger import torch import torch.nn as nn import torch.optim as optim -from torch.utils.data import DataLoader -from .pytorch_utils import count_parameters from ...model.base import Model -from ...data.dataset import DatasetH, TSDatasetH +from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP # qrun examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml ” @@ -101,7 +98,7 @@ class TransformerModel(Model): mask = torch.isfinite(label) - if self.metric == "" or self.metric == "loss": + if self.metric in ("", "loss"): return -self.loss_fn(pred[mask], label[mask]) raise ValueError("unknown metric `%s`" % self.metric) diff --git a/qlib/contrib/model/pytorch_transformer_ts.py b/qlib/contrib/model/pytorch_transformer_ts.py index fbb47df7f..6cffded9c 100644 --- a/qlib/contrib/model/pytorch_transformer_ts.py +++ b/qlib/contrib/model/pytorch_transformer_ts.py @@ -5,7 +5,6 @@ from __future__ import division from __future__ import print_function -import os import numpy as np import pandas as pd import copy @@ -18,9 +17,8 @@ import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader -from .pytorch_utils import count_parameters from ...model.base import Model -from ...data.dataset import DatasetH, TSDatasetH +from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP @@ -98,7 +96,7 @@ class TransformerModel(Model): mask = torch.isfinite(label) - if self.metric == "" or self.metric == "loss": + if self.metric in ("", "loss"): return -self.loss_fn(pred[mask], label[mask]) raise ValueError("unknown metric `%s`" % self.metric) diff --git a/qlib/contrib/model/pytorch_utils.py b/qlib/contrib/model/pytorch_utils.py index e22f8f754..224809dc2 100644 --- a/qlib/contrib/model/pytorch_utils.py +++ b/qlib/contrib/model/pytorch_utils.py @@ -26,11 +26,11 @@ def count_parameters(models_or_parameters, unit="m"): else: counts = sum(v.numel() for v in models_or_parameters) unit = unit.lower() - if unit == "kb" or unit == "k": + if unit in ("kb", "k"): counts /= 2 ** 10 - elif unit == "mb" or unit == "m": + elif unit in ("mb", "m"): counts /= 2 ** 20 - elif unit == "gb" or unit == "g": + elif unit in ("gb", "g"): counts /= 2 ** 30 elif unit is not None: raise ValueError("Unknown unit: {:}".format(unit)) diff --git a/qlib/contrib/model/tcn.py b/qlib/contrib/model/tcn.py index ba6a85b8f..aa61bd44a 100644 --- a/qlib/contrib/model/tcn.py +++ b/qlib/contrib/model/tcn.py @@ -1,6 +1,5 @@ # MIT License # Copyright (c) 2018 CMU Locus Lab -import torch import torch.nn as nn from torch.nn.utils import weight_norm diff --git a/qlib/contrib/online/__init__.py b/qlib/contrib/online/__init__.py index 71389882e..642cf8d3a 100644 --- a/qlib/contrib/online/__init__.py +++ b/qlib/contrib/online/__init__.py @@ -1,3 +1,5 @@ +# pylint: skip-file + ''' TODO: diff --git a/qlib/contrib/online/manager.py b/qlib/contrib/online/manager.py index 7b07c4c07..d0b82df43 100644 --- a/qlib/contrib/online/manager.py +++ b/qlib/contrib/online/manager.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# pylint: skip-file + import yaml import pathlib import pandas as pd diff --git a/qlib/contrib/online/online_model.py b/qlib/contrib/online/online_model.py index 0e8c0cb19..1f7d455dd 100644 --- a/qlib/contrib/online/online_model.py +++ b/qlib/contrib/online/online_model.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# pylint: skip-file + import random import pandas as pd from ...data import D diff --git a/qlib/contrib/online/operator.py b/qlib/contrib/online/operator.py index 971dcda75..082e5da50 100644 --- a/qlib/contrib/online/operator.py +++ b/qlib/contrib/online/operator.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# pylint: skip-file + import fire import pandas as pd import pathlib diff --git a/qlib/contrib/online/user.py b/qlib/contrib/online/user.py index a7a8654d1..aade29596 100644 --- a/qlib/contrib/online/user.py +++ b/qlib/contrib/online/user.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# pylint: skip-file + import logging from ...log import get_module_logger diff --git a/qlib/contrib/online/utils.py b/qlib/contrib/online/utils.py index 52dcd819e..3b4ec8c5d 100644 --- a/qlib/contrib/online/utils.py +++ b/qlib/contrib/online/utils.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# pylint: skip-file + import pathlib import pickle import yaml diff --git a/qlib/contrib/ops/high_freq.py b/qlib/contrib/ops/high_freq.py index 3ce5c961f..d35dbf92f 100644 --- a/qlib/contrib/ops/high_freq.py +++ b/qlib/contrib/ops/high_freq.py @@ -1,12 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from pathlib import Path import numpy as np -import pandas as pd from datetime import datetime -import qlib -from qlib.data import D from qlib.data.cache import H from qlib.data.data import Cal from qlib.data.ops import ElemOperator diff --git a/qlib/contrib/report/analysis_model/analysis_model_performance.py b/qlib/contrib/report/analysis_model/analysis_model_performance.py index 32c111ab9..3bd3eb65e 100644 --- a/qlib/contrib/report/analysis_model/analysis_model_performance.py +++ b/qlib/contrib/report/analysis_model/analysis_model_performance.py @@ -34,7 +34,7 @@ def _group_return(pred_label: pd.DataFrame = None, reverse: bool = False, N: int { "Group%d" % (i + 1): pred_label_drop.groupby(level="datetime")["label"].apply( - lambda x: x[len(x) // N * i : len(x) // N * (i + 1)].mean() + lambda x: x[len(x) // N * i : len(x) // N * (i + 1)].mean() # pylint: disable=W0640 ) for i in range(N) } diff --git a/qlib/contrib/report/graph.py b/qlib/contrib/report/graph.py index 9aa3631af..c5f932978 100644 --- a/qlib/contrib/report/graph.py +++ b/qlib/contrib/report/graph.py @@ -282,8 +282,10 @@ class SubplotsGraph: if self._subplots_kwargs is None: self._init_subplots_kwargs() - self.__cols = self._subplots_kwargs.get("cols", 2) - self.__rows = self._subplots_kwargs.get("rows", math.ceil(len(self._df.columns) / self.__cols)) + self.__cols = self._subplots_kwargs.get("cols", 2) # pylint: disable=W0238 + self.__rows = self._subplots_kwargs.get( # pylint: disable=W0238 + "rows", math.ceil(len(self._df.columns) / self.__cols) + ) self._sub_graph_data = sub_graph_data if self._sub_graph_data is None: diff --git a/qlib/contrib/strategy/optimizer/base.py b/qlib/contrib/strategy/optimizer/base.py index e3f692014..715fd4981 100644 --- a/qlib/contrib/strategy/optimizer/base.py +++ b/qlib/contrib/strategy/optimizer/base.py @@ -10,4 +10,3 @@ class BaseOptimizer(abc.ABC): @abc.abstractmethod def __call__(self, *args, **kwargs) -> object: """Generate a optimized portfolio allocation""" - pass diff --git a/qlib/contrib/strategy/optimizer/enhanced_indexing.py b/qlib/contrib/strategy/optimizer/enhanced_indexing.py index 9e3a35748..9da609005 100644 --- a/qlib/contrib/strategy/optimizer/enhanced_indexing.py +++ b/qlib/contrib/strategy/optimizer/enhanced_indexing.py @@ -3,7 +3,6 @@ import numpy as np import cvxpy as cp -import pandas as pd from typing import Union, Optional, Dict, Any, List @@ -156,7 +155,7 @@ class EnhancedIndexingOptimizer(BaseOptimizer): # factor deviation if self.f_dev is not None: - cons.extend([v >= -self.f_dev, v <= self.f_dev]) + cons.extend([v >= -self.f_dev, v <= self.f_dev]) # pylint: disable=E1130 # total turnover constraint t_cons = [] diff --git a/qlib/contrib/strategy/order_generator.py b/qlib/contrib/strategy/order_generator.py index 5dfef1510..9e84d5046 100644 --- a/qlib/contrib/strategy/order_generator.py +++ b/qlib/contrib/strategy/order_generator.py @@ -6,7 +6,6 @@ This order generator is for strategies based on WeightStrategyBase """ from ...backtest.position import Position from ...backtest.exchange import Exchange -from ...backtest.decision import BaseTradeDecision, TradeDecisionWO import pandas as pd import copy diff --git a/qlib/contrib/strategy/signal_strategy.py b/qlib/contrib/strategy/signal_strategy.py index 9587abd65..b0826b1ba 100644 --- a/qlib/contrib/strategy/signal_strategy.py +++ b/qlib/contrib/strategy/signal_strategy.py @@ -3,7 +3,6 @@ import os import copy import warnings -import cvxpy as cp import numpy as np import pandas as pd @@ -15,11 +14,10 @@ from qlib.model.base import BaseModel from qlib.strategy.base import BaseStrategy from qlib.backtest.position import Position from qlib.backtest.signal import Signal, create_signal_from -from qlib.backtest.decision import Order, BaseTradeDecision, OrderDir, TradeDecisionWO +from qlib.backtest.decision import Order, OrderDir, TradeDecisionWO from qlib.log import get_module_logger from qlib.utils import get_pre_trading_date, load_dataset -from qlib.utils.resam import resam_ts_data -from qlib.contrib.strategy.order_generator import OrderGenWInteract, OrderGenWOInteract +from qlib.contrib.strategy.order_generator import OrderGenWOInteract from qlib.contrib.strategy.optimizer import EnhancedIndexingOptimizer diff --git a/qlib/contrib/tuner/__init__.py b/qlib/contrib/tuner/__init__.py index e69de29bb..388083ed9 100644 --- a/qlib/contrib/tuner/__init__.py +++ b/qlib/contrib/tuner/__init__.py @@ -0,0 +1 @@ +# pylint: skip-file diff --git a/qlib/contrib/tuner/config.py b/qlib/contrib/tuner/config.py index 247fa6a4f..3a6ba4345 100644 --- a/qlib/contrib/tuner/config.py +++ b/qlib/contrib/tuner/config.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# pylint: skip-file + import yaml import copy import os diff --git a/qlib/contrib/tuner/launcher.py b/qlib/contrib/tuner/launcher.py index 711658c9a..36828e443 100644 --- a/qlib/contrib/tuner/launcher.py +++ b/qlib/contrib/tuner/launcher.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# pylint: skip-file + # coding=utf-8 import argparse diff --git a/qlib/contrib/tuner/pipeline.py b/qlib/contrib/tuner/pipeline.py index ee92db529..7b651276d 100644 --- a/qlib/contrib/tuner/pipeline.py +++ b/qlib/contrib/tuner/pipeline.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# pylint: skip-file + import os import json import logging diff --git a/qlib/contrib/tuner/space.py b/qlib/contrib/tuner/space.py index 76f101671..67cc8a7f5 100644 --- a/qlib/contrib/tuner/space.py +++ b/qlib/contrib/tuner/space.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# pylint: skip-file + from hyperopt import hp diff --git a/qlib/contrib/tuner/tuner.py b/qlib/contrib/tuner/tuner.py index 114ee0a74..9c0db6494 100644 --- a/qlib/contrib/tuner/tuner.py +++ b/qlib/contrib/tuner/tuner.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# pylint: skip-file + import os import yaml import json diff --git a/qlib/data/base.py b/qlib/data/base.py index b18e7aa47..bbadf68a3 100644 --- a/qlib/data/base.py +++ b/qlib/data/base.py @@ -6,7 +6,6 @@ from __future__ import division from __future__ import print_function import abc -import pandas as pd from ..log import get_module_logger @@ -21,107 +20,107 @@ class Expression(abc.ABC): return str(self) def __gt__(self, other): - from .ops import Gt + from .ops import Gt # pylint: disable=C0415 return Gt(self, other) def __ge__(self, other): - from .ops import Ge + from .ops import Ge # pylint: disable=C0415 return Ge(self, other) def __lt__(self, other): - from .ops import Lt + from .ops import Lt # pylint: disable=C0415 return Lt(self, other) def __le__(self, other): - from .ops import Le + from .ops import Le # pylint: disable=C0415 return Le(self, other) def __eq__(self, other): - from .ops import Eq + from .ops import Eq # pylint: disable=C0415 return Eq(self, other) def __ne__(self, other): - from .ops import Ne + from .ops import Ne # pylint: disable=C0415 return Ne(self, other) def __add__(self, other): - from .ops import Add + from .ops import Add # pylint: disable=C0415 return Add(self, other) def __radd__(self, other): - from .ops import Add + from .ops import Add # pylint: disable=C0415 return Add(other, self) def __sub__(self, other): - from .ops import Sub + from .ops import Sub # pylint: disable=C0415 return Sub(self, other) def __rsub__(self, other): - from .ops import Sub + from .ops import Sub # pylint: disable=C0415 return Sub(other, self) def __mul__(self, other): - from .ops import Mul + from .ops import Mul # pylint: disable=C0415 return Mul(self, other) def __rmul__(self, other): - from .ops import Mul + from .ops import Mul # pylint: disable=C0415 return Mul(self, other) def __div__(self, other): - from .ops import Div + from .ops import Div # pylint: disable=C0415 return Div(self, other) def __rdiv__(self, other): - from .ops import Div + from .ops import Div # pylint: disable=C0415 return Div(other, self) def __truediv__(self, other): - from .ops import Div + from .ops import Div # pylint: disable=C0415 return Div(self, other) def __rtruediv__(self, other): - from .ops import Div + from .ops import Div # pylint: disable=C0415 return Div(other, self) def __pow__(self, other): - from .ops import Power + from .ops import Power # pylint: disable=C0415 return Power(self, other) def __and__(self, other): - from .ops import And + from .ops import And # pylint: disable=C0415 return And(self, other) def __rand__(self, other): - from .ops import And + from .ops import And # pylint: disable=C0415 return And(other, self) def __or__(self, other): - from .ops import Or + from .ops import Or # pylint: disable=C0415 return Or(self, other) def __ror__(self, other): - from .ops import Or + from .ops import Or # pylint: disable=C0415 return Or(other, self) @@ -144,7 +143,7 @@ class Expression(abc.ABC): pd.Series feature series: The index of the series is the calendar index """ - from .cache import H + from .cache import H # pylint: disable=C0415 # cache args = str(self), instrument, start_index, end_index, freq @@ -215,7 +214,7 @@ class Feature(Expression): def _load_internal(self, instrument, start_index, end_index, freq): # load - from .data import FeatureD + from .data import FeatureD # pylint: disable=C0415 return FeatureD.feature(instrument, str(self), start_index, end_index, freq) @@ -232,5 +231,3 @@ class ExpressionOps(Expression): This kind of feature will use operator for feature construction on the fly. """ - - pass diff --git a/qlib/data/cache.py b/qlib/data/cache.py index a156bded4..fc6518de5 100644 --- a/qlib/data/cache.py +++ b/qlib/data/cache.py @@ -33,8 +33,7 @@ from ..utils import ( from ..log import get_module_logger from .base import Feature - -from .ops import Operators +from .ops import Operators # pylint: disable=W0611 class QlibCacheException(RuntimeError): @@ -229,8 +228,8 @@ class CacheUtils: try: d["meta"]["last_visit"] = str(time.time()) d["meta"]["visits"] = d["meta"]["visits"] + 1 - except KeyError: - raise KeyError("Unknown meta keyword") + except KeyError as key_e: + raise KeyError("Unknown meta keyword") from key_e pickle.dump(d, f, protocol=C.dump_protocol_version) except Exception as e: get_module_logger("CacheUtils").warning(f"visit {cache_path} cache error: {e}") @@ -239,7 +238,7 @@ class CacheUtils: def acquire(lock, lock_name): try: lock.acquire() - except redis_lock.AlreadyAcquired: + except redis_lock.AlreadyAcquired as lock_acquired: raise QlibCacheException( f"""It sees the key(lock:{repr(lock_name)[1:-1]}-wlock) of the redis lock has existed in your redis db now. You can use the following command to clear your redis keys and rerun your commands: @@ -249,7 +248,7 @@ class CacheUtils: > quit If the issue is not resolved, use "keys *" to find if multiple keys exist. If so, try using "flushall" to clear all the keys. """ - ) + ) from lock_acquired @staticmethod @contextlib.contextmanager @@ -507,7 +506,7 @@ class DiskExpressionCache(ExpressionCache): _instrument_dir = self.get_cache_dir(freq).joinpath(instrument.lower()) cache_path = _instrument_dir.joinpath(_cache_uri) # get calendar - from .data import Cal + from .data import Cal # pylint: disable=C0415 _calendar = Cal.calendar(freq=freq) @@ -599,7 +598,7 @@ class DiskExpressionCache(ExpressionCache): last_update_time = d["info"]["last_update"] # get newest calendar - from .data import Cal, ExpressionD + from .data import Cal, ExpressionD # pylint: disable=C0415 whole_calendar = Cal.calendar(start_time=None, end_time=None, freq=freq) # calendar since last updated. @@ -753,7 +752,7 @@ class DiskDatasetCache(DatasetCache): if disk_cache == 0: # In this case, server only checks the expression cache. # The client will load the cache data by itself. - from .data import LocalDatasetProvider + from .data import LocalDatasetProvider # pylint: disable=C0415 LocalDatasetProvider.multi_cache_walker(instruments, fields, start_time, end_time, freq) return "" @@ -895,7 +894,7 @@ class DiskDatasetCache(DatasetCache): :return type pd.DataFrame; The fields of the returned DataFrame are consistent with the parameters of the function. """ # get calendar - from .data import Cal + from .data import Cal # pylint: disable=C0415 cache_path = Path(cache_path) _calendar = Cal.calendar(freq=freq) @@ -970,14 +969,14 @@ class DiskDatasetCache(DatasetCache): index_data = im.get_index() self.logger.debug("Updating dataset: {}".format(d)) - from .data import Inst + from .data import Inst # pylint: disable=C0415 if Inst.get_inst_type(instruments) == Inst.DICT: self.logger.info(f"The file {cache_uri} has dict cache. Skip updating") return 1 # get newest calendar - from .data import Cal + from .data import Cal # pylint: disable=C0415 whole_calendar = Cal.calendar(start_time=None, end_time=None, freq=freq) # The calendar since last updated @@ -994,7 +993,7 @@ class DiskDatasetCache(DatasetCache): current_index = len(whole_calendar) - len(new_calendar) + 1 # To avoid recursive import - from .data import ExpressionD + from .data import ExpressionD # pylint: disable=C0415 # The existing data length lft_etd = rght_etd = 0 diff --git a/qlib/data/data.py b/qlib/data/data.py index 587d21d8d..8080eb66c 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -5,17 +5,13 @@ from __future__ import division from __future__ import print_function -import os import re import abc -import time import copy import queue import bisect import numpy as np import pandas as pd -from multiprocessing import Pool -from typing import Iterable, Union from typing import List, Union # For supporting multiprocessing in outer code, joblib is used @@ -23,13 +19,10 @@ from joblib import delayed from .cache import H from ..config import C -from .base import Feature -from .ops import Operators from .inst_processor import InstProcessor from ..log import get_module_logger -from ..utils.time import Freq -from .cache import DiskDatasetCache, DiskExpressionCache +from .cache import DiskDatasetCache from ..utils import ( Wrapper, init_instance_by_config, @@ -43,6 +36,7 @@ from ..utils import ( time_to_slc_point, ) from ..utils.paral import ParallelExt +from .ops import Operators # pylint: disable=W0611 class ProviderBackendMixin: @@ -144,10 +138,10 @@ class CalendarProvider(abc.ABC): if start_time not in calendar_index: try: start_time = calendar[bisect.bisect_left(calendar, start_time)] - except IndexError: + except IndexError as index_e: raise IndexError( "`start_time` uses a future date, if you want to get future trading days, you can use: `future=True`" - ) + ) from index_e start_index = calendar_index[start_time] if end_time not in calendar_index: end_time = calendar[bisect.bisect_right(calendar, end_time) - 1] @@ -246,7 +240,7 @@ class InstrumentProvider(abc.ABC): """ if isinstance(market, list): return market - from .filter import SeriesDFilter + from .filter import SeriesDFilter # pylint: disable=C0415 if filter_pipe is None: filter_pipe = [] @@ -672,7 +666,7 @@ class LocalInstrumentProvider(InstrumentProvider, ProviderBackendMixin): # filter filter_pipe = instruments["filter_pipe"] for filter_config in filter_pipe: - from . import filter as F + from . import filter as F # pylint: disable=C0415 filter_t = getattr(F, filter_config["filter_type"]).from_config(filter_config) _instruments_filtered = filter_t(_instruments_filtered, start_time, end_time, freq) @@ -1003,8 +997,8 @@ class ClientDatasetProvider(DatasetProvider): if return_uri: return df, feature_uri return df - except AttributeError: - raise IOError("Unable to fetch instruments from remote server!") + except AttributeError as attribute_e: + raise IOError("Unable to fetch instruments from remote server!") from attribute_e class BaseProvider: @@ -1110,7 +1104,7 @@ class ClientProvider(BaseProvider): return isinstance(instance, cls) - from .client import Client + from .client import Client # pylint: disable=C0415 self.client = Client(C.flask_server, C.flask_port) self.logger = get_module_logger(self.__class__.__name__) diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 5cb81e8c9..3f8b7dcf0 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -52,7 +52,6 @@ class Dataset(Serializable): - User prepare data for model based on previous status. """ - pass def prepare(self, **kwargs) -> object: """ @@ -68,7 +67,6 @@ class Dataset(Serializable): object: return the object """ - pass class DatasetH(Dataset): @@ -348,7 +346,7 @@ class TSDataSampler: flt_data = flt_data.reindex(self.data_index).fillna(False).astype(np.bool) self.flt_data = flt_data.values self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map) - self.data_index = self.data_index[np.where(self.flt_data == True)[0]] + self.data_index = self.data_index[np.where(self.flt_data is True)[0]] self.idx_map = self.idx_map2arr(self.idx_map) self.start_idx, self.end_idx = self.data_index.slice_locs( diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index a6877f013..eab889e85 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -2,24 +2,16 @@ # Licensed under the MIT License. # coding=utf-8 -import abc -import bisect -import logging import warnings -from inspect import getfullargspec from typing import Callable, Union, Tuple, List, Iterator, Optional import pandas as pd -import numpy as np from ...log import get_module_logger, TimeInspector -from ...data import D -from ...config import C -from ...utils import parse_config, transform_end_date, init_instance_by_config +from ...utils import init_instance_by_config from ...utils.serial import Serializable from .utils import fetch_df_by_index, fetch_df_by_col from ...utils import lazy_sort_index -from pathlib import Path from .loader import DataLoader from . import processor as processor_module @@ -228,7 +220,7 @@ class DataHandler(Serializable): proc_func: Callable = None, ): # This method is extracted for sharing in subclasses - from .storage import BaseHandlerStorage + from .storage import BaseHandlerStorage # pylint: disable=C0415 # Following conflictions may occurs # - Does [20200101", "20210101"] mean selecting this slice or these two days? @@ -627,7 +619,6 @@ class DataHandlerLP(DataHandler): ------- pd.DataFrame: """ - from .storage import BaseHandlerStorage return self._fetch_data( data_storage=self._get_df_by_key(data_key), diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 73538703f..c80d60bab 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -51,7 +51,6 @@ class DataLoader(abc.ABC): pd.DataFrame: data load from the under layer source """ - pass class DLWParser(DataLoader): @@ -129,7 +128,6 @@ class DLWParser(DataLoader): pd.DataFrame: the queried dataframe. """ - pass def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame: if self.is_group: @@ -308,7 +306,7 @@ class DataLoaderDH(DataLoader): is_group will be used to describe whether the key of handler_config is group """ - from qlib.data.dataset.handler import DataHandler + from qlib.data.dataset.handler import DataHandler # pylint: disable=C0415 if is_group: self.handlers = { diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index d3d07f822..ec3fa5506 100644 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -42,7 +42,6 @@ class Processor(Serializable): processor, i.e. `df`. """ - pass @abc.abstractmethod def __call__(self, df: pd.DataFrame): @@ -57,7 +56,6 @@ class Processor(Serializable): df : pd.DataFrame The raw_df of handler or result from previous processor. """ - pass def is_for_infer(self) -> bool: """ @@ -201,7 +199,7 @@ class MinMaxNorm(Processor): self.fit_end_time = fit_end_time self.fields_group = fields_group - def fit(self, df): + def fit(self, df: pd.DataFrame = None): df = fetch_df_by_index(df, slice(self.fit_start_time, self.fit_end_time), level="datetime") cols = get_group_columns(df, self.fields_group) self.min_val = np.nanmin(df[cols].values, axis=0) @@ -232,7 +230,7 @@ class ZScoreNorm(Processor): self.fit_end_time = fit_end_time self.fields_group = fields_group - def fit(self, df): + def fit(self, df: pd.DataFrame = None): df = fetch_df_by_index(df, slice(self.fit_start_time, self.fit_end_time), level="datetime") cols = get_group_columns(df, self.fields_group) self.mean_train = np.nanmean(df[cols].values, axis=0) @@ -272,7 +270,7 @@ class RobustZScoreNorm(Processor): self.fields_group = fields_group self.clip_outlier = clip_outlier - def fit(self, df): + def fit(self, df: pd.DataFrame = None): df = fetch_df_by_index(df, slice(self.fit_start_time, self.fit_end_time), level="datetime") self.cols = get_group_columns(df, self.fields_group) X = df[self.cols].values @@ -351,6 +349,6 @@ class HashStockFormat(Processor): """Process the storage of from df into hasing stock format""" def __call__(self, df: pd.DataFrame): - from .storage import HasingStockStorage + from .storage import HasingStockStorage # pylint: disable=C0415 return HasingStockStorage.from_df(df) diff --git a/qlib/data/dataset/storage.py b/qlib/data/dataset/storage.py index 1af78e92a..42f003269 100644 --- a/qlib/data/dataset/storage.py +++ b/qlib/data/dataset/storage.py @@ -2,7 +2,7 @@ import pandas as pd import numpy as np from .handler import DataHandler -from typing import Tuple, Union, List, Callable +from typing import Union, List, Callable from .utils import get_level_index, fetch_df_by_index, fetch_df_by_col @@ -109,7 +109,7 @@ class HasingStockStorage(BaseHandlerStorage): stock_selector = selector[self.stock_level] elif isinstance(selector, (list, str)) and self.stock_level == 0: stock_selector = selector - elif level == "instrument" or level == self.stock_level: + elif level in ("instrument", self.stock_level): if isinstance(selector, tuple): stock_selector = selector[0] elif isinstance(selector, (list, str)): diff --git a/qlib/data/dataset/utils.py b/qlib/data/dataset/utils.py index 4b8fedb0b..390546666 100644 --- a/qlib/data/dataset/utils.py +++ b/qlib/data/dataset/utils.py @@ -63,7 +63,7 @@ def fetch_df_by_index( Data of the given index. """ # level = None -> use selector directly - if level == None: + if level is None: return df.loc(axis=0)[selector] # Try to get the right index idx_slc = (selector, slice(None, None)) @@ -75,7 +75,7 @@ def fetch_df_by_index( return df.loc[ pd.IndexSlice[idx_slc], ] - else: + else: # pylint: disable=W0120 return df else: return df.loc[ @@ -84,7 +84,7 @@ def fetch_df_by_index( def fetch_df_by_col(df: pd.DataFrame, col_set: Union[str, List[str]]) -> pd.DataFrame: - from .handler import DataHandler + from .handler import DataHandler # pylint: disable=C0415 if not isinstance(df.columns, pd.MultiIndex) or col_set == DataHandler.CS_RAW: return df @@ -136,7 +136,7 @@ def init_task_handler(task: dict) -> Union[DataHandler, None]: returns """ # avoid recursive import - from .handler import DataHandler + from .handler import DataHandler # pylint: disable=C0415 h_conf = task["dataset"]["kwargs"].get("handler") if h_conf is not None: diff --git a/qlib/data/dataset/weight.py b/qlib/data/dataset/weight.py index 8e2a6b959..ee8208053 100644 --- a/qlib/data/dataset/weight.py +++ b/qlib/data/dataset/weight.py @@ -1,13 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import pandas as pd -import numpy as np -from typing import Union, List, Tuple -from ...data.dataset import TSDataSampler -from ...data.dataset.utils import get_level_index -from ...utils import lazy_sort_index - class Reweighter: def __init__(self, *args, **kwargs): diff --git a/qlib/data/filter.py b/qlib/data/filter.py index d2e5aa901..c8c36c099 100644 --- a/qlib/data/filter.py +++ b/qlib/data/filter.py @@ -62,7 +62,7 @@ class SeriesDFilter(BaseDFilter): Override _getFilterSeries to use the rule to filter the series and get a dict of {inst => series}, or override filter_main for more advanced series filter rule """ - def __init__(self, fstart_time=None, fend_time=None): + def __init__(self, fstart_time=None, fend_time=None, keep=False): """Init function for filter base class. Filter a set of instruments based on a certain rule within a certain period assigned by fstart_time and fend_time. @@ -72,10 +72,13 @@ class SeriesDFilter(BaseDFilter): the time for the filter rule to start filter the instruments. fend_time: str the time for the filter rule to stop filter the instruments. + keep: bool + whether to keep the instruments of which features don't exist in the filter time span. """ super(SeriesDFilter, self).__init__() self.filter_start_time = pd.Timestamp(fstart_time) if fstart_time else None self.filter_end_time = pd.Timestamp(fend_time) if fend_time else None + self.keep = keep def _getTimeBound(self, instruments): """Get time bound for all instruments. @@ -330,12 +333,9 @@ class ExpressionDFilter(SeriesDFilter): filter the feature ending by this time. rule_expression: str an input expression for the rule. - keep: bool - whether to keep the instruments of which features don't exist in the filter time span. """ - super(ExpressionDFilter, self).__init__(fstart_time, fend_time) + super(ExpressionDFilter, self).__init__(fstart_time, fend_time, keep=keep) self.rule_expression = rule_expression - self.keep = keep def _getFilterSeries(self, instruments, fstart, fend): # do not use dataset cache diff --git a/qlib/data/inst_processor.py b/qlib/data/inst_processor.py index 27b356722..9022f57db 100644 --- a/qlib/data/inst_processor.py +++ b/qlib/data/inst_processor.py @@ -17,7 +17,6 @@ class InstProcessor: df : pd.DataFrame The raw_df of handler or result from previous processor. """ - pass def __str__(self): return f"{self.__class__.__name__}:{json.dumps(self.__dict__, sort_keys=True, default=str)}" diff --git a/qlib/data/ops.py b/qlib/data/ops.py index 3e83e6829..555a29ba4 100644 --- a/qlib/data/ops.py +++ b/qlib/data/ops.py @@ -5,8 +5,6 @@ from __future__ import division from __future__ import print_function -import sys -import abc import numpy as np import pandas as pd @@ -15,7 +13,6 @@ from scipy.stats import percentileofscore from .base import Expression, ExpressionOps, Feature -from ..config import C from ..log import get_module_logger from ..utils import get_callable_kwargs @@ -331,7 +328,7 @@ class NpPairOperator(PairOperator): res = getattr(np, self.func)(series_left, series_right) except ValueError as e: get_module_logger("ops").debug(warning_info) - raise ValueError(f"{str(e)}. \n\t{warning_info}") + raise ValueError(f"{str(e)}. \n\t{warning_info}") from e else: if check_length and len(series_left) != len(series_right): get_module_logger("ops").debug(warning_info) @@ -1430,21 +1427,20 @@ class PairRolling(ExpressionOps): return max(left_br, right_br) def get_extended_window_size(self): + if isinstance(self.feature_left, Expression): + ll, lr = self.feature_left.get_extended_window_size() + else: + ll, lr = 0, 0 + if isinstance(self.feature_right, Expression): + rl, rr = self.feature_right.get_extended_window_size() + else: + rl, rr = 0, 0 if self.N == 0: get_module_logger(self.__class__.__name__).warning( "The PairRolling(ATTR, 0) will not be accurately calculated" ) return -np.inf, max(lr, rr) else: - if isinstance(self.feature_left, Expression): - ll, lr = self.feature_left.get_extended_window_size() - else: - ll, lr = 0, 0 - - if isinstance(self.feature_right, Expression): - rl, rr = self.feature_right.get_extended_window_size() - else: - rl, rr = 0, 0 return max(ll, rl) + self.N - 1, max(lr, rr) diff --git a/qlib/log.py b/qlib/log.py index debc24fa4..8533a7ba3 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -13,7 +13,7 @@ from .config import C class MetaLogger(type): - def __new__(mcs, name, bases, attrs): + def __new__(mcs, name, bases, attrs): # pylint: disable=C0204 wrapper_dict = logging.Logger.__dict__.copy() for key in wrapper_dict: if key not in attrs and key != "__reduce__": @@ -164,7 +164,7 @@ class LogFilter(logging.Filter): if isinstance(self.param, str): allow = not self.match_msg(self.param, record.msg) elif isinstance(self.param, list): - allow = not any([self.match_msg(p, record.msg) for p in self.param]) + allow = not any(self.match_msg(p, record.msg) for p in self.param) return allow @@ -201,7 +201,7 @@ def set_global_logger_level(level: int, return_orig_handler_level: bool = False) """ _handler_level_map = {} - qlib_logger = logging.root.manager.loggerDict.get("qlib", None) + qlib_logger = logging.root.manager.loggerDict.get("qlib", None) # pylint: disable=E1101 if qlib_logger is not None: for _handler in qlib_logger.handlers: _handler_level_map[_handler] = _handler.level diff --git a/qlib/model/base.py b/qlib/model/base.py index 7047b5f44..009a3bd14 100644 --- a/qlib/model/base.py +++ b/qlib/model/base.py @@ -13,7 +13,6 @@ class BaseModel(Serializable, metaclass=abc.ABCMeta): @abc.abstractmethod def predict(self, *args, **kwargs) -> object: """Make predictions after modeling things""" - pass def __call__(self, *args, **kwargs) -> object: """leverage Python syntactic sugar to make the models' behaviors like functions""" diff --git a/qlib/model/ens/group.py b/qlib/model/ens/group.py index a31c647d1..ba6f9f807 100644 --- a/qlib/model/ens/group.py +++ b/qlib/model/ens/group.py @@ -13,7 +13,7 @@ reduce: {(A,B): {C1: object, C2: object}} -> {(A,B): object} """ from qlib.model.ens.ensemble import Ensemble, RollingEnsemble -from typing import Callable, Union +from typing import Callable from joblib import Parallel, delayed diff --git a/qlib/model/interpret/base.py b/qlib/model/interpret/base.py index 57cc7929a..a490d7744 100644 --- a/qlib/model/interpret/base.py +++ b/qlib/model/interpret/base.py @@ -27,6 +27,9 @@ class FeatureInt: class LightGBMFInt(FeatureInt): """LightGBM (F)eature (Int)erpreter""" + def __init__(self): + self.model = None + def get_feature_importance(self, *args, **kwargs) -> pd.Series: """get feature importance @@ -35,6 +38,8 @@ class LightGBMFInt(FeatureInt): parameters reference: https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.Booster.html?highlight=feature_importance#lightgbm.Booster.feature_importance """ - return pd.Series(self.model.feature_importance(*args, **kwargs), index=self.model.feature_name()).sort_values( + return pd.Series( + self.model.feature_importance(*args, **kwargs), index=self.model.feature_name() + ).sort_values( # pylint: disable=E1101 ascending=False ) diff --git a/qlib/model/meta/dataset.py b/qlib/model/meta/dataset.py index 4b56dd1ba..823842897 100644 --- a/qlib/model/meta/dataset.py +++ b/qlib/model/meta/dataset.py @@ -4,8 +4,6 @@ import abc from qlib.model.meta.task import MetaTask from typing import Dict, Union, List, Tuple, Text -from ...workflow.task.gen import RollingGen, task_generator -from ...data.dataset.handler import DataHandler from ...utils.serial import Serializable @@ -73,4 +71,3 @@ class MetaTaskDataset(Serializable, metaclass=abc.ABCMeta): seg : Text the name of the segment """ - pass diff --git a/qlib/model/meta/model.py b/qlib/model/meta/model.py index 224600daa..1f13dba34 100644 --- a/qlib/model/meta/model.py +++ b/qlib/model/meta/model.py @@ -2,10 +2,8 @@ # Licensed under the MIT License. import abc -from qlib.contrib.meta.data_selection.dataset import MetaDatasetDS -from typing import Union, List, Tuple +from typing import List -from qlib.model.meta.task import MetaTask from .dataset import MetaTaskDataset @@ -23,7 +21,6 @@ class MetaModel(metaclass=abc.ABCMeta): """ The training process of the meta-model. """ - pass @abc.abstractmethod def inference(self, *args, **kwargs) -> object: @@ -35,7 +32,6 @@ class MetaModel(metaclass=abc.ABCMeta): object: Some information to guide the model learning """ - pass class MetaTaskModel(MetaModel): diff --git a/qlib/model/meta/task.py b/qlib/model/meta/task.py index f6c2f26f4..f59198830 100644 --- a/qlib/model/meta/task.py +++ b/qlib/model/meta/task.py @@ -1,9 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import abc -from typing import Union, List, Tuple - from qlib.data.dataset import Dataset from ...utils import init_instance_by_config diff --git a/qlib/model/riskmodel/base.py b/qlib/model/riskmodel/base.py index bb067e3d5..7afacfe8f 100644 --- a/qlib/model/riskmodel/base.py +++ b/qlib/model/riskmodel/base.py @@ -91,7 +91,7 @@ class RiskModel(BaseModel): "return_decomposed_components" in inspect.getfullargspec(self._predict).args ), "This risk model does not support return decomposed components of the covariance matrix " - F, cov_b, var_u = self._predict(X, return_decomposed_components=True) + F, cov_b, var_u = self._predict(X, return_decomposed_components=True) # pylint: disable=E1123 return F, cov_b, var_u # estimate covariance diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index a65801bbb..6de94b755 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -12,17 +12,13 @@ In ``DelayTrainer``, the first step is only to save some necessary info to model """ import socket -import time -import re from typing import Callable, List from tqdm.auto import tqdm from qlib.data.dataset import Dataset -from qlib.log import get_module_logger from qlib.model.base import Model -from qlib.utils import flatten_dict, get_callable_kwargs, init_instance_by_config, auto_filter_kwargs, fill_placeholder +from qlib.utils import flatten_dict, init_instance_by_config, auto_filter_kwargs, fill_placeholder from qlib.workflow import R -from qlib.workflow.record_temp import SignalRecord from qlib.workflow.recorder import Recorder from qlib.workflow.task.manage import TaskManager, run_task from qlib.data.dataset.weight import Reweighter diff --git a/qlib/rl/env.py b/qlib/rl/env.py index 77da90718..6173a27a8 100644 --- a/qlib/rl/env.py +++ b/qlib/rl/env.py @@ -7,7 +7,6 @@ from typing import Union from ..backtest.executor import BaseExecutor from .interpreter import StateInterpreter, ActionInterpreter from ..utils import init_instance_by_config -from .interpreter import BaseInterpreter class BaseRLEnv: diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index bf641343b..a2d5e198a 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -6,12 +6,8 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from qlib.backtest.exchange import Exchange from qlib.backtest.position import BasePosition -from typing import List, Tuple, Union -import pandas as pd +from typing import Tuple, Union -from ..model.base import BaseModel -from ..data.dataset import DatasetH -from ..data.dataset.utils import convert_index_format from ..rl.interpreter import ActionInterpreter, StateInterpreter from ..utils import init_instance_by_config from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 13f202a31..a54f9a296 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -139,8 +139,8 @@ def parse_config(config): # Check whether the str can be parsed try: return yaml.safe_load(config) - except BaseException: - raise ValueError("cannot parse config!") + except BaseException as base_exp: + raise ValueError("cannot parse config!") from base_exp #################### Other #################### @@ -436,7 +436,7 @@ def is_tradable_date(cur_date): date : pandas.Timestamp current date """ - from ..data import D + from ..data import D # pylint: disable=C0415 return str(cur_date.date()) == str(D.calendar(start_time=cur_date, future=True)[0].date()) @@ -453,7 +453,7 @@ def get_date_range(trading_date, left_shift=0, right_shift=0, future=False): """ - from ..data import D + from ..data import D # pylint: disable=C0415 start = get_date_by_shift(trading_date, left_shift, future=future) end = get_date_by_shift(trading_date, right_shift, future=future) @@ -476,7 +476,7 @@ def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq=" when align is "left"/"right", it will try to align to left/right nearest trading date before shifting when `trading_date` is not a trading date """ - from qlib.data import D + from qlib.data import D # pylint: disable=C0415 cal = D.calendar(future=future, freq=freq) trading_date = pd.to_datetime(trading_date) @@ -529,7 +529,7 @@ def transform_end_date(end_date=None, freq="day"): date : pandas.Timestamp current date """ - from ..data import D + from ..data import D # pylint: disable=C0415 last_date = D.calendar(freq=freq)[-1] if end_date is None or (str(end_date) == "-1") or (pd.Timestamp(last_date) < pd.Timestamp(end_date)): @@ -810,7 +810,7 @@ def fill_placeholder(config: dict, config_extend: dict): elif isinstance(now_item, dict): item_keys = now_item.keys() for key in item_keys: - if isinstance(now_item[key], list) or isinstance(now_item[key], dict): + if isinstance(now_item[key], (list, dict)): item_queue.append(now_item[key]) tail += 1 elif isinstance(now_item[key], str): diff --git a/qlib/utils/exceptions.py b/qlib/utils/exceptions.py index c869f5d73..9fa5c6dfe 100644 --- a/qlib/utils/exceptions.py +++ b/qlib/utils/exceptions.py @@ -10,16 +10,10 @@ class QlibException(Exception): class RecorderInitializationError(QlibException): """Error type for re-initialization when starting an experiment""" - pass - class LoadObjectError(QlibException): """Error type for Recorder when can not load object""" - pass - class ExpAlreadyExistError(Exception): """Experiment already exists""" - - pass diff --git a/qlib/utils/file.py b/qlib/utils/file.py index 9fa83d7f1..1e17a574a 100644 --- a/qlib/utils/file.py +++ b/qlib/utils/file.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import contextlib import os import shutil import tempfile diff --git a/qlib/utils/index_data.py b/qlib/utils/index_data.py index 4b5bb8456..26d59c4fe 100644 --- a/qlib/utils/index_data.py +++ b/qlib/utils/index_data.py @@ -153,8 +153,8 @@ class Index: """ try: return self.index_map[self._convert_type(item)] - except IndexError: - raise KeyError(f"{item} can't be found in {self}") + except IndexError as index_e: + raise KeyError(f"{item} can't be found in {self}") from index_e def __or__(self, other: "Index"): return Index(idx_list=list(set(self.idx_list) | set(other.idx_list))) diff --git a/qlib/utils/objm.py b/qlib/utils/objm.py index c125a6ae1..aa9bed564 100644 --- a/qlib/utils/objm.py +++ b/qlib/utils/objm.py @@ -101,8 +101,10 @@ class FileManager(ObjManager): def create_path(self) -> str: try: return tempfile.mkdtemp(prefix=str(C["file_manager_path"]) + os.sep) - except AttributeError: - raise NotImplementedError(f"If path is not given, the `create_path` function should be implemented") + except AttributeError as attribute_e: + raise NotImplementedError( + f"If path is not given, the `create_path` function should be implemented" + ) from attribute_e def save_obj(self, obj, name): with (self.path / name).open("wb") as f: diff --git a/qlib/utils/resam.py b/qlib/utils/resam.py index 94c1c7164..e9e3cacc7 100644 --- a/qlib/utils/resam.py +++ b/qlib/utils/resam.py @@ -70,12 +70,12 @@ def get_higher_eq_freq_feature(instruments, fields, start_time=None, end_time=No the feature with higher or equal frequency """ - from ..data.data import D + from ..data.data import D # pylint: disable=C0415 try: _result = D.features(instruments, fields, start_time, end_time, freq=freq, disk_cache=disk_cache) _freq = freq - except (ValueError, KeyError): + except (ValueError, KeyError) as value_key_e: _, norm_freq = Freq.parse(freq) if norm_freq in [Freq.NORM_FREQ_MONTH, Freq.NORM_FREQ_WEEK, Freq.NORM_FREQ_DAY]: try: @@ -88,7 +88,7 @@ def get_higher_eq_freq_feature(instruments, fields, start_time=None, end_time=No _result = D.features(instruments, fields, start_time, end_time, freq="1min", disk_cache=disk_cache) _freq = "1min" else: - raise ValueError(f"freq {freq} is not supported") + raise ValueError(f"freq {freq} is not supported") from value_key_e return _result, _freq @@ -172,7 +172,7 @@ def resam_ts_data( selector_datetime = slice(start_time, end_time) - from ..data.dataset.utils import get_level_index + from ..data.dataset.utils import get_level_index # pylint: disable=C0415 feature = lazy_sort_index(ts_feature) diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index 47009b792..a528fa67a 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. from contextlib import contextmanager -from typing import Text, Optional, Any, Dict, Text, Optional +from typing import Text, Optional, Any, Dict from .expm import ExpManager from .exp import Experiment from .recorder import Recorder diff --git a/qlib/workflow/cli.py b/qlib/workflow/cli.py index f0be8f4e8..ecb3e9fdf 100644 --- a/qlib/workflow/cli.py +++ b/qlib/workflow/cli.py @@ -1,7 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import sys, os +import sys +import os from pathlib import Path import qlib diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index 9bf0b2262..b9f420015 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -2,10 +2,10 @@ # Licensed under the MIT License. from typing import Dict, List, Union -import mlflow, logging +import mlflow +import logging from mlflow.entities import ViewType from mlflow.exceptions import MlflowException -from pathlib import Path from .recorder import Recorder, MLflowRecorder from ..log import get_module_logger @@ -271,7 +271,7 @@ class MLflowExperiment(Experiment): return self.active_recorder - def end(self, recorder_status): + def end(self, recorder_status=Recorder.STATUS_S): if self.active_recorder is not None: self.active_recorder.end_run(recorder_status) self.active_recorder = None @@ -299,8 +299,10 @@ class MLflowExperiment(Experiment): run = self._client.get_run(recorder_id) recorder = MLflowRecorder(self.id, self._uri, mlflow_run=run) return recorder - except MlflowException: - raise ValueError("No valid recorder has been found, please make sure the input recorder id is correct.") + except MlflowException as mlflow_exp: + raise ValueError( + "No valid recorder has been found, please make sure the input recorder id is correct." + ) from mlflow_exp elif recorder_name is not None: logger.warning( f"Please make sure the recorder name {recorder_name} is unique, we will only return the latest recorder if there exist several matched the given name." @@ -332,7 +334,7 @@ class MLflowExperiment(Experiment): except MlflowException as e: raise Exception( f"Error: {e}. Something went wrong when deleting recorder. Please check if the name/id of the recorder is correct." - ) + ) from e UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!! @@ -362,10 +364,10 @@ class MLflowExperiment(Experiment): ) rids = [] recorders = [] - for i in range(len(runs)): - recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i]) + for i, n in enumerate(runs): + recorder = MLflowRecorder(self.id, self._uri, mlflow_run=n) if status is None or recorder.status == status: - rids.append(runs[i].info.run_id) + rids.append(n.info.run_id) recorders.append(recorder) if rtype == Experiment.RT_D: diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index f16c58ddb..5d9896e5c 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -6,9 +6,7 @@ import mlflow from filelock import FileLock from mlflow.exceptions import MlflowException, RESOURCE_ALREADY_EXISTS, ErrorCode from mlflow.entities import ViewType -import os, logging -from pathlib import Path -from contextlib import contextmanager +import os from typing import Optional, Text from .exp import MLflowExperiment, Experiment @@ -203,7 +201,7 @@ class ExpManager: # So we supported it in the interface wrapper pr = urlparse(self.uri) if pr.scheme == "file": - with FileLock(os.path.join(pr.netloc, pr.path, "filelock")) as f: + with FileLock(os.path.join(pr.netloc, pr.path, "filelock")) as f: # pylint: disable=E0110 return self.create_exp(experiment_name), True # NOTE: for other schemes like http, we double check to avoid create exp conflicts try: @@ -363,7 +361,7 @@ class MLflowExpManager(ExpManager): experiment_id = self.client.create_experiment(experiment_name) except MlflowException as e: if e.error_code == ErrorCode.Name(RESOURCE_ALREADY_EXISTS): - raise ExpAlreadyExistError() + raise ExpAlreadyExistError() from e raise e experiment = MLflowExperiment(experiment_id, experiment_name, self.uri) @@ -387,10 +385,10 @@ class MLflowExpManager(ExpManager): raise MlflowException("No valid experiment has been found.") experiment = MLflowExperiment(exp.experiment_id, exp.name, self.uri) return experiment - except MlflowException: + except MlflowException as e: raise ValueError( "No valid experiment has been found, please make sure the input experiment id is correct." - ) + ) from e elif experiment_name is not None: try: exp = self.client.get_experiment_by_name(experiment_name) @@ -401,9 +399,9 @@ class MLflowExpManager(ExpManager): except MlflowException as e: raise ValueError( "No valid experiment has been found, please make sure the input experiment name is correct." - ) + ) from e - def search_records(self, experiment_ids, **kwargs): + def search_records(self, experiment_ids=None, **kwargs): filter_string = "" if kwargs.get("filter_string") is None else kwargs.get("filter_string") run_view_type = 1 if kwargs.get("run_view_type") is None else kwargs.get("run_view_type") max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results") @@ -425,7 +423,7 @@ class MLflowExpManager(ExpManager): except MlflowException as e: raise Exception( f"Error: {e}. Something went wrong when deleting experiment. Please check if the name/id of the experiment is correct." - ) + ) from e def list_experiments(self): # retrieve all the existing experiments diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py index e9f0fe9d2..aeeb111b2 100644 --- a/qlib/workflow/online/manager.py +++ b/qlib/workflow/online/manager.py @@ -83,15 +83,14 @@ For simplicity """ import logging -from typing import Callable, Dict, List, Union +from typing import Callable, List, Union import pandas as pd from qlib import get_module_logger from qlib.data.data import D from qlib.log import set_global_logger_level from qlib.model.ens.ensemble import AverageEnsemble -from qlib.model.trainer import DelayTrainerR, Trainer, TrainerR -from qlib.utils import flatten_dict +from qlib.model.trainer import Trainer, TrainerR from qlib.utils.serial import Serializable from qlib.workflow.online.strategy import OnlineStrategy from qlib.workflow.task.collect import MergeCollector diff --git a/qlib/workflow/online/strategy.py b/qlib/workflow/online/strategy.py index 7a923ebad..bda068dbf 100644 --- a/qlib/workflow/online/strategy.py +++ b/qlib/workflow/online/strategy.py @@ -5,9 +5,7 @@ OnlineStrategy module is an element of online serving. """ -from copy import deepcopy -from typing import List, Tuple, Union -from qlib.data.data import D +from typing import List, Union from qlib.log import get_module_logger from qlib.model.ens.group import RollingGroup from qlib.utils import transform_end_date diff --git a/qlib/workflow/online/update.py b/qlib/workflow/online/update.py index 1e8c7d750..d4c4df517 100644 --- a/qlib/workflow/online/update.py +++ b/qlib/workflow/online/update.py @@ -148,7 +148,7 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta): self.rmdl = loader_cls(rec=record) latest_date = D.calendar(freq=freq)[-1] - if to_date == None: + if to_date is None: to_date = latest_date to_date = pd.Timestamp(to_date) @@ -191,7 +191,9 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta): else: hist_ref = self.hist_ref - start_time_buffer = get_date_by_shift(self.last_end, -hist_ref + 1, clip_shift=False, freq=self.freq) + start_time_buffer = get_date_by_shift( + self.last_end, -hist_ref + 1, clip_shift=False, freq=self.freq # pylint: disable=E1130 + ) start_time = get_date_by_shift(self.last_end, 1, freq=self.freq) seg = {"test": (start_time, self.to_date)} return self.rmdl.get_dataset( diff --git a/qlib/workflow/online/utils.py b/qlib/workflow/online/utils.py index 75ff3c4fd..c390ca009 100644 --- a/qlib/workflow/online/utils.py +++ b/qlib/workflow/online/utils.py @@ -8,10 +8,8 @@ This allows us to use efficient submodels as the market-style changing. """ from typing import List, Union -from qlib.data.dataset import TSDatasetH from qlib.log import get_module_logger -from qlib.utils import get_callable_kwargs from qlib.utils.exceptions import LoadObjectError from qlib.workflow.online.update import PredUpdater from qlib.workflow.recorder import Recorder diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 1186920a7..3d64e268e 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -2,15 +2,18 @@ # Licensed under the MIT License. import os -from qlib.utils.serial import Serializable -import mlflow, logging -import shutil, os, pickle, tempfile, codecs, pickle +import mlflow +import logging +import shutil +import pickle +import tempfile from pathlib import Path from datetime import datetime +from qlib.utils.serial import Serializable from qlib.utils.exceptions import LoadObjectError from qlib.utils.paral import AsyncCaller -from ..utils.objm import FileManager + from ..log import TimeInspector, get_module_logger from mlflow.store.artifact.azure_blob_artifact_repo import AzureBlobArtifactRepository @@ -355,7 +358,7 @@ class MLflowRecorder(Recorder): shutil.rmtree(Path(path).absolute().parent) return data except Exception as e: - raise LoadObjectError(str(e)) + raise LoadObjectError(str(e)) from e @AsyncCaller.async_dec(ac_attr="async_log") def log_params(self, **kwargs): diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py index 97b9abb46..7ef7b4ed9 100644 --- a/qlib/workflow/task/gen.py +++ b/qlib/workflow/task/gen.py @@ -82,7 +82,6 @@ class TaskGen(metaclass=abc.ABCMeta): typing.List[dict]: A list of tasks """ - pass def __call__(self, *args, **kwargs): """ diff --git a/qlib/workflow/task/utils.py b/qlib/workflow/task/utils.py index a8c147f0d..77fd9fa2e 100644 --- a/qlib/workflow/task/utils.py +++ b/qlib/workflow/task/utils.py @@ -189,7 +189,7 @@ class TimeAdjuster: """ if isinstance(segment, dict): return {k: self.align_seg(seg) for k, seg in segment.items()} - elif isinstance(segment, tuple) or isinstance(segment, list): + elif isinstance(segment, (tuple, list)): return self.align_time(segment[0], tp_type="start"), self.align_time(segment[1], tp_type="end") else: raise NotImplementedError(f"This type of input is not supported")