reader-lm/classification.py

30 lines
685 B
Python
Raw Permalink Normal View History

2024-10-24 03:12:23 +00:00
from pydantic import BaseModel
from transformers import pipeline
from config import MODEL_ARGS
2024-10-24 03:12:23 +00:00
classifier = pipeline("zero-shot-classification", **MODEL_ARGS)
2024-10-24 03:12:23 +00:00
# 返回一个结构化的内容
class ClassifyResult(BaseModel):
sequence: str
2024-10-24 05:07:11 +00:00
rank: list
2024-10-24 03:12:23 +00:00
scores: list
prediction: str
def classify(text: str, labels: list):
output = classifier(text, labels)
2024-10-24 05:07:11 +00:00
# pprint(output)
2024-10-24 03:12:23 +00:00
# 根据 score寻找最高的 label
prediction_rank = output['scores'].index(max(output['scores']))
return ClassifyResult(
sequence=text,
2024-10-24 05:07:11 +00:00
rank=output['labels'],
2024-10-24 03:12:23 +00:00
scores=output['scores'],
prediction=output['labels'][prediction_rank]
)