aboutsummaryrefslogtreecommitdiff
path: root/src/server
diff options
context:
space:
mode:
Diffstat (limited to 'src/server')
-rw-r--r--src/server/ApiManagers/FlashcardManager.ts161
-rw-r--r--src/server/flashcard/labels.py285
-rw-r--r--src/server/flashcard/requirements.txt12
3 files changed, 458 insertions, 0 deletions
diff --git a/src/server/ApiManagers/FlashcardManager.ts b/src/server/ApiManagers/FlashcardManager.ts
new file mode 100644
index 000000000..fd7c42437
--- /dev/null
+++ b/src/server/ApiManagers/FlashcardManager.ts
@@ -0,0 +1,161 @@
+/**
+ * @file FlashcardManager.ts
+ * @description This file defines the FlashcardManager class, responsible for managing API routes
+ * related to flashcard creation and manipulation. It provides functionality for handling file processing,
+ * running Python scripts in a virtual environment, and managing dependencies.
+ */
+
+import { spawn } from 'child_process';
+import * as fs from 'fs';
+import * as path from 'path';
+import { Method } from '../RouteManager';
+import ApiManager, { Registration } from './ApiManager';
+
+/**
+ * Runs a Python script using the provided virtual environment and passes file and option arguments.
+ * @param {string} venvPath - Path to the virtual environment.
+ * @param {string} scriptPath - Path to the Python script.
+ * @param {string} [file] - Optional file to pass to the Python script.
+ * @param {string} [drag] - Optional argument to control drag mode.
+ * @param {string} [smart] - Optional argument to control smart mode.
+ * @returns {Promise<string>} - Resolves with the output from the Python script, or rejects on error.
+ */
+function runPythonScript(venvPath: string, scriptPath: string, file?: string, drag?: string, smart?: string): Promise<string> {
+ return new Promise((resolve, reject) => {
+ const pythonPath = process.platform === 'win32' ? path.join(venvPath, 'Scripts', 'python.exe') : path.join(venvPath, 'bin', 'python3');
+
+ const tempFilePath = path.join(__dirname, `temp_data.txt`); // Unique temp file name
+
+ if (file) {
+ // Write the raw file data to the temp file without conversion
+ fs.writeFileSync(tempFilePath, file, 'utf8');
+ }
+
+ const pythonProcess = spawn(
+ pythonPath,
+ [scriptPath, file ? tempFilePath : undefined, drag, smart].filter(arg => arg !== undefined)
+ );
+
+ let pythonOutput = '';
+ let stderrOutput = '';
+
+ pythonProcess.stdout.on('data', data => {
+ pythonOutput += data.toString();
+ });
+
+ pythonProcess.stderr.on('data', data => {
+ stderrOutput += data.toString();
+ });
+
+ pythonProcess.on('close', code => {
+ if (code === 0) {
+ resolve(pythonOutput);
+ } else {
+ reject(`Python process exited with code ${code}: ${stderrOutput}`);
+ }
+ });
+ });
+}
+
+/**
+ * Installs Python dependencies using pip in the specified virtual environment.
+ * @param {string} venvPath - Path to the virtual environment.
+ * @param {string} requirementsPath - Path to the requirements.txt file.
+ * @returns {Promise<void>} - Resolves when dependencies are successfully installed, rejects on failure.
+ */
+function installDependencies(venvPath: string, requirementsPath: string): Promise<void> {
+ return new Promise((resolve, reject) => {
+ const pipPath = process.platform === 'win32' ? path.join(venvPath, 'Scripts', 'pip.exe') : path.join(venvPath, 'bin', 'pip3');
+
+ const installProcess = spawn(pipPath, ['install', '-r', requirementsPath]);
+
+ installProcess.stdout.on('data', data => {
+ console.log(`pip stdout: ${data}`);
+ });
+
+ installProcess.stderr.on('data', data => {
+ console.error(`pip stderr: ${data}`);
+ });
+
+ installProcess.on('close', code => {
+ if (code !== 0) {
+ reject(`Failed to install dependencies. Exit code: ${code}`);
+ } else {
+ resolve();
+ }
+ });
+ });
+}
+
+/**
+ * Creates a new Python virtual environment.
+ * @param {string} venvPath - Path to the virtual environment that will be created.
+ * @returns {Promise<void>} - Resolves when the virtual environment is successfully created, rejects on failure.
+ */
+function createVirtualEnvironment(venvPath: string): Promise<void> {
+ return new Promise((resolve, reject) => {
+ const createVenvProcess = spawn('python3', ['-m', 'venv', venvPath]);
+
+ createVenvProcess.on('close', code => {
+ if (code !== 0) {
+ reject(`Failed to create virtual environment. Exit code: ${code}`);
+ } else {
+ resolve();
+ }
+ });
+ });
+}
+
+/**
+ * Manages the creation of the virtual environment, installation of dependencies, and running of the Python script.
+ * @param {string} [file] - Optional file data to be processed by the Python script.
+ * @param {string} [drag] - Optional argument controlling drag mode.
+ * @param {string} [smart] - Optional argument controlling smart mode.
+ * @returns {Promise<string>} - Resolves with the Python script output, or rejects on failure.
+ */
+async function manageVenvAndRunScript(file?: string, drag?: string, smart?: string): Promise<string> {
+ const venvPath = path.join(__dirname, '../flashcard/venv'); // Virtual environment path
+ const requirementsPath = path.join(__dirname, '../flashcard/requirements.txt');
+ const pythonScriptPath = path.join(__dirname, '../flashcard/labels.py');
+ console.log('venvPath:', venvPath);
+
+ // Check if the virtual environment exists
+ if (!fs.existsSync(path.join(venvPath, 'bin', 'python3')) && !fs.existsSync(path.join(venvPath, 'Scripts', 'python.exe'))) {
+ await createVirtualEnvironment(venvPath);
+
+ await installDependencies(venvPath, requirementsPath);
+ }
+
+ return runPythonScript(venvPath, pythonScriptPath, file, drag, smart);
+}
+
+/**
+ * FlashcardManager class responsible for managing API routes related to flashcard functionality.
+ * It initializes API routes for handling YouTube subscriptions and label creation using a Python backend.
+ */
+export default class FlashcardManager extends ApiManager {
+ /**
+ * Initializes the API routes for the FlashcardManager class.
+ * @param {Registration} register - The registration function for defining API routes.
+ */
+ protected initialize(register: Registration): void {
+ register({
+ method: Method.POST,
+ subscription: '/labels',
+ secureHandler: async ({ req, res }) => {
+ const { file, drag, smart } = req.body;
+
+ try {
+ // Run the Python process
+ const result = await manageVenvAndRunScript(file, drag, smart);
+ res.status(200).send({ result });
+ } catch (error) {
+ console.error('Error initiating document creation:', error);
+ res.status(500).send({
+ error: 'Failed to initiate document creation',
+ });
+ }
+ },
+ });
+ }
+}
diff --git a/src/server/flashcard/labels.py b/src/server/flashcard/labels.py
new file mode 100644
index 000000000..546fc4bd3
--- /dev/null
+++ b/src/server/flashcard/labels.py
@@ -0,0 +1,285 @@
+import base64
+import numpy as np
+import base64
+import easyocr
+import sys
+from PIL import Image
+from io import BytesIO
+import requests
+import json
+import numpy as np
+
+class BoundingBoxUtils:
+ """Utility class for bounding box operations and OCR result corrections."""
+
+ @staticmethod
+ def is_close(box1, box2, x_threshold=20, y_threshold=20):
+ """
+ Determines if two bounding boxes are horizontally and vertically close.
+
+ Parameters:
+ box1, box2 (list): The bounding boxes to compare.
+ x_threshold (int): The threshold for horizontal proximity.
+ y_threshold (int): The threshold for vertical proximity.
+
+ Returns:
+ bool: True if boxes are close, False otherwise.
+ """
+ horizontally_close = (abs(box1[2] - box2[0]) < x_threshold or # Right edge of box1 and left edge of box2
+ abs(box2[2] - box1[0]) < x_threshold or # Right edge of box2 and left edge of box1
+ abs(box1[2] - box2[2]) < x_threshold or
+ abs(box2[0] - box1[0]) < x_threshold)
+
+ vertically_close = (abs(box1[3] - box2[1]) < y_threshold or # Bottom edge of box1 and top edge of box2
+ abs(box2[3] - box1[1]) < y_threshold or
+ box1[1] == box2[1] or box1[3] == box2[3])
+
+ return horizontally_close and vertically_close
+
+ @staticmethod
+ def adjust_bounding_box(bbox, original_text, corrected_text):
+ """
+ Adjusts a bounding box based on differences in text length.
+
+ Parameters:
+ bbox (list): The original bounding box coordinates.
+ original_text (str): The original text detected by OCR.
+ corrected_text (str): The corrected text after cleaning.
+
+ Returns:
+ list: The adjusted bounding box.
+ """
+ if not bbox or len(bbox) != 4:
+ return bbox
+
+ # Adjust the x-coordinates slightly to account for text correction
+ x_adjustment = 5
+ adjusted_bbox = [
+ [bbox[0][0] + x_adjustment, bbox[0][1]],
+ [bbox[1][0], bbox[1][1]],
+ [bbox[2][0] + x_adjustment, bbox[2][1]],
+ [bbox[3][0], bbox[3][1]]
+ ]
+ return adjusted_bbox
+
+ @staticmethod
+ def correct_ocr_results(results):
+ """
+ Corrects common OCR misinterpretations in the detected text and adjusts bounding boxes accordingly.
+
+ Parameters:
+ results (list): A list of OCR results, each containing bounding box, text, and confidence score.
+
+ Returns:
+ list: Corrected OCR results with adjusted bounding boxes.
+ """
+ corrections = {
+ "~": "", # Replace '~' with empty string
+ "-": "" # Replace '-' with empty string
+ }
+
+ corrected_results = []
+ for (bbox, text, prob) in results:
+ corrected_text = ''.join(corrections.get(char, char) for char in text)
+ adjusted_bbox = BoundingBoxUtils.adjust_bounding_box(bbox, text, corrected_text)
+ corrected_results.append((adjusted_bbox, corrected_text, prob))
+
+ return corrected_results
+
+ @staticmethod
+ def convert_to_json_serializable(data):
+ """
+ Converts a list containing various types, including numpy types, to a JSON-serializable format.
+
+ Parameters:
+ data (list): A list containing numpy or other non-serializable types.
+
+ Returns:
+ list: A JSON-serializable version of the input list.
+ """
+ def convert_element(element):
+ if isinstance(element, list):
+ return [convert_element(e) for e in element]
+ elif isinstance(element, tuple):
+ return tuple(convert_element(e) for e in element)
+ elif isinstance(element, np.integer):
+ return int(element)
+ elif isinstance(element, np.floating):
+ return float(element)
+ elif isinstance(element, np.ndarray):
+ return element.tolist()
+ else:
+ return element
+
+ return convert_element(data)
+
+class ImageLabelProcessor:
+ """Class to process images and perform OCR with EasyOCR."""
+
+ VERTICAL_THRESHOLD = 20
+ HORIZONTAL_THRESHOLD = 8
+
+ def __init__(self, img_source, source_type, smart_mode):
+ self.img_source = img_source
+ self.source_type = source_type
+ self.smart_mode = smart_mode
+ self.img_val = self.load_image()
+
+ def load_image(self):
+ """Load image from either a base64 string or URL."""
+ if self.source_type == 'drag':
+ return self._load_base64_image()
+ else:
+ return self._load_url_image()
+
+ def _load_base64_image(self):
+ """Decode and save the base64 image."""
+ base64_string = self.img_source
+ if base64_string.startswith("data:image"):
+ base64_string = base64_string.split(",")[1]
+
+
+ # Decode the base64 string
+ image_data = base64.b64decode(base64_string)
+ image = Image.open(BytesIO(image_data)).convert('RGB')
+ image.save("temp_image.jpg")
+ return "temp_image.jpg"
+
+ def _load_url_image(self):
+ """Download image from URL and return it in byte format."""
+ url = self.img_source
+ response = requests.get(url)
+ image = Image.open(BytesIO(response.content)).convert('RGB')
+
+ image_bytes = BytesIO()
+ image.save(image_bytes, format='PNG')
+ return image_bytes.getvalue()
+
+ def process_image(self):
+ """Process the image and return the OCR results."""
+ if self.smart_mode:
+ return self._process_smart_mode()
+ else:
+ return self._process_standard_mode()
+
+ def _process_smart_mode(self):
+ """Process the image in smart mode using EasyOCR."""
+ reader = easyocr.Reader(['en'])
+ result = reader.readtext(self.img_val, detail=1, paragraph=True)
+
+ all_boxes = [bbox for bbox, text in result]
+ all_texts = [text for bbox, text in result]
+
+ response_data = {
+ 'status': 'success',
+ 'message': 'Data received',
+ 'boxes': BoundingBoxUtils.convert_to_json_serializable(all_boxes),
+ 'text': BoundingBoxUtils.convert_to_json_serializable(all_texts),
+ }
+
+ return response_data
+
+ def _process_standard_mode(self):
+ """Process the image in standard mode using EasyOCR."""
+ reader = easyocr.Reader(['en'])
+ results = reader.readtext(self.img_val)
+
+ filtered_results = BoundingBoxUtils.correct_ocr_results([
+ (bbox, text, prob) for bbox, text, prob in results if prob >= 0.7
+ ])
+
+ return self._merge_and_prepare_response(filtered_results)
+
+ def are_vertically_close(self, box1, box2):
+ """Check if two bounding boxes are vertically close."""
+ box1_bottom = max(box1[2][1], box1[3][1])
+ box2_top = min(box2[0][1], box2[1][1])
+ vertical_distance = box2_top - box1_bottom
+
+ box1_left = box1[0][0]
+ box2_left = box2[0][0]
+ box1_right = box1[1][0]
+ box2_right = box2[1][0]
+ hori_close = abs(box2_left - box1_left) <= self.HORIZONTAL_THRESHOLD or abs(box2_right - box1_right) <= self.HORIZONTAL_THRESHOLD
+
+ return vertical_distance <= self.VERTICAL_THRESHOLD and hori_close
+
+ def merge_boxes(self, boxes, texts):
+ """Merge multiple bounding boxes and their associated text."""
+ x_coords = []
+ y_coords = []
+
+ # Collect all x and y coordinates
+ for box in boxes:
+ for point in box:
+ x_coords.append(point[0])
+ y_coords.append(point[1])
+
+ # Create the merged bounding box
+ merged_box = [
+ [min(x_coords), min(y_coords)],
+ [max(x_coords), min(y_coords)],
+ [max(x_coords), max(y_coords)],
+ [min(x_coords), max(y_coords)]
+ ]
+
+ # Combine the texts
+ merged_text = ' '.join(texts)
+
+ return merged_box, merged_text
+
+ def _merge_and_prepare_response(self, filtered_results):
+ """Merge vertically close boxes and prepare the final response."""
+ current_boxes, current_texts = [], []
+ all_boxes, all_texts = [], []
+
+ for ind in range(len(filtered_results) - 1):
+ if not current_boxes:
+ current_boxes.append(filtered_results[ind][0])
+ current_texts.append(filtered_results[ind][1])
+
+ if self.are_vertically_close(filtered_results[ind][0], filtered_results[ind + 1][0]):
+ current_boxes.append(filtered_results[ind + 1][0])
+ current_texts.append(filtered_results[ind + 1][1])
+ else:
+ merged = self.merge_boxes(current_boxes, current_texts)
+ all_boxes.append(merged[0])
+ all_texts.append(merged[1])
+ current_boxes, current_texts = [], []
+
+ if current_boxes:
+ merged = self.merge_boxes(current_boxes, current_texts)
+ all_boxes.append(merged[0])
+ all_texts.append(merged[1])
+
+ if not current_boxes and filtered_results:
+ merged = self.merge_boxes([filtered_results[-1][0]], [filtered_results[-1][1]])
+ all_boxes.append(merged[0])
+ all_texts.append(merged[1])
+
+ response = {
+ 'status': 'success',
+ 'message': 'Data received',
+ 'boxes': BoundingBoxUtils.convert_to_json_serializable(all_boxes),
+ 'text': BoundingBoxUtils.convert_to_json_serializable(all_texts),
+ }
+
+ return response
+
+# Main execution function
+def labels():
+ """Main function to handle image OCR processing based on input arguments."""
+ source_type = sys.argv[2]
+ smart_mode = (sys.argv[3] == 'smart')
+ with open(sys.argv[1], 'r') as f:
+ img_source = f.read()
+ # Create ImageLabelProcessor instance
+ processor = ImageLabelProcessor(img_source, source_type, smart_mode)
+ response = processor.process_image()
+
+ # Print and return the response
+ print(response)
+ return response
+
+
+labels()
diff --git a/src/server/flashcard/requirements.txt b/src/server/flashcard/requirements.txt
new file mode 100644
index 000000000..eb92a819b
--- /dev/null
+++ b/src/server/flashcard/requirements.txt
@@ -0,0 +1,12 @@
+easyocr==1.7.1
+requests==2.32.3
+pillow==10.4.0
+numpy==1.26.4
+tqdm==4.66.4
+Werkzeug==3.0.3
+python-dateutil==2.9.0.post0
+six==1.16.0
+certifi==2024.6.2
+charset-normalizer==3.3.2
+idna==3.7
+urllib3==1.26.19 \ No newline at end of file