1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
| def update(self, state, action, reward, next_state, done): state = torch.FloatTensor(state).unsqueeze(0) next_state = torch.FloatTensor(next_state).unsqueeze(0) action = torch.LongTensor([action]) reward = torch.FloatTensor([reward]) done = torch.FloatTensor([int(done)])
_, next_state_value = self.actor_critic(next_state) _, state_value = self.actor_critic(state) q_value = reward + self.gamma * next_state_value * (1 - done)
log_prob, _ = self.actor_critic(state) actor_loss = -(log_prob[0][action] * q_value).mean() critic_loss = F.mse_loss(state_value, q_value.detach()) loss = actor_loss + critic_loss
self.optimizer.zero_grad() loss.backward() self.optimizer.step()
|