import OpenAI from 'openai'; import { Tool, AgentMessage } from './types'; import { getReactPrompt } from './prompts'; import { XMLParser, XMLBuilder } from 'fast-xml-parser'; import { WikipediaTool } from './tools/WikipediaTool'; import { CalculateTool } from './tools/CalculateTool'; import { RAGTool } from './tools/RAGTool'; import { Vectorstore } from './vectorstore/VectorstoreUpload'; import { ChatCompletionAssistantMessageParam, ChatCompletionMessageParam } from 'openai/resources'; import dotenv from 'dotenv'; import { ChatBox } from './ChatBox'; dotenv.config(); export class Agent { private client: OpenAI; private tools: Record>; private messages: AgentMessage[] = []; private interMessages: AgentMessage[] = []; private vectorstore: Vectorstore; private _history: () => string; constructor(_vectorstore: Vectorstore, summaries: () => string, history: () => string) { this.client = new OpenAI({ apiKey: process.env.OPENAI_KEY, dangerouslyAllowBrowser: true }); this.vectorstore = _vectorstore; this._history = history; this.tools = { wikipedia: new WikipediaTool(), calculate: new CalculateTool(), rag: new RAGTool(this.vectorstore, summaries), }; } async askAgent(question: string, maxTurns: number = 8): Promise { console.log(`Starting query: ${question}`); this.messages.push({ role: 'user', content: question }); const chatHistory = this._history(); console.log(`Chat history: ${chatHistory}`); const systemPrompt = getReactPrompt(Object.values(this.tools), chatHistory); console.log(`System prompt: ${systemPrompt}`); this.interMessages = [{ role: 'system', content: systemPrompt }]; this.interMessages.push({ role: 'user', content: `${question}` }); const parser = new XMLParser(); const builder = new XMLBuilder(); let currentAction: string | undefined; for (let i = 0; i < maxTurns; i++) { console.log(`Turn ${i + 1}/${maxTurns}`); const result = await this.execute(); console.log(`Bot response: ${result}`); this.interMessages.push({ role: 'assistant', content: result }); let parsedResult; try { parsedResult = parser.parse(result); } catch (error) { console.log('Error: Invalid XML response from bot'); return 'Invalid response format.'; } const step = parsedResult[Object.keys(parsedResult)[0]]; for (const key in step) { if (key === 'thought') { console.log(`Thought: ${step[key]}`); } else if (key === 'action') { currentAction = step[key] as string; console.log(`Action: ${currentAction}`); if (this.tools[currentAction]) { const nextPrompt = [ { type: 'text', text: builder.build({ action_rules: this.tools[currentAction].getActionRule() }), }, ]; this.interMessages.push({ role: 'user', content: nextPrompt }); break; } else { console.log('Error: No valid action'); this.interMessages.push({ role: 'user', content: 'No valid action, try again.' }); break; } } else if (key === 'action_input') { const actionInput = builder.build({ action_input: step[key] }); console.log(`Action input: ${actionInput}`); if (currentAction) { try { const observation = await this.processAction(currentAction, step[key]); const nextPrompt = [{ type: 'text', text: '' }, ...observation, { type: 'text', text: '' }]; console.log(observation); this.interMessages.push({ role: 'user', content: nextPrompt }); break; } catch (error) { console.log(`Error processing action: ${error}`); return `${error}`; } } else { console.log('Error: Action input without a valid action'); return 'Action input without a valid action'; } } else if (key === 'answer') { console.log('Answer found. Ending query.'); return result; } } } console.log(this.messages); console.log('Reached maximum turns. Ending query.'); return 'Reached maximum turns without finding an answer'; } private async execute(): Promise { console.log(this.interMessages); const completion = await this.client.chat.completions.create({ model: 'gpt-4o', messages: this.interMessages as ChatCompletionMessageParam[], temperature: 0, }); if (completion.choices[0].message.content) return completion.choices[0].message.content; else throw new Error('No completion content found'); } private async processAction(action: string, actionInput: any): Promise { if (!(action in this.tools)) { throw new Error(`Unknown action: ${action}`); } const tool = this.tools[action]; const args: Record = {}; for (const paramName in tool.parameters) { if (actionInput[paramName] !== undefined) { args[paramName] = actionInput[paramName]; } else { throw new Error(`Missing required parameter '${paramName}' for action '${action}'`); } } return await tool.execute(args); } }