# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from __future__ import division from __future__ import print_function import os import pickle import re import sys import copy import json import yaml import redis import bisect import shutil import difflib import hashlib import datetime import requests import tempfile import importlib import contextlib import collections import numpy as np import pandas as pd from pathlib import Path from typing import Union, Tuple, Any, Text, Optional from types import ModuleType from urllib.parse import urlparse from ..config import C from ..log import get_module_logger, set_log_with_config log = get_module_logger("utils") #################### Server #################### def get_redis_connection(): """get redis connection instance.""" return redis.StrictRedis(host=C.redis_host, port=C.redis_port, db=C.redis_task_db) #################### Data #################### def read_bin(file_path: Union[str, Path], start_index, end_index): file_path = Path(file_path.expanduser().resolve()) with file_path.open("rb") as f: # read start_index ref_start_index = int(np.frombuffer(f.read(4), dtype=" end_index: return pd.Series(dtype=np.float32) # calculate offset f.seek(4 * (si - ref_start_index) + 4) # read nbytes count = end_index - si + 1 data = np.frombuffer(f.read(4 * count), dtype="= data[mid][level]: left = mid + 1 else: right = mid return left #################### HTTP #################### def requests_with_retry(url, retry=5, **kwargs): while retry > 0: retry -= 1 try: res = requests.get(url, timeout=1, **kwargs) assert res.status_code in {200, 206} return res except AssertionError: continue except Exception as e: log.warning("exception encountered {}".format(e)) continue raise Exception("ERROR: requests failed!") #################### Parse #################### def parse_config(config): # Check whether need parse, all object except str do not need to be parsed if not isinstance(config, str): return config # Check whether config is file if os.path.exists(config): with open(config, "r") as f: return yaml.safe_load(f) # Check whether the str can be parsed try: return yaml.safe_load(config) except BaseException: raise ValueError("cannot parse config!") #################### Other #################### def drop_nan_by_y_index(x, y, weight=None): # x, y, weight: DataFrame # Find index of rows which do not contain Nan in all columns from y. mask = ~y.isna().any(axis=1) # Get related rows from x, y, weight. x = x[mask] y = y[mask] if weight is not None: weight = weight[mask] return x, y, weight def hash_args(*args): # json.dumps will keep the dict keys always sorted. string = json.dumps(args, sort_keys=True, default=str) # frozenset return hashlib.md5(string.encode()).hexdigest() def parse_field(field): # Following patterns will be matched: # - $close -> Feature("close") # - $close5 -> Feature("close5") # - $open+$close -> Feature("open")+Feature("close") if not isinstance(field, str): field = str(field) return re.sub(r"\$(\w+)", r'Feature("\1")', re.sub(r"(\w+\s*)\(", r"Operators.\1(", field)) def get_module_by_module_path(module_path: Union[str, ModuleType]): """Load module path :param module_path: :return: """ if isinstance(module_path, ModuleType): module = module_path else: if module_path.endswith(".py"): module_name = re.sub("^[^a-zA-Z_]+", "", re.sub("[^0-9a-zA-Z_]", "", module_path[:-3].replace("/", "_"))) module_spec = importlib.util.spec_from_file_location(module_name, module_path) module = importlib.util.module_from_spec(module_spec) sys.modules[module_name] = module module_spec.loader.exec_module(module) else: module = importlib.import_module(module_path) return module def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType] = None) -> (type, dict): """ extract class/func and kwargs from config info Parameters ---------- config : [dict, str] similar to config default_module : Python module or str It should be a python module to load the class type This function will load class from the config['module_path'] first. If config['module_path'] doesn't exists, it will load the class from default_module. Returns ------- (type, dict): the class/func object and it's arguments. """ if isinstance(config, dict): module = get_module_by_module_path(config.get("module_path", default_module)) # raise AttributeError _callable = getattr(module, config["class" if "class" in config else "func"]) kwargs = config.get("kwargs", {}) elif isinstance(config, str): module = get_module_by_module_path(default_module) _callable = getattr(module, config) kwargs = {} else: raise NotImplementedError(f"This type of input is not supported") return _callable, kwargs def init_instance_by_config( config: Union[str, dict, object], default_module=None, accept_types: Union[type, Tuple[type]] = (), **kwargs ) -> Any: """ get initialized instance with config Parameters ---------- config : Union[str, dict, object] dict example. { 'class': 'ClassName', 'kwargs': dict, # It is optional. {} will be used if not given 'model_path': path, # It is optional if module is given } str example. 1) specify a pickle object - path like 'file:////obj.pkl' 2) specify a class name - "ClassName": getattr(module, config)() will be used. object example: instance of accept_types default_module : Python module Optional. It should be a python module. NOTE: the "module_path" will be override by `module` arguments This function will load class from the config['module_path'] first. If config['module_path'] doesn't exists, it will load the class from default_module. accept_types: Union[type, Tuple[type]] Optional. If the config is a instance of specific type, return the config directly. This will be passed into the second parameter of isinstance. Returns ------- object: An initialized object based on the config info """ if isinstance(config, accept_types): return config if isinstance(config, str): # path like 'file:////obj.pkl' pr = urlparse(config) if pr.scheme == "file": with open(os.path.join(pr.netloc, pr.path), "rb") as f: return pickle.load(f) klass, cls_kwargs = get_callable_kwargs(config, default_module=default_module) return klass(**cls_kwargs, **kwargs) def compare_dict_value(src_data: dict, dst_data: dict): """Compare dict value :param src_data: :param dst_data: :return: """ class DateEncoder(json.JSONEncoder): # FIXME: This class can only be accurate to the day. If it is a minute, # there may be a bug def default(self, o): if isinstance(o, (datetime.datetime, datetime.date)): return o.strftime("%Y-%m-%d %H:%M:%S") return json.JSONEncoder.default(self, o) src_data = json.dumps(src_data, indent=4, sort_keys=True, cls=DateEncoder) dst_data = json.dumps(dst_data, indent=4, sort_keys=True, cls=DateEncoder) diff = difflib.ndiff(src_data, dst_data) changes = [line for line in diff if line.startswith("+ ") or line.startswith("- ")] return changes def get_or_create_path(path: Optional[Text] = None, return_dir: bool = False): """Create or get a file or directory given the path and return_dir. Parameters ---------- path: a string indicates the path or None indicates creating a temporary path. return_dir: if True, create and return a directory; otherwise c&r a file. """ if path: if return_dir and not os.path.exists(path): os.makedirs(path) elif not return_dir: # return a file, thus we need to create its parent directory xpath = os.path.abspath(os.path.join(path, "..")) if not os.path.exists(xpath): os.makedirs(xpath) else: temp_dir = os.path.expanduser("~/tmp") if not os.path.exists(temp_dir): os.makedirs(temp_dir) if return_dir: _, path = tempfile.mkdtemp(dir=temp_dir) else: _, path = tempfile.mkstemp(dir=temp_dir) return path @contextlib.contextmanager def save_multiple_parts_file(filename, format="gztar"): """Save multiple parts file Implementation process: 1. get the absolute path to 'filename' 2. create a 'filename' directory 3. user does something with file_path('filename/') 4. remove 'filename' directory 5. make_archive 'filename' directory, and rename 'archive file' to filename :param filename: result model path :param format: archive format: one of "zip", "tar", "gztar", "bztar", or "xztar" :return: real model path Usage:: >>> # The following code will create an archive file('~/tmp/test_file') containing 'test_doc_i'(i is 0-10) files. >>> with save_multiple_parts_file('~/tmp/test_file') as filename_dir: ... for i in range(10): ... temp_path = os.path.join(filename_dir, 'test_doc_{}'.format(str(i))) ... with open(temp_path) as fp: ... fp.write(str(i)) ... """ if filename.startswith("~"): filename = os.path.expanduser(filename) file_path = os.path.abspath(filename) # Create model dir if os.path.exists(file_path): raise FileExistsError("ERROR: file exists: {}, cannot be create the directory.".format(file_path)) os.makedirs(file_path) # return model dir yield file_path # filename dir to filename.tar.gz file tar_file = shutil.make_archive(file_path, format=format, root_dir=file_path) # Remove filename dir if os.path.exists(file_path): shutil.rmtree(file_path) # filename.tar.gz rename to filename os.rename(tar_file, file_path) @contextlib.contextmanager def unpack_archive_with_buffer(buffer, format="gztar"): """Unpack archive with archive buffer After the call is finished, the archive file and directory will be deleted. Implementation process: 1. create 'tempfile' in '~/tmp/' and directory 2. 'buffer' write to 'tempfile' 3. unpack archive file('tempfile') 4. user does something with file_path('tempfile/') 5. remove 'tempfile' and 'tempfile directory' :param buffer: bytes :param format: archive format: one of "zip", "tar", "gztar", "bztar", or "xztar" :return: unpack archive directory path Usage:: >>> # The following code is to print all the file names in 'test_unpack.tar.gz' >>> with open('test_unpack.tar.gz') as fp: ... buffer = fp.read() ... >>> with unpack_archive_with_buffer(buffer) as temp_dir: ... for f_n in os.listdir(temp_dir): ... print(f_n) ... """ temp_dir = os.path.expanduser("~/tmp") if not os.path.exists(temp_dir): os.makedirs(temp_dir) with tempfile.NamedTemporaryFile("wb", delete=False, dir=temp_dir) as fp: fp.write(buffer) file_path = fp.name try: tar_file = file_path + ".tar.gz" os.rename(file_path, tar_file) # Create dir os.makedirs(file_path) shutil.unpack_archive(tar_file, format=format, extract_dir=file_path) # Return temp dir yield file_path except Exception as e: log.error(str(e)) finally: # Remove temp tar file if os.path.exists(tar_file): os.unlink(tar_file) # Remove temp model dir if os.path.exists(file_path): shutil.rmtree(file_path) @contextlib.contextmanager def get_tmp_file_with_buffer(buffer): temp_dir = os.path.expanduser("~/tmp") if not os.path.exists(temp_dir): os.makedirs(temp_dir) with tempfile.NamedTemporaryFile("wb", delete=True, dir=temp_dir) as fp: fp.write(buffer) file_path = fp.name yield file_path def remove_repeat_field(fields): """remove repeat field :param fields: list; features fields :return: list """ fields = copy.deepcopy(fields) _fields = set(fields) return sorted(_fields, key=fields.index) def remove_fields_space(fields: [list, str, tuple]): """remove fields space :param fields: features fields :return: list or str """ if isinstance(fields, str): return fields.replace(" ", "") return [i.replace(" ", "") for i in fields if isinstance(i, str)] def normalize_cache_fields(fields: [list, tuple]): """normalize cache fields :param fields: features fields :return: list """ return sorted(remove_repeat_field(remove_fields_space(fields))) def normalize_cache_instruments(instruments): """normalize cache instruments :return: list or dict """ if isinstance(instruments, (list, tuple, pd.Index, np.ndarray)): instruments = sorted(list(instruments)) else: # dict type stockpool if "market" in instruments: pass else: instruments = {k: sorted(v) for k, v in instruments.items()} return instruments def is_tradable_date(cur_date): """judgy whether date is a tradable date ---------- date : pandas.Timestamp current date """ from ..data import D return str(cur_date.date()) == str(D.calendar(start_time=cur_date, future=True)[0].date()) def get_date_range(trading_date, left_shift=0, right_shift=0, future=False): """get trading date range by shift Parameters ---------- trading_date: pd.Timestamp left_shift: int right_shift: int future: bool """ from ..data import D start = get_date_by_shift(trading_date, left_shift, future=future) end = get_date_by_shift(trading_date, right_shift, future=future) calendar = D.calendar(start, end, future=future) return calendar def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day"): """get trading date with shift bias wil cur_date e.g. : shift == 1, return next trading date shift == -1, return previous trading date ---------- trading_date : pandas.Timestamp current date shift : int clip_shift: bool """ from qlib.data import D cal = D.calendar(future=future, freq=freq) if pd.to_datetime(trading_date) not in list(cal): raise ValueError("{} is not trading day!".format(str(trading_date))) _index = bisect.bisect_left(cal, trading_date) shift_index = _index + shift if shift_index < 0 or shift_index >= len(cal): if clip_shift: shift_index = np.clip(shift_index, 0, len(cal) - 1) else: raise IndexError(f"The shift_index({shift_index}) of the trading day ({trading_date}) is out of range") return cal[shift_index] def get_next_trading_date(trading_date, future=False): """get next trading date ---------- cur_date : pandas.Timestamp current date """ return get_date_by_shift(trading_date, 1, future=future) def get_pre_trading_date(trading_date, future=False): """get previous trading date ---------- date : pandas.Timestamp current date """ return get_date_by_shift(trading_date, -1, future=future) def transform_end_date(end_date=None, freq="day"): """handle the end date with various format If end_date is -1, None, or end_date is greater than the maximum trading day, the last trading date is returned. Otherwise, returns the end_date ---------- end_date: str end trading date date : pandas.Timestamp current date """ from ..data import D last_date = D.calendar(freq=freq)[-1] if end_date is None or (str(end_date) == "-1") or (pd.Timestamp(last_date) < pd.Timestamp(end_date)): log.warning( "\nInfo: the end_date in the configuration file is {}, " "so the default last date {} is used.".format(end_date, last_date) ) end_date = last_date return end_date def get_date_in_file_name(file_name): """Get the date(YYYY-MM-DD) written in file name Parameter file_name : str :return date : str 'YYYY-MM-DD' """ pattern = "[0-9]{4}-[0-9]{2}-[0-9]{2}" date = re.search(pattern, str(file_name)).group() return date def split_pred(pred, number=None, split_date=None): """split the score file into two part Parameter --------- pred : pd.DataFrame (index:) A score file of stocks number: the number of dates for pred_left split_date: the last date of the pred_left Return ------- pred_left : pd.DataFrame (index:) The first part of original score file pred_right : pd.DataFrame (index:) The second part of original score file """ if number is None and split_date is None: raise ValueError("`number` and `split date` cannot both be None") dates = sorted(pred.index.get_level_values("datetime").unique()) dates = list(map(pd.Timestamp, dates)) if split_date is None: date_left_end = dates[number - 1] date_right_begin = dates[number] date_left_start = None else: split_date = pd.Timestamp(split_date) date_left_end = split_date date_right_begin = split_date + pd.Timedelta(days=1) if number is None: date_left_start = None else: end_idx = bisect.bisect_right(dates, split_date) date_left_start = dates[end_idx - number] pred_temp = pred.sort_index() pred_left = pred_temp.loc(axis=0)[:, date_left_start:date_left_end] pred_right = pred_temp.loc(axis=0)[:, date_right_begin:] return pred_left, pred_right def time_to_slc_point(t: Union[None, str, pd.Timestamp]) -> Union[None, pd.Timestamp]: """ Time slicing in Qlib or Pandas is a frequently-used action. However, user often input all kinds of data format to represent time. This function will help user to convert these inputs into a uniform format which is friendly to time slicing. Parameters ---------- t : Union[None, str, pd.Timestamp] original time Returns ------- Union[None, pd.Timestamp]: """ if t is None: # None represents unbounded in Qlib or Pandas(e.g. df.loc[slice(None, "20210303")]). return t else: return pd.Timestamp(t) def can_use_cache(): res = True r = get_redis_connection() try: r.client() except redis.exceptions.ConnectionError: res = False finally: r.close() return res def exists_qlib_data(qlib_dir): qlib_dir = Path(qlib_dir).expanduser() if not qlib_dir.exists(): return False calendars_dir = qlib_dir.joinpath("calendars") instruments_dir = qlib_dir.joinpath("instruments") features_dir = qlib_dir.joinpath("features") # check dir for _dir in [calendars_dir, instruments_dir, features_dir]: if not (_dir.exists() and list(_dir.iterdir())): return False # check calendar bin for _calendar in calendars_dir.iterdir(): if ("_future" not in _calendar.name) and ( not list(features_dir.rglob(f"*.{_calendar.name.split('.')[0]}.bin")) ): return False # check instruments code_names = set(map(lambda x: x.name.lower(), features_dir.iterdir())) _instrument = instruments_dir.joinpath("all.txt") miss_code = set(pd.read_csv(_instrument, sep="\t", header=None).loc[:, 0].apply(str.lower)) - set(code_names) if miss_code and any(map(lambda x: "sht" not in x, miss_code)): return False return True def check_qlib_data(qlib_config): inst_dir = Path(qlib_config["provider_uri"]).joinpath("instruments") for _p in inst_dir.glob("*.txt"): try: assert len(pd.read_csv(_p, sep="\t", nrows=0, header=None).columns) == 3, ( f"\nThe {str(_p.resolve())} of qlib data is not equal to 3 columns:" f"\n\tIf you are using the data provided by qlib: " f"https://qlib.readthedocs.io/en/latest/component/data.html#qlib-format-dataset" f"\n\tIf you are using your own data, please dump the data again: " f"https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format" ) except AssertionError: raise def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame: """ make the df index sorted df.sort_index() will take a lot of time even when `df.is_lexsorted() == True` This function could avoid such case Parameters ---------- df : pd.DataFrame Returns ------- pd.DataFrame: sorted dataframe """ idx = df.index if axis == 0 else df.columns # NOTE: MultiIndex.is_lexsorted() is a deprecated method in Pandas 1.3.0 and is suggested to be replaced by MultiIndex.is_monotonic_increasing (see discussion here: https://github.com/pandas-dev/pandas/issues/32259). However, in case older versions of Pandas is implemented, MultiIndex.is_lexsorted() is necessary to prevent certain fatal errors. if idx.is_monotonic_increasing and not (isinstance(idx, pd.MultiIndex) and not idx.is_lexsorted()): return df else: return df.sort_index(axis=axis) FLATTEN_TUPLE = "_FLATTEN_TUPLE" def flatten_dict(d, parent_key="", sep=".") -> dict: """ Flatten a nested dict. >>> flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]}) >>> {'a': 1, 'c.a': 2, 'c.b.x': 5, 'd': [1, 2, 3], 'c.b.y': 10} >>> flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]}, sep=FLATTEN_TUPLE) >>> {'a': 1, ('c','a'): 2, ('c','b','x'): 5, 'd': [1, 2, 3], ('c','b','y'): 10} Args: d (dict): the dict waiting for flatting parent_key (str, optional): the parent key, will be a prefix in new key. Defaults to "". sep (str, optional): the separator for string connecting. FLATTEN_TUPLE for tuple connecting. Returns: dict: flatten dict """ items = [] for k, v in d.items(): if sep == FLATTEN_TUPLE: new_key = (parent_key, k) if parent_key else k else: new_key = parent_key + sep + k if parent_key else k if isinstance(v, collections.abc.MutableMapping): items.extend(flatten_dict(v, new_key, sep=sep).items()) else: items.append((new_key, v)) return dict(items) #################### Wrapper ##################### class Wrapper: """Wrapper class for anything that needs to set up during qlib.init""" def __init__(self): self._provider = None def register(self, provider): self._provider = provider def __repr__(self): return "{name}(provider={provider})".format(name=self.__class__.__name__, provider=self._provider) def __getattr__(self, key): if self.__dict__.get("_provider", None) is None: raise AttributeError("Please run qlib.init() first using qlib") return getattr(self._provider, key) def register_wrapper(wrapper, cls_or_obj, module_path=None): """register_wrapper :param wrapper: A wrapper. :param cls_or_obj: A class or class name or object instance. """ if isinstance(cls_or_obj, str): module = get_module_by_module_path(module_path) cls_or_obj = getattr(module, cls_or_obj) obj = cls_or_obj() if isinstance(cls_or_obj, type) else cls_or_obj wrapper.register(obj) def load_dataset(path_or_obj): """load dataset from multiple file formats""" if isinstance(path_or_obj, pd.DataFrame): return path_or_obj if not os.path.exists(path_or_obj): raise ValueError(f"file {path_or_obj} doesn't exist") _, extension = os.path.splitext(path_or_obj) if extension == ".h5": return pd.read_hdf(path_or_obj) elif extension == ".pkl": return pd.read_pickle(path_or_obj) elif extension == ".csv": return pd.read_csv(path_or_obj, parse_dates=True, index_col=[0, 1]) raise ValueError(f"unsupported file type `{extension}`") def code_to_fname(code: str): """stock code to file name Parameters ---------- code: str """ # NOTE: In windows, the following name is I/O device, and the file with the corresponding name cannot be created # reference: https://superuser.com/questions/86999/why-cant-i-name-a-folder-or-file-con-in-windows replace_names = ["CON", "PRN", "AUX", "NUL"] replace_names += [f"COM{i}" for i in range(10)] replace_names += [f"LPT{i}" for i in range(10)] prefix = "_qlib_" if str(code).upper() in replace_names: code = prefix + str(code) return code def fname_to_code(fname: str): """file name to stock code Parameters ---------- fname: str """ prefix = "_qlib_" if fname.startswith(prefix): fname = fname.lstrip(prefix) return fname