From 4e7a147759286b8b20729a07381652030d428a58 Mon Sep 17 00:00:00 2001 From: wangershi Date: Sun, 14 Mar 2021 14:24:14 +0800 Subject: [PATCH] use base.py --- scripts/data_collector/base.py | 48 +-- scripts/data_collector/fund/collector.py | 417 ++++++----------------- 2 files changed, 121 insertions(+), 344 deletions(-) diff --git a/scripts/data_collector/base.py b/scripts/data_collector/base.py index ccd8f59e5..12983f6a5 100644 --- a/scripts/data_collector/base.py +++ b/scripts/data_collector/base.py @@ -46,7 +46,7 @@ class BaseCollector(abc.ABC): Parameters ---------- save_dir: str - stock save dir + instrument save dir max_workers: int workers, default 4 max_collector_count: int @@ -77,11 +77,11 @@ class BaseCollector(abc.ABC): self.start_datetime = self.normalize_start_datetime(start) self.end_datetime = self.normalize_end_datetime(end) - self.stock_list = sorted(set(self.get_stock_list())) + self.instrument_list = sorted(set(self.get_instrument_list())) if limit_nums is not None: try: - self.stock_list = self.stock_list[: int(limit_nums)] + self.instrument_list = self.instrument_list[: int(limit_nums)] except Exception as e: logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored") @@ -108,8 +108,8 @@ class BaseCollector(abc.ABC): raise NotImplementedError("rewrite min_numbers_trading") @abc.abstractmethod - def get_stock_list(self): - raise NotImplementedError("rewrite get_stock_list") + def get_instrument_list(self): + raise NotImplementedError("rewrite get_instrument_list") @abc.abstractmethod def normalize_symbol(self, symbol: str): @@ -158,27 +158,27 @@ class BaseCollector(abc.ABC): return _result def save_instrument(self, symbol, df: pd.DataFrame): - """save stock data to file + """save instrument data to file Parameters ---------- symbol: str - stock code + instrument code df : pd.DataFrame df.columns must contain "symbol" and "datetime" """ - if df.empty: + if df is None or df.empty: logger.warning(f"{symbol} is empty") return symbol = self.normalize_symbol(symbol) symbol = code_to_fname(symbol) - stock_path = self.save_dir.joinpath(f"{symbol}.csv") + instrument_path = self.save_dir.joinpath(f"{symbol}.csv") df["symbol"] = symbol - if stock_path.exists(): - _old_df = pd.read_csv(stock_path) + if instrument_path.exists(): + _old_df = pd.read_csv(instrument_path) df = _old_df.append(df, sort=False) - df.to_csv(stock_path, index=False) + df.to_csv(instrument_path, index=False) def cache_small_data(self, symbol, df): if len(df) <= self.min_numbers_trading: @@ -191,38 +191,38 @@ class BaseCollector(abc.ABC): self.mini_symbol_map.pop(symbol) return self.NORMAL_FLAG - def _collector(self, stock_list): + def _collector(self, instrument_list): error_symbol = [] with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - with tqdm(total=len(stock_list)) as p_bar: - for _symbol, _result in zip(stock_list, executor.map(self._simple_collector, stock_list)): + with tqdm(total=len(instrument_list)) as p_bar: + for _symbol, _result in zip(instrument_list, executor.map(self._simple_collector, instrument_list)): if _result != self.NORMAL_FLAG: error_symbol.append(_symbol) p_bar.update() print(error_symbol) logger.info(f"error symbol nums: {len(error_symbol)}") - logger.info(f"current get symbol nums: {len(stock_list)}") + logger.info(f"current get symbol nums: {len(instrument_list)}") error_symbol.extend(self.mini_symbol_map.keys()) return sorted(set(error_symbol)) def collector_data(self): """collector data""" logger.info("start collector data......") - stock_list = self.stock_list + instrument_list = self.instrument_list for i in range(self.max_collector_count): - if not stock_list: + if not instrument_list: break logger.info(f"getting data: {i+1}") - stock_list = self._collector(stock_list) + instrument_list = self._collector(instrument_list) logger.info(f"{i+1} finish.") for _symbol, _df_list in self.mini_symbol_map.items(): self.save_instrument( _symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"]) ) if self.mini_symbol_map: - logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self.mini_symbol_map.keys())}") - logger.info(f"total {len(self.stock_list)}, error: {len(set(stock_list))}") + logger.warning(f"less than {self.min_numbers_trading} instrument list: {list(self.mini_symbol_map.keys())}") + logger.info(f"total {len(self.instrument_list)}, error: {len(set(instrument_list))}") class BaseNormalize(abc.ABC): @@ -386,9 +386,9 @@ class BaseRun(abc.ABC): Examples --------- # get daily data - $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d + $ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d # get 1m data - $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m + $ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m """ _class = getattr(self._cur_module, self.collector_class_name) # type: Type[BaseCollector] @@ -416,7 +416,7 @@ class BaseRun(abc.ABC): Examples --------- - $ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d + $ python collector.py normalize_data --source_dir ~/.qlib/instrument_data/source --normalize_dir ~/.qlib/instrument_data/normalize --region CN --interval 1d """ _class = getattr(self._cur_module, self.normalize_class_name) yc = Normalize( diff --git a/scripts/data_collector/fund/collector.py b/scripts/data_collector/fund/collector.py index 08773cb9f..1e0d2d8bf 100644 --- a/scripts/data_collector/fund/collector.py +++ b/scripts/data_collector/fund/collector.py @@ -11,123 +11,24 @@ import json from abc import ABC from pathlib import Path from typing import Iterable, Type -from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor import fire import requests import numpy as np import pandas as pd -from tqdm import tqdm from loguru import logger from dateutil.tz import tzlocal +from qlib.config import REG_CN as REGION_CN CUR_DIR = Path(__file__).resolve().parent sys.path.append(str(CUR_DIR.parent.parent)) +from data_collector.base import BaseCollector, BaseNormalize, BaseRun from data_collector.utils import get_calendar_list, get_en_fund_symbols INDEX_BENCH_URL = "http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}" -REGION_CN = "CN" -class FundData: - START_DATETIME = pd.Timestamp("2000-01-01") - END_DATETIME = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)) - INTERVAL_1d = "1d" - - def __init__( - self, - timezone: str = None, - start=None, - end=None, - interval="1d", - delay=0, - ): - """ - - Parameters - ---------- - timezone: str - The timezone where the data is located - delay: float - time.sleep(delay), default 0 - interval: str - freq, value from [1d], default 1d - start: str - start datetime, default None - end: str - end datetime, default None - """ - self._timezone = tzlocal() if timezone is None else timezone - self._delay = delay - self._interval = interval - self.start_datetime = pd.Timestamp(str(start)) if start else self.START_DATETIME - self.end_datetime = min(pd.Timestamp(str(end)) if end else self.END_DATETIME, self.END_DATETIME) - if self._interval != self.INTERVAL_1d: - raise ValueError(f"interval error: {self._interval}") - - self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone) - self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone) - - @staticmethod - def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone): - try: - dt = pd.Timestamp(dt, tz=timezone).timestamp() - dt = pd.Timestamp(dt, tz=tzlocal(), unit="s") - except ValueError as e: - pass - return dt - - def _sleep(self): - time.sleep(self._delay) - - @staticmethod - def get_data_from_remote(symbol, interval, start, end): - error_msg = f"{symbol}-{interval}-{start}-{end}" - - try: - # TODO: numberOfHistoricalDaysToCrawl should be bigger enouhg - url = INDEX_BENCH_URL.format( - index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end - ) - resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"}) - - if resp.status_code != 200: - raise ValueError("request error") - - data = json.loads(resp.text.split("(")[-1].split(")")[0]) - - # Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html - SYType = data["Data"]["SYType"] - if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"): - raise Exception("The fund contains 每*份收益") - - # TODO: should we sort the value by datetime? - _resp = pd.DataFrame(data["Data"]["LSJZList"]) - - if isinstance(_resp, pd.DataFrame): - return _resp.reset_index() - except Exception as e: - logger.warning(f"{error_msg}:{e}") - - def get_data(self, symbol: str) -> [pd.DataFrame]: - def _get_simple(start_, end_): - self._sleep() - _remote_interval = self._interval - return self.get_data_from_remote( - symbol, - interval=_remote_interval, - start=start_, - end=end_, - ) - - if self._interval == self.INTERVAL_1d: - _result = _get_simple(self.start_datetime, self.end_datetime) - else: - raise ValueError(f"cannot support {self._interval}") - return _result - - -class FundCollector: +class FundCollector(BaseCollector): def __init__( self, save_dir: [str, Path], @@ -163,134 +64,108 @@ class FundCollector: limit_nums: int using for debug, by default None """ - self.save_dir = Path(save_dir).expanduser().resolve() - self.save_dir.mkdir(parents=True, exist_ok=True) - - self._delay = delay - self.max_workers = max_workers - self._max_collector_count = max_collector_count - self._mini_symbol_map = {} - self._interval = interval - self._check_small_data = check_data_length - - self.fund_list = sorted(set(self.get_fund_list())) - if limit_nums is not None: - try: - self.fund_list = self.fund_list[: int(limit_nums)] - except Exception as e: - logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored") - - self.fund_data = FundData( - timezone=self._timezone, + super(FundCollector, self).__init__( + save_dir=save_dir, start=start, end=end, interval=interval, + max_workers=max_workers, + max_collector_count=max_collector_count, delay=delay, + check_data_length=check_data_length, + limit_nums=limit_nums, ) - @property - @abc.abstractmethod - def min_numbers_trading(self): - # daily, one year: 252 / 4 - # us 1min, a week: 6.5 * 60 * 5 - # cn 1min, a week: 4 * 60 * 5 - raise NotImplementedError("rewrite min_numbers_trading") + self.init_datetime() - @abc.abstractmethod - def get_fund_list(self): - raise NotImplementedError("rewrite get_fund_list") + def init_datetime(self): + if self.interval == self.INTERVAL_1min: + self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN) + elif self.interval == self.INTERVAL_1d: + pass + else: + raise ValueError(f"interval error: {self.interval}") + + self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone) + self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone) + + @staticmethod + def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone): + try: + dt = pd.Timestamp(dt, tz=timezone).timestamp() + dt = pd.Timestamp(dt, tz=tzlocal(), unit="s") + except ValueError as e: + pass + return dt @property @abc.abstractmethod def _timezone(self): raise NotImplementedError("rewrite get_timezone") - def save_fund(self, symbol, df: pd.DataFrame): - """save fund data to file + @staticmethod + def get_data_from_remote(symbol, interval, start, end): + error_msg = f"{symbol}-{interval}-{start}-{end}" - Parameters - ---------- - symbol: str - fund code - df : pd.DataFrame - df.columns must contain "symbol" and "datetime" - """ - if df.empty: - logger.warning(f"{symbol} is empty") - return + try: + # TODO: numberOfHistoricalDaysToCrawl should be bigger enouhg + url = INDEX_BENCH_URL.format( + index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end + ) + resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"}) - fund_path = self.save_dir.joinpath(f"{symbol}.csv") - df["symbol"] = symbol - if fund_path.exists(): - # TODO: read the fund code as str, not int, like "000001" shouldn't be "1" - _old_df = pd.read_csv(fund_path) - # TODO: remove the duplicate date - df = _old_df.append(df, sort=False) - df.to_csv(fund_path, index=False) + if resp.status_code != 200: + raise ValueError("request error") - def _save_small_data(self, symbol, df): - if len(df) <= self.min_numbers_trading: - logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!") - _temp = self._mini_symbol_map.setdefault(symbol, []) - _temp.append(df.copy()) - return None + data = json.loads(resp.text.split("(")[-1].split(")")[0]) + + # Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html + SYType = data["Data"]["SYType"] + if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"): + raise Exception("The fund contains 每*份收益") + + # TODO: should we sort the value by datetime? + _resp = pd.DataFrame(data["Data"]["LSJZList"]) + + if isinstance(_resp, pd.DataFrame): + return _resp.reset_index() + except Exception as e: + logger.warning(f"{error_msg}:{e}") + + def get_data( + self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp + ) -> [pd.DataFrame]: + def _get_simple(start_, end_): + self.sleep() + _remote_interval = interval + return self.get_data_from_remote( + symbol, + interval=_remote_interval, + start=start_, + end=end_, + ) + + if interval == self.INTERVAL_1d: + _result = _get_simple(start_datetime, end_datetime) else: - if symbol in self._mini_symbol_map: - self._mini_symbol_map.pop(symbol) - return symbol - - def _get_data(self, symbol): - _result = None - df = self.fund_data.get_data(symbol) - if isinstance(df, pd.DataFrame): - if not df.empty: - if self._check_small_data: - if self._save_small_data(symbol, df) is not None: - _result = symbol - self.save_fund(symbol, df) - else: - _result = symbol - self.save_fund(symbol, df) + raise ValueError(f"cannot support {interval}") return _result - def _collector(self, fund_list): - error_symbol = [] - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - with tqdm(total=len(fund_list)) as p_bar: - for _symbol, _result in zip(fund_list, executor.map(self._get_data, fund_list)): - if _result is None: - error_symbol.append(_symbol) - p_bar.update() - print(error_symbol) - logger.info(f"error symbol nums: {len(error_symbol)}") - logger.info(f"current get symbol nums: {len(fund_list)}") - error_symbol.extend(self._mini_symbol_map.keys()) - return sorted(set(error_symbol)) - def collector_data(self): """collector data""" - logger.info("start collector fund data......") - fund_list = self.fund_list - for i in range(self._max_collector_count): - if not fund_list: - break - logger.info(f"getting data: {i+1}") - fund_list = self._collector(fund_list) - logger.info(f"{i+1} finish.") - for _symbol, _df_list in self._mini_symbol_map.items(): - self.save_fund(_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])) - if self._mini_symbol_map: - logger.warning(f"less than {self.min_numbers_trading} fund list: {list(self._mini_symbol_map.keys())}") - logger.info(f"total {len(self.fund_list)}, error: {len(set(fund_list))}") + super(FundCollector, self).collector_data() class FundollectorCN(FundCollector, ABC): - def get_fund_list(self): + def get_instrument_list(self): logger.info("get cn fund symbols......") symbols = get_en_fund_symbols() logger.info(f"get {len(symbols)} symbols.") return symbols + def normalize_symbol(self, symbol): + return symbol + @property def _timezone(self): return "Asia/Shanghai" @@ -302,29 +177,9 @@ class FundCollectorCN1d(FundollectorCN): return 252 / 4 -class FundNormalize: - COLUMNS = ["open", "close", "high", "low", "volume"] +class FundNormalize(BaseNormalize): DAILY_FORMAT = "%Y-%m-%d" - def __init__( - self, - date_field_name: str = "date", - symbol_field_name: str = "symbol", - ): - """ - - Parameters - ---------- - date_field_name: str - date field name, default is date - symbol_field_name: str - symbol field name, default is symbol - """ - self._date_field_name = date_field_name - self._symbol_field_name = symbol_field_name - - self._calendar_list = self._get_calendar_list() - @staticmethod def normalize_fund( df: pd.DataFrame, @@ -357,11 +212,6 @@ class FundNormalize: df = self.normalize_fund(df, self._calendar_list, self._date_field_name, self._symbol_field_name) return df - @abc.abstractmethod - def _get_calendar_list(self): - """Get benchmark calendar""" - raise NotImplementedError("") - class FundNormalize1d(FundNormalize, ABC): DAILY_FORMAT = "%Y-%m-%d" @@ -380,62 +230,8 @@ class FundNormalizeCN1d(FundNormalizeCN, FundNormalize1d): pass -class Normalize: - def __init__( - self, - source_dir: [str, Path], - target_dir: [str, Path], - normalize_class: Type[FundNormalize], - max_workers: int = 16, - date_field_name: str = "date", - symbol_field_name: str = "symbol", - ): - """ - - Parameters - ---------- - source_dir: str or Path - The directory where the raw data collected from the Internet is saved - target_dir: str or Path - Directory for normalize data - normalize_class: Type[FundNormalize] - normalize class - max_workers: int - Concurrent number, default is 16 - date_field_name: str - date field name, default is date - symbol_field_name: str - symbol field name, default is symbol - """ - if not (source_dir and target_dir): - raise ValueError("source_dir and target_dir cannot be None") - self._source_dir = Path(source_dir).expanduser() - self._target_dir = Path(target_dir).expanduser() - self._target_dir.mkdir(parents=True, exist_ok=True) - - self._max_workers = max_workers - - self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name) - - def _executor(self, file_path: Path): - file_path = Path(file_path) - df = pd.read_csv(file_path) - df = self._normalize_obj.normalize(df) - if not df.empty: - df.to_csv(self._target_dir.joinpath(file_path.name), index=False) - - def normalize(self): - logger.info("normalize data......") - - with ProcessPoolExecutor(max_workers=self._max_workers) as worker: - file_list = list(self._source_dir.glob("*.csv")) - with tqdm(total=len(file_list)) as p_bar: - for _ in worker.map(self._executor, file_list): - p_bar.update() - - -class Run: - def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN): +class Run(BaseRun): + def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN): """ Parameters @@ -446,23 +242,26 @@ class Run: Directory for normalize data, default "Path(__file__).parent/normalize" max_workers: int Concurrent number, default is 4 + interval: str + freq, value from [1min, 1d], default 1d region: str region, value from ["CN"], default "CN" """ - if source_dir is None: - source_dir = CUR_DIR.joinpath("source") - self.source_dir = Path(source_dir).expanduser().resolve() - self.source_dir.mkdir(parents=True, exist_ok=True) - - if normalize_dir is None: - normalize_dir = CUR_DIR.joinpath("normalize") - self.normalize_dir = Path(normalize_dir).expanduser().resolve() - self.normalize_dir.mkdir(parents=True, exist_ok=True) - - self._cur_module = importlib.import_module("collector") - self.max_workers = max_workers + super().__init__(source_dir, normalize_dir, max_workers, interval) self.region = region + @property + def collector_class_name(self): + return f"FundCollector{self.region.upper()}{self.interval}" + + @property + def normalize_class_name(self): + return f"FundNormalize{self.region.upper()}{self.interval}" + + @property + def default_base_dir(self) -> [Path, str]: + return CUR_DIR + def download_data( self, max_collector_count=2, @@ -498,26 +297,13 @@ class Run: $ python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d """ - _class = getattr(self._cur_module, f"FundCollector{self.region.upper()}{interval}") # type: Type[FundCollector] - _class( - self.source_dir, - max_workers=self.max_workers, - max_collector_count=max_collector_count, - delay=delay, - start=start, - end=end, - interval=interval, - check_data_length=check_data_length, - limit_nums=limit_nums, - ).collector_data() + super(Run, self).download_data(max_collector_count, delay, start, end, interval, check_data_length, limit_nums) - def normalize_data(self, interval: str = "1d", date_field_name: str = "date", symbol_field_name: str = "symbol"): + def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"): """normalize data Parameters ---------- - interval: str - freq, value from [1d], default 1d date_field_name: str date field name, default date symbol_field_name: str @@ -527,16 +313,7 @@ class Run: --------- $ python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_1d --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ """ - _class = getattr(self._cur_module, f"FundNormalize{self.region.upper()}{interval}") - fc = Normalize( - source_dir=self.source_dir, - target_dir=self.normalize_dir, - normalize_class=_class, - max_workers=self.max_workers, - date_field_name=date_field_name, - symbol_field_name=symbol_field_name, - ) - fc.normalize() + super(Run, self).normalize_data(date_field_name, symbol_field_name) if __name__ == "__main__":