ocr/main.py

194 lines
5.6 KiB
Python
Raw Normal View History

2024-10-01 10:26:31 +00:00
import asyncio
2024-10-02 13:20:06 +00:00
import os
2024-10-24 16:23:31 +00:00
import pprint
2024-10-02 13:20:06 +00:00
from typing import Optional
2024-10-01 10:26:31 +00:00
import aiohttp # 用于异步 HTTP 请求
2024-10-02 13:20:06 +00:00
import easyocr
import torch
2024-10-24 16:23:31 +00:00
import uvicorn
2024-10-02 13:20:06 +00:00
from fastapi import FastAPI, UploadFile, File, Query
from fastapi.responses import JSONResponse
2024-10-01 10:26:31 +00:00
2024-10-24 16:23:31 +00:00
import response
from response import OCRResponse
2024-10-02 09:17:30 +00:00
MODEL_DIR = "./models"
2024-10-02 13:20:06 +00:00
USE_GPU = os.environ.get('USE_GPU', 'false') == 'true'
NNPACK = os.environ.get('NNPACK', 'false') == 'true'
torch.backends.nnpack.enabled = NNPACK
2024-10-01 10:26:31 +00:00
app = FastAPI()
# 初始化常用的语言组合 OCR 读取器
2024-10-02 13:20:06 +00:00
reader_sim = easyocr.Reader(
lang_list=['ch_sim', 'en'],
model_storage_directory=MODEL_DIR,
gpu=USE_GPU)
2024-10-01 10:26:31 +00:00
# 存储 OCR 读取器的缓存字典
readers = {}
# 使用 asyncio 的异步锁
lock = asyncio.Lock()
# 获取或初始化 OCR 读取器
async def get_reader(lang_1: Optional[str], lang_2: Optional[str]):
global readers, reader_sim, lock
# 语言组合
lang_combination = tuple(sorted([lang_1, lang_2] if lang_1 and lang_2 else [lang_1 or lang_2]))
# 默认使用中文简体和英文的 OCR 读取器
if not lang_1 and not lang_2:
return reader_sim, None # 没有错误
# 检查缓存中是否已经存在该语言组合的读取器
if lang_combination in readers:
return readers[lang_combination], None # 没有错误
# 使用锁避免并发问题,仅在初始化新读取器时需要锁
async with lock:
# 双重检查,确保其他请求在等待锁时未初始化该读取器
if lang_combination not in readers:
try:
# 尝试初始化 OCR 读取器
2024-10-02 09:17:30 +00:00
readers[lang_combination] = easyocr.Reader(lang_list=list(lang_combination),
2024-10-02 13:20:06 +00:00
model_storage_directory=MODEL_DIR,
gpu=USE_GPU)
2024-10-01 10:26:31 +00:00
except ValueError as ve:
return None, str(ve) # 返回错误信息
except Exception as e:
return None, f"Unexpected error: {str(e)}" # 返回意外错误信息
return readers[lang_combination], None # 没有错误
# 通用的 OCR 处理函数
async def process_ocr(reader, image_data):
try:
# 执行 OCR 处理
result = reader.readtext(image_data)
# 创建一个用于存储提取信息的列表
output = []
# 遍历识别结果
for (bbox, text, confidence) in result:
# 将 bbox 中的每个坐标点转换为 Python 原生 int 类型
bbox = [[int(coord) for coord in point] for point in bbox]
entry = {
"text": text,
"bbox": bbox, # bounding box 位置信息
"confidence": confidence # 置信度
}
output.append(entry)
# 成功处理后返回数据
return {
2024-10-24 16:23:31 +00:00
"status": bool,
2024-10-01 10:26:31 +00:00
"data": output
}, 200
except Exception as e:
# 处理失败,返回错误信息
return {
2024-10-24 16:23:31 +00:00
"status": bool,
2024-10-01 10:26:31 +00:00
"error": str(e),
"data": None
}, 500
@app.post("/ocr")
2024-10-24 16:23:31 +00:00
async def ocr_image(lang_1: Optional[str] = None, lang_2: Optional[str] = None, image: UploadFile = File(...)) \
-> OCRResponse:
2024-10-01 10:26:31 +00:00
# 获取 OCR 读取器并捕获可能的错误
reader, error = await get_reader(lang_1, lang_2)
if error:
2024-10-24 16:23:31 +00:00
return OCRResponse(
success=False,
message="OCR 处理失败",
data=None,
error=error
)
2024-10-01 10:26:31 +00:00
try:
# 获取图片的二进制数据
image_data = await image.read()
except Exception as e:
2024-10-24 16:23:31 +00:00
return OCRResponse(
success=False,
message="无法读取图片文件",
data=None,
error=error
)
2024-10-01 10:26:31 +00:00
# 调用通用的 OCR 处理函数
2024-10-24 16:23:31 +00:00
ocr_data, status_code = await process_ocr(reader, image_data)
full_text = ""
ocr_data_list = list()
for entry in ocr_data["data"]:
full_text += entry["text"]
r = response.OCRData(
text=entry["text"],
bbox=entry["bbox"],
confidence=entry["confidence"]
)
ocr_data_list.append(r)
# 返回 OCR 处理结果
pprint.pprint(ocr_data)
return OCRResponse(
success=True,
data=ocr_data_list,
error=None,
message=None,
)
2024-10-01 10:26:31 +00:00
@app.get("/ocr")
2024-10-24 16:23:31 +00:00
async def ocr_image_from_url(lang_1: Optional[str] = None, lang_2: Optional[str] = None,
url: str = Query(...)) -> OCRResponse:
2024-10-01 10:26:31 +00:00
# 使用 aiohttp 进行异步 HTTP 请求
async with aiohttp.ClientSession() as session:
try:
async with session.get(url) as response:
if response.status != 200:
raise Exception(f"无法获取图片, HTTP 状态码: {response.status}")
# 读取图片数据
image_data = await response.read()
2024-10-24 16:23:31 +00:00
2024-10-01 10:26:31 +00:00
except Exception as e:
2024-10-24 16:23:31 +00:00
return OCRResponse(
success=False,
message="无法从 URL 获取图片",
data=None,
error=str(e),
)
2024-10-01 10:26:31 +00:00
# 获取 OCR 读取器并捕获可能的错误
reader, error = await get_reader(lang_1, lang_2)
if error:
2024-10-24 16:23:31 +00:00
return OCRResponse(
success=False,
message="OCR 处理失败",
data=None,
error=error,
)
2024-10-01 10:26:31 +00:00
# 调用通用的 OCR 处理函数
2024-10-24 16:23:31 +00:00
ocr_data, status_code = await process_ocr(reader, image_data)
return OCRResponse(
success=True,
data=ocr_data,
)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)