1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Merge branch 'main' into Fix_collector_doc

This commit is contained in:
Kenneth Tang
2021-02-16 14:16:53 +08:00
30 changed files with 448 additions and 92 deletions

View File

@@ -82,13 +82,10 @@ This table demonstrates the supported Python version of `Qlib`:
2. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
### Install with pip
**Note**: Due to latest numpy release: version 1.20.0, unexpected errors will occur if you install or run Qlib with `numpy==1.20.0`. We recommend to use lower version of `numpy==1.19.5` for now and we will fix this incompatibility in the neaar future.
Users can easily install ``Qlib`` by pip according to the following command.
```bash
pip install numpy==1.19.5
pip install pyqlib --ignore-installed numpy
pip install pyqlib
```
**Note**: pip will install the latest stable qlib. However, the main branch of qlib is in active development. If you want to test the latest scripts or functions in the main branch. Please install qlib with the methods below.
@@ -121,7 +118,12 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
## Data Preparation
Load and prepare data by running the following code:
```bash
# get 1d data
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
# get 1min data
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --region cn --interval 1min
```
This dataset is created by public data collected by [crawler scripts](scripts/data_collector/), which have been released in

42
docs/advanced/serial.rst Normal file
View File

@@ -0,0 +1,42 @@
.. _serial:
=================================
Serialization
=================================
.. currentmodule:: qlib
Introduction
===================
``Qlib`` supports dumping the state of ``DataHandler``, ``DataSet``, ``Processor`` and ``Model``, etc. into a disk and reloading them.
Serializable Class
========================
``Qlib`` provides a base class ``qlib.utils.serial.Serializable``, whose state can be dumped into or loaded from disk in `pickle` format.
When users dump the state of a ``Serializable`` instance, the attributes of the instance whose name **does not** start with `_` will be saved on the disk.
Example
==========================
``Qlib``'s serializable class includes ``DataHandler``, ``DataSet``, ``Processor`` and ``Model``, etc., which are subclass of ``qlib.utils.serial.Serializable``.
Specifically, ``qlib.data.dataset.DatasetH`` is one of them. Users can serialize ``DatasetH`` as follows.
.. code-block:: Python
##=============dump dataset=============
dataset.to_pickle(path="dataset.pkl") # dataset is an instance of qlib.data.dataset.DatasetH
##=============reload dataset=============
with open("dataset.pkl", "rb") as file_dataset:
dataset = pickle.load(file_dataset)
.. note::
Only state of ``DatasetH`` should be saved on the disk, such as some `mean` and `variance` used for data normalization, etc.
After reloading the ``DatasetH``, users need to reinitialize it. It means that users can reset some states of ``DatasetH`` or ``QlibDataHandler`` such as `instruments`, `start_time`, `end_time` and `segments`, etc., and generate new data according to the states (data is not state and should not be saved on the disk).
A more detailed example is in this `link <https://github.com/microsoft/qlib/tree/main/examples/highfreq>`_.
API
===================
Please refer to `Serializable API <../reference/api.html#module-qlib.utils.serial.Serializable>`_.

View File

@@ -31,7 +31,7 @@ Qlib Format Data
We've specially designed a data structure to manage financial data, please refer to the `File storage design section in Qlib paper <https://arxiv.org/abs/2009.11189>`_ for detailed information.
Such data will be stored with filename suffix `.bin` (We'll call them `.bin` file, `.bin` format, or qlib format). `.bin` file is designed for scientific computing on finance data.
``Qlib`` provides two different off-the-shelf dataset, which can be accessed through this `link <https://github.com/microsoft/qlib/blob/main/qlib/contrib/data/handler.py>`_:
``Qlib`` provides two different off-the-shelf datasets, which can be accessed through this `link <https://github.com/microsoft/qlib/blob/main/qlib/contrib/data/handler.py>`_:
======================== ================= ================
Dataset US Market China Market
@@ -41,6 +41,7 @@ Alpha360 √ √
Alpha158 √ √
======================== ================= ================
Also, ``Qlib`` provides a high-frequency dataset. Users can run a high-frequency dataset example through this `link <https://github.com/microsoft/qlib/tree/main/examples/highfreq>`_.
Qlib Format Dataset
--------------------
@@ -48,8 +49,12 @@ Qlib Format Dataset
.. code-block:: bash
# download 1d
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
# download 1min
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1min --region cn --interval 1min
In addition to China-Stock data, ``Qlib`` also includes a US-Stock dataset, which can be downloaded with the following command:
.. code-block:: bash
@@ -167,7 +172,7 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t
- If users use ``Qlib`` in US-stock mode, US-stock data is required. ``Qlib`` also provides a script to download US-stock data. Users can use ``Qlib`` in US-stock mode according to the following steps:
- Download china-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
- Download us-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
- Initialize ``Qlib`` in US-stock mode
Supposed that users prepare their Qlib format data in the directory ``~/.qlib/csv_data/us_data``. Users only need to initialize ``Qlib`` as follows.

View File

@@ -94,6 +94,52 @@ The ``RecordTemp`` class is a class that enables generate experiment results suc
- ``SignalRecord``: This class generates the `prediction` results of the model.
- ``SigAnaRecord``: This class generates the `IC`, `ICIR`, `Rank IC` and `Rank ICIR` of the model.
Here is a simple example of what is done in ``SigAnaRecord``, which users can refer to if they want to calculate IC, Rank IC, Long-Short Return with their own prediction and label.
.. code-block:: Python
from qlib.contrib.eva.alpha import calc_ic, calc_long_short_return
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, 0])
- ``PortAnaRecord``: This class generates the results of `backtest`. The detailed information about `backtest` as well as the available `strategy`, users can refer to `Strategy <../component/strategy.html>`_ and `Backtest <../component/backtest.html>`_.
Here is a simple exampke of what is done in ``PortAnaRecord``, which users can refer to if they want to do backtest based on their own prediction and label.
.. code-block:: Python
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)
# backtest
STRATEGY_CONFIG = {
"topk": 50,
"n_drop": 5,
}
BACKTEST_CONFIG = {
"verbose": False,
"limit_threshold": 0.095,
"account": 100000000,
"benchmark": BENCHMARK,
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5,
}
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
# analysis
analysis = dict()
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
analysis["excess_return_with_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"] - report_normal["cost"])
analysis_df = pd.concat(analysis) # type: pd.DataFrame
print(analysis_df)
For more information about the APIs, please refer to `Record Template API <../reference/api.html#module-qlib.workflow.record_temp>`_.

View File

@@ -90,12 +90,12 @@ Below is a typical config file of ``qrun``.
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
module_path: qlib.workflow.record_temp
kwargs: {}
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
After saving the config into `configuration.yaml`, users could start the workflow and test their ideas with a single command below.

View File

@@ -49,6 +49,7 @@ Document Structure
Building Formulaic Alphas <advanced/alpha.rst>
Online & Offline mode <advanced/server.rst>
Serialization <advanced/serial.rst>
.. toctree::
:maxdepth: 3

View File

@@ -152,4 +152,14 @@ Recorder
Record Template
--------------------
.. automodule:: qlib.workflow.record_temp
:members:
Utils
====================
Serializable
--------------------
.. automodule:: qlib.utils.serial.Serializable
:members:

View File

