mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
add normalize 1min to use local data && change the default parameters for collecting 1min
This commit is contained in:
@@ -7,7 +7,7 @@ import time
|
||||
import datetime
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from typing import Type
|
||||
from typing import Type, Iterable
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
|
||||
import pandas as pd
|
||||
@@ -38,7 +38,7 @@ class BaseCollector(abc.ABC):
|
||||
max_workers=1,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
check_data_length: int = None,
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
@@ -59,8 +59,8 @@ class BaseCollector(abc.ABC):
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
check_data_length: int
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
"""
|
||||
@@ -72,7 +72,7 @@ class BaseCollector(abc.ABC):
|
||||
self.max_collector_count = max_collector_count
|
||||
self.mini_symbol_map = {}
|
||||
self.interval = interval
|
||||
self.check_small_data = check_data_length
|
||||
self.check_data_length = max(int(check_data_length) if check_data_length is not None else 0, 0)
|
||||
|
||||
self.start_datetime = self.normalize_start_datetime(start)
|
||||
self.end_datetime = self.normalize_end_datetime(end)
|
||||
@@ -99,14 +99,6 @@ class BaseCollector(abc.ABC):
|
||||
else getattr(self, f"DEFAULT_END_DATETIME_{self.interval.upper()}")
|
||||
)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def min_numbers_trading(self):
|
||||
# daily, one year: 252 / 4
|
||||
# us 1min, a week: 6.5 * 60 * 5
|
||||
# cn 1min, a week: 4 * 60 * 5
|
||||
raise NotImplementedError("rewrite min_numbers_trading")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_instrument_list(self):
|
||||
raise NotImplementedError("rewrite get_instrument_list")
|
||||
@@ -132,7 +124,7 @@ class BaseCollector(abc.ABC):
|
||||
|
||||
Returns
|
||||
---------
|
||||
pd.DataFrame, "symbol" in pd.columns
|
||||
pd.DataFrame, "symbol" and "date"in pd.columns
|
||||
|
||||
"""
|
||||
raise NotImplementedError("rewrite get_timezone")
|
||||
@@ -151,7 +143,7 @@ class BaseCollector(abc.ABC):
|
||||
self.sleep()
|
||||
df = self.get_data(symbol, self.interval, self.start_datetime, self.end_datetime)
|
||||
_result = self.NORMAL_FLAG
|
||||
if self.check_small_data:
|
||||
if self.check_data_length > 0:
|
||||
_result = self.cache_small_data(symbol, df)
|
||||
if _result == self.NORMAL_FLAG:
|
||||
self.save_instrument(symbol, df)
|
||||
@@ -181,8 +173,8 @@ class BaseCollector(abc.ABC):
|
||||
df.to_csv(instrument_path, index=False)
|
||||
|
||||
def cache_small_data(self, symbol, df):
|
||||
if len(df) <= self.min_numbers_trading:
|
||||
logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!")
|
||||
if len(df) < self.check_data_length:
|
||||
logger.warning(f"the number of trading days of {symbol} is less than {self.check_data_length}!")
|
||||
_temp = self.mini_symbol_map.setdefault(symbol, [])
|
||||
_temp.append(df.copy())
|
||||
return self.CACHE_FLAG
|
||||
@@ -194,9 +186,17 @@ class BaseCollector(abc.ABC):
|
||||
def _collector(self, instrument_list):
|
||||
|
||||
error_symbol = []
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
with tqdm(total=len(instrument_list)) as p_bar:
|
||||
for _symbol, _result in zip(instrument_list, executor.map(self._simple_collector, instrument_list)):
|
||||
with tqdm(total=len(instrument_list)) as p_bar:
|
||||
if self.max_workers is not None and self.max_workers > 1:
|
||||
logger.info(f"concurrent collector, max_workers: {self.max_workers}")
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
for _symbol, _result in zip(instrument_list, executor.map(self._simple_collector, instrument_list)):
|
||||
if _result != self.NORMAL_FLAG:
|
||||
error_symbol.append(_symbol)
|
||||
p_bar.update()
|
||||
else:
|
||||
for _symbol in instrument_list:
|
||||
_result = self._simple_collector(_symbol)
|
||||
if _result != self.NORMAL_FLAG:
|
||||
error_symbol.append(_symbol)
|
||||
p_bar.update()
|
||||
@@ -217,11 +217,11 @@ class BaseCollector(abc.ABC):
|
||||
instrument_list = self._collector(instrument_list)
|
||||
logger.info(f"{i+1} finish.")
|
||||
for _symbol, _df_list in self.mini_symbol_map.items():
|
||||
self.save_instrument(
|
||||
_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])
|
||||
)
|
||||
_df = pd.concat(_df_list, sort=False)
|
||||
if not _df.empty:
|
||||
self.save_instrument(_symbol, _df.drop_duplicates(["date"]).sort_values(["date"]))
|
||||
if self.mini_symbol_map:
|
||||
logger.warning(f"less than {self.min_numbers_trading} instrument list: {list(self.mini_symbol_map.keys())}")
|
||||
logger.warning(f"less than {self.check_data_length} instrument list: {list(self.mini_symbol_map.keys())}")
|
||||
logger.info(f"total {len(self.instrument_list)}, error: {len(set(instrument_list))}")
|
||||
|
||||
|
||||
@@ -247,7 +247,7 @@ class BaseNormalize(abc.ABC):
|
||||
raise NotImplementedError("")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_calendar_list(self):
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
"""Get benchmark calendar"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
@@ -296,7 +296,7 @@ class Normalize:
|
||||
file_path = Path(file_path)
|
||||
df = pd.read_csv(file_path)
|
||||
df = self._normalize_obj.normalize(df)
|
||||
if not df.empty:
|
||||
if df is not None and not df.empty:
|
||||
df.to_csv(self._target_dir.joinpath(file_path.name), index=False)
|
||||
|
||||
def normalize(self):
|
||||
|
||||
23
scripts/data_collector/contrib/fill_cn_1min_data/README.md
Normal file
23
scripts/data_collector/contrib/fill_cn_1min_data/README.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# Use 1d data to fill in the missing symbols relative to 1min
|
||||
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## fill 1min data
|
||||
|
||||
```bash
|
||||
python fill_1min_using_1d.py --data_1min_dir ~/.qlib/csv_data/cn_data_1min --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
- ata_1min_dir: csv data
|
||||
- qlib_data_1d_dir: qlib data directory
|
||||
- max_workers: `ThreadPoolExecutor(max_workers=max_workers)`, by default *16*
|
||||
- date_field_name: date field name, by default *date*
|
||||
- symbol_field_name: symbol field name, by default *symbol*
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import fire
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from qlib.data import D
|
||||
from loguru import logger
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent.parent))
|
||||
from data_collector.utils import generate_minutes_calendar_from_daily
|
||||
|
||||
|
||||
def get_date_range(data_1min_dir: Path, max_workers: int = 16, date_field_name: str = "date"):
|
||||
csv_files = list(data_1min_dir.glob("*.csv"))
|
||||
min_date = None
|
||||
max_date = None
|
||||
with tqdm(total=len(csv_files)) as p_bar:
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
for _file, _result in zip(csv_files, executor.map(pd.read_csv, csv_files)):
|
||||
if not _result.empty:
|
||||
_dates = pd.to_datetime(_result[date_field_name])
|
||||
|
||||
_tmp_min = _dates.min()
|
||||
min_date = min_date(min_date, _tmp_min) if min_date is not None else _tmp_min
|
||||
|
||||
_tmp_max = _dates.max()
|
||||
max_date = min_date(max_date, _tmp_max) if max_date is not None else _tmp_max
|
||||
p_bar.update()
|
||||
return min_date, max_date
|
||||
|
||||
|
||||
def get_symbols(data_1min_dir: Path):
|
||||
return list(map(lambda x: x.name[:-4].upper(), data_1min_dir.glob("*.csv")))
|
||||
|
||||
|
||||
def fill_1min_using_1d(
|
||||
data_1min_dir: [str, Path],
|
||||
qlib_data_1d_dir: [str, Path],
|
||||
max_workers: int = 16,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
"""Use 1d data to fill in the missing symbols relative to 1min
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data_1min_dir: str
|
||||
1min data dir
|
||||
qlib_data_1d_dir: str
|
||||
1d qlib data(bin data) dir, from: https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format
|
||||
max_workers: int
|
||||
ThreadPoolExecutor(max_workers), by default 16
|
||||
date_field_name: str
|
||||
date field name, by default date
|
||||
symbol_field_name: str
|
||||
symbol field name, by default symbol
|
||||
|
||||
"""
|
||||
data_1min_dir = Path(data_1min_dir).expanduser().resolve()
|
||||
qlib_data_1d_dir = Path(qlib_data_1d_dir).expanduser().resolve()
|
||||
|
||||
min_date, max_date = get_date_range(data_1min_dir, max_workers, date_field_name)
|
||||
symbols_1min = get_symbols(data_1min_dir)
|
||||
|
||||
qlib.init(provider_uri=str(qlib_data_1d_dir))
|
||||
data_1d = D.features(D.instruments("all"), ["$close"], min_date, max_date, freq="day")
|
||||
|
||||
miss_symbols = set(data_1d.index.get_level_values(level="instrument").unique()) - set(symbols_1min)
|
||||
if not miss_symbols:
|
||||
logger.warning("More symbols in 1min than 1d, no padding required")
|
||||
return
|
||||
|
||||
logger.info(f"miss_symbols {len(miss_symbols)}: {miss_symbols}")
|
||||
tmp_df = pd.read_csv(list(data_1min_dir.glob("*.csv"))[0])
|
||||
columns = tmp_df.columns
|
||||
_si = tmp_df[symbol_field_name].first_valid_index()
|
||||
is_lower = tmp_df.loc[tmp_df][symbol_field_name].islower()
|
||||
for symbol in tqdm(miss_symbols):
|
||||
if is_lower:
|
||||
symbol = symbol.lower()
|
||||
index_1d = data_1d.loc(axis=0)[symbol.upper()].index
|
||||
index_1min = generate_minutes_calendar_from_daily(index_1d)
|
||||
index_1min.name = date_field_name
|
||||
_df = pd.DataFrame(columns=columns, index=index_1min)
|
||||
_df.reset_index(inplace=True)
|
||||
_df[symbol_field_name] = symbol
|
||||
_df.to_csv(data_1min_dir.joinpath(f"{symbol}.csv"), index=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(fill_1min_using_1d)
|
||||
@@ -0,0 +1,5 @@
|
||||
fire
|
||||
pandas
|
||||
loguru
|
||||
tqdm
|
||||
pyqlib
|
||||
@@ -14,7 +14,7 @@ from loguru import logger
|
||||
import baostock as bs
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
sys.path.append(str(CUR_DIR.parent.parent.parent))
|
||||
|
||||
|
||||
from data_collector.utils import generate_minutes_calendar_from_daily
|
||||
@@ -3,18 +3,13 @@
|
||||
|
||||
import abc
|
||||
import sys
|
||||
import copy
|
||||
import time
|
||||
import datetime
|
||||
import importlib
|
||||
import json
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Type
|
||||
|
||||
import fire
|
||||
import requests
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from loguru import logger
|
||||
from dateutil.tz import tzlocal
|
||||
@@ -38,7 +33,7 @@ class FundCollector(BaseCollector):
|
||||
max_workers=4,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
check_data_length: int = None,
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
@@ -59,8 +54,8 @@ class FundCollector(BaseCollector):
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
check_data_length: int
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
"""
|
||||
@@ -168,10 +163,7 @@ class FundollectorCN(FundCollector, ABC):
|
||||
|
||||
|
||||
class FundCollectorCN1d(FundollectorCN):
|
||||
@property
|
||||
def min_numbers_trading(self):
|
||||
return 252 / 4
|
||||
|
||||
pass
|
||||
|
||||
class FundNormalize(BaseNormalize):
|
||||
DAILY_FORMAT = "%Y-%m-%d"
|
||||
@@ -261,7 +253,7 @@ class Run(BaseRun):
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
check_data_length=False,
|
||||
check_data_length=None,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download data from Internet
|
||||
@@ -278,8 +270,8 @@ class Run(BaseRun):
|
||||
start datetime, default "2000-01-01"
|
||||
end: str
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
||||
check_data_length: bool # if this param useful?
|
||||
check data length, by default False
|
||||
check_data_length: int # if this param useful?
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
|
||||
|
||||
@@ -137,7 +137,7 @@ class YahooCollector(BaseCollector):
|
||||
def get_data(
|
||||
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
||||
) -> pd.DataFrame:
|
||||
@deco_retry(retry_sleep=1)
|
||||
@deco_retry(retry_sleep=self.delay)
|
||||
def _get_simple(start_, end_):
|
||||
self.sleep()
|
||||
_remote_interval = "1m" if interval == self.INTERVAL_1min else interval
|
||||
@@ -200,10 +200,6 @@ class YahooCollectorCN(YahooCollector, ABC):
|
||||
|
||||
|
||||
class YahooCollectorCN1d(YahooCollectorCN):
|
||||
@property
|
||||
def min_numbers_trading(self):
|
||||
return 252 / 4
|
||||
|
||||
def download_index_data(self):
|
||||
# TODO: from MSN
|
||||
_format = "%Y%m%d"
|
||||
@@ -237,10 +233,6 @@ class YahooCollectorCN1d(YahooCollectorCN):
|
||||
|
||||
|
||||
class YahooCollectorCN1min(YahooCollectorCN):
|
||||
@property
|
||||
def min_numbers_trading(self):
|
||||
return 60 * 4 * 5
|
||||
|
||||
def download_index_data(self):
|
||||
# TODO: 1m
|
||||
logger.warning(f"{self.__class__.__name__} {self.interval} does not support: download_index_data")
|
||||
@@ -269,15 +261,11 @@ class YahooCollectorUS(YahooCollector, ABC):
|
||||
|
||||
|
||||
class YahooCollectorUS1d(YahooCollectorUS):
|
||||
@property
|
||||
def min_numbers_trading(self):
|
||||
return 252 / 4
|
||||
pass
|
||||
|
||||
|
||||
class YahooCollectorUS1min(YahooCollectorUS):
|
||||
@property
|
||||
def min_numbers_trading(self):
|
||||
return 60 * 6.5 * 5
|
||||
pass
|
||||
|
||||
|
||||
class YahooNormalize(BaseNormalize):
|
||||
@@ -514,7 +502,17 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
)
|
||||
|
||||
def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame:
|
||||
"""get 1d data
|
||||
|
||||
Returns
|
||||
------
|
||||
data_1d: pd.DataFrame
|
||||
set(data_1d.columns) - set([self._date_field_name, self._symbol_field_name, "paused", "volume", "factor"]) == {}
|
||||
|
||||
"""
|
||||
data_1d = YahooCollector.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end)
|
||||
if not (data_1d is None or data_1d.empty):
|
||||
data_1d = self.data_1d_obj.normalize(data_1d)
|
||||
return data_1d
|
||||
|
||||
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
@@ -526,13 +524,12 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
# get 1d data from yahoo
|
||||
_start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT)
|
||||
_end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT)
|
||||
data_1d = self.get_1d_data(symbol, _start, _end)
|
||||
data_1d: pd.DataFrame = self.get_1d_data(symbol, _start, _end)
|
||||
if data_1d is None or data_1d.empty:
|
||||
df["factor"] = 1
|
||||
# TODO: np.nan or 1 or 0
|
||||
df["paused"] = np.nan
|
||||
else:
|
||||
data_1d = self.data_1d_obj.normalize(data_1d) # type: pd.DataFrame
|
||||
# NOTE: volume is np.nan or volume <= 0, paused = 1
|
||||
# FIXME: find a more accurate data source
|
||||
data_1d["paused"] = 0
|
||||
@@ -621,12 +618,12 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
raise NotImplementedError("rewrite symbol_to_yahoo")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_1d_calendar_list(self):
|
||||
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
raise NotImplementedError("rewrite _get_1d_calendar_list")
|
||||
|
||||
|
||||
class YahooNormalizeUS:
|
||||
def _get_calendar_list(self):
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
# TODO: from MSN
|
||||
return get_calendar_list("US_ALL")
|
||||
|
||||
@@ -638,7 +635,7 @@ class YahooNormalizeUS1d(YahooNormalizeUS, YahooNormalize1d):
|
||||
class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1min):
|
||||
CONSISTENT_1d = False
|
||||
|
||||
def _get_calendar_list(self):
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
# TODO: support 1min
|
||||
raise ValueError("Does not support 1min")
|
||||
|
||||
@@ -650,7 +647,7 @@ class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1min):
|
||||
|
||||
|
||||
class YahooNormalizeCN:
|
||||
def _get_calendar_list(self):
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
# TODO: from MSN
|
||||
return get_calendar_list("ALL")
|
||||
|
||||
@@ -670,7 +667,7 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min):
|
||||
CONSISTENT_1d = True
|
||||
CALC_PAUSED_NUM = True
|
||||
|
||||
def _get_calendar_list(self):
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
return self.generate_1min_from_daily(self.calendar_list_1d)
|
||||
|
||||
def symbol_to_yahoo(self, symbol):
|
||||
@@ -680,10 +677,67 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min):
|
||||
symbol = symbol[2:] + "." + _exchange
|
||||
return symbol
|
||||
|
||||
def _get_1d_calendar_list(self):
|
||||
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
return get_calendar_list("ALL")
|
||||
|
||||
|
||||
class YahooNormalizeCN1minOffline(YahooNormalizeCN1min):
|
||||
"""Normalised to 1min using local 1d data
|
||||
1d data usually from: Normalised to 1min using local 1d data
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_data_1d_dir: str, Path
|
||||
the qlib data to be updated for yahoo, usually from: Normalised to 1min using local 1d data
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
super(YahooNormalizeCN1minOffline, self).__init__(date_field_name, symbol_field_name)
|
||||
self.qlib_data_1d_dir = qlib_data_1d_dir
|
||||
self._all_1d_data = self._get_all_1d_data()
|
||||
|
||||
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri=self.qlib_data_1d_dir)
|
||||
return list(D.calendar(freq="day"))
|
||||
|
||||
def _get_all_1d_data(self):
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri=self.qlib_data_1d_dir)
|
||||
df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor"], freq="day")
|
||||
df.reset_index(inplace=True)
|
||||
df.rename(columns={"datetime": self._date_field_name, "instrument": self._symbol_field_name}, inplace=True)
|
||||
df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns))
|
||||
return df
|
||||
|
||||
def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame:
|
||||
"""get 1d data
|
||||
|
||||
Returns
|
||||
------
|
||||
data_1d: pd.DataFrame
|
||||
set(data_1d.columns) - set([self._date_field_name, self._symbol_field_name, "paused", "volume", "factor"]) == {}
|
||||
|
||||
"""
|
||||
return self._all_1d_data[
|
||||
(self._all_1d_data[self._symbol_field_name] == symbol.upper())
|
||||
& (self._all_1d_data[self._date_field_name] >= pd.Timestamp(start))
|
||||
& (self._all_1d_data[self._date_field_name] < pd.Timestamp(end))
|
||||
]
|
||||
|
||||
|
||||
class Run(BaseRun):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d", region=REGION_CN):
|
||||
"""
|
||||
@@ -722,7 +776,7 @@ class Run(BaseRun):
|
||||
delay=0,
|
||||
start=None,
|
||||
end=None,
|
||||
check_data_length=False,
|
||||
check_data_length=None,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download data from Internet
|
||||
@@ -734,14 +788,21 @@ class Run(BaseRun):
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
start: str
|
||||
start datetime, default "2000-01-01"
|
||||
start datetime, default "2000-01-01"; closed interval(including start)
|
||||
end: str
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``; open interval(excluding end)
|
||||
check_data_length: int
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
|
||||
Notes
|
||||
-----
|
||||
check_data_length, example:
|
||||
daily, one year: 252 // 4
|
||||
us 1min, a week: 6.5 * 60 * 5
|
||||
cn 1min, a week: 4 * 60 * 5
|
||||
|
||||
Examples
|
||||
---------
|
||||
# get daily data
|
||||
@@ -813,6 +874,85 @@ class Run(BaseRun):
|
||||
)
|
||||
yc.normalize()
|
||||
|
||||
def normalize_data_1min_cn_offline(
|
||||
self, qlib_data_1d_dir, date_field_name: str = "date", symbol_field_name: str = "symbol"
|
||||
):
|
||||
"""Normalised to 1min using local 1d data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_data_1d_dir: str
|
||||
the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data
|
||||
date_field_name: str
|
||||
date field name, default date
|
||||
symbol_field_name: str
|
||||
symbol field name, default symbol
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data_1min_cn_offline --qlib_data_1d_dir ~/.qlib/qlib_data/cn_1d --source_dir ~/.qlib/stock_data/source_cn_1min --normalize_dir ~/.qlib/stock_data/normalize_cn_1min --region CN --interval 1min
|
||||
"""
|
||||
_class = getattr(self._cur_module, f"{self.normalize_class_name}Offline")
|
||||
yc = Normalize(
|
||||
source_dir=self.source_dir,
|
||||
target_dir=self.normalize_dir,
|
||||
normalize_class=_class,
|
||||
max_workers=self.max_workers,
|
||||
date_field_name=date_field_name,
|
||||
symbol_field_name=symbol_field_name,
|
||||
qlib_data_1d_dir=qlib_data_1d_dir,
|
||||
)
|
||||
yc.normalize()
|
||||
|
||||
def download_today_data(
|
||||
self,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
check_data_length=None,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download today data from Internet
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
check_data_length: int
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
|
||||
Notes
|
||||
-----
|
||||
Download today's data:
|
||||
start_time = datetime.datetime.now().date(); closed interval(including start)
|
||||
end_time = pd.Timestamp(start_time + pd.Timedelta(days=1)).date(); open interval(excluding end)
|
||||
|
||||
check_data_length, example:
|
||||
daily, one year: 252 // 4
|
||||
us 1min, a week: 6.5 * 60 * 5
|
||||
cn 1min, a week: 4 * 60 * 5
|
||||
|
||||
Examples
|
||||
---------
|
||||
# get daily data
|
||||
$ python collector.py download_today_data --source_dir ~/.qlib/stock_data/source --region CN --delay 0.1 --interval 1d
|
||||
# get 1m data
|
||||
$ python collector.py download_today_data --source_dir ~/.qlib/stock_data/source --region CN --delay 0.1 --interval 1m
|
||||
"""
|
||||
start = datetime.datetime.now().date()
|
||||
end = pd.Timestamp(start + pd.Timedelta(days=1)).date()
|
||||
self.download_data(
|
||||
max_collector_count,
|
||||
delay,
|
||||
start.strftime("%Y-%m-%d"),
|
||||
end.strftime("%Y-%m-%d"),
|
||||
check_data_length,
|
||||
limit_nums,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(Run)
|
||||
|
||||
Reference in New Issue
Block a user