mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
* Fix pit download_data script TypeError (#978) * Format pit collector with black * Format pit collector with black
This commit is contained in:
@@ -369,8 +369,6 @@ class BaseRun(abc.ABC):
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
start: str
|
||||
start datetime, default "2000-01-01"
|
||||
end: str
|
||||
|
||||
@@ -2,10 +2,8 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import abc
|
||||
import sys
|
||||
import datetime
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
@@ -114,15 +112,27 @@ class PitCollector(BaseCollector):
|
||||
market = {"ss": "sh"}.get(market, market) # baostock's API naming is different from default symbol list
|
||||
symbol = f"{market}.{code}"
|
||||
rs_report = bs.query_performance_express_report(
|
||||
code=symbol, start_date=str(start_datetime.date()), end_date=str(end_datetime.date())
|
||||
code=symbol,
|
||||
start_date=str(start_datetime.date()),
|
||||
end_date=str(end_datetime.date()),
|
||||
)
|
||||
report_list = []
|
||||
while (rs_report.error_code == "0") & rs_report.next():
|
||||
report_list.append(rs_report.get_row_data())
|
||||
|
||||
df_report = pd.DataFrame(report_list, columns=rs_report.fields)
|
||||
if {"performanceExpPubDate", "performanceExpStatDate", "performanceExpressROEWa"} <= set(rs_report.fields):
|
||||
df_report = df_report[["performanceExpPubDate", "performanceExpStatDate", "performanceExpressROEWa"]]
|
||||
if {
|
||||
"performanceExpPubDate",
|
||||
"performanceExpStatDate",
|
||||
"performanceExpressROEWa",
|
||||
} <= set(rs_report.fields):
|
||||
df_report = df_report[
|
||||
[
|
||||
"performanceExpPubDate",
|
||||
"performanceExpStatDate",
|
||||
"performanceExpressROEWa",
|
||||
]
|
||||
]
|
||||
df_report.rename(
|
||||
columns={
|
||||
"performanceExpPubDate": "date",
|
||||
@@ -149,7 +159,11 @@ class PitCollector(BaseCollector):
|
||||
if {"pubDate", "statDate", "roeAvg"} <= set(rs_profit.fields):
|
||||
df_profit = df_profit[["pubDate", "statDate", "roeAvg"]]
|
||||
df_profit.rename(
|
||||
columns={"pubDate": "date", "statDate": "period", "roeAvg": "value"},
|
||||
columns={
|
||||
"pubDate": "date",
|
||||
"statDate": "period",
|
||||
"roeAvg": "value",
|
||||
},
|
||||
inplace=True,
|
||||
)
|
||||
df_profit["value"] = df_profit["value"].apply(_str_to_float)
|
||||
@@ -157,7 +171,9 @@ class PitCollector(BaseCollector):
|
||||
|
||||
forecast_list = []
|
||||
rs_forecast = bs.query_forecast_report(
|
||||
code=symbol, start_date=str(start_datetime.date()), end_date=str(end_datetime.date())
|
||||
code=symbol,
|
||||
start_date=str(start_datetime.date()),
|
||||
end_date=str(end_datetime.date()),
|
||||
)
|
||||
|
||||
while (rs_forecast.error_code == "0") & rs_forecast.next():
|
||||
@@ -192,7 +208,11 @@ class PitCollector(BaseCollector):
|
||||
df_forecast["profitForcastChgPctUp"] + df_forecast["profitForcastChgPctDwn"]
|
||||
) / 200
|
||||
df_forecast["field"] = "YOYNI"
|
||||
df_forecast.drop(["profitForcastChgPctUp", "profitForcastChgPctDwn"], axis=1, inplace=True)
|
||||
df_forecast.drop(
|
||||
["profitForcastChgPctUp", "profitForcastChgPctDwn"],
|
||||
axis=1,
|
||||
inplace=True,
|
||||
)
|
||||
|
||||
growth_list = []
|
||||
for year in range(start_datetime.year - 1, end_datetime.year + 1):
|
||||
@@ -240,7 +260,11 @@ class PitCollector(BaseCollector):
|
||||
logger.warning(f"{error_msg}:{e}")
|
||||
|
||||
def get_data(
|
||||
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
||||
self,
|
||||
symbol: str,
|
||||
interval: str,
|
||||
start_datetime: pd.Timestamp,
|
||||
end_datetime: pd.Timestamp,
|
||||
) -> [pd.DataFrame]:
|
||||
|
||||
if interval == self.INTERVAL_quarterly:
|
||||
@@ -266,8 +290,6 @@ class Run(BaseRun):
|
||||
----------
|
||||
source_dir: str
|
||||
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
||||
normalize_dir: str
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
max_workers: int
|
||||
Concurrent number, default is 4
|
||||
interval: str
|
||||
@@ -289,7 +311,6 @@ class Run(BaseRun):
|
||||
delay=0,
|
||||
start=None,
|
||||
end=None,
|
||||
interval="quarterly",
|
||||
check_data_length=False,
|
||||
limit_nums=None,
|
||||
**kwargs,
|
||||
@@ -302,8 +323,6 @@ class Run(BaseRun):
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [quarterly, annual], default 1d
|
||||
start: str
|
||||
start datetime, default "2000-01-01"
|
||||
end: str
|
||||
@@ -320,7 +339,13 @@ class Run(BaseRun):
|
||||
"""
|
||||
|
||||
super(Run, self).download_data(
|
||||
max_collector_count, delay, start, end, interval, check_data_length, limit_nums, **kwargs
|
||||
max_collector_count,
|
||||
delay,
|
||||
start,
|
||||
end,
|
||||
check_data_length,
|
||||
limit_nums,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def normalize_class_name(self):
|
||||
|
||||
Reference in New Issue
Block a user