mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Merge pull request #292 from wangershi/addFund
Add fund data as an example
This commit is contained in:
@@ -46,7 +46,7 @@ class BaseCollector(abc.ABC):
|
||||
Parameters
|
||||
----------
|
||||
save_dir: str
|
||||
stock save dir
|
||||
instrument save dir
|
||||
max_workers: int
|
||||
workers, default 4
|
||||
max_collector_count: int
|
||||
@@ -77,11 +77,11 @@ class BaseCollector(abc.ABC):
|
||||
self.start_datetime = self.normalize_start_datetime(start)
|
||||
self.end_datetime = self.normalize_end_datetime(end)
|
||||
|
||||
self.stock_list = sorted(set(self.get_stock_list()))
|
||||
self.instrument_list = sorted(set(self.get_instrument_list()))
|
||||
|
||||
if limit_nums is not None:
|
||||
try:
|
||||
self.stock_list = self.stock_list[: int(limit_nums)]
|
||||
self.instrument_list = self.instrument_list[: int(limit_nums)]
|
||||
except Exception as e:
|
||||
logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored")
|
||||
|
||||
@@ -108,8 +108,8 @@ class BaseCollector(abc.ABC):
|
||||
raise NotImplementedError("rewrite min_numbers_trading")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_stock_list(self):
|
||||
raise NotImplementedError("rewrite get_stock_list")
|
||||
def get_instrument_list(self):
|
||||
raise NotImplementedError("rewrite get_instrument_list")
|
||||
|
||||
@abc.abstractmethod
|
||||
def normalize_symbol(self, symbol: str):
|
||||
@@ -158,27 +158,27 @@ class BaseCollector(abc.ABC):
|
||||
return _result
|
||||
|
||||
def save_instrument(self, symbol, df: pd.DataFrame):
|
||||
"""save stock data to file
|
||||
"""save instrument data to file
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
stock code
|
||||
instrument code
|
||||
df : pd.DataFrame
|
||||
df.columns must contain "symbol" and "datetime"
|
||||
"""
|
||||
if df.empty:
|
||||
if df is None or df.empty:
|
||||
logger.warning(f"{symbol} is empty")
|
||||
return
|
||||
|
||||
symbol = self.normalize_symbol(symbol)
|
||||
symbol = code_to_fname(symbol)
|
||||
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
|
||||
instrument_path = self.save_dir.joinpath(f"{symbol}.csv")
|
||||
df["symbol"] = symbol
|
||||
if stock_path.exists():
|
||||
_old_df = pd.read_csv(stock_path)
|
||||
if instrument_path.exists():
|
||||
_old_df = pd.read_csv(instrument_path)
|
||||
df = _old_df.append(df, sort=False)
|
||||
df.to_csv(stock_path, index=False)
|
||||
df.to_csv(instrument_path, index=False)
|
||||
|
||||
def cache_small_data(self, symbol, df):
|
||||
if len(df) <= self.min_numbers_trading:
|
||||
@@ -191,38 +191,38 @@ class BaseCollector(abc.ABC):
|
||||
self.mini_symbol_map.pop(symbol)
|
||||
return self.NORMAL_FLAG
|
||||
|
||||
def _collector(self, stock_list):
|
||||
def _collector(self, instrument_list):
|
||||
|
||||
error_symbol = []
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
with tqdm(total=len(stock_list)) as p_bar:
|
||||
for _symbol, _result in zip(stock_list, executor.map(self._simple_collector, stock_list)):
|
||||
with tqdm(total=len(instrument_list)) as p_bar:
|
||||
for _symbol, _result in zip(instrument_list, executor.map(self._simple_collector, instrument_list)):
|
||||
if _result != self.NORMAL_FLAG:
|
||||
error_symbol.append(_symbol)
|
||||
p_bar.update()
|
||||
print(error_symbol)
|
||||
logger.info(f"error symbol nums: {len(error_symbol)}")
|
||||
logger.info(f"current get symbol nums: {len(stock_list)}")
|
||||
logger.info(f"current get symbol nums: {len(instrument_list)}")
|
||||
error_symbol.extend(self.mini_symbol_map.keys())
|
||||
return sorted(set(error_symbol))
|
||||
|
||||
def collector_data(self):
|
||||
"""collector data"""
|
||||
logger.info("start collector data......")
|
||||
stock_list = self.stock_list
|
||||
instrument_list = self.instrument_list
|
||||
for i in range(self.max_collector_count):
|
||||
if not stock_list:
|
||||
if not instrument_list:
|
||||
break
|
||||
logger.info(f"getting data: {i+1}")
|
||||
stock_list = self._collector(stock_list)
|
||||
instrument_list = self._collector(instrument_list)
|
||||
logger.info(f"{i+1} finish.")
|
||||
for _symbol, _df_list in self.mini_symbol_map.items():
|
||||
self.save_instrument(
|
||||
_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])
|
||||
)
|
||||
if self.mini_symbol_map:
|
||||
logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self.mini_symbol_map.keys())}")
|
||||
logger.info(f"total {len(self.stock_list)}, error: {len(set(stock_list))}")
|
||||
logger.warning(f"less than {self.min_numbers_trading} instrument list: {list(self.mini_symbol_map.keys())}")
|
||||
logger.info(f"total {len(self.instrument_list)}, error: {len(set(instrument_list))}")
|
||||
|
||||
|
||||
class BaseNormalize(abc.ABC):
|
||||
@@ -386,9 +386,9 @@ class BaseRun(abc.ABC):
|
||||
Examples
|
||||
---------
|
||||
# get daily data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
$ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
# get 1m data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
|
||||
$ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
|
||||
"""
|
||||
|
||||
_class = getattr(self._cur_module, self.collector_class_name) # type: Type[BaseCollector]
|
||||
@@ -416,7 +416,7 @@ class BaseRun(abc.ABC):
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/instrument_data/source --normalize_dir ~/.qlib/instrument_data/normalize --region CN --interval 1d
|
||||
"""
|
||||
_class = getattr(self._cur_module, self.normalize_class_name)
|
||||
yc = Normalize(
|
||||
|
||||
51
scripts/data_collector/fund/README.md
Normal file
51
scripts/data_collector/fund/README.md
Normal file
@@ -0,0 +1,51 @@
|
||||
# Collect Fund Data
|
||||
|
||||
> *Please pay **ATTENTION** that the data is collected from [天天基金网](https://fund.eastmoney.com/) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Collector Data
|
||||
|
||||
|
||||
### CN Data
|
||||
|
||||
#### 1d from East Money
|
||||
|
||||
```bash
|
||||
|
||||
# download from eastmoney.com
|
||||
python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
|
||||
# normalize
|
||||
python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_1d --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ
|
||||
|
||||
# dump data
|
||||
cd qlib/scripts
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/fund_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/cn_fund_data --freq day --date_field_name FSRQ --include_fields DWJZ,LJJZ
|
||||
|
||||
```
|
||||
|
||||
### using data
|
||||
|
||||
```python
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/cn_fund_data")
|
||||
df = D.features(D.instruments(market="all"), ["$DWJZ", "$LJJZ"], freq="day")
|
||||
```
|
||||
|
||||
|
||||
### Help
|
||||
```bash
|
||||
pythono collector.py collector_data --help
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
- interval: 1d
|
||||
- region: CN
|
||||
312
scripts/data_collector/fund/collector.py
Normal file
312
scripts/data_collector/fund/collector.py
Normal file
@@ -0,0 +1,312 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
import sys
|
||||
import copy
|
||||
import time
|
||||
import datetime
|
||||
import importlib
|
||||
import json
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Type
|
||||
|
||||
import fire
|
||||
import requests
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from loguru import logger
|
||||
from dateutil.tz import tzlocal
|
||||
from qlib.config import REG_CN as REGION_CN
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
|
||||
from data_collector.utils import get_calendar_list, get_en_fund_symbols
|
||||
|
||||
INDEX_BENCH_URL = "http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}"
|
||||
|
||||
|
||||
class FundCollector(BaseCollector):
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: [str, Path],
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
max_workers=4,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
save_dir: str
|
||||
fund save dir
|
||||
max_workers: int
|
||||
workers, default 4
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1min
|
||||
start: str
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
"""
|
||||
super(FundCollector, self).__init__(
|
||||
save_dir=save_dir,
|
||||
start=start,
|
||||
end=end,
|
||||
interval=interval,
|
||||
max_workers=max_workers,
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
)
|
||||
|
||||
self.init_datetime()
|
||||
|
||||
def init_datetime(self):
|
||||
if self.interval == self.INTERVAL_1min:
|
||||
self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN)
|
||||
elif self.interval == self.INTERVAL_1d:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"interval error: {self.interval}")
|
||||
|
||||
self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)
|
||||
self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)
|
||||
|
||||
@staticmethod
|
||||
def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone):
|
||||
try:
|
||||
dt = pd.Timestamp(dt, tz=timezone).timestamp()
|
||||
dt = pd.Timestamp(dt, tz=tzlocal(), unit="s")
|
||||
except ValueError as e:
|
||||
pass
|
||||
return dt
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def _timezone(self):
|
||||
raise NotImplementedError("rewrite get_timezone")
|
||||
|
||||
@staticmethod
|
||||
def get_data_from_remote(symbol, interval, start, end):
|
||||
error_msg = f"{symbol}-{interval}-{start}-{end}"
|
||||
|
||||
try:
|
||||
# TODO: numberOfHistoricalDaysToCrawl should be bigger enouhg
|
||||
url = INDEX_BENCH_URL.format(
|
||||
index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end
|
||||
)
|
||||
resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"})
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
|
||||
data = json.loads(resp.text.split("(")[-1].split(")")[0])
|
||||
|
||||
# Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html
|
||||
SYType = data["Data"]["SYType"]
|
||||
if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"):
|
||||
raise Exception("The fund contains 每*份收益")
|
||||
|
||||
# TODO: should we sort the value by datetime?
|
||||
_resp = pd.DataFrame(data["Data"]["LSJZList"])
|
||||
|
||||
if isinstance(_resp, pd.DataFrame):
|
||||
return _resp.reset_index()
|
||||
except Exception as e:
|
||||
logger.warning(f"{error_msg}:{e}")
|
||||
|
||||
def get_data(
|
||||
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
||||
) -> [pd.DataFrame]:
|
||||
def _get_simple(start_, end_):
|
||||
self.sleep()
|
||||
_remote_interval = interval
|
||||
return self.get_data_from_remote(
|
||||
symbol,
|
||||
interval=_remote_interval,
|
||||
start=start_,
|
||||
end=end_,
|
||||
)
|
||||
|
||||
if interval == self.INTERVAL_1d:
|
||||
_result = _get_simple(start_datetime, end_datetime)
|
||||
else:
|
||||
raise ValueError(f"cannot support {interval}")
|
||||
return _result
|
||||
|
||||
|
||||
class FundollectorCN(FundCollector, ABC):
|
||||
def get_instrument_list(self):
|
||||
logger.info("get cn fund symbols......")
|
||||
symbols = get_en_fund_symbols()
|
||||
logger.info(f"get {len(symbols)} symbols.")
|
||||
return symbols
|
||||
|
||||
def normalize_symbol(self, symbol):
|
||||
return symbol
|
||||
|
||||
@property
|
||||
def _timezone(self):
|
||||
return "Asia/Shanghai"
|
||||
|
||||
|
||||
class FundCollectorCN1d(FundollectorCN):
|
||||
@property
|
||||
def min_numbers_trading(self):
|
||||
return 252 / 4
|
||||
|
||||
|
||||
class FundNormalize(BaseNormalize):
|
||||
DAILY_FORMAT = "%Y-%m-%d"
|
||||
|
||||
@staticmethod
|
||||
def normalize_fund(
|
||||
df: pd.DataFrame,
|
||||
calendar_list: list = None,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
if df.empty:
|
||||
return df
|
||||
df = df.copy()
|
||||
df.set_index(date_field_name, inplace=True)
|
||||
df.index = pd.to_datetime(df.index)
|
||||
df = df[~df.index.duplicated(keep="first")]
|
||||
if calendar_list is not None:
|
||||
df = df.reindex(
|
||||
pd.DataFrame(index=calendar_list)
|
||||
.loc[
|
||||
pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date()
|
||||
+ pd.Timedelta(hours=23, minutes=59)
|
||||
]
|
||||
.index
|
||||
)
|
||||
df.sort_index(inplace=True)
|
||||
|
||||
df.index.names = [date_field_name]
|
||||
return df.reset_index()
|
||||
|
||||
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
# normalize
|
||||
df = self.normalize_fund(df, self._calendar_list, self._date_field_name, self._symbol_field_name)
|
||||
return df
|
||||
|
||||
|
||||
class FundNormalize1d(FundNormalize):
|
||||
pass
|
||||
|
||||
|
||||
class FundNormalizeCN:
|
||||
def _get_calendar_list(self):
|
||||
return get_calendar_list("ALL")
|
||||
|
||||
|
||||
class FundNormalizeCN1d(FundNormalizeCN, FundNormalize1d):
|
||||
pass
|
||||
|
||||
|
||||
class Run(BaseRun):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str
|
||||
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
||||
normalize_dir: str
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
max_workers: int
|
||||
Concurrent number, default is 4
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
region: str
|
||||
region, value from ["CN"], default "CN"
|
||||
"""
|
||||
super().__init__(source_dir, normalize_dir, max_workers, interval)
|
||||
self.region = region
|
||||
|
||||
@property
|
||||
def collector_class_name(self):
|
||||
return f"FundCollector{self.region.upper()}{self.interval}"
|
||||
|
||||
@property
|
||||
def normalize_class_name(self):
|
||||
return f"FundNormalize{self.region.upper()}{self.interval}"
|
||||
|
||||
@property
|
||||
def default_base_dir(self) -> [Path, str]:
|
||||
return CUR_DIR
|
||||
|
||||
def download_data(
|
||||
self,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
check_data_length=False,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download data from Internet
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
start: str
|
||||
start datetime, default "2000-01-01"
|
||||
end: str
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
||||
check_data_length: bool # if this param useful?
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
|
||||
Examples
|
||||
---------
|
||||
# get daily data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
"""
|
||||
|
||||
super(Run, self).download_data(max_collector_count, delay, start, end, interval, check_data_length, limit_nums)
|
||||
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
||||
"""normalize data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
date_field_name: str
|
||||
date field name, default date
|
||||
symbol_field_name: str
|
||||
symbol field name, default symbol
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_1d --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ
|
||||
"""
|
||||
super(Run, self).normalize_data(date_field_name, symbol_field_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(Run)
|
||||
10
scripts/data_collector/fund/requirements.txt
Normal file
10
scripts/data_collector/fund/requirements.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
loguru
|
||||
fire
|
||||
requests
|
||||
numpy
|
||||
pandas
|
||||
tqdm
|
||||
lxml
|
||||
loguru
|
||||
yahooquery
|
||||
json
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import os
|
||||
import time
|
||||
import bisect
|
||||
import pickle
|
||||
@@ -14,6 +15,9 @@ import pandas as pd
|
||||
from lxml import etree
|
||||
from loguru import logger
|
||||
from yahooquery import Ticker
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
HS_SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}"
|
||||
|
||||
@@ -34,6 +38,7 @@ _BENCH_CALENDAR_LIST = None
|
||||
_ALL_CALENDAR_LIST = None
|
||||
_HS_SYMBOLS = None
|
||||
_US_SYMBOLS = None
|
||||
_EN_FUND_SYMBOLS = None
|
||||
_CALENDAR_MAP = {}
|
||||
|
||||
# NOTE: Until 2020-10-20 20:00:00
|
||||
@@ -93,6 +98,78 @@ def get_calendar_list(bench_code="CSI300") -> list:
|
||||
return calendar
|
||||
|
||||
|
||||
def return_date_list(date_field_name: str, file_path: Path):
|
||||
date_list = pd.read_csv(file_path, sep=",", index_col=0)[date_field_name].to_list()
|
||||
return sorted(map(lambda x: pd.Timestamp(x), date_list))
|
||||
|
||||
|
||||
def get_calendar_list_by_ratio(
|
||||
source_dir: [str, Path],
|
||||
date_field_name: str = "date",
|
||||
threshold: float = 0.5,
|
||||
minimum_count: int = 10,
|
||||
max_workers: int = 16,
|
||||
) -> list:
|
||||
"""get calendar list by selecting the date when few funds trade in this day
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str or Path
|
||||
The directory where the raw data collected from the Internet is saved
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
threshold: float
|
||||
threshold to exclude some days when few funds trade in this day, default 0.5
|
||||
minimum_count: int
|
||||
minimum count of funds should trade in one day
|
||||
max_workers: int
|
||||
Concurrent number, default is 16
|
||||
|
||||
Returns
|
||||
-------
|
||||
history calendar list
|
||||
"""
|
||||
logger.info(f"get calendar list from {source_dir} by threshold = {threshold}......")
|
||||
|
||||
source_dir = Path(source_dir).expanduser()
|
||||
file_list = list(source_dir.glob("*.csv"))
|
||||
|
||||
_number_all_funds = len(file_list)
|
||||
|
||||
logger.info(f"count how many funds trade in this day......")
|
||||
_dict_count_trade = dict() # dict{date:count}
|
||||
_fun = partial(return_date_list, date_field_name)
|
||||
all_oldest_list = []
|
||||
with tqdm(total=_number_all_funds) as p_bar:
|
||||
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
||||
for date_list in executor.map(_fun, file_list):
|
||||
if date_list:
|
||||
all_oldest_list.append(date_list[0])
|
||||
for date in date_list:
|
||||
if date not in _dict_count_trade.keys():
|
||||
_dict_count_trade[date] = 0
|
||||
|
||||
_dict_count_trade[date] += 1
|
||||
|
||||
p_bar.update()
|
||||
|
||||
logger.info(f"count how many funds have founded in this day......")
|
||||
_dict_count_founding = {date: _number_all_funds for date in _dict_count_trade.keys()} # dict{date:count}
|
||||
with tqdm(total=_number_all_funds) as p_bar:
|
||||
for oldest_date in all_oldest_list:
|
||||
for date in _dict_count_founding.keys():
|
||||
if date < oldest_date:
|
||||
_dict_count_founding[date] -= 1
|
||||
|
||||
calendar = [
|
||||
date
|
||||
for date in _dict_count_trade
|
||||
if _dict_count_trade[date] >= max(int(_dict_count_founding[date] * threshold), minimum_count)
|
||||
]
|
||||
|
||||
return calendar
|
||||
|
||||
|
||||
def get_hs_stock_symbols() -> list:
|
||||
"""get SH/SZ stock symbols
|
||||
|
||||
@@ -220,6 +297,42 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
return _US_SYMBOLS
|
||||
|
||||
|
||||
def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
"""get en fund symbols
|
||||
|
||||
Returns
|
||||
-------
|
||||
fund symbols in China
|
||||
"""
|
||||
global _EN_FUND_SYMBOLS
|
||||
|
||||
@deco_retry
|
||||
def _get_eastmoney():
|
||||
url = "http://fund.eastmoney.com/js/fundcode_search.js"
|
||||
resp = requests.get(url)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
try:
|
||||
_symbols = []
|
||||
for sub_data in re.findall(r"[\[](.*?)[\]]", resp.content.decode().split("= [")[-1].replace("];", "")):
|
||||
data = sub_data.replace('"', "").replace("'", "")
|
||||
# TODO: do we need other informations, like fund_name from ['000001', 'HXCZHH', '华夏成长混合', '混合型', 'HUAXIACHENGZHANGHUNHE']
|
||||
_symbols.append(data.split(",")[0])
|
||||
except Exception as e:
|
||||
logger.warning(f"request error: {e}")
|
||||
raise
|
||||
if len(_symbols) < 8000:
|
||||
raise ValueError("request error")
|
||||
return _symbols
|
||||
|
||||
if _EN_FUND_SYMBOLS is None:
|
||||
_all_symbols = _get_eastmoney()
|
||||
|
||||
_EN_FUND_SYMBOLS = sorted(set(_all_symbols))
|
||||
|
||||
return _EN_FUND_SYMBOLS
|
||||
|
||||
|
||||
def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str:
|
||||
"""symbol suffix to prefix
|
||||
|
||||
|
||||
Reference in New Issue
Block a user