diff --git a/examples/nested_decision_execution/workflow.py b/examples/nested_decision_execution/workflow.py index 24af42123..1d602d7fe 100644 --- a/examples/nested_decision_execution/workflow.py +++ b/examples/nested_decision_execution/workflow.py @@ -155,6 +155,8 @@ class NestedDecisionExecutionWorkflow: }, } + exp_name = "nested" + port_analysis_config = { "executor": { "class": "NestedExecutor", @@ -230,7 +232,7 @@ class NestedDecisionExecutionWorkflow: qlib.init(provider_uri=provider_uri_map, dataset_cache=None, expression_cache=None) def _train_model(self, model, dataset): - with R.start(experiment_name="train"): + with R.start(experiment_name=self.exp_name): R.log_params(**flatten_dict(self.task)) model.fit(dataset) R.save_objects(**{"params.pkl": model}) @@ -257,7 +259,7 @@ class NestedDecisionExecutionWorkflow: self.port_analysis_config["strategy"] = strategy_config self.port_analysis_config["backtest"]["benchmark"] = self.benchmark - with R.start(experiment_name="backtest"): + with R.start(experiment_name=self.exp_name, resume=True): recorder = R.get_recorder() par = PortAnaRecord( recorder, @@ -382,7 +384,7 @@ class NestedDecisionExecutionWorkflow: } pa_conf["backtest"]["benchmark"] = self.benchmark - with R.start(experiment_name="backtest"): + with R.start(experiment_name=self.exp_name, resume=True): recorder = R.get_recorder() par = PortAnaRecord(recorder, pa_conf) par.generate() diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index dce1e80a8..902394a9c 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -536,7 +536,7 @@ class Exchange: deal_amount = self.get_real_deal_amount(current_amount, target_amount, factor) if deal_amount == 0: continue - elif deal_amount > 0: + if deal_amount > 0: # buy stock buy_order_list.append( Order( @@ -687,9 +687,7 @@ class Exchange: orig_deal_amount = order.deal_amount order.deal_amount = max(min(vol_limit_min, orig_deal_amount), 0) if vol_limit_min < orig_deal_amount: - self.logger.debug( - f"Order clipped due to volume limitation: {order}, {[(vol, rule) for vol, rule in zip(vol_limit_num, vol_limit)]}" - ) + self.logger.debug(f"Order clipped due to volume limitation: {order}, {list(zip(vol_limit_num, vol_limit))}") def _get_buy_amount_by_cash_limit(self, trade_price, cash, cost_ratio): """return the real order amount after cash limit for buying. diff --git a/qlib/backtest/executor.py b/qlib/backtest/executor.py index 94aa84d6d..37098379b 100644 --- a/qlib/backtest/executor.py +++ b/qlib/backtest/executor.py @@ -194,7 +194,7 @@ class BaseExecutor: return return_value.get("execute_result") @abstractclassmethod - def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]: + def _collect_data(cls, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]: """ Please refer to the doc of collect_data The only difference between `_collect_data` and `collect_data` is that some common steps are moved into diff --git a/qlib/backtest/position.py b/qlib/backtest/position.py index 907da9975..ddb5c24e8 100644 --- a/qlib/backtest/position.py +++ b/qlib/backtest/position.py @@ -20,7 +20,7 @@ class BasePosition: Please refer to the `Position` class for the position """ - def __init__(self, cash=0.0, *args, **kwargs): + def __init__(self, *args, cash=0.0, **kwargs): self._settle_type = self.ST_NO def skip_update(self) -> bool: diff --git a/qlib/backtest/profit_attribution.py b/qlib/backtest/profit_attribution.py index e5b61f8d6..371cb422a 100644 --- a/qlib/backtest/profit_attribution.py +++ b/qlib/backtest/profit_attribution.py @@ -156,16 +156,16 @@ def decompose_portofolio(stock_weight_df, stock_group_df, stock_ret_df): group_weight, stock_weight_in_group = decompose_portofolio_weight(stock_weight_df, stock_group_df) group_ret = {} - for group_key in stock_weight_in_group: - stock_weight_in_group_start_date = min(stock_weight_in_group[group_key].index) - stock_weight_in_group_end_date = max(stock_weight_in_group[group_key].index) + for group_key, val in stock_weight_in_group.items(): + stock_weight_in_group_start_date = min(val.index) + stock_weight_in_group_end_date = max(val.index) temp_stock_ret_df = stock_ret_df[ (stock_ret_df.index >= stock_weight_in_group_start_date) & (stock_ret_df.index <= stock_weight_in_group_end_date) ] - group_ret[group_key] = (temp_stock_ret_df * stock_weight_in_group[group_key]).sum(axis=1) + group_ret[group_key] = (temp_stock_ret_df * val).sum(axis=1) # If no weight is assigned, then the return of group will be np.nan group_ret[group_key][group_weight[group_key] == 0.0] = np.nan diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index 69ae720a2..f846aea9d 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -212,7 +212,8 @@ class PortfolioMetrics: path: str/ pathlib.Path() """ path = pathlib.Path(path) - r = pd.read_csv(open(path, "rb"), index_col=0) + with path.open("rb") as f: + r = pd.read_csv(f, index_col=0) r.index = pd.DatetimeIndex(r.index) index = r.index diff --git a/qlib/backtest/utils.py b/qlib/backtest/utils.py index 25a0d9965..5fa02420d 100644 --- a/qlib/backtest/utils.py +++ b/qlib/backtest/utils.py @@ -205,10 +205,7 @@ class BaseInfrastructure: warnings.warn(f"infra {infra_name} is not found!") def has(self, infra_name): - if infra_name in self.get_support_infra() and hasattr(self, infra_name): - return True - else: - return False + return infra_name in self.get_support_infra() and hasattr(self, infra_name) def update(self, other): support_infra = other.get_support_infra() diff --git a/qlib/contrib/data/dataset.py b/qlib/contrib/data/dataset.py index af4893acf..e8a048762 100644 --- a/qlib/contrib/data/dataset.py +++ b/qlib/contrib/data/dataset.py @@ -63,9 +63,7 @@ def _get_date_parse_fn(target): get_date_parse_fn('20120101')('2017-01-01') => '20170101' get_date_parse_fn(20120101)('2017-01-01') => 20170101 """ - if isinstance(target, pd.Timestamp): - _fn = lambda x: pd.Timestamp(x) # Timestamp('2020-01-01') - elif isinstance(target, int): + if isinstance(target, int): _fn = lambda x: int(str(x).replace("-", "")[:8]) # 20200201 elif isinstance(target, str) and len(target) == 8: _fn = lambda x: str(x).replace("-", "")[:8] # '20200201' @@ -158,7 +156,7 @@ class MTSDatasetH(DatasetH): try: df = self.handler._learn.copy() # use copy otherwise recorder will fail # FIXME: currently we cannot support switching from `_learn` to `_infer` for inference - except: + except Exception: warnings.warn("cannot access `_learn`, will load raw data") df = self.handler._data.copy() df.index = df.index.swaplevel() diff --git a/qlib/contrib/evaluate.py b/qlib/contrib/evaluate.py index 5a74757de..37853942a 100644 --- a/qlib/contrib/evaluate.py +++ b/qlib/contrib/evaluate.py @@ -371,7 +371,7 @@ def long_short_backtest( def t_run(): pred_FN = "./check_pred.csv" - pred = pd.read_csv(pred_FN) + pred: pd.DataFrame = pd.read_csv(pred_FN) pred["datetime"] = pd.to_datetime(pred["datetime"]) pred = pred.set_index([pred.columns[0], pred.columns[1]]) pred = pred.iloc[:9000] diff --git a/qlib/contrib/model/pytorch_adarnn.py b/qlib/contrib/model/pytorch_adarnn.py index aad01011c..aaa949b14 100644 --- a/qlib/contrib/model/pytorch_adarnn.py +++ b/qlib/contrib/model/pytorch_adarnn.py @@ -554,7 +554,7 @@ class AdaRNN(nn.Module): return fc_out -class TransferLoss(object): +class TransferLoss: def __init__(self, loss_type="cosine", input_dim=512): """ Supported loss_type: mmd(mmd_lin), mmd_rbf, coral, cosine, kl, js, mine, adv diff --git a/qlib/contrib/model/pytorch_nn.py b/qlib/contrib/model/pytorch_nn.py index 26ac666f6..5849e613d 100644 --- a/qlib/contrib/model/pytorch_nn.py +++ b/qlib/contrib/model/pytorch_nn.py @@ -98,7 +98,6 @@ class DNNModelPytorch(Model): "\nlr_decay_steps : {}" "\noptimizer : {}" "\nloss_type : {}" - "\neval_steps : {}" "\nseed : {}" "\ndevice : {}" "\nuse_GPU : {}" @@ -113,7 +112,6 @@ class DNNModelPytorch(Model): lr_decay_steps, optimizer, loss, - eval_steps, seed, self.device, self.use_gpu, @@ -331,8 +329,8 @@ class Net(nn.Module): dnn_layers = [] drop_input = nn.Dropout(0.05) dnn_layers.append(drop_input) - for i, (input_dim, hidden_units) in enumerate(zip(layers[:-1], layers[1:])): - fc = nn.Linear(input_dim, hidden_units) + for i, (_input_dim, hidden_units) in enumerate(zip(layers[:-1], layers[1:])): + fc = nn.Linear(_input_dim, hidden_units) activation = nn.LeakyReLU(negative_slope=0.1, inplace=False) bn = nn.BatchNorm1d(hidden_units) seq = nn.Sequential(fc, bn, activation) diff --git a/qlib/contrib/model/pytorch_tra.py b/qlib/contrib/model/pytorch_tra.py index 423474c1f..42df4e5c1 100644 --- a/qlib/contrib/model/pytorch_tra.py +++ b/qlib/contrib/model/pytorch_tra.py @@ -19,7 +19,7 @@ import torch.nn.functional as F try: from torch.utils.tensorboard import SummaryWriter -except: +except ImportError: SummaryWriter = None from tqdm import tqdm @@ -257,7 +257,7 @@ class TRAModel(Model): total_loss += loss.item() total_count += 1 - if self.use_daily_transport and len(P_all): + if self.use_daily_transport and len(P_all) > 0: P_all = pd.concat(P_all, axis=0) prob_all = pd.concat(prob_all, axis=0) choice_all = pd.concat(choice_all, axis=0) diff --git a/qlib/contrib/report/graph.py b/qlib/contrib/report/graph.py index edb7e018d..9aa3631af 100644 --- a/qlib/contrib/report/graph.py +++ b/qlib/contrib/report/graph.py @@ -15,7 +15,6 @@ from plotly.figure_factory import create_distplot class BaseGraph: - """ """ _name = None @@ -297,8 +296,8 @@ class SubplotsGraph: :return: """ - self._sub_graph_data = list() - self._subplot_titles = list() + self._sub_graph_data = [] + self._subplot_titles = [] for i, column_name in enumerate(self._df.columns): row = math.ceil((i + 1) / self.__cols) diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index d2168219b..c35cf62a3 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -594,7 +594,7 @@ class TSDatasetH(DatasetH): flt_kwargs = deepcopy(kwargs) if flt_col is not None: flt_kwargs["col_set"] = flt_col - flt_data = self._prepare_seg(ext_slice, **flt_kwargs) + flt_data = super()._prepare_seg(ext_slice, **flt_kwargs) assert len(flt_data.columns) == 1 else: flt_data = None diff --git a/qlib/data/ops.py b/qlib/data/ops.py index 7eb3a005f..049adece9 100644 --- a/qlib/data/ops.py +++ b/qlib/data/ops.py @@ -1407,14 +1407,14 @@ class PairRolling(ExpressionOps): ) def get_extended_window_size(self): + ll, lr = self.feature_left.get_extended_window_size() + rl, rr = self.feature_right.get_extended_window_size() if self.N == 0: get_module_logger(self.__class__.__name__).warning( "The PairRolling(ATTR, 0) will not be accurately calculated" ) - return self.feature.get_extended_window_size() + return -np.inf, max(lr, rr) else: - ll, lr = self.feature_left.get_extended_window_size() - rl, rr = self.feature_right.get_extended_window_size() return max(ll, rl) + self.N - 1, max(lr, rr) diff --git a/qlib/data/storage/file_storage.py b/qlib/data/storage/file_storage.py index 31f2712a2..4058b85c2 100644 --- a/qlib/data/storage/file_storage.py +++ b/qlib/data/storage/file_storage.py @@ -120,7 +120,7 @@ class FileCalendarStorage(FileStorageMixin, CalendarStorage): # If cache is enabled, then return cache directly if self.enable_read_cache: key = "orig_file" + str(self.uri) - if not key in H["c"]: + if key not in H["c"]: H["c"][key] = self._read_calendar() _calendar = H["c"][key] else: diff --git a/qlib/model/riskmodel/structured.py b/qlib/model/riskmodel/structured.py index 96b426ae7..d9a2ec130 100644 --- a/qlib/model/riskmodel/structured.py +++ b/qlib/model/riskmodel/structured.py @@ -50,7 +50,7 @@ class StructuredCovEstimator(RiskModel): num_factors (int): number of components to keep. kwargs: see `RiskModel` for more information """ - if "nan_option" in kwargs.keys(): + if "nan_option" in kwargs: assert kwargs["nan_option"] in [self.DEFAULT_NAN_OPTION], "nan_option={} is not supported".format( kwargs["nan_option"] ) diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index 9b8bf2726..187a14817 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -254,21 +254,21 @@ class TrainerR(Trainer): recs.append(rec) return recs - def end_train(self, recs: list, **kwargs) -> List[Recorder]: + def end_train(self, models: list, **kwargs) -> List[Recorder]: """ Set STATUS_END tag to the recorders. Args: - recs (list): a list of trained recorders. + models (list): a list of trained recorders. Returns: List[Recorder]: the same list as the param. """ - if isinstance(recs, Recorder): - recs = [recs] - for rec in recs: + if isinstance(models, Recorder): + models = [models] + for rec in models: rec.set_tags(**{self.STATUS_KEY: self.STATUS_END}) - return recs + return models class DelayTrainerR(TrainerR): @@ -289,13 +289,13 @@ class DelayTrainerR(TrainerR): self.end_train_func = end_train_func self.delay = True - def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]: + def end_train(self, models, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]: """ Given a list of Recorder and return a list of trained Recorder. This class will finish real data loading and model fitting. Args: - recs (list): a list of Recorder, the tasks have been saved to them + models (list): a list of Recorder, the tasks have been saved to them end_train_func (Callable, optional): the end_train method which needs at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func. experiment_name (str): the experiment name, None for use default name. kwargs: the params for end_train_func. @@ -303,18 +303,18 @@ class DelayTrainerR(TrainerR): Returns: List[Recorder]: a list of Recorders """ - if isinstance(recs, Recorder): - recs = [recs] + if isinstance(models, Recorder): + models = [models] if end_train_func is None: end_train_func = self.end_train_func if experiment_name is None: experiment_name = self.experiment_name - for rec in recs: + for rec in models: if rec.list_tags()[self.STATUS_KEY] == self.STATUS_END: continue end_train_func(rec, experiment_name, **kwargs) rec.set_tags(**{self.STATUS_KEY: self.STATUS_END}) - return recs + return models class TrainerRM(Trainer): diff --git a/qlib/tests/data.py b/qlib/tests/data.py index 0c169c022..2a7281203 100644 --- a/qlib/tests/data.py +++ b/qlib/tests/data.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import re +import sys import qlib import shutil import zipfile @@ -101,7 +102,7 @@ class GetData: f"\nAre you sure you want to delete, yes(Y/y), no (N/n):" ) if str(flag) not in ["Y", "y"]: - exit() + sys.exit() for _p in rm_dirs: logger.warning(f"delete: {_p}") shutil.rmtree(_p) diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 9c8098473..04a1931b4 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -654,16 +654,13 @@ def exists_qlib_data(qlib_dir): def check_qlib_data(qlib_config): inst_dir = Path(qlib_config["provider_uri"]).joinpath("instruments") for _p in inst_dir.glob("*.txt"): - try: - assert len(pd.read_csv(_p, sep="\t", nrows=0, header=None).columns) == 3, ( - f"\nThe {str(_p.resolve())} of qlib data is not equal to 3 columns:" - f"\n\tIf you are using the data provided by qlib: " - f"https://qlib.readthedocs.io/en/latest/component/data.html#qlib-format-dataset" - f"\n\tIf you are using your own data, please dump the data again: " - f"https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format" - ) - except AssertionError: - raise + assert len(pd.read_csv(_p, sep="\t", nrows=0, header=None).columns) == 3, ( + f"\nThe {str(_p.resolve())} of qlib data is not equal to 3 columns:" + f"\n\tIf you are using the data provided by qlib: " + f"https://qlib.readthedocs.io/en/latest/component/data.html#qlib-format-dataset" + f"\n\tIf you are using your own data, please dump the data again: " + f"https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format" + ) def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame: diff --git a/qlib/utils/exceptions.py b/qlib/utils/exceptions.py index ed9d567be..c869f5d73 100644 --- a/qlib/utils/exceptions.py +++ b/qlib/utils/exceptions.py @@ -4,8 +4,7 @@ # Base exception class class QlibException(Exception): - def __init__(self, message): - super(QlibException, self).__init__(message) + pass class RecorderInitializationError(QlibException): diff --git a/qlib/utils/paral.py b/qlib/utils/paral.py index 48b427a28..464b3c122 100644 --- a/qlib/utils/paral.py +++ b/qlib/utils/paral.py @@ -80,8 +80,7 @@ class AsyncCaller: data = self._q.get() if data == self.STOP_MARK: break - else: - data() + data() def __call__(self, func, *args, **kwargs): self._q.put(partial(func, *args, **kwargs)) diff --git a/qlib/utils/resam.py b/qlib/utils/resam.py index d4a19b655..94c1c7164 100644 --- a/qlib/utils/resam.py +++ b/qlib/utils/resam.py @@ -187,7 +187,7 @@ def resam_ts_data( if isinstance(feature.index, pd.MultiIndex): if callable(method): method_func = method - return feature.groupby(level="instrument").apply(lambda x: method_func(x, **method_kwargs)) + return feature.groupby(level="instrument").apply(method_func, **method_kwargs) elif isinstance(method, str): return getattr(feature.groupby(level="instrument"), method)(**method_kwargs) else: diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index cae050b64..c2fd93fff 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -416,6 +416,11 @@ class QlibRecorder: # Case 5 recorder = R.get_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d', experiment_name='test') + + Here are some things users may concern + - Q: What recorder will it return if multiple recorder meets the query (e.g. query with experiment_name) + - A: If mlflow backend is used, then the recorder with the latest `start_time` will be returned. Because MLflow's `search_runs` function guarantee it + Parameters ---------- recorder_id : str diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index f3b2916d7..9bf0b2262 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -287,6 +287,9 @@ class MLflowExperiment(Experiment): """ Method for getting or creating a recorder. It will try to first get a valid recorder, if exception occurs, it will raise errors. + + Quoting docs of search_runs from MLflow + > The default ordering is to sort by start_time DESC, then run_id. """ assert ( recorder_id is not None or recorder_name is not None diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index a04bee3eb..1186920a7 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -355,7 +355,7 @@ class MLflowRecorder(Recorder): shutil.rmtree(Path(path).absolute().parent) return data except Exception as e: - raise LoadObjectError(message=str(e)) + raise LoadObjectError(str(e)) @AsyncCaller.async_dec(ac_attr="async_log") def log_params(self, **kwargs):