load model from resource files
This commit is contained in:
parent
cf7b4bd6eb
commit
f66413710b
6 changed files with 1732 additions and 22 deletions
|
@ -1,5 +1,6 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="GradleMigrationSettings" migrationVersion="1" />
|
||||
<component name="GradleSettings">
|
||||
<option name="linkedExternalProjectsSettings">
|
||||
<GradleProjectSettings>
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
<component name="FrameworkDetectionExcludesConfiguration">
|
||||
<file type="web" url="file://$PROJECT_DIR$" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" languageLevel="JDK_21" project-jdk-name="21" project-jdk-type="JavaSDK">
|
||||
<component name="ProjectRootManager" version="2" languageLevel="JDK_17" project-jdk-name="21" project-jdk-type="JavaSDK">
|
||||
<output url="file://$PROJECT_DIR$/out" />
|
||||
</component>
|
||||
</project>
|
|
@ -4,7 +4,6 @@ import ai.onnxruntime.OnnxTensor
|
|||
import ai.onnxruntime.OrtEnvironment
|
||||
import ai.onnxruntime.OrtSession
|
||||
import java.awt.image.BufferedImage
|
||||
import java.io.File
|
||||
import java.io.IOException
|
||||
import java.io.InputStream
|
||||
import java.util.logging.Logger
|
||||
|
@ -16,20 +15,20 @@ import javax.imageio.ImageIO
|
|||
class ImageTagsPrediction private constructor() {
|
||||
private val logger: Logger = Logger.getLogger("InfoLogging")
|
||||
private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment()
|
||||
private var ortSession: OrtSession;
|
||||
private var ortSession: OrtSession
|
||||
private var modelClasses: MutableList<String> = mutableListOf()
|
||||
|
||||
init {
|
||||
logger.info("Loaded ML model.")
|
||||
val modelFile = File("/home/denis/IdeaProjects/ImageTagger/src/main/resources/AIModels/prediction.onnx")
|
||||
val classesFile =
|
||||
File("/home/denis/IdeaProjects/ImageTagger/src/main/resources/AIModels/prediction_categories.txt")
|
||||
|
||||
ortSession = ortEnv.createSession(
|
||||
modelFile.readBytes(),
|
||||
OrtSession.SessionOptions()
|
||||
)
|
||||
modelClasses.addAll(0, classesFile.bufferedReader().readLines())
|
||||
ImageTagsPrediction::class.java.getResourceAsStream("/AIModels/prediction.onnx").let { modelFile ->
|
||||
ortSession = ortEnv.createSession(
|
||||
modelFile!!.readBytes(),
|
||||
OrtSession.SessionOptions()
|
||||
)
|
||||
}
|
||||
ImageTagsPrediction::class.java.getResourceAsStream("/AIModels/prediction_categories.txt").let { classesFile ->
|
||||
modelClasses.addAll(0, classesFile!!.bufferedReader().readLines())
|
||||
}
|
||||
logger.info("Loaded ${modelClasses.size} model classes.")
|
||||
}
|
||||
|
||||
|
@ -58,8 +57,8 @@ class ImageTagsPrediction private constructor() {
|
|||
height = width
|
||||
}
|
||||
|
||||
val image = bufferedImage.getSubimage(startX, startY, width, height);
|
||||
val resizedImage = image.getScaledInstance(224, 224, 4);
|
||||
val image = bufferedImage.getSubimage(startX, startY, width, height)
|
||||
val resizedImage = image.getScaledInstance(224, 224, 4)
|
||||
val scaledImage = BufferedImage(224, 224, BufferedImage.TYPE_4BYTE_ABGR)
|
||||
scaledImage.graphics.drawImage(resizedImage, 0, 0, null)
|
||||
|
||||
|
@ -97,11 +96,11 @@ class ImageTagsPrediction private constructor() {
|
|||
val inputTensor = OnnxTensor.createTensor(ortEnv, processImage(bufferedImage))
|
||||
|
||||
// 3. Run the model.
|
||||
val inputs = mapOf(inputName to inputTensor);
|
||||
val results = ortSession.run(inputs);
|
||||
val inputs = mapOf(inputName to inputTensor)
|
||||
val results = ortSession.run(inputs)
|
||||
|
||||
// 4. Get output tensor
|
||||
val outputTensor = results.get(outputName);
|
||||
val outputTensor = results.get(outputName)
|
||||
if (outputTensor.isPresent) {
|
||||
// 5. Get prediction results
|
||||
val floatBuffer = outputTensor.get().value as Array<FloatArray>
|
||||
|
@ -114,7 +113,7 @@ class ImageTagsPrediction private constructor() {
|
|||
}
|
||||
}
|
||||
|
||||
return predictions;
|
||||
return predictions
|
||||
} else {
|
||||
return ArrayList()
|
||||
}
|
||||
|
|
|
@ -8,12 +8,14 @@ import javafx.stage.Stage
|
|||
|
||||
class MainApplication : Application() {
|
||||
override fun start(stage: Stage) {
|
||||
val imageTagsPrediction = ImageTagsPrediction.getInstance()
|
||||
ImageTagsPrediction.getInstance()
|
||||
|
||||
val fxmlLoader = FXMLLoader(MainApplication::class.java.getResource("hello-view.fxml"))
|
||||
val scene = Scene(fxmlLoader.load(), 320.0, 240.0)
|
||||
stage.title = "Hello!"
|
||||
val fxmlLoader = FXMLLoader(MainApplication::class.java.getResource("main-window-view.fxml"))
|
||||
val scene = Scene(fxmlLoader.load(), 640.0, 760.0)
|
||||
stage.title = "Image Tagger"
|
||||
stage.scene = scene
|
||||
stage.minWidth = 640.0
|
||||
stage.minHeight = 760.0
|
||||
stage.show()
|
||||
}
|
||||
}
|
||||
|
|
1708
src/main/resources/AIModels/prediction_categories.txt
Normal file
1708
src/main/resources/AIModels/prediction_categories.txt
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Reference in a new issue