32 lines
724 B
Python
32 lines
724 B
Python
from pprint import pprint
|
||
|
||
from pydantic import BaseModel
|
||
from transformers import pipeline
|
||
|
||
classifier = pipeline("zero-shot-classification",
|
||
model="morit/chinese_xlm_xnli")
|
||
|
||
|
||
# 返回一个结构化的内容
|
||
class ClassifyResult(BaseModel):
|
||
sequence: str
|
||
labels: list
|
||
scores: list
|
||
prediction: str
|
||
|
||
|
||
def classify(text: str, labels: list):
|
||
output = classifier(text, labels)
|
||
|
||
pprint(output)
|
||
|
||
# 根据 score,寻找最高的 label
|
||
prediction_rank = output['scores'].index(max(output['scores']))
|
||
|
||
return ClassifyResult(
|
||
sequence=text,
|
||
labels=output['labels'],
|
||
scores=output['scores'],
|
||
prediction=output['labels'][prediction_rank]
|
||
)
|