diff --git a/archon-ui-main/src/components/settings/RAGSettings.tsx b/archon-ui-main/src/components/settings/RAGSettings.tsx index 8cb721a7..a60925bb 100644 --- a/archon-ui-main/src/components/settings/RAGSettings.tsx +++ b/archon-ui-main/src/components/settings/RAGSettings.tsx @@ -9,6 +9,103 @@ import { credentialsService } from '../../services/credentialsService'; import OllamaModelDiscoveryModal from './OllamaModelDiscoveryModal'; import OllamaModelSelectionModal from './OllamaModelSelectionModal'; +type ProviderKey = 'openai' | 'google' | 'ollama' | 'anthropic' | 'grok' | 'openrouter'; + +interface ProviderModels { + chatModel: string; + embeddingModel: string; +} + +type ProviderModelMap = Record; + +// Provider model persistence helpers +const PROVIDER_MODELS_KEY = 'archon_provider_models'; + +const getDefaultModels = (provider: ProviderKey): ProviderModels => { + const chatDefaults: Record = { + openai: 'gpt-4o-mini', + anthropic: 'claude-3-5-sonnet-20241022', + google: 'gemini-1.5-flash', + grok: 'grok-3-mini', // Updated to use grok-3-mini as default + openrouter: 'openai/gpt-4o-mini', + ollama: 'llama3:8b' + }; + + const embeddingDefaults: Record = { + openai: 'text-embedding-3-small', + anthropic: 'text-embedding-3-small', // Fallback to OpenAI + google: 'text-embedding-004', + grok: 'text-embedding-3-small', // Fallback to OpenAI + openrouter: 'text-embedding-3-small', + ollama: 'nomic-embed-text' + }; + + return { + chatModel: chatDefaults[provider], + embeddingModel: embeddingDefaults[provider] + }; +}; + +const saveProviderModels = (providerModels: ProviderModelMap): void => { + try { + localStorage.setItem(PROVIDER_MODELS_KEY, JSON.stringify(providerModels)); + } catch (error) { + console.error('Failed to save provider models:', error); + } +}; + +const loadProviderModels = (): ProviderModelMap => { + try { + const saved = localStorage.getItem(PROVIDER_MODELS_KEY); + if (saved) { + return JSON.parse(saved); + } + } catch (error) { + console.error('Failed to load provider models:', error); + } + + // Return defaults for all providers if nothing saved + const providers: ProviderKey[] = ['openai', 'google', 'openrouter', 'ollama', 'anthropic', 'grok']; + const defaultModels: ProviderModelMap = {} as ProviderModelMap; + + providers.forEach(provider => { + defaultModels[provider] = getDefaultModels(provider); + }); + + return defaultModels; +}; + +// Static color styles mapping (prevents Tailwind JIT purging) +const colorStyles: Record = { + openai: 'border-green-500 bg-green-500/10', + google: 'border-blue-500 bg-blue-500/10', + openrouter: 'border-cyan-500 bg-cyan-500/10', + ollama: 'border-purple-500 bg-purple-500/10', + anthropic: 'border-orange-500 bg-orange-500/10', + grok: 'border-yellow-500 bg-yellow-500/10', +}; + +const providerAlertStyles: Record = { + openai: 'bg-green-50 dark:bg-green-900/20 border-green-200 dark:border-green-800 text-green-800 dark:text-green-300', + google: 'bg-blue-50 dark:bg-blue-900/20 border-blue-200 dark:border-blue-800 text-blue-800 dark:text-blue-300', + openrouter: 'bg-cyan-50 dark:bg-cyan-900/20 border-cyan-200 dark:border-cyan-800 text-cyan-800 dark:text-cyan-300', + ollama: 'bg-purple-50 dark:bg-purple-900/20 border-purple-200 dark:border-purple-800 text-purple-800 dark:text-purple-300', + anthropic: 'bg-orange-50 dark:bg-orange-900/20 border-orange-200 dark:border-orange-800 text-orange-800 dark:text-orange-300', + grok: 'bg-yellow-50 dark:bg-yellow-900/20 border-yellow-200 dark:border-yellow-800 text-yellow-800 dark:text-yellow-300', +}; + +const providerAlertMessages: Record = { + openai: 'Configure your OpenAI API key in the credentials section to use GPT models.', + google: 'Configure your Google API key in the credentials section to use Gemini models.', + openrouter: 'Configure your OpenRouter API key in the credentials section to use models.', + ollama: 'Configure your Ollama instances in this panel to connect local models.', + anthropic: 'Configure your Anthropic API key in the credentials section to use Claude models.', + grok: 'Configure your Grok API key in the credentials section to use Grok models.', +}; + +const isProviderKey = (value: unknown): value is ProviderKey => + typeof value === 'string' && ['openai', 'google', 'openrouter', 'ollama', 'anthropic', 'grok'].includes(value); + interface RAGSettingsProps { ragSettings: { MODEL_CHOICE: string; @@ -19,8 +116,10 @@ interface RAGSettingsProps { USE_RERANKING: boolean; LLM_PROVIDER?: string; LLM_BASE_URL?: string; + LLM_INSTANCE_NAME?: string; EMBEDDING_MODEL?: string; OLLAMA_EMBEDDING_URL?: string; + OLLAMA_EMBEDDING_INSTANCE_NAME?: string; // Crawling Performance Settings CRAWL_BATCH_SIZE?: number; CRAWL_MAX_CONCURRENT?: number; @@ -57,7 +156,10 @@ export const RAGSettings = ({ // Model selection modals state const [showLLMModelSelectionModal, setShowLLMModelSelectionModal] = useState(false); const [showEmbeddingModelSelectionModal, setShowEmbeddingModelSelectionModal] = useState(false); - + + // Provider-specific model persistence state + const [providerModels, setProviderModels] = useState(() => loadProviderModels()); + // Instance configurations const [llmInstanceConfig, setLLMInstanceConfig] = useState({ name: '', @@ -113,6 +215,25 @@ export const RAGSettings = ({ } }, [ragSettings.OLLAMA_EMBEDDING_URL, ragSettings.OLLAMA_EMBEDDING_INSTANCE_NAME]); + // Provider model persistence effects + useEffect(() => { + // Update provider models when current models change + const currentProvider = ragSettings.LLM_PROVIDER as ProviderKey; + if (currentProvider && ragSettings.MODEL_CHOICE && ragSettings.EMBEDDING_MODEL) { + setProviderModels(prev => { + const updated = { + ...prev, + [currentProvider]: { + chatModel: ragSettings.MODEL_CHOICE, + embeddingModel: ragSettings.EMBEDDING_MODEL + } + }; + saveProviderModels(updated); + return updated; + }); + } + }, [ragSettings.MODEL_CHOICE, ragSettings.EMBEDDING_MODEL, ragSettings.LLM_PROVIDER]); + // Load API credentials for status checking useEffect(() => { const loadApiCredentials = async () => { @@ -197,58 +318,27 @@ export const RAGSettings = ({ }>({}); // Test connection to external providers - const testProviderConnection = async (provider: string, apiKey: string): Promise => { + const testProviderConnection = async (provider: string): Promise => { setProviderConnectionStatus(prev => ({ ...prev, [provider]: { ...prev[provider], checking: true } })); try { - switch (provider) { - case 'openai': - // Test OpenAI connection with a simple completion request - const openaiResponse = await fetch('https://api.openai.com/v1/models', { - method: 'GET', - headers: { - 'Authorization': `Bearer ${apiKey}`, - 'Content-Type': 'application/json' - } - }); - - if (openaiResponse.ok) { - setProviderConnectionStatus(prev => ({ - ...prev, - openai: { connected: true, checking: false, lastChecked: new Date() } - })); - return true; - } else { - throw new Error(`OpenAI API returned ${openaiResponse.status}`); - } + // Use server-side API endpoint for secure connectivity testing + const response = await fetch(`/api/providers/${provider}/status`); + const result = await response.json(); - case 'google': - // Test Google Gemini connection - const googleResponse = await fetch(`https://generativelanguage.googleapis.com/v1/models?key=${apiKey}`, { - method: 'GET', - headers: { - 'Content-Type': 'application/json' - } - }); - - if (googleResponse.ok) { - setProviderConnectionStatus(prev => ({ - ...prev, - google: { connected: true, checking: false, lastChecked: new Date() } - })); - return true; - } else { - throw new Error(`Google API returned ${googleResponse.status}`); - } + const isConnected = result.ok && result.reason === 'connected'; - default: - return false; - } + setProviderConnectionStatus(prev => ({ + ...prev, + [provider]: { connected: isConnected, checking: false, lastChecked: new Date() } + })); + + return isConnected; } catch (error) { - console.error(`Failed to test ${provider} connection:`, error); + console.error(`Error testing ${provider} connection:`, error); setProviderConnectionStatus(prev => ({ ...prev, [provider]: { connected: false, checking: false, lastChecked: new Date() } @@ -260,37 +350,27 @@ export const RAGSettings = ({ // Test provider connections when API credentials change useEffect(() => { const testConnections = async () => { - const providers = ['openai', 'google']; - + // Test all supported providers + const providers = ['openai', 'google', 'anthropic', 'openrouter', 'grok']; + for (const provider of providers) { - const keyName = provider === 'openai' ? 'OPENAI_API_KEY' : 'GOOGLE_API_KEY'; - const apiKey = Object.keys(apiCredentials).find(key => key.toUpperCase() === keyName); - const keyValue = apiKey ? apiCredentials[apiKey] : undefined; - - if (keyValue && keyValue.trim().length > 0) { - // Don't test if we've already checked recently (within last 30 seconds) - const lastChecked = providerConnectionStatus[provider]?.lastChecked; - const now = new Date(); - const timeSinceLastCheck = lastChecked ? now.getTime() - lastChecked.getTime() : Infinity; - - if (timeSinceLastCheck > 30000) { // 30 seconds - console.log(`🔄 Testing ${provider} connection...`); - await testProviderConnection(provider, keyValue); - } - } else { - // No API key, mark as disconnected - setProviderConnectionStatus(prev => ({ - ...prev, - [provider]: { connected: false, checking: false, lastChecked: new Date() } - })); + // Don't test if we've already checked recently (within last 30 seconds) + const lastChecked = providerConnectionStatus[provider]?.lastChecked; + const now = new Date(); + const timeSinceLastCheck = lastChecked ? now.getTime() - lastChecked.getTime() : Infinity; + + if (timeSinceLastCheck > 30000) { // 30 seconds + console.log(`🔄 Testing ${provider} connection...`); + await testProviderConnection(provider); } } }; - // Only test if we have credentials loaded - if (Object.keys(apiCredentials).length > 0) { - testConnections(); - } + // Test connections periodically (every 60 seconds) + testConnections(); + const interval = setInterval(testConnections, 60000); + + return () => clearInterval(interval); }, [apiCredentials]); // Test when credentials change // Ref to track if initial test has been run (will be used after function definitions) @@ -662,24 +742,41 @@ export const RAGSettings = ({ if (llmStatus.online || embeddingStatus.online) return 'partial'; return 'missing'; case 'anthropic': - // Check if Anthropic API key is configured (case insensitive) - const anthropicKey = Object.keys(apiCredentials).find(key => key.toUpperCase() === 'ANTHROPIC_API_KEY'); - const hasAnthropicKey = anthropicKey && apiCredentials[anthropicKey] && apiCredentials[anthropicKey].trim().length > 0; - return hasAnthropicKey ? 'configured' : 'missing'; + // Use server-side connection status + const anthropicConnected = providerConnectionStatus['anthropic']?.connected || false; + const anthropicChecking = providerConnectionStatus['anthropic']?.checking || false; + if (anthropicChecking) return 'partial'; + return anthropicConnected ? 'configured' : 'missing'; case 'grok': - // Check if Grok API key is configured (case insensitive) - const grokKey = Object.keys(apiCredentials).find(key => key.toUpperCase() === 'GROK_API_KEY'); - const hasGrokKey = grokKey && apiCredentials[grokKey] && apiCredentials[grokKey].trim().length > 0; - return hasGrokKey ? 'configured' : 'missing'; + // Use server-side connection status + const grokConnected = providerConnectionStatus['grok']?.connected || false; + const grokChecking = providerConnectionStatus['grok']?.checking || false; + if (grokChecking) return 'partial'; + return grokConnected ? 'configured' : 'missing'; case 'openrouter': - // Check if OpenRouter API key is configured (case insensitive) - const openRouterKey = Object.keys(apiCredentials).find(key => key.toUpperCase() === 'OPENROUTER_API_KEY'); - const hasOpenRouterKey = openRouterKey && apiCredentials[openRouterKey] && apiCredentials[openRouterKey].trim().length > 0; - return hasOpenRouterKey ? 'configured' : 'missing'; + // Use server-side connection status + const openRouterConnected = providerConnectionStatus['openrouter']?.connected || false; + const openRouterChecking = providerConnectionStatus['openrouter']?.checking || false; + if (openRouterChecking) return 'partial'; + return openRouterConnected ? 'configured' : 'missing'; default: return 'missing'; } - };; + }; + + const selectedProviderKey = isProviderKey(ragSettings.LLM_PROVIDER) + ? (ragSettings.LLM_PROVIDER as ProviderKey) + : undefined; + const selectedProviderStatus = selectedProviderKey ? getProviderStatus(selectedProviderKey) : undefined; + const shouldShowProviderAlert = Boolean( + selectedProviderKey && selectedProviderStatus === 'missing' + ); + const providerAlertClassName = shouldShowProviderAlert && selectedProviderKey + ? providerAlertStyles[selectedProviderKey] + : ''; + const providerAlertMessage = shouldShowProviderAlert && selectedProviderKey + ? providerAlertMessages[selectedProviderKey] + : ''; // Test Ollama connectivity when Settings page loads (scenario 4: page load) // This useEffect is placed after function definitions to ensure access to manualTestConnection @@ -750,55 +847,32 @@ export const RAGSettings = ({ {[ { key: 'openai', name: 'OpenAI', logo: '/img/OpenAI.png', color: 'green' }, { key: 'google', name: 'Google', logo: '/img/google-logo.svg', color: 'blue' }, + { key: 'openrouter', name: 'OpenRouter', logo: '/img/OpenRouter.png', color: 'cyan' }, { key: 'ollama', name: 'Ollama', logo: '/img/Ollama.png', color: 'purple' }, { key: 'anthropic', name: 'Anthropic', logo: '/img/claude-logo.svg', color: 'orange' }, - { key: 'grok', name: 'Grok', logo: '/img/Grok.png', color: 'yellow' }, - { key: 'openrouter', name: 'OpenRouter', logo: '/img/OpenRouter.png', color: 'cyan' } + { key: 'grok', name: 'Grok', logo: '/img/Grok.png', color: 'yellow' } ].map(provider => ( ))} @@ -1223,19 +1290,9 @@ export const RAGSettings = ({ )} - {ragSettings.LLM_PROVIDER === 'anthropic' && ( -
-

