改进
This commit is contained in:
parent
fabe3c908e
commit
69067d948f
@ -1,4 +1,4 @@
|
||||
# docker build . --platform linux/amd64 --push -t leafdev.top/ecosystem/zero-shot-classification:v0.0.3
|
||||
# docker build . --platform linux/amd64 --push -t leafdev.top/ecosystem/zero-shot-classification:v0.0.4
|
||||
FROM python:3.12.7
|
||||
|
||||
#
|
||||
|
@ -1,3 +1,5 @@
|
||||
from pprint import pprint
|
||||
|
||||
from pydantic import BaseModel
|
||||
from transformers import pipeline
|
||||
from config import MODEL_ARGS
|
||||
@ -5,25 +7,27 @@ from config import MODEL_ARGS
|
||||
classifier = pipeline("zero-shot-classification", **MODEL_ARGS)
|
||||
|
||||
|
||||
# 返回一个结构化的内容
|
||||
class ClassifyResult(BaseModel):
|
||||
sequence: str
|
||||
rank: list
|
||||
scores: list
|
||||
prediction: str
|
||||
prediction_score: float
|
||||
ranks: list[str]
|
||||
labels: list[str]
|
||||
scores: list[float]
|
||||
|
||||
|
||||
|
||||
def classify(text: str, labels: list):
|
||||
output = classifier(text, labels)
|
||||
|
||||
# pprint(output)
|
||||
|
||||
# 根据 score,寻找最高的 label
|
||||
prediction_rank = output['scores'].index(max(output['scores']))
|
||||
|
||||
return ClassifyResult(
|
||||
sequence=text,
|
||||
rank=output['labels'],
|
||||
prediction=output['labels'][prediction_rank],
|
||||
prediction_score=output['scores'][prediction_rank],
|
||||
ranks=output['labels'],
|
||||
labels=labels,
|
||||
scores=output['scores'],
|
||||
prediction=output['labels'][prediction_rank]
|
||||
)
|
||||
|
13
main.py
13
main.py
@ -1,7 +1,7 @@
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
import classification
|
||||
from classification import ClassifyResult, classify as classifier
|
||||
from config import HTTP_ARGS
|
||||
|
||||
app = FastAPI()
|
||||
@ -12,16 +12,15 @@ class TextClassificationRequest(BaseModel):
|
||||
labels: list[str]
|
||||
|
||||
|
||||
class TextClassificationResponse(BaseModel):
|
||||
prediction: str
|
||||
ranks: list[str]
|
||||
# class TextClassificationResponse(BaseModel):
|
||||
# prediction: str
|
||||
# ranks: list[str]
|
||||
|
||||
|
||||
@app.post("/classify")
|
||||
def classify(req: TextClassificationRequest) -> TextClassificationResponse:
|
||||
result = classification.classify(req.text, req.labels)
|
||||
def classify(req: TextClassificationRequest) -> ClassifyResult:
|
||||
|
||||
return TextClassificationResponse(prediction=result.prediction, ranks=result.rank)
|
||||
return classifier(req.text, req.labels)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -33,7 +33,7 @@ spec:
|
||||
claimName: text-classification-pvc
|
||||
containers:
|
||||
- name: text-classification
|
||||
image: leafdev.top/ecosystem/zero-shot-classification:v0.0.1
|
||||
image: leafdev.top/ecosystem/zero-shot-classification:v0.0.4
|
||||
env:
|
||||
- name: HF_DATASETS_CACHE
|
||||
value: "/app/models"
|
||||
|
Loading…
Reference in New Issue
Block a user