from pprint import pprint from pydantic import BaseModel from transformers import pipeline classifier = pipeline("zero-shot-classification", model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli") # 返回一个结构化的内容 class ClassifyResult(BaseModel): sequence: str rank: list scores: list prediction: str def classify(text: str, labels: list): output = classifier(text, labels) # pprint(output) # 根据 score,寻找最高的 label prediction_rank = output['scores'].index(max(output['scores'])) return ClassifyResult( sequence=text, rank=output['labels'], scores=output['scores'], prediction=output['labels'][prediction_rank] )