1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Fix pit download_data script TypeError (#978) (#979)

* Fix pit download_data script TypeError (#978)

* Format pit collector with black

* Format pit collector with black
This commit is contained in:
Chauncey
2022-03-15 14:02:14 +08:00
committed by GitHub
parent 2681c61c60
commit 5f18ba7970
2 changed files with 40 additions and 17 deletions

View File

@@ -369,8 +369,6 @@ class BaseRun(abc.ABC):
default 2
delay: float
time.sleep(delay), default 0
interval: str
freq, value from [1min, 1d], default 1d
start: str
start datetime, default "2000-01-01"
end: str

View File

@@ -2,10 +2,8 @@
# Licensed under the MIT License.
import re
import abc
import sys
import datetime
from abc import ABC
from pathlib import Path
import fire
@@ -114,15 +112,27 @@ class PitCollector(BaseCollector):
market = {"ss": "sh"}.get(market, market) # baostock's API naming is different from default symbol list
symbol = f"{market}.{code}"
rs_report = bs.query_performance_express_report(
code=symbol, start_date=str(start_datetime.date()), end_date=str(end_datetime.date())
code=symbol,
start_date=str(start_datetime.date()),
end_date=str(end_datetime.date()),
)
report_list = []
while (rs_report.error_code == "0") & rs_report.next():
report_list.append(rs_report.get_row_data())
df_report = pd.DataFrame(report_list, columns=rs_report.fields)
if {"performanceExpPubDate", "performanceExpStatDate", "performanceExpressROEWa"} <= set(rs_report.fields):
df_report = df_report[["performanceExpPubDate", "performanceExpStatDate", "performanceExpressROEWa"]]
if {
"performanceExpPubDate",
"performanceExpStatDate",
"performanceExpressROEWa",
} <= set(rs_report.fields):
df_report = df_report[
[
"performanceExpPubDate",
"performanceExpStatDate",
"performanceExpressROEWa",
]
]
df_report.rename(
columns={
"performanceExpPubDate": "date",
@@ -149,7 +159,11 @@ class PitCollector(BaseCollector):
if {"pubDate", "statDate", "roeAvg"} <= set(rs_profit.fields):
df_profit = df_profit[["pubDate", "statDate", "roeAvg"]]
df_profit.rename(
columns={"pubDate": "date", "statDate": "period", "roeAvg": "value"},
columns={
"pubDate": "date",
"statDate": "period",
"roeAvg": "value",
},
inplace=True,
)
df_profit["value"] = df_profit["value"].apply(_str_to_float)
@@ -157,7 +171,9 @@ class PitCollector(BaseCollector):
forecast_list = []
rs_forecast = bs.query_forecast_report(
code=symbol, start_date=str(start_datetime.date()), end_date=str(end_datetime.date())
code=symbol,
start_date=str(start_datetime.date()),
end_date=str(end_datetime.date()),
)
while (rs_forecast.error_code == "0") & rs_forecast.next():
@@ -192,7 +208,11 @@ class PitCollector(BaseCollector):
df_forecast["profitForcastChgPctUp"] + df_forecast["profitForcastChgPctDwn"]
) / 200
df_forecast["field"] = "YOYNI"
df_forecast.drop(["profitForcastChgPctUp", "profitForcastChgPctDwn"], axis=1, inplace=True)
df_forecast.drop(
["profitForcastChgPctUp", "profitForcastChgPctDwn"],
axis=1,
inplace=True,
)
growth_list = []
for year in range(start_datetime.year - 1, end_datetime.year + 1):
@@ -240,7 +260,11 @@ class PitCollector(BaseCollector):
logger.warning(f"{error_msg}:{e}")
def get_data(
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
self,
symbol: str,
interval: str,
start_datetime: pd.Timestamp,
end_datetime: pd.Timestamp,
) -> [pd.DataFrame]:
if interval == self.INTERVAL_quarterly:
@@ -266,8 +290,6 @@ class Run(BaseRun):
----------
source_dir: str
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
normalize_dir: str
Directory for normalize data, default "Path(__file__).parent/normalize"
max_workers: int
Concurrent number, default is 4
interval: str
@@ -289,7 +311,6 @@ class Run(BaseRun):
delay=0,
start=None,
end=None,
interval="quarterly",
check_data_length=False,
limit_nums=None,
**kwargs,
@@ -302,8 +323,6 @@ class Run(BaseRun):
default 2
delay: float
time.sleep(delay), default 0
interval: str
freq, value from [quarterly, annual], default 1d
start: str
start datetime, default "2000-01-01"
end: str
@@ -320,7 +339,13 @@ class Run(BaseRun):
"""
super(Run, self).download_data(
max_collector_count, delay, start, end, interval, check_data_length, limit_nums, **kwargs
max_collector_count,
delay,
start,
end,
check_data_length,
limit_nums,
**kwargs,
)
def normalize_class_name(self):