mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-30 01:21:18 +08:00
fix task name & add cur_path
This commit is contained in:
@@ -154,7 +154,7 @@ def init_from_yaml_conf(conf_path, **kwargs):
|
||||
init(default_conf, **config)
|
||||
|
||||
|
||||
def get_project_path(config_name="config.yaml") -> Path:
|
||||
def get_project_path(config_name="config.yaml", cur_path=None) -> Path:
|
||||
"""
|
||||
If users are building a project follow the following pattern.
|
||||
- Qlib is a sub folder in project path
|
||||
@@ -181,7 +181,8 @@ def get_project_path(config_name="config.yaml") -> Path:
|
||||
FileNotFoundError:
|
||||
If project path is not found
|
||||
"""
|
||||
cur_path = Path(__file__).absolute().resolve()
|
||||
if cur_path is None:
|
||||
cur_path = Path(__file__).absolute().resolve()
|
||||
while True:
|
||||
if (cur_path / config_name).exists():
|
||||
return cur_path
|
||||
@@ -199,7 +200,7 @@ def auto_init(**kwargs):
|
||||
"""
|
||||
|
||||
try:
|
||||
pp = get_project_path()
|
||||
pp = get_project_path(cur_path=kwargs.pop("cur_path", None))
|
||||
except FileNotFoundError:
|
||||
init(**kwargs)
|
||||
else:
|
||||
|
||||
@@ -34,7 +34,7 @@ def task_train(task_config: dict, experiment_name: str) -> str:
|
||||
model.fit(dataset)
|
||||
recorder = R.get_recorder()
|
||||
R.save_objects(**{"params.pkl": model})
|
||||
R.save_objects(**{"task.pkl": task_config}) # keep the original format and datatype
|
||||
R.save_objects(task=task_config) # keep the original format and datatype
|
||||
|
||||
# generate records: prediction, backtest, and analysis
|
||||
records = task_config.get("record", [])
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from qlib.workflow import R
|
||||
import pandas as pd
|
||||
import tqdm.auto import tqdm
|
||||
from typing import Union
|
||||
from qlib import get_module_logger
|
||||
|
||||
@@ -35,7 +36,7 @@ class TaskCollector:
|
||||
|
||||
recs_flt = {}
|
||||
for rid, rec in recs.items():
|
||||
params = rec.load_object("task.pkl")
|
||||
params = rec.load_object("task")
|
||||
if rec.status == rec.STATUS_FI:
|
||||
if filter_func is None or filter_func(params):
|
||||
rec.params = params
|
||||
@@ -83,7 +84,7 @@ class RollingCollector:
|
||||
|
||||
recs_flt = {}
|
||||
for rid, rec in tqdm(recs.items(), desc="Loading data"):
|
||||
params = rec.load_object("task.pkl")
|
||||
params = rec.load_object("task")
|
||||
if rec.status == rec.STATUS_FI:
|
||||
if self.flt_func is None or self.flt_func(params):
|
||||
rec.params = params
|
||||
|
||||
@@ -74,7 +74,7 @@ class ModelUpdater:
|
||||
rec = self.exp.get_recorder(recorder_id=rid)
|
||||
old_pred = rec.load_object("pred.pkl")
|
||||
last_end = old_pred.index.get_level_values("datetime").max()
|
||||
task_config = rec.load_object("task.pkl")
|
||||
task_config = rec.load_object("task")
|
||||
|
||||
# updated to the latest trading day
|
||||
cal = D.calendar(start_time=last_end + pd.Timedelta(days=1), end_time=None)
|
||||
@@ -107,7 +107,7 @@ class ModelUpdater:
|
||||
.. code-block:: python
|
||||
|
||||
def record_filter(record):
|
||||
task_config = record.load_object("task.pkl")
|
||||
task_config = record.load_object("task")
|
||||
if task_config["model"]["class"]=="LGBModel":
|
||||
return True
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user