How to train a Deep Q Network — PyTorch Lightning 2.0.0rc1 documentation

class

DQNLightning

(

LightningModule

):

"""Basic DQN Model."""

def

__init__

(

self

,

batch_size

:

int

=

16

,

lr

:

float

=

1e-2

,

env

:

str

=

"CartPole-v0"

,

gamma

:

float

=

0.99

,

sync_rate

:

int

=

10

,

replay_size

:

int

=

1000

,

warm_start_size

:

int

=

1000

,

eps_last_frame

:

int

=

1000

,

eps_start

:

float

=

1.0

,

eps_end

:

float

=

0.01

,

episode_length

:

int

=

200

,

warm_start_steps

:

int

=

1000

,

)

->

None

:

"""

Args:

batch_size: size of the batches")

lr: learning rate

env: gym environment tag

gamma: discount factor

sync_rate: how many frames do we update the target network

replay_size: capacity of the replay buffer

warm_start_size: how many samples do we use to fill our buffer at the start of training

eps_last_frame: what frame should epsilon stop decaying

eps_start: starting value of epsilon

eps_end: final value of epsilon

episode_length: max length of an episode

warm_start_steps: max episode reward in the environment

"""

super

()

.

__init__

()

self

.

save_hyperparameters

()

self

.

env

=

gym

.

make

(

self

.

hparams

.

env

)

obs_size

=

self

.

env

.

observation_space

.

shape

[

0

]

n_actions

=

self

.

env

.

action_space

.

n

self

.

net

=

DQN

(

obs_size

,

n_actions

)

self

.

target_net

=

DQN

(

obs_size

,

n_actions

)

self

.

buffer

=

ReplayBuffer

(

self

.

hparams

.

replay_size

)

self

.

agent

=

Agent

(

self

.

env

,

self

.

buffer

)

self

.

total_reward

=

0

self

.

episode_reward

=

0

self

.

populate

(

self

.

hparams

.

warm_start_steps

)

def

populate

(

self

,

steps

:

int

=

1000

)

->

None

:

"""Carries out several random steps through the environment to initially fill up the replay buffer with

experiences.

Args:

steps: number of random steps to populate the buffer with

"""

for

_

in

range

(

steps

):

self

.

agent

.

play_step

(

self

.

net

,

epsilon

=

1.0

)

def

forward

(

self

,

x

:

Tensor

)

->

Tensor

:

"""Passes in a state x through the network and gets the q_values of each action as an output.

Args:

x: environment state

Returns:

q values

"""

output

=

self

.

net

(

x

)

return

output

def

dqn_mse_loss

(

self

,

batch

:

Tuple

[

Tensor

,

Tensor

])

->

Tensor

:

"""Calculates the mse loss using a mini batch from the replay buffer.

Args:

batch: current mini batch of replay data

Returns:

loss

"""

states

,

actions

,

rewards

,

dones

,

next_states

=

batch

state_action_values

=

self

.

net

(

states

)

.

gather

(

1

,

actions

.

long

()

.

unsqueeze

(

-

1

))

.

squeeze

(

-

1

)

with

torch

.

no_grad

():

next_state_values

=

self

.

target_net

(

next_states

)

.

max

(

1

)[

0

]

next_state_values

[

dones

]

=

0.0

next_state_values

=

next_state_values

.

detach

()

expected_state_action_values

=

next_state_values

*

self

.

hparams

.

gamma

+

rewards

return

nn

.

MSELoss

()(

state_action_values

,

expected_state_action_values

)

def

get_epsilon

(

self

,

start

:

int

,

end

:

int

,

frames

:

int

)

->

float

:

if

self

.

global_step

>

frames

:

return

end

return

start

-

(

self

.

global_step

/

frames

)

*

(

start

-

end

)

def

training_step

(

self

,

batch

:

Tuple

[

Tensor

,

Tensor

],

nb_batch

)

->

OrderedDict

:

"""Carries out a single step through the environment to update the replay buffer. Then calculates loss

based on the minibatch recieved.

Args:

batch: current mini batch of replay data

nb_batch: batch number

Returns:

Training loss and log metrics

"""

device

=

self

.

get_device

(

batch

)

epsilon

=

self

.

get_epsilon

(

self

.

hparams

.

eps_start

,

self

.

hparams

.

eps_end

,

self

.

hparams

.

eps_last_frame

)

self

.

log

(

"epsilon"

,

epsilon

)

# step through environment with agent

reward

,

done

=

self

.

agent

.

play_step

(

self

.

net

,

epsilon

,

device

)

self

.

episode_reward

+=

reward

self

.

log

(

"episode reward"

,

self

.

episode_reward

)

# calculates training loss

loss

=

self

.

dqn_mse_loss

(

batch

)

if

done

:

self

.

total_reward

=

self

.

episode_reward

self

.

episode_reward

=

0

# Soft update of target network

if

self

.

global_step

%

self

.

hparams

.

sync_rate

==

0

:

self

.

target_net

.

load_state_dict

(

self

.

net

.

state_dict

())

self

.

log_dict

(

{

"reward"

:

reward

,

"train_loss"

:

loss

,

}

)

self

.

log

(

"total_reward"

,

self

.

total_reward

,

prog_bar

=

True

)

self

.

log

(

"steps"

,

self

.

global_step

,

logger

=

False

,

prog_bar

=

True

)

return

loss

def

configure_optimizers

(

self

)

->

List

[

Optimizer

]:

"""Initialize Adam optimizer."""

optimizer

=

Adam

(

self

.

net

.

parameters

(),

lr

=

self

.

hparams

.

lr

)

return

optimizer

def

__dataloader

(

self

)

->

DataLoader

:

"""Initialize the Replay Buffer dataset used for retrieving experiences."""

dataset

=

RLDataset

(

self

.

buffer

,

self

.

hparams

.

episode_length

)

dataloader

=

DataLoader

(

dataset

=

dataset

,

batch_size

=

self

.

hparams

.

batch_size

,

)

return

dataloader

def

train_dataloader

(

self

)

->

DataLoader

:

"""Get train loader."""

return

self

.

__dataloader

()

def

get_device

(

self

,

batch

)

->

str

:

"""Retrieve device currently being used by minibatch."""

return

batch

[

0

]

.

device

.

index

if

self

.

on_gpu

else

"cpu"