diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/client/views/nodes/ChatBox/Agent.ts | 2 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/AnswerParser.ts | 93 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/ChatBox.tsx | 20 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/MessageComponent.tsx | 6 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/tools/RAGTool.ts | 3 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/types.ts | 1 | ||||
-rw-r--r-- | src/client/views/nodes/ChatBox/vectorstore/VectorstoreUpload.ts | 14 |
7 files changed, 99 insertions, 40 deletions
diff --git a/src/client/views/nodes/ChatBox/Agent.ts b/src/client/views/nodes/ChatBox/Agent.ts index b9d137270..fddfdfcb1 100644 --- a/src/client/views/nodes/ChatBox/Agent.ts +++ b/src/client/views/nodes/ChatBox/Agent.ts @@ -22,7 +22,7 @@ export class Agent { private _summaries: () => string; constructor(_vectorstore: Vectorstore, summaries: () => string, history: () => string) { - this.client = new OpenAI({ apiKey: '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', dangerouslyAllowBrowser: true }); + this.client = new OpenAI({ apiKey: process.env.OPENAI_KEY, dangerouslyAllowBrowser: true }); this.vectorstore = _vectorstore; this._history = history; this._summaries = summaries; diff --git a/src/client/views/nodes/ChatBox/AnswerParser.ts b/src/client/views/nodes/ChatBox/AnswerParser.ts index dd7ec3499..9956792d8 100644 --- a/src/client/views/nodes/ChatBox/AnswerParser.ts +++ b/src/client/views/nodes/ChatBox/AnswerParser.ts @@ -4,62 +4,95 @@ import { v4 as uuid } from 'uuid'; export class AnswerParser { static parse(xml: string): AssistantMessage { const answerRegex = /<answer>([\s\S]*?)<\/answer>/; + const citationsRegex = /<citations>([\s\S]*?)<\/citations>/; const citationRegex = /<citation index="([^"]+)" chunk_id="([^"]+)" type="([^"]+)">([\s\S]*?)<\/citation>/g; const followUpQuestionsRegex = /<follow_up_questions>([\s\S]*?)<\/follow_up_questions>/; const questionRegex = /<question>(.*?)<\/question>/g; const groundedTextRegex = /<grounded_text citation_index="([^"]+)">([\s\S]*?)<\/grounded_text>/g; const answerMatch = answerRegex.exec(xml); + const citationsMatch = citationsRegex.exec(xml); const followUpQuestionsMatch = followUpQuestionsRegex.exec(xml); if (!answerMatch) { throw new Error('Invalid XML: Missing <answer> tag.'); } - const rawTextContent = answerMatch[1].trim(); + let rawTextContent = answerMatch[1].trim(); let content: AssistantMessage['content'] = []; let citations: Citation[] = []; let contentIndex = 0; + // Remove citations and follow-up questions from rawTextContent + if (citationsMatch) { + rawTextContent = rawTextContent.replace(citationsMatch[0], '').trim(); + } + if (followUpQuestionsMatch) { + rawTextContent = rawTextContent.replace(followUpQuestionsMatch[0], '').trim(); + } + // Parse citations let citationMatch; const citationMap = new Map<string, string>(); - while ((citationMatch = citationRegex.exec(rawTextContent)) !== null) { - const [_, index, chunk_id, type, direct_text] = citationMatch; - const citation_id = uuid(); - citationMap.set(index, citation_id); - citations.push({ - direct_text: direct_text.trim(), - type: getChunkType(type), - chunk_id, - citation_id, - }); + if (citationsMatch) { + const citationsContent = citationsMatch[1]; + while ((citationMatch = citationRegex.exec(citationsContent)) !== null) { + const [_, index, chunk_id, type, direct_text] = citationMatch; + const citation_id = uuid(); + citationMap.set(index, citation_id); + citations.push({ + direct_text: direct_text.trim(), + type: getChunkType(type), + chunk_id, + citation_id, + }); + } } - // Parse grounded text content - const parseGroundedText = (text: string): AssistantMessage['content'] => { - const result: AssistantMessage['content'] = []; - let lastIndex = 0; - let match; + // Parse text content (normal and grounded) + let lastIndex = 0; + let match; - while ((match = groundedTextRegex.exec(text)) !== null) { - const [fullMatch, citationIndex, groundedText] = match; - const citation_ids = citationIndex.split(',').map(index => citationMap.get(index) || ''); - - result.push({ - index: contentIndex++, - type: TEXT_TYPE.GROUNDED, - text: groundedText.trim(), - citation_ids, - }); + while ((match = groundedTextRegex.exec(rawTextContent)) !== null) { + const [fullMatch, citationIndex, groundedText] = match; - lastIndex = match.index + fullMatch.length; + // Add normal text before the grounded text + if (match.index > lastIndex) { + const normalText = rawTextContent.slice(lastIndex, match.index).trim(); + if (normalText) { + content.push({ + index: contentIndex++, + type: TEXT_TYPE.NORMAL, + text: normalText, + citation_ids: null, + }); + } } - return result; - }; + // Add grounded text + const citation_ids = citationIndex.split(',').map(index => citationMap.get(index) || ''); + content.push({ + index: contentIndex++, + type: TEXT_TYPE.GROUNDED, + text: groundedText.trim(), + citation_ids, + }); + + lastIndex = match.index + fullMatch.length; + } - content = parseGroundedText(rawTextContent); + // Add any remaining normal text after the last grounded text + if (lastIndex < rawTextContent.length) { + const remainingText = rawTextContent.slice(lastIndex).trim(); + if (remainingText) { + content.push({ + index: contentIndex++, + type: TEXT_TYPE.NORMAL, + text: remainingText, + citation_ids: null, + }); + } + } let followUpQuestions: string[] = []; if (followUpQuestionsMatch) { diff --git a/src/client/views/nodes/ChatBox/ChatBox.tsx b/src/client/views/nodes/ChatBox/ChatBox.tsx index 9e604073d..6269b8768 100644 --- a/src/client/views/nodes/ChatBox/ChatBox.tsx +++ b/src/client/views/nodes/ChatBox/ChatBox.tsx @@ -54,7 +54,7 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() { } else { this.vectorstore_id = StrCast(this.dataDoc.vectorstore_id); } - this.vectorstore = new Vectorstore(this.vectorstore_id); + this.vectorstore = new Vectorstore(this.vectorstore_id, this.retrieveDocIds); this.agent = new Agent(this.vectorstore, this.retrieveSummaries, this.retrieveFormattedHistory); reaction( @@ -113,7 +113,7 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() { } catch (err) { console.error('Error:', err); runInAction(() => { - this.history.push({ role: ASSISTANT_ROLE.USER, content: [{ index: 0, type: TEXT_TYPE.NORMAL, text: 'Sorry, I encountered an error while processing your request.', citation_ids: null }] }); + this.history.push({ role: ASSISTANT_ROLE.ASSISTANT, content: [{ index: 0, type: TEXT_TYPE.NORMAL, text: 'Sorry, I encountered an error while processing your request.', citation_ids: null }] }); }); } finally { runInAction(() => { @@ -224,7 +224,7 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() { } else { runInAction(() => { this.history.push({ - role: ASSISTANT_ROLE.USER, + role: ASSISTANT_ROLE.ASSISTANT, content: [{ index: 0, type: TEXT_TYPE.NORMAL, text: 'Welcome to the Document Analyser Assistant! Link a document or ask questions to get started.', citation_ids: null }], }); }); @@ -265,6 +265,16 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() { } @computed + get docIds() { + return LinkManager.Instance.getAllRelatedLinks(this.Document) + .map(d => DocCast(LinkManager.getOppositeAnchor(d, this.Document))) + .map(d => DocCast(d?.annotationOn, d)) + .filter(d => d) + .filter(d => d.ai_doc_id) + .map(d => StrCast(d.ai_doc_id)); + } + + @computed get summaries(): string { return ( LinkManager.Instance.getAllRelatedLinks(this.Document) @@ -295,6 +305,10 @@ export class ChatBox extends ViewBoxAnnotatableComponent<FieldViewProps>() { return this.formattedHistory; }; + retrieveDocIds = () => { + return this.docIds; + }; + @action handleFollowUpClick = (question: string) => { console.log('Follow-up question clicked:', question); diff --git a/src/client/views/nodes/ChatBox/MessageComponent.tsx b/src/client/views/nodes/ChatBox/MessageComponent.tsx index d24a55d23..07bfd4e3d 100644 --- a/src/client/views/nodes/ChatBox/MessageComponent.tsx +++ b/src/client/views/nodes/ChatBox/MessageComponent.tsx @@ -49,6 +49,12 @@ const MessageComponentBox: React.FC<MessageComponentProps> = function ({ message })} </span> ); + } else if (item.type === TEXT_TYPE.NORMAL) { + return ( + <span key={i} className="normal-text"> + {item.text} + </span> + ); } else { return <span key={i}>{item.text}</span>; } diff --git a/src/client/views/nodes/ChatBox/tools/RAGTool.ts b/src/client/views/nodes/ChatBox/tools/RAGTool.ts index c7175326c..23b93b0f0 100644 --- a/src/client/views/nodes/ChatBox/tools/RAGTool.ts +++ b/src/client/views/nodes/ChatBox/tools/RAGTool.ts @@ -3,6 +3,7 @@ import { Vectorstore } from '../vectorstore/VectorstoreUpload'; import { Chunk } from '../types'; import * as fs from 'fs'; import { Networking } from '../../../../Network'; +import { file } from 'jszip'; export class RAGTool extends BaseTool<{ hypothetical_document_chunk: string }> { constructor(private vectorstore: Vectorstore) { @@ -62,7 +63,7 @@ export class RAGTool extends BaseTool<{ hypothetical_document_chunk: string }> { 6. Structural Integrity Checks: a. Ensure all opening tags have corresponding closing tags. - b. Verify that all grounded_text tags have valid citation_index attributes. + b. Verify that all grounded_text tags have valid citation_index attributes (they should be equal to the associated citation(s) index field—not their chunk_id field). c. Check that all cited indices in grounded_text tags have corresponding citations. Example of grounded_text usage: diff --git a/src/client/views/nodes/ChatBox/types.ts b/src/client/views/nodes/ChatBox/types.ts index 10c80c05a..c2fb095f0 100644 --- a/src/client/views/nodes/ChatBox/types.ts +++ b/src/client/views/nodes/ChatBox/types.ts @@ -59,6 +59,7 @@ export interface Chunk { type: CHUNK_TYPE; original_document: string; file_path: string; + doc_id: string; location: string; start_page: number; end_page: number; diff --git a/src/client/views/nodes/ChatBox/vectorstore/VectorstoreUpload.ts b/src/client/views/nodes/ChatBox/vectorstore/VectorstoreUpload.ts index ab0b6e617..0737e2392 100644 --- a/src/client/views/nodes/ChatBox/vectorstore/VectorstoreUpload.ts +++ b/src/client/views/nodes/ChatBox/vectorstore/VectorstoreUpload.ts @@ -17,9 +17,10 @@ export class Vectorstore { private cohere: CohereClient; private indexName: string = 'pdf-chatbot'; private id: string; + private file_ids: string[] = []; documents: AI_Document[] = []; - constructor(id: string) { + constructor(id: string, doc_ids: () => string[]) { const pineconeApiKey = process.env.PINECONE_API_KEY; if (!pineconeApiKey) { throw new Error('PINECONE_API_KEY is not defined.'); @@ -32,6 +33,7 @@ export class Vectorstore { token: process.env.COHERE_API_KEY, }); this.id = id; + this.file_ids = doc_ids(); this.initializeIndex(); } @@ -63,7 +65,7 @@ export class Vectorstore { console.log('Already in progress.'); return; } - console.log(`Document already added: ${doc.file_name}`); + if (!this.file_ids.includes(StrCast(doc.ai_doc_id))) this.file_ids.push(StrCast(doc.ai_doc_id)); } else { doc.ai_document_status = 'PROGRESS'; console.log(doc); @@ -79,6 +81,8 @@ export class Vectorstore { await this.indexDocument(JSON.parse(JSON.stringify(document_json, (key, value) => (value === null || value === undefined ? undefined : value)))); console.log(`Document added: ${document_json.file_name}`); doc.summary = document_json.summary; + doc.ai_doc_id = document_json.doc_id; + this.file_ids.push(document_json.doc_id); doc.ai_purpose = document_json.purpose; if (doc.vectorstore_id === undefined || doc.vectorstore_id === null || doc.vectorstore_id === '' || doc.vectorstore_id === '[]') { doc.vectorstore_id = JSON.stringify([this.id]); @@ -125,7 +129,7 @@ export class Vectorstore { ({ id: chunk.id, values: chunk.values, - metadata: { ...chunk.metadata, vectorstore_id: this.id } as RecordMetadata, + metadata: { ...chunk.metadata } as RecordMetadata, }) as PineconeRecord ); await this.index.upsert(pineconeRecords); @@ -157,7 +161,7 @@ export class Vectorstore { const queryResponse: QueryResponse<RecordMetadata> = await this.index.query({ vector: queryEmbedding, filter: { - vectorstore_id: this.id, + doc_id: { $in: this.file_ids }, }, topK, includeValues: true, @@ -169,7 +173,7 @@ export class Vectorstore { ({ id: match.id, values: match.values as number[], - metadata: match.metadata as { text: string; type: string; original_document: string; file_path: string; location: string; start_page: number; end_page: number }, + metadata: match.metadata as { text: string; type: string; original_document: string; file_path: string; doc_id: string; location: string; start_page: number; end_page: number }, }) as Chunk ); } catch (error) { |