diff options
author | A.J. Shulman <Shulman.aj@gmail.com> | 2024-08-16 15:45:23 -0400 |
---|---|---|
committer | A.J. Shulman <Shulman.aj@gmail.com> | 2024-08-16 15:45:23 -0400 |
commit | daa72b906e3364c2b6a836533fc1980bb63ba303 (patch) | |
tree | da6b0d25f7ff547e460832b1de823cd2a909ac85 | |
parent | d97405e0a172b03a759452a1e9a7291974d89248 (diff) |
now shows thoughts in real time
next steps:
integrate everything with the AnswerParser
make sure citations work perfectly (right now clicking citations isn't perfect for urls and multiple citations for the same url source are generated—check examples for mistakes)
-rw-r--r-- | src/client/views/nodes/ChatBox/Agent.ts | 75 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/ChatBox.scss | 4 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/ChatBox.tsx | 62 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/MessageComponent.tsx | 8 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/types.ts | 1 |
5 files changed, 106 insertions, 44 deletions
diff --git a/src/client/views/nodes/ChatBox/Agent.ts b/src/client/views/nodes/ChatBox/Agent.ts index 41c91b4c6..8bad29d9a 100644 --- a/src/client/views/nodes/ChatBox/Agent.ts +++ b/src/client/views/nodes/ChatBox/Agent.ts @@ -2,19 +2,17 @@ import OpenAI from 'openai'; import { Tool, AgentMessage } from './types'; import { getReactPrompt } from './prompts'; import { XMLParser, XMLBuilder } from 'fast-xml-parser'; -import { WikipediaTool } from './tools/WikipediaTool'; -import { CalculateTool } from './tools/CalculateTool'; -import { RAGTool } from './tools/RAGTool'; -import { NoTool } from './tools/NoTool'; import { Vectorstore } from './vectorstore/Vectorstore'; -import { ChatCompletionAssistantMessageParam, ChatCompletionMessageParam } from 'openai/resources'; +import { ChatCompletionMessageParam } from 'openai/resources'; import dotenv from 'dotenv'; -import { ChatBox } from './ChatBox'; +import { CalculateTool } from './tools/CalculateTool'; +import { RAGTool } from './tools/RAGTool'; import { DataAnalysisTool } from './tools/DataAnalysisTool'; -import { string } from 'cohere-ai/core/schemas'; import { WebsiteInfoScraperTool } from './tools/WebsiteInfoScraperTool'; import { SearchTool } from './tools/SearchTool'; -import { add } from 'lodash'; +import { NoTool } from './tools/NoTool'; +import { on } from 'events'; + dotenv.config(); export class Agent { @@ -28,14 +26,12 @@ export class Agent { private _csvData: () => { filename: string; id: string; text: string }[]; constructor(_vectorstore: Vectorstore, summaries: () => string, history: () => string, csvData: () => { filename: string; id: string; text: string }[], addLinkedUrlDoc: (url: string, id: string) => void) { - console.log(process.env.OPENAI_KEY); this.client = new OpenAI({ apiKey: process.env.OPENAI_KEY, dangerouslyAllowBrowser: true }); this.vectorstore = _vectorstore; this._history = history; this._summaries = summaries; this._csvData = csvData; this.tools = { - //wikipedia: new WikipediaTool(addLinkedUrlDoc), calculate: new CalculateTool(), rag: new RAGTool(this.vectorstore), dataAnalysis: new DataAnalysisTool(csvData), @@ -45,26 +41,24 @@ export class Agent { }; } - async askAgent(question: string, maxTurns: number = 20): Promise<string> { + async askAgent(question: string, maxTurns: number = 20, onUpdate: (update: string) => void): Promise<string> { console.log(`Starting query: ${question}`); this.messages.push({ role: 'user', content: question }); const chatHistory = this._history(); - console.log(`Chat history: ${chatHistory}`); const systemPrompt = getReactPrompt(Object.values(this.tools), this._summaries, chatHistory); - console.log(`System prompt: ${systemPrompt}`); this.interMessages = [{ role: 'system', content: systemPrompt }]; - this.interMessages.push({ role: 'user', content: `<stage number="1" role="user"><query>${question}</query></stage>` }); const parser = new XMLParser({ ignoreAttributes: false, attributeNamePrefix: '@_' }); const builder = new XMLBuilder({ ignoreAttributes: false, attributeNamePrefix: '@_' }); let currentAction: string | undefined; + let thoughtNumber = 0; + for (let i = 2; i < maxTurns; i += 2) { console.log(`Turn ${i}/${maxTurns}`); - const result = await this.execute(); - console.log(`Bot response: ${result}`); + const result = await this.execute(onUpdate, thoughtNumber); this.interMessages.push({ role: 'assistant', content: result }); let parsedResult; @@ -85,18 +79,13 @@ export class Agent { for (const key in stage) { if (key === 'thought') { console.log(`Thought: ${stage[key]}`); + thoughtNumber++; } else if (key === 'action') { currentAction = stage[key] as string; console.log(`Action: ${currentAction}`); if (this.tools[currentAction]) { - const nextPrompt = [ - { - type: 'text', - text: `<stage number="${i + 1}" role="user">` + builder.build({ action_rules: this.tools[currentAction].getActionRule() }) + `</stage>`, - }, - ]; + const nextPrompt = `<stage number="${i + 1}" role="user">` + builder.build({ action_rules: this.tools[currentAction].getActionRule() }) + `</stage>`; this.interMessages.push({ role: 'user', content: nextPrompt }); - break; } else { console.log('Error: No valid action'); @@ -109,7 +98,7 @@ export class Agent { if (currentAction) { try { const observation = await this.processAction(currentAction, stage[key]); - const nextPrompt = [{ type: 'text', text: `<stage number="${i + 1}" role="user"> <observation>` }, ...observation, { type: 'text', text: '</observation></stage>' }]; + const nextPrompt = `<stage number="${i + 1}" role="user"><observation>${observation}</observation></stage>`; console.log(observation); this.interMessages.push({ role: 'user', content: nextPrompt }); break; @@ -123,24 +112,50 @@ export class Agent { } } else if (key === 'answer') { console.log('Answer found. Ending query.'); + onUpdate(`ANSWER:${stage[key]}`); return result; } } } - console.log(this.messages); console.log('Reached maximum turns. Ending query.'); return '<error>Reached maximum turns without finding an answer</error>'; } - private async execute(): Promise<string> { - console.log(this.interMessages); - const completion = await this.client.chat.completions.create({ + private async execute(onUpdate: (update: string) => void, thoughtNumber: number): Promise<string> { + const stream = await this.client.chat.completions.create({ model: 'gpt-4o', messages: this.interMessages as ChatCompletionMessageParam[], temperature: 0, + stream: true, }); - if (completion.choices[0].message.content) return completion.choices[0].message.content; - else throw new Error('No completion content found'); + + let fullResponse = ''; + let currentContent = ''; + + for await (const chunk of stream) { + const content = chunk.choices[0]?.delta?.content || ''; + fullResponse += content; + + currentContent += content; + + console.log(currentContent); + + if (currentContent.includes('<thought>')) { + onUpdate(`THOUGHT${thoughtNumber}:${currentContent}`); + } + if (currentContent.includes('</thought>')) { + currentContent = ''; + } + if (currentContent.includes('<answer>')) { + onUpdate(`ANSWER_START:${currentContent}`); + } + if (currentContent.includes('</answer>')) { + onUpdate(`ANSWER_END:${currentContent}`); + currentContent = ''; + } + } + + return fullResponse; } private async processAction(action: string, actionInput: any): Promise<any> { diff --git a/src/client/views/nodes/ChatBox/ChatBox.scss b/src/client/views/nodes/ChatBox/ChatBox.scss index e39938c4f..91bb3aba7 100644 --- a/src/client/views/nodes/ChatBox/ChatBox.scss +++ b/src/client/views/nodes/ChatBox/ChatBox.scss @@ -239,4 +239,8 @@ $follow-up-hover-bg-color: #dee2e6; } } } + .thought-text { + color: #6c757d; + font-style: italic; + } } diff --git a/src/client/views/nodes/ChatBox/ChatBox.tsx b/src/client/views/nodes/ChatBox/ChatBox.tsx index 8d09cde1e..d38c71810 100644 --- a/src/client/views/nodes/ChatBox/ChatBox.tsx +++ b/src/client/views/nodes/ChatBox/ChatBox.tsx @@ -149,26 +149,50 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() { if (trimmedText) { try { + //Make everything go through the answer parser + //Pass in all the updates to the AnswrParser and it will create the assistant messasge that will be the current message including adding in the thoughts and also waiting for the asnwer and also showing tool progress textInput.value = ''; - runInAction(() => { - this.history.push({ role: ASSISTANT_ROLE.USER, content: [{ index: 0, type: TEXT_TYPE.NORMAL, text: trimmedText, citation_ids: null }] }); - this.isLoading = true; - }); + this.history.push({ role: ASSISTANT_ROLE.USER, content: [{ index: 0, type: TEXT_TYPE.NORMAL, text: trimmedText, citation_ids: null }] }); + this.isLoading = true; + this.current_message = { role: ASSISTANT_ROLE.ASSISTANT, content: [], thoughts: [] }; + + let currentThought = ''; + + this.current_message?.thoughts?.push(currentThought); + + const onUpdate = (update: string) => { + const thoughtNumber = Number(update.match(/^THOUGHT(\d+):/)?.[1] ?? 0); + const regex = /<thought>\s*([\s\S]*?)(?:<\/thought>|$)/; + const match = update.match(regex); + const currentThought = match ? match[1].trim() : ''; + //const numericPrefix = Number(update.match(/^\d+/)?.[0]); + if (update.startsWith('THOUGHT')) { + console.log('Thought:', currentThought, thoughtNumber); + if (this.current_message?.thoughts) { + if (this.current_message.thoughts.length <= thoughtNumber) { + this.current_message.thoughts.push(currentThought); + } else { + this.current_message.thoughts[thoughtNumber] = currentThought; + } + } + console.log('Thoughts:', this.current_message?.thoughts); + } + }; - const response = await this.agent.askAgent(trimmedText); // Use the chatbot to get the response - runInAction(() => { - this.history.push(AnswerParser.parse(response)); - }); + const response = await this.agent.askAgent(trimmedText, 20, onUpdate); + const parsedAnswer = AnswerParser.parse(response); + parsedAnswer.thoughts = this.current_message?.thoughts; + + if (this.current_message) { + this.history.push(parsedAnswer); + this.current_message = undefined; + } this.dataDoc.data = JSON.stringify(this.history); } catch (err) { console.error('Error:', err); - runInAction(() => { - this.history.push({ role: ASSISTANT_ROLE.ASSISTANT, content: [{ index: 0, type: TEXT_TYPE.NORMAL, text: 'Sorry, I encountered an error while processing your request.', citation_ids: null }] }); - }); + this.history.push({ role: ASSISTANT_ROLE.ASSISTANT, content: [{ index: 0, type: TEXT_TYPE.NORMAL, text: 'Sorry, I encountered an error while processing your request.', citation_ids: null }] }); } finally { - runInAction(() => { - this.isLoading = false; - }); + this.isLoading = false; } } }; @@ -416,6 +440,16 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() { {this.history.map((message, index) => ( <MessageComponentBox key={index} message={message} index={index} onFollowUpClick={this.handleFollowUpClick} onCitationClick={this.handleCitationClick} updateMessageCitations={this.updateMessageCitations} /> ))} + {!this.current_message ? null : ( + <MessageComponentBox + key={this.history.length} + message={this.current_message} + index={this.history.length} + onFollowUpClick={this.handleFollowUpClick} + onCitationClick={this.handleCitationClick} + updateMessageCitations={this.updateMessageCitations} + /> + )} </div> </div> <form onSubmit={this.askGPT} className="chat-form"> diff --git a/src/client/views/nodes/ChatBox/MessageComponent.tsx b/src/client/views/nodes/ChatBox/MessageComponent.tsx index 07bfd4e3d..70b0527a2 100644 --- a/src/client/views/nodes/ChatBox/MessageComponent.tsx +++ b/src/client/views/nodes/ChatBox/MessageComponent.tsx @@ -62,6 +62,14 @@ const MessageComponentBox: React.FC<MessageComponentProps> = function ({ message return ( <div className={`message ${message.role}`}> + <div> + {message.thoughts && + message.thoughts.map((thought, index) => ( + <span key={index} className="thought-text"> + <i>Thought: {thought}</i> + </span> + ))} + </div> <div>{message.content && message.content.map(messageFragment => <React.Fragment key={messageFragment.index}>{renderContent(messageFragment)}</React.Fragment>)}</div> {message.follow_up_questions && message.follow_up_questions.length > 0 && ( <div className="follow-up-questions"> diff --git a/src/client/views/nodes/ChatBox/types.ts b/src/client/views/nodes/ChatBox/types.ts index 1c7aaa4b7..391f124e0 100644 --- a/src/client/views/nodes/ChatBox/types.ts +++ b/src/client/views/nodes/ChatBox/types.ts @@ -47,6 +47,7 @@ export interface AssistantMessage { role: ASSISTANT_ROLE; content: MessageContent[]; follow_up_questions?: string[]; + thoughts?: string[]; citations?: Citation[]; } |