1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 11:00:57 +08:00

Format with black

This commit is contained in:
Jactus
2020-10-29 13:22:49 +08:00
parent 490dbd908b
commit da9d1c8ac6
20 changed files with 290 additions and 251 deletions

View File

@@ -39,7 +39,7 @@ def init(default_conf="client", **kwargs):
LOG.info(f"default_conf: {default_conf}.")
C.set_mode(default_conf)
C.set_region(kwargs.get('region', C['region'] if 'region' in C else REG_CN ))
C.set_region(kwargs.get("region", C["region"] if "region" in C else REG_CN))
for k, v in kwargs.items():
C[k] = v
@@ -80,13 +80,13 @@ def init(default_conf="client", **kwargs):
if "flask_server" in C:
LOG.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
# set up QlibRecorder
default_uri = str(Path(os.getcwd()).resolve() / "mlruns")
current_uri = C['exp_uri'] if C['exp_uri'] is not None else default_uri
current_uri = C["exp_uri"] if C["exp_uri"] is not None else default_uri
# exp manager module
module = get_module_by_module_path('qlib.workflow')
exp_manager = init_instance_by_config(C['exp_manager'], module)
module = get_module_by_module_path("qlib.workflow")
exp_manager = init_instance_by_config(C["exp_manager"], module)
qr = QlibRecorder(exp_manager, default_uri, current_uri)
R.register(qr)

View File

@@ -125,10 +125,7 @@ _default_config = {
"loggers": {"qlib": {"level": "DEBUG", "handlers": ["console"]}},
},
# Defatult config for experiment manager
"exp_manager": {
"class": "MLflowExpManager",
"kwargs": {}
},
"exp_manager": {"class": "MLflowExpManager", "kwargs": {}},
"exp_uri": None,
}

View File

