diff options
author | alyssaf16 <alyssa_feinberg@brown.edu> | 2024-11-04 21:56:25 -0500 |
---|---|---|
committer | alyssaf16 <alyssa_feinberg@brown.edu> | 2024-11-04 21:56:25 -0500 |
commit | 1e4909f04fdcc4c0b3a60b8c75e8b687e2b63b8e (patch) | |
tree | 16fe239082c37cd4f7e10bbfa6964d13a458046a /src | |
parent | 95afe2c1093dc3229375c08e6684b3d9866ef7a2 (diff) | |
parent | 09d7d63d1f248a0bf1d36e4da804cbde5e12e209 (diff) |
Merge branch 'ajs-finalagent' into alyssa-agent
Diffstat (limited to 'src')
-rw-r--r-- | src/client/views/nodes/chatbot/agentsystem/Agent.ts | 87 | ||||
-rw-r--r-- | src/client/views/nodes/chatbot/agentsystem/prompts.ts | 40 | ||||
-rw-r--r-- | src/client/views/nodes/chatbot/chatboxcomponents/ChatBox.tsx | 25 | ||||
-rw-r--r-- | src/client/views/nodes/chatbot/tools/SearchTool.ts | 7 | ||||
-rw-r--r-- | src/client/views/nodes/chatbot/types/tool_types.ts | 2 | ||||
-rw-r--r-- | src/client/views/nodes/chatbot/types/types.ts | 1 | ||||
-rw-r--r-- | src/server/ApiManagers/AssistantManager.ts | 157 | ||||
-rw-r--r-- | src/server/chunker/pdf_chunker.py | 70 |
8 files changed, 225 insertions, 164 deletions
diff --git a/src/client/views/nodes/chatbot/agentsystem/Agent.ts b/src/client/views/nodes/chatbot/agentsystem/Agent.ts index 23e7d4a9d..05d13d1db 100644 --- a/src/client/views/nodes/chatbot/agentsystem/Agent.ts +++ b/src/client/views/nodes/chatbot/agentsystem/Agent.ts @@ -15,7 +15,7 @@ import { AgentMessage, AssistantMessage, Observation, PROCESSING_TYPE, Processin import { Vectorstore } from '../vectorstore/Vectorstore'; import { getReactPrompt } from './prompts'; import { BaseTool } from '../tools/BaseTool'; -import { Parameter, ParametersType } from '../types/tool_types'; +import { Parameter, ParametersType, TypeMap } from '../types/tool_types'; import { CreateDocTool } from '../tools/CreateDocumentTool'; import { DocumentOptions } from '../../../../documents/Documents'; @@ -269,11 +269,35 @@ export class Agent { } /** + * Helper function to check if a string can be parsed as an array of the expected type. + * @param input The input string to check. + * @param expectedType The expected type of the array elements ('string', 'number', or 'boolean'). + * @returns The parsed array if valid, otherwise throws an error. + */ + private parseArray<T>(input: string, expectedType: 'string' | 'number' | 'boolean'): T[] { + try { + // Parse the input string into a JSON object + const parsed = JSON.parse(input); + + // Check if the parsed object is an array and if all elements are of the expected type + if (Array.isArray(parsed) && parsed.every(item => typeof item === expectedType)) { + return parsed; + } else { + throw new Error(`Invalid ${expectedType} array format.`); + } + } catch (error) { + throw new Error(`Failed to parse ${expectedType} array: ` + error); + } + } + + /** * Processes a specific action by invoking the appropriate tool with the provided inputs. * This method ensures that the action exists and validates the types of `actionInput` * based on the tool's parameter rules. It throws errors for missing required parameters * or mismatched types before safely executing the tool with the validated input. * + * NOTE: In the future, it should typecheck for specific tool parameter types using the `TypeMap` or otherwise. + * * Type validation includes checks for: * - `string`, `number`, `boolean` * - `string[]`, `number[]` (arrays of strings or numbers) @@ -283,56 +307,35 @@ export class Agent { * @returns A promise that resolves to an array of `Observation` objects representing the result of the action. * @throws An error if the action is unknown, if required parameters are missing, or if input types don't match the expected parameter types. */ - private async processAction(action: string, actionInput: Record<string, unknown>): Promise<Observation[]> { + private async processAction(action: string, actionInput: ParametersType<ReadonlyArray<Parameter>>): Promise<Observation[]> { // Check if the action exists in the tools list if (!(action in this.tools)) { throw new Error(`Unknown action: ${action}`); } + console.log(actionInput); - const tool = this.tools[action]; - - // Validate actionInput based on tool's parameter rules - for (const paramRule of tool.parameterRules) { - const inputValue = actionInput[paramRule.name]; - - if (paramRule.required && inputValue === undefined) { - throw new Error(`Missing required parameter: ${paramRule.name}`); + for (const param of this.tools[action].parameterRules) { + // Check if the parameter is required and missing in the input + if (param.required && !(param.name in actionInput)) { + throw new Error(`Missing required parameter: ${param.name}`); } - // If the parameter is defined, check its type - if (inputValue !== undefined) { - switch (paramRule.type) { - case 'string': - if (typeof inputValue !== 'string') { - throw new Error(`Expected parameter '${paramRule.name}' to be a string.`); - } - break; - case 'number': - if (typeof inputValue !== 'number') { - throw new Error(`Expected parameter '${paramRule.name}' to be a number.`); - } - break; - case 'boolean': - if (typeof inputValue !== 'boolean') { - throw new Error(`Expected parameter '${paramRule.name}' to be a boolean.`); - } - break; - case 'string[]': - if (!Array.isArray(inputValue) || !inputValue.every(item => typeof item === 'string')) { - throw new Error(`Expected parameter '${paramRule.name}' to be an array of strings.`); - } - break; - case 'number[]': - if (!Array.isArray(inputValue) || !inputValue.every(item => typeof item === 'number')) { - throw new Error(`Expected parameter '${paramRule.name}' to be an array of numbers.`); - } - break; - default: - throw new Error(`Unsupported parameter type: ${paramRule.type}`); - } + // Check if the parameter type matches the expected type + const expectedType = param.type.replace('[]', '') as 'string' | 'number' | 'boolean'; + const isArray = param.type.endsWith('[]'); + const input = actionInput[param.name]; + + if (isArray) { + // Check if the input is a valid array of the expected type + const parsedArray = this.parseArray(input as string, expectedType); + actionInput[param.name] = parsedArray as TypeMap[typeof param.type]; + } else if (typeof input !== expectedType) { + throw new Error(`Invalid type for parameter ${param.name}: expected ${expectedType}`); } } - return await tool.execute(actionInput as ParametersType<typeof tool.parameterRules>); + const tool = this.tools[action]; + + return await tool.execute(actionInput); } } diff --git a/src/client/views/nodes/chatbot/agentsystem/prompts.ts b/src/client/views/nodes/chatbot/agentsystem/prompts.ts index f5aec3130..140587b2f 100644 --- a/src/client/views/nodes/chatbot/agentsystem/prompts.ts +++ b/src/client/views/nodes/chatbot/agentsystem/prompts.ts @@ -7,9 +7,10 @@ * and summarizing content from provided text chunks. */ -import { Tool } from '../types/types'; +import { BaseTool } from '../tools/BaseTool'; +import { Parameter } from '../types/tool_types'; -export function getReactPrompt(tools: Tool[], summaries: () => string, chatHistory: string): string { +export function getReactPrompt(tools: BaseTool<ReadonlyArray<Parameter>>[], summaries: () => string, chatHistory: string): string { const toolDescriptions = tools .map( tool => ` @@ -143,9 +144,9 @@ export function getReactPrompt(tools: Tool[], summaries: () => string, chatHisto <stage number="6" role="assistant"> <thought> - With key moments from the World Cup retrieved, I will now use the website scraper tool to gather data on Qatar's tourism impact during the World Cup. + With key moments from the World Cup retrieved, I will now use the search tool to gather data on Qatar's tourism impact during the World Cup. </thought> - <action>websiteInfoScraper</action> + <action>searchTool</action> </stage> <stage number="7" role="user"> @@ -156,7 +157,7 @@ export function getReactPrompt(tools: Tool[], summaries: () => string, chatHisto <action_input> <action_input_description>Scraping websites for information about Qatar's tourism impact during the 2022 World Cup.</action_input_description> <inputs> - <query>Tourism impact of the 2022 World Cup in Qatar</query> + <queries>["Tourism impact of the 2022 World Cup in Qatar"]</query> </inputs> </action_input> </stage> @@ -167,11 +168,40 @@ export function getReactPrompt(tools: Tool[], summaries: () => string, chatHisto <url>https://www.qatartourism.com/world-cup-impact</url> <overview>During the 2022 World Cup, Qatar saw a 40% increase in tourism, with over 1.5 million visitors attending.</overview> </chunk> + ***Additional URLs and overviews omitted*** </observation> </stage> <stage number="10" role="assistant"> <thought> + After retrieving the urls of relevant sites, I will now use the website scraping tool to gather data on Qatar's tourism impact during the World Cup from these sites. + <action>websiteInfoScraper</action> + </stage> + + <stage number="11" role="user"> + <action_rules>***Action rules omitted***</action_rules> + </stage> + + <stage number="12" role="assistant"> + <action_input> + <action_input_description>Getting information from the relevant websites about Qatar's tourism impact during the World Cup.</action_input_description> + <inputs> + <urls>[***URLS to search elided, but they will be comma seperated double quoted strings"]</urls> + </inputs> + </action_input> + </stage> + + <stage number="13" role="user"> + <observation> + <chunk chunk_id="5678" chunk_type="url"> + ***Data from the websites scraped*** + </chunk> + ***Additional scraped sites omitted*** + </observation> + </stage> + + <stage number="14" role="assistant"> + <thought> Now that I have gathered both key moments from the World Cup and tourism impact data from Qatar, I will summarize the information in my final response. </thought> <answer> diff --git a/src/client/views/nodes/chatbot/chatboxcomponents/ChatBox.tsx b/src/client/views/nodes/chatbot/chatboxcomponents/ChatBox.tsx index e2508d752..68d4383e7 100644 --- a/src/client/views/nodes/chatbot/chatboxcomponents/ChatBox.tsx +++ b/src/client/views/nodes/chatbot/chatboxcomponents/ChatBox.tsx @@ -355,29 +355,11 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() { const linkDoc = Docs.Create.LinkDocument(this.Document, doc); LinkManager.Instance.addLink(linkDoc); - let canDisplay; - - try { - // Fetch the URL content through the proxy - const { data } = await Networking.PostToServer('/proxyFetch', { url }); - - // Simulating header behavior since we can't fetch headers via proxy - const xFrameOptions = data.headers?.['x-frame-options']; - - if (xFrameOptions && xFrameOptions.toUpperCase() === 'SAMEORIGIN') { - canDisplay = false; - } else { - canDisplay = true; - } - } catch (error) { - console.error('Error fetching the URL from the server:', error); - } const chunkToAdd = { chunkId: id, chunkType: CHUNK_TYPE.URL, url: url, - canDisplay: canDisplay, }; doc.chunk_simpl = JSON.stringify({ chunks: [chunkToAdd] }); @@ -648,11 +630,8 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() { }); break; case CHUNK_TYPE.URL: - if (!foundChunk.canDisplay) { - window.open(StrCast(doc.displayUrl), '_blank'); - } else if (foundChunk.canDisplay) { - DocumentManager.Instance.showDocument(doc, { willZoomCentered: true }, () => {}); - } + DocumentManager.Instance.showDocument(doc, { willZoomCentered: true }, () => {}); + break; case CHUNK_TYPE.CSV: DocumentManager.Instance.showDocument(doc, { willZoomCentered: true }, () => {}); diff --git a/src/client/views/nodes/chatbot/tools/SearchTool.ts b/src/client/views/nodes/chatbot/tools/SearchTool.ts index 03340aae5..d22f4c189 100644 --- a/src/client/views/nodes/chatbot/tools/SearchTool.ts +++ b/src/client/views/nodes/chatbot/tools/SearchTool.ts @@ -6,7 +6,7 @@ import { ParametersType } from '../types/tool_types'; const searchToolParams = [ { - name: 'query', + name: 'queries', type: 'string[]', description: 'The search query or queries to use for finding websites', required: true, @@ -20,7 +20,7 @@ export class SearchTool extends BaseTool<SearchToolParamsType> { private _addLinkedUrlDoc: (url: string, id: string) => void; private _max_results: number; - constructor(addLinkedUrlDoc: (url: string, id: string) => void, max_results: number = 5) { + constructor(addLinkedUrlDoc: (url: string, id: string) => void, max_results: number = 4) { super( 'searchTool', 'Search the web to find a wide range of websites related to a query or multiple queries', @@ -33,8 +33,9 @@ export class SearchTool extends BaseTool<SearchToolParamsType> { } async execute(args: ParametersType<SearchToolParamsType>): Promise<Observation[]> { - const queries = args.query; + const queries = args.queries; + console.log(`Searching the web for queries: ${queries[0]}`); // Create an array of promises, each one handling a search for a query const searchPromises = queries.map(async query => { try { diff --git a/src/client/views/nodes/chatbot/types/tool_types.ts b/src/client/views/nodes/chatbot/types/tool_types.ts index c1150534d..b2e05efe4 100644 --- a/src/client/views/nodes/chatbot/types/tool_types.ts +++ b/src/client/views/nodes/chatbot/types/tool_types.ts @@ -19,7 +19,7 @@ export type Parameter = { * A utility type that maps string representations of types to actual TypeScript types. * This is used to convert the `type` field of a `Parameter` into a concrete TypeScript type. */ -type TypeMap = { +export type TypeMap = { string: string; number: number; boolean: boolean; diff --git a/src/client/views/nodes/chatbot/types/types.ts b/src/client/views/nodes/chatbot/types/types.ts index 7abad85f0..c65ac9820 100644 --- a/src/client/views/nodes/chatbot/types/types.ts +++ b/src/client/views/nodes/chatbot/types/types.ts @@ -102,7 +102,6 @@ export interface SimplifiedChunk { location?: string; chunkType: CHUNK_TYPE; url?: string; - canDisplay?: boolean; } export interface AI_Document { diff --git a/src/server/ApiManagers/AssistantManager.ts b/src/server/ApiManagers/AssistantManager.ts index 8447a4934..4d2068014 100644 --- a/src/server/ApiManagers/AssistantManager.ts +++ b/src/server/ApiManagers/AssistantManager.ts @@ -9,7 +9,7 @@ */ import { Readability } from '@mozilla/readability'; -import axios from 'axios'; +import axios, { AxiosResponse } from 'axios'; import { spawn } from 'child_process'; import * as fs from 'fs'; import { writeFile } from 'fs'; @@ -23,6 +23,7 @@ import { AI_Document } from '../../client/views/nodes/chatbot/types/types'; import { Method } from '../RouteManager'; import { filesDirectory, publicDirectory } from '../SocketData'; import ApiManager, { Registration } from './ApiManager'; +import { getServerPath } from '../../client/util/reportManager/reportManagerUtils'; // Enumeration of directories where different file types are stored export enum Directory { @@ -115,29 +116,79 @@ export default class AssistantManager extends ApiManager { }, }); - // Register Google Web Search Results API route register({ method: Method.POST, subscription: '/getWebSearchResults', secureHandler: async ({ req, res }) => { const { query, max_results } = req.body; - try { - // Fetch search results using Google Custom Search API - const response = await customsearch.cse.list({ + const MIN_VALID_RESULTS_RATIO = 0.75; // 3/4 threshold + let startIndex = 1; // Start at the first result initially + let validResults: any[] = []; + + const fetchSearchResults = async (start: number) => { + return customsearch.cse.list({ q: query, cx: process.env._CLIENT_GOOGLE_SEARCH_ENGINE_ID, key: process.env._CLIENT_GOOGLE_API_KEY, safe: 'active', num: max_results, + start, // This controls which result index the search starts from }); + }; + + const filterResultsByXFrameOptions = async (results: any[]) => { + const filteredResults = await Promise.all( + results.map(async result => { + try { + const urlResponse: AxiosResponse = await axios.head(result.url, { timeout: 5000 }); + const xFrameOptions = urlResponse.headers['x-frame-options']; + if (xFrameOptions && xFrameOptions.toUpperCase() === 'SAMEORIGIN') { + return result; + } + } catch (error) { + console.error(`Error checking x-frame-options for URL: ${result.url}`, error); + } + return null; // Exclude the result if it doesn't match + }) + ); + return filteredResults.filter(result => result !== null); // Remove null results + }; - const results = + try { + // Fetch initial search results + let response = await fetchSearchResults(startIndex); + let initialResults = response.data.items?.map(item => ({ url: item.link, snippet: item.snippet, })) || []; - res.send({ results }); + // Filter the initial results + validResults = await filterResultsByXFrameOptions(initialResults); + + // If valid results are less than 3/4 of max_results, fetch more results + while (validResults.length < max_results * MIN_VALID_RESULTS_RATIO) { + // Increment the start index by the max_results to fetch the next set of results + startIndex += max_results; + response = await fetchSearchResults(startIndex); + + const additionalResults = + response.data.items?.map(item => ({ + url: item.link, + snippet: item.snippet, + })) || []; + + const additionalValidResults = await filterResultsByXFrameOptions(additionalResults); + validResults = [...validResults, ...additionalValidResults]; // Combine valid results + + // Break if no more results are available + if (additionalValidResults.length === 0 || response.data.items?.length === 0) { + break; + } + } + + // Return the filtered valid results + res.send({ results: validResults.slice(0, max_results) }); // Limit the results to max_results } catch (error) { console.error('Error performing web search:', error); res.status(500).send({ @@ -299,47 +350,16 @@ export default class AssistantManager extends ApiManager { method: Method.GET, subscription: '/getResult/:jobId', secureHandler: async ({ req, res }) => { - const { jobId } = req.params; // Get the job ID from the URL parameters - // Check if the job result is available + const { jobId } = req.params; if (jobResults[jobId]) { const result = jobResults[jobId] as AI_Document & { status: string }; - // If the result contains image or table chunks, save the base64 data as image files if (result.chunks && Array.isArray(result.chunks)) { - await Promise.all( - result.chunks.map(chunk => { - if (chunk.metadata && (chunk.metadata.type === 'image' || chunk.metadata.type === 'table')) { - const files_directory = '/files/chunk_images/'; - const directory = path.join(publicDirectory, files_directory); - - // Ensure the directory exists or create it - if (!fs.existsSync(directory)) { - fs.mkdirSync(directory); - } - - const fileName = path.basename(chunk.metadata.file_path); // Get the file name from the path - const filePath = path.join(directory, fileName); // Create the full file path - - // Check if the chunk contains base64 encoded data - if (chunk.metadata.base64_data) { - // Decode the base64 data and write it to a file - const buffer = Buffer.from(chunk.metadata.base64_data, 'base64'); - fs.promises.writeFile(filePath, buffer).then(() => { - // Update the file path in the chunk's metadata - chunk.metadata.file_path = path.join(files_directory, fileName); - chunk.metadata.base64_data = undefined; // Remove the base64 data from the metadata - }); - } else { - console.warn(`No base64_data found for chunk: ${fileName}`); - } - } - }) - ); result.status = 'completed'; } else { result.status = 'pending'; } - res.json(result); // Send the result back to the client + res.json(result); } else { res.status(202).send({ status: 'pending' }); } @@ -367,7 +387,7 @@ export default class AssistantManager extends ApiManager { // If the chunk is an image or table, read the corresponding file and encode it as base64 if (chunk.metadata.type === 'image' || chunk.metadata.type === 'table') { try { - const filePath = serverPathToFile(Directory.chunk_images, chunk.metadata.file_path); // Get the file path + const filePath = path.join(pathToDirectory(Directory.chunk_images), chunk.metadata.file_path); // Get the file path readFileAsync(filePath).then(imageBuffer => { const base64Image = imageBuffer.toString('base64'); // Convert the image to base64 @@ -445,10 +465,12 @@ function spawnPythonProcess(jobId: string, file_name: string, file_data: string) const requirementsPath = path.join(__dirname, '../chunker/requirements.txt'); const pythonScriptPath = path.join(__dirname, '../chunker/pdf_chunker.py'); + const outputDirectory = pathToDirectory(Directory.chunk_images); + function runPythonScript() { const pythonPath = process.platform === 'win32' ? path.join(venvPath, 'Scripts', 'python') : path.join(venvPath, 'bin', 'python3'); - const pythonProcess = spawn(pythonPath, [pythonScriptPath, jobId, file_name, file_data]); + const pythonProcess = spawn(pythonPath, [pythonScriptPath, jobId, file_name, file_data, outputDirectory]); let pythonOutput = ''; let stderrOutput = ''; @@ -460,23 +482,30 @@ function spawnPythonProcess(jobId: string, file_name: string, file_data: string) pythonProcess.stderr.on('data', data => { stderrOutput += data.toString(); const lines = stderrOutput.split('\n'); + stderrOutput = lines.pop() || ''; // Save the last partial line back to stderrOutput lines.forEach(line => { if (line.trim()) { - try { - const parsedOutput = JSON.parse(line); - if (parsedOutput.job_id && parsedOutput.progress !== undefined) { - jobProgress[parsedOutput.job_id] = { - step: parsedOutput.step, - progress: parsedOutput.progress, - }; - } else if (parsedOutput.progress !== undefined) { - jobProgress[jobId] = { - step: parsedOutput.step, - progress: parsedOutput.progress, - }; + if (line.startsWith('PROGRESS:')) { + const jsonString = line.substring('PROGRESS:'.length); + try { + const parsedOutput = JSON.parse(jsonString); + if (parsedOutput.job_id && parsedOutput.progress !== undefined) { + jobProgress[parsedOutput.job_id] = { + step: parsedOutput.step, + progress: parsedOutput.progress, + }; + } else if (parsedOutput.progress !== undefined) { + jobProgress[jobId] = { + step: parsedOutput.step, + progress: parsedOutput.progress, + }; + } + } catch (err) { + console.error('Error parsing progress JSON:', jsonString, err); } - } catch (err) { - console.error('Progress log from Python:', line, err); + } else { + // Log other stderr output + console.error('Python stderr:', line); } } }); @@ -490,10 +519,24 @@ function spawnPythonProcess(jobId: string, file_name: string, file_data: string) jobProgress[jobId] = { step: 'Complete', progress: 100 }; } catch (err) { console.error('Error parsing final JSON result:', err); + jobResults[jobId] = { error: 'Failed to parse final result' }; } } else { console.error(`Python process exited with code ${code}`); - jobResults[jobId] = { error: 'Python process failed' }; + // Check if there was an error message in stderr + if (stderrOutput) { + // Try to parse the last line as JSON + const lines = stderrOutput.trim().split('\n'); + const lastLine = lines[lines.length - 1]; + try { + const errorOutput = JSON.parse(lastLine); + jobResults[jobId] = errorOutput; + } catch (err) { + jobResults[jobId] = { error: 'Python process failed' }; + } + } else { + jobResults[jobId] = { error: 'Python process failed' }; + } } }); } diff --git a/src/server/chunker/pdf_chunker.py b/src/server/chunker/pdf_chunker.py index 4fe3b9dbf..48b2dbf97 100644 --- a/src/server/chunker/pdf_chunker.py +++ b/src/server/chunker/pdf_chunker.py @@ -54,8 +54,9 @@ def update_progress(job_id, step, progress_value): "step": step, "progress": progress_value } - print(json.dumps(progress_data), file=sys.stderr) # Use stderr for progress logs - sys.stderr.flush() # Ensure it's sent immediately + print(f"PROGRESS:{json.dumps(progress_data)}", file=sys.stderr) + sys.stderr.flush() + class ElementExtractor: @@ -63,13 +64,15 @@ class ElementExtractor: A class that uses a YOLO model to extract tables and images from a PDF page. """ - def __init__(self, output_folder: str): + def __init__(self, output_folder: str, doc_id: str): """ Initializes the ElementExtractor with the output folder for saving images and the YOLO model. :param output_folder: Path to the folder where extracted elements will be saved. """ - self.output_folder = output_folder + self.doc_id = doc_id + self.output_folder = os.path.join(output_folder, doc_id) + os.makedirs(self.output_folder, exist_ok=True) self.model = YOLO('keremberke/yolov8m-table-extraction') # Load YOLO model for table extraction self.model.overrides['conf'] = 0.25 # Set confidence threshold for detection self.model.overrides['iou'] = 0.45 # Set Intersection over Union (IoU) threshold @@ -116,17 +119,16 @@ class ElementExtractor: table_path = os.path.join(self.output_folder, table_filename) page_with_outline.save(table_path) - # Convert the full-page image with red outline to base64 - base64_data = self.image_to_base64(page_with_outline) + file_path_for_client = f"{self.doc_id}/{table_filename}" tables.append({ 'metadata': { "type": "table", "location": [x1 / img.width, y1 / img.height, x2 / img.width, y2 / img.height], - "file_path": table_path, + "file_path": file_path_for_client, "start_page": page_num, "end_page": page_num, - "base64_data": base64_data, + "base64_data": self.image_to_base64(page_with_outline) } }) @@ -175,18 +177,17 @@ class ElementExtractor: image_path = os.path.join(self.output_folder, image_filename) page_with_outline.save(image_path) - # Convert the full-page image with red outline to base64 - base64_data = self.image_to_base64(page_with_outline) + file_path_for_client = f"{self.doc_id}/{image_filename}" images.append({ 'metadata': { "type": "image", "location": [x1 / page.rect.width, y1 / page.rect.height, x2 / page.rect.width, y2 / page.rect.height], - "file_path": image_path, + "file_path": file_path_for_client, "start_page": page_num, "end_page": page_num, - "base64_data": base64_data, + "base64_data": self.image_to_base64(image) } }) @@ -268,7 +269,7 @@ class PDFChunker: The main class responsible for chunking PDF files into text and visual elements (tables/images). """ - def __init__(self, output_folder: str = "output", image_batch_size: int = 5) -> None: + def __init__(self, output_folder: str = "output", doc_id: str = '', image_batch_size: int = 5) -> None: """ Initializes the PDFChunker with an output folder and an element extractor for visual elements. @@ -278,7 +279,8 @@ class PDFChunker: self.client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) # Initialize the Anthropic API client self.output_folder = output_folder self.image_batch_size = image_batch_size # Batch size for image processing - self.element_extractor = ElementExtractor(output_folder) # Initialize the element extractor + self.doc_id = doc_id # Add doc_id + self.element_extractor = ElementExtractor(output_folder, doc_id) async def chunk_pdf(self, file_data: bytes, file_name: str, doc_id: str, job_id: str) -> List[Dict[str, Any]]: """ @@ -363,6 +365,7 @@ class PDFChunker: for j, elem in enumerate(batch, start=1): if j in summaries: elem['metadata']['text'] = re.sub(r'^(Image|Table):\s*', '', summaries[j]) + elem['metadata']['base64_data'] = '' processed_elements.append(elem) progress = ((i // image_batch_size) + 1) / total_batches * 100 # Calculate progress @@ -628,10 +631,11 @@ class PDFChunker: return summaries - except Exception: - #print(f"Error in batch_summarize_images: {str(e)}") - #print("Returning placeholder summaries") - return {number: "Error: No summary available" for number in images} + except Exception as e: + # Print errors to stderr so they don't interfere with JSON output + print(json.dumps({"error": str(e)}), file=sys.stderr) + sys.stderr.flush() + class DocumentType(Enum): """ @@ -664,7 +668,7 @@ class Document: Represents a document being processed, such as a PDF, handling chunking, embedding, and summarization. """ - def __init__(self, file_data: bytes, file_name: str, job_id: str): + def __init__(self, file_data: bytes, file_name: str, job_id: str, output_folder: str): """ Initialize the Document with file data, file name, and job ID. @@ -672,6 +676,7 @@ class Document: :param file_name: The name of the file being processed. :param job_id: The job ID associated with this document processing task. """ + self.output_folder = output_folder self.file_data = file_data self.file_name = file_name self.job_id = job_id @@ -680,14 +685,13 @@ class Document: self.chunks = [] # List to hold text and visual chunks self.num_pages = 0 # Number of pages in the document (if applicable) self.summary = "" # The generated summary for the document - self._process() # Start processing the document def _process(self): """ Process the document: extract chunks, embed them, and generate a summary. """ - pdf_chunker = PDFChunker(output_folder="output") # Initialize the PDF chunker + pdf_chunker = PDFChunker(output_folder=self.output_folder, doc_id=self.doc_id) # Initialize PDFChunker self.chunks = asyncio.run(pdf_chunker.chunk_pdf(self.file_data, self.file_name, self.doc_id, self.job_id)) # Extract chunks self.num_pages = self._get_pdf_pages() # Get the number of pages in the document @@ -796,8 +800,7 @@ class Document: "doc_id": self.doc_id }, indent=2) # Convert the document's attributes to JSON format - -def process_document(file_data, file_name, job_id): +def process_document(file_data, file_name, job_id, output_folder): """ Top-level function to process a document and return the JSON output. @@ -806,28 +809,30 @@ def process_document(file_data, file_name, job_id): :param job_id: The job ID for this document processing task. :return: The processed document's data in JSON format. """ - new_document = Document(file_data, file_name, job_id) # Create a new Document object - return new_document.to_json() # Return the document's JSON data - + new_document = Document(file_data, file_name, job_id, output_folder) + return new_document.to_json() def main(): """ Main entry point for the script, called with arguments from Node.js. """ - if len(sys.argv) != 4: - print(json.dumps({"error": "Invalid arguments"}), file=sys.stderr) # Print error if incorrect number of arguments + if len(sys.argv) != 5: + print(json.dumps({"error": "Invalid arguments"}), file=sys.stderr) return - job_id = sys.argv[1] # Get the job ID from command-line arguments - file_name = sys.argv[2] # Get the file name from command-line arguments - file_data = sys.argv[3] # Get the base64-encoded file data from command-line arguments + job_id = sys.argv[1] + file_name = sys.argv[2] + file_data = sys.argv[3] + output_folder = sys.argv[4] # Get the output folder from arguments try: + os.makedirs(output_folder, exist_ok=True) + # Decode the base64 file data file_bytes = base64.b64decode(file_data) # Process the document - document_result = process_document(file_bytes, file_name, job_id) + document_result = process_document(file_bytes, file_name, job_id, output_folder) # Pass output_folder # Output the final result as JSON to stdout print(document_result) @@ -839,5 +844,6 @@ def main(): sys.stderr.flush() + if __name__ == "__main__": main() # Execute the main function when the script is run |