1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Black format

This commit is contained in:
Jactus
2021-02-22 11:42:36 +08:00
parent 37871389b9
commit dc4aa67503
13 changed files with 147 additions and 33 deletions

View File

@@ -191,7 +191,15 @@ man_pages = [(master_doc, "qlib", u"QLib Documentation", [author], 1)]
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, "QLib", u"QLib Documentation", author, "QLib", "One line description of project.", "Miscellaneous",),
(
master_doc,
"QLib",
u"QLib Documentation",
author,
"QLib",
"One line description of project.",
"Miscellaneous",
),
]

View File

@@ -721,7 +721,12 @@ class TemporalFusionTransformer:
encoder_steps = self.num_encoder_steps
# Inputs.
all_inputs = tf.keras.layers.Input(shape=(time_steps, combined_input_size,))
all_inputs = tf.keras.layers.Input(
shape=(
time_steps,
combined_input_size,
)
)
unknown_inputs, known_combined_layer, obs_inputs, static_inputs = self.get_tft_embeddings(all_inputs)
@@ -861,7 +866,10 @@ class TemporalFusionTransformer:
"""Returns LSTM cell initialized with default parameters."""
if self.use_cudnn:
lstm = tf.keras.layers.CuDNNLSTM(
self.hidden_layer_size, return_sequences=True, return_state=return_state, stateful=False,
self.hidden_layer_size,
return_sequences=True,
return_state=return_state,
stateful=False,
)
else:
lstm = tf.keras.layers.LSTM(

View File

@@ -20,7 +20,10 @@ class HighFreqHandler(DataHandlerLP):
new_l = []
for p in proc_l:
p["kwargs"].update(
{"fit_start_time": fit_start_time, "fit_end_time": fit_end_time,}
{
"fit_start_time": fit_start_time,
"fit_end_time": fit_end_time,
}
)
new_l.append(p)
return new_l
@@ -30,7 +33,11 @@ class HighFreqHandler(DataHandlerLP):
data_loader = {
"class": "QlibDataLoader",
"kwargs": {"config": self.get_feature_config(), "swap_level": False, "freq": "1min",},
"kwargs": {
"config": self.get_feature_config(),
"swap_level": False,
"freq": "1min",
},
}
super().__init__(
instruments=instruments,
@@ -61,7 +68,8 @@ class HighFreqHandler(DataHandlerLP):
feature_ops = template_norm.format(
template_if.format(
template_fillnan.format(template_paused.format("$close")), template_paused.format(price_field),
template_fillnan.format(template_paused.format("$close")),
template_paused.format(price_field),
),
template_fillnan.format(template_paused.format("$close")),
)
@@ -111,14 +119,24 @@ class HighFreqHandler(DataHandlerLP):
class HighFreqBacktestHandler(DataHandler):
def __init__(
self, instruments="csi300", start_time=None, end_time=None,
self,
instruments="csi300",
start_time=None,
end_time=None,
):
data_loader = {
"class": "QlibDataLoader",
"kwargs": {"config": self.get_feature_config(), "swap_level": False, "freq": "1min",},
"kwargs": {
"config": self.get_feature_config(),
"swap_level": False,
"freq": "1min",
},
}
super().__init__(
instruments=instruments, start_time=start_time, end_time=end_time, data_loader=data_loader,
instruments=instruments,
start_time=start_time,
end_time=end_time,
data_loader=data_loader,
)
def get_feature_config(self):
@@ -137,7 +155,8 @@ class HighFreqBacktestHandler(DataHandler):
fields += [
"Cut({0}, 240, None)".format(
template_if.format(
template_fillnan.format(template_paused.format("$close")), template_paused.format(simpson_vwap),
template_fillnan.format(template_paused.format("$close")),
template_paused.format(simpson_vwap),
)
)
]

View File

@@ -65,6 +65,8 @@ class HighFreqNorm(Processor):
feat = df_values[:, [0, 1, 2, 3, 4, 10]].reshape(-1, 6 * 240)
feat_1 = df_values[:, [5, 6, 7, 8, 9, 11]].reshape(-1, 6 * 240)
df_new_features = pd.DataFrame(
data=np.concatenate((feat, feat_1), axis=1), index=idx, columns=["FEATURE_%d" % i for i in range(12 * 240)],
data=np.concatenate((feat, feat_1), axis=1),
index=idx,
columns=["FEATURE_%d" % i for i in range(12 * 240)],
).sort_index()
return df_new_features

View File

@@ -63,7 +63,13 @@ class HighfreqWorkflow(object):
"module_path": "highfreq_handler",
"kwargs": DATA_HANDLER_CONFIG0,
},
"segments": {"train": (start_time, train_end_time), "test": (test_start_time, end_time,),},
"segments": {
"train": (start_time, train_end_time),
"test": (
test_start_time,
end_time,
),
},
},
},
"dataset_backtest": {
@@ -75,7 +81,13 @@ class HighfreqWorkflow(object):
"module_path": "highfreq_handler",
"kwargs": DATA_HANDLER_CONFIG1,
},
"segments": {"train": (start_time, train_end_time), "test": (test_start_time, end_time,),},
"segments": {
"train": (start_time, train_end_time),
"test": (
test_start_time,
end_time,
),
},
},
},
}
@@ -140,11 +152,24 @@ class HighfreqWorkflow(object):
"start_time": "2021-01-19 00:00:00",
"end_time": "2021-01-25 16:00:00",
},
segment_kwargs={"test": ("2021-01-19 00:00:00", "2021-01-25 16:00:00",),},
segment_kwargs={
"test": (
"2021-01-19 00:00:00",
"2021-01-25 16:00:00",
),
},
)
dataset_backtest.init(
handler_kwargs={"start_time": "2021-01-19 00:00:00", "end_time": "2021-01-25 16:00:00",},
segment_kwargs={"test": ("2021-01-19 00:00:00", "2021-01-25 16:00:00",),},
handler_kwargs={
"start_time": "2021-01-19 00:00:00",
"end_time": "2021-01-25 16:00:00",
},
segment_kwargs={
"test": (
"2021-01-19 00:00:00",
"2021-01-25 16:00:00",
),
},
)
##=============get data=============

