32 lines
724 B
Python
32 lines
724 B
Python
|
from pprint import pprint
|
|||
|
|
|||
|
from pydantic import BaseModel
|
|||
|
from transformers import pipeline
|
|||
|
|
|||
|
classifier = pipeline("zero-shot-classification",
|
|||
|
model="morit/chinese_xlm_xnli")
|
|||
|
|
|||
|
|
|||
|
# 返回一个结构化的内容
|
|||
|
class ClassifyResult(BaseModel):
|
|||
|
sequence: str
|
|||
|
labels: list
|
|||
|
scores: list
|
|||
|
prediction: str
|
|||
|
|
|||
|
|
|||
|
def classify(text: str, labels: list):
|
|||
|
output = classifier(text, labels)
|
|||
|
|
|||
|
pprint(output)
|
|||
|
|
|||
|
# 根据 score,寻找最高的 label
|
|||
|
prediction_rank = output['scores'].index(max(output['scores']))
|
|||
|
|
|||
|
return ClassifyResult(
|
|||
|
sequence=text,
|
|||
|
labels=output['labels'],
|
|||
|
scores=output['scores'],
|
|||
|
prediction=output['labels'][prediction_rank]
|
|||
|
)
|