1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 10:01:19 +08:00

bug fixed & code format

This commit is contained in:
lzh222333
2021-03-31 03:08:48 +00:00
parent 3724273d73
commit edcd7b1ff9
8 changed files with 28 additions and 32 deletions

View File

@@ -73,7 +73,7 @@ def reset(task_pool, exp_name):
print("========== reset ==========")
TaskManager(task_pool=task_pool).remove()
exp, _ = R.get_exp(experiment_name=exp_name)
exp = R.get_exp(experiment_name=exp_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
@@ -115,7 +115,7 @@ def task_collecting(task_pool, exp_name):
task_config = recorder.load_object("task")
model_key = task_config["model"]["class"]
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
return model_key, model_key, rolling_key
return model_key, rolling_key
def my_filter(recorder):
# only choose the results of "LGBModel"

View File

@@ -117,7 +117,7 @@ def task_collecting():
task_config = recorder.load_object("task")
model_key = task_config["model"]["class"]
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
return model_key, model_key, rolling_key
return model_key, rolling_key
def my_filter(recorder):
# only choose the results of "LGBModel"
@@ -136,7 +136,7 @@ def task_collecting():
def reset():
print("========== reset ==========")
task_manager.remove()
exp, _ = R.get_exp(experiment_name=exp_name)
exp = R.get_exp(experiment_name=exp_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)

View File

@@ -7,8 +7,7 @@ from qlib.workflow.task.utils import list_recorders
class Collector:
"""The collector to collect different results based on experiment backend and ensemble method
"""
"""The collector to collect different results based on experiment backend and ensemble method"""
def collect(self, ensemble, get_group_key_func, *args, **kwargs):
"""To collect the results, we need to get the experiment record firstly and divided them into
@@ -23,7 +22,7 @@ class Collector:
class RecorderCollector(Collector):
def __init__(self, exp_name, artifacts_path = {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}) -> None:
def __init__(self, exp_name, artifacts_path={"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}) -> None:
"""init RecorderCollector
Args:
@@ -48,14 +47,14 @@ class RecorderCollector(Collector):
"""
if artifacts_key is None:
artifacts_key = self.artifacts_path.keys()
if isinstance(artifacts_key, str):
artifacts_key = [artifacts_key]
# prepare_ensemble
ensemble_dict = {}
for key in artifacts_key:
ensemble_dict.setdefault(key,{})
ensemble_dict.setdefault(key, {})
# filter records
recs_flt = list_recorders(self.exp_name, rec_filter_func)
for _, rec in recs_flt.items():
@@ -64,7 +63,6 @@ class RecorderCollector(Collector):
artifact = rec.load_object(self.artifacts_path[key])
ensemble_dict[key][group_key] = artifact
if isinstance(artifacts_key, str):
return ensemble(ensemble_dict[artifacts_key])
@@ -72,4 +70,3 @@ class RecorderCollector(Collector):
for key in artifacts_key:
collect_dict[key] = ensemble(ensemble_dict[key])
return collect_dict

View File

@@ -7,12 +7,10 @@ from qlib.workflow.task.utils import list_recorders
from typing import Dict
class Ensemble:
"""Merge the objects in an Ensemble.
"""
"""Merge the objects in an Ensemble."""
def __init__(self, merge_func = None, get_grouped_key_func = None) -> None:
def __init__(self, merge_func=None, get_grouped_key_func=None) -> None:
"""init Ensemble
Args:
@@ -26,7 +24,7 @@ class Ensemble:
self.get_grouped_key_func = get_grouped_key_func
def merge_func(self, group_inner_dict):
"""Given a group_inner_dict such as {Rollinga_b: object, Rollingb_c: object},
"""Given a group_inner_dict such as {Rollinga_b: object, Rollingb_c: object},
merge it to object
Args:
@@ -34,10 +32,10 @@ class Ensemble:
"""
raise NotImplementedError(f"Please implement the `merge_func` method.")
def get_grouped_key_func(self, group_key):
"""Given a group_key and return the group_outer_key, group_inner_key.
For example:
(A,B,Rolling) -> (A,B):Rolling
(A,B) -> C:(A,B)
@@ -135,10 +133,10 @@ class Ensemble:
grouped_dict = self.group(group_dict)
return self.reduce(grouped_dict)
class RollingEnsemble(Ensemble):
"""A specific implementation of Ensemble for Rolling.
"""
class RollingEnsemble(Ensemble):
"""A specific implementation of Ensemble for Rolling."""
def merge_func(self, group_inner_dict):
"""merge group_inner_dict by datetime.
@@ -155,7 +153,7 @@ class RollingEnsemble(Ensemble):
artifact = artifact[~artifact.index.duplicated(keep="last")]
artifact = artifact.sort_index()
return artifact
def get_grouped_key_func(self, group_key):
"""The final axis of group_key must be the Rolling key.
When `collect`, get_group_key_func can add the statement below.
@@ -174,7 +172,5 @@ class RollingEnsemble(Ensemble):
Returns:
tuple or str, tuple or str: group_outer_key, group_inner_key
"""
assert len(group_key)>=2
assert len(group_key) >= 2
return group_key[:-1], group_key[-1]

View File

@@ -174,7 +174,7 @@ class TaskManager:
for t in new_tasks:
self.insert_task_def(t, task_pool)
return len(new_tasks)
def fetch_task(self, query={}, task_pool=None):
@@ -250,7 +250,7 @@ class TaskManager:
def re_query(self, task, task_pool=None):
task_pool = self._get_task_pool(task_pool)
return task_pool.find_one({"_id":ObjectId(task["_id"])})
return task_pool.find_one({"_id": ObjectId(task["_id"])})
def commit_task_res(self, task, res, status=None, task_pool=None):
task_pool = self._get_task_pool(task_pool)

View File

@@ -106,7 +106,9 @@ class RollingOnlineManager(OnlineManagerR):
pass
def prepare_tasks(self):
latest_records, max_test = self.list_latest_recorders(lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG)
latest_records, max_test = self.list_latest_recorders(
lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG
)
if max_test is None:
self.logger.warn(f"No latest_recorders.")
return

View File

@@ -9,6 +9,7 @@ from qlib.workflow.recorder import Recorder
from qlib.workflow.task.utils import list_recorders
from qlib.data.dataset.handler import DataHandlerLP
class ModelUpdater:
"""
The model updater to update model results in new data.
@@ -48,7 +49,7 @@ class ModelUpdater:
dataset.setup_data(handler_kwargs={"init_type": DataHandlerLP.IT_LS}, segments=segments)
return dataset
def update_pred(self, recorder: Recorder, frequency='day'):
def update_pred(self, recorder: Recorder, frequency="day"):
"""update predictions to the latest day in Calendar based on rid
Parameters
@@ -60,10 +61,10 @@ class ModelUpdater:
last_end = old_pred.index.get_level_values("datetime").max()
# updated to the latest trading day
if frequency=='day':
if frequency == "day":
cal = D.calendar(start_time=last_end + pd.Timedelta(days=1), end_time=None)
else:
raise NotImplementedError("Now Qlib only support update daily frequency prediction")
raise NotImplementedError("Now `ModelUpdater` only support update daily frequency prediction")
if len(cal) == 0:
self.logger.info(

View File

@@ -42,7 +42,7 @@ def list_recorders(experiment, rec_filter_func=None):
dict: a dict {rid: recorder} after filtering.
"""
if isinstance(experiment, str):
experiment, _ = R.get_exp(experiment_name=experiment)
experiment = R.get_exp(experiment_name=experiment)
recs = experiment.list_recorders()
recs_flt = {}
for rid, rec in recs.items():