在Langchain中同时使用Chain和Parser

11

LangChain文档中包含了配置和调用PydanticOutputParser的示例。

# Define your desired data structure.
class Joke(BaseModel):
    setup: str = Field(description="question to set up a joke")
    punchline: str = Field(description="answer to resolve the joke")
    
    # You can add custom validation logic easily with Pydantic.
    @validator('setup')
    def question_ends_with_question_mark(cls, field):
        if field[-1] != '?':
            raise ValueError("Badly formed question!")
        return field

# And a query intented to prompt a language model to populate the data structure.
joke_query = "Tell me a joke."

# Set up a parser + inject instructions into the prompt template.
parser = PydanticOutputParser(pydantic_object=Joke)

prompt = PromptTemplate(
    template="Answer the user query.\n{format_instructions}\n{query}\n",
    input_variables=["query"],
    partial_variables={"format_instructions": parser.get_format_instructions()}
)

然而,实际进行API调用的代码有点奇怪:

model_name = 'text-davinci-003'
temperature = 0.0
my_llm = OpenAI(model_name=model_name, temperature=temperature)

_input = prompt.format_prompt(query=joke_query)
output = my_llm(_input.to_string())

parser.parse(output)

这将返回我们想要的内容:Joke(setup='为什么鸡会过马路?', punchline='为了到达另一边!')

然而,不使用Chains似乎有些奇怪。

我可以通过以下方式接近:

chain = LLMChain(llm=my_llm, prompt=prompt)
chain.run(query=joke_query)

但是这会返回原始的、未解析的文本:'\n{"setup": "为什么小鸡过马路?", "punchline": "去到另一边!"}'

有没有一种首选的方法来让Chain类充分利用Parser,并返回解析后的对象?我可以创建子类并扩展LLMChain,但如果这个功能已经存在,我会感到惊讶。


完全同意,这不符合LangChainic的风格。此外,由于某种原因,PromptTemplate类需要一个可选的输出解析器参数... - Alexandre Dumont
2个回答

5
你可以使用 TransformChain 来实现这个功能!
from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain, TransformChain
from langchain.chains import SequentialChain


llm = ChatOpenAI(temperature=0.5)
llm_chain = LLMChain(
    prompt=prompt,
    llm=llm,
    output_key="json_string",
)

def parse_output(inputs: dict) -> dict:
    text = inputs["json_string"]
    return {"result": parser.parse(text)}

transform_chain = TransformChain(
    input_variables=["json_string"],
    output_variables=["result"],
    transform=parse_output
)

chain = SequentialChain(
    input_variables=["joke_query"],
    output_variables=["result"],
    chains=[llm_chain, transform_chain],
)

chain.run(query="Tell me a joke.")

0
你应该能够使用解析器来解析链的输出。无需创建子类:
output = chain.run(query=joke_query)
bad_joke = parser.parse(output)

我使用langchainjs,所以语法不是很确定,但这应该可以让你接近。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接