diff options
author | A.J. Shulman <Shulman.aj@gmail.com> | 2024-07-10 16:16:26 -0400 |
---|---|---|
committer | A.J. Shulman <Shulman.aj@gmail.com> | 2024-07-10 16:16:26 -0400 |
commit | cab0311e2fd9a6379628c000d11ddcd805e01f64 (patch) | |
tree | 60cb3f397426cb3931c13ebe3b8a1e8eb98480dd | |
parent | d0e09ff3526e4f6b9aad824fad1020d083a87631 (diff) |
first attempt at integrating everything
-rw-r--r-- | src/client/views/nodes/ChatBox/Agent.ts | 123 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/ChatBot.ts | 14 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/ChatBox.tsx | 96 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/prompts.ts | 99 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/tools/BaseTool.ts | 24 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/tools/CalculateTool.ts | 25 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/tools/RAGTool.ts | 81 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/tools/WikipediaTool.ts | 33 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/types.ts | 15 |
9 files changed, 475 insertions, 35 deletions
diff --git a/src/client/views/nodes/ChatBox/Agent.ts b/src/client/views/nodes/ChatBox/Agent.ts new file mode 100644 index 000000000..f20a75a8d --- /dev/null +++ b/src/client/views/nodes/ChatBox/Agent.ts @@ -0,0 +1,123 @@ +import OpenAI from 'openai'; +import { Tool, AgentMessage } from './types'; +import { getReactPrompt } from './prompts'; +import { XMLParser, XMLBuilder } from 'fast-xml-parser'; + +export class Agent { + private client: OpenAI; + private tools: Record<string, Tool>; + private messages: AgentMessage[] = []; + private interMessages: AgentMessage[] = []; + private summaries: string; + + constructor(private vectorstore: Vectorstore) { + this.client = new OpenAI({ apiKey: process.env.OPENAI_API_KEY }); + this.summaries = this.vectorstore ? this.vectorstore.getSummaries() : 'No documents available.'; + this.tools = { + wikipedia: new WikipediaTool(), + calculate: new CalculateTool(), + rag: new RAGTool(vectorstore, this.summaries), + }; + } + + private formatChatHistory(): string { + let history = '<chat_history>\n'; + for (const message of this.messages) { + if (message.role === 'user') { + history += `<user>${message.content}</user>\n`; + } else if (message.role === 'assistant') { + history += `<assistant>${message.content}</assistant>\n`; + } + } + history += '</chat_history>'; + return history; + } + + async askAgent(question: string, maxTurns: number = 5): Promise<string> { + console.log(`Starting query: ${question}`); + this.messages.push({ role: 'user', content: question }); + const chatHistory = this.formatChatHistory(); + console.log(`Chat history: ${chatHistory}`); + const systemPrompt = getReactPrompt(Object.values(this.tools), chatHistory); + console.log(`System prompt: ${systemPrompt}`); + this.interMessages = [{ role: 'system', content: systemPrompt }]; + + this.interMessages.push({ role: 'assistant', content: `<query>${question}</query>` }); + + for (let i = 0; i < maxTurns; i++) { + console.log(`Turn ${i + 1}/${maxTurns}`); + + const result = await this.execute(); + console.log(`Bot response: ${result}`); + this.interMessages.push({ role: 'assistant', content: result }); + + try { + const parser = new XMLParser(); + const parsedResult = parser.parse(result); + const step = parsedResult[`step${i + 1}`]; + + if (step.thought) console.log(`Thought: ${step.thought}`); + if (step.action) { + console.log(`Action: ${step.action}`); + const action = step.action; + const actionRules = new XMLBuilder().build({ + action_rules: this.tools[action].getActionRule(), + }); + this.interMessages.push({ role: 'user', content: actionRules }); + } + if (step.action_input) { + const actionInput = new XMLBuilder().build({ action_input: step.action_input }); + console.log(`Action input: ${actionInput}`); + try { + const observation = await this.processAction(action, step.action_input); + const nextPrompt = [{ type: 'text', text: '<observation>' }, ...observation, { type: 'text', text: '</observation>' }]; + this.interMessages.push({ role: 'user', content: nextPrompt }); + } catch (e) { + console.error(`Error processing action: ${e}`); + return `<error>${e}</error>`; + } + } + if (step.answer) { + console.log('Answer found. Ending query.'); + const answerContent = new XMLBuilder().build({ answer: step.answer }); + this.messages.push({ role: 'assistant', content: answerContent }); + this.interMessages = []; + return answerContent; + } + } catch (e) { + console.error('Error: Invalid XML response from bot'); + return '<error>Invalid response format.</error>'; + } + } + + console.log('Reached maximum turns. Ending query.'); + return '<error>Reached maximum turns without finding an answer</error>'; + } + + private async execute(): Promise<string> { + const completion = await this.client.chat.completions.create({ + model: 'gpt-4', + messages: this.interMessages, + temperature: 0, + }); + return completion.choices[0].message.content; + } + + private async processAction(action: string, actionInput: any): Promise<any> { + if (!(action in this.tools)) { + throw new Error(`Unknown action: ${action}`); + } + + const tool = this.tools[action]; + const args: Record<string, any> = {}; + for (const paramName in tool.parameters) { + if (actionInput[paramName] !== undefined) { + args[paramName] = actionInput[paramName]; + } else { + throw new Error(`Missing required parameter '${paramName}' for action '${action}'`); + } + } + + return await tool.execute(args); + } +} diff --git a/src/client/views/nodes/ChatBox/ChatBot.ts b/src/client/views/nodes/ChatBox/ChatBot.ts new file mode 100644 index 000000000..31b4ea9e3 --- /dev/null +++ b/src/client/views/nodes/ChatBox/ChatBot.ts @@ -0,0 +1,14 @@ +import { Agent } from './Agent'; +import { Vectorstore } from './vectorstore/VectorstoreUpload'; + +export class ChatBot { + private agent: Agent; + + constructor(vectorstore: Vectorstore) { + this.agent = new Agent(vectorstore); + } + + async ask(question: string): Promise<string> { + return await this.agent.askAgent(question); + } +} diff --git a/src/client/views/nodes/ChatBox/ChatBox.tsx b/src/client/views/nodes/ChatBox/ChatBox.tsx index 2283aad56..73f35f501 100644 --- a/src/client/views/nodes/ChatBox/ChatBox.tsx +++ b/src/client/views/nodes/ChatBox/ChatBox.tsx @@ -27,6 +27,7 @@ import { Vectorstore } from './vectorstore/VectorstoreUpload'; import { DocumentView } from '../DocumentView'; import { CollectionFreeFormDocumentView } from '../CollectionFreeFormDocumentView'; import { CollectionFreeFormView } from '../../collections/collectionFreeForm'; +import { ChatBot } from './ChatBot'; @observer export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() { @@ -42,6 +43,7 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() { private documents: AI_Document[] = []; private _oldWheel: any; private vectorstore: Vectorstore; + private chatbot: ChatBot; // Add the ChatBot instance public static LayoutString(fieldKey: string) { return FieldView.LayoutString(ChatBox, fieldKey); @@ -55,6 +57,7 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() { this.openai = this.initializeOpenAI(); this.getOtherDocs(); this.vectorstore = new Vectorstore(); + this.chatbot = new ChatBot(this.vectorstore); // Initialize the ChatBot reaction( () => this.history.map((msg: AssistantMessage) => ({ role: msg.role, text: msg.text, follow_up_questions: msg.follow_up_questions, citations: msg.citations })), @@ -65,7 +68,11 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() { } getOtherDocs = async () => { - const visible_docs = (CollectionFreeFormDocumentView.from(this._props.DocumentView?.())?._props.parent as CollectionFreeFormView)?.childDocs.filter(doc => doc != this.Document); + const visible_docs = (CollectionFreeFormDocumentView.from(this._props.DocumentView?.())?._props.parent as CollectionFreeFormView)?.childDocs + .filter(doc => doc != this.Document) + .map(d => DocCast(d?.annotationOn, d)) + .filter(d => d); + console.log('All Docs:', visible_docs); visible_docs?.forEach(async doc => { @@ -121,17 +128,39 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() { runInAction(() => { this.history.push({ role: ASSISTANT_ROLE.USER, text: trimmedText }); }); - const { response } = await Networking.PostToServer('/askAgent', { input: trimmedText }); + this.isLoading = true; + const response = await this.chatbot.ask(trimmedText); // Use the chatbot to get the response runInAction(() => { - this.history.push({ role: ASSISTANT_ROLE.ASSISTANT, text: response }); + this.history.push(this.parseAssistantResponse(response)); }); this.dataDoc.data = JSON.stringify(this.history); } catch (err) { console.error('Error:', err); + runInAction(() => { + this.history.push({ role: ASSISTANT_ROLE.ASSISTANT, text: 'Sorry, I encountered an error while processing your request.' }); + }); + } finally { + this.isLoading = false; } } }; + parseAssistantResponse(response: string): AssistantMessage { + const parser = new DOMParser(); + const xmlDoc = parser.parseFromString(response, 'text/xml'); + const answerElement = xmlDoc.querySelector('answer'); + const followUpQuestionsElement = xmlDoc.querySelector('follow_up_questions'); + + const text = answerElement ? answerElement.textContent || '' : ''; + const followUpQuestions = followUpQuestionsElement ? Array.from(followUpQuestionsElement.querySelectorAll('question')).map(q => q.textContent || '') : []; + + return { + role: ASSISTANT_ROLE.ASSISTANT, + text, + follow_up_questions: followUpQuestions, + }; + } + // @action // uploadLinks = async (linkedDocs: Doc[]) => { // if (this.isInitializing) { @@ -241,42 +270,39 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() { r?.addEventListener('wheel', this.onPassiveWheel, { passive: false }); }}> <div className="messages"> - { - //this.history.map((message, index) => ( - // <MessageComponent - // key={index} - // message={message} - // toggleToolLogs={this.toggleToolLogs} - // expandedLogIndex={this.expandedLogIndex} - // index={index} - // showModal={this.showModal} - // goToLinkedDoc={() => {}} - // setCurrentFile={this.setCurrentFile} - // onFollowUpClick={this.handleFollowUpClick} - // /> - //) - //) - } - { - //!this.current_message ? null : ( - // <MessageComponent - // key={this.history.length} - // message={this.current_message} - // toggleToolLogs={this.toggleToolLogs} - // expandedLogIndex={this.expandedLogIndex} - // index={this.history.length} - // showModal={this.showModal} - // goToLinkedDoc={() => {}} - // setCurrentFile={this.setCurrentFile} - // onFollowUpClick={this.handleFollowUpClick} - // /> - //) - } + {this.history.map((message, index) => ( + <MessageComponent + key={index} + message={message} + toggleToolLogs={this.toggleToolLogs} + expandedLogIndex={this.expandedScratchpadIndex} + index={index} + showModal={() => {}} // Implement this method if needed + goToLinkedDoc={() => {}} // Implement this method if needed + setCurrentFile={() => {}} // Implement this method if needed + onFollowUpClick={this.handleFollowUpClick} + /> + ))} + {this.current_message && ( + <MessageComponent + key={this.history.length} + message={this.current_message} + toggleToolLogs={this.toggleToolLogs} + expandedLogIndex={this.expandedScratchpadIndex} + index={this.history.length} + showModal={() => {}} // Implement this method if needed + goToLinkedDoc={() => {}} // Implement this method if needed + setCurrentFile={() => {}} // Implement this method if needed + onFollowUpClick={this.handleFollowUpClick} + /> + )} </div> </div> <form onSubmit={this.askGPT} className="chat-form"> <input type="text" name="messageInput" autoComplete="off" placeholder="Type a message..." value={this.inputValue} onChange={e => (this.inputValue = e.target.value)} /> - <button type="submit">Send</button> + <button type="submit" disabled={this.isLoading}> + {this.isLoading ? 'Thinking...' : 'Send'} + </button> </form> </div> /** </MathJaxContext> **/ diff --git a/src/client/views/nodes/ChatBox/prompts.ts b/src/client/views/nodes/ChatBox/prompts.ts new file mode 100644 index 000000000..8835265e4 --- /dev/null +++ b/src/client/views/nodes/ChatBox/prompts.ts @@ -0,0 +1,99 @@ +// prompts.ts + +import { Tool } from './types'; + +export function getReactPrompt(tools: Tool[], chatHistory: string): string { + const toolDescriptions = tools.map(tool => `${tool.name}:\n${tool.briefSummary}`).join('\n*****\n'); + + return ` + You run in a loop of Thought, Action, PAUSE, Action Input, Pause, Observation. + (this Thought/Action/PAUSE/Action Input/PAUSE/Observation can repeat N times) + Contain each stage of the loop within an XML element that specifies the stage type (e.g. <thought>content of the thought</thought>). + At the end of the loop, you output an Answer with the answer content contained within an XML element with an <answer> tag. At the end of the answer should be an array of 3 potential follow-up questions for the user to ask you next, contained within a <follow_up_questions> key. + Use <thought> to describe your thoughts about the question you have been asked. + Use <action> to specify run one of the actions available to you - then return a </pause> element. + Then, you will be provided with action rules within an <action_rules> element that specifies how you should structure the input to the action and what the output of that action will look like - then return another </pause> element. + Then, provide within an <action_input> element each parameter, with parameter names as element tags themselves with their values inside, following the structure defined in the action rules. + Observation, in an <observation> element will be the result of running those actions. + ********** + Your available actions are: + ***** + ${toolDescriptions} + ********** + Example: + You will be called with: + <query>What is the capital of France?</query> + + You will then output: + <step1> + <thought>I should look up France on Wikipedia</thought> + <action>wikipedia</action> + <pause/> + </step1> + + You will be called again with this: + <action_rules> + { + "wikipedia": { + "name": "wikipedia", + "description": "Search Wikipedia and return a summary", + "parameters": [ + { + "title": { + "type": "string", + "description": "The title of the Wikipedia article to search", + "required": "true" + } + } + ] + } + } + </action_rules> + + You will then output (back in valid XML with the parameters each being a tag): + <step2> + <action_input> + <title>France</title> + </action_input> + </step2> + + You will then be called again with this: + <observation>France is a country. The capital is Paris.</observation> + + You then output: + <step3> + <answer> + The capital of France is Paris + <follow_up_questions> + <question>Where in France is Paris located?</question> + <question>What are some major tourist attractions in Paris?</question> + <question>What are some other major cities in France?</question> + </follow_up_questions> + </answer> + </step3> + ********** + Here is the history of your conversation with the user (all loop steps are ommitted, so it is just the user query and final answer): + ${chatHistory} + Use context from the past conversation if necessary. + ********** + If the response is inadequate, repeat the loop, either trying a different tool or changing the parameters for the action input. + + !!!IMPORTANT When you have an Answer, Write your entire response inside an <answer> element (which itself should be inside the step element for the current step). After you finish the answer, provide an array of 3 follow-up questions inside a <follow_up_questions> array. These should relate to the query and the response and should aim to help the user better understand whatever they are looking for. + ********** + !!!IMPORTANT Every response, provide in full parsable and valid XML with the root element being the step number (e.g. <step1>), iterated every time you output something new. + `; +} + +export function getSummarizedChunksPrompt(chunks: string): string { + return `Please provide a comprehensive summary of what you think the document from which these chunks originated. + Ensure the summary captures the main ideas and key points from all provided chunks. Be concise and brief and only provide the summary in paragraph form. + + Text chunks: + \`\`\` + ${chunks} + \`\`\``; +} + +export function getSummarizedSystemPrompt(): string { + return 'You are an AI assistant tasked with summarizing a document. You are provided with important chunks from the document and provide a summary, as best you can, of what the document will contain overall. Be concise and brief with your response.'; +} diff --git a/src/client/views/nodes/ChatBox/tools/BaseTool.ts b/src/client/views/nodes/ChatBox/tools/BaseTool.ts new file mode 100644 index 000000000..3511d9528 --- /dev/null +++ b/src/client/views/nodes/ChatBox/tools/BaseTool.ts @@ -0,0 +1,24 @@ +import { Tool } from '../types'; + +export abstract class BaseTool implements Tool { + constructor( + public name: string, + public description: string, + public parameters: Record<string, any>, + public useRules: string, + public briefSummary: string + ) {} + + abstract execute(args: Record<string, any>): Promise<any>; + + getActionRule(): Record<string, any> { + return { + [this.name]: { + name: this.name, + useRules: this.useRules, + description: this.description, + parameters: this.parameters, + }, + }; + } +} diff --git a/src/client/views/nodes/ChatBox/tools/CalculateTool.ts b/src/client/views/nodes/ChatBox/tools/CalculateTool.ts new file mode 100644 index 000000000..b881d90fa --- /dev/null +++ b/src/client/views/nodes/ChatBox/tools/CalculateTool.ts @@ -0,0 +1,25 @@ +import { BaseTool } from './BaseTool'; + +export class CalculateTool extends BaseTool { + constructor() { + super( + 'calculate', + 'Perform a calculation', + { + expression: { + type: 'string', + description: 'The mathematical expression to evaluate', + required: 'true', + }, + }, + 'Provide a mathematical expression to calculate that would work with JavaScript eval().', + 'Runs a calculation and returns the number - uses JavaScript so be sure to use floating point syntax if necessary' + ); + } + + async execute(args: { expression: string }): Promise<any> { + // Note: Using eval() can be dangerous. Consider using a safer alternative. + const result = eval(args.expression); + return [{ type: 'text', text: result.toString() }]; + } +} diff --git a/src/client/views/nodes/ChatBox/tools/RAGTool.ts b/src/client/views/nodes/ChatBox/tools/RAGTool.ts new file mode 100644 index 000000000..84d5430e7 --- /dev/null +++ b/src/client/views/nodes/ChatBox/tools/RAGTool.ts @@ -0,0 +1,81 @@ +import { BaseTool } from './BaseTool'; +import { Vectorstore } from '../vectorstore/VectorstoreUpload'; +import { Chunk } from '../types'; + +export class RAGTool extends BaseTool { + constructor( + private vectorstore: Vectorstore, + summaries: string + ) { + super( + 'rag', + 'Perform a RAG search on user documents', + { + hypothetical_document_chunk: { + type: 'string', + description: + "Detailed version of the prompt that is effectively a hypothetical document chunk that would be ideal to embed and compare to the vectors of real document chunks to fetch the most relevant document chunks to answer the user's query", + required: 'true', + }, + }, + `Your task is to first provide a response to the user's prompt based on the information given in the chunks and considering the chat history. Follow these steps: + + 1. Carefully read and analyze the provided chunks, which may include text, images, or tables. Each chunk has an associated chunk_id. + + 2. Review the prompt and chat history to understand the context of the user's question or request. + + 3. Formulate a response that addresses the prompt using information from the relevant chunks. Your response should be informative and directly answer the user's question or request. + + 4. Use citations to support your response. Citations should contain direct textual references to the granular, specific part of the original chunk that applies to the situation—with no text ommitted. Citations should be in the following format: + - For text: <citation chunk_id="d980c2a7-cad3-4d7e-9eae-19bd2380bd02" type="text">relevant direct text from the chunk that the citation in referencing specifically</citation> + - For images or tables: <citation chunk_id="9ef37681-b57e-4424-b877-e1ebc326ff11" type="image"></citation> + + Place citations after the sentences they apply to. You can use multiple citations in a row. + + 5. If there's insufficient information in the provided chunks to answer the prompt sufficiently, ALWAYS respond with <answer>RAG not applicable</answer> + + Write your entire response, including follow-up questions, inside <answer> tags. Remember to use the citation format for both text and image references, and maintain a conversational tone throughout your response. + + !!!IMPORTANT Before you close the tag with </answer>, within the answer tags provide a set of 3 follow-up questions inside a <follow_up_questions> tag and individually within <question> tags. These should relate to the document, the current query, and the chat_history and should aim to help the user better understand whatever they are looking for. + Also, ensure that the answer tags are wrapped with the correct step tags as well.`, + `Performs a RAG (Retrieval-Augmented Generation) search on user documents and returns a + set of document chunks (either images or text) that can be used to provide a grounded response based on + user documents + + !!!IMPORTANT Use the RAG tool ANYTIME the question may potentially (even if you are not sure) relate to one of the user's documents. + Here are the summaries of the user's documents: + ${summaries}` + ); + } + + async execute(args: { hypothetical_document_chunk: string }): Promise<any> { + const relevantChunks = await this.vectorstore.retrieve(args.hypothetical_document_chunk); + return this.getFormattedChunks(relevantChunks); + } + + private getFormattedChunks(relevantChunks: Chunk[]): { type: string; text?: string; image_url?: { url: string } }[] { + const content: { type: string; text?: string; image_url?: { url: string } }[] = [{ type: 'text', text: '<chunks>' }]; + + for (const chunk of relevantChunks) { + content.push({ + type: 'text', + text: `<chunk chunk_id=${chunk.id} chunk_type=${chunk.metadata.type === 'image' ? 'image' : 'text'}>`, + }); + + if (chunk.metadata.type === 'image') { + // Implement image loading and base64 encoding here + // For now, we'll just add a placeholder + content.push({ + type: 'image_url', + image_url: { url: chunk.metadata.file_path }, + }); + } + + content.push({ type: 'text', text: `${chunk.metadata.text}\n</chunk>\n` }); + } + + content.push({ type: 'text', text: '</chunks>' }); + + return content; + } +} diff --git a/src/client/views/nodes/ChatBox/tools/WikipediaTool.ts b/src/client/views/nodes/ChatBox/tools/WikipediaTool.ts new file mode 100644 index 000000000..0aef58f61 --- /dev/null +++ b/src/client/views/nodes/ChatBox/tools/WikipediaTool.ts @@ -0,0 +1,33 @@ +import { BaseTool } from './BaseTool'; +import axios from 'axios'; + +export class WikipediaTool extends BaseTool { + constructor() { + super( + 'wikipedia', + 'Search Wikipedia and return a summary', + { + title: { + type: 'string', + description: 'The title of the Wikipedia article to search', + required: 'true', + }, + }, + 'Provide simply the title you want to search on Wikipedia and nothing more. If re-using this tool, try a different title for different information.', + 'Returns a summary from searching an article title on Wikipedia' + ); + } + + async execute(args: { title: string }): Promise<any> { + const response = await axios.get('https://en.wikipedia.org/w/api.php', { + params: { + action: 'query', + list: 'search', + srsearch: args.title, + format: 'json', + }, + }); + const result = response.data.query.search[0].snippet; + return [{ type: 'text', text: result }]; + } +} diff --git a/src/client/views/nodes/ChatBox/types.ts b/src/client/views/nodes/ChatBox/types.ts index 7acb96c15..c60973be3 100644 --- a/src/client/views/nodes/ChatBox/types.ts +++ b/src/client/views/nodes/ChatBox/types.ts @@ -46,6 +46,21 @@ export interface AI_Document { type: string; } +export interface Tool { + name: string; + description: string; + parameters: Record<string, any>; + useRules: string; + briefSummary: string; + execute: (args: Record<string, any>) => Promise<any>; + getActionRule: () => Record<string, any>; +} + +export interface AgentMessage { + role: 'system' | 'user' | 'assistant'; + content: string | { type: string; text?: string; image_url?: { url: string } }[]; +} + export function convertToAIDocument(json: any): AI_Document { if (!json) { throw new Error('Invalid JSON object'); |