# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import abc import sys import datetime import json from abc import ABC from pathlib import Path import fire import requests import pandas as pd from loguru import logger from dateutil.tz import tzlocal from qlib.constant import REG_CN as REGION_CN CUR_DIR = Path(__file__).resolve().parent sys.path.append(str(CUR_DIR.parent.parent)) from data_collector.base import BaseCollector, BaseNormalize, BaseRun from data_collector.utils import get_calendar_list, get_en_fund_symbols INDEX_BENCH_URL = "http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}" class FundCollector(BaseCollector): def __init__( self, save_dir: [str, Path], start=None, end=None, interval="1d", max_workers=4, max_collector_count=2, delay=0, check_data_length: int = None, limit_nums: int = None, ): """ Parameters ---------- save_dir: str fund save dir max_workers: int workers, default 4 max_collector_count: int default 2 delay: float time.sleep(delay), default 0 interval: str freq, value from [1min, 1d], default 1min start: str start datetime, default None end: str end datetime, default None check_data_length: int check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None. limit_nums: int using for debug, by default None """ super(FundCollector, self).__init__( save_dir=save_dir, start=start, end=end, interval=interval, max_workers=max_workers, max_collector_count=max_collector_count, delay=delay, check_data_length=check_data_length, limit_nums=limit_nums, ) self.init_datetime() def init_datetime(self): if self.interval == self.INTERVAL_1min: self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN) elif self.interval == self.INTERVAL_1d: pass else: raise ValueError(f"interval error: {self.interval}") self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone) self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone) @staticmethod def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone): try: dt = pd.Timestamp(dt, tz=timezone).timestamp() dt = pd.Timestamp(dt, tz=tzlocal(), unit="s") except ValueError as e: pass return dt @property @abc.abstractmethod def _timezone(self): raise NotImplementedError("rewrite get_timezone") @staticmethod def get_data_from_remote(symbol, interval, start, end): error_msg = f"{symbol}-{interval}-{start}-{end}" try: # TODO: numberOfHistoricalDaysToCrawl should be bigger enough url = INDEX_BENCH_URL.format( index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end ) resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"}, timeout=None) if resp.status_code != 200: raise ValueError("request error") data = json.loads(resp.text.split("(")[-1].split(")")[0]) # Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html SYType = data["Data"]["SYType"] if SYType in {"每万份收益", "每百份收益", "每百万份收益"}: raise ValueError("The fund contains 每*份收益") # TODO: should we sort the value by datetime? _resp = pd.DataFrame(data["Data"]["LSJZList"]) if isinstance(_resp, pd.DataFrame): return _resp.reset_index() except Exception as e: logger.warning(f"{error_msg}:{e}") def get_data( self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp ) -> [pd.DataFrame]: def _get_simple(start_, end_): self.sleep() _remote_interval = interval return self.get_data_from_remote( symbol, interval=_remote_interval, start=start_, end=end_, ) if interval == self.INTERVAL_1d: _result = _get_simple(start_datetime, end_datetime) else: raise ValueError(f"cannot support {interval}") return _result class FundollectorCN(FundCollector, ABC): def get_instrument_list(self): logger.info("get cn fund symbols......") symbols = get_en_fund_symbols() logger.info(f"get {len(symbols)} symbols.") return symbols def normalize_symbol(self, symbol): return symbol @property def _timezone(self): return "Asia/Shanghai" class FundCollectorCN1d(FundollectorCN): pass class FundNormalize(BaseNormalize): DAILY_FORMAT = "%Y-%m-%d" @staticmethod def normalize_fund( df: pd.DataFrame, calendar_list: list = None, date_field_name: str = "date", symbol_field_name: str = "symbol", ): if df.empty: return df df = df.copy() df.set_index(date_field_name, inplace=True) df.index = pd.to_datetime(df.index) df = df[~df.index.duplicated(keep="first")] if calendar_list is not None: df = df.reindex( pd.DataFrame(index=calendar_list) .loc[ pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date() + pd.Timedelta(hours=23, minutes=59) ] .index ) df.sort_index(inplace=True) df.index.names = [date_field_name] return df.reset_index() def normalize(self, df: pd.DataFrame) -> pd.DataFrame: # normalize df = self.normalize_fund(df, self._calendar_list, self._date_field_name, self._symbol_field_name) return df class FundNormalize1d(FundNormalize): pass class FundNormalizeCN: def _get_calendar_list(self): return get_calendar_list("ALL") class FundNormalizeCN1d(FundNormalizeCN, FundNormalize1d): pass class Run(BaseRun): def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN): """ Parameters ---------- source_dir: str The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source" normalize_dir: str Directory for normalize data, default "Path(__file__).parent/normalize" max_workers: int Concurrent number, default is 4 interval: str freq, value from [1min, 1d], default 1d region: str region, value from ["CN"], default "CN" """ super().__init__(source_dir, normalize_dir, max_workers, interval) self.region = region @property def collector_class_name(self): return f"FundCollector{self.region.upper()}{self.interval}" @property def normalize_class_name(self): return f"FundNormalize{self.region.upper()}{self.interval}" @property def default_base_dir(self) -> [Path, str]: return CUR_DIR def download_data( self, max_collector_count=2, delay=0, start=None, end=None, check_data_length: int = None, limit_nums=None, ): """download data from Internet Parameters ---------- max_collector_count: int default 2 delay: float time.sleep(delay), default 0 interval: str freq, value from [1min, 1d], default 1d start: str start datetime, default "2000-01-01" end: str end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`` check_data_length: int # if this param useful? check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None. limit_nums: int using for debug, by default None Examples --------- # get daily data $ python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_data --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d """ super(Run, self).download_data(max_collector_count, delay, start, end, check_data_length, limit_nums) def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"): """normalize data Parameters ---------- date_field_name: str date field name, default date symbol_field_name: str symbol field name, default symbol Examples --------- $ python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_data --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ """ super(Run, self).normalize_data(date_field_name, symbol_field_name) if __name__ == "__main__": fire.Fire(Run)