Implement BM25+ ranking algorithm
This commit is contained in:
parent
25a47d62d6
commit
ab0dea9404
3 changed files with 262 additions and 0 deletions
1
.idea/.name
Normal file
1
.idea/.name
Normal file
|
@ -0,0 +1 @@
|
||||||
|
DSA
|
187
src/main/kotlin/ranking/bm25/Bm25Plus.kt
Normal file
187
src/main/kotlin/ranking/bm25/Bm25Plus.kt
Normal file
|
@ -0,0 +1,187 @@
|
||||||
|
package dev.nuculabs.dsa.ranking.bm25
|
||||||
|
|
||||||
|
import java.lang.Double.isFinite
|
||||||
|
import java.util.HashMap
|
||||||
|
import kotlin.math.log10
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Document models a simple document which contains a numeric id and text.
|
||||||
|
*/
|
||||||
|
data class Document(val id: Int, val text: String)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* TokenizedDocument models a document which is tokenized using a simple strategy tokenization strategy.
|
||||||
|
*/
|
||||||
|
data class TokenizedDocument(val document: Document, private val text: String) {
|
||||||
|
private var tokens: List<String> = document.text.split(" ").map { token ->
|
||||||
|
token.filter { it.isLetterOrDigit() }.lowercase()
|
||||||
|
}.filter {
|
||||||
|
it.isNotEmpty()
|
||||||
|
}
|
||||||
|
|
||||||
|
fun getTokens(): List<String> {
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
if (this === other) return true
|
||||||
|
if (javaClass != other?.javaClass) return false
|
||||||
|
|
||||||
|
other as TokenizedDocument
|
||||||
|
|
||||||
|
return document == other.document
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int {
|
||||||
|
return document.hashCode()
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* BM25+ is a variation of the BM25 ranking algorithm.
|
||||||
|
*
|
||||||
|
* The algorithm is implemented using the following paper as a reference.
|
||||||
|
* http://www.cs.otago.ac.nz/homepages/andrew/papers/2014-2.pdf
|
||||||
|
*/
|
||||||
|
class Bm25Plus {
|
||||||
|
/**
|
||||||
|
* The storage holds a mapping of document id -> document.
|
||||||
|
*/
|
||||||
|
private var storage: MutableMap<Int, TokenizedDocument> = HashMap()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The term frequency index holds a mapping of term -> list of documents in which the term occurs.
|
||||||
|
*/
|
||||||
|
private var termFrequencyIndex: MutableMap<String, HashSet<Int>> = HashMap()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The tuning parameters are used to tune the result of the algorithm.
|
||||||
|
*
|
||||||
|
* These values were taken directly from the paper.
|
||||||
|
*/
|
||||||
|
private var tuningParameterB: Double = 0.3
|
||||||
|
private var tuningParameterK1: Double = 1.6
|
||||||
|
private var tuningParameterDelta: Double = 0.7
|
||||||
|
|
||||||
|
private var totalTokens: Int = 0
|
||||||
|
private var meanDocumentLengths: Double = 0.0
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the number of indexed documents.
|
||||||
|
*/
|
||||||
|
fun indexSize(): Int {
|
||||||
|
return storage.size
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Indexes a document.
|
||||||
|
*/
|
||||||
|
fun index(document: Document) {
|
||||||
|
// Tokenize the document, for educational purposes and simplicity we will consider tokens only
|
||||||
|
// the words delimited by a space and transform them into lowercase.
|
||||||
|
val tokenizedDocument = TokenizedDocument(document, document.text)
|
||||||
|
|
||||||
|
// Document does not exist in index
|
||||||
|
if (!storage.containsKey(document.id)) {
|
||||||
|
storage[document.id] = tokenizedDocument
|
||||||
|
|
||||||
|
totalTokens += tokenizedDocument.getTokens().size
|
||||||
|
meanDocumentLengths = (totalTokens / storage.size).toDouble()
|
||||||
|
|
||||||
|
// Index all tokens
|
||||||
|
tokenizedDocument.getTokens().forEach {
|
||||||
|
if (termFrequencyIndex.containsKey(it)) {
|
||||||
|
termFrequencyIndex[it]?.add(document.id)
|
||||||
|
} else {
|
||||||
|
termFrequencyIndex[it] = HashSet()
|
||||||
|
termFrequencyIndex[it]?.add(document.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Indexes all the documents.
|
||||||
|
*/
|
||||||
|
fun indexAll(vararg documents: Document) {
|
||||||
|
documents.forEach {
|
||||||
|
index(it)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Queries documents using the given term and returns a list of documents which contain the term ordered by
|
||||||
|
* relevance.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
fun termQuery(term: String): List<Pair<Double, Document>> {
|
||||||
|
val documentIds = termFrequencyIndex[term.lowercase()] ?: return emptyList()
|
||||||
|
// Compute the RSV for each document.
|
||||||
|
return documentIds.map {
|
||||||
|
val document = storage[it] ?: return@map null
|
||||||
|
val documentRsv = computeRsv(term.lowercase(), document)
|
||||||
|
return@map documentRsv to document.document
|
||||||
|
// Sort results by highest score and filter out Infinity scores, which mean that the term does not exist.
|
||||||
|
}.filterNotNull().filter { isFinite(it.first) }.sortedByDescending { it.first }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Queries documents using the given terms and returns a list of documents which contain the terms ordered by
|
||||||
|
* relevance.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
fun termsQuery(vararg terms: String): List<Pair<Double, Document>> {
|
||||||
|
|
||||||
|
val documentIds = terms.map { term ->
|
||||||
|
Pair(term, termFrequencyIndex[term.lowercase()] ?: mutableSetOf())
|
||||||
|
}.reduce { acc, pair ->
|
||||||
|
// add all documents which contain them terms to the documents set.
|
||||||
|
acc.second.addAll(pair.second)
|
||||||
|
// return
|
||||||
|
acc
|
||||||
|
}.second
|
||||||
|
|
||||||
|
// Compute the terms RSV sum for each document.
|
||||||
|
return documentIds.map {
|
||||||
|
val document = storage[it] ?: return@map null
|
||||||
|
val documentRsv: Double = terms.sumOf { term -> computeRsv(term.lowercase(), document) }
|
||||||
|
return@map documentRsv to document.document
|
||||||
|
// Sort results by highest score and filter out Infinity scores, which mean that the term does not exist.
|
||||||
|
}.filterNotNull().filter { isFinite(it.first) }.sortedByDescending { it.first }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the inverse document frequency for a given term.
|
||||||
|
*
|
||||||
|
* THe IDF is defined as the total number of documents (N) divided by the documents that contain the term (dft).
|
||||||
|
* In the BM25+ version the IDF is the (N+1)/(dft)
|
||||||
|
*/
|
||||||
|
private fun computeInverseDocumentFrequency(term: String): Double {
|
||||||
|
val numberOfDocumentsContainingTheTerm = termFrequencyIndex[term]?.size ?: 0
|
||||||
|
return (storage.size + 1) / numberOfDocumentsContainingTheTerm.toDouble()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the RSV for the given term and document.
|
||||||
|
* The RSV (Retrieval Status Value) is computed for every document using the BM25+ formula from the paper.
|
||||||
|
*/
|
||||||
|
private fun computeRsv(
|
||||||
|
term: String,
|
||||||
|
document: TokenizedDocument
|
||||||
|
): Double {
|
||||||
|
val inverseDocumentFrequencyLog: Double = log10(computeInverseDocumentFrequency(term.lowercase()))
|
||||||
|
val termOccurringInDocumentFrequency: Double =
|
||||||
|
document.getTokens().filter { token -> token == term.lowercase() }.size.toDouble()
|
||||||
|
val documentLength: Double = document.getTokens().size.toDouble()
|
||||||
|
|
||||||
|
val score =
|
||||||
|
inverseDocumentFrequencyLog *
|
||||||
|
(
|
||||||
|
((tuningParameterK1 + 1) * termOccurringInDocumentFrequency) /
|
||||||
|
((tuningParameterK1 * ((1 - tuningParameterB) + tuningParameterB * (documentLength / meanDocumentLengths))) + termOccurringInDocumentFrequency)
|
||||||
|
+ tuningParameterDelta
|
||||||
|
)
|
||||||
|
return score
|
||||||
|
}
|
||||||
|
}
|
74
src/test/kotlin/ranking/bm25/BM25PlusTest.kt
Normal file
74
src/test/kotlin/ranking/bm25/BM25PlusTest.kt
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
package ranking.bm25
|
||||||
|
|
||||||
|
import dev.nuculabs.dsa.ranking.bm25.Bm25Plus
|
||||||
|
import dev.nuculabs.dsa.ranking.bm25.Document
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
class BM25PlusTest {
|
||||||
|
@Test
|
||||||
|
fun test_index_and_indexSize() {
|
||||||
|
// Setup
|
||||||
|
val bm25Plus = Bm25Plus()
|
||||||
|
|
||||||
|
val document1 = Document(1, "Ana are mere")
|
||||||
|
val document2 = Document(2, "Ana Ana Ana Ana Ana Ana Ana Ana")
|
||||||
|
|
||||||
|
// Test
|
||||||
|
bm25Plus.indexAll(document1, document2)
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(2, bm25Plus.indexSize())
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun test_termQuery() {
|
||||||
|
// Given
|
||||||
|
val bm25Plus = Bm25Plus()
|
||||||
|
|
||||||
|
val document1 = Document(1, "Ana are mere")
|
||||||
|
val document2 = Document(2, "Ana Ana Ana Ana Ana Ana Ana Ana")
|
||||||
|
|
||||||
|
// Then
|
||||||
|
bm25Plus.index(document1)
|
||||||
|
bm25Plus.index(document2)
|
||||||
|
|
||||||
|
assertEquals(
|
||||||
|
listOf(0.4936823874431607 to document2, 0.3133956394555762 to document1),
|
||||||
|
bm25Plus.termQuery("Ana")
|
||||||
|
)
|
||||||
|
assertEquals(listOf(0.8491490237651933 to document1), bm25Plus.termQuery("mere"))
|
||||||
|
assertEquals(listOf(), bm25Plus.termQuery("batman"))
|
||||||
|
assertEquals(
|
||||||
|
listOf(0.4936823874431607 to document2, 0.3133956394555762 to document1),
|
||||||
|
bm25Plus.termQuery("ana")
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun test_termsQuery() {
|
||||||
|
// Setup
|
||||||
|
val bm25Plus = Bm25Plus()
|
||||||
|
|
||||||
|
val document1 = Document(
|
||||||
|
1,
|
||||||
|
"A linked list is a fundamental data structure which consists of Nodes that are connected to each other."
|
||||||
|
)
|
||||||
|
val document2 =
|
||||||
|
Document(2, "The Linked List data structure permits the storage of data in an efficient manner.")
|
||||||
|
val document3 =
|
||||||
|
Document(3, "The space and time complexity of the linked list operations depends on the implementation.")
|
||||||
|
val document4 = Document(
|
||||||
|
4,
|
||||||
|
"The operations that take O(N) time takes this much because you have to traverse the list’s for at least N nodes in order to perform it successfully. On the other hand, operations that take O(1) time do not require any traversals because the list holds pointers to the head first Node and tail last Node."
|
||||||
|
)
|
||||||
|
|
||||||
|
bm25Plus.indexAll(document1, document2, document3, document4)
|
||||||
|
|
||||||
|
// Test
|
||||||
|
val results = bm25Plus.termsQuery("linked", "list", "complexity")
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(1.5966769323799244 to document3, results.first())
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue