1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 11:30:57 +08:00

fix some small bug

This commit is contained in:
Young
2020-11-14 08:23:19 +00:00
parent ea5f14ce12
commit 7a79028a72
6 changed files with 40 additions and 17 deletions

View File

@@ -664,9 +664,11 @@ class LocalExpressionProvider(ExpressionProvider):
lft_etd, rght_etd = expression.get_extended_window_size()
series = expression.load(instrument, max(0, start_index - lft_etd), end_index + rght_etd, freq)
# Ensure that each column type is consistent
# FIXME: The stock data is currently float. If there is other types of data, this part needs to be re-implemented.
# FIXME:
# 1) The stock data is currently float. If there is other types of data, this part needs to be re-implemented.
# 2) The the precision should be configurable
try:
series = series.astype(float)
series = series.astype(np.float32)
except ValueError:
pass
if not series.empty:

View File

@@ -99,10 +99,11 @@ class DataHandler(Serializable):
self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time)
# TODO: cache
CS_ALL = "__all"
CS_ALL = "__all" # return all columns with single-level index column
CS_RAW = "__raw" # return raw data with multi-level index column
def _fetch_df_by_col(self, df: pd.DataFrame, col_set: str) -> pd.DataFrame:
if not isinstance(df.columns, pd.MultiIndex):
if not isinstance(df.columns, pd.MultiIndex) or col_set == self.CS_RAW:
return df
elif col_set == self.CS_ALL:
return df.droplevel(axis=1, level=0)
@@ -111,7 +112,7 @@ class DataHandler(Serializable):
def fetch(
self,
selector: Union[pd.Timestamp, slice, str],
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = CS_ALL,
squeeze: bool = False,
@@ -371,7 +372,7 @@ class DataHandlerLP(DataHandler):
def fetch(
self,
selector: Union[pd.Timestamp, slice, str],
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
level: Union[str, int] = "datetime",
col_set=DataHandler.CS_ALL,
data_key: str = DK_I,

View File

@@ -52,4 +52,4 @@ def fetch_df_by_index(
idx_slc = (selector, slice(None, None))
if get_level_index(df, level) == 1:
idx_slc = idx_slc[1], idx_slc[0]
return df.loc(axis=0)[idx_slc]
return df.loc[pd.IndexSlice[idx_slc], ] # This could be faster than df.loc(axis=0)[idx_slc]

View File

@@ -185,7 +185,7 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
if isinstance(config, dict):
# raise AttributeError
klass = getattr(module, config["class"])
kwargs = config["kwargs"]
kwargs = config.get("kwargs", {})
elif isinstance(config, str):
klass = getattr(module, config)
kwargs = {}
@@ -619,6 +619,28 @@ def exists_qlib_data(qlib_dir):
return True
def lexsort_index(df: pd.DataFrame) -> pd.DataFrame:
"""
make the df index lexsorted
df.sort_index() will take a lot of time even when `df.is_lexsorted() == True`
This function could avoid such case
Parameters
----------
df : pd.DataFrame
Returns
-------
pd.DataFrame:
sorted dataframe
"""
if df.index.is_lexsorted():
return df
else:
return df.sort_index()
#################### Wrapper #####################
class Wrapper(object):
"""Data Provider Wrapper"""

View File

@@ -14,7 +14,7 @@ logger = get_module_logger("workflow", "INFO")
class ExpManager:
"""
This is the `ExpManager` class for managing the experiments. The API is designed similar to mlflow.
This is the `ExpManager` class for managing experiments. The API is designed similar to mlflow.
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
"""
@@ -34,10 +34,6 @@ class ExpManager:
name of the active experiment.
uri : str
the current tracking URI.
artifact_location : str
the location to store all the artifacts.
nested : boolean
controls whether run is nested in parent run.
Returns
-------
@@ -99,7 +95,7 @@ class ExpManager:
"""
raise NotImplementedError(f"Please implement the `create_exp` method.")
def get_exp(self, experiment_id=None, experiment_name=None):
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True):
"""
Retrieve an experiment by experiment_id from the backend store.
@@ -107,6 +103,8 @@ class ExpManager:
----------
experiment_id : str
the experiment id to return.
create : boolean
create the experiment if it does not exists
Returns
-------

View File

@@ -139,13 +139,13 @@ class Recorder:
"""
raise NotImplementedError(f"Please implement the `delete_tags` method.")
def list_artifacts(self, artifact_path=None):
def list_artifacts(self, artifact_path: str = None):
"""
List all the artifacts of a recorder.
Parameters
----------
artifact_path=None : str
artifact_path : str
the relative path for the artifact to be stored in the URI.
Returns
@@ -186,7 +186,7 @@ class MLflowRecorder(Recorder):
assert status in ["SCHEDULED", "RUNNING", "FINISHED", "FAILED"], f"The status type {status} is not supported."
mlflow.end_run(status)
self.end_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
if self.status is not "FINISHED":
if self.status != "FINISHED":
self.status = status
shutil.rmtree(self.temp_dir)