mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 12:00:58 +08:00
Ptnn4both datatypes and alignment tests (#1827)
* Init model for both dataset * Remove some deprecated code * Add model template; * We must align with previous results * We choose another mode as the initial version * Almost success to run GRU * Successfully run training * Passed general_nn test * gru test * Alignment test passed * comment * fix readme & minor errors * general nn updates & benchmarks * Update examples/benchmarks/GeneralPtNN/workflow_config_gru2mlp.yaml --------- Co-authored-by: Young <afe.young@gmail.com> Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
This commit is contained in:
76
tests/model/test_general_nn.py
Normal file
76
tests/model/test_general_nn.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import unittest
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
|
||||
class TestNN(TestAutoData):
|
||||
def test_both_dataset(self):
|
||||
try:
|
||||
from qlib.contrib.model.pytorch_general_nn import GeneralPTNN
|
||||
from qlib.data.dataset import DatasetH, TSDatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
except ImportError:
|
||||
print("Import error.")
|
||||
return
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"instruments": "csi300",
|
||||
"data_loader": {
|
||||
"class": "QlibDataLoader", # Assuming QlibDataLoader is a string reference to the class
|
||||
"kwargs": {
|
||||
"config": {
|
||||
"feature": [["$high", "$close", "$low"], ["H", "C", "L"]],
|
||||
"label": [["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"]],
|
||||
},
|
||||
"freq": "day",
|
||||
},
|
||||
},
|
||||
# TODO: processors
|
||||
"learn_processors": [
|
||||
{
|
||||
"class": "DropnaLabel",
|
||||
},
|
||||
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
|
||||
],
|
||||
}
|
||||
segments = {
|
||||
"train": ["2008-01-01", "2014-12-31"],
|
||||
"valid": ["2015-01-01", "2016-12-31"],
|
||||
"test": ["2017-01-01", "2020-08-01"],
|
||||
}
|
||||
data_handler = DataHandlerLP(**data_handler_config)
|
||||
|
||||
# time-series dataset
|
||||
tsds = TSDatasetH(handler=data_handler, segments=segments)
|
||||
|
||||
# tabular dataset
|
||||
tbds = DatasetH(handler=data_handler, segments=segments)
|
||||
|
||||
model_l = [
|
||||
GeneralPTNN(
|
||||
n_epochs=2,
|
||||
pt_model_uri="qlib.contrib.model.pytorch_gru_ts.GRUModel",
|
||||
pt_model_kwargs={
|
||||
"d_feat": 3,
|
||||
"hidden_size": 8,
|
||||
"num_layers": 1,
|
||||
"dropout": 0.0,
|
||||
},
|
||||
),
|
||||
GeneralPTNN(
|
||||
n_epochs=2,
|
||||
pt_model_uri="qlib.contrib.model.pytorch_nn.Net", # it is a MLP
|
||||
pt_model_kwargs={
|
||||
"input_dim": 3,
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
for ds, model in list(zip((tsds, tbds), model_l)):
|
||||
model.fit(ds) # It works
|
||||
model.predict(ds) # It works
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user