Reinforcement Learning (DQN) tutorial — PyTorch Tutorials 0.2.0_4 documentation

last_sync

=

0

def

optimize_model

():

global

last_sync

if

len

(

memory

)

<

BATCH_SIZE

:

return

transitions

=

memory

.

sample

(

BATCH_SIZE

)

# Transpose the batch (see http://stackoverflow.com/a/19343/3343043 for

# detailed explanation).

batch

=

Transition

(

*

zip

(

*

transitions

))

# Compute a mask of non-final states and concatenate the batch elements

non_final_mask

=

ByteTensor

(

tuple

(

map

(

lambda

s

:

s

is

not

None

,

batch

.

next_state

)))

# We don't want to backprop through the expected action values and volatile

# will save us on temporarily changing the model parameters'

# requires_grad to False!

non_final_next_states

=

Variable

(

torch

.

cat

([

s

for

s

in

batch

.

next_state

if

s

is

not

None

]),

volatile

=

True

)

state_batch

=

Variable

(

torch

.

cat

(

batch

.

state

))

action_batch

=

Variable

(

torch

.

cat

(

batch

.

action

))

reward_batch

=

Variable

(

torch

.

cat

(

batch

.

reward

))

# Compute Q(s_t, a) - the model computes Q(s_t), then we select the

# columns of actions taken

state_action_values

=

model

(

state_batch

)

.

gather

(

1

,

action_batch

)

# Compute V(s_{t+1}) for all next states.

next_state_values

=

Variable

(

torch

.

zeros

(

BATCH_SIZE

)

.

type

(

Tensor

))

next_state_values

[

non_final_mask

]

=

model

(

non_final_next_states

)

.

max

(

1

)[

0

]

# Now, we don't want to mess up the loss with a volatile flag, so let's

# clear it. After this, we'll just end up with a Variable that has

# requires_grad=False

next_state_values

.

volatile

=

False

# Compute the expected Q values

expected_state_action_values

=

(

next_state_values

*

GAMMA

)

+

reward_batch

# Compute Huber loss

loss

=

F

.

smooth_l1_loss

(

state_action_values

,

expected_state_action_values

)

# Optimize the model

optimizer

.

zero_grad

()

loss

.

backward

()

for

param

in

model

.

parameters

():

param

.

grad

.

data

.

clamp_

(

-

1

,

1

)

optimizer

.

step

()