mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 14:01:28 +08:00
Compare commits
13 Commits
v0.9.5
...
fix_get_we
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3d8aca7723 | ||
|
|
47bd13295b | ||
|
|
ebc0ca893e | ||
|
|
3a348aec9f | ||
|
|
37b908792b | ||
|
|
73ec0f4003 | ||
|
|
155c17f8ff | ||
|
|
41b94059aa | ||
|
|
7db83d84b7 | ||
|
|
35e0fdd1c0 | ||
|
|
598017f634 | ||
|
|
907c888c23 | ||
|
|
02fe6b6974 |
6
.github/workflows/test_qlib_from_pip.yml
vendored
6
.github/workflows/test_qlib_from_pip.yml
vendored
@@ -45,6 +45,9 @@ jobs:
|
||||
|
||||
- name: Qlib installation test
|
||||
run: |
|
||||
# 2024-05-30 scs has released a new version: 3.2.4.post2,
|
||||
# This will cause the CI to fail, so we have limited the version of scs for now.
|
||||
python -m pip install "scs<=3.2.4"
|
||||
python -m pip install pyqlib
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
@@ -65,5 +68,8 @@ jobs:
|
||||
cd qlib
|
||||
|
||||
- name: Test workflow by config
|
||||
# On macos-11 system, it will lead to "Segmentation fault: 11" error,
|
||||
# which may be caused by the excessive memory overhead of macos-11 system, so we disable macos-11 temporarily here.
|
||||
if: ${{ matrix.os != 'macos-11' }}
|
||||
run: |
|
||||
qrun examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
9
.github/workflows/test_qlib_from_source.yml
vendored
9
.github/workflows/test_qlib_from_source.yml
vendored
@@ -72,8 +72,10 @@ jobs:
|
||||
black . -l 120 --check --diff
|
||||
|
||||
- name: Make html with sphinx
|
||||
# Since read the docs builds on ubuntu 22.04, we only need to test that the build passes on ubuntu 22.04.
|
||||
if: ${{ matrix.os == 'ubuntu-22.04' }}
|
||||
run: |
|
||||
cd docs
|
||||
cd docs
|
||||
sphinx-build -W --keep-going -b html . _build
|
||||
cd ..
|
||||
|
||||
@@ -159,11 +161,16 @@ jobs:
|
||||
|
||||
# Run after data downloads
|
||||
- name: Check Qlib ipynb with nbconvert
|
||||
# Running the nbconvert check on a macos-11 system results in a "Kernel died" error, so we've temporarily disabled macos-11 here.
|
||||
if: ${{ matrix.os != 'macos-11' }}
|
||||
run: |
|
||||
# add more ipynb files in future
|
||||
jupyter nbconvert --to notebook --execute examples/workflow_by_code.ipynb
|
||||
|
||||
- name: Test workflow by config (install from source)
|
||||
# On macos-11 system, it will lead to "Segmentation fault: 11" error,
|
||||
# which may be caused by the excessive memory overhead of macos-11 system, so we disable macos-11 temporarily here.
|
||||
if: ${{ matrix.os != 'macos-11' }}
|
||||
run: |
|
||||
python -m pip install numba
|
||||
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
@@ -40,7 +40,7 @@ Recent released features
|
||||
Features released before 2021 are not listed here.
|
||||
|
||||
<p align="center">
|
||||
<img src="http://fintech.msra.cn/images_v070/logo/1.png" />
|
||||
<img src="docs/_static/img/logo/1.png" />
|
||||
</p>
|
||||
|
||||
Qlib is an open-source, AI-oriented quantitative investment platform that aims to realize the potential, empower research, and create value using AI technologies in quantitative investment, from exploring ideas to implementing productions. Qlib supports diverse machine learning modeling paradigms, including supervised learning, market dynamics modeling, and reinforcement learning.
|
||||
@@ -166,7 +166,7 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
|
||||
* Clone the repository and install ``Qlib`` as follows.
|
||||
```bash
|
||||
git clone https://github.com/microsoft/qlib.git && cd qlib
|
||||
pip install .
|
||||
pip install . # `pip install -e .[dev]` is recommended for development. check details in docs/developer/code_standard_and_dev_guide.rst
|
||||
```
|
||||
**Note**: You can install Qlib with `python setup.py install` as well. But it is not the recommended approach. It will skip `pip` and cause obscure problems. For example, **only** the command ``pip install .`` **can** overwrite the stable version installed by ``pip install pyqlib``, while the command ``python setup.py install`` **can't**.
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ Example
|
||||
},
|
||||
}
|
||||
|
||||
# model initiaiton
|
||||
# model initialization
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ We use China stock market data for our example.
|
||||
1. Prepare CSI300 weight:
|
||||
|
||||
```bash
|
||||
wget http://fintech.msra.cn/stock_data/downloads/csi300_weight.zip
|
||||
wget https://github.com/SunsetWolf/qlib_dataset/releases/download/v0/csi300_weight.zip
|
||||
unzip -d ~/.qlib/qlib_data/cn_data csi300_weight.zip
|
||||
rm -f csi300_weight.zip
|
||||
```
|
||||
|
||||
@@ -161,7 +161,7 @@
|
||||
" },\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# model initiaiton\n",
|
||||
"# model initialization\n",
|
||||
"model = init_instance_by_config(task[\"model\"])\n",
|
||||
"dataset = init_instance_by_config(task[\"dataset\"])\n",
|
||||
"\n",
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
__version__ = "0.9.5"
|
||||
__version__ = "0.9.5.99"
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
@@ -160,6 +160,10 @@ class ALSTM(Model):
|
||||
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
elif self.metric == "mse":
|
||||
mask = ~torch.isnan(label)
|
||||
weight = torch.ones_like(label)
|
||||
return -self.mse(pred[mask], label[mask], weight[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ class Client:
|
||||
def connect_server(self):
|
||||
"""Connect to server."""
|
||||
try:
|
||||
self.sio.connect("ws://" + self.server_host + ":" + str(self.server_port))
|
||||
self.sio.connect(f"ws://{self.server_host}:{self.server_port}")
|
||||
except socketio.exceptions.ConnectionError:
|
||||
self.logger.error("Cannot connect to server - check your network or server status")
|
||||
|
||||
|
||||
@@ -616,7 +616,7 @@ class DatasetProvider(abc.ABC):
|
||||
|
||||
data = pd.DataFrame(obj)
|
||||
if not data.empty and not np.issubdtype(data.index.dtype, np.dtype("M")):
|
||||
# If the underlaying provides the data not in datatime formmat, we'll convert it into datetime format
|
||||
# If the underlaying provides the data not in datetime format, we'll convert it into datetime format
|
||||
_calendar = Cal.calendar(freq=freq)
|
||||
data.index = _calendar[data.index.values.astype(int)]
|
||||
data.index.names = ["datetime"]
|
||||
|
||||
@@ -403,7 +403,7 @@ class TSDataSampler:
|
||||
np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype),
|
||||
axis=0,
|
||||
)
|
||||
self.nan_idx = -1 # The last line is all NaN
|
||||
self.nan_idx = len(self.data_arr) - 1 # The last line is all NaN; setting it to -1 can cause bug #1716
|
||||
|
||||
# the data type will be changed
|
||||
# The index of usable data is between start_idx and end_idx
|
||||
|
||||
@@ -9,7 +9,7 @@ if TYPE_CHECKING:
|
||||
from qlib.data.dataset import DataHandler
|
||||
|
||||
|
||||
def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int:
|
||||
def get_level_index(df: pd.DataFrame, level: Union[str, int]) -> int:
|
||||
"""
|
||||
|
||||
get the level index of `df` given `level`
|
||||
|
||||
@@ -41,7 +41,7 @@ def _log_task_info(task_config: dict):
|
||||
|
||||
def _exe_task(task_config: dict):
|
||||
rec = R.get_recorder()
|
||||
# model & dataset initiation
|
||||
# model & dataset initialization
|
||||
model: Model = init_instance_by_config(task_config["model"], accept_types=Model)
|
||||
dataset: Dataset = init_instance_by_config(task_config["dataset"], accept_types=Dataset)
|
||||
reweighter: Reweighter = task_config.get("reweighter", None)
|
||||
|
||||
@@ -242,7 +242,7 @@ class TimeAdjuster:
|
||||
|
||||
def shift(self, seg: tuple, step: int, rtype=SHIFT_SD) -> tuple:
|
||||
"""
|
||||
Shift the datatime of segment
|
||||
Shift the datetime of segment
|
||||
|
||||
If there are None (which indicates unbounded index) in the segment, this method will return None.
|
||||
|
||||
|
||||
@@ -301,6 +301,7 @@ class Normalize:
|
||||
na_values={col: symbol_na if col == self._symbol_field_name else default_na for col in columns},
|
||||
)
|
||||
|
||||
# NOTE: It has been reported that there may be some problems here, and the specific issues will be dealt with when they are identified.
|
||||
df = self._normalize_obj.normalize(df)
|
||||
if df is not None and not df.empty:
|
||||
if self._end_date is not None:
|
||||
|
||||
@@ -9,7 +9,7 @@ pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Usage of the dataset
|
||||
> *Crypto dateset only support Data retrieval function but not support backtest function due to the lack of OHLC data.*
|
||||
> *Crypto dataset only support Data retrieval function but not support backtest function due to the lack of OHLC data.*
|
||||
|
||||
## Collector Data
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ from typing import Iterable, Tuple, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from lxml import etree
|
||||
from loguru import logger
|
||||
from yahooquery import Ticker
|
||||
from tqdm import tqdm
|
||||
@@ -190,17 +189,43 @@ def get_hs_stock_symbols() -> list:
|
||||
global _HS_SYMBOLS # pylint: disable=W0603
|
||||
|
||||
def _get_symbol():
|
||||
_res = set()
|
||||
for _k, _v in (("ha", "ss"), ("sa", "sz"), ("gem", "sz")):
|
||||
resp = requests.get(HS_SYMBOLS_URL.format(s_type=_k), timeout=None)
|
||||
_res |= set(
|
||||
map(
|
||||
lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v), # pylint: disable=W0640
|
||||
etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"), # pylint: disable=I1101
|
||||
)
|
||||
)
|
||||
time.sleep(3)
|
||||
return _res
|
||||
"""
|
||||
Get the stock pool from a web page and process it into the format required by yahooquery.
|
||||
Format of data retrieved from the web page: 600519, 000001
|
||||
The data format required by yahooquery: 600519.ss, 000001.sz
|
||||
|
||||
Returns
|
||||
-------
|
||||
set: Returns the set of symbol codes.
|
||||
|
||||
Examples:
|
||||
-------
|
||||
{600000.ss, 600001.ss, 600002.ss, 600003.ss, ...}
|
||||
"""
|
||||
url = "http://99.push2.eastmoney.com/api/qt/clist/get?pn=1&pz=10000&po=1&np=1&fs=m:0+t:6,m:0+t:80,m:1+t:2,m:1+t:23,m:0+t:81+s:2048&fields=f12"
|
||||
try:
|
||||
resp = requests.get(url, timeout=None)
|
||||
resp.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
raise requests.exceptions.HTTPError(f"Request to {url} failed with status code {resp.status_code}") from e
|
||||
|
||||
try:
|
||||
_symbols = [_v["f12"] for _v in resp.json()["data"]["diff"]]
|
||||
except Exception as e:
|
||||
logger.warning("An error occurred while extracting data from the response.")
|
||||
raise
|
||||
|
||||
if len(_symbols) < 3900:
|
||||
raise ValueError("The complete list of stocks is not available.")
|
||||
|
||||
# Add suffix after the stock code to conform to yahooquery standard, otherwise the data will not be fetched.
|
||||
_symbols = [
|
||||
_symbol + ".ss" if _symbol.startswith("6") else _symbol + ".sz" if _symbol.startswith(("0", "3")) else None
|
||||
for _symbol in _symbols
|
||||
]
|
||||
_symbols = [_symbol for _symbol in _symbols if _symbol is not None]
|
||||
|
||||
return set(_symbols)
|
||||
|
||||
if _HS_SYMBOLS is None:
|
||||
symbols = set()
|
||||
|
||||
@@ -796,6 +796,9 @@ class Run(BaseRun):
|
||||
# get 1m data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
|
||||
"""
|
||||
if self.interval == "1d" and pd.Timestamp(end) > pd.Timestamp(datetime.datetime.now().strftime("%Y-%m-%d")):
|
||||
raise ValueError(f"end_date: {end} is greater than the current date.")
|
||||
|
||||
super(Run, self).download_data(max_collector_count, delay, start, end, check_data_length, limit_nums)
|
||||
|
||||
def normalize_data(
|
||||
|
||||
7
setup.py
7
setup.py
@@ -46,7 +46,7 @@ if not _CYTHON_INSTALLED:
|
||||
REQUIRED = [
|
||||
"numpy>=1.12.0, <1.24",
|
||||
"pandas>=0.25.1",
|
||||
"scipy>=1.0.0",
|
||||
"scipy>=1.7.3",
|
||||
"requests>=2.18.0",
|
||||
"sacred>=0.7.4",
|
||||
"python-socketio",
|
||||
@@ -82,7 +82,7 @@ REQUIRED = [
|
||||
"dill",
|
||||
"dataclasses;python_version<'3.7'",
|
||||
"filelock",
|
||||
"jinja2<3.1.0", # for passing the readthedocs workflow.
|
||||
"jinja2",
|
||||
"gym",
|
||||
# Installing the latest version of protobuf for python versions below 3.8 will cause unit tests to fail.
|
||||
"protobuf<=3.20.1;python_version<='3.8'",
|
||||
@@ -166,6 +166,9 @@ setup(
|
||||
"lxml",
|
||||
"baostock",
|
||||
"yahooquery",
|
||||
# 2024-05-30 scs has released a new version: 3.2.4.post2,
|
||||
# this version, causes qlib installation to fail, so we've limited the scs version a bit for now.
|
||||
"scs<=3.2.4",
|
||||
"beautifulsoup4",
|
||||
# In version 0.4.11 of tianshou, the code:
|
||||
# logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
|
||||
|
||||
@@ -5,8 +5,9 @@ import unittest
|
||||
import pytest
|
||||
import sys
|
||||
from qlib.tests import TestAutoData
|
||||
from qlib.data.dataset import TSDatasetH
|
||||
from qlib.data.dataset import TSDatasetH, TSDataSampler
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import time
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
|
||||
@@ -98,6 +99,54 @@ class TestDataset(TestAutoData):
|
||||
print(idx[i])
|
||||
|
||||
|
||||
class TestTSDataSampler(unittest.TestCase):
|
||||
def test_TSDataSampler(self):
|
||||
"""
|
||||
Test TSDataSampler for issue #1716
|
||||
"""
|
||||
datetime_list = ["2000-01-31", "2000-02-29", "2000-03-31", "2000-04-30", "2000-05-31"]
|
||||
instruments = ["000001", "000002", "000003", "000004", "000005"]
|
||||
index = pd.MultiIndex.from_product(
|
||||
[pd.to_datetime(datetime_list), instruments], names=["datetime", "instrument"]
|
||||
)
|
||||
data = np.random.randn(len(datetime_list) * len(instruments))
|
||||
test_df = pd.DataFrame(data=data, index=index, columns=["factor"])
|
||||
dataset = TSDataSampler(test_df, datetime_list[0], datetime_list[-1], step_len=2)
|
||||
print()
|
||||
print("--------------dataset[0]--------------")
|
||||
print(dataset[0])
|
||||
print("--------------dataset[1]--------------")
|
||||
print(dataset[1])
|
||||
assert len(dataset[0]) == 2
|
||||
self.assertTrue(np.isnan(dataset[0][0]))
|
||||
self.assertEqual(dataset[0][1], dataset[1][0])
|
||||
self.assertEqual(dataset[1][1], dataset[2][0])
|
||||
self.assertEqual(dataset[2][1], dataset[3][0])
|
||||
|
||||
def test_TSDataSampler2(self):
|
||||
"""
|
||||
Extra test TSDataSampler to prevent incorrect filling of nan for the values at the front
|
||||
"""
|
||||
datetime_list = ["2000-01-31", "2000-02-29", "2000-03-31", "2000-04-30", "2000-05-31"]
|
||||
instruments = ["000001", "000002", "000003", "000004", "000005"]
|
||||
index = pd.MultiIndex.from_product(
|
||||
[pd.to_datetime(datetime_list), instruments], names=["datetime", "instrument"]
|
||||
)
|
||||
data = np.random.randn(len(datetime_list) * len(instruments))
|
||||
test_df = pd.DataFrame(data=data, index=index, columns=["factor"])
|
||||
dataset = TSDataSampler(test_df, datetime_list[2], datetime_list[-1], step_len=3)
|
||||
print()
|
||||
print("--------------dataset[0]--------------")
|
||||
print(dataset[0])
|
||||
print("--------------dataset[1]--------------")
|
||||
print(dataset[1])
|
||||
for i in range(3):
|
||||
self.assertFalse(np.isnan(dataset[0][i]))
|
||||
self.assertFalse(np.isnan(dataset[1][i]))
|
||||
self.assertEqual(dataset[0][1], dataset[1][0])
|
||||
self.assertEqual(dataset[0][2], dataset[1][1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=10)
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ def train(uri_path: str = None):
|
||||
model performance
|
||||
"""
|
||||
|
||||
# model initiaiton
|
||||
# model initialization
|
||||
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
|
||||
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
|
||||
# To test __repr__
|
||||
|
||||
Reference in New Issue
Block a user