From da9d1c8ac65ba4cfbf48acf2025a24a615f54a31 Mon Sep 17 00:00:00 2001 From: Jactus Date: Thu, 29 Oct 2020 13:22:49 +0800 Subject: [PATCH] Format with black --- qlib/__init__.py | 10 +-- qlib/config.py | 5 +- qlib/contrib/backtest/backtest.py | 9 +-- qlib/contrib/data/handler.py | 76 ++++++++++----------- qlib/contrib/data/processor.py | 5 +- qlib/contrib/evaluate.py | 6 +- qlib/contrib/model/gbdt.py | 51 ++++++++------- qlib/data/dataset/__init__.py | 25 ++++--- qlib/data/dataset/handler.py | 105 +++++++++++++++++------------- qlib/data/dataset/loader.py | 57 ++++++++-------- qlib/data/dataset/processor.py | 22 ++++--- qlib/data/dataset/utils.py | 3 +- qlib/model/base.py | 4 +- qlib/utils/__init__.py | 12 ++-- qlib/utils/objm.py | 15 +++-- qlib/utils/serial.py | 8 +-- qlib/workflow/__init__.py | 12 ++-- qlib/workflow/exp.py | 15 +++-- qlib/workflow/expm.py | 67 ++++++++++--------- qlib/workflow/record.py | 34 +++++----- 20 files changed, 290 insertions(+), 251 deletions(-) diff --git a/qlib/__init__.py b/qlib/__init__.py index f2b2c28ac..154d4ea08 100644 --- a/qlib/__init__.py +++ b/qlib/__init__.py @@ -39,7 +39,7 @@ def init(default_conf="client", **kwargs): LOG.info(f"default_conf: {default_conf}.") C.set_mode(default_conf) - C.set_region(kwargs.get('region', C['region'] if 'region' in C else REG_CN )) + C.set_region(kwargs.get("region", C["region"] if "region" in C else REG_CN)) for k, v in kwargs.items(): C[k] = v @@ -80,13 +80,13 @@ def init(default_conf="client", **kwargs): if "flask_server" in C: LOG.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}") - + # set up QlibRecorder default_uri = str(Path(os.getcwd()).resolve() / "mlruns") - current_uri = C['exp_uri'] if C['exp_uri'] is not None else default_uri + current_uri = C["exp_uri"] if C["exp_uri"] is not None else default_uri # exp manager module - module = get_module_by_module_path('qlib.workflow') - exp_manager = init_instance_by_config(C['exp_manager'], module) + module = get_module_by_module_path("qlib.workflow") + exp_manager = init_instance_by_config(C["exp_manager"], module) qr = QlibRecorder(exp_manager, default_uri, current_uri) R.register(qr) diff --git a/qlib/config.py b/qlib/config.py index db5fab69c..0e2a264af 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -125,10 +125,7 @@ _default_config = { "loggers": {"qlib": {"level": "DEBUG", "handlers": ["console"]}}, }, # Defatult config for experiment manager - "exp_manager": { - "class": "MLflowExpManager", - "kwargs": {} - }, + "exp_manager": {"class": "MLflowExpManager", "kwargs": {}}, "exp_uri": None, } diff --git a/qlib/contrib/backtest/backtest.py b/qlib/contrib/backtest/backtest.py index a1b2e5bce..52e74e14b 100644 --- a/qlib/contrib/backtest/backtest.py +++ b/qlib/contrib/backtest/backtest.py @@ -46,10 +46,10 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark) benchmark code, default is SH000905 CSI500 """ # Convert format if the input format is not expected - if get_level_index(pred, level='datetime') == 1: + if get_level_index(pred, level="datetime") == 1: pred = pred.swaplevel().sort_index() if isinstance(pred, pd.Series): - pred = pred.to_frame('score') + pred = pred.to_frame("score") trade_account = Account(init_cash=account) _pred_dates = pred.index.get_level_values(level="datetime") @@ -80,8 +80,9 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark) # 1. Load the score_series at pred_date try: score = pred.loc(axis=0)[pred_date, :] # (trade_date, stock_id) multi_index, score in pdate - score_series = score.reset_index(level="datetime", - drop=True)["score"] # pd.Series(index:stock_id, data: score) + score_series = score.reset_index(level="datetime", drop=True)[ + "score" + ] # pd.Series(index:stock_id, data: score) except KeyError: LOG.warning("No score found on predict date[{:%Y-%m-%d}]".format(trade_date)) score_series = None diff --git a/qlib/contrib/data/handler.py b/qlib/contrib/data/handler.py index 45e4855c1..2e7f2febc 100644 --- a/qlib/contrib/data/handler.py +++ b/qlib/contrib/data/handler.py @@ -16,21 +16,16 @@ class ALPHA360(DataHandlerLP): "kwargs": { "config": { "feature": { - "price": { - "windows": range(60) - }, - "volume": { - "windows": range(60) - }, + "price": {"windows": range(60)}, + "volume": {"windows": range(60)}, }, - "label": self.get_label_config() + "label": self.get_label_config(), }, - } + }, } - infer_processors = [{ - "class": "ConfigSectionProcessor", - "module_path": "qlib.contrib.data.processor" - }] # ConfigSectionProcessor will normalize LABEL0 + infer_processors = [ + {"class": "ConfigSectionProcessor", "module_path": "qlib.contrib.data.processor"} + ] # ConfigSectionProcessor will normalize LABEL0 super().__init__(instruments, start_time, end_time, data_loader=data_loader, infer_processors=infer_processors) def get_label_config(self): @@ -49,12 +44,7 @@ class Alpha158(DataHandlerLP): start_time=None, end_time=None, infer_processors=[], - learn_processors=["DropnaLabel", { - "class": "CSZScoreNorm", - "kwargs": { - "fields_group": "label" - } - }], + learn_processors=["DropnaLabel", {"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}}], fit_start_time=None, fit_end_time=None, ): @@ -65,11 +55,13 @@ class Alpha158(DataHandlerLP): klass, pkwargs = get_cls_kwargs(p, processor_module) # FIXME: It's hard code here!!!!! if isinstance(klass, (MinMaxNorm, ZscoreNorm)): - assert (fit_start_time is not None and fit_end_time is not None) - pkwargs.update({ - "fit_start_time": fit_start_time, - "fit_end_time": fit_end_time, - }) + assert fit_start_time is not None and fit_end_time is not None + pkwargs.update( + { + "fit_start_time": fit_start_time, + "fit_end_time": fit_end_time, + } + ) new_l.append({"class": klass.__name__, "kwargs": pkwargs}) else: new_l.append(p) @@ -81,18 +73,17 @@ class Alpha158(DataHandlerLP): data_loader = { "class": "QlibDataLoader", "kwargs": { - "config": { - "feature": self.get_feature_config(), - "label": self.get_label_config() - }, - } + "config": {"feature": self.get_feature_config(), "label": self.get_label_config()}, + }, } - super().__init__(instruments, - start_time, - end_time, - data_loader=data_loader, - infer_processors=infer_processors, - learn_processors=learn_processors) + super().__init__( + instruments, + start_time, + end_time, + data_loader=data_loader, + infer_processors=infer_processors, + learn_processors=learn_processors, + ) def get_feature_config(self): conf = { @@ -247,7 +238,8 @@ class Alpha158(DataHandlerLP): if use("SUMD"): fields += [ "(Sum(Greater($close-Ref($close, 1), 0), %d)-Sum(Greater(Ref($close, 1)-$close, 0), %d))" - "/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d) for d in windows + "/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d) + for d in windows ] names += ["SUMD%d" % d for d in windows] if use("VMA"): @@ -258,26 +250,30 @@ class Alpha158(DataHandlerLP): names += ["VSTD%d" % d for d in windows] if use("WVMA"): fields += [ - "Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)" % - (d, d) for d in windows + "Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)" + % (d, d) + for d in windows ] names += ["WVMA%d" % d for d in windows] if use("VSUMP"): fields += [ - "Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d) + "Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" + % (d, d) for d in windows ] names += ["VSUMP%d" % d for d in windows] if use("VSUMN"): fields += [ - "Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d) + "Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" + % (d, d) for d in windows ] names += ["VSUMN%d" % d for d in windows] if use("VSUMD"): fields += [ "(Sum(Greater($volume-Ref($volume, 1), 0), %d)-Sum(Greater(Ref($volume, 1)-$volume, 0), %d))" - "/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d) for d in windows + "/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d) + for d in windows ] names += ["VSUMD%d" % d for d in windows] diff --git a/qlib/contrib/data/processor.py b/qlib/contrib/data/processor.py index 9fca094a4..35b242510 100644 --- a/qlib/contrib/data/processor.py +++ b/qlib/contrib/data/processor.py @@ -8,9 +8,10 @@ from ...data.dataset.processor import Processor, get_group_columns class ConfigSectionProcessor(Processor): - ''' + """ This processor is designed for Alpha158. And will be replaced by simple processors in the future - ''' + """ + def __init__(self, fields_group=None, **kwargs): super().__init__() # Options diff --git a/qlib/contrib/evaluate.py b/qlib/contrib/evaluate.py index 6dcefbb80..a4b6d87dc 100644 --- a/qlib/contrib/evaluate.py +++ b/qlib/contrib/evaluate.py @@ -159,11 +159,11 @@ def get_exchange( if deal_price[0] != "$": deal_price = "$" + deal_price if extract_codes: - codes = sorted(pred.index.get_level_values('instrument').unique()) + codes = sorted(pred.index.get_level_values("instrument").unique()) else: codes = "all" # TODO: We must ensure that 'all.txt' includes all the stocks - dates = sorted(pred.index.get_level_values('datetime').unique()) + dates = sorted(pred.index.get_level_values("datetime").unique()) dates = np.append(dates, get_date_range(dates[-1], shift=shift)) exchange = Exchange( @@ -298,7 +298,7 @@ def long_short_backtest( "short": short_returns(excess), "long_short": long_short_returns} """ - if get_level_index(pred, level='datetime') == 1: + if get_level_index(pred, level="datetime") == 1: pred = pred.swaplevel().sort_index() if trade_unit is None: diff --git a/qlib/contrib/model/gbdt.py b/qlib/contrib/model/gbdt.py index 2769f2282..61c617b8d 100644 --- a/qlib/contrib/model/gbdt.py +++ b/qlib/contrib/model/gbdt.py @@ -12,26 +12,29 @@ from ...data.dataset.handler import DataHandlerLP class LGBModel(Model): """LightGBM Model""" + def __init__(self, loss="mse", **kwargs): if loss not in {"mse", "binary"}: raise NotImplementedError - self._params = {'objective': loss} + self._params = {"objective": loss} self._params.update(kwargs) self.model = None - def fit(self, - dataset: DatasetH, - num_boost_round=1000, - early_stopping_rounds=50, - verbose_eval=20, - evals_result=dict(), - **kwargs): + def fit( + self, + dataset: DatasetH, + num_boost_round=1000, + early_stopping_rounds=50, + verbose_eval=20, + evals_result=dict(), + **kwargs + ): - df_train, df_valid = dataset.prepare(['train', 'valid'], - col_set=['feature', 'label'], - data_key=DataHandlerLP.DK_L) - x_train, y_train = df_train['feature'], df_train['label'] - x_valid, y_valid = df_valid['feature'], df_valid['label'] + df_train, df_valid = dataset.prepare( + ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) + x_train, y_train = df_train["feature"], df_train["label"] + x_valid, y_valid = df_valid["feature"], df_valid["label"] # Lightgbm need 1D array as its label if y_train.values.ndim == 2 and y_train.values.shape[1] == 1: @@ -41,20 +44,22 @@ class LGBModel(Model): dtrain = lgb.Dataset(x_train.values, label=y_train_1d) dvalid = lgb.Dataset(x_valid.values, label=y_valid_1d) - self.model = lgb.train(self._params, - dtrain, - num_boost_round=num_boost_round, - valid_sets=[dtrain, dvalid], - valid_names=["train", "valid"], - early_stopping_rounds=early_stopping_rounds, - verbose_eval=verbose_eval, - evals_result=evals_result, - **kwargs) + self.model = lgb.train( + self._params, + dtrain, + num_boost_round=num_boost_round, + valid_sets=[dtrain, dvalid], + valid_names=["train", "valid"], + early_stopping_rounds=early_stopping_rounds, + verbose_eval=verbose_eval, + evals_result=evals_result, + **kwargs + ) evals_result["train"] = list(evals_result["train"].values())[0] evals_result["valid"] = list(evals_result["valid"].values())[0] def predict(self, dataset): if self.model is None: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare('test', col_set='feature') + x_test = dataset.prepare("test", col_set="feature") return pd.Series(self.model.predict(np.squeeze(x_test.values)), index=x_test.index) diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index fcf17546f..d5b8a12e9 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -6,11 +6,12 @@ import pandas as pd class Dataset(Serializable): - ''' + """ Preparing data for model training and inferencing. - ''' + """ + def __init__(self, *args, **kwargs): - ''' + """ init is designed to finish following steps - setup data - The data related attributes' names should start with '_' so that it will not be saved on disk when serializing @@ -18,7 +19,7 @@ class Dataset(Serializable): - The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing. The data could specify the info to caculate the essential data for preparation - ''' + """ self.setup_data(*args, **kwargs) super().__init__() @@ -51,14 +52,15 @@ class Dataset(Serializable): class DatasetH(Dataset): - ''' + """ Dataset with Data(H)anler User should try to put the data preprocessing functions into handler. Only following data processing functions should be placed in Dataset - The processing is related to specific model. - The processing is related to data split - ''' + """ + def __init__(self, handler: Union[dict, DataHandler], segments: list): """ Parameters @@ -96,10 +98,9 @@ class DatasetH(Dataset): self._handler = init_instance_by_config(handler, accept_types=DataHandler) self._segments = segments - def prepare(self, - segments: Union[List[str], Tuple[str], str, slice], - col_set=DataHandler.CS_ALL, - **kwargs) -> Union[List[pd.DataFrame], pd.DataFrame]: + def prepare( + self, segments: Union[List[str], Tuple[str], str, slice], col_set=DataHandler.CS_ALL, **kwargs + ) -> Union[List[pd.DataFrame], pd.DataFrame]: """ prepare the data for learning and inference @@ -124,9 +125,7 @@ class DatasetH(Dataset): [TODO:description] """ if isinstance(segments, (list, tuple)): - return [ - self._handler.fetch(slice(*self._segments[seg]), col_set=col_set, **kwargs) for seg in segments - ] + return [self._handler.fetch(slice(*self._segments[seg]), col_set=col_set, **kwargs) for seg in segments] elif isinstance(segments, str): return self._handler.fetch(slice(*self._segments[segments]), col_set=col_set, **kwargs) else: diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 4f33ae73c..04715c892 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -25,7 +25,7 @@ from . import loader as data_loader_module # TODO: A more general handler interface which does not relies on internal pd.DataFrame is needed. class DataHandler(Serializable): - ''' + """ The steps to using a handler 1. initialized data handler (call by `init`). 2. use the data @@ -46,13 +46,21 @@ class DataHandler(Serializable): SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042 SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289 - ''' - def __init__(self, instruments, start_time=None, end_time=None, data_loader: Tuple[dict, str, DataLoader]=None, init_data=True): + """ + + def __init__( + self, + instruments, + start_time=None, + end_time=None, + data_loader: Tuple[dict, str, DataLoader] = None, + init_data=True, + ): # Set logger self.logger = get_module_logger("DataHandler") # Setup data loader - assert(data_loader is not None) # to make start_time end_time could have None default value + assert data_loader is not None # to make start_time end_time could have None default value self.data_loader = init_instance_by_config(data_loader, data_loader_module, accept_types=DataLoader) self.instruments = instruments @@ -62,7 +70,7 @@ class DataHandler(Serializable): self.init() super().__init__() - def init(self, enable_cache: bool=True): + def init(self, enable_cache: bool = True): """ initialize the data. In case of running intialization for multiple time, it will do nothing for the second time. @@ -83,7 +91,9 @@ class DataHandler(Serializable): self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time) # TODO: cache - def _fetch_df_by_index(self, df: pd.DataFrame, selector: Union[pd.Timestamp, slice, str, list], level: Union[str, int]) -> pd.DataFrame: + def _fetch_df_by_index( + self, df: pd.DataFrame, selector: Union[pd.Timestamp, slice, str, list], level: Union[str, int] + ) -> pd.DataFrame: """ fetch data from `data` with `selector` and `level` @@ -100,7 +110,7 @@ class DataHandler(Serializable): idx_slc = idx_slc[1], idx_slc[0] return df.loc(axis=0)[idx_slc] - CS_ALL = '__all' + CS_ALL = "__all" def _fetch_df_by_col(self, df: pd.DataFrame, col_set: str) -> pd.DataFrame: cln = len(df.columns.levels) @@ -111,10 +121,12 @@ class DataHandler(Serializable): else: return df.loc(axis=1)[col_set] - def fetch(self, - selector: Union[pd.Timestamp, slice, str], - level: Union[str, int] = 'datetime', - col_set: Union[str, List[str]] = CS_ALL) -> pd.DataFrame: + def fetch( + self, + selector: Union[pd.Timestamp, slice, str], + level: Union[str, int] = "datetime", + col_set: Union[str, List[str]] = CS_ALL, + ) -> pd.DataFrame: """ fetch data from underlying data source @@ -157,32 +169,35 @@ class DataHandler(Serializable): class DataHandlerLP(DataHandler): - ''' + """ DataHandler with **(L)earnable (P)rocessor** - ''' + """ + # data key - DK_R = 'raw' - DK_I = 'infer' - DK_L = 'learn' + DK_R = "raw" + DK_I = "infer" + DK_L = "learn" # process type - PTYPE_I = 'independent' + PTYPE_I = "independent" # - _proc_infer_df will processed by infer_processors # - _proc_learn_df will be processed by learn_processors - PTYPE_A = 'append' + PTYPE_A = "append" # - _proc_infer_df will processed by infer_processors # - _proc_learn_df will be processed by infer_processors + learn_processors # - (e.g. _proc_infer_df processed by learn_processors ) - def __init__(self, - instruments, - start_time=None, - end_time=None, - data_loader: Tuple[dict, str, DataLoader] = None, - infer_processors=[], - learn_processors=[], - process_type=PTYPE_A, - **kwargs): + def __init__( + self, + instruments, + start_time=None, + end_time=None, + data_loader: Tuple[dict, str, DataLoader] = None, + infer_processors=[], + learn_processors=[], + process_type=PTYPE_A, + **kwargs, + ): """ Parameters ---------- @@ -217,10 +232,11 @@ class DataHandlerLP(DataHandler): # Setup preprocessor self.infer_processors = [] # for lint self.learn_processors = [] # for lint - for pname in 'infer_processors', 'learn_processors': + for pname in "infer_processors", "learn_processors": for proc in locals()[pname]: - getattr(self, pname).append(init_instance_by_config(proc, processor_module, - accept_types=(processor_module.Processor,))) + getattr(self, pname).append( + init_instance_by_config(proc, processor_module, accept_types=(processor_module.Processor,)) + ) self.process_type = process_type super().__init__(instruments, start_time, end_time, data_loader, **kwargs) @@ -240,8 +256,7 @@ class DataHandlerLP(DataHandler): """ self.process_data(with_fit=True) - - def process_data(self, with_fit: bool=False): + def process_data(self, with_fit: bool = False): """ process_data data. Fun `processor.fit` if necessary @@ -281,11 +296,11 @@ class DataHandlerLP(DataHandler): self._learn = _learn_df # init type - IT_FIT_SEQ = 'fit_seq' # the input of `fit` will be the output of the previous processor - IT_FIT_IND = 'fit_ind' # the input of `fit` will be the original df - IT_LS = 'load_state' # The state of the object has been load by pickle + IT_FIT_SEQ = "fit_seq" # the input of `fit` will be the output of the previous processor + IT_FIT_IND = "fit_ind" # the input of `fit` will be the original df + IT_LS = "load_state" # The state of the object has been load by pickle - def init(self, init_type: str=IT_FIT_SEQ, enable_cache: bool=False): + def init(self, init_type: str = IT_FIT_SEQ, enable_cache: bool = False): """ Initialize the data of Qlib @@ -314,15 +329,17 @@ class DataHandlerLP(DataHandler): # TODO: Be able to cache handler data. Save the memory for data processing - def _get_df_by_key(self, data_key: str=DK_I) -> pd.DataFrame: - df = getattr(self, {self.DK_R: '_data', self.DK_I: "_infer", self.DK_L: "_learn"}[data_key]) + def _get_df_by_key(self, data_key: str = DK_I) -> pd.DataFrame: + df = getattr(self, {self.DK_R: "_data", self.DK_I: "_infer", self.DK_L: "_learn"}[data_key]) return df - def fetch(self, - selector: Union[pd.Timestamp, slice, str], - level: Union[str, int] = 'datetime', - col_set=DataHandler.CS_ALL, - data_key: str = DK_I) -> pd.DataFrame: + def fetch( + self, + selector: Union[pd.Timestamp, slice, str], + level: Union[str, int] = "datetime", + col_set=DataHandler.CS_ALL, + data_key: str = DK_I, + ) -> pd.DataFrame: """ fetch data from underlying data source @@ -345,7 +362,7 @@ class DataHandlerLP(DataHandler): df = self._fetch_df_by_index(df, selector, level) return self._fetch_df_by_col(df, col_set) - def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str=DK_I) -> list: + def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list: """ get the column names diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index b94280a83..e4f2f8619 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -8,44 +8,46 @@ from typing import Tuple class DataLoader(ABC): - ''' + """ DataLoader is designed for loading raw data from original data source. - ''' + """ + @abstractmethod def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame: """ - load the data as pd.DataFrame + load the data as pd.DataFrame - Parameters - ---------- - self : [TODO:type] - [TODO:description] - instruments : [TODO:type] - [TODO:description] - start_time : [TODO:type] - [TODO:description] - end_time : [TODO:type] - [TODO:description] + Parameters + ---------- + self : [TODO:type] + [TODO:description] + instruments : [TODO:type] + [TODO:description] + start_time : [TODO:type] + [TODO:description] + end_time : [TODO:type] + [TODO:description] - Returns - ------- - pd.DataFrame: - data load from the under layer source + Returns + ------- + pd.DataFrame: + data load from the under layer source - Example of the data: - The multi-index of the columns is optional. - feature label - $close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0 - datetime instrument - 2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032 - SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042 - SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289 + Example of the data: + The multi-index of the columns is optional. + feature label + $close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0 + datetime instrument + 2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032 + SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042 + SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289 """ pass class QlibDataLoader(DataLoader): - '''Same as QlibDataLoader. The fields can be define by config''' + """Same as QlibDataLoader. The fields can be define by config""" + def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None): """ Parameters @@ -65,7 +67,7 @@ class QlibDataLoader(DataLoader): Here is a few examples to describe the fields TODO: """ - self.is_group = isinstance(config, dict) + self.is_group = isinstance(config, dict) if self.is_group: self.fields = {grp: self._parse_fields_info(fields_info) for grp, fields_info in config.items()} @@ -88,6 +90,7 @@ class QlibDataLoader(DataLoader): df = D.features(D.instruments(instruments, filter_pipe=self.filter_pipe), exprs, start_time, end_time) df.columns = names return df + if self.is_group: df = pd.concat({grp: _get_df(exprs, names) for grp, (exprs, names) in self.fields.items()}, axis=1) else: diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index 2ab012de2..3fc91f52c 100644 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -30,8 +30,7 @@ def get_group_columns(df: pd.DataFrame, group: str): class Processor(Serializable): - - def fit(self, df: pd.DataFrame=None): + def fit(self, df: pd.DataFrame = None): """ learn data processing parameters @@ -40,7 +39,7 @@ class Processor(Serializable): df : pd.DataFrame When we fit and process data with processor one by one. The fit function reiles on the output of previous processor, i.e. `df`. - + """ pass @@ -81,16 +80,17 @@ class DropnaProcessor(Processor): class DropnaLabel(DropnaProcessor): - def __init__(self, group='label'): + def __init__(self, group="label"): super().__init__(group=group) def is_for_infer(self) -> bool: - '''The samples are dropped according to label. So it is not usable for inference''' + """The samples are dropped according to label. So it is not usable for inference""" return False class ProcessInf(Processor): - '''Process infinity ''' + """Process infinity """ + def __call__(self, df): def replace_inf(data): def process_inf(df): @@ -102,6 +102,7 @@ class ProcessInf(Processor): data = data.groupby("datetime").apply(process_inf) data.sort_index(inplace=True) return data + return replace_inf(df) @@ -126,6 +127,7 @@ class MinMaxNorm(Processor): if not ignore[i]: x[i] = (x[i] - min_val) / (max_val - min_val) return x + df.loc(axis=1)[self.cols] = normalize(df[self.cols].values) return df @@ -151,17 +153,19 @@ class ZscoreNorm(Processor): if not ignore[i]: x[i] = (x[i] - mean_train) / std_train return x + df.loc(axis=1)[self.cols] = normalize(df[self.cols].values) return df class CSZScoreNorm(Processor): - '''Cross Sectional ZScore Normalization''' + """Cross Sectional ZScore Normalization""" + def __init__(self, fields_group=None): self.fields_group = fields_group def __call__(self, df): # try not modify original dataframe - cols = get_group_columns(df,self.fields_group) - df[cols] = df[cols].groupby('datetime').apply(lambda df: (df - df.mean()).div(df.std())) + cols = get_group_columns(df, self.fields_group) + df[cols] = df[cols].groupby("datetime").apply(lambda df: (df - df.mean()).div(df.std())) return df diff --git a/qlib/data/dataset/utils.py b/qlib/data/dataset/utils.py index c97256896..af0900867 100644 --- a/qlib/data/dataset/utils.py +++ b/qlib/data/dataset/utils.py @@ -24,9 +24,8 @@ def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int: return df.index.names.index(level) except (AttributeError, ValueError): # NOTE: If level index is not given in the data, the default level index will be ('datetime', 'instrument') - return ('datetime', 'instrument').index(level) + return ("datetime", "instrument").index(level) elif isinstance(level, int): return level else: raise NotImplementedError(f"This type of input is not supported") - diff --git a/qlib/model/base.py b/qlib/model/base.py index 11bd76d06..3a6ad504e 100644 --- a/qlib/model/base.py +++ b/qlib/model/base.py @@ -6,7 +6,7 @@ from ..data.dataset import Dataset class BaseModel(Serializable, metaclass=abc.ABCMeta): - '''Modeling things''' + """Modeling things""" @abc.abstractmethod def predict(self, *args, **kwargs) -> object: @@ -19,7 +19,7 @@ class BaseModel(Serializable, metaclass=abc.ABCMeta): class Model(BaseModel): - '''Learnable Models''' + """Learnable Models""" def fit(self, dataset: Dataset): """ diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index ef7bf63d6..87b43f456 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -165,7 +165,7 @@ def get_module_by_module_path(module_path): return module -def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict): +def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict): """ extract class and kwargs from config info @@ -184,8 +184,8 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict): """ if isinstance(config, dict): # raise AttributeError - klass = getattr(module, config['class']) - kwargs = config['kwargs'] + klass = getattr(module, config["class"]) + kwargs = config["kwargs"] elif isinstance(config, str): klass = getattr(module, config) kwargs = {} @@ -194,7 +194,9 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict): return klass, kwargs -def init_instance_by_config(config: Union[str, dict], module=None, accept_types: Union[type, Tuple[type]]=tuple([])) -> object: +def init_instance_by_config( + config: Union[str, dict], module=None, accept_types: Union[type, Tuple[type]] = tuple([]) +) -> object: """ get initialized instance with config @@ -647,4 +649,4 @@ def register_wrapper(wrapper, cls_or_obj): module = get_module_by_module_path("qlib.data") cls_or_obj = getattr(module, cls_or_obj) obj = cls_or_obj() if isinstance(cls_or_obj, type) else cls_or_obj - wrapper.register(obj) \ No newline at end of file + wrapper.register(obj) diff --git a/qlib/utils/objm.py b/qlib/utils/objm.py index d7c4f4cb1..eebd529c6 100644 --- a/qlib/utils/objm.py +++ b/qlib/utils/objm.py @@ -24,7 +24,7 @@ class ObjManager: def save_objs(self, obj_name_l): """ - save objects + save objects Parameters ---------- @@ -88,9 +88,10 @@ class ObjManager: class FileManager(ObjManager): - ''' + """ Use file system to manage objects - ''' + """ + def __init__(self, path=None): if path is None: self.path = Path(self.create_path()) @@ -99,12 +100,12 @@ class FileManager(ObjManager): def create_path(self) -> str: try: - return tempfile.mkdtemp(prefix=str(C['file_manager_path']) + os.sep) + return tempfile.mkdtemp(prefix=str(C["file_manager_path"]) + os.sep) except AttributeError: raise NotImplementedError(f"If path is not given, the `create_path` function should be implemented") def save_obj(self, obj, name): - with (self.path / name).open('wb') as f: + with (self.path / name).open("wb") as f: pickle.dump(obj, f) def save_objs(self, obj_name_l): @@ -112,7 +113,7 @@ class FileManager(ObjManager): self.save_obj(obj, name) def load_obj(self, name): - with (self.path / name).open('rb') as f: + with (self.path / name).open("rb") as f: return pickle.load(f) def exists(self, name): @@ -123,7 +124,7 @@ class FileManager(ObjManager): def remove(self, fname=None): if fname is None: - for fp in self.path.glob('*'): + for fp in self.path.glob("*"): fp.unlink() self.path.rmdir() else: diff --git a/qlib/utils/serial.py b/qlib/utils/serial.py index a4825615f..04781d655 100644 --- a/qlib/utils/serial.py +++ b/qlib/utils/serial.py @@ -6,17 +6,17 @@ import pickle class Serializable: - ''' + """ Serializable behaves like pickle. But it only save the state whose name starts with `_` - ''' + """ def __getstate__(self) -> dict: - return {k: v for k, v in self.__dict__.items() if k.startswith('_') } + return {k: v for k, v in self.__dict__.items() if k.startswith("_")} def __setstate__(self, state: dict): self.__dict__.update(state) def to_pickle(self, path: [Path, str]): - with Path(path).open('wb') as f: + with Path(path).open("wb") as f: pickle.dump(self, f) diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index 06a646c84..7c9c1928f 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -5,6 +5,7 @@ from contextlib import contextmanager from .expm import * from ..utils import Wrapper + class QlibRecorder: def __init__(self, exp_manager, default_uri, current_uri): self.exp_manager = exp_manager @@ -16,16 +17,16 @@ class QlibRecorder: run = self.start_exp(experiment_name, self.current_uri) yield run self.end_exp() - + def start_exp(self, experiment_name=None): - return self.exp_manager.start_exp(experiment_name, self.current_uri) + return self.exp_manager.start_exp(experiment_name, self.current_uri) def end_exp(self): self.exp_manager.end_exp() - + def search_records(self, experiment_ids, **kwargs): return self.exp_manager.search_records(experiment_ids, **kwargs) - + def get_exp(self, experiment_id=None, experiment_name=None): return self.exp_manager.get_exp(experiment_id, experiment_name) @@ -52,12 +53,13 @@ class QlibRecorder: def log_metrics(self, step=None, **kwargs): self.exp_manager.active_recorder.log_metrics(step, **kwargs) - + def set_tags(self, **kwargs): self.exp_manager.active_recorder.set_tags(**kwargs) def delete_tag(self, key): self.exp_manager.active_recorder.delete_tag(key) + # global record R = Wrapper() diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index 9e076aced..a63187e28 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -4,10 +4,12 @@ import mlflow from pathlib import Path + class Experiment: """ - Thie is the `Experiment` class for each experiment being run. The API is designed + Thie is the `Experiment` class for each experiment being run. The API is designed """ + def __init__(self): self.name = None self.id = None @@ -39,9 +41,10 @@ class MLflowExperiment(Experiment): """ Use mlflow to implement Experiment. """ + def search_records(self, **kwargs): - filter_string = '' if kwargs.get('filter_string') is None else kwargs.get('filter_string') - run_view_type = 1 if kwargs.get('run_view_type') is None else kwargs.get('run_view_type') - max_results = 100000 if kwargs.get('max_results') is None else kwargs.get('max_results') - order_by = kwargs.get('order_by') - return mlflow.search_runs([self.experiment_id], filter_string, run_view_type, max_results, order_by) \ No newline at end of file + filter_string = "" if kwargs.get("filter_string") is None else kwargs.get("filter_string") + run_view_type = 1 if kwargs.get("run_view_type") is None else kwargs.get("run_view_type") + max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results") + order_by = kwargs.get("order_by") + return mlflow.search_runs([self.experiment_id], filter_string, run_view_type, max_results, order_by) diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 36a945f42..00d25da48 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -8,15 +8,17 @@ from contextlib import contextmanager from .exp import MLflowExperiment from .record import MLflowRecorder + class ExpManager: """ This is the `ExpManager` class for managing the experiments. The API is designed similar to mlflow. (The link: https://mlflow.org/docs/latest/python_api/mlflow.html) """ + def __init__(self): self.default_uri = None - self.active_recorder = None # only one recorder can running each time - self.experiments = dict() # store the experiment name --> Experiment object + self.active_recorder = None # only one recorder can running each time + self.experiments = dict() # store the experiment name --> Experiment object def start_exp(self, experiment_name=None, uri=None, **kwargs): """ @@ -88,7 +90,7 @@ class ExpManager: An experiment object. """ raise NotImplementedError(f"Please implement the `create_exp` method.") - + def get_exp(self, experiment_id=None, experiment_name=None): """ Retrieve an experiment by experiment_id from the backend store. @@ -111,7 +113,7 @@ class ExpManager: Parameters ---------- experiment_id : str - the experiment id. + the experiment id. """ raise NotImplementedError(f"Please implement the `create_exp` method.") @@ -142,12 +144,13 @@ class ExpManager: An Recorder object. """ raise NotImplementedError(f"Please implement the `get_recorder` method.") - + class MLflowExpManager(ExpManager): - ''' + """ Use mlflow to implement ExpManager. - ''' + """ + def __init__(self): super(MLflowExpManager, self).__init__() self.default_uri = None @@ -169,27 +172,31 @@ class MLflowExpManager(ExpManager): def end_exp(self): self.active_recorder.end_run() self.active_recorder = None - + def __create_exp(self, experiment_name=None, uri=None): # init experiment experiment = MLflowExperiment() # set the tracking uri if uri is None: - print('No tracking URI is provided. The default tracking URI is set as `mlruns` under the working directory.') + print( + "No tracking URI is provided. The default tracking URI is set as `mlruns` under the working directory." + ) else: self.current_uri = uri mlflow.set_tracking_uri(self.current_uri) # start the experiment if experiment_name is None: - print('No experiment name provided. The default experiment name is set as `experiment`.') - experiment_id = mlflow.create_experiment('experiment') + print("No experiment name provided. The default experiment name is set as `experiment`.") + experiment_id = mlflow.create_experiment("experiment") # set the active experiment - mlflow.set_experiment('experiment') - experiment_name = 'experiment' + mlflow.set_experiment("experiment") + experiment_name = "experiment" else: if experiment_name not in self.experiments: if mlflow.get_experiment_by_name(experiment_name) is not None: - raise Exception('The experiment has already been created before. Please pick another name or delete the files under uri.') + raise Exception( + "The experiment has already been created before. Please pick another name or delete the files under uri." + ) experiment_id = mlflow.create_experiment(experiment_name) else: experiment_id = self.experiments[experiment_name].id @@ -197,40 +204,42 @@ class MLflowExpManager(ExpManager): # set the active experiment mlflow.set_experiment(experiment_name) # set up experiment - experiment.id = experiment_id + experiment.id = experiment_id experiment.name = experiment_name return experiment - + def search_records(self, experiment_ids, **kwargs): - filter_string = '' if kwargs.get('filter_string') is None else kwargs.get('filter_string') - run_view_type = 1 if kwargs.get('run_view_type') is None else kwargs.get('run_view_type') - max_results = 100000 if kwargs.get('max_results') is None else kwargs.get('max_results') - order_by = kwargs.get('order_by') + filter_string = "" if kwargs.get("filter_string") is None else kwargs.get("filter_string") + run_view_type = 1 if kwargs.get("run_view_type") is None else kwargs.get("run_view_type") + max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results") + order_by = kwargs.get("order_by") return mlflow.search_runs(experiment_ids, filter_string, run_view_type, max_results, order_by) - + def get_exp(self, experiment_id=None, experiment_name=None): - assert experiment_id is not None or experiment_name is not None, 'Please provide at least one of the experiment id or name to retrieve an experiment.' + assert ( + experiment_id is not None or experiment_name is not None + ), "Please provide at least one of the experiment id or name to retrieve an experiment." if experiment_name is not None: return self.experiments[experiment_name] - elif: + elif experiment_id is not None: for name in self.experiments: if self.experiments[name].id == experiment_id: return self.experiments[name] else: - print('No valid experiment is found. Please make sure the id and name are correctly given.') + print("No valid experiment is found. Please make sure the id and name are correctly given.") def delete_exp(self, experiment_id): mlflow.delete_experiment(experiment_id) - self.experiments = {key:val for key, val in self.experiments.items() if val.id != experiment_id} + self.experiments = {key: val for key, val in self.experiments.items() if val.id != experiment_id} def get_uri(self, type): - if uri == 'default': + if uri == "default": return self.default_uri - elif uri == 'current': + elif uri == "current": return self.current_uri else: - raise ValueError('Input type is not supported. Please choose type default or current to get the uri.') + raise ValueError("Input type is not supported. Please choose type default or current to get the uri.") def get_recorder(self): - return self.active_recorder \ No newline at end of file + return self.active_recorder diff --git a/qlib/workflow/record.py b/qlib/workflow/record.py index 071c92691..e132710ca 100644 --- a/qlib/workflow/record.py +++ b/qlib/workflow/record.py @@ -6,6 +6,7 @@ import shutil, os, pickle, tempfile, codecs from pathlib import Path from ..utils.objm import FileManager + class Recorder: """ This is the `Recorder` class for logging the experiments. The API is designed similar to mlflow. @@ -16,7 +17,7 @@ class Recorder: self.experiment_id = experiment_id self.recorder_id = None self.recorder_name = None - + def set_recorder_name(self, rname): self.recorder_name = rname @@ -63,10 +64,9 @@ class Recorder: """ raise NotImplementedError(f"Please implement the `load_object` method.") - def start_run(self, run_id=None, experiment_id=None, - run_name=None, nested=False): + def start_run(self, run_id=None, experiment_id=None, run_name=None, nested=False): """ - Start running the Recorder. The return value can be used as a context manager within a `with` block; + Start running the Recorder. The return value can be used as a context manager within a `with` block; otherwise, you must call end_run() to terminate the current run. (See `ActiveRun` class in mlflow) Parameters @@ -85,7 +85,7 @@ class Recorder: An active running object (e.g. mlflow.ActiveRun object). """ raise NotImplementedError(f"Please implement the `start_run` method.") - + def end_run(self): """ End an active Recorder. @@ -138,19 +138,19 @@ class Recorder: class MLflowRecorder(Recorder): - ''' + """ Use mlflow to implement a Recorder. - Due to the fact that mlflow will only log artifact from a file or directory, we decide to + Due to the fact that mlflow will only log artifact from a file or directory, we decide to use file manager to help maintain the objects in the project. - ''' + """ + def __init__(self, experiment_id): super(MLflowRecorder, self).__init__(experiment_id) self.fm = None self.temp_dir = None - def start_run(self, run_id=None, experiment_id=None, - run_name=None, nested=False): + def start_run(self, run_id=None, experiment_id=None, run_name=None, nested=False): if run_id is None: run_id = self.recorder_id if experiment_id is None: @@ -166,7 +166,7 @@ class MLflowRecorder(Recorder): self.temp_dir = tempfile.mkdtemp() self.fm = FileManager(Path(self.temp_dir).absolute()) return run - + def end_run(self): mlflow.end_run() shutil.rmtree(self.temp_dir) @@ -194,13 +194,13 @@ class MLflowRecorder(Recorder): client = mlflow.tracking.MlflowClient() path = client.download_artifacts(self.recorder_id, name) try: - with Path(path).open('rb') as f: + with Path(path).open("rb") as f: f.seek(0) return pickle.load(f) except: - with codecs.open(path, mode="r", encoding='utf-8') as f: - return f.read() - + with codecs.open(path, mode="r", encoding="utf-8") as f: + return f.read() + def log_params(self, **kwargs): keys = list(kwargs.keys()) if len(keys) == 0: @@ -214,7 +214,7 @@ class MLflowRecorder(Recorder): mlflow.log_metric(keys[0], kwargs.get(keys[0])) else: mlflow.log_metrics(dict(kwargs)) - + def set_tags(self, **kwargs): keys = list(kwargs.keys()) if len(keys) == 0: @@ -228,4 +228,4 @@ class MLflowRecorder(Recorder): def get_artifact_uri(self, artifact_path=None): if self.artifact_uri is not None: return self.artifact_uri - return mlflow.get_artifact_uri(artifact_path) \ No newline at end of file + return mlflow.get_artifact_uri(artifact_path)