A simple Neural Network Module for Relational Reasoning

relation module to equip CNN architectures with notion of relational reasoning, particularly useful for tasks such as visual question answering, dynamics understanding etc.

  • Pros (+):

    Simple architecture, relies on small and flexible modules.

  • Cons (-):

    Still a black-box module, hard to quantify how much “reasoning” happens.

The authors propose ato equiparchitectures with notion of relational reasoning, particularly useful for tasks such as visual question answering, dynamics understanding etc.

Proposed Model

The main idea of Relation Networks (RN) is to constrain the functional form of convolutional neural networks as to explicitly learn relations between entities, rather than hoping for this property to emerge in the representation during training. Formally, let \(O\) be a set of objects of interest \(O = \{o_1 \dots o_n\}\); The Relation Network is trained to learn a representation that considers all pairwise relations across the objects:

\[\begin{align} \mbox{RN}(O) = f_{\phi}& \left(\sum_{i, j} g_{\theta}(o_i, o_j) \right) \end{align}\]

\[\begin{align} \mbox{RN}(O) = f_{\phi}& \left(\sum_{i, j} g_{\theta}(o_i, o_j) \right) \end{align}\]

\(f_{\phi}\) and \(g_{\theta}\) are defined as Multi Layer Perceptrons. By definition, the Relation Network (i) has to consider all pairs of objects, (ii) operates directly on the set of objects hence is not constrained to a specific organization of the data, and (iii) is data-efficient in the sense that only one function, \(g_{\theta}\) is learned to capture all the possible relations: \(g\) and \(f\) are typically light modules and most of the overhead comes from the sum of pairwise components (\(n^2\)).

The objects are the basic elements of the relational process we want to model. They are defined with regard to the task at hand, for instance:

  • Attending relations between objects in an image: The image is first processed through a fully-convolutional network. Each of the resulting cell is taken as an object, which is a feature of dimensions \(k\), additionally tagged with its position in the feature map.

  • Sequence of images. In that case, each image is first fed through a feature extractor and the resulting embedding is used as an object. The goal is to model relations between images across the sequence.

Figure: Example of applying the Relation Network for Visual Question Answeting. Questions are processed with an LSTM to produce a question embedding, and images are processed with a CNN to produce a set of objects for the RN.

Experiments

The main evaluation is done on the CLEVR dataset [2]. The main message seems to be that the proposed module is very simple and yet often improves the model accuracy when added to various architectures (CNN, CNN + LSTM etc.) introduced in [1]. The main baseline they compare to (and outperform) is Spatial Attention (SA) which is another simple method to integrate some form of relational reasoning in a neural architecture.

Closely related

Recurrent Relational Neural Networks

[3]

Palm et al, [link]

This paper builds on the Relation Network architecture and propose to explore more complex relational structures, defined as a graph, using a message passing approach: Formally, we are given a graph with vertices \(\mathcal V = \{v_i\}\) and edges \(\mathcal E = \{e_{i, j}\}\). By abuse of notation, \(v_i\) also denotes the embedding for vertex \(i\) (e.g. obtained via a CNN) and \(e_{i, j}\) is 1 where \(i\) and \(j\) are linked, 0 otherwise. To each node we associate a hidden state \(h_i^t\) at iteration \(t\), which will be updated via message passing. After a few iterations, the resulting state is passed through a MLP \(r\) to output the result (either for each node or for the whole graph):

\[\begin{align} h_i^0 &= v_i\\ h_i^{t + 1} &= f_{\phi} \left( h_i^t, v_i, \sum_{j} e_{i, j} g_{\theta}(h^t_i, h^t_j) \right)\\ o_i &= r(h_i^T) \mbox{ or } o = r(\sum_i h_i^T) \end{align}\]

Comparing to the original Relation Network:

  • Each update rule is a Relation Network that only looks at pairwise relations between linked vertices. The message passing scheme additionally introduces the notion of recurrence, and the dependency on the previous hidden state.
  • The dependence on \(h_i^t\) could in theory be avoided by adding self-edges from \(v_i\) to \(v_i\), to make it closer to the Relation Network formulation.
  • Adding \(v_i\) as input of \(f_\phi\) looks like a simple trick to avoid long-term memory problems.

The experiments essentially compare the proposed RRNN model to the Relation Network and classical recurrent architectures such as LSTM. They consider three datasets:

  • Babi. NLP question answering task with some reasoning involved. Solves 19.7 (out of 20) tasks on average, while simple RN solved around 18 of them reliably.
  • Pretty CLEVR. A CLEVR like dataset (only with simple 2D shapes) with questions involving various steps of reasoning, e.g. “which is the shape \(n\) steps of the red circle ?”
  • Sudoku. the graph contains 81 nodes (one for each cell in the sudoku), with edges between cells belonging to the same row, column or block.

Multi-Layer Relation Neural Networks

[4]

\[\begin{align} h_i^0 &= v_i\\ h_i^{t + 1} &= f_{\phi} \left( h_i^t, v_i, \sum_{j} e_{i, j} g_{\theta}(h^t_i, h^t_j) \right)\\ o_i &= r(h_i^T) \mbox{ or } o = r(\sum_i h_i^T) \end{align}\]

Jahrens and Martinetz, [link]

This paper presents a very simple trick to make Relation Network consider higher order relations than pairwise, while retaining some efficiency. Essentially the model can be written as follow:

\[\begin{align} h_{i, j}^0 &= g^0_{\theta}(x_i, x_j) \\ h_{i, j}^t &= g^{t + 1}_{\theta}\left(\sum_k h_{i, k}^{t – 1}, \sum_k h_{j, k}^{t – 1}\right) \\ MLRN(O) &= f_{\phi}(\sum_{i, j} h^T_{i, j}) \end{align}\]

It is not clear why this model would be equivalent to explicitly considering higher-level relations (as it is rather combining pairwise terms for a finite number of steps). According to the experiments it seems that indeed this architecture could be better fitted for the studied tasks (e.g. over the Relation Network or Recurrent Relation Network) but it also makes the model even harder to interpret.

References

  • [1]

    Inferring and executing programs for visual reasoning, Johnson et al, ICCV 2017

  • [2]

    CLEVR: A Diagnostic Dataset for Compositional Language and Elementary Visual Reasoning, Johnson et al, CVPR 1017

  • [3]

    Recurrent Relational Neural Networks, Palm et al, NeurIPS 2018

  • [4]

    Multi-Layer Relation Neural Networks, Jahrens et Martinetz, arXiv 2018

\[\begin{align} h_{i, j}^0 &= g^0_{\theta}(x_i, x_j) \\ h_{i, j}^t &= g^{t + 1}_{\theta}\left(\sum_k h_{i, k}^{t – 1}, \sum_k h_{j, k}^{t – 1}\right) \\ MLRN(O) &= f_{\phi}(\sum_{i, j} h^T_{i, j}) \end{align}\]