mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
453 lines
16 KiB
Python
453 lines
16 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
|
|
import abc
|
|
import time
|
|
import datetime
|
|
import importlib
|
|
from pathlib import Path
|
|
from typing import Type, Iterable
|
|
from concurrent.futures import ProcessPoolExecutor
|
|
|
|
import pandas as pd
|
|
from tqdm import tqdm
|
|
from loguru import logger
|
|
from joblib import Parallel, delayed
|
|
from qlib.utils import code_to_fname
|
|
|
|
|
|
class BaseCollector(abc.ABC):
|
|
CACHE_FLAG = "CACHED"
|
|
NORMAL_FLAG = "NORMAL"
|
|
|
|
DEFAULT_START_DATETIME_1D = pd.Timestamp("2000-01-01")
|
|
DEFAULT_START_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6 - 1)).date()
|
|
DEFAULT_END_DATETIME_1D = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)).date()
|
|
DEFAULT_END_DATETIME_1MIN = DEFAULT_END_DATETIME_1D
|
|
|
|
INTERVAL_1min = "1min"
|
|
INTERVAL_1d = "1d"
|
|
|
|
def __init__(
|
|
self,
|
|
save_dir: [str, Path],
|
|
start=None,
|
|
end=None,
|
|
interval="1d",
|
|
max_workers=1,
|
|
max_collector_count=2,
|
|
delay=0,
|
|
check_data_length: int = None,
|
|
limit_nums: int = None,
|
|
):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
save_dir: str
|
|
instrument save dir
|
|
max_workers: int
|
|
workers, default 1; Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1
|
|
max_collector_count: int
|
|
default 2
|
|
delay: float
|
|
time.sleep(delay), default 0
|
|
interval: str
|
|
freq, value from [1min, 1d], default 1d
|
|
start: str
|
|
start datetime, default None
|
|
end: str
|
|
end datetime, default None
|
|
check_data_length: int
|
|
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
|
limit_nums: int
|
|
using for debug, by default None
|
|
"""
|
|
self.save_dir = Path(save_dir).expanduser().resolve()
|
|
self.save_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
self.delay = delay
|
|
self.max_workers = max_workers
|
|
self.max_collector_count = max_collector_count
|
|
self.mini_symbol_map = {}
|
|
self.interval = interval
|
|
self.check_data_length = max(int(check_data_length) if check_data_length is not None else 0, 0)
|
|
|
|
self.start_datetime = self.normalize_start_datetime(start)
|
|
self.end_datetime = self.normalize_end_datetime(end)
|
|
|
|
self.instrument_list = sorted(set(self.get_instrument_list()))
|
|
|
|
if limit_nums is not None:
|
|
try:
|
|
self.instrument_list = self.instrument_list[: int(limit_nums)]
|
|
except Exception as e:
|
|
logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored")
|
|
|
|
def normalize_start_datetime(self, start_datetime: [str, pd.Timestamp] = None):
|
|
return (
|
|
pd.Timestamp(str(start_datetime))
|
|
if start_datetime
|
|
else getattr(self, f"DEFAULT_START_DATETIME_{self.interval.upper()}")
|
|
)
|
|
|
|
def normalize_end_datetime(self, end_datetime: [str, pd.Timestamp] = None):
|
|
return (
|
|
pd.Timestamp(str(end_datetime))
|
|
if end_datetime
|
|
else getattr(self, f"DEFAULT_END_DATETIME_{self.interval.upper()}")
|
|
)
|
|
|
|
@abc.abstractmethod
|
|
def get_instrument_list(self):
|
|
raise NotImplementedError("rewrite get_instrument_list")
|
|
|
|
@abc.abstractmethod
|
|
def normalize_symbol(self, symbol: str):
|
|
"""normalize symbol"""
|
|
raise NotImplementedError("rewrite normalize_symbol")
|
|
|
|
@abc.abstractmethod
|
|
def get_data(
|
|
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
|
) -> pd.DataFrame:
|
|
"""get data with symbol
|
|
|
|
Parameters
|
|
----------
|
|
symbol: str
|
|
interval: str
|
|
value from [1min, 1d]
|
|
start_datetime: pd.Timestamp
|
|
end_datetime: pd.Timestamp
|
|
|
|
Returns
|
|
---------
|
|
pd.DataFrame, "symbol" and "date"in pd.columns
|
|
|
|
"""
|
|
raise NotImplementedError("rewrite get_timezone")
|
|
|
|
def sleep(self):
|
|
time.sleep(self.delay)
|
|
|
|
def _simple_collector(self, symbol: str):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
symbol: str
|
|
|
|
"""
|
|
self.sleep()
|
|
df = self.get_data(symbol, self.interval, self.start_datetime, self.end_datetime)
|
|
_result = self.NORMAL_FLAG
|
|
if self.check_data_length > 0:
|
|
_result = self.cache_small_data(symbol, df)
|
|
if _result == self.NORMAL_FLAG:
|
|
self.save_instrument(symbol, df)
|
|
return _result
|
|
|
|
def save_instrument(self, symbol, df: pd.DataFrame):
|
|
"""save instrument data to file
|
|
|
|
Parameters
|
|
----------
|
|
symbol: str
|
|
instrument code
|
|
df : pd.DataFrame
|
|
df.columns must contain "symbol" and "datetime"
|
|
"""
|
|
if df is None or df.empty:
|
|
logger.warning(f"{symbol} is empty")
|
|
return
|
|
|
|
symbol = self.normalize_symbol(symbol)
|
|
symbol = code_to_fname(symbol)
|
|
instrument_path = self.save_dir.joinpath(f"{symbol}.csv")
|
|
df["symbol"] = symbol
|
|
if instrument_path.exists():
|
|
_old_df = pd.read_csv(instrument_path)
|
|
df = pd.concat([_old_df, df], sort=False)
|
|
df.to_csv(instrument_path, index=False)
|
|
|
|
def cache_small_data(self, symbol, df):
|
|
if len(df) < self.check_data_length:
|
|
logger.warning(f"the number of trading days of {symbol} is less than {self.check_data_length}!")
|
|
_temp = self.mini_symbol_map.setdefault(symbol, [])
|
|
_temp.append(df.copy())
|
|
return self.CACHE_FLAG
|
|
else:
|
|
if symbol in self.mini_symbol_map:
|
|
self.mini_symbol_map.pop(symbol)
|
|
return self.NORMAL_FLAG
|
|
|
|
def _collector(self, instrument_list):
|
|
error_symbol = []
|
|
res = Parallel(n_jobs=self.max_workers)(
|
|
delayed(self._simple_collector)(_inst) for _inst in tqdm(instrument_list)
|
|
)
|
|
for _symbol, _result in zip(instrument_list, res):
|
|
if _result != self.NORMAL_FLAG:
|
|
error_symbol.append(_symbol)
|
|
print(error_symbol)
|
|
logger.info(f"error symbol nums: {len(error_symbol)}")
|
|
logger.info(f"current get symbol nums: {len(instrument_list)}")
|
|
error_symbol.extend(self.mini_symbol_map.keys())
|
|
return sorted(set(error_symbol))
|
|
|
|
def collector_data(self):
|
|
"""collector data"""
|
|
logger.info("start collector data......")
|
|
instrument_list = self.instrument_list
|
|
for i in range(self.max_collector_count):
|
|
if not instrument_list:
|
|
break
|
|
logger.info(f"getting data: {i+1}")
|
|
instrument_list = self._collector(instrument_list)
|
|
logger.info(f"{i+1} finish.")
|
|
for _symbol, _df_list in self.mini_symbol_map.items():
|
|
_df = pd.concat(_df_list, sort=False)
|
|
if not _df.empty:
|
|
self.save_instrument(_symbol, _df.drop_duplicates(["date"]).sort_values(["date"]))
|
|
if self.mini_symbol_map:
|
|
logger.warning(f"less than {self.check_data_length} instrument list: {list(self.mini_symbol_map.keys())}")
|
|
logger.info(f"total {len(self.instrument_list)}, error: {len(set(instrument_list))}")
|
|
|
|
|
|
class BaseNormalize(abc.ABC):
|
|
def __init__(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
date_field_name: str
|
|
date field name, default is date
|
|
symbol_field_name: str
|
|
symbol field name, default is symbol
|
|
"""
|
|
self._date_field_name = date_field_name
|
|
self._symbol_field_name = symbol_field_name
|
|
self.kwargs = kwargs
|
|
self._calendar_list = self._get_calendar_list()
|
|
|
|
@abc.abstractmethod
|
|
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
# normalize
|
|
raise NotImplementedError("")
|
|
|
|
@abc.abstractmethod
|
|
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
|
"""Get benchmark calendar"""
|
|
raise NotImplementedError("")
|
|
|
|
|
|
class Normalize:
|
|
def __init__(
|
|
self,
|
|
source_dir: [str, Path],
|
|
target_dir: [str, Path],
|
|
normalize_class: Type[BaseNormalize],
|
|
max_workers: int = 16,
|
|
date_field_name: str = "date",
|
|
symbol_field_name: str = "symbol",
|
|
**kwargs,
|
|
):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
source_dir: str or Path
|
|
The directory where the raw data collected from the Internet is saved
|
|
target_dir: str or Path
|
|
Directory for normalize data
|
|
normalize_class: Type[YahooNormalize]
|
|
normalize class
|
|
max_workers: int
|
|
Concurrent number, default is 16
|
|
date_field_name: str
|
|
date field name, default is date
|
|
symbol_field_name: str
|
|
symbol field name, default is symbol
|
|
"""
|
|
if not (source_dir and target_dir):
|
|
raise ValueError("source_dir and target_dir cannot be None")
|
|
self._source_dir = Path(source_dir).expanduser()
|
|
self._target_dir = Path(target_dir).expanduser()
|
|
self._target_dir.mkdir(parents=True, exist_ok=True)
|
|
self._date_field_name = date_field_name
|
|
self._symbol_field_name = symbol_field_name
|
|
self._end_date = kwargs.get("end_date", None)
|
|
self._max_workers = max_workers
|
|
self.interval = kwargs.get("interval", "1d")
|
|
|
|
self._normalize_obj = normalize_class(
|
|
date_field_name=date_field_name, symbol_field_name=symbol_field_name, **kwargs
|
|
)
|
|
|
|
def format_data(self, df: pd.DataFrame):
|
|
if self.interval == "1d":
|
|
try:
|
|
pd.to_datetime(df.iloc[-1]["date"], format="%Y-%m-%d", errors="raise")
|
|
except Exception:
|
|
df = df.iloc[:-1]
|
|
return df
|
|
|
|
def _executor(self, file_path: Path):
|
|
file_path = Path(file_path)
|
|
|
|
# some symbol_field values such as TRUE, NA are decoded as True(bool), NaN(np.float) by pandas default csv parsing.
|
|
# manually defines dtype and na_values of the symbol_field.
|
|
default_na = pd._libs.parsers.STR_NA_VALUES # pylint: disable=I1101
|
|
symbol_na = default_na.copy()
|
|
symbol_na.remove("NA")
|
|
columns = pd.read_csv(file_path, nrows=0).columns
|
|
df = pd.read_csv(
|
|
file_path,
|
|
dtype={self._symbol_field_name: str},
|
|
keep_default_na=False,
|
|
na_values={col: symbol_na if col == self._symbol_field_name else default_na for col in columns},
|
|
)
|
|
df = self.format_data(df=df)
|
|
|
|
if not df.empty:
|
|
# NOTE: It has been reported that there may be some problems here, and the specific issues will be dealt with when they are identified.
|
|
df = self._normalize_obj.normalize(df)
|
|
if df is not None and not df.empty:
|
|
if self._end_date is not None:
|
|
_mask = pd.to_datetime(df[self._date_field_name]) <= pd.Timestamp(self._end_date)
|
|
df = df[_mask]
|
|
df.to_csv(self._target_dir.joinpath(file_path.name), index=False)
|
|
else:
|
|
logger.warning(f"{file_path.stem} source data is empty and will not undergo normalization processing.")
|
|
|
|
def normalize(self):
|
|
logger.info("normalize data......")
|
|
|
|
with ProcessPoolExecutor(max_workers=self._max_workers) as worker:
|
|
file_list = list(self._source_dir.glob("*.csv"))
|
|
with tqdm(total=len(file_list)) as p_bar:
|
|
for _ in worker.map(self._executor, file_list):
|
|
p_bar.update()
|
|
|
|
|
|
class BaseRun(abc.ABC):
|
|
def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d"):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
source_dir: str
|
|
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
|
normalize_dir: str
|
|
Directory for normalize data, default "Path(__file__).parent/normalize"
|
|
max_workers: int
|
|
Concurrent number, default is 1; Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1
|
|
interval: str
|
|
freq, value from [1min, 1d], default 1d
|
|
"""
|
|
if source_dir is None:
|
|
source_dir = Path(self.default_base_dir).joinpath("source")
|
|
self.source_dir = Path(source_dir).expanduser().resolve()
|
|
self.source_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
if normalize_dir is None:
|
|
normalize_dir = Path(self.default_base_dir).joinpath("normalize")
|
|
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
|
|
self.normalize_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
self._cur_module = importlib.import_module("collector")
|
|
self.max_workers = max_workers
|
|
self.interval = interval
|
|
|
|
@property
|
|
@abc.abstractmethod
|
|
def collector_class_name(self):
|
|
raise NotImplementedError("rewrite collector_class_name")
|
|
|
|
@property
|
|
@abc.abstractmethod
|
|
def normalize_class_name(self):
|
|
raise NotImplementedError("rewrite normalize_class_name")
|
|
|
|
@property
|
|
@abc.abstractmethod
|
|
def default_base_dir(self) -> [Path, str]:
|
|
raise NotImplementedError("rewrite default_base_dir")
|
|
|
|
def download_data(
|
|
self,
|
|
max_collector_count=2,
|
|
delay=0,
|
|
start=None,
|
|
end=None,
|
|
check_data_length: int = None,
|
|
limit_nums=None,
|
|
**kwargs,
|
|
):
|
|
"""download data from Internet
|
|
|
|
Parameters
|
|
----------
|
|
max_collector_count: int
|
|
default 2
|
|
delay: float
|
|
time.sleep(delay), default 0
|
|
start: str
|
|
start datetime, default "2000-01-01"
|
|
end: str
|
|
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
|
check_data_length: int
|
|
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
|
limit_nums: int
|
|
using for debug, by default None
|
|
|
|
Examples
|
|
---------
|
|
# get daily data
|
|
$ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
|
# get 1m data
|
|
$ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
|
|
"""
|
|
|
|
_class = getattr(self._cur_module, self.collector_class_name) # type: Type[BaseCollector]
|
|
_class(
|
|
self.source_dir,
|
|
max_workers=self.max_workers,
|
|
max_collector_count=max_collector_count,
|
|
delay=delay,
|
|
start=start,
|
|
end=end,
|
|
interval=self.interval,
|
|
check_data_length=check_data_length,
|
|
limit_nums=limit_nums,
|
|
**kwargs,
|
|
).collector_data()
|
|
|
|
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs):
|
|
"""normalize data
|
|
|
|
Parameters
|
|
----------
|
|
date_field_name: str
|
|
date field name, default date
|
|
symbol_field_name: str
|
|
symbol field name, default symbol
|
|
|
|
Examples
|
|
---------
|
|
$ python collector.py normalize_data --source_dir ~/.qlib/instrument_data/source --normalize_dir ~/.qlib/instrument_data/normalize --region CN --interval 1d
|
|
"""
|
|
_class = getattr(self._cur_module, self.normalize_class_name)
|
|
yc = Normalize(
|
|
source_dir=self.source_dir,
|
|
target_dir=self.normalize_dir,
|
|
normalize_class=_class,
|
|
max_workers=self.max_workers,
|
|
date_field_name=date_field_name,
|
|
symbol_field_name=symbol_field_name,
|
|
**kwargs,
|
|
)
|
|
yc.normalize()
|