predict sample image

This commit is contained in:
Denis-Cosmin Nutiu 2024-03-19 15:37:18 +02:00
parent 0c345ce398
commit cf7b4bd6eb
7 changed files with 302 additions and 7 deletions

View file

@ -4,7 +4,7 @@
<component name="FrameworkDetectionExcludesConfiguration">
<file type="web" url="file://$PROJECT_DIR$" />
</component>
<component name="ProjectRootManager" version="2" languageLevel="JDK_17" project-jdk-name="corretto-21" project-jdk-type="JavaSDK">
<component name="ProjectRootManager" version="2" languageLevel="JDK_21" project-jdk-name="21" project-jdk-type="JavaSDK">
<output url="file://$PROJECT_DIR$/out" />
</component>
</project>

124
.idea/uiDesigner.xml Normal file
View file

@ -0,0 +1,124 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Palette2">
<group name="Swing">
<item class="com.intellij.uiDesigner.HSpacer" tooltip-text="Horizontal Spacer" icon="/com/intellij/uiDesigner/icons/hspacer.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="1" hsize-policy="6" anchor="0" fill="1" />
</item>
<item class="com.intellij.uiDesigner.VSpacer" tooltip-text="Vertical Spacer" icon="/com/intellij/uiDesigner/icons/vspacer.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="6" hsize-policy="1" anchor="0" fill="2" />
</item>
<item class="javax.swing.JPanel" icon="/com/intellij/uiDesigner/icons/panel.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="3" hsize-policy="3" anchor="0" fill="3" />
</item>
<item class="javax.swing.JScrollPane" icon="/com/intellij/uiDesigner/icons/scrollPane.svg" removable="false" auto-create-binding="false" can-attach-label="true">
<default-constraints vsize-policy="7" hsize-policy="7" anchor="0" fill="3" />
</item>
<item class="javax.swing.JButton" icon="/com/intellij/uiDesigner/icons/button.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="3" anchor="0" fill="1" />
<initial-values>
<property name="text" value="Button" />
</initial-values>
</item>
<item class="javax.swing.JRadioButton" icon="/com/intellij/uiDesigner/icons/radioButton.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="3" anchor="8" fill="0" />
<initial-values>
<property name="text" value="RadioButton" />
</initial-values>
</item>
<item class="javax.swing.JCheckBox" icon="/com/intellij/uiDesigner/icons/checkBox.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="3" anchor="8" fill="0" />
<initial-values>
<property name="text" value="CheckBox" />
</initial-values>
</item>
<item class="javax.swing.JLabel" icon="/com/intellij/uiDesigner/icons/label.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="0" anchor="8" fill="0" />
<initial-values>
<property name="text" value="Label" />
</initial-values>
</item>
<item class="javax.swing.JTextField" icon="/com/intellij/uiDesigner/icons/textField.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="8" fill="1">
<preferred-size width="150" height="-1" />
</default-constraints>
</item>
<item class="javax.swing.JPasswordField" icon="/com/intellij/uiDesigner/icons/passwordField.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="8" fill="1">
<preferred-size width="150" height="-1" />
</default-constraints>
</item>
<item class="javax.swing.JFormattedTextField" icon="/com/intellij/uiDesigner/icons/formattedTextField.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="8" fill="1">
<preferred-size width="150" height="-1" />
</default-constraints>
</item>
<item class="javax.swing.JTextArea" icon="/com/intellij/uiDesigner/icons/textArea.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3">
<preferred-size width="150" height="50" />
</default-constraints>
</item>
<item class="javax.swing.JTextPane" icon="/com/intellij/uiDesigner/icons/textPane.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3">
<preferred-size width="150" height="50" />
</default-constraints>
</item>
<item class="javax.swing.JEditorPane" icon="/com/intellij/uiDesigner/icons/editorPane.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3">
<preferred-size width="150" height="50" />
</default-constraints>
</item>
<item class="javax.swing.JComboBox" icon="/com/intellij/uiDesigner/icons/comboBox.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="0" hsize-policy="2" anchor="8" fill="1" />
</item>
<item class="javax.swing.JTable" icon="/com/intellij/uiDesigner/icons/table.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3">
<preferred-size width="150" height="50" />
</default-constraints>
</item>
<item class="javax.swing.JList" icon="/com/intellij/uiDesigner/icons/list.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="6" hsize-policy="2" anchor="0" fill="3">
<preferred-size width="150" height="50" />
</default-constraints>
</item>
<item class="javax.swing.JTree" icon="/com/intellij/uiDesigner/icons/tree.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3">
<preferred-size width="150" height="50" />
</default-constraints>
</item>
<item class="javax.swing.JTabbedPane" icon="/com/intellij/uiDesigner/icons/tabbedPane.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="3" hsize-policy="3" anchor="0" fill="3">
<preferred-size width="200" height="200" />
</default-constraints>
</item>
<item class="javax.swing.JSplitPane" icon="/com/intellij/uiDesigner/icons/splitPane.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="3" hsize-policy="3" anchor="0" fill="3">
<preferred-size width="200" height="200" />
</default-constraints>
</item>
<item class="javax.swing.JSpinner" icon="/com/intellij/uiDesigner/icons/spinner.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="8" fill="1" />
</item>
<item class="javax.swing.JSlider" icon="/com/intellij/uiDesigner/icons/slider.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="8" fill="1" />
</item>
<item class="javax.swing.JSeparator" icon="/com/intellij/uiDesigner/icons/separator.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3" />
</item>
<item class="javax.swing.JProgressBar" icon="/com/intellij/uiDesigner/icons/progressbar.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="0" fill="1" />
</item>
<item class="javax.swing.JToolBar" icon="/com/intellij/uiDesigner/icons/toolbar.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="0" fill="1">
<preferred-size width="-1" height="20" />
</default-constraints>
</item>
<item class="javax.swing.JToolBar$Separator" icon="/com/intellij/uiDesigner/icons/toolbarSeparator.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="0" anchor="0" fill="1" />
</item>
<item class="javax.swing.JScrollBar" icon="/com/intellij/uiDesigner/icons/scrollbar.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="6" hsize-policy="0" anchor="0" fill="2" />
</item>
</group>
</component>
</project>

