update
This commit is contained in:
parent
de1e78dbbb
commit
5f53274f26
4
.dockerignore
Normal file
4
.dockerignore
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
/models
|
||||||
|
/.venv
|
||||||
|
/.idea
|
||||||
|
/.vscode
|
14
Dockerfile
Normal file
14
Dockerfile
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# docker build . --platform linux/amd64 --push -t leafdev.top/ecosystem/zero-shot-classification:v0.0.1
|
||||||
|
FROM python:3.12.7
|
||||||
|
|
||||||
|
#
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY requirements.txt /app
|
||||||
|
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
||||||
|
|
||||||
|
#
|
||||||
|
COPY . /app
|
||||||
|
|
||||||
|
#
|
||||||
|
CMD ["uvicorn", "main:app", "--proxy-headers", "--host", "0.0.0.0", "--port", "80"]
|
@ -4,13 +4,13 @@ from pydantic import BaseModel
|
|||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
|
|
||||||
classifier = pipeline("zero-shot-classification",
|
classifier = pipeline("zero-shot-classification",
|
||||||
model="morit/chinese_xlm_xnli")
|
model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli")
|
||||||
|
|
||||||
|
|
||||||
# 返回一个结构化的内容
|
# 返回一个结构化的内容
|
||||||
class ClassifyResult(BaseModel):
|
class ClassifyResult(BaseModel):
|
||||||
sequence: str
|
sequence: str
|
||||||
labels: list
|
rank: list
|
||||||
scores: list
|
scores: list
|
||||||
prediction: str
|
prediction: str
|
||||||
|
|
||||||
@ -18,14 +18,14 @@ class ClassifyResult(BaseModel):
|
|||||||
def classify(text: str, labels: list):
|
def classify(text: str, labels: list):
|
||||||
output = classifier(text, labels)
|
output = classifier(text, labels)
|
||||||
|
|
||||||
pprint(output)
|
# pprint(output)
|
||||||
|
|
||||||
# 根据 score,寻找最高的 label
|
# 根据 score,寻找最高的 label
|
||||||
prediction_rank = output['scores'].index(max(output['scores']))
|
prediction_rank = output['scores'].index(max(output['scores']))
|
||||||
|
|
||||||
return ClassifyResult(
|
return ClassifyResult(
|
||||||
sequence=text,
|
sequence=text,
|
||||||
labels=output['labels'],
|
rank=output['labels'],
|
||||||
scores=output['scores'],
|
scores=output['scores'],
|
||||||
prediction=output['labels'][prediction_rank]
|
prediction=output['labels'][prediction_rank]
|
||||||
)
|
)
|
||||||
|
10
main.py
10
main.py
@ -1,5 +1,3 @@
|
|||||||
from typing import Union
|
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -7,22 +5,22 @@ import classification
|
|||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
class TextClassificationRequest(BaseModel):
|
class TextClassificationRequest(BaseModel):
|
||||||
text: str
|
text: str
|
||||||
labels: list[str]
|
labels: list[str]
|
||||||
|
|
||||||
|
|
||||||
class TextClassificationResponse(BaseModel):
|
class TextClassificationResponse(BaseModel):
|
||||||
prediction: str
|
prediction: str
|
||||||
labels: list[str]
|
ranks: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/classify")
|
@app.post("/classify")
|
||||||
def classify(req: TextClassificationRequest) -> TextClassificationResponse:
|
def classify(req: TextClassificationRequest) -> TextClassificationResponse:
|
||||||
result = classification.classify(req.text, req.labels)
|
result = classification.classify(req.text, req.labels)
|
||||||
|
|
||||||
return TextClassificationResponse(prediction=result.labels[0], labels=result.labels)
|
return TextClassificationResponse(prediction=result.prediction, ranks=result.rank)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
90
manifest.yaml
Normal file
90
manifest.yaml
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: PersistentVolumeClaim
|
||||||
|
metadata:
|
||||||
|
name: text-classification-pvc
|
||||||
|
namespace: ecosystem
|
||||||
|
spec:
|
||||||
|
accessModes:
|
||||||
|
- ReadWriteMany
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
storage: 2Gi
|
||||||
|
|
||||||
|
---
|
||||||
|
apiVersion: apps/v1
|
||||||
|
kind: Deployment
|
||||||
|
metadata:
|
||||||
|
name: text-classification
|
||||||
|
namespace: ecosystem
|
||||||
|
spec:
|
||||||
|
selector:
|
||||||
|
matchLabels:
|
||||||
|
app: text-classification
|
||||||
|
tier: backend
|
||||||
|
template:
|
||||||
|
metadata:
|
||||||
|
labels:
|
||||||
|
app: text-classification
|
||||||
|
tier: backend
|
||||||
|
spec:
|
||||||
|
volumes:
|
||||||
|
- name: text-classification-models
|
||||||
|
persistentVolumeClaim:
|
||||||
|
claimName: text-classification-pvc
|
||||||
|
containers:
|
||||||
|
- name: text-classification
|
||||||
|
image: leafdev.top/ecosystem/zero-shot-classification:v0.0.1
|
||||||
|
env:
|
||||||
|
- name: HF_DATASETS_CACHE
|
||||||
|
value: "/app/models"
|
||||||
|
- name: HF_HOME
|
||||||
|
value: "/app/models"
|
||||||
|
- name: HUGGINGFACE_HUB_CACHE
|
||||||
|
value: "/app/models"
|
||||||
|
- name: TRANSFORMERS_CACHE
|
||||||
|
value: "/app/models"
|
||||||
|
# - name: CUDA_VISIBLE_DEVICES
|
||||||
|
# value: "1"
|
||||||
|
ports:
|
||||||
|
- containerPort: 80
|
||||||
|
protocol: TCP
|
||||||
|
name: http
|
||||||
|
# resources:
|
||||||
|
# requests:
|
||||||
|
# cpu: 1000m
|
||||||
|
# memory: 1024Mi
|
||||||
|
volumeMounts:
|
||||||
|
- mountPath: /app/models
|
||||||
|
name: text-classification-models
|
||||||
|
---
|
||||||
|
apiVersion: v1
|
||||||
|
kind: Service
|
||||||
|
metadata:
|
||||||
|
name: text-classification
|
||||||
|
namespace: ecosystem
|
||||||
|
spec:
|
||||||
|
selector:
|
||||||
|
app: text-classification
|
||||||
|
tier: backend
|
||||||
|
type: ClusterIP
|
||||||
|
ports:
|
||||||
|
- port: 80
|
||||||
|
targetPort: 80
|
||||||
|
protocol: TCP
|
||||||
|
name: http
|
||||||
|
---
|
||||||
|
apiVersion: gateway.networking.k8s.io/v1
|
||||||
|
kind: HTTPRoute
|
||||||
|
metadata:
|
||||||
|
name: zero-shot-classification-http
|
||||||
|
namespace: ecosystem
|
||||||
|
spec:
|
||||||
|
hostnames:
|
||||||
|
- text-classification-api-testing.leaflow.cn
|
||||||
|
parentRefs:
|
||||||
|
- name: http-gw
|
||||||
|
namespace: networking
|
||||||
|
rules:
|
||||||
|
- backendRefs:
|
||||||
|
- name: text-classification
|
||||||
|
port: 80
|
Loading…
Reference in New Issue
Block a user