diff options
| author | bobzel <zzzman@gmail.com> | 2025-03-10 16:13:04 -0400 |
|---|---|---|
| committer | bobzel <zzzman@gmail.com> | 2025-03-10 16:13:04 -0400 |
| commit | b7989dded8bb001876de6cbca59bf77935f0daf7 (patch) | |
| tree | 0dba0665674db7bb84770833df0a4100d0520701 /src/client/views/nodes/chatbot/agentsystem/Agent.ts | |
| parent | 4979415d4604d280e81a162bf9a9d39c731d3738 (diff) | |
| parent | 5bf944035c0ba94ad15245416f51ca0329a51bde (diff) | |
Merge branch 'master' into alyssa-starter
Diffstat (limited to 'src/client/views/nodes/chatbot/agentsystem/Agent.ts')
| -rw-r--r-- | src/client/views/nodes/chatbot/agentsystem/Agent.ts | 265 |
1 files changed, 209 insertions, 56 deletions
diff --git a/src/client/views/nodes/chatbot/agentsystem/Agent.ts b/src/client/views/nodes/chatbot/agentsystem/Agent.ts index 34e7cf5ea..e93fb87db 100644 --- a/src/client/views/nodes/chatbot/agentsystem/Agent.ts +++ b/src/client/views/nodes/chatbot/agentsystem/Agent.ts @@ -1,21 +1,30 @@ import dotenv from 'dotenv'; import { XMLBuilder, XMLParser } from 'fast-xml-parser'; +import { escape } from 'lodash'; // Imported escape from lodash import OpenAI from 'openai'; -import { ChatCompletionMessageParam } from 'openai/resources'; +import { DocumentOptions } from '../../../../documents/Documents'; import { AnswerParser } from '../response_parsers/AnswerParser'; import { StreamedAnswerParser } from '../response_parsers/StreamedAnswerParser'; +import { BaseTool } from '../tools/BaseTool'; import { CalculateTool } from '../tools/CalculateTool'; -import { CreateCSVTool } from '../tools/CreateCSVTool'; +//import { CreateAnyDocumentTool } from '../tools/CreateAnyDocTool'; +import { CreateDocTool } from '../tools/CreateDocumentTool'; import { DataAnalysisTool } from '../tools/DataAnalysisTool'; +import { ImageCreationTool } from '../tools/ImageCreationTool'; import { NoTool } from '../tools/NoTool'; -import { RAGTool } from '../tools/RAGTool'; import { SearchTool } from '../tools/SearchTool'; -import { WebsiteInfoScraperTool } from '../tools/WebsiteInfoScraperTool'; -import { AgentMessage, AssistantMessage, Observation, PROCESSING_TYPE, ProcessingInfo } from '../types/types'; +import { Parameter, ParametersType, TypeMap } from '../types/tool_types'; +import { AgentMessage, ASSISTANT_ROLE, AssistantMessage, Observation, PROCESSING_TYPE, ProcessingInfo, TEXT_TYPE } from '../types/types'; import { Vectorstore } from '../vectorstore/Vectorstore'; import { getReactPrompt } from './prompts'; -import { BaseTool } from '../tools/BaseTool'; -import { Parameter, ParametersType, Tool } from '../tools/ToolTypes'; +//import { DictionaryTool } from '../tools/DictionaryTool'; +import { ChatCompletionMessageParam } from 'openai/resources'; +import { Doc } from '../../../../../fields/Doc'; +import { parsedDoc } from '../chatboxcomponents/ChatBox'; +import { WebsiteInfoScraperTool } from '../tools/WebsiteInfoScraperTool'; +import { Upload } from '../../../../../server/SharedMediaTypes'; +import { RAGTool } from '../tools/RAGTool'; +//import { CreateTextDocTool } from '../tools/CreateTextDocumentTool'; dotenv.config(); @@ -54,6 +63,9 @@ export class Agent { history: () => string, csvData: () => { filename: string; id: string; text: string }[], addLinkedUrlDoc: (url: string, id: string) => void, + createImage: (result: Upload.FileInformation & Upload.InspectionResults, options: DocumentOptions) => void, + addLinkedDoc: (doc: parsedDoc) => Doc | undefined, + // eslint-disable-next-line @typescript-eslint/no-unused-vars createCSVInDash: (url: string, title: string, id: string, data: string) => void ) { // Initialize OpenAI client with API key from environment @@ -70,8 +82,13 @@ export class Agent { dataAnalysis: new DataAnalysisTool(csvData), websiteInfoScraper: new WebsiteInfoScraperTool(addLinkedUrlDoc), searchTool: new SearchTool(addLinkedUrlDoc), - createCSV: new CreateCSVTool(createCSVInDash), - no_tool: new NoTool(), + // createCSV: new CreateCSVTool(createCSVInDash), + noTool: new NoTool(), + imageCreationTool: new ImageCreationTool(createImage), + // createTextDoc: new CreateTextDocTool(addLinkedDoc), + createDoc: new CreateDocTool(addLinkedDoc), + // createAnyDocument: new CreateAnyDocumentTool(addLinkedDoc), + // dictionary: new DictionaryTool(), }; } @@ -86,9 +103,17 @@ export class Agent { */ async askAgent(question: string, onProcessingUpdate: (processingUpdate: ProcessingInfo[]) => void, onAnswerUpdate: (answerUpdate: string) => void, maxTurns: number = 30): Promise<AssistantMessage> { console.log(`Starting query: ${question}`); + const MAX_QUERY_LENGTH = 1000; // adjust the limit as needed + + // Check if the question exceeds the maximum length + if (question.length > MAX_QUERY_LENGTH) { + return { role: ASSISTANT_ROLE.ASSISTANT, content: [{ text: 'User query too long. Please shorten your question and try again.', index: 0, type: TEXT_TYPE.NORMAL, citation_ids: null }], processing_info: [] }; + } + + const sanitizedQuestion = escape(question); // Sanitized user input - // Push user's question to message history - this.messages.push({ role: 'user', content: question }); + // Push sanitized user's question to message history + this.messages.push({ role: 'user', content: sanitizedQuestion }); // Retrieve chat history and generate system prompt const chatHistory = this._history(); @@ -96,14 +121,20 @@ export class Agent { // Initialize intermediate messages this.interMessages = [{ role: 'system', content: systemPrompt }]; - this.interMessages.push({ role: 'user', content: `<stage number="1" role="user"><query>${question}</query></stage>` }); + + this.interMessages.push({ + role: 'user', + content: this.constructUserPrompt(1, 'user', `<query>${sanitizedQuestion}</query>`), + }); // Setup XML parser and builder const parser = new XMLParser({ ignoreAttributes: false, attributeNamePrefix: '@_', textNodeName: '_text', - isArray: (name /* , jpath, isLeafNode, isAttribute */) => ['query', 'url'].indexOf(name) !== -1, + isArray: name => ['query', 'url'].indexOf(name) !== -1, + processEntities: false, // Disable processing of entities + stopNodes: ['*.entity'], // Do not process any entities }); const builder = new XMLBuilder({ ignoreAttributes: false, attributeNamePrefix: '@_' }); @@ -115,6 +146,7 @@ export class Agent { console.log(this.interMessages); console.log(`Turn ${i}/${maxTurns}`); + // eslint-disable-next-line no-await-in-loop const result = await this.execute(onProcessingUpdate, onAnswerUpdate); this.interMessages.push({ role: 'assistant', content: result }); @@ -124,8 +156,11 @@ export class Agent { try { // Parse XML result from the assistant parsedResult = parser.parse(result); + + // Validate the structure of the parsedResult + this.validateAssistantResponse(parsedResult); } catch (error) { - throw new Error(`Error parsing response: ${error}`); + throw new Error(`Error parsing or validating response: ${error}`); } // Extract the stage from the parsed result @@ -158,17 +193,22 @@ export class Agent { } else { // Handle error in case of an invalid action console.log('Error: No valid action'); - this.interMessages.push({ role: 'user', content: `<stage number="${i + 1}" role="system-error-reporter">No valid action, try again.</stage>` }); + this.interMessages.push({ + role: 'user', + content: `<stage number="${i + 1}" role="system-error-reporter">No valid action, try again.</stage>`, + }); break; } } else if (key === 'action_input') { // Handle action input stage const actionInput = stage[key]; + console.log(`Action input full:`, actionInput); console.log(`Action input:`, actionInput.inputs); if (currentAction) { try { // Process the action with its input + // eslint-disable-next-line no-await-in-loop const observation = (await this.processAction(currentAction, actionInput.inputs)) as Observation[]; const nextPrompt = [{ type: 'text', text: `<stage number="${i + 1}" role="user"> <observation>` }, ...observation, { type: 'text', text: '</observation></stage>' }] as Observation[]; console.log(observation); @@ -194,6 +234,10 @@ export class Agent { throw new Error('Reached maximum turns. Ending query.'); } + private constructUserPrompt(stageNumber: number, role: string, content: string): string { + return `<stage number="${stageNumber}" role="${role}">${content}</stage>`; + } + /** * Executes a step in the conversation, processing the assistant's response and parsing it in real-time. * @param onProcessingUpdate Callback for processing updates. @@ -207,6 +251,7 @@ export class Agent { messages: this.interMessages as ChatCompletionMessageParam[], temperature: 0, stream: true, + stop: ['</stage>'], }); let fullResponse: string = ''; @@ -264,11 +309,140 @@ export class Agent { } /** + * Validates the assistant's response to ensure it conforms to the expected XML structure. + * @param response The parsed XML response from the assistant. + * @throws An error if the response does not meet the expected structure. + */ + private validateAssistantResponse(response: { stage: { [key: string]: object | string } }) { + if (!response.stage) { + throw new Error('Response does not contain a <stage> element'); + } + + // Validate that the stage has the required attributes + const stage = response.stage; + if (!stage['@_number'] || !stage['@_role']) { + throw new Error('Stage element must have "number" and "role" attributes'); + } + + // Extract the role of the stage to determine expected content + const role = stage['@_role']; + + // Depending on the role, validate the presence of required elements + if (role === 'assistant') { + // Assistant's response should contain either 'thought', 'action', 'action_input', or 'answer' + if (!('thought' in stage || 'action' in stage || 'action_input' in stage || 'answer' in stage)) { + throw new Error('Assistant stage must contain a thought, action, action_input, or answer element'); + } + + // If 'thought' is present, validate it + if ('thought' in stage) { + if (typeof stage.thought !== 'string' || stage.thought.trim() === '') { + throw new Error('Thought must be a non-empty string'); + } + } + + // If 'action' is present, validate it + if ('action' in stage) { + if (typeof stage.action !== 'string' || stage.action.trim() === '') { + throw new Error('Action must be a non-empty string'); + } + + // Optional: Check if the action is among allowed actions + const allowedActions = Object.keys(this.tools); + if (!allowedActions.includes(stage.action)) { + throw new Error(`Action "${stage.action}" is not a valid tool`); + } + } + + // If 'action_input' is present, validate its structure + if ('action_input' in stage) { + const actionInput = stage.action_input as object; + + if (!('action_input_description' in actionInput) || typeof actionInput.action_input_description !== 'string') { + throw new Error('action_input must contain an action_input_description string'); + } + + if (!('inputs' in actionInput)) { + throw new Error('action_input must contain an inputs object'); + } + + // Further validation of inputs can be done here based on the expected parameters of the action + } + + // If 'answer' is present, validate its structure + if ('answer' in stage) { + const answer = stage.answer as object; + + // Ensure answer contains at least one of the required elements + if (!('grounded_text' in answer || 'normal_text' in answer)) { + throw new Error('Answer must contain grounded_text or normal_text'); + } + + // Validate follow_up_questions + if (!('follow_up_questions' in answer)) { + throw new Error('Answer must contain follow_up_questions'); + } + + // Validate loop_summary + if (!('loop_summary' in answer)) { + throw new Error('Answer must contain a loop_summary'); + } + + // Additional validation for citations, grounded_text, etc., can be added here + } + } else if (role === 'user') { + // User's stage should contain 'query' or 'observation' + if (!('query' in stage || 'observation' in stage)) { + throw new Error('User stage must contain a query or observation element'); + } + + // Validate 'query' if present + if ('query' in stage && typeof stage.query !== 'string') { + throw new Error('Query must be a string'); + } + + // Validate 'observation' if present + if ('observation' in stage) { + // Ensure observation has the correct structure + // This can be expanded based on how observations are structured + } + } else { + throw new Error(`Unknown role "${role}" in stage`); + } + + // Add any additional validation rules as necessary + } + + /** + * Helper function to check if a string can be parsed as an array of the expected type. + * @param input The input string to check. + * @param expectedType The expected type of the array elements ('string', 'number', or 'boolean'). + * @returns The parsed array if valid, otherwise throws an error. + */ + private parseArray<T>(input: string, expectedType: 'string' | 'number' | 'boolean'): T[] { + try { + // Parse the input string into a JSON object + const parsed = JSON.parse(input); + + // Check if the parsed object is an array and if all elements are of the expected type + if (Array.isArray(parsed) && parsed.every(item => typeof item === expectedType)) { + return parsed; + } else { + throw new Error(`Invalid ${expectedType} array format.`); + } + } catch (error) { + throw new Error(`Failed to parse ${expectedType} array: ` + error); + } + } + + /** * Processes a specific action by invoking the appropriate tool with the provided inputs. * This method ensures that the action exists and validates the types of `actionInput` * based on the tool's parameter rules. It throws errors for missing required parameters * or mismatched types before safely executing the tool with the validated input. * + * NOTE: In the future, it should typecheck for specific tool parameter types using the `TypeMap` or otherwise. + * * Type validation includes checks for: * - `string`, `number`, `boolean` * - `string[]`, `number[]` (arrays of strings or numbers) @@ -278,56 +452,35 @@ export class Agent { * @returns A promise that resolves to an array of `Observation` objects representing the result of the action. * @throws An error if the action is unknown, if required parameters are missing, or if input types don't match the expected parameter types. */ - private async processAction(action: string, actionInput: Record<string, unknown>): Promise<Observation[]> { + private async processAction(action: string, actionInput: ParametersType<ReadonlyArray<Parameter>>): Promise<Observation[]> { // Check if the action exists in the tools list if (!(action in this.tools)) { throw new Error(`Unknown action: ${action}`); } + console.log(actionInput); - const tool = this.tools[action]; - - // Validate actionInput based on tool's parameter rules - for (const paramRule of tool.parameterRules) { - const inputValue = actionInput[paramRule.name]; - - if (paramRule.required && inputValue === undefined) { - throw new Error(`Missing required parameter: ${paramRule.name}`); + for (const param of this.tools[action].parameterRules) { + // Check if the parameter is required and missing in the input + if (param.required && !(param.name in actionInput) && !this.tools[action].inputValidator(actionInput)) { + throw new Error(`Missing required parameter: ${param.name}`); } - // If the parameter is defined, check its type - if (inputValue !== undefined) { - switch (paramRule.type) { - case 'string': - if (typeof inputValue !== 'string') { - throw new Error(`Expected parameter '${paramRule.name}' to be a string.`); - } - break; - case 'number': - if (typeof inputValue !== 'number') { - throw new Error(`Expected parameter '${paramRule.name}' to be a number.`); - } - break; - case 'boolean': - if (typeof inputValue !== 'boolean') { - throw new Error(`Expected parameter '${paramRule.name}' to be a boolean.`); - } - break; - case 'string[]': - if (!Array.isArray(inputValue) || !inputValue.every(item => typeof item === 'string')) { - throw new Error(`Expected parameter '${paramRule.name}' to be an array of strings.`); - } - break; - case 'number[]': - if (!Array.isArray(inputValue) || !inputValue.every(item => typeof item === 'number')) { - throw new Error(`Expected parameter '${paramRule.name}' to be an array of numbers.`); - } - break; - default: - throw new Error(`Unsupported parameter type: ${paramRule.type}`); - } + // Check if the parameter type matches the expected type + const expectedType = param.type.replace('[]', '') as 'string' | 'number' | 'boolean'; + const isArray = param.type.endsWith('[]'); + const input = actionInput[param.name]; + + if (isArray) { + // Check if the input is a valid array of the expected type + const parsedArray = this.parseArray(input as string, expectedType); + actionInput[param.name] = parsedArray as TypeMap[typeof param.type]; + } else if (input !== undefined && typeof input !== expectedType) { + throw new Error(`Invalid type for parameter ${param.name}: expected ${expectedType}`); } } - return await tool.execute(actionInput as ParametersType<typeof tool.parameterRules>); + const tool = this.tools[action]; + + return await tool.execute(actionInput); } } |