@@ -46,10 +46,10 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
benchmark code, default is SH000905 CSI500
"""
# Convert format if the input format is not expected
if get_level_index(pred, level='datetime') == 1:
if get_level_index(pred, level="datetime") == 1:
pred = pred.swaplevel().sort_index()
if isinstance(pred, pd.Series):
pred = pred.to_frame('score')
pred = pred.to_frame("score")
trade_account = Account(init_cash=account)
_pred_dates = pred.index.get_level_values(level="datetime")
@@ -80,8 +80,9 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
# 1. Load the score_series at pred_date
try:
score = pred.loc(axis=0)[pred_date, :] # (trade_date, stock_id) multi_index, score in pdate
score_series = score.reset_index(level="datetime",
drop=True)["score"] # pd.Series(index:stock_id, data: score)
score_series = score.reset_index(level="datetime", drop=True)[
"score"
] # pd.Series(index:stock_id, data: score)
except KeyError:
LOG.warning("No score found on predict date[{:%Y-%m-%d}]".format(trade_date))
score_series = None

View File

@@ -16,21 +16,16 @@ class ALPHA360(DataHandlerLP):
"kwargs": {
"config": {
"feature": {
"price": {
"windows": range(60)
},
"volume": {
"windows": range(60)
},
"price": {"windows": range(60)},
"volume": {"windows": range(60)},
},
"label": self.get_label_config()
"label": self.get_label_config(),
},
}
},
}
infer_processors = [{
"class": "ConfigSectionProcessor",
"module_path": "qlib.contrib.data.processor"
}] # ConfigSectionProcessor will normalize LABEL0
infer_processors = [
{"class": "ConfigSectionProcessor", "module_path": "qlib.contrib.data.processor"}
] # ConfigSectionProcessor will normalize LABEL0
super().__init__(instruments, start_time, end_time, data_loader=data_loader, infer_processors=infer_processors)
def get_label_config(self):
@@ -49,12 +44,7 @@ class Alpha158(DataHandlerLP):
start_time=None,
end_time=None,
infer_processors=[],
learn_processors=["DropnaLabel", {
"class": "CSZScoreNorm",
"kwargs": {
"fields_group": "label"
}
}],
learn_processors=["DropnaLabel", {"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}}],
fit_start_time=None,
fit_end_time=None,
):
@@ -65,11 +55,13 @@ class Alpha158(DataHandlerLP):
klass, pkwargs = get_cls_kwargs(p, processor_module)
# FIXME: It's hard code here!!!!!
if isinstance(klass, (MinMaxNorm, ZscoreNorm)):
assert (fit_start_time is not None and fit_end_time is not None)
pkwargs.update({
"fit_start_time": fit_start_time,
"fit_end_time": fit_end_time,
})
assert fit_start_time is not None and fit_end_time is not None
pkwargs.update(
{
"fit_start_time": fit_start_time,
"fit_end_time": fit_end_time,
}
)
new_l.append({"class": klass.__name__, "kwargs": pkwargs})
else:
new_l.append(p)
@@ -81,18 +73,17 @@ class Alpha158(DataHandlerLP):
data_loader = {
"class": "QlibDataLoader",
"kwargs": {
"config": {
"feature": self.get_feature_config(),
"label": self.get_label_config()
},
}
"config": {"feature": self.get_feature_config(), "label": self.get_label_config()},
},
}
super().__init__(instruments,
start_time,
end_time,
data_loader=data_loader,
infer_processors=infer_processors,
learn_processors=learn_processors)
super().__init__(
instruments,
start_time,
end_time,
data_loader=data_loader,
infer_processors=infer_processors,
learn_processors=learn_processors,
)
def get_feature_config(self):
conf = {
@@ -247,7 +238,8 @@ class Alpha158(DataHandlerLP):
if use("SUMD"):
fields += [
"(Sum(Greater($close-Ref($close, 1), 0), %d)-Sum(Greater(Ref($close, 1)-$close, 0), %d))"
"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d) for d in windows
"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d)
for d in windows
]
names += ["SUMD%d" % d for d in windows]
if use("VMA"):
@@ -258,26 +250,30 @@ class Alpha158(DataHandlerLP):
names += ["VSTD%d" % d for d in windows]
if use("WVMA"):
fields += [
"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)" %
(d, d) for d in windows
"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)"
% (d, d)
for d in windows
]
names += ["WVMA%d" % d for d in windows]
if use("VSUMP"):
fields += [
"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d)
"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
% (d, d)
for d in windows
]
names += ["VSUMP%d" % d for d in windows]
if use("VSUMN"):
fields += [
"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d)
"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
% (d, d)
for d in windows
]
names += ["VSUMN%d" % d for d in windows]
if use("VSUMD"):
fields += [
"(Sum(Greater($volume-Ref($volume, 1), 0), %d)-Sum(Greater(Ref($volume, 1)-$volume, 0), %d))"
"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d) for d in windows
"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d)
for d in windows
]
names += ["VSUMD%d" % d for d in windows]

View File

@@ -8,9 +8,10 @@ from ...data.dataset.processor import Processor, get_group_columns
class ConfigSectionProcessor(Processor):
'''
"""
This processor is designed for Alpha158. And will be replaced by simple processors in the future
'''
"""
def __init__(self, fields_group=None, **kwargs):
super().__init__()
# Options

View File

@@ -159,11 +159,11 @@ def get_exchange(
if deal_price[0] != "$":
deal_price = "$" + deal_price
if extract_codes:
codes = sorted(pred.index.get_level_values('instrument').unique())
codes = sorted(pred.index.get_level_values("instrument").unique())
else:
codes = "all" # TODO: We must ensure that 'all.txt' includes all the stocks
dates = sorted(pred.index.get_level_values('datetime').unique())
dates = sorted(pred.index.get_level_values("datetime").unique())
dates = np.append(dates, get_date_range(dates[-1], shift=shift))
exchange = Exchange(
@@ -298,7 +298,7 @@ def long_short_backtest(
"short": short_returns(excess),
"long_short": long_short_returns}
"""
if get_level_index(pred, level='datetime') == 1:
if get_level_index(pred, level="datetime") == 1:
pred = pred.swaplevel().sort_index()
if trade_unit is None:

View File

@@ -12,26 +12,29 @@ from ...data.dataset.handler import DataHandlerLP
class LGBModel(Model):
"""LightGBM Model"""
def __init__(self, loss="mse", **kwargs):
if loss not in {"mse", "binary"}:
raise NotImplementedError
self._params = {'objective': loss}
self._params = {"objective": loss}
self._params.update(kwargs)
self.model = None
def fit(self,
dataset: DatasetH,
num_boost_round=1000,
early_stopping_rounds=50,
verbose_eval=20,
evals_result=dict(),
**kwargs):
def fit(
self,
dataset: DatasetH,
num_boost_round=1000,
early_stopping_rounds=50,
verbose_eval=20,
evals_result=dict(),
**kwargs
):
df_train, df_valid = dataset.prepare(['train', 'valid'],
col_set=['feature', 'label'],
data_key=DataHandlerLP.DK_L)
x_train, y_train = df_train['feature'], df_train['label']
x_valid, y_valid = df_valid['feature'], df_valid['label']
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
)
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
# Lightgbm need 1D array as its label
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
@@ -41,20 +44,22 @@ class LGBModel(Model):
dtrain = lgb.Dataset(x_train.values, label=y_train_1d)
dvalid = lgb.Dataset(x_valid.values, label=y_valid_1d)
self.model = lgb.train(self._params,
dtrain,
num_boost_round=num_boost_round,
valid_sets=[dtrain, dvalid],
valid_names=["train", "valid"],
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval,
evals_result=evals_result,
**kwargs)
self.model = lgb.train(
self._params,
dtrain,
num_boost_round=num_boost_round,
valid_sets=[dtrain, dvalid],
valid_names=["train", "valid"],
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval,
evals_result=evals_result,
**kwargs
)
evals_result["train"] = list(evals_result["train"].values())[0]
evals_result["valid"] = list(evals_result["valid"].values())[0]
def predict(self, dataset):
if self.model is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare('test', col_set='feature')
x_test = dataset.prepare("test", col_set="feature")
return pd.Series(self.model.predict(np.squeeze(x_test.values)), index=x_test.index)

View File

@@ -6,11 +6,12 @@ import pandas as pd
class Dataset(Serializable):
'''
"""
Preparing data for model training and inferencing.
'''
"""
def __init__(self, *args, **kwargs):
'''
"""
init is designed to finish following steps
- setup data
- The data related attributes' names should start with '_' so that it will not be saved on disk when serializing
@@ -18,7 +19,7 @@ class Dataset(Serializable):
- The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing.
The data could specify the info to caculate the essential data for preparation
'''
"""
self.setup_data(*args, **kwargs)
super().__init__()
@@ -51,14 +52,15 @@ class Dataset(Serializable):
class DatasetH(Dataset):
'''
"""
Dataset with Data(H)anler
User should try to put the data preprocessing functions into handler.
Only following data processing functions should be placed in Dataset
- The processing is related to specific model.
- The processing is related to data split
'''
"""
def __init__(self, handler: Union[dict, DataHandler], segments: list):
"""
Parameters
@@ -96,10 +98,9 @@ class DatasetH(Dataset):
self._handler = init_instance_by_config(handler, accept_types=DataHandler)
self._segments = segments
def prepare(self,
segments: Union[List[str], Tuple[str], str, slice],
col_set=DataHandler.CS_ALL,
**kwargs) -> Union[List[pd.DataFrame], pd.DataFrame]:
def prepare(
self, segments: Union[List[str], Tuple[str], str, slice], col_set=DataHandler.CS_ALL, **kwargs
) -> Union[List[pd.DataFrame], pd.DataFrame]:
"""
prepare the data for learning and inference
@@ -124,9 +125,7 @@ class DatasetH(Dataset):
[TODO:description]
"""
if isinstance(segments, (list, tuple)):
return [
self._handler.fetch(slice(*self._segments[seg]), col_set=col_set, **kwargs) for seg in segments
]
return [self._handler.fetch(slice(*self._segments[seg]), col_set=col_set, **kwargs) for seg in segments]
elif isinstance(segments, str):
return self._handler.fetch(slice(*self._segments[segments]), col_set=col_set, **kwargs)
else:

View File

@@ -25,7 +25,7 @@ from . import loader as data_loader_module
# TODO: A more general handler interface which does not relies on internal pd.DataFrame is needed.
class DataHandler(Serializable):
'''
"""
The steps to using a handler
1. initialized data handler (call by `init`).
2. use the data
@@ -46,13 +46,21 @@ class DataHandler(Serializable):
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
'''
def __init__(self, instruments, start_time=None, end_time=None, data_loader: Tuple[dict, str, DataLoader]=None, init_data=True):
"""
def __init__(
self,
instruments,
start_time=None,
end_time=None,
data_loader: Tuple[dict, str, DataLoader] = None,
init_data=True,
):
# Set logger
self.logger = get_module_logger("DataHandler")
# Setup data loader
assert(data_loader is not None) # to make start_time end_time could have None default value
assert data_loader is not None # to make start_time end_time could have None default value
self.data_loader = init_instance_by_config(data_loader, data_loader_module, accept_types=DataLoader)
self.instruments = instruments
@@ -62,7 +70,7 @@ class DataHandler(Serializable):
self.init()
super().__init__()
def init(self, enable_cache: bool=True):
def init(self, enable_cache: bool = True):
"""
initialize the data.
In case of running intialization for multiple time, it will do nothing for the second time.
@@ -83,7 +91,9 @@ class DataHandler(Serializable):
self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time)
# TODO: cache
def _fetch_df_by_index(self, df: pd.DataFrame, selector: Union[pd.Timestamp, slice, str, list], level: Union[str, int]) -> pd.DataFrame:
def _fetch_df_by_index(
self, df: pd.DataFrame, selector: Union[pd.Timestamp, slice, str, list], level: Union[str, int]
) -> pd.DataFrame:
"""
fetch data from `data` with `selector` and `level`
@@ -100,7 +110,7 @@ class DataHandler(Serializable):
idx_slc = idx_slc[1], idx_slc[0]
return df.loc(axis=0)[idx_slc]
CS_ALL = '__all'
CS_ALL = "__all"
def _fetch_df_by_col(self, df: pd.DataFrame, col_set: str) -> pd.DataFrame:
cln = len(df.columns.levels)
@@ -111,10 +121,12 @@ class DataHandler(Serializable):
else:
return df.loc(axis=1)[col_set]
def fetch(self,
selector: Union[pd.Timestamp, slice, str],
level: Union[str, int] = 'datetime',
col_set: Union[str, List[str]] = CS_ALL) -> pd.DataFrame:
def fetch(
self,
selector: Union[pd.Timestamp, slice, str],
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = CS_ALL,
) -> pd.DataFrame:
"""
fetch data from underlying data source
@@ -157,32 +169,35 @@ class DataHandler(Serializable):
class DataHandlerLP(DataHandler):
'''
"""
DataHandler with **(L)earnable (P)rocessor**
'''
"""
# data key
DK_R = 'raw'
DK_I = 'infer'
DK_L = 'learn'
DK_R = "raw"
DK_I = "infer"
DK_L = "learn"
# process type
PTYPE_I = 'independent'
PTYPE_I = "independent"
# - _proc_infer_df will processed by infer_processors
# - _proc_learn_df will be processed by learn_processors
PTYPE_A = 'append'
PTYPE_A = "append"
# - _proc_infer_df will processed by infer_processors
# - _proc_learn_df will be processed by infer_processors + learn_processors
# - (e.g. _proc_infer_df processed by learn_processors )
def __init__(self,
instruments,
start_time=None,
end_time=None,
data_loader: Tuple[dict, str, DataLoader] = None,
infer_processors=[],
learn_processors=[],
process_type=PTYPE_A,
**kwargs):
def __init__(
self,
instruments,
start_time=None,
end_time=None,
data_loader: Tuple[dict, str, DataLoader] = None,
infer_processors=[],
learn_processors=[],
process_type=PTYPE_A,
**kwargs,
):
"""
Parameters
----------
@@ -217,10 +232,11 @@ class DataHandlerLP(DataHandler):
# Setup preprocessor
self.infer_processors = [] # for lint
self.learn_processors = [] # for lint
for pname in 'infer_processors', 'learn_processors':
for pname in "infer_processors", "learn_processors":
for proc in locals()[pname]:
getattr(self, pname).append(init_instance_by_config(proc, processor_module,
accept_types=(processor_module.Processor,)))
getattr(self, pname).append(
init_instance_by_config(proc, processor_module, accept_types=(processor_module.Processor,))
)
self.process_type = process_type
super().__init__(instruments, start_time, end_time, data_loader, **kwargs)
@@ -240,8 +256,7 @@ class DataHandlerLP(DataHandler):
"""
self.process_data(with_fit=True)
def process_data(self, with_fit: bool=False):
def process_data(self, with_fit: bool = False):
"""
process_data data. Fun `processor.fit` if necessary
@@ -281,11 +296,11 @@ class DataHandlerLP(DataHandler):
self._learn = _learn_df
# init type
IT_FIT_SEQ = 'fit_seq' # the input of `fit` will be the output of the previous processor
IT_FIT_IND = 'fit_ind' # the input of `fit` will be the original df
IT_LS = 'load_state' # The state of the object has been load by pickle
IT_FIT_SEQ = "fit_seq" # the input of `fit` will be the output of the previous processor
IT_FIT_IND = "fit_ind" # the input of `fit` will be the original df
IT_LS = "load_state" # The state of the object has been load by pickle
def init(self, init_type: str=IT_FIT_SEQ, enable_cache: bool=False):
def init(self, init_type: str = IT_FIT_SEQ, enable_cache: bool = False):
"""
Initialize the data of Qlib
@@ -314,15 +329,17 @@ class DataHandlerLP(DataHandler):
# TODO: Be able to cache handler data. Save the memory for data processing
def _get_df_by_key(self, data_key: str=DK_I) -> pd.DataFrame:
df = getattr(self, {self.DK_R: '_data', self.DK_I: "_infer", self.DK_L: "_learn"}[data_key])
def _get_df_by_key(self, data_key: str = DK_I) -> pd.DataFrame:
df = getattr(self, {self.DK_R: "_data", self.DK_I: "_infer", self.DK_L: "_learn"}[data_key])
return df
def fetch(self,
selector: Union[pd.Timestamp, slice, str],
level: Union[str, int] = 'datetime',
col_set=DataHandler.CS_ALL,
data_key: str = DK_I) -> pd.DataFrame:
def fetch(
self,
selector: Union[pd.Timestamp, slice, str],
level: Union[str, int] = "datetime",
col_set=DataHandler.CS_ALL,
data_key: str = DK_I,
) -> pd.DataFrame:
"""
fetch data from underlying data source
@@ -345,7 +362,7 @@ class DataHandlerLP(DataHandler):
df = self._fetch_df_by_index(df, selector, level)
return self._fetch_df_by_col(df, col_set)
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str=DK_I) -> list:
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list:
"""
get the column names

View File

@@ -8,44 +8,46 @@ from typing import Tuple
class DataLoader(ABC):
'''
"""
DataLoader is designed for loading raw data from original data source.
'''
"""
@abstractmethod
def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame:
"""
load the data as pd.DataFrame
load the data as pd.DataFrame
Parameters
----------
self : [TODO:type]
[TODO:description]
instruments : [TODO:type]
[TODO:description]
start_time : [TODO:type]
[TODO:description]
end_time : [TODO:type]
[TODO:description]
Parameters
----------
self : [TODO:type]
[TODO:description]
instruments : [TODO:type]
[TODO:description]
start_time : [TODO:type]
[TODO:description]
end_time : [TODO:type]
[TODO:description]
Returns
-------
pd.DataFrame:
data load from the under layer source
Returns
-------
pd.DataFrame:
data load from the under layer source
Example of the data:
The multi-index of the columns is optional.
feature label
$close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0
datetime instrument
2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
Example of the data:
The multi-index of the columns is optional.
feature label
$close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0
datetime instrument
2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
"""
pass
class QlibDataLoader(DataLoader):
'''Same as QlibDataLoader. The fields can be define by config'''
"""Same as QlibDataLoader. The fields can be define by config"""
def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None):
"""
Parameters
@@ -65,7 +67,7 @@ class QlibDataLoader(DataLoader):
Here is a few examples to describe the fields
TODO:
"""
self.is_group = isinstance(config, dict)
self.is_group = isinstance(config, dict)
if self.is_group:
self.fields = {grp: self._parse_fields_info(fields_info) for grp, fields_info in config.items()}
@@ -88,6 +90,7 @@ class QlibDataLoader(DataLoader):
df = D.features(D.instruments(instruments, filter_pipe=self.filter_pipe), exprs, start_time, end_time)
df.columns = names
return df
if self.is_group:
df = pd.concat({grp: _get_df(exprs, names) for grp, (exprs, names) in self.fields.items()}, axis=1)
else:

