mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Support Point-in-time Data Operation (#343)
* add period ops class * black format * add pit data read * fix bug in period ops * update ops runnable * update PIT test example * black format * update PIT test * update tets_PIT * update code format * add check_feature_exist * black format * optimize the PIT Algorithm * fix bug * update example * update test_PIT name * add pit collector * black format * fix bugs * fix try * fix bug & add dump_pit.py * Successfully run and understand PIT * Add some docs and remove a bug * mv crypto collector * black format * Run succesfully after merging master * Pass test and fix code * remove useless PIT code * fix PYlint * Rename Co-authored-by: Young <afe.young@gmail.com>
This commit is contained in:
@@ -323,7 +323,7 @@ class BaseRun(abc.ABC):
|
||||
freq, value from [1min, 1d], default 1d
|
||||
"""
|
||||
if source_dir is None:
|
||||
source_dir = Path(self.default_base_dir).joinpath("_source")
|
||||
source_dir = Path(self.default_base_dir).joinpath("source")
|
||||
self.source_dir = Path(source_dir).expanduser().resolve()
|
||||
self.source_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -359,6 +359,7 @@ class BaseRun(abc.ABC):
|
||||
end=None,
|
||||
check_data_length: int = None,
|
||||
limit_nums=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""download data from Internet
|
||||
|
||||
@@ -398,6 +399,7 @@ class BaseRun(abc.ABC):
|
||||
interval=self.interval,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
**kwargs,
|
||||
).collector_data()
|
||||
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs):
|
||||
|
||||
@@ -13,7 +13,7 @@ from dateutil.tz import tzlocal
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
|
||||
from data_collector.utils import get_cg_crypto_symbols
|
||||
from data_collector.utils import get_cg_crypto_symbols, deco_retry
|
||||
|
||||
from pycoingecko import CoinGeckoAPI
|
||||
from time import mktime
|
||||
@@ -21,6 +21,40 @@ from datetime import datetime as dt
|
||||
import time
|
||||
|
||||
|
||||
_CG_CRYPTO_SYMBOLS = None
|
||||
|
||||
|
||||
def get_cg_crypto_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
"""get crypto symbols in coingecko
|
||||
|
||||
Returns
|
||||
-------
|
||||
crypto symbols in given exchanges list of coingecko
|
||||
"""
|
||||
global _CG_CRYPTO_SYMBOLS
|
||||
|
||||
@deco_retry
|
||||
def _get_coingecko():
|
||||
try:
|
||||
cg = CoinGeckoAPI()
|
||||
resp = pd.DataFrame(cg.get_coins_markets(vs_currency="usd"))
|
||||
except:
|
||||
raise ValueError("request error")
|
||||
try:
|
||||
_symbols = resp["id"].to_list()
|
||||
except Exception as e:
|
||||
logger.warning(f"request error: {e}")
|
||||
raise
|
||||
return _symbols
|
||||
|
||||
if _CG_CRYPTO_SYMBOLS is None:
|
||||
_all_symbols = _get_coingecko()
|
||||
|
||||
_CG_CRYPTO_SYMBOLS = sorted(set(_all_symbols))
|
||||
|
||||
return _CG_CRYPTO_SYMBOLS
|
||||
|
||||
|
||||
class CryptoCollector(BaseCollector):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
35
scripts/data_collector/pit/README.md
Normal file
35
scripts/data_collector/pit/README.md
Normal file
@@ -0,0 +1,35 @@
|
||||
# Collect Point-in-Time Data
|
||||
|
||||
> *Please pay **ATTENTION** that the data is collected from [baostock](http://baostock.com) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Collector Data
|
||||
|
||||
|
||||
### Download Quarterly CN Data
|
||||
|
||||
```bash
|
||||
cd qlib/scripts/data_collector/pit/
|
||||
# download from baostock.com
|
||||
python collector.py download_data --source_dir ./csv_pit --start 2000-01-01 --end 2020-01-01 --interval quarterly
|
||||
```
|
||||
|
||||
Downloading all data from the stock is very time consuming. If you just want run a quick test on a few stocks, you can run the command below
|
||||
``` bash
|
||||
python collector.py download_data --source_dir ./csv_pit --start 2000-01-01 --end 2020-01-01 --interval quarterly --symbol_flt_regx "^(600519|000725).*"
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Dump Data into PIT Format
|
||||
|
||||
```bash
|
||||
cd qlib/scripts
|
||||
# data_collector/pit/csv_pit is the data you download just now.
|
||||
python dump_pit.py dump --csv_path data_collector/pit/csv_pit --qlib_dir ~/.qlib/qlib_data/cn_data --interval quarterly
|
||||
```
|
||||
334
scripts/data_collector/pit/collector.py
Normal file
334
scripts/data_collector/pit/collector.py
Normal file
@@ -0,0 +1,334 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import abc
|
||||
import sys
|
||||
import datetime
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import baostock as bs
|
||||
from loguru import logger
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.base import BaseCollector, BaseRun
|
||||
from data_collector.utils import get_calendar_list, get_hs_stock_symbols
|
||||
|
||||
|
||||
class PitCollector(BaseCollector):
|
||||
|
||||
DEFAULT_START_DATETIME_QUARTER = pd.Timestamp("2000-01-01")
|
||||
DEFAULT_START_DATETIME_ANNUAL = pd.Timestamp("2000-01-01")
|
||||
DEFAULT_END_DATETIME_QUARTER = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
|
||||
DEFAULT_END_DATETIME_ANNUAL = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
|
||||
|
||||
INTERVAL_quarterly = "quarterly"
|
||||
INTERVAL_annual = "annual"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: [str, Path],
|
||||
start=None,
|
||||
end=None,
|
||||
interval="quarterly",
|
||||
max_workers=1,
|
||||
max_collector_count=1,
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
limit_nums: int = None,
|
||||
symbol_flt_regx=None,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
save_dir: str
|
||||
pit save dir
|
||||
interval: str:
|
||||
value from ['quarterly', 'annual']
|
||||
max_workers: int
|
||||
workers, default 1
|
||||
max_collector_count: int
|
||||
default 1
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
start: str
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
"""
|
||||
if symbol_flt_regx is None:
|
||||
self.symbol_flt_regx = None
|
||||
else:
|
||||
self.symbol_flt_regx = re.compile(symbol_flt_regx)
|
||||
super(PitCollector, self).__init__(
|
||||
save_dir=save_dir,
|
||||
start=start,
|
||||
end=end,
|
||||
interval=interval,
|
||||
max_workers=max_workers,
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
)
|
||||
|
||||
def normalize_symbol(self, symbol):
|
||||
symbol_s = symbol.split(".")
|
||||
symbol = f"sh{symbol_s[0]}" if symbol_s[-1] == "ss" else f"sz{symbol_s[0]}"
|
||||
return symbol
|
||||
|
||||
def get_instrument_list(self):
|
||||
logger.info("get cn stock symbols......")
|
||||
symbols = get_hs_stock_symbols()
|
||||
logger.info(f"get {symbols[:10]}[{len(symbols)}] symbols.")
|
||||
if self.symbol_flt_regx is not None:
|
||||
s_flt = []
|
||||
for s in symbols:
|
||||
m = self.symbol_flt_regx.match(s)
|
||||
if m is not None:
|
||||
s_flt.append(s)
|
||||
logger.info(f"after filtering, it becomes {s_flt[:10]}[{len(s_flt)}] symbols")
|
||||
return s_flt
|
||||
|
||||
return symbols
|
||||
|
||||
def _get_data_from_baostock(self, symbol, interval, start_datetime, end_datetime):
|
||||
error_msg = f"{symbol}-{interval}-{start_datetime}-{end_datetime}"
|
||||
|
||||
def _str_to_float(r):
|
||||
try:
|
||||
return float(r)
|
||||
except Exception as e:
|
||||
return np.nan
|
||||
|
||||
try:
|
||||
code, market = symbol.split(".")
|
||||
market = {"ss": "sh"}.get(market, market) # baostock's API naming is different from default symbol list
|
||||
symbol = f"{market}.{code}"
|
||||
rs_report = bs.query_performance_express_report(
|
||||
code=symbol, start_date=str(start_datetime.date()), end_date=str(end_datetime.date())
|
||||
)
|
||||
report_list = []
|
||||
while (rs_report.error_code == "0") & rs_report.next():
|
||||
report_list.append(rs_report.get_row_data())
|
||||
# 获取一条记录,将记录合并在一起
|
||||
|
||||
df_report = pd.DataFrame(report_list, columns=rs_report.fields)
|
||||
if {"performanceExpPubDate", "performanceExpStatDate", "performanceExpressROEWa"} <= set(rs_report.fields):
|
||||
df_report = df_report[["performanceExpPubDate", "performanceExpStatDate", "performanceExpressROEWa"]]
|
||||
df_report.rename(
|
||||
columns={
|
||||
"performanceExpPubDate": "date",
|
||||
"performanceExpStatDate": "period",
|
||||
"performanceExpressROEWa": "value",
|
||||
},
|
||||
inplace=True,
|
||||
)
|
||||
df_report["value"] = df_report["value"].apply(lambda r: _str_to_float(r) / 100.0)
|
||||
df_report["field"] = "roeWa"
|
||||
|
||||
profit_list = []
|
||||
for year in range(start_datetime.year - 1, end_datetime.year + 1):
|
||||
for q_num in range(0, 4):
|
||||
rs_profit = bs.query_profit_data(code=symbol, year=year, quarter=q_num + 1)
|
||||
while (rs_profit.error_code == "0") & rs_profit.next():
|
||||
row_data = rs_profit.get_row_data()
|
||||
if "pubDate" in rs_profit.fields:
|
||||
pub_date = pd.Timestamp(row_data[rs_profit.fields.index("pubDate")])
|
||||
if pub_date >= start_datetime and pub_date <= end_datetime:
|
||||
profit_list.append(row_data)
|
||||
|
||||
df_profit = pd.DataFrame(profit_list, columns=rs_profit.fields)
|
||||
if {"pubDate", "statDate", "roeAvg"} <= set(rs_profit.fields):
|
||||
df_profit = df_profit[["pubDate", "statDate", "roeAvg"]]
|
||||
df_profit.rename(
|
||||
columns={"pubDate": "date", "statDate": "period", "roeAvg": "value"},
|
||||
inplace=True,
|
||||
)
|
||||
df_profit["value"] = df_profit["value"].apply(_str_to_float)
|
||||
df_profit["field"] = "roeWa"
|
||||
|
||||
forecast_list = []
|
||||
rs_forecast = bs.query_forecast_report(
|
||||
code=symbol, start_date=str(start_datetime.date()), end_date=str(end_datetime.date())
|
||||
)
|
||||
|
||||
while (rs_forecast.error_code == "0") & rs_forecast.next():
|
||||
forecast_list.append(rs_forecast.get_row_data())
|
||||
|
||||
df_forecast = pd.DataFrame(forecast_list, columns=rs_forecast.fields)
|
||||
if {
|
||||
"profitForcastExpPubDate",
|
||||
"profitForcastExpStatDate",
|
||||
"profitForcastChgPctUp",
|
||||
"profitForcastChgPctDwn",
|
||||
} <= set(rs_forecast.fields):
|
||||
df_forecast = df_forecast[
|
||||
[
|
||||
"profitForcastExpPubDate",
|
||||
"profitForcastExpStatDate",
|
||||
"profitForcastChgPctUp",
|
||||
"profitForcastChgPctDwn",
|
||||
]
|
||||
]
|
||||
df_forecast.rename(
|
||||
columns={
|
||||
"profitForcastExpPubDate": "date",
|
||||
"profitForcastExpStatDate": "period",
|
||||
},
|
||||
inplace=True,
|
||||
)
|
||||
|
||||
df_forecast["profitForcastChgPctUp"] = df_forecast["profitForcastChgPctUp"].apply(_str_to_float)
|
||||
df_forecast["profitForcastChgPctDwn"] = df_forecast["profitForcastChgPctDwn"].apply(_str_to_float)
|
||||
df_forecast["value"] = (
|
||||
df_forecast["profitForcastChgPctUp"] + df_forecast["profitForcastChgPctDwn"]
|
||||
) / 200
|
||||
df_forecast["field"] = "YOYNI"
|
||||
df_forecast.drop(["profitForcastChgPctUp", "profitForcastChgPctDwn"], axis=1, inplace=True)
|
||||
|
||||
growth_list = []
|
||||
for year in range(start_datetime.year - 1, end_datetime.year + 1):
|
||||
for q_num in range(0, 4):
|
||||
rs_growth = bs.query_growth_data(code=symbol, year=year, quarter=q_num + 1)
|
||||
while (rs_growth.error_code == "0") & rs_growth.next():
|
||||
row_data = rs_growth.get_row_data()
|
||||
if "pubDate" in rs_growth.fields:
|
||||
pub_date = pd.Timestamp(row_data[rs_growth.fields.index("pubDate")])
|
||||
if pub_date >= start_datetime and pub_date <= end_datetime:
|
||||
growth_list.append(row_data)
|
||||
|
||||
df_growth = pd.DataFrame(growth_list, columns=rs_growth.fields)
|
||||
if {"pubDate", "statDate", "YOYNI"} <= set(rs_growth.fields):
|
||||
df_growth = df_growth[["pubDate", "statDate", "YOYNI"]]
|
||||
df_growth.rename(
|
||||
columns={"pubDate": "date", "statDate": "period", "YOYNI": "value"},
|
||||
inplace=True,
|
||||
)
|
||||
df_growth["value"] = df_growth["value"].apply(_str_to_float)
|
||||
df_growth["field"] = "YOYNI"
|
||||
df_merge = df_report.append([df_profit, df_forecast, df_growth])
|
||||
|
||||
return df_merge
|
||||
except Exception as e:
|
||||
logger.warning(f"{error_msg}:{e}")
|
||||
|
||||
def _process_data(self, df, symbol, interval):
|
||||
error_msg = f"{symbol}-{interval}"
|
||||
|
||||
def _process_period(r):
|
||||
_date = pd.Timestamp(r)
|
||||
return _date.year if interval == self.INTERVAL_annual else _date.year * 100 + (_date.month - 1) // 3 + 1
|
||||
|
||||
try:
|
||||
_date = df["period"].apply(
|
||||
lambda x: (
|
||||
pd.to_datetime(x) + pd.DateOffset(days=(45 if interval == self.INTERVAL_quarterly else 90))
|
||||
).date()
|
||||
)
|
||||
df["date"] = df["date"].fillna(_date.astype(str))
|
||||
df["period"] = df["period"].apply(_process_period)
|
||||
return df
|
||||
except Exception as e:
|
||||
logger.warning(f"{error_msg}:{e}")
|
||||
|
||||
def get_data(
|
||||
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
||||
) -> [pd.DataFrame]:
|
||||
|
||||
if interval == self.INTERVAL_quarterly:
|
||||
_result = self._get_data_from_baostock(symbol, interval, start_datetime, end_datetime)
|
||||
if _result is None or _result.empty:
|
||||
return _result
|
||||
else:
|
||||
return self._process_data(_result, symbol, interval)
|
||||
else:
|
||||
raise ValueError(f"cannot support {interval}")
|
||||
return self._process_data(_result, interval)
|
||||
|
||||
@property
|
||||
def min_numbers_trading(self):
|
||||
pass
|
||||
|
||||
|
||||
class Run(BaseRun):
|
||||
def __init__(self, source_dir=None, max_workers=1, interval="quarterly"):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str
|
||||
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
||||
normalize_dir: str
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
max_workers: int
|
||||
Concurrent number, default is 4
|
||||
interval: str
|
||||
freq, value from [quarterly, annual], default 1d
|
||||
"""
|
||||
super().__init__(source_dir=source_dir, max_workers=max_workers, interval=interval)
|
||||
|
||||
@property
|
||||
def collector_class_name(self):
|
||||
return "PitCollector"
|
||||
|
||||
@property
|
||||
def default_base_dir(self) -> [Path, str]:
|
||||
return CUR_DIR
|
||||
|
||||
def download_data(
|
||||
self,
|
||||
max_collector_count=1,
|
||||
delay=0,
|
||||
start=None,
|
||||
end=None,
|
||||
interval="quarterly",
|
||||
check_data_length=False,
|
||||
limit_nums=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""download data from Internet
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [quarterly, annual], default 1d
|
||||
start: str
|
||||
start datetime, default "2000-01-01"
|
||||
end: str
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
||||
check_data_length: bool # if this param useful?
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
|
||||
Examples
|
||||
---------
|
||||
# get quarterly data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/cn_data/source/pit_quarter --start 2000-01-01 --end 2021-01-01 --interval quarterly
|
||||
"""
|
||||
|
||||
super(Run, self).download_data(
|
||||
max_collector_count, delay, start, end, interval, check_data_length, limit_nums, **kwargs
|
||||
)
|
||||
|
||||
def normalize_class_name(self):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bs.login()
|
||||
fire.Fire(Run)
|
||||
bs.logout()
|
||||
9
scripts/data_collector/pit/requirements.txt
Normal file
9
scripts/data_collector/pit/requirements.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
loguru
|
||||
fire
|
||||
tqdm
|
||||
requests
|
||||
pandas
|
||||
lxml
|
||||
loguru
|
||||
baostock
|
||||
yahooquery
|
||||
194
scripts/data_collector/pit/test_pit.py
Normal file
194
scripts/data_collector/pit/test_pit.py
Normal file
@@ -0,0 +1,194 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
import unittest
|
||||
|
||||
|
||||
class TestPIT(unittest.TestCase):
|
||||
"""
|
||||
NOTE!!!!!!
|
||||
The assert of this test assumes that users follows the cmd below and only download 2 stock.
|
||||
`python collector.py download_data --source_dir ./csv_pit --start 2000-01-01 --end 2020-01-01 --interval quarterly --symbol_flt_regx "^(600519|000725).*"`
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
# qlib.init(kernels=1) # NOTE: set kernel to 1 to make it debug easier
|
||||
qlib.init() # NOTE: set kernel to 1 to make it debug easier
|
||||
|
||||
def to_str(self, obj):
|
||||
return "".join(str(obj).split())
|
||||
|
||||
def check_same(self, a, b):
|
||||
self.assertEqual(self.to_str(a), self.to_str(b))
|
||||
|
||||
def test_query(self):
|
||||
instruments = ["sh600519"]
|
||||
fields = ["P($$roewa_q)", "P($$yoyni_q)"]
|
||||
# Mao Tai published 2019Q2 report at 2019-07-13 & 2019-07-18
|
||||
# - http://www.cninfo.com.cn/new/commonUrl/pageOfSearch?url=disclosure/list/search&lastPage=index
|
||||
data = D.features(instruments, fields, start_time="2019-01-01", end_time="20190719", freq="day")
|
||||
|
||||
print(data)
|
||||
|
||||
res = """
|
||||
P($$roewa_q) P($$yoyni_q)
|
||||
count 133.000000 133.000000
|
||||
mean 0.196412 0.277930
|
||||
std 0.097591 0.030262
|
||||
min 0.000000 0.243892
|
||||
25% 0.094737 0.243892
|
||||
50% 0.255220 0.304181
|
||||
75% 0.255220 0.305041
|
||||
max 0.344644 0.305041
|
||||
"""
|
||||
self.check_same(data.describe(), res)
|
||||
|
||||
res = """
|
||||
P($$roewa_q) P($$yoyni_q)
|
||||
instrument datetime
|
||||
sh600519 2019-07-15 0.000000 0.305041
|
||||
2019-07-16 0.000000 0.305041
|
||||
2019-07-17 0.000000 0.305041
|
||||
2019-07-18 0.175322 0.252650
|
||||
2019-07-19 0.175322 0.252650
|
||||
"""
|
||||
self.check_same(data.tail(), res)
|
||||
|
||||
def test_no_exist_data(self):
|
||||
fields = ["P($$roewa_q)", "P($$yoyni_q)", "$close"]
|
||||
data = D.features(["sh600519", "sh601988"], fields, start_time="2019-01-01", end_time="20190719", freq="day")
|
||||
data["$close"] = 1 # in case of different dataset gives different values
|
||||
print(data)
|
||||
expect = """
|
||||
P($$roewa_q) P($$yoyni_q) $close
|
||||
instrument datetime
|
||||
sh600519 2019-01-02 0.25522 0.243892 1
|
||||
2019-01-03 0.25522 0.243892 1
|
||||
2019-01-04 0.25522 0.243892 1
|
||||
2019-01-07 0.25522 0.243892 1
|
||||
2019-01-08 0.25522 0.243892 1
|
||||
... ... ... ...
|
||||
sh601988 2019-07-15 NaN NaN 1
|
||||
2019-07-16 NaN NaN 1
|
||||
2019-07-17 NaN NaN 1
|
||||
2019-07-18 NaN NaN 1
|
||||
2019-07-19 NaN NaN 1
|
||||
|
||||
[266 rows x 3 columns]
|
||||
"""
|
||||
self.check_same(data, expect)
|
||||
|
||||
def test_expr(self):
|
||||
fields = [
|
||||
"P(Mean($$roewa_q, 1))",
|
||||
"P($$roewa_q)",
|
||||
"P(Mean($$roewa_q, 2))",
|
||||
"P(Ref($$roewa_q, 1))",
|
||||
"P((Ref($$roewa_q, 1) +$$roewa_q) / 2)",
|
||||
]
|
||||
instruments = ["sh600519"]
|
||||
data = D.features(instruments, fields, start_time="2019-01-01", end_time="20190719", freq="day")
|
||||
expect = """
|
||||
P(Mean($$roewa_q, 1)) P($$roewa_q) P(Mean($$roewa_q, 2)) P(Ref($$roewa_q, 1)) P((Ref($$roewa_q, 1) +$$roewa_q) / 2)
|
||||
instrument datetime
|
||||
sh600519 2019-07-01 0.094737 0.094737 0.219691 0.344644 0.219691
|
||||
2019-07-02 0.094737 0.094737 0.219691 0.344644 0.219691
|
||||
2019-07-03 0.094737 0.094737 0.219691 0.344644 0.219691
|
||||
2019-07-04 0.094737 0.094737 0.219691 0.344644 0.219691
|
||||
2019-07-05 0.094737 0.094737 0.219691 0.344644 0.219691
|
||||
2019-07-08 0.094737 0.094737 0.219691 0.344644 0.219691
|
||||
2019-07-09 0.094737 0.094737 0.219691 0.344644 0.219691
|
||||
2019-07-10 0.094737 0.094737 0.219691 0.344644 0.219691
|
||||
2019-07-11 0.094737 0.094737 0.219691 0.344644 0.219691
|
||||
2019-07-12 0.094737 0.094737 0.219691 0.344644 0.219691
|
||||
2019-07-15 0.000000 0.000000 0.047369 0.094737 0.047369
|
||||
2019-07-16 0.000000 0.000000 0.047369 0.094737 0.047369
|
||||
2019-07-17 0.000000 0.000000 0.047369 0.094737 0.047369
|
||||
2019-07-18 0.175322 0.175322 0.135029 0.094737 0.135029
|
||||
2019-07-19 0.175322 0.175322 0.135029 0.094737 0.135029
|
||||
"""
|
||||
self.check_same(data.tail(15), expect)
|
||||
|
||||
def test_unlimit(self):
|
||||
# fields = ["P(Mean($$roewa_q, 1))", "P($$roewa_q)", "P(Mean($$roewa_q, 2))", "P(Ref($$roewa_q, 1))", "P((Ref($$roewa_q, 1) +$$roewa_q) / 2)"]
|
||||
fields = ["P($$roewa_q)"]
|
||||
instruments = ["sh600519"]
|
||||
_ = D.features(instruments, fields, freq="day") # this should not raise error
|
||||
data = D.features(instruments, fields, end_time="20200101", freq="day") # this should not raise error
|
||||
s = data.iloc[:, 0]
|
||||
# You can check the expected value based on the content in `docs/advanced/PIT.rst`
|
||||
expect = """
|
||||
instrument datetime
|
||||
sh600519 1999-11-10 NaN
|
||||
2007-04-30 0.090219
|
||||
2007-08-17 0.139330
|
||||
2007-10-23 0.245863
|
||||
2008-03-03 0.347900
|
||||
2008-03-13 0.395989
|
||||
2008-04-22 0.100724
|
||||
2008-08-28 0.249968
|
||||
2008-10-27 0.334120
|
||||
2009-03-25 0.390117
|
||||
2009-04-21 0.102675
|
||||
2009-08-07 0.230712
|
||||
2009-10-26 0.300730
|
||||
2010-04-02 0.335461
|
||||
2010-04-26 0.083825
|
||||
2010-08-12 0.200545
|
||||
2010-10-29 0.260986
|
||||
2011-03-21 0.307393
|
||||
2011-04-25 0.097411
|
||||
2011-08-31 0.248251
|
||||
2011-10-18 0.318919
|
||||
2012-03-23 0.403900
|
||||
2012-04-11 0.403925
|
||||
2012-04-26 0.112148
|
||||
2012-08-10 0.264847
|
||||
2012-10-26 0.370487
|
||||
2013-03-29 0.450047
|
||||
2013-04-18 0.099958
|
||||
2013-09-02 0.210442
|
||||
2013-10-16 0.304543
|
||||
2014-03-25 0.394328
|
||||
2014-04-25 0.083217
|
||||
2014-08-29 0.164503
|
||||
2014-10-30 0.234085
|
||||
2015-04-21 0.078494
|
||||
2015-08-28 0.137504
|
||||
2015-10-26 0.201709
|
||||
2016-03-24 0.264205
|
||||
2016-04-21 0.073664
|
||||
2016-08-29 0.136576
|
||||
2016-10-31 0.188062
|
||||
2017-04-17 0.244385
|
||||
2017-04-25 0.080614
|
||||
2017-07-28 0.151510
|
||||
2017-10-26 0.254166
|
||||
2018-03-28 0.329542
|
||||
2018-05-02 0.088887
|
||||
2018-08-02 0.170563
|
||||
2018-10-29 0.255220
|
||||
2019-03-29 0.344644
|
||||
2019-04-25 0.094737
|
||||
2019-07-15 0.000000
|
||||
2019-07-18 0.175322
|
||||
2019-10-16 0.255819
|
||||
Name: P($$roewa_q), dtype: float32
|
||||
"""
|
||||
|
||||
self.check_same(s[~s.duplicated().values], expect)
|
||||
|
||||
def test_expr2(self):
|
||||
instruments = ["sh600519"]
|
||||
fields = ["P($$roewa_q)", "P($$yoyni_q)"]
|
||||
fields += ["P(($$roewa_q / $$yoyni_q) / Ref($$roewa_q / $$yoyni_q, 1) - 1)"]
|
||||
fields += ["P(Sum($$yoyni_q, 4))"]
|
||||
fields += ["$close", "P($$roewa_q) * $close"]
|
||||
data = D.features(instruments, fields, start_time="2019-01-01", end_time="2020-01-01", freq="day")
|
||||
print(data)
|
||||
print(data.describe())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -19,7 +19,6 @@ from yahooquery import Ticker
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from pycoingecko import CoinGeckoAPI
|
||||
|
||||
HS_SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}"
|
||||
|
||||
@@ -43,7 +42,6 @@ _HS_SYMBOLS = None
|
||||
_US_SYMBOLS = None
|
||||
_IN_SYMBOLS = None
|
||||
_EN_FUND_SYMBOLS = None
|
||||
_CG_CRYPTO_SYMBOLS = None
|
||||
_CALENDAR_MAP = {}
|
||||
|
||||
# NOTE: Until 2020-10-20 20:00:00
|
||||
@@ -379,37 +377,6 @@ def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
return _EN_FUND_SYMBOLS
|
||||
|
||||
|
||||
def get_cg_crypto_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
"""get crypto symbols in coingecko
|
||||
|
||||
Returns
|
||||
-------
|
||||
crypto symbols in given exchanges list of coingecko
|
||||
"""
|
||||
global _CG_CRYPTO_SYMBOLS
|
||||
|
||||
@deco_retry
|
||||
def _get_coingecko():
|
||||
try:
|
||||
cg = CoinGeckoAPI()
|
||||
resp = pd.DataFrame(cg.get_coins_markets(vs_currency="usd"))
|
||||
except:
|
||||
raise ValueError("request error")
|
||||
try:
|
||||
_symbols = resp["id"].to_list()
|
||||
except Exception as e:
|
||||
logger.warning(f"request error: {e}")
|
||||
raise
|
||||
return _symbols
|
||||
|
||||
if _CG_CRYPTO_SYMBOLS is None:
|
||||
_all_symbols = _get_coingecko()
|
||||
|
||||
_CG_CRYPTO_SYMBOLS = sorted(set(_all_symbols))
|
||||
|
||||
return _CG_CRYPTO_SYMBOLS
|
||||
|
||||
|
||||
def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str:
|
||||
"""symbol suffix to prefix
|
||||
|
||||
|
||||
282
scripts/dump_pit.py
Normal file
282
scripts/dump_pit.py
Normal file
@@ -0,0 +1,282 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
TODO:
|
||||
- A more well-designed PIT database is required.
|
||||
- seperated insert, delete, update, query operations are required.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import shutil
|
||||
import struct
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Union
|
||||
from functools import partial
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
|
||||
import fire
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from qlib.utils import fname_to_code, code_to_fname, get_period_offset
|
||||
from qlib.config import C
|
||||
|
||||
|
||||
class DumpPitData:
|
||||
PIT_DIR_NAME = "financial"
|
||||
PIT_CSV_SEP = ","
|
||||
DATA_FILE_SUFFIX = ".data"
|
||||
INDEX_FILE_SUFFIX = ".index"
|
||||
|
||||
INTERVAL_quarterly = "quarterly"
|
||||
INTERVAL_annual = "annual"
|
||||
|
||||
PERIOD_DTYPE = C.pit_record_type["period"]
|
||||
INDEX_DTYPE = C.pit_record_type["index"]
|
||||
DATA_DTYPE = "".join(
|
||||
[
|
||||
C.pit_record_type["date"],
|
||||
C.pit_record_type["period"],
|
||||
C.pit_record_type["value"],
|
||||
C.pit_record_type["index"],
|
||||
]
|
||||
)
|
||||
|
||||
NA_INDEX = C.pit_record_nan["index"]
|
||||
|
||||
INDEX_DTYPE_SIZE = struct.calcsize(INDEX_DTYPE)
|
||||
PERIOD_DTYPE_SIZE = struct.calcsize(PERIOD_DTYPE)
|
||||
DATA_DTYPE_SIZE = struct.calcsize(DATA_DTYPE)
|
||||
|
||||
UPDATE_MODE = "update"
|
||||
ALL_MODE = "all"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
csv_path: str,
|
||||
qlib_dir: str,
|
||||
backup_dir: str = None,
|
||||
freq: str = "quarterly",
|
||||
max_workers: int = 16,
|
||||
date_column_name: str = "date",
|
||||
period_column_name: str = "period",
|
||||
value_column_name: str = "value",
|
||||
field_column_name: str = "field",
|
||||
file_suffix: str = ".csv",
|
||||
exclude_fields: str = "",
|
||||
include_fields: str = "",
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
csv_path: str
|
||||
stock data path or directory
|
||||
qlib_dir: str
|
||||
qlib(dump) data director
|
||||
backup_dir: str, default None
|
||||
if backup_dir is not None, backup qlib_dir to backup_dir
|
||||
freq: str, default "quarterly"
|
||||
data frequency
|
||||
max_workers: int, default None
|
||||
number of threads
|
||||
date_column_name: str, default "date"
|
||||
the name of the date field in the csv
|
||||
file_suffix: str, default ".csv"
|
||||
file suffix
|
||||
include_fields: tuple
|
||||
dump fields
|
||||
exclude_fields: tuple
|
||||
fields not dumped
|
||||
limit_nums: int
|
||||
Use when debugging, default None
|
||||
"""
|
||||
csv_path = Path(csv_path).expanduser()
|
||||
if isinstance(exclude_fields, str):
|
||||
exclude_fields = exclude_fields.split(",")
|
||||
if isinstance(include_fields, str):
|
||||
include_fields = include_fields.split(",")
|
||||
self._exclude_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, exclude_fields)))
|
||||
self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields)))
|
||||
self.file_suffix = file_suffix
|
||||
self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path])
|
||||
if limit_nums is not None:
|
||||
self.csv_files = self.csv_files[: int(limit_nums)]
|
||||
self.qlib_dir = Path(qlib_dir).expanduser()
|
||||
self.backup_dir = backup_dir if backup_dir is None else Path(backup_dir).expanduser()
|
||||
if backup_dir is not None:
|
||||
self._backup_qlib_dir(Path(backup_dir).expanduser())
|
||||
|
||||
self.works = max_workers
|
||||
self.date_column_name = date_column_name
|
||||
self.period_column_name = period_column_name
|
||||
self.value_column_name = value_column_name
|
||||
self.field_column_name = field_column_name
|
||||
|
||||
self._mode = self.ALL_MODE
|
||||
|
||||
def _backup_qlib_dir(self, target_dir: Path):
|
||||
shutil.copytree(str(self.qlib_dir.resolve()), str(target_dir.resolve()))
|
||||
|
||||
def get_source_data(self, file_path: Path) -> pd.DataFrame:
|
||||
df = pd.read_csv(str(file_path.resolve()), low_memory=False)
|
||||
df[self.value_column_name] = df[self.value_column_name].astype("float32")
|
||||
df[self.date_column_name] = df[self.date_column_name].str.replace("-", "").astype("int32")
|
||||
# df.drop_duplicates([self.date_field_name], inplace=True)
|
||||
return df
|
||||
|
||||
def get_symbol_from_file(self, file_path: Path) -> str:
|
||||
return fname_to_code(file_path.name[: -len(self.file_suffix)].strip().lower())
|
||||
|
||||
def get_dump_fields(self, df: Iterable[str]) -> Iterable[str]:
|
||||
return (
|
||||
set(self._include_fields)
|
||||
if self._include_fields
|
||||
else set(df[self.field_column_name]) - set(self._exclude_fields)
|
||||
if self._exclude_fields
|
||||
else set(df[self.field_column_name])
|
||||
)
|
||||
|
||||
def get_filenames(self, symbol, field, interval):
|
||||
dir_name = self.qlib_dir.joinpath(self.PIT_DIR_NAME, symbol)
|
||||
dir_name.mkdir(parents=True, exist_ok=True)
|
||||
return (
|
||||
dir_name.joinpath(f"{field}_{interval[0]}{self.DATA_FILE_SUFFIX}".lower()),
|
||||
dir_name.joinpath(f"{field}_{interval[0]}{self.INDEX_FILE_SUFFIX}".lower()),
|
||||
)
|
||||
|
||||
def _dump_pit(
|
||||
self,
|
||||
file_path: str,
|
||||
interval: str = "quarterly",
|
||||
overwrite: bool = False,
|
||||
):
|
||||
"""
|
||||
dump data as the following format:
|
||||
`/path/to/<field>.data`
|
||||
[date, period, value, _next]
|
||||
[date, period, value, _next]
|
||||
[...]
|
||||
`/path/to/<field>.index`
|
||||
[first_year, index, index, ...]
|
||||
|
||||
`<field.data>` contains the data as the point-in-time (PIT) order: `value` of `period`
|
||||
is published at `date`, and its successive revised value can be found at `_next` (linked list).
|
||||
|
||||
`<field>.index` contains the index of value for each period (quarter or year). To save
|
||||
disk space, we only store the `first_year` as its followings periods can be easily infered.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
stock symbol
|
||||
interval: str
|
||||
data interval
|
||||
overwrite: bool
|
||||
whether overwrite existing data or update only
|
||||
"""
|
||||
symbol = self.get_symbol_from_file(file_path)
|
||||
df = self.get_source_data(file_path)
|
||||
if df.empty:
|
||||
logger.warning(f"{symbol} file is empty")
|
||||
return
|
||||
for field in self.get_dump_fields(df):
|
||||
df_sub = df.query(f'{self.field_column_name}=="{field}"').sort_values(self.date_column_name)
|
||||
if df_sub.empty:
|
||||
logger.warning(f"field {field} of {symbol} is empty")
|
||||
continue
|
||||
data_file, index_file = self.get_filenames(symbol, field, interval)
|
||||
|
||||
## calculate first & last period
|
||||
start_year = df_sub[self.period_column_name].min()
|
||||
end_year = df_sub[self.period_column_name].max()
|
||||
if interval == self.INTERVAL_quarterly:
|
||||
start_year //= 100
|
||||
end_year //= 100
|
||||
|
||||
# adjust `first_year` if existing data found
|
||||
if not overwrite and index_file.exists():
|
||||
with open(index_file, "rb") as fi:
|
||||
(first_year,) = struct.unpack(self.PERIOD_DTYPE, fi.read(self.PERIOD_DTYPE_SIZE))
|
||||
n_years = len(fi.read()) // self.INDEX_DTYPE_SIZE
|
||||
if interval == self.INTERVAL_quarterly:
|
||||
n_years //= 4
|
||||
start_year = first_year + n_years
|
||||
else:
|
||||
with open(index_file, "wb") as f:
|
||||
f.write(struct.pack(self.PERIOD_DTYPE, start_year))
|
||||
first_year = start_year
|
||||
|
||||
# if data already exists, continue to the next field
|
||||
if start_year > end_year:
|
||||
logger.warning(f"{symbol}-{field} data already exists, continue to the next field")
|
||||
continue
|
||||
|
||||
# dump index filled with NA
|
||||
with open(index_file, "ab") as fi:
|
||||
for year in range(start_year, end_year + 1):
|
||||
if interval == self.INTERVAL_quarterly:
|
||||
fi.write(struct.pack(self.INDEX_DTYPE * 4, *[self.NA_INDEX] * 4))
|
||||
else:
|
||||
fi.write(struct.pack(self.INDEX_DTYPE, self.NA_INDEX))
|
||||
|
||||
# if data already exists, remove overlapped data
|
||||
if not overwrite and data_file.exists():
|
||||
with open(data_file, "rb") as fd:
|
||||
fd.seek(-self.DATA_DTYPE_SIZE, 2)
|
||||
last_date, _, _, _ = struct.unpack(self.DATA_DTYPE, fd.read())
|
||||
df_sub = df_sub.query(f"{self.date_column_name}>{last_date}")
|
||||
# otherwise,
|
||||
# 1) truncate existing file or create a new file with `wb+` if overwrite,
|
||||
# 2) or append existing file or create a new file with `ab+` if not overwrite
|
||||
else:
|
||||
with open(data_file, "wb+" if overwrite else "ab+"):
|
||||
pass
|
||||
|
||||
with open(data_file, "rb+") as fd, open(index_file, "rb+") as fi:
|
||||
|
||||
# update index if needed
|
||||
for i, row in df_sub.iterrows():
|
||||
# get index
|
||||
offset = get_period_offset(first_year, row.period, interval == self.INTERVAL_quarterly)
|
||||
|
||||
fi.seek(self.PERIOD_DTYPE_SIZE + self.INDEX_DTYPE_SIZE * offset)
|
||||
(cur_index,) = struct.unpack(self.INDEX_DTYPE, fi.read(self.INDEX_DTYPE_SIZE))
|
||||
|
||||
# Case I: new data => update `_next` with current index
|
||||
if cur_index == self.NA_INDEX:
|
||||
fi.seek(self.PERIOD_DTYPE_SIZE + self.INDEX_DTYPE_SIZE * offset)
|
||||
fi.write(struct.pack(self.INDEX_DTYPE, fd.tell()))
|
||||
# Case II: previous data exists => find and update the last `_next`
|
||||
else:
|
||||
_cur_fd = fd.tell()
|
||||
prev_index = self.NA_INDEX
|
||||
while cur_index != self.NA_INDEX: # NOTE: first iter always != NA_INDEX
|
||||
fd.seek(cur_index + self.DATA_DTYPE_SIZE - self.INDEX_DTYPE_SIZE)
|
||||
prev_index = cur_index
|
||||
(cur_index,) = struct.unpack(self.INDEX_DTYPE, fd.read(self.INDEX_DTYPE_SIZE))
|
||||
fd.seek(prev_index + self.DATA_DTYPE_SIZE - self.INDEX_DTYPE_SIZE)
|
||||
fd.write(struct.pack(self.INDEX_DTYPE, _cur_fd)) # NOTE: add _next pointer
|
||||
fd.seek(_cur_fd)
|
||||
|
||||
# dump data
|
||||
fd.write(struct.pack(self.DATA_DTYPE, row.date, row.period, row.value, self.NA_INDEX))
|
||||
|
||||
def dump(self, interval="quarterly", overwrite=False):
|
||||
logger.info("start dump pit data......")
|
||||
_dump_func = partial(self._dump_pit, interval=interval, overwrite=overwrite)
|
||||
|
||||
with tqdm(total=len(self.csv_files)) as p_bar:
|
||||
with ProcessPoolExecutor(max_workers=self.works) as executor:
|
||||
for _ in executor.map(_dump_func, self.csv_files):
|
||||
p_bar.update()
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self.dump()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(DumpPitData)
|
||||
Reference in New Issue
Block a user