배경
- PromptTemplate의 내부 코드를 분석해 원리를 알아본다.
코드
from __future__ import annotations
import warnings
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union
from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING, # 각 템플릿 형식에 대한 포맷터 매핑
StringPromptTemplate, # 부모 클래스 정의
check_valid_template, # 템플릿 유효성 검사 함수
get_template_variables, # 템플릿에서 변수 추출 함수
mustache_schema, # mustache 템플릿을 위한 스키마 생성
)
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.runnables.config import RunnableConfig
class PromptTemplate(StringPromptTemplate):
"""
언어 모델을 위한 프롬프트 템플릿 클래스.
프롬프트 템플릿은 문자열 템플릿으로 구성되며, 사용자로부터 전달받은
여러 매개변수를 사용해 언어 모델의 프롬프트를 생성할 수 있습니다.
템플릿은 기본적으로 f-string 형식으로 포맷팅되며, jinja2 문법도 지원합니다.
*보안 경고*:
템플릿 형식으로 `template_format="f-string"`을 사용하는 것이 권장됩니다.
`template_format="jinja2"`를 사용하는 경우, 외부에서 제공되는 jinja2 템플릿은
절대 사용하지 마세요. 임의의 Python 코드를 실행하게 될 위험이 있습니다.
LangChain 버전 0.0.329부터는 Jinja2 템플릿을 SandboxedEnvironment를 사용해
렌더링합니다. 이는 최선의 노력으로 보안을 강화하는 방법이지만,
보안을 완전히 보장할 수는 없습니다. (Opt-out 방식 적용됨)
신뢰할 수 없는 출처에서 jinja2 템플릿을 절대 사용하지 않는 것이 좋습니다.
예제:
.. code-block:: python
from langchain_core.prompts import PromptTemplate
# from_template 메서드를 이용해 인스턴스 생성 (권장 방식)
prompt = PromptTemplate.from_template("Say {foo}")
prompt.format(foo="bar")
# 초기화 메서드를 통해 직접 인스턴스 생성
prompt = PromptTemplate(template="Say {foo}")
"""
@property
def lc_attributes(self) -> Dict[str, Any]:
"""
프롬프트 템플릿의 속성을 반환합니다.
이 속성은 템플릿의 주요 설정을 포함하며, 직렬화 및 디버깅에 유용합니다.
Returns:
Dict[str, Any]: 템플릿 속성을 포함한 딕셔너리.
"""
return {
"template_format": self.template_format,
}
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""
LangChain 객체의 네임스페이스를 반환합니다.
Returns:
List[str]: 네임스페이스 구성 요소 리스트.
"""
return ["langchain", "prompts", "prompt"]
# 프롬프트 템플릿을 담고 있는 문자열.
template: str
# 프롬프트 템플릿의 형식입니다. 지원 옵션: 'f-string', 'mustache', 'jinja2'. 기본값은 'f-string'.
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string"
# 템플릿 유효성 검사를 수행할지 여부를 나타냅니다. 기본값은 False.
validate_template: bool = False
# @root_validator(pre=True)는 Python의 데이터 유효성 검사 라이브러리인 Pydantic에서 사용되는 데코레이터입니다.
# 이를 통해 데이터 모델에 입력되기 전에 전처리를 하거나, 여러 필드 간의 상호 관계를 기반으로 데이터를 검증
@root_validator(pre=True)
def pre_init_validation(cls, values: Dict) -> Dict:
"""
템플릿과 입력 변수의 일관성을 확인하는 유효성 검사 메서드.
초기화 시 입력된 값들을 확인하고, 유효성 검사를 수행하며 기본값을 설정합니다.
Args:
values (Dict): 템플릿과 관련된 초기화 값들.
Returns:
Dict: 유효성 검사를 통과한 값들.
"""
# 템플릿이 제공되지 않은 경우 ValidationError를 발생시키도록 함
if values.get("template") is None:
return values # Pydantic이 ValidationError를 처리함
# 템플릿 형식 기본값을 f-string으로 설정
values.setdefault("template_format", "f-string")
values.setdefault("partial_variables", {})
# validate_template이 True일 경우 템플릿 유효성 검사 수행
if values.get("validate_template"):
# Mustache 형식은 유효성 검사를 지원하지 않음
if values["template_format"] == "mustache":
raise ValueError("Mustache templates cannot be validated.")
# input_variables가 없으면 예외 처리
if "input_variables" not in values:
raise ValueError(
"Input variables must be provided to validate the template."
)
# 모든 입력 변수 목록을 구성하여 유효성 검사 함수에 전달
all_inputs = values["input_variables"] + list(values["partial_variables"])
check_valid_template(
values["template"], values["template_format"], all_inputs
)
# 템플릿 형식에 따라 필요한 변수 추출
if values["template_format"]: # 템플릿 형식이 제공되었는지 확인
# 1. 템플릿과 템플릿 형식에서 요구되는 변수 리스트 가져오기
# - get_template_variables 함수는 템플릿에서 사용해야 할 변수들을 분석하여 반환합니다.
# - 예: 템플릿이 "Hello, {name}!"이고 템플릿 형식이 문자열이라면, 반환값은 ["name"]입니다.
template_variables = get_template_variables(
values["template"], values["template_format"]
)
# 2. 필수 입력 변수 정의
# - 부분 변수(partial_variables)는 이미 입력값으로 제공된 변수들의 리스트입니다.
# - template_variables 중 partial_variables에 포함되지 않은 변수들만 필수 입력 변수로 간주합니다.
# - 즉, 아직 입력값으로 제공되지 않은 변수들을 input_variables에 저장합니다.
values["input_variables"] = [
var # template_variables에서 반복적으로 변수 확인
for var in template_variables # 템플릿에 정의된 변수 목록
if var not in values["partial_variables"] # 이미 제공된 변수는 제외
]
return values
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
"""
프롬프트의 입력 스키마를 반환.
프롬프트 템플릿에 필요한 입력 변수를 정의한 스키마를 생성합니다.
Args:
config (RunnableConfig | None): 실행 가능한 구성 객체 (선택적).
Returns:
type[BaseModel]: 입력 스키마 클래스.
"""
# 템플릿 형식이 'mustache'가 아닌 경우 부모 클래스의 스키마 반환
if self.template_format != "mustache":
return super().get_input_schema(config)
# 템플릿 형식이 'mustache'인 경우 mustache 전용 스키마 생성
return mustache_schema(self.template)
def __add__(self, other: Any) -> PromptTemplate:
"""프롬프트 템플릿을 결합하기 위해 + 연산자 오버라이딩."""
# 두 템플릿을 쉽게 결합할 수 있게 허용
if isinstance(other, PromptTemplate):
# 템플릿 형식이 f-string이 아닌 경우 예외 발생
if self.template_format != "f-string":
raise ValueError(
"Adding prompt templates only supported for f-strings."
)
if other.template_format != "f-string":
raise ValueError(
"Adding prompt templates only supported for f-strings."
)
# 두 템플릿의 입력 변수를 결합
input_variables = list(
set(self.input_variables) | set(other.input_variables)
)
# 템플릿 문자열을 결합
template = self.template + other.template
# 두 템플릿 모두 유효성 검사 설정이 True인 경우에만 유효성 검사 수행
validate_template = self.validate_template and other.validate_template
# 부분 변수 결합 (중복 검사)
partial_variables = {k: v for k, v in self.partial_variables.items()}
for k, v in other.partial_variables.items():
if k in partial_variables:
raise ValueError("Cannot have same variable partialed twice.")
else:
partial_variables[k] = v
# 결합된 템플릿 반환
return PromptTemplate(
template=template,
input_variables=input_variables,
partial_variables=partial_variables,
template_format="f-string",
validate_template=validate_template,
)
elif isinstance(other, str):
# 문자열을 템플릿으로 변환하여 재귀적으로 + 연산 수행
prompt = PromptTemplate.from_template(other)
return self + prompt
else:
raise NotImplementedError(f"Unsupported operand type for +: {type(other)}")
@property
def _prompt_type(self) -> str:
"""프롬프트 타입 키를 반환."""
return "prompt"
def format(self, **kwargs: Any) -> str:
"""사용자가 제공한 입력으로 프롬프트를 포맷팅.
Args:
kwargs: 프롬프트 템플릿에 전달할 인수들.
Returns:
포맷팅된 문자열.
"""
# 사용자 변수와 부분 변수를 병합
# **kwargs로 사용하면 여러 개의 키워드 인수를 딕셔너리로 묶어서 받을 수 있습니다.
kwargs = self._merge_partial_and_user_variables(**kwargs)
# 설정된 템플릿 형식에 맞게 템플릿 포맷팅 후 반환
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
@classmethod
def from_examples(
cls,
examples: List[str],
suffix: str,
input_variables: List[str],
example_separator: str = "\n\n",
prefix: str = "",
**kwargs: Any,
) -> PromptTemplate:
"""예제 리스트를 기반으로 프롬프트를 생성.
Args:
examples: 프롬프트에 사용할 예제 리스트.
suffix: 예제 리스트 끝에 추가할 문자열.
input_variables: 최종 프롬프트 템플릿에서 사용할 변수 이름 리스트.
example_separator: 예제들 사이에 넣을 구분자. 기본값은 두 줄 바꿈 문자.
prefix: 예제 리스트 앞에 추가할 문자열. 기본값은 빈 문자열.
Returns:
생성된 프롬프트 템플릿 객체.
"""
# prefix와 suffix를 포함하여 예제들을 연결하여 최종 템플릿 생성
template = example_separator.join([prefix, *examples, suffix])
# cls를 호출해 PromptTemplate 또는 서브클래스의 인스턴스를 생성
return cls(input_variables=input_variables, template=template, **kwargs)
@classmethod
def from_file(
cls,
template_file: Union[str, Path],
input_variables: Optional[List[str]] = None,
encoding: Optional[str] = None,
**kwargs: Any,
) -> PromptTemplate:
"""파일에서 프롬프트를 불러와서 생성.
Args:
template_file: 프롬프트 템플릿을 포함한 파일 경로.
input_variables: (사용되지 않음) 최종 템플릿에 필요한 변수 이름 리스트.
encoding: 파일을 열 때 사용할 인코딩. OS 기본값 사용 가능.
Returns:
파일에서 로드한 프롬프트 템플릿 객체.
"""
# 파일을 열어 템플릿 문자열을 읽기
with open(str(template_file), "r", encoding=encoding) as f:
template = f.read()
# input_variables는 더 이상 사용되지 않음
if input_variables:
warnings.warn(
"`input_variables' is deprecated and ignored.", DeprecationWarning
)
return cls.from_template(template=template, **kwargs)
@classmethod
def from_template(
cls,
template: str,
*,
template_format: str = "f-string",
partial_variables: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> PromptTemplate:
"""문자열 템플릿으로부터 프롬프트 템플릿을 생성.
보안 경고:
가급적 `template_format="f-string"`을 사용하세요. `jinja2`를 사용하려면
신뢰할 수 없는 소스에서 제공된 템플릿을 사용하지 않아야 합니다.
Args:
template: 생성할 템플릿 문자열.
template_format: 템플릿 형식. 기본값은 `f-string`.
partial_variables: 일부 변수만 채운 상태로 템플릿을 생성할 수 있는 딕셔너리.
kwargs: 프롬프트 템플릿에 전달할 추가 인수.
Returns:
생성된 프롬프트 템플릿 객체.
"""
# 템플릿 변수 추출
input_variables = get_template_variables(template, template_format)
_partial_variables = partial_variables or {}
# 부분 변수가 존재하면 해당 변수들을 입력 변수에서 제외
if _partial_variables:
input_variables = [
var for var in input_variables if var not in _partial_variables
]
# 템플릿 객체 생성 후 반환
return cls(
input_variables=input_variables,
template=template,
template_format=template_format, # type: ignore[arg-type]
partial_variables=_partial_variables,
**kwargs,
)