implement BasicServiceLocator

This commit is contained in:
Denis-Cosmin NUTIU 2024-04-08 23:10:27 +03:00
parent b3ff128641
commit d971ff591b
5 changed files with 72 additions and 28 deletions

View file

@ -0,0 +1,16 @@
package dev.nuculabs.imagetagger.ai
import java.awt.image.BufferedImage
import java.io.InputStream
interface IImageTagsPrediction {
/**
* Predicts tags for a Bitmap.
*/
fun predictTags(image: BufferedImage): List<String>
/**
* Predicts tags for a given image input stream.
*/
fun predictTags(input: InputStream?): List<String>
}

View file

@ -4,6 +4,7 @@ import ai.onnxruntime.OnnxTensor
import ai.onnxruntime.OrtEnvironment
import ai.onnxruntime.OrtSession
import java.awt.image.BufferedImage
import java.io.Closeable
import java.io.IOException
import java.io.InputStream
import java.util.logging.Logger
@ -12,7 +13,7 @@ import javax.imageio.ImageIO
/**
* ImageTagsPrediction is a specialized class that predicts an Image's tags
*/
class ImageTagsPrediction {
class ImageTagsPrediction : IImageTagsPrediction, Closeable {
private val logger: Logger = Logger.getLogger("InfoLogging")
private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment()
private var ortSession: OrtSession
@ -21,12 +22,13 @@ class ImageTagsPrediction {
init {
try {
logger.info("Loading ML model. Please wait.")
ImageTagsPrediction::class.java.getResourceAsStream("/dev/nuculabs/imagetagger/ai/prediction.onnx").let { modelFile ->
ortSession = ortEnv.createSession(
modelFile!!.readBytes(),
OrtSession.SessionOptions()
)
}
ImageTagsPrediction::class.java.getResourceAsStream("/dev/nuculabs/imagetagger/ai/prediction.onnx")
.let { modelFile ->
ortSession = ortEnv.createSession(
modelFile!!.readBytes(),
OrtSession.SessionOptions()
)
}
ImageTagsPrediction::class.java.getResourceAsStream("/dev/nuculabs/imagetagger/ai/prediction_categories.txt")
.let { classesFile ->
modelClasses.addAll(0, classesFile!!.bufferedReader().readLines())
@ -46,7 +48,7 @@ class ImageTagsPrediction {
/**
* Processes an image into an ONNX Tensor.
*/
private fun processImage(bufferedImage: BufferedImage): Array<Array<Array<FloatArray>>> {
fun processImage(bufferedImage: BufferedImage): Array<Array<Array<FloatArray>>> {
try {
val tensorData = Array(1) {
Array(3) {
@ -101,7 +103,7 @@ class ImageTagsPrediction {
* Uses the ML model to predict tags for a given bitmap.
*/
@Suppress("UNCHECKED_CAST")
private fun predictTagsInternal(bufferedImage: BufferedImage): List<String> {
fun predictTagsInternal(bufferedImage: BufferedImage): List<String> {
// 1. Get input and output names
val inputName: String = ortSession.inputNames.iterator().next()
val outputName: String = ortSession.outputNames.iterator().next()
@ -136,14 +138,14 @@ class ImageTagsPrediction {
/**
* Predicts tags for a Bitmap.
*/
fun predictTags(image: BufferedImage): List<String> {
override fun predictTags(image: BufferedImage): List<String> {
return predictTagsInternal(image)
}
/**
* Predicts tags for a given image input stream.
*/
fun predictTags(input: InputStream?): List<String> {
override fun predictTags(input: InputStream?): List<String> {
if (input == null) {
return ArrayList()
}
@ -154,20 +156,9 @@ class ImageTagsPrediction {
/**
* Close the session and environment.
*/
fun close() {
override fun close() {
ortSession.close()
ortEnv.close()
modelClasses.clear()
}
// Singleton Pattern
companion object {
@Volatile
private var instance: ImageTagsPrediction? = null
fun getInstance() =
instance ?: synchronized(this) {
instance ?: ImageTagsPrediction().also { instance = it }
}
}
}

View file

@ -0,0 +1,27 @@
package dev.nuculabs.imagetagger.ui
import dev.nuculabs.imagetagger.ai.IImageTagsPrediction
/**
* BasicServiceLocator is implemented to avoid polluting the apps with singletons.
* It's verbose but gluon-ignite was updated 2 years ago, and I've tried to get it running with
* Spring, Guice and Dagger, and it got me more trouble than solutions.
* This basically a compromise for simplicity.
*/
class BasicServiceLocator private constructor() {
internal lateinit var imageTagsPrediction: IImageTagsPrediction
// Singleton Pattern
companion object {
@Volatile
private var instance: BasicServiceLocator? = null
/**
* Returns a new BasicServiceLocator singleton instance.
*/
fun getInstance() =
instance ?: synchronized(this) {
instance ?: BasicServiceLocator().also { instance = it }
}
}
}

View file

@ -1,5 +1,6 @@
package dev.nuculabs.imagetagger.ui
import dev.nuculabs.imagetagger.ai.IImageTagsPrediction
import dev.nuculabs.imagetagger.ai.ImageTagsPrediction
import dev.nuculabs.imagetagger.ui.controls.programatic.ApplicationMenuBar
import javafx.application.Application
@ -11,22 +12,25 @@ import javafx.scene.layout.BorderPane
import javafx.stage.Stage
import java.awt.Taskbar
import java.awt.Toolkit
import java.io.Closeable
import java.util.logging.Logger
class MainPage : Application() {
private val serviceLocator = BasicServiceLocator.getInstance()
private val logger: Logger = Logger.getLogger("MainPage")
private var imageTagger: ImageTagsPrediction? = null
private lateinit var imageTagger: IImageTagsPrediction
private lateinit var fxmlLoader: FXMLLoader
private lateinit var mainStage: Stage
override fun start(stage: Stage) {
configureServiceLocator()
// Initial resource loading
fxmlLoader = FXMLLoader(MainPage::class.java.getResource("main-window-view.fxml"))
mainStage = stage
imageTagger = ImageTagsPrediction.getInstance()
imageTagger = BasicServiceLocator.getInstance().imageTagsPrediction
setUpApplicationIcon()
// Load the FXML.
@ -52,6 +56,13 @@ class MainPage : Application() {
stage.show()
}
/**
* Configures the service locator.
*/
private fun configureServiceLocator() {
serviceLocator.imageTagsPrediction = ImageTagsPrediction()
}
/**
* Loads and sets up the main application icon.
*/
@ -74,7 +85,7 @@ class MainPage : Application() {
logger.info("Stop called.")
val controller = fxmlLoader.getController<MainPageController>()
controller.shutdown()
imageTagger?.close()
(imageTagger as Closeable).close()
}
}

View file

@ -1,6 +1,5 @@
package dev.nuculabs.imagetagger.ui
import dev.nuculabs.imagetagger.ai.ImageTagsPrediction
import dev.nuculabs.imagetagger.ui.controls.ImageTagsEntryControl
import javafx.application.Platform
import javafx.fxml.FXML
@ -51,7 +50,7 @@ class MainPageController {
/**
* The ImageTagsPrediction service instance.
*/
private val imageTagsPrediction = ImageTagsPrediction.getInstance()
private val imageTagsPrediction = BasicServiceLocator.getInstance().imageTagsPrediction
/**
* A boolean that when set to true it will stop the current image tagging operation.