mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 20:11:08 +08:00
Enhance Task Dict Var (#778)
This commit is contained in:
@@ -17,7 +17,7 @@ from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.model.trainer import TrainerRM, task_train
|
||||
from qlib.model.trainer import TrainerR, TrainerRM, task_train
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ class RollingTaskExample:
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
experiment_name="rolling_exp",
|
||||
task_pool="rolling_task",
|
||||
task_pool=None, # if user want to "rolling_task"
|
||||
task_config=None,
|
||||
rolling_step=550,
|
||||
rolling_type=RollingGen.ROLL_SD,
|
||||
@@ -43,14 +43,19 @@ class RollingTaskExample:
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
self.experiment_name = experiment_name
|
||||
self.task_pool = task_pool
|
||||
if task_pool is None:
|
||||
self.trainer = TrainerR(experiment_name=self.experiment_name)
|
||||
else:
|
||||
self.task_pool = task_pool
|
||||
self.trainer = TrainerRM(self.experiment_name, self.task_pool)
|
||||
self.task_config = task_config
|
||||
self.rolling_gen = RollingGen(step=rolling_step, rtype=rolling_type)
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
print("========== reset ==========")
|
||||
TaskManager(task_pool=self.task_pool).remove()
|
||||
if isinstance(self.trainer, TrainerRM):
|
||||
TaskManager(task_pool=self.task_pool).remove()
|
||||
exp = R.get_exp(experiment_name=self.experiment_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
@@ -66,10 +71,10 @@ class RollingTaskExample:
|
||||
|
||||
def task_training(self, tasks):
|
||||
print("========== task_training ==========")
|
||||
trainer = TrainerRM(self.experiment_name, self.task_pool)
|
||||
trainer.train(tasks)
|
||||
self.trainer.train(tasks)
|
||||
|
||||
def worker(self):
|
||||
# NOTE: this is only used for TrainerRM
|
||||
# train tasks by other progress or machines for multiprocessing. It is same as TrainerRM.worker.
|
||||
print("========== worker ==========")
|
||||
run_task(task_train, self.task_pool, experiment_name=self.experiment_name)
|
||||
|
||||
@@ -86,10 +86,61 @@ def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str
|
||||
return R.get_recorder()
|
||||
|
||||
|
||||
def get_item_from_obj(config: dict, name_path: str) -> object:
|
||||
"""
|
||||
Follow the name_path to get values from config
|
||||
For example:
|
||||
If we follow the example in in the Parameters section,
|
||||
Timestamp('2008-01-02 00:00:00') will be returned
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : dict
|
||||
e.g.
|
||||
{'dataset': {'class': 'DatasetH',
|
||||
'kwargs': {'handler': {'class': 'Alpha158',
|
||||
'kwargs': {'end_time': '2020-08-01',
|
||||
'fit_end_time': '<dataset.kwargs.segments.train.1>',
|
||||
'fit_start_time': '<dataset.kwargs.segments.train.0>',
|
||||
'instruments': 'csi100',
|
||||
'start_time': '2008-01-01'},
|
||||
'module_path': 'qlib.contrib.data.handler'},
|
||||
'segments': {'test': (Timestamp('2017-01-03 00:00:00'),
|
||||
Timestamp('2019-04-08 00:00:00')),
|
||||
'train': (Timestamp('2008-01-02 00:00:00'),
|
||||
Timestamp('2014-12-31 00:00:00')),
|
||||
'valid': (Timestamp('2015-01-05 00:00:00'),
|
||||
Timestamp('2016-12-30 00:00:00'))}}
|
||||
}}
|
||||
name_path : str
|
||||
e.g.
|
||||
"dataset.kwargs.segments.train.1"
|
||||
|
||||
Returns
|
||||
-------
|
||||
object
|
||||
the retrieved object
|
||||
"""
|
||||
cur_cfg = config
|
||||
for k in name_path.split("."):
|
||||
if isinstance(cur_cfg, dict):
|
||||
cur_cfg = cur_cfg[k]
|
||||
elif k.isdigit():
|
||||
cur_cfg = cur_cfg[int(k)]
|
||||
else:
|
||||
raise ValueError(f"Error when getting {k} from cur_cfg")
|
||||
return cur_cfg
|
||||
|
||||
|
||||
def fill_placeholder(config: dict, config_extend: dict):
|
||||
"""
|
||||
Detect placeholder in config and fill them with config_extend.
|
||||
The item of dict must be single item(int, str, etc), dict and list. Tuples are not supported.
|
||||
There are two type of variables:
|
||||
- user-defined variables :
|
||||
e.g. when config_extend is `{"<MODEL>": model, "<DATASET>": dataset}`, "<MODEL>" and "<DATASET>" in `config` will be replaced with `model` `dataset`
|
||||
- variables extracted from `config` :
|
||||
e.g. the variables like "<dataset.kwargs.segments.train.0>" will be replaced with the values from `config`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -122,8 +173,13 @@ def fill_placeholder(config: dict, config_extend: dict):
|
||||
if isinstance(now_item[key], list) or isinstance(now_item[key], dict):
|
||||
item_queue.append(now_item[key])
|
||||
tail += 1
|
||||
elif isinstance(now_item[key], str) and now_item[key] in config_extend.keys():
|
||||
now_item[key] = config_extend[now_item[key]]
|
||||
elif isinstance(now_item[key], str):
|
||||
if now_item[key] in config_extend.keys():
|
||||
now_item[key] = config_extend[now_item[key]]
|
||||
else:
|
||||
m = re.match(r"<(?P<name_path>[^<>]+)>", now_item[key])
|
||||
if m is not None:
|
||||
now_item[key] = get_item_from_obj(config, m.groupdict()["name_path"])
|
||||
return config
|
||||
|
||||
|
||||
|
||||
@@ -50,8 +50,8 @@ RECORD_CONFIG = [
|
||||
def get_data_handler_config(
|
||||
start_time="2008-01-01",
|
||||
end_time="2020-08-01",
|
||||
fit_start_time="2008-01-01",
|
||||
fit_end_time="2014-12-31",
|
||||
fit_start_time="<dataset.kwargs.segments.train.0>",
|
||||
fit_end_time="<dataset.kwargs.segments.train.1>",
|
||||
instruments=CSI300_MARKET,
|
||||
):
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user