1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-29 00:51:19 +08:00

Compare commits

..

47 Commits

Author SHA1 Message Date
Xu Yang
2df211c320 merge all commit 2023-07-13 16:29:44 +08:00
Fivele-Li
effed382e9 Optimize prompt for entire learn loop (#1589)
* Adjust prompt and fix cases
* adjust summarizeTask & learn prompts;
* fix typos & drop duplicate task method;

* adjust learn prompts;
2023-07-11 18:13:52 +08:00
Fivele-Li
86ffd1799d Add knowledge module and tune summarizeTask (#1582)
* Add knowledge module
* add KnowledgeExperiment add KnowledgeBase;
* add knowledge associate prompts to template;

* Add Topic class
* add Topic to summarize knowledge;
* add recorder's metric to summarizeTask;

---------

Co-authored-by: Cadenza-Li <362237642@qq.com>
2023-07-06 11:39:36 +08:00
Young
aef11536e3 rename & test 2023-07-04 20:28:08 +08:00
Xu Yang
8b0fdf1623 Merge pull request #1581 from microsoft/xuyang1/fix_singleton_bug
fix singleton bug
2023-07-04 16:51:51 +08:00
Xu Yang
9a36f8da20 fix singleton bug 2023-07-04 16:20:02 +08:00
Xu Yang
b7757d5008 Merge pull request #1580 from microsoft/xuyang1/refine_workflow_to_increase_success_rate
refine workflow to increase success rate
2023-07-03 17:59:54 +08:00
Xu Yang
ee5e5cfdd8 remove useless code 2023-07-03 17:57:13 +08:00
Xu Yang
6cb87ecfd1 refine code to use qrun 2023-07-03 17:56:22 +08:00
Xu Yang
9119bcdd3c Merge pull request #1576 from microsoft/xuyang1/add_config_and_code_dump_task
refine workflow and prompts
2023-06-30 14:43:49 +08:00
Xu Yang
4fccf8112d fix one workflow 2023-06-30 14:33:41 +08:00
Xu Yang
73bd79ca1a merge into one commit 2023-06-30 14:23:40 +08:00
Fivele-Li
7e84f3aae2 Add backtest and backforward task (#1568)
* * add TrainTask & BacktestTask;
* add BackForwardTask;
* adjust prompt_template.yaml which default config failed to backtest;
* run workflow in loop
* add update method to prompt_template.py

* remove debug code

* Adjust Learn Process
* add LearnManager class & use LearnManager to update system prompt;
* use qrun to replace recorder for training and backtesting;

* Adjust analyser
* analyser independent of recorder;
* rename analyser's workspace attribution;
* analyser load variable by recorder.

---------

Co-authored-by: Cadenza-Li <362237642@qq.com>
2023-06-30 10:04:43 +08:00
Fivele-Li
1326ac614d Add docs to context and retrieve (#1566)
* add analyser docstring to context;
* add retrieve method to context manager;

* add notes to retrieve
2023-06-24 21:47:27 +08:00
Fivele-Li
f12184cc0f Add analyser task and optimize interact (#1552)
* * optimize interact
* add AnalyserTask
* optimize logger format and add render feature

* format optimize
2023-06-16 11:42:45 +08:00
Xu Yang
a70386ad52 Merge pull request #1550 from microsoft/xuyang1/refine_task_prompts
add datahandler and design action task according to component
2023-06-14 14:52:42 +08:00
Xu Yang
74619ed8d8 fix using defaut in record strategy and backtest 2023-06-14 14:52:16 +08:00
Fivele-Li
1a523df007 Optimize log and interact of FinCo (#1549)
* use FinCoLog for a better interact experience

* addition file changes

* optimize format

* optimize format
2023-06-14 14:48:17 +08:00
Xu Yang
f9cc8a5aaa remove useless prompt 2023-06-14 10:46:38 +08:00
Xu Yang
7762c5a1fd add datahandler and design action task according to component 2023-06-13 23:28:27 +08:00
Xu Yang
fa7ef29281 Merge pull request #1548 from microsoft/xuyang1/add_dump_to_file_task
add simple readme & move prompt templates to outer yaml file to make the code clean
2023-06-13 15:29:13 +08:00
Xu Yang
429c9a7c66 format 2023-06-13 15:27:59 +08:00
Xu Yang
80fbc00792 move prompt templates to yaml file to make code clean 2023-06-13 15:21:19 +08:00
Xu Yang
01accec24c update code 2023-06-12 16:25:16 +08:00
Fivele-Li
1d88830b0d Add recorder task and visualize (#1542)
* add recorder task

* add batch generate summarize report unittest.

* * add recorder to RecorderTask;
* add matplot figure to analyzer.py

* add image to markdown;

* Add some log

* update figure path.

---------

Co-authored-by: Young <afe.young@gmail.com>
Co-authored-by: Cadenza-Li <362237642@qq.com>
2023-06-12 15:48:00 +08:00
you-n-g
ad7498e287 Edit yaml task (#1538)
* Edit yaml task

* update comments
2023-06-02 00:44:41 +08:00
you-n-g
73d51f05b4 Init workspace and CMDTask (#1537)
* Update setup.py and config

* WIP

* init_workspace and CMDTask

* Delete test_sumarize.py
2023-06-01 23:32:35 +08:00
Fivele-Li
3b56b8e6c0 Optimize summarize task prompt and others (#1533)
* 1.update prompt;
2.update fetch information method.

* 1.update prompt;
2.save result to markdown;

* 1.get context info from context_manager;
2.run the entire process successfully.
2023-06-01 21:22:24 +08:00
you-n-g
40e0c329ba Add configurable dataset (#1535) 2023-06-01 20:05:02 +08:00
Xu Yang
e376648860 Merge pull request #1536 from microsoft/xuyang1/add_debug_mode_to_save_cache
add a debug mode to speed up debug process
2023-06-01 19:44:17 +08:00
Xu Yang
5f37f32184 update code 2023-06-01 19:38:26 +08:00
Xu Yang
d46b4c1ebf Merge pull request #1534 from microsoft/xuyang1/add_code_implementation_task
add code implementation task
2023-06-01 18:13:05 +08:00
Xu Yang
0515524b51 add code implementation code 2023-06-01 18:04:31 +08:00
Xu Yang
cda32d5703 Merge pull request #1532 from microsoft/xuyang1/add-plan-and-config-task-implementation
add the initial version of plan and config task implementation
2023-06-01 11:20:04 +08:00
Xu Yang
e2332a004b imporove some words in prompt 2023-06-01 01:09:14 +08:00
Xu Yang
08d9dbccc9 update v1 code containing SLplan and config action 2023-06-01 00:36:04 +08:00
Fivele-Li
e7cd93a36d add base method for summarization; (#1530) 2023-05-31 15:50:34 +08:00
Xu Yang
3919678028 split task into workflow and task to make the strcture more clear 2023-05-31 11:45:25 +08:00
Xu Yang
421b1403b2 Merge pull request #1528 from microsoft/xuyang1/refine_task_and_implement_workflow_task_as_example
Xuyang1/refine task and implement workflow task as example
2023-05-31 11:36:36 +08:00
Xu Yang
94102fb742 remove tasktype variable 2023-05-31 11:35:54 +08:00
Cadenza-Li
74a5d7c8af add parse method for summarization; 2023-05-31 00:08:21 +08:00
Xu Yang
ce39b4b6f8 add qlib auto init so logger can display info 2023-05-30 21:52:35 +08:00
Xu Yang
2af35d9c89 second commit 2023-05-30 20:20:16 +08:00
Xu Yang
f37643550b first round 2023-05-30 20:19:58 +08:00
Xu Yang
55611aa43e Merge pull request #1527 from microsoft/xuyang1/add_openai_api_support
add openai interface support
2023-05-30 13:44:10 +08:00
Xu Yang
f24253efd2 add openai interface support 2023-05-30 13:42:01 +08:00
Young
7c4f3b8a7d Initial interface for discussion 2023-05-24 12:18:31 +08:00
46 changed files with 4056 additions and 886 deletions

1
.gitignore vendored
View File

@@ -22,6 +22,7 @@ dist/
qlib/VERSION.txt
qlib/data/_libs/expanding.cpp
qlib/data/_libs/rolling.cpp
qlib/finco/prompt_cache.json
examples/estimator/estimator_example/
examples/rl/data/
examples/rl/checkpoints/

View File

@@ -179,7 +179,7 @@ def get_strategy_executor(
executor: Union[str, dict, object, Path],
benchmark: Optional[str] = "SH000300",
account: Union[float, int, dict] = 1e9,
exchange_kwargs: Union[dict, Exchange] = {}, # TODO: rename parameter
exchange_kwargs: dict = {},
pos_type: str = "Position",
) -> Tuple[BaseStrategy, BaseExecutor]:
@@ -197,15 +197,12 @@ def get_strategy_executor(
pos_type=pos_type,
)
if isinstance(exchange_kwargs, Exchange):
trade_exchange = exchange_kwargs
else:
exchange_kwargs = copy.copy(exchange_kwargs)
if "start_time" not in exchange_kwargs:
exchange_kwargs["start_time"] = start_time
if "end_time" not in exchange_kwargs:
exchange_kwargs["end_time"] = end_time
trade_exchange = get_exchange(**exchange_kwargs)
exchange_kwargs = copy.copy(exchange_kwargs)
if "start_time" not in exchange_kwargs:
exchange_kwargs["start_time"] = start_time
if "end_time" not in exchange_kwargs:
exchange_kwargs["end_time"] = end_time
trade_exchange = get_exchange(**exchange_kwargs)
common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=trade_exchange)
trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy)

View File

@@ -56,7 +56,6 @@ def collect_data_loop(
trade_strategy: BaseStrategy,
trade_executor: BaseExecutor,
return_value: dict | None = None,
show_progress: bool = True,
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], None]:
"""Generator for collecting the trade decision data for rl training
@@ -75,8 +74,6 @@ def collect_data_loop(
the outermost executor
return_value : dict
used for backtest_loop
show_progress: bool
whether to show execution progress
Yields
-------
@@ -86,8 +83,7 @@ def collect_data_loop(
trade_executor.reset(start_time=start_time, end_time=end_time)
trade_strategy.reset(level_infra=trade_executor.get_level_infra())
disable = not show_progress
with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop", disable=disable) as bar:
with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop") as bar:
_execute_result = None
while not trade_executor.finished():
_trade_decision: BaseTradeDecision = trade_strategy.generate_trade_decision(_execute_result)

View File

@@ -177,7 +177,7 @@ class Exchange:
necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"}
if self.limit_type == self.LT_TP_EXP:
assert isinstance(limit_threshold, tuple) or (isinstance(limit_threshold, list) and len(limit_threshold) == 2)
assert isinstance(limit_threshold, tuple)
for exp in limit_threshold:
necessary_fields.add(exp)
all_fields = list(necessary_fields | set(vol_lt_fields) | set(subscribe_fields))
@@ -263,9 +263,6 @@ class Exchange:
"""get limit type"""
if isinstance(limit_threshold, tuple):
return self.LT_TP_EXP
if isinstance(limit_threshold, list):
assert len(limit_threshold) == 2
return self.LT_TP_EXP
elif isinstance(limit_threshold, float):
return self.LT_FLT
elif limit_threshold is None:
@@ -328,7 +325,7 @@ class Exchange:
assert isinstance(volume_threshold, dict)
for key, vol_limit in volume_threshold.items():
assert isinstance(vol_limit, tuple) or (isinstance(vol_limit, list) and len(vol_limit) == 2)
assert isinstance(vol_limit, tuple)
fields.add(vol_limit[1])
if key in ("buy", "all"):
@@ -806,7 +803,7 @@ class Exchange:
vol_limit_num: List[float] = []
for limit in vol_limit:
assert isinstance(limit, tuple) or (isinstance(limit, list) and len(limit) == 2)
assert isinstance(limit, tuple)
if limit[0] == "current":
limit_value = self.quote.get_data(
order.stock_id,

111
qlib/contrib/analyzer.py Normal file
View File

@@ -0,0 +1,111 @@
import logging
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
from ..log import get_module_logger
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
logger = get_module_logger("analysis", logging.INFO)
class AnalyzerTemp:
def __init__(self, recorder, output_dir=None, **kwargs):
self.recorder = recorder
self.output_dir = Path(output_dir) if output_dir else "./"
def load(self, name: str):
"""
It behaves the same as self.recorder.load_object.
But it is an easier interface because users don't have to care about `get_path` and `artifact_path`
Parameters
----------
name : str
the name for the file to be load.
Return
------
The stored records.
"""
return self.recorder.load_object(name)
def analyse(self, **kwargs):
"""
Analyse data index, distribution .etc
Parameters
----------
Return
------
The handled data.
"""
raise NotImplementedError(f"Please implement the `analysis` method.")
class HFAnalyzer(AnalyzerTemp):
"""
This is the Signal Analysis class that generates the analysis results such as IC and IR.
default output image filename is "HFAnalyzerTable.jpeg"
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def analyse(self):
pred = self.load("pred.pkl")
label = self.load("label.pkl")
long_pre, short_pre = calc_long_short_prec(pred.iloc[:, 0], label.iloc[:, 0], is_alpha=True)
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
metrics = {
"IC": ic.mean(),
"ICIR": ic.mean() / ic.std(),
"Rank IC": ric.mean(),
"Rank ICIR": ric.mean() / ric.std(),
"Long precision": long_pre.mean(),
"Short precision": short_pre.mean(),
}
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, 0])
metrics.update(
{
"Long-Short Average Return": long_short_r.mean(),
"Long-Short Average Sharpe": long_short_r.mean() / long_short_r.std(),
}
)
table = [[k, v] for (k, v) in metrics.items()]
plt.table(cellText=table, loc="center")
plt.axis("off")
plt.savefig(self.output_dir.joinpath("HFAnalyzerTable.jpeg"))
plt.clf()
plt.scatter(np.arange(0, len(pred)), pred.iloc[:, 0])
plt.scatter(np.arange(0, len(label)), label.iloc[:, 0])
plt.title("HFAnalyzer")
plt.savefig(self.output_dir.joinpath("HFAnalyzer.jpeg"))
return "HFAnalyzer.jpeg"
class SignalAnalyzer(AnalyzerTemp):
"""
This is the Signal Analysis class that generates the analysis results such as IC and IR.
default output image filename is "signalAnalysis.jpeg"
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def analyse(self, dataset=None, **kwargs):
label = self.load("label.pkl")
plt.hist(label)
plt.title("SignalAnalyzer")
plt.savefig(self.output_dir.joinpath("signalAnalysis.jpeg"))
return "signalAnalysis.jpeg"

View File

@@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Optional
from qlib.utils.data import update_config
from ...data.dataset.handler import DataHandlerLP
from ...data.dataset.processor import Processor
from ...utils import get_callable_kwargs
@@ -57,12 +59,13 @@ class Alpha360(DataHandlerLP):
fit_end_time=None,
filter_pipe=None,
inst_processors=None,
data_loader: Optional[dict] = None,
**kwargs
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
data_loader = {
_data_loader = {
"class": "QlibDataLoader",
"kwargs": {
"config": {
@@ -74,12 +77,14 @@ class Alpha360(DataHandlerLP):
"inst_processors": inst_processors,
},
}
if data_loader is not None:
update_config(_data_loader, data_loader)
super().__init__(
instruments=instruments,
start_time=start_time,
end_time=end_time,
data_loader=data_loader,
data_loader=_data_loader,
learn_processors=learn_processors,
infer_processors=infer_processors,
**kwargs
@@ -153,12 +158,13 @@ class Alpha158(DataHandlerLP):
process_type=DataHandlerLP.PTYPE_A,
filter_pipe=None,
inst_processors=None,
data_loader: Optional[dict] = None,
**kwargs
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
data_loader = {
_data_loader = {
"class": "QlibDataLoader",
"kwargs": {
"config": {
@@ -170,11 +176,13 @@ class Alpha158(DataHandlerLP):
"inst_processors": inst_processors,
},
}
if data_loader is not None:
update_config(_data_loader, data_loader)
super().__init__(
instruments=instruments,
start_time=start_time,
end_time=end_time,
data_loader=data_loader,
data_loader=_data_loader,
infer_processors=infer_processors,
learn_processors=learn_processors,
process_type=process_type,

18
qlib/finco/.env.example Normal file
View File

@@ -0,0 +1,18 @@
OPENAI_API_KEY=your_api_key
# USE_AZURE=True
# AZURE_API_BASE=your_api_base
# AZURE_API_VERSION=your_api_version
# use gpt-4 means more token but more wait time
# MODEL=gpt-4
# MAX_TOKENS=1600
# MAX_RETRY=1000
MAX_TOKENS=1600
MAX_RETRY=120
CONTINOUS_MODE=True
DEBUG_MODE=True

22
qlib/finco/README.md Normal file
View File

@@ -0,0 +1,22 @@
# This is an experimental branch of "`FI`nancial `CO`pilot of `Qlib`"
## Installation
- To run this module, you need to first install Qlib following the instruction in [install-from-source](/README.md#install-from-source) or follow:
```python
python -m pip install git+https://github.com/microsoft/qlib.git@finco
```
- then you need to install other dependencies of finco:
```python
python -m pip install pydantic openai python-dotenv
```
## Quick run
To run this module, you can start the workflow easily with one command:
```sh
cd qlib/finco; python cli.py "your prompt"
```

13
qlib/finco/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from pathlib import Path
DIRNAME = Path(__file__).absolute().resolve().parent
def get_finco_path() -> Path:
"""
return the template path
Because the template path is located in the folder. We don't know where it is located. So __file__ for this module will be used.
"""
return DIRNAME

15
qlib/finco/cli.py Normal file
View File

@@ -0,0 +1,15 @@
import fire
from qlib.finco.workflow import WorkflowManager
from dotenv import load_dotenv
from qlib import auto_init
def main(prompt=None):
load_dotenv(verbose=True, override=True)
wm = WorkflowManager()
wm.run(prompt)
if __name__ == "__main__":
auto_init()
fire.Fire(main)

15
qlib/finco/cli_learn.py Normal file
View File

@@ -0,0 +1,15 @@
import fire
from qlib.finco.workflow import LearnManager
from dotenv import load_dotenv
from qlib import auto_init
def main(prompt=None):
load_dotenv(verbose=True, override=True)
lm = LearnManager()
lm.run(prompt)
if __name__ == "__main__":
auto_init()
fire.Fire(main)

32
qlib/finco/conf.py Normal file
View File

@@ -0,0 +1,32 @@
# TODO: use pydantic for other modules in Qlib
from pydantic import BaseSettings
from qlib.finco.utils import SingletonBaseClass
import os
class Config(SingletonBaseClass):
"""
This config is for fast demo purpose.
Please use BaseSettings insetead in the future
"""
def __init__(self):
self.use_azure = os.getenv("USE_AZURE") == "True"
self.temperature = 0.5 if os.getenv("TEMPERATURE") is None else float(os.getenv("TEMPERATURE"))
self.max_tokens = 800 if os.getenv("MAX_TOKENS") is None else int(os.getenv("MAX_TOKENS"))
self.openai_api_key = os.getenv("OPENAI_API_KEY")
self.use_azure = os.getenv("USE_AZURE") == "True"
self.azure_api_base = os.getenv("AZURE_API_BASE")
self.azure_api_version = os.getenv("AZURE_API_VERSION")
self.model = os.getenv("MODEL") or ("gpt-35-turbo" if self.use_azure else "gpt-3.5-turbo")
self.max_retry = int(os.getenv("MAX_RETRY")) if os.getenv("MAX_RETRY") is not None else None
self.continuous_mode = (
os.getenv("CONTINOUS_MODE") == "True" if os.getenv("CONTINOUS_MODE") is not None else False
)
self.debug_mode = os.getenv("DEBUG_MODE") == "True" if os.getenv("DEBUG_MODE") is not None else False
self.workspace = os.getenv("WORKSPACE") if os.getenv("WORKSPACE") is not None else "./finco_workspace"
self.max_past_message_include = int(os.getenv("MAX_PAST_MESSAGE_INCLUDE") or 6) // 2 * 2

156
qlib/finco/knowledge.py Normal file
View File

@@ -0,0 +1,156 @@
from pathlib import Path
from jinja2 import Template
from typing import List
from qlib.workflow import R
from qlib.finco.log import FinCoLog
from qlib.finco.llm import APIBackend
class Knowledge:
"""
Use to handle knowledge in finCo such as experiment and outside domain information
"""
def __init__(self):
self.logger = FinCoLog()
def load(self, **kwargs):
"""
Load knowledge in memory
Parameters
----------
Return
------
"""
raise NotImplementedError(f"Please implement the `load` method.")
def brief(self, **kwargs):
"""
Return a brief summary of knowledge
Parameters
----------
Return
------
"""
raise NotImplementedError(f"Please implement the `load` method.")
class KnowledgeExperiment(Knowledge):
"""
Handle knowledge from experiments
"""
def __init__(self, exp_name, rec_id=None):
super().__init__()
self.exp_name = exp_name
self.exp = None
self.recs = []
self.load(exp_name=exp_name, rec_id=rec_id)
def load(self, exp_name, rec_id=None):
recs = []
self.exp = R.get_exp(experiment_name=exp_name)
for r in self.exp.list_recorders(rtype=self.exp.RT_L):
if rec_id is not None and r.id != rec_id:
continue
recs.append(r)
self.recs.extend(recs)
def brief(self):
docs = []
for recorder in self.recs:
docs.append({"exp_name": self.exp.name, "record_info": recorder.info,
"config": recorder.load_object("config"),
"context_summary": recorder.load_object("context_summary")})
return docs
class Topic:
def __init__(self, name: str, describe: Template):
self.name = name
self.describe = describe
self.docs = []
self.knowledge = None
self.logger = FinCoLog()
def summarize(self, docs: list):
self.logger.info(f"Summarize topic: \nname: {self.name}\ndescribe: {self.describe.module}")
prompt_workflow_selection = self.describe.render(docs=docs)
response = APIBackend().build_messages_and_create_chat_completion(
user_prompt=prompt_workflow_selection
)
self.knowledge = response
self.docs = docs
class KnowledgeBase:
"""
Load knowledge, offer brief information of knowledge and common handle interfaces
"""
def __init__(self, init_path=None, topics: List[Topic] = None):
self.logger = FinCoLog()
init_path = init_path if init_path else Path.cwd()
if not init_path.exists():
self.logger.warning(f"{init_path} not exist, create empty directory.")
Path.mkdir(init_path)
self.knowledge = self.load(path=init_path)
# todo: replace list with persistent storage strategy such as ES/pinecone to enable
# literal search/semantic search
self.docs = self.brief(knowledge=self.knowledge)
self.topics = topics if topics else []
def load(self, path) -> List:
if isinstance(path, str):
path = Path(path)
knowledge = []
path = path if path.name == "mlruns" else path.joinpath("mlruns")
R.set_uri(path.as_uri())
for exp_name in R.list_experiments():
knowledge.append(KnowledgeExperiment(exp_name=exp_name))
self.logger.plain_info(f"Load knowledge from: {path} finished.")
return knowledge
def update(self, path):
# note: only update new knowledge in future
knowledge = self.load(path)
self.knowledge = knowledge
self.docs = self.brief(self.knowledge)
self.logger.plain_info(f"Update knowledge finished.")
def brief(self, knowledge: List[Knowledge]) -> List:
docs = []
for k in knowledge:
docs.extend(k.brief())
self.logger.plain_info(f"Generate brief knowledge summary finished.")
return docs
def query(self, content: str = None):
# todo: query by DSL
return self.docs
def query_topics(self):
knowledge_of_topics = []
for topic in self.topics:
knowledge_of_topics.append({topic.name: topic.knowledge})
return knowledge_of_topics
def summarize_by_topic(self):
for topic in self.topics:
topic.summarize(self.docs)

111
qlib/finco/llm.py Normal file
View File

@@ -0,0 +1,111 @@
import os
import time
import openai
import json
from typing import Optional
from qlib.finco.conf import Config
from qlib.finco.utils import SingletonBaseClass
from qlib.finco.log import FinCoLog
class APIBackend(SingletonBaseClass):
def __init__(self):
self.cfg = Config()
openai.api_key = self.cfg.openai_api_key
if self.cfg.use_azure:
openai.api_type = "azure"
openai.api_base = self.cfg.azure_api_base
openai.api_version = self.cfg.azure_api_version
self.use_azure = self.cfg.use_azure
self.debug_mode = False
if self.cfg.debug_mode:
self.debug_mode = True
cwd = os.getcwd()
self.cache_file_location = os.path.join(cwd, "prompt_cache.json")
self.cache = (
json.load(open(self.cache_file_location, "r")) if os.path.exists(self.cache_file_location) else {}
)
def build_messages_and_create_chat_completion(self, user_prompt, system_prompt=None, former_messages=[], **kwargs):
"""build the messages to avoid implementing several redundant lines of code"""
cfg = Config()
# TODO: system prompt should always be provided. In development stage we can use default value
if system_prompt is None:
try:
system_prompt = cfg.system_prompt
except AttributeError:
FinCoLog().warning("system_prompt is not set, using default value.")
system_prompt = "You are an AI assistant who helps to answer user's questions about finance."
messages = [
{
"role": "system",
"content": system_prompt,
}
]
messages.extend(former_messages[-1*cfg.max_past_message_include:])
messages.append(
{
"role": "user",
"content": user_prompt,
}
)
fcl = FinCoLog()
response = self.try_create_chat_completion(messages=messages, **kwargs)
fcl.log_message(messages)
fcl.log_response(response)
return response
def try_create_chat_completion(self, max_retry=10, **kwargs):
max_retry = self.cfg.max_retry if self.cfg.max_retry is not None else max_retry
for i in range(max_retry):
try:
response = self.create_chat_completion(**kwargs)
return response
except (openai.error.RateLimitError, openai.error.Timeout, openai.error.APIError) as e:
print(e)
print(f"Retrying {i+1}th time...")
time.sleep(1)
continue
except openai.InvalidRequestError as e:
print("Invalid request, will try to reduce the messages length and retry...")
if len(kwargs["messages"]) > 2:
kwargs["messages"] = kwargs["messages"][[0]] + kwargs["messages"][3:]
continue
raise e
raise Exception(f"Failed to create chat completion after {max_retry} retries.")
def create_chat_completion(
self,
messages,
model=None,
temperature: float = None,
max_tokens: Optional[int] = None,
) -> str:
if self.debug_mode:
key = json.dumps(messages)
if key in self.cache:
return self.cache[key]
if temperature is None:
temperature = self.cfg.temperature
if max_tokens is None:
max_tokens = self.cfg.max_tokens
if self.cfg.use_azure:
response = openai.ChatCompletion.create(
engine=self.cfg.model,
messages=messages,
max_tokens=self.cfg.max_tokens,
)
else:
response = openai.ChatCompletion.create(
model=self.cfg.model,
messages=messages,
)
resp = response.choices[0].message["content"]
if self.debug_mode:
self.cache[key] = resp
json.dump(self.cache, open(self.cache_file_location, "w"))
return resp

131
qlib/finco/log.py Normal file
View File

@@ -0,0 +1,131 @@
"""
This module will base on Qlib's logger module and provides some interactive functions.
"""
import logging
from typing import Dict, List
from qlib.finco.utils import SingletonBaseClass
from contextlib import contextmanager
class LogColors:
"""
ANSI color codes for use in console output.
"""
RED = "\033[91m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
BLUE = "\033[94m"
MAGENTA = "\033[95m"
CYAN = "\033[96m"
WHITE = "\033[97m"
GRAY = "\033[90m"
BLACK = "\033[30m"
BOLD = "\033[1m"
ITALIC = "\033[3m"
END = "\033[0m"
@classmethod
def get_all_colors(cls):
names = dir(cls)
names = [name for name in names if not name.startswith("__") and not callable(getattr(cls, name))]
var_values = [getattr(cls, name) for name in names]
return var_values
def render(self, text: str, color: str = "", style: str = ""):
"""
render text by input color and style. It's not recommend that input text is already rendered.
"""
# This method is called too frequently, which is not good.
colors = self.get_all_colors()
# Perhaps color and font should be distinguished here.
if color:
assert color in colors, f"color should be in: {colors} but now is: {color}"
if style:
assert style in colors, f"style should be in: {colors} but now is: {style}"
text = f"{color}{text}{self.END}"
text = f"{style}{text}{self.END}"
return text
@contextmanager
def formatting_log(logger, title="Info"):
"""
a context manager, print liens before and after a function
"""
length = {"Start": 120, "Task": 120, "Info": 60, "Interact": 60, "End": 120}.get(title, 60)
color, bold = (LogColors.YELLOW, LogColors.BOLD) \
if title in ["Start", "Task", "Info", "Interact", "End"] else (LogColors.CYAN, "")
logger.info("")
logger.info(f"{color}{bold}{'-'} {title} {'-' * (length - len(title))}{LogColors.END}")
yield
logger.info("")
class FinCoLog(SingletonBaseClass):
# TODO:
# - config to file logger and save it into workspace
def __init__(self) -> None:
self.logger = logging.Logger("interactive")
# TODO: merge these with Qlib's default logger.
# We can do the same thing by changing the default log dict of Qlib.
# Reference: https://github.com/microsoft/qlib/blob/main/qlib/config.py#L155
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(message)s"))
self.logger.addHandler(handler)
self.logger.setLevel(logging.INFO)
def log_message(self, messages: List[Dict[str, str]]):
"""
messages is some info like this [
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": user_prompt,
},
]
"""
with formatting_log(self.logger, "GPT Messages"):
for m in messages:
self.logger.info(
f"{LogColors.MAGENTA}{LogColors.BOLD}Role:{LogColors.END} "
f"{LogColors.CYAN}{m['role']}{LogColors.END}\n"
+ f"{LogColors.MAGENTA}{LogColors.BOLD}Content:{LogColors.END} "
f"{LogColors.CYAN}{m['content']}{LogColors.END}\n")
def log_response(self, response: str):
with formatting_log(self.logger, "GPT Response"):
self.logger.info(
f"{LogColors.CYAN}{response}{LogColors.END}\n")
# TODO:
# It looks wierd if we only have logger
def info(self, *args, plain=False, title="Info"):
if plain:
return self.plain_info(*args)
with formatting_log(self.logger, title):
for arg in args:
self.logger.info(f"{LogColors.WHITE}{arg}{LogColors.END}")
def plain_info(self, *args):
for arg in args:
self.logger.info(
f"{LogColors.YELLOW}{LogColors.BOLD}Info:{LogColors.END}{LogColors.WHITE}{arg}{LogColors.END}")
def warning(self, *args):
for arg in args:
self.logger.warning(
f"{LogColors.BLUE}{LogColors.BOLD}Warning:{LogColors.END}{arg}")
def error(self, *args):
for arg in args:
self.logger.error(
f"{LogColors.RED}{LogColors.BOLD}Error:{LogColors.END}{arg}")

View File

@@ -0,0 +1,32 @@
from typing import Union
from pathlib import Path
from jinja2 import Template
import yaml
from qlib.finco.utils import SingletonBaseClass
from qlib.finco import get_finco_path
class PromptTemplate(SingletonBaseClass):
def __init__(self) -> None:
super().__init__()
_template = yaml.load(open(Path.joinpath(get_finco_path(), "prompt_template.yaml"), "r"),
Loader=yaml.FullLoader)
for k, v in _template.items():
if k == "mods":
continue
self.__setattr__(k, Template(v))
def get(self, key: str):
return self.__dict__.get(key, Template(""))
def update(self, key: str, value):
self.__setattr__(key, value)
def save(self, file_path: Union[str, Path]):
if isinstance(file_path, str):
file_path = Path(file_path)
Path.mkdir(file_path.parent, exist_ok=True)
with open(file_path, 'w') as f:
yaml.dump(self.__dict__, f)

File diff suppressed because it is too large Load Diff

1110
qlib/finco/task.py Normal file

File diff suppressed because it is too large Load Diff

12
qlib/finco/tpl/README.md Normal file
View File

@@ -0,0 +1,12 @@
This is a set of templates that should be copied for a new project.
Here are the explanations for the templates folder.
| folder | explanations |
|--------|------------------------------------------------------------------|
| sl | Default configuration for supervised learning |
| sl-cfg | Like configuration in sl. But the dataset is highly configurable |
# TODO
- [ ] [Copier](https://copier.readthedocs.io/en/stable/#quick-start) may be useful if the generation process becomes complicated

View File

@@ -0,0 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from pathlib import Path
DIRNAME = Path(__file__).absolute().resolve().parent
def get_tpl_path() -> Path:
"""
return the template path
Because the template path is located in the folder. We don't know where it is located. So __file__ for this module will be used.
"""
return DIRNAME

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,73 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
experiment_name: finCo
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
topk: 50
n_drop: 5
backtest:
start_time: 2017-01-01
end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
exchange_kwargs:
limit_threshold: 0.095
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: LGBModel
module_path: qlib.contrib.model.gbdt
kwargs:
loss: mse
colsample_bytree: 0.8879
learning_rate: 0.2
subsample: 0.8789
lambda_l1: 205.6999
lambda_l2: 580.9768
max_depth: 8
num_leaves: 210
num_threads: 20
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs:
model: <MODEL>
dataset: <DATASET>
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

38
qlib/finco/utils.py Normal file
View File

@@ -0,0 +1,38 @@
import json
from fuzzywuzzy import fuzz
class SingletonMeta(type):
_instance = None
def __call__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(SingletonMeta, cls).__call__(*args, **kwargs)
return cls._instance
class SingletonBaseClass(metaclass=SingletonMeta):
"""
Because we try to support defining Singleton with `class A(SingletonBaseClass)` instead of `A(metaclass=SingletonMeta)`
This class becomes necessary
"""
# TODO: Add move this class to Qlib's general utils.
def parse_json(response):
try:
return json.loads(response)
except json.decoder.JSONDecodeError:
pass
raise Exception(f"Failed to parse response: {response}, please report it or help us to fix it.")
def similarity(text1, text2):
text1 = text1 if isinstance(text1, str) else ""
text2 = text2 if isinstance(text2, str) else ""
# Maybe we can use other similarity algorithm such as tfidf
return fuzz.ratio(text1, text2)

223
qlib/finco/workflow.py Normal file
View File

@@ -0,0 +1,223 @@
import sys
import copy
import shutil
from pathlib import Path
from typing import List
from qlib.finco.task import HighLevelPlanTask, SummarizeTask, TrainTask
from qlib.finco.prompt_template import PromptTemplate, Template
from qlib.finco.log import FinCoLog, LogColors
from qlib.finco.utils import similarity
from qlib.finco.llm import APIBackend
from qlib.finco.conf import Config
from qlib.finco.knowledge import KnowledgeBase, Topic
class WorkflowContextManager:
"""Context Manager stores the context of the workflow"""
"""All context are key value pairs which saves the input, output and status of the whole workflow"""
def __init__(self) -> None:
self.context = {}
self.logger = FinCoLog()
def set_context(self, key, value):
if key in self.context:
self.logger.warning("The key already exists in the context, the value will be overwritten")
self.context[key] = value
def get_context(self, key):
# NOTE: if the key doesn't exist, return None. In the future, we may raise an error to detect abnormal behavior
if key not in self.context:
self.logger.warning("The key doesn't exist in the context")
return None
return self.context[key]
def update_context(self, key, new_value):
# NOTE: if the key doesn't exist, return None. In the future, we may raise an error to detect abnormal behavior
if key not in self.context:
self.logger.warning("The key doesn't exist in the context")
self.context.update({key: new_value})
def get_all_context(self):
"""return a deep copy of the context"""
"""TODO: do we need to return a deep copy?"""
return copy.deepcopy(self.context)
def retrieve(self, query: str) -> dict:
if query in self.context.keys():
return {query: self.context.get(query)}
# Note: retrieve information from context by string similarity maybe abandon in future
scores = {}
for k, v in self.context.items():
scores.update({k: max(similarity(query, k), similarity(query, v))})
max_score_key = max(scores, key=scores.get)
return {max_score_key: self.context.get(max_score_key)}
def clear(self, reserve: list = None):
if reserve is None:
reserve = []
_context = {k: self.get_context(k) for k in reserve}
self.context = _context
class WorkflowManager:
"""This manage the whole task automation workflow including tasks and actions"""
def __init__(self, workspace=None) -> None:
self.logger = FinCoLog()
if workspace is None:
self._workspace = Path.cwd() / "finco_workspace"
else:
self._workspace = Path(workspace)
self.conf = Config()
self._confirm_and_rm()
self.prompt_template = PromptTemplate()
self.context = WorkflowContextManager()
self.context.set_context("workspace", self._workspace)
self.default_user_prompt = "Please help me build a low turnover strategy that focus more on longterm return in China A csi300. Please help to use lightgbm model."
def _confirm_and_rm(self):
# if workspace exists, please confirm and remove it. Otherwise exit.
if self._workspace.exists() and not self.conf.continuous_mode:
self.logger.info(title="Interact")
flag = input(
LogColors().render(
f"Will be deleted: \n\t{self._workspace}\n"
f"If you do not need to delete {self._workspace},"
f" please change the workspace dir or rename existing files\n"
f"Are you sure you want to delete, yes(Y/y), no (N/n):",
color=LogColors.WHITE)
)
if str(flag) not in ["Y", "y"]:
sys.exit()
else:
# remove self._workspace
shutil.rmtree(self._workspace)
elif self._workspace.exists() and self.conf.continuous_mode:
shutil.rmtree(self._workspace)
def set_context(self, key, value):
"""Direct call set_context method of the context manager"""
self.context.set_context(key, value)
def get_context(self) -> WorkflowContextManager:
return self.context
def run(self, prompt: str) -> Path:
"""
The workflow manager is supposed to generate a codebase based on the prompt
Parameters
----------
prompt: str
the prompt user gives
Returns
-------
Path
The workflow manager is expected to produce output that includes a codebase containing generated code, results, and reports in a designated location.
The path is returned
The output path should follow a specific format:
- TODO: design
There is a summarized report where user can start from.
"""
# NOTE: The following items are not designed to make the workflow very flexible.
# - The generated tasks can't be changed after geting new information from the execution retuls.
# - But it is required in some cases, if we want to build a external dataset, it maybe have to plan like autogpt...
# NOTE: default user prompt might be changed in the future and exposed to the user
if prompt is None:
self.set_context("user_prompt", self.default_user_prompt)
else:
self.set_context("user_prompt", prompt)
self.logger.info(f"user_prompt: {self.get_context().get_context('user_prompt')}", title="Start")
# NOTE: list may not be enough for general task list
task_list = [HighLevelPlanTask(), SummarizeTask()]
task_finished = []
while len(task_list):
task_list_info = [str(task) for task in task_list]
# task list is not long, so sort it is not a big problem
# TODO: sort the task list based on the priority of the task
# task_list = sorted(task_list, key=lambda x: x.task_type)
t = task_list.pop(0)
self.logger.info(f"Task finished: {[str(task) for task in task_finished]}",
f"Task in queue: {task_list_info}",
f"Executing task: {str(t)}",
title="Task")
t.assign_context_manager(self.context)
res = t.execute()
t.summarize()
task_finished.append(t)
self.context.set_context("task_finished", task_finished)
self.logger.plain_info(f"{str(t)} finished.\n\n\n")
task_list = res + task_list
return self._workspace
class LearnManager:
__DEFAULT_TOPICS = ["IC", "MaxDropDown"]
def __init__(self):
self.epoch = 0
self.wm = WorkflowManager()
topics = [Topic(name=topic, describe=self.wm.prompt_template.get(f"Topic_{topic}")) for topic in
self.__DEFAULT_TOPICS]
self.knowledge_base = KnowledgeBase(init_path=Path.cwd().joinpath('knowledge'), topics=topics)
def run(self, prompt):
# todo: add early stop condition
for i in range(10):
self.wm.run(prompt)
self.knowledge_base.update(self.wm._workspace)
self.knowledge_base.summarize_by_topic()
self.learn()
self.epoch += 1
def learn(self):
workspace = self.wm.context.get_context("workspace")
def _drop_duplicate_task(_task: List):
unique_task = {}
for obj in _task:
task_name = obj.__class__.__name__
if task_name not in unique_task:
unique_task[task_name] = obj
return list(unique_task.values())
# one task maybe run several times in workflow
task_finished = _drop_duplicate_task(self.wm.context.get_context("task_finished"))
user_prompt = self.wm.context.get_context("user_prompt")
summary = self.wm.context.get_context("summary")
for task in task_finished:
prompt_workflow_selection = self.wm.prompt_template.get(f"{self.__class__.__name__}_user").render(
summary=summary, brief=self.knowledge_base.query_topics(),
task_finished=[str(t) for t in task_finished],
task=task.__class__.__name__, system=task.system.render(), user_prompt=user_prompt
)
response = APIBackend().build_messages_and_create_chat_completion(
user_prompt=prompt_workflow_selection,
system_prompt=self.wm.prompt_template.get(f"{self.__class__.__name__}_system").render()
)
# todo: response assertion
task.prompt_template.update(key=f"{task.__class__.__name__}_system", value=Template(response))
self.wm.prompt_template.save(Path.joinpath(workspace, f"prompts/checkpoint_{self.epoch}.yml"))
self.wm.context.clear(reserve=["workspace"])

View File

@@ -16,12 +16,13 @@ import torch
from joblib import Parallel, delayed
from qlib.backtest import INDICATOR_METRIC, collect_data_loop, get_strategy_executor
from qlib.backtest.decision import BaseTradeDecision, TradeRangeByTime
from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir, TradeRangeByTime
from qlib.backtest.executor import SimulatorExecutor
from qlib.backtest.high_performance_ds import BaseOrderIndicator
from qlib.rl.contrib.naive_config_parser import BacktestConfigParser
from qlib.rl.contrib.naive_config_parser import get_backtest_config_fromfile
from qlib.rl.contrib.utils import read_order_file
from qlib.rl.data.integration import init_qlib
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
from qlib.typehint import Literal
@@ -123,13 +124,105 @@ def _generate_report(
return report
def single_with_collect_data_loop(
def single_with_simulator(
backtest_config: dict,
orders: pd.DataFrame,
split: Literal["stock", "day"] = "stock",
cash_limit: float | None = None,
generate_report: bool = False,
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
"""Run backtest in a single thread with SingleAssetOrderExecution simulator. The orders will be executed day by day.
A new simulator will be created and used for every single-day order.
Parameters
----------
backtest_config:
Backtest config
orders:
Orders to be executed. Example format:
datetime instrument amount direction
0 2020-06-01 INST 600.0 0
1 2020-06-02 INST 700.0 1
...
split
Method to split orders. If it is "stock", split orders by stock. If it is "day", split orders by date.
cash_limit
Limitation of cash.
generate_report
Whether to generate reports.
Returns
-------
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
"""
init_qlib(backtest_config["qlib"])
stocks = orders.instrument.unique().tolist()
reports = []
decisions = []
for _, row in orders.iterrows():
date = pd.Timestamp(row["datetime"])
start_time = pd.Timestamp(backtest_config["start_time"]).replace(year=date.year, month=date.month, day=date.day)
end_time = pd.Timestamp(backtest_config["end_time"]).replace(year=date.year, month=date.month, day=date.day)
order = Order(
stock_id=row["instrument"],
amount=row["amount"],
direction=OrderDir(row["direction"]),
start_time=start_time,
end_time=end_time,
)
executor_config = _get_multi_level_executor_config(
strategy_config=backtest_config["strategies"],
cash_limit=cash_limit,
generate_report=generate_report,
data_granularity=backtest_config["data_granularity"],
)
exchange_config = copy.deepcopy(backtest_config["exchange"])
exchange_config.update(
{
"codes": stocks,
"freq": backtest_config["data_granularity"],
}
)
simulator = SingleAssetOrderExecution(
order=order,
executor_config=executor_config,
exchange_config=exchange_config,
qlib_config=None,
cash_limit=None,
)
reports.append(simulator.report_dict)
decisions += simulator.decisions
indicator_1day_objs = [report["indicator_dict"]["1day"][1] for report in reports]
indicator_info = {k: v for obj in indicator_1day_objs for k, v in obj.order_indicator_his.items()}
records = _convert_indicator_to_dataframe(indicator_info)
assert records is None or not np.isnan(records["ffr"]).any()
if generate_report:
_report = _generate_report(decisions, [report["indicator"] for report in reports])
if split == "stock":
stock_id = orders.iloc[0].instrument
report = {stock_id: _report}
else:
day = orders.iloc[0].datetime
report = {day: _report}
return records, report
else:
return records
def single_with_collect_data_loop(
backtest_config: dict,
orders: pd.DataFrame,
time_range: Tuple[str, str],
exchange_config: dict,
strategy_config: dict,
split: Literal["stock", "day"] = "stock",
data_granularity: str = "1min",
cash_limit: float | None = None,
generate_report: bool = False,
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
@@ -157,42 +250,44 @@ def single_with_collect_data_loop(
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
"""
init_qlib(backtest_config["qlib"])
trade_start_time = orders["datetime"].min()
trade_end_time = orders["datetime"].max()
stocks = orders.instrument.unique().tolist()
top_strategy_config = {
strategy_config = {
"class": "FileOrderStrategy",
"module_path": "qlib.contrib.strategy.rule_strategy",
"kwargs": {
"file": orders,
"trade_range": TradeRangeByTime(
pd.Timestamp(time_range[0]).time(),
pd.Timestamp(time_range[1]).time(),
pd.Timestamp(backtest_config["start_time"]).time(),
pd.Timestamp(backtest_config["end_time"]).time(),
),
},
}
top_executor_config = _get_multi_level_executor_config(
strategy_config=strategy_config,
executor_config = _get_multi_level_executor_config(
strategy_config=backtest_config["strategies"],
cash_limit=cash_limit,
generate_report=generate_report,
data_granularity=data_granularity,
data_granularity=backtest_config["data_granularity"],
)
exchange_config = {
**exchange_config,
**{
exchange_config = copy.deepcopy(backtest_config["exchange"])
exchange_config.update(
{
"codes": stocks,
"freq": data_granularity,
},
}
"freq": backtest_config["data_granularity"],
}
)
strategy, executor = get_strategy_executor(
start_time=pd.Timestamp(trade_start_time),
end_time=pd.Timestamp(trade_end_time) + pd.DateOffset(1),
strategy=top_strategy_config,
executor=top_executor_config,
strategy=strategy_config,
executor=executor_config,
benchmark=None,
account=cash_limit if cash_limit is not None else int(1e12),
exchange_kwargs=exchange_config,
@@ -200,7 +295,7 @@ def single_with_collect_data_loop(
)
report_dict: dict = {}
decisions = list(collect_data_loop(trade_start_time, trade_end_time, strategy, executor, report_dict, show_progress=False))
decisions = list(collect_data_loop(trade_start_time, trade_end_time, strategy, executor, report_dict))
indicator_dict = cast(INDICATOR_METRIC, report_dict.get("indicator_dict"))
records = _convert_indicator_to_dataframe(indicator_dict["1day"][1].order_indicator_his)
@@ -220,54 +315,46 @@ def single_with_collect_data_loop(
def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFrame:
init_qlib(backtest_config["simulator"]["qlib"])
order_df = read_order_file(backtest_config["order_file"])
cash_limit = backtest_config["exchange"].pop("cash_limit")
generate_report = backtest_config.pop("generate_report")
stock_pool = order_df["instrument"].unique().tolist()
stock_pool.sort()
single = single_with_simulator if with_simulator else single_with_collect_data_loop
mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"}
torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199
single = single_with_collect_data_loop
mp_config = {"n_jobs": backtest_config["runtime"]["concurrency"], "verbose": 10, "backend": "multiprocessing"}
for task_config in backtest_config["tasks"]:
order_df = read_order_file(task_config["order_file"])
exchange_config = task_config["exchange"]
cash_limit = exchange_config.pop("cash_limit")
generate_report = backtest_config["runtime"]["generate_report"]
stock_pool = order_df["instrument"].unique().tolist()
stock_pool.sort()
#
res = Parallel(**mp_config)(
delayed(single)(
orders=order_df[order_df["instrument"] == stock].copy(),
time_range=task_config["time_range"],
exchange_config=task_config["exchange"],
strategy_config=backtest_config["strategies"],
split="stock",
data_granularity=task_config["data_granularity"],
cash_limit=cash_limit,
generate_report=generate_report,
)
for stock in stock_pool
res = Parallel(**mp_config)(
delayed(single)(
backtest_config=backtest_config,
orders=order_df[order_df["instrument"] == stock].copy(),
split="stock",
cash_limit=cash_limit,
generate_report=generate_report,
)
#
output_path = Path(task_config["output_dir"])
os.makedirs(output_path, exist_ok=True)
if generate_report:
with (output_path / "report.pkl").open("wb") as f:
report = {}
for r in res:
report.update(r[1])
pickle.dump(report, f)
res = pd.concat([r[0] for r in res], 0)
else:
res = pd.concat(res)
if "pa" in res.columns:
res["pa"] = res["pa"] * 10000.0 # align with training metrics
res.to_csv(output_path / "backtest_result.csv")
# return res # TODO
for stock in stock_pool
)
output_path = Path(backtest_config["output_dir"])
if generate_report:
with (output_path / "report.pkl").open("wb") as f:
report = {}
for r in res:
report.update(r[1])
pickle.dump(report, f)
res = pd.concat([r[0] for r in res], 0)
else:
res = pd.concat(res)
if not output_path.exists():
os.makedirs(output_path)
if "pa" in res.columns:
res["pa"] = res["pa"] * 10000.0 # align with training metrics
res.to_csv(output_path / "backtest_result.csv")
return res
if __name__ == "__main__":
@@ -275,7 +362,6 @@ if __name__ == "__main__":
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
@@ -288,11 +374,9 @@ if __name__ == "__main__":
)
args = parser.parse_args()
config_parser = BacktestConfigParser(args.config_path)
config = config_parser.parse()
if args.n_jobs is not None: # Overwrite concurrency
config["runtime"]["concurrency"] = args.n_jobs
config = get_backtest_config_fromfile(args.config_path)
if args.n_jobs is not None:
config["concurrency"] = args.n_jobs
backtest(
backtest_config=config,

View File

@@ -1,7 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import copy
import os
import platform
import shutil
@@ -31,7 +30,7 @@ def check_file_exist(filename: str, msg_tmpl: str = 'file "{}" does not exist')
raise FileNotFoundError(msg_tmpl.format(filename))
def load_config(path: str) -> dict:
def parse_backtest_config(path: str) -> dict:
abs_path = os.path.abspath(path)
check_file_exist(abs_path)
@@ -66,154 +65,43 @@ def load_config(path: str) -> dict:
base_file_name = [base_file_name]
for f in base_file_name:
base_config = load_config(os.path.join(os.path.dirname(abs_path), f))
base_config = parse_backtest_config(os.path.join(os.path.dirname(abs_path), f))
config = merge_a_into_b(a=config, b=base_config)
return config
class BacktestConfigParser:
def __init__(self, path: str) -> None:
self.raw_config = load_config(path)
def parse(self) -> dict:
self._simulator_config = self._parse_simulator()
self._exchange_config = self._simulator_config.pop("exchange")
config = {
"strategies": self.raw_config["strategies"],
"runtime": self.raw_config["runtime"],
"tasks": self._parse_tasks(),
"simulator": self._simulator_config,
}
return config
def _parse_tasks(self) -> dict:
task_config = []
for task in self.raw_config["tasks"]:
if "output_dir" not in task:
task["output_dir"] = os.path.join("outputs_backtest", task["name"])
if "exchange" not in task:
task["exchange"] = copy.deepcopy(self._exchange_config)
else:
task["exchange"] = self._complete_exchange_config(task["exchange"])
task_config.append(task)
return task_config
def _complete_exchange_config(self, exchange_config: dict) -> dict:
exchange_config_default = {
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5.0,
"trade_unit": 100.0,
"cash_limit": None,
}
exchange_config = merge_a_into_b(a=exchange_config, b=exchange_config_default)
return exchange_config
def _parse_simulator(self) -> dict:
config = self.raw_config["simulator"]
return {
"qlib": config["qlib"],
"exchange": self._complete_exchange_config(config["exchange"]),
}
def _convert_all_list_to_tuple(config: dict) -> dict:
for k, v in config.items():
if isinstance(v, list):
config[k] = tuple(v)
elif isinstance(v, dict):
config[k] = _convert_all_list_to_tuple(v)
return config
class TrainingConfigParser:
def __init__(self, path: str) -> None:
self.raw_config = load_config(path)
def get_backtest_config_fromfile(path: str) -> dict:
backtest_config = parse_backtest_config(path)
def parse(self) -> dict:
return {
"general": self._parse_general(),
"policy": self.raw_config["policy"],
"interpreter": self.raw_config["interpreter"],
"runtime": self._parse_runtime(),
"training": self._parse_training(),
"simulator": self._parse_simulator(),
}
exchange_config_default = {
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5.0,
"trade_unit": 100.0,
"cash_limit": None,
}
backtest_config["exchange"] = merge_a_into_b(a=backtest_config["exchange"], b=exchange_config_default)
backtest_config["exchange"] = _convert_all_list_to_tuple(backtest_config["exchange"])
def _parse_general(self) -> dict:
default = {
"freq": "1min",
"extra_module_paths": [],
}
return {**default, **self.raw_config["general"]}
backtest_config_default = {
"debug_single_stock": None,
"debug_single_day": None,
"concurrency": -1,
"multiplier": 1.0,
"output_dir": "outputs_backtest/",
"generate_report": False,
"data_granularity": "1min",
}
backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default)
def _parse_runtime(self) -> dict:
default = {
"seed": None,
"use_cuda": False,
"concurrency": 1,
"parallel_mode": "dummy",
}
return {**default, **self.raw_config["runtime"]}
def _parse_training(self) -> dict:
default = {
"max_epoch": 100,
"repeat_per_collect": 2,
"earlystop_patience": float("inf"),
"episode_per_collect": 10000,
"batch_size": 256,
"val_every_n_epoch": None,
"checkpoint_path": "./outputs",
"checkpoint_every_n_iters": 10,
}
config = self.raw_config["training"]
assert "order_dir" in config
return {**default, **config}
def _parse_simulator(self) -> dict:
config = self.raw_config["simulator"]
sim_type = config["type"]
assert sim_type in ("simple", "full")
if sim_type == "simple":
return {
"type": sim_type,
"data": {
"feature_root_dir": config["data"]["feature_root_dir"],
"feature_columns_today": config["data"]["feature_columns_today"],
"default_start_time_index": config["data"].get("default_start_time_index", 0),
"default_end_time_index": config["data"].get("default_end_time_index", 240),
},
"time_per_step": config["time_per_step"],
"vol_limit": config["vol_limit"],
}
else:
exchange_config_default = {
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5.0,
"trade_unit": 100.0,
# "cash_limit": None,
}
exchange_config = {**exchange_config_default, **config["exchange"]}
exchange_config["freq"] = self.raw_config["general"].get("freq", "1min")
ret_config = {
"type": sim_type,
"data": {
"feature_root_dir": config["data"]["feature_root_dir"],
"default_start_time_index": config["data"].get("default_start_time_index", 0),
"default_end_time_index": config["data"].get("default_end_time_index", 240),
},
"qlib": {
"provider_uri_1min": config["qlib"]["provider_uri_1min"],
},
"exchange": exchange_config,
}
return ret_config
if __name__ == "__main__":
parser = TrainingConfigParser("/home/huoran/exp_configs/amc4th_training_refined.yml")
from pprint import pprint
pprint(parser.parse())
return backtest_config

View File

@@ -1,362 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import argparse
import os
import random
import sys
import warnings
from pathlib import Path
from typing import Callable, cast, List, Optional, Sequence
import numpy as np
import pandas as pd
import torch
from qlib.backtest import Order
from qlib.backtest.decision import OrderDir
from qlib.constant import ONE_MIN
from qlib.rl import Simulator
from qlib.rl.contrib.naive_config_parser import TrainingConfigParser
from qlib.rl.data.integration import init_qlib
from qlib.rl.data.pickle_styled import load_pickle_intraday_processed_data
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
from qlib.rl.reward import Reward
from qlib.rl.trainer import Checkpoint, backtest, train
from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter
from qlib.rl.utils.log import CsvWriter
from qlib.utils import init_instance_by_config
from tianshou.policy import BasePolicy
from torch.utils.data import Dataset
def get_executor_config(freq: int) -> dict:
return {
"class": "NestedExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"inner_executor": {
"class": "NestedExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"inner_executor": {
"class": "SimulatorExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"generate_report": False,
"time_per_step": f"{freq}min",
"track_data": True,
"trade_type": "serial",
"verbose": False,
},
},
"inner_strategy": {
"class": "TWAPStrategy",
"kwargs": {},
"module_path": "qlib.contrib.strategy.rule_strategy",
},
"time_per_step": "30min",
"track_data": True,
},
},
"inner_strategy": {
"class": "ProxySAOEStrategy",
"module_path": "qlib.rl.order_execution.strategy",
"kwargs": {},
},
"time_per_step": "1day",
"track_data": True,
},
}
def seed_everything(seed: int) -> None:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def _read_orders(order_dir: Path) -> pd.DataFrame:
if os.path.isfile(order_dir):
return pd.read_pickle(order_dir)
else:
orders = []
for file in order_dir.iterdir():
order_data = pd.read_pickle(file)
orders.append(order_data)
return pd.concat(orders)
def _freq_str_to_int(freq: str) -> int:
if freq.endswith("min"):
return int(freq.replace("min", ""))
elif freq.endswith("hour"):
return int(freq.replace("hour", "") * 60)
else:
raise ValueError(f"Unrecognized freq string: {freq}")
class LazyLoadDataset(Dataset):
def __init__(
self,
data_dir: str,
order_df: pd.DataFrame,
default_start_time_index: int,
default_end_time_index: int,
) -> None:
self._default_start_time_index = default_start_time_index
self._default_end_time_index = default_end_time_index
self._order_df = order_df
self._ticks_index: Optional[pd.DatetimeIndex] = None
self._data_dir = Path(data_dir)
def __len__(self) -> int:
return len(self._order_df)
def __getitem__(self, index: int) -> Order:
row = self._order_df.iloc[index]
date = pd.Timestamp(str(row["date"]))
if self._ticks_index is None:
# TODO: We only load ticks index once based on the assumption that ticks index of different dates
# TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index
# TODO: of all dates.
data = load_pickle_intraday_processed_data(
data_dir=self._data_dir,
stock_id=row["instrument"],
date=date,
feature_columns_today=[],
feature_columns_yesterday=[],
backtest=True,
)
self._ticks_index = [t - date for t in data.today.index]
order = Order(
stock_id=row["instrument"],
amount=row["amount"],
direction=OrderDir(int(row["order_type"])),
start_time=date + self._ticks_index[self._default_start_time_index],
end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN,
)
return order
def _split_order_df_by_instrument(df: pd.DataFrame, k: int) -> List[pd.DataFrame]:
df = df.copy()
df["group"] = df["instrument"].apply(lambda s: hash(s) % k)
dfs = [df[df["group"] == i].drop(columns=["group"]) for i in range(k)]
return dfs
def _get_simulator_factory(
sim_type: str,
data_dir: Path,
freq_min: int,
simulator_config: dict,
) -> Callable[[Order], Simulator]:
if sim_type == "simple":
def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
simulator = SingleAssetOrderExecutionSimple(
order=order,
data_dir=data_dir,
feature_columns_today=simulator_config["data"]["feature_columns_today"],
data_granularity=freq_min,
ticks_per_step=simulator_config["time_per_step"],
vol_threshold=simulator_config["vol_limit"],
)
return simulator
return _simulator_factory_simple
elif sim_type == "full":
init_qlib(simulator_config["qlib"])
executor_config = get_executor_config(freq_min)
exchange_config = simulator_config["exchange"]
def _simulator_factory_full(order: Order) -> SingleAssetOrderExecution:
simulator = SingleAssetOrderExecution(
order=order,
executor_config=executor_config,
exchange_config=exchange_config, # `codes` will be set in SingleAssetOrderExecution.__init__()
qlib_config=None,
cash_limit=None,
)
return simulator
return _simulator_factory_full
else:
raise ValueError(f"Unknown simulator type: {sim_type}")
def train_and_test(
freq: str,
concurrency: int,
parallel_mode: str,
training_config: dict,
simulator_config: dict,
policy: BasePolicy,
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
reward: Reward,
run_training: bool,
run_backtest: bool,
) -> None:
freq_min: int = _freq_str_to_int(freq)
order_root_path = Path(training_config["order_dir"])
feature_root_dir = simulator_config["data"]["feature_root_dir"]
assert simulator_config["data"]["default_start_time_index"] % freq_min == 0
assert simulator_config["data"]["default_end_time_index"] % freq_min == 0
_simulator_factory = _get_simulator_factory(
sim_type=simulator_config["type"],
data_dir=feature_root_dir,
freq_min=freq_min,
simulator_config=simulator_config,
)
# Load orders
load_data_tags = []
orders_by_tag = {}
if run_training:
load_data_tags += ["train", "valid"]
if run_backtest:
load_data_tags += ["test"]
for tag in load_data_tags:
order_df = _read_orders(order_root_path / tag).reset_index()
dfs = _split_order_df_by_instrument(order_df, concurrency)
datasets = [
LazyLoadDataset(
data_dir=feature_root_dir,
order_df=df,
default_start_time_index=simulator_config["data"]["default_start_time_index"] // freq_min,
default_end_time_index=simulator_config["data"]["default_end_time_index"] // freq_min,
)
for df in dfs
]
orders_by_tag[tag] = datasets
if run_training:
callbacks: List[Callback] = [
MetricsWriter(dirpath=Path(training_config["checkpoint_path"])),
Checkpoint(
dirpath=Path(training_config["checkpoint_path"]) / "checkpoints",
every_n_iters=training_config["checkpoint_every_n_iters"],
save_latest="copy",
),
EarlyStopping(
patience=training_config["earlystop_patience"],
monitor="val/pa",
),
]
train(
simulator_fn=_simulator_factory,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
policy=policy,
reward=reward,
initial_states=cast(List[Sequence[Order]], orders_by_tag["train"]),
trainer_kwargs={
"max_iters": training_config["max_epoch"],
"finite_env_type": parallel_mode,
"concurrency": concurrency,
"val_every_n_iters": training_config["val_every_n_epoch"],
"callbacks": callbacks,
},
vessel_kwargs={
"episode_per_iter": training_config["episode_per_collect"],
"update_kwargs": {
"batch_size": training_config["batch_size"],
"repeat": training_config["repeat_per_collect"],
},
"val_initial_states": cast(List[Sequence[Order]], orders_by_tag["valid"]),
},
)
if run_backtest:
backtest(
simulator_fn=_simulator_factory,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
initial_states=cast(List[Sequence[Order]], orders_by_tag["test"]),
policy=policy,
logger=CsvWriter(Path(training_config["checkpoint_path"])),
reward=reward,
finite_env_type=parallel_mode, # type: ignore[arg-type]
concurrency=concurrency,
)
def main(config: dict, run_training: bool, run_backtest: bool) -> None:
if not run_training and not run_backtest:
warnings.warn("Skip the entire job since training and backtest are both skipped.")
return
seed = config["runtime"]["seed"]
if seed is not None:
seed_everything(seed)
for extra_module_path in config["general"]["extra_module_paths"]:
sys.path.append(extra_module_path)
state_interpreter: StateInterpreter = init_instance_by_config(config["interpreter"]["state"])
action_interpreter: ActionInterpreter = init_instance_by_config(config["interpreter"]["action"])
reward: Reward = init_instance_by_config(config["interpreter"]["reward"])
additional_policy_kwargs = {
"obs_space": state_interpreter.observation_space,
"action_space": action_interpreter.action_space,
}
# Create torch network
if "network" in config["policy"]:
network_config = config["policy"]["network"]
network_config["kwargs"] = {
**network_config.get("kwargs", {}),
**{"obs_space": state_interpreter.observation_space},
}
additional_policy_kwargs["network"] = init_instance_by_config(network_config)
# Create policy
policy_config = config["policy"]["policy"]
policy_config["kwargs"] = {**policy_config.get("kwargs", {}), **additional_policy_kwargs}
policy: BasePolicy = init_instance_by_config(policy_config)
use_cuda = config["runtime"]["use_cuda"]
if use_cuda:
policy.cuda()
train_and_test(
freq=config["general"]["freq"],
concurrency=config["runtime"]["concurrency"],
parallel_mode=config["runtime"]["parallel_mode"],
training_config=config["training"],
simulator_config=config["simulator"],
policy=policy,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
reward=reward,
run_training=run_training,
run_backtest=run_backtest,
)
if __name__ == "__main__":
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
parser.add_argument("--no_training", action="store_true", help="Skip training workflow.")
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.")
args = parser.parse_args()
config_parser = TrainingConfigParser(args.config_path)
config = config_parser.parse()
main(config, run_training=not args.no_training, run_backtest=args.run_backtest)

View File

@@ -0,0 +1,268 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import argparse
import os
import random
import sys
import warnings
from pathlib import Path
from typing import cast, List, Optional
import numpy as np
import pandas as pd
import torch
import yaml
from qlib.backtest import Order
from qlib.backtest.decision import OrderDir
from qlib.constant import ONE_MIN
from qlib.rl.data.native import load_handler_intraday_processed_data
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
from qlib.rl.reward import Reward
from qlib.rl.trainer import Checkpoint, backtest, train
from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter
from qlib.rl.utils.log import CsvWriter
from qlib.utils import init_instance_by_config
from tianshou.policy import BasePolicy
from torch.utils.data import Dataset
def seed_everything(seed: int) -> None:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def _read_orders(order_dir: Path) -> pd.DataFrame:
if os.path.isfile(order_dir):
return pd.read_pickle(order_dir)
else:
orders = []
for file in order_dir.iterdir():
order_data = pd.read_pickle(file)
orders.append(order_data)
return pd.concat(orders)
class LazyLoadDataset(Dataset):
def __init__(
self,
data_dir: str,
order_file_path: Path,
default_start_time_index: int,
default_end_time_index: int,
) -> None:
self._default_start_time_index = default_start_time_index
self._default_end_time_index = default_end_time_index
self._order_df = _read_orders(order_file_path).reset_index()
self._ticks_index: Optional[pd.DatetimeIndex] = None
self._data_dir = Path(data_dir)
def __len__(self) -> int:
return len(self._order_df)
def __getitem__(self, index: int) -> Order:
row = self._order_df.iloc[index]
date = pd.Timestamp(str(row["date"]))
if self._ticks_index is None:
# TODO: We only load ticks index once based on the assumption that ticks index of different dates
# TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index
# TODO: of all dates.
data = load_handler_intraday_processed_data(
data_dir=self._data_dir,
stock_id=row["instrument"],
date=date,
feature_columns_today=[],
feature_columns_yesterday=[],
backtest=True,
index_only=True,
)
self._ticks_index = [t - date for t in data.today.index]
order = Order(
stock_id=row["instrument"],
amount=row["amount"],
direction=OrderDir(int(row["order_type"])),
start_time=date + self._ticks_index[self._default_start_time_index],
end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN,
)
return order
def train_and_test(
env_config: dict,
simulator_config: dict,
trainer_config: dict,
data_config: dict,
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
policy: BasePolicy,
reward: Reward,
run_training: bool,
run_backtest: bool,
) -> None:
order_root_path = Path(data_config["source"]["order_dir"])
data_granularity = simulator_config.get("data_granularity", 1)
def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
return SingleAssetOrderExecutionSimple(
order=order,
data_dir=data_config["source"]["feature_root_dir"],
feature_columns_today=data_config["source"]["feature_columns_today"],
feature_columns_yesterday=data_config["source"]["feature_columns_yesterday"],
data_granularity=data_granularity,
ticks_per_step=simulator_config["time_per_step"],
vol_threshold=simulator_config["vol_limit"],
)
assert data_config["source"]["default_start_time_index"] % data_granularity == 0
assert data_config["source"]["default_end_time_index"] % data_granularity == 0
if run_training:
train_dataset, valid_dataset = [
LazyLoadDataset(
data_dir=data_config["source"]["feature_root_dir"],
order_file_path=order_root_path / tag,
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
)
for tag in ("train", "valid")
]
callbacks: List[Callback] = []
if "checkpoint_path" in trainer_config:
callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"])))
callbacks.append(
Checkpoint(
dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints",
every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1),
save_latest="copy",
),
)
if "earlystop_patience" in trainer_config:
callbacks.append(
EarlyStopping(
patience=trainer_config["earlystop_patience"],
monitor="val/pa",
)
)
train(
simulator_fn=_simulator_factory_simple,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
policy=policy,
reward=reward,
initial_states=cast(List[Order], train_dataset),
trainer_kwargs={
"max_iters": trainer_config["max_epoch"],
"finite_env_type": env_config["parallel_mode"],
"concurrency": env_config["concurrency"],
"val_every_n_iters": trainer_config.get("val_every_n_epoch", None),
"callbacks": callbacks,
},
vessel_kwargs={
"episode_per_iter": trainer_config["episode_per_collect"],
"update_kwargs": {
"batch_size": trainer_config["batch_size"],
"repeat": trainer_config["repeat_per_collect"],
},
"val_initial_states": valid_dataset,
},
)
if run_backtest:
test_dataset = LazyLoadDataset(
data_dir=data_config["source"]["feature_root_dir"],
order_file_path=order_root_path / "test",
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
)
backtest(
simulator_fn=_simulator_factory_simple,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
initial_states=test_dataset,
policy=policy,
logger=CsvWriter(Path(trainer_config["checkpoint_path"])),
reward=reward,
finite_env_type=env_config["parallel_mode"],
concurrency=env_config["concurrency"],
)
def main(config: dict, run_training: bool, run_backtest: bool) -> None:
if not run_training and not run_backtest:
warnings.warn("Skip the entire job since training and backtest are both skipped.")
return
if "seed" in config["runtime"]:
seed_everything(config["runtime"]["seed"])
for extra_module_path in config["env"].get("extra_module_paths", []):
sys.path.append(extra_module_path)
state_interpreter: StateInterpreter = init_instance_by_config(config["state_interpreter"])
action_interpreter: ActionInterpreter = init_instance_by_config(config["action_interpreter"])
reward: Reward = init_instance_by_config(config["reward"])
additional_policy_kwargs = {
"obs_space": state_interpreter.observation_space,
"action_space": action_interpreter.action_space,
}
# Create torch network
if "network" in config:
if "kwargs" not in config["network"]:
config["network"]["kwargs"] = {}
config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space})
additional_policy_kwargs["network"] = init_instance_by_config(config["network"])
# Create policy
if "kwargs" not in config["policy"]:
config["policy"]["kwargs"] = {}
config["policy"]["kwargs"].update(additional_policy_kwargs)
policy: BasePolicy = init_instance_by_config(config["policy"])
use_cuda = config["runtime"].get("use_cuda", False)
if use_cuda:
policy.cuda()
train_and_test(
env_config=config["env"],
simulator_config=config["simulator"],
data_config=config["data"],
trainer_config=config["trainer"],
action_interpreter=action_interpreter,
state_interpreter=state_interpreter,
policy=policy,
reward=reward,
run_training=run_training,
run_backtest=run_backtest,
)
if __name__ == "__main__":
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
parser.add_argument("--no_training", action="store_true", help="Skip training workflow.")
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.")
args = parser.parse_args()
with open(args.config_path, "r") as input_stream:
config = yaml.safe_load(input_stream)
main(config, run_training=not args.no_training, run_backtest=args.run_backtest)

View File

@@ -13,7 +13,6 @@ import os
from qlib.backtest import Exchange, Order
from qlib.backtest.decision import TradeRange, TradeRangeByTime
from qlib.constant import EPS_T
from qlib.data.dataset import DatasetH
from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
@@ -141,16 +140,6 @@ def load_backtest_data(
return backtest_data
@cachetools.cached( # type: ignore
cache=cachetools.LRUCache(1000),
key=lambda path: path,
)
def _load_handler_pickle(path: str) -> DatasetH:
with open(path, "rb") as fstream:
obj = pickle.load(fstream)
return obj
class HandlerIntradayProcessedData(BaseIntradayProcessedData):
"""Subclass of IntradayProcessedData. Used to handle handler (bin format) style data."""
@@ -162,6 +151,7 @@ class HandlerIntradayProcessedData(BaseIntradayProcessedData):
feature_columns_today: List[str],
feature_columns_yesterday: List[str],
backtest: bool = False,
index_only: bool = False,
) -> None:
def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame:
df = df.reset_index()
@@ -171,17 +161,31 @@ class HandlerIntradayProcessedData(BaseIntradayProcessedData):
path = os.path.join(data_dir, "backtest" if backtest else "feature", f"{stock_id}.pkl")
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
dataset = _load_handler_pickle(path)
with open(path, "rb") as fstream:
dataset = pickle.load(fstream)
data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
self.today = _drop_stock_id(data[feature_columns_today])
self.yesterday = _drop_stock_id(data[feature_columns_yesterday])
if index_only:
self.today = _drop_stock_id(data[[]])
self.yesterday = _drop_stock_id(data[[]])
else:
self.today = _drop_stock_id(data[feature_columns_today])
self.yesterday = _drop_stock_id(data[feature_columns_yesterday])
def __repr__(self) -> str:
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
return f"{self.__class__.__name__}({self.today}, {self.yesterday})"
@cachetools.cached( # type: ignore
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
key=lambda data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only: (
stock_id,
date,
backtest,
index_only,
),
)
def load_handler_intraday_processed_data(
data_dir: Path,
stock_id: str,
@@ -189,14 +193,10 @@ def load_handler_intraday_processed_data(
feature_columns_today: List[str],
feature_columns_yesterday: List[str],
backtest: bool = False,
index_only: bool = False,
) -> HandlerIntradayProcessedData:
return HandlerIntradayProcessedData(
data_dir,
stock_id,
date,
feature_columns_today,
feature_columns_yesterday,
backtest,
data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only
)
@@ -229,4 +229,5 @@ class HandlerProcessedDataProvider(ProcessedDataProvider):
self.feature_columns_today,
self.feature_columns_yesterday,
backtest=self.backtest,
index_only=False,
)

View File

@@ -26,6 +26,7 @@ from typing import List, Sequence, cast
import cachetools
import numpy as np
import pandas as pd
from cachetools.keys import hashkey
from qlib.backtest.decision import Order, OrderDir
from qlib.rl.data.base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
@@ -157,15 +158,6 @@ class SimpleIntradayBacktestData(BaseIntradayBacktestData):
return cast(pd.DatetimeIndex, self.data.index)
@cachetools.cached( # type: ignore
cache=cachetools.LRUCache(1000),
key=lambda path: path,
)
def _load_df_pickle(path: str) -> pd.DataFrame:
df = pd.read_pickle(path)
return df
class PickleIntradayProcessedData(BaseIntradayProcessedData):
"""Subclass of IntradayProcessedData. Used to handle pickle-styled data."""
@@ -174,18 +166,36 @@ class PickleIntradayProcessedData(BaseIntradayProcessedData):
data_dir: Path | str,
stock_id: str,
date: pd.Timestamp,
feature_columns_today: List[str],
feature_columns_yesterday: List[str],
backtest: bool,
feature_dim: int,
time_index: pd.Index,
) -> None:
if isinstance(data_dir, str):
data_dir = Path(data_dir)
path = data_dir / ("backtest" if backtest else "feature") / f"{stock_id}.pkl"
df = _load_df_pickle(str(path))
df = df.loc[pd.IndexSlice[stock_id, :, date]]
proc = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id)
self.today = df[feature_columns_today]
self.yesterday = df[feature_columns_yesterday]
# We have to infer the names here because,
# unfortunately they are not included in the original data.
cnames = _infer_processed_data_column_names(feature_dim)
time_length: int = len(time_index)
try:
# new data format
proc = proc.loc[pd.IndexSlice[stock_id, :, date]]
assert len(proc) == time_length and len(proc.columns) == feature_dim * 2
proc_today = proc[cnames]
proc_yesterday = proc[[f"{c}_1" for c in cnames]].rename(columns=lambda c: c[:-2])
except (IndexError, KeyError):
# legacy data
proc = proc.loc[pd.IndexSlice[stock_id, date]]
assert time_length * feature_dim * 2 == len(proc)
proc_today = proc.to_numpy()[: time_length * feature_dim].reshape((time_length, feature_dim))
proc_yesterday = proc.to_numpy()[time_length * feature_dim :].reshape((time_length, feature_dim))
proc_today = pd.DataFrame(proc_today, index=time_index, columns=cnames)
proc_yesterday = pd.DataFrame(proc_yesterday, index=time_index, columns=cnames)
self.today: pd.DataFrame = proc_today
self.yesterday: pd.DataFrame = proc_yesterday
assert len(self.today.columns) == len(self.yesterday.columns) == feature_dim
assert len(self.today) == len(self.yesterday) == time_length
def __repr__(self) -> str:
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
@@ -203,38 +213,25 @@ def load_simple_intraday_backtest_data(
return SimpleIntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir)
@cachetools.cached( # type: ignore
cache=cachetools.LRUCache(100), # 100 * 50K = 5MB
key=lambda data_dir, stock_id, date, feature_dim, time_index: hashkey(data_dir, stock_id, date),
)
def load_pickle_intraday_processed_data(
data_dir: Path,
stock_id: str,
date: pd.Timestamp,
feature_columns_today: List[str],
feature_columns_yesterday: List[str],
backtest: bool = False,
feature_dim: int,
time_index: pd.Index,
) -> BaseIntradayProcessedData:
return PickleIntradayProcessedData(
data_dir,
stock_id,
date,
feature_columns_today,
feature_columns_yesterday,
backtest,
)
return PickleIntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index)
class PickleProcessedDataProvider(ProcessedDataProvider):
def __init__(
self,
data_dir: Path,
feature_columns_today: List[str],
feature_columns_yesterday: List[str],
backtest: bool = False,
) -> None:
def __init__(self, data_dir: Path) -> None:
super().__init__()
self._data_dir = data_dir
self._backtest = backtest
self._feature_columns_today = feature_columns_today
self._feature_columns_yesterday = feature_columns_yesterday
def get_data(
self,
@@ -247,9 +244,8 @@ class PickleProcessedDataProvider(ProcessedDataProvider):
data_dir=self._data_dir,
stock_id=stock_id,
date=date,
feature_columns_today=self._feature_columns_today,
feature_columns_yesterday=self._feature_columns_yesterday,
backtest=self._backtest,
feature_dim=feature_dim,
time_index=time_index,
)

