mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 12:00:58 +08:00
US stock code supports Windows
This commit is contained in:
@@ -20,7 +20,7 @@ pip install -r requirements.txt
|
||||
|
||||
### Download data and Normalize data
|
||||
```bash
|
||||
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d --normalize_dir ~/.qlib/stock_data/normalize
|
||||
```
|
||||
|
||||
### Download Data
|
||||
|
||||
@@ -18,6 +18,7 @@ from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from yahooquery import Ticker
|
||||
from dateutil.tz import tzlocal
|
||||
from qlib.utils import code_to_fname
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
@@ -40,7 +41,7 @@ class YahooCollector:
|
||||
end=None,
|
||||
interval="1d",
|
||||
max_workers=4,
|
||||
max_collector_count=5,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
limit_nums: int = None,
|
||||
@@ -55,7 +56,7 @@ class YahooCollector:
|
||||
max_workers: int
|
||||
workers, default 4
|
||||
max_collector_count: int
|
||||
default 5
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
@@ -147,11 +148,10 @@ class YahooCollector:
|
||||
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
|
||||
df["symbol"] = symbol
|
||||
if stock_path.exists():
|
||||
with stock_path.open("a") as fp:
|
||||
df.to_csv(fp, index=False, header=False)
|
||||
_temp_df = pd.read_csv(stock_path, nrows=0)
|
||||
df.loc[:, _temp_df.columns].to_csv(stock_path, index=False, header=False, mode="a")
|
||||
else:
|
||||
with stock_path.open("w") as fp:
|
||||
df.to_csv(fp, index=False)
|
||||
df.to_csv(stock_path, index=False, mode="w")
|
||||
|
||||
def _save_small_data(self, symbol, df):
|
||||
if len(df) <= self.min_numbers_trading:
|
||||
@@ -350,7 +350,7 @@ class YahooCollectorUS(YahooCollector):
|
||||
pass
|
||||
|
||||
def normalize_symbol(self, symbol):
|
||||
return symbol.upper()
|
||||
return code_to_fname(symbol).upper()
|
||||
|
||||
@property
|
||||
def _timezone(self):
|
||||
|
||||
@@ -14,6 +14,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from qlib.utils import fname_to_code, code_to_fname
|
||||
|
||||
|
||||
class DumpDataBase:
|
||||
@@ -27,7 +28,6 @@ class DumpDataBase:
|
||||
HIGH_FREQ_FORMAT = "%Y-%m-%d %H:%M:%S"
|
||||
INSTRUMENTS_SEP = "\t"
|
||||
INSTRUMENTS_FILE_NAME = "all.txt"
|
||||
SAVE_INST_FIELD = "save_inst"
|
||||
|
||||
UPDATE_MODE = "update"
|
||||
ALL_MODE = "all"
|
||||
@@ -45,7 +45,6 @@ class DumpDataBase:
|
||||
exclude_fields: str = "",
|
||||
include_fields: str = "",
|
||||
limit_nums: int = None,
|
||||
inst_prefix: str = "",
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -73,9 +72,6 @@ class DumpDataBase:
|
||||
fields not dumped
|
||||
limit_nums: int
|
||||
Use when debugging, default None
|
||||
inst_prefix: str
|
||||
add a column to the instruments file and record the saved instrument name,
|
||||
the US stock code contains "PRN", and the directory cannot be created on Windows system, use the "_" prefix.
|
||||
"""
|
||||
csv_path = Path(csv_path).expanduser()
|
||||
if isinstance(exclude_fields, str):
|
||||
@@ -84,7 +80,6 @@ class DumpDataBase:
|
||||
include_fields = include_fields.split(",")
|
||||
self._exclude_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, exclude_fields)))
|
||||
self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields)))
|
||||
self._inst_prefix = inst_prefix.strip()
|
||||
self.file_suffix = file_suffix
|
||||
self.symbol_field_name = symbol_field_name
|
||||
self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path])
|
||||
@@ -145,7 +140,7 @@ class DumpDataBase:
|
||||
return df
|
||||
|
||||
def get_symbol_from_file(self, file_path: Path) -> str:
|
||||
return file_path.name[: -len(self.file_suffix)].strip().lower()
|
||||
return fname_to_code(file_path.name[: -len(self.file_suffix)].strip().lower())
|
||||
|
||||
def get_dump_fields(self, df_columns: Iterable[str]) -> Iterable[str]:
|
||||
return (
|
||||
@@ -173,7 +168,6 @@ class DumpDataBase:
|
||||
self.symbol_field_name,
|
||||
self.INSTRUMENTS_START_FIELD,
|
||||
self.INSTRUMENTS_END_FIELD,
|
||||
self.SAVE_INST_FIELD,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -190,13 +184,11 @@ class DumpDataBase:
|
||||
instruments_path = str(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME).resolve())
|
||||
if isinstance(instruments_data, pd.DataFrame):
|
||||
_df_fields = [self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD]
|
||||
if self._inst_prefix:
|
||||
_df_fields.append(self.SAVE_INST_FIELD)
|
||||
instruments_data[self.SAVE_INST_FIELD] = instruments_data[self.symbol_field_name].apply(
|
||||
lambda x: f"{self._inst_prefix}{x}"
|
||||
)
|
||||
instruments_data = instruments_data.loc[:, _df_fields]
|
||||
instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP)
|
||||
instruments_data[self.symbol_field_name] = instruments_data[self.symbol_field_name].apply(
|
||||
lambda x: fname_to_code(x.lower()).upper()
|
||||
)
|
||||
instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP, index=False)
|
||||
else:
|
||||
np.savetxt(instruments_path, instruments_data, fmt="%s", encoding="utf-8")
|
||||
|
||||
@@ -223,26 +215,26 @@ class DumpDataBase:
|
||||
logger.warning(f"{features_dir.name} data is None or empty")
|
||||
return
|
||||
# align index
|
||||
_df = self.data_merge_calendar(df, self._calendars_list)
|
||||
_df = self.data_merge_calendar(df, calendar_list)
|
||||
# used when creating a bin file
|
||||
date_index = self.get_datetime_index(_df, calendar_list)
|
||||
for field in self.get_dump_fields(_df.columns):
|
||||
bin_path = features_dir.joinpath(f"{field}.{self.freq}{self.DUMP_FILE_SUFFIX}")
|
||||
if field not in _df.columns:
|
||||
continue
|
||||
if self._mode == self.UPDATE_MODE:
|
||||
if bin_path.exists() and self._mode == self.UPDATE_MODE:
|
||||
# update
|
||||
with bin_path.open("ab") as fp:
|
||||
np.array(_df[field]).astype("<f").tofile(fp)
|
||||
elif self._mode == self.ALL_MODE:
|
||||
np.hstack([date_index, _df[field]]).astype("<f").tofile(str(bin_path.resolve()))
|
||||
else:
|
||||
raise ValueError(f"{self._mode} cannot support!")
|
||||
# append; self._mode == self.ALL_MODE or not bin_path.exists()
|
||||
np.hstack([date_index, _df[field]]).astype("<f").tofile(str(bin_path.resolve()))
|
||||
|
||||
def _dump_bin(self, file_or_data: [Path, pd.DataFrame], calendar_list: List[pd.Timestamp]):
|
||||
if isinstance(file_or_data, pd.DataFrame):
|
||||
if file_or_data.empty:
|
||||
return
|
||||
code = file_or_data.iloc[0][self.symbol_field_name].lower()
|
||||
code = fname_to_code(file_or_data.iloc[0][self.symbol_field_name].lower())
|
||||
df = file_or_data
|
||||
elif isinstance(file_or_data, Path):
|
||||
code = self.get_symbol_from_file(file_or_data)
|
||||
@@ -253,8 +245,7 @@ class DumpDataBase:
|
||||
logger.warning(f"{code} data is None or empty")
|
||||
return
|
||||
# features save dir
|
||||
code = self._inst_prefix + code if self._inst_prefix else code
|
||||
features_dir = self._features_dir.joinpath(code)
|
||||
features_dir = self._features_dir.joinpath(code_to_fname(code).lower())
|
||||
features_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._data_to_bin(df, calendar_list, features_dir)
|
||||
|
||||
@@ -283,8 +274,6 @@ class DumpDataAll(DumpDataBase):
|
||||
_end_time = self._format_datetime(_end_time)
|
||||
symbol = self.get_symbol_from_file(file_path)
|
||||
_inst_fields = [symbol.upper(), _begin_time, _end_time]
|
||||
if self._inst_prefix:
|
||||
_inst_fields.append(self._inst_prefix + symbol.upper())
|
||||
date_range_list.append(f"{self.INSTRUMENTS_SEP.join(_inst_fields)}")
|
||||
p_bar.update()
|
||||
self._kwargs["all_datetime_set"] = all_datetime
|
||||
@@ -323,12 +312,18 @@ class DumpDataFix(DumpDataAll):
|
||||
def _dump_instruments(self):
|
||||
logger.info("start dump instruments......")
|
||||
_fun = partial(self._get_date, is_begin_end=True)
|
||||
new_stock_files = sorted(filter(lambda x: x.name not in self._old_instruments, self.csv_files))
|
||||
new_stock_files = sorted(
|
||||
filter(
|
||||
lambda x: fname_to_code(x.name[: -len(self.file_suffix)].strip().lower()).upper()
|
||||
not in self._old_instruments,
|
||||
self.csv_files,
|
||||
)
|
||||
)
|
||||
with tqdm(total=len(new_stock_files)) as p_bar:
|
||||
with ProcessPoolExecutor(max_workers=self.works) as execute:
|
||||
for file_path, (_begin_time, _end_time) in zip(new_stock_files, execute.map(_fun, new_stock_files)):
|
||||
if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp):
|
||||
symbol = self.get_symbol_from_file(file_path).upper()
|
||||
symbol = fname_to_code(self.get_symbol_from_file(file_path).lower()).upper()
|
||||
_dt_map = self._old_instruments.setdefault(symbol, dict())
|
||||
_dt_map[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_begin_time)
|
||||
_dt_map[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end_time)
|
||||
@@ -406,10 +401,10 @@ class DumpDataUpdate(DumpDataBase):
|
||||
)
|
||||
self._mode = self.UPDATE_MODE
|
||||
self._old_calendar_list = self._read_calendars(self._calendars_dir.joinpath(f"{self.freq}.txt"))
|
||||
self._update_instruments = self._read_instruments(
|
||||
self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME)
|
||||
).to_dict(
|
||||
orient="index"
|
||||
self._update_instruments = (
|
||||
self._read_instruments(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME))
|
||||
.set_index([self.symbol_field_name])
|
||||
.to_dict(orient="index")
|
||||
) # type: dict
|
||||
|
||||
# load all csv files
|
||||
@@ -425,10 +420,7 @@ class DumpDataUpdate(DumpDataBase):
|
||||
all_df = []
|
||||
|
||||
def _read_csv(file_path: Path):
|
||||
if self._include_fields:
|
||||
_df = pd.read_csv(file_path, usecols=self._include_fields)
|
||||
else:
|
||||
_df = pd.read_csv(file_path)
|
||||
_df = pd.read_csv(file_path, parse_dates=[self.date_field_name])
|
||||
if self.symbol_field_name not in _df.columns:
|
||||
_df[self.symbol_field_name] = self.get_symbol_from_file(file_path)
|
||||
return _df
|
||||
@@ -436,7 +428,7 @@ class DumpDataUpdate(DumpDataBase):
|
||||
with tqdm(total=len(self.csv_files)) as p_bar:
|
||||
with ThreadPoolExecutor(max_workers=self.works) as executor:
|
||||
for df in executor.map(_read_csv, self.csv_files):
|
||||
if df:
|
||||
if not df.empty:
|
||||
all_df.append(df)
|
||||
p_bar.update()
|
||||
|
||||
@@ -455,25 +447,27 @@ class DumpDataUpdate(DumpDataBase):
|
||||
with ProcessPoolExecutor(max_workers=self.works) as executor:
|
||||
futures = {}
|
||||
for _code, _df in self._all_data.groupby(self.symbol_field_name):
|
||||
_code = str(_code).upper()
|
||||
_code = fname_to_code(str(_code).lower()).upper()
|
||||
_start, _end = self._get_date(_df, is_begin_end=True)
|
||||
if not (isinstance(_start, pd.Timestamp) and isinstance(_end, pd.Timestamp)):
|
||||
continue
|
||||
if _code in self._update_instruments:
|
||||
self._update_instruments[_code]["end_time"] = _end
|
||||
self._update_instruments[_code][self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end)
|
||||
futures[executor.submit(self._dump_bin, _df, self._update_calendars)] = _code
|
||||
else:
|
||||
# new stock
|
||||
_dt_range = self._update_instruments.setdefault(_code, dict())
|
||||
_dt_range["start_time"] = _start
|
||||
_dt_range["end_time"] = _end
|
||||
_dt_range[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_start)
|
||||
_dt_range[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end)
|
||||
futures[executor.submit(self._dump_bin, _df, self._new_calendar_list)] = _code
|
||||
|
||||
for _future in tqdm(as_completed(futures)):
|
||||
try:
|
||||
_future.result()
|
||||
except Exception:
|
||||
error_code[futures[_future]] = traceback.format_exc()
|
||||
with tqdm(total=len(futures)) as p_bar:
|
||||
for _future in as_completed(futures):
|
||||
try:
|
||||
_future.result()
|
||||
except Exception:
|
||||
error_code[futures[_future]] = traceback.format_exc()
|
||||
p_bar.update()
|
||||
logger.info(f"dump bin errors: {error_code}")
|
||||
|
||||
logger.info("end of features dump.\n")
|
||||
@@ -481,7 +475,9 @@ class DumpDataUpdate(DumpDataBase):
|
||||
def dump(self):
|
||||
self.save_calendars(self._new_calendar_list)
|
||||
self._dump_features()
|
||||
self.save_instruments(pd.DataFrame.from_dict(self._update_instruments, orient="index"))
|
||||
df = pd.DataFrame.from_dict(self._update_instruments, orient="index")
|
||||
df.index.names = [self.symbol_field_name]
|
||||
self.save_instruments(df.reset_index())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import fire
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
Reference in New Issue
Block a user