From 2c38022a7f21d4b498277b18ad31baf24ac3a143 Mon Sep 17 00:00:00 2001 From: "A.J. Shulman" Date: Sun, 18 Aug 2024 10:12:35 -0400 Subject: Attempting streaming content --- src/client/views/nodes/ChatBox/Agent.ts | 198 +++++++++++++++++---- src/client/views/nodes/ChatBox/ChatBox.tsx | 43 ++--- .../views/nodes/ChatBox/MessageComponent.tsx | 32 +++- src/client/views/nodes/ChatBox/StreamParser.ts | 125 +++++++++++++ src/client/views/nodes/ChatBox/types.ts | 1 + 5 files changed, 326 insertions(+), 73 deletions(-) create mode 100644 src/client/views/nodes/ChatBox/StreamParser.ts (limited to 'src') diff --git a/src/client/views/nodes/ChatBox/Agent.ts b/src/client/views/nodes/ChatBox/Agent.ts index 8bad29d9a..2c7c40e0c 100644 --- a/src/client/views/nodes/ChatBox/Agent.ts +++ b/src/client/views/nodes/ChatBox/Agent.ts @@ -1,5 +1,5 @@ import OpenAI from 'openai'; -import { Tool, AgentMessage } from './types'; +import { Tool, AgentMessage, AssistantMessage, TEXT_TYPE, CHUNK_TYPE, ASSISTANT_ROLE } from './types'; import { getReactPrompt } from './prompts'; import { XMLParser, XMLBuilder } from 'fast-xml-parser'; import { Vectorstore } from './vectorstore/Vectorstore'; @@ -12,6 +12,9 @@ import { WebsiteInfoScraperTool } from './tools/WebsiteInfoScraperTool'; import { SearchTool } from './tools/SearchTool'; import { NoTool } from './tools/NoTool'; import { on } from 'events'; +import { StreamParser } from './StreamParser'; +import { v4 as uuidv4 } from 'uuid'; +import { AnswerParser } from './AnswerParser'; dotenv.config(); @@ -41,7 +44,7 @@ export class Agent { }; } - async askAgent(question: string, maxTurns: number = 20, onUpdate: (update: string) => void): Promise { + async askAgent(question: string, maxTurns: number = 20, onUpdate: (update: AssistantMessage) => void): Promise { console.log(`Starting query: ${question}`); this.messages.push({ role: 'user', content: question }); const chatHistory = this._history(); @@ -53,12 +56,18 @@ export class Agent { const builder = new XMLBuilder({ ignoreAttributes: false, attributeNamePrefix: '@_' }); let currentAction: string | undefined; - let thoughtNumber = 0; + let assistantMessage: AssistantMessage = { + role: ASSISTANT_ROLE.ASSISTANT, + content: [], + thoughts: [], + actions: [], + citations: [], + }; for (let i = 2; i < maxTurns; i += 2) { console.log(`Turn ${i}/${maxTurns}`); - const result = await this.execute(onUpdate, thoughtNumber); + const result = await this.execute(assistantMessage, onUpdate); this.interMessages.push({ role: 'assistant', content: result }); let parsedResult; @@ -66,23 +75,25 @@ export class Agent { parsedResult = parser.parse(result); } catch (error) { console.log('Error: Invalid XML response from bot'); - return 'Invalid response format.'; + return assistantMessage; } const stage = parsedResult.stage; if (!stage) { console.log('Error: No stage found in response'); - return 'Invalid response format: No stage found.'; + return assistantMessage; } for (const key in stage) { - if (key === 'thought') { - console.log(`Thought: ${stage[key]}`); - thoughtNumber++; - } else if (key === 'action') { + if (!assistantMessage.actions) { + assistantMessage.actions = []; + } + if (key === 'action') { currentAction = stage[key] as string; console.log(`Action: ${currentAction}`); + assistantMessage.actions.push({ index: assistantMessage.actions.length, action: currentAction, action_input: '' }); + onUpdate({ ...assistantMessage }); if (this.tools[currentAction]) { const nextPrompt = `` + builder.build({ action_rules: this.tools[currentAction].getActionRule() }) + ``; this.interMessages.push({ role: 'user', content: nextPrompt }); @@ -93,8 +104,12 @@ export class Agent { break; } } else if (key === 'action_input') { - const actionInput = builder.build({ action_input: stage[key] }); + const actionInput = stage[key]; console.log(`Action input: ${actionInput}`); + if (currentAction && assistantMessage.actions.length > 0) { + assistantMessage.actions[assistantMessage.actions.length - 1].action_input = actionInput; + onUpdate({ ...assistantMessage }); + } if (currentAction) { try { const observation = await this.processAction(currentAction, stage[key]); @@ -104,24 +119,28 @@ export class Agent { break; } catch (error) { console.log(`Error processing action: ${error}`); - return `${error}`; + return assistantMessage; } } else { console.log('Error: Action input without a valid action'); - return 'Action input without a valid action'; + return assistantMessage; } } else if (key === 'answer') { console.log('Answer found. Ending query.'); - onUpdate(`ANSWER:${stage[key]}`); - return result; + const parsedAnswer = AnswerParser.parse(`${stage[key]}`); + assistantMessage.content = parsedAnswer.content; + assistantMessage.follow_up_questions = parsedAnswer.follow_up_questions; + assistantMessage.citations = parsedAnswer.citations; + onUpdate({ ...assistantMessage }); + return assistantMessage; } } } console.log('Reached maximum turns. Ending query.'); - return 'Reached maximum turns without finding an answer'; + return assistantMessage; } - private async execute(onUpdate: (update: string) => void, thoughtNumber: number): Promise { + private async execute(assistantMessage: AssistantMessage, onUpdate: (update: AssistantMessage) => void): Promise { const stream = await this.client.chat.completions.create({ model: 'gpt-4o', messages: this.interMessages as ChatCompletionMessageParam[], @@ -130,32 +149,147 @@ export class Agent { }); let fullResponse = ''; + let currentTag = ''; let currentContent = ''; + let isInsideTag = false; + let isInsideActionInput = false; + let actionInputContent = ''; + + if (!assistantMessage.actions) { + assistantMessage.actions = []; + } for await (const chunk of stream) { const content = chunk.choices[0]?.delta?.content || ''; fullResponse += content; - currentContent += content; + for (const char of content) { + if (char === '<') { + isInsideTag = true; + if (currentTag && currentContent) { + if (currentTag === 'action_input') { + assistantMessage.actions[assistantMessage.actions.length - 1].action_input = actionInputContent; + actionInputContent = ''; + } else { + this.processStreamedContent(currentTag, currentContent, assistantMessage); + } + onUpdate({ ...assistantMessage }); + } + currentTag = ''; + currentContent = ''; + } else if (char === '>') { + isInsideTag = false; + if (currentTag === 'action_input') { + isInsideActionInput = true; + } else if (currentTag === '/action_input') { + isInsideActionInput = false; + assistantMessage.actions[assistantMessage.actions.length - 1].action_input = actionInputContent; + actionInputContent = ''; + onUpdate({ ...assistantMessage }); + } + if (currentTag.startsWith('/')) { + currentTag = ''; + } + } else if (isInsideTag) { + currentTag += char; + } else if (isInsideActionInput) { + actionInputContent += char; + } else { + currentContent += char; + if (currentTag === 'thought' || currentTag === 'action') { + this.processStreamedContent(currentTag, currentContent, assistantMessage); + onUpdate({ ...assistantMessage }); + } + } + } + } - console.log(currentContent); + return fullResponse; + } - if (currentContent.includes('')) { - onUpdate(`THOUGHT${thoughtNumber}:${currentContent}`); - } - if (currentContent.includes('')) { - currentContent = ''; - } - if (currentContent.includes('')) { - onUpdate(`ANSWER_START:${currentContent}`); - } - if (currentContent.includes('')) { - onUpdate(`ANSWER_END:${currentContent}`); - currentContent = ''; + private processStreamedContent(tag: string, content: string, assistantMessage: AssistantMessage) { + if (!assistantMessage.thoughts) { + assistantMessage.thoughts = []; + } + if (!assistantMessage.actions) { + assistantMessage.actions = []; + } + switch (tag) { + case 'thought': + if (assistantMessage.thoughts.length > 0) { + assistantMessage.thoughts[assistantMessage.thoughts.length - 1] = content; + } else { + assistantMessage.thoughts.push(content); + } + break; + case 'action': + if (assistantMessage.actions.length > 0) { + assistantMessage.actions[assistantMessage.actions.length - 1].action = content; + } else { + assistantMessage.actions.push({ index: assistantMessage.actions.length, action: content, action_input: '' }); + } + break; + case 'action_input': + if (assistantMessage.actions.length > 0) { + assistantMessage.actions[assistantMessage.actions.length - 1].action_input = content; + } + break; + } + } + + private processAnswer(content: string, assistantMessage: AssistantMessage) { + const groundedTextRegex = /([\s\S]*?)<\/grounded_text>/g; + let lastIndex = 0; + let match; + + while ((match = groundedTextRegex.exec(content)) !== null) { + const [fullMatch, citationIndex, groundedText] = match; + + // Add normal text before the grounded text + if (match.index > lastIndex) { + const normalText = content.slice(lastIndex, match.index).trim(); + if (normalText) { + assistantMessage.content.push({ + index: assistantMessage.content.length, + type: TEXT_TYPE.NORMAL, + text: normalText, + citation_ids: null, + }); + } } + + // Add grounded text + const citation_id = uuidv4(); + assistantMessage.content.push({ + index: assistantMessage.content.length, + type: TEXT_TYPE.GROUNDED, + text: groundedText.trim(), + citation_ids: [citation_id], + }); + + // Add citation + assistantMessage.citations?.push({ + citation_id, + chunk_id: '', + type: CHUNK_TYPE.TEXT, + direct_text: '', + }); + + lastIndex = match.index + fullMatch.length; } - return fullResponse; + // Add any remaining normal text after the last grounded text + if (lastIndex < content.length) { + const remainingText = content.slice(lastIndex).trim(); + if (remainingText) { + assistantMessage.content.push({ + index: assistantMessage.content.length, + type: TEXT_TYPE.NORMAL, + text: remainingText, + citation_ids: null, + }); + } + } } private async processAction(action: string, actionInput: any): Promise { diff --git a/src/client/views/nodes/ChatBox/ChatBox.tsx b/src/client/views/nodes/ChatBox/ChatBox.tsx index d38c71810..099c0298e 100644 --- a/src/client/views/nodes/ChatBox/ChatBox.tsx +++ b/src/client/views/nodes/ChatBox/ChatBox.tsx @@ -149,45 +149,24 @@ export class ChatBox extends ViewBoxAnnotatableComponent() { if (trimmedText) { try { - //Make everything go through the answer parser - //Pass in all the updates to the AnswrParser and it will create the assistant messasge that will be the current message including adding in the thoughts and also waiting for the asnwer and also showing tool progress textInput.value = ''; this.history.push({ role: ASSISTANT_ROLE.USER, content: [{ index: 0, type: TEXT_TYPE.NORMAL, text: trimmedText, citation_ids: null }] }); this.isLoading = true; - this.current_message = { role: ASSISTANT_ROLE.ASSISTANT, content: [], thoughts: [] }; - - let currentThought = ''; - - this.current_message?.thoughts?.push(currentThought); - - const onUpdate = (update: string) => { - const thoughtNumber = Number(update.match(/^THOUGHT(\d+):/)?.[1] ?? 0); - const regex = /\s*([\s\S]*?)(?:<\/thought>|$)/; - const match = update.match(regex); - const currentThought = match ? match[1].trim() : ''; - //const numericPrefix = Number(update.match(/^\d+/)?.[0]); - if (update.startsWith('THOUGHT')) { - console.log('Thought:', currentThought, thoughtNumber); - if (this.current_message?.thoughts) { - if (this.current_message.thoughts.length <= thoughtNumber) { - this.current_message.thoughts.push(currentThought); - } else { - this.current_message.thoughts[thoughtNumber] = currentThought; - } - } - console.log('Thoughts:', this.current_message?.thoughts); - } + this.current_message = { role: ASSISTANT_ROLE.ASSISTANT, content: [], thoughts: [], actions: [], citations: [] }; + + const onUpdate = (update: AssistantMessage) => { + runInAction(() => { + this.current_message = { ...update }; + }); }; - const response = await this.agent.askAgent(trimmedText, 20, onUpdate); - const parsedAnswer = AnswerParser.parse(response); - parsedAnswer.thoughts = this.current_message?.thoughts; + const finalMessage = await this.agent.askAgent(trimmedText, 20, onUpdate); - if (this.current_message) { - this.history.push(parsedAnswer); + runInAction(() => { + this.history.push({ ...finalMessage }); this.current_message = undefined; - } - this.dataDoc.data = JSON.stringify(this.history); + this.dataDoc.data = JSON.stringify(this.history); + }); } catch (err) { console.error('Error:', err); this.history.push({ role: ASSISTANT_ROLE.ASSISTANT, content: [{ index: 0, type: TEXT_TYPE.NORMAL, text: 'Sorry, I encountered an error while processing your request.', citation_ids: null }] }); diff --git a/src/client/views/nodes/ChatBox/MessageComponent.tsx b/src/client/views/nodes/ChatBox/MessageComponent.tsx index 70b0527a2..e82dcd5f7 100644 --- a/src/client/views/nodes/ChatBox/MessageComponent.tsx +++ b/src/client/views/nodes/ChatBox/MessageComponent.tsx @@ -55,21 +55,35 @@ const MessageComponentBox: React.FC = function ({ message {item.text} ); + } else if ('query' in item) { + // Handle the case where the item has a query property + return ( + + {JSON.stringify(item.query)} + + ); } else { - return {item.text}; + // Fallback for any other unexpected cases + return {JSON.stringify(item)}; } }; return (
-
- {message.thoughts && - message.thoughts.map((thought, index) => ( - - Thought: {thought} - - ))} -
+ {message.thoughts && + message.thoughts.map((thought, idx) => ( +
+ Thought: {thought} +
+ ))} + {message.actions && + message.actions.map((action, idx) => ( +
+ Action: {action.action} +
+ Input: {action.action_input} +
+ ))}
{message.content && message.content.map(messageFragment => {renderContent(messageFragment)})}
{message.follow_up_questions && message.follow_up_questions.length > 0 && (
diff --git a/src/client/views/nodes/ChatBox/StreamParser.ts b/src/client/views/nodes/ChatBox/StreamParser.ts new file mode 100644 index 000000000..9b087663a --- /dev/null +++ b/src/client/views/nodes/ChatBox/StreamParser.ts @@ -0,0 +1,125 @@ +import { AssistantMessage, ASSISTANT_ROLE, TEXT_TYPE, Citation, CHUNK_TYPE } from './types'; +import { v4 as uuidv4 } from 'uuid'; + +export class StreamParser { + private currentMessage: AssistantMessage; + private currentTag: string | null = null; + private buffer: string = ''; + private citationIndex: number = 1; + + constructor() { + this.currentMessage = { + role: ASSISTANT_ROLE.ASSISTANT, + content: [], + thoughts: [], + actions: [], + citations: [], + }; + } + + parse(chunk: string): AssistantMessage { + this.buffer += chunk; + + while (this.buffer.length > 0) { + if (this.currentTag === null) { + const openTagMatch = this.buffer.match(/<(\w+)>/); + if (openTagMatch) { + this.currentTag = openTagMatch[1]; + this.buffer = this.buffer.slice(openTagMatch.index! + openTagMatch[0].length); + } else { + break; + } + } else { + const closeTagIndex = this.buffer.indexOf(``); + if (closeTagIndex !== -1) { + const content = this.buffer.slice(0, closeTagIndex); + this.processTag(this.currentTag, content); + this.buffer = this.buffer.slice(closeTagIndex + this.currentTag.length + 3); + this.currentTag = null; + } else { + break; + } + } + } + + return this.currentMessage; + } + + private processTag(tag: string, content: string) { + switch (tag) { + case 'thought': + this.currentMessage.thoughts!.push(content); + break; + case 'action': + this.currentMessage.actions!.push({ index: this.currentMessage.actions!.length, action: content, action_input: '' }); + break; + case 'action_input': + if (this.currentMessage.actions!.length > 0) { + this.currentMessage.actions![this.currentMessage.actions!.length - 1].action_input = content; + } + break; + case 'answer': + this.processAnswer(content); + break; + } + } + + private processAnswer(content: string) { + const groundedTextRegex = /([\s\S]*?)<\/grounded_text>/g; + let lastIndex = 0; + let match; + + while ((match = groundedTextRegex.exec(content)) !== null) { + const [fullMatch, citationIndex, groundedText] = match; + + // Add normal text before the grounded text + if (match.index > lastIndex) { + const normalText = content.slice(lastIndex, match.index).trim(); + if (normalText) { + this.currentMessage.content.push({ + index: this.currentMessage.content.length, + type: TEXT_TYPE.NORMAL, + text: normalText, + citation_ids: null, + }); + } + } + + // Add grounded text + const citation_id = uuidv4(); + this.currentMessage.content.push({ + index: this.currentMessage.content.length, + type: TEXT_TYPE.GROUNDED, + text: groundedText.trim(), + citation_ids: [citation_id], + }); + + // Add citation + this.currentMessage.citations!.push({ + citation_id, + chunk_id: '', + type: CHUNK_TYPE.TEXT, + direct_text: '', + }); + + lastIndex = match.index + fullMatch.length; + } + + // Add any remaining normal text after the last grounded text + if (lastIndex < content.length) { + const remainingText = content.slice(lastIndex).trim(); + if (remainingText) { + this.currentMessage.content.push({ + index: this.currentMessage.content.length, + type: TEXT_TYPE.NORMAL, + text: remainingText, + citation_ids: null, + }); + } + } + } + + getResult(): AssistantMessage { + return this.currentMessage; + } +} diff --git a/src/client/views/nodes/ChatBox/types.ts b/src/client/views/nodes/ChatBox/types.ts index 391f124e0..efeec7b93 100644 --- a/src/client/views/nodes/ChatBox/types.ts +++ b/src/client/views/nodes/ChatBox/types.ts @@ -48,6 +48,7 @@ export interface AssistantMessage { content: MessageContent[]; follow_up_questions?: string[]; thoughts?: string[]; + actions?: { index: number; action: string; action_input: string }[]; citations?: Citation[]; } -- cgit v1.2.3-70-g09d2