1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Merge pull request #292 from wangershi/addFund

Add fund data as an example
This commit is contained in:
you-n-g
2021-03-19 11:05:29 +08:00
committed by GitHub
5 changed files with 510 additions and 24 deletions

View File

@@ -46,7 +46,7 @@ class BaseCollector(abc.ABC):
Parameters
----------
save_dir: str
stock save dir
instrument save dir
max_workers: int
workers, default 4
max_collector_count: int
@@ -77,11 +77,11 @@ class BaseCollector(abc.ABC):
self.start_datetime = self.normalize_start_datetime(start)
self.end_datetime = self.normalize_end_datetime(end)
self.stock_list = sorted(set(self.get_stock_list()))
self.instrument_list = sorted(set(self.get_instrument_list()))
if limit_nums is not None:
try:
self.stock_list = self.stock_list[: int(limit_nums)]
self.instrument_list = self.instrument_list[: int(limit_nums)]
except Exception as e:
logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored")
@@ -108,8 +108,8 @@ class BaseCollector(abc.ABC):
raise NotImplementedError("rewrite min_numbers_trading")
@abc.abstractmethod
def get_stock_list(self):
raise NotImplementedError("rewrite get_stock_list")
def get_instrument_list(self):
raise NotImplementedError("rewrite get_instrument_list")
@abc.abstractmethod
def normalize_symbol(self, symbol: str):
@@ -158,27 +158,27 @@ class BaseCollector(abc.ABC):
return _result
def save_instrument(self, symbol, df: pd.DataFrame):
"""save stock data to file
"""save instrument data to file
Parameters
----------
symbol: str
stock code
instrument code
df : pd.DataFrame
df.columns must contain "symbol" and "datetime"
"""
if df.empty:
if df is None or df.empty:
logger.warning(f"{symbol} is empty")
return
symbol = self.normalize_symbol(symbol)
symbol = code_to_fname(symbol)
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
instrument_path = self.save_dir.joinpath(f"{symbol}.csv")
df["symbol"] = symbol
if stock_path.exists():
_old_df = pd.read_csv(stock_path)
if instrument_path.exists():
_old_df = pd.read_csv(instrument_path)
df = _old_df.append(df, sort=False)
df.to_csv(stock_path, index=False)
df.to_csv(instrument_path, index=False)
def cache_small_data(self, symbol, df):
if len(df) <= self.min_numbers_trading:
@@ -191,38 +191,38 @@ class BaseCollector(abc.ABC):
self.mini_symbol_map.pop(symbol)
return self.NORMAL_FLAG
def _collector(self, stock_list):
def _collector(self, instrument_list):
error_symbol = []
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
with tqdm(total=len(stock_list)) as p_bar:
for _symbol, _result in zip(stock_list, executor.map(self._simple_collector, stock_list)):
with tqdm(total=len(instrument_list)) as p_bar:
for _symbol, _result in zip(instrument_list, executor.map(self._simple_collector, instrument_list)):
if _result != self.NORMAL_FLAG:
error_symbol.append(_symbol)
p_bar.update()
print(error_symbol)
logger.info(f"error symbol nums: {len(error_symbol)}")
logger.info(f"current get symbol nums: {len(stock_list)}")
logger.info(f"current get symbol nums: {len(instrument_list)}")
error_symbol.extend(self.mini_symbol_map.keys())
return sorted(set(error_symbol))
def collector_data(self):
"""collector data"""
logger.info("start collector data......")
stock_list = self.stock_list
instrument_list = self.instrument_list
for i in range(self.max_collector_count):
if not stock_list:
if not instrument_list:
break
logger.info(f"getting data: {i+1}")
stock_list = self._collector(stock_list)
instrument_list = self._collector(instrument_list)
logger.info(f"{i+1} finish.")
for _symbol, _df_list in self.mini_symbol_map.items():
self.save_instrument(
_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])
)
if self.mini_symbol_map:
logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self.mini_symbol_map.keys())}")
logger.info(f"total {len(self.stock_list)}, error: {len(set(stock_list))}")
logger.warning(f"less than {self.min_numbers_trading} instrument list: {list(self.mini_symbol_map.keys())}")
logger.info(f"total {len(self.instrument_list)}, error: {len(set(instrument_list))}")
class BaseNormalize(abc.ABC):
@@ -386,9 +386,9 @@ class BaseRun(abc.ABC):
Examples
---------
# get daily data
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
$ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
# get 1m data
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
$ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
"""
_class = getattr(self._cur_module, self.collector_class_name) # type: Type[BaseCollector]
@@ -416,7 +416,7 @@ class BaseRun(abc.ABC):
Examples
---------
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d
$ python collector.py normalize_data --source_dir ~/.qlib/instrument_data/source --normalize_dir ~/.qlib/instrument_data/normalize --region CN --interval 1d
"""
_class = getattr(self._cur_module, self.normalize_class_name)
yc = Normalize(

View File

@@ -0,0 +1,51 @@
# Collect Fund Data
> *Please pay **ATTENTION** that the data is collected from [天天基金网](https://fund.eastmoney.com/) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*
## Requirements
```bash
pip install -r requirements.txt
```
## Collector Data
### CN Data
#### 1d from East Money
```bash
# download from eastmoney.com
python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
# normalize
python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_1d --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ
# dump data
cd qlib/scripts
python dump_bin.py dump_all --csv_path ~/.qlib/fund_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/cn_fund_data --freq day --date_field_name FSRQ --include_fields DWJZ,LJJZ
```
### using data
```python
import qlib
from qlib.data import D
qlib.init(provider_uri="~/.qlib/qlib_data/cn_fund_data")
df = D.features(D.instruments(market="all"), ["$DWJZ", "$LJJZ"], freq="day")
```
### Help
```bash
pythono collector.py collector_data --help
```
## Parameters
- interval: 1d
- region: CN

View File

@@ -0,0 +1,312 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import abc
import sys
import copy
import time
import datetime
import importlib
import json
from abc import ABC
from pathlib import Path
from typing import Iterable, Type
import fire
import requests
import numpy as np
import pandas as pd
from loguru import logger
from dateutil.tz import tzlocal
from qlib.config import REG_CN as REGION_CN
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
from data_collector.utils import get_calendar_list, get_en_fund_symbols
INDEX_BENCH_URL = "http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}"
class FundCollector(BaseCollector):
def __init__(
self,
save_dir: [str, Path],
start=None,
end=None,
interval="1d",
max_workers=4,
max_collector_count=2,
delay=0,
check_data_length: bool = False,
limit_nums: int = None,
):
"""
Parameters
----------
save_dir: str
fund save dir
max_workers: int
workers, default 4
max_collector_count: int
default 2
delay: float
time.sleep(delay), default 0
interval: str
freq, value from [1min, 1d], default 1min
start: str
start datetime, default None
end: str
end datetime, default None
check_data_length: bool
check data length, by default False
limit_nums: int
using for debug, by default None
"""
super(FundCollector, self).__init__(
save_dir=save_dir,
start=start,
end=end,
interval=interval,
max_workers=max_workers,
max_collector_count=max_collector_count,
delay=delay,
check_data_length=check_data_length,
limit_nums=limit_nums,
)
self.init_datetime()
def init_datetime(self):
if self.interval == self.INTERVAL_1min:
self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN)
elif self.interval == self.INTERVAL_1d:
pass
else:
raise ValueError(f"interval error: {self.interval}")
self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)
self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)
@staticmethod
def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone):
try:
dt = pd.Timestamp(dt, tz=timezone).timestamp()
dt = pd.Timestamp(dt, tz=tzlocal(), unit="s")
except ValueError as e:
pass
return dt
@property
@abc.abstractmethod
def _timezone(self):
raise NotImplementedError("rewrite get_timezone")
@staticmethod
def get_data_from_remote(symbol, interval, start, end):
error_msg = f"{symbol}-{interval}-{start}-{end}"
try:
# TODO: numberOfHistoricalDaysToCrawl should be bigger enouhg
url = INDEX_BENCH_URL.format(
index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end
)
resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"})
if resp.status_code != 200:
raise ValueError("request error")
data = json.loads(resp.text.split("(")[-1].split(")")[0])
# Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html
SYType = data["Data"]["SYType"]
if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"):
raise Exception("The fund contains 每*份收益")
# TODO: should we sort the value by datetime?
_resp = pd.DataFrame(data["Data"]["LSJZList"])
if isinstance(_resp, pd.DataFrame):
return _resp.reset_index()
except Exception as e:
logger.warning(f"{error_msg}:{e}")
def get_data(
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
) -> [pd.DataFrame]:
def _get_simple(start_, end_):
self.sleep()
_remote_interval = interval
return self.get_data_from_remote(
symbol,
interval=_remote_interval,
start=start_,
end=end_,
)
if interval == self.INTERVAL_1d:
_result = _get_simple(start_datetime, end_datetime)
else:
raise ValueError(f"cannot support {interval}")
return _result
class FundollectorCN(FundCollector, ABC):
def get_instrument_list(self):
logger.info("get cn fund symbols......")
symbols = get_en_fund_symbols()
logger.info(f"get {len(symbols)} symbols.")
return symbols
def normalize_symbol(self, symbol):
return symbol
@property
def _timezone(self):
return "Asia/Shanghai"
class FundCollectorCN1d(FundollectorCN):
@property
def min_numbers_trading(self):
return 252 / 4
class FundNormalize(BaseNormalize):
DAILY_FORMAT = "%Y-%m-%d"
@staticmethod
def normalize_fund(
df: pd.DataFrame,
calendar_list: list = None,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
):
if df.empty:
return df
df = df.copy()
df.set_index(date_field_name, inplace=True)
df.index = pd.to_datetime(df.index)
df = df[~df.index.duplicated(keep="first")]
if calendar_list is not None:
df = df.reindex(
pd.DataFrame(index=calendar_list)
.loc[
pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date()
+ pd.Timedelta(hours=23, minutes=59)
]
.index
)
df.sort_index(inplace=True)
df.index.names = [date_field_name]
return df.reset_index()
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
# normalize
df = self.normalize_fund(df, self._calendar_list, self._date_field_name, self._symbol_field_name)
return df
class FundNormalize1d(FundNormalize):
pass
class FundNormalizeCN:
def _get_calendar_list(self):
return get_calendar_list("ALL")
class FundNormalizeCN1d(FundNormalizeCN, FundNormalize1d):
pass
class Run(BaseRun):
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN):
"""
Parameters
----------
source_dir: str
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
normalize_dir: str
Directory for normalize data, default "Path(__file__).parent/normalize"
max_workers: int
Concurrent number, default is 4
interval: str
freq, value from [1min, 1d], default 1d
region: str
region, value from ["CN"], default "CN"
"""
super().__init__(source_dir, normalize_dir, max_workers, interval)
self.region = region
@property
def collector_class_name(self):
return f"FundCollector{self.region.upper()}{self.interval}"
@property
def normalize_class_name(self):
return f"FundNormalize{self.region.upper()}{self.interval}"
@property
def default_base_dir(self) -> [Path, str]:
return CUR_DIR
def download_data(
self,
max_collector_count=2,
delay=0,
start=None,
end=None,
interval="1d",
check_data_length=False,
limit_nums=None,
):
"""download data from Internet
Parameters
----------
max_collector_count: int
default 2
delay: float
time.sleep(delay), default 0
interval: str
freq, value from [1min, 1d], default 1d
start: str
start datetime, default "2000-01-01"
end: str
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
check_data_length: bool # if this param useful?
check data length, by default False
limit_nums: int
using for debug, by default None
Examples
---------
# get daily data
$ python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
"""
super(Run, self).download_data(max_collector_count, delay, start, end, interval, check_data_length, limit_nums)
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
"""normalize data
Parameters
----------
date_field_name: str
date field name, default date
symbol_field_name: str
symbol field name, default symbol
Examples
---------
$ python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_1d --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ
"""
super(Run, self).normalize_data(date_field_name, symbol_field_name)
if __name__ == "__main__":
fire.Fire(Run)

View File

@@ -0,0 +1,10 @@
loguru
fire
requests
numpy
pandas
tqdm
lxml
loguru
yahooquery
json

View File

@@ -2,6 +2,7 @@
# Licensed under the MIT License.
import re
import os
import time
import bisect
import pickle
@@ -14,6 +15,9 @@ import pandas as pd
from lxml import etree
from loguru import logger
from yahooquery import Ticker
from tqdm import tqdm
from functools import partial
from concurrent.futures import ProcessPoolExecutor
HS_SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}"
@@ -34,6 +38,7 @@ _BENCH_CALENDAR_LIST = None
_ALL_CALENDAR_LIST = None
_HS_SYMBOLS = None
_US_SYMBOLS = None
_EN_FUND_SYMBOLS = None
_CALENDAR_MAP = {}
# NOTE: Until 2020-10-20 20:00:00
@@ -93,6 +98,78 @@ def get_calendar_list(bench_code="CSI300") -> list:
return calendar
def return_date_list(date_field_name: str, file_path: Path):
date_list = pd.read_csv(file_path, sep=",", index_col=0)[date_field_name].to_list()
return sorted(map(lambda x: pd.Timestamp(x), date_list))
def get_calendar_list_by_ratio(
source_dir: [str, Path],
date_field_name: str = "date",
threshold: float = 0.5,
minimum_count: int = 10,
max_workers: int = 16,
) -> list:
"""get calendar list by selecting the date when few funds trade in this day
Parameters
----------
source_dir: str or Path
The directory where the raw data collected from the Internet is saved
date_field_name: str
date field name, default is date
threshold: float
threshold to exclude some days when few funds trade in this day, default 0.5
minimum_count: int
minimum count of funds should trade in one day
max_workers: int
Concurrent number, default is 16
Returns
-------
history calendar list
"""
logger.info(f"get calendar list from {source_dir} by threshold = {threshold}......")
source_dir = Path(source_dir).expanduser()
file_list = list(source_dir.glob("*.csv"))
_number_all_funds = len(file_list)
logger.info(f"count how many funds trade in this day......")
_dict_count_trade = dict() # dict{date:count}
_fun = partial(return_date_list, date_field_name)
all_oldest_list = []
with tqdm(total=_number_all_funds) as p_bar:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
for date_list in executor.map(_fun, file_list):
if date_list:
all_oldest_list.append(date_list[0])
for date in date_list:
if date not in _dict_count_trade.keys():
_dict_count_trade[date] = 0
_dict_count_trade[date] += 1
p_bar.update()
logger.info(f"count how many funds have founded in this day......")
_dict_count_founding = {date: _number_all_funds for date in _dict_count_trade.keys()} # dict{date:count}
with tqdm(total=_number_all_funds) as p_bar:
for oldest_date in all_oldest_list:
for date in _dict_count_founding.keys():
if date < oldest_date:
_dict_count_founding[date] -= 1
calendar = [
date
for date in _dict_count_trade
if _dict_count_trade[date] >= max(int(_dict_count_founding[date] * threshold), minimum_count)
]
return calendar
def get_hs_stock_symbols() -> list:
"""get SH/SZ stock symbols
@@ -220,6 +297,42 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
return _US_SYMBOLS
def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list:
"""get en fund symbols
Returns
-------
fund symbols in China
"""
global _EN_FUND_SYMBOLS
@deco_retry
def _get_eastmoney():
url = "http://fund.eastmoney.com/js/fundcode_search.js"
resp = requests.get(url)
if resp.status_code != 200:
raise ValueError("request error")
try:
_symbols = []
for sub_data in re.findall(r"[\[](.*?)[\]]", resp.content.decode().split("= [")[-1].replace("];", "")):
data = sub_data.replace('"', "").replace("'", "")
# TODO: do we need other informations, like fund_name from ['000001', 'HXCZHH', '华夏成长混合', '混合型', 'HUAXIACHENGZHANGHUNHE']
_symbols.append(data.split(",")[0])
except Exception as e:
logger.warning(f"request error: {e}")
raise
if len(_symbols) < 8000:
raise ValueError("request error")
return _symbols
if _EN_FUND_SYMBOLS is None:
_all_symbols = _get_eastmoney()
_EN_FUND_SYMBOLS = sorted(set(_all_symbols))
return _EN_FUND_SYMBOLS
def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str:
"""symbol suffix to prefix