This commit is contained in:
ivamp 2024-11-06 02:01:25 +08:00
parent fabe3c908e
commit 69067d948f
4 changed files with 19 additions and 16 deletions

View File

@ -1,4 +1,4 @@
# docker build . --platform linux/amd64 --push -t leafdev.top/ecosystem/zero-shot-classification:v0.0.3
# docker build . --platform linux/amd64 --push -t leafdev.top/ecosystem/zero-shot-classification:v0.0.4
FROM python:3.12.7
#

View File

@ -1,3 +1,5 @@
from pprint import pprint
from pydantic import BaseModel
from transformers import pipeline
from config import MODEL_ARGS
@ -5,25 +7,27 @@ from config import MODEL_ARGS
classifier = pipeline("zero-shot-classification", **MODEL_ARGS)
# 返回一个结构化的内容
class ClassifyResult(BaseModel):
sequence: str
rank: list
scores: list
prediction: str
prediction_score: float
ranks: list[str]
labels: list[str]
scores: list[float]
def classify(text: str, labels: list):
output = classifier(text, labels)
# pprint(output)
# 根据 score寻找最高的 label
prediction_rank = output['scores'].index(max(output['scores']))
return ClassifyResult(
sequence=text,
rank=output['labels'],
prediction=output['labels'][prediction_rank],
prediction_score=output['scores'][prediction_rank],
ranks=output['labels'],
labels=labels,
scores=output['scores'],
prediction=output['labels'][prediction_rank]
)

13
main.py
View File

@ -1,7 +1,7 @@
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
import classification
from classification import ClassifyResult, classify as classifier
from config import HTTP_ARGS
app = FastAPI()
@ -12,16 +12,15 @@ class TextClassificationRequest(BaseModel):
labels: list[str]
class TextClassificationResponse(BaseModel):
prediction: str
ranks: list[str]
# class TextClassificationResponse(BaseModel):
# prediction: str
# ranks: list[str]
@app.post("/classify")
def classify(req: TextClassificationRequest) -> TextClassificationResponse:
result = classification.classify(req.text, req.labels)
def classify(req: TextClassificationRequest) -> ClassifyResult:
return TextClassificationResponse(prediction=result.prediction, ranks=result.rank)
return classifier(req.text, req.labels)
if __name__ == "__main__":

View File

@ -33,7 +33,7 @@ spec:
claimName: text-classification-pvc
containers:
- name: text-classification
image: leafdev.top/ecosystem/zero-shot-classification:v0.0.1
image: leafdev.top/ecosystem/zero-shot-classification:v0.0.4
env:
- name: HF_DATASETS_CACHE
value: "/app/models"