改进
This commit is contained in:
parent
fabe3c908e
commit
69067d948f
@ -1,4 +1,4 @@
|
|||||||
# docker build . --platform linux/amd64 --push -t leafdev.top/ecosystem/zero-shot-classification:v0.0.3
|
# docker build . --platform linux/amd64 --push -t leafdev.top/ecosystem/zero-shot-classification:v0.0.4
|
||||||
FROM python:3.12.7
|
FROM python:3.12.7
|
||||||
|
|
||||||
#
|
#
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from pprint import pprint
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
from config import MODEL_ARGS
|
from config import MODEL_ARGS
|
||||||
@ -5,25 +7,27 @@ from config import MODEL_ARGS
|
|||||||
classifier = pipeline("zero-shot-classification", **MODEL_ARGS)
|
classifier = pipeline("zero-shot-classification", **MODEL_ARGS)
|
||||||
|
|
||||||
|
|
||||||
# 返回一个结构化的内容
|
|
||||||
class ClassifyResult(BaseModel):
|
class ClassifyResult(BaseModel):
|
||||||
sequence: str
|
sequence: str
|
||||||
rank: list
|
|
||||||
scores: list
|
|
||||||
prediction: str
|
prediction: str
|
||||||
|
prediction_score: float
|
||||||
|
ranks: list[str]
|
||||||
|
labels: list[str]
|
||||||
|
scores: list[float]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def classify(text: str, labels: list):
|
def classify(text: str, labels: list):
|
||||||
output = classifier(text, labels)
|
output = classifier(text, labels)
|
||||||
|
|
||||||
# pprint(output)
|
|
||||||
|
|
||||||
# 根据 score,寻找最高的 label
|
# 根据 score,寻找最高的 label
|
||||||
prediction_rank = output['scores'].index(max(output['scores']))
|
prediction_rank = output['scores'].index(max(output['scores']))
|
||||||
|
|
||||||
return ClassifyResult(
|
return ClassifyResult(
|
||||||
sequence=text,
|
sequence=text,
|
||||||
rank=output['labels'],
|
prediction=output['labels'][prediction_rank],
|
||||||
|
prediction_score=output['scores'][prediction_rank],
|
||||||
|
ranks=output['labels'],
|
||||||
|
labels=labels,
|
||||||
scores=output['scores'],
|
scores=output['scores'],
|
||||||
prediction=output['labels'][prediction_rank]
|
|
||||||
)
|
)
|
||||||
|
13
main.py
13
main.py
@ -1,7 +1,7 @@
|
|||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import classification
|
from classification import ClassifyResult, classify as classifier
|
||||||
from config import HTTP_ARGS
|
from config import HTTP_ARGS
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
@ -12,16 +12,15 @@ class TextClassificationRequest(BaseModel):
|
|||||||
labels: list[str]
|
labels: list[str]
|
||||||
|
|
||||||
|
|
||||||
class TextClassificationResponse(BaseModel):
|
# class TextClassificationResponse(BaseModel):
|
||||||
prediction: str
|
# prediction: str
|
||||||
ranks: list[str]
|
# ranks: list[str]
|
||||||
|
|
||||||
|
|
||||||
@app.post("/classify")
|
@app.post("/classify")
|
||||||
def classify(req: TextClassificationRequest) -> TextClassificationResponse:
|
def classify(req: TextClassificationRequest) -> ClassifyResult:
|
||||||
result = classification.classify(req.text, req.labels)
|
|
||||||
|
|
||||||
return TextClassificationResponse(prediction=result.prediction, ranks=result.rank)
|
return classifier(req.text, req.labels)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -33,7 +33,7 @@ spec:
|
|||||||
claimName: text-classification-pvc
|
claimName: text-classification-pvc
|
||||||
containers:
|
containers:
|
||||||
- name: text-classification
|
- name: text-classification
|
||||||
image: leafdev.top/ecosystem/zero-shot-classification:v0.0.1
|
image: leafdev.top/ecosystem/zero-shot-classification:v0.0.4
|
||||||
env:
|
env:
|
||||||
- name: HF_DATASETS_CACHE
|
- name: HF_DATASETS_CACHE
|
||||||
value: "/app/models"
|
value: "/app/models"
|
||||||
|
Loading…
Reference in New Issue
Block a user