aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/client/apis/gpt/GPT.ts75
-rw-r--r--src/client/util/CurrentUserUtils.ts8
-rw-r--r--src/client/views/MainView.tsx6
-rw-r--r--src/client/views/collections/collectionFreeForm/ImageLabelHandler.scss44
-rw-r--r--src/client/views/collections/collectionFreeForm/ImageLabelHandler.tsx120
-rw-r--r--src/client/views/collections/collectionFreeForm/MarqueeOptionsMenu.tsx3
-rw-r--r--src/client/views/collections/collectionFreeForm/MarqueeView.tsx129
7 files changed, 345 insertions, 40 deletions
diff --git a/src/client/apis/gpt/GPT.ts b/src/client/apis/gpt/GPT.ts
index 6857246da..455352068 100644
--- a/src/client/apis/gpt/GPT.ts
+++ b/src/client/apis/gpt/GPT.ts
@@ -83,8 +83,51 @@ const gptAPICall = async (inputTextIn: string, callType: GPTCallType, prompt?: a
return 'Error connecting with API.';
}
};
+const gptImageCall = async (prompt: string, n?: number) => {
+ try {
+ const configuration: ClientOptions = {
+ apiKey: process.env.OPENAI_KEY,
+ dangerouslyAllowBrowser: true,
+ };
+
+ const openai = new OpenAI(configuration);
+ const response = await openai.images.generate({
+ prompt: prompt,
+ n: n ?? 1,
+ size: '1024x1024',
+ });
+ return response.data.map((data: any) => data.url);
+ // return response.data.data[0].url;
+ } catch (err) {
+ console.error(err);
+ }
+ return undefined;
+};
+
+const gptGetEmbedding = async (src: string): Promise<number[]> => {
+ try {
+ const configuration: ClientOptions = {
+ apiKey: process.env.OPENAI_KEY,
+ dangerouslyAllowBrowser: true,
+ };
+ const openai = new OpenAI(configuration);
+ const embeddingResponse = await openai.embeddings.create({
+ model: 'text-embedding-3-large',
+ input: [src],
+ encoding_format: 'float',
+ dimensions: 256,
+ });
-const gptImageLabel = async (imgUrl: string): Promise<string> => {
+ // Assume the embeddingResponse structure is correct; adjust based on actual API response
+ const embedding = embeddingResponse.data[0].embedding;
+ return embedding;
+ } catch (err) {
+ console.log(err);
+ return [];
+ }
+};
+
+const gptImageLabel = async (src: string): Promise<string> => {
try {
const configuration: ClientOptions = {
apiKey: process.env.OPENAI_KEY,
@@ -98,11 +141,12 @@ const gptImageLabel = async (imgUrl: string): Promise<string> => {
{
role: 'user',
content: [
- { type: 'text', text: 'Describe this image in 3-5 words' },
+ { type: 'text', text: 'Give three to five labels to describe this image.' },
{
type: 'image_url',
image_url: {
- url: `${imgUrl}`,
+ url: `${src}`,
+ detail: 'low',
},
},
],
@@ -112,7 +156,7 @@ const gptImageLabel = async (imgUrl: string): Promise<string> => {
if (response.choices[0].message.content) {
return response.choices[0].message.content;
} else {
- return ':(';
+ return 'Missing labels';
}
} catch (err) {
console.log(err);
@@ -120,25 +164,4 @@ const gptImageLabel = async (imgUrl: string): Promise<string> => {
}
};
-const gptImageCall = async (prompt: string, n?: number) => {
- try {
- const configuration: ClientOptions = {
- apiKey: process.env.OPENAI_KEY,
- dangerouslyAllowBrowser: true,
- };
-
- const openai = new OpenAI(configuration);
- const response = await openai.images.generate({
- prompt: prompt,
- n: n ?? 1,
- size: '1024x1024',
- });
- return response.data.map((data: any) => data.url);
- // return response.data.data[0].url;
- } catch (err) {
- console.error(err);
- }
- return undefined;
-};
-
-export { gptAPICall, gptImageCall, gptImageLabel, GPTCallType };
+export { gptAPICall, gptImageCall, GPTCallType, gptImageLabel, gptGetEmbedding };
diff --git a/src/client/util/CurrentUserUtils.ts b/src/client/util/CurrentUserUtils.ts
index 4fd6df799..8ce001158 100644
--- a/src/client/util/CurrentUserUtils.ts
+++ b/src/client/util/CurrentUserUtils.ts
@@ -668,13 +668,13 @@ pie title Minerals in my tap water
static labelTools(): Button[] {
return [
{ title: "AI", icon:"robot", toolTip:"Add AI labels", btnType: ButtonType.ToggleButton, expertMode: false, toolType:"chat", funcs: {hidden:`showFreeform ("chat", true)`},scripts: { onClick: '{ return showFreeform(this.toolType, _readOnly_);}'}},
- { title: "AIs", icon:"AI Sort", toolTip:"Filter AI labels", subMenu: this.cardGroupTools("chat"), expertMode: false, toolType:CollectionViewType.Card, funcs: {hidden:`!showFreeform("chat", true)`, linearView_IsOpen: `SelectionManager_selectedDocType(this.toolType, this.expertMode)`} },
+ { title: "AIs", icon:"AI Sort", toolTip:"Filter AI labels", subMenu: this.cardGroupTools("chat"), expertMode: false, toolType:CollectionViewType.Card, funcs: {hidden:`!showFreeform("chat", true)`, linearView_IsOpen: `SelectedDocType(this.toolType, this.expertMode)`} },
{ title: "Like", icon:"heart", toolTip:"Add Like labels", btnType: ButtonType.ToggleButton, expertMode: false, toolType:"like", funcs: {hidden:`showFreeform ("like", true)`},scripts: { onClick: '{ return showFreeform(this.toolType, _readOnly_);}'}},
- { title: "Likes", icon:"Likes", toolTip:"Filter likes", width: 10, subMenu: this.cardGroupTools("heart"), expertMode: false, toolType:CollectionViewType.Card, funcs: {hidden:`!showFreeform("like", true)`, linearView_IsOpen: `SelectionManager_selectedDocType(this.toolType, this.expertMode)`} },
+ { title: "Likes", icon:"Likes", toolTip:"Filter likes", width: 10, subMenu: this.cardGroupTools("heart"), expertMode: false, toolType:CollectionViewType.Card, funcs: {hidden:`!showFreeform("like", true)`, linearView_IsOpen: `SelectedDocType(this.toolType, this.expertMode)`} },
{ title: "Star", icon:"star", toolTip:"Add Star labels", btnType: ButtonType.ToggleButton, expertMode: false, toolType:"star", funcs: {hidden:`showFreeform ("star", true)`},scripts: { onClick: '{ return showFreeform(this.toolType, _readOnly_);}'}},
- { title: "Stars", icon:"Stars", toolTip:"Filter stars", width: 80, subMenu: this.cardGroupTools("star"), expertMode: false, toolType:CollectionViewType.Card, funcs: {hidden:`!showFreeform("star", true)`, linearView_IsOpen: `SelectionManager_selectedDocType(this.toolType, this.expertMode)`} },
+ { title: "Stars", icon:"Stars", toolTip:"Filter stars", width: 80, subMenu: this.cardGroupTools("star"), expertMode: false, toolType:CollectionViewType.Card, funcs: {hidden:`!showFreeform("star", true)`, linearView_IsOpen: `SelectedDocType(this.toolType, this.expertMode)`} },
{ title: "Idea", icon:"satellite", toolTip:"Add Idea labels", btnType: ButtonType.ToggleButton, expertMode: false, toolType:"idea", funcs: {hidden:`showFreeform ("idea", true)`},scripts: { onClick: '{ return showFreeform(this.toolType, _readOnly_);}'}},
- { title: "Ideas", icon:"Ideas", toolTip:"Filter ideas", width: 80, subMenu: this.cardGroupTools("satellite"), expertMode: false, toolType:CollectionViewType.Card, funcs: {hidden:`!showFreeform("idea", true)`, linearView_IsOpen: `SelectionManager_selectedDocType(this.toolType, this.expertMode)`} },
+ { title: "Ideas", icon:"Ideas", toolTip:"Filter ideas", width: 80, subMenu: this.cardGroupTools("satellite"), expertMode: false, toolType:CollectionViewType.Card, funcs: {hidden:`!showFreeform("idea", true)`, linearView_IsOpen: `SelectedDocType(this.toolType, this.expertMode)`} },
]
}
static cardGroupTools(icon: string): Button[] {
diff --git a/src/client/views/MainView.tsx b/src/client/views/MainView.tsx
index 33c343176..a53e24f8f 100644
--- a/src/client/views/MainView.tsx
+++ b/src/client/views/MainView.tsx
@@ -54,6 +54,7 @@ import { CollectionMenu } from './collections/CollectionMenu';
import { TabDocView } from './collections/TabDocView';
import './collections/TreeView.scss';
import { CollectionFreeFormView } from './collections/collectionFreeForm';
+import { ImageLabelHandler } from './collections/collectionFreeForm/ImageLabelHandler';
import { MarqueeOptionsMenu } from './collections/collectionFreeForm/MarqueeOptionsMenu';
import { CollectionLinearView } from './collections/collectionLinear';
import { LinkMenu } from './linking/LinkMenu';
@@ -76,8 +77,8 @@ import { AnchorMenu } from './pdf/AnchorMenu';
import { GPTPopup } from './pdf/GPTPopup/GPTPopup';
import { TopBar } from './topbar/TopBar';
-const _global = (window /* browser */ || global) /* node */ as any;
const { LEFT_MENU_WIDTH, TOPBAR_HEIGHT } = require('./global/globalCssVariables.module.scss'); // prettier-ignore
+const _global = (window /* browser */ || global) /* node */ as any;
@observer
export class MainView extends ObservableReactComponent<{}> {
@@ -542,7 +543,7 @@ export class MainView extends ObservableReactComponent<{}> {
fa.faHourglassHalf,
fa.faRobot,
fa.faSatellite,
- fa.faStar
+ fa.faStar,
]
);
}
@@ -1083,6 +1084,7 @@ export class MainView extends ObservableReactComponent<{}> {
<PreviewCursor />
<TaskCompletionBox />
<ContextMenu />
+ <ImageLabelHandler />
<AnchorMenu />
<MapAnchorMenu />
<DirectionsAnchorMenu />
diff --git a/src/client/views/collections/collectionFreeForm/ImageLabelHandler.scss b/src/client/views/collections/collectionFreeForm/ImageLabelHandler.scss
new file mode 100644
index 000000000..e7413bf8e
--- /dev/null
+++ b/src/client/views/collections/collectionFreeForm/ImageLabelHandler.scss
@@ -0,0 +1,44 @@
+#label-handler {
+ display: flex;
+ flex-direction: column;
+ align-items: center;
+
+ > div:first-child {
+ display: flex; // Puts the input and button on the same row
+ align-items: center; // Vertically centers items in the flex container
+
+ input {
+ color: black;
+ }
+
+ .IconButton {
+ margin-left: 8px; // Adds space between the input and the icon button
+ width: 19px;
+ }
+ }
+
+ > div:not(:first-of-type) {
+ display: flex;
+ flex-direction: column;
+ align-items: center; // Centers the content vertically in the flex container
+ width: 100%;
+
+ > div {
+ display: flex;
+ justify-content: space-between; // Puts the content and delete button on opposite ends
+ align-items: center;
+ width: 100%;
+ margin-top: 8px; // Adds space between label rows
+
+ p {
+ text-align: center; // Centers the text of the paragraph
+ flex-grow: 1; // Allows the paragraph to grow and occupy the available space
+ }
+
+ .IconButton {
+ // Styling for the delete button
+ margin-left: auto; // Pushes the button to the far right
+ }
+ }
+ }
+}
diff --git a/src/client/views/collections/collectionFreeForm/ImageLabelHandler.tsx b/src/client/views/collections/collectionFreeForm/ImageLabelHandler.tsx
new file mode 100644
index 000000000..46bc3d946
--- /dev/null
+++ b/src/client/views/collections/collectionFreeForm/ImageLabelHandler.tsx
@@ -0,0 +1,120 @@
+import { FontAwesomeIcon } from '@fortawesome/react-fontawesome';
+import { IconButton } from 'browndash-components';
+import { action, makeObservable, observable } from 'mobx';
+import { observer } from 'mobx-react';
+import React from 'react';
+import { SettingsManager } from '../../../util/SettingsManager';
+import { ObservableReactComponent } from '../../ObservableReactComponent';
+import { MarqueeOptionsMenu } from './MarqueeOptionsMenu';
+import './ImageLabelHandler.scss';
+
+@observer
+export class ImageLabelHandler extends ObservableReactComponent<{}> {
+ static Instance: ImageLabelHandler;
+
+ @observable _display: boolean = false;
+ @observable _pageX: number = 0;
+ @observable _pageY: number = 0;
+ @observable _yRelativeToTop: boolean = true;
+ @observable _currentLabel: string = '';
+ @observable _labelGroups: string[] = [];
+
+ constructor(props: any) {
+ super(props);
+ makeObservable(this);
+ ImageLabelHandler.Instance = this;
+ console.log('Instantiated label handler!');
+ }
+
+ @action
+ displayLabelHandler = (x: number, y: number) => {
+ this._pageX = x;
+ this._pageY = y;
+ this._display = true;
+ this._labelGroups = [];
+ };
+
+ @action
+ hideLabelhandler = () => {
+ this._display = false;
+ this._labelGroups = [];
+ };
+
+ @action
+ addLabel = (label: string) => {
+ label = label.toUpperCase().trim();
+ if (label.length > 0) {
+ if (!this._labelGroups.includes(label)) {
+ this._labelGroups = [...this._labelGroups, label];
+ }
+ }
+ };
+
+ @action
+ removeLabel = (label: string) => {
+ label = label.toUpperCase();
+ this._labelGroups = this._labelGroups.filter(group => group !== label);
+ };
+
+ @action
+ groupImages = () => {
+ MarqueeOptionsMenu.Instance.groupImages();
+ this._display = false;
+ };
+
+ render() {
+ if (this._display) {
+ return (
+ <div
+ id="label-handler"
+ className="contextMenu-cont"
+ style={{
+ display: this._display ? '' : 'none',
+ left: this._pageX,
+ ...(this._yRelativeToTop ? { top: Math.max(0, this._pageY) } : { bottom: this._pageY }),
+ background: SettingsManager.userBackgroundColor,
+ color: SettingsManager.userColor,
+ }}>
+ <div>
+ <IconButton tooltip={'Cancel'} onPointerDown={this.hideLabelhandler} icon={<FontAwesomeIcon icon="eye-slash" />} color={MarqueeOptionsMenu.Instance.userColor} style={{ width: '19px' }} />
+ <input aria-label="label-input" id="new-label" type="text" style={{ color: 'black' }} />
+ <IconButton
+ tooltip={'Add Label'}
+ onPointerDown={() => {
+ const input = document.getElementById('new-label') as HTMLInputElement;
+ const newLabel = input.value;
+ this.addLabel(newLabel);
+ this._currentLabel = '';
+ input.value = '';
+ }}
+ icon={<FontAwesomeIcon icon="plus" />}
+ color={MarqueeOptionsMenu.Instance.userColor}
+ style={{ width: '19px' }}
+ />
+ <IconButton tooltip={'Group Images'} onPointerDown={this.groupImages} icon={<FontAwesomeIcon icon="object-group" />} color={MarqueeOptionsMenu.Instance.userColor} style={{ width: '19px' }} />
+ </div>
+ <div>
+ {this._labelGroups.map(group => {
+ return (
+ <div>
+ <p>{group}</p>
+ <IconButton
+ tooltip={'Remove Label'}
+ onPointerDown={() => {
+ this.removeLabel(group);
+ }}
+ icon={'x'}
+ color={MarqueeOptionsMenu.Instance.userColor}
+ style={{ width: '19px' }}
+ />
+ </div>
+ );
+ })}
+ </div>
+ </div>
+ );
+ } else {
+ return <></>;
+ }
+ }
+}
diff --git a/src/client/views/collections/collectionFreeForm/MarqueeOptionsMenu.tsx b/src/client/views/collections/collectionFreeForm/MarqueeOptionsMenu.tsx
index adac5a102..f02cd9d45 100644
--- a/src/client/views/collections/collectionFreeForm/MarqueeOptionsMenu.tsx
+++ b/src/client/views/collections/collectionFreeForm/MarqueeOptionsMenu.tsx
@@ -18,6 +18,8 @@ export class MarqueeOptionsMenu extends AntimodeMenu<AntimodeMenuProps> {
public showMarquee: () => void = unimplementedFunction;
public hideMarquee: () => void = unimplementedFunction;
public pinWithView: (e: KeyboardEvent | React.PointerEvent | undefined) => void = unimplementedFunction;
+ public classifyImages: (e: React.MouseEvent | undefined) => void = unimplementedFunction;
+ public groupImages: () => void = unimplementedFunction;
public isShown = () => this._opacity > 0;
constructor(props: any) {
super(props);
@@ -37,6 +39,7 @@ export class MarqueeOptionsMenu extends AntimodeMenu<AntimodeMenuProps> {
<IconButton tooltip="Summarize Documents" onPointerDown={this.summarize} icon={<FontAwesomeIcon icon="compress-arrows-alt" />} color={this.userColor} />
<IconButton tooltip="Delete Documents" onPointerDown={this.delete} icon={<FontAwesomeIcon icon="trash-alt" />} color={this.userColor} />
<IconButton tooltip="Pin selected region" onPointerDown={this.pinWithView} icon={<FontAwesomeIcon icon="map-pin" />} color={this.userColor} />
+ <IconButton tooltip="Classify Images" onPointerDown={this.classifyImages} icon={<FontAwesomeIcon icon="object-group" />} color={this.userColor} />
</>
);
return this.getElement(buttons);
diff --git a/src/client/views/collections/collectionFreeForm/MarqueeView.tsx b/src/client/views/collections/collectionFreeForm/MarqueeView.tsx
index b96444024..768270c09 100644
--- a/src/client/views/collections/collectionFreeForm/MarqueeView.tsx
+++ b/src/client/views/collections/collectionFreeForm/MarqueeView.tsx
@@ -4,15 +4,16 @@ import { observer } from 'mobx-react';
import * as React from 'react';
import { ClientUtils, lightOrDark, returnFalse } from '../../../../ClientUtils';
import { intersectRect } from '../../../../Utils';
-import { Doc, Opt } from '../../../../fields/Doc';
+import { Doc, NumListCast, Opt } from '../../../../fields/Doc';
import { AclAdmin, AclAugment, AclEdit, DocData } from '../../../../fields/DocSymbols';
import { Id } from '../../../../fields/FieldSymbols';
import { InkData, InkField, InkTool } from '../../../../fields/InkField';
import { List } from '../../../../fields/List';
import { RichTextField } from '../../../../fields/RichTextField';
import { Cast, FieldValue, NumCast, StrCast } from '../../../../fields/Types';
-import { ImageField } from '../../../../fields/URLField';
+import { ImageField, URLField } from '../../../../fields/URLField';
import { GetEffectiveAcl } from '../../../../fields/util';
+import { gptGetEmbedding, gptImageLabel } from '../../../apis/gpt/GPT';
import { CognitiveServices } from '../../../cognitive_services/CognitiveServices';
import { DocUtils } from '../../../documents/DocUtils';
import { DocumentType } from '../../../documents/DocumentTypes';
@@ -29,9 +30,10 @@ import { OpenWhere } from '../../nodes/OpenWhere';
import { pasteImageBitmap } from '../../nodes/WebBoxRenderer';
import { FormattedTextBox } from '../../nodes/formattedText/FormattedTextBox';
import { SubCollectionViewProps } from '../CollectionSubView';
+import { ImageLabelHandler } from './ImageLabelHandler';
import { MarqueeOptionsMenu } from './MarqueeOptionsMenu';
import './MarqueeView.scss';
-
+import { ImageUtility } from '../../nodes/generativeFill/generativeFillUtils/ImageHandler';
interface MarqueeViewProps {
getContainerTransform: () => Transform;
getTransform: () => Transform;
@@ -61,11 +63,13 @@ export class MarqueeView extends ObservableReactComponent<SubCollectionViewProps
}
private _commandExecuted = false;
+ private _selectedDocs: Doc[] = [];
@observable _lastX: number = 0;
@observable _lastY: number = 0;
@observable _downX: number = 0;
@observable _downY: number = 0;
@observable _visible: boolean = false; // selection rentangle for marquee selection/free hand lasso is visible
+ @observable _labelsVisibile: boolean = false;
@observable _lassoPts: [number, number][] = [];
@observable _lassoFreehand: boolean = false;
@@ -267,6 +271,8 @@ export class MarqueeView extends ObservableReactComponent<SubCollectionViewProps
MarqueeOptionsMenu.Instance.hideMarquee = this.hideMarquee;
MarqueeOptionsMenu.Instance.jumpTo(e.clientX, e.clientY);
MarqueeOptionsMenu.Instance.pinWithView = this.pinWithView;
+ MarqueeOptionsMenu.Instance.classifyImages = this.classifyImages;
+ MarqueeOptionsMenu.Instance.groupImages = this.groupImages;
document.addEventListener('pointerdown', hideMarquee, true);
document.addEventListener('wheel', hideMarquee, true);
} else {
@@ -419,6 +425,102 @@ export class MarqueeView extends ObservableReactComponent<SubCollectionViewProps
this.hideMarquee();
});
+ /**
+ * Classifies images and assigns the labels as document fields.
+ * TODO: Turn into lists of labels instead of individual fields.
+ */
+ @undoBatch
+ classifyImages = action(async (e: React.MouseEvent | undefined) => {
+ const selected = this.marqueeSelect(false, DocumentType.IMG);
+ this._selectedDocs = selected;
+
+ const imagePromises = selected.map(doc => {
+ const href = (doc['data'] as URLField).url.href;
+ const hrefParts = href.split('.');
+ const hrefComplete = `${hrefParts[0]}_o.${hrefParts[1]}`;
+ return ImageUtility.urlToBase64(hrefComplete).then(hrefBase64 =>
+ !hrefBase64
+ ? undefined
+ : gptImageLabel(hrefBase64).then(response => {
+ const labels = response.split('\n');
+ doc.image_labels = new List<string>(Array.from(labels!));
+ return Promise.all(labels!.map(label => gptGetEmbedding(label))).then(embeddings => {
+ return { doc, embeddings };
+ });
+ })
+ );
+ });
+
+ const docsAndEmbeddings = await Promise.all(imagePromises);
+ docsAndEmbeddings
+ .filter(d => d)
+ .map(d => d!)
+ .forEach(docAndEmbedding => {
+ if (Array.isArray(docAndEmbedding.embeddings)) {
+ let doc = docAndEmbedding.doc;
+ for (let i = 0; i < 3; i++) {
+ doc[`label_embedding_${i + 1}`] = new List<number>(docAndEmbedding.embeddings[i]);
+ }
+ }
+ });
+
+ if (e) {
+ ImageLabelHandler.Instance.displayLabelHandler(e.pageX, e.pageY);
+ }
+ });
+
+ /**
+ * Groups images to most similar labels.
+ */
+ @undoBatch
+ groupImages = action(async () => {
+ const labelGroups: string[] = ImageLabelHandler.Instance._labelGroups;
+ const labelToCollection: Map<string, Doc> = new Map();
+ const labelToEmbedding: Map<string, number[]> = new Map();
+ var similarity = require('compute-cosine-similarity');
+
+ // Create new collections associated with each label and get the embeddings for the labels.
+ for (const label of labelGroups) {
+ const newCollection = this.getCollection([], undefined, false);
+ newCollection._freeform_panX = this.Bounds.left + this.Bounds.width / 2;
+ newCollection._freeform_panY = this.Bounds.top + this.Bounds.height / 2;
+ labelToCollection.set(label, newCollection);
+ this._props.addDocument?.(newCollection);
+ const labelEmbedding = await gptGetEmbedding(label);
+ if (Array.isArray(labelEmbedding)) {
+ labelToEmbedding.set(label, labelEmbedding);
+ }
+ }
+
+ // For each image, loop through the labels, and calculate similarity. Associate it with the
+ // most similar one.
+ this._selectedDocs.forEach(doc => {
+ let mostSimilarLabel: string | undefined;
+ let maxSimilarity: number = 0;
+ const embeddingAsList1 = NumListCast(doc.label_embedding_1);
+ const embeddingAsList2 = NumListCast(doc.label_embedding_2);
+ const embeddingAsList3 = NumListCast(doc.label_embedding_3);
+
+ labelGroups.forEach(label => {
+ let curSimilarity1 = similarity(labelToEmbedding.get(label)!, Array.from(embeddingAsList1));
+ let curSimilarity2 = similarity(labelToEmbedding.get(label)!, Array.from(embeddingAsList2));
+ let curSimilarity3 = similarity(labelToEmbedding.get(label)!, Array.from(embeddingAsList3));
+ let maxCurSimilarity = Math.max(curSimilarity1, curSimilarity2, curSimilarity3);
+ if (maxCurSimilarity >= 0.3 && maxCurSimilarity > maxSimilarity) {
+ mostSimilarLabel = label;
+ maxSimilarity = maxCurSimilarity;
+ }
+
+ console.log('Doc with labels ' + doc.image_labels + 'has similarity score ' + maxCurSimilarity + ' to ' + mostSimilarLabel);
+ });
+
+ if (mostSimilarLabel) {
+ Doc.AddDocToList(labelToCollection.get(mostSimilarLabel)!, undefined, doc);
+ this._props.removeDocument?.(doc);
+ }
+ });
+ });
+
@undoBatch
syntaxHighlight = action((e: KeyboardEvent | React.PointerEvent | undefined) => {
const selected = this.marqueeSelect(false);
@@ -579,7 +681,10 @@ export class MarqueeView extends ObservableReactComponent<SubCollectionViewProps
return false;
}
- marqueeSelect(selectBackgrounds: boolean = false) {
+ /**
+ * When this is called, returns the list of documents that have been selected by the marquee box.
+ */
+ marqueeSelect(selectBackgrounds: boolean = false, docType: DocumentType | undefined = undefined) {
const selection: Doc[] = [];
const selectFunc = (doc: Doc) => {
const layoutDoc = Doc.Layout(doc);
@@ -589,11 +694,19 @@ export class MarqueeView extends ObservableReactComponent<SubCollectionViewProps
} else {
(this.touchesLine(bounds) || this.boundingShape(bounds)) && selection.push(doc);
}
+ console.log(doc['type']);
};
- this._props
- .activeDocuments()
- .filter(doc => !doc.z && !doc._lockedPosition)
- .map(selectFunc);
+ if (docType) {
+ this._props
+ .activeDocuments()
+ .filter(doc => !doc.z && !doc._lockedPosition && doc['type'] === docType)
+ .map(selectFunc);
+ } else {
+ this._props
+ .activeDocuments()
+ .filter(doc => !doc.z && !doc._lockedPosition)
+ .map(selectFunc);
+ }
if (!selection.length && selectBackgrounds)
this._props
.activeDocuments()