1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Add BaseCollector

This commit is contained in:
zhupr
2021-03-10 16:23:20 +08:00
parent e2817ab87c
commit 42be8ac312
2 changed files with 548 additions and 370 deletions

View File

@@ -0,0 +1,430 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import abc
import time
import datetime
import importlib
from pathlib import Path
from typing import Type
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import pandas as pd
from tqdm import tqdm
from loguru import logger
from qlib.utils import code_to_fname
class BaseCollector(abc.ABC):
CACHE_FLAG = "CACHED"
NORMAL_FLAG = "NORMAL"
DEFAULT_START_DATETIME_1D = pd.Timestamp("2000-01-01")
DEFAULT_START_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6))
DEFAULT_END_DATETIME_1D = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
DEFAULT_END_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
INTERVAL_1min = "1min"
INTERVAL_1d = "1d"
def __init__(
self,
save_dir: [str, Path],
start=None,
end=None,
interval="1d",
max_workers=4,
max_collector_count=2,
delay=0,
check_data_length: bool = False,
limit_nums: int = None,
):
"""
Parameters
----------
save_dir: str
stock save dir
max_workers: int
workers, default 4
max_collector_count: int
default 2
delay: float
time.sleep(delay), default 0
interval: str
freq, value from [1min, 1d], default 1d
start: str
start datetime, default None
end: str
end datetime, default None
check_data_length: bool
check data length, by default False
limit_nums: int
using for debug, by default None
"""
self.save_dir = Path(save_dir).expanduser().resolve()
self.save_dir.mkdir(parents=True, exist_ok=True)
self.delay = delay
self.max_workers = max_workers
self.max_collector_count = max_collector_count
self.mini_symbol_map = {}
self.interval = interval
self.check_small_data = check_data_length
self.start_datetime = self.normalize_start_datetime(start)
self.end_datetime = self.normalize_end_datetime(end)
self.stock_list = sorted(set(self.get_stock_list()))
if limit_nums is not None:
try:
self.stock_list = self.stock_list[: int(limit_nums)]
except Exception as e:
logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored")
def normalize_start_datetime(self, start_datetime: [str, pd.Timestamp] = None):
return (
pd.Timestamp(str(start_datetime))
if start_datetime
else getattr(self, f"DEFAULT_START_DATETIME_{self.interval.upper()}")
)
def normalize_end_datetime(self, end_datetime: [str, pd.Timestamp] = None):
return (
pd.Timestamp(str(end_datetime))
if end_datetime
else getattr(self, f"DEFAULT_END_DATETIME_{self.interval.upper()}")
)
@property
@abc.abstractmethod
def min_numbers_trading(self):
# daily, one year: 252 / 4
# us 1min, a week: 6.5 * 60 * 5
# cn 1min, a week: 4 * 60 * 5
raise NotImplementedError("rewrite min_numbers_trading")
@abc.abstractmethod
def get_stock_list(self):
raise NotImplementedError("rewrite get_stock_list")
@abc.abstractmethod
def normalize_symbol(self, symbol: str):
"""normalize symbol"""
raise NotImplementedError("rewrite normalize_symbol")
@abc.abstractmethod
def get_data(
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
) -> pd.DataFrame:
"""get data with symbol
Parameters
----------
symbol: str
interval: str
value from [1min, 1d]
start_datetime: pd.Timestamp
end_datetime: pd.Timestamp
Returns
---------
pd.DataFrame, "symbol" in pd.columns
"""
raise NotImplementedError("rewrite get_timezone")
def sleep(self):
time.sleep(self.delay)
def _simple_collector(self, symbol: str):
"""
Parameters
----------
symbol: str
"""
self.sleep()
df = self.get_data(symbol, self.interval, self.start_datetime, self.end_datetime)
_result = self.NORMAL_FLAG
if self.check_small_data:
_result = self.cache_small_data(symbol, df)
if _result == self.NORMAL_FLAG:
self.save_instrument(symbol, df)
return _result
def save_instrument(self, symbol, df: pd.DataFrame):
"""save stock data to file
Parameters
----------
symbol: str
stock code
df : pd.DataFrame
df.columns must contain "symbol" and "datetime"
"""
if df.empty:
logger.warning(f"{symbol} is empty")
return
symbol = self.normalize_symbol(symbol)
symbol = code_to_fname(symbol)
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
df["symbol"] = symbol
if stock_path.exists():
_old_df = pd.read_csv(stock_path)
df = _old_df.append(df, sort=False)
df.to_csv(stock_path, index=False)
def cache_small_data(self, symbol, df):
if len(df) <= self.min_numbers_trading:
logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!")
_temp = self.mini_symbol_map.setdefault(symbol, [])
_temp.append(df.copy())
return self.CACHE_FLAG
else:
if symbol in self.mini_symbol_map:
self.mini_symbol_map.pop(symbol)
return self.NORMAL_FLAG
def _collector(self, stock_list):
error_symbol = []
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
with tqdm(total=len(stock_list)) as p_bar:
for _symbol, _result in zip(stock_list, executor.map(self._simple_collector, stock_list)):
if _result != self.NORMAL_FLAG:
error_symbol.append(_symbol)
p_bar.update()
print(error_symbol)
logger.info(f"error symbol nums: {len(error_symbol)}")
logger.info(f"current get symbol nums: {len(stock_list)}")
error_symbol.extend(self.mini_symbol_map.keys())
return sorted(set(error_symbol))
def collector_data(self):
"""collector data"""
logger.info("start collector data......")
stock_list = self.stock_list
for i in range(self.max_collector_count):
if not stock_list:
break
logger.info(f"getting data: {i+1}")
stock_list = self._collector(stock_list)
logger.info(f"{i+1} finish.")
for _symbol, _df_list in self.mini_symbol_map.items():
self.save_instrument(
_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])
)
if self.mini_symbol_map:
logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self.mini_symbol_map.keys())}")
logger.info(f"total {len(self.stock_list)}, error: {len(set(stock_list))}")
class BaseNormalize(abc.ABC):
def __init__(
self,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
):
"""
Parameters
----------
date_field_name: str
date field name, default is date
symbol_field_name: str
symbol field name, default is symbol
"""
self._date_field_name = date_field_name
self._symbol_field_name = symbol_field_name
self._calendar_list = self._get_calendar_list()
@abc.abstractmethod
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
# normalize
raise NotImplementedError("")
@abc.abstractmethod
def _get_calendar_list(self):
"""Get benchmark calendar"""
raise NotImplementedError("")
class Normalize:
def __init__(
self,
source_dir: [str, Path],
target_dir: [str, Path],
normalize_class: Type[BaseNormalize],
max_workers: int = 16,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
):
"""
Parameters
----------
source_dir: str or Path
The directory where the raw data collected from the Internet is saved
target_dir: str or Path
Directory for normalize data
normalize_class: Type[YahooNormalize]
normalize class
max_workers: int
Concurrent number, default is 16
date_field_name: str
date field name, default is date
symbol_field_name: str
symbol field name, default is symbol
"""
if not (source_dir and target_dir):
raise ValueError("source_dir and target_dir cannot be None")
self._source_dir = Path(source_dir).expanduser()
self._target_dir = Path(target_dir).expanduser()
self._target_dir.mkdir(parents=True, exist_ok=True)
self._max_workers = max_workers
self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name)
def _executor(self, file_path: Path):
file_path = Path(file_path)
df = pd.read_csv(file_path)
df = self._normalize_obj.normalize(df)
if not df.empty:
df.to_csv(self._target_dir.joinpath(file_path.name), index=False)
def normalize(self):
logger.info("normalize data......")
with ProcessPoolExecutor(max_workers=self._max_workers) as worker:
file_list = list(self._source_dir.glob("*.csv"))
with tqdm(total=len(file_list)) as p_bar:
for _ in worker.map(self._executor, file_list):
p_bar.update()
class BaseRun(abc.ABC):
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d"):
"""
Parameters
----------
source_dir: str
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
normalize_dir: str
Directory for normalize data, default "Path(__file__).parent/normalize"
max_workers: int
Concurrent number, default is 4
interval: str
freq, value from [1min, 1d], default 1d
"""
if source_dir is None:
source_dir = Path(self.default_base_dir).joinpath("_source")
self.source_dir = Path(source_dir).expanduser().resolve()
self.source_dir.mkdir(parents=True, exist_ok=True)
if normalize_dir is None:
normalize_dir = Path(self.default_base_dir).joinpath("normalize")
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
self.normalize_dir.mkdir(parents=True, exist_ok=True)
self._cur_module = importlib.import_module("collector")
self.max_workers = max_workers
self.interval = interval
@property
@abc.abstractmethod
def collector_class_name(self):
raise NotImplementedError("rewrite normalize_symbol")
@property
@abc.abstractmethod
def normalize_class_name(self):
raise NotImplementedError("rewrite normalize_symbol")
@property
@abc.abstractmethod
def default_base_dir(self) -> [Path, str]:
raise NotImplementedError("rewrite normalize_symbol")
def download_data(
self,
max_collector_count=2,
delay=0,
start=None,
end=None,
interval="1d",
check_data_length=False,
limit_nums=None,
):
"""download data from Internet
Parameters
----------
max_collector_count: int
default 2
delay: float
time.sleep(delay), default 0
interval: str
freq, value from [1min, 1d], default 1d
start: str
start datetime, default "2000-01-01"
end: str
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
check_data_length: bool
check data length, by default False
limit_nums: int
using for debug, by default None
Examples
---------
# get daily data
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
# get 1m data
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
"""
_class = getattr(self._cur_module, self.collector_class_name) # type: Type[BaseCollector]
_class(
self.source_dir,
max_workers=self.max_workers,
max_collector_count=max_collector_count,
delay=delay,
start=start,
end=end,
interval=interval,
check_data_length=check_data_length,
limit_nums=limit_nums,
).collector_data()
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
"""normalize data
Parameters
----------
date_field_name: str
date field name, default date
symbol_field_name: str
symbol field name, default symbol
Examples
---------
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d
"""
_class = getattr(self._cur_module, self.normalize_class_name)
yc = Normalize(
source_dir=self.source_dir,
target_dir=self.normalize_dir,
normalize_class=_class,
max_workers=self.max_workers,
date_field_name=date_field_name,
symbol_field_name=symbol_field_name,
)
yc.normalize()

