1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 02:21:18 +08:00
Files
qlib/tests/model/test_general_nn.py
2024-07-08 06:11:51 +00:00

67 lines
2.0 KiB
Python

import unittest
from qlib.contrib.model.pytorch_general_nn import GeneralPTNN
from qlib.data.dataset import DatasetH, TSDatasetH
from qlib.data.dataset.handler import DataHandlerLP
from qlib.tests import TestAutoData
class TestNN(TestAutoData):
def test_both_dataset(self):
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"instruments": "csi300",
"data_loader": {
"class": "QlibDataLoader", # Assuming QlibDataLoader is a string reference to the class
"kwargs": {
"config": {
"feature": [
["$high", "$close", "$low"],
["H", "C", "L"]
],
"label": [
["Ref($close, -2)/Ref($close, -1) - 1"],
["LABEL0"]
]
},
"freq": "day"
}
},
# TODO: processors
"learn_processors": [
{
"class": "DropnaLabel",
},
{
"class": "CSZScoreNorm",
"kwargs": {
"fields_group": "label"
}
}
]
}
segments = {
"train": ["2008-01-01", "2014-12-31"],
"valid": ["2015-01-01", "2016-12-31"],
"test": ["2017-01-01", "2020-08-01"]
}
data_handler = DataHandlerLP(**data_handler_config)
# time-series dataset
tsds = TSDatasetH(handler=data_handler, segments=segments)
# tabular dataset
tbds = DatasetH(handler=data_handler, segments=segments)
for ds in (tsds, tbds):
ptnn = GeneralPTNN()
ptnn.fit(ds) # It works
ptnn.predict(ds) # It works
if __name__ == "__main__":
unittest.main()