This commit is contained in:
ivamp 2024-11-06 02:01:25 +08:00
parent fabe3c908e
commit 69067d948f
4 changed files with 19 additions and 16 deletions

View File

@ -1,4 +1,4 @@
# docker build . --platform linux/amd64 --push -t leafdev.top/ecosystem/zero-shot-classification:v0.0.3 # docker build . --platform linux/amd64 --push -t leafdev.top/ecosystem/zero-shot-classification:v0.0.4
FROM python:3.12.7 FROM python:3.12.7
# #

View File

@ -1,3 +1,5 @@
from pprint import pprint
from pydantic import BaseModel from pydantic import BaseModel
from transformers import pipeline from transformers import pipeline
from config import MODEL_ARGS from config import MODEL_ARGS
@ -5,25 +7,27 @@ from config import MODEL_ARGS
classifier = pipeline("zero-shot-classification", **MODEL_ARGS) classifier = pipeline("zero-shot-classification", **MODEL_ARGS)
# 返回一个结构化的内容
class ClassifyResult(BaseModel): class ClassifyResult(BaseModel):
sequence: str sequence: str
rank: list
scores: list
prediction: str prediction: str
prediction_score: float
ranks: list[str]
labels: list[str]
scores: list[float]
def classify(text: str, labels: list): def classify(text: str, labels: list):
output = classifier(text, labels) output = classifier(text, labels)
# pprint(output)
# 根据 score寻找最高的 label # 根据 score寻找最高的 label
prediction_rank = output['scores'].index(max(output['scores'])) prediction_rank = output['scores'].index(max(output['scores']))
return ClassifyResult( return ClassifyResult(
sequence=text, sequence=text,
rank=output['labels'], prediction=output['labels'][prediction_rank],
prediction_score=output['scores'][prediction_rank],
ranks=output['labels'],
labels=labels,
scores=output['scores'], scores=output['scores'],
prediction=output['labels'][prediction_rank]
) )

13
main.py
View File

@ -1,7 +1,7 @@
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from pydantic import BaseModel from pydantic import BaseModel
import classification from classification import ClassifyResult, classify as classifier
from config import HTTP_ARGS from config import HTTP_ARGS
app = FastAPI() app = FastAPI()
@ -12,16 +12,15 @@ class TextClassificationRequest(BaseModel):
labels: list[str] labels: list[str]
class TextClassificationResponse(BaseModel): # class TextClassificationResponse(BaseModel):
prediction: str # prediction: str
ranks: list[str] # ranks: list[str]
@app.post("/classify") @app.post("/classify")
def classify(req: TextClassificationRequest) -> TextClassificationResponse: def classify(req: TextClassificationRequest) -> ClassifyResult:
result = classification.classify(req.text, req.labels)
return TextClassificationResponse(prediction=result.prediction, ranks=result.rank) return classifier(req.text, req.labels)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -33,7 +33,7 @@ spec:
claimName: text-classification-pvc claimName: text-classification-pvc
containers: containers:
- name: text-classification - name: text-classification
image: leafdev.top/ecosystem/zero-shot-classification:v0.0.1 image: leafdev.top/ecosystem/zero-shot-classification:v0.0.4
env: env:
- name: HF_DATASETS_CACHE - name: HF_DATASETS_CACHE
value: "/app/models" value: "/app/models"