matching networks for one shot learning · one/few shot learning • learning from one or few...
TRANSCRIPT
Matching Networks for One Shot Learning
Oriol Vinyals, Charles Blundell, Timothy Lillicrap, Koray Kavukcuoglu, Daan Wierstra
�1
Key Contributions
• Matching Nets (MN) which use attention and memory for rapid learning.
• Training procedure which trains on few examples per class.
!2
One/Few Shot Learning• Learning from one or few examples.
• E.g. a child can generalize the idea of a giraffe from a single picture in a book.
!3
Special Case: Zero Shot Learning
• Instead of support set, we have a meta-data vector for each class that gives high level description of class.
• Meta-data potentially determined in advance or learned.
• E.g. images with annotations or captions.
!4
MN Model
• Non-parametric approach.
• Attention-memory mechanism.
• Training strategy tailored towards one-shot learning.
!5
MN Model: Attention-Memory Mechanism
• Originally arose out of an attempt to make models function more like computers.
• Neural attention mechanism accesses a memory matrix that contains useful information.
• Map support set of image-label pairs S to attention-memory mechanism to specify classifier. Letting a be the attention mechanism, classify as
y =kX
i=1
a(x,xi)yi
!6
Attention Mechanism Details
• Authors use softmax over the cosine distance with embedding functions f and g define as neural networks appropriate for the inputs.
c(z, z0) =z · z0
kzkkz0k
a(x,xi) =exp {c(f(x), g(xi))}Pkj=1 exp {c(f(x), g(xj))}
!7
MN Model: Full Context Embeddings
• Also known as Fully Conditional Embeddings (FCE).
• Idea: Embed the inputs along with the support set S to obtain conditional distribution of the labels on S.
a(x,xi) =exp {c(f(x, S), g(xi, S))}Pkj=1 exp {c(f(x, S), g(xj , S))}
!8
FCE Details
• Q: How to encode input in context of the support set?
• A: Consider S as a sequence and use a bidirectional LSTM.
f(x, S) = attLSTM(f 0(x), g(S),K)
g(xi, S) =�!h i +
�h i + g0(xi)
!9
Related Work: Neighborhood Components Analysis (NCA)
• Goldberger et al. Neighborhood components analysis.
• Propose method for learning Mahalonobis distance for KNN classification.
• Minimize pairwise negative log-likelihood over similar pairs of points rather than using the whole support set S which is more amenable to one shot learning.
!10
Related Work: Neural Turing Machines
• Graves et al. Neural turing machines.
• Couple network to external memory resources accessed by attentional processes.
• Example of models with “computer-like” architectures.
!11
Experiments
• Training set: k labelled examples from each of N classes
• Testing: Label each element of disjoint set of data into one of N classes.
• Designate subset of labels, L’, to test on in one-shot mode.
!12
Experiments: Image Classification - Omniglot
• 1623 total classes, 20 samples per class.
• Use CNN as embedding function.
!13
Experiments: Image Classification - ImageNet
• Construct miniImageNet with 60,000 color images of size 84x84 with 100 classes (600 examples each).
• Same setup as with Omniglot.
!14
Experiments: One Shot Language Modeling
• Task: given query sentence with a missing word and support set of sentences with missing words and 1-hot labels, choose label that best matches sentence.
• MN with “simple encoding model” achieve 32.4%, 36.1% and 38.2% accuracy compared to upper bound of 72.8% achieved by LSTM (which is not trained in one shot mode).
!15
MN Concluding Points
• One shot learning is easier when you train your model specifically in a one shot setting.
• Non-parametric models can more easily adapt to and remember new training sets in a given task.
!16
Prototypical Networks for Few-shot Learning
Jake Snell, Kevin Swersky, Richard S.Zemel
�17
Key Contribution
• Propose Prototypical Networks (PN) which offer a simple, effective approach to distance metric learning model in few-shot learning setting with potential for generalization.
!18
PN Model• Idea: There is an embedding in which points cluster around a
single prototype representation in each class.
19
PN Model• Notation: S is support set, Sk is set of examples labeled
with k, ck is prototype of class k, and fɸ is embedding function.
• Prototype is just the mean in embedding space, class probability given by softmax of distances.
ck =1
|Sk|X
(xi,yi)2Sk
f�(xi)
p�(y = k|x) = exp {�d(f�(x), ck)}Pk0 exp {�d(f�(x), ck0)}
!20
PN Model
• Minimize negative log-likelihood of the true class k via SGD.
J(�) = � log p�(y = k|x)
!21
PN as Mixture Estimation• Def: A distance function is a Bregman divergence if we
can write
• where F is continuously differentiable, strictly convex, and defined on a closed convex set.
• E.g. squared Euclidean distance, Mahalonobis distance.
dF (z, z0) = F (z)� (z� z0)TrF (z0)
!22
PN as Mixture Estimation
• For distance functions in the class of Bregman divergences, prototypical networks algorithm is equivalent to mixture estimation with an exponential family density.
• Choice of distance function specifies modeling assumptions about the class conditional data distribution in the embedding space.
p(y = k|z) = ⇡k exp {�d'(z, µ(✓k))}Pk0 ⇡k0 exp {�d'(z, µ(✓k0))}
!23
• If you use the squared Euclidean distance, then the softmax is taken over a linear function.
• where wk = 2ck, bk=-ck·ck.
• Authors found that squared Euclidean distance did well enough as the embedding learns to remove the non-linearity.
PN as Linear Model
�kf�(x)� ckk2 = �f�(x)T f�(x) +wT
k f�(x) + bk
!24
Related Work: Matching Nets
• Attention mechanism in MN comparable to softmax over distances to prototypes.
• Equivalent in the one shot case.
• However, PN allows for any distance metric rather than just the cosine distance.
• PN also elucidates connection to mixture estimation.
!25
Related Work: Distance-based Image Classification
• Mensink et al. Distance-based image classification: Generalizing to new classes at near-zero cost.
• Allow classes to have multiple prototypes but this requires extra pre-processing and does not generalize well to other distance functions.
!26
Experiments: Image Classification - Omniglot
• 1623 handwritten characters (classes) from 50 alphabets, 20 examples per character.
• Few shot learning.
!27
Experiments: Image Classification - miniImageNet• 60,000 color images, 100 classes, image size 84x84.
• Few shot learning.
!28
Experiments: Image Classification - CU Birds
• 11,788 images of 200 bird species
• Zero shot learning task. Meta-data consists of 312 dimensional continuous attribute vectors provided with the dataset.
!29
PN Concluding Points
• PN perform at the state-of-the-art in few shot learning tasks despite being simpler and more efficient than recent approaches.
• PN offer further opportunities for potential improvement via improved distance metrics.
!30