diff options
| author | A.J. Shulman <Shulman.aj@gmail.com> | 2024-07-10 16:16:26 -0400 |
|---|---|---|
| committer | A.J. Shulman <Shulman.aj@gmail.com> | 2024-07-10 16:16:26 -0400 |
| commit | cab0311e2fd9a6379628c000d11ddcd805e01f64 (patch) | |
| tree | 60cb3f397426cb3931c13ebe3b8a1e8eb98480dd /src/client/views/nodes/ChatBox/Agent.ts | |
| parent | d0e09ff3526e4f6b9aad824fad1020d083a87631 (diff) | |
first attempt at integrating everything
Diffstat (limited to 'src/client/views/nodes/ChatBox/Agent.ts')
| -rw-r--r-- | src/client/views/nodes/ChatBox/Agent.ts | 123 |
1 files changed, 123 insertions, 0 deletions
diff --git a/src/client/views/nodes/ChatBox/Agent.ts b/src/client/views/nodes/ChatBox/Agent.ts new file mode 100644 index 000000000..f20a75a8d --- /dev/null +++ b/src/client/views/nodes/ChatBox/Agent.ts @@ -0,0 +1,123 @@ +import OpenAI from 'openai'; +import { Tool, AgentMessage } from './types'; +import { getReactPrompt } from './prompts'; +import { XMLParser, XMLBuilder } from 'fast-xml-parser'; + +export class Agent { + private client: OpenAI; + private tools: Record<string, Tool>; + private messages: AgentMessage[] = []; + private interMessages: AgentMessage[] = []; + private summaries: string; + + constructor(private vectorstore: Vectorstore) { + this.client = new OpenAI({ apiKey: process.env.OPENAI_API_KEY }); + this.summaries = this.vectorstore ? this.vectorstore.getSummaries() : 'No documents available.'; + this.tools = { + wikipedia: new WikipediaTool(), + calculate: new CalculateTool(), + rag: new RAGTool(vectorstore, this.summaries), + }; + } + + private formatChatHistory(): string { + let history = '<chat_history>\n'; + for (const message of this.messages) { + if (message.role === 'user') { + history += `<user>${message.content}</user>\n`; + } else if (message.role === 'assistant') { + history += `<assistant>${message.content}</assistant>\n`; + } + } + history += '</chat_history>'; + return history; + } + + async askAgent(question: string, maxTurns: number = 5): Promise<string> { + console.log(`Starting query: ${question}`); + this.messages.push({ role: 'user', content: question }); + const chatHistory = this.formatChatHistory(); + console.log(`Chat history: ${chatHistory}`); + const systemPrompt = getReactPrompt(Object.values(this.tools), chatHistory); + console.log(`System prompt: ${systemPrompt}`); + this.interMessages = [{ role: 'system', content: systemPrompt }]; + + this.interMessages.push({ role: 'assistant', content: `<query>${question}</query>` }); + + for (let i = 0; i < maxTurns; i++) { + console.log(`Turn ${i + 1}/${maxTurns}`); + + const result = await this.execute(); + console.log(`Bot response: ${result}`); + this.interMessages.push({ role: 'assistant', content: result }); + + try { + const parser = new XMLParser(); + const parsedResult = parser.parse(result); + const step = parsedResult[`step${i + 1}`]; + + if (step.thought) console.log(`Thought: ${step.thought}`); + if (step.action) { + console.log(`Action: ${step.action}`); + const action = step.action; + const actionRules = new XMLBuilder().build({ + action_rules: this.tools[action].getActionRule(), + }); + this.interMessages.push({ role: 'user', content: actionRules }); + } + if (step.action_input) { + const actionInput = new XMLBuilder().build({ action_input: step.action_input }); + console.log(`Action input: ${actionInput}`); + try { + const observation = await this.processAction(action, step.action_input); + const nextPrompt = [{ type: 'text', text: '<observation>' }, ...observation, { type: 'text', text: '</observation>' }]; + this.interMessages.push({ role: 'user', content: nextPrompt }); + } catch (e) { + console.error(`Error processing action: ${e}`); + return `<error>${e}</error>`; + } + } + if (step.answer) { + console.log('Answer found. Ending query.'); + const answerContent = new XMLBuilder().build({ answer: step.answer }); + this.messages.push({ role: 'assistant', content: answerContent }); + this.interMessages = []; + return answerContent; + } + } catch (e) { + console.error('Error: Invalid XML response from bot'); + return '<error>Invalid response format.</error>'; + } + } + + console.log('Reached maximum turns. Ending query.'); + return '<error>Reached maximum turns without finding an answer</error>'; + } + + private async execute(): Promise<string> { + const completion = await this.client.chat.completions.create({ + model: 'gpt-4', + messages: this.interMessages, + temperature: 0, + }); + return completion.choices[0].message.content; + } + + private async processAction(action: string, actionInput: any): Promise<any> { + if (!(action in this.tools)) { + throw new Error(`Unknown action: ${action}`); + } + + const tool = this.tools[action]; + const args: Record<string, any> = {}; + for (const paramName in tool.parameters) { + if (actionInput[paramName] !== undefined) { + args[paramName] = actionInput[paramName]; + } else { + throw new Error(`Missing required parameter '${paramName}' for action '${action}'`); + } + } + + return await tool.execute(args); + } +} |
