1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-28 16:41:18 +08:00

Compare commits

..

5 Commits

Author SHA1 Message Date
Linlang
5f9219acf2 add comments 2024-05-23 20:47:32 +08:00
Linlang
d0b84d5696 fix pylint error 2024-05-23 16:21:54 +08:00
Linlang
fb54d08236 optimize get_data code 2024-05-23 16:01:57 +08:00
Linlang
9bb4259080 fix get v0 data error 2024-05-23 07:19:04 +08:00
Linlang
117f67d6e1 fix get data error 2024-05-23 06:57:59 +08:00
21 changed files with 30 additions and 128 deletions

View File

@@ -45,9 +45,6 @@ jobs:
- name: Qlib installation test
run: |
# 2024-05-30 scs has released a new version: 3.2.4.post2,
# This will cause the CI to fail, so we have limited the version of scs for now.
python -m pip install "scs<=3.2.4"
python -m pip install pyqlib
- name: Install Lightgbm for MacOS
@@ -68,8 +65,5 @@ jobs:
cd qlib
- name: Test workflow by config
# On macos-11 system, it will lead to "Segmentation fault: 11" error,
# which may be caused by the excessive memory overhead of macos-11 system, so we disable macos-11 temporarily here.
if: ${{ matrix.os != 'macos-11' }}
run: |
qrun examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml

View File

@@ -72,10 +72,8 @@ jobs:
black . -l 120 --check --diff
- name: Make html with sphinx
# Since read the docs builds on ubuntu 22.04, we only need to test that the build passes on ubuntu 22.04.
if: ${{ matrix.os == 'ubuntu-22.04' }}
run: |
cd docs
cd docs
sphinx-build -W --keep-going -b html . _build
cd ..
@@ -161,16 +159,11 @@ jobs:
# Run after data downloads
- name: Check Qlib ipynb with nbconvert
# Running the nbconvert check on a macos-11 system results in a "Kernel died" error, so we've temporarily disabled macos-11 here.
if: ${{ matrix.os != 'macos-11' }}
run: |
# add more ipynb files in future
jupyter nbconvert --to notebook --execute examples/workflow_by_code.ipynb
- name: Test workflow by config (install from source)
# On macos-11 system, it will lead to "Segmentation fault: 11" error,
# which may be caused by the excessive memory overhead of macos-11 system, so we disable macos-11 temporarily here.
if: ${{ matrix.os != 'macos-11' }}
run: |
python -m pip install numba
python qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml

View File

@@ -40,7 +40,7 @@ Recent released features
Features released before 2021 are not listed here.
<p align="center">
<img src="docs/_static/img/logo/1.png" />
<img src="http://fintech.msra.cn/images_v070/logo/1.png" />
</p>
Qlib is an open-source, AI-oriented quantitative investment platform that aims to realize the potential, empower research, and create value using AI technologies in quantitative investment, from exploring ideas to implementing productions. Qlib supports diverse machine learning modeling paradigms, including supervised learning, market dynamics modeling, and reinforcement learning.
@@ -166,7 +166,7 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
* Clone the repository and install ``Qlib`` as follows.
```bash
git clone https://github.com/microsoft/qlib.git && cd qlib
pip install . # `pip install -e .[dev]` is recommended for development. check details in docs/developer/code_standard_and_dev_guide.rst
pip install .
```
**Note**: You can install Qlib with `python setup.py install` as well. But it is not the recommended approach. It will skip `pip` and cause obscure problems. For example, **only** the command ``pip install .`` **can** overwrite the stable version installed by ``pip install pyqlib``, while the command ``python setup.py install`` **can't**.

View File

@@ -86,7 +86,7 @@ Example
},
}
# model initialization
# model initiaiton
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])

View File

@@ -20,7 +20,7 @@ We use China stock market data for our example.
1. Prepare CSI300 weight:
```bash
wget https://github.com/SunsetWolf/qlib_dataset/releases/download/v0/csi300_weight.zip
wget http://fintech.msra.cn/stock_data/downloads/csi300_weight.zip
unzip -d ~/.qlib/qlib_data/cn_data csi300_weight.zip
rm -f csi300_weight.zip
```

View File

@@ -161,7 +161,7 @@
" },\n",
"}\n",
"\n",
"# model initialization\n",
"# model initiaiton\n",
"model = init_instance_by_config(task[\"model\"])\n",
"dataset = init_instance_by_config(task[\"dataset\"])\n",
"\n",

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT License.
from pathlib import Path
__version__ = "0.9.5.99"
__version__ = "0.9.4.99"
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
import os
from typing import Union

View File

@@ -160,10 +160,6 @@ class ALSTM(Model):
if self.metric in ("", "loss"):
return -self.loss_fn(pred[mask], label[mask])
elif self.metric == "mse":
mask = ~torch.isnan(label)
weight = torch.ones_like(label)
return -self.mse(pred[mask], label[mask], weight[mask])
raise ValueError("unknown metric `%s`" % self.metric)

