diff --git a/scripts/data_collector/base.py b/scripts/data_collector/base.py index 236d9cddf..d3b7ad57a 100644 --- a/scripts/data_collector/base.py +++ b/scripts/data_collector/base.py @@ -369,8 +369,6 @@ class BaseRun(abc.ABC): default 2 delay: float time.sleep(delay), default 0 - interval: str - freq, value from [1min, 1d], default 1d start: str start datetime, default "2000-01-01" end: str diff --git a/scripts/data_collector/pit/collector.py b/scripts/data_collector/pit/collector.py index ebca199ec..c1e811bbd 100644 --- a/scripts/data_collector/pit/collector.py +++ b/scripts/data_collector/pit/collector.py @@ -2,10 +2,8 @@ # Licensed under the MIT License. import re -import abc import sys import datetime -from abc import ABC from pathlib import Path import fire @@ -114,15 +112,27 @@ class PitCollector(BaseCollector): market = {"ss": "sh"}.get(market, market) # baostock's API naming is different from default symbol list symbol = f"{market}.{code}" rs_report = bs.query_performance_express_report( - code=symbol, start_date=str(start_datetime.date()), end_date=str(end_datetime.date()) + code=symbol, + start_date=str(start_datetime.date()), + end_date=str(end_datetime.date()), ) report_list = [] while (rs_report.error_code == "0") & rs_report.next(): report_list.append(rs_report.get_row_data()) df_report = pd.DataFrame(report_list, columns=rs_report.fields) - if {"performanceExpPubDate", "performanceExpStatDate", "performanceExpressROEWa"} <= set(rs_report.fields): - df_report = df_report[["performanceExpPubDate", "performanceExpStatDate", "performanceExpressROEWa"]] + if { + "performanceExpPubDate", + "performanceExpStatDate", + "performanceExpressROEWa", + } <= set(rs_report.fields): + df_report = df_report[ + [ + "performanceExpPubDate", + "performanceExpStatDate", + "performanceExpressROEWa", + ] + ] df_report.rename( columns={ "performanceExpPubDate": "date", @@ -149,7 +159,11 @@ class PitCollector(BaseCollector): if {"pubDate", "statDate", "roeAvg"} <= set(rs_profit.fields): df_profit = df_profit[["pubDate", "statDate", "roeAvg"]] df_profit.rename( - columns={"pubDate": "date", "statDate": "period", "roeAvg": "value"}, + columns={ + "pubDate": "date", + "statDate": "period", + "roeAvg": "value", + }, inplace=True, ) df_profit["value"] = df_profit["value"].apply(_str_to_float) @@ -157,7 +171,9 @@ class PitCollector(BaseCollector): forecast_list = [] rs_forecast = bs.query_forecast_report( - code=symbol, start_date=str(start_datetime.date()), end_date=str(end_datetime.date()) + code=symbol, + start_date=str(start_datetime.date()), + end_date=str(end_datetime.date()), ) while (rs_forecast.error_code == "0") & rs_forecast.next(): @@ -192,7 +208,11 @@ class PitCollector(BaseCollector): df_forecast["profitForcastChgPctUp"] + df_forecast["profitForcastChgPctDwn"] ) / 200 df_forecast["field"] = "YOYNI" - df_forecast.drop(["profitForcastChgPctUp", "profitForcastChgPctDwn"], axis=1, inplace=True) + df_forecast.drop( + ["profitForcastChgPctUp", "profitForcastChgPctDwn"], + axis=1, + inplace=True, + ) growth_list = [] for year in range(start_datetime.year - 1, end_datetime.year + 1): @@ -240,7 +260,11 @@ class PitCollector(BaseCollector): logger.warning(f"{error_msg}:{e}") def get_data( - self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp + self, + symbol: str, + interval: str, + start_datetime: pd.Timestamp, + end_datetime: pd.Timestamp, ) -> [pd.DataFrame]: if interval == self.INTERVAL_quarterly: @@ -266,8 +290,6 @@ class Run(BaseRun): ---------- source_dir: str The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source" - normalize_dir: str - Directory for normalize data, default "Path(__file__).parent/normalize" max_workers: int Concurrent number, default is 4 interval: str @@ -289,7 +311,6 @@ class Run(BaseRun): delay=0, start=None, end=None, - interval="quarterly", check_data_length=False, limit_nums=None, **kwargs, @@ -302,8 +323,6 @@ class Run(BaseRun): default 2 delay: float time.sleep(delay), default 0 - interval: str - freq, value from [quarterly, annual], default 1d start: str start datetime, default "2000-01-01" end: str @@ -320,7 +339,13 @@ class Run(BaseRun): """ super(Run, self).download_data( - max_collector_count, delay, start, end, interval, check_data_length, limit_nums, **kwargs + max_collector_count, + delay, + start, + end, + check_data_length, + limit_nums, + **kwargs, ) def normalize_class_name(self):