mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
black format
This commit is contained in:
@@ -70,7 +70,7 @@ class HighFreqNorm(Processor):
|
||||
columns=["FEATURE_%d" % i for i in range(12 * 240)],
|
||||
).sort_index()
|
||||
return df_new_features
|
||||
|
||||
|
||||
def config(self, fit_start_time=None, fit_end_time=None, **kwargs):
|
||||
if fit_start_time:
|
||||
self.fit_start_time = fit_start_time
|
||||
|
||||
@@ -177,8 +177,8 @@ class HighfreqWorkflow(object):
|
||||
dataset_backtest.setup_data(handler_kwargs={})
|
||||
|
||||
##=============get data=============
|
||||
xtest, = dataset.prepare(["test"])
|
||||
backtest_test, = dataset_backtest.prepare(["test"])
|
||||
(xtest,) = dataset.prepare(["test"])
|
||||
(backtest_test,) = dataset_backtest.prepare(["test"])
|
||||
|
||||
print(xtest, backtest_test)
|
||||
return
|
||||
|
||||
@@ -105,7 +105,7 @@ class RollingDataWorkflow(object):
|
||||
handler_kwargs={
|
||||
"start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
|
||||
"end_time": datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),
|
||||
"processor_kwargs":{
|
||||
"processor_kwargs": {
|
||||
"fit_start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
|
||||
"fit_end_time": datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),
|
||||
},
|
||||
@@ -126,7 +126,9 @@ class RollingDataWorkflow(object):
|
||||
},
|
||||
)
|
||||
dataset.setup_data(
|
||||
handler_kwargs={"init_type": DataHandlerLP.IT_FIT_SEQ,}
|
||||
handler_kwargs={
|
||||
"init_type": DataHandlerLP.IT_FIT_SEQ,
|
||||
}
|
||||
)
|
||||
|
||||
dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"])
|
||||
|
||||
@@ -35,7 +35,7 @@ class Dataset(Serializable):
|
||||
|
||||
def config(self, *arg, **kwargs):
|
||||
"""
|
||||
config is designed to configure and parameters that cannot be learned from the data
|
||||
config is designed to configure and parameters that cannot be learned from the data
|
||||
"""
|
||||
super().config(*arg, **kwargs)
|
||||
|
||||
@@ -117,7 +117,7 @@ class DatasetH(Dataset):
|
||||
self.segments = segments.copy()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def config(self, handler_kwargs:dict = None, segments:dict = None, **kwargs):
|
||||
def config(self, handler_kwargs: dict = None, segments: dict = None, **kwargs):
|
||||
"""
|
||||
Initialize the DatasetH
|
||||
|
||||
@@ -130,7 +130,7 @@ class DatasetH(Dataset):
|
||||
|
||||
kwargs : dict
|
||||
Config of DatasetH, such as
|
||||
|
||||
|
||||
- segments : dict
|
||||
Config of segments which is same as 'segments' in self.__init__
|
||||
|
||||
@@ -141,8 +141,6 @@ class DatasetH(Dataset):
|
||||
if segments is not None:
|
||||
self.segments = segments.copy()
|
||||
|
||||
|
||||
|
||||
def setup_data(self, handler_kwargs: dict = None, **kwargs):
|
||||
"""
|
||||
Setup the Data
|
||||
@@ -151,16 +149,15 @@ class DatasetH(Dataset):
|
||||
----------
|
||||
handler_kwargs : dict
|
||||
init arguments of DataHanlder, which could include the following arguments:
|
||||
|
||||
|
||||
- init_type : Init Type of Handler
|
||||
|
||||
|
||||
- enable_cache : wheter to enable cache
|
||||
|
||||
"""
|
||||
super().setup_data(**kwargs)
|
||||
if handler_kwargs is not None:
|
||||
self.handler.setup_data(**handler_kwargs)
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(handler={handler}, segments={segments})".format(
|
||||
@@ -464,7 +461,6 @@ class TSDatasetH(DatasetH):
|
||||
cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique()
|
||||
cal = sorted(cal)
|
||||
self.cal = cal
|
||||
|
||||
|
||||
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
|
||||
# Dataset decide how to slice data(Get more data for timeseries).
|
||||
|
||||
@@ -119,7 +119,7 @@ class DataHandler(Serializable):
|
||||
self.start_time = start_time
|
||||
if end_time:
|
||||
self.end_time = end_time
|
||||
|
||||
|
||||
def setup_data(self, enable_cache: bool = False):
|
||||
"""
|
||||
Set Up the data.
|
||||
@@ -407,7 +407,7 @@ class DataHandlerLP(DataHandler):
|
||||
if self.drop_raw:
|
||||
del self._data
|
||||
|
||||
def config(self, processor_kwargs:dict = None, **kwargs):
|
||||
def config(self, processor_kwargs: dict = None, **kwargs):
|
||||
"""
|
||||
configuration of data.
|
||||
# what data to be loaded from data source
|
||||
|
||||
@@ -53,6 +53,7 @@ class DataLoader(abc.ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DLWParser(DataLoader):
|
||||
"""
|
||||
(D)ata(L)oader (W)ith (P)arser for features and names
|
||||
|
||||
@@ -202,6 +202,7 @@ class MinMaxNorm(Processor):
|
||||
self.fit_end_time = fit_end_time
|
||||
super().config(**kwargs)
|
||||
|
||||
|
||||
class ZScoreNorm(Processor):
|
||||
"""ZScore Normalization"""
|
||||
|
||||
@@ -229,7 +230,7 @@ class ZScoreNorm(Processor):
|
||||
|
||||
df.loc(axis=1)[self.cols] = normalize(df[self.cols].values)
|
||||
return df
|
||||
|
||||
|
||||
def config(self, fit_start_time=None, fit_end_time=None, **kwargs):
|
||||
if fit_start_time:
|
||||
self.fit_start_time = fit_start_time
|
||||
@@ -280,6 +281,7 @@ class RobustZScoreNorm(Processor):
|
||||
self.fit_end_time = fit_end_time
|
||||
super().config(**kwargs)
|
||||
|
||||
|
||||
class CSZScoreNorm(Processor):
|
||||
"""Cross Sectional ZScore Normalization"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user