diff --git a/examples/trade/order_gen.py b/examples/trade/order_gen.py index 71499523f..51add3918 100644 --- a/examples/trade/order_gen.py +++ b/examples/trade/order_gen.py @@ -33,27 +33,30 @@ def w_order(f, start, end): order_valid = order_test[order_test.index.get_level_values(0) < '2021-01-01'] order_test = order_test[order_test.index.get_level_values(0) >= '2021-01-01'] if len(order_train) > 0: - train_path = os.path.join(data_path, "order/train/") - if not os.path.exists(train_path): - os.makedirs(train_path) order_train.to_pickle(train_path + f[:-9] + '.target') if len(order_valid) > 0: - valid_path = os.path.join(data_path, "order/valid/") - if not os.path.exists(valid_path): - os.makedirs(valid_path) order_valid.to_pickle(valid_path + f[:-9] + '.target') if len(order_test) > 0: - test_path = os.path.join(data_path, "order/test/") - if not os.path.exists(test_path): - os.makedirs(test_path) order_test.to_pickle(test_path + f[:-9] + '.target') if len(order) > 0: - all_path = os.path.join(data_path, "order/all/") - if not os.path.exists(all_path): - os.makedirs(all_path) - order.to_pickle(all_path + f[:-9] + '.target') return 0 +train_path = os.path.join(data_path, "order/train/") +if not os.path.exists(train_path): + os.makedirs(train_path) + +valid_path = os.path.join(data_path, "order/valid/") +if not os.path.exists(valid_path): + os.makedirs(valid_path) + +test_path = os.path.join(data_path, "order/test/") +if not os.path.exists(test_path): + os.makedirs(test_path) + +all_path = os.path.join(data_path, "order/all/") +if not os.path.exists(all_path): + os.makedirs(all_path) + res = Parallel(n_jobs=64)(delayed(w_order)(f, 0, 239) for f in os.listdir(in_dir)) print(sum(res)) diff --git a/qlib/config.py b/qlib/config.py index e7120c23a..52b05568d 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -90,7 +90,6 @@ _default_config = { # How many tasks belong to one process. Recommend 1 for high-frequency data and None for daily data. "maxtasksperchild": None, "default_disk_cache": 1, # 0:skip/1:use - "disable_disk_cache": False, # disable disk cache; if High-frequency data generally disable_disk_cache=True "mem_cache_size_limit": 500, # memory cache expire second, only in used 'DatasetURICache' and 'client D.calendar' # default 1 hour diff --git a/qlib/data/data.py b/qlib/data/data.py index 2a0e569ab..71915a3c3 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -961,8 +961,7 @@ class BaseProvider: is a provider class. """ disk_cache = C.default_disk_cache if disk_cache is None else disk_cache - if C.disable_disk_cache: - disk_cache = False + fields = list(fields) # In case of tuple. try: return DatasetD.dataset(instruments, fields, start_time, end_time, freq, disk_cache) except TypeError: diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index abcd5a60c..74bafbb80 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -57,10 +57,10 @@ class DataHandler(Serializable): instruments=None, start_time=None, end_time=None, - freq="day", data_loader: Tuple[dict, str, DataLoader] = None, init_data=True, fetch_orig=True, + **kwargs, ): """ Parameters @@ -71,14 +71,14 @@ class DataHandler(Serializable): start_time of the original data. end_time : end_time of the original data. - freq : - frequency of data data_loader : Tuple[dict, str, DataLoader] data loader to load the data. init_data : intialize the original data in the constructor. fetch_orig : bool Return the original data instead of copy if possible. + **kwargs: + it will be passed into data_loader """ # Set logger self.logger = get_module_logger("DataHandler") @@ -86,23 +86,41 @@ class DataHandler(Serializable): # Setup data loader assert data_loader is not None # to make start_time end_time could have None default value + # what data source to load data self.data_loader = init_instance_by_config( data_loader, None if (isinstance(data_loader, dict) and "module_path" in data_loader) else data_loader_module, accept_types=DataLoader, + **kwargs, ) + # what data to be loaded from data source + # For IDE auto-completion. self.instruments = instruments self.start_time = start_time self.end_time = end_time - self.freq = freq + self.fetch_orig = fetch_orig if init_data: with TimeInspector.logt("Init data"): self.init() super().__init__() - def init(self, enable_cache: bool = True): + def conf_data(self, **kwargs): + """ + configuration of data. + # what data to be loaded from data source + This method will be used when loading pickled handler from dataset. + The data will be initialized with different time range. + """ + attr_list = {"instruments", "start_time", "end_time"} + for k, v in kwargs.items(): + if k in attr_list: + setattr(self, k, v) + else: + raise KeyError("Such config is not supported.") + + def init(self, enable_cache: bool = False): """ initialize the data. In case of running intialization for multiple time, it will do nothing for the second time. @@ -123,7 +141,7 @@ class DataHandler(Serializable): # Setup data. # _data may be with multiple column index level. The outer level indicates the feature set name with TimeInspector.logt("Loading data"): - self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time, self.freq) + self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time) # TODO: cache CS_ALL = "__all" # return all columns with single-level index column @@ -262,7 +280,6 @@ class DataHandlerLP(DataHandler): instruments=None, start_time=None, end_time=None, - freq="day", data_loader: Tuple[dict, str, DataLoader] = None, infer_processors=[], learn_processors=[], @@ -328,7 +345,7 @@ class DataHandlerLP(DataHandler): self.process_type = process_type self.drop_raw = drop_raw - super().__init__(instruments, start_time, end_time, freq, data_loader, **kwargs) + super().__init__(instruments, start_time, end_time, data_loader, **kwargs) def get_all_processors(self): return self.infer_processors + self.learn_processors diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 3b33ff749..5e7af6f9b 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -19,7 +19,7 @@ class DataLoader(abc.ABC): """ @abc.abstractmethod - def load(self, instruments, start_time=None, end_time=None, freq="day") -> pd.DataFrame: + def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame: """ load the data as pd.DataFrame. @@ -76,6 +76,7 @@ class DLWParser(DataLoader): := := ["expr", ...] | (["expr", ...], ["col_name", ...]) + # NOTE: list or tuple will be treated as the things when parsing """ self.is_group = isinstance(config, dict) @@ -85,18 +86,22 @@ class DLWParser(DataLoader): self.fields = self._parse_fields_info(config) def _parse_fields_info(self, fields_info: Tuple[list, tuple]) -> Tuple[list, list]: - if isinstance(fields_info, list): + if len(fields_info) == 0: + raise ValueError("The size of fields must be greater than 0") + + if not isinstance(fields_info, (list, tuple)): + raise TypeError("Unsupported type") + + if isinstance(fields_info[0], str): exprs = names = fields_info - elif isinstance(fields_info, tuple): + elif isinstance(fields_info[0], (list, tuple)): exprs, names = fields_info else: raise NotImplementedError(f"This type of input is not supported") return exprs, names @abc.abstractmethod - def load_group_df( - self, instruments, exprs: list, names: list, start_time=None, end_time=None, freq="day" - ) -> pd.DataFrame: + def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame: """ load the dataframe for specific group @@ -116,25 +121,25 @@ class DLWParser(DataLoader): """ pass - def load(self, instruments=None, start_time=None, end_time=None, freq="day") -> pd.DataFrame: + def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame: if self.is_group: df = pd.concat( { - grp: self.load_group_df(instruments, exprs, names, start_time, end_time, freq) + grp: self.load_group_df(instruments, exprs, names, start_time, end_time) for grp, (exprs, names) in self.fields.items() }, axis=1, ) else: exprs, names = self.fields - df = self.load_group_df(instruments, exprs, names, start_time, end_time, freq) + df = self.load_group_df(instruments, exprs, names, start_time, end_time) return df class QlibDataLoader(DLWParser): """Same as QlibDataLoader. The fields can be define by config""" - def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None, swap_level=True): + def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None, swap_level=True, freq="day"): """ Parameters ---------- @@ -147,11 +152,10 @@ class QlibDataLoader(DLWParser): """ self.filter_pipe = filter_pipe self.swap_level = swap_level + self.freq = freq super().__init__(config) - def load_group_df( - self, instruments, exprs: list, names: list, start_time=None, end_time=None, freq="day" - ) -> pd.DataFrame: + def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame: if instruments is None: warnings.warn("`instruments` is not set, will load all stocks") instruments = "all" @@ -160,7 +164,7 @@ class QlibDataLoader(DLWParser): elif self.filter_pipe is not None: warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list") - df = D.features(instruments, exprs, start_time, end_time, freq) + df = D.features(instruments, exprs, start_time, end_time, self.freq) df.columns = names if self.swap_level: df = df.swaplevel().sort_index() # NOTE: if swaplevel, return @@ -185,7 +189,7 @@ class StaticDataLoader(DataLoader): self.join = join self._data = None - def load(self, instruments=None, start_time=None, end_time=None, freq="day") -> pd.DataFrame: + def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame: self._maybe_load_raw_data() if instruments is None: df = self._data diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index 15faa0da1..aedf73a9c 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -3,6 +3,7 @@ from contextlib import contextmanager from .expm import MLflowExpManager +from .exp import Experiment from .recorder import Recorder from ..utils import Wrapper @@ -165,7 +166,7 @@ class QlibRecorder: """ return self.get_exp(experiment_id, experiment_name).list_recorders() - def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True): + def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment: """ Method for retrieving an experiment with given id or name. Once the `create` argument is set to True, if no valid experiment is found, this method will create one for you. Otherwise, it will