Face recognition using a Siamese network with fastai

In this post I explore a simple face recognition system using a Siamese network. It is based on the following fastai material

What makes a Siamese network attractive is you don't need a ton of data. It works by learning whether two images are similiar or not, rather than learn an explicit class label. As a demonstration I'll be using the Yale Face Database. This is an ancient dataset from 1997 that consists of 15 individuals. There are 11 images per individual, each image has a different facial expression or lighting. Each image is a 320x243 grayscale GIF. In total there are 165 images. A dataset this size is laughably too small for a typical image classification network. But with a Siamese network we're comparing pairs of images. With 165 images this allows for 165x165 = 27225 unique training pairs. This is a much more reasonable dataset size!

To recognize a new image I take a simple approach and compare it against all existing images and accumulate the similarity and see which one has the highest score.

This Jupyter notebook can be downloaded here https://github.com/nghiaho12/fastai_projects/blob/main/yale_face_siamese.ipynb

  1. Load the Yale Face dataset
  2. Set up custom transform and model
  3. Set up DataLoader
  4. Train the model
  5. Evaluate on test set
  6. Some thoughts

1. Load the Yale Face dataset

The dataaset has a simple filename structure with the exception of a redundant file and an incorrectly named file.

In [1]:
from fastai.vision.all import *
import os

if torch.cuda.is_available() == False:
    raise ValueError("No CUDA device found!")
    
#plt.style.use('dark_background')

path = untar_data("http://vision.ucsd.edu/datasets/yale_face_dataset_original/yalefaces.zip")

# redundant file
if os.path.exists(path/"subject01.glasses.gif"):
    os.remove(path/"subject01.glasses.gif") 
    
# incorrect naming    
if os.path.exists(path/"subject01.gif"):
    os.rename(path/"subject01.gif", path/"subject01.centerlight")

files = L(path.glob("subject*"))

2. Set up custom transform and model

This is mostly copy and paste from the fastai tutorials. I make a slight change to the model by using a simpler head after the encoder. It takes the two encoded vectors and performs an absolute difference, which then gets fed into a fully connected layer. This felt a bit more intuitive than using the resnet18 head.

In [6]:
class SiameseImage(fastuple):
    def show(self, ctx=None, **kwargs):
        img1, img2, same = self
        if not isinstance(img1, Tensor):
            if img2.size != img1.size: img2 = img2.resize(img1.size)
            t1, t2 = tensor(img1), tensor(img2)
            t1, t2 = t1.permute(2,0,1), t2.permute(2,0,1)
        else: t1, t2 = img1, img2
        line = t1.new_zeros(t1.shape[0], t1.shape[1], 10)
        return show_image(torch.cat([t1, line, t2], dim=2), title=same, ctx=ctx)
    
class SiameseTransform(Transform):
    def __init__(self, files, label_func, splits):
        self.labels = files.map(label_func).unique()        
        self.lbl2files = {l : L(f for f in files if label_func(f) == l) for l in self.labels}
        self.label_func = label_func
        self.valid = {f: self._draw(f) for f in files[splits[1]]}
        
    def encodes(self, f):
        f2, t = self.valid.get(f, self._draw(f)) # calls draw() if key=f does not exist
        img1, img2 = PILImage.create(f), PILImage.create(f2)
        return SiameseImage(img1, img2, t)
    
    def _draw(self, f):
        same = random.random() < 0.5
        cls = self.label_func(f)
        if not same:
            cls = random.choice(L(l for l in self.labels if l != cls))
        return random.choice(self.lbl2files[cls]), same
              
class SiameseModel(Module):
    def __init__(self, encoder):
        self.encoder = encoder
        self.fc = nn.Linear(1024, 1)
        
    def similarity(self, e1, e2):
        x = torch.abs(e1 - e2)
        x = self.fc(x)
        x = nn.Sigmoid()(x)
        return x
    
    def forward(self, x1, x2):
        e1 = self.encoder(x1)
        e2 = self.encoder(x2)
        return self.similarity(e1, e2)
    
