改进 模型加载以及增加 GPU 选择
This commit is contained in:
parent
5498c82c6c
commit
7641757e45
@ -1,4 +1,4 @@
|
|||||||
# docker build . --platform linux/amd64 --push -t leafdev.top/ecosystem/zero-shot-classification:v0.0.1
|
# docker build . --platform linux/amd64 --push -t leafdev.top/ecosystem/zero-shot-classification:v0.0.2
|
||||||
FROM python:3.12.7
|
FROM python:3.12.7
|
||||||
|
|
||||||
#
|
#
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
from pprint import pprint
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
|
from config import MODEL_ARGS
|
||||||
|
|
||||||
classifier = pipeline("zero-shot-classification",
|
classifier = pipeline("zero-shot-classification", **MODEL_ARGS)
|
||||||
model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli")
|
|
||||||
|
|
||||||
|
|
||||||
# 返回一个结构化的内容
|
# 返回一个结构化的内容
|
||||||
|
15
config.py
Normal file
15
config.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import os
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
MODEL_ARGS = {
|
||||||
|
"model": "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli",
|
||||||
|
"device": int(os.environ.get("GPU_DEVICE", -1))
|
||||||
|
}
|
||||||
|
|
||||||
|
HTTP_ARGS = {
|
||||||
|
"host": os.environ.get("HOST", "0.0.0.0"),
|
||||||
|
"port": int(os.environ.get("PORT", 8000))
|
||||||
|
}
|
||||||
|
|
||||||
|
pprint(MODEL_ARGS)
|
||||||
|
pprint(HTTP_ARGS)
|
3
main.py
3
main.py
@ -2,6 +2,7 @@ import uvicorn
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import classification
|
import classification
|
||||||
|
from config import HTTP_ARGS
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
@ -24,4 +25,4 @@ def classify(req: TextClassificationRequest) -> TextClassificationResponse:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
uvicorn.run(app, **HTTP_ARGS)
|
||||||
|
@ -45,6 +45,8 @@ spec:
|
|||||||
value: "/app/models"
|
value: "/app/models"
|
||||||
- name: HF_ENDPOINT
|
- name: HF_ENDPOINT
|
||||||
value: "https://hf-mirror.com"
|
value: "https://hf-mirror.com"
|
||||||
|
- name: GPU_DEVICE
|
||||||
|
value: "-1"
|
||||||
# - name: CUDA_VISIBLE_DEVICES
|
# - name: CUDA_VISIBLE_DEVICES
|
||||||
# value: "1"
|
# value: "1"
|
||||||
ports:
|
ports:
|
||||||
|
Loading…
Reference in New Issue
Block a user