1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 11:00:57 +08:00

Add tests for SigAnaRecord

This commit is contained in:
D-X-Y
2021-03-16 08:17:13 +00:00
parent 9f57681032
commit 6559d44c7d
2 changed files with 14 additions and 12 deletions

View File

@@ -163,8 +163,8 @@ class SigAnaRecord(SignalRecord):
artifact_path = "sig_analysis"
def __init__(self, recorder, ana_long_short=False, ann_scaler=252):
super().__init__(recorder=recorder)
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, **kwargs):
super().__init__(recorder=recorder, **kwargs)
self.ana_long_short = ana_long_short
self.ann_scaler = ann_scaler

View File

@@ -153,7 +153,7 @@ def train_with_sigana():
dataset = init_instance_by_config(task["dataset"])
# start exp
with R.start(experiment_name="workflow"):
with R.start(experiment_name="workflow_with_sigana"):
R.log_params(**flatten_dict(task))
model.fit(dataset)
@@ -163,7 +163,8 @@ def train_with_sigana():
sar.generate()
ic = sar.load(sar.get_path("ic.pkl"))
ric = sar.load(sar.get_path("ric.pkl"))
return pred_score, {"ic": ic, "ric": ric}, rid
uri_path = R.get_uri()
return pred_score, {"ic": ic, "ric": ric}, uri_path
def fake_experiment():
@@ -222,12 +223,18 @@ class TestAllFlow(TestAutoData):
def tearDownClass(cls) -> None:
shutil.rmtree(str(Path(C["exp_manager"]["kwargs"]["uri"].strip("file:")).resolve()))
def test_0_train(self):
def test_0_train_with_sigana(self):
TestAllFlow.PRED_SCORE, ic_ric, uri_path = train_with_sigana()
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
def test_1_train(self):
TestAllFlow.PRED_SCORE, ic_ric, TestAllFlow.RID = train()
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
def test_1_backtest(self):
def test_2_backtest(self):
analyze_df = backtest_analysis(TestAllFlow.PRED_SCORE, TestAllFlow.RID)
self.assertGreaterEqual(
analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0],
@@ -235,17 +242,12 @@ class TestAllFlow(TestAutoData):
"backtest failed",
)
def test_2_expmanager(self):
def test_3_expmanager(self):
pass_default, pass_current, uri_path = fake_experiment()
self.assertTrue(pass_default, msg="default uri is incorrect")
self.assertTrue(pass_current, msg="current uri is incorrect")
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
def test_3_train_with_sigana(self):
TestAllFlow.PRED_SCORE, ic_ric, TestAllFlow.RID = train_with_sigana()
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
def suite():
_suite = unittest.TestSuite()