mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Fix Models (#483)
* fix gat dataset * fix tft model * Update tft.py * Fix tft.py Co-authored-by: Pengrong Zhu <zhu.pengrong@foxmail.com>
This commit is contained in:
@@ -8,7 +8,7 @@
|
||||
Users can follow the ``workflow_by_code_tft.py`` to run the benchmark.
|
||||
|
||||
### Notes
|
||||
1. Please be **aware** that this script can only support `Python 3.5 - 3.8`.
|
||||
1. Please be **aware** that this script can only support `Python 3.6 - 3.7`.
|
||||
2. If the CUDA version on your machine is not 10.0, please remember to run the following commands `conda install anaconda cudatoolkit=10.0` and `conda install cudnn` on your machine.
|
||||
3. The model must run in GPU, or an error will be raised.
|
||||
4. New datasets should be registered in ``data_formatters``, for detail please visit the source.
|
||||
|
||||
@@ -1,3 +1,2 @@
|
||||
tensorflow-gpu==1.15.0
|
||||
numpy == 1.19.4
|
||||
pandas==1.1.0
|
||||
pandas==1.1.0
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tensorflow.compat.v1 as tf
|
||||
@@ -243,7 +245,7 @@ class TFTModel(ModelFT):
|
||||
# extract_numerical_data(targets), extract_numerical_data(p90_forecast),
|
||||
# 0.9)
|
||||
tf.keras.backend.set_session(default_keras_session)
|
||||
print("Training completed.".format(dte.datetime.now()))
|
||||
print("Training completed at {}.".format(dte.datetime.now()))
|
||||
# ===========================Training Process===========================
|
||||
|
||||
def predict(self, dataset):
|
||||
@@ -289,3 +291,24 @@ class TFTModel(ModelFT):
|
||||
dataset for finetuning
|
||||
"""
|
||||
pass
|
||||
|
||||
def to_pickle(self, path: Union[Path, str]):
|
||||
"""
|
||||
Tensorflow model can't be dumped directly.
|
||||
So the data should be save seperatedly
|
||||
|
||||
**TODO**: Please implement the function to load the files
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : Union[Path, str]
|
||||
the target path to be dumped
|
||||
"""
|
||||
# save tensorflow model
|
||||
# path = Path(path)
|
||||
# path.mkdir(parents=True)
|
||||
# self.model.save(path)
|
||||
|
||||
# save qlib model wrapper
|
||||
self.model = None
|
||||
super(TFTModel, self).to_pickle(path / "qlib_model")
|
||||
|
||||
Reference in New Issue
Block a user