View file

@ -25,7 +25,7 @@ tasks.withType(JavaCompile) {
application {
mainModule = 'com.nuculabs.dev.imagetagger.ui'
mainClass = 'com.nuculabs.dev.imagetagger.ui.HelloApplication'
mainClass = 'com.nuculabs.dev.imagetagger.ui.MainApplication'
}
kotlin {
jvmToolchain( 17 )
@ -45,7 +45,7 @@ dependencies {
exclude(group: 'org.openjfx')
}
implementation('org.kordamp.ikonli:ikonli-javafx:12.3.1')
implementation('com.microsoft.onnxruntime:onnxruntime:1.17.1')
testImplementation("org.junit.jupiter:junit-jupiter-api:${junitVersion}")
testRuntimeOnly("org.junit.jupiter:junit-jupiter-engine:${junitVersion}")
}

View file

@ -7,6 +7,9 @@ module com.nuculabs.dev.imagetagger.ui {
requires com.dlsc.formsfx;
requires net.synedra.validatorfx;
requires org.kordamp.ikonli.javafx;
requires com.microsoft.onnxruntime;
requires java.logging;
requires java.desktop;
opens com.nuculabs.dev.imagetagger.ui to javafx.fxml;
exports com.nuculabs.dev.imagetagger.ui;

View file

@ -0,0 +1,160 @@
package com.nuculabs.dev.imagetagger.tag_prediction
import ai.onnxruntime.OnnxTensor
import ai.onnxruntime.OrtEnvironment
import ai.onnxruntime.OrtSession
import java.awt.image.BufferedImage
import java.io.File
import java.io.IOException
import java.io.InputStream
import java.util.logging.Logger
import javax.imageio.ImageIO
/**
* ImageTagsPrediction is a specialized class that predicts an Image's tags
*/
class ImageTagsPrediction private constructor() {
private val logger: Logger = Logger.getLogger("InfoLogging")
private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment()
private var ortSession: OrtSession;
private var modelClasses: MutableList<String> = mutableListOf()
init {
logger.info("Loaded ML model.")
val modelFile = File("/home/denis/IdeaProjects/ImageTagger/src/main/resources/AIModels/prediction.onnx")
val classesFile =
File("/home/denis/IdeaProjects/ImageTagger/src/main/resources/AIModels/prediction_categories.txt")
ortSession = ortEnv.createSession(
modelFile.readBytes(),
OrtSession.SessionOptions()
)
modelClasses.addAll(0, classesFile.bufferedReader().readLines())
logger.info("Loaded ${modelClasses.size} model classes.")
}
private fun processImage(bufferedImage: BufferedImage): Array<Array<Array<FloatArray>>> {
try {
val tensorData = Array(1) {
Array(3) {
Array(224) {
FloatArray(224)
}
}
}
val mean = floatArrayOf(0.485f, 0.456f, 0.406f)
val standardDeviation = floatArrayOf(0.229f, 0.224f, 0.225f)
// crop image to 224x224
var width: Int = bufferedImage.width
var height: Int = bufferedImage.height
var startX = 0
var startY = 0
if (width > height) {
startX = (width - height) / 2
width = height
} else {
startY = (height - width) / 2
height = width
}
val image = bufferedImage.getSubimage(startX, startY, width, height);
val resizedImage = image.getScaledInstance(224, 224, 4);
val scaledImage = BufferedImage(224, 224, BufferedImage.TYPE_4BYTE_ABGR)
scaledImage.graphics.drawImage(resizedImage, 0, 0, null)
// Process image
for (y in 0 until scaledImage.height) {
for (x in 0 until scaledImage.width) {
val pixel: Int = scaledImage.getRGB(x, y)
// Get RGB values
tensorData[0][0][y][x] =
((pixel shr 16 and 0xFF) / 255f - mean[0]) / standardDeviation[0]
tensorData[0][1][y][x] =
((pixel shr 16 and 0xFF) / 255f - mean[1]) / standardDeviation[1]
tensorData[0][2][y][x] =
((pixel shr 16 and 0xFF) / 255f - mean[2]) / standardDeviation[2]
}
}
return tensorData
} catch (e: IOException) {
throw RuntimeException(e)
}
}
/**
* Uses the ML model to predict tags for a given bitmap.
*/
@Suppress("UNCHECKED_CAST")
private fun predictTagsInternal(bufferedImage: BufferedImage): List<String> {
// 1. Get input and output names
val inputName: String = ortSession.inputNames.iterator().next()
val outputName: String = ortSession.outputNames.iterator().next()
// 2. Create input tensor
val inputTensor = OnnxTensor.createTensor(ortEnv, processImage(bufferedImage))
// 3. Run the model.
val inputs = mapOf(inputName to inputTensor);
val results = ortSession.run(inputs);
// 4. Get output tensor
val outputTensor = results.get(outputName);
if (outputTensor.isPresent) {
// 5. Get prediction results
val floatBuffer = outputTensor.get().value as Array<FloatArray>
val predictions = ArrayList<String>()
// filter buffer by threshold
for (i in floatBuffer[0].indices) {
if (floatBuffer[0][i] > -0.5) {
predictions.add(modelClasses[i])
}
}
return predictions;
} else {
return ArrayList()
}
}
/**
* Predicts tags for a Bitmap.
*/
fun predictTags(image: BufferedImage): List<String> {
return predictTagsInternal(image)
}
/**
* Predicts tags for a given image input stream.
*/
fun predictTags(input: InputStream?): List<String> {
if (input == null) {
return ArrayList()
}
return predictTagsInternal(ImageIO.read(input))
}
/**
* Close the session and environment.
*/
fun close() {
ortSession.close()
ortEnv.close()
modelClasses.clear()
}
// Singleton Pattern
companion object {
@Volatile
private var instance: ImageTagsPrediction? = null
fun getInstance() =
instance ?: synchronized(this) {
instance ?: ImageTagsPrediction().also { instance = it }
}
}
}

View file

@ -1,7 +1,9 @@
package com.nuculabs.dev.imagetagger.ui
import com.nuculabs.dev.imagetagger.tag_prediction.ImageTagsPrediction
import javafx.fxml.FXML
import javafx.scene.control.Label
import java.io.File
class HelloController {
@FXML
@ -9,6 +11,9 @@ class HelloController {
@FXML
private fun onHelloButtonClick() {
welcomeText.text = "Welcome to JavaFX Application!"
val imageTagsPrediction = ImageTagsPrediction.getInstance()
val testTags = imageTagsPrediction.predictTags(File("/home/denis/Pictures/not_in_train/0a1a1e8bafbcdb00d34d87f35f0f4b9f.jpg").inputStream())
welcomeText.text = testTags.joinToString { it }
}
}

View file

@ -1,13 +1,16 @@
package com.nuculabs.dev.imagetagger.ui
import com.nuculabs.dev.imagetagger.tag_prediction.ImageTagsPrediction
import javafx.application.Application
import javafx.fxml.FXMLLoader
import javafx.scene.Scene
import javafx.stage.Stage
class HelloApplication : Application() {
class MainApplication : Application() {
override fun start(stage: Stage) {
val fxmlLoader = FXMLLoader(HelloApplication::class.java.getResource("hello-view.fxml"))
val imageTagsPrediction = ImageTagsPrediction.getInstance()
val fxmlLoader = FXMLLoader(MainApplication::class.java.getResource("hello-view.fxml"))
val scene = Scene(fxmlLoader.load(), 320.0, 240.0)
stage.title = "Hello!"
stage.scene = scene
@ -16,5 +19,5 @@ class HelloApplication : Application() {
}
fun main() {
Application.launch(HelloApplication::class.java)
Application.launch(MainApplication::class.java)
}