Perplexica/src/lib/providers.ts

186 lines
4.9 KiB
TypeScript
Raw Normal View History

2024-07-05 14:36:50 +08:00
import { ChatOpenAI, OpenAIEmbeddings } from "@langchain/openai";
import { ChatOllama } from "@langchain/community/chat_models/ollama";
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama";
import { HuggingFaceTransformersEmbeddings } from "./huggingfaceTransformer";
import { getGroqApiKey, getOllamaApiEndpoint, getOpenaiApiKey } from "../config";
import logger from "../utils/logger";
2024-04-20 11:18:52 +05:30
export const getAvailableChatModelProviders = async () => {
2024-04-20 11:18:52 +05:30
const openAIApiKey = getOpenaiApiKey();
2024-05-01 19:43:06 +05:30
const groqApiKey = getGroqApiKey();
2024-04-20 11:18:52 +05:30
const ollamaEndpoint = getOllamaApiEndpoint();
const models = {};
if (openAIApiKey) {
2024-04-21 20:52:47 +05:30
try {
2024-07-05 14:36:50 +08:00
models["openai"] = {
"GPT-3.5 turbo": new ChatOpenAI({
2024-04-21 20:52:47 +05:30
openAIApiKey,
2024-07-05 14:36:50 +08:00
modelName: "gpt-3.5-turbo",
2024-04-21 20:52:47 +05:30
temperature: 0.7,
}),
2024-07-05 14:36:50 +08:00
"GPT-4": new ChatOpenAI({
2024-04-21 20:52:47 +05:30
openAIApiKey,
2024-07-05 14:36:50 +08:00
modelName: "gpt-4",
2024-04-21 20:52:47 +05:30
temperature: 0.7,
}),
2024-07-05 14:36:50 +08:00
"GPT-4 turbo": new ChatOpenAI({
openAIApiKey,
2024-07-05 14:36:50 +08:00
modelName: "gpt-4-turbo",
temperature: 0.7,
2024-04-21 20:52:47 +05:30
}),
2024-07-05 14:36:50 +08:00
"GPT-4 omni": new ChatOpenAI({
2024-05-14 19:33:54 +05:30
openAIApiKey,
2024-07-05 14:36:50 +08:00
modelName: "gpt-4o",
2024-05-14 19:33:54 +05:30
temperature: 0.7,
}),
2024-04-21 20:52:47 +05:30
};
} catch (err) {
2024-04-30 12:18:18 +05:30
logger.error(`Error loading OpenAI models: ${err}`);
2024-04-21 20:52:47 +05:30
}
2024-04-20 11:18:52 +05:30
}
2024-05-01 19:43:06 +05:30
if (groqApiKey) {
try {
2024-07-05 14:36:50 +08:00
models["groq"] = {
"LLaMA3 8b": new ChatOpenAI(
2024-05-01 19:43:06 +05:30
{
openAIApiKey: groqApiKey,
2024-07-05 14:36:50 +08:00
modelName: "llama3-8b-8192",
2024-05-01 19:43:06 +05:30
temperature: 0.7,
},
{
2024-07-05 14:36:50 +08:00
baseURL: "https://api.groq.com/openai/v1",
2024-05-01 19:43:06 +05:30
},
),
2024-07-05 14:36:50 +08:00
"LLaMA3 70b": new ChatOpenAI(
2024-05-01 19:43:06 +05:30
{
openAIApiKey: groqApiKey,
2024-07-05 14:36:50 +08:00
modelName: "llama3-70b-8192",
2024-05-01 19:43:06 +05:30
temperature: 0.7,
},
{
2024-07-05 14:36:50 +08:00
baseURL: "https://api.groq.com/openai/v1",
2024-05-01 19:43:06 +05:30
},
),
2024-07-05 14:36:50 +08:00
"Mixtral 8x7b": new ChatOpenAI(
2024-05-01 19:43:06 +05:30
{
openAIApiKey: groqApiKey,
2024-07-05 14:36:50 +08:00
modelName: "mixtral-8x7b-32768",
2024-05-01 19:43:06 +05:30
temperature: 0.7,
},
{
2024-07-05 14:36:50 +08:00
baseURL: "https://api.groq.com/openai/v1",
2024-05-01 19:43:06 +05:30
},
),
2024-07-05 14:36:50 +08:00
"Gemma 7b": new ChatOpenAI(
2024-05-01 19:43:06 +05:30
{
openAIApiKey: groqApiKey,
2024-07-05 14:36:50 +08:00
modelName: "gemma-7b-it",
2024-05-01 19:43:06 +05:30
temperature: 0.7,
},
{
2024-07-05 14:36:50 +08:00
baseURL: "https://api.groq.com/openai/v1",
2024-05-01 19:43:06 +05:30
},
),
};
} catch (err) {
logger.error(`Error loading Groq models: ${err}`);
}
}
2024-04-20 11:18:52 +05:30
if (ollamaEndpoint) {
try {
const response = await fetch(`${ollamaEndpoint}/api/tags`, {
headers: {
2024-07-05 14:36:50 +08:00
"Content-Type": "application/json",
},
});
2024-04-20 11:18:52 +05:30
// eslint-disable-next-line @typescript-eslint/no-explicit-any
2024-04-20 11:18:52 +05:30
const { models: ollamaModels } = (await response.json()) as any;
2024-07-05 14:36:50 +08:00
models["ollama"] = ollamaModels.reduce((acc, model) => {
2024-04-20 11:18:52 +05:30
acc[model.model] = new ChatOllama({
baseUrl: ollamaEndpoint,
model: model.model,
temperature: 0.7,
});
return acc;
}, {});
} catch (err) {
logger.error(`Error loading Ollama models: ${err}`);
}
}
2024-07-05 14:36:50 +08:00
models["custom_openai"] = {};
return models;
};
export const getAvailableEmbeddingModelProviders = async () => {
const openAIApiKey = getOpenaiApiKey();
const ollamaEndpoint = getOllamaApiEndpoint();
const models = {};
if (openAIApiKey) {
try {
2024-07-05 14:36:50 +08:00
models["openai"] = {
"Text embedding 3 small": new OpenAIEmbeddings({
openAIApiKey,
2024-07-05 14:36:50 +08:00
modelName: "text-embedding-3-small",
}),
2024-07-05 14:36:50 +08:00
"Text embedding 3 large": new OpenAIEmbeddings({
openAIApiKey,
2024-07-05 14:36:50 +08:00
modelName: "text-embedding-3-large",
}),
};
} catch (err) {
logger.error(`Error loading OpenAI embeddings: ${err}`);
}
}
2024-04-20 11:18:52 +05:30
if (ollamaEndpoint) {
try {
const response = await fetch(`${ollamaEndpoint}/api/tags`, {
headers: {
2024-07-05 14:36:50 +08:00
"Content-Type": "application/json",
},
});
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const { models: ollamaModels } = (await response.json()) as any;
2024-07-05 14:36:50 +08:00
models["ollama"] = ollamaModels.reduce((acc, model) => {
acc[model.model] = new OllamaEmbeddings({
2024-04-20 11:18:52 +05:30
baseUrl: ollamaEndpoint,
model: model.model,
2024-04-20 11:18:52 +05:30
});
return acc;
}, {});
2024-04-20 11:18:52 +05:30
} catch (err) {
logger.error(`Error loading Ollama embeddings: ${err}`);
2024-04-20 11:18:52 +05:30
}
}
try {
2024-07-05 14:36:50 +08:00
models["local"] = {
"BGE Small": new HuggingFaceTransformersEmbeddings({
modelName: "Xenova/bge-small-en-v1.5",
}),
2024-07-05 14:36:50 +08:00
"GTE Small": new HuggingFaceTransformersEmbeddings({
modelName: "Xenova/gte-small",
}),
2024-07-05 14:36:50 +08:00
"Bert Multilingual": new HuggingFaceTransformersEmbeddings({
modelName: "Xenova/bert-base-multilingual-uncased",
}),
};
2024-05-09 20:43:04 +05:30
} catch (err) {
logger.error(`Error loading local embeddings: ${err}`);
}
2024-04-20 11:18:52 +05:30
return models;
};