diff --git a/scripts/data_collector/base.py b/scripts/data_collector/base.py index ccd8f59e5..12983f6a5 100644 --- a/scripts/data_collector/base.py +++ b/scripts/data_collector/base.py @@ -46,7 +46,7 @@ class BaseCollector(abc.ABC): Parameters ---------- save_dir: str - stock save dir + instrument save dir max_workers: int workers, default 4 max_collector_count: int @@ -77,11 +77,11 @@ class BaseCollector(abc.ABC): self.start_datetime = self.normalize_start_datetime(start) self.end_datetime = self.normalize_end_datetime(end) - self.stock_list = sorted(set(self.get_stock_list())) + self.instrument_list = sorted(set(self.get_instrument_list())) if limit_nums is not None: try: - self.stock_list = self.stock_list[: int(limit_nums)] + self.instrument_list = self.instrument_list[: int(limit_nums)] except Exception as e: logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored") @@ -108,8 +108,8 @@ class BaseCollector(abc.ABC): raise NotImplementedError("rewrite min_numbers_trading") @abc.abstractmethod - def get_stock_list(self): - raise NotImplementedError("rewrite get_stock_list") + def get_instrument_list(self): + raise NotImplementedError("rewrite get_instrument_list") @abc.abstractmethod def normalize_symbol(self, symbol: str): @@ -158,27 +158,27 @@ class BaseCollector(abc.ABC): return _result def save_instrument(self, symbol, df: pd.DataFrame): - """save stock data to file + """save instrument data to file Parameters ---------- symbol: str - stock code + instrument code df : pd.DataFrame df.columns must contain "symbol" and "datetime" """ - if df.empty: + if df is None or df.empty: logger.warning(f"{symbol} is empty") return symbol = self.normalize_symbol(symbol) symbol = code_to_fname(symbol) - stock_path = self.save_dir.joinpath(f"{symbol}.csv") + instrument_path = self.save_dir.joinpath(f"{symbol}.csv") df["symbol"] = symbol - if stock_path.exists(): - _old_df = pd.read_csv(stock_path) + if instrument_path.exists(): + _old_df = pd.read_csv(instrument_path) df = _old_df.append(df, sort=False) - df.to_csv(stock_path, index=False) + df.to_csv(instrument_path, index=False) def cache_small_data(self, symbol, df): if len(df) <= self.min_numbers_trading: @@ -191,38 +191,38 @@ class BaseCollector(abc.ABC): self.mini_symbol_map.pop(symbol) return self.NORMAL_FLAG - def _collector(self, stock_list): + def _collector(self, instrument_list): error_symbol = [] with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - with tqdm(total=len(stock_list)) as p_bar: - for _symbol, _result in zip(stock_list, executor.map(self._simple_collector, stock_list)): + with tqdm(total=len(instrument_list)) as p_bar: + for _symbol, _result in zip(instrument_list, executor.map(self._simple_collector, instrument_list)): if _result != self.NORMAL_FLAG: error_symbol.append(_symbol) p_bar.update() print(error_symbol) logger.info(f"error symbol nums: {len(error_symbol)}") - logger.info(f"current get symbol nums: {len(stock_list)}") + logger.info(f"current get symbol nums: {len(instrument_list)}") error_symbol.extend(self.mini_symbol_map.keys()) return sorted(set(error_symbol)) def collector_data(self): """collector data""" logger.info("start collector data......") - stock_list = self.stock_list + instrument_list = self.instrument_list for i in range(self.max_collector_count): - if not stock_list: + if not instrument_list: break logger.info(f"getting data: {i+1}") - stock_list = self._collector(stock_list) + instrument_list = self._collector(instrument_list) logger.info(f"{i+1} finish.") for _symbol, _df_list in self.mini_symbol_map.items(): self.save_instrument( _symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"]) ) if self.mini_symbol_map: - logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self.mini_symbol_map.keys())}") - logger.info(f"total {len(self.stock_list)}, error: {len(set(stock_list))}") + logger.warning(f"less than {self.min_numbers_trading} instrument list: {list(self.mini_symbol_map.keys())}") + logger.info(f"total {len(self.instrument_list)}, error: {len(set(instrument_list))}") class BaseNormalize(abc.ABC): @@ -386,9 +386,9 @@ class BaseRun(abc.ABC): Examples --------- # get daily data - $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d + $ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d # get 1m data - $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m + $ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m """ _class = getattr(self._cur_module, self.collector_class_name) # type: Type[BaseCollector] @@ -416,7 +416,7 @@ class BaseRun(abc.ABC): Examples --------- - $ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d + $ python collector.py normalize_data --source_dir ~/.qlib/instrument_data/source --normalize_dir ~/.qlib/instrument_data/normalize --region CN --interval 1d """ _class = getattr(self._cur_module, self.normalize_class_name) yc = Normalize( diff --git a/scripts/data_collector/fund/README.md b/scripts/data_collector/fund/README.md new file mode 100644 index 000000000..c729b7eaa --- /dev/null +++ b/scripts/data_collector/fund/README.md @@ -0,0 +1,51 @@ +# Collect Fund Data + +> *Please pay **ATTENTION** that the data is collected from [天天基金网](https://fund.eastmoney.com/) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)* + +## Requirements + +```bash +pip install -r requirements.txt +``` + +## Collector Data + + +### CN Data + +#### 1d from East Money + +```bash + +# download from eastmoney.com +python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d + +# normalize +python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_1d --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ + +# dump data +cd qlib/scripts +python dump_bin.py dump_all --csv_path ~/.qlib/fund_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/cn_fund_data --freq day --date_field_name FSRQ --include_fields DWJZ,LJJZ + +``` + +### using data + +```python +import qlib +from qlib.data import D + +qlib.init(provider_uri="~/.qlib/qlib_data/cn_fund_data") +df = D.features(D.instruments(market="all"), ["$DWJZ", "$LJJZ"], freq="day") +``` + + +### Help +```bash +pythono collector.py collector_data --help +``` + +## Parameters + +- interval: 1d +- region: CN diff --git a/scripts/data_collector/fund/collector.py b/scripts/data_collector/fund/collector.py new file mode 100644 index 000000000..10800a7a3 --- /dev/null +++ b/scripts/data_collector/fund/collector.py @@ -0,0 +1,312 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import abc +import sys +import copy +import time +import datetime +import importlib +import json +from abc import ABC +from pathlib import Path +from typing import Iterable, Type + +import fire +import requests +import numpy as np +import pandas as pd +from loguru import logger +from dateutil.tz import tzlocal +from qlib.config import REG_CN as REGION_CN + +CUR_DIR = Path(__file__).resolve().parent +sys.path.append(str(CUR_DIR.parent.parent)) +from data_collector.base import BaseCollector, BaseNormalize, BaseRun +from data_collector.utils import get_calendar_list, get_en_fund_symbols + +INDEX_BENCH_URL = "http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}" + + +class FundCollector(BaseCollector): + def __init__( + self, + save_dir: [str, Path], + start=None, + end=None, + interval="1d", + max_workers=4, + max_collector_count=2, + delay=0, + check_data_length: bool = False, + limit_nums: int = None, + ): + """ + + Parameters + ---------- + save_dir: str + fund save dir + max_workers: int + workers, default 4 + max_collector_count: int + default 2 + delay: float + time.sleep(delay), default 0 + interval: str + freq, value from [1min, 1d], default 1min + start: str + start datetime, default None + end: str + end datetime, default None + check_data_length: bool + check data length, by default False + limit_nums: int + using for debug, by default None + """ + super(FundCollector, self).__init__( + save_dir=save_dir, + start=start, + end=end, + interval=interval, + max_workers=max_workers, + max_collector_count=max_collector_count, + delay=delay, + check_data_length=check_data_length, + limit_nums=limit_nums, + ) + + self.init_datetime() + + def init_datetime(self): + if self.interval == self.INTERVAL_1min: + self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN) + elif self.interval == self.INTERVAL_1d: + pass + else: + raise ValueError(f"interval error: {self.interval}") + + self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone) + self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone) + + @staticmethod + def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone): + try: + dt = pd.Timestamp(dt, tz=timezone).timestamp() + dt = pd.Timestamp(dt, tz=tzlocal(), unit="s") + except ValueError as e: + pass + return dt + + @property + @abc.abstractmethod + def _timezone(self): + raise NotImplementedError("rewrite get_timezone") + + @staticmethod + def get_data_from_remote(symbol, interval, start, end): + error_msg = f"{symbol}-{interval}-{start}-{end}" + + try: + # TODO: numberOfHistoricalDaysToCrawl should be bigger enouhg + url = INDEX_BENCH_URL.format( + index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end + ) + resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"}) + + if resp.status_code != 200: + raise ValueError("request error") + + data = json.loads(resp.text.split("(")[-1].split(")")[0]) + + # Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html + SYType = data["Data"]["SYType"] + if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"): + raise Exception("The fund contains 每*份收益") + + # TODO: should we sort the value by datetime? + _resp = pd.DataFrame(data["Data"]["LSJZList"]) + + if isinstance(_resp, pd.DataFrame): + return _resp.reset_index() + except Exception as e: + logger.warning(f"{error_msg}:{e}") + + def get_data( + self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp + ) -> [pd.DataFrame]: + def _get_simple(start_, end_): + self.sleep() + _remote_interval = interval + return self.get_data_from_remote( + symbol, + interval=_remote_interval, + start=start_, + end=end_, + ) + + if interval == self.INTERVAL_1d: + _result = _get_simple(start_datetime, end_datetime) + else: + raise ValueError(f"cannot support {interval}") + return _result + + +class FundollectorCN(FundCollector, ABC): + def get_instrument_list(self): + logger.info("get cn fund symbols......") + symbols = get_en_fund_symbols() + logger.info(f"get {len(symbols)} symbols.") + return symbols + + def normalize_symbol(self, symbol): + return symbol + + @property + def _timezone(self): + return "Asia/Shanghai" + + +class FundCollectorCN1d(FundollectorCN): + @property + def min_numbers_trading(self): + return 252 / 4 + + +class FundNormalize(BaseNormalize): + DAILY_FORMAT = "%Y-%m-%d" + + @staticmethod + def normalize_fund( + df: pd.DataFrame, + calendar_list: list = None, + date_field_name: str = "date", + symbol_field_name: str = "symbol", + ): + if df.empty: + return df + df = df.copy() + df.set_index(date_field_name, inplace=True) + df.index = pd.to_datetime(df.index) + df = df[~df.index.duplicated(keep="first")] + if calendar_list is not None: + df = df.reindex( + pd.DataFrame(index=calendar_list) + .loc[ + pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date() + + pd.Timedelta(hours=23, minutes=59) + ] + .index + ) + df.sort_index(inplace=True) + + df.index.names = [date_field_name] + return df.reset_index() + + def normalize(self, df: pd.DataFrame) -> pd.DataFrame: + # normalize + df = self.normalize_fund(df, self._calendar_list, self._date_field_name, self._symbol_field_name) + return df + + +class FundNormalize1d(FundNormalize): + pass + + +class FundNormalizeCN: + def _get_calendar_list(self): + return get_calendar_list("ALL") + + +class FundNormalizeCN1d(FundNormalizeCN, FundNormalize1d): + pass + + +class Run(BaseRun): + def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN): + """ + + Parameters + ---------- + source_dir: str + The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source" + normalize_dir: str + Directory for normalize data, default "Path(__file__).parent/normalize" + max_workers: int + Concurrent number, default is 4 + interval: str + freq, value from [1min, 1d], default 1d + region: str + region, value from ["CN"], default "CN" + """ + super().__init__(source_dir, normalize_dir, max_workers, interval) + self.region = region + + @property + def collector_class_name(self): + return f"FundCollector{self.region.upper()}{self.interval}" + + @property + def normalize_class_name(self): + return f"FundNormalize{self.region.upper()}{self.interval}" + + @property + def default_base_dir(self) -> [Path, str]: + return CUR_DIR + + def download_data( + self, + max_collector_count=2, + delay=0, + start=None, + end=None, + interval="1d", + check_data_length=False, + limit_nums=None, + ): + """download data from Internet + + Parameters + ---------- + max_collector_count: int + default 2 + delay: float + time.sleep(delay), default 0 + interval: str + freq, value from [1min, 1d], default 1d + start: str + start datetime, default "2000-01-01" + end: str + end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`` + check_data_length: bool # if this param useful? + check data length, by default False + limit_nums: int + using for debug, by default None + + Examples + --------- + # get daily data + $ python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d + """ + + super(Run, self).download_data(max_collector_count, delay, start, end, interval, check_data_length, limit_nums) + + def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"): + """normalize data + + Parameters + ---------- + date_field_name: str + date field name, default date + symbol_field_name: str + symbol field name, default symbol + + Examples + --------- + $ python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_1d --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ + """ + super(Run, self).normalize_data(date_field_name, symbol_field_name) + + +if __name__ == "__main__": + fire.Fire(Run) diff --git a/scripts/data_collector/fund/requirements.txt b/scripts/data_collector/fund/requirements.txt new file mode 100644 index 000000000..11c6730c0 --- /dev/null +++ b/scripts/data_collector/fund/requirements.txt @@ -0,0 +1,10 @@ +loguru +fire +requests +numpy +pandas +tqdm +lxml +loguru +yahooquery +json \ No newline at end of file diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 5f34aae7d..e8c9b9dc4 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import re +import os import time import bisect import pickle @@ -14,6 +15,9 @@ import pandas as pd from lxml import etree from loguru import logger from yahooquery import Ticker +from tqdm import tqdm +from functools import partial +from concurrent.futures import ProcessPoolExecutor HS_SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}" @@ -34,6 +38,7 @@ _BENCH_CALENDAR_LIST = None _ALL_CALENDAR_LIST = None _HS_SYMBOLS = None _US_SYMBOLS = None +_EN_FUND_SYMBOLS = None _CALENDAR_MAP = {} # NOTE: Until 2020-10-20 20:00:00 @@ -93,6 +98,78 @@ def get_calendar_list(bench_code="CSI300") -> list: return calendar +def return_date_list(date_field_name: str, file_path: Path): + date_list = pd.read_csv(file_path, sep=",", index_col=0)[date_field_name].to_list() + return sorted(map(lambda x: pd.Timestamp(x), date_list)) + + +def get_calendar_list_by_ratio( + source_dir: [str, Path], + date_field_name: str = "date", + threshold: float = 0.5, + minimum_count: int = 10, + max_workers: int = 16, +) -> list: + """get calendar list by selecting the date when few funds trade in this day + + Parameters + ---------- + source_dir: str or Path + The directory where the raw data collected from the Internet is saved + date_field_name: str + date field name, default is date + threshold: float + threshold to exclude some days when few funds trade in this day, default 0.5 + minimum_count: int + minimum count of funds should trade in one day + max_workers: int + Concurrent number, default is 16 + + Returns + ------- + history calendar list + """ + logger.info(f"get calendar list from {source_dir} by threshold = {threshold}......") + + source_dir = Path(source_dir).expanduser() + file_list = list(source_dir.glob("*.csv")) + + _number_all_funds = len(file_list) + + logger.info(f"count how many funds trade in this day......") + _dict_count_trade = dict() # dict{date:count} + _fun = partial(return_date_list, date_field_name) + all_oldest_list = [] + with tqdm(total=_number_all_funds) as p_bar: + with ProcessPoolExecutor(max_workers=max_workers) as executor: + for date_list in executor.map(_fun, file_list): + if date_list: + all_oldest_list.append(date_list[0]) + for date in date_list: + if date not in _dict_count_trade.keys(): + _dict_count_trade[date] = 0 + + _dict_count_trade[date] += 1 + + p_bar.update() + + logger.info(f"count how many funds have founded in this day......") + _dict_count_founding = {date: _number_all_funds for date in _dict_count_trade.keys()} # dict{date:count} + with tqdm(total=_number_all_funds) as p_bar: + for oldest_date in all_oldest_list: + for date in _dict_count_founding.keys(): + if date < oldest_date: + _dict_count_founding[date] -= 1 + + calendar = [ + date + for date in _dict_count_trade + if _dict_count_trade[date] >= max(int(_dict_count_founding[date] * threshold), minimum_count) + ] + + return calendar + + def get_hs_stock_symbols() -> list: """get SH/SZ stock symbols @@ -220,6 +297,42 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: return _US_SYMBOLS +def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list: + """get en fund symbols + + Returns + ------- + fund symbols in China + """ + global _EN_FUND_SYMBOLS + + @deco_retry + def _get_eastmoney(): + url = "http://fund.eastmoney.com/js/fundcode_search.js" + resp = requests.get(url) + if resp.status_code != 200: + raise ValueError("request error") + try: + _symbols = [] + for sub_data in re.findall(r"[\[](.*?)[\]]", resp.content.decode().split("= [")[-1].replace("];", "")): + data = sub_data.replace('"', "").replace("'", "") + # TODO: do we need other informations, like fund_name from ['000001', 'HXCZHH', '华夏成长混合', '混合型', 'HUAXIACHENGZHANGHUNHE'] + _symbols.append(data.split(",")[0]) + except Exception as e: + logger.warning(f"request error: {e}") + raise + if len(_symbols) < 8000: + raise ValueError("request error") + return _symbols + + if _EN_FUND_SYMBOLS is None: + _all_symbols = _get_eastmoney() + + _EN_FUND_SYMBOLS = sorted(set(_all_symbols)) + + return _EN_FUND_SYMBOLS + + def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str: """symbol suffix to prefix