1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Fix Models (#483)

* fix gat dataset

* fix tft model

* Update tft.py

* Fix tft.py

Co-authored-by: Pengrong Zhu <zhu.pengrong@foxmail.com>
This commit is contained in:
you-n-g
2021-09-30 13:11:06 +08:00
committed by GitHub
parent b6a8bd5b80
commit fc243fd29b
8 changed files with 50 additions and 8 deletions

View File

@@ -8,7 +8,7 @@
Users can follow the ``workflow_by_code_tft.py`` to run the benchmark.
### Notes
1. Please be **aware** that this script can only support `Python 3.5 - 3.8`.
1. Please be **aware** that this script can only support `Python 3.6 - 3.7`.
2. If the CUDA version on your machine is not 10.0, please remember to run the following commands `conda install anaconda cudatoolkit=10.0` and `conda install cudnn` on your machine.
3. The model must run in GPU, or an error will be raised.
4. New datasets should be registered in ``data_formatters``, for detail please visit the source.

View File

@@ -1,3 +1,2 @@
tensorflow-gpu==1.15.0
numpy == 1.19.4
pandas==1.1.0
pandas==1.1.0

View File

@@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from pathlib import Path
from typing import Union
import numpy as np
import pandas as pd
import tensorflow.compat.v1 as tf
@@ -243,7 +245,7 @@ class TFTModel(ModelFT):
# extract_numerical_data(targets), extract_numerical_data(p90_forecast),
# 0.9)
tf.keras.backend.set_session(default_keras_session)
print("Training completed.".format(dte.datetime.now()))
print("Training completed at {}.".format(dte.datetime.now()))
# ===========================Training Process===========================
def predict(self, dataset):
@@ -289,3 +291,24 @@ class TFTModel(ModelFT):
dataset for finetuning
"""
pass
def to_pickle(self, path: Union[Path, str]):
"""
Tensorflow model can't be dumped directly.
So the data should be save seperatedly
**TODO**: Please implement the function to load the files
Parameters
----------
path : Union[Path, str]
the target path to be dumped
"""
# save tensorflow model
# path = Path(path)
# path.mkdir(parents=True)
# self.model.save(path)
# save qlib model wrapper
self.model = None
super(TFTModel, self).to_pickle(path / "qlib_model")