1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

feat: data improve, support parquet (#1966)

* refactor: relocate CLI modules to qlib.cli and update references

* refactor: introduce read_as_df and rename csv_path to data_path

* lint

* refactor: rename csv_path to data_path and use QSettings.provider_uri

* fix pylint error

* fix get_data command

* add comments to CI yaml

* update docs

---------

Co-authored-by: Linlang <Lv.Linlang@hotmail.com>
This commit is contained in:
you-n-g
2025-08-07 15:04:37 +08:00
committed by GitHub
parent 78b77e302b
commit 1b426503fc
21 changed files with 105 additions and 62 deletions

View File

@@ -17,6 +17,39 @@ from loguru import logger
from qlib.utils import fname_to_code, code_to_fname
def read_as_df(file_path: Union[str, Path], **kwargs) -> pd.DataFrame:
"""
Read a csv or parquet file into a pandas DataFrame.
Parameters
----------
file_path : Union[str, Path]
Path to the data file.
**kwargs :
Additional keyword arguments passed to the underlying pandas
reader.
Returns
-------
pd.DataFrame
"""
file_path = Path(file_path).expanduser()
suffix = file_path.suffix.lower()
keep_keys = {".csv": ("low_memory",)}
kept_kwargs = {}
for k in keep_keys.get(suffix, []):
if k in kwargs:
kept_kwargs[k] = kwargs[k]
if suffix == ".csv":
return pd.read_csv(file_path, **kept_kwargs)
elif suffix == ".parquet":
return pd.read_parquet(file_path, **kept_kwargs)
else:
raise ValueError(f"Unsupported file format: {suffix}")
class DumpDataBase:
INSTRUMENTS_START_FIELD = "start_datetime"
INSTRUMENTS_END_FIELD = "end_datetime"
@@ -34,7 +67,7 @@ class DumpDataBase:
def __init__(
self,
csv_path: str,
data_path: str,
qlib_dir: str,
backup_dir: str = None,
freq: str = "day",
@@ -50,7 +83,7 @@ class DumpDataBase:
Parameters
----------
csv_path: str
data_path: str
stock data path or directory
qlib_dir: str
qlib(dump) data director
@@ -73,7 +106,7 @@ class DumpDataBase:
limit_nums: int
Use when debugging, default None
"""
csv_path = Path(csv_path).expanduser()
data_path = Path(data_path).expanduser()
if isinstance(exclude_fields, str):
exclude_fields = exclude_fields.split(",")
if isinstance(include_fields, str):
@@ -82,9 +115,9 @@ class DumpDataBase:
self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields)))
self.file_suffix = file_suffix
self.symbol_field_name = symbol_field_name
self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path])
self.df_files = sorted(data_path.glob(f"*{self.file_suffix}") if data_path.is_dir() else [data_path])
if limit_nums is not None:
self.csv_files = self.csv_files[: int(limit_nums)]
self.df_files = self.df_files[: int(limit_nums)]
self.qlib_dir = Path(qlib_dir).expanduser()
self.backup_dir = backup_dir if backup_dir is None else Path(backup_dir).expanduser()
if backup_dir is not None:
@@ -134,13 +167,14 @@ class DumpDataBase:
return _calendars.tolist()
def _get_source_data(self, file_path: Path) -> pd.DataFrame:
df = pd.read_csv(str(file_path.resolve()), low_memory=False)
df[self.date_field_name] = df[self.date_field_name].astype(str).astype("datetime64[ns]")
df = read_as_df(file_path, low_memory=False)
if self.date_field_name in df.columns:
df[self.date_field_name] = pd.to_datetime(df[self.date_field_name])
# df.drop_duplicates([self.date_field_name], inplace=True)
return df
def get_symbol_from_file(self, file_path: Path) -> str:
return fname_to_code(file_path.name[: -len(self.file_suffix)].strip().lower())
return fname_to_code(file_path.stem.strip().lower())
def get_dump_fields(self, df_columns: Iterable[str]) -> Iterable[str]:
return (
@@ -274,10 +308,10 @@ class DumpDataAll(DumpDataBase):
all_datetime = set()
date_range_list = []
_fun = partial(self._get_date, as_set=True, is_begin_end=True)
with tqdm(total=len(self.csv_files)) as p_bar:
with tqdm(total=len(self.df_files)) as p_bar:
with ProcessPoolExecutor(max_workers=self.works) as executor:
for file_path, ((_begin_time, _end_time), _set_calendars) in zip(
self.csv_files, executor.map(_fun, self.csv_files)
self.df_files, executor.map(_fun, self.df_files)
):
all_datetime = all_datetime | _set_calendars
if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp):
@@ -305,9 +339,9 @@ class DumpDataAll(DumpDataBase):
def _dump_features(self):
logger.info("start dump features......")
_dump_func = partial(self._dump_bin, calendar_list=self._calendars_list)
with tqdm(total=len(self.csv_files)) as p_bar:
with tqdm(total=len(self.df_files)) as p_bar:
with ProcessPoolExecutor(max_workers=self.works) as executor:
for _ in executor.map(_dump_func, self.csv_files):
for _ in executor.map(_dump_func, self.df_files):
p_bar.update()
logger.info("end of features dump.\n")
@@ -325,16 +359,15 @@ class DumpDataFix(DumpDataAll):
_fun = partial(self._get_date, is_begin_end=True)
new_stock_files = sorted(
filter(
lambda x: fname_to_code(x.name[: -len(self.file_suffix)].strip().lower()).upper()
not in self._old_instruments,
self.csv_files,
lambda x: self.get_symbol_from_file(x).upper() not in self._old_instruments,
self.df_files,
)
)
with tqdm(total=len(new_stock_files)) as p_bar:
with ProcessPoolExecutor(max_workers=self.works) as execute:
for file_path, (_begin_time, _end_time) in zip(new_stock_files, execute.map(_fun, new_stock_files)):
if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp):
symbol = fname_to_code(self.get_symbol_from_file(file_path).lower()).upper()
symbol = self.get_symbol_from_file(file_path).upper()
_dt_map = self._old_instruments.setdefault(symbol, dict())
_dt_map[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_begin_time)
_dt_map[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end_time)
@@ -359,7 +392,7 @@ class DumpDataFix(DumpDataAll):
class DumpDataUpdate(DumpDataBase):
def __init__(
self,
csv_path: str,
data_path: str,
qlib_dir: str,
backup_dir: str = None,
freq: str = "day",
@@ -375,7 +408,7 @@ class DumpDataUpdate(DumpDataBase):
Parameters
----------
csv_path: str
data_path: str
stock data path or directory
qlib_dir: str
qlib(dump) data director
@@ -399,7 +432,7 @@ class DumpDataUpdate(DumpDataBase):
Use when debugging, default None
"""
super().__init__(
csv_path,
data_path,
qlib_dir,
backup_dir,
freq,
@@ -431,15 +464,19 @@ class DumpDataUpdate(DumpDataBase):
logger.info("start load all source data....")
all_df = []
def _read_csv(file_path: Path):
_df = pd.read_csv(file_path, parse_dates=[self.date_field_name])
def _read_df(file_path: Path):
_df = read_as_df(file_path)
if self.date_field_name in _df.columns and not np.issubdtype(
_df[self.date_field_name].dtype, np.datetime64
):
_df[self.date_field_name] = pd.to_datetime(_df[self.date_field_name])
if self.symbol_field_name not in _df.columns:
_df[self.symbol_field_name] = self.get_symbol_from_file(file_path)
return _df
with tqdm(total=len(self.csv_files)) as p_bar:
with tqdm(total=len(self.df_files)) as p_bar:
with ThreadPoolExecutor(max_workers=self.works) as executor:
for df in executor.map(_read_csv, self.csv_files):
for df in executor.map(_read_df, self.df_files):
if not df.empty:
all_df.append(df)
p_bar.update()