View File

@@ -35,7 +35,7 @@ class Client:
def connect_server(self):
"""Connect to server."""
try:
self.sio.connect(f"ws://{self.server_host}:{self.server_port}")
self.sio.connect("ws://" + self.server_host + ":" + str(self.server_port))
except socketio.exceptions.ConnectionError:
self.logger.error("Cannot connect to server - check your network or server status")

View File

@@ -616,7 +616,7 @@ class DatasetProvider(abc.ABC):
data = pd.DataFrame(obj)
if not data.empty and not np.issubdtype(data.index.dtype, np.dtype("M")):
# If the underlaying provides the data not in datetime format, we'll convert it into datetime format
# If the underlaying provides the data not in datatime formmat, we'll convert it into datetime format
_calendar = Cal.calendar(freq=freq)
data.index = _calendar[data.index.values.astype(int)]
data.index.names = ["datetime"]

View File

@@ -403,7 +403,7 @@ class TSDataSampler:
np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype),
axis=0,
)
self.nan_idx = len(self.data_arr) - 1 # The last line is all NaN; setting it to -1 can cause bug #1716
self.nan_idx = -1 # The last line is all NaN
# the data type will be changed
# The index of usable data is between start_idx and end_idx

View File

@@ -9,7 +9,7 @@ if TYPE_CHECKING:
from qlib.data.dataset import DataHandler
def get_level_index(df: pd.DataFrame, level: Union[str, int]) -> int:
def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int:
"""
get the level index of `df` given `level`

View File

@@ -41,7 +41,7 @@ def _log_task_info(task_config: dict):
def _exe_task(task_config: dict):
rec = R.get_recorder()
# model & dataset initialization
# model & dataset initiation
model: Model = init_instance_by_config(task_config["model"], accept_types=Model)
dataset: Dataset = init_instance_by_config(task_config["dataset"], accept_types=Dataset)
reweighter: Reweighter = task_config.get("reweighter", None)

View File

@@ -242,7 +242,7 @@ class TimeAdjuster:
def shift(self, seg: tuple, step: int, rtype=SHIFT_SD) -> tuple:
"""
Shift the datetime of segment
Shift the datatime of segment
If there are None (which indicates unbounded index) in the segment, this method will return None.

View File

@@ -301,7 +301,6 @@ class Normalize:
na_values={col: symbol_na if col == self._symbol_field_name else default_na for col in columns},
)
# NOTE: It has been reported that there may be some problems here, and the specific issues will be dealt with when they are identified.
df = self._normalize_obj.normalize(df)
if df is not None and not df.empty:
if self._end_date is not None:

View File