@typedispatch
def show_batch(x:SiameseImage, y, samples, ctxs=None, max_n=6, nrows=None, ncols=2, figsize=None, **kwargs):
    if figsize is None: figsize = (ncols*6, max_n//ncols * 2)
    if ctxs is None: ctxs = get_grid(min(x[0].shape[0], max_n), nrows=None, ncols=ncols, figsize=figsize)
    for i,ctx in enumerate(ctxs): SiameseImage(x[0][i], x[1][i], ['Not similar','Similar'][x[2][i].item()]).show(ctx=ctx)    

3. Set up DataLoader

This again is mostly copy and paste from the fastai tutorials. For the test set I decided to make it interesting by using only faces with glasses. I wanted to see if the network can work on face type it has never seen before.

In [7]:
def label_func(fname):
    return re.match(r'^subject(.*)\.', fname.name).groups()[0]

train_files = L()
test_files = L()

for f in files:
    if ".glasses" in f.name:
        test_files.append(f)
    else:
        train_files.append(f)
            
splits = RandomSplitter()(train_files)
tfm = SiameseTransform(train_files, label_func, splits)
tls = TfmdLists(train_files, tfm, splits=splits)
dls = tls.dataloaders(after_item=[ToTensor], after_batch=[IntToFloatTensor, Normalize.from_stats(*imagenet_stats)], bs=16)

# sanity check
dls.show_batch()

4. Train the model

I'm using resnet18 as the base model. There was a bit of of gotcha I came across when coding this up. Previously, I didn't call torch.squeeze. This caused a mismatch in tensor dimensions, but due to broadcasting it ran without errors. The results that came out were wrong though. So always check your dimensions!

The training epoch is set to 150 to make sure we sample enough of all the possible image pairs.

In [8]:
def my_loss(out, target):
    return nn.BCELoss()(torch.squeeze(out, 1), target.float())

def my_accuracy(input, target):
    label = input > 0.5
    return (label.squeeze(1) == target).float().mean()

encoder = nn.Sequential(
    create_body(resnet18, cut=-2),
    AdaptiveConcatPool2d(),
    nn.Flatten()
)

model = SiameseModel(encoder)
In [9]:
learn = Learner(dls, model, loss_func=my_loss, metrics=my_accuracy).to_fp16()
#learn.lr_find()
learn.fit(150)
learn.save("yale_face")
learn.recorder.plot_loss()
epoch train_loss valid_loss my_accuracy time
0 0.771358 0.798573 0.433333 00:04
1 0.762180 0.764269 0.633333 00:03
2 0.707662 0.418891 0.766667 00:03
3 0.690268 0.467284 0.866667 00:03
4 0.663904 0.192052 0.933333 00:03
5 0.631431 0.603939 0.666667 00:03
6 0.592009 0.485361 0.733333 00:03
7 0.559610 0.353569 0.866667 00:03
8 0.537616 0.321258 0.866667 00:03
9 0.515465 0.305035 0.866667 00:03
10 0.522608 0.439124 0.800000 00:03
11 0.504370 0.344908 0.833333 00:03
12 0.492477 0.415101 0.833333 00:03
13 0.469759 0.338476 0.833333 00:03
14 0.447328 0.312775 0.900000 00:03
15 0.429784 0.461406 0.866667 00:03
16 0.423695 0.259948 0.900000 00:03
17 0.427641 0.455844 0.800000 00:03
18 0.418350 0.265542 0.900000 00:03
19 0.407068 0.233716 0.966667 00:03
20 0.397152 0.219240 0.933333 00:03
21 0.391396 0.242315 0.933333 00:03
22 0.396777 0.389481 0.866667 00:03
23 0.399570 0.340503 0.933333 00:03
24 0.400220 0.274863 0.933333 00:03
25 0.388329 0.332037 0.933333 00:03
26 0.391222 0.215648 0.966667 00:03
27 0.377320 0.462584 0.800000 00:03
28 0.368654 0.350085 0.833333 00:03
29 0.355847 0.269538 0.900000 00:03
30 0.340884 0.203802 0.933333 00:03
31 0.333750 0.261518 0.900000 00:03
32 0.341796 0.292746 0.900000 00:03
33 0.354824 0.404484 0.800000 00:03
34 0.352173 0.476014 0.800000 00:03
35 0.347852 0.234246 0.900000 00:03
36 0.341717 0.148518 0.966667 00:03
37 0.329180 0.267744 0.866667 00:03
38 0.320899 0.234046 0.900000 00:03
39 0.320495 0.443201 0.866667 00:03
40 0.323765 0.179629 0.900000 00:03
41 0.311841 0.179645 0.900000 00:03
42 0.293914 0.219344 0.933333 00:03
43 0.284600 0.157972 0.966667 00:03
44 0.261732 0.232673 0.900000 00:03
45 0.251162 0.091173 0.966667 00:03
46 0.250668 0.141534 0.966667 00:03
47 0.235463 0.174219 0.933333 00:03
48 0.226625 0.228336 0.933333 00:03
49 0.215378 0.122686 0.966667 00:03
50 0.204091 0.363965 0.900000 00:03
51 0.205064 0.184772 0.900000 00:03
52 0.224805 0.148733 0.966667 00:03
53 0.212715 0.310556 0.900000 00:03
54 0.209553 0.119873 0.966667 00:03
55 0.203683 0.190742 0.900000 00:03
56 0.203018 0.174201 0.900000 00:03
57 0.199530 0.079015 1.000000 00:03
58 0.199540 0.064811 1.000000 00:03
59 0.202381 0.340615 0.900000 00:03
60 0.205328 0.252232 0.866667 00:03
61 0.203808 0.175128 0.900000 00:03
62 0.203636 0.153278 0.933333 00:03
63 0.194972 0.324574 0.900000 00:03
64 0.194273 0.173281 0.900000 00:03
65 0.185298 0.292468 0.933333 00:03
66 0.174574 0.121578 0.933333 00:03
67 0.169444 0.229488 0.866667 00:03
68 0.162427 0.127305 0.933333 00:03
69 0.154960 0.163160 0.966667 00:03
70 0.156632 0.096727 0.966667 00:03
71 0.151960 0.130506 0.933333 00:03
72 0.140487 0.061769 1.000000 00:03
73 0.138634 0.090673 0.966667 00:03
74 0.128519 0.322521 0.933333 00:03
75 0.124302 0.132279 0.966667 00:03
76 0.123788 0.169148 0.966667 00:03
77 0.122911 0.109776 0.966667 00:03
78 0.125244 0.061401 1.000000 00:03
79 0.129666 0.072725 1.000000 00:03
80 0.120455 0.067541 1.000000 00:03
81 0.120466 0.062071 1.000000 00:03
82 0.122316 0.045800 1.000000 00:03
83 0.119409 0.093577 0.966667 00:03
84 0.118797 0.108664 0.966667 00:03
85 0.115181 0.139946 0.966667 00:03
86 0.113063 0.189135 0.966667 00:03
87 0.123476 0.084278 1.000000 00:03
88 0.122245 0.057583 1.000000 00:03
89 0.127036 0.172101 0.966667 00:03
90 0.124319 0.071027 1.000000 00:03
91 0.132396 0.059073 1.000000 00:03
92 0.127176 0.275350 0.866667 00:03
93 0.130793 0.192220 0.933333 00:03
94 0.124894 0.364077 0.900000 00:03
95 0.133376 0.120467 0.966667 00:03
96 0.134325 0.096073 0.966667 00:03
97 0.128241 0.121352 0.966667 00:03
98 0.127081 0.065259 1.000000 00:03
99 0.136612 0.110917 0.966667 00:03
100 0.129069 0.061674 1.000000 00:03
101 0.117462 0.115273 0.966667 00:03
102 0.114762 0.059832 1.000000 00:03
103 0.106716 0.056188 1.000000 00:03
104 0.119044 0.053240 1.000000 00:03
105 0.111633 0.119676 0.966667 00:03
106 0.107212 0.077022 0.966667 00:03
107 0.105327 0.049430 1.000000 00:03
108 0.098985 0.058894 1.000000 00:03
109 0.101012 0.056524 1.000000 00:03
110 0.097253 0.058959 1.000000 00:03
111 0.093863 0.036646 1.000000 00:03
112 0.087012 0.039417 1.000000 00:03
113 0.089810 0.041918 1.000000 00:03
114 0.091238 0.028124 1.000000 00:03
115 0.086539 0.064753 0.966667 00:03
116 0.089725 0.078176 0.966667 00:03
117 0.088685 0.096591 0.966667 00:03
118 0.085691 0.078476 0.966667 00:03
119 0.089123 0.086907 0.966667 00:03
120 0.096020 0.108284 0.966667 00:03
121 0.096970 0.061336 0.966667 00:03
122 0.097451 0.036039 1.000000 00:03
123 0.093356 0.046852 1.000000 00:03
124 0.097766 0.033516 1.000000 00:03
125 0.094717 0.033016 1.000000 00:03
126 0.093153 0.043008 1.000000 00:03
127 0.086626 0.032254 1.000000 00:03
128 0.079253 0.031341 1.000000 00:03
129 0.076460 0.036718 1.000000 00:03
130 0.074800 0.033921 1.000000 00:03
131 0.069800 0.031310 1.000000 00:03
132 0.065554 0.055939 1.000000 00:03
133 0.064280 0.034122 1.000000 00:03
134 0.061899 0.049290 1.000000 00:03
135 0.064276 0.054888 1.000000 00:03
136 0.061236 0.103860 0.966667 00:03
137 0.061712 0.067840 1.000000 00:03
138 0.058570 0.065528 1.000000 00:03
139 0.060206 0.028810 1.000000 00:03
140 0.054501 0.040244 1.000000 00:03
141 0.058984 0.052320 1.000000 00:03
142 0.065623 0.058379 1.000000 00:03
143 0.070843 0.082804 1.000000 00:03
144 0.080594 0.046443 1.000000 00:03
145 0.094239 0.066486 1.000000 00:03
146 0.107337 0.052765 1.000000 00:03
147 0.104944 0.074229 1.000000 00:03
148 0.100248 0.036259 1.000000 00:03
149 0.091433 0.143146 0.933333 00:03

The validation loss is rather bumpy because we're using a very small fixed set of image pairs. 20% of training images by default.

5. Evaluate on test set

For each test image I compare it against all the training images. Each training image will vote using the similarity score from the network. The scores are then averaged out and the best one is picked as the label. I'm doing this in a very slow loop to keep things simple. In practice you would cache the encoded training image vectors and run the prediction in batches to speed things up.

I had issues with my laptop's GPU running out of memory when doing this loop. The solution is to turn off gradient calculation when doing inference using torch.no_grad in a context manager.

In [12]:
learn = Learner(dls, model, loss_func=my_loss, metrics=my_accuracy).to_fp16()
learn.load("yale_face")
learn.model.cuda()

correct = 0

pipe = Pipeline([IntToFloatTensor(), Normalize.from_stats(*imagenet_stats)])

# cache the encoded training iamges
encoded = torch.zeros(len(train_files), learn.model.fc.in_features).cuda()

for idx, f in enumerate(train_files):
    img = PILImage.create(f)

    # disable gradients to save GPU memory!
    with torch.no_grad(): 
        x = pipe(ToTensor()(img).cuda())
        x = learn.model.encoder(x)
        encoded[idx, :] = x
        
fig, axs = plt.subplots(5, 3, figsize=(15,15))

for f, ax in zip(sorted(test_files), axs.flat):
    img = PILImage.create(f)
    
    # find the best match
    # disable gradients to save GPU memory!
    
    tally = {}
    
    with torch.no_grad(): 
        x = pipe(ToTensor()(img).cuda())
        x = learn.model.encoder(x)
        
        # this will use tensor broadcast
        y = learn.model.similarity(encoded, x)
                
        best_idx = torch.argmax(y)
        best_score = y[best_idx][0].item()        
        best_label = label_func(train_files[best_idx])
        best_img = PILImage.create(train_files[best_idx])
        
    if label_func(f) == best_label:
        correct += 1
        SiameseImage(img, best_img, f"correct ({best_score:.4f})").show(ctx=ax)
    else:
        SiameseImage(img, best_img, "incorrect").show(ctx=ax)
    
    print(f"{f.name} most similar to label ({best_label}, {best_score:.4f})")
    
print("")
print(f"correct matches {correct}/{len(test_files)}")
subject01.glasses most similar to label (01, 0.9998)
subject02.glasses most similar to label (02, 0.9999)
subject03.glasses most similar to label (03, 0.9996)
subject04.glasses most similar to label (04, 0.9995)
subject05.glasses most similar to label (05, 0.9993)
subject06.glasses most similar to label (06, 0.9998)
subject07.glasses most similar to label (07, 0.9998)
subject08.glasses most similar to label (08, 0.9998)
subject09.glasses most similar to label (09, 0.9960)
subject10.glasses most similar to label (10, 0.9816)
subject11.glasses most similar to label (11, 0.9997)
subject12.glasses most similar to label (12, 0.9994)
subject13.glasses most similar to label (13, 1.0000)
subject14.glasses most similar to label (14, 0.9988)
subject15.glasses most similar to label (15, 0.9999)

correct matches 15/15

The network managed to correctly classified all the test images. It turns out the Yale Face dataset isn't very consistent about people wearing glasses or not. So the network learnt some faces with glasses, oh well.

Some thoughts

I've shown how to create a simple face recognition system using a Siamese network with very little training data. Each face is encoded as a 1024 float vector, so a fairly compact representation. Inference runtime scales linearly with the number of faces.