mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Black format
This commit is contained in:
10
docs/conf.py
10
docs/conf.py
@@ -191,7 +191,15 @@ man_pages = [(master_doc, "qlib", u"QLib Documentation", [author], 1)]
|
||||
# (source start file, target name, title, author,
|
||||
# dir menu entry, description, category)
|
||||
texinfo_documents = [
|
||||
(master_doc, "QLib", u"QLib Documentation", author, "QLib", "One line description of project.", "Miscellaneous",),
|
||||
(
|
||||
master_doc,
|
||||
"QLib",
|
||||
u"QLib Documentation",
|
||||
author,
|
||||
"QLib",
|
||||
"One line description of project.",
|
||||
"Miscellaneous",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -721,7 +721,12 @@ class TemporalFusionTransformer:
|
||||
encoder_steps = self.num_encoder_steps
|
||||
|
||||
# Inputs.
|
||||
all_inputs = tf.keras.layers.Input(shape=(time_steps, combined_input_size,))
|
||||
all_inputs = tf.keras.layers.Input(
|
||||
shape=(
|
||||
time_steps,
|
||||
combined_input_size,
|
||||
)
|
||||
)
|
||||
|
||||
unknown_inputs, known_combined_layer, obs_inputs, static_inputs = self.get_tft_embeddings(all_inputs)
|
||||
|
||||
@@ -861,7 +866,10 @@ class TemporalFusionTransformer:
|
||||
"""Returns LSTM cell initialized with default parameters."""
|
||||
if self.use_cudnn:
|
||||
lstm = tf.keras.layers.CuDNNLSTM(
|
||||
self.hidden_layer_size, return_sequences=True, return_state=return_state, stateful=False,
|
||||
self.hidden_layer_size,
|
||||
return_sequences=True,
|
||||
return_state=return_state,
|
||||
stateful=False,
|
||||
)
|
||||
else:
|
||||
lstm = tf.keras.layers.LSTM(
|
||||
|
||||
@@ -20,7 +20,10 @@ class HighFreqHandler(DataHandlerLP):
|
||||
new_l = []
|
||||
for p in proc_l:
|
||||
p["kwargs"].update(
|
||||
{"fit_start_time": fit_start_time, "fit_end_time": fit_end_time,}
|
||||
{
|
||||
"fit_start_time": fit_start_time,
|
||||
"fit_end_time": fit_end_time,
|
||||
}
|
||||
)
|
||||
new_l.append(p)
|
||||
return new_l
|
||||
@@ -30,7 +33,11 @@ class HighFreqHandler(DataHandlerLP):
|
||||
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {"config": self.get_feature_config(), "swap_level": False, "freq": "1min",},
|
||||
"kwargs": {
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
"freq": "1min",
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
@@ -61,7 +68,8 @@ class HighFreqHandler(DataHandlerLP):
|
||||
|
||||
feature_ops = template_norm.format(
|
||||
template_if.format(
|
||||
template_fillnan.format(template_paused.format("$close")), template_paused.format(price_field),
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format(price_field),
|
||||
),
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
)
|
||||
@@ -111,14 +119,24 @@ class HighFreqHandler(DataHandlerLP):
|
||||
|
||||
class HighFreqBacktestHandler(DataHandler):
|
||||
def __init__(
|
||||
self, instruments="csi300", start_time=None, end_time=None,
|
||||
self,
|
||||
instruments="csi300",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
):
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {"config": self.get_feature_config(), "swap_level": False, "freq": "1min",},
|
||||
"kwargs": {
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
"freq": "1min",
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
instruments=instruments, start_time=start_time, end_time=end_time, data_loader=data_loader,
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
)
|
||||
|
||||
def get_feature_config(self):
|
||||
@@ -137,7 +155,8 @@ class HighFreqBacktestHandler(DataHandler):
|
||||
fields += [
|
||||
"Cut({0}, 240, None)".format(
|
||||
template_if.format(
|
||||
template_fillnan.format(template_paused.format("$close")), template_paused.format(simpson_vwap),
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format(simpson_vwap),
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
@@ -65,6 +65,8 @@ class HighFreqNorm(Processor):
|
||||
feat = df_values[:, [0, 1, 2, 3, 4, 10]].reshape(-1, 6 * 240)
|
||||
feat_1 = df_values[:, [5, 6, 7, 8, 9, 11]].reshape(-1, 6 * 240)
|
||||
df_new_features = pd.DataFrame(
|
||||
data=np.concatenate((feat, feat_1), axis=1), index=idx, columns=["FEATURE_%d" % i for i in range(12 * 240)],
|
||||
data=np.concatenate((feat, feat_1), axis=1),
|
||||
index=idx,
|
||||
columns=["FEATURE_%d" % i for i in range(12 * 240)],
|
||||
).sort_index()
|
||||
return df_new_features
|
||||
|
||||
@@ -63,7 +63,13 @@ class HighfreqWorkflow(object):
|
||||
"module_path": "highfreq_handler",
|
||||
"kwargs": DATA_HANDLER_CONFIG0,
|
||||
},
|
||||
"segments": {"train": (start_time, train_end_time), "test": (test_start_time, end_time,),},
|
||||
"segments": {
|
||||
"train": (start_time, train_end_time),
|
||||
"test": (
|
||||
test_start_time,
|
||||
end_time,
|
||||
),
|
||||
},
|
||||
},
|
||||
},
|
||||
"dataset_backtest": {
|
||||
@@ -75,7 +81,13 @@ class HighfreqWorkflow(object):
|
||||
"module_path": "highfreq_handler",
|
||||
"kwargs": DATA_HANDLER_CONFIG1,
|
||||
},
|
||||
"segments": {"train": (start_time, train_end_time), "test": (test_start_time, end_time,),},
|
||||
"segments": {
|
||||
"train": (start_time, train_end_time),
|
||||
"test": (
|
||||
test_start_time,
|
||||
end_time,
|
||||
),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -140,11 +152,24 @@ class HighfreqWorkflow(object):
|
||||
"start_time": "2021-01-19 00:00:00",
|
||||
"end_time": "2021-01-25 16:00:00",
|
||||
},
|
||||
segment_kwargs={"test": ("2021-01-19 00:00:00", "2021-01-25 16:00:00",),},
|
||||
segment_kwargs={
|
||||
"test": (
|
||||
"2021-01-19 00:00:00",
|
||||
"2021-01-25 16:00:00",
|
||||
),
|
||||
},
|
||||
)
|
||||
dataset_backtest.init(
|
||||
handler_kwargs={"start_time": "2021-01-19 00:00:00", "end_time": "2021-01-25 16:00:00",},
|
||||
segment_kwargs={"test": ("2021-01-19 00:00:00", "2021-01-25 16:00:00",),},
|
||||
handler_kwargs={
|
||||
"start_time": "2021-01-19 00:00:00",
|
||||
"end_time": "2021-01-25 16:00:00",
|
||||
},
|
||||
segment_kwargs={
|
||||
"test": (
|
||||
"2021-01-19 00:00:00",
|
||||
"2021-01-25 16:00:00",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
##=============get data=============
|
||||
|
||||
@@ -34,7 +34,10 @@ exp_path = str(Path(os.getcwd()).resolve() / exp_folder_name)
|
||||
exp_manager = {
|
||||
"class": "MLflowExpManager",
|
||||
"module_path": "qlib.workflow.expm",
|
||||
"kwargs": {"uri": "file:" + exp_path, "default_exp_name": "Experiment",},
|
||||
"kwargs": {
|
||||
"uri": "file:" + exp_path,
|
||||
"default_exp_name": "Experiment",
|
||||
},
|
||||
}
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
|
||||
@@ -81,7 +81,10 @@ if __name__ == "__main__":
|
||||
"strategy": {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.strategy",
|
||||
"kwargs": {"topk": 50, "n_drop": 5,},
|
||||
"kwargs": {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
},
|
||||
"backtest": {
|
||||
"verbose": False,
|
||||
|
||||
@@ -39,7 +39,13 @@ class YahooData:
|
||||
INTERVAL_1d = "1d"
|
||||
|
||||
def __init__(
|
||||
self, timezone: str = None, start=None, end=None, interval="1d", delay=0, show_1min_logging: bool = False,
|
||||
self,
|
||||
timezone: str = None,
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
delay=0,
|
||||
show_1min_logging: bool = False,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -119,7 +125,11 @@ class YahooData:
|
||||
self._sleep()
|
||||
_remote_interval = "1m" if self._interval == self.INTERVAL_1min else self._interval
|
||||
return self.get_data_from_remote(
|
||||
symbol, interval=_remote_interval, start=start_, end=end_, show_1min_logging=self._show_1min_logging,
|
||||
symbol,
|
||||
interval=_remote_interval,
|
||||
start=start_,
|
||||
end=end_,
|
||||
show_1min_logging=self._show_1min_logging,
|
||||
)
|
||||
|
||||
_result = None
|
||||
@@ -428,7 +438,9 @@ class YahooNormalize:
|
||||
DAILY_FORMAT = "%Y-%m-%d"
|
||||
|
||||
def __init__(
|
||||
self, date_field_name: str = "date", symbol_field_name: str = "symbol",
|
||||
self,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -446,7 +458,10 @@ class YahooNormalize:
|
||||
|
||||
@staticmethod
|
||||
def normalize_yahoo(
|
||||
df: pd.DataFrame, calendar_list: list = None, date_field_name: str = "date", symbol_field_name: str = "symbol",
|
||||
df: pd.DataFrame,
|
||||
calendar_list: list = None,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
if df.empty:
|
||||
return df
|
||||
@@ -551,7 +566,9 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
CONSISTENT_1d = False
|
||||
|
||||
def __init__(
|
||||
self, date_field_name: str = "date", symbol_field_name: str = "symbol",
|
||||
self,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
"""
|
||||
|
||||
|
||||
@@ -153,13 +153,22 @@ class DumpDataBase:
|
||||
|
||||
@staticmethod
|
||||
def _read_calendars(calendar_path: Path) -> List[pd.Timestamp]:
|
||||
return sorted(map(pd.Timestamp, pd.read_csv(calendar_path, header=None).loc[:, 0].tolist(),))
|
||||
return sorted(
|
||||
map(
|
||||
pd.Timestamp,
|
||||
pd.read_csv(calendar_path, header=None).loc[:, 0].tolist(),
|
||||
)
|
||||
)
|
||||
|
||||
def _read_instruments(self, instrument_path: Path) -> pd.DataFrame:
|
||||
df = pd.read_csv(
|
||||
instrument_path,
|
||||
sep=self.INSTRUMENTS_SEP,
|
||||
names=[self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD,],
|
||||
names=[
|
||||
self.symbol_field_name,
|
||||
self.INSTRUMENTS_START_FIELD,
|
||||
self.INSTRUMENTS_END_FIELD,
|
||||
],
|
||||
)
|
||||
|
||||
return df
|
||||
|
||||
14
setup.py
14
setup.py
@@ -70,10 +70,16 @@ with open(os.path.join(here, "README.md"), encoding="utf-8") as f:
|
||||
# Cython Extensions
|
||||
extensions = [
|
||||
Extension(
|
||||
"qlib.data._libs.rolling", ["qlib/data/_libs/rolling.pyx"], language="c++", include_dirs=[NUMPY_INCLUDE],
|
||||
"qlib.data._libs.rolling",
|
||||
["qlib/data/_libs/rolling.pyx"],
|
||||
language="c++",
|
||||
include_dirs=[NUMPY_INCLUDE],
|
||||
),
|
||||
Extension(
|
||||
"qlib.data._libs.expanding", ["qlib/data/_libs/expanding.pyx"], language="c++", include_dirs=[NUMPY_INCLUDE],
|
||||
"qlib.data._libs.expanding",
|
||||
["qlib/data/_libs/expanding.pyx"],
|
||||
language="c++",
|
||||
include_dirs=[NUMPY_INCLUDE],
|
||||
),
|
||||
]
|
||||
|
||||
@@ -92,7 +98,9 @@ setup(
|
||||
# py_modules=['qlib'],
|
||||
entry_points={
|
||||
# 'console_scripts': ['mycli=mymodule:cli'],
|
||||
"console_scripts": ["qrun=qlib.workflow.cli:run",],
|
||||
"console_scripts": [
|
||||
"qrun=qlib.workflow.cli:run",
|
||||
],
|
||||
},
|
||||
ext_modules=extensions,
|
||||
install_requires=REQUIRED,
|
||||
|
||||
@@ -78,7 +78,10 @@ port_analysis_config = {
|
||||
"strategy": {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.strategy",
|
||||
"kwargs": {"topk": 50, "n_drop": 5,},
|
||||
"kwargs": {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
},
|
||||
"backtest": {
|
||||
"verbose": False,
|
||||
@@ -173,7 +176,9 @@ class TestAllFlow(TestAutoData):
|
||||
def test_1_backtest(self):
|
||||
analyze_df = backtest_analysis(TestAllFlow.PRED_SCORE, TestAllFlow.RID)
|
||||
self.assertGreaterEqual(
|
||||
analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0], 0.10, "backtest failed",
|
||||
analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0],
|
||||
0.10,
|
||||
"backtest failed",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -40,7 +40,9 @@ class TestDumpData(unittest.TestCase):
|
||||
TestDumpData.STOCK_NAMES = list(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob("*.csv")))
|
||||
provider_uri = str(QLIB_DIR.resolve())
|
||||
qlib.init(
|
||||
provider_uri=provider_uri, expression_cache=None, dataset_cache=None,
|
||||
provider_uri=provider_uri,
|
||||
expression_cache=None,
|
||||
dataset_cache=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -52,7 +54,10 @@ class TestDumpData(unittest.TestCase):
|
||||
|
||||
def test_1_dump_calendars(self):
|
||||
ori_calendars = set(
|
||||
map(pd.Timestamp, pd.read_csv(QLIB_DIR.joinpath("calendars", "day.txt"), header=None).loc[:, 0].values,)
|
||||
map(
|
||||
pd.Timestamp,
|
||||
pd.read_csv(QLIB_DIR.joinpath("calendars", "day.txt"), header=None).loc[:, 0].values,
|
||||
)
|
||||
)
|
||||
res_calendars = set(D.calendar())
|
||||
assert len(ori_calendars - res_calendars) == len(res_calendars - ori_calendars) == 0, "dump calendars failed"
|
||||
|
||||
@@ -26,7 +26,9 @@ class TestGetData(unittest.TestCase):
|
||||
def setUpClass(cls) -> None:
|
||||
provider_uri = str(QLIB_DIR.resolve())
|
||||
qlib.init(
|
||||
provider_uri=provider_uri, expression_cache=None, dataset_cache=None,
|
||||
provider_uri=provider_uri,
|
||||
expression_cache=None,
|
||||
dataset_cache=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user