Add status updates when generating chat completions, add system theme support, add custom openai model embedding support, and fix various bugs.

This commit is contained in:
jakepresent 2025-08-16 13:13:46 -04:00
parent cc821315f8
commit 1077c1b703
No known key found for this signature in database
13 changed files with 347 additions and 66 deletions

View file

@ -61,10 +61,22 @@ const handleEmitterEvents = async (
) => { ) => {
let recievedMessage = ''; let recievedMessage = '';
let sources: any[] = []; let sources: any[] = [];
let sentGeneratingStatus = false;
stream.on('data', (data) => { stream.on('data', (data: string) => {
const parsedData = JSON.parse(data); const parsedData = JSON.parse(data);
if (parsedData.type === 'response') { if (parsedData.type === 'response') {
if (!sentGeneratingStatus) {
writer.write(
encoder.encode(
JSON.stringify({
type: 'status',
data: 'Generating answer...',
}) + '\n',
),
);
sentGeneratingStatus = true;
}
writer.write( writer.write(
encoder.encode( encoder.encode(
JSON.stringify({ JSON.stringify({
@ -77,6 +89,17 @@ const handleEmitterEvents = async (
recievedMessage += parsedData.data; recievedMessage += parsedData.data;
} else if (parsedData.type === 'sources') { } else if (parsedData.type === 'sources') {
if (!sentGeneratingStatus) {
writer.write(
encoder.encode(
JSON.stringify({
type: 'status',
data: 'Generating answer...',
}) + '\n',
),
);
sentGeneratingStatus = true;
}
writer.write( writer.write(
encoder.encode( encoder.encode(
JSON.stringify({ JSON.stringify({
@ -114,8 +137,16 @@ const handleEmitterEvents = async (
}) })
.execute(); .execute();
}); });
stream.on('error', (data) => { stream.on('error', (data: string) => {
const parsedData = JSON.parse(data); const parsedData = JSON.parse(data);
writer.write(
encoder.encode(
JSON.stringify({
type: 'status',
data: 'Chat completion failed.',
}) + '\n',
),
);
writer.write( writer.write(
encoder.encode( encoder.encode(
JSON.stringify({ JSON.stringify({
@ -218,6 +249,28 @@ export const POST = async (req: Request) => {
body.embeddingModel?.name || Object.keys(embeddingProvider)[0] body.embeddingModel?.name || Object.keys(embeddingProvider)[0]
]; ];
const selectedChatProviderKey =
body.chatModel?.provider || Object.keys(chatModelProviders)[0];
const selectedChatModelKey =
body.chatModel?.name || Object.keys(chatModelProvider)[0];
const selectedEmbeddingProviderKey =
body.embeddingModel?.provider || Object.keys(embeddingModelProviders)[0];
const selectedEmbeddingModelKey =
body.embeddingModel?.name || Object.keys(embeddingProvider)[0];
console.log('[Models] Chat request', {
chatProvider: selectedChatProviderKey,
chatModel: selectedChatModelKey,
embeddingProvider: selectedEmbeddingProviderKey,
embeddingModel: selectedEmbeddingModelKey,
...(selectedChatProviderKey === 'custom_openai'
? { chatBaseURL: getCustomOpenaiApiUrl() }
: {}),
...(selectedEmbeddingProviderKey === 'custom_openai'
? { embeddingBaseURL: getCustomOpenaiApiUrl() }
: {}),
});
let llm: BaseChatModel | undefined; let llm: BaseChatModel | undefined;
let embedding = embeddingModel.model; let embedding = embeddingModel.model;
@ -272,11 +325,54 @@ export const POST = async (req: Request) => {
); );
} }
const llmProxy = new Proxy(llm as any, {
get(target, prop, receiver) {
if (
prop === 'invoke' ||
prop === 'stream' ||
prop === 'streamEvents' ||
prop === 'generate'
) {
return (...args: any[]) => {
console.log('[Models] Chat model call', {
provider: selectedChatProviderKey,
model: selectedChatModelKey,
method: String(prop),
});
return (target as any)[prop](...args);
};
}
return Reflect.get(target, prop, receiver);
},
});
const embeddingProxy = new Proxy(embedding as any, {
get(target, prop, receiver) {
if (prop === 'embedQuery' || prop === 'embedDocuments') {
return (...args: any[]) => {
console.log('[Models] Embedding model call', {
provider: selectedEmbeddingProviderKey,
model: selectedEmbeddingModelKey,
method: String(prop),
size:
prop === 'embedDocuments'
? Array.isArray(args[0])
? args[0].length
: undefined
: undefined,
});
return (target as any)[prop](...args);
};
}
return Reflect.get(target, prop, receiver);
},
});
const stream = await handler.searchAndAnswer( const stream = await handler.searchAndAnswer(
message.content, message.content,
history, history,
llm, llmProxy as any,
embedding, embeddingProxy as any,
body.optimizationMode, body.optimizationMode,
body.files, body.files,
body.systemInstructions, body.systemInstructions,
@ -286,6 +382,18 @@ export const POST = async (req: Request) => {
const writer = responseStream.writable.getWriter(); const writer = responseStream.writable.getWriter();
const encoder = new TextEncoder(); const encoder = new TextEncoder();
writer.write(
encoder.encode(
JSON.stringify({
type: 'status',
data:
body.focusMode === 'writingAssistant'
? 'Waiting for chat completion...'
: 'Searching web...',
}) + '\n',
),
);
handleEmitterEvents(stream, writer, encoder, aiMessageId, message.chatId); handleEmitterEvents(stream, writer, encoder, aiMessageId, message.chatId);
handleHistorySave(message, humanMessageId, body.focusMode, body.files); handleHistorySave(message, humanMessageId, body.focusMode, body.files);

View file

@ -75,6 +75,19 @@ export const POST = async (req: Request) => {
body.embeddingModel?.name || body.embeddingModel?.name ||
Object.keys(embeddingModelProviders[embeddingModelProvider])[0]; Object.keys(embeddingModelProviders[embeddingModelProvider])[0];
console.log('[Models] Search request', {
chatProvider: chatModelProvider,
chatModel,
embeddingProvider: embeddingModelProvider,
embeddingModel,
...(chatModelProvider === 'custom_openai'
? { chatBaseURL: getCustomOpenaiApiUrl() }
: {}),
...(embeddingModelProvider === 'custom_openai'
? { embeddingBaseURL: getCustomOpenaiApiUrl() }
: {}),
});
let llm: BaseChatModel | undefined; let llm: BaseChatModel | undefined;
let embeddings: Embeddings | undefined; let embeddings: Embeddings | undefined;
@ -118,11 +131,54 @@ export const POST = async (req: Request) => {
return Response.json({ message: 'Invalid focus mode' }, { status: 400 }); return Response.json({ message: 'Invalid focus mode' }, { status: 400 });
} }
const llmProxy = new Proxy(llm as any, {
get(target, prop, receiver) {
if (
prop === 'invoke' ||
prop === 'stream' ||
prop === 'streamEvents' ||
prop === 'generate'
) {
return (...args: any[]) => {
console.log('[Models] Chat model call', {
provider: chatModelProvider,
model: chatModel,
method: String(prop),
});
return (target as any)[prop](...args);
};
}
return Reflect.get(target, prop, receiver);
},
});
const embeddingProxy = new Proxy(embeddings as any, {
get(target, prop, receiver) {
if (prop === 'embedQuery' || prop === 'embedDocuments') {
return (...args: any[]) => {
console.log('[Models] Embedding model call', {
provider: embeddingModelProvider,
model: embeddingModel,
method: String(prop),
size:
prop === 'embedDocuments'
? Array.isArray(args[0])
? args[0].length
: undefined
: undefined,
});
return (target as any)[prop](...args);
};
}
return Reflect.get(target, prop, receiver);
},
});
const emitter = await searchHandler.searchAndAnswer( const emitter = await searchHandler.searchAndAnswer(
body.query, body.query,
history, history,
llm, llmProxy as any,
embeddings, embeddingProxy as any,
body.optimizationMode, body.optimizationMode,
[], [],
body.systemInstructions || '', body.systemInstructions || '',

View file

@ -3,6 +3,7 @@ import fs from 'fs';
import path from 'path'; import path from 'path';
import crypto from 'crypto'; import crypto from 'crypto';
import { getAvailableEmbeddingModelProviders } from '@/lib/providers'; import { getAvailableEmbeddingModelProviders } from '@/lib/providers';
import { getCustomOpenaiApiUrl } from '@/lib/config';
import { PDFLoader } from '@langchain/community/document_loaders/fs/pdf'; import { PDFLoader } from '@langchain/community/document_loaders/fs/pdf';
import { DocxLoader } from '@langchain/community/document_loaders/fs/docx'; import { DocxLoader } from '@langchain/community/document_loaders/fs/docx';
import { RecursiveCharacterTextSplitter } from '@langchain/textsplitters'; import { RecursiveCharacterTextSplitter } from '@langchain/textsplitters';
@ -46,6 +47,14 @@ export async function POST(req: Request) {
const embeddingModel = const embeddingModel =
embedding_model ?? Object.keys(embeddingModels[provider as string])[0]; embedding_model ?? Object.keys(embeddingModels[provider as string])[0];
console.log('[Models] Upload embeddings request', {
embeddingProvider: provider,
embeddingModel,
...(provider === 'custom_openai'
? { embeddingBaseURL: getCustomOpenaiApiUrl() }
: {}),
});
let embeddingsModel = let embeddingsModel =
embeddingModels[provider as string]?.[embeddingModel as string]?.model; embeddingModels[provider as string]?.[embeddingModel as string]?.model;
if (!embeddingsModel) { if (!embeddingsModel) {
@ -55,6 +64,28 @@ export async function POST(req: Request) {
); );
} }
const loggedEmbeddings = new Proxy(embeddingsModel as any, {
get(target, prop, receiver) {
if (prop === 'embedQuery' || prop === 'embedDocuments') {
return (...args: any[]) => {
console.log('[Models] Upload embedding model call', {
provider,
model: embeddingModel,
method: String(prop),
size:
prop === 'embedDocuments'
? Array.isArray(args[0])
? args[0].length
: undefined
: undefined,
});
return (target as any)[prop](...args);
};
}
return Reflect.get(target, prop, receiver);
},
});
const processedFiles: FileRes[] = []; const processedFiles: FileRes[] = [];
await Promise.all( await Promise.all(
@ -98,7 +129,7 @@ export async function POST(req: Request) {
}), }),
); );
const embeddings = await embeddingsModel.embedDocuments( const embeddings = await loggedEmbeddings.embedDocuments(
splitted.map((doc) => doc.pageContent), splitted.map((doc) => doc.pageContent),
); );
const embeddingsDataPath = filePath.replace( const embeddingsDataPath = filePath.replace(

View file

@ -16,6 +16,7 @@ const Chat = ({
setFileIds, setFileIds,
files, files,
setFiles, setFiles,
statusText,
}: { }: {
messages: Message[]; messages: Message[];
sendMessage: (message: string) => void; sendMessage: (message: string) => void;
@ -26,6 +27,7 @@ const Chat = ({
setFileIds: (fileIds: string[]) => void; setFileIds: (fileIds: string[]) => void;
files: File[]; files: File[];
setFiles: (files: File[]) => void; setFiles: (files: File[]) => void;
statusText?: string;
}) => { }) => {
const [dividerWidth, setDividerWidth] = useState(0); const [dividerWidth, setDividerWidth] = useState(0);
const dividerRef = useRef<HTMLDivElement | null>(null); const dividerRef = useRef<HTMLDivElement | null>(null);
@ -78,6 +80,7 @@ const Chat = ({
isLast={isLast} isLast={isLast}
rewrite={rewrite} rewrite={rewrite}
sendMessage={sendMessage} sendMessage={sendMessage}
statusText={statusText}
/> />
{!isLast && msg.role === 'assistant' && ( {!isLast && msg.role === 'assistant' && (
<div className="h-px w-full bg-light-secondary dark:bg-dark-secondary" /> <div className="h-px w-full bg-light-secondary dark:bg-dark-secondary" />
@ -85,7 +88,9 @@ const Chat = ({
</Fragment> </Fragment>
); );
})} })}
{loading && !messageAppeared && <MessageBoxLoading />} {loading && !messageAppeared && (
<MessageBoxLoading statusText={statusText} />
)}
<div ref={messageEnd} className="h-0" /> <div ref={messageEnd} className="h-0" />
{dividerWidth > 0 && ( {dividerWidth > 0 && (
<div <div

View file

@ -313,6 +313,7 @@ const ChatWindow = ({ id }: { id?: string }) => {
const [isMessagesLoaded, setIsMessagesLoaded] = useState(false); const [isMessagesLoaded, setIsMessagesLoaded] = useState(false);
const [notFound, setNotFound] = useState(false); const [notFound, setNotFound] = useState(false);
const [statusText, setStatusText] = useState<string | undefined>(undefined);
useEffect(() => { useEffect(() => {
if ( if (
@ -367,6 +368,11 @@ const ChatWindow = ({ id }: { id?: string }) => {
setLoading(true); setLoading(true);
setMessageAppeared(false); setMessageAppeared(false);
setStatusText(
focusMode === 'writingAssistant'
? 'Waiting for chat completion...'
: 'Searching web...'
);
let sources: Document[] | undefined = undefined; let sources: Document[] | undefined = undefined;
let recievedMessage = ''; let recievedMessage = '';
@ -386,13 +392,19 @@ const ChatWindow = ({ id }: { id?: string }) => {
]); ]);
const messageHandler = async (data: any) => { const messageHandler = async (data: any) => {
if (data.type === 'status') {
if (typeof data.data === 'string') setStatusText(data.data);
return;
}
if (data.type === 'error') { if (data.type === 'error') {
toast.error(data.data); toast.error(data.data);
setStatusText('Chat completion failed.');
setLoading(false); setLoading(false);
return; return;
} }
if (data.type === 'sources') { if (data.type === 'sources') {
setStatusText('Generating answer...');
sources = data.data; sources = data.data;
if (!added) { if (!added) {
setMessages((prevMessages) => [ setMessages((prevMessages) => [
@ -412,6 +424,7 @@ const ChatWindow = ({ id }: { id?: string }) => {
} }
if (data.type === 'message') { if (data.type === 'message') {
setStatusText('Generating answer...');
if (!added) { if (!added) {
setMessages((prevMessages) => [ setMessages((prevMessages) => [
...prevMessages, ...prevMessages,
@ -442,6 +455,7 @@ const ChatWindow = ({ id }: { id?: string }) => {
} }
if (data.type === 'messageEnd') { if (data.type === 'messageEnd') {
setStatusText(undefined);
setChatHistory((prevHistory) => [ setChatHistory((prevHistory) => [
...prevHistory, ...prevHistory,
['human', message], ['human', message],
@ -519,31 +533,61 @@ const ChatWindow = ({ id }: { id?: string }) => {
}), }),
}); });
if (!res.body) throw new Error('No response body'); if (!res.ok) {
const text = await res.text();
try {
const json = JSON.parse(text);
toast.error(json.message || `Request failed: ${res.status} ${res.statusText}`);
} catch {
toast.error(`Request failed: ${res.status} ${res.statusText}`);
}
setStatusText('Chat completion failed.');
setLoading(false);
return;
}
if (!res.body) {
toast.error('No response body');
setStatusText('Chat completion failed.');
setLoading(false);
return;
}
const reader = res.body?.getReader(); const reader = res.body?.getReader();
const decoder = new TextDecoder('utf-8'); const decoder = new TextDecoder('utf-8');
let partialChunk = ''; let partialChunk = '';
try {
while (true) {
const { value, done } = await reader.read();
if (done) break;
while (true) { partialChunk += decoder.decode(value, { stream: true });
const { value, done } = await reader.read();
if (done) break;
partialChunk += decoder.decode(value, { stream: true }); try {
const messages = partialChunk.split('\n');
try { for (const msg of messages) {
const messages = partialChunk.split('\n'); if (!msg.trim()) continue;
for (const msg of messages) { const json = JSON.parse(msg);
if (!msg.trim()) continue; messageHandler(json);
const json = JSON.parse(msg); }
messageHandler(json); partialChunk = '';
} catch (error) {
console.warn('Incomplete JSON, waiting for next chunk...');
} }
partialChunk = '';
} catch (error) {
console.warn('Incomplete JSON, waiting for next chunk...');
} }
} catch (e) {
console.error('Streaming error', e);
toast.error('Chat streaming failed.');
setStatusText('Chat completion failed.');
setLoading(false);
return;
} }
// Fallback: if the stream ended without 'messageEnd' or explicit error,
// ensure the UI doesn't stay in a loading state indefinitely.
setStatusText(undefined);
setLoading(false);
}; };
const rewrite = (messageId: string) => { const rewrite = (messageId: string) => {
@ -605,6 +649,7 @@ const ChatWindow = ({ id }: { id?: string }) => {
setFileIds={setFileIds} setFileIds={setFileIds}
files={files} files={files}
setFiles={setFiles} setFiles={setFiles}
statusText={statusText}
/> />
</> </>
) : ( ) : (

View file

@ -36,7 +36,7 @@ const EmptyChat = ({
<div className="flex flex-col items-center justify-center min-h-screen max-w-screen-sm mx-auto p-2 space-y-4"> <div className="flex flex-col items-center justify-center min-h-screen max-w-screen-sm mx-auto p-2 space-y-4">
<div className="flex flex-col items-center justify-center w-full space-y-8"> <div className="flex flex-col items-center justify-center w-full space-y-8">
<h2 className="text-black/70 dark:text-white/70 text-3xl font-medium -mt-8"> <h2 className="text-black/70 dark:text-white/70 text-3xl font-medium -mt-8">
Research begins here. Ask away...
</h2> </h2>
<EmptyChatMessageInput <EmptyChatMessageInput
sendMessage={sendMessage} sendMessage={sendMessage}

View file

@ -1,3 +1,5 @@
'use client';
import { Check, ClipboardList } from 'lucide-react'; import { Check, ClipboardList } from 'lucide-react';
import { Message } from '../ChatWindow'; import { Message } from '../ChatWindow';
import { useState } from 'react'; import { useState } from 'react';
@ -13,11 +15,37 @@ const Copy = ({
return ( return (
<button <button
onClick={() => { onClick={async () => {
const contentToCopy = `${initialMessage}${message.sources && message.sources.length > 0 && `\n\nCitations:\n${message.sources?.map((source: any, i: any) => `[${i + 1}] ${source.metadata.url}`).join(`\n`)}`}`; const citations =
navigator.clipboard.writeText(contentToCopy); message.sources && message.sources.length > 0
setCopied(true); ? `\n\nCitations:\n${message.sources
setTimeout(() => setCopied(false), 1000); ?.map((source: any, i: number) => {
const url = source?.metadata?.url ?? '';
return `[${i + 1}] ${url}`;
})
.join('\n')}`
: '';
const contentToCopy = `${initialMessage}${citations}`;
try {
if (navigator?.clipboard && window.isSecureContext) {
await navigator.clipboard.writeText(contentToCopy);
} else {
const textArea = document.createElement('textarea');
textArea.value = contentToCopy;
textArea.style.position = 'fixed';
textArea.style.left = '-9999px';
document.body.appendChild(textArea);
textArea.focus();
textArea.select();
document.execCommand('copy');
document.body.removeChild(textArea);
}
setCopied(true);
setTimeout(() => setCopied(false), 1200);
} catch (err) {
console.error('Copy failed', err);
}
}} }}
className="p-2 text-black/70 dark:text-white/70 rounded-xl hover:bg-light-secondary dark:hover:bg-dark-secondary transition duration-200 hover:text-black dark:hover:text-white" className="p-2 text-black/70 dark:text-white/70 rounded-xl hover:bg-light-secondary dark:hover:bg-dark-secondary transition duration-200 hover:text-black dark:hover:text-white"
> >

View file

@ -42,6 +42,7 @@ const MessageBox = ({
isLast, isLast,
rewrite, rewrite,
sendMessage, sendMessage,
statusText,
}: { }: {
message: Message; message: Message;
messageIndex: number; messageIndex: number;
@ -51,6 +52,7 @@ const MessageBox = ({
isLast: boolean; isLast: boolean;
rewrite: (messageId: string) => void; rewrite: (messageId: string) => void;
sendMessage: (message: string) => void; sendMessage: (message: string) => void;
statusText?: string;
}) => { }) => {
const [parsedMessage, setParsedMessage] = useState(message.content); const [parsedMessage, setParsedMessage] = useState(message.content);
const [speechMessage, setSpeechMessage] = useState(message.content); const [speechMessage, setSpeechMessage] = useState(message.content);
@ -182,7 +184,7 @@ const MessageBox = ({
size={20} size={20}
/> />
<h3 className="text-black dark:text-white font-medium text-xl"> <h3 className="text-black dark:text-white font-medium text-xl">
Answer {loading && isLast && statusText ? statusText : 'Answer'}
</h3> </h3>
</div> </div>

View file

@ -1,9 +1,14 @@
const MessageBoxLoading = () => { const MessageBoxLoading = ({ statusText }: { statusText?: string }) => {
return ( return (
<div className="flex flex-col space-y-2 w-full lg:w-9/12 bg-light-primary dark:bg-dark-primary animate-pulse rounded-lg py-3"> <div className="flex flex-col space-y-2 w-full lg:w-9/12 bg-light-primary dark:bg-dark-primary animate-pulse rounded-lg py-3">
<div className="h-2 rounded-full w-full bg-light-secondary dark:bg-dark-secondary" /> <div className="h-2 rounded-full w-full bg-light-secondary dark:bg-dark-secondary" />
<div className="h-2 rounded-full w-9/12 bg-light-secondary dark:bg-dark-secondary" /> <div className="h-2 rounded-full w-9/12 bg-light-secondary dark:bg-dark-secondary" />
<div className="h-2 rounded-full w-10/12 bg-light-secondary dark:bg-dark-secondary" /> <div className="h-2 rounded-full w-10/12 bg-light-secondary dark:bg-dark-secondary" />
{statusText && (
<div className="mt-3 text-xs text-black/70 dark:text-white/70 not-italic animate-none">
{statusText}
</div>
)}
</div> </div>
); );
}; };

View file

@ -1,13 +1,10 @@
'use client'; 'use client';
import { ThemeProvider } from 'next-themes'; import { ThemeProvider } from 'next-themes';
import type { ReactNode } from 'react';
const ThemeProviderComponent = ({ const ThemeProviderComponent = ({ children }: { children: ReactNode }) => {
children,
}: {
children: React.ReactNode;
}) => {
return ( return (
<ThemeProvider attribute="class" enableSystem={false} defaultTheme="dark"> <ThemeProvider attribute="class" enableSystem={true} defaultTheme="system">
{children} {children}
</ThemeProvider> </ThemeProvider>
); );

View file

@ -1,44 +1,19 @@
'use client'; 'use client';
import { useTheme } from 'next-themes'; import { useTheme } from 'next-themes';
import { useCallback, useEffect, useState } from 'react'; import { useEffect, useState } from 'react';
import type { ChangeEvent } from 'react';
import Select from '../ui/Select'; import Select from '../ui/Select';
type Theme = 'dark' | 'light' | 'system'; type Theme = 'dark' | 'light' | 'system';
const ThemeSwitcher = ({ className }: { className?: string }) => { const ThemeSwitcher = ({ className }: { className?: string }) => {
const [mounted, setMounted] = useState(false); const [mounted, setMounted] = useState(false);
const { theme, setTheme } = useTheme(); const { theme, setTheme } = useTheme();
const isTheme = useCallback((t: Theme) => t === theme, [theme]);
const handleThemeSwitch = (theme: Theme) => {
setTheme(theme);
};
useEffect(() => { useEffect(() => {
setMounted(true); setMounted(true);
}, []); }, []);
useEffect(() => {
if (isTheme('system')) {
const preferDarkScheme = window.matchMedia(
'(prefers-color-scheme: dark)',
);
const detectThemeChange = (event: MediaQueryListEvent) => {
const theme: Theme = event.matches ? 'dark' : 'light';
setTheme(theme);
};
preferDarkScheme.addEventListener('change', detectThemeChange);
return () => {
preferDarkScheme.removeEventListener('change', detectThemeChange);
};
}
}, [isTheme, setTheme, theme]);
// Avoid Hydration Mismatch // Avoid Hydration Mismatch
if (!mounted) { if (!mounted) {
return null; return null;
@ -48,8 +23,9 @@ const ThemeSwitcher = ({ className }: { className?: string }) => {
<Select <Select
className={className} className={className}
value={theme} value={theme}
onChange={(e) => handleThemeSwitch(e.target.value as Theme)} onChange={(e: ChangeEvent<HTMLSelectElement>) => setTheme(e.target.value as Theme)}
options={[ options={[
{ value: 'system', label: 'System' },
{ value: 'light', label: 'Light' }, { value: 'light', label: 'Light' },
{ value: 'dark', label: 'Dark' }, { value: 'dark', label: 'Dark' },
]} ]}

View file

@ -45,6 +45,7 @@ interface Config {
API_URL: string; API_URL: string;
API_KEY: string; API_KEY: string;
MODEL_NAME: string; MODEL_NAME: string;
EMBEDDING_MODEL_NAME: string;
}; };
}; };
API_ENDPOINTS: { API_ENDPOINTS: {
@ -99,6 +100,9 @@ export const getCustomOpenaiApiUrl = () =>
export const getCustomOpenaiModelName = () => export const getCustomOpenaiModelName = () =>
loadConfig().MODELS.CUSTOM_OPENAI.MODEL_NAME; loadConfig().MODELS.CUSTOM_OPENAI.MODEL_NAME;
export const getCustomOpenaiEmbeddingModelName = () =>
loadConfig().MODELS.CUSTOM_OPENAI.EMBEDDING_MODEL_NAME;
export const getLMStudioApiEndpoint = () => export const getLMStudioApiEndpoint = () =>
loadConfig().MODELS.LM_STUDIO.API_URL; loadConfig().MODELS.LM_STUDIO.API_URL;

View file

@ -10,8 +10,9 @@ import {
getCustomOpenaiApiKey, getCustomOpenaiApiKey,
getCustomOpenaiApiUrl, getCustomOpenaiApiUrl,
getCustomOpenaiModelName, getCustomOpenaiModelName,
getCustomOpenaiEmbeddingModelName,
} from '../config'; } from '../config';
import { ChatOpenAI } from '@langchain/openai'; import { ChatOpenAI, OpenAIEmbeddings } from '@langchain/openai';
import { import {
loadOllamaChatModels, loadOllamaChatModels,
loadOllamaEmbeddingModels, loadOllamaEmbeddingModels,
@ -143,5 +144,28 @@ export const getAvailableEmbeddingModelProviders = async () => {
} }
} }
const customOpenAiApiKey = getCustomOpenaiApiKey();
const customOpenAiApiUrl = getCustomOpenaiApiUrl();
const customOpenAiEmbeddingModelName = getCustomOpenaiEmbeddingModelName();
models['custom_openai'] = {
...(customOpenAiApiKey &&
customOpenAiApiUrl &&
customOpenAiEmbeddingModelName
? {
[customOpenAiEmbeddingModelName]: {
displayName: customOpenAiEmbeddingModelName,
model: new OpenAIEmbeddings({
apiKey: customOpenAiApiKey,
modelName: customOpenAiEmbeddingModelName,
configuration: {
baseURL: customOpenAiApiUrl,
},
}) as unknown as Embeddings,
},
}
: {}),
};
return models; return models;
}; };