mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 02:21:18 +08:00
313 lines
9.6 KiB
Python
313 lines
9.6 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import abc
|
|
import sys
|
|
import copy
|
|
import time
|
|
import datetime
|
|
import importlib
|
|
import json
|
|
from abc import ABC
|
|
from pathlib import Path
|
|
from typing import Iterable, Type
|
|
|
|
import fire
|
|
import requests
|
|
import numpy as np
|
|
import pandas as pd
|
|
from loguru import logger
|
|
from dateutil.tz import tzlocal
|
|
from qlib.config import REG_CN as REGION_CN
|
|
|
|
CUR_DIR = Path(__file__).resolve().parent
|
|
sys.path.append(str(CUR_DIR.parent.parent))
|
|
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
|
|
from data_collector.utils import get_calendar_list, get_en_fund_symbols
|
|
|
|
INDEX_BENCH_URL = "http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}"
|
|
|
|
|
|
class FundCollector(BaseCollector):
|
|
def __init__(
|
|
self,
|
|
save_dir: [str, Path],
|
|
start=None,
|
|
end=None,
|
|
interval="1d",
|
|
max_workers=4,
|
|
max_collector_count=2,
|
|
delay=0,
|
|
check_data_length: bool = False,
|
|
limit_nums: int = None,
|
|
):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
save_dir: str
|
|
fund save dir
|
|
max_workers: int
|
|
workers, default 4
|
|
max_collector_count: int
|
|
default 2
|
|
delay: float
|
|
time.sleep(delay), default 0
|
|
interval: str
|
|
freq, value from [1min, 1d], default 1min
|
|
start: str
|
|
start datetime, default None
|
|
end: str
|
|
end datetime, default None
|
|
check_data_length: bool
|
|
check data length, by default False
|
|
limit_nums: int
|
|
using for debug, by default None
|
|
"""
|
|
super(FundCollector, self).__init__(
|
|
save_dir=save_dir,
|
|
start=start,
|
|
end=end,
|
|
interval=interval,
|
|
max_workers=max_workers,
|
|
max_collector_count=max_collector_count,
|
|
delay=delay,
|
|
check_data_length=check_data_length,
|
|
limit_nums=limit_nums,
|
|
)
|
|
|
|
self.init_datetime()
|
|
|
|
def init_datetime(self):
|
|
if self.interval == self.INTERVAL_1min:
|
|
self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN)
|
|
elif self.interval == self.INTERVAL_1d:
|
|
pass
|
|
else:
|
|
raise ValueError(f"interval error: {self.interval}")
|
|
|
|
self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)
|
|
self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)
|
|
|
|
@staticmethod
|
|
def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone):
|
|
try:
|
|
dt = pd.Timestamp(dt, tz=timezone).timestamp()
|
|
dt = pd.Timestamp(dt, tz=tzlocal(), unit="s")
|
|
except ValueError as e:
|
|
pass
|
|
return dt
|
|
|
|
@property
|
|
@abc.abstractmethod
|
|
def _timezone(self):
|
|
raise NotImplementedError("rewrite get_timezone")
|
|
|
|
@staticmethod
|
|
def get_data_from_remote(symbol, interval, start, end):
|
|
error_msg = f"{symbol}-{interval}-{start}-{end}"
|
|
|
|
try:
|
|
# TODO: numberOfHistoricalDaysToCrawl should be bigger enouhg
|
|
url = INDEX_BENCH_URL.format(
|
|
index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end
|
|
)
|
|
resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"})
|
|
|
|
if resp.status_code != 200:
|
|
raise ValueError("request error")
|
|
|
|
data = json.loads(resp.text.split("(")[-1].split(")")[0])
|
|
|
|
# Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html
|
|
SYType = data["Data"]["SYType"]
|
|
if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"):
|
|
raise Exception("The fund contains 每*份收益")
|
|
|
|
# TODO: should we sort the value by datetime?
|
|
_resp = pd.DataFrame(data["Data"]["LSJZList"])
|
|
|
|
if isinstance(_resp, pd.DataFrame):
|
|
return _resp.reset_index()
|
|
except Exception as e:
|
|
logger.warning(f"{error_msg}:{e}")
|
|
|
|
def get_data(
|
|
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
|
) -> [pd.DataFrame]:
|
|
def _get_simple(start_, end_):
|
|
self.sleep()
|
|
_remote_interval = interval
|
|
return self.get_data_from_remote(
|
|
symbol,
|
|
interval=_remote_interval,
|
|
start=start_,
|
|
end=end_,
|
|
)
|
|
|
|
if interval == self.INTERVAL_1d:
|
|
_result = _get_simple(start_datetime, end_datetime)
|
|
else:
|
|
raise ValueError(f"cannot support {interval}")
|
|
return _result
|
|
|
|
|
|
class FundollectorCN(FundCollector, ABC):
|
|
def get_instrument_list(self):
|
|
logger.info("get cn fund symbols......")
|
|
symbols = get_en_fund_symbols()
|
|
logger.info(f"get {len(symbols)} symbols.")
|
|
return symbols
|
|
|
|
def normalize_symbol(self, symbol):
|
|
return symbol
|
|
|
|
@property
|
|
def _timezone(self):
|
|
return "Asia/Shanghai"
|
|
|
|
|
|
class FundCollectorCN1d(FundollectorCN):
|
|
@property
|
|
def min_numbers_trading(self):
|
|
return 252 / 4
|
|
|
|
|
|
class FundNormalize(BaseNormalize):
|
|
DAILY_FORMAT = "%Y-%m-%d"
|
|
|
|
@staticmethod
|
|
def normalize_fund(
|
|
df: pd.DataFrame,
|
|
calendar_list: list = None,
|
|
date_field_name: str = "date",
|
|
symbol_field_name: str = "symbol",
|
|
):
|
|
if df.empty:
|
|
return df
|
|
df = df.copy()
|
|
df.set_index(date_field_name, inplace=True)
|
|
df.index = pd.to_datetime(df.index)
|
|
df = df[~df.index.duplicated(keep="first")]
|
|
if calendar_list is not None:
|
|
df = df.reindex(
|
|
pd.DataFrame(index=calendar_list)
|
|
.loc[
|
|
pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date()
|
|
+ pd.Timedelta(hours=23, minutes=59)
|
|
]
|
|
.index
|
|
)
|
|
df.sort_index(inplace=True)
|
|
|
|
df.index.names = [date_field_name]
|
|
return df.reset_index()
|
|
|
|
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
# normalize
|
|
df = self.normalize_fund(df, self._calendar_list, self._date_field_name, self._symbol_field_name)
|
|
return df
|
|
|
|
|
|
class FundNormalize1d(FundNormalize):
|
|
pass
|
|
|
|
|
|
class FundNormalizeCN:
|
|
def _get_calendar_list(self):
|
|
return get_calendar_list("ALL")
|
|
|
|
|
|
class FundNormalizeCN1d(FundNormalizeCN, FundNormalize1d):
|
|
pass
|
|
|
|
|
|
class Run(BaseRun):
|
|
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
source_dir: str
|
|
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
|
normalize_dir: str
|
|
Directory for normalize data, default "Path(__file__).parent/normalize"
|
|
max_workers: int
|
|
Concurrent number, default is 4
|
|
interval: str
|
|
freq, value from [1min, 1d], default 1d
|
|
region: str
|
|
region, value from ["CN"], default "CN"
|
|
"""
|
|
super().__init__(source_dir, normalize_dir, max_workers, interval)
|
|
self.region = region
|
|
|
|
@property
|
|
def collector_class_name(self):
|
|
return f"FundCollector{self.region.upper()}{self.interval}"
|
|
|
|
@property
|
|
def normalize_class_name(self):
|
|
return f"FundNormalize{self.region.upper()}{self.interval}"
|
|
|
|
@property
|
|
def default_base_dir(self) -> [Path, str]:
|
|
return CUR_DIR
|
|
|
|
def download_data(
|
|
self,
|
|
max_collector_count=2,
|
|
delay=0,
|
|
start=None,
|
|
end=None,
|
|
interval="1d",
|
|
check_data_length=False,
|
|
limit_nums=None,
|
|
):
|
|
"""download data from Internet
|
|
|
|
Parameters
|
|
----------
|
|
max_collector_count: int
|
|
default 2
|
|
delay: float
|
|
time.sleep(delay), default 0
|
|
interval: str
|
|
freq, value from [1min, 1d], default 1d
|
|
start: str
|
|
start datetime, default "2000-01-01"
|
|
end: str
|
|
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
|
check_data_length: bool # if this param useful?
|
|
check data length, by default False
|
|
limit_nums: int
|
|
using for debug, by default None
|
|
|
|
Examples
|
|
---------
|
|
# get daily data
|
|
$ python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
|
"""
|
|
|
|
super(Run, self).download_data(max_collector_count, delay, start, end, interval, check_data_length, limit_nums)
|
|
|
|
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
|
"""normalize data
|
|
|
|
Parameters
|
|
----------
|
|
date_field_name: str
|
|
date field name, default date
|
|
symbol_field_name: str
|
|
symbol field name, default symbol
|
|
|
|
Examples
|
|
---------
|
|
$ python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_1d --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ
|
|
"""
|
|
super(Run, self).normalize_data(date_field_name, symbol_field_name)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
fire.Fire(Run)
|