1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-30 01:21:18 +08:00

fix task name & add cur_path

This commit is contained in:
Young
2021-03-12 10:17:16 +00:00
parent 5de7870f9b
commit e4e8a4abcd
4 changed files with 10 additions and 8 deletions

View File

@@ -154,7 +154,7 @@ def init_from_yaml_conf(conf_path, **kwargs):
init(default_conf, **config)
def get_project_path(config_name="config.yaml") -> Path:
def get_project_path(config_name="config.yaml", cur_path=None) -> Path:
"""
If users are building a project follow the following pattern.
- Qlib is a sub folder in project path
@@ -181,7 +181,8 @@ def get_project_path(config_name="config.yaml") -> Path:
FileNotFoundError:
If project path is not found
"""
cur_path = Path(__file__).absolute().resolve()
if cur_path is None:
cur_path = Path(__file__).absolute().resolve()
while True:
if (cur_path / config_name).exists():
return cur_path
@@ -199,7 +200,7 @@ def auto_init(**kwargs):
"""
try:
pp = get_project_path()
pp = get_project_path(cur_path=kwargs.pop("cur_path", None))
except FileNotFoundError:
init(**kwargs)
else:

View File

@@ -34,7 +34,7 @@ def task_train(task_config: dict, experiment_name: str) -> str:
model.fit(dataset)
recorder = R.get_recorder()
R.save_objects(**{"params.pkl": model})
R.save_objects(**{"task.pkl": task_config}) # keep the original format and datatype
R.save_objects(task=task_config) # keep the original format and datatype
# generate records: prediction, backtest, and analysis
records = task_config.get("record", [])

View File

@@ -1,5 +1,6 @@
from qlib.workflow import R
import pandas as pd
import tqdm.auto import tqdm
from typing import Union
from qlib import get_module_logger
@@ -35,7 +36,7 @@ class TaskCollector:
recs_flt = {}
for rid, rec in recs.items():
params = rec.load_object("task.pkl")
params = rec.load_object("task")
if rec.status == rec.STATUS_FI:
if filter_func is None or filter_func(params):
rec.params = params
@@ -83,7 +84,7 @@ class RollingCollector:
recs_flt = {}
for rid, rec in tqdm(recs.items(), desc="Loading data"):
params = rec.load_object("task.pkl")
params = rec.load_object("task")
if rec.status == rec.STATUS_FI:
if self.flt_func is None or self.flt_func(params):
rec.params = params

View File

@@ -74,7 +74,7 @@ class ModelUpdater:
rec = self.exp.get_recorder(recorder_id=rid)
old_pred = rec.load_object("pred.pkl")
last_end = old_pred.index.get_level_values("datetime").max()
task_config = rec.load_object("task.pkl")
task_config = rec.load_object("task")
# updated to the latest trading day
cal = D.calendar(start_time=last_end + pd.Timedelta(days=1), end_time=None)
@@ -107,7 +107,7 @@ class ModelUpdater:
.. code-block:: python
def record_filter(record):
task_config = record.load_object("task.pkl")
task_config = record.load_object("task")
if task_config["model"]["class"]=="LGBModel":
return True
return False