diff options
author | A.J. Shulman <Shulman.aj@gmail.com> | 2024-08-21 16:06:31 -0400 |
---|---|---|
committer | A.J. Shulman <Shulman.aj@gmail.com> | 2024-08-21 16:06:31 -0400 |
commit | 484eb670b291afa07f2f7b976fafe02bdc9ac71d (patch) | |
tree | c201a6ac6d3cd729ff07a219c7a05987138c409a /src | |
parent | e5464e4c04ef6f8a2bbf868b43bbcdba54239406 (diff) |
added answer streaming parsing so it provides realtime parsing and then follow-up questions and citations are added when its finished
Diffstat (limited to 'src')
-rw-r--r-- | src/client/views/nodes/ChatBox/Agent.ts | 54 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/AnswerParser.ts | 2 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/ChatBox.tsx | 16 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/MessageComponent.tsx | 3 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/StreamParser.ts | 125 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/StreamedAnswerParser.ts | 73 |
6 files changed, 111 insertions, 162 deletions
diff --git a/src/client/views/nodes/ChatBox/Agent.ts b/src/client/views/nodes/ChatBox/Agent.ts index 43138bf94..4ccb179f0 100644 --- a/src/client/views/nodes/ChatBox/Agent.ts +++ b/src/client/views/nodes/ChatBox/Agent.ts @@ -12,9 +12,9 @@ import { WebsiteInfoScraperTool } from './tools/WebsiteInfoScraperTool'; import { SearchTool } from './tools/SearchTool'; import { NoTool } from './tools/NoTool'; import { on } from 'events'; -import { StreamParser } from './StreamParser'; import { v4 as uuidv4 } from 'uuid'; import { AnswerParser } from './AnswerParser'; +import { StreamedAnswerParser } from './StreamedAnswerParser'; dotenv.config(); @@ -31,6 +31,7 @@ export class Agent { private thoughtNumber: number = 0; private processingNumber: number = 0; private processingInfo: ProcessingInfo[] = []; + private streamedAnswerParser: StreamedAnswerParser = new StreamedAnswerParser(); constructor(_vectorstore: Vectorstore, summaries: () => string, history: () => string, csvData: () => { filename: string; id: string; text: string }[], addLinkedUrlDoc: (url: string, id: string) => void) { this.client = new OpenAI({ apiKey: process.env.OPENAI_KEY, dangerouslyAllowBrowser: true }); @@ -48,7 +49,7 @@ export class Agent { }; } - async askAgent(question: string, onUpdate: (update: ProcessingInfo[]) => void, maxTurns: number = 30): Promise<AssistantMessage> { + async askAgent(question: string, onProcessingUpdate: (processingUpdate: ProcessingInfo[]) => void, onAnswerUpdate: (answerUpdate: string) => void, maxTurns: number = 30): Promise<AssistantMessage> { console.log(`Starting query: ${question}`); this.messages.push({ role: 'user', content: question }); const chatHistory = this._history(); @@ -74,7 +75,7 @@ export class Agent { console.log(this.interMessages); console.log(`Turn ${i}/${maxTurns}`); - const result = await this.execute(onUpdate); + const result = await this.execute(onProcessingUpdate, onAnswerUpdate); this.interMessages.push({ role: 'assistant', content: result }); let parsedResult; @@ -133,6 +134,7 @@ export class Agent { } } else if (key === 'answer') { console.log('Answer found. Ending query.'); + this.streamedAnswerParser.reset(); const parsedAnswer = AnswerParser.parse(result, this.processingInfo); return parsedAnswer; } @@ -141,7 +143,7 @@ export class Agent { throw new Error('Reached maximum turns. Ending query.'); } - private async execute(onUpdate: (update: ProcessingInfo[]) => void): Promise<string> { + private async execute(onProcessingUpdate: (processingUpdate: ProcessingInfo[]) => void, onAnswerUpdate: (answerUpdate: string) => void): Promise<string> { const stream = await this.client.chat.completions.create({ model: 'gpt-4o', messages: this.interMessages as ChatCompletionMessageParam[], @@ -155,11 +157,18 @@ export class Agent { let isInsideTag: boolean = false; for await (const chunk of stream) { - const content = chunk.choices[0]?.delta?.content || ''; + let content = chunk.choices[0]?.delta?.content || ''; fullResponse += content; for (const char of content) { - if (char === '<') { + if (currentTag === 'answer') { + currentContent += char; + console.log(char); + const streamedAnswer = this.streamedAnswerParser.parse(char); + console.log(streamedAnswer); + onAnswerUpdate(streamedAnswer); + continue; + } else if (char === '<') { isInsideTag = true; currentTag = ''; currentContent = ''; @@ -170,11 +179,15 @@ export class Agent { } } else if (isInsideTag) { currentTag += char; - } else { + } else if (currentTag === 'thought' || currentTag === 'action_input_description') { currentContent += char; - if (currentTag === 'thought' || currentTag === 'action_input_description') { - this.processStreamedContent(currentTag, currentContent); - onUpdate(this.processingInfo); + const current_info = this.processingInfo.find(info => info.index === this.processingNumber); + if (current_info) { + current_info.content = currentContent.trim(); + onProcessingUpdate(this.processingInfo); + } else { + this.processingInfo.push({ index: this.processingNumber, type: currentTag === 'thought' ? PROCESSING_TYPE.THOUGHT : PROCESSING_TYPE.ACTION, content: currentContent.trim() }); + onProcessingUpdate(this.processingInfo); } } } @@ -183,27 +196,6 @@ export class Agent { return fullResponse; } - private processStreamedContent(tag: string, streamed_content: string) { - const current_info = this.processingInfo.find(info => info.index === this.processingNumber); - switch (tag) { - case 'thought': - if (current_info) { - current_info.content = streamed_content; - } else { - console.log(`Adding thought: ${streamed_content}`); - this.processingInfo.push({ index: this.processingNumber, type: PROCESSING_TYPE.THOUGHT, content: streamed_content.trim() }); - } - break; - case 'action_input_description': - if (current_info) { - current_info.content = streamed_content; - } else { - console.log(`Adding thought: ${streamed_content}`); - this.processingInfo.push({ index: this.processingNumber, type: PROCESSING_TYPE.ACTION, content: streamed_content.trim() }); - } - } - } - private async processAction(action: string, actionInput: any): Promise<any> { if (!(action in this.tools)) { throw new Error(`Unknown action: ${action}`); diff --git a/src/client/views/nodes/ChatBox/AnswerParser.ts b/src/client/views/nodes/ChatBox/AnswerParser.ts index 1d46a366d..b18083a27 100644 --- a/src/client/views/nodes/ChatBox/AnswerParser.ts +++ b/src/client/views/nodes/ChatBox/AnswerParser.ts @@ -56,7 +56,7 @@ export class AnswerParser { while ((match = groundedTextRegex.exec(rawTextContent)) !== null) { const [fullMatch, citationIndex, groundedText] = match; - // Add normal text before the grounded text + // Add normal text that is before the grounded text if (match.index > lastIndex) { const normalText = rawTextContent.slice(lastIndex, match.index).trim(); if (normalText) { diff --git a/src/client/views/nodes/ChatBox/ChatBox.tsx b/src/client/views/nodes/ChatBox/ChatBox.tsx index 1366eb772..45f5c0a65 100644 --- a/src/client/views/nodes/ChatBox/ChatBox.tsx +++ b/src/client/views/nodes/ChatBox/ChatBox.tsx @@ -11,7 +11,7 @@ import { ViewBoxAnnotatableComponent } from '../../DocComponent'; import { FieldView, FieldViewProps } from '../FieldView'; import './ChatBox.scss'; import MessageComponentBox from './MessageComponent'; -import { ASSISTANT_ROLE, AssistantMessage, AI_Document, Citation, CHUNK_TYPE, RAGChunk, getChunkType, TEXT_TYPE, SimplifiedChunk, ProcessingInfo } from './types'; +import { ASSISTANT_ROLE, AssistantMessage, AI_Document, Citation, CHUNK_TYPE, RAGChunk, getChunkType, TEXT_TYPE, SimplifiedChunk, ProcessingInfo, MessageContent } from './types'; import { Vectorstore } from './vectorstore/Vectorstore'; import { Agent } from './Agent'; import dotenv from 'dotenv'; @@ -175,16 +175,24 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() { this.isLoading = true; this.current_message = { role: ASSISTANT_ROLE.ASSISTANT, content: [], citations: [], processing_info: [] }; - const onUpdate = (update: ProcessingInfo[]) => { + const onProcessingUpdate = (processingUpdate: ProcessingInfo[]) => { runInAction(() => { if (this.current_message) { - this.current_message = { ...this.current_message, processing_info: update }; + this.current_message = { ...this.current_message, processing_info: processingUpdate }; } }); this.scrollToBottom(); }; - const finalMessage = await this.agent.askAgent(trimmedText, onUpdate); + const onAnswerUpdate = (answerUpdate: string) => { + runInAction(() => { + if (this.current_message) { + this.current_message = { ...this.current_message, content: [{ text: answerUpdate, type: TEXT_TYPE.NORMAL, index: 0, citation_ids: [] }] }; + } + }); + }; + + const finalMessage = await this.agent.askAgent(trimmedText, onProcessingUpdate, onAnswerUpdate); runInAction(() => { if (this.current_message) { diff --git a/src/client/views/nodes/ChatBox/MessageComponent.tsx b/src/client/views/nodes/ChatBox/MessageComponent.tsx index 3edfb272c..d0e78c751 100644 --- a/src/client/views/nodes/ChatBox/MessageComponent.tsx +++ b/src/client/views/nodes/ChatBox/MessageComponent.tsx @@ -76,15 +76,16 @@ const MessageComponentBox: React.FC<MessageComponentProps> = function ({ message return ( <div className={`message ${message.role}`}> - <div className="message-content">{message.content && message.content.map(messageFragment => <React.Fragment key={messageFragment.index}>{renderContent(messageFragment)}</React.Fragment>)}</div> {hasProcessingInfo && ( <div className="processing-info"> <button className="toggle-info" onClick={() => setDropdownOpen(!dropdownOpen)}> {dropdownOpen ? 'Hide Agent Thoughts/Actions' : 'Show Agent Thoughts/Actions'} </button> {dropdownOpen && <div className="info-content">{message.processing_info.map(renderProcessingInfo)}</div>} + <br /> </div> )} + <div className="message-content">{message.content && message.content.map(messageFragment => <React.Fragment key={messageFragment.index}>{renderContent(messageFragment)}</React.Fragment>)}</div> {message.follow_up_questions && message.follow_up_questions.length > 0 && ( <div className="follow-up-questions"> <h4>Follow-up Questions:</h4> diff --git a/src/client/views/nodes/ChatBox/StreamParser.ts b/src/client/views/nodes/ChatBox/StreamParser.ts deleted file mode 100644 index 9b087663a..000000000 --- a/src/client/views/nodes/ChatBox/StreamParser.ts +++ /dev/null @@ -1,125 +0,0 @@ -import { AssistantMessage, ASSISTANT_ROLE, TEXT_TYPE, Citation, CHUNK_TYPE } from './types'; -import { v4 as uuidv4 } from 'uuid'; - -export class StreamParser { - private currentMessage: AssistantMessage; - private currentTag: string | null = null; - private buffer: string = ''; - private citationIndex: number = 1; - - constructor() { - this.currentMessage = { - role: ASSISTANT_ROLE.ASSISTANT, - content: [], - thoughts: [], - actions: [], - citations: [], - }; - } - - parse(chunk: string): AssistantMessage { - this.buffer += chunk; - - while (this.buffer.length > 0) { - if (this.currentTag === null) { - const openTagMatch = this.buffer.match(/<(\w+)>/); - if (openTagMatch) { - this.currentTag = openTagMatch[1]; - this.buffer = this.buffer.slice(openTagMatch.index! + openTagMatch[0].length); - } else { - break; - } - } else { - const closeTagIndex = this.buffer.indexOf(`</${this.currentTag}>`); - if (closeTagIndex !== -1) { - const content = this.buffer.slice(0, closeTagIndex); - this.processTag(this.currentTag, content); - this.buffer = this.buffer.slice(closeTagIndex + this.currentTag.length + 3); - this.currentTag = null; - } else { - break; - } - } - } - - return this.currentMessage; - } - - private processTag(tag: string, content: string) { - switch (tag) { - case 'thought': - this.currentMessage.thoughts!.push(content); - break; - case 'action': - this.currentMessage.actions!.push({ index: this.currentMessage.actions!.length, action: content, action_input: '' }); - break; - case 'action_input': - if (this.currentMessage.actions!.length > 0) { - this.currentMessage.actions![this.currentMessage.actions!.length - 1].action_input = content; - } - break; - case 'answer': - this.processAnswer(content); - break; - } - } - - private processAnswer(content: string) { - const groundedTextRegex = /<grounded_text citation_index="([^"]+)">([\s\S]*?)<\/grounded_text>/g; - let lastIndex = 0; - let match; - - while ((match = groundedTextRegex.exec(content)) !== null) { - const [fullMatch, citationIndex, groundedText] = match; - - // Add normal text before the grounded text - if (match.index > lastIndex) { - const normalText = content.slice(lastIndex, match.index).trim(); - if (normalText) { - this.currentMessage.content.push({ - index: this.currentMessage.content.length, - type: TEXT_TYPE.NORMAL, - text: normalText, - citation_ids: null, - }); - } - } - - // Add grounded text - const citation_id = uuidv4(); - this.currentMessage.content.push({ - index: this.currentMessage.content.length, - type: TEXT_TYPE.GROUNDED, - text: groundedText.trim(), - citation_ids: [citation_id], - }); - - // Add citation - this.currentMessage.citations!.push({ - citation_id, - chunk_id: '', - type: CHUNK_TYPE.TEXT, - direct_text: '', - }); - - lastIndex = match.index + fullMatch.length; - } - - // Add any remaining normal text after the last grounded text - if (lastIndex < content.length) { - const remainingText = content.slice(lastIndex).trim(); - if (remainingText) { - this.currentMessage.content.push({ - index: this.currentMessage.content.length, - type: TEXT_TYPE.NORMAL, - text: remainingText, - citation_ids: null, - }); - } - } - } - - getResult(): AssistantMessage { - return this.currentMessage; - } -} diff --git a/src/client/views/nodes/ChatBox/StreamedAnswerParser.ts b/src/client/views/nodes/ChatBox/StreamedAnswerParser.ts new file mode 100644 index 000000000..3585cab4a --- /dev/null +++ b/src/client/views/nodes/ChatBox/StreamedAnswerParser.ts @@ -0,0 +1,73 @@ +import { threadId } from 'worker_threads'; + +enum ParserState { + Outside, + InGroundedText, + InNormalText, +} + +export class StreamedAnswerParser { + private state: ParserState = ParserState.Outside; + private buffer: string = ''; + private result: string = ''; + private isStartOfLine: boolean = true; + + public parse(char: string): string { + switch (this.state) { + case ParserState.Outside: + if (char === '<') { + this.buffer = '<'; + } else if (char === '>') { + if (this.buffer.startsWith('<grounded_text')) { + this.state = ParserState.InGroundedText; + } else if (this.buffer.startsWith('<normal_text')) { + this.state = ParserState.InNormalText; + } + this.buffer = ''; + } else { + this.buffer += char; + } + break; + + case ParserState.InGroundedText: + case ParserState.InNormalText: + if (char === '<') { + this.buffer = '<'; + } else if (this.buffer.startsWith('</grounded_text') && char === '>') { + this.state = ParserState.Outside; + this.buffer = ''; + } else if (this.buffer.startsWith('</normal_text') && char === '>') { + this.state = ParserState.Outside; + this.buffer = ''; + } else if (this.buffer.startsWith('<')) { + this.buffer += char; + } else { + this.processChar(char); + } + break; + } + + return this.result.trim(); + } + + private processChar(char: string): void { + if (this.isStartOfLine && char === ' ') { + // Skip leading spaces + return; + } + if (char === '\n') { + this.result += char; + this.isStartOfLine = true; + } else { + this.result += char; + this.isStartOfLine = false; + } + } + + public reset(): void { + this.state = ParserState.Outside; + this.buffer = ''; + this.result = ''; + this.isStartOfLine = true; + } +} |