mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 11:00:57 +08:00
add downcast to save data
This commit is contained in:
@@ -130,25 +130,25 @@ class SignalRecord(RecordTemp):
|
||||
pprint(f"The following are prediction results of the {type(self.model).__name__} model.")
|
||||
pprint(pred.head(5))
|
||||
|
||||
# save according label
|
||||
if isinstance(self.dataset, TSDatasetH):
|
||||
index = raw_label.get_index()
|
||||
raw_label = raw_label.data.loc[index]
|
||||
raw_label = raw_label[:, -1:]
|
||||
self.recorder.save_objects(**{"label.pkl": raw_label})
|
||||
if isinstance(self.dataset, DatasetH):
|
||||
# NOTE:
|
||||
# Python doesn't provide the downcasting mechanism.
|
||||
# We use the trick here to downcast the class
|
||||
orig_cls = self.dataset.__class__
|
||||
self.dataset.__class__ = DatasetH
|
||||
|
||||
elif isinstance(self.dataset, DatasetH):
|
||||
params = dict(self=self.dataset, segments="test", col_set="label", data_key=DataHandlerLP.DK_R)
|
||||
params = dict(segments="test", col_set="label", data_key=DataHandlerLP.DK_R)
|
||||
try:
|
||||
# Assume the backend handler is DataHandlerLP
|
||||
raw_label = DatasetH.prepare(**params)
|
||||
raw_label = self.dataset.prepare(**params)
|
||||
except TypeError:
|
||||
# The argument number is not right
|
||||
del params["data_key"]
|
||||
# The backend handler should be DataHandler
|
||||
raw_label = DatasetH.prepare(**params)
|
||||
raw_label = self.dataset.prepare(**params)
|
||||
|
||||
self.recorder.save_objects(**{"label.pkl": raw_label})
|
||||
self.dataset.__class__ = orig_cls
|
||||
|
||||
def list(self):
|
||||
return ["pred.pkl", "label.pkl"]
|
||||
|
||||
Reference in New Issue
Block a user