diff --git a/sample.config.toml b/sample.config.toml index 1db2125..65297d7 100644 --- a/sample.config.toml +++ b/sample.config.toml @@ -21,6 +21,8 @@ MODEL_NAME = "" [MODELS.OLLAMA] API_URL = "" # Ollama API URL - http://host.docker.internal:11434 +API_KEY = "" +MODEL_NAME = "" [MODELS.DEEPSEEK] API_KEY = "" diff --git a/src/lib/config.ts b/src/lib/config.ts index 78ad09c..86a349e 100644 --- a/src/lib/config.ts +++ b/src/lib/config.ts @@ -31,6 +31,8 @@ interface Config { }; OLLAMA: { API_URL: string; + API_KEY: string; + MODEL_NAME: string; }; DEEPSEEK: { API_KEY: string; @@ -82,6 +84,8 @@ export const getSearxngApiEndpoint = () => process.env.SEARXNG_API_URL || loadConfig().API_ENDPOINTS.SEARXNG; export const getOllamaApiEndpoint = () => loadConfig().MODELS.OLLAMA.API_URL; +export const getOllamaApiKey = () => loadConfig().MODELS.OLLAMA.API_KEY; +export const getOllamaModelName = () => loadConfig().MODELS.OLLAMA.MODEL_NAME; export const getDeepseekApiKey = () => loadConfig().MODELS.DEEPSEEK.API_KEY; diff --git a/src/lib/providers/ollama.ts b/src/lib/providers/ollama.ts index cca2142..5f867ba 100644 --- a/src/lib/providers/ollama.ts +++ b/src/lib/providers/ollama.ts @@ -1,5 +1,5 @@ import axios from 'axios'; -import { getKeepAlive, getOllamaApiEndpoint } from '../config'; +import { getKeepAlive, getOllamaApiEndpoint, getOllamaApiKey, getOllamaModelName } from '../config'; import { ChatModel, EmbeddingModel } from '.'; export const PROVIDER_INFO = { @@ -8,6 +8,20 @@ export const PROVIDER_INFO = { }; import { ChatOllama } from '@langchain/community/chat_models/ollama'; import { OllamaEmbeddings } from '@langchain/community/embeddings/ollama'; +import { get } from 'http'; + +const getOllamaHttpHeaders = (): Record => { + const result: Record = {}; + + if (getOllamaApiKey()) { + result["Authorization"] = `Bearer ${getOllamaApiKey()}`; + } + if (process.env.OLLAMA_API_KEY) { + result["Authorization"] = `Bearer ${process.env.OLLAMA_API_KEY}`; + } + + return result; +}; export const loadOllamaChatModels = async () => { const ollamaApiEndpoint = getOllamaApiEndpoint(); @@ -18,6 +32,7 @@ export const loadOllamaChatModels = async () => { const res = await axios.get(`${ollamaApiEndpoint}/api/tags`, { headers: { 'Content-Type': 'application/json', + ...getOllamaHttpHeaders(), }, }); @@ -26,6 +41,9 @@ export const loadOllamaChatModels = async () => { const chatModels: Record = {}; models.forEach((model: any) => { + if (getOllamaModelName() && !model.model.startsWith(getOllamaModelName())) { + return; // Skip models that do not match the configured model name + } chatModels[model.model] = { displayName: model.name, model: new ChatOllama({ @@ -33,6 +51,7 @@ export const loadOllamaChatModels = async () => { model: model.model, temperature: 0.7, keepAlive: getKeepAlive(), + headers: getOllamaHttpHeaders(), }), }; }); @@ -53,6 +72,7 @@ export const loadOllamaEmbeddingModels = async () => { const res = await axios.get(`${ollamaApiEndpoint}/api/tags`, { headers: { 'Content-Type': 'application/json', + ...getOllamaHttpHeaders(), }, }); @@ -66,6 +86,7 @@ export const loadOllamaEmbeddingModels = async () => { model: new OllamaEmbeddings({ baseUrl: ollamaApiEndpoint, model: model.model, + headers: getOllamaHttpHeaders(), }), }; });