From c0f1696adbddaa32df93e7e1e888214b4af9f82d Mon Sep 17 00:00:00 2001 From: Young Date: Wed, 9 Dec 2020 11:21:37 +0000 Subject: [PATCH] add downcast to save data --- qlib/workflow/record_temp.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index be4ccdb77..6381b914e 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -130,25 +130,25 @@ class SignalRecord(RecordTemp): pprint(f"The following are prediction results of the {type(self.model).__name__} model.") pprint(pred.head(5)) - # save according label - if isinstance(self.dataset, TSDatasetH): - index = raw_label.get_index() - raw_label = raw_label.data.loc[index] - raw_label = raw_label[:, -1:] - self.recorder.save_objects(**{"label.pkl": raw_label}) + if isinstance(self.dataset, DatasetH): + # NOTE: + # Python doesn't provide the downcasting mechanism. + # We use the trick here to downcast the class + orig_cls = self.dataset.__class__ + self.dataset.__class__ = DatasetH - elif isinstance(self.dataset, DatasetH): - params = dict(self=self.dataset, segments="test", col_set="label", data_key=DataHandlerLP.DK_R) + params = dict(segments="test", col_set="label", data_key=DataHandlerLP.DK_R) try: # Assume the backend handler is DataHandlerLP - raw_label = DatasetH.prepare(**params) + raw_label = self.dataset.prepare(**params) except TypeError: # The argument number is not right del params["data_key"] # The backend handler should be DataHandler - raw_label = DatasetH.prepare(**params) + raw_label = self.dataset.prepare(**params) self.recorder.save_objects(**{"label.pkl": raw_label}) + self.dataset.__class__ = orig_cls def list(self): return ["pred.pkl", "label.pkl"]