diff --git a/.idea/misc.xml b/.idea/misc.xml
index e1c3f7e..ef46287 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -4,7 +4,7 @@
-
+
\ No newline at end of file
diff --git a/.idea/uiDesigner.xml b/.idea/uiDesigner.xml
new file mode 100644
index 0000000..2b63946
--- /dev/null
+++ b/.idea/uiDesigner.xml
@@ -0,0 +1,124 @@
+
+
+
+
+ -
+
+
+ -
+
+
+ -
+
+
+ -
+
+
+ -
+
+
+
+
+
+ -
+
+
+
+
+
+ -
+
+
+
+
+
+ -
+
+
+
+
+
+ -
+
+
+
+
+ -
+
+
+
+
+ -
+
+
+
+
+ -
+
+
+
+
+ -
+
+
+
+
+ -
+
+
+
+
+ -
+
+
+ -
+
+
+
+
+ -
+
+
+
+
+ -
+
+
+
+
+ -
+
+
+
+
+ -
+
+
+
+
+ -
+
+
+ -
+
+
+ -
+
+
+ -
+
+
+ -
+
+
+
+
+ -
+
+
+ -
+
+
+
+
+
\ No newline at end of file
diff --git a/build.gradle b/build.gradle
index c94ea93..cc4e76f 100644
--- a/build.gradle
+++ b/build.gradle
@@ -25,7 +25,7 @@ tasks.withType(JavaCompile) {
application {
mainModule = 'com.nuculabs.dev.imagetagger.ui'
- mainClass = 'com.nuculabs.dev.imagetagger.ui.HelloApplication'
+ mainClass = 'com.nuculabs.dev.imagetagger.ui.MainApplication'
}
kotlin {
jvmToolchain( 17 )
@@ -45,7 +45,7 @@ dependencies {
exclude(group: 'org.openjfx')
}
implementation('org.kordamp.ikonli:ikonli-javafx:12.3.1')
-
+ implementation('com.microsoft.onnxruntime:onnxruntime:1.17.1')
testImplementation("org.junit.jupiter:junit-jupiter-api:${junitVersion}")
testRuntimeOnly("org.junit.jupiter:junit-jupiter-engine:${junitVersion}")
}
diff --git a/src/main/java/module-info.java b/src/main/java/module-info.java
index a2a95e8..abf62b6 100644
--- a/src/main/java/module-info.java
+++ b/src/main/java/module-info.java
@@ -7,6 +7,9 @@ module com.nuculabs.dev.imagetagger.ui {
requires com.dlsc.formsfx;
requires net.synedra.validatorfx;
requires org.kordamp.ikonli.javafx;
+ requires com.microsoft.onnxruntime;
+ requires java.logging;
+ requires java.desktop;
opens com.nuculabs.dev.imagetagger.ui to javafx.fxml;
exports com.nuculabs.dev.imagetagger.ui;
diff --git a/src/main/kotlin/com/nuculabs/dev/imagetagger/tag_prediction/ImageTagsPrediction.kt b/src/main/kotlin/com/nuculabs/dev/imagetagger/tag_prediction/ImageTagsPrediction.kt
new file mode 100644
index 0000000..c0a4e44
--- /dev/null
+++ b/src/main/kotlin/com/nuculabs/dev/imagetagger/tag_prediction/ImageTagsPrediction.kt
@@ -0,0 +1,160 @@
+package com.nuculabs.dev.imagetagger.tag_prediction
+
+import ai.onnxruntime.OnnxTensor
+import ai.onnxruntime.OrtEnvironment
+import ai.onnxruntime.OrtSession
+import java.awt.image.BufferedImage
+import java.io.File
+import java.io.IOException
+import java.io.InputStream
+import java.util.logging.Logger
+import javax.imageio.ImageIO
+
+/**
+ * ImageTagsPrediction is a specialized class that predicts an Image's tags
+ */
+class ImageTagsPrediction private constructor() {
+ private val logger: Logger = Logger.getLogger("InfoLogging")
+ private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment()
+ private var ortSession: OrtSession;
+ private var modelClasses: MutableList = mutableListOf()
+
+ init {
+ logger.info("Loaded ML model.")
+ val modelFile = File("/home/denis/IdeaProjects/ImageTagger/src/main/resources/AIModels/prediction.onnx")
+ val classesFile =
+ File("/home/denis/IdeaProjects/ImageTagger/src/main/resources/AIModels/prediction_categories.txt")
+
+ ortSession = ortEnv.createSession(
+ modelFile.readBytes(),
+ OrtSession.SessionOptions()
+ )
+ modelClasses.addAll(0, classesFile.bufferedReader().readLines())
+ logger.info("Loaded ${modelClasses.size} model classes.")
+ }
+
+ private fun processImage(bufferedImage: BufferedImage): Array>> {
+ try {
+ val tensorData = Array(1) {
+ Array(3) {
+ Array(224) {
+ FloatArray(224)
+ }
+ }
+ }
+ val mean = floatArrayOf(0.485f, 0.456f, 0.406f)
+ val standardDeviation = floatArrayOf(0.229f, 0.224f, 0.225f)
+
+ // crop image to 224x224
+ var width: Int = bufferedImage.width
+ var height: Int = bufferedImage.height
+ var startX = 0
+ var startY = 0
+ if (width > height) {
+ startX = (width - height) / 2
+ width = height
+ } else {
+ startY = (height - width) / 2
+ height = width
+ }
+
+ val image = bufferedImage.getSubimage(startX, startY, width, height);
+ val resizedImage = image.getScaledInstance(224, 224, 4);
+ val scaledImage = BufferedImage(224, 224, BufferedImage.TYPE_4BYTE_ABGR)
+ scaledImage.graphics.drawImage(resizedImage, 0, 0, null)
+
+
+ // Process image
+ for (y in 0 until scaledImage.height) {
+ for (x in 0 until scaledImage.width) {
+ val pixel: Int = scaledImage.getRGB(x, y)
+
+ // Get RGB values
+ tensorData[0][0][y][x] =
+ ((pixel shr 16 and 0xFF) / 255f - mean[0]) / standardDeviation[0]
+ tensorData[0][1][y][x] =
+ ((pixel shr 16 and 0xFF) / 255f - mean[1]) / standardDeviation[1]
+ tensorData[0][2][y][x] =
+ ((pixel shr 16 and 0xFF) / 255f - mean[2]) / standardDeviation[2]
+ }
+ }
+ return tensorData
+ } catch (e: IOException) {
+ throw RuntimeException(e)
+ }
+ }
+
+ /**
+ * Uses the ML model to predict tags for a given bitmap.
+ */
+ @Suppress("UNCHECKED_CAST")
+ private fun predictTagsInternal(bufferedImage: BufferedImage): List {
+ // 1. Get input and output names
+ val inputName: String = ortSession.inputNames.iterator().next()
+ val outputName: String = ortSession.outputNames.iterator().next()
+
+ // 2. Create input tensor
+ val inputTensor = OnnxTensor.createTensor(ortEnv, processImage(bufferedImage))
+
+ // 3. Run the model.
+ val inputs = mapOf(inputName to inputTensor);
+ val results = ortSession.run(inputs);
+
+ // 4. Get output tensor
+ val outputTensor = results.get(outputName);
+ if (outputTensor.isPresent) {
+ // 5. Get prediction results
+ val floatBuffer = outputTensor.get().value as Array
+ val predictions = ArrayList()
+
+ // filter buffer by threshold
+ for (i in floatBuffer[0].indices) {
+ if (floatBuffer[0][i] > -0.5) {
+ predictions.add(modelClasses[i])
+ }
+ }
+
+ return predictions;
+ } else {
+ return ArrayList()
+ }
+ }
+
+ /**
+ * Predicts tags for a Bitmap.
+ */
+ fun predictTags(image: BufferedImage): List {
+ return predictTagsInternal(image)
+ }
+
+ /**
+ * Predicts tags for a given image input stream.
+ */
+ fun predictTags(input: InputStream?): List {
+ if (input == null) {
+ return ArrayList()
+ }
+
+ return predictTagsInternal(ImageIO.read(input))
+ }
+
+ /**
+ * Close the session and environment.
+ */
+ fun close() {
+ ortSession.close()
+ ortEnv.close()
+ modelClasses.clear()
+ }
+
+ // Singleton Pattern
+ companion object {
+ @Volatile
+ private var instance: ImageTagsPrediction? = null
+
+ fun getInstance() =
+ instance ?: synchronized(this) {
+ instance ?: ImageTagsPrediction().also { instance = it }
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/main/kotlin/com/nuculabs/dev/imagetagger/ui/HelloController.kt b/src/main/kotlin/com/nuculabs/dev/imagetagger/ui/HelloController.kt
index 78744cd..fa1a337 100644
--- a/src/main/kotlin/com/nuculabs/dev/imagetagger/ui/HelloController.kt
+++ b/src/main/kotlin/com/nuculabs/dev/imagetagger/ui/HelloController.kt
@@ -1,7 +1,9 @@
package com.nuculabs.dev.imagetagger.ui
+import com.nuculabs.dev.imagetagger.tag_prediction.ImageTagsPrediction
import javafx.fxml.FXML
import javafx.scene.control.Label
+import java.io.File
class HelloController {
@FXML
@@ -9,6 +11,9 @@ class HelloController {
@FXML
private fun onHelloButtonClick() {
- welcomeText.text = "Welcome to JavaFX Application!"
+ val imageTagsPrediction = ImageTagsPrediction.getInstance()
+ val testTags = imageTagsPrediction.predictTags(File("/home/denis/Pictures/not_in_train/0a1a1e8bafbcdb00d34d87f35f0f4b9f.jpg").inputStream())
+
+ welcomeText.text = testTags.joinToString { it }
}
}
\ No newline at end of file
diff --git a/src/main/kotlin/com/nuculabs/dev/imagetagger/ui/HelloApplication.kt b/src/main/kotlin/com/nuculabs/dev/imagetagger/ui/MainApplication.kt
similarity index 52%
rename from src/main/kotlin/com/nuculabs/dev/imagetagger/ui/HelloApplication.kt
rename to src/main/kotlin/com/nuculabs/dev/imagetagger/ui/MainApplication.kt
index cdddb32..32b59d9 100644
--- a/src/main/kotlin/com/nuculabs/dev/imagetagger/ui/HelloApplication.kt
+++ b/src/main/kotlin/com/nuculabs/dev/imagetagger/ui/MainApplication.kt
@@ -1,13 +1,16 @@
package com.nuculabs.dev.imagetagger.ui
+import com.nuculabs.dev.imagetagger.tag_prediction.ImageTagsPrediction
import javafx.application.Application
import javafx.fxml.FXMLLoader
import javafx.scene.Scene
import javafx.stage.Stage
-class HelloApplication : Application() {
+class MainApplication : Application() {
override fun start(stage: Stage) {
- val fxmlLoader = FXMLLoader(HelloApplication::class.java.getResource("hello-view.fxml"))
+ val imageTagsPrediction = ImageTagsPrediction.getInstance()
+
+ val fxmlLoader = FXMLLoader(MainApplication::class.java.getResource("hello-view.fxml"))
val scene = Scene(fxmlLoader.load(), 320.0, 240.0)
stage.title = "Hello!"
stage.scene = scene
@@ -16,5 +19,5 @@ class HelloApplication : Application() {
}
fun main() {
- Application.launch(HelloApplication::class.java)
+ Application.launch(MainApplication::class.java)
}
\ No newline at end of file