feat: Enhance embedding management and model configuration

- Updated embedding_manager.py to include backend parameter in model configuration.
- Modified model_manager.py to utilize cache_name for ONNX models.
- Refactored hybrid_search.py to improve embedder initialization based on backend type.
- Added backend column to vector_store.py for better model configuration management.
- Implemented migration for existing database to include backend information.
- Enhanced API settings implementation with comprehensive provider and endpoint management.
- Introduced LiteLLM integration guide detailing configuration and usage.
- Added examples for LiteLLM usage in TypeScript.
This commit is contained in:
catlog22
2025-12-24 14:03:59 +08:00
parent 9b926d1a1e
commit b00113d212
22 changed files with 5507 additions and 706 deletions

View File

@@ -0,0 +1,441 @@
/**
* LiteLLM API Config Manager
* Manages provider credentials, endpoint configurations, and model discovery
*/
import { join } from 'path';
import { readFileSync, writeFileSync, existsSync, mkdirSync } from 'fs';
import { homedir } from 'os';
// ===========================
// Type Definitions
// ===========================
export type ProviderType =
| 'openai'
| 'anthropic'
| 'google'
| 'cohere'
| 'azure'
| 'bedrock'
| 'vertexai'
| 'huggingface'
| 'ollama'
| 'custom';
export interface ProviderCredential {
id: string;
name: string;
type: ProviderType;
apiKey?: string;
baseUrl?: string;
apiVersion?: string;
region?: string;
projectId?: string;
organizationId?: string;
enabled: boolean;
metadata?: Record<string, any>;
createdAt: string;
updatedAt: string;
}
export interface EndpointConfig {
id: string;
name: string;
providerId: string;
model: string;
alias?: string;
temperature?: number;
maxTokens?: number;
topP?: number;
enabled: boolean;
metadata?: Record<string, any>;
createdAt: string;
updatedAt: string;
}
export interface ModelInfo {
id: string;
name: string;
provider: ProviderType;
contextWindow: number;
supportsFunctions: boolean;
supportsStreaming: boolean;
inputCostPer1k?: number;
outputCostPer1k?: number;
}
export interface LiteLLMApiConfig {
version: string;
providers: ProviderCredential[];
endpoints: EndpointConfig[];
}
// ===========================
// Model Definitions
// ===========================
export const PROVIDER_MODELS: Record<ProviderType, ModelInfo[]> = {
openai: [
{
id: 'gpt-4-turbo',
name: 'GPT-4 Turbo',
provider: 'openai',
contextWindow: 128000,
supportsFunctions: true,
supportsStreaming: true,
inputCostPer1k: 0.01,
outputCostPer1k: 0.03,
},
{
id: 'gpt-4',
name: 'GPT-4',
provider: 'openai',
contextWindow: 8192,
supportsFunctions: true,
supportsStreaming: true,
inputCostPer1k: 0.03,
outputCostPer1k: 0.06,
},
{
id: 'gpt-3.5-turbo',
name: 'GPT-3.5 Turbo',
provider: 'openai',
contextWindow: 16385,
supportsFunctions: true,
supportsStreaming: true,
inputCostPer1k: 0.0005,
outputCostPer1k: 0.0015,
},
],
anthropic: [
{
id: 'claude-3-opus-20240229',
name: 'Claude 3 Opus',
provider: 'anthropic',
contextWindow: 200000,
supportsFunctions: true,
supportsStreaming: true,
inputCostPer1k: 0.015,
outputCostPer1k: 0.075,
},
{
id: 'claude-3-sonnet-20240229',
name: 'Claude 3 Sonnet',
provider: 'anthropic',
contextWindow: 200000,
supportsFunctions: true,
supportsStreaming: true,
inputCostPer1k: 0.003,
outputCostPer1k: 0.015,
},
{
id: 'claude-3-haiku-20240307',
name: 'Claude 3 Haiku',
provider: 'anthropic',
contextWindow: 200000,
supportsFunctions: true,
supportsStreaming: true,
inputCostPer1k: 0.00025,
outputCostPer1k: 0.00125,
},
],
google: [
{
id: 'gemini-pro',
name: 'Gemini Pro',
provider: 'google',
contextWindow: 32768,
supportsFunctions: true,
supportsStreaming: true,
},
{
id: 'gemini-pro-vision',
name: 'Gemini Pro Vision',
provider: 'google',
contextWindow: 16384,
supportsFunctions: false,
supportsStreaming: true,
},
],
cohere: [
{
id: 'command',
name: 'Command',
provider: 'cohere',
contextWindow: 4096,
supportsFunctions: false,
supportsStreaming: true,
},
{
id: 'command-light',
name: 'Command Light',
provider: 'cohere',
contextWindow: 4096,
supportsFunctions: false,
supportsStreaming: true,
},
],
azure: [],
bedrock: [],
vertexai: [],
huggingface: [],
ollama: [],
custom: [],
};
// ===========================
// Config File Management
// ===========================
const CONFIG_DIR = join(homedir(), '.claude', 'litellm');
const CONFIG_FILE = join(CONFIG_DIR, 'config.json');
function ensureConfigDir(): void {
if (!existsSync(CONFIG_DIR)) {
mkdirSync(CONFIG_DIR, { recursive: true });
}
}
function loadConfig(): LiteLLMApiConfig {
ensureConfigDir();
if (!existsSync(CONFIG_FILE)) {
const defaultConfig: LiteLLMApiConfig = {
version: '1.0.0',
providers: [],
endpoints: [],
};
saveConfig(defaultConfig);
return defaultConfig;
}
try {
const content = readFileSync(CONFIG_FILE, 'utf-8');
return JSON.parse(content);
} catch (err) {
throw new Error(`Failed to load config: ${(err as Error).message}`);
}
}
function saveConfig(config: LiteLLMApiConfig): void {
ensureConfigDir();
try {
writeFileSync(CONFIG_FILE, JSON.stringify(config, null, 2), 'utf-8');
} catch (err) {
throw new Error(`Failed to save config: ${(err as Error).message}`);
}
}
// ===========================
// Provider Management
// ===========================
export function getAllProviders(): ProviderCredential[] {
const config = loadConfig();
return config.providers;
}
export function getProvider(id: string): ProviderCredential | null {
const config = loadConfig();
return config.providers.find((p) => p.id === id) || null;
}
export function createProvider(
data: Omit<ProviderCredential, 'id' | 'createdAt' | 'updatedAt'>
): ProviderCredential {
const config = loadConfig();
const now = new Date().toISOString();
const provider: ProviderCredential = {
...data,
id: `provider-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`,
createdAt: now,
updatedAt: now,
};
config.providers.push(provider);
saveConfig(config);
return provider;
}
export function updateProvider(
id: string,
updates: Partial<ProviderCredential>
): ProviderCredential | null {
const config = loadConfig();
const index = config.providers.findIndex((p) => p.id === id);
if (index === -1) {
return null;
}
const updated: ProviderCredential = {
...config.providers[index],
...updates,
id,
updatedAt: new Date().toISOString(),
};
config.providers[index] = updated;
saveConfig(config);
return updated;
}
export function deleteProvider(id: string): { success: boolean } {
const config = loadConfig();
const index = config.providers.findIndex((p) => p.id === id);
if (index === -1) {
return { success: false };
}
config.providers.splice(index, 1);
// Also delete endpoints using this provider
config.endpoints = config.endpoints.filter((e) => e.providerId !== id);
saveConfig(config);
return { success: true };
}
export async function testProviderConnection(
providerId: string
): Promise<{ success: boolean; error?: string }> {
const provider = getProvider(providerId);
if (!provider) {
return { success: false, error: 'Provider not found' };
}
if (!provider.enabled) {
return { success: false, error: 'Provider is disabled' };
}
// Basic validation
if (!provider.apiKey && provider.type !== 'ollama' && provider.type !== 'custom') {
return { success: false, error: 'API key is required for this provider type' };
}
// TODO: Implement actual provider connection testing using litellm-client
// For now, just validate the configuration
return { success: true };
}
// ===========================
// Endpoint Management
// ===========================
export function getAllEndpoints(): EndpointConfig[] {
const config = loadConfig();
return config.endpoints;
}
export function getEndpoint(id: string): EndpointConfig | null {
const config = loadConfig();
return config.endpoints.find((e) => e.id === id) || null;
}
export function createEndpoint(
data: Omit<EndpointConfig, 'id' | 'createdAt' | 'updatedAt'>
): EndpointConfig {
const config = loadConfig();
// Validate provider exists
const provider = config.providers.find((p) => p.id === data.providerId);
if (!provider) {
throw new Error('Provider not found');
}
const now = new Date().toISOString();
const endpoint: EndpointConfig = {
...data,
id: `endpoint-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`,
createdAt: now,
updatedAt: now,
};
config.endpoints.push(endpoint);
saveConfig(config);
return endpoint;
}
export function updateEndpoint(
id: string,
updates: Partial<EndpointConfig>
): EndpointConfig | null {
const config = loadConfig();
const index = config.endpoints.findIndex((e) => e.id === id);
if (index === -1) {
return null;
}
// Validate provider if being updated
if (updates.providerId) {
const provider = config.providers.find((p) => p.id === updates.providerId);
if (!provider) {
throw new Error('Provider not found');
}
}
const updated: EndpointConfig = {
...config.endpoints[index],
...updates,
id,
updatedAt: new Date().toISOString(),
};
config.endpoints[index] = updated;
saveConfig(config);
return updated;
}
export function deleteEndpoint(id: string): { success: boolean } {
const config = loadConfig();
const index = config.endpoints.findIndex((e) => e.id === id);
if (index === -1) {
return { success: false };
}
config.endpoints.splice(index, 1);
saveConfig(config);
return { success: true };
}
// ===========================
// Model Discovery
// ===========================
export function getModelsForProviderType(providerType: ProviderType): ModelInfo[] | null {
return PROVIDER_MODELS[providerType] || null;
}
export function getAllModels(): Record<ProviderType, ModelInfo[]> {
return PROVIDER_MODELS;
}
// ===========================
// Config Access
// ===========================
export function getFullConfig(): LiteLLMApiConfig {
return loadConfig();
}
export function resetConfig(): void {
const defaultConfig: LiteLLMApiConfig = {
version: '1.0.0',
providers: [],
endpoints: [],
};
saveConfig(defaultConfig);
}

