refactor image prediction to use thread pool

This commit is contained in:
Denis-Cosmin Nutiu 2024-03-22 23:16:56 +02:00
parent 9831b7acb7
commit 007ce37031

View file

@ -9,11 +9,41 @@ import javafx.scene.control.Separator
import javafx.scene.layout.VBox
import javafx.stage.FileChooser
import java.io.File
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.concurrent.Semaphore
import java.util.concurrent.atomic.AtomicInteger
import java.util.logging.Logger
import javax.imageio.ImageIO
class MainPageController {
/**
* The thread pool worker pool.
*/
private val workerPool: ExecutorService = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors())
/**
* Number of maximum prediction operations of an image tags that's that is allowed to run in parallel.
*/
private val maxImagesPredictionInProgress = Runtime.getRuntime().availableProcessors()
/**
* Semaphore to limit the maximum amount of predictions submitted to the tread pool.
*/
private val workerSemaphore: Semaphore = Semaphore(maxImagesPredictionInProgress)
/**
* A counter to keep track of the current image prediction.
*/
private val processedImageFilesCount = AtomicInteger(0)
private var imageFilesTotal = 0
/**
* The ImageTagsPrediction service instance.
*/
private val imageTagsPrediction = ImageTagsPrediction.getInstance()
private val logger: Logger = Logger.getLogger("MainPageController")
@ -27,40 +57,79 @@ class MainPageController {
* Prompts the user to select files then predicts tags for the selected image files.
*/
@FXML
private fun onTagImagesButtonClick() {
val imageTagsPrediction = ImageTagsPrediction.getInstance()
fun onTagImagesButtonClick() {
synchronized(this) {
val fileChooser = FileChooser().apply { title = "Choose images" }
val filePaths = fileChooser.showOpenMultipleDialog(null) ?: return
val fileChooser = FileChooser()
fileChooser.title = "Choose images"
val filePaths = fileChooser.showOpenMultipleDialog(null) ?: return
progressBar.isVisible = true
progressBar.progress = 0.0
// Create a new thread to predict the images.
val thread = Thread {
val filePathsTotal = filePaths.count()
logger.info("Analyzing $filePathsTotal files")
filePaths.forEachIndexed { index, filePath ->
try {
// Get predictions for the image.
val imageFile = ImageIO.read(File(filePath.absolutePath))
val tags: List<String> = imageTagsPrediction.predictTags(imageFile)
Platform.runLater {
// Add image and prediction to the view.
verticalBox.children.add(ImageTagsEntryControl(filePath.absolutePath, tags))
verticalBox.children.add(Separator())
progressBar.progress = (((index + 1) * 100) / filePathsTotal).toDouble() / 100.0
logger.info("Progress ${progressBar.progress}")
progressBar.isVisible = true
progressBar.progress = 0.0
processedImageFilesCount.set(0)
// Create a new thread to predict the images.
Thread {
imageFilesTotal = filePaths.count()
logger.info("Analyzing $imageFilesTotal files")
filePaths.forEach { filePath ->
workerSemaphore.acquire()
workerPool.submit {
predictImageTags(
filePath,
onError = {
workerSemaphore.release()
}
) { imagePath, imageTags ->
// Add newly predicted tags to UI.
Platform.runLater {
// Add image and prediction to the view.
addNewImagePredictionEntry(imagePath, imageTags)
workerSemaphore.release()
}
}
}
} catch (e: Exception) {
logger.warning("Error while predicting images $e")
}
}
Platform.runLater {
progressBar.isVisible = false
}
}.start()
}
}
/**
* Predicts an image tags and executes an action with it.
*
* @param filePath - The image file's absolute path.
*/
fun predictImageTags(
filePath: File,
onError: (Exception) -> Unit,
onSuccess: (String, List<String>) -> Unit
) {
try {
// Get predictions for the image.
val imageFile = ImageIO.read(File(filePath.absolutePath))
val tags: List<String> = imageTagsPrediction.predictTags(imageFile)
onSuccess(filePath.absolutePath, tags)
} catch (e: Exception) {
logger.warning("Error while predicting images $e")
onError(e)
}
}
/**
* Updates the UI with a new ImagePredictionEntry.
*
* @param imagePath - The image path.
* @param imageTags - The image's tags.
*/
fun addNewImagePredictionEntry(
imagePath: String,
imageTags: List<String>,
) {
verticalBox.children.add(ImageTagsEntryControl(imagePath, imageTags))
verticalBox.children.add(Separator())
progressBar.progress =
((processedImageFilesCount.incrementAndGet() * 100) / imageFilesTotal).toDouble() / 100.0
logger.info("Progress ${processedImageFilesCount.get()}/${imageFilesTotal} ${progressBar.progress}")
if (processedImageFilesCount.get() == imageFilesTotal) {
progressBar.isVisible = false
logger.info("Finished processing images.")
}
thread.start()
}
}