1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

fix some typo in doc/comments (#1389)

* fix typo in docstrings

* fix typo

* fix typo

* fix black lint

* fix black lint
This commit is contained in:
YQ Tsui
2022-12-11 14:29:16 +08:00
committed by GitHub
parent 57f9813f85
commit 5e3924d7a6
6 changed files with 25 additions and 21 deletions

View File

@@ -56,7 +56,7 @@ class ADARNN(Model):
n_splits=2, n_splits=2,
GPU=0, GPU=0,
seed=None, seed=None,
**kwargs **_
): ):
# Set logger. # Set logger.
self.logger = get_module_logger("ADARNN") self.logger = get_module_logger("ADARNN")
@@ -81,7 +81,7 @@ class ADARNN(Model):
self.optimizer = optimizer.lower() self.optimizer = optimizer.lower()
self.loss = loss self.loss = loss
self.n_splits = n_splits self.n_splits = n_splits
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.seed = seed self.seed = seed
self.logger.info( self.logger.info(
@@ -213,7 +213,8 @@ class ADARNN(Model):
weight_mat = self.transform_type(out_weight_list) weight_mat = self.transform_type(out_weight_list)
return weight_mat, None return weight_mat, None
def calc_all_metrics(self, pred): @staticmethod
def calc_all_metrics(pred):
"""pred is a pandas dataframe that has two attributes: score (pred) and label (real)""" """pred is a pandas dataframe that has two attributes: score (pred) and label (real)"""
res = {} res = {}
ic = pred.groupby(level="datetime").apply(lambda x: x.label.corr(x.score)) ic = pred.groupby(level="datetime").apply(lambda x: x.label.corr(x.score))
@@ -259,8 +260,6 @@ class ADARNN(Model):
save_path = get_or_create_path(save_path) save_path = get_or_create_path(save_path)
stop_steps = 0 stop_steps = 0
best_score = -np.inf
best_epoch = 0
evals_result["train"] = [] evals_result["train"] = []
evals_result["valid"] = [] evals_result["valid"] = []
@@ -400,7 +399,7 @@ class AdaRNN(nn.Module):
self.model_type = model_type self.model_type = model_type
self.trans_loss = trans_loss self.trans_loss = trans_loss
self.len_seq = len_seq self.len_seq = len_seq
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
in_size = self.n_input in_size = self.n_input
features = nn.ModuleList() features = nn.ModuleList()
@@ -499,7 +498,8 @@ class AdaRNN(nn.Module):
res = self.softmax(weight).squeeze() res = self.softmax(weight).squeeze()
return res return res
def get_features(self, output_list): @staticmethod
def get_features(output_list):
fea_list_src, fea_list_tar = [], [] fea_list_src, fea_list_tar = [], []
for fea in output_list: for fea in output_list:
fea_list_src.append(fea[0 : fea.size(0) // 2]) fea_list_src.append(fea[0 : fea.size(0) // 2])
@@ -561,7 +561,7 @@ class TransferLoss:
""" """
self.loss_type = loss_type self.loss_type = loss_type
self.input_dim = input_dim self.input_dim = input_dim
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
def compute(self, X, Y): def compute(self, X, Y):
"""Compute adaptation loss """Compute adaptation loss
@@ -676,7 +676,8 @@ class MMD_loss(nn.Module):
self.fix_sigma = None self.fix_sigma = None
self.kernel_type = kernel_type self.kernel_type = kernel_type
def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): @staticmethod
def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
n_samples = int(source.size()[0]) + int(target.size()[0]) n_samples = int(source.size()[0]) + int(target.size()[0])
total = torch.cat([source, target], dim=0) total = torch.cat([source, target], dim=0)
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
@@ -691,7 +692,8 @@ class MMD_loss(nn.Module):
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list] kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
return sum(kernel_val) return sum(kernel_val)
def linear_mmd(self, X, Y): @staticmethod
def linear_mmd(X, Y):
delta = X.mean(axis=0) - Y.mean(axis=0) delta = X.mean(axis=0) - Y.mean(axis=0)
loss = delta.dot(delta.T) loss = delta.dot(delta.T)
return loss return loss

View File

@@ -428,7 +428,7 @@ class EnhancedIndexingStrategy(WeightStrategyBase):
specific_risk = load_dataset(root + "/" + self.specific_risk_path, index_col=[0]) specific_risk = load_dataset(root + "/" + self.specific_risk_path, index_col=[0])
if not factor_exp.index.equals(specific_risk.index): if not factor_exp.index.equals(specific_risk.index):
# NOTE: for stocks missing specific_risk, we always assume it have the highest volatility # NOTE: for stocks missing specific_risk, we always assume it has the highest volatility
specific_risk = specific_risk.reindex(factor_exp.index, fill_value=specific_risk.max()) specific_risk = specific_risk.reindex(factor_exp.index, fill_value=specific_risk.max())
universe = factor_exp.index.tolist() universe = factor_exp.index.tolist()

View File

@@ -18,7 +18,7 @@ class StructuredCovEstimator(RiskModel):
`B` is the regression coefficients matrix for all observations (row) on `B` is the regression coefficients matrix for all observations (row) on
all factors (columns), and `U` is the residual matrix with shape like `X`. all factors (columns), and `U` is the residual matrix with shape like `X`.
Therefore the structured covariance can be estimated by Therefore, the structured covariance can be estimated by
cov(X.T) = F @ cov(B.T) @ F.T + diag(var(U)) cov(X.T) = F @ cov(B.T) @ F.T + diag(var(U))
In finance domain, there are mainly three methods to design `F` [1][2]: In finance domain, there are mainly three methods to design `F` [1][2]:

View File

@@ -155,7 +155,7 @@ class QlibRecorder:
The arguments of this function are not set to be rigid, and they will be different with different implementation of The arguments of this function are not set to be rigid, and they will be different with different implementation of
``ExpManager`` in ``Qlib``. ``Qlib`` now provides an implementation of ``ExpManager`` with mlflow, and here is the ``ExpManager`` in ``Qlib``. ``Qlib`` now provides an implementation of ``ExpManager`` with mlflow, and here is the
example code of the this method with the ``MLflowExpManager``: example code of the method with the ``MLflowExpManager``:
.. code-block:: Python .. code-block:: Python

View File

@@ -30,7 +30,8 @@ class RecordTemp:
""" """
artifact_path = None artifact_path = None
depend_cls = None # the depend class of the record; the record will depend on the results generated by `depend_cls` depend_cls = None # the dependant class of the record; the record will depend on the results generated by
# `depend_cls`
@classmethod @classmethod
def get_path(cls, path=None): def get_path(cls, path=None):
@@ -119,7 +120,7 @@ class RecordTemp:
Check if the records is properly generated and saved. Check if the records is properly generated and saved.
It is useful in following examples It is useful in following examples
- checking if the depended files complete before generating new things. - checking if the dependant files complete before generating new things.
- checking if the final files is completed - checking if the final files is completed
Parameters Parameters
@@ -186,7 +187,7 @@ class SignalRecord(RecordTemp):
return raw_label return raw_label
def generate(self, **kwargs): def generate(self, **kwargs):
# generate prediciton # generate prediction
pred = self.model.predict(self.dataset) pred = self.model.predict(self.dataset)
if isinstance(pred, pd.Series): if isinstance(pred, pd.Series):
pred = pred.to_frame("score") pred = pred.to_frame("score")
@@ -285,7 +286,8 @@ class HFSignalRecord(SignalRecord):
class SigAnaRecord(ACRecordTemp): class SigAnaRecord(ACRecordTemp):
""" """
This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class. This is the Signal Analysis Record class that generates the analysis results such as IC and IR.
This class inherits the ``RecordTemp`` class.
""" """
artifact_path = "sig_analysis" artifact_path = "sig_analysis"
@@ -382,7 +384,7 @@ class PortAnaRecord(ACRecordTemp):
indicator_analysis_freq : str|List[str] indicator_analysis_freq : str|List[str]
indicator analysis freq of report indicator analysis freq of report
indicator_analysis_method : str, optional, default by None indicator_analysis_method : str, optional, default by None
the candidated values include 'mean', 'amount_weighted', 'value_weighted' the candidate values include 'mean', 'amount_weighted', 'value_weighted'
""" """
super().__init__(recorder=recorder, skip_existing=skip_existing, **kwargs) super().__init__(recorder=recorder, skip_existing=skip_existing, **kwargs)
@@ -456,9 +458,9 @@ class PortAnaRecord(ACRecordTemp):
pred = self.load("pred.pkl") pred = self.load("pred.pkl")
# replace the "<PRED>" with prediction saved before # replace the "<PRED>" with prediction saved before
placehorder_value = {"<PRED>": pred} placeholder_value = {"<PRED>": pred}
for k in "executor_config", "strategy_config": for k in "executor_config", "strategy_config":
setattr(self, k, fill_placeholder(getattr(self, k), placehorder_value)) setattr(self, k, fill_placeholder(getattr(self, k), placeholder_value))
# if the backtesting time range is not set, it will automatically extract time range from the prediction file # if the backtesting time range is not set, it will automatically extract time range from the prediction file
dt_values = pred.index.get_level_values("datetime") dt_values = pred.index.get_level_values("datetime")

View File

@@ -19,7 +19,7 @@ cd qlib/scripts/data_collector/pit/
python collector.py download_data --source_dir ~/.qlib/stock_data/source/pit --start 2000-01-01 --end 2020-01-01 --interval quarterly python collector.py download_data --source_dir ~/.qlib/stock_data/source/pit --start 2000-01-01 --end 2020-01-01 --interval quarterly
``` ```
Downloading all data from the stock is very time consuming. If you just want run a quick test on a few stocks, you can run the command below Downloading all data from the stock is very time-consuming. If you just want to run a quick test on a few stocks, you can run the command below
```bash ```bash
python collector.py download_data --source_dir ~/.qlib/stock_data/source/pit --start 2000-01-01 --end 2020-01-01 --interval quarterly --symbol_regex "^(600519|000725).*" python collector.py download_data --source_dir ~/.qlib/stock_data/source/pit --start 2000-01-01 --end 2020-01-01 --interval quarterly --symbol_regex "^(600519|000725).*"
``` ```