diff options
| author | A.J. Shulman <Shulman.aj@gmail.com> | 2024-10-17 10:41:49 -0400 |
|---|---|---|
| committer | A.J. Shulman <Shulman.aj@gmail.com> | 2024-10-17 10:41:49 -0400 |
| commit | 80d86bd5ae3e1d3dc70e7636f72a872a5fb2f01d (patch) | |
| tree | 0eaea49f596bd16720f05a6535958ab8270673c8 /src/client/views/nodes/chatbot/tools/SearchTool.ts | |
| parent | 596502c232ea6b6b88c3c58486e139074ea056ff (diff) | |
Implemented strict typechecking for tools, specifically tool inputs
Diffstat (limited to 'src/client/views/nodes/chatbot/tools/SearchTool.ts')
| -rw-r--r-- | src/client/views/nodes/chatbot/tools/SearchTool.ts | 39 |
1 files changed, 26 insertions, 13 deletions
diff --git a/src/client/views/nodes/chatbot/tools/SearchTool.ts b/src/client/views/nodes/chatbot/tools/SearchTool.ts index 103abcdbe..c5cf951e7 100644 --- a/src/client/views/nodes/chatbot/tools/SearchTool.ts +++ b/src/client/views/nodes/chatbot/tools/SearchTool.ts @@ -2,8 +2,21 @@ import { v4 as uuidv4 } from 'uuid'; import { Networking } from '../../../../Network'; import { BaseTool } from './BaseTool'; import { Observation } from '../types/types'; +import { ParametersType } from './ToolTypes'; -export class SearchTool extends BaseTool<{ query: { type: string | string[]; description: string; required: boolean } }> { +const searchToolParams = [ + { + name: 'query', + type: 'string[]', + description: 'The search query or queries to use for finding websites', + required: true, + max_inputs: 3, + }, +] as const; + +type SearchToolParamsType = typeof searchToolParams; + +export class SearchTool extends BaseTool<SearchToolParamsType> { private _addLinkedUrlDoc: (url: string, id: string) => void; private _max_results: number; @@ -11,13 +24,7 @@ export class SearchTool extends BaseTool<{ query: { type: string | string[]; des super( 'searchTool', 'Search the web to find a wide range of websites related to a query or multiple queries', - { - query: { - type: 'string', - description: 'The search query or queries to use for finding websites', - required: true, - }, - }, + searchToolParams, 'Provide up to 3 search queries to find a broad range of websites.', 'Returns a list of websites and their overviews based on the search queries.' ); @@ -25,13 +32,16 @@ export class SearchTool extends BaseTool<{ query: { type: string | string[]; des this._max_results = max_results; } - async execute(args: { query: string | string[] }): Promise<Observation[]> { - const queries = Array.isArray(args.query) ? args.query : [args.query]; - const allResults = []; + async execute(args: ParametersType<SearchToolParamsType>): Promise<Observation[]> { + const queries = args.query; + const allResults: Observation[] = []; for (const query of queries) { try { - const { results } = await Networking.PostToServer('/getWebSearchResults', { query, max_results: this._max_results }); + const { results } = await Networking.PostToServer('/getWebSearchResults', { + query, + max_results: this._max_results, + }); const data = results.map((result: { url: string; snippet: string }) => { const id = uuidv4(); return { @@ -42,7 +52,10 @@ export class SearchTool extends BaseTool<{ query: { type: string | string[]; des allResults.push(...data); } catch (error) { console.log(error); - allResults.push({ type: 'text', text: `An error occurred while performing the web search for query: ${query}` }); + allResults.push({ + type: 'text', + text: `An error occurred while performing the web search for query: ${query}`, + }); } } |