View File

@@ -25,10 +25,33 @@ export interface ModelInfo {
}
/**
* Predefined models for each provider
* Embedding model information metadata
*/
export interface EmbeddingModelInfo {
/** Model identifier (used in API calls) */
id: string;
/** Human-readable display name */
name: string;
/** Embedding dimensions */
dimensions: number;
/** Maximum input tokens */
maxTokens: number;
/** Provider identifier */
provider: string;
}
/**
* Predefined models for each API format
* Used for UI selection and validation
* Note: Most providers use OpenAI-compatible format
*/
export const PROVIDER_MODELS: Record<ProviderType, ModelInfo[]> = {
// OpenAI-compatible format (used by OpenAI, DeepSeek, Ollama, etc.)
openai: [
{
id: 'gpt-4o',
@@ -49,19 +72,32 @@ export const PROVIDER_MODELS: Record<ProviderType, ModelInfo[]> = {
supportsCaching: true
},
{
id: 'o1-mini',
name: 'O1 Mini',
contextWindow: 128000,
supportsCaching: true
id: 'deepseek-chat',
name: 'DeepSeek Chat',
contextWindow: 64000,
supportsCaching: false
},
{
id: 'gpt-4-turbo',
name: 'GPT-4 Turbo',
id: 'deepseek-coder',
name: 'DeepSeek Coder',
contextWindow: 64000,
supportsCaching: false
},
{
id: 'llama3.2',
name: 'Llama 3.2',
contextWindow: 128000,
supportsCaching: false
},
{
id: 'qwen2.5-coder',
name: 'Qwen 2.5 Coder',
contextWindow: 32000,
supportsCaching: false
}
],
// Anthropic format
anthropic: [
{
id: 'claude-sonnet-4-20250514',
@@ -89,135 +125,7 @@ export const PROVIDER_MODELS: Record<ProviderType, ModelInfo[]> = {
}
],
ollama: [
{
id: 'llama3.2',
name: 'Llama 3.2',
contextWindow: 128000,
supportsCaching: false
},
{
id: 'llama3.1',
name: 'Llama 3.1',
contextWindow: 128000,
supportsCaching: false
},
{
id: 'qwen2.5-coder',
name: 'Qwen 2.5 Coder',
contextWindow: 32000,
supportsCaching: false
},
{
id: 'codellama',
name: 'Code Llama',
contextWindow: 16000,
supportsCaching: false
},
{
id: 'mistral',
name: 'Mistral',
contextWindow: 32000,
supportsCaching: false
}
],
azure: [
{
id: 'gpt-4o',
name: 'GPT-4o (Azure)',
contextWindow: 128000,
supportsCaching: true
},
{
id: 'gpt-4o-mini',
name: 'GPT-4o Mini (Azure)',
contextWindow: 128000,
supportsCaching: true
},
{
id: 'gpt-4-turbo',
name: 'GPT-4 Turbo (Azure)',
contextWindow: 128000,
supportsCaching: false
},
{
id: 'gpt-35-turbo',
name: 'GPT-3.5 Turbo (Azure)',
contextWindow: 16000,
supportsCaching: false
}
],
google: [
{
id: 'gemini-2.0-flash-exp',
name: 'Gemini 2.0 Flash Experimental',
contextWindow: 1048576,
supportsCaching: true
},
{
id: 'gemini-1.5-pro',
name: 'Gemini 1.5 Pro',
contextWindow: 2097152,
supportsCaching: true
},
{
id: 'gemini-1.5-flash',
name: 'Gemini 1.5 Flash',
contextWindow: 1048576,
supportsCaching: true
},
{
id: 'gemini-1.0-pro',
name: 'Gemini 1.0 Pro',
contextWindow: 32000,
supportsCaching: false
}
],
mistral: [
{
id: 'mistral-large-latest',
name: 'Mistral Large',
contextWindow: 128000,
supportsCaching: false
},
{
id: 'mistral-medium-latest',
name: 'Mistral Medium',
contextWindow: 32000,
supportsCaching: false
},
{
id: 'mistral-small-latest',
name: 'Mistral Small',
contextWindow: 32000,
supportsCaching: false
},
{
id: 'codestral-latest',
name: 'Codestral',
contextWindow: 32000,
supportsCaching: false
}
],
deepseek: [
{
id: 'deepseek-chat',
name: 'DeepSeek Chat',
contextWindow: 64000,
supportsCaching: false
},
{
id: 'deepseek-coder',
name: 'DeepSeek Coder',
contextWindow: 64000,
supportsCaching: false
}
],
// Custom format
custom: [
{
id: 'custom-model',
@@ -237,6 +145,61 @@ export function getModelsForProvider(providerType: ProviderType): ModelInfo[] {
return PROVIDER_MODELS[providerType] || [];
}
/**
* Predefined embedding models for each API format
* Used for UI selection and validation
*/
export const EMBEDDING_MODELS: Record<ProviderType, EmbeddingModelInfo[]> = {
// OpenAI embedding models
openai: [
{
id: 'text-embedding-3-small',
name: 'Text Embedding 3 Small',
dimensions: 1536,
maxTokens: 8191,
provider: 'openai'
},
{
id: 'text-embedding-3-large',
name: 'Text Embedding 3 Large',
dimensions: 3072,
maxTokens: 8191,
provider: 'openai'
},
{
id: 'text-embedding-ada-002',
name: 'Ada 002',
dimensions: 1536,
maxTokens: 8191,
provider: 'openai'
}
],
// Anthropic doesn't have embedding models
anthropic: [],
// Custom embedding models
custom: [
{
id: 'custom-embedding',
name: 'Custom Embedding',
dimensions: 1536,
maxTokens: 8192,
provider: 'custom'
}
]
};
/**
* Get embedding models for a specific provider
* @param providerType - Provider type to get embedding models for
* @returns Array of embedding model information
*/
export function getEmbeddingModelsForProvider(providerType: ProviderType): EmbeddingModelInfo[] {
return EMBEDDING_MODELS[providerType] || [];
}
/**
* Get model information by ID within a provider
* @param providerType - Provider type