오류 해결 과정

리랭커(Reranker) 사용시 여러 문서 한번에 재정렬 시키기

필만이 2024. 10. 18. 13:16

배경

  1. 한가지 질문에 여러 문서에서 자료를 검색기로 가져오고
  2. 그 자료들을 한번에 리랭커 하고 싶다.
  3. 그런데 ContextualCompressionRetriever를 사용하려면 기본 검색기(base_retriever) 값을 꼭 넣어줘야한다.
  4. ContextualCompressionRetriever를 사용 목적에 맞게 수정해, 문제를 해결하고자 한다.

해결과정

  1. 여러 책을 돌리면서, 검색기로 찾은 문서들이 점점 dense_docs, sparse_docs에 쌓이게 만듦
  2. base_retriever에 값이 없어도 되게(Obtional) 기존 모듈 커스터 마이징(바꿈)
  • 외부 코드
    for book_name in book_names:
        # 검색기 설정 (리랭커 호출 없이 검색만 수행)
        retriever = vectorstore.as_retriever(
            search_type="similarity_score_threshold",
            search_kwargs={'k': 40, "score_threshold": 0.30}
        )
        dense_docs.extend(retriever.invoke(query))

        # BM25 검색기 설정 (리랭커 호출 없이 검색만 수행)
        bm25_retriever = KiwiBM25Retriever.from_documents(documents)
        bm25_retriever.k = 40  # 검색 결과 개수 설정
        sparse_docs.extend(bm25_retriever.invoke(query))

    # 수집한 모든 문서 합치기
    retrieved_docs = sparse_docs + dense_docs

    # 리랭커 설정
    reranker_model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3")
    compressor = CrossEncoderReranker(model=reranker_model, top_n=40)

    ## 원래 base_retriever에 None 값은 입력이 안되서 ContextualCompressionRetriever를 직접수정해야한다.
    compression_retriever = ContextualCompressionRetriever(
        base_compressor=compressor, base_retriever=None  # 이미 수집된 문서들에 리랭커만 적용
    )

    # 문서 압축 및 리랭킹 수행
    ranked_docs = compression_retriever._get_relevant_documents(
        query,
        collected_docs=filtered_docs
    )
  • 커스터 마이징한 모듈(ContextualCompressionRetriever) 코드
class ContextualCompressionRetriever(BaseRetriever):
    """Retriever that wraps a base retriever and compresses the results."""

    base_compressor: BaseDocumentCompressor
    """Compressor for compressing retrieved documents."""

    base_retriever: Optional[RetrieverLike] = None
    """Base Retriever to use for getting relevant documents. Can be None if documents are already collected."""

    class Config:
        """Configuration for this pydantic object."""
        arbitrary_types_allowed = True

    def _get_relevant_documents(
        self,
        query: str,
        *,
        collected_docs: Optional[List[Document]] = None,  # 이미 수집된 문서를 인자로 추가
        **kwargs: Any,
    ) -> List[Document]:
        """Get documents relevant for a query, or use already collected documents.

        Args:
            query: string to find relevant documents for
            collected_docs: 이미 수집된 문서가 있는 경우 전달됨

        Returns:
            Sequence of relevant documents
        """
        # 이미 수집된 문서가 있을 경우, 수집 단계를 생략하고 압축만 수행
        if collected_docs is not None:
            docs = collected_docs
        else:
            # base_retriever가 있을 경우 문서를 검색
            if self.base_retriever:
                docs = self.base_retriever.invoke(
                    query, **kwargs
                )
            else:
                raise ValueError("Either base_retriever must be set or collected_docs must be provided.")

        # 문서가 있으면 압축 수행
        if docs:
            compressed_docs = self.base_compressor.compress_documents(docs, query)
            return list(compressed_docs)
        else:
            return []

    async def _aget_relevant_documents(
        self,
        query: str,
        *,
        collected_docs: Optional[List[Document]] = None,  # 이미 수집된 문서에 대해 비동기 방식으로 처리
        **kwargs: Any,
    ) -> List[Document]:
        """Get documents relevant for a query, or use already collected documents asynchronously.

        Args:
            query: string to find relevant documents for
            collected_docs: 이미 수집된 문서가 있는 경우 전달됨

        Returns:
            List of relevant documents
        """
        # 이미 수집된 문서가 있을 경우, 수집 단계를 생략하고 압축만 수행
        if collected_docs is not None:
            docs = collected_docs
        else:
            # base_retriever가 있을 경우 문서를 검색
            if self.base_retriever:
                docs = await self.base_retriever.ainvoke(query, **kwargs)
            else:
                raise ValueError("Either base_retriever must be set or collected_docs must be provided.")

        # 문서가 있으면 비동기 압축 수행
        if docs:
            compressed_docs = await self.base_compressor.acompress_documents(docs, query)
            return list(compressed_docs)
        else:
            return []

결론

  1. 그 전까지는 모듈 코드가 맘에 안들어도 외부 코드를 괴랄스럽게 고치며 원하는데로 사용해왔다.
  2. 처음으로 다른 사람이 만든 모듈 코드를 직접 고쳐서 해봤다.
  3. 이런 방식이면 기존에 어렵게 풀어왔던 문제들이 대부분 쉽게 풀릴것 같다.
  4. 뭔가 보이지 않은 차원을 넘은거 같다.