1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 11:00:57 +08:00

add downcast to save data

This commit is contained in:
Young
2020-12-09 11:21:37 +00:00
committed by you-n-g
parent 361d168890
commit c0f1696adb

View File

@@ -130,25 +130,25 @@ class SignalRecord(RecordTemp):
pprint(f"The following are prediction results of the {type(self.model).__name__} model.")
pprint(pred.head(5))
# save according label
if isinstance(self.dataset, TSDatasetH):
index = raw_label.get_index()
raw_label = raw_label.data.loc[index]
raw_label = raw_label[:, -1:]
self.recorder.save_objects(**{"label.pkl": raw_label})
if isinstance(self.dataset, DatasetH):
# NOTE:
# Python doesn't provide the downcasting mechanism.
# We use the trick here to downcast the class
orig_cls = self.dataset.__class__
self.dataset.__class__ = DatasetH
elif isinstance(self.dataset, DatasetH):
params = dict(self=self.dataset, segments="test", col_set="label", data_key=DataHandlerLP.DK_R)
params = dict(segments="test", col_set="label", data_key=DataHandlerLP.DK_R)
try:
# Assume the backend handler is DataHandlerLP
raw_label = DatasetH.prepare(**params)
raw_label = self.dataset.prepare(**params)
except TypeError:
# The argument number is not right
del params["data_key"]
# The backend handler should be DataHandler
raw_label = DatasetH.prepare(**params)
raw_label = self.dataset.prepare(**params)
self.recorder.save_objects(**{"label.pkl": raw_label})
self.dataset.__class__ = orig_cls
def list(self):
return ["pred.pkl", "label.pkl"]