predict sample image
This commit is contained in:
parent
0c345ce398
commit
cf7b4bd6eb
7 changed files with 302 additions and 7 deletions
|
@ -4,7 +4,7 @@
|
|||
<component name="FrameworkDetectionExcludesConfiguration">
|
||||
<file type="web" url="file://$PROJECT_DIR$" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" languageLevel="JDK_17" project-jdk-name="corretto-21" project-jdk-type="JavaSDK">
|
||||
<component name="ProjectRootManager" version="2" languageLevel="JDK_21" project-jdk-name="21" project-jdk-type="JavaSDK">
|
||||
<output url="file://$PROJECT_DIR$/out" />
|
||||
</component>
|
||||
</project>
|
124
.idea/uiDesigner.xml
Normal file
124
.idea/uiDesigner.xml
Normal file
|
@ -0,0 +1,124 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Palette2">
|
||||
<group name="Swing">
|
||||
<item class="com.intellij.uiDesigner.HSpacer" tooltip-text="Horizontal Spacer" icon="/com/intellij/uiDesigner/icons/hspacer.svg" removable="false" auto-create-binding="false" can-attach-label="false">
|
||||
<default-constraints vsize-policy="1" hsize-policy="6" anchor="0" fill="1" />
|
||||
</item>
|
||||
<item class="com.intellij.uiDesigner.VSpacer" tooltip-text="Vertical Spacer" icon="/com/intellij/uiDesigner/icons/vspacer.svg" removable="false" auto-create-binding="false" can-attach-label="false">
|
||||
<default-constraints vsize-policy="6" hsize-policy="1" anchor="0" fill="2" />
|
||||
</item>
|
||||
<item class="javax.swing.JPanel" icon="/com/intellij/uiDesigner/icons/panel.svg" removable="false" auto-create-binding="false" can-attach-label="false">
|
||||
<default-constraints vsize-policy="3" hsize-policy="3" anchor="0" fill="3" />
|
||||
</item>
|
||||
<item class="javax.swing.JScrollPane" icon="/com/intellij/uiDesigner/icons/scrollPane.svg" removable="false" auto-create-binding="false" can-attach-label="true">
|
||||
<default-constraints vsize-policy="7" hsize-policy="7" anchor="0" fill="3" />
|
||||
</item>
|
||||
<item class="javax.swing.JButton" icon="/com/intellij/uiDesigner/icons/button.svg" removable="false" auto-create-binding="true" can-attach-label="false">
|
||||
<default-constraints vsize-policy="0" hsize-policy="3" anchor="0" fill="1" />
|
||||
<initial-values>
|
||||
<property name="text" value="Button" />
|
||||
</initial-values>
|
||||
</item>
|
||||
<item class="javax.swing.JRadioButton" icon="/com/intellij/uiDesigner/icons/radioButton.svg" removable="false" auto-create-binding="true" can-attach-label="false">
|
||||
<default-constraints vsize-policy="0" hsize-policy="3" anchor="8" fill="0" />
|
||||
<initial-values>
|
||||
<property name="text" value="RadioButton" />
|
||||
</initial-values>
|
||||
</item>
|
||||
<item class="javax.swing.JCheckBox" icon="/com/intellij/uiDesigner/icons/checkBox.svg" removable="false" auto-create-binding="true" can-attach-label="false">
|
||||
<default-constraints vsize-policy="0" hsize-policy="3" anchor="8" fill="0" />
|
||||
<initial-values>
|
||||
<property name="text" value="CheckBox" />
|
||||
</initial-values>
|
||||
</item>
|
||||
<item class="javax.swing.JLabel" icon="/com/intellij/uiDesigner/icons/label.svg" removable="false" auto-create-binding="false" can-attach-label="false">
|
||||
<default-constraints vsize-policy="0" hsize-policy="0" anchor="8" fill="0" />
|
||||
<initial-values>
|
||||
<property name="text" value="Label" />
|
||||
</initial-values>
|
||||
</item>
|
||||
<item class="javax.swing.JTextField" icon="/com/intellij/uiDesigner/icons/textField.svg" removable="false" auto-create-binding="true" can-attach-label="true">
|
||||
<default-constraints vsize-policy="0" hsize-policy="6" anchor="8" fill="1">
|
||||
<preferred-size width="150" height="-1" />
|
||||
</default-constraints>
|
||||
</item>
|
||||
<item class="javax.swing.JPasswordField" icon="/com/intellij/uiDesigner/icons/passwordField.svg" removable="false" auto-create-binding="true" can-attach-label="true">
|
||||
<default-constraints vsize-policy="0" hsize-policy="6" anchor="8" fill="1">
|
||||
<preferred-size width="150" height="-1" />
|
||||
</default-constraints>
|
||||
</item>
|
||||
<item class="javax.swing.JFormattedTextField" icon="/com/intellij/uiDesigner/icons/formattedTextField.svg" removable="false" auto-create-binding="true" can-attach-label="true">
|
||||
<default-constraints vsize-policy="0" hsize-policy="6" anchor="8" fill="1">
|
||||
<preferred-size width="150" height="-1" />
|
||||
</default-constraints>
|
||||
</item>
|
||||
<item class="javax.swing.JTextArea" icon="/com/intellij/uiDesigner/icons/textArea.svg" removable="false" auto-create-binding="true" can-attach-label="true">
|
||||
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3">
|
||||
<preferred-size width="150" height="50" />
|
||||
</default-constraints>
|
||||
</item>
|
||||
<item class="javax.swing.JTextPane" icon="/com/intellij/uiDesigner/icons/textPane.svg" removable="false" auto-create-binding="true" can-attach-label="true">
|
||||
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3">
|
||||
<preferred-size width="150" height="50" />
|
||||
</default-constraints>
|
||||
</item>
|
||||
<item class="javax.swing.JEditorPane" icon="/com/intellij/uiDesigner/icons/editorPane.svg" removable="false" auto-create-binding="true" can-attach-label="true">
|
||||
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3">
|
||||
<preferred-size width="150" height="50" />
|
||||
</default-constraints>
|
||||
</item>
|
||||
<item class="javax.swing.JComboBox" icon="/com/intellij/uiDesigner/icons/comboBox.svg" removable="false" auto-create-binding="true" can-attach-label="true">
|
||||
<default-constraints vsize-policy="0" hsize-policy="2" anchor="8" fill="1" />
|
||||
</item>
|
||||
<item class="javax.swing.JTable" icon="/com/intellij/uiDesigner/icons/table.svg" removable="false" auto-create-binding="true" can-attach-label="false">
|
||||
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3">
|
||||
<preferred-size width="150" height="50" />
|
||||
</default-constraints>
|
||||
</item>
|
||||
<item class="javax.swing.JList" icon="/com/intellij/uiDesigner/icons/list.svg" removable="false" auto-create-binding="true" can-attach-label="false">
|
||||
<default-constraints vsize-policy="6" hsize-policy="2" anchor="0" fill="3">
|
||||
<preferred-size width="150" height="50" />
|
||||
</default-constraints>
|
||||
</item>
|
||||
<item class="javax.swing.JTree" icon="/com/intellij/uiDesigner/icons/tree.svg" removable="false" auto-create-binding="true" can-attach-label="false">
|
||||
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3">
|
||||
<preferred-size width="150" height="50" />
|
||||
</default-constraints>
|
||||
</item>
|
||||
<item class="javax.swing.JTabbedPane" icon="/com/intellij/uiDesigner/icons/tabbedPane.svg" removable="false" auto-create-binding="true" can-attach-label="false">
|
||||
<default-constraints vsize-policy="3" hsize-policy="3" anchor="0" fill="3">
|
||||
<preferred-size width="200" height="200" />
|
||||
</default-constraints>
|
||||
</item>
|
||||
<item class="javax.swing.JSplitPane" icon="/com/intellij/uiDesigner/icons/splitPane.svg" removable="false" auto-create-binding="false" can-attach-label="false">
|
||||
<default-constraints vsize-policy="3" hsize-policy="3" anchor="0" fill="3">
|
||||
<preferred-size width="200" height="200" />
|
||||
</default-constraints>
|
||||
</item>
|
||||
<item class="javax.swing.JSpinner" icon="/com/intellij/uiDesigner/icons/spinner.svg" removable="false" auto-create-binding="true" can-attach-label="true">
|
||||
<default-constraints vsize-policy="0" hsize-policy="6" anchor="8" fill="1" />
|
||||
</item>
|
||||
<item class="javax.swing.JSlider" icon="/com/intellij/uiDesigner/icons/slider.svg" removable="false" auto-create-binding="true" can-attach-label="false">
|
||||
<default-constraints vsize-policy="0" hsize-policy="6" anchor="8" fill="1" />
|
||||
</item>
|
||||
<item class="javax.swing.JSeparator" icon="/com/intellij/uiDesigner/icons/separator.svg" removable="false" auto-create-binding="false" can-attach-label="false">
|
||||
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3" />
|
||||
</item>
|
||||
<item class="javax.swing.JProgressBar" icon="/com/intellij/uiDesigner/icons/progressbar.svg" removable="false" auto-create-binding="true" can-attach-label="false">
|
||||
<default-constraints vsize-policy="0" hsize-policy="6" anchor="0" fill="1" />
|
||||
</item>
|
||||
<item class="javax.swing.JToolBar" icon="/com/intellij/uiDesigner/icons/toolbar.svg" removable="false" auto-create-binding="false" can-attach-label="false">
|
||||
<default-constraints vsize-policy="0" hsize-policy="6" anchor="0" fill="1">
|
||||
<preferred-size width="-1" height="20" />
|
||||
</default-constraints>
|
||||
</item>
|
||||
<item class="javax.swing.JToolBar$Separator" icon="/com/intellij/uiDesigner/icons/toolbarSeparator.svg" removable="false" auto-create-binding="false" can-attach-label="false">
|
||||
<default-constraints vsize-policy="0" hsize-policy="0" anchor="0" fill="1" />
|
||||
</item>
|
||||
<item class="javax.swing.JScrollBar" icon="/com/intellij/uiDesigner/icons/scrollbar.svg" removable="false" auto-create-binding="true" can-attach-label="false">
|
||||
<default-constraints vsize-policy="6" hsize-policy="0" anchor="0" fill="2" />
|
||||
</item>
|
||||
</group>
|
||||
</component>
|
||||
</project>
|
|
@ -25,7 +25,7 @@ tasks.withType(JavaCompile) {
|
|||
|
||||
application {
|
||||
mainModule = 'com.nuculabs.dev.imagetagger.ui'
|
||||
mainClass = 'com.nuculabs.dev.imagetagger.ui.HelloApplication'
|
||||
mainClass = 'com.nuculabs.dev.imagetagger.ui.MainApplication'
|
||||
}
|
||||
kotlin {
|
||||
jvmToolchain( 17 )
|
||||
|
@ -45,7 +45,7 @@ dependencies {
|
|||
exclude(group: 'org.openjfx')
|
||||
}
|
||||
implementation('org.kordamp.ikonli:ikonli-javafx:12.3.1')
|
||||
|
||||
implementation('com.microsoft.onnxruntime:onnxruntime:1.17.1')
|
||||
testImplementation("org.junit.jupiter:junit-jupiter-api:${junitVersion}")
|
||||
testRuntimeOnly("org.junit.jupiter:junit-jupiter-engine:${junitVersion}")
|
||||
}
|
||||
|
|
|
@ -7,6 +7,9 @@ module com.nuculabs.dev.imagetagger.ui {
|
|||
requires com.dlsc.formsfx;
|
||||
requires net.synedra.validatorfx;
|
||||
requires org.kordamp.ikonli.javafx;
|
||||
requires com.microsoft.onnxruntime;
|
||||
requires java.logging;
|
||||
requires java.desktop;
|
||||
|
||||
opens com.nuculabs.dev.imagetagger.ui to javafx.fxml;
|
||||
exports com.nuculabs.dev.imagetagger.ui;
|
||||
|
|
|
@ -0,0 +1,160 @@
|
|||
package com.nuculabs.dev.imagetagger.tag_prediction
|
||||
|
||||
import ai.onnxruntime.OnnxTensor
|
||||
import ai.onnxruntime.OrtEnvironment
|
||||
import ai.onnxruntime.OrtSession
|
||||
import java.awt.image.BufferedImage
|
||||
import java.io.File
|
||||
import java.io.IOException
|
||||
import java.io.InputStream
|
||||
import java.util.logging.Logger
|
||||
import javax.imageio.ImageIO
|
||||
|
||||
/**
|
||||
* ImageTagsPrediction is a specialized class that predicts an Image's tags
|
||||
*/
|
||||
class ImageTagsPrediction private constructor() {
|
||||
private val logger: Logger = Logger.getLogger("InfoLogging")
|
||||
private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment()
|
||||
private var ortSession: OrtSession;
|
||||
private var modelClasses: MutableList<String> = mutableListOf()
|
||||
|
||||
init {
|
||||
logger.info("Loaded ML model.")
|
||||
val modelFile = File("/home/denis/IdeaProjects/ImageTagger/src/main/resources/AIModels/prediction.onnx")
|
||||
val classesFile =
|
||||
File("/home/denis/IdeaProjects/ImageTagger/src/main/resources/AIModels/prediction_categories.txt")
|
||||
|
||||
ortSession = ortEnv.createSession(
|
||||
modelFile.readBytes(),
|
||||
OrtSession.SessionOptions()
|
||||
)
|
||||
modelClasses.addAll(0, classesFile.bufferedReader().readLines())
|
||||
logger.info("Loaded ${modelClasses.size} model classes.")
|
||||
}
|
||||
|
||||
private fun processImage(bufferedImage: BufferedImage): Array<Array<Array<FloatArray>>> {
|
||||
try {
|
||||
val tensorData = Array(1) {
|
||||
Array(3) {
|
||||
Array(224) {
|
||||
FloatArray(224)
|
||||
}
|
||||
}
|
||||
}
|
||||
val mean = floatArrayOf(0.485f, 0.456f, 0.406f)
|
||||
val standardDeviation = floatArrayOf(0.229f, 0.224f, 0.225f)
|
||||
|
||||
// crop image to 224x224
|
||||
var width: Int = bufferedImage.width
|
||||
var height: Int = bufferedImage.height
|
||||
var startX = 0
|
||||
var startY = 0
|
||||
if (width > height) {
|
||||
startX = (width - height) / 2
|
||||
width = height
|
||||
} else {
|
||||
startY = (height - width) / 2
|
||||
height = width
|
||||
}
|
||||
|
||||
val image = bufferedImage.getSubimage(startX, startY, width, height);
|
||||
val resizedImage = image.getScaledInstance(224, 224, 4);
|
||||
val scaledImage = BufferedImage(224, 224, BufferedImage.TYPE_4BYTE_ABGR)
|
||||
scaledImage.graphics.drawImage(resizedImage, 0, 0, null)
|
||||
|
||||
|
||||
// Process image
|
||||
for (y in 0 until scaledImage.height) {
|
||||
for (x in 0 until scaledImage.width) {
|
||||
val pixel: Int = scaledImage.getRGB(x, y)
|
||||
|
||||
// Get RGB values
|
||||
tensorData[0][0][y][x] =
|
||||
((pixel shr 16 and 0xFF) / 255f - mean[0]) / standardDeviation[0]
|
||||
tensorData[0][1][y][x] =
|
||||
((pixel shr 16 and 0xFF) / 255f - mean[1]) / standardDeviation[1]
|
||||
tensorData[0][2][y][x] =
|
||||
((pixel shr 16 and 0xFF) / 255f - mean[2]) / standardDeviation[2]
|
||||
}
|
||||
}
|
||||
return tensorData
|
||||
} catch (e: IOException) {
|
||||
throw RuntimeException(e)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Uses the ML model to predict tags for a given bitmap.
|
||||
*/
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
private fun predictTagsInternal(bufferedImage: BufferedImage): List<String> {
|
||||
// 1. Get input and output names
|
||||
val inputName: String = ortSession.inputNames.iterator().next()
|
||||
val outputName: String = ortSession.outputNames.iterator().next()
|
||||
|
||||
// 2. Create input tensor
|
||||
val inputTensor = OnnxTensor.createTensor(ortEnv, processImage(bufferedImage))
|
||||
|
||||
// 3. Run the model.
|
||||
val inputs = mapOf(inputName to inputTensor);
|
||||
val results = ortSession.run(inputs);
|
||||
|
||||
// 4. Get output tensor
|
||||
val outputTensor = results.get(outputName);
|
||||
if (outputTensor.isPresent) {
|
||||
// 5. Get prediction results
|
||||
val floatBuffer = outputTensor.get().value as Array<FloatArray>
|
||||
val predictions = ArrayList<String>()
|
||||
|
||||
// filter buffer by threshold
|
||||
for (i in floatBuffer[0].indices) {
|
||||
if (floatBuffer[0][i] > -0.5) {
|
||||
predictions.add(modelClasses[i])
|
||||
}
|
||||
}
|
||||
|
||||
return predictions;
|
||||
} else {
|
||||
return ArrayList()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Predicts tags for a Bitmap.
|
||||
*/
|
||||
fun predictTags(image: BufferedImage): List<String> {
|
||||
return predictTagsInternal(image)
|
||||
}
|
||||
|
||||
/**
|
||||
* Predicts tags for a given image input stream.
|
||||
*/
|
||||
fun predictTags(input: InputStream?): List<String> {
|
||||
if (input == null) {
|
||||
return ArrayList()
|
||||
}
|
||||
|
||||
return predictTagsInternal(ImageIO.read(input))
|
||||
}
|
||||
|
||||
/**
|
||||
* Close the session and environment.
|
||||
*/
|
||||
fun close() {
|
||||
ortSession.close()
|
||||
ortEnv.close()
|
||||
modelClasses.clear()
|
||||
}
|
||||
|
||||
// Singleton Pattern
|
||||
companion object {
|
||||
@Volatile
|
||||
private var instance: ImageTagsPrediction? = null
|
||||
|
||||
fun getInstance() =
|
||||
instance ?: synchronized(this) {
|
||||
instance ?: ImageTagsPrediction().also { instance = it }
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,7 +1,9 @@
|
|||
package com.nuculabs.dev.imagetagger.ui
|
||||
|
||||
import com.nuculabs.dev.imagetagger.tag_prediction.ImageTagsPrediction
|
||||
import javafx.fxml.FXML
|
||||
import javafx.scene.control.Label
|
||||
import java.io.File
|
||||
|
||||
class HelloController {
|
||||
@FXML
|
||||
|
@ -9,6 +11,9 @@ class HelloController {
|
|||
|
||||
@FXML
|
||||
private fun onHelloButtonClick() {
|
||||
welcomeText.text = "Welcome to JavaFX Application!"
|
||||
val imageTagsPrediction = ImageTagsPrediction.getInstance()
|
||||
val testTags = imageTagsPrediction.predictTags(File("/home/denis/Pictures/not_in_train/0a1a1e8bafbcdb00d34d87f35f0f4b9f.jpg").inputStream())
|
||||
|
||||
welcomeText.text = testTags.joinToString { it }
|
||||
}
|
||||
}
|
|
@ -1,13 +1,16 @@
|
|||
package com.nuculabs.dev.imagetagger.ui
|
||||
|
||||
import com.nuculabs.dev.imagetagger.tag_prediction.ImageTagsPrediction
|
||||
import javafx.application.Application
|
||||
import javafx.fxml.FXMLLoader
|
||||
import javafx.scene.Scene
|
||||
import javafx.stage.Stage
|
||||
|
||||
class HelloApplication : Application() {
|
||||
class MainApplication : Application() {
|
||||
override fun start(stage: Stage) {
|
||||
val fxmlLoader = FXMLLoader(HelloApplication::class.java.getResource("hello-view.fxml"))
|
||||
val imageTagsPrediction = ImageTagsPrediction.getInstance()
|
||||
|
||||
val fxmlLoader = FXMLLoader(MainApplication::class.java.getResource("hello-view.fxml"))
|
||||
val scene = Scene(fxmlLoader.load(), 320.0, 240.0)
|
||||
stage.title = "Hello!"
|
||||
stage.scene = scene
|
||||
|
@ -16,5 +19,5 @@ class HelloApplication : Application() {
|
|||
}
|
||||
|
||||
fun main() {
|
||||
Application.launch(HelloApplication::class.java)
|
||||
Application.launch(MainApplication::class.java)
|
||||
}
|
Loading…
Reference in a new issue