1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 14:01:28 +08:00

Fix Models (#483)

* fix gat dataset

* fix tft model

* Update tft.py

* Fix tft.py

Co-authored-by: Pengrong Zhu <zhu.pengrong@foxmail.com>
This commit is contained in:
you-n-g
2021-09-30 13:11:06 +08:00
committed by GitHub
parent b6a8bd5b80
commit fc243fd29b
8 changed files with 50 additions and 8 deletions

View File

@@ -308,6 +308,7 @@ All the models listed above are runnable with ``Qlib``. Users can find the confi
- Users can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
- Users can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
- **NOTE**: Each baseline has different environment dependencies, please make sure that your python version aligns with the requirements(e.g. TFT only supports Python 3.6~3.7 due to the limitation of `tensorflow==1.15.0`)
## Run multiple models
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only support *Linux* for now. Other OS will be supported in the future. Besides, it doesn't support parallel running the same model for multiple times as well, and this will be fixed in the future development too.)

View File

@@ -8,7 +8,7 @@
Users can follow the ``workflow_by_code_tft.py`` to run the benchmark.
### Notes
1. Please be **aware** that this script can only support `Python 3.5 - 3.8`.
1. Please be **aware** that this script can only support `Python 3.6 - 3.7`.
2. If the CUDA version on your machine is not 10.0, please remember to run the following commands `conda install anaconda cudatoolkit=10.0` and `conda install cudnn` on your machine.
3. The model must run in GPU, or an error will be raised.
4. New datasets should be registered in ``data_formatters``, for detail please visit the source.

View File

@@ -1,3 +1,2 @@
tensorflow-gpu==1.15.0
numpy == 1.19.4
pandas==1.1.0
pandas==1.1.0

View File

@@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from pathlib import Path
from typing import Union
import numpy as np
import pandas as pd
import tensorflow.compat.v1 as tf
@@ -243,7 +245,7 @@ class TFTModel(ModelFT):
# extract_numerical_data(targets), extract_numerical_data(p90_forecast),
# 0.9)
tf.keras.backend.set_session(default_keras_session)
print("Training completed.".format(dte.datetime.now()))
print("Training completed at {}.".format(dte.datetime.now()))
# ===========================Training Process===========================
def predict(self, dataset):
@@ -289,3 +291,24 @@ class TFTModel(ModelFT):
dataset for finetuning
"""
pass
def to_pickle(self, path: Union[Path, str]):
"""
Tensorflow model can't be dumped directly.
So the data should be save seperatedly
**TODO**: Please implement the function to load the files
Parameters
----------
path : Union[Path, str]
the target path to be dumped
"""
# save tensorflow model
# path = Path(path)
# path.mkdir(parents=True)
# self.model.save(path)
# save qlib model wrapper
self.model = None
super(TFTModel, self).to_pickle(path / "qlib_model")

View File

@@ -27,7 +27,6 @@ from ...contrib.model.pytorch_gru import GRUModel
class DailyBatchSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
# calculate number of samples in each batch
self.daily_count = pd.Series(index=self.data_source.get_index()).groupby("datetime").size().values

View File

@@ -122,3 +122,22 @@ class Serializable:
return dill
else:
raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.")
@staticmethod
def general_dump(obj, path: Union[Path, str]):
"""
A general dumping method for object
Parameters
----------
obj : object
the object to be dumped
path : Union[Path, str]
the target path the data will be dumped
"""
path = Path(path)
if isinstance(obj, Serializable):
obj.to_pickle(path)
else:
with path.open("wb") as f:
pickle.dump(obj, f)

View File

@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from qlib.utils.serial import Serializable
import mlflow, logging
import shutil, os, pickle, tempfile, codecs, pickle
from pathlib import Path
@@ -307,8 +308,8 @@ class MLflowRecorder(Recorder):
else:
temp_dir = Path(tempfile.mkdtemp()).resolve()
for name, data in kwargs.items():
with (temp_dir / name).open("wb") as f:
pickle.dump(data, f)
path = temp_dir / name
Serializable.general_dump(data, path)
self.client.log_artifact(self.id, temp_dir / name, artifact_path)
shutil.rmtree(temp_dir)

View File

@@ -52,7 +52,7 @@ REQUIRED = [
"statsmodels",
"xlrd>=1.0.0",
"plotly==4.12.0",
"matplotlib==3.3",
"matplotlib>=3.3",
"tables>=3.6.1",
"pyyaml>=5.3.1",
"mlflow>=1.12.1",