diff --git a/.idea/misc.xml b/.idea/misc.xml index e1c3f7e..ef46287 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -4,7 +4,7 @@ - + \ No newline at end of file diff --git a/.idea/uiDesigner.xml b/.idea/uiDesigner.xml new file mode 100644 index 0000000..2b63946 --- /dev/null +++ b/.idea/uiDesigner.xml @@ -0,0 +1,124 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/build.gradle b/build.gradle index c94ea93..cc4e76f 100644 --- a/build.gradle +++ b/build.gradle @@ -25,7 +25,7 @@ tasks.withType(JavaCompile) { application { mainModule = 'com.nuculabs.dev.imagetagger.ui' - mainClass = 'com.nuculabs.dev.imagetagger.ui.HelloApplication' + mainClass = 'com.nuculabs.dev.imagetagger.ui.MainApplication' } kotlin { jvmToolchain( 17 ) @@ -45,7 +45,7 @@ dependencies { exclude(group: 'org.openjfx') } implementation('org.kordamp.ikonli:ikonli-javafx:12.3.1') - + implementation('com.microsoft.onnxruntime:onnxruntime:1.17.1') testImplementation("org.junit.jupiter:junit-jupiter-api:${junitVersion}") testRuntimeOnly("org.junit.jupiter:junit-jupiter-engine:${junitVersion}") } diff --git a/src/main/java/module-info.java b/src/main/java/module-info.java index a2a95e8..abf62b6 100644 --- a/src/main/java/module-info.java +++ b/src/main/java/module-info.java @@ -7,6 +7,9 @@ module com.nuculabs.dev.imagetagger.ui { requires com.dlsc.formsfx; requires net.synedra.validatorfx; requires org.kordamp.ikonli.javafx; + requires com.microsoft.onnxruntime; + requires java.logging; + requires java.desktop; opens com.nuculabs.dev.imagetagger.ui to javafx.fxml; exports com.nuculabs.dev.imagetagger.ui; diff --git a/src/main/kotlin/com/nuculabs/dev/imagetagger/tag_prediction/ImageTagsPrediction.kt b/src/main/kotlin/com/nuculabs/dev/imagetagger/tag_prediction/ImageTagsPrediction.kt new file mode 100644 index 0000000..c0a4e44 --- /dev/null +++ b/src/main/kotlin/com/nuculabs/dev/imagetagger/tag_prediction/ImageTagsPrediction.kt @@ -0,0 +1,160 @@ +package com.nuculabs.dev.imagetagger.tag_prediction + +import ai.onnxruntime.OnnxTensor +import ai.onnxruntime.OrtEnvironment +import ai.onnxruntime.OrtSession +import java.awt.image.BufferedImage +import java.io.File +import java.io.IOException +import java.io.InputStream +import java.util.logging.Logger +import javax.imageio.ImageIO + +/** + * ImageTagsPrediction is a specialized class that predicts an Image's tags + */ +class ImageTagsPrediction private constructor() { + private val logger: Logger = Logger.getLogger("InfoLogging") + private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment() + private var ortSession: OrtSession; + private var modelClasses: MutableList = mutableListOf() + + init { + logger.info("Loaded ML model.") + val modelFile = File("/home/denis/IdeaProjects/ImageTagger/src/main/resources/AIModels/prediction.onnx") + val classesFile = + File("/home/denis/IdeaProjects/ImageTagger/src/main/resources/AIModels/prediction_categories.txt") + + ortSession = ortEnv.createSession( + modelFile.readBytes(), + OrtSession.SessionOptions() + ) + modelClasses.addAll(0, classesFile.bufferedReader().readLines()) + logger.info("Loaded ${modelClasses.size} model classes.") + } + + private fun processImage(bufferedImage: BufferedImage): Array>> { + try { + val tensorData = Array(1) { + Array(3) { + Array(224) { + FloatArray(224) + } + } + } + val mean = floatArrayOf(0.485f, 0.456f, 0.406f) + val standardDeviation = floatArrayOf(0.229f, 0.224f, 0.225f) + + // crop image to 224x224 + var width: Int = bufferedImage.width + var height: Int = bufferedImage.height + var startX = 0 + var startY = 0 + if (width > height) { + startX = (width - height) / 2 + width = height + } else { + startY = (height - width) / 2 + height = width + } + + val image = bufferedImage.getSubimage(startX, startY, width, height); + val resizedImage = image.getScaledInstance(224, 224, 4); + val scaledImage = BufferedImage(224, 224, BufferedImage.TYPE_4BYTE_ABGR) + scaledImage.graphics.drawImage(resizedImage, 0, 0, null) + + + // Process image + for (y in 0 until scaledImage.height) { + for (x in 0 until scaledImage.width) { + val pixel: Int = scaledImage.getRGB(x, y) + + // Get RGB values + tensorData[0][0][y][x] = + ((pixel shr 16 and 0xFF) / 255f - mean[0]) / standardDeviation[0] + tensorData[0][1][y][x] = + ((pixel shr 16 and 0xFF) / 255f - mean[1]) / standardDeviation[1] + tensorData[0][2][y][x] = + ((pixel shr 16 and 0xFF) / 255f - mean[2]) / standardDeviation[2] + } + } + return tensorData + } catch (e: IOException) { + throw RuntimeException(e) + } + } + + /** + * Uses the ML model to predict tags for a given bitmap. + */ + @Suppress("UNCHECKED_CAST") + private fun predictTagsInternal(bufferedImage: BufferedImage): List { + // 1. Get input and output names + val inputName: String = ortSession.inputNames.iterator().next() + val outputName: String = ortSession.outputNames.iterator().next() + + // 2. Create input tensor + val inputTensor = OnnxTensor.createTensor(ortEnv, processImage(bufferedImage)) + + // 3. Run the model. + val inputs = mapOf(inputName to inputTensor); + val results = ortSession.run(inputs); + + // 4. Get output tensor + val outputTensor = results.get(outputName); + if (outputTensor.isPresent) { + // 5. Get prediction results + val floatBuffer = outputTensor.get().value as Array + val predictions = ArrayList() + + // filter buffer by threshold + for (i in floatBuffer[0].indices) { + if (floatBuffer[0][i] > -0.5) { + predictions.add(modelClasses[i]) + } + } + + return predictions; + } else { + return ArrayList() + } + } + + /** + * Predicts tags for a Bitmap. + */ + fun predictTags(image: BufferedImage): List { + return predictTagsInternal(image) + } + + /** + * Predicts tags for a given image input stream. + */ + fun predictTags(input: InputStream?): List { + if (input == null) { + return ArrayList() + } + + return predictTagsInternal(ImageIO.read(input)) + } + + /** + * Close the session and environment. + */ + fun close() { + ortSession.close() + ortEnv.close() + modelClasses.clear() + } + + // Singleton Pattern + companion object { + @Volatile + private var instance: ImageTagsPrediction? = null + + fun getInstance() = + instance ?: synchronized(this) { + instance ?: ImageTagsPrediction().also { instance = it } + } + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/nuculabs/dev/imagetagger/ui/HelloController.kt b/src/main/kotlin/com/nuculabs/dev/imagetagger/ui/HelloController.kt index 78744cd..fa1a337 100644 --- a/src/main/kotlin/com/nuculabs/dev/imagetagger/ui/HelloController.kt +++ b/src/main/kotlin/com/nuculabs/dev/imagetagger/ui/HelloController.kt @@ -1,7 +1,9 @@ package com.nuculabs.dev.imagetagger.ui +import com.nuculabs.dev.imagetagger.tag_prediction.ImageTagsPrediction import javafx.fxml.FXML import javafx.scene.control.Label +import java.io.File class HelloController { @FXML @@ -9,6 +11,9 @@ class HelloController { @FXML private fun onHelloButtonClick() { - welcomeText.text = "Welcome to JavaFX Application!" + val imageTagsPrediction = ImageTagsPrediction.getInstance() + val testTags = imageTagsPrediction.predictTags(File("/home/denis/Pictures/not_in_train/0a1a1e8bafbcdb00d34d87f35f0f4b9f.jpg").inputStream()) + + welcomeText.text = testTags.joinToString { it } } } \ No newline at end of file diff --git a/src/main/kotlin/com/nuculabs/dev/imagetagger/ui/HelloApplication.kt b/src/main/kotlin/com/nuculabs/dev/imagetagger/ui/MainApplication.kt similarity index 52% rename from src/main/kotlin/com/nuculabs/dev/imagetagger/ui/HelloApplication.kt rename to src/main/kotlin/com/nuculabs/dev/imagetagger/ui/MainApplication.kt index cdddb32..32b59d9 100644 --- a/src/main/kotlin/com/nuculabs/dev/imagetagger/ui/HelloApplication.kt +++ b/src/main/kotlin/com/nuculabs/dev/imagetagger/ui/MainApplication.kt @@ -1,13 +1,16 @@ package com.nuculabs.dev.imagetagger.ui +import com.nuculabs.dev.imagetagger.tag_prediction.ImageTagsPrediction import javafx.application.Application import javafx.fxml.FXMLLoader import javafx.scene.Scene import javafx.stage.Stage -class HelloApplication : Application() { +class MainApplication : Application() { override fun start(stage: Stage) { - val fxmlLoader = FXMLLoader(HelloApplication::class.java.getResource("hello-view.fxml")) + val imageTagsPrediction = ImageTagsPrediction.getInstance() + + val fxmlLoader = FXMLLoader(MainApplication::class.java.getResource("hello-view.fxml")) val scene = Scene(fxmlLoader.load(), 320.0, 240.0) stage.title = "Hello!" stage.scene = scene @@ -16,5 +19,5 @@ class HelloApplication : Application() { } fun main() { - Application.launch(HelloApplication::class.java) + Application.launch(MainApplication::class.java) } \ No newline at end of file