mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-29 00:51:19 +08:00
Compare commits
7 Commits
fix_logo_d
...
fix_get_we
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3d8aca7723 | ||
|
|
47bd13295b | ||
|
|
ebc0ca893e | ||
|
|
3a348aec9f | ||
|
|
37b908792b | ||
|
|
73ec0f4003 | ||
|
|
155c17f8ff |
3
.github/workflows/test_qlib_from_pip.yml
vendored
3
.github/workflows/test_qlib_from_pip.yml
vendored
@@ -68,5 +68,8 @@ jobs:
|
||||
cd qlib
|
||||
|
||||
- name: Test workflow by config
|
||||
# On macos-11 system, it will lead to "Segmentation fault: 11" error,
|
||||
# which may be caused by the excessive memory overhead of macos-11 system, so we disable macos-11 temporarily here.
|
||||
if: ${{ matrix.os != 'macos-11' }}
|
||||
run: |
|
||||
qrun examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
9
.github/workflows/test_qlib_from_source.yml
vendored
9
.github/workflows/test_qlib_from_source.yml
vendored
@@ -72,8 +72,10 @@ jobs:
|
||||
black . -l 120 --check --diff
|
||||
|
||||
- name: Make html with sphinx
|
||||
# Since read the docs builds on ubuntu 22.04, we only need to test that the build passes on ubuntu 22.04.
|
||||
if: ${{ matrix.os == 'ubuntu-22.04' }}
|
||||
run: |
|
||||
cd docs
|
||||
cd docs
|
||||
sphinx-build -W --keep-going -b html . _build
|
||||
cd ..
|
||||
|
||||
@@ -159,11 +161,16 @@ jobs:
|
||||
|
||||
# Run after data downloads
|
||||
- name: Check Qlib ipynb with nbconvert
|
||||
# Running the nbconvert check on a macos-11 system results in a "Kernel died" error, so we've temporarily disabled macos-11 here.
|
||||
if: ${{ matrix.os != 'macos-11' }}
|
||||
run: |
|
||||
# add more ipynb files in future
|
||||
jupyter nbconvert --to notebook --execute examples/workflow_by_code.ipynb
|
||||
|
||||
- name: Test workflow by config (install from source)
|
||||
# On macos-11 system, it will lead to "Segmentation fault: 11" error,
|
||||
# which may be caused by the excessive memory overhead of macos-11 system, so we disable macos-11 temporarily here.
|
||||
if: ${{ matrix.os != 'macos-11' }}
|
||||
run: |
|
||||
python -m pip install numba
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
@@ -86,7 +86,7 @@ Example
|
||||
},
|
||||
}
|
||||
|
||||
# model initiaiton
|
||||
# model initialization
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ We use China stock market data for our example.
|
||||
1. Prepare CSI300 weight:
|
||||
|
||||
```bash
|
||||
wget http://fintech.msra.cn/stock_data/downloads/csi300_weight.zip
|
||||
wget https://github.com/SunsetWolf/qlib_dataset/releases/download/v0/csi300_weight.zip
|
||||
unzip -d ~/.qlib/qlib_data/cn_data csi300_weight.zip
|
||||
rm -f csi300_weight.zip
|
||||
```
|
||||
|
||||
@@ -161,7 +161,7 @@
|
||||
" },\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# model initiaiton\n",
|
||||
"# model initialization\n",
|
||||
"model = init_instance_by_config(task[\"model\"])\n",
|
||||
"dataset = init_instance_by_config(task[\"dataset\"])\n",
|
||||
"\n",
|
||||
|
||||
@@ -160,6 +160,10 @@ class ALSTM(Model):
|
||||
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
elif self.metric == "mse":
|
||||
mask = ~torch.isnan(label)
|
||||
weight = torch.ones_like(label)
|
||||
return -self.mse(pred[mask], label[mask], weight[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
|
||||
@@ -616,7 +616,7 @@ class DatasetProvider(abc.ABC):
|
||||
|
||||
data = pd.DataFrame(obj)
|
||||
if not data.empty and not np.issubdtype(data.index.dtype, np.dtype("M")):
|
||||
# If the underlaying provides the data not in datatime formmat, we'll convert it into datetime format
|
||||
# If the underlaying provides the data not in datetime format, we'll convert it into datetime format
|
||||
_calendar = Cal.calendar(freq=freq)
|
||||
data.index = _calendar[data.index.values.astype(int)]
|
||||
data.index.names = ["datetime"]
|
||||
|
||||
@@ -403,7 +403,7 @@ class TSDataSampler:
|
||||
np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype),
|
||||
axis=0,
|
||||
)
|
||||
self.nan_idx = -1 # The last line is all NaN
|
||||
self.nan_idx = len(self.data_arr) - 1 # The last line is all NaN; setting it to -1 can cause bug #1716
|
||||
|
||||
# the data type will be changed
|
||||
# The index of usable data is between start_idx and end_idx
|
||||
|
||||
@@ -41,7 +41,7 @@ def _log_task_info(task_config: dict):
|
||||
|
||||
def _exe_task(task_config: dict):
|
||||
rec = R.get_recorder()
|
||||
# model & dataset initiation
|
||||
# model & dataset initialization
|
||||
model: Model = init_instance_by_config(task_config["model"], accept_types=Model)
|
||||
dataset: Dataset = init_instance_by_config(task_config["dataset"], accept_types=Dataset)
|
||||
reweighter: Reweighter = task_config.get("reweighter", None)
|
||||
|
||||
@@ -242,7 +242,7 @@ class TimeAdjuster:
|
||||
|
||||
def shift(self, seg: tuple, step: int, rtype=SHIFT_SD) -> tuple:
|
||||
"""
|
||||
Shift the datatime of segment
|
||||
Shift the datetime of segment
|
||||
|
||||
If there are None (which indicates unbounded index) in the segment, this method will return None.
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Usage of the dataset
|
||||
> *Crypto dateset only support Data retrieval function but not support backtest function due to the lack of OHLC data.*
|
||||
> *Crypto dataset only support Data retrieval function but not support backtest function due to the lack of OHLC data.*
|
||||
|
||||
## Collector Data
|
||||
|
||||
|
||||
@@ -796,6 +796,9 @@ class Run(BaseRun):
|
||||
# get 1m data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
|
||||
"""
|
||||
if self.interval == "1d" and pd.Timestamp(end) > pd.Timestamp(datetime.datetime.now().strftime("%Y-%m-%d")):
|
||||
raise ValueError(f"end_date: {end} is greater than the current date.")
|
||||
|
||||
super(Run, self).download_data(max_collector_count, delay, start, end, check_data_length, limit_nums)
|
||||
|
||||
def normalize_data(
|
||||
|
||||
4
setup.py
4
setup.py
@@ -46,7 +46,7 @@ if not _CYTHON_INSTALLED:
|
||||
REQUIRED = [
|
||||
"numpy>=1.12.0, <1.24",
|
||||
"pandas>=0.25.1",
|
||||
"scipy>=1.0.0",
|
||||
"scipy>=1.7.3",
|
||||
"requests>=2.18.0",
|
||||
"sacred>=0.7.4",
|
||||
"python-socketio",
|
||||
@@ -82,7 +82,7 @@ REQUIRED = [
|
||||
"dill",
|
||||
"dataclasses;python_version<'3.7'",
|
||||
"filelock",
|
||||
"jinja2<3.1.0", # for passing the readthedocs workflow.
|
||||
"jinja2",
|
||||
"gym",
|
||||
# Installing the latest version of protobuf for python versions below 3.8 will cause unit tests to fail.
|
||||
"protobuf<=3.20.1;python_version<='3.8'",
|
||||
|
||||
@@ -5,8 +5,9 @@ import unittest
|
||||
import pytest
|
||||
import sys
|
||||
from qlib.tests import TestAutoData
|
||||
from qlib.data.dataset import TSDatasetH
|
||||
from qlib.data.dataset import TSDatasetH, TSDataSampler
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import time
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
|
||||
@@ -98,6 +99,54 @@ class TestDataset(TestAutoData):
|
||||
print(idx[i])
|
||||
|
||||
|
||||
class TestTSDataSampler(unittest.TestCase):
|
||||
def test_TSDataSampler(self):
|
||||
"""
|
||||
Test TSDataSampler for issue #1716
|
||||
"""
|
||||
datetime_list = ["2000-01-31", "2000-02-29", "2000-03-31", "2000-04-30", "2000-05-31"]
|
||||
instruments = ["000001", "000002", "000003", "000004", "000005"]
|
||||
index = pd.MultiIndex.from_product(
|
||||
[pd.to_datetime(datetime_list), instruments], names=["datetime", "instrument"]
|
||||
)
|
||||
data = np.random.randn(len(datetime_list) * len(instruments))
|
||||
test_df = pd.DataFrame(data=data, index=index, columns=["factor"])
|
||||
dataset = TSDataSampler(test_df, datetime_list[0], datetime_list[-1], step_len=2)
|
||||
print()
|
||||
print("--------------dataset[0]--------------")
|
||||
print(dataset[0])
|
||||
print("--------------dataset[1]--------------")
|
||||
print(dataset[1])
|
||||
assert len(dataset[0]) == 2
|
||||
self.assertTrue(np.isnan(dataset[0][0]))
|
||||
self.assertEqual(dataset[0][1], dataset[1][0])
|
||||
self.assertEqual(dataset[1][1], dataset[2][0])
|
||||
self.assertEqual(dataset[2][1], dataset[3][0])
|
||||
|
||||
def test_TSDataSampler2(self):
|
||||
"""
|
||||
Extra test TSDataSampler to prevent incorrect filling of nan for the values at the front
|
||||
"""
|
||||
datetime_list = ["2000-01-31", "2000-02-29", "2000-03-31", "2000-04-30", "2000-05-31"]
|
||||
instruments = ["000001", "000002", "000003", "000004", "000005"]
|
||||
index = pd.MultiIndex.from_product(
|
||||
[pd.to_datetime(datetime_list), instruments], names=["datetime", "instrument"]
|
||||
)
|
||||
data = np.random.randn(len(datetime_list) * len(instruments))
|
||||
test_df = pd.DataFrame(data=data, index=index, columns=["factor"])
|
||||
dataset = TSDataSampler(test_df, datetime_list[2], datetime_list[-1], step_len=3)
|
||||
print()
|
||||
print("--------------dataset[0]--------------")
|
||||
print(dataset[0])
|
||||
print("--------------dataset[1]--------------")
|
||||
print(dataset[1])
|
||||
for i in range(3):
|
||||
self.assertFalse(np.isnan(dataset[0][i]))
|
||||
self.assertFalse(np.isnan(dataset[1][i]))
|
||||
self.assertEqual(dataset[0][1], dataset[1][0])
|
||||
self.assertEqual(dataset[0][2], dataset[1][1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=10)
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ def train(uri_path: str = None):
|
||||
model performance
|
||||
"""
|
||||
|
||||
# model initiaiton
|
||||
# model initialization
|
||||
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
|
||||
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
|
||||
# To test __repr__
|
||||
|
||||
Reference in New Issue
Block a user