diff --git a/reinforcement_learning/reinforce.py b/reinforcement_learning/reinforce.py index 961598174c..20b7058f09 100644 --- a/reinforcement_learning/reinforce.py +++ b/reinforcement_learning/reinforce.py @@ -19,20 +19,20 @@ help='render the environment') parser.add_argument('--log-interval', type=int, default=10, metavar='N', help='interval between training status logs (default: 10)') +parser.add_argument('--env-id', type=str, default='CartPole-v1') args = parser.parse_args() - -env = gym.make('CartPole-v1') +env = gym.make(args.env_id) env.reset(seed=args.seed) torch.manual_seed(args.seed) class Policy(nn.Module): - def __init__(self): + def __init__(self, n_observation, n_actions): super(Policy, self).__init__() - self.affine1 = nn.Linear(4, 128) + self.affine1 = nn.Linear(n_observation, 128) self.dropout = nn.Dropout(p=0.6) - self.affine2 = nn.Linear(128, 2) + self.affine2 = nn.Linear(128, n_actions) self.saved_log_probs = [] self.rewards = [] @@ -44,8 +44,7 @@ def forward(self, x): action_scores = self.affine2(x) return F.softmax(action_scores, dim=1) - -policy = Policy() +policy = Policy(env.observation_space.shape[0], env.action_space.n) optimizer = optim.Adam(policy.parameters(), lr=1e-2) eps = np.finfo(np.float32).eps.item()