2024-10-24 03:12:23 +00:00
|
|
|
|
from pprint import pprint
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
from transformers import pipeline
|
|
|
|
|
|
|
|
|
|
classifier = pipeline("zero-shot-classification",
|
2024-10-24 05:07:11 +00:00
|
|
|
|
model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli")
|
2024-10-24 03:12:23 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 返回一个结构化的内容
|
|
|
|
|
class ClassifyResult(BaseModel):
|
|
|
|
|
sequence: str
|
2024-10-24 05:07:11 +00:00
|
|
|
|
rank: list
|
2024-10-24 03:12:23 +00:00
|
|
|
|
scores: list
|
|
|
|
|
prediction: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify(text: str, labels: list):
|
|
|
|
|
output = classifier(text, labels)
|
|
|
|
|
|
2024-10-24 05:07:11 +00:00
|
|
|
|
# pprint(output)
|
2024-10-24 03:12:23 +00:00
|
|
|
|
|
|
|
|
|
# 根据 score,寻找最高的 label
|
|
|
|
|
prediction_rank = output['scores'].index(max(output['scores']))
|
|
|
|
|
|
|
|
|
|
return ClassifyResult(
|
|
|
|
|
sequence=text,
|
2024-10-24 05:07:11 +00:00
|
|
|
|
rank=output['labels'],
|
2024-10-24 03:12:23 +00:00
|
|
|
|
scores=output['scores'],
|
|
|
|
|
prediction=output['labels'][prediction_rank]
|
|
|
|
|
)
|