aboutsummaryrefslogtreecommitdiff
path: root/src/client/views/nodes/ChatBox/Agent.ts
blob: 41c91b4c63ef5e2697998cec3f42b171a94f4e7e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import OpenAI from 'openai';
import { Tool, AgentMessage } from './types';
import { getReactPrompt } from './prompts';
import { XMLParser, XMLBuilder } from 'fast-xml-parser';
import { WikipediaTool } from './tools/WikipediaTool';
import { CalculateTool } from './tools/CalculateTool';
import { RAGTool } from './tools/RAGTool';
import { NoTool } from './tools/NoTool';
import { Vectorstore } from './vectorstore/Vectorstore';
import { ChatCompletionAssistantMessageParam, ChatCompletionMessageParam } from 'openai/resources';
import dotenv from 'dotenv';
import { ChatBox } from './ChatBox';
import { DataAnalysisTool } from './tools/DataAnalysisTool';
import { string } from 'cohere-ai/core/schemas';
import { WebsiteInfoScraperTool } from './tools/WebsiteInfoScraperTool';
import { SearchTool } from './tools/SearchTool';
import { add } from 'lodash';
dotenv.config();

export class Agent {
    private client: OpenAI;
    private tools: Record<string, Tool<any>>;
    private messages: AgentMessage[] = [];
    private interMessages: AgentMessage[] = [];
    private vectorstore: Vectorstore;
    private _history: () => string;
    private _summaries: () => string;
    private _csvData: () => { filename: string; id: string; text: string }[];

    constructor(_vectorstore: Vectorstore, summaries: () => string, history: () => string, csvData: () => { filename: string; id: string; text: string }[], addLinkedUrlDoc: (url: string, id: string) => void) {
        console.log(process.env.OPENAI_KEY);
        this.client = new OpenAI({ apiKey: process.env.OPENAI_KEY, dangerouslyAllowBrowser: true });
        this.vectorstore = _vectorstore;
        this._history = history;
        this._summaries = summaries;
        this._csvData = csvData;
        this.tools = {
            //wikipedia: new WikipediaTool(addLinkedUrlDoc),
            calculate: new CalculateTool(),
            rag: new RAGTool(this.vectorstore),
            dataAnalysis: new DataAnalysisTool(csvData),
            websiteInfoScraper: new WebsiteInfoScraperTool(addLinkedUrlDoc),
            searchTool: new SearchTool(addLinkedUrlDoc),
            no_tool: new NoTool(),
        };
    }

    async askAgent(question: string, maxTurns: number = 20): Promise<string> {
        console.log(`Starting query: ${question}`);
        this.messages.push({ role: 'user', content: question });
        const chatHistory = this._history();
        console.log(`Chat history: ${chatHistory}`);
        const systemPrompt = getReactPrompt(Object.values(this.tools), this._summaries, chatHistory);
        console.log(`System prompt: ${systemPrompt}`);
        this.interMessages = [{ role: 'system', content: systemPrompt }];

        this.interMessages.push({ role: 'user', content: `<stage number="1" role="user"><query>${question}</query></stage>` });

        const parser = new XMLParser({ ignoreAttributes: false, attributeNamePrefix: '@_' });
        const builder = new XMLBuilder({ ignoreAttributes: false, attributeNamePrefix: '@_' });
        let currentAction: string | undefined;

        for (let i = 2; i < maxTurns; i += 2) {
            console.log(`Turn ${i}/${maxTurns}`);

            const result = await this.execute();
            console.log(`Bot response: ${result}`);
            this.interMessages.push({ role: 'assistant', content: result });

            let parsedResult;
            try {
                parsedResult = parser.parse(result);
            } catch (error) {
                console.log('Error: Invalid XML response from bot');
                return '<error>Invalid response format.</error>';
            }

            const stage = parsedResult.stage;

            if (!stage) {
                console.log('Error: No stage found in response');
                return '<error>Invalid response format: No stage found.</error>';
            }

            for (const key in stage) {
                if (key === 'thought') {
                    console.log(`Thought: ${stage[key]}`);
                } else if (key === 'action') {
                    currentAction = stage[key] as string;
                    console.log(`Action: ${currentAction}`);
                    if (this.tools[currentAction]) {
                        const nextPrompt = [
                            {
                                type: 'text',
                                text: `<stage number="${i + 1}" role="user">` + builder.build({ action_rules: this.tools[currentAction].getActionRule() }) + `</stage>`,
                            },
                        ];
                        this.interMessages.push({ role: 'user', content: nextPrompt });

                        break;
                    } else {
                        console.log('Error: No valid action');
                        this.interMessages.push({ role: 'user', content: `<stage number="${i + 1}" role="system-error-reporter">No valid action, try again.</stage>` });
                        break;
                    }
                } else if (key === 'action_input') {
                    const actionInput = builder.build({ action_input: stage[key] });
                    console.log(`Action input: ${actionInput}`);
                    if (currentAction) {
                        try {
                            const observation = await this.processAction(currentAction, stage[key]);
                            const nextPrompt = [{ type: 'text', text: `<stage number="${i + 1}" role="user"> <observation>` }, ...observation, { type: 'text', text: '</observation></stage>' }];
                            console.log(observation);
                            this.interMessages.push({ role: 'user', content: nextPrompt });
                            break;
                        } catch (error) {
                            console.log(`Error processing action: ${error}`);
                            return `<error>${error}</error>`;
                        }
                    } else {
                        console.log('Error: Action input without a valid action');
                        return '<error>Action input without a valid action</error>';
                    }
                } else if (key === 'answer') {
                    console.log('Answer found. Ending query.');
                    return result;
                }
            }
        }
        console.log(this.messages);
        console.log('Reached maximum turns. Ending query.');
        return '<error>Reached maximum turns without finding an answer</error>';
    }

    private async execute(): Promise<string> {
        console.log(this.interMessages);
        const completion = await this.client.chat.completions.create({
            model: 'gpt-4o',
            messages: this.interMessages as ChatCompletionMessageParam[],
            temperature: 0,
        });
        if (completion.choices[0].message.content) return completion.choices[0].message.content;
        else throw new Error('No completion content found');
    }

    private async processAction(action: string, actionInput: any): Promise<any> {
        if (!(action in this.tools)) {
            throw new Error(`Unknown action: ${action}`);
        }

        const tool = this.tools[action];
        const args: Record<string, any> = {};
        for (const paramName in tool.parameters) {
            if (actionInput[paramName] !== undefined) {
                args[paramName] = actionInput[paramName];
            } else if (tool.parameters[paramName].required === 'true') {
                throw new Error(`Missing required parameter '${paramName}' for action '${action}'`);
            }
        }

        return await tool.execute(args);
    }
}