From 89f907bf6c0a3bad6b380ee5eb34d87d15c2ecc6 Mon Sep 17 00:00:00 2001 From: Young Date: Sun, 29 Nov 2020 09:21:46 +0000 Subject: [PATCH] set the task base class --- README.md | 2 +- qlib/model/task.py | 163 +++++++-------------------------------------- 2 files changed, 25 insertions(+), 140 deletions(-) diff --git a/README.md b/README.md index 89d14e9eb..f9fdf1719 100644 --- a/README.md +++ b/README.md @@ -210,7 +210,7 @@ Your PR of new Quant models is highly welcomed. `Qlib` provides three different ways to run a single model, users can pick the one that fits their cases best: - User can use the tool `qrun` mentioned above to run a model's workflow based from a config file. - User can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder. -- User can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`. For more use cases, please refer to the file's [docstrings](examples/run_all_model.py). +- User can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py). ## Run multiple models `Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only supprots *Linux* now. Other OS will be supported in the future.) diff --git a/qlib/model/task.py b/qlib/model/task.py index e66159233..f29f513a4 100644 --- a/qlib/model/task.py +++ b/qlib/model/task.py @@ -1,142 +1,27 @@ -''' -Please implement similar function here - -# Rolling relealted +import abc +import typing - def split_rolling_periods( - self, - train_start_date, - train_end_date, - validate_start_date, - validate_end_date, - test_start_date, - test_end_date, - rolling_period, - calendar_freq="day", - ): + +class TaskGen(metaclass=abc.ABCMeta): + @abc.abstractmethod + def __call__(self, *args, **kwargs) -> typing.List[dict]: """ - Calculating the Rolling split periods, the period rolling on market calendar. - :param train_start_date: - :param train_end_date: - :param validate_start_date: - :param validate_end_date: - :param test_start_date: - :param test_end_date: - :param rolling_period: The market period of rolling - :param calendar_freq: The frequence of the market calendar - :yield: Rolling split periods + generate + + Parameters + ---------- + args, kwargs: + The info for generating tasks + Example 1): + input: a specific task template + output: rolling version of the tasks + Example 2): + input: a specific task template + output: a set of tasks with different losses + + Returns + ------- + typing.List[dict]: + A list of tasks """ - - def get_start_index(calendar, start_date): - start_index = bisect.bisect_left(calendar, start_date) - return start_index - - def get_end_index(calendar, end_date): - end_index = bisect.bisect_right(calendar, end_date) - return end_index - 1 - - calendar = self.raw_df.index.get_level_values("datetime").unique() - - train_start_index = get_start_index(calendar, pd.Timestamp(train_start_date)) - train_end_index = get_end_index(calendar, pd.Timestamp(train_end_date)) - valid_start_index = get_start_index(calendar, pd.Timestamp(validate_start_date)) - valid_end_index = get_end_index(calendar, pd.Timestamp(validate_end_date)) - test_start_index = get_start_index(calendar, pd.Timestamp(test_start_date)) - test_end_index = test_start_index + rolling_period - 1 - - need_stop_split = False - - bound_test_end_index = get_end_index(calendar, pd.Timestamp(test_end_date)) - - while not need_stop_split: - - if test_end_index > bound_test_end_index: - test_end_index = bound_test_end_index - need_stop_split = True - - yield ( - calendar[train_start_index], - calendar[train_end_index], - calendar[valid_start_index], - calendar[valid_end_index], - calendar[test_start_index], - calendar[test_end_index], - ) - - train_start_index += rolling_period - train_end_index += rolling_period - valid_start_index += rolling_period - valid_end_index += rolling_period - test_start_index += rolling_period - test_end_index += rolling_period - - def get_rolling_data( - self, - train_start_date, - train_end_date, - validate_start_date, - validate_end_date, - test_start_date, - test_end_date, - rolling_period, - calendar_freq="day", - ): - # Set generator. - for period in self.split_rolling_periods( - train_start_date, - train_end_date, - validate_start_date, - validate_end_date, - test_start_date, - test_end_date, - rolling_period, - calendar_freq, - ): - ( - x_train, - y_train, - x_validate, - y_validate, - x_test, - y_test, - ) = self.get_split_data(*period) - yield x_train, y_train, x_validate, y_validate, x_test, y_test - - def get_split_data( - self, - train_start_date, - train_end_date, - validate_start_date, - validate_end_date, - test_start_date, - test_end_date, - ): - """ - all return types are DataFrame - """ - ## TODO: loc can be slow, expecially when we put it at the second level index. - if self.raw_df.index.names[0] == "instrument": - df_train = self.raw_df.loc(axis=0)[:, train_start_date:train_end_date] - df_validate = self.raw_df.loc(axis=0)[:, validate_start_date:validate_end_date] - df_test = self.raw_df.loc(axis=0)[:, test_start_date:test_end_date] - else: - df_train = self.raw_df.loc[train_start_date:train_end_date] - df_validate = self.raw_df.loc[validate_start_date:validate_end_date] - df_test = self.raw_df.loc[test_start_date:test_end_date] - - TimeInspector.set_time_mark() - df_train, df_validate, df_test = self.process_data(df_train, df_validate, df_test) - TimeInspector.log_cost_time("Finished setup processed data.") - - x_train = df_train[self.feature_names] - y_train = df_train[self.label_names] - - x_validate = df_validate[self.feature_names] - y_validate = df_validate[self.label_names] - - x_test = df_test[self.feature_names] - y_test = df_test[self.label_names] - - return x_train, y_train, x_validate, y_validate, x_test, y_test - -''' + pass