mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 10:31:00 +08:00
Add backtest and backforward task (#1568)
* * add TrainTask & BacktestTask; * add BackForwardTask; * adjust prompt_template.yaml which default config failed to backtest; * run workflow in loop * add update method to prompt_template.py * remove debug code * Adjust Learn Process * add LearnManager class & use LearnManager to update system prompt; * use qrun to replace recorder for training and backtesting; * Adjust analyser * analyser independent of recorder; * rename analyser's workspace attribution; * analyser load variable by recorder. --------- Co-authored-by: Cadenza-Li <362237642@qq.com>
This commit is contained in:
@@ -3,10 +3,6 @@ import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
from qlib.utils import class_casting
|
||||
|
||||
from ..data.dataset import DatasetH
|
||||
from ..data.dataset.handler import DataHandlerLP
|
||||
from ..log import get_module_logger
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
|
||||
|
||||
@@ -14,8 +10,25 @@ logger = get_module_logger("analysis", logging.INFO)
|
||||
|
||||
|
||||
class AnalyzerTemp:
|
||||
def __init__(self, workspace=None, **kwargs):
|
||||
self.workspace = Path(workspace) if workspace else "./"
|
||||
def __init__(self, recorder, output_dir=None, **kwargs):
|
||||
self.recorder = recorder
|
||||
self.output_dir = Path(output_dir) if output_dir else "./"
|
||||
|
||||
def load(self, name: str):
|
||||
"""
|
||||
It behaves the same as self.recorder.load_object.
|
||||
But it is an easier interface because users don't have to care about `get_path` and `artifact_path`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
the name for the file to be load.
|
||||
|
||||
Return
|
||||
------
|
||||
The stored records.
|
||||
"""
|
||||
return self.recorder.load_object(name)
|
||||
|
||||
def analyse(self, **kwargs):
|
||||
"""
|
||||
@@ -42,7 +55,10 @@ class HFAnalyzer(AnalyzerTemp):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def analyse(self, pred=None, label=None):
|
||||
def analyse(self):
|
||||
pred = self.load("pred.pkl")
|
||||
label = self.load("label.pkl")
|
||||
|
||||
long_pre, short_pre = calc_long_short_prec(pred.iloc[:, 0], label.iloc[:, 0], is_alpha=True)
|
||||
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
|
||||
metrics = {
|
||||
@@ -65,13 +81,13 @@ class HFAnalyzer(AnalyzerTemp):
|
||||
table = [[k, v] for (k, v) in metrics.items()]
|
||||
plt.table(cellText=table, loc="center")
|
||||
plt.axis("off")
|
||||
plt.savefig(self.workspace.joinpath("HFAnalyzerTable.jpeg"))
|
||||
plt.savefig(self.output_dir.joinpath("HFAnalyzerTable.jpeg"))
|
||||
plt.clf()
|
||||
|
||||
plt.scatter(np.arange(0, len(pred)), pred.iloc[:, 0])
|
||||
plt.scatter(np.arange(0, len(label)), label.iloc[:, 0])
|
||||
plt.title("HFAnalyzer")
|
||||
plt.savefig(self.workspace.joinpath("HFAnalyzer.jpeg"))
|
||||
plt.savefig(self.output_dir.joinpath("HFAnalyzer.jpeg"))
|
||||
return "HFAnalyzer.jpeg"
|
||||
|
||||
|
||||
@@ -86,24 +102,10 @@ class SignalAnalyzer(AnalyzerTemp):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def analyse(self, dataset=None, **kwargs):
|
||||
label = self.load("label.pkl")
|
||||
|
||||
with class_casting(dataset, DatasetH):
|
||||
params = dict(segments="test", col_set="label", data_key=DataHandlerLP.DK_R)
|
||||
try:
|
||||
# Assume the backend handler is DataHandlerLP
|
||||
raw_label = dataset.prepare(**params)
|
||||
except TypeError:
|
||||
# The argument number is not right
|
||||
del params["data_key"]
|
||||
# The backend handler should be DataHandler
|
||||
raw_label = dataset.prepare(**params)
|
||||
except AttributeError as e:
|
||||
# The data handler is initialized with `drop_raw=True`...
|
||||
# So raw_label is not available
|
||||
logger.warning(f"Exception: {e}")
|
||||
raw_label = None
|
||||
plt.hist(raw_label)
|
||||
plt.hist(label)
|
||||
plt.title("SignalAnalyzer")
|
||||
plt.savefig(self.workspace.joinpath("signalAnalysis.jpeg"))
|
||||
plt.savefig(self.output_dir.joinpath("signalAnalysis.jpeg"))
|
||||
|
||||
return "signalAnalysis.jpeg"
|
||||
|
||||
15
qlib/finco/cli_learn.py
Normal file
15
qlib/finco/cli_learn.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import fire
|
||||
from qlib.finco.workflow import LearnManager
|
||||
from dotenv import load_dotenv
|
||||
from qlib import auto_init
|
||||
|
||||
|
||||
def main(prompt=None):
|
||||
load_dotenv(verbose=True, override=True)
|
||||
lm = LearnManager()
|
||||
lm.run(prompt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
auto_init()
|
||||
fire.Fire(main)
|
||||
@@ -1,18 +1,37 @@
|
||||
from typing import Union
|
||||
from pathlib import Path
|
||||
from jinja2 import Template
|
||||
import yaml
|
||||
|
||||
from qlib.finco.utils import Singleton
|
||||
from qlib.finco import get_finco_path
|
||||
import yaml
|
||||
import os
|
||||
|
||||
class PormptTemplate(Singleton):
|
||||
|
||||
class PromptTemplate(Singleton):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
_template = yaml.load(open(os.path.join(get_finco_path(), "prompt_template.yaml"), "r"), Loader=yaml.FullLoader)
|
||||
_template = yaml.load(open(Path.joinpath(get_finco_path(), "prompt_template.yaml"), "r"),
|
||||
Loader=yaml.FullLoader)
|
||||
for k, v in _template.items():
|
||||
if k == "mods":
|
||||
continue
|
||||
self.__setattr__(k, Template(v))
|
||||
|
||||
|
||||
for target_name, module_to_render_params in _template["mods"].items():
|
||||
for module_name, params in module_to_render_params.items():
|
||||
self.__setattr__(f"{target_name}_{module_name}", Template(self.__getattribute__(target_name).render(**params)))
|
||||
self.__setattr__(f"{target_name}_{module_name}",
|
||||
Template(self.__getattribute__(target_name).render(**params)))
|
||||
|
||||
def get(self, key: str):
|
||||
return self.__dict__.get(key, Template(""))
|
||||
|
||||
def update(self, key: str, value):
|
||||
self.__setattr__(key, value)
|
||||
|
||||
def save(self, file_path: Union[str, Path]):
|
||||
if isinstance(file_path, str):
|
||||
file_path = Path(file_path)
|
||||
Path.mkdir(file_path.parent, exist_ok=True)
|
||||
|
||||
with open(file_path, 'w') as f:
|
||||
yaml.dump(self.__dict__, f)
|
||||
|
||||
@@ -193,6 +193,13 @@ SummarizeTask_user : |-
|
||||
Here is my information: '{{information}}'
|
||||
My intention is: {{user_prompt}}. Please provide me with a summary and recommendation based on my intention and the information I have provided. There are some figures which absolute path are: {{figure_path}}, You must display these images in markdown using the appropriate image format.
|
||||
|
||||
BackForwardTask_system : |-
|
||||
Your task is adjusting system prompt in each task to fulfill user's intention
|
||||
|
||||
BackForwardTask_user : |-
|
||||
Here is the final summary: '{{summary}}'
|
||||
Tasks I have run are: {{task_finished}}, {{task}}'s system prompt is: {{system}}. User's intention is: {{user_prompt}}. you will adjust it to:
|
||||
|
||||
mods:
|
||||
ConfigActionTask_system:
|
||||
Dataset:
|
||||
@@ -382,7 +389,7 @@ mods:
|
||||
```
|
||||
Reason: I choose the backtest parameters above because they are suitable for a low turnover strategy focusing on long-term returns in the China A stock market. The start and end times are set to cover a 4-year period, which is reasonable for a long-term strategy. The account value is set to 1,000,000 as a starting point, and the benchmark is set to SH000300, which represents the China A stock market.
|
||||
Improve suggestion: You can try different time ranges for the backtest to evaluate the performance of the strategy in different market conditions. Also, you can adjust the costs (open_cost, close_cost, and min_cost) to better reflect the actual trading costs in the China A stock market.
|
||||
|
||||
|
||||
ConfigActionTask_user:
|
||||
Dataset:
|
||||
target_component : |-
|
||||
@@ -402,7 +409,7 @@ mods:
|
||||
Backtest:
|
||||
target_component : |-
|
||||
Backtest
|
||||
|
||||
|
||||
ImplementActionTask_system:
|
||||
Dataset:
|
||||
target_component : |-
|
||||
|
||||
@@ -11,10 +11,9 @@ import platform
|
||||
|
||||
from qlib.finco.llm import APIBackend
|
||||
from qlib.finco.tpl import get_tpl_path
|
||||
from qlib.finco.prompt_template import PormptTemplate
|
||||
from qlib.finco.prompt_template import PromptTemplate
|
||||
from qlib.workflow.record_temp import HFSignalRecord, SignalRecord
|
||||
from qlib.contrib.analyzer import HFAnalyzer, SignalAnalyzer
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
from qlib.finco.log import FinCoLog, LogColors
|
||||
from qlib.finco.conf import Config
|
||||
@@ -41,7 +40,7 @@ class Task:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._context_manager = None
|
||||
self.prompt_template = PormptTemplate()
|
||||
self.prompt_template = PromptTemplate()
|
||||
self.executed = False
|
||||
self.continuous = Config().continuous_mode
|
||||
self.logger = FinCoLog()
|
||||
@@ -96,13 +95,11 @@ class Task:
|
||||
|
||||
@property
|
||||
def system(self):
|
||||
return self.prompt_template.__getattribute__(
|
||||
self.__class__.__name__ + "_system"
|
||||
)
|
||||
return self.prompt_template.get(self.__class__.__name__ + "_system")
|
||||
|
||||
@property
|
||||
def user(self):
|
||||
return self.prompt_template.__getattribute__(self.__class__.__name__ + "_user")
|
||||
return self.prompt_template.get(self.__class__.__name__ + "_user")
|
||||
|
||||
def __str__(self):
|
||||
return self.__class__.__name__
|
||||
@@ -150,7 +147,7 @@ class PlanTask(Task):
|
||||
|
||||
|
||||
class SLPlanTask(PlanTask):
|
||||
def __init__(self,) -> None:
|
||||
def __init__(self, ) -> None:
|
||||
super().__init__()
|
||||
|
||||
def execute(self):
|
||||
@@ -220,13 +217,14 @@ class RLPlanTask(PlanTask):
|
||||
return []
|
||||
|
||||
|
||||
class RecorderTask(Task):
|
||||
class TrainTask(Task):
|
||||
"""
|
||||
This Recorder task is responsible for analysing data such as index and distribution.
|
||||
This train task is responsible for training model configure by yaml file.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._output = None
|
||||
|
||||
def execute(self):
|
||||
workflow_config = (
|
||||
@@ -234,6 +232,7 @@ class RecorderTask(Task):
|
||||
if self._context_manager.get_context("workflow_config")
|
||||
else "workflow_config.yaml"
|
||||
)
|
||||
|
||||
workspace = self._context_manager.get_context("workspace")
|
||||
workflow_path = workspace.joinpath(workflow_config)
|
||||
with workflow_path.open() as f:
|
||||
@@ -246,24 +245,19 @@ class RecorderTask(Task):
|
||||
if confirm is False:
|
||||
return []
|
||||
|
||||
model = init_instance_by_config(workflow["task"]["model"])
|
||||
dataset = init_instance_by_config(workflow["task"]["dataset"])
|
||||
|
||||
with R.start(experiment_name="finCo"):
|
||||
model.fit(dataset)
|
||||
R.save_objects(trained_model=model)
|
||||
|
||||
# prediction
|
||||
recorder = R.get_recorder()
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
|
||||
self._context_manager.set_context("model", model)
|
||||
self._context_manager.set_context("dataset", dataset)
|
||||
self._context_manager.set_context("recorder", recorder)
|
||||
command = f"qrun {workflow_path}"
|
||||
self._output = subprocess.check_output(command, shell=True, cwd=workspace)
|
||||
|
||||
return [AnalysisTask()]
|
||||
|
||||
def summarize(self):
|
||||
if self._output is not None:
|
||||
# TODO: it will be overrides by later commands
|
||||
# utf8 can't decode normally on Windows
|
||||
self._context_manager.set_context(
|
||||
self.__class__.__name__, self._output.decode("ANSI")
|
||||
)
|
||||
|
||||
|
||||
class AnalysisTask(Task):
|
||||
"""
|
||||
@@ -271,8 +265,8 @@ class AnalysisTask(Task):
|
||||
"""
|
||||
|
||||
__ANALYZERS_PROJECT = {
|
||||
HFAnalyzer.__name__: HFSignalRecord,
|
||||
SignalAnalyzer.__name__: SignalRecord,
|
||||
HFAnalyzer.__name__: HFAnalyzer,
|
||||
SignalAnalyzer.__name__: SignalAnalyzer,
|
||||
}
|
||||
__ANALYZERS_DOCS = {
|
||||
HFAnalyzer.__name__: HFAnalyzer.__doc__,
|
||||
@@ -303,7 +297,7 @@ class AnalysisTask(Task):
|
||||
ANALYZERS_DOCS=self.__ANALYZERS_DOCS,
|
||||
),
|
||||
)
|
||||
analysers = response.split(",")
|
||||
analysers = response.replace(" ", "").split(",")
|
||||
confirm = self.interact(f"I select these analysers: {analysers}\n"
|
||||
f"Are you sure you want to use? yes(Y/y), no(N/n) or prompt:")
|
||||
if confirm is False:
|
||||
@@ -317,15 +311,26 @@ class AnalysisTask(Task):
|
||||
if isinstance(analysers, list) and len(analysers):
|
||||
self.logger.info(f"selected analysers: {analysers}", plain=True)
|
||||
|
||||
workflow_config = (
|
||||
self._context_manager.get_context("workflow_config")
|
||||
if self._context_manager.get_context("workflow_config")
|
||||
else "workflow_config.yaml"
|
||||
)
|
||||
workspace = self._context_manager.get_context("workspace")
|
||||
workflow_path = workspace.joinpath(workflow_config)
|
||||
with workflow_path.open() as f:
|
||||
workflow = yaml.safe_load(f)
|
||||
|
||||
experiment_name = workflow["experiment_name"] if "experiment_name" in workflow else "workflow"
|
||||
R.set_uri(Path.joinpath(workspace, 'mlruns').as_uri())
|
||||
|
||||
tasks = []
|
||||
for analyser in analysers:
|
||||
if analyser in self.__ANALYZERS_PROJECT.keys():
|
||||
tasks.append(
|
||||
self.__ANALYZERS_PROJECT.get(analyser)(
|
||||
workspace=self._context_manager.get_context("workspace"),
|
||||
model=self._context_manager.get_context("model"),
|
||||
dataset=self._context_manager.get_context("dataset"),
|
||||
recorder=self._context_manager.get_context("recorder"),
|
||||
recorder=R.get_recorder(experiment_name=experiment_name),
|
||||
output_dir=workspace
|
||||
)
|
||||
)
|
||||
|
||||
@@ -575,11 +580,14 @@ class SummarizeTask(Task):
|
||||
information=information, figure_path=figure_path, user_prompt=user_prompt
|
||||
)
|
||||
|
||||
# todo: remove 'be' after test
|
||||
be = APIBackend()
|
||||
be.debug_mode = False
|
||||
response = be.build_messages_and_create_chat_completion(
|
||||
user_prompt=prompt_workflow_selection, system_prompt=self.system.render()
|
||||
)
|
||||
|
||||
self._context_manager.set_context("summary", response)
|
||||
self.save_markdown(content=response)
|
||||
self.logger.info(f"Report has saved to {self.__DEFAULT_REPORT_NAME}", title="End")
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
experiment_name: finCo
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
@@ -14,7 +15,7 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
|
||||
@@ -3,9 +3,11 @@ import copy
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from qlib.finco.task import WorkflowTask, PlanTask, ActionTask, SummarizeTask, RecorderTask, AnalysisTask
|
||||
from qlib.finco.task import WorkflowTask, SummarizeTask, TrainTask
|
||||
from qlib.finco.prompt_template import PromptTemplate, Template
|
||||
from qlib.finco.log import FinCoLog, LogColors
|
||||
from qlib.finco.utils import similarity
|
||||
from qlib.finco.llm import APIBackend
|
||||
|
||||
|
||||
class WorkflowContextManager:
|
||||
@@ -51,9 +53,16 @@ class WorkflowContextManager:
|
||||
max_score_key = max(scores, key=scores.get)
|
||||
return {max_score_key: self.context.get(max_score_key)}
|
||||
|
||||
def clear(self, reserve: list = None):
|
||||
if reserve is None:
|
||||
reserve = []
|
||||
|
||||
_context = {k: self.get_context(k) for k in reserve}
|
||||
self.context = _context
|
||||
|
||||
|
||||
class WorkflowManager:
|
||||
"""This manange the whole task automation workflow including tasks and actions"""
|
||||
"""This manage the whole task automation workflow including tasks and actions"""
|
||||
|
||||
def __init__(self, workspace=None) -> None:
|
||||
self.logger = FinCoLog()
|
||||
@@ -63,8 +72,10 @@ class WorkflowManager:
|
||||
else:
|
||||
self._workspace = Path(workspace)
|
||||
self._confirm_and_rm()
|
||||
self._context = WorkflowContextManager()
|
||||
self._context.set_context("workspace", self._workspace)
|
||||
|
||||
self.prompt_template = PromptTemplate()
|
||||
self.context = WorkflowContextManager()
|
||||
self.context.set_context("workspace", self._workspace)
|
||||
self.default_user_prompt = "Please help me build a low turnover strategy that focus more on longterm return in China a stock market. Please help to pick one third of the factors in Alpha360 and use lightGBM model."
|
||||
|
||||
def _confirm_and_rm(self):
|
||||
@@ -87,10 +98,10 @@ class WorkflowManager:
|
||||
|
||||
def set_context(self, key, value):
|
||||
"""Direct call set_context method of the context manager"""
|
||||
self._context.set_context(key, value)
|
||||
self.context.set_context(key, value)
|
||||
|
||||
def get_context(self) -> WorkflowContextManager:
|
||||
return self._context
|
||||
return self.context
|
||||
|
||||
def run(self, prompt: str) -> Path:
|
||||
"""
|
||||
@@ -124,7 +135,7 @@ class WorkflowManager:
|
||||
self.logger.info(f"user_prompt: {self.get_context().get_context('user_prompt')}", title="Start")
|
||||
|
||||
# NOTE: list may not be enough for general task list
|
||||
task_list = [WorkflowTask(), RecorderTask(), SummarizeTask()]
|
||||
task_list = [WorkflowTask(), TrainTask(), SummarizeTask()]
|
||||
task_finished = []
|
||||
while len(task_list):
|
||||
task_list_info = [str(task) for task in task_list]
|
||||
@@ -138,15 +149,51 @@ class WorkflowManager:
|
||||
f"Executing task: {str(t)}",
|
||||
title="Task")
|
||||
|
||||
t.assign_context_manager(self._context)
|
||||
t.assign_context_manager(self.context)
|
||||
res = t.execute()
|
||||
t.summarize()
|
||||
task_finished.append(t)
|
||||
self.context.set_context("task_finished", task_finished)
|
||||
self.logger.plain_info(f"{str(t)} finished.\n\n\n")
|
||||
|
||||
for _ in res:
|
||||
if not isinstance(t, (WorkflowTask, PlanTask, ActionTask, RecorderTask, AnalysisTask, SummarizeTask)):
|
||||
raise NotImplementedError(f"Unsupported Task type {_}")
|
||||
task_list = res + task_list
|
||||
|
||||
return self._workspace
|
||||
|
||||
|
||||
class LearnManager:
|
||||
|
||||
def __init__(self):
|
||||
self.epoch = 0
|
||||
self.wm = WorkflowManager()
|
||||
|
||||
def run(self, prompt):
|
||||
|
||||
# todo: add early stop condition
|
||||
for i in range(10):
|
||||
self.wm.run(prompt)
|
||||
self.learn()
|
||||
self.epoch += 1
|
||||
|
||||
def learn(self):
|
||||
workspace = self.wm.context.get_context("workspace")
|
||||
task_finished = self.wm.context.get_context("task_finished")
|
||||
|
||||
user_prompt = self.wm.context.get_context("user_prompt")
|
||||
summary = self.wm.context.get_context("summary")
|
||||
|
||||
for task in task_finished:
|
||||
prompt_workflow_selection = task.user.render(
|
||||
summary=summary, task_finished=[str(task) for task in task_finished],
|
||||
task=task.__class__, system=task.system, user_prompt=user_prompt
|
||||
)
|
||||
|
||||
response = APIBackend().build_messages_and_create_chat_completion(
|
||||
user_prompt=prompt_workflow_selection, system_prompt=task.system.render()
|
||||
)
|
||||
|
||||
# todo: response assertion
|
||||
task.prompt_template.update(key=f"{task.__class__.__name__}_system", value=Template(response))
|
||||
|
||||
self.wm.prompt_template.save(Path.joinpath(workspace, f"prompts/checkpoint_{self.epoch}.yml"))
|
||||
self.wm.context.clear(reserve=["workspace"])
|
||||
|
||||
@@ -165,11 +165,10 @@ class SignalRecord(RecordTemp):
|
||||
This is the Signal Record class that generates the signal prediction. This class inherits the ``RecordTemp`` class.
|
||||
"""
|
||||
|
||||
def __init__(self, model=None, dataset=None, recorder=None, workspace=None):
|
||||
def __init__(self, model=None, dataset=None, recorder=None):
|
||||
super().__init__(recorder=recorder)
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
self.workspace = workspace
|
||||
|
||||
@staticmethod
|
||||
def generate_label(dataset):
|
||||
@@ -208,10 +207,6 @@ class SignalRecord(RecordTemp):
|
||||
raw_label = self.generate_label(self.dataset)
|
||||
self.save(**{"label.pkl": raw_label})
|
||||
|
||||
def analyse(self):
|
||||
res = SignalAnalyzer(workspace=self.workspace).analyse(dataset=self.dataset)
|
||||
return res
|
||||
|
||||
def list(self):
|
||||
return ["pred.pkl", "label.pkl"]
|
||||
|
||||
@@ -253,9 +248,8 @@ class HFSignalRecord(SignalRecord):
|
||||
artifact_path = "hg_sig_analysis"
|
||||
depend_cls = SignalRecord
|
||||
|
||||
def __init__(self, recorder, workspace=None, **kwargs):
|
||||
def __init__(self, recorder, **kwargs):
|
||||
super().__init__(recorder=recorder)
|
||||
self.workspace = workspace
|
||||
|
||||
def generate(self):
|
||||
pred = self.load("pred.pkl")
|
||||
@@ -289,12 +283,6 @@ class HFSignalRecord(SignalRecord):
|
||||
self.save(**objects)
|
||||
pprint(metrics)
|
||||
|
||||
def analyse(self):
|
||||
pred = self.load("pred.pkl")
|
||||
raw_label = self.load("label.pkl")
|
||||
res = HFAnalyzer(workspace=self.workspace).analyse(pred=pred, label=raw_label)
|
||||
return res
|
||||
|
||||
def list(self):
|
||||
return ["ic.pkl", "ric.pkl", "long_pre.pkl", "short_pre.pkl", "long_short_r.pkl", "long_avg_r.pkl"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user