Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 2c76178

Browse files
koz4kcopybara-github
authored andcommitted
Implement SimulatedEnvProblem
PiperOrigin-RevId: 258613974
1 parent 6e8476f commit 2c76178

File tree

6 files changed

+456
-47
lines changed

6 files changed

+456
-47
lines changed

tensor2tensor/envs/env_problem.py

+12-18
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ class EnvProblem(Env, problem.Problem):
5454
5555
Subclasses *should* override the following functions:
5656
- initialize_environments
57+
- observation_space
58+
- action_space
59+
- reward_range
5760
- _reset
5861
- _step
5962
- _render
@@ -95,7 +98,6 @@ class EnvProblem(Env, problem.Problem):
9598

9699
def __init__(self,
97100
batch_size=None,
98-
reward_range=(-np.inf, np.inf),
99101
discrete_rewards=True,
100102
parallelism=1,
101103
**env_kwargs):
@@ -104,9 +106,6 @@ def __init__(self,
104106
Args:
105107
batch_size: (int or None) How many envs to make in the non natively
106108
batched mode.
107-
reward_range: (tuple(number, number)) the first element is the minimum
108-
reward and the second is the maximum reward, used to clip and process
109-
the raw reward in `process_rewards`.
110109
discrete_rewards: (bool) whether to round the rewards to the nearest
111110
integer.
112111
parallelism: (int) If this is greater than one then we run the envs in
@@ -124,18 +123,11 @@ def __init__(self,
124123
# to an appropriate directory.
125124
self._agent_id = "default"
126125

127-
# We clip rewards to this range before processing them further, as described
128-
# in `process_rewards`.
129-
self._reward_range = reward_range
130-
131126
# If set, we discretize the rewards and treat them as integers.
132127
self._discrete_rewards = discrete_rewards
133128

134129
self._parallelism = None
135130

136-
self._observation_space = None
137-
self._action_space = None
138-
139131
# A data structure to hold the `batch_size` currently active trajectories
140132
# and also the ones that are completed, i.e. done.
141133
self._trajectories = None
@@ -168,10 +160,10 @@ def initialize(self, batch_size=1, **kwargs):
168160

169161
# Assert that *all* the above are now set, we should do this since
170162
# subclasses can override `initialize_environments`.
171-
assert self._envs is not None
172-
assert self._observation_space is not None
173-
assert self._action_space is not None
174-
assert self._reward_range is not None
163+
self.assert_common_preconditions()
164+
assert self.observation_space is not None
165+
assert self.action_space is not None
166+
assert self.reward_range is not None
175167

176168
def initialize_environments(self, batch_size=1, parallelism=1, **kwargs):
177169
"""Initializes the environments.
@@ -189,7 +181,7 @@ def assert_common_preconditions(self):
189181

190182
@property
191183
def observation_space(self):
192-
return self._observation_space
184+
raise NotImplementedError
193185

194186
@property
195187
def observation_spec(self):
@@ -210,7 +202,7 @@ def process_observations(self, observations):
210202

211203
@property
212204
def action_space(self):
213-
return self._action_space
205+
raise NotImplementedError
214206

215207
@property
216208
def action_spec(self):
@@ -228,7 +220,9 @@ def num_actions(self):
228220

229221
@property
230222
def reward_range(self):
231-
return self._reward_range
223+
# We clip rewards to this range before processing them further, as described
224+
# in `process_rewards`.
225+
raise NotImplementedError
232226

233227
@property
234228
def is_reward_range_finite(self):

tensor2tensor/envs/gym_env_problem.py

+29-20
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,19 @@ class GymEnvProblem(env_problem.EnvProblem):
7575
the following properties: observation_space, action_space, reward_range.
7676
"""
7777

78-
def __init__(self, base_env_name=None, env_wrapper_fn=None, **kwargs):
78+
def __init__(self, base_env_name=None, env_wrapper_fn=None, reward_range=None,
79+
**kwargs):
7980
"""Initializes this class by creating the envs and managing trajectories.
8081
8182
Args:
8283
base_env_name: (string) passed to `gym.make` to make the underlying
8384
environment.
8485
env_wrapper_fn: (callable(env): env) Applies gym wrappers to the base
8586
environment.
87+
reward_range: (tuple(number, number) or None) the first element is the
88+
minimum reward and the second is the maximum reward, used to clip and
89+
process the raw reward in `process_rewards`. If None, this is inferred
90+
from the inner environments.
8691
**kwargs: (dict) Arguments passed to the base class.
8792
"""
8893
# Name for the base environment, will be used in `gym.make` in
@@ -96,6 +101,10 @@ def __init__(self, base_env_name=None, env_wrapper_fn=None, **kwargs):
96101
# to an appropriate directory.
97102
self._agent_id = "default"
98103

104+
# We clip rewards to this range before processing them further, as described
105+
# in `process_rewards`.
106+
self._reward_range = reward_range
107+
99108
# Initialize the environment(s).
100109

101110
# This can either be a list of environments of len `batch_size` or this can
@@ -171,25 +180,6 @@ def initialize_environments(self, batch_size=1, parallelism=1, **kwargs):
171180
if self._env_wrapper_fn is not None:
172181
self._envs = list(map(self._env_wrapper_fn, self._envs))
173182

174-
# If self.observation_space and self.action_space aren't None, then it means
175-
# that this is a re-initialization of this class, in that case make sure
176-
# that this matches our previous behaviour.
177-
if self._observation_space:
178-
assert str(self._observation_space) == str(
179-
self._envs[0].observation_space)
180-
else:
181-
# This means that we are initializing this class for the first time.
182-
#
183-
# We set this equal to the first env's observation space, later on we'll
184-
# verify that all envs have the same observation space.
185-
self._observation_space = self._envs[0].observation_space
186-
187-
# Similarly for action_space
188-
if self._action_space:
189-
assert str(self._action_space) == str(self._envs[0].action_space)
190-
else:
191-
self._action_space = self._envs[0].action_space
192-
193183
self._verify_same_spaces()
194184

195185
# If self.reward_range is None, i.e. this means that we should take the
@@ -203,6 +193,25 @@ def initialize_environments(self, batch_size=1, parallelism=1, **kwargs):
203193
# is still valuable to store the trajectories separately.
204194
self._trajectories = trajectory.BatchTrajectory(batch_size=batch_size)
205195

196+
def assert_common_preconditions(self):
197+
# Asserts on the common pre-conditions of:
198+
# - self._envs is initialized.
199+
# - self._envs is a list.
200+
assert self._envs
201+
assert isinstance(self._envs, list)
202+
203+
@property
204+
def observation_space(self):
205+
return self._envs[0].observation_space
206+
207+
@property
208+
def action_space(self):
209+
return self._envs[0].action_space
210+
211+
@property
212+
def reward_range(self):
213+
return self._reward_range
214+
206215
def seed(self, seed=None):
207216
if not self._envs:
208217
tf.logging.info("`seed` called on non-existent envs, doing nothing.")

tensor2tensor/trax/backend.py

+2
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ def jax_avg_pool(x, pool_size, strides, padding):
118118
"name": "numpy",
119119
"np": onp,
120120
"jit": (lambda f: f),
121+
"random_get_prng": lambda seed: None,
122+
"random_split": lambda prng, num=2: (None,) * num,
121123
}
122124

123125

tensor2tensor/trax/rlax/ppo_training_loop_test.py

+101-9
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,25 @@
2121

2222
import contextlib
2323
import functools
24+
import itertools
2425
import os
2526
import tempfile
2627

2728
import gin
29+
import gym
2830
import numpy as np
2931

3032
from tensor2tensor.envs import gym_env_problem
3133
from tensor2tensor.rl import gym_utils
3234
from tensor2tensor.trax import inputs as trax_inputs
3335
from tensor2tensor.trax import layers
36+
from tensor2tensor.trax import learning_rate as lr
3437
from tensor2tensor.trax import models
38+
from tensor2tensor.trax import optimizers as trax_opt
39+
from tensor2tensor.trax import trax
3540
from tensor2tensor.trax.rlax import envs # pylint: disable=unused-import
3641
from tensor2tensor.trax.rlax import ppo
42+
from tensor2tensor.trax.rlax import simulated_env_problem
3743
from tensorflow import test
3844
from tensorflow.io import gfile
3945

@@ -55,7 +61,6 @@ def get_wrapped_env(self, name="CartPole-v0", max_episode_steps=2):
5561
return gym_env_problem.GymEnvProblem(base_env_name=name,
5662
batch_size=1,
5763
env_wrapper_fn=wrapper_fn,
58-
reward_range=(-1, 1),
5964
discrete_rewards=False)
6065

6166
@contextlib.contextmanager
@@ -64,9 +69,7 @@ def tmp_dir(self):
6469
yield tmp
6570
gfile.rmtree(tmp)
6671

67-
def _run_training_loop(self, env_name, output_dir):
68-
env = self.get_wrapped_env(env_name, 2)
69-
eval_env = self.get_wrapped_env(env_name, 2)
72+
def _run_training_loop(self, env, eval_env, output_dir):
7073
n_epochs = 2
7174
# Run the training loop.
7275
ppo.training_loop(
@@ -79,28 +82,117 @@ def _run_training_loop(self, env_name, output_dir):
7982
policy_and_value_optimizer_fn=ppo.optimizer_fn,
8083
n_optimizer_steps=1,
8184
output_dir=output_dir,
82-
env_name=env_name,
85+
env_name="SomeEnv",
8386
random_seed=0)
8487

8588
def test_training_loop_cartpole(self):
8689
with self.tmp_dir() as output_dir:
87-
self._run_training_loop("CartPole-v0", output_dir)
90+
self._run_training_loop(
91+
env=self.get_wrapped_env("CartPole-v0", 2),
92+
eval_env=self.get_wrapped_env("CartPole-v0", 2),
93+
output_dir=output_dir,
94+
)
8895

8996
def test_training_loop_onlinetune(self):
9097
with self.tmp_dir() as output_dir:
9198
gin.bind_parameter("OnlineTuneEnv.model", functools.partial(
92-
models.MLP, n_hidden_layers=0, n_output_classes=1))
99+
models.MLP,
100+
n_hidden_layers=0,
101+
n_output_classes=1,
102+
))
93103
gin.bind_parameter("OnlineTuneEnv.inputs", functools.partial(
94104
trax_inputs.random_inputs,
95105
input_shape=(1, 1),
96106
input_dtype=np.float32,
97107
output_shape=(1, 1),
98-
output_dtype=np.float32))
108+
output_dtype=np.float32,
109+
))
99110
gin.bind_parameter("OnlineTuneEnv.train_steps", 2)
100111
gin.bind_parameter("OnlineTuneEnv.eval_steps", 2)
101112
gin.bind_parameter(
102113
"OnlineTuneEnv.output_dir", os.path.join(output_dir, "envs"))
103-
self._run_training_loop("OnlineTuneEnv-v0", output_dir)
114+
self._run_training_loop(
115+
env=self.get_wrapped_env("OnlineTuneEnv-v0", 2),
116+
eval_env=self.get_wrapped_env("OnlineTuneEnv-v0", 2),
117+
output_dir=output_dir,
118+
)
119+
120+
def test_training_loop_simulated(self):
121+
n_actions = 5
122+
history_shape = (3, 2, 3)
123+
action_shape = (3,)
124+
obs_shape = (3, 3)
125+
reward_shape = (3, 1)
126+
127+
def model(mode):
128+
del mode
129+
return layers.Serial(
130+
layers.Parallel(
131+
layers.Flatten(), # Observation stack.
132+
layers.Embedding(d_feature=1, vocab_size=n_actions), # Action.
133+
),
134+
layers.Concatenate(),
135+
layers.Dense(n_units=1),
136+
layers.Dup(),
137+
layers.Parallel(
138+
layers.Dense(n_units=obs_shape[1]), # New observation.
139+
None, # Reward.
140+
)
141+
)
142+
143+
def inputs(n_devices):
144+
del n_devices
145+
stream = itertools.repeat((
146+
(np.zeros(history_shape), np.zeros(action_shape, dtype=np.int32)),
147+
(np.zeros(obs_shape), np.zeros(reward_shape)),
148+
))
149+
return trax_inputs.Inputs(
150+
train_stream=lambda: stream,
151+
train_eval_stream=lambda: stream,
152+
eval_stream=lambda: stream,
153+
input_shape=(history_shape[1:], action_shape[1:]),
154+
input_dtype=(np.float32, np.int32),
155+
)
156+
157+
def loss(*args, **kwargs):
158+
del args
159+
del kwargs
160+
return 0.0
161+
162+
with self.tmp_dir() as output_dir:
163+
# Run fake training just to save the parameters.
164+
trainer = trax.Trainer(
165+
model=model,
166+
loss_fn=loss,
167+
inputs=inputs,
168+
optimizer=trax_opt.SM3,
169+
lr_schedule=lr.MultifactorSchedule,
170+
output_dir=output_dir,
171+
)
172+
trainer.train_epoch(epoch_steps=1, eval_steps=1)
173+
174+
# Repeat the initial observations over and over again.
175+
stream = itertools.repeat(np.zeros(history_shape))
176+
env_fn = functools.partial(
177+
simulated_env_problem.SimulatedEnvProblem,
178+
model=model,
179+
history_length=history_shape[1],
180+
trajectory_length=3,
181+
batch_size=history_shape[0],
182+
observation_space=gym.spaces.Box(
183+
low=-np.inf, high=np.inf, shape=(obs_shape[1],)),
184+
action_space=gym.spaces.Discrete(n=n_actions),
185+
reward_range=(-1, 1),
186+
discrete_rewards=False,
187+
initial_observation_stream=stream,
188+
output_dir=output_dir,
189+
)
190+
191+
self._run_training_loop(
192+
env=env_fn(),
193+
eval_env=env_fn(),
194+
output_dir=output_dir,
195+
)
104196

105197

106198
if __name__ == "__main__":

0 commit comments

Comments
 (0)