diff options
author | A.J. Shulman <Shulman.aj@gmail.com> | 2024-08-19 10:55:57 -0400 |
---|---|---|
committer | A.J. Shulman <Shulman.aj@gmail.com> | 2024-08-19 10:55:57 -0400 |
commit | ff3c041af6738d025926732115a032d40cffb859 (patch) | |
tree | e55ad36ed1ef999b498bd4f87e505bf7ff6c3263 | |
parent | 2c38022a7f21d4b498277b18ad31baf24ac3a143 (diff) |
working on making streaming work
-rw-r--r-- | src/client/views/nodes/ChatBox/Agent.ts | 125 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/AnswerParser.ts | 14 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/types.ts | 1 |
3 files changed, 40 insertions, 100 deletions
diff --git a/src/client/views/nodes/ChatBox/Agent.ts b/src/client/views/nodes/ChatBox/Agent.ts index 2c7c40e0c..413ecbd41 100644 --- a/src/client/views/nodes/ChatBox/Agent.ts +++ b/src/client/views/nodes/ChatBox/Agent.ts @@ -27,6 +27,8 @@ export class Agent { private _history: () => string; private _summaries: () => string; private _csvData: () => { filename: string; id: string; text: string }[]; + private actionNumber: number = 0; + private thoughtNumber: number = 0; constructor(_vectorstore: Vectorstore, summaries: () => string, history: () => string, csvData: () => { filename: string; id: string; text: string }[], addLinkedUrlDoc: (url: string, id: string) => void) { this.client = new OpenAI({ apiKey: process.env.OPENAI_KEY, dangerouslyAllowBrowser: true }); @@ -44,14 +46,13 @@ export class Agent { }; } - async askAgent(question: string, maxTurns: number = 20, onUpdate: (update: AssistantMessage) => void): Promise<AssistantMessage> { + async askAgent(question: string, maxTurns: number = 30, onUpdate: (update: AssistantMessage) => void): Promise<AssistantMessage> { console.log(`Starting query: ${question}`); this.messages.push({ role: 'user', content: question }); const chatHistory = this._history(); const systemPrompt = getReactPrompt(Object.values(this.tools), this._summaries, chatHistory); this.interMessages = [{ role: 'system', content: systemPrompt }]; this.interMessages.push({ role: 'user', content: `<stage number="1" role="user"><query>${question}</query></stage>` }); - const parser = new XMLParser({ ignoreAttributes: false, attributeNamePrefix: '@_' }); const builder = new XMLBuilder({ ignoreAttributes: false, attributeNamePrefix: '@_' }); let currentAction: string | undefined; @@ -75,6 +76,7 @@ export class Agent { parsedResult = parser.parse(result); } catch (error) { console.log('Error: Invalid XML response from bot'); + assistantMessage.content.push({ index: assistantMessage.content.length, type: TEXT_TYPE.ERROR, text: 'Invalid response from bot', citation_ids: null }); return assistantMessage; } @@ -82,6 +84,7 @@ export class Agent { if (!stage) { console.log('Error: No stage found in response'); + assistantMessage.content.push({ index: assistantMessage.content.length, type: TEXT_TYPE.ERROR, text: 'Invalid response from bot', citation_ids: null }); return assistantMessage; } @@ -89,13 +92,20 @@ export class Agent { if (!assistantMessage.actions) { assistantMessage.actions = []; } - if (key === 'action') { + if (key === 'thought') { + console.log(`Thought: ${stage[key]}`); + this.thoughtNumber++; + } else if (key === 'action') { currentAction = stage[key] as string; console.log(`Action: ${currentAction}`); - assistantMessage.actions.push({ index: assistantMessage.actions.length, action: currentAction, action_input: '' }); onUpdate({ ...assistantMessage }); if (this.tools[currentAction]) { - const nextPrompt = `<stage number="${i + 1}" role="user">` + builder.build({ action_rules: this.tools[currentAction].getActionRule() }) + `</stage>`; + const nextPrompt = [ + { + type: 'text', + text: `<stage number="${i + 1}" role="user">` + builder.build({ action_rules: this.tools[currentAction].getActionRule() }) + `</stage>`, + }, + ]; this.interMessages.push({ role: 'user', content: nextPrompt }); break; } else { @@ -104,35 +114,31 @@ export class Agent { break; } } else if (key === 'action_input') { - const actionInput = stage[key]; + const actionInput = builder.build({ action_input: stage[key] }); console.log(`Action input: ${actionInput}`); - if (currentAction && assistantMessage.actions.length > 0) { - assistantMessage.actions[assistantMessage.actions.length - 1].action_input = actionInput; - onUpdate({ ...assistantMessage }); - } if (currentAction) { try { const observation = await this.processAction(currentAction, stage[key]); - const nextPrompt = `<stage number="${i + 1}" role="user"><observation>${observation}</observation></stage>`; + const nextPrompt = [{ type: 'text', text: `<stage number="${i + 1}" role="user"> <observation>` }, ...observation, { type: 'text', text: '</observation></stage>' }]; console.log(observation); this.interMessages.push({ role: 'user', content: nextPrompt }); + this.actionNumber++; //might not work with no tool break; } catch (error) { console.log(`Error processing action: ${error}`); + assistantMessage.content.push({ index: assistantMessage.content.length, type: TEXT_TYPE.ERROR, text: 'Invalid response from bot', citation_ids: null }); return assistantMessage; } } else { console.log('Error: Action input without a valid action'); + assistantMessage.content.push({ index: assistantMessage.content.length, type: TEXT_TYPE.ERROR, text: 'Invalid response from bot', citation_ids: null }); return assistantMessage; } } else if (key === 'answer') { console.log('Answer found. Ending query.'); - const parsedAnswer = AnswerParser.parse(`<answer>${stage[key]}</answer>`); - assistantMessage.content = parsedAnswer.content; - assistantMessage.follow_up_questions = parsedAnswer.follow_up_questions; - assistantMessage.citations = parsedAnswer.citations; - onUpdate({ ...assistantMessage }); - return assistantMessage; + const parsedAnswer = AnswerParser.parse(result, assistantMessage); + onUpdate({ ...parsedAnswer }); + return parsedAnswer; } } } @@ -148,12 +154,12 @@ export class Agent { stream: true, }); - let fullResponse = ''; - let currentTag = ''; - let currentContent = ''; - let isInsideTag = false; - let isInsideActionInput = false; - let actionInputContent = ''; + let fullResponse: string = ''; + let currentTag: string = ''; + let currentContent: string = ''; + let isInsideTag: boolean = false; + let isInsideActionInput: boolean = false; + let actionInputContent: string = ''; if (!assistantMessage.actions) { assistantMessage.actions = []; @@ -183,6 +189,7 @@ export class Agent { isInsideActionInput = true; } else if (currentTag === '/action_input') { isInsideActionInput = false; + console.log('Action input:', actionInputContent); assistantMessage.actions[assistantMessage.actions.length - 1].action_input = actionInputContent; actionInputContent = ''; onUpdate({ ...assistantMessage }); @@ -216,82 +223,18 @@ export class Agent { } switch (tag) { case 'thought': - if (assistantMessage.thoughts.length > 0) { - assistantMessage.thoughts[assistantMessage.thoughts.length - 1] = content; - } else { - assistantMessage.thoughts.push(content); - } + assistantMessage.thoughts[this.thoughtNumber] = content; break; case 'action': - if (assistantMessage.actions.length > 0) { - assistantMessage.actions[assistantMessage.actions.length - 1].action = content; - } else { - assistantMessage.actions.push({ index: assistantMessage.actions.length, action: content, action_input: '' }); - } + assistantMessage.actions[this.actionNumber].action = content; + break; case 'action_input': - if (assistantMessage.actions.length > 0) { - assistantMessage.actions[assistantMessage.actions.length - 1].action_input = content; - } + assistantMessage.actions[this.actionNumber].action_input = content; break; } } - private processAnswer(content: string, assistantMessage: AssistantMessage) { - const groundedTextRegex = /<grounded_text citation_index="([^"]+)">([\s\S]*?)<\/grounded_text>/g; - let lastIndex = 0; - let match; - - while ((match = groundedTextRegex.exec(content)) !== null) { - const [fullMatch, citationIndex, groundedText] = match; - - // Add normal text before the grounded text - if (match.index > lastIndex) { - const normalText = content.slice(lastIndex, match.index).trim(); - if (normalText) { - assistantMessage.content.push({ - index: assistantMessage.content.length, - type: TEXT_TYPE.NORMAL, - text: normalText, - citation_ids: null, - }); - } - } - - // Add grounded text - const citation_id = uuidv4(); - assistantMessage.content.push({ - index: assistantMessage.content.length, - type: TEXT_TYPE.GROUNDED, - text: groundedText.trim(), - citation_ids: [citation_id], - }); - - // Add citation - assistantMessage.citations?.push({ - citation_id, - chunk_id: '', - type: CHUNK_TYPE.TEXT, - direct_text: '', - }); - - lastIndex = match.index + fullMatch.length; - } - - // Add any remaining normal text after the last grounded text - if (lastIndex < content.length) { - const remainingText = content.slice(lastIndex).trim(); - if (remainingText) { - assistantMessage.content.push({ - index: assistantMessage.content.length, - type: TEXT_TYPE.NORMAL, - text: remainingText, - citation_ids: null, - }); - } - } - } - private async processAction(action: string, actionInput: any): Promise<any> { if (!(action in this.tools)) { throw new Error(`Unknown action: ${action}`); diff --git a/src/client/views/nodes/ChatBox/AnswerParser.ts b/src/client/views/nodes/ChatBox/AnswerParser.ts index 9956792d8..68637b7c7 100644 --- a/src/client/views/nodes/ChatBox/AnswerParser.ts +++ b/src/client/views/nodes/ChatBox/AnswerParser.ts @@ -2,7 +2,7 @@ import { ASSISTANT_ROLE, AssistantMessage, Citation, CHUNK_TYPE, TEXT_TYPE, getC import { v4 as uuid } from 'uuid'; export class AnswerParser { - static parse(xml: string): AssistantMessage { + static parse(xml: string, currentMessage: AssistantMessage): AssistantMessage { const answerRegex = /<answer>([\s\S]*?)<\/answer>/; const citationsRegex = /<citations>([\s\S]*?)<\/citations>/; const citationRegex = /<citation index="([^"]+)" chunk_id="([^"]+)" type="([^"]+)">([\s\S]*?)<\/citation>/g; @@ -102,14 +102,10 @@ export class AnswerParser { followUpQuestions.push(questionMatch[1].trim()); } } + currentMessage.content = currentMessage.content.concat(content); + currentMessage.citations = citations; + currentMessage.follow_up_questions = followUpQuestions; - const assistantResponse: AssistantMessage = { - role: ASSISTANT_ROLE.ASSISTANT, - content, - follow_up_questions: followUpQuestions, - citations, - }; - - return assistantResponse; + return currentMessage; } } diff --git a/src/client/views/nodes/ChatBox/types.ts b/src/client/views/nodes/ChatBox/types.ts index efeec7b93..b4e66bdbe 100644 --- a/src/client/views/nodes/ChatBox/types.ts +++ b/src/client/views/nodes/ChatBox/types.ts @@ -10,6 +10,7 @@ export enum ASSISTANT_ROLE { export enum TEXT_TYPE { NORMAL = 'normal', GROUNDED = 'grounded', + ERROR = 'error', } export enum CHUNK_TYPE { |