mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
fix automatic update of daily frequency data
This commit is contained in:
@@ -13,6 +13,7 @@ from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from joblib import Parallel, delayed
|
||||
from qlib.utils import code_to_fname
|
||||
|
||||
|
||||
@@ -186,20 +187,12 @@ class BaseCollector(abc.ABC):
|
||||
def _collector(self, instrument_list):
|
||||
|
||||
error_symbol = []
|
||||
with tqdm(total=len(instrument_list)) as p_bar:
|
||||
if self.max_workers is not None and self.max_workers > 1:
|
||||
logger.info(f"concurrent collector, max_workers: {self.max_workers}")
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
for _symbol, _result in zip(instrument_list, executor.map(self._simple_collector, instrument_list)):
|
||||
if _result != self.NORMAL_FLAG:
|
||||
error_symbol.append(_symbol)
|
||||
p_bar.update()
|
||||
else:
|
||||
for _symbol in instrument_list:
|
||||
_result = self._simple_collector(_symbol)
|
||||
if _result != self.NORMAL_FLAG:
|
||||
error_symbol.append(_symbol)
|
||||
p_bar.update()
|
||||
res = Parallel(n_jobs=self.max_workers)(
|
||||
delayed(self._simple_collector)(_inst) for _inst in tqdm(instrument_list)
|
||||
)
|
||||
for _symbol, _result in zip(instrument_list, res):
|
||||
if _result != self.NORMAL_FLAG:
|
||||
error_symbol.append(_symbol)
|
||||
print(error_symbol)
|
||||
logger.info(f"error symbol nums: {len(error_symbol)}")
|
||||
logger.info(f"current get symbol nums: {len(instrument_list)}")
|
||||
@@ -365,7 +358,7 @@ class BaseRun(abc.ABC):
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
check_data_length=False,
|
||||
check_data_length: int = None,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download data from Internet
|
||||
@@ -382,8 +375,8 @@ class BaseRun(abc.ABC):
|
||||
start datetime, default "2000-01-01"
|
||||
end: str
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
check_data_length: int
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
|
||||
|
||||
@@ -254,7 +254,7 @@ class Run(BaseRun):
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
check_data_length=None,
|
||||
check_data_length: int = None,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download data from Internet
|
||||
|
||||
@@ -140,22 +140,24 @@ pip install -r requirements.txt
|
||||
```
|
||||
* * * * 1-5 python <script path> update_data_to_bin --qlib_data_1d_dir <user data dir>
|
||||
```
|
||||
* **script path**: *qlib/scripts/data_collector/yahoo/collector.py*
|
||||
* **script path**: *scripts/data_collector/yahoo/collector.py*
|
||||
|
||||
* Manual update of data
|
||||
```
|
||||
python qlib/scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
|
||||
python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
|
||||
```
|
||||
* *trading_date*: start of trading day
|
||||
* *end_date*: end of trading day(not included)
|
||||
* `trading_date`: start of trading day
|
||||
* `end_date`: end of trading day(not included)
|
||||
* `check_data_length`: check the number of rows per *symbol*, by default `None`
|
||||
> if `len(symbol_df) < check_data_length`, it will be re-fetched, with the number of re-fetches coming from the `max_collector_count` parameter
|
||||
|
||||
* qlib/scripts/data_collector/yahoo/collector.py update_data_to_bin parameters:
|
||||
* *source_dir*: The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
||||
* *normalize_dir*: Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
* *qlib_data_1d_dir*: the qlib data to be updated for yahoo, usually from: [download qlib data](https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data)
|
||||
* *trading_date*: trading days to be updated, by default ``datetime.datetime.now().strftime("%Y-%m-%d")``
|
||||
* *end_date*: end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end)
|
||||
* *region*: region, value from ["CN", "US"], default "CN"
|
||||
* `scripts/data_collector/yahoo/collector.py update_data_to_bin` parameters:
|
||||
* `source_dir`: The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
||||
* `normalize_dir`: Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
* `qlib_data_1d_dir`: the qlib data to be updated for yahoo, usually from: [download qlib data](https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data)
|
||||
* `trading_date`: trading days to be updated, by default ``datetime.datetime.now().strftime("%Y-%m-%d")``
|
||||
* `end_date`: end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end)
|
||||
* `region`: region, value from ["CN", "US"], default "CN"
|
||||
|
||||
|
||||
## Using qlib data
|
||||
|
||||
@@ -8,6 +8,7 @@ import time
|
||||
import datetime
|
||||
import importlib
|
||||
from abc import ABC
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
@@ -49,7 +50,7 @@ class YahooCollector(BaseCollector):
|
||||
max_workers=4,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
check_data_length: int = None,
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
@@ -70,8 +71,8 @@ class YahooCollector(BaseCollector):
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
check_data_length: int
|
||||
check data length, by default None
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
"""
|
||||
@@ -311,8 +312,8 @@ class YahooNormalize(BaseNormalize):
|
||||
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {symbol_field_name}] = np.nan
|
||||
_tmp_series = df["close"].fillna(method="ffill")
|
||||
_tmp_shift_series = _tmp_series.shift(1)
|
||||
if last_close is not None and isinstance(last_close, (int, float)):
|
||||
_tmp_shift_series.iloc[0] = last_close
|
||||
if last_close is not None:
|
||||
_tmp_shift_series.iloc[0] = float(last_close)
|
||||
df["change"] = _tmp_series / _tmp_shift_series - 1
|
||||
columns += ["change"]
|
||||
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), columns] = np.nan
|
||||
@@ -408,8 +409,9 @@ class YahooNormalize1dExtend(YahooNormalize1d):
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
super(YahooNormalize1dExtend, self).__init__(date_field_name, symbol_field_name)
|
||||
self._end_date, self._old_close = self._get_old_data(old_qlib_data_dir)
|
||||
self._end_date = pd.Timestamp(self._end_date).strftime(self.DAILY_FORMAT)
|
||||
self._first_close_field = "first_close"
|
||||
self._ori_close_field = "ori_close"
|
||||
self.old_qlib_data = self._get_old_data(old_qlib_data_dir)
|
||||
|
||||
def _get_old_data(self, qlib_data_dir: [str, Path]):
|
||||
import qlib
|
||||
@@ -417,32 +419,34 @@ class YahooNormalize1dExtend(YahooNormalize1d):
|
||||
|
||||
qlib_data_dir = str(Path(qlib_data_dir).expanduser().resolve())
|
||||
qlib.init(provider_uri=qlib_data_dir, expression_cache=None, dataset_cache=None)
|
||||
df = D.features(D.instruments("all"), ["$close/$factor"])
|
||||
df.columns = ["close"]
|
||||
return D.calendar()[-1], df
|
||||
df = D.features(D.instruments("all"), ["$close/$factor", "$adjclose/$close"])
|
||||
df.columns = [self._ori_close_field, self._first_close_field]
|
||||
return df
|
||||
|
||||
def _get_close(self, df: pd.DataFrame, field_name: str):
|
||||
_symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper()
|
||||
_df = self.old_qlib_data.loc(axis=0)[_symbol]
|
||||
_close = _df.loc[_df.last_valid_index()][field_name]
|
||||
return _close
|
||||
|
||||
def _get_first_close(self, df: pd.DataFrame) -> float:
|
||||
_symbol = df.iloc[0][self._symbol_field_name]
|
||||
try:
|
||||
_df = self._old_close.loc(axis=0)[_symbol.upper()]
|
||||
_close = _df.loc[_df.first_valid_index()]["close"]
|
||||
_close = self._get_close(df, field_name=self._first_close_field)
|
||||
except KeyError:
|
||||
_close = super(YahooNormalize1dExtend, self)._get_first_close(df)
|
||||
return _close
|
||||
|
||||
def _get_last_close(self, df: pd.DataFrame) -> float:
|
||||
_symbol = df.iloc[0][self._symbol_field_name]
|
||||
try:
|
||||
_df = self._old_close.loc(axis=0)[_symbol.upper()]
|
||||
_close = _df.loc[_df.last_valid_index()]["close"]
|
||||
_close = self._get_close(df, field_name=self._ori_close_field)
|
||||
except KeyError:
|
||||
_close = None
|
||||
return _close
|
||||
|
||||
def _get_last_date(self, df: pd.DataFrame) -> pd.Timestamp:
|
||||
_symbol = df.iloc[0][self._symbol_field_name]
|
||||
_symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper()
|
||||
try:
|
||||
_df = self._old_close.loc(axis=0)[_symbol.upper()]
|
||||
_df = self.old_qlib_data.loc(axis=0)[_symbol]
|
||||
_date = _df.index.max()
|
||||
except KeyError:
|
||||
_date = None
|
||||
@@ -901,7 +905,7 @@ class Run(BaseRun):
|
||||
def download_today_data(
|
||||
self,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
delay=0.5,
|
||||
check_data_length=None,
|
||||
limit_nums=None,
|
||||
):
|
||||
@@ -912,7 +916,7 @@ class Run(BaseRun):
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
time.sleep(delay), default 0.5
|
||||
check_data_length: int
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
limit_nums: int
|
||||
@@ -947,7 +951,14 @@ class Run(BaseRun):
|
||||
limit_nums,
|
||||
)
|
||||
|
||||
def update_data_to_bin(self, qlib_data_1d_dir: str, trading_date: str = None, end_date: str = None):
|
||||
def update_data_to_bin(
|
||||
self,
|
||||
qlib_data_1d_dir: str,
|
||||
trading_date: str = None,
|
||||
end_date: str = None,
|
||||
check_data_length: int = None,
|
||||
delay: float = 1,
|
||||
):
|
||||
"""update yahoo data to bin
|
||||
|
||||
Parameters
|
||||
@@ -959,7 +970,10 @@ class Run(BaseRun):
|
||||
trading days to be updated, by default ``datetime.datetime.now().strftime("%Y-%m-%d")``
|
||||
end_date: str
|
||||
end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end)
|
||||
|
||||
check_data_length: int
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
delay: float
|
||||
time.sleep(delay), default 1
|
||||
Notes
|
||||
-----
|
||||
If the data in qlib_data_dir is incomplete, np.nan will be populated to trading_date for the previous trading day
|
||||
@@ -987,8 +1001,14 @@ class Run(BaseRun):
|
||||
GetData().qlib_data(target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region)
|
||||
|
||||
# download data from yahoo
|
||||
self.download_data(delay=1, start=trading_date, end=end_date, check_data_length=1)
|
||||
|
||||
# NOTE: when downloading data from YahooFinance, max_workers is recommended to be 1
|
||||
self.download_data(delay=delay, start=trading_date, end=end_date, check_data_length=check_data_length)
|
||||
# NOTE: a larger max_workers setting here would be faster
|
||||
self.max_workers = (
|
||||
max(multiprocessing.cpu_count() - 2, 1)
|
||||
if self.max_workers is None or self.max_workers <= 1
|
||||
else self.max_workers
|
||||
)
|
||||
# normalize data
|
||||
self.normalize_data_1d_extend(qlib_data_1d_dir)
|
||||
|
||||
|
||||
@@ -7,3 +7,4 @@ tqdm
|
||||
lxml
|
||||
loguru
|
||||
yahooquery
|
||||
joblib
|
||||
|
||||
Reference in New Issue
Block a user