measuring abstract reasoning in neural presented by martin

Post on 24-Apr-2022

5 Views

Category:

Documents

0 Downloads

Preview:

Click to see full reader

TRANSCRIPT

Measuring abstract reasoning in neural networks

David G.T. Barrett, Felix Hill, Adam Santoro, Ari S. Morcos, Timothy Lillicrap

Presented by Martin Liivak

IntroductionHumans can perform abstract reasoning

It can be measured using Raven’s Progressive Matrices (RPMs)

One of the goals of AI research is to develop machines with similar abstract reasoning capabilities to humans

Maybe RPMs could be used to measure learning and reasoning in machines

2

Raven-style Progressive Matrices (RPMs)

3

The authors contend that visual intelligence tests can help to better understand learning and reasoning in machines, provided they are coupled with a principled treatment of generalisation.

Clever separation of attributes in training and test datasets could provide insights to generalisation and abstract understanding.

A large dataset of abstract visual reasoning questions is needed where the underlying abstract semantics can be precisely controlled.

This approach allows us to address the following goals.

Approach

4

Goals(1) - Can state-of-the-art neural networks find any solutions to complex, human-challenging abstract reasoning tasks if trained with plentiful training data?

(2) - If so, how well does this capacity generalise when the abstract content of training data is specifically controlled for?

5

Procedurally generated matrices (PGMs)The RPMs need to be automatically generated

Firstly an abstract structure for the matrices needs to be built

The structure S is a set of triples S = {[r, o, a] : r ∈ R, o ∈ O, a ∈ A}

Up to four relations per matrix are permitted (1 ≤ |S| ≤ 4)

6

PGM generation constraintsEach attribute type a ∈ A (e.g. colour) can take one of a finite number of discrete values v ∈ V (e.g. integers between [0, 255]).

The choice of r constrains the values of v that can be realized (e.g with progression the values v need to increase).

Sa is used to denote the set of attributes among the triples in S.

7

PGM generation process(1) Sampling 1-4 triples depending on target difficulty

(2) Sampling values v ∈ V for each a ∈ Sa, adhering to the associated relation r

(3) Sampling values v ∈ V for each a ∉ Sa, ensuring no spurious relation is induced

(4) Rendering the symbolic form into pixels

8

Difficult PGM example

9

Neutral - S has any triples [r, o, a] for r ∈ R, o ∈ O and a ∈ A, test and training are disjoint.

Interpolation - training set restricted to even-indexed, and test set to odd-indexed

elements of Va.

Extrapolation - training set restricted to lower half and test set to upper half

values of Va elements.

Held-out shape-colour - training set contained no triples with o = shape and a = colour, test set had at least one triple with these elements

Generalisation regimes

10

Generalisation regimes cont.Held-out line-type - training set contained no triples with o = line and a = type, test set had at least one triple with these elements

Held-out triples - seven unique triples did not occur in train set, but were in test set.

Held-out pairs of triples - held out triple pairs in test set didn’t occur together in training S.

Held-out attribute pairs - Training set had a single triple with two attributes, while test set had all triples with these two attributes.

11

Model setupThe input consisted of eight context panels and eight multiple-choice panels.

Models were trained to produce the label of the correct missing panel.

All networks were trained by stochastic gradient descent using the ADAM optimiser.

For each model, hyper-parameters were chosen using a grid sweep to select the model with smallest loss estimated on a held-out validation set.

12

CNN-MLPA standard four layer convolutional neural network with batch normalization and ReLU non-linearities.

The convolved output was passed through a two-layer, fully connected MLP

13

ResNetA standard implementation of the ResNet-50 architecture was used.

A selection of variants, including ResNet-101, ResNet-152 and several custom-build were trained.

14

LSTMA standard LSTM module was implemented.

Each panel was sequentially and independently passed through a small 4-layer CNN, with their positions labelled using one-hot label.

The resulting sequence was passed to LSTM.

15

Wild Relation Network (WReN)A novel model created by the authors.

16

Functions fφ and gθ are MLPs

Wild-ResNetNovel variant of the ResNet architecture

One multiple-choice candidate panel, along with the eight context panels were provided as input.

Highest score candidate is the output answer.

17

Context-blind ResNetSufficiently strong models can learn to exploit statistical regularities in multiple-choice problems using the choice inputs alone without context.

A ResNet-50 model was trained with only the eight multiple-choice panels as inputs.

18

Neutral regime experimentAll models were compared on the Neutral train/test split.

The table records the resulting accuracies.

19

WReN performance on different regimes

20

Training on auxiliary information

21

Auxiliary training was explored as a means to improve generalisation performance.

A model trained to predict the relevant relation, object and attribute types involved in each PGM might develop better ability to generalize.

Meta-targets were constructed that encoded the relation, object and attribute types present in PGMs as a binary string.

Scaling factor β determined the influence of this loss relative to the answer panel target loss.

WReN performance with auxiliary information

22

WReN meta-target certainty vs accuracy

23

Results(1) - With important caveats, neural networks can indeed learn to infer and apply abstract reasoning principles.

(2) - When applying known abstract content in unfamiliar combinations, the models generalised notably well.

An important contribution of this work is the introduction of the PGM dataset, as a tool for studying both abstract reasoning and generalisation in models.

24

Thanks for listening

25

top related