refactor image prediction to use thread pool
This commit is contained in:
parent
9831b7acb7
commit
007ce37031
1 changed files with 100 additions and 31 deletions
|
@ -9,11 +9,41 @@ import javafx.scene.control.Separator
|
|||
import javafx.scene.layout.VBox
|
||||
import javafx.stage.FileChooser
|
||||
import java.io.File
|
||||
import java.util.concurrent.ExecutorService
|
||||
import java.util.concurrent.Executors
|
||||
import java.util.concurrent.Semaphore
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
import java.util.logging.Logger
|
||||
import javax.imageio.ImageIO
|
||||
|
||||
|
||||
class MainPageController {
|
||||
/**
|
||||
* The thread pool worker pool.
|
||||
*/
|
||||
private val workerPool: ExecutorService = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors())
|
||||
|
||||
/**
|
||||
* Number of maximum prediction operations of an image tags that's that is allowed to run in parallel.
|
||||
*/
|
||||
private val maxImagesPredictionInProgress = Runtime.getRuntime().availableProcessors()
|
||||
|
||||
/**
|
||||
* Semaphore to limit the maximum amount of predictions submitted to the tread pool.
|
||||
*/
|
||||
private val workerSemaphore: Semaphore = Semaphore(maxImagesPredictionInProgress)
|
||||
|
||||
/**
|
||||
* A counter to keep track of the current image prediction.
|
||||
*/
|
||||
private val processedImageFilesCount = AtomicInteger(0)
|
||||
private var imageFilesTotal = 0
|
||||
|
||||
/**
|
||||
* The ImageTagsPrediction service instance.
|
||||
*/
|
||||
private val imageTagsPrediction = ImageTagsPrediction.getInstance()
|
||||
|
||||
|
||||
private val logger: Logger = Logger.getLogger("MainPageController")
|
||||
|
||||
|
@ -27,40 +57,79 @@ class MainPageController {
|
|||
* Prompts the user to select files then predicts tags for the selected image files.
|
||||
*/
|
||||
@FXML
|
||||
private fun onTagImagesButtonClick() {
|
||||
val imageTagsPrediction = ImageTagsPrediction.getInstance()
|
||||
fun onTagImagesButtonClick() {
|
||||
synchronized(this) {
|
||||
val fileChooser = FileChooser().apply { title = "Choose images" }
|
||||
val filePaths = fileChooser.showOpenMultipleDialog(null) ?: return
|
||||
|
||||
val fileChooser = FileChooser()
|
||||
fileChooser.title = "Choose images"
|
||||
val filePaths = fileChooser.showOpenMultipleDialog(null) ?: return
|
||||
|
||||
progressBar.isVisible = true
|
||||
|
||||
progressBar.progress = 0.0
|
||||
// Create a new thread to predict the images.
|
||||
val thread = Thread {
|
||||
val filePathsTotal = filePaths.count()
|
||||
logger.info("Analyzing $filePathsTotal files")
|
||||
filePaths.forEachIndexed { index, filePath ->
|
||||
try {
|
||||
// Get predictions for the image.
|
||||
val imageFile = ImageIO.read(File(filePath.absolutePath))
|
||||
val tags: List<String> = imageTagsPrediction.predictTags(imageFile)
|
||||
Platform.runLater {
|
||||
// Add image and prediction to the view.
|
||||
verticalBox.children.add(ImageTagsEntryControl(filePath.absolutePath, tags))
|
||||
verticalBox.children.add(Separator())
|
||||
progressBar.progress = (((index + 1) * 100) / filePathsTotal).toDouble() / 100.0
|
||||
logger.info("Progress ${progressBar.progress}")
|
||||
progressBar.isVisible = true
|
||||
progressBar.progress = 0.0
|
||||
processedImageFilesCount.set(0)
|
||||
// Create a new thread to predict the images.
|
||||
Thread {
|
||||
imageFilesTotal = filePaths.count()
|
||||
logger.info("Analyzing $imageFilesTotal files")
|
||||
filePaths.forEach { filePath ->
|
||||
workerSemaphore.acquire()
|
||||
workerPool.submit {
|
||||
predictImageTags(
|
||||
filePath,
|
||||
onError = {
|
||||
workerSemaphore.release()
|
||||
}
|
||||
) { imagePath, imageTags ->
|
||||
// Add newly predicted tags to UI.
|
||||
Platform.runLater {
|
||||
// Add image and prediction to the view.
|
||||
addNewImagePredictionEntry(imagePath, imageTags)
|
||||
workerSemaphore.release()
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
logger.warning("Error while predicting images $e")
|
||||
}
|
||||
}
|
||||
Platform.runLater {
|
||||
progressBar.isVisible = false
|
||||
}
|
||||
}.start()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Predicts an image tags and executes an action with it.
|
||||
*
|
||||
* @param filePath - The image file's absolute path.
|
||||
*/
|
||||
fun predictImageTags(
|
||||
filePath: File,
|
||||
onError: (Exception) -> Unit,
|
||||
onSuccess: (String, List<String>) -> Unit
|
||||
) {
|
||||
try {
|
||||
// Get predictions for the image.
|
||||
val imageFile = ImageIO.read(File(filePath.absolutePath))
|
||||
val tags: List<String> = imageTagsPrediction.predictTags(imageFile)
|
||||
onSuccess(filePath.absolutePath, tags)
|
||||
} catch (e: Exception) {
|
||||
logger.warning("Error while predicting images $e")
|
||||
onError(e)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the UI with a new ImagePredictionEntry.
|
||||
*
|
||||
* @param imagePath - The image path.
|
||||
* @param imageTags - The image's tags.
|
||||
*/
|
||||
fun addNewImagePredictionEntry(
|
||||
imagePath: String,
|
||||
imageTags: List<String>,
|
||||
) {
|
||||
verticalBox.children.add(ImageTagsEntryControl(imagePath, imageTags))
|
||||
verticalBox.children.add(Separator())
|
||||
progressBar.progress =
|
||||
((processedImageFilesCount.incrementAndGet() * 100) / imageFilesTotal).toDouble() / 100.0
|
||||
logger.info("Progress ${processedImageFilesCount.get()}/${imageFilesTotal} ${progressBar.progress}")
|
||||
if (processedImageFilesCount.get() == imageFilesTotal) {
|
||||
progressBar.isVisible = false
|
||||
logger.info("Finished processing images.")
|
||||
}
|
||||
thread.start()
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue