34 lines
806 B
Python
34 lines
806 B
Python
from pprint import pprint
|
||
|
||
from pydantic import BaseModel
|
||
from transformers import pipeline
|
||
from config import MODEL_ARGS
|
||
|
||
classifier = pipeline("zero-shot-classification", **MODEL_ARGS)
|
||
|
||
|
||
class ClassifyResult(BaseModel):
|
||
sequence: str
|
||
prediction: str
|
||
prediction_score: float
|
||
ranks: list[str]
|
||
labels: list[str]
|
||
scores: list[float]
|
||
|
||
|
||
|
||
def classify(text: str, labels: list):
|
||
output = classifier(text, labels)
|
||
|
||
# 根据 score,寻找最高的 label
|
||
prediction_rank = output['scores'].index(max(output['scores']))
|
||
|
||
return ClassifyResult(
|
||
sequence=text,
|
||
prediction=output['labels'][prediction_rank],
|
||
prediction_score=output['scores'][prediction_rank],
|
||
ranks=output['labels'],
|
||
labels=labels,
|
||
scores=output['scores'],
|
||
)
|