diff --git a/scripts/data_collector/base.py b/scripts/data_collector/base.py index 08e1838a4..ce94f1783 100644 --- a/scripts/data_collector/base.py +++ b/scripts/data_collector/base.py @@ -238,7 +238,7 @@ class BaseNormalize(abc.ABC): """ self._date_field_name = date_field_name self._symbol_field_name = symbol_field_name - + self.kwargs = kwargs self._calendar_list = self._get_calendar_list() @abc.abstractmethod @@ -285,7 +285,9 @@ class Normalize: self._source_dir = Path(source_dir).expanduser() self._target_dir = Path(target_dir).expanduser() self._target_dir.mkdir(parents=True, exist_ok=True) - + self._date_field_name = date_field_name + self._symbol_field_name = symbol_field_name + self._end_date = kwargs.get("end_date", None) self._max_workers = max_workers self._normalize_obj = normalize_class( @@ -297,6 +299,9 @@ class Normalize: df = pd.read_csv(file_path) df = self._normalize_obj.normalize(df) if df is not None and not df.empty: + if self._end_date is not None: + _mask = pd.to_datetime(df[self._date_field_name]) <= pd.Timestamp(self._end_date) + df = df[_mask] df.to_csv(self._target_dir.joinpath(file_path.name), index=False) def normalize(self): diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 3f4539612..1a8d479d9 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import re -import os import time import bisect import pickle @@ -10,7 +9,7 @@ import random import requests import functools from pathlib import Path -from typing import Iterable, Tuple +from typing import Iterable, Tuple, List import numpy as np import pandas as pd @@ -47,7 +46,7 @@ _CALENDAR_MAP = {} MINIMUM_SYMBOLS_NUM = 3900 -def get_calendar_list(bench_code="CSI300") -> list: +def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]: """get SH/SZ history calendar list Parameters diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index f9a209168..44cfce7ca 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -379,14 +379,13 @@ class YahooNormalize1d(YahooNormalize, ABC): df = df.set_index(self._date_field_name) _close = self._get_first_close(df) for _col in df.columns: - if _col == self._symbol_field_name: + # NOTE: retain original adjclose, required for incremental updates + if _col in [self._symbol_field_name, "adjclose", "change"]: continue if _col == "volume": df[_col] = df[_col] * _close - elif _col != "change": - df[_col] = df[_col] / _close else: - pass + df[_col] = df[_col] / _close return df.reset_index() @@ -812,7 +811,7 @@ class Run(BaseRun): max_collector_count, delay, start, end, self.interval, check_data_length, limit_nums ) - def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"): + def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol", end_date: str = None): """normalize data Parameters @@ -821,12 +820,14 @@ class Run(BaseRun): date field name, default date symbol_field_name: str symbol field name, default symbol + end_date: str + if not None, normalize the last date saved (including end_date); if None, it will ignore this parameter; by default None Examples --------- $ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region cn --interval 1d """ - super(Run, self).normalize_data(date_field_name, symbol_field_name) + super(Run, self).normalize_data(date_field_name, symbol_field_name, end_date=end_date) def normalize_data_1d_extend( self, old_qlib_data_dir, date_field_name: str = "date", symbol_field_name: str = "symbol"