aboutsummaryrefslogtreecommitdiff
path: root/src/client/views/nodes/ChatBox/Agent.ts
diff options
context:
space:
mode:
authorA.J. Shulman <Shulman.aj@gmail.com>2024-08-18 10:12:35 -0400
committerA.J. Shulman <Shulman.aj@gmail.com>2024-08-18 10:12:35 -0400
commit2c38022a7f21d4b498277b18ad31baf24ac3a143 (patch)
tree006e70734530ad5cc9e08a3cadea200cceefdba5 /src/client/views/nodes/ChatBox/Agent.ts
parentdaa72b906e3364c2b6a836533fc1980bb63ba303 (diff)
Attempting streaming content
Diffstat (limited to 'src/client/views/nodes/ChatBox/Agent.ts')
-rw-r--r--src/client/views/nodes/ChatBox/Agent.ts198
1 files changed, 166 insertions, 32 deletions
diff --git a/src/client/views/nodes/ChatBox/Agent.ts b/src/client/views/nodes/ChatBox/Agent.ts
index 8bad29d9a..2c7c40e0c 100644
--- a/src/client/views/nodes/ChatBox/Agent.ts
+++ b/src/client/views/nodes/ChatBox/Agent.ts
@@ -1,5 +1,5 @@
import OpenAI from 'openai';
-import { Tool, AgentMessage } from './types';
+import { Tool, AgentMessage, AssistantMessage, TEXT_TYPE, CHUNK_TYPE, ASSISTANT_ROLE } from './types';
import { getReactPrompt } from './prompts';
import { XMLParser, XMLBuilder } from 'fast-xml-parser';
import { Vectorstore } from './vectorstore/Vectorstore';
@@ -12,6 +12,9 @@ import { WebsiteInfoScraperTool } from './tools/WebsiteInfoScraperTool';
import { SearchTool } from './tools/SearchTool';
import { NoTool } from './tools/NoTool';
import { on } from 'events';
+import { StreamParser } from './StreamParser';
+import { v4 as uuidv4 } from 'uuid';
+import { AnswerParser } from './AnswerParser';
dotenv.config();
@@ -41,7 +44,7 @@ export class Agent {
};
}
- async askAgent(question: string, maxTurns: number = 20, onUpdate: (update: string) => void): Promise<string> {
+ async askAgent(question: string, maxTurns: number = 20, onUpdate: (update: AssistantMessage) => void): Promise<AssistantMessage> {
console.log(`Starting query: ${question}`);
this.messages.push({ role: 'user', content: question });
const chatHistory = this._history();
@@ -53,12 +56,18 @@ export class Agent {
const builder = new XMLBuilder({ ignoreAttributes: false, attributeNamePrefix: '@_' });
let currentAction: string | undefined;
- let thoughtNumber = 0;
+ let assistantMessage: AssistantMessage = {
+ role: ASSISTANT_ROLE.ASSISTANT,
+ content: [],
+ thoughts: [],
+ actions: [],
+ citations: [],
+ };
for (let i = 2; i < maxTurns; i += 2) {
console.log(`Turn ${i}/${maxTurns}`);
- const result = await this.execute(onUpdate, thoughtNumber);
+ const result = await this.execute(assistantMessage, onUpdate);
this.interMessages.push({ role: 'assistant', content: result });
let parsedResult;
@@ -66,23 +75,25 @@ export class Agent {
parsedResult = parser.parse(result);
} catch (error) {
console.log('Error: Invalid XML response from bot');
- return '<error>Invalid response format.</error>';
+ return assistantMessage;
}
const stage = parsedResult.stage;
if (!stage) {
console.log('Error: No stage found in response');
- return '<error>Invalid response format: No stage found.</error>';
+ return assistantMessage;
}
for (const key in stage) {
- if (key === 'thought') {
- console.log(`Thought: ${stage[key]}`);
- thoughtNumber++;
- } else if (key === 'action') {
+ if (!assistantMessage.actions) {
+ assistantMessage.actions = [];
+ }
+ if (key === 'action') {
currentAction = stage[key] as string;
console.log(`Action: ${currentAction}`);
+ assistantMessage.actions.push({ index: assistantMessage.actions.length, action: currentAction, action_input: '' });
+ onUpdate({ ...assistantMessage });
if (this.tools[currentAction]) {
const nextPrompt = `<stage number="${i + 1}" role="user">` + builder.build({ action_rules: this.tools[currentAction].getActionRule() }) + `</stage>`;
this.interMessages.push({ role: 'user', content: nextPrompt });
@@ -93,8 +104,12 @@ export class Agent {
break;
}
} else if (key === 'action_input') {
- const actionInput = builder.build({ action_input: stage[key] });
+ const actionInput = stage[key];
console.log(`Action input: ${actionInput}`);
+ if (currentAction && assistantMessage.actions.length > 0) {
+ assistantMessage.actions[assistantMessage.actions.length - 1].action_input = actionInput;
+ onUpdate({ ...assistantMessage });
+ }
if (currentAction) {
try {
const observation = await this.processAction(currentAction, stage[key]);
@@ -104,24 +119,28 @@ export class Agent {
break;
} catch (error) {
console.log(`Error processing action: ${error}`);
- return `<error>${error}</error>`;
+ return assistantMessage;
}
} else {
console.log('Error: Action input without a valid action');
- return '<error>Action input without a valid action</error>';
+ return assistantMessage;
}
} else if (key === 'answer') {
console.log('Answer found. Ending query.');
- onUpdate(`ANSWER:${stage[key]}`);
- return result;
+ const parsedAnswer = AnswerParser.parse(`<answer>${stage[key]}</answer>`);
+ assistantMessage.content = parsedAnswer.content;
+ assistantMessage.follow_up_questions = parsedAnswer.follow_up_questions;
+ assistantMessage.citations = parsedAnswer.citations;
+ onUpdate({ ...assistantMessage });
+ return assistantMessage;
}
}
}
console.log('Reached maximum turns. Ending query.');
- return '<error>Reached maximum turns without finding an answer</error>';
+ return assistantMessage;
}
- private async execute(onUpdate: (update: string) => void, thoughtNumber: number): Promise<string> {
+ private async execute(assistantMessage: AssistantMessage, onUpdate: (update: AssistantMessage) => void): Promise<string> {
const stream = await this.client.chat.completions.create({
model: 'gpt-4o',
messages: this.interMessages as ChatCompletionMessageParam[],
@@ -130,32 +149,147 @@ export class Agent {
});
let fullResponse = '';
+ let currentTag = '';
let currentContent = '';
+ let isInsideTag = false;
+ let isInsideActionInput = false;
+ let actionInputContent = '';
+
+ if (!assistantMessage.actions) {
+ assistantMessage.actions = [];
+ }
for await (const chunk of stream) {
const content = chunk.choices[0]?.delta?.content || '';
fullResponse += content;
- currentContent += content;
+ for (const char of content) {
+ if (char === '<') {
+ isInsideTag = true;
+ if (currentTag && currentContent) {
+ if (currentTag === 'action_input') {
+ assistantMessage.actions[assistantMessage.actions.length - 1].action_input = actionInputContent;
+ actionInputContent = '';
+ } else {
+ this.processStreamedContent(currentTag, currentContent, assistantMessage);
+ }
+ onUpdate({ ...assistantMessage });
+ }
+ currentTag = '';
+ currentContent = '';
+ } else if (char === '>') {
+ isInsideTag = false;
+ if (currentTag === 'action_input') {
+ isInsideActionInput = true;
+ } else if (currentTag === '/action_input') {
+ isInsideActionInput = false;
+ assistantMessage.actions[assistantMessage.actions.length - 1].action_input = actionInputContent;
+ actionInputContent = '';
+ onUpdate({ ...assistantMessage });
+ }
+ if (currentTag.startsWith('/')) {
+ currentTag = '';
+ }
+ } else if (isInsideTag) {
+ currentTag += char;
+ } else if (isInsideActionInput) {
+ actionInputContent += char;
+ } else {
+ currentContent += char;
+ if (currentTag === 'thought' || currentTag === 'action') {
+ this.processStreamedContent(currentTag, currentContent, assistantMessage);
+ onUpdate({ ...assistantMessage });
+ }
+ }
+ }
+ }
- console.log(currentContent);
+ return fullResponse;
+ }
- if (currentContent.includes('<thought>')) {
- onUpdate(`THOUGHT${thoughtNumber}:${currentContent}`);
- }
- if (currentContent.includes('</thought>')) {
- currentContent = '';
- }
- if (currentContent.includes('<answer>')) {
- onUpdate(`ANSWER_START:${currentContent}`);
- }
- if (currentContent.includes('</answer>')) {
- onUpdate(`ANSWER_END:${currentContent}`);
- currentContent = '';
+ private processStreamedContent(tag: string, content: string, assistantMessage: AssistantMessage) {
+ if (!assistantMessage.thoughts) {
+ assistantMessage.thoughts = [];
+ }
+ if (!assistantMessage.actions) {
+ assistantMessage.actions = [];
+ }
+ switch (tag) {
+ case 'thought':
+ if (assistantMessage.thoughts.length > 0) {
+ assistantMessage.thoughts[assistantMessage.thoughts.length - 1] = content;
+ } else {
+ assistantMessage.thoughts.push(content);
+ }
+ break;
+ case 'action':
+ if (assistantMessage.actions.length > 0) {
+ assistantMessage.actions[assistantMessage.actions.length - 1].action = content;
+ } else {
+ assistantMessage.actions.push({ index: assistantMessage.actions.length, action: content, action_input: '' });
+ }
+ break;
+ case 'action_input':
+ if (assistantMessage.actions.length > 0) {
+ assistantMessage.actions[assistantMessage.actions.length - 1].action_input = content;
+ }
+ break;
+ }
+ }
+
+ private processAnswer(content: string, assistantMessage: AssistantMessage) {
+ const groundedTextRegex = /<grounded_text citation_index="([^"]+)">([\s\S]*?)<\/grounded_text>/g;
+ let lastIndex = 0;
+ let match;
+
+ while ((match = groundedTextRegex.exec(content)) !== null) {
+ const [fullMatch, citationIndex, groundedText] = match;
+
+ // Add normal text before the grounded text
+ if (match.index > lastIndex) {
+ const normalText = content.slice(lastIndex, match.index).trim();
+ if (normalText) {
+ assistantMessage.content.push({
+ index: assistantMessage.content.length,
+ type: TEXT_TYPE.NORMAL,
+ text: normalText,
+ citation_ids: null,
+ });
+ }
}
+
+ // Add grounded text
+ const citation_id = uuidv4();
+ assistantMessage.content.push({
+ index: assistantMessage.content.length,
+ type: TEXT_TYPE.GROUNDED,
+ text: groundedText.trim(),
+ citation_ids: [citation_id],
+ });
+
+ // Add citation
+ assistantMessage.citations?.push({
+ citation_id,
+ chunk_id: '',
+ type: CHUNK_TYPE.TEXT,
+ direct_text: '',
+ });
+
+ lastIndex = match.index + fullMatch.length;
}
- return fullResponse;
+ // Add any remaining normal text after the last grounded text
+ if (lastIndex < content.length) {
+ const remainingText = content.slice(lastIndex).trim();
+ if (remainingText) {
+ assistantMessage.content.push({
+ index: assistantMessage.content.length,
+ type: TEXT_TYPE.NORMAL,
+ text: remainingText,
+ citation_ids: null,
+ });
+ }
+ }
}
private async processAction(action: string, actionInput: any): Promise<any> {