// Copyright 2019 The TensorFlow Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. import CoreImage import TensorFlowLite import UIKit import Accelerate /// A result from invoking the `Interpreter`. struct Result { let inferenceTime: Double let inferences: [Inference] } /// An inference from invoking the `Interpreter`. struct Inference { let confidence: Float let label: String } /// Information about a model file or labels file. typealias FileInfo = (name: String, extension: String) /// Information about the MobileNet model. enum MobileNet { static let modelInfo: FileInfo = (name: "liveness", extension: "tflite") static let cardModel: FileInfo = (name: "valid_card_10102020", extension: "tflite") static let landMarkModel: FileInfo = (name: "face_detection_front", extension: "tflite") } /// This class handles all data preprocessing and makes calls to run inference on a given frame /// by invoking the `Interpreter`. It then formats the inferences obtained and returns the top N /// results for a successful inference. class SBKModelDataHandler { // MARK: - Internal Properties /// The current thread count used by the TensorFlow Lite Interpreter. let threadCount: Int let resultCount = 3 let threadCountLimit = 10 // MARK: - Model Parameters let batchSize = 1 let inputChannels = 3 let inputWidth = 224 let inputHeight = 224 // MARK: - Private Properties /// List of labels from the given labels file. private var labels: [String] = [] /// TensorFlow Lite `Interpreter` object for performing inference on a given model. private var interpreter: Interpreter /// Information about the alpha component in RGBA data. private let alphaComponent = (baseOffset: 4, moduloRemainder: 3) // MARK: - Initialization /// A failable initializer for `ModelDataHandler`. A new instance is created if the model and /// labels files are successfully loaded from the app's main bundle. Default `threadCount` is 1. init?(modelFileInfo: FileInfo, threadCount: Int = 1) { let modelFilename = modelFileInfo.name // Construct the path to the model file. let bundle = Bundle(for: SBKRecordFace.self) guard let modelPath = bundle.path( forResource: modelFilename, ofType: modelFileInfo.extension ) else { print("Failed to load the model file with name: \(modelFilename).") return nil } let delegate = MetalDelegate() // Specify the options for the `Interpreter`. self.threadCount = threadCount var options = Interpreter.Options() options.threadCount = threadCount do { // Create the `Interpreter`. interpreter = try Interpreter(modelPath: modelPath, options: options, delegates: [delegate]) // Allocate memory for the model's input `Tensor`s. try interpreter.allocateTensors() } catch let error { print("Failed to create the interpreter with error: \(error.localizedDescription)") return nil } } func fromImage(image: UIImage, datas: Data, imagesss: UIImage) -> UIColor { var totalR: CGFloat = 0 var totalG: CGFloat = 0 var totalB: CGFloat = 0 var count: CGFloat = 0 for x in 0..<Int(image.size.width) { for y in 0..<Int(image.size.height) { count += 1 var rF: CGFloat = 0, gF: CGFloat = 0, bF: CGFloat = 0, aF: CGFloat = 0 image.getPixelColor(pos: CGPoint(x: x, y: y), dataImage: datas, image: imagesss ).getRed(&rF, green: &gF, blue: &bF, alpha: &aF) totalR += rF totalG += gF totalB += bF } } let averageR = totalR / count let averageG = totalG / count let averageB = totalB / count return UIColor(red: averageR, green: averageG, blue: averageB, alpha: 1.0) } func convert(cmage:CIImage) -> UIImage { let context:CIContext = CIContext.init(options: nil) let cgImage:CGImage = context.createCGImage(cmage, from: cmage.extent)! let image:UIImage = UIImage.init(cgImage: cgImage) return image } // MARK: - Internal Methods /// Performs image preprocessing, invokes the `Interpreter`, and processes the inference results. func runModel(onFrame pixelBuffer: CVPixelBuffer) -> [Float]? { let sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer) assert(sourcePixelFormat == kCVPixelFormatType_32ARGB || sourcePixelFormat == kCVPixelFormatType_32BGRA || sourcePixelFormat == kCVPixelFormatType_32RGBA) let imageChannels = 4 assert(imageChannels >= inputChannels) // Crops the image to the biggest square in the center and scales it down to model dimensions. let scaledSize = CGSize(width: inputWidth, height: inputHeight) guard let thumbnailPixelBuffer = pixelBuffer.centerThumbnail(ofSize: scaledSize) else { return nil } let interval: TimeInterval let outputTensor: Tensor do { let inputTensor = try interpreter.input(at: 0) // Remove the alpha component from the image buffer to get the RGB data. guard let rgbData = rgbDataFromBuffer( thumbnailPixelBuffer, byteCount: batchSize * inputWidth * inputHeight * inputChannels, isModelQuantized: inputTensor.dataType == .uInt8 ) else { print("Failed to convert the image buffer to RGB data.") return nil } let imageCap = UIImage(data: rgbData) // self.fromImage(image: imageCap!, datas: rgbData, imagesss: imageCap!) // Copy the RGB data to the input `Tensor`. try interpreter.copy(rgbData, toInputAt: 0) // Run inference by invoking the `Interpreter`. let startDate = Date() try interpreter.invoke() interval = Date().timeIntervalSince(startDate) * 1000 // Get the output `Tensor` to process the inference results. outputTensor = try interpreter.output(at: 0) } catch let error { print("Failed to invoke the interpreter with error: \(error.localizedDescription)") return nil } let results: [Float] switch outputTensor.dataType { case .uInt8: guard let quantization = outputTensor.quantizationParameters else { print("No results returned because the quantization values for the output tensor are nil.") return nil } let quantizedResults = [UInt8](outputTensor.data) results = quantizedResults.map { quantization.scale * Float(Int($0) - quantization.zeroPoint) } case .float32: results = [Float32](unsafeData: outputTensor.data) ?? [] default: print("Output tensor data type \(outputTensor.dataType) is unsupported for this example app.") return nil } return results } private func rgbDataFromBuffer( _ buffer: CVPixelBuffer, byteCount: Int, isModelQuantized: Bool ) -> Data? { CVPixelBufferLockBaseAddress(buffer, .readOnly) defer { CVPixelBufferUnlockBaseAddress(buffer, .readOnly) } guard let sourceData = CVPixelBufferGetBaseAddress(buffer) else { return nil } let width = CVPixelBufferGetWidth(buffer) let height = CVPixelBufferGetHeight(buffer) let sourceBytesPerRow = CVPixelBufferGetBytesPerRow(buffer) let destinationChannelCount = 3 let destinationBytesPerRow = destinationChannelCount * width var sourceBuffer = vImage_Buffer(data: sourceData, height: vImagePixelCount(height), width: vImagePixelCount(width), rowBytes: sourceBytesPerRow) guard let destinationData = malloc(height * destinationBytesPerRow) else { print("Error: out of memory") return nil } defer { free(destinationData) } var destinationBuffer = vImage_Buffer(data: destinationData, height: vImagePixelCount(height), width: vImagePixelCount(width), rowBytes: destinationBytesPerRow) let pixelBufferFormat = CVPixelBufferGetPixelFormatType(buffer) switch (pixelBufferFormat) { case kCVPixelFormatType_32BGRA: vImageConvert_BGRA8888toRGB888(&sourceBuffer, &destinationBuffer, UInt32(kvImageNoFlags)) case kCVPixelFormatType_32ARGB: vImageConvert_ARGB8888toRGB888(&sourceBuffer, &destinationBuffer, UInt32(kvImageNoFlags)) case kCVPixelFormatType_32RGBA: vImageConvert_RGBA8888toRGB888(&sourceBuffer, &destinationBuffer, UInt32(kvImageNoFlags)) default: // Unknown pixel format. return nil } let byteData = Data(bytes: destinationBuffer.data, count: destinationBuffer.rowBytes * height) if isModelQuantized { return byteData } // Not quantized, convert to floats let bytes = Array<UInt8>(unsafeData: byteData)! var floats = [Float]() for i in 0..<bytes.count { floats.append(Float(bytes[i]) / 255.0) } return Data(copyingBufferOf: floats) } } // MARK: - Extensions extension Data { init<T>(copyingBufferOf array: [T]) { self = array.withUnsafeBufferPointer(Data.init) } } extension Array { init?(unsafeData: Data) { guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil } #if swift(>=5.0) self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) } #else self = unsafeData.withUnsafeBytes { .init(UnsafeBufferPointer<Element>( start: $0, count: unsafeData.count / MemoryLayout<Element>.stride )) } #endif // swift(>=5.0) } }