mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 11:00:57 +08:00
format
This commit is contained in:
@@ -42,7 +42,9 @@ class Task:
|
||||
self._context_manager = None
|
||||
self.prompt_template = PormptTemplate()
|
||||
self.executed = False
|
||||
self.logger: logging.Logger = get_module_logger(f"finco.{self.__class__.__name__}")
|
||||
self.logger: logging.Logger = get_module_logger(
|
||||
f"finco.{self.__class__.__name__}"
|
||||
)
|
||||
|
||||
def summarize(self) -> str:
|
||||
"""After the execution of the task, it is supposed to generated some context about the execution"""
|
||||
@@ -74,12 +76,16 @@ class Task:
|
||||
"""The user can interact with the task"""
|
||||
"""All sub classes should implement the interact method to determine the next task"""
|
||||
"""In continous mode, this method will not be called and the next task will be determined by the execution method only"""
|
||||
raise NotImplementedError("The interact method is not implemented, but workflow not in continous mode")
|
||||
raise NotImplementedError(
|
||||
"The interact method is not implemented, but workflow not in continous mode"
|
||||
)
|
||||
|
||||
@property
|
||||
def system(self):
|
||||
return self.prompt_template.__getattribute__(self.__class__.__name__ + "_system")
|
||||
|
||||
return self.prompt_template.__getattribute__(
|
||||
self.__class__.__name__ + "_system"
|
||||
)
|
||||
|
||||
@property
|
||||
def user(self):
|
||||
return self.prompt_template.__getattribute__(self.__class__.__name__ + "_user")
|
||||
@@ -149,7 +155,9 @@ class SLPlanTask(PlanTask):
|
||||
|
||||
def execute(self):
|
||||
workflow = self._context_manager.get_context("workflow")
|
||||
assert workflow == "supervised learning", "The workflow is not supervised learning"
|
||||
assert (
|
||||
workflow == "supervised learning"
|
||||
), "The workflow is not supervised learning"
|
||||
|
||||
user_prompt = self._context_manager.get_context("user_prompt")
|
||||
assert user_prompt is not None, "The user prompt is not provided"
|
||||
@@ -157,7 +165,9 @@ class SLPlanTask(PlanTask):
|
||||
response = APIBackend().build_messages_and_create_chat_completion(
|
||||
prompt_plan_all, self.system.render()
|
||||
)
|
||||
self.save_chat_history_to_context_manager(prompt_plan_all, response, self.system.render())
|
||||
self.save_chat_history_to_context_manager(
|
||||
prompt_plan_all, response, self.system.render()
|
||||
)
|
||||
if "components" not in response:
|
||||
self.logger.warning(
|
||||
"The response is not in the correct format, which probably means the answer is not correct"
|
||||
@@ -216,8 +226,14 @@ class RecorderTask(Task):
|
||||
This Recorder task is responsible for analysing data such as index and distribution.
|
||||
"""
|
||||
|
||||
__ANALYZERS_PROJECT = {HFAnalyzer.__name__: HFSignalRecord, SignalAnalyzer.__name__: SignalRecord}
|
||||
__ANALYZERS_DOCS = {HFAnalyzer.__name__: HFAnalyzer.__doc__, SignalAnalyzer.__name__: SignalAnalyzer.__doc__}
|
||||
__ANALYZERS_PROJECT = {
|
||||
HFAnalyzer.__name__: HFSignalRecord,
|
||||
SignalAnalyzer.__name__: SignalRecord,
|
||||
}
|
||||
__ANALYZERS_DOCS = {
|
||||
HFAnalyzer.__name__: HFAnalyzer.__doc__,
|
||||
SignalAnalyzer.__name__: SignalAnalyzer.__doc__,
|
||||
}
|
||||
# __ANALYZERS_PROJECT = {SignalAnalyzer.__name__: SignalRecord}
|
||||
# __ANALYZERS_DOCS = {SignalAnalyzer.__name__: SignalAnalyzer.__doc__}
|
||||
|
||||
@@ -231,7 +247,13 @@ class RecorderTask(Task):
|
||||
)
|
||||
be = APIBackend()
|
||||
be.debug_mode = False
|
||||
response = be.build_messages_and_create_chat_completion(prompt, self.system.render(ANALYZERS_list=list(self.__ANALYZERS_DOCS.keys()), ANALYZERS_DOCS=self.__ANALYZERS_DOCS))
|
||||
response = be.build_messages_and_create_chat_completion(
|
||||
prompt,
|
||||
self.system.render(
|
||||
ANALYZERS_list=list(self.__ANALYZERS_DOCS.keys()),
|
||||
ANALYZERS_DOCS=self.__ANALYZERS_DOCS,
|
||||
),
|
||||
)
|
||||
|
||||
# it's better to move to another Task
|
||||
workflow_config = (
|
||||
@@ -263,7 +285,10 @@ class RecorderTask(Task):
|
||||
if analyser in self.__ANALYZERS_PROJECT.keys():
|
||||
tasks.append(
|
||||
self.__ANALYZERS_PROJECT.get(analyser)(
|
||||
workspace=workspace, model=model, dataset=dataset, recorder=recorder
|
||||
workspace=workspace,
|
||||
model=model,
|
||||
dataset=dataset,
|
||||
recorder=recorder,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -282,6 +307,7 @@ class CMDTask(ActionTask):
|
||||
"""
|
||||
This CMD task is responsible for ensuring compatibility across different operating systems.
|
||||
"""
|
||||
|
||||
def __init__(self, cmd_intention: str, cwd=None):
|
||||
self.cwd = cwd
|
||||
self.cmd_intention = cmd_intention
|
||||
@@ -292,7 +318,9 @@ class CMDTask(ActionTask):
|
||||
prompt = self.user.render(
|
||||
cmd_intention=self.cmd_intention, user_os=platform.system()
|
||||
)
|
||||
response = APIBackend().build_messages_and_create_chat_completion(prompt, self.system.render())
|
||||
response = APIBackend().build_messages_and_create_chat_completion(
|
||||
prompt, self.system.render()
|
||||
)
|
||||
self._output = subprocess.check_output(response, shell=True, cwd=self.cwd)
|
||||
return []
|
||||
|
||||
@@ -300,7 +328,9 @@ class CMDTask(ActionTask):
|
||||
if self._output is not None:
|
||||
# TODO: it will be overrides by later commands
|
||||
# utf8 can't decode normally on Windows
|
||||
self._context_manager.set_context(self.__class__.__name__, self._output.decode("ANSI"))
|
||||
self._context_manager.set_context(
|
||||
self.__class__.__name__, self._output.decode("ANSI")
|
||||
)
|
||||
|
||||
|
||||
class ConfigActionTask(ActionTask):
|
||||
@@ -313,10 +343,16 @@ class ConfigActionTask(ActionTask):
|
||||
component_list = ["Dataset", "Model", "Record", "Strategy", "Backtest"]
|
||||
prompt_element_dict = dict()
|
||||
for component in component_list:
|
||||
prompt_element_dict[f"{component}_decision"] = self._context_manager.get_context(f"{component}_decision")
|
||||
prompt_element_dict[f"{component}_plan"] = self._context_manager.get_context(f"{component}_plan")
|
||||
prompt_element_dict[
|
||||
f"{component}_decision"
|
||||
] = self._context_manager.get_context(f"{component}_decision")
|
||||
prompt_element_dict[
|
||||
f"{component}_plan"
|
||||
] = self._context_manager.get_context(f"{component}_plan")
|
||||
|
||||
assert None not in prompt_element_dict.values(), "Some decision or plan is not set by plan maker"
|
||||
assert (
|
||||
None not in prompt_element_dict.values()
|
||||
), "Some decision or plan is not set by plan maker"
|
||||
|
||||
config_prompt = self.user.render(
|
||||
user_requirement=user_prompt,
|
||||
@@ -335,21 +371,29 @@ class ConfigActionTask(ActionTask):
|
||||
response = APIBackend().build_messages_and_create_chat_completion(
|
||||
config_prompt, self.system.render()
|
||||
)
|
||||
self.save_chat_history_to_context_manager(config_prompt, response, self.system.render())
|
||||
res = re.search(r"Config:(.*)Reason:(.*)Improve suggestion:(.*)", response, re.S)
|
||||
self.save_chat_history_to_context_manager(
|
||||
config_prompt, response, self.system.render()
|
||||
)
|
||||
res = re.search(
|
||||
r"Config:(.*)Reason:(.*)Improve suggestion:(.*)", response, re.S
|
||||
)
|
||||
assert (
|
||||
res is not None and len(res.groups()) == 3
|
||||
), "The response of config action task is not in the correct format"
|
||||
|
||||
config = re.search(r"```yaml(.*)```", res.group(1), re.S)
|
||||
assert config is not None, "The config part of config action task response is not in the correct format"
|
||||
assert (
|
||||
config is not None
|
||||
), "The config part of config action task response is not in the correct format"
|
||||
config = config.group(1)
|
||||
reason = res.group(2)
|
||||
improve_suggestion = res.group(3)
|
||||
|
||||
self._context_manager.set_context(f"{self.target_componet}_config", config)
|
||||
self._context_manager.set_context(f"{self.target_componet}_reason", reason)
|
||||
self._context_manager.set_context(f"{self.target_componet}_improve_suggestion", improve_suggestion)
|
||||
self._context_manager.set_context(
|
||||
f"{self.target_componet}_improve_suggestion", improve_suggestion
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
@@ -368,10 +412,16 @@ class ImplementActionTask(ActionTask):
|
||||
component_list = ["Dataset", "Model", "Record", "Strategy", "Backtest"]
|
||||
prompt_element_dict = dict()
|
||||
for component in component_list:
|
||||
prompt_element_dict[f"{component}_decision"] = self._context_manager.get_context(f"{component}_decision")
|
||||
prompt_element_dict[f"{component}_plan"] = self._context_manager.get_context(f"{component}_plan")
|
||||
prompt_element_dict[
|
||||
f"{component}_decision"
|
||||
] = self._context_manager.get_context(f"{component}_decision")
|
||||
prompt_element_dict[
|
||||
f"{component}_plan"
|
||||
] = self._context_manager.get_context(f"{component}_plan")
|
||||
|
||||
assert None not in prompt_element_dict.values(), "Some decision or plan is not set by plan maker"
|
||||
assert (
|
||||
None not in prompt_element_dict.values()
|
||||
), "Some decision or plan is not set by plan maker"
|
||||
config = self._context_manager.get_context(f"{self.target_component}_config")
|
||||
|
||||
implement_prompt = self.user.render(
|
||||
@@ -396,13 +446,17 @@ class ImplementActionTask(ActionTask):
|
||||
implement_prompt, response, self.system.render()
|
||||
)
|
||||
|
||||
res = re.search(r"Code:(.*)Explanation:(.*)Modified config:(.*)", response, re.S)
|
||||
res = re.search(
|
||||
r"Code:(.*)Explanation:(.*)Modified config:(.*)", response, re.S
|
||||
)
|
||||
assert (
|
||||
res is not None and len(res.groups()) == 3
|
||||
), f"The response of implement action task of component {self.target_component} is not in the correct format"
|
||||
|
||||
code = re.search(r"```python(.*)```", res.group(1), re.S)
|
||||
assert code is not None, "The code part of implementation action task response is not in the correct format"
|
||||
assert (
|
||||
code is not None
|
||||
), "The code part of implementation action task response is not in the correct format"
|
||||
code = code.group(1)
|
||||
explanation = res.group(2)
|
||||
modified_config = re.search(r"```yaml(.*)```", res.group(3), re.S)
|
||||
@@ -412,8 +466,12 @@ class ImplementActionTask(ActionTask):
|
||||
modified_config = modified_config.group(1)
|
||||
|
||||
self._context_manager.set_context(f"{self.target_component}_code", code)
|
||||
self._context_manager.set_context(f"{self.target_component}_code_explanation", explanation)
|
||||
self._context_manager.set_context(f"{self.target_component}_modified_config", modified_config)
|
||||
self._context_manager.set_context(
|
||||
f"{self.target_component}_code_explanation", explanation
|
||||
)
|
||||
self._context_manager.set_context(
|
||||
f"{self.target_component}_modified_config", modified_config
|
||||
)
|
||||
|
||||
return []
|
||||
|
||||
@@ -458,7 +516,9 @@ class YamlEditTask(ActionTask):
|
||||
class SummarizeTask(Task):
|
||||
__DEFAULT_WORKSPACE = "./"
|
||||
|
||||
__DEFAULT_USER_PROMPT = "Summarize the information I offered and give me some advice."
|
||||
__DEFAULT_USER_PROMPT = (
|
||||
"Summarize the information I offered and give me some advice."
|
||||
)
|
||||
|
||||
# TODO: 2048 is close to exceed GPT token limit
|
||||
__MAX_LENGTH_OF_FILE = 2048
|
||||
@@ -474,7 +534,9 @@ class SummarizeTask(Task):
|
||||
self.workspace = workspace
|
||||
|
||||
user_prompt = self._context_manager.get_context("user_prompt")
|
||||
user_prompt = user_prompt if user_prompt is not None else self.__DEFAULT_USER_PROMPT
|
||||
user_prompt = (
|
||||
user_prompt if user_prompt is not None else self.__DEFAULT_USER_PROMPT
|
||||
)
|
||||
|
||||
file_info = self.get_info_from_file(workspace)
|
||||
context_info = [] # too long context make response unstable.
|
||||
@@ -519,7 +581,9 @@ class SummarizeTask(Task):
|
||||
self.logger.info(f"file to summarize: {file}")
|
||||
# in case of too large file
|
||||
# TODO: Perhaps summarization method instead of truncation would be a better approach
|
||||
result.append({"file": file, "content": content[: self.__MAX_LENGTH_OF_FILE]})
|
||||
result.append(
|
||||
{"file": file, "content": content[: self.__MAX_LENGTH_OF_FILE]}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
Reference in New Issue
Block a user