mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 19:41:00 +08:00
42 lines
1.2 KiB
Python
42 lines
1.2 KiB
Python
import pandas as pd
|
|
import numpy as np
|
|
from gym.spaces import Discrete, Box, Tuple, MultiDiscrete
|
|
import math
|
|
import json
|
|
|
|
from .obs_rule import RuleObs
|
|
|
|
|
|
class PPOObs(RuleObs):
|
|
"""The observation defined in IJCAI 2020. The action of previous state is included in private state"""
|
|
|
|
def get_obs(
|
|
self,
|
|
raw_df,
|
|
feature_dfs,
|
|
t,
|
|
interval,
|
|
position,
|
|
target,
|
|
is_buy,
|
|
max_step_num,
|
|
interval_num,
|
|
action=0,
|
|
):
|
|
if t == -1:
|
|
self.private_states = []
|
|
|
|
public_state = self.get_feature_res(feature_dfs, t, interval, whole_day=True)
|
|
# market_state = feature_dfs[0].reshape(-1)[:6*240]
|
|
private_state = np.array([position / target, (t + 1) / max_step_num, action])
|
|
self.private_states.append(private_state)
|
|
list_private_state = np.concatenate(self.private_states)
|
|
list_private_state = np.concatenate(
|
|
(
|
|
list_private_state,
|
|
[0.0] * 3 * (interval_num + 1 - len(self.private_states)),
|
|
)
|
|
)
|
|
seqlen = np.array([interval])
|
|
return np.concatenate((public_state, list_private_state, seqlen))
|