機械学習手法を用いて画像認識してみたいです。
このような要望にお応えします。
今回は、ResNet学習済みモデルを用いて画像認識を行ってみます。
下記のサイトを参考にさせていただきました。
- https://deepage.net/deep_learning/2016/11/30/resnet.html
- https://github.com/creafz/pytorch-cnn-finetune/blob/master/examples/cifar10.py#L60
- https://github.com/creafz/pytorch-cnn-finetune
- https://vaaaaaanquish.hatenablog.com/entry/2018/09/15/213253
Google Colaboratoryの準備
Google Colaboratoryの準備は、下記の記事を参照ください。
学習準備
データは、動物画像を用います。
分類の対象は、以下になります。
・犬
・猫
・リス
以下のように、train_data.csv, test_data.csvを用意しました。
【train_data.csv】
【test_data.csv】
ソースコードは、以下になります。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# coding: utf-8 from PIL import Image from torch.utils.data import Dataset import torchvision.transforms as transforms import pandas as pd import os from cnn_finetune import make_model import torch import torch.nn as nn import torch.optim as optim import datetime from sklearn.metrics import classification_report import cv2 import matplotlib.pyplot as plt import numpy as np device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class TrainDataSet(Dataset): def __init__(self, csv_path): self.df = pd.read_csv(csv_path) self.images = self.df['image_name'] self.transform = transforms.Compose([transforms.ToTensor()]) def __len__(self): return len(self.images) def __getitem__(self, idx): image_name = self.images[idx] image_name = image_name.replace('\\', '/') image = Image.open(image_name) image = image.convert('RGB') label = self.df['image_label'][idx] return self.transform(image), int(label) def fn_train(model, epoch, optimizer, train_loader, criterion=nn.CrossEntropyLoss()): total_loss = 0 total_size = 0 model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = criterion(output, target) total_loss += loss.item() total_size += data.size(0) loss.backward() optimizer.step() if batch_idx % 1000 == 0: now = datetime.datetime.now() print('[{}] Train Epoch: {} [{}/{} ({:.0f}%)] Average loss: {:.6f}'.format(now, epoch, batch_idx*len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), total_loss / total_size)) def fn_start_learning(): model = make_model('resnet18', num_classes=3, pretrained=True, input_size=(120, 120)) model = model.to(device) train_set = TrainDataSet('train_data.csv') train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) for epoch in range(1, 4): fn_train(model=model, epoch=epoch, optimizer=optimizer, train_loader=train_loader, criterion=criterion) torch.save(model.state_dict(), 'cnn_dict.model') torch.save(model, 'cnn.model') param = torch.load('cnn_dict.model') model.load_state_dict(param) model = model.eval() test_set = TrainDataSet('test_data.csv') test_loader = torch.utils.data.DataLoader(test_set, batch_size=32) pred = [] y_res = [] for index, (x,y) in enumerate(test_loader): with torch.no_grad(): output = model(x) pred += [int(l.argmax()) for l in output] y_res += [int(l) for l in y] print(classification_report(y_res, pred)) df = pd.read_csv('test_data.csv') col = 5 row = int(len(df['image_name']) / col) + 1 plt.figure(figsize=(32, 32)) plt.subplots_adjust(wspace=0.2, hspace=0.7) for index in range(len(df['image_name'])): image = Image.open(df['image_name'][index].replace('\\', '/')) plt.subplot(row, col, index+1) plt.imshow(np.asarray(image)) # 正解の場合は青、不正解の場合は赤 if pred[index] == y_res[index]: color = 'blue' else: color = 'red' result = '' if pred[index] == 0: result = 'dog' elif pred[index] == 1: result = 'cat' elif pred[index] == 2: result = 'Squirrel' plt.xlabel("{}".format(result), color=color, fontsize=20) if __name__ == '__main__': fn_start_learning() |
ざっくりと処理内容を説明します。
以下の箇所でモデル定義と用意した画像データでの学習を行います。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
model = make_model('resnet18', num_classes=3, pretrained=True, input_size=(120, 120)) model = model.to(device) train_set = TrainDataSet('train_data.csv') train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) for epoch in range(1, 4): fn_train(model=model, epoch=epoch, optimizer=optimizer, train_loader=train_loader, criterion=criterion) torch.save(model.state_dict(), 'cnn_dict.model') torch.save(model, 'cnn.model') |
以下の箇所で学習モデルの保存を行い、テストデータでの検証をしています。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
param = torch.load('cnn_dict.model') model.load_state_dict(param) model = model.eval() test_set = TrainDataSet('test_data.csv') test_loader = torch.utils.data.DataLoader(test_set, batch_size=32) pred = [] y_res = [] for index, (x,y) in enumerate(test_loader): with torch.no_grad(): output = model(x) pred += [int(l.argmax()) for l in output] y_res += [int(l) for l in y] print(classification_report(y_res, pred)) |
以下の箇所で予測結果と該当する画像を出力しています。
予測が正しい場合は、青色文字で出力し、不正解の場合は、赤文字で出力しています。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
df = pd.read_csv('test_data.csv') col = 5 row = int(len(df['image_name']) / col) + 1 plt.figure(figsize=(32, 32)) plt.subplots_adjust(wspace=0.2, hspace=0.7) for index in range(len(df['image_name'])): image = Image.open(df['image_name'][index].replace('\\', '/')) plt.subplot(row, col, index+1) plt.imshow(np.asarray(image)) # 正解の場合は青、不正解の場合は赤 if pred[index] == y_res[index]: color = 'blue' else: color = 'red' result = '' if pred[index] == 0: result = 'dog' elif pred[index] == 1: result = 'cat' elif pred[index] == 2: result = 'Squirrel' plt.xlabel("{}".format(result), color=color, fontsize=20) |
出力結果
出力結果は、以下になります。
1 2 3 4 5 6 7 8 9 |
precision recall f1-score support 0 0.67 0.83 0.74 12 1 0.93 0.88 0.90 16 2 0.95 0.86 0.90 22 accuracy 0.86 50 macro avg 0.85 0.86 0.85 50 weighted avg 0.88 0.86 0.86 50 |
どうでしょうか?
今回は、ResNetの学習済みモデルを使用しましたが、その他の学習済みモデルも公開されていますので試してみましょう。