1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 11:00:57 +08:00
This commit is contained in:
Xu Yang
2023-06-13 15:27:59 +08:00
parent 80fbc00792
commit 429c9a7c66

View File

@@ -42,7 +42,9 @@ class Task:
self._context_manager = None
self.prompt_template = PormptTemplate()
self.executed = False
self.logger: logging.Logger = get_module_logger(f"finco.{self.__class__.__name__}")
self.logger: logging.Logger = get_module_logger(
f"finco.{self.__class__.__name__}"
)
def summarize(self) -> str:
"""After the execution of the task, it is supposed to generated some context about the execution"""
@@ -74,12 +76,16 @@ class Task:
"""The user can interact with the task"""
"""All sub classes should implement the interact method to determine the next task"""
"""In continous mode, this method will not be called and the next task will be determined by the execution method only"""
raise NotImplementedError("The interact method is not implemented, but workflow not in continous mode")
raise NotImplementedError(
"The interact method is not implemented, but workflow not in continous mode"
)
@property
def system(self):
return self.prompt_template.__getattribute__(self.__class__.__name__ + "_system")
return self.prompt_template.__getattribute__(
self.__class__.__name__ + "_system"
)
@property
def user(self):
return self.prompt_template.__getattribute__(self.__class__.__name__ + "_user")
@@ -149,7 +155,9 @@ class SLPlanTask(PlanTask):
def execute(self):
workflow = self._context_manager.get_context("workflow")
assert workflow == "supervised learning", "The workflow is not supervised learning"
assert (
workflow == "supervised learning"
), "The workflow is not supervised learning"
user_prompt = self._context_manager.get_context("user_prompt")
assert user_prompt is not None, "The user prompt is not provided"
@@ -157,7 +165,9 @@ class SLPlanTask(PlanTask):
response = APIBackend().build_messages_and_create_chat_completion(
prompt_plan_all, self.system.render()
)
self.save_chat_history_to_context_manager(prompt_plan_all, response, self.system.render())
self.save_chat_history_to_context_manager(
prompt_plan_all, response, self.system.render()
)
if "components" not in response:
self.logger.warning(
"The response is not in the correct format, which probably means the answer is not correct"
@@ -216,8 +226,14 @@ class RecorderTask(Task):
This Recorder task is responsible for analysing data such as index and distribution.
"""
__ANALYZERS_PROJECT = {HFAnalyzer.__name__: HFSignalRecord, SignalAnalyzer.__name__: SignalRecord}
__ANALYZERS_DOCS = {HFAnalyzer.__name__: HFAnalyzer.__doc__, SignalAnalyzer.__name__: SignalAnalyzer.__doc__}
__ANALYZERS_PROJECT = {
HFAnalyzer.__name__: HFSignalRecord,
SignalAnalyzer.__name__: SignalRecord,
}
__ANALYZERS_DOCS = {
HFAnalyzer.__name__: HFAnalyzer.__doc__,
SignalAnalyzer.__name__: SignalAnalyzer.__doc__,
}
# __ANALYZERS_PROJECT = {SignalAnalyzer.__name__: SignalRecord}
# __ANALYZERS_DOCS = {SignalAnalyzer.__name__: SignalAnalyzer.__doc__}
@@ -231,7 +247,13 @@ class RecorderTask(Task):
)
be = APIBackend()
be.debug_mode = False
response = be.build_messages_and_create_chat_completion(prompt, self.system.render(ANALYZERS_list=list(self.__ANALYZERS_DOCS.keys()), ANALYZERS_DOCS=self.__ANALYZERS_DOCS))
response = be.build_messages_and_create_chat_completion(
prompt,
self.system.render(
ANALYZERS_list=list(self.__ANALYZERS_DOCS.keys()),
ANALYZERS_DOCS=self.__ANALYZERS_DOCS,
),
)
# it's better to move to another Task
workflow_config = (
@@ -263,7 +285,10 @@ class RecorderTask(Task):
if analyser in self.__ANALYZERS_PROJECT.keys():
tasks.append(
self.__ANALYZERS_PROJECT.get(analyser)(
workspace=workspace, model=model, dataset=dataset, recorder=recorder
workspace=workspace,
model=model,
dataset=dataset,
recorder=recorder,
)
)
@@ -282,6 +307,7 @@ class CMDTask(ActionTask):
"""
This CMD task is responsible for ensuring compatibility across different operating systems.
"""
def __init__(self, cmd_intention: str, cwd=None):
self.cwd = cwd
self.cmd_intention = cmd_intention
@@ -292,7 +318,9 @@ class CMDTask(ActionTask):
prompt = self.user.render(
cmd_intention=self.cmd_intention, user_os=platform.system()
)
response = APIBackend().build_messages_and_create_chat_completion(prompt, self.system.render())
response = APIBackend().build_messages_and_create_chat_completion(
prompt, self.system.render()
)
self._output = subprocess.check_output(response, shell=True, cwd=self.cwd)
return []
@@ -300,7 +328,9 @@ class CMDTask(ActionTask):
if self._output is not None:
# TODO: it will be overrides by later commands
# utf8 can't decode normally on Windows
self._context_manager.set_context(self.__class__.__name__, self._output.decode("ANSI"))
self._context_manager.set_context(
self.__class__.__name__, self._output.decode("ANSI")
)
class ConfigActionTask(ActionTask):
@@ -313,10 +343,16 @@ class ConfigActionTask(ActionTask):
component_list = ["Dataset", "Model", "Record", "Strategy", "Backtest"]
prompt_element_dict = dict()
for component in component_list:
prompt_element_dict[f"{component}_decision"] = self._context_manager.get_context(f"{component}_decision")
prompt_element_dict[f"{component}_plan"] = self._context_manager.get_context(f"{component}_plan")
prompt_element_dict[
f"{component}_decision"
] = self._context_manager.get_context(f"{component}_decision")
prompt_element_dict[
f"{component}_plan"
] = self._context_manager.get_context(f"{component}_plan")
assert None not in prompt_element_dict.values(), "Some decision or plan is not set by plan maker"
assert (
None not in prompt_element_dict.values()
), "Some decision or plan is not set by plan maker"
config_prompt = self.user.render(
user_requirement=user_prompt,
@@ -335,21 +371,29 @@ class ConfigActionTask(ActionTask):
response = APIBackend().build_messages_and_create_chat_completion(
config_prompt, self.system.render()
)
self.save_chat_history_to_context_manager(config_prompt, response, self.system.render())
res = re.search(r"Config:(.*)Reason:(.*)Improve suggestion:(.*)", response, re.S)
self.save_chat_history_to_context_manager(
config_prompt, response, self.system.render()
)
res = re.search(
r"Config:(.*)Reason:(.*)Improve suggestion:(.*)", response, re.S
)
assert (
res is not None and len(res.groups()) == 3
), "The response of config action task is not in the correct format"
config = re.search(r"```yaml(.*)```", res.group(1), re.S)
assert config is not None, "The config part of config action task response is not in the correct format"
assert (
config is not None
), "The config part of config action task response is not in the correct format"
config = config.group(1)
reason = res.group(2)
improve_suggestion = res.group(3)
self._context_manager.set_context(f"{self.target_componet}_config", config)
self._context_manager.set_context(f"{self.target_componet}_reason", reason)
self._context_manager.set_context(f"{self.target_componet}_improve_suggestion", improve_suggestion)
self._context_manager.set_context(
f"{self.target_componet}_improve_suggestion", improve_suggestion
)
return []
@@ -368,10 +412,16 @@ class ImplementActionTask(ActionTask):
component_list = ["Dataset", "Model", "Record", "Strategy", "Backtest"]
prompt_element_dict = dict()
for component in component_list:
prompt_element_dict[f"{component}_decision"] = self._context_manager.get_context(f"{component}_decision")
prompt_element_dict[f"{component}_plan"] = self._context_manager.get_context(f"{component}_plan")
prompt_element_dict[
f"{component}_decision"
] = self._context_manager.get_context(f"{component}_decision")
prompt_element_dict[
f"{component}_plan"
] = self._context_manager.get_context(f"{component}_plan")
assert None not in prompt_element_dict.values(), "Some decision or plan is not set by plan maker"
assert (
None not in prompt_element_dict.values()
), "Some decision or plan is not set by plan maker"
config = self._context_manager.get_context(f"{self.target_component}_config")
implement_prompt = self.user.render(
@@ -396,13 +446,17 @@ class ImplementActionTask(ActionTask):
implement_prompt, response, self.system.render()
)
res = re.search(r"Code:(.*)Explanation:(.*)Modified config:(.*)", response, re.S)
res = re.search(
r"Code:(.*)Explanation:(.*)Modified config:(.*)", response, re.S
)
assert (
res is not None and len(res.groups()) == 3
), f"The response of implement action task of component {self.target_component} is not in the correct format"
code = re.search(r"```python(.*)```", res.group(1), re.S)
assert code is not None, "The code part of implementation action task response is not in the correct format"
assert (
code is not None
), "The code part of implementation action task response is not in the correct format"
code = code.group(1)
explanation = res.group(2)
modified_config = re.search(r"```yaml(.*)```", res.group(3), re.S)
@@ -412,8 +466,12 @@ class ImplementActionTask(ActionTask):
modified_config = modified_config.group(1)
self._context_manager.set_context(f"{self.target_component}_code", code)
self._context_manager.set_context(f"{self.target_component}_code_explanation", explanation)
self._context_manager.set_context(f"{self.target_component}_modified_config", modified_config)
self._context_manager.set_context(
f"{self.target_component}_code_explanation", explanation
)
self._context_manager.set_context(
f"{self.target_component}_modified_config", modified_config
)
return []
@@ -458,7 +516,9 @@ class YamlEditTask(ActionTask):
class SummarizeTask(Task):
__DEFAULT_WORKSPACE = "./"
__DEFAULT_USER_PROMPT = "Summarize the information I offered and give me some advice."
__DEFAULT_USER_PROMPT = (
"Summarize the information I offered and give me some advice."
)
# TODO: 2048 is close to exceed GPT token limit
__MAX_LENGTH_OF_FILE = 2048
@@ -474,7 +534,9 @@ class SummarizeTask(Task):
self.workspace = workspace
user_prompt = self._context_manager.get_context("user_prompt")
user_prompt = user_prompt if user_prompt is not None else self.__DEFAULT_USER_PROMPT
user_prompt = (
user_prompt if user_prompt is not None else self.__DEFAULT_USER_PROMPT
)
file_info = self.get_info_from_file(workspace)
context_info = [] # too long context make response unstable.
@@ -519,7 +581,9 @@ class SummarizeTask(Task):
self.logger.info(f"file to summarize: {file}")
# in case of too large file
# TODO: Perhaps summarization method instead of truncation would be a better approach
result.append({"file": file, "content": content[: self.__MAX_LENGTH_OF_FILE]})
result.append(
{"file": file, "content": content[: self.__MAX_LENGTH_OF_FILE]}
)
return result