aboutsummaryrefslogtreecommitdiff
path: root/src/client/views/nodes/ChatBox/Agent.ts
diff options
context:
space:
mode:
authorA.J. Shulman <Shulman.aj@gmail.com>2024-08-15 13:16:32 -0400
committerA.J. Shulman <Shulman.aj@gmail.com>2024-08-15 13:16:32 -0400
commit6f9b8f9b393d411a17f7954b6cc36618efe698e2 (patch)
tree8090d9d0bafdfe3e97b8fd8914da9d1264e4172c /src/client/views/nodes/ChatBox/Agent.ts
parent0c8001c61a55540cdeeb6ae249fdd2835580121c (diff)
implemented search tool and other tools but scraping doesn't work
Diffstat (limited to 'src/client/views/nodes/ChatBox/Agent.ts')
-rw-r--r--src/client/views/nodes/ChatBox/Agent.ts26
1 files changed, 16 insertions, 10 deletions
diff --git a/src/client/views/nodes/ChatBox/Agent.ts b/src/client/views/nodes/ChatBox/Agent.ts
index 69b83c1b5..825cd831b 100644
--- a/src/client/views/nodes/ChatBox/Agent.ts
+++ b/src/client/views/nodes/ChatBox/Agent.ts
@@ -10,6 +10,11 @@ import { Vectorstore } from './vectorstore/Vectorstore';
import { ChatCompletionAssistantMessageParam, ChatCompletionMessageParam } from 'openai/resources';
import dotenv from 'dotenv';
import { ChatBox } from './ChatBox';
+import { DataAnalysisTool } from './tools/DataAnalysisTool';
+import { string } from 'cohere-ai/core/schemas';
+import { WebsiteInfoScraperTool } from './tools/WebsiteInfoScraperTool';
+import { SearchTool } from './tools/SearchTool';
+import { add } from 'lodash';
dotenv.config();
export class Agent {
@@ -20,17 +25,22 @@ export class Agent {
private vectorstore: Vectorstore;
private _history: () => string;
private _summaries: () => string;
+ private _csvData: () => { filename: string; id: string; text: string }[];
- constructor(_vectorstore: Vectorstore, summaries: () => string, history: () => string) {
+ constructor(_vectorstore: Vectorstore, summaries: () => string, history: () => string, csvData: () => { filename: string; id: string; text: string }[], addLinkedUrlDoc: (url: string, id: string) => void) {
console.log(process.env.OPENAI_KEY);
this.client = new OpenAI({ apiKey: process.env.OPENAI_KEY, dangerouslyAllowBrowser: true });
this.vectorstore = _vectorstore;
this._history = history;
this._summaries = summaries;
+ this._csvData = csvData;
this.tools = {
- wikipedia: new WikipediaTool(),
+ //wikipedia: new WikipediaTool(addLinkedUrlDoc),
calculate: new CalculateTool(),
rag: new RAGTool(this.vectorstore),
+ dataAnalysis: new DataAnalysisTool(csvData),
+ websiteInfoScraper: new WebsiteInfoScraperTool(addLinkedUrlDoc),
+ searchTool: new SearchTool(addLinkedUrlDoc),
no_tool: new NoTool(),
};
}
@@ -44,13 +54,13 @@ export class Agent {
console.log(`System prompt: ${systemPrompt}`);
this.interMessages = [{ role: 'system', content: systemPrompt }];
- this.interMessages.push({ role: 'user', content: `<step0 role="user"><query>${question}</query></step>` });
+ this.interMessages.push({ role: 'user', content: `<step1 role="user"><query>${question}</query></step>` });
const parser = new XMLParser();
const builder = new XMLBuilder();
let currentAction: string | undefined;
- for (let i = 1; i < maxTurns; i++) {
+ for (let i = 3; i < maxTurns; i += 2) {
console.log(`Turn ${i}/${maxTurns}`);
const result = await this.execute();
@@ -74,12 +84,10 @@ export class Agent {
currentAction = step[key] as string;
console.log(`Action: ${currentAction}`);
if (this.tools[currentAction]) {
- i++;
- console.log(builder.build({ action_rules: this.tools[currentAction].getActionRule(true) }));
const nextPrompt = [
{
type: 'text',
- text: `<step${i} role="user">` + builder.build({ action_rules: this.tools[currentAction].getActionRule(true) }) + `<\step>`,
+ text: `<step${i} role="user">` + builder.build({ action_rules: this.tools[currentAction].getActionRule() }) + `<\step>`,
},
];
this.interMessages.push({ role: 'user', content: nextPrompt });
@@ -87,7 +95,6 @@ export class Agent {
break;
} else {
console.log('Error: No valid action');
- i++;
this.interMessages.push({ role: 'user', content: `<step${i}>No valid action, try again.</step>` });
break;
}
@@ -101,8 +108,7 @@ export class Agent {
// const rootTagName = stepElement.tagName;
// const match = rootTagName.match(/step(\d+)/);
// const currentStep = match ? parseInt(match[1]) + 1 : 1;
- i++;
- const nextPrompt = [{ type: 'text', text: `<step${i}<observation>` }, ...observation, { type: 'text', text: '</observation></step>' }];
+ const nextPrompt = [{ type: 'text', text: `<step${i}> <observation>` }, ...observation, { type: 'text', text: '</observation></step>' }];
console.log(observation);
this.interMessages.push({ role: 'user', content: nextPrompt });
break;