ocr/main.py
2024-10-01 18:26:31 +08:00

148 lines
4.8 KiB
Python

from fastapi import FastAPI, UploadFile, File, Query
from fastapi.responses import JSONResponse
import easyocr
from typing import Optional
import asyncio
import aiohttp # 用于异步 HTTP 请求
app = FastAPI()
# 初始化常用的语言组合 OCR 读取器
reader_sim = easyocr.Reader(['ch_sim', 'en'])
# 存储 OCR 读取器的缓存字典
readers = {}
# 使用 asyncio 的异步锁
lock = asyncio.Lock()
# 获取或初始化 OCR 读取器
async def get_reader(lang_1: Optional[str], lang_2: Optional[str]):
global readers, reader_sim, lock
# 语言组合
lang_combination = tuple(sorted([lang_1, lang_2] if lang_1 and lang_2 else [lang_1 or lang_2]))
# 默认使用中文简体和英文的 OCR 读取器
if not lang_1 and not lang_2:
return reader_sim, None # 没有错误
# 检查缓存中是否已经存在该语言组合的读取器
if lang_combination in readers:
return readers[lang_combination], None # 没有错误
# 使用锁避免并发问题,仅在初始化新读取器时需要锁
async with lock:
# 双重检查,确保其他请求在等待锁时未初始化该读取器
if lang_combination not in readers:
try:
# 尝试初始化 OCR 读取器
readers[lang_combination] = easyocr.Reader(list(lang_combination))
except ValueError as ve:
return None, str(ve) # 返回错误信息
except Exception as e:
return None, f"Unexpected error: {str(e)}" # 返回意外错误信息
return readers[lang_combination], None # 没有错误
# 通用的 OCR 处理函数
async def process_ocr(reader, image_data):
try:
# 执行 OCR 处理
result = reader.readtext(image_data)
# 创建一个用于存储提取信息的列表
output = []
# 遍历识别结果
for (bbox, text, confidence) in result:
# 将 bbox 中的每个坐标点转换为 Python 原生 int 类型
bbox = [[int(coord) for coord in point] for point in bbox]
entry = {
"text": text,
"bbox": bbox, # bounding box 位置信息
"confidence": confidence # 置信度
}
output.append(entry)
# 成功处理后返回数据
return {
"status": "success",
"message": "OCR 处理成功",
"data": output
}, 200
except Exception as e:
# 处理失败,返回错误信息
return {
"status": "error",
"message": "OCR 处理失败",
"error": str(e),
"data": None
}, 500
@app.post("/ocr")
async def ocr_image(lang_1: Optional[str] = None, lang_2: Optional[str] = None, image: UploadFile = File(...)):
# 获取 OCR 读取器并捕获可能的错误
reader, error = await get_reader(lang_1, lang_2)
if error:
return JSONResponse(status_code=500, content={
"status": "error",
"message": "OCR 处理失败",
"error": error,
"data": None
})
try:
# 获取图片的二进制数据
image_data = await image.read()
except Exception as e:
return JSONResponse(status_code=500, content={
"status": "error",
"message": "无法读取图片文件",
"error": str(e),
"data": None
})
# 调用通用的 OCR 处理函数
response_data, status_code = await process_ocr(reader, image_data)
return JSONResponse(status_code=status_code, content=response_data)
@app.get("/ocr")
async def ocr_image_from_url(lang_1: Optional[str] = None, lang_2: Optional[str] = None, url: str = Query(...)):
# 使用 aiohttp 进行异步 HTTP 请求
async with aiohttp.ClientSession() as session:
try:
async with session.get(url) as response:
if response.status != 200:
raise Exception(f"无法获取图片, HTTP 状态码: {response.status}")
# 读取图片数据
image_data = await response.read()
except Exception as e:
return JSONResponse(status_code=500, content={
"status": "error",
"message": "无法从 URL 获取图片",
"error": str(e),
"data": None
})
# 获取 OCR 读取器并捕获可能的错误
reader, error = await get_reader(lang_1, lang_2)
if error:
return JSONResponse(status_code=500, content={
"status": "error",
"message": "OCR 处理失败",
"error": error,
"data": None
})
# 调用通用的 OCR 处理函数
response_data, status_code = await process_ocr(reader, image_data)
return JSONResponse(status_code=status_code, content=response_data)