mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
* update python version * fix: Correct selector handling and add time filtering in storage.py * fix: convert index and columns to list in repr methods * feat: Add Makefile for managing project prerequisites * feat: Add Cython extensions for rolling and expanding operations * resolve install error * fix lint error * fix lint error * fix lint error * fix lint error * fix lint error * update build package * update makefile * update ci yaml * fix docs build error * fix ubuntu install error * fix docs build error * fix install error * fix install error * fix install error * fix install error * fix pylint error * fix pylint error * fix pylint error * fix pylint error * fix pylint error E1123 * fix pylint error R0917 * fix pytest error * fix pytest error * fix pytest error * update code * update code * fix ci error * fix pylint error * fix black error * fix pytest error * fix CI error * fix CI error * add python version to CI * add python version to CI * add python version to CI * fix pylint error * fix pytest general nn error * fix CI error * optimize code * add coments * Extended macos version * remove build package --------- Co-authored-by: Young <afe.young@gmail.com>
81 lines
2.6 KiB
Python
81 lines
2.6 KiB
Python
import unittest
|
|
from qlib.tests import TestAutoData
|
|
|
|
|
|
class TestNN(TestAutoData):
|
|
def test_both_dataset(self):
|
|
try:
|
|
from qlib.contrib.model.pytorch_general_nn import GeneralPTNN
|
|
from qlib.data.dataset import DatasetH, TSDatasetH
|
|
from qlib.data.dataset.handler import DataHandlerLP
|
|
except ImportError:
|
|
print("Import error.")
|
|
return
|
|
|
|
data_handler_config = {
|
|
"start_time": "2008-01-01",
|
|
"end_time": "2020-08-01",
|
|
"instruments": "csi300",
|
|
"data_loader": {
|
|
"class": "QlibDataLoader", # Assuming QlibDataLoader is a string reference to the class
|
|
"kwargs": {
|
|
"config": {
|
|
"feature": [["$high", "$close", "$low"], ["H", "C", "L"]],
|
|
"label": [["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"]],
|
|
},
|
|
"freq": "day",
|
|
},
|
|
},
|
|
# TODO: processors
|
|
"learn_processors": [
|
|
{
|
|
"class": "DropnaLabel",
|
|
},
|
|
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
|
|
],
|
|
}
|
|
segments = {
|
|
"train": ["2008-01-01", "2014-12-31"],
|
|
"valid": ["2015-01-01", "2016-12-31"],
|
|
"test": ["2017-01-01", "2020-08-01"],
|
|
}
|
|
data_handler = DataHandlerLP(**data_handler_config)
|
|
|
|
# time-series dataset
|
|
tsds = TSDatasetH(handler=data_handler, segments=segments)
|
|
|
|
# tabular dataset
|
|
tbds = DatasetH(handler=data_handler, segments=segments)
|
|
|
|
model_l = [
|
|
GeneralPTNN(
|
|
n_epochs=2,
|
|
batch_size=32,
|
|
n_jobs=0,
|
|
pt_model_uri="qlib.contrib.model.pytorch_gru_ts.GRUModel",
|
|
pt_model_kwargs={
|
|
"d_feat": 3,
|
|
"hidden_size": 8,
|
|
"num_layers": 1,
|
|
"dropout": 0.0,
|
|
},
|
|
),
|
|
GeneralPTNN(
|
|
n_epochs=2,
|
|
batch_size=32,
|
|
n_jobs=0,
|
|
pt_model_uri="qlib.contrib.model.pytorch_nn.Net", # it is a MLP
|
|
pt_model_kwargs={
|
|
"input_dim": 3,
|
|
},
|
|
),
|
|
]
|
|
|
|
for ds, model in list(zip((tsds, tbds), model_l)):
|
|
model.fit(ds) # It works
|
|
model.predict(ds) # It works
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|