mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
fix collector start datetime
This commit is contained in:
@@ -17,6 +17,7 @@ import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from yahooquery import Ticker
|
||||
from dateutil.tz import tzlocal
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
@@ -42,6 +43,7 @@ class YahooCollector:
|
||||
max_collector_count=5,
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -63,18 +65,25 @@ class YahooCollector:
|
||||
end datetime, default None
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
"""
|
||||
self.save_dir = Path(save_dir).expanduser().resolve()
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._delay = delay
|
||||
self.stock_list = sorted(set(self.get_stock_list()))
|
||||
if limit_nums is not None:
|
||||
try:
|
||||
self.stock_list = self.stock_list[: int(limit_nums)]
|
||||
except Exception as e:
|
||||
logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored")
|
||||
self.max_workers = max_workers
|
||||
self._max_collector_count = max_collector_count
|
||||
self._mini_symbol_map = {}
|
||||
self._interval = interval
|
||||
self._check_small_data = check_data_length
|
||||
self._start_datetime = pd.Timestamp(start) if start else self.START_DATETIME
|
||||
self._end_datetime = pd.Timestamp(end) if end else self.END_DATETIME
|
||||
self._start_datetime = pd.Timestamp(str(start)) if start else self.START_DATETIME
|
||||
self._end_datetime = pd.Timestamp(str(end)) if end else self.END_DATETIME
|
||||
if self._interval == "1m":
|
||||
self._start_datetime = max(self._start_datetime, self.HIGH_FREQ_START_DATETIME)
|
||||
elif self._interval == "1d":
|
||||
@@ -82,7 +91,8 @@ class YahooCollector:
|
||||
else:
|
||||
raise ValueError(f"interval error: {self._interval}")
|
||||
|
||||
self._end_datetime = min(self._end_datetime, self.END_DATETIME)
|
||||
self._start_datetime = self.convert_datetime(self._start_datetime)
|
||||
self._end_datetime = self.convert_datetime(min(self._end_datetime, self.END_DATETIME))
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
@@ -90,11 +100,20 @@ class YahooCollector:
|
||||
# daily, one year: 252 / 4
|
||||
# us 1min, a week: 6.5 * 60 * 5
|
||||
# cn 1min, a week: 4 * 60 * 5
|
||||
raise NotImplementedError("")
|
||||
raise NotImplementedError("rewirte min_numbers_trading")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_stock_list(self):
|
||||
raise NotImplementedError("")
|
||||
raise NotImplementedError("rewirte get_stock_list")
|
||||
|
||||
@property
|
||||
@abc.abstractclassmethod
|
||||
def _timezone(self):
|
||||
raise NotImplementedError("rewrite get_timezone")
|
||||
|
||||
def convert_datetime(self, dt: pd.Timestamp):
|
||||
dt = pd.Timestamp(dt, tz=self._timezone).timestamp()
|
||||
return pd.Timestamp(dt, tz=tzlocal(), unit="s")
|
||||
|
||||
def _sleep(self):
|
||||
time.sleep(self._delay)
|
||||
@@ -112,80 +131,90 @@ class YahooCollector:
|
||||
if df.empty:
|
||||
raise ValueError("df is empty")
|
||||
|
||||
symbol = self.normailze_symbol(symbol)
|
||||
symbol = self.normalize_symbol(symbol)
|
||||
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
|
||||
df["symbol"] = symbol
|
||||
df.to_csv(stock_path, index=False)
|
||||
if stock_path.exists():
|
||||
with stock_path.open("a") as fp:
|
||||
df.to_csv(fp, index=False, header=None)
|
||||
else:
|
||||
with stock_path.open("w") as fp:
|
||||
df.to_csv(fp, index=False)
|
||||
|
||||
def _save_small_data(self, symbol, df):
|
||||
if len(df) <= self.min_numbers_trading:
|
||||
logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!")
|
||||
_temp = self._mini_symbol_map.setdefault(symbol, [])
|
||||
_temp.append(df.copy())
|
||||
return symbol
|
||||
return None
|
||||
else:
|
||||
if symbol in self._mini_symbol_map:
|
||||
self._mini_symbol_map.pop(symbol)
|
||||
return None
|
||||
return symbol
|
||||
|
||||
def _get_from_remote(self, symbol):
|
||||
if self._interval == "1d":
|
||||
def _get_simple(start_, end_):
|
||||
self._sleep()
|
||||
try:
|
||||
resp = Ticker(symbol, asynchronous=False).history(
|
||||
interval=self._interval, start=self._start_datetime, end=self._end_datetime
|
||||
)
|
||||
_resp = Ticker(symbol, asynchronous=False).history(interval=self._interval, start=start_, end=end_)
|
||||
if isinstance(_resp, pd.DataFrame):
|
||||
return _resp.reset_index()
|
||||
else:
|
||||
logger.warning(f"{symbol}-{self._interval}-{start_}-{end_}:{_resp}")
|
||||
except Exception as e:
|
||||
logger.warning(f"{symbol}-{self._interval}-{self._start_datetime}-{self._end_datetime}:{e}")
|
||||
resp = None
|
||||
yield resp
|
||||
logger.warning(f"{symbol}-{self._interval}-{start_}-{end_}:{e}")
|
||||
|
||||
_result = None
|
||||
if self._interval == "1d":
|
||||
_result = _get_simple(self._start_datetime, self._end_datetime)
|
||||
elif self._interval == "1m":
|
||||
_res = []
|
||||
for _start in pd.date_range(self._start_datetime, self._end_datetime + pd.Timedelta(days=-1)):
|
||||
_end = _start + pd.Timedelta(days=1)
|
||||
self._sleep()
|
||||
try:
|
||||
resp = Ticker(symbol, asynchronous=False).history(interval=self._interval, start=_start, end=_end)
|
||||
if isinstance(resp, pd.DataFrame):
|
||||
_res.append(resp)
|
||||
except Exception as e:
|
||||
logger.warning(f"{symbol}-{self._interval}-{_start}-{_end}:{e}")
|
||||
if _res:
|
||||
yield pd.concat(_res, sort=False).sort_values(["symbol", "date"])
|
||||
_start_date = self._start_datetime.date() + pd.Timedelta(days=1)
|
||||
_end_date = self._end_datetime.date()
|
||||
if _start_date >= _end_date:
|
||||
_result = _get_simple(self._start_datetime, self._end_datetime)
|
||||
else:
|
||||
yield None
|
||||
_res = []
|
||||
|
||||
def _get_multi(start_, end_):
|
||||
_resp = _get_simple(start_, end_)
|
||||
if _resp is not None:
|
||||
_res.append(_resp)
|
||||
|
||||
for _s, _e in ((self._start_datetime, _start_date), (_end_date, self._end_datetime)):
|
||||
_get_multi(_s, _e)
|
||||
for _start in pd.date_range(_start_date, _end_date, closed="left"):
|
||||
_end = _start + pd.Timedelta(days=1)
|
||||
self._sleep()
|
||||
_get_multi(_start, _end)
|
||||
if _res:
|
||||
_result = pd.concat(_res, sort=False).sort_values(["symbol", "date"])
|
||||
else:
|
||||
raise ValueError(f"cannot support {self._interval}")
|
||||
return _result
|
||||
|
||||
def _get_data(self, symbol):
|
||||
_result = None
|
||||
df = self._get_from_remote(symbol)
|
||||
if isinstance(df, pd.DataFrame):
|
||||
if not df.empty:
|
||||
if self._check_small_data:
|
||||
if self._save_small_data(symbol, df) is not None:
|
||||
_result = symbol
|
||||
self.save_stock(symbol, df)
|
||||
else:
|
||||
_result = symbol
|
||||
self.save_stock(symbol, df)
|
||||
return _result
|
||||
|
||||
def _collector(self, stock_list):
|
||||
|
||||
error_symbol = []
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as worker:
|
||||
futures = {}
|
||||
for _symbol in tqdm(stock_list):
|
||||
for _resp in self._get_from_remote(_symbol):
|
||||
if isinstance(_resp, pd.DataFrame):
|
||||
df = _resp.reset_index()
|
||||
if self._check_small_data:
|
||||
if self._save_small_data(_symbol, df) is not None:
|
||||
error_symbol.append(_symbol)
|
||||
futures[worker.submit(self.save_stock, _symbol, df)] = _symbol
|
||||
elif isinstance(_resp, dict):
|
||||
if "timestamp" in _resp[_symbol]:
|
||||
logger.warning(_resp[_symbol])
|
||||
error_symbol.append(_symbol)
|
||||
elif _resp is None:
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
with tqdm(total=len(stock_list)) as p_bar:
|
||||
for _symbol, _result in zip(stock_list, executor.map(self._get_data, stock_list)):
|
||||
if _result is None:
|
||||
error_symbol.append(_symbol)
|
||||
else:
|
||||
if not (("1m data not available for" in _resp) or ("Data doesn't exist for" in _resp)):
|
||||
error_symbol.append(_symbol)
|
||||
logger.info("save stock data......")
|
||||
for future in tqdm(as_completed(futures)):
|
||||
try:
|
||||
future.result()
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
error_symbol.append(futures[future])
|
||||
p_bar.update()
|
||||
print(error_symbol)
|
||||
logger.info(f"error symbol nums: {len(error_symbol)}")
|
||||
logger.info(f"current get symbol nums: {len(stock_list)}")
|
||||
@@ -204,8 +233,9 @@ class YahooCollector:
|
||||
logger.info(f"{i+1} finish.")
|
||||
for _symbol, _df_list in self._mini_symbol_map.items():
|
||||
self.save_stock(_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"]))
|
||||
|
||||
logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self._mini_symbol_map.keys())}")
|
||||
if self._mini_symbol_map:
|
||||
logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self._mini_symbol_map.keys())}")
|
||||
logger.info(f"total {len(self.stock_list)}, error: {len(set(stock_list))}")
|
||||
|
||||
self.download_index_data()
|
||||
|
||||
@@ -215,7 +245,7 @@ class YahooCollector:
|
||||
raise NotImplementedError("rewrite download_index_data")
|
||||
|
||||
@abc.abstractmethod
|
||||
def normailze_symbol(self, symbol: str):
|
||||
def normalize_symbol(self, symbol: str):
|
||||
"""normalize symbol"""
|
||||
raise NotImplementedError("rewrite normalize_symbol")
|
||||
|
||||
@@ -237,30 +267,41 @@ class YahooCollectorCN(YahooCollector):
|
||||
def download_index_data(self):
|
||||
# TODO: from MSN
|
||||
# FIXME: 1m
|
||||
_format = "%Y%m%d"
|
||||
_begin = self._start_datetime.strftime(_format)
|
||||
_end = (self._end_datetime + pd.Timedelta(days=-1)).strftime(_format)
|
||||
for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items():
|
||||
logger.info(f"get bench data: {_index_name}({_index_code})......")
|
||||
df = pd.DataFrame(
|
||||
map(
|
||||
lambda x: x.split(","),
|
||||
requests.get(INDEX_BENCH_URL.format(index_code=_index_code, begin=_begin, end=_end)).json()["data"][
|
||||
"klines"
|
||||
],
|
||||
)
|
||||
)
|
||||
df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"]
|
||||
df["date"] = pd.to_datetime(df["date"])
|
||||
df = df.astype(float, errors="ignore")
|
||||
df["adjclose"] = df["close"]
|
||||
df.to_csv(self.save_dir.joinpath(f"sh{_index_code}.csv"), index=False)
|
||||
if self._interval == "1d":
|
||||
_format = "%Y%m%d"
|
||||
_begin = self._start_datetime.strftime(_format)
|
||||
_end = (self._end_datetime + pd.Timedelta(days=-1)).strftime(_format)
|
||||
for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items():
|
||||
logger.info(f"get bench data: {_index_name}({_index_code})......")
|
||||
try:
|
||||
df = pd.DataFrame(
|
||||
map(
|
||||
lambda x: x.split(","),
|
||||
requests.get(INDEX_BENCH_URL.format(index_code=_index_code, begin=_begin, end=_end)).json()[
|
||||
"data"
|
||||
]["klines"],
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"get {_index_name} error: {e}")
|
||||
continue
|
||||
df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"]
|
||||
df["date"] = pd.to_datetime(df["date"])
|
||||
df = df.astype(float, errors="ignore")
|
||||
df["adjclose"] = df["close"]
|
||||
df.to_csv(self.save_dir.joinpath(f"sh{_index_code}.csv"), index=False)
|
||||
else:
|
||||
logger.warning(f"{self.__class__.__name__} {self._interval} does not support: downlaod_index_data")
|
||||
|
||||
def normailze_symbol(self, symbol):
|
||||
def normalize_symbol(self, symbol):
|
||||
symbol_s = symbol.split(".")
|
||||
symbol = f"sh{symbol_s[0]}" if symbol_s[-1] == "ss" else f"sz{symbol_s[0]}"
|
||||
return symbol
|
||||
|
||||
@property
|
||||
def _timezone(self):
|
||||
return "Asia/Shanghai"
|
||||
|
||||
|
||||
class YahooCollectorUS(YahooCollector):
|
||||
@property
|
||||
@@ -283,9 +324,13 @@ class YahooCollectorUS(YahooCollector):
|
||||
def download_index_data(self):
|
||||
pass
|
||||
|
||||
def normailze_symbol(self, symbol):
|
||||
def normalize_symbol(self, symbol):
|
||||
return symbol.upper()
|
||||
|
||||
@property
|
||||
def _timezone(self):
|
||||
return "America/New_York"
|
||||
|
||||
|
||||
class YahooNormalize:
|
||||
COLUMNS = ["open", "close", "high", "low", "volume"]
|
||||
@@ -419,7 +464,14 @@ class Run:
|
||||
self.region = region
|
||||
|
||||
def download_data(
|
||||
self, max_collector_count=5, delay=0, start=None, end=None, interval="1d", check_data_length=True
|
||||
self,
|
||||
max_collector_count=5,
|
||||
delay=0,
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
check_data_length=False,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download data from Internet
|
||||
|
||||
@@ -436,8 +488,9 @@ class Run:
|
||||
end: str
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
||||
check_data_length: bool
|
||||
check data length, by default True
|
||||
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
Examples
|
||||
---------
|
||||
# get daily data
|
||||
@@ -456,6 +509,7 @@ class Run:
|
||||
end=end,
|
||||
interval=interval,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
).collector_data()
|
||||
|
||||
def normalize_data(self):
|
||||
@@ -469,7 +523,14 @@ class Run:
|
||||
_class(self.source_dir, self.normalize_dir, self.max_workers).normalize()
|
||||
|
||||
def collector_data(
|
||||
self, max_collector_count=5, delay=0, start=None, end=None, interval="1d", check_data_length=False
|
||||
self,
|
||||
max_collector_count=5,
|
||||
delay=0,
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
check_data_length=False,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download -> normalize
|
||||
|
||||
@@ -487,7 +548,8 @@ class Run:
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
Examples
|
||||
-------
|
||||
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
@@ -499,6 +561,7 @@ class Run:
|
||||
end=end,
|
||||
interval=interval,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
)
|
||||
self.normalize_data()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user