This commit is contained in:
Timothy Jaeryang Baek 2025-08-21 21:48:21 +04:00
parent 5a66f69460
commit 60b8cfb9fa

View file

@ -952,6 +952,7 @@ class RerankCompressor(BaseDocumentCompressor):
) -> Sequence[Document]:
reranking = self.reranking_function is not None
scores = None
if reranking:
scores = self.reranking_function(
[(query, doc.page_content) for doc in documents]
@ -965,22 +966,31 @@ class RerankCompressor(BaseDocumentCompressor):
)
scores = util.cos_sim(query_embedding, document_embedding)[0]
docs_with_scores = list(
zip(documents, scores.tolist() if not isinstance(scores, list) else scores)
)
if self.r_score:
docs_with_scores = [
(d, s) for d, s in docs_with_scores if s >= self.r_score
]
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
final_results = []
for doc, doc_score in result[: self.top_n]:
metadata = doc.metadata
metadata["score"] = doc_score
doc = Document(
page_content=doc.page_content,
metadata=metadata,
if scores:
docs_with_scores = list(
zip(
documents,
scores.tolist() if not isinstance(scores, list) else scores,
)
)
final_results.append(doc)
return final_results
if self.r_score:
docs_with_scores = [
(d, s) for d, s in docs_with_scores if s >= self.r_score
]
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
final_results = []
for doc, doc_score in result[: self.top_n]:
metadata = doc.metadata
metadata["score"] = doc_score
doc = Document(
page_content=doc.page_content,
metadata=metadata,
)
final_results.append(doc)
return final_results
else:
log.warning(
"No valid scores found, check your reranking function. Returning original documents."
)
return documents