@@ -0,0 +1,28 @@
# High-Frequency Dataset
This dataset is an example for RL high frequency trading.
## Get High-Frequency Data
Get high-frequency data by running the following command:
```bash
python workflow.py get_data
```
## Dump & Reload & Reinitialize the Dataset
The High-Frequency Dataset is implemented as `qlib.data.dataset.DatasetH` in the `workflow.py`. `DatatsetH` is the subclass of [`qlib.utils.serial.Serializable`](https://qlib.readthedocs.io/en/latest/advanced/serial.html), whose state can be dumped in or loaded from disk in `pickle` format.
### About Reinitialization
After reloading `Dataset` from disk, `Qlib` also support reinitializing the dataset. It means that users can reset some states of `Dataset` or `DataHandler` such as `instruments`, `start_time`, `end_time` and `segments`, etc., and generate new data according to the states.
The example is given in `workflow.py`, users can run the code as follows.
### Run the Code
Run the example by running the following command:
```bash
python workflow.py dump_and_load_dataset
```

View File

@@ -62,9 +62,9 @@ class HighFreqHandler(DataHandlerLP):
def get_normalized_price_feature(price_field, shift=0):
"""Get normalized price feature ops"""
if shift == 0:
template_norm = "{0}/Ref(DayLast({1}), 240)"
template_norm = "Cut({0}/Ref(DayLast({1}), 240), 240, None)"
else:
template_norm = "Ref({0}, " + str(shift) + ")/Ref(DayLast({1}), 240)"
template_norm = "Cut(Ref({0}, " + str(shift) + ")/Ref(DayLast({1}), 240), 240, None)"
feature_ops = template_norm.format(
template_if.format(
@@ -90,7 +90,7 @@ class HighFreqHandler(DataHandlerLP):
names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"]
fields += [
"{0}/Ref(DayLast(Mean({0}, 7200)), 240)".format(
"Cut({0}/Ref(DayLast(Mean({0}, 7200)), 240), 240, None)".format(
"If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format(
template_paused.format("$volume"),
template_paused.format(simpson_vwap),
@@ -101,7 +101,7 @@ class HighFreqHandler(DataHandlerLP):
]
names += ["$volume"]
fields += [
"Ref({0}, 240)/Ref(DayLast(Mean({0}, 7200)), 240)".format(
"Cut(Ref({0}, 240)/Ref(DayLast(Mean({0}, 7200)), 240), 240, None)".format(
"If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format(
template_paused.format("$volume"),
template_paused.format(simpson_vwap),
@@ -112,7 +112,7 @@ class HighFreqHandler(DataHandlerLP):
]
names += ["$volume_1"]
fields += [template_paused.format("Date($close)")]
fields += ["Cut({0}, 240, None)".format(template_paused.format("Date($close)"))]
names += ["date"]
return fields, names
@@ -149,18 +149,20 @@ class HighFreqBacktestHandler(DataHandler):
# Because there is no vwap field in the yahoo data, a method similar to Simpson integration is used to approximate vwap
simpson_vwap = "($open + 2*$high + 2*$low + $close)/6"
fields += [
template_fillnan.format(template_paused.format("$close")),
"Cut({0}, 240, None)".format(template_fillnan.format(template_paused.format("$close"))),
]
names += ["$close0"]
fields += [
template_if.format(
template_fillnan.format(template_paused.format("$close")),
template_paused.format(simpson_vwap),
"Cut({0}, 240, None)".format(
template_if.format(
template_fillnan.format(template_paused.format("$close")),
template_paused.format(simpson_vwap),
)
)
]
names += ["$vwap0"]
fields += [
"If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format(
"Cut(If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0})), 240, None)".format(
template_paused.format("$volume"),
template_paused.format(simpson_vwap),
template_paused.format("$low"),

View File

@@ -8,6 +8,20 @@ from qlib.data.data import Cal
def get_calendar_day(freq="day", future=False):
"""Load High-Freq Calendar Date Using Memcache.
Parameters
----------
freq : str
frequency of read calendar file.
future : bool
whether including future trading day.
Returns
-------
_calendar:
array of date.
"""
flag = f"{freq}_future_{future}_day"
if flag in H["c"]:
_calendar = H["c"][flag]
@@ -18,6 +32,19 @@ def get_calendar_day(freq="day", future=False):
class DayLast(ElemOperator):
"""DayLast Operator
Parameters
----------
feature : Expression
feature instance
Returns
----------
feature:
a series of that each value equals the last value of its day
"""
def _load_internal(self, instrument, start_index, end_index, freq):
_calendar = get_calendar_day(freq=freq)
series = self.feature.load(instrument, start_index, end_index, freq)
@@ -25,18 +52,57 @@ class DayLast(ElemOperator):
class FFillNan(ElemOperator):
"""FFillNan Operator
Parameters
----------
feature : Expression
feature instance
Returns
----------
feature:
a forward fill nan feature
"""
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return series.fillna(method="ffill")
class BFillNan(ElemOperator):
"""BFillNan Operator
Parameters
----------
feature : Expression
feature instance
Returns
----------
feature:
a backfoward fill nan feature
"""
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return series.fillna(method="bfill")
class Date(ElemOperator):
"""Date Operator
Parameters
----------
feature : Expression
feature instance
Returns
----------
feature:
a series of that each value is the date corresponding to feature.index
"""
def _load_internal(self, instrument, start_index, end_index, freq):
_calendar = get_calendar_day(freq=freq)
series = self.feature.load(instrument, start_index, end_index, freq)
@@ -44,6 +110,22 @@ class Date(ElemOperator):
class Select(PairOperator):
"""Select Operator
Parameters
----------
feature_left : Expression
feature instance, select condition
feature_right : Expression
feature instance, select value
Returns
----------
feature:
value(feature_right) that meets the condition(feature_left)
"""
def _load_internal(self, instrument, start_index, end_index, freq):
series_condition = self.feature_left.load(instrument, start_index, end_index, freq)
series_feature = self.feature_right.load(instrument, start_index, end_index, freq)
@@ -51,6 +133,58 @@ class Select(PairOperator):
class IsNull(ElemOperator):
"""IsNull Operator
Parameters
----------
feature : Expression
feature instance
Returns
----------
feature:
A series indicating whether the feature is nan
"""
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return series.isnull()
class Cut(ElemOperator):
"""Cut Operator
Parameters
----------
feature : Expression
feature instance
l : int
l > 0, delete the first l elements of feature (default is None, which means 0)
r : int
r < 0, delete the last -r elements of feature (default is None, which means 0)
Returns
----------
feature:
A series with the first l and last -r elements deleted from the feature.
Note: It is deleted from the raw data, not the sliced data
"""
def __init__(self, feature, l=None, r=None):
self.l = l
self.r = r
if (self.l is not None and self.l <= 0) or (self.r is not None and self.r >= 0):
raise ValueError("Cut operator l shoud > 0 and r should < 0")
super(Cut, self).__init__(feature)
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return series.iloc[self.l : self.r]
def get_extended_window_size(self):
ll = 0 if self.l is None else self.l
rr = 0 if self.r is None else abs(self.r)
lft_etd, rght_etd = self.feature.get_extended_window_size()
lft_etd = lft_etd + ll
rght_etd = rght_etd + rr
return lft_etd, rght_etd

View File

@@ -9,7 +9,7 @@ import qlib
import pickle
import numpy as np
import pandas as pd
from qlib.config import HIGH_FREQ_CONFIG
from qlib.config import REG_CN, HIGH_FREQ_CONFIG
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.data.handler import Alpha158
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
@@ -24,17 +24,17 @@ from qlib.data.ops import Operators
from qlib.data.data import Cal
from qlib.tests.data import GetData
from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull
from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut
class HighfreqWorkflow(object):
SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull], "expression_cache": None}
SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None}
MARKET = "all"
BENCHMARK = "SH000300"
start_time = "2020-09-14 00:00:00"
start_time = "2020-09-15 00:00:00"
end_time = "2021-01-18 16:00:00"
train_end_time = "2020-11-30 16:00:00"
test_start_time = "2020-12-01 00:00:00"
@@ -123,8 +123,7 @@ class HighfreqWorkflow(object):
backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"])
print(backtest_train, backtest_test)
del xtrain, xtest
del backtest_train, backtest_test
return
def dump_and_load_dataset(self):
"""dump and load dataset state on disk"""
@@ -146,18 +145,39 @@ class HighfreqWorkflow(object):
dataset_backtest = pickle.load(file_dataset_backtest)
self._prepare_calender_cache()
##=============reload_dataset=============
dataset.init(init_type=DataHandlerLP.IT_LS)
dataset_backtest.init()
##=============reinit dataset=============
dataset.init(
handler_kwargs={
"init_type": DataHandlerLP.IT_LS,
"start_time": "2021-01-19 00:00:00",
"end_time": "2021-01-25 16:00:00",
},
segment_kwargs={
"test": (
"2021-01-19 00:00:00",
"2021-01-25 16:00:00",
),
},
)
dataset_backtest.init(
handler_kwargs={
"start_time": "2021-01-19 00:00:00",
"end_time": "2021-01-25 16:00:00",
},
segment_kwargs={
"test": (
"2021-01-19 00:00:00",
"2021-01-25 16:00:00",
),
},
)
##=============get data=============
xtrain, xtest = dataset.prepare(["train", "test"])
backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"])
xtest = dataset.prepare(["test"])
backtest_test = dataset_backtest.prepare(["test"])
print(xtrain, xtest)
print(backtest_train, backtest_test)
del xtrain, xtest
del backtest_train, backtest_test
print(xtest, backtest_test)
return
if __name__ == "__main__":

View File

@@ -99,7 +99,7 @@ if __name__ == "__main__":
},
}
# model initiaiton
# model initialization
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])
@@ -112,12 +112,14 @@ if __name__ == "__main__":
with R.start(experiment_name="workflow"):
R.log_params(**flatten_dict(task))
model.fit(dataset)
R.save_objects(**{"params.pkl": model})
# prediction
recorder = R.get_recorder()
sr = SignalRecord(model, dataset, recorder)
sr.generate()
# backtest
# backtest. If users want to use backtest based on their own prediction,
# please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template.
par = PortAnaRecord(recorder, port_analysis_config)
par.generate()

View File

@@ -130,7 +130,7 @@ class ALSTM(Model):
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self._fitted = False
self.fitted = False
self.ALSTM_model.to(self.device)
def mse(self, pred, label):
@@ -238,7 +238,7 @@ class ALSTM(Model):
# train
self.logger.info("training...")
self._fitted = True
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
@@ -270,7 +270,7 @@ class ALSTM(Model):
torch.cuda.empty_cache()
def predict(self, dataset):
if not self._fitted:
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")

View File

@@ -135,7 +135,7 @@ class ALSTM(Model):
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self._fitted = False
self.fitted = False
self.ALSTM_model.to(self.device)
def mse(self, pred, label):
@@ -225,7 +225,7 @@ class ALSTM(Model):
# train
self.logger.info("training...")
self._fitted = True
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
@@ -257,7 +257,7 @@ class ALSTM(Model):
torch.cuda.empty_cache()
def predict(self, dataset):
if not self._fitted:
if not self.fitted:
raise ValueError("model is not fitted yet!")
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)

View File

@@ -142,7 +142,7 @@ class GATs(Model):
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self._fitted = False
self.fitted = False
self.GAT_model.to(self.device)
def mse(self, pred, label):
@@ -275,7 +275,7 @@ class GATs(Model):
# train
self.logger.info("training...")
self._fitted = True
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
@@ -307,7 +307,7 @@ class GATs(Model):
torch.cuda.empty_cache()
def predict(self, dataset):
if not self._fitted:
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")

View File

@@ -164,7 +164,7 @@ class GATs(Model):
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self._fitted = False
self.fitted = False
self.GAT_model.to(self.device)
def mse(self, pred, label):
@@ -297,7 +297,7 @@ class GATs(Model):
# train
self.logger.info("training...")
self._fitted = True
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
@@ -329,7 +329,7 @@ class GATs(Model):
torch.cuda.empty_cache()
def predict(self, dataset):
if not self._fitted:
if not self.fitted:
raise ValueError("model is not fitted yet!")
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)

View File

@@ -130,7 +130,7 @@ class GRU(Model):
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self._fitted = False
self.fitted = False
self.gru_model.to(self.device)
def mse(self, pred, label):
@@ -238,7 +238,7 @@ class GRU(Model):
# train
self.logger.info("training...")
self._fitted = True
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
@@ -270,7 +270,7 @@ class GRU(Model):
torch.cuda.empty_cache()
def predict(self, dataset):
if not self._fitted:
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")

View File

@@ -135,7 +135,7 @@ class GRU(Model):
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self._fitted = False
self.fitted = False
self.GRU_model.to(self.device)
def mse(self, pred, label):
@@ -225,7 +225,7 @@ class GRU(Model):
# train
self.logger.info("training...")
self._fitted = True
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
@@ -257,7 +257,7 @@ class GRU(Model):
torch.cuda.empty_cache()
def predict(self, dataset):
if not self._fitted:
if not self.fitted:
raise ValueError("model is not fitted yet!")
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)

