1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
|
import dotenv from 'dotenv';
import { XMLBuilder, XMLParser } from 'fast-xml-parser';
import OpenAI from 'openai';
import { ChatCompletionMessageParam } from 'openai/resources';
import { AnswerParser } from '../response_parsers/AnswerParser';
import { StreamedAnswerParser } from '../response_parsers/StreamedAnswerParser';
import { CalculateTool } from '../tools/CalculateTool';
import { CreateCSVTool } from '../tools/CreateCSVTool';
import { DataAnalysisTool } from '../tools/DataAnalysisTool';
import { NoTool } from '../tools/NoTool';
import { RAGTool } from '../tools/RAGTool';
import { SearchTool } from '../tools/SearchTool';
import { WebsiteInfoScraperTool } from '../tools/WebsiteInfoScraperTool';
import { AgentMessage, AssistantMessage, Observation, PROCESSING_TYPE, ProcessingInfo } from '../types/types';
import { Vectorstore } from '../vectorstore/Vectorstore';
import { getReactPrompt } from './prompts';
import { BaseTool } from '../tools/BaseTool';
import { Parameter, ParametersType, Tool } from '../tools/ToolTypes';
dotenv.config();
/**
* The Agent class handles the interaction between the assistant and the tools available,
* processes user queries, and manages the communication flow between the tools and OpenAI.
*/
export class Agent {
// Private properties
private client: OpenAI;
private messages: AgentMessage[] = [];
private interMessages: AgentMessage[] = [];
private vectorstore: Vectorstore;
private _history: () => string;
private _summaries: () => string;
private _csvData: () => { filename: string; id: string; text: string }[];
private actionNumber: number = 0;
private thoughtNumber: number = 0;
private processingNumber: number = 0;
private processingInfo: ProcessingInfo[] = [];
private streamedAnswerParser: StreamedAnswerParser = new StreamedAnswerParser();
private tools: Record<string, BaseTool<ReadonlyArray<Parameter>>>;
/**
* The constructor initializes the agent with the vector store and toolset, and sets up the OpenAI client.
* @param _vectorstore Vector store instance for document storage and retrieval.
* @param summaries A function to retrieve document summaries.
* @param history A function to retrieve chat history.
* @param csvData A function to retrieve CSV data linked to the assistant.
* @param addLinkedUrlDoc A function to add a linked document from a URL.
* @param createCSVInDash A function to create a CSV document in the dashboard.
*/
constructor(
_vectorstore: Vectorstore,
summaries: () => string,
history: () => string,
csvData: () => { filename: string; id: string; text: string }[],
addLinkedUrlDoc: (url: string, id: string) => void,
createCSVInDash: (url: string, title: string, id: string, data: string) => void
) {
// Initialize OpenAI client with API key from environment
this.client = new OpenAI({ apiKey: process.env.OPENAI_KEY, dangerouslyAllowBrowser: true });
this.vectorstore = _vectorstore;
this._history = history;
this._summaries = summaries;
this._csvData = csvData;
// Define available tools for the assistant
this.tools = {
calculate: new CalculateTool(),
rag: new RAGTool(this.vectorstore),
dataAnalysis: new DataAnalysisTool(csvData),
websiteInfoScraper: new WebsiteInfoScraperTool(addLinkedUrlDoc),
searchTool: new SearchTool(addLinkedUrlDoc),
createCSV: new CreateCSVTool(createCSVInDash),
no_tool: new NoTool(),
};
}
/**
* This method handles the conversation flow with the assistant, processes user queries,
* and manages the assistant's decision-making process, including tool actions.
* @param question The user's question.
* @param onProcessingUpdate Callback function for processing updates.
* @param onAnswerUpdate Callback function for answer updates.
* @param maxTurns The maximum number of turns to allow in the conversation.
* @returns The final response from the assistant.
*/
async askAgent(question: string, onProcessingUpdate: (processingUpdate: ProcessingInfo[]) => void, onAnswerUpdate: (answerUpdate: string) => void, maxTurns: number = 30): Promise<AssistantMessage> {
console.log(`Starting query: ${question}`);
// Push user's question to message history
this.messages.push({ role: 'user', content: question });
// Retrieve chat history and generate system prompt
const chatHistory = this._history();
const systemPrompt = getReactPrompt(Object.values(this.tools), this._summaries, chatHistory);
// Initialize intermediate messages
this.interMessages = [{ role: 'system', content: systemPrompt }];
this.interMessages.push({ role: 'user', content: `<stage number="1" role="user"><query>${question}</query></stage>` });
// Setup XML parser and builder
const parser = new XMLParser({
ignoreAttributes: false,
attributeNamePrefix: '@_',
textNodeName: '_text',
isArray: (name /* , jpath, isLeafNode, isAttribute */) => ['query', 'url'].indexOf(name) !== -1,
});
const builder = new XMLBuilder({ ignoreAttributes: false, attributeNamePrefix: '@_' });
let currentAction: string | undefined;
this.processingInfo = [];
let i = 2;
while (i < maxTurns) {
console.log(this.interMessages);
console.log(`Turn ${i}/${maxTurns}`);
const result = await this.execute(onProcessingUpdate, onAnswerUpdate);
this.interMessages.push({ role: 'assistant', content: result });
i += 2;
let parsedResult;
try {
// Parse XML result from the assistant
parsedResult = parser.parse(result);
} catch (error) {
throw new Error(`Error parsing response: ${error}`);
}
// Extract the stage from the parsed result
const stage = parsedResult.stage;
if (!stage) {
throw new Error(`Error: No stage found in response`);
}
// Handle different stage elements (thoughts, actions, inputs, answers)
for (const key in stage) {
if (key === 'thought') {
// Handle assistant's thoughts
console.log(`Thought: ${stage[key]}`);
this.processingNumber++;
} else if (key === 'action') {
// Handle action stage
currentAction = stage[key] as string;
console.log(`Action: ${currentAction}`);
if (this.tools[currentAction]) {
// Prepare the next action based on the current tool
const nextPrompt = [
{
type: 'text',
text: `<stage number="${i + 1}" role="user">` + builder.build({ action_rules: this.tools[currentAction].getActionRule() }) + `</stage>`,
} as Observation,
];
this.interMessages.push({ role: 'user', content: nextPrompt });
break;
} else {
// Handle error in case of an invalid action
console.log('Error: No valid action');
this.interMessages.push({ role: 'user', content: `<stage number="${i + 1}" role="system-error-reporter">No valid action, try again.</stage>` });
break;
}
} else if (key === 'action_input') {
// Handle action input stage
const actionInput = stage[key];
console.log(`Action input:`, actionInput.inputs);
if (currentAction) {
try {
// Process the action with its input
const observation = (await this.processAction(currentAction, actionInput.inputs)) as Observation[];
const nextPrompt = [{ type: 'text', text: `<stage number="${i + 1}" role="user"> <observation>` }, ...observation, { type: 'text', text: '</observation></stage>' }] as Observation[];
console.log(observation);
this.interMessages.push({ role: 'user', content: nextPrompt });
this.processingNumber++;
break;
} catch (error) {
throw new Error(`Error processing action: ${error}`);
}
} else {
throw new Error('Error: Action input without a valid action');
}
} else if (key === 'answer') {
// If an answer is found, end the query
console.log('Answer found. Ending query.');
this.streamedAnswerParser.reset();
const parsedAnswer = AnswerParser.parse(result, this.processingInfo);
return parsedAnswer;
}
}
}
throw new Error('Reached maximum turns. Ending query.');
}
/**
* Executes a step in the conversation, processing the assistant's response and parsing it in real-time.
* @param onProcessingUpdate Callback for processing updates.
* @param onAnswerUpdate Callback for answer updates.
* @returns The full response from the assistant.
*/
private async execute(onProcessingUpdate: (processingUpdate: ProcessingInfo[]) => void, onAnswerUpdate: (answerUpdate: string) => void): Promise<string> {
// Stream OpenAI response for real-time updates
const stream = await this.client.chat.completions.create({
model: 'gpt-4o',
messages: this.interMessages as ChatCompletionMessageParam[],
temperature: 0,
stream: true,
});
let fullResponse: string = '';
let currentTag: string = '';
let currentContent: string = '';
let isInsideTag: boolean = false;
// Process each chunk of the streamed response
for await (const chunk of stream) {
const content = chunk.choices[0]?.delta?.content || '';
fullResponse += content;
// Parse the streamed content character by character
for (const char of content) {
if (currentTag === 'answer') {
// Handle answer parsing for real-time updates
currentContent += char;
const streamedAnswer = this.streamedAnswerParser.parse(char);
onAnswerUpdate(streamedAnswer);
continue;
} else if (char === '<') {
// Start of a new tag
isInsideTag = true;
currentTag = '';
currentContent = '';
} else if (char === '>') {
// End of the tag
isInsideTag = false;
if (currentTag.startsWith('/')) {
currentTag = '';
}
} else if (isInsideTag) {
// Append characters to the tag name
currentTag += char;
} else if (currentTag === 'thought' || currentTag === 'action_input_description') {
// Handle processing information for thought or action input description
currentContent += char;
const current_info = this.processingInfo.find(info => info.index === this.processingNumber);
if (current_info) {
current_info.content = currentContent.trim();
onProcessingUpdate(this.processingInfo);
} else {
this.processingInfo.push({
index: this.processingNumber,
type: currentTag === 'thought' ? PROCESSING_TYPE.THOUGHT : PROCESSING_TYPE.ACTION,
content: currentContent.trim(),
});
onProcessingUpdate(this.processingInfo);
}
}
}
}
return fullResponse;
}
/**
* Processes a specific action by invoking the appropriate tool with the provided inputs.
* This method ensures that the action exists and validates the types of `actionInput`
* based on the tool's parameter rules. It throws errors for missing required parameters
* or mismatched types before safely executing the tool with the validated input.
*
* Type validation includes checks for:
* - `string`, `number`, `boolean`
* - `string[]`, `number[]` (arrays of strings or numbers)
*
* @param action The action to perform. It corresponds to a registered tool.
* @param actionInput The inputs for the action, passed as an object where each key is a parameter name.
* @returns A promise that resolves to an array of `Observation` objects representing the result of the action.
* @throws An error if the action is unknown, if required parameters are missing, or if input types don't match the expected parameter types.
*/
private async processAction(action: string, actionInput: Record<string, unknown>): Promise<Observation[]> {
// Check if the action exists in the tools list
if (!(action in this.tools)) {
throw new Error(`Unknown action: ${action}`);
}
const tool = this.tools[action];
// Validate actionInput based on tool's parameter rules
for (const paramRule of tool.parameterRules) {
const inputValue = actionInput[paramRule.name];
if (paramRule.required && inputValue === undefined) {
throw new Error(`Missing required parameter: ${paramRule.name}`);
}
// If the parameter is defined, check its type
if (inputValue !== undefined) {
switch (paramRule.type) {
case 'string':
if (typeof inputValue !== 'string') {
throw new Error(`Expected parameter '${paramRule.name}' to be a string.`);
}
break;
case 'number':
if (typeof inputValue !== 'number') {
throw new Error(`Expected parameter '${paramRule.name}' to be a number.`);
}
break;
case 'boolean':
if (typeof inputValue !== 'boolean') {
throw new Error(`Expected parameter '${paramRule.name}' to be a boolean.`);
}
break;
case 'string[]':
if (!Array.isArray(inputValue) || !inputValue.every(item => typeof item === 'string')) {
throw new Error(`Expected parameter '${paramRule.name}' to be an array of strings.`);
}
break;
case 'number[]':
if (!Array.isArray(inputValue) || !inputValue.every(item => typeof item === 'number')) {
throw new Error(`Expected parameter '${paramRule.name}' to be an array of numbers.`);
}
break;
default:
throw new Error(`Unsupported parameter type: ${paramRule.type}`);
}
}
}
return await tool.execute(actionInput as ParametersType<typeof tool.parameterRules>);
}
}
|