From 2c38022a7f21d4b498277b18ad31baf24ac3a143 Mon Sep 17 00:00:00 2001 From: "A.J. Shulman" Date: Sun, 18 Aug 2024 10:12:35 -0400 Subject: Attempting streaming content --- src/client/views/nodes/ChatBox/Agent.ts | 198 ++++++++++++++++++++++++++------ 1 file changed, 166 insertions(+), 32 deletions(-) (limited to 'src/client/views/nodes/ChatBox/Agent.ts') diff --git a/src/client/views/nodes/ChatBox/Agent.ts b/src/client/views/nodes/ChatBox/Agent.ts index 8bad29d9a..2c7c40e0c 100644 --- a/src/client/views/nodes/ChatBox/Agent.ts +++ b/src/client/views/nodes/ChatBox/Agent.ts @@ -1,5 +1,5 @@ import OpenAI from 'openai'; -import { Tool, AgentMessage } from './types'; +import { Tool, AgentMessage, AssistantMessage, TEXT_TYPE, CHUNK_TYPE, ASSISTANT_ROLE } from './types'; import { getReactPrompt } from './prompts'; import { XMLParser, XMLBuilder } from 'fast-xml-parser'; import { Vectorstore } from './vectorstore/Vectorstore'; @@ -12,6 +12,9 @@ import { WebsiteInfoScraperTool } from './tools/WebsiteInfoScraperTool'; import { SearchTool } from './tools/SearchTool'; import { NoTool } from './tools/NoTool'; import { on } from 'events'; +import { StreamParser } from './StreamParser'; +import { v4 as uuidv4 } from 'uuid'; +import { AnswerParser } from './AnswerParser'; dotenv.config(); @@ -41,7 +44,7 @@ export class Agent { }; } - async askAgent(question: string, maxTurns: number = 20, onUpdate: (update: string) => void): Promise { + async askAgent(question: string, maxTurns: number = 20, onUpdate: (update: AssistantMessage) => void): Promise { console.log(`Starting query: ${question}`); this.messages.push({ role: 'user', content: question }); const chatHistory = this._history(); @@ -53,12 +56,18 @@ export class Agent { const builder = new XMLBuilder({ ignoreAttributes: false, attributeNamePrefix: '@_' }); let currentAction: string | undefined; - let thoughtNumber = 0; + let assistantMessage: AssistantMessage = { + role: ASSISTANT_ROLE.ASSISTANT, + content: [], + thoughts: [], + actions: [], + citations: [], + }; for (let i = 2; i < maxTurns; i += 2) { console.log(`Turn ${i}/${maxTurns}`); - const result = await this.execute(onUpdate, thoughtNumber); + const result = await this.execute(assistantMessage, onUpdate); this.interMessages.push({ role: 'assistant', content: result }); let parsedResult; @@ -66,23 +75,25 @@ export class Agent { parsedResult = parser.parse(result); } catch (error) { console.log('Error: Invalid XML response from bot'); - return 'Invalid response format.'; + return assistantMessage; } const stage = parsedResult.stage; if (!stage) { console.log('Error: No stage found in response'); - return 'Invalid response format: No stage found.'; + return assistantMessage; } for (const key in stage) { - if (key === 'thought') { - console.log(`Thought: ${stage[key]}`); - thoughtNumber++; - } else if (key === 'action') { + if (!assistantMessage.actions) { + assistantMessage.actions = []; + } + if (key === 'action') { currentAction = stage[key] as string; console.log(`Action: ${currentAction}`); + assistantMessage.actions.push({ index: assistantMessage.actions.length, action: currentAction, action_input: '' }); + onUpdate({ ...assistantMessage }); if (this.tools[currentAction]) { const nextPrompt = `` + builder.build({ action_rules: this.tools[currentAction].getActionRule() }) + ``; this.interMessages.push({ role: 'user', content: nextPrompt }); @@ -93,8 +104,12 @@ export class Agent { break; } } else if (key === 'action_input') { - const actionInput = builder.build({ action_input: stage[key] }); + const actionInput = stage[key]; console.log(`Action input: ${actionInput}`); + if (currentAction && assistantMessage.actions.length > 0) { + assistantMessage.actions[assistantMessage.actions.length - 1].action_input = actionInput; + onUpdate({ ...assistantMessage }); + } if (currentAction) { try { const observation = await this.processAction(currentAction, stage[key]); @@ -104,24 +119,28 @@ export class Agent { break; } catch (error) { console.log(`Error processing action: ${error}`); - return `${error}`; + return assistantMessage; } } else { console.log('Error: Action input without a valid action'); - return 'Action input without a valid action'; + return assistantMessage; } } else if (key === 'answer') { console.log('Answer found. Ending query.'); - onUpdate(`ANSWER:${stage[key]}`); - return result; + const parsedAnswer = AnswerParser.parse(`${stage[key]}`); + assistantMessage.content = parsedAnswer.content; + assistantMessage.follow_up_questions = parsedAnswer.follow_up_questions; + assistantMessage.citations = parsedAnswer.citations; + onUpdate({ ...assistantMessage }); + return assistantMessage; } } } console.log('Reached maximum turns. Ending query.'); - return 'Reached maximum turns without finding an answer'; + return assistantMessage; } - private async execute(onUpdate: (update: string) => void, thoughtNumber: number): Promise { + private async execute(assistantMessage: AssistantMessage, onUpdate: (update: AssistantMessage) => void): Promise { const stream = await this.client.chat.completions.create({ model: 'gpt-4o', messages: this.interMessages as ChatCompletionMessageParam[], @@ -130,32 +149,147 @@ export class Agent { }); let fullResponse = ''; + let currentTag = ''; let currentContent = ''; + let isInsideTag = false; + let isInsideActionInput = false; + let actionInputContent = ''; + + if (!assistantMessage.actions) { + assistantMessage.actions = []; + } for await (const chunk of stream) { const content = chunk.choices[0]?.delta?.content || ''; fullResponse += content; - currentContent += content; + for (const char of content) { + if (char === '<') { + isInsideTag = true; + if (currentTag && currentContent) { + if (currentTag === 'action_input') { + assistantMessage.actions[assistantMessage.actions.length - 1].action_input = actionInputContent; + actionInputContent = ''; + } else { + this.processStreamedContent(currentTag, currentContent, assistantMessage); + } + onUpdate({ ...assistantMessage }); + } + currentTag = ''; + currentContent = ''; + } else if (char === '>') { + isInsideTag = false; + if (currentTag === 'action_input') { + isInsideActionInput = true; + } else if (currentTag === '/action_input') { + isInsideActionInput = false; + assistantMessage.actions[assistantMessage.actions.length - 1].action_input = actionInputContent; + actionInputContent = ''; + onUpdate({ ...assistantMessage }); + } + if (currentTag.startsWith('/')) { + currentTag = ''; + } + } else if (isInsideTag) { + currentTag += char; + } else if (isInsideActionInput) { + actionInputContent += char; + } else { + currentContent += char; + if (currentTag === 'thought' || currentTag === 'action') { + this.processStreamedContent(currentTag, currentContent, assistantMessage); + onUpdate({ ...assistantMessage }); + } + } + } + } - console.log(currentContent); + return fullResponse; + } - if (currentContent.includes('')) { - onUpdate(`THOUGHT${thoughtNumber}:${currentContent}`); - } - if (currentContent.includes('')) { - currentContent = ''; - } - if (currentContent.includes('')) { - onUpdate(`ANSWER_START:${currentContent}`); - } - if (currentContent.includes('')) { - onUpdate(`ANSWER_END:${currentContent}`); - currentContent = ''; + private processStreamedContent(tag: string, content: string, assistantMessage: AssistantMessage) { + if (!assistantMessage.thoughts) { + assistantMessage.thoughts = []; + } + if (!assistantMessage.actions) { + assistantMessage.actions = []; + } + switch (tag) { + case 'thought': + if (assistantMessage.thoughts.length > 0) { + assistantMessage.thoughts[assistantMessage.thoughts.length - 1] = content; + } else { + assistantMessage.thoughts.push(content); + } + break; + case 'action': + if (assistantMessage.actions.length > 0) { + assistantMessage.actions[assistantMessage.actions.length - 1].action = content; + } else { + assistantMessage.actions.push({ index: assistantMessage.actions.length, action: content, action_input: '' }); + } + break; + case 'action_input': + if (assistantMessage.actions.length > 0) { + assistantMessage.actions[assistantMessage.actions.length - 1].action_input = content; + } + break; + } + } + + private processAnswer(content: string, assistantMessage: AssistantMessage) { + const groundedTextRegex = /([\s\S]*?)<\/grounded_text>/g; + let lastIndex = 0; + let match; + + while ((match = groundedTextRegex.exec(content)) !== null) { + const [fullMatch, citationIndex, groundedText] = match; + + // Add normal text before the grounded text + if (match.index > lastIndex) { + const normalText = content.slice(lastIndex, match.index).trim(); + if (normalText) { + assistantMessage.content.push({ + index: assistantMessage.content.length, + type: TEXT_TYPE.NORMAL, + text: normalText, + citation_ids: null, + }); + } } + + // Add grounded text + const citation_id = uuidv4(); + assistantMessage.content.push({ + index: assistantMessage.content.length, + type: TEXT_TYPE.GROUNDED, + text: groundedText.trim(), + citation_ids: [citation_id], + }); + + // Add citation + assistantMessage.citations?.push({ + citation_id, + chunk_id: '', + type: CHUNK_TYPE.TEXT, + direct_text: '', + }); + + lastIndex = match.index + fullMatch.length; } - return fullResponse; + // Add any remaining normal text after the last grounded text + if (lastIndex < content.length) { + const remainingText = content.slice(lastIndex).trim(); + if (remainingText) { + assistantMessage.content.push({ + index: assistantMessage.content.length, + type: TEXT_TYPE.NORMAL, + text: remainingText, + citation_ids: null, + }); + } + } } private async processAction(action: string, actionInput: any): Promise { -- cgit v1.2.3-70-g09d2