reader-lm/classification.py

32 lines
724 B
Python
Raw Normal View History

2024-10-24 03:12:23 +00:00
from pprint import pprint
from pydantic import BaseModel
from transformers import pipeline
classifier = pipeline("zero-shot-classification",
model="morit/chinese_xlm_xnli")
# 返回一个结构化的内容
class ClassifyResult(BaseModel):
sequence: str
labels: list
scores: list
prediction: str
def classify(text: str, labels: list):
output = classifier(text, labels)
pprint(output)
# 根据 score寻找最高的 label
prediction_rank = output['scores'].index(max(output['scores']))
return ClassifyResult(
sequence=text,
labels=output['labels'],
scores=output['scores'],
prediction=output['labels'][prediction_rank]
)