mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
update & fix
This commit is contained in:
@@ -16,62 +16,6 @@ from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2018-01-01",
|
||||
"end_time": "2018-10-31",
|
||||
"fit_start_time": "2018-01-01",
|
||||
"fit_end_time": "2018-03-31",
|
||||
"instruments": "csi100",
|
||||
}
|
||||
|
||||
dataset_config = {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2018-01-01", "2018-03-31"),
|
||||
"valid": ("2018-04-01", "2018-05-31"),
|
||||
"test": ("2018-06-01", "2018-09-10"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
record_config = [
|
||||
{
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
{
|
||||
"class": "SigAnaRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
]
|
||||
|
||||
# use lgb model
|
||||
task_lgb_config = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
# use xgboost model
|
||||
task_xgboost_config = {
|
||||
"model": {
|
||||
"class": "XGBModel",
|
||||
"module_path": "qlib.contrib.model.xgboost",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
|
||||
class OnlineSimulationExample:
|
||||
def __init__(
|
||||
@@ -103,10 +47,7 @@ class OnlineSimulationExample:
|
||||
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
|
||||
"""
|
||||
if tasks is None:
|
||||
#tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE, CSI100_RECORD_LGB_TASK_CONFIG_ONLINE]
|
||||
tasks = [task_xgboost_config, task_lgb_config]
|
||||
#pprint(CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE)
|
||||
#pprint(task_xgboost_config)
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE, CSI100_RECORD_LGB_TASK_CONFIG_ONLINE]
|
||||
self.exp_name = exp_name
|
||||
self.task_pool = task_pool
|
||||
self.start_time = start_time
|
||||
|
||||
@@ -19,6 +19,7 @@ from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING, CSI100_RECORD_LGB_TASK_CONFIG_ROLLING
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
|
||||
|
||||
class RollingOnlineExample:
|
||||
|
||||
@@ -163,17 +163,20 @@ class OnlineManager(Serializable):
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
self.prepare_signals(**signal_kwargs)
|
||||
|
||||
def get_collector(self) -> MergeCollector:
|
||||
def get_collector(self, **kwargs) -> MergeCollector:
|
||||
"""
|
||||
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy.
|
||||
This collector can be a basis as the signals preparation.
|
||||
|
||||
Args:
|
||||
**kwargs: the params for get_collector.
|
||||
|
||||
Returns:
|
||||
MergeCollector: the collector to merge other collectors.
|
||||
"""
|
||||
collector_dict = {}
|
||||
for strategy in self.strategies:
|
||||
collector_dict[strategy.name_id] = strategy.get_collector()
|
||||
collector_dict[strategy.name_id] = strategy.get_collector(**kwargs)
|
||||
return MergeCollector(collector_dict, process_list=[])
|
||||
|
||||
def add_strategy(self, strategies: Union[OnlineStrategy, List[OnlineStrategy]]):
|
||||
|
||||
Reference in New Issue
Block a user