1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

add normalize 1min to use local data && change the default parameters for collecting 1min

This commit is contained in:
zhupr
2021-06-08 14:45:20 +08:00
parent 554b9c7826
commit a845a2271b
9 changed files with 328 additions and 70 deletions

View File

@@ -7,7 +7,7 @@ import time
import datetime
import importlib
from pathlib import Path
from typing import Type
from typing import Type, Iterable
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import pandas as pd
@@ -38,7 +38,7 @@ class BaseCollector(abc.ABC):
max_workers=1,
max_collector_count=2,
delay=0,
check_data_length: bool = False,
check_data_length: int = None,
limit_nums: int = None,
):
"""
@@ -59,8 +59,8 @@ class BaseCollector(abc.ABC):
start datetime, default None
end: str
end datetime, default None
check_data_length: bool
check data length, by default False
check_data_length: int
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
limit_nums: int
using for debug, by default None
"""
@@ -72,7 +72,7 @@ class BaseCollector(abc.ABC):
self.max_collector_count = max_collector_count
self.mini_symbol_map = {}
self.interval = interval
self.check_small_data = check_data_length
self.check_data_length = max(int(check_data_length) if check_data_length is not None else 0, 0)
self.start_datetime = self.normalize_start_datetime(start)
self.end_datetime = self.normalize_end_datetime(end)
@@ -99,14 +99,6 @@ class BaseCollector(abc.ABC):
else getattr(self, f"DEFAULT_END_DATETIME_{self.interval.upper()}")
)
@property
@abc.abstractmethod
def min_numbers_trading(self):
# daily, one year: 252 / 4
# us 1min, a week: 6.5 * 60 * 5
# cn 1min, a week: 4 * 60 * 5
raise NotImplementedError("rewrite min_numbers_trading")
@abc.abstractmethod
def get_instrument_list(self):
raise NotImplementedError("rewrite get_instrument_list")
@@ -132,7 +124,7 @@ class BaseCollector(abc.ABC):
Returns
---------
pd.DataFrame, "symbol" in pd.columns
pd.DataFrame, "symbol" and "date"in pd.columns
"""
raise NotImplementedError("rewrite get_timezone")
@@ -151,7 +143,7 @@ class BaseCollector(abc.ABC):
self.sleep()
df = self.get_data(symbol, self.interval, self.start_datetime, self.end_datetime)
_result = self.NORMAL_FLAG
if self.check_small_data:
if self.check_data_length > 0:
_result = self.cache_small_data(symbol, df)
if _result == self.NORMAL_FLAG:
self.save_instrument(symbol, df)
@@ -181,8 +173,8 @@ class BaseCollector(abc.ABC):
df.to_csv(instrument_path, index=False)
def cache_small_data(self, symbol, df):
if len(df) <= self.min_numbers_trading:
logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!")
if len(df) < self.check_data_length:
logger.warning(f"the number of trading days of {symbol} is less than {self.check_data_length}!")
_temp = self.mini_symbol_map.setdefault(symbol, [])
_temp.append(df.copy())
return self.CACHE_FLAG
@@ -194,9 +186,17 @@ class BaseCollector(abc.ABC):
def _collector(self, instrument_list):
error_symbol = []
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
with tqdm(total=len(instrument_list)) as p_bar:
for _symbol, _result in zip(instrument_list, executor.map(self._simple_collector, instrument_list)):
with tqdm(total=len(instrument_list)) as p_bar:
if self.max_workers is not None and self.max_workers > 1:
logger.info(f"concurrent collector, max_workers: {self.max_workers}")
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
for _symbol, _result in zip(instrument_list, executor.map(self._simple_collector, instrument_list)):
if _result != self.NORMAL_FLAG:
error_symbol.append(_symbol)
p_bar.update()
else:
for _symbol in instrument_list:
_result = self._simple_collector(_symbol)
if _result != self.NORMAL_FLAG:
error_symbol.append(_symbol)
p_bar.update()
@@ -217,11 +217,11 @@ class BaseCollector(abc.ABC):
instrument_list = self._collector(instrument_list)
logger.info(f"{i+1} finish.")
for _symbol, _df_list in self.mini_symbol_map.items():
self.save_instrument(
_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])
)
_df = pd.concat(_df_list, sort=False)
if not _df.empty:
self.save_instrument(_symbol, _df.drop_duplicates(["date"]).sort_values(["date"]))
if self.mini_symbol_map:
logger.warning(f"less than {self.min_numbers_trading} instrument list: {list(self.mini_symbol_map.keys())}")
logger.warning(f"less than {self.check_data_length} instrument list: {list(self.mini_symbol_map.keys())}")
logger.info(f"total {len(self.instrument_list)}, error: {len(set(instrument_list))}")
@@ -247,7 +247,7 @@ class BaseNormalize(abc.ABC):
raise NotImplementedError("")
@abc.abstractmethod
def _get_calendar_list(self):
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
"""Get benchmark calendar"""
raise NotImplementedError("")
@@ -296,7 +296,7 @@ class Normalize:
file_path = Path(file_path)
df = pd.read_csv(file_path)
df = self._normalize_obj.normalize(df)
if not df.empty:
if df is not None and not df.empty:
df.to_csv(self._target_dir.joinpath(file_path.name), index=False)
def normalize(self):

View File

@@ -0,0 +1,23 @@
# Use 1d data to fill in the missing symbols relative to 1min
## Requirements
```bash
pip install -r requirements.txt
```
## fill 1min data
```bash
python fill_1min_using_1d.py --data_1min_dir ~/.qlib/csv_data/cn_data_1min --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data
```
## Parameters
- ata_1min_dir: csv data
- qlib_data_1d_dir: qlib data directory
- max_workers: `ThreadPoolExecutor(max_workers=max_workers)`, by default *16*
- date_field_name: date field name, by default *date*
- symbol_field_name: symbol field name, by default *symbol*

View File

@@ -0,0 +1,98 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
import fire
import qlib
import pandas as pd
from tqdm import tqdm
from qlib.data import D
from loguru import logger
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent.parent))
from data_collector.utils import generate_minutes_calendar_from_daily
def get_date_range(data_1min_dir: Path, max_workers: int = 16, date_field_name: str = "date"):
csv_files = list(data_1min_dir.glob("*.csv"))
min_date = None
max_date = None
with tqdm(total=len(csv_files)) as p_bar:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
for _file, _result in zip(csv_files, executor.map(pd.read_csv, csv_files)):
if not _result.empty:
_dates = pd.to_datetime(_result[date_field_name])
_tmp_min = _dates.min()
min_date = min_date(min_date, _tmp_min) if min_date is not None else _tmp_min
_tmp_max = _dates.max()
max_date = min_date(max_date, _tmp_max) if max_date is not None else _tmp_max
p_bar.update()
return min_date, max_date
def get_symbols(data_1min_dir: Path):
return list(map(lambda x: x.name[:-4].upper(), data_1min_dir.glob("*.csv")))
def fill_1min_using_1d(
data_1min_dir: [str, Path],
qlib_data_1d_dir: [str, Path],
max_workers: int = 16,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
):
"""Use 1d data to fill in the missing symbols relative to 1min
Parameters
----------
data_1min_dir: str
1min data dir
qlib_data_1d_dir: str
1d qlib data(bin data) dir, from: https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format
max_workers: int
ThreadPoolExecutor(max_workers), by default 16
date_field_name: str
date field name, by default date
symbol_field_name: str
symbol field name, by default symbol
"""
data_1min_dir = Path(data_1min_dir).expanduser().resolve()
qlib_data_1d_dir = Path(qlib_data_1d_dir).expanduser().resolve()
min_date, max_date = get_date_range(data_1min_dir, max_workers, date_field_name)
symbols_1min = get_symbols(data_1min_dir)
qlib.init(provider_uri=str(qlib_data_1d_dir))
data_1d = D.features(D.instruments("all"), ["$close"], min_date, max_date, freq="day")
miss_symbols = set(data_1d.index.get_level_values(level="instrument").unique()) - set(symbols_1min)
if not miss_symbols:
logger.warning("More symbols in 1min than 1d, no padding required")
return
logger.info(f"miss_symbols {len(miss_symbols)}: {miss_symbols}")
tmp_df = pd.read_csv(list(data_1min_dir.glob("*.csv"))[0])
columns = tmp_df.columns
_si = tmp_df[symbol_field_name].first_valid_index()
is_lower = tmp_df.loc[tmp_df][symbol_field_name].islower()
for symbol in tqdm(miss_symbols):
if is_lower:
symbol = symbol.lower()
index_1d = data_1d.loc(axis=0)[symbol.upper()].index
index_1min = generate_minutes_calendar_from_daily(index_1d)
index_1min.name = date_field_name
_df = pd.DataFrame(columns=columns, index=index_1min)
_df.reset_index(inplace=True)
_df[symbol_field_name] = symbol
_df.to_csv(data_1min_dir.joinpath(f"{symbol}.csv"), index=False)
if __name__ == "__main__":
fire.Fire(fill_1min_using_1d)

View File

@@ -0,0 +1,5 @@
fire
pandas
loguru
tqdm
pyqlib

View File

@@ -14,7 +14,7 @@ from loguru import logger
import baostock as bs
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
sys.path.append(str(CUR_DIR.parent.parent.parent))
from data_collector.utils import generate_minutes_calendar_from_daily

View File

@@ -3,18 +3,13 @@
import abc
import sys
import copy
import time
import datetime
import importlib
import json
from abc import ABC
from pathlib import Path
from typing import Iterable, Type
import fire
import requests
import numpy as np
import pandas as pd
from loguru import logger
from dateutil.tz import tzlocal
@@ -38,7 +33,7 @@ class FundCollector(BaseCollector):
max_workers=4,
max_collector_count=2,
delay=0,
check_data_length: bool = False,
check_data_length: int = None,
limit_nums: int = None,
):
"""
@@ -59,8 +54,8 @@ class FundCollector(BaseCollector):
start datetime, default None
end: str
end datetime, default None
check_data_length: bool
check data length, by default False
check_data_length: int
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
limit_nums: int
using for debug, by default None
"""
@@ -168,10 +163,7 @@ class FundollectorCN(FundCollector, ABC):
class FundCollectorCN1d(FundollectorCN):
@property
def min_numbers_trading(self):
return 252 / 4
pass
class FundNormalize(BaseNormalize):
DAILY_FORMAT = "%Y-%m-%d"
@@ -261,7 +253,7 @@ class Run(BaseRun):
start=None,
end=None,
interval="1d",
check_data_length=False,
check_data_length=None,
limit_nums=None,
):
"""download data from Internet
@@ -278,8 +270,8 @@ class Run(BaseRun):
start datetime, default "2000-01-01"
end: str
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
check_data_length: bool # if this param useful?
check data length, by default False
check_data_length: int # if this param useful?
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
limit_nums: int
using for debug, by default None