View File

@@ -10,158 +10,26 @@ import importlib
from abc import ABC
from pathlib import Path
from typing import Iterable, Type
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import fire
import requests
import numpy as np
import pandas as pd
from tqdm import tqdm
from loguru import logger
from yahooquery import Ticker
from dateutil.tz import tzlocal
from qlib.utils import code_to_fname, fname_to_code
from qlib.config import REG_CN as REGION_CN
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
from data_collector.utils import get_calendar_list, get_hs_stock_symbols, get_us_stock_symbols
INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}"
REGION_CN = "CN"
REGION_US = "US"
class YahooData:
START_DATETIME = pd.Timestamp("2000-01-01")
HIGH_FREQ_START_DATETIME = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6))
END_DATETIME = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
INTERVAL_1min = "1min"
INTERVAL_1d = "1d"
def __init__(
self,
timezone: str = None,
start=None,
end=None,
interval="1d",
delay=0,
show_1min_logging: bool = False,
):
"""
Parameters
----------
timezone: str
The timezone where the data is located
delay: float
time.sleep(delay), default 0
interval: str
freq, value from [1min, 1d], default 1min
start: str
start datetime, default None
end: str
end datetime, default None
show_1min_logging: bool
show 1min logging, by default False; if True, there may be many warning logs
"""
self._timezone = tzlocal() if timezone is None else timezone
self._delay = delay
self._interval = interval
self._show_1min_logging = show_1min_logging
self.start_datetime = pd.Timestamp(str(start)) if start else self.START_DATETIME
self.end_datetime = min(pd.Timestamp(str(end)) if end else self.END_DATETIME, self.END_DATETIME)
if self._interval == self.INTERVAL_1min:
self.start_datetime = max(self.start_datetime, self.HIGH_FREQ_START_DATETIME)
elif self._interval == self.INTERVAL_1d:
pass
else:
raise ValueError(f"interval error: {self._interval}")
# using for 1min
self._next_datetime = self.convert_datetime(self.start_datetime.date() + pd.Timedelta(days=1), self._timezone)
self._latest_datetime = self.convert_datetime(self.end_datetime.date(), self._timezone)
self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)
self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)
@staticmethod
def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone):
try:
dt = pd.Timestamp(dt, tz=timezone).timestamp()
dt = pd.Timestamp(dt, tz=tzlocal(), unit="s")
except ValueError as e:
pass
return dt
def _sleep(self):
time.sleep(self._delay)
@staticmethod
def get_data_from_remote(symbol, interval, start, end, show_1min_logging: bool = False):
error_msg = f"{symbol}-{interval}-{start}-{end}"
def _show_logging_func():
if interval == YahooData.INTERVAL_1min and show_1min_logging:
logger.warning(f"{error_msg}:{_resp}")
interval = "1m" if interval in ["1m", "1min"] else interval
try:
_resp = Ticker(symbol, asynchronous=False).history(interval=interval, start=start, end=end)
if isinstance(_resp, pd.DataFrame):
return _resp.reset_index()
elif isinstance(_resp, dict):
_temp_data = _resp.get(symbol, {})
if isinstance(_temp_data, str) or (
isinstance(_resp, dict) and _temp_data.get("indicators", {}).get("quote", None) is None
):
_show_logging_func()
else:
_show_logging_func()
except Exception as e:
logger.warning(f"{error_msg}:{e}")
def get_data(self, symbol: str) -> [pd.DataFrame]:
def _get_simple(start_, end_):
self._sleep()
_remote_interval = "1m" if self._interval == self.INTERVAL_1min else self._interval
return self.get_data_from_remote(
symbol,
interval=_remote_interval,
start=start_,
end=end_,
show_1min_logging=self._show_1min_logging,
)
_result = None
if self._interval == self.INTERVAL_1d:
_result = _get_simple(self.start_datetime, self.end_datetime)
elif self._interval == self.INTERVAL_1min:
if self._next_datetime >= self._latest_datetime:
_result = _get_simple(self.start_datetime, self.end_datetime)
else:
_res = []
def _get_multi(start_, end_):
_resp = _get_simple(start_, end_)
if _resp is not None and not _resp.empty:
_res.append(_resp)
for _s, _e in (
(self.start_datetime, self._next_datetime),
(self._latest_datetime, self.end_datetime),
):
_get_multi(_s, _e)
for _start in pd.date_range(self._next_datetime, self._latest_datetime, closed="left"):
_end = _start + pd.Timedelta(days=1)
_get_multi(_start, _end)
if _res:
_result = pd.concat(_res, sort=False).sort_values(["symbol", "date"])
else:
raise ValueError(f"cannot support {self._interval}")
return _result
class YahooCollector:
class YahooCollector(BaseCollector):
def __init__(
self,
save_dir: [str, Path],
@@ -173,7 +41,6 @@ class YahooCollector:
delay=0,
check_data_length: bool = False,
limit_nums: int = None,
show_1min_logging: bool = False,
):
"""
@@ -197,131 +64,118 @@ class YahooCollector:
check data length, by default False
limit_nums: int
using for debug, by default None
show_1min_logging: bool
show 1m logging, by default False; if True, there may be many warning logs
"""
self.save_dir = Path(save_dir).expanduser().resolve()
self.save_dir.mkdir(parents=True, exist_ok=True)
self._delay = delay
self.max_workers = max_workers
self._max_collector_count = max_collector_count
self._mini_symbol_map = {}
self._interval = interval
self._check_small_data = check_data_length
self.stock_list = sorted(set(self.get_stock_list()))
if limit_nums is not None:
try:
self.stock_list = self.stock_list[: int(limit_nums)]
except Exception as e:
logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored")
self.yahoo_data = YahooData(
timezone=self._timezone,
super(YahooCollector, self).__init__(
save_dir=save_dir,
start=start,
end=end,
interval=interval,
max_workers=max_workers,
max_collector_count=max_collector_count,
delay=delay,
show_1min_logging=show_1min_logging,
check_data_length=check_data_length,
limit_nums=limit_nums,
)
@property
@abc.abstractmethod
def min_numbers_trading(self):
# daily, one year: 252 / 4
# us 1min, a week: 6.5 * 60 * 5
# cn 1min, a week: 4 * 60 * 5
raise NotImplementedError("rewrite min_numbers_trading")
self.init_datetime()
@abc.abstractmethod
def get_stock_list(self):
raise NotImplementedError("rewrite get_stock_list")
def init_datetime(self):
if self.interval == self.INTERVAL_1min:
self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN)
elif self.interval == self.INTERVAL_1d:
pass
else:
raise ValueError(f"interval error: {self.interval}")
# using for 1min
self._next_datetime = self.convert_datetime(self.start_datetime.date() + pd.Timedelta(days=1), self._timezone)
self._latest_datetime = self.convert_datetime(self.end_datetime.date(), self._timezone)
self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)
self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)
@staticmethod
def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone):
try:
dt = pd.Timestamp(dt, tz=timezone).timestamp()
dt = pd.Timestamp(dt, tz=tzlocal(), unit="s")
except ValueError as e:
pass
return dt
@property
@abc.abstractmethod
def _timezone(self):
raise NotImplementedError("rewrite get_timezone")
def save_stock(self, symbol, df: pd.DataFrame):
"""save stock data to file
@staticmethod
def get_data_from_remote(symbol, interval, start, end, show_1min_logging: bool = False):
error_msg = f"{symbol}-{interval}-{start}-{end}"
Parameters
----------
symbol: str
stock code
df : pd.DataFrame
df.columns must contain "symbol" and "datetime"
"""
if df.empty:
logger.warning(f"{symbol} is empty")
return
def _show_logging_func():
if interval == YahooCollector.INTERVAL_1min and show_1min_logging:
logger.warning(f"{error_msg}:{_resp}")
symbol = self.normalize_symbol(symbol)
symbol = code_to_fname(symbol)
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
df["symbol"] = symbol
if stock_path.exists():
_old_df = pd.read_csv(stock_path)
df = _old_df.append(df, sort=False)
df.to_csv(stock_path, index=False)
interval = "1m" if interval in ["1m", "1min"] else interval
try:
_resp = Ticker(symbol, asynchronous=False).history(interval=interval, start=start, end=end)
if isinstance(_resp, pd.DataFrame):
return _resp.reset_index()
elif isinstance(_resp, dict):
_temp_data = _resp.get(symbol, {})
if isinstance(_temp_data, str) or (
isinstance(_resp, dict) and _temp_data.get("indicators", {}).get("quote", None) is None
):
_show_logging_func()
else:
_show_logging_func()
except Exception as e:
logger.warning(f"{error_msg}:{e}")
def _save_small_data(self, symbol, df):
if len(df) <= self.min_numbers_trading:
logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!")
_temp = self._mini_symbol_map.setdefault(symbol, [])
_temp.append(df.copy())
return None
else:
if symbol in self._mini_symbol_map:
self._mini_symbol_map.pop(symbol)
return symbol
def get_data(
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
) -> pd.DataFrame:
def _get_simple(start_, end_):
self.sleep()
_remote_interval = "1m" if interval == self.INTERVAL_1min else interval
return self.get_data_from_remote(
symbol,
interval=_remote_interval,
start=start_,
end=end_,
)
def _get_data(self, symbol):
_result = None
df = self.yahoo_data.get_data(symbol)
if isinstance(df, pd.DataFrame):
if not df.empty:
if self._check_small_data:
if self._save_small_data(symbol, df) is not None:
_result = symbol
self.save_stock(symbol, df)
else:
_result = symbol
self.save_stock(symbol, df)
return _result
if interval == self.INTERVAL_1d:
_result = _get_simple(start_datetime, end_datetime)
elif interval == self.INTERVAL_1min:
if self._next_datetime >= self._latest_datetime:
_result = _get_simple(start_datetime, end_datetime)
else:
_res = []
def _collector(self, stock_list):
def _get_multi(start_, end_):
_resp = _get_simple(start_, end_)
if _resp is not None and not _resp.empty:
_res.append(_resp)
error_symbol = []
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
with tqdm(total=len(stock_list)) as p_bar:
for _symbol, _result in zip(stock_list, executor.map(self._get_data, stock_list)):
if _result is None:
error_symbol.append(_symbol)
p_bar.update()
print(error_symbol)
logger.info(f"error symbol nums: {len(error_symbol)}")
logger.info(f"current get symbol nums: {len(stock_list)}")
error_symbol.extend(self._mini_symbol_map.keys())
return sorted(set(error_symbol))
for _s, _e in (
(self.start_datetime, self._next_datetime),
(self._latest_datetime, self.end_datetime),
):
_get_multi(_s, _e)
for _start in pd.date_range(self._next_datetime, self._latest_datetime, closed="left"):
_end = _start + pd.Timedelta(days=1)
_get_multi(_start, _end)
if _res:
_result = pd.concat(_res, sort=False).sort_values(["symbol", "date"])
else:
raise ValueError(f"cannot support {self.interval}")
return pd.DataFrame() if _result is None else _result
def collector_data(self):
"""collector data"""
logger.info("start collector yahoo data......")
stock_list = self.stock_list
for i in range(self._max_collector_count):
if not stock_list:
break
logger.info(f"getting data: {i+1}")
stock_list = self._collector(stock_list)
logger.info(f"{i+1} finish.")
for _symbol, _df_list in self._mini_symbol_map.items():
self.save_stock(_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"]))
if self._mini_symbol_map:
logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self._mini_symbol_map.keys())}")
logger.info(f"total {len(self.stock_list)}, error: {len(set(stock_list))}")
super(YahooCollector, self).collector_data()
self.download_index_data()
@abc.abstractmethod
@@ -329,11 +183,6 @@ class YahooCollector:
"""download index data"""
raise NotImplementedError("rewrite download_index_data")
@abc.abstractmethod
def normalize_symbol(self, symbol: str):
"""normalize symbol"""
raise NotImplementedError("rewrite normalize_symbol")
class YahooCollectorCN(YahooCollector, ABC):
def get_stock_list(self):
@@ -360,8 +209,8 @@ class YahooCollectorCN1d(YahooCollectorCN):
def download_index_data(self):
# TODO: from MSN
_format = "%Y%m%d"
_begin = self.yahoo_data.start_datetime.strftime(_format)
_end = (self.yahoo_data.end_datetime + pd.Timedelta(days=-1)).strftime(_format)
_begin = self.start_datetime.strftime(_format)
_end = (self.end_datetime + pd.Timedelta(days=-1)).strftime(_format)
for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items():
logger.info(f"get bench data: {_index_name}({_index_code})......")
try:
@@ -396,7 +245,7 @@ class YahooCollectorCN1min(YahooCollectorCN):
def download_index_data(self):
# TODO: 1m
logger.warning(f"{self.__class__.__name__} {self._interval} does not support: download_index_data")
logger.warning(f"{self.__class__.__name__} {self.interval} does not support: download_index_data")
class YahooCollectorUS(YahooCollector, ABC):
@@ -433,29 +282,10 @@ class YahooCollectorUS1min(YahooCollectorUS):
return 60 * 6.5 * 5
class YahooNormalize:
class YahooNormalize(BaseNormalize):
COLUMNS = ["open", "close", "high", "low", "volume"]
DAILY_FORMAT = "%Y-%m-%d"
def __init__(
self,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
):
"""
Parameters
----------
date_field_name: str
date field name, default is date
symbol_field_name: str
symbol field name, default is symbol
"""
self._date_field_name = date_field_name
self._symbol_field_name = symbol_field_name
self._calendar_list = self._get_calendar_list()
@staticmethod
def normalize_yahoo(
df: pd.DataFrame,
@@ -498,11 +328,6 @@ class YahooNormalize:
df = self.adjusted_price(df)
return df
@abc.abstractmethod
def _get_calendar_list(self):
"""Get benchmark calendar"""
raise NotImplementedError("")
@abc.abstractmethod
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
"""adjusted price"""
@@ -618,7 +443,9 @@ class YahooNormalize1min(YahooNormalize, ABC):
# get 1d data from yahoo
_start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT)
_end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT)
data_1d = YahooData.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=_start, end=_end)
data_1d = YahooCollector.get_data_from_remote(
self.symbol_to_yahoo(symbol), interval="1d", start=_start, end=_end
)
if data_1d is None or data_1d.empty:
df["factor"] = 1
# TODO: np.nan or 1 or 0
@@ -723,62 +550,8 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min):
return get_calendar_list("ALL")
class Normalize:
def __init__(
self,
source_dir: [str, Path],
target_dir: [str, Path],
normalize_class: Type[YahooNormalize],
max_workers: int = 16,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
):
"""
Parameters
----------
source_dir: str or Path
The directory where the raw data collected from the Internet is saved
target_dir: str or Path
Directory for normalize data
normalize_class: Type[YahooNormalize]
normalize class
max_workers: int
Concurrent number, default is 16
date_field_name: str
date field name, default is date
symbol_field_name: str
symbol field name, default is symbol
"""
if not (source_dir and target_dir):
raise ValueError("source_dir and target_dir cannot be None")
self._source_dir = Path(source_dir).expanduser()
self._target_dir = Path(target_dir).expanduser()
self._target_dir.mkdir(parents=True, exist_ok=True)
self._max_workers = max_workers
self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name)
def _executor(self, file_path: Path):
file_path = Path(file_path)
df = pd.read_csv(file_path)
df = self._normalize_obj.normalize(df)
if not df.empty:
df.to_csv(self._target_dir.joinpath(file_path.name), index=False)
def normalize(self):
logger.info("normalize data......")
with ProcessPoolExecutor(max_workers=self._max_workers) as worker:
file_list = list(self._source_dir.glob("*.csv"))
with tqdm(total=len(file_list)) as p_bar:
for _ in worker.map(self._executor, file_list):
p_bar.update()
class Run:
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN):
class Run(BaseRun):
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN):
"""
Parameters
@@ -789,23 +562,26 @@ class Run:
Directory for normalize data, default "Path(__file__).parent/normalize"
max_workers: int
Concurrent number, default is 4
interval: str
freq, value from [1min, 1d], default 1d
region: str
region, value from ["CN", "US"], default "CN"
"""
if source_dir is None:
source_dir = CUR_DIR.joinpath("source")
self.source_dir = Path(source_dir).expanduser().resolve()
self.source_dir.mkdir(parents=True, exist_ok=True)
if normalize_dir is None:
normalize_dir = CUR_DIR.joinpath("normalize")
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
self.normalize_dir.mkdir(parents=True, exist_ok=True)
self._cur_module = importlib.import_module("collector")
self.max_workers = max_workers
super().__init__(source_dir, normalize_dir, max_workers, interval)
self.region = region
@property
def collector_class_name(self):
return f"YahooCollector{self.region.upper()}{self.interval}"
@property
def normalize_class_name(self):
return f"YahooNormalize{self.region.upper()}{self.interval}"
@property
def default_base_dir(self) -> [Path, str]:
return CUR_DIR
def download_data(
self,
max_collector_count=2,
@@ -815,7 +591,6 @@ class Run:
interval="1d",
check_data_length=False,
limit_nums=None,
show_1min_logging=False,
):
"""download data from Internet
@@ -835,8 +610,6 @@ class Run:
check data length, by default False
limit_nums: int
using for debug, by default None
show_1min_logging: bool
show 1m logging, by default False; if True, there may be many warning logs
Examples
---------
@@ -846,29 +619,13 @@ class Run:
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
"""
_class = getattr(
self._cur_module, f"YahooCollector{self.region.upper()}{interval}"
) # type: Type[YahooCollector]
_class(
self.source_dir,
max_workers=self.max_workers,
max_collector_count=max_collector_count,
delay=delay,
start=start,
end=end,
interval=interval,
check_data_length=check_data_length,
limit_nums=limit_nums,
show_1min_logging=show_1min_logging,
).collector_data()
super(Run, self).download_data(max_collector_count, delay, start, end, interval, check_data_length, limit_nums)
def normalize_data(self, interval: str = "1d", date_field_name: str = "date", symbol_field_name: str = "symbol"):
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
"""normalize data
Parameters
----------
interval: str
freq, value from [1min, 1d], default 1d
date_field_name: str
date field name, default date
symbol_field_name: str
@@ -878,16 +635,7 @@ class Run:
---------
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d
"""
_class = getattr(self._cur_module, f"YahooNormalize{self.region.upper()}{interval}")
yc = Normalize(
source_dir=self.source_dir,
target_dir=self.normalize_dir,
normalize_class=_class,
max_workers=self.max_workers,
date_field_name=date_field_name,
symbol_field_name=symbol_field_name,
)
yc.normalize()
super(Run, self).normalize_data(date_field_name, symbol_field_name)
if __name__ == "__main__":