aboutsummaryrefslogtreecommitdiff
path: root/src/client/views/nodes/chatbot/agentsystem/Agent.ts
diff options
context:
space:
mode:
authoreleanor-park <eleanor_park@brown.edu>2024-10-30 19:39:46 -0400
committereleanor-park <eleanor_park@brown.edu>2024-10-30 19:39:46 -0400
commitc11c760db62f78a07b624b98b209e6ee86036c8e (patch)
treec9587b50042a5115373e91ba8ecf9b76913cd321 /src/client/views/nodes/chatbot/agentsystem/Agent.ts
parentb5944e87f9d4f3149161de4de0d76db486461c76 (diff)
parent4c768162e0436115a05b9c8b0e4d837d626d45ba (diff)
Merge branch 'master' into eleanor-gptdraw
Diffstat (limited to 'src/client/views/nodes/chatbot/agentsystem/Agent.ts')
-rw-r--r--src/client/views/nodes/chatbot/agentsystem/Agent.ts82
1 files changed, 69 insertions, 13 deletions
diff --git a/src/client/views/nodes/chatbot/agentsystem/Agent.ts b/src/client/views/nodes/chatbot/agentsystem/Agent.ts
index ccf9caf15..34e7cf5ea 100644
--- a/src/client/views/nodes/chatbot/agentsystem/Agent.ts
+++ b/src/client/views/nodes/chatbot/agentsystem/Agent.ts
@@ -11,9 +11,11 @@ import { NoTool } from '../tools/NoTool';
import { RAGTool } from '../tools/RAGTool';
import { SearchTool } from '../tools/SearchTool';
import { WebsiteInfoScraperTool } from '../tools/WebsiteInfoScraperTool';
-import { AgentMessage, AssistantMessage, PROCESSING_TYPE, ProcessingInfo, Tool } from '../types/types';
+import { AgentMessage, AssistantMessage, Observation, PROCESSING_TYPE, ProcessingInfo } from '../types/types';
import { Vectorstore } from '../vectorstore/Vectorstore';
import { getReactPrompt } from './prompts';
+import { BaseTool } from '../tools/BaseTool';
+import { Parameter, ParametersType, Tool } from '../tools/ToolTypes';
dotenv.config();
@@ -24,7 +26,6 @@ dotenv.config();
export class Agent {
// Private properties
private client: OpenAI;
- private tools: Record<string, Tool<any>>; // bcz: need a real type here
private messages: AgentMessage[] = [];
private interMessages: AgentMessage[] = [];
private vectorstore: Vectorstore;
@@ -36,6 +37,7 @@ export class Agent {
private processingNumber: number = 0;
private processingInfo: ProcessingInfo[] = [];
private streamedAnswerParser: StreamedAnswerParser = new StreamedAnswerParser();
+ private tools: Record<string, BaseTool<ReadonlyArray<Parameter>>>;
/**
* The constructor initializes the agent with the vector store and toolset, and sets up the OpenAI client.
@@ -108,15 +110,16 @@ export class Agent {
let currentAction: string | undefined;
this.processingInfo = [];
- // Conversation loop (up to maxTurns)
- for (let i = 2; i < maxTurns; i += 2) {
+ let i = 2;
+ while (i < maxTurns) {
console.log(this.interMessages);
console.log(`Turn ${i}/${maxTurns}`);
- // Execute a step in the conversation and get the result
const result = await this.execute(onProcessingUpdate, onAnswerUpdate);
this.interMessages.push({ role: 'assistant', content: result });
+ i += 2;
+
let parsedResult;
try {
// Parse XML result from the assistant
@@ -148,7 +151,7 @@ export class Agent {
{
type: 'text',
text: `<stage number="${i + 1}" role="user">` + builder.build({ action_rules: this.tools[currentAction].getActionRule() }) + `</stage>`,
- },
+ } as Observation,
];
this.interMessages.push({ role: 'user', content: nextPrompt });
break;
@@ -166,8 +169,8 @@ export class Agent {
if (currentAction) {
try {
// Process the action with its input
- const observation = (await this.processAction(currentAction, actionInput.inputs)) as any; // bcz: really need a type here
- const nextPrompt = [{ type: 'text', text: `<stage number="${i + 1}" role="user"> <observation>` }, ...observation, { type: 'text', text: '</observation></stage>' }];
+ const observation = (await this.processAction(currentAction, actionInput.inputs)) as Observation[];
+ const nextPrompt = [{ type: 'text', text: `<stage number="${i + 1}" role="user"> <observation>` }, ...observation, { type: 'text', text: '</observation></stage>' }] as Observation[];
console.log(observation);
this.interMessages.push({ role: 'user', content: nextPrompt });
this.processingNumber++;
@@ -262,16 +265,69 @@ export class Agent {
/**
* Processes a specific action by invoking the appropriate tool with the provided inputs.
- * @param action The action to perform.
- * @param actionInput The inputs for the action.
- * @returns The result of the action.
+ * This method ensures that the action exists and validates the types of `actionInput`
+ * based on the tool's parameter rules. It throws errors for missing required parameters
+ * or mismatched types before safely executing the tool with the validated input.
+ *
+ * Type validation includes checks for:
+ * - `string`, `number`, `boolean`
+ * - `string[]`, `number[]` (arrays of strings or numbers)
+ *
+ * @param action The action to perform. It corresponds to a registered tool.
+ * @param actionInput The inputs for the action, passed as an object where each key is a parameter name.
+ * @returns A promise that resolves to an array of `Observation` objects representing the result of the action.
+ * @throws An error if the action is unknown, if required parameters are missing, or if input types don't match the expected parameter types.
*/
- private async processAction(action: string, actionInput: unknown): Promise<unknown> {
+ private async processAction(action: string, actionInput: Record<string, unknown>): Promise<Observation[]> {
+ // Check if the action exists in the tools list
if (!(action in this.tools)) {
throw new Error(`Unknown action: ${action}`);
}
const tool = this.tools[action];
- return await tool.execute(actionInput);
+
+ // Validate actionInput based on tool's parameter rules
+ for (const paramRule of tool.parameterRules) {
+ const inputValue = actionInput[paramRule.name];
+
+ if (paramRule.required && inputValue === undefined) {
+ throw new Error(`Missing required parameter: ${paramRule.name}`);
+ }
+
+ // If the parameter is defined, check its type
+ if (inputValue !== undefined) {
+ switch (paramRule.type) {
+ case 'string':
+ if (typeof inputValue !== 'string') {
+ throw new Error(`Expected parameter '${paramRule.name}' to be a string.`);
+ }
+ break;
+ case 'number':
+ if (typeof inputValue !== 'number') {
+ throw new Error(`Expected parameter '${paramRule.name}' to be a number.`);
+ }
+ break;
+ case 'boolean':
+ if (typeof inputValue !== 'boolean') {
+ throw new Error(`Expected parameter '${paramRule.name}' to be a boolean.`);
+ }
+ break;
+ case 'string[]':
+ if (!Array.isArray(inputValue) || !inputValue.every(item => typeof item === 'string')) {
+ throw new Error(`Expected parameter '${paramRule.name}' to be an array of strings.`);
+ }
+ break;
+ case 'number[]':
+ if (!Array.isArray(inputValue) || !inputValue.every(item => typeof item === 'number')) {
+ throw new Error(`Expected parameter '${paramRule.name}' to be an array of numbers.`);
+ }
+ break;
+ default:
+ throw new Error(`Unsupported parameter type: ${paramRule.type}`);
+ }
+ }
+ }
+
+ return await tool.execute(actionInput as ParametersType<typeof tool.parameterRules>);
}
}