Implement BM25+ ranking algorithm

This commit is contained in:
Denis-Cosmin Nutiu 2024-06-02 23:07:58 +03:00
parent 25a47d62d6
commit ab0dea9404
3 changed files with 262 additions and 0 deletions

1
.idea/.name Normal file
View file

@ -0,0 +1 @@
DSA

View file

@ -0,0 +1,187 @@
package dev.nuculabs.dsa.ranking.bm25
import java.lang.Double.isFinite
import java.util.HashMap
import kotlin.math.log10
/**
* Document models a simple document which contains a numeric id and text.
*/
data class Document(val id: Int, val text: String)
/**
* TokenizedDocument models a document which is tokenized using a simple strategy tokenization strategy.
*/
data class TokenizedDocument(val document: Document, private val text: String) {
private var tokens: List<String> = document.text.split(" ").map { token ->
token.filter { it.isLetterOrDigit() }.lowercase()
}.filter {
it.isNotEmpty()
}
fun getTokens(): List<String> {
return tokens
}
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false
other as TokenizedDocument
return document == other.document
}
override fun hashCode(): Int {
return document.hashCode()
}
}
/**
* BM25+ is a variation of the BM25 ranking algorithm.
*
* The algorithm is implemented using the following paper as a reference.
* http://www.cs.otago.ac.nz/homepages/andrew/papers/2014-2.pdf
*/
class Bm25Plus {
/**
* The storage holds a mapping of document id -> document.
*/
private var storage: MutableMap<Int, TokenizedDocument> = HashMap()
/**
* The term frequency index holds a mapping of term -> list of documents in which the term occurs.
*/
private var termFrequencyIndex: MutableMap<String, HashSet<Int>> = HashMap()
/**
* The tuning parameters are used to tune the result of the algorithm.
*
* These values were taken directly from the paper.
*/
private var tuningParameterB: Double = 0.3
private var tuningParameterK1: Double = 1.6
private var tuningParameterDelta: Double = 0.7
private var totalTokens: Int = 0
private var meanDocumentLengths: Double = 0.0
/**
* Returns the number of indexed documents.
*/
fun indexSize(): Int {
return storage.size
}
/**
* Indexes a document.
*/
fun index(document: Document) {
// Tokenize the document, for educational purposes and simplicity we will consider tokens only
// the words delimited by a space and transform them into lowercase.
val tokenizedDocument = TokenizedDocument(document, document.text)
// Document does not exist in index
if (!storage.containsKey(document.id)) {
storage[document.id] = tokenizedDocument
totalTokens += tokenizedDocument.getTokens().size
meanDocumentLengths = (totalTokens / storage.size).toDouble()
// Index all tokens
tokenizedDocument.getTokens().forEach {
if (termFrequencyIndex.containsKey(it)) {
termFrequencyIndex[it]?.add(document.id)
} else {
termFrequencyIndex[it] = HashSet()
termFrequencyIndex[it]?.add(document.id)
}
}
}
}
/**
* Indexes all the documents.
*/
fun indexAll(vararg documents: Document) {
documents.forEach {
index(it)
}
}
/**
* Queries documents using the given term and returns a list of documents which contain the term ordered by
* relevance.
*
*/
fun termQuery(term: String): List<Pair<Double, Document>> {
val documentIds = termFrequencyIndex[term.lowercase()] ?: return emptyList()
// Compute the RSV for each document.
return documentIds.map {
val document = storage[it] ?: return@map null
val documentRsv = computeRsv(term.lowercase(), document)
return@map documentRsv to document.document
// Sort results by highest score and filter out Infinity scores, which mean that the term does not exist.
}.filterNotNull().filter { isFinite(it.first) }.sortedByDescending { it.first }
}
/**
* Queries documents using the given terms and returns a list of documents which contain the terms ordered by
* relevance.
*
*/
fun termsQuery(vararg terms: String): List<Pair<Double, Document>> {
val documentIds = terms.map { term ->
Pair(term, termFrequencyIndex[term.lowercase()] ?: mutableSetOf())
}.reduce { acc, pair ->
// add all documents which contain them terms to the documents set.
acc.second.addAll(pair.second)
// return
acc
}.second
// Compute the terms RSV sum for each document.
return documentIds.map {
val document = storage[it] ?: return@map null
val documentRsv: Double = terms.sumOf { term -> computeRsv(term.lowercase(), document) }
return@map documentRsv to document.document
// Sort results by highest score and filter out Infinity scores, which mean that the term does not exist.
}.filterNotNull().filter { isFinite(it.first) }.sortedByDescending { it.first }
}
/**
* Computes the inverse document frequency for a given term.
*
* THe IDF is defined as the total number of documents (N) divided by the documents that contain the term (dft).
* In the BM25+ version the IDF is the (N+1)/(dft)
*/
private fun computeInverseDocumentFrequency(term: String): Double {
val numberOfDocumentsContainingTheTerm = termFrequencyIndex[term]?.size ?: 0
return (storage.size + 1) / numberOfDocumentsContainingTheTerm.toDouble()
}
/**
* Computes the RSV for the given term and document.
* The RSV (Retrieval Status Value) is computed for every document using the BM25+ formula from the paper.
*/
private fun computeRsv(
term: String,
document: TokenizedDocument
): Double {
val inverseDocumentFrequencyLog: Double = log10(computeInverseDocumentFrequency(term.lowercase()))
val termOccurringInDocumentFrequency: Double =
document.getTokens().filter { token -> token == term.lowercase() }.size.toDouble()
val documentLength: Double = document.getTokens().size.toDouble()
val score =
inverseDocumentFrequencyLog *
(
((tuningParameterK1 + 1) * termOccurringInDocumentFrequency) /
((tuningParameterK1 * ((1 - tuningParameterB) + tuningParameterB * (documentLength / meanDocumentLengths))) + termOccurringInDocumentFrequency)
+ tuningParameterDelta
)
return score
}
}

