zero-shot-classification-zh/classification.py

34 lines
806 B
Python
Raw Permalink Normal View History

2024-11-05 18:01:25 +00:00
from pprint import pprint
2024-10-24 03:12:23 +00:00
from pydantic import BaseModel
from transformers import pipeline
from config import MODEL_ARGS
2024-10-24 03:12:23 +00:00
classifier = pipeline("zero-shot-classification", **MODEL_ARGS)
2024-10-24 03:12:23 +00:00
class ClassifyResult(BaseModel):
sequence: str
prediction: str
2024-11-05 18:01:25 +00:00
prediction_score: float
ranks: list[str]
labels: list[str]
scores: list[float]
2024-10-24 03:12:23 +00:00
def classify(text: str, labels: list):
output = classifier(text, labels)
# 根据 score寻找最高的 label
prediction_rank = output['scores'].index(max(output['scores']))
return ClassifyResult(
sequence=text,
2024-11-05 18:01:25 +00:00
prediction=output['labels'][prediction_rank],
prediction_score=output['scores'][prediction_rank],
ranks=output['labels'],
labels=labels,
2024-10-24 03:12:23 +00:00
scores=output['scores'],
)