aboutsummaryrefslogtreecommitdiff
path: root/src/client/views/nodes/chatbot/agentsystem/Agent.ts
diff options
context:
space:
mode:
authorA.J. Shulman <Shulman.aj@gmail.com>2024-09-07 12:43:05 -0400
committerA.J. Shulman <Shulman.aj@gmail.com>2024-09-07 12:43:05 -0400
commit4791cd23af08da70895204a3a7fbaf889d9af2d5 (patch)
treec4c2534e64724d62bae9152763f1a74cd5a963e0 /src/client/views/nodes/chatbot/agentsystem/Agent.ts
parent210f8f5f1cd19e9416a12524cce119b273334fd3 (diff)
completely restructured, added comments, and significantly reduced the length of the prompt (~72% shorter and cheaper)
Diffstat (limited to 'src/client/views/nodes/chatbot/agentsystem/Agent.ts')
-rw-r--r--src/client/views/nodes/chatbot/agentsystem/Agent.ts278
1 files changed, 278 insertions, 0 deletions
diff --git a/src/client/views/nodes/chatbot/agentsystem/Agent.ts b/src/client/views/nodes/chatbot/agentsystem/Agent.ts
new file mode 100644
index 000000000..180d05cf3
--- /dev/null
+++ b/src/client/views/nodes/chatbot/agentsystem/Agent.ts
@@ -0,0 +1,278 @@
+import OpenAI from 'openai';
+import { Tool, AgentMessage, AssistantMessage, TEXT_TYPE, CHUNK_TYPE, ASSISTANT_ROLE, ProcessingInfo, PROCESSING_TYPE } from '../types/types';
+import { getReactPrompt } from './prompts';
+import { XMLParser, XMLBuilder } from 'fast-xml-parser';
+import { Vectorstore } from '../vectorstore/Vectorstore';
+import { ChatCompletionMessageParam } from 'openai/resources';
+import dotenv from 'dotenv';
+import { CalculateTool } from '../tools/CalculateTool';
+import { RAGTool } from '../tools/RAGTool';
+import { DataAnalysisTool } from '../tools/DataAnalysisTool';
+import { WebsiteInfoScraperTool } from '../tools/WebsiteInfoScraperTool';
+import { SearchTool } from '../tools/SearchTool';
+import { NoTool } from '../tools/NoTool';
+import { v4 as uuidv4 } from 'uuid';
+import { AnswerParser } from '../response_parsers/AnswerParser';
+import { StreamedAnswerParser } from '../response_parsers/StreamedAnswerParser';
+import { CreateCSVTool } from '../tools/CreateCSVTool';
+
+dotenv.config();
+
+/**
+ * The Agent class handles the interaction between the assistant and the tools available,
+ * processes user queries, and manages the communication flow between the tools and OpenAI.
+ */
+export class Agent {
+ // Private properties
+ private client: OpenAI;
+ private tools: Record<string, Tool<any>>;
+ private messages: AgentMessage[] = [];
+ private interMessages: AgentMessage[] = [];
+ private vectorstore: Vectorstore;
+ private _history: () => string;
+ private _summaries: () => string;
+ private _csvData: () => { filename: string; id: string; text: string }[];
+ private actionNumber: number = 0;
+ private thoughtNumber: number = 0;
+ private processingNumber: number = 0;
+ private processingInfo: ProcessingInfo[] = [];
+ private streamedAnswerParser: StreamedAnswerParser = new StreamedAnswerParser();
+
+ /**
+ * The constructor initializes the agent with the vector store and toolset, and sets up the OpenAI client.
+ * @param _vectorstore Vector store instance for document storage and retrieval.
+ * @param summaries A function to retrieve document summaries.
+ * @param history A function to retrieve chat history.
+ * @param csvData A function to retrieve CSV data linked to the assistant.
+ * @param addLinkedUrlDoc A function to add a linked document from a URL.
+ * @param createCSVInDash A function to create a CSV document in the dashboard.
+ */
+ constructor(
+ _vectorstore: Vectorstore,
+ summaries: () => string,
+ history: () => string,
+ csvData: () => { filename: string; id: string; text: string }[],
+ addLinkedUrlDoc: (url: string, id: string) => void,
+ createCSVInDash: (url: string, title: string, id: string, data: string) => void
+ ) {
+ // Initialize OpenAI client with API key from environment
+ this.client = new OpenAI({ apiKey: process.env.OPENAI_KEY, dangerouslyAllowBrowser: true });
+ this.vectorstore = _vectorstore;
+ this._history = history;
+ this._summaries = summaries;
+ this._csvData = csvData;
+
+ // Define available tools for the assistant
+ this.tools = {
+ calculate: new CalculateTool(),
+ rag: new RAGTool(this.vectorstore),
+ dataAnalysis: new DataAnalysisTool(csvData),
+ websiteInfoScraper: new WebsiteInfoScraperTool(addLinkedUrlDoc),
+ searchTool: new SearchTool(addLinkedUrlDoc),
+ createCSV: new CreateCSVTool(createCSVInDash),
+ no_tool: new NoTool(),
+ };
+ }
+
+ /**
+ * This method handles the conversation flow with the assistant, processes user queries,
+ * and manages the assistant's decision-making process, including tool actions.
+ * @param question The user's question.
+ * @param onProcessingUpdate Callback function for processing updates.
+ * @param onAnswerUpdate Callback function for answer updates.
+ * @param maxTurns The maximum number of turns to allow in the conversation.
+ * @returns The final response from the assistant.
+ */
+ async askAgent(question: string, onProcessingUpdate: (processingUpdate: ProcessingInfo[]) => void, onAnswerUpdate: (answerUpdate: string) => void, maxTurns: number = 30): Promise<AssistantMessage> {
+ console.log(`Starting query: ${question}`);
+
+ // Push user's question to message history
+ this.messages.push({ role: 'user', content: question });
+
+ // Retrieve chat history and generate system prompt
+ const chatHistory = this._history();
+ const systemPrompt = getReactPrompt(Object.values(this.tools), this._summaries, chatHistory);
+
+ // Initialize intermediate messages
+ this.interMessages = [{ role: 'system', content: systemPrompt }];
+ this.interMessages.push({ role: 'user', content: `<stage number="1" role="user"><query>${question}</query></stage>` });
+
+ // Setup XML parser and builder
+ const parser = new XMLParser({
+ ignoreAttributes: false,
+ attributeNamePrefix: '@_',
+ textNodeName: '_text',
+ isArray: (name, jpath, isLeafNode, isAttribute) => ['query', 'url'].indexOf(name) !== -1,
+ });
+ const builder = new XMLBuilder({ ignoreAttributes: false, attributeNamePrefix: '@_' });
+
+ let currentAction: string | undefined;
+ this.processingInfo = [];
+
+ // Conversation loop (up to maxTurns)
+ for (let i = 2; i < maxTurns; i += 2) {
+ console.log(this.interMessages);
+ console.log(`Turn ${i}/${maxTurns}`);
+
+ // Execute a step in the conversation and get the result
+ const result = await this.execute(onProcessingUpdate, onAnswerUpdate);
+ this.interMessages.push({ role: 'assistant', content: result });
+
+ let parsedResult;
+ try {
+ // Parse XML result from the assistant
+ parsedResult = parser.parse(result);
+ } catch (error) {
+ throw new Error(`Error parsing response: ${error}`);
+ }
+
+ // Extract the stage from the parsed result
+ const stage = parsedResult.stage;
+ if (!stage) {
+ throw new Error(`Error: No stage found in response`);
+ }
+
+ // Handle different stage elements (thoughts, actions, inputs, answers)
+ for (const key in stage) {
+ if (key === 'thought') {
+ // Handle assistant's thoughts
+ console.log(`Thought: ${stage[key]}`);
+ this.processingNumber++;
+ } else if (key === 'action') {
+ // Handle action stage
+ currentAction = stage[key] as string;
+ console.log(`Action: ${currentAction}`);
+
+ if (this.tools[currentAction]) {
+ // Prepare the next action based on the current tool
+ const nextPrompt = [
+ {
+ type: 'text',
+ text: `<stage number="${i + 1}" role="user">` + builder.build({ action_rules: this.tools[currentAction].getActionRule() }) + `</stage>`,
+ },
+ ];
+ this.interMessages.push({ role: 'user', content: nextPrompt });
+ break;
+ } else {
+ // Handle error in case of an invalid action
+ console.log('Error: No valid action');
+ this.interMessages.push({ role: 'user', content: `<stage number="${i + 1}" role="system-error-reporter">No valid action, try again.</stage>` });
+ break;
+ }
+ } else if (key === 'action_input') {
+ // Handle action input stage
+ const actionInput = stage[key];
+ console.log(`Action input:`, actionInput.inputs);
+
+ if (currentAction) {
+ try {
+ // Process the action with its input
+ const observation = await this.processAction(currentAction, actionInput.inputs);
+ const nextPrompt = [{ type: 'text', text: `<stage number="${i + 1}" role="user"> <observation>` }, ...observation, { type: 'text', text: '</observation></stage>' }];
+ console.log(observation);
+ this.interMessages.push({ role: 'user', content: nextPrompt });
+ this.processingNumber++;
+ break;
+ } catch (error) {
+ throw new Error(`Error processing action: ${error}`);
+ }
+ } else {
+ throw new Error('Error: Action input without a valid action');
+ }
+ } else if (key === 'answer') {
+ // If an answer is found, end the query
+ console.log('Answer found. Ending query.');
+ this.streamedAnswerParser.reset();
+ const parsedAnswer = AnswerParser.parse(result, this.processingInfo);
+ return parsedAnswer;
+ }
+ }
+ }
+
+ throw new Error('Reached maximum turns. Ending query.');
+ }
+
+ /**
+ * Executes a step in the conversation, processing the assistant's response and parsing it in real-time.
+ * @param onProcessingUpdate Callback for processing updates.
+ * @param onAnswerUpdate Callback for answer updates.
+ * @returns The full response from the assistant.
+ */
+ private async execute(onProcessingUpdate: (processingUpdate: ProcessingInfo[]) => void, onAnswerUpdate: (answerUpdate: string) => void): Promise<string> {
+ // Stream OpenAI response for real-time updates
+ const stream = await this.client.chat.completions.create({
+ model: 'gpt-4o',
+ messages: this.interMessages as ChatCompletionMessageParam[],
+ temperature: 0,
+ stream: true,
+ });
+
+ let fullResponse: string = '';
+ let currentTag: string = '';
+ let currentContent: string = '';
+ let isInsideTag: boolean = false;
+
+ // Process each chunk of the streamed response
+ for await (const chunk of stream) {
+ let content = chunk.choices[0]?.delta?.content || '';
+ fullResponse += content;
+
+ // Parse the streamed content character by character
+ for (const char of content) {
+ if (currentTag === 'answer') {
+ // Handle answer parsing for real-time updates
+ currentContent += char;
+ const streamedAnswer = this.streamedAnswerParser.parse(char);
+ onAnswerUpdate(streamedAnswer);
+ continue;
+ } else if (char === '<') {
+ // Start of a new tag
+ isInsideTag = true;
+ currentTag = '';
+ currentContent = '';
+ } else if (char === '>') {
+ // End of the tag
+ isInsideTag = false;
+ if (currentTag.startsWith('/')) {
+ currentTag = '';
+ }
+ } else if (isInsideTag) {
+ // Append characters to the tag name
+ currentTag += char;
+ } else if (currentTag === 'thought' || currentTag === 'action_input_description') {
+ // Handle processing information for thought or action input description
+ currentContent += char;
+ const current_info = this.processingInfo.find(info => info.index === this.processingNumber);
+ if (current_info) {
+ current_info.content = currentContent.trim();
+ onProcessingUpdate(this.processingInfo);
+ } else {
+ this.processingInfo.push({
+ index: this.processingNumber,
+ type: currentTag === 'thought' ? PROCESSING_TYPE.THOUGHT : PROCESSING_TYPE.ACTION,
+ content: currentContent.trim(),
+ });
+ onProcessingUpdate(this.processingInfo);
+ }
+ }
+ }
+ }
+
+ return fullResponse;
+ }
+
+ /**
+ * Processes a specific action by invoking the appropriate tool with the provided inputs.
+ * @param action The action to perform.
+ * @param actionInput The inputs for the action.
+ * @returns The result of the action.
+ */
+ private async processAction(action: string, actionInput: any): Promise<any> {
+ if (!(action in this.tools)) {
+ throw new Error(`Unknown action: ${action}`);
+ }
+
+ const tool = this.tools[action];
+ return await tool.execute(actionInput);
+ }
+}