aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorA.J. Shulman <Shulman.aj@gmail.com>2024-08-21 16:06:31 -0400
committerA.J. Shulman <Shulman.aj@gmail.com>2024-08-21 16:06:31 -0400
commit484eb670b291afa07f2f7b976fafe02bdc9ac71d (patch)
treec201a6ac6d3cd729ff07a219c7a05987138c409a
parente5464e4c04ef6f8a2bbf868b43bbcdba54239406 (diff)
added answer streaming parsing so it provides realtime parsing and then follow-up questions and citations are added when its finished
-rw-r--r--src/client/views/nodes/ChatBox/Agent.ts54
-rw-r--r--src/client/views/nodes/ChatBox/AnswerParser.ts2
-rw-r--r--src/client/views/nodes/ChatBox/ChatBox.tsx16
-rw-r--r--src/client/views/nodes/ChatBox/MessageComponent.tsx3
-rw-r--r--src/client/views/nodes/ChatBox/StreamParser.ts125
-rw-r--r--src/client/views/nodes/ChatBox/StreamedAnswerParser.ts73
6 files changed, 111 insertions, 162 deletions
diff --git a/src/client/views/nodes/ChatBox/Agent.ts b/src/client/views/nodes/ChatBox/Agent.ts
index 43138bf94..4ccb179f0 100644
--- a/src/client/views/nodes/ChatBox/Agent.ts
+++ b/src/client/views/nodes/ChatBox/Agent.ts
@@ -12,9 +12,9 @@ import { WebsiteInfoScraperTool } from './tools/WebsiteInfoScraperTool';
import { SearchTool } from './tools/SearchTool';
import { NoTool } from './tools/NoTool';
import { on } from 'events';
-import { StreamParser } from './StreamParser';
import { v4 as uuidv4 } from 'uuid';
import { AnswerParser } from './AnswerParser';
+import { StreamedAnswerParser } from './StreamedAnswerParser';
dotenv.config();
@@ -31,6 +31,7 @@ export class Agent {
private thoughtNumber: number = 0;
private processingNumber: number = 0;
private processingInfo: ProcessingInfo[] = [];
+ private streamedAnswerParser: StreamedAnswerParser = new StreamedAnswerParser();
constructor(_vectorstore: Vectorstore, summaries: () => string, history: () => string, csvData: () => { filename: string; id: string; text: string }[], addLinkedUrlDoc: (url: string, id: string) => void) {
this.client = new OpenAI({ apiKey: process.env.OPENAI_KEY, dangerouslyAllowBrowser: true });
@@ -48,7 +49,7 @@ export class Agent {
};
}
- async askAgent(question: string, onUpdate: (update: ProcessingInfo[]) => void, maxTurns: number = 30): Promise<AssistantMessage> {
+ async askAgent(question: string, onProcessingUpdate: (processingUpdate: ProcessingInfo[]) => void, onAnswerUpdate: (answerUpdate: string) => void, maxTurns: number = 30): Promise<AssistantMessage> {
console.log(`Starting query: ${question}`);
this.messages.push({ role: 'user', content: question });
const chatHistory = this._history();
@@ -74,7 +75,7 @@ export class Agent {
console.log(this.interMessages);
console.log(`Turn ${i}/${maxTurns}`);
- const result = await this.execute(onUpdate);
+ const result = await this.execute(onProcessingUpdate, onAnswerUpdate);
this.interMessages.push({ role: 'assistant', content: result });
let parsedResult;
@@ -133,6 +134,7 @@ export class Agent {
}
} else if (key === 'answer') {
console.log('Answer found. Ending query.');
+ this.streamedAnswerParser.reset();
const parsedAnswer = AnswerParser.parse(result, this.processingInfo);
return parsedAnswer;
}
@@ -141,7 +143,7 @@ export class Agent {
throw new Error('Reached maximum turns. Ending query.');
}
- private async execute(onUpdate: (update: ProcessingInfo[]) => void): Promise<string> {
+ private async execute(onProcessingUpdate: (processingUpdate: ProcessingInfo[]) => void, onAnswerUpdate: (answerUpdate: string) => void): Promise<string> {
const stream = await this.client.chat.completions.create({
model: 'gpt-4o',
messages: this.interMessages as ChatCompletionMessageParam[],
@@ -155,11 +157,18 @@ export class Agent {
let isInsideTag: boolean = false;
for await (const chunk of stream) {
- const content = chunk.choices[0]?.delta?.content || '';
+ let content = chunk.choices[0]?.delta?.content || '';
fullResponse += content;
for (const char of content) {
- if (char === '<') {
+ if (currentTag === 'answer') {
+ currentContent += char;
+ console.log(char);
+ const streamedAnswer = this.streamedAnswerParser.parse(char);
+ console.log(streamedAnswer);
+ onAnswerUpdate(streamedAnswer);
+ continue;
+ } else if (char === '<') {
isInsideTag = true;
currentTag = '';
currentContent = '';
@@ -170,11 +179,15 @@ export class Agent {
}
} else if (isInsideTag) {
currentTag += char;
- } else {
+ } else if (currentTag === 'thought' || currentTag === 'action_input_description') {
currentContent += char;
- if (currentTag === 'thought' || currentTag === 'action_input_description') {
- this.processStreamedContent(currentTag, currentContent);
- onUpdate(this.processingInfo);
+ const current_info = this.processingInfo.find(info => info.index === this.processingNumber);
+ if (current_info) {
+ current_info.content = currentContent.trim();
+ onProcessingUpdate(this.processingInfo);
+ } else {
+ this.processingInfo.push({ index: this.processingNumber, type: currentTag === 'thought' ? PROCESSING_TYPE.THOUGHT : PROCESSING_TYPE.ACTION, content: currentContent.trim() });
+ onProcessingUpdate(this.processingInfo);
}
}
}
@@ -183,27 +196,6 @@ export class Agent {
return fullResponse;
}
- private processStreamedContent(tag: string, streamed_content: string) {
- const current_info = this.processingInfo.find(info => info.index === this.processingNumber);
- switch (tag) {
- case 'thought':
- if (current_info) {
- current_info.content = streamed_content;
- } else {
- console.log(`Adding thought: ${streamed_content}`);
- this.processingInfo.push({ index: this.processingNumber, type: PROCESSING_TYPE.THOUGHT, content: streamed_content.trim() });
- }
- break;
- case 'action_input_description':
- if (current_info) {
- current_info.content = streamed_content;
- } else {
- console.log(`Adding thought: ${streamed_content}`);
- this.processingInfo.push({ index: this.processingNumber, type: PROCESSING_TYPE.ACTION, content: streamed_content.trim() });
- }
- }
- }
-
private async processAction(action: string, actionInput: any): Promise<any> {
if (!(action in this.tools)) {
throw new Error(`Unknown action: ${action}`);
diff --git a/src/client/views/nodes/ChatBox/AnswerParser.ts b/src/client/views/nodes/ChatBox/AnswerParser.ts
index 1d46a366d..b18083a27 100644
--- a/src/client/views/nodes/ChatBox/AnswerParser.ts
+++ b/src/client/views/nodes/ChatBox/AnswerParser.ts
@@ -56,7 +56,7 @@ export class AnswerParser {
while ((match = groundedTextRegex.exec(rawTextContent)) !== null) {
const [fullMatch, citationIndex, groundedText] = match;
- // Add normal text before the grounded text
+ // Add normal text that is before the grounded text
if (match.index > lastIndex) {
const normalText = rawTextContent.slice(lastIndex, match.index).trim();
if (normalText) {
diff --git a/src/client/views/nodes/ChatBox/ChatBox.tsx b/src/client/views/nodes/ChatBox/ChatBox.tsx
index 1366eb772..45f5c0a65 100644
--- a/src/client/views/nodes/ChatBox/ChatBox.tsx
+++ b/src/client/views/nodes/ChatBox/ChatBox.tsx
@@ -11,7 +11,7 @@ import { ViewBoxAnnotatableComponent } from '../../DocComponent';
import { FieldView, FieldViewProps } from '../FieldView';
import './ChatBox.scss';
import MessageComponentBox from './MessageComponent';
-import { ASSISTANT_ROLE, AssistantMessage, AI_Document, Citation, CHUNK_TYPE, RAGChunk, getChunkType, TEXT_TYPE, SimplifiedChunk, ProcessingInfo } from './types';
+import { ASSISTANT_ROLE, AssistantMessage, AI_Document, Citation, CHUNK_TYPE, RAGChunk, getChunkType, TEXT_TYPE, SimplifiedChunk, ProcessingInfo, MessageContent } from './types';
import { Vectorstore } from './vectorstore/Vectorstore';
import { Agent } from './Agent';
import dotenv from 'dotenv';
@@ -175,16 +175,24 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() {
this.isLoading = true;
this.current_message = { role: ASSISTANT_ROLE.ASSISTANT, content: [], citations: [], processing_info: [] };
- const onUpdate = (update: ProcessingInfo[]) => {
+ const onProcessingUpdate = (processingUpdate: ProcessingInfo[]) => {
runInAction(() => {
if (this.current_message) {
- this.current_message = { ...this.current_message, processing_info: update };
+ this.current_message = { ...this.current_message, processing_info: processingUpdate };
}
});
this.scrollToBottom();
};
- const finalMessage = await this.agent.askAgent(trimmedText, onUpdate);
+ const onAnswerUpdate = (answerUpdate: string) => {
+ runInAction(() => {
+ if (this.current_message) {
+ this.current_message = { ...this.current_message, content: [{ text: answerUpdate, type: TEXT_TYPE.NORMAL, index: 0, citation_ids: [] }] };
+ }
+ });
+ };
+
+ const finalMessage = await this.agent.askAgent(trimmedText, onProcessingUpdate, onAnswerUpdate);
runInAction(() => {
if (this.current_message) {
diff --git a/src/client/views/nodes/ChatBox/MessageComponent.tsx b/src/client/views/nodes/ChatBox/MessageComponent.tsx
index 3edfb272c..d0e78c751 100644
--- a/src/client/views/nodes/ChatBox/MessageComponent.tsx
+++ b/src/client/views/nodes/ChatBox/MessageComponent.tsx
@@ -76,15 +76,16 @@ const MessageComponentBox: React.FC<MessageComponentProps> = function ({ message
return (
<div className={`message ${message.role}`}>
- <div className="message-content">{message.content && message.content.map(messageFragment => <React.Fragment key={messageFragment.index}>{renderContent(messageFragment)}</React.Fragment>)}</div>
{hasProcessingInfo && (
<div className="processing-info">
<button className="toggle-info" onClick={() => setDropdownOpen(!dropdownOpen)}>
{dropdownOpen ? 'Hide Agent Thoughts/Actions' : 'Show Agent Thoughts/Actions'}
</button>
{dropdownOpen && <div className="info-content">{message.processing_info.map(renderProcessingInfo)}</div>}
+ <br />
</div>
)}
+ <div className="message-content">{message.content && message.content.map(messageFragment => <React.Fragment key={messageFragment.index}>{renderContent(messageFragment)}</React.Fragment>)}</div>
{message.follow_up_questions && message.follow_up_questions.length > 0 && (
<div className="follow-up-questions">
<h4>Follow-up Questions:</h4>
diff --git a/src/client/views/nodes/ChatBox/StreamParser.ts b/src/client/views/nodes/ChatBox/StreamParser.ts
deleted file mode 100644
index 9b087663a..000000000
--- a/src/client/views/nodes/ChatBox/StreamParser.ts
+++ /dev/null
@@ -1,125 +0,0 @@
-import { AssistantMessage, ASSISTANT_ROLE, TEXT_TYPE, Citation, CHUNK_TYPE } from './types';
-import { v4 as uuidv4 } from 'uuid';
-
-export class StreamParser {
- private currentMessage: AssistantMessage;
- private currentTag: string | null = null;
- private buffer: string = '';
- private citationIndex: number = 1;
-
- constructor() {
- this.currentMessage = {
- role: ASSISTANT_ROLE.ASSISTANT,
- content: [],
- thoughts: [],
- actions: [],
- citations: [],
- };
- }
-
- parse(chunk: string): AssistantMessage {
- this.buffer += chunk;
-
- while (this.buffer.length > 0) {
- if (this.currentTag === null) {
- const openTagMatch = this.buffer.match(/<(\w+)>/);
- if (openTagMatch) {
- this.currentTag = openTagMatch[1];
- this.buffer = this.buffer.slice(openTagMatch.index! + openTagMatch[0].length);
- } else {
- break;
- }
- } else {
- const closeTagIndex = this.buffer.indexOf(`</${this.currentTag}>`);
- if (closeTagIndex !== -1) {
- const content = this.buffer.slice(0, closeTagIndex);
- this.processTag(this.currentTag, content);
- this.buffer = this.buffer.slice(closeTagIndex + this.currentTag.length + 3);
- this.currentTag = null;
- } else {
- break;
- }
- }
- }
-
- return this.currentMessage;
- }
-
- private processTag(tag: string, content: string) {
- switch (tag) {
- case 'thought':
- this.currentMessage.thoughts!.push(content);
- break;
- case 'action':
- this.currentMessage.actions!.push({ index: this.currentMessage.actions!.length, action: content, action_input: '' });
- break;
- case 'action_input':
- if (this.currentMessage.actions!.length > 0) {
- this.currentMessage.actions![this.currentMessage.actions!.length - 1].action_input = content;
- }
- break;
- case 'answer':
- this.processAnswer(content);
- break;
- }
- }
-
- private processAnswer(content: string) {
- const groundedTextRegex = /<grounded_text citation_index="([^"]+)">([\s\S]*?)<\/grounded_text>/g;
- let lastIndex = 0;
- let match;
-
- while ((match = groundedTextRegex.exec(content)) !== null) {
- const [fullMatch, citationIndex, groundedText] = match;
-
- // Add normal text before the grounded text
- if (match.index > lastIndex) {
- const normalText = content.slice(lastIndex, match.index).trim();
- if (normalText) {
- this.currentMessage.content.push({
- index: this.currentMessage.content.length,
- type: TEXT_TYPE.NORMAL,
- text: normalText,
- citation_ids: null,
- });
- }
- }
-
- // Add grounded text
- const citation_id = uuidv4();
- this.currentMessage.content.push({
- index: this.currentMessage.content.length,
- type: TEXT_TYPE.GROUNDED,
- text: groundedText.trim(),
- citation_ids: [citation_id],
- });
-
- // Add citation
- this.currentMessage.citations!.push({
- citation_id,
- chunk_id: '',
- type: CHUNK_TYPE.TEXT,
- direct_text: '',
- });
-
- lastIndex = match.index + fullMatch.length;
- }
-
- // Add any remaining normal text after the last grounded text
- if (lastIndex < content.length) {
- const remainingText = content.slice(lastIndex).trim();
- if (remainingText) {
- this.currentMessage.content.push({
- index: this.currentMessage.content.length,
- type: TEXT_TYPE.NORMAL,
- text: remainingText,
- citation_ids: null,
- });
- }
- }
- }
-
- getResult(): AssistantMessage {
- return this.currentMessage;
- }
-}
diff --git a/src/client/views/nodes/ChatBox/StreamedAnswerParser.ts b/src/client/views/nodes/ChatBox/StreamedAnswerParser.ts
new file mode 100644
index 000000000..3585cab4a
--- /dev/null
+++ b/src/client/views/nodes/ChatBox/StreamedAnswerParser.ts
@@ -0,0 +1,73 @@
+import { threadId } from 'worker_threads';
+
+enum ParserState {
+ Outside,
+ InGroundedText,
+ InNormalText,
+}
+
+export class StreamedAnswerParser {
+ private state: ParserState = ParserState.Outside;
+ private buffer: string = '';
+ private result: string = '';
+ private isStartOfLine: boolean = true;
+
+ public parse(char: string): string {
+ switch (this.state) {
+ case ParserState.Outside:
+ if (char === '<') {
+ this.buffer = '<';
+ } else if (char === '>') {
+ if (this.buffer.startsWith('<grounded_text')) {
+ this.state = ParserState.InGroundedText;
+ } else if (this.buffer.startsWith('<normal_text')) {
+ this.state = ParserState.InNormalText;
+ }
+ this.buffer = '';
+ } else {
+ this.buffer += char;
+ }
+ break;
+
+ case ParserState.InGroundedText:
+ case ParserState.InNormalText:
+ if (char === '<') {
+ this.buffer = '<';
+ } else if (this.buffer.startsWith('</grounded_text') && char === '>') {
+ this.state = ParserState.Outside;
+ this.buffer = '';
+ } else if (this.buffer.startsWith('</normal_text') && char === '>') {
+ this.state = ParserState.Outside;
+ this.buffer = '';
+ } else if (this.buffer.startsWith('<')) {
+ this.buffer += char;
+ } else {
+ this.processChar(char);
+ }
+ break;
+ }
+
+ return this.result.trim();
+ }
+
+ private processChar(char: string): void {
+ if (this.isStartOfLine && char === ' ') {
+ // Skip leading spaces
+ return;
+ }
+ if (char === '\n') {
+ this.result += char;
+ this.isStartOfLine = true;
+ } else {
+ this.result += char;
+ this.isStartOfLine = false;
+ }
+ }
+
+ public reset(): void {
+ this.state = ParserState.Outside;
+ this.buffer = '';
+ this.result = '';
+ this.isStartOfLine = true;
+ }
+}