View File

@@ -30,8 +30,7 @@ def get_group_columns(df: pd.DataFrame, group: str):
class Processor(Serializable):
def fit(self, df: pd.DataFrame=None):
def fit(self, df: pd.DataFrame = None):
"""
learn data processing parameters
@@ -40,7 +39,7 @@ class Processor(Serializable):
df : pd.DataFrame
When we fit and process data with processor one by one. The fit function reiles on the output of previous
processor, i.e. `df`.
"""
pass
@@ -81,16 +80,17 @@ class DropnaProcessor(Processor):
class DropnaLabel(DropnaProcessor):
def __init__(self, group='label'):
def __init__(self, group="label"):
super().__init__(group=group)
def is_for_infer(self) -> bool:
'''The samples are dropped according to label. So it is not usable for inference'''
"""The samples are dropped according to label. So it is not usable for inference"""
return False
class ProcessInf(Processor):
'''Process infinity '''
"""Process infinity """
def __call__(self, df):
def replace_inf(data):
def process_inf(df):
@@ -102,6 +102,7 @@ class ProcessInf(Processor):
data = data.groupby("datetime").apply(process_inf)
data.sort_index(inplace=True)
return data
return replace_inf(df)
@@ -126,6 +127,7 @@ class MinMaxNorm(Processor):
if not ignore[i]:
x[i] = (x[i] - min_val) / (max_val - min_val)
return x
df.loc(axis=1)[self.cols] = normalize(df[self.cols].values)
return df
@@ -151,17 +153,19 @@ class ZscoreNorm(Processor):
if not ignore[i]:
x[i] = (x[i] - mean_train) / std_train
return x
df.loc(axis=1)[self.cols] = normalize(df[self.cols].values)
return df
class CSZScoreNorm(Processor):
'''Cross Sectional ZScore Normalization'''
"""Cross Sectional ZScore Normalization"""
def __init__(self, fields_group=None):
self.fields_group = fields_group
def __call__(self, df):
# try not modify original dataframe
cols = get_group_columns(df,self.fields_group)
df[cols] = df[cols].groupby('datetime').apply(lambda df: (df - df.mean()).div(df.std()))
cols = get_group_columns(df, self.fields_group)
df[cols] = df[cols].groupby("datetime").apply(lambda df: (df - df.mean()).div(df.std()))
return df

