1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 14:01:28 +08:00

Compare commits

..

20 Commits

Author SHA1 Message Date
Linlang
3d8aca7723 change weight data download url 2024-06-21 12:38:47 +08:00
Fivele-Li
47bd13295b Fix Yahoo daily data format inconsistent (#1517)
* Fix FutureWarning: Passing unit-less datetime64 dtype to .astype is deprecated and will raise in a future version. Pass 'datetime64[ns]' instead

* align index format while end date contains current day data

* fix black

* fix black

* optimize code

* optimize code

* optimize code

* fix ci error

* check ci error

* fix ci error

* check ci error

* check ci error

* check ci error

* check ci error

* check ci error

* check ci error

* fix ci error

* fix ci error

* fix ci error

* fix ci error

* fix ci error

---------

Co-authored-by: Cadenza-Li <362237642@qq.com>
Co-authored-by: Linlang <Lv.Linlang@hotmail.com>
2024-06-21 11:22:23 +08:00
陈屹华
ebc0ca893e Fix TSDataSampler Slicing Bug #1716 (#1803)
* Fix TSDataSampler Slicing Bug #1716

* Fix TSDataSampler Slicing Bug #1716

* Fix TSDataSampler Slicing Bug #1716

* Fix TSDataSampler Slicing Bug with simplyer implmentation#1716
 with Simplified Implementation

* Refactor: Fix CI errors by addressing pylint formatting issues

* Refactor: Remove extraneous whitespace for improved code formatting with Black
2024-06-21 09:25:23 +08:00
Lee Yuntong
3a348aec9f Fix typo (#1811)
Co-authored-by: LeeYuntong <nukuihayu@outlook.com>
2024-06-20 18:12:07 +08:00
Lee Yuntong
37b908792b Fix typo (#1809)
Co-authored-by: LeeYuntong <nukuihayu@outlook.com>
2024-06-19 17:31:57 +08:00
raikiriww
73ec0f4003 Add "mse" metric option to ALSTM.metric_fn (#1810) 2024-06-19 17:31:47 +08:00
Linlang
155c17f8ff fix logo display error (#1804) 2024-06-06 13:39:49 +08:00
Yang
41b94059aa fix panic during normalizing the invalid data (#1698)
* fix panic during normalizing the invalid data

* fix yaml load

* change error to warning

* change error code

* optimize code

---------

Co-authored-by: Linlang <Lv.Linlang@hotmail.com>
2024-06-02 06:54:39 +08:00
block-gpt
7db83d84b7 Update utils.py for typo (#1751)
Fix typo

Co-authored-by: Linlang <Lv.Linlang@hotmail.com>
2024-06-01 19:33:23 +08:00
Hao Zhao
35e0fdd1c0 fix the bug that the HS_SYMBOLS_URL is 404 (#1758)
* fix the bug that the HS_SYMBOLS_URL is 404

* fix bug

* format with black

* fix pylint error

* change error code

* fix ci error

* fix ci error

* optimize code

* optimize code

* add comments

---------

Co-authored-by: Linlang <Lv.Linlang@hotmail.com>
2024-06-01 08:07:34 +08:00
you-n-g
598017f634 Update Dev in README.md (#1800) 2024-05-29 17:44:18 +08:00
igeni
907c888c23 changed concat of strings to f-strings and redundant type conversion was removed (#1767)
Co-authored-by: Linlang <Lv.Linlang@hotmail.com>
2024-05-28 12:13:12 +08:00
Linlang
02fe6b6974 bump verison 2024-05-24 16:38:48 +08:00
Linlang
b892b21045 update version 2024-05-24 15:14:49 +08:00
Linlang
155f80323c fix get data error (#1793)
* fix get data error

* fix get v0 data error

* optimize get_data code

* fix pylint error

* add comments
2024-05-24 12:59:50 +08:00
you-n-g
63021018d6 Update README.md's dataset 2024-05-21 08:15:18 +08:00
Linlang
f79a0eeaff fix docs (#1788)
Co-authored-by: Linlang Lv (iSoftStone Information) <v-lvlinlang@microsoft.com>
2024-05-21 04:23:55 +08:00
Linlang
8a087d0db9 fix docs (#1721)
* fix docs

* modify file extension

* modify file extension

---------

Co-authored-by: Linlang Lv (iSoftStone Information) <v-lvlinlang@microsoft.com>
2024-05-17 19:19:45 +08:00
playfund
2ae4be426a Delete redundant copy() code to speed up (#1732)
Delete redundant copy() code to speed up

Co-authored-by: Linlang Lv (iSoftStone Information) <v-lvlinlang@microsoft.com>
2024-05-17 18:45:07 +08:00
fei long
6ed83f7c04 data_collector: cn_index: fix missing dependencies package in requirements.txt (#1770)
add yahooquery and openpyxl in requirements.txt

Signed-off-by: YuLong Yao <feilongphone@gmail.com>
Co-authored-by: Linlang Lv (iSoftStone Information) <v-lvlinlang@microsoft.com>
2024-05-17 18:43:12 +08:00
25 changed files with 192 additions and 57 deletions

View File

@@ -45,6 +45,9 @@ jobs:
- name: Qlib installation test
run: |
# 2024-05-30 scs has released a new version: 3.2.4.post2,
# This will cause the CI to fail, so we have limited the version of scs for now.
python -m pip install "scs<=3.2.4"
python -m pip install pyqlib
- name: Install Lightgbm for MacOS
@@ -65,5 +68,8 @@ jobs:
cd qlib
- name: Test workflow by config
# On macos-11 system, it will lead to "Segmentation fault: 11" error,
# which may be caused by the excessive memory overhead of macos-11 system, so we disable macos-11 temporarily here.
if: ${{ matrix.os != 'macos-11' }}
run: |
qrun examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml

View File

@@ -72,8 +72,10 @@ jobs:
black . -l 120 --check --diff
- name: Make html with sphinx
# Since read the docs builds on ubuntu 22.04, we only need to test that the build passes on ubuntu 22.04.
if: ${{ matrix.os == 'ubuntu-22.04' }}
run: |
cd docs
cd docs
sphinx-build -W --keep-going -b html . _build
cd ..
@@ -159,11 +161,16 @@ jobs:
# Run after data downloads
- name: Check Qlib ipynb with nbconvert
# Running the nbconvert check on a macos-11 system results in a "Kernel died" error, so we've temporarily disabled macos-11 here.
if: ${{ matrix.os != 'macos-11' }}
run: |
# add more ipynb files in future
jupyter nbconvert --to notebook --execute examples/workflow_by_code.ipynb
- name: Test workflow by config (install from source)
# On macos-11 system, it will lead to "Segmentation fault: 11" error,
# which may be caused by the excessive memory overhead of macos-11 system, so we disable macos-11 temporarily here.
if: ${{ matrix.os != 'macos-11' }}
run: |
python -m pip install numba
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml

View File

@@ -40,7 +40,7 @@ Recent released features
Features released before 2021 are not listed here.
<p align="center">
<img src="http://fintech.msra.cn/images_v070/logo/1.png" />
<img src="docs/_static/img/logo/1.png" />
</p>
Qlib is an open-source, AI-oriented quantitative investment platform that aims to realize the potential, empower research, and create value using AI technologies in quantitative investment, from exploring ideas to implementing productions. Qlib supports diverse machine learning modeling paradigms, including supervised learning, market dynamics modeling, and reinforcement learning.
@@ -166,7 +166,7 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
* Clone the repository and install ``Qlib`` as follows.
```bash
git clone https://github.com/microsoft/qlib.git && cd qlib
pip install .
pip install . # `pip install -e .[dev]` is recommended for development. check details in docs/developer/code_standard_and_dev_guide.rst
```
**Note**: You can install Qlib with `python setup.py install` as well. But it is not the recommended approach. It will skip `pip` and cause obscure problems. For example, **only** the command ``pip install .`` **can** overwrite the stable version installed by ``pip install pyqlib``, while the command ``python setup.py install`` **can't**.
@@ -175,6 +175,20 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
**Tips for Mac**: If you are using Mac with M1, you might encounter issues in building the wheel for LightGBM, which is due to missing dependencies from OpenMP. To solve the problem, install openmp first with ``brew install libomp`` and then run ``pip install .`` to build it successfully.
## Data Preparation
❗ Due to more restrict data security policy. The offical dataset is disabled temporarily. You can try [this data source](https://github.com/chenditc/investment_data/releases) contributed by the community.
Here is an example to download the data updated on 20220720.
```bash
wget https://github.com/chenditc/investment_data/releases/download/20220720/qlib_bin.tar.gz
mkdir -p ~/.qlib/qlib_data/cn_data
tar -zxvf qlib_bin.tar.gz -C ~/.qlib/qlib_data/cn_data --strip-components=2
rm -f qlib_bin.tar.gz
```
The official dataset below will resume in short future.
----
Load and prepare data by running the following code:
### Get with module

View File

@@ -86,7 +86,7 @@ Example
},
}
# model initiaiton
# model initialization
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])

View File

@@ -5,3 +5,4 @@ scipy
scikit-learn
pandas
tianshou
sphinx_rtd_theme

View File

@@ -16,7 +16,7 @@ Current version of script with default value tries to connect localhost **via de
Run following command to install necessary libraries
```
pip install pytest coverage
pip install pytest coverage gdown
pip install arctic # NOTE: pip may fail to resolve the right package dependency !!! Please make sure the dependency are satisfied.
```
@@ -27,7 +27,8 @@ pip install arctic # NOTE: pip may fail to resolve the right package dependency
2. Please follow following steps to download example data
```bash
cd examples/orderbook_data/
python ../../scripts/get_data.py download_data --target_dir . --file_name highfreq_orderbook_example_data.zip
gdown https://drive.google.com/uc?id=15nZF7tFT_eKVZAcMFL1qPS4jGyJflH7e # Proxies may be necessary here.
python ../../scripts/get_data.py _unzip --file_path highfreq_orderbook_example_data.zip --target_dir .
```
3. Please import the example data to your mongo db

View File

@@ -20,7 +20,7 @@ We use China stock market data for our example.
1. Prepare CSI300 weight:
```bash
wget http://fintech.msra.cn/stock_data/downloads/csi300_weight.zip
wget https://github.com/SunsetWolf/qlib_dataset/releases/download/v0/csi300_weight.zip
unzip -d ~/.qlib/qlib_data/cn_data csi300_weight.zip
rm -f csi300_weight.zip
```

View File

@@ -161,7 +161,7 @@
" },\n",
"}\n",
"\n",
"# model initiaiton\n",
"# model initialization\n",
"model = init_instance_by_config(task[\"model\"])\n",
"dataset = init_instance_by_config(task[\"dataset\"])\n",
"\n",

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT License.
from pathlib import Path
__version__ = "0.9.4.99"
__version__ = "0.9.5.99"
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
import os
from typing import Union

View File

@@ -160,6 +160,10 @@ class ALSTM(Model):
if self.metric in ("", "loss"):
return -self.loss_fn(pred[mask], label[mask])
elif self.metric == "mse":
mask = ~torch.isnan(label)
weight = torch.ones_like(label)
return -self.mse(pred[mask], label[mask], weight[mask])
raise ValueError("unknown metric `%s`" % self.metric)

View File

@@ -35,7 +35,7 @@ class Client:
def connect_server(self):
"""Connect to server."""
try:
self.sio.connect("ws://" + self.server_host + ":" + str(self.server_port))
self.sio.connect(f"ws://{self.server_host}:{self.server_port}")
except socketio.exceptions.ConnectionError:
self.logger.error("Cannot connect to server - check your network or server status")

View File

@@ -536,7 +536,6 @@ class DatasetProvider(abc.ABC):
"""
if len(fields) == 0:
raise ValueError("fields cannot be empty")
fields = fields.copy()
column_names = [str(f) for f in fields]
return column_names
@@ -617,7 +616,7 @@ class DatasetProvider(abc.ABC):
data = pd.DataFrame(obj)
if not data.empty and not np.issubdtype(data.index.dtype, np.dtype("M")):
# If the underlaying provides the data not in datatime formmat, we'll convert it into datetime format
# If the underlaying provides the data not in datetime format, we'll convert it into datetime format
_calendar = Cal.calendar(freq=freq)
data.index = _calendar[data.index.values.astype(int)]
data.index.names = ["datetime"]

View File

@@ -403,7 +403,7 @@ class TSDataSampler:
np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype),
axis=0,
)
self.nan_idx = -1 # The last line is all NaN
self.nan_idx = len(self.data_arr) - 1 # The last line is all NaN; setting it to -1 can cause bug #1716
# the data type will be changed
# The index of usable data is between start_idx and end_idx

View File

@@ -9,7 +9,7 @@ if TYPE_CHECKING:
from qlib.data.dataset import DataHandler
def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int:
def get_level_index(df: pd.DataFrame, level: Union[str, int]) -> int:
"""
get the level index of `df` given `level`

View File

@@ -41,7 +41,7 @@ def _log_task_info(task_config: dict):
def _exe_task(task_config: dict):
rec = R.get_recorder()
# model & dataset initiation
# model & dataset initialization
model: Model = init_instance_by_config(task_config["model"], accept_types=Model)
dataset: Dataset = init_instance_by_config(task_config["dataset"], accept_types=Dataset)
reweighter: Reweighter = task_config.get("reweighter", None)

View File

@@ -12,15 +12,11 @@ import datetime
from tqdm import tqdm
from pathlib import Path
from loguru import logger
from cryptography.fernet import Fernet
from qlib.utils import exists_qlib_data
class GetData:
REMOTE_URL = "https://qlibpublic.blob.core.windows.net/data/default/stock_data"
# "?" is not included in the token.
TOKEN = b"gAAAAABkmDhojHc0VSCDdNK1MqmRzNLeDFXe5hy8obHpa6SDQh4de6nW5gtzuD-fa6O_WZb0yyqYOL7ndOfJX_751W3xN5YB4-n-P22jK-t6ucoZqhT70KPD0Lf0_P328QPJVZ1gDnjIdjhi2YLOcP4BFTHLNYO0mvzszR8TKm9iT5AKRvuysWnpi8bbYwGU9zAcJK3x9EPL43hOGtxliFHcPNGMBoJW4g_ercdhi0-Qgv5_JLsV-29_MV-_AhuaYvJuN2dEywBy"
KEY = "EYcA8cgorA8X9OhyMwVfuFxn_1W3jGk6jCbs3L2oPoA="
REMOTE_URL = "https://github.com/SunsetWolf/qlib_dataset/releases/download"
def __init__(self, delete_zip_file=False):
"""
@@ -33,9 +29,45 @@ class GetData:
self.delete_zip_file = delete_zip_file
def merge_remote_url(self, file_name: str):
fernet = Fernet(self.KEY)
token = fernet.decrypt(self.TOKEN).decode()
return f"{self.REMOTE_URL}/{file_name}?{token}"
"""
Generate download links.
Parameters
----------
file_name: str
The name of the file to be downloaded.
The file name can be accompanied by a version number, (e.g.: v2/qlib_data_simple_cn_1d_latest.zip),
if no version number is attached, it will be downloaded from v0 by default.
"""
return f"{self.REMOTE_URL}/{file_name}" if "/" in file_name else f"{self.REMOTE_URL}/v0/{file_name}"
def download(self, url: str, target_path: [Path, str]):
"""
Download a file from the specified url.
Parameters
----------
url: str
The url of the data.
target_path: str
The location where the data is saved, including the file name.
"""
file_name = str(target_path).rsplit("/", maxsplit=1)[-1]
resp = requests.get(url, stream=True, timeout=60)
resp.raise_for_status()
if resp.status_code != 200:
raise requests.exceptions.HTTPError()
chunk_size = 1024
logger.warning(
f"The data for the example is collected from Yahoo Finance. Please be aware that the quality of the data might not be perfect. (You can refer to the original data source: https://finance.yahoo.com/lookup.)"
)
logger.info(f"{os.path.basename(file_name)} downloading......")
with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar:
with target_path.open("wb") as fp:
for chunk in resp.iter_content(chunk_size=chunk_size):
fp.write(chunk)
p_bar.update(chunk_size)
def download_data(self, file_name: str, target_dir: [Path, str], delete_old: bool = True):
"""
@@ -70,21 +102,7 @@ class GetData:
target_path = target_dir.joinpath(_target_file_name)
url = self.merge_remote_url(file_name)
resp = requests.get(url, stream=True, timeout=60)
resp.raise_for_status()
if resp.status_code != 200:
raise requests.exceptions.HTTPError()
chunk_size = 1024
logger.warning(
f"The data for the example is collected from Yahoo Finance. Please be aware that the quality of the data might not be perfect. (You can refer to the original data source: https://finance.yahoo.com/lookup.)"
)
logger.info(f"{os.path.basename(file_name)} downloading......")
with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar:
with target_path.open("wb") as fp:
for chunk in resp.iter_content(chunk_size=chunk_size):
fp.write(chunk)
p_bar.update(chunk_size)
self.download(url=url, target_path=target_path)
self._unzip(target_path, target_dir, delete_old)
if self.delete_zip_file:
@@ -99,7 +117,9 @@ class GetData:
return status
@staticmethod
def _unzip(file_path: Path, target_dir: Path, delete_old: bool = True):
def _unzip(file_path: [Path, str], target_dir: [Path, str], delete_old: bool = True):
file_path = Path(file_path)
target_dir = Path(target_dir)
if delete_old:
logger.warning(
f"will delete the old qlib data directory(features, instruments, calendars, features_cache, dataset_cache): {target_dir}"

View File

@@ -242,7 +242,7 @@ class TimeAdjuster:
def shift(self, seg: tuple, step: int, rtype=SHIFT_SD) -> tuple:
"""
Shift the datatime of segment
Shift the datetime of segment
If there are None (which indicates unbounded index) in the segment, this method will return None.

View File

@@ -301,6 +301,7 @@ class Normalize:
na_values={col: symbol_na if col == self._symbol_field_name else default_na for col in columns},
)
# NOTE: It has been reported that there may be some problems here, and the specific issues will be dealt with when they are identified.
df = self._normalize_obj.normalize(df)
if df is not None and not df.empty:
if self._end_date is not None:

View File

@@ -5,3 +5,5 @@ pandas
lxml
loguru
tqdm
yahooquery
openpyxl

View File

@@ -9,7 +9,7 @@ pip install -r requirements.txt
```
## Usage of the dataset
> *Crypto dateset only support Data retrieval function but not support backtest function due to the lack of OHLC data.*
> *Crypto dataset only support Data retrieval function but not support backtest function due to the lack of OHLC data.*
## Collector Data

View File

@@ -15,7 +15,6 @@ from typing import Iterable, Tuple, List
import numpy as np
import pandas as pd
from lxml import etree
from loguru import logger
from yahooquery import Ticker
from tqdm import tqdm
@@ -190,17 +189,43 @@ def get_hs_stock_symbols() -> list:
global _HS_SYMBOLS # pylint: disable=W0603
def _get_symbol():
_res = set()
for _k, _v in (("ha", "ss"), ("sa", "sz"), ("gem", "sz")):
resp = requests.get(HS_SYMBOLS_URL.format(s_type=_k), timeout=None)
_res |= set(
map(
lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v), # pylint: disable=W0640
etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"), # pylint: disable=I1101
)
)
time.sleep(3)
return _res
"""
Get the stock pool from a web page and process it into the format required by yahooquery.
Format of data retrieved from the web page: 600519, 000001
The data format required by yahooquery: 600519.ss, 000001.sz
Returns
-------
set: Returns the set of symbol codes.
Examples:
-------
{600000.ss, 600001.ss, 600002.ss, 600003.ss, ...}
"""
url = "http://99.push2.eastmoney.com/api/qt/clist/get?pn=1&pz=10000&po=1&np=1&fs=m:0+t:6,m:0+t:80,m:1+t:2,m:1+t:23,m:0+t:81+s:2048&fields=f12"
try:
resp = requests.get(url, timeout=None)
resp.raise_for_status()
except requests.exceptions.HTTPError as e:
raise requests.exceptions.HTTPError(f"Request to {url} failed with status code {resp.status_code}") from e
try:
_symbols = [_v["f12"] for _v in resp.json()["data"]["diff"]]
except Exception as e:
logger.warning("An error occurred while extracting data from the response.")
raise
if len(_symbols) < 3900:
raise ValueError("The complete list of stocks is not available.")
# Add suffix after the stock code to conform to yahooquery standard, otherwise the data will not be fetched.
_symbols = [
_symbol + ".ss" if _symbol.startswith("6") else _symbol + ".sz" if _symbol.startswith(("0", "3")) else None
for _symbol in _symbols
]
_symbols = [_symbol for _symbol in _symbols if _symbol is not None]
return set(_symbols)
if _HS_SYMBOLS is None:
symbols = set()

View File

@@ -796,6 +796,9 @@ class Run(BaseRun):
# get 1m data
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
"""
if self.interval == "1d" and pd.Timestamp(end) > pd.Timestamp(datetime.datetime.now().strftime("%Y-%m-%d")):
raise ValueError(f"end_date: {end} is greater than the current date.")
super(Run, self).download_data(max_collector_count, delay, start, end, check_data_length, limit_nums)
def normalize_data(

View File

@@ -46,7 +46,7 @@ if not _CYTHON_INSTALLED:
REQUIRED = [
"numpy>=1.12.0, <1.24",
"pandas>=0.25.1",
"scipy>=1.0.0",
"scipy>=1.7.3",
"requests>=2.18.0",
"sacred>=0.7.4",
"python-socketio",
@@ -82,7 +82,7 @@ REQUIRED = [
"dill",
"dataclasses;python_version<'3.7'",
"filelock",
"jinja2<3.1.0", # for passing the readthedocs workflow.
"jinja2",
"gym",
# Installing the latest version of protobuf for python versions below 3.8 will cause unit tests to fail.
"protobuf<=3.20.1;python_version<='3.8'",
@@ -166,6 +166,9 @@ setup(
"lxml",
"baostock",
"yahooquery",
# 2024-05-30 scs has released a new version: 3.2.4.post2,
# this version, causes qlib installation to fail, so we've limited the scs version a bit for now.
"scs<=3.2.4",
"beautifulsoup4",
# In version 0.4.11 of tianshou, the code:
# logits, hidden = self.actor(batch.obs, state=state, info=batch.info)

View File

@@ -5,8 +5,9 @@ import unittest
import pytest
import sys
from qlib.tests import TestAutoData
from qlib.data.dataset import TSDatasetH
from qlib.data.dataset import TSDatasetH, TSDataSampler
import numpy as np
import pandas as pd
import time
from qlib.data.dataset.handler import DataHandlerLP
@@ -98,6 +99,54 @@ class TestDataset(TestAutoData):
print(idx[i])
class TestTSDataSampler(unittest.TestCase):
def test_TSDataSampler(self):
"""
Test TSDataSampler for issue #1716
"""
datetime_list = ["2000-01-31", "2000-02-29", "2000-03-31", "2000-04-30", "2000-05-31"]
instruments = ["000001", "000002", "000003", "000004", "000005"]
index = pd.MultiIndex.from_product(
[pd.to_datetime(datetime_list), instruments], names=["datetime", "instrument"]
)
data = np.random.randn(len(datetime_list) * len(instruments))
test_df = pd.DataFrame(data=data, index=index, columns=["factor"])
dataset = TSDataSampler(test_df, datetime_list[0], datetime_list[-1], step_len=2)
print()
print("--------------dataset[0]--------------")
print(dataset[0])
print("--------------dataset[1]--------------")
print(dataset[1])
assert len(dataset[0]) == 2
self.assertTrue(np.isnan(dataset[0][0]))
self.assertEqual(dataset[0][1], dataset[1][0])
self.assertEqual(dataset[1][1], dataset[2][0])
self.assertEqual(dataset[2][1], dataset[3][0])
def test_TSDataSampler2(self):
"""
Extra test TSDataSampler to prevent incorrect filling of nan for the values at the front
"""
datetime_list = ["2000-01-31", "2000-02-29", "2000-03-31", "2000-04-30", "2000-05-31"]
instruments = ["000001", "000002", "000003", "000004", "000005"]
index = pd.MultiIndex.from_product(
[pd.to_datetime(datetime_list), instruments], names=["datetime", "instrument"]
)
data = np.random.randn(len(datetime_list) * len(instruments))
test_df = pd.DataFrame(data=data, index=index, columns=["factor"])
dataset = TSDataSampler(test_df, datetime_list[2], datetime_list[-1], step_len=3)
print()
print("--------------dataset[0]--------------")
print(dataset[0])
print("--------------dataset[1]--------------")
print(dataset[1])
for i in range(3):
self.assertFalse(np.isnan(dataset[0][i]))
self.assertFalse(np.isnan(dataset[1][i]))
self.assertEqual(dataset[0][1], dataset[1][0])
self.assertEqual(dataset[0][2], dataset[1][1])
if __name__ == "__main__":
unittest.main(verbosity=10)

View File

@@ -27,7 +27,7 @@ def train(uri_path: str = None):
model performance
"""
# model initiaiton
# model initialization
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
# To test __repr__