View File

@@ -4,11 +4,10 @@
from __future__ import annotations
from typing import Generator, List, Optional
import cachetools
import pandas as pd
from qlib.backtest import collect_data_loop, Exchange, get_exchange, get_strategy_executor
from qlib.backtest import collect_data_loop, get_strategy_executor
from qlib.backtest.decision import BaseTradeDecision, Order, TradeRangeByTime
from qlib.backtest.executor import NestedExecutor
from qlib.rl.data.integration import init_qlib
@@ -17,18 +16,6 @@ from .state import SAOEState
from .strategy import SAOEStateAdapter, SAOEStrategy
@cachetools.cached( # type: ignore
cache=cachetools.LRUCache(1000),
key=lambda order, _: order.stock_id,
)
def _create_exchange(order: Order, exchange_config: dict) -> Exchange:
exchange_kwargs = {
**exchange_config,
"codes": [order.stock_id],
}
return get_exchange(**exchange_kwargs)
class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
"""Single-asset order execution (SAOE) simulator which is implemented based on Qlib backtest tools.
@@ -89,7 +76,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
executor=executor_config,
benchmark=order.stock_id,
account=cash_limit if cash_limit is not None else int(1e12),
exchange_kwargs=_create_exchange(order, exchange_config),
exchange_kwargs=exchange_config,
pos_type="Position" if cash_limit is not None else "InfPosition",
)
@@ -103,7 +90,6 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
trade_strategy=strategy,
trade_executor=self._executor,
return_value=self.report_dict,
show_progress=False,
)
assert isinstance(self._collect_data_loop, Generator)

View File

@@ -12,8 +12,7 @@ from pathlib import Path
from qlib.backtest.decision import Order, OrderDir
from qlib.constant import EPS, EPS_T, float_or_ndarray
from qlib.rl.data.base import BaseIntradayBacktestData
from qlib.rl.data.native import DataframeIntradayBacktestData
from qlib.rl.data.pickle_styled import load_pickle_intraday_processed_data
from qlib.rl.data.native import DataframeIntradayBacktestData, load_handler_intraday_processed_data
from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data
from qlib.rl.simulator import Simulator
from qlib.rl.utils import LogLevel
@@ -43,6 +42,8 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
Path to load backtest data.
feature_columns_today
Columns of today's feature.
feature_columns_yesterday
Columns of yesterday's feature.
data_granularity
Number of ticks between consecutive data entries.
ticks_per_step
@@ -79,6 +80,7 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
order: Order,
data_dir: Path,
feature_columns_today: List[str] = [],
feature_columns_yesterday: List[str] = [],
data_granularity: int = 1,
ticks_per_step: int = 30,
vol_threshold: Optional[float] = None,
@@ -90,6 +92,7 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
self.order = order
self.data_dir = data_dir
self.feature_columns_today = feature_columns_today
self.feature_columns_yesterday = feature_columns_yesterday
self.ticks_per_step: int = ticks_per_step // data_granularity
self.vol_threshold = vol_threshold
@@ -119,13 +122,14 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
def get_backtest_data(self) -> BaseIntradayBacktestData:
try:
data = load_pickle_intraday_processed_data(
data = load_handler_intraday_processed_data(
data_dir=self.data_dir,
stock_id=self.order.stock_id,
date=pd.Timestamp(self.order.start_time.date()),
feature_columns_today=self.feature_columns_today,
feature_columns_yesterday=[],
feature_columns_yesterday=self.feature_columns_yesterday,
backtest=True,
index_only=False,
)
return DataframeIntradayBacktestData(data.today)
except (AttributeError, FileNotFoundError):

View File

@@ -451,7 +451,6 @@ class SAOEIntStrategy(SAOEStrategy):
state_interpreter: dict | StateInterpreter,
action_interpreter: dict | ActionInterpreter,
network: dict | torch.nn.Module | None = None,
immediate_addition: bool = False,
outer_trade_decision: BaseTradeDecision | None = None,
level_infra: LevelInfrastructure | None = None,
common_infra: CommonInfrastructure | None = None,
@@ -502,12 +501,9 @@ class SAOEIntStrategy(SAOEStrategy):
if self._policy is not None:
self._policy.eval()
self.immediate_addition = immediate_addition
def reset(self, outer_trade_decision: BaseTradeDecision | None = None, **kwargs: Any) -> None:
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
self.trade_amount_planned = collections.defaultdict(float)
def _generate_trade_details(self, act: np.ndarray, exec_vols: List[float]) -> pd.DataFrame:
assert hasattr(self.outer_trade_decision, "order_list")
@@ -543,15 +539,9 @@ class SAOEIntStrategy(SAOEStrategy):
oh = self.trade_exchange.get_order_helper()
order_list = []
for decision, exec_vol, state in zip(self.outer_trade_decision.get_decision(), exec_vols, states):
order = cast(Order, decision)
if self.immediate_addition:
self.trade_amount_planned[order.stock_id] += exec_vol
amount_planned = self.trade_amount_planned[order.stock_id]
amount_finished = order.amount - state.position
exec_vol = min(state.position, amount_planned - amount_finished)
for decision, exec_vol in zip(self.outer_trade_decision.get_decision(), exec_vols):
if exec_vol != 0:
order = cast(Order, decision)
order_list.append(oh.create(order.stock_id, exec_vol, order.direction))
return TradeDecisionWithDetails(

View File

@@ -20,7 +20,7 @@ def train(
simulator_fn: Callable[[InitialStateType], Simulator],
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
initial_states: List[Sequence[InitialStateType]],
initial_states: Sequence[InitialStateType],
policy: BasePolicy,
reward: Reward,
vessel_kwargs: Dict[str, Any],
@@ -39,9 +39,7 @@ def train(
action_interpreter
Interprets the policy actions.
initial_states
List of Initial state iterators to iterate over. There should be 1 or `concurrency` initial state iterators in
the list. If there is only 1 initial state iterator, this iterator will be shared by all workers and every
state will be run exactly once. Otherwise, every worker will have its own iterator.
Initial states to iterate over. Every state will be run exactly once.
policy
Policy to train against.
reward
@@ -69,7 +67,7 @@ def backtest(
simulator_fn: Callable[[InitialStateType], Simulator],
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
initial_states: List[Sequence[InitialStateType]],
initial_states: Sequence[InitialStateType],
policy: BasePolicy,
logger: LogWriter | List[LogWriter],
reward: Reward | None = None,
@@ -89,9 +87,7 @@ def backtest(
action_interpreter
Interprets the policy actions.
initial_states
List of Initial state iterators to iterate over. There should be 1 or `concurrency` initial state iterators in
the list. If there is only 1 initial state iterator, this iterator will be shared by all workers and every
state will be run exactly once. Otherwise, every worker will have its own iterator.
Initial states to iterate over. Every state will be run exactly once.
policy
Policy to test against.
logger

View File

@@ -5,9 +5,8 @@ from __future__ import annotations
import collections
import copy
from contextlib import AbstractContextManager, ExitStack, contextmanager
from contextlib import AbstractContextManager, contextmanager
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import Any, Dict, Iterable, List, OrderedDict, Sequence, TypeVar, cast
@@ -207,50 +206,45 @@ class Trainer:
self._call_callback_hooks("on_fit_start")
with _wrap_context(vessel.train_seed_iterators()) as train_iterators, _wrap_context(
vessel.val_seed_iterators()
) as valid_iterators:
train_vector_env = self.venv_from_iterator(train_iterators)
valid_vector_env = self.venv_from_iterator(valid_iterators)
while not self.should_stop:
msg = f"\n{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\tTrain iteration {self.current_iter + 1}/{self.max_iters}"
_logger.info(msg)
while not self.should_stop:
msg = f"\n{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\tTrain iteration {self.current_iter + 1}/{self.max_iters}"
print(msg)
_logger.info(msg)
self.initialize_iter()
self.initialize_iter()
self._call_callback_hooks("on_iter_start")
self._call_callback_hooks("on_iter_start")
self.current_stage = "train"
self._call_callback_hooks("on_train_start")
self.current_stage = "train"
self._call_callback_hooks("on_train_start")
# TODO
# Add a feature that supports reloading the training environment every few iterations.
with _wrap_context(vessel.train_seed_iterator()) as iterator:
vector_env = self.venv_from_iterator(iterator)
self.vessel.train(vector_env)
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
# TODO
# Add a feature that supports reloading the training environment every few iterations.
self.vessel.train(train_vector_env)
self._call_callback_hooks("on_train_end")
self._call_callback_hooks("on_train_end")
if self.val_every_n_iters is not None and (self.current_iter + 1) % self.val_every_n_iters == 0:
# Implementation of validation loop
self.current_stage = "val"
self._call_callback_hooks("on_validate_start")
with _wrap_context(vessel.val_seed_iterator()) as iterator:
vector_env = self.venv_from_iterator(iterator)
self.vessel.validate(vector_env)
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
if self.val_every_n_iters is not None and (self.current_iter + 1) % self.val_every_n_iters == 0:
# Implementation of validation loop
self.current_stage = "val"
self._call_callback_hooks("on_validate_start")
self._call_callback_hooks("on_validate_end")
self.vessel.validate(valid_vector_env)
# This iteration is considered complete.
# Bumping the current iteration counter.
self.current_iter += 1
self._call_callback_hooks("on_validate_end")
if self.max_iters is not None and self.current_iter >= self.max_iters:
self.should_stop = True
# This iteration is considered complete.
# Bumping the current iteration counter.
self.current_iter += 1
if self.max_iters is not None and self.current_iter >= self.max_iters:
self.should_stop = True
self._call_callback_hooks("on_iter_end")
del train_vector_env # FIXME: Explicitly delete this object to avoid memory leak.
del valid_vector_env # FIXME: Explicitly delete this object to avoid memory leak.
self._call_callback_hooks("on_iter_end")
self._call_callback_hooks("on_fit_end")
@@ -271,16 +265,16 @@ class Trainer:
self.current_stage = "test"
self._call_callback_hooks("on_test_start")
with _wrap_context(vessel.test_seed_iterators()) as iterators:
vector_env = self.venv_from_iterator(iterators)
with _wrap_context(vessel.test_seed_iterator()) as iterator:
vector_env = self.venv_from_iterator(iterator)
self.vessel.test(vector_env)
del vector_env # FIXME: Explicitly delete this object to avoid memory leak.
self._call_callback_hooks("on_test_end")
def venv_from_iterator(self, iterators: List[Iterable[InitialStateType]]) -> FiniteVectorEnv:
def venv_from_iterator(self, iterator: Iterable[InitialStateType]) -> FiniteVectorEnv:
"""Create a vectorized environment from iterator and the training vessel."""
def env_factory(iterator):
def env_factory():
# FIXME: state_interpreter and action_interpreter are stateful (having a weakref of env),
# and could be thread unsafe.
# I'm not sure whether it's a design flaw.
@@ -306,7 +300,7 @@ class Trainer:
)
return vectorize_env(
[partial(env_factory, iterator=it) for it in iterators],
env_factory,
self.finite_env_type,
self.concurrency,
self.loggers,
@@ -340,11 +334,8 @@ class Trainer:
@contextmanager
def _wrap_context(obj):
"""Make any object a (possibly dummy) context manager."""
if isinstance(obj, list) and isinstance(obj[0], AbstractContextManager):
with ExitStack() as stack:
yield [stack.enter_context(e) for e in obj]
stack.pop_all().close()
elif isinstance(obj, AbstractContextManager):
if isinstance(obj, AbstractContextManager):
# obj has __enter__ and __exit__
with obj as ctx:
yield ctx

View File

@@ -4,7 +4,7 @@
from __future__ import annotations
import weakref
from typing import List, TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generic, Iterable, Sequence, TypeVar, cast
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generic, Iterable, Sequence, TypeVar, cast
import numpy as np
from tianshou.data import Collector, VectorReplayBuffer
@@ -49,23 +49,19 @@ class TrainingVesselBase(Generic[InitialStateType, StateType, ActType, ObsType,
def assign_trainer(self, trainer: Trainer) -> None:
self.trainer = weakref.proxy(trainer) # type: ignore
def train_seed_iterators(
self,
) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
"""Override this to create a seed iterators for training.
def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
"""Override this to create a seed iterator for training.
If the iterable is a context manager, the whole training will be invoked in the with-block,
and the iterator will be automatically closed after the training is done."""
raise SeedIteratorNotAvailable("Seed iterators for training is not available.")
raise SeedIteratorNotAvailable("Seed iterator for training is not available.")
def val_seed_iterators(self) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
"""Override this to create a seed iterators for validation."""
raise SeedIteratorNotAvailable("Seed iterators for validation is not available.")
def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
"""Override this to create a seed iterator for validation."""
raise SeedIteratorNotAvailable("Seed iterator for validation is not available.")
def test_seed_iterators(
self,
) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
"""Override this to create a seed iterators for testing."""
raise SeedIteratorNotAvailable("Seed iterators for testing is not available.")
def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
"""Override this to create a seed iterator for testing."""
raise SeedIteratorNotAvailable("Seed iterator for testing is not available.")
def train(self, vector_env: BaseVectorEnv) -> Dict[str, Any]:
"""Implement this to train one iteration. In RL, one iteration usually refers to one collect."""
@@ -124,9 +120,9 @@ class TrainingVessel(TrainingVesselBase):
action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType],
policy: BasePolicy,
reward: Reward,
train_initial_states: List[Sequence[InitialStateType]] | None = None,
val_initial_states: List[Sequence[InitialStateType]] | None = None,
test_initial_states: List[Sequence[InitialStateType]] | None = None,
train_initial_states: Sequence[InitialStateType] | None = None,
val_initial_states: Sequence[InitialStateType] | None = None,
test_initial_states: Sequence[InitialStateType] | None = None,
buffer_size: int = 20000,
episode_per_iter: int = 1000,
update_kwargs: Dict[str, Any] = cast(Dict[str, Any], None),
@@ -136,49 +132,34 @@ class TrainingVessel(TrainingVesselBase):
self.action_interpreter = action_interpreter
self.policy = policy
self.reward = reward
self.train_initial_states = None if train_initial_states is None else train_initial_states
self.val_initial_states = None if val_initial_states is None else val_initial_states
self.test_initial_states = None if test_initial_states is None else test_initial_states
self.train_initial_states = train_initial_states
self.val_initial_states = val_initial_states
self.test_initial_states = test_initial_states
self.buffer_size = buffer_size
self.episode_per_iter = episode_per_iter
self.update_kwargs = update_kwargs or {}
def train_seed_iterators(
self,
) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
if self.train_initial_states is not None:
_logger.info(f"Training initial states collection sizes: {[len(e) for e in self.train_initial_states]}")
train_initial_states = [
self._random_subset("train", e, self.trainer.fast_dev_run) for e in self.train_initial_states
]
iterators = [DataQueue(e, repeat=-1, shuffle=True) for e in train_initial_states]
return cast(List[Iterable[InitialStateType]], iterators)
else:
return super().train_seed_iterators()
_logger.info("Training initial states collection size: %d", len(self.train_initial_states))
# Implement fast_dev_run here.
train_initial_states = self._random_subset("train", self.train_initial_states, self.trainer.fast_dev_run)
return DataQueue(train_initial_states, repeat=-1, shuffle=True)
return super().train_seed_iterator()
def val_seed_iterators(self) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
if self.val_initial_states is not None:
_logger.info(f"Validation initial states collection sizes: {[len(e) for e in self.val_initial_states]}")
val_initial_states = [
self._random_subset("val", e, self.trainer.fast_dev_run) for e in self.val_initial_states
]
iterators = [DataQueue(e, repeat=1) for e in val_initial_states]
return cast(List[Iterable[InitialStateType]], iterators)
else:
return super().val_seed_iterators()
_logger.info("Validation initial states collection size: %d", len(self.val_initial_states))
val_initial_states = self._random_subset("val", self.val_initial_states, self.trainer.fast_dev_run)
return DataQueue(val_initial_states, repeat=1)
return super().val_seed_iterator()
def test_seed_iterators(
self,
) -> List[ContextManager[Iterable[InitialStateType]]] | List[Iterable[InitialStateType]]:
def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]:
if self.test_initial_states is not None:
_logger.info(f"Testing initial states collection sizes: {[len(e) for e in self.test_initial_states]}")
test_initial_states = [
self._random_subset("test", e, self.trainer.fast_dev_run) for e in self.test_initial_states
]
iterators = [DataQueue(e, repeat=1) for e in test_initial_states]
return cast(List[Iterable[InitialStateType]], iterators)
else:
return super().test_seed_iterators()
_logger.info("Testing initial states collection size: %d", len(self.test_initial_states))
test_initial_states = self._random_subset("test", self.test_initial_states, self.trainer.fast_dev_run)
return DataQueue(test_initial_states, repeat=1)
return super().test_seed_iterator()
def train(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]:
"""Create a collector and collects ``episode_per_iter`` episodes.

View File

@@ -258,46 +258,6 @@ class FiniteVectorEnv(BaseVectorEnv):
return np.stack(obs)
def step2(
self,
action: np.ndarray,
id: int | List[int] | np.ndarray | None = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
assert not self._zombie
wrapped_id = self._wrap_id(id)
id2idx = {i: k for k, i in enumerate(wrapped_id)}
request_id = list(filter(lambda i: i in self._alive_env_ids, wrapped_id))
result = {}
# ask super to step alive envs and remap to current index
if request_id:
valid_act = np.stack([action[id2idx[i]] for i in request_id])
tmp = super().step(valid_act, request_id)
for obs_next, rew, done, info in zip(*tmp):
obs_next = self._postproc_env_obs(obs_next)
result[info["env_id"]] = [obs_next, rew, done, info]
# logging
for i, r in result.items():
if i in self._alive_env_ids and r[0] is not None:
for logger in self._logger:
logger.on_env_step(i, *r)
for _, reward, __, info in result.values():
self._set_default_info(info)
self._set_default_rew(reward)
for r in result.values():
if r[0] is None:
r[0] = self._get_default_obs()
if r[1] is None:
r[1] = self._get_default_rew()
if r[3] is None:
r[3] = self._get_default_info()
ret = list(map(np.stack, zip(*result.values())))
return cast(Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], ret)
def step(
self,
action: np.ndarray,
@@ -351,7 +311,7 @@ class FiniteShmemVectorEnv(FiniteVectorEnv, ShmemVectorEnv):
def vectorize_env(
env_factories: List[Callable[..., gym.Env]],
env_factory: Callable[..., gym.Env],
env_type: FiniteEnvType,
concurrency: int,
logger: LogWriter | List[LogWriter],
@@ -374,10 +334,9 @@ def vectorize_env(
Parameters
----------
env_factories
Callables to instantiate one single ``gym.Env``.
There should be 1 or `concurrency` env_factories. If there is 1 env_factory, all concurrent workers will have
the same env_factory. Otherwise, each worker will have its own env_factory.
env_factory
Callable to instantiate one single ``gym.Env``.
All concurrent workers will have the same ``env_factory``.
env_type
dummy or subproc or shmem. Corresponding to
`parallelism in tianshou <https://tianshou.readthedocs.io/en/master/api/tianshou.env.html#vectorenv>`_.
@@ -399,8 +358,6 @@ def vectorize_env(
def env_factory(): ...
vectorize_env(env_factory, ...)
"""
assert len(env_factories) in (1, concurrency)
env_type_cls_mapping: Dict[str, Type[FiniteVectorEnv]] = {
"dummy": FiniteDummyVectorEnv,
"subproc": FiniteSubprocVectorEnv,
@@ -409,7 +366,4 @@ def vectorize_env(
finite_env_cls = env_type_cls_mapping[env_type]
if len(env_factories) == 1:
return finite_env_cls(logger, [env_factories[0] for _ in range(concurrency)])
else:
return finite_env_cls(logger, env_factories)
return finite_env_cls(logger, [env_factory for _ in range(concurrency)])

View File

@@ -1,30 +0,0 @@
import time
from contextlib import contextmanager
from typing import Callable, Generator
from line_profiler import LineProfiler
@contextmanager
def simple_perf(desc: str = "", out_path: str = None) -> Generator[None, None, None]:
s = time.perf_counter()
yield
e = time.perf_counter()
msg = f"{desc}: {(e - s) * 1000.0:.4f} ms"
if out_path is not None:
with open(out_path, "a") as fstream:
fstream.write(msg + "\n")
else:
print(msg)
def lprofile(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
lp = LineProfiler()
lpw = lp(func)
res = lpw(*args, **kwargs)
lp.print_stats()
return res
return wrapper

View File

@@ -18,7 +18,7 @@ from ..utils import fill_placeholder, flatten_dict, class_casting, get_date_by_s
from ..utils.time import Freq
from ..utils.data import deepcopy_basic_type
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
from qlib.contrib.analyzer import HFAnalyzer, SignalAnalyzer
logger = get_module_logger("workflow", logging.INFO)
@@ -156,6 +156,9 @@ class RecordTemp:
with class_casting(self, self.depend_cls):
self.check(include_self=True)
def analyse(self):
raise NotImplementedError(f"Please implement the `analysis` method.")
class SignalRecord(RecordTemp):
"""

15
scripts/finco/README.md Normal file
View File

@@ -0,0 +1,15 @@
# Requirements
Use following install command to complete the project.
```
pip install -e '.[finco]'
```
# TODOs
- [ ] Select the appropriate LLM API
- Which API is more suitable for meeting our requirements - the original API or an alternative like LangChain?

15
scripts/finco/cmd.sh Normal file
View File

@@ -0,0 +1,15 @@
#!/bin/bash
set -x # show command
set -e # Error on exception
DIR="$(
cd "$(dirname "$(readlink -f "$0")")" || exit
pwd -P
)"
# --load the cridentials
if [ -e $DIR/cridential.sh ]; then
source $DIR/cridential.sh
fi
# run the command
python -m qlib.finco.cli "please help me build a low turnover strategy that focus more on longterm return"

View File

@@ -0,0 +1,3 @@
export OPENAI_API_TYPE=azure # This only necessary for Azure OpenAI
export OPENAI_API_KEY=
export OPENAI_API_BASE=

View File

@@ -173,6 +173,14 @@ setup(
"tianshou<=0.4.10",
"torch",
],
"finco": [
# finco is not necessary for all Qlib users; So a single require section is used for it.
"openapi",
"pydantic", # Please add it to basic requirements after the design of pydantic is state.
"python-dotenv", # I don't think this is necessary if we use pydantic.
"fuzzywuzzy",
"python-Levenshtein" # not necessary but accelerate fuzzywuzzy calculation
],
},
include_package_data=True,
classifiers=[

71
tests/finco/test_cfg.py Normal file
View File

@@ -0,0 +1,71 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import unittest
import shutil
import difflib
from qlib.finco.tpl import get_tpl_path
import ruamel.yaml as yaml
from qlib.data.dataset.handler import DataHandlerLP
from qlib.utils import init_instance_by_config
from qlib.tests import TestAutoData
from pathlib import Path
from qlib.finco.tpl import get_tpl_path
from qlib.finco.task import YamlEditTask
DIRNAME = Path(__file__).absolute().resolve().parent
class FincoTpl(TestAutoData):
def test_tpl_consistence(self):
"""Motivation: make sure the configuable template is consistent with the default config"""
tpl_p = get_tpl_path()
with (tpl_p / "sl" / "workflow_config.yaml").open("rb") as fp:
config = yaml.safe_load(fp)
# init_data_handler
hd: DataHandlerLP = init_instance_by_config(config["task"]["dataset"]["kwargs"]["handler"])
# NOTE: The config in workflow_config.yaml is generated by the following code:
# dump in yaml format to file without auto linebreak
# print(yaml.dump(hd.data_loader.fields, width=10000, stream=open("_tmp", "w")))
with (tpl_p / "sl-cfg" / "workflow_config.yaml").open("rb") as fp:
config = yaml.safe_load(fp)
hd_ds: DataHandlerLP = init_instance_by_config(config["task"]["dataset"]["kwargs"]["handler"])
self.assertEqual(hd_ds.data_loader.fields, hd.data_loader.fields)
check = hd_ds.fetch().fillna(0.0) == hd.fetch().fillna(0.0)
self.assertTrue(check.all().all())
def test_update_yaml(self):
p = get_tpl_path() / "sl" / "workflow_config.yaml"
p_new = DIRNAME / "_test_config.yaml"
shutil.copy(p, p_new)
updated_content = """
class: LGBModelTest
module_path: qlib.contrib.model.gbdt
kwargs:
loss: mse
colsample_bytree: 1.8879
learning_rate: 0.3
subsample: 0.8790
lambda_l1: 205.7000
lambda_l2: 580.9769
max_depth: 9
num_leaves: 211
num_threads: 21
"""
t = YamlEditTask(p_new, "task.model", updated_content)
t.execute()
# NOTE: the formmat is changed by ruamel.yaml, so it can't be compared by text directly..
# print the diff between p and p_new with difflib
# with p.open("r") as fp:
# content = fp.read()
# with p_new.open("r") as fp:
# content_new = fp.read()
# for line in difflib.unified_diff(content, content_new, fromfile="original", tofile="new", lineterm=""):
# print(line)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,66 @@
import unittest
import os
import shutil
from dotenv import load_dotenv
# pydantic support load_dotenv, so load_dotenv will be deprecated in the future.
from qlib.finco.task import SummarizeTask
from qlib.finco.workflow import WorkflowContextManager
from qlib.finco.llm import APIBackend
from qlib.finco.workflow import WorkflowManager
load_dotenv(verbose=True, override=True)
class TestSummarize(unittest.TestCase):
def test_chat(self):
messages = [
{
"role": "system",
"content": "Your are a professional financial assistant.",
},
{
"role": "user",
"content": "How to write a perfect quant strategy.",
},
]
response = APIBackend().try_create_chat_completion(messages=messages)
print(response)
def test_execution(self):
task = SummarizeTask()
context = WorkflowContextManager()
context.set_context("workspace", "../../examples/benchmarks/Linear")
context.set_context("user_prompt", "My main focus is on the performance of the strategy's return."
"Please summarize the information and give me some advice.")
task.assign_context_manager(context)
resp = task.execute()
print(resp)
def test_generate_batch_result(self):
wm = WorkflowManager()
prompt = wm.default_user_prompt
# prompt = ""
workdir = os.path.dirname(wm.get_context().get_context("workspace"))
summaries_path = os.path.join(workdir, "summaries")
if not os.path.exists(summaries_path):
os.makedirs(summaries_path)
for i in range(10):
wm.run(prompt)
if os.path.exists(f"{workdir}/finCoReport.md"):
shutil.move(f"{workdir}/finCoReport.md", f"{workdir}/summaries/finCoReport{i}.md")
def test_parse2txt(self):
task = SummarizeTask()
resp = task.get_info_from_file("")
print(resp)
if __name__ == "__main__":
unittest.main()

23
tests/finco/test_utils.py Normal file
View File

@@ -0,0 +1,23 @@
import unittest
from qlib.finco.utils import SingletonBaseClass
class TimeUtils(unittest.TestCase):
def test_singleton(self):
# self.assertEqual(self.to_str(data.tail()), self.to_str(res))
closure_checker = []
class A(SingletonBaseClass):
def __init__(self) -> None:
closure_checker.append(0)
A()
self.assertEqual(len(closure_checker), 1)
A()
self.assertEqual(len(closure_checker), 1)
if __name__ == "__main__":
unittest.main()