mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 18:11:18 +08:00
* Init model for both dataset * Remove some deprecated code * Add model template; * We must align with previous results * We choose another mode as the initial version * Almost success to run GRU * Successfully run training * Passed general_nn test * gru test * Alignment test passed * comment * fix readme & minor errors * general nn updates & benchmarks * Update examples/benchmarks/GeneralPtNN/workflow_config_gru2mlp.yaml --------- Co-authored-by: Young <afe.young@gmail.com> Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
77 lines
2.4 KiB
Python
77 lines
2.4 KiB
Python
import unittest
|
|
from qlib.tests import TestAutoData
|
|
|
|
|
|
class TestNN(TestAutoData):
|
|
def test_both_dataset(self):
|
|
try:
|
|
from qlib.contrib.model.pytorch_general_nn import GeneralPTNN
|
|
from qlib.data.dataset import DatasetH, TSDatasetH
|
|
from qlib.data.dataset.handler import DataHandlerLP
|
|
except ImportError:
|
|
print("Import error.")
|
|
return
|
|
|
|
data_handler_config = {
|
|
"start_time": "2008-01-01",
|
|
"end_time": "2020-08-01",
|
|
"instruments": "csi300",
|
|
"data_loader": {
|
|
"class": "QlibDataLoader", # Assuming QlibDataLoader is a string reference to the class
|
|
"kwargs": {
|
|
"config": {
|
|
"feature": [["$high", "$close", "$low"], ["H", "C", "L"]],
|
|
"label": [["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"]],
|
|
},
|
|
"freq": "day",
|
|
},
|
|
},
|
|
# TODO: processors
|
|
"learn_processors": [
|
|
{
|
|
"class": "DropnaLabel",
|
|
},
|
|
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
|
|
],
|
|
}
|
|
segments = {
|
|
"train": ["2008-01-01", "2014-12-31"],
|
|
"valid": ["2015-01-01", "2016-12-31"],
|
|
"test": ["2017-01-01", "2020-08-01"],
|
|
}
|
|
data_handler = DataHandlerLP(**data_handler_config)
|
|
|
|
# time-series dataset
|
|
tsds = TSDatasetH(handler=data_handler, segments=segments)
|
|
|
|
# tabular dataset
|
|
tbds = DatasetH(handler=data_handler, segments=segments)
|
|
|
|
model_l = [
|
|
GeneralPTNN(
|
|
n_epochs=2,
|
|
pt_model_uri="qlib.contrib.model.pytorch_gru_ts.GRUModel",
|
|
pt_model_kwargs={
|
|
"d_feat": 3,
|
|
"hidden_size": 8,
|
|
"num_layers": 1,
|
|
"dropout": 0.0,
|
|
},
|
|
),
|
|
GeneralPTNN(
|
|
n_epochs=2,
|
|
pt_model_uri="qlib.contrib.model.pytorch_nn.Net", # it is a MLP
|
|
pt_model_kwargs={
|
|
"input_dim": 3,
|
|
},
|
|
),
|
|
]
|
|
|
|
for ds, model in list(zip((tsds, tbds), model_l)):
|
|
model.fit(ds) # It works
|
|
model.predict(ds) # It works
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|