from pprint import pprint from pydantic import BaseModel from transformers import pipeline from config import MODEL_ARGS classifier = pipeline("zero-shot-classification", **MODEL_ARGS) class ClassifyResult(BaseModel): sequence: str prediction: str prediction_score: float ranks: list[str] labels: list[str] scores: list[float] def classify(text: str, labels: list): output = classifier(text, labels) # 根据 score,寻找最高的 label prediction_rank = output['scores'].index(max(output['scores'])) return ClassifyResult( sequence=text, prediction=output['labels'][prediction_rank], prediction_score=output['scores'][prediction_rank], ranks=output['labels'], labels=labels, scores=output['scores'], )