mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 11:30:57 +08:00
fix some small bug
This commit is contained in:
@@ -664,9 +664,11 @@ class LocalExpressionProvider(ExpressionProvider):
|
||||
lft_etd, rght_etd = expression.get_extended_window_size()
|
||||
series = expression.load(instrument, max(0, start_index - lft_etd), end_index + rght_etd, freq)
|
||||
# Ensure that each column type is consistent
|
||||
# FIXME: The stock data is currently float. If there is other types of data, this part needs to be re-implemented.
|
||||
# FIXME:
|
||||
# 1) The stock data is currently float. If there is other types of data, this part needs to be re-implemented.
|
||||
# 2) The the precision should be configurable
|
||||
try:
|
||||
series = series.astype(float)
|
||||
series = series.astype(np.float32)
|
||||
except ValueError:
|
||||
pass
|
||||
if not series.empty:
|
||||
|
||||
@@ -99,10 +99,11 @@ class DataHandler(Serializable):
|
||||
self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time)
|
||||
# TODO: cache
|
||||
|
||||
CS_ALL = "__all"
|
||||
CS_ALL = "__all" # return all columns with single-level index column
|
||||
CS_RAW = "__raw" # return raw data with multi-level index column
|
||||
|
||||
def _fetch_df_by_col(self, df: pd.DataFrame, col_set: str) -> pd.DataFrame:
|
||||
if not isinstance(df.columns, pd.MultiIndex):
|
||||
if not isinstance(df.columns, pd.MultiIndex) or col_set == self.CS_RAW:
|
||||
return df
|
||||
elif col_set == self.CS_ALL:
|
||||
return df.droplevel(axis=1, level=0)
|
||||
@@ -111,7 +112,7 @@ class DataHandler(Serializable):
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
selector: Union[pd.Timestamp, slice, str],
|
||||
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = CS_ALL,
|
||||
squeeze: bool = False,
|
||||
@@ -371,7 +372,7 @@ class DataHandlerLP(DataHandler):
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
selector: Union[pd.Timestamp, slice, str],
|
||||
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set=DataHandler.CS_ALL,
|
||||
data_key: str = DK_I,
|
||||
|
||||
@@ -52,4 +52,4 @@ def fetch_df_by_index(
|
||||
idx_slc = (selector, slice(None, None))
|
||||
if get_level_index(df, level) == 1:
|
||||
idx_slc = idx_slc[1], idx_slc[0]
|
||||
return df.loc(axis=0)[idx_slc]
|
||||
return df.loc[pd.IndexSlice[idx_slc], ] # This could be faster than df.loc(axis=0)[idx_slc]
|
||||
|
||||
@@ -185,7 +185,7 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
|
||||
if isinstance(config, dict):
|
||||
# raise AttributeError
|
||||
klass = getattr(module, config["class"])
|
||||
kwargs = config["kwargs"]
|
||||
kwargs = config.get("kwargs", {})
|
||||
elif isinstance(config, str):
|
||||
klass = getattr(module, config)
|
||||
kwargs = {}
|
||||
@@ -619,6 +619,28 @@ def exists_qlib_data(qlib_dir):
|
||||
return True
|
||||
|
||||
|
||||
def lexsort_index(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
make the df index lexsorted
|
||||
|
||||
df.sort_index() will take a lot of time even when `df.is_lexsorted() == True`
|
||||
This function could avoid such case
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df : pd.DataFrame
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
sorted dataframe
|
||||
"""
|
||||
if df.index.is_lexsorted():
|
||||
return df
|
||||
else:
|
||||
return df.sort_index()
|
||||
|
||||
|
||||
#################### Wrapper #####################
|
||||
class Wrapper(object):
|
||||
"""Data Provider Wrapper"""
|
||||
|
||||
@@ -14,7 +14,7 @@ logger = get_module_logger("workflow", "INFO")
|
||||
|
||||
class ExpManager:
|
||||
"""
|
||||
This is the `ExpManager` class for managing the experiments. The API is designed similar to mlflow.
|
||||
This is the `ExpManager` class for managing experiments. The API is designed similar to mlflow.
|
||||
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
|
||||
"""
|
||||
|
||||
@@ -34,10 +34,6 @@ class ExpManager:
|
||||
name of the active experiment.
|
||||
uri : str
|
||||
the current tracking URI.
|
||||
artifact_location : str
|
||||
the location to store all the artifacts.
|
||||
nested : boolean
|
||||
controls whether run is nested in parent run.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -99,7 +95,7 @@ class ExpManager:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `create_exp` method.")
|
||||
|
||||
def get_exp(self, experiment_id=None, experiment_name=None):
|
||||
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True):
|
||||
"""
|
||||
Retrieve an experiment by experiment_id from the backend store.
|
||||
|
||||
@@ -107,6 +103,8 @@ class ExpManager:
|
||||
----------
|
||||
experiment_id : str
|
||||
the experiment id to return.
|
||||
create : boolean
|
||||
create the experiment if it does not exists
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
@@ -139,13 +139,13 @@ class Recorder:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `delete_tags` method.")
|
||||
|
||||
def list_artifacts(self, artifact_path=None):
|
||||
def list_artifacts(self, artifact_path: str = None):
|
||||
"""
|
||||
List all the artifacts of a recorder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
artifact_path=None : str
|
||||
artifact_path : str
|
||||
the relative path for the artifact to be stored in the URI.
|
||||
|
||||
Returns
|
||||
@@ -186,7 +186,7 @@ class MLflowRecorder(Recorder):
|
||||
assert status in ["SCHEDULED", "RUNNING", "FINISHED", "FAILED"], f"The status type {status} is not supported."
|
||||
mlflow.end_run(status)
|
||||
self.end_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
if self.status is not "FINISHED":
|
||||
if self.status != "FINISHED":
|
||||
self.status = status
|
||||
shutil.rmtree(self.temp_dir)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user