mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
format code
This commit is contained in:
@@ -119,7 +119,8 @@ def task_collecting(task_pool, exp_name):
|
||||
return False
|
||||
|
||||
artifact = ens_workflow(
|
||||
RecorderCollector(exp_name=exp_name, rec_key_func=rec_key, rec_filter_func=my_filter), RollingGroup(),
|
||||
RecorderCollector(exp_name=exp_name, rec_key_func=rec_key, rec_filter_func=my_filter),
|
||||
RollingGroup(),
|
||||
)
|
||||
print(artifact)
|
||||
|
||||
|
||||
@@ -70,9 +70,18 @@ task_xgboost_config = {
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
class RollingOnlineExample:
|
||||
|
||||
def __init__(self, exp_name="rolling_exp", task_pool="rolling_task", provider_uri="~/.qlib/qlib_data/cn_data", region="cn", task_url="mongodb://10.0.0.4:27017/", task_db_name="rolling_db", rolling_step=550):
|
||||
class RollingOnlineExample:
|
||||
def __init__(
|
||||
self,
|
||||
exp_name="rolling_exp",
|
||||
task_pool="rolling_task",
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
rolling_step=550,
|
||||
):
|
||||
self.exp_name = exp_name
|
||||
self.task_pool = task_pool
|
||||
mongo_conf = {
|
||||
@@ -84,9 +93,9 @@ class RollingOnlineExample:
|
||||
self.rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD)
|
||||
self.trainer = TrainerRM(self.exp_name, self.task_pool)
|
||||
self.task_manager = TaskManager(self.task_pool)
|
||||
self.rolling_online_manager = RollingOnlineManager(experiment_name=exp_name, rolling_gen=self.rolling_gen, trainer=self.trainer)
|
||||
|
||||
|
||||
self.rolling_online_manager = RollingOnlineManager(
|
||||
experiment_name=exp_name, rolling_gen=self.rolling_gen, trainer=self.trainer
|
||||
)
|
||||
|
||||
def print_online_model(self):
|
||||
print("========== print_online_model ==========")
|
||||
@@ -99,7 +108,6 @@ class RollingOnlineExample:
|
||||
if self.rolling_online_manager.get_online_tag(rec) == self.rolling_online_manager.NEXT_ONLINE_TAG:
|
||||
print(rid)
|
||||
|
||||
|
||||
# This part corresponds to "Task Generating" in the document
|
||||
def task_generating(self):
|
||||
|
||||
@@ -114,11 +122,9 @@ class RollingOnlineExample:
|
||||
|
||||
return tasks
|
||||
|
||||
|
||||
def task_training(self, tasks):
|
||||
self.trainer.train(tasks)
|
||||
|
||||
|
||||
# This part corresponds to "Task Collecting" in the document
|
||||
def task_collecting(self):
|
||||
print("========== task_collecting ==========")
|
||||
@@ -141,7 +147,6 @@ class RollingOnlineExample:
|
||||
)
|
||||
print(artifact)
|
||||
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
print("========== reset ==========")
|
||||
@@ -150,7 +155,6 @@ class RollingOnlineExample:
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
|
||||
# Run this firstly to see the workflow in Task Management
|
||||
def first_run(self):
|
||||
print("========== first_run ==========")
|
||||
@@ -163,7 +167,6 @@ class RollingOnlineExample:
|
||||
latest_rec, _ = self.rolling_online_manager.list_latest_recorders()
|
||||
self.rolling_online_manager.reset_online_tag(list(latest_rec.values()))
|
||||
|
||||
|
||||
def routine(self):
|
||||
print("========== routine ==========")
|
||||
self.print_online_model()
|
||||
@@ -178,7 +181,7 @@ if __name__ == "__main__":
|
||||
|
||||
####### to update the models and predictions after the trading time, use the command below
|
||||
# python task_manager_rolling_with_updating.py after_day
|
||||
|
||||
|
||||
####### to define your own parameters, use `--`
|
||||
# python task_manager_rolling_with_updating.py first_run --exp_name='your_exp_name' --rolling_step=40
|
||||
fire.Fire(RollingOnlineExample)
|
||||
|
||||
@@ -71,12 +71,14 @@ def update_online_pred(experiment_name="online_srv"):
|
||||
|
||||
online_manager.update_online_pred()
|
||||
|
||||
def main(provider_uri = "~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv"):
|
||||
|
||||
def main(provider_uri="~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv"):
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
qlib.init(provider_uri=provider_uri, region=region)
|
||||
first_train(experiment_name)
|
||||
update_online_pred(experiment_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to train a model and set it to online model, use the command below
|
||||
# python update_online_pred.py first_train
|
||||
|
||||
@@ -184,7 +184,7 @@ def get_module_by_module_path(module_path: Union[str, ModuleType]):
|
||||
return module
|
||||
|
||||
|
||||
def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType]=None) -> (type, dict):
|
||||
def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType] = None) -> (type, dict):
|
||||
"""
|
||||
extract class and kwargs from config info
|
||||
|
||||
|
||||
@@ -74,6 +74,8 @@ class OnlineManager(Serializable):
|
||||
self.update_online_pred(*args, **kwargs)
|
||||
self.reset_online_tag(*args, **kwargs)
|
||||
|
||||
# TODO: first_train?
|
||||
|
||||
|
||||
class OnlineManagerR(OnlineManager):
|
||||
"""
|
||||
|
||||
@@ -106,7 +106,8 @@ class RollingGen(TaskGen):
|
||||
|
||||
def generate(self, task: dict):
|
||||
"""
|
||||
Converting the task into a rolling task
|
||||
Converting the task into a rolling task.
|
||||
# FIXME: only modify dataset layer, user need to change datahandler firstly.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
@@ -62,10 +62,6 @@ class TaskManager:
|
||||
self.task_pool = getattr(self.mdb, task_pool)
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
# @property
|
||||
# def task_pool(self):
|
||||
# return self._task_pool
|
||||
|
||||
def list(self):
|
||||
return self.mdb.list_collection_names()
|
||||
|
||||
@@ -83,22 +79,12 @@ class TaskManager:
|
||||
task[k] = pickle.loads(task[k])
|
||||
return task
|
||||
|
||||
# def _get_task_pool(self, task_pool=None):
|
||||
# if task_pool is None:
|
||||
# task_pool = self.task_pool
|
||||
# if task_pool is None:
|
||||
# raise ValueError("You must specify a task pool.")
|
||||
# if isinstance(task_pool, str):
|
||||
# return getattr(self.mdb, task_pool)
|
||||
# return task_pool
|
||||
|
||||
def _dict_to_str(self, flt):
|
||||
return {k: str(v) for k, v in flt.items()}
|
||||
|
||||
def replace_task(self, task, new_task):
|
||||
# assume that the data out of interface was decoded and the data in interface was encoded
|
||||
new_task = self._encode_task(new_task)
|
||||
# task_pool = self._get_task_pool(task_pool)
|
||||
query = {"_id": ObjectId(task["_id"])}
|
||||
try:
|
||||
self.task_pool.replace_one(query, new_task)
|
||||
@@ -107,7 +93,7 @@ class TaskManager:
|
||||
self.task_pool.replace_one(query, new_task)
|
||||
|
||||
def insert_task(self, task):
|
||||
# task_pool = self._get_task_pool(task_pool)
|
||||
|
||||
try:
|
||||
insert_result = self.task_pool.insert_one(task)
|
||||
except InvalidDocument:
|
||||
@@ -123,14 +109,11 @@ class TaskManager:
|
||||
----------
|
||||
task_def: dict
|
||||
the task definition
|
||||
task_pool: str
|
||||
the name of Collection in MongoDB
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
# task_pool = self._get_task_pool(task_pool)
|
||||
task = self._encode_task(
|
||||
{
|
||||
"def": task_def,
|
||||
@@ -149,8 +132,6 @@ class TaskManager:
|
||||
----------
|
||||
task_def_l: list
|
||||
a list of task
|
||||
task_pool: str
|
||||
the name of task_pool (collection name of MongoDB)
|
||||
dry_run: bool
|
||||
if insert those new tasks to task pool
|
||||
print_nt: bool
|
||||
@@ -160,7 +141,6 @@ class TaskManager:
|
||||
list
|
||||
a list of the _id of new tasks
|
||||
"""
|
||||
# task_pool = self._get_task_pool(task_pool)
|
||||
new_tasks = []
|
||||
for t in task_def_l:
|
||||
try:
|
||||
@@ -186,7 +166,6 @@ class TaskManager:
|
||||
return _id_list
|
||||
|
||||
def fetch_task(self, query={}):
|
||||
# task_pool = self._get_task_pool(task_pool)
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
@@ -209,8 +188,6 @@ class TaskManager:
|
||||
----------
|
||||
query: dict
|
||||
the dict of query
|
||||
task_pool: str
|
||||
the name of Collection in MongoDB
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -226,9 +203,9 @@ class TaskManager:
|
||||
self.logger.info("Task returned")
|
||||
raise
|
||||
|
||||
def task_fetcher_iter(self, query={}, task_pool=None):
|
||||
def task_fetcher_iter(self, query={}):
|
||||
while True:
|
||||
with self.safe_fetch_task(query=query, task_pool=task_pool) as task:
|
||||
with self.safe_fetch_task(query=query) as task:
|
||||
if task is None:
|
||||
break
|
||||
yield task
|
||||
@@ -242,8 +219,6 @@ class TaskManager:
|
||||
query: dict
|
||||
the dict of query
|
||||
decode: bool
|
||||
task_pool: str
|
||||
the name of Collection in MongoDB
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -252,24 +227,20 @@ class TaskManager:
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
# task_pool = self._get_task_pool(task_pool)
|
||||
for t in self.task_pool.find(query):
|
||||
yield self._decode_task(t)
|
||||
|
||||
def re_query(self, _id):
|
||||
# task_pool = self._get_task_pool(task_pool)
|
||||
t = self.task_pool.find_one({"_id": ObjectId(_id)})
|
||||
return self._decode_task(t)
|
||||
|
||||
def commit_task_res(self, task, res, status=None):
|
||||
# task_pool = self._get_task_pool(task_pool)
|
||||
# A workaround to use the class attribute.
|
||||
if status is None:
|
||||
status = TaskManager.STATUS_DONE
|
||||
self.task_pool.update_one({"_id": task["_id"]}, {"$set": {"status": status, "res": Binary(pickle.dumps(res))}})
|
||||
|
||||
def return_task(self, task, status=None):
|
||||
# task_pool = self._get_task_pool(task_pool)
|
||||
if status is None:
|
||||
status = TaskManager.STATUS_WAITING
|
||||
update_dict = {"$set": {"status": status}}
|
||||
@@ -283,15 +254,12 @@ class TaskManager:
|
||||
----------
|
||||
query: dict
|
||||
the dict of query
|
||||
task_pool: str
|
||||
the name of Collection in MongoDB
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
query = query.copy()
|
||||
# task_pool = self._get_task_pool(task_pool)
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
self.task_pool.delete_many(query)
|
||||
@@ -306,7 +274,7 @@ class TaskManager:
|
||||
status_stat[t["status"]] = status_stat.get(t["status"], 0) + 1
|
||||
return status_stat
|
||||
|
||||
def reset_waiting(self, query={}, task_pool=None):
|
||||
def reset_waiting(self, query={}):
|
||||
query = query.copy()
|
||||
# default query
|
||||
if "status" not in query:
|
||||
@@ -315,7 +283,6 @@ class TaskManager:
|
||||
|
||||
def reset_status(self, query, status):
|
||||
query = query.copy()
|
||||
# task_pool = self._get_task_pool(task_pool)
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
print(self.task_pool.update_many(query, {"$set": {"status": status}}))
|
||||
|
||||
Reference in New Issue
Block a user