diff --git a/examples/model_rolling/task_manager_rolling.py b/examples/model_rolling/task_manager_rolling.py index 907086487..3e914cc63 100644 --- a/examples/model_rolling/task_manager_rolling.py +++ b/examples/model_rolling/task_manager_rolling.py @@ -119,7 +119,8 @@ def task_collecting(task_pool, exp_name): return False artifact = ens_workflow( - RecorderCollector(exp_name=exp_name, rec_key_func=rec_key, rec_filter_func=my_filter), RollingGroup(), + RecorderCollector(exp_name=exp_name, rec_key_func=rec_key, rec_filter_func=my_filter), + RollingGroup(), ) print(artifact) diff --git a/examples/online_srv/task_manager_rolling_with_updating.py b/examples/online_srv/task_manager_rolling_with_updating.py index 5b80f9133..bfdc5f3c0 100644 --- a/examples/online_srv/task_manager_rolling_with_updating.py +++ b/examples/online_srv/task_manager_rolling_with_updating.py @@ -70,9 +70,18 @@ task_xgboost_config = { "record": record_config, } -class RollingOnlineExample: - def __init__(self, exp_name="rolling_exp", task_pool="rolling_task", provider_uri="~/.qlib/qlib_data/cn_data", region="cn", task_url="mongodb://10.0.0.4:27017/", task_db_name="rolling_db", rolling_step=550): +class RollingOnlineExample: + def __init__( + self, + exp_name="rolling_exp", + task_pool="rolling_task", + provider_uri="~/.qlib/qlib_data/cn_data", + region="cn", + task_url="mongodb://10.0.0.4:27017/", + task_db_name="rolling_db", + rolling_step=550, + ): self.exp_name = exp_name self.task_pool = task_pool mongo_conf = { @@ -84,9 +93,9 @@ class RollingOnlineExample: self.rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD) self.trainer = TrainerRM(self.exp_name, self.task_pool) self.task_manager = TaskManager(self.task_pool) - self.rolling_online_manager = RollingOnlineManager(experiment_name=exp_name, rolling_gen=self.rolling_gen, trainer=self.trainer) - - + self.rolling_online_manager = RollingOnlineManager( + experiment_name=exp_name, rolling_gen=self.rolling_gen, trainer=self.trainer + ) def print_online_model(self): print("========== print_online_model ==========") @@ -99,7 +108,6 @@ class RollingOnlineExample: if self.rolling_online_manager.get_online_tag(rec) == self.rolling_online_manager.NEXT_ONLINE_TAG: print(rid) - # This part corresponds to "Task Generating" in the document def task_generating(self): @@ -114,11 +122,9 @@ class RollingOnlineExample: return tasks - def task_training(self, tasks): self.trainer.train(tasks) - # This part corresponds to "Task Collecting" in the document def task_collecting(self): print("========== task_collecting ==========") @@ -141,7 +147,6 @@ class RollingOnlineExample: ) print(artifact) - # Reset all things to the first status, be careful to save important data def reset(self): print("========== reset ==========") @@ -150,7 +155,6 @@ class RollingOnlineExample: for rid in exp.list_recorders(): exp.delete_recorder(rid) - # Run this firstly to see the workflow in Task Management def first_run(self): print("========== first_run ==========") @@ -163,7 +167,6 @@ class RollingOnlineExample: latest_rec, _ = self.rolling_online_manager.list_latest_recorders() self.rolling_online_manager.reset_online_tag(list(latest_rec.values())) - def routine(self): print("========== routine ==========") self.print_online_model() @@ -178,7 +181,7 @@ if __name__ == "__main__": ####### to update the models and predictions after the trading time, use the command below # python task_manager_rolling_with_updating.py after_day - + ####### to define your own parameters, use `--` # python task_manager_rolling_with_updating.py first_run --exp_name='your_exp_name' --rolling_step=40 fire.Fire(RollingOnlineExample) diff --git a/examples/online_srv/update_online_pred.py b/examples/online_srv/update_online_pred.py index 84472bc3b..0f075abcd 100644 --- a/examples/online_srv/update_online_pred.py +++ b/examples/online_srv/update_online_pred.py @@ -71,12 +71,14 @@ def update_online_pred(experiment_name="online_srv"): online_manager.update_online_pred() -def main(provider_uri = "~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv"): + +def main(provider_uri="~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv"): provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir qlib.init(provider_uri=provider_uri, region=region) first_train(experiment_name) update_online_pred(experiment_name) + if __name__ == "__main__": ## to train a model and set it to online model, use the command below # python update_online_pred.py first_train diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 35f067304..2a34035f3 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -184,7 +184,7 @@ def get_module_by_module_path(module_path: Union[str, ModuleType]): return module -def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType]=None) -> (type, dict): +def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType] = None) -> (type, dict): """ extract class and kwargs from config info diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py index 0676bfb6b..66df160cd 100644 --- a/qlib/workflow/online/manager.py +++ b/qlib/workflow/online/manager.py @@ -74,6 +74,8 @@ class OnlineManager(Serializable): self.update_online_pred(*args, **kwargs) self.reset_online_tag(*args, **kwargs) + # TODO: first_train? + class OnlineManagerR(OnlineManager): """ diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py index c64939e82..a8426d920 100644 --- a/qlib/workflow/task/gen.py +++ b/qlib/workflow/task/gen.py @@ -106,7 +106,8 @@ class RollingGen(TaskGen): def generate(self, task: dict): """ - Converting the task into a rolling task + Converting the task into a rolling task. + # FIXME: only modify dataset layer, user need to change datahandler firstly. Parameters ---------- diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index 0d6f8c0de..720eeb12f 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -62,10 +62,6 @@ class TaskManager: self.task_pool = getattr(self.mdb, task_pool) self.logger = get_module_logger(self.__class__.__name__) - # @property - # def task_pool(self): - # return self._task_pool - def list(self): return self.mdb.list_collection_names() @@ -83,22 +79,12 @@ class TaskManager: task[k] = pickle.loads(task[k]) return task - # def _get_task_pool(self, task_pool=None): - # if task_pool is None: - # task_pool = self.task_pool - # if task_pool is None: - # raise ValueError("You must specify a task pool.") - # if isinstance(task_pool, str): - # return getattr(self.mdb, task_pool) - # return task_pool - def _dict_to_str(self, flt): return {k: str(v) for k, v in flt.items()} def replace_task(self, task, new_task): # assume that the data out of interface was decoded and the data in interface was encoded new_task = self._encode_task(new_task) - # task_pool = self._get_task_pool(task_pool) query = {"_id": ObjectId(task["_id"])} try: self.task_pool.replace_one(query, new_task) @@ -107,7 +93,7 @@ class TaskManager: self.task_pool.replace_one(query, new_task) def insert_task(self, task): - # task_pool = self._get_task_pool(task_pool) + try: insert_result = self.task_pool.insert_one(task) except InvalidDocument: @@ -123,14 +109,11 @@ class TaskManager: ---------- task_def: dict the task definition - task_pool: str - the name of Collection in MongoDB Returns ------- """ - # task_pool = self._get_task_pool(task_pool) task = self._encode_task( { "def": task_def, @@ -149,8 +132,6 @@ class TaskManager: ---------- task_def_l: list a list of task - task_pool: str - the name of task_pool (collection name of MongoDB) dry_run: bool if insert those new tasks to task pool print_nt: bool @@ -160,7 +141,6 @@ class TaskManager: list a list of the _id of new tasks """ - # task_pool = self._get_task_pool(task_pool) new_tasks = [] for t in task_def_l: try: @@ -186,7 +166,6 @@ class TaskManager: return _id_list def fetch_task(self, query={}): - # task_pool = self._get_task_pool(task_pool) query = query.copy() if "_id" in query: query["_id"] = ObjectId(query["_id"]) @@ -209,8 +188,6 @@ class TaskManager: ---------- query: dict the dict of query - task_pool: str - the name of Collection in MongoDB Returns ------- @@ -226,9 +203,9 @@ class TaskManager: self.logger.info("Task returned") raise - def task_fetcher_iter(self, query={}, task_pool=None): + def task_fetcher_iter(self, query={}): while True: - with self.safe_fetch_task(query=query, task_pool=task_pool) as task: + with self.safe_fetch_task(query=query) as task: if task is None: break yield task @@ -242,8 +219,6 @@ class TaskManager: query: dict the dict of query decode: bool - task_pool: str - the name of Collection in MongoDB Returns ------- @@ -252,24 +227,20 @@ class TaskManager: query = query.copy() if "_id" in query: query["_id"] = ObjectId(query["_id"]) - # task_pool = self._get_task_pool(task_pool) for t in self.task_pool.find(query): yield self._decode_task(t) def re_query(self, _id): - # task_pool = self._get_task_pool(task_pool) t = self.task_pool.find_one({"_id": ObjectId(_id)}) return self._decode_task(t) def commit_task_res(self, task, res, status=None): - # task_pool = self._get_task_pool(task_pool) # A workaround to use the class attribute. if status is None: status = TaskManager.STATUS_DONE self.task_pool.update_one({"_id": task["_id"]}, {"$set": {"status": status, "res": Binary(pickle.dumps(res))}}) def return_task(self, task, status=None): - # task_pool = self._get_task_pool(task_pool) if status is None: status = TaskManager.STATUS_WAITING update_dict = {"$set": {"status": status}} @@ -283,15 +254,12 @@ class TaskManager: ---------- query: dict the dict of query - task_pool: str - the name of Collection in MongoDB Returns ------- """ query = query.copy() - # task_pool = self._get_task_pool(task_pool) if "_id" in query: query["_id"] = ObjectId(query["_id"]) self.task_pool.delete_many(query) @@ -306,7 +274,7 @@ class TaskManager: status_stat[t["status"]] = status_stat.get(t["status"], 0) + 1 return status_stat - def reset_waiting(self, query={}, task_pool=None): + def reset_waiting(self, query={}): query = query.copy() # default query if "status" not in query: @@ -315,7 +283,6 @@ class TaskManager: def reset_status(self, query, status): query = query.copy() - # task_pool = self._get_task_pool(task_pool) if "_id" in query: query["_id"] = ObjectId(query["_id"]) print(self.task_pool.update_many(query, {"$set": {"status": status}}))