diff options
| author | eleanor-park <eleanor_park@brown.edu> | 2024-10-30 19:39:46 -0400 |
|---|---|---|
| committer | eleanor-park <eleanor_park@brown.edu> | 2024-10-30 19:39:46 -0400 |
| commit | c11c760db62f78a07b624b98b209e6ee86036c8e (patch) | |
| tree | c9587b50042a5115373e91ba8ecf9b76913cd321 /src/client/views/nodes/chatbot/agentsystem | |
| parent | b5944e87f9d4f3149161de4de0d76db486461c76 (diff) | |
| parent | 4c768162e0436115a05b9c8b0e4d837d626d45ba (diff) | |
Merge branch 'master' into eleanor-gptdraw
Diffstat (limited to 'src/client/views/nodes/chatbot/agentsystem')
| -rw-r--r-- | src/client/views/nodes/chatbot/agentsystem/Agent.ts | 82 |
1 files changed, 69 insertions, 13 deletions
diff --git a/src/client/views/nodes/chatbot/agentsystem/Agent.ts b/src/client/views/nodes/chatbot/agentsystem/Agent.ts index ccf9caf15..34e7cf5ea 100644 --- a/src/client/views/nodes/chatbot/agentsystem/Agent.ts +++ b/src/client/views/nodes/chatbot/agentsystem/Agent.ts @@ -11,9 +11,11 @@ import { NoTool } from '../tools/NoTool'; import { RAGTool } from '../tools/RAGTool'; import { SearchTool } from '../tools/SearchTool'; import { WebsiteInfoScraperTool } from '../tools/WebsiteInfoScraperTool'; -import { AgentMessage, AssistantMessage, PROCESSING_TYPE, ProcessingInfo, Tool } from '../types/types'; +import { AgentMessage, AssistantMessage, Observation, PROCESSING_TYPE, ProcessingInfo } from '../types/types'; import { Vectorstore } from '../vectorstore/Vectorstore'; import { getReactPrompt } from './prompts'; +import { BaseTool } from '../tools/BaseTool'; +import { Parameter, ParametersType, Tool } from '../tools/ToolTypes'; dotenv.config(); @@ -24,7 +26,6 @@ dotenv.config(); export class Agent { // Private properties private client: OpenAI; - private tools: Record<string, Tool<any>>; // bcz: need a real type here private messages: AgentMessage[] = []; private interMessages: AgentMessage[] = []; private vectorstore: Vectorstore; @@ -36,6 +37,7 @@ export class Agent { private processingNumber: number = 0; private processingInfo: ProcessingInfo[] = []; private streamedAnswerParser: StreamedAnswerParser = new StreamedAnswerParser(); + private tools: Record<string, BaseTool<ReadonlyArray<Parameter>>>; /** * The constructor initializes the agent with the vector store and toolset, and sets up the OpenAI client. @@ -108,15 +110,16 @@ export class Agent { let currentAction: string | undefined; this.processingInfo = []; - // Conversation loop (up to maxTurns) - for (let i = 2; i < maxTurns; i += 2) { + let i = 2; + while (i < maxTurns) { console.log(this.interMessages); console.log(`Turn ${i}/${maxTurns}`); - // Execute a step in the conversation and get the result const result = await this.execute(onProcessingUpdate, onAnswerUpdate); this.interMessages.push({ role: 'assistant', content: result }); + i += 2; + let parsedResult; try { // Parse XML result from the assistant @@ -148,7 +151,7 @@ export class Agent { { type: 'text', text: `<stage number="${i + 1}" role="user">` + builder.build({ action_rules: this.tools[currentAction].getActionRule() }) + `</stage>`, - }, + } as Observation, ]; this.interMessages.push({ role: 'user', content: nextPrompt }); break; @@ -166,8 +169,8 @@ export class Agent { if (currentAction) { try { // Process the action with its input - const observation = (await this.processAction(currentAction, actionInput.inputs)) as any; // bcz: really need a type here - const nextPrompt = [{ type: 'text', text: `<stage number="${i + 1}" role="user"> <observation>` }, ...observation, { type: 'text', text: '</observation></stage>' }]; + const observation = (await this.processAction(currentAction, actionInput.inputs)) as Observation[]; + const nextPrompt = [{ type: 'text', text: `<stage number="${i + 1}" role="user"> <observation>` }, ...observation, { type: 'text', text: '</observation></stage>' }] as Observation[]; console.log(observation); this.interMessages.push({ role: 'user', content: nextPrompt }); this.processingNumber++; @@ -262,16 +265,69 @@ export class Agent { /** * Processes a specific action by invoking the appropriate tool with the provided inputs. - * @param action The action to perform. - * @param actionInput The inputs for the action. - * @returns The result of the action. + * This method ensures that the action exists and validates the types of `actionInput` + * based on the tool's parameter rules. It throws errors for missing required parameters + * or mismatched types before safely executing the tool with the validated input. + * + * Type validation includes checks for: + * - `string`, `number`, `boolean` + * - `string[]`, `number[]` (arrays of strings or numbers) + * + * @param action The action to perform. It corresponds to a registered tool. + * @param actionInput The inputs for the action, passed as an object where each key is a parameter name. + * @returns A promise that resolves to an array of `Observation` objects representing the result of the action. + * @throws An error if the action is unknown, if required parameters are missing, or if input types don't match the expected parameter types. */ - private async processAction(action: string, actionInput: unknown): Promise<unknown> { + private async processAction(action: string, actionInput: Record<string, unknown>): Promise<Observation[]> { + // Check if the action exists in the tools list if (!(action in this.tools)) { throw new Error(`Unknown action: ${action}`); } const tool = this.tools[action]; - return await tool.execute(actionInput); + + // Validate actionInput based on tool's parameter rules + for (const paramRule of tool.parameterRules) { + const inputValue = actionInput[paramRule.name]; + + if (paramRule.required && inputValue === undefined) { + throw new Error(`Missing required parameter: ${paramRule.name}`); + } + + // If the parameter is defined, check its type + if (inputValue !== undefined) { + switch (paramRule.type) { + case 'string': + if (typeof inputValue !== 'string') { + throw new Error(`Expected parameter '${paramRule.name}' to be a string.`); + } + break; + case 'number': + if (typeof inputValue !== 'number') { + throw new Error(`Expected parameter '${paramRule.name}' to be a number.`); + } + break; + case 'boolean': + if (typeof inputValue !== 'boolean') { + throw new Error(`Expected parameter '${paramRule.name}' to be a boolean.`); + } + break; + case 'string[]': + if (!Array.isArray(inputValue) || !inputValue.every(item => typeof item === 'string')) { + throw new Error(`Expected parameter '${paramRule.name}' to be an array of strings.`); + } + break; + case 'number[]': + if (!Array.isArray(inputValue) || !inputValue.every(item => typeof item === 'number')) { + throw new Error(`Expected parameter '${paramRule.name}' to be an array of numbers.`); + } + break; + default: + throw new Error(`Unsupported parameter type: ${paramRule.type}`); + } + } + } + + return await tool.execute(actionInput as ParametersType<typeof tool.parameterRules>); } } |
