Reinforcement Learning (DQN) Tutorial — PyTorch Tutorials 1.8.1+cu102 documentation
def
optimize_model
():
if
len
(
memory
)
<
BATCH_SIZE
:
return
transitions
=
memory
.
sample
(
BATCH_SIZE
)
# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
# detailed explanation). This converts batch-array of Transitions
# to Transition of batch-arrays.
batch
=
Transition
(
*
zip
(
*
transitions
))
# Compute a mask of non-final states and concatenate the batch elements
# (a final state would've been the one after which simulation ended)
non_final_mask
=
torch
.
tensor
(
tuple
(
map
(
lambda
s
:
s
is
not
None
,
batch
.
next_state
)),
device
=
device
,
dtype
=
torch
.
bool
)
non_final_next_states
=
torch
.
cat
([
s
for
s
in
batch
.
next_state
if
s
is
not
None
])
state_batch
=
torch
.
cat
(
batch
.
state
)
action_batch
=
torch
.
cat
(
batch
.
action
)
reward_batch
=
torch
.
cat
(
batch
.
reward
)
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
# columns of actions taken. These are the actions which would've been taken
# for each batch state according to policy_net
state_action_values
=
policy_net
(
state_batch
)
.
gather
(
1
,
action_batch
)
# Compute V(s_{t+1}) for all next states.
# Expected values of actions for non_final_next_states are computed based
# on the "older" target_net; selecting their best reward with max(1)[0].
# This is merged based on the mask, such that we'll have either the expected
# state value or 0 in case the state was final.
next_state_values
=
torch
.
zeros
(
BATCH_SIZE
,
device
=
device
)
next_state_values
[
non_final_mask
]
=
target_net
(
non_final_next_states
)
.
max
(
1
)[
0
]
.
detach
()
# Compute the expected Q values
expected_state_action_values
=
(
next_state_values
*
GAMMA
)
+
reward_batch
# Compute Huber loss
criterion
=
nn
.
SmoothL1Loss
()
loss
=
criterion
(
state_action_values
,
expected_state_action_values
.
unsqueeze
(
1
))
# Optimize the model
optimizer
.
zero_grad
()
loss
.
backward
()
for
param
in
policy_net
.
parameters
():
param
.
grad
.
data
.
clamp_
(
-
1
,
1
)
optimizer
.
step
()