1. 環境は、Window 10 Home (64bit) 上で行った。
2. Anaconda3 (64bit) – Spyder上で、動作確認を行った。
3. python の バージョンは、python 3.7.0 である。
4. pytorch の バージョンは、pytorch 0.4.1 である。
5. Flask の バージョンは、Flask 1.0.2 である。
6. GPU は, NVIDIA社 の GeForce GTX 1050 である。
7. CPU は, Intel社 の Core(TM) i7-7700HQ である。
今回確認した内容は、現場で使える! PyTorch開発入門 深層学習モデルの作成とアプリケーションへの実装 (AI & TECHNOLOGY) の 7.2 Flaskを用いたWebAPI化 (P.175 – P.182) である。
※1. プログラムの詳細は, 書籍を参考(P.175 – P.182)にして下さい.
※2. Unix環境で無いため, 原書のような Gunicorn は使用せず, 代案となるが,
参照URL① にある HTTPクライアント(Advanced REST client)を使う形で, 動作確認を行った.
原書の確認方法から, かなり逸脱してしまったが, 学習済みモデルに, テスト画像を読み込ませて,
予測した情報をレスポンスさせる形で確認できたので, とりあえず良しとした.
※3. 下記, ソースのコメントにも記載したが, “loadするモデル” に注意が必要である,
具体的には, resnet18 以外のネットワークを使った学習モデルを保存(ここでは, taco_and_burrito_09.pth)した場合,
Missing key(s) in state_dict: に関する RuntimeError が生じる.
これを回避するため, resnet18 のネットワークを改変せず, そのまま taco画像, burrito画像を訓練したモデル
(ここでは, resnet18_19.pth)を保存し, このモデルを loadする形で, 動作確認を行った.
[モデル/load error]
[モデル/再保存]
[モデル/load ok]
※4. 下記, ソースのコメントにも記載したが, 最初, 変数 img の取得が出来なかったので, 参照URL② を参考に, Debug Mode(SET FLASK_ENV=development を指定する形)で, 動作確認を行った.
[KeyError: ‘img’]
[Advanced REST client/setting 1]
[Advanced REST client/setting 2]
■フォルダ構成.
■対象プログラム(原書を一部改変).
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
# -*- coding: utf-8 -*- # 1. library import. from flask import Flask, request, jsonify from PIL import Image # 2. create an application. def create_app(classifier): ~(略)~ def predict(): # get handler of received file. # Debug Mode. # http://flask.pocoo.org/docs/1.0/quickstart/#debug-mode # SET FLASK_ENV=development # -> 以下の error を抽出できた. # raise exceptions.BadRequestKeyError(key) # werkzeug.exceptions.HTTPException.wrap.<locals>.newcls: # 400 Bad Request: KeyError: 'img' # print("request:" + str(request.files)) # ImmutableMultiDict([]) img_file = request.files["img"] # check if the file is empty. print("img_file.filename: " + str(img_file.filename) + " => OK!") ~(略)~ |
1 2 3 4 5 6 7 8 9 10 11 12 |
# -*- coding: utf-8 -*- # 1. library import. from torch import nn from torchvision import transforms, models # 2. create network. def create_network(): ~(略)~ # 3. declare a class. class Classifier(object): ~(略)~ |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
# -*- coding: utf-8 -*- # 1. library import. import os, torch from flaskTest import create_app from flaskTest.classifier import Classifier # 2. set path. folder_path = os.path.expanduser('~') folder_path = folder_path + '\\.spyder-py3\\pytorch\\flaskTest\\' # pth_file = folder_path + "taco_and_burrito_09.pth" pth_file = folder_path + "resnet18_19.pth" # 3. load parameters. # RuntimeError: Error(s) in loading state_dict for ResNet: # Missing key(s) in state_dict: "conv1.weight", "bn1.weight", ... # "0.11.running_var", "0.11.num_batches_tracked", "1.weight", "1.bias". # -> このエラーは, load する pthファイルの不具合と推測される. # -> resnet18 以外のネットワークを使った学習モデルを保存したことに起因する, パラメータ不具合と推測. # -> resnet18 のネットワークを改変しないで, そのまま taco画像, burrito画像を訓練したモデルを保存して, 再確認. ~(略)~ |
■実行結果.
1. 訓練データ(taco)
画像:taco_000.jpg
予測:taco
⇒ 正解
2. 訓練データ(burrito)
画像:burrito_000.jpg
予測:burrito
⇒ 正解
3. テストデータ(taco)
画像:taco_380.jpg
予測:taco
⇒ 正解
4. テストデータ(burrito)
画像:burrito_383.jpg
予測:taco
⇒ 不正解
5. テストデータ(burrito)
画像:burrito_374.jpg
予測:burrito
⇒ 正解
■参照サイト
【参照URL①】Advanced REST client
【参照URL②】Debug Mode.
■参考書籍
現場で使える! PyTorch開発入門 深層学習モデルの作成とアプリケーションへの実装 (AI & TECHNOLOGY)