見出し画像

TensorFlow Lite入門 / iOSによる画像分類

1. iOSによる画像分類

iOSで「TensorFlow Lite」を使って画像分類を行ます。端末の背面カメラから見えるものをリアルタイムに画像分類し、可能性の高いラベル3つを表示します。

◎バージョン
・Xcode 10.3
・Swift 5
・TensorFlowLiteSwift 1.14.0

2. プロジェクトへのTensorFlow Liteフレームワークの追加

プロジェクトへのTensorFlow Liteフレームワークを追加するには、CocosPodsを使います。「pod init」でProfileを生成し、PodfileにTensorFlow Liteのフレームワークを追加し、最後に「pod install」でフレームワークの追加を実行します。

platform :ios, '12.0'

target 'CaptureClassificationEx' do
 use_frameworks!
 pod 'TensorFlowLiteSwift'
end

以降、プロジェクトを開く時、ImageClassification.xcworkspaceをダブルクリックします。

3. リソース

プロジェクトには、Image classificationからダウンロードしたTensorFlow Liteモデルとラベルを追加します。

・mobilenet_v1_1.0_224_quant.tflite
・labels_mobilenet_quant_v1_224.txt

4. 推論

「推論」を行っているのは、以下のコードです。
具体的には、interpreter.input()で入力バッファを取得し、interpreter.copy()で入力バッファを指定し、try interpreter.invoke()で推論し、interpreter.output()で出力バッファを取得するのみで簡単です。
難しい(めんどくさい)のは、「カメラの準備」と「入力バッファの生成」と「出力バッファの解析」になります。

   //予測
   func predict(_ sampleBuffer: CMSampleBuffer) {
       //CMSampleBufferをCVPixelBufferに変換
       let pixelBuffer = CMSampleBufferGetImageBuffer(sampleBuffer)!

       //Pixelフォーマットの確認
       let sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer)
       assert(sourcePixelFormat == kCVPixelFormatType_32ARGB ||
           sourcePixelFormat == kCVPixelFormatType_32BGRA ||
           sourcePixelFormat == kCVPixelFormatType_32RGBA)

       //画像のクロップとスケーリング
       let scaledSize = CGSize(width: INPUT_WIDTH, height: INPUT_HEIGHT)
       guard let cropPixelBuffer = pixelBuffer.centerThumbnail(ofSize: scaledSize) else {
           return
       }

       let outputTensor: Tensor
       do {
           //RGBデータの生成
           let inputTensor = try interpreter.input(at: 0)
           let rgbData = buffer2rgbData(
               cropPixelBuffer,
               byteCount: BATCH_SIZE * INPUT_WIDTH * INPUT_HEIGHT * INPUT_CHANNELS,
               isModelQuantized: inputTensor.dataType == .uInt8)

           //推論の実行
           try interpreter.copy(rgbData!, toInputAt: 0)
           try interpreter.invoke()
           outputTensor = try interpreter.output(at: 0)
       } catch let error {
           print(error.localizedDescription)
           return
       }

       var results: [Float] = []
       switch outputTensor.dataType {
       //量子化モデル
       case .uInt8:
           let quantization = outputTensor.quantizationParameters!
           let quantizedResults = [UInt8](outputTensor.data)
           results = quantizedResults.map{
               quantization.scale * Float(Int($0) - quantization.zeroPoint)}
       //浮動少数モデル
       case .float32:
           results = [Float32](unsafeData: outputTensor.data) ?? []
       //その他
       default:
           return
       }

       //検出結果の取得
       var text: String = "\n"
       let zippedResults = zip(labels.indices, results)
       let sortedResults = zippedResults.sorted {$0.1 > $1.1}.prefix(3) //上位3件ソート
       for result in sortedResults {
           let probabillity = Int(result.1*100) //信頼度
           let label = labels[result.0] //ID
           text += "\(label) : \(probabillity)%\n"
       }

       //UIの更新
       DispatchQueue.main.async {
           self.lblText.text = text
       }
   }

5. ソースコード全体

ソースコード全体は次の通りです。

import UIKit
import AVFoundation
import TensorFlowLite
import Accelerate

