import asyncio import os from typing import Optional import aiohttp # 用于异步 HTTP 请求 import easyocr import torch from fastapi import FastAPI, UploadFile, File, Query from fastapi.responses import JSONResponse MODEL_DIR = "./models" USE_GPU = os.environ.get('USE_GPU', 'false') == 'true' NNPACK = os.environ.get('NNPACK', 'false') == 'true' torch.backends.nnpack.enabled = NNPACK app = FastAPI() # 初始化常用的语言组合 OCR 读取器 reader_sim = easyocr.Reader( lang_list=['ch_sim', 'en'], model_storage_directory=MODEL_DIR, gpu=USE_GPU) # 存储 OCR 读取器的缓存字典 readers = {} # 使用 asyncio 的异步锁 lock = asyncio.Lock() # 获取或初始化 OCR 读取器 async def get_reader(lang_1: Optional[str], lang_2: Optional[str]): global readers, reader_sim, lock # 语言组合 lang_combination = tuple(sorted([lang_1, lang_2] if lang_1 and lang_2 else [lang_1 or lang_2])) # 默认使用中文简体和英文的 OCR 读取器 if not lang_1 and not lang_2: return reader_sim, None # 没有错误 # 检查缓存中是否已经存在该语言组合的读取器 if lang_combination in readers: return readers[lang_combination], None # 没有错误 # 使用锁避免并发问题,仅在初始化新读取器时需要锁 async with lock: # 双重检查,确保其他请求在等待锁时未初始化该读取器 if lang_combination not in readers: try: # 尝试初始化 OCR 读取器 readers[lang_combination] = easyocr.Reader(lang_list=list(lang_combination), model_storage_directory=MODEL_DIR, gpu=USE_GPU) except ValueError as ve: return None, str(ve) # 返回错误信息 except Exception as e: return None, f"Unexpected error: {str(e)}" # 返回意外错误信息 return readers[lang_combination], None # 没有错误 # 通用的 OCR 处理函数 async def process_ocr(reader, image_data): try: # 执行 OCR 处理 result = reader.readtext(image_data) # 创建一个用于存储提取信息的列表 output = [] # 遍历识别结果 for (bbox, text, confidence) in result: # 将 bbox 中的每个坐标点转换为 Python 原生 int 类型 bbox = [[int(coord) for coord in point] for point in bbox] entry = { "text": text, "bbox": bbox, # bounding box 位置信息 "confidence": confidence # 置信度 } output.append(entry) # 成功处理后返回数据 return { "status": "success", "message": "OCR 处理成功", "data": output }, 200 except Exception as e: # 处理失败,返回错误信息 return { "status": "error", "message": "OCR 处理失败", "error": str(e), "data": None }, 500 @app.post("/ocr") async def ocr_image(lang_1: Optional[str] = None, lang_2: Optional[str] = None, image: UploadFile = File(...)): # 获取 OCR 读取器并捕获可能的错误 reader, error = await get_reader(lang_1, lang_2) if error: return JSONResponse(status_code=500, content={ "status": "error", "message": "OCR 处理失败", "error": error, "data": None }) try: # 获取图片的二进制数据 image_data = await image.read() except Exception as e: return JSONResponse(status_code=500, content={ "status": "error", "message": "无法读取图片文件", "error": str(e), "data": None }) # 调用通用的 OCR 处理函数 response_data, status_code = await process_ocr(reader, image_data) return JSONResponse(status_code=status_code, content=response_data) @app.get("/ocr") async def ocr_image_from_url(lang_1: Optional[str] = None, lang_2: Optional[str] = None, url: str = Query(...)): # 使用 aiohttp 进行异步 HTTP 请求 async with aiohttp.ClientSession() as session: try: async with session.get(url) as response: if response.status != 200: raise Exception(f"无法获取图片, HTTP 状态码: {response.status}") # 读取图片数据 image_data = await response.read() except Exception as e: return JSONResponse(status_code=500, content={ "status": "error", "message": "无法从 URL 获取图片", "error": str(e), "data": None }) # 获取 OCR 读取器并捕获可能的错误 reader, error = await get_reader(lang_1, lang_2) if error: return JSONResponse(status_code=500, content={ "status": "error", "message": "OCR 处理失败", "error": error, "data": None }) # 调用通用的 OCR 处理函数 response_data, status_code = await process_ocr(reader, image_data) return JSONResponse(status_code=status_code, content=response_data)