langchain 공부

BM25Retriever1 내부 코드 분석

필만이 2024. 11. 19. 21:51

배경

  • BM25Retriever의 bm25.py 의 원리를 알기 위해 내부코드를 분석하고자함
  • 진짜 bm25의 원리가 되는건 rank_bm25.py로 그건 내일 포스트 예정

코드

  • bm25.py
from __future__ import annotations

from typing import Any, Callable, Dict, Iterable, List, Optional

from langchain_core.callbacks import CallbackManagerForRetrieverRun  
from langchain_core.documents import Document  
from langchain_core.retrievers import BaseRetriever  
from pydantic import ConfigDict, Field

def default_preprocessing_func(text: str) -> List[str]:  
"""  
기본 전처리 함수: 입력 텍스트를 공백 단위로 분할하여 토큰 리스트로 반환.

Args:
    text: 입력 텍스트.

Returns:
    공백을 기준으로 분할된 문자열 리스트.
"""
return text.split()


class BM25Retriever(BaseRetriever):  
"""  
BM25 알고리즘을 기반으로 문서를 검색하는 검색기.

Elasticsearch 없이 rank_bm25 라이브러리를 사용하여 동작.
"""

vectorizer: Any = None
""" BM25 벡터라이저 객체. 문서 검색 알고리즘의 핵심 역할."""

# Field(repr=False): 문자열 출력에서 해당 필드를 제외하여 간결성과 보안을 확보
docs: List[Document] = Field(repr=False)
""" 검색 대상이 되는 문서 리스트. 각 문서는 `Document` 객체로 표현됨."""

k: int = 4
""" 반환할 문서의 최대 개수."""

# Callable: 호출 가능한 객체를 의미합니다. 함수 또는 메서드일 가능성이 큽니다.
# [str]: 입력 값의 타입. 이 함수는 **문자열(str)**을 입력으로 받습니다.
# List[str]: 반환 값의 타입. 이 함수는 **문자열의 리스트(List[str])**를 반환
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func

# Pydantic은 일반적으로 기본 타입(str, int, List, Dict 등)과 Pydantic 모델만 검증가능함
# 애플리케이션이 특정 알고리즘 객체(BM25, TF-IDF 등)나 데이터베이스 연결 객체를 관리해야 할 때, 
# 해당 객체를 Pydantic 모델에 검증 없이 포함
model_config = ConfigDict(
    arbitrary_types_allowed=True,  # Pydantic에서 임의의 객체 타입(BM25 등)을 허용.
)

@classmethod
def from_texts(
    cls,
    texts: Iterable[str],  # 검색 대상 텍스트 리스트
    metadatas: Optional[Iterable[dict]] = None,  # 각 텍스트에 연결된 선택적 메타데이터 리스트
    bm25_params: Optional[Dict[str, Any]] = None,  # BM25 벡터라이저의 추가 설정 파라미터
    preprocess_func: Callable[[str], List[str]] = default_preprocessing_func,  # 텍스트 전처리 함수
    **kwargs: Any,  # 추가적인 키워드 인자
) -> BM25Retriever:
    """
    텍스트 리스트로부터 BM25Retriever 인스턴스를 생성하는 클래스 메서드.

    Args:
        texts: 검색할 텍스트 리스트.
        metadatas: 각 텍스트에 연결할 선택적 메타데이터.
        bm25_params: BM25 벡터라이저에 전달할 추가 파라미터.
        preprocess_func: 벡터화 전에 텍스트를 처리하는 함수.
        **kwargs: 추가적으로 전달할 인자.

    Returns:
        BM25Retriever 객체.
    """
    try:
        # rank_bm25 라이브러리를 가져옴. 설치되지 않은 경우 ImportError 발생.
        from rank_bm25 import BM25Okapi
    except ImportError:
        raise ImportError(
            "rank_bm25를 가져올 수 없습니다. `pip install rank_bm25`로 설치하세요."
        )

    # 각 텍스트를 전처리하여 토큰 리스트로 변환
    texts_processed = [preprocess_func(t) for t in texts]

    # BM25 벡터라이저를 초기화. 추가 설정이 없으면 기본 설정 사용.
    bm25_params = bm25_params or {}
    vectorizer = BM25Okapi(texts_processed, **bm25_params)

    # 메타데이터가 제공되지 않으면 빈 딕셔너리 생성
    metadatas = metadatas or ({} for _ in texts)

    # 각 텍스트와 메타데이터를 결합하여 Document 객체 생성
    docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)]

    # BM25Retriever 객체 생성 후 반환
    return cls(
        vectorizer=vectorizer, docs=docs, preprocess_func=preprocess_func, **kwargs
    )

@classmethod
def from_documents(
    cls,
    documents: Iterable[Document],  # Document 객체 리스트
    *,
    bm25_params: Optional[Dict[str, Any]] = None,  # BM25 벡터라이저의 추가 설정 파라미터
    preprocess_func: Callable[[str], List[str]] = default_preprocessing_func,  # 텍스트 전처리 함수
    **kwargs: Any,  # 추가 키워드 인자
) -> BM25Retriever:
    """
    Document 객체 리스트로부터 BM25Retriever 인스턴스를 생성하는 클래스 메서드.

    Args:
        documents: 검색할 Document 객체 리스트.
        bm25_params: BM25 벡터라이저에 전달할 추가 파라미터.
        preprocess_func: 벡터화 전에 텍스트를 처리하는 함수.
        **kwargs: 추가적으로 전달할 인자.

    Returns:
        BM25Retriever 객체.
    """
    # Document 객체에서 텍스트와 메타데이터를 추출
    # zip(*) 언패킹 연산을 통해 각 튜플의 첫 번째와 두 번째 요소를 각각 묶어서 두 개의 새로운 튜플로 변환
    # 예> pairs = [(1, 'one'), (2, 'two'), (3, 'three')]
    # 언패킹으로 첫 번째와 두 번째 요소를 각각 분리
    # numbers, words = zip(*pairs)
    # print(numbers)  # (1, 2, 3)
    # print(words)    # ('one', 'two', 'three')

    texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))

    # 추출한 텍스트와 메타데이터를 사용하여 BM25Retriever 생성
    # cls: 현재 클래스(해당 클래스 메서드가 정의된 클래스)를 나타내는 예약어
    # 현재 클래스의 from_texts 메서드를 호출
    return cls.from_texts(
        texts=texts,
        bm25_params=bm25_params,
        metadatas=metadatas,
        preprocess_func=preprocess_func,
        **kwargs,
    )

def _get_relevant_documents(
    self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
    """
    주어진 쿼리에 대해 BM25를 사용하여 관련 문서를 검색하는 메서드.

    Args:
        query: 검색할 쿼리 문자열.
        run_manager: 콜백 매니저 객체.

    Returns:
        관련 문서 리스트.
    """
    # 쿼리를 전처리하여 토큰 리스트로 변환
    processed_query = self.preprocess_func(query)

    # BM25를 사용하여 상위 k개의 문서를 검색
    return_docs = self.vectorizer.get_top_n(processed_query, self.docs, n=self.k)

    # 관련 문서 리스트 반환
    return return_docs