aboutsummaryrefslogtreecommitdiff
path: root/src/client/views/nodes/ChatBox/Agent.ts
blob: bada4b1468a91d03dab6fa25adea26efe2e9f902 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import OpenAI from 'openai';
import { Tool, AgentMessage } from './types';
import { getReactPrompt } from './prompts';
import { XMLParser, XMLBuilder } from 'fast-xml-parser';
import { WikipediaTool } from './tools/WikipediaTool';
import { CalculateTool } from './tools/CalculateTool';
import { RAGTool } from './tools/RAGTool';
import { Vectorstore } from './vectorstore/VectorstoreUpload';
import { ChatCompletionAssistantMessageParam, ChatCompletionMessageParam } from 'openai/resources';
import dotenv from 'dotenv';
import { ChatBox } from './ChatBox';
dotenv.config();

export class Agent {
    private client: OpenAI;
    private tools: Record<string, Tool<any>>;
    private messages: AgentMessage[] = [];
    private interMessages: AgentMessage[] = [];
    private vectorstore: Vectorstore;
    private _history: () => string;

    constructor(_vectorstore: Vectorstore, summaries: () => string, history: () => string) {
        this.client = new OpenAI({ apiKey: process.env.OPENAI_KEY, dangerouslyAllowBrowser: true });
        this.vectorstore = _vectorstore;
        this._history = history;
        this.tools = {
            wikipedia: new WikipediaTool(),
            calculate: new CalculateTool(),
            rag: new RAGTool(this.vectorstore, summaries),
        };
    }

    async askAgent(question: string, maxTurns: number = 8): Promise<string> {
        console.log(`Starting query: ${question}`);
        this.messages.push({ role: 'user', content: question });
        const chatHistory = this._history();
        console.log(`Chat history: ${chatHistory}`);
        const systemPrompt = getReactPrompt(Object.values(this.tools), chatHistory);
        console.log(`System prompt: ${systemPrompt}`);
        this.interMessages = [{ role: 'system', content: systemPrompt }];

        this.interMessages.push({ role: 'user', content: `<query>${question}</query>` });

        const parser = new XMLParser();
        const builder = new XMLBuilder();
        let currentAction: string | undefined;

        for (let i = 0; i < maxTurns; i++) {
            console.log(`Turn ${i + 1}/${maxTurns}`);

            const result = await this.execute();
            console.log(`Bot response: ${result}`);
            this.interMessages.push({ role: 'assistant', content: result });

            let parsedResult;
            try {
                parsedResult = parser.parse(result);
            } catch (error) {
                console.log('Error: Invalid XML response from bot');
                return '<error>Invalid response format.</error>';
            }

            const step = parsedResult[Object.keys(parsedResult)[0]];

            for (const key in step) {
                if (key === 'thought') {
                    console.log(`Thought: ${step[key]}`);
                } else if (key === 'action') {
                    currentAction = step[key] as string;
                    console.log(`Action: ${currentAction}`);
                    if (this.tools[currentAction]) {
                        const nextPrompt = [
                            {
                                type: 'text',
                                text: builder.build({ action_rules: this.tools[currentAction].getActionRule() }),
                            },
                        ];
                        this.interMessages.push({ role: 'user', content: nextPrompt });
                        break;
                    } else {
                        console.log('Error: No valid action');
                        this.interMessages.push({ role: 'user', content: 'No valid action, try again.' });
                        break;
                    }
                } else if (key === 'action_input') {
                    const actionInput = builder.build({ action_input: step[key] });
                    console.log(`Action input: ${actionInput}`);
                    if (currentAction) {
                        try {
                            const observation = await this.processAction(currentAction, step[key]);
                            const nextPrompt = [{ type: 'text', text: '<observation>' }, ...observation, { type: 'text', text: '</observation>' }];
                            console.log(observation);
                            this.interMessages.push({ role: 'user', content: nextPrompt });
                            break;
                        } catch (error) {
                            console.log(`Error processing action: ${error}`);
                            return `<error>${error}</error>`;
                        }
                    } else {
                        console.log('Error: Action input without a valid action');
                        return '<error>Action input without a valid action</error>';
                    }
                } else if (key === 'answer') {
                    console.log('Answer found. Ending query.');
                    return result;
                }
            }
        }
        console.log(this.messages);
        console.log('Reached maximum turns. Ending query.');
        return '<error>Reached maximum turns without finding an answer</error>';
    }

    private async execute(): Promise<string> {
        console.log(this.interMessages);
        const completion = await this.client.chat.completions.create({
            model: 'gpt-4o',
            messages: this.interMessages as ChatCompletionMessageParam[],
            temperature: 0,
        });
        if (completion.choices[0].message.content) return completion.choices[0].message.content;
        else throw new Error('No completion content found');
    }

    private async processAction(action: string, actionInput: any): Promise<any> {
        if (!(action in this.tools)) {
            throw new Error(`Unknown action: ${action}`);
        }

        const tool = this.tools[action];
        const args: Record<string, any> = {};
        for (const paramName in tool.parameters) {
            if (actionInput[paramName] !== undefined) {
                args[paramName] = actionInput[paramName];
            } else {
                throw new Error(`Missing required parameter '${paramName}' for action '${action}'`);
            }
        }

        return await tool.execute(args);
    }
}