diff --git a/README.md b/README.md index 1c6d94ddd..6f416d420 100644 --- a/README.md +++ b/README.md @@ -82,13 +82,10 @@ This table demonstrates the supported Python version of `Qlib`: 2. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future. ### Install with pip -**Note**: Due to latest numpy release: version 1.20.0, unexpected errors will occur if you install or run Qlib with `numpy==1.20.0`. We recommend to use lower version of `numpy==1.19.5` for now and we will fix this incompatibility in the neaar future. - Users can easily install ``Qlib`` by pip according to the following command. ```bash - pip install numpy==1.19.5 - pip install pyqlib --ignore-installed numpy + pip install pyqlib ``` **Note**: pip will install the latest stable qlib. However, the main branch of qlib is in active development. If you want to test the latest scripts or functions in the main branch. Please install qlib with the methods below. @@ -121,7 +118,12 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor ## Data Preparation Load and prepare data by running the following code: ```bash + # get 1d data python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn + + # get 1min data + python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --region cn --interval 1min + ``` This dataset is created by public data collected by [crawler scripts](scripts/data_collector/), which have been released in diff --git a/docs/advanced/serial.rst b/docs/advanced/serial.rst new file mode 100644 index 000000000..8c0f83746 --- /dev/null +++ b/docs/advanced/serial.rst @@ -0,0 +1,42 @@ +.. _serial: + +================================= +Serialization +================================= +.. currentmodule:: qlib + +Introduction +=================== +``Qlib`` supports dumping the state of ``DataHandler``, ``DataSet``, ``Processor`` and ``Model``, etc. into a disk and reloading them. + +Serializable Class +======================== + +``Qlib`` provides a base class ``qlib.utils.serial.Serializable``, whose state can be dumped into or loaded from disk in `pickle` format. +When users dump the state of a ``Serializable`` instance, the attributes of the instance whose name **does not** start with `_` will be saved on the disk. + +Example +========================== +``Qlib``'s serializable class includes ``DataHandler``, ``DataSet``, ``Processor`` and ``Model``, etc., which are subclass of ``qlib.utils.serial.Serializable``. +Specifically, ``qlib.data.dataset.DatasetH`` is one of them. Users can serialize ``DatasetH`` as follows. + +.. code-block:: Python + + ##=============dump dataset============= + dataset.to_pickle(path="dataset.pkl") # dataset is an instance of qlib.data.dataset.DatasetH + + ##=============reload dataset============= + with open("dataset.pkl", "rb") as file_dataset: + dataset = pickle.load(file_dataset) + +.. note:: + Only state of ``DatasetH`` should be saved on the disk, such as some `mean` and `variance` used for data normalization, etc. + + After reloading the ``DatasetH``, users need to reinitialize it. It means that users can reset some states of ``DatasetH`` or ``QlibDataHandler`` such as `instruments`, `start_time`, `end_time` and `segments`, etc., and generate new data according to the states (data is not state and should not be saved on the disk). + +A more detailed example is in this `link `_. + + +API +=================== +Please refer to `Serializable API <../reference/api.html#module-qlib.utils.serial.Serializable>`_. diff --git a/docs/component/data.rst b/docs/component/data.rst index dd32c5cd8..4b0962d49 100644 --- a/docs/component/data.rst +++ b/docs/component/data.rst @@ -31,7 +31,7 @@ Qlib Format Data We've specially designed a data structure to manage financial data, please refer to the `File storage design section in Qlib paper `_ for detailed information. Such data will be stored with filename suffix `.bin` (We'll call them `.bin` file, `.bin` format, or qlib format). `.bin` file is designed for scientific computing on finance data. -``Qlib`` provides two different off-the-shelf dataset, which can be accessed through this `link `_: +``Qlib`` provides two different off-the-shelf datasets, which can be accessed through this `link `_: ======================== ================= ================ Dataset US Market China Market @@ -41,6 +41,7 @@ Alpha360 √ √ Alpha158 √ √ ======================== ================= ================ +Also, ``Qlib`` provides a high-frequency dataset. Users can run a high-frequency dataset example through this `link `_. Qlib Format Dataset -------------------- @@ -48,8 +49,12 @@ Qlib Format Dataset .. code-block:: bash + # download 1d python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn + # download 1min + python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1min --region cn --interval 1min + In addition to China-Stock data, ``Qlib`` also includes a US-Stock dataset, which can be downloaded with the following command: .. code-block:: bash @@ -167,7 +172,7 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t - If users use ``Qlib`` in US-stock mode, US-stock data is required. ``Qlib`` also provides a script to download US-stock data. Users can use ``Qlib`` in US-stock mode according to the following steps: - - Download china-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_. + - Download us-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_. - Initialize ``Qlib`` in US-stock mode Supposed that users prepare their Qlib format data in the directory ``~/.qlib/csv_data/us_data``. Users only need to initialize ``Qlib`` as follows. diff --git a/docs/component/recorder.rst b/docs/component/recorder.rst index 5e01140cf..3882161bc 100644 --- a/docs/component/recorder.rst +++ b/docs/component/recorder.rst @@ -94,6 +94,52 @@ The ``RecordTemp`` class is a class that enables generate experiment results suc - ``SignalRecord``: This class generates the `prediction` results of the model. - ``SigAnaRecord``: This class generates the `IC`, `ICIR`, `Rank IC` and `Rank ICIR` of the model. + +Here is a simple example of what is done in ``SigAnaRecord``, which users can refer to if they want to calculate IC, Rank IC, Long-Short Return with their own prediction and label. + +.. code-block:: Python + + from qlib.contrib.eva.alpha import calc_ic, calc_long_short_return + + ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0]) + long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, 0]) + - ``PortAnaRecord``: This class generates the results of `backtest`. The detailed information about `backtest` as well as the available `strategy`, users can refer to `Strategy <../component/strategy.html>`_ and `Backtest <../component/backtest.html>`_. +Here is a simple exampke of what is done in ``PortAnaRecord``, which users can refer to if they want to do backtest based on their own prediction and label. + +.. code-block:: Python + + from qlib.contrib.strategy.strategy import TopkDropoutStrategy + from qlib.contrib.evaluate import ( + backtest as normal_backtest, + risk_analysis, + ) + + # backtest + STRATEGY_CONFIG = { + "topk": 50, + "n_drop": 5, + } + BACKTEST_CONFIG = { + "verbose": False, + "limit_threshold": 0.095, + "account": 100000000, + "benchmark": BENCHMARK, + "deal_price": "close", + "open_cost": 0.0005, + "close_cost": 0.0015, + "min_cost": 5, + } + + strategy = TopkDropoutStrategy(**STRATEGY_CONFIG) + report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG) + + # analysis + analysis = dict() + analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"]) + analysis["excess_return_with_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"] - report_normal["cost"]) + analysis_df = pd.concat(analysis) # type: pd.DataFrame + print(analysis_df) + For more information about the APIs, please refer to `Record Template API <../reference/api.html#module-qlib.workflow.record_temp>`_. diff --git a/docs/component/workflow.rst b/docs/component/workflow.rst index 96a764de1..9c8481862 100644 --- a/docs/component/workflow.rst +++ b/docs/component/workflow.rst @@ -90,12 +90,12 @@ Below is a typical config file of ``qrun``. test: [2017-01-01, 2020-08-01] record: - class: SignalRecord - module_path: qlib.workflow.record_temp - kwargs: {} + module_path: qlib.workflow.record_temp + kwargs: {} - class: PortAnaRecord - module_path: qlib.workflow.record_temp - kwargs: - config: *port_analysis_config + module_path: qlib.workflow.record_temp + kwargs: + config: *port_analysis_config After saving the config into `configuration.yaml`, users could start the workflow and test their ideas with a single command below. diff --git a/docs/index.rst b/docs/index.rst index 15a36b489..3fa35fc60 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -49,6 +49,7 @@ Document Structure Building Formulaic Alphas Online & Offline mode + Serialization .. toctree:: :maxdepth: 3 diff --git a/docs/reference/api.rst b/docs/reference/api.rst index f21a9f518..3167d8a62 100644 --- a/docs/reference/api.rst +++ b/docs/reference/api.rst @@ -152,4 +152,14 @@ Recorder Record Template -------------------- .. automodule:: qlib.workflow.record_temp + :members: + + +Utils +==================== + +Serializable +-------------------- + +.. automodule:: qlib.utils.serial.Serializable :members: \ No newline at end of file diff --git a/examples/highfreq/README.md b/examples/highfreq/README.md new file mode 100644 index 000000000..30c2e19db --- /dev/null +++ b/examples/highfreq/README.md @@ -0,0 +1,28 @@ +# High-Frequency Dataset + +This dataset is an example for RL high frequency trading. + +## Get High-Frequency Data + +Get high-frequency data by running the following command: +```bash + python workflow.py get_data +``` + +## Dump & Reload & Reinitialize the Dataset + + +The High-Frequency Dataset is implemented as `qlib.data.dataset.DatasetH` in the `workflow.py`. `DatatsetH` is the subclass of [`qlib.utils.serial.Serializable`](https://qlib.readthedocs.io/en/latest/advanced/serial.html), whose state can be dumped in or loaded from disk in `pickle` format. + +### About Reinitialization + +After reloading `Dataset` from disk, `Qlib` also support reinitializing the dataset. It means that users can reset some states of `Dataset` or `DataHandler` such as `instruments`, `start_time`, `end_time` and `segments`, etc., and generate new data according to the states. + +The example is given in `workflow.py`, users can run the code as follows. + +### Run the Code + +Run the example by running the following command: +```bash + python workflow.py dump_and_load_dataset +``` \ No newline at end of file diff --git a/examples/highfreq/__init__.py b/examples/highfreq/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/examples/highfreq/highfreq_handler.py b/examples/highfreq/highfreq_handler.py index 1f0ddb28c..d35650514 100644 --- a/examples/highfreq/highfreq_handler.py +++ b/examples/highfreq/highfreq_handler.py @@ -62,9 +62,9 @@ class HighFreqHandler(DataHandlerLP): def get_normalized_price_feature(price_field, shift=0): """Get normalized price feature ops""" if shift == 0: - template_norm = "{0}/Ref(DayLast({1}), 240)" + template_norm = "Cut({0}/Ref(DayLast({1}), 240), 240, None)" else: - template_norm = "Ref({0}, " + str(shift) + ")/Ref(DayLast({1}), 240)" + template_norm = "Cut(Ref({0}, " + str(shift) + ")/Ref(DayLast({1}), 240), 240, None)" feature_ops = template_norm.format( template_if.format( @@ -90,7 +90,7 @@ class HighFreqHandler(DataHandlerLP): names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"] fields += [ - "{0}/Ref(DayLast(Mean({0}, 7200)), 240)".format( + "Cut({0}/Ref(DayLast(Mean({0}, 7200)), 240), 240, None)".format( "If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format( template_paused.format("$volume"), template_paused.format(simpson_vwap), @@ -101,7 +101,7 @@ class HighFreqHandler(DataHandlerLP): ] names += ["$volume"] fields += [ - "Ref({0}, 240)/Ref(DayLast(Mean({0}, 7200)), 240)".format( + "Cut(Ref({0}, 240)/Ref(DayLast(Mean({0}, 7200)), 240), 240, None)".format( "If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format( template_paused.format("$volume"), template_paused.format(simpson_vwap), @@ -112,7 +112,7 @@ class HighFreqHandler(DataHandlerLP): ] names += ["$volume_1"] - fields += [template_paused.format("Date($close)")] + fields += ["Cut({0}, 240, None)".format(template_paused.format("Date($close)"))] names += ["date"] return fields, names @@ -149,18 +149,20 @@ class HighFreqBacktestHandler(DataHandler): # Because there is no vwap field in the yahoo data, a method similar to Simpson integration is used to approximate vwap simpson_vwap = "($open + 2*$high + 2*$low + $close)/6" fields += [ - template_fillnan.format(template_paused.format("$close")), + "Cut({0}, 240, None)".format(template_fillnan.format(template_paused.format("$close"))), ] names += ["$close0"] fields += [ - template_if.format( - template_fillnan.format(template_paused.format("$close")), - template_paused.format(simpson_vwap), + "Cut({0}, 240, None)".format( + template_if.format( + template_fillnan.format(template_paused.format("$close")), + template_paused.format(simpson_vwap), + ) ) ] names += ["$vwap0"] fields += [ - "If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format( + "Cut(If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0})), 240, None)".format( template_paused.format("$volume"), template_paused.format(simpson_vwap), template_paused.format("$low"), diff --git a/examples/highfreq/highfreq_ops.py b/examples/highfreq/highfreq_ops.py index 85ed63285..66a084f9f 100644 --- a/examples/highfreq/highfreq_ops.py +++ b/examples/highfreq/highfreq_ops.py @@ -8,6 +8,20 @@ from qlib.data.data import Cal def get_calendar_day(freq="day", future=False): + """Load High-Freq Calendar Date Using Memcache. + + Parameters + ---------- + freq : str + frequency of read calendar file. + future : bool + whether including future trading day. + + Returns + ------- + _calendar: + array of date. + """ flag = f"{freq}_future_{future}_day" if flag in H["c"]: _calendar = H["c"][flag] @@ -18,6 +32,19 @@ def get_calendar_day(freq="day", future=False): class DayLast(ElemOperator): + """DayLast Operator + + Parameters + ---------- + feature : Expression + feature instance + + Returns + ---------- + feature: + a series of that each value equals the last value of its day + """ + def _load_internal(self, instrument, start_index, end_index, freq): _calendar = get_calendar_day(freq=freq) series = self.feature.load(instrument, start_index, end_index, freq) @@ -25,18 +52,57 @@ class DayLast(ElemOperator): class FFillNan(ElemOperator): + """FFillNan Operator + + Parameters + ---------- + feature : Expression + feature instance + + Returns + ---------- + feature: + a forward fill nan feature + """ + def _load_internal(self, instrument, start_index, end_index, freq): series = self.feature.load(instrument, start_index, end_index, freq) return series.fillna(method="ffill") class BFillNan(ElemOperator): + """BFillNan Operator + + Parameters + ---------- + feature : Expression + feature instance + + Returns + ---------- + feature: + a backfoward fill nan feature + """ + def _load_internal(self, instrument, start_index, end_index, freq): series = self.feature.load(instrument, start_index, end_index, freq) return series.fillna(method="bfill") class Date(ElemOperator): + """Date Operator + + Parameters + ---------- + feature : Expression + feature instance + + Returns + ---------- + feature: + a series of that each value is the date corresponding to feature.index + """ + def _load_internal(self, instrument, start_index, end_index, freq): _calendar = get_calendar_day(freq=freq) series = self.feature.load(instrument, start_index, end_index, freq) @@ -44,6 +110,22 @@ class Date(ElemOperator): class Select(PairOperator): + """Select Operator + + Parameters + ---------- + feature_left : Expression + feature instance, select condition + feature_right : Expression + feature instance, select value + + Returns + ---------- + feature: + value(feature_right) that meets the condition(feature_left) + + """ + def _load_internal(self, instrument, start_index, end_index, freq): series_condition = self.feature_left.load(instrument, start_index, end_index, freq) series_feature = self.feature_right.load(instrument, start_index, end_index, freq) @@ -51,6 +133,58 @@ class Select(PairOperator): class IsNull(ElemOperator): + """IsNull Operator + + Parameters + ---------- + feature : Expression + feature instance + + Returns + ---------- + feature: + A series indicating whether the feature is nan + """ + def _load_internal(self, instrument, start_index, end_index, freq): series = self.feature.load(instrument, start_index, end_index, freq) return series.isnull() + + +class Cut(ElemOperator): + """Cut Operator + + Parameters + ---------- + feature : Expression + feature instance + l : int + l > 0, delete the first l elements of feature (default is None, which means 0) + r : int + r < 0, delete the last -r elements of feature (default is None, which means 0) + Returns + ---------- + feature: + A series with the first l and last -r elements deleted from the feature. + Note: It is deleted from the raw data, not the sliced data + """ + + def __init__(self, feature, l=None, r=None): + self.l = l + self.r = r + if (self.l is not None and self.l <= 0) or (self.r is not None and self.r >= 0): + raise ValueError("Cut operator l shoud > 0 and r should < 0") + + super(Cut, self).__init__(feature) + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + return series.iloc[self.l : self.r] + + def get_extended_window_size(self): + ll = 0 if self.l is None else self.l + rr = 0 if self.r is None else abs(self.r) + lft_etd, rght_etd = self.feature.get_extended_window_size() + lft_etd = lft_etd + ll + rght_etd = rght_etd + rr + return lft_etd, rght_etd diff --git a/examples/highfreq/workflow.py b/examples/highfreq/workflow.py index 6649079d8..01de59c0e 100644 --- a/examples/highfreq/workflow.py +++ b/examples/highfreq/workflow.py @@ -9,7 +9,7 @@ import qlib import pickle import numpy as np import pandas as pd -from qlib.config import HIGH_FREQ_CONFIG +from qlib.config import REG_CN, HIGH_FREQ_CONFIG from qlib.contrib.model.gbdt import LGBModel from qlib.contrib.data.handler import Alpha158 from qlib.contrib.strategy.strategy import TopkDropoutStrategy @@ -24,17 +24,17 @@ from qlib.data.ops import Operators from qlib.data.data import Cal from qlib.tests.data import GetData -from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull +from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut class HighfreqWorkflow(object): - SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull], "expression_cache": None} + SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None} MARKET = "all" BENCHMARK = "SH000300" - start_time = "2020-09-14 00:00:00" + start_time = "2020-09-15 00:00:00" end_time = "2021-01-18 16:00:00" train_end_time = "2020-11-30 16:00:00" test_start_time = "2020-12-01 00:00:00" @@ -123,8 +123,7 @@ class HighfreqWorkflow(object): backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"]) print(backtest_train, backtest_test) - del xtrain, xtest - del backtest_train, backtest_test + return def dump_and_load_dataset(self): """dump and load dataset state on disk""" @@ -146,18 +145,39 @@ class HighfreqWorkflow(object): dataset_backtest = pickle.load(file_dataset_backtest) self._prepare_calender_cache() - ##=============reload_dataset============= - dataset.init(init_type=DataHandlerLP.IT_LS) - dataset_backtest.init() + ##=============reinit dataset============= + dataset.init( + handler_kwargs={ + "init_type": DataHandlerLP.IT_LS, + "start_time": "2021-01-19 00:00:00", + "end_time": "2021-01-25 16:00:00", + }, + segment_kwargs={ + "test": ( + "2021-01-19 00:00:00", + "2021-01-25 16:00:00", + ), + }, + ) + dataset_backtest.init( + handler_kwargs={ + "start_time": "2021-01-19 00:00:00", + "end_time": "2021-01-25 16:00:00", + }, + segment_kwargs={ + "test": ( + "2021-01-19 00:00:00", + "2021-01-25 16:00:00", + ), + }, + ) ##=============get data============= - xtrain, xtest = dataset.prepare(["train", "test"]) - backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"]) + xtest = dataset.prepare(["test"]) + backtest_test = dataset_backtest.prepare(["test"]) - print(xtrain, xtest) - print(backtest_train, backtest_test) - del xtrain, xtest - del backtest_train, backtest_test + print(xtest, backtest_test) + return if __name__ == "__main__": diff --git a/examples/workflow_by_code.py b/examples/workflow_by_code.py index 6d166646c..d5dab8917 100644 --- a/examples/workflow_by_code.py +++ b/examples/workflow_by_code.py @@ -99,7 +99,7 @@ if __name__ == "__main__": }, } - # model initiaiton + # model initialization model = init_instance_by_config(task["model"]) dataset = init_instance_by_config(task["dataset"]) @@ -112,12 +112,14 @@ if __name__ == "__main__": with R.start(experiment_name="workflow"): R.log_params(**flatten_dict(task)) model.fit(dataset) + R.save_objects(**{"params.pkl": model}) # prediction recorder = R.get_recorder() sr = SignalRecord(model, dataset, recorder) sr.generate() - # backtest + # backtest. If users want to use backtest based on their own prediction, + # please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template. par = PortAnaRecord(recorder, port_analysis_config) par.generate() diff --git a/qlib/contrib/model/pytorch_alstm.py b/qlib/contrib/model/pytorch_alstm.py index f2cfbdc36..bbbb61851 100644 --- a/qlib/contrib/model/pytorch_alstm.py +++ b/qlib/contrib/model/pytorch_alstm.py @@ -130,7 +130,7 @@ class ALSTM(Model): else: raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) - self._fitted = False + self.fitted = False self.ALSTM_model.to(self.device) def mse(self, pred, label): @@ -238,7 +238,7 @@ class ALSTM(Model): # train self.logger.info("training...") - self._fitted = True + self.fitted = True for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) @@ -270,7 +270,7 @@ class ALSTM(Model): torch.cuda.empty_cache() def predict(self, dataset): - if not self._fitted: + if not self.fitted: raise ValueError("model is not fitted yet!") x_test = dataset.prepare("test", col_set="feature") diff --git a/qlib/contrib/model/pytorch_alstm_ts.py b/qlib/contrib/model/pytorch_alstm_ts.py index d2a5db8f1..725568de8 100644 --- a/qlib/contrib/model/pytorch_alstm_ts.py +++ b/qlib/contrib/model/pytorch_alstm_ts.py @@ -135,7 +135,7 @@ class ALSTM(Model): else: raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) - self._fitted = False + self.fitted = False self.ALSTM_model.to(self.device) def mse(self, pred, label): @@ -225,7 +225,7 @@ class ALSTM(Model): # train self.logger.info("training...") - self._fitted = True + self.fitted = True for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) @@ -257,7 +257,7 @@ class ALSTM(Model): torch.cuda.empty_cache() def predict(self, dataset): - if not self._fitted: + if not self.fitted: raise ValueError("model is not fitted yet!") dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) diff --git a/qlib/contrib/model/pytorch_gats.py b/qlib/contrib/model/pytorch_gats.py index 9e5aa3e28..07048e1bc 100644 --- a/qlib/contrib/model/pytorch_gats.py +++ b/qlib/contrib/model/pytorch_gats.py @@ -142,7 +142,7 @@ class GATs(Model): else: raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) - self._fitted = False + self.fitted = False self.GAT_model.to(self.device) def mse(self, pred, label): @@ -275,7 +275,7 @@ class GATs(Model): # train self.logger.info("training...") - self._fitted = True + self.fitted = True for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) @@ -307,7 +307,7 @@ class GATs(Model): torch.cuda.empty_cache() def predict(self, dataset): - if not self._fitted: + if not self.fitted: raise ValueError("model is not fitted yet!") x_test = dataset.prepare("test", col_set="feature") diff --git a/qlib/contrib/model/pytorch_gats_ts.py b/qlib/contrib/model/pytorch_gats_ts.py index c3b8a2f06..1e94f56e4 100644 --- a/qlib/contrib/model/pytorch_gats_ts.py +++ b/qlib/contrib/model/pytorch_gats_ts.py @@ -164,7 +164,7 @@ class GATs(Model): else: raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) - self._fitted = False + self.fitted = False self.GAT_model.to(self.device) def mse(self, pred, label): @@ -297,7 +297,7 @@ class GATs(Model): # train self.logger.info("training...") - self._fitted = True + self.fitted = True for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) @@ -329,7 +329,7 @@ class GATs(Model): torch.cuda.empty_cache() def predict(self, dataset): - if not self._fitted: + if not self.fitted: raise ValueError("model is not fitted yet!") dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) diff --git a/qlib/contrib/model/pytorch_gru.py b/qlib/contrib/model/pytorch_gru.py index db8257093..84f863b9f 100755 --- a/qlib/contrib/model/pytorch_gru.py +++ b/qlib/contrib/model/pytorch_gru.py @@ -130,7 +130,7 @@ class GRU(Model): else: raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) - self._fitted = False + self.fitted = False self.gru_model.to(self.device) def mse(self, pred, label): @@ -238,7 +238,7 @@ class GRU(Model): # train self.logger.info("training...") - self._fitted = True + self.fitted = True for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) @@ -270,7 +270,7 @@ class GRU(Model): torch.cuda.empty_cache() def predict(self, dataset): - if not self._fitted: + if not self.fitted: raise ValueError("model is not fitted yet!") x_test = dataset.prepare("test", col_set="feature") diff --git a/qlib/contrib/model/pytorch_gru_ts.py b/qlib/contrib/model/pytorch_gru_ts.py index b6afc068c..bb6618b85 100755 --- a/qlib/contrib/model/pytorch_gru_ts.py +++ b/qlib/contrib/model/pytorch_gru_ts.py @@ -135,7 +135,7 @@ class GRU(Model): else: raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) - self._fitted = False + self.fitted = False self.GRU_model.to(self.device) def mse(self, pred, label): @@ -225,7 +225,7 @@ class GRU(Model): # train self.logger.info("training...") - self._fitted = True + self.fitted = True for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) @@ -257,7 +257,7 @@ class GRU(Model): torch.cuda.empty_cache() def predict(self, dataset): - if not self._fitted: + if not self.fitted: raise ValueError("model is not fitted yet!") dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) diff --git a/qlib/contrib/model/pytorch_lstm.py b/qlib/contrib/model/pytorch_lstm.py index 8eb390a98..163d500ec 100755 --- a/qlib/contrib/model/pytorch_lstm.py +++ b/qlib/contrib/model/pytorch_lstm.py @@ -130,7 +130,7 @@ class LSTM(Model): else: raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) - self._fitted = False + self.fitted = False self.lstm_model.to(self.device) def mse(self, pred, label): @@ -238,7 +238,7 @@ class LSTM(Model): # train self.logger.info("training...") - self._fitted = True + self.fitted = True for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) @@ -270,7 +270,7 @@ class LSTM(Model): torch.cuda.empty_cache() def predict(self, dataset): - if not self._fitted: + if not self.fitted: raise ValueError("model is not fitted yet!") x_test = dataset.prepare("test", col_set="feature") diff --git a/qlib/contrib/model/pytorch_lstm_ts.py b/qlib/contrib/model/pytorch_lstm_ts.py index 79987ee0f..cf4f8fb9f 100755 --- a/qlib/contrib/model/pytorch_lstm_ts.py +++ b/qlib/contrib/model/pytorch_lstm_ts.py @@ -135,7 +135,7 @@ class LSTM(Model): else: raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) - self._fitted = False + self.fitted = False self.LSTM_model.to(self.device) def mse(self, pred, label): @@ -225,7 +225,7 @@ class LSTM(Model): # train self.logger.info("training...") - self._fitted = True + self.fitted = True for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) @@ -257,7 +257,7 @@ class LSTM(Model): torch.cuda.empty_cache() def predict(self, dataset): - if not self._fitted: + if not self.fitted: raise ValueError("model is not fitted yet!") dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) diff --git a/qlib/contrib/model/pytorch_nn.py b/qlib/contrib/model/pytorch_nn.py index ee23404fe..16fcea9ff 100644 --- a/qlib/contrib/model/pytorch_nn.py +++ b/qlib/contrib/model/pytorch_nn.py @@ -150,7 +150,7 @@ class DNNModelPytorch(Model): eps=1e-08, ) - self._fitted = False + self.fitted = False self.dnn_model.to(self.device) def fit( @@ -180,7 +180,7 @@ class DNNModelPytorch(Model): evals_result["valid"] = [] # train self.logger.info("training...") - self._fitted = True + self.fitted = True # return # prepare training data x_train_values = torch.from_numpy(x_train.values).float() @@ -265,7 +265,7 @@ class DNNModelPytorch(Model): raise NotImplementedError("loss {} is not supported!".format(loss_type)) def predict(self, dataset): - if not self._fitted: + if not self.fitted: raise ValueError("model is not fitted yet!") x_test_pd = dataset.prepare("test", col_set="feature") x_test = torch.from_numpy(x_test_pd.values).float().to(self.device) diff --git a/qlib/contrib/model/pytorch_sfm.py b/qlib/contrib/model/pytorch_sfm.py index ae175a202..d5169e6c7 100644 --- a/qlib/contrib/model/pytorch_sfm.py +++ b/qlib/contrib/model/pytorch_sfm.py @@ -302,7 +302,7 @@ class SFM(Model): else: raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) - self._fitted = False + self.fitted = False self.sfm_model.to(self.device) def test_epoch(self, data_x, data_y): @@ -386,7 +386,7 @@ class SFM(Model): # train self.logger.info("training...") - self._fitted = True + self.fitted = True for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) @@ -435,7 +435,7 @@ class SFM(Model): raise ValueError("unknown metric `%s`" % self.metric) def predict(self, dataset): - if not self._fitted: + if not self.fitted: raise ValueError("model is not fitted yet!") x_test = dataset.prepare("test", col_set="feature") diff --git a/qlib/contrib/model/pytorch_tabnet.py b/qlib/contrib/model/pytorch_tabnet.py index ef1c8e2a8..62e32d701 100644 --- a/qlib/contrib/model/pytorch_tabnet.py +++ b/qlib/contrib/model/pytorch_tabnet.py @@ -88,6 +88,7 @@ class TabnetModel(Model): "\nGPU : {}" "\npretrain: {}".format(self.batch_size, vbs, GPU, pretrain) ) + self.fitted = False np.random.seed(self.seed) torch.manual_seed(self.seed) @@ -187,7 +188,7 @@ class TabnetModel(Model): evals_result["valid"] = [] self.logger.info("training...") - self._fitted = True + self.fitted = True for epoch_idx in range(self.n_epochs): self.logger.info("epoch: %s" % (epoch_idx)) @@ -212,7 +213,7 @@ class TabnetModel(Model): self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch)) def predict(self, dataset): - if not self._fitted: + if not self.fitted: raise ValueError("model is not fitted yet!") x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I) diff --git a/qlib/data/data.py b/qlib/data/data.py index 71915a3c3..762467da3 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -478,13 +478,13 @@ class DatasetProvider(abc.ABC): data = pd.DataFrame(obj) _calendar = Cal.calendar(freq=freq) - data.index = _calendar[data.index.values.astype(np.int)] + data.index = _calendar[data.index.values.astype(int)] data.index.names = ["datetime"] if spans is None: return data else: - mask = np.zeros(len(data), dtype=np.bool) + mask = np.zeros(len(data), dtype=bool) for begin, end in spans: mask |= (data.index >= begin) & (data.index <= end) return data[mask] diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 6b98baf8f..8ff8c1210 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -87,9 +87,42 @@ class DatasetH(Dataset): """ super().__init__(handler, segments) - def init(self, **kwargs): - """Initialize the DatasetH, Only parameters belonging to handler.init will be passed in""" - self.handler.init(**kwargs) + def init(self, handler_kwargs: dict = None, segment_kwargs: dict = None): + """ + Initialize the DatasetH + + Parameters + ---------- + handler_kwargs : dict + Config of DataHanlder, which could include the following arguments: + + - arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'. + + - arguments of DataHandler.init, such as 'enable_cache', etc. + + segment_kwargs : dict + Config of segments which is same as 'segments' in DatasetH.setup_data + + """ + if handler_kwargs: + if not isinstance(handler_kwargs, dict): + raise TypeError(f"param handler_kwargs must be type dict, not {type(handler_kwargs)}") + kwargs_init = {} + kwargs_conf_data = {} + conf_data_arg = {"instruments", "start_time", "end_time"} + for k, v in handler_kwargs.items(): + if k in conf_data_arg: + kwargs_conf_data.update({k: v}) + else: + kwargs_init.update({k: v}) + + self.handler.conf_data(**kwargs_conf_data) + self.handler.init(**kwargs_init) + + if segment_kwargs: + if not isinstance(segment_kwargs, dict): + raise TypeError(f"param handler_kwargs must be type dict, not {type(segment_kwargs)}") + self.segments = segment_kwargs.copy() def setup_data(self, handler: Union[dict, DataHandler], segments: dict): """ diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index 8f128d382..f0bc0b780 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -26,6 +26,7 @@ def task_train(task_config: dict, experiment_name): R.log_params(**flatten_dict(task_config)) model.fit(dataset) recorder = R.get_recorder() + R.save_objects(**{"params.pkl": model}) # generate records: prediction, backtest, and analysis for record in task_config["record"]: diff --git a/qlib/tests/data.py b/qlib/tests/data.py index 6f8a0c9e0..3bf6a2c96 100644 --- a/qlib/tests/data.py +++ b/qlib/tests/data.py @@ -86,7 +86,6 @@ class GetData: @staticmethod def _delete_qlib_data(file_dir: Path): - logger.info(f"delete {file_dir}") rm_dirs = [] for _name in ["features", "calendars", "instruments", "features_cache", "dataset_cache"]: _p = file_dir.joinpath(_name) @@ -133,7 +132,11 @@ class GetData: Examples --------- + # get 1d data python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn + + # get 1min data + python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --interval 1min --region cn ------- """ diff --git a/scripts/data_collector/yahoo/README.md b/scripts/data_collector/yahoo/README.md index 520369ef5..b9fd9123c 100644 --- a/scripts/data_collector/yahoo/README.md +++ b/scripts/data_collector/yahoo/README.md @@ -21,7 +21,7 @@ pip install -r requirements.txt ### CN Data -#### 1d +#### 1d from yahoo ```bash @@ -33,18 +33,26 @@ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source/cn_1d # dump data cd qlib/scripts -python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1d_nor --qlib_dir ~/.qlib/stock_data/source/qlib_cn_1d --freq day --exclude_fields date,adjclose,dividends,splits,symbol - -# using -import qlib -from qlib.data import D - -qlib.init(provider_uri="~/.qlib/stock_data/source/qlib_cn_1d", region="cn") -df = D.features(D.instruments("all"), ["$close"], freq="day") +python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/qlib_cn_1d --freq day --exclude_fields date,adjclose,dividends,splits,symbol ``` -#### 1min +### 1d from qlib +```bash +python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1d --region cn +``` + +### using data + +```python +import qlib +from qlib.data import D + +qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1d", region="cn") +df = D.features(D.instruments("all"), ["$close"], freq="day") +``` + +#### 1min from yahoo ```bash @@ -56,20 +64,28 @@ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source/cn_1mi # dump data cd qlib/scripts -python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1min_nor --qlib_dir ~/.qlib/stock_data/source/qlib_cn_1min --freq 1min --exclude_fields date,adjclose,dividends,splits,symbol +python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1min_nor --qlib_dir ~/.qlib/qlib_data/qlib_cn_1min --freq 1min --exclude_fields date,adjclose,dividends,splits,symbol +``` -# using +### 1min from qlib +```bash +python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1min --interval 1min --region cn +``` + +### using data + +```python import qlib from qlib.data import D -qlib.init(provider_uri="~/.qlib/stock_data/source/qlib_cn_1min", region="CN") +qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1min", region="cn") df = D.features(D.instruments("all"), ["$close"], freq="1min") ``` ### US Data -#### 1d +#### 1d from yahoo ```bash @@ -82,12 +98,22 @@ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source/us_1d # dump data cd qlib/scripts python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/us_1d_nor --qlib_dir ~/.qlib/stock_data/source/qlib_us_1d --freq day --exclude_fields date,adjclose,dividends,splits,symbol +``` +#### 1d from qlib + +```bash +python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_us_1d --region us +``` + +### using data + +```python # using import qlib from qlib.data import D -qlib.init(provider_uri="~/.qlib/stock_data/source/qlib_us_1d", region="us") +qlib.init(provider_uri="~/.qlib/qlib_data/qlib_us_1d", region="us") df = D.features(D.instruments("all"), ["$close"], freq="day") ``` diff --git a/setup.py b/setup.py index 142731d07..f759945fd 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ if not _CYTHON_INSTALLED: # What packages are required for this module to be executed? # `estimator` may depend on other packages. In order to reduce dependencies, it is not written here. REQUIRED = [ - "numpy>=1.12.0,<=1.19.5", + "numpy>=1.12.0", "pandas>=0.25.1", "scipy>=1.0.0", "requests>=2.18.0",