mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
feat: data improve, support parquet (#1966)
* refactor: relocate CLI modules to qlib.cli and update references * refactor: introduce read_as_df and rename csv_path to data_path * lint * refactor: rename csv_path to data_path and use QSettings.provider_uri * fix pylint error * fix get_data command * add comments to CI yaml * update docs --------- Co-authored-by: Linlang <Lv.Linlang@hotmail.com>
This commit is contained in:
@@ -17,6 +17,39 @@ from loguru import logger
|
||||
from qlib.utils import fname_to_code, code_to_fname
|
||||
|
||||
|
||||
def read_as_df(file_path: Union[str, Path], **kwargs) -> pd.DataFrame:
|
||||
"""
|
||||
Read a csv or parquet file into a pandas DataFrame.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file_path : Union[str, Path]
|
||||
Path to the data file.
|
||||
**kwargs :
|
||||
Additional keyword arguments passed to the underlying pandas
|
||||
reader.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame
|
||||
"""
|
||||
file_path = Path(file_path).expanduser()
|
||||
suffix = file_path.suffix.lower()
|
||||
|
||||
keep_keys = {".csv": ("low_memory",)}
|
||||
kept_kwargs = {}
|
||||
for k in keep_keys.get(suffix, []):
|
||||
if k in kwargs:
|
||||
kept_kwargs[k] = kwargs[k]
|
||||
|
||||
if suffix == ".csv":
|
||||
return pd.read_csv(file_path, **kept_kwargs)
|
||||
elif suffix == ".parquet":
|
||||
return pd.read_parquet(file_path, **kept_kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file format: {suffix}")
|
||||
|
||||
|
||||
class DumpDataBase:
|
||||
INSTRUMENTS_START_FIELD = "start_datetime"
|
||||
INSTRUMENTS_END_FIELD = "end_datetime"
|
||||
@@ -34,7 +67,7 @@ class DumpDataBase:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
csv_path: str,
|
||||
data_path: str,
|
||||
qlib_dir: str,
|
||||
backup_dir: str = None,
|
||||
freq: str = "day",
|
||||
@@ -50,7 +83,7 @@ class DumpDataBase:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
csv_path: str
|
||||
data_path: str
|
||||
stock data path or directory
|
||||
qlib_dir: str
|
||||
qlib(dump) data director
|
||||
@@ -73,7 +106,7 @@ class DumpDataBase:
|
||||
limit_nums: int
|
||||
Use when debugging, default None
|
||||
"""
|
||||
csv_path = Path(csv_path).expanduser()
|
||||
data_path = Path(data_path).expanduser()
|
||||
if isinstance(exclude_fields, str):
|
||||
exclude_fields = exclude_fields.split(",")
|
||||
if isinstance(include_fields, str):
|
||||
@@ -82,9 +115,9 @@ class DumpDataBase:
|
||||
self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields)))
|
||||
self.file_suffix = file_suffix
|
||||
self.symbol_field_name = symbol_field_name
|
||||
self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path])
|
||||
self.df_files = sorted(data_path.glob(f"*{self.file_suffix}") if data_path.is_dir() else [data_path])
|
||||
if limit_nums is not None:
|
||||
self.csv_files = self.csv_files[: int(limit_nums)]
|
||||
self.df_files = self.df_files[: int(limit_nums)]
|
||||
self.qlib_dir = Path(qlib_dir).expanduser()
|
||||
self.backup_dir = backup_dir if backup_dir is None else Path(backup_dir).expanduser()
|
||||
if backup_dir is not None:
|
||||
@@ -134,13 +167,14 @@ class DumpDataBase:
|
||||
return _calendars.tolist()
|
||||
|
||||
def _get_source_data(self, file_path: Path) -> pd.DataFrame:
|
||||
df = pd.read_csv(str(file_path.resolve()), low_memory=False)
|
||||
df[self.date_field_name] = df[self.date_field_name].astype(str).astype("datetime64[ns]")
|
||||
df = read_as_df(file_path, low_memory=False)
|
||||
if self.date_field_name in df.columns:
|
||||
df[self.date_field_name] = pd.to_datetime(df[self.date_field_name])
|
||||
# df.drop_duplicates([self.date_field_name], inplace=True)
|
||||
return df
|
||||
|
||||
def get_symbol_from_file(self, file_path: Path) -> str:
|
||||
return fname_to_code(file_path.name[: -len(self.file_suffix)].strip().lower())
|
||||
return fname_to_code(file_path.stem.strip().lower())
|
||||
|
||||
def get_dump_fields(self, df_columns: Iterable[str]) -> Iterable[str]:
|
||||
return (
|
||||
@@ -274,10 +308,10 @@ class DumpDataAll(DumpDataBase):
|
||||
all_datetime = set()
|
||||
date_range_list = []
|
||||
_fun = partial(self._get_date, as_set=True, is_begin_end=True)
|
||||
with tqdm(total=len(self.csv_files)) as p_bar:
|
||||
with tqdm(total=len(self.df_files)) as p_bar:
|
||||
with ProcessPoolExecutor(max_workers=self.works) as executor:
|
||||
for file_path, ((_begin_time, _end_time), _set_calendars) in zip(
|
||||
self.csv_files, executor.map(_fun, self.csv_files)
|
||||
self.df_files, executor.map(_fun, self.df_files)
|
||||
):
|
||||
all_datetime = all_datetime | _set_calendars
|
||||
if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp):
|
||||
@@ -305,9 +339,9 @@ class DumpDataAll(DumpDataBase):
|
||||
def _dump_features(self):
|
||||
logger.info("start dump features......")
|
||||
_dump_func = partial(self._dump_bin, calendar_list=self._calendars_list)
|
||||
with tqdm(total=len(self.csv_files)) as p_bar:
|
||||
with tqdm(total=len(self.df_files)) as p_bar:
|
||||
with ProcessPoolExecutor(max_workers=self.works) as executor:
|
||||
for _ in executor.map(_dump_func, self.csv_files):
|
||||
for _ in executor.map(_dump_func, self.df_files):
|
||||
p_bar.update()
|
||||
|
||||
logger.info("end of features dump.\n")
|
||||
@@ -325,16 +359,15 @@ class DumpDataFix(DumpDataAll):
|
||||
_fun = partial(self._get_date, is_begin_end=True)
|
||||
new_stock_files = sorted(
|
||||
filter(
|
||||
lambda x: fname_to_code(x.name[: -len(self.file_suffix)].strip().lower()).upper()
|
||||
not in self._old_instruments,
|
||||
self.csv_files,
|
||||
lambda x: self.get_symbol_from_file(x).upper() not in self._old_instruments,
|
||||
self.df_files,
|
||||
)
|
||||
)
|
||||
with tqdm(total=len(new_stock_files)) as p_bar:
|
||||
with ProcessPoolExecutor(max_workers=self.works) as execute:
|
||||
for file_path, (_begin_time, _end_time) in zip(new_stock_files, execute.map(_fun, new_stock_files)):
|
||||
if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp):
|
||||
symbol = fname_to_code(self.get_symbol_from_file(file_path).lower()).upper()
|
||||
symbol = self.get_symbol_from_file(file_path).upper()
|
||||
_dt_map = self._old_instruments.setdefault(symbol, dict())
|
||||
_dt_map[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_begin_time)
|
||||
_dt_map[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end_time)
|
||||
@@ -359,7 +392,7 @@ class DumpDataFix(DumpDataAll):
|
||||
class DumpDataUpdate(DumpDataBase):
|
||||
def __init__(
|
||||
self,
|
||||
csv_path: str,
|
||||
data_path: str,
|
||||
qlib_dir: str,
|
||||
backup_dir: str = None,
|
||||
freq: str = "day",
|
||||
@@ -375,7 +408,7 @@ class DumpDataUpdate(DumpDataBase):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
csv_path: str
|
||||
data_path: str
|
||||
stock data path or directory
|
||||
qlib_dir: str
|
||||
qlib(dump) data director
|
||||
@@ -399,7 +432,7 @@ class DumpDataUpdate(DumpDataBase):
|
||||
Use when debugging, default None
|
||||
"""
|
||||
super().__init__(
|
||||
csv_path,
|
||||
data_path,
|
||||
qlib_dir,
|
||||
backup_dir,
|
||||
freq,
|
||||
@@ -431,15 +464,19 @@ class DumpDataUpdate(DumpDataBase):
|
||||
logger.info("start load all source data....")
|
||||
all_df = []
|
||||
|
||||
def _read_csv(file_path: Path):
|
||||
_df = pd.read_csv(file_path, parse_dates=[self.date_field_name])
|
||||
def _read_df(file_path: Path):
|
||||
_df = read_as_df(file_path)
|
||||
if self.date_field_name in _df.columns and not np.issubdtype(
|
||||
_df[self.date_field_name].dtype, np.datetime64
|
||||
):
|
||||
_df[self.date_field_name] = pd.to_datetime(_df[self.date_field_name])
|
||||
if self.symbol_field_name not in _df.columns:
|
||||
_df[self.symbol_field_name] = self.get_symbol_from_file(file_path)
|
||||
return _df
|
||||
|
||||
with tqdm(total=len(self.csv_files)) as p_bar:
|
||||
with tqdm(total=len(self.df_files)) as p_bar:
|
||||
with ThreadPoolExecutor(max_workers=self.works) as executor:
|
||||
for df in executor.map(_read_csv, self.csv_files):
|
||||
for df in executor.map(_read_df, self.df_files):
|
||||
if not df.empty:
|
||||
all_df.append(df)
|
||||
p_bar.update()
|
||||
|
||||
Reference in New Issue
Block a user