diff options
| author | A.J. Shulman <Shulman.aj@gmail.com> | 2024-09-07 12:43:05 -0400 |
|---|---|---|
| committer | A.J. Shulman <Shulman.aj@gmail.com> | 2024-09-07 12:43:05 -0400 |
| commit | 4791cd23af08da70895204a3a7fbaf889d9af2d5 (patch) | |
| tree | c4c2534e64724d62bae9152763f1a74cd5a963e0 /src/client/views/nodes/chatbot/agentsystem/Agent.ts | |
| parent | 210f8f5f1cd19e9416a12524cce119b273334fd3 (diff) | |
completely restructured, added comments, and significantly reduced the length of the prompt (~72% shorter and cheaper)
Diffstat (limited to 'src/client/views/nodes/chatbot/agentsystem/Agent.ts')
| -rw-r--r-- | src/client/views/nodes/chatbot/agentsystem/Agent.ts | 278 |
1 files changed, 278 insertions, 0 deletions
diff --git a/src/client/views/nodes/chatbot/agentsystem/Agent.ts b/src/client/views/nodes/chatbot/agentsystem/Agent.ts new file mode 100644 index 000000000..180d05cf3 --- /dev/null +++ b/src/client/views/nodes/chatbot/agentsystem/Agent.ts @@ -0,0 +1,278 @@ +import OpenAI from 'openai'; +import { Tool, AgentMessage, AssistantMessage, TEXT_TYPE, CHUNK_TYPE, ASSISTANT_ROLE, ProcessingInfo, PROCESSING_TYPE } from '../types/types'; +import { getReactPrompt } from './prompts'; +import { XMLParser, XMLBuilder } from 'fast-xml-parser'; +import { Vectorstore } from '../vectorstore/Vectorstore'; +import { ChatCompletionMessageParam } from 'openai/resources'; +import dotenv from 'dotenv'; +import { CalculateTool } from '../tools/CalculateTool'; +import { RAGTool } from '../tools/RAGTool'; +import { DataAnalysisTool } from '../tools/DataAnalysisTool'; +import { WebsiteInfoScraperTool } from '../tools/WebsiteInfoScraperTool'; +import { SearchTool } from '../tools/SearchTool'; +import { NoTool } from '../tools/NoTool'; +import { v4 as uuidv4 } from 'uuid'; +import { AnswerParser } from '../response_parsers/AnswerParser'; +import { StreamedAnswerParser } from '../response_parsers/StreamedAnswerParser'; +import { CreateCSVTool } from '../tools/CreateCSVTool'; + +dotenv.config(); + +/** + * The Agent class handles the interaction between the assistant and the tools available, + * processes user queries, and manages the communication flow between the tools and OpenAI. + */ +export class Agent { + // Private properties + private client: OpenAI; + private tools: Record<string, Tool<any>>; + private messages: AgentMessage[] = []; + private interMessages: AgentMessage[] = []; + private vectorstore: Vectorstore; + private _history: () => string; + private _summaries: () => string; + private _csvData: () => { filename: string; id: string; text: string }[]; + private actionNumber: number = 0; + private thoughtNumber: number = 0; + private processingNumber: number = 0; + private processingInfo: ProcessingInfo[] = []; + private streamedAnswerParser: StreamedAnswerParser = new StreamedAnswerParser(); + + /** + * The constructor initializes the agent with the vector store and toolset, and sets up the OpenAI client. + * @param _vectorstore Vector store instance for document storage and retrieval. + * @param summaries A function to retrieve document summaries. + * @param history A function to retrieve chat history. + * @param csvData A function to retrieve CSV data linked to the assistant. + * @param addLinkedUrlDoc A function to add a linked document from a URL. + * @param createCSVInDash A function to create a CSV document in the dashboard. + */ + constructor( + _vectorstore: Vectorstore, + summaries: () => string, + history: () => string, + csvData: () => { filename: string; id: string; text: string }[], + addLinkedUrlDoc: (url: string, id: string) => void, + createCSVInDash: (url: string, title: string, id: string, data: string) => void + ) { + // Initialize OpenAI client with API key from environment + this.client = new OpenAI({ apiKey: process.env.OPENAI_KEY, dangerouslyAllowBrowser: true }); + this.vectorstore = _vectorstore; + this._history = history; + this._summaries = summaries; + this._csvData = csvData; + + // Define available tools for the assistant + this.tools = { + calculate: new CalculateTool(), + rag: new RAGTool(this.vectorstore), + dataAnalysis: new DataAnalysisTool(csvData), + websiteInfoScraper: new WebsiteInfoScraperTool(addLinkedUrlDoc), + searchTool: new SearchTool(addLinkedUrlDoc), + createCSV: new CreateCSVTool(createCSVInDash), + no_tool: new NoTool(), + }; + } + + /** + * This method handles the conversation flow with the assistant, processes user queries, + * and manages the assistant's decision-making process, including tool actions. + * @param question The user's question. + * @param onProcessingUpdate Callback function for processing updates. + * @param onAnswerUpdate Callback function for answer updates. + * @param maxTurns The maximum number of turns to allow in the conversation. + * @returns The final response from the assistant. + */ + async askAgent(question: string, onProcessingUpdate: (processingUpdate: ProcessingInfo[]) => void, onAnswerUpdate: (answerUpdate: string) => void, maxTurns: number = 30): Promise<AssistantMessage> { + console.log(`Starting query: ${question}`); + + // Push user's question to message history + this.messages.push({ role: 'user', content: question }); + + // Retrieve chat history and generate system prompt + const chatHistory = this._history(); + const systemPrompt = getReactPrompt(Object.values(this.tools), this._summaries, chatHistory); + + // Initialize intermediate messages + this.interMessages = [{ role: 'system', content: systemPrompt }]; + this.interMessages.push({ role: 'user', content: `<stage number="1" role="user"><query>${question}</query></stage>` }); + + // Setup XML parser and builder + const parser = new XMLParser({ + ignoreAttributes: false, + attributeNamePrefix: '@_', + textNodeName: '_text', + isArray: (name, jpath, isLeafNode, isAttribute) => ['query', 'url'].indexOf(name) !== -1, + }); + const builder = new XMLBuilder({ ignoreAttributes: false, attributeNamePrefix: '@_' }); + + let currentAction: string | undefined; + this.processingInfo = []; + + // Conversation loop (up to maxTurns) + for (let i = 2; i < maxTurns; i += 2) { + console.log(this.interMessages); + console.log(`Turn ${i}/${maxTurns}`); + + // Execute a step in the conversation and get the result + const result = await this.execute(onProcessingUpdate, onAnswerUpdate); + this.interMessages.push({ role: 'assistant', content: result }); + + let parsedResult; + try { + // Parse XML result from the assistant + parsedResult = parser.parse(result); + } catch (error) { + throw new Error(`Error parsing response: ${error}`); + } + + // Extract the stage from the parsed result + const stage = parsedResult.stage; + if (!stage) { + throw new Error(`Error: No stage found in response`); + } + + // Handle different stage elements (thoughts, actions, inputs, answers) + for (const key in stage) { + if (key === 'thought') { + // Handle assistant's thoughts + console.log(`Thought: ${stage[key]}`); + this.processingNumber++; + } else if (key === 'action') { + // Handle action stage + currentAction = stage[key] as string; + console.log(`Action: ${currentAction}`); + + if (this.tools[currentAction]) { + // Prepare the next action based on the current tool + const nextPrompt = [ + { + type: 'text', + text: `<stage number="${i + 1}" role="user">` + builder.build({ action_rules: this.tools[currentAction].getActionRule() }) + `</stage>`, + }, + ]; + this.interMessages.push({ role: 'user', content: nextPrompt }); + break; + } else { + // Handle error in case of an invalid action + console.log('Error: No valid action'); + this.interMessages.push({ role: 'user', content: `<stage number="${i + 1}" role="system-error-reporter">No valid action, try again.</stage>` }); + break; + } + } else if (key === 'action_input') { + // Handle action input stage + const actionInput = stage[key]; + console.log(`Action input:`, actionInput.inputs); + + if (currentAction) { + try { + // Process the action with its input + const observation = await this.processAction(currentAction, actionInput.inputs); + const nextPrompt = [{ type: 'text', text: `<stage number="${i + 1}" role="user"> <observation>` }, ...observation, { type: 'text', text: '</observation></stage>' }]; + console.log(observation); + this.interMessages.push({ role: 'user', content: nextPrompt }); + this.processingNumber++; + break; + } catch (error) { + throw new Error(`Error processing action: ${error}`); + } + } else { + throw new Error('Error: Action input without a valid action'); + } + } else if (key === 'answer') { + // If an answer is found, end the query + console.log('Answer found. Ending query.'); + this.streamedAnswerParser.reset(); + const parsedAnswer = AnswerParser.parse(result, this.processingInfo); + return parsedAnswer; + } + } + } + + throw new Error('Reached maximum turns. Ending query.'); + } + + /** + * Executes a step in the conversation, processing the assistant's response and parsing it in real-time. + * @param onProcessingUpdate Callback for processing updates. + * @param onAnswerUpdate Callback for answer updates. + * @returns The full response from the assistant. + */ + private async execute(onProcessingUpdate: (processingUpdate: ProcessingInfo[]) => void, onAnswerUpdate: (answerUpdate: string) => void): Promise<string> { + // Stream OpenAI response for real-time updates + const stream = await this.client.chat.completions.create({ + model: 'gpt-4o', + messages: this.interMessages as ChatCompletionMessageParam[], + temperature: 0, + stream: true, + }); + + let fullResponse: string = ''; + let currentTag: string = ''; + let currentContent: string = ''; + let isInsideTag: boolean = false; + + // Process each chunk of the streamed response + for await (const chunk of stream) { + let content = chunk.choices[0]?.delta?.content || ''; + fullResponse += content; + + // Parse the streamed content character by character + for (const char of content) { + if (currentTag === 'answer') { + // Handle answer parsing for real-time updates + currentContent += char; + const streamedAnswer = this.streamedAnswerParser.parse(char); + onAnswerUpdate(streamedAnswer); + continue; + } else if (char === '<') { + // Start of a new tag + isInsideTag = true; + currentTag = ''; + currentContent = ''; + } else if (char === '>') { + // End of the tag + isInsideTag = false; + if (currentTag.startsWith('/')) { + currentTag = ''; + } + } else if (isInsideTag) { + // Append characters to the tag name + currentTag += char; + } else if (currentTag === 'thought' || currentTag === 'action_input_description') { + // Handle processing information for thought or action input description + currentContent += char; + const current_info = this.processingInfo.find(info => info.index === this.processingNumber); + if (current_info) { + current_info.content = currentContent.trim(); + onProcessingUpdate(this.processingInfo); + } else { + this.processingInfo.push({ + index: this.processingNumber, + type: currentTag === 'thought' ? PROCESSING_TYPE.THOUGHT : PROCESSING_TYPE.ACTION, + content: currentContent.trim(), + }); + onProcessingUpdate(this.processingInfo); + } + } + } + } + + return fullResponse; + } + + /** + * Processes a specific action by invoking the appropriate tool with the provided inputs. + * @param action The action to perform. + * @param actionInput The inputs for the action. + * @returns The result of the action. + */ + private async processAction(action: string, actionInput: any): Promise<any> { + if (!(action in this.tools)) { + throw new Error(`Unknown action: ${action}`); + } + + const tool = this.tools[action]; + return await tool.execute(actionInput); + } +} |
