diff options
| author | A.J. Shulman <Shulman.aj@gmail.com> | 2024-08-15 13:16:32 -0400 |
|---|---|---|
| committer | A.J. Shulman <Shulman.aj@gmail.com> | 2024-08-15 13:16:32 -0400 |
| commit | 6f9b8f9b393d411a17f7954b6cc36618efe698e2 (patch) | |
| tree | 8090d9d0bafdfe3e97b8fd8914da9d1264e4172c /src/client/views/nodes/ChatBox/Agent.ts | |
| parent | 0c8001c61a55540cdeeb6ae249fdd2835580121c (diff) | |
implemented search tool and other tools but scraping doesn't work
Diffstat (limited to 'src/client/views/nodes/ChatBox/Agent.ts')
| -rw-r--r-- | src/client/views/nodes/ChatBox/Agent.ts | 26 |
1 files changed, 16 insertions, 10 deletions
diff --git a/src/client/views/nodes/ChatBox/Agent.ts b/src/client/views/nodes/ChatBox/Agent.ts index 69b83c1b5..825cd831b 100644 --- a/src/client/views/nodes/ChatBox/Agent.ts +++ b/src/client/views/nodes/ChatBox/Agent.ts @@ -10,6 +10,11 @@ import { Vectorstore } from './vectorstore/Vectorstore'; import { ChatCompletionAssistantMessageParam, ChatCompletionMessageParam } from 'openai/resources'; import dotenv from 'dotenv'; import { ChatBox } from './ChatBox'; +import { DataAnalysisTool } from './tools/DataAnalysisTool'; +import { string } from 'cohere-ai/core/schemas'; +import { WebsiteInfoScraperTool } from './tools/WebsiteInfoScraperTool'; +import { SearchTool } from './tools/SearchTool'; +import { add } from 'lodash'; dotenv.config(); export class Agent { @@ -20,17 +25,22 @@ export class Agent { private vectorstore: Vectorstore; private _history: () => string; private _summaries: () => string; + private _csvData: () => { filename: string; id: string; text: string }[]; - constructor(_vectorstore: Vectorstore, summaries: () => string, history: () => string) { + constructor(_vectorstore: Vectorstore, summaries: () => string, history: () => string, csvData: () => { filename: string; id: string; text: string }[], addLinkedUrlDoc: (url: string, id: string) => void) { console.log(process.env.OPENAI_KEY); this.client = new OpenAI({ apiKey: process.env.OPENAI_KEY, dangerouslyAllowBrowser: true }); this.vectorstore = _vectorstore; this._history = history; this._summaries = summaries; + this._csvData = csvData; this.tools = { - wikipedia: new WikipediaTool(), + //wikipedia: new WikipediaTool(addLinkedUrlDoc), calculate: new CalculateTool(), rag: new RAGTool(this.vectorstore), + dataAnalysis: new DataAnalysisTool(csvData), + websiteInfoScraper: new WebsiteInfoScraperTool(addLinkedUrlDoc), + searchTool: new SearchTool(addLinkedUrlDoc), no_tool: new NoTool(), }; } @@ -44,13 +54,13 @@ export class Agent { console.log(`System prompt: ${systemPrompt}`); this.interMessages = [{ role: 'system', content: systemPrompt }]; - this.interMessages.push({ role: 'user', content: `<step0 role="user"><query>${question}</query></step>` }); + this.interMessages.push({ role: 'user', content: `<step1 role="user"><query>${question}</query></step>` }); const parser = new XMLParser(); const builder = new XMLBuilder(); let currentAction: string | undefined; - for (let i = 1; i < maxTurns; i++) { + for (let i = 3; i < maxTurns; i += 2) { console.log(`Turn ${i}/${maxTurns}`); const result = await this.execute(); @@ -74,12 +84,10 @@ export class Agent { currentAction = step[key] as string; console.log(`Action: ${currentAction}`); if (this.tools[currentAction]) { - i++; - console.log(builder.build({ action_rules: this.tools[currentAction].getActionRule(true) })); const nextPrompt = [ { type: 'text', - text: `<step${i} role="user">` + builder.build({ action_rules: this.tools[currentAction].getActionRule(true) }) + `<\step>`, + text: `<step${i} role="user">` + builder.build({ action_rules: this.tools[currentAction].getActionRule() }) + `<\step>`, }, ]; this.interMessages.push({ role: 'user', content: nextPrompt }); @@ -87,7 +95,6 @@ export class Agent { break; } else { console.log('Error: No valid action'); - i++; this.interMessages.push({ role: 'user', content: `<step${i}>No valid action, try again.</step>` }); break; } @@ -101,8 +108,7 @@ export class Agent { // const rootTagName = stepElement.tagName; // const match = rootTagName.match(/step(\d+)/); // const currentStep = match ? parseInt(match[1]) + 1 : 1; - i++; - const nextPrompt = [{ type: 'text', text: `<step${i}<observation>` }, ...observation, { type: 'text', text: '</observation></step>' }]; + const nextPrompt = [{ type: 'text', text: `<step${i}> <observation>` }, ...observation, { type: 'text', text: '</observation></step>' }]; console.log(observation); this.interMessages.push({ role: 'user', content: nextPrompt }); break; |
