diff options
Diffstat (limited to 'src/server')
| -rw-r--r-- | src/server/ApiManagers/AssistantManager.ts | 4 | ||||
| -rw-r--r-- | src/server/ApiManagers/FireflyManager.ts | 23 | ||||
| -rw-r--r-- | src/server/DashUploadUtils.ts | 8 | ||||
| -rw-r--r-- | src/server/apis/google/GoogleApiServerUtils.ts | 93 | ||||
| -rw-r--r-- | src/server/chunker/pdf_chunker.py | 20 |
5 files changed, 73 insertions, 75 deletions
diff --git a/src/server/ApiManagers/AssistantManager.ts b/src/server/ApiManagers/AssistantManager.ts index e859e5c5f..af25722a4 100644 --- a/src/server/ApiManagers/AssistantManager.ts +++ b/src/server/ApiManagers/AssistantManager.ts @@ -539,7 +539,7 @@ export default class AssistantManager extends ApiManager { // Spawn the Python process and track its progress/output // eslint-disable-next-line no-use-before-define - spawnPythonProcess(jobId, file_name, public_path); + spawnPythonProcess(jobId, public_path); // Send the job ID back to the client for tracking res.send({ jobId }); @@ -696,7 +696,7 @@ export default class AssistantManager extends ApiManager { * @param file_name The name of the file to process. * @param file_path The filepath of the file to process. */ -function spawnPythonProcess(jobId: string, file_name: string, file_path: string) { +function spawnPythonProcess(jobId: string, file_path: string) { const venvPath = path.join(__dirname, '../chunker/venv'); const requirementsPath = path.join(__dirname, '../chunker/requirements.txt'); const pythonScriptPath = path.join(__dirname, '../chunker/pdf_chunker.py'); diff --git a/src/server/ApiManagers/FireflyManager.ts b/src/server/ApiManagers/FireflyManager.ts index 160a94d40..e75ede9df 100644 --- a/src/server/ApiManagers/FireflyManager.ts +++ b/src/server/ApiManagers/FireflyManager.ts @@ -132,7 +132,8 @@ export default class FireflyManager extends ApiManager { ], body: body, }) - .then(response2 => response2.json().then(json => ({ seed: json.outputs?.[0]?.seed, url: json.outputs?.[0]?.image?.url }))) + .then(response2 => response2.json()) + .then(json => (json.error_code ? json : { seed: json.outputs?.[0]?.seed, url: json.outputs?.[0]?.image?.url })) .catch(error => { console.error('Error:', error); return undefined; @@ -297,13 +298,13 @@ export default class FireflyManager extends ApiManager { _invalid(res, styleUrl.message); throw new Error('Error uploading images to dropbox'); } - this.uploadImageToDropbox(req.body.structure, req.user as DashUserModel) - .then(structureUrl => { - if (structureUrl instanceof Error) { - _invalid(res, structureUrl.message); + this.uploadImageToDropbox(req.body.structureUrl, req.user as DashUserModel) + .then(dropboxStructureUrl => { + if (dropboxStructureUrl instanceof Error) { + _invalid(res, dropboxStructureUrl.message); throw new Error('Error uploading images to dropbox'); } - return { styleUrl, structureUrl }; + return { styleUrl, structureUrl: dropboxStructureUrl }; }) .then(uploads => this.generateImageFromStructure(req.body.prompt, req.body.width, req.body.height, uploads.structureUrl, req.body.strength, req.body.presets, uploads.styleUrl) @@ -332,10 +333,12 @@ export default class FireflyManager extends ApiManager { subscription: '/queryFireflyImage', secureHandler: ({ req, res }) => this.generateImage(req.body.prompt, req.body.width, req.body.height, req.body.seed).then(img => - DashUploadUtils.UploadImage(img?.url ?? '', undefined, img?.seed).then(info => { - if (info instanceof Error) _invalid(res, info.message); - else _success(res, info); - }) + img.error_code + ? _invalid(res, img.message) + : DashUploadUtils.UploadImage(img?.url ?? '', undefined, img?.seed).then(info => { + if (info instanceof Error) _invalid(res, info.message); + else _success(res, info); + }) ), }); diff --git a/src/server/DashUploadUtils.ts b/src/server/DashUploadUtils.ts index 2177c5d97..a2747257a 100644 --- a/src/server/DashUploadUtils.ts +++ b/src/server/DashUploadUtils.ts @@ -23,6 +23,7 @@ import { AcceptableMedia, Upload } from './SharedMediaTypes'; import { Directory, clientPathToFile, filesDirectory, pathToDirectory, publicDirectory, serverPathToFile } from './SocketData'; import { resolvedServerUrl } from './server_Initialization'; import { Worker, isMainThread, parentPort } from 'worker_threads'; +import requestImageSize from '../client/util/request-image-size'; // Create an array to store worker threads enum workertasks { @@ -62,9 +63,6 @@ if (isMainThread) { } } -// eslint-disable-next-line @typescript-eslint/no-require-imports -const requestImageSize = require('../client/util/request-image-size'); - export enum SizeSuffix { Small = '_s', Medium = '_m', @@ -349,14 +347,14 @@ export namespace DashUploadUtils { imgReadStream.push(null); await Promise.all( sizes.map(({ suffix }) => - new Promise<unknown>(res => + new Promise<void>(res => imgReadStream.pipe(createWriteStream(writtenFiles[suffix] = InjectSize(outputPath, suffix))).on('close', res) ) )); // prettier-ignore } else { await Promise.all( sizes.map(({ suffix }) => - new Promise<unknown>(res => + new Promise<void>(res => request.get(imgSourcePath).pipe(createWriteStream(writtenFiles[suffix] = InjectSize(outputPath, suffix))).on('close', res) ) )); // prettier-ignore diff --git a/src/server/apis/google/GoogleApiServerUtils.ts b/src/server/apis/google/GoogleApiServerUtils.ts index 21c405bee..7373df473 100644 --- a/src/server/apis/google/GoogleApiServerUtils.ts +++ b/src/server/apis/google/GoogleApiServerUtils.ts @@ -1,7 +1,7 @@ +/* eslint-disable no-use-before-define */ import { GaxiosResponse } from 'gaxios'; import { Credentials, OAuth2Client, OAuth2ClientOptions } from 'google-auth-library'; import { google } from 'googleapis'; -import * as qs from 'query-string'; import * as request from 'request-promise'; import { Opt } from '../../../fields/Doc'; import { Database } from '../../database'; @@ -21,7 +21,6 @@ const scope = ['documents.readonly', 'documents', 'presentations', 'presentation * This namespace manages server side authentication for Google API queries, either * from the standard v1 APIs or the Google Photos REST API. */ - export namespace GoogleApiServerUtils { /** * As we expand out to more Google APIs that are accessible from @@ -71,29 +70,29 @@ export namespace GoogleApiServerUtils { /** * A briefer format for the response from a 'googleapis' API request */ - export type ApiResponse = Promise<GaxiosResponse>; + export type ApiResponse = Promise<GaxiosResponse<unknown>>; /** * A generic form for a handler that executes some request on the endpoint */ - export type ApiRouter = (endpoint: Endpoint, parameters: any) => ApiResponse; + export type ApiRouter = (endpoint: Endpoint, parameters: Record<string, unknown>) => ApiResponse; /** * A generic form for the asynchronous function that actually submits the - * request to the API and returns the corresporing response. Helpful when + * request to the API and returns the corresponding response. Helpful when * making an extensible endpoint definition. */ - export type ApiHandler = (parameters: any, methodOptions?: any) => ApiResponse; + export type ApiHandler = (parameters: Record<string, unknown>, methodOptions?: Record<string, unknown>) => ApiResponse; /** * A literal union type indicating the valid actions for these 'googleapis' - * requestions + * requests */ export type Action = 'create' | 'retrieve' | 'update'; /** * An interface defining any entity on which one can invoke - * anuy of the following handlers. All 'googleapis' wrappers + * any of the following handlers. All 'googleapis' wrappers * such as google.docs().documents and google.slides().presentations * satisfy this interface. */ @@ -109,7 +108,7 @@ export namespace GoogleApiServerUtils { * of needless duplicate clients that would arise from * making one new client instance per request. */ - const authenticationClients = new Map<String, OAuth2Client>(); + const authenticationClients = new Map<string, OAuth2Client>(); /** * This function receives the target sector ("which G-Suite app's API am I interested in?") @@ -120,23 +119,21 @@ export namespace GoogleApiServerUtils { * @returns the relevant 'googleapis' wrapper, if any */ export async function GetEndpoint(sector: string, userId: string): Promise<Endpoint | void> { - return new Promise(async resolve => { - const auth = await retrieveOAuthClient(userId); - if (!auth) { - return resolve(); - } - let routed: Opt<Endpoint>; - const parameters: any = { auth, version: 'v1' }; - switch (sector) { - case Service.Documents: - routed = google.docs(parameters).documents; - break; - case Service.Slides: - routed = google.slides(parameters).presentations; - break; - } - resolve(routed); - }); + const auth = await retrieveOAuthClient(userId); + if (!auth) { + return; + } + let routed: Opt<Endpoint>; + const parameters: { version: 'v1' } = { /* auth, */ version: 'v1' }; ///* auth: OAuth2Client;*/ + switch (sector) { + case Service.Documents: + routed = google.docs(parameters).documents; + break; + case Service.Slides: + routed = google.slides(parameters).presentations; + break; + } + return routed; } /** @@ -149,19 +146,17 @@ export namespace GoogleApiServerUtils { * security. */ export async function retrieveOAuthClient(userId: string): Promise<OAuth2Client | void> { - return new Promise(async resolve => { - const { credentials, refreshed } = await retrieveCredentials(userId); - if (!credentials) { - return resolve(); - } - let client = authenticationClients.get(userId); - if (!client) { - authenticationClients.set(userId, (client = generateClient(credentials))); - } else if (refreshed) { - client.setCredentials(credentials); - } - resolve(client); - }); + const { credentials, refreshed } = await retrieveCredentials(userId); + if (!credentials) { + return; + } + let client = authenticationClients.get(userId); + if (!client) { + authenticationClients.set(userId, (client = generateClient(credentials))); + } else if (refreshed) { + client.setCredentials(credentials); + } + return client; } /** @@ -173,7 +168,9 @@ export namespace GoogleApiServerUtils { */ function generateClient(credentials?: Credentials): OAuth2Client { const client = new google.auth.OAuth2(oAuthOptions); - credentials && client.setCredentials(credentials); + if (credentials) { + client.setCredentials(credentials); + } return client; } @@ -206,7 +203,7 @@ export namespace GoogleApiServerUtils { */ export async function processNewUser(userId: string, authenticationCode: string): Promise<EnrichedCredentials> { const credentials = await new Promise<Credentials>((resolve, reject) => { - worker.getToken(authenticationCode, async (err, credentials) => { + worker.getToken(authenticationCode, (err, credentials) => { if (err || !credentials) { reject(err); return; @@ -221,7 +218,7 @@ export namespace GoogleApiServerUtils { /** * This type represents the union of the full set of OAuth2 credentials - * and all of a Google user's publically available information. This is the strucure + * and all of a Google user's publicly available information. This is the structure * of the JSON object we ultimately store in the googleAuthentication table of the database. */ export type EnrichedCredentials = Credentials & { userInfo: UserInfo }; @@ -297,15 +294,15 @@ export namespace GoogleApiServerUtils { async function refreshAccessToken(credentials: Credentials, userId: string): Promise<Credentials> { const headerParameters = { headers: { 'Content-Type': 'application/x-www-form-urlencoded' } }; const { client_id, client_secret } = GoogleCredentialsLoader.ProjectCredentials; - const url = `https://oauth2.googleapis.com/token?${qs.stringify({ - refreshToken: credentials.refresh_token, + const params = new URLSearchParams({ + refresh_token: credentials.refresh_token!, client_id, client_secret, grant_type: 'refresh_token', - })}`; - const { access_token, expires_in } = await new Promise<any>(async resolve => { - const response = await request.post(url, headerParameters); - resolve(JSON.parse(response)); + }); + const url = `https://oauth2.googleapis.com/token?${params.toString()}`; + const { access_token, expires_in } = await new Promise<{ access_token: string; expires_in: number }>(resolve => { + request.post(url, headerParameters).then(response => resolve(JSON.parse(response))); }); // expires_in is in seconds, but we're building the new expiry date in milliseconds const expiry_date = new Date().getTime() + expires_in * 1000; diff --git a/src/server/chunker/pdf_chunker.py b/src/server/chunker/pdf_chunker.py index a9dbcbb0c..697550f2e 100644 --- a/src/server/chunker/pdf_chunker.py +++ b/src/server/chunker/pdf_chunker.py @@ -21,7 +21,7 @@ import json import os import uuid # For generating unique IDs from enum import Enum # Enums for types like document type and purpose -import cohere # Embedding client +import openai import numpy as np from PyPDF2 import PdfReader # PDF text extraction from openai import OpenAI # OpenAI client for text completion @@ -35,8 +35,8 @@ warnings.filterwarnings('ignore', message="torch.load") dotenv.load_dotenv() # Load environment variables # Fix for newer versions of PIL -if parse(PIL.__version__) >= parse('10.0.0'): - Image.LINEAR = Image.BILINEAR +# if parse(PIL.__version__) >= parse('10.0.0'): +# Image.LINEAR = Image.BILINEAR # Global dictionary to track progress of document processing jobs current_progress = {} @@ -727,19 +727,19 @@ class Document: """ Embed the text chunks using the Cohere API. """ - co = cohere.Client(os.getenv("COHERE_API_KEY")) # Initialize Cohere client with API key + openai = OpenAI() # Initialize Cohere client with API key batch_size = 90 # Batch size for embedding chunks_len = len(self.chunks) # Total number of chunks to embed for i in tqdm(range(0, chunks_len, batch_size), desc="Embedding Chunks"): batch = self.chunks[i: min(i + batch_size, chunks_len)] # Get batch of chunks texts = [chunk['metadata']['text'] for chunk in batch] # Extract text from each chunk - chunk_embs_batch = co.embed( - texts=texts, - model="embed-english-v3.0", # Use Cohere's embedding model - input_type="search_document" # Specify input type + chunk_embs_batch = openai.embeddings.create( + model="text-embedding-3-large", + input=texts, + encoding_format="float" ) - for j, emb in enumerate(chunk_embs_batch.embeddings): - self.chunks[i + j]['values'] = emb # Store the embeddings in the corresponding chunks + for j, data_val in enumerate(chunk_embs_batch.data): + self.chunks[i + j]['values'] = data_val.embedding # Store the embeddings in the corresponding chunks def _generate_summary(self) -> str: """ |
