import OpenAI from 'openai'; import { Tool, AgentMessage } from './types'; import { getReactPrompt } from './prompts'; import { XMLParser, XMLBuilder } from 'fast-xml-parser'; import { Vectorstore } from './vectorstore/Vectorstore'; import { ChatCompletionMessageParam } from 'openai/resources'; import dotenv from 'dotenv'; import { CalculateTool } from './tools/CalculateTool'; import { RAGTool } from './tools/RAGTool'; import { DataAnalysisTool } from './tools/DataAnalysisTool'; import { WebsiteInfoScraperTool } from './tools/WebsiteInfoScraperTool'; import { SearchTool } from './tools/SearchTool'; import { NoTool } from './tools/NoTool'; import { on } from 'events'; dotenv.config(); export class Agent { private client: OpenAI; private tools: Record>; private messages: AgentMessage[] = []; private interMessages: AgentMessage[] = []; private vectorstore: Vectorstore; private _history: () => string; private _summaries: () => string; private _csvData: () => { filename: string; id: string; text: string }[]; constructor(_vectorstore: Vectorstore, summaries: () => string, history: () => string, csvData: () => { filename: string; id: string; text: string }[], addLinkedUrlDoc: (url: string, id: string) => void) { this.client = new OpenAI({ apiKey: process.env.OPENAI_KEY, dangerouslyAllowBrowser: true }); this.vectorstore = _vectorstore; this._history = history; this._summaries = summaries; this._csvData = csvData; this.tools = { calculate: new CalculateTool(), rag: new RAGTool(this.vectorstore), dataAnalysis: new DataAnalysisTool(csvData), websiteInfoScraper: new WebsiteInfoScraperTool(addLinkedUrlDoc), searchTool: new SearchTool(addLinkedUrlDoc), no_tool: new NoTool(), }; } async askAgent(question: string, maxTurns: number = 20, onUpdate: (update: string) => void): Promise { console.log(`Starting query: ${question}`); this.messages.push({ role: 'user', content: question }); const chatHistory = this._history(); const systemPrompt = getReactPrompt(Object.values(this.tools), this._summaries, chatHistory); this.interMessages = [{ role: 'system', content: systemPrompt }]; this.interMessages.push({ role: 'user', content: `${question}` }); const parser = new XMLParser({ ignoreAttributes: false, attributeNamePrefix: '@_' }); const builder = new XMLBuilder({ ignoreAttributes: false, attributeNamePrefix: '@_' }); let currentAction: string | undefined; let thoughtNumber = 0; for (let i = 2; i < maxTurns; i += 2) { console.log(`Turn ${i}/${maxTurns}`); const result = await this.execute(onUpdate, thoughtNumber); this.interMessages.push({ role: 'assistant', content: result }); let parsedResult; try { parsedResult = parser.parse(result); } catch (error) { console.log('Error: Invalid XML response from bot'); return 'Invalid response format.'; } const stage = parsedResult.stage; if (!stage) { console.log('Error: No stage found in response'); return 'Invalid response format: No stage found.'; } for (const key in stage) { if (key === 'thought') { console.log(`Thought: ${stage[key]}`); thoughtNumber++; } else if (key === 'action') { currentAction = stage[key] as string; console.log(`Action: ${currentAction}`); if (this.tools[currentAction]) { const nextPrompt = `` + builder.build({ action_rules: this.tools[currentAction].getActionRule() }) + ``; this.interMessages.push({ role: 'user', content: nextPrompt }); break; } else { console.log('Error: No valid action'); this.interMessages.push({ role: 'user', content: `No valid action, try again.` }); break; } } else if (key === 'action_input') { const actionInput = builder.build({ action_input: stage[key] }); console.log(`Action input: ${actionInput}`); if (currentAction) { try { const observation = await this.processAction(currentAction, stage[key]); const nextPrompt = `${observation}`; console.log(observation); this.interMessages.push({ role: 'user', content: nextPrompt }); break; } catch (error) { console.log(`Error processing action: ${error}`); return `${error}`; } } else { console.log('Error: Action input without a valid action'); return 'Action input without a valid action'; } } else if (key === 'answer') { console.log('Answer found. Ending query.'); onUpdate(`ANSWER:${stage[key]}`); return result; } } } console.log('Reached maximum turns. Ending query.'); return 'Reached maximum turns without finding an answer'; } private async execute(onUpdate: (update: string) => void, thoughtNumber: number): Promise { const stream = await this.client.chat.completions.create({ model: 'gpt-4o', messages: this.interMessages as ChatCompletionMessageParam[], temperature: 0, stream: true, }); let fullResponse = ''; let currentContent = ''; for await (const chunk of stream) { const content = chunk.choices[0]?.delta?.content || ''; fullResponse += content; currentContent += content; console.log(currentContent); if (currentContent.includes('')) { onUpdate(`THOUGHT${thoughtNumber}:${currentContent}`); } if (currentContent.includes('')) { currentContent = ''; } if (currentContent.includes('')) { onUpdate(`ANSWER_START:${currentContent}`); } if (currentContent.includes('')) { onUpdate(`ANSWER_END:${currentContent}`); currentContent = ''; } } return fullResponse; } private async processAction(action: string, actionInput: any): Promise { if (!(action in this.tools)) { throw new Error(`Unknown action: ${action}`); } const tool = this.tools[action]; const args: Record = {}; for (const paramName in tool.parameters) { if (actionInput[paramName] !== undefined) { args[paramName] = actionInput[paramName]; } else if (tool.parameters[paramName].required === 'true') { throw new Error(`Missing required parameter '${paramName}' for action '${action}'`); } } return await tool.execute(args); } }