ぱそきいろのIT日記

ぱそきいろがITに関する記事を書いていきます。

深層学習で一枚の画像から人物認識をした話(モデル,学習の話)

こんにちは,ぱそきいろです.
ぱそきいろ (@takacpu55) | Twitter

今日は深層学習を使って人物認識で遊んだこの記事の続きになります.

www.takacpu55.xyz
f:id:takabsk55:20190802124957j:plain
この画像一枚で二人の人物認識をしていきます.
(パオパオチャンネルの「ぶんけい」と「あーずー」を見分けます.)
今回はモデルの構成や,学習について詳しく書いていきます.

結果は9割ほど正解したので,教師データが1枚だったことを考えたらまぁまぁな精度だと思います.
よろしくお願いします.

モデルの構成
#モデルの定義
class VGG16Model(chainer.Chain):

    def __init__(self, out_size):
        super(VGG16Model, self).__init__(
            base = L.VGG16Layers(),
            fc = L.Linear(None, out_size)
        )

    def __call__(self, x):
        h = self.base(x, layers=['fc7'])
        y = self.fc(h['fc7'])
        return y

model = L.Classifier(VGG16Model(out_size=2))

この構成のモデルを用いて学習させました.

これはVGG16というモデルが元となっています.

もともと学習済みもモデルに最後の層だけ学習させることで少ないデータセットでも精度を出すファインチューニングという技術があります.

今回はこれを使っていきます.

ファインチューニング

qiita.com

こちらの記事を参考にファインチューニングしていきます.
データ数が少ないので,バッチサイズを4にしています.
最後に詳しいソースコードを載せています.
 

結果

今回はぶんけいとあーずーの二人を見分けるので,それぞれ10枚ずつ画像を持ってきて,テストしていきます.
(ここでも手作業で10枚画像をダウンロードしてきました.
Webスクレイピング勉強しないと...)
結果はこちらです.
ぶんけいのテスト
f:id:takabsk55:20190813174709p:plain

/home/dl-box/PycharmProjects/face2/venv/bin/python /home/dl-box/PycharmProjects/face2/test.py datas/bun/image/G23.jpg
./bun9.jpg 0 0.5755364
./bunx.jpg 0 0.54191417
./bun1.jpg 0 0.9893457
./bun7.jpg 0 0.7951491
./bun5.jpg 0 0.54435015
./bun8.jpg 0 0.580492
./bun4.jpg 0 0.54191417
./bun6.jpg 0 0.86531574
./bun2.jpg 1 0.6155463
./bun3.jpg 0 0.6257619
./bun10.jpg 0 0.92617035

Process finished with exit code 0

あーずーのテスト
f:id:takabsk55:20190813174728p:plain

/home/dl-box/PycharmProjects/face2/venv/bin/python /home/dl-box/PycharmProjects/face2/test.py datas/bun/image/G23.jpg
./azu4.jpg 0 0.8545124
./azu5.jpg 1 0.70550936
./azu2.jpg 1 0.9799365
./azu.jpg 1 0.96837497
./azu7.jpg 1 0.9988709
./azu1.jpg 1 0.9227053
./azu6.jpg 1 0.6484426
./azu8.jpg 1 0.98450005
./azu9.jpg 1 0.86194366
./azu3.jpg 1 0.5503469

Process finished with exit code 0

どちらも1枚だけ間違えて判定していますが,それ以外は正解してます.
テスト数が少ないのですが,この場合で9割正解しています.

まとめ

今回は1枚の画像から人物認証をしてみました.
思ってたよりも少ないコード量で,学習させることができました.
もうちょっと画像の深層学習で遊んでみたいと思います.
ありがとうございました.

ソースコード

学習コード

# coding: utf-8

#ライブラリーのインポート
import glob
from tkinter import Image
import numpy as np
import chainer
from chainer import functions as F, serializers,links as L
import cv2
from chainer import training
from chainer.datasets import LabeledImageDataset
import os
from chainer.training import extensions
from itertools import chain
from chainer import iterators

#モデルの定義
class VGG16Model(chainer.Chain):

    def __init__(self, out_size):
        super(VGG16Model, self).__init__(
            base = L.VGG16Layers(),
            fc = L.Linear(None, out_size)
        )

    def __call__(self, x):
        h = self.base(x, layers=['fc7'])
        y = self.fc(h['fc7'])
        return y

model = L.Classifier(VGG16Model(out_size=2))

# 学習済みレイヤーの学習率を固定する

optimizer = chainer.optimizers.Adam(alpha=1e-4)
optimizer.setup(model)

model.predictor.base.disable_update()

gpu=1

if gpu >= 0:
    chainer.cuda.get_device(gpu).use()
    model.to_gpu(gpu)

#データセットの名前一覧を取得
path="./datas/bun/image"
namelisttemp=os.listdir(path)
namelist=[]
for i in namelisttemp:
    namelist.append("./datas/bun/image/"+i+" 0")

path = "./datas/kanta/image"
namelisttemp = os.listdir(path)
for i in namelisttemp:
    namelist.append("./datas/kanta/image/" + i + " 1")


pathw="./datas/namelist.txt"

with open(pathw,mode="w") as f:
    f.write("\n".join(namelist))


# 画像フォルダ
IMG_DIR = './datas'


train = LabeledImageDataset("./datas/namelist.txt")

train_iter=iterators.SerialIterator(train,4,shuffle=True)

updater=training.StandardUpdater(train_iter,optimizer,device=gpu)

max_epoch = 50

# TrainerにUpdaterを渡す
trainer = training.Trainer(
    updater, (max_epoch, 'epoch'), out='model/bunkan')


#trainerの追加
trainer.extend(extensions.LogReport())
trainer.extend(extensions.snapshot(filename='snapshot_epoch-{.updater.epoch}'))
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'val/main/loss', 'val/main/accuracy', 'l1/W/data/std', 'elapsed_time']))
trainer.extend(extensions.PlotReport(['l1/W/data/std'], x_key='epoch', file_name='std.png'))
trainer.extend(extensions.PlotReport(['main/loss', 'val/main/loss'], x_key='epoch', file_name='loss.png'))
trainer.extend(extensions.PlotReport(['main/accuracy', 'val/main/accuracy'], x_key='epoch', file_name='accuracy.png'))
trainer.extend(extensions.dump_graph('main/loss'))

#学習開始
trainer.run()