View file

@ -0,0 +1,74 @@
package ranking.bm25
import dev.nuculabs.dsa.ranking.bm25.Bm25Plus
import dev.nuculabs.dsa.ranking.bm25.Document
import kotlin.test.Test
import kotlin.test.assertEquals
class BM25PlusTest {
@Test
fun test_index_and_indexSize() {
// Setup
val bm25Plus = Bm25Plus()
val document1 = Document(1, "Ana are mere")
val document2 = Document(2, "Ana Ana Ana Ana Ana Ana Ana Ana")
// Test
bm25Plus.indexAll(document1, document2)
// Assert
assertEquals(2, bm25Plus.indexSize())
}
@Test
fun test_termQuery() {
// Given
val bm25Plus = Bm25Plus()
val document1 = Document(1, "Ana are mere")
val document2 = Document(2, "Ana Ana Ana Ana Ana Ana Ana Ana")
// Then
bm25Plus.index(document1)
bm25Plus.index(document2)
assertEquals(
listOf(0.4936823874431607 to document2, 0.3133956394555762 to document1),
bm25Plus.termQuery("Ana")
)
assertEquals(listOf(0.8491490237651933 to document1), bm25Plus.termQuery("mere"))
assertEquals(listOf(), bm25Plus.termQuery("batman"))
assertEquals(
listOf(0.4936823874431607 to document2, 0.3133956394555762 to document1),
bm25Plus.termQuery("ana")
)
}
@Test
fun test_termsQuery() {
// Setup
val bm25Plus = Bm25Plus()
val document1 = Document(
1,
"A linked list is a fundamental data structure which consists of Nodes that are connected to each other."
)
val document2 =
Document(2, "The Linked List data structure permits the storage of data in an efficient manner.")
val document3 =
Document(3, "The space and time complexity of the linked list operations depends on the implementation.")
val document4 = Document(
4,
"The operations that take O(N) time takes this much because you have to traverse the lists for at least N nodes in order to perform it successfully. On the other hand, operations that take O(1) time do not require any traversals because the list holds pointers to the head first Node and tail last Node."
)
bm25Plus.indexAll(document1, document2, document3, document4)
// Test
val results = bm25Plus.termsQuery("linked", "list", "complexity")
// Assert
assertEquals(1.5966769323799244 to document3, results.first())
}
}