From 7641757e452911f6b31979dec8e442172a27ed07 Mon Sep 17 00:00:00 2001 From: ivamp Date: Mon, 4 Nov 2024 16:27:35 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=B9=E8=BF=9B=20=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E5=8A=A0=E8=BD=BD=E4=BB=A5=E5=8F=8A=E5=A2=9E=E5=8A=A0=20GPU=20?= =?UTF-8?q?=E9=80=89=E6=8B=A9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Dockerfile | 2 +- classification.py | 6 ++---- config.py | 15 +++++++++++++++ main.py | 3 ++- manifest.yaml | 2 ++ 5 files changed, 22 insertions(+), 6 deletions(-) create mode 100644 config.py diff --git a/Dockerfile b/Dockerfile index f8e4b74..2c7b366 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -# docker build . --platform linux/amd64 --push -t leafdev.top/ecosystem/zero-shot-classification:v0.0.1 +# docker build . --platform linux/amd64 --push -t leafdev.top/ecosystem/zero-shot-classification:v0.0.2 FROM python:3.12.7 # diff --git a/classification.py b/classification.py index 8b585c5..b003faf 100644 --- a/classification.py +++ b/classification.py @@ -1,10 +1,8 @@ -from pprint import pprint - from pydantic import BaseModel from transformers import pipeline +from config import MODEL_ARGS -classifier = pipeline("zero-shot-classification", - model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli") +classifier = pipeline("zero-shot-classification", **MODEL_ARGS) # 返回一个结构化的内容 diff --git a/config.py b/config.py new file mode 100644 index 0000000..6641459 --- /dev/null +++ b/config.py @@ -0,0 +1,15 @@ +import os +from pprint import pprint + +MODEL_ARGS = { + "model": "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli", + "device": int(os.environ.get("GPU_DEVICE", -1)) +} + +HTTP_ARGS = { + "host": os.environ.get("HOST", "0.0.0.0"), + "port": int(os.environ.get("PORT", 8000)) +} + +pprint(MODEL_ARGS) +pprint(HTTP_ARGS) \ No newline at end of file diff --git a/main.py b/main.py index 91c4bf8..3d9a603 100644 --- a/main.py +++ b/main.py @@ -2,6 +2,7 @@ import uvicorn from fastapi import FastAPI from pydantic import BaseModel import classification +from config import HTTP_ARGS app = FastAPI() @@ -24,4 +25,4 @@ def classify(req: TextClassificationRequest) -> TextClassificationResponse: if __name__ == "__main__": - uvicorn.run(app, host="0.0.0.0", port=8000) + uvicorn.run(app, **HTTP_ARGS) diff --git a/manifest.yaml b/manifest.yaml index 9465bd3..f478162 100644 --- a/manifest.yaml +++ b/manifest.yaml @@ -45,6 +45,8 @@ spec: value: "/app/models" - name: HF_ENDPOINT value: "https://hf-mirror.com" + - name: GPU_DEVICE + value: "-1" # - name: CUDA_VISIBLE_DEVICES # value: "1" ports: