mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-29 00:51:19 +08:00
Compare commits
47 Commits
6cma
...
xuyang1/su
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2df211c320 | ||
|
|
effed382e9 | ||
|
|
86ffd1799d | ||
|
|
aef11536e3 | ||
|
|
8b0fdf1623 | ||
|
|
9a36f8da20 | ||
|
|
b7757d5008 | ||
|
|
ee5e5cfdd8 | ||
|
|
6cb87ecfd1 | ||
|
|
9119bcdd3c | ||
|
|
4fccf8112d | ||
|
|
73bd79ca1a | ||
|
|
7e84f3aae2 | ||
|
|
1326ac614d | ||
|
|
f12184cc0f | ||
|
|
a70386ad52 | ||
|
|
74619ed8d8 | ||
|
|
1a523df007 | ||
|
|
f9cc8a5aaa | ||
|
|
7762c5a1fd | ||
|
|
fa7ef29281 | ||
|
|
429c9a7c66 | ||
|
|
80fbc00792 | ||
|
|
01accec24c | ||
|
|
1d88830b0d | ||
|
|
ad7498e287 | ||
|
|
73d51f05b4 | ||
|
|
3b56b8e6c0 | ||
|
|
40e0c329ba | ||
|
|
e376648860 | ||
|
|
5f37f32184 | ||
|
|
d46b4c1ebf | ||
|
|
0515524b51 | ||
|
|
cda32d5703 | ||
|
|
e2332a004b | ||
|
|
08d9dbccc9 | ||
|
|
e7cd93a36d | ||
|
|
3919678028 | ||
|
|
421b1403b2 | ||
|
|
94102fb742 | ||
|
|
74a5d7c8af | ||
|
|
ce39b4b6f8 | ||
|
|
2af35d9c89 | ||
|
|
f37643550b | ||
|
|
55611aa43e | ||
|
|
f24253efd2 | ||
|
|
7c4f3b8a7d |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -22,6 +22,7 @@ dist/
|
||||
qlib/VERSION.txt
|
||||
qlib/data/_libs/expanding.cpp
|
||||
qlib/data/_libs/rolling.cpp
|
||||
qlib/finco/prompt_cache.json
|
||||
examples/estimator/estimator_example/
|
||||
examples/rl/data/
|
||||
examples/rl/checkpoints/
|
||||
|
||||
@@ -179,7 +179,7 @@ def get_strategy_executor(
|
||||
executor: Union[str, dict, object, Path],
|
||||
benchmark: Optional[str] = "SH000300",
|
||||
account: Union[float, int, dict] = 1e9,
|
||||
exchange_kwargs: Union[dict, Exchange] = {}, # TODO: rename parameter
|
||||
exchange_kwargs: dict = {},
|
||||
pos_type: str = "Position",
|
||||
) -> Tuple[BaseStrategy, BaseExecutor]:
|
||||
|
||||
@@ -197,15 +197,12 @@ def get_strategy_executor(
|
||||
pos_type=pos_type,
|
||||
)
|
||||
|
||||
if isinstance(exchange_kwargs, Exchange):
|
||||
trade_exchange = exchange_kwargs
|
||||
else:
|
||||
exchange_kwargs = copy.copy(exchange_kwargs)
|
||||
if "start_time" not in exchange_kwargs:
|
||||
exchange_kwargs["start_time"] = start_time
|
||||
if "end_time" not in exchange_kwargs:
|
||||
exchange_kwargs["end_time"] = end_time
|
||||
trade_exchange = get_exchange(**exchange_kwargs)
|
||||
exchange_kwargs = copy.copy(exchange_kwargs)
|
||||
if "start_time" not in exchange_kwargs:
|
||||
exchange_kwargs["start_time"] = start_time
|
||||
if "end_time" not in exchange_kwargs:
|
||||
exchange_kwargs["end_time"] = end_time
|
||||
trade_exchange = get_exchange(**exchange_kwargs)
|
||||
|
||||
common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=trade_exchange)
|
||||
trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy)
|
||||
|
||||
@@ -56,7 +56,6 @@ def collect_data_loop(
|
||||
trade_strategy: BaseStrategy,
|
||||
trade_executor: BaseExecutor,
|
||||
return_value: dict | None = None,
|
||||
show_progress: bool = True,
|
||||
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], None]:
|
||||
"""Generator for collecting the trade decision data for rl training
|
||||
|
||||
@@ -75,8 +74,6 @@ def collect_data_loop(
|
||||
the outermost executor
|
||||
return_value : dict
|
||||
used for backtest_loop
|
||||
show_progress: bool
|
||||
whether to show execution progress
|
||||
|
||||
Yields
|
||||
-------
|
||||
@@ -86,8 +83,7 @@ def collect_data_loop(
|
||||
trade_executor.reset(start_time=start_time, end_time=end_time)
|
||||
trade_strategy.reset(level_infra=trade_executor.get_level_infra())
|
||||
|
||||
disable = not show_progress
|
||||
with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop", disable=disable) as bar:
|
||||
with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop") as bar:
|
||||
_execute_result = None
|
||||
while not trade_executor.finished():
|
||||
_trade_decision: BaseTradeDecision = trade_strategy.generate_trade_decision(_execute_result)
|
||||
|
||||
@@ -177,7 +177,7 @@ class Exchange:
|
||||
|
||||
necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"}
|
||||
if self.limit_type == self.LT_TP_EXP:
|
||||
assert isinstance(limit_threshold, tuple) or (isinstance(limit_threshold, list) and len(limit_threshold) == 2)
|
||||
assert isinstance(limit_threshold, tuple)
|
||||
for exp in limit_threshold:
|
||||
necessary_fields.add(exp)
|
||||
all_fields = list(necessary_fields | set(vol_lt_fields) | set(subscribe_fields))
|
||||
@@ -263,9 +263,6 @@ class Exchange:
|
||||
"""get limit type"""
|
||||
if isinstance(limit_threshold, tuple):
|
||||
return self.LT_TP_EXP
|
||||
if isinstance(limit_threshold, list):
|
||||
assert len(limit_threshold) == 2
|
||||
return self.LT_TP_EXP
|
||||
elif isinstance(limit_threshold, float):
|
||||
return self.LT_FLT
|
||||
elif limit_threshold is None:
|
||||
@@ -328,7 +325,7 @@ class Exchange:
|
||||
|
||||
assert isinstance(volume_threshold, dict)
|
||||
for key, vol_limit in volume_threshold.items():
|
||||
assert isinstance(vol_limit, tuple) or (isinstance(vol_limit, list) and len(vol_limit) == 2)
|
||||
assert isinstance(vol_limit, tuple)
|
||||
fields.add(vol_limit[1])
|
||||
|
||||
if key in ("buy", "all"):
|
||||
@@ -806,7 +803,7 @@ class Exchange:
|
||||
|
||||
vol_limit_num: List[float] = []
|
||||
for limit in vol_limit:
|
||||
assert isinstance(limit, tuple) or (isinstance(limit, list) and len(limit) == 2)
|
||||
assert isinstance(limit, tuple)
|
||||
if limit[0] == "current":
|
||||
limit_value = self.quote.get_data(
|
||||
order.stock_id,
|
||||
|
||||
111
qlib/contrib/analyzer.py
Normal file
111
qlib/contrib/analyzer.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import logging
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
from ..log import get_module_logger
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
|
||||
|
||||
logger = get_module_logger("analysis", logging.INFO)
|
||||
|
||||
|
||||
class AnalyzerTemp:
|
||||
def __init__(self, recorder, output_dir=None, **kwargs):
|
||||
self.recorder = recorder
|
||||
self.output_dir = Path(output_dir) if output_dir else "./"
|
||||
|
||||
def load(self, name: str):
|
||||
"""
|
||||
It behaves the same as self.recorder.load_object.
|
||||
But it is an easier interface because users don't have to care about `get_path` and `artifact_path`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
the name for the file to be load.
|
||||
|
||||
Return
|
||||
------
|
||||
The stored records.
|
||||
"""
|
||||
return self.recorder.load_object(name)
|
||||
|
||||
def analyse(self, **kwargs):
|
||||
"""
|
||||
Analyse data index, distribution .etc
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
|
||||
Return
|
||||
------
|
||||
The handled data.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `analysis` method.")
|
||||
|
||||
|
||||
class HFAnalyzer(AnalyzerTemp):
|
||||
"""
|
||||
This is the Signal Analysis class that generates the analysis results such as IC and IR.
|
||||
|
||||
default output image filename is "HFAnalyzerTable.jpeg"
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def analyse(self):
|
||||
pred = self.load("pred.pkl")
|
||||
label = self.load("label.pkl")
|
||||
|
||||
long_pre, short_pre = calc_long_short_prec(pred.iloc[:, 0], label.iloc[:, 0], is_alpha=True)
|
||||
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
|
||||
metrics = {
|
||||
"IC": ic.mean(),
|
||||
"ICIR": ic.mean() / ic.std(),
|
||||
"Rank IC": ric.mean(),
|
||||
"Rank ICIR": ric.mean() / ric.std(),
|
||||
"Long precision": long_pre.mean(),
|
||||
"Short precision": short_pre.mean(),
|
||||
}
|
||||
|
||||
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, 0])
|
||||
metrics.update(
|
||||
{
|
||||
"Long-Short Average Return": long_short_r.mean(),
|
||||
"Long-Short Average Sharpe": long_short_r.mean() / long_short_r.std(),
|
||||
}
|
||||
)
|
||||
|
||||
table = [[k, v] for (k, v) in metrics.items()]
|
||||
plt.table(cellText=table, loc="center")
|
||||
plt.axis("off")
|
||||
plt.savefig(self.output_dir.joinpath("HFAnalyzerTable.jpeg"))
|
||||
plt.clf()
|
||||
|
||||
plt.scatter(np.arange(0, len(pred)), pred.iloc[:, 0])
|
||||
plt.scatter(np.arange(0, len(label)), label.iloc[:, 0])
|
||||
plt.title("HFAnalyzer")
|
||||
plt.savefig(self.output_dir.joinpath("HFAnalyzer.jpeg"))
|
||||
return "HFAnalyzer.jpeg"
|
||||
|
||||
|
||||
class SignalAnalyzer(AnalyzerTemp):
|
||||
"""
|
||||
This is the Signal Analysis class that generates the analysis results such as IC and IR.
|
||||
|
||||
default output image filename is "signalAnalysis.jpeg"
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def analyse(self, dataset=None, **kwargs):
|
||||
label = self.load("label.pkl")
|
||||
|
||||
plt.hist(label)
|
||||
plt.title("SignalAnalyzer")
|
||||
plt.savefig(self.output_dir.joinpath("signalAnalysis.jpeg"))
|
||||
|
||||
return "signalAnalysis.jpeg"
|
||||
@@ -1,6 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Optional
|
||||
from qlib.utils.data import update_config
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...data.dataset.processor import Processor
|
||||
from ...utils import get_callable_kwargs
|
||||
@@ -57,12 +59,13 @@ class Alpha360(DataHandlerLP):
|
||||
fit_end_time=None,
|
||||
filter_pipe=None,
|
||||
inst_processors=None,
|
||||
data_loader: Optional[dict] = None,
|
||||
**kwargs
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
|
||||
data_loader = {
|
||||
_data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": {
|
||||
@@ -74,12 +77,14 @@ class Alpha360(DataHandlerLP):
|
||||
"inst_processors": inst_processors,
|
||||
},
|
||||
}
|
||||
if data_loader is not None:
|
||||
update_config(_data_loader, data_loader)
|
||||
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
data_loader=_data_loader,
|
||||
learn_processors=learn_processors,
|
||||
infer_processors=infer_processors,
|
||||
**kwargs
|
||||
@@ -153,12 +158,13 @@ class Alpha158(DataHandlerLP):
|
||||
process_type=DataHandlerLP.PTYPE_A,
|
||||
filter_pipe=None,
|
||||
inst_processors=None,
|
||||
data_loader: Optional[dict] = None,
|
||||
**kwargs
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
|
||||
data_loader = {
|
||||
_data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": {
|
||||
@@ -170,11 +176,13 @@ class Alpha158(DataHandlerLP):
|
||||
"inst_processors": inst_processors,
|
||||
},
|
||||
}
|
||||
if data_loader is not None:
|
||||
update_config(_data_loader, data_loader)
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
data_loader=_data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
process_type=process_type,
|
||||
|
||||
18
qlib/finco/.env.example
Normal file
18
qlib/finco/.env.example
Normal file
@@ -0,0 +1,18 @@
|
||||
|
||||
OPENAI_API_KEY=your_api_key
|
||||
|
||||
# USE_AZURE=True
|
||||
# AZURE_API_BASE=your_api_base
|
||||
# AZURE_API_VERSION=your_api_version
|
||||
|
||||
# use gpt-4 means more token but more wait time
|
||||
# MODEL=gpt-4
|
||||
# MAX_TOKENS=1600
|
||||
# MAX_RETRY=1000
|
||||
|
||||
|
||||
MAX_TOKENS=1600
|
||||
MAX_RETRY=120
|
||||
|
||||
CONTINOUS_MODE=True
|
||||
DEBUG_MODE=True
|
||||
22
qlib/finco/README.md
Normal file
22
qlib/finco/README.md
Normal file
@@ -0,0 +1,22 @@
|
||||
# This is an experimental branch of "`FI`nancial `CO`pilot of `Qlib`"
|
||||
|
||||
## Installation
|
||||
|
||||
- To run this module, you need to first install Qlib following the instruction in [install-from-source](/README.md#install-from-source) or follow:
|
||||
|
||||
```python
|
||||
python -m pip install git+https://github.com/microsoft/qlib.git@finco
|
||||
```
|
||||
|
||||
- then you need to install other dependencies of finco:
|
||||
```python
|
||||
python -m pip install pydantic openai python-dotenv
|
||||
```
|
||||
|
||||
## Quick run
|
||||
|
||||
To run this module, you can start the workflow easily with one command:
|
||||
|
||||
```sh
|
||||
cd qlib/finco; python cli.py "your prompt"
|
||||
```
|
||||
13
qlib/finco/__init__.py
Normal file
13
qlib/finco/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
DIRNAME = Path(__file__).absolute().resolve().parent
|
||||
|
||||
|
||||
def get_finco_path() -> Path:
|
||||
"""
|
||||
return the template path
|
||||
Because the template path is located in the folder. We don't know where it is located. So __file__ for this module will be used.
|
||||
"""
|
||||
return DIRNAME
|
||||
15
qlib/finco/cli.py
Normal file
15
qlib/finco/cli.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import fire
|
||||
from qlib.finco.workflow import WorkflowManager
|
||||
from dotenv import load_dotenv
|
||||
from qlib import auto_init
|
||||
|
||||
|
||||
def main(prompt=None):
|
||||
load_dotenv(verbose=True, override=True)
|
||||
wm = WorkflowManager()
|
||||
wm.run(prompt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
auto_init()
|
||||
fire.Fire(main)
|
||||
15
qlib/finco/cli_learn.py
Normal file
15
qlib/finco/cli_learn.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import fire
|
||||
from qlib.finco.workflow import LearnManager
|
||||
from dotenv import load_dotenv
|
||||
from qlib import auto_init
|
||||
|
||||
|
||||
def main(prompt=None):
|
||||
load_dotenv(verbose=True, override=True)
|
||||
lm = LearnManager()
|
||||
lm.run(prompt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
auto_init()
|
||||
fire.Fire(main)
|
||||
32
qlib/finco/conf.py
Normal file
32
qlib/finco/conf.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# TODO: use pydantic for other modules in Qlib
|
||||
from pydantic import BaseSettings
|
||||
from qlib.finco.utils import SingletonBaseClass
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class Config(SingletonBaseClass):
|
||||
"""
|
||||
This config is for fast demo purpose.
|
||||
Please use BaseSettings insetead in the future
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.use_azure = os.getenv("USE_AZURE") == "True"
|
||||
self.temperature = 0.5 if os.getenv("TEMPERATURE") is None else float(os.getenv("TEMPERATURE"))
|
||||
self.max_tokens = 800 if os.getenv("MAX_TOKENS") is None else int(os.getenv("MAX_TOKENS"))
|
||||
|
||||
self.openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
self.use_azure = os.getenv("USE_AZURE") == "True"
|
||||
self.azure_api_base = os.getenv("AZURE_API_BASE")
|
||||
self.azure_api_version = os.getenv("AZURE_API_VERSION")
|
||||
self.model = os.getenv("MODEL") or ("gpt-35-turbo" if self.use_azure else "gpt-3.5-turbo")
|
||||
|
||||
self.max_retry = int(os.getenv("MAX_RETRY")) if os.getenv("MAX_RETRY") is not None else None
|
||||
|
||||
self.continuous_mode = (
|
||||
os.getenv("CONTINOUS_MODE") == "True" if os.getenv("CONTINOUS_MODE") is not None else False
|
||||
)
|
||||
self.debug_mode = os.getenv("DEBUG_MODE") == "True" if os.getenv("DEBUG_MODE") is not None else False
|
||||
self.workspace = os.getenv("WORKSPACE") if os.getenv("WORKSPACE") is not None else "./finco_workspace"
|
||||
self.max_past_message_include = int(os.getenv("MAX_PAST_MESSAGE_INCLUDE") or 6) // 2 * 2
|
||||
156
qlib/finco/knowledge.py
Normal file
156
qlib/finco/knowledge.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from pathlib import Path
|
||||
from jinja2 import Template
|
||||
from typing import List
|
||||
|
||||
from qlib.workflow import R
|
||||
from qlib.finco.log import FinCoLog
|
||||
from qlib.finco.llm import APIBackend
|
||||
|
||||
|
||||
class Knowledge:
|
||||
"""
|
||||
Use to handle knowledge in finCo such as experiment and outside domain information
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = FinCoLog()
|
||||
|
||||
def load(self, **kwargs):
|
||||
"""
|
||||
Load knowledge in memory
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
Return
|
||||
------
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `load` method.")
|
||||
|
||||
def brief(self, **kwargs):
|
||||
"""
|
||||
Return a brief summary of knowledge
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
Return
|
||||
------
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `load` method.")
|
||||
|
||||
|
||||
class KnowledgeExperiment(Knowledge):
|
||||
"""
|
||||
Handle knowledge from experiments
|
||||
"""
|
||||
|
||||
def __init__(self, exp_name, rec_id=None):
|
||||
super().__init__()
|
||||
self.exp_name = exp_name
|
||||
self.exp = None
|
||||
self.recs = []
|
||||
|
||||
self.load(exp_name=exp_name, rec_id=rec_id)
|
||||
|
||||
def load(self, exp_name, rec_id=None):
|
||||
recs = []
|
||||
self.exp = R.get_exp(experiment_name=exp_name)
|
||||
for r in self.exp.list_recorders(rtype=self.exp.RT_L):
|
||||
if rec_id is not None and r.id != rec_id:
|
||||
continue
|
||||
recs.append(r)
|
||||
self.recs.extend(recs)
|
||||
|
||||
def brief(self):
|
||||
docs = []
|
||||
for recorder in self.recs:
|
||||
docs.append({"exp_name": self.exp.name, "record_info": recorder.info,
|
||||
"config": recorder.load_object("config"),
|
||||
"context_summary": recorder.load_object("context_summary")})
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
class Topic:
|
||||
|
||||
def __init__(self, name: str, describe: Template):
|
||||
self.name = name
|
||||
self.describe = describe
|
||||
self.docs = []
|
||||
self.knowledge = None
|
||||
self.logger = FinCoLog()
|
||||
|
||||
def summarize(self, docs: list):
|
||||
self.logger.info(f"Summarize topic: \nname: {self.name}\ndescribe: {self.describe.module}")
|
||||
prompt_workflow_selection = self.describe.render(docs=docs)
|
||||
response = APIBackend().build_messages_and_create_chat_completion(
|
||||
user_prompt=prompt_workflow_selection
|
||||
)
|
||||
|
||||
self.knowledge = response
|
||||
self.docs = docs
|
||||
|
||||
|
||||
class KnowledgeBase:
|
||||
"""
|
||||
Load knowledge, offer brief information of knowledge and common handle interfaces
|
||||
"""
|
||||
|
||||
def __init__(self, init_path=None, topics: List[Topic] = None):
|
||||
self.logger = FinCoLog()
|
||||
init_path = init_path if init_path else Path.cwd()
|
||||
|
||||
if not init_path.exists():
|
||||
self.logger.warning(f"{init_path} not exist, create empty directory.")
|
||||
Path.mkdir(init_path)
|
||||
|
||||
self.knowledge = self.load(path=init_path)
|
||||
|
||||
# todo: replace list with persistent storage strategy such as ES/pinecone to enable
|
||||
# literal search/semantic search
|
||||
self.docs = self.brief(knowledge=self.knowledge)
|
||||
|
||||
self.topics = topics if topics else []
|
||||
|
||||
def load(self, path) -> List:
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
|
||||
knowledge = []
|
||||
path = path if path.name == "mlruns" else path.joinpath("mlruns")
|
||||
R.set_uri(path.as_uri())
|
||||
for exp_name in R.list_experiments():
|
||||
knowledge.append(KnowledgeExperiment(exp_name=exp_name))
|
||||
|
||||
self.logger.plain_info(f"Load knowledge from: {path} finished.")
|
||||
return knowledge
|
||||
|
||||
def update(self, path):
|
||||
# note: only update new knowledge in future
|
||||
knowledge = self.load(path)
|
||||
self.knowledge = knowledge
|
||||
self.docs = self.brief(self.knowledge)
|
||||
self.logger.plain_info(f"Update knowledge finished.")
|
||||
|
||||
def brief(self, knowledge: List[Knowledge]) -> List:
|
||||
docs = []
|
||||
for k in knowledge:
|
||||
docs.extend(k.brief())
|
||||
|
||||
self.logger.plain_info(f"Generate brief knowledge summary finished.")
|
||||
return docs
|
||||
|
||||
def query(self, content: str = None):
|
||||
# todo: query by DSL
|
||||
return self.docs
|
||||
|
||||
def query_topics(self):
|
||||
knowledge_of_topics = []
|
||||
for topic in self.topics:
|
||||
knowledge_of_topics.append({topic.name: topic.knowledge})
|
||||
return knowledge_of_topics
|
||||
|
||||
def summarize_by_topic(self):
|
||||
for topic in self.topics:
|
||||
topic.summarize(self.docs)
|
||||
111
qlib/finco/llm.py
Normal file
111
qlib/finco/llm.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import os
|
||||
import time
|
||||
import openai
|
||||
import json
|
||||
from typing import Optional
|
||||
from qlib.finco.conf import Config
|
||||
from qlib.finco.utils import SingletonBaseClass
|
||||
from qlib.finco.log import FinCoLog
|
||||
|
||||
|
||||
class APIBackend(SingletonBaseClass):
|
||||
def __init__(self):
|
||||
self.cfg = Config()
|
||||
openai.api_key = self.cfg.openai_api_key
|
||||
if self.cfg.use_azure:
|
||||
openai.api_type = "azure"
|
||||
openai.api_base = self.cfg.azure_api_base
|
||||
openai.api_version = self.cfg.azure_api_version
|
||||
self.use_azure = self.cfg.use_azure
|
||||
|
||||
self.debug_mode = False
|
||||
if self.cfg.debug_mode:
|
||||
self.debug_mode = True
|
||||
cwd = os.getcwd()
|
||||
self.cache_file_location = os.path.join(cwd, "prompt_cache.json")
|
||||
self.cache = (
|
||||
json.load(open(self.cache_file_location, "r")) if os.path.exists(self.cache_file_location) else {}
|
||||
)
|
||||
|
||||
def build_messages_and_create_chat_completion(self, user_prompt, system_prompt=None, former_messages=[], **kwargs):
|
||||
"""build the messages to avoid implementing several redundant lines of code"""
|
||||
cfg = Config()
|
||||
# TODO: system prompt should always be provided. In development stage we can use default value
|
||||
if system_prompt is None:
|
||||
try:
|
||||
system_prompt = cfg.system_prompt
|
||||
except AttributeError:
|
||||
FinCoLog().warning("system_prompt is not set, using default value.")
|
||||
system_prompt = "You are an AI assistant who helps to answer user's questions about finance."
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
}
|
||||
]
|
||||
messages.extend(former_messages[-1*cfg.max_past_message_include:])
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
}
|
||||
)
|
||||
fcl = FinCoLog()
|
||||
response = self.try_create_chat_completion(messages=messages, **kwargs)
|
||||
fcl.log_message(messages)
|
||||
fcl.log_response(response)
|
||||
return response
|
||||
|
||||
def try_create_chat_completion(self, max_retry=10, **kwargs):
|
||||
max_retry = self.cfg.max_retry if self.cfg.max_retry is not None else max_retry
|
||||
for i in range(max_retry):
|
||||
try:
|
||||
response = self.create_chat_completion(**kwargs)
|
||||
return response
|
||||
except (openai.error.RateLimitError, openai.error.Timeout, openai.error.APIError) as e:
|
||||
print(e)
|
||||
print(f"Retrying {i+1}th time...")
|
||||
time.sleep(1)
|
||||
continue
|
||||
except openai.InvalidRequestError as e:
|
||||
print("Invalid request, will try to reduce the messages length and retry...")
|
||||
if len(kwargs["messages"]) > 2:
|
||||
kwargs["messages"] = kwargs["messages"][[0]] + kwargs["messages"][3:]
|
||||
continue
|
||||
raise e
|
||||
raise Exception(f"Failed to create chat completion after {max_retry} retries.")
|
||||
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages,
|
||||
model=None,
|
||||
temperature: float = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> str:
|
||||
|
||||
if self.debug_mode:
|
||||
key = json.dumps(messages)
|
||||
if key in self.cache:
|
||||
return self.cache[key]
|
||||
|
||||
if temperature is None:
|
||||
temperature = self.cfg.temperature
|
||||
if max_tokens is None:
|
||||
max_tokens = self.cfg.max_tokens
|
||||
|
||||
if self.cfg.use_azure:
|
||||
response = openai.ChatCompletion.create(
|
||||
engine=self.cfg.model,
|
||||
messages=messages,
|
||||
max_tokens=self.cfg.max_tokens,
|
||||
)
|
||||
else:
|
||||
response = openai.ChatCompletion.create(
|
||||
model=self.cfg.model,
|
||||
messages=messages,
|
||||
)
|
||||
resp = response.choices[0].message["content"]
|
||||
if self.debug_mode:
|
||||
self.cache[key] = resp
|
||||
json.dump(self.cache, open(self.cache_file_location, "w"))
|
||||
return resp
|
||||
131
qlib/finco/log.py
Normal file
131
qlib/finco/log.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
This module will base on Qlib's logger module and provides some interactive functions.
|
||||
"""
|
||||
import logging
|
||||
|
||||
from typing import Dict, List
|
||||
from qlib.finco.utils import SingletonBaseClass
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
class LogColors:
|
||||
"""
|
||||
ANSI color codes for use in console output.
|
||||
"""
|
||||
RED = "\033[91m"
|
||||
GREEN = "\033[92m"
|
||||
YELLOW = "\033[93m"
|
||||
BLUE = "\033[94m"
|
||||
MAGENTA = "\033[95m"
|
||||
CYAN = "\033[96m"
|
||||
WHITE = "\033[97m"
|
||||
GRAY = "\033[90m"
|
||||
BLACK = "\033[30m"
|
||||
|
||||
BOLD = "\033[1m"
|
||||
ITALIC = "\033[3m"
|
||||
|
||||
END = "\033[0m"
|
||||
|
||||
@classmethod
|
||||
def get_all_colors(cls):
|
||||
names = dir(cls)
|
||||
names = [name for name in names if not name.startswith("__") and not callable(getattr(cls, name))]
|
||||
var_values = [getattr(cls, name) for name in names]
|
||||
return var_values
|
||||
|
||||
def render(self, text: str, color: str = "", style: str = ""):
|
||||
"""
|
||||
render text by input color and style. It's not recommend that input text is already rendered.
|
||||
"""
|
||||
# This method is called too frequently, which is not good.
|
||||
colors = self.get_all_colors()
|
||||
# Perhaps color and font should be distinguished here.
|
||||
if color:
|
||||
assert color in colors, f"color should be in: {colors} but now is: {color}"
|
||||
if style:
|
||||
assert style in colors, f"style should be in: {colors} but now is: {style}"
|
||||
|
||||
text = f"{color}{text}{self.END}"
|
||||
text = f"{style}{text}{self.END}"
|
||||
|
||||
return text
|
||||
|
||||
|
||||
@contextmanager
|
||||
def formatting_log(logger, title="Info"):
|
||||
"""
|
||||
a context manager, print liens before and after a function
|
||||
"""
|
||||
length = {"Start": 120, "Task": 120, "Info": 60, "Interact": 60, "End": 120}.get(title, 60)
|
||||
color, bold = (LogColors.YELLOW, LogColors.BOLD) \
|
||||
if title in ["Start", "Task", "Info", "Interact", "End"] else (LogColors.CYAN, "")
|
||||
logger.info("")
|
||||
logger.info(f"{color}{bold}{'-'} {title} {'-' * (length - len(title))}{LogColors.END}")
|
||||
yield
|
||||
logger.info("")
|
||||
|
||||
|
||||
class FinCoLog(SingletonBaseClass):
|
||||
# TODO:
|
||||
# - config to file logger and save it into workspace
|
||||
def __init__(self) -> None:
|
||||
self.logger = logging.Logger("interactive")
|
||||
# TODO: merge these with Qlib's default logger.
|
||||
# We can do the same thing by changing the default log dict of Qlib.
|
||||
# Reference: https://github.com/microsoft/qlib/blob/main/qlib/config.py#L155
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
self.logger.addHandler(handler)
|
||||
self.logger.setLevel(logging.INFO)
|
||||
|
||||
def log_message(self, messages: List[Dict[str, str]]):
|
||||
"""
|
||||
messages is some info like this [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
},
|
||||
]
|
||||
"""
|
||||
with formatting_log(self.logger, "GPT Messages"):
|
||||
for m in messages:
|
||||
self.logger.info(
|
||||
f"{LogColors.MAGENTA}{LogColors.BOLD}Role:{LogColors.END} "
|
||||
f"{LogColors.CYAN}{m['role']}{LogColors.END}\n"
|
||||
+ f"{LogColors.MAGENTA}{LogColors.BOLD}Content:{LogColors.END} "
|
||||
f"{LogColors.CYAN}{m['content']}{LogColors.END}\n")
|
||||
|
||||
def log_response(self, response: str):
|
||||
with formatting_log(self.logger, "GPT Response"):
|
||||
self.logger.info(
|
||||
f"{LogColors.CYAN}{response}{LogColors.END}\n")
|
||||
|
||||
# TODO:
|
||||
# It looks wierd if we only have logger
|
||||
def info(self, *args, plain=False, title="Info"):
|
||||
if plain:
|
||||
return self.plain_info(*args)
|
||||
with formatting_log(self.logger, title):
|
||||
for arg in args:
|
||||
self.logger.info(f"{LogColors.WHITE}{arg}{LogColors.END}")
|
||||
|
||||
def plain_info(self, *args):
|
||||
for arg in args:
|
||||
self.logger.info(
|
||||
f"{LogColors.YELLOW}{LogColors.BOLD}Info:{LogColors.END}{LogColors.WHITE}{arg}{LogColors.END}")
|
||||
|
||||
def warning(self, *args):
|
||||
for arg in args:
|
||||
self.logger.warning(
|
||||
f"{LogColors.BLUE}{LogColors.BOLD}Warning:{LogColors.END}{arg}")
|
||||
|
||||
def error(self, *args):
|
||||
for arg in args:
|
||||
self.logger.error(
|
||||
f"{LogColors.RED}{LogColors.BOLD}Error:{LogColors.END}{arg}")
|
||||
32
qlib/finco/prompt_template.py
Normal file
32
qlib/finco/prompt_template.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from typing import Union
|
||||
from pathlib import Path
|
||||
from jinja2 import Template
|
||||
import yaml
|
||||
|
||||
from qlib.finco.utils import SingletonBaseClass
|
||||
from qlib.finco import get_finco_path
|
||||
|
||||
|
||||
class PromptTemplate(SingletonBaseClass):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
_template = yaml.load(open(Path.joinpath(get_finco_path(), "prompt_template.yaml"), "r"),
|
||||
Loader=yaml.FullLoader)
|
||||
for k, v in _template.items():
|
||||
if k == "mods":
|
||||
continue
|
||||
self.__setattr__(k, Template(v))
|
||||
|
||||
def get(self, key: str):
|
||||
return self.__dict__.get(key, Template(""))
|
||||
|
||||
def update(self, key: str, value):
|
||||
self.__setattr__(key, value)
|
||||
|
||||
def save(self, file_path: Union[str, Path]):
|
||||
if isinstance(file_path, str):
|
||||
file_path = Path(file_path)
|
||||
Path.mkdir(file_path.parent, exist_ok=True)
|
||||
|
||||
with open(file_path, 'w') as f:
|
||||
yaml.dump(self.__dict__, f)
|
||||
1012
qlib/finco/prompt_template.yaml
Normal file
1012
qlib/finco/prompt_template.yaml
Normal file
File diff suppressed because it is too large
Load Diff
1110
qlib/finco/task.py
Normal file
1110
qlib/finco/task.py
Normal file
File diff suppressed because it is too large
Load Diff
12
qlib/finco/tpl/README.md
Normal file
12
qlib/finco/tpl/README.md
Normal file
@@ -0,0 +1,12 @@
|
||||
This is a set of templates that should be copied for a new project.
|
||||
|
||||
Here are the explanations for the templates folder.
|
||||
|
||||
| folder | explanations |
|
||||
|--------|------------------------------------------------------------------|
|
||||
| sl | Default configuration for supervised learning |
|
||||
| sl-cfg | Like configuration in sl. But the dataset is highly configurable |
|
||||
|
||||
|
||||
# TODO
|
||||
- [ ] [Copier](https://copier.readthedocs.io/en/stable/#quick-start) may be useful if the generation process becomes complicated
|
||||
13
qlib/finco/tpl/__init__.py
Normal file
13
qlib/finco/tpl/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
DIRNAME = Path(__file__).absolute().resolve().parent
|
||||
|
||||
|
||||
def get_tpl_path() -> Path:
|
||||
"""
|
||||
return the template path
|
||||
Because the template path is located in the folder. We don't know where it is located. So __file__ for this module will be used.
|
||||
"""
|
||||
return DIRNAME
|
||||
83
qlib/finco/tpl/sl-cfg/workflow_config.yaml
Normal file
83
qlib/finco/tpl/sl-cfg/workflow_config.yaml
Normal file
File diff suppressed because one or more lines are too long
73
qlib/finco/tpl/sl/workflow_config.yaml
Normal file
73
qlib/finco/tpl/sl/workflow_config.yaml
Normal file
@@ -0,0 +1,73 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
experiment_name: finCo
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.2
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
38
qlib/finco/utils.py
Normal file
38
qlib/finco/utils.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import json
|
||||
|
||||
from fuzzywuzzy import fuzz
|
||||
|
||||
|
||||
class SingletonMeta(type):
|
||||
_instance = None
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(SingletonMeta, cls).__call__(*args, **kwargs)
|
||||
return cls._instance
|
||||
|
||||
|
||||
class SingletonBaseClass(metaclass=SingletonMeta):
|
||||
"""
|
||||
Because we try to support defining Singleton with `class A(SingletonBaseClass)` instead of `A(metaclass=SingletonMeta)`
|
||||
This class becomes necessary
|
||||
|
||||
"""
|
||||
# TODO: Add move this class to Qlib's general utils.
|
||||
|
||||
|
||||
def parse_json(response):
|
||||
try:
|
||||
return json.loads(response)
|
||||
except json.decoder.JSONDecodeError:
|
||||
pass
|
||||
|
||||
raise Exception(f"Failed to parse response: {response}, please report it or help us to fix it.")
|
||||
|
||||
|
||||
def similarity(text1, text2):
|
||||
text1 = text1 if isinstance(text1, str) else ""
|
||||
text2 = text2 if isinstance(text2, str) else ""
|
||||
|
||||
# Maybe we can use other similarity algorithm such as tfidf
|
||||
return fuzz.ratio(text1, text2)
|
||||
223
qlib/finco/workflow.py
Normal file
223
qlib/finco/workflow.py
Normal file
@@ -0,0 +1,223 @@
|
||||
import sys
|
||||
import copy
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from qlib.finco.task import HighLevelPlanTask, SummarizeTask, TrainTask
|
||||
from qlib.finco.prompt_template import PromptTemplate, Template
|
||||
from qlib.finco.log import FinCoLog, LogColors
|
||||
from qlib.finco.utils import similarity
|
||||
from qlib.finco.llm import APIBackend
|
||||
from qlib.finco.conf import Config
|
||||
from qlib.finco.knowledge import KnowledgeBase, Topic
|
||||
|
||||
|
||||
class WorkflowContextManager:
|
||||
"""Context Manager stores the context of the workflow"""
|
||||
|
||||
"""All context are key value pairs which saves the input, output and status of the whole workflow"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.context = {}
|
||||
self.logger = FinCoLog()
|
||||
|
||||
def set_context(self, key, value):
|
||||
if key in self.context:
|
||||
self.logger.warning("The key already exists in the context, the value will be overwritten")
|
||||
self.context[key] = value
|
||||
|
||||
def get_context(self, key):
|
||||
# NOTE: if the key doesn't exist, return None. In the future, we may raise an error to detect abnormal behavior
|
||||
if key not in self.context:
|
||||
self.logger.warning("The key doesn't exist in the context")
|
||||
return None
|
||||
return self.context[key]
|
||||
|
||||
def update_context(self, key, new_value):
|
||||
# NOTE: if the key doesn't exist, return None. In the future, we may raise an error to detect abnormal behavior
|
||||
if key not in self.context:
|
||||
self.logger.warning("The key doesn't exist in the context")
|
||||
self.context.update({key: new_value})
|
||||
|
||||
def get_all_context(self):
|
||||
"""return a deep copy of the context"""
|
||||
"""TODO: do we need to return a deep copy?"""
|
||||
return copy.deepcopy(self.context)
|
||||
|
||||
def retrieve(self, query: str) -> dict:
|
||||
if query in self.context.keys():
|
||||
return {query: self.context.get(query)}
|
||||
|
||||
# Note: retrieve information from context by string similarity maybe abandon in future
|
||||
scores = {}
|
||||
for k, v in self.context.items():
|
||||
scores.update({k: max(similarity(query, k), similarity(query, v))})
|
||||
max_score_key = max(scores, key=scores.get)
|
||||
return {max_score_key: self.context.get(max_score_key)}
|
||||
|
||||
def clear(self, reserve: list = None):
|
||||
if reserve is None:
|
||||
reserve = []
|
||||
|
||||
_context = {k: self.get_context(k) for k in reserve}
|
||||
self.context = _context
|
||||
|
||||
|
||||
class WorkflowManager:
|
||||
"""This manage the whole task automation workflow including tasks and actions"""
|
||||
|
||||
def __init__(self, workspace=None) -> None:
|
||||
self.logger = FinCoLog()
|
||||
|
||||
if workspace is None:
|
||||
self._workspace = Path.cwd() / "finco_workspace"
|
||||
else:
|
||||
self._workspace = Path(workspace)
|
||||
self.conf = Config()
|
||||
self._confirm_and_rm()
|
||||
|
||||
self.prompt_template = PromptTemplate()
|
||||
self.context = WorkflowContextManager()
|
||||
self.context.set_context("workspace", self._workspace)
|
||||
self.default_user_prompt = "Please help me build a low turnover strategy that focus more on longterm return in China A csi300. Please help to use lightgbm model."
|
||||
|
||||
def _confirm_and_rm(self):
|
||||
# if workspace exists, please confirm and remove it. Otherwise exit.
|
||||
if self._workspace.exists() and not self.conf.continuous_mode:
|
||||
self.logger.info(title="Interact")
|
||||
flag = input(
|
||||
LogColors().render(
|
||||
f"Will be deleted: \n\t{self._workspace}\n"
|
||||
f"If you do not need to delete {self._workspace},"
|
||||
f" please change the workspace dir or rename existing files\n"
|
||||
f"Are you sure you want to delete, yes(Y/y), no (N/n):",
|
||||
color=LogColors.WHITE)
|
||||
)
|
||||
if str(flag) not in ["Y", "y"]:
|
||||
sys.exit()
|
||||
else:
|
||||
# remove self._workspace
|
||||
shutil.rmtree(self._workspace)
|
||||
elif self._workspace.exists() and self.conf.continuous_mode:
|
||||
shutil.rmtree(self._workspace)
|
||||
|
||||
def set_context(self, key, value):
|
||||
"""Direct call set_context method of the context manager"""
|
||||
self.context.set_context(key, value)
|
||||
|
||||
def get_context(self) -> WorkflowContextManager:
|
||||
return self.context
|
||||
|
||||
def run(self, prompt: str) -> Path:
|
||||
"""
|
||||
The workflow manager is supposed to generate a codebase based on the prompt
|
||||
|
||||
Parameters
|
||||
----------
|
||||
prompt: str
|
||||
the prompt user gives
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path
|
||||
The workflow manager is expected to produce output that includes a codebase containing generated code, results, and reports in a designated location.
|
||||
The path is returned
|
||||
|
||||
The output path should follow a specific format:
|
||||
- TODO: design
|
||||
There is a summarized report where user can start from.
|
||||
"""
|
||||
|
||||
# NOTE: The following items are not designed to make the workflow very flexible.
|
||||
# - The generated tasks can't be changed after geting new information from the execution retuls.
|
||||
# - But it is required in some cases, if we want to build a external dataset, it maybe have to plan like autogpt...
|
||||
|
||||
# NOTE: default user prompt might be changed in the future and exposed to the user
|
||||
if prompt is None:
|
||||
self.set_context("user_prompt", self.default_user_prompt)
|
||||
else:
|
||||
self.set_context("user_prompt", prompt)
|
||||
self.logger.info(f"user_prompt: {self.get_context().get_context('user_prompt')}", title="Start")
|
||||
|
||||
# NOTE: list may not be enough for general task list
|
||||
task_list = [HighLevelPlanTask(), SummarizeTask()]
|
||||
task_finished = []
|
||||
while len(task_list):
|
||||
task_list_info = [str(task) for task in task_list]
|
||||
|
||||
# task list is not long, so sort it is not a big problem
|
||||
# TODO: sort the task list based on the priority of the task
|
||||
# task_list = sorted(task_list, key=lambda x: x.task_type)
|
||||
t = task_list.pop(0)
|
||||
self.logger.info(f"Task finished: {[str(task) for task in task_finished]}",
|
||||
f"Task in queue: {task_list_info}",
|
||||
f"Executing task: {str(t)}",
|
||||
title="Task")
|
||||
|
||||
t.assign_context_manager(self.context)
|
||||
res = t.execute()
|
||||
t.summarize()
|
||||
task_finished.append(t)
|
||||
self.context.set_context("task_finished", task_finished)
|
||||
self.logger.plain_info(f"{str(t)} finished.\n\n\n")
|
||||
|
||||
task_list = res + task_list
|
||||
|
||||
return self._workspace
|
||||
|
||||
|
||||
class LearnManager:
|
||||
__DEFAULT_TOPICS = ["IC", "MaxDropDown"]
|
||||
|
||||
def __init__(self):
|
||||
self.epoch = 0
|
||||
self.wm = WorkflowManager()
|
||||
|
||||
topics = [Topic(name=topic, describe=self.wm.prompt_template.get(f"Topic_{topic}")) for topic in
|
||||
self.__DEFAULT_TOPICS]
|
||||
self.knowledge_base = KnowledgeBase(init_path=Path.cwd().joinpath('knowledge'), topics=topics)
|
||||
|
||||
def run(self, prompt):
|
||||
# todo: add early stop condition
|
||||
for i in range(10):
|
||||
self.wm.run(prompt)
|
||||
self.knowledge_base.update(self.wm._workspace)
|
||||
self.knowledge_base.summarize_by_topic()
|
||||
self.learn()
|
||||
self.epoch += 1
|
||||
|
||||
def learn(self):
|
||||
workspace = self.wm.context.get_context("workspace")
|
||||
|
||||
def _drop_duplicate_task(_task: List):
|
||||
unique_task = {}
|
||||
for obj in _task:
|
||||
task_name = obj.__class__.__name__
|
||||
if task_name not in unique_task:
|
||||
unique_task[task_name] = obj
|
||||
return list(unique_task.values())
|
||||
|
||||
# one task maybe run several times in workflow
|
||||
task_finished = _drop_duplicate_task(self.wm.context.get_context("task_finished"))
|
||||
|
||||
user_prompt = self.wm.context.get_context("user_prompt")
|
||||
summary = self.wm.context.get_context("summary")
|
||||
|
||||
for task in task_finished:
|
||||
prompt_workflow_selection = self.wm.prompt_template.get(f"{self.__class__.__name__}_user").render(
|
||||
summary=summary, brief=self.knowledge_base.query_topics(),
|
||||
task_finished=[str(t) for t in task_finished],
|
||||
task=task.__class__.__name__, system=task.system.render(), user_prompt=user_prompt
|
||||
)
|
||||
|
||||
response = APIBackend().build_messages_and_create_chat_completion(
|
||||
user_prompt=prompt_workflow_selection,
|
||||
system_prompt=self.wm.prompt_template.get(f"{self.__class__.__name__}_system").render()
|
||||
)
|
||||
|
||||
# todo: response assertion
|
||||
task.prompt_template.update(key=f"{task.__class__.__name__}_system", value=Template(response))
|
||||
|
||||
self.wm.prompt_template.save(Path.joinpath(workspace, f"prompts/checkpoint_{self.epoch}.yml"))
|
||||
self.wm.context.clear(reserve=["workspace"])
|
||||
@@ -16,12 +16,13 @@ import torch
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
from qlib.backtest import INDICATOR_METRIC, collect_data_loop, get_strategy_executor
|
||||
from qlib.backtest.decision import BaseTradeDecision, TradeRangeByTime
|
||||
from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir, TradeRangeByTime
|
||||
from qlib.backtest.executor import SimulatorExecutor
|
||||
from qlib.backtest.high_performance_ds import BaseOrderIndicator
|
||||
from qlib.rl.contrib.naive_config_parser import BacktestConfigParser
|
||||
from qlib.rl.contrib.naive_config_parser import get_backtest_config_fromfile
|
||||
from qlib.rl.contrib.utils import read_order_file
|
||||
from qlib.rl.data.integration import init_qlib
|
||||
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
|
||||
from qlib.typehint import Literal
|
||||
|
||||
|
||||
@@ -123,13 +124,105 @@ def _generate_report(
|
||||
return report
|
||||
|
||||
|
||||
def single_with_collect_data_loop(
|
||||
def single_with_simulator(
|
||||
backtest_config: dict,
|
||||
orders: pd.DataFrame,
|
||||
split: Literal["stock", "day"] = "stock",
|
||||
cash_limit: float | None = None,
|
||||
generate_report: bool = False,
|
||||
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
|
||||
"""Run backtest in a single thread with SingleAssetOrderExecution simulator. The orders will be executed day by day.
|
||||
A new simulator will be created and used for every single-day order.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backtest_config:
|
||||
Backtest config
|
||||
orders:
|
||||
Orders to be executed. Example format:
|
||||
datetime instrument amount direction
|
||||
0 2020-06-01 INST 600.0 0
|
||||
1 2020-06-02 INST 700.0 1
|
||||
...
|
||||
split
|
||||
Method to split orders. If it is "stock", split orders by stock. If it is "day", split orders by date.
|
||||
cash_limit
|
||||
Limitation of cash.
|
||||
generate_report
|
||||
Whether to generate reports.
|
||||
|
||||
Returns
|
||||
-------
|
||||
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
|
||||
"""
|
||||
init_qlib(backtest_config["qlib"])
|
||||
|
||||
stocks = orders.instrument.unique().tolist()
|
||||
|
||||
reports = []
|
||||
decisions = []
|
||||
for _, row in orders.iterrows():
|
||||
date = pd.Timestamp(row["datetime"])
|
||||
start_time = pd.Timestamp(backtest_config["start_time"]).replace(year=date.year, month=date.month, day=date.day)
|
||||
end_time = pd.Timestamp(backtest_config["end_time"]).replace(year=date.year, month=date.month, day=date.day)
|
||||
order = Order(
|
||||
stock_id=row["instrument"],
|
||||
amount=row["amount"],
|
||||
direction=OrderDir(row["direction"]),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
executor_config = _get_multi_level_executor_config(
|
||||
strategy_config=backtest_config["strategies"],
|
||||
cash_limit=cash_limit,
|
||||
generate_report=generate_report,
|
||||
data_granularity=backtest_config["data_granularity"],
|
||||
)
|
||||
|
||||
exchange_config = copy.deepcopy(backtest_config["exchange"])
|
||||
exchange_config.update(
|
||||
{
|
||||
"codes": stocks,
|
||||
"freq": backtest_config["data_granularity"],
|
||||
}
|
||||
)
|
||||
|
||||
simulator = SingleAssetOrderExecution(
|
||||
order=order,
|
||||
executor_config=executor_config,
|
||||
exchange_config=exchange_config,
|
||||
qlib_config=None,
|
||||
cash_limit=None,
|
||||
)
|
||||
|
||||
reports.append(simulator.report_dict)
|
||||
decisions += simulator.decisions
|
||||
|
||||
indicator_1day_objs = [report["indicator_dict"]["1day"][1] for report in reports]
|
||||
indicator_info = {k: v for obj in indicator_1day_objs for k, v in obj.order_indicator_his.items()}
|
||||
records = _convert_indicator_to_dataframe(indicator_info)
|
||||
assert records is None or not np.isnan(records["ffr"]).any()
|
||||
|
||||
if generate_report:
|
||||
_report = _generate_report(decisions, [report["indicator"] for report in reports])
|
||||
|
||||
if split == "stock":
|
||||
stock_id = orders.iloc[0].instrument
|
||||
report = {stock_id: _report}
|
||||
else:
|
||||
day = orders.iloc[0].datetime
|
||||
report = {day: _report}
|
||||
|
||||
return records, report
|
||||
else:
|
||||
return records
|
||||
|
||||
|
||||
def single_with_collect_data_loop(
|
||||
backtest_config: dict,
|
||||
orders: pd.DataFrame,
|
||||
time_range: Tuple[str, str],
|
||||
exchange_config: dict,
|
||||
strategy_config: dict,
|
||||
split: Literal["stock", "day"] = "stock",
|
||||
data_granularity: str = "1min",
|
||||
cash_limit: float | None = None,
|
||||
generate_report: bool = False,
|
||||
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
|
||||
@@ -157,42 +250,44 @@ def single_with_collect_data_loop(
|
||||
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
|
||||
"""
|
||||
|
||||
init_qlib(backtest_config["qlib"])
|
||||
|
||||
trade_start_time = orders["datetime"].min()
|
||||
trade_end_time = orders["datetime"].max()
|
||||
stocks = orders.instrument.unique().tolist()
|
||||
|
||||
top_strategy_config = {
|
||||
strategy_config = {
|
||||
"class": "FileOrderStrategy",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
"kwargs": {
|
||||
"file": orders,
|
||||
"trade_range": TradeRangeByTime(
|
||||
pd.Timestamp(time_range[0]).time(),
|
||||
pd.Timestamp(time_range[1]).time(),
|
||||
pd.Timestamp(backtest_config["start_time"]).time(),
|
||||
pd.Timestamp(backtest_config["end_time"]).time(),
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
top_executor_config = _get_multi_level_executor_config(
|
||||
strategy_config=strategy_config,
|
||||
executor_config = _get_multi_level_executor_config(
|
||||
strategy_config=backtest_config["strategies"],
|
||||
cash_limit=cash_limit,
|
||||
generate_report=generate_report,
|
||||
data_granularity=data_granularity,
|
||||
data_granularity=backtest_config["data_granularity"],
|
||||
)
|
||||
|
||||
exchange_config = {
|
||||
**exchange_config,
|
||||
**{
|
||||
exchange_config = copy.deepcopy(backtest_config["exchange"])
|
||||
exchange_config.update(
|
||||
{
|
||||
"codes": stocks,
|
||||
"freq": data_granularity,
|
||||
},
|
||||
}
|
||||
"freq": backtest_config["data_granularity"],
|
||||
}
|
||||
)
|
||||
|
||||
strategy, executor = get_strategy_executor(
|
||||
start_time=pd.Timestamp(trade_start_time),
|
||||
end_time=pd.Timestamp(trade_end_time) + pd.DateOffset(1),
|
||||
strategy=top_strategy_config,
|
||||
executor=top_executor_config,
|
||||
strategy=strategy_config,
|
||||
executor=executor_config,
|
||||
benchmark=None,
|
||||
account=cash_limit if cash_limit is not None else int(1e12),
|
||||
exchange_kwargs=exchange_config,
|
||||
@@ -200,7 +295,7 @@ def single_with_collect_data_loop(
|
||||
)
|
||||
|
||||
report_dict: dict = {}
|
||||
decisions = list(collect_data_loop(trade_start_time, trade_end_time, strategy, executor, report_dict, show_progress=False))
|
||||
decisions = list(collect_data_loop(trade_start_time, trade_end_time, strategy, executor, report_dict))
|
||||
|
||||
indicator_dict = cast(INDICATOR_METRIC, report_dict.get("indicator_dict"))
|
||||
records = _convert_indicator_to_dataframe(indicator_dict["1day"][1].order_indicator_his)
|
||||
@@ -220,54 +315,46 @@ def single_with_collect_data_loop(
|
||||
|
||||
|
||||
def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFrame:
|
||||
init_qlib(backtest_config["simulator"]["qlib"])
|
||||
order_df = read_order_file(backtest_config["order_file"])
|
||||
|
||||
cash_limit = backtest_config["exchange"].pop("cash_limit")
|
||||
generate_report = backtest_config.pop("generate_report")
|
||||
|
||||
stock_pool = order_df["instrument"].unique().tolist()
|
||||
stock_pool.sort()
|
||||
|
||||
single = single_with_simulator if with_simulator else single_with_collect_data_loop
|
||||
mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"}
|
||||
torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199
|
||||
|
||||
single = single_with_collect_data_loop
|
||||
mp_config = {"n_jobs": backtest_config["runtime"]["concurrency"], "verbose": 10, "backend": "multiprocessing"}
|
||||
|
||||
for task_config in backtest_config["tasks"]:
|
||||
order_df = read_order_file(task_config["order_file"])
|
||||
exchange_config = task_config["exchange"]
|
||||
cash_limit = exchange_config.pop("cash_limit")
|
||||
generate_report = backtest_config["runtime"]["generate_report"]
|
||||
|
||||
stock_pool = order_df["instrument"].unique().tolist()
|
||||
stock_pool.sort()
|
||||
|
||||
#
|
||||
res = Parallel(**mp_config)(
|
||||
delayed(single)(
|
||||
orders=order_df[order_df["instrument"] == stock].copy(),
|
||||
time_range=task_config["time_range"],
|
||||
exchange_config=task_config["exchange"],
|
||||
strategy_config=backtest_config["strategies"],
|
||||
split="stock",
|
||||
data_granularity=task_config["data_granularity"],
|
||||
cash_limit=cash_limit,
|
||||
generate_report=generate_report,
|
||||
)
|
||||
for stock in stock_pool
|
||||
res = Parallel(**mp_config)(
|
||||
delayed(single)(
|
||||
backtest_config=backtest_config,
|
||||
orders=order_df[order_df["instrument"] == stock].copy(),
|
||||
split="stock",
|
||||
cash_limit=cash_limit,
|
||||
generate_report=generate_report,
|
||||
)
|
||||
|
||||
#
|
||||
output_path = Path(task_config["output_dir"])
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
if generate_report:
|
||||
with (output_path / "report.pkl").open("wb") as f:
|
||||
report = {}
|
||||
for r in res:
|
||||
report.update(r[1])
|
||||
pickle.dump(report, f)
|
||||
res = pd.concat([r[0] for r in res], 0)
|
||||
else:
|
||||
res = pd.concat(res)
|
||||
|
||||
if "pa" in res.columns:
|
||||
res["pa"] = res["pa"] * 10000.0 # align with training metrics
|
||||
res.to_csv(output_path / "backtest_result.csv")
|
||||
# return res # TODO
|
||||
for stock in stock_pool
|
||||
)
|
||||
|
||||
output_path = Path(backtest_config["output_dir"])
|
||||
if generate_report:
|
||||
with (output_path / "report.pkl").open("wb") as f:
|
||||
report = {}
|
||||
for r in res:
|
||||
report.update(r[1])
|
||||
pickle.dump(report, f)
|
||||
res = pd.concat([r[0] for r in res], 0)
|
||||
else:
|
||||
res = pd.concat(res)
|
||||
|
||||
if not output_path.exists():
|
||||
os.makedirs(output_path)
|
||||
|
||||
if "pa" in res.columns:
|
||||
res["pa"] = res["pa"] * 10000.0 # align with training metrics
|
||||
res.to_csv(output_path / "backtest_result.csv")
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -275,7 +362,6 @@ if __name__ == "__main__":
|
||||
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
|
||||
@@ -288,11 +374,9 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
config_parser = BacktestConfigParser(args.config_path)
|
||||
config = config_parser.parse()
|
||||
if args.n_jobs is not None: # Overwrite concurrency
|
||||
config["runtime"]["concurrency"] = args.n_jobs
|
||||
config = get_backtest_config_fromfile(args.config_path)
|
||||
if args.n_jobs is not None:
|
||||
config["concurrency"] = args.n_jobs
|
||||
|
||||
backtest(
|
||||
backtest_config=config,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import copy
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
@@ -31,7 +30,7 @@ def check_file_exist(filename: str, msg_tmpl: str = 'file "{}" does not exist')
|
||||
raise FileNotFoundError(msg_tmpl.format(filename))
|
||||
|
||||
|
||||
def load_config(path: str) -> dict:
|
||||
def parse_backtest_config(path: str) -> dict:
|
||||
abs_path = os.path.abspath(path)
|
||||
check_file_exist(abs_path)
|
||||
|
||||
@@ -66,154 +65,43 @@ def load_config(path: str) -> dict:
|
||||
base_file_name = [base_file_name]
|
||||
|
||||
for f in base_file_name:
|
||||
base_config = load_config(os.path.join(os.path.dirname(abs_path), f))
|
||||
base_config = parse_backtest_config(os.path.join(os.path.dirname(abs_path), f))
|
||||
config = merge_a_into_b(a=config, b=base_config)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
class BacktestConfigParser:
|
||||
def __init__(self, path: str) -> None:
|
||||
self.raw_config = load_config(path)
|
||||
|
||||
def parse(self) -> dict:
|
||||
self._simulator_config = self._parse_simulator()
|
||||
self._exchange_config = self._simulator_config.pop("exchange")
|
||||
config = {
|
||||
"strategies": self.raw_config["strategies"],
|
||||
"runtime": self.raw_config["runtime"],
|
||||
"tasks": self._parse_tasks(),
|
||||
"simulator": self._simulator_config,
|
||||
}
|
||||
return config
|
||||
|
||||
def _parse_tasks(self) -> dict:
|
||||
task_config = []
|
||||
for task in self.raw_config["tasks"]:
|
||||
if "output_dir" not in task:
|
||||
task["output_dir"] = os.path.join("outputs_backtest", task["name"])
|
||||
if "exchange" not in task:
|
||||
task["exchange"] = copy.deepcopy(self._exchange_config)
|
||||
else:
|
||||
task["exchange"] = self._complete_exchange_config(task["exchange"])
|
||||
task_config.append(task)
|
||||
|
||||
return task_config
|
||||
|
||||
def _complete_exchange_config(self, exchange_config: dict) -> dict:
|
||||
exchange_config_default = {
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5.0,
|
||||
"trade_unit": 100.0,
|
||||
"cash_limit": None,
|
||||
}
|
||||
exchange_config = merge_a_into_b(a=exchange_config, b=exchange_config_default)
|
||||
return exchange_config
|
||||
|
||||
def _parse_simulator(self) -> dict:
|
||||
config = self.raw_config["simulator"]
|
||||
|
||||
return {
|
||||
"qlib": config["qlib"],
|
||||
"exchange": self._complete_exchange_config(config["exchange"]),
|
||||
}
|
||||
def _convert_all_list_to_tuple(config: dict) -> dict:
|
||||
for k, v in config.items():
|
||||
if isinstance(v, list):
|
||||
config[k] = tuple(v)
|
||||
elif isinstance(v, dict):
|
||||
config[k] = _convert_all_list_to_tuple(v)
|
||||
return config
|
||||
|
||||
|
||||
class TrainingConfigParser:
|
||||
def __init__(self, path: str) -> None:
|
||||
self.raw_config = load_config(path)
|
||||
def get_backtest_config_fromfile(path: str) -> dict:
|
||||
backtest_config = parse_backtest_config(path)
|
||||
|
||||
def parse(self) -> dict:
|
||||
return {
|
||||
"general": self._parse_general(),
|
||||
"policy": self.raw_config["policy"],
|
||||
"interpreter": self.raw_config["interpreter"],
|
||||
"runtime": self._parse_runtime(),
|
||||
"training": self._parse_training(),
|
||||
"simulator": self._parse_simulator(),
|
||||
}
|
||||
exchange_config_default = {
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5.0,
|
||||
"trade_unit": 100.0,
|
||||
"cash_limit": None,
|
||||
}
|
||||
backtest_config["exchange"] = merge_a_into_b(a=backtest_config["exchange"], b=exchange_config_default)
|
||||
backtest_config["exchange"] = _convert_all_list_to_tuple(backtest_config["exchange"])
|
||||
|
||||
def _parse_general(self) -> dict:
|
||||
default = {
|
||||
"freq": "1min",
|
||||
"extra_module_paths": [],
|
||||
}
|
||||
return {**default, **self.raw_config["general"]}
|
||||
backtest_config_default = {
|
||||
"debug_single_stock": None,
|
||||
"debug_single_day": None,
|
||||
"concurrency": -1,
|
||||
"multiplier": 1.0,
|
||||
"output_dir": "outputs_backtest/",
|
||||
"generate_report": False,
|
||||
"data_granularity": "1min",
|
||||
}
|
||||
backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default)
|
||||
|
||||
def _parse_runtime(self) -> dict:
|
||||
default = {
|
||||
"seed": None,
|
||||
"use_cuda": False,
|
||||
"concurrency": 1,
|
||||
"parallel_mode": "dummy",
|
||||
}
|
||||
return {**default, **self.raw_config["runtime"]}
|
||||
|
||||
def _parse_training(self) -> dict:
|
||||
default = {
|
||||
"max_epoch": 100,
|
||||
"repeat_per_collect": 2,
|
||||
"earlystop_patience": float("inf"),
|
||||
"episode_per_collect": 10000,
|
||||
"batch_size": 256,
|
||||
"val_every_n_epoch": None,
|
||||
"checkpoint_path": "./outputs",
|
||||
"checkpoint_every_n_iters": 10,
|
||||
}
|
||||
|
||||
config = self.raw_config["training"]
|
||||
assert "order_dir" in config
|
||||
|
||||
return {**default, **config}
|
||||
|
||||
def _parse_simulator(self) -> dict:
|
||||
config = self.raw_config["simulator"]
|
||||
sim_type = config["type"]
|
||||
assert sim_type in ("simple", "full")
|
||||
|
||||
if sim_type == "simple":
|
||||
return {
|
||||
"type": sim_type,
|
||||
"data": {
|
||||
"feature_root_dir": config["data"]["feature_root_dir"],
|
||||
"feature_columns_today": config["data"]["feature_columns_today"],
|
||||
"default_start_time_index": config["data"].get("default_start_time_index", 0),
|
||||
"default_end_time_index": config["data"].get("default_end_time_index", 240),
|
||||
},
|
||||
"time_per_step": config["time_per_step"],
|
||||
"vol_limit": config["vol_limit"],
|
||||
}
|
||||
else:
|
||||
exchange_config_default = {
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5.0,
|
||||
"trade_unit": 100.0,
|
||||
# "cash_limit": None,
|
||||
}
|
||||
exchange_config = {**exchange_config_default, **config["exchange"]}
|
||||
exchange_config["freq"] = self.raw_config["general"].get("freq", "1min")
|
||||
|
||||
ret_config = {
|
||||
"type": sim_type,
|
||||
"data": {
|
||||
"feature_root_dir": config["data"]["feature_root_dir"],
|
||||
"default_start_time_index": config["data"].get("default_start_time_index", 0),
|
||||
"default_end_time_index": config["data"].get("default_end_time_index", 240),
|
||||
},
|
||||
"qlib": {
|
||||
"provider_uri_1min": config["qlib"]["provider_uri_1min"],
|
||||
},
|
||||
"exchange": exchange_config,
|
||||
}
|
||||
|
||||
return ret_config
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrainingConfigParser("/home/huoran/exp_configs/amc4th_training_refined.yml")
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
pprint(parser.parse())
|
||||
return backtest_config
|
||||
|
||||
@@ -1,362 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Callable, cast, List, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from qlib.backtest import Order
|
||||
from qlib.backtest.decision import OrderDir
|
||||
from qlib.constant import ONE_MIN
|
||||
from qlib.rl import Simulator
|
||||
from qlib.rl.contrib.naive_config_parser import TrainingConfigParser
|
||||
from qlib.rl.data.integration import init_qlib
|
||||
from qlib.rl.data.pickle_styled import load_pickle_intraday_processed_data
|
||||
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
|
||||
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.trainer import Checkpoint, backtest, train
|
||||
from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter
|
||||
from qlib.rl.utils.log import CsvWriter
|
||||
from qlib.utils import init_instance_by_config
|
||||
from tianshou.policy import BasePolicy
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def get_executor_config(freq: int) -> dict:
|
||||
return {
|
||||
"class": "NestedExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"inner_executor": {
|
||||
"class": "NestedExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"inner_executor": {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"generate_report": False,
|
||||
"time_per_step": f"{freq}min",
|
||||
"track_data": True,
|
||||
"trade_type": "serial",
|
||||
"verbose": False,
|
||||
},
|
||||
},
|
||||
"inner_strategy": {
|
||||
"class": "TWAPStrategy",
|
||||
"kwargs": {},
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
},
|
||||
"time_per_step": "30min",
|
||||
"track_data": True,
|
||||
},
|
||||
},
|
||||
"inner_strategy": {
|
||||
"class": "ProxySAOEStrategy",
|
||||
"module_path": "qlib.rl.order_execution.strategy",
|
||||
"kwargs": {},
|
||||
},
|
||||
"time_per_step": "1day",
|
||||
"track_data": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def seed_everything(seed: int) -> None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def _read_orders(order_dir: Path) -> pd.DataFrame:
|
||||
if os.path.isfile(order_dir):
|
||||
return pd.read_pickle(order_dir)
|
||||
else:
|
||||
orders = []
|
||||
for file in order_dir.iterdir():
|
||||
order_data = pd.read_pickle(file)
|
||||
orders.append(order_data)
|
||||
return pd.concat(orders)
|
||||
|
||||
|
||||
def _freq_str_to_int(freq: str) -> int:
|
||||
if freq.endswith("min"):
|
||||
return int(freq.replace("min", ""))
|
||||
elif freq.endswith("hour"):
|
||||
return int(freq.replace("hour", "") * 60)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized freq string: {freq}")
|
||||
|
||||
|
||||
class LazyLoadDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
order_df: pd.DataFrame,
|
||||
default_start_time_index: int,
|
||||
default_end_time_index: int,
|
||||
) -> None:
|
||||
self._default_start_time_index = default_start_time_index
|
||||
self._default_end_time_index = default_end_time_index
|
||||
|
||||
self._order_df = order_df
|
||||
self._ticks_index: Optional[pd.DatetimeIndex] = None
|
||||
self._data_dir = Path(data_dir)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._order_df)
|
||||
|
||||
def __getitem__(self, index: int) -> Order:
|
||||
row = self._order_df.iloc[index]
|
||||
date = pd.Timestamp(str(row["date"]))
|
||||
|
||||
if self._ticks_index is None:
|
||||
# TODO: We only load ticks index once based on the assumption that ticks index of different dates
|
||||
# TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index
|
||||
# TODO: of all dates.
|
||||
|
||||
data = load_pickle_intraday_processed_data(
|
||||
data_dir=self._data_dir,
|
||||
stock_id=row["instrument"],
|
||||
date=date,
|
||||
feature_columns_today=[],
|
||||
feature_columns_yesterday=[],
|
||||
backtest=True,
|
||||
)
|
||||
self._ticks_index = [t - date for t in data.today.index]
|
||||
|
||||
order = Order(
|
||||
stock_id=row["instrument"],
|
||||
amount=row["amount"],
|
||||
direction=OrderDir(int(row["order_type"])),
|
||||
start_time=date + self._ticks_index[self._default_start_time_index],
|
||||
end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN,
|
||||
)
|
||||
|
||||
return order
|
||||
|
||||
|
||||
def _split_order_df_by_instrument(df: pd.DataFrame, k: int) -> List[pd.DataFrame]:
|
||||
df = df.copy()
|
||||
df["group"] = df["instrument"].apply(lambda s: hash(s) % k)
|
||||
dfs = [df[df["group"] == i].drop(columns=["group"]) for i in range(k)]
|
||||
return dfs
|
||||
|
||||
|
||||
def _get_simulator_factory(
|
||||
sim_type: str,
|
||||
data_dir: Path,
|
||||
freq_min: int,
|
||||
simulator_config: dict,
|
||||
) -> Callable[[Order], Simulator]:
|
||||
if sim_type == "simple":
|
||||
|
||||
def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
|
||||
simulator = SingleAssetOrderExecutionSimple(
|
||||
order=order,
|
||||
data_dir=data_dir,
|
||||
feature_columns_today=simulator_config["data"]["feature_columns_today"],
|
||||
data_granularity=freq_min,
|
||||
ticks_per_step=simulator_config["time_per_step"],
|
||||
vol_threshold=simulator_config["vol_limit"],
|
||||
)
|
||||
return simulator
|
||||
|
||||
return _simulator_factory_simple
|
||||
elif sim_type == "full":
|
||||
init_qlib(simulator_config["qlib"])
|
||||
executor_config = get_executor_config(freq_min)
|
||||
exchange_config = simulator_config["exchange"]
|
||||
|
||||
def _simulator_factory_full(order: Order) -> SingleAssetOrderExecution:
|
||||
simulator = SingleAssetOrderExecution(
|
||||
order=order,
|
||||
executor_config=executor_config,
|
||||
exchange_config=exchange_config, # `codes` will be set in SingleAssetOrderExecution.__init__()
|
||||
qlib_config=None,
|
||||
cash_limit=None,
|
||||
)
|
||||
return simulator
|
||||
|
||||
return _simulator_factory_full
|
||||
else:
|
||||
raise ValueError(f"Unknown simulator type: {sim_type}")
|
||||
|
||||
|
||||
def train_and_test(
|
||||
freq: str,
|
||||
concurrency: int,
|
||||
parallel_mode: str,
|
||||
training_config: dict,
|
||||
simulator_config: dict,
|
||||
policy: BasePolicy,
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
reward: Reward,
|
||||
run_training: bool,
|
||||
run_backtest: bool,
|
||||
) -> None:
|
||||
freq_min: int = _freq_str_to_int(freq)
|
||||
order_root_path = Path(training_config["order_dir"])
|
||||
feature_root_dir = simulator_config["data"]["feature_root_dir"]
|
||||
assert simulator_config["data"]["default_start_time_index"] % freq_min == 0
|
||||
assert simulator_config["data"]["default_end_time_index"] % freq_min == 0
|
||||
|
||||
_simulator_factory = _get_simulator_factory(
|
||||
sim_type=simulator_config["type"],
|
||||
data_dir=feature_root_dir,
|
||||
freq_min=freq_min,
|
||||
simulator_config=simulator_config,
|
||||
)
|
||||
|
||||
# Load orders
|
||||
load_data_tags = []
|
||||
orders_by_tag = {}
|
||||
if run_training:
|
||||
load_data_tags += ["train", "valid"]
|
||||
if run_backtest:
|
||||
load_data_tags += ["test"]
|
||||
for tag in load_data_tags:
|
||||
order_df = _read_orders(order_root_path / tag).reset_index()
|
||||
dfs = _split_order_df_by_instrument(order_df, concurrency)
|
||||
datasets = [
|
||||
LazyLoadDataset(
|
||||
data_dir=feature_root_dir,
|
||||
order_df=df,
|
||||
default_start_time_index=simulator_config["data"]["default_start_time_index"] // freq_min,
|
||||
default_end_time_index=simulator_config["data"]["default_end_time_index"] // freq_min,
|
||||
)
|
||||
for df in dfs
|
||||
]
|
||||
orders_by_tag[tag] = datasets
|
||||
|
||||
if run_training:
|
||||
callbacks: List[Callback] = [
|
||||
MetricsWriter(dirpath=Path(training_config["checkpoint_path"])),
|
||||
Checkpoint(
|
||||
dirpath=Path(training_config["checkpoint_path"]) / "checkpoints",
|
||||
every_n_iters=training_config["checkpoint_every_n_iters"],
|
||||
save_latest="copy",
|
||||
),
|
||||
EarlyStopping(
|
||||
patience=training_config["earlystop_patience"],
|
||||
monitor="val/pa",
|
||||
),
|
||||
]
|
||||
|
||||
train(
|
||||
simulator_fn=_simulator_factory,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
policy=policy,
|
||||
reward=reward,
|
||||
initial_states=cast(List[Sequence[Order]], orders_by_tag["train"]),
|
||||
trainer_kwargs={
|
||||
"max_iters": training_config["max_epoch"],
|
||||
"finite_env_type": parallel_mode,
|
||||
"concurrency": concurrency,
|
||||
"val_every_n_iters": training_config["val_every_n_epoch"],
|
||||
"callbacks": callbacks,
|
||||
},
|
||||
vessel_kwargs={
|
||||
"episode_per_iter": training_config["episode_per_collect"],
|
||||
"update_kwargs": {
|
||||
"batch_size": training_config["batch_size"],
|
||||
"repeat": training_config["repeat_per_collect"],
|
||||
},
|
||||
"val_initial_states": cast(List[Sequence[Order]], orders_by_tag["valid"]),
|
||||
},
|
||||
)
|
||||
|
||||
if run_backtest:
|
||||
backtest(
|
||||
simulator_fn=_simulator_factory,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
initial_states=cast(List[Sequence[Order]], orders_by_tag["test"]),
|
||||
policy=policy,
|
||||
logger=CsvWriter(Path(training_config["checkpoint_path"])),
|
||||
reward=reward,
|
||||
finite_env_type=parallel_mode, # type: ignore[arg-type]
|
||||
concurrency=concurrency,
|
||||
)
|
||||
|
||||
|
||||
def main(config: dict, run_training: bool, run_backtest: bool) -> None:
|
||||
if not run_training and not run_backtest:
|
||||
warnings.warn("Skip the entire job since training and backtest are both skipped.")
|
||||
return
|
||||
|
||||
seed = config["runtime"]["seed"]
|
||||
if seed is not None:
|
||||
seed_everything(seed)
|
||||
|
||||
for extra_module_path in config["general"]["extra_module_paths"]:
|
||||
sys.path.append(extra_module_path)
|
||||
|
||||
state_interpreter: StateInterpreter = init_instance_by_config(config["interpreter"]["state"])
|
||||
action_interpreter: ActionInterpreter = init_instance_by_config(config["interpreter"]["action"])
|
||||
reward: Reward = init_instance_by_config(config["interpreter"]["reward"])
|
||||
|
||||
additional_policy_kwargs = {
|
||||
"obs_space": state_interpreter.observation_space,
|
||||
"action_space": action_interpreter.action_space,
|
||||
}
|
||||
# Create torch network
|
||||
if "network" in config["policy"]:
|
||||
network_config = config["policy"]["network"]
|
||||
network_config["kwargs"] = {
|
||||
**network_config.get("kwargs", {}),
|
||||
**{"obs_space": state_interpreter.observation_space},
|
||||
}
|
||||
additional_policy_kwargs["network"] = init_instance_by_config(network_config)
|
||||
|
||||
# Create policy
|
||||
policy_config = config["policy"]["policy"]
|
||||
policy_config["kwargs"] = {**policy_config.get("kwargs", {}), **additional_policy_kwargs}
|
||||
policy: BasePolicy = init_instance_by_config(policy_config)
|
||||
|
||||
use_cuda = config["runtime"]["use_cuda"]
|
||||
if use_cuda:
|
||||
policy.cuda()
|
||||
|
||||
train_and_test(
|
||||
freq=config["general"]["freq"],
|
||||
concurrency=config["runtime"]["concurrency"],
|
||||
parallel_mode=config["runtime"]["parallel_mode"],
|
||||
training_config=config["training"],
|
||||
simulator_config=config["simulator"],
|
||||
policy=policy,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
reward=reward,
|
||||
run_training=run_training,
|
||||
run_backtest=run_backtest,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
|
||||
parser.add_argument("--no_training", action="store_true", help="Skip training workflow.")
|
||||
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.")
|
||||
args = parser.parse_args()
|
||||
|
||||
config_parser = TrainingConfigParser(args.config_path)
|
||||
config = config_parser.parse()
|
||||
main(config, run_training=not args.no_training, run_backtest=args.run_backtest)
|
||||
268
qlib/rl/contrib/train_onpolicy.py
Normal file
268
qlib/rl/contrib/train_onpolicy.py
Normal file
@@ -0,0 +1,268 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import cast, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import yaml
|
||||
from qlib.backtest import Order
|
||||
from qlib.backtest.decision import OrderDir
|
||||
from qlib.constant import ONE_MIN
|
||||
from qlib.rl.data.native import load_handler_intraday_processed_data
|
||||
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.trainer import Checkpoint, backtest, train
|
||||
from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter
|
||||
from qlib.rl.utils.log import CsvWriter
|
||||
from qlib.utils import init_instance_by_config
|
||||
from tianshou.policy import BasePolicy
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def seed_everything(seed: int) -> None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def _read_orders(order_dir: Path) -> pd.DataFrame:
|
||||
if os.path.isfile(order_dir):
|
||||
return pd.read_pickle(order_dir)
|
||||
else:
|
||||
orders = []
|
||||
for file in order_dir.iterdir():
|
||||
order_data = pd.read_pickle(file)
|
||||
orders.append(order_data)
|
||||
return pd.concat(orders)
|
||||
|
||||
|
||||
class LazyLoadDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
order_file_path: Path,
|
||||
default_start_time_index: int,
|
||||
default_end_time_index: int,
|
||||
) -> None:
|
||||
self._default_start_time_index = default_start_time_index
|
||||
self._default_end_time_index = default_end_time_index
|
||||
|
||||
self._order_df = _read_orders(order_file_path).reset_index()
|
||||
self._ticks_index: Optional[pd.DatetimeIndex] = None
|
||||
self._data_dir = Path(data_dir)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._order_df)
|
||||
|
||||
def __getitem__(self, index: int) -> Order:
|
||||
row = self._order_df.iloc[index]
|
||||
date = pd.Timestamp(str(row["date"]))
|
||||
|
||||
if self._ticks_index is None:
|
||||
# TODO: We only load ticks index once based on the assumption that ticks index of different dates
|
||||
# TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index
|
||||
# TODO: of all dates.
|
||||
|
||||
data = load_handler_intraday_processed_data(
|
||||
data_dir=self._data_dir,
|
||||
stock_id=row["instrument"],
|
||||
date=date,
|
||||
feature_columns_today=[],
|
||||
feature_columns_yesterday=[],
|
||||
backtest=True,
|
||||
index_only=True,
|
||||
)
|
||||
self._ticks_index = [t - date for t in data.today.index]
|
||||
|
||||
order = Order(
|
||||
stock_id=row["instrument"],
|
||||
amount=row["amount"],
|
||||
direction=OrderDir(int(row["order_type"])),
|
||||
start_time=date + self._ticks_index[self._default_start_time_index],
|
||||
end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN,
|
||||
)
|
||||
|
||||
return order
|
||||
|
||||
|
||||
def train_and_test(
|
||||
env_config: dict,
|
||||
simulator_config: dict,
|
||||
trainer_config: dict,
|
||||
data_config: dict,
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
policy: BasePolicy,
|
||||
reward: Reward,
|
||||
run_training: bool,
|
||||
run_backtest: bool,
|
||||
) -> None:
|
||||
order_root_path = Path(data_config["source"]["order_dir"])
|
||||
|
||||
data_granularity = simulator_config.get("data_granularity", 1)
|
||||
|
||||
def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
|
||||
return SingleAssetOrderExecutionSimple(
|
||||
order=order,
|
||||
data_dir=data_config["source"]["feature_root_dir"],
|
||||
feature_columns_today=data_config["source"]["feature_columns_today"],
|
||||
feature_columns_yesterday=data_config["source"]["feature_columns_yesterday"],
|
||||
data_granularity=data_granularity,
|
||||
ticks_per_step=simulator_config["time_per_step"],
|
||||
vol_threshold=simulator_config["vol_limit"],
|
||||
)
|
||||
|
||||
assert data_config["source"]["default_start_time_index"] % data_granularity == 0
|
||||
assert data_config["source"]["default_end_time_index"] % data_granularity == 0
|
||||
|
||||
if run_training:
|
||||
train_dataset, valid_dataset = [
|
||||
LazyLoadDataset(
|
||||
data_dir=data_config["source"]["feature_root_dir"],
|
||||
order_file_path=order_root_path / tag,
|
||||
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
|
||||
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
|
||||
)
|
||||
for tag in ("train", "valid")
|
||||
]
|
||||
|
||||
callbacks: List[Callback] = []
|
||||
if "checkpoint_path" in trainer_config:
|
||||
callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"])))
|
||||
callbacks.append(
|
||||
Checkpoint(
|
||||
dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints",
|
||||
every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1),
|
||||
save_latest="copy",
|
||||
),
|
||||
)
|
||||
if "earlystop_patience" in trainer_config:
|
||||
callbacks.append(
|
||||
EarlyStopping(
|
||||
patience=trainer_config["earlystop_patience"],
|
||||
monitor="val/pa",
|
||||
)
|
||||
)
|
||||
|
||||
train(
|
||||
simulator_fn=_simulator_factory_simple,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
policy=policy,
|
||||
reward=reward,
|
||||
initial_states=cast(List[Order], train_dataset),
|
||||
trainer_kwargs={
|
||||
"max_iters": trainer_config["max_epoch"],
|
||||
"finite_env_type": env_config["parallel_mode"],
|
||||
"concurrency": env_config["concurrency"],
|
||||
"val_every_n_iters": trainer_config.get("val_every_n_epoch", None),
|
||||
"callbacks": callbacks,
|
||||
},
|
||||
vessel_kwargs={
|
||||
"episode_per_iter": trainer_config["episode_per_collect"],
|
||||
"update_kwargs": {
|
||||
"batch_size": trainer_config["batch_size"],
|
||||
"repeat": trainer_config["repeat_per_collect"],
|
||||
},
|
||||
"val_initial_states": valid_dataset,
|
||||
},
|
||||
)
|
||||
|
||||
if run_backtest:
|
||||
test_dataset = LazyLoadDataset(
|
||||
data_dir=data_config["source"]["feature_root_dir"],
|
||||
order_file_path=order_root_path / "test",
|
||||
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
|
||||
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
|
||||
)
|
||||
|
||||
backtest(
|
||||
simulator_fn=_simulator_factory_simple,
|
||||
state_interpreter=state_interpreter,
|
||||
action_interpreter=action_interpreter,
|
||||
initial_states=test_dataset,
|
||||
policy=policy,
|
||||
logger=CsvWriter(Path(trainer_config["checkpoint_path"])),
|
||||
reward=reward,
|
||||
finite_env_type=env_config["parallel_mode"],
|
||||
concurrency=env_config["concurrency"],
|
||||
)
|
||||
|
||||
|
||||
def main(config: dict, run_training: bool, run_backtest: bool) -> None:
|
||||
if not run_training and not run_backtest:
|
||||
warnings.warn("Skip the entire job since training and backtest are both skipped.")
|
||||
return
|
||||
|
||||
if "seed" in config["runtime"]:
|
||||
seed_everything(config["runtime"]["seed"])
|
||||
|
||||
for extra_module_path in config["env"].get("extra_module_paths", []):
|
||||
sys.path.append(extra_module_path)
|
||||
|
||||
state_interpreter: StateInterpreter = init_instance_by_config(config["state_interpreter"])
|
||||
action_interpreter: ActionInterpreter = init_instance_by_config(config["action_interpreter"])
|
||||
reward: Reward = init_instance_by_config(config["reward"])
|
||||
|
||||
additional_policy_kwargs = {
|
||||
"obs_space": state_interpreter.observation_space,
|
||||
"action_space": action_interpreter.action_space,
|
||||
}
|
||||
|
||||
# Create torch network
|
||||
if "network" in config:
|
||||
if "kwargs" not in config["network"]:
|
||||
config["network"]["kwargs"] = {}
|
||||
config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space})
|
||||
additional_policy_kwargs["network"] = init_instance_by_config(config["network"])
|
||||
|
||||
# Create policy
|
||||
if "kwargs" not in config["policy"]:
|
||||
config["policy"]["kwargs"] = {}
|
||||
config["policy"]["kwargs"].update(additional_policy_kwargs)
|
||||
policy: BasePolicy = init_instance_by_config(config["policy"])
|
||||
|
||||
use_cuda = config["runtime"].get("use_cuda", False)
|
||||
if use_cuda:
|
||||
policy.cuda()
|
||||
|
||||
train_and_test(
|
||||
env_config=config["env"],
|
||||
simulator_config=config["simulator"],
|
||||
data_config=config["data"],
|
||||
trainer_config=config["trainer"],
|
||||
action_interpreter=action_interpreter,
|
||||
state_interpreter=state_interpreter,
|
||||
policy=policy,
|
||||
reward=reward,
|
||||
run_training=run_training,
|
||||
run_backtest=run_backtest,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
|
||||
parser.add_argument("--no_training", action="store_true", help="Skip training workflow.")
|
||||
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.")
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config_path, "r") as input_stream:
|
||||
config = yaml.safe_load(input_stream)
|
||||
|
||||
main(config, run_training=not args.no_training, run_backtest=args.run_backtest)
|
||||
@@ -13,7 +13,6 @@ import os
|
||||
from qlib.backtest import Exchange, Order
|
||||
from qlib.backtest.decision import TradeRange, TradeRangeByTime
|
||||
from qlib.constant import EPS_T
|
||||
from qlib.data.dataset import DatasetH
|
||||
from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
|
||||
|
||||
|
||||
@@ -141,16 +140,6 @@ def load_backtest_data(
|
||||
return backtest_data
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(1000),
|
||||
key=lambda path: path,
|
||||
)
|
||||
def _load_handler_pickle(path: str) -> DatasetH:
|
||||
with open(path, "rb") as fstream:
|
||||
obj = pickle.load(fstream)
|
||||
return obj
|
||||
|
||||
|
||||
class HandlerIntradayProcessedData(BaseIntradayProcessedData):
|
||||
"""Subclass of IntradayProcessedData. Used to handle handler (bin format) style data."""
|
||||
|
||||
@@ -162,6 +151,7 @@ class HandlerIntradayProcessedData(BaseIntradayProcessedData):
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool = False,
|
||||
index_only: bool = False,
|
||||
) -> None:
|
||||
def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame:
|
||||
df = df.reset_index()
|
||||
@@ -171,17 +161,31 @@ class HandlerIntradayProcessedData(BaseIntradayProcessedData):
|
||||
|
||||
path = os.path.join(data_dir, "backtest" if backtest else "feature", f"{stock_id}.pkl")
|
||||
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
|
||||
dataset = _load_handler_pickle(path)
|
||||
with open(path, "rb") as fstream:
|
||||
dataset = pickle.load(fstream)
|
||||
data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
|
||||
|
||||
self.today = _drop_stock_id(data[feature_columns_today])
|
||||
self.yesterday = _drop_stock_id(data[feature_columns_yesterday])
|
||||
if index_only:
|
||||
self.today = _drop_stock_id(data[[]])
|
||||
self.yesterday = _drop_stock_id(data[[]])
|
||||
else:
|
||||
self.today = _drop_stock_id(data[feature_columns_today])
|
||||
self.yesterday = _drop_stock_id(data[feature_columns_yesterday])
|
||||
|
||||
def __repr__(self) -> str:
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
return f"{self.__class__.__name__}({self.today}, {self.yesterday})"
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
|
||||
key=lambda data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only: (
|
||||
stock_id,
|
||||
date,
|
||||
backtest,
|
||||
index_only,
|
||||
),
|
||||
)
|
||||
def load_handler_intraday_processed_data(
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
@@ -189,14 +193,10 @@ def load_handler_intraday_processed_data(
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool = False,
|
||||
index_only: bool = False,
|
||||
) -> HandlerIntradayProcessedData:
|
||||
return HandlerIntradayProcessedData(
|
||||
data_dir,
|
||||
stock_id,
|
||||
date,
|
||||
feature_columns_today,
|
||||
feature_columns_yesterday,
|
||||
backtest,
|
||||
data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only
|
||||
)
|
||||
|
||||
|
||||
@@ -229,4 +229,5 @@ class HandlerProcessedDataProvider(ProcessedDataProvider):
|
||||
self.feature_columns_today,
|
||||
self.feature_columns_yesterday,
|
||||
backtest=self.backtest,
|
||||
index_only=False,
|
||||
)
|
||||
|
||||
@@ -26,6 +26,7 @@ from typing import List, Sequence, cast
|
||||
import cachetools
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from cachetools.keys import hashkey
|
||||
|
||||
from qlib.backtest.decision import Order, OrderDir
|
||||
from qlib.rl.data.base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
|
||||
@@ -157,15 +158,6 @@ class SimpleIntradayBacktestData(BaseIntradayBacktestData):
|
||||
return cast(pd.DatetimeIndex, self.data.index)
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(1000),
|
||||
key=lambda path: path,
|
||||
)
|
||||
def _load_df_pickle(path: str) -> pd.DataFrame:
|
||||
df = pd.read_pickle(path)
|
||||
return df
|
||||
|
||||
|
||||
class PickleIntradayProcessedData(BaseIntradayProcessedData):
|
||||
"""Subclass of IntradayProcessedData. Used to handle pickle-styled data."""
|
||||
|
||||
@@ -174,18 +166,36 @@ class PickleIntradayProcessedData(BaseIntradayProcessedData):
|
||||
data_dir: Path | str,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool,
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
) -> None:
|
||||
if isinstance(data_dir, str):
|
||||
data_dir = Path(data_dir)
|
||||
path = data_dir / ("backtest" if backtest else "feature") / f"{stock_id}.pkl"
|
||||
df = _load_df_pickle(str(path))
|
||||
df = df.loc[pd.IndexSlice[stock_id, :, date]]
|
||||
proc = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id)
|
||||
|
||||
self.today = df[feature_columns_today]
|
||||
self.yesterday = df[feature_columns_yesterday]
|
||||
# We have to infer the names here because,
|
||||
# unfortunately they are not included in the original data.
|
||||
cnames = _infer_processed_data_column_names(feature_dim)
|
||||
|
||||
time_length: int = len(time_index)
|
||||
|
||||
try:
|
||||
# new data format
|
||||
proc = proc.loc[pd.IndexSlice[stock_id, :, date]]
|
||||
assert len(proc) == time_length and len(proc.columns) == feature_dim * 2
|
||||
proc_today = proc[cnames]
|
||||
proc_yesterday = proc[[f"{c}_1" for c in cnames]].rename(columns=lambda c: c[:-2])
|
||||
except (IndexError, KeyError):
|
||||
# legacy data
|
||||
proc = proc.loc[pd.IndexSlice[stock_id, date]]
|
||||
assert time_length * feature_dim * 2 == len(proc)
|
||||
proc_today = proc.to_numpy()[: time_length * feature_dim].reshape((time_length, feature_dim))
|
||||
proc_yesterday = proc.to_numpy()[time_length * feature_dim :].reshape((time_length, feature_dim))
|
||||
proc_today = pd.DataFrame(proc_today, index=time_index, columns=cnames)
|
||||
proc_yesterday = pd.DataFrame(proc_yesterday, index=time_index, columns=cnames)
|
||||
|
||||
self.today: pd.DataFrame = proc_today
|
||||
self.yesterday: pd.DataFrame = proc_yesterday
|
||||
assert len(self.today.columns) == len(self.yesterday.columns) == feature_dim
|
||||
assert len(self.today) == len(self.yesterday) == time_length
|
||||
|
||||
def __repr__(self) -> str:
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
@@ -203,38 +213,25 @@ def load_simple_intraday_backtest_data(
|
||||
return SimpleIntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir)
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
|
||||
key=lambda data_dir, stock_id, date, feature_dim, time_index: hashkey(data_dir, stock_id, date),
|
||||
)
|
||||
def load_pickle_intraday_processed_data(
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool = False,
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
) -> BaseIntradayProcessedData:
|
||||
return PickleIntradayProcessedData(
|
||||
data_dir,
|
||||
stock_id,
|
||||
date,
|
||||
feature_columns_today,
|
||||
feature_columns_yesterday,
|
||||
backtest,
|
||||
)
|
||||
return PickleIntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index)
|
||||
|
||||
|
||||
class PickleProcessedDataProvider(ProcessedDataProvider):
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: Path,
|
||||
feature_columns_today: List[str],
|
||||
feature_columns_yesterday: List[str],
|
||||
backtest: bool = False,
|
||||
) -> None:
|
||||
def __init__(self, data_dir: Path) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._data_dir = data_dir
|
||||
self._backtest = backtest
|
||||
self._feature_columns_today = feature_columns_today
|
||||
self._feature_columns_yesterday = feature_columns_yesterday
|
||||
|
||||
def get_data(
|
||||
self,
|
||||
@@ -247,9 +244,8 @@ class PickleProcessedDataProvider(ProcessedDataProvider):
|
||||
data_dir=self._data_dir,
|
||||
stock_id=stock_id,
|
||||
date=date,
|
||||
feature_columns_today=self._feature_columns_today,
|
||||
feature_columns_yesterday=self._feature_columns_yesterday,
|
||||
backtest=self._backtest,
|
||||
feature_dim=feature_dim,
|
||||
time_index=time_index,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -4,11 +4,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generator, List, Optional
|
||||
import cachetools
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest import collect_data_loop, Exchange, get_exchange, get_strategy_executor
|
||||
from qlib.backtest import collect_data_loop, get_strategy_executor
|
||||
from qlib.backtest.decision import BaseTradeDecision, Order, TradeRangeByTime
|
||||
from qlib.backtest.executor import NestedExecutor
|
||||
from qlib.rl.data.integration import init_qlib
|
||||
@@ -17,18 +16,6 @@ from .state import SAOEState
|
||||
from .strategy import SAOEStateAdapter, SAOEStrategy
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
cache=cachetools.LRUCache(1000),
|
||||
key=lambda order, _: order.stock_id,
|
||||
)
|
||||
def _create_exchange(order: Order, exchange_config: dict) -> Exchange:
|
||||
exchange_kwargs = {
|
||||
**exchange_config,
|
||||
"codes": [order.stock_id],
|
||||
}
|
||||
return get_exchange(**exchange_kwargs)
|
||||
|
||||
|
||||
class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
"""Single-asset order execution (SAOE) simulator which is implemented based on Qlib backtest tools.
|
||||
|
||||
@@ -89,7 +76,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
executor=executor_config,
|
||||
benchmark=order.stock_id,
|
||||
account=cash_limit if cash_limit is not None else int(1e12),
|
||||
exchange_kwargs=_create_exchange(order, exchange_config),
|
||||
exchange_kwargs=exchange_config,
|
||||
pos_type="Position" if cash_limit is not None else "InfPosition",
|
||||
)
|
||||
|
||||
@@ -103,7 +90,6 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
||||
trade_strategy=strategy,
|
||||
trade_executor=self._executor,
|
||||
return_value=self.report_dict,
|
||||
show_progress=False,
|
||||
)
|
||||
assert isinstance(self._collect_data_loop, Generator)
|
||||
|
||||
|
||||
@@ -12,8 +12,7 @@ from pathlib import Path
|
||||
from qlib.backtest.decision import Order, OrderDir
|
||||
from qlib.constant import EPS, EPS_T, float_or_ndarray
|
||||
from qlib.rl.data.base import BaseIntradayBacktestData
|
||||
from qlib.rl.data.native import DataframeIntradayBacktestData
|
||||
from qlib.rl.data.pickle_styled import load_pickle_intraday_processed_data
|
||||
from qlib.rl.data.native import DataframeIntradayBacktestData, load_handler_intraday_processed_data
|
||||
from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data
|
||||
from qlib.rl.simulator import Simulator
|
||||
from qlib.rl.utils import LogLevel
|
||||
@@ -43,6 +42,8 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
Path to load backtest data.
|
||||
feature_columns_today
|
||||
Columns of today's feature.
|
||||
feature_columns_yesterday
|
||||
Columns of yesterday's feature.
|
||||
data_granularity
|
||||
Number of ticks between consecutive data entries.
|
||||
ticks_per_step
|
||||
@@ -79,6 +80,7 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
order: Order,
|
||||
data_dir: Path,
|
||||
feature_columns_today: List[str] = [],
|
||||
feature_columns_yesterday: List[str] = [],
|
||||
data_granularity: int = 1,
|
||||
ticks_per_step: int = 30,
|
||||
vol_threshold: Optional[float] = None,
|
||||
@@ -90,6 +92,7 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
self.order = order
|
||||
self.data_dir = data_dir
|
||||
self.feature_columns_today = feature_columns_today
|
||||
self.feature_columns_yesterday = feature_columns_yesterday
|
||||
self.ticks_per_step: int = ticks_per_step // data_granularity
|
||||
self.vol_threshold = vol_threshold
|
||||
|
||||
@@ -119,13 +122,14 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
|
||||
|
||||
def get_backtest_data(self) -> BaseIntradayBacktestData:
|
||||
try:
|
||||
data = load_pickle_intraday_processed_data(
|
||||
data = load_handler_intraday_processed_data(
|
||||
data_dir=self.data_dir,
|
||||
stock_id=self.order.stock_id,
|
||||
date=pd.Timestamp(self.order.start_time.date()),
|
||||
feature_columns_today=self.feature_columns_today,
|
||||
feature_columns_yesterday=[],
|
||||
feature_columns_yesterday=self.feature_columns_yesterday,
|
||||
backtest=True,
|
||||
index_only=False,
|
||||
)
|
||||
return DataframeIntradayBacktestData(data.today)
|
||||
except (AttributeError, FileNotFoundError):
|
||||
|
||||
@@ -451,7 +451,6 @@ class SAOEIntStrategy(SAOEStrategy):
|
||||
state_interpreter: dict | StateInterpreter,
|
||||
action_interpreter: dict | ActionInterpreter,
|
||||
network: dict | torch.nn.Module | None = None,
|
||||
immediate_addition: bool = False,
|
||||
outer_trade_decision: BaseTradeDecision | None = None,
|
||||
level_infra: LevelInfrastructure | None = None,
|
||||
common_infra: CommonInfrastructure | None = None,
|
||||
@@ -502,12 +501,9 @@ class SAOEIntStrategy(SAOEStrategy):
|
||||
|
||||
if self._policy is not None:
|
||||
self._policy.eval()
|
||||
|
||||
self.immediate_addition = immediate_addition
|
||||
|
||||
def reset(self, outer_trade_decision: BaseTradeDecision | None = None, **kwargs: Any) -> None:
|
||||
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
self.trade_amount_planned = collections.defaultdict(float)
|
||||
|
||||
def _generate_trade_details(self, act: np.ndarray, exec_vols: List[float]) -> pd.DataFrame:
|
||||
assert hasattr(self.outer_trade_decision, "order_list")
|
||||
@@ -543,15 +539,9 @@ class SAOEIntStrategy(SAOEStrategy):
|
||||
|
||||
oh = self.trade_exchange.get_order_helper()
|
||||
order_list = []
|
||||
for decision, exec_vol, state in zip(self.outer_trade_decision.get_decision(), exec_vols, states):
|
||||
order = cast(Order, decision)
|
||||
if self.immediate_addition:
|
||||
self.trade_amount_planned[order.stock_id] += exec_vol
|
||||
amount_planned = self.trade_amount_planned[order.stock_id]
|
||||
amount_finished = order.amount - state.position
|
||||
exec_vol = min(state.position, amount_planned - amount_finished)
|
||||
|
||||
for decision, exec_vol in zip(self.outer_trade_decision.get_decision(), exec_vols):
|
||||
if exec_vol != 0:
|
||||
order = cast(Order, decision)
|
||||
order_list.append(oh.create(order.stock_id, exec_vol, order.direction))
|
||||
|
||||
return TradeDecisionWithDetails(
|
||||
|
||||
@@ -20,7 +20,7 @@ def train(
|
||||
simulator_fn: Callable[[InitialStateType], Simulator],
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
initial_states: List[Sequence[InitialStateType]],
|
||||
initial_states: Sequence[InitialStateType],
|
||||
policy: BasePolicy,
|
||||
reward: Reward,
|
||||
vessel_kwargs: Dict[str, Any],
|
||||
@@ -39,9 +39,7 @@ def train(
|
||||
action_interpreter
|
||||
Interprets the policy actions.
|
||||
initial_states
|
||||
List of Initial state iterators to iterate over. There should be 1 or `concurrency` initial state iterators in
|
||||
the list. If there is only 1 initial state iterator, this iterator will be shared by all workers and every
|
||||
state will be run exactly once. Otherwise, every worker will have its own iterator.
|
||||
Initial states to iterate over. Every state will be run exactly once.
|
||||
policy
|
||||
Policy to train against.
|
||||
reward
|
||||
@@ -69,7 +67,7 @@ def backtest(
|
||||
simulator_fn: Callable[[InitialStateType], Simulator],
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
initial_states: List[Sequence[InitialStateType]],
|
||||
initial_states: Sequence[InitialStateType],
|
||||
policy: BasePolicy,
|
||||
logger: LogWriter | List[LogWriter],
|
||||
reward: Reward | None = None,
|
||||
@@ -89,9 +87,7 @@ def backtest(
|
||||
action_interpreter
|
||||
Interprets the policy actions.
|
||||
initial_states
|
||||
List of Initial state iterators to iterate over. There should be 1 or `concurrency` initial state iterators in
|
||||
the list. If there is only 1 initial state iterator, this iterator will be shared by all workers and every
|
||||
state will be run exactly once. Otherwise, every worker will have its own iterator.
|
||||
Initial states to iterate over. Every state will be run exactly once.
|
||||
policy
|
||||
Policy to test against.
|
||||
logger
|
||||
|
||||
@@ -5,9 +5,8 @@ from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import copy
|
||||
from contextlib import AbstractContextManager, ExitStack, contextmanager
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, OrderedDict, Sequence, TypeVar, cast
|
||||
|
||||
@@ -207,50 +206,45 @@ class Trainer:
|
||||
|
||||
self._call_callback_hooks("on_fit_start")
|
||||
|
||||
with _wrap_context(vessel.train_seed_iterators()) as train_iterators, _wrap_context(
|
||||
vessel.val_seed_iterators()
|
||||
) as valid_iterators:
|
||||
train_vector_env = self.venv_from_iterator(train_iterators)
|
||||
valid_vector_env = self.venv_from_iterator(valid_iterators)
|
||||
while not self.should_stop:
|
||||
msg = f"\n{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\tTrain iteration {self.current_iter + 1}/{self.max_iters}"
|
||||
_logger.info(msg)
|
||||
|
||||
while not self.should_stop:
|
||||
msg = f"\n{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\tTrain iteration {self.current_iter + 1}/{self.max_iters}"
|
||||
print(msg)
|
||||
_logger.info(msg)
|
||||
self.initialize_iter()
|
||||
|
||||
self.initialize_iter()
|
||||
self._call_callback_hooks("on_iter_start")
|
||||
|
||||
self._call_callback_hooks("on_iter_start")
|
||||
self.current_stage = "train"
|
||||
self._call_callback_hooks("on_train_start")
|
||||
|
||||
self.current_stage = "train"
|
||||
self._call_callback_hooks("on_train_start")
|
||||
# TODO
|
||||
# Add a feature that supports reloading the training environment every few iterations.
|
||||
with _wrap_context(vessel.train_seed_iterator()) as iterator:
|
||||
vector_env = self.venv_from_iterator(iterator)
|
||||
self.vessel.train(vector_env)
|
||||
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
|
||||
|
||||
# TODO
|
||||
# Add a feature that supports reloading the training environment every few iterations.
|
||||
self.vessel.train(train_vector_env)
|
||||
self._call_callback_hooks("on_train_end")
|
||||
|
||||
self._call_callback_hooks("on_train_end")
|
||||
if self.val_every_n_iters is not None and (self.current_iter + 1) % self.val_every_n_iters == 0:
|
||||
# Implementation of validation loop
|
||||
self.current_stage = "val"
|
||||
self._call_callback_hooks("on_validate_start")
|
||||
with _wrap_context(vessel.val_seed_iterator()) as iterator:
|
||||
vector_env = self.venv_from_iterator(iterator)
|
||||
self.vessel.validate(vector_env)
|
||||
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
|
||||
|
||||
if self.val_every_n_iters is not None and (self.current_iter + 1) % self.val_every_n_iters == 0:
|
||||
# Implementation of validation loop
|
||||
self.current_stage = "val"
|
||||
self._call_callback_hooks("on_validate_start")
|
||||
self._call_callback_hooks("on_validate_end")
|
||||
|
||||
self.vessel.validate(valid_vector_env)
|
||||
# This iteration is considered complete.
|
||||
# Bumping the current iteration counter.
|
||||
self.current_iter += 1
|
||||
|
||||
self._call_callback_hooks("on_validate_end")
|
||||
if self.max_iters is not None and self.current_iter >= self.max_iters:
|
||||
self.should_stop = True
|
||||
|
||||
# This iteration is considered complete.
|
||||
# Bumping the current iteration counter.
|
||||
self.current_iter += 1
|
||||
|
||||
if self.max_iters is not None and self.current_iter >= self.max_iters:
|
||||
self.should_stop = True
|
||||
|
||||
self._call_callback_hooks("on_iter_end")
|
||||
|
||||
del train_vector_env # FIXME: Explicitly delete this object to avoid memory leak.
|
||||
del valid_vector_env # FIXME: Explicitly delete this object to avoid memory leak.
|
||||
self._call_callback_hooks("on_iter_end")
|
||||
|
||||
self._call_callback_hooks("on_fit_end")
|
||||
|
||||
@@ -271,16 +265,16 @@ class Trainer:
|
||||
|
||||
self.current_stage = "test"
|
||||
self._call_callback_hooks("on_test_start")
|
||||
with _wrap_context(vessel.test_seed_iterators()) as iterators:
|
||||
vector_env = self.venv_from_iterator(iterators)
|
||||
with _wrap_context(vessel.test_seed_iterator()) as iterator:
|
||||
vector_env = self.venv_from_iterator(iterator)
|
||||
self.vessel.test(vector_env)
|
||||
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
|
||||
self._call_callback_hooks("on_test_end")
|
||||
|
||||
def venv_from_iterator(self, iterators: List[Iterable[InitialStateType]]) -> FiniteVectorEnv:
|
||||
def venv_from_iterator(self, iterator: Iterable[InitialStateType]) -> FiniteVectorEnv:
|
||||
"""Create a vectorized environment from iterator and the training vessel."""
|
||||
|
||||
def env_factory(iterator):
|
||||
def env_factory():
|
||||
# FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env),
|
||||
# and could be thread unsafe.
|
||||
# I'm not sure whether it's a design flaw.
|
||||
@@ -306,7 +300,7 @@ class Trainer:
|
||||
)
|
||||
|
||||
return vectorize_env(
|
||||
[partial(env_factory, iterator=it) for it in iterators],
|
||||
env_factory,
|
||||
self.finite_env_type,
|
||||
self.concurrency,
|
||||
self.loggers,
|
||||
@@ -340,11 +334,8 @@ class Trainer:
|
||||
@contextmanager
|
||||
def _wrap_context(obj):
|
||||
"""Make any object a (possibly dummy) context manager."""
|
||||
if isinstance(obj, list) and isinstance(obj[0], AbstractContextManager):
|
||||
with ExitStack() as stack:
|
||||
yield [stack.enter_context(e) for e in obj]
|
||||
stack.pop_all().close()
|
||||
elif isinstance(obj, AbstractContextManager):
|
||||
|
||||
if isinstance(obj, AbstractContextManager):
|
||||
# obj has __enter__ and __exit__
|
||||
with obj as ctx:
|
||||
yield ctx
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import weakref
|
||||
from typing import List, TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generic, Iterable, Sequence, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generic, Iterable, Sequence, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
@@ -49,23 +49,19 @@ class TrainingVesselBase(Generic[InitialStateType, StateType, ActType, ObsType,
|
||||
def assign_trainer(self, trainer: Trainer) -> None:
|
||||
self.trainer = weakref.proxy(trainer) # type: ignore
|
||||
|
||||
def train_seed_iterators(
|
||||
self,
|
||||
) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
|
||||
"""Override this to create a seed iterators for training.
|
||||
def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
"""Override this to create a seed iterator for training.
|
||||
If the iterable is a context manager, the whole training will be invoked in the with-block,
|
||||
and the iterator will be automatically closed after the training is done."""
|
||||
raise SeedIteratorNotAvailable("Seed iterators for training is not available.")
|
||||
raise SeedIteratorNotAvailable("Seed iterator for training is not available.")
|
||||
|
||||
def val_seed_iterators(self) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
|
||||
"""Override this to create a seed iterators for validation."""
|
||||
raise SeedIteratorNotAvailable("Seed iterators for validation is not available.")
|
||||
def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
"""Override this to create a seed iterator for validation."""
|
||||
raise SeedIteratorNotAvailable("Seed iterator for validation is not available.")
|
||||
|
||||
def test_seed_iterators(
|
||||
self,
|
||||
) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
|
||||
"""Override this to create a seed iterators for testing."""
|
||||
raise SeedIteratorNotAvailable("Seed iterators for testing is not available.")
|
||||
def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
"""Override this to create a seed iterator for testing."""
|
||||
raise SeedIteratorNotAvailable("Seed iterator for testing is not available.")
|
||||
|
||||
def train(self, vector_env: BaseVectorEnv) -> Dict[str, Any]:
|
||||
"""Implement this to train one iteration. In RL, one iteration usually refers to one collect."""
|
||||
@@ -124,9 +120,9 @@ class TrainingVessel(TrainingVesselBase):
|
||||
action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType],
|
||||
policy: BasePolicy,
|
||||
reward: Reward,
|
||||
train_initial_states: List[Sequence[InitialStateType]] | None = None,
|
||||
val_initial_states: List[Sequence[InitialStateType]] | None = None,
|
||||
test_initial_states: List[Sequence[InitialStateType]] | None = None,
|
||||
train_initial_states: Sequence[InitialStateType] | None = None,
|
||||
val_initial_states: Sequence[InitialStateType] | None = None,
|
||||
test_initial_states: Sequence[InitialStateType] | None = None,
|
||||
buffer_size: int = 20000,
|
||||
episode_per_iter: int = 1000,
|
||||
update_kwargs: Dict[str, Any] = cast(Dict[str, Any], None),
|
||||
@@ -136,49 +132,34 @@ class TrainingVessel(TrainingVesselBase):
|
||||
self.action_interpreter = action_interpreter
|
||||
self.policy = policy
|
||||
self.reward = reward
|
||||
self.train_initial_states = None if train_initial_states is None else train_initial_states
|
||||
self.val_initial_states = None if val_initial_states is None else val_initial_states
|
||||
self.test_initial_states = None if test_initial_states is None else test_initial_states
|
||||
self.train_initial_states = train_initial_states
|
||||
self.val_initial_states = val_initial_states
|
||||
self.test_initial_states = test_initial_states
|
||||
self.buffer_size = buffer_size
|
||||
self.episode_per_iter = episode_per_iter
|
||||
self.update_kwargs = update_kwargs or {}
|
||||
|
||||
def train_seed_iterators(
|
||||
self,
|
||||
) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
|
||||
def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
if self.train_initial_states is not None:
|
||||
_logger.info(f"Training initial states collection sizes: {[len(e) for e in self.train_initial_states]}")
|
||||
train_initial_states = [
|
||||
self._random_subset("train", e, self.trainer.fast_dev_run) for e in self.train_initial_states
|
||||
]
|
||||
iterators = [DataQueue(e, repeat=-1, shuffle=True) for e in train_initial_states]
|
||||
return cast(List[Iterable[InitialStateType]], iterators)
|
||||
else:
|
||||
return super().train_seed_iterators()
|
||||
_logger.info("Training initial states collection size: %d", len(self.train_initial_states))
|
||||
# Implement fast_dev_run here.
|
||||
train_initial_states = self._random_subset("train", self.train_initial_states, self.trainer.fast_dev_run)
|
||||
return DataQueue(train_initial_states, repeat=-1, shuffle=True)
|
||||
return super().train_seed_iterator()
|
||||
|
||||
def val_seed_iterators(self) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
|
||||
def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
if self.val_initial_states is not None:
|
||||
_logger.info(f"Validation initial states collection sizes: {[len(e) for e in self.val_initial_states]}")
|
||||
val_initial_states = [
|
||||
self._random_subset("val", e, self.trainer.fast_dev_run) for e in self.val_initial_states
|
||||
]
|
||||
iterators = [DataQueue(e, repeat=1) for e in val_initial_states]
|
||||
return cast(List[Iterable[InitialStateType]], iterators)
|
||||
else:
|
||||
return super().val_seed_iterators()
|
||||
_logger.info("Validation initial states collection size: %d", len(self.val_initial_states))
|
||||
val_initial_states = self._random_subset("val", self.val_initial_states, self.trainer.fast_dev_run)
|
||||
return DataQueue(val_initial_states, repeat=1)
|
||||
return super().val_seed_iterator()
|
||||
|
||||
def test_seed_iterators(
|
||||
self,
|
||||
) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
|
||||
def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
|
||||
if self.test_initial_states is not None:
|
||||
_logger.info(f"Testing initial states collection sizes: {[len(e) for e in self.test_initial_states]}")
|
||||
test_initial_states = [
|
||||
self._random_subset("test", e, self.trainer.fast_dev_run) for e in self.test_initial_states
|
||||
]
|
||||
iterators = [DataQueue(e, repeat=1) for e in test_initial_states]
|
||||
return cast(List[Iterable[InitialStateType]], iterators)
|
||||
else:
|
||||
return super().test_seed_iterators()
|
||||
_logger.info("Testing initial states collection size: %d", len(self.test_initial_states))
|
||||
test_initial_states = self._random_subset("test", self.test_initial_states, self.trainer.fast_dev_run)
|
||||
return DataQueue(test_initial_states, repeat=1)
|
||||
return super().test_seed_iterator()
|
||||
|
||||
def train(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]:
|
||||
"""Create a collector and collects ``episode_per_iter`` episodes.
|
||||
|
||||
@@ -258,46 +258,6 @@ class FiniteVectorEnv(BaseVectorEnv):
|
||||
|
||||
return np.stack(obs)
|
||||
|
||||
def step2(
|
||||
self,
|
||||
action: np.ndarray,
|
||||
id: int | List[int] | np.ndarray | None = None,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
assert not self._zombie
|
||||
wrapped_id = self._wrap_id(id)
|
||||
id2idx = {i: k for k, i in enumerate(wrapped_id)}
|
||||
request_id = list(filter(lambda i: i in self._alive_env_ids, wrapped_id))
|
||||
result = {}
|
||||
|
||||
# ask super to step alive envs and remap to current index
|
||||
if request_id:
|
||||
valid_act = np.stack([action[id2idx[i]] for i in request_id])
|
||||
tmp = super().step(valid_act, request_id)
|
||||
|
||||
for obs_next, rew, done, info in zip(*tmp):
|
||||
obs_next = self._postproc_env_obs(obs_next)
|
||||
result[info["env_id"]] = [obs_next, rew, done, info]
|
||||
|
||||
# logging
|
||||
for i, r in result.items():
|
||||
if i in self._alive_env_ids and r[0] is not None:
|
||||
for logger in self._logger:
|
||||
logger.on_env_step(i, *r)
|
||||
|
||||
for _, reward, __, info in result.values():
|
||||
self._set_default_info(info)
|
||||
self._set_default_rew(reward)
|
||||
for r in result.values():
|
||||
if r[0] is None:
|
||||
r[0] = self._get_default_obs()
|
||||
if r[1] is None:
|
||||
r[1] = self._get_default_rew()
|
||||
if r[3] is None:
|
||||
r[3] = self._get_default_info()
|
||||
|
||||
ret = list(map(np.stack, zip(*result.values())))
|
||||
return cast(Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], ret)
|
||||
|
||||
def step(
|
||||
self,
|
||||
action: np.ndarray,
|
||||
@@ -351,7 +311,7 @@ class FiniteShmemVectorEnv(FiniteVectorEnv, ShmemVectorEnv):
|
||||
|
||||
|
||||
def vectorize_env(
|
||||
env_factories: List[Callable[..., gym.Env]],
|
||||
env_factory: Callable[..., gym.Env],
|
||||
env_type: FiniteEnvType,
|
||||
concurrency: int,
|
||||
logger: LogWriter | List[LogWriter],
|
||||
@@ -374,10 +334,9 @@ def vectorize_env(
|
||||
|
||||
Parameters
|
||||
----------
|
||||
env_factories
|
||||
Callables to instantiate one single ``gym.Env``.
|
||||
There should be 1 or `concurrency` env_factories. If there is 1 env_factory, all concurrent workers will have
|
||||
the same env_factory. Otherwise, each worker will have its own env_factory.
|
||||
env_factory
|
||||
Callable to instantiate one single ``gym.Env``.
|
||||
All concurrent workers will have the same ``env_factory``.
|
||||
env_type
|
||||
dummy or subproc or shmem. Corresponding to
|
||||
`parallelism in tianshou <https://tianshou.readthedocs.io/en/master/api/tianshou.env.html#vectorenv>`_.
|
||||
@@ -399,8 +358,6 @@ def vectorize_env(
|
||||
def env_factory(): ...
|
||||
vectorize_env(env_factory, ...)
|
||||
"""
|
||||
assert len(env_factories) in (1, concurrency)
|
||||
|
||||
env_type_cls_mapping: Dict[str, Type[FiniteVectorEnv]] = {
|
||||
"dummy": FiniteDummyVectorEnv,
|
||||
"subproc": FiniteSubprocVectorEnv,
|
||||
@@ -409,7 +366,4 @@ def vectorize_env(
|
||||
|
||||
finite_env_cls = env_type_cls_mapping[env_type]
|
||||
|
||||
if len(env_factories) == 1:
|
||||
return finite_env_cls(logger, [env_factories[0] for _ in range(concurrency)])
|
||||
else:
|
||||
return finite_env_cls(logger, env_factories)
|
||||
return finite_env_cls(logger, [env_factory for _ in range(concurrency)])
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, Generator
|
||||
|
||||
from line_profiler import LineProfiler
|
||||
|
||||
|
||||
@contextmanager
|
||||
def simple_perf(desc: str = "", out_path: str = None) -> Generator[None, None, None]:
|
||||
s = time.perf_counter()
|
||||
yield
|
||||
e = time.perf_counter()
|
||||
msg = f"{desc}: {(e - s) * 1000.0:.4f} ms"
|
||||
|
||||
if out_path is not None:
|
||||
with open(out_path, "a") as fstream:
|
||||
fstream.write(msg + "\n")
|
||||
else:
|
||||
print(msg)
|
||||
|
||||
|
||||
def lprofile(func: Callable) -> Callable:
|
||||
def wrapper(*args, **kwargs):
|
||||
lp = LineProfiler()
|
||||
lpw = lp(func)
|
||||
res = lpw(*args, **kwargs)
|
||||
lp.print_stats()
|
||||
return res
|
||||
|
||||
return wrapper
|
||||
@@ -18,7 +18,7 @@ from ..utils import fill_placeholder, flatten_dict, class_casting, get_date_by_s
|
||||
from ..utils.time import Freq
|
||||
from ..utils.data import deepcopy_basic_type
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
|
||||
|
||||
from qlib.contrib.analyzer import HFAnalyzer, SignalAnalyzer
|
||||
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
@@ -156,6 +156,9 @@ class RecordTemp:
|
||||
with class_casting(self, self.depend_cls):
|
||||
self.check(include_self=True)
|
||||
|
||||
def analyse(self):
|
||||
raise NotImplementedError(f"Please implement the `analysis` method.")
|
||||
|
||||
|
||||
class SignalRecord(RecordTemp):
|
||||
"""
|
||||
|
||||
15
scripts/finco/README.md
Normal file
15
scripts/finco/README.md
Normal file
@@ -0,0 +1,15 @@
|
||||
|
||||
|
||||
# Requirements
|
||||
|
||||
|
||||
Use following install command to complete the project.
|
||||
```
|
||||
pip install -e '.[finco]'
|
||||
```
|
||||
|
||||
|
||||
# TODOs
|
||||
|
||||
- [ ] Select the appropriate LLM API
|
||||
- Which API is more suitable for meeting our requirements - the original API or an alternative like LangChain?
|
||||
15
scripts/finco/cmd.sh
Normal file
15
scripts/finco/cmd.sh
Normal file
@@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
set -x # show command
|
||||
set -e # Error on exception
|
||||
|
||||
DIR="$(
|
||||
cd "$(dirname "$(readlink -f "$0")")" || exit
|
||||
pwd -P
|
||||
)"
|
||||
# --load the cridentials
|
||||
if [ -e $DIR/cridential.sh ]; then
|
||||
source $DIR/cridential.sh
|
||||
fi
|
||||
|
||||
# run the command
|
||||
python -m qlib.finco.cli "please help me build a low turnover strategy that focus more on longterm return"
|
||||
3
scripts/finco/cridential.sh.example
Normal file
3
scripts/finco/cridential.sh.example
Normal file
@@ -0,0 +1,3 @@
|
||||
export OPENAI_API_TYPE=azure # This only necessary for Azure OpenAI
|
||||
export OPENAI_API_KEY=
|
||||
export OPENAI_API_BASE=
|
||||
8
setup.py
8
setup.py
@@ -173,6 +173,14 @@ setup(
|
||||
"tianshou<=0.4.10",
|
||||
"torch",
|
||||
],
|
||||
"finco": [
|
||||
# finco is not necessary for all Qlib users; So a single require section is used for it.
|
||||
"openapi",
|
||||
"pydantic", # Please add it to basic requirements after the design of pydantic is state.
|
||||
"python-dotenv", # I don't think this is necessary if we use pydantic.
|
||||
"fuzzywuzzy",
|
||||
"python-Levenshtein" # not necessary but accelerate fuzzywuzzy calculation
|
||||
],
|
||||
},
|
||||
include_package_data=True,
|
||||
classifiers=[
|
||||
|
||||
71
tests/finco/test_cfg.py
Normal file
71
tests/finco/test_cfg.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import unittest
|
||||
import shutil
|
||||
import difflib
|
||||
from qlib.finco.tpl import get_tpl_path
|
||||
import ruamel.yaml as yaml
|
||||
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
from pathlib import Path
|
||||
from qlib.finco.tpl import get_tpl_path
|
||||
from qlib.finco.task import YamlEditTask
|
||||
|
||||
DIRNAME = Path(__file__).absolute().resolve().parent
|
||||
|
||||
|
||||
class FincoTpl(TestAutoData):
|
||||
def test_tpl_consistence(self):
|
||||
"""Motivation: make sure the configuable template is consistent with the default config"""
|
||||
tpl_p = get_tpl_path()
|
||||
with (tpl_p / "sl" / "workflow_config.yaml").open("rb") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
# init_data_handler
|
||||
hd: DataHandlerLP = init_instance_by_config(config["task"]["dataset"]["kwargs"]["handler"])
|
||||
# NOTE: The config in workflow_config.yaml is generated by the following code:
|
||||
# dump in yaml format to file without auto linebreak
|
||||
# print(yaml.dump(hd.data_loader.fields, width=10000, stream=open("_tmp", "w")))
|
||||
|
||||
with (tpl_p / "sl-cfg" / "workflow_config.yaml").open("rb") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
hd_ds: DataHandlerLP = init_instance_by_config(config["task"]["dataset"]["kwargs"]["handler"])
|
||||
self.assertEqual(hd_ds.data_loader.fields, hd.data_loader.fields)
|
||||
|
||||
check = hd_ds.fetch().fillna(0.0) == hd.fetch().fillna(0.0)
|
||||
self.assertTrue(check.all().all())
|
||||
|
||||
def test_update_yaml(self):
|
||||
p = get_tpl_path() / "sl" / "workflow_config.yaml"
|
||||
p_new = DIRNAME / "_test_config.yaml"
|
||||
shutil.copy(p, p_new)
|
||||
updated_content = """
|
||||
class: LGBModelTest
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 1.8879
|
||||
learning_rate: 0.3
|
||||
subsample: 0.8790
|
||||
lambda_l1: 205.7000
|
||||
lambda_l2: 580.9769
|
||||
max_depth: 9
|
||||
num_leaves: 211
|
||||
num_threads: 21
|
||||
"""
|
||||
t = YamlEditTask(p_new, "task.model", updated_content)
|
||||
t.execute()
|
||||
# NOTE: the formmat is changed by ruamel.yaml, so it can't be compared by text directly..
|
||||
# print the diff between p and p_new with difflib
|
||||
# with p.open("r") as fp:
|
||||
# content = fp.read()
|
||||
# with p_new.open("r") as fp:
|
||||
# content_new = fp.read()
|
||||
# for line in difflib.unified_diff(content, content_new, fromfile="original", tofile="new", lineterm=""):
|
||||
# print(line)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
66
tests/finco/test_sumarize.py
Normal file
66
tests/finco/test_sumarize.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import unittest
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from dotenv import load_dotenv
|
||||
# pydantic support load_dotenv, so load_dotenv will be deprecated in the future.
|
||||
|
||||
from qlib.finco.task import SummarizeTask
|
||||
from qlib.finco.workflow import WorkflowContextManager
|
||||
from qlib.finco.llm import APIBackend
|
||||
from qlib.finco.workflow import WorkflowManager
|
||||
|
||||
load_dotenv(verbose=True, override=True)
|
||||
|
||||
|
||||
class TestSummarize(unittest.TestCase):
|
||||
|
||||
def test_chat(self):
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Your are a professional financial assistant.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "How to write a perfect quant strategy.",
|
||||
},
|
||||
]
|
||||
response = APIBackend().try_create_chat_completion(messages=messages)
|
||||
print(response)
|
||||
|
||||
def test_execution(self):
|
||||
task = SummarizeTask()
|
||||
context = WorkflowContextManager()
|
||||
context.set_context("workspace", "../../examples/benchmarks/Linear")
|
||||
context.set_context("user_prompt", "My main focus is on the performance of the strategy's return."
|
||||
"Please summarize the information and give me some advice.")
|
||||
task.assign_context_manager(context)
|
||||
resp = task.execute()
|
||||
print(resp)
|
||||
|
||||
def test_generate_batch_result(self):
|
||||
wm = WorkflowManager()
|
||||
|
||||
prompt = wm.default_user_prompt
|
||||
# prompt = ""
|
||||
|
||||
workdir = os.path.dirname(wm.get_context().get_context("workspace"))
|
||||
summaries_path = os.path.join(workdir, "summaries")
|
||||
|
||||
if not os.path.exists(summaries_path):
|
||||
os.makedirs(summaries_path)
|
||||
|
||||
for i in range(10):
|
||||
wm.run(prompt)
|
||||
if os.path.exists(f"{workdir}/finCoReport.md"):
|
||||
shutil.move(f"{workdir}/finCoReport.md", f"{workdir}/summaries/finCoReport{i}.md")
|
||||
|
||||
def test_parse2txt(self):
|
||||
task = SummarizeTask()
|
||||
resp = task.get_info_from_file("")
|
||||
print(resp)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
23
tests/finco/test_utils.py
Normal file
23
tests/finco/test_utils.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import unittest
|
||||
from qlib.finco.utils import SingletonBaseClass
|
||||
|
||||
|
||||
class TimeUtils(unittest.TestCase):
|
||||
|
||||
def test_singleton(self):
|
||||
# self.assertEqual(self.to_str(data.tail()), self.to_str(res))
|
||||
closure_checker = []
|
||||
|
||||
class A(SingletonBaseClass):
|
||||
|
||||
def __init__(self) -> None:
|
||||
closure_checker.append(0)
|
||||
|
||||
A()
|
||||
self.assertEqual(len(closure_checker), 1)
|
||||
A()
|
||||
self.assertEqual(len(closure_checker), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user