diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/client/views/StyleProviderQuiz.tsx | 17 | ||||
-rw-r--r-- | src/server/ApiManagers/FlashcardManager.ts | 161 | ||||
-rw-r--r-- | src/server/flashcard/labels.py | 285 | ||||
-rw-r--r-- | src/server/flashcard/requirements.txt | 12 |
4 files changed, 463 insertions, 12 deletions
diff --git a/src/client/views/StyleProviderQuiz.tsx b/src/client/views/StyleProviderQuiz.tsx index 1f2ad1485..b9dd4016c 100644 --- a/src/client/views/StyleProviderQuiz.tsx +++ b/src/client/views/StyleProviderQuiz.tsx @@ -19,6 +19,7 @@ import { FieldViewProps } from './nodes/FieldView'; import { ImageBox } from './nodes/ImageBox'; import { ImageUtility } from './nodes/generativeFill/generativeFillUtils/ImageHandler'; import './StyleProviderQuiz.scss'; +import { Networking } from '../Network'; export namespace styleProviderQuiz { enum quizMode { @@ -85,18 +86,10 @@ export namespace styleProviderQuiz { if (!quizBoxes.length) { imgBox.Loading = true; - const img = { - file: i ? i : imgBox.paths[0], - drag: i ? 'drag' : 'full', - smart: quiz, - }; - const response = await axios.post('http://localhost:105/labels/', img, { - headers: { - 'Content-Type': 'application/json', - }, - }); - if (response.data['boxes'].length != 0) { - createBoxes(imgBox, response.data['boxes'], response.data['text']); + const response = await Networking.PostToServer('/labels', { file: i ? i : imgBox.paths[0], drag: i ? 'drag' : 'full', smart: quiz }); + const parsedResponse = JSON.parse(response.result.replace(/'/g, '"')); + if (parsedResponse['boxes'].length != 0) { + createBoxes(imgBox, parsedResponse['boxes'], parsedResponse['text']); } else { imgBox.Loading = false; } diff --git a/src/server/ApiManagers/FlashcardManager.ts b/src/server/ApiManagers/FlashcardManager.ts new file mode 100644 index 000000000..fd7c42437 --- /dev/null +++ b/src/server/ApiManagers/FlashcardManager.ts @@ -0,0 +1,161 @@ +/** + * @file FlashcardManager.ts + * @description This file defines the FlashcardManager class, responsible for managing API routes + * related to flashcard creation and manipulation. It provides functionality for handling file processing, + * running Python scripts in a virtual environment, and managing dependencies. + */ + +import { spawn } from 'child_process'; +import * as fs from 'fs'; +import * as path from 'path'; +import { Method } from '../RouteManager'; +import ApiManager, { Registration } from './ApiManager'; + +/** + * Runs a Python script using the provided virtual environment and passes file and option arguments. + * @param {string} venvPath - Path to the virtual environment. + * @param {string} scriptPath - Path to the Python script. + * @param {string} [file] - Optional file to pass to the Python script. + * @param {string} [drag] - Optional argument to control drag mode. + * @param {string} [smart] - Optional argument to control smart mode. + * @returns {Promise<string>} - Resolves with the output from the Python script, or rejects on error. + */ +function runPythonScript(venvPath: string, scriptPath: string, file?: string, drag?: string, smart?: string): Promise<string> { + return new Promise((resolve, reject) => { + const pythonPath = process.platform === 'win32' ? path.join(venvPath, 'Scripts', 'python.exe') : path.join(venvPath, 'bin', 'python3'); + + const tempFilePath = path.join(__dirname, `temp_data.txt`); // Unique temp file name + + if (file) { + // Write the raw file data to the temp file without conversion + fs.writeFileSync(tempFilePath, file, 'utf8'); + } + + const pythonProcess = spawn( + pythonPath, + [scriptPath, file ? tempFilePath : undefined, drag, smart].filter(arg => arg !== undefined) + ); + + let pythonOutput = ''; + let stderrOutput = ''; + + pythonProcess.stdout.on('data', data => { + pythonOutput += data.toString(); + }); + + pythonProcess.stderr.on('data', data => { + stderrOutput += data.toString(); + }); + + pythonProcess.on('close', code => { + if (code === 0) { + resolve(pythonOutput); + } else { + reject(`Python process exited with code ${code}: ${stderrOutput}`); + } + }); + }); +} + +/** + * Installs Python dependencies using pip in the specified virtual environment. + * @param {string} venvPath - Path to the virtual environment. + * @param {string} requirementsPath - Path to the requirements.txt file. + * @returns {Promise<void>} - Resolves when dependencies are successfully installed, rejects on failure. + */ +function installDependencies(venvPath: string, requirementsPath: string): Promise<void> { + return new Promise((resolve, reject) => { + const pipPath = process.platform === 'win32' ? path.join(venvPath, 'Scripts', 'pip.exe') : path.join(venvPath, 'bin', 'pip3'); + + const installProcess = spawn(pipPath, ['install', '-r', requirementsPath]); + + installProcess.stdout.on('data', data => { + console.log(`pip stdout: ${data}`); + }); + + installProcess.stderr.on('data', data => { + console.error(`pip stderr: ${data}`); + }); + + installProcess.on('close', code => { + if (code !== 0) { + reject(`Failed to install dependencies. Exit code: ${code}`); + } else { + resolve(); + } + }); + }); +} + +/** + * Creates a new Python virtual environment. + * @param {string} venvPath - Path to the virtual environment that will be created. + * @returns {Promise<void>} - Resolves when the virtual environment is successfully created, rejects on failure. + */ +function createVirtualEnvironment(venvPath: string): Promise<void> { + return new Promise((resolve, reject) => { + const createVenvProcess = spawn('python3', ['-m', 'venv', venvPath]); + + createVenvProcess.on('close', code => { + if (code !== 0) { + reject(`Failed to create virtual environment. Exit code: ${code}`); + } else { + resolve(); + } + }); + }); +} + +/** + * Manages the creation of the virtual environment, installation of dependencies, and running of the Python script. + * @param {string} [file] - Optional file data to be processed by the Python script. + * @param {string} [drag] - Optional argument controlling drag mode. + * @param {string} [smart] - Optional argument controlling smart mode. + * @returns {Promise<string>} - Resolves with the Python script output, or rejects on failure. + */ +async function manageVenvAndRunScript(file?: string, drag?: string, smart?: string): Promise<string> { + const venvPath = path.join(__dirname, '../flashcard/venv'); // Virtual environment path + const requirementsPath = path.join(__dirname, '../flashcard/requirements.txt'); + const pythonScriptPath = path.join(__dirname, '../flashcard/labels.py'); + console.log('venvPath:', venvPath); + + // Check if the virtual environment exists + if (!fs.existsSync(path.join(venvPath, 'bin', 'python3')) && !fs.existsSync(path.join(venvPath, 'Scripts', 'python.exe'))) { + await createVirtualEnvironment(venvPath); + + await installDependencies(venvPath, requirementsPath); + } + + return runPythonScript(venvPath, pythonScriptPath, file, drag, smart); +} + +/** + * FlashcardManager class responsible for managing API routes related to flashcard functionality. + * It initializes API routes for handling YouTube subscriptions and label creation using a Python backend. + */ +export default class FlashcardManager extends ApiManager { + /** + * Initializes the API routes for the FlashcardManager class. + * @param {Registration} register - The registration function for defining API routes. + */ + protected initialize(register: Registration): void { + register({ + method: Method.POST, + subscription: '/labels', + secureHandler: async ({ req, res }) => { + const { file, drag, smart } = req.body; + + try { + // Run the Python process + const result = await manageVenvAndRunScript(file, drag, smart); + res.status(200).send({ result }); + } catch (error) { + console.error('Error initiating document creation:', error); + res.status(500).send({ + error: 'Failed to initiate document creation', + }); + } + }, + }); + } +} diff --git a/src/server/flashcard/labels.py b/src/server/flashcard/labels.py new file mode 100644 index 000000000..546fc4bd3 --- /dev/null +++ b/src/server/flashcard/labels.py @@ -0,0 +1,285 @@ +import base64 +import numpy as np +import base64 +import easyocr +import sys +from PIL import Image +from io import BytesIO +import requests +import json +import numpy as np + +class BoundingBoxUtils: + """Utility class for bounding box operations and OCR result corrections.""" + + @staticmethod + def is_close(box1, box2, x_threshold=20, y_threshold=20): + """ + Determines if two bounding boxes are horizontally and vertically close. + + Parameters: + box1, box2 (list): The bounding boxes to compare. + x_threshold (int): The threshold for horizontal proximity. + y_threshold (int): The threshold for vertical proximity. + + Returns: + bool: True if boxes are close, False otherwise. + """ + horizontally_close = (abs(box1[2] - box2[0]) < x_threshold or # Right edge of box1 and left edge of box2 + abs(box2[2] - box1[0]) < x_threshold or # Right edge of box2 and left edge of box1 + abs(box1[2] - box2[2]) < x_threshold or + abs(box2[0] - box1[0]) < x_threshold) + + vertically_close = (abs(box1[3] - box2[1]) < y_threshold or # Bottom edge of box1 and top edge of box2 + abs(box2[3] - box1[1]) < y_threshold or + box1[1] == box2[1] or box1[3] == box2[3]) + + return horizontally_close and vertically_close + + @staticmethod + def adjust_bounding_box(bbox, original_text, corrected_text): + """ + Adjusts a bounding box based on differences in text length. + + Parameters: + bbox (list): The original bounding box coordinates. + original_text (str): The original text detected by OCR. + corrected_text (str): The corrected text after cleaning. + + Returns: + list: The adjusted bounding box. + """ + if not bbox or len(bbox) != 4: + return bbox + + # Adjust the x-coordinates slightly to account for text correction + x_adjustment = 5 + adjusted_bbox = [ + [bbox[0][0] + x_adjustment, bbox[0][1]], + [bbox[1][0], bbox[1][1]], + [bbox[2][0] + x_adjustment, bbox[2][1]], + [bbox[3][0], bbox[3][1]] + ] + return adjusted_bbox + + @staticmethod + def correct_ocr_results(results): + """ + Corrects common OCR misinterpretations in the detected text and adjusts bounding boxes accordingly. + + Parameters: + results (list): A list of OCR results, each containing bounding box, text, and confidence score. + + Returns: + list: Corrected OCR results with adjusted bounding boxes. + """ + corrections = { + "~": "", # Replace '~' with empty string + "-": "" # Replace '-' with empty string + } + + corrected_results = [] + for (bbox, text, prob) in results: + corrected_text = ''.join(corrections.get(char, char) for char in text) + adjusted_bbox = BoundingBoxUtils.adjust_bounding_box(bbox, text, corrected_text) + corrected_results.append((adjusted_bbox, corrected_text, prob)) + + return corrected_results + + @staticmethod + def convert_to_json_serializable(data): + """ + Converts a list containing various types, including numpy types, to a JSON-serializable format. + + Parameters: + data (list): A list containing numpy or other non-serializable types. + + Returns: + list: A JSON-serializable version of the input list. + """ + def convert_element(element): + if isinstance(element, list): + return [convert_element(e) for e in element] + elif isinstance(element, tuple): + return tuple(convert_element(e) for e in element) + elif isinstance(element, np.integer): + return int(element) + elif isinstance(element, np.floating): + return float(element) + elif isinstance(element, np.ndarray): + return element.tolist() + else: + return element + + return convert_element(data) + +class ImageLabelProcessor: + """Class to process images and perform OCR with EasyOCR.""" + + VERTICAL_THRESHOLD = 20 + HORIZONTAL_THRESHOLD = 8 + + def __init__(self, img_source, source_type, smart_mode): + self.img_source = img_source + self.source_type = source_type + self.smart_mode = smart_mode + self.img_val = self.load_image() + + def load_image(self): + """Load image from either a base64 string or URL.""" + if self.source_type == 'drag': + return self._load_base64_image() + else: + return self._load_url_image() + + def _load_base64_image(self): + """Decode and save the base64 image.""" + base64_string = self.img_source + if base64_string.startswith("data:image"): + base64_string = base64_string.split(",")[1] + + + # Decode the base64 string + image_data = base64.b64decode(base64_string) + image = Image.open(BytesIO(image_data)).convert('RGB') + image.save("temp_image.jpg") + return "temp_image.jpg" + + def _load_url_image(self): + """Download image from URL and return it in byte format.""" + url = self.img_source + response = requests.get(url) + image = Image.open(BytesIO(response.content)).convert('RGB') + + image_bytes = BytesIO() + image.save(image_bytes, format='PNG') + return image_bytes.getvalue() + + def process_image(self): + """Process the image and return the OCR results.""" + if self.smart_mode: + return self._process_smart_mode() + else: + return self._process_standard_mode() + + def _process_smart_mode(self): + """Process the image in smart mode using EasyOCR.""" + reader = easyocr.Reader(['en']) + result = reader.readtext(self.img_val, detail=1, paragraph=True) + + all_boxes = [bbox for bbox, text in result] + all_texts = [text for bbox, text in result] + + response_data = { + 'status': 'success', + 'message': 'Data received', + 'boxes': BoundingBoxUtils.convert_to_json_serializable(all_boxes), + 'text': BoundingBoxUtils.convert_to_json_serializable(all_texts), + } + + return response_data + + def _process_standard_mode(self): + """Process the image in standard mode using EasyOCR.""" + reader = easyocr.Reader(['en']) + results = reader.readtext(self.img_val) + + filtered_results = BoundingBoxUtils.correct_ocr_results([ + (bbox, text, prob) for bbox, text, prob in results if prob >= 0.7 + ]) + + return self._merge_and_prepare_response(filtered_results) + + def are_vertically_close(self, box1, box2): + """Check if two bounding boxes are vertically close.""" + box1_bottom = max(box1[2][1], box1[3][1]) + box2_top = min(box2[0][1], box2[1][1]) + vertical_distance = box2_top - box1_bottom + + box1_left = box1[0][0] + box2_left = box2[0][0] + box1_right = box1[1][0] + box2_right = box2[1][0] + hori_close = abs(box2_left - box1_left) <= self.HORIZONTAL_THRESHOLD or abs(box2_right - box1_right) <= self.HORIZONTAL_THRESHOLD + + return vertical_distance <= self.VERTICAL_THRESHOLD and hori_close + + def merge_boxes(self, boxes, texts): + """Merge multiple bounding boxes and their associated text.""" + x_coords = [] + y_coords = [] + + # Collect all x and y coordinates + for box in boxes: + for point in box: + x_coords.append(point[0]) + y_coords.append(point[1]) + + # Create the merged bounding box + merged_box = [ + [min(x_coords), min(y_coords)], + [max(x_coords), min(y_coords)], + [max(x_coords), max(y_coords)], + [min(x_coords), max(y_coords)] + ] + + # Combine the texts + merged_text = ' '.join(texts) + + return merged_box, merged_text + + def _merge_and_prepare_response(self, filtered_results): + """Merge vertically close boxes and prepare the final response.""" + current_boxes, current_texts = [], [] + all_boxes, all_texts = [], [] + + for ind in range(len(filtered_results) - 1): + if not current_boxes: + current_boxes.append(filtered_results[ind][0]) + current_texts.append(filtered_results[ind][1]) + + if self.are_vertically_close(filtered_results[ind][0], filtered_results[ind + 1][0]): + current_boxes.append(filtered_results[ind + 1][0]) + current_texts.append(filtered_results[ind + 1][1]) + else: + merged = self.merge_boxes(current_boxes, current_texts) + all_boxes.append(merged[0]) + all_texts.append(merged[1]) + current_boxes, current_texts = [], [] + + if current_boxes: + merged = self.merge_boxes(current_boxes, current_texts) + all_boxes.append(merged[0]) + all_texts.append(merged[1]) + + if not current_boxes and filtered_results: + merged = self.merge_boxes([filtered_results[-1][0]], [filtered_results[-1][1]]) + all_boxes.append(merged[0]) + all_texts.append(merged[1]) + + response = { + 'status': 'success', + 'message': 'Data received', + 'boxes': BoundingBoxUtils.convert_to_json_serializable(all_boxes), + 'text': BoundingBoxUtils.convert_to_json_serializable(all_texts), + } + + return response + +# Main execution function +def labels(): + """Main function to handle image OCR processing based on input arguments.""" + source_type = sys.argv[2] + smart_mode = (sys.argv[3] == 'smart') + with open(sys.argv[1], 'r') as f: + img_source = f.read() + # Create ImageLabelProcessor instance + processor = ImageLabelProcessor(img_source, source_type, smart_mode) + response = processor.process_image() + + # Print and return the response + print(response) + return response + + +labels() diff --git a/src/server/flashcard/requirements.txt b/src/server/flashcard/requirements.txt new file mode 100644 index 000000000..eb92a819b --- /dev/null +++ b/src/server/flashcard/requirements.txt @@ -0,0 +1,12 @@ +easyocr==1.7.1 +requests==2.32.3 +pillow==10.4.0 +numpy==1.26.4 +tqdm==4.66.4 +Werkzeug==3.0.3 +python-dateutil==2.9.0.post0 +six==1.16.0 +certifi==2024.6.2 +charset-normalizer==3.3.2 +idna==3.7 +urllib3==1.26.19
\ No newline at end of file |