update
This commit is contained in:
parent
de1e78dbbb
commit
5f53274f26
4
.dockerignore
Normal file
4
.dockerignore
Normal file
@ -0,0 +1,4 @@
|
||||
/models
|
||||
/.venv
|
||||
/.idea
|
||||
/.vscode
|
14
Dockerfile
Normal file
14
Dockerfile
Normal file
@ -0,0 +1,14 @@
|
||||
# docker build . --platform linux/amd64 --push -t leafdev.top/ecosystem/zero-shot-classification:v0.0.1
|
||||
FROM python:3.12.7
|
||||
|
||||
#
|
||||
WORKDIR /app
|
||||
|
||||
COPY requirements.txt /app
|
||||
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
||||
|
||||
#
|
||||
COPY . /app
|
||||
|
||||
#
|
||||
CMD ["uvicorn", "main:app", "--proxy-headers", "--host", "0.0.0.0", "--port", "80"]
|
@ -4,13 +4,13 @@ from pydantic import BaseModel
|
||||
from transformers import pipeline
|
||||
|
||||
classifier = pipeline("zero-shot-classification",
|
||||
model="morit/chinese_xlm_xnli")
|
||||
model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli")
|
||||
|
||||
|
||||
# 返回一个结构化的内容
|
||||
class ClassifyResult(BaseModel):
|
||||
sequence: str
|
||||
labels: list
|
||||
rank: list
|
||||
scores: list
|
||||
prediction: str
|
||||
|
||||
@ -18,14 +18,14 @@ class ClassifyResult(BaseModel):
|
||||
def classify(text: str, labels: list):
|
||||
output = classifier(text, labels)
|
||||
|
||||
pprint(output)
|
||||
# pprint(output)
|
||||
|
||||
# 根据 score,寻找最高的 label
|
||||
prediction_rank = output['scores'].index(max(output['scores']))
|
||||
|
||||
return ClassifyResult(
|
||||
sequence=text,
|
||||
labels=output['labels'],
|
||||
rank=output['labels'],
|
||||
scores=output['scores'],
|
||||
prediction=output['labels'][prediction_rank]
|
||||
)
|
||||
|
12
main.py
12
main.py
@ -1,5 +1,3 @@
|
||||
from typing import Union
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
@ -7,23 +5,23 @@ import classification
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class TextClassificationRequest(BaseModel):
|
||||
text: str
|
||||
labels: list[str]
|
||||
|
||||
|
||||
class TextClassificationResponse(BaseModel):
|
||||
prediction: str
|
||||
labels: list[str]
|
||||
|
||||
ranks: list[str]
|
||||
|
||||
|
||||
@app.post("/classify")
|
||||
def classify(req: TextClassificationRequest) -> TextClassificationResponse:
|
||||
result = classification.classify(req.text, req.labels)
|
||||
|
||||
return TextClassificationResponse(prediction=result.labels[0], labels=result.labels)
|
||||
|
||||
return TextClassificationResponse(prediction=result.prediction, ranks=result.rank)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
90
manifest.yaml
Normal file
90
manifest.yaml
Normal file
@ -0,0 +1,90 @@
|
||||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: text-classification-pvc
|
||||
namespace: ecosystem
|
||||
spec:
|
||||
accessModes:
|
||||
- ReadWriteMany
|
||||
resources:
|
||||
requests:
|
||||
storage: 2Gi
|
||||
|
||||
---
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: text-classification
|
||||
namespace: ecosystem
|
||||
spec:
|
||||
selector:
|
||||
matchLabels:
|
||||
app: text-classification
|
||||
tier: backend
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: text-classification
|
||||
tier: backend
|
||||
spec:
|
||||
volumes:
|
||||
- name: text-classification-models
|
||||
persistentVolumeClaim:
|
||||
claimName: text-classification-pvc
|
||||
containers:
|
||||
- name: text-classification
|
||||
image: leafdev.top/ecosystem/zero-shot-classification:v0.0.1
|
||||
env:
|
||||
- name: HF_DATASETS_CACHE
|
||||
value: "/app/models"
|
||||
- name: HF_HOME
|
||||
value: "/app/models"
|
||||
- name: HUGGINGFACE_HUB_CACHE
|
||||
value: "/app/models"
|
||||
- name: TRANSFORMERS_CACHE
|
||||
value: "/app/models"
|
||||
# - name: CUDA_VISIBLE_DEVICES
|
||||
# value: "1"
|
||||
ports:
|
||||
- containerPort: 80
|
||||
protocol: TCP
|
||||
name: http
|
||||
# resources:
|
||||
# requests:
|
||||
# cpu: 1000m
|
||||
# memory: 1024Mi
|
||||
volumeMounts:
|
||||
- mountPath: /app/models
|
||||
name: text-classification-models
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: text-classification
|
||||
namespace: ecosystem
|
||||
spec:
|
||||
selector:
|
||||
app: text-classification
|
||||
tier: backend
|
||||
type: ClusterIP
|
||||
ports:
|
||||
- port: 80
|
||||
targetPort: 80
|
||||
protocol: TCP
|
||||
name: http
|
||||
---
|
||||
apiVersion: gateway.networking.k8s.io/v1
|
||||
kind: HTTPRoute
|
||||
metadata:
|
||||
name: zero-shot-classification-http
|
||||
namespace: ecosystem
|
||||
spec:
|
||||
hostnames:
|
||||
- text-classification-api-testing.leaflow.cn
|
||||
parentRefs:
|
||||
- name: http-gw
|
||||
namespace: networking
|
||||
rules:
|
||||
- backendRefs:
|
||||
- name: text-classification
|
||||
port: 80
|
Loading…
Reference in New Issue
Block a user