1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 02:21:18 +08:00
This commit is contained in:
lzh222333
2021-06-01 07:45:39 +00:00
parent ca0363ded8
commit 94ab4bbf3f

View File

@@ -24,7 +24,9 @@ from bson.binary import Binary
from bson.objectid import ObjectId
from pymongo.errors import InvalidDocument
from qlib import auto_init, get_module_logger
import qlib
from tqdm.cli import tqdm
import yaml
from .utils import get_mongodb
@@ -72,24 +74,26 @@ class TaskManager:
def __init__(self, task_pool: str):
"""
Init Task Manager, remember to make the statement of MongoDB url and database name firstly.
A TaskManager instance serves a specific task pool.
The static method of this module serves the whole MongoDB.
Parameters
----------
task_pool: str
the name of Collection in MongoDB
"""
self.mdb = get_mongodb()
self.task_pool = getattr(self.mdb, task_pool)
self.task_pool = getattr(get_mongodb(), task_pool)
self.logger = get_module_logger(self.__class__.__name__)
def list(self) -> list:
@staticmethod
def list() -> list:
"""
List the all collection(task_pool) of the db
List the all collection(task_pool) of the db.
Returns:
list
"""
return self.mdb.list_collection_names()
return get_mongodb().list_collection_names()
def _encode_task(self, task):
for prefix in self.ENCODE_FIELDS_PREFIX:
@@ -109,6 +113,16 @@ class TaskManager:
return {k: str(v) for k, v in flt.items()}
def _decode_query(self, query):
"""
If the query includes any `_id`, then it needs `ObjectId` to decode.
For example, when using TrainerRM, it needs query `{"_id": {"$in": _id_list}}`. Then we need to `ObjectId` every `_id` in `_id_list`.
Args:
query (dict): query dict. Defaults to {}.
Returns:
dict: the query after decoding.
"""
if "_id" in query:
if isinstance(query["_id"], dict):
for key in query["_id"]: