zero-shot-classification-zh/main.py

29 lines
623 B
Python

import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
import classification
from config import HTTP_ARGS
app = FastAPI()
class TextClassificationRequest(BaseModel):
text: str
labels: list[str]
class TextClassificationResponse(BaseModel):
prediction: str
ranks: list[str]
@app.post("/classify")
def classify(req: TextClassificationRequest) -> TextClassificationResponse:
result = classification.classify(req.text, req.labels)
return TextClassificationResponse(prediction=result.prediction, ranks=result.rank)
if __name__ == "__main__":
uvicorn.run(app, **HTTP_ARGS)