diff --git a/.github/workflows/test_qlib_from_source.yml b/.github/workflows/test_qlib_from_source.yml index ffeb66483..edaee5576 100644 --- a/.github/workflows/test_qlib_from_source.yml +++ b/.github/workflows/test_qlib_from_source.yml @@ -87,9 +87,10 @@ jobs: # E1102: not-callable # E1136: unsubscriptable-object # References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962 + # We use sys.setrecursionlimit(2000) to make the recursion depth larger to ensure that pylint works properly (the default recursion depth is 1000). - name: Check Qlib with pylint run: | - pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500" + pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)" # The following flake8 error codes were ignored: # E501 line too long diff --git a/examples/rl/scripts/collect_pickle_dataframe.py b/examples/rl/scripts/collect_pickle_dataframe.py index 8950ec203..64dc94bdb 100644 --- a/examples/rl/scripts/collect_pickle_dataframe.py +++ b/examples/rl/scripts/collect_pickle_dataframe.py @@ -13,7 +13,7 @@ for tag in ("backtest", "feature"): df = pd.concat(list(df.values())).reset_index() df["date"] = df["datetime"].dt.date.astype("datetime64") instruments = sorted(set(df["instrument"])) - + os.makedirs(os.path.join("data", "pickle_dataframe", tag), exist_ok=True) for instrument in tqdm(instruments): cur = df[df["instrument"] == instrument].sort_values(by=["datetime"]) diff --git a/examples/rl/scripts/gen_backtest_orders.py b/examples/rl/scripts/gen_backtest_orders.py index c3d0e4ef9..cdf6f9cb8 100644 --- a/examples/rl/scripts/gen_backtest_orders.py +++ b/examples/rl/scripts/gen_backtest_orders.py @@ -22,19 +22,21 @@ instruments = sorted(set(df["instrument"])) df_list = [] for instrument in instruments: print(instrument) - + cur_df = df[df["instrument"] == instrument] - + dates = sorted(set([str(d).split(" ")[0] for d in cur_df["date"]])) - + n = args.num_order df_list.append( - pd.DataFrame({ - "date": sorted(np.random.choice(dates, size=n, replace=False)), - "instrument": [instrument] * n, - "amount": np.random.randint(low=3, high=11, size=n) * 100.0, - "order_type": np.random.randint(low=0, high=2, size=n), - }).set_index(["date", "instrument"]), + pd.DataFrame( + { + "date": sorted(np.random.choice(dates, size=n, replace=False)), + "instrument": [instrument] * n, + "amount": np.random.randint(low=3, high=11, size=n) * 100.0, + "order_type": np.random.randint(low=0, high=2, size=n), + } + ).set_index(["date", "instrument"]), ) total_df = pd.concat(df_list) diff --git a/examples/rl/scripts/gen_pickle_data.py b/examples/rl/scripts/gen_pickle_data.py index 3cb74f314..f2dbbf115 100755 --- a/examples/rl/scripts/gen_pickle_data.py +++ b/examples/rl/scripts/gen_pickle_data.py @@ -30,8 +30,8 @@ if __name__ == "__main__": if "backtest_conf" in conf: backtest = provider._gen_dataframe(deepcopy(provider.backtest_conf)) - provider.feature_conf['path'] = os.path.splitext(provider.feature_conf['path'])[0] + '/' - provider.backtest_conf['path'] = os.path.splitext(provider.backtest_conf['path'])[0] + '/' + provider.feature_conf["path"] = os.path.splitext(provider.feature_conf["path"])[0] + "/" + provider.backtest_conf["path"] = os.path.splitext(provider.backtest_conf["path"])[0] + "/" # Split by date if args.split == "date" or args.split == "both": provider._gen_day_dataset(deepcopy(provider.feature_conf), "feature") diff --git a/examples/rl/scripts/gen_training_orders.py b/examples/rl/scripts/gen_training_orders.py index 07383c860..5dd1e96c6 100644 --- a/examples/rl/scripts/gen_training_orders.py +++ b/examples/rl/scripts/gen_training_orders.py @@ -23,15 +23,17 @@ for group, n in zip(("train", "valid", "test"), (args.train_size, args.valid_siz path = os.path.join("data", "pickle", f"backtest{group}.pkl") df = pickle.load(open(path, "rb")).reset_index() df["date"] = df["datetime"].dt.date.astype("datetime64") - + dates = sorted(set([str(d).split(" ")[0] for d in df["date"]])) - data_df = pd.DataFrame({ - "date": sorted(np.random.choice(dates, size=n, replace=False)), - "instrument": [args.stock] * n, - "amount": np.random.randint(low=3, high=11, size=n) * 100.0, - "order_type": [0] * n, - }).set_index(["date", "instrument"]) + data_df = pd.DataFrame( + { + "date": sorted(np.random.choice(dates, size=n, replace=False)), + "instrument": [args.stock] * n, + "amount": np.random.randint(low=3, high=11, size=n) * 100.0, + "order_type": [0] * n, + } + ).set_index(["date", "instrument"]) os.makedirs(os.path.join("data", "training_order_split", group), exist_ok=True) pickle.dump(data_df, open(os.path.join("data", "training_order_split", group, f"{args.stock}.pkl"), "wb")) diff --git a/qlib/backtest/decision.py b/qlib/backtest/decision.py index af45d1e67..27026b25e 100644 --- a/qlib/backtest/decision.py +++ b/qlib/backtest/decision.py @@ -579,8 +579,11 @@ class TradeDecisionWO(BaseTradeDecision[Order]): class TradeDecisionWithDetails(TradeDecisionWO): - """Decision with detail information. Detail information is used to generate execution reports. """ + Decision with detail information. + Detail information is used to generate execution reports. + """ + def __init__( self, order_list: List[Order], diff --git a/qlib/rl/contrib/backtest.py b/qlib/rl/contrib/backtest.py index 4cd101150..695c13d2e 100644 --- a/qlib/rl/contrib/backtest.py +++ b/qlib/rl/contrib/backtest.py @@ -8,13 +8,14 @@ import os import pickle from collections import defaultdict from pathlib import Path -from typing import List, Literal, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import pandas as pd import torch from joblib import Parallel, delayed +from qlib.typehint import Literal from qlib.backtest import collect_data_loop, get_strategy_executor from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir, TradeRangeByTime from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor diff --git a/setup.py b/setup.py index 0ca9f26ba..4527cf910 100644 --- a/setup.py +++ b/setup.py @@ -142,7 +142,11 @@ setup( "setuptools", "black", "pylint", - "mypy", + # Using the latest versions(0.981 and 0.982) of mypy, + # the error "multiprocessing.Value()" is detected in the file "qlib/rl/utils/data_queue.py", + # If this is fixed in a subsequent version of mypy, then we will revert to the latest version of mypy. + # References: https://github.com/python/typeshed/issues/8799 + "mypy<0.981", "flake8", "readthedocs_sphinx_ext", "cmake", diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index 30c0b5010..7dc904bce 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -56,39 +56,8 @@ def train(uri_path: str = None): ic = sar.load("ic.pkl") ric = sar.load("ric.pkl") - return pred_score, {"ic": ic, "ric": ric}, rid - - -def train_with_sigana(uri_path: str = None): - """train model followed by SigAnaRecord - - Returns - ------- - pred_score: pandas.DataFrame - predict scores - performance: dict - model performance - """ - model = init_instance_by_config(CSI300_GBDT_TASK["model"]) - dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"]) - # start exp - with R.start(experiment_name="workflow_with_sigana", uri=uri_path): - R.log_params(**flatten_dict(CSI300_GBDT_TASK)) - model.fit(dataset) - recorder = R.get_recorder() - - sr = SignalRecord(model, dataset, recorder) - sr.generate() - pred_score = sr.load("pred.pkl") - - # predict and calculate ic and ric - sar = SigAnaRecord(recorder) - sar.generate() - ic = sar.load("ic.pkl") - ric = sar.load("ric.pkl") - uri_path = R.get_uri() - return pred_score, {"ic": ic, "ric": ric}, uri_path + return pred_score, {"ic": ic, "ric": ric}, rid, uri_path def fake_experiment(): @@ -186,19 +155,13 @@ class TestAllFlow(TestAutoData): shutil.rmtree(cls.URI_PATH.lstrip("file:")) @pytest.mark.slow - def test_0_train_with_sigana(self): - TestAllFlow.PRED_SCORE, ic_ric, uri_path = train_with_sigana(self.URI_PATH) + def test_0_train(self): + TestAllFlow.PRED_SCORE, ic_ric, TestAllFlow.RID, uri_path = train(self.URI_PATH) self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed") self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed") @pytest.mark.slow - def test_1_train(self): - TestAllFlow.PRED_SCORE, ic_ric, TestAllFlow.RID = train(self.URI_PATH) - self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed") - self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed") - - @pytest.mark.slow - def test_2_backtest(self): + def test_1_backtest(self): analyze_df = backtest_analysis(TestAllFlow.PRED_SCORE, TestAllFlow.RID, self.URI_PATH) self.assertGreaterEqual( analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0], @@ -208,7 +171,7 @@ class TestAllFlow(TestAutoData): self.assertTrue(not analyze_df.isna().any().any(), "backtest failed") @pytest.mark.slow - def test_3_expmanager(self): + def test_2_expmanager(self): pass_default, pass_current, uri_path = fake_experiment() self.assertTrue(pass_default, msg="default uri is incorrect") self.assertTrue(pass_current, msg="current uri is incorrect") @@ -217,10 +180,9 @@ class TestAllFlow(TestAutoData): def suite(): _suite = unittest.TestSuite() - _suite.addTest(TestAllFlow("test_0_train_with_sigana")) - _suite.addTest(TestAllFlow("test_1_train")) - _suite.addTest(TestAllFlow("test_2_backtest")) - _suite.addTest(TestAllFlow("test_3_expmanager")) + _suite.addTest(TestAllFlow("test_0_train")) + _suite.addTest(TestAllFlow("test_1_backtest")) + _suite.addTest(TestAllFlow("test_2_expmanager")) return _suite diff --git a/tests/test_contrib_workflow.py b/tests/test_contrib_workflow.py index 3d250a142..c556472c0 100644 --- a/tests/test_contrib_workflow.py +++ b/tests/test_contrib_workflow.py @@ -11,7 +11,24 @@ from qlib.contrib.workflow import MultiSegRecord, SignalMseRecord from qlib.utils import init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.tests import TestAutoData -from qlib.tests.config import CSI300_GBDT_TASK +from qlib.tests.config import GBDT_MODEL, get_dataset_config, CSI300_MARKET + + +CSI300_GBDT_TASK = { + "model": GBDT_MODEL, + "dataset": get_dataset_config( + train=("2020-05-01", "2020-06-01"), + valid=("2020-06-01", "2020-07-01"), + test=("2020-07-01", "2020-08-01"), + handler_kwargs={ + "start_time": "2020-05-01", + "end_time": "2020-08-01", + "fit_start_time": "", + "fit_end_time": "", + "instruments": CSI300_MARKET, + }, + ), +} def train_multiseg(uri_path: str = None): diff --git a/tests/test_dataset.py b/tests/test_dataset.py index ebb4aaa55..dc2ec812f 100755 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -19,10 +19,10 @@ class TestDataset(TestAutoData): "class": "Alpha158", "module_path": "qlib.contrib.data.handler", "kwargs": { - "start_time": "2008-01-01", + "start_time": "2017-01-01", "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", + "fit_start_time": "2017-01-01", + "fit_end_time": "2017-12-31", "instruments": "csi300", "infer_processors": [ {"class": "FilterCol", "kwargs": {"col_list": ["RESI5", "WVMA5", "RSQR5"]}}, @@ -36,9 +36,9 @@ class TestDataset(TestAutoData): }, }, segments={ - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), + "train": ("2017-01-01", "2017-12-31"), + "valid": ("2018-01-01", "2018-12-31"), + "test": ("2019-01-01", "2020-08-01"), }, ) tsds_train = tsdh.prepare("train", data_key=DataHandlerLP.DK_L) # Test the correctness @@ -63,13 +63,13 @@ class TestDataset(TestAutoData): tsds[len(tsds) - 1] # 2) sample by index - data_from_ds = tsds["2016-12-31", "SZ300315"] + data_from_ds = tsds["2017-12-31", "SZ300315"] # Check the data # Get data from DataFrame Directly data_from_df = ( tsdh.handler.fetch(data_key=DataHandlerLP.DK_L) - .loc(axis=0)["2015-01-01":"2016-12-31", "SZ300315"] + .loc(axis=0)["2017-01-01":"2017-12-31", "SZ300315"] .iloc[-30:] .values )