mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 11:00:57 +08:00
Add tests for SigAnaRecord
This commit is contained in:
@@ -163,8 +163,8 @@ class SigAnaRecord(SignalRecord):
|
||||
|
||||
artifact_path = "sig_analysis"
|
||||
|
||||
def __init__(self, recorder, ana_long_short=False, ann_scaler=252):
|
||||
super().__init__(recorder=recorder)
|
||||
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, **kwargs):
|
||||
super().__init__(recorder=recorder, **kwargs)
|
||||
self.ana_long_short = ana_long_short
|
||||
self.ann_scaler = ann_scaler
|
||||
|
||||
|
||||
@@ -153,7 +153,7 @@ def train_with_sigana():
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name="workflow"):
|
||||
with R.start(experiment_name="workflow_with_sigana"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
model.fit(dataset)
|
||||
|
||||
@@ -163,7 +163,8 @@ def train_with_sigana():
|
||||
sar.generate()
|
||||
ic = sar.load(sar.get_path("ic.pkl"))
|
||||
ric = sar.load(sar.get_path("ric.pkl"))
|
||||
return pred_score, {"ic": ic, "ric": ric}, rid
|
||||
uri_path = R.get_uri()
|
||||
return pred_score, {"ic": ic, "ric": ric}, uri_path
|
||||
|
||||
|
||||
def fake_experiment():
|
||||
@@ -222,12 +223,18 @@ class TestAllFlow(TestAutoData):
|
||||
def tearDownClass(cls) -> None:
|
||||
shutil.rmtree(str(Path(C["exp_manager"]["kwargs"]["uri"].strip("file:")).resolve()))
|
||||
|
||||
def test_0_train(self):
|
||||
def test_0_train_with_sigana(self):
|
||||
TestAllFlow.PRED_SCORE, ic_ric, uri_path = train_with_sigana()
|
||||
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
|
||||
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
|
||||
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
|
||||
|
||||
def test_1_train(self):
|
||||
TestAllFlow.PRED_SCORE, ic_ric, TestAllFlow.RID = train()
|
||||
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
|
||||
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
|
||||
|
||||
def test_1_backtest(self):
|
||||
def test_2_backtest(self):
|
||||
analyze_df = backtest_analysis(TestAllFlow.PRED_SCORE, TestAllFlow.RID)
|
||||
self.assertGreaterEqual(
|
||||
analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0],
|
||||
@@ -235,17 +242,12 @@ class TestAllFlow(TestAutoData):
|
||||
"backtest failed",
|
||||
)
|
||||
|
||||
def test_2_expmanager(self):
|
||||
def test_3_expmanager(self):
|
||||
pass_default, pass_current, uri_path = fake_experiment()
|
||||
self.assertTrue(pass_default, msg="default uri is incorrect")
|
||||
self.assertTrue(pass_current, msg="current uri is incorrect")
|
||||
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
|
||||
|
||||
def test_3_train_with_sigana(self):
|
||||
TestAllFlow.PRED_SCORE, ic_ric, TestAllFlow.RID = train_with_sigana()
|
||||
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
|
||||
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
|
||||
|
||||
|
||||
def suite():
|
||||
_suite = unittest.TestSuite()
|
||||
|
||||
Reference in New Issue
Block a user