改进 模型加载以及增加 GPU 选择
This commit is contained in:
parent
5498c82c6c
commit
7641757e45
@ -1,4 +1,4 @@
|
||||
# docker build . --platform linux/amd64 --push -t leafdev.top/ecosystem/zero-shot-classification:v0.0.1
|
||||
# docker build . --platform linux/amd64 --push -t leafdev.top/ecosystem/zero-shot-classification:v0.0.2
|
||||
FROM python:3.12.7
|
||||
|
||||
#
|
||||
|
@ -1,10 +1,8 @@
|
||||
from pprint import pprint
|
||||
|
||||
from pydantic import BaseModel
|
||||
from transformers import pipeline
|
||||
from config import MODEL_ARGS
|
||||
|
||||
classifier = pipeline("zero-shot-classification",
|
||||
model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli")
|
||||
classifier = pipeline("zero-shot-classification", **MODEL_ARGS)
|
||||
|
||||
|
||||
# 返回一个结构化的内容
|
||||
|
15
config.py
Normal file
15
config.py
Normal file
@ -0,0 +1,15 @@
|
||||
import os
|
||||
from pprint import pprint
|
||||
|
||||
MODEL_ARGS = {
|
||||
"model": "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli",
|
||||
"device": int(os.environ.get("GPU_DEVICE", -1))
|
||||
}
|
||||
|
||||
HTTP_ARGS = {
|
||||
"host": os.environ.get("HOST", "0.0.0.0"),
|
||||
"port": int(os.environ.get("PORT", 8000))
|
||||
}
|
||||
|
||||
pprint(MODEL_ARGS)
|
||||
pprint(HTTP_ARGS)
|
3
main.py
3
main.py
@ -2,6 +2,7 @@ import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
import classification
|
||||
from config import HTTP_ARGS
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@ -24,4 +25,4 @@ def classify(req: TextClassificationRequest) -> TextClassificationResponse:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
uvicorn.run(app, **HTTP_ARGS)
|
||||
|
@ -45,6 +45,8 @@ spec:
|
||||
value: "/app/models"
|
||||
- name: HF_ENDPOINT
|
||||
value: "https://hf-mirror.com"
|
||||
- name: GPU_DEVICE
|
||||
value: "-1"
|
||||
# - name: CUDA_VISIBLE_DEVICES
|
||||
# value: "1"
|
||||
ports:
|
||||
|
Loading…
Reference in New Issue
Block a user