@@ -9,7 +9,7 @@ pip install -r requirements.txt
```
## Usage of the dataset
> *Crypto dataset only support Data retrieval function but not support backtest function due to the lack of OHLC data.*
> *Crypto dateset only support Data retrieval function but not support backtest function due to the lack of OHLC data.*
## Collector Data

View File

@@ -15,6 +15,7 @@ from typing import Iterable, Tuple, List
import numpy as np
import pandas as pd
from lxml import etree
from loguru import logger
from yahooquery import Ticker
from tqdm import tqdm
@@ -189,43 +190,17 @@ def get_hs_stock_symbols() -> list:
global _HS_SYMBOLS # pylint: disable=W0603
def _get_symbol():
"""
Get the stock pool from a web page and process it into the format required by yahooquery.
Format of data retrieved from the web page: 600519, 000001
The data format required by yahooquery: 600519.ss, 000001.sz
Returns
-------
set: Returns the set of symbol codes.
Examples:
-------
{600000.ss, 600001.ss, 600002.ss, 600003.ss, ...}
"""
url = "http://99.push2.eastmoney.com/api/qt/clist/get?pn=1&pz=10000&po=1&np=1&fs=m:0+t:6,m:0+t:80,m:1+t:2,m:1+t:23,m:0+t:81+s:2048&fields=f12"
try:
resp = requests.get(url, timeout=None)
resp.raise_for_status()
except requests.exceptions.HTTPError as e:
raise requests.exceptions.HTTPError(f"Request to {url} failed with status code {resp.status_code}") from e
try:
_symbols = [_v["f12"] for _v in resp.json()["data"]["diff"]]
except Exception as e:
logger.warning("An error occurred while extracting data from the response.")
raise
if len(_symbols) < 3900:
raise ValueError("The complete list of stocks is not available.")
# Add suffix after the stock code to conform to yahooquery standard, otherwise the data will not be fetched.
_symbols = [
_symbol + ".ss" if _symbol.startswith("6") else _symbol + ".sz" if _symbol.startswith(("0", "3")) else None
for _symbol in _symbols
]
_symbols = [_symbol for _symbol in _symbols if _symbol is not None]
return set(_symbols)
_res = set()
for _k, _v in (("ha", "ss"), ("sa", "sz"), ("gem", "sz")):
resp = requests.get(HS_SYMBOLS_URL.format(s_type=_k), timeout=None)
_res |= set(
map(
lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v), # pylint: disable=W0640
etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"), # pylint: disable=I1101
)
)
time.sleep(3)
return _res
if _HS_SYMBOLS is None:
symbols = set()

View File

@@ -796,9 +796,6 @@ class Run(BaseRun):
# get 1m data
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
"""
if self.interval == "1d" and pd.Timestamp(end) > pd.Timestamp(datetime.datetime.now().strftime("%Y-%m-%d")):
raise ValueError(f"end_date: {end} is greater than the current date.")
super(Run, self).download_data(max_collector_count, delay, start, end, check_data_length, limit_nums)
def normalize_data(

View File

@@ -46,7 +46,7 @@ if not _CYTHON_INSTALLED:
REQUIRED = [
"numpy>=1.12.0, <1.24",
"pandas>=0.25.1",
"scipy>=1.7.3",
"scipy>=1.0.0",
"requests>=2.18.0",
"sacred>=0.7.4",
"python-socketio",
@@ -82,7 +82,7 @@ REQUIRED = [
"dill",
"dataclasses;python_version<'3.7'",
"filelock",
"jinja2",
"jinja2<3.1.0", # for passing the readthedocs workflow.
"gym",
# Installing the latest version of protobuf for python versions below 3.8 will cause unit tests to fail.
"protobuf<=3.20.1;python_version<='3.8'",
@@ -166,9 +166,6 @@ setup(
"lxml",
"baostock",
"yahooquery",
# 2024-05-30 scs has released a new version: 3.2.4.post2,
# this version, causes qlib installation to fail, so we've limited the scs version a bit for now.
"scs<=3.2.4",
"beautifulsoup4",
# In version 0.4.11 of tianshou, the code:
# logits, hidden = self.actor(batch.obs, state=state, info=batch.info)

View File

@@ -5,9 +5,8 @@ import unittest
import pytest
import sys
from qlib.tests import TestAutoData
from qlib.data.dataset import TSDatasetH, TSDataSampler
from qlib.data.dataset import TSDatasetH
import numpy as np
import pandas as pd
import time
from qlib.data.dataset.handler import DataHandlerLP
@@ -99,54 +98,6 @@ class TestDataset(TestAutoData):
print(idx[i])
class TestTSDataSampler(unittest.TestCase):
def test_TSDataSampler(self):
"""
Test TSDataSampler for issue #1716
"""
datetime_list = ["2000-01-31", "2000-02-29", "2000-03-31", "2000-04-30", "2000-05-31"]
instruments = ["000001", "000002", "000003", "000004", "000005"]
index = pd.MultiIndex.from_product(
[pd.to_datetime(datetime_list), instruments], names=["datetime", "instrument"]
)
data = np.random.randn(len(datetime_list) * len(instruments))
test_df = pd.DataFrame(data=data, index=index, columns=["factor"])
dataset = TSDataSampler(test_df, datetime_list[0], datetime_list[-1], step_len=2)
print()
print("--------------dataset[0]--------------")
print(dataset[0])
print("--------------dataset[1]--------------")
print(dataset[1])
assert len(dataset[0]) == 2
self.assertTrue(np.isnan(dataset[0][0]))
self.assertEqual(dataset[0][1], dataset[1][0])
self.assertEqual(dataset[1][1], dataset[2][0])
self.assertEqual(dataset[2][1], dataset[3][0])
def test_TSDataSampler2(self):
"""
Extra test TSDataSampler to prevent incorrect filling of nan for the values at the front
"""
datetime_list = ["2000-01-31", "2000-02-29", "2000-03-31", "2000-04-30", "2000-05-31"]
instruments = ["000001", "000002", "000003", "000004", "000005"]
index = pd.MultiIndex.from_product(
[pd.to_datetime(datetime_list), instruments], names=["datetime", "instrument"]
)
data = np.random.randn(len(datetime_list) * len(instruments))
test_df = pd.DataFrame(data=data, index=index, columns=["factor"])
dataset = TSDataSampler(test_df, datetime_list[2], datetime_list[-1], step_len=3)
print()
print("--------------dataset[0]--------------")
print(dataset[0])
print("--------------dataset[1]--------------")
print(dataset[1])
for i in range(3):
self.assertFalse(np.isnan(dataset[0][i]))
self.assertFalse(np.isnan(dataset[1][i]))
self.assertEqual(dataset[0][1], dataset[1][0])
self.assertEqual(dataset[0][2], dataset[1][1])
if __name__ == "__main__":
unittest.main(verbosity=10)

View File

@@ -27,7 +27,7 @@ def train(uri_path: str = None):
model performance
"""
# model initialization
# model initiaiton
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
# To test __repr__