aboutsummaryrefslogtreecommitdiff
path: root/src/client/views/nodes/chatbot/tools/ImageCreationTool.ts
diff options
context:
space:
mode:
authorsharkiecodes <lanyi_stroud@brown.edu>2025-07-22 21:05:47 -0400
committersharkiecodes <lanyi_stroud@brown.edu>2025-07-22 21:05:47 -0400
commit16e7cfcac3d41bd86ef953f131bb0fecba11f299 (patch)
treea2b91bf30e75e513b4913ac88ec3158e512665cf /src/client/views/nodes/chatbot/tools/ImageCreationTool.ts
parent8ff34d5335093c4ff85473227f39b3e83133d999 (diff)
adjusted agent to include UI control tool
Diffstat (limited to 'src/client/views/nodes/chatbot/tools/ImageCreationTool.ts')
-rw-r--r--src/client/views/nodes/chatbot/tools/ImageCreationTool.ts139
1 files changed, 115 insertions, 24 deletions
diff --git a/src/client/views/nodes/chatbot/tools/ImageCreationTool.ts b/src/client/views/nodes/chatbot/tools/ImageCreationTool.ts
index c5b1e028b..fa98e2472 100644
--- a/src/client/views/nodes/chatbot/tools/ImageCreationTool.ts
+++ b/src/client/views/nodes/chatbot/tools/ImageCreationTool.ts
@@ -6,6 +6,13 @@ import { Observation } from '../types/types';
import { BaseTool } from './BaseTool';
import { Upload } from '../../../../../server/SharedMediaTypes';
import { List } from '../../../../../fields/List';
+import { SmartDrawHandler } from '../../../smartdraw/SmartDrawHandler';
+import { FireflyImageDimensions } from '../../../smartdraw/FireflyConstants';
+import { gptImageCall } from '../../../../apis/gpt/GPT';
+import { ClientUtils } from '../../../../../ClientUtils';
+import { Doc } from '../../../../../fields/Doc';
+import { DocumentViewInternal } from '../../DocumentView';
+import { OpenWhere } from '../../OpenWhere';
const imageCreationToolParams = [
{
@@ -14,6 +21,18 @@ const imageCreationToolParams = [
description: 'The prompt for the image to be created. This should be a string that describes the image to be created in extreme detail for an AI image generator.',
required: true,
},
+ {
+ name: 'engine',
+ type: 'string',
+ description: 'The image generation engine to use. Options: "firefly" (default), "dalle". If not specified, defaults to "firefly".',
+ required: false,
+ },
+ {
+ name: 'aspect_ratio',
+ type: 'string',
+ description: 'Aspect ratio for the image (Firefly only). Options: "square" (default), "landscape", "portrait", "widescreen".',
+ required: false,
+ },
] as const;
type ImageCreationToolParamsType = typeof imageCreationToolParams;
@@ -22,7 +41,7 @@ const imageCreationToolInfo: ToolInfo<ImageCreationToolParamsType> = {
name: 'imageCreationTool',
citationRules: 'No citation needed. Cannot cite image generation for a response.',
parameterRules: imageCreationToolParams,
- description: 'Create an image of any style, content, or design, based on a prompt. The prompt should be a detailed description of the image to be created.',
+ description: 'Create an image of any style, content, or design, based on a prompt. Uses Firefly by default for better quality and control. Use "dalle" engine explicitly only if requested. The prompt should be a detailed description of the image to be created.',
};
export class ImageCreationTool extends BaseTool<ImageCreationToolParamsType> {
@@ -35,37 +54,109 @@ export class ImageCreationTool extends BaseTool<ImageCreationToolParamsType> {
}
async execute(args: ParametersType<ImageCreationToolParamsType>): Promise<Observation[]> {
- const image_prompt = args.image_prompt;
+ const { image_prompt, engine = 'firefly', aspect_ratio = 'square' } = args;
+
+ console.log(`Generating image with ${engine} for prompt: ${image_prompt}`);
- console.log(`Generating image for prompt: ${image_prompt}`);
- // Create an array of promises, each one handling a search for a query
try {
- const { result, url } = (await Networking.PostToServer('/generateImage', {
- image_prompt,
- })) as { result: Upload.FileInformation & Upload.InspectionResults; url: string };
- console.log('Image generation result:', result);
- this._createImage(result, { text: RTFCast(image_prompt), ai: 'dall-e-3', tags: new List<string>(['@ai']) });
- return url
- ? [
- {
- type: 'image_url',
- image_url: { url },
- },
- ]
- : [
- {
- type: 'text',
- text: `An error occurred while generating image.`,
- },
- ];
+ if (engine.toLowerCase() === 'dalle') {
+ // Use DALL-E for image generation
+ return await this.generateWithDalle(image_prompt);
+ } else {
+ // Default to Firefly
+ return await this.generateWithFirefly(image_prompt, aspect_ratio);
+ }
} catch (error) {
- console.log(error);
+ console.error('ImageCreationTool error:', error);
return [
{
type: 'text',
- text: `An error occurred while generating image.`,
+ text: `Error generating image: ${error}`,
},
];
}
}
+
+ private async generateWithDalle(prompt: string): Promise<Observation[]> {
+ try {
+ // Call GPT image API directly
+ const imageUrls = await gptImageCall(prompt);
+
+ if (imageUrls && imageUrls[0]) {
+ // Upload the remote image to our server
+ const uploadRes = await Networking.PostToServer('/uploadRemoteImage', { sources: [imageUrls[0]] });
+ const fileInfo = (uploadRes as Upload.FileInformation[])[0];
+ const source = ClientUtils.prepend(fileInfo.accessPaths.agnostic.client);
+
+ // Create image document with DALL-E metadata
+ this._createImage(fileInfo as Upload.FileInformation & Upload.InspectionResults, {
+ text: RTFCast(prompt),
+ ai: 'dall-e-3',
+ tags: new List<string>(['@ai']),
+ title: prompt,
+ _width: 400,
+ _height: 400,
+ });
+
+ return [
+ {
+ type: 'image_url',
+ image_url: { url: source },
+ },
+ ];
+ } else {
+ return [
+ {
+ type: 'text',
+ text: 'Failed to generate image with DALL-E',
+ },
+ ];
+ }
+ } catch (error) {
+ console.error('DALL-E generation error:', error);
+ throw error;
+ }
+ }
+
+ private async generateWithFirefly(prompt: string, aspectRatio: string): Promise<Observation[]> {
+ try {
+ // Map aspect ratio string to FireflyImageDimensions enum
+ const dimensionMap: Record<string, FireflyImageDimensions> = {
+ 'square': FireflyImageDimensions.Square,
+ 'landscape': FireflyImageDimensions.Landscape,
+ 'portrait': FireflyImageDimensions.Portrait,
+ 'widescreen': FireflyImageDimensions.Widescreen,
+ };
+
+ const dimensions = dimensionMap[aspectRatio.toLowerCase()] || FireflyImageDimensions.Square;
+
+ // Use SmartDrawHandler to create Firefly image
+ const doc = await SmartDrawHandler.CreateWithFirefly(prompt, dimensions);
+
+ if (doc instanceof Doc) {
+ // Open the document in a new tab
+ DocumentViewInternal.addDocTabFunc(doc, OpenWhere.addRight);
+
+ // Get the image URL from the document
+ const imageUrl = doc.image || doc.url || '';
+
+ return [
+ {
+ type: 'text',
+ text: `Created image with Firefly: "${prompt}". The image has been opened in a new tab.`,
+ },
+ ];
+ } else {
+ return [
+ {
+ type: 'text',
+ text: 'Failed to generate image with Firefly',
+ },
+ ];
+ }
+ } catch (error) {
+ console.error('Firefly generation error:', error);
+ throw error;
+ }
+ }
}