1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

add normalizer

This commit is contained in:
wangershi
2021-03-07 18:51:38 +08:00
parent 34b7da1dd8
commit 11412727ef
3 changed files with 200 additions and 18 deletions

View File

@@ -20,10 +20,12 @@ pip install -r requirements.txt
# download from eastmoney.com
python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
# normalize
python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_1d --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ
# dump data
cd qlib/scripts
python dump_bin.py dump_all --csv_path ~/.qlib/fund_data/source/cn_1d --qlib_dir ~/.qlib/qlib_data/cn_fund_data --freq day --date_field_name FSRQ --include_fields DWJZ,LJJZ
python dump_bin.py dump_all --csv_path ~/.qlib/fund_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/cn_fund_data --freq day --date_field_name FSRQ --include_fields DWJZ,LJJZ
```

View File

@@ -23,7 +23,7 @@ from dateutil.tz import tzlocal
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
from data_collector.utils import get_en_fund_symbols
from data_collector.utils import get_calendar_list, get_en_fund_symbols
INDEX_BENCH_URL = "http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}"
REGION_CN = "CN"
@@ -302,14 +302,149 @@ class FundCollectorCN1d(FundollectorCN):
return 252 / 4
class FundNormalize:
COLUMNS = ["open", "close", "high", "low", "volume"]
DAILY_FORMAT = "%Y-%m-%d"
def __init__(
self,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
):
"""
Parameters
----------
date_field_name: str
date field name, default is date
symbol_field_name: str
symbol field name, default is symbol
"""
self._date_field_name = date_field_name
self._symbol_field_name = symbol_field_name
self._calendar_list = self._get_calendar_list()
print (self._calendar_list)
@staticmethod
def normalize_fund(
df: pd.DataFrame,
calendar_list: list = None,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
):
if df.empty:
return df
df = df.copy()
df.set_index(date_field_name, inplace=True)
df.index = pd.to_datetime(df.index)
df = df[~df.index.duplicated(keep="first")]
if calendar_list is not None:
df = df.reindex(
pd.DataFrame(index=calendar_list)
.loc[
pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date()
+ pd.Timedelta(hours=23, minutes=59)
]
.index
)
df.sort_index(inplace=True)
df.index.names = [date_field_name]
return df.reset_index()
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
# normalize
df = self.normalize_fund(df, self._calendar_list, self._date_field_name, self._symbol_field_name)
return df
@abc.abstractmethod
def _get_calendar_list(self):
"""Get benchmark calendar"""
raise NotImplementedError("")
class FundNormalize1d(FundNormalize, ABC):
DAILY_FORMAT = "%Y-%m-%d"
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
df = super(FundNormalize, self).normalize(df)
return df
class FundNormalizeCN:
def _get_calendar_list(self):
return get_calendar_list("ALL")
class FundNormalizeCN1d(FundNormalizeCN, FundNormalize1d):
pass
class Normalize:
def __init__(
self,
source_dir: [str, Path],
target_dir: [str, Path],
normalize_class: Type[FundNormalize],
max_workers: int = 16,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
):
"""
Parameters
----------
source_dir: str or Path
The directory where the raw data collected from the Internet is saved
target_dir: str or Path
Directory for normalize data
normalize_class: Type[FundNormalize]
normalize class
max_workers: int
Concurrent number, default is 16
date_field_name: str
date field name, default is date
symbol_field_name: str
symbol field name, default is symbol
"""
if not (source_dir and target_dir):
raise ValueError("source_dir and target_dir cannot be None")
self._source_dir = Path(source_dir).expanduser()
self._target_dir = Path(target_dir).expanduser()
self._target_dir.mkdir(parents=True, exist_ok=True)
self._max_workers = max_workers
self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name)
def _executor(self, file_path: Path):
file_path = Path(file_path)
df = pd.read_csv(file_path)
df = self._normalize_obj.normalize(df)
if not df.empty:
df.to_csv(self._target_dir.joinpath(file_path.name), index=False)
def normalize(self):
logger.info("normalize data......")
with ProcessPoolExecutor(max_workers=self._max_workers) as worker:
file_list = list(self._source_dir.glob("*.csv"))
with tqdm(total=len(file_list)) as p_bar:
for _ in worker.map(self._executor, file_list):
p_bar.update()
class Run:
def __init__(self, source_dir=None, max_workers=4, region=REGION_CN):
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN):
"""
Parameters
----------
source_dir: str
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
normalize_dir: str
Directory for normalize data, default "Path(__file__).parent/normalize"
max_workers: int
Concurrent number, default is 4
region: str
@@ -320,6 +455,11 @@ class Run:
self.source_dir = Path(source_dir).expanduser().resolve()
self.source_dir.mkdir(parents=True, exist_ok=True)
if normalize_dir is None:
normalize_dir = CUR_DIR.joinpath("normalize")
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
self.normalize_dir.mkdir(parents=True, exist_ok=True)
self._cur_module = importlib.import_module("collector")
self.max_workers = max_workers
self.region = region
@@ -372,6 +512,33 @@ class Run:
limit_nums=limit_nums,
).collector_data()
def normalize_data(self, interval: str = "1d", date_field_name: str = "date", symbol_field_name: str = "symbol"):
"""normalize data
Parameters
----------
interval: str
freq, value from [1d], default 1d
date_field_name: str
date field name, default date
symbol_field_name: str
symbol field name, default symbol
Examples
---------
$ python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_1d --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ
"""
_class = getattr(self._cur_module, f"FundNormalize{self.region.upper()}{interval}")
fc = Normalize(
source_dir=self.source_dir,
target_dir=self.normalize_dir,
normalize_class=_class,
max_workers=self.max_workers,
date_field_name=date_field_name,
symbol_field_name=symbol_field_name,
)
fc.normalize()
if __name__ == "__main__":
fire.Fire(Run)

View File

@@ -98,14 +98,14 @@ def get_calendar_list(bench_code="CSI300") -> list:
return calendar
def return_date_list(source_dir, date_field_name, file_path):
df = pd.read_csv(Path(source_dir).joinpath(file_path), sep=",", index_col=0)
return df[date_field_name].to_list()
def return_date_list(source_dir, date_field_name: str, file_path: Path):
file_path = Path(file_path)
date_list = pd.read_csv(Path(source_dir).joinpath(file_path), sep=",", index_col=0)[date_field_name].to_list()
return sorted(map(lambda x: pd.Timestamp(x), date_list))
def get_calendar_list_by_ratio(
source_dir: [str, Path], date_field_name: str = "date", threshold: float = 0.5, max_workers: int = 16
source_dir: [str, Path], date_field_name: str = "date", threshold: float = 0.5, minimum_count: int = 10, max_workers: int = 16
) -> list:
"""get calendar list by selecting the date when few funds trade in this day
@@ -117,6 +117,8 @@ def get_calendar_list_by_ratio(
date field name, default is date
threshold: float
threshold to exclude some days when few funds trade in this day, default 0.5
minimum_count: int
minimum count of funds should trade in one day
max_workers: int
Concurrent number, default is 16
@@ -126,25 +128,36 @@ def get_calendar_list_by_ratio(
"""
logger.info(f"get calendar list from {source_dir} by threshold = {threshold}......")
_number_all_funds = len(os.listdir(source_dir))
source_dir = Path(source_dir).expanduser()
file_list = list(source_dir.glob("*.csv"))
_list_all_date = dict()
_number_all_funds = len(file_list)
logger.info(f"count how many funds trade in this day......")
_dict_count_trade = dict() # dict{date:count}
_fun = partial(return_date_list, source_dir, date_field_name)
with tqdm(total=_number_all_funds) as p_bar:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
for date_list in executor.map(_fun, os.listdir(source_dir)):
for date_list in executor.map(_fun, file_list[:_number_all_funds]):
for date in date_list:
if date in _list_all_date.keys():
_list_all_date[date] += 1
else:
_list_all_date[date] = 0
if date not in _dict_count_trade.keys():
_dict_count_trade[date] = 0
_dict_count_trade[date] += 1
p_bar.update()
logger.info(f"count how many funds have founded in this day......")
_dict_count_founding = {date:_number_all_funds for date in _dict_count_trade.keys()} # dict{date:count}
with tqdm(total=_number_all_funds) as p_bar:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
for date_list in executor.map(_fun, file_list[:_number_all_funds]):
oldest_date = sorted(date_list)[0] # this fund haven't found before this day
for date in _dict_count_founding.keys():
if date < oldest_date:
_dict_count_founding[date] -= 1
_threshold_number = int(_number_all_funds * threshold)
calendar = [date for date in _list_all_date if _list_all_date[date] >= _threshold_number]
calendar = [date for date in _dict_count_trade if _dict_count_trade[date] >= max(int(_dict_count_founding[date] * threshold), minimum_count)]
return calendar