This commit is contained in:
ivamp 2024-10-24 13:07:11 +08:00
parent de1e78dbbb
commit 5f53274f26
5 changed files with 117 additions and 11 deletions

4
.dockerignore Normal file
View File

@ -0,0 +1,4 @@
/models
/.venv
/.idea
/.vscode

14
Dockerfile Normal file
View File

@ -0,0 +1,14 @@
# docker build . --platform linux/amd64 --push -t leafdev.top/ecosystem/zero-shot-classification:v0.0.1
FROM python:3.12.7
#
WORKDIR /app
COPY requirements.txt /app
RUN pip install --no-cache-dir --upgrade -r requirements.txt
#
COPY . /app
#
CMD ["uvicorn", "main:app", "--proxy-headers", "--host", "0.0.0.0", "--port", "80"]

View File

@ -4,13 +4,13 @@ from pydantic import BaseModel
from transformers import pipeline
classifier = pipeline("zero-shot-classification",
model="morit/chinese_xlm_xnli")
model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli")
# 返回一个结构化的内容
class ClassifyResult(BaseModel):
sequence: str
labels: list
rank: list
scores: list
prediction: str
@ -18,14 +18,14 @@ class ClassifyResult(BaseModel):
def classify(text: str, labels: list):
output = classifier(text, labels)
pprint(output)
# pprint(output)
# 根据 score寻找最高的 label
prediction_rank = output['scores'].index(max(output['scores']))
return ClassifyResult(
sequence=text,
labels=output['labels'],
rank=output['labels'],
scores=output['scores'],
prediction=output['labels'][prediction_rank]
)

12
main.py
View File

@ -1,5 +1,3 @@
from typing import Union
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
@ -7,23 +5,23 @@ import classification
app = FastAPI()
class TextClassificationRequest(BaseModel):
text: str
labels: list[str]
class TextClassificationResponse(BaseModel):
prediction: str
labels: list[str]
ranks: list[str]
@app.post("/classify")
def classify(req: TextClassificationRequest) -> TextClassificationResponse:
result = classification.classify(req.text, req.labels)
return TextClassificationResponse(prediction=result.labels[0], labels=result.labels)
return TextClassificationResponse(prediction=result.prediction, ranks=result.rank)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
uvicorn.run(app, host="0.0.0.0", port=8000)

90
manifest.yaml Normal file
View File

@ -0,0 +1,90 @@
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: text-classification-pvc
namespace: ecosystem
spec:
accessModes:
- ReadWriteMany
resources:
requests:
storage: 2Gi
---
apiVersion: apps/v1
kind: Deployment
metadata:
name: text-classification
namespace: ecosystem
spec:
selector:
matchLabels:
app: text-classification
tier: backend
template:
metadata:
labels:
app: text-classification
tier: backend
spec:
volumes:
- name: text-classification-models
persistentVolumeClaim:
claimName: text-classification-pvc
containers:
- name: text-classification
image: leafdev.top/ecosystem/zero-shot-classification:v0.0.1
env:
- name: HF_DATASETS_CACHE
value: "/app/models"
- name: HF_HOME
value: "/app/models"
- name: HUGGINGFACE_HUB_CACHE
value: "/app/models"
- name: TRANSFORMERS_CACHE
value: "/app/models"
# - name: CUDA_VISIBLE_DEVICES
# value: "1"
ports:
- containerPort: 80
protocol: TCP
name: http
# resources:
# requests:
# cpu: 1000m
# memory: 1024Mi
volumeMounts:
- mountPath: /app/models
name: text-classification-models
---
apiVersion: v1
kind: Service
metadata:
name: text-classification
namespace: ecosystem
spec:
selector:
app: text-classification
tier: backend
type: ClusterIP
ports:
- port: 80
targetPort: 80
protocol: TCP
name: http
---
apiVersion: gateway.networking.k8s.io/v1
kind: HTTPRoute
metadata:
name: zero-shot-classification-http
namespace: ecosystem
spec:
hostnames:
- text-classification-api-testing.leaflow.cn
parentRefs:
- name: http-gw
namespace: networking
rules:
- backendRefs:
- name: text-classification
port: 80