langchain 공부

CrossEncoderReranker 내부 코드 분석

필만이 2024. 11. 14. 17:41

배경

  • CrossEncoderReranker의 내부 코드를 보고 활용 방법을 찾아본다.

코드

from __future__ import annotations

import operator
from typing import Optional, Sequence

from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from pydantic import ConfigDict

from langchain.retrievers.document_compressors.cross_encoder import BaseCrossEncoder


class CrossEncoderReranker(BaseDocumentCompressor):
    """CrossEncoder를 사용하여 문서를 재정렬(rerank)하는 Document 압축기 클래스."""

    model: BaseCrossEncoder
    """유사도를 평가하기 위해 CrossEncoder 모델을 사용하여 쿼리와 문서 간의 점수를 계산."""
    top_n: int = 3
    """반환할 상위 문서의 개수."""

    model_config = ConfigDict(
        # 기본적으로 pydantic 모델에서는 int, str, list 같은 기본 타입이나 pydantic으로 정의된 객체 타입만을 허용.
        # arbitrary_types_allowed=True로 설정하면, pydantic이 기본 타입 외에 임의의 사용자 정의 객체 타입도 허용
        arbitrary_types_allowed=True,  # 임의 타입의 필드를 허용
        # "forbid" 설정은 pydantic 모델에 정의되지 않은 필드가 포함될 경우 예외를 발생
        extra="forbid",
    )

    def compress_documents(
        self,
        documents: Sequence[Document],
        query: str,
        # 특정 이벤트 (압축(compress)과정)에서 실행할 수 있는 함수나 객체를 참조, 진행 상황을 추적하거나 오류를 로깅하는 등의 작업을 수행
        # 선택 사항이며, 기본값으로 None이 설정
        callbacks: Optional[Callbacks] = None,
    ) -> Sequence[Document]:
        """
        CrossEncoder를 사용하여 문서를 재정렬합니다.

        Args:
            documents: 압축할 문서의 시퀀스.
            query: 문서 압축에 사용할 쿼리.
            callbacks: 압축 과정 중 실행할 콜백 함수들.

        Returns:
            압축된 문서의 시퀀스.
        """
        # 각 문서의 page_content와 쿼리 쌍을 만들고 이를 모델에 전달하여 유사도 점수를 계산
        # 예: [(query, "문서1 내용"), (query, "문서2 내용"), ...]
        scores = self.model.score([(query, doc.page_content) for doc in documents])

        # 문서와 해당 유사도 점수를 zip을 통해 결합하여 리스트로 변환
        docs_with_scores = list(zip(documents, scores))

        # 유사도 점수를 기준으로 문서를 내림차순으로 정렬하여 가장 유사한 문서가 상위에 위치하도록 함
        result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)

        # 상위 top_n개의 문서를 선택하여 반환. 각 요소에서 문서만 추출하여 리스트로 반환
        return [doc for doc, _ in result[: self.top_n]]

결론

  • 관련이 별로 없는데도, top_n개의 문서를 선택해야해서 저품질 문서가 포함되는 문제 발생
  • 상위 top_n개의 문서를 반환하는게 아닌 점수로 필터링을 한다면, 상대적으로 가치가 있는 정보를 뽑을 수 있을듯.
  • 다음 문서에서 그렇게 시행 하고자 함