diff --git a/scripts/data_collector/fund/README.md b/scripts/data_collector/fund/README.md index c7b91a3f5..b14938a3d 100644 --- a/scripts/data_collector/fund/README.md +++ b/scripts/data_collector/fund/README.md @@ -17,8 +17,8 @@ pip install -r requirements.txt ```bash -# download from yahoo finance -python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d +# download from eastmoney.com +python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d # dump data diff --git a/scripts/data_collector/fund/collector.py b/scripts/data_collector/fund/collector.py index e69de29bb..404b6af0e 100644 --- a/scripts/data_collector/fund/collector.py +++ b/scripts/data_collector/fund/collector.py @@ -0,0 +1,408 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import abc +import sys +import copy +import time +import datetime +import importlib +import json +from abc import ABC +from pathlib import Path +from typing import Iterable, Type +from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor + +import fire +import requests +import numpy as np +import pandas as pd +from tqdm import tqdm +from loguru import logger +from dateutil.tz import tzlocal +from qlib.utils import code_to_fname, fname_to_code + +CUR_DIR = Path(__file__).resolve().parent +sys.path.append(str(CUR_DIR.parent.parent)) +from data_collector.utils import get_calendar_list, get_en_fund_symbols + +INDEX_BENCH_URL = "http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}" +REGION_CN = "CN" +REGION_US = "US" + + +class FundData: + START_DATETIME = pd.Timestamp("2000-01-01") + HIGH_FREQ_START_DATETIME = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6)) + END_DATETIME = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)) + INTERVAL_1d = "1d" + + def __init__( + self, + timezone: str = None, + start=None, + end=None, + interval="1d", + delay=0, + show_1min_logging: bool = False, + ): + """ + + Parameters + ---------- + timezone: str + The timezone where the data is located + delay: float + time.sleep(delay), default 0 + interval: str + freq, value from [1d], default 1d + start: str + start datetime, default None + end: str + end datetime, default None + show_1min_logging: bool + show 1min logging, by default False; if True, there may be many warning logs + """ + self._timezone = tzlocal() if timezone is None else timezone + self._delay = delay + self._interval = interval + self._show_1min_logging = show_1min_logging + self.start_datetime = pd.Timestamp(str(start)) if start else self.START_DATETIME + self.end_datetime = min(pd.Timestamp(str(end)) if end else self.END_DATETIME, self.END_DATETIME) + if self._interval != self.INTERVAL_1d: + raise ValueError(f"interval error: {self._interval}") + + # using for 1min + self._next_datetime = self.convert_datetime(self.start_datetime.date() + pd.Timedelta(days=1), self._timezone) + self._latest_datetime = self.convert_datetime(self.end_datetime.date(), self._timezone) + + self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone) + self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone) + + @staticmethod + def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone): + try: + dt = pd.Timestamp(dt, tz=timezone).timestamp() + dt = pd.Timestamp(dt, tz=tzlocal(), unit="s") + except ValueError as e: + pass + return dt + + def _sleep(self): + time.sleep(self._delay) + + @staticmethod + def get_data_from_remote(symbol, interval, start, end, show_1min_logging: bool = False): + error_msg = f"{symbol}-{interval}-{start}-{end}" + + try: + _resp = None + + # TODO: numberOfHistoricalDaysToCrawl should be bigger enouhg + url = INDEX_BENCH_URL.format(index_code=symbol, numberOfHistoricalDaysToCrawl=100, startDate=start, endDate=end) + resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"}) + + if resp.status_code != 200: + raise ValueError("request error") + try: + data = json.loads(resp.text.split("(")[-1].split(")")[0]) + + # Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html + SYType = data["Data"]["SYType"] + if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"): + raise Exception("The fund contains 每*份收益") + + _resp = pd.DataFrame( + data["Data"]["LSJZList"] + ) + + except Exception as e: + logger.warning(f"request error: {e}") + raise + + if isinstance(_resp, pd.DataFrame): + return _resp.reset_index() + except Exception as e: + logger.warning(f"{error_msg}:{e}") + + def get_data(self, symbol: str) -> [pd.DataFrame]: + def _get_simple(start_, end_): + self._sleep() + _remote_interval = self._interval + return self.get_data_from_remote( + symbol, + interval=_remote_interval, + start=start_, + end=end_, + show_1min_logging=self._show_1min_logging, + ) + + if self._interval == self.INTERVAL_1d: + _result = _get_simple(self.start_datetime, self.end_datetime) + else: + raise ValueError(f"cannot support {self._interval}") + return _result + + +class FundCollector: + def __init__( + self, + save_dir: [str, Path], + start=None, + end=None, + interval="1d", + max_workers=4, + max_collector_count=2, + delay=0, + check_data_length: bool = False, + limit_nums: int = None, + show_1min_logging: bool = False, + ): + """ + + Parameters + ---------- + save_dir: str + stock save dir + max_workers: int + workers, default 4 + max_collector_count: int + default 2 + delay: float + time.sleep(delay), default 0 + interval: str + freq, value from [1min, 1d], default 1min + start: str + start datetime, default None + end: str + end datetime, default None + check_data_length: bool + check data length, by default False + limit_nums: int + using for debug, by default None + show_1min_logging: bool + show 1m logging, by default False; if True, there may be many warning logs + """ + self.save_dir = Path(save_dir).expanduser().resolve() + self.save_dir.mkdir(parents=True, exist_ok=True) + + self._delay = delay + self.max_workers = max_workers + self._max_collector_count = max_collector_count + self._mini_symbol_map = {} + self._interval = interval + self._check_small_data = check_data_length + + self.fund_list = sorted(set(self.get_fund_list())) + if limit_nums is not None: + try: + self.fund_list = self.fund_list[: int(limit_nums)] + except Exception as e: + logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored") + + self.fund_data = FundData( + timezone=self._timezone, + start=start, + end=end, + interval=interval, + delay=delay, + show_1min_logging=show_1min_logging, + ) + + @property + @abc.abstractmethod + def min_numbers_trading(self): + # daily, one year: 252 / 4 + # us 1min, a week: 6.5 * 60 * 5 + # cn 1min, a week: 4 * 60 * 5 + raise NotImplementedError("rewrite min_numbers_trading") + + @abc.abstractmethod + def get_fund_list(self): + raise NotImplementedError("rewrite get_fund_list") + + @property + @abc.abstractmethod + def _timezone(self): + raise NotImplementedError("rewrite get_timezone") + + def save_fund(self, symbol, df: pd.DataFrame): + """save fund data to file + + Parameters + ---------- + symbol: str + fund code + df : pd.DataFrame + df.columns must contain "symbol" and "datetime" + """ + if df.empty: + logger.warning(f"{symbol} is empty") + return + + symbol = code_to_fname(symbol) + stock_path = self.save_dir.joinpath(f"{symbol}.csv") + df["symbol"] = symbol + if stock_path.exists(): + _old_df = pd.read_csv(stock_path) + df = _old_df.append(df, sort=False) + df.to_csv(stock_path, index=False) + + def _save_small_data(self, symbol, df): + if len(df) <= self.min_numbers_trading: + logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!") + _temp = self._mini_symbol_map.setdefault(symbol, []) + _temp.append(df.copy()) + return None + else: + if symbol in self._mini_symbol_map: + self._mini_symbol_map.pop(symbol) + return symbol + + def _get_data(self, symbol): + _result = None + df = self.fund_data.get_data(symbol) + if isinstance(df, pd.DataFrame): + if not df.empty: + if self._check_small_data: + if self._save_small_data(symbol, df) is not None: + _result = symbol + self.save_fund(symbol, df) + else: + _result = symbol + self.save_fund(symbol, df) + return _result + + def _collector(self, fund_list): + + error_symbol = [] + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + with tqdm(total=len(fund_list)) as p_bar: + for _symbol, _result in zip(fund_list, executor.map(self._get_data, fund_list)): + if _result is None: + error_symbol.append(_symbol) + p_bar.update() + print(error_symbol) + logger.info(f"error symbol nums: {len(error_symbol)}") + logger.info(f"current get symbol nums: {len(fund_list)}") + error_symbol.extend(self._mini_symbol_map.keys()) + return sorted(set(error_symbol)) + + def collector_data(self): + """collector data""" + logger.info("start collector fund data......") + fund_list = self.fund_list + for i in range(self._max_collector_count): + if not fund_list: + break + logger.info(f"getting data: {i+1}") + fund_list = self._collector(fund_list) + logger.info(f"{i+1} finish.") + for _symbol, _df_list in self._mini_symbol_map.items(): + self.save_fund(_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])) + if self._mini_symbol_map: + logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self._mini_symbol_map.keys())}") + logger.info(f"total {len(self.fund_list)}, error: {len(set(fund_list))}") + +class FundollectorCN(FundCollector, ABC): + def get_fund_list(self): + logger.info("get cn fund symbols......") + symbols = get_en_fund_symbols() + logger.info(f"get {len(symbols)} symbols.") + return symbols + + @property + def _timezone(self): + return "Asia/Shanghai" + + +class FundCollectorCN1d(FundollectorCN): + @property + def min_numbers_trading(self): + return 252 / 4 + +class Run: + def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN): + """ + + Parameters + ---------- + source_dir: str + The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source" + normalize_dir: str + Directory for normalize data, default "Path(__file__).parent/normalize" + max_workers: int + Concurrent number, default is 4 + region: str + region, value from ["CN", "US"], default "CN" + """ + if source_dir is None: + source_dir = CUR_DIR.joinpath("source") + self.source_dir = Path(source_dir).expanduser().resolve() + self.source_dir.mkdir(parents=True, exist_ok=True) + + if normalize_dir is None: + normalize_dir = CUR_DIR.joinpath("normalize") + self.normalize_dir = Path(normalize_dir).expanduser().resolve() + self.normalize_dir.mkdir(parents=True, exist_ok=True) + + self._cur_module = importlib.import_module("collector") + self.max_workers = max_workers + self.region = region + + def download_data( + self, + max_collector_count=2, + delay=0, + start=None, + end=None, + interval="1d", + check_data_length=False, + limit_nums=None, + show_1min_logging=False, + ): + """download data from Internet + + Parameters + ---------- + max_collector_count: int + default 2 + delay: float + time.sleep(delay), default 0 + interval: str + freq, value from [1min, 1d], default 1d + start: str + start datetime, default "2000-01-01" + end: str + end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`` + check_data_length: bool + check data length, by default False + limit_nums: int + using for debug, by default None + show_1min_logging: bool + show 1m logging, by default False; if True, there may be many warning logs + + Examples + --------- + # get daily data + $ python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d + """ + + _class = getattr( + self._cur_module, f"FundCollector{self.region.upper()}{interval}" + ) # type: Type[FundCollector] + _class( + self.source_dir, + max_workers=self.max_workers, + max_collector_count=max_collector_count, + delay=delay, + start=start, + end=end, + interval=interval, + check_data_length=check_data_length, + limit_nums=limit_nums, + show_1min_logging=show_1min_logging, + ).collector_data() + +if __name__ == "__main__": + fire.Fire(Run) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 5f34aae7d..3319025fc 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -34,6 +34,7 @@ _BENCH_CALENDAR_LIST = None _ALL_CALENDAR_LIST = None _HS_SYMBOLS = None _US_SYMBOLS = None +_EN_FUND_SYMBOLS = None _CALENDAR_MAP = {} # NOTE: Until 2020-10-20 20:00:00 @@ -220,6 +221,42 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: return _US_SYMBOLS +def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list: + """get en fund symbols + + Returns + ------- + fund symbols in China + """ + global _EN_FUND_SYMBOLS + + @deco_retry + def _get_eastmoney(): + url = "http://fund.eastmoney.com/js/fundcode_search.js" + resp = requests.get(url) + if resp.status_code != 200: + raise ValueError("request error") + try: + _symbols = [] + for sub_data in re.findall(r"[\[](.*?)[\]]", resp.content.decode().split("= [")[-1].replace("];", "")): + data = sub_data.replace("\"","").replace("'","") + # TODO: do we need other informations, like fund_name from ['000001', 'HXCZHH', '华夏成长混合', '混合型', 'HUAXIACHENGZHANGHUNHE'] + _symbols.append(data.split(",")[0]) + except Exception as e: + logger.warning(f"request error: {e}") + raise + if len(_symbols) < 8000: + raise ValueError("request error") + return _symbols + + if _EN_FUND_SYMBOLS is None: + _all_symbols = _get_eastmoney() + + _EN_FUND_SYMBOLS = sorted(set(_all_symbols)) + + return _EN_FUND_SYMBOLS + + def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str: """symbol suffix to prefix