1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

ready for collector

This commit is contained in:
wangershi
2021-02-28 17:03:14 +08:00
parent 6e56396217
commit db80b620d8
2 changed files with 32 additions and 57 deletions

View File

@@ -20,20 +20,16 @@ import pandas as pd
from tqdm import tqdm
from loguru import logger
from dateutil.tz import tzlocal
from qlib.utils import code_to_fname, fname_to_code
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
from data_collector.utils import get_calendar_list, get_en_fund_symbols
from data_collector.utils import get_en_fund_symbols
INDEX_BENCH_URL = "http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}"
REGION_CN = "CN"
REGION_US = "US"
class FundData:
START_DATETIME = pd.Timestamp("2000-01-01")
HIGH_FREQ_START_DATETIME = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6))
END_DATETIME = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
INTERVAL_1d = "1d"
@@ -44,7 +40,6 @@ class FundData:
end=None,
interval="1d",
delay=0,
show_1min_logging: bool = False,
):
"""
@@ -60,22 +55,15 @@ class FundData:
start datetime, default None
end: str
end datetime, default None
show_1min_logging: bool
show 1min logging, by default False; if True, there may be many warning logs
"""
self._timezone = tzlocal() if timezone is None else timezone
self._delay = delay
self._interval = interval
self._show_1min_logging = show_1min_logging
self.start_datetime = pd.Timestamp(str(start)) if start else self.START_DATETIME
self.end_datetime = min(pd.Timestamp(str(end)) if end else self.END_DATETIME, self.END_DATETIME)
if self._interval != self.INTERVAL_1d:
raise ValueError(f"interval error: {self._interval}")
# using for 1min
self._next_datetime = self.convert_datetime(self.start_datetime.date() + pd.Timedelta(days=1), self._timezone)
self._latest_datetime = self.convert_datetime(self.end_datetime.date(), self._timezone)
self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)
self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)
@@ -92,33 +80,26 @@ class FundData:
time.sleep(self._delay)
@staticmethod
def get_data_from_remote(symbol, interval, start, end, show_1min_logging: bool = False):
def get_data_from_remote(symbol, interval, start, end):
error_msg = f"{symbol}-{interval}-{start}-{end}"
try:
_resp = None
# TODO: numberOfHistoricalDaysToCrawl should be bigger enouhg
url = INDEX_BENCH_URL.format(index_code=symbol, numberOfHistoricalDaysToCrawl=100, startDate=start, endDate=end)
url = INDEX_BENCH_URL.format(index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end)
resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"})
if resp.status_code != 200:
raise ValueError("request error")
try:
data = json.loads(resp.text.split("(")[-1].split(")")[0])
data = json.loads(resp.text.split("(")[-1].split(")")[0])
# Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html
SYType = data["Data"]["SYType"]
if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"):
raise Exception("The fund contains 每*份收益")
# Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html
SYType = data["Data"]["SYType"]
if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"):
raise Exception("The fund contains 每*份收益")
_resp = pd.DataFrame(
data["Data"]["LSJZList"]
)
except Exception as e:
logger.warning(f"request error: {e}")
raise
# TODO: should we sort the value by datetime?
_resp = pd.DataFrame(data["Data"]["LSJZList"])
if isinstance(_resp, pd.DataFrame):
return _resp.reset_index()
@@ -134,7 +115,6 @@ class FundData:
interval=_remote_interval,
start=start_,
end=end_,
show_1min_logging=self._show_1min_logging,
)
if self._interval == self.INTERVAL_1d:
@@ -156,14 +136,13 @@ class FundCollector:
delay=0,
check_data_length: bool = False,
limit_nums: int = None,
show_1min_logging: bool = False,
):
"""
Parameters
----------
save_dir: str
stock save dir
fund save dir
max_workers: int
workers, default 4
max_collector_count: int
@@ -180,8 +159,6 @@ class FundCollector:
check data length, by default False
limit_nums: int
using for debug, by default None
show_1min_logging: bool
show 1m logging, by default False; if True, there may be many warning logs
"""
self.save_dir = Path(save_dir).expanduser().resolve()
self.save_dir.mkdir(parents=True, exist_ok=True)
@@ -206,7 +183,6 @@ class FundCollector:
end=end,
interval=interval,
delay=delay,
show_1min_logging=show_1min_logging,
)
@property
@@ -240,13 +216,14 @@ class FundCollector:
logger.warning(f"{symbol} is empty")
return
symbol = code_to_fname(symbol)
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
fund_path = self.save_dir.joinpath(f"{symbol}.csv")
df["symbol"] = symbol
if stock_path.exists():
_old_df = pd.read_csv(stock_path)
if fund_path.exists():
# TODO: read the fund code as str, not int, like "000001" shouldn't be "1"
_old_df = pd.read_csv(fund_path)
# TODO: remove the duplicate date
df = _old_df.append(df, sort=False)
df.to_csv(stock_path, index=False)
df.to_csv(fund_path, index=False)
def _save_small_data(self, symbol, df):
if len(df) <= self.min_numbers_trading:
@@ -274,7 +251,6 @@ class FundCollector:
return _result
def _collector(self, fund_list):
error_symbol = []
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
with tqdm(total=len(fund_list)) as p_bar:
@@ -301,7 +277,7 @@ class FundCollector:
for _symbol, _df_list in self._mini_symbol_map.items():
self.save_fund(_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"]))
if self._mini_symbol_map:
logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self._mini_symbol_map.keys())}")
logger.warning(f"less than {self.min_numbers_trading} fund list: {list(self._mini_symbol_map.keys())}")
logger.info(f"total {len(self.fund_list)}, error: {len(set(fund_list))}")
class FundollectorCN(FundCollector, ABC):
@@ -322,30 +298,23 @@ class FundCollectorCN1d(FundollectorCN):
return 252 / 4
class Run:
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN):
def __init__(self, source_dir=None, max_workers=4, region=REGION_CN):
"""
Parameters
----------
source_dir: str
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
normalize_dir: str
Directory for normalize data, default "Path(__file__).parent/normalize"
max_workers: int
Concurrent number, default is 4
region: str
region, value from ["CN", "US"], default "CN"
region, value from ["CN"], default "CN"
"""
if source_dir is None:
source_dir = CUR_DIR.joinpath("source")
self.source_dir = Path(source_dir).expanduser().resolve()
self.source_dir.mkdir(parents=True, exist_ok=True)
if normalize_dir is None:
normalize_dir = CUR_DIR.joinpath("normalize")
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
self.normalize_dir.mkdir(parents=True, exist_ok=True)
self._cur_module = importlib.import_module("collector")
self.max_workers = max_workers
self.region = region
@@ -359,7 +328,6 @@ class Run:
interval="1d",
check_data_length=False,
limit_nums=None,
show_1min_logging=False,
):
"""download data from Internet
@@ -375,12 +343,10 @@ class Run:
start datetime, default "2000-01-01"
end: str
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
check_data_length: bool
check_data_length: bool # if this param useful?
check data length, by default False
limit_nums: int
using for debug, by default None
show_1min_logging: bool
show 1m logging, by default False; if True, there may be many warning logs
Examples
---------
@@ -401,7 +367,6 @@ class Run:
interval=interval,
check_data_length=check_data_length,
limit_nums=limit_nums,
show_1min_logging=show_1min_logging,
).collector_data()
if __name__ == "__main__":

View File

@@ -0,0 +1,10 @@
loguru
fire
requests
numpy
pandas
tqdm
lxml
loguru
yahooquery
json