diff options
author | A.J. Shulman <Shulman.aj@gmail.com> | 2024-07-10 16:35:11 -0400 |
---|---|---|
committer | A.J. Shulman <Shulman.aj@gmail.com> | 2024-07-10 16:35:11 -0400 |
commit | aa8b1248408846d6a158f8df1c76fa3015ce3aac (patch) | |
tree | c0bf0d85b3f09a59e001bdc93963fc413222f942 /src | |
parent | cab0311e2fd9a6379628c000d11ddcd805e01f64 (diff) |
Fixing bugs and attempting to get it to work
Diffstat (limited to 'src')
-rw-r--r-- | src/client/views/nodes/ChatBox/Agent.ts | 16 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/ChatBot.ts | 2 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/ChatBox.tsx | 51 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/MessageComponent.tsx | 127 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/tools/BaseTool.ts | 4 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/tools/CalculateTool.ts | 2 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/tools/RAGTool.ts | 29 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/tools/WikipediaTool.ts | 2 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/types.ts | 7 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/vectorstore/VectorstoreUpload.ts | 1 |
10 files changed, 107 insertions, 134 deletions
diff --git a/src/client/views/nodes/ChatBox/Agent.ts b/src/client/views/nodes/ChatBox/Agent.ts index f20a75a8d..4c2838540 100644 --- a/src/client/views/nodes/ChatBox/Agent.ts +++ b/src/client/views/nodes/ChatBox/Agent.ts @@ -2,10 +2,17 @@ import OpenAI from 'openai'; import { Tool, AgentMessage } from './types'; import { getReactPrompt } from './prompts'; import { XMLParser, XMLBuilder } from 'fast-xml-parser'; +import { WikipediaTool } from './tools/WikipediaTool'; +import { CalculateTool } from './tools/CalculateTool'; +import { RAGTool } from './tools/RAGTool'; +import { Vectorstore } from './vectorstore/VectorstoreUpload'; +import { ChatCompletionAssistantMessageParam, ChatCompletionMessageParam } from 'openai/resources'; +import dotenv from 'dotenv'; +dotenv.config(); export class Agent { private client: OpenAI; - private tools: Record<string, Tool>; + private tools: Record<string, Tool<any>>; private messages: AgentMessage[] = []; private interMessages: AgentMessage[] = []; private summaries: string; @@ -69,7 +76,7 @@ export class Agent { const actionInput = new XMLBuilder().build({ action_input: step.action_input }); console.log(`Action input: ${actionInput}`); try { - const observation = await this.processAction(action, step.action_input); + const observation = await this.processAction(step.action, step.action_input); const nextPrompt = [{ type: 'text', text: '<observation>' }, ...observation, { type: 'text', text: '</observation>' }]; this.interMessages.push({ role: 'user', content: nextPrompt }); } catch (e) { @@ -97,10 +104,11 @@ export class Agent { private async execute(): Promise<string> { const completion = await this.client.chat.completions.create({ model: 'gpt-4', - messages: this.interMessages, + messages: this.interMessages as ChatCompletionMessageParam[], temperature: 0, }); - return completion.choices[0].message.content; + if (completion.choices[0].message.content) return completion.choices[0].message.content; + else throw new Error('No completion content found'); } private async processAction(action: string, actionInput: any): Promise<any> { diff --git a/src/client/views/nodes/ChatBox/ChatBot.ts b/src/client/views/nodes/ChatBox/ChatBot.ts index 31b4ea9e3..8b5e0982c 100644 --- a/src/client/views/nodes/ChatBox/ChatBot.ts +++ b/src/client/views/nodes/ChatBox/ChatBot.ts @@ -1,5 +1,7 @@ import { Agent } from './Agent'; import { Vectorstore } from './vectorstore/VectorstoreUpload'; +import dotenv from 'dotenv'; +dotenv.config(); export class ChatBot { private agent: Agent; diff --git a/src/client/views/nodes/ChatBox/ChatBox.tsx b/src/client/views/nodes/ChatBox/ChatBox.tsx index 73f35f501..3ecb2d340 100644 --- a/src/client/views/nodes/ChatBox/ChatBox.tsx +++ b/src/client/views/nodes/ChatBox/ChatBox.tsx @@ -1,33 +1,24 @@ -import { MathJaxContext } from 'better-react-mathjax'; import { action, makeObservable, observable, observe, reaction, runInAction } from 'mobx'; import { observer } from 'mobx-react'; import OpenAI, { ClientOptions } from 'openai'; -import { ImageFile, Message } from 'openai/resources/beta/threads/messages'; -import { RunStep } from 'openai/resources/beta/threads/runs/steps'; import * as React from 'react'; import { Doc } from '../../../../fields/Doc'; -import { Id } from '../../../../fields/FieldSymbols'; import { CsvCast, DocCast, PDFCast, StrCast } from '../../../../fields/Types'; -import { CsvField } from '../../../../fields/URLField'; import { Networking } from '../../../Network'; -import { DocUtils } from '../../../documents/DocUtils'; import { DocumentType } from '../../../documents/DocumentTypes'; import { Docs } from '../../../documents/Documents'; -import { DocumentManager } from '../../../util/DocumentManager'; import { LinkManager } from '../../../util/LinkManager'; import { ViewBoxAnnotatableComponent } from '../../DocComponent'; import { FieldView, FieldViewProps } from '../FieldView'; import './ChatBox.scss'; import MessageComponent from './MessageComponent'; -import { ASSISTANT_ROLE, AssistantMessage, AI_Document, convertToAIDocument } from './types'; -import { Annotation } from 'mobx/dist/internal'; -import { FormEvent } from 'react'; -import { url } from 'inspector'; +import { ASSISTANT_ROLE, AssistantMessage, AI_Document, convertToAIDocument, Citation } from './types'; import { Vectorstore } from './vectorstore/VectorstoreUpload'; -import { DocumentView } from '../DocumentView'; import { CollectionFreeFormDocumentView } from '../CollectionFreeFormDocumentView'; import { CollectionFreeFormView } from '../../collections/collectionFreeForm'; import { ChatBot } from './ChatBot'; +import dotenv from 'dotenv'; +dotenv.config(); @observer export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() { @@ -151,7 +142,7 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() { const answerElement = xmlDoc.querySelector('answer'); const followUpQuestionsElement = xmlDoc.querySelector('follow_up_questions'); - const text = answerElement ? answerElement.textContent || '' : ''; + const text = answerElement ? answerElement.innerHTML || '' : ''; // Use innerHTML to preserve citation tags const followUpQuestions = followUpQuestionsElement ? Array.from(followUpQuestionsElement.querySelectorAll('question')).map(q => q.textContent || '') : []; return { @@ -161,6 +152,19 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() { }; } + @action + updateMessageCitations = (index: number, citations: Citation[]) => { + if (this.history[index]) { + this.history[index].citations = citations; + } + }; + + @action + handleCitationClick = (citation: Citation) => { + console.log('Citation clicked:', citation); + // You can implement additional functionality here, such as showing a modal with the full citation content + }; + // @action // uploadLinks = async (linkedDocs: Doc[]) => { // if (this.isInitializing) { @@ -259,7 +263,6 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() { }; render() { return ( - /** <MathJaxContext config={this.mathJaxConfig}> **/ <div className="chatBox"> {this.isInitializing && <div className="initializing-overlay">Initializing...</div>} <div @@ -271,29 +274,16 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() { }}> <div className="messages"> {this.history.map((message, index) => ( - <MessageComponent - key={index} - message={message} - toggleToolLogs={this.toggleToolLogs} - expandedLogIndex={this.expandedScratchpadIndex} - index={index} - showModal={() => {}} // Implement this method if needed - goToLinkedDoc={() => {}} // Implement this method if needed - setCurrentFile={() => {}} // Implement this method if needed - onFollowUpClick={this.handleFollowUpClick} - /> + <MessageComponent key={index} message={message} index={index} onFollowUpClick={this.handleFollowUpClick} onCitationClick={this.handleCitationClick} updateMessageCitations={this.updateMessageCitations} /> ))} {this.current_message && ( <MessageComponent key={this.history.length} message={this.current_message} - toggleToolLogs={this.toggleToolLogs} - expandedLogIndex={this.expandedScratchpadIndex} index={this.history.length} - showModal={() => {}} // Implement this method if needed - goToLinkedDoc={() => {}} // Implement this method if needed - setCurrentFile={() => {}} // Implement this method if needed onFollowUpClick={this.handleFollowUpClick} + onCitationClick={this.handleCitationClick} + updateMessageCitations={this.updateMessageCitations} /> )} </div> @@ -305,7 +295,6 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() { </button> </form> </div> - /** </MathJaxContext> **/ ); } } diff --git a/src/client/views/nodes/ChatBox/MessageComponent.tsx b/src/client/views/nodes/ChatBox/MessageComponent.tsx index 15c0811fb..1baf6d7d5 100644 --- a/src/client/views/nodes/ChatBox/MessageComponent.tsx +++ b/src/client/views/nodes/ChatBox/MessageComponent.tsx @@ -1,110 +1,73 @@ -/* eslint-disable react/require-default-props */ import React from 'react'; import { observer } from 'mobx-react'; -import { MathJax, MathJaxContext } from 'better-react-mathjax'; -import ReactMarkdown from 'react-markdown'; -import { AssistantMessage } from './types'; +import { AssistantMessage, CHUNK_TYPE, Citation } from './types'; import { TbInfoCircleFilled } from 'react-icons/tb'; interface MessageComponentProps { message: AssistantMessage; - toggleToolLogs: (index: number) => void; - expandedLogIndex: number | null; index: number; - showModal: () => void; - goToLinkedDoc: (url: string) => void; - setCurrentFile: (file: { url: string }) => void; - onFollowUpClick: (question: string) => void; // New prop - isCurrent?: boolean; + onFollowUpClick: (question: string) => void; + onCitationClick: (citation: Citation) => void; + updateMessageCitations: (index: number, citations: Citation[]) => void; } -const MessageComponent: React.FC<MessageComponentProps> = function ({ - message, - toggleToolLogs, - expandedLogIndex, - goToLinkedDoc, - index, - showModal, - setCurrentFile, - onFollowUpClick, // New prop - isCurrent = false, -}) { - const LinkRenderer = ({ href, children }: { href: string; children: React.ReactNode }) => { - const regex = /([a-zA-Z0-9_.!-]+)~~~(citation|file_path)/; - const matches = href.match(regex); - const url = matches ? matches[1] : href; - const linkType = matches ? matches[2] : null; - if (linkType === 'citation') { - children = <TbInfoCircleFilled />; - } - const style = { - color: 'lightblue', - verticalAlign: linkType === 'citation' ? 'super' : 'baseline', - fontSize: linkType === 'citation' ? 'smaller' : 'inherit', - }; +const MessageComponent: React.FC<MessageComponentProps> = function ({ message, index, onFollowUpClick, onCitationClick, updateMessageCitations }) { + const LinkRenderer = ({ children }: { children: React.ReactNode }) => { + const text = children as string; + const citationRegex = /<citation chunk_id="([^"]*)" type="([^"]*)">([^<]*)<\/citation>/g; + const parts = []; + let lastIndex = 0; + let match; + const citations: Citation[] = []; - return ( - <a - href="#" - onClick={e => { - e.preventDefault(); - if (linkType === 'citation') { - goToLinkedDoc(url); - } else if (linkType === 'file_path') { - showModal(); - setCurrentFile({ url }); - } - }} - style={style}> - {children} - </a> - ); - }; + while ((match = citationRegex.exec(text)) !== null) { + const [fullMatch, chunkId, type, content] = match; + const citation: Citation = { chunk_id: chunkId, type: type as CHUNK_TYPE, text: content }; + citations.push(citation); + + parts.push(text.slice(lastIndex, match.index)); + parts.push( + <a + key={chunkId} + href="#" + onClick={e => { + e.preventDefault(); + onCitationClick(citation); + }} + style={{ + color: 'lightblue', + verticalAlign: 'super', + fontSize: 'smaller', + }}> + <TbInfoCircleFilled /> + </a> + ); + lastIndex = match.index + fullMatch.length; + } - const parseMessage = (text: string) => { - const answerMatch = text.match(/<answer>([\s\S]*?)<\/answer>/); - const followUpMatch = text.match(/<follow_up_question>([\s\S]*?)<\/follow_up_question>/); + parts.push(text.slice(lastIndex)); - const answer = answerMatch ? answerMatch[1] : text; - const followUpQuestions = followUpMatch - ? followUpMatch[1] - .split('\n') - .filter(q => q.trim()) - .map(q => q.replace(/^\d+\.\s*/, '').trim()) - : []; + // Update the message's citations in the ChatBox's history + updateMessageCitations(index, citations); - return { answer, followUpQuestions }; + return <>{parts}</>; }; - const { answer, followUpQuestions } = parseMessage(message.text); - console.log('Parsed answer:', answer); - console.log('Parsed follow-up questions:', followUpQuestions); return ( <div className={`message ${message.role}`}> - <ReactMarkdown components={{ a: LinkRenderer }}>{answer}</ReactMarkdown> - {message.image && <img src={message.image} alt="" />} - {followUpQuestions.length > 0 && ( + <div> + <LinkRenderer>{message.text}</LinkRenderer> + </div> + {message.follow_up_questions && message.follow_up_questions.length > 0 && ( <div className="follow-up-questions"> <h4>Follow-up Questions:</h4> - {followUpQuestions.map((question, idx) => ( + {message.follow_up_questions.map((question, idx) => ( <button key={idx} className="follow-up-button" onClick={() => onFollowUpClick(question)}> {question} </button> ))} </div> )} - <div className="message-footer"> - {message.tool_logs && ( - <button className="toggle-logs-button" onClick={() => toggleToolLogs(index)}> - {expandedLogIndex === index ? 'Hide Code Interpreter Logs' : 'Show Code Interpreter Logs'} - </button> - )} - {expandedLogIndex === index && ( - <div className="tool-logs"> - <pre>{message.tool_logs}</pre> - </div> - )} - </div> </div> ); }; diff --git a/src/client/views/nodes/ChatBox/tools/BaseTool.ts b/src/client/views/nodes/ChatBox/tools/BaseTool.ts index 3511d9528..903161bd5 100644 --- a/src/client/views/nodes/ChatBox/tools/BaseTool.ts +++ b/src/client/views/nodes/ChatBox/tools/BaseTool.ts @@ -1,6 +1,6 @@ import { Tool } from '../types'; -export abstract class BaseTool implements Tool { +export abstract class BaseTool<T extends Record<string, any> = Record<string, any>> implements Tool<T> { constructor( public name: string, public description: string, @@ -9,7 +9,7 @@ export abstract class BaseTool implements Tool { public briefSummary: string ) {} - abstract execute(args: Record<string, any>): Promise<any>; + abstract execute(args: T): Promise<any>; getActionRule(): Record<string, any> { return { diff --git a/src/client/views/nodes/ChatBox/tools/CalculateTool.ts b/src/client/views/nodes/ChatBox/tools/CalculateTool.ts index b881d90fa..818332c44 100644 --- a/src/client/views/nodes/ChatBox/tools/CalculateTool.ts +++ b/src/client/views/nodes/ChatBox/tools/CalculateTool.ts @@ -1,6 +1,6 @@ import { BaseTool } from './BaseTool'; -export class CalculateTool extends BaseTool { +export class CalculateTool extends BaseTool<{ expression: string }> { constructor() { super( 'calculate', diff --git a/src/client/views/nodes/ChatBox/tools/RAGTool.ts b/src/client/views/nodes/ChatBox/tools/RAGTool.ts index 84d5430e7..185efa0ba 100644 --- a/src/client/views/nodes/ChatBox/tools/RAGTool.ts +++ b/src/client/views/nodes/ChatBox/tools/RAGTool.ts @@ -1,8 +1,9 @@ import { BaseTool } from './BaseTool'; import { Vectorstore } from '../vectorstore/VectorstoreUpload'; import { Chunk } from '../types'; +import * as fs from 'fs'; -export class RAGTool extends BaseTool { +export class RAGTool extends BaseTool<{ hypothetical_document_chunk: string }> { constructor( private vectorstore: Vectorstore, summaries: string @@ -59,16 +60,26 @@ export class RAGTool extends BaseTool { for (const chunk of relevantChunks) { content.push({ type: 'text', - text: `<chunk chunk_id=${chunk.id} chunk_type=${chunk.metadata.type === 'image' ? 'image' : 'text'}>`, + text: `<chunk chunk_id=${chunk.id} chunk_type=${chunk.metadata.type === 'image' || chunk.metadata.type === 'table' ? 'image' : 'text'}>`, }); - if (chunk.metadata.type === 'image') { - // Implement image loading and base64 encoding here - // For now, we'll just add a placeholder - content.push({ - type: 'image_url', - image_url: { url: chunk.metadata.file_path }, - }); + if (chunk.metadata.type === 'image' || chunk.metadata.type === 'table') { + try { + const imageBuffer = fs.readFileSync(chunk.metadata.file_path); + const base64Image = imageBuffer.toString('base64'); + if (base64Image) { + content.push({ + type: 'image_url', + image_url: { + url: `data:image/jpeg;base64,${base64Image}`, + }, + }); + } else { + console.log(`Failed to encode image for chunk ${chunk.id}`); + } + } catch (error) { + console.error(`Error reading image file for chunk ${chunk.id}:`, error); + } } content.push({ type: 'text', text: `${chunk.metadata.text}\n</chunk>\n` }); diff --git a/src/client/views/nodes/ChatBox/tools/WikipediaTool.ts b/src/client/views/nodes/ChatBox/tools/WikipediaTool.ts index 0aef58f61..8ef2830d4 100644 --- a/src/client/views/nodes/ChatBox/tools/WikipediaTool.ts +++ b/src/client/views/nodes/ChatBox/tools/WikipediaTool.ts @@ -1,7 +1,7 @@ import { BaseTool } from './BaseTool'; import axios from 'axios'; -export class WikipediaTool extends BaseTool { +export class WikipediaTool extends BaseTool<{ title: string }> { constructor() { super( 'wikipedia', diff --git a/src/client/views/nodes/ChatBox/types.ts b/src/client/views/nodes/ChatBox/types.ts index c60973be3..0270b6256 100644 --- a/src/client/views/nodes/ChatBox/types.ts +++ b/src/client/views/nodes/ChatBox/types.ts @@ -6,6 +6,7 @@ export enum ASSISTANT_ROLE { export enum CHUNK_TYPE { TEXT = 'text', IMAGE = 'image', + TABLE = 'table', } export interface AssistantMessage { @@ -18,9 +19,7 @@ export interface AssistantMessage { export interface Citation { text: string; type: CHUNK_TYPE; - span: [number, number]; chunk_id: string; - direct_text?: string; } export interface Chunk { @@ -46,13 +45,13 @@ export interface AI_Document { type: string; } -export interface Tool { +export interface Tool<T extends Record<string, any> = Record<string, any>> { name: string; description: string; parameters: Record<string, any>; useRules: string; briefSummary: string; - execute: (args: Record<string, any>) => Promise<any>; + execute: (args: T) => Promise<any>; getActionRule: () => Record<string, any>; } diff --git a/src/client/views/nodes/ChatBox/vectorstore/VectorstoreUpload.ts b/src/client/views/nodes/ChatBox/vectorstore/VectorstoreUpload.ts index d16e117b6..1f483ad61 100644 --- a/src/client/views/nodes/ChatBox/vectorstore/VectorstoreUpload.ts +++ b/src/client/views/nodes/ChatBox/vectorstore/VectorstoreUpload.ts @@ -2,6 +2,7 @@ import { Pinecone, Index, IndexList, PineconeRecord, RecordMetadata, QueryRespon import { CohereClient } from 'cohere-ai'; import { EmbedResponse } from 'cohere-ai/api'; import dotenv from 'dotenv'; + import { Chunk, AI_Document } from '../types'; dotenv.config(); |