View File

@@ -34,7 +34,10 @@ exp_path = str(Path(os.getcwd()).resolve() / exp_folder_name)
exp_manager = {
"class": "MLflowExpManager",
"module_path": "qlib.workflow.expm",
"kwargs": {"uri": "file:" + exp_path, "default_exp_name": "Experiment",},
"kwargs": {
"uri": "file:" + exp_path,
"default_exp_name": "Experiment",
},
}
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")

View File

@@ -81,7 +81,10 @@ if __name__ == "__main__":
"strategy": {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.strategy",
"kwargs": {"topk": 50, "n_drop": 5,},
"kwargs": {
"topk": 50,
"n_drop": 5,
},
},
"backtest": {
"verbose": False,

View File

@@ -39,7 +39,13 @@ class YahooData:
INTERVAL_1d = "1d"
def __init__(
self, timezone: str = None, start=None, end=None, interval="1d", delay=0, show_1min_logging: bool = False,
self,
timezone: str = None,
start=None,
end=None,
interval="1d",
delay=0,
show_1min_logging: bool = False,
):
"""
@@ -119,7 +125,11 @@ class YahooData:
self._sleep()
_remote_interval = "1m" if self._interval == self.INTERVAL_1min else self._interval
return self.get_data_from_remote(
symbol, interval=_remote_interval, start=start_, end=end_, show_1min_logging=self._show_1min_logging,
symbol,
interval=_remote_interval,
start=start_,
end=end_,
show_1min_logging=self._show_1min_logging,
)
_result = None
@@ -428,7 +438,9 @@ class YahooNormalize:
DAILY_FORMAT = "%Y-%m-%d"
def __init__(
self, date_field_name: str = "date", symbol_field_name: str = "symbol",
self,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
):
"""
@@ -446,7 +458,10 @@ class YahooNormalize:
@staticmethod
def normalize_yahoo(
df: pd.DataFrame, calendar_list: list = None, date_field_name: str = "date", symbol_field_name: str = "symbol",
df: pd.DataFrame,
calendar_list: list = None,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
):
if df.empty:
return df
@@ -551,7 +566,9 @@ class YahooNormalize1min(YahooNormalize, ABC):
CONSISTENT_1d = False
def __init__(
self, date_field_name: str = "date", symbol_field_name: str = "symbol",
self,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
):
"""

View File

@@ -153,13 +153,22 @@ class DumpDataBase:
@staticmethod
def _read_calendars(calendar_path: Path) -> List[pd.Timestamp]:
return sorted(map(pd.Timestamp, pd.read_csv(calendar_path, header=None).loc[:, 0].tolist(),))
return sorted(
map(
pd.Timestamp,
pd.read_csv(calendar_path, header=None).loc[:, 0].tolist(),
)
)
def _read_instruments(self, instrument_path: Path) -> pd.DataFrame:
df = pd.read_csv(
instrument_path,
sep=self.INSTRUMENTS_SEP,
names=[self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD,],
names=[
self.symbol_field_name,
self.INSTRUMENTS_START_FIELD,
self.INSTRUMENTS_END_FIELD,
],
)
return df

View File

@@ -70,10 +70,16 @@ with open(os.path.join(here, "README.md"), encoding="utf-8") as f:
# Cython Extensions
extensions = [
Extension(
"qlib.data._libs.rolling", ["qlib/data/_libs/rolling.pyx"], language="c++", include_dirs=[NUMPY_INCLUDE],
"qlib.data._libs.rolling",
["qlib/data/_libs/rolling.pyx"],
language="c++",
include_dirs=[NUMPY_INCLUDE],
),
Extension(
"qlib.data._libs.expanding", ["qlib/data/_libs/expanding.pyx"], language="c++", include_dirs=[NUMPY_INCLUDE],
"qlib.data._libs.expanding",
["qlib/data/_libs/expanding.pyx"],
language="c++",
include_dirs=[NUMPY_INCLUDE],
),
]
@@ -92,7 +98,9 @@ setup(
# py_modules=['qlib'],
entry_points={
# 'console_scripts': ['mycli=mymodule:cli'],
"console_scripts": ["qrun=qlib.workflow.cli:run",],
"console_scripts": [
"qrun=qlib.workflow.cli:run",
],
},
ext_modules=extensions,
install_requires=REQUIRED,

View File

@@ -78,7 +78,10 @@ port_analysis_config = {
"strategy": {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.strategy",
"kwargs": {"topk": 50, "n_drop": 5,},
"kwargs": {
"topk": 50,
"n_drop": 5,
},
},
"backtest": {
"verbose": False,
@@ -173,7 +176,9 @@ class TestAllFlow(TestAutoData):
def test_1_backtest(self):
analyze_df = backtest_analysis(TestAllFlow.PRED_SCORE, TestAllFlow.RID)
self.assertGreaterEqual(
analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0], 0.10, "backtest failed",
analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0],
0.10,
"backtest failed",
)

View File

@@ -40,7 +40,9 @@ class TestDumpData(unittest.TestCase):
TestDumpData.STOCK_NAMES = list(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob("*.csv")))
provider_uri = str(QLIB_DIR.resolve())
qlib.init(
provider_uri=provider_uri, expression_cache=None, dataset_cache=None,
provider_uri=provider_uri,
expression_cache=None,
dataset_cache=None,
)
@classmethod
@@ -52,7 +54,10 @@ class TestDumpData(unittest.TestCase):
def test_1_dump_calendars(self):
ori_calendars = set(
map(pd.Timestamp, pd.read_csv(QLIB_DIR.joinpath("calendars", "day.txt"), header=None).loc[:, 0].values,)
map(
pd.Timestamp,
pd.read_csv(QLIB_DIR.joinpath("calendars", "day.txt"), header=None).loc[:, 0].values,
)
)
res_calendars = set(D.calendar())
assert len(ori_calendars - res_calendars) == len(res_calendars - ori_calendars) == 0, "dump calendars failed"

View File

@@ -26,7 +26,9 @@ class TestGetData(unittest.TestCase):
def setUpClass(cls) -> None:
provider_uri = str(QLIB_DIR.resolve())
qlib.init(
provider_uri=provider_uri, expression_cache=None, dataset_cache=None,
provider_uri=provider_uri,
expression_cache=None,
dataset_cache=None,
)
@classmethod