mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 19:10:58 +08:00
Merge pull request #463 from zhupr/support_extend_data
Support extend data
This commit is contained in:
22
README.md
22
README.md
@@ -159,6 +159,28 @@ Users could create the same dataset with it.
|
||||
*Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup), and the data might not be perfect.
|
||||
We recommend users to prepare their own data if they have a high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*.
|
||||
|
||||
### Automatic update of daily frequency data(from yahoo finance)
|
||||
> It is recommended that users update the data manually once (--trading_date 2021-05-25) and then set it to update automatically.
|
||||
|
||||
> For more information refer to: [yahoo collector](https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance)
|
||||
|
||||
* Automatic update of data to the "qlib" directory each trading day(Linux)
|
||||
* use *crontab*: `crontab -e`
|
||||
* set up timed tasks:
|
||||
|
||||
```
|
||||
* * * * 1-5 python <script path> update_data_to_bin --qlib_data_1d_dir <user data dir>
|
||||
```
|
||||
* **script path**: *scripts/data_collector/yahoo/collector.py*
|
||||
|
||||
* Manual update of data
|
||||
```
|
||||
python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
|
||||
```
|
||||
* *trading_date*: start of trading day
|
||||
* *end_date*: end of trading day(not included)
|
||||
|
||||
|
||||
<!--
|
||||
- Run the initialization code and get stock data:
|
||||
|
||||
|
||||
@@ -67,6 +67,34 @@ After running the above command, users can find china-stock and us-stock data in
|
||||
|
||||
When ``Qlib`` is initialized with this dataset, users could build and evaluate their own models with it. Please refer to `Initialization <../start/initialization.html>`_ for more details.
|
||||
|
||||
Automatic update of daily frequency data
|
||||
----------------------------------------
|
||||
|
||||
**It is recommended that users update the data manually once (\-\-trading_date 2021-05-25) and then set it to update automatically.**
|
||||
|
||||
For more information refer to: `yahoo collector <https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#Automatic-update-of-daily-frequency-data>`_
|
||||
|
||||
- Automatic update of data to the "qlib" directory each trading day(Linux)
|
||||
- use *crontab*: `crontab -e`
|
||||
- set up timed tasks:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
* * * * 1-5 python <script path> update_data_to_bin --qlib_data_1d_dir <user data dir>
|
||||
|
||||
- **script path**: *scripts/data_collector/yahoo/collector.py*
|
||||
|
||||
- Manual update of data
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
|
||||
|
||||
- *trading_date*: start of trading day
|
||||
- *end_date*: end of trading day(not included)
|
||||
|
||||
|
||||
|
||||
Converting CSV Format into Qlib Format
|
||||
-------------------------------------------
|
||||
|
||||
|
||||
@@ -4,6 +4,10 @@ Here are the results of each benchmark model running on Qlib's `Alpha360` and `A
|
||||
|
||||
The numbers shown below demonstrate the performance of the entire `workflow` of each model. We will update the `workflow` as well as models in the near future for better results.
|
||||
|
||||
> If you need to reproduce the results below, please use the **v1** dataset: `python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1d --region cn --version v1`
|
||||
>
|
||||
> In the new version of qlib, the default dataset is **v2**. Since the data is collected from the YahooFinance API (which is not very stable), the results of *v2* and *v1* may differ
|
||||
|
||||
## Alpha360 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|---|---|---|---|---|---|---|---|---|
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import plotly.tools as tls
|
||||
import plotly.graph_objs as go
|
||||
|
||||
import statsmodels.api as sm
|
||||
@@ -80,9 +79,35 @@ def _plot_qq(data: pd.Series = None, dist=stats.norm) -> go.Figure:
|
||||
:param dist:
|
||||
:return:
|
||||
"""
|
||||
fig, ax = plt.subplots(figsize=(8, 5))
|
||||
_mpl_fig = sm.qqplot(data.dropna(), dist, fit=True, line="45", ax=ax)
|
||||
return tls.mpl_to_plotly(_mpl_fig)
|
||||
# NOTE: plotly.tools.mpl_to_plotly not actively maintained, resulting in errors in the new version of matplotlib,
|
||||
# ref: https://github.com/plotly/plotly.py/issues/2913#issuecomment-730071567
|
||||
# removing plotly.tools.mpl_to_plotly for greater compatibility with matplotlib versions
|
||||
_plt_fig = sm.qqplot(data.dropna(), dist=dist, fit=True, line="45")
|
||||
plt.close(_plt_fig)
|
||||
qqplot_data = _plt_fig.gca().lines
|
||||
fig = go.Figure()
|
||||
|
||||
fig.add_trace(
|
||||
{
|
||||
"type": "scatter",
|
||||
"x": qqplot_data[0].get_xdata(),
|
||||
"y": qqplot_data[0].get_ydata(),
|
||||
"mode": "markers",
|
||||
"marker": {"color": "#19d3f3"},
|
||||
}
|
||||
)
|
||||
|
||||
fig.add_trace(
|
||||
{
|
||||
"type": "scatter",
|
||||
"x": qqplot_data[1].get_xdata(),
|
||||
"y": qqplot_data[1].get_ydata(),
|
||||
"mode": "lines",
|
||||
"line": {"color": "#636efa"},
|
||||
}
|
||||
)
|
||||
del qqplot_data
|
||||
return fig
|
||||
|
||||
|
||||
def _pred_ic(pred_label: pd.DataFrame = None, rank: bool = False, **kwargs) -> tuple:
|
||||
|
||||
@@ -7,12 +7,13 @@ import time
|
||||
import datetime
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from typing import Type
|
||||
from typing import Type, Iterable
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from joblib import Parallel, delayed
|
||||
from qlib.utils import code_to_fname
|
||||
|
||||
|
||||
@@ -22,9 +23,9 @@ class BaseCollector(abc.ABC):
|
||||
NORMAL_FLAG = "NORMAL"
|
||||
|
||||
DEFAULT_START_DATETIME_1D = pd.Timestamp("2000-01-01")
|
||||
DEFAULT_START_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6))
|
||||
DEFAULT_END_DATETIME_1D = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
|
||||
DEFAULT_END_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
|
||||
DEFAULT_START_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6 - 1)).date()
|
||||
DEFAULT_END_DATETIME_1D = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)).date()
|
||||
DEFAULT_END_DATETIME_1MIN = DEFAULT_END_DATETIME_1D
|
||||
|
||||
INTERVAL_1min = "1min"
|
||||
INTERVAL_1d = "1d"
|
||||
@@ -35,10 +36,10 @@ class BaseCollector(abc.ABC):
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
max_workers=4,
|
||||
max_workers=1,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
check_data_length: int = None,
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
@@ -48,7 +49,7 @@ class BaseCollector(abc.ABC):
|
||||
save_dir: str
|
||||
instrument save dir
|
||||
max_workers: int
|
||||
workers, default 4
|
||||
workers, default 1; Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
@@ -59,8 +60,8 @@ class BaseCollector(abc.ABC):
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
check_data_length: int
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
"""
|
||||
@@ -72,7 +73,7 @@ class BaseCollector(abc.ABC):
|
||||
self.max_collector_count = max_collector_count
|
||||
self.mini_symbol_map = {}
|
||||
self.interval = interval
|
||||
self.check_small_data = check_data_length
|
||||
self.check_data_length = max(int(check_data_length) if check_data_length is not None else 0, 0)
|
||||
|
||||
self.start_datetime = self.normalize_start_datetime(start)
|
||||
self.end_datetime = self.normalize_end_datetime(end)
|
||||
@@ -99,14 +100,6 @@ class BaseCollector(abc.ABC):
|
||||
else getattr(self, f"DEFAULT_END_DATETIME_{self.interval.upper()}")
|
||||
)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def min_numbers_trading(self):
|
||||
# daily, one year: 252 / 4
|
||||
# us 1min, a week: 6.5 * 60 * 5
|
||||
# cn 1min, a week: 4 * 60 * 5
|
||||
raise NotImplementedError("rewrite min_numbers_trading")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_instrument_list(self):
|
||||
raise NotImplementedError("rewrite get_instrument_list")
|
||||
@@ -132,7 +125,7 @@ class BaseCollector(abc.ABC):
|
||||
|
||||
Returns
|
||||
---------
|
||||
pd.DataFrame, "symbol" in pd.columns
|
||||
pd.DataFrame, "symbol" and "date"in pd.columns
|
||||
|
||||
"""
|
||||
raise NotImplementedError("rewrite get_timezone")
|
||||
@@ -151,7 +144,7 @@ class BaseCollector(abc.ABC):
|
||||
self.sleep()
|
||||
df = self.get_data(symbol, self.interval, self.start_datetime, self.end_datetime)
|
||||
_result = self.NORMAL_FLAG
|
||||
if self.check_small_data:
|
||||
if self.check_data_length > 0:
|
||||
_result = self.cache_small_data(symbol, df)
|
||||
if _result == self.NORMAL_FLAG:
|
||||
self.save_instrument(symbol, df)
|
||||
@@ -181,8 +174,8 @@ class BaseCollector(abc.ABC):
|
||||
df.to_csv(instrument_path, index=False)
|
||||
|
||||
def cache_small_data(self, symbol, df):
|
||||
if len(df) <= self.min_numbers_trading:
|
||||
logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!")
|
||||
if len(df) < self.check_data_length:
|
||||
logger.warning(f"the number of trading days of {symbol} is less than {self.check_data_length}!")
|
||||
_temp = self.mini_symbol_map.setdefault(symbol, [])
|
||||
_temp.append(df.copy())
|
||||
return self.CACHE_FLAG
|
||||
@@ -194,12 +187,12 @@ class BaseCollector(abc.ABC):
|
||||
def _collector(self, instrument_list):
|
||||
|
||||
error_symbol = []
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
with tqdm(total=len(instrument_list)) as p_bar:
|
||||
for _symbol, _result in zip(instrument_list, executor.map(self._simple_collector, instrument_list)):
|
||||
if _result != self.NORMAL_FLAG:
|
||||
error_symbol.append(_symbol)
|
||||
p_bar.update()
|
||||
res = Parallel(n_jobs=self.max_workers)(
|
||||
delayed(self._simple_collector)(_inst) for _inst in tqdm(instrument_list)
|
||||
)
|
||||
for _symbol, _result in zip(instrument_list, res):
|
||||
if _result != self.NORMAL_FLAG:
|
||||
error_symbol.append(_symbol)
|
||||
print(error_symbol)
|
||||
logger.info(f"error symbol nums: {len(error_symbol)}")
|
||||
logger.info(f"current get symbol nums: {len(instrument_list)}")
|
||||
@@ -217,20 +210,16 @@ class BaseCollector(abc.ABC):
|
||||
instrument_list = self._collector(instrument_list)
|
||||
logger.info(f"{i+1} finish.")
|
||||
for _symbol, _df_list in self.mini_symbol_map.items():
|
||||
self.save_instrument(
|
||||
_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])
|
||||
)
|
||||
_df = pd.concat(_df_list, sort=False)
|
||||
if not _df.empty:
|
||||
self.save_instrument(_symbol, _df.drop_duplicates(["date"]).sort_values(["date"]))
|
||||
if self.mini_symbol_map:
|
||||
logger.warning(f"less than {self.min_numbers_trading} instrument list: {list(self.mini_symbol_map.keys())}")
|
||||
logger.warning(f"less than {self.check_data_length} instrument list: {list(self.mini_symbol_map.keys())}")
|
||||
logger.info(f"total {len(self.instrument_list)}, error: {len(set(instrument_list))}")
|
||||
|
||||
|
||||
class BaseNormalize(abc.ABC):
|
||||
def __init__(
|
||||
self,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
def __init__(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
@@ -242,7 +231,7 @@ class BaseNormalize(abc.ABC):
|
||||
"""
|
||||
self._date_field_name = date_field_name
|
||||
self._symbol_field_name = symbol_field_name
|
||||
|
||||
self.kwargs = kwargs
|
||||
self._calendar_list = self._get_calendar_list()
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -251,7 +240,7 @@ class BaseNormalize(abc.ABC):
|
||||
raise NotImplementedError("")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_calendar_list(self):
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
"""Get benchmark calendar"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
@@ -265,6 +254,7 @@ class Normalize:
|
||||
max_workers: int = 16,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -288,16 +278,23 @@ class Normalize:
|
||||
self._source_dir = Path(source_dir).expanduser()
|
||||
self._target_dir = Path(target_dir).expanduser()
|
||||
self._target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._date_field_name = date_field_name
|
||||
self._symbol_field_name = symbol_field_name
|
||||
self._end_date = kwargs.get("end_date", None)
|
||||
self._max_workers = max_workers
|
||||
|
||||
self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name)
|
||||
self._normalize_obj = normalize_class(
|
||||
date_field_name=date_field_name, symbol_field_name=symbol_field_name, **kwargs
|
||||
)
|
||||
|
||||
def _executor(self, file_path: Path):
|
||||
file_path = Path(file_path)
|
||||
df = pd.read_csv(file_path)
|
||||
df = self._normalize_obj.normalize(df)
|
||||
if not df.empty:
|
||||
if df is not None and not df.empty:
|
||||
if self._end_date is not None:
|
||||
_mask = pd.to_datetime(df[self._date_field_name]) <= pd.Timestamp(self._end_date)
|
||||
df = df[_mask]
|
||||
df.to_csv(self._target_dir.joinpath(file_path.name), index=False)
|
||||
|
||||
def normalize(self):
|
||||
@@ -311,7 +308,7 @@ class Normalize:
|
||||
|
||||
|
||||
class BaseRun(abc.ABC):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d"):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d"):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
@@ -321,7 +318,7 @@ class BaseRun(abc.ABC):
|
||||
normalize_dir: str
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
max_workers: int
|
||||
Concurrent number, default is 4
|
||||
Concurrent number, default is 1; Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
"""
|
||||
@@ -361,7 +358,7 @@ class BaseRun(abc.ABC):
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
check_data_length=False,
|
||||
check_data_length: int = None,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download data from Internet
|
||||
@@ -378,8 +375,8 @@ class BaseRun(abc.ABC):
|
||||
start datetime, default "2000-01-01"
|
||||
end: str
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
check_data_length: int
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
|
||||
@@ -404,7 +401,7 @@ class BaseRun(abc.ABC):
|
||||
limit_nums=limit_nums,
|
||||
).collector_data()
|
||||
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs):
|
||||
"""normalize data
|
||||
|
||||
Parameters
|
||||
@@ -426,5 +423,6 @@ class BaseRun(abc.ABC):
|
||||
max_workers=self.max_workers,
|
||||
date_field_name=date_field_name,
|
||||
symbol_field_name=symbol_field_name,
|
||||
**kwargs,
|
||||
)
|
||||
yc.normalize()
|
||||
|
||||
@@ -19,12 +19,31 @@ CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
|
||||
from data_collector.index import IndexBase
|
||||
from data_collector.utils import get_calendar_list, get_trading_date_by_shift
|
||||
from data_collector.utils import get_calendar_list, get_trading_date_by_shift, deco_retry
|
||||
|
||||
|
||||
NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/{index_code}cons.xls"
|
||||
|
||||
INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A"
|
||||
|
||||
# INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A"
|
||||
# 2020-11-27 Announcement title change
|
||||
INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89"
|
||||
|
||||
REQ_HEADERS = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.101 Safari/537.36 Edg/91.0.864.48"
|
||||
}
|
||||
|
||||
|
||||
@deco_retry
|
||||
def retry_request(url: str, method: str = "get", exclude_status: List = None):
|
||||
if exclude_status is None:
|
||||
exclude_status = []
|
||||
method_func = getattr(requests, method)
|
||||
_resp = method_func(url, headers=REQ_HEADERS)
|
||||
_status = _resp.status_code
|
||||
if _status not in exclude_status and _status != 200:
|
||||
raise ValueError(f"response status: {_status}, url={url}")
|
||||
return _resp
|
||||
|
||||
|
||||
class CSIIndex(IndexBase):
|
||||
@@ -134,9 +153,8 @@ class CSIIndex(IndexBase):
|
||||
date: pd.Timestamp
|
||||
type: str, value from ["add", "remove"]
|
||||
"""
|
||||
resp = requests.get(url)
|
||||
resp = retry_request(url)
|
||||
_text = resp.text
|
||||
|
||||
date_list = re.findall(r"(\d{4}).*?年.*?(\d+).*?月.*?(\d+).*?日", _text)
|
||||
if len(date_list) >= 2:
|
||||
add_date = pd.Timestamp("-".join(date_list[0]))
|
||||
@@ -147,7 +165,7 @@ class CSIIndex(IndexBase):
|
||||
logger.info(f"get {add_date} changes")
|
||||
try:
|
||||
excel_url = re.findall('.*href="(.*?xls.*?)".*', _text)[0]
|
||||
content = requests.get(f"http://www.csindex.com.cn{excel_url}").content
|
||||
content = retry_request(f"http://www.csindex.com.cn{excel_url}", exclude_status=[404]).content
|
||||
_io = BytesIO(content)
|
||||
df_map = pd.read_excel(_io, sheet_name=None)
|
||||
with self.cache_dir.joinpath(
|
||||
@@ -201,7 +219,7 @@ class CSIIndex(IndexBase):
|
||||
-------
|
||||
[url1, url2]
|
||||
"""
|
||||
resp = requests.get(self.changes_url)
|
||||
resp = retry_request(self.changes_url)
|
||||
html = etree.HTML(resp.text)
|
||||
return html.xpath("//*[@id='itemContainer']//li/a/@href")
|
||||
|
||||
@@ -221,7 +239,7 @@ class CSIIndex(IndexBase):
|
||||
end_date: pd.Timestamp
|
||||
"""
|
||||
logger.info("get new companies......")
|
||||
context = requests.get(self.new_companies_url).content
|
||||
context = retry_request(self.new_companies_url).content
|
||||
with self.cache_dir.joinpath(
|
||||
f"{self.index_name.lower()}_new_companies.{self.new_companies_url.split('.')[-1]}"
|
||||
).open("wb") as fp:
|
||||
@@ -292,7 +310,7 @@ def get_instruments(
|
||||
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
|
||||
|
||||
"""
|
||||
_cur_module = importlib.import_module("collector")
|
||||
_cur_module = importlib.import_module("data_collector.cn_index.collector")
|
||||
obj = getattr(_cur_module, f"{index_name.upper()}")(
|
||||
qlib_dir=qlib_dir, index_name=index_name, request_retry=request_retry, retry_sleep=retry_sleep
|
||||
)
|
||||
|
||||
23
scripts/data_collector/contrib/fill_cn_1min_data/README.md
Normal file
23
scripts/data_collector/contrib/fill_cn_1min_data/README.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# Use 1d data to fill in the missing symbols relative to 1min
|
||||
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## fill 1min data
|
||||
|
||||
```bash
|
||||
python fill_1min_using_1d.py --data_1min_dir ~/.qlib/csv_data/cn_data_1min --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
- ata_1min_dir: csv data
|
||||
- qlib_data_1d_dir: qlib data directory
|
||||
- max_workers: `ThreadPoolExecutor(max_workers=max_workers)`, by default *16*
|
||||
- date_field_name: date field name, by default *date*
|
||||
- symbol_field_name: symbol field name, by default *symbol*
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import fire
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from qlib.data import D
|
||||
from loguru import logger
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent.parent))
|
||||
from data_collector.utils import generate_minutes_calendar_from_daily
|
||||
|
||||
|
||||
def get_date_range(data_1min_dir: Path, max_workers: int = 16, date_field_name: str = "date"):
|
||||
csv_files = list(data_1min_dir.glob("*.csv"))
|
||||
min_date = None
|
||||
max_date = None
|
||||
with tqdm(total=len(csv_files)) as p_bar:
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
for _file, _result in zip(csv_files, executor.map(pd.read_csv, csv_files)):
|
||||
if not _result.empty:
|
||||
_dates = pd.to_datetime(_result[date_field_name])
|
||||
|
||||
_tmp_min = _dates.min()
|
||||
min_date = min(min_date, _tmp_min) if min_date is not None else _tmp_min
|
||||
_tmp_max = _dates.max()
|
||||
max_date = max(max_date, _tmp_max) if max_date is not None else _tmp_max
|
||||
p_bar.update()
|
||||
return min_date, max_date
|
||||
|
||||
|
||||
def get_symbols(data_1min_dir: Path):
|
||||
return list(map(lambda x: x.name[:-4].upper(), data_1min_dir.glob("*.csv")))
|
||||
|
||||
|
||||
def fill_1min_using_1d(
|
||||
data_1min_dir: [str, Path],
|
||||
qlib_data_1d_dir: [str, Path],
|
||||
max_workers: int = 16,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
"""Use 1d data to fill in the missing symbols relative to 1min
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data_1min_dir: str
|
||||
1min data dir
|
||||
qlib_data_1d_dir: str
|
||||
1d qlib data(bin data) dir, from: https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format
|
||||
max_workers: int
|
||||
ThreadPoolExecutor(max_workers), by default 16
|
||||
date_field_name: str
|
||||
date field name, by default date
|
||||
symbol_field_name: str
|
||||
symbol field name, by default symbol
|
||||
|
||||
"""
|
||||
data_1min_dir = Path(data_1min_dir).expanduser().resolve()
|
||||
qlib_data_1d_dir = Path(qlib_data_1d_dir).expanduser().resolve()
|
||||
|
||||
min_date, max_date = get_date_range(data_1min_dir, max_workers, date_field_name)
|
||||
symbols_1min = get_symbols(data_1min_dir)
|
||||
|
||||
qlib.init(provider_uri=str(qlib_data_1d_dir))
|
||||
data_1d = D.features(D.instruments("all"), ["$close"], min_date, max_date, freq="day")
|
||||
|
||||
miss_symbols = set(data_1d.index.get_level_values(level="instrument").unique()) - set(symbols_1min)
|
||||
if not miss_symbols:
|
||||
logger.warning("More symbols in 1min than 1d, no padding required")
|
||||
return
|
||||
|
||||
logger.info(f"miss_symbols {len(miss_symbols)}: {miss_symbols}")
|
||||
tmp_df = pd.read_csv(list(data_1min_dir.glob("*.csv"))[0])
|
||||
columns = tmp_df.columns
|
||||
_si = tmp_df[symbol_field_name].first_valid_index()
|
||||
is_lower = tmp_df.loc[_si][symbol_field_name].islower()
|
||||
for symbol in tqdm(miss_symbols):
|
||||
if is_lower:
|
||||
symbol = symbol.lower()
|
||||
index_1d = data_1d.loc(axis=0)[symbol.upper()].index
|
||||
index_1min = generate_minutes_calendar_from_daily(index_1d)
|
||||
index_1min.name = date_field_name
|
||||
_df = pd.DataFrame(columns=columns, index=index_1min)
|
||||
if date_field_name in _df.columns:
|
||||
del _df[date_field_name]
|
||||
_df.reset_index(inplace=True)
|
||||
_df[symbol_field_name] = symbol
|
||||
_df["paused_num"] = 0
|
||||
_df.to_csv(data_1min_dir.joinpath(f"{symbol}.csv"), index=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(fill_1min_using_1d)
|
||||
@@ -0,0 +1,5 @@
|
||||
fire
|
||||
pandas
|
||||
loguru
|
||||
tqdm
|
||||
pyqlib
|
||||
@@ -14,7 +14,7 @@ from loguru import logger
|
||||
import baostock as bs
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
sys.path.append(str(CUR_DIR.parent.parent.parent))
|
||||
|
||||
|
||||
from data_collector.utils import generate_minutes_calendar_from_daily
|
||||
@@ -3,18 +3,13 @@
|
||||
|
||||
import abc
|
||||
import sys
|
||||
import copy
|
||||
import time
|
||||
import datetime
|
||||
import importlib
|
||||
import json
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Type
|
||||
|
||||
import fire
|
||||
import requests
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from loguru import logger
|
||||
from dateutil.tz import tzlocal
|
||||
@@ -38,7 +33,7 @@ class FundCollector(BaseCollector):
|
||||
max_workers=4,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
check_data_length: int = None,
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
@@ -59,8 +54,8 @@ class FundCollector(BaseCollector):
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
check_data_length: int
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
"""
|
||||
@@ -168,9 +163,7 @@ class FundollectorCN(FundCollector, ABC):
|
||||
|
||||
|
||||
class FundCollectorCN1d(FundollectorCN):
|
||||
@property
|
||||
def min_numbers_trading(self):
|
||||
return 252 / 4
|
||||
pass
|
||||
|
||||
|
||||
class FundNormalize(BaseNormalize):
|
||||
@@ -261,7 +254,7 @@ class Run(BaseRun):
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
check_data_length=False,
|
||||
check_data_length: int = None,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download data from Internet
|
||||
@@ -278,8 +271,8 @@ class Run(BaseRun):
|
||||
start datetime, default "2000-01-01"
|
||||
end: str
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
||||
check_data_length: bool # if this param useful?
|
||||
check data length, by default False
|
||||
check_data_length: int # if this param useful?
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
|
||||
|
||||
@@ -271,7 +271,7 @@ def get_instruments(
|
||||
$ python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
|
||||
|
||||
"""
|
||||
_cur_module = importlib.import_module("collector")
|
||||
_cur_module = importlib.import_module("data_collector.us_index.collector")
|
||||
obj = getattr(_cur_module, f"{index_name.upper()}Index")(
|
||||
qlib_dir=qlib_dir, index_name=index_name, request_retry=request_retry, retry_sleep=retry_sleep
|
||||
)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import os
|
||||
import time
|
||||
import bisect
|
||||
import pickle
|
||||
@@ -10,7 +9,7 @@ import random
|
||||
import requests
|
||||
import functools
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Tuple
|
||||
from typing import Iterable, Tuple, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -47,7 +46,7 @@ _CALENDAR_MAP = {}
|
||||
MINIMUM_SYMBOLS_NUM = 3900
|
||||
|
||||
|
||||
def get_calendar_list(bench_code="CSI300") -> list:
|
||||
def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
|
||||
"""get SH/SZ history calendar list
|
||||
|
||||
Parameters
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
|
||||
- [Collector Data](#collector-data)
|
||||
- [Get Qlib data](#get-qlib-databin-file)
|
||||
- [Collector *YahooFinance* data to qlib](#collector-yahoofinance-data-to-qlib)
|
||||
- [Automatic update of daily frequency data](#automatic-update-of-daily-frequency-datafrom-yahoo-finance)
|
||||
- [Using qlib data](#using-qlib-data)
|
||||
|
||||
|
||||
# Collect Data From Yahoo Finance
|
||||
|
||||
> *Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*
|
||||
@@ -18,113 +26,170 @@ pip install -r requirements.txt
|
||||
|
||||
## Collector Data
|
||||
|
||||
### Get Qlib data(`bin file`)
|
||||
> `qlib-data` from *YahooFinance*, is the data that has been dumped and can be used directly in `qlib`
|
||||
|
||||
### CN Data
|
||||
- get data: `python scripts/get_data.py qlib_data`
|
||||
- parameters:
|
||||
- `target_dir`: save dir, by default *~/.qlib/qlib_data/cn_data*
|
||||
- `version`: dataset version, value from [`v1`, `v2`], by default `v1`
|
||||
- `v2` end date is *2021-06*, `v1` end date is *2020-09*
|
||||
- user can append data to `v2`: [automatic update of daily frequency data](#automatic-update-of-daily-frequency-datafrom-yahoo-finance)
|
||||
- **the [benchmarks](https://github.com/microsoft/qlib/tree/main/examples/benchmarks) for qlib use `v1`**, *due to the unstable access to historical data by YahooFinance, there are some differences between `v2` and `v1`*
|
||||
- `interval`: `1d` or `1min`, by default `1d`
|
||||
- `region`: `cn` or `us`, by default `cn`
|
||||
- `delete_old`: delete existing data from `target_dir`(*features, calendars, instruments, dataset_cache, features_cache*), value from [`True`, `False`], by default `True`
|
||||
- `exists_skip`: traget_dir data already exists, skip `get_data`, value from [`True`, `False`], by default `False`
|
||||
- examples:
|
||||
```bash
|
||||
# cn 1d
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1d --region cn
|
||||
# cn 1min
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1min --region cn --interval 1min
|
||||
# us 1d
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_us_1d --region us --interval 1d
|
||||
# us 1min
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_us_1min --region us --interval 1min
|
||||
```
|
||||
|
||||
#### 1d from yahoo
|
||||
### Collector *YahooFinance* data to qlib
|
||||
> collector *YahooFinance* data and *dump* into `qlib` format
|
||||
1. download data to csv: `python scripts/data_collector/yahoo/collector.py download_data`
|
||||
|
||||
```bash
|
||||
- parameters:
|
||||
- `source_dir`: save the directory
|
||||
- `interval`: `1d` or `1min`, by default `1d`
|
||||
> **due to the limitation of the *YahooFinance API*, only the last month's data is available in `1min`**
|
||||
- `region`: `CN` or `US`, by default `CN`
|
||||
- `delay`: `time.sleep(delay)`, by default *0.5*
|
||||
- `start`: start datetime, by default *"2000-01-01"*; *closed interval(including start)*
|
||||
- `end`: end datetime, by default `pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`; *open interval(excluding end)*
|
||||
- `max_workers`: get the number of concurrent symbols, it is not recommended to change this parameter in order to maintain the integrity of the symbol data, by default *1*
|
||||
- `check_data_length`: check the number of rows per *symbol*, by default `None`
|
||||
> if `len(symbol_df) < check_data_length`, it will be re-fetched, with the number of re-fetches coming from the `max_collector_count` parameter
|
||||
- `max_collector_count`: number of *"failed"* symbol retries, by default 2
|
||||
- examples:
|
||||
```bash
|
||||
# cn 1d data
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1d --start 2020-01-01 --end 2020-12-31 --delay 1 --interval 1d --region US
|
||||
# cn 1min data
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1min --delay 1 --interval 1min --region CN
|
||||
# us 1d data
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/us_1d --start 2020-01-01 --end 2020-12-31 --delay 1 --interval 1d --region US
|
||||
# us 1min data
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/us_1min --delay 1 --interval 1min --region US
|
||||
```
|
||||
2. normalize data: `python scripts/data_collector/yahoo/collector.py normalize_data`
|
||||
|
||||
- parameters:
|
||||
- `source_dir`: csv directory
|
||||
- `normalize_dir`: result directory
|
||||
- `max_workers`: number of concurrent, by default *1*
|
||||
- `interval`: `1d` or `1min`, by default `1d`
|
||||
> if **`interval == 1min`**, `qlib_data_1d_dir` cannot be `None`
|
||||
- `region`: `CN` or `US`, by default `CN`
|
||||
- `date_field_name`: column *name* identifying time in csv files, by default `date`
|
||||
- `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol`
|
||||
- `end_date`: if not `None`, normalize the last date saved (*including end_date*); if `None`, it will ignore this parameter; by default `None`
|
||||
- `qlib_data_1d_dir`: qlib directory(1d data)
|
||||
```
|
||||
if interval==1min, qlib_data_1d_dir cannot be None, normalize 1min needs to use 1d data;
|
||||
|
||||
qlib_data_1d can be obtained like this:
|
||||
$ python scripts/get_data.py qlib_data --target_dir <qlib_data_1d_dir> --interval 1d
|
||||
$ python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <qlib_data_1d_dir> --trading_date 2021-06-01
|
||||
or:
|
||||
download 1d data from YahooFinance
|
||||
|
||||
```
|
||||
- examples:
|
||||
```bash
|
||||
# normalize 1d cn
|
||||
python collector.py normalize_data --source_dir ~/.qlib/stock_data/source/cn_1d --normalize_dir ~/.qlib/stock_data/source/cn_1d_nor --region CN --interval 1d
|
||||
# normalize 1min cn
|
||||
python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/qlib_cn_1d --source_dir ~/.qlib/stock_data/source/cn_1min --normalize_dir ~/.qlib/stock_data/source/cn_1min_nor --region CN --interval 1min
|
||||
```
|
||||
3. dump data: `python scripts/dump_bin.py dump_all`
|
||||
|
||||
- parameters:
|
||||
- `csv_path`: stock data path or directory, **normalize result(normalize_dir)**
|
||||
- `qlib_dir`: qlib(dump) data director
|
||||
- `freq`: transaction frequency, by default `day`
|
||||
> `freq_map = {1d:day, 1mih: 1min}`
|
||||
- `max_workers`: number of threads, by default *16*
|
||||
- `include_fields`: dump fields, by default `""`
|
||||
- `exclude_fields`: fields not dumped, by default `"""
|
||||
> dump_fields = `include_fields if include_fields else set(symbol_df.columns) - set(exclude_fields) exclude_fields else symbol_df.columns`
|
||||
- `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol`
|
||||
- `date_field_name`: column *name* identifying time in csv files, by default `date`
|
||||
- examples:
|
||||
```bash
|
||||
# dump 1d cn
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/qlib_cn_1d --freq day --exclude_fields date,symbol
|
||||
# dump 1min cn
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1min_nor --qlib_dir ~/.qlib/qlib_data/qlib_cn_1min --freq 1min --exclude_fields date,symbol
|
||||
```
|
||||
|
||||
# download from yahoo finance
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
### Automatic update of daily frequency data(from yahoo finance)
|
||||
> It is recommended that users update the data manually once (--trading_date 2021-05-25) and then set it to update automatically.
|
||||
|
||||
# normalize
|
||||
python collector.py normalize_data --source_dir ~/.qlib/stock_data/source/cn_1d --normalize_dir ~/.qlib/stock_data/source/cn_1d_nor --region CN --interval 1d
|
||||
* Automatic update of data to the "qlib" directory each trading day(Linux)
|
||||
* use *crontab*: `crontab -e`
|
||||
* set up timed tasks:
|
||||
|
||||
# dump data
|
||||
cd qlib/scripts
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/qlib_cn_1d --freq day --exclude_fields date,adjclose,dividends,splits,symbol
|
||||
```
|
||||
* * * * 1-5 python <script path> update_data_to_bin --qlib_data_1d_dir <user data dir>
|
||||
```
|
||||
* **script path**: *scripts/data_collector/yahoo/collector.py*
|
||||
|
||||
```
|
||||
* Manual update of data
|
||||
```
|
||||
python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
|
||||
```
|
||||
* `trading_date`: start of trading day
|
||||
* `end_date`: end of trading day(not included)
|
||||
* `check_data_length`: check the number of rows per *symbol*, by default `None`
|
||||
> if `len(symbol_df) < check_data_length`, it will be re-fetched, with the number of re-fetches coming from the `max_collector_count` parameter
|
||||
|
||||
### 1d from qlib
|
||||
```bash
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1d --region cn
|
||||
```
|
||||
|
||||
### using data
|
||||
|
||||
```python
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1d", region="cn")
|
||||
df = D.features(D.instruments("all"), ["$close"], freq="day")
|
||||
```
|
||||
|
||||
#### 1min from yahoo
|
||||
|
||||
```bash
|
||||
|
||||
# download from yahoo finance
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1min --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1min
|
||||
|
||||
# normalize
|
||||
python collector.py normalize_data --source_dir ~/.qlib/stock_data/source/cn_1min --normalize_dir ~/.qlib/stock_data/source/cn_1min_nor --region CN --interval 1min
|
||||
|
||||
# dump data
|
||||
cd qlib/scripts
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1min_nor --qlib_dir ~/.qlib/qlib_data/qlib_cn_1min --freq 1min --exclude_fields date,adjclose,dividends,splits,symbol
|
||||
```
|
||||
|
||||
### 1min from qlib
|
||||
```bash
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1min --interval 1min --region cn
|
||||
```
|
||||
|
||||
### using data
|
||||
|
||||
```python
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1min", region="cn")
|
||||
df = D.features(D.instruments("all"), ["$close"], freq="1min")
|
||||
|
||||
```
|
||||
|
||||
### US Data
|
||||
|
||||
#### 1d from yahoo
|
||||
|
||||
```bash
|
||||
|
||||
# download from yahoo finance
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/us_1d --region US --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
|
||||
# normalize
|
||||
python collector.py normalize_data --source_dir ~/.qlib/stock_data/source/us_1d --normalize_dir ~/.qlib/stock_data/source/us_1d_nor --region US --interval 1d
|
||||
|
||||
# dump data
|
||||
cd qlib/scripts
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/us_1d_nor --qlib_dir ~/.qlib/stock_data/source/qlib_us_1d --freq day --exclude_fields date,adjclose,dividends,splits,symbol
|
||||
```
|
||||
|
||||
#### 1d from qlib
|
||||
|
||||
```bash
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_us_1d --region us
|
||||
```
|
||||
|
||||
### using data
|
||||
|
||||
```python
|
||||
# using
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_us_1d", region="us")
|
||||
df = D.features(D.instruments("all"), ["$close"], freq="day")
|
||||
|
||||
```
|
||||
* `scripts/data_collector/yahoo/collector.py update_data_to_bin` parameters:
|
||||
* `source_dir`: The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
||||
* `normalize_dir`: Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
* `qlib_data_1d_dir`: the qlib data to be updated for yahoo, usually from: [download qlib data](https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data)
|
||||
* `trading_date`: trading days to be updated, by default ``datetime.datetime.now().strftime("%Y-%m-%d")``
|
||||
* `end_date`: end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end)
|
||||
* `region`: region, value from ["CN", "US"], default "CN"
|
||||
|
||||
|
||||
### Help
|
||||
```bash
|
||||
python collector.py collector_data --help
|
||||
```
|
||||
## Using qlib data
|
||||
|
||||
## Parameters
|
||||
```python
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
# 1d data cn
|
||||
# freq=day, freq default day
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1d", region="cn")
|
||||
df = D.features(D.instruments("all"), ["$close"], freq="day")
|
||||
|
||||
# 1min data cn
|
||||
# freq=1min
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1min", region="cn")
|
||||
inst = D.list_instruments(D.instruments("all"), freq="1min", as_list=True)
|
||||
# get 100 symbols
|
||||
df = D.features(inst[:100], ["$close"], freq="1min")
|
||||
# get all symbol data
|
||||
# df = D.features(D.instruments("all"), ["$close"], freq="1min")
|
||||
|
||||
# 1d data us
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_us_1d", region="us")
|
||||
df = D.features(D.instruments("all"), ["$close"], freq="day")
|
||||
|
||||
# 1min data us
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_us_1min", region="cn")
|
||||
inst = D.list_instruments(D.instruments("all"), freq="1min", as_list=True)
|
||||
# get 100 symbols
|
||||
df = D.features(inst[:100], ["$close"], freq="1min")
|
||||
# get all symbol data
|
||||
# df = D.features(D.instruments("all"), ["$close"], freq="1min")
|
||||
```
|
||||
|
||||
- interval: 1min or 1d
|
||||
- region: CN or US
|
||||
|
||||
@@ -8,8 +8,9 @@ import time
|
||||
import datetime
|
||||
import importlib
|
||||
from abc import ABC
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Type
|
||||
from typing import Iterable
|
||||
|
||||
import fire
|
||||
import requests
|
||||
@@ -18,13 +19,18 @@ import pandas as pd
|
||||
from loguru import logger
|
||||
from yahooquery import Ticker
|
||||
from dateutil.tz import tzlocal
|
||||
from qlib.utils import code_to_fname, fname_to_code
|
||||
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.utils import code_to_fname, fname_to_code, exists_qlib_data
|
||||
from qlib.config import REG_CN as REGION_CN
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
|
||||
|
||||
from dump_bin import DumpDataUpdate
|
||||
from data_collector.base import BaseCollector, BaseNormalize, BaseRun, Normalize
|
||||
from data_collector.utils import (
|
||||
deco_retry,
|
||||
get_calendar_list,
|
||||
get_hs_stock_symbols,
|
||||
get_us_stock_symbols,
|
||||
@@ -44,7 +50,7 @@ class YahooCollector(BaseCollector):
|
||||
max_workers=4,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
check_data_length: int = None,
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
@@ -65,8 +71,8 @@ class YahooCollector(BaseCollector):
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
check_data_length: int
|
||||
check data length, by default None
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
"""
|
||||
@@ -92,10 +98,6 @@ class YahooCollector(BaseCollector):
|
||||
else:
|
||||
raise ValueError(f"interval error: {self.interval}")
|
||||
|
||||
# using for 1min
|
||||
self._next_datetime = self.convert_datetime(self.start_datetime.date() + pd.Timedelta(days=1), self._timezone)
|
||||
self._latest_datetime = self.convert_datetime(self.end_datetime.date(), self._timezone)
|
||||
|
||||
self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)
|
||||
self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)
|
||||
|
||||
@@ -140,40 +142,39 @@ class YahooCollector(BaseCollector):
|
||||
def get_data(
|
||||
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
||||
) -> pd.DataFrame:
|
||||
@deco_retry(retry_sleep=self.delay)
|
||||
def _get_simple(start_, end_):
|
||||
self.sleep()
|
||||
_remote_interval = "1m" if interval == self.INTERVAL_1min else interval
|
||||
return self.get_data_from_remote(
|
||||
resp = self.get_data_from_remote(
|
||||
symbol,
|
||||
interval=_remote_interval,
|
||||
start=start_,
|
||||
end=end_,
|
||||
)
|
||||
if resp is None or resp.empty:
|
||||
raise ValueError(f"get data error: {symbol}--{start_}--{end_}")
|
||||
return resp
|
||||
|
||||
_result = None
|
||||
if interval == self.INTERVAL_1d:
|
||||
_result = _get_simple(start_datetime, end_datetime)
|
||||
elif interval == self.INTERVAL_1min:
|
||||
if self._next_datetime >= self._latest_datetime:
|
||||
try:
|
||||
_result = _get_simple(start_datetime, end_datetime)
|
||||
else:
|
||||
_res = []
|
||||
|
||||
def _get_multi(start_, end_):
|
||||
_resp = _get_simple(start_, end_)
|
||||
if _resp is not None and not _resp.empty:
|
||||
_res.append(_resp)
|
||||
|
||||
for _s, _e in (
|
||||
(self.start_datetime, self._next_datetime),
|
||||
(self._latest_datetime, self.end_datetime),
|
||||
):
|
||||
_get_multi(_s, _e)
|
||||
for _start in pd.date_range(self._next_datetime, self._latest_datetime, closed="left"):
|
||||
_end = _start + pd.Timedelta(days=1)
|
||||
_get_multi(_start, _end)
|
||||
if _res:
|
||||
_result = pd.concat(_res, sort=False).sort_values(["symbol", "date"])
|
||||
except ValueError as e:
|
||||
pass
|
||||
elif interval == self.INTERVAL_1min:
|
||||
_res = []
|
||||
_start = self.start_datetime
|
||||
while _start < self.end_datetime:
|
||||
_tmp_end = min(_start + pd.Timedelta(days=7), self.end_datetime)
|
||||
try:
|
||||
_resp = _get_simple(_start, _tmp_end)
|
||||
_res.append(_resp)
|
||||
except ValueError as e:
|
||||
pass
|
||||
_start = _tmp_end
|
||||
if _res:
|
||||
_result = pd.concat(_res, sort=False).sort_values(["symbol", "date"])
|
||||
else:
|
||||
raise ValueError(f"cannot support {self.interval}")
|
||||
return pd.DataFrame() if _result is None else _result
|
||||
@@ -207,10 +208,6 @@ class YahooCollectorCN(YahooCollector, ABC):
|
||||
|
||||
|
||||
class YahooCollectorCN1d(YahooCollectorCN):
|
||||
@property
|
||||
def min_numbers_trading(self):
|
||||
return 252 / 4
|
||||
|
||||
def download_index_data(self):
|
||||
# TODO: from MSN
|
||||
_format = "%Y%m%d"
|
||||
@@ -244,13 +241,12 @@ class YahooCollectorCN1d(YahooCollectorCN):
|
||||
|
||||
|
||||
class YahooCollectorCN1min(YahooCollectorCN):
|
||||
@property
|
||||
def min_numbers_trading(self):
|
||||
return 60 * 4 * 5
|
||||
def get_instrument_list(self):
|
||||
symbols = super(YahooCollectorCN1min, self).get_instrument_list()
|
||||
return symbols + ["000300.ss", "000905.ss", "000903.ss"]
|
||||
|
||||
def download_index_data(self):
|
||||
# TODO: 1m
|
||||
logger.warning(f"{self.__class__.__name__} {self.interval} does not support: download_index_data")
|
||||
pass
|
||||
|
||||
|
||||
class YahooCollectorUS(YahooCollector, ABC):
|
||||
@@ -276,15 +272,11 @@ class YahooCollectorUS(YahooCollector, ABC):
|
||||
|
||||
|
||||
class YahooCollectorUS1d(YahooCollectorUS):
|
||||
@property
|
||||
def min_numbers_trading(self):
|
||||
return 252 / 4
|
||||
pass
|
||||
|
||||
|
||||
class YahooCollectorUS1min(YahooCollectorUS):
|
||||
@property
|
||||
def min_numbers_trading(self):
|
||||
return 60 * 6.5 * 5
|
||||
pass
|
||||
|
||||
|
||||
class YahooNormalize(BaseNormalize):
|
||||
@@ -297,6 +289,7 @@ class YahooNormalize(BaseNormalize):
|
||||
calendar_list: list = None,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
last_close: float = None,
|
||||
):
|
||||
if df.empty:
|
||||
return df
|
||||
@@ -318,7 +311,10 @@ class YahooNormalize(BaseNormalize):
|
||||
df.sort_index(inplace=True)
|
||||
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {symbol_field_name}] = np.nan
|
||||
_tmp_series = df["close"].fillna(method="ffill")
|
||||
df["change"] = _tmp_series / _tmp_series.shift(1) - 1
|
||||
_tmp_shift_series = _tmp_series.shift(1)
|
||||
if last_close is not None:
|
||||
_tmp_shift_series.iloc[0] = float(last_close)
|
||||
df["change"] = _tmp_series / _tmp_shift_series - 1
|
||||
columns += ["change"]
|
||||
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), columns] = np.nan
|
||||
|
||||
@@ -367,6 +363,17 @@ class YahooNormalize1d(YahooNormalize, ABC):
|
||||
df = self._manual_adj_data(df)
|
||||
return df
|
||||
|
||||
def _get_first_close(self, df: pd.DataFrame) -> float:
|
||||
"""get first close value
|
||||
|
||||
Notes
|
||||
-----
|
||||
For incremental updates(append) to Yahoo 1D data, user need to use a close that is not 0 on the first trading day of the existing data
|
||||
"""
|
||||
df = df.loc[df["close"].first_valid_index() :]
|
||||
_close = df["close"].iloc[0]
|
||||
return _close
|
||||
|
||||
def _manual_adj_data(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""manual adjust data: All fields (except change) are standardized according to the close of the first day"""
|
||||
if df.empty:
|
||||
@@ -374,45 +381,112 @@ class YahooNormalize1d(YahooNormalize, ABC):
|
||||
df = df.copy()
|
||||
df.sort_values(self._date_field_name, inplace=True)
|
||||
df = df.set_index(self._date_field_name)
|
||||
df = df.loc[df["close"].first_valid_index() :]
|
||||
_close = df["close"].iloc[0]
|
||||
_close = self._get_first_close(df)
|
||||
for _col in df.columns:
|
||||
if _col == self._symbol_field_name:
|
||||
# NOTE: retain original adjclose, required for incremental updates
|
||||
if _col in [self._symbol_field_name, "adjclose", "change"]:
|
||||
continue
|
||||
if _col == "volume":
|
||||
df[_col] = df[_col] * _close
|
||||
elif _col != "change":
|
||||
df[_col] = df[_col] / _close
|
||||
else:
|
||||
pass
|
||||
df[_col] = df[_col] / _close
|
||||
return df.reset_index()
|
||||
|
||||
|
||||
class YahooNormalize1dExtend(YahooNormalize1d):
|
||||
def __init__(
|
||||
self, old_qlib_data_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
old_qlib_data_dir: str, Path
|
||||
the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
super(YahooNormalize1dExtend, self).__init__(date_field_name, symbol_field_name)
|
||||
self._first_close_field = "first_close"
|
||||
self._ori_close_field = "ori_close"
|
||||
self.old_qlib_data = self._get_old_data(old_qlib_data_dir)
|
||||
|
||||
def _get_old_data(self, qlib_data_dir: [str, Path]):
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib_data_dir = str(Path(qlib_data_dir).expanduser().resolve())
|
||||
qlib.init(provider_uri=qlib_data_dir, expression_cache=None, dataset_cache=None)
|
||||
df = D.features(D.instruments("all"), ["$close/$factor", "$adjclose/$close"])
|
||||
df.columns = [self._ori_close_field, self._first_close_field]
|
||||
return df
|
||||
|
||||
def _get_close(self, df: pd.DataFrame, field_name: str):
|
||||
_symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper()
|
||||
_df = self.old_qlib_data.loc(axis=0)[_symbol]
|
||||
_close = _df.loc[_df.last_valid_index()][field_name]
|
||||
return _close
|
||||
|
||||
def _get_first_close(self, df: pd.DataFrame) -> float:
|
||||
try:
|
||||
_close = self._get_close(df, field_name=self._first_close_field)
|
||||
except KeyError:
|
||||
_close = super(YahooNormalize1dExtend, self)._get_first_close(df)
|
||||
return _close
|
||||
|
||||
def _get_last_close(self, df: pd.DataFrame) -> float:
|
||||
try:
|
||||
_close = self._get_close(df, field_name=self._ori_close_field)
|
||||
except KeyError:
|
||||
_close = None
|
||||
return _close
|
||||
|
||||
def _get_last_date(self, df: pd.DataFrame) -> pd.Timestamp:
|
||||
_symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper()
|
||||
try:
|
||||
_df = self.old_qlib_data.loc(axis=0)[_symbol]
|
||||
_date = _df.index.max()
|
||||
except KeyError:
|
||||
_date = None
|
||||
return _date
|
||||
|
||||
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
_last_close = self._get_last_close(df)
|
||||
# reindex
|
||||
_last_date = self._get_last_date(df)
|
||||
if _last_date is not None:
|
||||
df = df.set_index(self._date_field_name)
|
||||
df.index = pd.to_datetime(df.index)
|
||||
df = df[~df.index.duplicated(keep="first")]
|
||||
_max_date = df.index.max()
|
||||
df = df.reindex(self._calendar_list).loc[:_max_date].reset_index()
|
||||
df = df[df[self._date_field_name] > _last_date]
|
||||
if df.empty:
|
||||
return pd.DataFrame()
|
||||
_si = df["close"].first_valid_index()
|
||||
if _si > df.index[0]:
|
||||
logger.warning(
|
||||
f"{df.loc[_si][self._symbol_field_name]} missing data: {df.loc[:_si - 1][self._date_field_name].to_list()}"
|
||||
)
|
||||
# normalize
|
||||
df = self.normalize_yahoo(
|
||||
df, self._calendar_list, self._date_field_name, self._symbol_field_name, last_close=_last_close
|
||||
)
|
||||
# adjusted price
|
||||
df = self.adjusted_price(df)
|
||||
df = self._manual_adj_data(df)
|
||||
return df
|
||||
|
||||
|
||||
class YahooNormalize1min(YahooNormalize, ABC):
|
||||
AM_RANGE = None # type: tuple # eg: ("09:30:00", "11:29:00")
|
||||
PM_RANGE = None # type: tuple # eg: ("13:00:00", "14:59:00")
|
||||
|
||||
# Whether the trading day of 1min data is consistent with 1d
|
||||
CONSISTENT_1d = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
super(YahooNormalize1min, self).__init__(date_field_name, symbol_field_name)
|
||||
_class_name = self.__class__.__name__.replace("min", "d")
|
||||
_class = getattr(importlib.import_module("collector"), _class_name) # type: Type[YahooNormalize]
|
||||
self.data_1d_obj = _class(self._date_field_name, self._symbol_field_name)
|
||||
CONSISTENT_1d = True
|
||||
CALC_PAUSED_NUM = True
|
||||
|
||||
@property
|
||||
def calendar_list_1d(self):
|
||||
@@ -427,24 +501,40 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
calendars, freq="1min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE
|
||||
)
|
||||
|
||||
def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame:
|
||||
"""get 1d data
|
||||
|
||||
Returns
|
||||
------
|
||||
data_1d: pd.DataFrame
|
||||
data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"]
|
||||
|
||||
"""
|
||||
data_1d = YahooCollector.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end)
|
||||
if not (data_1d is None or data_1d.empty):
|
||||
_class_name = self.__class__.__name__.replace("min", "d")
|
||||
_class: type(YahooNormalize) = getattr(importlib.import_module("collector"), _class_name)
|
||||
data_1d_obj = _class(self._date_field_name, self._symbol_field_name)
|
||||
data_1d = data_1d_obj.normalize(data_1d)
|
||||
return data_1d
|
||||
|
||||
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
# TODO: using daily data factor
|
||||
if df.empty:
|
||||
return df
|
||||
df = df.copy()
|
||||
df = df.sort_values(self._date_field_name)
|
||||
symbol = df.iloc[0][self._symbol_field_name]
|
||||
# get 1d data from yahoo
|
||||
_start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT)
|
||||
_end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT)
|
||||
data_1d = YahooCollector.get_data_from_remote(
|
||||
self.symbol_to_yahoo(symbol), interval="1d", start=_start, end=_end
|
||||
)
|
||||
data_1d: pd.DataFrame = self.get_1d_data(symbol, _start, _end)
|
||||
data_1d = data_1d.copy()
|
||||
if data_1d is None or data_1d.empty:
|
||||
df["factor"] = 1
|
||||
df["factor"] = 1 / df.loc[df["close"].first_valid_index()]["close"]
|
||||
# TODO: np.nan or 1 or 0
|
||||
df["paused"] = np.nan
|
||||
else:
|
||||
data_1d = self.data_1d_obj.normalize(data_1d) # type: pd.DataFrame
|
||||
# NOTE: volume is np.nan or volume <= 0, paused = 1
|
||||
# FIXME: find a more accurate data source
|
||||
data_1d["paused"] = 0
|
||||
@@ -452,9 +542,13 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
data_1d = data_1d.set_index(self._date_field_name)
|
||||
|
||||
# add factor from 1d data
|
||||
# NOTE: yahoo 1d data info:
|
||||
# - Close price adjusted for splits. Adjusted close price adjusted for both dividends and splits.
|
||||
# - data_1d.adjclose: Adjusted close price adjusted for both dividends and splits.
|
||||
# - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)`
|
||||
df["date_tmp"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date())
|
||||
df.set_index("date_tmp", inplace=True)
|
||||
df.loc[:, "factor"] = data_1d["factor"]
|
||||
df.loc[:, "factor"] = data_1d["close"] / df["close"]
|
||||
df.loc[:, "paused"] = data_1d["paused"]
|
||||
df.reset_index("date_tmp", drop=True, inplace=True)
|
||||
|
||||
@@ -478,6 +572,54 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
df[_col] = df[_col] / df["factor"]
|
||||
else:
|
||||
df[_col] = df[_col] * df["factor"]
|
||||
|
||||
if self.CALC_PAUSED_NUM:
|
||||
df = self.calc_paused_num(df)
|
||||
return df
|
||||
|
||||
def calc_paused_num(self, df: pd.DataFrame):
|
||||
_symbol = df.iloc[0][self._symbol_field_name]
|
||||
df = df.copy()
|
||||
df["_tmp_date"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date())
|
||||
# remove data that starts and ends with `np.nan` all day
|
||||
all_data = []
|
||||
# Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan
|
||||
all_nan_nums = 0
|
||||
# Record the number of consecutive occurrences of trading days that are not nan throughout the day
|
||||
not_nan_nums = 0
|
||||
for _date, _df in df.groupby("_tmp_date"):
|
||||
_df["paused"] = 0
|
||||
if not _df.loc[_df["volume"] < 0].empty:
|
||||
logger.warning(f"volume < 0, will fill np.nan: {_date} {_symbol}")
|
||||
_df.loc[_df["volume"] < 0, "volume"] = np.nan
|
||||
|
||||
check_fields = set(_df.columns) - {
|
||||
"_tmp_date",
|
||||
"paused",
|
||||
"factor",
|
||||
self._date_field_name,
|
||||
self._symbol_field_name,
|
||||
}
|
||||
if _df.loc[:, check_fields].isna().values.all() or (_df["volume"] == 0).all():
|
||||
all_nan_nums += 1
|
||||
not_nan_nums = 0
|
||||
_df["paused"] = 1
|
||||
if all_data:
|
||||
_df["paused_num"] = not_nan_nums
|
||||
all_data.append(_df)
|
||||
else:
|
||||
all_nan_nums = 0
|
||||
not_nan_nums += 1
|
||||
_df["paused_num"] = not_nan_nums
|
||||
all_data.append(_df)
|
||||
all_data = all_data[: len(all_data) - all_nan_nums]
|
||||
if all_data:
|
||||
df = pd.concat(all_data, sort=False)
|
||||
else:
|
||||
logger.warning(f"data is empty: {_symbol}")
|
||||
df = pd.DataFrame()
|
||||
return df
|
||||
del df["_tmp_date"]
|
||||
return df
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -485,12 +627,67 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
raise NotImplementedError("rewrite symbol_to_yahoo")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_1d_calendar_list(self):
|
||||
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
raise NotImplementedError("rewrite _get_1d_calendar_list")
|
||||
|
||||
|
||||
class YahooNormalize1minOffline(YahooNormalize1min):
|
||||
"""Normalised to 1min using local 1d data"""
|
||||
|
||||
def __init__(
|
||||
self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_data_1d_dir: str, Path
|
||||
the qlib data to be updated for yahoo, usually from: Normalised to 1min using local 1d data
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
self.qlib_data_1d_dir = qlib_data_1d_dir
|
||||
super(YahooNormalize1minOffline, self).__init__(date_field_name, symbol_field_name)
|
||||
self._all_1d_data = self._get_all_1d_data()
|
||||
|
||||
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri=self.qlib_data_1d_dir)
|
||||
return list(D.calendar(freq="day"))
|
||||
|
||||
def _get_all_1d_data(self):
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri=self.qlib_data_1d_dir)
|
||||
df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day")
|
||||
df.reset_index(inplace=True)
|
||||
df.rename(columns={"datetime": self._date_field_name, "instrument": self._symbol_field_name}, inplace=True)
|
||||
df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns))
|
||||
return df
|
||||
|
||||
def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame:
|
||||
"""get 1d data
|
||||
|
||||
Returns
|
||||
------
|
||||
data_1d: pd.DataFrame
|
||||
data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"]
|
||||
|
||||
"""
|
||||
return self._all_1d_data[
|
||||
(self._all_1d_data[self._symbol_field_name] == symbol.upper())
|
||||
& (self._all_1d_data[self._date_field_name] >= pd.Timestamp(start))
|
||||
& (self._all_1d_data[self._date_field_name] < pd.Timestamp(end))
|
||||
]
|
||||
|
||||
|
||||
class YahooNormalizeUS:
|
||||
def _get_calendar_list(self):
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
# TODO: from MSN
|
||||
return get_calendar_list("US_ALL")
|
||||
|
||||
@@ -499,10 +696,10 @@ class YahooNormalizeUS1d(YahooNormalizeUS, YahooNormalize1d):
|
||||
pass
|
||||
|
||||
|
||||
class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1min):
|
||||
CONSISTENT_1d = False
|
||||
class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1minOffline):
|
||||
CALC_PAUSED_NUM = False
|
||||
|
||||
def _get_calendar_list(self):
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
# TODO: support 1min
|
||||
raise ValueError("Does not support 1min")
|
||||
|
||||
@@ -514,7 +711,7 @@ class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1min):
|
||||
|
||||
|
||||
class YahooNormalizeCN:
|
||||
def _get_calendar_list(self):
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
# TODO: from MSN
|
||||
return get_calendar_list("ALL")
|
||||
|
||||
@@ -523,28 +720,30 @@ class YahooNormalizeCN1d(YahooNormalizeCN, YahooNormalize1d):
|
||||
pass
|
||||
|
||||
|
||||
class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min):
|
||||
class YahooNormalizeCN1dExtend(YahooNormalizeCN, YahooNormalize1dExtend):
|
||||
pass
|
||||
|
||||
|
||||
class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1minOffline):
|
||||
AM_RANGE = ("09:30:00", "11:29:00")
|
||||
PM_RANGE = ("13:00:00", "14:59:00")
|
||||
|
||||
CONSISTENT_1d = True
|
||||
|
||||
def _get_calendar_list(self):
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
return self.generate_1min_from_daily(self.calendar_list_1d)
|
||||
|
||||
def symbol_to_yahoo(self, symbol):
|
||||
if "." not in symbol:
|
||||
_exchange = symbol[:2]
|
||||
_exchange = "ss" if _exchange == "sh" else _exchange
|
||||
_exchange = ("ss" if _exchange.islower() else "SS") if _exchange.lower() == "sh" else _exchange
|
||||
symbol = symbol[2:] + "." + _exchange
|
||||
return symbol
|
||||
|
||||
def _get_1d_calendar_list(self):
|
||||
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
return get_calendar_list("ALL")
|
||||
|
||||
|
||||
class Run(BaseRun):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d", region=REGION_CN):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
@@ -554,7 +753,7 @@ class Run(BaseRun):
|
||||
normalize_dir: str
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
max_workers: int
|
||||
Concurrent number, default is 4
|
||||
Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
region: str
|
||||
@@ -578,10 +777,10 @@ class Run(BaseRun):
|
||||
def download_data(
|
||||
self,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
delay=0.5,
|
||||
start=None,
|
||||
end=None,
|
||||
check_data_length=False,
|
||||
check_data_length=None,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download data from Internet
|
||||
@@ -591,16 +790,23 @@ class Run(BaseRun):
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
time.sleep(delay), default 0.5
|
||||
start: str
|
||||
start datetime, default "2000-01-01"
|
||||
start datetime, default "2000-01-01"; closed interval(including start)
|
||||
end: str
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``; open interval(excluding end)
|
||||
check_data_length: int
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
|
||||
Notes
|
||||
-----
|
||||
check_data_length, example:
|
||||
daily, one year: 252 // 4
|
||||
us 1min, a week: 6.5 * 60 * 5
|
||||
cn 1min, a week: 4 * 60 * 5
|
||||
|
||||
Examples
|
||||
---------
|
||||
# get daily data
|
||||
@@ -612,7 +818,13 @@ class Run(BaseRun):
|
||||
max_collector_count, delay, start, end, self.interval, check_data_length, limit_nums
|
||||
)
|
||||
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
||||
def normalize_data(
|
||||
self,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
end_date: str = None,
|
||||
qlib_data_1d_dir: str = None,
|
||||
):
|
||||
"""normalize data
|
||||
|
||||
Parameters
|
||||
@@ -621,12 +833,205 @@ class Run(BaseRun):
|
||||
date field name, default date
|
||||
symbol_field_name: str
|
||||
symbol field name, default symbol
|
||||
end_date: str
|
||||
if not None, normalize the last date saved (including end_date); if None, it will ignore this parameter; by default None
|
||||
qlib_data_1d_dir: str
|
||||
if interval==1min, qlib_data_1d_dir cannot be None, normalize 1min needs to use 1d data;
|
||||
|
||||
qlib_data_1d can be obtained like this:
|
||||
$ python scripts/get_data.py qlib_data --target_dir <qlib_data_1d_dir> --interval 1d
|
||||
$ python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <qlib_data_1d_dir> --trading_date 2021-06-01
|
||||
or:
|
||||
download 1d data, reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#1d-from-yahoo
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region cn --interval 1d
|
||||
$ python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_1d --source_dir ~/.qlib/stock_data/source_cn_1min --normalize_dir ~/.qlib/stock_data/normalize_cn_1min --region CN --interval 1min
|
||||
"""
|
||||
super(Run, self).normalize_data(date_field_name, symbol_field_name)
|
||||
if self.interval.lower() == "1min":
|
||||
if qlib_data_1d_dir is None or not Path(qlib_data_1d_dir).expanduser().exists():
|
||||
raise ValueError(
|
||||
"If normalize 1min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir <user qlib 1d data >, Reference: https://github.com/zhupr/qlib/tree/support_extend_data/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance"
|
||||
)
|
||||
super(Run, self).normalize_data(
|
||||
date_field_name, symbol_field_name, end_date=end_date, qlib_data_1d_dir=qlib_data_1d_dir
|
||||
)
|
||||
|
||||
def normalize_data_1d_extend(
|
||||
self, old_qlib_data_dir, date_field_name: str = "date", symbol_field_name: str = "symbol"
|
||||
):
|
||||
"""normalize data extend; extending yahoo qlib data(from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data)
|
||||
|
||||
Notes
|
||||
-----
|
||||
Steps to extend yahoo qlib data:
|
||||
|
||||
1. download qlib data: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data; save to <dir1>
|
||||
|
||||
2. collector source data: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#collector-data; save to <dir2>
|
||||
|
||||
3. normalize new source data(from step 2): python scripts/data_collector/yahoo/collector.py normalize_data_1d_extend --old_qlib_dir <dir1> --source_dir <dir2> --normalize_dir <dir3> --region CN --interval 1d
|
||||
|
||||
4. dump data: python scripts/dump_bin.py dump_update --csv_path <dir3> --qlib_dir <dir1> --freq day --date_field_name date --symbol_field_name symbol --exclude_fields symbol,date
|
||||
|
||||
5. update instrument(eg. csi300): python python scripts/data_collector/cn_index/collector.py --index_name CSI300 --qlib_dir <dir1> --method parse_instruments
|
||||
|
||||
Parameters
|
||||
----------
|
||||
old_qlib_data_dir: str
|
||||
the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data
|
||||
date_field_name: str
|
||||
date field name, default date
|
||||
symbol_field_name: str
|
||||
symbol field name, default symbol
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data_1d_extend --old_qlib_dir ~/.qlib/qlib_data/cn_1d --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d
|
||||
"""
|
||||
_class = getattr(self._cur_module, f"{self.normalize_class_name}Extend")
|
||||
yc = Normalize(
|
||||
source_dir=self.source_dir,
|
||||
target_dir=self.normalize_dir,
|
||||
normalize_class=_class,
|
||||
max_workers=self.max_workers,
|
||||
date_field_name=date_field_name,
|
||||
symbol_field_name=symbol_field_name,
|
||||
old_qlib_data_dir=old_qlib_data_dir,
|
||||
)
|
||||
yc.normalize()
|
||||
|
||||
def download_today_data(
|
||||
self,
|
||||
max_collector_count=2,
|
||||
delay=0.5,
|
||||
check_data_length=None,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download today data from Internet
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0.5
|
||||
check_data_length: int
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
|
||||
Notes
|
||||
-----
|
||||
Download today's data:
|
||||
start_time = datetime.datetime.now().date(); closed interval(including start)
|
||||
end_time = pd.Timestamp(start_time + pd.Timedelta(days=1)).date(); open interval(excluding end)
|
||||
|
||||
check_data_length, example:
|
||||
daily, one year: 252 // 4
|
||||
us 1min, a week: 6.5 * 60 * 5
|
||||
cn 1min, a week: 4 * 60 * 5
|
||||
|
||||
Examples
|
||||
---------
|
||||
# get daily data
|
||||
$ python collector.py download_today_data --source_dir ~/.qlib/stock_data/source --region CN --delay 0.1 --interval 1d
|
||||
# get 1m data
|
||||
$ python collector.py download_today_data --source_dir ~/.qlib/stock_data/source --region CN --delay 0.1 --interval 1m
|
||||
"""
|
||||
start = datetime.datetime.now().date()
|
||||
end = pd.Timestamp(start + pd.Timedelta(days=1)).date()
|
||||
self.download_data(
|
||||
max_collector_count,
|
||||
delay,
|
||||
start.strftime("%Y-%m-%d"),
|
||||
end.strftime("%Y-%m-%d"),
|
||||
check_data_length,
|
||||
limit_nums,
|
||||
)
|
||||
|
||||
def update_data_to_bin(
|
||||
self,
|
||||
qlib_data_1d_dir: str,
|
||||
trading_date: str = None,
|
||||
end_date: str = None,
|
||||
check_data_length: int = None,
|
||||
delay: float = 1,
|
||||
):
|
||||
"""update yahoo data to bin
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_data_1d_dir: str
|
||||
the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data
|
||||
|
||||
trading_date: str
|
||||
trading days to be updated, by default ``datetime.datetime.now().strftime("%Y-%m-%d")``
|
||||
end_date: str
|
||||
end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end)
|
||||
check_data_length: int
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
delay: float
|
||||
time.sleep(delay), default 1
|
||||
Notes
|
||||
-----
|
||||
If the data in qlib_data_dir is incomplete, np.nan will be populated to trading_date for the previous trading day
|
||||
|
||||
Examples
|
||||
-------
|
||||
$ python collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
|
||||
# get 1m data
|
||||
"""
|
||||
|
||||
if self.interval.lower() != "1d":
|
||||
logger.warning(f"currently supports 1d data updates: --interval 1d")
|
||||
|
||||
# start/end date
|
||||
if trading_date is None:
|
||||
trading_date = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
logger.warning(f"trading_date is None, use the current date: {trading_date}")
|
||||
|
||||
if end_date is None:
|
||||
end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
|
||||
# download qlib 1d data
|
||||
qlib_data_1d_dir = str(Path(qlib_data_1d_dir).expanduser().resolve())
|
||||
if not exists_qlib_data(qlib_data_1d_dir):
|
||||
GetData().qlib_data(target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region)
|
||||
|
||||
# download data from yahoo
|
||||
# NOTE: when downloading data from YahooFinance, max_workers is recommended to be 1
|
||||
self.download_data(delay=delay, start=trading_date, end=end_date, check_data_length=check_data_length)
|
||||
# NOTE: a larger max_workers setting here would be faster
|
||||
self.max_workers = (
|
||||
max(multiprocessing.cpu_count() - 2, 1)
|
||||
if self.max_workers is None or self.max_workers <= 1
|
||||
else self.max_workers
|
||||
)
|
||||
# normalize data
|
||||
self.normalize_data_1d_extend(qlib_data_1d_dir)
|
||||
|
||||
# dump bin
|
||||
_dump = DumpDataUpdate(
|
||||
csv_path=self.normalize_dir,
|
||||
qlib_dir=qlib_data_1d_dir,
|
||||
exclude_fields="symbol,date",
|
||||
max_workers=self.max_workers,
|
||||
)
|
||||
_dump.dump()
|
||||
|
||||
# parse index
|
||||
_region = self.region.lower()
|
||||
if _region not in ["cn", "us"]:
|
||||
logger.warning(f"Unsupported region: region={_region}, component downloads will be ignored")
|
||||
return
|
||||
index_list = ["CSI100", "CSI300"] if _region == "cn" else ["SP500", "NASDAQ100", "DJIA", "SP400"]
|
||||
get_instruments = getattr(
|
||||
importlib.import_module(f"data_collector.{_region}_index.collector"), "get_instruments"
|
||||
)
|
||||
for _index in index_list:
|
||||
get_instruments(str(qlib_data_1d_dir), _index)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -6,3 +6,4 @@ pandas
|
||||
tqdm
|
||||
lxml
|
||||
yahooquery
|
||||
joblib
|
||||
|
||||
@@ -401,6 +401,8 @@ class DumpDataUpdate(DumpDataBase):
|
||||
)
|
||||
self._mode = self.UPDATE_MODE
|
||||
self._old_calendar_list = self._read_calendars(self._calendars_dir.joinpath(f"{self.freq}.txt"))
|
||||
# NOTE: all.txt only exists once for each stock
|
||||
# NOTE: if a stock corresponds to multiple different time ranges, user need to modify self._update_instruments
|
||||
self._update_instruments = (
|
||||
self._read_instruments(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME))
|
||||
.set_index([self.symbol_field_name])
|
||||
@@ -409,10 +411,9 @@ class DumpDataUpdate(DumpDataBase):
|
||||
|
||||
# load all csv files
|
||||
self._all_data = self._load_all_source_data() # type: pd.DataFrame
|
||||
self._update_calendars = sorted(
|
||||
self._new_calendar_list = self._old_calendar_list + sorted(
|
||||
filter(lambda x: x > self._old_calendar_list[-1], self._all_data[self.date_field_name].unique())
|
||||
)
|
||||
self._new_calendar_list = self._old_calendar_list + self._update_calendars
|
||||
|
||||
def _load_all_source_data(self):
|
||||
# NOTE: Need more memory
|
||||
@@ -452,8 +453,16 @@ class DumpDataUpdate(DumpDataBase):
|
||||
if not (isinstance(_start, pd.Timestamp) and isinstance(_end, pd.Timestamp)):
|
||||
continue
|
||||
if _code in self._update_instruments:
|
||||
# exists stock, will append data
|
||||
_update_calendars = (
|
||||
_df[_df[self.date_field_name] > self._update_instruments[_code][self.INSTRUMENTS_START_FIELD]][
|
||||
self.date_field_name
|
||||
]
|
||||
.sort_values()
|
||||
.to_list()
|
||||
)
|
||||
self._update_instruments[_code][self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end)
|
||||
futures[executor.submit(self._dump_bin, _df, self._update_calendars)] = _code
|
||||
futures[executor.submit(self._dump_bin, _df, _update_calendars)] = _code
|
||||
else:
|
||||
# new stock
|
||||
_dt_range = self._update_instruments.setdefault(_code, dict())
|
||||
|
||||
Reference in New Issue
Block a user