mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 03:21:00 +08:00
add default protocol_version (#677)
* add default protocol_version * add comment to serial.Serializable.get_backend
This commit is contained in:
@@ -73,6 +73,9 @@ class Config:
|
||||
REG_CN = "cn"
|
||||
REG_US = "us"
|
||||
|
||||
# pickle.dump protocol version: https://docs.python.org/3/library/pickle.html#data-stream-format
|
||||
PROTOCOL_VERSION = 4
|
||||
|
||||
NUM_USABLE_CPU = max(multiprocessing.cpu_count() - 2, 1)
|
||||
|
||||
DISK_DATASET_CACHE = "DiskDatasetCache"
|
||||
@@ -107,6 +110,8 @@ _default_config = {
|
||||
# for simple dataset cache
|
||||
"local_cache_path": None,
|
||||
"kernels": NUM_USABLE_CPU,
|
||||
# pickle.dump protocol version
|
||||
"dump_protocol_version": PROTOCOL_VERSION,
|
||||
# How many tasks belong to one process. Recommend 1 for high-frequency data and None for daily data.
|
||||
"maxtasksperchild": None,
|
||||
# If joblib_backend is None, use loky
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import yaml
|
||||
import pathlib
|
||||
import pandas as pd
|
||||
import shutil
|
||||
from ..backtest.account import Account
|
||||
from ..backtest.exchange import Exchange
|
||||
from ...backtest.account import Account
|
||||
from .user import User
|
||||
from .utils import load_instance
|
||||
from ...utils import save_instance, init_instance_by_config
|
||||
from .utils import load_instance, save_instance
|
||||
from ...utils import init_instance_by_config
|
||||
|
||||
|
||||
class UserManager:
|
||||
|
||||
@@ -6,10 +6,10 @@ import pickle
|
||||
import yaml
|
||||
import pandas as pd
|
||||
from ...data import D
|
||||
from ...config import C
|
||||
from ...log import get_module_logger
|
||||
from ...utils import get_module_by_module_path, init_instance_by_config
|
||||
from ...utils import get_next_trading_date
|
||||
from ..backtest.exchange import Exchange
|
||||
from ...backtest.exchange import Exchange
|
||||
|
||||
log = get_module_logger("utils")
|
||||
|
||||
@@ -42,7 +42,7 @@ def save_instance(instance, file_path):
|
||||
"""
|
||||
file_path = pathlib.Path(file_path)
|
||||
with file_path.open("wb") as fr:
|
||||
pickle.dump(instance, fr)
|
||||
pickle.dump(instance, fr, C.dump_protocol_version)
|
||||
|
||||
|
||||
def create_user_folder(path):
|
||||
|
||||
@@ -154,10 +154,11 @@ class Expression(abc.ABC):
|
||||
raise ValueError("Invalid index range: {} {}".format(start_index, end_index))
|
||||
try:
|
||||
series = self._load_internal(instrument, start_index, end_index, freq)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
get_module_logger("data").error(
|
||||
f"Loading data error: instrument={instrument}, expression={str(self)}, "
|
||||
f"start_index={start_index}, end_index={end_index}, freq={freq}"
|
||||
f"start_index={start_index}, end_index={end_index}, freq={freq}. "
|
||||
f"error info: {str(e)}"
|
||||
)
|
||||
raise
|
||||
series.name = str(self)
|
||||
|
||||
@@ -230,7 +230,7 @@ class CacheUtils:
|
||||
d["meta"]["visits"] = d["meta"]["visits"] + 1
|
||||
except KeyError:
|
||||
raise KeyError("Unknown meta keyword")
|
||||
pickle.dump(d, f)
|
||||
pickle.dump(d, f, protocol=C.dump_protocol_version)
|
||||
except Exception as e:
|
||||
get_module_logger("CacheUtils").warning(f"visit {cache_path} cache error: {e}")
|
||||
|
||||
@@ -573,7 +573,7 @@ class DiskExpressionCache(ExpressionCache):
|
||||
meta_path = cache_path.with_suffix(".meta")
|
||||
|
||||
with meta_path.open("wb") as f:
|
||||
pickle.dump(meta, f)
|
||||
pickle.dump(meta, f, protocol=C.dump_protocol_version)
|
||||
meta_path.chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
|
||||
df = expression_data.to_frame()
|
||||
|
||||
@@ -638,7 +638,7 @@ class DiskExpressionCache(ExpressionCache):
|
||||
# update meta file
|
||||
d["info"]["last_update"] = str(new_calendar[-1])
|
||||
with meta_path.open("wb") as f:
|
||||
pickle.dump(d, f)
|
||||
pickle.dump(d, f, protocol=C.dump_protocol_version)
|
||||
return 0
|
||||
|
||||
|
||||
@@ -935,7 +935,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
"meta": {"last_visit": time.time(), "visits": 1},
|
||||
}
|
||||
with cache_path.with_suffix(".meta").open("wb") as f:
|
||||
pickle.dump(meta, f)
|
||||
pickle.dump(meta, f, protocol=C.dump_protocol_version)
|
||||
cache_path.with_suffix(".meta").chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
|
||||
# write index file
|
||||
im = DiskDatasetCache.IndexManager(cache_path)
|
||||
@@ -1057,7 +1057,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
# update meta file
|
||||
d["info"]["last_update"] = str(new_calendar[-1])
|
||||
with meta_path.open("wb") as f:
|
||||
pickle.dump(d, f)
|
||||
pickle.dump(d, f, protocol=C.dump_protocol_version)
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
@@ -1,102 +1,103 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import socketio
|
||||
|
||||
import qlib
|
||||
from ..log import get_module_logger
|
||||
import pickle
|
||||
|
||||
|
||||
class Client:
|
||||
"""A client class
|
||||
|
||||
Provide the connection tool functions for ClientProvider.
|
||||
"""
|
||||
|
||||
def __init__(self, host, port):
|
||||
super(Client, self).__init__()
|
||||
self.sio = socketio.Client()
|
||||
self.server_host = host
|
||||
self.server_port = port
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
# bind connect/disconnect callbacks
|
||||
self.sio.on(
|
||||
"connect",
|
||||
lambda: self.logger.debug("Connect to server {}".format(self.sio.connection_url)),
|
||||
)
|
||||
self.sio.on("disconnect", lambda: self.logger.debug("Disconnect from server!"))
|
||||
|
||||
def connect_server(self):
|
||||
"""Connect to server."""
|
||||
try:
|
||||
self.sio.connect("ws://" + self.server_host + ":" + str(self.server_port))
|
||||
except socketio.exceptions.ConnectionError:
|
||||
self.logger.error("Cannot connect to server - check your network or server status")
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnect from server."""
|
||||
try:
|
||||
self.sio.eio.disconnect(True)
|
||||
except Exception as e:
|
||||
self.logger.error("Cannot disconnect from server : %s" % e)
|
||||
|
||||
def send_request(self, request_type, request_content, msg_queue, msg_proc_func=None):
|
||||
"""Send a certain request to server.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
request_type : str
|
||||
type of proposed request, 'calendar'/'instrument'/'feature'.
|
||||
request_content : dict
|
||||
records the information of the request.
|
||||
msg_proc_func : func
|
||||
the function to process the message when receiving response, should have arg `*args`.
|
||||
msg_queue: Queue
|
||||
The queue to pass the messsage after callback.
|
||||
"""
|
||||
head_info = {"version": qlib.__version__}
|
||||
|
||||
def request_callback(*args):
|
||||
"""callback_wrapper
|
||||
|
||||
:param *args: args[0] is the response content
|
||||
"""
|
||||
# args[0] is the response content
|
||||
self.logger.debug("receive data and enter queue")
|
||||
msg = dict(args[0])
|
||||
if msg["detailed_info"] is not None:
|
||||
if msg["status"] != 0:
|
||||
self.logger.error(msg["detailed_info"])
|
||||
else:
|
||||
self.logger.info(msg["detailed_info"])
|
||||
if msg["status"] != 0:
|
||||
ex = ValueError(f"Bad response(status=={msg['status']}), detailed info: {msg['detailed_info']}")
|
||||
msg_queue.put(ex)
|
||||
else:
|
||||
if msg_proc_func is not None:
|
||||
try:
|
||||
ret = msg_proc_func(msg["result"])
|
||||
except Exception as e:
|
||||
self.logger.exception("Error when processing message.")
|
||||
ret = e
|
||||
else:
|
||||
ret = msg["result"]
|
||||
msg_queue.put(ret)
|
||||
self.disconnect()
|
||||
self.logger.debug("disconnected")
|
||||
|
||||
self.logger.debug("try connecting")
|
||||
self.connect_server()
|
||||
self.logger.debug("connected")
|
||||
# The pickle is for passing some parameters with special type(such as
|
||||
# pd.Timestamp)
|
||||
request_content = {"head": head_info, "body": pickle.dumps(request_content)}
|
||||
self.sio.on(request_type + "_response", request_callback)
|
||||
self.logger.debug("try sending")
|
||||
self.sio.emit(request_type + "_request", request_content)
|
||||
self.sio.wait()
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import socketio
|
||||
|
||||
import qlib
|
||||
from ..config import C
|
||||
from ..log import get_module_logger
|
||||
import pickle
|
||||
|
||||
|
||||
class Client:
|
||||
"""A client class
|
||||
|
||||
Provide the connection tool functions for ClientProvider.
|
||||
"""
|
||||
|
||||
def __init__(self, host, port):
|
||||
super(Client, self).__init__()
|
||||
self.sio = socketio.Client()
|
||||
self.server_host = host
|
||||
self.server_port = port
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
# bind connect/disconnect callbacks
|
||||
self.sio.on(
|
||||
"connect",
|
||||
lambda: self.logger.debug("Connect to server {}".format(self.sio.connection_url)),
|
||||
)
|
||||
self.sio.on("disconnect", lambda: self.logger.debug("Disconnect from server!"))
|
||||
|
||||
def connect_server(self):
|
||||
"""Connect to server."""
|
||||
try:
|
||||
self.sio.connect("ws://" + self.server_host + ":" + str(self.server_port))
|
||||
except socketio.exceptions.ConnectionError:
|
||||
self.logger.error("Cannot connect to server - check your network or server status")
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnect from server."""
|
||||
try:
|
||||
self.sio.eio.disconnect(True)
|
||||
except Exception as e:
|
||||
self.logger.error("Cannot disconnect from server : %s" % e)
|
||||
|
||||
def send_request(self, request_type, request_content, msg_queue, msg_proc_func=None):
|
||||
"""Send a certain request to server.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
request_type : str
|
||||
type of proposed request, 'calendar'/'instrument'/'feature'.
|
||||
request_content : dict
|
||||
records the information of the request.
|
||||
msg_proc_func : func
|
||||
the function to process the message when receiving response, should have arg `*args`.
|
||||
msg_queue: Queue
|
||||
The queue to pass the messsage after callback.
|
||||
"""
|
||||
head_info = {"version": qlib.__version__}
|
||||
|
||||
def request_callback(*args):
|
||||
"""callback_wrapper
|
||||
|
||||
:param *args: args[0] is the response content
|
||||
"""
|
||||
# args[0] is the response content
|
||||
self.logger.debug("receive data and enter queue")
|
||||
msg = dict(args[0])
|
||||
if msg["detailed_info"] is not None:
|
||||
if msg["status"] != 0:
|
||||
self.logger.error(msg["detailed_info"])
|
||||
else:
|
||||
self.logger.info(msg["detailed_info"])
|
||||
if msg["status"] != 0:
|
||||
ex = ValueError(f"Bad response(status=={msg['status']}), detailed info: {msg['detailed_info']}")
|
||||
msg_queue.put(ex)
|
||||
else:
|
||||
if msg_proc_func is not None:
|
||||
try:
|
||||
ret = msg_proc_func(msg["result"])
|
||||
except Exception as e:
|
||||
self.logger.exception("Error when processing message.")
|
||||
ret = e
|
||||
else:
|
||||
ret = msg["result"]
|
||||
msg_queue.put(ret)
|
||||
self.disconnect()
|
||||
self.logger.debug("disconnected")
|
||||
|
||||
self.logger.debug("try connecting")
|
||||
self.connect_server()
|
||||
self.logger.debug("connected")
|
||||
# The pickle is for passing some parameters with special type(such as
|
||||
# pd.Timestamp)
|
||||
request_content = {"head": head_info, "body": pickle.dumps(request_content, protocol=C.dump_protocol_version)}
|
||||
self.sio.on(request_type + "_response", request_callback)
|
||||
self.logger.debug("try sending")
|
||||
self.sio.emit(request_type + "_request", request_content)
|
||||
self.sio.wait()
|
||||
|
||||
@@ -726,10 +726,11 @@ class LocalExpressionProvider(ExpressionProvider):
|
||||
lft_etd, rght_etd = expression.get_extended_window_size()
|
||||
try:
|
||||
series = expression.load(instrument, max(0, start_index - lft_etd), end_index + rght_etd, freq)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
get_module_logger("data").error(
|
||||
f"Loading expression error: "
|
||||
f"instrument={instrument}, field=({field}), start_time={start_time}, end_time={end_time}, freq={freq}"
|
||||
f"instrument={instrument}, field=({field}), start_time={start_time}, end_time={end_time}, freq={freq}. "
|
||||
f"error info: {str(e)}"
|
||||
)
|
||||
raise
|
||||
# Ensure that each column type is consistent
|
||||
|
||||
@@ -312,12 +312,12 @@ class NpPairOperator(PairOperator):
|
||||
warning_info = (
|
||||
f"Loading {instrument}: {str(self)}; np.{self.func}(series_left, series_right), "
|
||||
f"The length of series_left and series_right is different: ({len(series_left)}, {len(series_right)}), "
|
||||
f"series_left is {str(self.feature_left)}, series_right is {str(self.feature_left)}. Please check the data"
|
||||
f"series_left is {str(self.feature_left)}, series_right is {str(self.feature_right)}. Please check the data"
|
||||
)
|
||||
else:
|
||||
warning_info = (
|
||||
f"Loading {instrument}: {str(self)}; np.{self.func}(series_left, series_right), "
|
||||
f"series_left is {str(self.feature_left)}, series_right is {str(self.feature_left)}. Please check the data"
|
||||
f"series_left is {str(self.feature_left)}, series_right is {str(self.feature_right)}. Please check the data"
|
||||
)
|
||||
try:
|
||||
res = getattr(np, self.func)(series_left, series_right)
|
||||
|
||||
@@ -106,7 +106,7 @@ class FileManager(ObjManager):
|
||||
|
||||
def save_obj(self, obj, name):
|
||||
with (self.path / name).open("wb") as f:
|
||||
pickle.dump(obj, f)
|
||||
pickle.dump(obj, f, protocol=C.dump_protocol_version)
|
||||
|
||||
def save_objs(self, obj_name_l):
|
||||
for obj, name in obj_name_l:
|
||||
|
||||
@@ -5,6 +5,7 @@ import pickle
|
||||
import dill
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from ..config import C
|
||||
|
||||
|
||||
class Serializable:
|
||||
@@ -85,7 +86,8 @@ class Serializable:
|
||||
"""
|
||||
self.config(dump_all=dump_all, exclude=exclude)
|
||||
with Path(path).open("wb") as f:
|
||||
self.get_backend().dump(self, f)
|
||||
# pickle interface like backend; such as dill
|
||||
self.get_backend().dump(self, f, protocol=C.dump_protocol_version)
|
||||
|
||||
@classmethod
|
||||
def load(cls, filepath):
|
||||
@@ -116,6 +118,7 @@ class Serializable:
|
||||
Returns:
|
||||
module: pickle or dill module based on pickle_backend
|
||||
"""
|
||||
# NOTE: pickle interface like backend; such as dill
|
||||
if cls.pickle_backend == "pickle":
|
||||
return pickle
|
||||
elif cls.pickle_backend == "dill":
|
||||
@@ -140,4 +143,4 @@ class Serializable:
|
||||
obj.to_pickle(path)
|
||||
else:
|
||||
with path.open("wb") as f:
|
||||
pickle.dump(obj, f)
|
||||
pickle.dump(obj, f, protocol=C.dump_protocol_version)
|
||||
|
||||
@@ -27,6 +27,7 @@ from qlib import auto_init, get_module_logger
|
||||
from tqdm.cli import tqdm
|
||||
|
||||
from .utils import get_mongodb
|
||||
from ...config import C
|
||||
|
||||
|
||||
class TaskManager:
|
||||
@@ -108,7 +109,7 @@ class TaskManager:
|
||||
for prefix in self.ENCODE_FIELDS_PREFIX:
|
||||
for k in list(task.keys()):
|
||||
if k.startswith(prefix):
|
||||
task[k] = Binary(pickle.dumps(task[k]))
|
||||
task[k] = Binary(pickle.dumps(task[k], protocol=C.dump_protocol_version))
|
||||
return task
|
||||
|
||||
def _decode_task(self, task):
|
||||
@@ -359,7 +360,10 @@ class TaskManager:
|
||||
# A workaround to use the class attribute.
|
||||
if status is None:
|
||||
status = TaskManager.STATUS_DONE
|
||||
self.task_pool.update_one({"_id": task["_id"]}, {"$set": {"status": status, "res": Binary(pickle.dumps(res))}})
|
||||
self.task_pool.update_one(
|
||||
{"_id": task["_id"]},
|
||||
{"$set": {"status": status, "res": Binary(pickle.dumps(res, protocol=C.dump_protocol_version))}},
|
||||
)
|
||||
|
||||
def return_task(self, task, status=STATUS_WAITING):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user