mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 11:00:57 +08:00
Format with black
This commit is contained in:
@@ -39,7 +39,7 @@ def init(default_conf="client", **kwargs):
|
||||
LOG.info(f"default_conf: {default_conf}.")
|
||||
|
||||
C.set_mode(default_conf)
|
||||
C.set_region(kwargs.get('region', C['region'] if 'region' in C else REG_CN ))
|
||||
C.set_region(kwargs.get("region", C["region"] if "region" in C else REG_CN))
|
||||
|
||||
for k, v in kwargs.items():
|
||||
C[k] = v
|
||||
@@ -80,13 +80,13 @@ def init(default_conf="client", **kwargs):
|
||||
|
||||
if "flask_server" in C:
|
||||
LOG.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
|
||||
|
||||
|
||||
# set up QlibRecorder
|
||||
default_uri = str(Path(os.getcwd()).resolve() / "mlruns")
|
||||
current_uri = C['exp_uri'] if C['exp_uri'] is not None else default_uri
|
||||
current_uri = C["exp_uri"] if C["exp_uri"] is not None else default_uri
|
||||
# exp manager module
|
||||
module = get_module_by_module_path('qlib.workflow')
|
||||
exp_manager = init_instance_by_config(C['exp_manager'], module)
|
||||
module = get_module_by_module_path("qlib.workflow")
|
||||
exp_manager = init_instance_by_config(C["exp_manager"], module)
|
||||
qr = QlibRecorder(exp_manager, default_uri, current_uri)
|
||||
R.register(qr)
|
||||
|
||||
|
||||
@@ -125,10 +125,7 @@ _default_config = {
|
||||
"loggers": {"qlib": {"level": "DEBUG", "handlers": ["console"]}},
|
||||
},
|
||||
# Defatult config for experiment manager
|
||||
"exp_manager": {
|
||||
"class": "MLflowExpManager",
|
||||
"kwargs": {}
|
||||
},
|
||||
"exp_manager": {"class": "MLflowExpManager", "kwargs": {}},
|
||||
"exp_uri": None,
|
||||
}
|
||||
|
||||
|
||||
@@ -46,10 +46,10 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
|
||||
benchmark code, default is SH000905 CSI500
|
||||
"""
|
||||
# Convert format if the input format is not expected
|
||||
if get_level_index(pred, level='datetime') == 1:
|
||||
if get_level_index(pred, level="datetime") == 1:
|
||||
pred = pred.swaplevel().sort_index()
|
||||
if isinstance(pred, pd.Series):
|
||||
pred = pred.to_frame('score')
|
||||
pred = pred.to_frame("score")
|
||||
|
||||
trade_account = Account(init_cash=account)
|
||||
_pred_dates = pred.index.get_level_values(level="datetime")
|
||||
@@ -80,8 +80,9 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
|
||||
# 1. Load the score_series at pred_date
|
||||
try:
|
||||
score = pred.loc(axis=0)[pred_date, :] # (trade_date, stock_id) multi_index, score in pdate
|
||||
score_series = score.reset_index(level="datetime",
|
||||
drop=True)["score"] # pd.Series(index:stock_id, data: score)
|
||||
score_series = score.reset_index(level="datetime", drop=True)[
|
||||
"score"
|
||||
] # pd.Series(index:stock_id, data: score)
|
||||
except KeyError:
|
||||
LOG.warning("No score found on predict date[{:%Y-%m-%d}]".format(trade_date))
|
||||
score_series = None
|
||||
|
||||
@@ -16,21 +16,16 @@ class ALPHA360(DataHandlerLP):
|
||||
"kwargs": {
|
||||
"config": {
|
||||
"feature": {
|
||||
"price": {
|
||||
"windows": range(60)
|
||||
},
|
||||
"volume": {
|
||||
"windows": range(60)
|
||||
},
|
||||
"price": {"windows": range(60)},
|
||||
"volume": {"windows": range(60)},
|
||||
},
|
||||
"label": self.get_label_config()
|
||||
"label": self.get_label_config(),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
infer_processors = [{
|
||||
"class": "ConfigSectionProcessor",
|
||||
"module_path": "qlib.contrib.data.processor"
|
||||
}] # ConfigSectionProcessor will normalize LABEL0
|
||||
infer_processors = [
|
||||
{"class": "ConfigSectionProcessor", "module_path": "qlib.contrib.data.processor"}
|
||||
] # ConfigSectionProcessor will normalize LABEL0
|
||||
super().__init__(instruments, start_time, end_time, data_loader=data_loader, infer_processors=infer_processors)
|
||||
|
||||
def get_label_config(self):
|
||||
@@ -49,12 +44,7 @@ class Alpha158(DataHandlerLP):
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
infer_processors=[],
|
||||
learn_processors=["DropnaLabel", {
|
||||
"class": "CSZScoreNorm",
|
||||
"kwargs": {
|
||||
"fields_group": "label"
|
||||
}
|
||||
}],
|
||||
learn_processors=["DropnaLabel", {"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}}],
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
):
|
||||
@@ -65,11 +55,13 @@ class Alpha158(DataHandlerLP):
|
||||
klass, pkwargs = get_cls_kwargs(p, processor_module)
|
||||
# FIXME: It's hard code here!!!!!
|
||||
if isinstance(klass, (MinMaxNorm, ZscoreNorm)):
|
||||
assert (fit_start_time is not None and fit_end_time is not None)
|
||||
pkwargs.update({
|
||||
"fit_start_time": fit_start_time,
|
||||
"fit_end_time": fit_end_time,
|
||||
})
|
||||
assert fit_start_time is not None and fit_end_time is not None
|
||||
pkwargs.update(
|
||||
{
|
||||
"fit_start_time": fit_start_time,
|
||||
"fit_end_time": fit_end_time,
|
||||
}
|
||||
)
|
||||
new_l.append({"class": klass.__name__, "kwargs": pkwargs})
|
||||
else:
|
||||
new_l.append(p)
|
||||
@@ -81,18 +73,17 @@ class Alpha158(DataHandlerLP):
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": {
|
||||
"feature": self.get_feature_config(),
|
||||
"label": self.get_label_config()
|
||||
},
|
||||
}
|
||||
"config": {"feature": self.get_feature_config(), "label": self.get_label_config()},
|
||||
},
|
||||
}
|
||||
super().__init__(instruments,
|
||||
start_time,
|
||||
end_time,
|
||||
data_loader=data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors)
|
||||
super().__init__(
|
||||
instruments,
|
||||
start_time,
|
||||
end_time,
|
||||
data_loader=data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
)
|
||||
|
||||
def get_feature_config(self):
|
||||
conf = {
|
||||
@@ -247,7 +238,8 @@ class Alpha158(DataHandlerLP):
|
||||
if use("SUMD"):
|
||||
fields += [
|
||||
"(Sum(Greater($close-Ref($close, 1), 0), %d)-Sum(Greater(Ref($close, 1)-$close, 0), %d))"
|
||||
"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d) for d in windows
|
||||
"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMD%d" % d for d in windows]
|
||||
if use("VMA"):
|
||||
@@ -258,26 +250,30 @@ class Alpha158(DataHandlerLP):
|
||||
names += ["VSTD%d" % d for d in windows]
|
||||
if use("WVMA"):
|
||||
fields += [
|
||||
"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)" %
|
||||
(d, d) for d in windows
|
||||
"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)"
|
||||
% (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["WVMA%d" % d for d in windows]
|
||||
if use("VSUMP"):
|
||||
fields += [
|
||||
"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d)
|
||||
"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
|
||||
% (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMP%d" % d for d in windows]
|
||||
if use("VSUMN"):
|
||||
fields += [
|
||||
"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d)
|
||||
"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
|
||||
% (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMN%d" % d for d in windows]
|
||||
if use("VSUMD"):
|
||||
fields += [
|
||||
"(Sum(Greater($volume-Ref($volume, 1), 0), %d)-Sum(Greater(Ref($volume, 1)-$volume, 0), %d))"
|
||||
"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d) for d in windows
|
||||
"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMD%d" % d for d in windows]
|
||||
|
||||
|
||||
@@ -8,9 +8,10 @@ from ...data.dataset.processor import Processor, get_group_columns
|
||||
|
||||
|
||||
class ConfigSectionProcessor(Processor):
|
||||
'''
|
||||
"""
|
||||
This processor is designed for Alpha158. And will be replaced by simple processors in the future
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, fields_group=None, **kwargs):
|
||||
super().__init__()
|
||||
# Options
|
||||
|
||||
@@ -159,11 +159,11 @@ def get_exchange(
|
||||
if deal_price[0] != "$":
|
||||
deal_price = "$" + deal_price
|
||||
if extract_codes:
|
||||
codes = sorted(pred.index.get_level_values('instrument').unique())
|
||||
codes = sorted(pred.index.get_level_values("instrument").unique())
|
||||
else:
|
||||
codes = "all" # TODO: We must ensure that 'all.txt' includes all the stocks
|
||||
|
||||
dates = sorted(pred.index.get_level_values('datetime').unique())
|
||||
dates = sorted(pred.index.get_level_values("datetime").unique())
|
||||
dates = np.append(dates, get_date_range(dates[-1], shift=shift))
|
||||
|
||||
exchange = Exchange(
|
||||
@@ -298,7 +298,7 @@ def long_short_backtest(
|
||||
"short": short_returns(excess),
|
||||
"long_short": long_short_returns}
|
||||
"""
|
||||
if get_level_index(pred, level='datetime') == 1:
|
||||
if get_level_index(pred, level="datetime") == 1:
|
||||
pred = pred.swaplevel().sort_index()
|
||||
|
||||
if trade_unit is None:
|
||||
|
||||
@@ -12,26 +12,29 @@ from ...data.dataset.handler import DataHandlerLP
|
||||
|
||||
class LGBModel(Model):
|
||||
"""LightGBM Model"""
|
||||
|
||||
def __init__(self, loss="mse", **kwargs):
|
||||
if loss not in {"mse", "binary"}:
|
||||
raise NotImplementedError
|
||||
self._params = {'objective': loss}
|
||||
self._params = {"objective": loss}
|
||||
self._params.update(kwargs)
|
||||
self.model = None
|
||||
|
||||
def fit(self,
|
||||
dataset: DatasetH,
|
||||
num_boost_round=1000,
|
||||
early_stopping_rounds=50,
|
||||
verbose_eval=20,
|
||||
evals_result=dict(),
|
||||
**kwargs):
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
num_boost_round=1000,
|
||||
early_stopping_rounds=50,
|
||||
verbose_eval=20,
|
||||
evals_result=dict(),
|
||||
**kwargs
|
||||
):
|
||||
|
||||
df_train, df_valid = dataset.prepare(['train', 'valid'],
|
||||
col_set=['feature', 'label'],
|
||||
data_key=DataHandlerLP.DK_L)
|
||||
x_train, y_train = df_train['feature'], df_train['label']
|
||||
x_valid, y_valid = df_valid['feature'], df_valid['label']
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
# Lightgbm need 1D array as its label
|
||||
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
|
||||
@@ -41,20 +44,22 @@ class LGBModel(Model):
|
||||
|
||||
dtrain = lgb.Dataset(x_train.values, label=y_train_1d)
|
||||
dvalid = lgb.Dataset(x_valid.values, label=y_valid_1d)
|
||||
self.model = lgb.train(self._params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=["train", "valid"],
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
**kwargs)
|
||||
self.model = lgb.train(
|
||||
self._params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=["train", "valid"],
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
**kwargs
|
||||
)
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
|
||||
def predict(self, dataset):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare('test', col_set='feature')
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
return pd.Series(self.model.predict(np.squeeze(x_test.values)), index=x_test.index)
|
||||
|
||||
@@ -6,11 +6,12 @@ import pandas as pd
|
||||
|
||||
|
||||
class Dataset(Serializable):
|
||||
'''
|
||||
"""
|
||||
Preparing data for model training and inferencing.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
'''
|
||||
"""
|
||||
init is designed to finish following steps
|
||||
- setup data
|
||||
- The data related attributes' names should start with '_' so that it will not be saved on disk when serializing
|
||||
@@ -18,7 +19,7 @@ class Dataset(Serializable):
|
||||
- The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing.
|
||||
|
||||
The data could specify the info to caculate the essential data for preparation
|
||||
'''
|
||||
"""
|
||||
self.setup_data(*args, **kwargs)
|
||||
super().__init__()
|
||||
|
||||
@@ -51,14 +52,15 @@ class Dataset(Serializable):
|
||||
|
||||
|
||||
class DatasetH(Dataset):
|
||||
'''
|
||||
"""
|
||||
Dataset with Data(H)anler
|
||||
|
||||
User should try to put the data preprocessing functions into handler.
|
||||
Only following data processing functions should be placed in Dataset
|
||||
- The processing is related to specific model.
|
||||
- The processing is related to data split
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, handler: Union[dict, DataHandler], segments: list):
|
||||
"""
|
||||
Parameters
|
||||
@@ -96,10 +98,9 @@ class DatasetH(Dataset):
|
||||
self._handler = init_instance_by_config(handler, accept_types=DataHandler)
|
||||
self._segments = segments
|
||||
|
||||
def prepare(self,
|
||||
segments: Union[List[str], Tuple[str], str, slice],
|
||||
col_set=DataHandler.CS_ALL,
|
||||
**kwargs) -> Union[List[pd.DataFrame], pd.DataFrame]:
|
||||
def prepare(
|
||||
self, segments: Union[List[str], Tuple[str], str, slice], col_set=DataHandler.CS_ALL, **kwargs
|
||||
) -> Union[List[pd.DataFrame], pd.DataFrame]:
|
||||
"""
|
||||
prepare the data for learning and inference
|
||||
|
||||
@@ -124,9 +125,7 @@ class DatasetH(Dataset):
|
||||
[TODO:description]
|
||||
"""
|
||||
if isinstance(segments, (list, tuple)):
|
||||
return [
|
||||
self._handler.fetch(slice(*self._segments[seg]), col_set=col_set, **kwargs) for seg in segments
|
||||
]
|
||||
return [self._handler.fetch(slice(*self._segments[seg]), col_set=col_set, **kwargs) for seg in segments]
|
||||
elif isinstance(segments, str):
|
||||
return self._handler.fetch(slice(*self._segments[segments]), col_set=col_set, **kwargs)
|
||||
else:
|
||||
|
||||
@@ -25,7 +25,7 @@ from . import loader as data_loader_module
|
||||
|
||||
# TODO: A more general handler interface which does not relies on internal pd.DataFrame is needed.
|
||||
class DataHandler(Serializable):
|
||||
'''
|
||||
"""
|
||||
The steps to using a handler
|
||||
1. initialized data handler (call by `init`).
|
||||
2. use the data
|
||||
@@ -46,13 +46,21 @@ class DataHandler(Serializable):
|
||||
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
|
||||
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
|
||||
|
||||
'''
|
||||
def __init__(self, instruments, start_time=None, end_time=None, data_loader: Tuple[dict, str, DataLoader]=None, init_data=True):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
instruments,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
data_loader: Tuple[dict, str, DataLoader] = None,
|
||||
init_data=True,
|
||||
):
|
||||
# Set logger
|
||||
self.logger = get_module_logger("DataHandler")
|
||||
|
||||
# Setup data loader
|
||||
assert(data_loader is not None) # to make start_time end_time could have None default value
|
||||
assert data_loader is not None # to make start_time end_time could have None default value
|
||||
self.data_loader = init_instance_by_config(data_loader, data_loader_module, accept_types=DataLoader)
|
||||
|
||||
self.instruments = instruments
|
||||
@@ -62,7 +70,7 @@ class DataHandler(Serializable):
|
||||
self.init()
|
||||
super().__init__()
|
||||
|
||||
def init(self, enable_cache: bool=True):
|
||||
def init(self, enable_cache: bool = True):
|
||||
"""
|
||||
initialize the data.
|
||||
In case of running intialization for multiple time, it will do nothing for the second time.
|
||||
@@ -83,7 +91,9 @@ class DataHandler(Serializable):
|
||||
self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time)
|
||||
# TODO: cache
|
||||
|
||||
def _fetch_df_by_index(self, df: pd.DataFrame, selector: Union[pd.Timestamp, slice, str, list], level: Union[str, int]) -> pd.DataFrame:
|
||||
def _fetch_df_by_index(
|
||||
self, df: pd.DataFrame, selector: Union[pd.Timestamp, slice, str, list], level: Union[str, int]
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
fetch data from `data` with `selector` and `level`
|
||||
|
||||
@@ -100,7 +110,7 @@ class DataHandler(Serializable):
|
||||
idx_slc = idx_slc[1], idx_slc[0]
|
||||
return df.loc(axis=0)[idx_slc]
|
||||
|
||||
CS_ALL = '__all'
|
||||
CS_ALL = "__all"
|
||||
|
||||
def _fetch_df_by_col(self, df: pd.DataFrame, col_set: str) -> pd.DataFrame:
|
||||
cln = len(df.columns.levels)
|
||||
@@ -111,10 +121,12 @@ class DataHandler(Serializable):
|
||||
else:
|
||||
return df.loc(axis=1)[col_set]
|
||||
|
||||
def fetch(self,
|
||||
selector: Union[pd.Timestamp, slice, str],
|
||||
level: Union[str, int] = 'datetime',
|
||||
col_set: Union[str, List[str]] = CS_ALL) -> pd.DataFrame:
|
||||
def fetch(
|
||||
self,
|
||||
selector: Union[pd.Timestamp, slice, str],
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = CS_ALL,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
fetch data from underlying data source
|
||||
|
||||
@@ -157,32 +169,35 @@ class DataHandler(Serializable):
|
||||
|
||||
|
||||
class DataHandlerLP(DataHandler):
|
||||
'''
|
||||
"""
|
||||
DataHandler with **(L)earnable (P)rocessor**
|
||||
'''
|
||||
"""
|
||||
|
||||
# data key
|
||||
DK_R = 'raw'
|
||||
DK_I = 'infer'
|
||||
DK_L = 'learn'
|
||||
DK_R = "raw"
|
||||
DK_I = "infer"
|
||||
DK_L = "learn"
|
||||
|
||||
# process type
|
||||
PTYPE_I = 'independent'
|
||||
PTYPE_I = "independent"
|
||||
# - _proc_infer_df will processed by infer_processors
|
||||
# - _proc_learn_df will be processed by learn_processors
|
||||
PTYPE_A = 'append'
|
||||
PTYPE_A = "append"
|
||||
# - _proc_infer_df will processed by infer_processors
|
||||
# - _proc_learn_df will be processed by infer_processors + learn_processors
|
||||
# - (e.g. _proc_infer_df processed by learn_processors )
|
||||
|
||||
def __init__(self,
|
||||
instruments,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
data_loader: Tuple[dict, str, DataLoader] = None,
|
||||
infer_processors=[],
|
||||
learn_processors=[],
|
||||
process_type=PTYPE_A,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
instruments,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
data_loader: Tuple[dict, str, DataLoader] = None,
|
||||
infer_processors=[],
|
||||
learn_processors=[],
|
||||
process_type=PTYPE_A,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -217,10 +232,11 @@ class DataHandlerLP(DataHandler):
|
||||
# Setup preprocessor
|
||||
self.infer_processors = [] # for lint
|
||||
self.learn_processors = [] # for lint
|
||||
for pname in 'infer_processors', 'learn_processors':
|
||||
for pname in "infer_processors", "learn_processors":
|
||||
for proc in locals()[pname]:
|
||||
getattr(self, pname).append(init_instance_by_config(proc, processor_module,
|
||||
accept_types=(processor_module.Processor,)))
|
||||
getattr(self, pname).append(
|
||||
init_instance_by_config(proc, processor_module, accept_types=(processor_module.Processor,))
|
||||
)
|
||||
|
||||
self.process_type = process_type
|
||||
super().__init__(instruments, start_time, end_time, data_loader, **kwargs)
|
||||
@@ -240,8 +256,7 @@ class DataHandlerLP(DataHandler):
|
||||
"""
|
||||
self.process_data(with_fit=True)
|
||||
|
||||
|
||||
def process_data(self, with_fit: bool=False):
|
||||
def process_data(self, with_fit: bool = False):
|
||||
"""
|
||||
process_data data. Fun `processor.fit` if necessary
|
||||
|
||||
@@ -281,11 +296,11 @@ class DataHandlerLP(DataHandler):
|
||||
self._learn = _learn_df
|
||||
|
||||
# init type
|
||||
IT_FIT_SEQ = 'fit_seq' # the input of `fit` will be the output of the previous processor
|
||||
IT_FIT_IND = 'fit_ind' # the input of `fit` will be the original df
|
||||
IT_LS = 'load_state' # The state of the object has been load by pickle
|
||||
IT_FIT_SEQ = "fit_seq" # the input of `fit` will be the output of the previous processor
|
||||
IT_FIT_IND = "fit_ind" # the input of `fit` will be the original df
|
||||
IT_LS = "load_state" # The state of the object has been load by pickle
|
||||
|
||||
def init(self, init_type: str=IT_FIT_SEQ, enable_cache: bool=False):
|
||||
def init(self, init_type: str = IT_FIT_SEQ, enable_cache: bool = False):
|
||||
"""
|
||||
Initialize the data of Qlib
|
||||
|
||||
@@ -314,15 +329,17 @@ class DataHandlerLP(DataHandler):
|
||||
|
||||
# TODO: Be able to cache handler data. Save the memory for data processing
|
||||
|
||||
def _get_df_by_key(self, data_key: str=DK_I) -> pd.DataFrame:
|
||||
df = getattr(self, {self.DK_R: '_data', self.DK_I: "_infer", self.DK_L: "_learn"}[data_key])
|
||||
def _get_df_by_key(self, data_key: str = DK_I) -> pd.DataFrame:
|
||||
df = getattr(self, {self.DK_R: "_data", self.DK_I: "_infer", self.DK_L: "_learn"}[data_key])
|
||||
return df
|
||||
|
||||
def fetch(self,
|
||||
selector: Union[pd.Timestamp, slice, str],
|
||||
level: Union[str, int] = 'datetime',
|
||||
col_set=DataHandler.CS_ALL,
|
||||
data_key: str = DK_I) -> pd.DataFrame:
|
||||
def fetch(
|
||||
self,
|
||||
selector: Union[pd.Timestamp, slice, str],
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set=DataHandler.CS_ALL,
|
||||
data_key: str = DK_I,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
fetch data from underlying data source
|
||||
|
||||
@@ -345,7 +362,7 @@ class DataHandlerLP(DataHandler):
|
||||
df = self._fetch_df_by_index(df, selector, level)
|
||||
return self._fetch_df_by_col(df, col_set)
|
||||
|
||||
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str=DK_I) -> list:
|
||||
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list:
|
||||
"""
|
||||
get the column names
|
||||
|
||||
|
||||
@@ -8,44 +8,46 @@ from typing import Tuple
|
||||
|
||||
|
||||
class DataLoader(ABC):
|
||||
'''
|
||||
"""
|
||||
DataLoader is designed for loading raw data from original data source.
|
||||
'''
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
"""
|
||||
load the data as pd.DataFrame
|
||||
load the data as pd.DataFrame
|
||||
|
||||
Parameters
|
||||
----------
|
||||
self : [TODO:type]
|
||||
[TODO:description]
|
||||
instruments : [TODO:type]
|
||||
[TODO:description]
|
||||
start_time : [TODO:type]
|
||||
[TODO:description]
|
||||
end_time : [TODO:type]
|
||||
[TODO:description]
|
||||
Parameters
|
||||
----------
|
||||
self : [TODO:type]
|
||||
[TODO:description]
|
||||
instruments : [TODO:type]
|
||||
[TODO:description]
|
||||
start_time : [TODO:type]
|
||||
[TODO:description]
|
||||
end_time : [TODO:type]
|
||||
[TODO:description]
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
data load from the under layer source
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
data load from the under layer source
|
||||
|
||||
Example of the data:
|
||||
The multi-index of the columns is optional.
|
||||
feature label
|
||||
$close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0
|
||||
datetime instrument
|
||||
2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032
|
||||
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
|
||||
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
|
||||
Example of the data:
|
||||
The multi-index of the columns is optional.
|
||||
feature label
|
||||
$close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0
|
||||
datetime instrument
|
||||
2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032
|
||||
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
|
||||
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class QlibDataLoader(DataLoader):
|
||||
'''Same as QlibDataLoader. The fields can be define by config'''
|
||||
"""Same as QlibDataLoader. The fields can be define by config"""
|
||||
|
||||
def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None):
|
||||
"""
|
||||
Parameters
|
||||
@@ -65,7 +67,7 @@ class QlibDataLoader(DataLoader):
|
||||
Here is a few examples to describe the fields
|
||||
TODO:
|
||||
"""
|
||||
self.is_group = isinstance(config, dict)
|
||||
self.is_group = isinstance(config, dict)
|
||||
|
||||
if self.is_group:
|
||||
self.fields = {grp: self._parse_fields_info(fields_info) for grp, fields_info in config.items()}
|
||||
@@ -88,6 +90,7 @@ class QlibDataLoader(DataLoader):
|
||||
df = D.features(D.instruments(instruments, filter_pipe=self.filter_pipe), exprs, start_time, end_time)
|
||||
df.columns = names
|
||||
return df
|
||||
|
||||
if self.is_group:
|
||||
df = pd.concat({grp: _get_df(exprs, names) for grp, (exprs, names) in self.fields.items()}, axis=1)
|
||||
else:
|
||||
|
||||
@@ -30,8 +30,7 @@ def get_group_columns(df: pd.DataFrame, group: str):
|
||||
|
||||
|
||||
class Processor(Serializable):
|
||||
|
||||
def fit(self, df: pd.DataFrame=None):
|
||||
def fit(self, df: pd.DataFrame = None):
|
||||
"""
|
||||
learn data processing parameters
|
||||
|
||||
@@ -40,7 +39,7 @@ class Processor(Serializable):
|
||||
df : pd.DataFrame
|
||||
When we fit and process data with processor one by one. The fit function reiles on the output of previous
|
||||
processor, i.e. `df`.
|
||||
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -81,16 +80,17 @@ class DropnaProcessor(Processor):
|
||||
|
||||
|
||||
class DropnaLabel(DropnaProcessor):
|
||||
def __init__(self, group='label'):
|
||||
def __init__(self, group="label"):
|
||||
super().__init__(group=group)
|
||||
|
||||
def is_for_infer(self) -> bool:
|
||||
'''The samples are dropped according to label. So it is not usable for inference'''
|
||||
"""The samples are dropped according to label. So it is not usable for inference"""
|
||||
return False
|
||||
|
||||
|
||||
class ProcessInf(Processor):
|
||||
'''Process infinity '''
|
||||
"""Process infinity """
|
||||
|
||||
def __call__(self, df):
|
||||
def replace_inf(data):
|
||||
def process_inf(df):
|
||||
@@ -102,6 +102,7 @@ class ProcessInf(Processor):
|
||||
data = data.groupby("datetime").apply(process_inf)
|
||||
data.sort_index(inplace=True)
|
||||
return data
|
||||
|
||||
return replace_inf(df)
|
||||
|
||||
|
||||
@@ -126,6 +127,7 @@ class MinMaxNorm(Processor):
|
||||
if not ignore[i]:
|
||||
x[i] = (x[i] - min_val) / (max_val - min_val)
|
||||
return x
|
||||
|
||||
df.loc(axis=1)[self.cols] = normalize(df[self.cols].values)
|
||||
return df
|
||||
|
||||
@@ -151,17 +153,19 @@ class ZscoreNorm(Processor):
|
||||
if not ignore[i]:
|
||||
x[i] = (x[i] - mean_train) / std_train
|
||||
return x
|
||||
|
||||
df.loc(axis=1)[self.cols] = normalize(df[self.cols].values)
|
||||
return df
|
||||
|
||||
|
||||
class CSZScoreNorm(Processor):
|
||||
'''Cross Sectional ZScore Normalization'''
|
||||
"""Cross Sectional ZScore Normalization"""
|
||||
|
||||
def __init__(self, fields_group=None):
|
||||
self.fields_group = fields_group
|
||||
|
||||
def __call__(self, df):
|
||||
# try not modify original dataframe
|
||||
cols = get_group_columns(df,self.fields_group)
|
||||
df[cols] = df[cols].groupby('datetime').apply(lambda df: (df - df.mean()).div(df.std()))
|
||||
cols = get_group_columns(df, self.fields_group)
|
||||
df[cols] = df[cols].groupby("datetime").apply(lambda df: (df - df.mean()).div(df.std()))
|
||||
return df
|
||||
|
||||
@@ -24,9 +24,8 @@ def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int:
|
||||
return df.index.names.index(level)
|
||||
except (AttributeError, ValueError):
|
||||
# NOTE: If level index is not given in the data, the default level index will be ('datetime', 'instrument')
|
||||
return ('datetime', 'instrument').index(level)
|
||||
return ("datetime", "instrument").index(level)
|
||||
elif isinstance(level, int):
|
||||
return level
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from ..data.dataset import Dataset
|
||||
|
||||
|
||||
class BaseModel(Serializable, metaclass=abc.ABCMeta):
|
||||
'''Modeling things'''
|
||||
"""Modeling things"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def predict(self, *args, **kwargs) -> object:
|
||||
@@ -19,7 +19,7 @@ class BaseModel(Serializable, metaclass=abc.ABCMeta):
|
||||
|
||||
|
||||
class Model(BaseModel):
|
||||
'''Learnable Models'''
|
||||
"""Learnable Models"""
|
||||
|
||||
def fit(self, dataset: Dataset):
|
||||
"""
|
||||
|
||||
@@ -165,7 +165,7 @@ def get_module_by_module_path(module_path):
|
||||
return module
|
||||
|
||||
|
||||
def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
|
||||
def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
|
||||
"""
|
||||
extract class and kwargs from config info
|
||||
|
||||
@@ -184,8 +184,8 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
|
||||
"""
|
||||
if isinstance(config, dict):
|
||||
# raise AttributeError
|
||||
klass = getattr(module, config['class'])
|
||||
kwargs = config['kwargs']
|
||||
klass = getattr(module, config["class"])
|
||||
kwargs = config["kwargs"]
|
||||
elif isinstance(config, str):
|
||||
klass = getattr(module, config)
|
||||
kwargs = {}
|
||||
@@ -194,7 +194,9 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
|
||||
return klass, kwargs
|
||||
|
||||
|
||||
def init_instance_by_config(config: Union[str, dict], module=None, accept_types: Union[type, Tuple[type]]=tuple([])) -> object:
|
||||
def init_instance_by_config(
|
||||
config: Union[str, dict], module=None, accept_types: Union[type, Tuple[type]] = tuple([])
|
||||
) -> object:
|
||||
"""
|
||||
get initialized instance with config
|
||||
|
||||
@@ -647,4 +649,4 @@ def register_wrapper(wrapper, cls_or_obj):
|
||||
module = get_module_by_module_path("qlib.data")
|
||||
cls_or_obj = getattr(module, cls_or_obj)
|
||||
obj = cls_or_obj() if isinstance(cls_or_obj, type) else cls_or_obj
|
||||
wrapper.register(obj)
|
||||
wrapper.register(obj)
|
||||
|
||||
@@ -24,7 +24,7 @@ class ObjManager:
|
||||
|
||||
def save_objs(self, obj_name_l):
|
||||
"""
|
||||
save objects
|
||||
save objects
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -88,9 +88,10 @@ class ObjManager:
|
||||
|
||||
|
||||
class FileManager(ObjManager):
|
||||
'''
|
||||
"""
|
||||
Use file system to manage objects
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, path=None):
|
||||
if path is None:
|
||||
self.path = Path(self.create_path())
|
||||
@@ -99,12 +100,12 @@ class FileManager(ObjManager):
|
||||
|
||||
def create_path(self) -> str:
|
||||
try:
|
||||
return tempfile.mkdtemp(prefix=str(C['file_manager_path']) + os.sep)
|
||||
return tempfile.mkdtemp(prefix=str(C["file_manager_path"]) + os.sep)
|
||||
except AttributeError:
|
||||
raise NotImplementedError(f"If path is not given, the `create_path` function should be implemented")
|
||||
|
||||
def save_obj(self, obj, name):
|
||||
with (self.path / name).open('wb') as f:
|
||||
with (self.path / name).open("wb") as f:
|
||||
pickle.dump(obj, f)
|
||||
|
||||
def save_objs(self, obj_name_l):
|
||||
@@ -112,7 +113,7 @@ class FileManager(ObjManager):
|
||||
self.save_obj(obj, name)
|
||||
|
||||
def load_obj(self, name):
|
||||
with (self.path / name).open('rb') as f:
|
||||
with (self.path / name).open("rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def exists(self, name):
|
||||
@@ -123,7 +124,7 @@ class FileManager(ObjManager):
|
||||
|
||||
def remove(self, fname=None):
|
||||
if fname is None:
|
||||
for fp in self.path.glob('*'):
|
||||
for fp in self.path.glob("*"):
|
||||
fp.unlink()
|
||||
self.path.rmdir()
|
||||
else:
|
||||
|
||||
@@ -6,17 +6,17 @@ import pickle
|
||||
|
||||
|
||||
class Serializable:
|
||||
'''
|
||||
"""
|
||||
Serializable behaves like pickle.
|
||||
But it only save the state whose name starts with `_`
|
||||
'''
|
||||
"""
|
||||
|
||||
def __getstate__(self) -> dict:
|
||||
return {k: v for k, v in self.__dict__.items() if k.startswith('_') }
|
||||
return {k: v for k, v in self.__dict__.items() if k.startswith("_")}
|
||||
|
||||
def __setstate__(self, state: dict):
|
||||
self.__dict__.update(state)
|
||||
|
||||
def to_pickle(self, path: [Path, str]):
|
||||
with Path(path).open('wb') as f:
|
||||
with Path(path).open("wb") as f:
|
||||
pickle.dump(self, f)
|
||||
|
||||
@@ -5,6 +5,7 @@ from contextlib import contextmanager
|
||||
from .expm import *
|
||||
from ..utils import Wrapper
|
||||
|
||||
|
||||
class QlibRecorder:
|
||||
def __init__(self, exp_manager, default_uri, current_uri):
|
||||
self.exp_manager = exp_manager
|
||||
@@ -16,16 +17,16 @@ class QlibRecorder:
|
||||
run = self.start_exp(experiment_name, self.current_uri)
|
||||
yield run
|
||||
self.end_exp()
|
||||
|
||||
|
||||
def start_exp(self, experiment_name=None):
|
||||
return self.exp_manager.start_exp(experiment_name, self.current_uri)
|
||||
return self.exp_manager.start_exp(experiment_name, self.current_uri)
|
||||
|
||||
def end_exp(self):
|
||||
self.exp_manager.end_exp()
|
||||
|
||||
|
||||
def search_records(self, experiment_ids, **kwargs):
|
||||
return self.exp_manager.search_records(experiment_ids, **kwargs)
|
||||
|
||||
|
||||
def get_exp(self, experiment_id=None, experiment_name=None):
|
||||
return self.exp_manager.get_exp(experiment_id, experiment_name)
|
||||
|
||||
@@ -52,12 +53,13 @@ class QlibRecorder:
|
||||
|
||||
def log_metrics(self, step=None, **kwargs):
|
||||
self.exp_manager.active_recorder.log_metrics(step, **kwargs)
|
||||
|
||||
|
||||
def set_tags(self, **kwargs):
|
||||
self.exp_manager.active_recorder.set_tags(**kwargs)
|
||||
|
||||
def delete_tag(self, key):
|
||||
self.exp_manager.active_recorder.delete_tag(key)
|
||||
|
||||
|
||||
# global record
|
||||
R = Wrapper()
|
||||
|
||||
@@ -4,10 +4,12 @@
|
||||
import mlflow
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class Experiment:
|
||||
"""
|
||||
Thie is the `Experiment` class for each experiment being run. The API is designed
|
||||
Thie is the `Experiment` class for each experiment being run. The API is designed
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.name = None
|
||||
self.id = None
|
||||
@@ -39,9 +41,10 @@ class MLflowExperiment(Experiment):
|
||||
"""
|
||||
Use mlflow to implement Experiment.
|
||||
"""
|
||||
|
||||
def search_records(self, **kwargs):
|
||||
filter_string = '' if kwargs.get('filter_string') is None else kwargs.get('filter_string')
|
||||
run_view_type = 1 if kwargs.get('run_view_type') is None else kwargs.get('run_view_type')
|
||||
max_results = 100000 if kwargs.get('max_results') is None else kwargs.get('max_results')
|
||||
order_by = kwargs.get('order_by')
|
||||
return mlflow.search_runs([self.experiment_id], filter_string, run_view_type, max_results, order_by)
|
||||
filter_string = "" if kwargs.get("filter_string") is None else kwargs.get("filter_string")
|
||||
run_view_type = 1 if kwargs.get("run_view_type") is None else kwargs.get("run_view_type")
|
||||
max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results")
|
||||
order_by = kwargs.get("order_by")
|
||||
return mlflow.search_runs([self.experiment_id], filter_string, run_view_type, max_results, order_by)
|
||||
|
||||
@@ -8,15 +8,17 @@ from contextlib import contextmanager
|
||||
from .exp import MLflowExperiment
|
||||
from .record import MLflowRecorder
|
||||
|
||||
|
||||
class ExpManager:
|
||||
"""
|
||||
This is the `ExpManager` class for managing the experiments. The API is designed similar to mlflow.
|
||||
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.default_uri = None
|
||||
self.active_recorder = None # only one recorder can running each time
|
||||
self.experiments = dict() # store the experiment name --> Experiment object
|
||||
self.active_recorder = None # only one recorder can running each time
|
||||
self.experiments = dict() # store the experiment name --> Experiment object
|
||||
|
||||
def start_exp(self, experiment_name=None, uri=None, **kwargs):
|
||||
"""
|
||||
@@ -88,7 +90,7 @@ class ExpManager:
|
||||
An experiment object.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `create_exp` method.")
|
||||
|
||||
|
||||
def get_exp(self, experiment_id=None, experiment_name=None):
|
||||
"""
|
||||
Retrieve an experiment by experiment_id from the backend store.
|
||||
@@ -111,7 +113,7 @@ class ExpManager:
|
||||
Parameters
|
||||
----------
|
||||
experiment_id : str
|
||||
the experiment id.
|
||||
the experiment id.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `create_exp` method.")
|
||||
|
||||
@@ -142,12 +144,13 @@ class ExpManager:
|
||||
An Recorder object.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_recorder` method.")
|
||||
|
||||
|
||||
|
||||
class MLflowExpManager(ExpManager):
|
||||
'''
|
||||
"""
|
||||
Use mlflow to implement ExpManager.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(MLflowExpManager, self).__init__()
|
||||
self.default_uri = None
|
||||
@@ -169,27 +172,31 @@ class MLflowExpManager(ExpManager):
|
||||
def end_exp(self):
|
||||
self.active_recorder.end_run()
|
||||
self.active_recorder = None
|
||||
|
||||
|
||||
def __create_exp(self, experiment_name=None, uri=None):
|
||||
# init experiment
|
||||
experiment = MLflowExperiment()
|
||||
# set the tracking uri
|
||||
if uri is None:
|
||||
print('No tracking URI is provided. The default tracking URI is set as `mlruns` under the working directory.')
|
||||
print(
|
||||
"No tracking URI is provided. The default tracking URI is set as `mlruns` under the working directory."
|
||||
)
|
||||
else:
|
||||
self.current_uri = uri
|
||||
mlflow.set_tracking_uri(self.current_uri)
|
||||
# start the experiment
|
||||
if experiment_name is None:
|
||||
print('No experiment name provided. The default experiment name is set as `experiment`.')
|
||||
experiment_id = mlflow.create_experiment('experiment')
|
||||
print("No experiment name provided. The default experiment name is set as `experiment`.")
|
||||
experiment_id = mlflow.create_experiment("experiment")
|
||||
# set the active experiment
|
||||
mlflow.set_experiment('experiment')
|
||||
experiment_name = 'experiment'
|
||||
mlflow.set_experiment("experiment")
|
||||
experiment_name = "experiment"
|
||||
else:
|
||||
if experiment_name not in self.experiments:
|
||||
if mlflow.get_experiment_by_name(experiment_name) is not None:
|
||||
raise Exception('The experiment has already been created before. Please pick another name or delete the files under uri.')
|
||||
raise Exception(
|
||||
"The experiment has already been created before. Please pick another name or delete the files under uri."
|
||||
)
|
||||
experiment_id = mlflow.create_experiment(experiment_name)
|
||||
else:
|
||||
experiment_id = self.experiments[experiment_name].id
|
||||
@@ -197,40 +204,42 @@ class MLflowExpManager(ExpManager):
|
||||
# set the active experiment
|
||||
mlflow.set_experiment(experiment_name)
|
||||
# set up experiment
|
||||
experiment.id = experiment_id
|
||||
experiment.id = experiment_id
|
||||
experiment.name = experiment_name
|
||||
|
||||
return experiment
|
||||
|
||||
|
||||
def search_records(self, experiment_ids, **kwargs):
|
||||
filter_string = '' if kwargs.get('filter_string') is None else kwargs.get('filter_string')
|
||||
run_view_type = 1 if kwargs.get('run_view_type') is None else kwargs.get('run_view_type')
|
||||
max_results = 100000 if kwargs.get('max_results') is None else kwargs.get('max_results')
|
||||
order_by = kwargs.get('order_by')
|
||||
filter_string = "" if kwargs.get("filter_string") is None else kwargs.get("filter_string")
|
||||
run_view_type = 1 if kwargs.get("run_view_type") is None else kwargs.get("run_view_type")
|
||||
max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results")
|
||||
order_by = kwargs.get("order_by")
|
||||
return mlflow.search_runs(experiment_ids, filter_string, run_view_type, max_results, order_by)
|
||||
|
||||
|
||||
def get_exp(self, experiment_id=None, experiment_name=None):
|
||||
assert experiment_id is not None or experiment_name is not None, 'Please provide at least one of the experiment id or name to retrieve an experiment.'
|
||||
assert (
|
||||
experiment_id is not None or experiment_name is not None
|
||||
), "Please provide at least one of the experiment id or name to retrieve an experiment."
|
||||
if experiment_name is not None:
|
||||
return self.experiments[experiment_name]
|
||||
elif:
|
||||
elif experiment_id is not None:
|
||||
for name in self.experiments:
|
||||
if self.experiments[name].id == experiment_id:
|
||||
return self.experiments[name]
|
||||
else:
|
||||
print('No valid experiment is found. Please make sure the id and name are correctly given.')
|
||||
print("No valid experiment is found. Please make sure the id and name are correctly given.")
|
||||
|
||||
def delete_exp(self, experiment_id):
|
||||
mlflow.delete_experiment(experiment_id)
|
||||
self.experiments = {key:val for key, val in self.experiments.items() if val.id != experiment_id}
|
||||
self.experiments = {key: val for key, val in self.experiments.items() if val.id != experiment_id}
|
||||
|
||||
def get_uri(self, type):
|
||||
if uri == 'default':
|
||||
if uri == "default":
|
||||
return self.default_uri
|
||||
elif uri == 'current':
|
||||
elif uri == "current":
|
||||
return self.current_uri
|
||||
else:
|
||||
raise ValueError('Input type is not supported. Please choose type default or current to get the uri.')
|
||||
raise ValueError("Input type is not supported. Please choose type default or current to get the uri.")
|
||||
|
||||
def get_recorder(self):
|
||||
return self.active_recorder
|
||||
return self.active_recorder
|
||||
|
||||
@@ -6,6 +6,7 @@ import shutil, os, pickle, tempfile, codecs
|
||||
from pathlib import Path
|
||||
from ..utils.objm import FileManager
|
||||
|
||||
|
||||
class Recorder:
|
||||
"""
|
||||
This is the `Recorder` class for logging the experiments. The API is designed similar to mlflow.
|
||||
@@ -16,7 +17,7 @@ class Recorder:
|
||||
self.experiment_id = experiment_id
|
||||
self.recorder_id = None
|
||||
self.recorder_name = None
|
||||
|
||||
|
||||
def set_recorder_name(self, rname):
|
||||
self.recorder_name = rname
|
||||
|
||||
@@ -63,10 +64,9 @@ class Recorder:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `load_object` method.")
|
||||
|
||||
def start_run(self, run_id=None, experiment_id=None,
|
||||
run_name=None, nested=False):
|
||||
def start_run(self, run_id=None, experiment_id=None, run_name=None, nested=False):
|
||||
"""
|
||||
Start running the Recorder. The return value can be used as a context manager within a `with` block;
|
||||
Start running the Recorder. The return value can be used as a context manager within a `with` block;
|
||||
otherwise, you must call end_run() to terminate the current run. (See `ActiveRun` class in mlflow)
|
||||
|
||||
Parameters
|
||||
@@ -85,7 +85,7 @@ class Recorder:
|
||||
An active running object (e.g. mlflow.ActiveRun object).
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `start_run` method.")
|
||||
|
||||
|
||||
def end_run(self):
|
||||
"""
|
||||
End an active Recorder.
|
||||
@@ -138,19 +138,19 @@ class Recorder:
|
||||
|
||||
|
||||
class MLflowRecorder(Recorder):
|
||||
'''
|
||||
"""
|
||||
Use mlflow to implement a Recorder.
|
||||
|
||||
Due to the fact that mlflow will only log artifact from a file or directory, we decide to
|
||||
Due to the fact that mlflow will only log artifact from a file or directory, we decide to
|
||||
use file manager to help maintain the objects in the project.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_id):
|
||||
super(MLflowRecorder, self).__init__(experiment_id)
|
||||
self.fm = None
|
||||
self.temp_dir = None
|
||||
|
||||
def start_run(self, run_id=None, experiment_id=None,
|
||||
run_name=None, nested=False):
|
||||
def start_run(self, run_id=None, experiment_id=None, run_name=None, nested=False):
|
||||
if run_id is None:
|
||||
run_id = self.recorder_id
|
||||
if experiment_id is None:
|
||||
@@ -166,7 +166,7 @@ class MLflowRecorder(Recorder):
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.fm = FileManager(Path(self.temp_dir).absolute())
|
||||
return run
|
||||
|
||||
|
||||
def end_run(self):
|
||||
mlflow.end_run()
|
||||
shutil.rmtree(self.temp_dir)
|
||||
@@ -194,13 +194,13 @@ class MLflowRecorder(Recorder):
|
||||
client = mlflow.tracking.MlflowClient()
|
||||
path = client.download_artifacts(self.recorder_id, name)
|
||||
try:
|
||||
with Path(path).open('rb') as f:
|
||||
with Path(path).open("rb") as f:
|
||||
f.seek(0)
|
||||
return pickle.load(f)
|
||||
except:
|
||||
with codecs.open(path, mode="r", encoding='utf-8') as f:
|
||||
return f.read()
|
||||
|
||||
with codecs.open(path, mode="r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
def log_params(self, **kwargs):
|
||||
keys = list(kwargs.keys())
|
||||
if len(keys) == 0:
|
||||
@@ -214,7 +214,7 @@ class MLflowRecorder(Recorder):
|
||||
mlflow.log_metric(keys[0], kwargs.get(keys[0]))
|
||||
else:
|
||||
mlflow.log_metrics(dict(kwargs))
|
||||
|
||||
|
||||
def set_tags(self, **kwargs):
|
||||
keys = list(kwargs.keys())
|
||||
if len(keys) == 0:
|
||||
@@ -228,4 +228,4 @@ class MLflowRecorder(Recorder):
|
||||
def get_artifact_uri(self, artifact_path=None):
|
||||
if self.artifact_uri is not None:
|
||||
return self.artifact_uri
|
||||
return mlflow.get_artifact_uri(artifact_path)
|
||||
return mlflow.get_artifact_uri(artifact_path)
|
||||
|
||||
Reference in New Issue
Block a user