diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..0a68b83 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,4 @@ +/models +/.venv +/.idea +/.vscode \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..f8e4b74 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,14 @@ +# docker build . --platform linux/amd64 --push -t leafdev.top/ecosystem/zero-shot-classification:v0.0.1 +FROM python:3.12.7 + +# +WORKDIR /app + +COPY requirements.txt /app +RUN pip install --no-cache-dir --upgrade -r requirements.txt + +# +COPY . /app + +# +CMD ["uvicorn", "main:app", "--proxy-headers", "--host", "0.0.0.0", "--port", "80"] \ No newline at end of file diff --git a/classification.py b/classification.py index a4b8eee..8b585c5 100644 --- a/classification.py +++ b/classification.py @@ -4,13 +4,13 @@ from pydantic import BaseModel from transformers import pipeline classifier = pipeline("zero-shot-classification", - model="morit/chinese_xlm_xnli") + model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli") # 返回一个结构化的内容 class ClassifyResult(BaseModel): sequence: str - labels: list + rank: list scores: list prediction: str @@ -18,14 +18,14 @@ class ClassifyResult(BaseModel): def classify(text: str, labels: list): output = classifier(text, labels) - pprint(output) + # pprint(output) # 根据 score,寻找最高的 label prediction_rank = output['scores'].index(max(output['scores'])) return ClassifyResult( sequence=text, - labels=output['labels'], + rank=output['labels'], scores=output['scores'], prediction=output['labels'][prediction_rank] ) diff --git a/main.py b/main.py index cc9f306..91c4bf8 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,3 @@ -from typing import Union - import uvicorn from fastapi import FastAPI from pydantic import BaseModel @@ -7,23 +5,23 @@ import classification app = FastAPI() + class TextClassificationRequest(BaseModel): text: str labels: list[str] + class TextClassificationResponse(BaseModel): prediction: str - labels: list[str] - + ranks: list[str] @app.post("/classify") def classify(req: TextClassificationRequest) -> TextClassificationResponse: result = classification.classify(req.text, req.labels) - return TextClassificationResponse(prediction=result.labels[0], labels=result.labels) - + return TextClassificationResponse(prediction=result.prediction, ranks=result.rank) if __name__ == "__main__": - uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/manifest.yaml b/manifest.yaml new file mode 100644 index 0000000..ae9540d --- /dev/null +++ b/manifest.yaml @@ -0,0 +1,90 @@ +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: text-classification-pvc + namespace: ecosystem +spec: + accessModes: + - ReadWriteMany + resources: + requests: + storage: 2Gi + +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: text-classification + namespace: ecosystem +spec: + selector: + matchLabels: + app: text-classification + tier: backend + template: + metadata: + labels: + app: text-classification + tier: backend + spec: + volumes: + - name: text-classification-models + persistentVolumeClaim: + claimName: text-classification-pvc + containers: + - name: text-classification + image: leafdev.top/ecosystem/zero-shot-classification:v0.0.1 + env: + - name: HF_DATASETS_CACHE + value: "/app/models" + - name: HF_HOME + value: "/app/models" + - name: HUGGINGFACE_HUB_CACHE + value: "/app/models" + - name: TRANSFORMERS_CACHE + value: "/app/models" +# - name: CUDA_VISIBLE_DEVICES +# value: "1" + ports: + - containerPort: 80 + protocol: TCP + name: http + # resources: + # requests: + # cpu: 1000m + # memory: 1024Mi + volumeMounts: + - mountPath: /app/models + name: text-classification-models +--- +apiVersion: v1 +kind: Service +metadata: + name: text-classification + namespace: ecosystem +spec: + selector: + app: text-classification + tier: backend + type: ClusterIP + ports: + - port: 80 + targetPort: 80 + protocol: TCP + name: http +--- +apiVersion: gateway.networking.k8s.io/v1 +kind: HTTPRoute +metadata: + name: zero-shot-classification-http + namespace: ecosystem +spec: + hostnames: + - text-classification-api-testing.leaflow.cn + parentRefs: + - name: http-gw + namespace: networking + rules: + - backendRefs: + - name: text-classification + port: 80 \ No newline at end of file