mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Add BaseCollector
This commit is contained in:
430
scripts/data_collector/base.py
Normal file
430
scripts/data_collector/base.py
Normal file
@@ -0,0 +1,430 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import abc
|
||||
import time
|
||||
import datetime
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from typing import Type
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from qlib.utils import code_to_fname
|
||||
|
||||
|
||||
class BaseCollector(abc.ABC):
|
||||
|
||||
CACHE_FLAG = "CACHED"
|
||||
NORMAL_FLAG = "NORMAL"
|
||||
|
||||
DEFAULT_START_DATETIME_1D = pd.Timestamp("2000-01-01")
|
||||
DEFAULT_START_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6))
|
||||
DEFAULT_END_DATETIME_1D = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
|
||||
DEFAULT_END_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
|
||||
|
||||
INTERVAL_1min = "1min"
|
||||
INTERVAL_1d = "1d"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: [str, Path],
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
max_workers=4,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
save_dir: str
|
||||
stock save dir
|
||||
max_workers: int
|
||||
workers, default 4
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
start: str
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
"""
|
||||
self.save_dir = Path(save_dir).expanduser().resolve()
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.delay = delay
|
||||
self.max_workers = max_workers
|
||||
self.max_collector_count = max_collector_count
|
||||
self.mini_symbol_map = {}
|
||||
self.interval = interval
|
||||
self.check_small_data = check_data_length
|
||||
|
||||
self.start_datetime = self.normalize_start_datetime(start)
|
||||
self.end_datetime = self.normalize_end_datetime(end)
|
||||
|
||||
self.stock_list = sorted(set(self.get_stock_list()))
|
||||
|
||||
if limit_nums is not None:
|
||||
try:
|
||||
self.stock_list = self.stock_list[: int(limit_nums)]
|
||||
except Exception as e:
|
||||
logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored")
|
||||
|
||||
def normalize_start_datetime(self, start_datetime: [str, pd.Timestamp] = None):
|
||||
return (
|
||||
pd.Timestamp(str(start_datetime))
|
||||
if start_datetime
|
||||
else getattr(self, f"DEFAULT_START_DATETIME_{self.interval.upper()}")
|
||||
)
|
||||
|
||||
def normalize_end_datetime(self, end_datetime: [str, pd.Timestamp] = None):
|
||||
return (
|
||||
pd.Timestamp(str(end_datetime))
|
||||
if end_datetime
|
||||
else getattr(self, f"DEFAULT_END_DATETIME_{self.interval.upper()}")
|
||||
)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def min_numbers_trading(self):
|
||||
# daily, one year: 252 / 4
|
||||
# us 1min, a week: 6.5 * 60 * 5
|
||||
# cn 1min, a week: 4 * 60 * 5
|
||||
raise NotImplementedError("rewrite min_numbers_trading")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_stock_list(self):
|
||||
raise NotImplementedError("rewrite get_stock_list")
|
||||
|
||||
@abc.abstractmethod
|
||||
def normalize_symbol(self, symbol: str):
|
||||
"""normalize symbol"""
|
||||
raise NotImplementedError("rewrite normalize_symbol")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_data(
|
||||
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
||||
) -> pd.DataFrame:
|
||||
"""get data with symbol
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
interval: str
|
||||
value from [1min, 1d]
|
||||
start_datetime: pd.Timestamp
|
||||
end_datetime: pd.Timestamp
|
||||
|
||||
Returns
|
||||
---------
|
||||
pd.DataFrame, "symbol" in pd.columns
|
||||
|
||||
"""
|
||||
raise NotImplementedError("rewrite get_timezone")
|
||||
|
||||
def sleep(self):
|
||||
time.sleep(self.delay)
|
||||
|
||||
def _simple_collector(self, symbol: str):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
|
||||
"""
|
||||
self.sleep()
|
||||
df = self.get_data(symbol, self.interval, self.start_datetime, self.end_datetime)
|
||||
_result = self.NORMAL_FLAG
|
||||
if self.check_small_data:
|
||||
_result = self.cache_small_data(symbol, df)
|
||||
if _result == self.NORMAL_FLAG:
|
||||
self.save_instrument(symbol, df)
|
||||
return _result
|
||||
|
||||
def save_instrument(self, symbol, df: pd.DataFrame):
|
||||
"""save stock data to file
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
stock code
|
||||
df : pd.DataFrame
|
||||
df.columns must contain "symbol" and "datetime"
|
||||
"""
|
||||
if df.empty:
|
||||
logger.warning(f"{symbol} is empty")
|
||||
return
|
||||
|
||||
symbol = self.normalize_symbol(symbol)
|
||||
symbol = code_to_fname(symbol)
|
||||
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
|
||||
df["symbol"] = symbol
|
||||
if stock_path.exists():
|
||||
_old_df = pd.read_csv(stock_path)
|
||||
df = _old_df.append(df, sort=False)
|
||||
df.to_csv(stock_path, index=False)
|
||||
|
||||
def cache_small_data(self, symbol, df):
|
||||
if len(df) <= self.min_numbers_trading:
|
||||
logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!")
|
||||
_temp = self.mini_symbol_map.setdefault(symbol, [])
|
||||
_temp.append(df.copy())
|
||||
return self.CACHE_FLAG
|
||||
else:
|
||||
if symbol in self.mini_symbol_map:
|
||||
self.mini_symbol_map.pop(symbol)
|
||||
return self.NORMAL_FLAG
|
||||
|
||||
def _collector(self, stock_list):
|
||||
|
||||
error_symbol = []
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
with tqdm(total=len(stock_list)) as p_bar:
|
||||
for _symbol, _result in zip(stock_list, executor.map(self._simple_collector, stock_list)):
|
||||
if _result != self.NORMAL_FLAG:
|
||||
error_symbol.append(_symbol)
|
||||
p_bar.update()
|
||||
print(error_symbol)
|
||||
logger.info(f"error symbol nums: {len(error_symbol)}")
|
||||
logger.info(f"current get symbol nums: {len(stock_list)}")
|
||||
error_symbol.extend(self.mini_symbol_map.keys())
|
||||
return sorted(set(error_symbol))
|
||||
|
||||
def collector_data(self):
|
||||
"""collector data"""
|
||||
logger.info("start collector data......")
|
||||
stock_list = self.stock_list
|
||||
for i in range(self.max_collector_count):
|
||||
if not stock_list:
|
||||
break
|
||||
logger.info(f"getting data: {i+1}")
|
||||
stock_list = self._collector(stock_list)
|
||||
logger.info(f"{i+1} finish.")
|
||||
for _symbol, _df_list in self.mini_symbol_map.items():
|
||||
self.save_instrument(
|
||||
_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])
|
||||
)
|
||||
if self.mini_symbol_map:
|
||||
logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self.mini_symbol_map.keys())}")
|
||||
logger.info(f"total {len(self.stock_list)}, error: {len(set(stock_list))}")
|
||||
|
||||
|
||||
class BaseNormalize(abc.ABC):
|
||||
def __init__(
|
||||
self,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
self._date_field_name = date_field_name
|
||||
self._symbol_field_name = symbol_field_name
|
||||
|
||||
self._calendar_list = self._get_calendar_list()
|
||||
|
||||
@abc.abstractmethod
|
||||
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
# normalize
|
||||
raise NotImplementedError("")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_calendar_list(self):
|
||||
"""Get benchmark calendar"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
|
||||
class Normalize:
|
||||
def __init__(
|
||||
self,
|
||||
source_dir: [str, Path],
|
||||
target_dir: [str, Path],
|
||||
normalize_class: Type[BaseNormalize],
|
||||
max_workers: int = 16,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str or Path
|
||||
The directory where the raw data collected from the Internet is saved
|
||||
target_dir: str or Path
|
||||
Directory for normalize data
|
||||
normalize_class: Type[YahooNormalize]
|
||||
normalize class
|
||||
max_workers: int
|
||||
Concurrent number, default is 16
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
if not (source_dir and target_dir):
|
||||
raise ValueError("source_dir and target_dir cannot be None")
|
||||
self._source_dir = Path(source_dir).expanduser()
|
||||
self._target_dir = Path(target_dir).expanduser()
|
||||
self._target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._max_workers = max_workers
|
||||
|
||||
self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name)
|
||||
|
||||
def _executor(self, file_path: Path):
|
||||
file_path = Path(file_path)
|
||||
df = pd.read_csv(file_path)
|
||||
df = self._normalize_obj.normalize(df)
|
||||
if not df.empty:
|
||||
df.to_csv(self._target_dir.joinpath(file_path.name), index=False)
|
||||
|
||||
def normalize(self):
|
||||
logger.info("normalize data......")
|
||||
|
||||
with ProcessPoolExecutor(max_workers=self._max_workers) as worker:
|
||||
file_list = list(self._source_dir.glob("*.csv"))
|
||||
with tqdm(total=len(file_list)) as p_bar:
|
||||
for _ in worker.map(self._executor, file_list):
|
||||
p_bar.update()
|
||||
|
||||
|
||||
class BaseRun(abc.ABC):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d"):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str
|
||||
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
||||
normalize_dir: str
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
max_workers: int
|
||||
Concurrent number, default is 4
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
"""
|
||||
if source_dir is None:
|
||||
source_dir = Path(self.default_base_dir).joinpath("_source")
|
||||
self.source_dir = Path(source_dir).expanduser().resolve()
|
||||
self.source_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if normalize_dir is None:
|
||||
normalize_dir = Path(self.default_base_dir).joinpath("normalize")
|
||||
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
|
||||
self.normalize_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._cur_module = importlib.import_module("collector")
|
||||
self.max_workers = max_workers
|
||||
self.interval = interval
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def collector_class_name(self):
|
||||
raise NotImplementedError("rewrite normalize_symbol")
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def normalize_class_name(self):
|
||||
raise NotImplementedError("rewrite normalize_symbol")
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def default_base_dir(self) -> [Path, str]:
|
||||
raise NotImplementedError("rewrite normalize_symbol")
|
||||
|
||||
def download_data(
|
||||
self,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
check_data_length=False,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download data from Internet
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
start: str
|
||||
start datetime, default "2000-01-01"
|
||||
end: str
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
|
||||
Examples
|
||||
---------
|
||||
# get daily data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
# get 1m data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
|
||||
"""
|
||||
|
||||
_class = getattr(self._cur_module, self.collector_class_name) # type: Type[BaseCollector]
|
||||
_class(
|
||||
self.source_dir,
|
||||
max_workers=self.max_workers,
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
start=start,
|
||||
end=end,
|
||||
interval=interval,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
).collector_data()
|
||||
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
||||
"""normalize data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
date_field_name: str
|
||||
date field name, default date
|
||||
symbol_field_name: str
|
||||
symbol field name, default symbol
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d
|
||||
"""
|
||||
_class = getattr(self._cur_module, self.normalize_class_name)
|
||||
yc = Normalize(
|
||||
source_dir=self.source_dir,
|
||||
target_dir=self.normalize_dir,
|
||||
normalize_class=_class,
|
||||
max_workers=self.max_workers,
|
||||
date_field_name=date_field_name,
|
||||
symbol_field_name=symbol_field_name,
|
||||
)
|
||||
yc.normalize()
|
||||
@@ -10,158 +10,26 @@ import importlib
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Type
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
|
||||
import fire
|
||||
import requests
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from yahooquery import Ticker
|
||||
from dateutil.tz import tzlocal
|
||||
from qlib.utils import code_to_fname, fname_to_code
|
||||
from qlib.config import REG_CN as REGION_CN
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
|
||||
from data_collector.utils import get_calendar_list, get_hs_stock_symbols, get_us_stock_symbols
|
||||
|
||||
INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}"
|
||||
REGION_CN = "CN"
|
||||
REGION_US = "US"
|
||||
|
||||
|
||||
class YahooData:
|
||||
START_DATETIME = pd.Timestamp("2000-01-01")
|
||||
HIGH_FREQ_START_DATETIME = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6))
|
||||
END_DATETIME = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
|
||||
INTERVAL_1min = "1min"
|
||||
INTERVAL_1d = "1d"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timezone: str = None,
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
delay=0,
|
||||
show_1min_logging: bool = False,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
timezone: str
|
||||
The timezone where the data is located
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1min
|
||||
start: str
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
show_1min_logging: bool
|
||||
show 1min logging, by default False; if True, there may be many warning logs
|
||||
"""
|
||||
self._timezone = tzlocal() if timezone is None else timezone
|
||||
self._delay = delay
|
||||
self._interval = interval
|
||||
self._show_1min_logging = show_1min_logging
|
||||
self.start_datetime = pd.Timestamp(str(start)) if start else self.START_DATETIME
|
||||
self.end_datetime = min(pd.Timestamp(str(end)) if end else self.END_DATETIME, self.END_DATETIME)
|
||||
if self._interval == self.INTERVAL_1min:
|
||||
self.start_datetime = max(self.start_datetime, self.HIGH_FREQ_START_DATETIME)
|
||||
elif self._interval == self.INTERVAL_1d:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"interval error: {self._interval}")
|
||||
|
||||
# using for 1min
|
||||
self._next_datetime = self.convert_datetime(self.start_datetime.date() + pd.Timedelta(days=1), self._timezone)
|
||||
self._latest_datetime = self.convert_datetime(self.end_datetime.date(), self._timezone)
|
||||
|
||||
self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)
|
||||
self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)
|
||||
|
||||
@staticmethod
|
||||
def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone):
|
||||
try:
|
||||
dt = pd.Timestamp(dt, tz=timezone).timestamp()
|
||||
dt = pd.Timestamp(dt, tz=tzlocal(), unit="s")
|
||||
except ValueError as e:
|
||||
pass
|
||||
return dt
|
||||
|
||||
def _sleep(self):
|
||||
time.sleep(self._delay)
|
||||
|
||||
@staticmethod
|
||||
def get_data_from_remote(symbol, interval, start, end, show_1min_logging: bool = False):
|
||||
error_msg = f"{symbol}-{interval}-{start}-{end}"
|
||||
|
||||
def _show_logging_func():
|
||||
if interval == YahooData.INTERVAL_1min and show_1min_logging:
|
||||
logger.warning(f"{error_msg}:{_resp}")
|
||||
|
||||
interval = "1m" if interval in ["1m", "1min"] else interval
|
||||
try:
|
||||
_resp = Ticker(symbol, asynchronous=False).history(interval=interval, start=start, end=end)
|
||||
if isinstance(_resp, pd.DataFrame):
|
||||
return _resp.reset_index()
|
||||
elif isinstance(_resp, dict):
|
||||
_temp_data = _resp.get(symbol, {})
|
||||
if isinstance(_temp_data, str) or (
|
||||
isinstance(_resp, dict) and _temp_data.get("indicators", {}).get("quote", None) is None
|
||||
):
|
||||
_show_logging_func()
|
||||
else:
|
||||
_show_logging_func()
|
||||
except Exception as e:
|
||||
logger.warning(f"{error_msg}:{e}")
|
||||
|
||||
def get_data(self, symbol: str) -> [pd.DataFrame]:
|
||||
def _get_simple(start_, end_):
|
||||
self._sleep()
|
||||
_remote_interval = "1m" if self._interval == self.INTERVAL_1min else self._interval
|
||||
return self.get_data_from_remote(
|
||||
symbol,
|
||||
interval=_remote_interval,
|
||||
start=start_,
|
||||
end=end_,
|
||||
show_1min_logging=self._show_1min_logging,
|
||||
)
|
||||
|
||||
_result = None
|
||||
if self._interval == self.INTERVAL_1d:
|
||||
_result = _get_simple(self.start_datetime, self.end_datetime)
|
||||
elif self._interval == self.INTERVAL_1min:
|
||||
if self._next_datetime >= self._latest_datetime:
|
||||
_result = _get_simple(self.start_datetime, self.end_datetime)
|
||||
else:
|
||||
_res = []
|
||||
|
||||
def _get_multi(start_, end_):
|
||||
_resp = _get_simple(start_, end_)
|
||||
if _resp is not None and not _resp.empty:
|
||||
_res.append(_resp)
|
||||
|
||||
for _s, _e in (
|
||||
(self.start_datetime, self._next_datetime),
|
||||
(self._latest_datetime, self.end_datetime),
|
||||
):
|
||||
_get_multi(_s, _e)
|
||||
for _start in pd.date_range(self._next_datetime, self._latest_datetime, closed="left"):
|
||||
_end = _start + pd.Timedelta(days=1)
|
||||
_get_multi(_start, _end)
|
||||
if _res:
|
||||
_result = pd.concat(_res, sort=False).sort_values(["symbol", "date"])
|
||||
else:
|
||||
raise ValueError(f"cannot support {self._interval}")
|
||||
return _result
|
||||
|
||||
|
||||
class YahooCollector:
|
||||
class YahooCollector(BaseCollector):
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: [str, Path],
|
||||
@@ -173,7 +41,6 @@ class YahooCollector:
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
limit_nums: int = None,
|
||||
show_1min_logging: bool = False,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -197,131 +64,118 @@ class YahooCollector:
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
show_1min_logging: bool
|
||||
show 1m logging, by default False; if True, there may be many warning logs
|
||||
"""
|
||||
self.save_dir = Path(save_dir).expanduser().resolve()
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._delay = delay
|
||||
self.max_workers = max_workers
|
||||
self._max_collector_count = max_collector_count
|
||||
self._mini_symbol_map = {}
|
||||
self._interval = interval
|
||||
self._check_small_data = check_data_length
|
||||
|
||||
self.stock_list = sorted(set(self.get_stock_list()))
|
||||
if limit_nums is not None:
|
||||
try:
|
||||
self.stock_list = self.stock_list[: int(limit_nums)]
|
||||
except Exception as e:
|
||||
logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored")
|
||||
|
||||
self.yahoo_data = YahooData(
|
||||
timezone=self._timezone,
|
||||
super(YahooCollector, self).__init__(
|
||||
save_dir=save_dir,
|
||||
start=start,
|
||||
end=end,
|
||||
interval=interval,
|
||||
max_workers=max_workers,
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
show_1min_logging=show_1min_logging,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def min_numbers_trading(self):
|
||||
# daily, one year: 252 / 4
|
||||
# us 1min, a week: 6.5 * 60 * 5
|
||||
# cn 1min, a week: 4 * 60 * 5
|
||||
raise NotImplementedError("rewrite min_numbers_trading")
|
||||
self.init_datetime()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_stock_list(self):
|
||||
raise NotImplementedError("rewrite get_stock_list")
|
||||
def init_datetime(self):
|
||||
if self.interval == self.INTERVAL_1min:
|
||||
self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN)
|
||||
elif self.interval == self.INTERVAL_1d:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"interval error: {self.interval}")
|
||||
|
||||
# using for 1min
|
||||
self._next_datetime = self.convert_datetime(self.start_datetime.date() + pd.Timedelta(days=1), self._timezone)
|
||||
self._latest_datetime = self.convert_datetime(self.end_datetime.date(), self._timezone)
|
||||
|
||||
self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)
|
||||
self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)
|
||||
|
||||
@staticmethod
|
||||
def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone):
|
||||
try:
|
||||
dt = pd.Timestamp(dt, tz=timezone).timestamp()
|
||||
dt = pd.Timestamp(dt, tz=tzlocal(), unit="s")
|
||||
except ValueError as e:
|
||||
pass
|
||||
return dt
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def _timezone(self):
|
||||
raise NotImplementedError("rewrite get_timezone")
|
||||
|
||||
def save_stock(self, symbol, df: pd.DataFrame):
|
||||
"""save stock data to file
|
||||
@staticmethod
|
||||
def get_data_from_remote(symbol, interval, start, end, show_1min_logging: bool = False):
|
||||
error_msg = f"{symbol}-{interval}-{start}-{end}"
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
stock code
|
||||
df : pd.DataFrame
|
||||
df.columns must contain "symbol" and "datetime"
|
||||
"""
|
||||
if df.empty:
|
||||
logger.warning(f"{symbol} is empty")
|
||||
return
|
||||
def _show_logging_func():
|
||||
if interval == YahooCollector.INTERVAL_1min and show_1min_logging:
|
||||
logger.warning(f"{error_msg}:{_resp}")
|
||||
|
||||
symbol = self.normalize_symbol(symbol)
|
||||
symbol = code_to_fname(symbol)
|
||||
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
|
||||
df["symbol"] = symbol
|
||||
if stock_path.exists():
|
||||
_old_df = pd.read_csv(stock_path)
|
||||
df = _old_df.append(df, sort=False)
|
||||
df.to_csv(stock_path, index=False)
|
||||
interval = "1m" if interval in ["1m", "1min"] else interval
|
||||
try:
|
||||
_resp = Ticker(symbol, asynchronous=False).history(interval=interval, start=start, end=end)
|
||||
if isinstance(_resp, pd.DataFrame):
|
||||
return _resp.reset_index()
|
||||
elif isinstance(_resp, dict):
|
||||
_temp_data = _resp.get(symbol, {})
|
||||
if isinstance(_temp_data, str) or (
|
||||
isinstance(_resp, dict) and _temp_data.get("indicators", {}).get("quote", None) is None
|
||||
):
|
||||
_show_logging_func()
|
||||
else:
|
||||
_show_logging_func()
|
||||
except Exception as e:
|
||||
logger.warning(f"{error_msg}:{e}")
|
||||
|
||||
def _save_small_data(self, symbol, df):
|
||||
if len(df) <= self.min_numbers_trading:
|
||||
logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!")
|
||||
_temp = self._mini_symbol_map.setdefault(symbol, [])
|
||||
_temp.append(df.copy())
|
||||
return None
|
||||
else:
|
||||
if symbol in self._mini_symbol_map:
|
||||
self._mini_symbol_map.pop(symbol)
|
||||
return symbol
|
||||
def get_data(
|
||||
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
||||
) -> pd.DataFrame:
|
||||
def _get_simple(start_, end_):
|
||||
self.sleep()
|
||||
_remote_interval = "1m" if interval == self.INTERVAL_1min else interval
|
||||
return self.get_data_from_remote(
|
||||
symbol,
|
||||
interval=_remote_interval,
|
||||
start=start_,
|
||||
end=end_,
|
||||
)
|
||||
|
||||
def _get_data(self, symbol):
|
||||
_result = None
|
||||
df = self.yahoo_data.get_data(symbol)
|
||||
if isinstance(df, pd.DataFrame):
|
||||
if not df.empty:
|
||||
if self._check_small_data:
|
||||
if self._save_small_data(symbol, df) is not None:
|
||||
_result = symbol
|
||||
self.save_stock(symbol, df)
|
||||
else:
|
||||
_result = symbol
|
||||
self.save_stock(symbol, df)
|
||||
return _result
|
||||
if interval == self.INTERVAL_1d:
|
||||
_result = _get_simple(start_datetime, end_datetime)
|
||||
elif interval == self.INTERVAL_1min:
|
||||
if self._next_datetime >= self._latest_datetime:
|
||||
_result = _get_simple(start_datetime, end_datetime)
|
||||
else:
|
||||
_res = []
|
||||
|
||||
def _collector(self, stock_list):
|
||||
def _get_multi(start_, end_):
|
||||
_resp = _get_simple(start_, end_)
|
||||
if _resp is not None and not _resp.empty:
|
||||
_res.append(_resp)
|
||||
|
||||
error_symbol = []
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
with tqdm(total=len(stock_list)) as p_bar:
|
||||
for _symbol, _result in zip(stock_list, executor.map(self._get_data, stock_list)):
|
||||
if _result is None:
|
||||
error_symbol.append(_symbol)
|
||||
p_bar.update()
|
||||
print(error_symbol)
|
||||
logger.info(f"error symbol nums: {len(error_symbol)}")
|
||||
logger.info(f"current get symbol nums: {len(stock_list)}")
|
||||
error_symbol.extend(self._mini_symbol_map.keys())
|
||||
return sorted(set(error_symbol))
|
||||
for _s, _e in (
|
||||
(self.start_datetime, self._next_datetime),
|
||||
(self._latest_datetime, self.end_datetime),
|
||||
):
|
||||
_get_multi(_s, _e)
|
||||
for _start in pd.date_range(self._next_datetime, self._latest_datetime, closed="left"):
|
||||
_end = _start + pd.Timedelta(days=1)
|
||||
_get_multi(_start, _end)
|
||||
if _res:
|
||||
_result = pd.concat(_res, sort=False).sort_values(["symbol", "date"])
|
||||
else:
|
||||
raise ValueError(f"cannot support {self.interval}")
|
||||
return pd.DataFrame() if _result is None else _result
|
||||
|
||||
def collector_data(self):
|
||||
"""collector data"""
|
||||
logger.info("start collector yahoo data......")
|
||||
stock_list = self.stock_list
|
||||
for i in range(self._max_collector_count):
|
||||
if not stock_list:
|
||||
break
|
||||
logger.info(f"getting data: {i+1}")
|
||||
stock_list = self._collector(stock_list)
|
||||
logger.info(f"{i+1} finish.")
|
||||
for _symbol, _df_list in self._mini_symbol_map.items():
|
||||
self.save_stock(_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"]))
|
||||
if self._mini_symbol_map:
|
||||
logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self._mini_symbol_map.keys())}")
|
||||
logger.info(f"total {len(self.stock_list)}, error: {len(set(stock_list))}")
|
||||
|
||||
super(YahooCollector, self).collector_data()
|
||||
self.download_index_data()
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -329,11 +183,6 @@ class YahooCollector:
|
||||
"""download index data"""
|
||||
raise NotImplementedError("rewrite download_index_data")
|
||||
|
||||
@abc.abstractmethod
|
||||
def normalize_symbol(self, symbol: str):
|
||||
"""normalize symbol"""
|
||||
raise NotImplementedError("rewrite normalize_symbol")
|
||||
|
||||
|
||||
class YahooCollectorCN(YahooCollector, ABC):
|
||||
def get_stock_list(self):
|
||||
@@ -360,8 +209,8 @@ class YahooCollectorCN1d(YahooCollectorCN):
|
||||
def download_index_data(self):
|
||||
# TODO: from MSN
|
||||
_format = "%Y%m%d"
|
||||
_begin = self.yahoo_data.start_datetime.strftime(_format)
|
||||
_end = (self.yahoo_data.end_datetime + pd.Timedelta(days=-1)).strftime(_format)
|
||||
_begin = self.start_datetime.strftime(_format)
|
||||
_end = (self.end_datetime + pd.Timedelta(days=-1)).strftime(_format)
|
||||
for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items():
|
||||
logger.info(f"get bench data: {_index_name}({_index_code})......")
|
||||
try:
|
||||
@@ -396,7 +245,7 @@ class YahooCollectorCN1min(YahooCollectorCN):
|
||||
|
||||
def download_index_data(self):
|
||||
# TODO: 1m
|
||||
logger.warning(f"{self.__class__.__name__} {self._interval} does not support: download_index_data")
|
||||
logger.warning(f"{self.__class__.__name__} {self.interval} does not support: download_index_data")
|
||||
|
||||
|
||||
class YahooCollectorUS(YahooCollector, ABC):
|
||||
@@ -433,29 +282,10 @@ class YahooCollectorUS1min(YahooCollectorUS):
|
||||
return 60 * 6.5 * 5
|
||||
|
||||
|
||||
class YahooNormalize:
|
||||
class YahooNormalize(BaseNormalize):
|
||||
COLUMNS = ["open", "close", "high", "low", "volume"]
|
||||
DAILY_FORMAT = "%Y-%m-%d"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
self._date_field_name = date_field_name
|
||||
self._symbol_field_name = symbol_field_name
|
||||
|
||||
self._calendar_list = self._get_calendar_list()
|
||||
|
||||
@staticmethod
|
||||
def normalize_yahoo(
|
||||
df: pd.DataFrame,
|
||||
@@ -498,11 +328,6 @@ class YahooNormalize:
|
||||
df = self.adjusted_price(df)
|
||||
return df
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_calendar_list(self):
|
||||
"""Get benchmark calendar"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
@abc.abstractmethod
|
||||
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""adjusted price"""
|
||||
@@ -618,7 +443,9 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
# get 1d data from yahoo
|
||||
_start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT)
|
||||
_end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT)
|
||||
data_1d = YahooData.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=_start, end=_end)
|
||||
data_1d = YahooCollector.get_data_from_remote(
|
||||
self.symbol_to_yahoo(symbol), interval="1d", start=_start, end=_end
|
||||
)
|
||||
if data_1d is None or data_1d.empty:
|
||||
df["factor"] = 1
|
||||
# TODO: np.nan or 1 or 0
|
||||
@@ -723,62 +550,8 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min):
|
||||
return get_calendar_list("ALL")
|
||||
|
||||
|
||||
class Normalize:
|
||||
def __init__(
|
||||
self,
|
||||
source_dir: [str, Path],
|
||||
target_dir: [str, Path],
|
||||
normalize_class: Type[YahooNormalize],
|
||||
max_workers: int = 16,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str or Path
|
||||
The directory where the raw data collected from the Internet is saved
|
||||
target_dir: str or Path
|
||||
Directory for normalize data
|
||||
normalize_class: Type[YahooNormalize]
|
||||
normalize class
|
||||
max_workers: int
|
||||
Concurrent number, default is 16
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
if not (source_dir and target_dir):
|
||||
raise ValueError("source_dir and target_dir cannot be None")
|
||||
self._source_dir = Path(source_dir).expanduser()
|
||||
self._target_dir = Path(target_dir).expanduser()
|
||||
self._target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._max_workers = max_workers
|
||||
|
||||
self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name)
|
||||
|
||||
def _executor(self, file_path: Path):
|
||||
file_path = Path(file_path)
|
||||
df = pd.read_csv(file_path)
|
||||
df = self._normalize_obj.normalize(df)
|
||||
if not df.empty:
|
||||
df.to_csv(self._target_dir.joinpath(file_path.name), index=False)
|
||||
|
||||
def normalize(self):
|
||||
logger.info("normalize data......")
|
||||
|
||||
with ProcessPoolExecutor(max_workers=self._max_workers) as worker:
|
||||
file_list = list(self._source_dir.glob("*.csv"))
|
||||
with tqdm(total=len(file_list)) as p_bar:
|
||||
for _ in worker.map(self._executor, file_list):
|
||||
p_bar.update()
|
||||
|
||||
|
||||
class Run:
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN):
|
||||
class Run(BaseRun):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
@@ -789,23 +562,26 @@ class Run:
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
max_workers: int
|
||||
Concurrent number, default is 4
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
region: str
|
||||
region, value from ["CN", "US"], default "CN"
|
||||
"""
|
||||
if source_dir is None:
|
||||
source_dir = CUR_DIR.joinpath("source")
|
||||
self.source_dir = Path(source_dir).expanduser().resolve()
|
||||
self.source_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if normalize_dir is None:
|
||||
normalize_dir = CUR_DIR.joinpath("normalize")
|
||||
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
|
||||
self.normalize_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._cur_module = importlib.import_module("collector")
|
||||
self.max_workers = max_workers
|
||||
super().__init__(source_dir, normalize_dir, max_workers, interval)
|
||||
self.region = region
|
||||
|
||||
@property
|
||||
def collector_class_name(self):
|
||||
return f"YahooCollector{self.region.upper()}{self.interval}"
|
||||
|
||||
@property
|
||||
def normalize_class_name(self):
|
||||
return f"YahooNormalize{self.region.upper()}{self.interval}"
|
||||
|
||||
@property
|
||||
def default_base_dir(self) -> [Path, str]:
|
||||
return CUR_DIR
|
||||
|
||||
def download_data(
|
||||
self,
|
||||
max_collector_count=2,
|
||||
@@ -815,7 +591,6 @@ class Run:
|
||||
interval="1d",
|
||||
check_data_length=False,
|
||||
limit_nums=None,
|
||||
show_1min_logging=False,
|
||||
):
|
||||
"""download data from Internet
|
||||
|
||||
@@ -835,8 +610,6 @@ class Run:
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
show_1min_logging: bool
|
||||
show 1m logging, by default False; if True, there may be many warning logs
|
||||
|
||||
Examples
|
||||
---------
|
||||
@@ -846,29 +619,13 @@ class Run:
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
|
||||
"""
|
||||
|
||||
_class = getattr(
|
||||
self._cur_module, f"YahooCollector{self.region.upper()}{interval}"
|
||||
) # type: Type[YahooCollector]
|
||||
_class(
|
||||
self.source_dir,
|
||||
max_workers=self.max_workers,
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
start=start,
|
||||
end=end,
|
||||
interval=interval,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
show_1min_logging=show_1min_logging,
|
||||
).collector_data()
|
||||
super(Run, self).download_data(max_collector_count, delay, start, end, interval, check_data_length, limit_nums)
|
||||
|
||||
def normalize_data(self, interval: str = "1d", date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
||||
"""normalize data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
date_field_name: str
|
||||
date field name, default date
|
||||
symbol_field_name: str
|
||||
@@ -878,16 +635,7 @@ class Run:
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d
|
||||
"""
|
||||
_class = getattr(self._cur_module, f"YahooNormalize{self.region.upper()}{interval}")
|
||||
yc = Normalize(
|
||||
source_dir=self.source_dir,
|
||||
target_dir=self.normalize_dir,
|
||||
normalize_class=_class,
|
||||
max_workers=self.max_workers,
|
||||
date_field_name=date_field_name,
|
||||
symbol_field_name=symbol_field_name,
|
||||
)
|
||||
yc.normalize()
|
||||
super(Run, self).normalize_data(date_field_name, symbol_field_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user