32 lines
739 B
Python
32 lines
739 B
Python
from pprint import pprint
|
||
|
||
from pydantic import BaseModel
|
||
from transformers import pipeline
|
||
|
||
classifier = pipeline("zero-shot-classification",
|
||
model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli")
|
||
|
||
|
||
# 返回一个结构化的内容
|
||
class ClassifyResult(BaseModel):
|
||
sequence: str
|
||
rank: list
|
||
scores: list
|
||
prediction: str
|
||
|
||
|
||
def classify(text: str, labels: list):
|
||
output = classifier(text, labels)
|
||
|
||
# pprint(output)
|
||
|
||
# 根据 score,寻找最高的 label
|
||
prediction_rank = output['scores'].index(max(output['scores']))
|
||
|
||
return ClassifyResult(
|
||
sequence=text,
|
||
rank=output['labels'],
|
||
scores=output['scores'],
|
||
prediction=output['labels'][prediction_rank]
|
||
)
|