View File

@@ -24,9 +24,8 @@ def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int:
return df.index.names.index(level)
except (AttributeError, ValueError):
# NOTE: If level index is not given in the data, the default level index will be ('datetime', 'instrument')
return ('datetime', 'instrument').index(level)
return ("datetime", "instrument").index(level)
elif isinstance(level, int):
return level
else:
raise NotImplementedError(f"This type of input is not supported")

View File

@@ -6,7 +6,7 @@ from ..data.dataset import Dataset
class BaseModel(Serializable, metaclass=abc.ABCMeta):
'''Modeling things'''
"""Modeling things"""
@abc.abstractmethod
def predict(self, *args, **kwargs) -> object:
@@ -19,7 +19,7 @@ class BaseModel(Serializable, metaclass=abc.ABCMeta):
class Model(BaseModel):
'''Learnable Models'''
"""Learnable Models"""
def fit(self, dataset: Dataset):
"""

View File

@@ -165,7 +165,7 @@ def get_module_by_module_path(module_path):
return module
def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
"""
extract class and kwargs from config info
@@ -184,8 +184,8 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
"""
if isinstance(config, dict):
# raise AttributeError
klass = getattr(module, config['class'])
kwargs = config['kwargs']
klass = getattr(module, config["class"])
kwargs = config["kwargs"]
elif isinstance(config, str):
klass = getattr(module, config)
kwargs = {}
@@ -194,7 +194,9 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
return klass, kwargs
def init_instance_by_config(config: Union[str, dict], module=None, accept_types: Union[type, Tuple[type]]=tuple([])) -> object:
def init_instance_by_config(
config: Union[str, dict], module=None, accept_types: Union[type, Tuple[type]] = tuple([])
) -> object:
"""
get initialized instance with config
@@ -647,4 +649,4 @@ def register_wrapper(wrapper, cls_or_obj):
module = get_module_by_module_path("qlib.data")
cls_or_obj = getattr(module, cls_or_obj)
obj = cls_or_obj() if isinstance(cls_or_obj, type) else cls_or_obj
wrapper.register(obj)
wrapper.register(obj)

View File

@@ -24,7 +24,7 @@ class ObjManager:
def save_objs(self, obj_name_l):
"""
save objects
save objects
Parameters
----------
@@ -88,9 +88,10 @@ class ObjManager:
class FileManager(ObjManager):
'''
"""
Use file system to manage objects
'''
"""
def __init__(self, path=None):
if path is None:
self.path = Path(self.create_path())
@@ -99,12 +100,12 @@ class FileManager(ObjManager):
def create_path(self) -> str:
try:
return tempfile.mkdtemp(prefix=str(C['file_manager_path']) + os.sep)
return tempfile.mkdtemp(prefix=str(C["file_manager_path"]) + os.sep)
except AttributeError:
raise NotImplementedError(f"If path is not given, the `create_path` function should be implemented")
def save_obj(self, obj, name):
with (self.path / name).open('wb') as f:
with (self.path / name).open("wb") as f:
pickle.dump(obj, f)
def save_objs(self, obj_name_l):
@@ -112,7 +113,7 @@ class FileManager(ObjManager):
self.save_obj(obj, name)
def load_obj(self, name):
with (self.path / name).open('rb') as f:
with (self.path / name).open("rb") as f:
return pickle.load(f)
def exists(self, name):
@@ -123,7 +124,7 @@ class FileManager(ObjManager):
def remove(self, fname=None):
if fname is None:
for fp in self.path.glob('*'):
for fp in self.path.glob("*"):
fp.unlink()
self.path.rmdir()
else:

View File

@@ -6,17 +6,17 @@ import pickle
class Serializable:
'''
"""
Serializable behaves like pickle.
But it only save the state whose name starts with `_`
'''
"""
def __getstate__(self) -> dict:
return {k: v for k, v in self.__dict__.items() if k.startswith('_') }
return {k: v for k, v in self.__dict__.items() if k.startswith("_")}
def __setstate__(self, state: dict):
self.__dict__.update(state)
def to_pickle(self, path: [Path, str]):
with Path(path).open('wb') as f:
with Path(path).open("wb") as f:
pickle.dump(self, f)

View File

@@ -5,6 +5,7 @@ from contextlib import contextmanager
from .expm import *
from ..utils import Wrapper
class QlibRecorder:
def __init__(self, exp_manager, default_uri, current_uri):
self.exp_manager = exp_manager
@@ -16,16 +17,16 @@ class QlibRecorder:
run = self.start_exp(experiment_name, self.current_uri)
yield run
self.end_exp()
def start_exp(self, experiment_name=None):
return self.exp_manager.start_exp(experiment_name, self.current_uri)
return self.exp_manager.start_exp(experiment_name, self.current_uri)
def end_exp(self):
self.exp_manager.end_exp()
def search_records(self, experiment_ids, **kwargs):
return self.exp_manager.search_records(experiment_ids, **kwargs)
def get_exp(self, experiment_id=None, experiment_name=None):
return self.exp_manager.get_exp(experiment_id, experiment_name)
@@ -52,12 +53,13 @@ class QlibRecorder:
def log_metrics(self, step=None, **kwargs):
self.exp_manager.active_recorder.log_metrics(step, **kwargs)
def set_tags(self, **kwargs):
self.exp_manager.active_recorder.set_tags(**kwargs)
def delete_tag(self, key):
self.exp_manager.active_recorder.delete_tag(key)
# global record
R = Wrapper()

View File

@@ -4,10 +4,12 @@
import mlflow
from pathlib import Path
class Experiment:
"""
Thie is the `Experiment` class for each experiment being run. The API is designed
Thie is the `Experiment` class for each experiment being run. The API is designed
"""
def __init__(self):
self.name = None
self.id = None
@@ -39,9 +41,10 @@ class MLflowExperiment(Experiment):
"""
Use mlflow to implement Experiment.
"""
def search_records(self, **kwargs):
filter_string = '' if kwargs.get('filter_string') is None else kwargs.get('filter_string')
run_view_type = 1 if kwargs.get('run_view_type') is None else kwargs.get('run_view_type')
max_results = 100000 if kwargs.get('max_results') is None else kwargs.get('max_results')
order_by = kwargs.get('order_by')
return mlflow.search_runs([self.experiment_id], filter_string, run_view_type, max_results, order_by)
filter_string = "" if kwargs.get("filter_string") is None else kwargs.get("filter_string")
run_view_type = 1 if kwargs.get("run_view_type") is None else kwargs.get("run_view_type")
max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results")
order_by = kwargs.get("order_by")
return mlflow.search_runs([self.experiment_id], filter_string, run_view_type, max_results, order_by)

View File

@@ -8,15 +8,17 @@ from contextlib import contextmanager
from .exp import MLflowExperiment
from .record import MLflowRecorder
class ExpManager:
"""
This is the `ExpManager` class for managing the experiments. The API is designed similar to mlflow.
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
"""
def __init__(self):
self.default_uri = None
self.active_recorder = None # only one recorder can running each time
self.experiments = dict() # store the experiment name --> Experiment object
self.active_recorder = None # only one recorder can running each time
self.experiments = dict() # store the experiment name --> Experiment object
def start_exp(self, experiment_name=None, uri=None, **kwargs):
"""
@@ -88,7 +90,7 @@ class ExpManager:
An experiment object.
"""
raise NotImplementedError(f"Please implement the `create_exp` method.")
def get_exp(self, experiment_id=None, experiment_name=None):
"""
Retrieve an experiment by experiment_id from the backend store.
@@ -111,7 +113,7 @@ class ExpManager:
Parameters
----------
experiment_id : str
the experiment id.
the experiment id.
"""
raise NotImplementedError(f"Please implement the `create_exp` method.")
@@ -142,12 +144,13 @@ class ExpManager:
An Recorder object.
"""
raise NotImplementedError(f"Please implement the `get_recorder` method.")
class MLflowExpManager(ExpManager):
'''
"""
Use mlflow to implement ExpManager.
'''
"""
def __init__(self):
super(MLflowExpManager, self).__init__()
self.default_uri = None
@@ -169,27 +172,31 @@ class MLflowExpManager(ExpManager):
def end_exp(self):
self.active_recorder.end_run()
self.active_recorder = None
def __create_exp(self, experiment_name=None, uri=None):
# init experiment
experiment = MLflowExperiment()
# set the tracking uri
if uri is None:
print('No tracking URI is provided. The default tracking URI is set as `mlruns` under the working directory.')
print(
"No tracking URI is provided. The default tracking URI is set as `mlruns` under the working directory."
)
else:
self.current_uri = uri
mlflow.set_tracking_uri(self.current_uri)
# start the experiment
if experiment_name is None:
print('No experiment name provided. The default experiment name is set as `experiment`.')
experiment_id = mlflow.create_experiment('experiment')
print("No experiment name provided. The default experiment name is set as `experiment`.")
experiment_id = mlflow.create_experiment("experiment")
# set the active experiment
mlflow.set_experiment('experiment')
experiment_name = 'experiment'
mlflow.set_experiment("experiment")
experiment_name = "experiment"
else:
if experiment_name not in self.experiments:
if mlflow.get_experiment_by_name(experiment_name) is not None:
raise Exception('The experiment has already been created before. Please pick another name or delete the files under uri.')
raise Exception(
"The experiment has already been created before. Please pick another name or delete the files under uri."
)
experiment_id = mlflow.create_experiment(experiment_name)
else:
experiment_id = self.experiments[experiment_name].id
@@ -197,40 +204,42 @@ class MLflowExpManager(ExpManager):
# set the active experiment
mlflow.set_experiment(experiment_name)
# set up experiment
experiment.id = experiment_id
experiment.id = experiment_id
experiment.name = experiment_name
return experiment
def search_records(self, experiment_ids, **kwargs):
filter_string = '' if kwargs.get('filter_string') is None else kwargs.get('filter_string')
run_view_type = 1 if kwargs.get('run_view_type') is None else kwargs.get('run_view_type')
max_results = 100000 if kwargs.get('max_results') is None else kwargs.get('max_results')
order_by = kwargs.get('order_by')
filter_string = "" if kwargs.get("filter_string") is None else kwargs.get("filter_string")
run_view_type = 1 if kwargs.get("run_view_type") is None else kwargs.get("run_view_type")
max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results")
order_by = kwargs.get("order_by")
return mlflow.search_runs(experiment_ids, filter_string, run_view_type, max_results, order_by)
def get_exp(self, experiment_id=None, experiment_name=None):
assert experiment_id is not None or experiment_name is not None, 'Please provide at least one of the experiment id or name to retrieve an experiment.'
assert (
experiment_id is not None or experiment_name is not None
), "Please provide at least one of the experiment id or name to retrieve an experiment."
if experiment_name is not None:
return self.experiments[experiment_name]
elif:
elif experiment_id is not None:
for name in self.experiments:
if self.experiments[name].id == experiment_id:
return self.experiments[name]
else:
print('No valid experiment is found. Please make sure the id and name are correctly given.')
print("No valid experiment is found. Please make sure the id and name are correctly given.")
def delete_exp(self, experiment_id):
mlflow.delete_experiment(experiment_id)
self.experiments = {key:val for key, val in self.experiments.items() if val.id != experiment_id}
self.experiments = {key: val for key, val in self.experiments.items() if val.id != experiment_id}
def get_uri(self, type):
if uri == 'default':
if uri == "default":
return self.default_uri
elif uri == 'current':
elif uri == "current":
return self.current_uri
else:
raise ValueError('Input type is not supported. Please choose type default or current to get the uri.')
raise ValueError("Input type is not supported. Please choose type default or current to get the uri.")
def get_recorder(self):
return self.active_recorder
return self.active_recorder

View File

@@ -6,6 +6,7 @@ import shutil, os, pickle, tempfile, codecs
from pathlib import Path
from ..utils.objm import FileManager
class Recorder:
"""
This is the `Recorder` class for logging the experiments. The API is designed similar to mlflow.
@@ -16,7 +17,7 @@ class Recorder:
self.experiment_id = experiment_id
self.recorder_id = None
self.recorder_name = None
def set_recorder_name(self, rname):
self.recorder_name = rname
@@ -63,10 +64,9 @@ class Recorder:
"""
raise NotImplementedError(f"Please implement the `load_object` method.")
def start_run(self, run_id=None, experiment_id=None,
run_name=None, nested=False):
def start_run(self, run_id=None, experiment_id=None, run_name=None, nested=False):
"""
Start running the Recorder. The return value can be used as a context manager within a `with` block;
Start running the Recorder. The return value can be used as a context manager within a `with` block;
otherwise, you must call end_run() to terminate the current run. (See `ActiveRun` class in mlflow)
Parameters
@@ -85,7 +85,7 @@ class Recorder:
An active running object (e.g. mlflow.ActiveRun object).
"""
raise NotImplementedError(f"Please implement the `start_run` method.")
def end_run(self):
"""
End an active Recorder.
@@ -138,19 +138,19 @@ class Recorder:
class MLflowRecorder(Recorder):
'''
"""
Use mlflow to implement a Recorder.
Due to the fact that mlflow will only log artifact from a file or directory, we decide to
Due to the fact that mlflow will only log artifact from a file or directory, we decide to
use file manager to help maintain the objects in the project.
'''
"""
def __init__(self, experiment_id):
super(MLflowRecorder, self).__init__(experiment_id)
self.fm = None
self.temp_dir = None
def start_run(self, run_id=None, experiment_id=None,
run_name=None, nested=False):
def start_run(self, run_id=None, experiment_id=None, run_name=None, nested=False):
if run_id is None:
run_id = self.recorder_id
if experiment_id is None:
@@ -166,7 +166,7 @@ class MLflowRecorder(Recorder):
self.temp_dir = tempfile.mkdtemp()
self.fm = FileManager(Path(self.temp_dir).absolute())
return run
def end_run(self):
mlflow.end_run()
shutil.rmtree(self.temp_dir)
@@ -194,13 +194,13 @@ class MLflowRecorder(Recorder):
client = mlflow.tracking.MlflowClient()
path = client.download_artifacts(self.recorder_id, name)
try:
with Path(path).open('rb') as f:
with Path(path).open("rb") as f:
f.seek(0)
return pickle.load(f)
except:
with codecs.open(path, mode="r", encoding='utf-8') as f:
return f.read()
with codecs.open(path, mode="r", encoding="utf-8") as f:
return f.read()
def log_params(self, **kwargs):
keys = list(kwargs.keys())
if len(keys) == 0:
@@ -214,7 +214,7 @@ class MLflowRecorder(Recorder):
mlflow.log_metric(keys[0], kwargs.get(keys[0]))
else:
mlflow.log_metrics(dict(kwargs))
def set_tags(self, **kwargs):
keys = list(kwargs.keys())
if len(keys) == 0:
@@ -228,4 +228,4 @@ class MLflowRecorder(Recorder):
def get_artifact_uri(self, artifact_path=None):
if self.artifact_uri is not None:
return self.artifact_uri
return mlflow.get_artifact_uri(artifact_path)
return mlflow.get_artifact_uri(artifact_path)