1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00
Files
qlib/tests/model/test_general_nn.py
Linlang a0cef033cb update python version (#1868)
* update python version

* fix: Correct selector handling and add time filtering in storage.py

* fix: convert index and columns to list in repr methods

* feat: Add Makefile for managing project prerequisites

* feat: Add Cython extensions for rolling and expanding operations

* resolve install error

* fix lint error

* fix lint error

* fix lint error

* fix lint error

* fix lint error

* update build package

* update makefile

* update ci yaml

* fix docs build error

* fix ubuntu install error

* fix docs build error

* fix install error

* fix install error

* fix install error

* fix install error

* fix pylint error

* fix pylint error

* fix pylint error

* fix pylint error

* fix pylint error E1123

* fix pylint error R0917

* fix pytest error

* fix pytest error

* fix pytest error

* update code

* update code

* fix ci error

* fix pylint error

* fix black error

* fix pytest error

* fix CI error

* fix CI error

* add python version to CI

* add python version to CI

* add python version to CI

* fix pylint error

* fix pytest general nn error

* fix CI error

* optimize code

* add coments

* Extended macos version

* remove build package

---------

Co-authored-by: Young <afe.young@gmail.com>
2024-12-17 11:30:06 +08:00

81 lines
2.6 KiB
Python

import unittest
from qlib.tests import TestAutoData
class TestNN(TestAutoData):
def test_both_dataset(self):
try:
from qlib.contrib.model.pytorch_general_nn import GeneralPTNN
from qlib.data.dataset import DatasetH, TSDatasetH
from qlib.data.dataset.handler import DataHandlerLP
except ImportError:
print("Import error.")
return
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"instruments": "csi300",
"data_loader": {
"class": "QlibDataLoader", # Assuming QlibDataLoader is a string reference to the class
"kwargs": {
"config": {
"feature": [["$high", "$close", "$low"], ["H", "C", "L"]],
"label": [["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"]],
},
"freq": "day",
},
},
# TODO: processors
"learn_processors": [
{
"class": "DropnaLabel",
},
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
],
}
segments = {
"train": ["2008-01-01", "2014-12-31"],
"valid": ["2015-01-01", "2016-12-31"],
"test": ["2017-01-01", "2020-08-01"],
}
data_handler = DataHandlerLP(**data_handler_config)
# time-series dataset
tsds = TSDatasetH(handler=data_handler, segments=segments)
# tabular dataset
tbds = DatasetH(handler=data_handler, segments=segments)
model_l = [
GeneralPTNN(
n_epochs=2,
batch_size=32,
n_jobs=0,
pt_model_uri="qlib.contrib.model.pytorch_gru_ts.GRUModel",
pt_model_kwargs={
"d_feat": 3,
"hidden_size": 8,
"num_layers": 1,
"dropout": 0.0,
},
),
GeneralPTNN(
n_epochs=2,
batch_size=32,
n_jobs=0,
pt_model_uri="qlib.contrib.model.pytorch_nn.Net", # it is a MLP
pt_model_kwargs={
"input_dim": 3,
},
),
]
for ds, model in list(zip((tsds, tbds), model_l)):
model.fit(ds) # It works
model.predict(ds) # It works
if __name__ == "__main__":
unittest.main()