改进 响应

This commit is contained in:
ivamp 2024-10-25 00:23:31 +08:00
parent 941b5a04d5
commit 22c3795f9f
2 changed files with 82 additions and 34 deletions

101
main.py
View File

@ -1,13 +1,18 @@
import asyncio import asyncio
import os import os
import pprint
from typing import Optional from typing import Optional
import aiohttp # 用于异步 HTTP 请求 import aiohttp # 用于异步 HTTP 请求
import easyocr import easyocr
import torch import torch
import uvicorn
from fastapi import FastAPI, UploadFile, File, Query from fastapi import FastAPI, UploadFile, File, Query
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
import response
from response import OCRResponse
MODEL_DIR = "./models" MODEL_DIR = "./models"
USE_GPU = os.environ.get('USE_GPU', 'false') == 'true' USE_GPU = os.environ.get('USE_GPU', 'false') == 'true'
NNPACK = os.environ.get('NNPACK', 'false') == 'true' NNPACK = os.environ.get('NNPACK', 'false') == 'true'
@ -83,51 +88,70 @@ async def process_ocr(reader, image_data):
# 成功处理后返回数据 # 成功处理后返回数据
return { return {
"status": "success", "status": bool,
"message": "OCR 处理成功",
"data": output "data": output
}, 200 }, 200
except Exception as e: except Exception as e:
# 处理失败,返回错误信息 # 处理失败,返回错误信息
return { return {
"status": "error", "status": bool,
"message": "OCR 处理失败",
"error": str(e), "error": str(e),
"data": None "data": None
}, 500 }, 500
@app.post("/ocr") @app.post("/ocr")
async def ocr_image(lang_1: Optional[str] = None, lang_2: Optional[str] = None, image: UploadFile = File(...)): async def ocr_image(lang_1: Optional[str] = None, lang_2: Optional[str] = None, image: UploadFile = File(...)) \
-> OCRResponse:
# 获取 OCR 读取器并捕获可能的错误 # 获取 OCR 读取器并捕获可能的错误
reader, error = await get_reader(lang_1, lang_2) reader, error = await get_reader(lang_1, lang_2)
if error: if error:
return JSONResponse(status_code=500, content={ return OCRResponse(
"status": "error", success=False,
"message": "OCR 处理失败", message="OCR 处理失败",
"error": error, data=None,
"data": None error=error
}) )
try: try:
# 获取图片的二进制数据 # 获取图片的二进制数据
image_data = await image.read() image_data = await image.read()
except Exception as e: except Exception as e:
return JSONResponse(status_code=500, content={ return OCRResponse(
"status": "error", success=False,
"message": "无法读取图片文件", message="无法读取图片文件",
"error": str(e), data=None,
"data": None error=error
}) )
# 调用通用的 OCR 处理函数 # 调用通用的 OCR 处理函数
response_data, status_code = await process_ocr(reader, image_data) ocr_data, status_code = await process_ocr(reader, image_data)
return JSONResponse(status_code=status_code, content=response_data) full_text = ""
ocr_data_list = list()
for entry in ocr_data["data"]:
full_text += entry["text"]
r = response.OCRData(
text=entry["text"],
bbox=entry["bbox"],
confidence=entry["confidence"]
)
ocr_data_list.append(r)
# 返回 OCR 处理结果
pprint.pprint(ocr_data)
return OCRResponse(
success=True,
data=ocr_data_list,
error=None,
message=None,
)
@app.get("/ocr") @app.get("/ocr")
async def ocr_image_from_url(lang_1: Optional[str] = None, lang_2: Optional[str] = None, url: str = Query(...)): async def ocr_image_from_url(lang_1: Optional[str] = None, lang_2: Optional[str] = None,
url: str = Query(...)) -> OCRResponse:
# 使用 aiohttp 进行异步 HTTP 请求 # 使用 aiohttp 进行异步 HTTP 请求
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
try: try:
@ -136,25 +160,34 @@ async def ocr_image_from_url(lang_1: Optional[str] = None, lang_2: Optional[str]
raise Exception(f"无法获取图片, HTTP 状态码: {response.status}") raise Exception(f"无法获取图片, HTTP 状态码: {response.status}")
# 读取图片数据 # 读取图片数据
image_data = await response.read() image_data = await response.read()
except Exception as e: except Exception as e:
return JSONResponse(status_code=500, content={ return OCRResponse(
"status": "error", success=False,
"message": "无法从 URL 获取图片", message="无法从 URL 获取图片",
"error": str(e), data=None,
"data": None error=str(e),
}) )
# 获取 OCR 读取器并捕获可能的错误 # 获取 OCR 读取器并捕获可能的错误
reader, error = await get_reader(lang_1, lang_2) reader, error = await get_reader(lang_1, lang_2)
if error: if error:
return JSONResponse(status_code=500, content={ return OCRResponse(
"status": "error", success=False,
"message": "OCR 处理失败", message="OCR 处理失败",
"error": error, data=None,
"data": None error=error,
}) )
# 调用通用的 OCR 处理函数 # 调用通用的 OCR 处理函数
response_data, status_code = await process_ocr(reader, image_data) ocr_data, status_code = await process_ocr(reader, image_data)
return JSONResponse(status_code=status_code, content=response_data)
return OCRResponse(
success=True,
data=ocr_data,
)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)

15
response.py Normal file
View File

@ -0,0 +1,15 @@
from pydantic import BaseModel
from typing import List, Optional
class OCRData(BaseModel):
text: str
bbox: List[List[int]] # 这里使用 List[List[int]] 来表示多个坐标点
confidence: float
class OCRResponse(BaseModel):
success: bool
error: Optional[str]
message: Optional[str]
data: Optional[List[OCRData]] = []