mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
fix some typo in doc/comments (#1389)
* fix typo in docstrings * fix typo * fix typo * fix black lint * fix black lint
This commit is contained in:
@@ -56,7 +56,7 @@ class ADARNN(Model):
|
|||||||
n_splits=2,
|
n_splits=2,
|
||||||
GPU=0,
|
GPU=0,
|
||||||
seed=None,
|
seed=None,
|
||||||
**kwargs
|
**_
|
||||||
):
|
):
|
||||||
# Set logger.
|
# Set logger.
|
||||||
self.logger = get_module_logger("ADARNN")
|
self.logger = get_module_logger("ADARNN")
|
||||||
@@ -81,7 +81,7 @@ class ADARNN(Model):
|
|||||||
self.optimizer = optimizer.lower()
|
self.optimizer = optimizer.lower()
|
||||||
self.loss = loss
|
self.loss = loss
|
||||||
self.n_splits = n_splits
|
self.n_splits = n_splits
|
||||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
@@ -213,7 +213,8 @@ class ADARNN(Model):
|
|||||||
weight_mat = self.transform_type(out_weight_list)
|
weight_mat = self.transform_type(out_weight_list)
|
||||||
return weight_mat, None
|
return weight_mat, None
|
||||||
|
|
||||||
def calc_all_metrics(self, pred):
|
@staticmethod
|
||||||
|
def calc_all_metrics(pred):
|
||||||
"""pred is a pandas dataframe that has two attributes: score (pred) and label (real)"""
|
"""pred is a pandas dataframe that has two attributes: score (pred) and label (real)"""
|
||||||
res = {}
|
res = {}
|
||||||
ic = pred.groupby(level="datetime").apply(lambda x: x.label.corr(x.score))
|
ic = pred.groupby(level="datetime").apply(lambda x: x.label.corr(x.score))
|
||||||
@@ -259,8 +260,6 @@ class ADARNN(Model):
|
|||||||
|
|
||||||
save_path = get_or_create_path(save_path)
|
save_path = get_or_create_path(save_path)
|
||||||
stop_steps = 0
|
stop_steps = 0
|
||||||
best_score = -np.inf
|
|
||||||
best_epoch = 0
|
|
||||||
evals_result["train"] = []
|
evals_result["train"] = []
|
||||||
evals_result["valid"] = []
|
evals_result["valid"] = []
|
||||||
|
|
||||||
@@ -400,7 +399,7 @@ class AdaRNN(nn.Module):
|
|||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.trans_loss = trans_loss
|
self.trans_loss = trans_loss
|
||||||
self.len_seq = len_seq
|
self.len_seq = len_seq
|
||||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||||
in_size = self.n_input
|
in_size = self.n_input
|
||||||
|
|
||||||
features = nn.ModuleList()
|
features = nn.ModuleList()
|
||||||
@@ -499,7 +498,8 @@ class AdaRNN(nn.Module):
|
|||||||
res = self.softmax(weight).squeeze()
|
res = self.softmax(weight).squeeze()
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def get_features(self, output_list):
|
@staticmethod
|
||||||
|
def get_features(output_list):
|
||||||
fea_list_src, fea_list_tar = [], []
|
fea_list_src, fea_list_tar = [], []
|
||||||
for fea in output_list:
|
for fea in output_list:
|
||||||
fea_list_src.append(fea[0 : fea.size(0) // 2])
|
fea_list_src.append(fea[0 : fea.size(0) // 2])
|
||||||
@@ -561,7 +561,7 @@ class TransferLoss:
|
|||||||
"""
|
"""
|
||||||
self.loss_type = loss_type
|
self.loss_type = loss_type
|
||||||
self.input_dim = input_dim
|
self.input_dim = input_dim
|
||||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||||
|
|
||||||
def compute(self, X, Y):
|
def compute(self, X, Y):
|
||||||
"""Compute adaptation loss
|
"""Compute adaptation loss
|
||||||
@@ -676,7 +676,8 @@ class MMD_loss(nn.Module):
|
|||||||
self.fix_sigma = None
|
self.fix_sigma = None
|
||||||
self.kernel_type = kernel_type
|
self.kernel_type = kernel_type
|
||||||
|
|
||||||
def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
|
@staticmethod
|
||||||
|
def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
|
||||||
n_samples = int(source.size()[0]) + int(target.size()[0])
|
n_samples = int(source.size()[0]) + int(target.size()[0])
|
||||||
total = torch.cat([source, target], dim=0)
|
total = torch.cat([source, target], dim=0)
|
||||||
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
|
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
|
||||||
@@ -691,7 +692,8 @@ class MMD_loss(nn.Module):
|
|||||||
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
|
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
|
||||||
return sum(kernel_val)
|
return sum(kernel_val)
|
||||||
|
|
||||||
def linear_mmd(self, X, Y):
|
@staticmethod
|
||||||
|
def linear_mmd(X, Y):
|
||||||
delta = X.mean(axis=0) - Y.mean(axis=0)
|
delta = X.mean(axis=0) - Y.mean(axis=0)
|
||||||
loss = delta.dot(delta.T)
|
loss = delta.dot(delta.T)
|
||||||
return loss
|
return loss
|
||||||
|
|||||||
@@ -428,7 +428,7 @@ class EnhancedIndexingStrategy(WeightStrategyBase):
|
|||||||
specific_risk = load_dataset(root + "/" + self.specific_risk_path, index_col=[0])
|
specific_risk = load_dataset(root + "/" + self.specific_risk_path, index_col=[0])
|
||||||
|
|
||||||
if not factor_exp.index.equals(specific_risk.index):
|
if not factor_exp.index.equals(specific_risk.index):
|
||||||
# NOTE: for stocks missing specific_risk, we always assume it have the highest volatility
|
# NOTE: for stocks missing specific_risk, we always assume it has the highest volatility
|
||||||
specific_risk = specific_risk.reindex(factor_exp.index, fill_value=specific_risk.max())
|
specific_risk = specific_risk.reindex(factor_exp.index, fill_value=specific_risk.max())
|
||||||
|
|
||||||
universe = factor_exp.index.tolist()
|
universe = factor_exp.index.tolist()
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ class StructuredCovEstimator(RiskModel):
|
|||||||
`B` is the regression coefficients matrix for all observations (row) on
|
`B` is the regression coefficients matrix for all observations (row) on
|
||||||
all factors (columns), and `U` is the residual matrix with shape like `X`.
|
all factors (columns), and `U` is the residual matrix with shape like `X`.
|
||||||
|
|
||||||
Therefore the structured covariance can be estimated by
|
Therefore, the structured covariance can be estimated by
|
||||||
cov(X.T) = F @ cov(B.T) @ F.T + diag(var(U))
|
cov(X.T) = F @ cov(B.T) @ F.T + diag(var(U))
|
||||||
|
|
||||||
In finance domain, there are mainly three methods to design `F` [1][2]:
|
In finance domain, there are mainly three methods to design `F` [1][2]:
|
||||||
|
|||||||
@@ -155,7 +155,7 @@ class QlibRecorder:
|
|||||||
|
|
||||||
The arguments of this function are not set to be rigid, and they will be different with different implementation of
|
The arguments of this function are not set to be rigid, and they will be different with different implementation of
|
||||||
``ExpManager`` in ``Qlib``. ``Qlib`` now provides an implementation of ``ExpManager`` with mlflow, and here is the
|
``ExpManager`` in ``Qlib``. ``Qlib`` now provides an implementation of ``ExpManager`` with mlflow, and here is the
|
||||||
example code of the this method with the ``MLflowExpManager``:
|
example code of the method with the ``MLflowExpManager``:
|
||||||
|
|
||||||
.. code-block:: Python
|
.. code-block:: Python
|
||||||
|
|
||||||
|
|||||||
@@ -30,7 +30,8 @@ class RecordTemp:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
artifact_path = None
|
artifact_path = None
|
||||||
depend_cls = None # the depend class of the record; the record will depend on the results generated by `depend_cls`
|
depend_cls = None # the dependant class of the record; the record will depend on the results generated by
|
||||||
|
# `depend_cls`
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_path(cls, path=None):
|
def get_path(cls, path=None):
|
||||||
@@ -119,7 +120,7 @@ class RecordTemp:
|
|||||||
Check if the records is properly generated and saved.
|
Check if the records is properly generated and saved.
|
||||||
It is useful in following examples
|
It is useful in following examples
|
||||||
|
|
||||||
- checking if the depended files complete before generating new things.
|
- checking if the dependant files complete before generating new things.
|
||||||
- checking if the final files is completed
|
- checking if the final files is completed
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -186,7 +187,7 @@ class SignalRecord(RecordTemp):
|
|||||||
return raw_label
|
return raw_label
|
||||||
|
|
||||||
def generate(self, **kwargs):
|
def generate(self, **kwargs):
|
||||||
# generate prediciton
|
# generate prediction
|
||||||
pred = self.model.predict(self.dataset)
|
pred = self.model.predict(self.dataset)
|
||||||
if isinstance(pred, pd.Series):
|
if isinstance(pred, pd.Series):
|
||||||
pred = pred.to_frame("score")
|
pred = pred.to_frame("score")
|
||||||
@@ -285,7 +286,8 @@ class HFSignalRecord(SignalRecord):
|
|||||||
|
|
||||||
class SigAnaRecord(ACRecordTemp):
|
class SigAnaRecord(ACRecordTemp):
|
||||||
"""
|
"""
|
||||||
This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class.
|
This is the Signal Analysis Record class that generates the analysis results such as IC and IR.
|
||||||
|
This class inherits the ``RecordTemp`` class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
artifact_path = "sig_analysis"
|
artifact_path = "sig_analysis"
|
||||||
@@ -382,7 +384,7 @@ class PortAnaRecord(ACRecordTemp):
|
|||||||
indicator_analysis_freq : str|List[str]
|
indicator_analysis_freq : str|List[str]
|
||||||
indicator analysis freq of report
|
indicator analysis freq of report
|
||||||
indicator_analysis_method : str, optional, default by None
|
indicator_analysis_method : str, optional, default by None
|
||||||
the candidated values include 'mean', 'amount_weighted', 'value_weighted'
|
the candidate values include 'mean', 'amount_weighted', 'value_weighted'
|
||||||
"""
|
"""
|
||||||
super().__init__(recorder=recorder, skip_existing=skip_existing, **kwargs)
|
super().__init__(recorder=recorder, skip_existing=skip_existing, **kwargs)
|
||||||
|
|
||||||
@@ -456,9 +458,9 @@ class PortAnaRecord(ACRecordTemp):
|
|||||||
pred = self.load("pred.pkl")
|
pred = self.load("pred.pkl")
|
||||||
|
|
||||||
# replace the "<PRED>" with prediction saved before
|
# replace the "<PRED>" with prediction saved before
|
||||||
placehorder_value = {"<PRED>": pred}
|
placeholder_value = {"<PRED>": pred}
|
||||||
for k in "executor_config", "strategy_config":
|
for k in "executor_config", "strategy_config":
|
||||||
setattr(self, k, fill_placeholder(getattr(self, k), placehorder_value))
|
setattr(self, k, fill_placeholder(getattr(self, k), placeholder_value))
|
||||||
|
|
||||||
# if the backtesting time range is not set, it will automatically extract time range from the prediction file
|
# if the backtesting time range is not set, it will automatically extract time range from the prediction file
|
||||||
dt_values = pred.index.get_level_values("datetime")
|
dt_values = pred.index.get_level_values("datetime")
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ cd qlib/scripts/data_collector/pit/
|
|||||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/pit --start 2000-01-01 --end 2020-01-01 --interval quarterly
|
python collector.py download_data --source_dir ~/.qlib/stock_data/source/pit --start 2000-01-01 --end 2020-01-01 --interval quarterly
|
||||||
```
|
```
|
||||||
|
|
||||||
Downloading all data from the stock is very time consuming. If you just want run a quick test on a few stocks, you can run the command below
|
Downloading all data from the stock is very time-consuming. If you just want to run a quick test on a few stocks, you can run the command below
|
||||||
```bash
|
```bash
|
||||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/pit --start 2000-01-01 --end 2020-01-01 --interval quarterly --symbol_regex "^(600519|000725).*"
|
python collector.py download_data --source_dir ~/.qlib/stock_data/source/pit --start 2000-01-01 --end 2020-01-01 --interval quarterly --symbol_regex "^(600519|000725).*"
|
||||||
```
|
```
|
||||||
|
|||||||
Reference in New Issue
Block a user