1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 10:31:00 +08:00

Add backtest and backforward task (#1568)

* * add TrainTask & BacktestTask;
* add BackForwardTask;
* adjust prompt_template.yaml which default config failed to backtest;
* run workflow in loop
* add update method to prompt_template.py

* remove debug code

* Adjust Learn Process
* add LearnManager class & use LearnManager to update system prompt;
* use qrun to replace recorder for training and backtesting;

* Adjust analyser
* analyser independent of recorder;
* rename analyser's workspace attribution;
* analyser load variable by recorder.

---------

Co-authored-by: Cadenza-Li <362237642@qq.com>
This commit is contained in:
Fivele-Li
2023-06-30 10:04:43 +08:00
committed by GitHub
parent 1326ac614d
commit 7e84f3aae2
8 changed files with 179 additions and 92 deletions

View File

@@ -3,10 +3,6 @@ import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
from qlib.utils import class_casting
from ..data.dataset import DatasetH
from ..data.dataset.handler import DataHandlerLP
from ..log import get_module_logger
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
@@ -14,8 +10,25 @@ logger = get_module_logger("analysis", logging.INFO)
class AnalyzerTemp:
def __init__(self, workspace=None, **kwargs):
self.workspace = Path(workspace) if workspace else "./"
def __init__(self, recorder, output_dir=None, **kwargs):
self.recorder = recorder
self.output_dir = Path(output_dir) if output_dir else "./"
def load(self, name: str):
"""
It behaves the same as self.recorder.load_object.
But it is an easier interface because users don't have to care about `get_path` and `artifact_path`
Parameters
----------
name : str
the name for the file to be load.
Return
------
The stored records.
"""
return self.recorder.load_object(name)
def analyse(self, **kwargs):
"""
@@ -42,7 +55,10 @@ class HFAnalyzer(AnalyzerTemp):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def analyse(self, pred=None, label=None):
def analyse(self):
pred = self.load("pred.pkl")
label = self.load("label.pkl")
long_pre, short_pre = calc_long_short_prec(pred.iloc[:, 0], label.iloc[:, 0], is_alpha=True)
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
metrics = {
@@ -65,13 +81,13 @@ class HFAnalyzer(AnalyzerTemp):
table = [[k, v] for (k, v) in metrics.items()]
plt.table(cellText=table, loc="center")
plt.axis("off")
plt.savefig(self.workspace.joinpath("HFAnalyzerTable.jpeg"))
plt.savefig(self.output_dir.joinpath("HFAnalyzerTable.jpeg"))
plt.clf()
plt.scatter(np.arange(0, len(pred)), pred.iloc[:, 0])
plt.scatter(np.arange(0, len(label)), label.iloc[:, 0])
plt.title("HFAnalyzer")
plt.savefig(self.workspace.joinpath("HFAnalyzer.jpeg"))
plt.savefig(self.output_dir.joinpath("HFAnalyzer.jpeg"))
return "HFAnalyzer.jpeg"
@@ -86,24 +102,10 @@ class SignalAnalyzer(AnalyzerTemp):
super().__init__(**kwargs)
def analyse(self, dataset=None, **kwargs):
label = self.load("label.pkl")
with class_casting(dataset, DatasetH):
params = dict(segments="test", col_set="label", data_key=DataHandlerLP.DK_R)
try:
# Assume the backend handler is DataHandlerLP
raw_label = dataset.prepare(**params)
except TypeError:
# The argument number is not right
del params["data_key"]
# The backend handler should be DataHandler
raw_label = dataset.prepare(**params)
except AttributeError as e:
# The data handler is initialized with `drop_raw=True`...
# So raw_label is not available
logger.warning(f"Exception: {e}")
raw_label = None
plt.hist(raw_label)
plt.hist(label)
plt.title("SignalAnalyzer")
plt.savefig(self.workspace.joinpath("signalAnalysis.jpeg"))
plt.savefig(self.output_dir.joinpath("signalAnalysis.jpeg"))
return "signalAnalysis.jpeg"

15
qlib/finco/cli_learn.py Normal file
View File

@@ -0,0 +1,15 @@
import fire
from qlib.finco.workflow import LearnManager
from dotenv import load_dotenv
from qlib import auto_init
def main(prompt=None):
load_dotenv(verbose=True, override=True)
lm = LearnManager()
lm.run(prompt)
if __name__ == "__main__":
auto_init()
fire.Fire(main)

View File

@@ -1,18 +1,37 @@
from typing import Union
from pathlib import Path
from jinja2 import Template
import yaml
from qlib.finco.utils import Singleton
from qlib.finco import get_finco_path
import yaml
import os
class PormptTemplate(Singleton):
class PromptTemplate(Singleton):
def __init__(self) -> None:
super().__init__()
_template = yaml.load(open(os.path.join(get_finco_path(), "prompt_template.yaml"), "r"), Loader=yaml.FullLoader)
_template = yaml.load(open(Path.joinpath(get_finco_path(), "prompt_template.yaml"), "r"),
Loader=yaml.FullLoader)
for k, v in _template.items():
if k == "mods":
continue
self.__setattr__(k, Template(v))
for target_name, module_to_render_params in _template["mods"].items():
for module_name, params in module_to_render_params.items():
self.__setattr__(f"{target_name}_{module_name}", Template(self.__getattribute__(target_name).render(**params)))
self.__setattr__(f"{target_name}_{module_name}",
Template(self.__getattribute__(target_name).render(**params)))
def get(self, key: str):
return self.__dict__.get(key, Template(""))
def update(self, key: str, value):
self.__setattr__(key, value)
def save(self, file_path: Union[str, Path]):
if isinstance(file_path, str):
file_path = Path(file_path)
Path.mkdir(file_path.parent, exist_ok=True)
with open(file_path, 'w') as f:
yaml.dump(self.__dict__, f)

View File

@@ -193,6 +193,13 @@ SummarizeTask_user : |-
Here is my information: '{{information}}'
My intention is: {{user_prompt}}. Please provide me with a summary and recommendation based on my intention and the information I have provided. There are some figures which absolute path are: {{figure_path}}, You must display these images in markdown using the appropriate image format.
BackForwardTask_system : |-
Your task is adjusting system prompt in each task to fulfill user's intention
BackForwardTask_user : |-
Here is the final summary: '{{summary}}'
Tasks I have run are: {{task_finished}}, {{task}}'s system prompt is: {{system}}. User's intention is: {{user_prompt}}. you will adjust it to:
mods:
ConfigActionTask_system:
Dataset:
@@ -382,7 +389,7 @@ mods:
```
Reason: I choose the backtest parameters above because they are suitable for a low turnover strategy focusing on long-term returns in the China A stock market. The start and end times are set to cover a 4-year period, which is reasonable for a long-term strategy. The account value is set to 1,000,000 as a starting point, and the benchmark is set to SH000300, which represents the China A stock market.
Improve suggestion: You can try different time ranges for the backtest to evaluate the performance of the strategy in different market conditions. Also, you can adjust the costs (open_cost, close_cost, and min_cost) to better reflect the actual trading costs in the China A stock market.
ConfigActionTask_user:
Dataset:
target_component : |-
@@ -402,7 +409,7 @@ mods:
Backtest:
target_component : |-
Backtest
ImplementActionTask_system:
Dataset:
target_component : |-

View File

@@ -11,10 +11,9 @@ import platform
from qlib.finco.llm import APIBackend
from qlib.finco.tpl import get_tpl_path
from qlib.finco.prompt_template import PormptTemplate
from qlib.finco.prompt_template import PromptTemplate
from qlib.workflow.record_temp import HFSignalRecord, SignalRecord
from qlib.contrib.analyzer import HFAnalyzer, SignalAnalyzer
from qlib.utils import init_instance_by_config
from qlib.workflow import R
from qlib.finco.log import FinCoLog, LogColors
from qlib.finco.conf import Config
@@ -41,7 +40,7 @@ class Task:
def __init__(self) -> None:
self._context_manager = None
self.prompt_template = PormptTemplate()
self.prompt_template = PromptTemplate()
self.executed = False
self.continuous = Config().continuous_mode
self.logger = FinCoLog()
@@ -96,13 +95,11 @@ class Task:
@property
def system(self):
return self.prompt_template.__getattribute__(
self.__class__.__name__ + "_system"
)
return self.prompt_template.get(self.__class__.__name__ + "_system")
@property
def user(self):
return self.prompt_template.__getattribute__(self.__class__.__name__ + "_user")
return self.prompt_template.get(self.__class__.__name__ + "_user")
def __str__(self):
return self.__class__.__name__
@@ -150,7 +147,7 @@ class PlanTask(Task):
class SLPlanTask(PlanTask):
def __init__(self,) -> None:
def __init__(self, ) -> None:
super().__init__()
def execute(self):
@@ -220,13 +217,14 @@ class RLPlanTask(PlanTask):
return []
class RecorderTask(Task):
class TrainTask(Task):
"""
This Recorder task is responsible for analysing data such as index and distribution.
This train task is responsible for training model configure by yaml file.
"""
def __init__(self):
super().__init__()
self._output = None
def execute(self):
workflow_config = (
@@ -234,6 +232,7 @@ class RecorderTask(Task):
if self._context_manager.get_context("workflow_config")
else "workflow_config.yaml"
)
workspace = self._context_manager.get_context("workspace")
workflow_path = workspace.joinpath(workflow_config)
with workflow_path.open() as f:
@@ -246,24 +245,19 @@ class RecorderTask(Task):
if confirm is False:
return []
model = init_instance_by_config(workflow["task"]["model"])
dataset = init_instance_by_config(workflow["task"]["dataset"])
with R.start(experiment_name="finCo"):
model.fit(dataset)
R.save_objects(trained_model=model)
# prediction
recorder = R.get_recorder()
sr = SignalRecord(model, dataset, recorder)
sr.generate()
self._context_manager.set_context("model", model)
self._context_manager.set_context("dataset", dataset)
self._context_manager.set_context("recorder", recorder)
command = f"qrun {workflow_path}"
self._output = subprocess.check_output(command, shell=True, cwd=workspace)
return [AnalysisTask()]
def summarize(self):
if self._output is not None:
# TODO: it will be overrides by later commands
# utf8 can't decode normally on Windows
self._context_manager.set_context(
self.__class__.__name__, self._output.decode("ANSI")
)
class AnalysisTask(Task):
"""
@@ -271,8 +265,8 @@ class AnalysisTask(Task):
"""
__ANALYZERS_PROJECT = {
HFAnalyzer.__name__: HFSignalRecord,
SignalAnalyzer.__name__: SignalRecord,
HFAnalyzer.__name__: HFAnalyzer,
SignalAnalyzer.__name__: SignalAnalyzer,
}
__ANALYZERS_DOCS = {
HFAnalyzer.__name__: HFAnalyzer.__doc__,
@@ -303,7 +297,7 @@ class AnalysisTask(Task):
ANALYZERS_DOCS=self.__ANALYZERS_DOCS,
),
)
analysers = response.split(",")
analysers = response.replace(" ", "").split(",")
confirm = self.interact(f"I select these analysers: {analysers}\n"
f"Are you sure you want to use? yes(Y/y), no(N/n) or prompt:")
if confirm is False:
@@ -317,15 +311,26 @@ class AnalysisTask(Task):
if isinstance(analysers, list) and len(analysers):
self.logger.info(f"selected analysers: {analysers}", plain=True)
workflow_config = (
self._context_manager.get_context("workflow_config")
if self._context_manager.get_context("workflow_config")
else "workflow_config.yaml"
)
workspace = self._context_manager.get_context("workspace")
workflow_path = workspace.joinpath(workflow_config)
with workflow_path.open() as f:
workflow = yaml.safe_load(f)
experiment_name = workflow["experiment_name"] if "experiment_name" in workflow else "workflow"
R.set_uri(Path.joinpath(workspace, 'mlruns').as_uri())
tasks = []
for analyser in analysers:
if analyser in self.__ANALYZERS_PROJECT.keys():
tasks.append(
self.__ANALYZERS_PROJECT.get(analyser)(
workspace=self._context_manager.get_context("workspace"),
model=self._context_manager.get_context("model"),
dataset=self._context_manager.get_context("dataset"),
recorder=self._context_manager.get_context("recorder"),
recorder=R.get_recorder(experiment_name=experiment_name),
output_dir=workspace
)
)
@@ -575,11 +580,14 @@ class SummarizeTask(Task):
information=information, figure_path=figure_path, user_prompt=user_prompt
)
# todo: remove 'be' after test
be = APIBackend()
be.debug_mode = False
response = be.build_messages_and_create_chat_completion(
user_prompt=prompt_workflow_selection, system_prompt=self.system.render()
)
self._context_manager.set_context("summary", response)
self.save_markdown(content=response)
self.logger.info(f"Report has saved to {self.__DEFAULT_REPORT_NAME}", title="End")

View File

@@ -1,6 +1,7 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
experiment_name: finCo
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
@@ -14,7 +15,7 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
model: <MODEL>
dataset: <DATASET>
topk: 50
n_drop: 5

View File

@@ -3,9 +3,11 @@ import copy
import shutil
from pathlib import Path
from qlib.finco.task import WorkflowTask, PlanTask, ActionTask, SummarizeTask, RecorderTask, AnalysisTask
from qlib.finco.task import WorkflowTask, SummarizeTask, TrainTask
from qlib.finco.prompt_template import PromptTemplate, Template
from qlib.finco.log import FinCoLog, LogColors
from qlib.finco.utils import similarity
from qlib.finco.llm import APIBackend
class WorkflowContextManager:
@@ -51,9 +53,16 @@ class WorkflowContextManager:
max_score_key = max(scores, key=scores.get)
return {max_score_key: self.context.get(max_score_key)}
def clear(self, reserve: list = None):
if reserve is None:
reserve = []
_context = {k: self.get_context(k) for k in reserve}
self.context = _context
class WorkflowManager:
"""This manange the whole task automation workflow including tasks and actions"""
"""This manage the whole task automation workflow including tasks and actions"""
def __init__(self, workspace=None) -> None:
self.logger = FinCoLog()
@@ -63,8 +72,10 @@ class WorkflowManager:
else:
self._workspace = Path(workspace)
self._confirm_and_rm()
self._context = WorkflowContextManager()
self._context.set_context("workspace", self._workspace)
self.prompt_template = PromptTemplate()
self.context = WorkflowContextManager()
self.context.set_context("workspace", self._workspace)
self.default_user_prompt = "Please help me build a low turnover strategy that focus more on longterm return in China a stock market. Please help to pick one third of the factors in Alpha360 and use lightGBM model."
def _confirm_and_rm(self):
@@ -87,10 +98,10 @@ class WorkflowManager:
def set_context(self, key, value):
"""Direct call set_context method of the context manager"""
self._context.set_context(key, value)
self.context.set_context(key, value)
def get_context(self) -> WorkflowContextManager:
return self._context
return self.context
def run(self, prompt: str) -> Path:
"""
@@ -124,7 +135,7 @@ class WorkflowManager:
self.logger.info(f"user_prompt: {self.get_context().get_context('user_prompt')}", title="Start")
# NOTE: list may not be enough for general task list
task_list = [WorkflowTask(), RecorderTask(), SummarizeTask()]
task_list = [WorkflowTask(), TrainTask(), SummarizeTask()]
task_finished = []
while len(task_list):
task_list_info = [str(task) for task in task_list]
@@ -138,15 +149,51 @@ class WorkflowManager:
f"Executing task: {str(t)}",
title="Task")
t.assign_context_manager(self._context)
t.assign_context_manager(self.context)
res = t.execute()
t.summarize()
task_finished.append(t)
self.context.set_context("task_finished", task_finished)
self.logger.plain_info(f"{str(t)} finished.\n\n\n")
for _ in res:
if not isinstance(t, (WorkflowTask, PlanTask, ActionTask, RecorderTask, AnalysisTask, SummarizeTask)):
raise NotImplementedError(f"Unsupported Task type {_}")
task_list = res + task_list
return self._workspace
class LearnManager:
def __init__(self):
self.epoch = 0
self.wm = WorkflowManager()
def run(self, prompt):
# todo: add early stop condition
for i in range(10):
self.wm.run(prompt)
self.learn()
self.epoch += 1
def learn(self):
workspace = self.wm.context.get_context("workspace")
task_finished = self.wm.context.get_context("task_finished")
user_prompt = self.wm.context.get_context("user_prompt")
summary = self.wm.context.get_context("summary")
for task in task_finished:
prompt_workflow_selection = task.user.render(
summary=summary, task_finished=[str(task) for task in task_finished],
task=task.__class__, system=task.system, user_prompt=user_prompt
)
response = APIBackend().build_messages_and_create_chat_completion(
user_prompt=prompt_workflow_selection, system_prompt=task.system.render()
)
# todo: response assertion
task.prompt_template.update(key=f"{task.__class__.__name__}_system", value=Template(response))
self.wm.prompt_template.save(Path.joinpath(workspace, f"prompts/checkpoint_{self.epoch}.yml"))
self.wm.context.clear(reserve=["workspace"])

