From fc243fd29bd65cb366cd3fc8312831e172a85aac Mon Sep 17 00:00:00 2001 From: you-n-g Date: Thu, 30 Sep 2021 13:11:06 +0800 Subject: [PATCH] Fix Models (#483) * fix gat dataset * fix tft model * Update tft.py * Fix tft.py Co-authored-by: Pengrong Zhu --- README.md | 1 + examples/benchmarks/TFT/README.md | 2 +- examples/benchmarks/TFT/requirements.txt | 3 +-- examples/benchmarks/TFT/tft.py | 25 +++++++++++++++++++++++- qlib/contrib/model/pytorch_gats_ts.py | 1 - qlib/utils/serial.py | 19 ++++++++++++++++++ qlib/workflow/recorder.py | 5 +++-- setup.py | 2 +- 8 files changed, 50 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 255551404..db0b6124e 100644 --- a/README.md +++ b/README.md @@ -308,6 +308,7 @@ All the models listed above are runnable with ``Qlib``. Users can find the confi - Users can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder. - Users can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py). + - **NOTE**: Each baseline has different environment dependencies, please make sure that your python version aligns with the requirements(e.g. TFT only supports Python 3.6~3.7 due to the limitation of `tensorflow==1.15.0`) ## Run multiple models `Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only support *Linux* for now. Other OS will be supported in the future. Besides, it doesn't support parallel running the same model for multiple times as well, and this will be fixed in the future development too.) diff --git a/examples/benchmarks/TFT/README.md b/examples/benchmarks/TFT/README.md index 5a6a9f153..991066b7f 100644 --- a/examples/benchmarks/TFT/README.md +++ b/examples/benchmarks/TFT/README.md @@ -8,7 +8,7 @@ Users can follow the ``workflow_by_code_tft.py`` to run the benchmark. ### Notes -1. Please be **aware** that this script can only support `Python 3.5 - 3.8`. +1. Please be **aware** that this script can only support `Python 3.6 - 3.7`. 2. If the CUDA version on your machine is not 10.0, please remember to run the following commands `conda install anaconda cudatoolkit=10.0` and `conda install cudnn` on your machine. 3. The model must run in GPU, or an error will be raised. 4. New datasets should be registered in ``data_formatters``, for detail please visit the source. diff --git a/examples/benchmarks/TFT/requirements.txt b/examples/benchmarks/TFT/requirements.txt index 04234aaed..f8bd00002 100644 --- a/examples/benchmarks/TFT/requirements.txt +++ b/examples/benchmarks/TFT/requirements.txt @@ -1,3 +1,2 @@ tensorflow-gpu==1.15.0 -numpy == 1.19.4 -pandas==1.1.0 \ No newline at end of file +pandas==1.1.0 diff --git a/examples/benchmarks/TFT/tft.py b/examples/benchmarks/TFT/tft.py index e1205b0e0..3908b2777 100644 --- a/examples/benchmarks/TFT/tft.py +++ b/examples/benchmarks/TFT/tft.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from pathlib import Path +from typing import Union import numpy as np import pandas as pd import tensorflow.compat.v1 as tf @@ -243,7 +245,7 @@ class TFTModel(ModelFT): # extract_numerical_data(targets), extract_numerical_data(p90_forecast), # 0.9) tf.keras.backend.set_session(default_keras_session) - print("Training completed.".format(dte.datetime.now())) + print("Training completed at {}.".format(dte.datetime.now())) # ===========================Training Process=========================== def predict(self, dataset): @@ -289,3 +291,24 @@ class TFTModel(ModelFT): dataset for finetuning """ pass + + def to_pickle(self, path: Union[Path, str]): + """ + Tensorflow model can't be dumped directly. + So the data should be save seperatedly + + **TODO**: Please implement the function to load the files + + Parameters + ---------- + path : Union[Path, str] + the target path to be dumped + """ + # save tensorflow model + # path = Path(path) + # path.mkdir(parents=True) + # self.model.save(path) + + # save qlib model wrapper + self.model = None + super(TFTModel, self).to_pickle(path / "qlib_model") diff --git a/qlib/contrib/model/pytorch_gats_ts.py b/qlib/contrib/model/pytorch_gats_ts.py index 09123cc5c..52b7183be 100644 --- a/qlib/contrib/model/pytorch_gats_ts.py +++ b/qlib/contrib/model/pytorch_gats_ts.py @@ -27,7 +27,6 @@ from ...contrib.model.pytorch_gru import GRUModel class DailyBatchSampler(Sampler): def __init__(self, data_source): - self.data_source = data_source # calculate number of samples in each batch self.daily_count = pd.Series(index=self.data_source.get_index()).groupby("datetime").size().values diff --git a/qlib/utils/serial.py b/qlib/utils/serial.py index f801b125c..04d16ab7a 100644 --- a/qlib/utils/serial.py +++ b/qlib/utils/serial.py @@ -122,3 +122,22 @@ class Serializable: return dill else: raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.") + + @staticmethod + def general_dump(obj, path: Union[Path, str]): + """ + A general dumping method for object + + Parameters + ---------- + obj : object + the object to be dumped + path : Union[Path, str] + the target path the data will be dumped + """ + path = Path(path) + if isinstance(obj, Serializable): + obj.to_pickle(path) + else: + with path.open("wb") as f: + pickle.dump(obj, f) diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 1b391cbe2..57d7a0f22 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from qlib.utils.serial import Serializable import mlflow, logging import shutil, os, pickle, tempfile, codecs, pickle from pathlib import Path @@ -307,8 +308,8 @@ class MLflowRecorder(Recorder): else: temp_dir = Path(tempfile.mkdtemp()).resolve() for name, data in kwargs.items(): - with (temp_dir / name).open("wb") as f: - pickle.dump(data, f) + path = temp_dir / name + Serializable.general_dump(data, path) self.client.log_artifact(self.id, temp_dir / name, artifact_path) shutil.rmtree(temp_dir) diff --git a/setup.py b/setup.py index d43257353..21d56371e 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ REQUIRED = [ "statsmodels", "xlrd>=1.0.0", "plotly==4.12.0", - "matplotlib==3.3", + "matplotlib>=3.3", "tables>=3.6.1", "pyyaml>=5.3.1", "mlflow>=1.12.1",