mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
add normalizer
This commit is contained in:
@@ -20,10 +20,12 @@ pip install -r requirements.txt
|
||||
# download from eastmoney.com
|
||||
python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
|
||||
# normalize
|
||||
python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_1d --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ
|
||||
|
||||
# dump data
|
||||
cd qlib/scripts
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/fund_data/source/cn_1d --qlib_dir ~/.qlib/qlib_data/cn_fund_data --freq day --date_field_name FSRQ --include_fields DWJZ,LJJZ
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/fund_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/cn_fund_data --freq day --date_field_name FSRQ --include_fields DWJZ,LJJZ
|
||||
|
||||
```
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from dateutil.tz import tzlocal
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.utils import get_en_fund_symbols
|
||||
from data_collector.utils import get_calendar_list, get_en_fund_symbols
|
||||
|
||||
INDEX_BENCH_URL = "http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}"
|
||||
REGION_CN = "CN"
|
||||
@@ -302,14 +302,149 @@ class FundCollectorCN1d(FundollectorCN):
|
||||
return 252 / 4
|
||||
|
||||
|
||||
class FundNormalize:
|
||||
COLUMNS = ["open", "close", "high", "low", "volume"]
|
||||
DAILY_FORMAT = "%Y-%m-%d"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
self._date_field_name = date_field_name
|
||||
self._symbol_field_name = symbol_field_name
|
||||
|
||||
self._calendar_list = self._get_calendar_list()
|
||||
print (self._calendar_list)
|
||||
|
||||
@staticmethod
|
||||
def normalize_fund(
|
||||
df: pd.DataFrame,
|
||||
calendar_list: list = None,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
if df.empty:
|
||||
return df
|
||||
df = df.copy()
|
||||
df.set_index(date_field_name, inplace=True)
|
||||
df.index = pd.to_datetime(df.index)
|
||||
df = df[~df.index.duplicated(keep="first")]
|
||||
if calendar_list is not None:
|
||||
df = df.reindex(
|
||||
pd.DataFrame(index=calendar_list)
|
||||
.loc[
|
||||
pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date()
|
||||
+ pd.Timedelta(hours=23, minutes=59)
|
||||
]
|
||||
.index
|
||||
)
|
||||
df.sort_index(inplace=True)
|
||||
|
||||
df.index.names = [date_field_name]
|
||||
return df.reset_index()
|
||||
|
||||
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
# normalize
|
||||
df = self.normalize_fund(df, self._calendar_list, self._date_field_name, self._symbol_field_name)
|
||||
return df
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_calendar_list(self):
|
||||
"""Get benchmark calendar"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
|
||||
class FundNormalize1d(FundNormalize, ABC):
|
||||
DAILY_FORMAT = "%Y-%m-%d"
|
||||
|
||||
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
df = super(FundNormalize, self).normalize(df)
|
||||
return df
|
||||
|
||||
|
||||
class FundNormalizeCN:
|
||||
def _get_calendar_list(self):
|
||||
return get_calendar_list("ALL")
|
||||
|
||||
|
||||
class FundNormalizeCN1d(FundNormalizeCN, FundNormalize1d):
|
||||
pass
|
||||
|
||||
|
||||
class Normalize:
|
||||
def __init__(
|
||||
self,
|
||||
source_dir: [str, Path],
|
||||
target_dir: [str, Path],
|
||||
normalize_class: Type[FundNormalize],
|
||||
max_workers: int = 16,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str or Path
|
||||
The directory where the raw data collected from the Internet is saved
|
||||
target_dir: str or Path
|
||||
Directory for normalize data
|
||||
normalize_class: Type[FundNormalize]
|
||||
normalize class
|
||||
max_workers: int
|
||||
Concurrent number, default is 16
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
if not (source_dir and target_dir):
|
||||
raise ValueError("source_dir and target_dir cannot be None")
|
||||
self._source_dir = Path(source_dir).expanduser()
|
||||
self._target_dir = Path(target_dir).expanduser()
|
||||
self._target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._max_workers = max_workers
|
||||
|
||||
self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name)
|
||||
|
||||
def _executor(self, file_path: Path):
|
||||
file_path = Path(file_path)
|
||||
df = pd.read_csv(file_path)
|
||||
df = self._normalize_obj.normalize(df)
|
||||
if not df.empty:
|
||||
df.to_csv(self._target_dir.joinpath(file_path.name), index=False)
|
||||
|
||||
def normalize(self):
|
||||
logger.info("normalize data......")
|
||||
|
||||
with ProcessPoolExecutor(max_workers=self._max_workers) as worker:
|
||||
file_list = list(self._source_dir.glob("*.csv"))
|
||||
with tqdm(total=len(file_list)) as p_bar:
|
||||
for _ in worker.map(self._executor, file_list):
|
||||
p_bar.update()
|
||||
|
||||
|
||||
class Run:
|
||||
def __init__(self, source_dir=None, max_workers=4, region=REGION_CN):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str
|
||||
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
||||
normalize_dir: str
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
max_workers: int
|
||||
Concurrent number, default is 4
|
||||
region: str
|
||||
@@ -320,6 +455,11 @@ class Run:
|
||||
self.source_dir = Path(source_dir).expanduser().resolve()
|
||||
self.source_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if normalize_dir is None:
|
||||
normalize_dir = CUR_DIR.joinpath("normalize")
|
||||
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
|
||||
self.normalize_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._cur_module = importlib.import_module("collector")
|
||||
self.max_workers = max_workers
|
||||
self.region = region
|
||||
@@ -372,6 +512,33 @@ class Run:
|
||||
limit_nums=limit_nums,
|
||||
).collector_data()
|
||||
|
||||
def normalize_data(self, interval: str = "1d", date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
||||
"""normalize data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
interval: str
|
||||
freq, value from [1d], default 1d
|
||||
date_field_name: str
|
||||
date field name, default date
|
||||
symbol_field_name: str
|
||||
symbol field name, default symbol
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_1d --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ
|
||||
"""
|
||||
_class = getattr(self._cur_module, f"FundNormalize{self.region.upper()}{interval}")
|
||||
fc = Normalize(
|
||||
source_dir=self.source_dir,
|
||||
target_dir=self.normalize_dir,
|
||||
normalize_class=_class,
|
||||
max_workers=self.max_workers,
|
||||
date_field_name=date_field_name,
|
||||
symbol_field_name=symbol_field_name,
|
||||
)
|
||||
fc.normalize()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(Run)
|
||||
|
||||
@@ -98,14 +98,14 @@ def get_calendar_list(bench_code="CSI300") -> list:
|
||||
return calendar
|
||||
|
||||
|
||||
def return_date_list(source_dir, date_field_name, file_path):
|
||||
df = pd.read_csv(Path(source_dir).joinpath(file_path), sep=",", index_col=0)
|
||||
|
||||
return df[date_field_name].to_list()
|
||||
def return_date_list(source_dir, date_field_name: str, file_path: Path):
|
||||
file_path = Path(file_path)
|
||||
date_list = pd.read_csv(Path(source_dir).joinpath(file_path), sep=",", index_col=0)[date_field_name].to_list()
|
||||
return sorted(map(lambda x: pd.Timestamp(x), date_list))
|
||||
|
||||
|
||||
def get_calendar_list_by_ratio(
|
||||
source_dir: [str, Path], date_field_name: str = "date", threshold: float = 0.5, max_workers: int = 16
|
||||
source_dir: [str, Path], date_field_name: str = "date", threshold: float = 0.5, minimum_count: int = 10, max_workers: int = 16
|
||||
) -> list:
|
||||
"""get calendar list by selecting the date when few funds trade in this day
|
||||
|
||||
@@ -117,6 +117,8 @@ def get_calendar_list_by_ratio(
|
||||
date field name, default is date
|
||||
threshold: float
|
||||
threshold to exclude some days when few funds trade in this day, default 0.5
|
||||
minimum_count: int
|
||||
minimum count of funds should trade in one day
|
||||
max_workers: int
|
||||
Concurrent number, default is 16
|
||||
|
||||
@@ -126,25 +128,36 @@ def get_calendar_list_by_ratio(
|
||||
"""
|
||||
logger.info(f"get calendar list from {source_dir} by threshold = {threshold}......")
|
||||
|
||||
_number_all_funds = len(os.listdir(source_dir))
|
||||
source_dir = Path(source_dir).expanduser()
|
||||
file_list = list(source_dir.glob("*.csv"))
|
||||
|
||||
_list_all_date = dict()
|
||||
_number_all_funds = len(file_list)
|
||||
|
||||
logger.info(f"count how many funds trade in this day......")
|
||||
_dict_count_trade = dict() # dict{date:count}
|
||||
_fun = partial(return_date_list, source_dir, date_field_name)
|
||||
|
||||
with tqdm(total=_number_all_funds) as p_bar:
|
||||
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
||||
for date_list in executor.map(_fun, os.listdir(source_dir)):
|
||||
for date_list in executor.map(_fun, file_list[:_number_all_funds]):
|
||||
for date in date_list:
|
||||
if date in _list_all_date.keys():
|
||||
_list_all_date[date] += 1
|
||||
else:
|
||||
_list_all_date[date] = 0
|
||||
if date not in _dict_count_trade.keys():
|
||||
_dict_count_trade[date] = 0
|
||||
|
||||
_dict_count_trade[date] += 1
|
||||
|
||||
p_bar.update()
|
||||
|
||||
logger.info(f"count how many funds have founded in this day......")
|
||||
_dict_count_founding = {date:_number_all_funds for date in _dict_count_trade.keys()} # dict{date:count}
|
||||
with tqdm(total=_number_all_funds) as p_bar:
|
||||
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
||||
for date_list in executor.map(_fun, file_list[:_number_all_funds]):
|
||||
oldest_date = sorted(date_list)[0] # this fund haven't found before this day
|
||||
for date in _dict_count_founding.keys():
|
||||
if date < oldest_date:
|
||||
_dict_count_founding[date] -= 1
|
||||
|
||||
_threshold_number = int(_number_all_funds * threshold)
|
||||
calendar = [date for date in _list_all_date if _list_all_date[date] >= _threshold_number]
|
||||
calendar = [date for date in _dict_count_trade if _dict_count_trade[date] >= max(int(_dict_count_founding[date] * threshold), minimum_count)]
|
||||
|
||||
return calendar
|
||||
|
||||
|
||||
Reference in New Issue
Block a user