aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/client/views/nodes/chatbot/agentsystem/Agent.ts98
-rw-r--r--src/client/views/nodes/chatbot/agentsystem/prompts.ts40
-rw-r--r--src/client/views/nodes/chatbot/chatboxcomponents/ChatBox.tsx290
-rw-r--r--src/client/views/nodes/chatbot/tools/BaseTool.ts4
-rw-r--r--src/client/views/nodes/chatbot/tools/CalculateTool.ts2
-rw-r--r--src/client/views/nodes/chatbot/tools/CreateCSVTool.ts2
-rw-r--r--src/client/views/nodes/chatbot/tools/CreateDocumentTool.ts392
-rw-r--r--src/client/views/nodes/chatbot/tools/DataAnalysisTool.ts2
-rw-r--r--src/client/views/nodes/chatbot/tools/GetDocsTool.ts2
-rw-r--r--src/client/views/nodes/chatbot/tools/NoTool.ts2
-rw-r--r--src/client/views/nodes/chatbot/tools/RAGTool.ts2
-rw-r--r--src/client/views/nodes/chatbot/tools/SearchTool.ts12
-rw-r--r--src/client/views/nodes/chatbot/tools/WebsiteInfoScraperTool.ts2
-rw-r--r--src/client/views/nodes/chatbot/tools/WikipediaTool.ts4
-rw-r--r--src/client/views/nodes/chatbot/types/tool_types.ts (renamed from src/client/views/nodes/chatbot/tools/ToolTypes.ts)34
-rw-r--r--src/client/views/nodes/chatbot/types/types.ts1
-rw-r--r--src/client/views/nodes/chatbot/vectorstore/Vectorstore.ts4
-rw-r--r--src/server/ApiManagers/AssistantManager.ts157
-rw-r--r--src/server/chunker/pdf_chunker.py70
-rw-r--r--src/server/flashcard/labels.py285
-rw-r--r--src/server/flashcard/requirements.txt12
-rw-r--r--src/server/flashcard/venv/pyvenv.cfg3
22 files changed, 1195 insertions, 225 deletions
diff --git a/src/client/views/nodes/chatbot/agentsystem/Agent.ts b/src/client/views/nodes/chatbot/agentsystem/Agent.ts
index 34e7cf5ea..0b0e211eb 100644
--- a/src/client/views/nodes/chatbot/agentsystem/Agent.ts
+++ b/src/client/views/nodes/chatbot/agentsystem/Agent.ts
@@ -15,7 +15,9 @@ import { AgentMessage, AssistantMessage, Observation, PROCESSING_TYPE, Processin
import { Vectorstore } from '../vectorstore/Vectorstore';
import { getReactPrompt } from './prompts';
import { BaseTool } from '../tools/BaseTool';
-import { Parameter, ParametersType, Tool } from '../tools/ToolTypes';
+import { Parameter, ParametersType, TypeMap } from '../types/tool_types';
+import { CreateDocTool } from '../tools/CreateDocumentTool';
+import { DocumentOptions } from '../../../../documents/Documents';
dotenv.config();
@@ -54,6 +56,7 @@ export class Agent {
history: () => string,
csvData: () => { filename: string; id: string; text: string }[],
addLinkedUrlDoc: (url: string, id: string) => void,
+ addLinkedDoc: (doc_type: string, data: string, options: DocumentOptions, id: string) => void,
createCSVInDash: (url: string, title: string, id: string, data: string) => void
) {
// Initialize OpenAI client with API key from environment
@@ -66,12 +69,13 @@ export class Agent {
// Define available tools for the assistant
this.tools = {
calculate: new CalculateTool(),
- rag: new RAGTool(this.vectorstore),
+ // rag: new RAGTool(this.vectorstore),
dataAnalysis: new DataAnalysisTool(csvData),
- websiteInfoScraper: new WebsiteInfoScraperTool(addLinkedUrlDoc),
+ // websiteInfoScraper: new WebsiteInfoScraperTool(addLinkedUrlDoc),
searchTool: new SearchTool(addLinkedUrlDoc),
createCSV: new CreateCSVTool(createCSVInDash),
- no_tool: new NoTool(),
+ noTool: new NoTool(),
+ createDoc: new CreateDocTool(addLinkedDoc),
};
}
@@ -164,6 +168,7 @@ export class Agent {
} else if (key === 'action_input') {
// Handle action input stage
const actionInput = stage[key];
+ console.log(`Action input full:`, actionInput);
console.log(`Action input:`, actionInput.inputs);
if (currentAction) {
@@ -264,11 +269,35 @@ export class Agent {
}
/**
+ * Helper function to check if a string can be parsed as an array of the expected type.
+ * @param input The input string to check.
+ * @param expectedType The expected type of the array elements ('string', 'number', or 'boolean').
+ * @returns The parsed array if valid, otherwise throws an error.
+ */
+ private parseArray<T>(input: string, expectedType: 'string' | 'number' | 'boolean'): T[] {
+ try {
+ // Parse the input string into a JSON object
+ const parsed = JSON.parse(input);
+
+ // Check if the parsed object is an array and if all elements are of the expected type
+ if (Array.isArray(parsed) && parsed.every(item => typeof item === expectedType)) {
+ return parsed;
+ } else {
+ throw new Error(`Invalid ${expectedType} array format.`);
+ }
+ } catch (error) {
+ throw new Error(`Failed to parse ${expectedType} array: ` + error);
+ }
+ }
+
+ /**
* Processes a specific action by invoking the appropriate tool with the provided inputs.
* This method ensures that the action exists and validates the types of `actionInput`
* based on the tool's parameter rules. It throws errors for missing required parameters
* or mismatched types before safely executing the tool with the validated input.
*
+ * NOTE: In the future, it should typecheck for specific tool parameter types using the `TypeMap` or otherwise.
+ *
* Type validation includes checks for:
* - `string`, `number`, `boolean`
* - `string[]`, `number[]` (arrays of strings or numbers)
@@ -278,56 +307,35 @@ export class Agent {
* @returns A promise that resolves to an array of `Observation` objects representing the result of the action.
* @throws An error if the action is unknown, if required parameters are missing, or if input types don't match the expected parameter types.
*/
- private async processAction(action: string, actionInput: Record<string, unknown>): Promise<Observation[]> {
+ private async processAction(action: string, actionInput: ParametersType<ReadonlyArray<Parameter>>): Promise<Observation[]> {
// Check if the action exists in the tools list
if (!(action in this.tools)) {
throw new Error(`Unknown action: ${action}`);
}
+ console.log(actionInput);
- const tool = this.tools[action];
-
- // Validate actionInput based on tool's parameter rules
- for (const paramRule of tool.parameterRules) {
- const inputValue = actionInput[paramRule.name];
-
- if (paramRule.required && inputValue === undefined) {
- throw new Error(`Missing required parameter: ${paramRule.name}`);
+ for (const param of this.tools[action].parameterRules) {
+ // Check if the parameter is required and missing in the input
+ if (param.required && !(param.name in actionInput)) {
+ throw new Error(`Missing required parameter: ${param.name}`);
}
- // If the parameter is defined, check its type
- if (inputValue !== undefined) {
- switch (paramRule.type) {
- case 'string':
- if (typeof inputValue !== 'string') {
- throw new Error(`Expected parameter '${paramRule.name}' to be a string.`);
- }
- break;
- case 'number':
- if (typeof inputValue !== 'number') {
- throw new Error(`Expected parameter '${paramRule.name}' to be a number.`);
- }
- break;
- case 'boolean':
- if (typeof inputValue !== 'boolean') {
- throw new Error(`Expected parameter '${paramRule.name}' to be a boolean.`);
- }
- break;
- case 'string[]':
- if (!Array.isArray(inputValue) || !inputValue.every(item => typeof item === 'string')) {
- throw new Error(`Expected parameter '${paramRule.name}' to be an array of strings.`);
- }
- break;
- case 'number[]':
- if (!Array.isArray(inputValue) || !inputValue.every(item => typeof item === 'number')) {
- throw new Error(`Expected parameter '${paramRule.name}' to be an array of numbers.`);
- }
- break;
- default:
- throw new Error(`Unsupported parameter type: ${paramRule.type}`);
- }
+ // Check if the parameter type matches the expected type
+ const expectedType = param.type.replace('[]', '') as 'string' | 'number' | 'boolean';
+ const isArray = param.type.endsWith('[]');
+ const input = actionInput[param.name];
+
+ if (isArray) {
+ // Check if the input is a valid array of the expected type
+ const parsedArray = this.parseArray(input as string, expectedType);
+ actionInput[param.name] = parsedArray as TypeMap[typeof param.type];
+ } else if (typeof input !== expectedType) {
+ throw new Error(`Invalid type for parameter ${param.name}: expected ${expectedType}`);
}
}
- return await tool.execute(actionInput as ParametersType<typeof tool.parameterRules>);
+ const tool = this.tools[action];
+
+ return await tool.execute(actionInput);
}
}
diff --git a/src/client/views/nodes/chatbot/agentsystem/prompts.ts b/src/client/views/nodes/chatbot/agentsystem/prompts.ts
index f5aec3130..140587b2f 100644
--- a/src/client/views/nodes/chatbot/agentsystem/prompts.ts
+++ b/src/client/views/nodes/chatbot/agentsystem/prompts.ts
@@ -7,9 +7,10 @@
* and summarizing content from provided text chunks.
*/
-import { Tool } from '../types/types';
+import { BaseTool } from '../tools/BaseTool';
+import { Parameter } from '../types/tool_types';
-export function getReactPrompt(tools: Tool[], summaries: () => string, chatHistory: string): string {
+export function getReactPrompt(tools: BaseTool<ReadonlyArray<Parameter>>[], summaries: () => string, chatHistory: string): string {
const toolDescriptions = tools
.map(
tool => `
@@ -143,9 +144,9 @@ export function getReactPrompt(tools: Tool[], summaries: () => string, chatHisto
<stage number="6" role="assistant">
<thought>
- With key moments from the World Cup retrieved, I will now use the website scraper tool to gather data on Qatar's tourism impact during the World Cup.
+ With key moments from the World Cup retrieved, I will now use the search tool to gather data on Qatar's tourism impact during the World Cup.
</thought>
- <action>websiteInfoScraper</action>
+ <action>searchTool</action>
</stage>
<stage number="7" role="user">
@@ -156,7 +157,7 @@ export function getReactPrompt(tools: Tool[], summaries: () => string, chatHisto
<action_input>
<action_input_description>Scraping websites for information about Qatar's tourism impact during the 2022 World Cup.</action_input_description>
<inputs>
- <query>Tourism impact of the 2022 World Cup in Qatar</query>
+ <queries>["Tourism impact of the 2022 World Cup in Qatar"]</query>
</inputs>
</action_input>
</stage>
@@ -167,11 +168,40 @@ export function getReactPrompt(tools: Tool[], summaries: () => string, chatHisto
<url>https://www.qatartourism.com/world-cup-impact</url>
<overview>During the 2022 World Cup, Qatar saw a 40% increase in tourism, with over 1.5 million visitors attending.</overview>
</chunk>
+ ***Additional URLs and overviews omitted***
</observation>
</stage>
<stage number="10" role="assistant">
<thought>
+ After retrieving the urls of relevant sites, I will now use the website scraping tool to gather data on Qatar's tourism impact during the World Cup from these sites.
+ <action>websiteInfoScraper</action>
+ </stage>
+
+ <stage number="11" role="user">
+ <action_rules>***Action rules omitted***</action_rules>
+ </stage>
+
+ <stage number="12" role="assistant">
+ <action_input>
+ <action_input_description>Getting information from the relevant websites about Qatar's tourism impact during the World Cup.</action_input_description>
+ <inputs>
+ <urls>[***URLS to search elided, but they will be comma seperated double quoted strings"]</urls>
+ </inputs>
+ </action_input>
+ </stage>
+
+ <stage number="13" role="user">
+ <observation>
+ <chunk chunk_id="5678" chunk_type="url">
+ ***Data from the websites scraped***
+ </chunk>
+ ***Additional scraped sites omitted***
+ </observation>
+ </stage>
+
+ <stage number="14" role="assistant">
+ <thought>
Now that I have gathered both key moments from the World Cup and tourism impact data from Qatar, I will summarize the information in my final response.
</thought>
<answer>
diff --git a/src/client/views/nodes/chatbot/chatboxcomponents/ChatBox.tsx b/src/client/views/nodes/chatbot/chatboxcomponents/ChatBox.tsx
index 44c231c87..542d8ea58 100644
--- a/src/client/views/nodes/chatbot/chatboxcomponents/ChatBox.tsx
+++ b/src/client/views/nodes/chatbot/chatboxcomponents/ChatBox.tsx
@@ -16,11 +16,11 @@ import { v4 as uuidv4 } from 'uuid';
import { ClientUtils } from '../../../../../ClientUtils';
import { Doc, DocListCast } from '../../../../../fields/Doc';
import { DocData, DocViews } from '../../../../../fields/DocSymbols';
-import { CsvCast, DocCast, PDFCast, RTFCast, StrCast } from '../../../../../fields/Types';
+import { CsvCast, DocCast, PDFCast, RTFCast, StrCast, NumCast } from '../../../../../fields/Types';
import { Networking } from '../../../../Network';
import { DocUtils } from '../../../../documents/DocUtils';
import { DocumentType } from '../../../../documents/DocumentTypes';
-import { Docs } from '../../../../documents/Documents';
+import { Docs, DocumentOptions } from '../../../../documents/Documents';
import { DocumentManager } from '../../../../util/DocumentManager';
import { LinkManager } from '../../../../util/LinkManager';
import { ViewBoxAnnotatableComponent } from '../../../DocComponent';
@@ -33,6 +33,7 @@ import { Vectorstore } from '../vectorstore/Vectorstore';
import './ChatBox.scss';
import MessageComponentBox from './MessageComponent';
import { ProgressBar } from './ProgressBar';
+import { RichTextField } from '../../../../../fields/RichTextField';
dotenv.config();
@@ -89,7 +90,7 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() {
this.vectorstore_id = StrCast(this.dataDoc.vectorstore_id);
}
this.vectorstore = new Vectorstore(this.vectorstore_id, this.retrieveDocIds);
- this.agent = new Agent(this.vectorstore, this.retrieveSummaries, this.retrieveFormattedHistory, this.retrieveCSVData, this.addLinkedUrlDoc, this.createCSVInDash);
+ this.agent = new Agent(this.vectorstore, this.retrieveSummaries, this.retrieveFormattedHistory, this.retrieveCSVData, this.addLinkedUrlDoc, this.createDocInDash, this.createCSVInDash);
this.messagesRef = React.createRef<HTMLDivElement>();
// Reaction to update dataDoc when chat history changes
@@ -354,29 +355,11 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() {
const linkDoc = Docs.Create.LinkDocument(this.Document, doc);
LinkManager.Instance.addLink(linkDoc);
- let canDisplay;
-
- try {
- // Fetch the URL content through the proxy
- const { data } = await Networking.PostToServer('/proxyFetch', { url });
-
- // Simulating header behavior since we can't fetch headers via proxy
- const xFrameOptions = data.headers?.['x-frame-options'];
-
- if (xFrameOptions && xFrameOptions.toUpperCase() === 'SAMEORIGIN') {
- canDisplay = false;
- } else {
- canDisplay = true;
- }
- } catch (error) {
- console.error('Error fetching the URL from the server:', error);
- }
const chunkToAdd = {
chunkId: id,
chunkType: CHUNK_TYPE.URL,
url: url,
- canDisplay: canDisplay,
};
doc.chunk_simpl = JSON.stringify({ chunks: [chunkToAdd] });
@@ -411,6 +394,253 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() {
};
/**
+ * Creates a text document in the dashboard and adds it for analysis.
+ * @param title The title of the doc.
+ * @param text_content The text of the document.
+ * @param options Other optional document options (e.g. color)
+ * @param id The unique ID for the document.
+ */
+ @action
+ private createCollectionWithChildren = async (data: any, insideCol: boolean): Promise<Doc[]> => {
+ // Create an array of promises for each document
+ const childDocPromises = data.map(async doc => {
+ const parsedDoc = doc;
+ if (parsedDoc.doc_type !== 'collection') {
+ // Handle non-collection documents
+ return await this.whichDoc(parsedDoc.doc_type, parsedDoc.data, { backgroundColor: parsedDoc.backgroundColor, _width: parsedDoc.width, _height: parsedDoc.height }, parsedDoc.id, insideCol);
+ } else {
+ // Recursively process collections
+ const nestedDocs = await this.createCollectionWithChildren(parsedDoc.data, true);
+ const collectionOptions: DocumentOptions = {
+ title: parsedDoc.title,
+ backgroundColor: parsedDoc.backgroundColor,
+ _width: parsedDoc.width,
+ _height: parsedDoc.height,
+ _layout_fitWidth: true,
+ _freeform_backgroundGrid: true,
+ };
+ const collectionDoc = DocCast(Docs.Create.FreeformDocument(nestedDocs, collectionOptions));
+ return collectionDoc;
+ }
+ });
+
+ // Await all child document creations concurrently
+ const nestedResults = await Promise.all(childDocPromises);
+ // Flatten any nested arrays from recursive collection calls
+ const childDocs = nestedResults.flat() as Doc[];
+ childDocs.forEach(doc => {
+ console.log(DocCast(doc));
+ console.log(DocCast(doc)[DocData].data);
+ console.log(DocCast(doc)[DocData].data);
+ });
+ return childDocs;
+ };
+
+ // @action
+ // createSingleFlashcard = (data: any, options: DocumentOptions) => {
+
+ // }
+
+ @action
+ whichDoc = async (doc_type: string, data: string, options: DocumentOptions, id: string, insideCol: boolean): Promise<Doc> => {
+ let doc;
+ switch (doc_type) {
+ case 'text':
+ doc = DocCast(Docs.Create.TextDocument(data, options));
+ break;
+ case 'flashcard':
+ doc = this.createFlashcard(data, options);
+ break;
+ case 'deck':
+ doc = this.createDeck(data, options);
+ break;
+ case 'image':
+ doc = DocCast(Docs.Create.ImageDocument(data, options));
+ break;
+ case 'equation':
+ doc = DocCast(Docs.Create.EquationDocument(data, options));
+ break;
+ case 'noteboard':
+ doc = DocCast(Docs.Create.NoteTakingDocument([], options));
+ break;
+ case 'simulation':
+ doc = DocCast(Docs.Create.SimulationDocument(options));
+ break;
+ case 'collection': {
+ const arr = await this.createCollectionWithChildren(data, true);
+ options._layout_fitWidth = true;
+ options._freeform_backgroundGrid = true;
+ if (options.type_collection == 'tree') {
+ doc = DocCast(Docs.Create.TreeDocument(arr, options));
+ } else if (options.type_collection == 'masonry') {
+ doc = DocCast(Docs.Create.MasonryDocument(arr, options));
+ } else if (options.type_collection == 'card') {
+ doc = DocCast(Docs.Create.CardDeckDocument(arr, options));
+ } else if (options.type_collection == 'carousel') {
+ doc = DocCast(Docs.Create.CarouselDocument(arr, options));
+ } else if (options.type_collection == '3d-carousel') {
+ doc = DocCast(Docs.Create.Carousel3DDocument(arr, options));
+ } else if (options.type_collection == 'multicolumn') {
+ doc = DocCast(Docs.Create.CarouselDocument(arr, options));
+ } else {
+ doc = DocCast(Docs.Create.FreeformDocument(arr, options));
+ }
+ break;
+ }
+ case 'web':
+ options.data_useCors = true;
+ doc = DocCast(Docs.Create.WebDocument(data, options));
+ break;
+ case 'comparison':
+ doc = this.createComparison(data, options);
+ break;
+ case 'diagram':
+ doc = Docs.Create.DiagramDocument(options);
+ break;
+ case 'audio':
+ doc = Docs.Create.AudioDocument(data, options);
+ break;
+ case 'map':
+ doc = Docs.Create.MapDocument([], options);
+ break;
+ case 'screengrab':
+ doc = Docs.Create.ScreenshotDocument(options);
+ break;
+ case 'webcam':
+ doc = Docs.Create.WebCamDocument('', options);
+ break;
+ case 'button':
+ doc = Docs.Create.ButtonDocument(options);
+ break;
+ case 'script':
+ doc = Docs.Create.ScriptingDocument(null, options);
+ break;
+ case 'dataviz':
+ doc = Docs.Create.DataVizDocument('/users/rz/Downloads/addresses.csv', options);
+ break;
+ case 'chat':
+ doc = Docs.Create.ChatDocument(options);
+ break;
+ case 'trail':
+ doc = Docs.Create.PresDocument(options);
+ break;
+ case 'tab':
+ doc = Docs.Create.FreeformDocument([], options);
+ break;
+ case 'slide':
+ doc = Docs.Create.TreeDocument([], options);
+ break;
+ default:
+ doc = DocCast(Docs.Create.TextDocument(data, options));
+ }
+ doc!.x = NumCast(options.x ?? 0) + (insideCol ? 0 : NumCast(this.layoutDoc.x) + NumCast(this.layoutDoc.width)) + 100;
+ doc!.y = NumCast(options.y) + (insideCol ? 0 : NumCast(this.layoutDoc.y));
+ return doc;
+ };
+
+ /**
+ * Creates a document in the dashboard.
+ *
+ * @param {string} doc_type - The type of document to create.
+ * @param {string} data - The data used to generate the document.
+ * @param {DocumentOptions} options - Configuration options for the document.
+ * @param {string} id - Unique identifier for the document.
+ * @returns {Promise<void>} A promise that resolves once the document is created and displayed.
+ */
+ @action
+ createDocInDash = async (doc_type: string, data: string, options: DocumentOptions, id: string) => {
+ const doc = await this.whichDoc(doc_type, data, options, id);
+ const linkDoc = Docs.Create.LinkDocument(this.Document, doc);
+ LinkManager.Instance.addLink(linkDoc);
+ doc && this._props.addDocument?.(doc);
+ await DocumentManager.Instance.showDocument(doc, { willZoomCentered: true }, () => {});
+ };
+
+ /**
+ * Creates a deck of flashcards.
+ *
+ * @param {any} data - The data used to generate the flashcards. Can be a string or an object.
+ * @param {DocumentOptions} options - Configuration options for the flashcard deck.
+ * @returns {Doc} A carousel document containing the flashcard deck.
+ */
+ @action
+ createDeck = (data: any, options: DocumentOptions) => {
+ const flashcardDeck: Doc[] = [];
+ // Parse `data` only if it’s a string
+ const deckData = typeof data === 'string' ? JSON.parse(data) : data;
+ const flashcardArray = Array.isArray(deckData) ? deckData : Object.values(deckData);
+ // Process each flashcard document in the `deckData` array
+ if (flashcardArray.length == 2 && flashcardArray[0].doc_type == 'text' && flashcardArray[1].doc_type == 'text') {
+ this.createFlashcard(flashcardArray, options);
+ } else {
+ flashcardArray.forEach(doc => {
+ const flashcardDoc = this.createFlashcard(doc, options);
+ if (flashcardDoc) flashcardDeck.push(flashcardDoc);
+ });
+ }
+
+ // Create a carousel to contain the flashcard deck
+ const carouselDoc = DocCast(
+ Docs.Create.CarouselDocument(flashcardDeck, {
+ title: options.title || 'Flashcard Deck',
+ _width: options._width || 300,
+ _height: options._height || 300,
+ _layout_fitWidth: false,
+ _layout_autoHeight: true,
+ })
+ );
+ return carouselDoc;
+ };
+
+ /**
+ * Creates a single flashcard document.
+ *
+ * @param {any} data - The data used to generate the flashcard. Can be a string or an object.
+ * @param {any} options - Configuration options for the flashcard.
+ * @returns {Doc | undefined} The created flashcard document, or undefined if the flashcard cannot be created.
+ */
+ @action
+ createFlashcard = (data: any, options: any) => {
+ const deckData = typeof data === 'string' ? JSON.parse(data) : data;
+ const flashcardArray = Array.isArray(deckData) ? deckData : Object.values(deckData)[2];
+ const [front, back] = flashcardArray;
+
+ if (front.doc_type === 'text' && back.doc_type === 'text') {
+ const sideOptions: DocumentOptions = {
+ backgroundColor: options.backgroundColor,
+ _width: options._width,
+ _height: options._height,
+ };
+
+ // Create front and back text documents
+ const side1 = Docs.Create.CenteredTextCreator(front.title, front.data, sideOptions);
+ const side2 = Docs.Create.CenteredTextCreator(back.title, back.data, sideOptions);
+
+ // Create the flashcard document with both sides
+ const flashcardDoc = DocCast(Docs.Create.FlashcardDocument(data.title, side1, side2, sideOptions));
+ return flashcardDoc;
+ }
+ };
+
+ /**
+ * Creates a comparison document.
+ *
+ * @param {any} doc - The document data containing left and right components for comparison.
+ * @param {any} options - Configuration options for the comparison document.
+ * @returns {Doc} The created comparison document.
+ */
+ @action
+ createComparison = (doc: any, options: any) => {
+ const comp = Docs.Create.ComparisonDocument(options.title, { _width: options.width, _height: options.height | 300, backgroundColor: options.backgroundColor });
+ const [left, right] = doc;
+ const docLeft = DocCast(Docs.Create.TextDocument(left.data, { backgroundColor: left.backgroundColor, _width: left.width, _height: left.height }));
+ const docRight = DocCast(Docs.Create.TextDocument(right.data, { backgroundColor: right.backgroundColor, _width: right.width, _height: right.height }));
+ comp[DocData].data_back = docLeft;
+ comp[DocData].data_front = docRight;
+ return comp;
+ };
+
+ /**
* Event handler to manage citations click in the message components.
* @param citation The citation object clicked by the user.
*/
@@ -462,11 +692,8 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() {
});
break;
case CHUNK_TYPE.URL:
- if (!foundChunk.canDisplay) {
- window.open(StrCast(doc.displayUrl), '_blank');
- } else if (foundChunk.canDisplay) {
- DocumentManager.Instance.showDocument(doc, { willZoomCentered: true }, () => {});
- }
+ DocumentManager.Instance.showDocument(doc, { willZoomCentered: true }, () => {});
+
break;
case CHUNK_TYPE.CSV:
DocumentManager.Instance.showDocument(doc, { willZoomCentered: true }, () => {});
@@ -709,17 +936,10 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() {
</div>
<div className="chat-messages" ref={this.messagesRef}>
{this.history.map((message, index) => (
- <MessageComponentBox key={index} message={message} index={index} onFollowUpClick={this.handleFollowUpClick} onCitationClick={this.handleCitationClick} updateMessageCitations={this.updateMessageCitations} />
+ <MessageComponentBox key={index} message={message} onFollowUpClick={this.handleFollowUpClick} onCitationClick={this.handleCitationClick} updateMessageCitations={this.updateMessageCitations} />
))}
{this.current_message && (
- <MessageComponentBox
- key={this.history.length}
- message={this.current_message}
- index={this.history.length}
- onFollowUpClick={this.handleFollowUpClick}
- onCitationClick={this.handleCitationClick}
- updateMessageCitations={this.updateMessageCitations}
- />
+ <MessageComponentBox key={this.history.length} message={this.current_message} onFollowUpClick={this.handleFollowUpClick} onCitationClick={this.handleCitationClick} updateMessageCitations={this.updateMessageCitations} />
)}
</div>
<form onSubmit={this.askGPT} className="chat-input">
diff --git a/src/client/views/nodes/chatbot/tools/BaseTool.ts b/src/client/views/nodes/chatbot/tools/BaseTool.ts
index 58cd514d9..05ca83b26 100644
--- a/src/client/views/nodes/chatbot/tools/BaseTool.ts
+++ b/src/client/views/nodes/chatbot/tools/BaseTool.ts
@@ -1,5 +1,5 @@
import { Observation } from '../types/types';
-import { Parameter, Tool, ParametersType } from './ToolTypes';
+import { Parameter, ParametersType } from '../types/tool_types';
/**
* @file BaseTool.ts
@@ -14,7 +14,7 @@ import { Parameter, Tool, ParametersType } from './ToolTypes';
* It is generic over a type parameter `P`, which extends `ReadonlyArray<Parameter>`.
* This means `P` is a readonly array of `Parameter` objects that cannot be modified (immutable).
*/
-export abstract class BaseTool<P extends ReadonlyArray<Parameter>> implements Tool<P> {
+export abstract class BaseTool<P extends ReadonlyArray<Parameter>> {
// The name of the tool (e.g., "calculate", "searchTool")
name: string;
// A description of the tool's functionality
diff --git a/src/client/views/nodes/chatbot/tools/CalculateTool.ts b/src/client/views/nodes/chatbot/tools/CalculateTool.ts
index e96c9a98a..139ede8f0 100644
--- a/src/client/views/nodes/chatbot/tools/CalculateTool.ts
+++ b/src/client/views/nodes/chatbot/tools/CalculateTool.ts
@@ -1,5 +1,5 @@
import { Observation } from '../types/types';
-import { ParametersType } from './ToolTypes';
+import { ParametersType } from '../types/tool_types';
import { BaseTool } from './BaseTool';
const calculateToolParams = [
diff --git a/src/client/views/nodes/chatbot/tools/CreateCSVTool.ts b/src/client/views/nodes/chatbot/tools/CreateCSVTool.ts
index b321d98ba..2cc513d6c 100644
--- a/src/client/views/nodes/chatbot/tools/CreateCSVTool.ts
+++ b/src/client/views/nodes/chatbot/tools/CreateCSVTool.ts
@@ -1,7 +1,7 @@
import { BaseTool } from './BaseTool';
import { Networking } from '../../../../Network';
import { Observation } from '../types/types';
-import { ParametersType } from './ToolTypes';
+import { ParametersType } from '../types/tool_types';
const createCSVToolParams = [
{
diff --git a/src/client/views/nodes/chatbot/tools/CreateDocumentTool.ts b/src/client/views/nodes/chatbot/tools/CreateDocumentTool.ts
new file mode 100644
index 000000000..63a6004a7
--- /dev/null
+++ b/src/client/views/nodes/chatbot/tools/CreateDocumentTool.ts
@@ -0,0 +1,392 @@
+import { v4 as uuidv4 } from 'uuid';
+import { BaseTool } from './BaseTool';
+import { Observation } from '../types/types';
+import { ParametersType } from '../types/tool_types';
+import { DocumentOptions } from '../../../../documents/Documents';
+
+/**
+ * Tthe CreateDocTool class is responsible for creating
+ * documents of various types (e.g., text, flashcards, collections) and organizing them in a
+ * structured manner. The tool supports creating dashboards with diverse document types and
+ * ensures proper placement of documents without overlap.
+ */
+
+// Example document structure for various document types
+const example = [
+ {
+ doc_type: 'equation',
+ title: 'quadratic',
+ data: 'x^2 + y^2 = 3',
+ width: 300,
+ height: 300,
+ x: 0,
+ y: 0,
+ },
+ {
+ doc_type: 'collection',
+ title: 'Advanced Biology',
+ data: [
+ {
+ doc_type: 'text',
+ title: 'Cell Structure',
+ data: 'Cells are the basic building blocks of all living organisms.',
+ width: 300,
+ height: 300,
+ x: 500,
+ y: 0,
+ },
+ ],
+ backgroundColor: '#00ff00',
+ width: 600,
+ height: 600,
+ x: 600,
+ y: 0,
+ type_collection: 'tree',
+ },
+ {
+ doc_type: 'image',
+ title: 'experiment',
+ data: 'https://plus.unsplash.com/premium_photo-1694819488591-a43907d1c5cc?q=80&w=2628&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D',
+ width: 300,
+ height: 300,
+ x: 600,
+ y: 300,
+ },
+ {
+ doc_type: 'deck',
+ title: 'Chemistry',
+ data: [
+ {
+ doc_type: 'flashcard',
+ title: 'Photosynthesis',
+ data: [
+ {
+ doc_type: 'text',
+ title: 'front_Photosynthesis',
+ data: 'What is photosynthesis?',
+ width: 300,
+ height: 300,
+ x: 100,
+ y: 600,
+ },
+ {
+ doc_type: 'text',
+ title: 'back_photosynthesis',
+ data: 'The process by which plants make food.',
+ width: 300,
+ height: 300,
+ x: 100,
+ y: 700,
+ },
+ ],
+ backgroundColor: '#00ff00',
+ width: 300,
+ height: 300,
+ x: 300,
+ y: 1000,
+ },
+ {
+ doc_type: 'flashcard',
+ title: 'Photosynthesis',
+ data: [
+ {
+ doc_type: 'text',
+ title: 'front_Photosynthesis',
+ data: 'What is photosynthesis?',
+ width: 300,
+ height: 300,
+ x: 200,
+ y: 800,
+ },
+ {
+ doc_type: 'text',
+ title: 'back_photosynthesis',
+ data: 'The process by which plants make food.',
+ width: 300,
+ height: 300,
+ x: 100,
+ y: -100,
+ },
+ ],
+ backgroundColor: '#00ff00',
+ width: 300,
+ height: 300,
+ x: 10,
+ y: 70,
+ },
+ ],
+ backgroundColor: '#00ff00',
+ width: 600,
+ height: 600,
+ x: 200,
+ y: 800,
+ },
+ {
+ doc_type: 'web',
+ title: 'Brown University Wikipedia',
+ data: 'https://en.wikipedia.org/wiki/Brown_University',
+ width: 300,
+ height: 300,
+ x: 1000,
+ y: 2000,
+ },
+ {
+ doc_type: 'simulation',
+ title: 'Physics simulation',
+ data: '',
+ width: 300,
+ height: 300,
+ x: 100,
+ y: 100,
+ },
+ {
+ doc_type: 'comparison',
+ title: 'WWI vs. WWII',
+ data: [
+ {
+ doc_type: 'text',
+ title: 'WWI',
+ data: 'From 1914 to 1918, fighting took place across several continents, at sea and, for the first time, in the air.',
+ width: 300,
+ height: 300,
+ x: 100,
+ y: 100,
+ },
+ {
+ doc_type: 'text',
+ title: 'WWII',
+ data: 'A devastating global conflict spanning from 1939 to 1945, saw the Allied powers fight against the Axis powers.',
+ width: 300,
+ height: 300,
+ x: 100,
+ y: 100,
+ },
+ ],
+ width: 300,
+ height: 300,
+ x: 100,
+ y: 100,
+ },
+ {
+ doc_type: 'collection',
+ title: 'Science Collection',
+ data: [
+ {
+ doc_type: 'flashcard',
+ title: 'Photosynthesis',
+ data: [
+ {
+ doc_type: 'text',
+ title: 'front_Photosynthesis',
+ data: 'What is photosynthesis?',
+ width: 300,
+ height: 300,
+ },
+ {
+ doc_type: 'text',
+ title: 'back_photosynthesis',
+ data: 'The process by which plants make food.',
+ width: 300,
+ height: 300,
+ },
+ ],
+ backgroundColor: '#00ff00',
+ width: 300,
+ height: 300,
+ },
+ {
+ doc_type: 'web',
+ title: 'Brown University Wikipedia',
+ data: 'https://en.wikipedia.org/wiki/Brown_University',
+ width: 300,
+ height: 300,
+ x: 1100,
+ y: 1100,
+ },
+ {
+ doc_type: 'text',
+ title: 'Water Cycle',
+ data: 'The continuous movement of water on, above, and below the Earth’s surface.',
+ width: 300,
+ height: 300,
+ x: 1500,
+ y: 500,
+ },
+ {
+ doc_type: 'collection',
+ title: 'Advanced Biology',
+ data: [
+ {
+ doc_type: 'text',
+ title: 'Cell Structure',
+ data: 'Cells are the basic building blocks of all living organisms.',
+ width: 300,
+ height: 300,
+ },
+ ],
+ backgroundColor: '#00ff00',
+ width: 600,
+ height: 600,
+ x: 1100,
+ y: 500,
+ type_collection: 'freeform',
+ },
+ ],
+ width: 600,
+ height: 600,
+ x: 500,
+ y: 500,
+ type_collection: 'freeform',
+ },
+];
+
+// Stringify the entire structure for transmission if needed
+const finalJsonString = JSON.stringify(example);
+
+// Instructions for creating various document types
+const docInstructions = {
+ collection: {
+ description:
+ 'A recursive collection of documents as a stringified array. Each document can be a "text", "deck", "flashcard", "image", "web", "image", "comparison", "equation", "noteboard", "simulation", "diagram", "map", "screengrab", "webcam", "button", or another "collection".',
+ example: finalJsonString,
+ },
+ text: 'Provide text content as a plain string. Example: "This is a standalone text document."',
+ flashcard: 'Two text documents with content for the front and back.',
+ deck: 'A decks data is an array of flashcards.',
+ web: 'A URL to a webpage. Example: https://en.wikipedia.org/wiki/Brown_University',
+ equation: 'Create an equation document, not a text document. Data is math equation.',
+ noteboard: 'Create a noteboard document',
+ simulation: 'Create a simulation document',
+ audio: 'A url to an audio recording. Example: ',
+} as const;
+
+// Parameters for creating individual documents
+const createDocToolParams = [
+ {
+ name: 'data',
+ type: 'string', // Accepts either string or array, supporting individual and nested data
+ description: docInstructions,
+ required: true,
+ },
+ {
+ name: 'doc_type',
+ type: 'string',
+ description: 'The type of the document. Options: "collection", "text", "flashcard", "web".',
+ required: true,
+ },
+ {
+ name: 'title',
+ type: 'string',
+ description: 'The title of the document.',
+ required: true,
+ },
+ {
+ name: 'x',
+ type: 'number',
+ description: 'The x location of the document; 0 <= x.',
+ required: true,
+ },
+ {
+ name: 'y',
+ type: 'number',
+ description: 'The y location of the document; 0 <= y.',
+ required: true,
+ },
+ {
+ name: 'background_color',
+ type: 'string',
+ description: 'The background color of the document as a hex string.',
+ required: false,
+ },
+ {
+ name: 'font_color',
+ type: 'string',
+ description: 'The font color of the document as a hex string.',
+ required: false,
+ },
+ {
+ name: 'width',
+ type: 'number',
+ description: 'The width of the document in pixels.',
+ required: true,
+ },
+ {
+ name: 'height',
+ type: 'number',
+ description: 'The height of the document in pixels.',
+ required: true,
+ },
+ {
+ name: 'type_collection',
+ type: 'string',
+ description: 'Either freeform, card, carousel, 3d-carousel, multicolumn, multirow, linear, map, notetaking, schema, stacking, grid, tree, or masonry.',
+ required: false,
+ },
+] as const;
+
+// Parameters for creating a list of documents
+const createListDocToolParams = [
+ {
+ name: 'docs',
+ type: 'string',
+ description:
+ 'Array of documents in stringified JSON format. Each item in the array should be an individual stringified JSON object. Each document can be of type "text", "flashcard", "web", or "collection" (for nested documents). ' +
+ 'Use this structure for nesting collections within collections. Each document should follow the structure in ' +
+ createDocToolParams +
+ '. Example: ' +
+ finalJsonString,
+ required: true,
+ },
+] as const;
+
+type CreateListDocToolParamsType = typeof createListDocToolParams;
+
+// Tool class for creating documents
+export class CreateDocTool extends BaseTool<CreateListDocToolParamsType> {
+ private _addLinkedDoc: (doc_type: string, data: string, options: DocumentOptions, id: string) => void;
+
+ constructor(addLinkedDoc: (doc_type: string, data: string, options: DocumentOptions, id: string) => void) {
+ super(
+ 'createDoc',
+ 'Creates one or more documents that best fit the user’s request. If the user requests a "dashboard," first call the search tool and then generate a variety of document types individually, with absolutely a minimum of 20 documents with two stacks of flashcards that are small and it should have a couple nested freeform collections of things, each with different content and color schemes. For example, create multiple individual documents like "text," "deck," "web", "equation," and "comparison." Use decks instead of flashcards for dashboards. Decks should have at least three flashcards. Really think about what documents are useful to the user. If they ask for a dashboard about the skeletal system, include flashcards, as they would be helpful. Arrange the documents in a grid layout, ensuring that the x and y coordinates are calculated so no documents overlap but they should be directly next to each other with 20 padding in between. Take into account the width and height of each document, spacing them appropriately to prevent collisions. Use a systematic approach, such as placing each document in a grid cell based on its order, where cell dimensions match the document dimensions plus a fixed margin for spacing. Do not nest all documents within a single collection unless explicitly requested by the user. Instead, create a set of independent documents with diverse document types. Each type should appear separately unless specified otherwise.',
+ createListDocToolParams,
+ 'Use the "data" parameter for document content and include title, color, and document dimensions. Ensure web documents use URLs from the search tool if relevant. Each document in a dashboard should be unique and well-differentiated in type and content, without repetition of similar types in any single collection.',
+ 'When creating a dashboard, ensure that it consists of a broad range of document types. Include a variety of documents, such as text, web, deck, comparison, image, simulation, and equation documents, each with distinct titles and colors, following the user’s preferences. ' +
+ 'Do not overuse collections or nest all document types within a single collection; instead, represent document types individually. Use this example for reference: ' +
+ finalJsonString +
+ '. Which documents are created should be random with different numbers of each document type and different for each dashboard. Must use search tool before creating a dashboard.'
+ );
+ this._addLinkedDoc = addLinkedDoc;
+ }
+
+ // Executes the tool logic for creating documents
+ async execute(args: ParametersType<CreateListDocToolParamsType>): Promise<Observation[]> {
+ try {
+ console.log('EXE' + args.docs);
+ const parsedDoc = JSON.parse(args.docs);
+ console.log('parsed' + parsedDoc);
+ parsedDoc.forEach(doc => {
+ this._addLinkedDoc(
+ doc['doc_type'],
+ doc['data'],
+ {
+ title: doc['title'],
+ backgroundColor: doc['backgroundColor'],
+ text_fontColor: doc['font_color'],
+ _width: doc['width'],
+ _height: doc['height'],
+ type_collection: doc['type_collection'],
+ _layout_fitWidth: false,
+ _layout_autoHeight: true,
+ x: doc['x'],
+ y: doc['y'],
+ },
+ uuidv4()
+ );
+ });
+ return [{ type: 'text', text: 'Created document.' }];
+ } catch (error) {
+ return [{ type: 'text', text: 'Error creating text document, ' + error }];
+ }
+ }
+}
diff --git a/src/client/views/nodes/chatbot/tools/DataAnalysisTool.ts b/src/client/views/nodes/chatbot/tools/DataAnalysisTool.ts
index d9b75219d..97b9ee023 100644
--- a/src/client/views/nodes/chatbot/tools/DataAnalysisTool.ts
+++ b/src/client/views/nodes/chatbot/tools/DataAnalysisTool.ts
@@ -1,5 +1,5 @@
import { Observation } from '../types/types';
-import { ParametersType } from './ToolTypes';
+import { ParametersType } from '../types/tool_types';
import { BaseTool } from './BaseTool';
const dataAnalysisToolParams = [
diff --git a/src/client/views/nodes/chatbot/tools/GetDocsTool.ts b/src/client/views/nodes/chatbot/tools/GetDocsTool.ts
index 26756522c..4286e7ffe 100644
--- a/src/client/views/nodes/chatbot/tools/GetDocsTool.ts
+++ b/src/client/views/nodes/chatbot/tools/GetDocsTool.ts
@@ -1,5 +1,5 @@
import { Observation } from '../types/types';
-import { ParametersType } from './ToolTypes';
+import { ParametersType } from '../types/tool_types';
import { BaseTool } from './BaseTool';
import { DocServer } from '../../../../DocServer';
import { Docs } from '../../../../documents/Documents';
diff --git a/src/client/views/nodes/chatbot/tools/NoTool.ts b/src/client/views/nodes/chatbot/tools/NoTool.ts
index a92e3fa23..5d652fd8d 100644
--- a/src/client/views/nodes/chatbot/tools/NoTool.ts
+++ b/src/client/views/nodes/chatbot/tools/NoTool.ts
@@ -1,6 +1,6 @@
import { BaseTool } from './BaseTool';
import { Observation } from '../types/types';
-import { ParametersType } from './ToolTypes';
+import { ParametersType } from '../types/tool_types';
const noToolParams = [] as const;
diff --git a/src/client/views/nodes/chatbot/tools/RAGTool.ts b/src/client/views/nodes/chatbot/tools/RAGTool.ts
index 482069f36..fcd93a07a 100644
--- a/src/client/views/nodes/chatbot/tools/RAGTool.ts
+++ b/src/client/views/nodes/chatbot/tools/RAGTool.ts
@@ -1,6 +1,6 @@
import { Networking } from '../../../../Network';
import { Observation, RAGChunk } from '../types/types';
-import { ParametersType } from './ToolTypes';
+import { ParametersType } from '../types/tool_types';
import { Vectorstore } from '../vectorstore/Vectorstore';
import { BaseTool } from './BaseTool';
diff --git a/src/client/views/nodes/chatbot/tools/SearchTool.ts b/src/client/views/nodes/chatbot/tools/SearchTool.ts
index fd5144dd6..d22f4c189 100644
--- a/src/client/views/nodes/chatbot/tools/SearchTool.ts
+++ b/src/client/views/nodes/chatbot/tools/SearchTool.ts
@@ -2,11 +2,11 @@ import { v4 as uuidv4 } from 'uuid';
import { Networking } from '../../../../Network';
import { BaseTool } from './BaseTool';
import { Observation } from '../types/types';
-import { ParametersType } from './ToolTypes';
+import { ParametersType } from '../types/tool_types';
const searchToolParams = [
{
- name: 'query',
+ name: 'queries',
type: 'string[]',
description: 'The search query or queries to use for finding websites',
required: true,
@@ -20,7 +20,7 @@ export class SearchTool extends BaseTool<SearchToolParamsType> {
private _addLinkedUrlDoc: (url: string, id: string) => void;
private _max_results: number;
- constructor(addLinkedUrlDoc: (url: string, id: string) => void, max_results: number = 5) {
+ constructor(addLinkedUrlDoc: (url: string, id: string) => void, max_results: number = 4) {
super(
'searchTool',
'Search the web to find a wide range of websites related to a query or multiple queries',
@@ -33,8 +33,9 @@ export class SearchTool extends BaseTool<SearchToolParamsType> {
}
async execute(args: ParametersType<SearchToolParamsType>): Promise<Observation[]> {
- const queries = args.query;
+ const queries = args.queries;
+ console.log(`Searching the web for queries: ${queries[0]}`);
// Create an array of promises, each one handling a search for a query
const searchPromises = queries.map(async query => {
try {
@@ -44,9 +45,10 @@ export class SearchTool extends BaseTool<SearchToolParamsType> {
});
const data = results.map((result: { url: string; snippet: string }) => {
const id = uuidv4();
+ this._addLinkedUrlDoc(result.url, id);
return {
type: 'text',
- text: `<chunk chunk_id="${id}" chunk_type="text"><url>${result.url}</url><overview>${result.snippet}</overview></chunk>`,
+ text: `<chunk chunk_id="${id}" chunk_type="url"><url>${result.url}</url><overview>${result.snippet}</overview></chunk>`,
};
});
return data;
diff --git a/src/client/views/nodes/chatbot/tools/WebsiteInfoScraperTool.ts b/src/client/views/nodes/chatbot/tools/WebsiteInfoScraperTool.ts
index f2e3863a6..ce659e344 100644
--- a/src/client/views/nodes/chatbot/tools/WebsiteInfoScraperTool.ts
+++ b/src/client/views/nodes/chatbot/tools/WebsiteInfoScraperTool.ts
@@ -2,7 +2,7 @@ import { v4 as uuidv4 } from 'uuid';
import { Networking } from '../../../../Network';
import { BaseTool } from './BaseTool';
import { Observation } from '../types/types';
-import { ParametersType } from './ToolTypes';
+import { ParametersType } from '../types/tool_types';
const websiteInfoScraperToolParams = [
{
diff --git a/src/client/views/nodes/chatbot/tools/WikipediaTool.ts b/src/client/views/nodes/chatbot/tools/WikipediaTool.ts
index 4fcffe2ed..f2dbf3cfd 100644
--- a/src/client/views/nodes/chatbot/tools/WikipediaTool.ts
+++ b/src/client/views/nodes/chatbot/tools/WikipediaTool.ts
@@ -2,7 +2,7 @@ import { v4 as uuidv4 } from 'uuid';
import { Networking } from '../../../../Network';
import { BaseTool } from './BaseTool';
import { Observation } from '../types/types';
-import { ParametersType } from './ToolTypes';
+import { ParametersType } from '../types/tool_types';
const wikipediaToolParams = [
{
@@ -38,7 +38,7 @@ export class WikipediaTool extends BaseTool<WikipediaToolParamsType> {
return [
{
type: 'text',
- text: `<chunk chunk_id="${id}" chunk_type="text"> ${text} </chunk>`,
+ text: `<chunk chunk_id="${id}" chunk_type="url"> ${text} </chunk>`,
},
];
} catch (error) {
diff --git a/src/client/views/nodes/chatbot/tools/ToolTypes.ts b/src/client/views/nodes/chatbot/types/tool_types.ts
index d47a38952..b2e05efe4 100644
--- a/src/client/views/nodes/chatbot/tools/ToolTypes.ts
+++ b/src/client/views/nodes/chatbot/types/tool_types.ts
@@ -1,34 +1,4 @@
-import { Observation } from '../types/types';
-
-/**
- * The `Tool` interface represents a generic tool in the system.
- * It is generic over a type parameter `P`, which extends `ReadonlyArray<Parameter>`.
- * @template P - An array of `Parameter` objects defining the tool's parameters.
- */
-export interface Tool<P extends ReadonlyArray<Parameter>> {
- // The name of the tool (e.g., "calculate", "searchTool")
- name: string;
- // A description of the tool's functionality
- description: string;
- // An array of parameter definitions for the tool
- parameterRules: P;
- // Guidelines for how to handle citations when using the tool
- citationRules: string;
- // A brief summary of the tool's purpose
- briefSummary: string;
- /**
- * Executes the tool's main functionality.
- * @param args - The arguments for execution, with types inferred from `ParametersType<P>`.
- * @returns A promise that resolves to an array of `Observation` objects.
- */
- execute: (args: ParametersType<P>) => Promise<Observation[]>;
- /**
- * Generates an action rule object that describes the tool's usage.
- * @returns An object representing the tool's action rules.
- */
- getActionRule: () => Record<string, unknown>;
-}
-
+import { Observation } from './types';
/**
* The `Parameter` type defines the structure of a parameter configuration.
*/
@@ -49,7 +19,7 @@ export type Parameter = {
* A utility type that maps string representations of types to actual TypeScript types.
* This is used to convert the `type` field of a `Parameter` into a concrete TypeScript type.
*/
-type TypeMap = {
+export type TypeMap = {
string: string;
number: number;
boolean: boolean;
diff --git a/src/client/views/nodes/chatbot/types/types.ts b/src/client/views/nodes/chatbot/types/types.ts
index 7abad85f0..c65ac9820 100644
--- a/src/client/views/nodes/chatbot/types/types.ts
+++ b/src/client/views/nodes/chatbot/types/types.ts
@@ -102,7 +102,6 @@ export interface SimplifiedChunk {
location?: string;
chunkType: CHUNK_TYPE;
url?: string;
- canDisplay?: boolean;
}
export interface AI_Document {
diff --git a/src/client/views/nodes/chatbot/vectorstore/Vectorstore.ts b/src/client/views/nodes/chatbot/vectorstore/Vectorstore.ts
index f96f55997..cf7fa0ff3 100644
--- a/src/client/views/nodes/chatbot/vectorstore/Vectorstore.ts
+++ b/src/client/views/nodes/chatbot/vectorstore/Vectorstore.ts
@@ -37,14 +37,14 @@ export class Vectorstore {
* @param doc_ids A function that returns a list of document IDs.
*/
constructor(id: string, doc_ids: () => string[]) {
- const pineconeApiKey = process.env.PINECONE_API_KEY;
+ const pineconeApiKey = '51738e9a-bea2-4c11-b6bf-48a825e774dc';
if (!pineconeApiKey) {
throw new Error('PINECONE_API_KEY is not defined.');
}
// Initialize Pinecone and Cohere clients with API keys from the environment.
this.pinecone = new Pinecone({ apiKey: pineconeApiKey });
- this.cohere = new CohereClient({ token: process.env.COHERE_API_KEY });
+ // this.cohere = new CohereClient({ token: process.env.COHERE_API_KEY });
this._id = id;
this._doc_ids = doc_ids();
this.initializeIndex();
diff --git a/src/server/ApiManagers/AssistantManager.ts b/src/server/ApiManagers/AssistantManager.ts
index 8447a4934..4d2068014 100644
--- a/src/server/ApiManagers/AssistantManager.ts
+++ b/src/server/ApiManagers/AssistantManager.ts
@@ -9,7 +9,7 @@
*/
import { Readability } from '@mozilla/readability';
-import axios from 'axios';
+import axios, { AxiosResponse } from 'axios';
import { spawn } from 'child_process';
import * as fs from 'fs';
import { writeFile } from 'fs';
@@ -23,6 +23,7 @@ import { AI_Document } from '../../client/views/nodes/chatbot/types/types';
import { Method } from '../RouteManager';
import { filesDirectory, publicDirectory } from '../SocketData';
import ApiManager, { Registration } from './ApiManager';
+import { getServerPath } from '../../client/util/reportManager/reportManagerUtils';
// Enumeration of directories where different file types are stored
export enum Directory {
@@ -115,29 +116,79 @@ export default class AssistantManager extends ApiManager {
},
});
- // Register Google Web Search Results API route
register({
method: Method.POST,
subscription: '/getWebSearchResults',
secureHandler: async ({ req, res }) => {
const { query, max_results } = req.body;
- try {
- // Fetch search results using Google Custom Search API
- const response = await customsearch.cse.list({
+ const MIN_VALID_RESULTS_RATIO = 0.75; // 3/4 threshold
+ let startIndex = 1; // Start at the first result initially
+ let validResults: any[] = [];
+
+ const fetchSearchResults = async (start: number) => {
+ return customsearch.cse.list({
q: query,
cx: process.env._CLIENT_GOOGLE_SEARCH_ENGINE_ID,
key: process.env._CLIENT_GOOGLE_API_KEY,
safe: 'active',
num: max_results,
+ start, // This controls which result index the search starts from
});
+ };
+
+ const filterResultsByXFrameOptions = async (results: any[]) => {
+ const filteredResults = await Promise.all(
+ results.map(async result => {
+ try {
+ const urlResponse: AxiosResponse = await axios.head(result.url, { timeout: 5000 });
+ const xFrameOptions = urlResponse.headers['x-frame-options'];
+ if (xFrameOptions && xFrameOptions.toUpperCase() === 'SAMEORIGIN') {
+ return result;
+ }
+ } catch (error) {
+ console.error(`Error checking x-frame-options for URL: ${result.url}`, error);
+ }
+ return null; // Exclude the result if it doesn't match
+ })
+ );
+ return filteredResults.filter(result => result !== null); // Remove null results
+ };
- const results =
+ try {
+ // Fetch initial search results
+ let response = await fetchSearchResults(startIndex);
+ let initialResults =
response.data.items?.map(item => ({
url: item.link,
snippet: item.snippet,
})) || [];
- res.send({ results });
+ // Filter the initial results
+ validResults = await filterResultsByXFrameOptions(initialResults);
+
+ // If valid results are less than 3/4 of max_results, fetch more results
+ while (validResults.length < max_results * MIN_VALID_RESULTS_RATIO) {
+ // Increment the start index by the max_results to fetch the next set of results
+ startIndex += max_results;
+ response = await fetchSearchResults(startIndex);
+
+ const additionalResults =
+ response.data.items?.map(item => ({
+ url: item.link,
+ snippet: item.snippet,
+ })) || [];
+
+ const additionalValidResults = await filterResultsByXFrameOptions(additionalResults);
+ validResults = [...validResults, ...additionalValidResults]; // Combine valid results
+
+ // Break if no more results are available
+ if (additionalValidResults.length === 0 || response.data.items?.length === 0) {
+ break;
+ }
+ }
+
+ // Return the filtered valid results
+ res.send({ results: validResults.slice(0, max_results) }); // Limit the results to max_results
} catch (error) {
console.error('Error performing web search:', error);
res.status(500).send({
@@ -299,47 +350,16 @@ export default class AssistantManager extends ApiManager {
method: Method.GET,
subscription: '/getResult/:jobId',
secureHandler: async ({ req, res }) => {
- const { jobId } = req.params; // Get the job ID from the URL parameters
- // Check if the job result is available
+ const { jobId } = req.params;
if (jobResults[jobId]) {
const result = jobResults[jobId] as AI_Document & { status: string };
- // If the result contains image or table chunks, save the base64 data as image files
if (result.chunks && Array.isArray(result.chunks)) {
- await Promise.all(
- result.chunks.map(chunk => {
- if (chunk.metadata && (chunk.metadata.type === 'image' || chunk.metadata.type === 'table')) {
- const files_directory = '/files/chunk_images/';
- const directory = path.join(publicDirectory, files_directory);
-
- // Ensure the directory exists or create it
- if (!fs.existsSync(directory)) {
- fs.mkdirSync(directory);
- }
-
- const fileName = path.basename(chunk.metadata.file_path); // Get the file name from the path
- const filePath = path.join(directory, fileName); // Create the full file path
-
- // Check if the chunk contains base64 encoded data
- if (chunk.metadata.base64_data) {
- // Decode the base64 data and write it to a file
- const buffer = Buffer.from(chunk.metadata.base64_data, 'base64');
- fs.promises.writeFile(filePath, buffer).then(() => {
- // Update the file path in the chunk's metadata
- chunk.metadata.file_path = path.join(files_directory, fileName);
- chunk.metadata.base64_data = undefined; // Remove the base64 data from the metadata
- });
- } else {
- console.warn(`No base64_data found for chunk: ${fileName}`);
- }
- }
- })
- );
result.status = 'completed';
} else {
result.status = 'pending';
}
- res.json(result); // Send the result back to the client
+ res.json(result);
} else {
res.status(202).send({ status: 'pending' });
}
@@ -367,7 +387,7 @@ export default class AssistantManager extends ApiManager {
// If the chunk is an image or table, read the corresponding file and encode it as base64
if (chunk.metadata.type === 'image' || chunk.metadata.type === 'table') {
try {
- const filePath = serverPathToFile(Directory.chunk_images, chunk.metadata.file_path); // Get the file path
+ const filePath = path.join(pathToDirectory(Directory.chunk_images), chunk.metadata.file_path); // Get the file path
readFileAsync(filePath).then(imageBuffer => {
const base64Image = imageBuffer.toString('base64'); // Convert the image to base64
@@ -445,10 +465,12 @@ function spawnPythonProcess(jobId: string, file_name: string, file_data: string)
const requirementsPath = path.join(__dirname, '../chunker/requirements.txt');
const pythonScriptPath = path.join(__dirname, '../chunker/pdf_chunker.py');
+ const outputDirectory = pathToDirectory(Directory.chunk_images);
+
function runPythonScript() {
const pythonPath = process.platform === 'win32' ? path.join(venvPath, 'Scripts', 'python') : path.join(venvPath, 'bin', 'python3');
- const pythonProcess = spawn(pythonPath, [pythonScriptPath, jobId, file_name, file_data]);
+ const pythonProcess = spawn(pythonPath, [pythonScriptPath, jobId, file_name, file_data, outputDirectory]);
let pythonOutput = '';
let stderrOutput = '';
@@ -460,23 +482,30 @@ function spawnPythonProcess(jobId: string, file_name: string, file_data: string)
pythonProcess.stderr.on('data', data => {
stderrOutput += data.toString();
const lines = stderrOutput.split('\n');
+ stderrOutput = lines.pop() || ''; // Save the last partial line back to stderrOutput
lines.forEach(line => {
if (line.trim()) {
- try {
- const parsedOutput = JSON.parse(line);
- if (parsedOutput.job_id && parsedOutput.progress !== undefined) {
- jobProgress[parsedOutput.job_id] = {
- step: parsedOutput.step,
- progress: parsedOutput.progress,
- };
- } else if (parsedOutput.progress !== undefined) {
- jobProgress[jobId] = {
- step: parsedOutput.step,
- progress: parsedOutput.progress,
- };
+ if (line.startsWith('PROGRESS:')) {
+ const jsonString = line.substring('PROGRESS:'.length);
+ try {
+ const parsedOutput = JSON.parse(jsonString);
+ if (parsedOutput.job_id && parsedOutput.progress !== undefined) {
+ jobProgress[parsedOutput.job_id] = {
+ step: parsedOutput.step,
+ progress: parsedOutput.progress,
+ };
+ } else if (parsedOutput.progress !== undefined) {
+ jobProgress[jobId] = {
+ step: parsedOutput.step,
+ progress: parsedOutput.progress,
+ };
+ }
+ } catch (err) {
+ console.error('Error parsing progress JSON:', jsonString, err);
}
- } catch (err) {
- console.error('Progress log from Python:', line, err);
+ } else {
+ // Log other stderr output
+ console.error('Python stderr:', line);
}
}
});
@@ -490,10 +519,24 @@ function spawnPythonProcess(jobId: string, file_name: string, file_data: string)
jobProgress[jobId] = { step: 'Complete', progress: 100 };
} catch (err) {
console.error('Error parsing final JSON result:', err);
+ jobResults[jobId] = { error: 'Failed to parse final result' };
}
} else {
console.error(`Python process exited with code ${code}`);
- jobResults[jobId] = { error: 'Python process failed' };
+ // Check if there was an error message in stderr
+ if (stderrOutput) {
+ // Try to parse the last line as JSON
+ const lines = stderrOutput.trim().split('\n');
+ const lastLine = lines[lines.length - 1];
+ try {
+ const errorOutput = JSON.parse(lastLine);
+ jobResults[jobId] = errorOutput;
+ } catch (err) {
+ jobResults[jobId] = { error: 'Python process failed' };
+ }
+ } else {
+ jobResults[jobId] = { error: 'Python process failed' };
+ }
}
});
}
diff --git a/src/server/chunker/pdf_chunker.py b/src/server/chunker/pdf_chunker.py
index 4fe3b9dbf..48b2dbf97 100644
--- a/src/server/chunker/pdf_chunker.py
+++ b/src/server/chunker/pdf_chunker.py
@@ -54,8 +54,9 @@ def update_progress(job_id, step, progress_value):
"step": step,
"progress": progress_value
}
- print(json.dumps(progress_data), file=sys.stderr) # Use stderr for progress logs
- sys.stderr.flush() # Ensure it's sent immediately
+ print(f"PROGRESS:{json.dumps(progress_data)}", file=sys.stderr)
+ sys.stderr.flush()
+
class ElementExtractor:
@@ -63,13 +64,15 @@ class ElementExtractor:
A class that uses a YOLO model to extract tables and images from a PDF page.
"""
- def __init__(self, output_folder: str):
+ def __init__(self, output_folder: str, doc_id: str):
"""
Initializes the ElementExtractor with the output folder for saving images and the YOLO model.
:param output_folder: Path to the folder where extracted elements will be saved.
"""
- self.output_folder = output_folder
+ self.doc_id = doc_id
+ self.output_folder = os.path.join(output_folder, doc_id)
+ os.makedirs(self.output_folder, exist_ok=True)
self.model = YOLO('keremberke/yolov8m-table-extraction') # Load YOLO model for table extraction
self.model.overrides['conf'] = 0.25 # Set confidence threshold for detection
self.model.overrides['iou'] = 0.45 # Set Intersection over Union (IoU) threshold
@@ -116,17 +119,16 @@ class ElementExtractor:
table_path = os.path.join(self.output_folder, table_filename)
page_with_outline.save(table_path)
- # Convert the full-page image with red outline to base64
- base64_data = self.image_to_base64(page_with_outline)
+ file_path_for_client = f"{self.doc_id}/{table_filename}"
tables.append({
'metadata': {
"type": "table",
"location": [x1 / img.width, y1 / img.height, x2 / img.width, y2 / img.height],
- "file_path": table_path,
+ "file_path": file_path_for_client,
"start_page": page_num,
"end_page": page_num,
- "base64_data": base64_data,
+ "base64_data": self.image_to_base64(page_with_outline)
}
})
@@ -175,18 +177,17 @@ class ElementExtractor:
image_path = os.path.join(self.output_folder, image_filename)
page_with_outline.save(image_path)
- # Convert the full-page image with red outline to base64
- base64_data = self.image_to_base64(page_with_outline)
+ file_path_for_client = f"{self.doc_id}/{image_filename}"
images.append({
'metadata': {
"type": "image",
"location": [x1 / page.rect.width, y1 / page.rect.height, x2 / page.rect.width,
y2 / page.rect.height],
- "file_path": image_path,
+ "file_path": file_path_for_client,
"start_page": page_num,
"end_page": page_num,
- "base64_data": base64_data,
+ "base64_data": self.image_to_base64(image)
}
})
@@ -268,7 +269,7 @@ class PDFChunker:
The main class responsible for chunking PDF files into text and visual elements (tables/images).
"""
- def __init__(self, output_folder: str = "output", image_batch_size: int = 5) -> None:
+ def __init__(self, output_folder: str = "output", doc_id: str = '', image_batch_size: int = 5) -> None:
"""
Initializes the PDFChunker with an output folder and an element extractor for visual elements.
@@ -278,7 +279,8 @@ class PDFChunker:
self.client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) # Initialize the Anthropic API client
self.output_folder = output_folder
self.image_batch_size = image_batch_size # Batch size for image processing
- self.element_extractor = ElementExtractor(output_folder) # Initialize the element extractor
+ self.doc_id = doc_id # Add doc_id
+ self.element_extractor = ElementExtractor(output_folder, doc_id)
async def chunk_pdf(self, file_data: bytes, file_name: str, doc_id: str, job_id: str) -> List[Dict[str, Any]]:
"""
@@ -363,6 +365,7 @@ class PDFChunker:
for j, elem in enumerate(batch, start=1):
if j in summaries:
elem['metadata']['text'] = re.sub(r'^(Image|Table):\s*', '', summaries[j])
+ elem['metadata']['base64_data'] = ''
processed_elements.append(elem)
progress = ((i // image_batch_size) + 1) / total_batches * 100 # Calculate progress
@@ -628,10 +631,11 @@ class PDFChunker:
return summaries
- except Exception:
- #print(f"Error in batch_summarize_images: {str(e)}")
- #print("Returning placeholder summaries")
- return {number: "Error: No summary available" for number in images}
+ except Exception as e:
+ # Print errors to stderr so they don't interfere with JSON output
+ print(json.dumps({"error": str(e)}), file=sys.stderr)
+ sys.stderr.flush()
+
class DocumentType(Enum):
"""
@@ -664,7 +668,7 @@ class Document:
Represents a document being processed, such as a PDF, handling chunking, embedding, and summarization.
"""
- def __init__(self, file_data: bytes, file_name: str, job_id: str):
+ def __init__(self, file_data: bytes, file_name: str, job_id: str, output_folder: str):
"""
Initialize the Document with file data, file name, and job ID.
@@ -672,6 +676,7 @@ class Document:
:param file_name: The name of the file being processed.
:param job_id: The job ID associated with this document processing task.
"""
+ self.output_folder = output_folder
self.file_data = file_data
self.file_name = file_name
self.job_id = job_id
@@ -680,14 +685,13 @@ class Document:
self.chunks = [] # List to hold text and visual chunks
self.num_pages = 0 # Number of pages in the document (if applicable)
self.summary = "" # The generated summary for the document
-
self._process() # Start processing the document
def _process(self):
"""
Process the document: extract chunks, embed them, and generate a summary.
"""
- pdf_chunker = PDFChunker(output_folder="output") # Initialize the PDF chunker
+ pdf_chunker = PDFChunker(output_folder=self.output_folder, doc_id=self.doc_id) # Initialize PDFChunker
self.chunks = asyncio.run(pdf_chunker.chunk_pdf(self.file_data, self.file_name, self.doc_id, self.job_id)) # Extract chunks
self.num_pages = self._get_pdf_pages() # Get the number of pages in the document
@@ -796,8 +800,7 @@ class Document:
"doc_id": self.doc_id
}, indent=2) # Convert the document's attributes to JSON format
-
-def process_document(file_data, file_name, job_id):
+def process_document(file_data, file_name, job_id, output_folder):
"""
Top-level function to process a document and return the JSON output.
@@ -806,28 +809,30 @@ def process_document(file_data, file_name, job_id):
:param job_id: The job ID for this document processing task.
:return: The processed document's data in JSON format.
"""
- new_document = Document(file_data, file_name, job_id) # Create a new Document object
- return new_document.to_json() # Return the document's JSON data
-
+ new_document = Document(file_data, file_name, job_id, output_folder)
+ return new_document.to_json()
def main():
"""
Main entry point for the script, called with arguments from Node.js.
"""
- if len(sys.argv) != 4:
- print(json.dumps({"error": "Invalid arguments"}), file=sys.stderr) # Print error if incorrect number of arguments
+ if len(sys.argv) != 5:
+ print(json.dumps({"error": "Invalid arguments"}), file=sys.stderr)
return
- job_id = sys.argv[1] # Get the job ID from command-line arguments
- file_name = sys.argv[2] # Get the file name from command-line arguments
- file_data = sys.argv[3] # Get the base64-encoded file data from command-line arguments
+ job_id = sys.argv[1]
+ file_name = sys.argv[2]
+ file_data = sys.argv[3]
+ output_folder = sys.argv[4] # Get the output folder from arguments
try:
+ os.makedirs(output_folder, exist_ok=True)
+
# Decode the base64 file data
file_bytes = base64.b64decode(file_data)
# Process the document
- document_result = process_document(file_bytes, file_name, job_id)
+ document_result = process_document(file_bytes, file_name, job_id, output_folder) # Pass output_folder
# Output the final result as JSON to stdout
print(document_result)
@@ -839,5 +844,6 @@ def main():
sys.stderr.flush()
+
if __name__ == "__main__":
main() # Execute the main function when the script is run
diff --git a/src/server/flashcard/labels.py b/src/server/flashcard/labels.py
new file mode 100644
index 000000000..546fc4bd3
--- /dev/null
+++ b/src/server/flashcard/labels.py
@@ -0,0 +1,285 @@
+import base64
+import numpy as np
+import base64
+import easyocr
+import sys
+from PIL import Image
+from io import BytesIO
+import requests
+import json
+import numpy as np
+
+class BoundingBoxUtils:
+ """Utility class for bounding box operations and OCR result corrections."""
+
+ @staticmethod
+ def is_close(box1, box2, x_threshold=20, y_threshold=20):
+ """
+ Determines if two bounding boxes are horizontally and vertically close.
+
+ Parameters:
+ box1, box2 (list): The bounding boxes to compare.
+ x_threshold (int): The threshold for horizontal proximity.
+ y_threshold (int): The threshold for vertical proximity.
+
+ Returns:
+ bool: True if boxes are close, False otherwise.
+ """
+ horizontally_close = (abs(box1[2] - box2[0]) < x_threshold or # Right edge of box1 and left edge of box2
+ abs(box2[2] - box1[0]) < x_threshold or # Right edge of box2 and left edge of box1
+ abs(box1[2] - box2[2]) < x_threshold or
+ abs(box2[0] - box1[0]) < x_threshold)
+
+ vertically_close = (abs(box1[3] - box2[1]) < y_threshold or # Bottom edge of box1 and top edge of box2
+ abs(box2[3] - box1[1]) < y_threshold or
+ box1[1] == box2[1] or box1[3] == box2[3])
+
+ return horizontally_close and vertically_close
+
+ @staticmethod
+ def adjust_bounding_box(bbox, original_text, corrected_text):
+ """
+ Adjusts a bounding box based on differences in text length.
+
+ Parameters:
+ bbox (list): The original bounding box coordinates.
+ original_text (str): The original text detected by OCR.
+ corrected_text (str): The corrected text after cleaning.
+
+ Returns:
+ list: The adjusted bounding box.
+ """
+ if not bbox or len(bbox) != 4:
+ return bbox
+
+ # Adjust the x-coordinates slightly to account for text correction
+ x_adjustment = 5
+ adjusted_bbox = [
+ [bbox[0][0] + x_adjustment, bbox[0][1]],
+ [bbox[1][0], bbox[1][1]],
+ [bbox[2][0] + x_adjustment, bbox[2][1]],
+ [bbox[3][0], bbox[3][1]]
+ ]
+ return adjusted_bbox
+
+ @staticmethod
+ def correct_ocr_results(results):
+ """
+ Corrects common OCR misinterpretations in the detected text and adjusts bounding boxes accordingly.
+
+ Parameters:
+ results (list): A list of OCR results, each containing bounding box, text, and confidence score.
+
+ Returns:
+ list: Corrected OCR results with adjusted bounding boxes.
+ """
+ corrections = {
+ "~": "", # Replace '~' with empty string
+ "-": "" # Replace '-' with empty string
+ }
+
+ corrected_results = []
+ for (bbox, text, prob) in results:
+ corrected_text = ''.join(corrections.get(char, char) for char in text)
+ adjusted_bbox = BoundingBoxUtils.adjust_bounding_box(bbox, text, corrected_text)
+ corrected_results.append((adjusted_bbox, corrected_text, prob))
+
+ return corrected_results
+
+ @staticmethod
+ def convert_to_json_serializable(data):
+ """
+ Converts a list containing various types, including numpy types, to a JSON-serializable format.
+
+ Parameters:
+ data (list): A list containing numpy or other non-serializable types.
+
+ Returns:
+ list: A JSON-serializable version of the input list.
+ """
+ def convert_element(element):
+ if isinstance(element, list):
+ return [convert_element(e) for e in element]
+ elif isinstance(element, tuple):
+ return tuple(convert_element(e) for e in element)
+ elif isinstance(element, np.integer):
+ return int(element)
+ elif isinstance(element, np.floating):
+ return float(element)
+ elif isinstance(element, np.ndarray):
+ return element.tolist()
+ else:
+ return element
+
+ return convert_element(data)
+
+class ImageLabelProcessor:
+ """Class to process images and perform OCR with EasyOCR."""
+
+ VERTICAL_THRESHOLD = 20
+ HORIZONTAL_THRESHOLD = 8
+
+ def __init__(self, img_source, source_type, smart_mode):
+ self.img_source = img_source
+ self.source_type = source_type
+ self.smart_mode = smart_mode
+ self.img_val = self.load_image()
+
+ def load_image(self):
+ """Load image from either a base64 string or URL."""
+ if self.source_type == 'drag':
+ return self._load_base64_image()
+ else:
+ return self._load_url_image()
+
+ def _load_base64_image(self):
+ """Decode and save the base64 image."""
+ base64_string = self.img_source
+ if base64_string.startswith("data:image"):
+ base64_string = base64_string.split(",")[1]
+
+
+ # Decode the base64 string
+ image_data = base64.b64decode(base64_string)
+ image = Image.open(BytesIO(image_data)).convert('RGB')
+ image.save("temp_image.jpg")
+ return "temp_image.jpg"
+
+ def _load_url_image(self):
+ """Download image from URL and return it in byte format."""
+ url = self.img_source
+ response = requests.get(url)
+ image = Image.open(BytesIO(response.content)).convert('RGB')
+
+ image_bytes = BytesIO()
+ image.save(image_bytes, format='PNG')
+ return image_bytes.getvalue()
+
+ def process_image(self):
+ """Process the image and return the OCR results."""
+ if self.smart_mode:
+ return self._process_smart_mode()
+ else:
+ return self._process_standard_mode()
+
+ def _process_smart_mode(self):
+ """Process the image in smart mode using EasyOCR."""
+ reader = easyocr.Reader(['en'])
+ result = reader.readtext(self.img_val, detail=1, paragraph=True)
+
+ all_boxes = [bbox for bbox, text in result]
+ all_texts = [text for bbox, text in result]
+
+ response_data = {
+ 'status': 'success',
+ 'message': 'Data received',
+ 'boxes': BoundingBoxUtils.convert_to_json_serializable(all_boxes),
+ 'text': BoundingBoxUtils.convert_to_json_serializable(all_texts),
+ }
+
+ return response_data
+
+ def _process_standard_mode(self):
+ """Process the image in standard mode using EasyOCR."""
+ reader = easyocr.Reader(['en'])
+ results = reader.readtext(self.img_val)
+
+ filtered_results = BoundingBoxUtils.correct_ocr_results([
+ (bbox, text, prob) for bbox, text, prob in results if prob >= 0.7
+ ])
+
+ return self._merge_and_prepare_response(filtered_results)
+
+ def are_vertically_close(self, box1, box2):
+ """Check if two bounding boxes are vertically close."""
+ box1_bottom = max(box1[2][1], box1[3][1])
+ box2_top = min(box2[0][1], box2[1][1])
+ vertical_distance = box2_top - box1_bottom
+
+ box1_left = box1[0][0]
+ box2_left = box2[0][0]
+ box1_right = box1[1][0]
+ box2_right = box2[1][0]
+ hori_close = abs(box2_left - box1_left) <= self.HORIZONTAL_THRESHOLD or abs(box2_right - box1_right) <= self.HORIZONTAL_THRESHOLD
+
+ return vertical_distance <= self.VERTICAL_THRESHOLD and hori_close
+
+ def merge_boxes(self, boxes, texts):
+ """Merge multiple bounding boxes and their associated text."""
+ x_coords = []
+ y_coords = []
+
+ # Collect all x and y coordinates
+ for box in boxes:
+ for point in box:
+ x_coords.append(point[0])
+ y_coords.append(point[1])
+
+ # Create the merged bounding box
+ merged_box = [
+ [min(x_coords), min(y_coords)],
+ [max(x_coords), min(y_coords)],
+ [max(x_coords), max(y_coords)],
+ [min(x_coords), max(y_coords)]
+ ]
+
+ # Combine the texts
+ merged_text = ' '.join(texts)
+
+ return merged_box, merged_text
+
+ def _merge_and_prepare_response(self, filtered_results):
+ """Merge vertically close boxes and prepare the final response."""
+ current_boxes, current_texts = [], []
+ all_boxes, all_texts = [], []
+
+ for ind in range(len(filtered_results) - 1):
+ if not current_boxes:
+ current_boxes.append(filtered_results[ind][0])
+ current_texts.append(filtered_results[ind][1])
+
+ if self.are_vertically_close(filtered_results[ind][0], filtered_results[ind + 1][0]):
+ current_boxes.append(filtered_results[ind + 1][0])
+ current_texts.append(filtered_results[ind + 1][1])
+ else:
+ merged = self.merge_boxes(current_boxes, current_texts)
+ all_boxes.append(merged[0])
+ all_texts.append(merged[1])
+ current_boxes, current_texts = [], []
+
+ if current_boxes:
+ merged = self.merge_boxes(current_boxes, current_texts)
+ all_boxes.append(merged[0])
+ all_texts.append(merged[1])
+
+ if not current_boxes and filtered_results:
+ merged = self.merge_boxes([filtered_results[-1][0]], [filtered_results[-1][1]])
+ all_boxes.append(merged[0])
+ all_texts.append(merged[1])
+
+ response = {
+ 'status': 'success',
+ 'message': 'Data received',
+ 'boxes': BoundingBoxUtils.convert_to_json_serializable(all_boxes),
+ 'text': BoundingBoxUtils.convert_to_json_serializable(all_texts),
+ }
+
+ return response
+
+# Main execution function
+def labels():
+ """Main function to handle image OCR processing based on input arguments."""
+ source_type = sys.argv[2]
+ smart_mode = (sys.argv[3] == 'smart')
+ with open(sys.argv[1], 'r') as f:
+ img_source = f.read()
+ # Create ImageLabelProcessor instance
+ processor = ImageLabelProcessor(img_source, source_type, smart_mode)
+ response = processor.process_image()
+
+ # Print and return the response
+ print(response)
+ return response
+
+
+labels()
diff --git a/src/server/flashcard/requirements.txt b/src/server/flashcard/requirements.txt
new file mode 100644
index 000000000..eb92a819b
--- /dev/null
+++ b/src/server/flashcard/requirements.txt
@@ -0,0 +1,12 @@
+easyocr==1.7.1
+requests==2.32.3
+pillow==10.4.0
+numpy==1.26.4
+tqdm==4.66.4
+Werkzeug==3.0.3
+python-dateutil==2.9.0.post0
+six==1.16.0
+certifi==2024.6.2
+charset-normalizer==3.3.2
+idna==3.7
+urllib3==1.26.19 \ No newline at end of file
diff --git a/src/server/flashcard/venv/pyvenv.cfg b/src/server/flashcard/venv/pyvenv.cfg
new file mode 100644
index 000000000..740014e00
--- /dev/null
+++ b/src/server/flashcard/venv/pyvenv.cfg
@@ -0,0 +1,3 @@
+home = /Library/Frameworks/Python.framework/Versions/3.10/bin
+include-system-site-packages = false
+version = 3.10.11