//画像分類(カメラ映像)
class ViewController: UIViewController,
   AVCaptureVideoDataOutputSampleBufferDelegate {
   //UI
   @IBOutlet weak var lblText: UILabel!
   @IBOutlet weak var drawView: UIView!
   var previewLayer: AVCaptureVideoPreviewLayer!

   //パラメータ
   let BATCH_SIZE = 1 //バッチサイズ
   let INPUT_CHANNELS = 3 //入力チャンネル
   let INPUT_WIDTH = 224 //入力幅
   let INPUT_HEIGHT = 224 //入力高さ
   let THREAD_COUNT = 1 //スレッド数

   //参照
   var interpreter: Interpreter! //インタプリタ
   var labels: [String]! //ラベル


//====================
//ライフサイクル
//====================
   //ビュー表示時に呼ばれる
   override func viewDidAppear(_ animated: Bool) {
       do {
           //モデルパスの生成
           let modelPath = Bundle.main.path(
               forResource: "mobilenet_v1_1.0_224_quant",
               ofType: "tflite")!

           //インタプリタオプションの生成
           var options = InterpreterOptions()
           options.threadCount = THREAD_COUNT

           //インタプリタの生成
           interpreter = try Interpreter(modelPath: modelPath, options: options)
           try interpreter.allocateTensors()

           //ラベルURLの生成
           let labelURL = Bundle.main.url(
               forResource: "labels_mobilenet_quant_v1_224",
               withExtension: "txt")!

           //ラベルの読み込み
           let contents = try String(contentsOf: labelURL, encoding: .utf8)
           labels = contents.components(separatedBy: .newlines)
       } catch let error {
           print(error.localizedDescription)
       }

       //カメラキャプチャの開始
       startCapture()
   }


//====================
//カメラキャプチャ
//====================
   //カメラキャプチャの開始
   func startCapture() {
       //セッションの生成
       let captureSession = AVCaptureSession()
       captureSession.sessionPreset = AVCaptureSession.Preset.photo //プリセット
       let captureDevice: AVCaptureDevice! = self.device(false)

       //コンフィギュレーションの指定
       do {
           try captureDevice.lockForConfiguration()
           captureDevice.activeVideoMinFrameDuration = CMTimeMake(value: 1, timescale: 20) //FPS
           captureDevice.focusMode = .continuousAutoFocus //フォーカス
           captureDevice.exposureMode = .continuousAutoExposure //露出
           captureDevice.whiteBalanceMode = .continuousAutoWhiteBalance //ホワイトバランス
           captureDevice.unlockForConfiguration()
       } catch {
           return
       }

       //入力の生成
       guard let input = try? AVCaptureDeviceInput(device: captureDevice) else {return}
       guard captureSession.canAddInput(input) else {return}
       captureSession.addInput(input)

       //出力の生成
       let output: AVCaptureVideoDataOutput = AVCaptureVideoDataOutput()
       output.setSampleBufferDelegate(self, queue: DispatchQueue(label: "VideoQueue"))
       output.videoSettings = [String(kCVPixelBufferPixelFormatTypeKey) : kCMPixelFormat_32BGRA] //画像フォーマット
       output.alwaysDiscardsLateVideoFrames = true //出力の遅延フレームの破棄
       guard captureSession.canAddOutput(output) else {return}
       captureSession.addOutput(output)

       //画面の向き
       let videoConnection = output.connection(with: AVMediaType.video)
       videoConnection!.videoOrientation = .portrait
   
       //プレビューの指定
       previewLayer = AVCaptureVideoPreviewLayer(session: captureSession)
       previewLayer.videoGravity = AVLayerVideoGravity.resizeAspectFill
       previewLayer.frame = self.drawView.frame
       self.view.layer.insertSublayer(previewLayer, at: 0)

       //カメラキャプチャの開始
       captureSession.startRunning()
   }

   //デバイスの取得
   func device(_ frontCamera: Bool) -> AVCaptureDevice! {
       //AVCaptureDeviceのリストの取得
       let deviceDiscoverySession = AVCaptureDevice.DiscoverySession(
           deviceTypes: [AVCaptureDevice.DeviceType.builtInWideAngleCamera],
           mediaType: AVMediaType.video,
           position: AVCaptureDevice.Position.unspecified)
       let devices = deviceDiscoverySession.devices

       //指定したポジションを持つAVCaptureDeviceの検索
       let position: AVCaptureDevice.Position = frontCamera ? .front : .back
       for device in devices {
           if device.position == position {
               return device
           }
       }
       return nil
   }

   //カメラキャプチャの取得時に呼ばれる
   func captureOutput(_ output: AVCaptureOutput,
       didOutput sampleBuffer: CMSampleBuffer,
       from connection: AVCaptureConnection) {

       //予測
       predict(sampleBuffer)
   }


//====================
//画像分類(カメラ映像)
//====================
   //予測
   func predict(_ sampleBuffer: CMSampleBuffer) {
       //CMSampleBufferをCVPixelBufferに変換
       let pixelBuffer = CMSampleBufferGetImageBuffer(sampleBuffer)!

       //Pixelフォーマットの確認
       let sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer)
       assert(sourcePixelFormat == kCVPixelFormatType_32ARGB ||
           sourcePixelFormat == kCVPixelFormatType_32BGRA ||
           sourcePixelFormat == kCVPixelFormatType_32RGBA)


       //画像のクロップとスケーリング
       let scaledSize = CGSize(width: INPUT_WIDTH, height: INPUT_HEIGHT)
       guard let cropPixelBuffer = pixelBuffer.centerThumbnail(ofSize: scaledSize) else {
           return
       }

       let outputTensor: Tensor
       do {
           //RGBデータの生成
           let inputTensor = try interpreter.input(at: 0)
           let rgbData = buffer2rgbData(
               cropPixelBuffer,
               byteCount: BATCH_SIZE * INPUT_WIDTH * INPUT_HEIGHT * INPUT_CHANNELS,
               isModelQuantized: inputTensor.dataType == .uInt8)

           //推論の実行
           try interpreter.copy(rgbData!, toInputAt: 0)
           try interpreter.invoke()
           outputTensor = try interpreter.output(at: 0)
       } catch let error {
           print(error.localizedDescription)
           return
       }
       var results: [Float] = []

       //量子化モデル
       if outputTensor.dataType == .uInt8 {
           let quantization = outputTensor.quantizationParameters!
           let quantizedResults = [UInt8](outputTensor.data)
           results = quantizedResults.map{
               quantization.scale * Float(Int($0) - quantization.zeroPoint)}
       }
       //浮動少数モデル
       else if outputTensor.dataType == .float32 {
           results = [Float32](unsafeData: outputTensor.data) ?? []
       }

       //検出結果の取得
       var text: String = "\n"
       let zippedResults = zip(labels.indices, results)
       let sortedResults = zippedResults.sorted {$0.1 > $1.1}.prefix(3) //上位3件ソート
       for result in sortedResults {
           let probabillity = Int(result.1*100) //信頼度
           let label = labels[result.0] //ID
           text += "\(label) : \(probabillity)%\n"
       }

       //UIの更新
       DispatchQueue.main.async {
           self.lblText.text = text
       }
   }

   //PixelBuffer→rgbData
   private func buffer2rgbData(_ buffer: CVPixelBuffer,
       byteCount: Int, isModelQuantized: Bool) -> Data? {
       //PixelBuffer→bufferData
       CVPixelBufferLockBaseAddress(buffer, .readOnly)
       defer { CVPixelBufferUnlockBaseAddress(buffer, .readOnly) }
       guard let mutableRawPointer = CVPixelBufferGetBaseAddress(buffer) else {
         return nil
       }
       let count = CVPixelBufferGetDataSize(buffer)
       let bufferData = Data(bytesNoCopy: mutableRawPointer,
           count: count, deallocator: .none)

       //bufferData→rgbBytes
       var rgbBytes = [UInt8](repeating: 0, count: byteCount)
       var index = 0
       for component in bufferData.enumerated() {
         let offset = component.offset
         let isAlphaComponent = (offset % 4) == 3
         guard !isAlphaComponent else {continue}
         rgbBytes[index] = component.element
         index += 1
       }

       //rgbBytes→rgbData
       if isModelQuantized {return Data(bytes: rgbBytes)}
       return Data(copyingBufferOf: rgbBytes.map{Float($0)/255.0})
   }
}