View File

@@ -165,11 +165,10 @@ class SignalRecord(RecordTemp):
This is the Signal Record class that generates the signal prediction. This class inherits the ``RecordTemp`` class.
"""
def __init__(self, model=None, dataset=None, recorder=None, workspace=None):
def __init__(self, model=None, dataset=None, recorder=None):
super().__init__(recorder=recorder)
self.model = model
self.dataset = dataset
self.workspace = workspace
@staticmethod
def generate_label(dataset):
@@ -208,10 +207,6 @@ class SignalRecord(RecordTemp):
raw_label = self.generate_label(self.dataset)
self.save(**{"label.pkl": raw_label})
def analyse(self):
res = SignalAnalyzer(workspace=self.workspace).analyse(dataset=self.dataset)
return res
def list(self):
return ["pred.pkl", "label.pkl"]
@@ -253,9 +248,8 @@ class HFSignalRecord(SignalRecord):
artifact_path = "hg_sig_analysis"
depend_cls = SignalRecord
def __init__(self, recorder, workspace=None, **kwargs):
def __init__(self, recorder, **kwargs):
super().__init__(recorder=recorder)
self.workspace = workspace
def generate(self):
pred = self.load("pred.pkl")
@@ -289,12 +283,6 @@ class HFSignalRecord(SignalRecord):
self.save(**objects)
pprint(metrics)
def analyse(self):
pred = self.load("pred.pkl")
raw_label = self.load("label.pkl")
res = HFAnalyzer(workspace=self.workspace).analyse(pred=pred, label=raw_label)
return res
def list(self):
return ["ic.pkl", "ric.pkl", "long_pre.pkl", "short_pre.pkl", "long_short_r.pkl", "long_avg_r.pkl"]