Reinforcement Learning (DQN) Tutorial — PyTorch Tutorials 1.8.1+cu102 documentation

def

optimize_model

():

if

len

(

memory

)

<

BATCH_SIZE

:

return

transitions

=

memory

.

sample

(

BATCH_SIZE

)

# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for

# detailed explanation). This converts batch-array of Transitions

# to Transition of batch-arrays.

batch

=

Transition

(

*

zip

(

*

transitions

))

# Compute a mask of non-final states and concatenate the batch elements

# (a final state would've been the one after which simulation ended)

non_final_mask

=

torch

.

tensor

(

tuple

(

map

(

lambda

s

:

s

is

not

None

,

batch

.

next_state

)),

device

=

device

,

dtype

=

torch

.

bool

)

non_final_next_states

=

torch

.

cat

([

s

for

s

in

batch

.

next_state

if

s

is

not

None

])

state_batch

=

torch

.

cat

(

batch

.

state

)

action_batch

=

torch

.

cat

(

batch

.

action

)

reward_batch

=

torch

.

cat

(

batch

.

reward

)

# Compute Q(s_t, a) - the model computes Q(s_t), then we select the

# columns of actions taken. These are the actions which would've been taken

# for each batch state according to policy_net

state_action_values

=

policy_net

(

state_batch

)

.

gather

(

1

,

action_batch

)

# Compute V(s_{t+1}) for all next states.

# Expected values of actions for non_final_next_states are computed based

# on the "older" target_net; selecting their best reward with max(1)[0].

# This is merged based on the mask, such that we'll have either the expected

# state value or 0 in case the state was final.

next_state_values

=

torch

.

zeros

(

BATCH_SIZE

,

device

=

device

)

next_state_values

[

non_final_mask

]

=

target_net

(

non_final_next_states

)

.

max

(

1

)[

0

]

.

detach

()

# Compute the expected Q values

expected_state_action_values

=

(

next_state_values

*

GAMMA

)

+

reward_batch

# Compute Huber loss

criterion

=

nn

.

SmoothL1Loss

()

loss

=

criterion

(

state_action_values

,

expected_state_action_values

.

unsqueeze

(

1

))

# Optimize the model

optimizer

.

zero_grad

()

loss

.

backward

()

for

param

in

policy_net

.

parameters

():

param

.

grad

.

data

.

clamp_

(

-

1

,

1

)

optimizer

.

step

()