//====================
//拡張
//====================
//CVPixelBufferの拡張
extension CVPixelBuffer {
   //画像のトリミングとスケーリング
   func centerThumbnail(ofSize size: CGSize ) -> CVPixelBuffer? {
       let imageWidth = CVPixelBufferGetWidth(self)
       let imageHeight = CVPixelBufferGetHeight(self)
       let pixelBufferType = CVPixelBufferGetPixelFormatType(self)
       assert(pixelBufferType == kCVPixelFormatType_32BGRA)
       let inputImageRowBytes = CVPixelBufferGetBytesPerRow(self)
       let imageChannels = 4
       let thumbnailSize = min(imageWidth, imageHeight)
       CVPixelBufferLockBaseAddress(self, CVPixelBufferLockFlags(rawValue: 0))
       var originX = 0
       var originY = 0
       if imageWidth > imageHeight {
         originX = (imageWidth - imageHeight) / 2
       }
       else {
         originY = (imageHeight - imageWidth) / 2
       }
       
       //PixelBufferで最大の正方形をみつける
       guard let inputBaseAddress = CVPixelBufferGetBaseAddress(self)?.advanced(
           by: originY * inputImageRowBytes + originX * imageChannels) else {
         return nil
       }
       
       //入力画像から画像バッファを取得
       var inputVImageBuffer = vImage_Buffer(
           data: inputBaseAddress, height: UInt(thumbnailSize), width: UInt(thumbnailSize),
           rowBytes: inputImageRowBytes)
       let thumbnailRowBytes = Int(size.width) * imageChannels
       guard  let thumbnailBytes = malloc(Int(size.height) * thumbnailRowBytes) else {
         return nil
       }
       
       //サムネイル画像にvImageバッファを割り当て
       var thumbnailVImageBuffer = vImage_Buffer(data: thumbnailBytes,
           height: UInt(size.height), width: UInt(size.width), rowBytes: thumbnailRowBytes)
       
       //入力画像バッファでスケール操作を実行し、サムネイル画像バッファに保存
       let scaleError = vImageScale_ARGB8888(&inputVImageBuffer, &thumbnailVImageBuffer, nil, vImage_Flags(0))
       CVPixelBufferUnlockBaseAddress(self, CVPixelBufferLockFlags(rawValue: 0))
       guard scaleError == kvImageNoError else {
         return nil
       }
       let releaseCallBack: CVPixelBufferReleaseBytesCallback = {mutablePointer, pointer in
           if let pointer = pointer {
               free(UnsafeMutableRawPointer(mutating: pointer))
           }
       }

       //サムネイルのvImageバッファをCVPixelBufferに変換
       var thumbnailPixelBuffer: CVPixelBuffer?
       let conversionStatus = CVPixelBufferCreateWithBytes(
           nil, Int(size.width), Int(size.height), pixelBufferType, thumbnailBytes,
           thumbnailRowBytes, releaseCallBack, nil, nil, &thumbnailPixelBuffer)
       guard conversionStatus == kCVReturnSuccess else {
         free(thumbnailBytes)
         return nil
       }
       return thumbnailPixelBuffer
   }
}

//Dataの拡張
extension Data {
   //float配列→byte配列(長さ4倍)
   init<T>(copyingBufferOf array: [T]) {
       self = array.withUnsafeBufferPointer(Data.init)
   }
}

//Arrayの拡張
extension Array {
   //byte配列→float配列(長さ1/4倍)
   init?(unsafeData: Data) {
       guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
       #if swift(>=5.0)
       self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) }
       #else
       self = unsafeData.withUnsafeBytes {
           .init(UnsafeBufferPointer<Element>(
               start: $0,
               count: unsafeData.count / MemoryLayout<Element>.stride
           ))
       }
       #endif  // swift(>=5.0)
   }
}


この記事が気に入ったらサポートをしてみませんか?