diff --git a/scripts/data_collector/fund/README.md b/scripts/data_collector/fund/README.md index bcbbbcba7..c729b7eaa 100644 --- a/scripts/data_collector/fund/README.md +++ b/scripts/data_collector/fund/README.md @@ -20,10 +20,12 @@ pip install -r requirements.txt # download from eastmoney.com python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d +# normalize +python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_1d --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ # dump data cd qlib/scripts -python dump_bin.py dump_all --csv_path ~/.qlib/fund_data/source/cn_1d --qlib_dir ~/.qlib/qlib_data/cn_fund_data --freq day --date_field_name FSRQ --include_fields DWJZ,LJJZ +python dump_bin.py dump_all --csv_path ~/.qlib/fund_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/cn_fund_data --freq day --date_field_name FSRQ --include_fields DWJZ,LJJZ ``` diff --git a/scripts/data_collector/fund/collector.py b/scripts/data_collector/fund/collector.py index a2b7089a1..795d8848e 100644 --- a/scripts/data_collector/fund/collector.py +++ b/scripts/data_collector/fund/collector.py @@ -23,7 +23,7 @@ from dateutil.tz import tzlocal CUR_DIR = Path(__file__).resolve().parent sys.path.append(str(CUR_DIR.parent.parent)) -from data_collector.utils import get_en_fund_symbols +from data_collector.utils import get_calendar_list, get_en_fund_symbols INDEX_BENCH_URL = "http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}" REGION_CN = "CN" @@ -302,14 +302,149 @@ class FundCollectorCN1d(FundollectorCN): return 252 / 4 +class FundNormalize: + COLUMNS = ["open", "close", "high", "low", "volume"] + DAILY_FORMAT = "%Y-%m-%d" + + def __init__( + self, + date_field_name: str = "date", + symbol_field_name: str = "symbol", + ): + """ + + Parameters + ---------- + date_field_name: str + date field name, default is date + symbol_field_name: str + symbol field name, default is symbol + """ + self._date_field_name = date_field_name + self._symbol_field_name = symbol_field_name + + self._calendar_list = self._get_calendar_list() + print (self._calendar_list) + + @staticmethod + def normalize_fund( + df: pd.DataFrame, + calendar_list: list = None, + date_field_name: str = "date", + symbol_field_name: str = "symbol", + ): + if df.empty: + return df + df = df.copy() + df.set_index(date_field_name, inplace=True) + df.index = pd.to_datetime(df.index) + df = df[~df.index.duplicated(keep="first")] + if calendar_list is not None: + df = df.reindex( + pd.DataFrame(index=calendar_list) + .loc[ + pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date() + + pd.Timedelta(hours=23, minutes=59) + ] + .index + ) + df.sort_index(inplace=True) + + df.index.names = [date_field_name] + return df.reset_index() + + def normalize(self, df: pd.DataFrame) -> pd.DataFrame: + # normalize + df = self.normalize_fund(df, self._calendar_list, self._date_field_name, self._symbol_field_name) + return df + + @abc.abstractmethod + def _get_calendar_list(self): + """Get benchmark calendar""" + raise NotImplementedError("") + + +class FundNormalize1d(FundNormalize, ABC): + DAILY_FORMAT = "%Y-%m-%d" + + def normalize(self, df: pd.DataFrame) -> pd.DataFrame: + df = super(FundNormalize, self).normalize(df) + return df + + +class FundNormalizeCN: + def _get_calendar_list(self): + return get_calendar_list("ALL") + + +class FundNormalizeCN1d(FundNormalizeCN, FundNormalize1d): + pass + + +class Normalize: + def __init__( + self, + source_dir: [str, Path], + target_dir: [str, Path], + normalize_class: Type[FundNormalize], + max_workers: int = 16, + date_field_name: str = "date", + symbol_field_name: str = "symbol", + ): + """ + + Parameters + ---------- + source_dir: str or Path + The directory where the raw data collected from the Internet is saved + target_dir: str or Path + Directory for normalize data + normalize_class: Type[FundNormalize] + normalize class + max_workers: int + Concurrent number, default is 16 + date_field_name: str + date field name, default is date + symbol_field_name: str + symbol field name, default is symbol + """ + if not (source_dir and target_dir): + raise ValueError("source_dir and target_dir cannot be None") + self._source_dir = Path(source_dir).expanduser() + self._target_dir = Path(target_dir).expanduser() + self._target_dir.mkdir(parents=True, exist_ok=True) + + self._max_workers = max_workers + + self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name) + + def _executor(self, file_path: Path): + file_path = Path(file_path) + df = pd.read_csv(file_path) + df = self._normalize_obj.normalize(df) + if not df.empty: + df.to_csv(self._target_dir.joinpath(file_path.name), index=False) + + def normalize(self): + logger.info("normalize data......") + + with ProcessPoolExecutor(max_workers=self._max_workers) as worker: + file_list = list(self._source_dir.glob("*.csv")) + with tqdm(total=len(file_list)) as p_bar: + for _ in worker.map(self._executor, file_list): + p_bar.update() + + class Run: - def __init__(self, source_dir=None, max_workers=4, region=REGION_CN): + def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN): """ Parameters ---------- source_dir: str The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source" + normalize_dir: str + Directory for normalize data, default "Path(__file__).parent/normalize" max_workers: int Concurrent number, default is 4 region: str @@ -320,6 +455,11 @@ class Run: self.source_dir = Path(source_dir).expanduser().resolve() self.source_dir.mkdir(parents=True, exist_ok=True) + if normalize_dir is None: + normalize_dir = CUR_DIR.joinpath("normalize") + self.normalize_dir = Path(normalize_dir).expanduser().resolve() + self.normalize_dir.mkdir(parents=True, exist_ok=True) + self._cur_module = importlib.import_module("collector") self.max_workers = max_workers self.region = region @@ -372,6 +512,33 @@ class Run: limit_nums=limit_nums, ).collector_data() + def normalize_data(self, interval: str = "1d", date_field_name: str = "date", symbol_field_name: str = "symbol"): + """normalize data + + Parameters + ---------- + interval: str + freq, value from [1d], default 1d + date_field_name: str + date field name, default date + symbol_field_name: str + symbol field name, default symbol + + Examples + --------- + $ python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_1d --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ + """ + _class = getattr(self._cur_module, f"FundNormalize{self.region.upper()}{interval}") + fc = Normalize( + source_dir=self.source_dir, + target_dir=self.normalize_dir, + normalize_class=_class, + max_workers=self.max_workers, + date_field_name=date_field_name, + symbol_field_name=symbol_field_name, + ) + fc.normalize() + if __name__ == "__main__": fire.Fire(Run) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 1a08c514f..56d010974 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -98,14 +98,14 @@ def get_calendar_list(bench_code="CSI300") -> list: return calendar -def return_date_list(source_dir, date_field_name, file_path): - df = pd.read_csv(Path(source_dir).joinpath(file_path), sep=",", index_col=0) - - return df[date_field_name].to_list() +def return_date_list(source_dir, date_field_name: str, file_path: Path): + file_path = Path(file_path) + date_list = pd.read_csv(Path(source_dir).joinpath(file_path), sep=",", index_col=0)[date_field_name].to_list() + return sorted(map(lambda x: pd.Timestamp(x), date_list)) def get_calendar_list_by_ratio( - source_dir: [str, Path], date_field_name: str = "date", threshold: float = 0.5, max_workers: int = 16 + source_dir: [str, Path], date_field_name: str = "date", threshold: float = 0.5, minimum_count: int = 10, max_workers: int = 16 ) -> list: """get calendar list by selecting the date when few funds trade in this day @@ -117,6 +117,8 @@ def get_calendar_list_by_ratio( date field name, default is date threshold: float threshold to exclude some days when few funds trade in this day, default 0.5 + minimum_count: int + minimum count of funds should trade in one day max_workers: int Concurrent number, default is 16 @@ -126,25 +128,36 @@ def get_calendar_list_by_ratio( """ logger.info(f"get calendar list from {source_dir} by threshold = {threshold}......") - _number_all_funds = len(os.listdir(source_dir)) + source_dir = Path(source_dir).expanduser() + file_list = list(source_dir.glob("*.csv")) - _list_all_date = dict() + _number_all_funds = len(file_list) + logger.info(f"count how many funds trade in this day......") + _dict_count_trade = dict() # dict{date:count} _fun = partial(return_date_list, source_dir, date_field_name) - with tqdm(total=_number_all_funds) as p_bar: with ProcessPoolExecutor(max_workers=max_workers) as executor: - for date_list in executor.map(_fun, os.listdir(source_dir)): + for date_list in executor.map(_fun, file_list[:_number_all_funds]): for date in date_list: - if date in _list_all_date.keys(): - _list_all_date[date] += 1 - else: - _list_all_date[date] = 0 + if date not in _dict_count_trade.keys(): + _dict_count_trade[date] = 0 + + _dict_count_trade[date] += 1 p_bar.update() + + logger.info(f"count how many funds have founded in this day......") + _dict_count_founding = {date:_number_all_funds for date in _dict_count_trade.keys()} # dict{date:count} + with tqdm(total=_number_all_funds) as p_bar: + with ProcessPoolExecutor(max_workers=max_workers) as executor: + for date_list in executor.map(_fun, file_list[:_number_all_funds]): + oldest_date = sorted(date_list)[0] # this fund haven't found before this day + for date in _dict_count_founding.keys(): + if date < oldest_date: + _dict_count_founding[date] -= 1 - _threshold_number = int(_number_all_funds * threshold) - calendar = [date for date in _list_all_date if _list_all_date[date] >= _threshold_number] + calendar = [date for date in _dict_count_trade if _dict_count_trade[date] >= max(int(_dict_count_founding[date] * threshold), minimum_count)] return calendar