From 719074d30673d9ba256bf69df396854c3e498c47 Mon Sep 17 00:00:00 2001 From: wangershi Date: Thu, 25 Feb 2021 19:20:14 +0800 Subject: [PATCH 01/11] touch file --- scripts/data_collector/fund/README.md | 49 ++++++++++++++++++++ scripts/data_collector/fund/collector.py | 0 scripts/data_collector/fund/requirements.txt | 0 3 files changed, 49 insertions(+) create mode 100644 scripts/data_collector/fund/README.md create mode 100644 scripts/data_collector/fund/collector.py create mode 100644 scripts/data_collector/fund/requirements.txt diff --git a/scripts/data_collector/fund/README.md b/scripts/data_collector/fund/README.md new file mode 100644 index 000000000..c7b91a3f5 --- /dev/null +++ b/scripts/data_collector/fund/README.md @@ -0,0 +1,49 @@ +# Collect Fund Data + +> *Please pay **ATTENTION** that the data is collected from [天天基金网](https://fund.eastmoney.com/) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)* + +## Requirements + +```bash +pip install -r requirements.txt +``` + +## Collector Data + + +### CN Data + +#### 1d from East Money + +```bash + +# download from yahoo finance +python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d + + +# dump data +cd qlib/scripts +python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/qlib_cn_1d --freq day --exclude_fields date,adjclose,dividends,splits,symbol + +``` + +### using data + +```python +import qlib +from qlib.data import D + +qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1d", region="CN") +df = D.features(D.instruments("all"), ["$close"], freq="day") +``` + + +### Help +```bash +pythono collector.py collector_data --help +``` + +## Parameters + +- interval: 1min or 1d +- region: CN or US diff --git a/scripts/data_collector/fund/collector.py b/scripts/data_collector/fund/collector.py new file mode 100644 index 000000000..e69de29bb diff --git a/scripts/data_collector/fund/requirements.txt b/scripts/data_collector/fund/requirements.txt new file mode 100644 index 000000000..e69de29bb From 6e5639621710bae40d508357d77b9d60a104aff8 Mon Sep 17 00:00:00 2001 From: wangershi Date: Sun, 28 Feb 2021 12:24:26 +0800 Subject: [PATCH 02/11] add crawler --- scripts/data_collector/fund/README.md | 4 +- scripts/data_collector/fund/collector.py | 408 +++++++++++++++++++++++ scripts/data_collector/utils.py | 37 ++ 3 files changed, 447 insertions(+), 2 deletions(-) diff --git a/scripts/data_collector/fund/README.md b/scripts/data_collector/fund/README.md index c7b91a3f5..b14938a3d 100644 --- a/scripts/data_collector/fund/README.md +++ b/scripts/data_collector/fund/README.md @@ -17,8 +17,8 @@ pip install -r requirements.txt ```bash -# download from yahoo finance -python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d +# download from eastmoney.com +python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d # dump data diff --git a/scripts/data_collector/fund/collector.py b/scripts/data_collector/fund/collector.py index e69de29bb..404b6af0e 100644 --- a/scripts/data_collector/fund/collector.py +++ b/scripts/data_collector/fund/collector.py @@ -0,0 +1,408 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import abc +import sys +import copy +import time +import datetime +import importlib +import json +from abc import ABC +from pathlib import Path +from typing import Iterable, Type +from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor + +import fire +import requests +import numpy as np +import pandas as pd +from tqdm import tqdm +from loguru import logger +from dateutil.tz import tzlocal +from qlib.utils import code_to_fname, fname_to_code + +CUR_DIR = Path(__file__).resolve().parent +sys.path.append(str(CUR_DIR.parent.parent)) +from data_collector.utils import get_calendar_list, get_en_fund_symbols + +INDEX_BENCH_URL = "http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}" +REGION_CN = "CN" +REGION_US = "US" + + +class FundData: + START_DATETIME = pd.Timestamp("2000-01-01") + HIGH_FREQ_START_DATETIME = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6)) + END_DATETIME = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)) + INTERVAL_1d = "1d" + + def __init__( + self, + timezone: str = None, + start=None, + end=None, + interval="1d", + delay=0, + show_1min_logging: bool = False, + ): + """ + + Parameters + ---------- + timezone: str + The timezone where the data is located + delay: float + time.sleep(delay), default 0 + interval: str + freq, value from [1d], default 1d + start: str + start datetime, default None + end: str + end datetime, default None + show_1min_logging: bool + show 1min logging, by default False; if True, there may be many warning logs + """ + self._timezone = tzlocal() if timezone is None else timezone + self._delay = delay + self._interval = interval + self._show_1min_logging = show_1min_logging + self.start_datetime = pd.Timestamp(str(start)) if start else self.START_DATETIME + self.end_datetime = min(pd.Timestamp(str(end)) if end else self.END_DATETIME, self.END_DATETIME) + if self._interval != self.INTERVAL_1d: + raise ValueError(f"interval error: {self._interval}") + + # using for 1min + self._next_datetime = self.convert_datetime(self.start_datetime.date() + pd.Timedelta(days=1), self._timezone) + self._latest_datetime = self.convert_datetime(self.end_datetime.date(), self._timezone) + + self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone) + self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone) + + @staticmethod + def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone): + try: + dt = pd.Timestamp(dt, tz=timezone).timestamp() + dt = pd.Timestamp(dt, tz=tzlocal(), unit="s") + except ValueError as e: + pass + return dt + + def _sleep(self): + time.sleep(self._delay) + + @staticmethod + def get_data_from_remote(symbol, interval, start, end, show_1min_logging: bool = False): + error_msg = f"{symbol}-{interval}-{start}-{end}" + + try: + _resp = None + + # TODO: numberOfHistoricalDaysToCrawl should be bigger enouhg + url = INDEX_BENCH_URL.format(index_code=symbol, numberOfHistoricalDaysToCrawl=100, startDate=start, endDate=end) + resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"}) + + if resp.status_code != 200: + raise ValueError("request error") + try: + data = json.loads(resp.text.split("(")[-1].split(")")[0]) + + # Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html + SYType = data["Data"]["SYType"] + if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"): + raise Exception("The fund contains 每*份收益") + + _resp = pd.DataFrame( + data["Data"]["LSJZList"] + ) + + except Exception as e: + logger.warning(f"request error: {e}") + raise + + if isinstance(_resp, pd.DataFrame): + return _resp.reset_index() + except Exception as e: + logger.warning(f"{error_msg}:{e}") + + def get_data(self, symbol: str) -> [pd.DataFrame]: + def _get_simple(start_, end_): + self._sleep() + _remote_interval = self._interval + return self.get_data_from_remote( + symbol, + interval=_remote_interval, + start=start_, + end=end_, + show_1min_logging=self._show_1min_logging, + ) + + if self._interval == self.INTERVAL_1d: + _result = _get_simple(self.start_datetime, self.end_datetime) + else: + raise ValueError(f"cannot support {self._interval}") + return _result + + +class FundCollector: + def __init__( + self, + save_dir: [str, Path], + start=None, + end=None, + interval="1d", + max_workers=4, + max_collector_count=2, + delay=0, + check_data_length: bool = False, + limit_nums: int = None, + show_1min_logging: bool = False, + ): + """ + + Parameters + ---------- + save_dir: str + stock save dir + max_workers: int + workers, default 4 + max_collector_count: int + default 2 + delay: float + time.sleep(delay), default 0 + interval: str + freq, value from [1min, 1d], default 1min + start: str + start datetime, default None + end: str + end datetime, default None + check_data_length: bool + check data length, by default False + limit_nums: int + using for debug, by default None + show_1min_logging: bool + show 1m logging, by default False; if True, there may be many warning logs + """ + self.save_dir = Path(save_dir).expanduser().resolve() + self.save_dir.mkdir(parents=True, exist_ok=True) + + self._delay = delay + self.max_workers = max_workers + self._max_collector_count = max_collector_count + self._mini_symbol_map = {} + self._interval = interval + self._check_small_data = check_data_length + + self.fund_list = sorted(set(self.get_fund_list())) + if limit_nums is not None: + try: + self.fund_list = self.fund_list[: int(limit_nums)] + except Exception as e: + logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored") + + self.fund_data = FundData( + timezone=self._timezone, + start=start, + end=end, + interval=interval, + delay=delay, + show_1min_logging=show_1min_logging, + ) + + @property + @abc.abstractmethod + def min_numbers_trading(self): + # daily, one year: 252 / 4 + # us 1min, a week: 6.5 * 60 * 5 + # cn 1min, a week: 4 * 60 * 5 + raise NotImplementedError("rewrite min_numbers_trading") + + @abc.abstractmethod + def get_fund_list(self): + raise NotImplementedError("rewrite get_fund_list") + + @property + @abc.abstractmethod + def _timezone(self): + raise NotImplementedError("rewrite get_timezone") + + def save_fund(self, symbol, df: pd.DataFrame): + """save fund data to file + + Parameters + ---------- + symbol: str + fund code + df : pd.DataFrame + df.columns must contain "symbol" and "datetime" + """ + if df.empty: + logger.warning(f"{symbol} is empty") + return + + symbol = code_to_fname(symbol) + stock_path = self.save_dir.joinpath(f"{symbol}.csv") + df["symbol"] = symbol + if stock_path.exists(): + _old_df = pd.read_csv(stock_path) + df = _old_df.append(df, sort=False) + df.to_csv(stock_path, index=False) + + def _save_small_data(self, symbol, df): + if len(df) <= self.min_numbers_trading: + logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!") + _temp = self._mini_symbol_map.setdefault(symbol, []) + _temp.append(df.copy()) + return None + else: + if symbol in self._mini_symbol_map: + self._mini_symbol_map.pop(symbol) + return symbol + + def _get_data(self, symbol): + _result = None + df = self.fund_data.get_data(symbol) + if isinstance(df, pd.DataFrame): + if not df.empty: + if self._check_small_data: + if self._save_small_data(symbol, df) is not None: + _result = symbol + self.save_fund(symbol, df) + else: + _result = symbol + self.save_fund(symbol, df) + return _result + + def _collector(self, fund_list): + + error_symbol = [] + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + with tqdm(total=len(fund_list)) as p_bar: + for _symbol, _result in zip(fund_list, executor.map(self._get_data, fund_list)): + if _result is None: + error_symbol.append(_symbol) + p_bar.update() + print(error_symbol) + logger.info(f"error symbol nums: {len(error_symbol)}") + logger.info(f"current get symbol nums: {len(fund_list)}") + error_symbol.extend(self._mini_symbol_map.keys()) + return sorted(set(error_symbol)) + + def collector_data(self): + """collector data""" + logger.info("start collector fund data......") + fund_list = self.fund_list + for i in range(self._max_collector_count): + if not fund_list: + break + logger.info(f"getting data: {i+1}") + fund_list = self._collector(fund_list) + logger.info(f"{i+1} finish.") + for _symbol, _df_list in self._mini_symbol_map.items(): + self.save_fund(_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])) + if self._mini_symbol_map: + logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self._mini_symbol_map.keys())}") + logger.info(f"total {len(self.fund_list)}, error: {len(set(fund_list))}") + +class FundollectorCN(FundCollector, ABC): + def get_fund_list(self): + logger.info("get cn fund symbols......") + symbols = get_en_fund_symbols() + logger.info(f"get {len(symbols)} symbols.") + return symbols + + @property + def _timezone(self): + return "Asia/Shanghai" + + +class FundCollectorCN1d(FundollectorCN): + @property + def min_numbers_trading(self): + return 252 / 4 + +class Run: + def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN): + """ + + Parameters + ---------- + source_dir: str + The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source" + normalize_dir: str + Directory for normalize data, default "Path(__file__).parent/normalize" + max_workers: int + Concurrent number, default is 4 + region: str + region, value from ["CN", "US"], default "CN" + """ + if source_dir is None: + source_dir = CUR_DIR.joinpath("source") + self.source_dir = Path(source_dir).expanduser().resolve() + self.source_dir.mkdir(parents=True, exist_ok=True) + + if normalize_dir is None: + normalize_dir = CUR_DIR.joinpath("normalize") + self.normalize_dir = Path(normalize_dir).expanduser().resolve() + self.normalize_dir.mkdir(parents=True, exist_ok=True) + + self._cur_module = importlib.import_module("collector") + self.max_workers = max_workers + self.region = region + + def download_data( + self, + max_collector_count=2, + delay=0, + start=None, + end=None, + interval="1d", + check_data_length=False, + limit_nums=None, + show_1min_logging=False, + ): + """download data from Internet + + Parameters + ---------- + max_collector_count: int + default 2 + delay: float + time.sleep(delay), default 0 + interval: str + freq, value from [1min, 1d], default 1d + start: str + start datetime, default "2000-01-01" + end: str + end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`` + check_data_length: bool + check data length, by default False + limit_nums: int + using for debug, by default None + show_1min_logging: bool + show 1m logging, by default False; if True, there may be many warning logs + + Examples + --------- + # get daily data + $ python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d + """ + + _class = getattr( + self._cur_module, f"FundCollector{self.region.upper()}{interval}" + ) # type: Type[FundCollector] + _class( + self.source_dir, + max_workers=self.max_workers, + max_collector_count=max_collector_count, + delay=delay, + start=start, + end=end, + interval=interval, + check_data_length=check_data_length, + limit_nums=limit_nums, + show_1min_logging=show_1min_logging, + ).collector_data() + +if __name__ == "__main__": + fire.Fire(Run) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 5f34aae7d..3319025fc 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -34,6 +34,7 @@ _BENCH_CALENDAR_LIST = None _ALL_CALENDAR_LIST = None _HS_SYMBOLS = None _US_SYMBOLS = None +_EN_FUND_SYMBOLS = None _CALENDAR_MAP = {} # NOTE: Until 2020-10-20 20:00:00 @@ -220,6 +221,42 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: return _US_SYMBOLS +def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list: + """get en fund symbols + + Returns + ------- + fund symbols in China + """ + global _EN_FUND_SYMBOLS + + @deco_retry + def _get_eastmoney(): + url = "http://fund.eastmoney.com/js/fundcode_search.js" + resp = requests.get(url) + if resp.status_code != 200: + raise ValueError("request error") + try: + _symbols = [] + for sub_data in re.findall(r"[\[](.*?)[\]]", resp.content.decode().split("= [")[-1].replace("];", "")): + data = sub_data.replace("\"","").replace("'","") + # TODO: do we need other informations, like fund_name from ['000001', 'HXCZHH', '华夏成长混合', '混合型', 'HUAXIACHENGZHANGHUNHE'] + _symbols.append(data.split(",")[0]) + except Exception as e: + logger.warning(f"request error: {e}") + raise + if len(_symbols) < 8000: + raise ValueError("request error") + return _symbols + + if _EN_FUND_SYMBOLS is None: + _all_symbols = _get_eastmoney() + + _EN_FUND_SYMBOLS = sorted(set(_all_symbols)) + + return _EN_FUND_SYMBOLS + + def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str: """symbol suffix to prefix From db80b620d8408998cab30e4a6d51b9d038264c7e Mon Sep 17 00:00:00 2001 From: wangershi Date: Sun, 28 Feb 2021 17:03:14 +0800 Subject: [PATCH 03/11] ready for collector --- scripts/data_collector/fund/collector.py | 79 ++++++-------------- scripts/data_collector/fund/requirements.txt | 10 +++ 2 files changed, 32 insertions(+), 57 deletions(-) diff --git a/scripts/data_collector/fund/collector.py b/scripts/data_collector/fund/collector.py index 404b6af0e..f9b2a6775 100644 --- a/scripts/data_collector/fund/collector.py +++ b/scripts/data_collector/fund/collector.py @@ -20,20 +20,16 @@ import pandas as pd from tqdm import tqdm from loguru import logger from dateutil.tz import tzlocal -from qlib.utils import code_to_fname, fname_to_code CUR_DIR = Path(__file__).resolve().parent sys.path.append(str(CUR_DIR.parent.parent)) -from data_collector.utils import get_calendar_list, get_en_fund_symbols +from data_collector.utils import get_en_fund_symbols INDEX_BENCH_URL = "http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}" REGION_CN = "CN" -REGION_US = "US" - class FundData: START_DATETIME = pd.Timestamp("2000-01-01") - HIGH_FREQ_START_DATETIME = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6)) END_DATETIME = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)) INTERVAL_1d = "1d" @@ -44,7 +40,6 @@ class FundData: end=None, interval="1d", delay=0, - show_1min_logging: bool = False, ): """ @@ -60,22 +55,15 @@ class FundData: start datetime, default None end: str end datetime, default None - show_1min_logging: bool - show 1min logging, by default False; if True, there may be many warning logs """ self._timezone = tzlocal() if timezone is None else timezone self._delay = delay self._interval = interval - self._show_1min_logging = show_1min_logging self.start_datetime = pd.Timestamp(str(start)) if start else self.START_DATETIME self.end_datetime = min(pd.Timestamp(str(end)) if end else self.END_DATETIME, self.END_DATETIME) if self._interval != self.INTERVAL_1d: raise ValueError(f"interval error: {self._interval}") - # using for 1min - self._next_datetime = self.convert_datetime(self.start_datetime.date() + pd.Timedelta(days=1), self._timezone) - self._latest_datetime = self.convert_datetime(self.end_datetime.date(), self._timezone) - self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone) self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone) @@ -92,33 +80,26 @@ class FundData: time.sleep(self._delay) @staticmethod - def get_data_from_remote(symbol, interval, start, end, show_1min_logging: bool = False): + def get_data_from_remote(symbol, interval, start, end): error_msg = f"{symbol}-{interval}-{start}-{end}" try: - _resp = None - # TODO: numberOfHistoricalDaysToCrawl should be bigger enouhg - url = INDEX_BENCH_URL.format(index_code=symbol, numberOfHistoricalDaysToCrawl=100, startDate=start, endDate=end) + url = INDEX_BENCH_URL.format(index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end) resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"}) if resp.status_code != 200: raise ValueError("request error") - try: - data = json.loads(resp.text.split("(")[-1].split(")")[0]) + + data = json.loads(resp.text.split("(")[-1].split(")")[0]) - # Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html - SYType = data["Data"]["SYType"] - if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"): - raise Exception("The fund contains 每*份收益") + # Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html + SYType = data["Data"]["SYType"] + if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"): + raise Exception("The fund contains 每*份收益") - _resp = pd.DataFrame( - data["Data"]["LSJZList"] - ) - - except Exception as e: - logger.warning(f"request error: {e}") - raise + # TODO: should we sort the value by datetime? + _resp = pd.DataFrame(data["Data"]["LSJZList"]) if isinstance(_resp, pd.DataFrame): return _resp.reset_index() @@ -134,7 +115,6 @@ class FundData: interval=_remote_interval, start=start_, end=end_, - show_1min_logging=self._show_1min_logging, ) if self._interval == self.INTERVAL_1d: @@ -156,14 +136,13 @@ class FundCollector: delay=0, check_data_length: bool = False, limit_nums: int = None, - show_1min_logging: bool = False, ): """ Parameters ---------- save_dir: str - stock save dir + fund save dir max_workers: int workers, default 4 max_collector_count: int @@ -180,8 +159,6 @@ class FundCollector: check data length, by default False limit_nums: int using for debug, by default None - show_1min_logging: bool - show 1m logging, by default False; if True, there may be many warning logs """ self.save_dir = Path(save_dir).expanduser().resolve() self.save_dir.mkdir(parents=True, exist_ok=True) @@ -206,7 +183,6 @@ class FundCollector: end=end, interval=interval, delay=delay, - show_1min_logging=show_1min_logging, ) @property @@ -240,13 +216,14 @@ class FundCollector: logger.warning(f"{symbol} is empty") return - symbol = code_to_fname(symbol) - stock_path = self.save_dir.joinpath(f"{symbol}.csv") + fund_path = self.save_dir.joinpath(f"{symbol}.csv") df["symbol"] = symbol - if stock_path.exists(): - _old_df = pd.read_csv(stock_path) + if fund_path.exists(): + # TODO: read the fund code as str, not int, like "000001" shouldn't be "1" + _old_df = pd.read_csv(fund_path) + # TODO: remove the duplicate date df = _old_df.append(df, sort=False) - df.to_csv(stock_path, index=False) + df.to_csv(fund_path, index=False) def _save_small_data(self, symbol, df): if len(df) <= self.min_numbers_trading: @@ -274,7 +251,6 @@ class FundCollector: return _result def _collector(self, fund_list): - error_symbol = [] with ThreadPoolExecutor(max_workers=self.max_workers) as executor: with tqdm(total=len(fund_list)) as p_bar: @@ -301,7 +277,7 @@ class FundCollector: for _symbol, _df_list in self._mini_symbol_map.items(): self.save_fund(_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])) if self._mini_symbol_map: - logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self._mini_symbol_map.keys())}") + logger.warning(f"less than {self.min_numbers_trading} fund list: {list(self._mini_symbol_map.keys())}") logger.info(f"total {len(self.fund_list)}, error: {len(set(fund_list))}") class FundollectorCN(FundCollector, ABC): @@ -322,30 +298,23 @@ class FundCollectorCN1d(FundollectorCN): return 252 / 4 class Run: - def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN): + def __init__(self, source_dir=None, max_workers=4, region=REGION_CN): """ Parameters ---------- source_dir: str The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source" - normalize_dir: str - Directory for normalize data, default "Path(__file__).parent/normalize" max_workers: int Concurrent number, default is 4 region: str - region, value from ["CN", "US"], default "CN" + region, value from ["CN"], default "CN" """ if source_dir is None: source_dir = CUR_DIR.joinpath("source") self.source_dir = Path(source_dir).expanduser().resolve() self.source_dir.mkdir(parents=True, exist_ok=True) - if normalize_dir is None: - normalize_dir = CUR_DIR.joinpath("normalize") - self.normalize_dir = Path(normalize_dir).expanduser().resolve() - self.normalize_dir.mkdir(parents=True, exist_ok=True) - self._cur_module = importlib.import_module("collector") self.max_workers = max_workers self.region = region @@ -359,7 +328,6 @@ class Run: interval="1d", check_data_length=False, limit_nums=None, - show_1min_logging=False, ): """download data from Internet @@ -375,12 +343,10 @@ class Run: start datetime, default "2000-01-01" end: str end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`` - check_data_length: bool + check_data_length: bool # if this param useful? check data length, by default False limit_nums: int using for debug, by default None - show_1min_logging: bool - show 1m logging, by default False; if True, there may be many warning logs Examples --------- @@ -401,7 +367,6 @@ class Run: interval=interval, check_data_length=check_data_length, limit_nums=limit_nums, - show_1min_logging=show_1min_logging, ).collector_data() if __name__ == "__main__": diff --git a/scripts/data_collector/fund/requirements.txt b/scripts/data_collector/fund/requirements.txt index e69de29bb..11c6730c0 100644 --- a/scripts/data_collector/fund/requirements.txt +++ b/scripts/data_collector/fund/requirements.txt @@ -0,0 +1,10 @@ +loguru +fire +requests +numpy +pandas +tqdm +lxml +loguru +yahooquery +json \ No newline at end of file From 3082f6ac1ba201ebcdeb115f27e97f75d2017439 Mon Sep 17 00:00:00 2001 From: wangershi Date: Sun, 28 Feb 2021 19:06:40 +0800 Subject: [PATCH 04/11] ready for dump_bin --- scripts/data_collector/fund/README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/scripts/data_collector/fund/README.md b/scripts/data_collector/fund/README.md index b14938a3d..bcbbbcba7 100644 --- a/scripts/data_collector/fund/README.md +++ b/scripts/data_collector/fund/README.md @@ -23,7 +23,7 @@ python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d -- # dump data cd qlib/scripts -python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/qlib_cn_1d --freq day --exclude_fields date,adjclose,dividends,splits,symbol +python dump_bin.py dump_all --csv_path ~/.qlib/fund_data/source/cn_1d --qlib_dir ~/.qlib/qlib_data/cn_fund_data --freq day --date_field_name FSRQ --include_fields DWJZ,LJJZ ``` @@ -33,8 +33,8 @@ python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1d_nor --qli import qlib from qlib.data import D -qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1d", region="CN") -df = D.features(D.instruments("all"), ["$close"], freq="day") +qlib.init(provider_uri="~/.qlib/qlib_data/cn_fund_data") +df = D.features(D.instruments(market="all"), ["$DWJZ", "$LJJZ"], freq="day") ``` @@ -45,5 +45,5 @@ pythono collector.py collector_data --help ## Parameters -- interval: 1min or 1d -- region: CN or US +- interval: 1d +- region: CN From 82353b20e18462d8b1a67742b2b6210dc80461ee Mon Sep 17 00:00:00 2001 From: wangershi Date: Mon, 1 Mar 2021 21:10:46 +0800 Subject: [PATCH 05/11] black format --- scripts/data_collector/fund/collector.py | 14 +++++++++----- scripts/data_collector/utils.py | 2 +- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/scripts/data_collector/fund/collector.py b/scripts/data_collector/fund/collector.py index f9b2a6775..a2b7089a1 100644 --- a/scripts/data_collector/fund/collector.py +++ b/scripts/data_collector/fund/collector.py @@ -28,6 +28,7 @@ from data_collector.utils import get_en_fund_symbols INDEX_BENCH_URL = "http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}" REGION_CN = "CN" + class FundData: START_DATETIME = pd.Timestamp("2000-01-01") END_DATETIME = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)) @@ -85,12 +86,14 @@ class FundData: try: # TODO: numberOfHistoricalDaysToCrawl should be bigger enouhg - url = INDEX_BENCH_URL.format(index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end) + url = INDEX_BENCH_URL.format( + index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end + ) resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"}) if resp.status_code != 200: raise ValueError("request error") - + data = json.loads(resp.text.split("(")[-1].split(")")[0]) # Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html @@ -280,6 +283,7 @@ class FundCollector: logger.warning(f"less than {self.min_numbers_trading} fund list: {list(self._mini_symbol_map.keys())}") logger.info(f"total {len(self.fund_list)}, error: {len(set(fund_list))}") + class FundollectorCN(FundCollector, ABC): def get_fund_list(self): logger.info("get cn fund symbols......") @@ -297,6 +301,7 @@ class FundCollectorCN1d(FundollectorCN): def min_numbers_trading(self): return 252 / 4 + class Run: def __init__(self, source_dir=None, max_workers=4, region=REGION_CN): """ @@ -354,9 +359,7 @@ class Run: $ python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d """ - _class = getattr( - self._cur_module, f"FundCollector{self.region.upper()}{interval}" - ) # type: Type[FundCollector] + _class = getattr(self._cur_module, f"FundCollector{self.region.upper()}{interval}") # type: Type[FundCollector] _class( self.source_dir, max_workers=self.max_workers, @@ -369,5 +372,6 @@ class Run: limit_nums=limit_nums, ).collector_data() + if __name__ == "__main__": fire.Fire(Run) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 3319025fc..5d5822f91 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -239,7 +239,7 @@ def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list: try: _symbols = [] for sub_data in re.findall(r"[\[](.*?)[\]]", resp.content.decode().split("= [")[-1].replace("];", "")): - data = sub_data.replace("\"","").replace("'","") + data = sub_data.replace('"', "").replace("'", "") # TODO: do we need other informations, like fund_name from ['000001', 'HXCZHH', '华夏成长混合', '混合型', 'HUAXIACHENGZHANGHUNHE'] _symbols.append(data.split(",")[0]) except Exception as e: From 34b7da1dd88ad6886d36f44ddf7d8ccf44a66eda Mon Sep 17 00:00:00 2001 From: wangershi Date: Wed, 3 Mar 2021 22:49:48 +0800 Subject: [PATCH 06/11] add calendar list by threshold --- scripts/data_collector/utils.py | 55 +++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 5d5822f91..1a08c514f 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import re +import os import time import bisect import pickle @@ -14,6 +15,9 @@ import pandas as pd from lxml import etree from loguru import logger from yahooquery import Ticker +from tqdm import tqdm +from functools import partial +from concurrent.futures import ProcessPoolExecutor HS_SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}" @@ -94,6 +98,57 @@ def get_calendar_list(bench_code="CSI300") -> list: return calendar +def return_date_list(source_dir, date_field_name, file_path): + df = pd.read_csv(Path(source_dir).joinpath(file_path), sep=",", index_col=0) + + return df[date_field_name].to_list() + + +def get_calendar_list_by_ratio( + source_dir: [str, Path], date_field_name: str = "date", threshold: float = 0.5, max_workers: int = 16 +) -> list: + """get calendar list by selecting the date when few funds trade in this day + + Parameters + ---------- + source_dir: str or Path + The directory where the raw data collected from the Internet is saved + date_field_name: str + date field name, default is date + threshold: float + threshold to exclude some days when few funds trade in this day, default 0.5 + max_workers: int + Concurrent number, default is 16 + + Returns + ------- + history calendar list + """ + logger.info(f"get calendar list from {source_dir} by threshold = {threshold}......") + + _number_all_funds = len(os.listdir(source_dir)) + + _list_all_date = dict() + + _fun = partial(return_date_list, source_dir, date_field_name) + + with tqdm(total=_number_all_funds) as p_bar: + with ProcessPoolExecutor(max_workers=max_workers) as executor: + for date_list in executor.map(_fun, os.listdir(source_dir)): + for date in date_list: + if date in _list_all_date.keys(): + _list_all_date[date] += 1 + else: + _list_all_date[date] = 0 + + p_bar.update() + + _threshold_number = int(_number_all_funds * threshold) + calendar = [date for date in _list_all_date if _list_all_date[date] >= _threshold_number] + + return calendar + + def get_hs_stock_symbols() -> list: """get SH/SZ stock symbols From 11412727ef9089863b88f4d58b332513350cb115 Mon Sep 17 00:00:00 2001 From: wangershi Date: Sun, 7 Mar 2021 18:51:38 +0800 Subject: [PATCH 07/11] add normalizer --- scripts/data_collector/fund/README.md | 4 +- scripts/data_collector/fund/collector.py | 171 ++++++++++++++++++++++- scripts/data_collector/utils.py | 43 ++++-- 3 files changed, 200 insertions(+), 18 deletions(-) diff --git a/scripts/data_collector/fund/README.md b/scripts/data_collector/fund/README.md index bcbbbcba7..c729b7eaa 100644 --- a/scripts/data_collector/fund/README.md +++ b/scripts/data_collector/fund/README.md @@ -20,10 +20,12 @@ pip install -r requirements.txt # download from eastmoney.com python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d +# normalize +python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_1d --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ # dump data cd qlib/scripts -python dump_bin.py dump_all --csv_path ~/.qlib/fund_data/source/cn_1d --qlib_dir ~/.qlib/qlib_data/cn_fund_data --freq day --date_field_name FSRQ --include_fields DWJZ,LJJZ +python dump_bin.py dump_all --csv_path ~/.qlib/fund_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/cn_fund_data --freq day --date_field_name FSRQ --include_fields DWJZ,LJJZ ``` diff --git a/scripts/data_collector/fund/collector.py b/scripts/data_collector/fund/collector.py index a2b7089a1..795d8848e 100644 --- a/scripts/data_collector/fund/collector.py +++ b/scripts/data_collector/fund/collector.py @@ -23,7 +23,7 @@ from dateutil.tz import tzlocal CUR_DIR = Path(__file__).resolve().parent sys.path.append(str(CUR_DIR.parent.parent)) -from data_collector.utils import get_en_fund_symbols +from data_collector.utils import get_calendar_list, get_en_fund_symbols INDEX_BENCH_URL = "http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}" REGION_CN = "CN" @@ -302,14 +302,149 @@ class FundCollectorCN1d(FundollectorCN): return 252 / 4 +class FundNormalize: + COLUMNS = ["open", "close", "high", "low", "volume"] + DAILY_FORMAT = "%Y-%m-%d" + + def __init__( + self, + date_field_name: str = "date", + symbol_field_name: str = "symbol", + ): + """ + + Parameters + ---------- + date_field_name: str + date field name, default is date + symbol_field_name: str + symbol field name, default is symbol + """ + self._date_field_name = date_field_name + self._symbol_field_name = symbol_field_name + + self._calendar_list = self._get_calendar_list() + print (self._calendar_list) + + @staticmethod + def normalize_fund( + df: pd.DataFrame, + calendar_list: list = None, + date_field_name: str = "date", + symbol_field_name: str = "symbol", + ): + if df.empty: + return df + df = df.copy() + df.set_index(date_field_name, inplace=True) + df.index = pd.to_datetime(df.index) + df = df[~df.index.duplicated(keep="first")] + if calendar_list is not None: + df = df.reindex( + pd.DataFrame(index=calendar_list) + .loc[ + pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date() + + pd.Timedelta(hours=23, minutes=59) + ] + .index + ) + df.sort_index(inplace=True) + + df.index.names = [date_field_name] + return df.reset_index() + + def normalize(self, df: pd.DataFrame) -> pd.DataFrame: + # normalize + df = self.normalize_fund(df, self._calendar_list, self._date_field_name, self._symbol_field_name) + return df + + @abc.abstractmethod + def _get_calendar_list(self): + """Get benchmark calendar""" + raise NotImplementedError("") + + +class FundNormalize1d(FundNormalize, ABC): + DAILY_FORMAT = "%Y-%m-%d" + + def normalize(self, df: pd.DataFrame) -> pd.DataFrame: + df = super(FundNormalize, self).normalize(df) + return df + + +class FundNormalizeCN: + def _get_calendar_list(self): + return get_calendar_list("ALL") + + +class FundNormalizeCN1d(FundNormalizeCN, FundNormalize1d): + pass + + +class Normalize: + def __init__( + self, + source_dir: [str, Path], + target_dir: [str, Path], + normalize_class: Type[FundNormalize], + max_workers: int = 16, + date_field_name: str = "date", + symbol_field_name: str = "symbol", + ): + """ + + Parameters + ---------- + source_dir: str or Path + The directory where the raw data collected from the Internet is saved + target_dir: str or Path + Directory for normalize data + normalize_class: Type[FundNormalize] + normalize class + max_workers: int + Concurrent number, default is 16 + date_field_name: str + date field name, default is date + symbol_field_name: str + symbol field name, default is symbol + """ + if not (source_dir and target_dir): + raise ValueError("source_dir and target_dir cannot be None") + self._source_dir = Path(source_dir).expanduser() + self._target_dir = Path(target_dir).expanduser() + self._target_dir.mkdir(parents=True, exist_ok=True) + + self._max_workers = max_workers + + self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name) + + def _executor(self, file_path: Path): + file_path = Path(file_path) + df = pd.read_csv(file_path) + df = self._normalize_obj.normalize(df) + if not df.empty: + df.to_csv(self._target_dir.joinpath(file_path.name), index=False) + + def normalize(self): + logger.info("normalize data......") + + with ProcessPoolExecutor(max_workers=self._max_workers) as worker: + file_list = list(self._source_dir.glob("*.csv")) + with tqdm(total=len(file_list)) as p_bar: + for _ in worker.map(self._executor, file_list): + p_bar.update() + + class Run: - def __init__(self, source_dir=None, max_workers=4, region=REGION_CN): + def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN): """ Parameters ---------- source_dir: str The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source" + normalize_dir: str + Directory for normalize data, default "Path(__file__).parent/normalize" max_workers: int Concurrent number, default is 4 region: str @@ -320,6 +455,11 @@ class Run: self.source_dir = Path(source_dir).expanduser().resolve() self.source_dir.mkdir(parents=True, exist_ok=True) + if normalize_dir is None: + normalize_dir = CUR_DIR.joinpath("normalize") + self.normalize_dir = Path(normalize_dir).expanduser().resolve() + self.normalize_dir.mkdir(parents=True, exist_ok=True) + self._cur_module = importlib.import_module("collector") self.max_workers = max_workers self.region = region @@ -372,6 +512,33 @@ class Run: limit_nums=limit_nums, ).collector_data() + def normalize_data(self, interval: str = "1d", date_field_name: str = "date", symbol_field_name: str = "symbol"): + """normalize data + + Parameters + ---------- + interval: str + freq, value from [1d], default 1d + date_field_name: str + date field name, default date + symbol_field_name: str + symbol field name, default symbol + + Examples + --------- + $ python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_1d --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ + """ + _class = getattr(self._cur_module, f"FundNormalize{self.region.upper()}{interval}") + fc = Normalize( + source_dir=self.source_dir, + target_dir=self.normalize_dir, + normalize_class=_class, + max_workers=self.max_workers, + date_field_name=date_field_name, + symbol_field_name=symbol_field_name, + ) + fc.normalize() + if __name__ == "__main__": fire.Fire(Run) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 1a08c514f..56d010974 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -98,14 +98,14 @@ def get_calendar_list(bench_code="CSI300") -> list: return calendar -def return_date_list(source_dir, date_field_name, file_path): - df = pd.read_csv(Path(source_dir).joinpath(file_path), sep=",", index_col=0) - - return df[date_field_name].to_list() +def return_date_list(source_dir, date_field_name: str, file_path: Path): + file_path = Path(file_path) + date_list = pd.read_csv(Path(source_dir).joinpath(file_path), sep=",", index_col=0)[date_field_name].to_list() + return sorted(map(lambda x: pd.Timestamp(x), date_list)) def get_calendar_list_by_ratio( - source_dir: [str, Path], date_field_name: str = "date", threshold: float = 0.5, max_workers: int = 16 + source_dir: [str, Path], date_field_name: str = "date", threshold: float = 0.5, minimum_count: int = 10, max_workers: int = 16 ) -> list: """get calendar list by selecting the date when few funds trade in this day @@ -117,6 +117,8 @@ def get_calendar_list_by_ratio( date field name, default is date threshold: float threshold to exclude some days when few funds trade in this day, default 0.5 + minimum_count: int + minimum count of funds should trade in one day max_workers: int Concurrent number, default is 16 @@ -126,25 +128,36 @@ def get_calendar_list_by_ratio( """ logger.info(f"get calendar list from {source_dir} by threshold = {threshold}......") - _number_all_funds = len(os.listdir(source_dir)) + source_dir = Path(source_dir).expanduser() + file_list = list(source_dir.glob("*.csv")) - _list_all_date = dict() + _number_all_funds = len(file_list) + logger.info(f"count how many funds trade in this day......") + _dict_count_trade = dict() # dict{date:count} _fun = partial(return_date_list, source_dir, date_field_name) - with tqdm(total=_number_all_funds) as p_bar: with ProcessPoolExecutor(max_workers=max_workers) as executor: - for date_list in executor.map(_fun, os.listdir(source_dir)): + for date_list in executor.map(_fun, file_list[:_number_all_funds]): for date in date_list: - if date in _list_all_date.keys(): - _list_all_date[date] += 1 - else: - _list_all_date[date] = 0 + if date not in _dict_count_trade.keys(): + _dict_count_trade[date] = 0 + + _dict_count_trade[date] += 1 p_bar.update() + + logger.info(f"count how many funds have founded in this day......") + _dict_count_founding = {date:_number_all_funds for date in _dict_count_trade.keys()} # dict{date:count} + with tqdm(total=_number_all_funds) as p_bar: + with ProcessPoolExecutor(max_workers=max_workers) as executor: + for date_list in executor.map(_fun, file_list[:_number_all_funds]): + oldest_date = sorted(date_list)[0] # this fund haven't found before this day + for date in _dict_count_founding.keys(): + if date < oldest_date: + _dict_count_founding[date] -= 1 - _threshold_number = int(_number_all_funds * threshold) - calendar = [date for date in _list_all_date if _list_all_date[date] >= _threshold_number] + calendar = [date for date in _dict_count_trade if _dict_count_trade[date] >= max(int(_dict_count_founding[date] * threshold), minimum_count)] return calendar From 6bcd88973b3ff906ac61e174185a28ea2b1d8ea0 Mon Sep 17 00:00:00 2001 From: wangershi Date: Sun, 7 Mar 2021 19:32:37 +0800 Subject: [PATCH 08/11] resolve one bug --- scripts/data_collector/fund/collector.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/scripts/data_collector/fund/collector.py b/scripts/data_collector/fund/collector.py index 795d8848e..08773cb9f 100644 --- a/scripts/data_collector/fund/collector.py +++ b/scripts/data_collector/fund/collector.py @@ -324,7 +324,6 @@ class FundNormalize: self._symbol_field_name = symbol_field_name self._calendar_list = self._get_calendar_list() - print (self._calendar_list) @staticmethod def normalize_fund( @@ -368,7 +367,7 @@ class FundNormalize1d(FundNormalize, ABC): DAILY_FORMAT = "%Y-%m-%d" def normalize(self, df: pd.DataFrame) -> pd.DataFrame: - df = super(FundNormalize, self).normalize(df) + df = super(FundNormalize1d, self).normalize(df) return df From 9df0361262eadd065ab783bb5349ef16203d04b4 Mon Sep 17 00:00:00 2001 From: wangershi Date: Sun, 7 Mar 2021 19:35:50 +0800 Subject: [PATCH 09/11] black --- scripts/data_collector/utils.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 56d010974..ed14ad6e1 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -105,7 +105,11 @@ def return_date_list(source_dir, date_field_name: str, file_path: Path): def get_calendar_list_by_ratio( - source_dir: [str, Path], date_field_name: str = "date", threshold: float = 0.5, minimum_count: int = 10, max_workers: int = 16 + source_dir: [str, Path], + date_field_name: str = "date", + threshold: float = 0.5, + minimum_count: int = 10, + max_workers: int = 16, ) -> list: """get calendar list by selecting the date when few funds trade in this day @@ -134,7 +138,7 @@ def get_calendar_list_by_ratio( _number_all_funds = len(file_list) logger.info(f"count how many funds trade in this day......") - _dict_count_trade = dict() # dict{date:count} + _dict_count_trade = dict() # dict{date:count} _fun = partial(return_date_list, source_dir, date_field_name) with tqdm(total=_number_all_funds) as p_bar: with ProcessPoolExecutor(max_workers=max_workers) as executor: @@ -146,9 +150,9 @@ def get_calendar_list_by_ratio( _dict_count_trade[date] += 1 p_bar.update() - + logger.info(f"count how many funds have founded in this day......") - _dict_count_founding = {date:_number_all_funds for date in _dict_count_trade.keys()} # dict{date:count} + _dict_count_founding = {date: _number_all_funds for date in _dict_count_trade.keys()} # dict{date:count} with tqdm(total=_number_all_funds) as p_bar: with ProcessPoolExecutor(max_workers=max_workers) as executor: for date_list in executor.map(_fun, file_list[:_number_all_funds]): @@ -157,7 +161,11 @@ def get_calendar_list_by_ratio( if date < oldest_date: _dict_count_founding[date] -= 1 - calendar = [date for date in _dict_count_trade if _dict_count_trade[date] >= max(int(_dict_count_founding[date] * threshold), minimum_count)] + calendar = [ + date + for date in _dict_count_trade + if _dict_count_trade[date] >= max(int(_dict_count_founding[date] * threshold), minimum_count) + ] return calendar From 4e7a147759286b8b20729a07381652030d428a58 Mon Sep 17 00:00:00 2001 From: wangershi Date: Sun, 14 Mar 2021 14:24:14 +0800 Subject: [PATCH 10/11] use base.py --- scripts/data_collector/base.py | 48 +-- scripts/data_collector/fund/collector.py | 417 ++++++----------------- 2 files changed, 121 insertions(+), 344 deletions(-) diff --git a/scripts/data_collector/base.py b/scripts/data_collector/base.py index ccd8f59e5..12983f6a5 100644 --- a/scripts/data_collector/base.py +++ b/scripts/data_collector/base.py @@ -46,7 +46,7 @@ class BaseCollector(abc.ABC): Parameters ---------- save_dir: str - stock save dir + instrument save dir max_workers: int workers, default 4 max_collector_count: int @@ -77,11 +77,11 @@ class BaseCollector(abc.ABC): self.start_datetime = self.normalize_start_datetime(start) self.end_datetime = self.normalize_end_datetime(end) - self.stock_list = sorted(set(self.get_stock_list())) + self.instrument_list = sorted(set(self.get_instrument_list())) if limit_nums is not None: try: - self.stock_list = self.stock_list[: int(limit_nums)] + self.instrument_list = self.instrument_list[: int(limit_nums)] except Exception as e: logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored") @@ -108,8 +108,8 @@ class BaseCollector(abc.ABC): raise NotImplementedError("rewrite min_numbers_trading") @abc.abstractmethod - def get_stock_list(self): - raise NotImplementedError("rewrite get_stock_list") + def get_instrument_list(self): + raise NotImplementedError("rewrite get_instrument_list") @abc.abstractmethod def normalize_symbol(self, symbol: str): @@ -158,27 +158,27 @@ class BaseCollector(abc.ABC): return _result def save_instrument(self, symbol, df: pd.DataFrame): - """save stock data to file + """save instrument data to file Parameters ---------- symbol: str - stock code + instrument code df : pd.DataFrame df.columns must contain "symbol" and "datetime" """ - if df.empty: + if df is None or df.empty: logger.warning(f"{symbol} is empty") return symbol = self.normalize_symbol(symbol) symbol = code_to_fname(symbol) - stock_path = self.save_dir.joinpath(f"{symbol}.csv") + instrument_path = self.save_dir.joinpath(f"{symbol}.csv") df["symbol"] = symbol - if stock_path.exists(): - _old_df = pd.read_csv(stock_path) + if instrument_path.exists(): + _old_df = pd.read_csv(instrument_path) df = _old_df.append(df, sort=False) - df.to_csv(stock_path, index=False) + df.to_csv(instrument_path, index=False) def cache_small_data(self, symbol, df): if len(df) <= self.min_numbers_trading: @@ -191,38 +191,38 @@ class BaseCollector(abc.ABC): self.mini_symbol_map.pop(symbol) return self.NORMAL_FLAG - def _collector(self, stock_list): + def _collector(self, instrument_list): error_symbol = [] with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - with tqdm(total=len(stock_list)) as p_bar: - for _symbol, _result in zip(stock_list, executor.map(self._simple_collector, stock_list)): + with tqdm(total=len(instrument_list)) as p_bar: + for _symbol, _result in zip(instrument_list, executor.map(self._simple_collector, instrument_list)): if _result != self.NORMAL_FLAG: error_symbol.append(_symbol) p_bar.update() print(error_symbol) logger.info(f"error symbol nums: {len(error_symbol)}") - logger.info(f"current get symbol nums: {len(stock_list)}") + logger.info(f"current get symbol nums: {len(instrument_list)}") error_symbol.extend(self.mini_symbol_map.keys()) return sorted(set(error_symbol)) def collector_data(self): """collector data""" logger.info("start collector data......") - stock_list = self.stock_list + instrument_list = self.instrument_list for i in range(self.max_collector_count): - if not stock_list: + if not instrument_list: break logger.info(f"getting data: {i+1}") - stock_list = self._collector(stock_list) + instrument_list = self._collector(instrument_list) logger.info(f"{i+1} finish.") for _symbol, _df_list in self.mini_symbol_map.items(): self.save_instrument( _symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"]) ) if self.mini_symbol_map: - logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self.mini_symbol_map.keys())}") - logger.info(f"total {len(self.stock_list)}, error: {len(set(stock_list))}") + logger.warning(f"less than {self.min_numbers_trading} instrument list: {list(self.mini_symbol_map.keys())}") + logger.info(f"total {len(self.instrument_list)}, error: {len(set(instrument_list))}") class BaseNormalize(abc.ABC): @@ -386,9 +386,9 @@ class BaseRun(abc.ABC): Examples --------- # get daily data - $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d + $ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d # get 1m data - $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m + $ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m """ _class = getattr(self._cur_module, self.collector_class_name) # type: Type[BaseCollector] @@ -416,7 +416,7 @@ class BaseRun(abc.ABC): Examples --------- - $ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d + $ python collector.py normalize_data --source_dir ~/.qlib/instrument_data/source --normalize_dir ~/.qlib/instrument_data/normalize --region CN --interval 1d """ _class = getattr(self._cur_module, self.normalize_class_name) yc = Normalize( diff --git a/scripts/data_collector/fund/collector.py b/scripts/data_collector/fund/collector.py index 08773cb9f..1e0d2d8bf 100644 --- a/scripts/data_collector/fund/collector.py +++ b/scripts/data_collector/fund/collector.py @@ -11,123 +11,24 @@ import json from abc import ABC from pathlib import Path from typing import Iterable, Type -from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor import fire import requests import numpy as np import pandas as pd -from tqdm import tqdm from loguru import logger from dateutil.tz import tzlocal +from qlib.config import REG_CN as REGION_CN CUR_DIR = Path(__file__).resolve().parent sys.path.append(str(CUR_DIR.parent.parent)) +from data_collector.base import BaseCollector, BaseNormalize, BaseRun from data_collector.utils import get_calendar_list, get_en_fund_symbols INDEX_BENCH_URL = "http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}" -REGION_CN = "CN" -class FundData: - START_DATETIME = pd.Timestamp("2000-01-01") - END_DATETIME = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)) - INTERVAL_1d = "1d" - - def __init__( - self, - timezone: str = None, - start=None, - end=None, - interval="1d", - delay=0, - ): - """ - - Parameters - ---------- - timezone: str - The timezone where the data is located - delay: float - time.sleep(delay), default 0 - interval: str - freq, value from [1d], default 1d - start: str - start datetime, default None - end: str - end datetime, default None - """ - self._timezone = tzlocal() if timezone is None else timezone - self._delay = delay - self._interval = interval - self.start_datetime = pd.Timestamp(str(start)) if start else self.START_DATETIME - self.end_datetime = min(pd.Timestamp(str(end)) if end else self.END_DATETIME, self.END_DATETIME) - if self._interval != self.INTERVAL_1d: - raise ValueError(f"interval error: {self._interval}") - - self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone) - self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone) - - @staticmethod - def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone): - try: - dt = pd.Timestamp(dt, tz=timezone).timestamp() - dt = pd.Timestamp(dt, tz=tzlocal(), unit="s") - except ValueError as e: - pass - return dt - - def _sleep(self): - time.sleep(self._delay) - - @staticmethod - def get_data_from_remote(symbol, interval, start, end): - error_msg = f"{symbol}-{interval}-{start}-{end}" - - try: - # TODO: numberOfHistoricalDaysToCrawl should be bigger enouhg - url = INDEX_BENCH_URL.format( - index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end - ) - resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"}) - - if resp.status_code != 200: - raise ValueError("request error") - - data = json.loads(resp.text.split("(")[-1].split(")")[0]) - - # Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html - SYType = data["Data"]["SYType"] - if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"): - raise Exception("The fund contains 每*份收益") - - # TODO: should we sort the value by datetime? - _resp = pd.DataFrame(data["Data"]["LSJZList"]) - - if isinstance(_resp, pd.DataFrame): - return _resp.reset_index() - except Exception as e: - logger.warning(f"{error_msg}:{e}") - - def get_data(self, symbol: str) -> [pd.DataFrame]: - def _get_simple(start_, end_): - self._sleep() - _remote_interval = self._interval - return self.get_data_from_remote( - symbol, - interval=_remote_interval, - start=start_, - end=end_, - ) - - if self._interval == self.INTERVAL_1d: - _result = _get_simple(self.start_datetime, self.end_datetime) - else: - raise ValueError(f"cannot support {self._interval}") - return _result - - -class FundCollector: +class FundCollector(BaseCollector): def __init__( self, save_dir: [str, Path], @@ -163,134 +64,108 @@ class FundCollector: limit_nums: int using for debug, by default None """ - self.save_dir = Path(save_dir).expanduser().resolve() - self.save_dir.mkdir(parents=True, exist_ok=True) - - self._delay = delay - self.max_workers = max_workers - self._max_collector_count = max_collector_count - self._mini_symbol_map = {} - self._interval = interval - self._check_small_data = check_data_length - - self.fund_list = sorted(set(self.get_fund_list())) - if limit_nums is not None: - try: - self.fund_list = self.fund_list[: int(limit_nums)] - except Exception as e: - logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored") - - self.fund_data = FundData( - timezone=self._timezone, + super(FundCollector, self).__init__( + save_dir=save_dir, start=start, end=end, interval=interval, + max_workers=max_workers, + max_collector_count=max_collector_count, delay=delay, + check_data_length=check_data_length, + limit_nums=limit_nums, ) - @property - @abc.abstractmethod - def min_numbers_trading(self): - # daily, one year: 252 / 4 - # us 1min, a week: 6.5 * 60 * 5 - # cn 1min, a week: 4 * 60 * 5 - raise NotImplementedError("rewrite min_numbers_trading") + self.init_datetime() - @abc.abstractmethod - def get_fund_list(self): - raise NotImplementedError("rewrite get_fund_list") + def init_datetime(self): + if self.interval == self.INTERVAL_1min: + self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN) + elif self.interval == self.INTERVAL_1d: + pass + else: + raise ValueError(f"interval error: {self.interval}") + + self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone) + self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone) + + @staticmethod + def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone): + try: + dt = pd.Timestamp(dt, tz=timezone).timestamp() + dt = pd.Timestamp(dt, tz=tzlocal(), unit="s") + except ValueError as e: + pass + return dt @property @abc.abstractmethod def _timezone(self): raise NotImplementedError("rewrite get_timezone") - def save_fund(self, symbol, df: pd.DataFrame): - """save fund data to file + @staticmethod + def get_data_from_remote(symbol, interval, start, end): + error_msg = f"{symbol}-{interval}-{start}-{end}" - Parameters - ---------- - symbol: str - fund code - df : pd.DataFrame - df.columns must contain "symbol" and "datetime" - """ - if df.empty: - logger.warning(f"{symbol} is empty") - return + try: + # TODO: numberOfHistoricalDaysToCrawl should be bigger enouhg + url = INDEX_BENCH_URL.format( + index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end + ) + resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"}) - fund_path = self.save_dir.joinpath(f"{symbol}.csv") - df["symbol"] = symbol - if fund_path.exists(): - # TODO: read the fund code as str, not int, like "000001" shouldn't be "1" - _old_df = pd.read_csv(fund_path) - # TODO: remove the duplicate date - df = _old_df.append(df, sort=False) - df.to_csv(fund_path, index=False) + if resp.status_code != 200: + raise ValueError("request error") - def _save_small_data(self, symbol, df): - if len(df) <= self.min_numbers_trading: - logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!") - _temp = self._mini_symbol_map.setdefault(symbol, []) - _temp.append(df.copy()) - return None + data = json.loads(resp.text.split("(")[-1].split(")")[0]) + + # Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html + SYType = data["Data"]["SYType"] + if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"): + raise Exception("The fund contains 每*份收益") + + # TODO: should we sort the value by datetime? + _resp = pd.DataFrame(data["Data"]["LSJZList"]) + + if isinstance(_resp, pd.DataFrame): + return _resp.reset_index() + except Exception as e: + logger.warning(f"{error_msg}:{e}") + + def get_data( + self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp + ) -> [pd.DataFrame]: + def _get_simple(start_, end_): + self.sleep() + _remote_interval = interval + return self.get_data_from_remote( + symbol, + interval=_remote_interval, + start=start_, + end=end_, + ) + + if interval == self.INTERVAL_1d: + _result = _get_simple(start_datetime, end_datetime) else: - if symbol in self._mini_symbol_map: - self._mini_symbol_map.pop(symbol) - return symbol - - def _get_data(self, symbol): - _result = None - df = self.fund_data.get_data(symbol) - if isinstance(df, pd.DataFrame): - if not df.empty: - if self._check_small_data: - if self._save_small_data(symbol, df) is not None: - _result = symbol - self.save_fund(symbol, df) - else: - _result = symbol - self.save_fund(symbol, df) + raise ValueError(f"cannot support {interval}") return _result - def _collector(self, fund_list): - error_symbol = [] - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - with tqdm(total=len(fund_list)) as p_bar: - for _symbol, _result in zip(fund_list, executor.map(self._get_data, fund_list)): - if _result is None: - error_symbol.append(_symbol) - p_bar.update() - print(error_symbol) - logger.info(f"error symbol nums: {len(error_symbol)}") - logger.info(f"current get symbol nums: {len(fund_list)}") - error_symbol.extend(self._mini_symbol_map.keys()) - return sorted(set(error_symbol)) - def collector_data(self): """collector data""" - logger.info("start collector fund data......") - fund_list = self.fund_list - for i in range(self._max_collector_count): - if not fund_list: - break - logger.info(f"getting data: {i+1}") - fund_list = self._collector(fund_list) - logger.info(f"{i+1} finish.") - for _symbol, _df_list in self._mini_symbol_map.items(): - self.save_fund(_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])) - if self._mini_symbol_map: - logger.warning(f"less than {self.min_numbers_trading} fund list: {list(self._mini_symbol_map.keys())}") - logger.info(f"total {len(self.fund_list)}, error: {len(set(fund_list))}") + super(FundCollector, self).collector_data() class FundollectorCN(FundCollector, ABC): - def get_fund_list(self): + def get_instrument_list(self): logger.info("get cn fund symbols......") symbols = get_en_fund_symbols() logger.info(f"get {len(symbols)} symbols.") return symbols + def normalize_symbol(self, symbol): + return symbol + @property def _timezone(self): return "Asia/Shanghai" @@ -302,29 +177,9 @@ class FundCollectorCN1d(FundollectorCN): return 252 / 4 -class FundNormalize: - COLUMNS = ["open", "close", "high", "low", "volume"] +class FundNormalize(BaseNormalize): DAILY_FORMAT = "%Y-%m-%d" - def __init__( - self, - date_field_name: str = "date", - symbol_field_name: str = "symbol", - ): - """ - - Parameters - ---------- - date_field_name: str - date field name, default is date - symbol_field_name: str - symbol field name, default is symbol - """ - self._date_field_name = date_field_name - self._symbol_field_name = symbol_field_name - - self._calendar_list = self._get_calendar_list() - @staticmethod def normalize_fund( df: pd.DataFrame, @@ -357,11 +212,6 @@ class FundNormalize: df = self.normalize_fund(df, self._calendar_list, self._date_field_name, self._symbol_field_name) return df - @abc.abstractmethod - def _get_calendar_list(self): - """Get benchmark calendar""" - raise NotImplementedError("") - class FundNormalize1d(FundNormalize, ABC): DAILY_FORMAT = "%Y-%m-%d" @@ -380,62 +230,8 @@ class FundNormalizeCN1d(FundNormalizeCN, FundNormalize1d): pass -class Normalize: - def __init__( - self, - source_dir: [str, Path], - target_dir: [str, Path], - normalize_class: Type[FundNormalize], - max_workers: int = 16, - date_field_name: str = "date", - symbol_field_name: str = "symbol", - ): - """ - - Parameters - ---------- - source_dir: str or Path - The directory where the raw data collected from the Internet is saved - target_dir: str or Path - Directory for normalize data - normalize_class: Type[FundNormalize] - normalize class - max_workers: int - Concurrent number, default is 16 - date_field_name: str - date field name, default is date - symbol_field_name: str - symbol field name, default is symbol - """ - if not (source_dir and target_dir): - raise ValueError("source_dir and target_dir cannot be None") - self._source_dir = Path(source_dir).expanduser() - self._target_dir = Path(target_dir).expanduser() - self._target_dir.mkdir(parents=True, exist_ok=True) - - self._max_workers = max_workers - - self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name) - - def _executor(self, file_path: Path): - file_path = Path(file_path) - df = pd.read_csv(file_path) - df = self._normalize_obj.normalize(df) - if not df.empty: - df.to_csv(self._target_dir.joinpath(file_path.name), index=False) - - def normalize(self): - logger.info("normalize data......") - - with ProcessPoolExecutor(max_workers=self._max_workers) as worker: - file_list = list(self._source_dir.glob("*.csv")) - with tqdm(total=len(file_list)) as p_bar: - for _ in worker.map(self._executor, file_list): - p_bar.update() - - -class Run: - def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN): +class Run(BaseRun): + def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN): """ Parameters @@ -446,23 +242,26 @@ class Run: Directory for normalize data, default "Path(__file__).parent/normalize" max_workers: int Concurrent number, default is 4 + interval: str + freq, value from [1min, 1d], default 1d region: str region, value from ["CN"], default "CN" """ - if source_dir is None: - source_dir = CUR_DIR.joinpath("source") - self.source_dir = Path(source_dir).expanduser().resolve() - self.source_dir.mkdir(parents=True, exist_ok=True) - - if normalize_dir is None: - normalize_dir = CUR_DIR.joinpath("normalize") - self.normalize_dir = Path(normalize_dir).expanduser().resolve() - self.normalize_dir.mkdir(parents=True, exist_ok=True) - - self._cur_module = importlib.import_module("collector") - self.max_workers = max_workers + super().__init__(source_dir, normalize_dir, max_workers, interval) self.region = region + @property + def collector_class_name(self): + return f"FundCollector{self.region.upper()}{self.interval}" + + @property + def normalize_class_name(self): + return f"FundNormalize{self.region.upper()}{self.interval}" + + @property + def default_base_dir(self) -> [Path, str]: + return CUR_DIR + def download_data( self, max_collector_count=2, @@ -498,26 +297,13 @@ class Run: $ python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d """ - _class = getattr(self._cur_module, f"FundCollector{self.region.upper()}{interval}") # type: Type[FundCollector] - _class( - self.source_dir, - max_workers=self.max_workers, - max_collector_count=max_collector_count, - delay=delay, - start=start, - end=end, - interval=interval, - check_data_length=check_data_length, - limit_nums=limit_nums, - ).collector_data() + super(Run, self).download_data(max_collector_count, delay, start, end, interval, check_data_length, limit_nums) - def normalize_data(self, interval: str = "1d", date_field_name: str = "date", symbol_field_name: str = "symbol"): + def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"): """normalize data Parameters ---------- - interval: str - freq, value from [1d], default 1d date_field_name: str date field name, default date symbol_field_name: str @@ -527,16 +313,7 @@ class Run: --------- $ python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_1d --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ """ - _class = getattr(self._cur_module, f"FundNormalize{self.region.upper()}{interval}") - fc = Normalize( - source_dir=self.source_dir, - target_dir=self.normalize_dir, - normalize_class=_class, - max_workers=self.max_workers, - date_field_name=date_field_name, - symbol_field_name=symbol_field_name, - ) - fc.normalize() + super(Run, self).normalize_data(date_field_name, symbol_field_name) if __name__ == "__main__": From d3160e94399ef395fdd12116b416a84bea3a87b4 Mon Sep 17 00:00:00 2001 From: wangershi Date: Thu, 18 Mar 2021 21:15:45 +0800 Subject: [PATCH 11/11] remove some useless code --- scripts/data_collector/fund/collector.py | 12 ++---------- scripts/data_collector/utils.py | 22 +++++++++++----------- 2 files changed, 13 insertions(+), 21 deletions(-) diff --git a/scripts/data_collector/fund/collector.py b/scripts/data_collector/fund/collector.py index 1e0d2d8bf..10800a7a3 100644 --- a/scripts/data_collector/fund/collector.py +++ b/scripts/data_collector/fund/collector.py @@ -151,10 +151,6 @@ class FundCollector(BaseCollector): raise ValueError(f"cannot support {interval}") return _result - def collector_data(self): - """collector data""" - super(FundCollector, self).collector_data() - class FundollectorCN(FundCollector, ABC): def get_instrument_list(self): @@ -213,12 +209,8 @@ class FundNormalize(BaseNormalize): return df -class FundNormalize1d(FundNormalize, ABC): - DAILY_FORMAT = "%Y-%m-%d" - - def normalize(self, df: pd.DataFrame) -> pd.DataFrame: - df = super(FundNormalize1d, self).normalize(df) - return df +class FundNormalize1d(FundNormalize): + pass class FundNormalizeCN: diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index ed14ad6e1..e8c9b9dc4 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -98,9 +98,8 @@ def get_calendar_list(bench_code="CSI300") -> list: return calendar -def return_date_list(source_dir, date_field_name: str, file_path: Path): - file_path = Path(file_path) - date_list = pd.read_csv(Path(source_dir).joinpath(file_path), sep=",", index_col=0)[date_field_name].to_list() +def return_date_list(date_field_name: str, file_path: Path): + date_list = pd.read_csv(file_path, sep=",", index_col=0)[date_field_name].to_list() return sorted(map(lambda x: pd.Timestamp(x), date_list)) @@ -139,10 +138,13 @@ def get_calendar_list_by_ratio( logger.info(f"count how many funds trade in this day......") _dict_count_trade = dict() # dict{date:count} - _fun = partial(return_date_list, source_dir, date_field_name) + _fun = partial(return_date_list, date_field_name) + all_oldest_list = [] with tqdm(total=_number_all_funds) as p_bar: with ProcessPoolExecutor(max_workers=max_workers) as executor: - for date_list in executor.map(_fun, file_list[:_number_all_funds]): + for date_list in executor.map(_fun, file_list): + if date_list: + all_oldest_list.append(date_list[0]) for date in date_list: if date not in _dict_count_trade.keys(): _dict_count_trade[date] = 0 @@ -154,12 +156,10 @@ def get_calendar_list_by_ratio( logger.info(f"count how many funds have founded in this day......") _dict_count_founding = {date: _number_all_funds for date in _dict_count_trade.keys()} # dict{date:count} with tqdm(total=_number_all_funds) as p_bar: - with ProcessPoolExecutor(max_workers=max_workers) as executor: - for date_list in executor.map(_fun, file_list[:_number_all_funds]): - oldest_date = sorted(date_list)[0] # this fund haven't found before this day - for date in _dict_count_founding.keys(): - if date < oldest_date: - _dict_count_founding[date] -= 1 + for oldest_date in all_oldest_list: + for date in _dict_count_founding.keys(): + if date < oldest_date: + _dict_count_founding[date] -= 1 calendar = [ date