mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 14:01:28 +08:00
Fix Models (#483)
* fix gat dataset * fix tft model * Update tft.py * Fix tft.py Co-authored-by: Pengrong Zhu <zhu.pengrong@foxmail.com>
This commit is contained in:
@@ -308,6 +308,7 @@ All the models listed above are runnable with ``Qlib``. Users can find the confi
|
||||
- Users can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
|
||||
|
||||
- Users can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
- **NOTE**: Each baseline has different environment dependencies, please make sure that your python version aligns with the requirements(e.g. TFT only supports Python 3.6~3.7 due to the limitation of `tensorflow==1.15.0`)
|
||||
|
||||
## Run multiple models
|
||||
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only support *Linux* for now. Other OS will be supported in the future. Besides, it doesn't support parallel running the same model for multiple times as well, and this will be fixed in the future development too.)
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
Users can follow the ``workflow_by_code_tft.py`` to run the benchmark.
|
||||
|
||||
### Notes
|
||||
1. Please be **aware** that this script can only support `Python 3.5 - 3.8`.
|
||||
1. Please be **aware** that this script can only support `Python 3.6 - 3.7`.
|
||||
2. If the CUDA version on your machine is not 10.0, please remember to run the following commands `conda install anaconda cudatoolkit=10.0` and `conda install cudnn` on your machine.
|
||||
3. The model must run in GPU, or an error will be raised.
|
||||
4. New datasets should be registered in ``data_formatters``, for detail please visit the source.
|
||||
|
||||
@@ -1,3 +1,2 @@
|
||||
tensorflow-gpu==1.15.0
|
||||
numpy == 1.19.4
|
||||
pandas==1.1.0
|
||||
pandas==1.1.0
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tensorflow.compat.v1 as tf
|
||||
@@ -243,7 +245,7 @@ class TFTModel(ModelFT):
|
||||
# extract_numerical_data(targets), extract_numerical_data(p90_forecast),
|
||||
# 0.9)
|
||||
tf.keras.backend.set_session(default_keras_session)
|
||||
print("Training completed.".format(dte.datetime.now()))
|
||||
print("Training completed at {}.".format(dte.datetime.now()))
|
||||
# ===========================Training Process===========================
|
||||
|
||||
def predict(self, dataset):
|
||||
@@ -289,3 +291,24 @@ class TFTModel(ModelFT):
|
||||
dataset for finetuning
|
||||
"""
|
||||
pass
|
||||
|
||||
def to_pickle(self, path: Union[Path, str]):
|
||||
"""
|
||||
Tensorflow model can't be dumped directly.
|
||||
So the data should be save seperatedly
|
||||
|
||||
**TODO**: Please implement the function to load the files
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : Union[Path, str]
|
||||
the target path to be dumped
|
||||
"""
|
||||
# save tensorflow model
|
||||
# path = Path(path)
|
||||
# path.mkdir(parents=True)
|
||||
# self.model.save(path)
|
||||
|
||||
# save qlib model wrapper
|
||||
self.model = None
|
||||
super(TFTModel, self).to_pickle(path / "qlib_model")
|
||||
|
||||
@@ -27,7 +27,6 @@ from ...contrib.model.pytorch_gru import GRUModel
|
||||
|
||||
class DailyBatchSampler(Sampler):
|
||||
def __init__(self, data_source):
|
||||
|
||||
self.data_source = data_source
|
||||
# calculate number of samples in each batch
|
||||
self.daily_count = pd.Series(index=self.data_source.get_index()).groupby("datetime").size().values
|
||||
|
||||
@@ -122,3 +122,22 @@ class Serializable:
|
||||
return dill
|
||||
else:
|
||||
raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.")
|
||||
|
||||
@staticmethod
|
||||
def general_dump(obj, path: Union[Path, str]):
|
||||
"""
|
||||
A general dumping method for object
|
||||
|
||||
Parameters
|
||||
----------
|
||||
obj : object
|
||||
the object to be dumped
|
||||
path : Union[Path, str]
|
||||
the target path the data will be dumped
|
||||
"""
|
||||
path = Path(path)
|
||||
if isinstance(obj, Serializable):
|
||||
obj.to_pickle(path)
|
||||
else:
|
||||
with path.open("wb") as f:
|
||||
pickle.dump(obj, f)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from qlib.utils.serial import Serializable
|
||||
import mlflow, logging
|
||||
import shutil, os, pickle, tempfile, codecs, pickle
|
||||
from pathlib import Path
|
||||
@@ -307,8 +308,8 @@ class MLflowRecorder(Recorder):
|
||||
else:
|
||||
temp_dir = Path(tempfile.mkdtemp()).resolve()
|
||||
for name, data in kwargs.items():
|
||||
with (temp_dir / name).open("wb") as f:
|
||||
pickle.dump(data, f)
|
||||
path = temp_dir / name
|
||||
Serializable.general_dump(data, path)
|
||||
self.client.log_artifact(self.id, temp_dir / name, artifact_path)
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user