load model from resource files

This commit is contained in:
Denis-Cosmin Nutiu 2024-03-19 18:18:29 +02:00
parent cf7b4bd6eb
commit f66413710b
6 changed files with 1732 additions and 22 deletions

View file

@ -1,5 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="GradleMigrationSettings" migrationVersion="1" />
<component name="GradleSettings">
<option name="linkedExternalProjectsSettings">
<GradleProjectSettings>

View file

@ -4,7 +4,7 @@
<component name="FrameworkDetectionExcludesConfiguration">
<file type="web" url="file://$PROJECT_DIR$" />
</component>
<component name="ProjectRootManager" version="2" languageLevel="JDK_21" project-jdk-name="21" project-jdk-type="JavaSDK">
<component name="ProjectRootManager" version="2" languageLevel="JDK_17" project-jdk-name="21" project-jdk-type="JavaSDK">
<output url="file://$PROJECT_DIR$/out" />
</component>
</project>

View file

@ -4,7 +4,6 @@ import ai.onnxruntime.OnnxTensor
import ai.onnxruntime.OrtEnvironment
import ai.onnxruntime.OrtSession
import java.awt.image.BufferedImage
import java.io.File
import java.io.IOException
import java.io.InputStream
import java.util.logging.Logger
@ -16,20 +15,20 @@ import javax.imageio.ImageIO
class ImageTagsPrediction private constructor() {
private val logger: Logger = Logger.getLogger("InfoLogging")
private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment()
private var ortSession: OrtSession;
private var ortSession: OrtSession
private var modelClasses: MutableList<String> = mutableListOf()
init {
logger.info("Loaded ML model.")
val modelFile = File("/home/denis/IdeaProjects/ImageTagger/src/main/resources/AIModels/prediction.onnx")
val classesFile =
File("/home/denis/IdeaProjects/ImageTagger/src/main/resources/AIModels/prediction_categories.txt")
ortSession = ortEnv.createSession(
modelFile.readBytes(),
OrtSession.SessionOptions()
)
modelClasses.addAll(0, classesFile.bufferedReader().readLines())
ImageTagsPrediction::class.java.getResourceAsStream("/AIModels/prediction.onnx").let { modelFile ->
ortSession = ortEnv.createSession(
modelFile!!.readBytes(),
OrtSession.SessionOptions()
)
}
ImageTagsPrediction::class.java.getResourceAsStream("/AIModels/prediction_categories.txt").let { classesFile ->
modelClasses.addAll(0, classesFile!!.bufferedReader().readLines())
}
logger.info("Loaded ${modelClasses.size} model classes.")
}
@ -58,8 +57,8 @@ class ImageTagsPrediction private constructor() {
height = width
}
val image = bufferedImage.getSubimage(startX, startY, width, height);
val resizedImage = image.getScaledInstance(224, 224, 4);
val image = bufferedImage.getSubimage(startX, startY, width, height)
val resizedImage = image.getScaledInstance(224, 224, 4)
val scaledImage = BufferedImage(224, 224, BufferedImage.TYPE_4BYTE_ABGR)
scaledImage.graphics.drawImage(resizedImage, 0, 0, null)
@ -97,11 +96,11 @@ class ImageTagsPrediction private constructor() {
val inputTensor = OnnxTensor.createTensor(ortEnv, processImage(bufferedImage))
// 3. Run the model.
val inputs = mapOf(inputName to inputTensor);
val results = ortSession.run(inputs);
val inputs = mapOf(inputName to inputTensor)
val results = ortSession.run(inputs)
// 4. Get output tensor
val outputTensor = results.get(outputName);
val outputTensor = results.get(outputName)
if (outputTensor.isPresent) {
// 5. Get prediction results
val floatBuffer = outputTensor.get().value as Array<FloatArray>
@ -114,7 +113,7 @@ class ImageTagsPrediction private constructor() {
}
}
return predictions;
return predictions
} else {
return ArrayList()
}

View file

@ -8,12 +8,14 @@ import javafx.stage.Stage
class MainApplication : Application() {
override fun start(stage: Stage) {
val imageTagsPrediction = ImageTagsPrediction.getInstance()
ImageTagsPrediction.getInstance()
val fxmlLoader = FXMLLoader(MainApplication::class.java.getResource("hello-view.fxml"))
val scene = Scene(fxmlLoader.load(), 320.0, 240.0)
stage.title = "Hello!"
val fxmlLoader = FXMLLoader(MainApplication::class.java.getResource("main-window-view.fxml"))
val scene = Scene(fxmlLoader.load(), 640.0, 760.0)
stage.title = "Image Tagger"
stage.scene = scene
stage.minWidth = 640.0
stage.minHeight = 760.0
stage.show()
}
}

File diff suppressed because it is too large Load diff