aboutsummaryrefslogtreecommitdiff
path: root/src/server/flashcard/labels.py
diff options
context:
space:
mode:
authoralyssaf16 <alyssa_feinberg@brown.edu>2024-10-15 12:19:38 -0400
committeralyssaf16 <alyssa_feinberg@brown.edu>2024-10-15 12:19:38 -0400
commit63d1731bb675b71c20c0e460cf50ef9d674a6c08 (patch)
tree2dcb0ce2842f0fc77e065b35aacba64171f14360 /src/server/flashcard/labels.py
parentbb8fe2933154c6db70cfe5da1e890535bc9012d4 (diff)
flashcard move to server
Diffstat (limited to 'src/server/flashcard/labels.py')
-rw-r--r--src/server/flashcard/labels.py285
1 files changed, 285 insertions, 0 deletions
diff --git a/src/server/flashcard/labels.py b/src/server/flashcard/labels.py
new file mode 100644
index 000000000..546fc4bd3
--- /dev/null
+++ b/src/server/flashcard/labels.py
@@ -0,0 +1,285 @@
+import base64
+import numpy as np
+import base64
+import easyocr
+import sys
+from PIL import Image
+from io import BytesIO
+import requests
+import json
+import numpy as np
+
+class BoundingBoxUtils:
+ """Utility class for bounding box operations and OCR result corrections."""
+
+ @staticmethod
+ def is_close(box1, box2, x_threshold=20, y_threshold=20):
+ """
+ Determines if two bounding boxes are horizontally and vertically close.
+
+ Parameters:
+ box1, box2 (list): The bounding boxes to compare.
+ x_threshold (int): The threshold for horizontal proximity.
+ y_threshold (int): The threshold for vertical proximity.
+
+ Returns:
+ bool: True if boxes are close, False otherwise.
+ """
+ horizontally_close = (abs(box1[2] - box2[0]) < x_threshold or # Right edge of box1 and left edge of box2
+ abs(box2[2] - box1[0]) < x_threshold or # Right edge of box2 and left edge of box1
+ abs(box1[2] - box2[2]) < x_threshold or
+ abs(box2[0] - box1[0]) < x_threshold)
+
+ vertically_close = (abs(box1[3] - box2[1]) < y_threshold or # Bottom edge of box1 and top edge of box2
+ abs(box2[3] - box1[1]) < y_threshold or
+ box1[1] == box2[1] or box1[3] == box2[3])
+
+ return horizontally_close and vertically_close
+
+ @staticmethod
+ def adjust_bounding_box(bbox, original_text, corrected_text):
+ """
+ Adjusts a bounding box based on differences in text length.
+
+ Parameters:
+ bbox (list): The original bounding box coordinates.
+ original_text (str): The original text detected by OCR.
+ corrected_text (str): The corrected text after cleaning.
+
+ Returns:
+ list: The adjusted bounding box.
+ """
+ if not bbox or len(bbox) != 4:
+ return bbox
+
+ # Adjust the x-coordinates slightly to account for text correction
+ x_adjustment = 5
+ adjusted_bbox = [
+ [bbox[0][0] + x_adjustment, bbox[0][1]],
+ [bbox[1][0], bbox[1][1]],
+ [bbox[2][0] + x_adjustment, bbox[2][1]],
+ [bbox[3][0], bbox[3][1]]
+ ]
+ return adjusted_bbox
+
+ @staticmethod
+ def correct_ocr_results(results):
+ """
+ Corrects common OCR misinterpretations in the detected text and adjusts bounding boxes accordingly.
+
+ Parameters:
+ results (list): A list of OCR results, each containing bounding box, text, and confidence score.
+
+ Returns:
+ list: Corrected OCR results with adjusted bounding boxes.
+ """
+ corrections = {
+ "~": "", # Replace '~' with empty string
+ "-": "" # Replace '-' with empty string
+ }
+
+ corrected_results = []
+ for (bbox, text, prob) in results:
+ corrected_text = ''.join(corrections.get(char, char) for char in text)
+ adjusted_bbox = BoundingBoxUtils.adjust_bounding_box(bbox, text, corrected_text)
+ corrected_results.append((adjusted_bbox, corrected_text, prob))
+
+ return corrected_results
+
+ @staticmethod
+ def convert_to_json_serializable(data):
+ """
+ Converts a list containing various types, including numpy types, to a JSON-serializable format.
+
+ Parameters:
+ data (list): A list containing numpy or other non-serializable types.
+
+ Returns:
+ list: A JSON-serializable version of the input list.
+ """
+ def convert_element(element):
+ if isinstance(element, list):
+ return [convert_element(e) for e in element]
+ elif isinstance(element, tuple):
+ return tuple(convert_element(e) for e in element)
+ elif isinstance(element, np.integer):
+ return int(element)
+ elif isinstance(element, np.floating):
+ return float(element)
+ elif isinstance(element, np.ndarray):
+ return element.tolist()
+ else:
+ return element
+
+ return convert_element(data)
+
+class ImageLabelProcessor:
+ """Class to process images and perform OCR with EasyOCR."""
+
+ VERTICAL_THRESHOLD = 20
+ HORIZONTAL_THRESHOLD = 8
+
+ def __init__(self, img_source, source_type, smart_mode):
+ self.img_source = img_source
+ self.source_type = source_type
+ self.smart_mode = smart_mode
+ self.img_val = self.load_image()
+
+ def load_image(self):
+ """Load image from either a base64 string or URL."""
+ if self.source_type == 'drag':
+ return self._load_base64_image()
+ else:
+ return self._load_url_image()
+
+ def _load_base64_image(self):
+ """Decode and save the base64 image."""
+ base64_string = self.img_source
+ if base64_string.startswith("data:image"):
+ base64_string = base64_string.split(",")[1]
+
+
+ # Decode the base64 string
+ image_data = base64.b64decode(base64_string)
+ image = Image.open(BytesIO(image_data)).convert('RGB')
+ image.save("temp_image.jpg")
+ return "temp_image.jpg"
+
+ def _load_url_image(self):
+ """Download image from URL and return it in byte format."""
+ url = self.img_source
+ response = requests.get(url)
+ image = Image.open(BytesIO(response.content)).convert('RGB')
+
+ image_bytes = BytesIO()
+ image.save(image_bytes, format='PNG')
+ return image_bytes.getvalue()
+
+ def process_image(self):
+ """Process the image and return the OCR results."""
+ if self.smart_mode:
+ return self._process_smart_mode()
+ else:
+ return self._process_standard_mode()
+
+ def _process_smart_mode(self):
+ """Process the image in smart mode using EasyOCR."""
+ reader = easyocr.Reader(['en'])
+ result = reader.readtext(self.img_val, detail=1, paragraph=True)
+
+ all_boxes = [bbox for bbox, text in result]
+ all_texts = [text for bbox, text in result]
+
+ response_data = {
+ 'status': 'success',
+ 'message': 'Data received',
+ 'boxes': BoundingBoxUtils.convert_to_json_serializable(all_boxes),
+ 'text': BoundingBoxUtils.convert_to_json_serializable(all_texts),
+ }
+
+ return response_data
+
+ def _process_standard_mode(self):
+ """Process the image in standard mode using EasyOCR."""
+ reader = easyocr.Reader(['en'])
+ results = reader.readtext(self.img_val)
+
+ filtered_results = BoundingBoxUtils.correct_ocr_results([
+ (bbox, text, prob) for bbox, text, prob in results if prob >= 0.7
+ ])
+
+ return self._merge_and_prepare_response(filtered_results)
+
+ def are_vertically_close(self, box1, box2):
+ """Check if two bounding boxes are vertically close."""
+ box1_bottom = max(box1[2][1], box1[3][1])
+ box2_top = min(box2[0][1], box2[1][1])
+ vertical_distance = box2_top - box1_bottom
+
+ box1_left = box1[0][0]
+ box2_left = box2[0][0]
+ box1_right = box1[1][0]
+ box2_right = box2[1][0]
+ hori_close = abs(box2_left - box1_left) <= self.HORIZONTAL_THRESHOLD or abs(box2_right - box1_right) <= self.HORIZONTAL_THRESHOLD
+
+ return vertical_distance <= self.VERTICAL_THRESHOLD and hori_close
+
+ def merge_boxes(self, boxes, texts):
+ """Merge multiple bounding boxes and their associated text."""
+ x_coords = []
+ y_coords = []
+
+ # Collect all x and y coordinates
+ for box in boxes:
+ for point in box:
+ x_coords.append(point[0])
+ y_coords.append(point[1])
+
+ # Create the merged bounding box
+ merged_box = [
+ [min(x_coords), min(y_coords)],
+ [max(x_coords), min(y_coords)],
+ [max(x_coords), max(y_coords)],
+ [min(x_coords), max(y_coords)]
+ ]
+
+ # Combine the texts
+ merged_text = ' '.join(texts)
+
+ return merged_box, merged_text
+
+ def _merge_and_prepare_response(self, filtered_results):
+ """Merge vertically close boxes and prepare the final response."""
+ current_boxes, current_texts = [], []
+ all_boxes, all_texts = [], []
+
+ for ind in range(len(filtered_results) - 1):
+ if not current_boxes:
+ current_boxes.append(filtered_results[ind][0])
+ current_texts.append(filtered_results[ind][1])
+
+ if self.are_vertically_close(filtered_results[ind][0], filtered_results[ind + 1][0]):
+ current_boxes.append(filtered_results[ind + 1][0])
+ current_texts.append(filtered_results[ind + 1][1])
+ else:
+ merged = self.merge_boxes(current_boxes, current_texts)
+ all_boxes.append(merged[0])
+ all_texts.append(merged[1])
+ current_boxes, current_texts = [], []
+
+ if current_boxes:
+ merged = self.merge_boxes(current_boxes, current_texts)
+ all_boxes.append(merged[0])
+ all_texts.append(merged[1])
+
+ if not current_boxes and filtered_results:
+ merged = self.merge_boxes([filtered_results[-1][0]], [filtered_results[-1][1]])
+ all_boxes.append(merged[0])
+ all_texts.append(merged[1])
+
+ response = {
+ 'status': 'success',
+ 'message': 'Data received',
+ 'boxes': BoundingBoxUtils.convert_to_json_serializable(all_boxes),
+ 'text': BoundingBoxUtils.convert_to_json_serializable(all_texts),
+ }
+
+ return response
+
+# Main execution function
+def labels():
+ """Main function to handle image OCR processing based on input arguments."""
+ source_type = sys.argv[2]
+ smart_mode = (sys.argv[3] == 'smart')
+ with open(sys.argv[1], 'r') as f:
+ img_source = f.read()
+ # Create ImageLabelProcessor instance
+ processor = ImageLabelProcessor(img_source, source_type, smart_mode)
+ response = processor.process_image()
+
+ # Print and return the response
+ print(response)
+ return response
+
+
+labels()