1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-30 17:41:18 +08:00

Merge pull request #1 from you-n-g/online_srv

qlib auto init basedon project & black format
This commit is contained in:
lzh222333
2021-03-02 10:06:45 +08:00
committed by GitHub
6 changed files with 158 additions and 66 deletions

View File

@@ -3,6 +3,7 @@
__version__ = "0.6.3.99"
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
import os
@@ -10,12 +11,13 @@ import yaml
import logging
import platform
import subprocess
from pathlib import Path
from .log import get_module_logger
# init qlib
def init(default_conf="client", **kwargs):
from .config import C
from .log import get_module_logger
from .data.cache import H
H.clear()
@@ -48,7 +50,6 @@ def init(default_conf="client", **kwargs):
def _mount_nfs_uri(C):
from .log import get_module_logger
LOG = get_module_logger("mount nfs", level=logging.INFO)
@@ -151,3 +152,73 @@ def init_from_yaml_conf(conf_path, **kwargs):
config.update(kwargs)
default_conf = config.pop("default_conf", "client")
init(default_conf, **config)
def get_project_path(config_name="config.yaml") -> Path:
"""
If users are building a project follow the following pattern.
- Qlib is a sub folder in project path
- There is a file named `config.yaml` in qlib.
For example:
If your project file system stucuture follows such a pattern
<project_path>/
- config.yaml
- ...some folders...
- qlib/
This folder will return <project_path>
NOTE: link is not supported here.
This method is often used when
- user want to use a relative config path instead of hard-coding qlib config path in code
Raises
------
FileNotFoundError:
If project path is not found
"""
cur_path = Path(__file__).absolute().resolve()
while True:
if (cur_path / config_name).exists():
return cur_path
if cur_path == cur_path.parent:
raise FileNotFoundError("We can't find the project path")
cur_path = cur_path.parent
def auto_init(**kwargs):
"""
This function will init qlib automatically with following priority
- Find the project configuration and init qlib
- The parsing process will be affected by the `conf_type` of the configuration file
- Init qlib with default config
"""
try:
pp = get_project_path()
except FileNotFoundError:
init(**kwargs)
else:
conf_pp = pp / "config.yaml"
with conf_pp.open() as f:
conf = yaml.safe_load(f)
conf_type = conf.get("conf_type", "origin")
if conf_type == "origin":
# The type of config is just like original qlib config
init_from_yaml_conf(conf_pp, **kwargs)
elif conf_type == "ref":
# This config type will be more convenient in following scenario
# - There is a shared configure file and you don't want to edit it inplace.
# - The shared configure may be updated later and you don't want to copy it.
# - You have some customized config.
qlib_conf_path = conf["qlib_cfg"]
qlib_conf_update = conf.get("qlib_cfg_update")
init_from_yaml_conf(qlib_conf_path, **qlib_conf_update, **kwargs)
logger = get_module_logger("Initialization")
logger.info(f"Auto load project config: {conf_pp}")

View File

@@ -33,6 +33,9 @@ class Config:
raise AttributeError(f"No such {attr} in self._config")
def get(self, key, default=None):
return self.__dict__["_config"].get(key, default)
def __setitem__(self, key, value):
self.__dict__["_config"][key] = value
@@ -310,8 +313,22 @@ class QlibConfig(Config):
# clean up experiment when python program ends
experiment_exit_handler()
# Supporting user reset qlib version (useful when user want to connect to qlib server with old version)
self.reset_qlib_version()
self._registered = True
def reset_qlib_version(self):
import qlib
reset_version = self.get("qlib_reset_version", None)
if reset_version is not None:
qlib.__version__ = reset_version
else:
qlib.__version__ = getattr(qlib, "__version__bak")
# Due to a bug? that converting __version__ to _QlibConfig__version__bak
# Using __version__bak instead of __version__
@property
def registered(self):
return self._registered

View File

