mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Adjust rolling api (#1594)
* Intermediate version * Fix yaml template & Successfully run rolling * Be compatible with benchmark * Get same results with previous linear model * Black formatting * Update black * Update the placeholder mechanism * Update CI * Update CI * Upgrade Black * Fix CI and simplify code * Fix CI * Move the data processing caching mechanism into utils. * Adjusting DDG-DA * Organize import
This commit is contained in:
@@ -6,7 +6,6 @@ from qlib.utils import init_instance_by_config
|
||||
|
||||
|
||||
def main(seed, config_file="configs/config_alstm.yaml"):
|
||||
|
||||
# set random seed
|
||||
with open(config_file) as f:
|
||||
config = yaml.safe_load(f)
|
||||
@@ -30,7 +29,6 @@ def main(seed, config_file="configs/config_alstm.yaml"):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# set params from cmd
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("--seed", type=int, default=1000, help="random seed")
|
||||
|
||||
@@ -96,7 +96,6 @@ class MTSDatasetH(DatasetH):
|
||||
drop_last=False,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
assert horizon > 0, "please specify `horizon` to avoid data leakage"
|
||||
|
||||
self.seq_len = seq_len
|
||||
@@ -111,7 +110,6 @@ class MTSDatasetH(DatasetH):
|
||||
super().__init__(handler, segments, **kwargs)
|
||||
|
||||
def setup_data(self, handler_kwargs: dict = None, **kwargs):
|
||||
|
||||
super().setup_data()
|
||||
|
||||
# change index to <code, date>
|
||||
|
||||
@@ -45,7 +45,6 @@ class TRAModel(Model):
|
||||
avg_params=True,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
@@ -93,7 +92,6 @@ class TRAModel(Model):
|
||||
self.global_step = -1
|
||||
|
||||
def train_epoch(self, data_set):
|
||||
|
||||
self.model.train()
|
||||
self.tra.train()
|
||||
|
||||
@@ -146,7 +144,6 @@ class TRAModel(Model):
|
||||
return total_loss
|
||||
|
||||
def test_epoch(self, data_set, return_pred=False):
|
||||
|
||||
self.model.eval()
|
||||
self.tra.eval()
|
||||
data_set.eval()
|
||||
@@ -204,7 +201,6 @@ class TRAModel(Model):
|
||||
return metrics, preds
|
||||
|
||||
def fit(self, dataset, evals_result=dict()):
|
||||
|
||||
train_set, valid_set, test_set = dataset.prepare(["train", "valid", "test"])
|
||||
|
||||
best_score = -1
|
||||
@@ -380,7 +376,6 @@ class LSTM(nn.Module):
|
||||
self.output_size = hidden_size
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = self.input_drop(x)
|
||||
|
||||
if self.training and self.noise_level > 0:
|
||||
@@ -464,7 +459,6 @@ class Transformer(nn.Module):
|
||||
self.output_size = hidden_size
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = self.input_drop(x)
|
||||
|
||||
if self.training and self.noise_level > 0:
|
||||
@@ -514,7 +508,6 @@ class TRA(nn.Module):
|
||||
self.predictors = nn.Linear(input_size, num_states)
|
||||
|
||||
def forward(self, hidden, hist_loss):
|
||||
|
||||
preds = self.predictors(hidden)
|
||||
|
||||
if self.num_states == 1:
|
||||
|
||||
@@ -57,9 +57,7 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
signal: <PRED>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -51,9 +51,7 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
signal: <PRED>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -51,9 +51,7 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
signal: <PRED>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
Reference in New Issue
Block a user