mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Compare commits
35 Commits
4933fcefc4
...
high-freq-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
56edc16089 | ||
|
|
2b8462d137 | ||
|
|
1979cac50a | ||
|
|
424a48d0fb | ||
|
|
202bbea272 | ||
|
|
6a22136366 | ||
|
|
603c282415 | ||
|
|
22abe852f7 | ||
|
|
e3f463010b | ||
|
|
80aa08215f | ||
|
|
b3893067f7 | ||
|
|
e6dfccce2f | ||
|
|
f9c30f9834 | ||
|
|
f164bf8411 | ||
|
|
1f28044d84 | ||
|
|
3cf0d27a07 | ||
|
|
bcae4bb22e | ||
|
|
f680a564a0 | ||
|
|
9cd41e5a81 | ||
|
|
e23022e9d8 | ||
|
|
ebbbec2a6c | ||
|
|
13d39e6bbc | ||
|
|
b96aab6bef | ||
|
|
700eef4164 | ||
|
|
31c7d72485 | ||
|
|
30ad1967a2 | ||
|
|
0c6cad1d7b | ||
|
|
a0f22571de | ||
|
|
6835b2f67e | ||
|
|
7c4971e566 | ||
|
|
70a9d42c7d | ||
|
|
bcadf47f32 | ||
|
|
4dc14a2489 | ||
|
|
a03b08bb4c | ||
|
|
98086e4fdc |
@@ -31,6 +31,7 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
|
||||
- [Run a single model](#run-a-single-model)
|
||||
- [Run multiple models](#run-multiple-models)
|
||||
- [**Quant Dataset Zoo**](#quant-dataset-zoo)
|
||||
- [High-frequency execution](#high-frequency-execution)
|
||||
- [More About Qlib](#more-about-qlib)
|
||||
- [Offline Mode and Online Mode](#offline-mode-and-online-mode)
|
||||
- [Performance of Qlib Data Server](#performance-of-qlib-data-server)
|
||||
@@ -270,6 +271,14 @@ Dataset plays a very important role in Quant. Here is a list of the datasets bui
|
||||
[Here](https://qlib.readthedocs.io/en/latest/advanced/alpha.html) is a tutorial to build dataset with `Qlib`.
|
||||
Your PR to build new Quant dataset is highly welcomed.
|
||||
|
||||
# High-Frequency Execution
|
||||
High-frequency order execution is a fundamental problem in quantitative finance.
|
||||
It aims at fulfilling a specific trading order, either liquidation or acquirement, for a given instrument.
|
||||
AI has the potential to mine patterns from a huge mass of high-frequency market data and helps traders make better decisions during order execution.
|
||||
Here is a list of solutions built on `Qlib`.
|
||||
- [Universal Trading for Order Execution with Oracle Policy Distillation](examples/trade/)
|
||||
|
||||
|
||||
# More About Qlib
|
||||
The detailed documents are organized in [docs](docs/).
|
||||
[Sphinx](http://www.sphinx-doc.org) and the readthedocs theme is required to build the documentation in html formats.
|
||||
|
||||
28
examples/highfreq/README.md
Normal file
28
examples/highfreq/README.md
Normal file
@@ -0,0 +1,28 @@
|
||||
# High-Frequency Dataset
|
||||
|
||||
This dataset is an example for RL high frequency trading.
|
||||
|
||||
## Get High-Frequency Data
|
||||
|
||||
Get high-frequency data by running the following command:
|
||||
```bash
|
||||
python workflow.py get_data
|
||||
```
|
||||
|
||||
## Dump & Reload & Reinitialize the Dataset
|
||||
|
||||
|
||||
The High-Frequency Dataset is implemented as `qlib.data.dataset.DatasetH` in the `workflow.py`. `DatatsetH` is the subclass of [`qlib.utils.serial.Serializable`](https://qlib.readthedocs.io/en/latest/advanced/serial.html), whose state can be dumped in or loaded from disk in `pickle` format.
|
||||
|
||||
### About Reinitialization
|
||||
|
||||
After reloading `Dataset` from disk, `Qlib` also support reinitializing the dataset. It means that users can reset some states of `Dataset` or `DataHandler` such as `instruments`, `start_time`, `end_time` and `segments`, etc., and generate new data according to the states.
|
||||
|
||||
The example is given in `workflow.py`, users can run the code as follows.
|
||||
|
||||
### Run the Code
|
||||
|
||||
Run the example by running the following command:
|
||||
```bash
|
||||
python workflow.py dump_and_load_dataset
|
||||
```
|
||||
@@ -10,7 +10,6 @@ class HighFreqHandler(DataHandlerLP):
|
||||
instruments="csi300",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
freq="1min",
|
||||
infer_processors=[],
|
||||
learn_processors=[],
|
||||
fit_start_time=None,
|
||||
@@ -37,13 +36,13 @@ class HighFreqHandler(DataHandlerLP):
|
||||
"kwargs": {
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
"freq": "1min",
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
freq=freq,
|
||||
data_loader=data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
@@ -63,9 +62,9 @@ class HighFreqHandler(DataHandlerLP):
|
||||
def get_normalized_price_feature(price_field, shift=0):
|
||||
"""Get normalized price feature ops"""
|
||||
if shift == 0:
|
||||
template_norm = "{0}/Ref(DayLast({1}), 240)"
|
||||
template_norm = "Cut({0}/Ref(DayLast({1}), 240), 240, None)"
|
||||
else:
|
||||
template_norm = "Ref({0}, " + str(shift) + ")/Ref(DayLast({1}), 240)"
|
||||
template_norm = "Cut(Ref({0}, " + str(shift) + ")/Ref(DayLast({1}), 240), 240, None)"
|
||||
|
||||
feature_ops = template_norm.format(
|
||||
template_if.format(
|
||||
@@ -91,7 +90,7 @@ class HighFreqHandler(DataHandlerLP):
|
||||
names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"]
|
||||
|
||||
fields += [
|
||||
"{0}/Ref(DayLast(Mean({0}, 7200)), 240)".format(
|
||||
"Cut({0}/Ref(DayLast(Mean({0}, 7200)), 240), 240, None)".format(
|
||||
"If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format(
|
||||
template_paused.format("$volume"),
|
||||
template_paused.format(simpson_vwap),
|
||||
@@ -102,7 +101,7 @@ class HighFreqHandler(DataHandlerLP):
|
||||
]
|
||||
names += ["$volume"]
|
||||
fields += [
|
||||
"Ref({0}, 240)/Ref(DayLast(Mean({0}, 7200)), 240)".format(
|
||||
"Cut(Ref({0}, 240)/Ref(DayLast(Mean({0}, 7200)), 240), 240, None)".format(
|
||||
"If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format(
|
||||
template_paused.format("$volume"),
|
||||
template_paused.format(simpson_vwap),
|
||||
@@ -113,7 +112,7 @@ class HighFreqHandler(DataHandlerLP):
|
||||
]
|
||||
names += ["$volume_1"]
|
||||
|
||||
fields += [template_paused.format("Date($close)")]
|
||||
fields += ["Cut({0}, 240, None)".format(template_paused.format("Date($close)"))]
|
||||
names += ["date"]
|
||||
return fields, names
|
||||
|
||||
@@ -124,20 +123,19 @@ class HighFreqBacktestHandler(DataHandler):
|
||||
instruments="csi300",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
freq="1min",
|
||||
):
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
"freq": "1min",
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
freq=freq,
|
||||
data_loader=data_loader,
|
||||
)
|
||||
|
||||
@@ -151,18 +149,20 @@ class HighFreqBacktestHandler(DataHandler):
|
||||
# Because there is no vwap field in the yahoo data, a method similar to Simpson integration is used to approximate vwap
|
||||
simpson_vwap = "($open + 2*$high + 2*$low + $close)/6"
|
||||
fields += [
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
"Cut({0}, 240, None)".format(template_fillnan.format(template_paused.format("$close"))),
|
||||
]
|
||||
names += ["$close0"]
|
||||
fields += [
|
||||
template_if.format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format(simpson_vwap),
|
||||
"Cut({0}, 240, None)".format(
|
||||
template_if.format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format(simpson_vwap),
|
||||
)
|
||||
)
|
||||
]
|
||||
names += ["$vwap0"]
|
||||
fields += [
|
||||
"If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format(
|
||||
"Cut(If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0})), 240, None)".format(
|
||||
template_paused.format("$volume"),
|
||||
template_paused.format(simpson_vwap),
|
||||
template_paused.format("$low"),
|
||||
|
||||
@@ -8,6 +8,20 @@ from qlib.data.data import Cal
|
||||
|
||||
|
||||
def get_calendar_day(freq="day", future=False):
|
||||
"""Load High-Freq Calendar Date Using Memcache.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
freq : str
|
||||
frequency of read calendar file.
|
||||
future : bool
|
||||
whether including future trading day.
|
||||
|
||||
Returns
|
||||
-------
|
||||
_calendar:
|
||||
array of date.
|
||||
"""
|
||||
flag = f"{freq}_future_{future}_day"
|
||||
if flag in H["c"]:
|
||||
_calendar = H["c"][flag]
|
||||
@@ -18,6 +32,19 @@ def get_calendar_day(freq="day", future=False):
|
||||
|
||||
|
||||
class DayLast(ElemOperator):
|
||||
"""DayLast Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
a series of that each value equals the last value of its day
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_calendar = get_calendar_day(freq=freq)
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
@@ -25,18 +52,57 @@ class DayLast(ElemOperator):
|
||||
|
||||
|
||||
class FFillNan(ElemOperator):
|
||||
"""FFillNan Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
a forward fill nan feature
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.fillna(method="ffill")
|
||||
|
||||
|
||||
class BFillNan(ElemOperator):
|
||||
"""BFillNan Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
a backfoward fill nan feature
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.fillna(method="bfill")
|
||||
|
||||
|
||||
class Date(ElemOperator):
|
||||
"""Date Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
a series of that each value is the date corresponding to feature.index
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_calendar = get_calendar_day(freq=freq)
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
@@ -44,6 +110,22 @@ class Date(ElemOperator):
|
||||
|
||||
|
||||
class Select(PairOperator):
|
||||
"""Select Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature_left : Expression
|
||||
feature instance, select condition
|
||||
feature_right : Expression
|
||||
feature instance, select value
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
value(feature_right) that meets the condition(feature_left)
|
||||
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series_condition = self.feature_left.load(instrument, start_index, end_index, freq)
|
||||
series_feature = self.feature_right.load(instrument, start_index, end_index, freq)
|
||||
@@ -51,6 +133,58 @@ class Select(PairOperator):
|
||||
|
||||
|
||||
class IsNull(ElemOperator):
|
||||
"""IsNull Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
A series indicating whether the feature is nan
|
||||
"""
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.isnull()
|
||||
|
||||
|
||||
class Cut(ElemOperator):
|
||||
"""Cut Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
l : int
|
||||
l > 0, delete the first l elements of feature (default is None, which means 0)
|
||||
r : int
|
||||
r < 0, delete the last -r elements of feature (default is None, which means 0)
|
||||
Returns
|
||||
----------
|
||||
feature:
|
||||
A series with the first l and last -r elements deleted from the feature.
|
||||
Note: It is deleted from the raw data, not the sliced data
|
||||
"""
|
||||
|
||||
def __init__(self, feature, l=None, r=None):
|
||||
self.l = l
|
||||
self.r = r
|
||||
if (self.l is not None and self.l <= 0) or (self.r is not None and self.r >= 0):
|
||||
raise ValueError("Cut operator l shoud > 0 and r should < 0")
|
||||
|
||||
super(Cut, self).__init__(feature)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.iloc[self.l : self.r]
|
||||
|
||||
def get_extended_window_size(self):
|
||||
ll = 0 if self.l is None else self.l
|
||||
rr = 0 if self.r is None else abs(self.r)
|
||||
lft_etd, rght_etd = self.feature.get_extended_window_size()
|
||||
lft_etd = lft_etd + ll
|
||||
rght_etd = rght_etd + rr
|
||||
return lft_etd, rght_etd
|
||||
|
||||
@@ -9,7 +9,7 @@ import qlib
|
||||
import pickle
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.config import HIGH_FREQ_CONFIG
|
||||
from qlib.config import REG_CN, HIGH_FREQ_CONFIG
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
@@ -24,25 +24,24 @@ from qlib.data.ops import Operators
|
||||
from qlib.data.data import Cal
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull
|
||||
from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut
|
||||
|
||||
|
||||
class HighfreqWorkflow(object):
|
||||
|
||||
SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull], "expression_cache": None}
|
||||
SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None}
|
||||
|
||||
MARKET = "all"
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
start_time = "2020-09-14 00:00:00"
|
||||
end_time = "2021-01-18 16:00:00"
|
||||
train_end_time = "2020-11-30 16:00:00"
|
||||
test_start_time = "2020-12-01 00:00:00"
|
||||
start_time = pd.Timestamp("2020-09-15 00:00:00")
|
||||
end_time = pd.Timestamp("2021-01-18 16:00:00")
|
||||
train_end_time = pd.Timestamp("2020-11-30 16:00:00")
|
||||
test_start_time = pd.Timestamp("2020-12-01 00:00:00")
|
||||
|
||||
DATA_HANDLER_CONFIG0 = {
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"freq": "1min",
|
||||
"fit_start_time": start_time,
|
||||
"fit_end_time": train_end_time,
|
||||
"instruments": MARKET,
|
||||
@@ -51,7 +50,6 @@ class HighfreqWorkflow(object):
|
||||
DATA_HANDLER_CONFIG1 = {
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"freq": "1min",
|
||||
"instruments": MARKET,
|
||||
}
|
||||
|
||||
@@ -125,8 +123,7 @@ class HighfreqWorkflow(object):
|
||||
backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"])
|
||||
print(backtest_train, backtest_test)
|
||||
|
||||
del xtrain, xtest
|
||||
del backtest_train, backtest_test
|
||||
return
|
||||
|
||||
def dump_and_load_dataset(self):
|
||||
"""dump and load dataset state on disk"""
|
||||
@@ -148,19 +145,73 @@ class HighfreqWorkflow(object):
|
||||
dataset_backtest = pickle.load(file_dataset_backtest)
|
||||
|
||||
self._prepare_calender_cache()
|
||||
##=============reload_dataset=============
|
||||
dataset.init(init_type=DataHandlerLP.IT_LS)
|
||||
dataset_backtest.init()
|
||||
##=============reinit dataset=============
|
||||
dataset.init(
|
||||
handler_kwargs={
|
||||
"init_type": DataHandlerLP.IT_LS,
|
||||
"start_time": "2021-01-19 00:00:00",
|
||||
"end_time": "2021-01-25 16:00:00",
|
||||
},
|
||||
segment_kwargs={
|
||||
"test": (
|
||||
"2021-01-19 00:00:00",
|
||||
"2021-01-25 16:00:00",
|
||||
),
|
||||
},
|
||||
)
|
||||
dataset_backtest.init(
|
||||
handler_kwargs={
|
||||
"start_time": "2021-01-19 00:00:00",
|
||||
"end_time": "2021-01-25 16:00:00",
|
||||
},
|
||||
segment_kwargs={
|
||||
"test": (
|
||||
"2021-01-19 00:00:00",
|
||||
"2021-01-25 16:00:00",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
##=============get data=============
|
||||
xtrain, xtest = dataset.prepare(["train", "test"])
|
||||
backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"])
|
||||
xtest = dataset.prepare(["test"])
|
||||
backtest_test = dataset_backtest.prepare(["test"])
|
||||
|
||||
print(xtrain, xtest)
|
||||
print(backtest_train, backtest_test)
|
||||
del xtrain, xtest
|
||||
del backtest_train, backtest_test
|
||||
print(xtest, backtest_test)
|
||||
return
|
||||
|
||||
|
||||
def get_high_freq_data(self, data_path):
|
||||
self._init_qlib()
|
||||
self._prepare_calender_cache()
|
||||
|
||||
import os
|
||||
dataset = init_instance_by_config(self.task["dataset"])
|
||||
xtrain, xtest = dataset.prepare(["train", "test"])
|
||||
normed_feature = pd.concat([xtrain, xtest]).sort_index()
|
||||
dic = dict(tuple(normed_feature.groupby("instrument")))
|
||||
feature_path = os.path.join(data_path, "normed_feature/")
|
||||
if not os.path.exists(feature_path):
|
||||
os.makedirs(feature_path)
|
||||
for k, v in dic.items():
|
||||
v.to_pickle(feature_path + f"{k}.pkl")
|
||||
|
||||
|
||||
dataset_backtest = init_instance_by_config(self.task["dataset_backtest"])
|
||||
backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"])
|
||||
backtest = pd.concat([backtest_train, backtest_test]).sort_index()
|
||||
backtest['date'] = backtest.index.map(lambda x: x[1].date())
|
||||
backtest.set_index('date', append=True, drop=True, inplace=True)
|
||||
dic = dict(tuple(backtest.groupby("instrument")))
|
||||
backtest_path = os.path.join(data_path, "backtest/")
|
||||
if not os.path.exists(backtest_path):
|
||||
os.makedirs(backtest_path)
|
||||
for k, v in dic.items():
|
||||
v.to_pickle(backtest_path + f"{k}.pkl.backtest")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(HighfreqWorkflow)
|
||||
#fire.Fire(HighfreqWorkflow)
|
||||
data_path = '../data/'
|
||||
workflow = HighfreqWorkflow()
|
||||
workflow.get_high_freq_data(data_path)
|
||||
|
||||
|
||||
104
examples/trade/README.md
Normal file
104
examples/trade/README.md
Normal file
@@ -0,0 +1,104 @@
|
||||
# Universal Trading for Order Execution with Oracle Policy Distillation
|
||||
This is the experiment code for our AAAI 2021 paper "[Universal Trading for Order Execution with Oracle Policy Distillation](https://arxiv.org/abs/2103.10860)", including the implementations of all the compared methods in the paper and a general reinforcement learning framework for order execution in quantitative finance.
|
||||
|
||||
## Abstract
|
||||
As a fundamental problem in algorithmic trading, order execution aims at fulfilling a specific trading order, either liquidation or acquirement, for a given instrument. Towards effective execution strategy, recent years have witnessed the shift from the analytical view with model-based market assumptions to model-free perspective, i.e., reinforcement learning, due to its nature of sequential decision optimization. However, the noisy and yet imperfect market information that can be leveraged by the policy has made it quite challenging to build up sample efficient reinforcement learning methods to achieve effective order execution. In this paper, we propose a novel universal trading policy optimization framework to bridge the gap between the noisy yet imperfect market states and the optimal action sequences for order execution. Particularly, this framework leverages a policy distillation method that can better guide the learning of the common policy towards practically optimal execution by an oracle teacher with perfect information to approximate the optimal trading strategy. The extensive experiments have shown significant improvements of our method over various strong baselines, with reasonable trading actions.
|
||||
|
||||
## Environment Dependencies
|
||||
|
||||
### Dependencies
|
||||
|
||||
```
|
||||
gym==0.17.3
|
||||
torch==1.6.0
|
||||
numba==0.51.2
|
||||
numpy==1.19.1
|
||||
pandas==1.1.3
|
||||
tqdm==4.50.2
|
||||
tianshou==0.3.0.post1
|
||||
env==0.1.0
|
||||
PyYAML==5.4.1
|
||||
redis==3.5.3
|
||||
```
|
||||
|
||||
### Environment Variable
|
||||
|
||||
`EXP_PATH` Absolute path to your config folder, we give folder `exp` as an example.
|
||||
|
||||
`OUTPUT_DIR` Absolute path to your log folder.
|
||||
|
||||
## Data Processing
|
||||
|
||||
For Feature processing, we take Yahoo dataset as an example, which can be precessed in `qlib/examples/highfreq/workflow.py` file. If you have a need to change your data storage path, you can change the `data_path` in `workflow.py`, and then do the following.
|
||||
|
||||
```
|
||||
python workflow.py
|
||||
```
|
||||
|
||||
For order generation, if you have changed change the the `data_path` in `workflow.py`, change `data_path` in `order_gen.py` again, then do the following.
|
||||
|
||||
```
|
||||
python order_gen.py
|
||||
```
|
||||
|
||||
## Training and backtest
|
||||
|
||||
### Config file
|
||||
|
||||
Config file is need to start our project, we take `PPO`, `OPDS` and `OPD` as an example in folder `exp/example`. If you want to use our given config, make sure the `data_path` you set before matches the config file.
|
||||
|
||||
### Baseline method
|
||||
|
||||
To run a method, you can do the following.
|
||||
|
||||
```
|
||||
python main.py --config={config_path}
|
||||
```
|
||||
|
||||
Where `{config_path}` means the relative path from your config.yml to `EXP_PATH`.
|
||||
|
||||
If you need to run our given method such as PPO method, you can do the following.
|
||||
|
||||
```
|
||||
python main.py --config=example/PPO/config.yml
|
||||
```
|
||||
|
||||
### OPD method
|
||||
|
||||
OPD method is a multi step method, at first you should run OPDT as the teacher in OPD method.
|
||||
|
||||
```
|
||||
python main.py --config=example/OPDT/config.yml
|
||||
```
|
||||
|
||||
After training, find the `policy_best` file in your OPDT log file and copy it to `trade` file for backtest. Also you can change `policy_path` in the `example/OPDT_b/config.yml` to your `policy_best` file. Then run the backtest method.
|
||||
|
||||
```
|
||||
python main.py --config=example/OPDT_b/config.yml
|
||||
```
|
||||
|
||||
then processed feature from teacher. Remember to change `log_path` if you have changed `log_dir` in `OPDT_b/config.yml`.
|
||||
|
||||
```
|
||||
python teacher_feature.py
|
||||
```
|
||||
|
||||
and finally start our OPD method.
|
||||
|
||||
```
|
||||
python main.py --config=example/OPD/config.yml
|
||||
```
|
||||
|
||||
## Citation
|
||||
You are more than welcome to citetmu our paper:
|
||||
```
|
||||
@inproceedings{fang2021universal,
|
||||
title={Universal Trading for Order Execution with Oracle Policy Distillation},
|
||||
author={Fang, Yuchen and Ren, Kan and Liu, Weiqing and Zhou, Dong and Zhang, Weinan and Bian, Jiang and Yu, Yong and Liu, Tie-Yan},
|
||||
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
|
||||
volume={35},
|
||||
number={1},
|
||||
pages={107--115},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
10
examples/trade/__init__.py
Normal file
10
examples/trade/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# from rl4execution import env, trainer, exploration
|
||||
|
||||
# __all__ = [
|
||||
# "env",
|
||||
# "data",
|
||||
# "utils",
|
||||
# "policy",
|
||||
# "trainer",
|
||||
# "exploration",
|
||||
# ]
|
||||
4
examples/trade/action/__init__.py
Normal file
4
examples/trade/action/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .base import *
|
||||
from .action_rl import *
|
||||
from .action_rule import *
|
||||
from .action_rl import *
|
||||
27
examples/trade/action/action_rl.py
Normal file
27
examples/trade/action/action_rl.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, Box, Tuple, MultiDiscrete
|
||||
|
||||
from .base import Base_Action
|
||||
|
||||
|
||||
class Static_Action(Base_Action):
|
||||
""" """
|
||||
|
||||
def __init__(self, config):
|
||||
self.action_num = config["action_num"]
|
||||
self.action_map = config["action_map"]
|
||||
|
||||
def get_space(self):
|
||||
""" """
|
||||
return Discrete(self.action_num)
|
||||
|
||||
def get_action(self, action, target, position, **kargs):
|
||||
"""
|
||||
|
||||
:param action:
|
||||
:param position:
|
||||
:param target:
|
||||
:param **kargs:
|
||||
|
||||
"""
|
||||
return min(target * self.action_map[action], position)
|
||||
46
examples/trade/action/action_rule.py
Normal file
46
examples/trade/action/action_rule.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, Box, Tuple, MultiDiscrete
|
||||
|
||||
from .base import Base_Action
|
||||
|
||||
|
||||
class Rule_Dynamic(Base_Action):
|
||||
""" """
|
||||
|
||||
def get_space(self):
|
||||
""" """
|
||||
return Box(0, np.inf, shape=(), dtype=np.float32)
|
||||
|
||||
def get_action(self, action, target, position, max_step_num, t, **kargs):
|
||||
"""
|
||||
|
||||
:param action: param target:
|
||||
:param position: param max_step_num:
|
||||
:param t: param **kargs:
|
||||
:param target:
|
||||
:param max_step_num:
|
||||
:param **kargs:
|
||||
|
||||
"""
|
||||
return position / (max_step_num - (t + 1)) * action
|
||||
|
||||
|
||||
class Rule_Static(Base_Action):
|
||||
""" """
|
||||
|
||||
def get_space(self):
|
||||
""" """
|
||||
return Box(0, np.inf, shape=(), dtype=np.float32)
|
||||
|
||||
def get_action(self, action, target, position, max_step_num, t, **kargs):
|
||||
"""
|
||||
|
||||
:param action: param target:
|
||||
:param position: param max_step_num:
|
||||
:param t: param **kargs:
|
||||
:param target:
|
||||
:param max_step_num:
|
||||
:param **kargs:
|
||||
|
||||
"""
|
||||
return target / max_step_num * action
|
||||
20
examples/trade/action/base.py
Normal file
20
examples/trade/action/base.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, Box, Tuple, MultiDiscrete
|
||||
|
||||
|
||||
class Base_Action(object):
|
||||
""" """
|
||||
|
||||
def __init__(self, config):
|
||||
return
|
||||
|
||||
def __call__(self, *args, **kargs):
|
||||
return self.get_action(*args, **kargs)
|
||||
|
||||
def get_action(self, action):
|
||||
"""
|
||||
|
||||
:param action:
|
||||
|
||||
"""
|
||||
return action
|
||||
46
examples/trade/action/interval_rule.py
Normal file
46
examples/trade/action/interval_rule.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, Box, Tuple, MultiDiscrete
|
||||
|
||||
from .base import Base_Action
|
||||
|
||||
|
||||
class Rule_Static_Interval(Base_Action):
|
||||
""" """
|
||||
|
||||
def get_space(self):
|
||||
""" """
|
||||
return Box(0, np.inf, shape=(), dtype=np.float32)
|
||||
|
||||
def get_action(self, action, target, position, interval_num, interval, **kargs):
|
||||
"""
|
||||
|
||||
:param action: param target:
|
||||
:param position: param interval_num:
|
||||
:param interval: param **kargs:
|
||||
:param target:
|
||||
:param interval_num:
|
||||
:param **kargs:
|
||||
|
||||
"""
|
||||
return target / (interval_num) * action
|
||||
|
||||
|
||||
class Rule_Dynamic_Interval(Base_Action):
|
||||
""" """
|
||||
|
||||
def get_space(self):
|
||||
""" """
|
||||
return Box(0, np.inf, shape=(), dtype=np.float32)
|
||||
|
||||
def get_action(self, action, target, position, interval_num, interval, **kargs):
|
||||
"""
|
||||
|
||||
:param action: param target:
|
||||
:param position: param interval_num:
|
||||
:param interval: param **kargs:
|
||||
:param target:
|
||||
:param interval_num:
|
||||
:param **kargs:
|
||||
|
||||
"""
|
||||
return position / (interval_num - interval) * action
|
||||
1
examples/trade/agent/__init__.py
Normal file
1
examples/trade/agent/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .basic import *
|
||||
69
examples/trade/agent/basic.py
Normal file
69
examples/trade/agent/basic.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.data import Batch
|
||||
import numpy as np
|
||||
import torch
|
||||
from env import nan_weighted_avg
|
||||
|
||||
|
||||
class TWAP(BasePolicy):
|
||||
""" The TWAP strategy. """
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.max_step_num = config["max_step_num"]
|
||||
self.num_cpus = config["num_cpus"]
|
||||
|
||||
# @njit(parallel=True)
|
||||
def forward(self, batch: Batch, state=None, **kwargs) -> Batch:
|
||||
act = [1] * len(batch.obs.private)
|
||||
return Batch(act=act, state=state)
|
||||
|
||||
def learn(self, batch, batch_size, repeat):
|
||||
pass
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
pass
|
||||
|
||||
|
||||
class VWAP(BasePolicy):
|
||||
""" The VWAP strategy."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, batch, state, **kwargs):
|
||||
obs = batch.obs
|
||||
r = np.stack(obs.prediction).reshape(-1)
|
||||
return Batch(act=r, state=state)
|
||||
|
||||
def learn(self, batch, batch_size, repeat):
|
||||
pass
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
pass
|
||||
|
||||
|
||||
class AC(VWAP):
|
||||
"""Almgren-Chriss strategy."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.T = config["max_step_num"]
|
||||
self.gamma = 0
|
||||
self.tau = 1
|
||||
self.lamb = config["lambda"]
|
||||
self.eps = 0.0625
|
||||
self.alpha = 0.02
|
||||
self.eta = 2.5e-6
|
||||
|
||||
def forward(self, batch, state, **kwargs):
|
||||
obs = batch.obs
|
||||
sig = np.stack(obs.prediction).reshape(-1)
|
||||
sell = ~np.stack(obs.is_buy).astype(np.bool)
|
||||
data = np.stack(obs.private)
|
||||
t = data[:, 2]
|
||||
t = t + 1
|
||||
k_tild = self.lamb / self.eta * sig * sig
|
||||
k = np.arccosh(k_tild / 2 + 1)
|
||||
act = (np.sinh(k * (self.T - t)) - np.sinh(k * (self.T - t - 1))) / np.sinh(k * self.T)
|
||||
return Batch(act=act, state=state)
|
||||
342
examples/trade/collector.py
Normal file
342
examples/trade/collector.py
Normal file
@@ -0,0 +1,342 @@
|
||||
import gym
|
||||
import time
|
||||
import torch
|
||||
import warnings
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from numbers import Number
|
||||
from typing import Any, Dict, List, Union, Optional, Callable
|
||||
|
||||
from vecenv import BaseVectorEnv
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
|
||||
from tianshou.exploration import BaseNoise
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.data.collector import _batch_set_item
|
||||
|
||||
|
||||
class Collector(object):
|
||||
def __init__(
|
||||
self,
|
||||
policy: BasePolicy,
|
||||
env: Union[gym.Env, BaseVectorEnv],
|
||||
testing=False,
|
||||
buffer: Optional[ReplayBuffer] = None,
|
||||
preprocess_fn: Optional[Callable[..., Batch]] = None,
|
||||
action_noise: Optional[BaseNoise] = None,
|
||||
reward_metric: Optional[Callable[[np.ndarray], float]] = np.sum,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if not isinstance(env, BaseVectorEnv):
|
||||
env = DummyVectorEnv([lambda: env])
|
||||
self.env = env
|
||||
self.env_num = len(env)
|
||||
# environments that are available in step()
|
||||
# this means all environments in synchronous simulation
|
||||
# but only a subset of environments in asynchronous simulation
|
||||
self._ready_env_ids = np.arange(self.env_num)
|
||||
# self.async is a flag to indicate whether this collector works
|
||||
# with asynchronous simulation
|
||||
self.is_async = env.is_async
|
||||
self.testing = testing
|
||||
# need cache buffers before storing in the main buffer
|
||||
self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)]
|
||||
self.buffer = buffer
|
||||
self.policy = policy
|
||||
self.preprocess_fn = preprocess_fn
|
||||
self.process_fn = policy.process_fn
|
||||
# self._action_space = env.action_space
|
||||
self._action_noise = action_noise
|
||||
self._rew_metric = reward_metric or Collector._default_rew_metric
|
||||
# avoid creating attribute outside __init__
|
||||
# self.reset()
|
||||
|
||||
@staticmethod
|
||||
def _default_rew_metric(x: Union[Number, np.number]) -> Union[Number, np.number]:
|
||||
# this internal function is designed for single-agent RL
|
||||
# for multi-agent RL, a reward_metric must be provided
|
||||
assert np.asanyarray(x).size == 1, "Please specify the reward_metric " "since the reward is not a scalar."
|
||||
return x
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all related variables in the collector."""
|
||||
# use empty Batch for ``state`` so that ``self.data`` supports slicing
|
||||
# convert empty Batch to None when passing data to policy
|
||||
self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={}, obs_next={}, policy={})
|
||||
self.reset_env()
|
||||
self.reset_buffer()
|
||||
self.reset_stat()
|
||||
if self._action_noise is not None:
|
||||
self._action_noise.reset()
|
||||
|
||||
def reset_stat(self) -> None:
|
||||
"""Reset the statistic variables."""
|
||||
self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0
|
||||
|
||||
def reset_buffer(self) -> None:
|
||||
"""Reset the main data buffer."""
|
||||
if self.buffer is not None:
|
||||
self.buffer.reset()
|
||||
|
||||
def get_env_num(self) -> int:
|
||||
""" """
|
||||
return self.env_num
|
||||
|
||||
def reset_env(self) -> None:
|
||||
"""Reset all of the environment(s)' states and the cache buffers."""
|
||||
self._ready_env_ids = np.arange(self.env_num)
|
||||
self.env.reset_sampler()
|
||||
obs, stop_id = self.env.reset()
|
||||
if self.preprocess_fn:
|
||||
obs = self.preprocess_fn(obs=obs).get("obs", obs)
|
||||
self.data.obs = obs
|
||||
for b in self._cached_buf:
|
||||
b.reset()
|
||||
self._ready_env_ids = np.array([x for x in self._ready_env_ids if x not in stop_id])
|
||||
|
||||
def _reset_state(self, id: Union[int, List[int]]) -> None:
|
||||
"""Reset the hidden state: self.data.state[id]."""
|
||||
state = self.data.state # it is a reference
|
||||
if isinstance(state, torch.Tensor):
|
||||
state[id].zero_()
|
||||
elif isinstance(state, np.ndarray):
|
||||
state[id] = None if state.dtype == np.object else 0
|
||||
elif isinstance(state, Batch):
|
||||
state.empty_(id)
|
||||
|
||||
def collect(
|
||||
self,
|
||||
n_step: Optional[int] = None,
|
||||
n_episode: Optional[Union[int, List[int]]] = None,
|
||||
random: bool = False,
|
||||
render: Optional[float] = None,
|
||||
log_fn=None,
|
||||
no_grad: bool = True,
|
||||
) -> Dict[str, float]:
|
||||
"""Collect a specified number of step or episode.
|
||||
|
||||
:param int: n_step: how many steps you want to collect.
|
||||
:param n_episode: how many episodes you want to collect. If it is an
|
||||
int, it means to collect at lease ``n_episode`` episodes; if it is
|
||||
a list, it means to collect exactly ``n_episode[i]`` episodes in
|
||||
the i-th environment
|
||||
:param bool: random: whether to use random policy for collecting data,
|
||||
defaults to False.
|
||||
:param float: render: the sleep time between rendering consecutive
|
||||
frames, defaults to None (no rendering).
|
||||
:param bool: no_grad: whether to retain gradient in policy.forward,
|
||||
defaults to True (no gradient retaining).
|
||||
|
||||
.. note::
|
||||
|
||||
One and only one collection number specification is permitted,
|
||||
either ``n_step`` or ``n_episode``.
|
||||
|
||||
:param n_step: Optional[int]: (Default value = None)
|
||||
:param n_episode: Optional[Union[int:List[int]]]: (Default value = None)
|
||||
:param random: bool: (Default value = False)
|
||||
:param render: Optional[float]: (Default value = None)
|
||||
:param log_fn: Default value = None)
|
||||
:param no_grad: bool: (Default value = True)
|
||||
:param n_step: Optional[int]: (Default value = None)
|
||||
:param n_episode: Optional[Union[int:
|
||||
:param List[int]]]: (Default value = None)
|
||||
:param random: bool: (Default value = False)
|
||||
:param render: Optional[float]: (Default value = None)
|
||||
:param no_grad: bool: (Default value = True)
|
||||
:param n_step: Optional[int]: (Default value = None)
|
||||
:param n_episode: Optional[Union[int:
|
||||
:param random: bool: (Default value = False)
|
||||
:param render: Optional[float]: (Default value = None)
|
||||
:param no_grad: bool: (Default value = True)
|
||||
:returns: A dict including the following keys
|
||||
|
||||
* ``n/ep`` the collected number of episodes.
|
||||
* ``n/st`` the collected number of steps.
|
||||
* ``v/st`` the speed of steps per second.
|
||||
* ``v/ep`` the speed of episode per second.
|
||||
* ``rew`` the mean reward over collected episodes.
|
||||
* ``len`` the mean length over collected episodes.
|
||||
|
||||
"""
|
||||
assert (
|
||||
(n_step is not None and n_episode is None and n_step > 0)
|
||||
or (n_step is None and n_episode is not None and np.sum(n_episode) > 0)
|
||||
or self.testing
|
||||
), "Only one of n_step or n_episode is allowed in Collector.collect, "
|
||||
f"got n_step = {n_step}, n_episode = {n_episode}."
|
||||
start_time = time.time()
|
||||
step_count = 0
|
||||
step_time = 0.0
|
||||
reset_time = 0.0
|
||||
model_time = 0.0
|
||||
# episode of each environment
|
||||
episode_count = np.zeros(self.env_num)
|
||||
# If n_episode is a list, and some envs have collected the required
|
||||
# number of episodes, these envs will be recorded in this list, and
|
||||
# they will not be stepped.
|
||||
finished_env_ids = []
|
||||
rewards = []
|
||||
whole_data = Batch()
|
||||
if isinstance(n_episode, list):
|
||||
assert len(n_episode) == self.get_env_num()
|
||||
finished_env_ids = [i for i in self._ready_env_ids if n_episode[i] <= 0]
|
||||
self._ready_env_ids = np.array([x for x in self._ready_env_ids if x not in finished_env_ids])
|
||||
while True:
|
||||
if step_count >= 100000 and episode_count.sum() == 0:
|
||||
warnings.warn(
|
||||
"There are already many steps in an episode. "
|
||||
"You should add a time limitation to your environment!",
|
||||
Warning,
|
||||
)
|
||||
|
||||
is_async = self.is_async or len(finished_env_ids) > 0
|
||||
if is_async:
|
||||
# self.data are the data for all environments in async
|
||||
# simulation or some envs have finished,
|
||||
# **only a subset of data are disposed**,
|
||||
# so we store the whole data in ``whole_data``, let self.data
|
||||
# to be the data available in ready environments, and finally
|
||||
# set these back into all the data
|
||||
whole_data = self.data
|
||||
self.data = self.data[self._ready_env_ids]
|
||||
|
||||
# restore the state and the input data
|
||||
last_state = self.data.state
|
||||
if isinstance(last_state, Batch) and last_state.is_empty():
|
||||
last_state = None
|
||||
self.data.update(state=Batch(), obs_next=Batch(), policy=Batch())
|
||||
|
||||
# calculate the next action
|
||||
start = time.time()
|
||||
if random:
|
||||
spaces = self._action_space
|
||||
result = Batch(act=[spaces[i].sample() for i in self._ready_env_ids])
|
||||
else:
|
||||
if no_grad:
|
||||
with torch.no_grad(): # faster than retain_grad version
|
||||
result = self.policy(self.data, last_state)
|
||||
else:
|
||||
result = self.policy(self.data, last_state)
|
||||
model_time += time.time() - start
|
||||
state = result.get("state", Batch())
|
||||
# convert None to Batch(), since None is reserved for 0-init
|
||||
if state is None:
|
||||
state = Batch()
|
||||
self.data.update(state=state, policy=result.get("policy", Batch()))
|
||||
# save hidden state to policy._state, in order to save into buffer
|
||||
if not (isinstance(state, Batch) and state.is_empty()):
|
||||
self.data.policy._state = self.data.state
|
||||
|
||||
self.data.act = to_numpy(result.act)
|
||||
if self._action_noise is not None:
|
||||
assert isinstance(self.data.act, np.ndarray)
|
||||
self.data.act += self._action_noise(self.data.act.shape)
|
||||
|
||||
# step in env
|
||||
start = time.time()
|
||||
if not is_async:
|
||||
obs_next, rew, done, info = self.env.step(self.data.act)
|
||||
if log_fn:
|
||||
log_fn(info)
|
||||
else:
|
||||
# store computed actions, states, etc
|
||||
_batch_set_item(whole_data, self._ready_env_ids, self.data, self.env_num)
|
||||
# fetch finished data
|
||||
obs_next, rew, done, info = self.env.step(self.data.act, id=self._ready_env_ids)
|
||||
self._ready_env_ids = np.array([i["env_id"] for i in info])
|
||||
# get the stepped data
|
||||
self.data = whole_data[self._ready_env_ids]
|
||||
if log_fn:
|
||||
log_fn(info)
|
||||
|
||||
step_time += time.time() - start
|
||||
# move data to self.data
|
||||
self.data.update(obs_next=obs_next, rew=rew, done=done, info=[{} for i in info])
|
||||
|
||||
if render:
|
||||
self.env.render()
|
||||
time.sleep(render)
|
||||
|
||||
# add data into the buffer
|
||||
if self.preprocess_fn:
|
||||
result = self.preprocess_fn(**self.data) # type: ignore
|
||||
self.data.update(result)
|
||||
|
||||
for j, i in enumerate(self._ready_env_ids):
|
||||
# j is the index in current ready_env_ids
|
||||
# i is the index in all environments
|
||||
if self.buffer is None:
|
||||
# users do not want to store data, so we store
|
||||
# small fake data here to make the code clean
|
||||
self._cached_buf[i].add(obs=0, act=0, rew=rew[j], done=0)
|
||||
else:
|
||||
self._cached_buf[i].add(**self.data[j])
|
||||
|
||||
if done[j]:
|
||||
if not (isinstance(n_episode, list) and episode_count[i] >= n_episode[i]):
|
||||
episode_count[i] += 1
|
||||
rewards.append(self._rew_metric(np.sum(self._cached_buf[i].rew, axis=0)))
|
||||
step_count += len(self._cached_buf[i])
|
||||
if self.buffer is not None:
|
||||
self.buffer.update(self._cached_buf[i])
|
||||
if isinstance(n_episode, list) and episode_count[i] >= n_episode[i]:
|
||||
# env i has collected enough data, it has finished
|
||||
finished_env_ids.append(i)
|
||||
self._cached_buf[i].reset()
|
||||
self._reset_state(j)
|
||||
obs_next = self.data.obs_next
|
||||
start = time.time()
|
||||
if sum(done):
|
||||
env_ind_local = np.where(done)[0].tolist()
|
||||
env_ind_global = self._ready_env_ids[env_ind_local]
|
||||
obs_reset, stop_id = self.env.reset(env_ind_global)
|
||||
_ready_env_ids = self._ready_env_ids.tolist()
|
||||
for i in stop_id:
|
||||
finished_env_ids.append(i)
|
||||
# env_ind_local.remove(_ready_env_ids.index(i))
|
||||
if len(env_ind_local) > 0:
|
||||
if self.preprocess_fn:
|
||||
obs_reset = self.preprocess_fn(obs=obs_reset).get("obs", obs_reset)
|
||||
obs_next[env_ind_local] = obs_reset
|
||||
reset_time += time.time() - start
|
||||
self.data.obs = obs_next
|
||||
if is_async:
|
||||
# set data back
|
||||
whole_data = deepcopy(whole_data) # avoid reference in ListBuf
|
||||
_batch_set_item(whole_data, self._ready_env_ids, self.data, self.env_num)
|
||||
# let self.data be the data in all environments again
|
||||
self.data = whole_data
|
||||
self._ready_env_ids = np.array([x for x in self._ready_env_ids if x not in finished_env_ids])
|
||||
if n_step:
|
||||
if step_count >= n_step:
|
||||
break
|
||||
else:
|
||||
if isinstance(n_episode, int) and episode_count.sum() >= n_episode:
|
||||
break
|
||||
if isinstance(n_episode, list) and (episode_count >= n_episode).all():
|
||||
break
|
||||
if len(self._ready_env_ids) == 0 and self.testing:
|
||||
break
|
||||
|
||||
# finished envs are ready, and can be used for the next collection
|
||||
self._ready_env_ids = np.array(self._ready_env_ids.tolist() + finished_env_ids)
|
||||
|
||||
# generate the statistics
|
||||
episode_count = sum(episode_count)
|
||||
duration = max(time.time() - start_time, 1e-9)
|
||||
self.collect_step += step_count
|
||||
self.collect_episode += episode_count
|
||||
self.collect_time += duration
|
||||
return {
|
||||
"n/ep": episode_count,
|
||||
"n/st": step_count,
|
||||
"v/st": step_count / duration,
|
||||
"v/ep": episode_count / duration,
|
||||
"t/st": step_time / step_count,
|
||||
"t/re": reset_time / episode_count,
|
||||
"t/mo": model_time / step_count,
|
||||
"rew": np.mean(rewards),
|
||||
"rew_std": np.std(rewards),
|
||||
"len": step_count / episode_count,
|
||||
}
|
||||
1
examples/trade/env/__init__.py
vendored
Normal file
1
examples/trade/env/__init__.py
vendored
Normal file
@@ -0,0 +1 @@
|
||||
from .env_rl import *
|
||||
481
examples/trade/env/env_rl.py
vendored
Normal file
481
examples/trade/env/env_rl.py
vendored
Normal file
@@ -0,0 +1,481 @@
|
||||
import gym
|
||||
|
||||
gym.logger.set_level(40)
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pickle as pkl
|
||||
import datetime
|
||||
import random
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import tianshou as ts
|
||||
import copy
|
||||
from multiprocessing import Process, Pipe, Queue
|
||||
from typing import List, Tuple, Union, Optional, Callable, Any
|
||||
from tianshou.env.utils import CloudpickleWrapper
|
||||
from scipy.stats import pearsonr
|
||||
from sklearn.metrics import roc_auc_score
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.append("..")
|
||||
from util import merge_dicts, nan_weighted_avg, robust_auc
|
||||
import reward
|
||||
import observation
|
||||
import action
|
||||
|
||||
ZERO = 1e-7
|
||||
|
||||
|
||||
class StockEnv(gym.Env):
|
||||
"""Single-assert environment"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.max_step_num = config["max_step_num"]
|
||||
self.limit = config["limit"]
|
||||
self.time_interval = config["time_interval"]
|
||||
self.interval_num = config["interval_num"]
|
||||
self.offset = config["offset"] if "offset" in config else 0
|
||||
if "last_reward" in config:
|
||||
self.last_reward = config["last_reward"]
|
||||
else:
|
||||
self.last_reward = None
|
||||
if "log" in config:
|
||||
self.log = config["log"]
|
||||
else:
|
||||
self.log = True
|
||||
# loader_conf = config['loader']['config']
|
||||
obs_conf = config["obs"]["config"]
|
||||
obs_conf["features"] = config["features"]
|
||||
obs_conf["time_interval"] = self.time_interval
|
||||
obs_conf["max_step_num"] = self.max_step_num
|
||||
self.obs = getattr(observation, config["obs"]["name"])(obs_conf)
|
||||
self.action_func = getattr(action, config["action"]["name"])(config["action"]["config"])
|
||||
self.reward_func_list = []
|
||||
self.reward_log_dict = {}
|
||||
self.reward_coef = []
|
||||
for name, conf in config["reward"].items():
|
||||
self.reward_coef.append(conf.pop("coefficient"))
|
||||
self.reward_func_list.append(getattr(reward, name)(conf))
|
||||
self.reward_log_dict[name] = 0.0
|
||||
self.observation_space = self.obs.get_space()
|
||||
self.action_space = self.action_func.get_space()
|
||||
|
||||
def toggle_log(self, log):
|
||||
self.log = log
|
||||
|
||||
def reset(self, sample):
|
||||
"""
|
||||
|
||||
:param sample:
|
||||
|
||||
"""
|
||||
|
||||
for key in self.reward_log_dict.keys():
|
||||
self.reward_log_dict[key] = 0.0
|
||||
if not sample is None:
|
||||
(
|
||||
self.ins,
|
||||
self.date,
|
||||
self.raw_df_values,
|
||||
self.raw_df_columns,
|
||||
self.raw_df_index,
|
||||
self.feature_dfs,
|
||||
self.target,
|
||||
self.is_buy,
|
||||
) = sample
|
||||
self.raw_df = pd.DataFrame(index=self.raw_df_index, data=self.raw_df_values, columns=self.raw_df_columns,)
|
||||
del self.raw_df_values, self.raw_df_columns, self.raw_df_index
|
||||
start_time = time.time()
|
||||
self.load_time = time.time() - start_time
|
||||
self.day_vwap = nan_weighted_avg(
|
||||
self.raw_df["$vwap0"].values[self.offset : self.offset + self.max_step_num],
|
||||
self.raw_df["$volume0"].values[self.offset : self.offset + self.max_step_num],
|
||||
)
|
||||
try:
|
||||
assert not (np.isnan(self.day_vwap) or np.isinf(self.day_vwap))
|
||||
except:
|
||||
print(self.raw_df)
|
||||
print(self.ins)
|
||||
print(self.day_vwap)
|
||||
self.raw_df.to_pickle("/nfs_data1/kanren/error_df.pkl")
|
||||
self.day_twap = np.nanmean(self.raw_df["$vwap0"].values[self.offset : self.offset + self.max_step_num])
|
||||
self.t = -1 + self.offset
|
||||
self.interval = 0
|
||||
self.position = self.target
|
||||
self.eps_start = time.time()
|
||||
|
||||
self.state = self.obs(
|
||||
self.raw_df,
|
||||
self.feature_dfs,
|
||||
self.t,
|
||||
self.interval,
|
||||
self.position,
|
||||
self.target,
|
||||
self.is_buy,
|
||||
self.max_step_num,
|
||||
self.interval_num,
|
||||
)
|
||||
if self.log:
|
||||
index_array = [
|
||||
np.array([self.ins] * self.max_step_num),
|
||||
self.raw_df.index.to_numpy()[self.offset : self.offset + self.max_step_num],
|
||||
np.array([self.date] * self.max_step_num),
|
||||
]
|
||||
self.traded_log = pd.DataFrame(
|
||||
data={
|
||||
"$v_t": np.nan,
|
||||
"$max_vol_t": (self.raw_df["$volume0"] * self.limit).values[
|
||||
self.offset : self.offset + self.max_step_num
|
||||
],
|
||||
"$traded_t": np.nan,
|
||||
"$vwap_t": self.raw_df["$vwap0"].values[self.offset : self.offset + self.max_step_num],
|
||||
"action": np.nan,
|
||||
},
|
||||
index=index_array,
|
||||
)
|
||||
# v_t: The amount of shares the agent hope to trade
|
||||
# max_vol_t: The max amount of shares can be traded
|
||||
# traded_t: The amount of shares that is acually traded
|
||||
# action: the action of agent, may have various meanings in different settings.
|
||||
self.done = False
|
||||
if self.limit > 1:
|
||||
self.this_valid = np.inf
|
||||
else:
|
||||
self.this_valid = np.nansum(self.raw_df["$volume0"].values) * self.limit
|
||||
self.this_cash = 0
|
||||
|
||||
self.step_time = []
|
||||
self.action_log = [np.nan] * self.interval_num
|
||||
self.reset_time = time.time() - start_time
|
||||
self.real_eps_time = self.reset_time
|
||||
self.total_reward = 0
|
||||
self.total_instant_rew = 0
|
||||
self.last_rew = 0
|
||||
return self.state
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
|
||||
:param action:
|
||||
|
||||
"""
|
||||
start_time = time.time()
|
||||
self.action_log[self.interval] = action
|
||||
volume_t = self.action_func(
|
||||
action,
|
||||
self.target,
|
||||
self.position,
|
||||
max_step_num=self.max_step_num,
|
||||
t=self.t - self.offset,
|
||||
interval=self.interval,
|
||||
interval_num=self.interval_num,
|
||||
)
|
||||
self.interval += 1
|
||||
reward = 0.0
|
||||
time_left = self.max_step_num - self.t - 1 + self.offset
|
||||
|
||||
for i in range(self.time_interval):
|
||||
v_t = volume_t / min(self.time_interval, time_left)
|
||||
self.t += 1
|
||||
if self.t == self.max_step_num - 1 + self.offset:
|
||||
v_t = self.position
|
||||
if self.log:
|
||||
log_index = self.t - self.offset
|
||||
self.traded_log.iat[log_index, 0] = v_t
|
||||
self.traded_log.iat[log_index, 4] = action
|
||||
vwap_t, vol_t = self.raw_df.iloc[self.t][["$vwap0", "$volume0"]]
|
||||
max_vol_t = self.limit * vol_t
|
||||
if self.limit >= 1:
|
||||
max_vol_t = np.inf
|
||||
if v_t > min(self.position, max_vol_t):
|
||||
if self.position <= max_vol_t:
|
||||
v_t = self.position
|
||||
else:
|
||||
v_t = max_vol_t
|
||||
self.position -= v_t
|
||||
self.this_cash += vwap_t * v_t
|
||||
if self.log:
|
||||
self.traded_log.iat[log_index, 2] = v_t
|
||||
|
||||
if self.is_buy:
|
||||
performance_raise = (1 - vwap_t / self.day_vwap) * 10000
|
||||
PA_t = (1 - vwap_t / self.day_twap) * 10000
|
||||
else:
|
||||
performance_raise = (vwap_t / self.day_vwap - 1) * 10000
|
||||
PA_t = (vwap_t / self.day_twap - 1) * 10000
|
||||
|
||||
for i, reward_func in enumerate(self.reward_func_list):
|
||||
if reward_func.isinstant:
|
||||
tmp_r = reward_func(performance_raise, v_t, self.target, PA_t)
|
||||
reward += tmp_r * self.reward_coef[i]
|
||||
self.reward_log_dict[type(reward_func).__name__] += tmp_r
|
||||
|
||||
if self.t == self.max_step_num - 1 + self.offset:
|
||||
break
|
||||
|
||||
if self.position < ZERO:
|
||||
self.done = True
|
||||
|
||||
if self.interval == self.interval_num:
|
||||
self.done = True
|
||||
|
||||
self.step_time.append(time.time() - start_time)
|
||||
self.real_eps_time += time.time() - start_time
|
||||
if self.done:
|
||||
this_traded = self.target - self.position
|
||||
this_vwap = (self.this_cash / this_traded) if this_traded > ZERO else self.day_vwap
|
||||
valid = min(self.target, self.this_valid)
|
||||
this_ffr = (this_traded / valid) if valid > ZERO else 1.0
|
||||
if abs(this_ffr - 1.0) < ZERO:
|
||||
this_ffr = 1.0
|
||||
this_ffr *= 100
|
||||
this_vv_ratio = this_vwap / self.day_vwap
|
||||
vwap = self.raw_df["$vwap0"].values[self.offset : self.max_step_num + self.offset]
|
||||
this_tt_ratio = this_vwap / np.nanmean(vwap)
|
||||
|
||||
if self.is_buy:
|
||||
performance_raise = (1 - this_vv_ratio) * 10000
|
||||
PA = (1 - this_tt_ratio) * 10000
|
||||
else:
|
||||
performance_raise = (this_vv_ratio - 1) * 10000
|
||||
PA = (this_tt_ratio - 1) * 10000
|
||||
|
||||
for i, reward_func in enumerate(self.reward_func_list):
|
||||
if not reward_func.isinstant:
|
||||
tmp_r = reward_func(performance_raise, this_ffr, this_tt_ratio, self.is_buy)
|
||||
reward += tmp_r * self.reward_coef[i]
|
||||
self.reward_log_dict[type(reward_func).__name__] += tmp_r
|
||||
|
||||
self.state = self.obs(
|
||||
self.raw_df,
|
||||
self.feature_dfs,
|
||||
self.t,
|
||||
self.interval,
|
||||
self.position,
|
||||
self.target,
|
||||
self.is_buy,
|
||||
self.max_step_num,
|
||||
self.interval_num,
|
||||
action,
|
||||
)
|
||||
if self.log:
|
||||
res = pd.DataFrame(
|
||||
{
|
||||
"target": self.target,
|
||||
"sell": not self.is_buy,
|
||||
"vwap": this_vwap,
|
||||
"this_vv_ratio": this_vv_ratio,
|
||||
"this_ffr": this_ffr,
|
||||
},
|
||||
index=[[self.ins], [self.date]],
|
||||
)
|
||||
money = self.target * self.day_vwap
|
||||
if self.is_buy:
|
||||
info = {
|
||||
"money": money,
|
||||
"money_buy": money,
|
||||
"action": self.action_log,
|
||||
"ffr": this_ffr,
|
||||
"obs0_PR": performance_raise,
|
||||
"ffr_buy": this_ffr,
|
||||
"PR_buy": performance_raise,
|
||||
"PA": PA,
|
||||
"PA_buy": PA,
|
||||
"vwap": this_vwap,
|
||||
}
|
||||
else:
|
||||
info = {
|
||||
"money": money,
|
||||
"money_sell": money,
|
||||
"action": self.action_log,
|
||||
"ffr": this_ffr,
|
||||
"obs0_PR": performance_raise,
|
||||
"ffr_sell": this_ffr,
|
||||
"PR_sell": performance_raise,
|
||||
"PA": PA,
|
||||
"PA_sell": PA,
|
||||
"vwap": this_vwap,
|
||||
}
|
||||
info = merge_dicts(info, self.reward_log_dict)
|
||||
if self.log:
|
||||
info["df"] = self.traded_log
|
||||
info["res"] = res
|
||||
del self.feature_dfs
|
||||
return self.state, reward, self.done, info
|
||||
|
||||
else:
|
||||
self.state = self.obs(
|
||||
self.raw_df,
|
||||
self.feature_dfs,
|
||||
self.t,
|
||||
self.interval,
|
||||
self.position,
|
||||
self.target,
|
||||
self.is_buy,
|
||||
self.max_step_num,
|
||||
self.interval_num,
|
||||
action,
|
||||
)
|
||||
return self.state, reward, self.done, {}
|
||||
|
||||
|
||||
class StockEnv_Acc(StockEnv):
|
||||
def step(self, action):
|
||||
start_time = time.time()
|
||||
self.action_log[self.interval] = action
|
||||
volume_t = self.action_func(
|
||||
action,
|
||||
self.target,
|
||||
self.position,
|
||||
max_step_num=self.max_step_num,
|
||||
t=self.t - self.offset,
|
||||
interval=self.interval,
|
||||
interval_num=self.interval_num,
|
||||
)
|
||||
self.interval += 1
|
||||
reward = 0.0
|
||||
time_left = self.max_step_num - self.t - 1 + self.offset
|
||||
time_left = min(self.time_interval, time_left)
|
||||
|
||||
v_t = np.repeat(volume_t / time_left, time_left)
|
||||
minutes = np.arange(self.t + 1, self.t + time_left + 1)
|
||||
if self.log:
|
||||
log_index = minutes - self.offset
|
||||
self.traded_log.iloc[log_index, 0] = v_t
|
||||
self.traded_log.iloc[log_index, 4] = action
|
||||
vwap_t = self.raw_df.iloc[minutes]["$vwap0"].values
|
||||
vol_t = self.raw_df.iloc[minutes]["$volume0"].values
|
||||
max_vol_t = self.limit * vol_t if self.limit < 1 else np.inf
|
||||
v_t = np.minimum(v_t, max_vol_t)
|
||||
if self.t + time_left == self.max_step_num - 1 + self.offset:
|
||||
left = self.position - v_t.sum()
|
||||
v_t[-1] += left
|
||||
v_t = np.minimum(v_t, max_vol_t)
|
||||
this_money = (v_t * vwap_t).sum()
|
||||
this_vol = v_t.sum()
|
||||
this_vwap = np.nan_to_num(this_money / this_vol)
|
||||
self.t += time_left
|
||||
self.position -= this_vol
|
||||
self.this_cash += this_money
|
||||
if self.log:
|
||||
self.traded_log.iloc[log_index, 2] = v_t
|
||||
|
||||
if self.is_buy:
|
||||
performance_raise = (1 - this_vwap / self.day_vwap) * 10000
|
||||
PA_t = (1 - this_vwap / self.day_twap) * 10000
|
||||
else:
|
||||
performance_raise = (this_vwap / self.day_vwap - 1) * 10000
|
||||
PA_t = (this_vwap / self.day_twap - 1) * 10000
|
||||
|
||||
for i, reward_func in enumerate(self.reward_func_list):
|
||||
if reward_func.isinstant:
|
||||
tmp_r = reward_func(performance_raise, v_t, self.target, PA_t)
|
||||
reward += tmp_r * self.reward_coef[i]
|
||||
self.reward_log_dict[type(reward_func).__name__] += tmp_r
|
||||
|
||||
if self.position < ZERO:
|
||||
self.done = True
|
||||
|
||||
if self.interval == self.interval_num:
|
||||
self.done = True
|
||||
|
||||
self.step_time.append(time.time() - start_time)
|
||||
self.real_eps_time += time.time() - start_time
|
||||
if self.done:
|
||||
this_traded = self.target - self.position
|
||||
this_vwap = (self.this_cash / this_traded) if this_traded > ZERO else self.day_vwap
|
||||
valid = min(self.target, self.this_valid)
|
||||
this_ffr = (this_traded / valid) if valid > ZERO else 1.0
|
||||
if abs(this_ffr - 1.0) < ZERO:
|
||||
this_ffr = 1.0
|
||||
this_ffr *= 100
|
||||
this_vv_ratio = this_vwap / self.day_vwap
|
||||
vwap = self.raw_df["$vwap0"].values[self.offset : self.max_step_num + self.offset]
|
||||
this_tt_ratio = this_vwap / np.nanmean(vwap)
|
||||
|
||||
if self.is_buy:
|
||||
performance_raise = (1 - this_vv_ratio) * 10000
|
||||
PA = (1 - this_tt_ratio) * 10000
|
||||
else:
|
||||
performance_raise = (this_vv_ratio - 1) * 10000
|
||||
PA = (this_tt_ratio - 1) * 10000
|
||||
|
||||
for i, reward_func in enumerate(self.reward_func_list):
|
||||
if not reward_func.isinstant:
|
||||
tmp_r = reward_func(performance_raise, this_ffr, this_tt_ratio, self.is_buy)
|
||||
reward += tmp_r * self.reward_coef[i]
|
||||
self.reward_log_dict[type(reward_func).__name__] += tmp_r
|
||||
|
||||
self.state = self.obs(
|
||||
self.raw_df,
|
||||
self.feature_dfs,
|
||||
self.t,
|
||||
self.interval,
|
||||
self.position,
|
||||
self.target,
|
||||
self.is_buy,
|
||||
self.max_step_num,
|
||||
self.interval_num,
|
||||
action,
|
||||
)
|
||||
if self.log:
|
||||
res = pd.DataFrame(
|
||||
{
|
||||
"target": self.target,
|
||||
"sell": not self.is_buy,
|
||||
"vwap": this_vwap,
|
||||
"this_vv_ratio": this_vv_ratio,
|
||||
"this_ffr": this_ffr,
|
||||
},
|
||||
index=[[self.ins], [self.date]],
|
||||
)
|
||||
money = self.target * self.day_vwap
|
||||
if self.is_buy:
|
||||
info = {
|
||||
"money": money,
|
||||
"money_buy": money,
|
||||
"action": self.action_log,
|
||||
"ffr": this_ffr,
|
||||
"obs0_PR": performance_raise,
|
||||
"ffr_buy": this_ffr,
|
||||
"PR_buy": performance_raise,
|
||||
"PA": PA,
|
||||
"PA_buy": PA,
|
||||
"vwap": this_vwap,
|
||||
}
|
||||
else:
|
||||
info = {
|
||||
"money": money,
|
||||
"money_sell": money,
|
||||
"action": self.action_log,
|
||||
"ffr": this_ffr,
|
||||
"obs0_PR": performance_raise,
|
||||
"ffr_sell": this_ffr,
|
||||
"PR_sell": performance_raise,
|
||||
"PA": PA,
|
||||
"PA_sell": PA,
|
||||
"vwap": this_vwap,
|
||||
}
|
||||
info = merge_dicts(info, self.reward_log_dict)
|
||||
if self.log:
|
||||
info["df"] = self.traded_log
|
||||
info["res"] = res
|
||||
del self.feature_dfs
|
||||
return self.state, reward, self.done, info
|
||||
|
||||
else:
|
||||
self.state = self.obs(
|
||||
self.raw_df,
|
||||
self.feature_dfs,
|
||||
self.t,
|
||||
self.interval,
|
||||
self.position,
|
||||
self.target,
|
||||
self.is_buy,
|
||||
self.max_step_num,
|
||||
self.interval_num,
|
||||
action,
|
||||
)
|
||||
return self.state, reward, self.done, {}
|
||||
351
examples/trade/executor.py
Normal file
351
examples/trade/executor.py
Normal file
@@ -0,0 +1,351 @@
|
||||
import env
|
||||
from vecenv import *
|
||||
import sampler
|
||||
import logger
|
||||
import json
|
||||
import os
|
||||
import agent
|
||||
import network
|
||||
import policy
|
||||
import random
|
||||
import tianshou as ts
|
||||
import tqdm
|
||||
from tianshou.utils import tqdm_config, MovAvg
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from collector import *
|
||||
import numpy as np
|
||||
|
||||
|
||||
from util import merge_dicts
|
||||
|
||||
|
||||
def get_best_gpu(force=None):
|
||||
if force is not None:
|
||||
return force
|
||||
s = os.popen("nvidia-smi --query-gpu=memory.free --format=csv")
|
||||
a = []
|
||||
ss = s.read().replace("MiB", "").replace("memory.free", "").split("\n")
|
||||
s.close()
|
||||
for i in range(1, len(ss) - 1):
|
||||
a.append(int(ss[i]))
|
||||
best = int(np.argmax(a))
|
||||
print("the best GPU is ", best, " with free memories of ", ss[best + 1])
|
||||
return best
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
"""
|
||||
|
||||
:param seed:
|
||||
|
||||
"""
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
class BaseExecutor(object):
|
||||
def __init__(
|
||||
self,
|
||||
log_dir,
|
||||
resources,
|
||||
env_conf,
|
||||
optim=None,
|
||||
policy_conf=None,
|
||||
network_conf=None,
|
||||
policy_path=None,
|
||||
seed=None,
|
||||
):
|
||||
"""A base class for executor
|
||||
|
||||
:param log_dir: The directory to write all the logs.
|
||||
:type log_dir: string
|
||||
:param resources: A dict which describes available computational resources.
|
||||
:type resources: dict
|
||||
:param env_conf: Configurations for the envionments.
|
||||
:type env_conf: dict
|
||||
:param optim: Optimization configuration, defaults to None
|
||||
:type optim: dict, optional
|
||||
:param policy_conf: Configurations for the RL algorithm, defaults to None
|
||||
:type policy_conf: dict, optional
|
||||
:param network_conf: Configurations for policy network_conf, defaults to None
|
||||
:type network_conf: dict, optional
|
||||
:param policy_path: If is not None, would load the policy from this path, defaults to None
|
||||
:type policy_path: string, optional
|
||||
:param seed: Random seed, defaults to None
|
||||
:type seed: int, optional
|
||||
"""
|
||||
# self.config = config
|
||||
self.log_dir = log_dir
|
||||
print(self.log_dir)
|
||||
if not os.path.exists(self.log_dir):
|
||||
os.makedirs(self.log_dir)
|
||||
if resources["device"] == "cuda":
|
||||
resources["device"] = "cuda:" + str(get_best_gpu())
|
||||
self.device = torch.device(resources["device"])
|
||||
if seed:
|
||||
setup_seed(seed)
|
||||
|
||||
assert not policy_path is None or not policy_conf is None, "Policy must be defined"
|
||||
if policy_path:
|
||||
self.policy = torch.load(policy_path, map_location=self.device)
|
||||
self.policy.actor.extractor.device = self.device
|
||||
# policy.eval()
|
||||
elif hasattr(agent, policy_conf["name"]):
|
||||
policy_conf["config"] = merge_dicts(policy_conf["config"], resources)
|
||||
self.policy = getattr(agent, policy_conf["name"])(policy_conf["config"])
|
||||
# print(self.policy)
|
||||
else:
|
||||
assert not network_conf is None
|
||||
if "extractor" in network_conf.keys():
|
||||
net = getattr(network, network_conf["extractor"]["name"] + "_Extractor")(
|
||||
device=self.device, **network_conf["config"]
|
||||
)
|
||||
else:
|
||||
net = getattr(network, network_conf["name"] + "_Extractor")(
|
||||
device=self.device, **network_conf["config"]
|
||||
)
|
||||
net.to(self.device)
|
||||
actor = getattr(network, network_conf["name"] + "_Actor")(
|
||||
extractor=net, device=self.device, **network_conf["config"]
|
||||
)
|
||||
actor.to(self.device)
|
||||
critic = getattr(network, network_conf["name"] + "_Critic")(
|
||||
extractor=net, device=self.device, **network_conf["config"]
|
||||
)
|
||||
critic.to(self.device)
|
||||
self.optim = torch.optim.Adam(
|
||||
list(actor.parameters()) + list(critic.parameters()),
|
||||
lr=optim["lr"],
|
||||
weight_decay=optim["weight_decay"] if "weight_decay" in optim else 0.0,
|
||||
)
|
||||
self.dist = torch.distributions.Categorical
|
||||
try:
|
||||
self.policy = getattr(ts.policy, policy_conf["name"])(
|
||||
actor, critic, self.optim, self.dist, **policy_conf["config"]
|
||||
)
|
||||
except:
|
||||
self.policy = getattr(policy, policy_conf["name"])(
|
||||
actor, critic, self.optim, self.dist, **policy_conf["config"]
|
||||
)
|
||||
self.writer = SummaryWriter(self.log_dir)
|
||||
|
||||
def train(
|
||||
self,
|
||||
max_epoch,
|
||||
step_per_epoch,
|
||||
repeat_per_collect,
|
||||
collect_per_step,
|
||||
batch_size,
|
||||
iteration=0,
|
||||
global_step=0,
|
||||
early_stopping=5,
|
||||
*args,
|
||||
**kargs,
|
||||
):
|
||||
"""Run the whole training process.
|
||||
|
||||
:param max_epoch: The total number of epoch.
|
||||
:param step_per_epoch: The times of bp in one epoch.
|
||||
:param collect_per_step: Number of episodes to collect before one bp.
|
||||
:param repeat_per_collect: Times of bps after every rould of experience collecting.
|
||||
:param batch_size: Batch size when bp.
|
||||
:param iteration: The iteration when starting the training, used when fine tuning. (Default value = 0)
|
||||
:param global_step: The number of steps when starting the training, used when fine tuning. (Default value = 0)
|
||||
:param early_stopping: If the test reward does not reach a new high in `early_stopping` iterations, the training would stop. (Default value = 5)
|
||||
:returns: The result on test set.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def train_round(self, repeat_per_collect, collect_per_step, batch_size, *args, **kargs):
|
||||
"""Do an round of training
|
||||
|
||||
:param collect_per_step: Number of episodes to collect before one bp.
|
||||
:param repeat_per_collect: Times of bps after every rould of experience collecting.
|
||||
:param batch_size: Batch size when bp.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def eval(self, order_dir, save_res=False, logdir=None, *args, **kargs):
|
||||
"""Evaluate the policy on orders in order_dir
|
||||
|
||||
:param order_dir: the orders to be evaluated on.
|
||||
:param save_res: whether the result of evaluation be saved to self.logdir/res.json (Default value = False)
|
||||
:param logdir: the place to save the .log and .pkl log files to. If None, don't save logfiles. (Default value = None)
|
||||
:returns: The result of evaluation.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Executor(BaseExecutor):
|
||||
def __init__(
|
||||
self,
|
||||
log_dir,
|
||||
resources,
|
||||
env_conf,
|
||||
train_paths,
|
||||
valid_paths,
|
||||
test_paths,
|
||||
io_conf,
|
||||
optim=None,
|
||||
policy_conf=None,
|
||||
network_conf=None,
|
||||
policy_path=None,
|
||||
seed=None,
|
||||
share_memory=False,
|
||||
buffer_size=200000,
|
||||
q_learning=False,
|
||||
*args,
|
||||
**kargs,
|
||||
):
|
||||
"""[summary]
|
||||
|
||||
:param log_dir: The directory to write all the logs.
|
||||
:type log_dir: string
|
||||
:param resources: A dict which describes available computational resources.
|
||||
:type resources: dict
|
||||
:param env_conf: Configurations for the envionments.
|
||||
:type env_conf: dict
|
||||
:param train_paths: The paths of training datasets including orders, backtest files and features.
|
||||
:type train_paths: string
|
||||
:param valid_paths: The paths of validation datasets including orders, backtest files and features.
|
||||
:type valid_paths: string
|
||||
:param test_paths: The paths of test datasets including orders, backtest files and features.
|
||||
:type test_paths: string
|
||||
:param io_conf: Configuration for sampler and loggers.
|
||||
:type io_conf: dict
|
||||
:param share_memory: Whether to use shared memory vecnev, defaults to False
|
||||
:type share_memory: bool, optional
|
||||
:param buffer_size: The size of replay buffer, defaults to 200000
|
||||
:type buffer_size: int, optional
|
||||
"""
|
||||
super().__init__(log_dir, resources, env_conf, optim, policy_conf, network_conf, policy_path, seed)
|
||||
single_env = getattr(env, env_conf["name"])
|
||||
env_conf = merge_dicts(env_conf, train_paths)
|
||||
env_conf["log"] = True
|
||||
print("CPU_COUNT:", resources["num_cpus"])
|
||||
if share_memory:
|
||||
self.env = ShmemVectorEnv([lambda: single_env(env_conf) for _ in range(resources["num_cpus"])])
|
||||
else:
|
||||
self.env = SubprocVectorEnv([lambda: single_env(env_conf) for _ in range(resources["num_cpus"])])
|
||||
self.test_collector = Collector(policy=self.policy, env=self.env, testing=True, reward_metric=np.sum)
|
||||
self.train_collector = Collector(
|
||||
self.policy, self.env, buffer=ts.data.ReplayBuffer(buffer_size), reward_metric=np.sum,
|
||||
)
|
||||
self.train_paths = train_paths
|
||||
self.test_paths = test_paths
|
||||
self.valid_paths = valid_paths
|
||||
train_sampler_conf = train_paths
|
||||
train_sampler_conf["features"] = env_conf["features"]
|
||||
test_sampler_conf = test_paths
|
||||
test_sampler_conf["features"] = env_conf["features"]
|
||||
self.train_sampler = getattr(sampler, io_conf["train_sampler"])(train_sampler_conf)
|
||||
self.test_sampler = getattr(sampler, io_conf["test_sampler"])(test_sampler_conf)
|
||||
self.train_logger = logger.InfoLogger()
|
||||
self.test_logger = getattr(logger, io_conf["test_logger"])
|
||||
|
||||
self.q_learning = q_learning
|
||||
|
||||
def train(
|
||||
self,
|
||||
max_epoch,
|
||||
step_per_epoch,
|
||||
repeat_per_collect,
|
||||
collect_per_step,
|
||||
batch_size,
|
||||
iteration=0,
|
||||
global_step=0,
|
||||
early_stopping=5,
|
||||
train_step_min=0,
|
||||
log_valid=True,
|
||||
*args,
|
||||
**kargs,
|
||||
):
|
||||
best_epoch, best_reward = -1, -1
|
||||
stat = {}
|
||||
for epoch in range(1, 1 + max_epoch):
|
||||
with tqdm.tqdm(total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config) as t:
|
||||
while t.n < t.total:
|
||||
result, losses = self.train_round(repeat_per_collect, collect_per_step, batch_size, iteration)
|
||||
global_step += result["n/st"]
|
||||
iteration += 1
|
||||
for k in result.keys():
|
||||
self.writer.add_scalar("Train/" + k, result[k], global_step=global_step)
|
||||
for k in losses.keys():
|
||||
if stat.get(k) is None:
|
||||
stat[k] = MovAvg()
|
||||
stat[k].add(losses[k])
|
||||
self.writer.add_scalar("Train/" + k, stat[k].get(), global_step=global_step)
|
||||
t.update(1)
|
||||
if t.n <= t.total:
|
||||
t.update()
|
||||
result = self.eval(
|
||||
self.valid_paths["order_dir"], logdir=f"{self.log_dir}/valid/{iteration}/" if log_valid else None,
|
||||
)
|
||||
for k in result.keys():
|
||||
self.writer.add_scalar("Valid/" + k, result[k], global_step=global_step)
|
||||
if best_epoch == -1 or best_reward < result["rew"]:
|
||||
best_reward = result["rew"]
|
||||
best_epoch = epoch
|
||||
best_state = self.policy.state_dict()
|
||||
early_stop_round = 0
|
||||
torch.save(self.policy, f"{self.log_dir}/policy_best")
|
||||
elif global_step >= train_step_min:
|
||||
early_stop_round += 1
|
||||
torch.save(self.policy, f"{self.log_dir}/policy_{epoch}")
|
||||
print(
|
||||
f'Epoch #{epoch}: test_reward: {result["rew"]:.4f}, ' # train_reward: {result_train["rew"]:.4f}, '
|
||||
f"best_reward: {best_reward:.4f} in #{best_epoch}"
|
||||
)
|
||||
if early_stop_round >= early_stopping:
|
||||
print("Early stopped")
|
||||
break
|
||||
print("Testing...")
|
||||
self.policy.load_state_dict(best_state)
|
||||
result = self.eval(self.test_paths["order_dir"], logdir=f"{self.log_dir}/test/", save_res=True)
|
||||
for k in result.keys():
|
||||
self.writer.add_scalar("Test/" + k, result[k], global_step=global_step)
|
||||
return result
|
||||
|
||||
def train_round(self, repeat_per_collect, collect_per_step, batch_size, *args, **kargs):
|
||||
self.policy.train()
|
||||
self.env.toggle_log(False)
|
||||
self.env.sampler = self.train_sampler
|
||||
if not self.q_learning:
|
||||
self.train_collector.reset()
|
||||
result = self.train_collector.collect(n_episode=collect_per_step, log_fn=self.train_logger)
|
||||
result = merge_dicts(result, self.train_logger.summary())
|
||||
if not self.q_learning:
|
||||
losses = self.policy.update(
|
||||
0, self.train_collector.buffer, batch_size=batch_size, repeat=repeat_per_collect,
|
||||
)
|
||||
else:
|
||||
losses = self.policy.update(batch_size, self.train_collector.buffer,)
|
||||
return result, losses
|
||||
|
||||
def eval(self, order_dir, save_res=False, logdir=None, *args, **kargs):
|
||||
print(f"start evaluating on {order_dir}")
|
||||
self.policy.eval()
|
||||
self.env.toggle_log(True)
|
||||
self.test_sampler.reset(order_dir)
|
||||
self.env.sampler = self.test_sampler
|
||||
self.test_collector.reset()
|
||||
if not logdir is None:
|
||||
if not os.path.exists(logdir):
|
||||
os.makedirs(logdir)
|
||||
eval_logger = self.test_logger(logdir, order_dir)
|
||||
eval_logger.reset()
|
||||
else:
|
||||
eval_logger = self.train_logger
|
||||
result = self.test_collector.collect(log_fn=eval_logger)
|
||||
result = merge_dicts(result, eval_logger.summary())
|
||||
if save_res:
|
||||
with open(self.log_dir + "/res.json", "w") as f:
|
||||
json.dump(result, f, sort_keys=True, indent=4)
|
||||
print(f"finish evaluating on {order_dir}")
|
||||
return result
|
||||
76
examples/trade/exp/example/OPD/config.yml
Normal file
76
examples/trade/exp/example/OPD/config.yml
Normal file
@@ -0,0 +1,76 @@
|
||||
seed: 42
|
||||
task: train
|
||||
log_dir: example/OPD
|
||||
buffer_size: 80000
|
||||
io_conf:
|
||||
test_sampler: TestSampler
|
||||
train_sampler: Sampler
|
||||
test_logger: DFLogger
|
||||
resources:
|
||||
num_cpus: 24
|
||||
num_gpus: 1
|
||||
device: cuda
|
||||
train_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/train/
|
||||
valid_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/valid/
|
||||
test_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/test/
|
||||
env_conf:
|
||||
name: StockEnv_Acc
|
||||
max_step_num: 237
|
||||
limit: 10
|
||||
time_interval: 30
|
||||
interval_num: 8
|
||||
features:
|
||||
- name: raw
|
||||
type: range
|
||||
loc: ../data/normed_feature/
|
||||
size: 180
|
||||
- name: teacher_action
|
||||
type: interval
|
||||
size: 1
|
||||
loc: ../data/feature/teacher/
|
||||
obs:
|
||||
name: RuleTeacher
|
||||
config: {}
|
||||
action:
|
||||
name: Static_Action
|
||||
config:
|
||||
action_num: 5
|
||||
action_map: [0, 0.25, 0.5, 0.75, 1]
|
||||
reward:
|
||||
VP_Penalty_small_vec:
|
||||
penalty: 100
|
||||
coefficient: 1
|
||||
policy_conf:
|
||||
name: PPO_sup
|
||||
config:
|
||||
discount_factor: 1.
|
||||
max_grad_norm: 100.
|
||||
reward_normalization: False
|
||||
eps_clip: 0.3
|
||||
value_clip: True
|
||||
vf_coef: 1.
|
||||
gae_lambda: 1.
|
||||
vf_clip_para: 0.3
|
||||
sup_coef: 0.01
|
||||
network_conf:
|
||||
name: OPD
|
||||
config:
|
||||
hidden_size: 64
|
||||
out_shape: 5
|
||||
fc_size: 32
|
||||
cnn_shape: [30, 6]
|
||||
optim:
|
||||
lr: 1e-4
|
||||
batch_size: 1024
|
||||
max_epoch: 30
|
||||
step_per_epoch: 20
|
||||
collect_per_step: 10000
|
||||
repeat_per_collect: 5
|
||||
early_stopping: 5
|
||||
weight_decay: 0.
|
||||
71
examples/trade/exp/example/OPDS/config.yml
Normal file
71
examples/trade/exp/example/OPDS/config.yml
Normal file
@@ -0,0 +1,71 @@
|
||||
seed: 42
|
||||
task: train
|
||||
log_dir: example/OPDS
|
||||
buffer_size: 80000
|
||||
io_conf:
|
||||
test_sampler: TestSampler
|
||||
train_sampler: Sampler
|
||||
test_logger: DFLogger
|
||||
resources:
|
||||
num_cpus: 24
|
||||
num_gpus: 1
|
||||
device: cuda
|
||||
train_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/train/
|
||||
valid_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/valid/
|
||||
test_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/test/
|
||||
env_conf:
|
||||
name: StockEnv_Acc
|
||||
max_step_num: 237
|
||||
limit: 10
|
||||
time_interval: 30
|
||||
interval_num: 8
|
||||
features:
|
||||
- name: raw
|
||||
type: range
|
||||
loc: ../data/normed_feature/
|
||||
size: 180
|
||||
obs:
|
||||
name: TeacherObs
|
||||
config: {}
|
||||
action:
|
||||
name: Static_Action
|
||||
config:
|
||||
action_num: 5
|
||||
action_map: [0, 0.25, 0.5, 0.75, 1]
|
||||
reward:
|
||||
VP_Penalty_small_vec:
|
||||
penalty: 100
|
||||
coefficient: 1
|
||||
policy_conf:
|
||||
name: PPO
|
||||
config:
|
||||
discount_factor: 1.
|
||||
max_grad_norm: 100.
|
||||
reward_normalization: False
|
||||
eps_clip: 0.3
|
||||
value_clip: True
|
||||
vf_coef: 1.
|
||||
gae_lambda: 1.
|
||||
vf_clip_para: 0.3
|
||||
network_conf:
|
||||
name: PPO
|
||||
config:
|
||||
hidden_size: 64
|
||||
out_shape: 5
|
||||
fc_size: 32
|
||||
cnn_shape: [30, 6]
|
||||
optim:
|
||||
lr: 1e-4
|
||||
batch_size: 1024
|
||||
max_epoch: 30
|
||||
step_per_epoch: 20
|
||||
collect_per_step: 10000
|
||||
repeat_per_collect: 5
|
||||
early_stopping: 5
|
||||
weight_decay: 0.
|
||||
71
examples/trade/exp/example/OPDT/config.yml
Normal file
71
examples/trade/exp/example/OPDT/config.yml
Normal file
@@ -0,0 +1,71 @@
|
||||
seed: 42
|
||||
task: train
|
||||
log_dir: example/OPDT
|
||||
buffer_size: 80000
|
||||
io_conf:
|
||||
test_sampler: TestSampler
|
||||
train_sampler: Sampler
|
||||
test_logger: DFLogger
|
||||
resources:
|
||||
num_cpus: 24
|
||||
num_gpus: 1
|
||||
device: cuda
|
||||
train_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/train/
|
||||
valid_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/valid/
|
||||
test_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/test/
|
||||
env_conf:
|
||||
name: StockEnv_Acc
|
||||
max_step_num: 237
|
||||
limit: 10
|
||||
time_interval: 30
|
||||
interval_num: 8
|
||||
features:
|
||||
- name: raw
|
||||
type: range
|
||||
loc: ../data/normed_feature/
|
||||
size: 180
|
||||
obs:
|
||||
name: TeacherObs
|
||||
config: {}
|
||||
action:
|
||||
name: Static_Action
|
||||
config:
|
||||
action_num: 5
|
||||
action_map: [0, 0.25, 0.5, 0.75, 1]
|
||||
reward:
|
||||
VP_Penalty_small_vec:
|
||||
penalty: 100
|
||||
coefficient: 1
|
||||
policy_conf:
|
||||
name: PPO
|
||||
config:
|
||||
discount_factor: 1.
|
||||
max_grad_norm: 100.
|
||||
reward_normalization: False
|
||||
eps_clip: 0.3
|
||||
value_clip: True
|
||||
vf_coef: 1.
|
||||
gae_lambda: 1.
|
||||
vf_clip_para: 0.3
|
||||
network_conf:
|
||||
name: Teacher
|
||||
config:
|
||||
hidden_size: 64
|
||||
out_shape: 5
|
||||
fc_size: 32
|
||||
cnn_shape: [30, 6]
|
||||
optim:
|
||||
lr: 1e-4
|
||||
batch_size: 1024
|
||||
max_epoch: 30
|
||||
step_per_epoch: 20
|
||||
collect_per_step: 10000
|
||||
repeat_per_collect: 5
|
||||
early_stopping: 5
|
||||
weight_decay: 0.
|
||||
76
examples/trade/exp/example/OPDT_b/config.yml
Normal file
76
examples/trade/exp/example/OPDT_b/config.yml
Normal file
@@ -0,0 +1,76 @@
|
||||
seed: 42
|
||||
task: eval
|
||||
log_dir: example/OPDT_b
|
||||
buffer_size: 80000
|
||||
io_conf:
|
||||
test_sampler: TestSampler
|
||||
train_sampler: Sampler
|
||||
test_logger: DFLogger
|
||||
resources:
|
||||
num_cpus: 24
|
||||
num_gpus: 1
|
||||
device: cuda
|
||||
train_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/train/
|
||||
valid_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/valid/
|
||||
test_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/all/
|
||||
env_conf:
|
||||
name: StockEnv_Acc
|
||||
max_step_num: 237
|
||||
limit: 10
|
||||
time_interval: 30
|
||||
interval_num: 8
|
||||
features:
|
||||
- name: raw
|
||||
type: range
|
||||
loc: ../data/normed_feature/
|
||||
size: 180
|
||||
obs:
|
||||
name: TeacherObs
|
||||
config: {}
|
||||
action:
|
||||
name: Static_Action
|
||||
config:
|
||||
action_num: 5
|
||||
action_map: [0, 0.25, 0.5, 0.75, 1]
|
||||
reward:
|
||||
VP_Penalty_small_vec:
|
||||
penalty: 100
|
||||
coefficient: 1
|
||||
policy_path: policy_best
|
||||
policy_conf:
|
||||
name: PPO
|
||||
config:
|
||||
discount_factor: 1.
|
||||
max_grad_norm: 100.
|
||||
reward_normalization: False
|
||||
eps_clip: 0.3
|
||||
value_clip: True
|
||||
vf_coef: 1.
|
||||
gae_lambda: 1.
|
||||
vf_clip_para: 0.3
|
||||
network_conf:
|
||||
name: Teacher
|
||||
config:
|
||||
hidden_size: 64
|
||||
out_shape: 5
|
||||
fc_size: 32
|
||||
cnn_shape: [30, 6]
|
||||
optim:
|
||||
lr: 1e-4
|
||||
batch_size: 1024
|
||||
max_epoch: 30
|
||||
step_per_epoch: 20
|
||||
collect_per_step: 10000
|
||||
repeat_per_collect: 5
|
||||
early_stopping: 5
|
||||
weight_decay: 0.
|
||||
search:
|
||||
optim.weight_decay:
|
||||
type: choice
|
||||
value: [0.]
|
||||
70
examples/trade/exp/example/PPO/config.yml
Normal file
70
examples/trade/exp/example/PPO/config.yml
Normal file
@@ -0,0 +1,70 @@
|
||||
seed: 42
|
||||
task: train
|
||||
log_dir: example/PPO
|
||||
buffer_size: 80000
|
||||
io_conf:
|
||||
test_sampler: TestSampler
|
||||
train_sampler: Sampler
|
||||
test_logger: DFLogger
|
||||
resources:
|
||||
num_cpus: 24
|
||||
num_gpus: 1
|
||||
device: cuda
|
||||
train_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/train/
|
||||
valid_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/valid/
|
||||
test_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/test/
|
||||
env_conf:
|
||||
name: StockEnv_Acc
|
||||
max_step_num: 237
|
||||
limit: 10
|
||||
time_interval: 30
|
||||
interval_num: 8
|
||||
features:
|
||||
- name: raw
|
||||
type: range
|
||||
loc: ../data/normed_feature/
|
||||
size: 180
|
||||
obs:
|
||||
name: TeacherObs
|
||||
config: {}
|
||||
action:
|
||||
name: Static_Action
|
||||
config:
|
||||
action_num: 5
|
||||
action_map: [0, 0.25, 0.5, 0.75, 1]
|
||||
reward:
|
||||
PPO_Reward:
|
||||
coefficient: 1
|
||||
policy_conf:
|
||||
name: PPO
|
||||
config:
|
||||
discount_factor: 1.
|
||||
max_grad_norm: 100.
|
||||
reward_normalization: False
|
||||
eps_clip: 0.3
|
||||
value_clip: True
|
||||
vf_coef: 1.
|
||||
gae_lambda: 1.
|
||||
vf_clip_para: 0.3
|
||||
network_conf:
|
||||
name: PPO
|
||||
config:
|
||||
hidden_size: 64
|
||||
out_shape: 5
|
||||
fc_size: 32
|
||||
cnn_shape: [30, 6]
|
||||
optim:
|
||||
lr: 1e-4
|
||||
batch_size: 1024
|
||||
max_epoch: 30
|
||||
step_per_epoch: 20
|
||||
collect_per_step: 10000
|
||||
repeat_per_collect: 5
|
||||
early_stopping: 5
|
||||
weight_decay: 0.
|
||||
1
examples/trade/logger/__init__.py
Normal file
1
examples/trade/logger/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .single_logger import *
|
||||
231
examples/trade/logger/single_logger.py
Normal file
231
examples/trade/logger/single_logger.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import os
|
||||
from multiprocessing import Queue, Process
|
||||
import time
|
||||
|
||||
|
||||
def GLR(values):
|
||||
"""
|
||||
|
||||
Calculate -P(value | value > 0) / P(value | value < 0)
|
||||
|
||||
"""
|
||||
pos = []
|
||||
neg = []
|
||||
for i in values:
|
||||
if i > 0:
|
||||
pos.append(i)
|
||||
elif i < 0:
|
||||
neg.append(i)
|
||||
return -np.mean(pos) / np.mean(neg)
|
||||
|
||||
|
||||
class DFLogger(object):
|
||||
"""The logger for single-assert backtest.
|
||||
Would save .pkl and .log in log_dir
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, log_dir, order_dir, writer=None):
|
||||
self.order_dir = order_dir + "/"
|
||||
self.log_dir = log_dir + "/"
|
||||
if not os.path.exists(log_dir):
|
||||
os.mkdir(log_dir)
|
||||
self.queue = Queue(100000)
|
||||
self.raw_log_dir = self.log_dir
|
||||
|
||||
@staticmethod
|
||||
def _worker(log_dir, order_dir, queue):
|
||||
df_cache = {}
|
||||
stat_cache = {}
|
||||
if not os.path.exists(log_dir):
|
||||
os.mkdir(log_dir)
|
||||
while True:
|
||||
info = queue.get(block=True)
|
||||
if info == "stop":
|
||||
summary = {}
|
||||
for k, v in stat_cache.items():
|
||||
if not k.startswith("money"):
|
||||
summary[k + "_std"] = np.nanstd(v)
|
||||
summary[k + "_mean"] = np.nanmean(v)
|
||||
try:
|
||||
for k in ["PR_sell", "ffr_sell", "PA_sell"]:
|
||||
summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache["money_sell"])
|
||||
except:
|
||||
# summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache['money_sell'])
|
||||
pass
|
||||
try:
|
||||
for k in ["PR_buy", "ffr_buy", "PA_buy"]:
|
||||
summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache["money_buy"])
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
for k in ["obs0_PR", "ffr", "PA"]:
|
||||
summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache["money"])
|
||||
except:
|
||||
pass
|
||||
summary["GLR"] = GLR(stat_cache["PA"])
|
||||
try:
|
||||
summary["GLR_sell"] = GLR(stat_cache["PA_sell"])
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
summary["GLR_buy"] = GLR(stat_cache["PA_buy"])
|
||||
except:
|
||||
pass
|
||||
queue.put(summary)
|
||||
break
|
||||
elif len(info) == 0:
|
||||
continue
|
||||
else:
|
||||
df = info.pop("df")
|
||||
res = info.pop("res")
|
||||
ins = df.index[0][0]
|
||||
if ins not in df_cache:
|
||||
df_cache[ins] = (
|
||||
[],
|
||||
[],
|
||||
(pd.read_pickle(order_dir + ins + ".pkl.target")['amount'] != 0).sum(),
|
||||
)
|
||||
df_cache[ins][0].append(df)
|
||||
df_cache[ins][1].append(res)
|
||||
if len(df_cache[ins][0]) == df_cache[ins][2]:
|
||||
pd.concat(df_cache[ins][0]).to_pickle(log_dir + ins + ".log")
|
||||
pd.concat(df_cache[ins][1]).to_pickle(log_dir + ins + ".pkl")
|
||||
del df_cache[ins]
|
||||
for k, v in info.items():
|
||||
if k not in stat_cache:
|
||||
stat_cache[k] = []
|
||||
if hasattr(v, "__len__"):
|
||||
stat_cache[k] += list(v)
|
||||
else:
|
||||
stat_cache[k].append(v)
|
||||
|
||||
def reset(self):
|
||||
""" """
|
||||
while not self.queue.empty():
|
||||
self.queue.get()
|
||||
assert self.queue.empty()
|
||||
self.child = Process(target=self._worker, args=(self.log_dir, self.order_dir, self.queue), daemon=True,)
|
||||
self.child.start()
|
||||
|
||||
def set_step(self, step):
|
||||
|
||||
self.log_dir = f"{self.raw_log_dir}{step}/"
|
||||
self.reset()
|
||||
|
||||
def __call__(self, infos):
|
||||
for info in infos:
|
||||
if "env_id" in info:
|
||||
info.pop("env_id")
|
||||
self.update(infos)
|
||||
|
||||
def update(self, infos):
|
||||
"""store values in info into the logger"""
|
||||
for info in infos:
|
||||
self.queue.put(info, block=True)
|
||||
|
||||
def summary(self):
|
||||
""":return: The mean and std of values in infos stored in logger"""
|
||||
summary = {}
|
||||
self.queue.put("stop", block=True)
|
||||
self.child.join()
|
||||
self.child.close()
|
||||
assert self.queue.qsize() == 1
|
||||
summary = self.queue.get()
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
class InfoLogger(DFLogger):
|
||||
""" """
|
||||
|
||||
def __init__(self, *args):
|
||||
self.stat_cache = {}
|
||||
self.queue = Queue(10000)
|
||||
self.child = Process(target=self._worker, args=(self.queue,), daemon=True)
|
||||
self.child.start()
|
||||
|
||||
def _worker(logdir, queue):
|
||||
stat_cache = {}
|
||||
while True:
|
||||
info = queue.get(block=True)
|
||||
if info == "stop":
|
||||
summary = {}
|
||||
for k, v in stat_cache.items():
|
||||
if not k.startswith("money"):
|
||||
summary[k + "_std"] = np.nanstd(v)
|
||||
summary[k + "_mean"] = np.nanmean(v)
|
||||
try:
|
||||
for k in ["PR_sell", "ffr_sell", "PA_sell"]:
|
||||
summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache["money_sell"])
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
for k in ["PR_buy", "ffr_buy", "PA_buy"]:
|
||||
summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache["money_buy"])
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
for k in ["obs0_PR", "ffr", "PA"]:
|
||||
summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache["money"])
|
||||
except:
|
||||
pass
|
||||
summary["GLR"] = GLR(stat_cache["PA"])
|
||||
try:
|
||||
summary["GLR_sell"] = GLR(stat_cache["PA_sell"])
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
summary["GLR_buy"] = GLR(stat_cache["PA_buy"])
|
||||
except:
|
||||
pass
|
||||
queue.put(summary)
|
||||
stat_cache = {}
|
||||
time.sleep(5)
|
||||
continue
|
||||
if len(info) == 0:
|
||||
continue
|
||||
for k, v in info.items():
|
||||
if k == "res" or k == "df":
|
||||
continue
|
||||
if k not in stat_cache:
|
||||
stat_cache[k] = []
|
||||
if hasattr(v, "__len__"):
|
||||
stat_cache[k] += list(v)
|
||||
else:
|
||||
stat_cache[k].append(v)
|
||||
|
||||
def _update(self, info):
|
||||
if len(info) == 0:
|
||||
return
|
||||
ins = df.index[0][0]
|
||||
for k, v in info.items():
|
||||
if k not in self.stat_cache:
|
||||
self.stat_cache[k] = []
|
||||
if hasattr(v, "__len__"):
|
||||
self.stat_cache[k] += list(v)
|
||||
else:
|
||||
self.stat_cache[k].append(v)
|
||||
|
||||
def summary(self):
|
||||
""" """
|
||||
while not self.queue.empty():
|
||||
# print('not empty')
|
||||
# print(self.queue.qsize())
|
||||
time.sleep(1)
|
||||
self.queue.put("stop")
|
||||
# self.child.join()
|
||||
time.sleep(1)
|
||||
while not self.queue.qsize() == 1:
|
||||
# print(self.queue.qsize())
|
||||
time.sleep(1)
|
||||
assert self.queue.qsize() == 1
|
||||
summary = self.queue.get()
|
||||
|
||||
return summary
|
||||
|
||||
def set_step(self, step):
|
||||
return
|
||||
135
examples/trade/main.py
Normal file
135
examples/trade/main.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import re
|
||||
import os
|
||||
import argparse
|
||||
import yaml
|
||||
from executor import Executor
|
||||
import warnings
|
||||
import redis
|
||||
import subprocess
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
from util import merge_dicts
|
||||
|
||||
loader = yaml.FullLoader
|
||||
loader.add_implicit_resolver(
|
||||
"tag:yaml.org,2002:float",
|
||||
re.compile(
|
||||
"""^(?:
|
||||
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
||||
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
||||
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
||||
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|
||||
|[-+]?\\.(?:inf|Inf|INF)
|
||||
|\\.(?:nan|NaN|NAN))$""",
|
||||
re.X,
|
||||
),
|
||||
list("-+0123456789."),
|
||||
)
|
||||
|
||||
|
||||
def get_full_config(config, dir_name):
|
||||
while "base" in config:
|
||||
base_config = os.path.normpath(os.path.join(dir_name, config.pop("base")))
|
||||
dir_name = os.path.dirname(base_config)
|
||||
with open(base_config, "r") as f:
|
||||
base_config = yaml.load(base_config, Loader=yaml.FullLoader)
|
||||
config = merge_dicts(base_config, config)
|
||||
return config
|
||||
|
||||
|
||||
def run(config):
|
||||
log_dir = config["log_dir"]
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
with open(log_dir + "/config.yml", "w") as f:
|
||||
yaml.dump(config, f)
|
||||
executor = Executor(**config)
|
||||
if config["task"] == "train":
|
||||
return executor.train(**config["optim"])
|
||||
elif config["task"] == "eval":
|
||||
return executor.eval(config["test_paths"]["order_dir"], save_res=True, logdir=config["log_dir"] + "/test/",)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-c", "--config", type=str)
|
||||
parser.add_argument("-n", "--index", type=int, default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(os.cpu_count())
|
||||
|
||||
EXP_PATH = os.environ["EXP_PATH"]
|
||||
config_path = os.path.normpath(os.path.join(EXP_PATH, args.config))
|
||||
EXP_NAME = os.path.relpath(config_path, EXP_PATH)
|
||||
if os.path.isdir(config_path):
|
||||
if not args.index is None:
|
||||
with open(config_path + "/configs.yml") as f:
|
||||
config_list = list(yaml.load_all(f, Loader=loader))
|
||||
config = config_list[args.index]
|
||||
if "PT_OUTPUT_DIR" in os.environ:
|
||||
config["log_dir"] = os.environ["PT_OUTPUT_DIR"]
|
||||
else:
|
||||
log_prefix = os.environ["OUTPUT_DIR"] if "OUTPUT_DIR" in os.environ else "../log"
|
||||
config["log_dir"] = os.path.join(log_prefix, config["log_dir"])
|
||||
config = get_full_config(config, config_path)
|
||||
run(config)
|
||||
else:
|
||||
redis_server = redis.Redis(
|
||||
host=os.environ["REDIS_SERVER"],
|
||||
port=os.environ["REDIS_PORT"],
|
||||
db=0,
|
||||
charset="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
with open(config_path + "/configs.yml") as f:
|
||||
config_list = list(yaml.load_all(f, Loader=loader))
|
||||
config_num = len(config_list)
|
||||
if not redis_server.exists(EXP_NAME):
|
||||
for i in range(config_num):
|
||||
redis_server.rpush(EXP_NAME, i)
|
||||
redis_server.set(f"{EXP_NAME}_{i}", "Pending")
|
||||
else:
|
||||
if redis_server.llen(EXP_NAME) == 0:
|
||||
for i in range(config_num):
|
||||
if (
|
||||
not redis_server.exists(f"{EXP_NAME}_{i}")
|
||||
or redis_server.get(f"{EXP_NAME}_{i}") == "Failed"
|
||||
):
|
||||
redis_server.rpush(EXP_NAME, i)
|
||||
redis_server.set(f"{EXP_NAME}_{i}", "Pending")
|
||||
print(f"Starting..., {redis_server.llen(EXP_NAME)} trails to run")
|
||||
while True:
|
||||
index = redis_server.lpop(EXP_NAME)
|
||||
if index is None:
|
||||
print("All done")
|
||||
break
|
||||
index = int(index)
|
||||
redis_server.set(f"{EXP_NAME}_{index}", "Running")
|
||||
print(f"Trail_{index} is running")
|
||||
try:
|
||||
res = subprocess.run(["python", "main.py", "--config", args.config, "--index", str(index),],)
|
||||
except KeyboardInterrupt:
|
||||
redis_server.set(f"{EXP_NAME}_{index}", "Failed")
|
||||
print(f"Trail_{index} has failed, {redis_server.llen(EXP_NAME)} trails to run")
|
||||
break
|
||||
if res.returncode == 0:
|
||||
redis_server.set(f"{EXP_NAME}_{index}", "Finished")
|
||||
print(f"Finish running one trail, {redis_server.llen(EXP_NAME)} trails to run")
|
||||
else:
|
||||
redis_server.set(f"{EXP_NAME}_{index}", "Failed")
|
||||
print(f"Trail_{index} has failed, {redis_server.llen(EXP_NAME)} trails to run")
|
||||
|
||||
elif os.path.isfile(config_path):
|
||||
assert config_path.endswith(".yml"), "Config file should be an yaml file"
|
||||
EXP_NAME = EXP_NAME[:-4]
|
||||
with open(config_path, "r") as f:
|
||||
config = yaml.load(f, Loader=loader)
|
||||
config = get_full_config(config, os.path.dirname(config_path))
|
||||
log_prefix = os.environ["OUTPUT_DIR"] if "OUTPUT_DIR" in os.environ else "../log"
|
||||
config["log_dir"] = os.path.join(log_prefix, config["log_dir"])
|
||||
run(config)
|
||||
else:
|
||||
print("The config path should be a relative path from EXP_PATH")
|
||||
5
examples/trade/network/__init__.py
Normal file
5
examples/trade/network/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .ppo import *
|
||||
from .qmodel import *
|
||||
from .teacher import *
|
||||
from .util import *
|
||||
from .opd import *
|
||||
74
examples/trade/network/opd.py
Normal file
74
examples/trade/network/opd.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from copy import deepcopy
|
||||
import sys
|
||||
|
||||
from tianshou.data import to_torch
|
||||
|
||||
|
||||
class OPD_Extractor(nn.Module):
|
||||
def __init__(self, device="cpu", **kargs):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
hidden_size = kargs["hidden_size"]
|
||||
fc_size = kargs["fc_size"]
|
||||
self.cnn_shape = kargs["cnn_shape"]
|
||||
|
||||
self.rnn = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.rnn2 = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.dnn = nn.Sequential(nn.Linear(2, 64), nn.ReLU(),)
|
||||
self.cnn = nn.Sequential(nn.Conv1d(self.cnn_shape[1], 3, 3), nn.ReLU(),)
|
||||
self.raw_fc = nn.Sequential(nn.Linear((self.cnn_shape[0] - 2) * 3, 64), nn.ReLU(),)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(hidden_size * 2, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 32), nn.ReLU(),
|
||||
)
|
||||
|
||||
def forward(self, inp):
|
||||
inp = to_torch(inp, dtype=torch.float32, device=self.device)
|
||||
teacher_action = inp[:, 0]
|
||||
inp = inp[:, 1:]
|
||||
seq_len = inp[:, -1].to(torch.long)
|
||||
batch_size = inp.shape[0]
|
||||
raw_in = inp[:, : 6 * 240]
|
||||
raw_in = torch.cat((torch.zeros_like(inp[:, : 6 * 30]), raw_in), dim=-1)
|
||||
raw_in = raw_in.reshape(-1, 30, 6).transpose(1, 2)
|
||||
dnn_in = inp[:, 6 * 240 : -1].reshape(batch_size, -1, 2)
|
||||
cnn_out = self.cnn(raw_in).view(batch_size, 9, -1)
|
||||
rnn_in = self.raw_fc(cnn_out)
|
||||
rnn2_in = self.dnn(dnn_in)
|
||||
rnn2_out = self.rnn2(rnn2_in)[0]
|
||||
rnn_out = self.rnn(rnn_in)[0]
|
||||
rnn_out = rnn_out[torch.arange(rnn_out.size(0)), seq_len]
|
||||
rnn2_out = rnn2_out[torch.arange(rnn2_out.size(0)), seq_len]
|
||||
# dnn_out = self.dnn(dnn_in)
|
||||
fc_in = torch.cat((rnn_out, rnn2_out), dim=-1)
|
||||
feature = self.fc(fc_in)
|
||||
return feature, teacher_action / 2
|
||||
|
||||
|
||||
class OPD_Actor(nn.Module):
|
||||
def __init__(self, extractor, out_shape, device=torch.device("cpu"), **kargs):
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.layer_out = nn.Sequential(nn.Linear(32, out_shape), nn.Softmax(dim=-1))
|
||||
self.device = device
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
feature, self.teacher_action = self.extractor(obs)
|
||||
out = self.layer_out(feature)
|
||||
return out, state
|
||||
|
||||
|
||||
class OPD_Critic(nn.Module):
|
||||
def __init__(self, extractor, out_shape, device=torch.device("cpu"), **kargs):
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.value_out = nn.Linear(32, 1)
|
||||
self.device = device
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
feature, self.teacher_action = self.extractor(obs)
|
||||
return self.value_out(feature).squeeze(dim=-1)
|
||||
79
examples/trade/network/ppo.py
Normal file
79
examples/trade/network/ppo.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from copy import deepcopy
|
||||
import sys
|
||||
|
||||
from tianshou.data import to_torch
|
||||
|
||||
|
||||
class PPO_Extractor(nn.Module):
|
||||
def __init__(self, device="cpu", **kargs):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
hidden_size = kargs["hidden_size"]
|
||||
fc_size = kargs["fc_size"]
|
||||
self.cnn_shape = kargs["cnn_shape"]
|
||||
|
||||
self.rnn = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.rnn2 = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.dnn = nn.Sequential(nn.Linear(2, 64), nn.ReLU(),)
|
||||
self.cnn = nn.Sequential(nn.Conv1d(self.cnn_shape[1], 3, 3), nn.ReLU(),)
|
||||
self.raw_fc = nn.Sequential(nn.Linear((self.cnn_shape[0] - 2) * 3, 64), nn.ReLU(),)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(hidden_size * 2, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 32), nn.ReLU(),
|
||||
)
|
||||
|
||||
def forward(self, inp):
|
||||
inp = to_torch(inp, dtype=torch.float32, device=self.device)
|
||||
# inp = torch.from_numpy(inp).to(torch.device('cpu'))
|
||||
seq_len = inp[:, -1].to(torch.long)
|
||||
batch_size = inp.shape[0]
|
||||
raw_in = inp[:, : 6 * 240]
|
||||
raw_in = torch.cat((torch.zeros_like(inp[:, : 6 * 30]), raw_in), dim=-1)
|
||||
raw_in = raw_in.reshape(-1, 30, 6).transpose(1, 2)
|
||||
dnn_in = inp[:, -19:-1].reshape(batch_size, -1, 2)
|
||||
cnn_out = self.cnn(raw_in).view(batch_size, 9, -1)
|
||||
assert not torch.isnan(cnn_out).any()
|
||||
rnn_in = self.raw_fc(cnn_out)
|
||||
assert not torch.isnan(rnn_in).any()
|
||||
rnn2_in = self.dnn(dnn_in)
|
||||
assert not torch.isnan(rnn2_in).any()
|
||||
rnn2_out = self.rnn2(rnn2_in)[0]
|
||||
assert not torch.isnan(rnn2_out).any()
|
||||
rnn_out = self.rnn(rnn_in)[0]
|
||||
assert not torch.isnan(rnn_out).any()
|
||||
rnn_out = rnn_out[torch.arange(rnn_out.size(0)), seq_len]
|
||||
rnn2_out = rnn2_out[torch.arange(rnn2_out.size(0)), seq_len]
|
||||
# dnn_out = self.dnn(dnn_in)
|
||||
fc_in = torch.cat((rnn_out, rnn2_out), dim=-1)
|
||||
self.feature = self.fc(fc_in)
|
||||
return self.feature
|
||||
|
||||
|
||||
class PPO_Actor(nn.Module):
|
||||
def __init__(self, extractor, out_shape, device=torch.device("cpu"), **kargs):
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.layer_out = nn.Sequential(nn.Linear(32, out_shape), nn.Softmax(dim=-1))
|
||||
self.device = device
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
self.feature = self.extractor(obs)
|
||||
assert not (torch.isnan(self.feature).any() | torch.isinf(self.feature).any()), f"{self.feature}"
|
||||
out = self.layer_out(self.feature)
|
||||
return out, state
|
||||
|
||||
|
||||
class PPO_Critic(nn.Module):
|
||||
def __init__(self, extractor, out_shape, device=torch.device("cpu"), **kargs):
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.value_out = nn.Linear(32, 1)
|
||||
self.device = device
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
self.feature = self.extractor(obs)
|
||||
return self.value_out(self.feature).squeeze(dim=-1)
|
||||
52
examples/trade/network/qmodel.py
Normal file
52
examples/trade/network/qmodel.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from copy import deepcopy
|
||||
import sys
|
||||
|
||||
from tianshou.data import to_torch
|
||||
|
||||
|
||||
class RNNQModel(nn.Module):
|
||||
def __init__(self, device="cpu", out_shape=10, **kargs):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
hidden_size = kargs["hidden_size"]
|
||||
fc_size = kargs["fc_size"]
|
||||
self.cnn_shape = kargs["cnn_shape"]
|
||||
|
||||
self.rnn = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.rnn2 = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.dnn = nn.Sequential(nn.Linear(2, 64), nn.ReLU(),)
|
||||
self.cnn = nn.Sequential(nn.Conv1d(self.cnn_shape[1], 3, 3), nn.ReLU(),)
|
||||
self.raw_fc = nn.Sequential(nn.Linear((self.cnn_shape[0] - 2) * 3, 64), nn.ReLU(),)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(hidden_size * 2, hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_size, 32),
|
||||
nn.ReLU(),
|
||||
nn.Linear(32, out_shape),
|
||||
)
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
inp = to_torch(obs, dtype=torch.float32, device=self.device)
|
||||
inp = inp[:, 182:]
|
||||
seq_len = inp[:, -1].to(torch.long)
|
||||
batch_size = inp.shape[0]
|
||||
raw_in = inp[:, : 6 * 240]
|
||||
raw_in = torch.cat((torch.zeros_like(inp[:, : 6 * 30]), raw_in), dim=-1)
|
||||
raw_in = raw_in.reshape(-1, 30, 6).transpose(1, 2)
|
||||
dnn_in = inp[:, 6 * 240 : -1].reshape(batch_size, -1, 2)
|
||||
cnn_out = self.cnn(raw_in).view(batch_size, 9, -1)
|
||||
rnn_in = self.raw_fc(cnn_out)
|
||||
rnn2_in = self.dnn(dnn_in)
|
||||
rnn2_out = self.rnn2(rnn2_in)[0]
|
||||
rnn_out = self.rnn(rnn_in)[0]
|
||||
rnn_out = rnn_out[torch.arange(rnn_out.size(0)), seq_len]
|
||||
rnn2_out = rnn2_out[torch.arange(rnn2_out.size(0)), seq_len]
|
||||
# dnn_out = self.dnn(dnn_in)
|
||||
fc_in = torch.cat((rnn_out, rnn2_out), dim=-1)
|
||||
out = self.fc(fc_in)
|
||||
return out, state
|
||||
69
examples/trade/network/teacher.py
Normal file
69
examples/trade/network/teacher.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from copy import deepcopy
|
||||
import sys
|
||||
|
||||
from tianshou.data import to_torch
|
||||
|
||||
|
||||
class Teacher_Extractor(nn.Module):
|
||||
def __init__(self, device="cpu", feature_size=180, **kargs):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
hidden_size = kargs["hidden_size"]
|
||||
fc_size = kargs["fc_size"]
|
||||
self.cnn_shape = kargs["cnn_shape"]
|
||||
|
||||
self.rnn = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.rnn2 = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.dnn = nn.Sequential(nn.Linear(2, 64), nn.ReLU(),)
|
||||
self.cnn = nn.Sequential(nn.Conv1d(self.cnn_shape[1], 3, 3), nn.ReLU(),)
|
||||
self.raw_fc = nn.Sequential(nn.Linear((self.cnn_shape[0] - 2) * 3, 64), nn.ReLU(),)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(hidden_size * 2, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 32), nn.ReLU(),
|
||||
)
|
||||
|
||||
def forward(self, inp):
|
||||
inp = to_torch(inp, dtype=torch.float32, device=self.device)
|
||||
seq_len = inp[:, -1].to(torch.long)
|
||||
batch_size = inp.shape[0]
|
||||
raw_in = inp[:, : 6 * 240].reshape(-1, 30, 6).transpose(1, 2) ## public part of state
|
||||
dnn_in = inp[:, 6 * 240 : -1].reshape(batch_size, -1, 2) ## private part of state
|
||||
cnn_out = self.cnn(raw_in).view(batch_size, 8, -1)
|
||||
rnn_in = self.raw_fc(cnn_out)
|
||||
rnn2_in = self.dnn(dnn_in)
|
||||
rnn2_out = self.rnn2(rnn2_in)[0]
|
||||
rnn_out = self.rnn(rnn_in)[0][:, -1, :]
|
||||
rnn2_out = rnn2_out[torch.arange(rnn2_out.size(0)), seq_len]
|
||||
# dnn_out = self.dnn(dnn_in)
|
||||
fc_in = torch.cat((rnn_out, rnn2_out), dim=-1)
|
||||
self.feature = self.fc(fc_in)
|
||||
return self.feature
|
||||
|
||||
|
||||
class Teacher_Actor(nn.Module):
|
||||
def __init__(self, extractor, out_shape, device=torch.device("cpu"), **kargs):
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.layer_out = nn.Sequential(nn.Linear(32, out_shape), nn.Softmax(dim=-1))
|
||||
self.device = device
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
self.feature = self.extractor(obs)
|
||||
out = self.layer_out(self.feature)
|
||||
return out, state
|
||||
|
||||
|
||||
class Teacher_Critic(nn.Module):
|
||||
def __init__(self, extractor, out_shape, device=torch.device("cpu"), **kargs):
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.value_out = nn.Linear(32, 1)
|
||||
self.device = device
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
self.feature = self.extractor(obs)
|
||||
return self.value_out(self.feature).squeeze(-1)
|
||||
191
examples/trade/network/util.py
Normal file
191
examples/trade/network/util.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from copy import deepcopy
|
||||
import sys
|
||||
|
||||
from tianshou.data import to_torch
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
self.get_w = nn.Sequential(nn.Linear(in_dim * 2, in_dim), nn.ReLU(), nn.Linear(in_dim, 1))
|
||||
|
||||
self.fc = nn.Sequential(nn.Linear(in_dim, out_dim), nn.ReLU(),)
|
||||
|
||||
def forward(self, value, key):
|
||||
key = key.unsqueeze(dim=1)
|
||||
length = value.shape[1]
|
||||
key = key.repeat([1, length, 1])
|
||||
weight = self.get_w(torch.cat((key, value), dim=-1)).squeeze() # B * l
|
||||
weight = weight.softmax(dim=-1).unsqueeze(dim=-1) # B * l * 1
|
||||
out = (value * weight).sum(dim=1)
|
||||
out = self.fc(out)
|
||||
return out
|
||||
|
||||
|
||||
class MaskAttention(nn.Module):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
self.get_w = nn.Sequential(nn.Linear(in_dim * 2, in_dim), nn.ReLU(), nn.Linear(in_dim, 1))
|
||||
|
||||
self.fc = nn.Sequential(nn.Linear(in_dim, out_dim), nn.ReLU(),)
|
||||
|
||||
def forward(self, value, key, seq_len, maxlen=9):
|
||||
# seq_len: (batch,)
|
||||
device = value.device
|
||||
key = key.unsqueeze(dim=1)
|
||||
length = value.shape[1]
|
||||
key = key.repeat([1, length, 1]) # (batch, 9, 64)
|
||||
weight = self.get_w(torch.cat((key, value), dim=-1)).squeeze(-1) # (batch, 9)
|
||||
mask = sequence_mask(seq_len + 1, maxlen=maxlen, device=device)
|
||||
weight[~mask] = float("-inf")
|
||||
weight = weight.softmax(dim=-1).unsqueeze(dim=-1)
|
||||
out = (value * weight).sum(dim=1)
|
||||
out = self.fc(out)
|
||||
return out
|
||||
|
||||
|
||||
class TFMaskAttention(nn.Module):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
self.get_w = nn.Sequential(nn.Linear(in_dim * 2, in_dim), nn.ReLU(), nn.Linear(in_dim, 1))
|
||||
|
||||
self.fc = nn.Sequential(nn.Linear(in_dim, out_dim), nn.ReLU(),)
|
||||
|
||||
def forward(self, value, key, seq_len, maxlen=9):
|
||||
device = value.device
|
||||
key = key.unsqueeze(dim=1)
|
||||
length = value.shape[1]
|
||||
key = key.repeat([1, length, 1])
|
||||
weight = self.get_w(torch.cat((key, value), dim=-1)).squeeze(-1)
|
||||
mask = sequence_mask(seq_len + 1, maxlen=maxlen, device=device)
|
||||
mask = mask.repeat(1, 3) # (batch, 9*3)
|
||||
weight[~mask] = float("-inf")
|
||||
weight = weight.softmax(dim=-1).unsqueeze(dim=-1)
|
||||
out = (value * weight).sum(dim=1)
|
||||
out = self.fc(out)
|
||||
return out
|
||||
|
||||
|
||||
class NNAttention(nn.Module):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
self.q_net = nn.Linear(in_dim, out_dim)
|
||||
self.k_net = nn.Linear(in_dim, out_dim)
|
||||
self.v_net = nn.Linear(in_dim, out_dim)
|
||||
|
||||
def forward(self, Q, K, V):
|
||||
q = self.q_net(Q)
|
||||
k = self.k_net(K)
|
||||
v = self.v_net(V)
|
||||
|
||||
attn = torch.einsum("ijk,ilk->ijl", q, k)
|
||||
attn = attn.to(Q.device)
|
||||
attn_prob = torch.softmax(attn, dim=-1)
|
||||
|
||||
attn_vec = torch.einsum("ijk,ikl->ijl", attn_prob, v)
|
||||
|
||||
return attn_vec
|
||||
|
||||
|
||||
class Reshape(nn.Module):
|
||||
def __init__(self, *args):
|
||||
super(Reshape, self).__init__()
|
||||
self.shape = args
|
||||
|
||||
def forward(self, x):
|
||||
return x.view(self.shape)
|
||||
|
||||
|
||||
class DARNN(nn.Module):
|
||||
def __init__(self, device="cpu", **kargs):
|
||||
super().__init__()
|
||||
self.emb_dim = kargs["emb_dim"]
|
||||
self.hidden_size = kargs["hidden_size"]
|
||||
self.num_layers = kargs["num_layers"]
|
||||
self.is_bidir = kargs["is_bidir"]
|
||||
self.dropout = kargs["dropout"]
|
||||
self.seq_len = kargs["seq_len"]
|
||||
self.interval = kargs["interval"]
|
||||
self.today_length = 238
|
||||
self.prev_length = 240
|
||||
self.input_length = 480
|
||||
self.input_size = 6
|
||||
|
||||
self.rnn = nn.LSTM(
|
||||
input_size=self.input_size + self.emb_dim,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
batch_first=True,
|
||||
bidirectional=self.is_bidir,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
self.prev_rnn = nn.LSTM(
|
||||
input_size=self.input_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
batch_first=True,
|
||||
bidirectional=self.is_bidir,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
self.fc_out = nn.Linear(in_features=self.hidden_size * 2, out_features=1)
|
||||
self.attention = NNAttention(self.hidden_size, self.hidden_size)
|
||||
self.act_out = nn.Sigmoid()
|
||||
if self.emb_dim != 0:
|
||||
self.pos_emb = nn.Embedding(self.input_length, self.emb_dim)
|
||||
|
||||
def forward(self, inputs):
|
||||
inputs = inputs.view(-1, self.input_length, self.input_size) # [B, T, F]
|
||||
today_input = inputs[:, : self.today_length, :]
|
||||
today_input = torch.cat((torch.zeros_like(today_input[:, :1, :]), today_input), dim=1)
|
||||
prev_input = inputs[:, 240 : 240 + self.prev_length, :]
|
||||
if self.emb_dim != 0:
|
||||
embedding = self.pos_emb(torch.arange(end=self.today_length + 1, device=inputs.device))
|
||||
embedding = embedding.repeat([today_input.size()[0], 1, 1])
|
||||
today_input = torch.cat((today_input, embedding), dim=-1)
|
||||
prev_outs, _ = self.prev_rnn(prev_input)
|
||||
today_outs, _ = self.rnn(today_input)
|
||||
|
||||
outs = self.attention(today_outs, prev_outs, prev_outs)
|
||||
outs = torch.cat((today_outs, outs), dim=-1)
|
||||
outs = outs[:, range(0, self.seq_len * self.interval, self.interval), :]
|
||||
# outs = self.fc_out(outs).squeeze()
|
||||
return self.act_out(self.fc_out(outs).squeeze(-1)), outs
|
||||
|
||||
|
||||
class Transpose(nn.Module):
|
||||
def __init__(self, dim1=0, dim2=1):
|
||||
super().__init__()
|
||||
self.dim1 = dim1
|
||||
self.dim2 = dim2
|
||||
|
||||
def forward(self, x):
|
||||
return x.transpose(self.dim1, self.dim2)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, *args, **kargs):
|
||||
super().__init__()
|
||||
self.attention = nn.MultiheadAttention(*args, **kargs)
|
||||
|
||||
def forward(self, x):
|
||||
return self.attention(x, x, x)[0]
|
||||
|
||||
|
||||
def onehot_enc(y, len):
|
||||
y = y.unsqueeze(-1)
|
||||
y_onehot = torch.zeros(y.shape[0], len)
|
||||
# y_onehot.zero_()
|
||||
y_onehot.scatter(1, y, 1)
|
||||
return y_onehot
|
||||
|
||||
|
||||
def sequence_mask(lengths, maxlen=None, dtype=torch.bool, device=None):
|
||||
if maxlen is None:
|
||||
maxlen = lengths.max()
|
||||
mask = ~(torch.ones((len(lengths), maxlen), device=device).cumsum(dim=1).t() > lengths).t()
|
||||
mask.type(dtype)
|
||||
return mask
|
||||
3
examples/trade/observation/__init__.py
Normal file
3
examples/trade/observation/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .ppo_obs import *
|
||||
from .teacher_obs import *
|
||||
from .obs_rule import *
|
||||
136
examples/trade/observation/obs_rule.py
Normal file
136
examples/trade/observation/obs_rule.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, Box, Tuple, MultiDiscrete
|
||||
import math
|
||||
import json
|
||||
|
||||
|
||||
class BaseObs(object):
|
||||
""" """
|
||||
|
||||
def __init__(self, config):
|
||||
self._observation_space = None
|
||||
|
||||
def get_space(self):
|
||||
""" """
|
||||
return self._observation_space
|
||||
|
||||
def get_obs(self, t):
|
||||
pass
|
||||
|
||||
|
||||
class RuleObs(BaseObs):
|
||||
"""The observation for minute-level rule-based agents, which consists of prediction, private state and direction information."""
|
||||
|
||||
def __init__(self, config):
|
||||
feature_size = 0
|
||||
self.features = config["features"]
|
||||
self.time_interval = config["time_interval"]
|
||||
self.max_step_num = config["max_step_num"]
|
||||
for feature in self.features:
|
||||
feature_size += feature["size"]
|
||||
|
||||
self._observation_space = Tuple(
|
||||
(
|
||||
Box(-np.inf, np.inf, shape=(feature_size,), dtype=np.float32),
|
||||
Box(-np.inf, np.inf, shape=(4,), dtype=np.float32),
|
||||
Discrete(2),
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(self, *args, **kargs):
|
||||
return self.get_obs(*args, **kargs)
|
||||
|
||||
def get_feature_res(self, df_list, time, interval, whole_day=False, interval_num=8):
|
||||
"""
|
||||
This method would extract the needed feature from the feature dataframe based on the feature name
|
||||
and the description in feature config.
|
||||
|
||||
:param df_list: The dataframes of features, the order is consistent with the feature list.
|
||||
:param time: The index of current minute of the day (starting from -1).
|
||||
:param interval: The index of interval or decition making.
|
||||
:param whole_day: if True, this method would return the concatenate of all dataframe.(Default value = False)
|
||||
|
||||
"""
|
||||
predictions = []
|
||||
if whole_day:
|
||||
try:
|
||||
prediction = [df_list[i].reshape(-1) for i in range(len(df_list))]
|
||||
except:
|
||||
prediction = [df_list[i].reshape(-1) for i in range(len(df_list))]
|
||||
for i, p in enumerate(prediction):
|
||||
if len(p) < interval_num:
|
||||
prediction[i] = np.concatenate((p, np.zeros(interval_num - len(p))), axis=-1)
|
||||
# res = np.stack(prediction).transpose().reshape(-1)
|
||||
return np.concatenate(prediction)
|
||||
for i in range(len(self.features)):
|
||||
feature = self.features[i]
|
||||
df = df_list[i]
|
||||
size = feature["size"]
|
||||
if feature["type"] == "inday":
|
||||
if time == -1:
|
||||
predictions += [0.0] * size
|
||||
else:
|
||||
predictions += df[size * time : size * (time + 1)].reshape(-1).tolist()
|
||||
elif feature["type"] == "daily":
|
||||
predictions += df.reshape(-1)[:size].tolist()
|
||||
elif feature["type"] == "range":
|
||||
if time == -1:
|
||||
predictions += [0.0] * size
|
||||
else:
|
||||
predictions += df[time : size + time].reshape(-1).tolist()
|
||||
elif feature["type"] == "interval":
|
||||
if len(df[interval * size : (interval + 1) * size].reshape(-1)) == size:
|
||||
predictions += df[interval * size : (interval + 1) * size].reshape(-1).tolist()
|
||||
else:
|
||||
predictions += [0.0] * size
|
||||
elif feature["type"] == "step":
|
||||
if len(df[size * (time + 1) : size * (time + 2)].reshape(-1)) == size:
|
||||
predictions += df[size * (time + 1) : size * (time + 2)].reshape(-1).tolist()
|
||||
else:
|
||||
predictions += [0.0] * size
|
||||
|
||||
return np.array(predictions)
|
||||
|
||||
def get_obs(self, raw_df, feature_dfs, t, interval, position, target, is_buy, *args, **kargs):
|
||||
private_state = np.array([position, target, t, self.max_step_num])
|
||||
prediction_state = self.get_feature_res(feature_dfs, t, interval)
|
||||
return {
|
||||
"prediction": prediction_state,
|
||||
"private": private_state,
|
||||
"is_buy": int(is_buy),
|
||||
}
|
||||
|
||||
|
||||
class RuleInterval(RuleObs):
|
||||
"""
|
||||
The observation for interval_level rule based strategy.
|
||||
|
||||
Consist of interval prediction, private state, direction
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def get_obs(
|
||||
self,
|
||||
raw_df,
|
||||
feature_dfs,
|
||||
t,
|
||||
interval,
|
||||
position,
|
||||
target,
|
||||
is_buy,
|
||||
max_step_num,
|
||||
interval_num,
|
||||
action=1.0,
|
||||
*args,
|
||||
**kargs
|
||||
):
|
||||
private_state = np.array([position, target, interval - 1, interval_num])
|
||||
prediction_state = self.get_feature_res(feature_dfs, t, interval)
|
||||
return {
|
||||
"prediction": prediction_state,
|
||||
"private": private_state,
|
||||
"is_buy": int(is_buy),
|
||||
"action": action,
|
||||
}
|
||||
28
examples/trade/observation/ppo_obs.py
Normal file
28
examples/trade/observation/ppo_obs.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, Box, Tuple, MultiDiscrete
|
||||
import math
|
||||
import json
|
||||
|
||||
from .obs_rule import RuleObs
|
||||
|
||||
|
||||
class PPOObs(RuleObs):
|
||||
"""The observation defined in IJCAI 2020. The action of previous state is included in private state"""
|
||||
|
||||
def get_obs(
|
||||
self, raw_df, feature_dfs, t, interval, position, target, is_buy, max_step_num, interval_num, action=0,
|
||||
):
|
||||
if t == -1:
|
||||
self.private_states = []
|
||||
|
||||
public_state = self.get_feature_res(feature_dfs, t, interval, whole_day=True)
|
||||
# market_state = feature_dfs[0].reshape(-1)[:6*240]
|
||||
private_state = np.array([position / target, (t + 1) / max_step_num, action])
|
||||
self.private_states.append(private_state)
|
||||
list_private_state = np.concatenate(self.private_states)
|
||||
list_private_state = np.concatenate(
|
||||
(list_private_state, [0.0] * 3 * (interval_num + 1 - len(self.private_states)),)
|
||||
)
|
||||
seqlen = np.array([interval])
|
||||
return np.concatenate((public_state, list_private_state, seqlen))
|
||||
55
examples/trade/observation/teacher_obs.py
Normal file
55
examples/trade/observation/teacher_obs.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, Box, Tuple, MultiDiscrete
|
||||
import math
|
||||
import json
|
||||
|
||||
from .obs_rule import RuleObs
|
||||
|
||||
|
||||
class TeacherObs(RuleObs):
|
||||
"""
|
||||
The Observation used for OPD method.
|
||||
|
||||
Consist of public state(raw feature), private state, seqlen
|
||||
|
||||
"""
|
||||
|
||||
def get_obs(
|
||||
self, raw_df, feature_dfs, t, interval, position, target, is_buy, max_step_num, interval_num, *args, **kargs,
|
||||
):
|
||||
if t == -1:
|
||||
self.private_states = []
|
||||
public_state = self.get_feature_res(feature_dfs, t, interval, whole_day=True)
|
||||
private_state = np.array([position / target, (t + 1) / max_step_num])
|
||||
self.private_states.append(private_state)
|
||||
list_private_state = np.concatenate(self.private_states)
|
||||
list_private_state = np.concatenate(
|
||||
(list_private_state, [0.0] * 2 * (interval_num + 1 - len(self.private_states)),)
|
||||
)
|
||||
seqlen = np.array([interval])
|
||||
assert not (
|
||||
np.isnan(list_private_state).any() | np.isinf(list_private_state).any()
|
||||
), f"{private_state}, {target}"
|
||||
assert not (np.isnan(public_state).any() | np.isinf(public_state).any()), f"{public_state}"
|
||||
return np.concatenate((public_state, list_private_state, seqlen))
|
||||
|
||||
|
||||
class RuleTeacher(RuleObs):
|
||||
""" """
|
||||
|
||||
def get_obs(
|
||||
self, raw_df, feature_dfs, t, interval, position, target, is_buy, max_step_num, interval_num, *args, **kargs,
|
||||
):
|
||||
if t == -1:
|
||||
self.private_states = []
|
||||
public_state = feature_dfs[0].reshape(-1)[: 6 * 240]
|
||||
private_state = np.array([position / target, (t + 1) / max_step_num])
|
||||
teacher_action = self.get_feature_res(feature_dfs, t, interval)[-self.features[1]["size"] :]
|
||||
self.private_states.append(private_state)
|
||||
list_private_state = np.concatenate(self.private_states)
|
||||
list_private_state = np.concatenate(
|
||||
(list_private_state, [0.0] * 2 * (interval_num + 1 - len(self.private_states)),)
|
||||
)
|
||||
seqlen = np.array([interval])
|
||||
return np.concatenate((teacher_action, public_state, list_private_state, seqlen))
|
||||
62
examples/trade/order_gen.py
Normal file
62
examples/trade/order_gen.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import os
|
||||
import time
|
||||
import datetime
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
data_path = '../data/'
|
||||
in_dir = os.path.join(data_path, 'backtest/')
|
||||
|
||||
### create order folders ####
|
||||
|
||||
def generate_order(df, start, end):
|
||||
# df['date'] = df.index.map(lambda x: x[1].date())
|
||||
# df.set_index('date', append=True, inplace=True)
|
||||
df = df.groupby('date').take(range(start, end)).droplevel(level=0)
|
||||
div = df['$volume0'].rolling((end - start)*60).mean().shift(1).groupby(level='date').transform('first')
|
||||
order = df.groupby(level=(2, 0)).mean().dropna()
|
||||
order = pd.DataFrame(order)
|
||||
order['amount'] = np.random.lognormal(-3.28, 1.14) * order['$volume0']
|
||||
order['order_type'] = 0
|
||||
order = order.drop(columns=["$volume0", "$vwap0"])
|
||||
return order
|
||||
|
||||
def w_order(f, start, end):
|
||||
df = pd.read_pickle(in_dir + f)
|
||||
#df['date'] = df.index.get_level_values(1).map(lambda x: x.date())
|
||||
#df = df.set_index('date', append=True, drop=True)
|
||||
|
||||
order = generate_order(df, start, end)
|
||||
order_train = order[order.index.get_level_values(0) < '2020-12-01']
|
||||
order_test = order[order.index.get_level_values(0) >= '2020-12-01']
|
||||
order_valid = order_test[order_test.index.get_level_values(0) < '2021-01-01']
|
||||
order_test = order_test[order_test.index.get_level_values(0) >= '2021-01-01']
|
||||
if len(order_train) > 0:
|
||||
order_train.to_pickle(train_path + f[:-9] + '.target')
|
||||
if len(order_valid) > 0:
|
||||
order_valid.to_pickle(valid_path + f[:-9] + '.target')
|
||||
if len(order_test) > 0:
|
||||
order_test.to_pickle(test_path + f[:-9] + '.target')
|
||||
if len(order) > 0:
|
||||
order.to_pickle(all_path + f[:-9] + '.target')
|
||||
return 0
|
||||
|
||||
train_path = os.path.join(data_path, "order/train/")
|
||||
if not os.path.exists(train_path):
|
||||
os.makedirs(train_path)
|
||||
|
||||
valid_path = os.path.join(data_path, "order/valid/")
|
||||
if not os.path.exists(valid_path):
|
||||
os.makedirs(valid_path)
|
||||
|
||||
test_path = os.path.join(data_path, "order/test/")
|
||||
if not os.path.exists(test_path):
|
||||
os.makedirs(test_path)
|
||||
|
||||
all_path = os.path.join(data_path, "order/all/")
|
||||
if not os.path.exists(all_path):
|
||||
os.makedirs(all_path)
|
||||
|
||||
res = Parallel(n_jobs=64)(delayed(w_order)(f, 0, 239) for f in os.listdir(in_dir))
|
||||
print(sum(res))
|
||||
2
examples/trade/policy/__init__.py
Normal file
2
examples/trade/policy/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .ppo_supervision import *
|
||||
from .ppo import *
|
||||
255
examples/trade/policy/ppo.py
Normal file
255
examples/trade/policy/ppo.py
Normal file
@@ -0,0 +1,255 @@
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, List, Tuple, Union, Optional
|
||||
|
||||
from tianshou.policy import PGPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
from tianshou.data import to_torch
|
||||
from numba import njit
|
||||
import sys
|
||||
|
||||
sys.path.append("..")
|
||||
from util import to_numpy, to_torch_as
|
||||
|
||||
|
||||
def _episodic_return(
|
||||
v_s_: np.ndarray, rew: np.ndarray, done: np.ndarray, gamma: float, gae_lambda: float,
|
||||
) -> np.ndarray:
|
||||
"""Numba speedup: 4.1s -> 0.057s."""
|
||||
returns = np.roll(v_s_, 1)
|
||||
m = (1.0 - done) * gamma
|
||||
delta = rew + v_s_ * m - returns
|
||||
m *= gae_lambda
|
||||
gae = 0.0
|
||||
for i in range(len(rew) - 1, -1, -1):
|
||||
gae_new = delta[i] + m[i] * gae
|
||||
gae = gae_new
|
||||
returns[i] += gae
|
||||
return returns
|
||||
|
||||
|
||||
class PPO(PGPolicy):
|
||||
""" The PPO policy with Teacher supervision"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: torch.distributions.Distribution,
|
||||
teacher=None,
|
||||
discount_factor: float = 0.99,
|
||||
max_grad_norm: Optional[float] = None,
|
||||
eps_clip: float = 0.2,
|
||||
vf_clip_para=10.0,
|
||||
vf_coef: float = 0.5,
|
||||
kl_coef=0.5,
|
||||
kl_target=0.01,
|
||||
ent_coef: float = 0.01,
|
||||
sup_coef=0.1,
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
gae_lambda: float = 0.95,
|
||||
dual_clip: Optional[float] = None,
|
||||
value_clip: bool = True,
|
||||
reward_normalization: bool = True,
|
||||
**kwargs
|
||||
) -> None:
|
||||
super().__init__(None, None, dist_fn, discount_factor, **kwargs)
|
||||
self._max_grad_norm = max_grad_norm
|
||||
self._eps_clip = eps_clip
|
||||
self._vf_clip_para = vf_clip_para
|
||||
self._w_vf = vf_coef
|
||||
self._w_ent = ent_coef
|
||||
self._range = action_range
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
self.optim = optim
|
||||
self.sup_coef = sup_coef
|
||||
self.kl_target = kl_target
|
||||
self.kl_coef = kl_coef
|
||||
self._batch = 64
|
||||
assert 0 <= gae_lambda <= 1, "GAE lambda should be in [0, 1]."
|
||||
self._lambda = gae_lambda
|
||||
assert dual_clip is None or dual_clip > 1, "Dual-clip PPO parameter should greater than 1."
|
||||
self._dual_clip = dual_clip
|
||||
self._value_clip = value_clip
|
||||
self._rew_norm = reward_normalization
|
||||
if not teacher is None:
|
||||
self.teacher = torch.load(teacher, map_location=torch.device("cpu"))
|
||||
self.teacher.to(self.actor.device)
|
||||
self.teacher.actor.extractor.device = self.actor.device
|
||||
else:
|
||||
self.teacher = None
|
||||
|
||||
@staticmethod
|
||||
def compute_episodic_return(
|
||||
batch: Batch,
|
||||
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
||||
gamma: float = 0.99,
|
||||
gae_lambda: float = 0.95,
|
||||
rew_norm: bool = False,
|
||||
) -> Batch:
|
||||
"""Compute returns over given full-length episodes.
|
||||
Implementation of Generalized Advantage Estimator (arXiv:1506.02438).
|
||||
:param batch: a data batch which contains several full-episode data
|
||||
chronologically.
|
||||
:type batch: :class:`~tianshou.data.Batch`
|
||||
:param v_s_: the value function of all next states :math:`V(s')`.
|
||||
:type v_s_: numpy.ndarray
|
||||
:param float gamma: the discount factor, should be in [0, 1], defaults
|
||||
to 0.99.
|
||||
:param float gae_lambda: the parameter for Generalized Advantage
|
||||
Estimation, should be in [0, 1], defaults to 0.95.
|
||||
:param bool rew_norm: normalize the reward to Normal(0, 1), defaults
|
||||
to False.
|
||||
:return: a Batch. The result will be stored in batch.returns as a numpy
|
||||
array with shape (bsz, ).
|
||||
"""
|
||||
rew = batch.rew
|
||||
v_s_ = np.zeros_like(rew) if v_s_ is None else to_numpy(v_s_.flatten())
|
||||
assert not np.isnan(v_s_).any()
|
||||
assert not np.isnan(rew).any()
|
||||
assert not np.isnan(batch.done).any()
|
||||
returns = _episodic_return(v_s_, rew, batch.done, gamma, gae_lambda)
|
||||
assert not np.isnan(returns).any()
|
||||
if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2):
|
||||
returns = (returns - returns.mean()) / returns.std()
|
||||
assert not np.isnan(returns).any()
|
||||
batch.returns = returns
|
||||
return batch
|
||||
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch:
|
||||
if self._rew_norm:
|
||||
mean, std = batch.rew.mean(), batch.rew.std()
|
||||
if not np.isclose(std, 0):
|
||||
batch.rew = (batch.rew - mean) / std
|
||||
assert not np.isnan(batch.rew).any()
|
||||
if self._lambda in [0, 1]:
|
||||
return self.compute_episodic_return(batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
else:
|
||||
v_ = []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(self._batch, shuffle=False):
|
||||
v_.append(self.critic(b.obs_next))
|
||||
v_ = to_numpy(torch.cat(v_, dim=0))
|
||||
assert not np.isnan(v_).any()
|
||||
return self.compute_episodic_return(batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
|
||||
def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, **kwargs) -> Batch:
|
||||
"""Compute action over the given batch data."""
|
||||
logits, h = self.actor(batch.obs, state=state, info=batch.info)
|
||||
if isinstance(logits, tuple):
|
||||
dist = self.dist_fn(*logits)
|
||||
else:
|
||||
dist = self.dist_fn(logits)
|
||||
if self.training:
|
||||
try:
|
||||
act = dist.sample()
|
||||
except:
|
||||
print(logits)
|
||||
act = dist.sample()
|
||||
else:
|
||||
act = torch.argmax(logits, dim=1)
|
||||
if self._range:
|
||||
act = act.clamp(self._range[0], self._range[1])
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]:
|
||||
self._batch = batch_size
|
||||
losses, clip_losses, vf_losses, ent_losses, kl_losses = [], [], [], [], []
|
||||
if self.teacher is not None:
|
||||
supervision_losses = []
|
||||
v = []
|
||||
old_log_prob = []
|
||||
feature = []
|
||||
old_logits = []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(batch_size, shuffle=False):
|
||||
v.append(self.critic(b.obs))
|
||||
b_ = self(b)
|
||||
dist = b_.dist
|
||||
logits = b_.logits
|
||||
old_log_prob.append(dist.log_prob(to_torch_as(b.act, v[0])))
|
||||
old_logits.append(logits)
|
||||
if not self.teacher is None:
|
||||
with torch.no_grad():
|
||||
for b in batch.split(batch_size, shuffle=False):
|
||||
self.teacher(b)
|
||||
feature.append(self.teacher.actor.feature)
|
||||
batch.old_feature = torch.cat(feature, dim=0)
|
||||
batch.old_logits = torch.cat(old_logits, dim=0)
|
||||
batch.v = torch.cat(v, dim=0) # old value
|
||||
batch.act = to_torch_as(batch.act, v[0])
|
||||
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
||||
batch.returns = to_torch_as(batch.returns, v[0]).reshape(batch.v.shape)
|
||||
if self._rew_norm:
|
||||
mean, std = batch.returns.mean(), batch.returns.std()
|
||||
if not np.isclose(std.item(), 0):
|
||||
batch.returns = (batch.returns - mean) / std
|
||||
batch.adv = batch.returns - batch.v
|
||||
if self._rew_norm:
|
||||
mean, std = batch.adv.mean(), batch.adv.std()
|
||||
if not np.isclose(std.item(), 0):
|
||||
batch.adv = (batch.adv - mean) / std
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size):
|
||||
dist = self(b).dist
|
||||
value = self.critic(b.obs)
|
||||
if not self.teacher is None:
|
||||
feature = self.actor.feature
|
||||
# print(feature.pow(2).mean())
|
||||
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
||||
surr1 = ratio * b.adv
|
||||
surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * b.adv
|
||||
if self._dual_clip:
|
||||
clip_loss = -torch.max(torch.min(surr1, surr2), self._dual_clip * b.adv).mean()
|
||||
else:
|
||||
clip_loss = -torch.min(surr1, surr2).mean()
|
||||
clip_losses.append(clip_loss.item())
|
||||
if self._value_clip:
|
||||
v_clip = b.v + (value - b.v).clamp(-self._vf_clip_para, self._vf_clip_para)
|
||||
vf1 = (b.returns - value).pow(2)
|
||||
vf2 = (b.returns - v_clip).pow(2)
|
||||
vf_loss = torch.max(vf1, vf2).mean()
|
||||
else:
|
||||
vf_loss = (b.returns - value).pow(2).mean()
|
||||
if not self.teacher is None:
|
||||
supervision_loss = (b.old_feature - feature).pow(2).mean()
|
||||
supervision_losses.append(supervision_loss.item())
|
||||
kl = torch.distributions.kl.kl_divergence(self.dist_fn(b.old_logits), dist)
|
||||
kl_loss = kl.mean()
|
||||
kl_losses.append(kl_loss.item())
|
||||
vf_losses.append(vf_loss.item())
|
||||
e_loss = dist.entropy().mean()
|
||||
ent_losses.append(e_loss.item())
|
||||
loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss + self.kl_coef * kl_loss
|
||||
if self.teacher is not None:
|
||||
loss += self.sup_coef * supervision_loss
|
||||
losses.append(loss.item())
|
||||
self.optim.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(
|
||||
list(self.actor.parameters()) + list(self.critic.parameters()), self._max_grad_norm,
|
||||
)
|
||||
self.optim.step()
|
||||
cur_kl = np.mean(kl_losses)
|
||||
if cur_kl > 2.0 * self.kl_target:
|
||||
self.kl_coef *= 1.5
|
||||
elif cur_kl < 0.5 * self.kl_target:
|
||||
self.kl_coef *= 0.5
|
||||
res = {
|
||||
"loss/total_loss": losses,
|
||||
"loss/policy": clip_losses,
|
||||
"loss/vf": vf_losses,
|
||||
"loss/entropy": ent_losses,
|
||||
"loss/kl": kl_losses,
|
||||
}
|
||||
if not self.teacher is None:
|
||||
res["loss/supervision"] = supervision_losses
|
||||
return res
|
||||
|
||||
|
||||
Student_new = PPO
|
||||
187
examples/trade/policy/ppo_supervision.py
Normal file
187
examples/trade/policy/ppo_supervision.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, List, Tuple, Union, Optional
|
||||
|
||||
from tianshou.policy import PGPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
from tianshou.data import to_torch
|
||||
from numba import njit
|
||||
import sys
|
||||
|
||||
sys.path.append("..")
|
||||
from util import to_numpy, to_torch_as
|
||||
|
||||
from .ppo import _episodic_return
|
||||
|
||||
|
||||
class PPO_sup(PGPolicy):
|
||||
"""The PPO policy with a log-likelihood supervision loss"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: torch.distributions.Distribution,
|
||||
discount_factor: float = 0.99,
|
||||
max_grad_norm: Optional[float] = None,
|
||||
eps_clip: float = 0.2,
|
||||
vf_clip_para=10.0,
|
||||
vf_coef: float = 0.5,
|
||||
kl_coef=0.5,
|
||||
kl_target=0.01,
|
||||
ent_coef: float = 0.01,
|
||||
sup_coef=0.1,
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
gae_lambda: float = 0.95,
|
||||
dual_clip: Optional[float] = None,
|
||||
value_clip: bool = True,
|
||||
reward_normalization: bool = True,
|
||||
**kwargs
|
||||
) -> None:
|
||||
super().__init__(None, None, dist_fn, discount_factor, **kwargs)
|
||||
self._max_grad_norm = max_grad_norm
|
||||
self._eps_clip = eps_clip
|
||||
self._vf_clip_para = vf_clip_para
|
||||
self._w_vf = vf_coef
|
||||
self._w_ent = ent_coef
|
||||
self._range = action_range
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
self.optim = optim
|
||||
self.sup_coef = sup_coef
|
||||
self.kl_target = kl_target
|
||||
self.kl_coef = kl_coef
|
||||
self._batch = 64
|
||||
assert 0 <= gae_lambda <= 1, "GAE lambda should be in [0, 1]."
|
||||
self._lambda = gae_lambda
|
||||
assert dual_clip is None or dual_clip > 1, "Dual-clip PPO parameter should greater than 1."
|
||||
self._dual_clip = dual_clip
|
||||
self._value_clip = value_clip
|
||||
self._rew_norm = reward_normalization
|
||||
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch:
|
||||
if self._rew_norm:
|
||||
mean, std = batch.rew.mean(), batch.rew.std()
|
||||
if not np.isclose(std, 0):
|
||||
batch.rew = (batch.rew - mean) / std
|
||||
if self._lambda in [0, 1]:
|
||||
return self.compute_episodic_return(batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
else:
|
||||
v_ = []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(self._batch, shuffle=False):
|
||||
v_.append(self.critic(b.obs_next))
|
||||
v_ = to_numpy(torch.cat(v_, dim=0))
|
||||
return self.compute_episodic_return(batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
|
||||
def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, **kwargs) -> Batch:
|
||||
logits, h = self.actor(batch.obs, state=state, info=batch.info)
|
||||
if isinstance(logits, tuple):
|
||||
dist = self.dist_fn(*logits)
|
||||
else:
|
||||
dist = self.dist_fn(logits)
|
||||
if self.training:
|
||||
act = dist.sample()
|
||||
else:
|
||||
act = torch.argmax(logits, dim=1)
|
||||
if self._range:
|
||||
act = act.clamp(self._range[0], self._range[1])
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]:
|
||||
self._batch = batch_size
|
||||
losses, clip_losses, vf_losses, ent_losses, kl_losses, supervision_losses = (
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
v = []
|
||||
old_log_prob = []
|
||||
teacher_action = []
|
||||
old_logits = []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(batch_size, shuffle=False):
|
||||
v.append(self.critic(b.obs))
|
||||
b_ = self(b)
|
||||
dist = b_.dist
|
||||
logits = b_.logits
|
||||
old_log_prob.append(dist.log_prob(to_torch_as(b.act, v[0])))
|
||||
old_logits.append(logits)
|
||||
teacher_action.append(self.actor.teacher_action)
|
||||
|
||||
batch.teacher_action = torch.cat(teacher_action, dim=0).to(torch.long)
|
||||
batch.old_logits = torch.cat(old_logits, dim=0)
|
||||
batch.v = torch.cat(v, dim=0) # old value
|
||||
batch.act = to_torch_as(batch.act, v[0])
|
||||
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
||||
batch.returns = to_torch_as(batch.returns, v[0]).reshape(batch.v.shape)
|
||||
if self._rew_norm:
|
||||
mean, std = batch.returns.mean(), batch.returns.std()
|
||||
if not np.isclose(std.item(), 0):
|
||||
batch.returns = (batch.returns - mean) / std
|
||||
batch.adv = batch.returns - batch.v
|
||||
if self._rew_norm:
|
||||
mean, std = batch.adv.mean(), batch.adv.std()
|
||||
if not np.isclose(std.item(), 0):
|
||||
batch.adv = (batch.adv - mean) / std
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size):
|
||||
res = self(b)
|
||||
logits = res.logits
|
||||
dist = res.dist
|
||||
value = self.critic(b.obs)
|
||||
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
||||
surr1 = ratio * b.adv
|
||||
surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * b.adv
|
||||
if self._dual_clip:
|
||||
clip_loss = -torch.max(torch.min(surr1, surr2), self._dual_clip * b.adv).mean()
|
||||
else:
|
||||
clip_loss = -torch.min(surr1, surr2).mean()
|
||||
clip_losses.append(clip_loss.item())
|
||||
if self._value_clip:
|
||||
v_clip = b.v + (value - b.v).clamp(-self._vf_clip_para, self._vf_clip_para)
|
||||
vf1 = (b.returns - value).pow(2)
|
||||
vf2 = (b.returns - v_clip).pow(2)
|
||||
vf_loss = torch.max(vf1, vf2).mean()
|
||||
else:
|
||||
vf_loss = (b.returns - value).pow(2).mean()
|
||||
supervision_loss = F.nll_loss(logits.log(), b.teacher_action)
|
||||
supervision_losses.append(supervision_loss.item())
|
||||
kl = torch.distributions.kl.kl_divergence(self.dist_fn(b.old_logits), dist)
|
||||
kl_loss = kl.mean()
|
||||
kl_losses.append(kl_loss.item())
|
||||
vf_losses.append(vf_loss.item())
|
||||
e_loss = dist.entropy().mean()
|
||||
ent_losses.append(e_loss.item())
|
||||
loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss + self.kl_coef * kl_loss
|
||||
loss += self.sup_coef * supervision_loss
|
||||
losses.append(loss.item())
|
||||
self.optim.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(
|
||||
list(self.actor.parameters()) + list(self.critic.parameters()), self._max_grad_norm,
|
||||
)
|
||||
self.optim.step()
|
||||
if hasattr(self.actor, "callback"):
|
||||
self.actor.callback()
|
||||
cur_kl = np.mean(kl_losses)
|
||||
if cur_kl > 2.0 * self.kl_target:
|
||||
self.kl_coef *= 1.5
|
||||
elif cur_kl < 0.5 * self.kl_target:
|
||||
self.kl_coef *= 0.5
|
||||
res = {
|
||||
"loss/total_loss": losses,
|
||||
"loss/policy": clip_losses,
|
||||
"loss/vf": vf_losses,
|
||||
"loss/entropy": ent_losses,
|
||||
"loss/kl": kl_losses,
|
||||
"loss/supervision": supervision_losses,
|
||||
}
|
||||
return res
|
||||
10
examples/trade/requirements.txt
Normal file
10
examples/trade/requirements.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
gym==0.17.3
|
||||
torch==1.6.0
|
||||
numba==0.51.2
|
||||
numpy==1.19.1
|
||||
pandas==1.1.3
|
||||
tqdm==4.50.2
|
||||
tianshou==0.3.0.post1
|
||||
env==0.1.0
|
||||
PyYAML==5.4.1
|
||||
redis==3.5.3
|
||||
4
examples/trade/reward/__init__.py
Normal file
4
examples/trade/reward/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .base import *
|
||||
from .pa_penalty import *
|
||||
from .ppo_reward import *
|
||||
from .vp_penalty import *
|
||||
38
examples/trade/reward/base.py
Normal file
38
examples/trade/reward/base.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Abs_Reward(object):
|
||||
"""The abstract class for Reward."""
|
||||
|
||||
def __init__(self, config):
|
||||
return
|
||||
|
||||
def get_reward(self):
|
||||
""":return: reward"""
|
||||
reward = 0
|
||||
return reward
|
||||
|
||||
def __call__(self, *args, **kargs):
|
||||
return self.get_reward(*args, **kargs)
|
||||
|
||||
def isinstant(self):
|
||||
""":return: Whether the reward should be given at every timestep or only at the end of this episode."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Instant_Reward(Abs_Reward):
|
||||
def __init__(self, config):
|
||||
self.ffr_ratio = config["ffr_ratio"]
|
||||
self.vvr_ratio = config["vvr_ratio"]
|
||||
|
||||
def isinstant(self):
|
||||
return True
|
||||
|
||||
|
||||
class EndEpisode_Reward(Abs_Reward):
|
||||
def __init__(self, config):
|
||||
self.ffr_ratio = config["ffr_ratio"]
|
||||
self.vvr_ratio = config["vvr_ratio"]
|
||||
|
||||
def isinstant(self):
|
||||
return False
|
||||
14
examples/trade/reward/pa_penalty.py
Normal file
14
examples/trade/reward/pa_penalty.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import numpy as np
|
||||
from .base import Instant_Reward
|
||||
|
||||
|
||||
class PA_Penalty(Instant_Reward):
|
||||
"""Reward: (Abs(tt_ratio_t - 1) * 10000 * v_t / target - v_t^2 * penalty) / 100"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.penalty = config["penalty"]
|
||||
|
||||
def get_reward(self, performance_raise, v_t, target, PA_t, *args):
|
||||
reward = PA_t * v_t / target
|
||||
reward -= self.penalty * (v_t / target) ** 2
|
||||
return reward / 100
|
||||
22
examples/trade/reward/ppo_reward.py
Normal file
22
examples/trade/reward/ppo_reward.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import numpy as np
|
||||
from .base import Abs_Reward
|
||||
|
||||
|
||||
class PPO_Reward(Abs_Reward):
|
||||
"""The reward function defined in IJCAI 2020"""
|
||||
|
||||
def __init__(self, *args):
|
||||
pass
|
||||
|
||||
def isinstant(self):
|
||||
return False
|
||||
|
||||
def get_reward(self, performace_raise, ffr, this_tt_ratio, is_buy):
|
||||
if is_buy:
|
||||
this_tt_ratio = 1 / this_tt_ratio
|
||||
if this_tt_ratio < 1:
|
||||
return -1.0
|
||||
elif this_tt_ratio < 1.1:
|
||||
return 0.0
|
||||
else:
|
||||
return 1.0
|
||||
37
examples/trade/reward/vp_penalty.py
Normal file
37
examples/trade/reward/vp_penalty.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import numpy as np
|
||||
from .base import Instant_Reward
|
||||
|
||||
|
||||
class VP_Penalty_small(Instant_Reward):
|
||||
"""Reward: (Abs(vv_ratio_t - 1) * 10000 - v_t^2 * penalty) / 100"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.penalty = config["penalty"]
|
||||
|
||||
def get_reward(self, performance_raise, v_t, target, *args):
|
||||
"""
|
||||
|
||||
:param performance_raise: Abs(vv_ratio_t - 1) * 10000.
|
||||
:param target: Target volume
|
||||
:param v_t: The traded volume
|
||||
"""
|
||||
assert target > 0
|
||||
reward = performance_raise * v_t / target
|
||||
reward -= self.penalty * (v_t / target) ** 2
|
||||
assert not (np.isnan(reward) or np.isinf(reward)), f"{performance_raise}, {v_t}, {target}"
|
||||
return reward / 100
|
||||
|
||||
|
||||
class VP_Penalty_small_vec(VP_Penalty_small):
|
||||
def get_reward(self, performance_raise, v_t, target, *args):
|
||||
"""
|
||||
|
||||
:param performance_raise: Abs(vv_ratio_t - 1) * 10000.
|
||||
:param target: Target volume
|
||||
:param v_t: The traded volume
|
||||
"""
|
||||
assert target > 0
|
||||
reward = performance_raise * v_t.sum() / target
|
||||
reward -= self.penalty * ((v_t / target) ** 2).sum()
|
||||
assert not (np.isnan(reward) or np.isinf(reward)), f"{performance_raise}, {v_t}, {target}"
|
||||
return reward / 100
|
||||
1
examples/trade/sampler/__init__.py
Normal file
1
examples/trade/sampler/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .single_sampler import *
|
||||
184
examples/trade/sampler/single_sampler.py
Normal file
184
examples/trade/sampler/single_sampler.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from multiprocessing.context import Process
|
||||
from multiprocessing import Queue
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append("..")
|
||||
|
||||
|
||||
def toArray(data):
|
||||
if type(data) == np.ndarray:
|
||||
return data
|
||||
|
||||
elif type(data) == list:
|
||||
data = np.array(data)
|
||||
return data
|
||||
|
||||
elif type(data) == pd.DataFrame:
|
||||
share_index = toArray(data.index)
|
||||
share_value = toArray(data.values)
|
||||
share_colmns = toArray(data.columns)
|
||||
return share_index, share_value, share_colmns
|
||||
|
||||
else:
|
||||
try:
|
||||
share_array = np.array(data)
|
||||
return share_array
|
||||
except:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Sampler:
|
||||
"""The sampler for training of single-assert RL."""
|
||||
|
||||
def __init__(self, config):
|
||||
self.raw_dir = config["raw_dir"] + "/"
|
||||
self.order_dir = config["order_dir"] + "/"
|
||||
self.ins_list = [f[:-11] for f in os.listdir(self.order_dir) if f.endswith("target")]
|
||||
self.features = config["features"]
|
||||
self.queue = Queue(1000)
|
||||
self.child = None
|
||||
self.ins = None
|
||||
self.raw_df = None
|
||||
self.df_list = None
|
||||
self.order_df = None
|
||||
|
||||
@staticmethod
|
||||
def _worker(order_dir, raw_dir, features, ins_list, queue):
|
||||
ins = None
|
||||
index = 0
|
||||
date_list = []
|
||||
while True:
|
||||
if ins is None or index == len(date_list):
|
||||
ins = np.random.choice(ins_list, 1)[0]
|
||||
# print(ins)
|
||||
order_df = pd.read_pickle(order_dir + ins + ".pkl.target")
|
||||
feature_df_list = []
|
||||
for feature in features:
|
||||
feature_df_list.append(pd.read_pickle(f"{feature['loc']}/{ins}.pkl"))
|
||||
raw_df = pd.read_pickle(raw_dir + ins + ".pkl.backtest")
|
||||
date_list = order_df.index.get_level_values(0).tolist()
|
||||
index = 0
|
||||
date = date_list[index]
|
||||
day_order_df = order_df.iloc[index]
|
||||
target = day_order_df["amount"]
|
||||
index += 1
|
||||
if target == 0:
|
||||
continue
|
||||
day_feature_dfs = []
|
||||
day_raw_df = raw_df.loc[pd.IndexSlice[ins, :, date]]
|
||||
is_buy = bool(day_order_df["order_type"])
|
||||
for df in feature_df_list:
|
||||
day_feature_dfs.append(df.loc[ins, date].values)
|
||||
day_feature_dfs = np.array(day_feature_dfs)
|
||||
day_raw_df_index, day_raw_df_value, day_raw_df_column = toArray(day_raw_df)
|
||||
day_feature_dfs_ = toArray(day_feature_dfs)
|
||||
queue.put(
|
||||
(ins, date, day_raw_df_value, day_raw_df_column, day_raw_df_index, day_feature_dfs_, target, is_buy,),
|
||||
block=True,
|
||||
)
|
||||
|
||||
def _sample_ins(self):
|
||||
""" """
|
||||
return np.random.choice(self.ins_list, 1)[0]
|
||||
|
||||
def reset(self):
|
||||
""" """
|
||||
if self.child is None:
|
||||
self.child = Process(
|
||||
target=self._worker,
|
||||
args=(self.order_dir, self.raw_dir, self.features, self.ins_list, self.queue,),
|
||||
daemon=True,
|
||||
)
|
||||
self.child.start()
|
||||
|
||||
def sample(self):
|
||||
""" """
|
||||
sample = self.queue.get(block=True)
|
||||
return sample
|
||||
|
||||
def stop(self):
|
||||
""" """
|
||||
try:
|
||||
self.child.terminate()
|
||||
except:
|
||||
for p in self.child:
|
||||
p.terminate()
|
||||
|
||||
|
||||
class TestSampler(Sampler):
|
||||
"""The sampler for backtest of single-assert strategies."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.ins_index = -1
|
||||
|
||||
def _sample_ins(self):
|
||||
""" """
|
||||
self.ins_index += 1
|
||||
if self.ins_index >= len(self.ins_list):
|
||||
return None
|
||||
else:
|
||||
return self.ins_list[self.ins_index]
|
||||
|
||||
@staticmethod
|
||||
def _worker(order_dir, raw_dir, features, ins_list, queue):
|
||||
for ins in ins_list:
|
||||
order_df = pd.read_pickle(order_dir + ins + ".pkl.target")
|
||||
df_list = []
|
||||
for feature in features:
|
||||
df_list.append(pd.read_pickle(f"{feature['loc']}/{ins}.pkl"))
|
||||
raw_df = pd.read_pickle(raw_dir + ins + ".pkl.backtest")
|
||||
date_list = order_df.index.get_level_values(0).tolist()
|
||||
for index in range(len(date_list)):
|
||||
date = date_list[index]
|
||||
day_df_list = []
|
||||
day_raw_df = raw_df.loc[pd.IndexSlice[ins, :, date]]
|
||||
day_order_df = order_df.iloc[index]
|
||||
target = day_order_df["amount"]
|
||||
if target == 0:
|
||||
continue
|
||||
is_buy = bool(day_order_df["order_type"])
|
||||
for df in df_list:
|
||||
day_df_list.append(df.loc[ins, date].values)
|
||||
day_feature_dfs = np.array(day_df_list)
|
||||
day_raw_df_index, day_raw_df_value, day_raw_df_column = toArray(day_raw_df)
|
||||
day_feature_dfs_ = toArray(day_feature_dfs)
|
||||
queue.put(
|
||||
(
|
||||
ins,
|
||||
date,
|
||||
day_raw_df_value,
|
||||
day_raw_df_column,
|
||||
day_raw_df_index,
|
||||
day_feature_dfs_,
|
||||
target,
|
||||
is_buy,
|
||||
),
|
||||
block=True,
|
||||
)
|
||||
for _ in range(100):
|
||||
queue.put(None)
|
||||
|
||||
def reset(self, order_dir=None):
|
||||
"""
|
||||
|
||||
reset the sampler and change self.order_dir if order_dir is not None.
|
||||
|
||||
"""
|
||||
if order_dir:
|
||||
self.order_dir = order_dir
|
||||
self.ins_list = [f[:-11] for f in os.listdir(self.order_dir) if f.endswith("target")]
|
||||
if not self.child is None:
|
||||
self.child.terminate()
|
||||
while not self.queue.empty():
|
||||
self.queue.get()
|
||||
self.child = Process(
|
||||
target=self._worker,
|
||||
args=(self.order_dir, self.raw_dir, self.features, self.ins_list, self.queue,),
|
||||
daemon=True,
|
||||
)
|
||||
self.child.start()
|
||||
28
examples/trade/teacher_feature.py
Normal file
28
examples/trade/teacher_feature.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import pandas as pd
|
||||
import os
|
||||
|
||||
data_path = '../data/'
|
||||
feature_path = os.path.join(data_path, 'feature/teacher/')
|
||||
if not os.path.exists(feature_path):
|
||||
os.makedirs(feature_path)
|
||||
|
||||
|
||||
log_file = os.path.join(os.environ.get('OUTPUT_DIR'),'example/OPDT_b/test/')
|
||||
|
||||
files = os.listdir(log_file)
|
||||
|
||||
for f in files:
|
||||
if f.endswith(".log"):
|
||||
df = pd.read_pickle(log_file + f)
|
||||
|
||||
#df['datetime'] = df.index.get_level_values(1).map(lambda x: x[1])
|
||||
df['datetime'] = df.index.get_level_values(1)
|
||||
df.set_index('datetime', append=True, drop=True, inplace=True)
|
||||
action = df['action']
|
||||
action = action.reset_index(level=1, drop=True)
|
||||
action.index = action.index.map(lambda x: (x[0], x[1], x[2].time()))
|
||||
action = action.unstack().iloc[:, ::30] * 2
|
||||
action = action.fillna(0)
|
||||
train_action = action.astype("int")
|
||||
final = train_action
|
||||
final.to_pickle(feature_path + f[:-4] + '.pkl')
|
||||
303
examples/trade/util.py
Normal file
303
examples/trade/util.py
Normal file
@@ -0,0 +1,303 @@
|
||||
from collections import namedtuple
|
||||
from torch.nn.utils.rnn import pack_padded_sequence
|
||||
from tianshou.data import Batch
|
||||
import numpy as np
|
||||
import torch
|
||||
import copy
|
||||
from typing import Union, Optional
|
||||
from numbers import Number
|
||||
|
||||
|
||||
def nan_weighted_avg(vals, weights, axis=None):
|
||||
"""
|
||||
|
||||
:param vals: The values to be averaged on.
|
||||
:param weights: The weights of weighted avrage.
|
||||
:param axis: On which axis to calculate the weighted avrage. (Default value = None)
|
||||
|
||||
"""
|
||||
assert vals.shape == weights.shape, AssertionError(f"{vals.shape} & {weights.shape}")
|
||||
vals = vals.copy()
|
||||
weights = weights.copy()
|
||||
res = (vals * weights).sum(axis=axis) / weights.sum(axis=axis)
|
||||
return np.nan_to_num(res, nan=vals[0])
|
||||
|
||||
|
||||
def robust_auc(y_true, y_pred):
|
||||
"""
|
||||
|
||||
Calculate AUC.
|
||||
|
||||
"""
|
||||
try:
|
||||
return roc_auc_score(y_true, y_pred)
|
||||
except:
|
||||
return np.nan
|
||||
|
||||
|
||||
def merge_dicts(d1, d2):
|
||||
"""
|
||||
|
||||
:param d1: Dict 1.
|
||||
:type d1: dict
|
||||
:param d2: Dict 2.
|
||||
:returns: A new dict that is d1 and d2 deep merged.
|
||||
:rtype: dict
|
||||
|
||||
"""
|
||||
merged = copy.deepcopy(d1)
|
||||
deep_update(merged, d2, True, [])
|
||||
return merged
|
||||
|
||||
|
||||
def deep_update(
|
||||
original, new_dict, new_keys_allowed=False, whitelist=None, override_all_if_type_changes=None,
|
||||
):
|
||||
"""Updates original dict with values from new_dict recursively.
|
||||
If new key is introduced in new_dict, then if new_keys_allowed is not
|
||||
True, an error will be thrown. Further, for sub-dicts, if the key is
|
||||
in the whitelist, then new subkeys can be introduced.
|
||||
|
||||
:param original: Dictionary with default values.
|
||||
:type original: dict
|
||||
:param new_dict(dict: dict): Dictionary with values to be updated
|
||||
:param new_keys_allowed: Whether new keys are allowed. (Default value = False)
|
||||
:type new_keys_allowed: bool
|
||||
:param whitelist: List of keys that correspond to dict
|
||||
values where new subkeys can be introduced. This is only at the top
|
||||
level. (Default value = None)
|
||||
:type whitelist: Optional[List[str]]
|
||||
:param override_all_if_type_changes: List of top level
|
||||
keys with value=dict, for which we always simply override the
|
||||
entire value (dict), iff the "type" key in that value dict changes. (Default value = None)
|
||||
:type override_all_if_type_changes: Optional[List[str]]
|
||||
:param new_dict:
|
||||
|
||||
"""
|
||||
whitelist = whitelist or []
|
||||
override_all_if_type_changes = override_all_if_type_changes or []
|
||||
|
||||
for k, value in new_dict.items():
|
||||
if k not in original and not new_keys_allowed:
|
||||
raise Exception("Unknown config parameter `{}` ".format(k))
|
||||
|
||||
# Both orginal value and new one are dicts.
|
||||
if isinstance(original.get(k), dict) and isinstance(value, dict):
|
||||
# Check old type vs old one. If different, override entire value.
|
||||
if (
|
||||
k in override_all_if_type_changes
|
||||
and "type" in value
|
||||
and "type" in original[k]
|
||||
and value["type"] != original[k]["type"]
|
||||
):
|
||||
original[k] = value
|
||||
# Whitelisted key -> ok to add new subkeys.
|
||||
elif k in whitelist:
|
||||
deep_update(original[k], value, True)
|
||||
# Non-whitelisted key.
|
||||
else:
|
||||
deep_update(original[k], value, new_keys_allowed)
|
||||
# Original value not a dict OR new value not a dict:
|
||||
# Override entire value.
|
||||
else:
|
||||
original[k] = value
|
||||
return original
|
||||
|
||||
|
||||
def get_seqlen(done_seq):
|
||||
"""
|
||||
|
||||
:param done_seq:
|
||||
|
||||
"""
|
||||
seqlen = []
|
||||
length = 0
|
||||
for i, done in enumerate(done_seq):
|
||||
length += 1
|
||||
if done:
|
||||
seqlen.append(length)
|
||||
length = 0
|
||||
if length > 0:
|
||||
seqlen.append(length)
|
||||
return np.array(seqlen)
|
||||
|
||||
|
||||
def generate_seq(seqlen, list):
|
||||
"""
|
||||
|
||||
:param seqlen: param list:
|
||||
:param list:
|
||||
|
||||
"""
|
||||
res = []
|
||||
index = 0
|
||||
maxlen = np.max(seqlen)
|
||||
for i in seqlen:
|
||||
if isinstance(list, torch.Tensor):
|
||||
res.append(torch.cat((list[index : index + i], torch.zeros_like(list[: maxlen - i])), dim=0,))
|
||||
else:
|
||||
res.append(np.concatenate((list[index : index + i], np.zeros_like(list[: maxlen - i])), axis=0))
|
||||
index += i
|
||||
if isinstance(list, torch.Tensor):
|
||||
res = torch.stack(res, dim=0)
|
||||
else:
|
||||
res = np.stack(res, axis=0)
|
||||
return res
|
||||
|
||||
|
||||
def sequence_batch(batch):
|
||||
"""
|
||||
|
||||
:param batch:
|
||||
|
||||
"""
|
||||
seqlen = get_seqlen(batch.done)
|
||||
# print(seqlen.max())
|
||||
# print(len(seqlen))
|
||||
res = Batch()
|
||||
# print(batch.keys())
|
||||
|
||||
for v in batch.keys():
|
||||
if v not in ["policy", "info"]:
|
||||
res[v] = generate_seq(seqlen, batch[v])
|
||||
else:
|
||||
res[v] = batch[v]
|
||||
res.seqlen = seqlen
|
||||
return res
|
||||
|
||||
|
||||
def flatten_seq(seq, seqlen):
|
||||
"""
|
||||
|
||||
:param seq: param seqlen:
|
||||
:param seqlen:
|
||||
|
||||
"""
|
||||
res = []
|
||||
for i, length in enumerate(seqlen):
|
||||
res.append(seq[i][:length])
|
||||
if isinstance(seq, torch.Tensor):
|
||||
res = torch.cat(res, dim=0)
|
||||
else:
|
||||
res = np.concatenate(res, axis=0)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def flatten_batch(batch):
|
||||
"""
|
||||
|
||||
:param batch:
|
||||
|
||||
"""
|
||||
for v in batch.keys():
|
||||
if v in ["policy", "info", "seqlen"]:
|
||||
continue
|
||||
batch[v] = flatten_seq(batch[v], batch.seqlen)
|
||||
return batch
|
||||
|
||||
|
||||
def to_numpy(
|
||||
x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]
|
||||
) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
|
||||
"""
|
||||
|
||||
:param x: Union[Batch:
|
||||
:param dict: param list:
|
||||
:param tuple: param np.ndarray:
|
||||
:param torch: Tensor]:
|
||||
:param x: Union[Batch:
|
||||
:param list:
|
||||
:param np.ndarray:
|
||||
:param torch.Tensor]:
|
||||
:param x: Union[Batch:
|
||||
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.detach().cpu().numpy()
|
||||
elif isinstance(x, dict):
|
||||
for k, v in x.items():
|
||||
x[k] = to_numpy(v)
|
||||
elif isinstance(x, Batch):
|
||||
x.to_numpy()
|
||||
elif isinstance(x, (list, tuple)):
|
||||
try:
|
||||
x = to_numpy(_parse_value(x))
|
||||
except TypeError:
|
||||
x = [to_numpy(e) for e in x]
|
||||
else: # fallback
|
||||
x = np.asanyarray(x)
|
||||
return x
|
||||
|
||||
|
||||
def to_torch(
|
||||
x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
|
||||
"""
|
||||
|
||||
:param x: Union[Batch:
|
||||
:param dict: param list:
|
||||
:param tuple: param np.ndarray:
|
||||
:param torch: Tensor]:
|
||||
:param dtype: Optional[torch.dtype]: (Default value = None)
|
||||
:param device: Union[str:
|
||||
:param int: param torch.device]: (Default value = 'cpu')
|
||||
:param x: Union[Batch:
|
||||
:param list:
|
||||
:param np.ndarray:
|
||||
:param torch.Tensor]:
|
||||
:param dtype: Optional[torch.dtype]: (Default value = None)
|
||||
:param device: Union[str:
|
||||
:param torch.device]: (Default value = 'cpu')
|
||||
:param x: Union[Batch:
|
||||
:param dtype: Optional[torch.dtype]: (Default value = None)
|
||||
:param device: Union[str:
|
||||
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
if dtype is not None:
|
||||
x = x.type(dtype)
|
||||
x = x.to(device)
|
||||
elif isinstance(x, dict):
|
||||
for k, v in x.items():
|
||||
x[k] = to_torch(v, dtype, device)
|
||||
elif isinstance(x, Batch):
|
||||
x.to_torch(dtype, device)
|
||||
elif isinstance(x, (np.number, np.bool_, Number)):
|
||||
x = to_torch(np.asanyarray(x), dtype, device)
|
||||
elif isinstance(x, (list, tuple)):
|
||||
try:
|
||||
x = to_torch(_parse_value(x), dtype, device)
|
||||
except TypeError:
|
||||
x = [to_torch(e, dtype, device) for e in x]
|
||||
else: # fallback
|
||||
x = np.asanyarray(x)
|
||||
if issubclass(x.dtype.type, (np.bool_, np.number)):
|
||||
x = torch.from_numpy(x).to(device)
|
||||
if dtype is not None:
|
||||
x = x.type(dtype)
|
||||
else:
|
||||
raise TypeError(f"object {x} cannot be converted to torch.")
|
||||
return x
|
||||
|
||||
|
||||
def to_torch_as(x: Union[torch.Tensor, dict, Batch, np.ndarray], y: torch.Tensor) -> Union[dict, Batch, torch.Tensor]:
|
||||
"""
|
||||
|
||||
:param x: Union[torch.Tensor:
|
||||
:param dict: param Batch:
|
||||
:param np: ndarray]:
|
||||
:param y: torch.Tensor:
|
||||
:param x: Union[torch.Tensor:
|
||||
:param Batch:
|
||||
:param np.ndarray]:
|
||||
:param y: torch.Tensor:
|
||||
:param x: Union[torch.Tensor:
|
||||
:param y: torch.Tensor:
|
||||
:returns: to_torch(x, dtype=y.dtype, device=y.device)``.
|
||||
|
||||
"""
|
||||
assert isinstance(y, torch.Tensor)
|
||||
return to_torch(x, dtype=y.dtype, device=y.device)
|
||||
695
examples/trade/vecenv.py
Normal file
695
examples/trade/vecenv.py
Normal file
@@ -0,0 +1,695 @@
|
||||
import gym
|
||||
import time
|
||||
import ctypes
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
from multiprocessing.context import Process
|
||||
from multiprocessing import Array, Pipe, connection, Queue
|
||||
from typing import Any, List, Tuple, Union, Callable, Optional
|
||||
|
||||
from tianshou.env.worker import EnvWorker
|
||||
from tianshou.env.utils import CloudpickleWrapper
|
||||
|
||||
|
||||
_NP_TO_CT = {
|
||||
np.bool: ctypes.c_bool,
|
||||
np.bool_: ctypes.c_bool,
|
||||
np.uint8: ctypes.c_uint8,
|
||||
np.uint16: ctypes.c_uint16,
|
||||
np.uint32: ctypes.c_uint32,
|
||||
np.uint64: ctypes.c_uint64,
|
||||
np.int8: ctypes.c_int8,
|
||||
np.int16: ctypes.c_int16,
|
||||
np.int32: ctypes.c_int32,
|
||||
np.int64: ctypes.c_int64,
|
||||
np.float32: ctypes.c_float,
|
||||
np.float64: ctypes.c_double,
|
||||
}
|
||||
|
||||
|
||||
class ShArray:
|
||||
"""Wrapper of multiprocessing Array."""
|
||||
|
||||
def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
|
||||
self.arr = Array(
|
||||
_NP_TO_CT[dtype.type], # type: ignore
|
||||
int(np.prod(shape)),
|
||||
)
|
||||
self.dtype = dtype
|
||||
self.shape = shape
|
||||
|
||||
def save(self, ndarray: np.ndarray) -> None:
|
||||
"""
|
||||
|
||||
:param ndarray: np.ndarray:
|
||||
:param ndarray: np.ndarray:
|
||||
:param ndarray: np.ndarray:
|
||||
|
||||
"""
|
||||
assert isinstance(ndarray, np.ndarray)
|
||||
dst = self.arr.get_obj()
|
||||
dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape)
|
||||
np.copyto(dst_np, ndarray)
|
||||
|
||||
def get(self) -> np.ndarray:
|
||||
""" """
|
||||
obj = self.arr.get_obj()
|
||||
return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape)
|
||||
|
||||
|
||||
def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]:
|
||||
"""
|
||||
|
||||
:param space: gym.Space:
|
||||
:param space: gym.Space:
|
||||
:param space: gym.Space:
|
||||
|
||||
"""
|
||||
if isinstance(space, gym.spaces.Dict):
|
||||
assert isinstance(space.spaces, OrderedDict)
|
||||
return {k: _setup_buf(v) for k, v in space.spaces.items()}
|
||||
elif isinstance(space, gym.spaces.Tuple):
|
||||
assert isinstance(space.spaces, tuple)
|
||||
return tuple([_setup_buf(t) for t in space.spaces])
|
||||
else:
|
||||
return ShArray(space.dtype, space.shape)
|
||||
|
||||
|
||||
def _worker(
|
||||
parent: connection.Connection,
|
||||
p: connection.Connection,
|
||||
env_fn_wrapper: CloudpickleWrapper,
|
||||
obs_bufs: Optional[Union[dict, tuple, ShArray]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
:param parent: connection.Connection:
|
||||
:param p: connection.Connection:
|
||||
:param env_fn_wrapper: CloudpickleWrapper:
|
||||
:param obs_bufs: Optional[Union[dict:
|
||||
:param tuple: param ShArray]]: (Default value = None)
|
||||
:param parent: connection.Connection:
|
||||
:param p: connection.Connection:
|
||||
:param env_fn_wrapper: CloudpickleWrapper:
|
||||
:param obs_bufs: Optional[Union[dict:
|
||||
:param ShArray]]: (Default value = None)
|
||||
:param parent: connection.Connection:
|
||||
:param p: connection.Connection:
|
||||
:param env_fn_wrapper: CloudpickleWrapper:
|
||||
:param obs_bufs: Optional[Union[dict:
|
||||
|
||||
"""
|
||||
|
||||
def _encode_obs(obs: Union[dict, tuple, np.ndarray], buffer: Union[dict, tuple, ShArray],) -> None:
|
||||
"""
|
||||
|
||||
:param obs: Union[dict:
|
||||
:param tuple: param np.ndarray]:
|
||||
:param buffer: Union[dict:
|
||||
:param ShArray:
|
||||
:param obs: Union[dict:
|
||||
:param np.ndarray]:
|
||||
:param buffer: Union[dict:
|
||||
:param ShArray]:
|
||||
:param obs: Union[dict:
|
||||
:param buffer: Union[dict:
|
||||
|
||||
"""
|
||||
if isinstance(obs, np.ndarray) and isinstance(buffer, ShArray):
|
||||
buffer.save(obs)
|
||||
elif isinstance(obs, tuple) and isinstance(buffer, tuple):
|
||||
for o, b in zip(obs, buffer):
|
||||
_encode_obs(o, b)
|
||||
elif isinstance(obs, dict) and isinstance(buffer, dict):
|
||||
for k in obs.keys():
|
||||
_encode_obs(obs[k], buffer[k])
|
||||
return None
|
||||
|
||||
parent.close()
|
||||
env = env_fn_wrapper.data()
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
cmd, data = p.recv()
|
||||
except EOFError: # the pipe has been closed
|
||||
p.close()
|
||||
break
|
||||
if cmd == "step":
|
||||
obs, reward, done, info = env.step(data)
|
||||
if obs_bufs is not None:
|
||||
_encode_obs(obs, obs_bufs)
|
||||
obs = None
|
||||
p.send((obs, reward, done, info))
|
||||
elif cmd == "reset":
|
||||
obs = env.reset(data)
|
||||
if obs_bufs is not None:
|
||||
_encode_obs(obs, obs_bufs)
|
||||
obs = None
|
||||
p.send(obs)
|
||||
elif cmd == "close":
|
||||
p.send(env.close())
|
||||
p.close()
|
||||
break
|
||||
elif cmd == "render":
|
||||
p.send(env.render(**data) if hasattr(env, "render") else None)
|
||||
elif cmd == "seed":
|
||||
p.send(env.seed(data) if hasattr(env, "seed") else None)
|
||||
elif cmd == "getattr":
|
||||
p.send(getattr(env, data) if hasattr(env, data) else None)
|
||||
elif cmd == "toggle_log":
|
||||
env.toggle_log(data)
|
||||
else:
|
||||
p.close()
|
||||
raise NotImplementedError
|
||||
except KeyboardInterrupt:
|
||||
p.close()
|
||||
|
||||
|
||||
class SubprocEnvWorker(EnvWorker):
|
||||
"""Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv."""
|
||||
|
||||
def __init__(self, env_fn: Callable[[], gym.Env], share_memory: bool = False) -> None:
|
||||
super().__init__(env_fn)
|
||||
self.parent_remote, self.child_remote = Pipe()
|
||||
self.share_memory = share_memory
|
||||
self.buffer: Optional[Union[dict, tuple, ShArray]] = None
|
||||
if self.share_memory:
|
||||
dummy = env_fn()
|
||||
obs_space = dummy.observation_space
|
||||
dummy.close()
|
||||
del dummy
|
||||
self.buffer = _setup_buf(obs_space)
|
||||
args = (
|
||||
self.parent_remote,
|
||||
self.child_remote,
|
||||
CloudpickleWrapper(env_fn),
|
||||
self.buffer,
|
||||
)
|
||||
self.process = Process(target=_worker, args=args, daemon=True)
|
||||
self.process.start()
|
||||
self.child_remote.close()
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
self.parent_remote.send(["getattr", key])
|
||||
return self.parent_remote.recv()
|
||||
|
||||
def _decode_obs(self) -> Union[dict, tuple, np.ndarray]:
|
||||
""" """
|
||||
|
||||
def decode_obs(buffer: Optional[Union[dict, tuple, ShArray]]) -> Union[dict, tuple, np.ndarray]:
|
||||
"""
|
||||
|
||||
:param buffer: Optional[Union[dict:
|
||||
:param tuple: param ShArray]]:
|
||||
:param buffer: Optional[Union[dict:
|
||||
:param ShArray]]:
|
||||
:param buffer: Optional[Union[dict:
|
||||
|
||||
"""
|
||||
if isinstance(buffer, ShArray):
|
||||
return buffer.get()
|
||||
elif isinstance(buffer, tuple):
|
||||
return tuple([decode_obs(b) for b in buffer])
|
||||
elif isinstance(buffer, dict):
|
||||
return {k: decode_obs(v) for k, v in buffer.items()}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return decode_obs(self.buffer)
|
||||
|
||||
def reset(self, sample) -> Any:
|
||||
"""
|
||||
|
||||
:param sample:
|
||||
|
||||
"""
|
||||
self.parent_remote.send(["reset", sample])
|
||||
# obs = self.parent_remote.recv()
|
||||
# if self.share_memory:
|
||||
# obs = self._decode_obs()
|
||||
# return obs
|
||||
|
||||
def get_reset_result(self):
|
||||
""" """
|
||||
obs = self.parent_remote.recv()
|
||||
if self.share_memory:
|
||||
obs = self._decode_obs()
|
||||
return obs
|
||||
|
||||
@staticmethod
|
||||
def wait( # type: ignore
|
||||
workers: List["SubprocEnvWorker"], wait_num: int, timeout: Optional[float] = None,
|
||||
) -> List["SubprocEnvWorker"]:
|
||||
"""
|
||||
|
||||
:param # type: ignoreworkers: List["SubprocEnvWorker"]:
|
||||
:param wait_num: int:
|
||||
:param timeout: Optional[float]: (Default value = None)
|
||||
:param # type: ignoreworkers: List["SubprocEnvWorker"]:
|
||||
:param wait_num: int:
|
||||
:param timeout: Optional[float]: (Default value = None)
|
||||
|
||||
"""
|
||||
remain_conns = conns = [x.parent_remote for x in workers]
|
||||
ready_conns: List[connection.Connection] = []
|
||||
remain_time, t1 = timeout, time.time()
|
||||
while len(remain_conns) > 0 and len(ready_conns) < wait_num:
|
||||
if timeout:
|
||||
remain_time = timeout - (time.time() - t1)
|
||||
if remain_time <= 0:
|
||||
break
|
||||
# connection.wait hangs if the list is empty
|
||||
new_ready_conns = connection.wait(remain_conns, timeout=remain_time)
|
||||
ready_conns.extend(new_ready_conns) # type: ignore
|
||||
remain_conns = [conn for conn in remain_conns if conn not in ready_conns]
|
||||
return [workers[conns.index(con)] for con in ready_conns]
|
||||
|
||||
def send_action(self, action: np.ndarray) -> None:
|
||||
"""
|
||||
|
||||
:param action: np.ndarray:
|
||||
:param action: np.ndarray:
|
||||
:param action: np.ndarray:
|
||||
|
||||
"""
|
||||
self.parent_remote.send(["step", action])
|
||||
|
||||
def toggle_log(self, log):
|
||||
self.parent_remote.send(["toggle_log", log])
|
||||
|
||||
def get_result(self,) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
""" """
|
||||
obs, rew, done, info = self.parent_remote.recv()
|
||||
if self.share_memory:
|
||||
obs = self._decode_obs()
|
||||
return obs, rew, done, info
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
|
||||
"""
|
||||
|
||||
:param seed: Optional[int]: (Default value = None)
|
||||
:param seed: Optional[int]: (Default value = None)
|
||||
:param seed: Optional[int]: (Default value = None)
|
||||
|
||||
"""
|
||||
self.parent_remote.send(["seed", seed])
|
||||
return self.parent_remote.recv()
|
||||
|
||||
def render(self, **kwargs: Any) -> Any:
|
||||
"""
|
||||
|
||||
:param **kwargs: Any:
|
||||
:param **kwargs: Any:
|
||||
|
||||
"""
|
||||
self.parent_remote.send(["render", kwargs])
|
||||
return self.parent_remote.recv()
|
||||
|
||||
def close_env(self) -> None:
|
||||
""" """
|
||||
try:
|
||||
self.parent_remote.send(["close", None])
|
||||
# mp may be deleted so it may raise AttributeError
|
||||
self.parent_remote.recv()
|
||||
self.process.join()
|
||||
except (BrokenPipeError, EOFError, AttributeError):
|
||||
pass
|
||||
# ensure the subproc is terminated
|
||||
self.process.terminate()
|
||||
|
||||
|
||||
class BaseVectorEnv(gym.Env):
|
||||
"""Base class for vectorized environments wrapper.
|
||||
Usage:
|
||||
::
|
||||
env_num = 8
|
||||
envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(env_num)])
|
||||
assert len(envs) == env_num
|
||||
It accepts a list of environment generators. In other words, an environment
|
||||
generator ``efn`` of a specific task means that ``efn()`` returns the
|
||||
environment of the given task, for example, ``gym.make(task)``.
|
||||
All of the VectorEnv must inherit :class:`~tianshou.env.BaseVectorEnv`.
|
||||
Here are some other usages:
|
||||
::
|
||||
envs.seed(2) # which is equal to the next line
|
||||
envs.seed([2, 3, 4, 5, 6, 7, 8, 9]) # set specific seed for each env
|
||||
obs = envs.reset() # reset all environments
|
||||
obs = envs.reset([0, 5, 7]) # reset 3 specific environments
|
||||
obs, rew, done, info = envs.step([1] * 8) # step synchronously
|
||||
envs.render() # render all environments
|
||||
envs.close() # close all environments
|
||||
.. warning::
|
||||
If you use your own environment, please make sure the ``seed`` method
|
||||
is set up properly, e.g.,
|
||||
::
|
||||
def seed(self, seed):
|
||||
np.random.seed(seed)
|
||||
Otherwise, the outputs of these envs may be the same with each other.
|
||||
|
||||
:param env_fns: a list of callable envs
|
||||
:param env:
|
||||
:param worker_fn: a callable worker
|
||||
:param worker: which contains the i
|
||||
:param int: wait_num
|
||||
:param env: step
|
||||
:param environments: to finish a step is time
|
||||
:param return: when
|
||||
:param simulation: in these environments
|
||||
:param is: disabled
|
||||
:param float: timeout
|
||||
:param vectorized: step it only deal with those environments spending time
|
||||
:param within: timeout
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_fns: List[Callable[[], gym.Env]],
|
||||
worker_fn: Callable[[Callable[[], gym.Env]], EnvWorker],
|
||||
sampler=None,
|
||||
testing: Optional[bool] = False,
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
self._env_fns = env_fns
|
||||
# A VectorEnv contains a pool of EnvWorkers, which corresponds to
|
||||
# interact with the given envs (one worker <-> one env).
|
||||
self.workers = [worker_fn(fn) for fn in env_fns]
|
||||
self.worker_class = type(self.workers[0])
|
||||
assert issubclass(self.worker_class, EnvWorker)
|
||||
assert all([isinstance(w, self.worker_class) for w in self.workers])
|
||||
|
||||
self.env_num = len(env_fns)
|
||||
self.wait_num = wait_num or len(env_fns)
|
||||
assert 1 <= self.wait_num <= len(env_fns), f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}"
|
||||
self.timeout = timeout
|
||||
assert self.timeout is None or self.timeout > 0, f"timeout is {timeout}, it should be positive if provided!"
|
||||
self.is_async = self.wait_num != len(env_fns) or timeout is not None or testing
|
||||
self.waiting_conn: List[EnvWorker] = []
|
||||
# environments in self.ready_id is actually ready
|
||||
# but environments in self.waiting_id are just waiting when checked,
|
||||
# and they may be ready now, but this is not known until we check it
|
||||
# in the step() function
|
||||
self.waiting_id: List[int] = []
|
||||
# all environments are ready in the beginning
|
||||
self.ready_id = list(range(self.env_num))
|
||||
self.is_closed = False
|
||||
self.sampler = sampler
|
||||
self.sample_obs = None
|
||||
|
||||
def _assert_is_not_closed(self) -> None:
|
||||
""" """
|
||||
assert not self.is_closed, f"Methods of {self.__class__.__name__} cannot be called after " "close."
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return len(self), which is the number of environments."""
|
||||
return self.env_num
|
||||
|
||||
def __getattribute__(self, key: str) -> Any:
|
||||
"""Switch the attribute getter depending on the key.
|
||||
Any class who inherits ``gym.Env`` will inherit some attributes, like
|
||||
``action_space``. However, we would like the attribute lookup to go
|
||||
straight into the worker (in fact, this vector env's action_space is
|
||||
always None).
|
||||
"""
|
||||
if key in [
|
||||
"metadata",
|
||||
"reward_range",
|
||||
"spec",
|
||||
"action_space",
|
||||
"observation_space",
|
||||
]: # reserved keys in gym.Env
|
||||
return self.__getattr__(key)
|
||||
else:
|
||||
return super().__getattribute__(key)
|
||||
|
||||
def __getattr__(self, key: str) -> List[Any]:
|
||||
"""Fetch a list of env attributes.
|
||||
This function tries to retrieve an attribute from each individual
|
||||
wrapped environment, if it does not belong to the wrapping vector
|
||||
environment class.
|
||||
"""
|
||||
return [getattr(worker, key) for worker in self.workers]
|
||||
|
||||
def _wrap_id(self, id: Optional[Union[int, List[int], np.ndarray]] = None) -> Union[List[int], np.ndarray]:
|
||||
"""
|
||||
|
||||
:param id: Optional[Union[int:
|
||||
:param List: int]:
|
||||
:param np: ndarray]]: (Default value = None)
|
||||
:param id: Optional[Union[int:
|
||||
:param List[int]:
|
||||
:param np.ndarray]]: (Default value = None)
|
||||
:param id: Optional[Union[int:
|
||||
|
||||
"""
|
||||
if id is None:
|
||||
id = list(range(self.env_num))
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
return id
|
||||
|
||||
def _assert_id(self, id: List[int]) -> None:
|
||||
"""
|
||||
|
||||
:param id: List[int]:
|
||||
:param id: List[int]:
|
||||
:param id: List[int]:
|
||||
|
||||
"""
|
||||
for i in id:
|
||||
assert i not in self.waiting_id, f"Cannot interact with environment {i} which is stepping now."
|
||||
assert i in self.ready_id, f"Can only interact with ready environments {self.ready_id}."
|
||||
|
||||
def reset(self, id: Optional[Union[int, List[int], np.ndarray]] = None) -> np.ndarray:
|
||||
"""Reset the state of some envs and return initial observations.
|
||||
If id is None, reset the state of all the environments and return
|
||||
initial observations, otherwise reset the specific environments with
|
||||
the given id, either an int or a list.
|
||||
|
||||
:param id: Optional[Union[int:
|
||||
:param List: int]:
|
||||
:param np: ndarray]]: (Default value = None)
|
||||
:param id: Optional[Union[int:
|
||||
:param List[int]:
|
||||
:param np.ndarray]]: (Default value = None)
|
||||
:param id: Optional[Union[int:
|
||||
|
||||
"""
|
||||
start_time = time.time()
|
||||
self._assert_is_not_closed()
|
||||
id = self._wrap_id(id)
|
||||
if self.is_async:
|
||||
self._assert_id(id)
|
||||
obs = []
|
||||
stop_id = []
|
||||
for i in id:
|
||||
sample = self.sampler.sample()
|
||||
if sample is None:
|
||||
stop_id.append(i)
|
||||
else:
|
||||
self.workers[i].reset(sample)
|
||||
for i in id:
|
||||
if i in stop_id:
|
||||
obs.append(self.sample_obs)
|
||||
else:
|
||||
this_obs = self.workers[i].get_reset_result()
|
||||
if self.sample_obs is None:
|
||||
self.sample_obs = this_obs
|
||||
for j in range(len(obs)):
|
||||
if obs[j] is None:
|
||||
obs[j] = self.sample_obs
|
||||
obs.append(this_obs)
|
||||
|
||||
if len(obs) > 0:
|
||||
obs = np.stack(obs)
|
||||
# if len(stop_id)> 0:
|
||||
# obs_zero =
|
||||
# print(time.time() - start_timed)
|
||||
|
||||
return obs, stop_id
|
||||
|
||||
def toggle_log(self, log):
|
||||
for worker in self.workers:
|
||||
worker.toggle_log(log)
|
||||
|
||||
def reset_sampler(self):
|
||||
""" """
|
||||
self.sampler.reset()
|
||||
|
||||
def step(self, action: np.ndarray, id: Optional[Union[int, List[int], np.ndarray]] = None) -> List[np.ndarray]:
|
||||
"""Run one timestep of some environments' dynamics.
|
||||
If id is None, run one timestep of all the environments’ dynamics;
|
||||
otherwise run one timestep for some environments with given id, either
|
||||
an int or a list. When the end of episode is reached, you are
|
||||
responsible for calling reset(id) to reset this environment’s state.
|
||||
Accept a batch of action and return a tuple (batch_obs, batch_rew,
|
||||
batch_done, batch_info) in numpy format.
|
||||
|
||||
:param numpy: ndarray action: a batch of action provided by the agent.
|
||||
:param action: np.ndarray:
|
||||
:param id: Optional[Union[int:
|
||||
:param List: int]:
|
||||
:param np: ndarray]]: (Default value = None)
|
||||
:param action: np.ndarray:
|
||||
:param id: Optional[Union[int:
|
||||
:param List[int]:
|
||||
:param np.ndarray]]: (Default value = None)
|
||||
:param action: np.ndarray:
|
||||
:param id: Optional[Union[int:
|
||||
:rtype: A tuple including four items
|
||||
|
||||
"""
|
||||
self._assert_is_not_closed()
|
||||
id = self._wrap_id(id)
|
||||
if not self.is_async:
|
||||
assert len(action) == len(id)
|
||||
for i, j in enumerate(id):
|
||||
self.workers[j].send_action(action[i])
|
||||
result = []
|
||||
for j in id:
|
||||
obs, rew, done, info = self.workers[j].get_result()
|
||||
info["env_id"] = j
|
||||
result.append((obs, rew, done, info))
|
||||
else:
|
||||
if action is not None:
|
||||
self._assert_id(id)
|
||||
assert len(action) == len(id)
|
||||
for i, (act, env_id) in enumerate(zip(action, id)):
|
||||
self.workers[env_id].send_action(act)
|
||||
self.waiting_conn.append(self.workers[env_id])
|
||||
self.waiting_id.append(env_id)
|
||||
self.ready_id = [x for x in self.ready_id if x not in id]
|
||||
ready_conns: List[EnvWorker] = []
|
||||
while not ready_conns:
|
||||
ready_conns = self.worker_class.wait(self.waiting_conn, self.wait_num, self.timeout)
|
||||
result = []
|
||||
for conn in ready_conns:
|
||||
waiting_index = self.waiting_conn.index(conn)
|
||||
self.waiting_conn.pop(waiting_index)
|
||||
env_id = self.waiting_id.pop(waiting_index)
|
||||
obs, rew, done, info = conn.get_result()
|
||||
info["env_id"] = env_id
|
||||
result.append((obs, rew, done, info))
|
||||
self.ready_id.append(env_id)
|
||||
return list(map(np.stack, zip(*result)))
|
||||
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[Optional[List[int]]]:
|
||||
"""Set the seed for all environments.
|
||||
Accept ``None``, an int (which will extend ``i`` to
|
||||
``[i, i + 1, i + 2, ...]``) or a list.
|
||||
|
||||
:param seed: Optional[Union[int:
|
||||
:param List: int]]]: (Default value = None)
|
||||
:param seed: Optional[Union[int:
|
||||
:param List[int]]]: (Default value = None)
|
||||
:param seed: Optional[Union[int:
|
||||
:returns: The list of seeds used in this env's random number generators.
|
||||
The first value in the list should be the "main" seed, or the value
|
||||
which a reproducer pass to "seed".
|
||||
|
||||
"""
|
||||
self._assert_is_not_closed()
|
||||
seed_list: Union[List[None], List[int]]
|
||||
if seed is None:
|
||||
seed_list = [seed] * self.env_num
|
||||
elif isinstance(seed, int):
|
||||
seed_list = [seed + i for i in range(self.env_num)]
|
||||
else:
|
||||
seed_list = seed
|
||||
return [w.seed(s) for w, s in zip(self.workers, seed_list)]
|
||||
|
||||
def render(self, **kwargs: Any) -> List[Any]:
|
||||
"""Render all of the environments.
|
||||
|
||||
:param **kwargs: Any:
|
||||
:param **kwargs: Any:
|
||||
|
||||
"""
|
||||
self._assert_is_not_closed()
|
||||
if self.is_async and len(self.waiting_id) > 0:
|
||||
raise RuntimeError(f"Environments {self.waiting_id} are still stepping, cannot " "render them now.")
|
||||
return [w.render(**kwargs) for w in self.workers]
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close all of the environments.
|
||||
This function will be called only once (if not, it will be called
|
||||
during garbage collected). This way, ``close`` of all workers can be
|
||||
assured.
|
||||
|
||||
|
||||
"""
|
||||
self._assert_is_not_closed()
|
||||
for w in self.workers:
|
||||
w.close()
|
||||
self.is_closed = True
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Redirect to self.close()."""
|
||||
if not self.is_closed:
|
||||
self.close()
|
||||
|
||||
|
||||
class SubprocVectorEnv(BaseVectorEnv):
|
||||
"""Vectorized environment wrapper based on subprocess.
|
||||
.. seealso::
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||
explanation.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_fns: List[Callable[[], gym.Env]],
|
||||
sampler=None,
|
||||
testing=False,
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
|
||||
"""
|
||||
|
||||
:param fn: Callable[[]:
|
||||
:param gym: Env]:
|
||||
:param fn: Callable[[]:
|
||||
:param gym.Env]:
|
||||
:param fn: Callable[[]:
|
||||
|
||||
"""
|
||||
return SubprocEnvWorker(fn, share_memory=False)
|
||||
|
||||
super().__init__(env_fns, worker_fn, sampler, testing, wait_num=wait_num, timeout=timeout)
|
||||
|
||||
|
||||
class ShmemVectorEnv(BaseVectorEnv):
|
||||
"""Optimized SubprocVectorEnv with shared buffers to exchange observations.
|
||||
ShmemVectorEnv has exactly the same API as SubprocVectorEnv.
|
||||
.. seealso::
|
||||
Please refer to :class:`~tianshou.env.SubprocVectorEnv` for more
|
||||
detailed explanation.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_fns: List[Callable[[], gym.Env]],
|
||||
sampler=None,
|
||||
testing=False,
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
|
||||
"""
|
||||
|
||||
:param fn: Callable[[]:
|
||||
:param gym: Env]:
|
||||
:param fn: Callable[[]:
|
||||
:param gym.Env]:
|
||||
:param fn: Callable[[]:
|
||||
|
||||
"""
|
||||
return SubprocEnvWorker(fn, share_memory=True)
|
||||
|
||||
super().__init__(env_fns, worker_fn, sampler, testing, wait_num=wait_num, timeout=timeout)
|
||||
@@ -90,7 +90,6 @@ _default_config = {
|
||||
# How many tasks belong to one process. Recommend 1 for high-frequency data and None for daily data.
|
||||
"maxtasksperchild": None,
|
||||
"default_disk_cache": 1, # 0:skip/1:use
|
||||
"disable_disk_cache": False, # disable disk cache; if High-frequency data generally disable_disk_cache=True
|
||||
"mem_cache_size_limit": 500,
|
||||
# memory cache expire second, only in used 'DatasetURICache' and 'client D.calendar'
|
||||
# default 1 hour
|
||||
|
||||
@@ -961,8 +961,7 @@ class BaseProvider:
|
||||
is a provider class.
|
||||
"""
|
||||
disk_cache = C.default_disk_cache if disk_cache is None else disk_cache
|
||||
if C.disable_disk_cache:
|
||||
disk_cache = False
|
||||
fields = list(fields) # In case of tuple.
|
||||
try:
|
||||
return DatasetD.dataset(instruments, fields, start_time, end_time, freq, disk_cache)
|
||||
except TypeError:
|
||||
|
||||
@@ -57,10 +57,10 @@ class DataHandler(Serializable):
|
||||
instruments=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
freq="day",
|
||||
data_loader: Tuple[dict, str, DataLoader] = None,
|
||||
init_data=True,
|
||||
fetch_orig=True,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
@@ -71,14 +71,14 @@ class DataHandler(Serializable):
|
||||
start_time of the original data.
|
||||
end_time :
|
||||
end_time of the original data.
|
||||
freq :
|
||||
frequency of data
|
||||
data_loader : Tuple[dict, str, DataLoader]
|
||||
data loader to load the data.
|
||||
init_data :
|
||||
intialize the original data in the constructor.
|
||||
fetch_orig : bool
|
||||
Return the original data instead of copy if possible.
|
||||
**kwargs:
|
||||
it will be passed into data_loader
|
||||
"""
|
||||
# Set logger
|
||||
self.logger = get_module_logger("DataHandler")
|
||||
@@ -86,23 +86,41 @@ class DataHandler(Serializable):
|
||||
# Setup data loader
|
||||
assert data_loader is not None # to make start_time end_time could have None default value
|
||||
|
||||
# what data source to load data
|
||||
self.data_loader = init_instance_by_config(
|
||||
data_loader,
|
||||
None if (isinstance(data_loader, dict) and "module_path" in data_loader) else data_loader_module,
|
||||
accept_types=DataLoader,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# what data to be loaded from data source
|
||||
# For IDE auto-completion.
|
||||
self.instruments = instruments
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
self.freq = freq
|
||||
|
||||
self.fetch_orig = fetch_orig
|
||||
if init_data:
|
||||
with TimeInspector.logt("Init data"):
|
||||
self.init()
|
||||
super().__init__()
|
||||
|
||||
def init(self, enable_cache: bool = True):
|
||||
def conf_data(self, **kwargs):
|
||||
"""
|
||||
configuration of data.
|
||||
# what data to be loaded from data source
|
||||
This method will be used when loading pickled handler from dataset.
|
||||
The data will be initialized with different time range.
|
||||
"""
|
||||
attr_list = {"instruments", "start_time", "end_time"}
|
||||
for k, v in kwargs.items():
|
||||
if k in attr_list:
|
||||
setattr(self, k, v)
|
||||
else:
|
||||
raise KeyError("Such config is not supported.")
|
||||
|
||||
def init(self, enable_cache: bool = False):
|
||||
"""
|
||||
initialize the data.
|
||||
In case of running intialization for multiple time, it will do nothing for the second time.
|
||||
@@ -123,7 +141,7 @@ class DataHandler(Serializable):
|
||||
# Setup data.
|
||||
# _data may be with multiple column index level. The outer level indicates the feature set name
|
||||
with TimeInspector.logt("Loading data"):
|
||||
self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time, self.freq)
|
||||
self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time)
|
||||
# TODO: cache
|
||||
|
||||
CS_ALL = "__all" # return all columns with single-level index column
|
||||
@@ -262,7 +280,6 @@ class DataHandlerLP(DataHandler):
|
||||
instruments=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
freq="day",
|
||||
data_loader: Tuple[dict, str, DataLoader] = None,
|
||||
infer_processors=[],
|
||||
learn_processors=[],
|
||||
@@ -328,7 +345,7 @@ class DataHandlerLP(DataHandler):
|
||||
|
||||
self.process_type = process_type
|
||||
self.drop_raw = drop_raw
|
||||
super().__init__(instruments, start_time, end_time, freq, data_loader, **kwargs)
|
||||
super().__init__(instruments, start_time, end_time, data_loader, **kwargs)
|
||||
|
||||
def get_all_processors(self):
|
||||
return self.infer_processors + self.learn_processors
|
||||
|
||||
@@ -19,7 +19,7 @@ class DataLoader(abc.ABC):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def load(self, instruments, start_time=None, end_time=None, freq="day") -> pd.DataFrame:
|
||||
def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
"""
|
||||
load the data as pd.DataFrame.
|
||||
|
||||
@@ -76,6 +76,7 @@ class DLWParser(DataLoader):
|
||||
<config> := <fields_info>
|
||||
|
||||
<fields_info> := ["expr", ...] | (["expr", ...], ["col_name", ...])
|
||||
# NOTE: list or tuple will be treated as the things when parsing
|
||||
"""
|
||||
self.is_group = isinstance(config, dict)
|
||||
|
||||
@@ -85,18 +86,22 @@ class DLWParser(DataLoader):
|
||||
self.fields = self._parse_fields_info(config)
|
||||
|
||||
def _parse_fields_info(self, fields_info: Tuple[list, tuple]) -> Tuple[list, list]:
|
||||
if isinstance(fields_info, list):
|
||||
if len(fields_info) == 0:
|
||||
raise ValueError("The size of fields must be greater than 0")
|
||||
|
||||
if not isinstance(fields_info, (list, tuple)):
|
||||
raise TypeError("Unsupported type")
|
||||
|
||||
if isinstance(fields_info[0], str):
|
||||
exprs = names = fields_info
|
||||
elif isinstance(fields_info, tuple):
|
||||
elif isinstance(fields_info[0], (list, tuple)):
|
||||
exprs, names = fields_info
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
return exprs, names
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_group_df(
|
||||
self, instruments, exprs: list, names: list, start_time=None, end_time=None, freq="day"
|
||||
) -> pd.DataFrame:
|
||||
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
"""
|
||||
load the dataframe for specific group
|
||||
|
||||
@@ -116,25 +121,25 @@ class DLWParser(DataLoader):
|
||||
"""
|
||||
pass
|
||||
|
||||
def load(self, instruments=None, start_time=None, end_time=None, freq="day") -> pd.DataFrame:
|
||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
if self.is_group:
|
||||
df = pd.concat(
|
||||
{
|
||||
grp: self.load_group_df(instruments, exprs, names, start_time, end_time, freq)
|
||||
grp: self.load_group_df(instruments, exprs, names, start_time, end_time)
|
||||
for grp, (exprs, names) in self.fields.items()
|
||||
},
|
||||
axis=1,
|
||||
)
|
||||
else:
|
||||
exprs, names = self.fields
|
||||
df = self.load_group_df(instruments, exprs, names, start_time, end_time, freq)
|
||||
df = self.load_group_df(instruments, exprs, names, start_time, end_time)
|
||||
return df
|
||||
|
||||
|
||||
class QlibDataLoader(DLWParser):
|
||||
"""Same as QlibDataLoader. The fields can be define by config"""
|
||||
|
||||
def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None, swap_level=True):
|
||||
def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None, swap_level=True, freq="day"):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -147,11 +152,10 @@ class QlibDataLoader(DLWParser):
|
||||
"""
|
||||
self.filter_pipe = filter_pipe
|
||||
self.swap_level = swap_level
|
||||
self.freq = freq
|
||||
super().__init__(config)
|
||||
|
||||
def load_group_df(
|
||||
self, instruments, exprs: list, names: list, start_time=None, end_time=None, freq="day"
|
||||
) -> pd.DataFrame:
|
||||
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
if instruments is None:
|
||||
warnings.warn("`instruments` is not set, will load all stocks")
|
||||
instruments = "all"
|
||||
@@ -160,7 +164,7 @@ class QlibDataLoader(DLWParser):
|
||||
elif self.filter_pipe is not None:
|
||||
warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list")
|
||||
|
||||
df = D.features(instruments, exprs, start_time, end_time, freq)
|
||||
df = D.features(instruments, exprs, start_time, end_time, self.freq)
|
||||
df.columns = names
|
||||
if self.swap_level:
|
||||
df = df.swaplevel().sort_index() # NOTE: if swaplevel, return <datetime, instrument>
|
||||
@@ -185,7 +189,7 @@ class StaticDataLoader(DataLoader):
|
||||
self.join = join
|
||||
self._data = None
|
||||
|
||||
def load(self, instruments=None, start_time=None, end_time=None, freq="day") -> pd.DataFrame:
|
||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
self._maybe_load_raw_data()
|
||||
if instruments is None:
|
||||
df = self._data
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
from contextlib import contextmanager
|
||||
from .expm import MLflowExpManager
|
||||
from .exp import Experiment
|
||||
from .recorder import Recorder
|
||||
from ..utils import Wrapper
|
||||
|
||||
@@ -165,7 +166,7 @@ class QlibRecorder:
|
||||
"""
|
||||
return self.get_exp(experiment_id, experiment_name).list_recorders()
|
||||
|
||||
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True):
|
||||
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment:
|
||||
"""
|
||||
Method for retrieving an experiment with given id or name. Once the `create` argument is set to
|
||||
True, if no valid experiment is found, this method will create one for you. Otherwise, it will
|
||||
|
||||
Reference in New Issue
Block a user