mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-30 17:41:18 +08:00
Merge pull request #1 from you-n-g/online_srv
qlib auto init basedon project & black format
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
|
||||
|
||||
__version__ = "0.6.3.99"
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
|
||||
|
||||
import os
|
||||
@@ -10,12 +11,13 @@ import yaml
|
||||
import logging
|
||||
import platform
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from .log import get_module_logger
|
||||
|
||||
|
||||
# init qlib
|
||||
def init(default_conf="client", **kwargs):
|
||||
from .config import C
|
||||
from .log import get_module_logger
|
||||
from .data.cache import H
|
||||
|
||||
H.clear()
|
||||
@@ -48,7 +50,6 @@ def init(default_conf="client", **kwargs):
|
||||
|
||||
|
||||
def _mount_nfs_uri(C):
|
||||
from .log import get_module_logger
|
||||
|
||||
LOG = get_module_logger("mount nfs", level=logging.INFO)
|
||||
|
||||
@@ -151,3 +152,73 @@ def init_from_yaml_conf(conf_path, **kwargs):
|
||||
config.update(kwargs)
|
||||
default_conf = config.pop("default_conf", "client")
|
||||
init(default_conf, **config)
|
||||
|
||||
|
||||
def get_project_path(config_name="config.yaml") -> Path:
|
||||
"""
|
||||
If users are building a project follow the following pattern.
|
||||
- Qlib is a sub folder in project path
|
||||
- There is a file named `config.yaml` in qlib.
|
||||
|
||||
For example:
|
||||
If your project file system stucuture follows such a pattern
|
||||
|
||||
<project_path>/
|
||||
- config.yaml
|
||||
- ...some folders...
|
||||
- qlib/
|
||||
|
||||
This folder will return <project_path>
|
||||
|
||||
NOTE: link is not supported here.
|
||||
|
||||
|
||||
This method is often used when
|
||||
- user want to use a relative config path instead of hard-coding qlib config path in code
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError:
|
||||
If project path is not found
|
||||
"""
|
||||
cur_path = Path(__file__).absolute().resolve()
|
||||
while True:
|
||||
if (cur_path / config_name).exists():
|
||||
return cur_path
|
||||
if cur_path == cur_path.parent:
|
||||
raise FileNotFoundError("We can't find the project path")
|
||||
cur_path = cur_path.parent
|
||||
|
||||
|
||||
def auto_init(**kwargs):
|
||||
"""
|
||||
This function will init qlib automatically with following priority
|
||||
- Find the project configuration and init qlib
|
||||
- The parsing process will be affected by the `conf_type` of the configuration file
|
||||
- Init qlib with default config
|
||||
"""
|
||||
|
||||
try:
|
||||
pp = get_project_path()
|
||||
except FileNotFoundError:
|
||||
init(**kwargs)
|
||||
else:
|
||||
|
||||
conf_pp = pp / "config.yaml"
|
||||
with conf_pp.open() as f:
|
||||
conf = yaml.safe_load(f)
|
||||
|
||||
conf_type = conf.get("conf_type", "origin")
|
||||
if conf_type == "origin":
|
||||
# The type of config is just like original qlib config
|
||||
init_from_yaml_conf(conf_pp, **kwargs)
|
||||
elif conf_type == "ref":
|
||||
# This config type will be more convenient in following scenario
|
||||
# - There is a shared configure file and you don't want to edit it inplace.
|
||||
# - The shared configure may be updated later and you don't want to copy it.
|
||||
# - You have some customized config.
|
||||
qlib_conf_path = conf["qlib_cfg"]
|
||||
qlib_conf_update = conf.get("qlib_cfg_update")
|
||||
init_from_yaml_conf(qlib_conf_path, **qlib_conf_update, **kwargs)
|
||||
logger = get_module_logger("Initialization")
|
||||
logger.info(f"Auto load project config: {conf_pp}")
|
||||
|
||||
@@ -33,6 +33,9 @@ class Config:
|
||||
|
||||
raise AttributeError(f"No such {attr} in self._config")
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self.__dict__["_config"].get(key, default)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.__dict__["_config"][key] = value
|
||||
|
||||
@@ -310,8 +313,22 @@ class QlibConfig(Config):
|
||||
# clean up experiment when python program ends
|
||||
experiment_exit_handler()
|
||||
|
||||
# Supporting user reset qlib version (useful when user want to connect to qlib server with old version)
|
||||
self.reset_qlib_version()
|
||||
|
||||
self._registered = True
|
||||
|
||||
def reset_qlib_version(self):
|
||||
import qlib
|
||||
|
||||
reset_version = self.get("qlib_reset_version", None)
|
||||
if reset_version is not None:
|
||||
qlib.__version__ = reset_version
|
||||
else:
|
||||
qlib.__version__ = getattr(qlib, "__version__bak")
|
||||
# Due to a bug? that converting __version__ to _QlibConfig__version__bak
|
||||
# Using __version__bak instead of __version__
|
||||
|
||||
@property
|
||||
def registered(self):
|
||||
return self._registered
|
||||
|
||||
@@ -5,11 +5,12 @@ from tqdm.auto import tqdm
|
||||
|
||||
|
||||
class RollingEnsemble:
|
||||
'''
|
||||
"""
|
||||
Rolling Models Ensemble based on (R)ecord
|
||||
|
||||
This shares nothing with Ensemble
|
||||
'''
|
||||
"""
|
||||
|
||||
# TODO: 这边还可以加加速
|
||||
def __init__(self, get_key_func, flt_func=None):
|
||||
self.get_key_func = get_key_func
|
||||
@@ -44,9 +45,8 @@ class RollingEnsemble:
|
||||
for k, rec_l in recs_group.items():
|
||||
pred_l = []
|
||||
for rec in rec_l:
|
||||
pred_l.append(rec.load_object('pred.pkl').iloc[:, 0])
|
||||
pred_l.append(rec.load_object("pred.pkl").iloc[:, 0])
|
||||
pred = pd.concat(pred_l).sort_index()
|
||||
reduce_group[k] = pred
|
||||
|
||||
return reduce_group
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
'''
|
||||
"""
|
||||
this is a task generator
|
||||
'''
|
||||
"""
|
||||
import abc
|
||||
import copy
|
||||
import typing
|
||||
@@ -54,8 +54,8 @@ class RollingGen(TaskGen):
|
||||
self.rtype = rtype
|
||||
self.ta = TimeAdjuster(future=True) # 为了保证test最后的日期不是None, 所以这边要改一改
|
||||
|
||||
self.test_key = 'test'
|
||||
self.train_key = 'train'
|
||||
self.test_key = "test"
|
||||
self.train_key = "train"
|
||||
|
||||
def __call__(self, task: dict):
|
||||
"""
|
||||
@@ -102,7 +102,7 @@ class RollingGen(TaskGen):
|
||||
if prev_seg is None:
|
||||
# First rolling
|
||||
# 1) prepare the end porint
|
||||
segments = copy.deepcopy(self.ta.align_seg(t['dataset']['kwargs']['segments']))
|
||||
segments = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
|
||||
test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1]
|
||||
# 2) and the init test segments
|
||||
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
|
||||
@@ -120,14 +120,12 @@ class RollingGen(TaskGen):
|
||||
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
|
||||
if segments[self.test_key][0] > test_end:
|
||||
break
|
||||
except KeyError:
|
||||
except KeyError:
|
||||
# We reach the end of tasks
|
||||
# No more rolling
|
||||
break
|
||||
|
||||
t['dataset']['kwargs']['segments'] = copy.deepcopy(segments)
|
||||
t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments)
|
||||
prev_seg = segments
|
||||
res.append(t)
|
||||
return res
|
||||
|
||||
|
||||
|
||||
@@ -19,6 +19,8 @@ import time
|
||||
import concurrent
|
||||
import pymongo
|
||||
from qlib.config import C
|
||||
from .utils import get_mongodb
|
||||
from qlib import auto_init
|
||||
|
||||
|
||||
class TaskManager:
|
||||
@@ -41,12 +43,13 @@ class TaskManager:
|
||||
NOTE:
|
||||
- 假设: 存储在db里面的都是encode过的, 拿出来的都是decode过的
|
||||
"""
|
||||
STATUS_WAITING = 'waiting'
|
||||
STATUS_RUNNING = 'running'
|
||||
STATUS_DONE = 'done'
|
||||
STATUS_PART_DONE = 'part_done'
|
||||
|
||||
ENCODE_FIELDS_PREFIX = ['def', 'res']
|
||||
STATUS_WAITING = "waiting"
|
||||
STATUS_RUNNING = "running"
|
||||
STATUS_DONE = "done"
|
||||
STATUS_PART_DONE = "part_done"
|
||||
|
||||
ENCODE_FIELDS_PREFIX = ["def", "res"]
|
||||
|
||||
def __init__(self, task_pool=None):
|
||||
self.mdb = get_mongodb()
|
||||
@@ -73,7 +76,7 @@ class TaskManager:
|
||||
if task_pool is None:
|
||||
task_pool = self.task_pool
|
||||
if task_pool is None:
|
||||
raise ValueError('You must specify a task pool.')
|
||||
raise ValueError("You must specify a task pool.")
|
||||
if isinstance(task_pool, str):
|
||||
return getattr(self.mdb, task_pool)
|
||||
return task_pool
|
||||
@@ -85,11 +88,11 @@ class TaskManager:
|
||||
# 这里的假设是从接口拿出来的都是decode过的,在接口内部的都是 encode过的
|
||||
new_task = self._encode_task(new_task)
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
query = {'_id': ObjectId(task['_id'])}
|
||||
query = {"_id": ObjectId(task["_id"])}
|
||||
try:
|
||||
task_pool.replace_one(query, new_task)
|
||||
except InvalidDocument:
|
||||
task['filter'] = self._dict_to_str(task['filter'])
|
||||
task["filter"] = self._dict_to_str(task["filter"])
|
||||
task_pool.replace_one(query, new_task)
|
||||
|
||||
def insert_task(self, task, task_pool=None):
|
||||
@@ -97,16 +100,18 @@ class TaskManager:
|
||||
try:
|
||||
task_pool.insert_one(task)
|
||||
except InvalidDocument:
|
||||
task['filter'] = self._dict_to_str(task['filter'])
|
||||
task["filter"] = self._dict_to_str(task["filter"])
|
||||
task_pool.insert_one(task)
|
||||
|
||||
def insert_task_def(self, task_def, task_pool=None):
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
task = self._encode_task({
|
||||
'def': task_def,
|
||||
'filter': task_def, # FIXME: catch the raised error
|
||||
'status': self.STATUS_WAITING,
|
||||
})
|
||||
task = self._encode_task(
|
||||
{
|
||||
"def": task_def,
|
||||
"filter": task_def, # FIXME: catch the raised error
|
||||
"status": self.STATUS_WAITING,
|
||||
}
|
||||
)
|
||||
self.insert_task(task, task_pool)
|
||||
|
||||
def create_task(self, task_def_l, task_pool=None, dry_run=False, print_nt=False):
|
||||
@@ -114,9 +119,9 @@ class TaskManager:
|
||||
new_tasks = []
|
||||
for t in task_def_l:
|
||||
try:
|
||||
r = task_pool.find_one({'filter': t})
|
||||
r = task_pool.find_one({"filter": t})
|
||||
except InvalidDocument:
|
||||
r = task_pool.find_one({'filter': self._dict_to_str(t)})
|
||||
r = task_pool.find_one({"filter": self._dict_to_str(t)})
|
||||
if r is None:
|
||||
new_tasks.append(t)
|
||||
print("Total Tasks, New Tasks:", len(task_def_l), len(new_tasks))
|
||||
@@ -134,17 +139,16 @@ class TaskManager:
|
||||
def fetch_task(self, query={}, task_pool=None):
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
query = query.copy()
|
||||
if '_id' in query:
|
||||
query['_id'] = ObjectId(query['_id'])
|
||||
query.update({'status': self.STATUS_WAITING})
|
||||
task = task_pool.find_one_and_update(query, {'$set': {
|
||||
'status': self.STATUS_RUNNING
|
||||
}},
|
||||
sort=[('priority', pymongo.DESCENDING)])
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query.update({"status": self.STATUS_WAITING})
|
||||
task = task_pool.find_one_and_update(
|
||||
query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)]
|
||||
)
|
||||
# 这里我的 priority 必须是 高数优先级更高,因为 null会被在 ASCENDING时被排在最前面
|
||||
if task is None:
|
||||
return None
|
||||
task['status'] = self.STATUS_RUNNING
|
||||
task["status"] = self.STATUS_RUNNING
|
||||
return self._decode_task(task)
|
||||
|
||||
@contextmanager
|
||||
@@ -154,9 +158,9 @@ class TaskManager:
|
||||
yield task
|
||||
except Exception:
|
||||
if task is not None:
|
||||
logger.info('Returning task before raising error')
|
||||
logger.info("Returning task before raising error")
|
||||
self.return_task(task)
|
||||
logger.info('Task returned')
|
||||
logger.info("Task returned")
|
||||
raise
|
||||
|
||||
def task_fetcher_iter(self, query={}, task_pool=None):
|
||||
@@ -175,8 +179,8 @@ class TaskManager:
|
||||
:param task_pool:
|
||||
"""
|
||||
query = query.copy()
|
||||
if '_id' in query:
|
||||
query['_id'] = ObjectId(query['_id'])
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
for t in task_pool.find(query):
|
||||
yield self._decode_task(t)
|
||||
@@ -186,44 +190,44 @@ class TaskManager:
|
||||
# A workaround to use the class attribute.
|
||||
if status is None:
|
||||
status = TaskManager.STATUS_DONE
|
||||
task_pool.update_one({"_id": task['_id']}, {'$set': {'status': status, 'res': Binary(pickle.dumps(res))}})
|
||||
task_pool.update_one({"_id": task["_id"]}, {"$set": {"status": status, "res": Binary(pickle.dumps(res))}})
|
||||
|
||||
def return_task(self, task, status=None, task_pool=None):
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
if status is None:
|
||||
status = TaskManager.STATUS_WAITING
|
||||
update_dict = {'$set': {'status': status}}
|
||||
task_pool.update_one({"_id": task['_id']}, update_dict)
|
||||
update_dict = {"$set": {"status": status}}
|
||||
task_pool.update_one({"_id": task["_id"]}, update_dict)
|
||||
|
||||
def remove(self, query={}, task_pool=None):
|
||||
query = query.copy()
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
if '_id' in query:
|
||||
query['_id'] = ObjectId(query['_id'])
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
task_pool.delete_many(query)
|
||||
|
||||
def task_stat(self, query={}, task_pool=None):
|
||||
query = query.copy()
|
||||
if '_id' in query:
|
||||
query['_id'] = ObjectId(query['_id'])
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
tasks = self.query(task_pool=task_pool, query=query, decode=False)
|
||||
status_stat = {}
|
||||
for t in tasks:
|
||||
status_stat[t['status']] = status_stat.get(t['status'], 0) + 1
|
||||
status_stat[t["status"]] = status_stat.get(t["status"], 0) + 1
|
||||
return status_stat
|
||||
|
||||
def reset_waiting(self, query={}, task_pool=None):
|
||||
query = query.copy()
|
||||
# default query
|
||||
if 'status' not in query:
|
||||
query['status'] = self.STATUS_RUNNING
|
||||
if "status" not in query:
|
||||
query["status"] = self.STATUS_RUNNING
|
||||
return self.reset_status(query=query, status=self.STATUS_WAITING, task_pool=task_pool)
|
||||
|
||||
def reset_status(self, query, status, task_pool=None):
|
||||
query = query.copy()
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
if '_id' in query:
|
||||
query['_id'] = ObjectId(query['_id'])
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
print(task_pool.update_many(query, {"$set": {"status": status}}))
|
||||
|
||||
def _get_undone_n(self, task_stat):
|
||||
@@ -274,17 +278,18 @@ def run_task(task_func, task_pool, force_release=False, *args, **kwargs):
|
||||
with tm.safe_fetch_task() as task:
|
||||
if task is None:
|
||||
break
|
||||
logger.info(task['def'])
|
||||
logger.info(task["def"])
|
||||
if force_release:
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
|
||||
res = executor.submit(task_func, task['def'], *args, **kwargs).result()
|
||||
res = executor.submit(task_func, task["def"], *args, **kwargs).result()
|
||||
else:
|
||||
res = task_func(task['def'], *args, **kwargs)
|
||||
res = task_func(task["def"], *args, **kwargs)
|
||||
tm.commit_task_res(task, res)
|
||||
ever_run = True
|
||||
|
||||
return ever_run
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
auto_init()
|
||||
Fire(TaskManager)
|
||||
|
||||
@@ -10,17 +10,18 @@ from pymongo import MongoClient
|
||||
|
||||
def get_mongodb():
|
||||
try:
|
||||
cfg = C['mongo']
|
||||
cfg = C["mongo"]
|
||||
except KeyError:
|
||||
get_module_logger("task").error("Please configure `C['mongo']` before using TaskManager")
|
||||
raise
|
||||
|
||||
client = MongoClient(cfg['task_url'])
|
||||
return client.get_database(name=cfg['task_db_name'])
|
||||
client = MongoClient(cfg["task_url"])
|
||||
return client.get_database(name=cfg["task_db_name"])
|
||||
|
||||
|
||||
class TimeAdjuster:
|
||||
'''找到合适的日期,然后adjust date'''
|
||||
"""找到合适的日期,然后adjust date"""
|
||||
|
||||
def __init__(self, future=False):
|
||||
self.cals = D.calendar(future=future)
|
||||
|
||||
@@ -45,9 +46,9 @@ class TimeAdjuster:
|
||||
|
||||
def align_idx(self, time_point, tp_type="start"):
|
||||
time_point = pd.Timestamp(time_point)
|
||||
if tp_type == 'start':
|
||||
if tp_type == "start":
|
||||
idx = bisect.bisect_left(self.cals, time_point)
|
||||
elif tp_type == 'end':
|
||||
elif tp_type == "end":
|
||||
idx = bisect.bisect_right(self.cals, time_point) - 1
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
@@ -91,7 +92,7 @@ class TimeAdjuster:
|
||||
new_seg = []
|
||||
for time_point in segment:
|
||||
tp_idx = min(self.align_idx(time_point), test_idx - days)
|
||||
assert (tp_idx > 0)
|
||||
assert tp_idx > 0
|
||||
new_seg.append(self.get(tp_idx))
|
||||
return tuple(new_seg)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user