mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 01:51:18 +08:00
Merge
This commit is contained in:
@@ -16,7 +16,8 @@ from qlib.contrib.evaluate import (
|
||||
)
|
||||
from qlib.utils import exists_qlib_data
|
||||
|
||||
from qlib.model.learner import train_model
|
||||
# from qlib.model.learner import train_model
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -57,13 +58,6 @@ if __name__ == "__main__":
|
||||
"test_end_time": "2020-08-01",
|
||||
}
|
||||
|
||||
# use default DataHandler
|
||||
# custom DataHandler, refer to: TODO: DataHandler API url
|
||||
handler = Alpha158(**DATA_HANDLER_CONFIG)
|
||||
|
||||
data = handler.fetch(slice('2008-01-01', '2014-12-31'), data_key=handler.DK_I)
|
||||
print(data)
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
@@ -80,59 +74,33 @@ if __name__ == "__main__":
|
||||
"num_threads": 20,
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
'handler': {
|
||||
"class": "Alpha158",
|
||||
"kwargs": DATA_HANDLER_CONFIG
|
||||
},
|
||||
"train_start_time": "2008-01-01",
|
||||
"train_end_time": "2014-12-31",
|
||||
"validate_start_time": "2015-01-01",
|
||||
"validate_end_time": "2016-12-31",
|
||||
"test_start_time": "2017-01-01",
|
||||
"test_end_time": "2020-08-01",
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
'handler': {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": DATA_HANDLER_CONFIG
|
||||
},
|
||||
'segments': {
|
||||
'train': ("2008-01-01", "2014-12-31"),
|
||||
'valid': ("2015-01-01", "2016-12-31",),
|
||||
'test': ("2017-01-01", "2020-08-01",),
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
# You shoud record the data in specific sequence
|
||||
# "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
|
||||
}
|
||||
|
||||
model = train_model(task)
|
||||
# model = train_model(task)
|
||||
model = init_instance_by_config(task['model'])
|
||||
dataset = init_instance_by_config(task['dataset'])
|
||||
|
||||
model.fit(dataset)
|
||||
|
||||
|
||||
sys.exit(0) # I have tested the code above ---------------------------------------------
|
||||
|
||||
x_train, y_train, x_validate, y_validate, x_test, y_test = Alpha158(**DATA_HANDLER_CONFIG).get_split_data(
|
||||
**TRAINER_CONFIG
|
||||
)
|
||||
|
||||
MODEL_CONFIG = {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
}
|
||||
# use default model
|
||||
# custom Model, refer to: TODO: Model API url
|
||||
model = LGBModel(**MODEL_CONFIG)
|
||||
model.fit(x_train, y_train, x_validate, y_validate)
|
||||
_pred = model.predict(x_test)
|
||||
_pred = pd.DataFrame(_pred, index=x_test.index, columns=y_test.columns)
|
||||
|
||||
# backtest requires pred_score
|
||||
pred_score = pd.DataFrame(index=_pred.index)
|
||||
pred_score["score"] = _pred.iloc(axis=1)[0]
|
||||
pred_score = model.predict(dataset)
|
||||
|
||||
# save pred_score to file
|
||||
pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
|
||||
|
||||
@@ -10,6 +10,7 @@ from ...data import D
|
||||
from .account import Account
|
||||
from ...config import C
|
||||
from ...log import get_module_logger
|
||||
from ...data.dataset.utils import get_level_index
|
||||
|
||||
LOG = get_module_logger("backtest")
|
||||
|
||||
@@ -18,7 +19,8 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
|
||||
"""Parameters
|
||||
----------
|
||||
pred : pandas.DataFrame
|
||||
predict should has <instrument, datetime> index and one `score` column
|
||||
predict should has <datetime, instrument> index and one `score` column
|
||||
Qlib want to support multi-singal strategy in the future. So pd.Series is not used.
|
||||
strategy : Strategy()
|
||||
strategy part for backtest
|
||||
trade_exchange : Exchange()
|
||||
@@ -43,6 +45,12 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
|
||||
`benchmark` is str, will use the daily change as the 'bench'.
|
||||
benchmark code, default is SH000905 CSI500
|
||||
"""
|
||||
# Convert format if the input format is not expected
|
||||
if get_level_index(pred, level='datetime') == 1:
|
||||
pred = pred.swaplevel().sort_index()
|
||||
if isinstance(pred, pd.Series):
|
||||
pred = pred.to_frame('score')
|
||||
|
||||
trade_account = Account(init_cash=account)
|
||||
_pred_dates = pred.index.get_level_values(level="datetime")
|
||||
predict_dates = D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max())
|
||||
@@ -71,10 +79,9 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
|
||||
|
||||
# 1. Load the score_series at pred_date
|
||||
try:
|
||||
score = pred.loc(axis=0)[:, pred_date] # (stock_id, trade_date) multi_index, score in pdate
|
||||
score_series = score.reset_index(level="datetime", drop=True)[
|
||||
"score"
|
||||
] # pd.Series(index:stock_id, data: score)
|
||||
score = pred.loc(axis=0)[pred_date, :] # (trade_date, stock_id) multi_index, score in pdate
|
||||
score_series = score.reset_index(level="datetime",
|
||||
drop=True)["score"] # pd.Series(index:stock_id, data: score)
|
||||
except KeyError:
|
||||
LOG.warning("No score found on predict date[{:%Y-%m-%d}]".format(trade_date))
|
||||
score_series = None
|
||||
|
||||
@@ -15,6 +15,7 @@ from .backtest.backtest import backtest as backtest_func, get_date_range
|
||||
|
||||
from ..data import D
|
||||
from ..config import C
|
||||
from ..data.dataset.utils import get_level_index
|
||||
|
||||
logger = get_module_logger("Evaluate")
|
||||
|
||||
@@ -158,11 +159,11 @@ def get_exchange(
|
||||
if deal_price[0] != "$":
|
||||
deal_price = "$" + deal_price
|
||||
if extract_codes:
|
||||
codes = sorted(pred.index.get_level_values(0).unique())
|
||||
codes = sorted(pred.index.get_level_values('instrument').unique())
|
||||
else:
|
||||
codes = "all" # TODO: We must ensure that 'all.txt' includes all the stocks
|
||||
|
||||
dates = sorted(pred.index.get_level_values(1).unique())
|
||||
dates = sorted(pred.index.get_level_values('datetime').unique())
|
||||
dates = np.append(dates, get_date_range(dates[-1], shift=shift))
|
||||
|
||||
exchange = Exchange(
|
||||
@@ -187,7 +188,7 @@ def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, **k
|
||||
|
||||
# backtest workflow related or commmon arguments
|
||||
pred : pandas.DataFrame
|
||||
predict should has <instrument, datetime> index and one `score` column
|
||||
predict should has <datetime, instrument> index and one `score` column
|
||||
account : float
|
||||
init account value
|
||||
shift : int
|
||||
@@ -297,6 +298,8 @@ def long_short_backtest(
|
||||
"short": short_returns(excess),
|
||||
"long_short": long_short_returns}
|
||||
"""
|
||||
if get_level_index(pred, level='datetime') == 1:
|
||||
pred = pred.swaplevel().sort_index()
|
||||
|
||||
if trade_unit is None:
|
||||
trade_unit = C.trade_unit
|
||||
@@ -333,13 +336,13 @@ def long_short_backtest(
|
||||
ls_returns = {}
|
||||
|
||||
for pdate, date in zip(predict_dates, trade_dates):
|
||||
score = pred.loc(axis=0)[:, pdate]
|
||||
score = pred.loc(axis=0)[pdate, :]
|
||||
score = score.reset_index().sort_values(by="score", ascending=False)
|
||||
|
||||
long_stocks = list(score.iloc[:topk]["instrument"])
|
||||
short_stocks = list(score.iloc[-topk:]["instrument"])
|
||||
|
||||
score = score.set_index(["instrument", "datetime"]).sort_index()
|
||||
score = score.set_index(["datetime", "instrument"]).sort_index()
|
||||
|
||||
long_profit = []
|
||||
short_profit = []
|
||||
@@ -363,7 +366,7 @@ def long_short_backtest(
|
||||
else:
|
||||
short_profit.append(-profit)
|
||||
|
||||
for stock in list(score.loc(axis=0)[:, pdate].index.get_level_values(level=0)):
|
||||
for stock in list(score.loc(axis=0)[pdate, :].index.get_level_values(level=0)):
|
||||
# exclude the suspend stock
|
||||
if trade_exchange.check_stock_suspended(stock_id=stock, trade_date=date):
|
||||
continue
|
||||
|
||||
@@ -1,91 +1,60 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import lightgbm as lgb
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
|
||||
from ...model.base import Model
|
||||
from ...utils import drop_nan_by_y_index
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
class LGBModel(Model):
|
||||
"""LightGBM Model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
param_update : dict
|
||||
training parameters
|
||||
"""
|
||||
|
||||
_params = dict()
|
||||
|
||||
"""LightGBM Model"""
|
||||
def __init__(self, loss="mse", **kwargs):
|
||||
if loss not in {"mse", "binary"}:
|
||||
raise NotImplementedError
|
||||
self._scorer = mean_squared_error if loss == "mse" else roc_auc_score
|
||||
self._params.update(objective=loss, **kwargs)
|
||||
self._model = None
|
||||
self._params = {'objective': loss}
|
||||
self._params.update(kwargs)
|
||||
self.model = None
|
||||
|
||||
def fit(self,
|
||||
dataset: DatasetH,
|
||||
num_boost_round=1000,
|
||||
early_stopping_rounds=50,
|
||||
verbose_eval=20,
|
||||
evals_result=dict(),
|
||||
**kwargs):
|
||||
|
||||
df_train, df_valid = dataset.prepare(['train', 'valid'],
|
||||
col_set=['feature', 'label'],
|
||||
data_key=DataHandlerLP.DK_L)
|
||||
x_train, y_train = df_train['feature'], df_train['label']
|
||||
x_valid, y_valid = df_valid['feature'], df_valid['label']
|
||||
|
||||
def fit(
|
||||
self,
|
||||
x_train,
|
||||
y_train,
|
||||
x_valid,
|
||||
y_valid,
|
||||
w_train=None,
|
||||
w_valid=None,
|
||||
num_boost_round=1000,
|
||||
early_stopping_rounds=50,
|
||||
verbose_eval=20,
|
||||
evals_result=dict(),
|
||||
**kwargs
|
||||
):
|
||||
# Lightgbm need 1D array as its label
|
||||
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
|
||||
y_train_1d, y_valid_1d = np.squeeze(y_train.values), np.squeeze(y_valid.values)
|
||||
else:
|
||||
raise ValueError("LightGBM doesn't support multi-label training")
|
||||
|
||||
w_train_weight = None if w_train is None else w_train.values
|
||||
w_valid_weight = None if w_valid is None else w_valid.values
|
||||
|
||||
dtrain = lgb.Dataset(x_train.values, label=y_train_1d, weight=w_train_weight)
|
||||
dvalid = lgb.Dataset(x_valid.values, label=y_valid_1d, weight=w_valid_weight)
|
||||
self._model = lgb.train(
|
||||
self._params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=["train", "valid"],
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
**kwargs
|
||||
)
|
||||
dtrain = lgb.Dataset(x_train.values, label=y_train_1d)
|
||||
dvalid = lgb.Dataset(x_valid.values, label=y_valid_1d)
|
||||
self.model = lgb.train(self._params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=["train", "valid"],
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
**kwargs)
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
|
||||
def predict(self, x_test):
|
||||
if self._model is None:
|
||||
def predict(self, dataset):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
return self._model.predict(x_test.values)
|
||||
|
||||
def score(self, x_test, y_test, w_test=None):
|
||||
# Remove rows from x, y and w, which contain Nan in any columns in y_test.
|
||||
x_test, y_test, w_test = drop_nan_by_y_index(x_test, y_test, w_test)
|
||||
preds = self.predict(x_test)
|
||||
w_test_weight = None if w_test is None else w_test.values
|
||||
return self._scorer(y_test.values, preds, sample_weight=w_test_weight)
|
||||
|
||||
def save(self, filename):
|
||||
if self._model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
self._model.save_model(filename)
|
||||
|
||||
def load(self, buffer):
|
||||
self._model = lgb.Booster(params={"model_str": buffer.decode("utf-8")})
|
||||
x_test = dataset.prepare('test', col_set='feature')
|
||||
return pd.Series(self.model.predict(np.squeeze(x_test.values)), index=x_test.index)
|
||||
|
||||
@@ -1,8 +1,133 @@
|
||||
from ...utils.serial import Serializable
|
||||
from typing import Union, List, Tuple
|
||||
from ...utils import init_instance_by_config
|
||||
from .handler import DataHandler
|
||||
import pandas as pd
|
||||
|
||||
class Dataset:
|
||||
|
||||
class Dataset(Serializable):
|
||||
'''
|
||||
Preparing data for model training.
|
||||
The type of dataset depends on the model. (It could be pd.DataFrame, pytorch.DataLoader, etc.)
|
||||
Preparing data for model training and inferencing.
|
||||
'''
|
||||
def generate(self):
|
||||
def __init__(self, *args, **kwargs):
|
||||
'''
|
||||
init is designed to finish following steps
|
||||
- setup data
|
||||
- The data related attributes' names should start with '_' so that it will not be saved on disk when serializing
|
||||
- initialize the state of the dataset(info to prepare the data)
|
||||
- The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing.
|
||||
|
||||
The data could specify the info to caculate the essential data for preparation
|
||||
'''
|
||||
self.setup_data(*args, **kwargs)
|
||||
super().__init__()
|
||||
|
||||
def setup_data(self, *args, **kwargs):
|
||||
"""
|
||||
setup the data
|
||||
|
||||
We split the setup_data function for following situation
|
||||
- 1) User have a Dataset object with learned status on disk
|
||||
- 2) User load the Dataset object from the disk(Note the init function is skiped)
|
||||
- 3) User call `setup_data` to load new data
|
||||
- 4) User prepare data for model based on previous status
|
||||
"""
|
||||
pass
|
||||
|
||||
def prepare(self, *args, **kwargs) -> object:
|
||||
"""
|
||||
The type of dataset depends on the model. (It could be pd.DataFrame, pytorch.DataLoader, etc.)
|
||||
The parameters should specify the scope for the prepared data
|
||||
The method sould
|
||||
- process the data
|
||||
- return the processed data
|
||||
|
||||
Returns
|
||||
-------
|
||||
object:
|
||||
return the object
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DatasetH(Dataset):
|
||||
'''
|
||||
Dataset with Data(H)anler
|
||||
|
||||
User should try to put the data preprocessing functions into handler.
|
||||
Only following data processing functions should be placed in Dataset
|
||||
- The processing is related to specific model.
|
||||
- The processing is related to data split
|
||||
'''
|
||||
def __init__(self, handler: Union[dict, DataHandler], segments: list):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
handler : Union[dict, DataHandler]
|
||||
handler will be passed into setup_data
|
||||
segments : list
|
||||
handler will be passed into setup_data
|
||||
"""
|
||||
super().__init__(handler, segments)
|
||||
|
||||
def setup_data(self, handler: Union[dict, DataHandler], segments: list):
|
||||
"""
|
||||
setup the underlying data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
handler : Union[dict, DataHandler]
|
||||
handler could be
|
||||
1) insntance of `DataHandler`
|
||||
2) config of `DataHandler`. Please refer to `DataHandler`
|
||||
segments : list
|
||||
Describe the options to segment the data.
|
||||
Here are some examples
|
||||
1) 'segments': {
|
||||
'train': ("2008-01-01", "2014-12-31"),
|
||||
'valid': ("2017-01-01", "2020-08-01",),
|
||||
'test': ("2015-01-01", "2016-12-31",),
|
||||
}
|
||||
2) 'segments': {
|
||||
'insample': ("2008-01-01", "2014-12-31"),
|
||||
'outsample': ("2017-01-01", "2020-08-01",),
|
||||
}
|
||||
"""
|
||||
self._handler = init_instance_by_config(handler, accept_types=DataHandler)
|
||||
self._segments = segments
|
||||
|
||||
def prepare(self,
|
||||
segments: Union[List[str], Tuple[str], str, slice],
|
||||
col_set=DataHandler.CS_ALL,
|
||||
**kwargs) -> Union[List[pd.DataFrame], pd.DataFrame]:
|
||||
"""
|
||||
prepare the data for learning and inference
|
||||
|
||||
Parameters
|
||||
----------
|
||||
segments : Union[List[str], Tuple[str], str, slice]
|
||||
Describe the scope of the data to be prepared
|
||||
Here are some examples
|
||||
1) 'train'
|
||||
2) ['train', 'valid']
|
||||
col_set : [TODO:type]
|
||||
[TODO:description]
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[List[pd.DataFrame], pd.DataFrame]:
|
||||
[TODO:description]
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError:
|
||||
[TODO:description]
|
||||
"""
|
||||
if isinstance(segments, (list, tuple)):
|
||||
return [
|
||||
self._handler.fetch(slice(*self._segments[seg]), col_set=col_set, **kwargs) for seg in segments
|
||||
]
|
||||
elif isinstance(segments, str):
|
||||
return self._handler.fetch(slice(*self._segments[segments]), col_set=col_set, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
import abc
|
||||
import bisect
|
||||
import logging
|
||||
from typing import Union, Tuple
|
||||
from typing import Union, Tuple, List
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
@@ -15,6 +15,7 @@ from ...data import D
|
||||
from ...config import C
|
||||
from ...utils import parse_config, transform_end_date, init_instance_by_config
|
||||
from ...utils.serial import Serializable
|
||||
from .utils import get_level_index
|
||||
from pathlib import Path
|
||||
from .loader import DataLoader
|
||||
|
||||
@@ -82,34 +83,6 @@ class DataHandler(Serializable):
|
||||
self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time)
|
||||
# TODO: cache
|
||||
|
||||
def _get_level_index(self, df: pd.DataFrame, level=Union[str, int]) -> int:
|
||||
"""
|
||||
|
||||
get the level index of `df` given `level`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df : pd.DataFrame
|
||||
data
|
||||
level : Union[str, int]
|
||||
index level
|
||||
|
||||
Returns
|
||||
-------
|
||||
int:
|
||||
The level index in the multiple index
|
||||
"""
|
||||
if isinstance(level, str):
|
||||
try:
|
||||
return df.index.names.index(level)
|
||||
except (AttributeError, ValueError):
|
||||
# NOTE: If level index is not given in the data, the default level index will be ('datetime', 'instrument')
|
||||
return ('datetime', 'instrument').index(level)
|
||||
elif isinstance(level, int):
|
||||
return level
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
def _fetch_df_by_index(self, df: pd.DataFrame, selector: Union[pd.Timestamp, slice, str, list], level: Union[str, int]) -> pd.DataFrame:
|
||||
"""
|
||||
fetch data from `data` with `selector` and `level`
|
||||
@@ -123,11 +96,11 @@ class DataHandler(Serializable):
|
||||
"""
|
||||
# Try to get the right index
|
||||
idx_slc = (selector, slice(None, None))
|
||||
if self._get_level_index(df, level) == 1:
|
||||
if get_level_index(df, level) == 1:
|
||||
idx_slc = idx_slc[1], idx_slc[0]
|
||||
return df.loc(axis=0)[idx_slc]
|
||||
|
||||
CS_ALL = '_all'
|
||||
CS_ALL = '__all'
|
||||
|
||||
def _fetch_df_by_col(self, df: pd.DataFrame, col_set: str) -> pd.DataFrame:
|
||||
cln = len(df.columns.levels)
|
||||
@@ -138,7 +111,10 @@ class DataHandler(Serializable):
|
||||
else:
|
||||
return df.loc(axis=1)[col_set]
|
||||
|
||||
def fetch(self, selector: Union[pd.Timestamp, slice, str], level: Union[str, int]='datetime', col_set=CS_ALL) -> pd.DataFrame:
|
||||
def fetch(self,
|
||||
selector: Union[pd.Timestamp, slice, str],
|
||||
level: Union[str, int] = 'datetime',
|
||||
col_set: Union[str, List[str]] = CS_ALL) -> pd.DataFrame:
|
||||
"""
|
||||
fetch data from underlying data source
|
||||
|
||||
@@ -148,8 +124,11 @@ class DataHandler(Serializable):
|
||||
describe how to select data by index
|
||||
level : Union[str, int]
|
||||
which index level to select the data
|
||||
col_set : str
|
||||
select a set of meaningful columns.(e.g. features, columns)
|
||||
col_set : Union[str, List[str]]
|
||||
if isinstance(col_set, str):
|
||||
select a set of meaningful columns.(e.g. features, columns)
|
||||
if isinstance(col_set, List[str]):
|
||||
select several sets of meaningful columns, the returned data has multiple levels
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -195,7 +174,15 @@ class DataHandlerLP(DataHandler):
|
||||
# - _proc_learn_df will be processed by infer_processors + learn_processors
|
||||
# - (e.g. _proc_infer_df processed by learn_processors )
|
||||
|
||||
def __init__(self, instruments, start_time=None, end_time=None, data_loader: Tuple[dict, str, DataLoader]=None, infer_processors=[], learn_processors=[], process_type=PTYPE_A, **kwargs):
|
||||
def __init__(self,
|
||||
instruments,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
data_loader: Tuple[dict, str, DataLoader] = None,
|
||||
infer_processors=[],
|
||||
learn_processors=[],
|
||||
process_type=PTYPE_A,
|
||||
**kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
|
||||
32
qlib/data/dataset/utils.py
Normal file
32
qlib/data/dataset/utils.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from typing import Union
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int:
|
||||
"""
|
||||
|
||||
get the level index of `df` given `level`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df : pd.DataFrame
|
||||
data
|
||||
level : Union[str, int]
|
||||
index level
|
||||
|
||||
Returns
|
||||
-------
|
||||
int:
|
||||
The level index in the multiple index
|
||||
"""
|
||||
if isinstance(level, str):
|
||||
try:
|
||||
return df.index.names.index(level)
|
||||
except (AttributeError, ValueError):
|
||||
# NOTE: If level index is not given in the data, the default level index will be ('datetime', 'instrument')
|
||||
return ('datetime', 'instrument').index(level)
|
||||
elif isinstance(level, int):
|
||||
return level
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
import abc
|
||||
from ..utils.serial import Serializable
|
||||
from ..data.dataset import Dataset
|
||||
|
||||
|
||||
class BaseModel(Serializable, metaclass=abc.ABCMeta):
|
||||
@@ -20,45 +21,27 @@ class BaseModel(Serializable, metaclass=abc.ABCMeta):
|
||||
class Model(BaseModel):
|
||||
'''Learnable Models'''
|
||||
|
||||
# TODO: Make the model easier.
|
||||
def fit(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs):
|
||||
"""fix train with cross-validation
|
||||
Fit model when ex_config.finetune is False
|
||||
def fit(self, dataset: Dataset):
|
||||
"""
|
||||
Learn model from the base model
|
||||
|
||||
** NOTE **: The the attribute names of learned model should **not** start with '_'. So that the model could be
|
||||
dumped to disk.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x_train : pd.dataframe
|
||||
train data
|
||||
y_train : pd.dataframe
|
||||
train label
|
||||
x_valid : pd.dataframe
|
||||
valid data
|
||||
y_valid : pd.dataframe
|
||||
valid label
|
||||
w_train : pd.dataframe
|
||||
train weight
|
||||
w_valid : pd.dataframe
|
||||
valid weight
|
||||
|
||||
Returns
|
||||
----------
|
||||
Model
|
||||
trained model
|
||||
dataset : Dataset
|
||||
dataset will generate the processed data from model training
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def predict(self, x_test, **kwargs):
|
||||
"""predict given test data
|
||||
def predict(self, dataset: Dataset) -> object:
|
||||
"""give prediction given Dataset
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x_test : pd.dataframe
|
||||
test data
|
||||
|
||||
Returns
|
||||
----------
|
||||
np.ndarray
|
||||
test predict label
|
||||
dataset : Dataset
|
||||
dataset will generate the processed dataset from model training
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -194,7 +194,7 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
|
||||
return klass, kwargs
|
||||
|
||||
|
||||
def init_instance_by_config(config: Union[str, dict], module=None, accept_types: Tuple[type]=tuple([])) -> object:
|
||||
def init_instance_by_config(config: Union[str, dict], module=None, accept_types: Union[type, Tuple[type]]=tuple([])) -> object:
|
||||
"""
|
||||
get initialized instance with config
|
||||
|
||||
@@ -212,8 +212,9 @@ def init_instance_by_config(config: Union[str, dict], module=None, accept_types:
|
||||
module : Python module
|
||||
Optional. It should be a python module.
|
||||
|
||||
accept_types: Tuple[type]
|
||||
accept_types: Union[type, Tuple[type]]
|
||||
Optional. If the config is a instance of specific type, return the config directly.
|
||||
This will be passed into the second parameter of isinstance.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
@@ -60,4 +60,4 @@ class QlibRecorder:
|
||||
self.exp_manager.active_recorder.delete_tag(key)
|
||||
|
||||
# global record
|
||||
R = Wrapper()
|
||||
R = Wrapper()
|
||||
|
||||
Reference in New Issue
Block a user