From 94ab4bbf3feb5496720c6359dc85cfb1766ed5dd Mon Sep 17 00:00:00 2001 From: lzh222333 Date: Tue, 1 Jun 2021 07:45:39 +0000 Subject: [PATCH] add docs --- qlib/workflow/task/manage.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index 167087260..dd42caf65 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -24,7 +24,9 @@ from bson.binary import Binary from bson.objectid import ObjectId from pymongo.errors import InvalidDocument from qlib import auto_init, get_module_logger +import qlib from tqdm.cli import tqdm +import yaml from .utils import get_mongodb @@ -72,24 +74,26 @@ class TaskManager: def __init__(self, task_pool: str): """ Init Task Manager, remember to make the statement of MongoDB url and database name firstly. + A TaskManager instance serves a specific task pool. + The static method of this module serves the whole MongoDB. Parameters ---------- task_pool: str the name of Collection in MongoDB """ - self.mdb = get_mongodb() - self.task_pool = getattr(self.mdb, task_pool) + self.task_pool = getattr(get_mongodb(), task_pool) self.logger = get_module_logger(self.__class__.__name__) - def list(self) -> list: + @staticmethod + def list() -> list: """ - List the all collection(task_pool) of the db + List the all collection(task_pool) of the db. Returns: list """ - return self.mdb.list_collection_names() + return get_mongodb().list_collection_names() def _encode_task(self, task): for prefix in self.ENCODE_FIELDS_PREFIX: @@ -109,6 +113,16 @@ class TaskManager: return {k: str(v) for k, v in flt.items()} def _decode_query(self, query): + """ + If the query includes any `_id`, then it needs `ObjectId` to decode. + For example, when using TrainerRM, it needs query `{"_id": {"$in": _id_list}}`. Then we need to `ObjectId` every `_id` in `_id_list`. + + Args: + query (dict): query dict. Defaults to {}. + + Returns: + dict: the query after decoding. + """ if "_id" in query: if isinstance(query["_id"], dict): for key in query["_id"]: