zero-shot-classification-zh/classification.py

32 lines
739 B
Python
Raw Normal View History

2024-10-24 03:12:23 +00:00
from pprint import pprint
from pydantic import BaseModel
from transformers import pipeline
classifier = pipeline("zero-shot-classification",
2024-10-24 05:07:11 +00:00
model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli")
2024-10-24 03:12:23 +00:00
# 返回一个结构化的内容
class ClassifyResult(BaseModel):
sequence: str
2024-10-24 05:07:11 +00:00
rank: list
2024-10-24 03:12:23 +00:00
scores: list
prediction: str
def classify(text: str, labels: list):
output = classifier(text, labels)
2024-10-24 05:07:11 +00:00
# pprint(output)
2024-10-24 03:12:23 +00:00
# 根据 score寻找最高的 label
prediction_rank = output['scores'].index(max(output['scores']))
return ClassifyResult(
sequence=text,
2024-10-24 05:07:11 +00:00
rank=output['labels'],
2024-10-24 03:12:23 +00:00
scores=output['scores'],
prediction=output['labels'][prediction_rank]
)