Remove Freeform and Find from UI. Allow Description to be added to Reviewed job
This commit is contained in:
488
backend/main.py
488
backend/main.py
@@ -1,8 +1,6 @@
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
import tempfile
|
||||
import shutil
|
||||
import base64
|
||||
from typing import List, Dict, Any, Optional
|
||||
from contextlib import asynccontextmanager
|
||||
@@ -12,8 +10,6 @@ from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Query
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
|
||||
from pydantic import BaseModel
|
||||
import torch
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from PIL import Image
|
||||
import uvicorn
|
||||
from decouple import config as env_config
|
||||
@@ -28,20 +24,29 @@ from pdf_utils import (
|
||||
)
|
||||
from format_converter import DocumentConverter
|
||||
from database import init_db, get_db
|
||||
from providers import (
|
||||
build_registry,
|
||||
parse_detections,
|
||||
clean_grounding_text,
|
||||
ProviderError,
|
||||
GROUNDING_MODES,
|
||||
)
|
||||
|
||||
OCR_IMAGES_DIR = env_config("OCR_IMAGES_DIR", default="/data/ocr_images")
|
||||
|
||||
# -----------------------------
|
||||
# Lifespan context for model loading
|
||||
# Lifespan context
|
||||
# -----------------------------
|
||||
model = None
|
||||
tokenizer = None
|
||||
# The model registry holds all available OCR providers. Local models (e.g.
|
||||
# DeepSeek-OCR) are loaded lazily on first use so an Ollama-only deployment
|
||||
# starts instantly and never touches the GPU.
|
||||
registry = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Load model on startup, cleanup on shutdown"""
|
||||
global model, tokenizer
|
||||
|
||||
"""Build the model registry on startup."""
|
||||
global registry
|
||||
|
||||
# Image storage directory
|
||||
os.makedirs(OCR_IMAGES_DIR, exist_ok=True)
|
||||
|
||||
@@ -51,42 +56,11 @@ async def lifespan(app: FastAPI):
|
||||
except Exception as exc:
|
||||
print(f"Warning: database initialization failed: {exc}")
|
||||
|
||||
# Environment setup
|
||||
os.environ.pop("TRANSFORMERS_CACHE", None)
|
||||
MODEL_NAME = env_config("MODEL_NAME", default="deepseek-ai/DeepSeek-OCR")
|
||||
HF_HOME = env_config("HF_HOME", default="/models")
|
||||
os.makedirs(HF_HOME, exist_ok=True)
|
||||
|
||||
# Load model
|
||||
print(f"🚀 Loading {MODEL_NAME}...")
|
||||
torch_dtype = torch.bfloat16
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
MODEL_NAME,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
MODEL_NAME,
|
||||
trust_remote_code=True,
|
||||
use_safetensors=True,
|
||||
attn_implementation="eager",
|
||||
torch_dtype=torch_dtype,
|
||||
).eval().to("cuda")
|
||||
|
||||
# Pad token setup
|
||||
try:
|
||||
if getattr(tokenizer, "pad_token_id", None) is None and getattr(tokenizer, "eos_token_id", None) is not None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
if getattr(model.config, "pad_token_id", None) is None and getattr(tokenizer, "pad_token_id", None) is not None:
|
||||
model.config.pad_token_id = tokenizer.pad_token_id
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print("✅ Model loaded and ready!")
|
||||
|
||||
# OCR model registry (providers load their models lazily)
|
||||
registry = build_registry()
|
||||
|
||||
yield
|
||||
|
||||
|
||||
# Cleanup
|
||||
print("🛑 Shutting down...")
|
||||
|
||||
@@ -112,155 +86,6 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# -----------------------------
|
||||
# Prompt builder
|
||||
# -----------------------------
|
||||
def build_prompt(
|
||||
mode: str,
|
||||
user_prompt: str,
|
||||
grounding: bool,
|
||||
find_term: Optional[str],
|
||||
schema: Optional[str],
|
||||
include_caption: bool,
|
||||
) -> str:
|
||||
"""Build the prompt based on mode"""
|
||||
parts: List[str] = ["<image>"]
|
||||
mode_requires_grounding = mode in {"find_ref", "layout_map", "pii_redact"}
|
||||
if grounding or mode_requires_grounding:
|
||||
parts.append("<|grounding|>")
|
||||
|
||||
instruction = ""
|
||||
if mode == "plain_ocr":
|
||||
instruction = "Free OCR."
|
||||
elif mode == "markdown":
|
||||
instruction = "Convert the document to markdown."
|
||||
elif mode == "tables_csv":
|
||||
instruction = (
|
||||
"Extract every table and output CSV only. "
|
||||
"Use commas, minimal quoting. If multiple tables, separate with a line containing '---'."
|
||||
)
|
||||
elif mode == "tables_md":
|
||||
instruction = "Extract every table as GitHub-flavored Markdown tables. Output only the tables."
|
||||
elif mode == "kv_json":
|
||||
schema_text = schema.strip() if schema else "{}"
|
||||
instruction = (
|
||||
"Extract key fields and return strict JSON only. "
|
||||
f"Use this schema (fill the values): {schema_text}"
|
||||
)
|
||||
elif mode == "figure_chart":
|
||||
instruction = (
|
||||
"Parse the figure. First extract any numeric series as a two-column table (x,y). "
|
||||
"Then summarize the chart in 2 sentences. Output the table, then a line '---', then the summary."
|
||||
)
|
||||
elif mode == "find_ref":
|
||||
key = (find_term or "").strip() or "Total"
|
||||
instruction = f"Locate <|ref|>{key}<|/ref|> in the image."
|
||||
elif mode == "layout_map":
|
||||
instruction = (
|
||||
'Return a JSON array of blocks with fields {"type":["title","paragraph","table","figure"],'
|
||||
'"box":[x1,y1,x2,y2]}. Do not include any text content.'
|
||||
)
|
||||
elif mode == "pii_redact":
|
||||
instruction = (
|
||||
'Find all occurrences of emails, phone numbers, postal addresses, and IBANs. '
|
||||
'Return a JSON array of objects {label, text, box:[x1,y1,x2,y2]}.'
|
||||
)
|
||||
elif mode == "multilingual":
|
||||
instruction = "Free OCR. Detect the language automatically and output in the same script."
|
||||
elif mode == "describe":
|
||||
instruction = "Describe this image. Focus on visible key elements."
|
||||
elif mode == "freeform":
|
||||
instruction = user_prompt.strip() if user_prompt else "OCR this image."
|
||||
else:
|
||||
instruction = "OCR this image."
|
||||
|
||||
if include_caption and mode not in {"describe"}:
|
||||
instruction = instruction + "\nThen add a one-paragraph description of the image."
|
||||
|
||||
parts.append(instruction)
|
||||
return "\n".join(parts)
|
||||
|
||||
# -----------------------------
|
||||
# Grounding parser
|
||||
# -----------------------------
|
||||
# Match a full detection block and capture the coordinates as the entire list expression
|
||||
# Examples of captured coords (including outer brackets):
|
||||
# - [[312, 339, 480, 681]]
|
||||
# - [[504, 700, 625, 910], [771, 570, 996, 996]]
|
||||
# - [[110, 310, 255, 800], [312, 343, 479, 680], ...]
|
||||
# Using a greedy bracket capture ensures we include all inner lists up to the last ']' before </|det|>
|
||||
DET_BLOCK = re.compile(
|
||||
r"<\|ref\|>(?P<label>.*?)<\|/ref\|>\s*<\|det\|>\s*(?P<coords>\[.*\])\s*<\|/det\|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def clean_grounding_text(text: str) -> str:
|
||||
"""Remove grounding tags from text for display, keeping labels"""
|
||||
# Replace <|ref|>label<|/ref|><|det|>[...any nested lists...]<|/det|> with just the label
|
||||
cleaned = re.sub(
|
||||
r"<\|ref\|>(.*?)<\|/ref\|>\s*<\|det\|>\s*\[.*\]\s*<\|/det\|>",
|
||||
r"\1",
|
||||
text,
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
# Also remove any standalone grounding tags
|
||||
cleaned = re.sub(r"<\|grounding\|>", "", cleaned)
|
||||
return cleaned.strip()
|
||||
|
||||
def parse_detections(text: str, image_width: int, image_height: int) -> List[Dict[str, Any]]:
|
||||
"""Parse grounding boxes from text and scale from 0-999 normalized coords to actual image dimensions
|
||||
|
||||
Handles both single and multiple bounding boxes:
|
||||
- Single: <|ref|>label<|/ref|><|det|>[[x1,y1,x2,y2]]<|/det|>
|
||||
- Multiple: <|ref|>label<|/ref|><|det|>[[x1,y1,x2,y2], [x1,y1,x2,y2], ...]<|/det|>
|
||||
"""
|
||||
boxes: List[Dict[str, Any]] = []
|
||||
for m in DET_BLOCK.finditer(text or ""):
|
||||
label = m.group("label").strip()
|
||||
coords_str = m.group("coords").strip()
|
||||
|
||||
print(f"🔍 DEBUG: Found detection for '{label}'")
|
||||
print(f"📦 Raw coords string (with brackets): {coords_str}")
|
||||
|
||||
try:
|
||||
import ast
|
||||
|
||||
# Parse the full bracket expression directly (handles single and multiple)
|
||||
parsed = ast.literal_eval(coords_str)
|
||||
|
||||
# Normalize to a list of lists
|
||||
if (
|
||||
isinstance(parsed, list)
|
||||
and len(parsed) == 4
|
||||
and all(isinstance(n, (int, float)) for n in parsed)
|
||||
):
|
||||
# Single box provided as [x1,y1,x2,y2]
|
||||
box_coords = [parsed]
|
||||
print("📦 Single box (flat list) detected")
|
||||
elif isinstance(parsed, list):
|
||||
box_coords = parsed
|
||||
print(f"📦 Boxes detected: {len(box_coords)}")
|
||||
else:
|
||||
raise ValueError("Unsupported coords structure")
|
||||
|
||||
# Process each box
|
||||
for idx, box in enumerate(box_coords):
|
||||
if isinstance(box, (list, tuple)) and len(box) >= 4:
|
||||
x1 = int(float(box[0]) / 999 * image_width)
|
||||
y1 = int(float(box[1]) / 999 * image_height)
|
||||
x2 = int(float(box[2]) / 999 * image_width)
|
||||
y2 = int(float(box[3]) / 999 * image_height)
|
||||
print(f" Box {idx+1}: {box} → [{x1}, {y1}, {x2}, {y2}]")
|
||||
boxes.append({"label": label, "box": [x1, y1, x2, y2]})
|
||||
else:
|
||||
print(f" ⚠️ Skipping invalid box: {box}")
|
||||
except Exception as e:
|
||||
print(f"❌ Parsing failed: {e}")
|
||||
continue
|
||||
|
||||
print(f"🎯 Total boxes parsed: {len(boxes)}")
|
||||
return boxes
|
||||
|
||||
# -----------------------------
|
||||
# Routes
|
||||
# -----------------------------
|
||||
@@ -270,11 +95,38 @@ async def root():
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "healthy", "model_loaded": model is not None}
|
||||
return {"status": "healthy", "models": registry.list_models() if registry else []}
|
||||
|
||||
|
||||
@app.get("/api/models")
|
||||
async def list_models():
|
||||
"""List the OCR models available for selection in the UI."""
|
||||
if registry is None:
|
||||
raise HTTPException(status_code=503, detail="Model registry not ready.")
|
||||
return JSONResponse({"models": registry.list_models()})
|
||||
|
||||
|
||||
def _resolve_provider(model_id: Optional[str], mode: str):
|
||||
"""Look up the provider and reject capability mismatches (e.g. grounding)."""
|
||||
if registry is None:
|
||||
raise HTTPException(status_code=503, detail="Model registry not ready.")
|
||||
try:
|
||||
provider = registry.get(model_id)
|
||||
except ProviderError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
if mode in GROUNDING_MODES and not provider.capabilities.get("grounding"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Model '{provider.label}' does not support grounding modes (e.g. {mode}).",
|
||||
)
|
||||
return provider
|
||||
|
||||
|
||||
@app.post("/api/ocr")
|
||||
async def ocr_inference(
|
||||
image: UploadFile = File(...),
|
||||
model: Optional[str] = Form(None),
|
||||
mode: str = Form("plain_ocr"),
|
||||
prompt: str = Form(""),
|
||||
grounding: bool = Form(False),
|
||||
@@ -288,93 +140,64 @@ async def ocr_inference(
|
||||
):
|
||||
"""
|
||||
Perform OCR inference on uploaded image
|
||||
|
||||
|
||||
- **image**: Image file to process
|
||||
- **model**: OCR model id (see GET /api/models); defaults to the registry default
|
||||
- **mode**: OCR mode (plain_ocr, markdown, tables_csv, etc.)
|
||||
- **prompt**: Custom prompt for freeform mode
|
||||
- **grounding**: Enable grounding boxes
|
||||
- **grounding**: Enable grounding boxes (DeepSeek only)
|
||||
- **include_caption**: Add image description
|
||||
- **find_term**: Term to find (for find_ref mode)
|
||||
- **schema**: JSON schema (for kv_json mode)
|
||||
- **base_size**: Base processing size
|
||||
- **image_size**: Image size parameter
|
||||
- **crop_mode**: Enable crop mode
|
||||
- **test_compress**: Test compression
|
||||
- **base_size/image_size/crop_mode/test_compress**: DeepSeek processing options
|
||||
"""
|
||||
if model is None or tokenizer is None:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded yet")
|
||||
|
||||
# Build prompt
|
||||
prompt_text = build_prompt(
|
||||
mode=mode,
|
||||
user_prompt=prompt,
|
||||
grounding=grounding,
|
||||
find_term=find_term,
|
||||
schema=schema,
|
||||
include_caption=include_caption,
|
||||
)
|
||||
|
||||
provider = _resolve_provider(model, mode)
|
||||
|
||||
tmp_img = None
|
||||
out_dir = None
|
||||
try:
|
||||
# Save uploaded file
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
|
||||
content = await image.read()
|
||||
tmp.write(content)
|
||||
tmp_img = tmp.name
|
||||
|
||||
|
||||
# Get original dimensions
|
||||
try:
|
||||
with Image.open(tmp_img) as im:
|
||||
orig_w, orig_h = im.size
|
||||
except Exception:
|
||||
orig_w = orig_h = None
|
||||
|
||||
out_dir = tempfile.mkdtemp(prefix="dsocr_")
|
||||
|
||||
# Run inference
|
||||
res = model.infer(
|
||||
tokenizer,
|
||||
prompt=prompt_text,
|
||||
image_file=tmp_img,
|
||||
output_path=out_dir,
|
||||
base_size=base_size,
|
||||
image_size=image_size,
|
||||
crop_mode=crop_mode,
|
||||
save_results=False,
|
||||
test_compress=test_compress,
|
||||
eval_mode=True,
|
||||
|
||||
# Run inference through the selected provider
|
||||
text = provider.run(
|
||||
tmp_img,
|
||||
mode=mode,
|
||||
prompt=prompt,
|
||||
grounding=grounding,
|
||||
find_term=find_term,
|
||||
schema=schema,
|
||||
include_caption=include_caption,
|
||||
options={
|
||||
"base_size": base_size,
|
||||
"image_size": image_size,
|
||||
"crop_mode": crop_mode,
|
||||
"test_compress": test_compress,
|
||||
},
|
||||
)
|
||||
|
||||
# Normalize response
|
||||
if isinstance(res, str):
|
||||
text = res.strip()
|
||||
elif isinstance(res, dict) and "text" in res:
|
||||
text = str(res["text"]).strip()
|
||||
elif isinstance(res, (list, tuple)):
|
||||
text = "\n".join(map(str, res)).strip()
|
||||
else:
|
||||
text = ""
|
||||
|
||||
# Fallback: check output file
|
||||
if not text:
|
||||
mmd = os.path.join(out_dir, "result.mmd")
|
||||
if os.path.exists(mmd):
|
||||
with open(mmd, "r", encoding="utf-8") as fh:
|
||||
text = fh.read().strip()
|
||||
|
||||
if not text:
|
||||
text = "No text returned by model."
|
||||
|
||||
# Parse grounding boxes with proper coordinate scaling
|
||||
|
||||
# Parse grounding boxes (no-op for providers/text without grounding tokens)
|
||||
boxes = parse_detections(text, orig_w or 1, orig_h or 1) if ("<|det|>" in text or "<|ref|>" in text) else []
|
||||
|
||||
|
||||
# Clean grounding tags from display text, but keep the labels
|
||||
display_text = clean_grounding_text(text) if ("<|ref|>" in text or "<|grounding|>" in text) else text
|
||||
|
||||
|
||||
# If display text is empty after cleaning but we have boxes, show the labels
|
||||
if not display_text and boxes:
|
||||
display_text = ", ".join([b["label"] for b in boxes])
|
||||
|
||||
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"text": display_text,
|
||||
@@ -382,30 +205,36 @@ async def ocr_inference(
|
||||
"boxes": boxes,
|
||||
"image_dims": {"w": orig_w, "h": orig_h},
|
||||
"metadata": {
|
||||
"model": provider.id,
|
||||
"model_label": provider.label,
|
||||
"mode": mode,
|
||||
"grounding": grounding or (mode in {"find_ref","layout_map","pii_redact"}),
|
||||
"grounding": grounding or (mode in GROUNDING_MODES),
|
||||
"base_size": base_size,
|
||||
"image_size": image_size,
|
||||
"crop_mode": crop_mode
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
except ProviderError as e:
|
||||
print(f"OCR provider error: {e}")
|
||||
raise HTTPException(status_code=502, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"OCR inference error: {type(e).__name__}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="An internal error occurred during OCR processing.")
|
||||
|
||||
|
||||
finally:
|
||||
if tmp_img:
|
||||
try:
|
||||
os.remove(tmp_img)
|
||||
except Exception:
|
||||
pass
|
||||
if out_dir:
|
||||
shutil.rmtree(out_dir, ignore_errors=True)
|
||||
|
||||
@app.post("/api/process-pdf")
|
||||
async def process_pdf(
|
||||
pdf_file: UploadFile = File(...),
|
||||
model: Optional[str] = Form(None),
|
||||
mode: str = Form("plain_ocr"),
|
||||
prompt: str = Form(""),
|
||||
output_format: str = Form("markdown"), # markdown, html, docx, json
|
||||
@@ -432,8 +261,7 @@ async def process_pdf(
|
||||
- **image_size**: Image size parameter
|
||||
- **crop_mode**: Enable crop mode
|
||||
"""
|
||||
if model is None or tokenizer is None:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded yet")
|
||||
provider = _resolve_provider(model, mode)
|
||||
|
||||
# Validate output format
|
||||
if output_format not in ["markdown", "html", "docx", "json"]:
|
||||
@@ -456,56 +284,32 @@ async def process_pdf(
|
||||
for page_idx, img in enumerate(images):
|
||||
print(f"🔍 Processing page {page_idx + 1}/{total_pages}...")
|
||||
|
||||
# Build prompt for this page
|
||||
prompt_text = build_prompt(
|
||||
mode=mode,
|
||||
user_prompt=prompt,
|
||||
grounding=grounding,
|
||||
find_term=None,
|
||||
schema=None,
|
||||
include_caption=include_caption,
|
||||
)
|
||||
|
||||
# Save image temporarily
|
||||
tmp_img = None
|
||||
out_dir = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
|
||||
img.save(tmp, format="PNG")
|
||||
tmp_img = tmp.name
|
||||
|
||||
orig_w, orig_h = img.size
|
||||
out_dir = tempfile.mkdtemp(prefix="dsocr_pdf_")
|
||||
|
||||
# Run inference
|
||||
res = model.infer(
|
||||
tokenizer,
|
||||
prompt=prompt_text,
|
||||
image_file=tmp_img,
|
||||
output_path=out_dir,
|
||||
base_size=base_size,
|
||||
image_size=image_size,
|
||||
crop_mode=crop_mode,
|
||||
save_results=False,
|
||||
test_compress=False,
|
||||
eval_mode=True,
|
||||
# Run inference through the selected provider
|
||||
text = provider.run(
|
||||
tmp_img,
|
||||
mode=mode,
|
||||
prompt=prompt,
|
||||
grounding=grounding,
|
||||
find_term=None,
|
||||
schema=None,
|
||||
include_caption=include_caption,
|
||||
options={
|
||||
"base_size": base_size,
|
||||
"image_size": image_size,
|
||||
"crop_mode": crop_mode,
|
||||
"test_compress": False,
|
||||
},
|
||||
)
|
||||
|
||||
# Normalize response
|
||||
if isinstance(res, str):
|
||||
text = res.strip()
|
||||
elif isinstance(res, dict) and "text" in res:
|
||||
text = str(res["text"]).strip()
|
||||
elif isinstance(res, (list, tuple)):
|
||||
text = "\n".join(map(str, res)).strip()
|
||||
else:
|
||||
text = ""
|
||||
|
||||
if not text:
|
||||
mmd = os.path.join(out_dir, "result.mmd")
|
||||
if os.path.exists(mmd):
|
||||
with open(mmd, "r", encoding="utf-8") as fh:
|
||||
text = fh.read().strip()
|
||||
if not text:
|
||||
text = f"No text returned for page {page_idx + 1}."
|
||||
|
||||
@@ -550,8 +354,6 @@ async def process_pdf(
|
||||
os.remove(tmp_img)
|
||||
except Exception:
|
||||
pass
|
||||
if out_dir:
|
||||
shutil.rmtree(out_dir, ignore_errors=True)
|
||||
|
||||
print(f"✅ Processed all {total_pages} pages")
|
||||
|
||||
@@ -562,6 +364,8 @@ async def process_pdf(
|
||||
"total_pages": total_pages,
|
||||
"pages": pages_content,
|
||||
"metadata": {
|
||||
"model": provider.id,
|
||||
"model_label": provider.label,
|
||||
"mode": mode,
|
||||
"grounding": grounding,
|
||||
"extract_images": extract_images,
|
||||
@@ -590,6 +394,9 @@ async def process_pdf(
|
||||
headers={"Content-Disposition": f"attachment; filename=ocr_result.docx"}
|
||||
)
|
||||
|
||||
except ProviderError as e:
|
||||
print(f"PDF provider error: {e}")
|
||||
raise HTTPException(status_code=502, detail=str(e))
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"Error processing PDF: {e}")
|
||||
@@ -633,6 +440,7 @@ async def commit_job(
|
||||
describe_text: str = Form(""),
|
||||
freeform_text: str = Form(""),
|
||||
mode: str = Form("plain_ocr"),
|
||||
ocr_model: str = Form(""),
|
||||
):
|
||||
"""Commit an OCR job: save the image and insert a DB record."""
|
||||
job_id = str(uuid.uuid4())
|
||||
@@ -664,13 +472,14 @@ async def commit_job(
|
||||
"""
|
||||
INSERT INTO ocr_jobs
|
||||
(id, author, book, chapter, page, image_path, original_filename,
|
||||
ocr_text, describe_text, freeform_text, mode, status)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, 'unreviewed')
|
||||
ocr_text, describe_text, freeform_text, mode, ocr_model, status)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, 'unreviewed')
|
||||
RETURNING *
|
||||
""",
|
||||
(job_id, author or None, book or None, chapter or None,
|
||||
page or None, image_path, original_filename,
|
||||
ocr_text or None, describe_text or None, freeform_text or None, mode),
|
||||
ocr_text or None, describe_text or None, freeform_text or None,
|
||||
mode, ocr_model or None),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
except Exception as exc:
|
||||
@@ -743,7 +552,7 @@ async def list_jobs(
|
||||
cur.execute(
|
||||
f"""
|
||||
SELECT id, author, book, chapter, page, submitted_at, status,
|
||||
reviewer_name, reviewed_at, mode, original_filename
|
||||
reviewer_name, reviewed_at, mode, ocr_model, original_filename
|
||||
FROM ocr_jobs {where}
|
||||
ORDER BY submitted_at DESC
|
||||
LIMIT %s OFFSET %s
|
||||
@@ -945,6 +754,75 @@ async def set_job_status(job_id: str, body: StatusRequest):
|
||||
return JSONResponse(_job_row_to_dict(row))
|
||||
|
||||
|
||||
class JobDescribeRequest(BaseModel):
|
||||
model: Optional[str] = None
|
||||
|
||||
|
||||
@app.post("/api/jobs/{job_id}/describe")
|
||||
async def describe_job(job_id: str, body: JobDescribeRequest):
|
||||
"""Run Describe mode on a job's stored image and save the result to describe_text."""
|
||||
try:
|
||||
uuid.UUID(job_id)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid job ID.")
|
||||
|
||||
# Look up the stored image for this job
|
||||
try:
|
||||
with get_db() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SELECT image_path FROM ocr_jobs WHERE id = %s", (job_id,))
|
||||
row = cur.fetchone()
|
||||
except Exception as exc:
|
||||
print(f"describe_job lookup DB error: {exc}")
|
||||
raise HTTPException(status_code=500, detail="Database error.")
|
||||
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="Job not found.")
|
||||
image_path = row["image_path"]
|
||||
if not image_path or not os.path.isfile(image_path):
|
||||
raise HTTPException(status_code=404, detail="Image file not found on disk.")
|
||||
|
||||
provider = _resolve_provider(body.model, "describe")
|
||||
|
||||
try:
|
||||
text = provider.run(
|
||||
image_path,
|
||||
mode="describe",
|
||||
prompt="",
|
||||
grounding=False,
|
||||
find_term=None,
|
||||
schema=None,
|
||||
include_caption=False,
|
||||
options={"base_size": 1024, "image_size": 640, "crop_mode": True, "test_compress": False},
|
||||
)
|
||||
except ProviderError as e:
|
||||
print(f"describe_job provider error: {e}")
|
||||
raise HTTPException(status_code=502, detail=str(e))
|
||||
except Exception as e:
|
||||
print(f"describe_job inference error: {type(e).__name__}: {e}")
|
||||
raise HTTPException(status_code=500, detail="An internal error occurred during description.")
|
||||
|
||||
display_text = clean_grounding_text(text) if ("<|ref|>" in text or "<|grounding|>" in text) else text
|
||||
|
||||
# Persist the generated description on the job
|
||||
try:
|
||||
with get_db() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"UPDATE ocr_jobs SET describe_text = %s WHERE id = %s RETURNING *",
|
||||
(display_text, job_id),
|
||||
)
|
||||
updated = cur.fetchone()
|
||||
except Exception as exc:
|
||||
print(f"describe_job save DB error: {exc}")
|
||||
raise HTTPException(status_code=500, detail="Database error.")
|
||||
|
||||
if not updated:
|
||||
raise HTTPException(status_code=404, detail="Job not found.")
|
||||
|
||||
return JSONResponse(_job_row_to_dict(updated))
|
||||
|
||||
|
||||
@app.delete("/api/jobs/{job_id}")
|
||||
async def delete_job(job_id: str):
|
||||
"""Delete a job record and its stored image."""
|
||||
|
||||
Reference in New Issue
Block a user