mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Add segment args for pred and refine MultiSegRecord
This commit is contained in:
@@ -61,10 +61,10 @@ class LGBModel(ModelFT):
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset, segment="test"):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
|
||||
def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
|
||||
|
||||
@@ -84,8 +84,8 @@ class LinearModel(Model):
|
||||
self.coef_ = coef
|
||||
self.intercept_ = 0.0
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset, segment="test"):
|
||||
if self.coef_ is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(x_test.values @ self.coef_ + self.intercept_, index=x_test.index)
|
||||
|
||||
@@ -57,8 +57,8 @@ class XGBModel(Model):
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset, segment="test"):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature")
|
||||
return pd.Series(self.model.predict(xgb.DMatrix(x_test.values)), index=x_test.index)
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import pandas as pd
|
||||
from sklearn.metrics import mean_squared_error
|
||||
from pprint import pprint
|
||||
from typing import Dict, Text, Any
|
||||
import numpy as np
|
||||
|
||||
from ...contrib.eva.alpha import calc_ic
|
||||
from ...workflow.record_temp import RecordTemp
|
||||
from ...workflow.record_temp import SignalRecord
|
||||
from ...data import dataset as qlib_dataset
|
||||
@@ -30,16 +29,29 @@ class MultiSegRecord(RecordTemp):
|
||||
self.dataset = dataset
|
||||
|
||||
def generate(self, segments: Dict[Text, Any], save: bool = False):
|
||||
# generate prediciton
|
||||
for key, segment in segments.items():
|
||||
predics = self.model.predict(self.dataset, segment)
|
||||
if isinstance(pred, pd.Series):
|
||||
predics = predictions.to_frame("score")
|
||||
# self.recorder.save_objects(**{"pred.pkl": pred})
|
||||
if isinstance(predics, pd.Series):
|
||||
predics = predics.to_frame("score")
|
||||
labels = self.dataset.prepare(
|
||||
segments=segment, col_set="label", data_key=dataset.handler.DataHandlerLP.DK_R
|
||||
segments=segment, col_set="label", data_key=qlib_dataset.handler.DataHandlerLP.DK_R
|
||||
)
|
||||
# compute ic, rank_ic
|
||||
# Compute the IC and Rank IC
|
||||
ic, ric = calc_ic(predics.iloc[:, 0], labels.iloc[:, 0])
|
||||
results = {"all-IC": ic, "mean-IC": ic.mean(), "all-Rank-IC": ric, "mean-Rank-IC": ric.mean()}
|
||||
logger.info("--- Results for {:} ({:}) ---".format(key, segment))
|
||||
ic_x100, ric_x100 = ic * 100, ric * 100
|
||||
logger.info("IC: {:.4f}%".format(ic_x100.mean()))
|
||||
logger.info("ICIR: {:.4f}%".format(ic_x100.mean() / ic_x100.std()))
|
||||
logger.info("Rank IC: {:.4f}%".format(ric_x100.mean()))
|
||||
logger.info("Rank ICIR: {:.4f}%".format(ric_x100.mean() / ric_x100.std()))
|
||||
|
||||
if save:
|
||||
save_name = "results-{:}.pkl".format(key)
|
||||
self.recorder.save_objects(**{save_name: results})
|
||||
logger.info(
|
||||
"The record '{save_name}' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
)
|
||||
|
||||
|
||||
class SignalMseRecord(SignalRecord):
|
||||
@@ -67,7 +79,7 @@ class SignalMseRecord(SignalRecord):
|
||||
objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)}
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.recorder.save_objects(**objects, artifact_path=self.get_path())
|
||||
pprint(metrics)
|
||||
logger.info("The evaluation results in SignalMseRecord is {:}".format(metrics))
|
||||
|
||||
def list(self):
|
||||
paths = [self.get_path("mse.pkl"), self.get_path("rmse.pkl")]
|
||||
|
||||
@@ -63,32 +63,46 @@ task = {
|
||||
}
|
||||
|
||||
|
||||
def test_multiseg():
|
||||
def train_multiseg():
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
with R.start(experiment_name="workflow"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
model.fit(dataset)
|
||||
|
||||
# prediction
|
||||
recorder = R.get_recorder()
|
||||
sr = MultiSegRecord(model, dataset, recorder)
|
||||
sr.generate(dict(valid="valid", test="test"))
|
||||
|
||||
sr.generate(dict(valid="valid", test="test"), True)
|
||||
uri = R.get_uri()
|
||||
return uri
|
||||
|
||||
|
||||
def train_mse():
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
with R.start(experiment_name="workflow"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
model.fit(dataset)
|
||||
recorder = R.get_recorder()
|
||||
sr = SignalMseRecord(recorder, model=model, dataset=dataset)
|
||||
sr.generate()
|
||||
uri = R.get_uri()
|
||||
return uri
|
||||
|
||||
|
||||
class TestAllFlow(TestAutoData):
|
||||
def test_0_multiseg(self):
|
||||
uri_path = test_multiseg()
|
||||
uri_path = train_multiseg()
|
||||
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
|
||||
|
||||
def test_1_mse(self):
|
||||
uri_path = train_mse()
|
||||
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
|
||||
|
||||
|
||||
def suite():
|
||||
_suite = unittest.TestSuite()
|
||||
_suite.addTest(TestAllFlow("test_0_multiseg"))
|
||||
_suite.addTest(TestAllFlow("test_1_mse"))
|
||||
return _suite
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user