diff --git a/scripts/data_collector/base.py b/scripts/data_collector/base.py new file mode 100644 index 000000000..ccd8f59e5 --- /dev/null +++ b/scripts/data_collector/base.py @@ -0,0 +1,430 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +import abc +import time +import datetime +import importlib +from pathlib import Path +from typing import Type +from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor + +import pandas as pd +from tqdm import tqdm +from loguru import logger +from qlib.utils import code_to_fname + + +class BaseCollector(abc.ABC): + + CACHE_FLAG = "CACHED" + NORMAL_FLAG = "NORMAL" + + DEFAULT_START_DATETIME_1D = pd.Timestamp("2000-01-01") + DEFAULT_START_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6)) + DEFAULT_END_DATETIME_1D = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)) + DEFAULT_END_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)) + + INTERVAL_1min = "1min" + INTERVAL_1d = "1d" + + def __init__( + self, + save_dir: [str, Path], + start=None, + end=None, + interval="1d", + max_workers=4, + max_collector_count=2, + delay=0, + check_data_length: bool = False, + limit_nums: int = None, + ): + """ + + Parameters + ---------- + save_dir: str + stock save dir + max_workers: int + workers, default 4 + max_collector_count: int + default 2 + delay: float + time.sleep(delay), default 0 + interval: str + freq, value from [1min, 1d], default 1d + start: str + start datetime, default None + end: str + end datetime, default None + check_data_length: bool + check data length, by default False + limit_nums: int + using for debug, by default None + """ + self.save_dir = Path(save_dir).expanduser().resolve() + self.save_dir.mkdir(parents=True, exist_ok=True) + + self.delay = delay + self.max_workers = max_workers + self.max_collector_count = max_collector_count + self.mini_symbol_map = {} + self.interval = interval + self.check_small_data = check_data_length + + self.start_datetime = self.normalize_start_datetime(start) + self.end_datetime = self.normalize_end_datetime(end) + + self.stock_list = sorted(set(self.get_stock_list())) + + if limit_nums is not None: + try: + self.stock_list = self.stock_list[: int(limit_nums)] + except Exception as e: + logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored") + + def normalize_start_datetime(self, start_datetime: [str, pd.Timestamp] = None): + return ( + pd.Timestamp(str(start_datetime)) + if start_datetime + else getattr(self, f"DEFAULT_START_DATETIME_{self.interval.upper()}") + ) + + def normalize_end_datetime(self, end_datetime: [str, pd.Timestamp] = None): + return ( + pd.Timestamp(str(end_datetime)) + if end_datetime + else getattr(self, f"DEFAULT_END_DATETIME_{self.interval.upper()}") + ) + + @property + @abc.abstractmethod + def min_numbers_trading(self): + # daily, one year: 252 / 4 + # us 1min, a week: 6.5 * 60 * 5 + # cn 1min, a week: 4 * 60 * 5 + raise NotImplementedError("rewrite min_numbers_trading") + + @abc.abstractmethod + def get_stock_list(self): + raise NotImplementedError("rewrite get_stock_list") + + @abc.abstractmethod + def normalize_symbol(self, symbol: str): + """normalize symbol""" + raise NotImplementedError("rewrite normalize_symbol") + + @abc.abstractmethod + def get_data( + self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp + ) -> pd.DataFrame: + """get data with symbol + + Parameters + ---------- + symbol: str + interval: str + value from [1min, 1d] + start_datetime: pd.Timestamp + end_datetime: pd.Timestamp + + Returns + --------- + pd.DataFrame, "symbol" in pd.columns + + """ + raise NotImplementedError("rewrite get_timezone") + + def sleep(self): + time.sleep(self.delay) + + def _simple_collector(self, symbol: str): + """ + + Parameters + ---------- + symbol: str + + """ + self.sleep() + df = self.get_data(symbol, self.interval, self.start_datetime, self.end_datetime) + _result = self.NORMAL_FLAG + if self.check_small_data: + _result = self.cache_small_data(symbol, df) + if _result == self.NORMAL_FLAG: + self.save_instrument(symbol, df) + return _result + + def save_instrument(self, symbol, df: pd.DataFrame): + """save stock data to file + + Parameters + ---------- + symbol: str + stock code + df : pd.DataFrame + df.columns must contain "symbol" and "datetime" + """ + if df.empty: + logger.warning(f"{symbol} is empty") + return + + symbol = self.normalize_symbol(symbol) + symbol = code_to_fname(symbol) + stock_path = self.save_dir.joinpath(f"{symbol}.csv") + df["symbol"] = symbol + if stock_path.exists(): + _old_df = pd.read_csv(stock_path) + df = _old_df.append(df, sort=False) + df.to_csv(stock_path, index=False) + + def cache_small_data(self, symbol, df): + if len(df) <= self.min_numbers_trading: + logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!") + _temp = self.mini_symbol_map.setdefault(symbol, []) + _temp.append(df.copy()) + return self.CACHE_FLAG + else: + if symbol in self.mini_symbol_map: + self.mini_symbol_map.pop(symbol) + return self.NORMAL_FLAG + + def _collector(self, stock_list): + + error_symbol = [] + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + with tqdm(total=len(stock_list)) as p_bar: + for _symbol, _result in zip(stock_list, executor.map(self._simple_collector, stock_list)): + if _result != self.NORMAL_FLAG: + error_symbol.append(_symbol) + p_bar.update() + print(error_symbol) + logger.info(f"error symbol nums: {len(error_symbol)}") + logger.info(f"current get symbol nums: {len(stock_list)}") + error_symbol.extend(self.mini_symbol_map.keys()) + return sorted(set(error_symbol)) + + def collector_data(self): + """collector data""" + logger.info("start collector data......") + stock_list = self.stock_list + for i in range(self.max_collector_count): + if not stock_list: + break + logger.info(f"getting data: {i+1}") + stock_list = self._collector(stock_list) + logger.info(f"{i+1} finish.") + for _symbol, _df_list in self.mini_symbol_map.items(): + self.save_instrument( + _symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"]) + ) + if self.mini_symbol_map: + logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self.mini_symbol_map.keys())}") + logger.info(f"total {len(self.stock_list)}, error: {len(set(stock_list))}") + + +class BaseNormalize(abc.ABC): + def __init__( + self, + date_field_name: str = "date", + symbol_field_name: str = "symbol", + ): + """ + + Parameters + ---------- + date_field_name: str + date field name, default is date + symbol_field_name: str + symbol field name, default is symbol + """ + self._date_field_name = date_field_name + self._symbol_field_name = symbol_field_name + + self._calendar_list = self._get_calendar_list() + + @abc.abstractmethod + def normalize(self, df: pd.DataFrame) -> pd.DataFrame: + # normalize + raise NotImplementedError("") + + @abc.abstractmethod + def _get_calendar_list(self): + """Get benchmark calendar""" + raise NotImplementedError("") + + +class Normalize: + def __init__( + self, + source_dir: [str, Path], + target_dir: [str, Path], + normalize_class: Type[BaseNormalize], + max_workers: int = 16, + date_field_name: str = "date", + symbol_field_name: str = "symbol", + ): + """ + + Parameters + ---------- + source_dir: str or Path + The directory where the raw data collected from the Internet is saved + target_dir: str or Path + Directory for normalize data + normalize_class: Type[YahooNormalize] + normalize class + max_workers: int + Concurrent number, default is 16 + date_field_name: str + date field name, default is date + symbol_field_name: str + symbol field name, default is symbol + """ + if not (source_dir and target_dir): + raise ValueError("source_dir and target_dir cannot be None") + self._source_dir = Path(source_dir).expanduser() + self._target_dir = Path(target_dir).expanduser() + self._target_dir.mkdir(parents=True, exist_ok=True) + + self._max_workers = max_workers + + self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name) + + def _executor(self, file_path: Path): + file_path = Path(file_path) + df = pd.read_csv(file_path) + df = self._normalize_obj.normalize(df) + if not df.empty: + df.to_csv(self._target_dir.joinpath(file_path.name), index=False) + + def normalize(self): + logger.info("normalize data......") + + with ProcessPoolExecutor(max_workers=self._max_workers) as worker: + file_list = list(self._source_dir.glob("*.csv")) + with tqdm(total=len(file_list)) as p_bar: + for _ in worker.map(self._executor, file_list): + p_bar.update() + + +class BaseRun(abc.ABC): + def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d"): + """ + + Parameters + ---------- + source_dir: str + The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source" + normalize_dir: str + Directory for normalize data, default "Path(__file__).parent/normalize" + max_workers: int + Concurrent number, default is 4 + interval: str + freq, value from [1min, 1d], default 1d + """ + if source_dir is None: + source_dir = Path(self.default_base_dir).joinpath("_source") + self.source_dir = Path(source_dir).expanduser().resolve() + self.source_dir.mkdir(parents=True, exist_ok=True) + + if normalize_dir is None: + normalize_dir = Path(self.default_base_dir).joinpath("normalize") + self.normalize_dir = Path(normalize_dir).expanduser().resolve() + self.normalize_dir.mkdir(parents=True, exist_ok=True) + + self._cur_module = importlib.import_module("collector") + self.max_workers = max_workers + self.interval = interval + + @property + @abc.abstractmethod + def collector_class_name(self): + raise NotImplementedError("rewrite normalize_symbol") + + @property + @abc.abstractmethod + def normalize_class_name(self): + raise NotImplementedError("rewrite normalize_symbol") + + @property + @abc.abstractmethod + def default_base_dir(self) -> [Path, str]: + raise NotImplementedError("rewrite normalize_symbol") + + def download_data( + self, + max_collector_count=2, + delay=0, + start=None, + end=None, + interval="1d", + check_data_length=False, + limit_nums=None, + ): + """download data from Internet + + Parameters + ---------- + max_collector_count: int + default 2 + delay: float + time.sleep(delay), default 0 + interval: str + freq, value from [1min, 1d], default 1d + start: str + start datetime, default "2000-01-01" + end: str + end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`` + check_data_length: bool + check data length, by default False + limit_nums: int + using for debug, by default None + + Examples + --------- + # get daily data + $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d + # get 1m data + $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m + """ + + _class = getattr(self._cur_module, self.collector_class_name) # type: Type[BaseCollector] + _class( + self.source_dir, + max_workers=self.max_workers, + max_collector_count=max_collector_count, + delay=delay, + start=start, + end=end, + interval=interval, + check_data_length=check_data_length, + limit_nums=limit_nums, + ).collector_data() + + def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"): + """normalize data + + Parameters + ---------- + date_field_name: str + date field name, default date + symbol_field_name: str + symbol field name, default symbol + + Examples + --------- + $ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d + """ + _class = getattr(self._cur_module, self.normalize_class_name) + yc = Normalize( + source_dir=self.source_dir, + target_dir=self.normalize_dir, + normalize_class=_class, + max_workers=self.max_workers, + date_field_name=date_field_name, + symbol_field_name=symbol_field_name, + ) + yc.normalize() diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 743f89462..eadc381ec 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -10,158 +10,26 @@ import importlib from abc import ABC from pathlib import Path from typing import Iterable, Type -from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor import fire import requests import numpy as np import pandas as pd -from tqdm import tqdm from loguru import logger from yahooquery import Ticker from dateutil.tz import tzlocal from qlib.utils import code_to_fname, fname_to_code +from qlib.config import REG_CN as REGION_CN CUR_DIR = Path(__file__).resolve().parent sys.path.append(str(CUR_DIR.parent.parent)) +from data_collector.base import BaseCollector, BaseNormalize, BaseRun from data_collector.utils import get_calendar_list, get_hs_stock_symbols, get_us_stock_symbols INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}" -REGION_CN = "CN" -REGION_US = "US" -class YahooData: - START_DATETIME = pd.Timestamp("2000-01-01") - HIGH_FREQ_START_DATETIME = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6)) - END_DATETIME = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)) - INTERVAL_1min = "1min" - INTERVAL_1d = "1d" - - def __init__( - self, - timezone: str = None, - start=None, - end=None, - interval="1d", - delay=0, - show_1min_logging: bool = False, - ): - """ - - Parameters - ---------- - timezone: str - The timezone where the data is located - delay: float - time.sleep(delay), default 0 - interval: str - freq, value from [1min, 1d], default 1min - start: str - start datetime, default None - end: str - end datetime, default None - show_1min_logging: bool - show 1min logging, by default False; if True, there may be many warning logs - """ - self._timezone = tzlocal() if timezone is None else timezone - self._delay = delay - self._interval = interval - self._show_1min_logging = show_1min_logging - self.start_datetime = pd.Timestamp(str(start)) if start else self.START_DATETIME - self.end_datetime = min(pd.Timestamp(str(end)) if end else self.END_DATETIME, self.END_DATETIME) - if self._interval == self.INTERVAL_1min: - self.start_datetime = max(self.start_datetime, self.HIGH_FREQ_START_DATETIME) - elif self._interval == self.INTERVAL_1d: - pass - else: - raise ValueError(f"interval error: {self._interval}") - - # using for 1min - self._next_datetime = self.convert_datetime(self.start_datetime.date() + pd.Timedelta(days=1), self._timezone) - self._latest_datetime = self.convert_datetime(self.end_datetime.date(), self._timezone) - - self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone) - self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone) - - @staticmethod - def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone): - try: - dt = pd.Timestamp(dt, tz=timezone).timestamp() - dt = pd.Timestamp(dt, tz=tzlocal(), unit="s") - except ValueError as e: - pass - return dt - - def _sleep(self): - time.sleep(self._delay) - - @staticmethod - def get_data_from_remote(symbol, interval, start, end, show_1min_logging: bool = False): - error_msg = f"{symbol}-{interval}-{start}-{end}" - - def _show_logging_func(): - if interval == YahooData.INTERVAL_1min and show_1min_logging: - logger.warning(f"{error_msg}:{_resp}") - - interval = "1m" if interval in ["1m", "1min"] else interval - try: - _resp = Ticker(symbol, asynchronous=False).history(interval=interval, start=start, end=end) - if isinstance(_resp, pd.DataFrame): - return _resp.reset_index() - elif isinstance(_resp, dict): - _temp_data = _resp.get(symbol, {}) - if isinstance(_temp_data, str) or ( - isinstance(_resp, dict) and _temp_data.get("indicators", {}).get("quote", None) is None - ): - _show_logging_func() - else: - _show_logging_func() - except Exception as e: - logger.warning(f"{error_msg}:{e}") - - def get_data(self, symbol: str) -> [pd.DataFrame]: - def _get_simple(start_, end_): - self._sleep() - _remote_interval = "1m" if self._interval == self.INTERVAL_1min else self._interval - return self.get_data_from_remote( - symbol, - interval=_remote_interval, - start=start_, - end=end_, - show_1min_logging=self._show_1min_logging, - ) - - _result = None - if self._interval == self.INTERVAL_1d: - _result = _get_simple(self.start_datetime, self.end_datetime) - elif self._interval == self.INTERVAL_1min: - if self._next_datetime >= self._latest_datetime: - _result = _get_simple(self.start_datetime, self.end_datetime) - else: - _res = [] - - def _get_multi(start_, end_): - _resp = _get_simple(start_, end_) - if _resp is not None and not _resp.empty: - _res.append(_resp) - - for _s, _e in ( - (self.start_datetime, self._next_datetime), - (self._latest_datetime, self.end_datetime), - ): - _get_multi(_s, _e) - for _start in pd.date_range(self._next_datetime, self._latest_datetime, closed="left"): - _end = _start + pd.Timedelta(days=1) - _get_multi(_start, _end) - if _res: - _result = pd.concat(_res, sort=False).sort_values(["symbol", "date"]) - else: - raise ValueError(f"cannot support {self._interval}") - return _result - - -class YahooCollector: +class YahooCollector(BaseCollector): def __init__( self, save_dir: [str, Path], @@ -173,7 +41,6 @@ class YahooCollector: delay=0, check_data_length: bool = False, limit_nums: int = None, - show_1min_logging: bool = False, ): """ @@ -197,131 +64,118 @@ class YahooCollector: check data length, by default False limit_nums: int using for debug, by default None - show_1min_logging: bool - show 1m logging, by default False; if True, there may be many warning logs """ - self.save_dir = Path(save_dir).expanduser().resolve() - self.save_dir.mkdir(parents=True, exist_ok=True) - - self._delay = delay - self.max_workers = max_workers - self._max_collector_count = max_collector_count - self._mini_symbol_map = {} - self._interval = interval - self._check_small_data = check_data_length - - self.stock_list = sorted(set(self.get_stock_list())) - if limit_nums is not None: - try: - self.stock_list = self.stock_list[: int(limit_nums)] - except Exception as e: - logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored") - - self.yahoo_data = YahooData( - timezone=self._timezone, + super(YahooCollector, self).__init__( + save_dir=save_dir, start=start, end=end, interval=interval, + max_workers=max_workers, + max_collector_count=max_collector_count, delay=delay, - show_1min_logging=show_1min_logging, + check_data_length=check_data_length, + limit_nums=limit_nums, ) - @property - @abc.abstractmethod - def min_numbers_trading(self): - # daily, one year: 252 / 4 - # us 1min, a week: 6.5 * 60 * 5 - # cn 1min, a week: 4 * 60 * 5 - raise NotImplementedError("rewrite min_numbers_trading") + self.init_datetime() - @abc.abstractmethod - def get_stock_list(self): - raise NotImplementedError("rewrite get_stock_list") + def init_datetime(self): + if self.interval == self.INTERVAL_1min: + self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN) + elif self.interval == self.INTERVAL_1d: + pass + else: + raise ValueError(f"interval error: {self.interval}") + + # using for 1min + self._next_datetime = self.convert_datetime(self.start_datetime.date() + pd.Timedelta(days=1), self._timezone) + self._latest_datetime = self.convert_datetime(self.end_datetime.date(), self._timezone) + + self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone) + self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone) + + @staticmethod + def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone): + try: + dt = pd.Timestamp(dt, tz=timezone).timestamp() + dt = pd.Timestamp(dt, tz=tzlocal(), unit="s") + except ValueError as e: + pass + return dt @property @abc.abstractmethod def _timezone(self): raise NotImplementedError("rewrite get_timezone") - def save_stock(self, symbol, df: pd.DataFrame): - """save stock data to file + @staticmethod + def get_data_from_remote(symbol, interval, start, end, show_1min_logging: bool = False): + error_msg = f"{symbol}-{interval}-{start}-{end}" - Parameters - ---------- - symbol: str - stock code - df : pd.DataFrame - df.columns must contain "symbol" and "datetime" - """ - if df.empty: - logger.warning(f"{symbol} is empty") - return + def _show_logging_func(): + if interval == YahooCollector.INTERVAL_1min and show_1min_logging: + logger.warning(f"{error_msg}:{_resp}") - symbol = self.normalize_symbol(symbol) - symbol = code_to_fname(symbol) - stock_path = self.save_dir.joinpath(f"{symbol}.csv") - df["symbol"] = symbol - if stock_path.exists(): - _old_df = pd.read_csv(stock_path) - df = _old_df.append(df, sort=False) - df.to_csv(stock_path, index=False) + interval = "1m" if interval in ["1m", "1min"] else interval + try: + _resp = Ticker(symbol, asynchronous=False).history(interval=interval, start=start, end=end) + if isinstance(_resp, pd.DataFrame): + return _resp.reset_index() + elif isinstance(_resp, dict): + _temp_data = _resp.get(symbol, {}) + if isinstance(_temp_data, str) or ( + isinstance(_resp, dict) and _temp_data.get("indicators", {}).get("quote", None) is None + ): + _show_logging_func() + else: + _show_logging_func() + except Exception as e: + logger.warning(f"{error_msg}:{e}") - def _save_small_data(self, symbol, df): - if len(df) <= self.min_numbers_trading: - logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!") - _temp = self._mini_symbol_map.setdefault(symbol, []) - _temp.append(df.copy()) - return None - else: - if symbol in self._mini_symbol_map: - self._mini_symbol_map.pop(symbol) - return symbol + def get_data( + self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp + ) -> pd.DataFrame: + def _get_simple(start_, end_): + self.sleep() + _remote_interval = "1m" if interval == self.INTERVAL_1min else interval + return self.get_data_from_remote( + symbol, + interval=_remote_interval, + start=start_, + end=end_, + ) - def _get_data(self, symbol): _result = None - df = self.yahoo_data.get_data(symbol) - if isinstance(df, pd.DataFrame): - if not df.empty: - if self._check_small_data: - if self._save_small_data(symbol, df) is not None: - _result = symbol - self.save_stock(symbol, df) - else: - _result = symbol - self.save_stock(symbol, df) - return _result + if interval == self.INTERVAL_1d: + _result = _get_simple(start_datetime, end_datetime) + elif interval == self.INTERVAL_1min: + if self._next_datetime >= self._latest_datetime: + _result = _get_simple(start_datetime, end_datetime) + else: + _res = [] - def _collector(self, stock_list): + def _get_multi(start_, end_): + _resp = _get_simple(start_, end_) + if _resp is not None and not _resp.empty: + _res.append(_resp) - error_symbol = [] - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - with tqdm(total=len(stock_list)) as p_bar: - for _symbol, _result in zip(stock_list, executor.map(self._get_data, stock_list)): - if _result is None: - error_symbol.append(_symbol) - p_bar.update() - print(error_symbol) - logger.info(f"error symbol nums: {len(error_symbol)}") - logger.info(f"current get symbol nums: {len(stock_list)}") - error_symbol.extend(self._mini_symbol_map.keys()) - return sorted(set(error_symbol)) + for _s, _e in ( + (self.start_datetime, self._next_datetime), + (self._latest_datetime, self.end_datetime), + ): + _get_multi(_s, _e) + for _start in pd.date_range(self._next_datetime, self._latest_datetime, closed="left"): + _end = _start + pd.Timedelta(days=1) + _get_multi(_start, _end) + if _res: + _result = pd.concat(_res, sort=False).sort_values(["symbol", "date"]) + else: + raise ValueError(f"cannot support {self.interval}") + return pd.DataFrame() if _result is None else _result def collector_data(self): """collector data""" - logger.info("start collector yahoo data......") - stock_list = self.stock_list - for i in range(self._max_collector_count): - if not stock_list: - break - logger.info(f"getting data: {i+1}") - stock_list = self._collector(stock_list) - logger.info(f"{i+1} finish.") - for _symbol, _df_list in self._mini_symbol_map.items(): - self.save_stock(_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])) - if self._mini_symbol_map: - logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self._mini_symbol_map.keys())}") - logger.info(f"total {len(self.stock_list)}, error: {len(set(stock_list))}") - + super(YahooCollector, self).collector_data() self.download_index_data() @abc.abstractmethod @@ -329,11 +183,6 @@ class YahooCollector: """download index data""" raise NotImplementedError("rewrite download_index_data") - @abc.abstractmethod - def normalize_symbol(self, symbol: str): - """normalize symbol""" - raise NotImplementedError("rewrite normalize_symbol") - class YahooCollectorCN(YahooCollector, ABC): def get_stock_list(self): @@ -360,8 +209,8 @@ class YahooCollectorCN1d(YahooCollectorCN): def download_index_data(self): # TODO: from MSN _format = "%Y%m%d" - _begin = self.yahoo_data.start_datetime.strftime(_format) - _end = (self.yahoo_data.end_datetime + pd.Timedelta(days=-1)).strftime(_format) + _begin = self.start_datetime.strftime(_format) + _end = (self.end_datetime + pd.Timedelta(days=-1)).strftime(_format) for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items(): logger.info(f"get bench data: {_index_name}({_index_code})......") try: @@ -396,7 +245,7 @@ class YahooCollectorCN1min(YahooCollectorCN): def download_index_data(self): # TODO: 1m - logger.warning(f"{self.__class__.__name__} {self._interval} does not support: download_index_data") + logger.warning(f"{self.__class__.__name__} {self.interval} does not support: download_index_data") class YahooCollectorUS(YahooCollector, ABC): @@ -433,29 +282,10 @@ class YahooCollectorUS1min(YahooCollectorUS): return 60 * 6.5 * 5 -class YahooNormalize: +class YahooNormalize(BaseNormalize): COLUMNS = ["open", "close", "high", "low", "volume"] DAILY_FORMAT = "%Y-%m-%d" - def __init__( - self, - date_field_name: str = "date", - symbol_field_name: str = "symbol", - ): - """ - - Parameters - ---------- - date_field_name: str - date field name, default is date - symbol_field_name: str - symbol field name, default is symbol - """ - self._date_field_name = date_field_name - self._symbol_field_name = symbol_field_name - - self._calendar_list = self._get_calendar_list() - @staticmethod def normalize_yahoo( df: pd.DataFrame, @@ -498,11 +328,6 @@ class YahooNormalize: df = self.adjusted_price(df) return df - @abc.abstractmethod - def _get_calendar_list(self): - """Get benchmark calendar""" - raise NotImplementedError("") - @abc.abstractmethod def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame: """adjusted price""" @@ -618,7 +443,9 @@ class YahooNormalize1min(YahooNormalize, ABC): # get 1d data from yahoo _start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT) _end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT) - data_1d = YahooData.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=_start, end=_end) + data_1d = YahooCollector.get_data_from_remote( + self.symbol_to_yahoo(symbol), interval="1d", start=_start, end=_end + ) if data_1d is None or data_1d.empty: df["factor"] = 1 # TODO: np.nan or 1 or 0 @@ -723,62 +550,8 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min): return get_calendar_list("ALL") -class Normalize: - def __init__( - self, - source_dir: [str, Path], - target_dir: [str, Path], - normalize_class: Type[YahooNormalize], - max_workers: int = 16, - date_field_name: str = "date", - symbol_field_name: str = "symbol", - ): - """ - - Parameters - ---------- - source_dir: str or Path - The directory where the raw data collected from the Internet is saved - target_dir: str or Path - Directory for normalize data - normalize_class: Type[YahooNormalize] - normalize class - max_workers: int - Concurrent number, default is 16 - date_field_name: str - date field name, default is date - symbol_field_name: str - symbol field name, default is symbol - """ - if not (source_dir and target_dir): - raise ValueError("source_dir and target_dir cannot be None") - self._source_dir = Path(source_dir).expanduser() - self._target_dir = Path(target_dir).expanduser() - self._target_dir.mkdir(parents=True, exist_ok=True) - - self._max_workers = max_workers - - self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name) - - def _executor(self, file_path: Path): - file_path = Path(file_path) - df = pd.read_csv(file_path) - df = self._normalize_obj.normalize(df) - if not df.empty: - df.to_csv(self._target_dir.joinpath(file_path.name), index=False) - - def normalize(self): - logger.info("normalize data......") - - with ProcessPoolExecutor(max_workers=self._max_workers) as worker: - file_list = list(self._source_dir.glob("*.csv")) - with tqdm(total=len(file_list)) as p_bar: - for _ in worker.map(self._executor, file_list): - p_bar.update() - - -class Run: - def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN): +class Run(BaseRun): + def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN): """ Parameters @@ -789,23 +562,26 @@ class Run: Directory for normalize data, default "Path(__file__).parent/normalize" max_workers: int Concurrent number, default is 4 + interval: str + freq, value from [1min, 1d], default 1d region: str region, value from ["CN", "US"], default "CN" """ - if source_dir is None: - source_dir = CUR_DIR.joinpath("source") - self.source_dir = Path(source_dir).expanduser().resolve() - self.source_dir.mkdir(parents=True, exist_ok=True) - - if normalize_dir is None: - normalize_dir = CUR_DIR.joinpath("normalize") - self.normalize_dir = Path(normalize_dir).expanduser().resolve() - self.normalize_dir.mkdir(parents=True, exist_ok=True) - - self._cur_module = importlib.import_module("collector") - self.max_workers = max_workers + super().__init__(source_dir, normalize_dir, max_workers, interval) self.region = region + @property + def collector_class_name(self): + return f"YahooCollector{self.region.upper()}{self.interval}" + + @property + def normalize_class_name(self): + return f"YahooNormalize{self.region.upper()}{self.interval}" + + @property + def default_base_dir(self) -> [Path, str]: + return CUR_DIR + def download_data( self, max_collector_count=2, @@ -815,7 +591,6 @@ class Run: interval="1d", check_data_length=False, limit_nums=None, - show_1min_logging=False, ): """download data from Internet @@ -835,8 +610,6 @@ class Run: check data length, by default False limit_nums: int using for debug, by default None - show_1min_logging: bool - show 1m logging, by default False; if True, there may be many warning logs Examples --------- @@ -846,29 +619,13 @@ class Run: $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m """ - _class = getattr( - self._cur_module, f"YahooCollector{self.region.upper()}{interval}" - ) # type: Type[YahooCollector] - _class( - self.source_dir, - max_workers=self.max_workers, - max_collector_count=max_collector_count, - delay=delay, - start=start, - end=end, - interval=interval, - check_data_length=check_data_length, - limit_nums=limit_nums, - show_1min_logging=show_1min_logging, - ).collector_data() + super(Run, self).download_data(max_collector_count, delay, start, end, interval, check_data_length, limit_nums) - def normalize_data(self, interval: str = "1d", date_field_name: str = "date", symbol_field_name: str = "symbol"): + def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"): """normalize data Parameters ---------- - interval: str - freq, value from [1min, 1d], default 1d date_field_name: str date field name, default date symbol_field_name: str @@ -878,16 +635,7 @@ class Run: --------- $ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d """ - _class = getattr(self._cur_module, f"YahooNormalize{self.region.upper()}{interval}") - yc = Normalize( - source_dir=self.source_dir, - target_dir=self.normalize_dir, - normalize_class=_class, - max_workers=self.max_workers, - date_field_name=date_field_name, - symbol_field_name=symbol_field_name, - ) - yc.normalize() + super(Run, self).normalize_data(date_field_name, symbol_field_name) if __name__ == "__main__":