Deep Q-network (DQN) reinforcement learning agent – MATLAB

Create an environment interface and obtain its observation and action specifications. For this example load the predefined environment used for the Train DQN Agent to Balance Cart-Pole System example. This environment has a continuous four-dimensional observation space (the positions and velocities of both cart and pole) and a discrete one-dimensional action space consisting on the application of two possible forces, -10N or 10N.

Create the predefined environment.

env = rlPredefinedEnv(

"CartPole-Discrete"

);

Get the observation and action specification objects.

obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);

To approximate the Q-value function within the critic, use a deep neural network. For DQN agents with a discrete action space, you have the option to create a multi-output Q-value function critic, which is generally more efficient than a comparable single-output critic.

A network for this critic must take only the observation as input and return a vector of values for each action. Therefore, it must have an input layer with as many elements as the dimension of the observation space and an output layer having as many elements as the number of possible discrete actions. Each output element represents the expected cumulative long-term reward following from the observation given as input, when the corresponding action is taken.

Define the network as an array of layer objects, and get the dimensions of the observation space (that is, prod(obsInfo.Dimension)) and the number of possible actions (that is, numel(actInfo.Elements)) directly from the environment specification objects.

dnn = [
    featureInputLayer(prod(obsInfo.Dimension))
    fullyConnectedLayer(24)
    reluLayer
    fullyConnectedLayer(24)
    reluLayer
    fullyConnectedLayer(numel(actInfo.Elements))];

Convert the network to a dlnetwork object ad display the number of weights.

dnn = dlnetwork(dnn);
summary(dnn)
   Initialized: true

   Number of learnables: 770

   Inputs:
      1   'input'   4 features

Create the critic using rlVectorQValueFunction, the network dnn as well as the observation and action specifications.

critic = rlVectorQValueFunction(dnn,obsInfo,actInfo);

Check that the critic works with a random observation input.

getValue(critic,{rand(obsInfo.Dimension)})
ans = 

2x1 single column vector

-0.0361 0.0913

Create the DQN agent using the critic.

agent = rlDQNAgent(critic)
agent = 
  rlDQNAgent with properties:

        ExperienceBuffer: [1x1 rl.replay.rlReplayMemory]
            AgentOptions: [1x1 rl.option.rlDQNAgentOptions]
    UseExplorationPolicy: 0
         ObservationInfo: [1x1 rl.util.rlNumericSpec]
              ActionInfo: [1x1 rl.util.rlFiniteSetSpec]
              SampleTime: 1

Specify agent options, including training options for the critic.

agent.AgentOptions.UseDoubleDQN=false;
agent.AgentOptions.TargetUpdateMethod=

"periodic"

; agent.AgentOptions.TargetUpdateFrequency=4; agent.AgentOptions.ExperienceBufferLength=100000; agent.AgentOptions.DiscountFactor=0.99; agent.AgentOptions.MiniBatchSize=256; agent.AgentOptions.CriticOptimizerOptions.LearnRate=1e-2; agent.AgentOptions.CriticOptimizerOptions.GradientThreshold=1;

To check your agent, use getAction to return the action from a random observation.

getAction(agent,{rand(obsInfo.Dimension)})
ans = 

1x1 cell array

{[10]}

You can now test and train the agent within the environment.