1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

use base.py

This commit is contained in:
wangershi
2021-03-14 14:24:14 +08:00
parent 1344c40598
commit 4e7a147759
2 changed files with 121 additions and 344 deletions

View File

@@ -46,7 +46,7 @@ class BaseCollector(abc.ABC):
Parameters
----------
save_dir: str
stock save dir
instrument save dir
max_workers: int
workers, default 4
max_collector_count: int
@@ -77,11 +77,11 @@ class BaseCollector(abc.ABC):
self.start_datetime = self.normalize_start_datetime(start)
self.end_datetime = self.normalize_end_datetime(end)
self.stock_list = sorted(set(self.get_stock_list()))
self.instrument_list = sorted(set(self.get_instrument_list()))
if limit_nums is not None:
try:
self.stock_list = self.stock_list[: int(limit_nums)]
self.instrument_list = self.instrument_list[: int(limit_nums)]
except Exception as e:
logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored")
@@ -108,8 +108,8 @@ class BaseCollector(abc.ABC):
raise NotImplementedError("rewrite min_numbers_trading")
@abc.abstractmethod
def get_stock_list(self):
raise NotImplementedError("rewrite get_stock_list")
def get_instrument_list(self):
raise NotImplementedError("rewrite get_instrument_list")
@abc.abstractmethod
def normalize_symbol(self, symbol: str):
@@ -158,27 +158,27 @@ class BaseCollector(abc.ABC):
return _result
def save_instrument(self, symbol, df: pd.DataFrame):
"""save stock data to file
"""save instrument data to file
Parameters
----------
symbol: str
stock code
instrument code
df : pd.DataFrame
df.columns must contain "symbol" and "datetime"
"""
if df.empty:
if df is None or df.empty:
logger.warning(f"{symbol} is empty")
return
symbol = self.normalize_symbol(symbol)
symbol = code_to_fname(symbol)
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
instrument_path = self.save_dir.joinpath(f"{symbol}.csv")
df["symbol"] = symbol
if stock_path.exists():
_old_df = pd.read_csv(stock_path)
if instrument_path.exists():
_old_df = pd.read_csv(instrument_path)
df = _old_df.append(df, sort=False)
df.to_csv(stock_path, index=False)
df.to_csv(instrument_path, index=False)
def cache_small_data(self, symbol, df):
if len(df) <= self.min_numbers_trading:
@@ -191,38 +191,38 @@ class BaseCollector(abc.ABC):
self.mini_symbol_map.pop(symbol)
return self.NORMAL_FLAG
def _collector(self, stock_list):
def _collector(self, instrument_list):
error_symbol = []
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
with tqdm(total=len(stock_list)) as p_bar:
for _symbol, _result in zip(stock_list, executor.map(self._simple_collector, stock_list)):
with tqdm(total=len(instrument_list)) as p_bar:
for _symbol, _result in zip(instrument_list, executor.map(self._simple_collector, instrument_list)):
if _result != self.NORMAL_FLAG:
error_symbol.append(_symbol)
p_bar.update()
print(error_symbol)
logger.info(f"error symbol nums: {len(error_symbol)}")
logger.info(f"current get symbol nums: {len(stock_list)}")
logger.info(f"current get symbol nums: {len(instrument_list)}")
error_symbol.extend(self.mini_symbol_map.keys())
return sorted(set(error_symbol))
def collector_data(self):
"""collector data"""
logger.info("start collector data......")
stock_list = self.stock_list
instrument_list = self.instrument_list
for i in range(self.max_collector_count):
if not stock_list:
if not instrument_list:
break
logger.info(f"getting data: {i+1}")
stock_list = self._collector(stock_list)
instrument_list = self._collector(instrument_list)
logger.info(f"{i+1} finish.")
for _symbol, _df_list in self.mini_symbol_map.items():
self.save_instrument(
_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])
)
if self.mini_symbol_map:
logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self.mini_symbol_map.keys())}")
logger.info(f"total {len(self.stock_list)}, error: {len(set(stock_list))}")
logger.warning(f"less than {self.min_numbers_trading} instrument list: {list(self.mini_symbol_map.keys())}")
logger.info(f"total {len(self.instrument_list)}, error: {len(set(instrument_list))}")
class BaseNormalize(abc.ABC):
@@ -386,9 +386,9 @@ class BaseRun(abc.ABC):
Examples
---------
# get daily data
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
$ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
# get 1m data
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
$ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
"""
_class = getattr(self._cur_module, self.collector_class_name) # type: Type[BaseCollector]
@@ -416,7 +416,7 @@ class BaseRun(abc.ABC):
Examples
---------
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d
$ python collector.py normalize_data --source_dir ~/.qlib/instrument_data/source --normalize_dir ~/.qlib/instrument_data/normalize --region CN --interval 1d
"""
_class = getattr(self._cur_module, self.normalize_class_name)
yc = Normalize(

View File

@@ -11,123 +11,24 @@ import json
from abc import ABC
from pathlib import Path
from typing import Iterable, Type
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import fire
import requests
import numpy as np
import pandas as pd
from tqdm import tqdm
from loguru import logger
from dateutil.tz import tzlocal
from qlib.config import REG_CN as REGION_CN
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
from data_collector.utils import get_calendar_list, get_en_fund_symbols
INDEX_BENCH_URL = "http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}"
REGION_CN = "CN"
class FundData:
START_DATETIME = pd.Timestamp("2000-01-01")
END_DATETIME = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
INTERVAL_1d = "1d"
def __init__(
self,
timezone: str = None,
start=None,
end=None,
interval="1d",
delay=0,
):
"""
Parameters
----------
timezone: str
The timezone where the data is located
delay: float
time.sleep(delay), default 0
interval: str
freq, value from [1d], default 1d
start: str
start datetime, default None
end: str
end datetime, default None
"""
self._timezone = tzlocal() if timezone is None else timezone
self._delay = delay
self._interval = interval
self.start_datetime = pd.Timestamp(str(start)) if start else self.START_DATETIME
self.end_datetime = min(pd.Timestamp(str(end)) if end else self.END_DATETIME, self.END_DATETIME)
if self._interval != self.INTERVAL_1d:
raise ValueError(f"interval error: {self._interval}")
self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)
self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)
@staticmethod
def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone):
try:
dt = pd.Timestamp(dt, tz=timezone).timestamp()
dt = pd.Timestamp(dt, tz=tzlocal(), unit="s")
except ValueError as e:
pass
return dt
def _sleep(self):
time.sleep(self._delay)
@staticmethod
def get_data_from_remote(symbol, interval, start, end):
error_msg = f"{symbol}-{interval}-{start}-{end}"
try:
# TODO: numberOfHistoricalDaysToCrawl should be bigger enouhg
url = INDEX_BENCH_URL.format(
index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end
)
resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"})
if resp.status_code != 200:
raise ValueError("request error")
data = json.loads(resp.text.split("(")[-1].split(")")[0])
# Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html
SYType = data["Data"]["SYType"]
if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"):
raise Exception("The fund contains 每*份收益")
# TODO: should we sort the value by datetime?
_resp = pd.DataFrame(data["Data"]["LSJZList"])
if isinstance(_resp, pd.DataFrame):
return _resp.reset_index()
except Exception as e:
logger.warning(f"{error_msg}:{e}")
def get_data(self, symbol: str) -> [pd.DataFrame]:
def _get_simple(start_, end_):
self._sleep()
_remote_interval = self._interval
return self.get_data_from_remote(
symbol,
interval=_remote_interval,
start=start_,
end=end_,
)
if self._interval == self.INTERVAL_1d:
_result = _get_simple(self.start_datetime, self.end_datetime)
else:
raise ValueError(f"cannot support {self._interval}")
return _result
class FundCollector:
class FundCollector(BaseCollector):
def __init__(
self,
save_dir: [str, Path],
@@ -163,134 +64,108 @@ class FundCollector:
limit_nums: int
using for debug, by default None
"""
self.save_dir = Path(save_dir).expanduser().resolve()
self.save_dir.mkdir(parents=True, exist_ok=True)
self._delay = delay
self.max_workers = max_workers
self._max_collector_count = max_collector_count
self._mini_symbol_map = {}
self._interval = interval
self._check_small_data = check_data_length
self.fund_list = sorted(set(self.get_fund_list()))
if limit_nums is not None:
try:
self.fund_list = self.fund_list[: int(limit_nums)]
except Exception as e:
logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored")
self.fund_data = FundData(
timezone=self._timezone,
super(FundCollector, self).__init__(
save_dir=save_dir,
start=start,
end=end,
interval=interval,
max_workers=max_workers,
max_collector_count=max_collector_count,
delay=delay,
check_data_length=check_data_length,
limit_nums=limit_nums,
)
@property
@abc.abstractmethod
def min_numbers_trading(self):
# daily, one year: 252 / 4
# us 1min, a week: 6.5 * 60 * 5
# cn 1min, a week: 4 * 60 * 5
raise NotImplementedError("rewrite min_numbers_trading")
self.init_datetime()
@abc.abstractmethod
def get_fund_list(self):
raise NotImplementedError("rewrite get_fund_list")
def init_datetime(self):
if self.interval == self.INTERVAL_1min:
self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN)
elif self.interval == self.INTERVAL_1d:
pass
else:
raise ValueError(f"interval error: {self.interval}")
self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)
self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)
@staticmethod
def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone):
try:
dt = pd.Timestamp(dt, tz=timezone).timestamp()
dt = pd.Timestamp(dt, tz=tzlocal(), unit="s")
except ValueError as e:
pass
return dt
@property
@abc.abstractmethod
def _timezone(self):
raise NotImplementedError("rewrite get_timezone")
def save_fund(self, symbol, df: pd.DataFrame):
"""save fund data to file
@staticmethod
def get_data_from_remote(symbol, interval, start, end):
error_msg = f"{symbol}-{interval}-{start}-{end}"
Parameters
----------
symbol: str
fund code
df : pd.DataFrame
df.columns must contain "symbol" and "datetime"
"""
if df.empty:
logger.warning(f"{symbol} is empty")
return
try:
# TODO: numberOfHistoricalDaysToCrawl should be bigger enouhg
url = INDEX_BENCH_URL.format(
index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end
)
resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"})
fund_path = self.save_dir.joinpath(f"{symbol}.csv")
df["symbol"] = symbol
if fund_path.exists():
# TODO: read the fund code as str, not int, like "000001" shouldn't be "1"
_old_df = pd.read_csv(fund_path)
# TODO: remove the duplicate date
df = _old_df.append(df, sort=False)
df.to_csv(fund_path, index=False)
if resp.status_code != 200:
raise ValueError("request error")
def _save_small_data(self, symbol, df):
if len(df) <= self.min_numbers_trading:
logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!")
_temp = self._mini_symbol_map.setdefault(symbol, [])
_temp.append(df.copy())
return None
data = json.loads(resp.text.split("(")[-1].split(")")[0])
# Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html
SYType = data["Data"]["SYType"]
if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"):
raise Exception("The fund contains 每*份收益")
# TODO: should we sort the value by datetime?
_resp = pd.DataFrame(data["Data"]["LSJZList"])
if isinstance(_resp, pd.DataFrame):
return _resp.reset_index()
except Exception as e:
logger.warning(f"{error_msg}:{e}")
def get_data(
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
) -> [pd.DataFrame]:
def _get_simple(start_, end_):
self.sleep()
_remote_interval = interval
return self.get_data_from_remote(
symbol,
interval=_remote_interval,
start=start_,
end=end_,
)
if interval == self.INTERVAL_1d:
_result = _get_simple(start_datetime, end_datetime)
else:
if symbol in self._mini_symbol_map:
self._mini_symbol_map.pop(symbol)
return symbol
def _get_data(self, symbol):
_result = None
df = self.fund_data.get_data(symbol)
if isinstance(df, pd.DataFrame):
if not df.empty:
if self._check_small_data:
if self._save_small_data(symbol, df) is not None:
_result = symbol
self.save_fund(symbol, df)
else:
_result = symbol
self.save_fund(symbol, df)
raise ValueError(f"cannot support {interval}")
return _result
def _collector(self, fund_list):
error_symbol = []
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
with tqdm(total=len(fund_list)) as p_bar:
for _symbol, _result in zip(fund_list, executor.map(self._get_data, fund_list)):
if _result is None:
error_symbol.append(_symbol)
p_bar.update()
print(error_symbol)
logger.info(f"error symbol nums: {len(error_symbol)}")
logger.info(f"current get symbol nums: {len(fund_list)}")
error_symbol.extend(self._mini_symbol_map.keys())
return sorted(set(error_symbol))
def collector_data(self):
"""collector data"""
logger.info("start collector fund data......")
fund_list = self.fund_list
for i in range(self._max_collector_count):
if not fund_list:
break
logger.info(f"getting data: {i+1}")
fund_list = self._collector(fund_list)
logger.info(f"{i+1} finish.")
for _symbol, _df_list in self._mini_symbol_map.items():
self.save_fund(_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"]))
if self._mini_symbol_map:
logger.warning(f"less than {self.min_numbers_trading} fund list: {list(self._mini_symbol_map.keys())}")
logger.info(f"total {len(self.fund_list)}, error: {len(set(fund_list))}")
super(FundCollector, self).collector_data()
class FundollectorCN(FundCollector, ABC):
def get_fund_list(self):
def get_instrument_list(self):
logger.info("get cn fund symbols......")
symbols = get_en_fund_symbols()
logger.info(f"get {len(symbols)} symbols.")
return symbols
def normalize_symbol(self, symbol):
return symbol
@property
def _timezone(self):
return "Asia/Shanghai"
@@ -302,29 +177,9 @@ class FundCollectorCN1d(FundollectorCN):
return 252 / 4
class FundNormalize:
COLUMNS = ["open", "close", "high", "low", "volume"]
class FundNormalize(BaseNormalize):
DAILY_FORMAT = "%Y-%m-%d"
def __init__(
self,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
):
"""
Parameters
----------
date_field_name: str
date field name, default is date
symbol_field_name: str
symbol field name, default is symbol
"""
self._date_field_name = date_field_name
self._symbol_field_name = symbol_field_name
self._calendar_list = self._get_calendar_list()
@staticmethod
def normalize_fund(
df: pd.DataFrame,
@@ -357,11 +212,6 @@ class FundNormalize:
df = self.normalize_fund(df, self._calendar_list, self._date_field_name, self._symbol_field_name)
return df
@abc.abstractmethod
def _get_calendar_list(self):
"""Get benchmark calendar"""
raise NotImplementedError("")
class FundNormalize1d(FundNormalize, ABC):
DAILY_FORMAT = "%Y-%m-%d"
@@ -380,62 +230,8 @@ class FundNormalizeCN1d(FundNormalizeCN, FundNormalize1d):
pass
class Normalize:
def __init__(
self,
source_dir: [str, Path],
target_dir: [str, Path],
normalize_class: Type[FundNormalize],
max_workers: int = 16,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
):
"""
Parameters
----------
source_dir: str or Path
The directory where the raw data collected from the Internet is saved
target_dir: str or Path
Directory for normalize data
normalize_class: Type[FundNormalize]
normalize class
max_workers: int
Concurrent number, default is 16
date_field_name: str
date field name, default is date
symbol_field_name: str
symbol field name, default is symbol
"""
if not (source_dir and target_dir):
raise ValueError("source_dir and target_dir cannot be None")
self._source_dir = Path(source_dir).expanduser()
self._target_dir = Path(target_dir).expanduser()
self._target_dir.mkdir(parents=True, exist_ok=True)
self._max_workers = max_workers
self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name)
def _executor(self, file_path: Path):
file_path = Path(file_path)
df = pd.read_csv(file_path)
df = self._normalize_obj.normalize(df)
if not df.empty:
df.to_csv(self._target_dir.joinpath(file_path.name), index=False)
def normalize(self):
logger.info("normalize data......")
with ProcessPoolExecutor(max_workers=self._max_workers) as worker:
file_list = list(self._source_dir.glob("*.csv"))
with tqdm(total=len(file_list)) as p_bar:
for _ in worker.map(self._executor, file_list):
p_bar.update()
class Run:
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN):
class Run(BaseRun):
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN):
"""
Parameters
@@ -446,23 +242,26 @@ class Run:
Directory for normalize data, default "Path(__file__).parent/normalize"
max_workers: int
Concurrent number, default is 4
interval: str
freq, value from [1min, 1d], default 1d
region: str
region, value from ["CN"], default "CN"
"""
if source_dir is None:
source_dir = CUR_DIR.joinpath("source")
self.source_dir = Path(source_dir).expanduser().resolve()
self.source_dir.mkdir(parents=True, exist_ok=True)
if normalize_dir is None:
normalize_dir = CUR_DIR.joinpath("normalize")
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
self.normalize_dir.mkdir(parents=True, exist_ok=True)
self._cur_module = importlib.import_module("collector")
self.max_workers = max_workers
super().__init__(source_dir, normalize_dir, max_workers, interval)
self.region = region
@property
def collector_class_name(self):
return f"FundCollector{self.region.upper()}{self.interval}"
@property
def normalize_class_name(self):
return f"FundNormalize{self.region.upper()}{self.interval}"
@property
def default_base_dir(self) -> [Path, str]:
return CUR_DIR
def download_data(
self,
max_collector_count=2,
@@ -498,26 +297,13 @@ class Run:
$ python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
"""
_class = getattr(self._cur_module, f"FundCollector{self.region.upper()}{interval}") # type: Type[FundCollector]
_class(
self.source_dir,
max_workers=self.max_workers,
max_collector_count=max_collector_count,
delay=delay,
start=start,
end=end,
interval=interval,
check_data_length=check_data_length,
limit_nums=limit_nums,
).collector_data()
super(Run, self).download_data(max_collector_count, delay, start, end, interval, check_data_length, limit_nums)
def normalize_data(self, interval: str = "1d", date_field_name: str = "date", symbol_field_name: str = "symbol"):
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
"""normalize data
Parameters
----------
interval: str
freq, value from [1d], default 1d
date_field_name: str
date field name, default date
symbol_field_name: str
@@ -527,16 +313,7 @@ class Run:
---------
$ python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_1d --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ
"""
_class = getattr(self._cur_module, f"FundNormalize{self.region.upper()}{interval}")
fc = Normalize(
source_dir=self.source_dir,
target_dir=self.normalize_dir,
normalize_class=_class,
max_workers=self.max_workers,
date_field_name=date_field_name,
symbol_field_name=symbol_field_name,
)
fc.normalize()
super(Run, self).normalize_data(date_field_name, symbol_field_name)
if __name__ == "__main__":