View File

@@ -130,7 +130,7 @@ class LSTM(Model):
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self._fitted = False
self.fitted = False
self.lstm_model.to(self.device)
def mse(self, pred, label):
@@ -238,7 +238,7 @@ class LSTM(Model):
# train
self.logger.info("training...")
self._fitted = True
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
@@ -270,7 +270,7 @@ class LSTM(Model):
torch.cuda.empty_cache()
def predict(self, dataset):
if not self._fitted:
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")

View File

@@ -135,7 +135,7 @@ class LSTM(Model):
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self._fitted = False
self.fitted = False
self.LSTM_model.to(self.device)
def mse(self, pred, label):
@@ -225,7 +225,7 @@ class LSTM(Model):
# train
self.logger.info("training...")
self._fitted = True
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
@@ -257,7 +257,7 @@ class LSTM(Model):
torch.cuda.empty_cache()
def predict(self, dataset):
if not self._fitted:
if not self.fitted:
raise ValueError("model is not fitted yet!")
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)

View File

@@ -150,7 +150,7 @@ class DNNModelPytorch(Model):
eps=1e-08,
)
self._fitted = False
self.fitted = False
self.dnn_model.to(self.device)
def fit(
@@ -180,7 +180,7 @@ class DNNModelPytorch(Model):
evals_result["valid"] = []
# train
self.logger.info("training...")
self._fitted = True
self.fitted = True
# return
# prepare training data
x_train_values = torch.from_numpy(x_train.values).float()
@@ -265,7 +265,7 @@ class DNNModelPytorch(Model):
raise NotImplementedError("loss {} is not supported!".format(loss_type))
def predict(self, dataset):
if not self._fitted:
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test_pd = dataset.prepare("test", col_set="feature")
x_test = torch.from_numpy(x_test_pd.values).float().to(self.device)

View File

@@ -302,7 +302,7 @@ class SFM(Model):
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self._fitted = False
self.fitted = False
self.sfm_model.to(self.device)
def test_epoch(self, data_x, data_y):
@@ -386,7 +386,7 @@ class SFM(Model):
# train
self.logger.info("training...")
self._fitted = True
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
@@ -435,7 +435,7 @@ class SFM(Model):
raise ValueError("unknown metric `%s`" % self.metric)
def predict(self, dataset):
if not self._fitted:
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")

View File

@@ -88,6 +88,7 @@ class TabnetModel(Model):
"\nGPU : {}"
"\npretrain: {}".format(self.batch_size, vbs, GPU, pretrain)
)
self.fitted = False
np.random.seed(self.seed)
torch.manual_seed(self.seed)
@@ -187,7 +188,7 @@ class TabnetModel(Model):
evals_result["valid"] = []
self.logger.info("training...")
self._fitted = True
self.fitted = True
for epoch_idx in range(self.n_epochs):
self.logger.info("epoch: %s" % (epoch_idx))
@@ -212,7 +213,7 @@ class TabnetModel(Model):
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
def predict(self, dataset):
if not self._fitted:
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)