View File

@@ -137,7 +137,7 @@ class YahooCollector(BaseCollector):
def get_data(
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
) -> pd.DataFrame:
@deco_retry(retry_sleep=1)
@deco_retry(retry_sleep=self.delay)
def _get_simple(start_, end_):
self.sleep()
_remote_interval = "1m" if interval == self.INTERVAL_1min else interval
@@ -200,10 +200,6 @@ class YahooCollectorCN(YahooCollector, ABC):
class YahooCollectorCN1d(YahooCollectorCN):
@property
def min_numbers_trading(self):
return 252 / 4
def download_index_data(self):
# TODO: from MSN
_format = "%Y%m%d"
@@ -237,10 +233,6 @@ class YahooCollectorCN1d(YahooCollectorCN):
class YahooCollectorCN1min(YahooCollectorCN):
@property
def min_numbers_trading(self):
return 60 * 4 * 5
def download_index_data(self):
# TODO: 1m
logger.warning(f"{self.__class__.__name__} {self.interval} does not support: download_index_data")
@@ -269,15 +261,11 @@ class YahooCollectorUS(YahooCollector, ABC):
class YahooCollectorUS1d(YahooCollectorUS):
@property
def min_numbers_trading(self):
return 252 / 4
pass
class YahooCollectorUS1min(YahooCollectorUS):
@property
def min_numbers_trading(self):
return 60 * 6.5 * 5
pass
class YahooNormalize(BaseNormalize):
@@ -514,7 +502,17 @@ class YahooNormalize1min(YahooNormalize, ABC):
)
def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame:
"""get 1d data
Returns
------
data_1d: pd.DataFrame
set(data_1d.columns) - set([self._date_field_name, self._symbol_field_name, "paused", "volume", "factor"]) == {}
"""
data_1d = YahooCollector.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end)
if not (data_1d is None or data_1d.empty):
data_1d = self.data_1d_obj.normalize(data_1d)
return data_1d
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
@@ -526,13 +524,12 @@ class YahooNormalize1min(YahooNormalize, ABC):
# get 1d data from yahoo
_start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT)
_end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT)
data_1d = self.get_1d_data(symbol, _start, _end)
data_1d: pd.DataFrame = self.get_1d_data(symbol, _start, _end)
if data_1d is None or data_1d.empty:
df["factor"] = 1
# TODO: np.nan or 1 or 0
df["paused"] = np.nan
else:
data_1d = self.data_1d_obj.normalize(data_1d) # type: pd.DataFrame
# NOTE: volume is np.nan or volume <= 0, paused = 1
# FIXME: find a more accurate data source
data_1d["paused"] = 0
@@ -621,12 +618,12 @@ class YahooNormalize1min(YahooNormalize, ABC):
raise NotImplementedError("rewrite symbol_to_yahoo")
@abc.abstractmethod
def _get_1d_calendar_list(self):
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
raise NotImplementedError("rewrite _get_1d_calendar_list")
class YahooNormalizeUS:
def _get_calendar_list(self):
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
# TODO: from MSN
return get_calendar_list("US_ALL")
@@ -638,7 +635,7 @@ class YahooNormalizeUS1d(YahooNormalizeUS, YahooNormalize1d):
class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1min):
CONSISTENT_1d = False
def _get_calendar_list(self):
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
# TODO: support 1min
raise ValueError("Does not support 1min")
@@ -650,7 +647,7 @@ class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1min):
class YahooNormalizeCN:
def _get_calendar_list(self):
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
# TODO: from MSN
return get_calendar_list("ALL")
@@ -670,7 +667,7 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min):
CONSISTENT_1d = True
CALC_PAUSED_NUM = True
def _get_calendar_list(self):
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
return self.generate_1min_from_daily(self.calendar_list_1d)
def symbol_to_yahoo(self, symbol):
@@ -680,10 +677,67 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min):
symbol = symbol[2:] + "." + _exchange
return symbol
def _get_1d_calendar_list(self):
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
return get_calendar_list("ALL")
class YahooNormalizeCN1minOffline(YahooNormalizeCN1min):
"""Normalised to 1min using local 1d data
1d data usually from: Normalised to 1min using local 1d data
"""
def __init__(
self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs
):
"""
Parameters
----------
qlib_data_1d_dir: str, Path
the qlib data to be updated for yahoo, usually from: Normalised to 1min using local 1d data
date_field_name: str
date field name, default is date
symbol_field_name: str
symbol field name, default is symbol
"""
super(YahooNormalizeCN1minOffline, self).__init__(date_field_name, symbol_field_name)
self.qlib_data_1d_dir = qlib_data_1d_dir
self._all_1d_data = self._get_all_1d_data()
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
import qlib
from qlib.data import D
qlib.init(provider_uri=self.qlib_data_1d_dir)
return list(D.calendar(freq="day"))
def _get_all_1d_data(self):
import qlib
from qlib.data import D
qlib.init(provider_uri=self.qlib_data_1d_dir)
df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor"], freq="day")
df.reset_index(inplace=True)
df.rename(columns={"datetime": self._date_field_name, "instrument": self._symbol_field_name}, inplace=True)
df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns))
return df
def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame:
"""get 1d data
Returns
------
data_1d: pd.DataFrame
set(data_1d.columns) - set([self._date_field_name, self._symbol_field_name, "paused", "volume", "factor"]) == {}
"""
return self._all_1d_data[
(self._all_1d_data[self._symbol_field_name] == symbol.upper())
& (self._all_1d_data[self._date_field_name] >= pd.Timestamp(start))
& (self._all_1d_data[self._date_field_name] < pd.Timestamp(end))
]
class Run(BaseRun):
def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d", region=REGION_CN):
"""
@@ -722,7 +776,7 @@ class Run(BaseRun):
delay=0,
start=None,
end=None,
check_data_length=False,
check_data_length=None,
limit_nums=None,
):
"""download data from Internet
@@ -734,14 +788,21 @@ class Run(BaseRun):
delay: float
time.sleep(delay), default 0
start: str
start datetime, default "2000-01-01"
start datetime, default "2000-01-01"; closed interval(including start)
end: str
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
check_data_length: bool
check data length, by default False
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``; open interval(excluding end)
check_data_length: int
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
limit_nums: int
using for debug, by default None
Notes
-----
check_data_length, example:
daily, one year: 252 // 4
us 1min, a week: 6.5 * 60 * 5
cn 1min, a week: 4 * 60 * 5
Examples
---------
# get daily data
@@ -813,6 +874,85 @@ class Run(BaseRun):
)
yc.normalize()
def normalize_data_1min_cn_offline(
self, qlib_data_1d_dir, date_field_name: str = "date", symbol_field_name: str = "symbol"
):
"""Normalised to 1min using local 1d data
Parameters
----------
qlib_data_1d_dir: str
the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data
date_field_name: str
date field name, default date
symbol_field_name: str
symbol field name, default symbol
Examples
---------
$ python collector.py normalize_data_1min_cn_offline --qlib_data_1d_dir ~/.qlib/qlib_data/cn_1d --source_dir ~/.qlib/stock_data/source_cn_1min --normalize_dir ~/.qlib/stock_data/normalize_cn_1min --region CN --interval 1min
"""
_class = getattr(self._cur_module, f"{self.normalize_class_name}Offline")
yc = Normalize(
source_dir=self.source_dir,
target_dir=self.normalize_dir,
normalize_class=_class,
max_workers=self.max_workers,
date_field_name=date_field_name,
symbol_field_name=symbol_field_name,
qlib_data_1d_dir=qlib_data_1d_dir,
)
yc.normalize()
def download_today_data(
self,
max_collector_count=2,
delay=0,
check_data_length=None,
limit_nums=None,
):
"""download today data from Internet
Parameters
----------
max_collector_count: int
default 2
delay: float
time.sleep(delay), default 0
check_data_length: int
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
limit_nums: int
using for debug, by default None
Notes
-----
Download today's data:
start_time = datetime.datetime.now().date(); closed interval(including start)
end_time = pd.Timestamp(start_time + pd.Timedelta(days=1)).date(); open interval(excluding end)
check_data_length, example:
daily, one year: 252 // 4
us 1min, a week: 6.5 * 60 * 5
cn 1min, a week: 4 * 60 * 5
Examples
---------
# get daily data
$ python collector.py download_today_data --source_dir ~/.qlib/stock_data/source --region CN --delay 0.1 --interval 1d
# get 1m data
$ python collector.py download_today_data --source_dir ~/.qlib/stock_data/source --region CN --delay 0.1 --interval 1m
"""
start = datetime.datetime.now().date()
end = pd.Timestamp(start + pd.Timedelta(days=1)).date()
self.download_data(
max_collector_count,
delay,
start.strftime("%Y-%m-%d"),
end.strftime("%Y-%m-%d"),
check_data_length,
limit_nums,
)
if __name__ == "__main__":
fire.Fire(Run)