From ff3c041af6738d025926732115a032d40cffb859 Mon Sep 17 00:00:00 2001 From: "A.J. Shulman" Date: Mon, 19 Aug 2024 10:55:57 -0400 Subject: working on making streaming work --- src/client/views/nodes/ChatBox/Agent.ts | 125 +++++++------------------ src/client/views/nodes/ChatBox/AnswerParser.ts | 14 +-- src/client/views/nodes/ChatBox/types.ts | 1 + 3 files changed, 40 insertions(+), 100 deletions(-) (limited to 'src') diff --git a/src/client/views/nodes/ChatBox/Agent.ts b/src/client/views/nodes/ChatBox/Agent.ts index 2c7c40e0c..413ecbd41 100644 --- a/src/client/views/nodes/ChatBox/Agent.ts +++ b/src/client/views/nodes/ChatBox/Agent.ts @@ -27,6 +27,8 @@ export class Agent { private _history: () => string; private _summaries: () => string; private _csvData: () => { filename: string; id: string; text: string }[]; + private actionNumber: number = 0; + private thoughtNumber: number = 0; constructor(_vectorstore: Vectorstore, summaries: () => string, history: () => string, csvData: () => { filename: string; id: string; text: string }[], addLinkedUrlDoc: (url: string, id: string) => void) { this.client = new OpenAI({ apiKey: process.env.OPENAI_KEY, dangerouslyAllowBrowser: true }); @@ -44,14 +46,13 @@ export class Agent { }; } - async askAgent(question: string, maxTurns: number = 20, onUpdate: (update: AssistantMessage) => void): Promise { + async askAgent(question: string, maxTurns: number = 30, onUpdate: (update: AssistantMessage) => void): Promise { console.log(`Starting query: ${question}`); this.messages.push({ role: 'user', content: question }); const chatHistory = this._history(); const systemPrompt = getReactPrompt(Object.values(this.tools), this._summaries, chatHistory); this.interMessages = [{ role: 'system', content: systemPrompt }]; this.interMessages.push({ role: 'user', content: `${question}` }); - const parser = new XMLParser({ ignoreAttributes: false, attributeNamePrefix: '@_' }); const builder = new XMLBuilder({ ignoreAttributes: false, attributeNamePrefix: '@_' }); let currentAction: string | undefined; @@ -75,6 +76,7 @@ export class Agent { parsedResult = parser.parse(result); } catch (error) { console.log('Error: Invalid XML response from bot'); + assistantMessage.content.push({ index: assistantMessage.content.length, type: TEXT_TYPE.ERROR, text: 'Invalid response from bot', citation_ids: null }); return assistantMessage; } @@ -82,6 +84,7 @@ export class Agent { if (!stage) { console.log('Error: No stage found in response'); + assistantMessage.content.push({ index: assistantMessage.content.length, type: TEXT_TYPE.ERROR, text: 'Invalid response from bot', citation_ids: null }); return assistantMessage; } @@ -89,13 +92,20 @@ export class Agent { if (!assistantMessage.actions) { assistantMessage.actions = []; } - if (key === 'action') { + if (key === 'thought') { + console.log(`Thought: ${stage[key]}`); + this.thoughtNumber++; + } else if (key === 'action') { currentAction = stage[key] as string; console.log(`Action: ${currentAction}`); - assistantMessage.actions.push({ index: assistantMessage.actions.length, action: currentAction, action_input: '' }); onUpdate({ ...assistantMessage }); if (this.tools[currentAction]) { - const nextPrompt = `` + builder.build({ action_rules: this.tools[currentAction].getActionRule() }) + ``; + const nextPrompt = [ + { + type: 'text', + text: `` + builder.build({ action_rules: this.tools[currentAction].getActionRule() }) + ``, + }, + ]; this.interMessages.push({ role: 'user', content: nextPrompt }); break; } else { @@ -104,35 +114,31 @@ export class Agent { break; } } else if (key === 'action_input') { - const actionInput = stage[key]; + const actionInput = builder.build({ action_input: stage[key] }); console.log(`Action input: ${actionInput}`); - if (currentAction && assistantMessage.actions.length > 0) { - assistantMessage.actions[assistantMessage.actions.length - 1].action_input = actionInput; - onUpdate({ ...assistantMessage }); - } if (currentAction) { try { const observation = await this.processAction(currentAction, stage[key]); - const nextPrompt = `${observation}`; + const nextPrompt = [{ type: 'text', text: ` ` }, ...observation, { type: 'text', text: '' }]; console.log(observation); this.interMessages.push({ role: 'user', content: nextPrompt }); + this.actionNumber++; //might not work with no tool break; } catch (error) { console.log(`Error processing action: ${error}`); + assistantMessage.content.push({ index: assistantMessage.content.length, type: TEXT_TYPE.ERROR, text: 'Invalid response from bot', citation_ids: null }); return assistantMessage; } } else { console.log('Error: Action input without a valid action'); + assistantMessage.content.push({ index: assistantMessage.content.length, type: TEXT_TYPE.ERROR, text: 'Invalid response from bot', citation_ids: null }); return assistantMessage; } } else if (key === 'answer') { console.log('Answer found. Ending query.'); - const parsedAnswer = AnswerParser.parse(`${stage[key]}`); - assistantMessage.content = parsedAnswer.content; - assistantMessage.follow_up_questions = parsedAnswer.follow_up_questions; - assistantMessage.citations = parsedAnswer.citations; - onUpdate({ ...assistantMessage }); - return assistantMessage; + const parsedAnswer = AnswerParser.parse(result, assistantMessage); + onUpdate({ ...parsedAnswer }); + return parsedAnswer; } } } @@ -148,12 +154,12 @@ export class Agent { stream: true, }); - let fullResponse = ''; - let currentTag = ''; - let currentContent = ''; - let isInsideTag = false; - let isInsideActionInput = false; - let actionInputContent = ''; + let fullResponse: string = ''; + let currentTag: string = ''; + let currentContent: string = ''; + let isInsideTag: boolean = false; + let isInsideActionInput: boolean = false; + let actionInputContent: string = ''; if (!assistantMessage.actions) { assistantMessage.actions = []; @@ -183,6 +189,7 @@ export class Agent { isInsideActionInput = true; } else if (currentTag === '/action_input') { isInsideActionInput = false; + console.log('Action input:', actionInputContent); assistantMessage.actions[assistantMessage.actions.length - 1].action_input = actionInputContent; actionInputContent = ''; onUpdate({ ...assistantMessage }); @@ -216,82 +223,18 @@ export class Agent { } switch (tag) { case 'thought': - if (assistantMessage.thoughts.length > 0) { - assistantMessage.thoughts[assistantMessage.thoughts.length - 1] = content; - } else { - assistantMessage.thoughts.push(content); - } + assistantMessage.thoughts[this.thoughtNumber] = content; break; case 'action': - if (assistantMessage.actions.length > 0) { - assistantMessage.actions[assistantMessage.actions.length - 1].action = content; - } else { - assistantMessage.actions.push({ index: assistantMessage.actions.length, action: content, action_input: '' }); - } + assistantMessage.actions[this.actionNumber].action = content; + break; case 'action_input': - if (assistantMessage.actions.length > 0) { - assistantMessage.actions[assistantMessage.actions.length - 1].action_input = content; - } + assistantMessage.actions[this.actionNumber].action_input = content; break; } } - private processAnswer(content: string, assistantMessage: AssistantMessage) { - const groundedTextRegex = /([\s\S]*?)<\/grounded_text>/g; - let lastIndex = 0; - let match; - - while ((match = groundedTextRegex.exec(content)) !== null) { - const [fullMatch, citationIndex, groundedText] = match; - - // Add normal text before the grounded text - if (match.index > lastIndex) { - const normalText = content.slice(lastIndex, match.index).trim(); - if (normalText) { - assistantMessage.content.push({ - index: assistantMessage.content.length, - type: TEXT_TYPE.NORMAL, - text: normalText, - citation_ids: null, - }); - } - } - - // Add grounded text - const citation_id = uuidv4(); - assistantMessage.content.push({ - index: assistantMessage.content.length, - type: TEXT_TYPE.GROUNDED, - text: groundedText.trim(), - citation_ids: [citation_id], - }); - - // Add citation - assistantMessage.citations?.push({ - citation_id, - chunk_id: '', - type: CHUNK_TYPE.TEXT, - direct_text: '', - }); - - lastIndex = match.index + fullMatch.length; - } - - // Add any remaining normal text after the last grounded text - if (lastIndex < content.length) { - const remainingText = content.slice(lastIndex).trim(); - if (remainingText) { - assistantMessage.content.push({ - index: assistantMessage.content.length, - type: TEXT_TYPE.NORMAL, - text: remainingText, - citation_ids: null, - }); - } - } - } - private async processAction(action: string, actionInput: any): Promise { if (!(action in this.tools)) { throw new Error(`Unknown action: ${action}`); diff --git a/src/client/views/nodes/ChatBox/AnswerParser.ts b/src/client/views/nodes/ChatBox/AnswerParser.ts index 9956792d8..68637b7c7 100644 --- a/src/client/views/nodes/ChatBox/AnswerParser.ts +++ b/src/client/views/nodes/ChatBox/AnswerParser.ts @@ -2,7 +2,7 @@ import { ASSISTANT_ROLE, AssistantMessage, Citation, CHUNK_TYPE, TEXT_TYPE, getC import { v4 as uuid } from 'uuid'; export class AnswerParser { - static parse(xml: string): AssistantMessage { + static parse(xml: string, currentMessage: AssistantMessage): AssistantMessage { const answerRegex = /([\s\S]*?)<\/answer>/; const citationsRegex = /([\s\S]*?)<\/citations>/; const citationRegex = /([\s\S]*?)<\/citation>/g; @@ -102,14 +102,10 @@ export class AnswerParser { followUpQuestions.push(questionMatch[1].trim()); } } + currentMessage.content = currentMessage.content.concat(content); + currentMessage.citations = citations; + currentMessage.follow_up_questions = followUpQuestions; - const assistantResponse: AssistantMessage = { - role: ASSISTANT_ROLE.ASSISTANT, - content, - follow_up_questions: followUpQuestions, - citations, - }; - - return assistantResponse; + return currentMessage; } } diff --git a/src/client/views/nodes/ChatBox/types.ts b/src/client/views/nodes/ChatBox/types.ts index efeec7b93..b4e66bdbe 100644 --- a/src/client/views/nodes/ChatBox/types.ts +++ b/src/client/views/nodes/ChatBox/types.ts @@ -10,6 +10,7 @@ export enum ASSISTANT_ROLE { export enum TEXT_TYPE { NORMAL = 'normal', GROUNDED = 'grounded', + ERROR = 'error', } export enum CHUNK_TYPE { -- cgit v1.2.3-70-g09d2