1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

black format

This commit is contained in:
bxdd
2021-03-29 20:16:00 +08:00
parent fb7f84f31e
commit 8743576f72
7 changed files with 18 additions and 17 deletions

View File

@@ -70,7 +70,7 @@ class HighFreqNorm(Processor):
columns=["FEATURE_%d" % i for i in range(12 * 240)],
).sort_index()
return df_new_features
def config(self, fit_start_time=None, fit_end_time=None, **kwargs):
if fit_start_time:
self.fit_start_time = fit_start_time

View File

@@ -177,8 +177,8 @@ class HighfreqWorkflow(object):
dataset_backtest.setup_data(handler_kwargs={})
##=============get data=============
xtest, = dataset.prepare(["test"])
backtest_test, = dataset_backtest.prepare(["test"])
(xtest,) = dataset.prepare(["test"])
(backtest_test,) = dataset_backtest.prepare(["test"])
print(xtest, backtest_test)
return

View File

@@ -105,7 +105,7 @@ class RollingDataWorkflow(object):
handler_kwargs={
"start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
"end_time": datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),
"processor_kwargs":{
"processor_kwargs": {
"fit_start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
"fit_end_time": datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),
},
@@ -126,7 +126,9 @@ class RollingDataWorkflow(object):
},
)
dataset.setup_data(
handler_kwargs={"init_type": DataHandlerLP.IT_FIT_SEQ,}
handler_kwargs={
"init_type": DataHandlerLP.IT_FIT_SEQ,
}
)
dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"])

View File

@@ -35,7 +35,7 @@ class Dataset(Serializable):
def config(self, *arg, **kwargs):
"""
config is designed to configure and parameters that cannot be learned from the data
config is designed to configure and parameters that cannot be learned from the data
"""
super().config(*arg, **kwargs)
@@ -117,7 +117,7 @@ class DatasetH(Dataset):
self.segments = segments.copy()
super().__init__(**kwargs)
def config(self, handler_kwargs:dict = None, segments:dict = None, **kwargs):
def config(self, handler_kwargs: dict = None, segments: dict = None, **kwargs):
"""
Initialize the DatasetH
@@ -130,7 +130,7 @@ class DatasetH(Dataset):
kwargs : dict
Config of DatasetH, such as
- segments : dict
Config of segments which is same as 'segments' in self.__init__
@@ -141,8 +141,6 @@ class DatasetH(Dataset):
if segments is not None:
self.segments = segments.copy()
def setup_data(self, handler_kwargs: dict = None, **kwargs):
"""
Setup the Data
@@ -151,16 +149,15 @@ class DatasetH(Dataset):
----------
handler_kwargs : dict
init arguments of DataHanlder, which could include the following arguments:
- init_type : Init Type of Handler
- enable_cache : wheter to enable cache
"""
super().setup_data(**kwargs)
if handler_kwargs is not None:
self.handler.setup_data(**handler_kwargs)
def __repr__(self):
return "{name}(handler={handler}, segments={segments})".format(
@@ -464,7 +461,6 @@ class TSDatasetH(DatasetH):
cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique()
cal = sorted(cal)
self.cal = cal
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
# Dataset decide how to slice data(Get more data for timeseries).

View File

@@ -119,7 +119,7 @@ class DataHandler(Serializable):
self.start_time = start_time
if end_time:
self.end_time = end_time
def setup_data(self, enable_cache: bool = False):
"""
Set Up the data.
@@ -407,7 +407,7 @@ class DataHandlerLP(DataHandler):
if self.drop_raw:
del self._data
def config(self, processor_kwargs:dict = None, **kwargs):
def config(self, processor_kwargs: dict = None, **kwargs):
"""
configuration of data.
# what data to be loaded from data source

View File

@@ -53,6 +53,7 @@ class DataLoader(abc.ABC):
"""
pass
class DLWParser(DataLoader):
"""
(D)ata(L)oader (W)ith (P)arser for features and names

View File

@@ -202,6 +202,7 @@ class MinMaxNorm(Processor):
self.fit_end_time = fit_end_time
super().config(**kwargs)
class ZScoreNorm(Processor):
"""ZScore Normalization"""
@@ -229,7 +230,7 @@ class ZScoreNorm(Processor):
df.loc(axis=1)[self.cols] = normalize(df[self.cols].values)
return df
def config(self, fit_start_time=None, fit_end_time=None, **kwargs):
if fit_start_time:
self.fit_start_time = fit_start_time
@@ -280,6 +281,7 @@ class RobustZScoreNorm(Processor):
self.fit_end_time = fit_end_time
super().config(**kwargs)
class CSZScoreNorm(Processor):
"""Cross Sectional ZScore Normalization"""