diff options
-rw-r--r-- | src/server/chunker/pdf_chunker.py | 317 |
1 files changed, 215 insertions, 102 deletions
diff --git a/src/server/chunker/pdf_chunker.py b/src/server/chunker/pdf_chunker.py index 12e71c29d..4fe3b9dbf 100644 --- a/src/server/chunker/pdf_chunker.py +++ b/src/server/chunker/pdf_chunker.py @@ -32,7 +32,6 @@ import warnings warnings.filterwarnings('ignore', message="Valid config keys have changed") warnings.filterwarnings('ignore', message="torch.load") - dotenv.load_dotenv() # Load environment variables # Fix for newer versions of PIL @@ -45,6 +44,10 @@ current_progress = {} def update_progress(job_id, step, progress_value): """ Output the progress in JSON format to stdout for the Node.js process to capture. + + :param job_id: The unique identifier for the processing job. + :param step: The current step of the job. + :param progress_value: The percentage of completion for the current step. """ progress_data = { "job_id": job_id, @@ -56,27 +59,50 @@ def update_progress(job_id, step, progress_value): class ElementExtractor: + """ + A class that uses a YOLO model to extract tables and images from a PDF page. + """ + def __init__(self, output_folder: str): + """ + Initializes the ElementExtractor with the output folder for saving images and the YOLO model. + + :param output_folder: Path to the folder where extracted elements will be saved. + """ self.output_folder = output_folder - self.model = YOLO('keremberke/yolov8m-table-extraction') - self.model.overrides['conf'] = 0.25 - self.model.overrides['iou'] = 0.45 - self.padding = 5 + self.model = YOLO('keremberke/yolov8m-table-extraction') # Load YOLO model for table extraction + self.model.overrides['conf'] = 0.25 # Set confidence threshold for detection + self.model.overrides['iou'] = 0.45 # Set Intersection over Union (IoU) threshold + self.padding = 5 # Padding around detected elements async def extract_elements(self, page, padding: int = 20) -> List[Dict[str, Any]]: + """ + Asynchronously extract tables and images from a PDF page. + + :param page: A Page object representing a PDF page. + :param padding: Padding around the extracted elements. + :return: A list of dictionaries containing the extracted elements. + """ tasks = [ - asyncio.create_task(self.extract_tables(page.image, page.page_num)), - asyncio.create_task(self.extract_images(page.page, page.image, page.page_num)) + asyncio.create_task(self.extract_tables(page.image, page.page_num)), # Extract tables from the page + asyncio.create_task(self.extract_images(page.page, page.image, page.page_num)) # Extract images from the page ] - results = await asyncio.gather(*tasks) - return [item for sublist in results for item in sublist] + results = await asyncio.gather(*tasks) # Wait for both tasks to complete + return [item for sublist in results for item in sublist] # Flatten and return results async def extract_tables(self, img: Image.Image, page_num: int) -> List[Dict[str, Any]]: - results = self.model.predict(img, verbose=False) + """ + Asynchronously extract tables from a given page image using the YOLO model. + + :param img: The image of the PDF page. + :param page_num: The current page number. + :return: A list of dictionaries with metadata about the detected tables. + """ + results = self.model.predict(img, verbose=False) # Predict table locations using YOLO tables = [] for idx, box in enumerate(results[0].boxes): - x1, y1, x2, y2 = map(int, box.xyxy[0]) + x1, y1, x2, y2 = map(int, box.xyxy[0]) # Extract bounding box coordinates # Draw a red rectangle on the full page image around the table page_with_outline = img.copy() @@ -107,20 +133,27 @@ class ElementExtractor: return tables async def extract_images(self, page: fitz.Page, img: Image.Image, page_num: int) -> List[Dict[str, Any]]: + """ + Asynchronously extract embedded images from a PDF page. + + :param page: A fitz.Page object representing the PDF page. + :param img: The image of the PDF page. + :param page_num: The current page number. + :return: A list of dictionaries with metadata about the detected images. + """ images = [] - image_list = page.get_images(full=True) + image_list = page.get_images(full=True) # Get a list of images on the page if not image_list: return images for img_index, img_info in enumerate(image_list): - xref = img_info[0] - #try: - base_image = page.parent.extract_image(xref) + xref = img_info[0] # XREF of the image in the PDF + base_image = page.parent.extract_image(xref) # Extract the image by its XREF image_bytes = base_image["image"] - image = Image.open(io.BytesIO(image_bytes)) - width_ratio = img.width / page.rect.width - height_ratio = img.height / page.rect.height + image = Image.open(io.BytesIO(image_bytes)) # Convert bytes to PIL image + width_ratio = img.width / page.rect.width # Scale factor for width + height_ratio = img.height / page.rect.height # Scale factor for height # Get image coordinates or default to page rectangle rect_list = page.get_image_rects(xref) @@ -157,15 +190,19 @@ class ElementExtractor: } }) - #except Exception as e: - # print(f"Error processing image on page {page_num + 1}, image {img_index + 1}: {str(e)}") return images @staticmethod def image_to_base64(image: Image.Image) -> str: + """ + Convert a PIL image to a base64-encoded string. + + :param image: The PIL image to be converted. + :return: The base64-encoded string of the image. + """ buffered = io.BytesIO() - image.save(buffered, format="PNG") - return base64.b64encode(buffered.getvalue()).decode('utf-8') + image.save(buffered, format="PNG") # Save image as PNG to an in-memory buffer + return base64.b64encode(buffered.getvalue()).decode('utf-8') # Convert to base64 and return class ChunkMetaData(TypedDict): @@ -198,6 +235,12 @@ class Page: """ def __init__(self, page: fitz.Page, page_num: int): + """ + Initializes the Page with its page number and the image representation of the page. + + :param page: A fitz.Page object representing the PDF page. + :param page_num: The number of the page in the PDF. + """ self.page = page self.page_num = page_num # Get high-resolution image of the page (for table/image extraction) @@ -210,12 +253,14 @@ class Page: def add_element(self, element): """ Adds a detected element (table/image) to the page and masks its location on the page image. + + :param element: A dictionary containing metadata about the detected element. """ self.elements.append(element) # Mask the element on the page image by drawing a white rectangle over its location x1, y1, x2, y2 = [coord * self.image.width if i % 2 == 0 else coord * self.image.height for i, coord in enumerate(element['metadata']['location'])] - self.draw.rectangle([x1, y1, x2, y2], fill="white") + self.draw.rectangle([x1, y1, x2, y2], fill="white") # Draw a white rectangle to mask the element class PDFChunker: @@ -224,6 +269,12 @@ class PDFChunker: """ def __init__(self, output_folder: str = "output", image_batch_size: int = 5) -> None: + """ + Initializes the PDFChunker with an output folder and an element extractor for visual elements. + + :param output_folder: Folder to store the output files (extracted tables/images). + :param image_batch_size: The batch size for processing visual elements. + """ self.client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) # Initialize the Anthropic API client self.output_folder = output_folder self.image_batch_size = image_batch_size # Batch size for image processing @@ -232,22 +283,28 @@ class PDFChunker: async def chunk_pdf(self, file_data: bytes, file_name: str, doc_id: str, job_id: str) -> List[Dict[str, Any]]: """ Processes a PDF file, extracting text and visual elements, and returning structured chunks. + + :param file_data: The binary data of the PDF file. + :param file_name: The name of the PDF file. + :param doc_id: The unique document ID for this job. + :param job_id: The unique job ID for the processing task. + :return: A list of structured chunks containing text and visual elements. """ with fitz.open(stream=file_data, filetype="pdf") as pdf_document: num_pages = len(pdf_document) # Get the total number of pages in the PDF - pages = [Page(pdf_document[i], i) for i in tqdm(range(num_pages), desc="Initializing Pages")] + pages = [Page(pdf_document[i], i) for i in tqdm(range(num_pages), desc="Initializing Pages")] # Initialize each page update_progress(job_id, "Extracting tables and images...", 0) - await self.extract_and_mask_elements(pages, job_id) + await self.extract_and_mask_elements(pages, job_id) # Extract and mask elements (tables/images) update_progress(job_id, "Processing tables and images...", 0) - await self.process_visual_elements(pages, self.image_batch_size, job_id) + await self.process_visual_elements(pages, self.image_batch_size, job_id) # Process visual elements update_progress(job_id, "Extracting text...", 0) - page_texts = await self.extract_text_from_masked_pages(pages, job_id) + page_texts = await self.extract_text_from_masked_pages(pages, job_id) # Extract text from masked pages update_progress(job_id, "Processing text...", 0) - text_chunks = self.chunk_text_with_metadata(page_texts, max_words=1000, job_id=job_id) + text_chunks = self.chunk_text_with_metadata(page_texts, max_words=1000, job_id=job_id) # Chunk text into smaller parts # Combine text and visual elements into a unified structure (chunks) chunks = self.combine_chunks(text_chunks, [elem for page in pages for elem in page.elements], file_name, @@ -258,13 +315,16 @@ class PDFChunker: async def extract_and_mask_elements(self, pages: List[Page], job_id: str): """ Extract visual elements (tables and images) from each page and mask them on the page. + + :param pages: A list of Page objects representing the PDF pages. + :param job_id: The unique job ID for the processing task. """ total_pages = len(pages) tasks = [] for i, page in enumerate(pages): - tasks.append(asyncio.create_task(self.element_extractor.extract_elements(page))) - progress = ((i + 1) / total_pages) * 100 + tasks.append(asyncio.create_task(self.element_extractor.extract_elements(page))) # Extract elements asynchronously + progress = ((i + 1) / total_pages) * 100 # Calculate progress update_progress(job_id, "Extracting tables and images...", progress) # Gather all extraction results @@ -273,16 +333,20 @@ class PDFChunker: # Mask the detected elements on the page images for page, elements in zip(pages, results): for element in elements: - page.add_element(element) + page.add_element(element) # Mask each extracted element on the page - async def process_visual_elements(self, pages: List[Page], image_batch_size: int, job_id: str) -> List[ - Dict[str, Any]]: + async def process_visual_elements(self, pages: List[Page], image_batch_size: int, job_id: str) -> List[Dict[str, Any]]: """ Process extracted visual elements in batches, generating summaries or descriptions. + + :param pages: A list of Page objects representing the PDF pages. + :param image_batch_size: The batch size for processing visual elements. + :param job_id: The unique job ID for the processing task. + :return: A list of processed elements with metadata and generated summaries. """ pre_elements = [element for page in pages for element in page.elements] # Flatten list of elements processed_elements = [] - total_batches = (len(pre_elements) // image_batch_size) + 1 + total_batches = (len(pre_elements) // image_batch_size) + 1 # Calculate total number of batches loop = asyncio.get_event_loop() with concurrent.futures.ThreadPoolExecutor() as executor: @@ -301,7 +365,7 @@ class PDFChunker: elem['metadata']['text'] = re.sub(r'^(Image|Table):\s*', '', summaries[j]) processed_elements.append(elem) - progress = ((i // image_batch_size) + 1) / total_batches * 100 + progress = ((i // image_batch_size) + 1) / total_batches * 100 # Calculate progress update_progress(job_id, "Processing tables and images...", progress) return processed_elements @@ -309,13 +373,17 @@ class PDFChunker: async def extract_text_from_masked_pages(self, pages: List[Page], job_id: str) -> Dict[int, str]: """ Extract text from masked page images (where tables and images have been masked out). + + :param pages: A list of Page objects representing the PDF pages. + :param job_id: The unique job ID for the processing task. + :return: A dictionary mapping page numbers to extracted text. """ total_pages = len(pages) tasks = [] for i, page in enumerate(pages): - tasks.append(asyncio.create_task(self.extract_text(page.masked_image, page.page_num))) - progress = ((i + 1) / total_pages) * 100 + tasks.append(asyncio.create_task(self.extract_text(page.masked_image, page.page_num))) # Perform OCR on each page + progress = ((i + 1) / total_pages) * 100 # Calculate progress update_progress(job_id, "Extracting text...", progress) # Return extracted text from each page @@ -325,13 +393,22 @@ class PDFChunker: async def extract_text(image: Image.Image, page_num: int) -> (int, str): """ Perform OCR on the provided image to extract text. + + :param image: The PIL image of the page. + :param page_num: The current page number. + :return: A tuple containing the page number and the extracted text. """ - result = pytesseract.image_to_string(image) + result = pytesseract.image_to_string(image) # Extract text using Tesseract OCR return page_num + 1, result.strip() # Return the page number and extracted text def chunk_text_with_metadata(self, page_texts: Dict[int, str], max_words: int, job_id: str) -> List[Dict[str, Any]]: """ Break the extracted text into smaller chunks with metadata (e.g., page numbers). + + :param page_texts: A dictionary mapping page numbers to extracted text. + :param max_words: The maximum number of words allowed in a chunk. + :param job_id: The unique job ID for the processing task. + :return: A list of dictionaries containing text chunks with metadata. """ chunks = [] current_chunk = "" @@ -362,7 +439,7 @@ class PDFChunker: total_words += word_count current_chunk += "\n\n" - progress = ((i + 1) / total_pages) * 100 + progress = ((i + 1) / total_pages) * 100 # Calculate progress update_progress(job_id, "Processing text...", progress) # Add the last chunk if there is leftover text @@ -375,6 +452,9 @@ class PDFChunker: def split_into_sentences(text): """ Split the text into sentences using regular expressions. + + :param text: The raw text to be split into sentences. + :return: A list of sentences. """ return re.split(r'(?<=[.!?])\s+', text) @@ -383,6 +463,12 @@ class PDFChunker: doc_id: str) -> List[Chunk]: """ Combine text and visual chunks into a unified list. + + :param text_chunks: A list of dictionaries containing text chunks with metadata. + :param visual_elements: A list of dictionaries containing visual elements (tables/images) with metadata. + :param pdf_path: The path to the original PDF file. + :param doc_id: The unique document ID for this job. + :return: A list of Chunk objects representing the combined data. """ combined_chunks = [] # Add text chunks @@ -399,7 +485,7 @@ class PDFChunker: "doc_id": doc_id, } chunk_dict: Chunk = { - "id": str(uuid.uuid4()), + "id": str(uuid.uuid4()), # Generate a unique ID for the chunk "values": [], "metadata": chunk_metadata, } @@ -419,7 +505,7 @@ class PDFChunker: "original_document": pdf_path, } visual_chunk_dict: Chunk = { - "id": str(uuid.uuid4()), + "id": str(uuid.uuid4()), # Generate a unique ID for the visual chunk "values": [], "metadata": visual_chunk_metadata, } @@ -430,6 +516,9 @@ class PDFChunker: def batch_summarize_images(self, images: Dict[int, str]) -> Dict[int, str]: """ Summarize images or tables by generating descriptive text. + + :param images: A dictionary mapping image numbers to base64-encoded image data. + :return: A dictionary mapping image numbers to their generated summaries. """ # Prompt for the AI model to summarize images and tables prompt = f"""<instruction> @@ -544,118 +633,136 @@ class PDFChunker: #print("Returning placeholder summaries") return {number: "Error: No summary available" for number in images} - class DocumentType(Enum): - PDF = "pdf" - CSV = "csv" - TXT = "txt" - HTML = "html" + """ + Enum representing different types of documents that can be processed. + """ + PDF = "pdf" # PDF file type + CSV = "csv" # CSV file type + TXT = "txt" # Plain text file type + HTML = "html" # HTML file type class FileTypeNotSupportedException(Exception): """ - Exception raised for unsupported file types. + Exception raised when a file type is unsupported during document processing. """ def __init__(self, file_extension: str): + """ + Initialize the exception with the unsupported file extension. + + :param file_extension: The file extension that triggered the exception. + """ self.file_extension = file_extension self.message = f"File type '{file_extension}' is not supported." - super().__init__(self.message) + super().__init__(self.message) # Call the parent class constructor with the message class Document: """ - Represents a document being processed, such as a PDF, handling chunking and embedding. + Represents a document being processed, such as a PDF, handling chunking, embedding, and summarization. """ def __init__(self, file_data: bytes, file_name: str, job_id: str): + """ + Initialize the Document with file data, file name, and job ID. + + :param file_data: The binary data of the file being processed. + :param file_name: The name of the file being processed. + :param job_id: The job ID associated with this document processing task. + """ self.file_data = file_data self.file_name = file_name self.job_id = job_id - self.type = self._get_document_type(file_name) - self.doc_id = job_id # Use job_id as document ID - self.chunks = [] - self.num_pages = 0 - self.summary = "" + self.type = self._get_document_type(file_name) # Determine the document type (PDF, CSV, etc.) + self.doc_id = job_id # Use the job ID as the document ID + self.chunks = [] # List to hold text and visual chunks + self.num_pages = 0 # Number of pages in the document (if applicable) + self.summary = "" # The generated summary for the document self._process() # Start processing the document def _process(self): """ - Process the document: chunk it, embed chunks, and generate a summary. + Process the document: extract chunks, embed them, and generate a summary. """ - pdf_chunker = PDFChunker(output_folder="output") - self.chunks = asyncio.run(pdf_chunker.chunk_pdf(self.file_data, self.file_name, self.doc_id, self.job_id)) + pdf_chunker = PDFChunker(output_folder="output") # Initialize the PDF chunker + self.chunks = asyncio.run(pdf_chunker.chunk_pdf(self.file_data, self.file_name, self.doc_id, self.job_id)) # Extract chunks - self.num_pages = self._get_pdf_pages() # Get the number of pages - self._embed_chunks() # Embed the text chunks - self.summary = self._generate_summary() # Generate a summary + self.num_pages = self._get_pdf_pages() # Get the number of pages in the document + self._embed_chunks() # Embed the text chunks into embeddings + self.summary = self._generate_summary() # Generate a summary for the document def _get_document_type(self, file_name: str) -> DocumentType: """ Determine the document type based on its file extension. + + :param file_name: The name of the file being processed. + :return: The DocumentType enum value corresponding to the file extension. """ - _, extension = os.path.splitext(file_name) - extension = extension.lower().lstrip('.') + _, extension = os.path.splitext(file_name) # Split the file name to get the extension + extension = extension.lower().lstrip('.') # Convert to lowercase and remove leading period try: - return DocumentType(extension) + return DocumentType(extension) # Try to match the extension to a DocumentType except ValueError: - raise FileTypeNotSupportedException(extension) + raise FileTypeNotSupportedException(extension) # Raise exception if file type is unsupported def _get_pdf_pages(self) -> int: """ - Get the total number of pages in the PDF. + Get the total number of pages in the PDF document. + + :return: The number of pages in the PDF. """ - pdf_file = io.BytesIO(self.file_data) - pdf_reader = PdfReader(pdf_file) - return len(pdf_reader.pages) + pdf_file = io.BytesIO(self.file_data) # Convert the file data to an in-memory binary stream + pdf_reader = PdfReader(pdf_file) # Initialize PDF reader + return len(pdf_reader.pages) # Return the number of pages in the PDF def _embed_chunks(self) -> None: """ Embed the text chunks using the Cohere API. """ - co = cohere.Client(os.getenv("COHERE_API_KEY")) - batch_size = 90 - chunks_len = len(self.chunks) + co = cohere.Client(os.getenv("COHERE_API_KEY")) # Initialize Cohere client with API key + batch_size = 90 # Batch size for embedding + chunks_len = len(self.chunks) # Total number of chunks to embed for i in tqdm(range(0, chunks_len, batch_size), desc="Embedding Chunks"): - batch = self.chunks[i: min(i + batch_size, chunks_len)] - texts = [chunk['metadata']['text'] for chunk in batch] - #try: + batch = self.chunks[i: min(i + batch_size, chunks_len)] # Get batch of chunks + texts = [chunk['metadata']['text'] for chunk in batch] # Extract text from each chunk chunk_embs_batch = co.embed( texts=texts, - model="embed-english-v3.0", - input_type="search_document" + model="embed-english-v3.0", # Use Cohere's embedding model + input_type="search_document" # Specify input type ) for j, emb in enumerate(chunk_embs_batch.embeddings): - self.chunks[i + j]['values'] = emb - #except Exception as e: - #print(f"Error embedding batch for {self.file_name}: {str(e)}") + self.chunks[i + j]['values'] = emb # Store the embeddings in the corresponding chunks def _generate_summary(self) -> str: """ Generate a summary of the document using KMeans clustering and a language model. + + :return: The generated summary of the document. """ - num_clusters = min(10, len(self.chunks)) - kmeans = KMeans(n_clusters=num_clusters, random_state=42) - doc_chunks = [chunk['values'] for chunk in self.chunks if 'values' in chunk] - cluster_labels = kmeans.fit_predict(doc_chunks) + num_clusters = min(10, len(self.chunks)) # Set number of clusters for KMeans, capped at 10 + kmeans = KMeans(n_clusters=num_clusters, random_state=42) # Initialize KMeans with 10 clusters + doc_chunks = [chunk['values'] for chunk in self.chunks if 'values' in chunk] # Extract embeddings + cluster_labels = kmeans.fit_predict(doc_chunks) # Assign each chunk to a cluster # Select representative chunks from each cluster selected_chunks = [] for i in range(num_clusters): - cluster_chunks = [chunk for chunk, label in zip(self.chunks, cluster_labels) if label == i] - cluster_embs = [emb for emb, label in zip(doc_chunks, cluster_labels) if label == i] - centroid = kmeans.cluster_centers_[i] - distances = [np.linalg.norm(np.array(emb) - centroid) for emb in cluster_embs] - closest_chunk = cluster_chunks[np.argmin(distances)] + cluster_chunks = [chunk for chunk, label in zip(self.chunks, cluster_labels) if label == i] # Get all chunks in this cluster + cluster_embs = [emb for emb, label in zip(doc_chunks, cluster_labels) if label == i] # Get embeddings for this cluster + centroid = kmeans.cluster_centers_[i] # Get the centroid of the cluster + distances = [np.linalg.norm(np.array(emb) - centroid) for emb in cluster_embs] # Compute distance to centroid + closest_chunk = cluster_chunks[np.argmin(distances)] # Select chunk closest to the centroid selected_chunks.append(closest_chunk) # Combine selected chunks into a summary - combined_text = "\n\n".join([chunk['metadata']['text'] for chunk in selected_chunks]) + combined_text = "\n\n".join([chunk['metadata']['text'] for chunk in selected_chunks]) # Concatenate chunk texts - client = OpenAI() # Call OpenAI API for text generation (summarization) + client = OpenAI() # Initialize OpenAI client for text generation completion = client.chat.completions.create( - model="gpt-3.5-turbo", + model="gpt-3.5-turbo", # Specify the language model messages=[ {"role": "system", "content": "You are an AI assistant tasked with summarizing a document. You are provided with important chunks from the document and provide a summary, as best you can, of what the document will contain overall. Be concise and brief with your response."}, @@ -670,13 +777,15 @@ class Document: Summary: """} ], - max_tokens=300 + max_tokens=300 # Set max tokens for the summary ) - return completion.choices[0].message.content.strip() + return completion.choices[0].message.content.strip() # Return the generated summary def to_json(self) -> str: """ Return the document's data in JSON format. + + :return: JSON string representing the document's metadata, chunks, and summary. """ return json.dumps({ "file_name": self.file_name, @@ -685,16 +794,20 @@ class Document: "chunks": self.chunks, "type": self.type.value, "doc_id": self.doc_id - }, indent=2) + }, indent=2) # Convert the document's attributes to JSON format def process_document(file_data, file_name, job_id): """ Top-level function to process a document and return the JSON output. - """ - new_document = Document(file_data, file_name, job_id) - return new_document.to_json() + :param file_data: The binary data of the file being processed. + :param file_name: The name of the file being processed. + :param job_id: The job ID for this document processing task. + :return: The processed document's data in JSON format. + """ + new_document = Document(file_data, file_name, job_id) # Create a new Document object + return new_document.to_json() # Return the document's JSON data def main(): @@ -702,12 +815,12 @@ def main(): Main entry point for the script, called with arguments from Node.js. """ if len(sys.argv) != 4: - print(json.dumps({"error": "Invalid arguments"}), file=sys.stderr) + print(json.dumps({"error": "Invalid arguments"}), file=sys.stderr) # Print error if incorrect number of arguments return - job_id = sys.argv[1] - file_name = sys.argv[2] - file_data = sys.argv[3] + job_id = sys.argv[1] # Get the job ID from command-line arguments + file_name = sys.argv[2] # Get the file name from command-line arguments + file_data = sys.argv[3] # Get the base64-encoded file data from command-line arguments try: # Decode the base64 file data @@ -727,4 +840,4 @@ def main(): if __name__ == "__main__": - main() + main() # Execute the main function when the script is run |