implement BasicServiceLocator
This commit is contained in:
parent
b3ff128641
commit
d971ff591b
5 changed files with 72 additions and 28 deletions
|
@ -0,0 +1,16 @@
|
|||
package dev.nuculabs.imagetagger.ai
|
||||
|
||||
import java.awt.image.BufferedImage
|
||||
import java.io.InputStream
|
||||
|
||||
interface IImageTagsPrediction {
|
||||
/**
|
||||
* Predicts tags for a Bitmap.
|
||||
*/
|
||||
fun predictTags(image: BufferedImage): List<String>
|
||||
|
||||
/**
|
||||
* Predicts tags for a given image input stream.
|
||||
*/
|
||||
fun predictTags(input: InputStream?): List<String>
|
||||
}
|
|
@ -4,6 +4,7 @@ import ai.onnxruntime.OnnxTensor
|
|||
import ai.onnxruntime.OrtEnvironment
|
||||
import ai.onnxruntime.OrtSession
|
||||
import java.awt.image.BufferedImage
|
||||
import java.io.Closeable
|
||||
import java.io.IOException
|
||||
import java.io.InputStream
|
||||
import java.util.logging.Logger
|
||||
|
@ -12,7 +13,7 @@ import javax.imageio.ImageIO
|
|||
/**
|
||||
* ImageTagsPrediction is a specialized class that predicts an Image's tags
|
||||
*/
|
||||
class ImageTagsPrediction {
|
||||
class ImageTagsPrediction : IImageTagsPrediction, Closeable {
|
||||
private val logger: Logger = Logger.getLogger("InfoLogging")
|
||||
private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment()
|
||||
private var ortSession: OrtSession
|
||||
|
@ -21,12 +22,13 @@ class ImageTagsPrediction {
|
|||
init {
|
||||
try {
|
||||
logger.info("Loading ML model. Please wait.")
|
||||
ImageTagsPrediction::class.java.getResourceAsStream("/dev/nuculabs/imagetagger/ai/prediction.onnx").let { modelFile ->
|
||||
ortSession = ortEnv.createSession(
|
||||
modelFile!!.readBytes(),
|
||||
OrtSession.SessionOptions()
|
||||
)
|
||||
}
|
||||
ImageTagsPrediction::class.java.getResourceAsStream("/dev/nuculabs/imagetagger/ai/prediction.onnx")
|
||||
.let { modelFile ->
|
||||
ortSession = ortEnv.createSession(
|
||||
modelFile!!.readBytes(),
|
||||
OrtSession.SessionOptions()
|
||||
)
|
||||
}
|
||||
ImageTagsPrediction::class.java.getResourceAsStream("/dev/nuculabs/imagetagger/ai/prediction_categories.txt")
|
||||
.let { classesFile ->
|
||||
modelClasses.addAll(0, classesFile!!.bufferedReader().readLines())
|
||||
|
@ -46,7 +48,7 @@ class ImageTagsPrediction {
|
|||
/**
|
||||
* Processes an image into an ONNX Tensor.
|
||||
*/
|
||||
private fun processImage(bufferedImage: BufferedImage): Array<Array<Array<FloatArray>>> {
|
||||
fun processImage(bufferedImage: BufferedImage): Array<Array<Array<FloatArray>>> {
|
||||
try {
|
||||
val tensorData = Array(1) {
|
||||
Array(3) {
|
||||
|
@ -101,7 +103,7 @@ class ImageTagsPrediction {
|
|||
* Uses the ML model to predict tags for a given bitmap.
|
||||
*/
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
private fun predictTagsInternal(bufferedImage: BufferedImage): List<String> {
|
||||
fun predictTagsInternal(bufferedImage: BufferedImage): List<String> {
|
||||
// 1. Get input and output names
|
||||
val inputName: String = ortSession.inputNames.iterator().next()
|
||||
val outputName: String = ortSession.outputNames.iterator().next()
|
||||
|
@ -136,14 +138,14 @@ class ImageTagsPrediction {
|
|||
/**
|
||||
* Predicts tags for a Bitmap.
|
||||
*/
|
||||
fun predictTags(image: BufferedImage): List<String> {
|
||||
override fun predictTags(image: BufferedImage): List<String> {
|
||||
return predictTagsInternal(image)
|
||||
}
|
||||
|
||||
/**
|
||||
* Predicts tags for a given image input stream.
|
||||
*/
|
||||
fun predictTags(input: InputStream?): List<String> {
|
||||
override fun predictTags(input: InputStream?): List<String> {
|
||||
if (input == null) {
|
||||
return ArrayList()
|
||||
}
|
||||
|
@ -154,20 +156,9 @@ class ImageTagsPrediction {
|
|||
/**
|
||||
* Close the session and environment.
|
||||
*/
|
||||
fun close() {
|
||||
override fun close() {
|
||||
ortSession.close()
|
||||
ortEnv.close()
|
||||
modelClasses.clear()
|
||||
}
|
||||
|
||||
// Singleton Pattern
|
||||
companion object {
|
||||
@Volatile
|
||||
private var instance: ImageTagsPrediction? = null
|
||||
|
||||
fun getInstance() =
|
||||
instance ?: synchronized(this) {
|
||||
instance ?: ImageTagsPrediction().also { instance = it }
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
package dev.nuculabs.imagetagger.ui
|
||||
|
||||
import dev.nuculabs.imagetagger.ai.IImageTagsPrediction
|
||||
|
||||
/**
|
||||
* BasicServiceLocator is implemented to avoid polluting the apps with singletons.
|
||||
* It's verbose but gluon-ignite was updated 2 years ago, and I've tried to get it running with
|
||||
* Spring, Guice and Dagger, and it got me more trouble than solutions.
|
||||
* This basically a compromise for simplicity.
|
||||
*/
|
||||
class BasicServiceLocator private constructor() {
|
||||
internal lateinit var imageTagsPrediction: IImageTagsPrediction
|
||||
|
||||
// Singleton Pattern
|
||||
companion object {
|
||||
@Volatile
|
||||
private var instance: BasicServiceLocator? = null
|
||||
|
||||
/**
|
||||
* Returns a new BasicServiceLocator singleton instance.
|
||||
*/
|
||||
fun getInstance() =
|
||||
instance ?: synchronized(this) {
|
||||
instance ?: BasicServiceLocator().also { instance = it }
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
package dev.nuculabs.imagetagger.ui
|
||||
|
||||
import dev.nuculabs.imagetagger.ai.IImageTagsPrediction
|
||||
import dev.nuculabs.imagetagger.ai.ImageTagsPrediction
|
||||
import dev.nuculabs.imagetagger.ui.controls.programatic.ApplicationMenuBar
|
||||
import javafx.application.Application
|
||||
|
@ -11,22 +12,25 @@ import javafx.scene.layout.BorderPane
|
|||
import javafx.stage.Stage
|
||||
import java.awt.Taskbar
|
||||
import java.awt.Toolkit
|
||||
import java.io.Closeable
|
||||
import java.util.logging.Logger
|
||||
|
||||
class MainPage : Application() {
|
||||
private val serviceLocator = BasicServiceLocator.getInstance()
|
||||
private val logger: Logger = Logger.getLogger("MainPage")
|
||||
private var imageTagger: ImageTagsPrediction? = null
|
||||
private lateinit var imageTagger: IImageTagsPrediction
|
||||
|
||||
private lateinit var fxmlLoader: FXMLLoader
|
||||
|
||||
private lateinit var mainStage: Stage
|
||||
|
||||
override fun start(stage: Stage) {
|
||||
configureServiceLocator()
|
||||
// Initial resource loading
|
||||
fxmlLoader = FXMLLoader(MainPage::class.java.getResource("main-window-view.fxml"))
|
||||
mainStage = stage
|
||||
|
||||
imageTagger = ImageTagsPrediction.getInstance()
|
||||
imageTagger = BasicServiceLocator.getInstance().imageTagsPrediction
|
||||
setUpApplicationIcon()
|
||||
|
||||
// Load the FXML.
|
||||
|
@ -52,6 +56,13 @@ class MainPage : Application() {
|
|||
stage.show()
|
||||
}
|
||||
|
||||
/**
|
||||
* Configures the service locator.
|
||||
*/
|
||||
private fun configureServiceLocator() {
|
||||
serviceLocator.imageTagsPrediction = ImageTagsPrediction()
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads and sets up the main application icon.
|
||||
*/
|
||||
|
@ -74,7 +85,7 @@ class MainPage : Application() {
|
|||
logger.info("Stop called.")
|
||||
val controller = fxmlLoader.getController<MainPageController>()
|
||||
controller.shutdown()
|
||||
imageTagger?.close()
|
||||
(imageTagger as Closeable).close()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
package dev.nuculabs.imagetagger.ui
|
||||
|
||||
import dev.nuculabs.imagetagger.ai.ImageTagsPrediction
|
||||
import dev.nuculabs.imagetagger.ui.controls.ImageTagsEntryControl
|
||||
import javafx.application.Platform
|
||||
import javafx.fxml.FXML
|
||||
|
@ -51,7 +50,7 @@ class MainPageController {
|
|||
/**
|
||||
* The ImageTagsPrediction service instance.
|
||||
*/
|
||||
private val imageTagsPrediction = ImageTagsPrediction.getInstance()
|
||||
private val imageTagsPrediction = BasicServiceLocator.getInstance().imageTagsPrediction
|
||||
|
||||
/**
|
||||
* A boolean that when set to true it will stop the current image tagging operation.
|
||||
|
|
Loading…
Reference in a new issue