From 429c9a7c6631c40bd93b44bcda25e34a73b80929 Mon Sep 17 00:00:00 2001 From: Xu Yang Date: Tue, 13 Jun 2023 15:27:59 +0800 Subject: [PATCH] format --- qlib/finco/task.py | 122 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 93 insertions(+), 29 deletions(-) diff --git a/qlib/finco/task.py b/qlib/finco/task.py index f5c09267b..3486e9a91 100644 --- a/qlib/finco/task.py +++ b/qlib/finco/task.py @@ -42,7 +42,9 @@ class Task: self._context_manager = None self.prompt_template = PormptTemplate() self.executed = False - self.logger: logging.Logger = get_module_logger(f"finco.{self.__class__.__name__}") + self.logger: logging.Logger = get_module_logger( + f"finco.{self.__class__.__name__}" + ) def summarize(self) -> str: """After the execution of the task, it is supposed to generated some context about the execution""" @@ -74,12 +76,16 @@ class Task: """The user can interact with the task""" """All sub classes should implement the interact method to determine the next task""" """In continous mode, this method will not be called and the next task will be determined by the execution method only""" - raise NotImplementedError("The interact method is not implemented, but workflow not in continous mode") + raise NotImplementedError( + "The interact method is not implemented, but workflow not in continous mode" + ) @property def system(self): - return self.prompt_template.__getattribute__(self.__class__.__name__ + "_system") - + return self.prompt_template.__getattribute__( + self.__class__.__name__ + "_system" + ) + @property def user(self): return self.prompt_template.__getattribute__(self.__class__.__name__ + "_user") @@ -149,7 +155,9 @@ class SLPlanTask(PlanTask): def execute(self): workflow = self._context_manager.get_context("workflow") - assert workflow == "supervised learning", "The workflow is not supervised learning" + assert ( + workflow == "supervised learning" + ), "The workflow is not supervised learning" user_prompt = self._context_manager.get_context("user_prompt") assert user_prompt is not None, "The user prompt is not provided" @@ -157,7 +165,9 @@ class SLPlanTask(PlanTask): response = APIBackend().build_messages_and_create_chat_completion( prompt_plan_all, self.system.render() ) - self.save_chat_history_to_context_manager(prompt_plan_all, response, self.system.render()) + self.save_chat_history_to_context_manager( + prompt_plan_all, response, self.system.render() + ) if "components" not in response: self.logger.warning( "The response is not in the correct format, which probably means the answer is not correct" @@ -216,8 +226,14 @@ class RecorderTask(Task): This Recorder task is responsible for analysing data such as index and distribution. """ - __ANALYZERS_PROJECT = {HFAnalyzer.__name__: HFSignalRecord, SignalAnalyzer.__name__: SignalRecord} - __ANALYZERS_DOCS = {HFAnalyzer.__name__: HFAnalyzer.__doc__, SignalAnalyzer.__name__: SignalAnalyzer.__doc__} + __ANALYZERS_PROJECT = { + HFAnalyzer.__name__: HFSignalRecord, + SignalAnalyzer.__name__: SignalRecord, + } + __ANALYZERS_DOCS = { + HFAnalyzer.__name__: HFAnalyzer.__doc__, + SignalAnalyzer.__name__: SignalAnalyzer.__doc__, + } # __ANALYZERS_PROJECT = {SignalAnalyzer.__name__: SignalRecord} # __ANALYZERS_DOCS = {SignalAnalyzer.__name__: SignalAnalyzer.__doc__} @@ -231,7 +247,13 @@ class RecorderTask(Task): ) be = APIBackend() be.debug_mode = False - response = be.build_messages_and_create_chat_completion(prompt, self.system.render(ANALYZERS_list=list(self.__ANALYZERS_DOCS.keys()), ANALYZERS_DOCS=self.__ANALYZERS_DOCS)) + response = be.build_messages_and_create_chat_completion( + prompt, + self.system.render( + ANALYZERS_list=list(self.__ANALYZERS_DOCS.keys()), + ANALYZERS_DOCS=self.__ANALYZERS_DOCS, + ), + ) # it's better to move to another Task workflow_config = ( @@ -263,7 +285,10 @@ class RecorderTask(Task): if analyser in self.__ANALYZERS_PROJECT.keys(): tasks.append( self.__ANALYZERS_PROJECT.get(analyser)( - workspace=workspace, model=model, dataset=dataset, recorder=recorder + workspace=workspace, + model=model, + dataset=dataset, + recorder=recorder, ) ) @@ -282,6 +307,7 @@ class CMDTask(ActionTask): """ This CMD task is responsible for ensuring compatibility across different operating systems. """ + def __init__(self, cmd_intention: str, cwd=None): self.cwd = cwd self.cmd_intention = cmd_intention @@ -292,7 +318,9 @@ class CMDTask(ActionTask): prompt = self.user.render( cmd_intention=self.cmd_intention, user_os=platform.system() ) - response = APIBackend().build_messages_and_create_chat_completion(prompt, self.system.render()) + response = APIBackend().build_messages_and_create_chat_completion( + prompt, self.system.render() + ) self._output = subprocess.check_output(response, shell=True, cwd=self.cwd) return [] @@ -300,7 +328,9 @@ class CMDTask(ActionTask): if self._output is not None: # TODO: it will be overrides by later commands # utf8 can't decode normally on Windows - self._context_manager.set_context(self.__class__.__name__, self._output.decode("ANSI")) + self._context_manager.set_context( + self.__class__.__name__, self._output.decode("ANSI") + ) class ConfigActionTask(ActionTask): @@ -313,10 +343,16 @@ class ConfigActionTask(ActionTask): component_list = ["Dataset", "Model", "Record", "Strategy", "Backtest"] prompt_element_dict = dict() for component in component_list: - prompt_element_dict[f"{component}_decision"] = self._context_manager.get_context(f"{component}_decision") - prompt_element_dict[f"{component}_plan"] = self._context_manager.get_context(f"{component}_plan") + prompt_element_dict[ + f"{component}_decision" + ] = self._context_manager.get_context(f"{component}_decision") + prompt_element_dict[ + f"{component}_plan" + ] = self._context_manager.get_context(f"{component}_plan") - assert None not in prompt_element_dict.values(), "Some decision or plan is not set by plan maker" + assert ( + None not in prompt_element_dict.values() + ), "Some decision or plan is not set by plan maker" config_prompt = self.user.render( user_requirement=user_prompt, @@ -335,21 +371,29 @@ class ConfigActionTask(ActionTask): response = APIBackend().build_messages_and_create_chat_completion( config_prompt, self.system.render() ) - self.save_chat_history_to_context_manager(config_prompt, response, self.system.render()) - res = re.search(r"Config:(.*)Reason:(.*)Improve suggestion:(.*)", response, re.S) + self.save_chat_history_to_context_manager( + config_prompt, response, self.system.render() + ) + res = re.search( + r"Config:(.*)Reason:(.*)Improve suggestion:(.*)", response, re.S + ) assert ( res is not None and len(res.groups()) == 3 ), "The response of config action task is not in the correct format" config = re.search(r"```yaml(.*)```", res.group(1), re.S) - assert config is not None, "The config part of config action task response is not in the correct format" + assert ( + config is not None + ), "The config part of config action task response is not in the correct format" config = config.group(1) reason = res.group(2) improve_suggestion = res.group(3) self._context_manager.set_context(f"{self.target_componet}_config", config) self._context_manager.set_context(f"{self.target_componet}_reason", reason) - self._context_manager.set_context(f"{self.target_componet}_improve_suggestion", improve_suggestion) + self._context_manager.set_context( + f"{self.target_componet}_improve_suggestion", improve_suggestion + ) return [] @@ -368,10 +412,16 @@ class ImplementActionTask(ActionTask): component_list = ["Dataset", "Model", "Record", "Strategy", "Backtest"] prompt_element_dict = dict() for component in component_list: - prompt_element_dict[f"{component}_decision"] = self._context_manager.get_context(f"{component}_decision") - prompt_element_dict[f"{component}_plan"] = self._context_manager.get_context(f"{component}_plan") + prompt_element_dict[ + f"{component}_decision" + ] = self._context_manager.get_context(f"{component}_decision") + prompt_element_dict[ + f"{component}_plan" + ] = self._context_manager.get_context(f"{component}_plan") - assert None not in prompt_element_dict.values(), "Some decision or plan is not set by plan maker" + assert ( + None not in prompt_element_dict.values() + ), "Some decision or plan is not set by plan maker" config = self._context_manager.get_context(f"{self.target_component}_config") implement_prompt = self.user.render( @@ -396,13 +446,17 @@ class ImplementActionTask(ActionTask): implement_prompt, response, self.system.render() ) - res = re.search(r"Code:(.*)Explanation:(.*)Modified config:(.*)", response, re.S) + res = re.search( + r"Code:(.*)Explanation:(.*)Modified config:(.*)", response, re.S + ) assert ( res is not None and len(res.groups()) == 3 ), f"The response of implement action task of component {self.target_component} is not in the correct format" code = re.search(r"```python(.*)```", res.group(1), re.S) - assert code is not None, "The code part of implementation action task response is not in the correct format" + assert ( + code is not None + ), "The code part of implementation action task response is not in the correct format" code = code.group(1) explanation = res.group(2) modified_config = re.search(r"```yaml(.*)```", res.group(3), re.S) @@ -412,8 +466,12 @@ class ImplementActionTask(ActionTask): modified_config = modified_config.group(1) self._context_manager.set_context(f"{self.target_component}_code", code) - self._context_manager.set_context(f"{self.target_component}_code_explanation", explanation) - self._context_manager.set_context(f"{self.target_component}_modified_config", modified_config) + self._context_manager.set_context( + f"{self.target_component}_code_explanation", explanation + ) + self._context_manager.set_context( + f"{self.target_component}_modified_config", modified_config + ) return [] @@ -458,7 +516,9 @@ class YamlEditTask(ActionTask): class SummarizeTask(Task): __DEFAULT_WORKSPACE = "./" - __DEFAULT_USER_PROMPT = "Summarize the information I offered and give me some advice." + __DEFAULT_USER_PROMPT = ( + "Summarize the information I offered and give me some advice." + ) # TODO: 2048 is close to exceed GPT token limit __MAX_LENGTH_OF_FILE = 2048 @@ -474,7 +534,9 @@ class SummarizeTask(Task): self.workspace = workspace user_prompt = self._context_manager.get_context("user_prompt") - user_prompt = user_prompt if user_prompt is not None else self.__DEFAULT_USER_PROMPT + user_prompt = ( + user_prompt if user_prompt is not None else self.__DEFAULT_USER_PROMPT + ) file_info = self.get_info_from_file(workspace) context_info = [] # too long context make response unstable. @@ -519,7 +581,9 @@ class SummarizeTask(Task): self.logger.info(f"file to summarize: {file}") # in case of too large file # TODO: Perhaps summarization method instead of truncation would be a better approach - result.append({"file": file, "content": content[: self.__MAX_LENGTH_OF_FILE]}) + result.append( + {"file": file, "content": content[: self.__MAX_LENGTH_OF_FILE]} + ) return result