1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 02:21:18 +08:00
Files
qlib/tests/model/test_general_nn.py
cyncyw c9ed050ef0 Ptnn4both datatypes and alignment tests (#1827)
* Init model for both dataset

* Remove some deprecated code

* Add model template;

* We must align with previous results

* We choose another mode as the initial version

* Almost success to run GRU

* Successfully run training

* Passed general_nn test

* gru test

* Alignment test passed

* comment

* fix readme & minor errors

* general nn updates & benchmarks

* Update examples/benchmarks/GeneralPtNN/workflow_config_gru2mlp.yaml

---------

Co-authored-by: Young <afe.young@gmail.com>
Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
2024-07-11 17:59:18 +08:00

77 lines
2.4 KiB
Python

import unittest
from qlib.tests import TestAutoData
class TestNN(TestAutoData):
def test_both_dataset(self):
try:
from qlib.contrib.model.pytorch_general_nn import GeneralPTNN
from qlib.data.dataset import DatasetH, TSDatasetH
from qlib.data.dataset.handler import DataHandlerLP
except ImportError:
print("Import error.")
return
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"instruments": "csi300",
"data_loader": {
"class": "QlibDataLoader", # Assuming QlibDataLoader is a string reference to the class
"kwargs": {
"config": {
"feature": [["$high", "$close", "$low"], ["H", "C", "L"]],
"label": [["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"]],
},
"freq": "day",
},
},
# TODO: processors
"learn_processors": [
{
"class": "DropnaLabel",
},
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
],
}
segments = {
"train": ["2008-01-01", "2014-12-31"],
"valid": ["2015-01-01", "2016-12-31"],
"test": ["2017-01-01", "2020-08-01"],
}
data_handler = DataHandlerLP(**data_handler_config)
# time-series dataset
tsds = TSDatasetH(handler=data_handler, segments=segments)
# tabular dataset
tbds = DatasetH(handler=data_handler, segments=segments)
model_l = [
GeneralPTNN(
n_epochs=2,
pt_model_uri="qlib.contrib.model.pytorch_gru_ts.GRUModel",
pt_model_kwargs={
"d_feat": 3,
"hidden_size": 8,
"num_layers": 1,
"dropout": 0.0,
},
),
GeneralPTNN(
n_epochs=2,
pt_model_uri="qlib.contrib.model.pytorch_nn.Net", # it is a MLP
pt_model_kwargs={
"input_dim": 3,
},
),
]
for ds, model in list(zip((tsds, tbds), model_l)):
model.fit(ds) # It works
model.predict(ds) # It works
if __name__ == "__main__":
unittest.main()