29 lines
636 B
Python
29 lines
636 B
Python
|
from typing import Union
|
||
|
|
||
|
import uvicorn
|
||
|
from fastapi import FastAPI
|
||
|
from pydantic import BaseModel
|
||
|
import classification
|
||
|
|
||
|
app = FastAPI()
|
||
|
|
||
|
class TextClassificationRequest(BaseModel):
|
||
|
text: str
|
||
|
labels: list[str]
|
||
|
|
||
|
class TextClassificationResponse(BaseModel):
|
||
|
prediction: str
|
||
|
labels: list[str]
|
||
|
|
||
|
|
||
|
|
||
|
@app.post("/classify")
|
||
|
def classify(req: TextClassificationRequest) -> TextClassificationResponse:
|
||
|
result = classification.classify(req.text, req.labels)
|
||
|
|
||
|
return TextClassificationResponse(prediction=result.labels[0], labels=result.labels)
|
||
|
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|