mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 10:01:19 +08:00
bug fixed & code format
This commit is contained in:
@@ -73,7 +73,7 @@ def reset(task_pool, exp_name):
|
||||
print("========== reset ==========")
|
||||
TaskManager(task_pool=task_pool).remove()
|
||||
|
||||
exp, _ = R.get_exp(experiment_name=exp_name)
|
||||
exp = R.get_exp(experiment_name=exp_name)
|
||||
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
@@ -115,7 +115,7 @@ def task_collecting(task_pool, exp_name):
|
||||
task_config = recorder.load_object("task")
|
||||
model_key = task_config["model"]["class"]
|
||||
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
|
||||
return model_key, model_key, rolling_key
|
||||
return model_key, rolling_key
|
||||
|
||||
def my_filter(recorder):
|
||||
# only choose the results of "LGBModel"
|
||||
|
||||
@@ -117,7 +117,7 @@ def task_collecting():
|
||||
task_config = recorder.load_object("task")
|
||||
model_key = task_config["model"]["class"]
|
||||
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
|
||||
return model_key, model_key, rolling_key
|
||||
return model_key, rolling_key
|
||||
|
||||
def my_filter(recorder):
|
||||
# only choose the results of "LGBModel"
|
||||
@@ -136,7 +136,7 @@ def task_collecting():
|
||||
def reset():
|
||||
print("========== reset ==========")
|
||||
task_manager.remove()
|
||||
exp, _ = R.get_exp(experiment_name=exp_name)
|
||||
exp = R.get_exp(experiment_name=exp_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
|
||||
@@ -7,8 +7,7 @@ from qlib.workflow.task.utils import list_recorders
|
||||
|
||||
|
||||
class Collector:
|
||||
"""The collector to collect different results based on experiment backend and ensemble method
|
||||
"""
|
||||
"""The collector to collect different results based on experiment backend and ensemble method"""
|
||||
|
||||
def collect(self, ensemble, get_group_key_func, *args, **kwargs):
|
||||
"""To collect the results, we need to get the experiment record firstly and divided them into
|
||||
@@ -23,7 +22,7 @@ class Collector:
|
||||
|
||||
|
||||
class RecorderCollector(Collector):
|
||||
def __init__(self, exp_name, artifacts_path = {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}) -> None:
|
||||
def __init__(self, exp_name, artifacts_path={"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}) -> None:
|
||||
"""init RecorderCollector
|
||||
|
||||
Args:
|
||||
@@ -48,14 +47,14 @@ class RecorderCollector(Collector):
|
||||
"""
|
||||
if artifacts_key is None:
|
||||
artifacts_key = self.artifacts_path.keys()
|
||||
|
||||
|
||||
if isinstance(artifacts_key, str):
|
||||
artifacts_key = [artifacts_key]
|
||||
|
||||
# prepare_ensemble
|
||||
ensemble_dict = {}
|
||||
for key in artifacts_key:
|
||||
ensemble_dict.setdefault(key,{})
|
||||
ensemble_dict.setdefault(key, {})
|
||||
# filter records
|
||||
recs_flt = list_recorders(self.exp_name, rec_filter_func)
|
||||
for _, rec in recs_flt.items():
|
||||
@@ -64,7 +63,6 @@ class RecorderCollector(Collector):
|
||||
artifact = rec.load_object(self.artifacts_path[key])
|
||||
ensemble_dict[key][group_key] = artifact
|
||||
|
||||
|
||||
if isinstance(artifacts_key, str):
|
||||
return ensemble(ensemble_dict[artifacts_key])
|
||||
|
||||
@@ -72,4 +70,3 @@ class RecorderCollector(Collector):
|
||||
for key in artifacts_key:
|
||||
collect_dict[key] = ensemble(ensemble_dict[key])
|
||||
return collect_dict
|
||||
|
||||
@@ -7,12 +7,10 @@ from qlib.workflow.task.utils import list_recorders
|
||||
from typing import Dict
|
||||
|
||||
|
||||
|
||||
class Ensemble:
|
||||
"""Merge the objects in an Ensemble.
|
||||
"""
|
||||
"""Merge the objects in an Ensemble."""
|
||||
|
||||
def __init__(self, merge_func = None, get_grouped_key_func = None) -> None:
|
||||
def __init__(self, merge_func=None, get_grouped_key_func=None) -> None:
|
||||
"""init Ensemble
|
||||
|
||||
Args:
|
||||
@@ -26,7 +24,7 @@ class Ensemble:
|
||||
self.get_grouped_key_func = get_grouped_key_func
|
||||
|
||||
def merge_func(self, group_inner_dict):
|
||||
"""Given a group_inner_dict such as {Rollinga_b: object, Rollingb_c: object},
|
||||
"""Given a group_inner_dict such as {Rollinga_b: object, Rollingb_c: object},
|
||||
merge it to object
|
||||
|
||||
Args:
|
||||
@@ -34,10 +32,10 @@ class Ensemble:
|
||||
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `merge_func` method.")
|
||||
|
||||
|
||||
def get_grouped_key_func(self, group_key):
|
||||
"""Given a group_key and return the group_outer_key, group_inner_key.
|
||||
|
||||
|
||||
For example:
|
||||
(A,B,Rolling) -> (A,B):Rolling
|
||||
(A,B) -> C:(A,B)
|
||||
@@ -135,10 +133,10 @@ class Ensemble:
|
||||
grouped_dict = self.group(group_dict)
|
||||
return self.reduce(grouped_dict)
|
||||
|
||||
class RollingEnsemble(Ensemble):
|
||||
"""A specific implementation of Ensemble for Rolling.
|
||||
|
||||
"""
|
||||
class RollingEnsemble(Ensemble):
|
||||
"""A specific implementation of Ensemble for Rolling."""
|
||||
|
||||
def merge_func(self, group_inner_dict):
|
||||
"""merge group_inner_dict by datetime.
|
||||
|
||||
@@ -155,7 +153,7 @@ class RollingEnsemble(Ensemble):
|
||||
artifact = artifact[~artifact.index.duplicated(keep="last")]
|
||||
artifact = artifact.sort_index()
|
||||
return artifact
|
||||
|
||||
|
||||
def get_grouped_key_func(self, group_key):
|
||||
"""The final axis of group_key must be the Rolling key.
|
||||
When `collect`, get_group_key_func can add the statement below.
|
||||
@@ -174,7 +172,5 @@ class RollingEnsemble(Ensemble):
|
||||
Returns:
|
||||
tuple or str, tuple or str: group_outer_key, group_inner_key
|
||||
"""
|
||||
assert len(group_key)>=2
|
||||
assert len(group_key) >= 2
|
||||
return group_key[:-1], group_key[-1]
|
||||
|
||||
|
||||
|
||||
@@ -174,7 +174,7 @@ class TaskManager:
|
||||
|
||||
for t in new_tasks:
|
||||
self.insert_task_def(t, task_pool)
|
||||
|
||||
|
||||
return len(new_tasks)
|
||||
|
||||
def fetch_task(self, query={}, task_pool=None):
|
||||
@@ -250,7 +250,7 @@ class TaskManager:
|
||||
|
||||
def re_query(self, task, task_pool=None):
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
return task_pool.find_one({"_id":ObjectId(task["_id"])})
|
||||
return task_pool.find_one({"_id": ObjectId(task["_id"])})
|
||||
|
||||
def commit_task_res(self, task, res, status=None, task_pool=None):
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
|
||||
@@ -106,7 +106,9 @@ class RollingOnlineManager(OnlineManagerR):
|
||||
pass
|
||||
|
||||
def prepare_tasks(self):
|
||||
latest_records, max_test = self.list_latest_recorders(lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG)
|
||||
latest_records, max_test = self.list_latest_recorders(
|
||||
lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG
|
||||
)
|
||||
if max_test is None:
|
||||
self.logger.warn(f"No latest_recorders.")
|
||||
return
|
||||
|
||||
@@ -9,6 +9,7 @@ from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
class ModelUpdater:
|
||||
"""
|
||||
The model updater to update model results in new data.
|
||||
@@ -48,7 +49,7 @@ class ModelUpdater:
|
||||
dataset.setup_data(handler_kwargs={"init_type": DataHandlerLP.IT_LS}, segments=segments)
|
||||
return dataset
|
||||
|
||||
def update_pred(self, recorder: Recorder, frequency='day'):
|
||||
def update_pred(self, recorder: Recorder, frequency="day"):
|
||||
"""update predictions to the latest day in Calendar based on rid
|
||||
|
||||
Parameters
|
||||
@@ -60,10 +61,10 @@ class ModelUpdater:
|
||||
last_end = old_pred.index.get_level_values("datetime").max()
|
||||
|
||||
# updated to the latest trading day
|
||||
if frequency=='day':
|
||||
if frequency == "day":
|
||||
cal = D.calendar(start_time=last_end + pd.Timedelta(days=1), end_time=None)
|
||||
else:
|
||||
raise NotImplementedError("Now Qlib only support update daily frequency prediction")
|
||||
raise NotImplementedError("Now `ModelUpdater` only support update daily frequency prediction")
|
||||
|
||||
if len(cal) == 0:
|
||||
self.logger.info(
|
||||
|
||||
@@ -42,7 +42,7 @@ def list_recorders(experiment, rec_filter_func=None):
|
||||
dict: a dict {rid: recorder} after filtering.
|
||||
"""
|
||||
if isinstance(experiment, str):
|
||||
experiment, _ = R.get_exp(experiment_name=experiment)
|
||||
experiment = R.get_exp(experiment_name=experiment)
|
||||
recs = experiment.list_recorders()
|
||||
recs_flt = {}
|
||||
for rid, rec in recs.items():
|
||||
|
||||
Reference in New Issue
Block a user