aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authoralyssaf16 <alyssa_feinberg@brown.edu>2024-11-04 21:56:25 -0500
committeralyssaf16 <alyssa_feinberg@brown.edu>2024-11-04 21:56:25 -0500
commit1e4909f04fdcc4c0b3a60b8c75e8b687e2b63b8e (patch)
tree16fe239082c37cd4f7e10bbfa6964d13a458046a /src
parent95afe2c1093dc3229375c08e6684b3d9866ef7a2 (diff)
parent09d7d63d1f248a0bf1d36e4da804cbde5e12e209 (diff)
Merge branch 'ajs-finalagent' into alyssa-agent
Diffstat (limited to 'src')
-rw-r--r--src/client/views/nodes/chatbot/agentsystem/Agent.ts87
-rw-r--r--src/client/views/nodes/chatbot/agentsystem/prompts.ts40
-rw-r--r--src/client/views/nodes/chatbot/chatboxcomponents/ChatBox.tsx25
-rw-r--r--src/client/views/nodes/chatbot/tools/SearchTool.ts7
-rw-r--r--src/client/views/nodes/chatbot/types/tool_types.ts2
-rw-r--r--src/client/views/nodes/chatbot/types/types.ts1
-rw-r--r--src/server/ApiManagers/AssistantManager.ts157
-rw-r--r--src/server/chunker/pdf_chunker.py70
8 files changed, 225 insertions, 164 deletions
diff --git a/src/client/views/nodes/chatbot/agentsystem/Agent.ts b/src/client/views/nodes/chatbot/agentsystem/Agent.ts
index 23e7d4a9d..05d13d1db 100644
--- a/src/client/views/nodes/chatbot/agentsystem/Agent.ts
+++ b/src/client/views/nodes/chatbot/agentsystem/Agent.ts
@@ -15,7 +15,7 @@ import { AgentMessage, AssistantMessage, Observation, PROCESSING_TYPE, Processin
import { Vectorstore } from '../vectorstore/Vectorstore';
import { getReactPrompt } from './prompts';
import { BaseTool } from '../tools/BaseTool';
-import { Parameter, ParametersType } from '../types/tool_types';
+import { Parameter, ParametersType, TypeMap } from '../types/tool_types';
import { CreateDocTool } from '../tools/CreateDocumentTool';
import { DocumentOptions } from '../../../../documents/Documents';
@@ -269,11 +269,35 @@ export class Agent {
}
/**
+ * Helper function to check if a string can be parsed as an array of the expected type.
+ * @param input The input string to check.
+ * @param expectedType The expected type of the array elements ('string', 'number', or 'boolean').
+ * @returns The parsed array if valid, otherwise throws an error.
+ */
+ private parseArray<T>(input: string, expectedType: 'string' | 'number' | 'boolean'): T[] {
+ try {
+ // Parse the input string into a JSON object
+ const parsed = JSON.parse(input);
+
+ // Check if the parsed object is an array and if all elements are of the expected type
+ if (Array.isArray(parsed) && parsed.every(item => typeof item === expectedType)) {
+ return parsed;
+ } else {
+ throw new Error(`Invalid ${expectedType} array format.`);
+ }
+ } catch (error) {
+ throw new Error(`Failed to parse ${expectedType} array: ` + error);
+ }
+ }
+
+ /**
* Processes a specific action by invoking the appropriate tool with the provided inputs.
* This method ensures that the action exists and validates the types of `actionInput`
* based on the tool's parameter rules. It throws errors for missing required parameters
* or mismatched types before safely executing the tool with the validated input.
*
+ * NOTE: In the future, it should typecheck for specific tool parameter types using the `TypeMap` or otherwise.
+ *
* Type validation includes checks for:
* - `string`, `number`, `boolean`
* - `string[]`, `number[]` (arrays of strings or numbers)
@@ -283,56 +307,35 @@ export class Agent {
* @returns A promise that resolves to an array of `Observation` objects representing the result of the action.
* @throws An error if the action is unknown, if required parameters are missing, or if input types don't match the expected parameter types.
*/
- private async processAction(action: string, actionInput: Record<string, unknown>): Promise<Observation[]> {
+ private async processAction(action: string, actionInput: ParametersType<ReadonlyArray<Parameter>>): Promise<Observation[]> {
// Check if the action exists in the tools list
if (!(action in this.tools)) {
throw new Error(`Unknown action: ${action}`);
}
+ console.log(actionInput);
- const tool = this.tools[action];
-
- // Validate actionInput based on tool's parameter rules
- for (const paramRule of tool.parameterRules) {
- const inputValue = actionInput[paramRule.name];
-
- if (paramRule.required && inputValue === undefined) {
- throw new Error(`Missing required parameter: ${paramRule.name}`);
+ for (const param of this.tools[action].parameterRules) {
+ // Check if the parameter is required and missing in the input
+ if (param.required && !(param.name in actionInput)) {
+ throw new Error(`Missing required parameter: ${param.name}`);
}
- // If the parameter is defined, check its type
- if (inputValue !== undefined) {
- switch (paramRule.type) {
- case 'string':
- if (typeof inputValue !== 'string') {
- throw new Error(`Expected parameter '${paramRule.name}' to be a string.`);
- }
- break;
- case 'number':
- if (typeof inputValue !== 'number') {
- throw new Error(`Expected parameter '${paramRule.name}' to be a number.`);
- }
- break;
- case 'boolean':
- if (typeof inputValue !== 'boolean') {
- throw new Error(`Expected parameter '${paramRule.name}' to be a boolean.`);
- }
- break;
- case 'string[]':
- if (!Array.isArray(inputValue) || !inputValue.every(item => typeof item === 'string')) {
- throw new Error(`Expected parameter '${paramRule.name}' to be an array of strings.`);
- }
- break;
- case 'number[]':
- if (!Array.isArray(inputValue) || !inputValue.every(item => typeof item === 'number')) {
- throw new Error(`Expected parameter '${paramRule.name}' to be an array of numbers.`);
- }
- break;
- default:
- throw new Error(`Unsupported parameter type: ${paramRule.type}`);
- }
+ // Check if the parameter type matches the expected type
+ const expectedType = param.type.replace('[]', '') as 'string' | 'number' | 'boolean';
+ const isArray = param.type.endsWith('[]');
+ const input = actionInput[param.name];
+
+ if (isArray) {
+ // Check if the input is a valid array of the expected type
+ const parsedArray = this.parseArray(input as string, expectedType);
+ actionInput[param.name] = parsedArray as TypeMap[typeof param.type];
+ } else if (typeof input !== expectedType) {
+ throw new Error(`Invalid type for parameter ${param.name}: expected ${expectedType}`);
}
}
- return await tool.execute(actionInput as ParametersType<typeof tool.parameterRules>);
+ const tool = this.tools[action];
+
+ return await tool.execute(actionInput);
}
}
diff --git a/src/client/views/nodes/chatbot/agentsystem/prompts.ts b/src/client/views/nodes/chatbot/agentsystem/prompts.ts
index f5aec3130..140587b2f 100644
--- a/src/client/views/nodes/chatbot/agentsystem/prompts.ts
+++ b/src/client/views/nodes/chatbot/agentsystem/prompts.ts
@@ -7,9 +7,10 @@
* and summarizing content from provided text chunks.
*/
-import { Tool } from '../types/types';
+import { BaseTool } from '../tools/BaseTool';
+import { Parameter } from '../types/tool_types';
-export function getReactPrompt(tools: Tool[], summaries: () => string, chatHistory: string): string {
+export function getReactPrompt(tools: BaseTool<ReadonlyArray<Parameter>>[], summaries: () => string, chatHistory: string): string {
const toolDescriptions = tools
.map(
tool => `
@@ -143,9 +144,9 @@ export function getReactPrompt(tools: Tool[], summaries: () => string, chatHisto
<stage number="6" role="assistant">
<thought>
- With key moments from the World Cup retrieved, I will now use the website scraper tool to gather data on Qatar's tourism impact during the World Cup.
+ With key moments from the World Cup retrieved, I will now use the search tool to gather data on Qatar's tourism impact during the World Cup.
</thought>
- <action>websiteInfoScraper</action>
+ <action>searchTool</action>
</stage>
<stage number="7" role="user">
@@ -156,7 +157,7 @@ export function getReactPrompt(tools: Tool[], summaries: () => string, chatHisto
<action_input>
<action_input_description>Scraping websites for information about Qatar's tourism impact during the 2022 World Cup.</action_input_description>
<inputs>
- <query>Tourism impact of the 2022 World Cup in Qatar</query>
+ <queries>["Tourism impact of the 2022 World Cup in Qatar"]</query>
</inputs>
</action_input>
</stage>
@@ -167,11 +168,40 @@ export function getReactPrompt(tools: Tool[], summaries: () => string, chatHisto
<url>https://www.qatartourism.com/world-cup-impact</url>
<overview>During the 2022 World Cup, Qatar saw a 40% increase in tourism, with over 1.5 million visitors attending.</overview>
</chunk>
+ ***Additional URLs and overviews omitted***
</observation>
</stage>
<stage number="10" role="assistant">
<thought>
+ After retrieving the urls of relevant sites, I will now use the website scraping tool to gather data on Qatar's tourism impact during the World Cup from these sites.
+ <action>websiteInfoScraper</action>
+ </stage>
+
+ <stage number="11" role="user">
+ <action_rules>***Action rules omitted***</action_rules>
+ </stage>
+
+ <stage number="12" role="assistant">
+ <action_input>
+ <action_input_description>Getting information from the relevant websites about Qatar's tourism impact during the World Cup.</action_input_description>
+ <inputs>
+ <urls>[***URLS to search elided, but they will be comma seperated double quoted strings"]</urls>
+ </inputs>
+ </action_input>
+ </stage>
+
+ <stage number="13" role="user">
+ <observation>
+ <chunk chunk_id="5678" chunk_type="url">
+ ***Data from the websites scraped***
+ </chunk>
+ ***Additional scraped sites omitted***
+ </observation>
+ </stage>
+
+ <stage number="14" role="assistant">
+ <thought>
Now that I have gathered both key moments from the World Cup and tourism impact data from Qatar, I will summarize the information in my final response.
</thought>
<answer>
diff --git a/src/client/views/nodes/chatbot/chatboxcomponents/ChatBox.tsx b/src/client/views/nodes/chatbot/chatboxcomponents/ChatBox.tsx
index e2508d752..68d4383e7 100644
--- a/src/client/views/nodes/chatbot/chatboxcomponents/ChatBox.tsx
+++ b/src/client/views/nodes/chatbot/chatboxcomponents/ChatBox.tsx
@@ -355,29 +355,11 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() {
const linkDoc = Docs.Create.LinkDocument(this.Document, doc);
LinkManager.Instance.addLink(linkDoc);
- let canDisplay;
-
- try {
- // Fetch the URL content through the proxy
- const { data } = await Networking.PostToServer('/proxyFetch', { url });
-
- // Simulating header behavior since we can't fetch headers via proxy
- const xFrameOptions = data.headers?.['x-frame-options'];
-
- if (xFrameOptions && xFrameOptions.toUpperCase() === 'SAMEORIGIN') {
- canDisplay = false;
- } else {
- canDisplay = true;
- }
- } catch (error) {
- console.error('Error fetching the URL from the server:', error);
- }
const chunkToAdd = {
chunkId: id,
chunkType: CHUNK_TYPE.URL,
url: url,
- canDisplay: canDisplay,
};
doc.chunk_simpl = JSON.stringify({ chunks: [chunkToAdd] });
@@ -648,11 +630,8 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() {
});
break;
case CHUNK_TYPE.URL:
- if (!foundChunk.canDisplay) {
- window.open(StrCast(doc.displayUrl), '_blank');
- } else if (foundChunk.canDisplay) {
- DocumentManager.Instance.showDocument(doc, { willZoomCentered: true }, () => {});
- }
+ DocumentManager.Instance.showDocument(doc, { willZoomCentered: true }, () => {});
+
break;
case CHUNK_TYPE.CSV:
DocumentManager.Instance.showDocument(doc, { willZoomCentered: true }, () => {});
diff --git a/src/client/views/nodes/chatbot/tools/SearchTool.ts b/src/client/views/nodes/chatbot/tools/SearchTool.ts
index 03340aae5..d22f4c189 100644
--- a/src/client/views/nodes/chatbot/tools/SearchTool.ts
+++ b/src/client/views/nodes/chatbot/tools/SearchTool.ts
@@ -6,7 +6,7 @@ import { ParametersType } from '../types/tool_types';
const searchToolParams = [
{
- name: 'query',
+ name: 'queries',
type: 'string[]',
description: 'The search query or queries to use for finding websites',
required: true,
@@ -20,7 +20,7 @@ export class SearchTool extends BaseTool<SearchToolParamsType> {
private _addLinkedUrlDoc: (url: string, id: string) => void;
private _max_results: number;
- constructor(addLinkedUrlDoc: (url: string, id: string) => void, max_results: number = 5) {
+ constructor(addLinkedUrlDoc: (url: string, id: string) => void, max_results: number = 4) {
super(
'searchTool',
'Search the web to find a wide range of websites related to a query or multiple queries',
@@ -33,8 +33,9 @@ export class SearchTool extends BaseTool<SearchToolParamsType> {
}
async execute(args: ParametersType<SearchToolParamsType>): Promise<Observation[]> {
- const queries = args.query;
+ const queries = args.queries;
+ console.log(`Searching the web for queries: ${queries[0]}`);
// Create an array of promises, each one handling a search for a query
const searchPromises = queries.map(async query => {
try {
diff --git a/src/client/views/nodes/chatbot/types/tool_types.ts b/src/client/views/nodes/chatbot/types/tool_types.ts
index c1150534d..b2e05efe4 100644
--- a/src/client/views/nodes/chatbot/types/tool_types.ts
+++ b/src/client/views/nodes/chatbot/types/tool_types.ts
@@ -19,7 +19,7 @@ export type Parameter = {
* A utility type that maps string representations of types to actual TypeScript types.
* This is used to convert the `type` field of a `Parameter` into a concrete TypeScript type.
*/
-type TypeMap = {
+export type TypeMap = {
string: string;
number: number;
boolean: boolean;
diff --git a/src/client/views/nodes/chatbot/types/types.ts b/src/client/views/nodes/chatbot/types/types.ts
index 7abad85f0..c65ac9820 100644
--- a/src/client/views/nodes/chatbot/types/types.ts
+++ b/src/client/views/nodes/chatbot/types/types.ts
@@ -102,7 +102,6 @@ export interface SimplifiedChunk {
location?: string;
chunkType: CHUNK_TYPE;
url?: string;
- canDisplay?: boolean;
}
export interface AI_Document {
diff --git a/src/server/ApiManagers/AssistantManager.ts b/src/server/ApiManagers/AssistantManager.ts
index 8447a4934..4d2068014 100644
--- a/src/server/ApiManagers/AssistantManager.ts
+++ b/src/server/ApiManagers/AssistantManager.ts
@@ -9,7 +9,7 @@
*/
import { Readability } from '@mozilla/readability';
-import axios from 'axios';
+import axios, { AxiosResponse } from 'axios';
import { spawn } from 'child_process';
import * as fs from 'fs';
import { writeFile } from 'fs';
@@ -23,6 +23,7 @@ import { AI_Document } from '../../client/views/nodes/chatbot/types/types';
import { Method } from '../RouteManager';
import { filesDirectory, publicDirectory } from '../SocketData';
import ApiManager, { Registration } from './ApiManager';
+import { getServerPath } from '../../client/util/reportManager/reportManagerUtils';
// Enumeration of directories where different file types are stored
export enum Directory {
@@ -115,29 +116,79 @@ export default class AssistantManager extends ApiManager {
},
});
- // Register Google Web Search Results API route
register({
method: Method.POST,
subscription: '/getWebSearchResults',
secureHandler: async ({ req, res }) => {
const { query, max_results } = req.body;
- try {
- // Fetch search results using Google Custom Search API
- const response = await customsearch.cse.list({
+ const MIN_VALID_RESULTS_RATIO = 0.75; // 3/4 threshold
+ let startIndex = 1; // Start at the first result initially
+ let validResults: any[] = [];
+
+ const fetchSearchResults = async (start: number) => {
+ return customsearch.cse.list({
q: query,
cx: process.env._CLIENT_GOOGLE_SEARCH_ENGINE_ID,
key: process.env._CLIENT_GOOGLE_API_KEY,
safe: 'active',
num: max_results,
+ start, // This controls which result index the search starts from
});
+ };
+
+ const filterResultsByXFrameOptions = async (results: any[]) => {
+ const filteredResults = await Promise.all(
+ results.map(async result => {
+ try {
+ const urlResponse: AxiosResponse = await axios.head(result.url, { timeout: 5000 });
+ const xFrameOptions = urlResponse.headers['x-frame-options'];
+ if (xFrameOptions && xFrameOptions.toUpperCase() === 'SAMEORIGIN') {
+ return result;
+ }
+ } catch (error) {
+ console.error(`Error checking x-frame-options for URL: ${result.url}`, error);
+ }
+ return null; // Exclude the result if it doesn't match
+ })
+ );
+ return filteredResults.filter(result => result !== null); // Remove null results
+ };
- const results =
+ try {
+ // Fetch initial search results
+ let response = await fetchSearchResults(startIndex);
+ let initialResults =
response.data.items?.map(item => ({
url: item.link,
snippet: item.snippet,
})) || [];
- res.send({ results });
+ // Filter the initial results
+ validResults = await filterResultsByXFrameOptions(initialResults);
+
+ // If valid results are less than 3/4 of max_results, fetch more results
+ while (validResults.length < max_results * MIN_VALID_RESULTS_RATIO) {
+ // Increment the start index by the max_results to fetch the next set of results
+ startIndex += max_results;
+ response = await fetchSearchResults(startIndex);
+
+ const additionalResults =
+ response.data.items?.map(item => ({
+ url: item.link,
+ snippet: item.snippet,
+ })) || [];
+
+ const additionalValidResults = await filterResultsByXFrameOptions(additionalResults);
+ validResults = [...validResults, ...additionalValidResults]; // Combine valid results
+
+ // Break if no more results are available
+ if (additionalValidResults.length === 0 || response.data.items?.length === 0) {
+ break;
+ }
+ }
+
+ // Return the filtered valid results
+ res.send({ results: validResults.slice(0, max_results) }); // Limit the results to max_results
} catch (error) {
console.error('Error performing web search:', error);
res.status(500).send({
@@ -299,47 +350,16 @@ export default class AssistantManager extends ApiManager {
method: Method.GET,
subscription: '/getResult/:jobId',
secureHandler: async ({ req, res }) => {
- const { jobId } = req.params; // Get the job ID from the URL parameters
- // Check if the job result is available
+ const { jobId } = req.params;
if (jobResults[jobId]) {
const result = jobResults[jobId] as AI_Document & { status: string };
- // If the result contains image or table chunks, save the base64 data as image files
if (result.chunks && Array.isArray(result.chunks)) {
- await Promise.all(
- result.chunks.map(chunk => {
- if (chunk.metadata && (chunk.metadata.type === 'image' || chunk.metadata.type === 'table')) {
- const files_directory = '/files/chunk_images/';
- const directory = path.join(publicDirectory, files_directory);
-
- // Ensure the directory exists or create it
- if (!fs.existsSync(directory)) {
- fs.mkdirSync(directory);
- }
-
- const fileName = path.basename(chunk.metadata.file_path); // Get the file name from the path
- const filePath = path.join(directory, fileName); // Create the full file path
-
- // Check if the chunk contains base64 encoded data
- if (chunk.metadata.base64_data) {
- // Decode the base64 data and write it to a file
- const buffer = Buffer.from(chunk.metadata.base64_data, 'base64');
- fs.promises.writeFile(filePath, buffer).then(() => {
- // Update the file path in the chunk's metadata
- chunk.metadata.file_path = path.join(files_directory, fileName);
- chunk.metadata.base64_data = undefined; // Remove the base64 data from the metadata
- });
- } else {
- console.warn(`No base64_data found for chunk: ${fileName}`);
- }
- }
- })
- );
result.status = 'completed';
} else {
result.status = 'pending';
}
- res.json(result); // Send the result back to the client
+ res.json(result);
} else {
res.status(202).send({ status: 'pending' });
}
@@ -367,7 +387,7 @@ export default class AssistantManager extends ApiManager {
// If the chunk is an image or table, read the corresponding file and encode it as base64
if (chunk.metadata.type === 'image' || chunk.metadata.type === 'table') {
try {
- const filePath = serverPathToFile(Directory.chunk_images, chunk.metadata.file_path); // Get the file path
+ const filePath = path.join(pathToDirectory(Directory.chunk_images), chunk.metadata.file_path); // Get the file path
readFileAsync(filePath).then(imageBuffer => {
const base64Image = imageBuffer.toString('base64'); // Convert the image to base64
@@ -445,10 +465,12 @@ function spawnPythonProcess(jobId: string, file_name: string, file_data: string)
const requirementsPath = path.join(__dirname, '../chunker/requirements.txt');
const pythonScriptPath = path.join(__dirname, '../chunker/pdf_chunker.py');
+ const outputDirectory = pathToDirectory(Directory.chunk_images);
+
function runPythonScript() {
const pythonPath = process.platform === 'win32' ? path.join(venvPath, 'Scripts', 'python') : path.join(venvPath, 'bin', 'python3');
- const pythonProcess = spawn(pythonPath, [pythonScriptPath, jobId, file_name, file_data]);
+ const pythonProcess = spawn(pythonPath, [pythonScriptPath, jobId, file_name, file_data, outputDirectory]);
let pythonOutput = '';
let stderrOutput = '';
@@ -460,23 +482,30 @@ function spawnPythonProcess(jobId: string, file_name: string, file_data: string)
pythonProcess.stderr.on('data', data => {
stderrOutput += data.toString();
const lines = stderrOutput.split('\n');
+ stderrOutput = lines.pop() || ''; // Save the last partial line back to stderrOutput
lines.forEach(line => {
if (line.trim()) {
- try {
- const parsedOutput = JSON.parse(line);
- if (parsedOutput.job_id && parsedOutput.progress !== undefined) {
- jobProgress[parsedOutput.job_id] = {
- step: parsedOutput.step,
- progress: parsedOutput.progress,
- };
- } else if (parsedOutput.progress !== undefined) {
- jobProgress[jobId] = {
- step: parsedOutput.step,
- progress: parsedOutput.progress,
- };
+ if (line.startsWith('PROGRESS:')) {
+ const jsonString = line.substring('PROGRESS:'.length);
+ try {
+ const parsedOutput = JSON.parse(jsonString);
+ if (parsedOutput.job_id && parsedOutput.progress !== undefined) {
+ jobProgress[parsedOutput.job_id] = {
+ step: parsedOutput.step,
+ progress: parsedOutput.progress,
+ };
+ } else if (parsedOutput.progress !== undefined) {
+ jobProgress[jobId] = {
+ step: parsedOutput.step,
+ progress: parsedOutput.progress,
+ };
+ }
+ } catch (err) {
+ console.error('Error parsing progress JSON:', jsonString, err);
}
- } catch (err) {
- console.error('Progress log from Python:', line, err);
+ } else {
+ // Log other stderr output
+ console.error('Python stderr:', line);
}
}
});
@@ -490,10 +519,24 @@ function spawnPythonProcess(jobId: string, file_name: string, file_data: string)
jobProgress[jobId] = { step: 'Complete', progress: 100 };
} catch (err) {
console.error('Error parsing final JSON result:', err);
+ jobResults[jobId] = { error: 'Failed to parse final result' };
}
} else {
console.error(`Python process exited with code ${code}`);
- jobResults[jobId] = { error: 'Python process failed' };
+ // Check if there was an error message in stderr
+ if (stderrOutput) {
+ // Try to parse the last line as JSON
+ const lines = stderrOutput.trim().split('\n');
+ const lastLine = lines[lines.length - 1];
+ try {
+ const errorOutput = JSON.parse(lastLine);
+ jobResults[jobId] = errorOutput;
+ } catch (err) {
+ jobResults[jobId] = { error: 'Python process failed' };
+ }
+ } else {
+ jobResults[jobId] = { error: 'Python process failed' };
+ }
}
});
}
diff --git a/src/server/chunker/pdf_chunker.py b/src/server/chunker/pdf_chunker.py
index 4fe3b9dbf..48b2dbf97 100644
--- a/src/server/chunker/pdf_chunker.py
+++ b/src/server/chunker/pdf_chunker.py
@@ -54,8 +54,9 @@ def update_progress(job_id, step, progress_value):
"step": step,
"progress": progress_value
}
- print(json.dumps(progress_data), file=sys.stderr) # Use stderr for progress logs
- sys.stderr.flush() # Ensure it's sent immediately
+ print(f"PROGRESS:{json.dumps(progress_data)}", file=sys.stderr)
+ sys.stderr.flush()
+
class ElementExtractor:
@@ -63,13 +64,15 @@ class ElementExtractor:
A class that uses a YOLO model to extract tables and images from a PDF page.
"""
- def __init__(self, output_folder: str):
+ def __init__(self, output_folder: str, doc_id: str):
"""
Initializes the ElementExtractor with the output folder for saving images and the YOLO model.
:param output_folder: Path to the folder where extracted elements will be saved.
"""
- self.output_folder = output_folder
+ self.doc_id = doc_id
+ self.output_folder = os.path.join(output_folder, doc_id)
+ os.makedirs(self.output_folder, exist_ok=True)
self.model = YOLO('keremberke/yolov8m-table-extraction') # Load YOLO model for table extraction
self.model.overrides['conf'] = 0.25 # Set confidence threshold for detection
self.model.overrides['iou'] = 0.45 # Set Intersection over Union (IoU) threshold
@@ -116,17 +119,16 @@ class ElementExtractor:
table_path = os.path.join(self.output_folder, table_filename)
page_with_outline.save(table_path)
- # Convert the full-page image with red outline to base64
- base64_data = self.image_to_base64(page_with_outline)
+ file_path_for_client = f"{self.doc_id}/{table_filename}"
tables.append({
'metadata': {
"type": "table",
"location": [x1 / img.width, y1 / img.height, x2 / img.width, y2 / img.height],
- "file_path": table_path,
+ "file_path": file_path_for_client,
"start_page": page_num,
"end_page": page_num,
- "base64_data": base64_data,
+ "base64_data": self.image_to_base64(page_with_outline)
}
})
@@ -175,18 +177,17 @@ class ElementExtractor:
image_path = os.path.join(self.output_folder, image_filename)
page_with_outline.save(image_path)
- # Convert the full-page image with red outline to base64
- base64_data = self.image_to_base64(page_with_outline)
+ file_path_for_client = f"{self.doc_id}/{image_filename}"
images.append({
'metadata': {
"type": "image",
"location": [x1 / page.rect.width, y1 / page.rect.height, x2 / page.rect.width,
y2 / page.rect.height],
- "file_path": image_path,
+ "file_path": file_path_for_client,
"start_page": page_num,
"end_page": page_num,
- "base64_data": base64_data,
+ "base64_data": self.image_to_base64(image)
}
})
@@ -268,7 +269,7 @@ class PDFChunker:
The main class responsible for chunking PDF files into text and visual elements (tables/images).
"""
- def __init__(self, output_folder: str = "output", image_batch_size: int = 5) -> None:
+ def __init__(self, output_folder: str = "output", doc_id: str = '', image_batch_size: int = 5) -> None:
"""
Initializes the PDFChunker with an output folder and an element extractor for visual elements.
@@ -278,7 +279,8 @@ class PDFChunker:
self.client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) # Initialize the Anthropic API client
self.output_folder = output_folder
self.image_batch_size = image_batch_size # Batch size for image processing
- self.element_extractor = ElementExtractor(output_folder) # Initialize the element extractor
+ self.doc_id = doc_id # Add doc_id
+ self.element_extractor = ElementExtractor(output_folder, doc_id)
async def chunk_pdf(self, file_data: bytes, file_name: str, doc_id: str, job_id: str) -> List[Dict[str, Any]]:
"""
@@ -363,6 +365,7 @@ class PDFChunker:
for j, elem in enumerate(batch, start=1):
if j in summaries:
elem['metadata']['text'] = re.sub(r'^(Image|Table):\s*', '', summaries[j])
+ elem['metadata']['base64_data'] = ''
processed_elements.append(elem)
progress = ((i // image_batch_size) + 1) / total_batches * 100 # Calculate progress
@@ -628,10 +631,11 @@ class PDFChunker:
return summaries
- except Exception:
- #print(f"Error in batch_summarize_images: {str(e)}")
- #print("Returning placeholder summaries")
- return {number: "Error: No summary available" for number in images}
+ except Exception as e:
+ # Print errors to stderr so they don't interfere with JSON output
+ print(json.dumps({"error": str(e)}), file=sys.stderr)
+ sys.stderr.flush()
+
class DocumentType(Enum):
"""
@@ -664,7 +668,7 @@ class Document:
Represents a document being processed, such as a PDF, handling chunking, embedding, and summarization.
"""
- def __init__(self, file_data: bytes, file_name: str, job_id: str):
+ def __init__(self, file_data: bytes, file_name: str, job_id: str, output_folder: str):
"""
Initialize the Document with file data, file name, and job ID.
@@ -672,6 +676,7 @@ class Document:
:param file_name: The name of the file being processed.
:param job_id: The job ID associated with this document processing task.
"""
+ self.output_folder = output_folder
self.file_data = file_data
self.file_name = file_name
self.job_id = job_id
@@ -680,14 +685,13 @@ class Document:
self.chunks = [] # List to hold text and visual chunks
self.num_pages = 0 # Number of pages in the document (if applicable)
self.summary = "" # The generated summary for the document
-
self._process() # Start processing the document
def _process(self):
"""
Process the document: extract chunks, embed them, and generate a summary.
"""
- pdf_chunker = PDFChunker(output_folder="output") # Initialize the PDF chunker
+ pdf_chunker = PDFChunker(output_folder=self.output_folder, doc_id=self.doc_id) # Initialize PDFChunker
self.chunks = asyncio.run(pdf_chunker.chunk_pdf(self.file_data, self.file_name, self.doc_id, self.job_id)) # Extract chunks
self.num_pages = self._get_pdf_pages() # Get the number of pages in the document
@@ -796,8 +800,7 @@ class Document:
"doc_id": self.doc_id
}, indent=2) # Convert the document's attributes to JSON format
-
-def process_document(file_data, file_name, job_id):
+def process_document(file_data, file_name, job_id, output_folder):
"""
Top-level function to process a document and return the JSON output.
@@ -806,28 +809,30 @@ def process_document(file_data, file_name, job_id):
:param job_id: The job ID for this document processing task.
:return: The processed document's data in JSON format.
"""
- new_document = Document(file_data, file_name, job_id) # Create a new Document object
- return new_document.to_json() # Return the document's JSON data
-
+ new_document = Document(file_data, file_name, job_id, output_folder)
+ return new_document.to_json()
def main():
"""
Main entry point for the script, called with arguments from Node.js.
"""
- if len(sys.argv) != 4:
- print(json.dumps({"error": "Invalid arguments"}), file=sys.stderr) # Print error if incorrect number of arguments
+ if len(sys.argv) != 5:
+ print(json.dumps({"error": "Invalid arguments"}), file=sys.stderr)
return
- job_id = sys.argv[1] # Get the job ID from command-line arguments
- file_name = sys.argv[2] # Get the file name from command-line arguments
- file_data = sys.argv[3] # Get the base64-encoded file data from command-line arguments
+ job_id = sys.argv[1]
+ file_name = sys.argv[2]
+ file_data = sys.argv[3]
+ output_folder = sys.argv[4] # Get the output folder from arguments
try:
+ os.makedirs(output_folder, exist_ok=True)
+
# Decode the base64 file data
file_bytes = base64.b64decode(file_data)
# Process the document
- document_result = process_document(file_bytes, file_name, job_id)
+ document_result = process_document(file_bytes, file_name, job_id, output_folder) # Pass output_folder
# Output the final result as JSON to stdout
print(document_result)
@@ -839,5 +844,6 @@ def main():
sys.stderr.flush()
+
if __name__ == "__main__":
main() # Execute the main function when the script is run