mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
431 lines
14 KiB
Python
431 lines
14 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
|
|
import abc
|
|
import time
|
|
import datetime
|
|
import importlib
|
|
from pathlib import Path
|
|
from typing import Type
|
|
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
|
|
|
import pandas as pd
|
|
from tqdm import tqdm
|
|
from loguru import logger
|
|
from qlib.utils import code_to_fname
|
|
|
|
|
|
class BaseCollector(abc.ABC):
|
|
|
|
CACHE_FLAG = "CACHED"
|
|
NORMAL_FLAG = "NORMAL"
|
|
|
|
DEFAULT_START_DATETIME_1D = pd.Timestamp("2000-01-01")
|
|
DEFAULT_START_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6))
|
|
DEFAULT_END_DATETIME_1D = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
|
|
DEFAULT_END_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
|
|
|
|
INTERVAL_1min = "1min"
|
|
INTERVAL_1d = "1d"
|
|
|
|
def __init__(
|
|
self,
|
|
save_dir: [str, Path],
|
|
start=None,
|
|
end=None,
|
|
interval="1d",
|
|
max_workers=4,
|
|
max_collector_count=2,
|
|
delay=0,
|
|
check_data_length: bool = False,
|
|
limit_nums: int = None,
|
|
):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
save_dir: str
|
|
instrument save dir
|
|
max_workers: int
|
|
workers, default 4
|
|
max_collector_count: int
|
|
default 2
|
|
delay: float
|
|
time.sleep(delay), default 0
|
|
interval: str
|
|
freq, value from [1min, 1d], default 1d
|
|
start: str
|
|
start datetime, default None
|
|
end: str
|
|
end datetime, default None
|
|
check_data_length: bool
|
|
check data length, by default False
|
|
limit_nums: int
|
|
using for debug, by default None
|
|
"""
|
|
self.save_dir = Path(save_dir).expanduser().resolve()
|
|
self.save_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
self.delay = delay
|
|
self.max_workers = max_workers
|
|
self.max_collector_count = max_collector_count
|
|
self.mini_symbol_map = {}
|
|
self.interval = interval
|
|
self.check_small_data = check_data_length
|
|
|
|
self.start_datetime = self.normalize_start_datetime(start)
|
|
self.end_datetime = self.normalize_end_datetime(end)
|
|
|
|
self.instrument_list = sorted(set(self.get_instrument_list()))
|
|
|
|
if limit_nums is not None:
|
|
try:
|
|
self.instrument_list = self.instrument_list[: int(limit_nums)]
|
|
except Exception as e:
|
|
logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored")
|
|
|
|
def normalize_start_datetime(self, start_datetime: [str, pd.Timestamp] = None):
|
|
return (
|
|
pd.Timestamp(str(start_datetime))
|
|
if start_datetime
|
|
else getattr(self, f"DEFAULT_START_DATETIME_{self.interval.upper()}")
|
|
)
|
|
|
|
def normalize_end_datetime(self, end_datetime: [str, pd.Timestamp] = None):
|
|
return (
|
|
pd.Timestamp(str(end_datetime))
|
|
if end_datetime
|
|
else getattr(self, f"DEFAULT_END_DATETIME_{self.interval.upper()}")
|
|
)
|
|
|
|
@property
|
|
@abc.abstractmethod
|
|
def min_numbers_trading(self):
|
|
# daily, one year: 252 / 4
|
|
# us 1min, a week: 6.5 * 60 * 5
|
|
# cn 1min, a week: 4 * 60 * 5
|
|
raise NotImplementedError("rewrite min_numbers_trading")
|
|
|
|
@abc.abstractmethod
|
|
def get_instrument_list(self):
|
|
raise NotImplementedError("rewrite get_instrument_list")
|
|
|
|
@abc.abstractmethod
|
|
def normalize_symbol(self, symbol: str):
|
|
"""normalize symbol"""
|
|
raise NotImplementedError("rewrite normalize_symbol")
|
|
|
|
@abc.abstractmethod
|
|
def get_data(
|
|
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
|
) -> pd.DataFrame:
|
|
"""get data with symbol
|
|
|
|
Parameters
|
|
----------
|
|
symbol: str
|
|
interval: str
|
|
value from [1min, 1d]
|
|
start_datetime: pd.Timestamp
|
|
end_datetime: pd.Timestamp
|
|
|
|
Returns
|
|
---------
|
|
pd.DataFrame, "symbol" in pd.columns
|
|
|
|
"""
|
|
raise NotImplementedError("rewrite get_timezone")
|
|
|
|
def sleep(self):
|
|
time.sleep(self.delay)
|
|
|
|
def _simple_collector(self, symbol: str):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
symbol: str
|
|
|
|
"""
|
|
self.sleep()
|
|
df = self.get_data(symbol, self.interval, self.start_datetime, self.end_datetime)
|
|
_result = self.NORMAL_FLAG
|
|
if self.check_small_data:
|
|
_result = self.cache_small_data(symbol, df)
|
|
if _result == self.NORMAL_FLAG:
|
|
self.save_instrument(symbol, df)
|
|
return _result
|
|
|
|
def save_instrument(self, symbol, df: pd.DataFrame):
|
|
"""save instrument data to file
|
|
|
|
Parameters
|
|
----------
|
|
symbol: str
|
|
instrument code
|
|
df : pd.DataFrame
|
|
df.columns must contain "symbol" and "datetime"
|
|
"""
|
|
if df is None or df.empty:
|
|
logger.warning(f"{symbol} is empty")
|
|
return
|
|
|
|
symbol = self.normalize_symbol(symbol)
|
|
symbol = code_to_fname(symbol)
|
|
instrument_path = self.save_dir.joinpath(f"{symbol}.csv")
|
|
df["symbol"] = symbol
|
|
if instrument_path.exists():
|
|
_old_df = pd.read_csv(instrument_path)
|
|
df = _old_df.append(df, sort=False)
|
|
df.to_csv(instrument_path, index=False)
|
|
|
|
def cache_small_data(self, symbol, df):
|
|
if len(df) <= self.min_numbers_trading:
|
|
logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!")
|
|
_temp = self.mini_symbol_map.setdefault(symbol, [])
|
|
_temp.append(df.copy())
|
|
return self.CACHE_FLAG
|
|
else:
|
|
if symbol in self.mini_symbol_map:
|
|
self.mini_symbol_map.pop(symbol)
|
|
return self.NORMAL_FLAG
|
|
|
|
def _collector(self, instrument_list):
|
|
|
|
error_symbol = []
|
|
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
with tqdm(total=len(instrument_list)) as p_bar:
|
|
for _symbol, _result in zip(instrument_list, executor.map(self._simple_collector, instrument_list)):
|
|
if _result != self.NORMAL_FLAG:
|
|
error_symbol.append(_symbol)
|
|
p_bar.update()
|
|
print(error_symbol)
|
|
logger.info(f"error symbol nums: {len(error_symbol)}")
|
|
logger.info(f"current get symbol nums: {len(instrument_list)}")
|
|
error_symbol.extend(self.mini_symbol_map.keys())
|
|
return sorted(set(error_symbol))
|
|
|
|
def collector_data(self):
|
|
"""collector data"""
|
|
logger.info("start collector data......")
|
|
instrument_list = self.instrument_list
|
|
for i in range(self.max_collector_count):
|
|
if not instrument_list:
|
|
break
|
|
logger.info(f"getting data: {i+1}")
|
|
instrument_list = self._collector(instrument_list)
|
|
logger.info(f"{i+1} finish.")
|
|
for _symbol, _df_list in self.mini_symbol_map.items():
|
|
self.save_instrument(
|
|
_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])
|
|
)
|
|
if self.mini_symbol_map:
|
|
logger.warning(f"less than {self.min_numbers_trading} instrument list: {list(self.mini_symbol_map.keys())}")
|
|
logger.info(f"total {len(self.instrument_list)}, error: {len(set(instrument_list))}")
|
|
|
|
|
|
class BaseNormalize(abc.ABC):
|
|
def __init__(
|
|
self,
|
|
date_field_name: str = "date",
|
|
symbol_field_name: str = "symbol",
|
|
):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
date_field_name: str
|
|
date field name, default is date
|
|
symbol_field_name: str
|
|
symbol field name, default is symbol
|
|
"""
|
|
self._date_field_name = date_field_name
|
|
self._symbol_field_name = symbol_field_name
|
|
|
|
self._calendar_list = self._get_calendar_list()
|
|
|
|
@abc.abstractmethod
|
|
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
# normalize
|
|
raise NotImplementedError("")
|
|
|
|
@abc.abstractmethod
|
|
def _get_calendar_list(self):
|
|
"""Get benchmark calendar"""
|
|
raise NotImplementedError("")
|
|
|
|
|
|
class Normalize:
|
|
def __init__(
|
|
self,
|
|
source_dir: [str, Path],
|
|
target_dir: [str, Path],
|
|
normalize_class: Type[BaseNormalize],
|
|
max_workers: int = 16,
|
|
date_field_name: str = "date",
|
|
symbol_field_name: str = "symbol",
|
|
):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
source_dir: str or Path
|
|
The directory where the raw data collected from the Internet is saved
|
|
target_dir: str or Path
|
|
Directory for normalize data
|
|
normalize_class: Type[YahooNormalize]
|
|
normalize class
|
|
max_workers: int
|
|
Concurrent number, default is 16
|
|
date_field_name: str
|
|
date field name, default is date
|
|
symbol_field_name: str
|
|
symbol field name, default is symbol
|
|
"""
|
|
if not (source_dir and target_dir):
|
|
raise ValueError("source_dir and target_dir cannot be None")
|
|
self._source_dir = Path(source_dir).expanduser()
|
|
self._target_dir = Path(target_dir).expanduser()
|
|
self._target_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
self._max_workers = max_workers
|
|
|
|
self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name)
|
|
|
|
def _executor(self, file_path: Path):
|
|
file_path = Path(file_path)
|
|
df = pd.read_csv(file_path)
|
|
df = self._normalize_obj.normalize(df)
|
|
if not df.empty:
|
|
df.to_csv(self._target_dir.joinpath(file_path.name), index=False)
|
|
|
|
def normalize(self):
|
|
logger.info("normalize data......")
|
|
|
|
with ProcessPoolExecutor(max_workers=self._max_workers) as worker:
|
|
file_list = list(self._source_dir.glob("*.csv"))
|
|
with tqdm(total=len(file_list)) as p_bar:
|
|
for _ in worker.map(self._executor, file_list):
|
|
p_bar.update()
|
|
|
|
|
|
class BaseRun(abc.ABC):
|
|
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d"):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
source_dir: str
|
|
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
|
normalize_dir: str
|
|
Directory for normalize data, default "Path(__file__).parent/normalize"
|
|
max_workers: int
|
|
Concurrent number, default is 4
|
|
interval: str
|
|
freq, value from [1min, 1d], default 1d
|
|
"""
|
|
if source_dir is None:
|
|
source_dir = Path(self.default_base_dir).joinpath("_source")
|
|
self.source_dir = Path(source_dir).expanduser().resolve()
|
|
self.source_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
if normalize_dir is None:
|
|
normalize_dir = Path(self.default_base_dir).joinpath("normalize")
|
|
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
|
|
self.normalize_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
self._cur_module = importlib.import_module("collector")
|
|
self.max_workers = max_workers
|
|
self.interval = interval
|
|
|
|
@property
|
|
@abc.abstractmethod
|
|
def collector_class_name(self):
|
|
raise NotImplementedError("rewrite normalize_symbol")
|
|
|
|
@property
|
|
@abc.abstractmethod
|
|
def normalize_class_name(self):
|
|
raise NotImplementedError("rewrite normalize_symbol")
|
|
|
|
@property
|
|
@abc.abstractmethod
|
|
def default_base_dir(self) -> [Path, str]:
|
|
raise NotImplementedError("rewrite normalize_symbol")
|
|
|
|
def download_data(
|
|
self,
|
|
max_collector_count=2,
|
|
delay=0,
|
|
start=None,
|
|
end=None,
|
|
interval="1d",
|
|
check_data_length=False,
|
|
limit_nums=None,
|
|
):
|
|
"""download data from Internet
|
|
|
|
Parameters
|
|
----------
|
|
max_collector_count: int
|
|
default 2
|
|
delay: float
|
|
time.sleep(delay), default 0
|
|
interval: str
|
|
freq, value from [1min, 1d], default 1d
|
|
start: str
|
|
start datetime, default "2000-01-01"
|
|
end: str
|
|
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
|
check_data_length: bool
|
|
check data length, by default False
|
|
limit_nums: int
|
|
using for debug, by default None
|
|
|
|
Examples
|
|
---------
|
|
# get daily data
|
|
$ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
|
# get 1m data
|
|
$ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
|
|
"""
|
|
|
|
_class = getattr(self._cur_module, self.collector_class_name) # type: Type[BaseCollector]
|
|
_class(
|
|
self.source_dir,
|
|
max_workers=self.max_workers,
|
|
max_collector_count=max_collector_count,
|
|
delay=delay,
|
|
start=start,
|
|
end=end,
|
|
interval=interval,
|
|
check_data_length=check_data_length,
|
|
limit_nums=limit_nums,
|
|
).collector_data()
|
|
|
|
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
|
"""normalize data
|
|
|
|
Parameters
|
|
----------
|
|
date_field_name: str
|
|
date field name, default date
|
|
symbol_field_name: str
|
|
symbol field name, default symbol
|
|
|
|
Examples
|
|
---------
|
|
$ python collector.py normalize_data --source_dir ~/.qlib/instrument_data/source --normalize_dir ~/.qlib/instrument_data/normalize --region CN --interval 1d
|
|
"""
|
|
_class = getattr(self._cur_module, self.normalize_class_name)
|
|
yc = Normalize(
|
|
source_dir=self.source_dir,
|
|
target_dir=self.normalize_dir,
|
|
normalize_class=_class,
|
|
max_workers=self.max_workers,
|
|
date_field_name=date_field_name,
|
|
symbol_field_name=symbol_field_name,
|
|
)
|
|
yc.normalize()
|