diff --git a/Dockerfile b/Dockerfile index d53add5..dc7b7ee 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -# docker build . --platform linux/amd64 --push -t leafdev.top/ecosystem/zero-shot-classification:v0.0.3 +# docker build . --platform linux/amd64 --push -t leafdev.top/ecosystem/zero-shot-classification:v0.0.4 FROM python:3.12.7 # diff --git a/classification.py b/classification.py index b003faf..a484ceb 100644 --- a/classification.py +++ b/classification.py @@ -1,3 +1,5 @@ +from pprint import pprint + from pydantic import BaseModel from transformers import pipeline from config import MODEL_ARGS @@ -5,25 +7,27 @@ from config import MODEL_ARGS classifier = pipeline("zero-shot-classification", **MODEL_ARGS) -# 返回一个结构化的内容 class ClassifyResult(BaseModel): sequence: str - rank: list - scores: list prediction: str + prediction_score: float + ranks: list[str] + labels: list[str] + scores: list[float] + def classify(text: str, labels: list): output = classifier(text, labels) - # pprint(output) - # 根据 score,寻找最高的 label prediction_rank = output['scores'].index(max(output['scores'])) return ClassifyResult( sequence=text, - rank=output['labels'], + prediction=output['labels'][prediction_rank], + prediction_score=output['scores'][prediction_rank], + ranks=output['labels'], + labels=labels, scores=output['scores'], - prediction=output['labels'][prediction_rank] ) diff --git a/main.py b/main.py index 3d9a603..681165b 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,7 @@ import uvicorn from fastapi import FastAPI from pydantic import BaseModel -import classification +from classification import ClassifyResult, classify as classifier from config import HTTP_ARGS app = FastAPI() @@ -12,16 +12,15 @@ class TextClassificationRequest(BaseModel): labels: list[str] -class TextClassificationResponse(BaseModel): - prediction: str - ranks: list[str] +# class TextClassificationResponse(BaseModel): +# prediction: str +# ranks: list[str] @app.post("/classify") -def classify(req: TextClassificationRequest) -> TextClassificationResponse: - result = classification.classify(req.text, req.labels) +def classify(req: TextClassificationRequest) -> ClassifyResult: - return TextClassificationResponse(prediction=result.prediction, ranks=result.rank) + return classifier(req.text, req.labels) if __name__ == "__main__": diff --git a/manifest.yaml b/manifest.yaml index f478162..7784f95 100644 --- a/manifest.yaml +++ b/manifest.yaml @@ -33,7 +33,7 @@ spec: claimName: text-classification-pvc containers: - name: text-classification - image: leafdev.top/ecosystem/zero-shot-classification:v0.0.1 + image: leafdev.top/ecosystem/zero-shot-classification:v0.0.4 env: - name: HF_DATASETS_CACHE value: "/app/models"