1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Support Point-in-time Data Operation (#343)

* add period ops class

* black format

* add pit data read

* fix bug in period ops

* update ops runnable

* update PIT test example

* black format

* update PIT test

* update tets_PIT

* update code format

* add check_feature_exist

* black format

* optimize the PIT Algorithm

* fix bug

* update example

* update test_PIT name

* add pit collector

* black format

* fix bugs

* fix try

* fix bug & add dump_pit.py

* Successfully run and understand PIT

* Add some docs and remove a bug

* mv crypto collector

* black format

* Run succesfully after merging master

* Pass test and fix code

* remove useless PIT code

* fix PYlint

* Rename

Co-authored-by: Young <afe.young@gmail.com>
This commit is contained in:
bxdd
2022-03-10 14:27:52 +08:00
committed by GitHub
parent 3a911bc09b
commit faa99f30fa
19 changed files with 1459 additions and 141 deletions

View File

@@ -323,7 +323,7 @@ class BaseRun(abc.ABC):
freq, value from [1min, 1d], default 1d
"""
if source_dir is None:
source_dir = Path(self.default_base_dir).joinpath("_source")
source_dir = Path(self.default_base_dir).joinpath("source")
self.source_dir = Path(source_dir).expanduser().resolve()
self.source_dir.mkdir(parents=True, exist_ok=True)
@@ -359,6 +359,7 @@ class BaseRun(abc.ABC):
end=None,
check_data_length: int = None,
limit_nums=None,
**kwargs,
):
"""download data from Internet
@@ -398,6 +399,7 @@ class BaseRun(abc.ABC):
interval=self.interval,
check_data_length=check_data_length,
limit_nums=limit_nums,
**kwargs,
).collector_data()
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs):

View File

