Created
May 26, 2021 04:50
-
-
Save anupamchugh/702cecd741703caa6e0b0975b60f575d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import UIKit | |
import CoreML | |
enum Animal { | |
case cat | |
case dog | |
} | |
class ViewController: UIViewController, UIImagePickerControllerDelegate, UINavigationControllerDelegate { | |
@IBOutlet weak var modelOutputLabel: UILabel! | |
private let model = catdogcoreml() | |
@IBOutlet weak var imageView: UIImageView! | |
private let trainedImageSize = CGSize(width: 150, height: 150) | |
override func viewDidLoad() { | |
super.viewDidLoad() | |
// Do any additional setup after loading the view. | |
} | |
@IBAction func takePhotoClicked(_ sender: Any) { | |
let imagePicker = UIImagePickerController() | |
imagePicker.sourceType = .photoLibrary | |
imagePicker.delegate = self | |
present(imagePicker, animated: true, completion: nil) | |
} | |
func predict(image: UIImage) -> Animal? { | |
do { | |
if let resizedImage = resize(image: image, newSize: trainedImageSize), let pixelBuffer = resizedImage.toCVPixelBuffer() { | |
let prediction = try model.prediction(image: pixelBuffer) | |
let value = prediction.output[0].intValue | |
print(value) | |
if value == 1{ | |
return .dog | |
} | |
else{ | |
return .cat | |
} | |
} | |
} catch { | |
print("Error while doing predictions: \(error)") | |
} | |
return nil | |
} | |
func resize(image: UIImage, newSize: CGSize) -> UIImage? { | |
UIGraphicsBeginImageContextWithOptions(newSize, false, 0.0) | |
image.draw(in: CGRect(x: 0, y: 0, width: newSize.width, height: newSize.height)) | |
let newImage = UIGraphicsGetImageFromCurrentImageContext() | |
UIGraphicsEndImageContext() | |
return newImage | |
} | |
func imagePickerControllerDidCancel(_ picker: UIImagePickerController) { | |
dismiss(animated: true, completion: nil) | |
} | |
func imagePickerController(_ picker: UIImagePickerController, didFinishPickingMediaWithInfo info: [UIImagePickerController.InfoKey : Any]) { | |
dismiss(animated: true) { | |
if let image = info[UIImagePickerController.InfoKey.originalImage] as? UIImage { | |
let animal = self.predict(image: image) | |
self.imageView.image = image | |
if let animal = animal{ | |
if animal == .dog{ | |
self.modelOutputLabel.text = "Dog" | |
} | |
else if animal == .cat{ | |
self.modelOutputLabel.text = "Cat" | |
} | |
} | |
else{ | |
self.modelOutputLabel.text = "Neither dog nor cat." | |
} | |
} | |
} | |
} | |
} | |
extension UIImage { | |
func toCVPixelBuffer() -> CVPixelBuffer? { | |
let attrs = [kCVPixelBufferCGImageCompatibilityKey: kCFBooleanTrue, kCVPixelBufferCGBitmapContextCompatibilityKey: kCFBooleanTrue] as CFDictionary | |
var pixelBuffer : CVPixelBuffer? | |
let status = CVPixelBufferCreate(kCFAllocatorDefault, Int(self.size.width), Int(self.size.height), kCVPixelFormatType_32ARGB, attrs, &pixelBuffer) | |
guard (status == kCVReturnSuccess) else { | |
return nil | |
} | |
CVPixelBufferLockBaseAddress(pixelBuffer!, CVPixelBufferLockFlags(rawValue: 0)) | |
let pixelData = CVPixelBufferGetBaseAddress(pixelBuffer!) | |
let rgbColorSpace = CGColorSpaceCreateDeviceRGB() | |
let context = CGContext(data: pixelData, width: Int(self.size.width), height: Int(self.size.height), bitsPerComponent: 8, bytesPerRow: CVPixelBufferGetBytesPerRow(pixelBuffer!), space: rgbColorSpace, bitmapInfo: CGImageAlphaInfo.noneSkipFirst.rawValue) | |
context?.translateBy(x: 0, y: self.size.height) | |
context?.scaleBy(x: 1.0, y: -1.0) | |
UIGraphicsPushContext(context!) | |
self.draw(in: CGRect(x: 0, y: 0, width: self.size.width, height: self.size.height)) | |
UIGraphicsPopContext() | |
CVPixelBufferUnlockBaseAddress(pixelBuffer!, CVPixelBufferLockFlags(rawValue: 0)) | |
return pixelBuffer | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment