From 22c3795f9f5e14e494517444f2912e98533c3349 Mon Sep 17 00:00:00 2001 From: ivamp Date: Fri, 25 Oct 2024 00:23:31 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=B9=E8=BF=9B=20=E5=93=8D=E5=BA=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 101 ++++++++++++++++++++++++++++++++++------------------ response.py | 15 ++++++++ 2 files changed, 82 insertions(+), 34 deletions(-) create mode 100644 response.py diff --git a/main.py b/main.py index 78a35d9..4dc790f 100644 --- a/main.py +++ b/main.py @@ -1,13 +1,18 @@ import asyncio import os +import pprint from typing import Optional import aiohttp # 用于异步 HTTP 请求 import easyocr import torch +import uvicorn from fastapi import FastAPI, UploadFile, File, Query from fastapi.responses import JSONResponse +import response +from response import OCRResponse + MODEL_DIR = "./models" USE_GPU = os.environ.get('USE_GPU', 'false') == 'true' NNPACK = os.environ.get('NNPACK', 'false') == 'true' @@ -83,51 +88,70 @@ async def process_ocr(reader, image_data): # 成功处理后返回数据 return { - "status": "success", - "message": "OCR 处理成功", + "status": bool, "data": output }, 200 except Exception as e: # 处理失败,返回错误信息 return { - "status": "error", - "message": "OCR 处理失败", + "status": bool, "error": str(e), "data": None }, 500 @app.post("/ocr") -async def ocr_image(lang_1: Optional[str] = None, lang_2: Optional[str] = None, image: UploadFile = File(...)): +async def ocr_image(lang_1: Optional[str] = None, lang_2: Optional[str] = None, image: UploadFile = File(...)) \ + -> OCRResponse: # 获取 OCR 读取器并捕获可能的错误 reader, error = await get_reader(lang_1, lang_2) if error: - return JSONResponse(status_code=500, content={ - "status": "error", - "message": "OCR 处理失败", - "error": error, - "data": None - }) + return OCRResponse( + success=False, + message="OCR 处理失败", + data=None, + error=error + ) try: # 获取图片的二进制数据 image_data = await image.read() except Exception as e: - return JSONResponse(status_code=500, content={ - "status": "error", - "message": "无法读取图片文件", - "error": str(e), - "data": None - }) + return OCRResponse( + success=False, + message="无法读取图片文件", + data=None, + error=error + ) # 调用通用的 OCR 处理函数 - response_data, status_code = await process_ocr(reader, image_data) - return JSONResponse(status_code=status_code, content=response_data) + ocr_data, status_code = await process_ocr(reader, image_data) + full_text = "" + ocr_data_list = list() + for entry in ocr_data["data"]: + full_text += entry["text"] + r = response.OCRData( + text=entry["text"], + bbox=entry["bbox"], + confidence=entry["confidence"] + ) + + ocr_data_list.append(r) + # 返回 OCR 处理结果 + pprint.pprint(ocr_data) + + return OCRResponse( + success=True, + data=ocr_data_list, + error=None, + message=None, + ) @app.get("/ocr") -async def ocr_image_from_url(lang_1: Optional[str] = None, lang_2: Optional[str] = None, url: str = Query(...)): +async def ocr_image_from_url(lang_1: Optional[str] = None, lang_2: Optional[str] = None, + url: str = Query(...)) -> OCRResponse: # 使用 aiohttp 进行异步 HTTP 请求 async with aiohttp.ClientSession() as session: try: @@ -136,25 +160,34 @@ async def ocr_image_from_url(lang_1: Optional[str] = None, lang_2: Optional[str] raise Exception(f"无法获取图片, HTTP 状态码: {response.status}") # 读取图片数据 image_data = await response.read() + except Exception as e: - return JSONResponse(status_code=500, content={ - "status": "error", - "message": "无法从 URL 获取图片", - "error": str(e), - "data": None - }) + return OCRResponse( + success=False, + message="无法从 URL 获取图片", + data=None, + error=str(e), + ) # 获取 OCR 读取器并捕获可能的错误 reader, error = await get_reader(lang_1, lang_2) if error: - return JSONResponse(status_code=500, content={ - "status": "error", - "message": "OCR 处理失败", - "error": error, - "data": None - }) + return OCRResponse( + success=False, + message="OCR 处理失败", + data=None, + error=error, + ) # 调用通用的 OCR 处理函数 - response_data, status_code = await process_ocr(reader, image_data) - return JSONResponse(status_code=status_code, content=response_data) + ocr_data, status_code = await process_ocr(reader, image_data) + + return OCRResponse( + success=True, + data=ocr_data, + ) + + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/response.py b/response.py new file mode 100644 index 0000000..9712028 --- /dev/null +++ b/response.py @@ -0,0 +1,15 @@ +from pydantic import BaseModel +from typing import List, Optional + + +class OCRData(BaseModel): + text: str + bbox: List[List[int]] # 这里使用 List[List[int]] 来表示多个坐标点 + confidence: float + + +class OCRResponse(BaseModel): + success: bool + error: Optional[str] + message: Optional[str] + data: Optional[List[OCRData]] = []