@@ -13,7 +13,7 @@ from dateutil.tz import tzlocal
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
from data_collector.utils import get_cg_crypto_symbols
from data_collector.utils import get_cg_crypto_symbols, deco_retry
from pycoingecko import CoinGeckoAPI
from time import mktime
@@ -21,6 +21,40 @@ from datetime import datetime as dt
import time
_CG_CRYPTO_SYMBOLS = None
def get_cg_crypto_symbols(qlib_data_path: [str, Path] = None) -> list:
"""get crypto symbols in coingecko
Returns
-------
crypto symbols in given exchanges list of coingecko
"""
global _CG_CRYPTO_SYMBOLS
@deco_retry
def _get_coingecko():
try:
cg = CoinGeckoAPI()
resp = pd.DataFrame(cg.get_coins_markets(vs_currency="usd"))
except:
raise ValueError("request error")
try:
_symbols = resp["id"].to_list()
except Exception as e:
logger.warning(f"request error: {e}")
raise
return _symbols
if _CG_CRYPTO_SYMBOLS is None:
_all_symbols = _get_coingecko()
_CG_CRYPTO_SYMBOLS = sorted(set(_all_symbols))
return _CG_CRYPTO_SYMBOLS
class CryptoCollector(BaseCollector):
def __init__(
self,

View File

@@ -0,0 +1,35 @@
# Collect Point-in-Time Data
> *Please pay **ATTENTION** that the data is collected from [baostock](http://baostock.com) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*
## Requirements
```bash
pip install -r requirements.txt
```
## Collector Data
### Download Quarterly CN Data
```bash
cd qlib/scripts/data_collector/pit/
# download from baostock.com
python collector.py download_data --source_dir ./csv_pit --start 2000-01-01 --end 2020-01-01 --interval quarterly
```
Downloading all data from the stock is very time consuming. If you just want run a quick test on a few stocks, you can run the command below
``` bash
python collector.py download_data --source_dir ./csv_pit --start 2000-01-01 --end 2020-01-01 --interval quarterly --symbol_flt_regx "^(600519|000725).*"
```
### Dump Data into PIT Format
```bash
cd qlib/scripts
# data_collector/pit/csv_pit is the data you download just now.
python dump_pit.py dump --csv_path data_collector/pit/csv_pit --qlib_dir ~/.qlib/qlib_data/cn_data --interval quarterly
```

View File

@@ -0,0 +1,334 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import re
import abc
import sys
import datetime
from abc import ABC
from pathlib import Path
import fire
import numpy as np
import pandas as pd
import baostock as bs
from loguru import logger
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
from data_collector.base import BaseCollector, BaseRun
from data_collector.utils import get_calendar_list, get_hs_stock_symbols
class PitCollector(BaseCollector):
DEFAULT_START_DATETIME_QUARTER = pd.Timestamp("2000-01-01")
DEFAULT_START_DATETIME_ANNUAL = pd.Timestamp("2000-01-01")
DEFAULT_END_DATETIME_QUARTER = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
DEFAULT_END_DATETIME_ANNUAL = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
INTERVAL_quarterly = "quarterly"
INTERVAL_annual = "annual"
def __init__(
self,
save_dir: [str, Path],
start=None,
end=None,
interval="quarterly",
max_workers=1,
max_collector_count=1,
delay=0,
check_data_length: bool = False,
limit_nums: int = None,
symbol_flt_regx=None,
):
"""
Parameters
----------
save_dir: str
pit save dir
interval: str:
value from ['quarterly', 'annual']
max_workers: int
workers, default 1
max_collector_count: int
default 1
delay: float
time.sleep(delay), default 0
start: str
start datetime, default None
end: str
end datetime, default None
limit_nums: int
using for debug, by default None
"""
if symbol_flt_regx is None:
self.symbol_flt_regx = None
else:
self.symbol_flt_regx = re.compile(symbol_flt_regx)
super(PitCollector, self).__init__(
save_dir=save_dir,
start=start,
end=end,
interval=interval,
max_workers=max_workers,
max_collector_count=max_collector_count,
delay=delay,
check_data_length=check_data_length,
limit_nums=limit_nums,
)
def normalize_symbol(self, symbol):
symbol_s = symbol.split(".")
symbol = f"sh{symbol_s[0]}" if symbol_s[-1] == "ss" else f"sz{symbol_s[0]}"
return symbol
def get_instrument_list(self):
logger.info("get cn stock symbols......")
symbols = get_hs_stock_symbols()
logger.info(f"get {symbols[:10]}[{len(symbols)}] symbols.")
if self.symbol_flt_regx is not None:
s_flt = []
for s in symbols:
m = self.symbol_flt_regx.match(s)
if m is not None:
s_flt.append(s)
logger.info(f"after filtering, it becomes {s_flt[:10]}[{len(s_flt)}] symbols")
return s_flt
return symbols
def _get_data_from_baostock(self, symbol, interval, start_datetime, end_datetime):
error_msg = f"{symbol}-{interval}-{start_datetime}-{end_datetime}"
def _str_to_float(r):
try:
return float(r)
except Exception as e:
return np.nan
try:
code, market = symbol.split(".")
market = {"ss": "sh"}.get(market, market) # baostock's API naming is different from default symbol list
symbol = f"{market}.{code}"
rs_report = bs.query_performance_express_report(
code=symbol, start_date=str(start_datetime.date()), end_date=str(end_datetime.date())
)
report_list = []
while (rs_report.error_code == "0") & rs_report.next():
report_list.append(rs_report.get_row_data())
# 获取一条记录,将记录合并在一起
df_report = pd.DataFrame(report_list, columns=rs_report.fields)
if {"performanceExpPubDate", "performanceExpStatDate", "performanceExpressROEWa"} <= set(rs_report.fields):
df_report = df_report[["performanceExpPubDate", "performanceExpStatDate", "performanceExpressROEWa"]]
df_report.rename(
columns={
"performanceExpPubDate": "date",
"performanceExpStatDate": "period",
"performanceExpressROEWa": "value",
},
inplace=True,
)
df_report["value"] = df_report["value"].apply(lambda r: _str_to_float(r) / 100.0)
df_report["field"] = "roeWa"
profit_list = []
for year in range(start_datetime.year - 1, end_datetime.year + 1):
for q_num in range(0, 4):
rs_profit = bs.query_profit_data(code=symbol, year=year, quarter=q_num + 1)
while (rs_profit.error_code == "0") & rs_profit.next():
row_data = rs_profit.get_row_data()
if "pubDate" in rs_profit.fields:
pub_date = pd.Timestamp(row_data[rs_profit.fields.index("pubDate")])
if pub_date >= start_datetime and pub_date <= end_datetime:
profit_list.append(row_data)
df_profit = pd.DataFrame(profit_list, columns=rs_profit.fields)
if {"pubDate", "statDate", "roeAvg"} <= set(rs_profit.fields):
df_profit = df_profit[["pubDate", "statDate", "roeAvg"]]
df_profit.rename(
columns={"pubDate": "date", "statDate": "period", "roeAvg": "value"},
inplace=True,
)
df_profit["value"] = df_profit["value"].apply(_str_to_float)
df_profit["field"] = "roeWa"
forecast_list = []
rs_forecast = bs.query_forecast_report(
code=symbol, start_date=str(start_datetime.date()), end_date=str(end_datetime.date())
)
while (rs_forecast.error_code == "0") & rs_forecast.next():
forecast_list.append(rs_forecast.get_row_data())
df_forecast = pd.DataFrame(forecast_list, columns=rs_forecast.fields)
if {
"profitForcastExpPubDate",
"profitForcastExpStatDate",
"profitForcastChgPctUp",
"profitForcastChgPctDwn",
} <= set(rs_forecast.fields):
df_forecast = df_forecast[
[
"profitForcastExpPubDate",
"profitForcastExpStatDate",
"profitForcastChgPctUp",
"profitForcastChgPctDwn",
]
]
df_forecast.rename(
columns={
"profitForcastExpPubDate": "date",
"profitForcastExpStatDate": "period",
},
inplace=True,
)
df_forecast["profitForcastChgPctUp"] = df_forecast["profitForcastChgPctUp"].apply(_str_to_float)
df_forecast["profitForcastChgPctDwn"] = df_forecast["profitForcastChgPctDwn"].apply(_str_to_float)
df_forecast["value"] = (
df_forecast["profitForcastChgPctUp"] + df_forecast["profitForcastChgPctDwn"]
) / 200
df_forecast["field"] = "YOYNI"
df_forecast.drop(["profitForcastChgPctUp", "profitForcastChgPctDwn"], axis=1, inplace=True)
growth_list = []
for year in range(start_datetime.year - 1, end_datetime.year + 1):
for q_num in range(0, 4):
rs_growth = bs.query_growth_data(code=symbol, year=year, quarter=q_num + 1)
while (rs_growth.error_code == "0") & rs_growth.next():
row_data = rs_growth.get_row_data()
if "pubDate" in rs_growth.fields:
pub_date = pd.Timestamp(row_data[rs_growth.fields.index("pubDate")])
if pub_date >= start_datetime and pub_date <= end_datetime:
growth_list.append(row_data)
df_growth = pd.DataFrame(growth_list, columns=rs_growth.fields)
if {"pubDate", "statDate", "YOYNI"} <= set(rs_growth.fields):
df_growth = df_growth[["pubDate", "statDate", "YOYNI"]]
df_growth.rename(
columns={"pubDate": "date", "statDate": "period", "YOYNI": "value"},
inplace=True,
)
df_growth["value"] = df_growth["value"].apply(_str_to_float)
df_growth["field"] = "YOYNI"
df_merge = df_report.append([df_profit, df_forecast, df_growth])
return df_merge
except Exception as e:
logger.warning(f"{error_msg}:{e}")
def _process_data(self, df, symbol, interval):
error_msg = f"{symbol}-{interval}"
def _process_period(r):
_date = pd.Timestamp(r)
return _date.year if interval == self.INTERVAL_annual else _date.year * 100 + (_date.month - 1) // 3 + 1
try:
_date = df["period"].apply(
lambda x: (
pd.to_datetime(x) + pd.DateOffset(days=(45 if interval == self.INTERVAL_quarterly else 90))
).date()
)
df["date"] = df["date"].fillna(_date.astype(str))
df["period"] = df["period"].apply(_process_period)
return df
except Exception as e:
logger.warning(f"{error_msg}:{e}")
def get_data(
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
) -> [pd.DataFrame]:
if interval == self.INTERVAL_quarterly:
_result = self._get_data_from_baostock(symbol, interval, start_datetime, end_datetime)
if _result is None or _result.empty:
return _result
else:
return self._process_data(_result, symbol, interval)
else:
raise ValueError(f"cannot support {interval}")
return self._process_data(_result, interval)
@property
def min_numbers_trading(self):
pass
class Run(BaseRun):
def __init__(self, source_dir=None, max_workers=1, interval="quarterly"):
"""
Parameters
----------
source_dir: str
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
normalize_dir: str
Directory for normalize data, default "Path(__file__).parent/normalize"
max_workers: int
Concurrent number, default is 4
interval: str
freq, value from [quarterly, annual], default 1d
"""
super().__init__(source_dir=source_dir, max_workers=max_workers, interval=interval)
@property
def collector_class_name(self):
return "PitCollector"
@property
def default_base_dir(self) -> [Path, str]:
return CUR_DIR
def download_data(
self,
max_collector_count=1,
delay=0,
start=None,
end=None,
interval="quarterly",
check_data_length=False,
limit_nums=None,
**kwargs,
):
"""download data from Internet
Parameters
----------
max_collector_count: int
default 2
delay: float
time.sleep(delay), default 0
interval: str
freq, value from [quarterly, annual], default 1d
start: str
start datetime, default "2000-01-01"
end: str
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
check_data_length: bool # if this param useful?
check data length, by default False
limit_nums: int
using for debug, by default None
Examples
---------
# get quarterly data
$ python collector.py download_data --source_dir ~/.qlib/cn_data/source/pit_quarter --start 2000-01-01 --end 2021-01-01 --interval quarterly
"""
super(Run, self).download_data(
max_collector_count, delay, start, end, interval, check_data_length, limit_nums, **kwargs
)
def normalize_class_name(self):
pass
if __name__ == "__main__":
bs.login()
fire.Fire(Run)
bs.logout()

View File

@@ -0,0 +1,9 @@
loguru
fire
tqdm
requests
pandas
lxml
loguru
baostock
yahooquery

View File

@@ -0,0 +1,194 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import qlib
from qlib.data import D
import unittest
class TestPIT(unittest.TestCase):
"""
NOTE!!!!!!
The assert of this test assumes that users follows the cmd below and only download 2 stock.
`python collector.py download_data --source_dir ./csv_pit --start 2000-01-01 --end 2020-01-01 --interval quarterly --symbol_flt_regx "^(600519|000725).*"`
"""
def setUp(self):
# qlib.init(kernels=1) # NOTE: set kernel to 1 to make it debug easier
qlib.init() # NOTE: set kernel to 1 to make it debug easier
def to_str(self, obj):
return "".join(str(obj).split())
def check_same(self, a, b):
self.assertEqual(self.to_str(a), self.to_str(b))
def test_query(self):
instruments = ["sh600519"]
fields = ["P($$roewa_q)", "P($$yoyni_q)"]
# Mao Tai published 2019Q2 report at 2019-07-13 & 2019-07-18
# - http://www.cninfo.com.cn/new/commonUrl/pageOfSearch?url=disclosure/list/search&lastPage=index
data = D.features(instruments, fields, start_time="2019-01-01", end_time="20190719", freq="day")
print(data)
res = """
P($$roewa_q) P($$yoyni_q)
count 133.000000 133.000000
mean 0.196412 0.277930
std 0.097591 0.030262
min 0.000000 0.243892
25% 0.094737 0.243892
50% 0.255220 0.304181
75% 0.255220 0.305041
max 0.344644 0.305041
"""
self.check_same(data.describe(), res)
res = """
P($$roewa_q) P($$yoyni_q)
instrument datetime
sh600519 2019-07-15 0.000000 0.305041
2019-07-16 0.000000 0.305041
2019-07-17 0.000000 0.305041
2019-07-18 0.175322 0.252650
2019-07-19 0.175322 0.252650
"""
self.check_same(data.tail(), res)
def test_no_exist_data(self):
fields = ["P($$roewa_q)", "P($$yoyni_q)", "$close"]
data = D.features(["sh600519", "sh601988"], fields, start_time="2019-01-01", end_time="20190719", freq="day")
data["$close"] = 1 # in case of different dataset gives different values
print(data)
expect = """
P($$roewa_q) P($$yoyni_q) $close
instrument datetime
sh600519 2019-01-02 0.25522 0.243892 1
2019-01-03 0.25522 0.243892 1
2019-01-04 0.25522 0.243892 1
2019-01-07 0.25522 0.243892 1
2019-01-08 0.25522 0.243892 1
... ... ... ...
sh601988 2019-07-15 NaN NaN 1
2019-07-16 NaN NaN 1
2019-07-17 NaN NaN 1
2019-07-18 NaN NaN 1
2019-07-19 NaN NaN 1
[266 rows x 3 columns]
"""
self.check_same(data, expect)
def test_expr(self):
fields = [
"P(Mean($$roewa_q, 1))",
"P($$roewa_q)",
"P(Mean($$roewa_q, 2))",
"P(Ref($$roewa_q, 1))",
"P((Ref($$roewa_q, 1) +$$roewa_q) / 2)",
]
instruments = ["sh600519"]
data = D.features(instruments, fields, start_time="2019-01-01", end_time="20190719", freq="day")
expect = """
P(Mean($$roewa_q, 1)) P($$roewa_q) P(Mean($$roewa_q, 2)) P(Ref($$roewa_q, 1)) P((Ref($$roewa_q, 1) +$$roewa_q) / 2)
instrument datetime
sh600519 2019-07-01 0.094737 0.094737 0.219691 0.344644 0.219691
2019-07-02 0.094737 0.094737 0.219691 0.344644 0.219691
2019-07-03 0.094737 0.094737 0.219691 0.344644 0.219691
2019-07-04 0.094737 0.094737 0.219691 0.344644 0.219691
2019-07-05 0.094737 0.094737 0.219691 0.344644 0.219691
2019-07-08 0.094737 0.094737 0.219691 0.344644 0.219691
2019-07-09 0.094737 0.094737 0.219691 0.344644 0.219691
2019-07-10 0.094737 0.094737 0.219691 0.344644 0.219691
2019-07-11 0.094737 0.094737 0.219691 0.344644 0.219691
2019-07-12 0.094737 0.094737 0.219691 0.344644 0.219691
2019-07-15 0.000000 0.000000 0.047369 0.094737 0.047369
2019-07-16 0.000000 0.000000 0.047369 0.094737 0.047369
2019-07-17 0.000000 0.000000 0.047369 0.094737 0.047369
2019-07-18 0.175322 0.175322 0.135029 0.094737 0.135029
2019-07-19 0.175322 0.175322 0.135029 0.094737 0.135029
"""
self.check_same(data.tail(15), expect)
def test_unlimit(self):
# fields = ["P(Mean($$roewa_q, 1))", "P($$roewa_q)", "P(Mean($$roewa_q, 2))", "P(Ref($$roewa_q, 1))", "P((Ref($$roewa_q, 1) +$$roewa_q) / 2)"]
fields = ["P($$roewa_q)"]
instruments = ["sh600519"]
_ = D.features(instruments, fields, freq="day") # this should not raise error
data = D.features(instruments, fields, end_time="20200101", freq="day") # this should not raise error
s = data.iloc[:, 0]
# You can check the expected value based on the content in `docs/advanced/PIT.rst`
expect = """
instrument datetime
sh600519 1999-11-10 NaN
2007-04-30 0.090219
2007-08-17 0.139330
2007-10-23 0.245863
2008-03-03 0.347900
2008-03-13 0.395989
2008-04-22 0.100724
2008-08-28 0.249968
2008-10-27 0.334120
2009-03-25 0.390117
2009-04-21 0.102675
2009-08-07 0.230712
2009-10-26 0.300730
2010-04-02 0.335461
2010-04-26 0.083825
2010-08-12 0.200545
2010-10-29 0.260986
2011-03-21 0.307393
2011-04-25 0.097411
2011-08-31 0.248251
2011-10-18 0.318919
2012-03-23 0.403900
2012-04-11 0.403925
2012-04-26 0.112148
2012-08-10 0.264847
2012-10-26 0.370487
2013-03-29 0.450047
2013-04-18 0.099958
2013-09-02 0.210442
2013-10-16 0.304543
2014-03-25 0.394328
2014-04-25 0.083217
2014-08-29 0.164503
2014-10-30 0.234085
2015-04-21 0.078494
2015-08-28 0.137504
2015-10-26 0.201709
2016-03-24 0.264205
2016-04-21 0.073664
2016-08-29 0.136576
2016-10-31 0.188062
2017-04-17 0.244385
2017-04-25 0.080614
2017-07-28 0.151510
2017-10-26 0.254166
2018-03-28 0.329542
2018-05-02 0.088887
2018-08-02 0.170563
2018-10-29 0.255220
2019-03-29 0.344644
2019-04-25 0.094737
2019-07-15 0.000000
2019-07-18 0.175322
2019-10-16 0.255819
Name: P($$roewa_q), dtype: float32
"""
self.check_same(s[~s.duplicated().values], expect)
def test_expr2(self):
instruments = ["sh600519"]
fields = ["P($$roewa_q)", "P($$yoyni_q)"]
fields += ["P(($$roewa_q / $$yoyni_q) / Ref($$roewa_q / $$yoyni_q, 1) - 1)"]
fields += ["P(Sum($$yoyni_q, 4))"]
fields += ["$close", "P($$roewa_q) * $close"]
data = D.features(instruments, fields, start_time="2019-01-01", end_time="2020-01-01", freq="day")
print(data)
print(data.describe())
if __name__ == "__main__":
unittest.main()

View File

@@ -19,7 +19,6 @@ from yahooquery import Ticker
from tqdm import tqdm
from functools import partial
from concurrent.futures import ProcessPoolExecutor
from pycoingecko import CoinGeckoAPI
HS_SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}"
@@ -43,7 +42,6 @@ _HS_SYMBOLS = None
_US_SYMBOLS = None
_IN_SYMBOLS = None
_EN_FUND_SYMBOLS = None
_CG_CRYPTO_SYMBOLS = None
_CALENDAR_MAP = {}
# NOTE: Until 2020-10-20 20:00:00
@@ -379,37 +377,6 @@ def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list:
return _EN_FUND_SYMBOLS
def get_cg_crypto_symbols(qlib_data_path: [str, Path] = None) -> list:
"""get crypto symbols in coingecko
Returns
-------
crypto symbols in given exchanges list of coingecko
"""
global _CG_CRYPTO_SYMBOLS
@deco_retry
def _get_coingecko():
try:
cg = CoinGeckoAPI()
resp = pd.DataFrame(cg.get_coins_markets(vs_currency="usd"))
except:
raise ValueError("request error")
try:
_symbols = resp["id"].to_list()
except Exception as e:
logger.warning(f"request error: {e}")
raise
return _symbols
if _CG_CRYPTO_SYMBOLS is None:
_all_symbols = _get_coingecko()
_CG_CRYPTO_SYMBOLS = sorted(set(_all_symbols))
return _CG_CRYPTO_SYMBOLS
def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str:
"""symbol suffix to prefix

282
scripts/dump_pit.py Normal file
View File

@@ -0,0 +1,282 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
TODO:
- A more well-designed PIT database is required.
- seperated insert, delete, update, query operations are required.
"""
import abc
import shutil
import struct
import traceback
from pathlib import Path
from typing import Iterable, List, Union
from functools import partial
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import fire
import numpy as np
import pandas as pd
from tqdm import tqdm
from loguru import logger
from qlib.utils import fname_to_code, code_to_fname, get_period_offset
from qlib.config import C
class DumpPitData:
PIT_DIR_NAME = "financial"
PIT_CSV_SEP = ","
DATA_FILE_SUFFIX = ".data"
INDEX_FILE_SUFFIX = ".index"
INTERVAL_quarterly = "quarterly"
INTERVAL_annual = "annual"
PERIOD_DTYPE = C.pit_record_type["period"]
INDEX_DTYPE = C.pit_record_type["index"]
DATA_DTYPE = "".join(
[
C.pit_record_type["date"],
C.pit_record_type["period"],
C.pit_record_type["value"],
C.pit_record_type["index"],
]
)
NA_INDEX = C.pit_record_nan["index"]
INDEX_DTYPE_SIZE = struct.calcsize(INDEX_DTYPE)
PERIOD_DTYPE_SIZE = struct.calcsize(PERIOD_DTYPE)
DATA_DTYPE_SIZE = struct.calcsize(DATA_DTYPE)
UPDATE_MODE = "update"
ALL_MODE = "all"
def __init__(
self,
csv_path: str,
qlib_dir: str,
backup_dir: str = None,
freq: str = "quarterly",
max_workers: int = 16,
date_column_name: str = "date",
period_column_name: str = "period",
value_column_name: str = "value",
field_column_name: str = "field",
file_suffix: str = ".csv",
exclude_fields: str = "",
include_fields: str = "",
limit_nums: int = None,
):
"""
Parameters
----------
csv_path: str
stock data path or directory
qlib_dir: str
qlib(dump) data director
backup_dir: str, default None
if backup_dir is not None, backup qlib_dir to backup_dir
freq: str, default "quarterly"
data frequency
max_workers: int, default None
number of threads
date_column_name: str, default "date"
the name of the date field in the csv
file_suffix: str, default ".csv"
file suffix
include_fields: tuple
dump fields
exclude_fields: tuple
fields not dumped
limit_nums: int
Use when debugging, default None
"""
csv_path = Path(csv_path).expanduser()
if isinstance(exclude_fields, str):
exclude_fields = exclude_fields.split(",")
if isinstance(include_fields, str):
include_fields = include_fields.split(",")
self._exclude_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, exclude_fields)))
self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields)))
self.file_suffix = file_suffix
self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path])
if limit_nums is not None:
self.csv_files = self.csv_files[: int(limit_nums)]
self.qlib_dir = Path(qlib_dir).expanduser()
self.backup_dir = backup_dir if backup_dir is None else Path(backup_dir).expanduser()
if backup_dir is not None:
self._backup_qlib_dir(Path(backup_dir).expanduser())
self.works = max_workers
self.date_column_name = date_column_name
self.period_column_name = period_column_name
self.value_column_name = value_column_name
self.field_column_name = field_column_name
self._mode = self.ALL_MODE
def _backup_qlib_dir(self, target_dir: Path):
shutil.copytree(str(self.qlib_dir.resolve()), str(target_dir.resolve()))
def get_source_data(self, file_path: Path) -> pd.DataFrame:
df = pd.read_csv(str(file_path.resolve()), low_memory=False)
df[self.value_column_name] = df[self.value_column_name].astype("float32")
df[self.date_column_name] = df[self.date_column_name].str.replace("-", "").astype("int32")
# df.drop_duplicates([self.date_field_name], inplace=True)
return df
def get_symbol_from_file(self, file_path: Path) -> str:
return fname_to_code(file_path.name[: -len(self.file_suffix)].strip().lower())
def get_dump_fields(self, df: Iterable[str]) -> Iterable[str]:
return (
set(self._include_fields)
if self._include_fields
else set(df[self.field_column_name]) - set(self._exclude_fields)
if self._exclude_fields
else set(df[self.field_column_name])
)
def get_filenames(self, symbol, field, interval):
dir_name = self.qlib_dir.joinpath(self.PIT_DIR_NAME, symbol)
dir_name.mkdir(parents=True, exist_ok=True)
return (
dir_name.joinpath(f"{field}_{interval[0]}{self.DATA_FILE_SUFFIX}".lower()),
dir_name.joinpath(f"{field}_{interval[0]}{self.INDEX_FILE_SUFFIX}".lower()),
)
def _dump_pit(
self,
file_path: str,
interval: str = "quarterly",
overwrite: bool = False,
):
"""
dump data as the following format:
`/path/to/<field>.data`
[date, period, value, _next]
[date, period, value, _next]
[...]
`/path/to/<field>.index`
[first_year, index, index, ...]
`<field.data>` contains the data as the point-in-time (PIT) order: `value` of `period`
is published at `date`, and its successive revised value can be found at `_next` (linked list).
`<field>.index` contains the index of value for each period (quarter or year). To save
disk space, we only store the `first_year` as its followings periods can be easily infered.
Parameters
----------
symbol: str
stock symbol
interval: str
data interval
overwrite: bool
whether overwrite existing data or update only
"""
symbol = self.get_symbol_from_file(file_path)
df = self.get_source_data(file_path)
if df.empty:
logger.warning(f"{symbol} file is empty")
return
for field in self.get_dump_fields(df):
df_sub = df.query(f'{self.field_column_name}=="{field}"').sort_values(self.date_column_name)
if df_sub.empty:
logger.warning(f"field {field} of {symbol} is empty")
continue
data_file, index_file = self.get_filenames(symbol, field, interval)
## calculate first & last period
start_year = df_sub[self.period_column_name].min()
end_year = df_sub[self.period_column_name].max()
if interval == self.INTERVAL_quarterly:
start_year //= 100
end_year //= 100
# adjust `first_year` if existing data found
if not overwrite and index_file.exists():
with open(index_file, "rb") as fi:
(first_year,) = struct.unpack(self.PERIOD_DTYPE, fi.read(self.PERIOD_DTYPE_SIZE))
n_years = len(fi.read()) // self.INDEX_DTYPE_SIZE
if interval == self.INTERVAL_quarterly:
n_years //= 4
start_year = first_year + n_years
else:
with open(index_file, "wb") as f:
f.write(struct.pack(self.PERIOD_DTYPE, start_year))
first_year = start_year
# if data already exists, continue to the next field
if start_year > end_year:
logger.warning(f"{symbol}-{field} data already exists, continue to the next field")
continue
# dump index filled with NA
with open(index_file, "ab") as fi:
for year in range(start_year, end_year + 1):
if interval == self.INTERVAL_quarterly:
fi.write(struct.pack(self.INDEX_DTYPE * 4, *[self.NA_INDEX] * 4))
else:
fi.write(struct.pack(self.INDEX_DTYPE, self.NA_INDEX))
# if data already exists, remove overlapped data
if not overwrite and data_file.exists():
with open(data_file, "rb") as fd:
fd.seek(-self.DATA_DTYPE_SIZE, 2)
last_date, _, _, _ = struct.unpack(self.DATA_DTYPE, fd.read())
df_sub = df_sub.query(f"{self.date_column_name}>{last_date}")
# otherwise,
# 1) truncate existing file or create a new file with `wb+` if overwrite,
# 2) or append existing file or create a new file with `ab+` if not overwrite
else:
with open(data_file, "wb+" if overwrite else "ab+"):
pass
with open(data_file, "rb+") as fd, open(index_file, "rb+") as fi:
# update index if needed
for i, row in df_sub.iterrows():
# get index
offset = get_period_offset(first_year, row.period, interval == self.INTERVAL_quarterly)
fi.seek(self.PERIOD_DTYPE_SIZE + self.INDEX_DTYPE_SIZE * offset)
(cur_index,) = struct.unpack(self.INDEX_DTYPE, fi.read(self.INDEX_DTYPE_SIZE))
# Case I: new data => update `_next` with current index
if cur_index == self.NA_INDEX:
fi.seek(self.PERIOD_DTYPE_SIZE + self.INDEX_DTYPE_SIZE * offset)
fi.write(struct.pack(self.INDEX_DTYPE, fd.tell()))
# Case II: previous data exists => find and update the last `_next`
else:
_cur_fd = fd.tell()
prev_index = self.NA_INDEX
while cur_index != self.NA_INDEX: # NOTE: first iter always != NA_INDEX
fd.seek(cur_index + self.DATA_DTYPE_SIZE - self.INDEX_DTYPE_SIZE)
prev_index = cur_index
(cur_index,) = struct.unpack(self.INDEX_DTYPE, fd.read(self.INDEX_DTYPE_SIZE))
fd.seek(prev_index + self.DATA_DTYPE_SIZE - self.INDEX_DTYPE_SIZE)
fd.write(struct.pack(self.INDEX_DTYPE, _cur_fd)) # NOTE: add _next pointer
fd.seek(_cur_fd)
# dump data
fd.write(struct.pack(self.DATA_DTYPE, row.date, row.period, row.value, self.NA_INDEX))
def dump(self, interval="quarterly", overwrite=False):
logger.info("start dump pit data......")
_dump_func = partial(self._dump_pit, interval=interval, overwrite=overwrite)
with tqdm(total=len(self.csv_files)) as p_bar:
with ProcessPoolExecutor(max_workers=self.works) as executor:
for _ in executor.map(_dump_func, self.csv_files):
p_bar.update()
def __call__(self, *args, **kwargs):
self.dump()
if __name__ == "__main__":
fire.Fire(DumpPitData)