aboutsummaryrefslogtreecommitdiff
path: root/src/client/views/nodes/ChatBox/Agent.ts
diff options
context:
space:
mode:
authorA.J. Shulman <Shulman.aj@gmail.com>2024-07-10 16:16:26 -0400
committerA.J. Shulman <Shulman.aj@gmail.com>2024-07-10 16:16:26 -0400
commitcab0311e2fd9a6379628c000d11ddcd805e01f64 (patch)
tree60cb3f397426cb3931c13ebe3b8a1e8eb98480dd /src/client/views/nodes/ChatBox/Agent.ts
parentd0e09ff3526e4f6b9aad824fad1020d083a87631 (diff)
first attempt at integrating everything
Diffstat (limited to 'src/client/views/nodes/ChatBox/Agent.ts')
-rw-r--r--src/client/views/nodes/ChatBox/Agent.ts123
1 files changed, 123 insertions, 0 deletions
diff --git a/src/client/views/nodes/ChatBox/Agent.ts b/src/client/views/nodes/ChatBox/Agent.ts
new file mode 100644
index 000000000..f20a75a8d
--- /dev/null
+++ b/src/client/views/nodes/ChatBox/Agent.ts
@@ -0,0 +1,123 @@
+import OpenAI from 'openai';
+import { Tool, AgentMessage } from './types';
+import { getReactPrompt } from './prompts';
+import { XMLParser, XMLBuilder } from 'fast-xml-parser';
+
+export class Agent {
+ private client: OpenAI;
+ private tools: Record<string, Tool>;
+ private messages: AgentMessage[] = [];
+ private interMessages: AgentMessage[] = [];
+ private summaries: string;
+
+ constructor(private vectorstore: Vectorstore) {
+ this.client = new OpenAI({ apiKey: process.env.OPENAI_API_KEY });
+ this.summaries = this.vectorstore ? this.vectorstore.getSummaries() : 'No documents available.';
+ this.tools = {
+ wikipedia: new WikipediaTool(),
+ calculate: new CalculateTool(),
+ rag: new RAGTool(vectorstore, this.summaries),
+ };
+ }
+
+ private formatChatHistory(): string {
+ let history = '<chat_history>\n';
+ for (const message of this.messages) {
+ if (message.role === 'user') {
+ history += `<user>${message.content}</user>\n`;
+ } else if (message.role === 'assistant') {
+ history += `<assistant>${message.content}</assistant>\n`;
+ }
+ }
+ history += '</chat_history>';
+ return history;
+ }
+
+ async askAgent(question: string, maxTurns: number = 5): Promise<string> {
+ console.log(`Starting query: ${question}`);
+ this.messages.push({ role: 'user', content: question });
+ const chatHistory = this.formatChatHistory();
+ console.log(`Chat history: ${chatHistory}`);
+ const systemPrompt = getReactPrompt(Object.values(this.tools), chatHistory);
+ console.log(`System prompt: ${systemPrompt}`);
+ this.interMessages = [{ role: 'system', content: systemPrompt }];
+
+ this.interMessages.push({ role: 'assistant', content: `<query>${question}</query>` });
+
+ for (let i = 0; i < maxTurns; i++) {
+ console.log(`Turn ${i + 1}/${maxTurns}`);
+
+ const result = await this.execute();
+ console.log(`Bot response: ${result}`);
+ this.interMessages.push({ role: 'assistant', content: result });
+
+ try {
+ const parser = new XMLParser();
+ const parsedResult = parser.parse(result);
+ const step = parsedResult[`step${i + 1}`];
+
+ if (step.thought) console.log(`Thought: ${step.thought}`);
+ if (step.action) {
+ console.log(`Action: ${step.action}`);
+ const action = step.action;
+ const actionRules = new XMLBuilder().build({
+ action_rules: this.tools[action].getActionRule(),
+ });
+ this.interMessages.push({ role: 'user', content: actionRules });
+ }
+ if (step.action_input) {
+ const actionInput = new XMLBuilder().build({ action_input: step.action_input });
+ console.log(`Action input: ${actionInput}`);
+ try {
+ const observation = await this.processAction(action, step.action_input);
+ const nextPrompt = [{ type: 'text', text: '<observation>' }, ...observation, { type: 'text', text: '</observation>' }];
+ this.interMessages.push({ role: 'user', content: nextPrompt });
+ } catch (e) {
+ console.error(`Error processing action: ${e}`);
+ return `<error>${e}</error>`;
+ }
+ }
+ if (step.answer) {
+ console.log('Answer found. Ending query.');
+ const answerContent = new XMLBuilder().build({ answer: step.answer });
+ this.messages.push({ role: 'assistant', content: answerContent });
+ this.interMessages = [];
+ return answerContent;
+ }
+ } catch (e) {
+ console.error('Error: Invalid XML response from bot');
+ return '<error>Invalid response format.</error>';
+ }
+ }
+
+ console.log('Reached maximum turns. Ending query.');
+ return '<error>Reached maximum turns without finding an answer</error>';
+ }
+
+ private async execute(): Promise<string> {
+ const completion = await this.client.chat.completions.create({
+ model: 'gpt-4',
+ messages: this.interMessages,
+ temperature: 0,
+ });
+ return completion.choices[0].message.content;
+ }
+
+ private async processAction(action: string, actionInput: any): Promise<any> {
+ if (!(action in this.tools)) {
+ throw new Error(`Unknown action: ${action}`);
+ }
+
+ const tool = this.tools[action];
+ const args: Record<string, any> = {};
+ for (const paramName in tool.parameters) {
+ if (actionInput[paramName] !== undefined) {
+ args[paramName] = actionInput[paramName];
+ } else {
+ throw new Error(`Missing required parameter '${paramName}' for action '${action}'`);
+ }
+ }
+
+ return await tool.execute(args);
+ }
+}