배경
- BM25Retriever의 bm25.py 의 원리를 알기 위해 내부코드를 분석하고자함
- 진짜 bm25의 원리가 되는건 rank_bm25.py로 그건 내일 포스트 예정
코드
from __future__ import annotations
from typing import Any, Callable, Dict, Iterable, List, Optional
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pydantic import ConfigDict, Field
def default_preprocessing_func(text: str) -> List[str]:
"""
기본 전처리 함수: 입력 텍스트를 공백 단위로 분할하여 토큰 리스트로 반환.
Args:
text: 입력 텍스트.
Returns:
공백을 기준으로 분할된 문자열 리스트.
"""
return text.split()
class BM25Retriever(BaseRetriever):
"""
BM25 알고리즘을 기반으로 문서를 검색하는 검색기.
Elasticsearch 없이 rank_bm25 라이브러리를 사용하여 동작.
"""
vectorizer: Any = None
""" BM25 벡터라이저 객체. 문서 검색 알고리즘의 핵심 역할."""
# Field(repr=False): 문자열 출력에서 해당 필드를 제외하여 간결성과 보안을 확보
docs: List[Document] = Field(repr=False)
""" 검색 대상이 되는 문서 리스트. 각 문서는 `Document` 객체로 표현됨."""
k: int = 4
""" 반환할 문서의 최대 개수."""
# Callable: 호출 가능한 객체를 의미합니다. 함수 또는 메서드일 가능성이 큽니다.
# [str]: 입력 값의 타입. 이 함수는 **문자열(str)**을 입력으로 받습니다.
# List[str]: 반환 값의 타입. 이 함수는 **문자열의 리스트(List[str])**를 반환
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func
# Pydantic은 일반적으로 기본 타입(str, int, List, Dict 등)과 Pydantic 모델만 검증가능함
# 애플리케이션이 특정 알고리즘 객체(BM25, TF-IDF 등)나 데이터베이스 연결 객체를 관리해야 할 때,
# 해당 객체를 Pydantic 모델에 검증 없이 포함
model_config = ConfigDict(
arbitrary_types_allowed=True, # Pydantic에서 임의의 객체 타입(BM25 등)을 허용.
)
@classmethod
def from_texts(
cls,
texts: Iterable[str], # 검색 대상 텍스트 리스트
metadatas: Optional[Iterable[dict]] = None, # 각 텍스트에 연결된 선택적 메타데이터 리스트
bm25_params: Optional[Dict[str, Any]] = None, # BM25 벡터라이저의 추가 설정 파라미터
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func, # 텍스트 전처리 함수
**kwargs: Any, # 추가적인 키워드 인자
) -> BM25Retriever:
"""
텍스트 리스트로부터 BM25Retriever 인스턴스를 생성하는 클래스 메서드.
Args:
texts: 검색할 텍스트 리스트.
metadatas: 각 텍스트에 연결할 선택적 메타데이터.
bm25_params: BM25 벡터라이저에 전달할 추가 파라미터.
preprocess_func: 벡터화 전에 텍스트를 처리하는 함수.
**kwargs: 추가적으로 전달할 인자.
Returns:
BM25Retriever 객체.
"""
try:
# rank_bm25 라이브러리를 가져옴. 설치되지 않은 경우 ImportError 발생.
from rank_bm25 import BM25Okapi
except ImportError:
raise ImportError(
"rank_bm25를 가져올 수 없습니다. `pip install rank_bm25`로 설치하세요."
)
# 각 텍스트를 전처리하여 토큰 리스트로 변환
texts_processed = [preprocess_func(t) for t in texts]
# BM25 벡터라이저를 초기화. 추가 설정이 없으면 기본 설정 사용.
bm25_params = bm25_params or {}
vectorizer = BM25Okapi(texts_processed, **bm25_params)
# 메타데이터가 제공되지 않으면 빈 딕셔너리 생성
metadatas = metadatas or ({} for _ in texts)
# 각 텍스트와 메타데이터를 결합하여 Document 객체 생성
docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)]
# BM25Retriever 객체 생성 후 반환
return cls(
vectorizer=vectorizer, docs=docs, preprocess_func=preprocess_func, **kwargs
)
@classmethod
def from_documents(
cls,
documents: Iterable[Document], # Document 객체 리스트
*,
bm25_params: Optional[Dict[str, Any]] = None, # BM25 벡터라이저의 추가 설정 파라미터
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func, # 텍스트 전처리 함수
**kwargs: Any, # 추가 키워드 인자
) -> BM25Retriever:
"""
Document 객체 리스트로부터 BM25Retriever 인스턴스를 생성하는 클래스 메서드.
Args:
documents: 검색할 Document 객체 리스트.
bm25_params: BM25 벡터라이저에 전달할 추가 파라미터.
preprocess_func: 벡터화 전에 텍스트를 처리하는 함수.
**kwargs: 추가적으로 전달할 인자.
Returns:
BM25Retriever 객체.
"""
# Document 객체에서 텍스트와 메타데이터를 추출
# zip(*) 언패킹 연산을 통해 각 튜플의 첫 번째와 두 번째 요소를 각각 묶어서 두 개의 새로운 튜플로 변환
# 예> pairs = [(1, 'one'), (2, 'two'), (3, 'three')]
# 언패킹으로 첫 번째와 두 번째 요소를 각각 분리
# numbers, words = zip(*pairs)
# print(numbers) # (1, 2, 3)
# print(words) # ('one', 'two', 'three')
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
# 추출한 텍스트와 메타데이터를 사용하여 BM25Retriever 생성
# cls: 현재 클래스(해당 클래스 메서드가 정의된 클래스)를 나타내는 예약어
# 현재 클래스의 from_texts 메서드를 호출
return cls.from_texts(
texts=texts,
bm25_params=bm25_params,
metadatas=metadatas,
preprocess_func=preprocess_func,
**kwargs,
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""
주어진 쿼리에 대해 BM25를 사용하여 관련 문서를 검색하는 메서드.
Args:
query: 검색할 쿼리 문자열.
run_manager: 콜백 매니저 객체.
Returns:
관련 문서 리스트.
"""
# 쿼리를 전처리하여 토큰 리스트로 변환
processed_query = self.preprocess_func(query)
# BM25를 사용하여 상위 k개의 문서를 검색
return_docs = self.vectorizer.get_top_n(processed_query, self.docs, n=self.k)
# 관련 문서 리스트 반환
return return_docs