View File

@@ -478,13 +478,13 @@ class DatasetProvider(abc.ABC):
data = pd.DataFrame(obj)
_calendar = Cal.calendar(freq=freq)
data.index = _calendar[data.index.values.astype(np.int)]
data.index = _calendar[data.index.values.astype(int)]
data.index.names = ["datetime"]
if spans is None:
return data
else:
mask = np.zeros(len(data), dtype=np.bool)
mask = np.zeros(len(data), dtype=bool)
for begin, end in spans:
mask |= (data.index >= begin) & (data.index <= end)
return data[mask]

View File

@@ -87,9 +87,42 @@ class DatasetH(Dataset):
"""
super().__init__(handler, segments)
def init(self, **kwargs):
"""Initialize the DatasetH, Only parameters belonging to handler.init will be passed in"""
self.handler.init(**kwargs)
def init(self, handler_kwargs: dict = None, segment_kwargs: dict = None):
"""
Initialize the DatasetH
Parameters
----------
handler_kwargs : dict
Config of DataHanlder, which could include the following arguments:
- arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'.
- arguments of DataHandler.init, such as 'enable_cache', etc.
segment_kwargs : dict
Config of segments which is same as 'segments' in DatasetH.setup_data
"""
if handler_kwargs:
if not isinstance(handler_kwargs, dict):
raise TypeError(f"param handler_kwargs must be type dict, not {type(handler_kwargs)}")
kwargs_init = {}
kwargs_conf_data = {}
conf_data_arg = {"instruments", "start_time", "end_time"}
for k, v in handler_kwargs.items():
if k in conf_data_arg:
kwargs_conf_data.update({k: v})
else:
kwargs_init.update({k: v})
self.handler.conf_data(**kwargs_conf_data)
self.handler.init(**kwargs_init)
if segment_kwargs:
if not isinstance(segment_kwargs, dict):
raise TypeError(f"param handler_kwargs must be type dict, not {type(segment_kwargs)}")
self.segments = segment_kwargs.copy()
def setup_data(self, handler: Union[dict, DataHandler], segments: dict):
"""

View File

@@ -26,6 +26,7 @@ def task_train(task_config: dict, experiment_name):
R.log_params(**flatten_dict(task_config))
model.fit(dataset)
recorder = R.get_recorder()
R.save_objects(**{"params.pkl": model})
# generate records: prediction, backtest, and analysis
for record in task_config["record"]:

View File

@@ -86,7 +86,6 @@ class GetData:
@staticmethod
def _delete_qlib_data(file_dir: Path):
logger.info(f"delete {file_dir}")
rm_dirs = []
for _name in ["features", "calendars", "instruments", "features_cache", "dataset_cache"]:
_p = file_dir.joinpath(_name)
@@ -133,7 +132,11 @@ class GetData:
Examples
---------
# get 1d data
python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
# get 1min data
python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --interval 1min --region cn
-------
"""

View File

@@ -21,7 +21,7 @@ pip install -r requirements.txt
### CN Data
#### 1d
#### 1d from yahoo
```bash
@@ -33,18 +33,26 @@ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source/cn_1d
# dump data
cd qlib/scripts
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1d_nor --qlib_dir ~/.qlib/stock_data/source/qlib_cn_1d --freq day --exclude_fields date,adjclose,dividends,splits,symbol
# using
import qlib
from qlib.data import D
qlib.init(provider_uri="~/.qlib/stock_data/source/qlib_cn_1d", region="cn")
df = D.features(D.instruments("all"), ["$close"], freq="day")
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/qlib_cn_1d --freq day --exclude_fields date,adjclose,dividends,splits,symbol
```
#### 1min
### 1d from qlib
```bash
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1d --region cn
```
### using data
```python
import qlib
from qlib.data import D
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1d", region="cn")
df = D.features(D.instruments("all"), ["$close"], freq="day")
```
#### 1min from yahoo
```bash
@@ -56,20 +64,28 @@ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source/cn_1mi
# dump data
cd qlib/scripts
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1min_nor --qlib_dir ~/.qlib/stock_data/source/qlib_cn_1min --freq 1min --exclude_fields date,adjclose,dividends,splits,symbol
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1min_nor --qlib_dir ~/.qlib/qlib_data/qlib_cn_1min --freq 1min --exclude_fields date,adjclose,dividends,splits,symbol
```
# using
### 1min from qlib
```bash
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1min --interval 1min --region cn
```
### using data
```python
import qlib
from qlib.data import D
qlib.init(provider_uri="~/.qlib/stock_data/source/qlib_cn_1min", region="CN")
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1min", region="cn")
df = D.features(D.instruments("all"), ["$close"], freq="1min")
```
### US Data
#### 1d
#### 1d from yahoo
```bash
@@ -82,12 +98,22 @@ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source/us_1d
# dump data
cd qlib/scripts
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/us_1d_nor --qlib_dir ~/.qlib/stock_data/source/qlib_us_1d --freq day --exclude_fields date,adjclose,dividends,splits,symbol
```
#### 1d from qlib
```bash
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_us_1d --region us
```
### using data
```python
# using
import qlib
from qlib.data import D
qlib.init(provider_uri="~/.qlib/stock_data/source/qlib_us_1d", region="us")
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_us_1d", region="us")
df = D.features(D.instruments("all"), ["$close"], freq="day")
```

View File

@@ -30,7 +30,7 @@ if not _CYTHON_INSTALLED:
# What packages are required for this module to be executed?
# `estimator` may depend on other packages. In order to reduce dependencies, it is not written here.
REQUIRED = [
"numpy>=1.12.0,<=1.19.5",
"numpy>=1.12.0",
"pandas>=0.25.1",
"scipy>=1.0.0",
"requests>=2.18.0",