1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

fix CI error

This commit is contained in:
Linlang
2024-03-05 17:24:03 +08:00
parent 6ea921bd84
commit 8cf7bb3aaf
12 changed files with 24 additions and 35 deletions

View File

@@ -324,7 +324,6 @@ class TRAModel(Model):
class LSTM(nn.Module):
"""LSTM Model
Args:
@@ -414,7 +413,6 @@ class PositionalEncoding(nn.Module):
class Transformer(nn.Module):
"""Transformer Model
Args:
@@ -475,7 +473,6 @@ class Transformer(nn.Module):
class TRA(nn.Module):
"""Temporal Routing Adaptor (TRA)
TRA takes historical prediction errors & latent representation as inputs,

View File

@@ -162,13 +162,15 @@ def create_account_instance(
init_cash=init_cash,
position_dict=position_dict,
pos_type=pos_type,
benchmark_config={}
if benchmark is None
else {
"benchmark": benchmark,
"start_time": start_time,
"end_time": end_time,
},
benchmark_config=(
{}
if benchmark is None
else {
"benchmark": benchmark,
"start_time": start_time,
"end_time": end_time,
}
),
)

View File

@@ -622,9 +622,11 @@ class Indicator:
print(
"[Indicator({}) {}]: FFR: {}, PA: {}, POS: {}".format(
freq,
trade_start_time
if isinstance(trade_start_time, str)
else trade_start_time.strftime("%Y-%m-%d %H:%M:%S"),
(
trade_start_time
if isinstance(trade_start_time, str)
else trade_start_time.strftime("%Y-%m-%d %H:%M:%S")
),
fulfill_rate,
price_advantage,
positive_rate,

View File

@@ -3,6 +3,7 @@ Here is a batch of evaluation functions.
The interface should be redesigned carefully in the future.
"""
import pandas as pd
from typing import Tuple
from qlib import get_module_logger

View File

@@ -511,7 +511,6 @@ class TRAModel(Model):
class RNN(nn.Module):
"""RNN Model
Args:
@@ -601,7 +600,6 @@ class PositionalEncoding(nn.Module):
class Transformer(nn.Module):
"""Transformer Model
Args:
@@ -649,7 +647,6 @@ class Transformer(nn.Module):
class TRA(nn.Module):
"""Temporal Routing Adaptor (TRA)
TRA takes historical prediction errors & latent representation as inputs,

View File

@@ -373,7 +373,6 @@ class WeightStrategyBase(BaseSignalStrategy):
class EnhancedIndexingStrategy(WeightStrategyBase):
"""Enhanced Indexing Strategy
Enhanced indexing combines the arts of active management and passive management,

View File

@@ -71,15 +71,11 @@ def fetch_df_by_index(
if fetch_orig:
for slc in idx_slc:
if slc != slice(None, None):
return df.loc[
pd.IndexSlice[idx_slc],
] # noqa: E231
return df.loc[pd.IndexSlice[idx_slc],] # noqa: E231
else: # pylint: disable=W0120
return df
else:
return df.loc[
pd.IndexSlice[idx_slc],
] # noqa: E231
return df.loc[pd.IndexSlice[idx_slc],] # noqa: E231
def fetch_df_by_col(df: pd.DataFrame, col_set: Union[str, List[str]]) -> pd.DataFrame:

View File

@@ -30,7 +30,6 @@ class Ensemble:
class SingleKeyEnsemble(Ensemble):
"""
Extract the object if there is only one key and value in the dict. Make the result more readable.
{Only key: Only value} -> Only value
@@ -64,7 +63,6 @@ class SingleKeyEnsemble(Ensemble):
class RollingEnsemble(Ensemble):
"""Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble.
NOTE: The values of dict must be pd.DataFrame, and have the index "datetime".

View File

@@ -247,9 +247,7 @@ class ShrinkCovEstimator(RiskModel):
v1 = y.T.dot(z) / t - cov_mkt[:, None] * S
roff1 = np.sum(v1 * cov_mkt[:, None].T) / var_mkt - np.sum(np.diag(v1) * cov_mkt) / var_mkt
v3 = z.T.dot(z) / t - var_mkt * S
roff3 = (
np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2
)
roff3 = np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2
roff = 2 * roff1 - roff3
rho = rdiag + roff

View File

@@ -90,7 +90,6 @@ class OnlineStrategy:
class RollingStrategy(OnlineStrategy):
"""
This example strategy always uses the latest rolling model sas online models.
"""

View File

@@ -146,9 +146,7 @@ class DumpDataBase:
return (
self._include_fields
if self._include_fields
else set(df_columns) - set(self._exclude_fields)
if self._exclude_fields
else df_columns
else set(df_columns) - set(self._exclude_fields) if self._exclude_fields else df_columns
)
@staticmethod

View File

@@ -132,9 +132,11 @@ class DumpPitData:
return (
set(self._include_fields)
if self._include_fields
else set(df[self.field_column_name]) - set(self._exclude_fields)
if self._exclude_fields
else set(df[self.field_column_name])
else (
set(df[self.field_column_name]) - set(self._exclude_fields)
if self._exclude_fields
else set(df[self.field_column_name])
)
)
def get_filenames(self, symbol, field, interval):