mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
306 lines
9.8 KiB
Python
306 lines
9.8 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import abc
|
|
import sys
|
|
import datetime
|
|
import json
|
|
from abc import ABC
|
|
from pathlib import Path
|
|
|
|
import fire
|
|
import requests
|
|
import pandas as pd
|
|
from loguru import logger
|
|
from dateutil.tz import tzlocal
|
|
from qlib.config import REG_CN as REGION_CN
|
|
|
|
CUR_DIR = Path(__file__).resolve().parent
|
|
sys.path.append(str(CUR_DIR.parent.parent))
|
|
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
|
|
from data_collector.utils import get_calendar_list, get_en_fund_symbols
|
|
|
|
INDEX_BENCH_URL = "http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}"
|
|
|
|
|
|
class FundCollector(BaseCollector):
|
|
def __init__(
|
|
self,
|
|
save_dir: [str, Path],
|
|
start=None,
|
|
end=None,
|
|
interval="1d",
|
|
max_workers=4,
|
|
max_collector_count=2,
|
|
delay=0,
|
|
check_data_length: int = None,
|
|
limit_nums: int = None,
|
|
):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
save_dir: str
|
|
fund save dir
|
|
max_workers: int
|
|
workers, default 4
|
|
max_collector_count: int
|
|
default 2
|
|
delay: float
|
|
time.sleep(delay), default 0
|
|
interval: str
|
|
freq, value from [1min, 1d], default 1min
|
|
start: str
|
|
start datetime, default None
|
|
end: str
|
|
end datetime, default None
|
|
check_data_length: int
|
|
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
|
limit_nums: int
|
|
using for debug, by default None
|
|
"""
|
|
super(FundCollector, self).__init__(
|
|
save_dir=save_dir,
|
|
start=start,
|
|
end=end,
|
|
interval=interval,
|
|
max_workers=max_workers,
|
|
max_collector_count=max_collector_count,
|
|
delay=delay,
|
|
check_data_length=check_data_length,
|
|
limit_nums=limit_nums,
|
|
)
|
|
|
|
self.init_datetime()
|
|
|
|
def init_datetime(self):
|
|
if self.interval == self.INTERVAL_1min:
|
|
self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN)
|
|
elif self.interval == self.INTERVAL_1d:
|
|
pass
|
|
else:
|
|
raise ValueError(f"interval error: {self.interval}")
|
|
|
|
self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)
|
|
self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)
|
|
|
|
@staticmethod
|
|
def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone):
|
|
try:
|
|
dt = pd.Timestamp(dt, tz=timezone).timestamp()
|
|
dt = pd.Timestamp(dt, tz=tzlocal(), unit="s")
|
|
except ValueError as e:
|
|
pass
|
|
return dt
|
|
|
|
@property
|
|
@abc.abstractmethod
|
|
def _timezone(self):
|
|
raise NotImplementedError("rewrite get_timezone")
|
|
|
|
@staticmethod
|
|
def get_data_from_remote(symbol, interval, start, end):
|
|
error_msg = f"{symbol}-{interval}-{start}-{end}"
|
|
|
|
try:
|
|
# TODO: numberOfHistoricalDaysToCrawl should be bigger enough
|
|
url = INDEX_BENCH_URL.format(
|
|
index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end
|
|
)
|
|
resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"})
|
|
|
|
if resp.status_code != 200:
|
|
raise ValueError("request error")
|
|
|
|
data = json.loads(resp.text.split("(")[-1].split(")")[0])
|
|
|
|
# Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html
|
|
SYType = data["Data"]["SYType"]
|
|
if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"):
|
|
raise Exception("The fund contains 每*份收益")
|
|
|
|
# TODO: should we sort the value by datetime?
|
|
_resp = pd.DataFrame(data["Data"]["LSJZList"])
|
|
|
|
if isinstance(_resp, pd.DataFrame):
|
|
return _resp.reset_index()
|
|
except Exception as e:
|
|
logger.warning(f"{error_msg}:{e}")
|
|
|
|
def get_data(
|
|
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
|
) -> [pd.DataFrame]:
|
|
def _get_simple(start_, end_):
|
|
self.sleep()
|
|
_remote_interval = interval
|
|
return self.get_data_from_remote(
|
|
symbol,
|
|
interval=_remote_interval,
|
|
start=start_,
|
|
end=end_,
|
|
)
|
|
|
|
if interval == self.INTERVAL_1d:
|
|
_result = _get_simple(start_datetime, end_datetime)
|
|
else:
|
|
raise ValueError(f"cannot support {interval}")
|
|
return _result
|
|
|
|
|
|
class FundollectorCN(FundCollector, ABC):
|
|
def get_instrument_list(self):
|
|
logger.info("get cn fund symbols......")
|
|
symbols = get_en_fund_symbols()
|
|
logger.info(f"get {len(symbols)} symbols.")
|
|
return symbols
|
|
|
|
def normalize_symbol(self, symbol):
|
|
return symbol
|
|
|
|
@property
|
|
def _timezone(self):
|
|
return "Asia/Shanghai"
|
|
|
|
|
|
class FundCollectorCN1d(FundollectorCN):
|
|
pass
|
|
|
|
|
|
class FundNormalize(BaseNormalize):
|
|
DAILY_FORMAT = "%Y-%m-%d"
|
|
|
|
@staticmethod
|
|
def normalize_fund(
|
|
df: pd.DataFrame,
|
|
calendar_list: list = None,
|
|
date_field_name: str = "date",
|
|
symbol_field_name: str = "symbol",
|
|
):
|
|
if df.empty:
|
|
return df
|
|
df = df.copy()
|
|
df.set_index(date_field_name, inplace=True)
|
|
df.index = pd.to_datetime(df.index)
|
|
df = df[~df.index.duplicated(keep="first")]
|
|
if calendar_list is not None:
|
|
df = df.reindex(
|
|
pd.DataFrame(index=calendar_list)
|
|
.loc[
|
|
pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date()
|
|
+ pd.Timedelta(hours=23, minutes=59)
|
|
]
|
|
.index
|
|
)
|
|
df.sort_index(inplace=True)
|
|
|
|
df.index.names = [date_field_name]
|
|
return df.reset_index()
|
|
|
|
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
# normalize
|
|
df = self.normalize_fund(df, self._calendar_list, self._date_field_name, self._symbol_field_name)
|
|
return df
|
|
|
|
|
|
class FundNormalize1d(FundNormalize):
|
|
pass
|
|
|
|
|
|
class FundNormalizeCN:
|
|
def _get_calendar_list(self):
|
|
return get_calendar_list("ALL")
|
|
|
|
|
|
class FundNormalizeCN1d(FundNormalizeCN, FundNormalize1d):
|
|
pass
|
|
|
|
|
|
class Run(BaseRun):
|
|
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
source_dir: str
|
|
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
|
normalize_dir: str
|
|
Directory for normalize data, default "Path(__file__).parent/normalize"
|
|
max_workers: int
|
|
Concurrent number, default is 4
|
|
interval: str
|
|
freq, value from [1min, 1d], default 1d
|
|
region: str
|
|
region, value from ["CN"], default "CN"
|
|
"""
|
|
super().__init__(source_dir, normalize_dir, max_workers, interval)
|
|
self.region = region
|
|
|
|
@property
|
|
def collector_class_name(self):
|
|
return f"FundCollector{self.region.upper()}{self.interval}"
|
|
|
|
@property
|
|
def normalize_class_name(self):
|
|
return f"FundNormalize{self.region.upper()}{self.interval}"
|
|
|
|
@property
|
|
def default_base_dir(self) -> [Path, str]:
|
|
return CUR_DIR
|
|
|
|
def download_data(
|
|
self,
|
|
max_collector_count=2,
|
|
delay=0,
|
|
start=None,
|
|
end=None,
|
|
interval="1d",
|
|
check_data_length: int = None,
|
|
limit_nums=None,
|
|
):
|
|
"""download data from Internet
|
|
|
|
Parameters
|
|
----------
|
|
max_collector_count: int
|
|
default 2
|
|
delay: float
|
|
time.sleep(delay), default 0
|
|
interval: str
|
|
freq, value from [1min, 1d], default 1d
|
|
start: str
|
|
start datetime, default "2000-01-01"
|
|
end: str
|
|
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
|
check_data_length: int # if this param useful?
|
|
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
|
limit_nums: int
|
|
using for debug, by default None
|
|
|
|
Examples
|
|
---------
|
|
# get daily data
|
|
$ python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_data --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
|
"""
|
|
|
|
super(Run, self).download_data(max_collector_count, delay, start, end, interval, check_data_length, limit_nums)
|
|
|
|
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
|
"""normalize data
|
|
|
|
Parameters
|
|
----------
|
|
date_field_name: str
|
|
date field name, default date
|
|
symbol_field_name: str
|
|
symbol field name, default symbol
|
|
|
|
Examples
|
|
---------
|
|
$ python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_data --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ
|
|
"""
|
|
super(Run, self).normalize_data(date_field_name, symbol_field_name)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
fire.Fire(Run)
|