- Configure your Anthropic API key in the credentials section to use Claude models. -

-
- )} - - {ragSettings.LLM_PROVIDER === 'groq' && ( -
-

- Groq provides fast inference with Llama, Mixtral, and Gemma models. -

+ {shouldShowProviderAlert && ( +
+

{providerAlertMessage}

)} @@ -1853,94 +1910,56 @@ export const RAGSettings = ({ function getDisplayedChatModel(ragSettings: any): string { const provider = ragSettings.LLM_PROVIDER || 'openai'; const modelChoice = ragSettings.MODEL_CHOICE; - - // Check if the stored model is appropriate for the current provider - const isModelAppropriate = (model: string, provider: string): boolean => { - if (!model) return false; - - switch (provider) { - case 'openai': - return model.startsWith('gpt-') || model.startsWith('o1-') || model.includes('text-davinci') || model.includes('text-embedding'); - case 'anthropic': - return model.startsWith('claude-'); - case 'google': - return model.startsWith('gemini-') || model.startsWith('text-embedding-'); - case 'grok': - return model.startsWith('grok-'); - case 'ollama': - return !model.startsWith('gpt-') && !model.startsWith('claude-') && !model.startsWith('gemini-') && !model.startsWith('grok-'); - case 'openrouter': - return model.includes('/') || model.startsWith('anthropic/') || model.startsWith('openai/'); - default: - return false; - } - }; - - // Use stored model if it's appropriate for the provider, otherwise use default - const useStoredModel = modelChoice && isModelAppropriate(modelChoice, provider); - + + // Always prioritize user input to allow editing + if (modelChoice !== undefined && modelChoice !== null) { + return modelChoice; + } + + // Only use defaults when there's no stored value switch (provider) { case 'openai': - return useStoredModel ? modelChoice : 'gpt-4o-mini'; + return 'gpt-4o-mini'; case 'anthropic': - return useStoredModel ? modelChoice : 'claude-3-5-sonnet-20241022'; + return 'claude-3-5-sonnet-20241022'; case 'google': - return useStoredModel ? modelChoice : 'gemini-1.5-flash'; + return 'gemini-1.5-flash'; case 'grok': - return useStoredModel ? modelChoice : 'grok-2-latest'; + return 'grok-3-mini'; case 'ollama': - return useStoredModel ? modelChoice : ''; + return ''; case 'openrouter': - return useStoredModel ? modelChoice : 'anthropic/claude-3.5-sonnet'; + return 'anthropic/claude-3.5-sonnet'; default: - return useStoredModel ? modelChoice : 'gpt-4o-mini'; + return 'gpt-4o-mini'; } } function getDisplayedEmbeddingModel(ragSettings: any): string { const provider = ragSettings.LLM_PROVIDER || 'openai'; const embeddingModel = ragSettings.EMBEDDING_MODEL; - - // Check if the stored embedding model is appropriate for the current provider - const isEmbeddingModelAppropriate = (model: string, provider: string): boolean => { - if (!model) return false; - - switch (provider) { - case 'openai': - return model.startsWith('text-embedding-') || model.includes('ada-'); - case 'anthropic': - return false; // Claude doesn't provide embedding models - case 'google': - return model.startsWith('text-embedding-') || model.startsWith('textembedding-') || model.includes('embedding'); - case 'grok': - return false; // Grok doesn't provide embedding models - case 'ollama': - return !model.startsWith('text-embedding-') || model.includes('embed') || model.includes('arctic'); - case 'openrouter': - return model.startsWith('text-embedding-') || model.includes('/'); - default: - return false; - } - }; - - // Use stored model if it's appropriate for the provider, otherwise use default - const useStoredModel = embeddingModel && isEmbeddingModelAppropriate(embeddingModel, provider); - + + // Always prioritize user input to allow editing + if (embeddingModel !== undefined && embeddingModel !== null && embeddingModel !== '') { + return embeddingModel; + } + + // Provide appropriate defaults based on LLM provider switch (provider) { case 'openai': - return useStoredModel ? embeddingModel : 'text-embedding-3-small'; - case 'anthropic': - return 'Not available - Claude does not provide embedding models'; + return 'text-embedding-3-small'; case 'google': - return useStoredModel ? embeddingModel : 'text-embedding-004'; - case 'grok': - return 'Not available - Grok does not provide embedding models'; + return 'text-embedding-004'; case 'ollama': - return useStoredModel ? embeddingModel : ''; + return ''; case 'openrouter': - return useStoredModel ? embeddingModel : 'text-embedding-3-small'; + return 'text-embedding-3-small'; // Default to OpenAI embedding for OpenRouter + case 'anthropic': + return 'text-embedding-3-small'; // Use OpenAI embeddings with Claude + case 'grok': + return 'text-embedding-3-small'; // Use OpenAI embeddings with Grok default: - return useStoredModel ? embeddingModel : 'text-embedding-3-small'; + return 'text-embedding-3-small'; } } @@ -2035,4 +2054,4 @@ const CustomCheckbox = ({
); -}; \ No newline at end of file +}; diff --git a/python/src/server/api_routes/__init__.py b/python/src/server/api_routes/__init__.py index 04df2992..8a39ef26 100644 --- a/python/src/server/api_routes/__init__.py +++ b/python/src/server/api_routes/__init__.py @@ -14,6 +14,7 @@ from .internal_api import router as internal_router from .knowledge_api import router as knowledge_router from .mcp_api import router as mcp_router from .projects_api import router as projects_router +from .providers_api import router as providers_router from .settings_api import router as settings_router __all__ = [ @@ -23,4 +24,5 @@ __all__ = [ "projects_router", "agent_chat_router", "internal_router", + "providers_router", ] diff --git a/python/src/server/api_routes/knowledge_api.py b/python/src/server/api_routes/knowledge_api.py index 56725838..1f26dace 100644 --- a/python/src/server/api_routes/knowledge_api.py +++ b/python/src/server/api_routes/knowledge_api.py @@ -18,6 +18,8 @@ from urllib.parse import urlparse from fastapi import APIRouter, File, Form, HTTPException, UploadFile from pydantic import BaseModel +# Basic validation - simplified inline version + # Import unified logging from ..config.logfire_config import get_logger, safe_logfire_error, safe_logfire_info from ..services.crawler_manager import get_crawler @@ -62,26 +64,59 @@ async def _validate_provider_api_key(provider: str = None) -> None: logger.info("🔑 Starting API key validation...") try: + # Basic provider validation if not provider: provider = "openai" + else: + # Simple provider validation + allowed_providers = {"openai", "ollama", "google", "openrouter", "anthropic", "grok"} + if provider not in allowed_providers: + raise HTTPException( + status_code=400, + detail={ + "error": "Invalid provider name", + "message": f"Provider '{provider}' not supported", + "error_type": "validation_error" + } + ) - logger.info(f"🔑 Testing {provider.title()} API key with minimal embedding request...") - - # Test API key with minimal embedding request - this will fail if key is invalid - from ..services.embeddings.embedding_service import create_embedding - test_result = await create_embedding(text="test") - - if not test_result: - logger.error(f"❌ {provider.title()} API key validation failed - no embedding returned") - raise HTTPException( - status_code=401, - detail={ - "error": f"Invalid {provider.title()} API key", - "message": f"Please verify your {provider.title()} API key in Settings.", - "error_type": "authentication_failed", - "provider": provider - } - ) + # Basic sanitization for logging + safe_provider = provider[:20] # Limit length + logger.info(f"🔑 Testing {safe_provider.title()} API key with minimal embedding request...") + + try: + # Test API key with minimal embedding request using provider-scoped configuration + from ..services.embeddings.embedding_service import create_embedding + + test_result = await create_embedding(text="test", provider=provider) + + if not test_result: + logger.error( + f"❌ {provider.title()} API key validation failed - no embedding returned" + ) + raise HTTPException( + status_code=401, + detail={ + "error": f"Invalid {provider.title()} API key", + "message": f"Please verify your {provider.title()} API key in Settings.", + "error_type": "authentication_failed", + "provider": provider, + }, + ) + except Exception as e: + logger.error( + f"❌ {provider.title()} API key validation failed: {e}", + exc_info=True, + ) + raise HTTPException( + status_code=401, + detail={ + "error": f"Invalid {provider.title()} API key", + "message": f"Please verify your {provider.title()} API key in Settings. Error: {str(e)[:100]}", + "error_type": "authentication_failed", + "provider": provider, + }, + ) logger.info(f"✅ {provider.title()} API key validation successful") diff --git a/python/src/server/api_routes/providers_api.py b/python/src/server/api_routes/providers_api.py new file mode 100644 index 00000000..9c405ecd --- /dev/null +++ b/python/src/server/api_routes/providers_api.py @@ -0,0 +1,154 @@ +""" +Provider status API endpoints for testing connectivity + +Handles server-side provider connectivity testing without exposing API keys to frontend. +""" + +import httpx +from fastapi import APIRouter, HTTPException, Path + +from ..config.logfire_config import logfire +from ..services.credential_service import credential_service +# Provider validation - simplified inline version + +router = APIRouter(prefix="/api/providers", tags=["providers"]) + + +async def test_openai_connection(api_key: str) -> bool: + """Test OpenAI API connectivity""" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + response = await client.get( + "https://api.openai.com/v1/models", + headers={"Authorization": f"Bearer {api_key}"} + ) + return response.status_code == 200 + except Exception as e: + logfire.warning(f"OpenAI connectivity test failed: {e}") + return False + + +async def test_google_connection(api_key: str) -> bool: + """Test Google AI API connectivity""" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + response = await client.get( + "https://generativelanguage.googleapis.com/v1/models", + headers={"x-goog-api-key": api_key} + ) + return response.status_code == 200 + except Exception: + logfire.warning("Google AI connectivity test failed") + return False + + +async def test_anthropic_connection(api_key: str) -> bool: + """Test Anthropic API connectivity""" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + response = await client.get( + "https://api.anthropic.com/v1/models", + headers={ + "x-api-key": api_key, + "anthropic-version": "2023-06-01" + } + ) + return response.status_code == 200 + except Exception as e: + logfire.warning(f"Anthropic connectivity test failed: {e}") + return False + + +async def test_openrouter_connection(api_key: str) -> bool: + """Test OpenRouter API connectivity""" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + response = await client.get( + "https://openrouter.ai/api/v1/models", + headers={"Authorization": f"Bearer {api_key}"} + ) + return response.status_code == 200 + except Exception as e: + logfire.warning(f"OpenRouter connectivity test failed: {e}") + return False + + +async def test_grok_connection(api_key: str) -> bool: + """Test Grok API connectivity""" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + response = await client.get( + "https://api.x.ai/v1/models", + headers={"Authorization": f"Bearer {api_key}"} + ) + return response.status_code == 200 + except Exception as e: + logfire.warning(f"Grok connectivity test failed: {e}") + return False + + +PROVIDER_TESTERS = { + "openai": test_openai_connection, + "google": test_google_connection, + "anthropic": test_anthropic_connection, + "openrouter": test_openrouter_connection, + "grok": test_grok_connection, +} + + +@router.get("/{provider}/status") +async def get_provider_status( + provider: str = Path( + ..., + description="Provider name to test connectivity for", + regex="^[a-z0-9_]+$", + max_length=20 + ) +): + """Test provider connectivity using server-side API key (secure)""" + try: + # Basic provider validation + allowed_providers = {"openai", "ollama", "google", "openrouter", "anthropic", "grok"} + if provider not in allowed_providers: + raise HTTPException( + status_code=400, + detail=f"Invalid provider '{provider}'. Allowed providers: {sorted(allowed_providers)}" + ) + + # Basic sanitization for logging + safe_provider = provider[:20] # Limit length + logfire.info(f"Testing {safe_provider} connectivity server-side") + + if provider not in PROVIDER_TESTERS: + raise HTTPException( + status_code=400, + detail=f"Provider '{provider}' not supported for connectivity testing" + ) + + # Get API key server-side (never expose to client) + key_name = f"{provider.upper()}_API_KEY" + api_key = await credential_service.get_credential(key_name, decrypt=True) + + if not api_key or not isinstance(api_key, str) or not api_key.strip(): + logfire.info(f"No API key configured for {safe_provider}") + return {"ok": False, "reason": "no_key"} + + # Test connectivity using server-side key + tester = PROVIDER_TESTERS[provider] + is_connected = await tester(api_key) + + logfire.info(f"{safe_provider} connectivity test result: {is_connected}") + return { + "ok": is_connected, + "reason": "connected" if is_connected else "connection_failed", + "provider": provider # Echo back validated provider name + } + + except HTTPException: + # Re-raise HTTP exceptions (they're already properly formatted) + raise + except Exception as e: + # Basic error sanitization for logging + safe_error = str(e)[:100] # Limit length + logfire.error(f"Error testing {provider[:20]} connectivity: {safe_error}") + raise HTTPException(status_code=500, detail={"error": "Internal server error during connectivity test"}) diff --git a/python/src/server/main.py b/python/src/server/main.py index bec14a71..ba0b19cb 100644 --- a/python/src/server/main.py +++ b/python/src/server/main.py @@ -26,6 +26,7 @@ from .api_routes.mcp_api import router as mcp_router from .api_routes.ollama_api import router as ollama_router from .api_routes.progress_api import router as progress_router from .api_routes.projects_api import router as projects_router +from .api_routes.providers_api import router as providers_router # Import modular API routers from .api_routes.settings_api import router as settings_router @@ -186,6 +187,7 @@ app.include_router(progress_router) app.include_router(agent_chat_router) app.include_router(internal_router) app.include_router(bug_report_router) +app.include_router(providers_router) # Root endpoint diff --git a/python/src/server/services/crawling/code_extraction_service.py b/python/src/server/services/crawling/code_extraction_service.py index f52b7e28..1a540f57 100644 --- a/python/src/server/services/crawling/code_extraction_service.py +++ b/python/src/server/services/crawling/code_extraction_service.py @@ -139,6 +139,7 @@ class CodeExtractionService: source_id: str, progress_callback: Callable | None = None, cancellation_check: Callable[[], None] | None = None, + provider: str | None = None, ) -> int: """ Extract code examples from crawled documents and store them. @@ -204,7 +205,7 @@ class CodeExtractionService: # Generate summaries for code blocks summary_results = await self._generate_code_summaries( - all_code_blocks, summary_callback, cancellation_check + all_code_blocks, summary_callback, cancellation_check, provider ) # Prepare code examples for storage @@ -223,7 +224,7 @@ class CodeExtractionService: # Store code examples in database return await self._store_code_examples( - storage_data, url_to_full_document, storage_callback + storage_data, url_to_full_document, storage_callback, provider ) async def _extract_code_blocks_from_documents( @@ -1523,6 +1524,7 @@ class CodeExtractionService: all_code_blocks: list[dict[str, Any]], progress_callback: Callable | None = None, cancellation_check: Callable[[], None] | None = None, + provider: str | None = None, ) -> list[dict[str, str]]: """ Generate summaries for all code blocks. @@ -1587,7 +1589,7 @@ class CodeExtractionService: try: results = await generate_code_summaries_batch( - code_blocks_for_summaries, max_workers, progress_callback=summary_progress_callback + code_blocks_for_summaries, max_workers, progress_callback=summary_progress_callback, provider=provider ) # Ensure all results are valid dicts @@ -1667,6 +1669,7 @@ class CodeExtractionService: storage_data: dict[str, list[Any]], url_to_full_document: dict[str, str], progress_callback: Callable | None = None, + provider: str | None = None, ) -> int: """ Store code examples in the database. @@ -1709,7 +1712,7 @@ class CodeExtractionService: batch_size=20, url_to_full_document=url_to_full_document, progress_callback=storage_progress_callback, - provider=None, # Use configured provider + provider=provider, ) # Report completion of code extraction/storage phase diff --git a/python/src/server/services/crawling/crawling_service.py b/python/src/server/services/crawling/crawling_service.py index 55cd6d92..69c65719 100644 --- a/python/src/server/services/crawling/crawling_service.py +++ b/python/src/server/services/crawling/crawling_service.py @@ -475,12 +475,24 @@ class CrawlingService: ) try: + # Extract provider from request or use credential service default + provider = request.get("provider") + if not provider: + try: + from ..credential_service import credential_service + provider_config = await credential_service.get_active_provider("llm") + provider = provider_config.get("provider", "openai") + except Exception as e: + logger.warning(f"Failed to get provider from credential service: {e}, defaulting to openai") + provider = "openai" + code_examples_count = await self.doc_storage_ops.extract_and_store_code_examples( crawl_results, storage_results["url_to_full_document"], storage_results["source_id"], code_progress_callback, self._check_cancellation, + provider, ) except RuntimeError as e: # Code extraction failed, continue crawl with warning diff --git a/python/src/server/services/crawling/document_storage_operations.py b/python/src/server/services/crawling/document_storage_operations.py index aaf211a7..88ed8e80 100644 --- a/python/src/server/services/crawling/document_storage_operations.py +++ b/python/src/server/services/crawling/document_storage_operations.py @@ -351,6 +351,7 @@ class DocumentStorageOperations: source_id: str, progress_callback: Callable | None = None, cancellation_check: Callable[[], None] | None = None, + provider: str | None = None, ) -> int: """ Extract code examples from crawled documents and store them. @@ -361,12 +362,13 @@ class DocumentStorageOperations: source_id: The unique source_id for all documents progress_callback: Optional callback for progress updates cancellation_check: Optional function to check for cancellation + provider: Optional LLM provider to use for code summaries Returns: Number of code examples stored """ result = await self.code_extraction_service.extract_and_store_code_examples( - crawl_results, url_to_full_document, source_id, progress_callback, cancellation_check + crawl_results, url_to_full_document, source_id, progress_callback, cancellation_check, provider ) return result diff --git a/python/src/server/services/credential_service.py b/python/src/server/services/credential_service.py index e72ca8a4..62fbb47a 100644 --- a/python/src/server/services/credential_service.py +++ b/python/src/server/services/credential_service.py @@ -36,6 +36,44 @@ class CredentialItem: description: str | None = None +def _detect_embedding_provider_from_model(embedding_model: str) -> str: + """ + Detect the appropriate embedding provider based on model name. + + Args: + embedding_model: The embedding model name + + Returns: + Provider name: 'google', 'openai', or 'openai' (default) + """ + if not embedding_model: + return "openai" # Default + + model_lower = embedding_model.lower() + + # Google embedding models + google_patterns = [ + "text-embedding-004", + "text-embedding-005", + "text-multilingual-embedding", + "gemini-embedding", + "multimodalembedding" + ] + + if any(pattern in model_lower for pattern in google_patterns): + return "google" + + # OpenAI embedding models (and default for unknown) + openai_patterns = [ + "text-embedding-ada-002", + "text-embedding-3-small", + "text-embedding-3-large" + ] + + # Default to OpenAI for OpenAI models or unknown models + return "openai" + + class CredentialService: """Service for managing application credentials and configuration.""" @@ -239,6 +277,14 @@ class CredentialService: self._rag_cache_timestamp = None logger.debug(f"Invalidated RAG settings cache due to update of {key}") + # Also invalidate provider service cache to ensure immediate effect + try: + from .llm_provider_service import clear_provider_cache + clear_provider_cache() + logger.debug("Also cleared LLM provider service cache") + except Exception as e: + logger.warning(f"Failed to clear provider service cache: {e}") + # Also invalidate LLM provider service cache for provider config try: from . import llm_provider_service @@ -281,6 +327,14 @@ class CredentialService: self._rag_cache_timestamp = None logger.debug(f"Invalidated RAG settings cache due to deletion of {key}") + # Also invalidate provider service cache to ensure immediate effect + try: + from .llm_provider_service import clear_provider_cache + clear_provider_cache() + logger.debug("Also cleared LLM provider service cache") + except Exception as e: + logger.warning(f"Failed to clear provider service cache: {e}") + # Also invalidate LLM provider service cache for provider config try: from . import llm_provider_service @@ -419,8 +473,33 @@ class CredentialService: # Get RAG strategy settings (where UI saves provider selection) rag_settings = await self.get_credentials_by_category("rag_strategy") - # Get the selected provider - provider = rag_settings.get("LLM_PROVIDER", "openai") + # Get the selected provider based on service type + if service_type == "embedding": + # Get the LLM provider setting to determine embedding provider + llm_provider = rag_settings.get("LLM_PROVIDER", "openai") + embedding_model = rag_settings.get("EMBEDDING_MODEL", "text-embedding-3-small") + + # Determine embedding provider based on LLM provider + if llm_provider == "google": + provider = "google" + elif llm_provider == "ollama": + provider = "ollama" + elif llm_provider == "openrouter": + # OpenRouter supports both OpenAI and Google embedding models + provider = _detect_embedding_provider_from_model(embedding_model) + elif llm_provider in ["anthropic", "grok"]: + # Anthropic and Grok support both OpenAI and Google embedding models + provider = _detect_embedding_provider_from_model(embedding_model) + else: + # Default case (openai, or unknown providers) + provider = "openai" + + logger.debug(f"Determined embedding provider '{provider}' from LLM provider '{llm_provider}' and embedding model '{embedding_model}'") + else: + provider = rag_settings.get("LLM_PROVIDER", "openai") + # Ensure provider is a valid string, not a boolean or other type + if not isinstance(provider, str) or provider.lower() in ("true", "false", "none", "null"): + provider = "openai" # Get API key for this provider api_key = await self._get_provider_api_key(provider) @@ -464,6 +543,9 @@ class CredentialService: key_mapping = { "openai": "OPENAI_API_KEY", "google": "GOOGLE_API_KEY", + "openrouter": "OPENROUTER_API_KEY", + "anthropic": "ANTHROPIC_API_KEY", + "grok": "GROK_API_KEY", "ollama": None, # No API key needed } @@ -478,6 +560,12 @@ class CredentialService: return rag_settings.get("LLM_BASE_URL", "http://host.docker.internal:11434/v1") elif provider == "google": return "https://generativelanguage.googleapis.com/v1beta/openai/" + elif provider == "openrouter": + return "https://openrouter.ai/api/v1" + elif provider == "anthropic": + return "https://api.anthropic.com/v1" + elif provider == "grok": + return "https://api.x.ai/v1" return None # Use default for OpenAI async def set_active_provider(self, provider: str, service_type: str = "llm") -> bool: @@ -485,7 +573,7 @@ class CredentialService: try: # For now, we'll update the RAG strategy settings return await self.set_credential( - "llm_provider", + "LLM_PROVIDER", provider, category="rag_strategy", description=f"Active {service_type} provider", diff --git a/python/src/server/services/embeddings/contextual_embedding_service.py b/python/src/server/services/embeddings/contextual_embedding_service.py index 76f3c59b..29b36395 100644 --- a/python/src/server/services/embeddings/contextual_embedding_service.py +++ b/python/src/server/services/embeddings/contextual_embedding_service.py @@ -10,7 +10,13 @@ import os import openai from ...config.logfire_config import search_logger -from ..llm_provider_service import get_llm_client +from ..credential_service import credential_service +from ..llm_provider_service import ( + extract_message_text, + get_llm_client, + prepare_chat_completion_params, + requires_max_completion_tokens, +) from ..threading_service import get_threading_service @@ -32,8 +38,6 @@ async def generate_contextual_embedding( """ # Model choice is a RAG setting, get from credential service try: - from ...services.credential_service import credential_service - model_choice = await credential_service.get_credential("MODEL_CHOICE", "gpt-4.1-nano") except Exception as e: # Fallback to environment variable or default @@ -65,20 +69,25 @@ Please give a short succinct context to situate this chunk within the overall do # Get model from provider configuration model = await _get_model_choice(provider) - response = await client.chat.completions.create( - model=model, - messages=[ + # Prepare parameters and convert max_tokens for GPT-5/reasoning models + params = { + "model": model, + "messages": [ { "role": "system", "content": "You are a helpful assistant that provides concise contextual information.", }, {"role": "user", "content": prompt}, ], - temperature=0.3, - max_tokens=200, - ) + "temperature": 0.3, + "max_tokens": 1200 if requires_max_completion_tokens(model) else 200, # Much more tokens for reasoning models (GPT-5 needs extra for reasoning process) + } + final_params = prepare_chat_completion_params(model, params) + response = await client.chat.completions.create(**final_params) - context = response.choices[0].message.content.strip() + choice = response.choices[0] if response.choices else None + context, _, _ = extract_message_text(choice) + context = context.strip() contextual_text = f"{context}\n---\n{chunk}" return contextual_text, True @@ -111,7 +120,7 @@ async def process_chunk_with_context( async def _get_model_choice(provider: str | None = None) -> str: - """Get model choice from credential service.""" + """Get model choice from credential service with centralized defaults.""" from ..credential_service import credential_service # Get the active provider configuration @@ -119,31 +128,36 @@ async def _get_model_choice(provider: str | None = None) -> str: model = provider_config.get("chat_model", "").strip() # Strip whitespace provider_name = provider_config.get("provider", "openai") - # Handle empty model case - fallback to provider-specific defaults or explicit config + # Handle empty model case - use centralized defaults if not model: - search_logger.warning(f"chat_model is empty for provider {provider_name}, using fallback logic") - + search_logger.warning(f"chat_model is empty for provider {provider_name}, using centralized defaults") + + # Special handling for Ollama to check specific credential if provider_name == "ollama": - # Try to get OLLAMA_CHAT_MODEL specifically try: ollama_model = await credential_service.get_credential("OLLAMA_CHAT_MODEL") if ollama_model and ollama_model.strip(): model = ollama_model.strip() search_logger.info(f"Using OLLAMA_CHAT_MODEL fallback: {model}") else: - # Use a sensible Ollama default + # Use default for Ollama model = "llama3.2:latest" - search_logger.info(f"Using Ollama default model: {model}") + search_logger.info(f"Using Ollama default: {model}") except Exception as e: search_logger.error(f"Error getting OLLAMA_CHAT_MODEL: {e}") model = "llama3.2:latest" - search_logger.info(f"Using Ollama fallback model: {model}") - elif provider_name == "google": - model = "gemini-1.5-flash" + search_logger.info(f"Using Ollama fallback: {model}") else: - # OpenAI or other providers - model = "gpt-4o-mini" - + # Use provider-specific defaults + provider_defaults = { + "openai": "gpt-4o-mini", + "openrouter": "anthropic/claude-3.5-sonnet", + "google": "gemini-1.5-flash", + "anthropic": "claude-3-5-haiku-20241022", + "grok": "grok-3-mini" + } + model = provider_defaults.get(provider_name, "gpt-4o-mini") + search_logger.debug(f"Using default model for provider {provider_name}: {model}") search_logger.debug(f"Using model from credential service: {model}") return model @@ -174,38 +188,48 @@ async def generate_contextual_embeddings_batch( model_choice = await _get_model_choice(provider) # Build batch prompt for ALL chunks at once - batch_prompt = ( - "Process the following chunks and provide contextual information for each:\\n\\n" - ) + batch_prompt = "Process the following chunks and provide contextual information for each:\n\n" for i, (doc, chunk) in enumerate(zip(full_documents, chunks, strict=False)): # Use only 2000 chars of document context to save tokens doc_preview = doc[:2000] if len(doc) > 2000 else doc - batch_prompt += f"CHUNK {i + 1}:\\n" - batch_prompt += f"\\n{doc_preview}\\n\\n" - batch_prompt += f"\\n{chunk[:500]}\\n\\n\\n" # Limit chunk preview + batch_prompt += f"CHUNK {i + 1}:\n" + batch_prompt += f"\n{doc_preview}\n\n" + batch_prompt += f"\n{chunk[:500]}\n\n\n" # Limit chunk preview - batch_prompt += "For each chunk, provide a short succinct context to situate it within the overall document for improving search retrieval. Format your response as:\\nCHUNK 1: [context]\\nCHUNK 2: [context]\\netc." + batch_prompt += ( + "For each chunk, provide a short succinct context to situate it within the overall document for improving search retrieval. " + "Format your response as:\nCHUNK 1: [context]\nCHUNK 2: [context]\netc." + ) # Make single API call for ALL chunks - response = await client.chat.completions.create( - model=model_choice, - messages=[ + # Prepare parameters and convert max_tokens for GPT-5/reasoning models + batch_params = { + "model": model_choice, + "messages": [ { "role": "system", "content": "You are a helpful assistant that generates contextual information for document chunks.", }, {"role": "user", "content": batch_prompt}, ], - temperature=0, - max_tokens=100 * len(chunks), # Limit response size - ) + "temperature": 0, + "max_tokens": (600 if requires_max_completion_tokens(model_choice) else 100) * len(chunks), # Much more tokens for reasoning models (GPT-5 needs extra reasoning space) + } + final_batch_params = prepare_chat_completion_params(model_choice, batch_params) + response = await client.chat.completions.create(**final_batch_params) # Parse response - response_text = response.choices[0].message.content + choice = response.choices[0] if response.choices else None + response_text, _, _ = extract_message_text(choice) + if not response_text: + search_logger.error( + "Empty response from LLM when generating contextual embeddings batch" + ) + return [(chunk, False) for chunk in chunks] # Extract contexts from response - lines = response_text.strip().split("\\n") + lines = response_text.strip().split("\n") chunk_contexts = {} for line in lines: @@ -245,4 +269,4 @@ async def generate_contextual_embeddings_batch( except Exception as e: search_logger.error(f"Error in contextual embedding batch: {e}") # Return non-contextual for all chunks - return [(chunk, False) for chunk in chunks] \ No newline at end of file + return [(chunk, False) for chunk in chunks] diff --git a/python/src/server/services/embeddings/embedding_service.py b/python/src/server/services/embeddings/embedding_service.py index d697abf9..4f825f1d 100644 --- a/python/src/server/services/embeddings/embedding_service.py +++ b/python/src/server/services/embeddings/embedding_service.py @@ -13,7 +13,7 @@ import openai from ...config.logfire_config import safe_span, search_logger from ..credential_service import credential_service -from ..llm_provider_service import get_embedding_model, get_llm_client +from ..llm_provider_service import get_embedding_model, get_llm_client, is_google_embedding_model, is_openai_embedding_model from ..threading_service import get_threading_service from .embedding_exceptions import ( EmbeddingAPIError, @@ -152,34 +152,56 @@ async def create_embeddings_batch( if not texts: return EmbeddingBatchResult() + result = EmbeddingBatchResult() + # Validate that all items in texts are strings validated_texts = [] for i, text in enumerate(texts): - if not isinstance(text, str): - search_logger.error( - f"Invalid text type at index {i}: {type(text)}, value: {text}", exc_info=True - ) - # Try to convert to string - try: - validated_texts.append(str(text)) - except Exception as e: - search_logger.error( - f"Failed to convert text at index {i} to string: {e}", exc_info=True - ) - validated_texts.append("") # Use empty string as fallback - else: + if isinstance(text, str): validated_texts.append(text) + continue + + search_logger.error( + f"Invalid text type at index {i}: {type(text)}, value: {text}", exc_info=True + ) + try: + converted = str(text) + validated_texts.append(converted) + except Exception as conversion_error: + search_logger.error( + f"Failed to convert text at index {i} to string: {conversion_error}", + exc_info=True, + ) + result.add_failure( + repr(text), + EmbeddingAPIError("Invalid text type", original_error=conversion_error), + batch_index=None, + ) texts = validated_texts - - result = EmbeddingBatchResult() threading_service = get_threading_service() with safe_span( "create_embeddings_batch", text_count=len(texts), total_chars=sum(len(t) for t in texts) ) as span: try: - async with get_llm_client(provider=provider, use_embedding_provider=True) as client: + # Intelligent embedding provider routing based on model type + # Get the embedding model first to determine the correct provider + embedding_model = await get_embedding_model(provider=provider) + + # Route to correct provider based on model type + if is_google_embedding_model(embedding_model): + embedding_provider = "google" + search_logger.info(f"Routing to Google for embedding model: {embedding_model}") + elif is_openai_embedding_model(embedding_model) or "openai/" in embedding_model.lower(): + embedding_provider = "openai" + search_logger.info(f"Routing to OpenAI for embedding model: {embedding_model}") + else: + # Keep original provider for ollama and other providers + embedding_provider = provider + search_logger.info(f"Using original provider '{provider}' for embedding model: {embedding_model}") + + async with get_llm_client(provider=embedding_provider, use_embedding_provider=True) as client: # Load batch size and dimensions from settings try: rag_settings = await credential_service.get_credentials_by_category( @@ -220,7 +242,8 @@ async def create_embeddings_batch( while retry_count < max_retries: try: # Create embeddings for this batch - embedding_model = await get_embedding_model(provider=provider) + embedding_model = await get_embedding_model(provider=embedding_provider) + response = await client.embeddings.create( model=embedding_model, input=batch, diff --git a/python/src/server/services/llm_provider_service.py b/python/src/server/services/llm_provider_service.py index 10655b22..00197926 100644 --- a/python/src/server/services/llm_provider_service.py +++ b/python/src/server/services/llm_provider_service.py @@ -5,6 +5,7 @@ Provides a unified interface for creating OpenAI-compatible clients for differen Supports OpenAI, Ollama, and Google Gemini. """ +import inspect import time from contextlib import asynccontextmanager from typing import Any @@ -16,31 +17,305 @@ from .credential_service import credential_service logger = get_logger(__name__) -# Settings cache with TTL -_settings_cache: dict[str, tuple[Any, float]] = {} + +# Basic validation functions to avoid circular imports +def _is_valid_provider(provider: str) -> bool: + """Basic provider validation.""" + if not provider or not isinstance(provider, str): + return False + return provider.lower() in {"openai", "ollama", "google", "openrouter", "anthropic", "grok"} + + +def _sanitize_for_log(text: str) -> str: + """Basic text sanitization for logging.""" + if not text: + return "" + import re + sanitized = re.sub(r"sk-[a-zA-Z0-9-_]{20,}", "[REDACTED]", text) + sanitized = re.sub(r"xai-[a-zA-Z0-9-_]{20,}", "[REDACTED]", sanitized) + return sanitized[:100] + + +# Secure settings cache with TTL and validation +_settings_cache: dict[str, tuple[Any, float, str]] = {} # value, timestamp, checksum _CACHE_TTL_SECONDS = 300 # 5 minutes +_cache_access_log: list[dict] = [] # Track cache access patterns for security monitoring + + +def _calculate_cache_checksum(value: Any) -> str: + """Calculate checksum for cache entry integrity validation.""" + import hashlib + import json + + # Convert value to JSON string for consistent hashing + try: + value_str = json.dumps(value, sort_keys=True, default=str) + return hashlib.sha256(value_str.encode()).hexdigest()[:16] # First 16 chars for efficiency + except Exception: + # Fallback for non-serializable objects + return hashlib.sha256(str(value).encode()).hexdigest()[:16] + + +def _log_cache_access(key: str, action: str, hit: bool = None, security_event: str = None) -> None: + """Log cache access for security monitoring.""" + + access_entry = { + "timestamp": time.time(), + "key": _sanitize_for_log(key), + "action": action, # "get", "set", "invalidate", "clear" + "hit": hit, # For get operations + "security_event": security_event # "checksum_mismatch", "expired", etc. + } + + # Keep only last 100 access entries to prevent memory growth + _cache_access_log.append(access_entry) + if len(_cache_access_log) > 100: + _cache_access_log.pop(0) + + # Log security events at warning level + if security_event: + safe_key = _sanitize_for_log(key) + logger.warning(f"Cache security event: {security_event} for key '{safe_key}'") def _get_cached_settings(key: str) -> Any | None: - """Get cached settings if not expired.""" - if key in _settings_cache: - value, timestamp = _settings_cache[key] - if time.time() - timestamp < _CACHE_TTL_SECONDS: + """Get cached settings if not expired and valid.""" + + try: + if key in _settings_cache: + value, timestamp, stored_checksum = _settings_cache[key] + current_time = time.time() + + # Check expiration with strict TTL enforcement + if current_time - timestamp >= _CACHE_TTL_SECONDS: + # Expired, remove from cache + del _settings_cache[key] + _log_cache_access(key, "get", hit=False, security_event="expired") + return None + + # Verify cache entry integrity + current_checksum = _calculate_cache_checksum(value) + if current_checksum != stored_checksum: + # Cache tampering detected, remove entry + del _settings_cache[key] + _log_cache_access(key, "get", hit=False, security_event="checksum_mismatch") + logger.error(f"Cache integrity violation detected for key: {_sanitize_for_log(key)}") + return None + + # Additional validation for provider configurations + if "provider_config" in key and isinstance(value, dict): + # Basic validation: check required fields + if not value.get("provider") or not _is_valid_provider(value.get("provider")): + # Invalid configuration in cache, remove it + del _settings_cache[key] + _log_cache_access(key, "get", hit=False, security_event="invalid_config") + return None + + _log_cache_access(key, "get", hit=True) return value - else: - # Expired, remove from cache - del _settings_cache[key] - return None + + _log_cache_access(key, "get", hit=False) + return None + + except Exception as e: + # Cache access error, log and return None for safety + _log_cache_access(key, "get", hit=False, security_event=f"access_error: {str(e)}") + return None def _set_cached_settings(key: str, value: Any) -> None: - """Cache settings with current timestamp.""" - _settings_cache[key] = (value, time.time()) + """Cache settings with current timestamp and integrity checksum.""" + + try: + # Validate provider configurations before caching + if "provider_config" in key and isinstance(value, dict): + # Basic validation: check required fields + if not value.get("provider") or not _is_valid_provider(value.get("provider")): + _log_cache_access(key, "set", security_event="invalid_config_rejected") + logger.warning(f"Rejected caching of invalid provider config for key: {_sanitize_for_log(key)}") + return + + # Calculate integrity checksum + checksum = _calculate_cache_checksum(value) + + # Store with timestamp and checksum + _settings_cache[key] = (value, time.time(), checksum) + _log_cache_access(key, "set") + + except Exception as e: + _log_cache_access(key, "set", security_event=f"set_error: {str(e)}") + logger.error(f"Failed to cache settings for key {_sanitize_for_log(key)}: {e}") +def clear_provider_cache() -> None: + """Clear the provider configuration cache to force refresh on next request.""" + global _settings_cache + + cache_size_before = len(_settings_cache) + _settings_cache.clear() + _log_cache_access("*", "clear") + logger.debug(f"Provider configuration cache cleared ({cache_size_before} entries removed)") + + +def invalidate_provider_cache(provider: str = None) -> None: + """ + Invalidate specific provider cache entries or all cache entries. + + Args: + provider: Optional provider name to invalidate. If None, clears all cache. + """ + global _settings_cache + + if provider is None: + # Clear entire cache + cache_size_before = len(_settings_cache) + _settings_cache.clear() + _log_cache_access("*", "invalidate") + logger.debug(f"All provider cache entries invalidated ({cache_size_before} entries)") + else: + # Validate provider name before processing + if not _is_valid_provider(provider): + _log_cache_access(provider, "invalidate", security_event="invalid_provider_name") + logger.warning(f"Rejected cache invalidation for invalid provider: {_sanitize_for_log(provider)}") + return + + # Clear specific provider entries + keys_to_remove = [] + for key in _settings_cache.keys(): + if provider in key: + keys_to_remove.append(key) + + for key in keys_to_remove: + del _settings_cache[key] + _log_cache_access(key, "invalidate") + + safe_provider = _sanitize_for_log(provider) + logger.debug(f"Cache entries for provider '{safe_provider}' invalidated: {len(keys_to_remove)} entries removed") + + +def get_cache_stats() -> dict[str, Any]: + """ + Get cache statistics with security metrics for monitoring and debugging. + + Returns: + Dictionary containing cache statistics and security metrics + """ + global _settings_cache, _cache_access_log + current_time = time.time() + + stats = { + "total_entries": len(_settings_cache), + "fresh_entries": 0, + "stale_entries": 0, + "cache_hit_potential": 0.0, + "security_metrics": { + "integrity_violations": 0, + "expired_access_attempts": 0, + "invalid_config_rejections": 0, + "access_errors": 0, + "total_security_events": 0 + }, + "access_patterns": { + "recent_cache_hits": 0, + "recent_cache_misses": 0, + "hit_rate": 0.0 + } + } + + # Analyze cache entries + for _key, (_value, timestamp, _checksum) in _settings_cache.items(): + age = current_time - timestamp + if age < _CACHE_TTL_SECONDS: + stats["fresh_entries"] += 1 + else: + stats["stale_entries"] += 1 + + if stats["total_entries"] > 0: + stats["cache_hit_potential"] = stats["fresh_entries"] / stats["total_entries"] + + # Analyze security events from access log + recent_threshold = current_time - 3600 # Last hour + recent_hits = 0 + recent_misses = 0 + + for access in _cache_access_log: + if access["timestamp"] >= recent_threshold: + if access["action"] == "get": + if access["hit"]: + recent_hits += 1 + else: + recent_misses += 1 + + # Count security events + if access["security_event"]: + stats["security_metrics"]["total_security_events"] += 1 + + if "checksum_mismatch" in access["security_event"]: + stats["security_metrics"]["integrity_violations"] += 1 + elif "expired" in access["security_event"]: + stats["security_metrics"]["expired_access_attempts"] += 1 + elif "invalid_config" in access["security_event"]: + stats["security_metrics"]["invalid_config_rejections"] += 1 + elif "error" in access["security_event"]: + stats["security_metrics"]["access_errors"] += 1 + + # Calculate hit rate + total_recent_access = recent_hits + recent_misses + if total_recent_access > 0: + stats["access_patterns"]["hit_rate"] = recent_hits / total_recent_access + + stats["access_patterns"]["recent_cache_hits"] = recent_hits + stats["access_patterns"]["recent_cache_misses"] = recent_misses + + return stats + + +def get_cache_security_report() -> dict[str, Any]: + """ + Get detailed security report for cache monitoring. + + Returns: + Detailed security analysis of cache operations + """ + global _cache_access_log + current_time = time.time() + + report = { + "timestamp": current_time, + "analysis_period_hours": 1, + "security_events": [], + "recommendations": [] + } + + # Extract security events from last hour + recent_threshold = current_time - 3600 + security_events = [ + access for access in _cache_access_log + if access["timestamp"] >= recent_threshold and access["security_event"] + ] + + report["security_events"] = security_events + + # Generate recommendations based on security events + if len(security_events) > 10: + report["recommendations"].append("High number of security events detected - investigate potential attacks") + + integrity_violations = sum(1 for event in security_events if "checksum_mismatch" in event.get("security_event", "")) + if integrity_violations > 0: + report["recommendations"].append(f"Cache integrity violations detected ({integrity_violations}) - check for memory corruption or attacks") + + invalid_configs = sum(1 for event in security_events if "invalid_config" in event.get("security_event", "")) + if invalid_configs > 3: + report["recommendations"].append(f"Multiple invalid configuration attempts ({invalid_configs}) - validate data sources") + + return report @asynccontextmanager -async def get_llm_client(provider: str | None = None, use_embedding_provider: bool = False, - instance_type: str | None = None, base_url: str | None = None): +async def get_llm_client( + provider: str | None = None, + use_embedding_provider: bool = False, + instance_type: str | None = None, + base_url: str | None = None, +): """ Create an async OpenAI-compatible client based on the configured provider. @@ -58,6 +333,8 @@ async def get_llm_client(provider: str | None = None, use_embedding_provider: bo openai.AsyncOpenAI: An OpenAI-compatible client configured for the selected provider """ client = None + provider_name: str | None = None + api_key = None try: # Get provider configuration from database settings @@ -77,7 +354,11 @@ async def get_llm_client(provider: str | None = None, use_embedding_provider: bo logger.debug("Using cached rag_strategy settings") # For Ollama, don't use the base_url from config - let _get_optimal_ollama_instance decide - base_url = credential_service._get_provider_base_url(provider, rag_settings) if provider != "ollama" else None + base_url = ( + credential_service._get_provider_base_url(provider, rag_settings) + if provider != "ollama" + else None + ) else: # Get configured provider from database service_type = "embedding" if use_embedding_provider else "llm" @@ -97,45 +378,61 @@ async def get_llm_client(provider: str | None = None, use_embedding_provider: bo # For Ollama, don't use the base_url from config - let _get_optimal_ollama_instance decide base_url = provider_config["base_url"] if provider_name != "ollama" else None - logger.info(f"Creating LLM client for provider: {provider_name}") + # Comprehensive provider validation with security checks + if not _is_valid_provider(provider_name): + raise ValueError(f"Unsupported LLM provider: {provider_name}") + + # Validate API key format for security (prevent injection) + if api_key: + if len(api_key.strip()) == 0: + api_key = None # Treat empty strings as None + elif len(api_key) > 500: # Reasonable API key length limit + raise ValueError("API key length exceeds security limits") + # Additional security: check for suspicious patterns + if any(char in api_key for char in ['\n', '\r', '\t', '\0']): + raise ValueError("API key contains invalid characters") + + # Sanitize provider name for logging + safe_provider_name = _sanitize_for_log(provider_name) if provider_name else "unknown" + logger.info(f"Creating LLM client for provider: {safe_provider_name}") if provider_name == "openai": - if not api_key: - # Check if Ollama instances are available as fallback - logger.warning("OpenAI API key not found, attempting Ollama fallback") - try: - # Try to get an optimal Ollama instance for fallback - ollama_base_url = await _get_optimal_ollama_instance( - instance_type="embedding" if use_embedding_provider else "chat", - use_embedding_provider=use_embedding_provider - ) - if ollama_base_url: - logger.info(f"Falling back to Ollama instance: {ollama_base_url}") - provider_name = "ollama" - api_key = "ollama" # Ollama doesn't need a real API key - base_url = ollama_base_url - # Create Ollama client after fallback - client = openai.AsyncOpenAI( - api_key="ollama", - base_url=ollama_base_url, - ) - logger.info(f"Ollama fallback client created successfully with base URL: {ollama_base_url}") - else: - raise ValueError("OpenAI API key not found and no Ollama instances available") - except Exception as ollama_error: - logger.error(f"Ollama fallback failed: {ollama_error}") - raise ValueError("OpenAI API key not found and Ollama fallback failed") from ollama_error - else: - # Only create OpenAI client if we have an API key (didn't fallback to Ollama) + if api_key: client = openai.AsyncOpenAI(api_key=api_key) logger.info("OpenAI client created successfully") + else: + logger.warning("OpenAI API key not found, attempting Ollama fallback") + try: + ollama_base_url = await _get_optimal_ollama_instance( + instance_type="embedding" if use_embedding_provider else "chat", + use_embedding_provider=use_embedding_provider, + base_url_override=base_url, + ) + + if not ollama_base_url: + raise RuntimeError("No Ollama base URL resolved") + + client = openai.AsyncOpenAI( + api_key="ollama", + base_url=ollama_base_url, + ) + logger.info( + f"Ollama fallback client created successfully with base URL: {ollama_base_url}" + ) + provider_name = "ollama" + api_key = "ollama" + base_url = ollama_base_url + except Exception as fallback_error: + raise ValueError( + "OpenAI API key not found and Ollama fallback failed" + ) from fallback_error elif provider_name == "ollama": - # Enhanced Ollama client creation with multi-instance support + # For Ollama, get the optimal instance based on usage ollama_base_url = await _get_optimal_ollama_instance( instance_type=instance_type, use_embedding_provider=use_embedding_provider, - base_url_override=base_url + base_url_override=base_url, ) # Ollama requires an API key in the client but doesn't actually use it @@ -155,19 +452,101 @@ async def get_llm_client(provider: str | None = None, use_embedding_provider: bo ) logger.info("Google Gemini client created successfully") + elif provider_name == "openrouter": + if not api_key: + raise ValueError("OpenRouter API key not found") + + client = openai.AsyncOpenAI( + api_key=api_key, + base_url=base_url or "https://openrouter.ai/api/v1", + ) + logger.info("OpenRouter client created successfully") + + elif provider_name == "anthropic": + if not api_key: + raise ValueError("Anthropic API key not found") + + client = openai.AsyncOpenAI( + api_key=api_key, + base_url=base_url or "https://api.anthropic.com/v1", + ) + logger.info("Anthropic client created successfully") + + elif provider_name == "grok": + if not api_key: + raise ValueError("Grok API key not found - set GROK_API_KEY environment variable") + + # Enhanced Grok API key validation (secure - no key fragments logged) + key_format_valid = api_key.startswith("xai-") + key_length_valid = len(api_key) >= 20 + + if not key_format_valid: + logger.warning("Grok API key format validation failed - should start with 'xai-'") + + if not key_length_valid: + logger.warning("Grok API key validation failed - insufficient length") + + logger.debug( + f"Grok API key validation: format_valid={key_format_valid}, length_valid={key_length_valid}" + ) + + client = openai.AsyncOpenAI( + api_key=api_key, + base_url=base_url or "https://api.x.ai/v1", + ) + logger.info("Grok client created successfully") + else: raise ValueError(f"Unsupported LLM provider: {provider_name}") - yield client - except Exception as e: logger.error( - f"Error creating LLM client for provider {provider_name if 'provider_name' in locals() else 'unknown'}: {e}" + f"Error creating LLM client for provider {provider_name if provider_name else 'unknown'}: {e}" ) raise + + try: + yield client finally: - # Cleanup if needed - pass + if client is not None: + safe_provider = _sanitize_for_log(provider_name) if provider_name else "unknown" + + try: + close_method = getattr(client, "aclose", None) + if callable(close_method): + if inspect.iscoroutinefunction(close_method): + await close_method() + else: + maybe_coro = close_method() + if inspect.isawaitable(maybe_coro): + await maybe_coro + else: + close_method = getattr(client, "close", None) + if callable(close_method): + if inspect.iscoroutinefunction(close_method): + await close_method() + else: + close_result = close_method() + if inspect.isawaitable(close_result): + await close_result + logger.debug(f"Closed LLM client for provider: {safe_provider}") + except RuntimeError as close_error: + if "Event loop is closed" in str(close_error): + logger.error( + f"Failed to close LLM client cleanly for provider {safe_provider}: event loop already closed", + exc_info=True, + ) + else: + logger.error( + f"Runtime error closing LLM client for provider {safe_provider}: {close_error}", + exc_info=True, + ) + except Exception as close_error: + logger.error( + f"Unexpected error while closing LLM client for provider {safe_provider}: {close_error}", + exc_info=True, + ) + async def _get_optimal_ollama_instance(instance_type: str | None = None, @@ -250,9 +629,20 @@ async def get_embedding_model(provider: str | None = None) -> str: provider_name = provider_config["provider"] custom_model = provider_config["embedding_model"] - # Use custom model if specified - if custom_model: - return custom_model + # Comprehensive provider validation for embeddings + if not _is_valid_provider(provider_name): + safe_provider = _sanitize_for_log(provider_name) + logger.warning(f"Invalid embedding provider: {safe_provider}, falling back to OpenAI") + provider_name = "openai" + # Use custom model if specified (with validation) + if custom_model and len(custom_model.strip()) > 0: + custom_model = custom_model.strip() + # Basic model name validation (check length and basic characters) + if len(custom_model) <= 100 and not any(char in custom_model for char in ['\n', '\r', '\t', '\0']): + return custom_model + else: + safe_model = _sanitize_for_log(custom_model) + logger.warning(f"Invalid custom embedding model '{safe_model}' for provider '{provider_name}', using default") # Return provider-specific defaults if provider_name == "openai": @@ -261,8 +651,20 @@ async def get_embedding_model(provider: str | None = None) -> str: # Ollama default embedding model return "nomic-embed-text" elif provider_name == "google": - # Google's embedding model + # Google's latest embedding model return "text-embedding-004" + elif provider_name == "openrouter": + # OpenRouter supports both OpenAI and Google embedding models + # Default to OpenAI's latest for compatibility + return "text-embedding-3-small" + elif provider_name == "anthropic": + # Anthropic supports OpenAI and Google embedding models through their API + # Default to OpenAI's latest for compatibility + return "text-embedding-3-small" + elif provider_name == "grok": + # Grok supports OpenAI and Google embedding models through their API + # Default to OpenAI's latest for compatibility + return "text-embedding-3-small" else: # Fallback to OpenAI's model return "text-embedding-3-small" @@ -273,6 +675,463 @@ async def get_embedding_model(provider: str | None = None) -> str: return "text-embedding-3-small" +def is_openai_embedding_model(model: str) -> bool: + """Check if a model is an OpenAI embedding model.""" + if not model: + return False + + model_lower = model.strip().lower() + + # Known OpenAI embeddings + base_models = { + "text-embedding-ada-002", + "text-embedding-3-small", + "text-embedding-3-large", + } + + if model_lower in base_models: + return True + + # Strip common vendor prefixes like "openai/" or "openrouter/" + for separator in ("/", ":"): + if separator in model_lower: + candidate = model_lower.split(separator)[-1] + if candidate in base_models: + return True + + # Fallback substring detection for custom naming conventions + return any(base in model_lower for base in base_models) + + +def is_google_embedding_model(model: str) -> bool: + """Check if a model is a Google embedding model.""" + if not model: + return False + + model_lower = model.lower() + google_patterns = [ + "text-embedding-004", + "text-embedding-005", + "text-multilingual-embedding-002", + "gemini-embedding-001", + "multimodalembedding@001" + ] + + return any(pattern in model_lower for pattern in google_patterns) + + +def is_valid_embedding_model_for_provider(model: str, provider: str) -> bool: + """ + Validate if an embedding model is compatible with a provider. + + Args: + model: The embedding model name + provider: The provider name + + Returns: + bool: True if the model is compatible with the provider + """ + if not model or not provider: + return False + + provider_lower = provider.lower() + + if provider_lower == "openai": + return is_openai_embedding_model(model) + elif provider_lower == "google": + return is_google_embedding_model(model) + elif provider_lower in ["openrouter", "anthropic", "grok"]: + # These providers support both OpenAI and Google models + return is_openai_embedding_model(model) or is_google_embedding_model(model) + elif provider_lower == "ollama": + # Ollama has its own models, check common ones + model_lower = model.lower() + ollama_patterns = ["nomic-embed", "all-minilm", "mxbai-embed", "embed"] + return any(pattern in model_lower for pattern in ollama_patterns) + else: + # For unknown providers, assume OpenAI compatibility + return is_openai_embedding_model(model) + + +def get_supported_embedding_models(provider: str) -> list[str]: + """ + Get list of supported embedding models for a provider. + + Args: + provider: The provider name + + Returns: + List of supported embedding model names + """ + if not provider: + return [] + + provider_lower = provider.lower() + + openai_models = [ + "text-embedding-ada-002", + "text-embedding-3-small", + "text-embedding-3-large" + ] + + google_models = [ + "text-embedding-004", + "text-embedding-005", + "text-multilingual-embedding-002", + "gemini-embedding-001", + "multimodalembedding@001" + ] + + if provider_lower == "openai": + return openai_models + elif provider_lower == "google": + return google_models + elif provider_lower in ["openrouter", "anthropic", "grok"]: + # These providers support both OpenAI and Google models + return openai_models + google_models + elif provider_lower == "ollama": + return ["nomic-embed-text", "all-minilm", "mxbai-embed-large"] + else: + # For unknown providers, assume OpenAI compatibility + return openai_models + + +def is_reasoning_model(model_name: str) -> bool: + """ + Unified check for reasoning models across providers. + + Normalizes vendor prefixes (openai/, openrouter/, x-ai/, deepseek/) before checking + known reasoning families (OpenAI GPT-5, o1, o3; xAI Grok; DeepSeek-R; etc.). + """ + if not model_name: + return False + + model_lower = model_name.lower() + + # Normalize vendor prefixes (e.g., openai/gpt-5-nano, openrouter/x-ai/grok-4) + if "/" in model_lower: + parts = model_lower.split("/") + # Drop known vendor prefixes while keeping the final model identifier + known_prefixes = {"openai", "openrouter", "x-ai", "deepseek", "anthropic"} + filtered_parts = [part for part in parts if part not in known_prefixes] + if filtered_parts: + model_lower = filtered_parts[-1] + else: + model_lower = parts[-1] + + if ":" in model_lower: + model_lower = model_lower.split(":", 1)[-1] + + reasoning_prefixes = ( + "gpt-5", + "o1", + "o3", + "o4", + "grok", + "deepseek-r", + "deepseek-reasoner", + "deepseek-chat-r", + ) + + return model_lower.startswith(reasoning_prefixes) + + +def _extract_reasoning_strings(value: Any) -> list[str]: + """Convert reasoning payload fragments into plain-text strings.""" + + if value is None: + return [] + + if isinstance(value, str): + text = value.strip() + return [text] if text else [] + + if isinstance(value, (list, tuple, set)): + collected: list[str] = [] + for item in value: + collected.extend(_extract_reasoning_strings(item)) + return collected + + if isinstance(value, dict): + candidates = [] + for key in ("text", "summary", "content", "message", "value"): + if value.get(key): + candidates.extend(_extract_reasoning_strings(value[key])) + # Some providers nest reasoning parts under "parts" + if value.get("parts"): + candidates.extend(_extract_reasoning_strings(value["parts"])) + return candidates + + # Handle pydantic-style objects with attributes + for attr in ("text", "summary", "content", "value"): + if hasattr(value, attr): + attr_value = getattr(value, attr) + if attr_value: + return _extract_reasoning_strings(attr_value) + + return [] + + +def _get_message_attr(message: Any, attribute: str) -> Any: + """Safely access message attributes that may be dict keys or properties.""" + + if hasattr(message, attribute): + return getattr(message, attribute) + if isinstance(message, dict): + return message.get(attribute) + return None + + +def extract_message_text(choice: Any) -> tuple[str, str, bool]: + """Extract primary content and reasoning text from a chat completion choice.""" + + if not choice: + return "", "", False + + message = _get_message_attr(choice, "message") + if message is None: + return "", "", False + + raw_content = _get_message_attr(message, "content") + content_text = raw_content.strip() if isinstance(raw_content, str) else "" + + reasoning_fragments: list[str] = [] + for attr in ("reasoning", "reasoning_details", "reasoning_content"): + reasoning_value = _get_message_attr(message, attr) + if reasoning_value: + reasoning_fragments.extend(_extract_reasoning_strings(reasoning_value)) + + reasoning_text = "\n".join(fragment for fragment in reasoning_fragments if fragment) + reasoning_text = reasoning_text.strip() + + # If content looks like reasoning text but no reasoning field, detect it + if content_text and not reasoning_text and _is_reasoning_text(content_text): + reasoning_text = content_text + # Try to extract structured data from reasoning text + extracted_json = extract_json_from_reasoning(content_text) + if extracted_json: + content_text = extracted_json + else: + content_text = "" + + if not content_text and reasoning_text: + content_text = reasoning_text + + has_reasoning = bool(reasoning_text) + + return content_text, reasoning_text, has_reasoning + + +def _is_reasoning_text(text: str) -> bool: + """Detect if text appears to be reasoning/thinking output rather than structured content.""" + if not text or len(text) < 10: + return False + + text_lower = text.lower().strip() + + # Common reasoning text patterns + reasoning_indicators = [ + "okay, let's see", "let me think", "first, i need to", "looking at this", + "step by step", "analyzing", "breaking this down", "considering", + "let me work through", "i should", "thinking about", "examining" + ] + + return any(indicator in text_lower for indicator in reasoning_indicators) + + +def extract_json_from_reasoning(reasoning_text: str, context_code: str = "", language: str = "") -> str: + """Extract JSON content from reasoning text, with synthesis fallback.""" + if not reasoning_text: + return "" + + import json + import re + + # Try to find JSON blocks in markdown + json_block_pattern = r'```(?:json)?\s*(\{.*?\})\s*```' + json_matches = re.findall(json_block_pattern, reasoning_text, re.DOTALL | re.IGNORECASE) + + for match in json_matches: + try: + # Validate it's proper JSON + json.loads(match.strip()) + return match.strip() + except json.JSONDecodeError: + continue + + # Try to find standalone JSON objects + json_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}' + json_matches = re.findall(json_pattern, reasoning_text, re.DOTALL) + + for match in json_matches: + try: + parsed = json.loads(match.strip()) + # Ensure it has expected structure + if isinstance(parsed, dict) and any(key in parsed for key in ["example_name", "summary", "name", "title"]): + return match.strip() + except json.JSONDecodeError: + continue + + # If no JSON found, synthesize from reasoning content + return synthesize_json_from_reasoning(reasoning_text, context_code, language) + + +def synthesize_json_from_reasoning(reasoning_text: str, context_code: str = "", language: str = "") -> str: + """Generate JSON structure from reasoning text when no JSON is found.""" + if not reasoning_text and not context_code: + return "" + + import json + import re + + # Extract key concepts and actions from reasoning text and code context + text_lower = reasoning_text.lower() if reasoning_text else "" + code_lower = context_code.lower() if context_code else "" + combined_text = f"{text_lower} {code_lower}" + + # Common action patterns in reasoning text and code + action_patterns = [ + (r'\b(?:parse|parsing|parsed)\b', 'Parse'), + (r'\b(?:create|creating|created)\b', 'Create'), + (r'\b(?:analyze|analyzing|analyzed)\b', 'Analyze'), + (r'\b(?:extract|extracting|extracted)\b', 'Extract'), + (r'\b(?:generate|generating|generated)\b', 'Generate'), + (r'\b(?:process|processing|processed)\b', 'Process'), + (r'\b(?:load|loading|loaded)\b', 'Load'), + (r'\b(?:handle|handling|handled)\b', 'Handle'), + (r'\b(?:manage|managing|managed)\b', 'Manage'), + (r'\b(?:build|building|built)\b', 'Build'), + (r'\b(?:define|defining|defined)\b', 'Define'), + (r'\b(?:implement|implementing|implemented)\b', 'Implement'), + (r'\b(?:fetch|fetching|fetched)\b', 'Fetch'), + (r'\b(?:connect|connecting|connected)\b', 'Connect'), + (r'\b(?:validate|validating|validated)\b', 'Validate'), + ] + + # Technology/concept patterns + tech_patterns = [ + (r'\bjson\b', 'JSON'), + (r'\bapi\b', 'API'), + (r'\bfile\b', 'File'), + (r'\bdata\b', 'Data'), + (r'\bcode\b', 'Code'), + (r'\btext\b', 'Text'), + (r'\bcontent\b', 'Content'), + (r'\bresponse\b', 'Response'), + (r'\brequest\b', 'Request'), + (r'\bconfig\b', 'Config'), + (r'\bllm\b', 'LLM'), + (r'\bmodel\b', 'Model'), + (r'\bexample\b', 'Example'), + (r'\bcontext\b', 'Context'), + (r'\basync\b', 'Async'), + (r'\bfunction\b', 'Function'), + (r'\bclass\b', 'Class'), + (r'\bprint\b', 'Output'), + (r'\breturn\b', 'Return'), + ] + + # Extract actions and technologies from combined text + detected_actions = [] + detected_techs = [] + + for pattern, action in action_patterns: + if re.search(pattern, combined_text): + detected_actions.append(action) + + for pattern, tech in tech_patterns: + if re.search(pattern, combined_text): + detected_techs.append(tech) + + # Generate example name + if detected_actions and detected_techs: + example_name = f"{detected_actions[0]} {detected_techs[0]}" + elif detected_actions: + example_name = f"{detected_actions[0]} Code" + elif detected_techs: + example_name = f"Handle {detected_techs[0]}" + elif language: + example_name = f"Process {language.title()}" + else: + example_name = "Code Processing" + + # Limit to 4 words as per requirements + example_name_words = example_name.split() + if len(example_name_words) > 4: + example_name = " ".join(example_name_words[:4]) + + # Generate summary from reasoning content + reasoning_lines = reasoning_text.split('\n') + meaningful_lines = [line.strip() for line in reasoning_lines if line.strip() and len(line.strip()) > 10] + + if meaningful_lines: + # Take first meaningful sentence for summary base + first_line = meaningful_lines[0] + if len(first_line) > 100: + first_line = first_line[:100] + "..." + + # Create contextual summary + if context_code and any(tech in text_lower for tech, _ in tech_patterns): + summary = f"This code demonstrates {detected_techs[0].lower() if detected_techs else 'data'} processing functionality. {first_line}" + else: + summary = f"Code example showing {detected_actions[0].lower() if detected_actions else 'processing'} operations. {first_line}" + else: + # Fallback summary + summary = f"Code example demonstrating {example_name.lower()} functionality for {language or 'general'} development." + + # Ensure summary is not too long + if len(summary) > 300: + summary = summary[:297] + "..." + + # Create JSON structure + result = { + "example_name": example_name, + "summary": summary + } + + return json.dumps(result) + + +def prepare_chat_completion_params(model: str, params: dict) -> dict: + """ + Convert parameters for compatibility with reasoning models (GPT-5, o1, o3 series). + + OpenAI made several API changes for reasoning models: + 1. max_tokens → max_completion_tokens + 2. temperature must be 1.0 (default) - custom values not supported + + This ensures compatibility with OpenAI's API changes for newer models + while maintaining backward compatibility for existing models. + + Args: + model: The model name being used + params: Dictionary of API parameters + + Returns: + Dictionary with converted parameters for the model + """ + if not model or not params: + return params + + # Make a copy to avoid modifying the original + updated_params = params.copy() + + reasoning_model = is_reasoning_model(model) + + # Convert max_tokens to max_completion_tokens for reasoning models + if reasoning_model and "max_tokens" in updated_params: + max_tokens_value = updated_params.pop("max_tokens") + updated_params["max_completion_tokens"] = max_tokens_value + logger.debug(f"Converted max_tokens to max_completion_tokens for model {model}") + + # Remove custom temperature for reasoning models (they only support default temperature=1.0) + if reasoning_model and "temperature" in updated_params: + original_temp = updated_params.pop("temperature") + logger.debug(f"Removed custom temperature {original_temp} for reasoning model {model} (only supports default temperature=1.0)") + + return updated_params + + async def get_embedding_model_with_routing(provider: str | None = None, instance_url: str | None = None) -> tuple[str, str]: """ Get the embedding model with intelligent routing for multi-instance setups. @@ -383,3 +1242,9 @@ async def validate_provider_instance(provider: str, instance_url: str | None = N "error_message": str(e), "validation_timestamp": time.time() } + + + +def requires_max_completion_tokens(model_name: str) -> bool: + """Backward compatible alias for previous API.""" + return is_reasoning_model(model_name) diff --git a/python/src/server/services/provider_discovery_service.py b/python/src/server/services/provider_discovery_service.py index ccd811dd..2ea3bc32 100644 --- a/python/src/server/services/provider_discovery_service.py +++ b/python/src/server/services/provider_discovery_service.py @@ -2,7 +2,7 @@ Provider Discovery Service Discovers available models, checks provider health, and provides model specifications -for OpenAI, Google Gemini, Ollama, and Anthropic providers. +for OpenAI, Google Gemini, Ollama, Anthropic, and Grok providers. """ import time @@ -359,6 +359,36 @@ class ProviderDiscoveryService: return models + async def discover_grok_models(self, api_key: str) -> list[ModelSpec]: + """Discover available Grok models.""" + cache_key = f"grok_models_{hash(api_key)}" + cached = self._get_cached_result(cache_key) + if cached: + return cached + + models = [] + try: + # Grok model specifications + model_specs = [ + ModelSpec("grok-3-mini", "grok", 32768, True, True, False, None, 0.15, 0.60, "Fast and efficient Grok model"), + ModelSpec("grok-3", "grok", 32768, True, True, False, None, 2.00, 10.00, "Standard Grok model"), + ModelSpec("grok-4", "grok", 32768, True, True, False, None, 5.00, 25.00, "Advanced Grok model"), + ModelSpec("grok-2-vision", "grok", 8192, True, True, True, None, 3.00, 15.00, "Grok model with vision capabilities"), + ModelSpec("grok-2-latest", "grok", 8192, True, True, False, None, 2.00, 10.00, "Latest Grok 2 model"), + ] + + # Test connectivity - Grok doesn't have a models list endpoint, + # so we'll just return the known models if API key is provided + if api_key: + models = model_specs + self._cache_result(cache_key, models) + logger.info(f"Discovered {len(models)} Grok models") + + except Exception as e: + logger.error(f"Error discovering Grok models: {e}") + + return models + async def check_provider_health(self, provider: str, config: dict[str, Any]) -> ProviderStatus: """Check health and connectivity status of a provider.""" start_time = time.time() @@ -456,6 +486,23 @@ class ProviderDiscoveryService: last_checked=time.time() ) + elif provider == "grok": + api_key = config.get("api_key") + if not api_key: + return ProviderStatus(provider, False, None, "API key not configured") + + # Grok doesn't have a health check endpoint, so we'll assume it's available + # if API key is provided. In a real implementation, you might want to make a + # small test request to verify the key is valid. + response_time = (time.time() - start_time) * 1000 + return ProviderStatus( + provider="grok", + is_available=True, + response_time_ms=response_time, + models_available=5, # Known model count + last_checked=time.time() + ) + else: return ProviderStatus(provider, False, None, f"Unknown provider: {provider}") @@ -496,6 +543,11 @@ class ProviderDiscoveryService: if anthropic_key: providers["anthropic"] = await self.discover_anthropic_models(anthropic_key) + # Grok + grok_key = await credential_service.get_credential("GROK_API_KEY") + if grok_key: + providers["grok"] = await self.discover_grok_models(grok_key) + except Exception as e: logger.error(f"Error getting all available models: {e}") diff --git a/python/src/server/services/source_management_service.py b/python/src/server/services/source_management_service.py index c7bcdb66..f8a27023 100644 --- a/python/src/server/services/source_management_service.py +++ b/python/src/server/services/source_management_service.py @@ -11,7 +11,7 @@ from supabase import Client from ..config.logfire_config import get_logger, search_logger from .client_manager import get_supabase_client -from .llm_provider_service import get_llm_client +from .llm_provider_service import extract_message_text, get_llm_client logger = get_logger(__name__) @@ -72,20 +72,21 @@ The above content is from the documentation for '{source_id}'. Please provide a ) # Extract the generated summary with proper error handling - if not response or not response.choices or len(response.choices) == 0: - search_logger.error(f"Empty or invalid response from LLM for {source_id}") - return default_summary - - message_content = response.choices[0].message.content - if message_content is None: - search_logger.error(f"LLM returned None content for {source_id}") - return default_summary - - summary = message_content.strip() - - # Ensure the summary is not too long - if len(summary) > max_length: - summary = summary[:max_length] + "..." + if not response or not response.choices or len(response.choices) == 0: + search_logger.error(f"Empty or invalid response from LLM for {source_id}") + return default_summary + + choice = response.choices[0] + summary_text, _, _ = extract_message_text(choice) + if not summary_text: + search_logger.error(f"LLM returned None content for {source_id}") + return default_summary + + summary = summary_text.strip() + + # Ensure the summary is not too long + if len(summary) > max_length: + summary = summary[:max_length] + "..." return summary @@ -187,7 +188,9 @@ Generate only the title, nothing else.""" ], ) - generated_title = response.choices[0].message.content.strip() + choice = response.choices[0] + generated_title, _, _ = extract_message_text(choice) + generated_title = generated_title.strip() # Clean up the title generated_title = generated_title.strip("\"'") if len(generated_title) < 50: # Sanity check diff --git a/python/src/server/services/storage/code_storage_service.py b/python/src/server/services/storage/code_storage_service.py index ece5ea10..a993bc70 100644 --- a/python/src/server/services/storage/code_storage_service.py +++ b/python/src/server/services/storage/code_storage_service.py @@ -8,6 +8,7 @@ import asyncio import json import os import re +import time from collections import defaultdict, deque from collections.abc import Callable from difflib import SequenceMatcher @@ -17,25 +18,104 @@ from urllib.parse import urlparse from supabase import Client from ...config.logfire_config import search_logger +from ..credential_service import credential_service from ..embeddings.contextual_embedding_service import generate_contextual_embeddings_batch from ..embeddings.embedding_service import create_embeddings_batch +from ..llm_provider_service import ( + extract_json_from_reasoning, + extract_message_text, + get_llm_client, + prepare_chat_completion_params, + synthesize_json_from_reasoning, +) -def _get_model_choice() -> str: - """Get MODEL_CHOICE with direct fallback.""" +def _extract_json_payload(raw_response: str, context_code: str = "", language: str = "") -> str: + """Return the best-effort JSON object from an LLM response.""" + + if not raw_response: + return raw_response + + cleaned = raw_response.strip() + + # Check if this looks like reasoning text first + if _is_reasoning_text_response(cleaned): + # Try intelligent extraction from reasoning text with context + extracted = extract_json_from_reasoning(cleaned, context_code, language) + if extracted: + return extracted + # extract_json_from_reasoning may return nothing; synthesize a fallback JSON if so\ + fallback_json = synthesize_json_from_reasoning("", context_code, language) + if fallback_json: + return fallback_json + # If all else fails, return a minimal valid JSON object to avoid downstream errors + return '{"example_name": "Code Example", "summary": "Code example extracted from context."}' + + + if cleaned.startswith("```"): + lines = cleaned.splitlines() + # Drop opening fence + lines = lines[1:] + # Drop closing fence if present + if lines and lines[-1].strip().startswith("```"): + lines = lines[:-1] + cleaned = "\n".join(lines).strip() + + # Trim any leading/trailing text outside the outermost JSON braces + start = cleaned.find("{") + end = cleaned.rfind("}") + if start != -1 and end != -1 and end >= start: + cleaned = cleaned[start : end + 1] + + return cleaned.strip() + + +REASONING_STARTERS = [ + "okay, let's see", "okay, let me", "let me think", "first, i need to", "looking at this", + "i need to", "analyzing", "let me work through", "thinking about", "let me see" +] + +def _is_reasoning_text_response(text: str) -> bool: + """Detect if response is reasoning text rather than direct JSON.""" + if not text or len(text) < 20: + return False + + text_lower = text.lower().strip() + + # Check if it's clearly not JSON (starts with reasoning text) + starts_with_reasoning = any(text_lower.startswith(starter) for starter in REASONING_STARTERS) + + # Check if it lacks immediate JSON structure + lacks_immediate_json = not text_lower.lstrip().startswith('{') + + return starts_with_reasoning or (lacks_immediate_json and any(pattern in text_lower for pattern in REASONING_STARTERS)) +async def _get_model_choice() -> str: + """Get MODEL_CHOICE with provider-aware defaults from centralized service.""" try: - # Direct cache/env fallback - from ..credential_service import credential_service + # Get the active provider configuration + provider_config = await credential_service.get_active_provider("llm") + active_provider = provider_config.get("provider", "openai") + model = provider_config.get("chat_model") - if credential_service._cache_initialized and "MODEL_CHOICE" in credential_service._cache: - model = credential_service._cache["MODEL_CHOICE"] - else: - model = os.getenv("MODEL_CHOICE", "gpt-4.1-nano") - search_logger.debug(f"Using model choice: {model}") + # If no custom model is set, use provider-specific defaults + if not model or model.strip() == "": + # Provider-specific defaults + provider_defaults = { + "openai": "gpt-4o-mini", + "openrouter": "anthropic/claude-3.5-sonnet", + "google": "gemini-1.5-flash", + "ollama": "llama3.2:latest", + "anthropic": "claude-3-5-haiku-20241022", + "grok": "grok-3-mini" + } + model = provider_defaults.get(active_provider, "gpt-4o-mini") + search_logger.debug(f"Using default model for provider {active_provider}: {model}") + + search_logger.debug(f"Using model for provider {active_provider}: {model}") return model except Exception as e: search_logger.warning(f"Error getting model choice: {e}, using default") - return "gpt-4.1-nano" + return "gpt-4o-mini" def _get_max_workers() -> int: @@ -155,6 +235,7 @@ def _select_best_code_variant(similar_blocks: list[dict[str, Any]]) -> dict[str, return best_block + def extract_code_blocks(markdown_content: str, min_length: int = None) -> list[dict[str, Any]]: """ Extract code blocks from markdown content along with context. @@ -168,8 +249,6 @@ def extract_code_blocks(markdown_content: str, min_length: int = None) -> list[d """ # Load all code extraction settings with direct fallback try: - from ...services.credential_service import credential_service - def _get_setting_fallback(key: str, default: str) -> str: if credential_service._cache_initialized and key in credential_service._cache: return credential_service._cache[key] @@ -507,7 +586,7 @@ def generate_code_example_summary( A dictionary with 'summary' and 'example_name' """ import asyncio - + # Run the async version in the current thread return asyncio.run(_generate_code_example_summary_async(code, context_before, context_after, language, provider)) @@ -518,13 +597,22 @@ async def _generate_code_example_summary_async( """ Async version of generate_code_example_summary using unified LLM provider service. """ - from ..llm_provider_service import get_llm_client - - # Get model choice from credential service (RAG setting) - model_choice = _get_model_choice() - # Create the prompt - prompt = f""" + # Get model choice from credential service (RAG setting) + model_choice = await _get_model_choice() + + # If provider is not specified, get it from credential service + if provider is None: + try: + provider_config = await credential_service.get_active_provider("llm") + provider = provider_config.get("provider", "openai") + search_logger.debug(f"Auto-detected provider from credential service: {provider}") + except Exception as e: + search_logger.warning(f"Failed to get provider from credential service: {e}, defaulting to openai") + provider = "openai" + + # Create the prompt variants: base prompt, guarded prompt (JSON reminder), and strict prompt for retries + base_prompt = f""" {context_before[-500:] if len(context_before) > 500 else context_before} @@ -548,6 +636,16 @@ Format your response as JSON: "summary": "2-3 sentence description of what the code demonstrates" }} """ + guard_prompt = ( + base_prompt + + "\n\nImportant: Respond with a valid JSON object that exactly matches the keys " + '{"example_name": string, "summary": string}. Do not include commentary, ' + "markdown fences, or reasoning notes." + ) + strict_prompt = ( + guard_prompt + + "\n\nSecond attempt enforcement: Return JSON only with the exact schema. No additional text or reasoning content." + ) try: # Use unified LLM provider service @@ -555,25 +653,261 @@ Format your response as JSON: search_logger.info( f"Generating summary for {hash(code) & 0xffffff:06x} using model: {model_choice}" ) - - response = await client.chat.completions.create( - model=model_choice, - messages=[ - { - "role": "system", - "content": "You are a helpful assistant that analyzes code examples and provides JSON responses with example names and summaries.", - }, - {"role": "user", "content": prompt}, - ], - response_format={"type": "json_object"}, - max_tokens=500, - temperature=0.3, + + provider_lower = provider.lower() + is_grok_model = (provider_lower == "grok") or ("grok" in model_choice.lower()) + + supports_response_format_base = ( + provider_lower in {"openai", "google", "anthropic"} + or (provider_lower == "openrouter" and model_choice.startswith("openai/")) ) - response_content = response.choices[0].message.content.strip() + last_response_obj = None + last_elapsed_time = None + last_response_content = "" + last_json_error: json.JSONDecodeError | None = None + + for enforce_json, current_prompt in ((False, guard_prompt), (True, strict_prompt)): + request_params = { + "model": model_choice, + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant that analyzes code examples and provides JSON responses with example names and summaries.", + }, + {"role": "user", "content": current_prompt}, + ], + "max_tokens": 2000, + "temperature": 0.3, + } + + should_use_response_format = False + if enforce_json: + if not is_grok_model and (supports_response_format_base or provider_lower == "openrouter"): + should_use_response_format = True + else: + if supports_response_format_base: + should_use_response_format = True + + if should_use_response_format: + request_params["response_format"] = {"type": "json_object"} + + if is_grok_model: + unsupported_params = ["presence_penalty", "frequency_penalty", "stop", "reasoning_effort"] + for param in unsupported_params: + if param in request_params: + removed_value = request_params.pop(param) + search_logger.warning(f"Removed unsupported Grok parameter '{param}': {removed_value}") + + supported_params = ["model", "messages", "max_tokens", "temperature", "response_format", "stream", "tools", "tool_choice"] + for param in list(request_params.keys()): + if param not in supported_params: + search_logger.warning(f"Parameter '{param}' may not be supported by Grok reasoning models") + + start_time = time.time() + max_retries = 3 if is_grok_model else 1 + retry_delay = 1.0 + response_content_local = "" + reasoning_text_local = "" + json_error_occurred = False + + for attempt in range(max_retries): + try: + if is_grok_model and attempt > 0: + search_logger.info(f"Grok retry attempt {attempt + 1}/{max_retries} after {retry_delay:.1f}s delay") + await asyncio.sleep(retry_delay) + + final_params = prepare_chat_completion_params(model_choice, request_params) + response = await client.chat.completions.create(**final_params) + last_response_obj = response + + choice = response.choices[0] if response.choices else None + message = choice.message if choice and hasattr(choice, "message") else None + response_content_local = "" + reasoning_text_local = "" + + if choice: + response_content_local, reasoning_text_local, _ = extract_message_text(choice) + + # Enhanced logging for response analysis + if message and reasoning_text_local: + content_preview = response_content_local[:100] if response_content_local else "None" + reasoning_preview = reasoning_text_local[:100] if reasoning_text_local else "None" + search_logger.debug( + f"Response has reasoning content - content: '{content_preview}', reasoning: '{reasoning_preview}'" + ) + + if response_content_local: + last_response_content = response_content_local.strip() + + # Pre-validate response before processing + if len(last_response_content) < 20 or (len(last_response_content) < 50 and not last_response_content.strip().startswith('{')): + # Very minimal response - likely "Okay\nOkay" type + search_logger.debug(f"Minimal response detected: {repr(last_response_content)}") + # Generate fallback directly from context + fallback_json = synthesize_json_from_reasoning("", code, language) + if fallback_json: + try: + result = json.loads(fallback_json) + final_result = { + "example_name": result.get("example_name", f"Code Example{f' ({language})' if language else ''}"), + "summary": result.get("summary", "Code example for demonstration purposes."), + } + search_logger.info(f"Generated fallback summary from context - Name: '{final_result['example_name']}', Summary length: {len(final_result['summary'])}") + return final_result + except json.JSONDecodeError: + pass # Continue to normal error handling + else: + # Even synthesis failed - provide hardcoded fallback for minimal responses + final_result = { + "example_name": f"Code Example{f' ({language})' if language else ''}", + "summary": "Code example extracted from development context.", + } + search_logger.info(f"Used hardcoded fallback for minimal response - Name: '{final_result['example_name']}', Summary length: {len(final_result['summary'])}") + return final_result + + payload = _extract_json_payload(last_response_content, code, language) + if payload != last_response_content: + search_logger.debug( + f"Sanitized LLM response payload before parsing: {repr(payload[:200])}..." + ) + + try: + result = json.loads(payload) + + if not result.get("example_name") or not result.get("summary"): + search_logger.warning(f"Incomplete response from LLM: {result}") + + final_result = { + "example_name": result.get( + "example_name", f"Code Example{f' ({language})' if language else ''}" + ), + "summary": result.get("summary", "Code example for demonstration purposes."), + } + + search_logger.info( + f"Generated code example summary - Name: '{final_result['example_name']}', Summary length: {len(final_result['summary'])}" + ) + return final_result + + except json.JSONDecodeError as json_error: + last_json_error = json_error + json_error_occurred = True + snippet = last_response_content[:200] + if not enforce_json: + # Check if this was reasoning text that couldn't be parsed + if _is_reasoning_text_response(last_response_content): + search_logger.debug( + f"Reasoning text detected but no JSON extracted. Response snippet: {repr(snippet)}" + ) + else: + search_logger.warning( + f"Failed to parse JSON response from LLM (non-strict attempt). Error: {json_error}. Response snippet: {repr(snippet)}" + ) + break + else: + search_logger.error( + f"Strict JSON enforcement still failed to produce valid JSON: {json_error}. Response snippet: {repr(snippet)}" + ) + break + + elif is_grok_model and attempt < max_retries - 1: + search_logger.warning(f"Grok empty response on attempt {attempt + 1}, retrying...") + retry_delay *= 2 + continue + else: + break + + except Exception as e: + if is_grok_model and attempt < max_retries - 1: + search_logger.error(f"Grok request failed on attempt {attempt + 1}: {e}, retrying...") + retry_delay *= 2 + continue + else: + raise + + if is_grok_model: + elapsed_time = time.time() - start_time + last_elapsed_time = elapsed_time + search_logger.debug(f"Grok total response time: {elapsed_time:.2f}s") + + if json_error_occurred: + if not enforce_json: + continue + else: + break + + if response_content_local: + # We would have returned already on success; if we reach here, parsing failed but we are not retrying + continue + + response_content = last_response_content + response = last_response_obj + elapsed_time = last_elapsed_time if last_elapsed_time is not None else 0.0 + + if last_json_error is not None and response_content: + search_logger.error( + f"LLM response after strict enforcement was still not valid JSON: {last_json_error}. Clearing response to trigger error handling." + ) + response_content = "" + + if not response_content: + search_logger.error(f"Empty response from LLM for model: {model_choice} (provider: {provider})") + if is_grok_model: + search_logger.error("Grok empty response debugging:") + search_logger.error(f" - Request took: {elapsed_time:.2f}s") + search_logger.error(f" - Response status: {getattr(response, 'status_code', 'N/A')}") + search_logger.error(f" - Response headers: {getattr(response, 'headers', 'N/A')}") + search_logger.error(f" - Full response: {response}") + search_logger.error(f" - Response choices length: {len(response.choices) if response.choices else 0}") + if response.choices: + search_logger.error(f" - First choice: {response.choices[0]}") + search_logger.error(f" - Message content: '{response.choices[0].message.content}'") + search_logger.error(f" - Message role: {response.choices[0].message.role}") + search_logger.error("Check: 1) API key validity, 2) rate limits, 3) model availability") + + # Implement fallback for Grok failures + search_logger.warning("Attempting fallback to OpenAI due to Grok failure...") + try: + # Use OpenAI as fallback with similar parameters + fallback_params = { + "model": "gpt-4o-mini", + "messages": request_params["messages"], + "temperature": request_params.get("temperature", 0.1), + "max_tokens": request_params.get("max_tokens", 500), + } + + async with get_llm_client(provider="openai") as fallback_client: + fallback_response = await fallback_client.chat.completions.create(**fallback_params) + fallback_content = fallback_response.choices[0].message.content + if fallback_content and fallback_content.strip(): + search_logger.info("gpt-4o-mini fallback succeeded") + response_content = fallback_content.strip() + else: + search_logger.error("gpt-4o-mini fallback also returned empty response") + raise ValueError(f"Both {model_choice} and gpt-4o-mini fallback failed") + + except Exception as fallback_error: + search_logger.error(f"gpt-4o-mini fallback failed: {fallback_error}") + raise ValueError(f"{model_choice} failed and fallback to gpt-4o-mini also failed: {fallback_error}") from fallback_error + else: + search_logger.debug(f"Full response object: {response}") + raise ValueError("Empty response from LLM") + + if not response_content: + # This should not happen after fallback logic, but safety check + raise ValueError("No valid response content after all attempts") + + response_content = response_content.strip() search_logger.debug(f"LLM API response: {repr(response_content[:200])}...") - result = json.loads(response_content) + payload = _extract_json_payload(response_content, code, language) + if payload != response_content: + search_logger.debug( + f"Sanitized LLM response payload before parsing: {repr(payload[:200])}..." + ) + + result = json.loads(payload) # Validate the response has the required fields if not result.get("example_name") or not result.get("summary"): @@ -595,12 +929,38 @@ Format your response as JSON: search_logger.error( f"Failed to parse JSON response from LLM: {e}, Response: {repr(response_content) if 'response_content' in locals() else 'No response'}" ) + # Try to generate context-aware fallback + try: + fallback_json = synthesize_json_from_reasoning("", code, language) + if fallback_json: + fallback_result = json.loads(fallback_json) + search_logger.info(f"Generated context-aware fallback summary") + return { + "example_name": fallback_result.get("example_name", f"Code Example{f' ({language})' if language else ''}"), + "summary": fallback_result.get("summary", "Code example for demonstration purposes."), + } + except Exception: + pass # Fall through to generic fallback + return { "example_name": f"Code Example{f' ({language})' if language else ''}", "summary": "Code example for demonstration purposes.", } except Exception as e: search_logger.error(f"Error generating code summary using unified LLM provider: {e}") + # Try to generate context-aware fallback + try: + fallback_json = synthesize_json_from_reasoning("", code, language) + if fallback_json: + fallback_result = json.loads(fallback_json) + search_logger.info(f"Generated context-aware fallback summary after error") + return { + "example_name": fallback_result.get("example_name", f"Code Example{f' ({language})' if language else ''}"), + "summary": fallback_result.get("summary", "Code example for demonstration purposes."), + } + except Exception: + pass # Fall through to generic fallback + return { "example_name": f"Code Example{f' ({language})' if language else ''}", "summary": "Code example for demonstration purposes.", @@ -608,7 +968,7 @@ Format your response as JSON: async def generate_code_summaries_batch( - code_blocks: list[dict[str, Any]], max_workers: int = None, progress_callback=None + code_blocks: list[dict[str, Any]], max_workers: int = None, progress_callback=None, provider: str = None ) -> list[dict[str, str]]: """ Generate summaries for multiple code blocks with rate limiting and proper worker management. @@ -617,6 +977,7 @@ async def generate_code_summaries_batch( code_blocks: List of code block dictionaries max_workers: Maximum number of concurrent API requests progress_callback: Optional callback for progress updates (async function) + provider: LLM provider to use for generation (e.g., 'grok', 'openai', 'anthropic') Returns: List of summary dictionaries @@ -627,8 +988,6 @@ async def generate_code_summaries_batch( # Get max_workers from settings if not provided if max_workers is None: try: - from ...services.credential_service import credential_service - if ( credential_service._cache_initialized and "CODE_SUMMARY_MAX_WORKERS" in credential_service._cache @@ -663,6 +1022,7 @@ async def generate_code_summaries_batch( block["context_before"], block["context_after"], block.get("language", ""), + provider, ) # Update progress @@ -757,29 +1117,17 @@ async def add_code_examples_to_supabase( except Exception as e: search_logger.error(f"Error deleting existing code examples for {url}: {e}") - # Check if contextual embeddings are enabled + # Check if contextual embeddings are enabled (use proper async method like document storage) try: - from ..credential_service import credential_service - - use_contextual_embeddings = credential_service._cache.get("USE_CONTEXTUAL_EMBEDDINGS") - if isinstance(use_contextual_embeddings, str): - use_contextual_embeddings = use_contextual_embeddings.lower() == "true" - elif isinstance(use_contextual_embeddings, dict) and use_contextual_embeddings.get( - "is_encrypted" - ): - # Handle encrypted value - encrypted_value = use_contextual_embeddings.get("encrypted_value") - if encrypted_value: - try: - decrypted = credential_service._decrypt_value(encrypted_value) - use_contextual_embeddings = decrypted.lower() == "true" - except: - use_contextual_embeddings = False - else: - use_contextual_embeddings = False + raw_value = await credential_service.get_credential( + "USE_CONTEXTUAL_EMBEDDINGS", "false", decrypt=True + ) + if isinstance(raw_value, str): + use_contextual_embeddings = raw_value.lower() == "true" else: - use_contextual_embeddings = bool(use_contextual_embeddings) - except: + use_contextual_embeddings = bool(raw_value) + except Exception as e: + search_logger.error(f"DEBUG: Error reading contextual embeddings: {e}") # Fallback to environment variable use_contextual_embeddings = ( os.getenv("USE_CONTEXTUAL_EMBEDDINGS", "false").lower() == "true" @@ -848,14 +1196,13 @@ async def add_code_examples_to_supabase( # Use only successful embeddings valid_embeddings = result.embeddings successful_texts = result.texts_processed - + # Get model information for tracking from ..llm_provider_service import get_embedding_model - from ..credential_service import credential_service - + # Get embedding model name embedding_model_name = await get_embedding_model(provider=provider) - + # Get LLM chat model (used for code summaries and contextual embeddings if enabled) llm_chat_model = None try: @@ -868,7 +1215,7 @@ async def add_code_examples_to_supabase( llm_chat_model = await credential_service.get_credential("MODEL_CHOICE", "gpt-4o-mini") else: # For code summaries, we use MODEL_CHOICE - llm_chat_model = _get_model_choice() + llm_chat_model = await _get_model_choice() except Exception as e: search_logger.warning(f"Failed to get LLM chat model: {e}") llm_chat_model = "gpt-4o-mini" # Default fallback @@ -888,7 +1235,7 @@ async def add_code_examples_to_supabase( positions_by_text[text].append(original_indices[k]) # Map successful texts back to their original indices - for embedding, text in zip(valid_embeddings, successful_texts, strict=False): + for embedding, text in zip(valid_embeddings, successful_texts, strict=True): # Get the next available index for this text (handles duplicates) if positions_by_text[text]: orig_idx = positions_by_text[text].popleft() # Original j index in [i, batch_end) @@ -908,7 +1255,7 @@ async def add_code_examples_to_supabase( # Determine the correct embedding column based on dimension embedding_dim = len(embedding) if isinstance(embedding, list) else len(embedding.tolist()) embedding_column = None - + if embedding_dim == 768: embedding_column = "embedding_768" elif embedding_dim == 1024: @@ -918,10 +1265,12 @@ async def add_code_examples_to_supabase( elif embedding_dim == 3072: embedding_column = "embedding_3072" else: - # Default to closest supported dimension - search_logger.warning(f"Unsupported embedding dimension {embedding_dim}, using embedding_1536") - embedding_column = "embedding_1536" - + # Skip unsupported dimensions to avoid corrupting the schema + search_logger.error( + f"Unsupported embedding dimension {embedding_dim}; skipping record to prevent column mismatch" + ) + continue + batch_data.append({ "url": urls[idx], "chunk_number": chunk_numbers[idx], @@ -954,9 +1303,7 @@ async def add_code_examples_to_supabase( f"Error inserting batch into Supabase (attempt {retry + 1}/{max_retries}): {e}" ) search_logger.info(f"Retrying in {retry_delay} seconds...") - import time - - time.sleep(retry_delay) + await asyncio.sleep(retry_delay) retry_delay *= 2 # Exponential backoff else: # Final attempt failed diff --git a/python/tests/test_async_llm_provider_service.py b/python/tests/test_async_llm_provider_service.py index e52c2242..db02c874 100644 --- a/python/tests/test_async_llm_provider_service.py +++ b/python/tests/test_async_llm_provider_service.py @@ -33,6 +33,12 @@ class AsyncContextManager: class TestAsyncLLMProviderService: """Test suite for async LLM provider service functions""" + @staticmethod + def _make_mock_client(): + client = MagicMock() + client.aclose = AsyncMock() + return client + @pytest.fixture(autouse=True) def clear_cache(self): """Clear cache before each test""" @@ -98,7 +104,7 @@ class TestAsyncLLMProviderService: with patch( "src.server.services.llm_provider_service.openai.AsyncOpenAI" ) as mock_openai: - mock_client = MagicMock() + mock_client = self._make_mock_client() mock_openai.return_value = mock_client async with get_llm_client() as client: @@ -121,7 +127,7 @@ class TestAsyncLLMProviderService: with patch( "src.server.services.llm_provider_service.openai.AsyncOpenAI" ) as mock_openai: - mock_client = MagicMock() + mock_client = self._make_mock_client() mock_openai.return_value = mock_client async with get_llm_client() as client: @@ -143,7 +149,7 @@ class TestAsyncLLMProviderService: with patch( "src.server.services.llm_provider_service.openai.AsyncOpenAI" ) as mock_openai: - mock_client = MagicMock() + mock_client = self._make_mock_client() mock_openai.return_value = mock_client async with get_llm_client() as client: @@ -166,7 +172,7 @@ class TestAsyncLLMProviderService: with patch( "src.server.services.llm_provider_service.openai.AsyncOpenAI" ) as mock_openai: - mock_client = MagicMock() + mock_client = self._make_mock_client() mock_openai.return_value = mock_client async with get_llm_client(provider="openai") as client: @@ -194,7 +200,7 @@ class TestAsyncLLMProviderService: with patch( "src.server.services.llm_provider_service.openai.AsyncOpenAI" ) as mock_openai: - mock_client = MagicMock() + mock_client = self._make_mock_client() mock_openai.return_value = mock_client async with get_llm_client(use_embedding_provider=True) as client: @@ -225,7 +231,7 @@ class TestAsyncLLMProviderService: with patch( "src.server.services.llm_provider_service.openai.AsyncOpenAI" ) as mock_openai: - mock_client = MagicMock() + mock_client = self._make_mock_client() mock_openai.return_value = mock_client # Should fallback to Ollama instead of raising an error @@ -426,7 +432,7 @@ class TestAsyncLLMProviderService: with patch( "src.server.services.llm_provider_service.openai.AsyncOpenAI" ) as mock_openai: - mock_client = MagicMock() + mock_client = self._make_mock_client() mock_openai.return_value = mock_client # First call should hit the credential service @@ -464,7 +470,7 @@ class TestAsyncLLMProviderService: with patch( "src.server.services.llm_provider_service.openai.AsyncOpenAI" ) as mock_openai: - mock_client = MagicMock() + mock_client = self._make_mock_client() mock_openai.return_value = mock_client client_ref = None @@ -474,6 +480,7 @@ class TestAsyncLLMProviderService: # After context manager exits, should still have reference to client assert client_ref == mock_client + mock_client.aclose.assert_awaited_once() @pytest.mark.asyncio async def test_multiple_providers_in_sequence(self, mock_credential_service): @@ -494,7 +501,7 @@ class TestAsyncLLMProviderService: with patch( "src.server.services.llm_provider_service.openai.AsyncOpenAI" ) as mock_openai: - mock_client = MagicMock() + mock_client = self._make_mock_client() mock_openai.return_value = mock_client for config in configs: diff --git a/python/tests/test_code_extraction_source_id.py b/python/tests/test_code_extraction_source_id.py index 7de851f5..05405ee7 100644 --- a/python/tests/test_code_extraction_source_id.py +++ b/python/tests/test_code_extraction_source_id.py @@ -104,13 +104,15 @@ class TestCodeExtractionSourceId: ) # Verify the correct source_id was passed (now with cancellation_check parameter) - mock_extract.assert_called_once_with( - crawl_results, - url_to_full_document, - source_id, # This should be the third argument - None, - None # cancellation_check parameter - ) + mock_extract.assert_called_once() + args, kwargs = mock_extract.call_args + assert args[0] == crawl_results + assert args[1] == url_to_full_document + assert args[2] == source_id + assert args[3] is None + assert args[4] is None + if len(args) > 5: + assert args[5] is None assert result == 5 @pytest.mark.asyncio @@ -174,4 +176,4 @@ class TestCodeExtractionSourceId: import inspect source = inspect.getsource(module) assert "from urllib.parse import urlparse" not in source, \ - "Should not import urlparse since we don't extract domain from URL anymore" \ No newline at end of file + "Should not import urlparse since we don't extract domain from URL anymore"