mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
add end_date parameter to collector.normalize_data
This commit is contained in:
@@ -238,7 +238,7 @@ class BaseNormalize(abc.ABC):
|
||||
"""
|
||||
self._date_field_name = date_field_name
|
||||
self._symbol_field_name = symbol_field_name
|
||||
|
||||
self.kwargs = kwargs
|
||||
self._calendar_list = self._get_calendar_list()
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -285,7 +285,9 @@ class Normalize:
|
||||
self._source_dir = Path(source_dir).expanduser()
|
||||
self._target_dir = Path(target_dir).expanduser()
|
||||
self._target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._date_field_name = date_field_name
|
||||
self._symbol_field_name = symbol_field_name
|
||||
self._end_date = kwargs.get("end_date", None)
|
||||
self._max_workers = max_workers
|
||||
|
||||
self._normalize_obj = normalize_class(
|
||||
@@ -297,6 +299,9 @@ class Normalize:
|
||||
df = pd.read_csv(file_path)
|
||||
df = self._normalize_obj.normalize(df)
|
||||
if df is not None and not df.empty:
|
||||
if self._end_date is not None:
|
||||
_mask = pd.to_datetime(df[self._date_field_name]) <= pd.Timestamp(self._end_date)
|
||||
df = df[_mask]
|
||||
df.to_csv(self._target_dir.joinpath(file_path.name), index=False)
|
||||
|
||||
def normalize(self):
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import os
|
||||
import time
|
||||
import bisect
|
||||
import pickle
|
||||
@@ -10,7 +9,7 @@ import random
|
||||
import requests
|
||||
import functools
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Tuple
|
||||
from typing import Iterable, Tuple, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -47,7 +46,7 @@ _CALENDAR_MAP = {}
|
||||
MINIMUM_SYMBOLS_NUM = 3900
|
||||
|
||||
|
||||
def get_calendar_list(bench_code="CSI300") -> list:
|
||||
def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
|
||||
"""get SH/SZ history calendar list
|
||||
|
||||
Parameters
|
||||
|
||||
@@ -379,14 +379,13 @@ class YahooNormalize1d(YahooNormalize, ABC):
|
||||
df = df.set_index(self._date_field_name)
|
||||
_close = self._get_first_close(df)
|
||||
for _col in df.columns:
|
||||
if _col == self._symbol_field_name:
|
||||
# NOTE: retain original adjclose, required for incremental updates
|
||||
if _col in [self._symbol_field_name, "adjclose", "change"]:
|
||||
continue
|
||||
if _col == "volume":
|
||||
df[_col] = df[_col] * _close
|
||||
elif _col != "change":
|
||||
df[_col] = df[_col] / _close
|
||||
else:
|
||||
pass
|
||||
df[_col] = df[_col] / _close
|
||||
return df.reset_index()
|
||||
|
||||
|
||||
@@ -812,7 +811,7 @@ class Run(BaseRun):
|
||||
max_collector_count, delay, start, end, self.interval, check_data_length, limit_nums
|
||||
)
|
||||
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol", end_date: str = None):
|
||||
"""normalize data
|
||||
|
||||
Parameters
|
||||
@@ -821,12 +820,14 @@ class Run(BaseRun):
|
||||
date field name, default date
|
||||
symbol_field_name: str
|
||||
symbol field name, default symbol
|
||||
end_date: str
|
||||
if not None, normalize the last date saved (including end_date); if None, it will ignore this parameter; by default None
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region cn --interval 1d
|
||||
"""
|
||||
super(Run, self).normalize_data(date_field_name, symbol_field_name)
|
||||
super(Run, self).normalize_data(date_field_name, symbol_field_name, end_date=end_date)
|
||||
|
||||
def normalize_data_1d_extend(
|
||||
self, old_qlib_data_dir, date_field_name: str = "date", symbol_field_name: str = "symbol"
|
||||
|
||||
Reference in New Issue
Block a user