mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-30 17:41:18 +08:00
783 lines
23 KiB
Python
783 lines
23 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import os
|
|
import re
|
|
import copy
|
|
import json
|
|
import yaml
|
|
import redis
|
|
import bisect
|
|
import shutil
|
|
import difflib
|
|
import hashlib
|
|
import logging
|
|
import datetime
|
|
import requests
|
|
import tempfile
|
|
import importlib
|
|
import contextlib
|
|
import collections
|
|
import numpy as np
|
|
import pandas as pd
|
|
from pathlib import Path
|
|
from typing import Union, Tuple
|
|
|
|
from ..config import C, REG_CN
|
|
from ..log import get_module_logger, set_log_with_config
|
|
|
|
|
|
log = get_module_logger("utils")
|
|
|
|
|
|
#################### Server ####################
|
|
def get_redis_connection():
|
|
"""get redis connection instance."""
|
|
return redis.StrictRedis(host=C.redis_host, port=C.redis_port, db=C.redis_task_db)
|
|
|
|
|
|
#################### Data ####################
|
|
def read_bin(file_path, start_index, end_index):
|
|
with open(file_path, "rb") as f:
|
|
# read start_index
|
|
ref_start_index = int(np.frombuffer(f.read(4), dtype="<f")[0])
|
|
si = max(ref_start_index, start_index)
|
|
if si > end_index:
|
|
return pd.Series(dtype=np.float32)
|
|
# calculate offset
|
|
f.seek(4 * (si - ref_start_index) + 4)
|
|
# read nbytes
|
|
count = end_index - si + 1
|
|
data = np.frombuffer(f.read(4 * count), dtype="<f")
|
|
series = pd.Series(data, index=pd.RangeIndex(si, si + len(data)))
|
|
return series
|
|
|
|
|
|
def np_ffill(arr: np.array):
|
|
"""
|
|
forward fill a 1D numpy array
|
|
|
|
Parameters
|
|
----------
|
|
arr : np.array
|
|
Input numpy 1D array
|
|
"""
|
|
mask = np.isnan(arr.astype(np.float)) # np.isnan only works on np.float
|
|
# get fill index
|
|
idx = np.where(~mask, np.arange(mask.shape[0]), 0)
|
|
np.maximum.accumulate(idx, out=idx)
|
|
return arr[idx]
|
|
|
|
|
|
#################### Search ####################
|
|
def lower_bound(data, val, level=0):
|
|
"""multi fields list lower bound.
|
|
|
|
for single field list use `bisect.bisect_left` instead
|
|
"""
|
|
left = 0
|
|
right = len(data)
|
|
while left < right:
|
|
mid = (left + right) // 2
|
|
if val <= data[mid][level]:
|
|
right = mid
|
|
else:
|
|
left = mid + 1
|
|
return left
|
|
|
|
|
|
def upper_bound(data, val, level=0):
|
|
"""multi fields list upper bound.
|
|
|
|
for single field list use `bisect.bisect_right` instead
|
|
"""
|
|
left = 0
|
|
right = len(data)
|
|
while left < right:
|
|
mid = (left + right) // 2
|
|
if val >= data[mid][level]:
|
|
left = mid + 1
|
|
else:
|
|
right = mid
|
|
return left
|
|
|
|
|
|
#################### HTTP ####################
|
|
def requests_with_retry(url, retry=5, **kwargs):
|
|
while retry > 0:
|
|
retry -= 1
|
|
try:
|
|
res = requests.get(url, timeout=1, **kwargs)
|
|
assert res.status_code in {200, 206}
|
|
return res
|
|
except AssertionError:
|
|
continue
|
|
except Exception as e:
|
|
log.warning("exception encountered {}".format(e))
|
|
continue
|
|
raise Exception("ERROR: requests failed!")
|
|
|
|
|
|
#################### Parse ####################
|
|
def parse_config(config):
|
|
# Check whether need parse, all object except str do not need to be parsed
|
|
if not isinstance(config, str):
|
|
return config
|
|
# Check whether config is file
|
|
if os.path.exists(config):
|
|
with open(config, "r") as f:
|
|
return yaml.load(f)
|
|
# Check whether the str can be parsed
|
|
try:
|
|
return yaml.load(config)
|
|
except BaseException:
|
|
raise ValueError("cannot parse config!")
|
|
|
|
|
|
#################### Other ####################
|
|
def drop_nan_by_y_index(x, y, weight=None):
|
|
# x, y, weight: DataFrame
|
|
# Find index of rows which do not contain Nan in all columns from y.
|
|
mask = ~y.isna().any(axis=1)
|
|
# Get related rows from x, y, weight.
|
|
x = x[mask]
|
|
y = y[mask]
|
|
if weight is not None:
|
|
weight = weight[mask]
|
|
return x, y, weight
|
|
|
|
|
|
def hash_args(*args):
|
|
# json.dumps will keep the dict keys always sorted.
|
|
string = json.dumps(args, sort_keys=True, default=str) # frozenset
|
|
return hashlib.md5(string.encode()).hexdigest()
|
|
|
|
|
|
def parse_field(field):
|
|
# Following patterns will be matched:
|
|
# - $close -> Feature("close")
|
|
# - $close5 -> Feature("close5")
|
|
# - $open+$close -> Feature("open")+Feature("close")
|
|
if not isinstance(field, str):
|
|
field = str(field)
|
|
return re.sub(r"\$(\w+)", r'Feature("\1")', re.sub(r"(\w+\s*)\(", r"Operators.\1(", field))
|
|
|
|
|
|
def get_module_by_module_path(module_path):
|
|
"""Load module path
|
|
|
|
:param module_path:
|
|
:return:
|
|
"""
|
|
|
|
if module_path.endswith(".py"):
|
|
module_spec = importlib.util.spec_from_file_location("", module_path)
|
|
module = importlib.util.module_from_spec(module_spec)
|
|
module_spec.loader.exec_module(module)
|
|
else:
|
|
module = importlib.import_module(module_path)
|
|
|
|
return module
|
|
|
|
|
|
def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
|
|
"""
|
|
extract class and kwargs from config info
|
|
|
|
Parameters
|
|
----------
|
|
config : [dict, str]
|
|
similar to config
|
|
|
|
module : Python module
|
|
It should be a python module to load the class type
|
|
|
|
Returns
|
|
-------
|
|
(type, dict):
|
|
the class object and it's arguments.
|
|
"""
|
|
if isinstance(config, dict):
|
|
# raise AttributeError
|
|
klass = getattr(module, config["class"])
|
|
kwargs = config.get("kwargs", {})
|
|
elif isinstance(config, str):
|
|
klass = getattr(module, config)
|
|
kwargs = {}
|
|
else:
|
|
raise NotImplementedError(f"This type of input is not supported")
|
|
return klass, kwargs
|
|
|
|
|
|
def init_instance_by_config(
|
|
config: Union[str, dict, object], module=None, accept_types: Union[type, Tuple[type]] = tuple([]), **kwargs
|
|
) -> object:
|
|
"""
|
|
get initialized instance with config
|
|
|
|
Parameters
|
|
----------
|
|
config : Union[str, dict, object]
|
|
dict example.
|
|
{
|
|
'class': 'ClassName',
|
|
'kwargs': dict, # It is optional. {} will be used if not given
|
|
'model_path': path, # It is optional if module is given
|
|
}
|
|
str example.
|
|
"ClassName": getattr(module, config)() will be used.
|
|
object example:
|
|
instance of accept_types
|
|
module : Python module
|
|
Optional. It should be a python module.
|
|
NOTE: the "module_path" will be override by `module` arguments
|
|
|
|
accept_types: Union[type, Tuple[type]]
|
|
Optional. If the config is a instance of specific type, return the config directly.
|
|
This will be passed into the second parameter of isinstance.
|
|
|
|
Returns
|
|
-------
|
|
object:
|
|
An initialized object based on the config info
|
|
"""
|
|
if isinstance(config, accept_types):
|
|
return config
|
|
|
|
if module is None:
|
|
module = get_module_by_module_path(config["module_path"])
|
|
|
|
klass, cls_kwargs = get_cls_kwargs(config, module)
|
|
return klass(**cls_kwargs, **kwargs)
|
|
|
|
|
|
def compare_dict_value(src_data: dict, dst_data: dict):
|
|
"""Compare dict value
|
|
|
|
:param src_data:
|
|
:param dst_data:
|
|
:return:
|
|
"""
|
|
|
|
class DateEncoder(json.JSONEncoder):
|
|
# FIXME: This class can only be accurate to the day. If it is a minute,
|
|
# there may be a bug
|
|
def default(self, o):
|
|
if isinstance(o, (datetime.datetime, datetime.date)):
|
|
return o.strftime("%Y-%m-%d %H:%M:%S")
|
|
return json.JSONEncoder.default(self, o)
|
|
|
|
src_data = json.dumps(src_data, indent=4, sort_keys=True, cls=DateEncoder)
|
|
dst_data = json.dumps(dst_data, indent=4, sort_keys=True, cls=DateEncoder)
|
|
diff = difflib.ndiff(src_data, dst_data)
|
|
changes = [line for line in diff if line.startswith("+ ") or line.startswith("- ")]
|
|
return changes
|
|
|
|
|
|
def create_save_path(save_path=None):
|
|
"""Create save path
|
|
|
|
:param save_path:
|
|
:return:
|
|
"""
|
|
if save_path:
|
|
if not os.path.exists(save_path):
|
|
os.makedirs(save_path)
|
|
else:
|
|
temp_dir = os.path.expanduser("~/tmp")
|
|
if not os.path.exists(temp_dir):
|
|
os.makedirs(temp_dir)
|
|
_, save_path = tempfile.mkstemp(dir=temp_dir)
|
|
return save_path
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def save_multiple_parts_file(filename, format="gztar"):
|
|
"""Save multiple parts file
|
|
|
|
Implementation process:
|
|
1. get the absolute path to 'filename'
|
|
2. create a 'filename' directory
|
|
3. user does something with file_path('filename/')
|
|
4. remove 'filename' directory
|
|
5. make_archive 'filename' directory, and rename 'archive file' to filename
|
|
|
|
:param filename: result model path
|
|
:param format: archive format: one of "zip", "tar", "gztar", "bztar", or "xztar"
|
|
:return: real model path
|
|
|
|
Usage::
|
|
|
|
>>> # The following code will create an archive file('~/tmp/test_file') containing 'test_doc_i'(i is 0-10) files.
|
|
>>> with save_multiple_parts_file('~/tmp/test_file') as filename_dir:
|
|
... for i in range(10):
|
|
... temp_path = os.path.join(filename_dir, 'test_doc_{}'.format(str(i)))
|
|
... with open(temp_path) as fp:
|
|
... fp.write(str(i))
|
|
...
|
|
|
|
"""
|
|
|
|
if filename.startswith("~"):
|
|
filename = os.path.expanduser(filename)
|
|
|
|
file_path = os.path.abspath(filename)
|
|
|
|
# Create model dir
|
|
if os.path.exists(file_path):
|
|
raise FileExistsError("ERROR: file exists: {}, cannot be create the directory.".format(file_path))
|
|
|
|
os.makedirs(file_path)
|
|
|
|
# return model dir
|
|
yield file_path
|
|
|
|
# filename dir to filename.tar.gz file
|
|
tar_file = shutil.make_archive(file_path, format=format, root_dir=file_path)
|
|
|
|
# Remove filename dir
|
|
if os.path.exists(file_path):
|
|
shutil.rmtree(file_path)
|
|
|
|
# filename.tar.gz rename to filename
|
|
os.rename(tar_file, file_path)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def unpack_archive_with_buffer(buffer, format="gztar"):
|
|
"""Unpack archive with archive buffer
|
|
After the call is finished, the archive file and directory will be deleted.
|
|
|
|
Implementation process:
|
|
1. create 'tempfile' in '~/tmp/' and directory
|
|
2. 'buffer' write to 'tempfile'
|
|
3. unpack archive file('tempfile')
|
|
4. user does something with file_path('tempfile/')
|
|
5. remove 'tempfile' and 'tempfile directory'
|
|
|
|
:param buffer: bytes
|
|
:param format: archive format: one of "zip", "tar", "gztar", "bztar", or "xztar"
|
|
:return: unpack archive directory path
|
|
|
|
Usage::
|
|
|
|
>>> # The following code is to print all the file names in 'test_unpack.tar.gz'
|
|
>>> with open('test_unpack.tar.gz') as fp:
|
|
... buffer = fp.read()
|
|
...
|
|
>>> with unpack_archive_with_buffer(buffer) as temp_dir:
|
|
... for f_n in os.listdir(temp_dir):
|
|
... print(f_n)
|
|
...
|
|
|
|
"""
|
|
temp_dir = os.path.expanduser("~/tmp")
|
|
if not os.path.exists(temp_dir):
|
|
os.makedirs(temp_dir)
|
|
with tempfile.NamedTemporaryFile("wb", delete=False, dir=temp_dir) as fp:
|
|
fp.write(buffer)
|
|
file_path = fp.name
|
|
|
|
try:
|
|
tar_file = file_path + ".tar.gz"
|
|
os.rename(file_path, tar_file)
|
|
# Create dir
|
|
os.makedirs(file_path)
|
|
shutil.unpack_archive(tar_file, format=format, extract_dir=file_path)
|
|
|
|
# Return temp dir
|
|
yield file_path
|
|
|
|
except Exception as e:
|
|
log.error(str(e))
|
|
finally:
|
|
# Remove temp tar file
|
|
if os.path.exists(tar_file):
|
|
os.unlink(tar_file)
|
|
|
|
# Remove temp model dir
|
|
if os.path.exists(file_path):
|
|
shutil.rmtree(file_path)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def get_tmp_file_with_buffer(buffer):
|
|
temp_dir = os.path.expanduser("~/tmp")
|
|
if not os.path.exists(temp_dir):
|
|
os.makedirs(temp_dir)
|
|
with tempfile.NamedTemporaryFile("wb", delete=True, dir=temp_dir) as fp:
|
|
fp.write(buffer)
|
|
file_path = fp.name
|
|
yield file_path
|
|
|
|
|
|
def remove_repeat_field(fields):
|
|
"""remove repeat field
|
|
|
|
:param fields: list; features fields
|
|
:return: list
|
|
"""
|
|
fields = copy.deepcopy(fields)
|
|
_fields = set(fields)
|
|
return sorted(_fields, key=fields.index)
|
|
|
|
|
|
def remove_fields_space(fields: [list, str, tuple]):
|
|
"""remove fields space
|
|
|
|
:param fields: features fields
|
|
:return: list or str
|
|
"""
|
|
if isinstance(fields, str):
|
|
return fields.replace(" ", "")
|
|
return [i.replace(" ", "") for i in fields if isinstance(i, str)]
|
|
|
|
|
|
def normalize_cache_fields(fields: [list, tuple]):
|
|
"""normalize cache fields
|
|
|
|
:param fields: features fields
|
|
:return: list
|
|
"""
|
|
return sorted(remove_repeat_field(remove_fields_space(fields)))
|
|
|
|
|
|
def normalize_cache_instruments(instruments):
|
|
"""normalize cache instruments
|
|
|
|
:return: list or dict
|
|
"""
|
|
if isinstance(instruments, (list, tuple, pd.Index, np.ndarray)):
|
|
instruments = sorted(list(instruments))
|
|
else:
|
|
# dict type stockpool
|
|
if "market" in instruments:
|
|
pass
|
|
else:
|
|
instruments = {k: sorted(v) for k, v in instruments.items()}
|
|
return instruments
|
|
|
|
|
|
def is_tradable_date(cur_date):
|
|
"""judgy whether date is a tradable date
|
|
----------
|
|
date : pandas.Timestamp
|
|
current date
|
|
"""
|
|
from ..data import D
|
|
|
|
return str(cur_date.date()) == str(D.calendar(start_time=cur_date, future=True)[0].date())
|
|
|
|
|
|
def get_date_range(trading_date, shift, future=False):
|
|
"""get trading date range by shift
|
|
|
|
:param trading_date:
|
|
:param shift: int
|
|
:param future: bool
|
|
:return:
|
|
"""
|
|
from ..data import D
|
|
|
|
calendar = D.calendar(future=future)
|
|
if pd.to_datetime(trading_date) not in list(calendar):
|
|
raise ValueError("{} is not trading day!".format(str(trading_date)))
|
|
day_index = bisect.bisect_left(calendar, trading_date)
|
|
if 0 <= (day_index + shift) < len(calendar):
|
|
if shift > 0:
|
|
return calendar[day_index + 1 : day_index + 1 + shift]
|
|
else:
|
|
return calendar[day_index + shift : day_index]
|
|
else:
|
|
return calendar
|
|
|
|
|
|
def get_date_by_shift(trading_date, shift, future=False):
|
|
"""get trading date with shift bias wil cur_date
|
|
e.g. : shift == 1, return next trading date
|
|
shift == -1, return previous trading date
|
|
----------
|
|
trading_date : pandas.Timestamp
|
|
current date
|
|
shift : int
|
|
"""
|
|
return get_date_range(trading_date, shift, future)[0 if shift < 0 else -1] if shift != 0 else trading_date
|
|
|
|
|
|
def get_next_trading_date(trading_date, future=False):
|
|
"""get next trading date
|
|
----------
|
|
cur_date : pandas.Timestamp
|
|
current date
|
|
"""
|
|
return get_date_by_shift(trading_date, 1, future=future)
|
|
|
|
|
|
def get_pre_trading_date(trading_date, future=False):
|
|
"""get previous trading date
|
|
----------
|
|
date : pandas.Timestamp
|
|
current date
|
|
"""
|
|
return get_date_by_shift(trading_date, -1, future=future)
|
|
|
|
|
|
def transform_end_date(end_date=None, freq="day"):
|
|
"""get previous trading date
|
|
If end_date is -1, None, or end_date is greater than the maximum trading day, the last trading date is returned.
|
|
Otherwise, returns the end_date
|
|
----------
|
|
end_date: str
|
|
end trading date
|
|
date : pandas.Timestamp
|
|
current date
|
|
"""
|
|
from ..data import D
|
|
|
|
last_date = D.calendar(freq=freq)[-1]
|
|
if end_date is None or (str(end_date) == "-1") or (pd.Timestamp(last_date) < pd.Timestamp(end_date)):
|
|
log.warning(
|
|
"\nInfo: the end_date in the configuration file is {}, "
|
|
"so the default last date {} is used.".format(end_date, last_date)
|
|
)
|
|
end_date = last_date
|
|
return end_date
|
|
|
|
|
|
def get_date_in_file_name(file_name):
|
|
"""Get the date(YYYY-MM-DD) written in file name
|
|
Parameter
|
|
file_name : str
|
|
:return
|
|
date : str
|
|
'YYYY-MM-DD'
|
|
"""
|
|
pattern = "[0-9]{4}-[0-9]{2}-[0-9]{2}"
|
|
date = re.search(pattern, str(file_name)).group()
|
|
return date
|
|
|
|
|
|
def split_pred(pred, number=None, split_date=None):
|
|
"""split the score file into two part
|
|
Parameter
|
|
---------
|
|
pred : pd.DataFrame (index:<instrument, datetime>)
|
|
A score file of stocks
|
|
number: the number of dates for pred_left
|
|
split_date: the last date of the pred_left
|
|
Return
|
|
-------
|
|
pred_left : pd.DataFrame (index:<instrument, datetime>)
|
|
The first part of original score file
|
|
pred_right : pd.DataFrame (index:<instrument, datetime>)
|
|
The second part of original score file
|
|
"""
|
|
if number is None and split_date is None:
|
|
raise ValueError("`number` and `split date` cannot both be None")
|
|
dates = sorted(pred.index.get_level_values("datetime").unique())
|
|
dates = list(map(pd.Timestamp, dates))
|
|
if split_date is None:
|
|
date_left_end = dates[number - 1]
|
|
date_right_begin = dates[number]
|
|
date_left_start = None
|
|
else:
|
|
split_date = pd.Timestamp(split_date)
|
|
date_left_end = split_date
|
|
date_right_begin = split_date + pd.Timedelta(days=1)
|
|
if number is None:
|
|
date_left_start = None
|
|
else:
|
|
end_idx = bisect.bisect_right(dates, split_date)
|
|
date_left_start = dates[end_idx - number]
|
|
pred_temp = pred.sort_index()
|
|
pred_left = pred_temp.loc(axis=0)[:, date_left_start:date_left_end]
|
|
pred_right = pred_temp.loc(axis=0)[:, date_right_begin:]
|
|
return pred_left, pred_right
|
|
|
|
|
|
def can_use_cache():
|
|
res = True
|
|
r = get_redis_connection()
|
|
try:
|
|
r.client()
|
|
except redis.exceptions.ConnectionError:
|
|
res = False
|
|
finally:
|
|
r.close()
|
|
return res
|
|
|
|
|
|
def exists_qlib_data(qlib_dir):
|
|
qlib_dir = Path(qlib_dir).expanduser()
|
|
if not qlib_dir.exists():
|
|
return False
|
|
|
|
calendars_dir = qlib_dir.joinpath("calendars")
|
|
instruments_dir = qlib_dir.joinpath("instruments")
|
|
features_dir = qlib_dir.joinpath("features")
|
|
# check dir
|
|
for _dir in [calendars_dir, instruments_dir, features_dir]:
|
|
if not (_dir.exists() and list(_dir.iterdir())):
|
|
return False
|
|
# check calendar bin
|
|
for _calendar in calendars_dir.iterdir():
|
|
if not list(features_dir.rglob(f"*.{_calendar.name.split('.')[0]}.bin")):
|
|
return False
|
|
|
|
# check instruments
|
|
code_names = set(map(lambda x: x.name.lower(), features_dir.iterdir()))
|
|
_instrument = instruments_dir.joinpath("all.txt")
|
|
df = pd.read_csv(_instrument, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"])
|
|
df = df.iloc[:, [0, -1]].fillna(axis=1, method="ffill")
|
|
miss_code = set(df.iloc[:, -1].apply(str.lower)) - set(code_names)
|
|
if miss_code and any(map(lambda x: "sht" not in x, miss_code)):
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame:
|
|
"""
|
|
make the df index sorted
|
|
|
|
df.sort_index() will take a lot of time even when `df.is_lexsorted() == True`
|
|
This function could avoid such case
|
|
|
|
Parameters
|
|
----------
|
|
df : pd.DataFrame
|
|
|
|
Returns
|
|
-------
|
|
pd.DataFrame:
|
|
sorted dataframe
|
|
"""
|
|
idx = df.index if axis == 0 else df.columns
|
|
if idx.is_monotonic_increasing:
|
|
return df
|
|
else:
|
|
return df.sort_index(axis=axis)
|
|
|
|
|
|
def flatten_dict(d, parent_key="", sep="."):
|
|
"""flatten_dict.
|
|
>>> flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]})
|
|
>>> {'a': 1, 'c.a': 2, 'c.b.x': 5, 'd': [1, 2, 3], 'c.b.y': 10}
|
|
|
|
Parameters
|
|
----------
|
|
d :
|
|
d
|
|
parent_key :
|
|
parent_key
|
|
sep :
|
|
sep
|
|
"""
|
|
items = []
|
|
for k, v in d.items():
|
|
new_key = parent_key + sep + k if parent_key else k
|
|
if isinstance(v, collections.abc.MutableMapping):
|
|
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
|
else:
|
|
items.append((new_key, v))
|
|
return dict(items)
|
|
|
|
|
|
#################### Wrapper #####################
|
|
class Wrapper(object):
|
|
"""Wrapper class for anything that needs to set up during qlib.init"""
|
|
|
|
def __init__(self):
|
|
self._provider = None
|
|
|
|
def register(self, provider):
|
|
self._provider = provider
|
|
|
|
def __getattr__(self, key):
|
|
if self._provider is None:
|
|
raise AttributeError("Please run qlib.init() first using qlib")
|
|
return getattr(self._provider, key)
|
|
|
|
|
|
def register_wrapper(wrapper, cls_or_obj, module_path=None):
|
|
"""register_wrapper
|
|
|
|
:param wrapper: A wrapper.
|
|
:param cls_or_obj: A class or class name or object instance.
|
|
"""
|
|
if isinstance(cls_or_obj, str):
|
|
module = get_module_by_module_path(module_path)
|
|
cls_or_obj = getattr(module, cls_or_obj)
|
|
obj = cls_or_obj() if isinstance(cls_or_obj, type) else cls_or_obj
|
|
wrapper.register(obj)
|
|
|
|
|
|
def load_dataset(path_or_obj):
|
|
"""load dataset from multiple file formats"""
|
|
if isinstance(path_or_obj, pd.DataFrame):
|
|
return path_or_obj
|
|
if not os.path.exists(path_or_obj):
|
|
raise ValueError(f"file {path_or_obj} doesn't exist")
|
|
_, extension = os.path.splitext(path_or_obj)
|
|
if extension == ".h5":
|
|
return pd.read_hdf(path_or_obj)
|
|
elif extension == ".pkl":
|
|
return pd.read_pickle(path_or_obj)
|
|
elif extension == ".csv":
|
|
return pd.read_csv(path_or_obj, parse_dates=True, index_col=[0, 1])
|
|
raise ValueError(f"unsupported file type `{extension}`")
|
|
|
|
|
|
def set_config(config_c, default_conf="client", **kwargs):
|
|
|
|
config_c.reset()
|
|
|
|
_logging_config = config_c.logging_config
|
|
if "logging_config" in kwargs:
|
|
_logging_config = kwargs["logging_config"]
|
|
|
|
# set global config
|
|
if _logging_config:
|
|
set_log_with_config(_logging_config)
|
|
|
|
# FIXME: this logger ignored the level in config
|
|
logger = get_module_logger("Initialization", level=logging.INFO)
|
|
logger.info(f"default_conf: {default_conf}.")
|
|
|
|
config_c.set_mode(default_conf)
|
|
config_c.set_region(kwargs.get("region", config_c["region"] if "region" in config_c else REG_CN))
|
|
|
|
for k, v in kwargs.items():
|
|
if k not in config_c:
|
|
logger.warning("Unrecognized config %s" % k)
|
|
config_c[k] = v
|
|
|
|
config_c.resolve_path()
|
|
|
|
if not (config_c["expression_cache"] is None and config_c["dataset_cache"] is None):
|
|
# check redis
|
|
if not can_use_cache():
|
|
logger.warning(
|
|
f"redis connection failed(host={config_c['redis_host']} port={config_c['redis_port']}), cache will not be used!"
|
|
)
|
|
config_c["expression_cache"] = None
|
|
config_c["dataset_cache"] = None
|
|
|
|
|
|
def config_based_on_c(config_c):
|
|
from ..data.data import register_all_wrappers
|
|
from ..workflow import R, QlibRecorder
|
|
from ..workflow.utils import experiment_exit_handler
|
|
|
|
register_all_wrappers(config_c)
|
|
# set up QlibRecorder
|
|
exp_manager = init_instance_by_config(config_c["exp_manager"])
|
|
qr = QlibRecorder(exp_manager)
|
|
R.register(qr)
|
|
# clean up experiment when python program ends
|
|
experiment_exit_handler()
|