mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 02:50:58 +08:00
format
This commit is contained in:
@@ -25,4 +25,3 @@ class Static_Action(Base_Action):
|
||||
|
||||
"""
|
||||
return min(target * self.action_map[action], position)
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ class Rule_Static_Interval(Base_Action):
|
||||
"""
|
||||
return target / (interval_num) * action
|
||||
|
||||
|
||||
class Rule_Dynamic_Interval(Base_Action):
|
||||
""" """
|
||||
|
||||
@@ -42,4 +43,4 @@ class Rule_Dynamic_Interval(Base_Action):
|
||||
:param **kargs:
|
||||
|
||||
"""
|
||||
return position / (interval_num - interval) * action
|
||||
return position / (interval_num - interval) * action
|
||||
|
||||
@@ -67,7 +67,5 @@ class AC(VWAP):
|
||||
t = t + 1
|
||||
k_tild = self.lamb / self.eta * sig * sig
|
||||
k = np.arccosh(k_tild / 2 + 1)
|
||||
act = (np.sinh(k * (self.T - t)) - np.sinh(k * (self.T - t - 1))) / np.sinh(
|
||||
k * self.T
|
||||
)
|
||||
act = (np.sinh(k * (self.T - t)) - np.sinh(k * (self.T - t - 1))) / np.sinh(k * self.T)
|
||||
return Batch(act=act, state=state)
|
||||
|
||||
@@ -55,18 +55,14 @@ class Collector(object):
|
||||
def _default_rew_metric(x: Union[Number, np.number]) -> Union[Number, np.number]:
|
||||
# this internal function is designed for single-agent RL
|
||||
# for multi-agent RL, a reward_metric must be provided
|
||||
assert np.asanyarray(x).size == 1, (
|
||||
"Please specify the reward_metric " "since the reward is not a scalar."
|
||||
)
|
||||
assert np.asanyarray(x).size == 1, "Please specify the reward_metric " "since the reward is not a scalar."
|
||||
return x
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all related variables in the collector."""
|
||||
# use empty Batch for ``state`` so that ``self.data`` supports slicing
|
||||
# convert empty Batch to None when passing data to policy
|
||||
self.data = Batch(
|
||||
state={}, obs={}, act={}, rew={}, done={}, info={}, obs_next={}, policy={}
|
||||
)
|
||||
self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={}, obs_next={}, policy={})
|
||||
self.reset_env()
|
||||
self.reset_buffer()
|
||||
self.reset_stat()
|
||||
@@ -96,9 +92,7 @@ class Collector(object):
|
||||
self.data.obs = obs
|
||||
for b in self._cached_buf:
|
||||
b.reset()
|
||||
self._ready_env_ids = np.array(
|
||||
[x for x in self._ready_env_ids if x not in stop_id]
|
||||
)
|
||||
self._ready_env_ids = np.array([x for x in self._ready_env_ids if x not in stop_id])
|
||||
|
||||
def _reset_state(self, id: Union[int, List[int]]) -> None:
|
||||
"""Reset the hidden state: self.data.state[id]."""
|
||||
@@ -187,9 +181,7 @@ class Collector(object):
|
||||
if isinstance(n_episode, list):
|
||||
assert len(n_episode) == self.get_env_num()
|
||||
finished_env_ids = [i for i in self._ready_env_ids if n_episode[i] <= 0]
|
||||
self._ready_env_ids = np.array(
|
||||
[x for x in self._ready_env_ids if x not in finished_env_ids]
|
||||
)
|
||||
self._ready_env_ids = np.array([x for x in self._ready_env_ids if x not in finished_env_ids])
|
||||
while True:
|
||||
if step_count >= 100000 and episode_count.sum() == 0:
|
||||
warnings.warn(
|
||||
@@ -249,13 +241,9 @@ class Collector(object):
|
||||
log_fn(info)
|
||||
else:
|
||||
# store computed actions, states, etc
|
||||
_batch_set_item(
|
||||
whole_data, self._ready_env_ids, self.data, self.env_num
|
||||
)
|
||||
_batch_set_item(whole_data, self._ready_env_ids, self.data, self.env_num)
|
||||
# fetch finished data
|
||||
obs_next, rew, done, info = self.env.step(
|
||||
self.data.act, id=self._ready_env_ids
|
||||
)
|
||||
obs_next, rew, done, info = self.env.step(self.data.act, id=self._ready_env_ids)
|
||||
self._ready_env_ids = np.array([i["env_id"] for i in info])
|
||||
# get the stepped data
|
||||
self.data = whole_data[self._ready_env_ids]
|
||||
@@ -264,9 +252,7 @@ class Collector(object):
|
||||
|
||||
step_time += time.time() - start
|
||||
# move data to self.data
|
||||
self.data.update(
|
||||
obs_next=obs_next, rew=rew, done=done, info=[{} for i in info]
|
||||
)
|
||||
self.data.update(obs_next=obs_next, rew=rew, done=done, info=[{} for i in info])
|
||||
|
||||
if render:
|
||||
self.env.render()
|
||||
@@ -288,20 +274,13 @@ class Collector(object):
|
||||
self._cached_buf[i].add(**self.data[j])
|
||||
|
||||
if done[j]:
|
||||
if not (
|
||||
isinstance(n_episode, list) and episode_count[i] >= n_episode[i]
|
||||
):
|
||||
if not (isinstance(n_episode, list) and episode_count[i] >= n_episode[i]):
|
||||
episode_count[i] += 1
|
||||
rewards.append(
|
||||
self._rew_metric(np.sum(self._cached_buf[i].rew, axis=0))
|
||||
)
|
||||
rewards.append(self._rew_metric(np.sum(self._cached_buf[i].rew, axis=0)))
|
||||
step_count += len(self._cached_buf[i])
|
||||
if self.buffer is not None:
|
||||
self.buffer.update(self._cached_buf[i])
|
||||
if (
|
||||
isinstance(n_episode, list)
|
||||
and episode_count[i] >= n_episode[i]
|
||||
):
|
||||
if isinstance(n_episode, list) and episode_count[i] >= n_episode[i]:
|
||||
# env i has collected enough data, it has finished
|
||||
finished_env_ids.append(i)
|
||||
self._cached_buf[i].reset()
|
||||
@@ -318,23 +297,17 @@ class Collector(object):
|
||||
# env_ind_local.remove(_ready_env_ids.index(i))
|
||||
if len(env_ind_local) > 0:
|
||||
if self.preprocess_fn:
|
||||
obs_reset = self.preprocess_fn(obs=obs_reset).get(
|
||||
"obs", obs_reset
|
||||
)
|
||||
obs_reset = self.preprocess_fn(obs=obs_reset).get("obs", obs_reset)
|
||||
obs_next[env_ind_local] = obs_reset
|
||||
reset_time += time.time() - start
|
||||
self.data.obs = obs_next
|
||||
if is_async:
|
||||
# set data back
|
||||
whole_data = deepcopy(whole_data) # avoid reference in ListBuf
|
||||
_batch_set_item(
|
||||
whole_data, self._ready_env_ids, self.data, self.env_num
|
||||
)
|
||||
_batch_set_item(whole_data, self._ready_env_ids, self.data, self.env_num)
|
||||
# let self.data be the data in all environments again
|
||||
self.data = whole_data
|
||||
self._ready_env_ids = np.array(
|
||||
[x for x in self._ready_env_ids if x not in finished_env_ids]
|
||||
)
|
||||
self._ready_env_ids = np.array([x for x in self._ready_env_ids if x not in finished_env_ids])
|
||||
if n_step:
|
||||
if step_count >= n_step:
|
||||
break
|
||||
|
||||
50
examples/trade/env/env_rl.py
vendored
50
examples/trade/env/env_rl.py
vendored
@@ -51,9 +51,7 @@ class StockEnv(gym.Env):
|
||||
obs_conf["time_interval"] = self.time_interval
|
||||
obs_conf["max_step_num"] = self.max_step_num
|
||||
self.obs = getattr(observation, config["obs"]["name"])(obs_conf)
|
||||
self.action_func = getattr(action, config["action"]["name"])(
|
||||
config["action"]["config"]
|
||||
)
|
||||
self.action_func = getattr(action, config["action"]["name"])(config["action"]["config"])
|
||||
self.reward_func_list = []
|
||||
self.reward_log_dict = {}
|
||||
self.reward_coef = []
|
||||
@@ -87,19 +85,13 @@ class StockEnv(gym.Env):
|
||||
self.target,
|
||||
self.is_buy,
|
||||
) = sample
|
||||
self.raw_df = pd.DataFrame(
|
||||
index=self.raw_df_index,
|
||||
data=self.raw_df_values,
|
||||
columns=self.raw_df_columns,
|
||||
)
|
||||
self.raw_df = pd.DataFrame(index=self.raw_df_index, data=self.raw_df_values, columns=self.raw_df_columns,)
|
||||
del self.raw_df_values, self.raw_df_columns, self.raw_df_index
|
||||
start_time = time.time()
|
||||
self.load_time = time.time() - start_time
|
||||
self.day_vwap = nan_weighted_avg(
|
||||
self.raw_df["$vwap0"].values[self.offset : self.offset + self.max_step_num],
|
||||
self.raw_df["$volume0"].values[
|
||||
self.offset : self.offset + self.max_step_num
|
||||
],
|
||||
self.raw_df["$volume0"].values[self.offset : self.offset + self.max_step_num],
|
||||
)
|
||||
try:
|
||||
assert not (np.isnan(self.day_vwap) or np.isinf(self.day_vwap))
|
||||
@@ -108,9 +100,7 @@ class StockEnv(gym.Env):
|
||||
print(self.ins)
|
||||
print(self.day_vwap)
|
||||
self.raw_df.to_pickle("/nfs_data1/kanren/error_df.pkl")
|
||||
self.day_twap = np.nanmean(
|
||||
self.raw_df["$vwap0"].values[self.offset : self.offset + self.max_step_num]
|
||||
)
|
||||
self.day_twap = np.nanmean(self.raw_df["$vwap0"].values[self.offset : self.offset + self.max_step_num])
|
||||
self.t = -1 + self.offset
|
||||
self.interval = 0
|
||||
self.position = self.target
|
||||
@@ -130,9 +120,7 @@ class StockEnv(gym.Env):
|
||||
if self.log:
|
||||
index_array = [
|
||||
np.array([self.ins] * self.max_step_num),
|
||||
self.raw_df.index.to_numpy()[
|
||||
self.offset : self.offset + self.max_step_num
|
||||
],
|
||||
self.raw_df.index.to_numpy()[self.offset : self.offset + self.max_step_num],
|
||||
np.array([self.date] * self.max_step_num),
|
||||
]
|
||||
self.traded_log = pd.DataFrame(
|
||||
@@ -142,9 +130,7 @@ class StockEnv(gym.Env):
|
||||
self.offset : self.offset + self.max_step_num
|
||||
],
|
||||
"$traded_t": np.nan,
|
||||
"$vwap_t": self.raw_df["$vwap0"].values[
|
||||
self.offset : self.offset + self.max_step_num
|
||||
],
|
||||
"$vwap_t": self.raw_df["$vwap0"].values[self.offset : self.offset + self.max_step_num],
|
||||
"action": np.nan,
|
||||
},
|
||||
index=index_array,
|
||||
@@ -239,18 +225,14 @@ class StockEnv(gym.Env):
|
||||
self.real_eps_time += time.time() - start_time
|
||||
if self.done:
|
||||
this_traded = self.target - self.position
|
||||
this_vwap = (
|
||||
(self.this_cash / this_traded) if this_traded > ZERO else self.day_vwap
|
||||
)
|
||||
this_vwap = (self.this_cash / this_traded) if this_traded > ZERO else self.day_vwap
|
||||
valid = min(self.target, self.this_valid)
|
||||
this_ffr = (this_traded / valid) if valid > ZERO else 1.0
|
||||
if abs(this_ffr - 1.0) < ZERO:
|
||||
this_ffr = 1.0
|
||||
this_ffr *= 100
|
||||
this_vv_ratio = this_vwap / self.day_vwap
|
||||
vwap = self.raw_df["$vwap0"].values[
|
||||
self.offset : self.max_step_num + self.offset
|
||||
]
|
||||
vwap = self.raw_df["$vwap0"].values[self.offset : self.max_step_num + self.offset]
|
||||
this_tt_ratio = this_vwap / np.nanmean(vwap)
|
||||
|
||||
if self.is_buy:
|
||||
@@ -262,9 +244,7 @@ class StockEnv(gym.Env):
|
||||
|
||||
for i, reward_func in enumerate(self.reward_func_list):
|
||||
if not reward_func.isinstant:
|
||||
tmp_r = reward_func(
|
||||
performance_raise, this_ffr, this_tt_ratio, self.is_buy
|
||||
)
|
||||
tmp_r = reward_func(performance_raise, this_ffr, this_tt_ratio, self.is_buy)
|
||||
reward += tmp_r * self.reward_coef[i]
|
||||
self.reward_log_dict[type(reward_func).__name__] += tmp_r
|
||||
|
||||
@@ -405,18 +385,14 @@ class StockEnv_Acc(StockEnv):
|
||||
self.real_eps_time += time.time() - start_time
|
||||
if self.done:
|
||||
this_traded = self.target - self.position
|
||||
this_vwap = (
|
||||
(self.this_cash / this_traded) if this_traded > ZERO else self.day_vwap
|
||||
)
|
||||
this_vwap = (self.this_cash / this_traded) if this_traded > ZERO else self.day_vwap
|
||||
valid = min(self.target, self.this_valid)
|
||||
this_ffr = (this_traded / valid) if valid > ZERO else 1.0
|
||||
if abs(this_ffr - 1.0) < ZERO:
|
||||
this_ffr = 1.0
|
||||
this_ffr *= 100
|
||||
this_vv_ratio = this_vwap / self.day_vwap
|
||||
vwap = self.raw_df["$vwap0"].values[
|
||||
self.offset : self.max_step_num + self.offset
|
||||
]
|
||||
vwap = self.raw_df["$vwap0"].values[self.offset : self.max_step_num + self.offset]
|
||||
this_tt_ratio = this_vwap / np.nanmean(vwap)
|
||||
|
||||
if self.is_buy:
|
||||
@@ -428,9 +404,7 @@ class StockEnv_Acc(StockEnv):
|
||||
|
||||
for i, reward_func in enumerate(self.reward_func_list):
|
||||
if not reward_func.isinstant:
|
||||
tmp_r = reward_func(
|
||||
performance_raise, this_ffr, this_tt_ratio, self.is_buy
|
||||
)
|
||||
tmp_r = reward_func(performance_raise, this_ffr, this_tt_ratio, self.is_buy)
|
||||
reward += tmp_r * self.reward_coef[i]
|
||||
self.reward_log_dict[type(reward_func).__name__] += tmp_r
|
||||
|
||||
|
||||
@@ -48,15 +48,7 @@ def setup_seed(seed):
|
||||
|
||||
class BaseExecutor(object):
|
||||
def __init__(
|
||||
self,
|
||||
log_dir,
|
||||
resources,
|
||||
env_conf,
|
||||
optim=None,
|
||||
policy_conf=None,
|
||||
network=None,
|
||||
policy_path=None,
|
||||
seed=None,
|
||||
self, log_dir, resources, env_conf, optim=None, policy_conf=None, network=None, policy_path=None, seed=None,
|
||||
):
|
||||
"""A base class for executor
|
||||
|
||||
@@ -88,9 +80,7 @@ class BaseExecutor(object):
|
||||
if seed:
|
||||
setup_seed(seed)
|
||||
|
||||
assert (
|
||||
not policy_path is None or not policy_conf is None
|
||||
), "Policy must be defined"
|
||||
assert not policy_path is None or not policy_conf is None, "Policy must be defined"
|
||||
if policy_path:
|
||||
self.policy = torch.load(policy_path, map_location=self.device)
|
||||
self.policy.actor.extractor.device = self.device
|
||||
@@ -106,17 +96,11 @@ class BaseExecutor(object):
|
||||
device=self.device, **network["config"]
|
||||
)
|
||||
else:
|
||||
net = getattr(model, network["name"] + "_Extractor")(
|
||||
device=self.device, **network["config"]
|
||||
)
|
||||
net = getattr(model, network["name"] + "_Extractor")(device=self.device, **network["config"])
|
||||
net.to(self.device)
|
||||
actor = getattr(model, network["name"] + "_Actor")(
|
||||
extractor=net, device=self.device, **network["config"]
|
||||
)
|
||||
actor = getattr(model, network["name"] + "_Actor")(extractor=net, device=self.device, **network["config"])
|
||||
actor.to(self.device)
|
||||
critic = getattr(model, network["name"] + "_Critic")(
|
||||
extractor=net, device=self.device, **network["config"]
|
||||
)
|
||||
critic = getattr(model, network["name"] + "_Critic")(extractor=net, device=self.device, **network["config"])
|
||||
critic.to(self.device)
|
||||
self.optim = torch.optim.Adam(
|
||||
list(actor.parameters()) + list(critic.parameters()),
|
||||
@@ -162,9 +146,7 @@ class BaseExecutor(object):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def train_round(
|
||||
self, repeat_per_collect, collect_per_step, batch_size, *args, **kargs
|
||||
):
|
||||
def train_round(self, repeat_per_collect, collect_per_step, batch_size, *args, **kargs):
|
||||
"""Do an round of training
|
||||
|
||||
:param collect_per_step: Number of episodes to collect before one bp.
|
||||
@@ -228,29 +210,18 @@ class Executor(BaseExecutor):
|
||||
:param buffer_size: The size of replay buffer, defaults to 200000
|
||||
:type buffer_size: int, optional
|
||||
"""
|
||||
super().__init__(
|
||||
log_dir, resources, env_conf, optim, policy_conf, network, policy_path, seed
|
||||
)
|
||||
super().__init__(log_dir, resources, env_conf, optim, policy_conf, network, policy_path, seed)
|
||||
single_env = getattr(env, env_conf["name"])
|
||||
env_conf = merge_dicts(env_conf, train_paths)
|
||||
env_conf["log"] = True
|
||||
print("CPU_COUNT:", resources["num_cpus"])
|
||||
if share_memory:
|
||||
self.env = ShmemVectorEnv(
|
||||
[lambda: single_env(env_conf) for _ in range(resources["num_cpus"])]
|
||||
)
|
||||
self.env = ShmemVectorEnv([lambda: single_env(env_conf) for _ in range(resources["num_cpus"])])
|
||||
else:
|
||||
self.env = SubprocVectorEnv(
|
||||
[lambda: single_env(env_conf) for _ in range(resources["num_cpus"])]
|
||||
)
|
||||
self.test_collector = Collector(
|
||||
policy=self.policy, env=self.env, testing=True, reward_metric=np.sum
|
||||
)
|
||||
self.env = SubprocVectorEnv([lambda: single_env(env_conf) for _ in range(resources["num_cpus"])])
|
||||
self.test_collector = Collector(policy=self.policy, env=self.env, testing=True, reward_metric=np.sum)
|
||||
self.train_collector = Collector(
|
||||
self.policy,
|
||||
self.env,
|
||||
buffer=ts.data.ReplayBuffer(buffer_size),
|
||||
reward_metric=np.sum,
|
||||
self.policy, self.env, buffer=ts.data.ReplayBuffer(buffer_size), reward_metric=np.sum,
|
||||
)
|
||||
self.train_paths = train_paths
|
||||
self.test_paths = test_paths
|
||||
@@ -259,9 +230,7 @@ class Executor(BaseExecutor):
|
||||
train_sampler_conf["features"] = env_conf["features"]
|
||||
test_sampler_conf = test_paths
|
||||
test_sampler_conf["features"] = env_conf["features"]
|
||||
self.train_sampler = getattr(sampler, io_conf["train_sampler"])(
|
||||
train_sampler_conf
|
||||
)
|
||||
self.train_sampler = getattr(sampler, io_conf["train_sampler"])(train_sampler_conf)
|
||||
self.test_sampler = getattr(sampler, io_conf["test_sampler"])(test_sampler_conf)
|
||||
self.train_logger = logger.InfoLogger()
|
||||
self.test_logger = getattr(logger, io_conf["test_logger"])
|
||||
@@ -286,32 +255,23 @@ class Executor(BaseExecutor):
|
||||
best_epoch, best_reward = -1, -1
|
||||
stat = {}
|
||||
for epoch in range(1, 1 + max_epoch):
|
||||
with tqdm.tqdm(
|
||||
total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
|
||||
) as t:
|
||||
with tqdm.tqdm(total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config) as t:
|
||||
while t.n < t.total:
|
||||
result, losses = self.train_round(
|
||||
repeat_per_collect, collect_per_step, batch_size, iteration
|
||||
)
|
||||
result, losses = self.train_round(repeat_per_collect, collect_per_step, batch_size, iteration)
|
||||
global_step += result["n/st"]
|
||||
iteration += 1
|
||||
for k in result.keys():
|
||||
self.writer.add_scalar(
|
||||
"Train/" + k, result[k], global_step=global_step
|
||||
)
|
||||
self.writer.add_scalar("Train/" + k, result[k], global_step=global_step)
|
||||
for k in losses.keys():
|
||||
if stat.get(k) is None:
|
||||
stat[k] = MovAvg()
|
||||
stat[k].add(losses[k])
|
||||
self.writer.add_scalar(
|
||||
"Train/" + k, stat[k].get(), global_step=global_step
|
||||
)
|
||||
self.writer.add_scalar("Train/" + k, stat[k].get(), global_step=global_step)
|
||||
t.update(1)
|
||||
if t.n <= t.total:
|
||||
t.update()
|
||||
result = self.eval(
|
||||
self.valid_paths["order_dir"],
|
||||
logdir=f"{self.log_dir}/valid/{iteration}/" if log_valid else None,
|
||||
self.valid_paths["order_dir"], logdir=f"{self.log_dir}/valid/{iteration}/" if log_valid else None,
|
||||
)
|
||||
for k in result.keys():
|
||||
self.writer.add_scalar("Valid/" + k, result[k], global_step=global_step)
|
||||
@@ -333,31 +293,22 @@ class Executor(BaseExecutor):
|
||||
break
|
||||
print("Testing...")
|
||||
self.policy.load_state_dict(best_state)
|
||||
result = self.eval(
|
||||
self.test_paths["order_dir"], logdir=f"{self.log_dir}/test/", save_res=True
|
||||
)
|
||||
result = self.eval(self.test_paths["order_dir"], logdir=f"{self.log_dir}/test/", save_res=True)
|
||||
for k in result.keys():
|
||||
self.writer.add_scalar("Test/" + k, result[k], global_step=global_step)
|
||||
return result
|
||||
|
||||
def train_round(
|
||||
self, repeat_per_collect, collect_per_step, batch_size, *args, **kargs
|
||||
):
|
||||
def train_round(self, repeat_per_collect, collect_per_step, batch_size, *args, **kargs):
|
||||
self.policy.train()
|
||||
self.env.toggle_log(False)
|
||||
self.env.sampler = self.train_sampler
|
||||
if not self.q_learning:
|
||||
self.train_collector.reset()
|
||||
result = self.train_collector.collect(
|
||||
n_episode=collect_per_step, log_fn=self.train_logger
|
||||
)
|
||||
result = self.train_collector.collect(n_episode=collect_per_step, log_fn=self.train_logger)
|
||||
result = merge_dicts(result, self.train_logger.summary())
|
||||
if not self.q_learning:
|
||||
losses = self.policy.update(
|
||||
0,
|
||||
self.train_collector.buffer,
|
||||
batch_size=batch_size,
|
||||
repeat=repeat_per_collect,
|
||||
0, self.train_collector.buffer, batch_size=batch_size, repeat=repeat_per_collect,
|
||||
)
|
||||
else:
|
||||
losses = self.policy.update(batch_size, self.train_collector.buffer,)
|
||||
|
||||
@@ -52,24 +52,18 @@ class DFLogger(object):
|
||||
summary[k + "_mean"] = np.nanmean(v)
|
||||
try:
|
||||
for k in ["PR_sell", "ffr_sell", "PA_sell"]:
|
||||
summary["weighted_" + k] = np.average(
|
||||
stat_cache[k], weights=stat_cache["money_sell"]
|
||||
)
|
||||
summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache["money_sell"])
|
||||
except:
|
||||
# summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache['money_sell'])
|
||||
pass
|
||||
try:
|
||||
for k in ["PR_buy", "ffr_buy", "PA_buy"]:
|
||||
summary["weighted_" + k] = np.average(
|
||||
stat_cache[k], weights=stat_cache["money_buy"]
|
||||
)
|
||||
summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache["money_buy"])
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
for k in ["obs0_PR", "ffr", "PA"]:
|
||||
summary["weighted_" + k] = np.average(
|
||||
stat_cache[k], weights=stat_cache["money"]
|
||||
)
|
||||
summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache["money"])
|
||||
except:
|
||||
pass
|
||||
summary["GLR"] = GLR(stat_cache["PA"])
|
||||
@@ -114,11 +108,7 @@ class DFLogger(object):
|
||||
while not self.queue.empty():
|
||||
self.queue.get()
|
||||
assert self.queue.empty()
|
||||
self.child = Process(
|
||||
target=self._worker,
|
||||
args=(self.log_dir, self.order_dir, self.queue),
|
||||
daemon=True,
|
||||
)
|
||||
self.child = Process(target=self._worker, args=(self.log_dir, self.order_dir, self.queue), daemon=True,)
|
||||
self.child.start()
|
||||
|
||||
def set_step(self, step):
|
||||
@@ -170,23 +160,17 @@ class InfoLogger(DFLogger):
|
||||
summary[k + "_mean"] = np.nanmean(v)
|
||||
try:
|
||||
for k in ["PR_sell", "ffr_sell", "PA_sell"]:
|
||||
summary["weighted_" + k] = np.average(
|
||||
stat_cache[k], weights=stat_cache["money_sell"]
|
||||
)
|
||||
summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache["money_sell"])
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
for k in ["PR_buy", "ffr_buy", "PA_buy"]:
|
||||
summary["weighted_" + k] = np.average(
|
||||
stat_cache[k], weights=stat_cache["money_buy"]
|
||||
)
|
||||
summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache["money_buy"])
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
for k in ["obs0_PR", "ffr", "PA"]:
|
||||
summary["weighted_" + k] = np.average(
|
||||
stat_cache[k], weights=stat_cache["money"]
|
||||
)
|
||||
summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache["money"])
|
||||
except:
|
||||
pass
|
||||
summary["GLR"] = GLR(stat_cache["PA"])
|
||||
|
||||
@@ -48,11 +48,7 @@ def run(config):
|
||||
if config["task"] == "train":
|
||||
return executor.train(**config["optim"])
|
||||
elif config["task"] == "eval":
|
||||
return executor.eval(
|
||||
config["test_paths"]["order_dir"],
|
||||
save_res=True,
|
||||
logdir=config["log_dir"] + "/test/",
|
||||
)
|
||||
return executor.eval(config["test_paths"]["order_dir"], save_res=True, logdir=config["log_dir"] + "/test/",)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -76,9 +72,7 @@ if __name__ == "__main__":
|
||||
if "PT_OUTPUT_DIR" in os.environ:
|
||||
config["log_dir"] = os.environ["PT_OUTPUT_DIR"]
|
||||
else:
|
||||
log_prefix = (
|
||||
os.environ["OUTPUT_DIR"] if "OUTPUT_DIR" in os.environ else "../log"
|
||||
)
|
||||
log_prefix = os.environ["OUTPUT_DIR"] if "OUTPUT_DIR" in os.environ else "../log"
|
||||
config["log_dir"] = os.path.join(log_prefix, config["log_dir"])
|
||||
config = get_full_config(config, config_path)
|
||||
run(config)
|
||||
@@ -116,32 +110,17 @@ if __name__ == "__main__":
|
||||
redis_server.set(f"{EXP_NAME}_{index}", "Running")
|
||||
print(f"Trail_{index} is running")
|
||||
try:
|
||||
res = subprocess.run(
|
||||
[
|
||||
"python",
|
||||
"main.py",
|
||||
"--config",
|
||||
args.config,
|
||||
"--index",
|
||||
str(index),
|
||||
],
|
||||
)
|
||||
res = subprocess.run(["python", "main.py", "--config", args.config, "--index", str(index),],)
|
||||
except KeyboardInterrupt:
|
||||
redis_server.set(f"{EXP_NAME}_{index}", "Failed")
|
||||
print(
|
||||
f"Trail_{index} has failed, {redis_server.llen(EXP_NAME)} trails to run"
|
||||
)
|
||||
print(f"Trail_{index} has failed, {redis_server.llen(EXP_NAME)} trails to run")
|
||||
break
|
||||
if res.returncode == 0:
|
||||
redis_server.set(f"{EXP_NAME}_{index}", "Finished")
|
||||
print(
|
||||
f"Finish running one trail, {redis_server.llen(EXP_NAME)} trails to run"
|
||||
)
|
||||
print(f"Finish running one trail, {redis_server.llen(EXP_NAME)} trails to run")
|
||||
else:
|
||||
redis_server.set(f"{EXP_NAME}_{index}", "Failed")
|
||||
print(
|
||||
f"Trail_{index} has failed, {redis_server.llen(EXP_NAME)} trails to run"
|
||||
)
|
||||
print(f"Trail_{index} has failed, {redis_server.llen(EXP_NAME)} trails to run")
|
||||
|
||||
elif os.path.isfile(config_path):
|
||||
assert config_path.endswith(".yml"), "Config file should be an yaml file"
|
||||
@@ -149,9 +128,7 @@ if __name__ == "__main__":
|
||||
with open(config_path, "r") as f:
|
||||
config = yaml.load(f, Loader=loader)
|
||||
config = get_full_config(config, os.path.dirname(config_path))
|
||||
log_prefix = (
|
||||
os.environ["OUTPUT_DIR"] if "OUTPUT_DIR" in os.environ else "../log"
|
||||
)
|
||||
log_prefix = os.environ["OUTPUT_DIR"] if "OUTPUT_DIR" in os.environ else "../log"
|
||||
config["log_dir"] = os.path.join(log_prefix, config["log_dir"])
|
||||
run(config)
|
||||
else:
|
||||
|
||||
@@ -18,24 +18,12 @@ class PPO_Extractor(nn.Module):
|
||||
|
||||
self.rnn = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.rnn2 = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.dnn = nn.Sequential(
|
||||
nn.Linear(2, 64),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.cnn = nn.Sequential(
|
||||
nn.Conv1d(self.cnn_shape[1], 3, 3),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.raw_fc = nn.Sequential(
|
||||
nn.Linear((self.cnn_shape[0] - 2) * 3, 64),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.dnn = nn.Sequential(nn.Linear(2, 64), nn.ReLU(),)
|
||||
self.cnn = nn.Sequential(nn.Conv1d(self.cnn_shape[1], 3, 3), nn.ReLU(),)
|
||||
self.raw_fc = nn.Sequential(nn.Linear((self.cnn_shape[0] - 2) * 3, 64), nn.ReLU(),)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(hidden_size * 2, hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_size, 32),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_size * 2, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 32), nn.ReLU(),
|
||||
)
|
||||
|
||||
def forward(self, inp):
|
||||
@@ -74,9 +62,7 @@ class PPO_Actor(nn.Module):
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
self.feature = self.extractor(obs)
|
||||
assert not (
|
||||
torch.isnan(self.feature).any() | torch.isinf(self.feature).any()
|
||||
), f"{self.feature}"
|
||||
assert not (torch.isnan(self.feature).any() | torch.isinf(self.feature).any()), f"{self.feature}"
|
||||
out = self.layer_out(self.feature)
|
||||
return out, state
|
||||
|
||||
|
||||
@@ -18,18 +18,9 @@ class RNNQModel(nn.Module):
|
||||
|
||||
self.rnn = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.rnn2 = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.dnn = nn.Sequential(
|
||||
nn.Linear(2, 64),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.cnn = nn.Sequential(
|
||||
nn.Conv1d(self.cnn_shape[1], 3, 3),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.raw_fc = nn.Sequential(
|
||||
nn.Linear((self.cnn_shape[0] - 2) * 3, 64),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.dnn = nn.Sequential(nn.Linear(2, 64), nn.ReLU(),)
|
||||
self.cnn = nn.Sequential(nn.Conv1d(self.cnn_shape[1], 3, 3), nn.ReLU(),)
|
||||
self.raw_fc = nn.Sequential(nn.Linear((self.cnn_shape[0] - 2) * 3, 64), nn.ReLU(),)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(hidden_size * 2, hidden_size),
|
||||
|
||||
@@ -18,24 +18,12 @@ class Teacher_Extractor(nn.Module):
|
||||
|
||||
self.rnn = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.rnn2 = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.dnn = nn.Sequential(
|
||||
nn.Linear(2, 64),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.cnn = nn.Sequential(
|
||||
nn.Conv1d(self.cnn_shape[1], 3, 3),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.raw_fc = nn.Sequential(
|
||||
nn.Linear((self.cnn_shape[0] - 2) * 3, 64),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.dnn = nn.Sequential(nn.Linear(2, 64), nn.ReLU(),)
|
||||
self.cnn = nn.Sequential(nn.Conv1d(self.cnn_shape[1], 3, 3), nn.ReLU(),)
|
||||
self.raw_fc = nn.Sequential(nn.Linear((self.cnn_shape[0] - 2) * 3, 64), nn.ReLU(),)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(hidden_size * 2, hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_size, 32),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_size * 2, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 32), nn.ReLU(),
|
||||
)
|
||||
|
||||
def forward(self, inp):
|
||||
|
||||
@@ -11,14 +11,9 @@ from tianshou.data import to_torch
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
self.get_w = nn.Sequential(
|
||||
nn.Linear(in_dim * 2, in_dim), nn.ReLU(), nn.Linear(in_dim, 1)
|
||||
)
|
||||
self.get_w = nn.Sequential(nn.Linear(in_dim * 2, in_dim), nn.ReLU(), nn.Linear(in_dim, 1))
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(in_dim, out_dim),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.fc = nn.Sequential(nn.Linear(in_dim, out_dim), nn.ReLU(),)
|
||||
|
||||
def forward(self, value, key):
|
||||
key = key.unsqueeze(dim=1)
|
||||
@@ -34,14 +29,9 @@ class Attention(nn.Module):
|
||||
class MaskAttention(nn.Module):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
self.get_w = nn.Sequential(
|
||||
nn.Linear(in_dim * 2, in_dim), nn.ReLU(), nn.Linear(in_dim, 1)
|
||||
)
|
||||
self.get_w = nn.Sequential(nn.Linear(in_dim * 2, in_dim), nn.ReLU(), nn.Linear(in_dim, 1))
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(in_dim, out_dim),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.fc = nn.Sequential(nn.Linear(in_dim, out_dim), nn.ReLU(),)
|
||||
|
||||
def forward(self, value, key, seq_len, maxlen=9):
|
||||
# seq_len: (batch,)
|
||||
@@ -61,14 +51,9 @@ class MaskAttention(nn.Module):
|
||||
class TFMaskAttention(nn.Module):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
self.get_w = nn.Sequential(
|
||||
nn.Linear(in_dim * 2, in_dim), nn.ReLU(), nn.Linear(in_dim, 1)
|
||||
)
|
||||
self.get_w = nn.Sequential(nn.Linear(in_dim * 2, in_dim), nn.ReLU(), nn.Linear(in_dim, 1))
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(in_dim, out_dim),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.fc = nn.Sequential(nn.Linear(in_dim, out_dim), nn.ReLU(),)
|
||||
|
||||
def forward(self, value, key, seq_len, maxlen=9):
|
||||
device = value.device
|
||||
@@ -155,14 +140,10 @@ class DARNN(nn.Module):
|
||||
def forward(self, inputs):
|
||||
inputs = inputs.view(-1, self.input_length, self.input_size) # [B, T, F]
|
||||
today_input = inputs[:, : self.today_length, :]
|
||||
today_input = torch.cat(
|
||||
(torch.zeros_like(today_input[:, :1, :]), today_input), dim=1
|
||||
)
|
||||
today_input = torch.cat((torch.zeros_like(today_input[:, :1, :]), today_input), dim=1)
|
||||
prev_input = inputs[:, 240 : 240 + self.prev_length, :]
|
||||
if self.emb_dim != 0:
|
||||
embedding = self.pos_emb(
|
||||
torch.arange(end=self.today_length + 1, device=inputs.device)
|
||||
)
|
||||
embedding = self.pos_emb(torch.arange(end=self.today_length + 1, device=inputs.device))
|
||||
embedding = embedding.repeat([today_input.size()[0], 1, 1])
|
||||
today_input = torch.cat((today_input, embedding), dim=-1)
|
||||
prev_outs, _ = self.prev_rnn(prev_input)
|
||||
@@ -205,8 +186,6 @@ def onehot_enc(y, len):
|
||||
def sequence_mask(lengths, maxlen=None, dtype=torch.bool, device=None):
|
||||
if maxlen is None:
|
||||
maxlen = lengths.max()
|
||||
mask = ~(
|
||||
torch.ones((len(lengths), maxlen), device=device).cumsum(dim=1).t() > lengths
|
||||
).t()
|
||||
mask = ~(torch.ones((len(lengths), maxlen), device=device).cumsum(dim=1).t() > lengths).t()
|
||||
mask.type(dtype)
|
||||
return mask
|
||||
|
||||
@@ -60,9 +60,7 @@ class RuleObs(BaseObs):
|
||||
prediction = [df_list[i].reshape(-1) for i in range(len(df_list))]
|
||||
for i, p in enumerate(prediction):
|
||||
if len(p) < interval_num:
|
||||
prediction[i] = np.concatenate(
|
||||
(p, np.zeros(interval_num - len(p))), axis=-1
|
||||
)
|
||||
prediction[i] = np.concatenate((p, np.zeros(interval_num - len(p))), axis=-1)
|
||||
# res = np.stack(prediction).transpose().reshape(-1)
|
||||
return np.concatenate(prediction)
|
||||
for i in range(len(self.features)):
|
||||
@@ -73,9 +71,7 @@ class RuleObs(BaseObs):
|
||||
if time == -1:
|
||||
predictions += [0.0] * size
|
||||
else:
|
||||
predictions += (
|
||||
df.iloc[size * time : size * (time + 1)].reshape(-1).tolist()
|
||||
)
|
||||
predictions += df.iloc[size * time : size * (time + 1)].reshape(-1).tolist()
|
||||
elif feature["type"] == "daily":
|
||||
predictions += df.reshape(-1)[:size].tolist()
|
||||
elif feature["type"] == "range":
|
||||
@@ -86,35 +82,19 @@ class RuleObs(BaseObs):
|
||||
else:
|
||||
predictions += df.iloc[time : size + time].reshape(-1).tolist()
|
||||
elif feature["type"] == "interval":
|
||||
if (
|
||||
len(df.iloc[interval * size : (interval + 1) * size].reshape(-1))
|
||||
== size
|
||||
):
|
||||
predictions += (
|
||||
df.iloc[interval * size : (interval + 1) * size]
|
||||
.reshape(-1)
|
||||
.tolist()
|
||||
)
|
||||
if len(df.iloc[interval * size : (interval + 1) * size].reshape(-1)) == size:
|
||||
predictions += df.iloc[interval * size : (interval + 1) * size].reshape(-1).tolist()
|
||||
else:
|
||||
predictions += [0.0] * size
|
||||
elif feature["type"] == "step":
|
||||
if (
|
||||
len(df.iloc[size * (time + 1) : size * (time + 2)].reshape(-1))
|
||||
== size
|
||||
):
|
||||
predictions += (
|
||||
df.iloc[size * (time + 1) : size * (time + 2)]
|
||||
.reshape(-1)
|
||||
.tolist()
|
||||
)
|
||||
if len(df.iloc[size * (time + 1) : size * (time + 2)].reshape(-1)) == size:
|
||||
predictions += df.iloc[size * (time + 1) : size * (time + 2)].reshape(-1).tolist()
|
||||
else:
|
||||
predictions += [0.0] * size
|
||||
|
||||
return np.array(predictions)
|
||||
|
||||
def get_obs(
|
||||
self, raw_df, feature_dfs, t, interval, position, target, is_buy, *args, **kargs
|
||||
):
|
||||
def get_obs(self, raw_df, feature_dfs, t, interval, position, target, is_buy, *args, **kargs):
|
||||
private_state = np.array([position, target, t, self.max_step_num])
|
||||
prediction_state = self.get_feature_res(feature_dfs, t, interval)
|
||||
return {
|
||||
|
||||
@@ -11,17 +11,7 @@ class PPOObs(RuleObs):
|
||||
"""The observation defined in IJCAI 2020. The action of previous state is included in private state"""
|
||||
|
||||
def get_obs(
|
||||
self,
|
||||
raw_df,
|
||||
feature_dfs,
|
||||
t,
|
||||
interval,
|
||||
position,
|
||||
target,
|
||||
is_buy,
|
||||
max_step_num,
|
||||
interval_num,
|
||||
action=0,
|
||||
self, raw_df, feature_dfs, t, interval, position, target, is_buy, max_step_num, interval_num, action=0,
|
||||
):
|
||||
if t == -1:
|
||||
self.private_states = []
|
||||
@@ -32,10 +22,7 @@ class PPOObs(RuleObs):
|
||||
self.private_states.append(private_state)
|
||||
list_private_state = np.concatenate(self.private_states)
|
||||
list_private_state = np.concatenate(
|
||||
(
|
||||
list_private_state,
|
||||
[0.0] * 3 * (interval_num + 1 - len(self.private_states)),
|
||||
)
|
||||
(list_private_state, [0.0] * 3 * (interval_num + 1 - len(self.private_states)),)
|
||||
)
|
||||
seqlen = np.array([interval])
|
||||
return np.concatenate((public_state, list_private_state, seqlen))
|
||||
|
||||
@@ -16,18 +16,7 @@ class TeacherObs(RuleObs):
|
||||
"""
|
||||
|
||||
def get_obs(
|
||||
self,
|
||||
raw_df,
|
||||
feature_dfs,
|
||||
t,
|
||||
interval,
|
||||
position,
|
||||
target,
|
||||
is_buy,
|
||||
max_step_num,
|
||||
interval_num,
|
||||
*args,
|
||||
**kargs,
|
||||
self, raw_df, feature_dfs, t, interval, position, target, is_buy, max_step_num, interval_num, *args, **kargs,
|
||||
):
|
||||
if t == -1:
|
||||
self.private_states = []
|
||||
@@ -36,18 +25,13 @@ class TeacherObs(RuleObs):
|
||||
self.private_states.append(private_state)
|
||||
list_private_state = np.concatenate(self.private_states)
|
||||
list_private_state = np.concatenate(
|
||||
(
|
||||
list_private_state,
|
||||
[0.0] * 2 * (interval_num + 1 - len(self.private_states)),
|
||||
)
|
||||
(list_private_state, [0.0] * 2 * (interval_num + 1 - len(self.private_states)),)
|
||||
)
|
||||
seqlen = np.array([interval])
|
||||
assert not (
|
||||
np.isnan(list_private_state).any() | np.isinf(list_private_state).any()
|
||||
), f"{private_state}, {target}"
|
||||
assert not (
|
||||
np.isnan(public_state).any() | np.isinf(public_state).any()
|
||||
), f"{public_state}"
|
||||
assert not (np.isnan(public_state).any() | np.isinf(public_state).any()), f"{public_state}"
|
||||
return np.concatenate((public_state, list_private_state, seqlen))
|
||||
|
||||
|
||||
@@ -55,35 +39,17 @@ class RuleTeacher(RuleObs):
|
||||
""" """
|
||||
|
||||
def get_obs(
|
||||
self,
|
||||
raw_df,
|
||||
feature_dfs,
|
||||
t,
|
||||
interval,
|
||||
position,
|
||||
target,
|
||||
is_buy,
|
||||
max_step_num,
|
||||
interval_num,
|
||||
*args,
|
||||
**kargs,
|
||||
self, raw_df, feature_dfs, t, interval, position, target, is_buy, max_step_num, interval_num, *args, **kargs,
|
||||
):
|
||||
if t == -1:
|
||||
self.private_states = []
|
||||
public_state = feature_dfs[0].reshape(-1)[: 6 * 240]
|
||||
private_state = np.array([position / target, (t + 1) / max_step_num])
|
||||
teacher_action = self.get_feature_res(feature_dfs, t, interval)[
|
||||
-self.features[1]["size"] :
|
||||
]
|
||||
teacher_action = self.get_feature_res(feature_dfs, t, interval)[-self.features[1]["size"] :]
|
||||
self.private_states.append(private_state)
|
||||
list_private_state = np.concatenate(self.private_states)
|
||||
list_private_state = np.concatenate(
|
||||
(
|
||||
list_private_state,
|
||||
[0.0] * 2 * (interval_num + 1 - len(self.private_states)),
|
||||
)
|
||||
(list_private_state, [0.0] * 2 * (interval_num + 1 - len(self.private_states)),)
|
||||
)
|
||||
seqlen = np.array([interval])
|
||||
return np.concatenate(
|
||||
(teacher_action, public_state, list_private_state, seqlen)
|
||||
)
|
||||
return np.concatenate((teacher_action, public_state, list_private_state, seqlen))
|
||||
|
||||
@@ -16,11 +16,7 @@ from util import to_numpy, to_torch_as
|
||||
|
||||
|
||||
def _episodic_return(
|
||||
v_s_: np.ndarray,
|
||||
rew: np.ndarray,
|
||||
done: np.ndarray,
|
||||
gamma: float,
|
||||
gae_lambda: float,
|
||||
v_s_: np.ndarray, rew: np.ndarray, done: np.ndarray, gamma: float, gae_lambda: float,
|
||||
) -> np.ndarray:
|
||||
"""Numba speedup: 4.1s -> 0.057s."""
|
||||
returns = np.roll(v_s_, 1)
|
||||
@@ -77,9 +73,7 @@ class PPO(PGPolicy):
|
||||
self._batch = 64
|
||||
assert 0 <= gae_lambda <= 1, "GAE lambda should be in [0, 1]."
|
||||
self._lambda = gae_lambda
|
||||
assert (
|
||||
dual_clip is None or dual_clip > 1
|
||||
), "Dual-clip PPO parameter should greater than 1."
|
||||
assert dual_clip is None or dual_clip > 1, "Dual-clip PPO parameter should greater than 1."
|
||||
self._dual_clip = dual_clip
|
||||
self._value_clip = value_clip
|
||||
self._rew_norm = reward_normalization
|
||||
@@ -127,18 +121,14 @@ class PPO(PGPolicy):
|
||||
batch.returns = returns
|
||||
return batch
|
||||
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> Batch:
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch:
|
||||
if self._rew_norm:
|
||||
mean, std = batch.rew.mean(), batch.rew.std()
|
||||
if not np.isclose(std, 0):
|
||||
batch.rew = (batch.rew - mean) / std
|
||||
assert not np.isnan(batch.rew).any()
|
||||
if self._lambda in [0, 1]:
|
||||
return self.compute_episodic_return(
|
||||
batch, None, gamma=self._gamma, gae_lambda=self._lambda
|
||||
)
|
||||
return self.compute_episodic_return(batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
else:
|
||||
v_ = []
|
||||
with torch.no_grad():
|
||||
@@ -146,16 +136,9 @@ class PPO(PGPolicy):
|
||||
v_.append(self.critic(b.obs_next))
|
||||
v_ = to_numpy(torch.cat(v_, dim=0))
|
||||
assert not np.isnan(v_).any()
|
||||
return self.compute_episodic_return(
|
||||
batch, v_, gamma=self._gamma, gae_lambda=self._lambda
|
||||
)
|
||||
return self.compute_episodic_return(batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
**kwargs
|
||||
) -> Batch:
|
||||
def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, **kwargs) -> Batch:
|
||||
"""Compute action over the given batch data."""
|
||||
logits, h = self.actor(batch.obs, state=state, info=batch.info)
|
||||
if isinstance(logits, tuple):
|
||||
@@ -174,9 +157,7 @@ class PPO(PGPolicy):
|
||||
act = act.clamp(self._range[0], self._range[1])
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def learn(
|
||||
self, batch: Batch, batch_size: int, repeat: int, **kwargs
|
||||
) -> Dict[str, List[float]]:
|
||||
def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]:
|
||||
self._batch = batch_size
|
||||
losses, clip_losses, vf_losses, ent_losses, kl_losses = [], [], [], [], []
|
||||
if self.teacher is not None:
|
||||
@@ -224,16 +205,12 @@ class PPO(PGPolicy):
|
||||
surr1 = ratio * b.adv
|
||||
surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * b.adv
|
||||
if self._dual_clip:
|
||||
clip_loss = -torch.max(
|
||||
torch.min(surr1, surr2), self._dual_clip * b.adv
|
||||
).mean()
|
||||
clip_loss = -torch.max(torch.min(surr1, surr2), self._dual_clip * b.adv).mean()
|
||||
else:
|
||||
clip_loss = -torch.min(surr1, surr2).mean()
|
||||
clip_losses.append(clip_loss.item())
|
||||
if self._value_clip:
|
||||
v_clip = b.v + (value - b.v).clamp(
|
||||
-self._vf_clip_para, self._vf_clip_para
|
||||
)
|
||||
v_clip = b.v + (value - b.v).clamp(-self._vf_clip_para, self._vf_clip_para)
|
||||
vf1 = (b.returns - value).pow(2)
|
||||
vf2 = (b.returns - v_clip).pow(2)
|
||||
vf_loss = torch.max(vf1, vf2).mean()
|
||||
@@ -242,28 +219,20 @@ class PPO(PGPolicy):
|
||||
if not self.teacher is None:
|
||||
supervision_loss = (b.old_feature - feature).pow(2).mean()
|
||||
supervision_losses.append(supervision_loss.item())
|
||||
kl = torch.distributions.kl.kl_divergence(
|
||||
self.dist_fn(b.old_logits), dist
|
||||
)
|
||||
kl = torch.distributions.kl.kl_divergence(self.dist_fn(b.old_logits), dist)
|
||||
kl_loss = kl.mean()
|
||||
kl_losses.append(kl_loss.item())
|
||||
vf_losses.append(vf_loss.item())
|
||||
e_loss = dist.entropy().mean()
|
||||
ent_losses.append(e_loss.item())
|
||||
loss = (
|
||||
clip_loss
|
||||
+ self._w_vf * vf_loss
|
||||
- self._w_ent * e_loss
|
||||
+ self.kl_coef * kl_loss
|
||||
)
|
||||
loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss + self.kl_coef * kl_loss
|
||||
if self.teacher is not None:
|
||||
loss += self.sup_coef * supervision_loss
|
||||
losses.append(loss.item())
|
||||
self.optim.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(
|
||||
list(self.actor.parameters()) + list(self.critic.parameters()),
|
||||
self._max_grad_norm,
|
||||
list(self.actor.parameters()) + list(self.critic.parameters()), self._max_grad_norm,
|
||||
)
|
||||
self.optim.step()
|
||||
cur_kl = np.mean(kl_losses)
|
||||
|
||||
@@ -58,40 +58,27 @@ class PPO_sup(PGPolicy):
|
||||
self._batch = 64
|
||||
assert 0 <= gae_lambda <= 1, "GAE lambda should be in [0, 1]."
|
||||
self._lambda = gae_lambda
|
||||
assert (
|
||||
dual_clip is None or dual_clip > 1
|
||||
), "Dual-clip PPO parameter should greater than 1."
|
||||
assert dual_clip is None or dual_clip > 1, "Dual-clip PPO parameter should greater than 1."
|
||||
self._dual_clip = dual_clip
|
||||
self._value_clip = value_clip
|
||||
self._rew_norm = reward_normalization
|
||||
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> Batch:
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch:
|
||||
if self._rew_norm:
|
||||
mean, std = batch.rew.mean(), batch.rew.std()
|
||||
if not np.isclose(std, 0):
|
||||
batch.rew = (batch.rew - mean) / std
|
||||
if self._lambda in [0, 1]:
|
||||
return self.compute_episodic_return(
|
||||
batch, None, gamma=self._gamma, gae_lambda=self._lambda
|
||||
)
|
||||
return self.compute_episodic_return(batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
else:
|
||||
v_ = []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(self._batch, shuffle=False):
|
||||
v_.append(self.critic(b.obs_next))
|
||||
v_ = to_numpy(torch.cat(v_, dim=0))
|
||||
return self.compute_episodic_return(
|
||||
batch, v_, gamma=self._gamma, gae_lambda=self._lambda
|
||||
)
|
||||
return self.compute_episodic_return(batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
**kwargs
|
||||
) -> Batch:
|
||||
def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, **kwargs) -> Batch:
|
||||
logits, h = self.actor(batch.obs, state=state, info=batch.info)
|
||||
if isinstance(logits, tuple):
|
||||
dist = self.dist_fn(*logits)
|
||||
@@ -105,9 +92,7 @@ class PPO_sup(PGPolicy):
|
||||
act = act.clamp(self._range[0], self._range[1])
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def learn(
|
||||
self, batch: Batch, batch_size: int, repeat: int, **kwargs
|
||||
) -> Dict[str, List[float]]:
|
||||
def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]:
|
||||
self._batch = batch_size
|
||||
losses, clip_losses, vf_losses, ent_losses, kl_losses, supervision_losses = (
|
||||
[],
|
||||
@@ -156,16 +141,12 @@ class PPO_sup(PGPolicy):
|
||||
surr1 = ratio * b.adv
|
||||
surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * b.adv
|
||||
if self._dual_clip:
|
||||
clip_loss = -torch.max(
|
||||
torch.min(surr1, surr2), self._dual_clip * b.adv
|
||||
).mean()
|
||||
clip_loss = -torch.max(torch.min(surr1, surr2), self._dual_clip * b.adv).mean()
|
||||
else:
|
||||
clip_loss = -torch.min(surr1, surr2).mean()
|
||||
clip_losses.append(clip_loss.item())
|
||||
if self._value_clip:
|
||||
v_clip = b.v + (value - b.v).clamp(
|
||||
-self._vf_clip_para, self._vf_clip_para
|
||||
)
|
||||
v_clip = b.v + (value - b.v).clamp(-self._vf_clip_para, self._vf_clip_para)
|
||||
vf1 = (b.returns - value).pow(2)
|
||||
vf2 = (b.returns - v_clip).pow(2)
|
||||
vf_loss = torch.max(vf1, vf2).mean()
|
||||
@@ -173,27 +154,19 @@ class PPO_sup(PGPolicy):
|
||||
vf_loss = (b.returns - value).pow(2).mean()
|
||||
supervision_loss = F.nll_loss(logits.log(), b.teacher_action)
|
||||
supervision_losses.append(supervision_loss.item())
|
||||
kl = torch.distributions.kl.kl_divergence(
|
||||
self.dist_fn(b.old_logits), dist
|
||||
)
|
||||
kl = torch.distributions.kl.kl_divergence(self.dist_fn(b.old_logits), dist)
|
||||
kl_loss = kl.mean()
|
||||
kl_losses.append(kl_loss.item())
|
||||
vf_losses.append(vf_loss.item())
|
||||
e_loss = dist.entropy().mean()
|
||||
ent_losses.append(e_loss.item())
|
||||
loss = (
|
||||
clip_loss
|
||||
+ self._w_vf * vf_loss
|
||||
- self._w_ent * e_loss
|
||||
+ self.kl_coef * kl_loss
|
||||
)
|
||||
loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss + self.kl_coef * kl_loss
|
||||
loss += self.sup_coef * supervision_loss
|
||||
losses.append(loss.item())
|
||||
self.optim.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(
|
||||
list(self.actor.parameters()) + list(self.critic.parameters()),
|
||||
self._max_grad_norm,
|
||||
list(self.actor.parameters()) + list(self.critic.parameters()), self._max_grad_norm,
|
||||
)
|
||||
self.optim.step()
|
||||
if hasattr(self.actor, "callback"):
|
||||
|
||||
@@ -18,9 +18,7 @@ class VP_Penalty_small(Instant_Reward):
|
||||
assert target > 0
|
||||
reward = performance_raise * v_t / target
|
||||
reward -= self.penalty * (v_t / target) ** 2
|
||||
assert not (
|
||||
np.isnan(reward) or np.isinf(reward)
|
||||
), f"{performance_raise}, {v_t}, {target}"
|
||||
assert not (np.isnan(reward) or np.isinf(reward)), f"{performance_raise}, {v_t}, {target}"
|
||||
return reward / 100
|
||||
|
||||
|
||||
@@ -35,7 +33,5 @@ class VP_Penalty_small_vec(VP_Penalty_small):
|
||||
assert target > 0
|
||||
reward = performance_raise * v_t.sum() / target
|
||||
reward -= self.penalty * ((v_t / target) ** 2).sum()
|
||||
assert not (
|
||||
np.isnan(reward) or np.isinf(reward)
|
||||
), f"{performance_raise}, {v_t}, {target}"
|
||||
assert not (np.isnan(reward) or np.isinf(reward)), f"{performance_raise}, {v_t}, {target}"
|
||||
return reward / 100
|
||||
|
||||
@@ -37,9 +37,7 @@ class Sampler:
|
||||
def __init__(self, config):
|
||||
self.raw_dir = config["raw_dir"] + "/"
|
||||
self.order_dir = config["order_dir"] + "/"
|
||||
self.ins_list = [
|
||||
f[:-11] for f in os.listdir(self.order_dir) if f.endswith("target")
|
||||
]
|
||||
self.ins_list = [f[:-11] for f in os.listdir(self.order_dir) if f.endswith("target")]
|
||||
self.features = config["features"]
|
||||
self.queue = Queue(1000)
|
||||
self.child = None
|
||||
@@ -60,9 +58,7 @@ class Sampler:
|
||||
order_df = pd.read_pickle(order_dir + ins + ".pkl.target")
|
||||
feature_df_list = []
|
||||
for feature in features:
|
||||
feature_df_list.append(
|
||||
pd.read_pickle(f"{feature['loc']}/{ins}.pkl")
|
||||
)
|
||||
feature_df_list.append(pd.read_pickle(f"{feature['loc']}/{ins}.pkl"))
|
||||
raw_df = pd.read_pickle(raw_dir + ins + ".pkl.backtest")
|
||||
date_list = order_df.index.get_level_values(0).tolist()
|
||||
index = 0
|
||||
@@ -81,16 +77,7 @@ class Sampler:
|
||||
day_raw_df_index, day_raw_df_value, day_raw_df_column = toArray(day_raw_df)
|
||||
day_feature_dfs_ = toArray(day_feature_dfs)
|
||||
queue.put(
|
||||
(
|
||||
ins,
|
||||
date,
|
||||
day_raw_df_value,
|
||||
day_raw_df_column,
|
||||
day_raw_df_index,
|
||||
day_feature_dfs_,
|
||||
target,
|
||||
is_buy,
|
||||
),
|
||||
(ins, date, day_raw_df_value, day_raw_df_column, day_raw_df_index, day_feature_dfs_, target, is_buy,),
|
||||
block=True,
|
||||
)
|
||||
|
||||
@@ -103,13 +90,7 @@ class Sampler:
|
||||
if self.child is None:
|
||||
self.child = Process(
|
||||
target=self._worker,
|
||||
args=(
|
||||
self.order_dir,
|
||||
self.raw_dir,
|
||||
self.features,
|
||||
self.ins_list,
|
||||
self.queue,
|
||||
),
|
||||
args=(self.order_dir, self.raw_dir, self.features, self.ins_list, self.queue,),
|
||||
daemon=True,
|
||||
)
|
||||
self.child.start()
|
||||
@@ -164,9 +145,7 @@ class TestSampler(Sampler):
|
||||
for df in df_list:
|
||||
day_df_list.append(df.loc[ins, date].values)
|
||||
day_feature_dfs = np.array(day_df_list)
|
||||
day_raw_df_index, day_raw_df_value, day_raw_df_column = toArray(
|
||||
day_raw_df
|
||||
)
|
||||
day_raw_df_index, day_raw_df_value, day_raw_df_column = toArray(day_raw_df)
|
||||
day_feature_dfs_ = toArray(day_feature_dfs)
|
||||
queue.put(
|
||||
(
|
||||
@@ -192,22 +171,14 @@ class TestSampler(Sampler):
|
||||
"""
|
||||
if order_dir:
|
||||
self.order_dir = order_dir
|
||||
self.ins_list = [
|
||||
f[:-11] for f in os.listdir(self.order_dir) if f.endswith("target")
|
||||
]
|
||||
self.ins_list = [f[:-11] for f in os.listdir(self.order_dir) if f.endswith("target")]
|
||||
if not self.child is None:
|
||||
self.child.terminate()
|
||||
while not self.queue.empty():
|
||||
self.queue.get()
|
||||
self.child = Process(
|
||||
target=self._worker,
|
||||
args=(
|
||||
self.order_dir,
|
||||
self.raw_dir,
|
||||
self.features,
|
||||
self.ins_list,
|
||||
self.queue,
|
||||
),
|
||||
args=(self.order_dir, self.raw_dir, self.features, self.ins_list, self.queue,),
|
||||
daemon=True,
|
||||
)
|
||||
self.child.start()
|
||||
|
||||
@@ -16,9 +16,7 @@ def nan_weighted_avg(vals, weights, axis=None):
|
||||
:param axis: On which axis to calculate the weighted avrage. (Default value = None)
|
||||
|
||||
"""
|
||||
assert vals.shape == weights.shape, AssertionError(
|
||||
f"{vals.shape} & {weights.shape}"
|
||||
)
|
||||
assert vals.shape == weights.shape, AssertionError(f"{vals.shape} & {weights.shape}")
|
||||
vals = vals.copy()
|
||||
weights = weights.copy()
|
||||
res = (vals * weights).sum(axis=axis) / weights.sum(axis=axis)
|
||||
@@ -53,11 +51,7 @@ def merge_dicts(d1, d2):
|
||||
|
||||
|
||||
def deep_update(
|
||||
original,
|
||||
new_dict,
|
||||
new_keys_allowed=False,
|
||||
whitelist=None,
|
||||
override_all_if_type_changes=None,
|
||||
original, new_dict, new_keys_allowed=False, whitelist=None, override_all_if_type_changes=None,
|
||||
):
|
||||
"""Updates original dict with values from new_dict recursively.
|
||||
If new key is introduced in new_dict, then if new_keys_allowed is not
|
||||
@@ -140,18 +134,9 @@ def generate_seq(seqlen, list):
|
||||
maxlen = np.max(seqlen)
|
||||
for i in seqlen:
|
||||
if isinstance(list, torch.Tensor):
|
||||
res.append(
|
||||
torch.cat(
|
||||
(list[index : index + i], torch.zeros_like(list[: maxlen - i])),
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
res.append(torch.cat((list[index : index + i], torch.zeros_like(list[: maxlen - i])), dim=0,))
|
||||
else:
|
||||
res.append(
|
||||
np.concatenate(
|
||||
(list[index : index + i], np.zeros_like(list[: maxlen - i])), axis=0
|
||||
)
|
||||
)
|
||||
res.append(np.concatenate((list[index : index + i], np.zeros_like(list[: maxlen - i])), axis=0))
|
||||
index += i
|
||||
if isinstance(list, torch.Tensor):
|
||||
res = torch.stack(res, dim=0)
|
||||
@@ -298,9 +283,7 @@ def to_torch(
|
||||
return x
|
||||
|
||||
|
||||
def to_torch_as(
|
||||
x: Union[torch.Tensor, dict, Batch, np.ndarray], y: torch.Tensor
|
||||
) -> Union[dict, Batch, torch.Tensor]:
|
||||
def to_torch_as(x: Union[torch.Tensor, dict, Batch, np.ndarray], y: torch.Tensor) -> Union[dict, Batch, torch.Tensor]:
|
||||
"""
|
||||
|
||||
:param x: Union[torch.Tensor:
|
||||
|
||||
@@ -100,9 +100,7 @@ def _worker(
|
||||
|
||||
"""
|
||||
|
||||
def _encode_obs(
|
||||
obs: Union[dict, tuple, np.ndarray], buffer: Union[dict, tuple, ShArray],
|
||||
) -> None:
|
||||
def _encode_obs(obs: Union[dict, tuple, np.ndarray], buffer: Union[dict, tuple, ShArray],) -> None:
|
||||
"""
|
||||
|
||||
:param obs: Union[dict:
|
||||
@@ -170,9 +168,7 @@ def _worker(
|
||||
class SubprocEnvWorker(EnvWorker):
|
||||
"""Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv."""
|
||||
|
||||
def __init__(
|
||||
self, env_fn: Callable[[], gym.Env], share_memory: bool = False
|
||||
) -> None:
|
||||
def __init__(self, env_fn: Callable[[], gym.Env], share_memory: bool = False) -> None:
|
||||
super().__init__(env_fn)
|
||||
self.parent_remote, self.child_remote = Pipe()
|
||||
self.share_memory = share_memory
|
||||
@@ -200,9 +196,7 @@ class SubprocEnvWorker(EnvWorker):
|
||||
def _decode_obs(self) -> Union[dict, tuple, np.ndarray]:
|
||||
""" """
|
||||
|
||||
def decode_obs(
|
||||
buffer: Optional[Union[dict, tuple, ShArray]]
|
||||
) -> Union[dict, tuple, np.ndarray]:
|
||||
def decode_obs(buffer: Optional[Union[dict, tuple, ShArray]]) -> Union[dict, tuple, np.ndarray]:
|
||||
"""
|
||||
|
||||
:param buffer: Optional[Union[dict:
|
||||
@@ -244,9 +238,7 @@ class SubprocEnvWorker(EnvWorker):
|
||||
|
||||
@staticmethod
|
||||
def wait( # type: ignore
|
||||
workers: List["SubprocEnvWorker"],
|
||||
wait_num: int,
|
||||
timeout: Optional[float] = None,
|
||||
workers: List["SubprocEnvWorker"], wait_num: int, timeout: Optional[float] = None,
|
||||
) -> List["SubprocEnvWorker"]:
|
||||
"""
|
||||
|
||||
@@ -389,13 +381,9 @@ class BaseVectorEnv(gym.Env):
|
||||
|
||||
self.env_num = len(env_fns)
|
||||
self.wait_num = wait_num or len(env_fns)
|
||||
assert (
|
||||
1 <= self.wait_num <= len(env_fns)
|
||||
), f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}"
|
||||
assert 1 <= self.wait_num <= len(env_fns), f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}"
|
||||
self.timeout = timeout
|
||||
assert (
|
||||
self.timeout is None or self.timeout > 0
|
||||
), f"timeout is {timeout}, it should be positive if provided!"
|
||||
assert self.timeout is None or self.timeout > 0, f"timeout is {timeout}, it should be positive if provided!"
|
||||
self.is_async = self.wait_num != len(env_fns) or timeout is not None or testing
|
||||
self.waiting_conn: List[EnvWorker] = []
|
||||
# environments in self.ready_id is actually ready
|
||||
@@ -411,9 +399,7 @@ class BaseVectorEnv(gym.Env):
|
||||
|
||||
def _assert_is_not_closed(self) -> None:
|
||||
""" """
|
||||
assert not self.is_closed, (
|
||||
f"Methods of {self.__class__.__name__} cannot be called after " "close."
|
||||
)
|
||||
assert not self.is_closed, f"Methods of {self.__class__.__name__} cannot be called after " "close."
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return len(self), which is the number of environments."""
|
||||
@@ -445,9 +431,7 @@ class BaseVectorEnv(gym.Env):
|
||||
"""
|
||||
return [getattr(worker, key) for worker in self.workers]
|
||||
|
||||
def _wrap_id(
|
||||
self, id: Optional[Union[int, List[int], np.ndarray]] = None
|
||||
) -> Union[List[int], np.ndarray]:
|
||||
def _wrap_id(self, id: Optional[Union[int, List[int], np.ndarray]] = None) -> Union[List[int], np.ndarray]:
|
||||
"""
|
||||
|
||||
:param id: Optional[Union[int:
|
||||
@@ -474,16 +458,10 @@ class BaseVectorEnv(gym.Env):
|
||||
|
||||
"""
|
||||
for i in id:
|
||||
assert (
|
||||
i not in self.waiting_id
|
||||
), f"Cannot interact with environment {i} which is stepping now."
|
||||
assert (
|
||||
i in self.ready_id
|
||||
), f"Can only interact with ready environments {self.ready_id}."
|
||||
assert i not in self.waiting_id, f"Cannot interact with environment {i} which is stepping now."
|
||||
assert i in self.ready_id, f"Can only interact with ready environments {self.ready_id}."
|
||||
|
||||
def reset(
|
||||
self, id: Optional[Union[int, List[int], np.ndarray]] = None
|
||||
) -> np.ndarray:
|
||||
def reset(self, id: Optional[Union[int, List[int], np.ndarray]] = None) -> np.ndarray:
|
||||
"""Reset the state of some envs and return initial observations.
|
||||
If id is None, reset the state of all the environments and return
|
||||
initial observations, otherwise reset the specific environments with
|
||||
@@ -539,9 +517,7 @@ class BaseVectorEnv(gym.Env):
|
||||
""" """
|
||||
self.sampler.reset()
|
||||
|
||||
def step(
|
||||
self, action: np.ndarray, id: Optional[Union[int, List[int], np.ndarray]] = None
|
||||
) -> List[np.ndarray]:
|
||||
def step(self, action: np.ndarray, id: Optional[Union[int, List[int], np.ndarray]] = None) -> List[np.ndarray]:
|
||||
"""Run one timestep of some environments' dynamics.
|
||||
If id is None, run one timestep of all the environments’ dynamics;
|
||||
otherwise run one timestep for some environments with given id, either
|
||||
@@ -586,9 +562,7 @@ class BaseVectorEnv(gym.Env):
|
||||
self.ready_id = [x for x in self.ready_id if x not in id]
|
||||
ready_conns: List[EnvWorker] = []
|
||||
while not ready_conns:
|
||||
ready_conns = self.worker_class.wait(
|
||||
self.waiting_conn, self.wait_num, self.timeout
|
||||
)
|
||||
ready_conns = self.worker_class.wait(self.waiting_conn, self.wait_num, self.timeout)
|
||||
result = []
|
||||
for conn in ready_conns:
|
||||
waiting_index = self.waiting_conn.index(conn)
|
||||
@@ -600,9 +574,7 @@ class BaseVectorEnv(gym.Env):
|
||||
self.ready_id.append(env_id)
|
||||
return list(map(np.stack, zip(*result)))
|
||||
|
||||
def seed(
|
||||
self, seed: Optional[Union[int, List[int]]] = None
|
||||
) -> List[Optional[List[int]]]:
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[Optional[List[int]]]:
|
||||
"""Set the seed for all environments.
|
||||
Accept ``None``, an int (which will extend ``i`` to
|
||||
``[i, i + 1, i + 2, ...]``) or a list.
|
||||
@@ -636,10 +608,7 @@ class BaseVectorEnv(gym.Env):
|
||||
"""
|
||||
self._assert_is_not_closed()
|
||||
if self.is_async and len(self.waiting_id) > 0:
|
||||
raise RuntimeError(
|
||||
f"Environments {self.waiting_id} are still stepping, cannot "
|
||||
"render them now."
|
||||
)
|
||||
raise RuntimeError(f"Environments {self.waiting_id} are still stepping, cannot " "render them now.")
|
||||
return [w.render(**kwargs) for w in self.workers]
|
||||
|
||||
def close(self) -> None:
|
||||
@@ -690,9 +659,7 @@ class SubprocVectorEnv(BaseVectorEnv):
|
||||
"""
|
||||
return SubprocEnvWorker(fn, share_memory=False)
|
||||
|
||||
super().__init__(
|
||||
env_fns, worker_fn, sampler, testing, wait_num=wait_num, timeout=timeout
|
||||
)
|
||||
super().__init__(env_fns, worker_fn, sampler, testing, wait_num=wait_num, timeout=timeout)
|
||||
|
||||
|
||||
class ShmemVectorEnv(BaseVectorEnv):
|
||||
@@ -725,6 +692,4 @@ class ShmemVectorEnv(BaseVectorEnv):
|
||||
"""
|
||||
return SubprocEnvWorker(fn, share_memory=True)
|
||||
|
||||
super().__init__(
|
||||
env_fns, worker_fn, sampler, testing, wait_num=wait_num, timeout=timeout
|
||||
)
|
||||
super().__init__(env_fns, worker_fn, sampler, testing, wait_num=wait_num, timeout=timeout)
|
||||
|
||||
Reference in New Issue
Block a user