langchain 공부

PromptTemplate 내부코드 분석

필만이 2024. 11. 24. 19:02

배경

  • PromptTemplate의 내부 코드를 분석해 원리를 알아본다.

코드

from __future__ import annotations

import warnings
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union

from langchain_core.prompts.string import (
    DEFAULT_FORMATTER_MAPPING,  # 각 템플릿 형식에 대한 포맷터 매핑
    StringPromptTemplate,  # 부모 클래스 정의
    check_valid_template,  # 템플릿 유효성 검사 함수
    get_template_variables,  # 템플릿에서 변수 추출 함수
    mustache_schema,  # mustache 템플릿을 위한 스키마 생성
)
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.runnables.config import RunnableConfig

class PromptTemplate(StringPromptTemplate):
    """
    언어 모델을 위한 프롬프트 템플릿 클래스.

    프롬프트 템플릿은 문자열 템플릿으로 구성되며, 사용자로부터 전달받은
    여러 매개변수를 사용해 언어 모델의 프롬프트를 생성할 수 있습니다.

    템플릿은 기본적으로 f-string 형식으로 포맷팅되며, jinja2 문법도 지원합니다.

    *보안 경고*:
        템플릿 형식으로 `template_format="f-string"`을 사용하는 것이 권장됩니다.
        `template_format="jinja2"`를 사용하는 경우, 외부에서 제공되는 jinja2 템플릿은
        절대 사용하지 마세요. 임의의 Python 코드를 실행하게 될 위험이 있습니다.

        LangChain 버전 0.0.329부터는 Jinja2 템플릿을 SandboxedEnvironment를 사용해
        렌더링합니다. 이는 최선의 노력으로 보안을 강화하는 방법이지만,
        보안을 완전히 보장할 수는 없습니다. (Opt-out 방식 적용됨)

        신뢰할 수 없는 출처에서 jinja2 템플릿을 절대 사용하지 않는 것이 좋습니다.

    예제:

        .. code-block:: python

            from langchain_core.prompts import PromptTemplate

            # from_template 메서드를 이용해 인스턴스 생성 (권장 방식)
            prompt = PromptTemplate.from_template("Say {foo}")
            prompt.format(foo="bar")

            # 초기화 메서드를 통해 직접 인스턴스 생성
            prompt = PromptTemplate(template="Say {foo}")
    """

    @property
    def lc_attributes(self) -> Dict[str, Any]:
        """
        프롬프트 템플릿의 속성을 반환합니다.

        이 속성은 템플릿의 주요 설정을 포함하며, 직렬화 및 디버깅에 유용합니다.

        Returns:
            Dict[str, Any]: 템플릿 속성을 포함한 딕셔너리.
        """
        return {
            "template_format": self.template_format,
        }

    @classmethod
    def get_lc_namespace(cls) -> List[str]:
        """
        LangChain 객체의 네임스페이스를 반환합니다.

        Returns:
            List[str]: 네임스페이스 구성 요소 리스트.
        """
        return ["langchain", "prompts", "prompt"]

    # 프롬프트 템플릿을 담고 있는 문자열.
    template: str

    # 프롬프트 템플릿의 형식입니다. 지원 옵션: 'f-string', 'mustache', 'jinja2'. 기본값은 'f-string'.
    template_format: Literal["f-string", "mustache", "jinja2"] = "f-string"

    # 템플릿 유효성 검사를 수행할지 여부를 나타냅니다. 기본값은 False.
    validate_template: bool = False

    # @root_validator(pre=True)는 Python의 데이터 유효성 검사 라이브러리인 Pydantic에서 사용되는 데코레이터입니다. 
    # 이를 통해 데이터 모델에 입력되기 전에 전처리를 하거나, 여러 필드 간의 상호 관계를 기반으로 데이터를 검증
    @root_validator(pre=True)
    def pre_init_validation(cls, values: Dict) -> Dict:
        """
        템플릿과 입력 변수의 일관성을 확인하는 유효성 검사 메서드.

        초기화 시 입력된 값들을 확인하고, 유효성 검사를 수행하며 기본값을 설정합니다.

        Args:
            values (Dict): 템플릿과 관련된 초기화 값들.

        Returns:
            Dict: 유효성 검사를 통과한 값들.
        """
        # 템플릿이 제공되지 않은 경우 ValidationError를 발생시키도록 함
        if values.get("template") is None:
            return values  # Pydantic이 ValidationError를 처리함

        # 템플릿 형식 기본값을 f-string으로 설정
        values.setdefault("template_format", "f-string")
        values.setdefault("partial_variables", {})

        # validate_template이 True일 경우 템플릿 유효성 검사 수행
        if values.get("validate_template"):
            # Mustache 형식은 유효성 검사를 지원하지 않음
            if values["template_format"] == "mustache":
                raise ValueError("Mustache templates cannot be validated.")

            # input_variables가 없으면 예외 처리
            if "input_variables" not in values:
                raise ValueError(
                    "Input variables must be provided to validate the template."
                )

            # 모든 입력 변수 목록을 구성하여 유효성 검사 함수에 전달
            all_inputs = values["input_variables"] + list(values["partial_variables"])
            check_valid_template(
                values["template"], values["template_format"], all_inputs
            )

        # 템플릿 형식에 따라 필요한 변수 추출
        if values["template_format"]:  # 템플릿 형식이 제공되었는지 확인
            # 1. 템플릿과 템플릿 형식에서 요구되는 변수 리스트 가져오기
            #    - get_template_variables 함수는 템플릿에서 사용해야 할 변수들을 분석하여 반환합니다.
            #    - 예: 템플릿이 "Hello, {name}!"이고 템플릿 형식이 문자열이라면, 반환값은 ["name"]입니다.
            template_variables = get_template_variables(
                values["template"], values["template_format"]
            )

            # 2. 필수 입력 변수 정의
            #    - 부분 변수(partial_variables)는 이미 입력값으로 제공된 변수들의 리스트입니다.
            #    - template_variables 중 partial_variables에 포함되지 않은 변수들만 필수 입력 변수로 간주합니다.
            #    - 즉, 아직 입력값으로 제공되지 않은 변수들을 input_variables에 저장합니다.
            values["input_variables"] = [
                var  # template_variables에서 반복적으로 변수 확인
                for var in template_variables  # 템플릿에 정의된 변수 목록
                if var not in values["partial_variables"]  # 이미 제공된 변수는 제외
        ]


        return values

    def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
        """
        프롬프트의 입력 스키마를 반환.

        프롬프트 템플릿에 필요한 입력 변수를 정의한 스키마를 생성합니다.

        Args:
            config (RunnableConfig | None): 실행 가능한 구성 객체 (선택적).

        Returns:
            type[BaseModel]: 입력 스키마 클래스.
        """
        # 템플릿 형식이 'mustache'가 아닌 경우 부모 클래스의 스키마 반환
        if self.template_format != "mustache":
            return super().get_input_schema(config)
        # 템플릿 형식이 'mustache'인 경우 mustache 전용 스키마 생성
        return mustache_schema(self.template)


    def __add__(self, other: Any) -> PromptTemplate:
        """프롬프트 템플릿을 결합하기 위해 + 연산자 오버라이딩."""
        # 두 템플릿을 쉽게 결합할 수 있게 허용
        if isinstance(other, PromptTemplate):
            # 템플릿 형식이 f-string이 아닌 경우 예외 발생
            if self.template_format != "f-string":
                raise ValueError(
                    "Adding prompt templates only supported for f-strings."
                )
            if other.template_format != "f-string":
                raise ValueError(
                    "Adding prompt templates only supported for f-strings."
                )
            # 두 템플릿의 입력 변수를 결합
            input_variables = list(
                set(self.input_variables) | set(other.input_variables)
            )
            # 템플릿 문자열을 결합
            template = self.template + other.template
            # 두 템플릿 모두 유효성 검사 설정이 True인 경우에만 유효성 검사 수행
            validate_template = self.validate_template and other.validate_template
            # 부분 변수 결합 (중복 검사)
            partial_variables = {k: v for k, v in self.partial_variables.items()}
            for k, v in other.partial_variables.items():
                if k in partial_variables:
                    raise ValueError("Cannot have same variable partialed twice.")
                else:
                    partial_variables[k] = v
            # 결합된 템플릿 반환
            return PromptTemplate(
                template=template,
                input_variables=input_variables,
                partial_variables=partial_variables,
                template_format="f-string",
                validate_template=validate_template,
            )
        elif isinstance(other, str):
            # 문자열을 템플릿으로 변환하여 재귀적으로 + 연산 수행
            prompt = PromptTemplate.from_template(other)
            return self + prompt
        else:
            raise NotImplementedError(f"Unsupported operand type for +: {type(other)}")

    @property
    def _prompt_type(self) -> str:
        """프롬프트 타입 키를 반환."""
        return "prompt"

    def format(self, **kwargs: Any) -> str:
        """사용자가 제공한 입력으로 프롬프트를 포맷팅.

        Args:
            kwargs: 프롬프트 템플릿에 전달할 인수들.

        Returns:
            포맷팅된 문자열.
        """
        # 사용자 변수와 부분 변수를 병합
        # **kwargs로 사용하면 여러 개의 키워드 인수를 딕셔너리로 묶어서 받을 수 있습니다.
        kwargs = self._merge_partial_and_user_variables(**kwargs)
        # 설정된 템플릿 형식에 맞게 템플릿 포맷팅 후 반환
        return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)

    @classmethod
    def from_examples(
        cls,
        examples: List[str],
        suffix: str,
        input_variables: List[str],
        example_separator: str = "\n\n",
        prefix: str = "",
        **kwargs: Any,
    ) -> PromptTemplate:
        """예제 리스트를 기반으로 프롬프트를 생성.

        Args:
            examples: 프롬프트에 사용할 예제 리스트.
            suffix: 예제 리스트 끝에 추가할 문자열.
            input_variables: 최종 프롬프트 템플릿에서 사용할 변수 이름 리스트.
            example_separator: 예제들 사이에 넣을 구분자. 기본값은 두 줄 바꿈 문자.
            prefix: 예제 리스트 앞에 추가할 문자열. 기본값은 빈 문자열.

        Returns:
            생성된 프롬프트 템플릿 객체.
        """
        # prefix와 suffix를 포함하여 예제들을 연결하여 최종 템플릿 생성
        template = example_separator.join([prefix, *examples, suffix])
        # cls를 호출해 PromptTemplate 또는 서브클래스의 인스턴스를 생성
        return cls(input_variables=input_variables, template=template, **kwargs)

    @classmethod
    def from_file(
        cls,
        template_file: Union[str, Path],
        input_variables: Optional[List[str]] = None,
        encoding: Optional[str] = None,
        **kwargs: Any,
    ) -> PromptTemplate:
        """파일에서 프롬프트를 불러와서 생성.

        Args:
            template_file: 프롬프트 템플릿을 포함한 파일 경로.
            input_variables: (사용되지 않음) 최종 템플릿에 필요한 변수 이름 리스트.
            encoding: 파일을 열 때 사용할 인코딩. OS 기본값 사용 가능.

        Returns:
            파일에서 로드한 프롬프트 템플릿 객체.
        """
        # 파일을 열어 템플릿 문자열을 읽기
        with open(str(template_file), "r", encoding=encoding) as f:
            template = f.read()
        # input_variables는 더 이상 사용되지 않음
        if input_variables:
            warnings.warn(
                "`input_variables' is deprecated and ignored.", DeprecationWarning
            )
        return cls.from_template(template=template, **kwargs)

    @classmethod
    def from_template(
        cls,
        template: str,
        *,
        template_format: str = "f-string",
        partial_variables: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> PromptTemplate:
        """문자열 템플릿으로부터 프롬프트 템플릿을 생성.

        보안 경고:
            가급적 `template_format="f-string"`을 사용하세요. `jinja2`를 사용하려면 
            신뢰할 수 없는 소스에서 제공된 템플릿을 사용하지 않아야 합니다.

        Args:
            template: 생성할 템플릿 문자열.
            template_format: 템플릿 형식. 기본값은 `f-string`.
            partial_variables: 일부 변수만 채운 상태로 템플릿을 생성할 수 있는 딕셔너리.
            kwargs: 프롬프트 템플릿에 전달할 추가 인수.

        Returns:
            생성된 프롬프트 템플릿 객체.
        """
        # 템플릿 변수 추출

        input_variables = get_template_variables(template, template_format)
        _partial_variables = partial_variables or {}
        # 부분 변수가 존재하면 해당 변수들을 입력 변수에서 제외
        if _partial_variables:
            input_variables = [
                var for var in input_variables if var not in _partial_variables
            ]
        # 템플릿 객체 생성 후 반환
        return cls(
            input_variables=input_variables,
            template=template,
            template_format=template_format,  # type: ignore[arg-type]
            partial_variables=_partial_variables,
            **kwargs,
        )