From daa72b906e3364c2b6a836533fc1980bb63ba303 Mon Sep 17 00:00:00 2001 From: "A.J. Shulman" Date: Fri, 16 Aug 2024 15:45:23 -0400 Subject: now shows thoughts in real time MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit next steps: integrate everything with the AnswerParser make sure citations work perfectly (right now clicking citations isn't perfect for urls and multiple citations for the same url source are generated—check examples for mistakes) --- src/client/views/nodes/ChatBox/Agent.ts | 75 ++++++++++++++++++++------------- 1 file changed, 45 insertions(+), 30 deletions(-) (limited to 'src/client/views/nodes/ChatBox/Agent.ts') diff --git a/src/client/views/nodes/ChatBox/Agent.ts b/src/client/views/nodes/ChatBox/Agent.ts index 41c91b4c6..8bad29d9a 100644 --- a/src/client/views/nodes/ChatBox/Agent.ts +++ b/src/client/views/nodes/ChatBox/Agent.ts @@ -2,19 +2,17 @@ import OpenAI from 'openai'; import { Tool, AgentMessage } from './types'; import { getReactPrompt } from './prompts'; import { XMLParser, XMLBuilder } from 'fast-xml-parser'; -import { WikipediaTool } from './tools/WikipediaTool'; -import { CalculateTool } from './tools/CalculateTool'; -import { RAGTool } from './tools/RAGTool'; -import { NoTool } from './tools/NoTool'; import { Vectorstore } from './vectorstore/Vectorstore'; -import { ChatCompletionAssistantMessageParam, ChatCompletionMessageParam } from 'openai/resources'; +import { ChatCompletionMessageParam } from 'openai/resources'; import dotenv from 'dotenv'; -import { ChatBox } from './ChatBox'; +import { CalculateTool } from './tools/CalculateTool'; +import { RAGTool } from './tools/RAGTool'; import { DataAnalysisTool } from './tools/DataAnalysisTool'; -import { string } from 'cohere-ai/core/schemas'; import { WebsiteInfoScraperTool } from './tools/WebsiteInfoScraperTool'; import { SearchTool } from './tools/SearchTool'; -import { add } from 'lodash'; +import { NoTool } from './tools/NoTool'; +import { on } from 'events'; + dotenv.config(); export class Agent { @@ -28,14 +26,12 @@ export class Agent { private _csvData: () => { filename: string; id: string; text: string }[]; constructor(_vectorstore: Vectorstore, summaries: () => string, history: () => string, csvData: () => { filename: string; id: string; text: string }[], addLinkedUrlDoc: (url: string, id: string) => void) { - console.log(process.env.OPENAI_KEY); this.client = new OpenAI({ apiKey: process.env.OPENAI_KEY, dangerouslyAllowBrowser: true }); this.vectorstore = _vectorstore; this._history = history; this._summaries = summaries; this._csvData = csvData; this.tools = { - //wikipedia: new WikipediaTool(addLinkedUrlDoc), calculate: new CalculateTool(), rag: new RAGTool(this.vectorstore), dataAnalysis: new DataAnalysisTool(csvData), @@ -45,26 +41,24 @@ export class Agent { }; } - async askAgent(question: string, maxTurns: number = 20): Promise { + async askAgent(question: string, maxTurns: number = 20, onUpdate: (update: string) => void): Promise { console.log(`Starting query: ${question}`); this.messages.push({ role: 'user', content: question }); const chatHistory = this._history(); - console.log(`Chat history: ${chatHistory}`); const systemPrompt = getReactPrompt(Object.values(this.tools), this._summaries, chatHistory); - console.log(`System prompt: ${systemPrompt}`); this.interMessages = [{ role: 'system', content: systemPrompt }]; - this.interMessages.push({ role: 'user', content: `${question}` }); const parser = new XMLParser({ ignoreAttributes: false, attributeNamePrefix: '@_' }); const builder = new XMLBuilder({ ignoreAttributes: false, attributeNamePrefix: '@_' }); let currentAction: string | undefined; + let thoughtNumber = 0; + for (let i = 2; i < maxTurns; i += 2) { console.log(`Turn ${i}/${maxTurns}`); - const result = await this.execute(); - console.log(`Bot response: ${result}`); + const result = await this.execute(onUpdate, thoughtNumber); this.interMessages.push({ role: 'assistant', content: result }); let parsedResult; @@ -85,18 +79,13 @@ export class Agent { for (const key in stage) { if (key === 'thought') { console.log(`Thought: ${stage[key]}`); + thoughtNumber++; } else if (key === 'action') { currentAction = stage[key] as string; console.log(`Action: ${currentAction}`); if (this.tools[currentAction]) { - const nextPrompt = [ - { - type: 'text', - text: `` + builder.build({ action_rules: this.tools[currentAction].getActionRule() }) + ``, - }, - ]; + const nextPrompt = `` + builder.build({ action_rules: this.tools[currentAction].getActionRule() }) + ``; this.interMessages.push({ role: 'user', content: nextPrompt }); - break; } else { console.log('Error: No valid action'); @@ -109,7 +98,7 @@ export class Agent { if (currentAction) { try { const observation = await this.processAction(currentAction, stage[key]); - const nextPrompt = [{ type: 'text', text: ` ` }, ...observation, { type: 'text', text: '' }]; + const nextPrompt = `${observation}`; console.log(observation); this.interMessages.push({ role: 'user', content: nextPrompt }); break; @@ -123,24 +112,50 @@ export class Agent { } } else if (key === 'answer') { console.log('Answer found. Ending query.'); + onUpdate(`ANSWER:${stage[key]}`); return result; } } } - console.log(this.messages); console.log('Reached maximum turns. Ending query.'); return 'Reached maximum turns without finding an answer'; } - private async execute(): Promise { - console.log(this.interMessages); - const completion = await this.client.chat.completions.create({ + private async execute(onUpdate: (update: string) => void, thoughtNumber: number): Promise { + const stream = await this.client.chat.completions.create({ model: 'gpt-4o', messages: this.interMessages as ChatCompletionMessageParam[], temperature: 0, + stream: true, }); - if (completion.choices[0].message.content) return completion.choices[0].message.content; - else throw new Error('No completion content found'); + + let fullResponse = ''; + let currentContent = ''; + + for await (const chunk of stream) { + const content = chunk.choices[0]?.delta?.content || ''; + fullResponse += content; + + currentContent += content; + + console.log(currentContent); + + if (currentContent.includes('')) { + onUpdate(`THOUGHT${thoughtNumber}:${currentContent}`); + } + if (currentContent.includes('')) { + currentContent = ''; + } + if (currentContent.includes('')) { + onUpdate(`ANSWER_START:${currentContent}`); + } + if (currentContent.includes('')) { + onUpdate(`ANSWER_END:${currentContent}`); + currentContent = ''; + } + } + + return fullResponse; } private async processAction(action: string, actionInput: any): Promise { -- cgit v1.2.3-70-g09d2