@@ -5,11 +5,12 @@ from tqdm.auto import tqdm
class RollingEnsemble:
'''
"""
Rolling Models Ensemble based on (R)ecord
This shares nothing with Ensemble
'''
"""
# TODO: 这边还可以加加速
def __init__(self, get_key_func, flt_func=None):
self.get_key_func = get_key_func
@@ -44,9 +45,8 @@ class RollingEnsemble:
for k, rec_l in recs_group.items():
pred_l = []
for rec in rec_l:
pred_l.append(rec.load_object('pred.pkl').iloc[:, 0])
pred_l.append(rec.load_object("pred.pkl").iloc[:, 0])
pred = pd.concat(pred_l).sort_index()
reduce_group[k] = pred
return reduce_group

View File

@@ -1,8 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
'''
"""
this is a task generator
'''
"""
import abc
import copy
import typing
@@ -54,8 +54,8 @@ class RollingGen(TaskGen):
self.rtype = rtype
self.ta = TimeAdjuster(future=True) # 为了保证test最后的日期不是None, 所以这边要改一改
self.test_key = 'test'
self.train_key = 'train'
self.test_key = "test"
self.train_key = "train"
def __call__(self, task: dict):
"""
@@ -102,7 +102,7 @@ class RollingGen(TaskGen):
if prev_seg is None:
# First rolling
# 1) prepare the end porint
segments = copy.deepcopy(self.ta.align_seg(t['dataset']['kwargs']['segments']))
segments = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1]
# 2) and the init test segments
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
@@ -120,14 +120,12 @@ class RollingGen(TaskGen):
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
if segments[self.test_key][0] > test_end:
break
except KeyError:
except KeyError:
# We reach the end of tasks
# No more rolling
break
t['dataset']['kwargs']['segments'] = copy.deepcopy(segments)
t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments)
prev_seg = segments
res.append(t)
return res

View File

@@ -19,6 +19,8 @@ import time
import concurrent
import pymongo
from qlib.config import C
from .utils import get_mongodb
from qlib import auto_init
class TaskManager:
@@ -41,12 +43,13 @@ class TaskManager:
NOTE:
- 假设: 存储在db里面的都是encode过的 拿出来的都是decode过的
"""
STATUS_WAITING = 'waiting'
STATUS_RUNNING = 'running'
STATUS_DONE = 'done'
STATUS_PART_DONE = 'part_done'
ENCODE_FIELDS_PREFIX = ['def', 'res']
STATUS_WAITING = "waiting"
STATUS_RUNNING = "running"
STATUS_DONE = "done"
STATUS_PART_DONE = "part_done"
ENCODE_FIELDS_PREFIX = ["def", "res"]
def __init__(self, task_pool=None):
self.mdb = get_mongodb()
@@ -73,7 +76,7 @@ class TaskManager:
if task_pool is None:
task_pool = self.task_pool
if task_pool is None:
raise ValueError('You must specify a task pool.')
raise ValueError("You must specify a task pool.")
if isinstance(task_pool, str):
return getattr(self.mdb, task_pool)
return task_pool
@@ -85,11 +88,11 @@ class TaskManager:
# 这里的假设是从接口拿出来的都是decode过的在接口内部的都是 encode过的
new_task = self._encode_task(new_task)
task_pool = self._get_task_pool(task_pool)
query = {'_id': ObjectId(task['_id'])}
query = {"_id": ObjectId(task["_id"])}
try:
task_pool.replace_one(query, new_task)
except InvalidDocument:
task['filter'] = self._dict_to_str(task['filter'])
task["filter"] = self._dict_to_str(task["filter"])
task_pool.replace_one(query, new_task)
def insert_task(self, task, task_pool=None):
@@ -97,16 +100,18 @@ class TaskManager:
try:
task_pool.insert_one(task)
except InvalidDocument:
task['filter'] = self._dict_to_str(task['filter'])
task["filter"] = self._dict_to_str(task["filter"])
task_pool.insert_one(task)
def insert_task_def(self, task_def, task_pool=None):
task_pool = self._get_task_pool(task_pool)
task = self._encode_task({
'def': task_def,
'filter': task_def, # FIXME: catch the raised error
'status': self.STATUS_WAITING,
})
task = self._encode_task(
{
"def": task_def,
"filter": task_def, # FIXME: catch the raised error
"status": self.STATUS_WAITING,
}
)
self.insert_task(task, task_pool)
def create_task(self, task_def_l, task_pool=None, dry_run=False, print_nt=False):
@@ -114,9 +119,9 @@ class TaskManager:
new_tasks = []
for t in task_def_l:
try:
r = task_pool.find_one({'filter': t})
r = task_pool.find_one({"filter": t})
except InvalidDocument:
r = task_pool.find_one({'filter': self._dict_to_str(t)})
r = task_pool.find_one({"filter": self._dict_to_str(t)})
if r is None:
new_tasks.append(t)
print("Total Tasks, New Tasks:", len(task_def_l), len(new_tasks))
@@ -134,17 +139,16 @@ class TaskManager:
def fetch_task(self, query={}, task_pool=None):
task_pool = self._get_task_pool(task_pool)
query = query.copy()
if '_id' in query:
query['_id'] = ObjectId(query['_id'])
query.update({'status': self.STATUS_WAITING})
task = task_pool.find_one_and_update(query, {'$set': {
'status': self.STATUS_RUNNING
}},
sort=[('priority', pymongo.DESCENDING)])
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
query.update({"status": self.STATUS_WAITING})
task = task_pool.find_one_and_update(
query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)]
)
# 这里我的 priority 必须是 高数优先级更高,因为 null会被在 ASCENDING时被排在最前面
if task is None:
return None
task['status'] = self.STATUS_RUNNING
task["status"] = self.STATUS_RUNNING
return self._decode_task(task)
@contextmanager
@@ -154,9 +158,9 @@ class TaskManager:
yield task
except Exception:
if task is not None:
logger.info('Returning task before raising error')
logger.info("Returning task before raising error")
self.return_task(task)
logger.info('Task returned')
logger.info("Task returned")
raise
def task_fetcher_iter(self, query={}, task_pool=None):
@@ -175,8 +179,8 @@ class TaskManager:
:param task_pool:
"""
query = query.copy()
if '_id' in query:
query['_id'] = ObjectId(query['_id'])
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
task_pool = self._get_task_pool(task_pool)
for t in task_pool.find(query):
yield self._decode_task(t)
@@ -186,44 +190,44 @@ class TaskManager:
# A workaround to use the class attribute.
if status is None:
status = TaskManager.STATUS_DONE
task_pool.update_one({"_id": task['_id']}, {'$set': {'status': status, 'res': Binary(pickle.dumps(res))}})
task_pool.update_one({"_id": task["_id"]}, {"$set": {"status": status, "res": Binary(pickle.dumps(res))}})
def return_task(self, task, status=None, task_pool=None):
task_pool = self._get_task_pool(task_pool)
if status is None:
status = TaskManager.STATUS_WAITING
update_dict = {'$set': {'status': status}}
task_pool.update_one({"_id": task['_id']}, update_dict)
update_dict = {"$set": {"status": status}}
task_pool.update_one({"_id": task["_id"]}, update_dict)
def remove(self, query={}, task_pool=None):
query = query.copy()
task_pool = self._get_task_pool(task_pool)
if '_id' in query:
query['_id'] = ObjectId(query['_id'])
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
task_pool.delete_many(query)
def task_stat(self, query={}, task_pool=None):
query = query.copy()
if '_id' in query:
query['_id'] = ObjectId(query['_id'])
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
tasks = self.query(task_pool=task_pool, query=query, decode=False)
status_stat = {}
for t in tasks:
status_stat[t['status']] = status_stat.get(t['status'], 0) + 1
status_stat[t["status"]] = status_stat.get(t["status"], 0) + 1
return status_stat
def reset_waiting(self, query={}, task_pool=None):
query = query.copy()
# default query
if 'status' not in query:
query['status'] = self.STATUS_RUNNING
if "status" not in query:
query["status"] = self.STATUS_RUNNING
return self.reset_status(query=query, status=self.STATUS_WAITING, task_pool=task_pool)
def reset_status(self, query, status, task_pool=None):
query = query.copy()
task_pool = self._get_task_pool(task_pool)
if '_id' in query:
query['_id'] = ObjectId(query['_id'])
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
print(task_pool.update_many(query, {"$set": {"status": status}}))
def _get_undone_n(self, task_stat):
@@ -274,17 +278,18 @@ def run_task(task_func, task_pool, force_release=False, *args, **kwargs):
with tm.safe_fetch_task() as task:
if task is None:
break
logger.info(task['def'])
logger.info(task["def"])
if force_release:
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
res = executor.submit(task_func, task['def'], *args, **kwargs).result()
res = executor.submit(task_func, task["def"], *args, **kwargs).result()
else:
res = task_func(task['def'], *args, **kwargs)
res = task_func(task["def"], *args, **kwargs)
tm.commit_task_res(task, res)
ever_run = True
return ever_run
if __name__ == '__main__':
if __name__ == "__main__":
auto_init()
Fire(TaskManager)

View File

@@ -10,17 +10,18 @@ from pymongo import MongoClient
def get_mongodb():
try:
cfg = C['mongo']
cfg = C["mongo"]
except KeyError:
get_module_logger("task").error("Please configure `C['mongo']` before using TaskManager")
raise
client = MongoClient(cfg['task_url'])
return client.get_database(name=cfg['task_db_name'])
client = MongoClient(cfg["task_url"])
return client.get_database(name=cfg["task_db_name"])
class TimeAdjuster:
'''找到合适的日期然后adjust date'''
"""找到合适的日期然后adjust date"""
def __init__(self, future=False):
self.cals = D.calendar(future=future)
@@ -45,9 +46,9 @@ class TimeAdjuster:
def align_idx(self, time_point, tp_type="start"):
time_point = pd.Timestamp(time_point)
if tp_type == 'start':
if tp_type == "start":
idx = bisect.bisect_left(self.cals, time_point)
elif tp_type == 'end':
elif tp_type == "end":
idx = bisect.bisect_right(self.cals, time_point) - 1
else:
raise NotImplementedError(f"This type of input is not supported")
@@ -91,7 +92,7 @@ class TimeAdjuster:
new_seg = []
for time_point in segment:
tp_idx = min(self.align_idx(time_point), test_idx - days)
assert (tp_idx > 0)
assert tp_idx > 0
new_seg.append(self.get(tp_idx))
return tuple(new_seg)
else: