diff options
-rw-r--r-- | src/client/views/nodes/ChatBox/Agent.ts | 198 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/ChatBox.tsx | 43 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/MessageComponent.tsx | 32 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/StreamParser.ts | 125 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/types.ts | 1 |
5 files changed, 326 insertions, 73 deletions
diff --git a/src/client/views/nodes/ChatBox/Agent.ts b/src/client/views/nodes/ChatBox/Agent.ts index 8bad29d9a..2c7c40e0c 100644 --- a/src/client/views/nodes/ChatBox/Agent.ts +++ b/src/client/views/nodes/ChatBox/Agent.ts @@ -1,5 +1,5 @@ import OpenAI from 'openai'; -import { Tool, AgentMessage } from './types'; +import { Tool, AgentMessage, AssistantMessage, TEXT_TYPE, CHUNK_TYPE, ASSISTANT_ROLE } from './types'; import { getReactPrompt } from './prompts'; import { XMLParser, XMLBuilder } from 'fast-xml-parser'; import { Vectorstore } from './vectorstore/Vectorstore'; @@ -12,6 +12,9 @@ import { WebsiteInfoScraperTool } from './tools/WebsiteInfoScraperTool'; import { SearchTool } from './tools/SearchTool'; import { NoTool } from './tools/NoTool'; import { on } from 'events'; +import { StreamParser } from './StreamParser'; +import { v4 as uuidv4 } from 'uuid'; +import { AnswerParser } from './AnswerParser'; dotenv.config(); @@ -41,7 +44,7 @@ export class Agent { }; } - async askAgent(question: string, maxTurns: number = 20, onUpdate: (update: string) => void): Promise<string> { + async askAgent(question: string, maxTurns: number = 20, onUpdate: (update: AssistantMessage) => void): Promise<AssistantMessage> { console.log(`Starting query: ${question}`); this.messages.push({ role: 'user', content: question }); const chatHistory = this._history(); @@ -53,12 +56,18 @@ export class Agent { const builder = new XMLBuilder({ ignoreAttributes: false, attributeNamePrefix: '@_' }); let currentAction: string | undefined; - let thoughtNumber = 0; + let assistantMessage: AssistantMessage = { + role: ASSISTANT_ROLE.ASSISTANT, + content: [], + thoughts: [], + actions: [], + citations: [], + }; for (let i = 2; i < maxTurns; i += 2) { console.log(`Turn ${i}/${maxTurns}`); - const result = await this.execute(onUpdate, thoughtNumber); + const result = await this.execute(assistantMessage, onUpdate); this.interMessages.push({ role: 'assistant', content: result }); let parsedResult; @@ -66,23 +75,25 @@ export class Agent { parsedResult = parser.parse(result); } catch (error) { console.log('Error: Invalid XML response from bot'); - return '<error>Invalid response format.</error>'; + return assistantMessage; } const stage = parsedResult.stage; if (!stage) { console.log('Error: No stage found in response'); - return '<error>Invalid response format: No stage found.</error>'; + return assistantMessage; } for (const key in stage) { - if (key === 'thought') { - console.log(`Thought: ${stage[key]}`); - thoughtNumber++; - } else if (key === 'action') { + if (!assistantMessage.actions) { + assistantMessage.actions = []; + } + if (key === 'action') { currentAction = stage[key] as string; console.log(`Action: ${currentAction}`); + assistantMessage.actions.push({ index: assistantMessage.actions.length, action: currentAction, action_input: '' }); + onUpdate({ ...assistantMessage }); if (this.tools[currentAction]) { const nextPrompt = `<stage number="${i + 1}" role="user">` + builder.build({ action_rules: this.tools[currentAction].getActionRule() }) + `</stage>`; this.interMessages.push({ role: 'user', content: nextPrompt }); @@ -93,8 +104,12 @@ export class Agent { break; } } else if (key === 'action_input') { - const actionInput = builder.build({ action_input: stage[key] }); + const actionInput = stage[key]; console.log(`Action input: ${actionInput}`); + if (currentAction && assistantMessage.actions.length > 0) { + assistantMessage.actions[assistantMessage.actions.length - 1].action_input = actionInput; + onUpdate({ ...assistantMessage }); + } if (currentAction) { try { const observation = await this.processAction(currentAction, stage[key]); @@ -104,24 +119,28 @@ export class Agent { break; } catch (error) { console.log(`Error processing action: ${error}`); - return `<error>${error}</error>`; + return assistantMessage; } } else { console.log('Error: Action input without a valid action'); - return '<error>Action input without a valid action</error>'; + return assistantMessage; } } else if (key === 'answer') { console.log('Answer found. Ending query.'); - onUpdate(`ANSWER:${stage[key]}`); - return result; + const parsedAnswer = AnswerParser.parse(`<answer>${stage[key]}</answer>`); + assistantMessage.content = parsedAnswer.content; + assistantMessage.follow_up_questions = parsedAnswer.follow_up_questions; + assistantMessage.citations = parsedAnswer.citations; + onUpdate({ ...assistantMessage }); + return assistantMessage; } } } console.log('Reached maximum turns. Ending query.'); - return '<error>Reached maximum turns without finding an answer</error>'; + return assistantMessage; } - private async execute(onUpdate: (update: string) => void, thoughtNumber: number): Promise<string> { + private async execute(assistantMessage: AssistantMessage, onUpdate: (update: AssistantMessage) => void): Promise<string> { const stream = await this.client.chat.completions.create({ model: 'gpt-4o', messages: this.interMessages as ChatCompletionMessageParam[], @@ -130,32 +149,147 @@ export class Agent { }); let fullResponse = ''; + let currentTag = ''; let currentContent = ''; + let isInsideTag = false; + let isInsideActionInput = false; + let actionInputContent = ''; + + if (!assistantMessage.actions) { + assistantMessage.actions = []; + } for await (const chunk of stream) { const content = chunk.choices[0]?.delta?.content || ''; fullResponse += content; - currentContent += content; + for (const char of content) { + if (char === '<') { + isInsideTag = true; + if (currentTag && currentContent) { + if (currentTag === 'action_input') { + assistantMessage.actions[assistantMessage.actions.length - 1].action_input = actionInputContent; + actionInputContent = ''; + } else { + this.processStreamedContent(currentTag, currentContent, assistantMessage); + } + onUpdate({ ...assistantMessage }); + } + currentTag = ''; + currentContent = ''; + } else if (char === '>') { + isInsideTag = false; + if (currentTag === 'action_input') { + isInsideActionInput = true; + } else if (currentTag === '/action_input') { + isInsideActionInput = false; + assistantMessage.actions[assistantMessage.actions.length - 1].action_input = actionInputContent; + actionInputContent = ''; + onUpdate({ ...assistantMessage }); + } + if (currentTag.startsWith('/')) { + currentTag = ''; + } + } else if (isInsideTag) { + currentTag += char; + } else if (isInsideActionInput) { + actionInputContent += char; + } else { + currentContent += char; + if (currentTag === 'thought' || currentTag === 'action') { + this.processStreamedContent(currentTag, currentContent, assistantMessage); + onUpdate({ ...assistantMessage }); + } + } + } + } - console.log(currentContent); + return fullResponse; + } - if (currentContent.includes('<thought>')) { - onUpdate(`THOUGHT${thoughtNumber}:${currentContent}`); - } - if (currentContent.includes('</thought>')) { - currentContent = ''; - } - if (currentContent.includes('<answer>')) { - onUpdate(`ANSWER_START:${currentContent}`); - } - if (currentContent.includes('</answer>')) { - onUpdate(`ANSWER_END:${currentContent}`); - currentContent = ''; + private processStreamedContent(tag: string, content: string, assistantMessage: AssistantMessage) { + if (!assistantMessage.thoughts) { + assistantMessage.thoughts = []; + } + if (!assistantMessage.actions) { + assistantMessage.actions = []; + } + switch (tag) { + case 'thought': + if (assistantMessage.thoughts.length > 0) { + assistantMessage.thoughts[assistantMessage.thoughts.length - 1] = content; + } else { + assistantMessage.thoughts.push(content); + } + break; + case 'action': + if (assistantMessage.actions.length > 0) { + assistantMessage.actions[assistantMessage.actions.length - 1].action = content; + } else { + assistantMessage.actions.push({ index: assistantMessage.actions.length, action: content, action_input: '' }); + } + break; + case 'action_input': + if (assistantMessage.actions.length > 0) { + assistantMessage.actions[assistantMessage.actions.length - 1].action_input = content; + } + break; + } + } + + private processAnswer(content: string, assistantMessage: AssistantMessage) { + const groundedTextRegex = /<grounded_text citation_index="([^"]+)">([\s\S]*?)<\/grounded_text>/g; + let lastIndex = 0; + let match; + + while ((match = groundedTextRegex.exec(content)) !== null) { + const [fullMatch, citationIndex, groundedText] = match; + + // Add normal text before the grounded text + if (match.index > lastIndex) { + const normalText = content.slice(lastIndex, match.index).trim(); + if (normalText) { + assistantMessage.content.push({ + index: assistantMessage.content.length, + type: TEXT_TYPE.NORMAL, + text: normalText, + citation_ids: null, + }); + } } + + // Add grounded text + const citation_id = uuidv4(); + assistantMessage.content.push({ + index: assistantMessage.content.length, + type: TEXT_TYPE.GROUNDED, + text: groundedText.trim(), + citation_ids: [citation_id], + }); + + // Add citation + assistantMessage.citations?.push({ + citation_id, + chunk_id: '', + type: CHUNK_TYPE.TEXT, + direct_text: '', + }); + + lastIndex = match.index + fullMatch.length; } - return fullResponse; + // Add any remaining normal text after the last grounded text + if (lastIndex < content.length) { + const remainingText = content.slice(lastIndex).trim(); + if (remainingText) { + assistantMessage.content.push({ + index: assistantMessage.content.length, + type: TEXT_TYPE.NORMAL, + text: remainingText, + citation_ids: null, + }); + } + } } private async processAction(action: string, actionInput: any): Promise<any> { diff --git a/src/client/views/nodes/ChatBox/ChatBox.tsx b/src/client/views/nodes/ChatBox/ChatBox.tsx index d38c71810..099c0298e 100644 --- a/src/client/views/nodes/ChatBox/ChatBox.tsx +++ b/src/client/views/nodes/ChatBox/ChatBox.tsx @@ -149,45 +149,24 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() { if (trimmedText) { try { - //Make everything go through the answer parser - //Pass in all the updates to the AnswrParser and it will create the assistant messasge that will be the current message including adding in the thoughts and also waiting for the asnwer and also showing tool progress textInput.value = ''; this.history.push({ role: ASSISTANT_ROLE.USER, content: [{ index: 0, type: TEXT_TYPE.NORMAL, text: trimmedText, citation_ids: null }] }); this.isLoading = true; - this.current_message = { role: ASSISTANT_ROLE.ASSISTANT, content: [], thoughts: [] }; - - let currentThought = ''; - - this.current_message?.thoughts?.push(currentThought); - - const onUpdate = (update: string) => { - const thoughtNumber = Number(update.match(/^THOUGHT(\d+):/)?.[1] ?? 0); - const regex = /<thought>\s*([\s\S]*?)(?:<\/thought>|$)/; - const match = update.match(regex); - const currentThought = match ? match[1].trim() : ''; - //const numericPrefix = Number(update.match(/^\d+/)?.[0]); - if (update.startsWith('THOUGHT')) { - console.log('Thought:', currentThought, thoughtNumber); - if (this.current_message?.thoughts) { - if (this.current_message.thoughts.length <= thoughtNumber) { - this.current_message.thoughts.push(currentThought); - } else { - this.current_message.thoughts[thoughtNumber] = currentThought; - } - } - console.log('Thoughts:', this.current_message?.thoughts); - } + this.current_message = { role: ASSISTANT_ROLE.ASSISTANT, content: [], thoughts: [], actions: [], citations: [] }; + + const onUpdate = (update: AssistantMessage) => { + runInAction(() => { + this.current_message = { ...update }; + }); }; - const response = await this.agent.askAgent(trimmedText, 20, onUpdate); - const parsedAnswer = AnswerParser.parse(response); - parsedAnswer.thoughts = this.current_message?.thoughts; + const finalMessage = await this.agent.askAgent(trimmedText, 20, onUpdate); - if (this.current_message) { - this.history.push(parsedAnswer); + runInAction(() => { + this.history.push({ ...finalMessage }); this.current_message = undefined; - } - this.dataDoc.data = JSON.stringify(this.history); + this.dataDoc.data = JSON.stringify(this.history); + }); } catch (err) { console.error('Error:', err); this.history.push({ role: ASSISTANT_ROLE.ASSISTANT, content: [{ index: 0, type: TEXT_TYPE.NORMAL, text: 'Sorry, I encountered an error while processing your request.', citation_ids: null }] }); diff --git a/src/client/views/nodes/ChatBox/MessageComponent.tsx b/src/client/views/nodes/ChatBox/MessageComponent.tsx index 70b0527a2..e82dcd5f7 100644 --- a/src/client/views/nodes/ChatBox/MessageComponent.tsx +++ b/src/client/views/nodes/ChatBox/MessageComponent.tsx @@ -55,21 +55,35 @@ const MessageComponentBox: React.FC<MessageComponentProps> = function ({ message {item.text} </span> ); + } else if ('query' in item) { + // Handle the case where the item has a query property + return ( + <span key={i} className="query-text"> + {JSON.stringify(item.query)} + </span> + ); } else { - return <span key={i}>{item.text}</span>; + // Fallback for any other unexpected cases + return <span key={i}>{JSON.stringify(item)}</span>; } }; return ( <div className={`message ${message.role}`}> - <div> - {message.thoughts && - message.thoughts.map((thought, index) => ( - <span key={index} className="thought-text"> - <i>Thought: {thought}</i> - </span> - ))} - </div> + {message.thoughts && + message.thoughts.map((thought, idx) => ( + <div key={idx} className="thought"> + <i>Thought: {thought}</i> + </div> + ))} + {message.actions && + message.actions.map((action, idx) => ( + <div key={idx} className="action"> + <strong>Action:</strong> {action.action} + <br /> + <strong>Input:</strong> {action.action_input} + </div> + ))} <div>{message.content && message.content.map(messageFragment => <React.Fragment key={messageFragment.index}>{renderContent(messageFragment)}</React.Fragment>)}</div> {message.follow_up_questions && message.follow_up_questions.length > 0 && ( <div className="follow-up-questions"> diff --git a/src/client/views/nodes/ChatBox/StreamParser.ts b/src/client/views/nodes/ChatBox/StreamParser.ts new file mode 100644 index 000000000..9b087663a --- /dev/null +++ b/src/client/views/nodes/ChatBox/StreamParser.ts @@ -0,0 +1,125 @@ +import { AssistantMessage, ASSISTANT_ROLE, TEXT_TYPE, Citation, CHUNK_TYPE } from './types'; +import { v4 as uuidv4 } from 'uuid'; + +export class StreamParser { + private currentMessage: AssistantMessage; + private currentTag: string | null = null; + private buffer: string = ''; + private citationIndex: number = 1; + + constructor() { + this.currentMessage = { + role: ASSISTANT_ROLE.ASSISTANT, + content: [], + thoughts: [], + actions: [], + citations: [], + }; + } + + parse(chunk: string): AssistantMessage { + this.buffer += chunk; + + while (this.buffer.length > 0) { + if (this.currentTag === null) { + const openTagMatch = this.buffer.match(/<(\w+)>/); + if (openTagMatch) { + this.currentTag = openTagMatch[1]; + this.buffer = this.buffer.slice(openTagMatch.index! + openTagMatch[0].length); + } else { + break; + } + } else { + const closeTagIndex = this.buffer.indexOf(`</${this.currentTag}>`); + if (closeTagIndex !== -1) { + const content = this.buffer.slice(0, closeTagIndex); + this.processTag(this.currentTag, content); + this.buffer = this.buffer.slice(closeTagIndex + this.currentTag.length + 3); + this.currentTag = null; + } else { + break; + } + } + } + + return this.currentMessage; + } + + private processTag(tag: string, content: string) { + switch (tag) { + case 'thought': + this.currentMessage.thoughts!.push(content); + break; + case 'action': + this.currentMessage.actions!.push({ index: this.currentMessage.actions!.length, action: content, action_input: '' }); + break; + case 'action_input': + if (this.currentMessage.actions!.length > 0) { + this.currentMessage.actions![this.currentMessage.actions!.length - 1].action_input = content; + } + break; + case 'answer': + this.processAnswer(content); + break; + } + } + + private processAnswer(content: string) { + const groundedTextRegex = /<grounded_text citation_index="([^"]+)">([\s\S]*?)<\/grounded_text>/g; + let lastIndex = 0; + let match; + + while ((match = groundedTextRegex.exec(content)) !== null) { + const [fullMatch, citationIndex, groundedText] = match; + + // Add normal text before the grounded text + if (match.index > lastIndex) { + const normalText = content.slice(lastIndex, match.index).trim(); + if (normalText) { + this.currentMessage.content.push({ + index: this.currentMessage.content.length, + type: TEXT_TYPE.NORMAL, + text: normalText, + citation_ids: null, + }); + } + } + + // Add grounded text + const citation_id = uuidv4(); + this.currentMessage.content.push({ + index: this.currentMessage.content.length, + type: TEXT_TYPE.GROUNDED, + text: groundedText.trim(), + citation_ids: [citation_id], + }); + + // Add citation + this.currentMessage.citations!.push({ + citation_id, + chunk_id: '', + type: CHUNK_TYPE.TEXT, + direct_text: '', + }); + + lastIndex = match.index + fullMatch.length; + } + + // Add any remaining normal text after the last grounded text + if (lastIndex < content.length) { + const remainingText = content.slice(lastIndex).trim(); + if (remainingText) { + this.currentMessage.content.push({ + index: this.currentMessage.content.length, + type: TEXT_TYPE.NORMAL, + text: remainingText, + citation_ids: null, + }); + } + } + } + + getResult(): AssistantMessage { + return this.currentMessage; + } +} diff --git a/src/client/views/nodes/ChatBox/types.ts b/src/client/views/nodes/ChatBox/types.ts index 391f124e0..efeec7b93 100644 --- a/src/client/views/nodes/ChatBox/types.ts +++ b/src/client/views/nodes/ChatBox/types.ts @@ -48,6 +48,7 @@ export interface AssistantMessage { content: MessageContent[]; follow_up_questions?: string[]; thoughts?: string[]; + actions?: { index: number; action: string; action_input: string }[]; citations?: Citation[]; } |