zero-shot-classification-zh/main.py

29 lines
636 B
Python
Raw Normal View History

2024-10-24 03:12:23 +00:00
from typing import Union
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
import classification
app = FastAPI()
class TextClassificationRequest(BaseModel):
text: str
labels: list[str]
class TextClassificationResponse(BaseModel):
prediction: str
labels: list[str]
@app.post("/classify")
def classify(req: TextClassificationRequest) -> TextClassificationResponse:
result = classification.classify(req.text, req.labels)
return TextClassificationResponse(prediction=result.labels[0], labels=result.labels)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)