From 394ac1befaf706a068d9e434a1446691f71a3fb5 Mon Sep 17 00:00:00 2001 From: Josh Date: Mon, 22 Sep 2025 02:36:30 -0500 Subject: [PATCH 1/7] Feat:Openrouter/Anthropic/grok-support (#231) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add Anthropic and Grok provider support * feat: Add crucial GPT-5 and reasoning model support for OpenRouter - Add requires_max_completion_tokens() function for GPT-5, o1, o3, Grok-3 series - Add prepare_chat_completion_params() for reasoning model compatibility - Implement max_tokens → max_completion_tokens conversion for reasoning models - Add temperature handling for reasoning models (must be 1.0 default) - Enhanced provider validation and API key security in provider endpoints - Streamlined retry logic (3→2 attempts) for faster issue detection - Add failure tracking and circuit breaker analysis for debugging - Support OpenRouter format detection (openai/gpt-5-nano, openai/o1-mini) - Improved Grok provider empty response handling with structured fallbacks - Enhanced contextual embedding with provider-aware model selection Core provider functionality: - OpenRouter, Grok, Anthropic provider support with full embedding integration - Provider-specific model defaults and validation - Secure API connectivity testing endpoints - Provider context passing for code generation workflows 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude * fully working model providers, addressing securtiy and code related concerns, throughly hardening our code * added multiprovider support, embeddings model support, cleaned the pr, need to fix health check, asnyico tasks errors, and contextual embeddings error * fixed contextual embeddings issue * - Added inspect-aware shutdown handling so get_llm_client always closes the underlying AsyncOpenAI / httpx.AsyncClient while the loop is still alive, with defensive logging if shutdown happens late (python/src/server/services/llm_provider_service.py:14, python/src/server/ services/llm_provider_service.py:520). * - Restructured get_llm_client so client creation and usage live in separate try/finally blocks; fallback clients now close without logging spurious Error creating LLM client when downstream code raises (python/src/server/services/llm_provider_service.py:335-556). - Close logic now sanitizes provider names consistently and awaits whichever aclose/close coroutine the SDK exposes, keeping the loop shut down cleanly (python/src/server/services/llm_provider_service.py:530-559). Robust JSON Parsing - Added _extract_json_payload to strip code fences / extra text returned by Ollama before json.loads runs, averting the markdown-induced decode errors you saw in logs (python/src/server/services/storage/code_storage_service.py:40-63). - Swapped the direct parse call for the sanitized payload and emit a debug preview when cleanup alters the content (python/src/server/ services/storage/code_storage_service.py:858-864). * added provider connection support * added provider api key not being configured warning * Updated get_llm_client so missing OpenAI keys automatically fall back to Ollama (matching existing tests) and so unsupported providers still raise the legacy ValueError the suite expects. The fallback now reuses _get_optimal_ollama_instance and rethrows ValueError(OpenAI API key not found and Ollama fallback failed) when it cant connect. Adjusted test_code_extraction_source_id.py to accept the new optional argument on the mocked extractor (and confirm its None when present). * Resolved a few needed code rabbit suggestion - Updated the knowledge API key validation to call create_embedding with the provider argument and removed the hard-coded OpenAI fallback (python/src/server/api_routes/knowledge_api.py). - Broadened embedding provider detection so prefixed OpenRouter/OpenAI model names route through the correct client (python/src/server/ services/embeddings/embedding_service.py, python/src/server/services/llm_provider_service.py). - Removed the duplicate helper definitions from llm_provider_service.py, eliminating the stray docstring that was causing the import-time syntax error. * updated via code rabbit PR review, code rabbit in my IDE found no issues and no nitpicks with the updates! what was done: Credential service now persists the provider under the uppercase key LLM_PROVIDER, matching the read path (no new EMBEDDING_PROVIDER usage introduced). Embedding batch creation stops inserting blank strings, logging failures and skipping invalid items before they ever hit the provider (python/src/server/services/embeddings/embedding_service.py). Contextual embedding prompts use real newline characters everywhereboth when constructing the batch prompt and when parsing the models response (python/src/server/services/embeddings/contextual_embedding_service.py). Embedding provider routing already recognizes OpenRouter-prefixed OpenAI models via is_openai_embedding_model; no further change needed there. Embedding insertion now skips unsupported vector dimensions instead of forcing them into the 1536-column, and the backoff loop uses await asyncio.sleep so we no longer block the event loop (python/src/server/services/storage/code_storage_service.py). RAG settings props were extended to include LLM_INSTANCE_NAME and OLLAMA_EMBEDDING_INSTANCE_NAME, and the debug log no longer prints API-key prefixes (the rest of the TanStack refactor/EMBEDDING_PROVIDER support remains deferred). * test fix * enhanced Openrouters parsing logic to automatically detect reasoning models and parse regardless of json output or not. this commit creates a robust way for archons parsing to work throughly with openrouter automatically, regardless of the model youre using, to ensure proper functionality with out breaking any generation capabilities! --------- Co-authored-by: Chillbruhhh Co-authored-by: Claude --- .../src/components/settings/RAGSettings.tsx | 435 ++++---- python/src/server/api_routes/__init__.py | 2 + python/src/server/api_routes/knowledge_api.py | 69 +- python/src/server/api_routes/providers_api.py | 154 +++ python/src/server/main.py | 2 + .../crawling/code_extraction_service.py | 11 +- .../services/crawling/crawling_service.py | 12 + .../crawling/document_storage_operations.py | 4 +- .../src/server/services/credential_service.py | 94 +- .../contextual_embedding_service.py | 102 +- .../services/embeddings/embedding_service.py | 59 +- .../server/services/llm_provider_service.py | 973 +++++++++++++++++- .../services/provider_discovery_service.py | 54 +- .../services/source_management_service.py | 35 +- .../services/storage/code_storage_service.py | 491 +++++++-- .../tests/test_async_llm_provider_service.py | 25 +- .../tests/test_code_extraction_source_id.py | 18 +- 17 files changed, 2090 insertions(+), 450 deletions(-) create mode 100644 python/src/server/api_routes/providers_api.py diff --git a/archon-ui-main/src/components/settings/RAGSettings.tsx b/archon-ui-main/src/components/settings/RAGSettings.tsx index 8cb721a7..a60925bb 100644 --- a/archon-ui-main/src/components/settings/RAGSettings.tsx +++ b/archon-ui-main/src/components/settings/RAGSettings.tsx @@ -9,6 +9,103 @@ import { credentialsService } from '../../services/credentialsService'; import OllamaModelDiscoveryModal from './OllamaModelDiscoveryModal'; import OllamaModelSelectionModal from './OllamaModelSelectionModal'; +type ProviderKey = 'openai' | 'google' | 'ollama' | 'anthropic' | 'grok' | 'openrouter'; + +interface ProviderModels { + chatModel: string; + embeddingModel: string; +} + +type ProviderModelMap = Record; + +// Provider model persistence helpers +const PROVIDER_MODELS_KEY = 'archon_provider_models'; + +const getDefaultModels = (provider: ProviderKey): ProviderModels => { + const chatDefaults: Record = { + openai: 'gpt-4o-mini', + anthropic: 'claude-3-5-sonnet-20241022', + google: 'gemini-1.5-flash', + grok: 'grok-3-mini', // Updated to use grok-3-mini as default + openrouter: 'openai/gpt-4o-mini', + ollama: 'llama3:8b' + }; + + const embeddingDefaults: Record = { + openai: 'text-embedding-3-small', + anthropic: 'text-embedding-3-small', // Fallback to OpenAI + google: 'text-embedding-004', + grok: 'text-embedding-3-small', // Fallback to OpenAI + openrouter: 'text-embedding-3-small', + ollama: 'nomic-embed-text' + }; + + return { + chatModel: chatDefaults[provider], + embeddingModel: embeddingDefaults[provider] + }; +}; + +const saveProviderModels = (providerModels: ProviderModelMap): void => { + try { + localStorage.setItem(PROVIDER_MODELS_KEY, JSON.stringify(providerModels)); + } catch (error) { + console.error('Failed to save provider models:', error); + } +}; + +const loadProviderModels = (): ProviderModelMap => { + try { + const saved = localStorage.getItem(PROVIDER_MODELS_KEY); + if (saved) { + return JSON.parse(saved); + } + } catch (error) { + console.error('Failed to load provider models:', error); + } + + // Return defaults for all providers if nothing saved + const providers: ProviderKey[] = ['openai', 'google', 'openrouter', 'ollama', 'anthropic', 'grok']; + const defaultModels: ProviderModelMap = {} as ProviderModelMap; + + providers.forEach(provider => { + defaultModels[provider] = getDefaultModels(provider); + }); + + return defaultModels; +}; + +// Static color styles mapping (prevents Tailwind JIT purging) +const colorStyles: Record = { + openai: 'border-green-500 bg-green-500/10', + google: 'border-blue-500 bg-blue-500/10', + openrouter: 'border-cyan-500 bg-cyan-500/10', + ollama: 'border-purple-500 bg-purple-500/10', + anthropic: 'border-orange-500 bg-orange-500/10', + grok: 'border-yellow-500 bg-yellow-500/10', +}; + +const providerAlertStyles: Record = { + openai: 'bg-green-50 dark:bg-green-900/20 border-green-200 dark:border-green-800 text-green-800 dark:text-green-300', + google: 'bg-blue-50 dark:bg-blue-900/20 border-blue-200 dark:border-blue-800 text-blue-800 dark:text-blue-300', + openrouter: 'bg-cyan-50 dark:bg-cyan-900/20 border-cyan-200 dark:border-cyan-800 text-cyan-800 dark:text-cyan-300', + ollama: 'bg-purple-50 dark:bg-purple-900/20 border-purple-200 dark:border-purple-800 text-purple-800 dark:text-purple-300', + anthropic: 'bg-orange-50 dark:bg-orange-900/20 border-orange-200 dark:border-orange-800 text-orange-800 dark:text-orange-300', + grok: 'bg-yellow-50 dark:bg-yellow-900/20 border-yellow-200 dark:border-yellow-800 text-yellow-800 dark:text-yellow-300', +}; + +const providerAlertMessages: Record = { + openai: 'Configure your OpenAI API key in the credentials section to use GPT models.', + google: 'Configure your Google API key in the credentials section to use Gemini models.', + openrouter: 'Configure your OpenRouter API key in the credentials section to use models.', + ollama: 'Configure your Ollama instances in this panel to connect local models.', + anthropic: 'Configure your Anthropic API key in the credentials section to use Claude models.', + grok: 'Configure your Grok API key in the credentials section to use Grok models.', +}; + +const isProviderKey = (value: unknown): value is ProviderKey => + typeof value === 'string' && ['openai', 'google', 'openrouter', 'ollama', 'anthropic', 'grok'].includes(value); + interface RAGSettingsProps { ragSettings: { MODEL_CHOICE: string; @@ -19,8 +116,10 @@ interface RAGSettingsProps { USE_RERANKING: boolean; LLM_PROVIDER?: string; LLM_BASE_URL?: string; + LLM_INSTANCE_NAME?: string; EMBEDDING_MODEL?: string; OLLAMA_EMBEDDING_URL?: string; + OLLAMA_EMBEDDING_INSTANCE_NAME?: string; // Crawling Performance Settings CRAWL_BATCH_SIZE?: number; CRAWL_MAX_CONCURRENT?: number; @@ -57,7 +156,10 @@ export const RAGSettings = ({ // Model selection modals state const [showLLMModelSelectionModal, setShowLLMModelSelectionModal] = useState(false); const [showEmbeddingModelSelectionModal, setShowEmbeddingModelSelectionModal] = useState(false); - + + // Provider-specific model persistence state + const [providerModels, setProviderModels] = useState(() => loadProviderModels()); + // Instance configurations const [llmInstanceConfig, setLLMInstanceConfig] = useState({ name: '', @@ -113,6 +215,25 @@ export const RAGSettings = ({ } }, [ragSettings.OLLAMA_EMBEDDING_URL, ragSettings.OLLAMA_EMBEDDING_INSTANCE_NAME]); + // Provider model persistence effects + useEffect(() => { + // Update provider models when current models change + const currentProvider = ragSettings.LLM_PROVIDER as ProviderKey; + if (currentProvider && ragSettings.MODEL_CHOICE && ragSettings.EMBEDDING_MODEL) { + setProviderModels(prev => { + const updated = { + ...prev, + [currentProvider]: { + chatModel: ragSettings.MODEL_CHOICE, + embeddingModel: ragSettings.EMBEDDING_MODEL + } + }; + saveProviderModels(updated); + return updated; + }); + } + }, [ragSettings.MODEL_CHOICE, ragSettings.EMBEDDING_MODEL, ragSettings.LLM_PROVIDER]); + // Load API credentials for status checking useEffect(() => { const loadApiCredentials = async () => { @@ -197,58 +318,27 @@ export const RAGSettings = ({ }>({}); // Test connection to external providers - const testProviderConnection = async (provider: string, apiKey: string): Promise => { + const testProviderConnection = async (provider: string): Promise => { setProviderConnectionStatus(prev => ({ ...prev, [provider]: { ...prev[provider], checking: true } })); try { - switch (provider) { - case 'openai': - // Test OpenAI connection with a simple completion request - const openaiResponse = await fetch('https://api.openai.com/v1/models', { - method: 'GET', - headers: { - 'Authorization': `Bearer ${apiKey}`, - 'Content-Type': 'application/json' - } - }); - - if (openaiResponse.ok) { - setProviderConnectionStatus(prev => ({ - ...prev, - openai: { connected: true, checking: false, lastChecked: new Date() } - })); - return true; - } else { - throw new Error(`OpenAI API returned ${openaiResponse.status}`); - } + // Use server-side API endpoint for secure connectivity testing + const response = await fetch(`/api/providers/${provider}/status`); + const result = await response.json(); - case 'google': - // Test Google Gemini connection - const googleResponse = await fetch(`https://generativelanguage.googleapis.com/v1/models?key=${apiKey}`, { - method: 'GET', - headers: { - 'Content-Type': 'application/json' - } - }); - - if (googleResponse.ok) { - setProviderConnectionStatus(prev => ({ - ...prev, - google: { connected: true, checking: false, lastChecked: new Date() } - })); - return true; - } else { - throw new Error(`Google API returned ${googleResponse.status}`); - } + const isConnected = result.ok && result.reason === 'connected'; - default: - return false; - } + setProviderConnectionStatus(prev => ({ + ...prev, + [provider]: { connected: isConnected, checking: false, lastChecked: new Date() } + })); + + return isConnected; } catch (error) { - console.error(`Failed to test ${provider} connection:`, error); + console.error(`Error testing ${provider} connection:`, error); setProviderConnectionStatus(prev => ({ ...prev, [provider]: { connected: false, checking: false, lastChecked: new Date() } @@ -260,37 +350,27 @@ export const RAGSettings = ({ // Test provider connections when API credentials change useEffect(() => { const testConnections = async () => { - const providers = ['openai', 'google']; - + // Test all supported providers + const providers = ['openai', 'google', 'anthropic', 'openrouter', 'grok']; + for (const provider of providers) { - const keyName = provider === 'openai' ? 'OPENAI_API_KEY' : 'GOOGLE_API_KEY'; - const apiKey = Object.keys(apiCredentials).find(key => key.toUpperCase() === keyName); - const keyValue = apiKey ? apiCredentials[apiKey] : undefined; - - if (keyValue && keyValue.trim().length > 0) { - // Don't test if we've already checked recently (within last 30 seconds) - const lastChecked = providerConnectionStatus[provider]?.lastChecked; - const now = new Date(); - const timeSinceLastCheck = lastChecked ? now.getTime() - lastChecked.getTime() : Infinity; - - if (timeSinceLastCheck > 30000) { // 30 seconds - console.log(`🔄 Testing ${provider} connection...`); - await testProviderConnection(provider, keyValue); - } - } else { - // No API key, mark as disconnected - setProviderConnectionStatus(prev => ({ - ...prev, - [provider]: { connected: false, checking: false, lastChecked: new Date() } - })); + // Don't test if we've already checked recently (within last 30 seconds) + const lastChecked = providerConnectionStatus[provider]?.lastChecked; + const now = new Date(); + const timeSinceLastCheck = lastChecked ? now.getTime() - lastChecked.getTime() : Infinity; + + if (timeSinceLastCheck > 30000) { // 30 seconds + console.log(`🔄 Testing ${provider} connection...`); + await testProviderConnection(provider); } } }; - // Only test if we have credentials loaded - if (Object.keys(apiCredentials).length > 0) { - testConnections(); - } + // Test connections periodically (every 60 seconds) + testConnections(); + const interval = setInterval(testConnections, 60000); + + return () => clearInterval(interval); }, [apiCredentials]); // Test when credentials change // Ref to track if initial test has been run (will be used after function definitions) @@ -662,24 +742,41 @@ export const RAGSettings = ({ if (llmStatus.online || embeddingStatus.online) return 'partial'; return 'missing'; case 'anthropic': - // Check if Anthropic API key is configured (case insensitive) - const anthropicKey = Object.keys(apiCredentials).find(key => key.toUpperCase() === 'ANTHROPIC_API_KEY'); - const hasAnthropicKey = anthropicKey && apiCredentials[anthropicKey] && apiCredentials[anthropicKey].trim().length > 0; - return hasAnthropicKey ? 'configured' : 'missing'; + // Use server-side connection status + const anthropicConnected = providerConnectionStatus['anthropic']?.connected || false; + const anthropicChecking = providerConnectionStatus['anthropic']?.checking || false; + if (anthropicChecking) return 'partial'; + return anthropicConnected ? 'configured' : 'missing'; case 'grok': - // Check if Grok API key is configured (case insensitive) - const grokKey = Object.keys(apiCredentials).find(key => key.toUpperCase() === 'GROK_API_KEY'); - const hasGrokKey = grokKey && apiCredentials[grokKey] && apiCredentials[grokKey].trim().length > 0; - return hasGrokKey ? 'configured' : 'missing'; + // Use server-side connection status + const grokConnected = providerConnectionStatus['grok']?.connected || false; + const grokChecking = providerConnectionStatus['grok']?.checking || false; + if (grokChecking) return 'partial'; + return grokConnected ? 'configured' : 'missing'; case 'openrouter': - // Check if OpenRouter API key is configured (case insensitive) - const openRouterKey = Object.keys(apiCredentials).find(key => key.toUpperCase() === 'OPENROUTER_API_KEY'); - const hasOpenRouterKey = openRouterKey && apiCredentials[openRouterKey] && apiCredentials[openRouterKey].trim().length > 0; - return hasOpenRouterKey ? 'configured' : 'missing'; + // Use server-side connection status + const openRouterConnected = providerConnectionStatus['openrouter']?.connected || false; + const openRouterChecking = providerConnectionStatus['openrouter']?.checking || false; + if (openRouterChecking) return 'partial'; + return openRouterConnected ? 'configured' : 'missing'; default: return 'missing'; } - };; + }; + + const selectedProviderKey = isProviderKey(ragSettings.LLM_PROVIDER) + ? (ragSettings.LLM_PROVIDER as ProviderKey) + : undefined; + const selectedProviderStatus = selectedProviderKey ? getProviderStatus(selectedProviderKey) : undefined; + const shouldShowProviderAlert = Boolean( + selectedProviderKey && selectedProviderStatus === 'missing' + ); + const providerAlertClassName = shouldShowProviderAlert && selectedProviderKey + ? providerAlertStyles[selectedProviderKey] + : ''; + const providerAlertMessage = shouldShowProviderAlert && selectedProviderKey + ? providerAlertMessages[selectedProviderKey] + : ''; // Test Ollama connectivity when Settings page loads (scenario 4: page load) // This useEffect is placed after function definitions to ensure access to manualTestConnection @@ -750,55 +847,32 @@ export const RAGSettings = ({ {[ { key: 'openai', name: 'OpenAI', logo: '/img/OpenAI.png', color: 'green' }, { key: 'google', name: 'Google', logo: '/img/google-logo.svg', color: 'blue' }, + { key: 'openrouter', name: 'OpenRouter', logo: '/img/OpenRouter.png', color: 'cyan' }, { key: 'ollama', name: 'Ollama', logo: '/img/Ollama.png', color: 'purple' }, { key: 'anthropic', name: 'Anthropic', logo: '/img/claude-logo.svg', color: 'orange' }, - { key: 'grok', name: 'Grok', logo: '/img/Grok.png', color: 'yellow' }, - { key: 'openrouter', name: 'OpenRouter', logo: '/img/OpenRouter.png', color: 'cyan' } + { key: 'grok', name: 'Grok', logo: '/img/Grok.png', color: 'yellow' } ].map(provider => ( ))} @@ -1223,19 +1290,9 @@ export const RAGSettings = ({ )} - {ragSettings.LLM_PROVIDER === 'anthropic' && ( -
-

- Configure your Anthropic API key in the credentials section to use Claude models. -

-
- )} - - {ragSettings.LLM_PROVIDER === 'groq' && ( -
-

- Groq provides fast inference with Llama, Mixtral, and Gemma models. -

+ {shouldShowProviderAlert && ( +
+

{providerAlertMessage}

)} @@ -1853,94 +1910,56 @@ export const RAGSettings = ({ function getDisplayedChatModel(ragSettings: any): string { const provider = ragSettings.LLM_PROVIDER || 'openai'; const modelChoice = ragSettings.MODEL_CHOICE; - - // Check if the stored model is appropriate for the current provider - const isModelAppropriate = (model: string, provider: string): boolean => { - if (!model) return false; - - switch (provider) { - case 'openai': - return model.startsWith('gpt-') || model.startsWith('o1-') || model.includes('text-davinci') || model.includes('text-embedding'); - case 'anthropic': - return model.startsWith('claude-'); - case 'google': - return model.startsWith('gemini-') || model.startsWith('text-embedding-'); - case 'grok': - return model.startsWith('grok-'); - case 'ollama': - return !model.startsWith('gpt-') && !model.startsWith('claude-') && !model.startsWith('gemini-') && !model.startsWith('grok-'); - case 'openrouter': - return model.includes('/') || model.startsWith('anthropic/') || model.startsWith('openai/'); - default: - return false; - } - }; - - // Use stored model if it's appropriate for the provider, otherwise use default - const useStoredModel = modelChoice && isModelAppropriate(modelChoice, provider); - + + // Always prioritize user input to allow editing + if (modelChoice !== undefined && modelChoice !== null) { + return modelChoice; + } + + // Only use defaults when there's no stored value switch (provider) { case 'openai': - return useStoredModel ? modelChoice : 'gpt-4o-mini'; + return 'gpt-4o-mini'; case 'anthropic': - return useStoredModel ? modelChoice : 'claude-3-5-sonnet-20241022'; + return 'claude-3-5-sonnet-20241022'; case 'google': - return useStoredModel ? modelChoice : 'gemini-1.5-flash'; + return 'gemini-1.5-flash'; case 'grok': - return useStoredModel ? modelChoice : 'grok-2-latest'; + return 'grok-3-mini'; case 'ollama': - return useStoredModel ? modelChoice : ''; + return ''; case 'openrouter': - return useStoredModel ? modelChoice : 'anthropic/claude-3.5-sonnet'; + return 'anthropic/claude-3.5-sonnet'; default: - return useStoredModel ? modelChoice : 'gpt-4o-mini'; + return 'gpt-4o-mini'; } } function getDisplayedEmbeddingModel(ragSettings: any): string { const provider = ragSettings.LLM_PROVIDER || 'openai'; const embeddingModel = ragSettings.EMBEDDING_MODEL; - - // Check if the stored embedding model is appropriate for the current provider - const isEmbeddingModelAppropriate = (model: string, provider: string): boolean => { - if (!model) return false; - - switch (provider) { - case 'openai': - return model.startsWith('text-embedding-') || model.includes('ada-'); - case 'anthropic': - return false; // Claude doesn't provide embedding models - case 'google': - return model.startsWith('text-embedding-') || model.startsWith('textembedding-') || model.includes('embedding'); - case 'grok': - return false; // Grok doesn't provide embedding models - case 'ollama': - return !model.startsWith('text-embedding-') || model.includes('embed') || model.includes('arctic'); - case 'openrouter': - return model.startsWith('text-embedding-') || model.includes('/'); - default: - return false; - } - }; - - // Use stored model if it's appropriate for the provider, otherwise use default - const useStoredModel = embeddingModel && isEmbeddingModelAppropriate(embeddingModel, provider); - + + // Always prioritize user input to allow editing + if (embeddingModel !== undefined && embeddingModel !== null && embeddingModel !== '') { + return embeddingModel; + } + + // Provide appropriate defaults based on LLM provider switch (provider) { case 'openai': - return useStoredModel ? embeddingModel : 'text-embedding-3-small'; - case 'anthropic': - return 'Not available - Claude does not provide embedding models'; + return 'text-embedding-3-small'; case 'google': - return useStoredModel ? embeddingModel : 'text-embedding-004'; - case 'grok': - return 'Not available - Grok does not provide embedding models'; + return 'text-embedding-004'; case 'ollama': - return useStoredModel ? embeddingModel : ''; + return ''; case 'openrouter': - return useStoredModel ? embeddingModel : 'text-embedding-3-small'; + return 'text-embedding-3-small'; // Default to OpenAI embedding for OpenRouter + case 'anthropic': + return 'text-embedding-3-small'; // Use OpenAI embeddings with Claude + case 'grok': + return 'text-embedding-3-small'; // Use OpenAI embeddings with Grok default: - return useStoredModel ? embeddingModel : 'text-embedding-3-small'; + return 'text-embedding-3-small'; } } @@ -2035,4 +2054,4 @@ const CustomCheckbox = ({
); -}; \ No newline at end of file +}; diff --git a/python/src/server/api_routes/__init__.py b/python/src/server/api_routes/__init__.py index 04df2992..8a39ef26 100644 --- a/python/src/server/api_routes/__init__.py +++ b/python/src/server/api_routes/__init__.py @@ -14,6 +14,7 @@ from .internal_api import router as internal_router from .knowledge_api import router as knowledge_router from .mcp_api import router as mcp_router from .projects_api import router as projects_router +from .providers_api import router as providers_router from .settings_api import router as settings_router __all__ = [ @@ -23,4 +24,5 @@ __all__ = [ "projects_router", "agent_chat_router", "internal_router", + "providers_router", ] diff --git a/python/src/server/api_routes/knowledge_api.py b/python/src/server/api_routes/knowledge_api.py index 56725838..1f26dace 100644 --- a/python/src/server/api_routes/knowledge_api.py +++ b/python/src/server/api_routes/knowledge_api.py @@ -18,6 +18,8 @@ from urllib.parse import urlparse from fastapi import APIRouter, File, Form, HTTPException, UploadFile from pydantic import BaseModel +# Basic validation - simplified inline version + # Import unified logging from ..config.logfire_config import get_logger, safe_logfire_error, safe_logfire_info from ..services.crawler_manager import get_crawler @@ -62,26 +64,59 @@ async def _validate_provider_api_key(provider: str = None) -> None: logger.info("🔑 Starting API key validation...") try: + # Basic provider validation if not provider: provider = "openai" + else: + # Simple provider validation + allowed_providers = {"openai", "ollama", "google", "openrouter", "anthropic", "grok"} + if provider not in allowed_providers: + raise HTTPException( + status_code=400, + detail={ + "error": "Invalid provider name", + "message": f"Provider '{provider}' not supported", + "error_type": "validation_error" + } + ) - logger.info(f"🔑 Testing {provider.title()} API key with minimal embedding request...") - - # Test API key with minimal embedding request - this will fail if key is invalid - from ..services.embeddings.embedding_service import create_embedding - test_result = await create_embedding(text="test") - - if not test_result: - logger.error(f"❌ {provider.title()} API key validation failed - no embedding returned") - raise HTTPException( - status_code=401, - detail={ - "error": f"Invalid {provider.title()} API key", - "message": f"Please verify your {provider.title()} API key in Settings.", - "error_type": "authentication_failed", - "provider": provider - } - ) + # Basic sanitization for logging + safe_provider = provider[:20] # Limit length + logger.info(f"🔑 Testing {safe_provider.title()} API key with minimal embedding request...") + + try: + # Test API key with minimal embedding request using provider-scoped configuration + from ..services.embeddings.embedding_service import create_embedding + + test_result = await create_embedding(text="test", provider=provider) + + if not test_result: + logger.error( + f"❌ {provider.title()} API key validation failed - no embedding returned" + ) + raise HTTPException( + status_code=401, + detail={ + "error": f"Invalid {provider.title()} API key", + "message": f"Please verify your {provider.title()} API key in Settings.", + "error_type": "authentication_failed", + "provider": provider, + }, + ) + except Exception as e: + logger.error( + f"❌ {provider.title()} API key validation failed: {e}", + exc_info=True, + ) + raise HTTPException( + status_code=401, + detail={ + "error": f"Invalid {provider.title()} API key", + "message": f"Please verify your {provider.title()} API key in Settings. Error: {str(e)[:100]}", + "error_type": "authentication_failed", + "provider": provider, + }, + ) logger.info(f"✅ {provider.title()} API key validation successful") diff --git a/python/src/server/api_routes/providers_api.py b/python/src/server/api_routes/providers_api.py new file mode 100644 index 00000000..9c405ecd --- /dev/null +++ b/python/src/server/api_routes/providers_api.py @@ -0,0 +1,154 @@ +""" +Provider status API endpoints for testing connectivity + +Handles server-side provider connectivity testing without exposing API keys to frontend. +""" + +import httpx +from fastapi import APIRouter, HTTPException, Path + +from ..config.logfire_config import logfire +from ..services.credential_service import credential_service +# Provider validation - simplified inline version + +router = APIRouter(prefix="/api/providers", tags=["providers"]) + + +async def test_openai_connection(api_key: str) -> bool: + """Test OpenAI API connectivity""" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + response = await client.get( + "https://api.openai.com/v1/models", + headers={"Authorization": f"Bearer {api_key}"} + ) + return response.status_code == 200 + except Exception as e: + logfire.warning(f"OpenAI connectivity test failed: {e}") + return False + + +async def test_google_connection(api_key: str) -> bool: + """Test Google AI API connectivity""" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + response = await client.get( + "https://generativelanguage.googleapis.com/v1/models", + headers={"x-goog-api-key": api_key} + ) + return response.status_code == 200 + except Exception: + logfire.warning("Google AI connectivity test failed") + return False + + +async def test_anthropic_connection(api_key: str) -> bool: + """Test Anthropic API connectivity""" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + response = await client.get( + "https://api.anthropic.com/v1/models", + headers={ + "x-api-key": api_key, + "anthropic-version": "2023-06-01" + } + ) + return response.status_code == 200 + except Exception as e: + logfire.warning(f"Anthropic connectivity test failed: {e}") + return False + + +async def test_openrouter_connection(api_key: str) -> bool: + """Test OpenRouter API connectivity""" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + response = await client.get( + "https://openrouter.ai/api/v1/models", + headers={"Authorization": f"Bearer {api_key}"} + ) + return response.status_code == 200 + except Exception as e: + logfire.warning(f"OpenRouter connectivity test failed: {e}") + return False + + +async def test_grok_connection(api_key: str) -> bool: + """Test Grok API connectivity""" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + response = await client.get( + "https://api.x.ai/v1/models", + headers={"Authorization": f"Bearer {api_key}"} + ) + return response.status_code == 200 + except Exception as e: + logfire.warning(f"Grok connectivity test failed: {e}") + return False + + +PROVIDER_TESTERS = { + "openai": test_openai_connection, + "google": test_google_connection, + "anthropic": test_anthropic_connection, + "openrouter": test_openrouter_connection, + "grok": test_grok_connection, +} + + +@router.get("/{provider}/status") +async def get_provider_status( + provider: str = Path( + ..., + description="Provider name to test connectivity for", + regex="^[a-z0-9_]+$", + max_length=20 + ) +): + """Test provider connectivity using server-side API key (secure)""" + try: + # Basic provider validation + allowed_providers = {"openai", "ollama", "google", "openrouter", "anthropic", "grok"} + if provider not in allowed_providers: + raise HTTPException( + status_code=400, + detail=f"Invalid provider '{provider}'. Allowed providers: {sorted(allowed_providers)}" + ) + + # Basic sanitization for logging + safe_provider = provider[:20] # Limit length + logfire.info(f"Testing {safe_provider} connectivity server-side") + + if provider not in PROVIDER_TESTERS: + raise HTTPException( + status_code=400, + detail=f"Provider '{provider}' not supported for connectivity testing" + ) + + # Get API key server-side (never expose to client) + key_name = f"{provider.upper()}_API_KEY" + api_key = await credential_service.get_credential(key_name, decrypt=True) + + if not api_key or not isinstance(api_key, str) or not api_key.strip(): + logfire.info(f"No API key configured for {safe_provider}") + return {"ok": False, "reason": "no_key"} + + # Test connectivity using server-side key + tester = PROVIDER_TESTERS[provider] + is_connected = await tester(api_key) + + logfire.info(f"{safe_provider} connectivity test result: {is_connected}") + return { + "ok": is_connected, + "reason": "connected" if is_connected else "connection_failed", + "provider": provider # Echo back validated provider name + } + + except HTTPException: + # Re-raise HTTP exceptions (they're already properly formatted) + raise + except Exception as e: + # Basic error sanitization for logging + safe_error = str(e)[:100] # Limit length + logfire.error(f"Error testing {provider[:20]} connectivity: {safe_error}") + raise HTTPException(status_code=500, detail={"error": "Internal server error during connectivity test"}) diff --git a/python/src/server/main.py b/python/src/server/main.py index bec14a71..ba0b19cb 100644 --- a/python/src/server/main.py +++ b/python/src/server/main.py @@ -26,6 +26,7 @@ from .api_routes.mcp_api import router as mcp_router from .api_routes.ollama_api import router as ollama_router from .api_routes.progress_api import router as progress_router from .api_routes.projects_api import router as projects_router +from .api_routes.providers_api import router as providers_router # Import modular API routers from .api_routes.settings_api import router as settings_router @@ -186,6 +187,7 @@ app.include_router(progress_router) app.include_router(agent_chat_router) app.include_router(internal_router) app.include_router(bug_report_router) +app.include_router(providers_router) # Root endpoint diff --git a/python/src/server/services/crawling/code_extraction_service.py b/python/src/server/services/crawling/code_extraction_service.py index f52b7e28..1a540f57 100644 --- a/python/src/server/services/crawling/code_extraction_service.py +++ b/python/src/server/services/crawling/code_extraction_service.py @@ -139,6 +139,7 @@ class CodeExtractionService: source_id: str, progress_callback: Callable | None = None, cancellation_check: Callable[[], None] | None = None, + provider: str | None = None, ) -> int: """ Extract code examples from crawled documents and store them. @@ -204,7 +205,7 @@ class CodeExtractionService: # Generate summaries for code blocks summary_results = await self._generate_code_summaries( - all_code_blocks, summary_callback, cancellation_check + all_code_blocks, summary_callback, cancellation_check, provider ) # Prepare code examples for storage @@ -223,7 +224,7 @@ class CodeExtractionService: # Store code examples in database return await self._store_code_examples( - storage_data, url_to_full_document, storage_callback + storage_data, url_to_full_document, storage_callback, provider ) async def _extract_code_blocks_from_documents( @@ -1523,6 +1524,7 @@ class CodeExtractionService: all_code_blocks: list[dict[str, Any]], progress_callback: Callable | None = None, cancellation_check: Callable[[], None] | None = None, + provider: str | None = None, ) -> list[dict[str, str]]: """ Generate summaries for all code blocks. @@ -1587,7 +1589,7 @@ class CodeExtractionService: try: results = await generate_code_summaries_batch( - code_blocks_for_summaries, max_workers, progress_callback=summary_progress_callback + code_blocks_for_summaries, max_workers, progress_callback=summary_progress_callback, provider=provider ) # Ensure all results are valid dicts @@ -1667,6 +1669,7 @@ class CodeExtractionService: storage_data: dict[str, list[Any]], url_to_full_document: dict[str, str], progress_callback: Callable | None = None, + provider: str | None = None, ) -> int: """ Store code examples in the database. @@ -1709,7 +1712,7 @@ class CodeExtractionService: batch_size=20, url_to_full_document=url_to_full_document, progress_callback=storage_progress_callback, - provider=None, # Use configured provider + provider=provider, ) # Report completion of code extraction/storage phase diff --git a/python/src/server/services/crawling/crawling_service.py b/python/src/server/services/crawling/crawling_service.py index 55cd6d92..69c65719 100644 --- a/python/src/server/services/crawling/crawling_service.py +++ b/python/src/server/services/crawling/crawling_service.py @@ -475,12 +475,24 @@ class CrawlingService: ) try: + # Extract provider from request or use credential service default + provider = request.get("provider") + if not provider: + try: + from ..credential_service import credential_service + provider_config = await credential_service.get_active_provider("llm") + provider = provider_config.get("provider", "openai") + except Exception as e: + logger.warning(f"Failed to get provider from credential service: {e}, defaulting to openai") + provider = "openai" + code_examples_count = await self.doc_storage_ops.extract_and_store_code_examples( crawl_results, storage_results["url_to_full_document"], storage_results["source_id"], code_progress_callback, self._check_cancellation, + provider, ) except RuntimeError as e: # Code extraction failed, continue crawl with warning diff --git a/python/src/server/services/crawling/document_storage_operations.py b/python/src/server/services/crawling/document_storage_operations.py index aaf211a7..88ed8e80 100644 --- a/python/src/server/services/crawling/document_storage_operations.py +++ b/python/src/server/services/crawling/document_storage_operations.py @@ -351,6 +351,7 @@ class DocumentStorageOperations: source_id: str, progress_callback: Callable | None = None, cancellation_check: Callable[[], None] | None = None, + provider: str | None = None, ) -> int: """ Extract code examples from crawled documents and store them. @@ -361,12 +362,13 @@ class DocumentStorageOperations: source_id: The unique source_id for all documents progress_callback: Optional callback for progress updates cancellation_check: Optional function to check for cancellation + provider: Optional LLM provider to use for code summaries Returns: Number of code examples stored """ result = await self.code_extraction_service.extract_and_store_code_examples( - crawl_results, url_to_full_document, source_id, progress_callback, cancellation_check + crawl_results, url_to_full_document, source_id, progress_callback, cancellation_check, provider ) return result diff --git a/python/src/server/services/credential_service.py b/python/src/server/services/credential_service.py index e72ca8a4..62fbb47a 100644 --- a/python/src/server/services/credential_service.py +++ b/python/src/server/services/credential_service.py @@ -36,6 +36,44 @@ class CredentialItem: description: str | None = None +def _detect_embedding_provider_from_model(embedding_model: str) -> str: + """ + Detect the appropriate embedding provider based on model name. + + Args: + embedding_model: The embedding model name + + Returns: + Provider name: 'google', 'openai', or 'openai' (default) + """ + if not embedding_model: + return "openai" # Default + + model_lower = embedding_model.lower() + + # Google embedding models + google_patterns = [ + "text-embedding-004", + "text-embedding-005", + "text-multilingual-embedding", + "gemini-embedding", + "multimodalembedding" + ] + + if any(pattern in model_lower for pattern in google_patterns): + return "google" + + # OpenAI embedding models (and default for unknown) + openai_patterns = [ + "text-embedding-ada-002", + "text-embedding-3-small", + "text-embedding-3-large" + ] + + # Default to OpenAI for OpenAI models or unknown models + return "openai" + + class CredentialService: """Service for managing application credentials and configuration.""" @@ -239,6 +277,14 @@ class CredentialService: self._rag_cache_timestamp = None logger.debug(f"Invalidated RAG settings cache due to update of {key}") + # Also invalidate provider service cache to ensure immediate effect + try: + from .llm_provider_service import clear_provider_cache + clear_provider_cache() + logger.debug("Also cleared LLM provider service cache") + except Exception as e: + logger.warning(f"Failed to clear provider service cache: {e}") + # Also invalidate LLM provider service cache for provider config try: from . import llm_provider_service @@ -281,6 +327,14 @@ class CredentialService: self._rag_cache_timestamp = None logger.debug(f"Invalidated RAG settings cache due to deletion of {key}") + # Also invalidate provider service cache to ensure immediate effect + try: + from .llm_provider_service import clear_provider_cache + clear_provider_cache() + logger.debug("Also cleared LLM provider service cache") + except Exception as e: + logger.warning(f"Failed to clear provider service cache: {e}") + # Also invalidate LLM provider service cache for provider config try: from . import llm_provider_service @@ -419,8 +473,33 @@ class CredentialService: # Get RAG strategy settings (where UI saves provider selection) rag_settings = await self.get_credentials_by_category("rag_strategy") - # Get the selected provider - provider = rag_settings.get("LLM_PROVIDER", "openai") + # Get the selected provider based on service type + if service_type == "embedding": + # Get the LLM provider setting to determine embedding provider + llm_provider = rag_settings.get("LLM_PROVIDER", "openai") + embedding_model = rag_settings.get("EMBEDDING_MODEL", "text-embedding-3-small") + + # Determine embedding provider based on LLM provider + if llm_provider == "google": + provider = "google" + elif llm_provider == "ollama": + provider = "ollama" + elif llm_provider == "openrouter": + # OpenRouter supports both OpenAI and Google embedding models + provider = _detect_embedding_provider_from_model(embedding_model) + elif llm_provider in ["anthropic", "grok"]: + # Anthropic and Grok support both OpenAI and Google embedding models + provider = _detect_embedding_provider_from_model(embedding_model) + else: + # Default case (openai, or unknown providers) + provider = "openai" + + logger.debug(f"Determined embedding provider '{provider}' from LLM provider '{llm_provider}' and embedding model '{embedding_model}'") + else: + provider = rag_settings.get("LLM_PROVIDER", "openai") + # Ensure provider is a valid string, not a boolean or other type + if not isinstance(provider, str) or provider.lower() in ("true", "false", "none", "null"): + provider = "openai" # Get API key for this provider api_key = await self._get_provider_api_key(provider) @@ -464,6 +543,9 @@ class CredentialService: key_mapping = { "openai": "OPENAI_API_KEY", "google": "GOOGLE_API_KEY", + "openrouter": "OPENROUTER_API_KEY", + "anthropic": "ANTHROPIC_API_KEY", + "grok": "GROK_API_KEY", "ollama": None, # No API key needed } @@ -478,6 +560,12 @@ class CredentialService: return rag_settings.get("LLM_BASE_URL", "http://host.docker.internal:11434/v1") elif provider == "google": return "https://generativelanguage.googleapis.com/v1beta/openai/" + elif provider == "openrouter": + return "https://openrouter.ai/api/v1" + elif provider == "anthropic": + return "https://api.anthropic.com/v1" + elif provider == "grok": + return "https://api.x.ai/v1" return None # Use default for OpenAI async def set_active_provider(self, provider: str, service_type: str = "llm") -> bool: @@ -485,7 +573,7 @@ class CredentialService: try: # For now, we'll update the RAG strategy settings return await self.set_credential( - "llm_provider", + "LLM_PROVIDER", provider, category="rag_strategy", description=f"Active {service_type} provider", diff --git a/python/src/server/services/embeddings/contextual_embedding_service.py b/python/src/server/services/embeddings/contextual_embedding_service.py index 76f3c59b..29b36395 100644 --- a/python/src/server/services/embeddings/contextual_embedding_service.py +++ b/python/src/server/services/embeddings/contextual_embedding_service.py @@ -10,7 +10,13 @@ import os import openai from ...config.logfire_config import search_logger -from ..llm_provider_service import get_llm_client +from ..credential_service import credential_service +from ..llm_provider_service import ( + extract_message_text, + get_llm_client, + prepare_chat_completion_params, + requires_max_completion_tokens, +) from ..threading_service import get_threading_service @@ -32,8 +38,6 @@ async def generate_contextual_embedding( """ # Model choice is a RAG setting, get from credential service try: - from ...services.credential_service import credential_service - model_choice = await credential_service.get_credential("MODEL_CHOICE", "gpt-4.1-nano") except Exception as e: # Fallback to environment variable or default @@ -65,20 +69,25 @@ Please give a short succinct context to situate this chunk within the overall do # Get model from provider configuration model = await _get_model_choice(provider) - response = await client.chat.completions.create( - model=model, - messages=[ + # Prepare parameters and convert max_tokens for GPT-5/reasoning models + params = { + "model": model, + "messages": [ { "role": "system", "content": "You are a helpful assistant that provides concise contextual information.", }, {"role": "user", "content": prompt}, ], - temperature=0.3, - max_tokens=200, - ) + "temperature": 0.3, + "max_tokens": 1200 if requires_max_completion_tokens(model) else 200, # Much more tokens for reasoning models (GPT-5 needs extra for reasoning process) + } + final_params = prepare_chat_completion_params(model, params) + response = await client.chat.completions.create(**final_params) - context = response.choices[0].message.content.strip() + choice = response.choices[0] if response.choices else None + context, _, _ = extract_message_text(choice) + context = context.strip() contextual_text = f"{context}\n---\n{chunk}" return contextual_text, True @@ -111,7 +120,7 @@ async def process_chunk_with_context( async def _get_model_choice(provider: str | None = None) -> str: - """Get model choice from credential service.""" + """Get model choice from credential service with centralized defaults.""" from ..credential_service import credential_service # Get the active provider configuration @@ -119,31 +128,36 @@ async def _get_model_choice(provider: str | None = None) -> str: model = provider_config.get("chat_model", "").strip() # Strip whitespace provider_name = provider_config.get("provider", "openai") - # Handle empty model case - fallback to provider-specific defaults or explicit config + # Handle empty model case - use centralized defaults if not model: - search_logger.warning(f"chat_model is empty for provider {provider_name}, using fallback logic") - + search_logger.warning(f"chat_model is empty for provider {provider_name}, using centralized defaults") + + # Special handling for Ollama to check specific credential if provider_name == "ollama": - # Try to get OLLAMA_CHAT_MODEL specifically try: ollama_model = await credential_service.get_credential("OLLAMA_CHAT_MODEL") if ollama_model and ollama_model.strip(): model = ollama_model.strip() search_logger.info(f"Using OLLAMA_CHAT_MODEL fallback: {model}") else: - # Use a sensible Ollama default + # Use default for Ollama model = "llama3.2:latest" - search_logger.info(f"Using Ollama default model: {model}") + search_logger.info(f"Using Ollama default: {model}") except Exception as e: search_logger.error(f"Error getting OLLAMA_CHAT_MODEL: {e}") model = "llama3.2:latest" - search_logger.info(f"Using Ollama fallback model: {model}") - elif provider_name == "google": - model = "gemini-1.5-flash" + search_logger.info(f"Using Ollama fallback: {model}") else: - # OpenAI or other providers - model = "gpt-4o-mini" - + # Use provider-specific defaults + provider_defaults = { + "openai": "gpt-4o-mini", + "openrouter": "anthropic/claude-3.5-sonnet", + "google": "gemini-1.5-flash", + "anthropic": "claude-3-5-haiku-20241022", + "grok": "grok-3-mini" + } + model = provider_defaults.get(provider_name, "gpt-4o-mini") + search_logger.debug(f"Using default model for provider {provider_name}: {model}") search_logger.debug(f"Using model from credential service: {model}") return model @@ -174,38 +188,48 @@ async def generate_contextual_embeddings_batch( model_choice = await _get_model_choice(provider) # Build batch prompt for ALL chunks at once - batch_prompt = ( - "Process the following chunks and provide contextual information for each:\\n\\n" - ) + batch_prompt = "Process the following chunks and provide contextual information for each:\n\n" for i, (doc, chunk) in enumerate(zip(full_documents, chunks, strict=False)): # Use only 2000 chars of document context to save tokens doc_preview = doc[:2000] if len(doc) > 2000 else doc - batch_prompt += f"CHUNK {i + 1}:\\n" - batch_prompt += f"\\n{doc_preview}\\n\\n" - batch_prompt += f"\\n{chunk[:500]}\\n\\n\\n" # Limit chunk preview + batch_prompt += f"CHUNK {i + 1}:\n" + batch_prompt += f"\n{doc_preview}\n\n" + batch_prompt += f"\n{chunk[:500]}\n\n\n" # Limit chunk preview - batch_prompt += "For each chunk, provide a short succinct context to situate it within the overall document for improving search retrieval. Format your response as:\\nCHUNK 1: [context]\\nCHUNK 2: [context]\\netc." + batch_prompt += ( + "For each chunk, provide a short succinct context to situate it within the overall document for improving search retrieval. " + "Format your response as:\nCHUNK 1: [context]\nCHUNK 2: [context]\netc." + ) # Make single API call for ALL chunks - response = await client.chat.completions.create( - model=model_choice, - messages=[ + # Prepare parameters and convert max_tokens for GPT-5/reasoning models + batch_params = { + "model": model_choice, + "messages": [ { "role": "system", "content": "You are a helpful assistant that generates contextual information for document chunks.", }, {"role": "user", "content": batch_prompt}, ], - temperature=0, - max_tokens=100 * len(chunks), # Limit response size - ) + "temperature": 0, + "max_tokens": (600 if requires_max_completion_tokens(model_choice) else 100) * len(chunks), # Much more tokens for reasoning models (GPT-5 needs extra reasoning space) + } + final_batch_params = prepare_chat_completion_params(model_choice, batch_params) + response = await client.chat.completions.create(**final_batch_params) # Parse response - response_text = response.choices[0].message.content + choice = response.choices[0] if response.choices else None + response_text, _, _ = extract_message_text(choice) + if not response_text: + search_logger.error( + "Empty response from LLM when generating contextual embeddings batch" + ) + return [(chunk, False) for chunk in chunks] # Extract contexts from response - lines = response_text.strip().split("\\n") + lines = response_text.strip().split("\n") chunk_contexts = {} for line in lines: @@ -245,4 +269,4 @@ async def generate_contextual_embeddings_batch( except Exception as e: search_logger.error(f"Error in contextual embedding batch: {e}") # Return non-contextual for all chunks - return [(chunk, False) for chunk in chunks] \ No newline at end of file + return [(chunk, False) for chunk in chunks] diff --git a/python/src/server/services/embeddings/embedding_service.py b/python/src/server/services/embeddings/embedding_service.py index d697abf9..4f825f1d 100644 --- a/python/src/server/services/embeddings/embedding_service.py +++ b/python/src/server/services/embeddings/embedding_service.py @@ -13,7 +13,7 @@ import openai from ...config.logfire_config import safe_span, search_logger from ..credential_service import credential_service -from ..llm_provider_service import get_embedding_model, get_llm_client +from ..llm_provider_service import get_embedding_model, get_llm_client, is_google_embedding_model, is_openai_embedding_model from ..threading_service import get_threading_service from .embedding_exceptions import ( EmbeddingAPIError, @@ -152,34 +152,56 @@ async def create_embeddings_batch( if not texts: return EmbeddingBatchResult() + result = EmbeddingBatchResult() + # Validate that all items in texts are strings validated_texts = [] for i, text in enumerate(texts): - if not isinstance(text, str): - search_logger.error( - f"Invalid text type at index {i}: {type(text)}, value: {text}", exc_info=True - ) - # Try to convert to string - try: - validated_texts.append(str(text)) - except Exception as e: - search_logger.error( - f"Failed to convert text at index {i} to string: {e}", exc_info=True - ) - validated_texts.append("") # Use empty string as fallback - else: + if isinstance(text, str): validated_texts.append(text) + continue + + search_logger.error( + f"Invalid text type at index {i}: {type(text)}, value: {text}", exc_info=True + ) + try: + converted = str(text) + validated_texts.append(converted) + except Exception as conversion_error: + search_logger.error( + f"Failed to convert text at index {i} to string: {conversion_error}", + exc_info=True, + ) + result.add_failure( + repr(text), + EmbeddingAPIError("Invalid text type", original_error=conversion_error), + batch_index=None, + ) texts = validated_texts - - result = EmbeddingBatchResult() threading_service = get_threading_service() with safe_span( "create_embeddings_batch", text_count=len(texts), total_chars=sum(len(t) for t in texts) ) as span: try: - async with get_llm_client(provider=provider, use_embedding_provider=True) as client: + # Intelligent embedding provider routing based on model type + # Get the embedding model first to determine the correct provider + embedding_model = await get_embedding_model(provider=provider) + + # Route to correct provider based on model type + if is_google_embedding_model(embedding_model): + embedding_provider = "google" + search_logger.info(f"Routing to Google for embedding model: {embedding_model}") + elif is_openai_embedding_model(embedding_model) or "openai/" in embedding_model.lower(): + embedding_provider = "openai" + search_logger.info(f"Routing to OpenAI for embedding model: {embedding_model}") + else: + # Keep original provider for ollama and other providers + embedding_provider = provider + search_logger.info(f"Using original provider '{provider}' for embedding model: {embedding_model}") + + async with get_llm_client(provider=embedding_provider, use_embedding_provider=True) as client: # Load batch size and dimensions from settings try: rag_settings = await credential_service.get_credentials_by_category( @@ -220,7 +242,8 @@ async def create_embeddings_batch( while retry_count < max_retries: try: # Create embeddings for this batch - embedding_model = await get_embedding_model(provider=provider) + embedding_model = await get_embedding_model(provider=embedding_provider) + response = await client.embeddings.create( model=embedding_model, input=batch, diff --git a/python/src/server/services/llm_provider_service.py b/python/src/server/services/llm_provider_service.py index 10655b22..00197926 100644 --- a/python/src/server/services/llm_provider_service.py +++ b/python/src/server/services/llm_provider_service.py @@ -5,6 +5,7 @@ Provides a unified interface for creating OpenAI-compatible clients for differen Supports OpenAI, Ollama, and Google Gemini. """ +import inspect import time from contextlib import asynccontextmanager from typing import Any @@ -16,31 +17,305 @@ from .credential_service import credential_service logger = get_logger(__name__) -# Settings cache with TTL -_settings_cache: dict[str, tuple[Any, float]] = {} + +# Basic validation functions to avoid circular imports +def _is_valid_provider(provider: str) -> bool: + """Basic provider validation.""" + if not provider or not isinstance(provider, str): + return False + return provider.lower() in {"openai", "ollama", "google", "openrouter", "anthropic", "grok"} + + +def _sanitize_for_log(text: str) -> str: + """Basic text sanitization for logging.""" + if not text: + return "" + import re + sanitized = re.sub(r"sk-[a-zA-Z0-9-_]{20,}", "[REDACTED]", text) + sanitized = re.sub(r"xai-[a-zA-Z0-9-_]{20,}", "[REDACTED]", sanitized) + return sanitized[:100] + + +# Secure settings cache with TTL and validation +_settings_cache: dict[str, tuple[Any, float, str]] = {} # value, timestamp, checksum _CACHE_TTL_SECONDS = 300 # 5 minutes +_cache_access_log: list[dict] = [] # Track cache access patterns for security monitoring + + +def _calculate_cache_checksum(value: Any) -> str: + """Calculate checksum for cache entry integrity validation.""" + import hashlib + import json + + # Convert value to JSON string for consistent hashing + try: + value_str = json.dumps(value, sort_keys=True, default=str) + return hashlib.sha256(value_str.encode()).hexdigest()[:16] # First 16 chars for efficiency + except Exception: + # Fallback for non-serializable objects + return hashlib.sha256(str(value).encode()).hexdigest()[:16] + + +def _log_cache_access(key: str, action: str, hit: bool = None, security_event: str = None) -> None: + """Log cache access for security monitoring.""" + + access_entry = { + "timestamp": time.time(), + "key": _sanitize_for_log(key), + "action": action, # "get", "set", "invalidate", "clear" + "hit": hit, # For get operations + "security_event": security_event # "checksum_mismatch", "expired", etc. + } + + # Keep only last 100 access entries to prevent memory growth + _cache_access_log.append(access_entry) + if len(_cache_access_log) > 100: + _cache_access_log.pop(0) + + # Log security events at warning level + if security_event: + safe_key = _sanitize_for_log(key) + logger.warning(f"Cache security event: {security_event} for key '{safe_key}'") def _get_cached_settings(key: str) -> Any | None: - """Get cached settings if not expired.""" - if key in _settings_cache: - value, timestamp = _settings_cache[key] - if time.time() - timestamp < _CACHE_TTL_SECONDS: + """Get cached settings if not expired and valid.""" + + try: + if key in _settings_cache: + value, timestamp, stored_checksum = _settings_cache[key] + current_time = time.time() + + # Check expiration with strict TTL enforcement + if current_time - timestamp >= _CACHE_TTL_SECONDS: + # Expired, remove from cache + del _settings_cache[key] + _log_cache_access(key, "get", hit=False, security_event="expired") + return None + + # Verify cache entry integrity + current_checksum = _calculate_cache_checksum(value) + if current_checksum != stored_checksum: + # Cache tampering detected, remove entry + del _settings_cache[key] + _log_cache_access(key, "get", hit=False, security_event="checksum_mismatch") + logger.error(f"Cache integrity violation detected for key: {_sanitize_for_log(key)}") + return None + + # Additional validation for provider configurations + if "provider_config" in key and isinstance(value, dict): + # Basic validation: check required fields + if not value.get("provider") or not _is_valid_provider(value.get("provider")): + # Invalid configuration in cache, remove it + del _settings_cache[key] + _log_cache_access(key, "get", hit=False, security_event="invalid_config") + return None + + _log_cache_access(key, "get", hit=True) return value - else: - # Expired, remove from cache - del _settings_cache[key] - return None + + _log_cache_access(key, "get", hit=False) + return None + + except Exception as e: + # Cache access error, log and return None for safety + _log_cache_access(key, "get", hit=False, security_event=f"access_error: {str(e)}") + return None def _set_cached_settings(key: str, value: Any) -> None: - """Cache settings with current timestamp.""" - _settings_cache[key] = (value, time.time()) + """Cache settings with current timestamp and integrity checksum.""" + + try: + # Validate provider configurations before caching + if "provider_config" in key and isinstance(value, dict): + # Basic validation: check required fields + if not value.get("provider") or not _is_valid_provider(value.get("provider")): + _log_cache_access(key, "set", security_event="invalid_config_rejected") + logger.warning(f"Rejected caching of invalid provider config for key: {_sanitize_for_log(key)}") + return + + # Calculate integrity checksum + checksum = _calculate_cache_checksum(value) + + # Store with timestamp and checksum + _settings_cache[key] = (value, time.time(), checksum) + _log_cache_access(key, "set") + + except Exception as e: + _log_cache_access(key, "set", security_event=f"set_error: {str(e)}") + logger.error(f"Failed to cache settings for key {_sanitize_for_log(key)}: {e}") +def clear_provider_cache() -> None: + """Clear the provider configuration cache to force refresh on next request.""" + global _settings_cache + + cache_size_before = len(_settings_cache) + _settings_cache.clear() + _log_cache_access("*", "clear") + logger.debug(f"Provider configuration cache cleared ({cache_size_before} entries removed)") + + +def invalidate_provider_cache(provider: str = None) -> None: + """ + Invalidate specific provider cache entries or all cache entries. + + Args: + provider: Optional provider name to invalidate. If None, clears all cache. + """ + global _settings_cache + + if provider is None: + # Clear entire cache + cache_size_before = len(_settings_cache) + _settings_cache.clear() + _log_cache_access("*", "invalidate") + logger.debug(f"All provider cache entries invalidated ({cache_size_before} entries)") + else: + # Validate provider name before processing + if not _is_valid_provider(provider): + _log_cache_access(provider, "invalidate", security_event="invalid_provider_name") + logger.warning(f"Rejected cache invalidation for invalid provider: {_sanitize_for_log(provider)}") + return + + # Clear specific provider entries + keys_to_remove = [] + for key in _settings_cache.keys(): + if provider in key: + keys_to_remove.append(key) + + for key in keys_to_remove: + del _settings_cache[key] + _log_cache_access(key, "invalidate") + + safe_provider = _sanitize_for_log(provider) + logger.debug(f"Cache entries for provider '{safe_provider}' invalidated: {len(keys_to_remove)} entries removed") + + +def get_cache_stats() -> dict[str, Any]: + """ + Get cache statistics with security metrics for monitoring and debugging. + + Returns: + Dictionary containing cache statistics and security metrics + """ + global _settings_cache, _cache_access_log + current_time = time.time() + + stats = { + "total_entries": len(_settings_cache), + "fresh_entries": 0, + "stale_entries": 0, + "cache_hit_potential": 0.0, + "security_metrics": { + "integrity_violations": 0, + "expired_access_attempts": 0, + "invalid_config_rejections": 0, + "access_errors": 0, + "total_security_events": 0 + }, + "access_patterns": { + "recent_cache_hits": 0, + "recent_cache_misses": 0, + "hit_rate": 0.0 + } + } + + # Analyze cache entries + for _key, (_value, timestamp, _checksum) in _settings_cache.items(): + age = current_time - timestamp + if age < _CACHE_TTL_SECONDS: + stats["fresh_entries"] += 1 + else: + stats["stale_entries"] += 1 + + if stats["total_entries"] > 0: + stats["cache_hit_potential"] = stats["fresh_entries"] / stats["total_entries"] + + # Analyze security events from access log + recent_threshold = current_time - 3600 # Last hour + recent_hits = 0 + recent_misses = 0 + + for access in _cache_access_log: + if access["timestamp"] >= recent_threshold: + if access["action"] == "get": + if access["hit"]: + recent_hits += 1 + else: + recent_misses += 1 + + # Count security events + if access["security_event"]: + stats["security_metrics"]["total_security_events"] += 1 + + if "checksum_mismatch" in access["security_event"]: + stats["security_metrics"]["integrity_violations"] += 1 + elif "expired" in access["security_event"]: + stats["security_metrics"]["expired_access_attempts"] += 1 + elif "invalid_config" in access["security_event"]: + stats["security_metrics"]["invalid_config_rejections"] += 1 + elif "error" in access["security_event"]: + stats["security_metrics"]["access_errors"] += 1 + + # Calculate hit rate + total_recent_access = recent_hits + recent_misses + if total_recent_access > 0: + stats["access_patterns"]["hit_rate"] = recent_hits / total_recent_access + + stats["access_patterns"]["recent_cache_hits"] = recent_hits + stats["access_patterns"]["recent_cache_misses"] = recent_misses + + return stats + + +def get_cache_security_report() -> dict[str, Any]: + """ + Get detailed security report for cache monitoring. + + Returns: + Detailed security analysis of cache operations + """ + global _cache_access_log + current_time = time.time() + + report = { + "timestamp": current_time, + "analysis_period_hours": 1, + "security_events": [], + "recommendations": [] + } + + # Extract security events from last hour + recent_threshold = current_time - 3600 + security_events = [ + access for access in _cache_access_log + if access["timestamp"] >= recent_threshold and access["security_event"] + ] + + report["security_events"] = security_events + + # Generate recommendations based on security events + if len(security_events) > 10: + report["recommendations"].append("High number of security events detected - investigate potential attacks") + + integrity_violations = sum(1 for event in security_events if "checksum_mismatch" in event.get("security_event", "")) + if integrity_violations > 0: + report["recommendations"].append(f"Cache integrity violations detected ({integrity_violations}) - check for memory corruption or attacks") + + invalid_configs = sum(1 for event in security_events if "invalid_config" in event.get("security_event", "")) + if invalid_configs > 3: + report["recommendations"].append(f"Multiple invalid configuration attempts ({invalid_configs}) - validate data sources") + + return report @asynccontextmanager -async def get_llm_client(provider: str | None = None, use_embedding_provider: bool = False, - instance_type: str | None = None, base_url: str | None = None): +async def get_llm_client( + provider: str | None = None, + use_embedding_provider: bool = False, + instance_type: str | None = None, + base_url: str | None = None, +): """ Create an async OpenAI-compatible client based on the configured provider. @@ -58,6 +333,8 @@ async def get_llm_client(provider: str | None = None, use_embedding_provider: bo openai.AsyncOpenAI: An OpenAI-compatible client configured for the selected provider """ client = None + provider_name: str | None = None + api_key = None try: # Get provider configuration from database settings @@ -77,7 +354,11 @@ async def get_llm_client(provider: str | None = None, use_embedding_provider: bo logger.debug("Using cached rag_strategy settings") # For Ollama, don't use the base_url from config - let _get_optimal_ollama_instance decide - base_url = credential_service._get_provider_base_url(provider, rag_settings) if provider != "ollama" else None + base_url = ( + credential_service._get_provider_base_url(provider, rag_settings) + if provider != "ollama" + else None + ) else: # Get configured provider from database service_type = "embedding" if use_embedding_provider else "llm" @@ -97,45 +378,61 @@ async def get_llm_client(provider: str | None = None, use_embedding_provider: bo # For Ollama, don't use the base_url from config - let _get_optimal_ollama_instance decide base_url = provider_config["base_url"] if provider_name != "ollama" else None - logger.info(f"Creating LLM client for provider: {provider_name}") + # Comprehensive provider validation with security checks + if not _is_valid_provider(provider_name): + raise ValueError(f"Unsupported LLM provider: {provider_name}") + + # Validate API key format for security (prevent injection) + if api_key: + if len(api_key.strip()) == 0: + api_key = None # Treat empty strings as None + elif len(api_key) > 500: # Reasonable API key length limit + raise ValueError("API key length exceeds security limits") + # Additional security: check for suspicious patterns + if any(char in api_key for char in ['\n', '\r', '\t', '\0']): + raise ValueError("API key contains invalid characters") + + # Sanitize provider name for logging + safe_provider_name = _sanitize_for_log(provider_name) if provider_name else "unknown" + logger.info(f"Creating LLM client for provider: {safe_provider_name}") if provider_name == "openai": - if not api_key: - # Check if Ollama instances are available as fallback - logger.warning("OpenAI API key not found, attempting Ollama fallback") - try: - # Try to get an optimal Ollama instance for fallback - ollama_base_url = await _get_optimal_ollama_instance( - instance_type="embedding" if use_embedding_provider else "chat", - use_embedding_provider=use_embedding_provider - ) - if ollama_base_url: - logger.info(f"Falling back to Ollama instance: {ollama_base_url}") - provider_name = "ollama" - api_key = "ollama" # Ollama doesn't need a real API key - base_url = ollama_base_url - # Create Ollama client after fallback - client = openai.AsyncOpenAI( - api_key="ollama", - base_url=ollama_base_url, - ) - logger.info(f"Ollama fallback client created successfully with base URL: {ollama_base_url}") - else: - raise ValueError("OpenAI API key not found and no Ollama instances available") - except Exception as ollama_error: - logger.error(f"Ollama fallback failed: {ollama_error}") - raise ValueError("OpenAI API key not found and Ollama fallback failed") from ollama_error - else: - # Only create OpenAI client if we have an API key (didn't fallback to Ollama) + if api_key: client = openai.AsyncOpenAI(api_key=api_key) logger.info("OpenAI client created successfully") + else: + logger.warning("OpenAI API key not found, attempting Ollama fallback") + try: + ollama_base_url = await _get_optimal_ollama_instance( + instance_type="embedding" if use_embedding_provider else "chat", + use_embedding_provider=use_embedding_provider, + base_url_override=base_url, + ) + + if not ollama_base_url: + raise RuntimeError("No Ollama base URL resolved") + + client = openai.AsyncOpenAI( + api_key="ollama", + base_url=ollama_base_url, + ) + logger.info( + f"Ollama fallback client created successfully with base URL: {ollama_base_url}" + ) + provider_name = "ollama" + api_key = "ollama" + base_url = ollama_base_url + except Exception as fallback_error: + raise ValueError( + "OpenAI API key not found and Ollama fallback failed" + ) from fallback_error elif provider_name == "ollama": - # Enhanced Ollama client creation with multi-instance support + # For Ollama, get the optimal instance based on usage ollama_base_url = await _get_optimal_ollama_instance( instance_type=instance_type, use_embedding_provider=use_embedding_provider, - base_url_override=base_url + base_url_override=base_url, ) # Ollama requires an API key in the client but doesn't actually use it @@ -155,19 +452,101 @@ async def get_llm_client(provider: str | None = None, use_embedding_provider: bo ) logger.info("Google Gemini client created successfully") + elif provider_name == "openrouter": + if not api_key: + raise ValueError("OpenRouter API key not found") + + client = openai.AsyncOpenAI( + api_key=api_key, + base_url=base_url or "https://openrouter.ai/api/v1", + ) + logger.info("OpenRouter client created successfully") + + elif provider_name == "anthropic": + if not api_key: + raise ValueError("Anthropic API key not found") + + client = openai.AsyncOpenAI( + api_key=api_key, + base_url=base_url or "https://api.anthropic.com/v1", + ) + logger.info("Anthropic client created successfully") + + elif provider_name == "grok": + if not api_key: + raise ValueError("Grok API key not found - set GROK_API_KEY environment variable") + + # Enhanced Grok API key validation (secure - no key fragments logged) + key_format_valid = api_key.startswith("xai-") + key_length_valid = len(api_key) >= 20 + + if not key_format_valid: + logger.warning("Grok API key format validation failed - should start with 'xai-'") + + if not key_length_valid: + logger.warning("Grok API key validation failed - insufficient length") + + logger.debug( + f"Grok API key validation: format_valid={key_format_valid}, length_valid={key_length_valid}" + ) + + client = openai.AsyncOpenAI( + api_key=api_key, + base_url=base_url or "https://api.x.ai/v1", + ) + logger.info("Grok client created successfully") + else: raise ValueError(f"Unsupported LLM provider: {provider_name}") - yield client - except Exception as e: logger.error( - f"Error creating LLM client for provider {provider_name if 'provider_name' in locals() else 'unknown'}: {e}" + f"Error creating LLM client for provider {provider_name if provider_name else 'unknown'}: {e}" ) raise + + try: + yield client finally: - # Cleanup if needed - pass + if client is not None: + safe_provider = _sanitize_for_log(provider_name) if provider_name else "unknown" + + try: + close_method = getattr(client, "aclose", None) + if callable(close_method): + if inspect.iscoroutinefunction(close_method): + await close_method() + else: + maybe_coro = close_method() + if inspect.isawaitable(maybe_coro): + await maybe_coro + else: + close_method = getattr(client, "close", None) + if callable(close_method): + if inspect.iscoroutinefunction(close_method): + await close_method() + else: + close_result = close_method() + if inspect.isawaitable(close_result): + await close_result + logger.debug(f"Closed LLM client for provider: {safe_provider}") + except RuntimeError as close_error: + if "Event loop is closed" in str(close_error): + logger.error( + f"Failed to close LLM client cleanly for provider {safe_provider}: event loop already closed", + exc_info=True, + ) + else: + logger.error( + f"Runtime error closing LLM client for provider {safe_provider}: {close_error}", + exc_info=True, + ) + except Exception as close_error: + logger.error( + f"Unexpected error while closing LLM client for provider {safe_provider}: {close_error}", + exc_info=True, + ) + async def _get_optimal_ollama_instance(instance_type: str | None = None, @@ -250,9 +629,20 @@ async def get_embedding_model(provider: str | None = None) -> str: provider_name = provider_config["provider"] custom_model = provider_config["embedding_model"] - # Use custom model if specified - if custom_model: - return custom_model + # Comprehensive provider validation for embeddings + if not _is_valid_provider(provider_name): + safe_provider = _sanitize_for_log(provider_name) + logger.warning(f"Invalid embedding provider: {safe_provider}, falling back to OpenAI") + provider_name = "openai" + # Use custom model if specified (with validation) + if custom_model and len(custom_model.strip()) > 0: + custom_model = custom_model.strip() + # Basic model name validation (check length and basic characters) + if len(custom_model) <= 100 and not any(char in custom_model for char in ['\n', '\r', '\t', '\0']): + return custom_model + else: + safe_model = _sanitize_for_log(custom_model) + logger.warning(f"Invalid custom embedding model '{safe_model}' for provider '{provider_name}', using default") # Return provider-specific defaults if provider_name == "openai": @@ -261,8 +651,20 @@ async def get_embedding_model(provider: str | None = None) -> str: # Ollama default embedding model return "nomic-embed-text" elif provider_name == "google": - # Google's embedding model + # Google's latest embedding model return "text-embedding-004" + elif provider_name == "openrouter": + # OpenRouter supports both OpenAI and Google embedding models + # Default to OpenAI's latest for compatibility + return "text-embedding-3-small" + elif provider_name == "anthropic": + # Anthropic supports OpenAI and Google embedding models through their API + # Default to OpenAI's latest for compatibility + return "text-embedding-3-small" + elif provider_name == "grok": + # Grok supports OpenAI and Google embedding models through their API + # Default to OpenAI's latest for compatibility + return "text-embedding-3-small" else: # Fallback to OpenAI's model return "text-embedding-3-small" @@ -273,6 +675,463 @@ async def get_embedding_model(provider: str | None = None) -> str: return "text-embedding-3-small" +def is_openai_embedding_model(model: str) -> bool: + """Check if a model is an OpenAI embedding model.""" + if not model: + return False + + model_lower = model.strip().lower() + + # Known OpenAI embeddings + base_models = { + "text-embedding-ada-002", + "text-embedding-3-small", + "text-embedding-3-large", + } + + if model_lower in base_models: + return True + + # Strip common vendor prefixes like "openai/" or "openrouter/" + for separator in ("/", ":"): + if separator in model_lower: + candidate = model_lower.split(separator)[-1] + if candidate in base_models: + return True + + # Fallback substring detection for custom naming conventions + return any(base in model_lower for base in base_models) + + +def is_google_embedding_model(model: str) -> bool: + """Check if a model is a Google embedding model.""" + if not model: + return False + + model_lower = model.lower() + google_patterns = [ + "text-embedding-004", + "text-embedding-005", + "text-multilingual-embedding-002", + "gemini-embedding-001", + "multimodalembedding@001" + ] + + return any(pattern in model_lower for pattern in google_patterns) + + +def is_valid_embedding_model_for_provider(model: str, provider: str) -> bool: + """ + Validate if an embedding model is compatible with a provider. + + Args: + model: The embedding model name + provider: The provider name + + Returns: + bool: True if the model is compatible with the provider + """ + if not model or not provider: + return False + + provider_lower = provider.lower() + + if provider_lower == "openai": + return is_openai_embedding_model(model) + elif provider_lower == "google": + return is_google_embedding_model(model) + elif provider_lower in ["openrouter", "anthropic", "grok"]: + # These providers support both OpenAI and Google models + return is_openai_embedding_model(model) or is_google_embedding_model(model) + elif provider_lower == "ollama": + # Ollama has its own models, check common ones + model_lower = model.lower() + ollama_patterns = ["nomic-embed", "all-minilm", "mxbai-embed", "embed"] + return any(pattern in model_lower for pattern in ollama_patterns) + else: + # For unknown providers, assume OpenAI compatibility + return is_openai_embedding_model(model) + + +def get_supported_embedding_models(provider: str) -> list[str]: + """ + Get list of supported embedding models for a provider. + + Args: + provider: The provider name + + Returns: + List of supported embedding model names + """ + if not provider: + return [] + + provider_lower = provider.lower() + + openai_models = [ + "text-embedding-ada-002", + "text-embedding-3-small", + "text-embedding-3-large" + ] + + google_models = [ + "text-embedding-004", + "text-embedding-005", + "text-multilingual-embedding-002", + "gemini-embedding-001", + "multimodalembedding@001" + ] + + if provider_lower == "openai": + return openai_models + elif provider_lower == "google": + return google_models + elif provider_lower in ["openrouter", "anthropic", "grok"]: + # These providers support both OpenAI and Google models + return openai_models + google_models + elif provider_lower == "ollama": + return ["nomic-embed-text", "all-minilm", "mxbai-embed-large"] + else: + # For unknown providers, assume OpenAI compatibility + return openai_models + + +def is_reasoning_model(model_name: str) -> bool: + """ + Unified check for reasoning models across providers. + + Normalizes vendor prefixes (openai/, openrouter/, x-ai/, deepseek/) before checking + known reasoning families (OpenAI GPT-5, o1, o3; xAI Grok; DeepSeek-R; etc.). + """ + if not model_name: + return False + + model_lower = model_name.lower() + + # Normalize vendor prefixes (e.g., openai/gpt-5-nano, openrouter/x-ai/grok-4) + if "/" in model_lower: + parts = model_lower.split("/") + # Drop known vendor prefixes while keeping the final model identifier + known_prefixes = {"openai", "openrouter", "x-ai", "deepseek", "anthropic"} + filtered_parts = [part for part in parts if part not in known_prefixes] + if filtered_parts: + model_lower = filtered_parts[-1] + else: + model_lower = parts[-1] + + if ":" in model_lower: + model_lower = model_lower.split(":", 1)[-1] + + reasoning_prefixes = ( + "gpt-5", + "o1", + "o3", + "o4", + "grok", + "deepseek-r", + "deepseek-reasoner", + "deepseek-chat-r", + ) + + return model_lower.startswith(reasoning_prefixes) + + +def _extract_reasoning_strings(value: Any) -> list[str]: + """Convert reasoning payload fragments into plain-text strings.""" + + if value is None: + return [] + + if isinstance(value, str): + text = value.strip() + return [text] if text else [] + + if isinstance(value, (list, tuple, set)): + collected: list[str] = [] + for item in value: + collected.extend(_extract_reasoning_strings(item)) + return collected + + if isinstance(value, dict): + candidates = [] + for key in ("text", "summary", "content", "message", "value"): + if value.get(key): + candidates.extend(_extract_reasoning_strings(value[key])) + # Some providers nest reasoning parts under "parts" + if value.get("parts"): + candidates.extend(_extract_reasoning_strings(value["parts"])) + return candidates + + # Handle pydantic-style objects with attributes + for attr in ("text", "summary", "content", "value"): + if hasattr(value, attr): + attr_value = getattr(value, attr) + if attr_value: + return _extract_reasoning_strings(attr_value) + + return [] + + +def _get_message_attr(message: Any, attribute: str) -> Any: + """Safely access message attributes that may be dict keys or properties.""" + + if hasattr(message, attribute): + return getattr(message, attribute) + if isinstance(message, dict): + return message.get(attribute) + return None + + +def extract_message_text(choice: Any) -> tuple[str, str, bool]: + """Extract primary content and reasoning text from a chat completion choice.""" + + if not choice: + return "", "", False + + message = _get_message_attr(choice, "message") + if message is None: + return "", "", False + + raw_content = _get_message_attr(message, "content") + content_text = raw_content.strip() if isinstance(raw_content, str) else "" + + reasoning_fragments: list[str] = [] + for attr in ("reasoning", "reasoning_details", "reasoning_content"): + reasoning_value = _get_message_attr(message, attr) + if reasoning_value: + reasoning_fragments.extend(_extract_reasoning_strings(reasoning_value)) + + reasoning_text = "\n".join(fragment for fragment in reasoning_fragments if fragment) + reasoning_text = reasoning_text.strip() + + # If content looks like reasoning text but no reasoning field, detect it + if content_text and not reasoning_text and _is_reasoning_text(content_text): + reasoning_text = content_text + # Try to extract structured data from reasoning text + extracted_json = extract_json_from_reasoning(content_text) + if extracted_json: + content_text = extracted_json + else: + content_text = "" + + if not content_text and reasoning_text: + content_text = reasoning_text + + has_reasoning = bool(reasoning_text) + + return content_text, reasoning_text, has_reasoning + + +def _is_reasoning_text(text: str) -> bool: + """Detect if text appears to be reasoning/thinking output rather than structured content.""" + if not text or len(text) < 10: + return False + + text_lower = text.lower().strip() + + # Common reasoning text patterns + reasoning_indicators = [ + "okay, let's see", "let me think", "first, i need to", "looking at this", + "step by step", "analyzing", "breaking this down", "considering", + "let me work through", "i should", "thinking about", "examining" + ] + + return any(indicator in text_lower for indicator in reasoning_indicators) + + +def extract_json_from_reasoning(reasoning_text: str, context_code: str = "", language: str = "") -> str: + """Extract JSON content from reasoning text, with synthesis fallback.""" + if not reasoning_text: + return "" + + import json + import re + + # Try to find JSON blocks in markdown + json_block_pattern = r'```(?:json)?\s*(\{.*?\})\s*```' + json_matches = re.findall(json_block_pattern, reasoning_text, re.DOTALL | re.IGNORECASE) + + for match in json_matches: + try: + # Validate it's proper JSON + json.loads(match.strip()) + return match.strip() + except json.JSONDecodeError: + continue + + # Try to find standalone JSON objects + json_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}' + json_matches = re.findall(json_pattern, reasoning_text, re.DOTALL) + + for match in json_matches: + try: + parsed = json.loads(match.strip()) + # Ensure it has expected structure + if isinstance(parsed, dict) and any(key in parsed for key in ["example_name", "summary", "name", "title"]): + return match.strip() + except json.JSONDecodeError: + continue + + # If no JSON found, synthesize from reasoning content + return synthesize_json_from_reasoning(reasoning_text, context_code, language) + + +def synthesize_json_from_reasoning(reasoning_text: str, context_code: str = "", language: str = "") -> str: + """Generate JSON structure from reasoning text when no JSON is found.""" + if not reasoning_text and not context_code: + return "" + + import json + import re + + # Extract key concepts and actions from reasoning text and code context + text_lower = reasoning_text.lower() if reasoning_text else "" + code_lower = context_code.lower() if context_code else "" + combined_text = f"{text_lower} {code_lower}" + + # Common action patterns in reasoning text and code + action_patterns = [ + (r'\b(?:parse|parsing|parsed)\b', 'Parse'), + (r'\b(?:create|creating|created)\b', 'Create'), + (r'\b(?:analyze|analyzing|analyzed)\b', 'Analyze'), + (r'\b(?:extract|extracting|extracted)\b', 'Extract'), + (r'\b(?:generate|generating|generated)\b', 'Generate'), + (r'\b(?:process|processing|processed)\b', 'Process'), + (r'\b(?:load|loading|loaded)\b', 'Load'), + (r'\b(?:handle|handling|handled)\b', 'Handle'), + (r'\b(?:manage|managing|managed)\b', 'Manage'), + (r'\b(?:build|building|built)\b', 'Build'), + (r'\b(?:define|defining|defined)\b', 'Define'), + (r'\b(?:implement|implementing|implemented)\b', 'Implement'), + (r'\b(?:fetch|fetching|fetched)\b', 'Fetch'), + (r'\b(?:connect|connecting|connected)\b', 'Connect'), + (r'\b(?:validate|validating|validated)\b', 'Validate'), + ] + + # Technology/concept patterns + tech_patterns = [ + (r'\bjson\b', 'JSON'), + (r'\bapi\b', 'API'), + (r'\bfile\b', 'File'), + (r'\bdata\b', 'Data'), + (r'\bcode\b', 'Code'), + (r'\btext\b', 'Text'), + (r'\bcontent\b', 'Content'), + (r'\bresponse\b', 'Response'), + (r'\brequest\b', 'Request'), + (r'\bconfig\b', 'Config'), + (r'\bllm\b', 'LLM'), + (r'\bmodel\b', 'Model'), + (r'\bexample\b', 'Example'), + (r'\bcontext\b', 'Context'), + (r'\basync\b', 'Async'), + (r'\bfunction\b', 'Function'), + (r'\bclass\b', 'Class'), + (r'\bprint\b', 'Output'), + (r'\breturn\b', 'Return'), + ] + + # Extract actions and technologies from combined text + detected_actions = [] + detected_techs = [] + + for pattern, action in action_patterns: + if re.search(pattern, combined_text): + detected_actions.append(action) + + for pattern, tech in tech_patterns: + if re.search(pattern, combined_text): + detected_techs.append(tech) + + # Generate example name + if detected_actions and detected_techs: + example_name = f"{detected_actions[0]} {detected_techs[0]}" + elif detected_actions: + example_name = f"{detected_actions[0]} Code" + elif detected_techs: + example_name = f"Handle {detected_techs[0]}" + elif language: + example_name = f"Process {language.title()}" + else: + example_name = "Code Processing" + + # Limit to 4 words as per requirements + example_name_words = example_name.split() + if len(example_name_words) > 4: + example_name = " ".join(example_name_words[:4]) + + # Generate summary from reasoning content + reasoning_lines = reasoning_text.split('\n') + meaningful_lines = [line.strip() for line in reasoning_lines if line.strip() and len(line.strip()) > 10] + + if meaningful_lines: + # Take first meaningful sentence for summary base + first_line = meaningful_lines[0] + if len(first_line) > 100: + first_line = first_line[:100] + "..." + + # Create contextual summary + if context_code and any(tech in text_lower for tech, _ in tech_patterns): + summary = f"This code demonstrates {detected_techs[0].lower() if detected_techs else 'data'} processing functionality. {first_line}" + else: + summary = f"Code example showing {detected_actions[0].lower() if detected_actions else 'processing'} operations. {first_line}" + else: + # Fallback summary + summary = f"Code example demonstrating {example_name.lower()} functionality for {language or 'general'} development." + + # Ensure summary is not too long + if len(summary) > 300: + summary = summary[:297] + "..." + + # Create JSON structure + result = { + "example_name": example_name, + "summary": summary + } + + return json.dumps(result) + + +def prepare_chat_completion_params(model: str, params: dict) -> dict: + """ + Convert parameters for compatibility with reasoning models (GPT-5, o1, o3 series). + + OpenAI made several API changes for reasoning models: + 1. max_tokens → max_completion_tokens + 2. temperature must be 1.0 (default) - custom values not supported + + This ensures compatibility with OpenAI's API changes for newer models + while maintaining backward compatibility for existing models. + + Args: + model: The model name being used + params: Dictionary of API parameters + + Returns: + Dictionary with converted parameters for the model + """ + if not model or not params: + return params + + # Make a copy to avoid modifying the original + updated_params = params.copy() + + reasoning_model = is_reasoning_model(model) + + # Convert max_tokens to max_completion_tokens for reasoning models + if reasoning_model and "max_tokens" in updated_params: + max_tokens_value = updated_params.pop("max_tokens") + updated_params["max_completion_tokens"] = max_tokens_value + logger.debug(f"Converted max_tokens to max_completion_tokens for model {model}") + + # Remove custom temperature for reasoning models (they only support default temperature=1.0) + if reasoning_model and "temperature" in updated_params: + original_temp = updated_params.pop("temperature") + logger.debug(f"Removed custom temperature {original_temp} for reasoning model {model} (only supports default temperature=1.0)") + + return updated_params + + async def get_embedding_model_with_routing(provider: str | None = None, instance_url: str | None = None) -> tuple[str, str]: """ Get the embedding model with intelligent routing for multi-instance setups. @@ -383,3 +1242,9 @@ async def validate_provider_instance(provider: str, instance_url: str | None = N "error_message": str(e), "validation_timestamp": time.time() } + + + +def requires_max_completion_tokens(model_name: str) -> bool: + """Backward compatible alias for previous API.""" + return is_reasoning_model(model_name) diff --git a/python/src/server/services/provider_discovery_service.py b/python/src/server/services/provider_discovery_service.py index ccd811dd..2ea3bc32 100644 --- a/python/src/server/services/provider_discovery_service.py +++ b/python/src/server/services/provider_discovery_service.py @@ -2,7 +2,7 @@ Provider Discovery Service Discovers available models, checks provider health, and provides model specifications -for OpenAI, Google Gemini, Ollama, and Anthropic providers. +for OpenAI, Google Gemini, Ollama, Anthropic, and Grok providers. """ import time @@ -359,6 +359,36 @@ class ProviderDiscoveryService: return models + async def discover_grok_models(self, api_key: str) -> list[ModelSpec]: + """Discover available Grok models.""" + cache_key = f"grok_models_{hash(api_key)}" + cached = self._get_cached_result(cache_key) + if cached: + return cached + + models = [] + try: + # Grok model specifications + model_specs = [ + ModelSpec("grok-3-mini", "grok", 32768, True, True, False, None, 0.15, 0.60, "Fast and efficient Grok model"), + ModelSpec("grok-3", "grok", 32768, True, True, False, None, 2.00, 10.00, "Standard Grok model"), + ModelSpec("grok-4", "grok", 32768, True, True, False, None, 5.00, 25.00, "Advanced Grok model"), + ModelSpec("grok-2-vision", "grok", 8192, True, True, True, None, 3.00, 15.00, "Grok model with vision capabilities"), + ModelSpec("grok-2-latest", "grok", 8192, True, True, False, None, 2.00, 10.00, "Latest Grok 2 model"), + ] + + # Test connectivity - Grok doesn't have a models list endpoint, + # so we'll just return the known models if API key is provided + if api_key: + models = model_specs + self._cache_result(cache_key, models) + logger.info(f"Discovered {len(models)} Grok models") + + except Exception as e: + logger.error(f"Error discovering Grok models: {e}") + + return models + async def check_provider_health(self, provider: str, config: dict[str, Any]) -> ProviderStatus: """Check health and connectivity status of a provider.""" start_time = time.time() @@ -456,6 +486,23 @@ class ProviderDiscoveryService: last_checked=time.time() ) + elif provider == "grok": + api_key = config.get("api_key") + if not api_key: + return ProviderStatus(provider, False, None, "API key not configured") + + # Grok doesn't have a health check endpoint, so we'll assume it's available + # if API key is provided. In a real implementation, you might want to make a + # small test request to verify the key is valid. + response_time = (time.time() - start_time) * 1000 + return ProviderStatus( + provider="grok", + is_available=True, + response_time_ms=response_time, + models_available=5, # Known model count + last_checked=time.time() + ) + else: return ProviderStatus(provider, False, None, f"Unknown provider: {provider}") @@ -496,6 +543,11 @@ class ProviderDiscoveryService: if anthropic_key: providers["anthropic"] = await self.discover_anthropic_models(anthropic_key) + # Grok + grok_key = await credential_service.get_credential("GROK_API_KEY") + if grok_key: + providers["grok"] = await self.discover_grok_models(grok_key) + except Exception as e: logger.error(f"Error getting all available models: {e}") diff --git a/python/src/server/services/source_management_service.py b/python/src/server/services/source_management_service.py index c7bcdb66..f8a27023 100644 --- a/python/src/server/services/source_management_service.py +++ b/python/src/server/services/source_management_service.py @@ -11,7 +11,7 @@ from supabase import Client from ..config.logfire_config import get_logger, search_logger from .client_manager import get_supabase_client -from .llm_provider_service import get_llm_client +from .llm_provider_service import extract_message_text, get_llm_client logger = get_logger(__name__) @@ -72,20 +72,21 @@ The above content is from the documentation for '{source_id}'. Please provide a ) # Extract the generated summary with proper error handling - if not response or not response.choices or len(response.choices) == 0: - search_logger.error(f"Empty or invalid response from LLM for {source_id}") - return default_summary - - message_content = response.choices[0].message.content - if message_content is None: - search_logger.error(f"LLM returned None content for {source_id}") - return default_summary - - summary = message_content.strip() - - # Ensure the summary is not too long - if len(summary) > max_length: - summary = summary[:max_length] + "..." + if not response or not response.choices or len(response.choices) == 0: + search_logger.error(f"Empty or invalid response from LLM for {source_id}") + return default_summary + + choice = response.choices[0] + summary_text, _, _ = extract_message_text(choice) + if not summary_text: + search_logger.error(f"LLM returned None content for {source_id}") + return default_summary + + summary = summary_text.strip() + + # Ensure the summary is not too long + if len(summary) > max_length: + summary = summary[:max_length] + "..." return summary @@ -187,7 +188,9 @@ Generate only the title, nothing else.""" ], ) - generated_title = response.choices[0].message.content.strip() + choice = response.choices[0] + generated_title, _, _ = extract_message_text(choice) + generated_title = generated_title.strip() # Clean up the title generated_title = generated_title.strip("\"'") if len(generated_title) < 50: # Sanity check diff --git a/python/src/server/services/storage/code_storage_service.py b/python/src/server/services/storage/code_storage_service.py index ece5ea10..a993bc70 100644 --- a/python/src/server/services/storage/code_storage_service.py +++ b/python/src/server/services/storage/code_storage_service.py @@ -8,6 +8,7 @@ import asyncio import json import os import re +import time from collections import defaultdict, deque from collections.abc import Callable from difflib import SequenceMatcher @@ -17,25 +18,104 @@ from urllib.parse import urlparse from supabase import Client from ...config.logfire_config import search_logger +from ..credential_service import credential_service from ..embeddings.contextual_embedding_service import generate_contextual_embeddings_batch from ..embeddings.embedding_service import create_embeddings_batch +from ..llm_provider_service import ( + extract_json_from_reasoning, + extract_message_text, + get_llm_client, + prepare_chat_completion_params, + synthesize_json_from_reasoning, +) -def _get_model_choice() -> str: - """Get MODEL_CHOICE with direct fallback.""" +def _extract_json_payload(raw_response: str, context_code: str = "", language: str = "") -> str: + """Return the best-effort JSON object from an LLM response.""" + + if not raw_response: + return raw_response + + cleaned = raw_response.strip() + + # Check if this looks like reasoning text first + if _is_reasoning_text_response(cleaned): + # Try intelligent extraction from reasoning text with context + extracted = extract_json_from_reasoning(cleaned, context_code, language) + if extracted: + return extracted + # extract_json_from_reasoning may return nothing; synthesize a fallback JSON if so\ + fallback_json = synthesize_json_from_reasoning("", context_code, language) + if fallback_json: + return fallback_json + # If all else fails, return a minimal valid JSON object to avoid downstream errors + return '{"example_name": "Code Example", "summary": "Code example extracted from context."}' + + + if cleaned.startswith("```"): + lines = cleaned.splitlines() + # Drop opening fence + lines = lines[1:] + # Drop closing fence if present + if lines and lines[-1].strip().startswith("```"): + lines = lines[:-1] + cleaned = "\n".join(lines).strip() + + # Trim any leading/trailing text outside the outermost JSON braces + start = cleaned.find("{") + end = cleaned.rfind("}") + if start != -1 and end != -1 and end >= start: + cleaned = cleaned[start : end + 1] + + return cleaned.strip() + + +REASONING_STARTERS = [ + "okay, let's see", "okay, let me", "let me think", "first, i need to", "looking at this", + "i need to", "analyzing", "let me work through", "thinking about", "let me see" +] + +def _is_reasoning_text_response(text: str) -> bool: + """Detect if response is reasoning text rather than direct JSON.""" + if not text or len(text) < 20: + return False + + text_lower = text.lower().strip() + + # Check if it's clearly not JSON (starts with reasoning text) + starts_with_reasoning = any(text_lower.startswith(starter) for starter in REASONING_STARTERS) + + # Check if it lacks immediate JSON structure + lacks_immediate_json = not text_lower.lstrip().startswith('{') + + return starts_with_reasoning or (lacks_immediate_json and any(pattern in text_lower for pattern in REASONING_STARTERS)) +async def _get_model_choice() -> str: + """Get MODEL_CHOICE with provider-aware defaults from centralized service.""" try: - # Direct cache/env fallback - from ..credential_service import credential_service + # Get the active provider configuration + provider_config = await credential_service.get_active_provider("llm") + active_provider = provider_config.get("provider", "openai") + model = provider_config.get("chat_model") - if credential_service._cache_initialized and "MODEL_CHOICE" in credential_service._cache: - model = credential_service._cache["MODEL_CHOICE"] - else: - model = os.getenv("MODEL_CHOICE", "gpt-4.1-nano") - search_logger.debug(f"Using model choice: {model}") + # If no custom model is set, use provider-specific defaults + if not model or model.strip() == "": + # Provider-specific defaults + provider_defaults = { + "openai": "gpt-4o-mini", + "openrouter": "anthropic/claude-3.5-sonnet", + "google": "gemini-1.5-flash", + "ollama": "llama3.2:latest", + "anthropic": "claude-3-5-haiku-20241022", + "grok": "grok-3-mini" + } + model = provider_defaults.get(active_provider, "gpt-4o-mini") + search_logger.debug(f"Using default model for provider {active_provider}: {model}") + + search_logger.debug(f"Using model for provider {active_provider}: {model}") return model except Exception as e: search_logger.warning(f"Error getting model choice: {e}, using default") - return "gpt-4.1-nano" + return "gpt-4o-mini" def _get_max_workers() -> int: @@ -155,6 +235,7 @@ def _select_best_code_variant(similar_blocks: list[dict[str, Any]]) -> dict[str, return best_block + def extract_code_blocks(markdown_content: str, min_length: int = None) -> list[dict[str, Any]]: """ Extract code blocks from markdown content along with context. @@ -168,8 +249,6 @@ def extract_code_blocks(markdown_content: str, min_length: int = None) -> list[d """ # Load all code extraction settings with direct fallback try: - from ...services.credential_service import credential_service - def _get_setting_fallback(key: str, default: str) -> str: if credential_service._cache_initialized and key in credential_service._cache: return credential_service._cache[key] @@ -507,7 +586,7 @@ def generate_code_example_summary( A dictionary with 'summary' and 'example_name' """ import asyncio - + # Run the async version in the current thread return asyncio.run(_generate_code_example_summary_async(code, context_before, context_after, language, provider)) @@ -518,13 +597,22 @@ async def _generate_code_example_summary_async( """ Async version of generate_code_example_summary using unified LLM provider service. """ - from ..llm_provider_service import get_llm_client - - # Get model choice from credential service (RAG setting) - model_choice = _get_model_choice() - # Create the prompt - prompt = f""" + # Get model choice from credential service (RAG setting) + model_choice = await _get_model_choice() + + # If provider is not specified, get it from credential service + if provider is None: + try: + provider_config = await credential_service.get_active_provider("llm") + provider = provider_config.get("provider", "openai") + search_logger.debug(f"Auto-detected provider from credential service: {provider}") + except Exception as e: + search_logger.warning(f"Failed to get provider from credential service: {e}, defaulting to openai") + provider = "openai" + + # Create the prompt variants: base prompt, guarded prompt (JSON reminder), and strict prompt for retries + base_prompt = f""" {context_before[-500:] if len(context_before) > 500 else context_before} @@ -548,6 +636,16 @@ Format your response as JSON: "summary": "2-3 sentence description of what the code demonstrates" }} """ + guard_prompt = ( + base_prompt + + "\n\nImportant: Respond with a valid JSON object that exactly matches the keys " + '{"example_name": string, "summary": string}. Do not include commentary, ' + "markdown fences, or reasoning notes." + ) + strict_prompt = ( + guard_prompt + + "\n\nSecond attempt enforcement: Return JSON only with the exact schema. No additional text or reasoning content." + ) try: # Use unified LLM provider service @@ -555,25 +653,261 @@ Format your response as JSON: search_logger.info( f"Generating summary for {hash(code) & 0xffffff:06x} using model: {model_choice}" ) - - response = await client.chat.completions.create( - model=model_choice, - messages=[ - { - "role": "system", - "content": "You are a helpful assistant that analyzes code examples and provides JSON responses with example names and summaries.", - }, - {"role": "user", "content": prompt}, - ], - response_format={"type": "json_object"}, - max_tokens=500, - temperature=0.3, + + provider_lower = provider.lower() + is_grok_model = (provider_lower == "grok") or ("grok" in model_choice.lower()) + + supports_response_format_base = ( + provider_lower in {"openai", "google", "anthropic"} + or (provider_lower == "openrouter" and model_choice.startswith("openai/")) ) - response_content = response.choices[0].message.content.strip() + last_response_obj = None + last_elapsed_time = None + last_response_content = "" + last_json_error: json.JSONDecodeError | None = None + + for enforce_json, current_prompt in ((False, guard_prompt), (True, strict_prompt)): + request_params = { + "model": model_choice, + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant that analyzes code examples and provides JSON responses with example names and summaries.", + }, + {"role": "user", "content": current_prompt}, + ], + "max_tokens": 2000, + "temperature": 0.3, + } + + should_use_response_format = False + if enforce_json: + if not is_grok_model and (supports_response_format_base or provider_lower == "openrouter"): + should_use_response_format = True + else: + if supports_response_format_base: + should_use_response_format = True + + if should_use_response_format: + request_params["response_format"] = {"type": "json_object"} + + if is_grok_model: + unsupported_params = ["presence_penalty", "frequency_penalty", "stop", "reasoning_effort"] + for param in unsupported_params: + if param in request_params: + removed_value = request_params.pop(param) + search_logger.warning(f"Removed unsupported Grok parameter '{param}': {removed_value}") + + supported_params = ["model", "messages", "max_tokens", "temperature", "response_format", "stream", "tools", "tool_choice"] + for param in list(request_params.keys()): + if param not in supported_params: + search_logger.warning(f"Parameter '{param}' may not be supported by Grok reasoning models") + + start_time = time.time() + max_retries = 3 if is_grok_model else 1 + retry_delay = 1.0 + response_content_local = "" + reasoning_text_local = "" + json_error_occurred = False + + for attempt in range(max_retries): + try: + if is_grok_model and attempt > 0: + search_logger.info(f"Grok retry attempt {attempt + 1}/{max_retries} after {retry_delay:.1f}s delay") + await asyncio.sleep(retry_delay) + + final_params = prepare_chat_completion_params(model_choice, request_params) + response = await client.chat.completions.create(**final_params) + last_response_obj = response + + choice = response.choices[0] if response.choices else None + message = choice.message if choice and hasattr(choice, "message") else None + response_content_local = "" + reasoning_text_local = "" + + if choice: + response_content_local, reasoning_text_local, _ = extract_message_text(choice) + + # Enhanced logging for response analysis + if message and reasoning_text_local: + content_preview = response_content_local[:100] if response_content_local else "None" + reasoning_preview = reasoning_text_local[:100] if reasoning_text_local else "None" + search_logger.debug( + f"Response has reasoning content - content: '{content_preview}', reasoning: '{reasoning_preview}'" + ) + + if response_content_local: + last_response_content = response_content_local.strip() + + # Pre-validate response before processing + if len(last_response_content) < 20 or (len(last_response_content) < 50 and not last_response_content.strip().startswith('{')): + # Very minimal response - likely "Okay\nOkay" type + search_logger.debug(f"Minimal response detected: {repr(last_response_content)}") + # Generate fallback directly from context + fallback_json = synthesize_json_from_reasoning("", code, language) + if fallback_json: + try: + result = json.loads(fallback_json) + final_result = { + "example_name": result.get("example_name", f"Code Example{f' ({language})' if language else ''}"), + "summary": result.get("summary", "Code example for demonstration purposes."), + } + search_logger.info(f"Generated fallback summary from context - Name: '{final_result['example_name']}', Summary length: {len(final_result['summary'])}") + return final_result + except json.JSONDecodeError: + pass # Continue to normal error handling + else: + # Even synthesis failed - provide hardcoded fallback for minimal responses + final_result = { + "example_name": f"Code Example{f' ({language})' if language else ''}", + "summary": "Code example extracted from development context.", + } + search_logger.info(f"Used hardcoded fallback for minimal response - Name: '{final_result['example_name']}', Summary length: {len(final_result['summary'])}") + return final_result + + payload = _extract_json_payload(last_response_content, code, language) + if payload != last_response_content: + search_logger.debug( + f"Sanitized LLM response payload before parsing: {repr(payload[:200])}..." + ) + + try: + result = json.loads(payload) + + if not result.get("example_name") or not result.get("summary"): + search_logger.warning(f"Incomplete response from LLM: {result}") + + final_result = { + "example_name": result.get( + "example_name", f"Code Example{f' ({language})' if language else ''}" + ), + "summary": result.get("summary", "Code example for demonstration purposes."), + } + + search_logger.info( + f"Generated code example summary - Name: '{final_result['example_name']}', Summary length: {len(final_result['summary'])}" + ) + return final_result + + except json.JSONDecodeError as json_error: + last_json_error = json_error + json_error_occurred = True + snippet = last_response_content[:200] + if not enforce_json: + # Check if this was reasoning text that couldn't be parsed + if _is_reasoning_text_response(last_response_content): + search_logger.debug( + f"Reasoning text detected but no JSON extracted. Response snippet: {repr(snippet)}" + ) + else: + search_logger.warning( + f"Failed to parse JSON response from LLM (non-strict attempt). Error: {json_error}. Response snippet: {repr(snippet)}" + ) + break + else: + search_logger.error( + f"Strict JSON enforcement still failed to produce valid JSON: {json_error}. Response snippet: {repr(snippet)}" + ) + break + + elif is_grok_model and attempt < max_retries - 1: + search_logger.warning(f"Grok empty response on attempt {attempt + 1}, retrying...") + retry_delay *= 2 + continue + else: + break + + except Exception as e: + if is_grok_model and attempt < max_retries - 1: + search_logger.error(f"Grok request failed on attempt {attempt + 1}: {e}, retrying...") + retry_delay *= 2 + continue + else: + raise + + if is_grok_model: + elapsed_time = time.time() - start_time + last_elapsed_time = elapsed_time + search_logger.debug(f"Grok total response time: {elapsed_time:.2f}s") + + if json_error_occurred: + if not enforce_json: + continue + else: + break + + if response_content_local: + # We would have returned already on success; if we reach here, parsing failed but we are not retrying + continue + + response_content = last_response_content + response = last_response_obj + elapsed_time = last_elapsed_time if last_elapsed_time is not None else 0.0 + + if last_json_error is not None and response_content: + search_logger.error( + f"LLM response after strict enforcement was still not valid JSON: {last_json_error}. Clearing response to trigger error handling." + ) + response_content = "" + + if not response_content: + search_logger.error(f"Empty response from LLM for model: {model_choice} (provider: {provider})") + if is_grok_model: + search_logger.error("Grok empty response debugging:") + search_logger.error(f" - Request took: {elapsed_time:.2f}s") + search_logger.error(f" - Response status: {getattr(response, 'status_code', 'N/A')}") + search_logger.error(f" - Response headers: {getattr(response, 'headers', 'N/A')}") + search_logger.error(f" - Full response: {response}") + search_logger.error(f" - Response choices length: {len(response.choices) if response.choices else 0}") + if response.choices: + search_logger.error(f" - First choice: {response.choices[0]}") + search_logger.error(f" - Message content: '{response.choices[0].message.content}'") + search_logger.error(f" - Message role: {response.choices[0].message.role}") + search_logger.error("Check: 1) API key validity, 2) rate limits, 3) model availability") + + # Implement fallback for Grok failures + search_logger.warning("Attempting fallback to OpenAI due to Grok failure...") + try: + # Use OpenAI as fallback with similar parameters + fallback_params = { + "model": "gpt-4o-mini", + "messages": request_params["messages"], + "temperature": request_params.get("temperature", 0.1), + "max_tokens": request_params.get("max_tokens", 500), + } + + async with get_llm_client(provider="openai") as fallback_client: + fallback_response = await fallback_client.chat.completions.create(**fallback_params) + fallback_content = fallback_response.choices[0].message.content + if fallback_content and fallback_content.strip(): + search_logger.info("gpt-4o-mini fallback succeeded") + response_content = fallback_content.strip() + else: + search_logger.error("gpt-4o-mini fallback also returned empty response") + raise ValueError(f"Both {model_choice} and gpt-4o-mini fallback failed") + + except Exception as fallback_error: + search_logger.error(f"gpt-4o-mini fallback failed: {fallback_error}") + raise ValueError(f"{model_choice} failed and fallback to gpt-4o-mini also failed: {fallback_error}") from fallback_error + else: + search_logger.debug(f"Full response object: {response}") + raise ValueError("Empty response from LLM") + + if not response_content: + # This should not happen after fallback logic, but safety check + raise ValueError("No valid response content after all attempts") + + response_content = response_content.strip() search_logger.debug(f"LLM API response: {repr(response_content[:200])}...") - result = json.loads(response_content) + payload = _extract_json_payload(response_content, code, language) + if payload != response_content: + search_logger.debug( + f"Sanitized LLM response payload before parsing: {repr(payload[:200])}..." + ) + + result = json.loads(payload) # Validate the response has the required fields if not result.get("example_name") or not result.get("summary"): @@ -595,12 +929,38 @@ Format your response as JSON: search_logger.error( f"Failed to parse JSON response from LLM: {e}, Response: {repr(response_content) if 'response_content' in locals() else 'No response'}" ) + # Try to generate context-aware fallback + try: + fallback_json = synthesize_json_from_reasoning("", code, language) + if fallback_json: + fallback_result = json.loads(fallback_json) + search_logger.info(f"Generated context-aware fallback summary") + return { + "example_name": fallback_result.get("example_name", f"Code Example{f' ({language})' if language else ''}"), + "summary": fallback_result.get("summary", "Code example for demonstration purposes."), + } + except Exception: + pass # Fall through to generic fallback + return { "example_name": f"Code Example{f' ({language})' if language else ''}", "summary": "Code example for demonstration purposes.", } except Exception as e: search_logger.error(f"Error generating code summary using unified LLM provider: {e}") + # Try to generate context-aware fallback + try: + fallback_json = synthesize_json_from_reasoning("", code, language) + if fallback_json: + fallback_result = json.loads(fallback_json) + search_logger.info(f"Generated context-aware fallback summary after error") + return { + "example_name": fallback_result.get("example_name", f"Code Example{f' ({language})' if language else ''}"), + "summary": fallback_result.get("summary", "Code example for demonstration purposes."), + } + except Exception: + pass # Fall through to generic fallback + return { "example_name": f"Code Example{f' ({language})' if language else ''}", "summary": "Code example for demonstration purposes.", @@ -608,7 +968,7 @@ Format your response as JSON: async def generate_code_summaries_batch( - code_blocks: list[dict[str, Any]], max_workers: int = None, progress_callback=None + code_blocks: list[dict[str, Any]], max_workers: int = None, progress_callback=None, provider: str = None ) -> list[dict[str, str]]: """ Generate summaries for multiple code blocks with rate limiting and proper worker management. @@ -617,6 +977,7 @@ async def generate_code_summaries_batch( code_blocks: List of code block dictionaries max_workers: Maximum number of concurrent API requests progress_callback: Optional callback for progress updates (async function) + provider: LLM provider to use for generation (e.g., 'grok', 'openai', 'anthropic') Returns: List of summary dictionaries @@ -627,8 +988,6 @@ async def generate_code_summaries_batch( # Get max_workers from settings if not provided if max_workers is None: try: - from ...services.credential_service import credential_service - if ( credential_service._cache_initialized and "CODE_SUMMARY_MAX_WORKERS" in credential_service._cache @@ -663,6 +1022,7 @@ async def generate_code_summaries_batch( block["context_before"], block["context_after"], block.get("language", ""), + provider, ) # Update progress @@ -757,29 +1117,17 @@ async def add_code_examples_to_supabase( except Exception as e: search_logger.error(f"Error deleting existing code examples for {url}: {e}") - # Check if contextual embeddings are enabled + # Check if contextual embeddings are enabled (use proper async method like document storage) try: - from ..credential_service import credential_service - - use_contextual_embeddings = credential_service._cache.get("USE_CONTEXTUAL_EMBEDDINGS") - if isinstance(use_contextual_embeddings, str): - use_contextual_embeddings = use_contextual_embeddings.lower() == "true" - elif isinstance(use_contextual_embeddings, dict) and use_contextual_embeddings.get( - "is_encrypted" - ): - # Handle encrypted value - encrypted_value = use_contextual_embeddings.get("encrypted_value") - if encrypted_value: - try: - decrypted = credential_service._decrypt_value(encrypted_value) - use_contextual_embeddings = decrypted.lower() == "true" - except: - use_contextual_embeddings = False - else: - use_contextual_embeddings = False + raw_value = await credential_service.get_credential( + "USE_CONTEXTUAL_EMBEDDINGS", "false", decrypt=True + ) + if isinstance(raw_value, str): + use_contextual_embeddings = raw_value.lower() == "true" else: - use_contextual_embeddings = bool(use_contextual_embeddings) - except: + use_contextual_embeddings = bool(raw_value) + except Exception as e: + search_logger.error(f"DEBUG: Error reading contextual embeddings: {e}") # Fallback to environment variable use_contextual_embeddings = ( os.getenv("USE_CONTEXTUAL_EMBEDDINGS", "false").lower() == "true" @@ -848,14 +1196,13 @@ async def add_code_examples_to_supabase( # Use only successful embeddings valid_embeddings = result.embeddings successful_texts = result.texts_processed - + # Get model information for tracking from ..llm_provider_service import get_embedding_model - from ..credential_service import credential_service - + # Get embedding model name embedding_model_name = await get_embedding_model(provider=provider) - + # Get LLM chat model (used for code summaries and contextual embeddings if enabled) llm_chat_model = None try: @@ -868,7 +1215,7 @@ async def add_code_examples_to_supabase( llm_chat_model = await credential_service.get_credential("MODEL_CHOICE", "gpt-4o-mini") else: # For code summaries, we use MODEL_CHOICE - llm_chat_model = _get_model_choice() + llm_chat_model = await _get_model_choice() except Exception as e: search_logger.warning(f"Failed to get LLM chat model: {e}") llm_chat_model = "gpt-4o-mini" # Default fallback @@ -888,7 +1235,7 @@ async def add_code_examples_to_supabase( positions_by_text[text].append(original_indices[k]) # Map successful texts back to their original indices - for embedding, text in zip(valid_embeddings, successful_texts, strict=False): + for embedding, text in zip(valid_embeddings, successful_texts, strict=True): # Get the next available index for this text (handles duplicates) if positions_by_text[text]: orig_idx = positions_by_text[text].popleft() # Original j index in [i, batch_end) @@ -908,7 +1255,7 @@ async def add_code_examples_to_supabase( # Determine the correct embedding column based on dimension embedding_dim = len(embedding) if isinstance(embedding, list) else len(embedding.tolist()) embedding_column = None - + if embedding_dim == 768: embedding_column = "embedding_768" elif embedding_dim == 1024: @@ -918,10 +1265,12 @@ async def add_code_examples_to_supabase( elif embedding_dim == 3072: embedding_column = "embedding_3072" else: - # Default to closest supported dimension - search_logger.warning(f"Unsupported embedding dimension {embedding_dim}, using embedding_1536") - embedding_column = "embedding_1536" - + # Skip unsupported dimensions to avoid corrupting the schema + search_logger.error( + f"Unsupported embedding dimension {embedding_dim}; skipping record to prevent column mismatch" + ) + continue + batch_data.append({ "url": urls[idx], "chunk_number": chunk_numbers[idx], @@ -954,9 +1303,7 @@ async def add_code_examples_to_supabase( f"Error inserting batch into Supabase (attempt {retry + 1}/{max_retries}): {e}" ) search_logger.info(f"Retrying in {retry_delay} seconds...") - import time - - time.sleep(retry_delay) + await asyncio.sleep(retry_delay) retry_delay *= 2 # Exponential backoff else: # Final attempt failed diff --git a/python/tests/test_async_llm_provider_service.py b/python/tests/test_async_llm_provider_service.py index e52c2242..db02c874 100644 --- a/python/tests/test_async_llm_provider_service.py +++ b/python/tests/test_async_llm_provider_service.py @@ -33,6 +33,12 @@ class AsyncContextManager: class TestAsyncLLMProviderService: """Test suite for async LLM provider service functions""" + @staticmethod + def _make_mock_client(): + client = MagicMock() + client.aclose = AsyncMock() + return client + @pytest.fixture(autouse=True) def clear_cache(self): """Clear cache before each test""" @@ -98,7 +104,7 @@ class TestAsyncLLMProviderService: with patch( "src.server.services.llm_provider_service.openai.AsyncOpenAI" ) as mock_openai: - mock_client = MagicMock() + mock_client = self._make_mock_client() mock_openai.return_value = mock_client async with get_llm_client() as client: @@ -121,7 +127,7 @@ class TestAsyncLLMProviderService: with patch( "src.server.services.llm_provider_service.openai.AsyncOpenAI" ) as mock_openai: - mock_client = MagicMock() + mock_client = self._make_mock_client() mock_openai.return_value = mock_client async with get_llm_client() as client: @@ -143,7 +149,7 @@ class TestAsyncLLMProviderService: with patch( "src.server.services.llm_provider_service.openai.AsyncOpenAI" ) as mock_openai: - mock_client = MagicMock() + mock_client = self._make_mock_client() mock_openai.return_value = mock_client async with get_llm_client() as client: @@ -166,7 +172,7 @@ class TestAsyncLLMProviderService: with patch( "src.server.services.llm_provider_service.openai.AsyncOpenAI" ) as mock_openai: - mock_client = MagicMock() + mock_client = self._make_mock_client() mock_openai.return_value = mock_client async with get_llm_client(provider="openai") as client: @@ -194,7 +200,7 @@ class TestAsyncLLMProviderService: with patch( "src.server.services.llm_provider_service.openai.AsyncOpenAI" ) as mock_openai: - mock_client = MagicMock() + mock_client = self._make_mock_client() mock_openai.return_value = mock_client async with get_llm_client(use_embedding_provider=True) as client: @@ -225,7 +231,7 @@ class TestAsyncLLMProviderService: with patch( "src.server.services.llm_provider_service.openai.AsyncOpenAI" ) as mock_openai: - mock_client = MagicMock() + mock_client = self._make_mock_client() mock_openai.return_value = mock_client # Should fallback to Ollama instead of raising an error @@ -426,7 +432,7 @@ class TestAsyncLLMProviderService: with patch( "src.server.services.llm_provider_service.openai.AsyncOpenAI" ) as mock_openai: - mock_client = MagicMock() + mock_client = self._make_mock_client() mock_openai.return_value = mock_client # First call should hit the credential service @@ -464,7 +470,7 @@ class TestAsyncLLMProviderService: with patch( "src.server.services.llm_provider_service.openai.AsyncOpenAI" ) as mock_openai: - mock_client = MagicMock() + mock_client = self._make_mock_client() mock_openai.return_value = mock_client client_ref = None @@ -474,6 +480,7 @@ class TestAsyncLLMProviderService: # After context manager exits, should still have reference to client assert client_ref == mock_client + mock_client.aclose.assert_awaited_once() @pytest.mark.asyncio async def test_multiple_providers_in_sequence(self, mock_credential_service): @@ -494,7 +501,7 @@ class TestAsyncLLMProviderService: with patch( "src.server.services.llm_provider_service.openai.AsyncOpenAI" ) as mock_openai: - mock_client = MagicMock() + mock_client = self._make_mock_client() mock_openai.return_value = mock_client for config in configs: diff --git a/python/tests/test_code_extraction_source_id.py b/python/tests/test_code_extraction_source_id.py index 7de851f5..05405ee7 100644 --- a/python/tests/test_code_extraction_source_id.py +++ b/python/tests/test_code_extraction_source_id.py @@ -104,13 +104,15 @@ class TestCodeExtractionSourceId: ) # Verify the correct source_id was passed (now with cancellation_check parameter) - mock_extract.assert_called_once_with( - crawl_results, - url_to_full_document, - source_id, # This should be the third argument - None, - None # cancellation_check parameter - ) + mock_extract.assert_called_once() + args, kwargs = mock_extract.call_args + assert args[0] == crawl_results + assert args[1] == url_to_full_document + assert args[2] == source_id + assert args[3] is None + assert args[4] is None + if len(args) > 5: + assert args[5] is None assert result == 5 @pytest.mark.asyncio @@ -174,4 +176,4 @@ class TestCodeExtractionSourceId: import inspect source = inspect.getsource(module) assert "from urllib.parse import urlparse" not in source, \ - "Should not import urlparse since we don't extract domain from URL anymore" \ No newline at end of file + "Should not import urlparse since we don't extract domain from URL anymore" From 7a4c67cf904f1d4c68008b6b7df75234b94be1f9 Mon Sep 17 00:00:00 2001 From: Jonah Gray Date: Mon, 22 Sep 2025 04:23:20 -0400 Subject: [PATCH 2/7] fix: resolve TypeScript strict mode errors in providerErrorHandler.ts (#720) * fix: resolve TypeScript strict mode errors in providerErrorHandler.ts - Add proper type guards for error object property access - Create ErrorWithStatus and ErrorWithMessage interfaces - Implement hasStatusProperty() and hasMessageProperty() type guards - Replace unsafe object property access with type-safe checks - All 8 TypeScript strict mode errors now resolved - Maintains existing functionality for LLM provider error handling Fixes #686 * fix: apply biome linting improvements to providerErrorHandler.ts - Use optional chaining instead of logical AND for property access - Improve formatting for better readability - Maintain all existing functionality while addressing linter warnings * chore: remove .claude-flow directory - Remove unnecessary .claude-flow metrics files - Clean up repository structure * Add comprehensive test coverage for providerErrorHandler TypeScript strict mode fixes - Added 24 comprehensive tests for parseProviderError and getProviderErrorMessage - Tests cover all error scenarios: basic errors, status codes, structured provider errors, malformed JSON, null/undefined handling, and TypeScript strict mode compliance - Fixed null/undefined handling in parseProviderError to properly return fallback messages - All tests passing (24/24) ensuring TypeScript strict mode fixes work correctly - Validates error handling for OpenAI, Google AI, Anthropic, and other LLM providers Related to PR #720 TypeScript strict mode compliance --------- Co-authored-by: OmniNode CI --- .../knowledge/utils/providerErrorHandler.ts | 65 ++-- .../utils/tests/providerErrorHandler.test.ts | 281 ++++++++++++++++++ 2 files changed, 329 insertions(+), 17 deletions(-) create mode 100644 archon-ui-main/src/features/knowledge/utils/tests/providerErrorHandler.test.ts diff --git a/archon-ui-main/src/features/knowledge/utils/providerErrorHandler.ts b/archon-ui-main/src/features/knowledge/utils/providerErrorHandler.ts index 655a08fd..588d49bc 100644 --- a/archon-ui-main/src/features/knowledge/utils/providerErrorHandler.ts +++ b/archon-ui-main/src/features/knowledge/utils/providerErrorHandler.ts @@ -10,31 +10,62 @@ export interface ProviderError extends Error { isProviderError?: boolean; } +// Type guards for error object properties +interface ErrorWithStatus { + statusCode?: number; + status?: number; +} + +interface ErrorWithMessage { + message?: string; +} + +// Type guard functions +function hasStatusProperty(obj: unknown): obj is ErrorWithStatus { + return typeof obj === "object" && obj !== null && ("statusCode" in obj || "status" in obj); +} + +function hasMessageProperty(obj: unknown): obj is ErrorWithMessage { + return typeof obj === "object" && obj !== null && "message" in obj; +} + /** * Parse backend error responses into provider-aware error objects */ export function parseProviderError(error: unknown): ProviderError { - const providerError = error as ProviderError; + // Handle null, undefined, and non-object types + if (!error || typeof error !== "object") { + const result: ProviderError = { + name: "Error", + } as ProviderError; - // Check if this is a structured provider error from backend - if (error && typeof error === "object") { - if (error.statusCode || error.status) { - providerError.statusCode = error.statusCode || error.status; + // Only set message for non-null/undefined values + if (error) { + result.message = String(error); } - // Parse backend error structure - if (error.message && error.message.includes("detail")) { - try { - const parsed = JSON.parse(error.message); - if (parsed.detail && parsed.detail.error_type) { - providerError.isProviderError = true; - providerError.provider = parsed.detail.provider || "LLM"; - providerError.errorType = parsed.detail.error_type; - providerError.message = parsed.detail.message || error.message; - } - } catch { - // If parsing fails, use message as-is + return result; + } + + const providerError = error as ProviderError; + + // Type-safe status code extraction + if (hasStatusProperty(error)) { + providerError.statusCode = error.statusCode || error.status; + } + + // Parse backend error structure with type safety + if (hasMessageProperty(error) && error.message && error.message.includes("detail")) { + try { + const parsed = JSON.parse(error.message); + if (parsed.detail?.error_type) { + providerError.isProviderError = true; + providerError.provider = parsed.detail.provider || "LLM"; + providerError.errorType = parsed.detail.error_type; + providerError.message = parsed.detail.message || error.message; } + } catch { + // If parsing fails, use message as-is } } diff --git a/archon-ui-main/src/features/knowledge/utils/tests/providerErrorHandler.test.ts b/archon-ui-main/src/features/knowledge/utils/tests/providerErrorHandler.test.ts new file mode 100644 index 00000000..193e2444 --- /dev/null +++ b/archon-ui-main/src/features/knowledge/utils/tests/providerErrorHandler.test.ts @@ -0,0 +1,281 @@ +import { describe, it, expect } from 'vitest'; +import { parseProviderError, getProviderErrorMessage, type ProviderError } from '../providerErrorHandler'; + +describe('providerErrorHandler', () => { + describe('parseProviderError', () => { + it('should handle basic Error objects', () => { + const error = new Error('Basic error message'); + const result = parseProviderError(error); + + expect(result.message).toBe('Basic error message'); + expect(result.isProviderError).toBeUndefined(); + }); + + it('should handle errors with statusCode property', () => { + const error = { statusCode: 401, message: 'Unauthorized' }; + const result = parseProviderError(error); + + expect(result.statusCode).toBe(401); + expect(result.message).toBe('Unauthorized'); + }); + + it('should handle errors with status property', () => { + const error = { status: 429, message: 'Rate limited' }; + const result = parseProviderError(error); + + expect(result.statusCode).toBe(429); + expect(result.message).toBe('Rate limited'); + }); + + it('should prioritize statusCode over status when both are present', () => { + const error = { statusCode: 401, status: 429, message: 'Auth error' }; + const result = parseProviderError(error); + + expect(result.statusCode).toBe(401); + }); + + it('should parse structured provider errors from backend', () => { + const error = { + message: JSON.stringify({ + detail: { + error_type: 'authentication_failed', + provider: 'OpenAI', + message: 'Invalid API key' + } + }) + }; + + const result = parseProviderError(error); + + expect(result.isProviderError).toBe(true); + expect(result.provider).toBe('OpenAI'); + expect(result.errorType).toBe('authentication_failed'); + expect(result.message).toBe('Invalid API key'); + }); + + it('should handle malformed JSON in message gracefully', () => { + const error = { + message: 'invalid json { detail' + }; + + const result = parseProviderError(error); + + expect(result.isProviderError).toBeUndefined(); + expect(result.message).toBe('invalid json { detail'); + }); + + it('should handle null and undefined inputs safely', () => { + expect(() => parseProviderError(null)).not.toThrow(); + expect(() => parseProviderError(undefined)).not.toThrow(); + + const nullResult = parseProviderError(null); + const undefinedResult = parseProviderError(undefined); + + expect(nullResult).toBeDefined(); + expect(undefinedResult).toBeDefined(); + }); + + it('should handle empty objects', () => { + const result = parseProviderError({}); + + expect(result).toBeDefined(); + expect(result.statusCode).toBeUndefined(); + expect(result.isProviderError).toBeUndefined(); + }); + + it('should handle primitive values', () => { + expect(() => parseProviderError('string error')).not.toThrow(); + expect(() => parseProviderError(42)).not.toThrow(); + expect(() => parseProviderError(true)).not.toThrow(); + }); + + it('should handle structured errors without provider field', () => { + const error = { + message: JSON.stringify({ + detail: { + error_type: 'quota_exhausted', + message: 'Usage limit exceeded' + } + }) + }; + + const result = parseProviderError(error); + + expect(result.isProviderError).toBe(true); + expect(result.provider).toBe('LLM'); // Default fallback + expect(result.errorType).toBe('quota_exhausted'); + expect(result.message).toBe('Usage limit exceeded'); + }); + + it('should handle partial structured errors', () => { + const error = { + message: JSON.stringify({ + detail: { + error_type: 'rate_limit' + // Missing message field + } + }) + }; + + const result = parseProviderError(error); + + expect(result.isProviderError).toBe(true); + expect(result.errorType).toBe('rate_limit'); + expect(result.message).toBe(error.message); // Falls back to original message + }); + }); + + describe('getProviderErrorMessage', () => { + it('should return user-friendly message for authentication_failed', () => { + const error: ProviderError = { + name: 'Error', + message: 'Auth failed', + isProviderError: true, + provider: 'OpenAI', + errorType: 'authentication_failed' + }; + + const result = getProviderErrorMessage(error); + expect(result).toBe('Please verify your OpenAI API key in Settings.'); + }); + + it('should return user-friendly message for quota_exhausted', () => { + const error: ProviderError = { + name: 'Error', + message: 'Quota exceeded', + isProviderError: true, + provider: 'Google AI', + errorType: 'quota_exhausted' + }; + + const result = getProviderErrorMessage(error); + expect(result).toBe('Google AI quota exhausted. Please check your billing settings.'); + }); + + it('should return user-friendly message for rate_limit', () => { + const error: ProviderError = { + name: 'Error', + message: 'Rate limited', + isProviderError: true, + provider: 'Anthropic', + errorType: 'rate_limit' + }; + + const result = getProviderErrorMessage(error); + expect(result).toBe('Anthropic rate limit exceeded. Please wait and try again.'); + }); + + it('should return generic provider message for unknown error types', () => { + const error: ProviderError = { + name: 'Error', + message: 'Unknown error', + isProviderError: true, + provider: 'OpenAI', + errorType: 'unknown_error' + }; + + const result = getProviderErrorMessage(error); + expect(result).toBe('OpenAI API error. Please check your configuration.'); + }); + + it('should use default provider when provider is missing', () => { + const error: ProviderError = { + name: 'Error', + message: 'Auth failed', + isProviderError: true, + errorType: 'authentication_failed' + }; + + const result = getProviderErrorMessage(error); + expect(result).toBe('Please verify your LLM API key in Settings.'); + }); + + it('should handle 401 status code for non-provider errors', () => { + const error = { statusCode: 401, message: 'Unauthorized' }; + + const result = getProviderErrorMessage(error); + expect(result).toBe('Please verify your API key in Settings.'); + }); + + it('should return original message for non-provider errors', () => { + const error = new Error('Network connection failed'); + + const result = getProviderErrorMessage(error); + expect(result).toBe('Network connection failed'); + }); + + it('should return default message when no message is available', () => { + const error = {}; + + const result = getProviderErrorMessage(error); + expect(result).toBe('An error occurred.'); + }); + + it('should handle complex error objects with structured backend response', () => { + const backendError = { + statusCode: 400, + message: JSON.stringify({ + detail: { + error_type: 'authentication_failed', + provider: 'OpenAI', + message: 'API key invalid or expired' + } + }) + }; + + const result = getProviderErrorMessage(backendError); + expect(result).toBe('Please verify your OpenAI API key in Settings.'); + }); + + it('should handle edge case: message contains "detail" but is not JSON', () => { + const error = { + message: 'Error detail: something went wrong' + }; + + const result = getProviderErrorMessage(error); + expect(result).toBe('Error detail: something went wrong'); + }); + + it('should handle null and undefined gracefully', () => { + expect(getProviderErrorMessage(null)).toBe('An error occurred.'); + expect(getProviderErrorMessage(undefined)).toBe('An error occurred.'); + }); + }); + + describe('TypeScript strict mode compliance', () => { + it('should handle type-safe property access', () => { + // Test that our type guards work properly + const errorWithStatus = { statusCode: 500 }; + const errorWithMessage = { message: 'test' }; + const errorWithBoth = { statusCode: 401, message: 'unauthorized' }; + + // These should not throw TypeScript errors and should work correctly + expect(() => parseProviderError(errorWithStatus)).not.toThrow(); + expect(() => parseProviderError(errorWithMessage)).not.toThrow(); + expect(() => parseProviderError(errorWithBoth)).not.toThrow(); + + const result1 = parseProviderError(errorWithStatus); + const result2 = parseProviderError(errorWithMessage); + const result3 = parseProviderError(errorWithBoth); + + expect(result1.statusCode).toBe(500); + expect(result2.message).toBe('test'); + expect(result3.statusCode).toBe(401); + expect(result3.message).toBe('unauthorized'); + }); + + it('should handle objects without expected properties safely', () => { + const objectWithoutStatus = { someOtherProperty: 'value' }; + const objectWithoutMessage = { anotherProperty: 42 }; + + expect(() => parseProviderError(objectWithoutStatus)).not.toThrow(); + expect(() => parseProviderError(objectWithoutMessage)).not.toThrow(); + + const result1 = parseProviderError(objectWithoutStatus); + const result2 = parseProviderError(objectWithoutMessage); + + expect(result1.statusCode).toBeUndefined(); + expect(result2.message).toBeUndefined(); + }); + }); +}); \ No newline at end of file From 3ff3f7f2dce6571a4fafbb656501dbfbf68e0f56 Mon Sep 17 00:00:00 2001 From: Cole Medin Date: Mon, 22 Sep 2025 04:25:58 -0500 Subject: [PATCH 3/7] Migrations and version APIs (#718) * Preparing migration folder for the migration alert implementation * Migrations and version APIs initial * Touching up update instructions in README and UI * Unit tests for migrations and version APIs * Splitting up the Ollama migration scripts * Removing temporary PRPs --------- Co-authored-by: Rasmus Widing --- README.md | 12 +- .../components/MigrationStatusCard.tsx | 132 +++++ .../components/PendingMigrationsModal.tsx | 195 +++++++ .../migrations/hooks/useMigrationQueries.ts | 58 ++ .../migrations/services/migrationService.ts | 47 ++ .../settings/migrations/types/index.ts | 41 ++ .../version/components/UpdateBanner.tsx | 69 +++ .../version/components/VersionStatusCard.tsx | 98 ++++ .../version/hooks/useVersionQueries.ts | 59 ++ .../version/services/versionService.ts | 49 ++ .../features/settings/version/types/index.ts | 35 ++ archon-ui-main/src/pages/SettingsPage.tsx | 35 ++ docker-compose.yml | 1 + .../001_add_source_url_display_name.sql} | 0 .../002_add_hybrid_search_tsvector.sql} | 0 migration/0.1.0/003_ollama_add_columns.sql | 35 ++ migration/0.1.0/004_ollama_migrate_data.sql | 70 +++ .../0.1.0/005_ollama_create_functions.sql | 172 ++++++ .../006_ollama_create_indexes_optional.sql | 67 +++ .../007_add_priority_column_to_tasks.sql} | 0 .../0.1.0/008_add_migration_tracking.sql | 65 +++ migration/0.1.0/DB_UPGRADE_INSTRUCTIONS.md | 157 ++++++ migration/DB_UPGRADE_INSTRUCTIONS.md | 167 ------ migration/RESET_DB.sql | 11 +- migration/complete_setup.sql | 56 ++ migration/upgrade_database.sql | 518 ------------------ migration/validate_migration.sql | 287 ---------- python/src/server/api_routes/migration_api.py | 170 ++++++ python/src/server/api_routes/version_api.py | 121 ++++ python/src/server/config/version.py | 11 + python/src/server/main.py | 4 + .../src/server/services/migration_service.py | 233 ++++++++ python/src/server/services/version_service.py | 162 ++++++ python/src/server/utils/semantic_version.py | 107 ++++ .../server/api_routes/test_migration_api.py | 206 +++++++ .../server/api_routes/test_version_api.py | 147 +++++ .../server/services/test_migration_service.py | 271 +++++++++ .../server/services/test_version_service.py | 234 ++++++++ 38 files changed, 3124 insertions(+), 978 deletions(-) create mode 100644 archon-ui-main/src/features/settings/migrations/components/MigrationStatusCard.tsx create mode 100644 archon-ui-main/src/features/settings/migrations/components/PendingMigrationsModal.tsx create mode 100644 archon-ui-main/src/features/settings/migrations/hooks/useMigrationQueries.ts create mode 100644 archon-ui-main/src/features/settings/migrations/services/migrationService.ts create mode 100644 archon-ui-main/src/features/settings/migrations/types/index.ts create mode 100644 archon-ui-main/src/features/settings/version/components/UpdateBanner.tsx create mode 100644 archon-ui-main/src/features/settings/version/components/VersionStatusCard.tsx create mode 100644 archon-ui-main/src/features/settings/version/hooks/useVersionQueries.ts create mode 100644 archon-ui-main/src/features/settings/version/services/versionService.ts create mode 100644 archon-ui-main/src/features/settings/version/types/index.ts rename migration/{add_source_url_display_name.sql => 0.1.0/001_add_source_url_display_name.sql} (100%) rename migration/{add_hybrid_search_tsvector.sql => 0.1.0/002_add_hybrid_search_tsvector.sql} (100%) create mode 100644 migration/0.1.0/003_ollama_add_columns.sql create mode 100644 migration/0.1.0/004_ollama_migrate_data.sql create mode 100644 migration/0.1.0/005_ollama_create_functions.sql create mode 100644 migration/0.1.0/006_ollama_create_indexes_optional.sql rename migration/{add_priority_column_to_tasks.sql => 0.1.0/007_add_priority_column_to_tasks.sql} (100%) create mode 100644 migration/0.1.0/008_add_migration_tracking.sql create mode 100644 migration/0.1.0/DB_UPGRADE_INSTRUCTIONS.md delete mode 100644 migration/DB_UPGRADE_INSTRUCTIONS.md delete mode 100644 migration/upgrade_database.sql delete mode 100644 migration/validate_migration.sql create mode 100644 python/src/server/api_routes/migration_api.py create mode 100644 python/src/server/api_routes/version_api.py create mode 100644 python/src/server/config/version.py create mode 100644 python/src/server/services/migration_service.py create mode 100644 python/src/server/services/version_service.py create mode 100644 python/src/server/utils/semantic_version.py create mode 100644 python/tests/server/api_routes/test_migration_api.py create mode 100644 python/tests/server/api_routes/test_version_api.py create mode 100644 python/tests/server/services/test_migration_service.py create mode 100644 python/tests/server/services/test_version_service.py diff --git a/README.md b/README.md index d0440f1c..90f5f784 100644 --- a/README.md +++ b/README.md @@ -206,14 +206,18 @@ To upgrade Archon to the latest version: git pull ``` -2. **Check for migrations**: Look in the `migration/` folder for any SQL files newer than your last update. Check the file created dates to determine if you need to run them. You can run these in the SQL editor just like you did when you first set up Archon. We are also working on a way to make handling these migrations automatic! - -3. **Rebuild and restart**: +2. **Rebuild and restart containers**: ```bash docker compose up -d --build ``` + This rebuilds containers with the latest code and restarts all services. -This is the same command used for initial setup - it rebuilds containers with the latest code and restarts services. +3. **Check for database migrations**: + - Open the Archon settings in your browser: [http://localhost:3737/settings](http://localhost:3737/settings) + - Navigate to the **Database Migrations** section + - If there are pending migrations, the UI will display them with clear instructions + - Click on each migration to view and copy the SQL + - Run the SQL scripts in your Supabase SQL editor in the order shown ## What's Included diff --git a/archon-ui-main/src/features/settings/migrations/components/MigrationStatusCard.tsx b/archon-ui-main/src/features/settings/migrations/components/MigrationStatusCard.tsx new file mode 100644 index 00000000..2b29531c --- /dev/null +++ b/archon-ui-main/src/features/settings/migrations/components/MigrationStatusCard.tsx @@ -0,0 +1,132 @@ +/** + * Card component showing migration status + */ + +import { motion } from "framer-motion"; +import { AlertTriangle, CheckCircle, Database, RefreshCw } from "lucide-react"; +import React from "react"; +import { useMigrationStatus } from "../hooks/useMigrationQueries"; +import { PendingMigrationsModal } from "./PendingMigrationsModal"; + +export function MigrationStatusCard() { + const { data, isLoading, error, refetch } = useMigrationStatus(); + const [isModalOpen, setIsModalOpen] = React.useState(false); + + const handleRefresh = () => { + refetch(); + }; + + return ( + <> + +
+
+ +

Database Migrations

+
+ +
+ +
+
+ Applied Migrations + {data?.applied_count ?? 0} +
+ +
+ Pending Migrations +
+ {data?.pending_count ?? 0} + {data && data.pending_count > 0 && } +
+
+ +
+ Status +
+ {isLoading ? ( + <> + + Checking... + + ) : error ? ( + <> + + Error loading + + ) : data?.bootstrap_required ? ( + <> + + Setup required + + ) : data?.has_pending ? ( + <> + + Migrations pending + + ) : ( + <> + + Up to date + + )} +
+
+ + {data?.current_version && ( +
+ Database Version + {data.current_version} +
+ )} +
+ + {data?.has_pending && ( +
+

+ {data.bootstrap_required + ? "Initial database setup is required." + : `${data.pending_count} migration${data.pending_count > 1 ? "s" : ""} need to be applied.`} +

+ +
+ )} + + {error && ( +
+

+ Failed to load migration status. Please check your database connection. +

+
+ )} +
+ + {/* Modal for viewing pending migrations */} + {data && ( + setIsModalOpen(false)} + migrations={data.pending_migrations} + onMigrationsApplied={refetch} + /> + )} + + ); +} diff --git a/archon-ui-main/src/features/settings/migrations/components/PendingMigrationsModal.tsx b/archon-ui-main/src/features/settings/migrations/components/PendingMigrationsModal.tsx new file mode 100644 index 00000000..f4bd23c0 --- /dev/null +++ b/archon-ui-main/src/features/settings/migrations/components/PendingMigrationsModal.tsx @@ -0,0 +1,195 @@ +/** + * Modal for viewing and copying pending migration SQL + */ + +import { AnimatePresence, motion } from "framer-motion"; +import { CheckCircle, Copy, Database, ExternalLink, X } from "lucide-react"; +import React from "react"; +import { copyToClipboard } from "@/features/shared/utils/clipboard"; +import { useToast } from "@/features/ui/hooks/useToast"; +import type { PendingMigration } from "../types"; + +interface PendingMigrationsModalProps { + isOpen: boolean; + onClose: () => void; + migrations: PendingMigration[]; + onMigrationsApplied: () => void; +} + +export function PendingMigrationsModal({ + isOpen, + onClose, + migrations, + onMigrationsApplied, +}: PendingMigrationsModalProps) { + const { showToast } = useToast(); + const [copiedIndex, setCopiedIndex] = React.useState(null); + const [expandedIndex, setExpandedIndex] = React.useState(null); + + const handleCopy = async (sql: string, index: number) => { + const result = await copyToClipboard(sql); + if (result.success) { + setCopiedIndex(index); + showToast("SQL copied to clipboard", "success"); + setTimeout(() => setCopiedIndex(null), 2000); + } else { + showToast("Failed to copy SQL", "error"); + } + }; + + const handleCopyAll = async () => { + const allSql = migrations.map((m) => `-- ${m.name}\n${m.sql_content}`).join("\n\n"); + const result = await copyToClipboard(allSql); + if (result.success) { + showToast("All migration SQL copied to clipboard", "success"); + } else { + showToast("Failed to copy SQL", "error"); + } + }; + + if (!isOpen) return null; + + return ( + +
+ {/* Backdrop */} + + + {/* Modal */} + + {/* Header */} +
+
+ +

Pending Database Migrations

+
+ +
+ + {/* Instructions */} +
+

+ + How to Apply Migrations +

+
    +
  1. Copy the SQL for each migration below
  2. +
  3. Open your Supabase dashboard SQL Editor
  4. +
  5. Paste and execute each migration in order
  6. +
  7. Click "Refresh Status" below to verify migrations were applied
  8. +
+ {migrations.length > 1 && ( + + )} +
+ + {/* Migration List */} +
+ {migrations.length === 0 ? ( +
+ +

All migrations have been applied!

+
+ ) : ( +
+ {migrations.map((migration, index) => ( +
+
+
+
+

{migration.name}

+

+ Version: {migration.version} • {migration.file_path} +

+
+
+ + +
+
+ + {/* SQL Content */} + + {expandedIndex === index && ( + +
+                              {migration.sql_content}
+                            
+
+ )} +
+
+
+ ))} +
+ )} +
+ + {/* Footer */} +
+ + +
+
+
+
+ ); +} diff --git a/archon-ui-main/src/features/settings/migrations/hooks/useMigrationQueries.ts b/archon-ui-main/src/features/settings/migrations/hooks/useMigrationQueries.ts new file mode 100644 index 00000000..1c2a6d7e --- /dev/null +++ b/archon-ui-main/src/features/settings/migrations/hooks/useMigrationQueries.ts @@ -0,0 +1,58 @@ +/** + * TanStack Query hooks for migration tracking + */ + +import { useQuery } from "@tanstack/react-query"; +import { STALE_TIMES } from "@/features/shared/queryPatterns"; +import { useSmartPolling } from "@/features/ui/hooks/useSmartPolling"; +import { migrationService } from "../services/migrationService"; +import type { MigrationHistoryResponse, MigrationStatusResponse, PendingMigration } from "../types"; + +// Query key factory +export const migrationKeys = { + all: ["migrations"] as const, + status: () => [...migrationKeys.all, "status"] as const, + history: () => [...migrationKeys.all, "history"] as const, + pending: () => [...migrationKeys.all, "pending"] as const, +}; + +/** + * Hook to get comprehensive migration status + * Polls more frequently when migrations are pending + */ +export function useMigrationStatus() { + // Poll every 30 seconds when tab is visible + const { refetchInterval } = useSmartPolling(30000); + + return useQuery({ + queryKey: migrationKeys.status(), + queryFn: () => migrationService.getMigrationStatus(), + staleTime: STALE_TIMES.normal, // 30 seconds + refetchInterval, + }); +} + +/** + * Hook to get migration history + */ +export function useMigrationHistory() { + return useQuery({ + queryKey: migrationKeys.history(), + queryFn: () => migrationService.getMigrationHistory(), + staleTime: STALE_TIMES.rare, // 5 minutes - history doesn't change often + }); +} + +/** + * Hook to get pending migrations only + */ +export function usePendingMigrations() { + const { refetchInterval } = useSmartPolling(30000); + + return useQuery({ + queryKey: migrationKeys.pending(), + queryFn: () => migrationService.getPendingMigrations(), + staleTime: STALE_TIMES.normal, + refetchInterval, + }); +} diff --git a/archon-ui-main/src/features/settings/migrations/services/migrationService.ts b/archon-ui-main/src/features/settings/migrations/services/migrationService.ts new file mode 100644 index 00000000..93cb15eb --- /dev/null +++ b/archon-ui-main/src/features/settings/migrations/services/migrationService.ts @@ -0,0 +1,47 @@ +/** + * Service for database migration tracking and management + */ + +import { callAPIWithETag } from "@/features/shared/apiWithEtag"; +import type { MigrationHistoryResponse, MigrationStatusResponse, PendingMigration } from "../types"; + +export const migrationService = { + /** + * Get comprehensive migration status including pending and applied + */ + async getMigrationStatus(): Promise { + try { + const response = await callAPIWithETag("/api/migrations/status"); + return response as MigrationStatusResponse; + } catch (error) { + console.error("Error getting migration status:", error); + throw error; + } + }, + + /** + * Get history of applied migrations + */ + async getMigrationHistory(): Promise { + try { + const response = await callAPIWithETag("/api/migrations/history"); + return response as MigrationHistoryResponse; + } catch (error) { + console.error("Error getting migration history:", error); + throw error; + } + }, + + /** + * Get list of pending migrations only + */ + async getPendingMigrations(): Promise { + try { + const response = await callAPIWithETag("/api/migrations/pending"); + return response as PendingMigration[]; + } catch (error) { + console.error("Error getting pending migrations:", error); + throw error; + } + }, +}; diff --git a/archon-ui-main/src/features/settings/migrations/types/index.ts b/archon-ui-main/src/features/settings/migrations/types/index.ts new file mode 100644 index 00000000..7c08c6bf --- /dev/null +++ b/archon-ui-main/src/features/settings/migrations/types/index.ts @@ -0,0 +1,41 @@ +/** + * Type definitions for database migration tracking and management + */ + +export interface MigrationRecord { + version: string; + migration_name: string; + applied_at: string; + checksum?: string | null; +} + +export interface PendingMigration { + version: string; + name: string; + sql_content: string; + file_path: string; + checksum?: string | null; +} + +export interface MigrationStatusResponse { + pending_migrations: PendingMigration[]; + applied_migrations: MigrationRecord[]; + has_pending: boolean; + bootstrap_required: boolean; + current_version: string; + pending_count: number; + applied_count: number; +} + +export interface MigrationHistoryResponse { + migrations: MigrationRecord[]; + total_count: number; + current_version: string; +} + +export interface MigrationState { + status: MigrationStatusResponse | null; + isLoading: boolean; + error: Error | null; + selectedMigration: PendingMigration | null; +} diff --git a/archon-ui-main/src/features/settings/version/components/UpdateBanner.tsx b/archon-ui-main/src/features/settings/version/components/UpdateBanner.tsx new file mode 100644 index 00000000..c25f37a2 --- /dev/null +++ b/archon-ui-main/src/features/settings/version/components/UpdateBanner.tsx @@ -0,0 +1,69 @@ +/** + * Banner component that shows when an update is available + */ + +import { AnimatePresence, motion } from "framer-motion"; +import { ArrowUpCircle, ExternalLink, X } from "lucide-react"; +import React from "react"; +import { useVersionCheck } from "../hooks/useVersionQueries"; + +export function UpdateBanner() { + const { data, isLoading, error } = useVersionCheck(); + const [isDismissed, setIsDismissed] = React.useState(false); + + // Don't show banner if loading, error, no data, or no update available + if (isLoading || error || !data?.update_available || isDismissed) { + return null; + } + + return ( + + +
+
+ +
+

Update Available: v{data.latest}

+

You are currently running v{data.current}

+
+
+
+ {data.release_url && ( + + View Release + + + )} + + View Upgrade Instructions + + + +
+
+
+
+ ); +} diff --git a/archon-ui-main/src/features/settings/version/components/VersionStatusCard.tsx b/archon-ui-main/src/features/settings/version/components/VersionStatusCard.tsx new file mode 100644 index 00000000..250b85d1 --- /dev/null +++ b/archon-ui-main/src/features/settings/version/components/VersionStatusCard.tsx @@ -0,0 +1,98 @@ +/** + * Card component showing current version status + */ + +import { motion } from "framer-motion"; +import { AlertCircle, CheckCircle, Info, RefreshCw } from "lucide-react"; +import { useClearVersionCache, useVersionCheck } from "../hooks/useVersionQueries"; + +export function VersionStatusCard() { + const { data, isLoading, error, refetch } = useVersionCheck(); + const clearCache = useClearVersionCache(); + + const handleRefreshClick = async () => { + // Clear cache and then refetch + await clearCache.mutateAsync(); + refetch(); + }; + + return ( + +
+
+ +

Version Information

+
+ +
+ +
+
+ Current Version + {data?.current || "Loading..."} +
+ +
+ Latest Version + + {isLoading ? "Checking..." : error ? "Check failed" : data?.latest ? data.latest : "No releases found"} + +
+ +
+ Status +
+ {isLoading ? ( + <> + + Checking... + + ) : error ? ( + <> + + Error checking + + ) : data?.update_available ? ( + <> + + Update available + + ) : ( + <> + + Up to date + + )} +
+
+ + {data?.published_at && ( +
+ Released + {new Date(data.published_at).toLocaleDateString()} +
+ )} +
+ + {error && ( +
+

+ {data?.check_error || "Failed to check for updates. Please try again later."} +

+
+ )} +
+ ); +} diff --git a/archon-ui-main/src/features/settings/version/hooks/useVersionQueries.ts b/archon-ui-main/src/features/settings/version/hooks/useVersionQueries.ts new file mode 100644 index 00000000..e1aefbd8 --- /dev/null +++ b/archon-ui-main/src/features/settings/version/hooks/useVersionQueries.ts @@ -0,0 +1,59 @@ +/** + * TanStack Query hooks for version checking + */ + +import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; +import { STALE_TIMES } from "@/features/shared/queryPatterns"; +import { useSmartPolling } from "@/features/ui/hooks/useSmartPolling"; +import { versionService } from "../services/versionService"; +import type { VersionCheckResponse } from "../types"; + +// Query key factory +export const versionKeys = { + all: ["version"] as const, + check: () => [...versionKeys.all, "check"] as const, + current: () => [...versionKeys.all, "current"] as const, +}; + +/** + * Hook to check for version updates + * Polls every 5 minutes when tab is visible + */ +export function useVersionCheck() { + // Smart polling: check every 5 minutes when tab is visible + const { refetchInterval } = useSmartPolling(300000); // 5 minutes + + return useQuery({ + queryKey: versionKeys.check(), + queryFn: () => versionService.checkVersion(), + staleTime: STALE_TIMES.rare, // 5 minutes + refetchInterval, + retry: false, // Don't retry on 404 or network errors + }); +} + +/** + * Hook to get current version without checking for updates + */ +export function useCurrentVersion() { + return useQuery({ + queryKey: versionKeys.current(), + queryFn: () => versionService.getCurrentVersion(), + staleTime: STALE_TIMES.static, // Never stale + }); +} + +/** + * Hook to clear version cache and force fresh check + */ +export function useClearVersionCache() { + const queryClient = useQueryClient(); + + return useMutation({ + mutationFn: () => versionService.clearCache(), + onSuccess: () => { + // Invalidate version queries to force fresh check + queryClient.invalidateQueries({ queryKey: versionKeys.all }); + }, + }); +} diff --git a/archon-ui-main/src/features/settings/version/services/versionService.ts b/archon-ui-main/src/features/settings/version/services/versionService.ts new file mode 100644 index 00000000..4ef45b82 --- /dev/null +++ b/archon-ui-main/src/features/settings/version/services/versionService.ts @@ -0,0 +1,49 @@ +/** + * Service for version checking and update management + */ + +import { callAPIWithETag } from "@/features/shared/apiWithEtag"; +import type { CurrentVersionResponse, VersionCheckResponse } from "../types"; + +export const versionService = { + /** + * Check for available Archon updates + */ + async checkVersion(): Promise { + try { + const response = await callAPIWithETag("/api/version/check"); + return response as VersionCheckResponse; + } catch (error) { + console.error("Error checking version:", error); + throw error; + } + }, + + /** + * Get current Archon version without checking for updates + */ + async getCurrentVersion(): Promise { + try { + const response = await callAPIWithETag("/api/version/current"); + return response as CurrentVersionResponse; + } catch (error) { + console.error("Error getting current version:", error); + throw error; + } + }, + + /** + * Clear version cache to force fresh check + */ + async clearCache(): Promise<{ message: string; success: boolean }> { + try { + const response = await callAPIWithETag("/api/version/clear-cache", { + method: "POST", + }); + return response as { message: string; success: boolean }; + } catch (error) { + console.error("Error clearing version cache:", error); + throw error; + } + }, +}; diff --git a/archon-ui-main/src/features/settings/version/types/index.ts b/archon-ui-main/src/features/settings/version/types/index.ts new file mode 100644 index 00000000..04da0860 --- /dev/null +++ b/archon-ui-main/src/features/settings/version/types/index.ts @@ -0,0 +1,35 @@ +/** + * Type definitions for version checking and update management + */ + +export interface ReleaseAsset { + name: string; + size: number; + download_count: number; + browser_download_url: string; + content_type: string; +} + +export interface VersionCheckResponse { + current: string; + latest: string | null; + update_available: boolean; + release_url: string | null; + release_notes: string | null; + published_at: string | null; + check_error?: string | null; + assets?: ReleaseAsset[] | null; + author?: string | null; +} + +export interface CurrentVersionResponse { + version: string; + timestamp: string; +} + +export interface VersionStatus { + isLoading: boolean; + error: Error | null; + data: VersionCheckResponse | null; + lastChecked: Date | null; +} diff --git a/archon-ui-main/src/pages/SettingsPage.tsx b/archon-ui-main/src/pages/SettingsPage.tsx index ad186e87..20c3c412 100644 --- a/archon-ui-main/src/pages/SettingsPage.tsx +++ b/archon-ui-main/src/pages/SettingsPage.tsx @@ -10,6 +10,8 @@ import { Code, FileCode, Bug, + Info, + Database, } from "lucide-react"; import { motion, AnimatePresence } from "framer-motion"; import { useToast } from "../features/ui/hooks/useToast"; @@ -28,6 +30,9 @@ import { RagSettings, CodeExtractionSettings as CodeExtractionSettingsType, } from "../services/credentialsService"; +import { UpdateBanner } from "../features/settings/version/components/UpdateBanner"; +import { VersionStatusCard } from "../features/settings/version/components/VersionStatusCard"; +import { MigrationStatusCard } from "../features/settings/migrations/components/MigrationStatusCard"; export const SettingsPage = () => { const [ragSettings, setRagSettings] = useState({ @@ -106,6 +111,9 @@ export const SettingsPage = () => { variants={containerVariants} className="w-full" > + {/* Update Banner */} + + {/* Header */} { + + {/* Version Status */} + + + + + + + {/* Migration Status */} + + + + + + {projectsEnabled && ( 0 THEN + -- Detect dimension + SELECT vector_dims(embedding) INTO dimension_detected + FROM archon_crawled_pages + WHERE embedding IS NOT NULL + LIMIT 1; + + IF dimension_detected = 1536 THEN + UPDATE archon_crawled_pages + SET embedding_1536 = embedding, + embedding_dimension = 1536, + embedding_model = COALESCE(embedding_model, 'text-embedding-3-small') + WHERE embedding IS NOT NULL AND embedding_1536 IS NULL; + END IF; + + -- Drop old column + ALTER TABLE archon_crawled_pages DROP COLUMN IF EXISTS embedding; + END IF; + + -- Same for code_examples + SELECT COUNT(*) INTO code_examples_count + FROM information_schema.columns + WHERE table_name = 'archon_code_examples' + AND column_name = 'embedding'; + + IF code_examples_count > 0 THEN + SELECT vector_dims(embedding) INTO dimension_detected + FROM archon_code_examples + WHERE embedding IS NOT NULL + LIMIT 1; + + IF dimension_detected = 1536 THEN + UPDATE archon_code_examples + SET embedding_1536 = embedding, + embedding_dimension = 1536, + embedding_model = COALESCE(embedding_model, 'text-embedding-3-small') + WHERE embedding IS NOT NULL AND embedding_1536 IS NULL; + END IF; + + ALTER TABLE archon_code_examples DROP COLUMN IF EXISTS embedding; + END IF; +END $$; + +-- Drop old indexes if they exist +DROP INDEX IF EXISTS idx_archon_crawled_pages_embedding; +DROP INDEX IF EXISTS idx_archon_code_examples_embedding; + +COMMIT; + +SELECT 'Ollama data migrated successfully' AS status; \ No newline at end of file diff --git a/migration/0.1.0/005_ollama_create_functions.sql b/migration/0.1.0/005_ollama_create_functions.sql new file mode 100644 index 00000000..0426cdf6 --- /dev/null +++ b/migration/0.1.0/005_ollama_create_functions.sql @@ -0,0 +1,172 @@ +-- ====================================================================== +-- Migration 005: Ollama Implementation - Create Functions +-- Creates search functions for multi-dimensional embeddings +-- ====================================================================== + +BEGIN; + +-- Helper function to detect embedding dimension +CREATE OR REPLACE FUNCTION detect_embedding_dimension(embedding_vector vector) +RETURNS INTEGER AS $$ +BEGIN + RETURN vector_dims(embedding_vector); +END; +$$ LANGUAGE plpgsql IMMUTABLE; + +-- Helper function to get column name for dimension +CREATE OR REPLACE FUNCTION get_embedding_column_name(dimension INTEGER) +RETURNS TEXT AS $$ +BEGIN + CASE dimension + WHEN 384 THEN RETURN 'embedding_384'; + WHEN 768 THEN RETURN 'embedding_768'; + WHEN 1024 THEN RETURN 'embedding_1024'; + WHEN 1536 THEN RETURN 'embedding_1536'; + WHEN 3072 THEN RETURN 'embedding_3072'; + ELSE RAISE EXCEPTION 'Unsupported embedding dimension: %', dimension; + END CASE; +END; +$$ LANGUAGE plpgsql IMMUTABLE; + +-- Multi-dimensional search for crawled pages +CREATE OR REPLACE FUNCTION match_archon_crawled_pages_multi ( + query_embedding VECTOR, + embedding_dimension INTEGER, + match_count INT DEFAULT 10, + filter JSONB DEFAULT '{}'::jsonb, + source_filter TEXT DEFAULT NULL +) RETURNS TABLE ( + id BIGINT, + url VARCHAR, + chunk_number INTEGER, + content TEXT, + metadata JSONB, + source_id TEXT, + similarity FLOAT +) +LANGUAGE plpgsql +AS $$ +#variable_conflict use_column +DECLARE + sql_query TEXT; + embedding_column TEXT; +BEGIN + CASE embedding_dimension + WHEN 384 THEN embedding_column := 'embedding_384'; + WHEN 768 THEN embedding_column := 'embedding_768'; + WHEN 1024 THEN embedding_column := 'embedding_1024'; + WHEN 1536 THEN embedding_column := 'embedding_1536'; + WHEN 3072 THEN embedding_column := 'embedding_3072'; + ELSE RAISE EXCEPTION 'Unsupported embedding dimension: %', embedding_dimension; + END CASE; + + sql_query := format(' + SELECT id, url, chunk_number, content, metadata, source_id, + 1 - (%I <=> $1) AS similarity + FROM archon_crawled_pages + WHERE (%I IS NOT NULL) + AND metadata @> $3 + AND ($4 IS NULL OR source_id = $4) + ORDER BY %I <=> $1 + LIMIT $2', + embedding_column, embedding_column, embedding_column); + + RETURN QUERY EXECUTE sql_query USING query_embedding, match_count, filter, source_filter; +END; +$$; + +-- Multi-dimensional search for code examples +CREATE OR REPLACE FUNCTION match_archon_code_examples_multi ( + query_embedding VECTOR, + embedding_dimension INTEGER, + match_count INT DEFAULT 10, + filter JSONB DEFAULT '{}'::jsonb, + source_filter TEXT DEFAULT NULL +) RETURNS TABLE ( + id BIGINT, + url VARCHAR, + chunk_number INTEGER, + content TEXT, + summary TEXT, + metadata JSONB, + source_id TEXT, + similarity FLOAT +) +LANGUAGE plpgsql +AS $$ +#variable_conflict use_column +DECLARE + sql_query TEXT; + embedding_column TEXT; +BEGIN + CASE embedding_dimension + WHEN 384 THEN embedding_column := 'embedding_384'; + WHEN 768 THEN embedding_column := 'embedding_768'; + WHEN 1024 THEN embedding_column := 'embedding_1024'; + WHEN 1536 THEN embedding_column := 'embedding_1536'; + WHEN 3072 THEN embedding_column := 'embedding_3072'; + ELSE RAISE EXCEPTION 'Unsupported embedding dimension: %', embedding_dimension; + END CASE; + + sql_query := format(' + SELECT id, url, chunk_number, content, summary, metadata, source_id, + 1 - (%I <=> $1) AS similarity + FROM archon_code_examples + WHERE (%I IS NOT NULL) + AND metadata @> $3 + AND ($4 IS NULL OR source_id = $4) + ORDER BY %I <=> $1 + LIMIT $2', + embedding_column, embedding_column, embedding_column); + + RETURN QUERY EXECUTE sql_query USING query_embedding, match_count, filter, source_filter; +END; +$$; + +-- Legacy compatibility (defaults to 1536D) +CREATE OR REPLACE FUNCTION match_archon_crawled_pages ( + query_embedding VECTOR(1536), + match_count INT DEFAULT 10, + filter JSONB DEFAULT '{}'::jsonb, + source_filter TEXT DEFAULT NULL +) RETURNS TABLE ( + id BIGINT, + url VARCHAR, + chunk_number INTEGER, + content TEXT, + metadata JSONB, + source_id TEXT, + similarity FLOAT +) +LANGUAGE plpgsql +AS $$ +BEGIN + RETURN QUERY SELECT * FROM match_archon_crawled_pages_multi(query_embedding, 1536, match_count, filter, source_filter); +END; +$$; + +CREATE OR REPLACE FUNCTION match_archon_code_examples ( + query_embedding VECTOR(1536), + match_count INT DEFAULT 10, + filter JSONB DEFAULT '{}'::jsonb, + source_filter TEXT DEFAULT NULL +) RETURNS TABLE ( + id BIGINT, + url VARCHAR, + chunk_number INTEGER, + content TEXT, + summary TEXT, + metadata JSONB, + source_id TEXT, + similarity FLOAT +) +LANGUAGE plpgsql +AS $$ +BEGIN + RETURN QUERY SELECT * FROM match_archon_code_examples_multi(query_embedding, 1536, match_count, filter, source_filter); +END; +$$; + +COMMIT; + +SELECT 'Ollama functions created successfully' AS status; \ No newline at end of file diff --git a/migration/0.1.0/006_ollama_create_indexes_optional.sql b/migration/0.1.0/006_ollama_create_indexes_optional.sql new file mode 100644 index 00000000..d8a38080 --- /dev/null +++ b/migration/0.1.0/006_ollama_create_indexes_optional.sql @@ -0,0 +1,67 @@ +-- ====================================================================== +-- Migration 006: Ollama Implementation - Create Indexes (Optional) +-- Creates vector indexes for performance (may timeout on large datasets) +-- ====================================================================== + +-- IMPORTANT: This migration creates vector indexes which are memory-intensive +-- If this fails, you can skip it and the system will use brute-force search +-- You can create these indexes later via direct database connection + +SET maintenance_work_mem = '512MB'; +SET statement_timeout = '10min'; + +-- Create ONE index at a time to avoid memory issues +-- Comment out any that fail and continue with the next + +-- Index 1 of 8 +CREATE INDEX IF NOT EXISTS idx_archon_crawled_pages_embedding_1536 +ON archon_crawled_pages USING ivfflat (embedding_1536 vector_cosine_ops) +WITH (lists = 100); + +-- Index 2 of 8 +CREATE INDEX IF NOT EXISTS idx_archon_code_examples_embedding_1536 +ON archon_code_examples USING ivfflat (embedding_1536 vector_cosine_ops) +WITH (lists = 100); + +-- Index 3 of 8 +CREATE INDEX IF NOT EXISTS idx_archon_crawled_pages_embedding_768 +ON archon_crawled_pages USING ivfflat (embedding_768 vector_cosine_ops) +WITH (lists = 100); + +-- Index 4 of 8 +CREATE INDEX IF NOT EXISTS idx_archon_code_examples_embedding_768 +ON archon_code_examples USING ivfflat (embedding_768 vector_cosine_ops) +WITH (lists = 100); + +-- Index 5 of 8 +CREATE INDEX IF NOT EXISTS idx_archon_crawled_pages_embedding_384 +ON archon_crawled_pages USING ivfflat (embedding_384 vector_cosine_ops) +WITH (lists = 100); + +-- Index 6 of 8 +CREATE INDEX IF NOT EXISTS idx_archon_code_examples_embedding_384 +ON archon_code_examples USING ivfflat (embedding_384 vector_cosine_ops) +WITH (lists = 100); + +-- Index 7 of 8 +CREATE INDEX IF NOT EXISTS idx_archon_crawled_pages_embedding_1024 +ON archon_crawled_pages USING ivfflat (embedding_1024 vector_cosine_ops) +WITH (lists = 100); + +-- Index 8 of 8 +CREATE INDEX IF NOT EXISTS idx_archon_code_examples_embedding_1024 +ON archon_code_examples USING ivfflat (embedding_1024 vector_cosine_ops) +WITH (lists = 100); + +-- Simple B-tree indexes (these are fast) +CREATE INDEX IF NOT EXISTS idx_archon_crawled_pages_embedding_model ON archon_crawled_pages (embedding_model); +CREATE INDEX IF NOT EXISTS idx_archon_crawled_pages_embedding_dimension ON archon_crawled_pages (embedding_dimension); +CREATE INDEX IF NOT EXISTS idx_archon_crawled_pages_llm_chat_model ON archon_crawled_pages (llm_chat_model); +CREATE INDEX IF NOT EXISTS idx_archon_code_examples_embedding_model ON archon_code_examples (embedding_model); +CREATE INDEX IF NOT EXISTS idx_archon_code_examples_embedding_dimension ON archon_code_examples (embedding_dimension); +CREATE INDEX IF NOT EXISTS idx_archon_code_examples_llm_chat_model ON archon_code_examples (llm_chat_model); + +RESET maintenance_work_mem; +RESET statement_timeout; + +SELECT 'Ollama indexes created (or skipped if timed out - that issue will be obvious in Supabase)' AS status; \ No newline at end of file diff --git a/migration/add_priority_column_to_tasks.sql b/migration/0.1.0/007_add_priority_column_to_tasks.sql similarity index 100% rename from migration/add_priority_column_to_tasks.sql rename to migration/0.1.0/007_add_priority_column_to_tasks.sql diff --git a/migration/0.1.0/008_add_migration_tracking.sql b/migration/0.1.0/008_add_migration_tracking.sql new file mode 100644 index 00000000..5cac0c72 --- /dev/null +++ b/migration/0.1.0/008_add_migration_tracking.sql @@ -0,0 +1,65 @@ +-- Migration: 008_add_migration_tracking.sql +-- Description: Create archon_migrations table for tracking applied database migrations +-- Version: 0.1.0 +-- Author: Archon Team +-- Date: 2025 + +-- Create archon_migrations table for tracking applied migrations +CREATE TABLE IF NOT EXISTS archon_migrations ( + id UUID DEFAULT gen_random_uuid() PRIMARY KEY, + version VARCHAR(20) NOT NULL, + migration_name VARCHAR(255) NOT NULL, + applied_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + checksum VARCHAR(32), + UNIQUE(version, migration_name) +); + +-- Add index for fast lookups by version +CREATE INDEX IF NOT EXISTS idx_archon_migrations_version ON archon_migrations(version); + +-- Add index for sorting by applied date +CREATE INDEX IF NOT EXISTS idx_archon_migrations_applied_at ON archon_migrations(applied_at DESC); + +-- Add comment describing table purpose +COMMENT ON TABLE archon_migrations IS 'Tracks database migrations that have been applied to maintain schema version consistency'; +COMMENT ON COLUMN archon_migrations.version IS 'Archon version that introduced this migration'; +COMMENT ON COLUMN archon_migrations.migration_name IS 'Filename of the migration SQL file'; +COMMENT ON COLUMN archon_migrations.applied_at IS 'Timestamp when migration was applied'; +COMMENT ON COLUMN archon_migrations.checksum IS 'Optional MD5 checksum of migration file content'; + +-- Record this migration as applied (self-recording pattern) +-- This allows the migration system to bootstrap itself +INSERT INTO archon_migrations (version, migration_name) +VALUES ('0.1.0', '008_add_migration_tracking') +ON CONFLICT (version, migration_name) DO NOTHING; + +-- Retroactively record previously applied migrations (001-007) +-- Since these migrations couldn't self-record (table didn't exist yet), +-- we record them here to ensure the migration system knows they've been applied +INSERT INTO archon_migrations (version, migration_name) +VALUES + ('0.1.0', '001_add_source_url_display_name'), + ('0.1.0', '002_add_hybrid_search_tsvector'), + ('0.1.0', '003_ollama_add_columns'), + ('0.1.0', '004_ollama_migrate_data'), + ('0.1.0', '005_ollama_create_functions'), + ('0.1.0', '006_ollama_create_indexes_optional'), + ('0.1.0', '007_add_priority_column_to_tasks') +ON CONFLICT (version, migration_name) DO NOTHING; + +-- Enable Row Level Security on migrations table +ALTER TABLE archon_migrations ENABLE ROW LEVEL SECURITY; + +-- Drop existing policies if they exist (makes this idempotent) +DROP POLICY IF EXISTS "Allow service role full access to archon_migrations" ON archon_migrations; +DROP POLICY IF EXISTS "Allow authenticated users to read archon_migrations" ON archon_migrations; + +-- Create RLS policies for migrations table +-- Service role has full access +CREATE POLICY "Allow service role full access to archon_migrations" ON archon_migrations + FOR ALL USING (auth.role() = 'service_role'); + +-- Authenticated users can only read migrations (they cannot modify migration history) +CREATE POLICY "Allow authenticated users to read archon_migrations" ON archon_migrations + FOR SELECT TO authenticated + USING (true); \ No newline at end of file diff --git a/migration/0.1.0/DB_UPGRADE_INSTRUCTIONS.md b/migration/0.1.0/DB_UPGRADE_INSTRUCTIONS.md new file mode 100644 index 00000000..5523d26a --- /dev/null +++ b/migration/0.1.0/DB_UPGRADE_INSTRUCTIONS.md @@ -0,0 +1,157 @@ +# Archon Database Migrations + +This folder contains database migration scripts for upgrading existing Archon installations. + +## Available Migration Scripts + +### 1. `backup_database.sql` - Pre-Migration Backup +**Always run this FIRST before any migration!** + +Creates timestamped backup tables of all your existing data: +- ✅ Complete backup of `archon_crawled_pages` +- ✅ Complete backup of `archon_code_examples` +- ✅ Complete backup of `archon_sources` +- ✅ Easy restore commands provided +- ✅ Row count verification + +### 2. Migration Scripts (Run in Order) + +You only have to run the ones you haven't already! If you don't remember exactly, it is okay to rerun migration scripts. + +**2.1. `001_add_source_url_display_name.sql`** +- Adds display name field to sources table +- Improves UI presentation of crawled sources + +**2.2. `002_add_hybrid_search_tsvector.sql`** +- Adds full-text search capabilities +- Implements hybrid search with tsvector columns +- Creates optimized search indexes + +**2.3. `003_ollama_add_columns.sql`** +- Adds multi-dimensional embedding columns (384, 768, 1024, 1536, 3072 dimensions) +- Adds model tracking fields (`llm_chat_model`, `embedding_model`, `embedding_dimension`) + +**2.4. `004_ollama_migrate_data.sql`** +- Migrates existing embeddings to new multi-dimensional columns +- Drops old embedding column after migration +- Removes obsolete indexes + +**2.5. `005_ollama_create_functions.sql`** +- Creates search functions for multi-dimensional embeddings +- Adds helper functions for dimension detection +- Maintains backward compatibility with legacy search functions + +**2.6. `006_ollama_create_indexes_optional.sql`** +- Creates vector indexes for performance (may timeout on large datasets) +- Creates B-tree indexes for model fields +- Can be skipped if timeout occurs (system will use brute-force search) + +**2.7. `007_add_priority_column_to_tasks.sql`** +- Adds priority field to tasks table +- Enables task prioritization in project management + +**2.8. `008_add_migration_tracking.sql`** +- Creates migration tracking table +- Records all applied migrations +- Enables migration version control + +## Migration Process (Follow This Order!) + +### Step 1: Backup Your Data +```sql +-- Run: backup_database.sql +-- This creates timestamped backup tables of all your data +``` + +### Step 2: Run All Migration Scripts (In Order!) +```sql +-- Run each script in sequence: +-- 1. Run: 001_add_source_url_display_name.sql +-- 2. Run: 002_add_hybrid_search_tsvector.sql +-- 3. Run: 003_ollama_add_columns.sql +-- 4. Run: 004_ollama_migrate_data.sql +-- 5. Run: 005_ollama_create_functions.sql +-- 6. Run: 006_ollama_create_indexes_optional.sql (optional - may timeout) +-- 7. Run: 007_add_priority_column_to_tasks.sql +-- 8. Run: 008_add_migration_tracking.sql +``` + +### Step 3: Restart Services +```bash +docker compose restart +``` + +## How to Run Migrations + +### Method 1: Using Supabase Dashboard (Recommended) +1. Open your Supabase project dashboard +2. Go to **SQL Editor** +3. Copy and paste the contents of the migration file +4. Click **Run** to execute the migration +5. **Important**: Supabase only shows the result of the last query - all our scripts end with a status summary table that shows the complete results + +### Method 2: Using psql Command Line +```bash +# Connect to your database +psql -h your-supabase-host -p 5432 -U postgres -d postgres + +# Run the migrations in order +\i /path/to/001_add_source_url_display_name.sql +\i /path/to/002_add_hybrid_search_tsvector.sql +\i /path/to/003_ollama_add_columns.sql +\i /path/to/004_ollama_migrate_data.sql +\i /path/to/005_ollama_create_functions.sql +\i /path/to/006_ollama_create_indexes_optional.sql +\i /path/to/007_add_priority_column_to_tasks.sql +\i /path/to/008_add_migration_tracking.sql + +# Exit +\q +``` + +### Method 3: Using Docker (if using local Supabase) +```bash +# Copy migrations to container +docker cp 001_add_source_url_display_name.sql supabase-db:/tmp/ +docker cp 002_add_hybrid_search_tsvector.sql supabase-db:/tmp/ +docker cp 003_ollama_add_columns.sql supabase-db:/tmp/ +docker cp 004_ollama_migrate_data.sql supabase-db:/tmp/ +docker cp 005_ollama_create_functions.sql supabase-db:/tmp/ +docker cp 006_ollama_create_indexes_optional.sql supabase-db:/tmp/ +docker cp 007_add_priority_column_to_tasks.sql supabase-db:/tmp/ +docker cp 008_add_migration_tracking.sql supabase-db:/tmp/ + +# Execute migrations in order +docker exec -it supabase-db psql -U postgres -d postgres -f /tmp/001_add_source_url_display_name.sql +docker exec -it supabase-db psql -U postgres -d postgres -f /tmp/002_add_hybrid_search_tsvector.sql +docker exec -it supabase-db psql -U postgres -d postgres -f /tmp/003_ollama_add_columns.sql +docker exec -it supabase-db psql -U postgres -d postgres -f /tmp/004_ollama_migrate_data.sql +docker exec -it supabase-db psql -U postgres -d postgres -f /tmp/005_ollama_create_functions.sql +docker exec -it supabase-db psql -U postgres -d postgres -f /tmp/006_ollama_create_indexes_optional.sql +docker exec -it supabase-db psql -U postgres -d postgres -f /tmp/007_add_priority_column_to_tasks.sql +docker exec -it supabase-db psql -U postgres -d postgres -f /tmp/008_add_migration_tracking.sql +``` + +## Migration Safety + +- ✅ **Safe to run multiple times** - Uses `IF NOT EXISTS` checks +- ✅ **Non-destructive** - Preserves all existing data +- ✅ **Automatic rollback** - Uses database transactions +- ✅ **Comprehensive logging** - Detailed progress notifications + +## After Migration + +1. **Restart Archon Services:** + ```bash + docker-compose restart + ``` + +2. **Verify Migration:** + - Check the Archon logs for any errors + - Try running a test crawl + - Verify search functionality works + +3. **Configure New Features:** + - Go to Settings page in Archon UI + - Configure your preferred LLM and embedding models + - New crawls will automatically use model tracking diff --git a/migration/DB_UPGRADE_INSTRUCTIONS.md b/migration/DB_UPGRADE_INSTRUCTIONS.md deleted file mode 100644 index 5ce32524..00000000 --- a/migration/DB_UPGRADE_INSTRUCTIONS.md +++ /dev/null @@ -1,167 +0,0 @@ -# Archon Database Migrations - -This folder contains database migration scripts for upgrading existing Archon installations. - -## Available Migration Scripts - -### 1. `backup_database.sql` - Pre-Migration Backup -**Always run this FIRST before any migration!** - -Creates timestamped backup tables of all your existing data: -- ✅ Complete backup of `archon_crawled_pages` -- ✅ Complete backup of `archon_code_examples` -- ✅ Complete backup of `archon_sources` -- ✅ Easy restore commands provided -- ✅ Row count verification - -### 2. `upgrade_database.sql` - Main Migration Script -**Use this migration if you:** -- Have an existing Archon installation from before multi-dimensional embedding support -- Want to upgrade to the latest features including model tracking -- Need to migrate existing embedding data to the new schema - -**Features added:** -- ✅ Multi-dimensional embedding support (384, 768, 1024, 1536, 3072 dimensions) -- ✅ Model tracking fields (`llm_chat_model`, `embedding_model`, `embedding_dimension`) -- ✅ Optimized indexes for improved search performance -- ✅ Enhanced search functions with dimension-aware querying -- ✅ Automatic migration of existing embedding data -- ✅ Legacy compatibility maintained - -### 3. `validate_migration.sql` - Post-Migration Validation -**Run this after the migration to verify everything worked correctly** - -Validates your migration results: -- ✅ Verifies all required columns were added -- ✅ Checks that database indexes were created -- ✅ Tests that all functions are working -- ✅ Shows sample data with new fields -- ✅ Provides clear success/failure reporting - -## Migration Process (Follow This Order!) - -### Step 1: Backup Your Data -```sql --- Run: backup_database.sql --- This creates timestamped backup tables of all your data -``` - -### Step 2: Run the Main Migration -```sql --- Run: upgrade_database.sql --- This adds all the new features and migrates existing data -``` - -### Step 3: Validate the Results -```sql --- Run: validate_migration.sql --- This verifies everything worked correctly -``` - -### Step 4: Restart Services -```bash -docker compose restart -``` - -## How to Run Migrations - -### Method 1: Using Supabase Dashboard (Recommended) -1. Open your Supabase project dashboard -2. Go to **SQL Editor** -3. Copy and paste the contents of the migration file -4. Click **Run** to execute the migration -5. **Important**: Supabase only shows the result of the last query - all our scripts end with a status summary table that shows the complete results - -### Method 2: Using psql Command Line -```bash -# Connect to your database -psql -h your-supabase-host -p 5432 -U postgres -d postgres - -# Run the migration -\i /path/to/upgrade_database.sql - -# Exit -\q -``` - -### Method 3: Using Docker (if using local Supabase) -```bash -# Copy migration to container -docker cp upgrade_database.sql supabase-db:/tmp/ - -# Execute migration -docker exec -it supabase-db psql -U postgres -d postgres -f /tmp/upgrade_database.sql -``` - -## Migration Safety - -- ✅ **Safe to run multiple times** - Uses `IF NOT EXISTS` checks -- ✅ **Non-destructive** - Preserves all existing data -- ✅ **Automatic rollback** - Uses database transactions -- ✅ **Comprehensive logging** - Detailed progress notifications - -## After Migration - -1. **Restart Archon Services:** - ```bash - docker-compose restart - ``` - -2. **Verify Migration:** - - Check the Archon logs for any errors - - Try running a test crawl - - Verify search functionality works - -3. **Configure New Features:** - - Go to Settings page in Archon UI - - Configure your preferred LLM and embedding models - - New crawls will automatically use model tracking - -## Troubleshooting - -### Permission Errors -If you get permission errors, ensure your database user has sufficient privileges: -```sql -GRANT ALL PRIVILEGES ON DATABASE postgres TO your_user; -GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO your_user; -``` - -### Index Creation Failures -If index creation fails due to resource constraints, the migration will continue. You can create indexes manually later: -```sql --- Example: Create missing index for 768-dimensional embeddings -CREATE INDEX idx_archon_crawled_pages_embedding_768 -ON archon_crawled_pages USING ivfflat (embedding_768 vector_cosine_ops) -WITH (lists = 100); -``` - -### Migration Verification -Check that the migration completed successfully: -```sql --- Verify new columns exist -SELECT column_name -FROM information_schema.columns -WHERE table_name = 'archon_crawled_pages' -AND column_name IN ('llm_chat_model', 'embedding_model', 'embedding_dimension', 'embedding_384', 'embedding_768'); - --- Verify functions exist -SELECT routine_name -FROM information_schema.routines -WHERE routine_name IN ('match_archon_crawled_pages_multi', 'detect_embedding_dimension'); -``` - -## Support - -If you encounter issues with the migration: - -1. Check the console output for detailed error messages -2. Verify your database connection and permissions -3. Ensure you have sufficient disk space for index creation -4. Create a GitHub issue with the error details if problems persist - -## Version Compatibility - -- **Archon v2.0+**: Use `upgrade_database.sql` -- **Earlier versions**: Use `complete_setup.sql` for fresh installations - -This migration is designed to bring any Archon installation up to the latest schema standards while preserving all existing data and functionality. \ No newline at end of file diff --git a/migration/RESET_DB.sql b/migration/RESET_DB.sql index 775464f5..ef0066a9 100644 --- a/migration/RESET_DB.sql +++ b/migration/RESET_DB.sql @@ -63,7 +63,11 @@ BEGIN -- Prompts policies DROP POLICY IF EXISTS "Allow service role full access to archon_prompts" ON archon_prompts; DROP POLICY IF EXISTS "Allow authenticated users to read archon_prompts" ON archon_prompts; - + + -- Migration tracking policies + DROP POLICY IF EXISTS "Allow service role full access to archon_migrations" ON archon_migrations; + DROP POLICY IF EXISTS "Allow authenticated users to read archon_migrations" ON archon_migrations; + -- Legacy table policies (for migration from old schema) DROP POLICY IF EXISTS "Allow service role full access" ON settings; DROP POLICY IF EXISTS "Allow authenticated users to read and update" ON settings; @@ -174,7 +178,10 @@ BEGIN -- Configuration System - new archon_ prefixed table DROP TABLE IF EXISTS archon_settings CASCADE; - + + -- Migration tracking table + DROP TABLE IF EXISTS archon_migrations CASCADE; + -- Legacy tables (without archon_ prefix) - for migration purposes DROP TABLE IF EXISTS document_versions CASCADE; DROP TABLE IF EXISTS project_sources CASCADE; diff --git a/migration/complete_setup.sql b/migration/complete_setup.sql index 322e0b2f..1609060c 100644 --- a/migration/complete_setup.sql +++ b/migration/complete_setup.sql @@ -951,6 +951,62 @@ COMMENT ON COLUMN archon_document_versions.change_type IS 'Type of change: creat COMMENT ON COLUMN archon_document_versions.document_id IS 'For docs arrays, the specific document ID that was changed'; COMMENT ON COLUMN archon_document_versions.task_id IS 'DEPRECATED: No longer used for new versions, kept for historical task version data'; +-- ===================================================== +-- SECTION 7: MIGRATION TRACKING +-- ===================================================== + +-- Create archon_migrations table for tracking applied database migrations +CREATE TABLE IF NOT EXISTS archon_migrations ( + id UUID DEFAULT gen_random_uuid() PRIMARY KEY, + version VARCHAR(20) NOT NULL, + migration_name VARCHAR(255) NOT NULL, + applied_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + checksum VARCHAR(32), + UNIQUE(version, migration_name) +); + +-- Add indexes for fast lookups +CREATE INDEX IF NOT EXISTS idx_archon_migrations_version ON archon_migrations(version); +CREATE INDEX IF NOT EXISTS idx_archon_migrations_applied_at ON archon_migrations(applied_at DESC); + +-- Add comments describing table purpose +COMMENT ON TABLE archon_migrations IS 'Tracks database migrations that have been applied to maintain schema version consistency'; +COMMENT ON COLUMN archon_migrations.version IS 'Archon version that introduced this migration'; +COMMENT ON COLUMN archon_migrations.migration_name IS 'Filename of the migration SQL file'; +COMMENT ON COLUMN archon_migrations.applied_at IS 'Timestamp when migration was applied'; +COMMENT ON COLUMN archon_migrations.checksum IS 'Optional MD5 checksum of migration file content'; + +-- Record all migrations as applied since this is a complete setup +-- This ensures the migration system knows the database is fully up-to-date +INSERT INTO archon_migrations (version, migration_name) +VALUES + ('0.1.0', '001_add_source_url_display_name'), + ('0.1.0', '002_add_hybrid_search_tsvector'), + ('0.1.0', '003_ollama_add_columns'), + ('0.1.0', '004_ollama_migrate_data'), + ('0.1.0', '005_ollama_create_functions'), + ('0.1.0', '006_ollama_create_indexes_optional'), + ('0.1.0', '007_add_priority_column_to_tasks'), + ('0.1.0', '008_add_migration_tracking') +ON CONFLICT (version, migration_name) DO NOTHING; + +-- Enable Row Level Security on migrations table +ALTER TABLE archon_migrations ENABLE ROW LEVEL SECURITY; + +-- Drop existing policies if they exist (makes this idempotent) +DROP POLICY IF EXISTS "Allow service role full access to archon_migrations" ON archon_migrations; +DROP POLICY IF EXISTS "Allow authenticated users to read archon_migrations" ON archon_migrations; + +-- Create RLS policies for migrations table +-- Service role has full access +CREATE POLICY "Allow service role full access to archon_migrations" ON archon_migrations + FOR ALL USING (auth.role() = 'service_role'); + +-- Authenticated users can only read migrations (they cannot modify migration history) +CREATE POLICY "Allow authenticated users to read archon_migrations" ON archon_migrations + FOR SELECT TO authenticated + USING (true); + -- ===================================================== -- SECTION 8: PROMPTS TABLE -- ===================================================== diff --git a/migration/upgrade_database.sql b/migration/upgrade_database.sql deleted file mode 100644 index 30a4f486..00000000 --- a/migration/upgrade_database.sql +++ /dev/null @@ -1,518 +0,0 @@ --- ====================================================================== --- UPGRADE TO MODEL TRACKING AND MULTI-DIMENSIONAL EMBEDDINGS --- ====================================================================== --- This migration upgrades existing Archon installations to support: --- 1. Multi-dimensional embedding columns (768, 1024, 1536, 3072) --- 2. Model tracking fields (llm_chat_model, embedding_model, embedding_dimension) --- 3. 384-dimension support for smaller embedding models --- 4. Enhanced search functions for multi-dimensional support --- ====================================================================== --- --- IMPORTANT: Run this ONLY if you have an existing Archon installation --- that was created BEFORE the multi-dimensional embedding support. --- --- This script is SAFE to run multiple times - it uses IF NOT EXISTS checks. --- ====================================================================== - -BEGIN; - --- ====================================================================== --- SECTION 1: ADD MULTI-DIMENSIONAL EMBEDDING COLUMNS --- ====================================================================== - --- Add multi-dimensional embedding columns to archon_crawled_pages -ALTER TABLE archon_crawled_pages -ADD COLUMN IF NOT EXISTS embedding_384 VECTOR(384), -- Small embedding models -ADD COLUMN IF NOT EXISTS embedding_768 VECTOR(768), -- Google/Ollama models -ADD COLUMN IF NOT EXISTS embedding_1024 VECTOR(1024), -- Ollama large models -ADD COLUMN IF NOT EXISTS embedding_1536 VECTOR(1536), -- OpenAI standard models -ADD COLUMN IF NOT EXISTS embedding_3072 VECTOR(3072); -- OpenAI large models - --- Add multi-dimensional embedding columns to archon_code_examples -ALTER TABLE archon_code_examples -ADD COLUMN IF NOT EXISTS embedding_384 VECTOR(384), -- Small embedding models -ADD COLUMN IF NOT EXISTS embedding_768 VECTOR(768), -- Google/Ollama models -ADD COLUMN IF NOT EXISTS embedding_1024 VECTOR(1024), -- Ollama large models -ADD COLUMN IF NOT EXISTS embedding_1536 VECTOR(1536), -- OpenAI standard models -ADD COLUMN IF NOT EXISTS embedding_3072 VECTOR(3072); -- OpenAI large models - --- ====================================================================== --- SECTION 2: ADD MODEL TRACKING COLUMNS --- ====================================================================== - --- Add model tracking columns to archon_crawled_pages -ALTER TABLE archon_crawled_pages -ADD COLUMN IF NOT EXISTS llm_chat_model TEXT, -- LLM model used for processing (e.g., 'gpt-4', 'llama3:8b') -ADD COLUMN IF NOT EXISTS embedding_model TEXT, -- Embedding model used (e.g., 'text-embedding-3-large', 'all-MiniLM-L6-v2') -ADD COLUMN IF NOT EXISTS embedding_dimension INTEGER; -- Dimension of the embedding used (384, 768, 1024, 1536, 3072) - --- Add model tracking columns to archon_code_examples -ALTER TABLE archon_code_examples -ADD COLUMN IF NOT EXISTS llm_chat_model TEXT, -- LLM model used for processing (e.g., 'gpt-4', 'llama3:8b') -ADD COLUMN IF NOT EXISTS embedding_model TEXT, -- Embedding model used (e.g., 'text-embedding-3-large', 'all-MiniLM-L6-v2') -ADD COLUMN IF NOT EXISTS embedding_dimension INTEGER; -- Dimension of the embedding used (384, 768, 1024, 1536, 3072) - --- ====================================================================== --- SECTION 3: MIGRATE EXISTING EMBEDDING DATA --- ====================================================================== - --- Check if there's existing embedding data in old 'embedding' column -DO $$ -DECLARE - crawled_pages_count INTEGER; - code_examples_count INTEGER; - dimension_detected INTEGER; -BEGIN - -- Check if old embedding column exists and has data - SELECT COUNT(*) INTO crawled_pages_count - FROM information_schema.columns - WHERE table_name = 'archon_crawled_pages' - AND column_name = 'embedding'; - - SELECT COUNT(*) INTO code_examples_count - FROM information_schema.columns - WHERE table_name = 'archon_code_examples' - AND column_name = 'embedding'; - - -- If old embedding columns exist, migrate the data - IF crawled_pages_count > 0 THEN - RAISE NOTICE 'Found existing embedding column in archon_crawled_pages - migrating data...'; - - -- Detect dimension from first non-null embedding - SELECT vector_dims(embedding) INTO dimension_detected - FROM archon_crawled_pages - WHERE embedding IS NOT NULL - LIMIT 1; - - IF dimension_detected IS NOT NULL THEN - RAISE NOTICE 'Detected embedding dimension: %', dimension_detected; - - -- Migrate based on detected dimension - CASE dimension_detected - WHEN 384 THEN - UPDATE archon_crawled_pages - SET embedding_384 = embedding, - embedding_dimension = 384, - embedding_model = COALESCE(embedding_model, 'legacy-384d-model') - WHERE embedding IS NOT NULL AND embedding_384 IS NULL; - - WHEN 768 THEN - UPDATE archon_crawled_pages - SET embedding_768 = embedding, - embedding_dimension = 768, - embedding_model = COALESCE(embedding_model, 'legacy-768d-model') - WHERE embedding IS NOT NULL AND embedding_768 IS NULL; - - WHEN 1024 THEN - UPDATE archon_crawled_pages - SET embedding_1024 = embedding, - embedding_dimension = 1024, - embedding_model = COALESCE(embedding_model, 'legacy-1024d-model') - WHERE embedding IS NOT NULL AND embedding_1024 IS NULL; - - WHEN 1536 THEN - UPDATE archon_crawled_pages - SET embedding_1536 = embedding, - embedding_dimension = 1536, - embedding_model = COALESCE(embedding_model, 'text-embedding-3-small') - WHERE embedding IS NOT NULL AND embedding_1536 IS NULL; - - WHEN 3072 THEN - UPDATE archon_crawled_pages - SET embedding_3072 = embedding, - embedding_dimension = 3072, - embedding_model = COALESCE(embedding_model, 'text-embedding-3-large') - WHERE embedding IS NOT NULL AND embedding_3072 IS NULL; - - ELSE - RAISE NOTICE 'Unsupported embedding dimension detected: %. Skipping migration.', dimension_detected; - END CASE; - - RAISE NOTICE 'Migrated existing embeddings to dimension-specific columns'; - END IF; - END IF; - - -- Migrate code examples if they exist - IF code_examples_count > 0 THEN - RAISE NOTICE 'Found existing embedding column in archon_code_examples - migrating data...'; - - -- Detect dimension from first non-null embedding - SELECT vector_dims(embedding) INTO dimension_detected - FROM archon_code_examples - WHERE embedding IS NOT NULL - LIMIT 1; - - IF dimension_detected IS NOT NULL THEN - RAISE NOTICE 'Detected code examples embedding dimension: %', dimension_detected; - - -- Migrate based on detected dimension - CASE dimension_detected - WHEN 384 THEN - UPDATE archon_code_examples - SET embedding_384 = embedding, - embedding_dimension = 384, - embedding_model = COALESCE(embedding_model, 'legacy-384d-model') - WHERE embedding IS NOT NULL AND embedding_384 IS NULL; - - WHEN 768 THEN - UPDATE archon_code_examples - SET embedding_768 = embedding, - embedding_dimension = 768, - embedding_model = COALESCE(embedding_model, 'legacy-768d-model') - WHERE embedding IS NOT NULL AND embedding_768 IS NULL; - - WHEN 1024 THEN - UPDATE archon_code_examples - SET embedding_1024 = embedding, - embedding_dimension = 1024, - embedding_model = COALESCE(embedding_model, 'legacy-1024d-model') - WHERE embedding IS NOT NULL AND embedding_1024 IS NULL; - - WHEN 1536 THEN - UPDATE archon_code_examples - SET embedding_1536 = embedding, - embedding_dimension = 1536, - embedding_model = COALESCE(embedding_model, 'text-embedding-3-small') - WHERE embedding IS NOT NULL AND embedding_1536 IS NULL; - - WHEN 3072 THEN - UPDATE archon_code_examples - SET embedding_3072 = embedding, - embedding_dimension = 3072, - embedding_model = COALESCE(embedding_model, 'text-embedding-3-large') - WHERE embedding IS NOT NULL AND embedding_3072 IS NULL; - - ELSE - RAISE NOTICE 'Unsupported code examples embedding dimension: %. Skipping migration.', dimension_detected; - END CASE; - - RAISE NOTICE 'Migrated existing code example embeddings to dimension-specific columns'; - END IF; - END IF; -END $$; - --- ====================================================================== --- SECTION 4: CLEANUP LEGACY EMBEDDING COLUMNS --- ====================================================================== - --- Remove old embedding columns after successful migration -DO $$ -DECLARE - crawled_pages_count INTEGER; - code_examples_count INTEGER; -BEGIN - -- Check if old embedding column exists in crawled pages - SELECT COUNT(*) INTO crawled_pages_count - FROM information_schema.columns - WHERE table_name = 'archon_crawled_pages' - AND column_name = 'embedding'; - - -- Check if old embedding column exists in code examples - SELECT COUNT(*) INTO code_examples_count - FROM information_schema.columns - WHERE table_name = 'archon_code_examples' - AND column_name = 'embedding'; - - -- Drop old embedding column from crawled pages if it exists - IF crawled_pages_count > 0 THEN - RAISE NOTICE 'Dropping legacy embedding column from archon_crawled_pages...'; - ALTER TABLE archon_crawled_pages DROP COLUMN embedding; - RAISE NOTICE 'Successfully removed legacy embedding column from archon_crawled_pages'; - END IF; - - -- Drop old embedding column from code examples if it exists - IF code_examples_count > 0 THEN - RAISE NOTICE 'Dropping legacy embedding column from archon_code_examples...'; - ALTER TABLE archon_code_examples DROP COLUMN embedding; - RAISE NOTICE 'Successfully removed legacy embedding column from archon_code_examples'; - END IF; - - -- Drop any indexes on the old embedding column if they exist - DROP INDEX IF EXISTS idx_archon_crawled_pages_embedding; - DROP INDEX IF EXISTS idx_archon_code_examples_embedding; - - RAISE NOTICE 'Legacy column cleanup completed'; -END $$; - --- ====================================================================== --- SECTION 5: CREATE OPTIMIZED INDEXES --- ====================================================================== - --- Create indexes for archon_crawled_pages (multi-dimensional support) -CREATE INDEX IF NOT EXISTS idx_archon_crawled_pages_embedding_384 -ON archon_crawled_pages USING ivfflat (embedding_384 vector_cosine_ops) -WITH (lists = 100); - -CREATE INDEX IF NOT EXISTS idx_archon_crawled_pages_embedding_768 -ON archon_crawled_pages USING ivfflat (embedding_768 vector_cosine_ops) -WITH (lists = 100); - -CREATE INDEX IF NOT EXISTS idx_archon_crawled_pages_embedding_1024 -ON archon_crawled_pages USING ivfflat (embedding_1024 vector_cosine_ops) -WITH (lists = 100); - -CREATE INDEX IF NOT EXISTS idx_archon_crawled_pages_embedding_1536 -ON archon_crawled_pages USING ivfflat (embedding_1536 vector_cosine_ops) -WITH (lists = 100); - --- Note: 3072-dimensional embeddings cannot have vector indexes due to PostgreSQL vector extension 2000 dimension limit --- The embedding_3072 column exists but cannot be indexed with current pgvector version --- Brute force search will be used for 3072-dimensional vectors --- CREATE INDEX IF NOT EXISTS idx_archon_crawled_pages_embedding_3072 --- ON archon_crawled_pages USING hnsw (embedding_3072 vector_cosine_ops); - --- Create indexes for archon_code_examples (multi-dimensional support) -CREATE INDEX IF NOT EXISTS idx_archon_code_examples_embedding_384 -ON archon_code_examples USING ivfflat (embedding_384 vector_cosine_ops) -WITH (lists = 100); - -CREATE INDEX IF NOT EXISTS idx_archon_code_examples_embedding_768 -ON archon_code_examples USING ivfflat (embedding_768 vector_cosine_ops) -WITH (lists = 100); - -CREATE INDEX IF NOT EXISTS idx_archon_code_examples_embedding_1024 -ON archon_code_examples USING ivfflat (embedding_1024 vector_cosine_ops) -WITH (lists = 100); - -CREATE INDEX IF NOT EXISTS idx_archon_code_examples_embedding_1536 -ON archon_code_examples USING ivfflat (embedding_1536 vector_cosine_ops) -WITH (lists = 100); - --- Note: 3072-dimensional embeddings cannot have vector indexes due to PostgreSQL vector extension 2000 dimension limit --- The embedding_3072 column exists but cannot be indexed with current pgvector version --- Brute force search will be used for 3072-dimensional vectors --- CREATE INDEX IF NOT EXISTS idx_archon_code_examples_embedding_3072 --- ON archon_code_examples USING hnsw (embedding_3072 vector_cosine_ops); - --- Create indexes for model tracking columns -CREATE INDEX IF NOT EXISTS idx_archon_crawled_pages_embedding_model -ON archon_crawled_pages (embedding_model); - -CREATE INDEX IF NOT EXISTS idx_archon_crawled_pages_embedding_dimension -ON archon_crawled_pages (embedding_dimension); - -CREATE INDEX IF NOT EXISTS idx_archon_crawled_pages_llm_chat_model -ON archon_crawled_pages (llm_chat_model); - -CREATE INDEX IF NOT EXISTS idx_archon_code_examples_embedding_model -ON archon_code_examples (embedding_model); - -CREATE INDEX IF NOT EXISTS idx_archon_code_examples_embedding_dimension -ON archon_code_examples (embedding_dimension); - -CREATE INDEX IF NOT EXISTS idx_archon_code_examples_llm_chat_model -ON archon_code_examples (llm_chat_model); - --- ====================================================================== --- SECTION 6: HELPER FUNCTIONS FOR MULTI-DIMENSIONAL SUPPORT --- ====================================================================== - --- Function to detect embedding dimension from vector -CREATE OR REPLACE FUNCTION detect_embedding_dimension(embedding_vector vector) -RETURNS INTEGER AS $$ -BEGIN - RETURN vector_dims(embedding_vector); -END; -$$ LANGUAGE plpgsql IMMUTABLE; - --- Function to get the appropriate column name for a dimension -CREATE OR REPLACE FUNCTION get_embedding_column_name(dimension INTEGER) -RETURNS TEXT AS $$ -BEGIN - CASE dimension - WHEN 384 THEN RETURN 'embedding_384'; - WHEN 768 THEN RETURN 'embedding_768'; - WHEN 1024 THEN RETURN 'embedding_1024'; - WHEN 1536 THEN RETURN 'embedding_1536'; - WHEN 3072 THEN RETURN 'embedding_3072'; - ELSE RAISE EXCEPTION 'Unsupported embedding dimension: %. Supported dimensions are: 384, 768, 1024, 1536, 3072', dimension; - END CASE; -END; -$$ LANGUAGE plpgsql IMMUTABLE; - --- ====================================================================== --- SECTION 7: ENHANCED SEARCH FUNCTIONS --- ====================================================================== - --- Create multi-dimensional function to search for documentation chunks -CREATE OR REPLACE FUNCTION match_archon_crawled_pages_multi ( - query_embedding VECTOR, - embedding_dimension INTEGER, - match_count INT DEFAULT 10, - filter JSONB DEFAULT '{}'::jsonb, - source_filter TEXT DEFAULT NULL -) RETURNS TABLE ( - id BIGINT, - url VARCHAR, - chunk_number INTEGER, - content TEXT, - metadata JSONB, - source_id TEXT, - similarity FLOAT -) -LANGUAGE plpgsql -AS $$ -#variable_conflict use_column -DECLARE - sql_query TEXT; - embedding_column TEXT; -BEGIN - -- Determine which embedding column to use based on dimension - CASE embedding_dimension - WHEN 384 THEN embedding_column := 'embedding_384'; - WHEN 768 THEN embedding_column := 'embedding_768'; - WHEN 1024 THEN embedding_column := 'embedding_1024'; - WHEN 1536 THEN embedding_column := 'embedding_1536'; - WHEN 3072 THEN embedding_column := 'embedding_3072'; - ELSE RAISE EXCEPTION 'Unsupported embedding dimension: %', embedding_dimension; - END CASE; - - -- Build dynamic query - sql_query := format(' - SELECT id, url, chunk_number, content, metadata, source_id, - 1 - (%I <=> $1) AS similarity - FROM archon_crawled_pages - WHERE (%I IS NOT NULL) - AND metadata @> $3 - AND ($4 IS NULL OR source_id = $4) - ORDER BY %I <=> $1 - LIMIT $2', - embedding_column, embedding_column, embedding_column); - - -- Execute dynamic query - RETURN QUERY EXECUTE sql_query USING query_embedding, match_count, filter, source_filter; -END; -$$; - --- Create multi-dimensional function to search for code examples -CREATE OR REPLACE FUNCTION match_archon_code_examples_multi ( - query_embedding VECTOR, - embedding_dimension INTEGER, - match_count INT DEFAULT 10, - filter JSONB DEFAULT '{}'::jsonb, - source_filter TEXT DEFAULT NULL -) RETURNS TABLE ( - id BIGINT, - url VARCHAR, - chunk_number INTEGER, - content TEXT, - summary TEXT, - metadata JSONB, - source_id TEXT, - similarity FLOAT -) -LANGUAGE plpgsql -AS $$ -#variable_conflict use_column -DECLARE - sql_query TEXT; - embedding_column TEXT; -BEGIN - -- Determine which embedding column to use based on dimension - CASE embedding_dimension - WHEN 384 THEN embedding_column := 'embedding_384'; - WHEN 768 THEN embedding_column := 'embedding_768'; - WHEN 1024 THEN embedding_column := 'embedding_1024'; - WHEN 1536 THEN embedding_column := 'embedding_1536'; - WHEN 3072 THEN embedding_column := 'embedding_3072'; - ELSE RAISE EXCEPTION 'Unsupported embedding dimension: %', embedding_dimension; - END CASE; - - -- Build dynamic query - sql_query := format(' - SELECT id, url, chunk_number, content, summary, metadata, source_id, - 1 - (%I <=> $1) AS similarity - FROM archon_code_examples - WHERE (%I IS NOT NULL) - AND metadata @> $3 - AND ($4 IS NULL OR source_id = $4) - ORDER BY %I <=> $1 - LIMIT $2', - embedding_column, embedding_column, embedding_column); - - -- Execute dynamic query - RETURN QUERY EXECUTE sql_query USING query_embedding, match_count, filter, source_filter; -END; -$$; - --- ====================================================================== --- SECTION 8: LEGACY COMPATIBILITY FUNCTIONS --- ====================================================================== - --- Legacy compatibility function for crawled pages (defaults to 1536D) -CREATE OR REPLACE FUNCTION match_archon_crawled_pages ( - query_embedding VECTOR(1536), - match_count INT DEFAULT 10, - filter JSONB DEFAULT '{}'::jsonb, - source_filter TEXT DEFAULT NULL -) RETURNS TABLE ( - id BIGINT, - url VARCHAR, - chunk_number INTEGER, - content TEXT, - metadata JSONB, - source_id TEXT, - similarity FLOAT -) -LANGUAGE plpgsql -AS $$ -BEGIN - RETURN QUERY SELECT * FROM match_archon_crawled_pages_multi(query_embedding, 1536, match_count, filter, source_filter); -END; -$$; - --- Legacy compatibility function for code examples (defaults to 1536D) -CREATE OR REPLACE FUNCTION match_archon_code_examples ( - query_embedding VECTOR(1536), - match_count INT DEFAULT 10, - filter JSONB DEFAULT '{}'::jsonb, - source_filter TEXT DEFAULT NULL -) RETURNS TABLE ( - id BIGINT, - url VARCHAR, - chunk_number INTEGER, - content TEXT, - summary TEXT, - metadata JSONB, - source_id TEXT, - similarity FLOAT -) -LANGUAGE plpgsql -AS $$ -BEGIN - RETURN QUERY SELECT * FROM match_archon_code_examples_multi(query_embedding, 1536, match_count, filter, source_filter); -END; -$$; - -COMMIT; - --- ====================================================================== --- MIGRATION COMPLETE - SUPABASE-FRIENDLY STATUS REPORT --- ====================================================================== --- This final SELECT statement consolidates all status information for --- display in Supabase SQL Editor (users only see the last query result) - -SELECT - '🎉 ARCHON MODEL TRACKING UPGRADE COMPLETED! 🎉' AS status, - 'Successfully upgraded your Archon installation' AS message, - ARRAY[ - '✅ Multi-dimensional embedding support (384, 768, 1024, 1536, 3072)', - '✅ Model tracking fields (llm_chat_model, embedding_model, embedding_dimension)', - '✅ Optimized indexes for improved search performance', - '✅ Enhanced search functions with dimension-aware querying', - '✅ Legacy compatibility maintained for existing code', - '✅ Existing embedding data migrated (if any was found)', - '✅ Support for 3072-dimensional vectors (using brute force search)' - ] AS features_added, - ARRAY[ - '• Multiple embedding providers (OpenAI, Ollama, Google, etc.)', - '• Automatic model detection and tracking', - '• Improved search accuracy with dimension-specific indexing', - '• Full audit trail of which models processed your data' - ] AS capabilities_enabled, - ARRAY[ - '1. Restart your Archon services: docker compose restart', - '2. New crawls will automatically use the enhanced features', - '3. Check the Settings page to configure your preferred models', - '4. Run validate_migration.sql to verify everything works' - ] AS next_steps; \ No newline at end of file diff --git a/migration/validate_migration.sql b/migration/validate_migration.sql deleted file mode 100644 index 3ff31924..00000000 --- a/migration/validate_migration.sql +++ /dev/null @@ -1,287 +0,0 @@ --- ====================================================================== --- ARCHON MIGRATION VALIDATION SCRIPT --- ====================================================================== --- This script validates that the upgrade_to_model_tracking.sql migration --- completed successfully and all features are working. --- ====================================================================== - -DO $$ -DECLARE - crawled_pages_columns INTEGER := 0; - code_examples_columns INTEGER := 0; - crawled_pages_indexes INTEGER := 0; - code_examples_indexes INTEGER := 0; - functions_count INTEGER := 0; - migration_success BOOLEAN := TRUE; - error_messages TEXT := ''; -BEGIN - RAISE NOTICE '===================================================================='; - RAISE NOTICE ' VALIDATING ARCHON MIGRATION RESULTS'; - RAISE NOTICE '===================================================================='; - - -- Check if required columns exist in archon_crawled_pages - SELECT COUNT(*) INTO crawled_pages_columns - FROM information_schema.columns - WHERE table_name = 'archon_crawled_pages' - AND column_name IN ( - 'embedding_384', 'embedding_768', 'embedding_1024', 'embedding_1536', 'embedding_3072', - 'llm_chat_model', 'embedding_model', 'embedding_dimension' - ); - - -- Check if required columns exist in archon_code_examples - SELECT COUNT(*) INTO code_examples_columns - FROM information_schema.columns - WHERE table_name = 'archon_code_examples' - AND column_name IN ( - 'embedding_384', 'embedding_768', 'embedding_1024', 'embedding_1536', 'embedding_3072', - 'llm_chat_model', 'embedding_model', 'embedding_dimension' - ); - - -- Check if indexes were created for archon_crawled_pages - SELECT COUNT(*) INTO crawled_pages_indexes - FROM pg_indexes - WHERE tablename = 'archon_crawled_pages' - AND indexname IN ( - 'idx_archon_crawled_pages_embedding_384', - 'idx_archon_crawled_pages_embedding_768', - 'idx_archon_crawled_pages_embedding_1024', - 'idx_archon_crawled_pages_embedding_1536', - 'idx_archon_crawled_pages_embedding_model', - 'idx_archon_crawled_pages_embedding_dimension', - 'idx_archon_crawled_pages_llm_chat_model' - ); - - -- Check if indexes were created for archon_code_examples - SELECT COUNT(*) INTO code_examples_indexes - FROM pg_indexes - WHERE tablename = 'archon_code_examples' - AND indexname IN ( - 'idx_archon_code_examples_embedding_384', - 'idx_archon_code_examples_embedding_768', - 'idx_archon_code_examples_embedding_1024', - 'idx_archon_code_examples_embedding_1536', - 'idx_archon_code_examples_embedding_model', - 'idx_archon_code_examples_embedding_dimension', - 'idx_archon_code_examples_llm_chat_model' - ); - - -- Check if required functions exist - SELECT COUNT(*) INTO functions_count - FROM information_schema.routines - WHERE routine_name IN ( - 'match_archon_crawled_pages_multi', - 'match_archon_code_examples_multi', - 'detect_embedding_dimension', - 'get_embedding_column_name' - ); - - -- Validate results - RAISE NOTICE 'COLUMN VALIDATION:'; - IF crawled_pages_columns = 8 THEN - RAISE NOTICE '✅ archon_crawled_pages: All 8 required columns found'; - ELSE - RAISE NOTICE '❌ archon_crawled_pages: Expected 8 columns, found %', crawled_pages_columns; - migration_success := FALSE; - error_messages := error_messages || '• Missing columns in archon_crawled_pages' || chr(10); - END IF; - - IF code_examples_columns = 8 THEN - RAISE NOTICE '✅ archon_code_examples: All 8 required columns found'; - ELSE - RAISE NOTICE '❌ archon_code_examples: Expected 8 columns, found %', code_examples_columns; - migration_success := FALSE; - error_messages := error_messages || '• Missing columns in archon_code_examples' || chr(10); - END IF; - - RAISE NOTICE ''; - RAISE NOTICE 'INDEX VALIDATION:'; - IF crawled_pages_indexes >= 6 THEN - RAISE NOTICE '✅ archon_crawled_pages: % indexes created (expected 6+)', crawled_pages_indexes; - ELSE - RAISE NOTICE '⚠️ archon_crawled_pages: % indexes created (expected 6+)', crawled_pages_indexes; - RAISE NOTICE ' Note: Some indexes may have failed due to resource constraints - this is OK'; - END IF; - - IF code_examples_indexes >= 6 THEN - RAISE NOTICE '✅ archon_code_examples: % indexes created (expected 6+)', code_examples_indexes; - ELSE - RAISE NOTICE '⚠️ archon_code_examples: % indexes created (expected 6+)', code_examples_indexes; - RAISE NOTICE ' Note: Some indexes may have failed due to resource constraints - this is OK'; - END IF; - - RAISE NOTICE ''; - RAISE NOTICE 'FUNCTION VALIDATION:'; - IF functions_count = 4 THEN - RAISE NOTICE '✅ All 4 required functions created successfully'; - ELSE - RAISE NOTICE '❌ Expected 4 functions, found %', functions_count; - migration_success := FALSE; - error_messages := error_messages || '• Missing database functions' || chr(10); - END IF; - - -- Test function functionality - BEGIN - PERFORM detect_embedding_dimension(ARRAY[1,2,3]::vector); - RAISE NOTICE '✅ detect_embedding_dimension function working'; - EXCEPTION WHEN OTHERS THEN - RAISE NOTICE '❌ detect_embedding_dimension function failed: %', SQLERRM; - migration_success := FALSE; - error_messages := error_messages || '• detect_embedding_dimension function not working' || chr(10); - END; - - BEGIN - PERFORM get_embedding_column_name(1536); - RAISE NOTICE '✅ get_embedding_column_name function working'; - EXCEPTION WHEN OTHERS THEN - RAISE NOTICE '❌ get_embedding_column_name function failed: %', SQLERRM; - migration_success := FALSE; - error_messages := error_messages || '• get_embedding_column_name function not working' || chr(10); - END; - - RAISE NOTICE ''; - RAISE NOTICE '===================================================================='; - - IF migration_success THEN - RAISE NOTICE '🎉 MIGRATION VALIDATION SUCCESSFUL!'; - RAISE NOTICE ''; - RAISE NOTICE 'Your Archon installation has been successfully upgraded with:'; - RAISE NOTICE '✅ Multi-dimensional embedding support'; - RAISE NOTICE '✅ Model tracking capabilities'; - RAISE NOTICE '✅ Enhanced search functions'; - RAISE NOTICE '✅ Optimized database indexes'; - RAISE NOTICE ''; - RAISE NOTICE 'Next steps:'; - RAISE NOTICE '1. Restart your Archon services: docker compose restart'; - RAISE NOTICE '2. Test with a small crawl to verify functionality'; - RAISE NOTICE '3. Configure your preferred models in Settings'; - ELSE - RAISE NOTICE '❌ MIGRATION VALIDATION FAILED!'; - RAISE NOTICE ''; - RAISE NOTICE 'Issues found:'; - RAISE NOTICE '%', error_messages; - RAISE NOTICE 'Please check the migration logs and re-run if necessary.'; - END IF; - - RAISE NOTICE '===================================================================='; - - -- Show sample of existing data if any - DECLARE - sample_count INTEGER; - r RECORD; -- Declare the loop variable as RECORD type - BEGIN - SELECT COUNT(*) INTO sample_count FROM archon_crawled_pages LIMIT 1; - IF sample_count > 0 THEN - RAISE NOTICE ''; - RAISE NOTICE 'SAMPLE DATA CHECK:'; - - -- Show one record with the new columns - FOR r IN - SELECT url, embedding_model, embedding_dimension, - CASE WHEN llm_chat_model IS NOT NULL THEN '✅' ELSE '⚪' END as llm_status, - CASE WHEN embedding_384 IS NOT NULL THEN '✅ 384' - WHEN embedding_768 IS NOT NULL THEN '✅ 768' - WHEN embedding_1024 IS NOT NULL THEN '✅ 1024' - WHEN embedding_1536 IS NOT NULL THEN '✅ 1536' - WHEN embedding_3072 IS NOT NULL THEN '✅ 3072' - ELSE '⚪ None' END as embedding_status - FROM archon_crawled_pages - LIMIT 3 - LOOP - RAISE NOTICE 'Record: % | Model: % | Dimension: % | LLM: % | Embedding: %', - substring(r.url from 1 for 40), - COALESCE(r.embedding_model, 'None'), - COALESCE(r.embedding_dimension::text, 'None'), - r.llm_status, - r.embedding_status; - END LOOP; - END IF; - END; - -END $$; - --- ====================================================================== --- VALIDATION COMPLETE - SUPABASE-FRIENDLY STATUS REPORT --- ====================================================================== --- This final SELECT statement consolidates validation results for --- display in Supabase SQL Editor (users only see the last query result) - -WITH validation_results AS ( - -- Check if all required columns exist - SELECT - COUNT(*) FILTER (WHERE column_name IN ('embedding_384', 'embedding_768', 'embedding_1024', 'embedding_1536', 'embedding_3072')) as embedding_columns, - COUNT(*) FILTER (WHERE column_name IN ('llm_chat_model', 'embedding_model', 'embedding_dimension')) as tracking_columns - FROM information_schema.columns - WHERE table_name = 'archon_crawled_pages' -), -function_check AS ( - -- Check if required functions exist - SELECT - COUNT(*) FILTER (WHERE routine_name IN ('match_archon_crawled_pages_multi', 'match_archon_code_examples_multi', 'detect_embedding_dimension', 'get_embedding_column_name')) as functions_count - FROM information_schema.routines - WHERE routine_type = 'FUNCTION' -), -index_check AS ( - -- Check if indexes exist - SELECT - COUNT(*) FILTER (WHERE indexname LIKE '%embedding_%') as embedding_indexes - FROM pg_indexes - WHERE tablename IN ('archon_crawled_pages', 'archon_code_examples') -), -data_sample AS ( - -- Get sample of data with new columns - SELECT - COUNT(*) as total_records, - COUNT(*) FILTER (WHERE embedding_model IS NOT NULL) as records_with_model_tracking, - COUNT(*) FILTER (WHERE embedding_384 IS NOT NULL OR embedding_768 IS NOT NULL OR embedding_1024 IS NOT NULL OR embedding_1536 IS NOT NULL OR embedding_3072 IS NOT NULL) as records_with_multi_dim_embeddings - FROM archon_crawled_pages -), -overall_status AS ( - SELECT - CASE - WHEN v.embedding_columns = 5 AND v.tracking_columns = 3 AND f.functions_count >= 4 AND i.embedding_indexes > 0 - THEN '✅ MIGRATION VALIDATION SUCCESSFUL!' - ELSE '❌ MIGRATION VALIDATION FAILED!' - END as status, - v.embedding_columns, - v.tracking_columns, - f.functions_count, - i.embedding_indexes, - d.total_records, - d.records_with_model_tracking, - d.records_with_multi_dim_embeddings - FROM validation_results v, function_check f, index_check i, data_sample d -) -SELECT - status, - CASE - WHEN embedding_columns = 5 AND tracking_columns = 3 AND functions_count >= 4 AND embedding_indexes > 0 - THEN 'All validation checks passed successfully' - ELSE 'Some validation checks failed - please review the results' - END as message, - json_build_object( - 'embedding_columns_added', embedding_columns || '/5', - 'tracking_columns_added', tracking_columns || '/3', - 'search_functions_created', functions_count || '+ functions', - 'embedding_indexes_created', embedding_indexes || '+ indexes' - ) as technical_validation, - json_build_object( - 'total_records', total_records, - 'records_with_model_tracking', records_with_model_tracking, - 'records_with_multi_dimensional_embeddings', records_with_multi_dim_embeddings - ) as data_status, - CASE - WHEN embedding_columns = 5 AND tracking_columns = 3 AND functions_count >= 4 AND embedding_indexes > 0 - THEN ARRAY[ - '1. Restart Archon services: docker compose restart', - '2. Test with a small crawl to verify functionality', - '3. Configure your preferred models in Settings', - '4. New crawls will automatically use model tracking' - ] - ELSE ARRAY[ - '1. Check migration logs for specific errors', - '2. Re-run upgrade_database.sql if needed', - '3. Ensure database has sufficient permissions', - '4. Contact support if issues persist' - ] - END as next_steps -FROM overall_status; \ No newline at end of file diff --git a/python/src/server/api_routes/migration_api.py b/python/src/server/api_routes/migration_api.py new file mode 100644 index 00000000..fec04d24 --- /dev/null +++ b/python/src/server/api_routes/migration_api.py @@ -0,0 +1,170 @@ +""" +API routes for database migration tracking and management. +""" + +from datetime import datetime + +import logfire +from fastapi import APIRouter, Header, HTTPException, Response +from pydantic import BaseModel + +from ..config.version import ARCHON_VERSION +from ..services.migration_service import migration_service +from ..utils.etag_utils import check_etag, generate_etag + + +# Response models +class MigrationRecord(BaseModel): + """Represents an applied migration.""" + + version: str + migration_name: str + applied_at: datetime + checksum: str | None = None + + +class PendingMigration(BaseModel): + """Represents a pending migration.""" + + version: str + name: str + sql_content: str + file_path: str + checksum: str | None = None + + +class MigrationStatusResponse(BaseModel): + """Complete migration status response.""" + + pending_migrations: list[PendingMigration] + applied_migrations: list[MigrationRecord] + has_pending: bool + bootstrap_required: bool + current_version: str + pending_count: int + applied_count: int + + +class MigrationHistoryResponse(BaseModel): + """Migration history response.""" + + migrations: list[MigrationRecord] + total_count: int + current_version: str + + +# Create router +router = APIRouter(prefix="/api/migrations", tags=["migrations"]) + + +@router.get("/status", response_model=MigrationStatusResponse) +async def get_migration_status( + response: Response, if_none_match: str | None = Header(None) +): + """ + Get current migration status including pending and applied migrations. + + Returns comprehensive migration status with: + - List of pending migrations with SQL content + - List of applied migrations + - Bootstrap flag if migrations table doesn't exist + - Current version information + """ + try: + # Get migration status from service + status = await migration_service.get_migration_status() + + # Generate ETag for response + etag = generate_etag(status) + + # Check if client has current data + if check_etag(if_none_match, etag): + # Client has current data, return 304 + response.status_code = 304 + response.headers["ETag"] = f'"{etag}"' + response.headers["Cache-Control"] = "no-cache, must-revalidate" + return Response(status_code=304) + else: + # Client needs new data + response.headers["ETag"] = f'"{etag}"' + response.headers["Cache-Control"] = "no-cache, must-revalidate" + return MigrationStatusResponse(**status) + + except Exception as e: + logfire.error(f"Error getting migration status: {e}") + raise HTTPException(status_code=500, detail=f"Failed to get migration status: {str(e)}") from e + + +@router.get("/history", response_model=MigrationHistoryResponse) +async def get_migration_history(response: Response, if_none_match: str | None = Header(None)): + """ + Get history of applied migrations. + + Returns list of all applied migrations sorted by date. + """ + try: + # Get applied migrations from service + applied = await migration_service.get_applied_migrations() + + # Format response + history = { + "migrations": [ + MigrationRecord( + version=m.version, + migration_name=m.migration_name, + applied_at=m.applied_at, + checksum=m.checksum, + ) + for m in applied + ], + "total_count": len(applied), + "current_version": ARCHON_VERSION, + } + + # Generate ETag for response + etag = generate_etag(history) + + # Check if client has current data + if check_etag(if_none_match, etag): + # Client has current data, return 304 + response.status_code = 304 + response.headers["ETag"] = f'"{etag}"' + response.headers["Cache-Control"] = "no-cache, must-revalidate" + return Response(status_code=304) + else: + # Client needs new data + response.headers["ETag"] = f'"{etag}"' + response.headers["Cache-Control"] = "no-cache, must-revalidate" + return MigrationHistoryResponse(**history) + + except Exception as e: + logfire.error(f"Error getting migration history: {e}") + raise HTTPException(status_code=500, detail=f"Failed to get migration history: {str(e)}") from e + + +@router.get("/pending", response_model=list[PendingMigration]) +async def get_pending_migrations(): + """ + Get list of pending migrations only. + + Returns simplified list of migrations that need to be applied. + """ + try: + # Get pending migrations from service + pending = await migration_service.get_pending_migrations() + + # Format response + return [ + PendingMigration( + version=m.version, + name=m.name, + sql_content=m.sql_content, + file_path=m.file_path, + checksum=m.checksum, + ) + for m in pending + ] + + except Exception as e: + logfire.error(f"Error getting pending migrations: {e}") + raise HTTPException(status_code=500, detail=f"Failed to get pending migrations: {str(e)}") from e diff --git a/python/src/server/api_routes/version_api.py b/python/src/server/api_routes/version_api.py new file mode 100644 index 00000000..ebfd306f --- /dev/null +++ b/python/src/server/api_routes/version_api.py @@ -0,0 +1,121 @@ +""" +API routes for version checking and update management. +""" + +from datetime import datetime +from typing import Any + +import logfire +from fastapi import APIRouter, Header, HTTPException, Response +from pydantic import BaseModel + +from ..config.version import ARCHON_VERSION +from ..services.version_service import version_service +from ..utils.etag_utils import check_etag, generate_etag + + +# Response models +class ReleaseAsset(BaseModel): + """Represents a downloadable asset from a release.""" + + name: str + size: int + download_count: int + browser_download_url: str + content_type: str + + +class VersionCheckResponse(BaseModel): + """Version check response with update information.""" + + current: str + latest: str | None + update_available: bool + release_url: str | None + release_notes: str | None + published_at: datetime | None + check_error: str | None = None + assets: list[dict[str, Any]] | None = None + author: str | None = None + + +class CurrentVersionResponse(BaseModel): + """Simple current version response.""" + + version: str + timestamp: datetime + + +# Create router +router = APIRouter(prefix="/api/version", tags=["version"]) + + +@router.get("/check", response_model=VersionCheckResponse) +async def check_for_updates(response: Response, if_none_match: str | None = Header(None)): + """ + Check for available Archon updates. + + Queries GitHub releases API to determine if a newer version is available. + Results are cached for 1 hour to avoid rate limiting. + + Returns: + Version information including current, latest, and update availability + """ + try: + # Get version check results from service + result = await version_service.check_for_updates() + + # Generate ETag for response + etag = generate_etag(result) + + # Check if client has current data + if check_etag(if_none_match, etag): + # Client has current data, return 304 + response.status_code = 304 + response.headers["ETag"] = f'"{etag}"' + response.headers["Cache-Control"] = "no-cache, must-revalidate" + return Response(status_code=304) + else: + # Client needs new data + response.headers["ETag"] = f'"{etag}"' + response.headers["Cache-Control"] = "no-cache, must-revalidate" + return VersionCheckResponse(**result) + + except Exception as e: + logfire.error(f"Error checking for updates: {e}") + # Return safe response with error + return VersionCheckResponse( + current=ARCHON_VERSION, + latest=None, + update_available=False, + release_url=None, + release_notes=None, + published_at=None, + check_error=str(e), + ) + + +@router.get("/current", response_model=CurrentVersionResponse) +async def get_current_version(): + """ + Get the current Archon version. + + Simple endpoint that returns the installed version without checking for updates. + """ + return CurrentVersionResponse(version=ARCHON_VERSION, timestamp=datetime.now()) + + +@router.post("/clear-cache") +async def clear_version_cache(): + """ + Clear the version check cache. + + Forces the next version check to query GitHub API instead of using cached data. + Useful for testing or forcing an immediate update check. + """ + try: + version_service.clear_cache() + return {"message": "Version cache cleared successfully", "success": True} + except Exception as e: + logfire.error(f"Error clearing version cache: {e}") + raise HTTPException(status_code=500, detail=f"Failed to clear cache: {str(e)}") from e diff --git a/python/src/server/config/version.py b/python/src/server/config/version.py new file mode 100644 index 00000000..97b74302 --- /dev/null +++ b/python/src/server/config/version.py @@ -0,0 +1,11 @@ +""" +Version configuration for Archon. +""" + +# Current version of Archon +# Update this with each release +ARCHON_VERSION = "0.1.0" + +# Repository information for GitHub API +GITHUB_REPO_OWNER = "coleam00" +GITHUB_REPO_NAME = "Archon" diff --git a/python/src/server/main.py b/python/src/server/main.py index ba0b19cb..19456e06 100644 --- a/python/src/server/main.py +++ b/python/src/server/main.py @@ -23,10 +23,12 @@ from .api_routes.bug_report_api import router as bug_report_router from .api_routes.internal_api import router as internal_router from .api_routes.knowledge_api import router as knowledge_router from .api_routes.mcp_api import router as mcp_router +from .api_routes.migration_api import router as migration_router from .api_routes.ollama_api import router as ollama_router from .api_routes.progress_api import router as progress_router from .api_routes.projects_api import router as projects_router from .api_routes.providers_api import router as providers_router +from .api_routes.version_api import router as version_router # Import modular API routers from .api_routes.settings_api import router as settings_router @@ -188,6 +190,8 @@ app.include_router(agent_chat_router) app.include_router(internal_router) app.include_router(bug_report_router) app.include_router(providers_router) +app.include_router(version_router) +app.include_router(migration_router) # Root endpoint diff --git a/python/src/server/services/migration_service.py b/python/src/server/services/migration_service.py new file mode 100644 index 00000000..f47a4d68 --- /dev/null +++ b/python/src/server/services/migration_service.py @@ -0,0 +1,233 @@ +""" +Database migration tracking and management service. +""" + +import hashlib +from pathlib import Path +from typing import Any + +import logfire +from supabase import Client + +from .client_manager import get_supabase_client +from ..config.version import ARCHON_VERSION + + +class MigrationRecord: + """Represents a migration record from the database.""" + + def __init__(self, data: dict[str, Any]): + self.id = data.get("id") + self.version = data.get("version") + self.migration_name = data.get("migration_name") + self.applied_at = data.get("applied_at") + self.checksum = data.get("checksum") + + +class PendingMigration: + """Represents a pending migration from the filesystem.""" + + def __init__(self, version: str, name: str, sql_content: str, file_path: str): + self.version = version + self.name = name + self.sql_content = sql_content + self.file_path = file_path + self.checksum = self._calculate_checksum(sql_content) + + def _calculate_checksum(self, content: str) -> str: + """Calculate MD5 checksum of migration content.""" + return hashlib.md5(content.encode()).hexdigest() + + +class MigrationService: + """Service for managing database migrations.""" + + def __init__(self): + self._supabase: Client | None = None + # Handle both Docker (/app/migration) and local (./migration) environments + if Path("/app/migration").exists(): + self._migrations_dir = Path("/app/migration") + else: + self._migrations_dir = Path("migration") + + def _get_supabase_client(self) -> Client: + """Get or create Supabase client.""" + if not self._supabase: + self._supabase = get_supabase_client() + return self._supabase + + async def check_migrations_table_exists(self) -> bool: + """ + Check if the archon_migrations table exists in the database. + + Returns: + True if table exists, False otherwise + """ + try: + supabase = self._get_supabase_client() + + # Query to check if table exists + result = supabase.rpc( + "sql", + { + "query": """ + SELECT EXISTS ( + SELECT 1 + FROM information_schema.tables + WHERE table_schema = 'public' + AND table_name = 'archon_migrations' + ) as exists + """ + } + ).execute() + + # Check if result indicates table exists + if result.data and len(result.data) > 0: + return result.data[0].get("exists", False) + return False + except Exception: + # If the SQL function doesn't exist or query fails, try direct query + try: + supabase = self._get_supabase_client() + # Try to select from the table with limit 0 + supabase.table("archon_migrations").select("id").limit(0).execute() + return True + except Exception as e: + logfire.info(f"Migrations table does not exist: {e}") + return False + + async def get_applied_migrations(self) -> list[MigrationRecord]: + """ + Get list of applied migrations from the database. + + Returns: + List of MigrationRecord objects + """ + try: + # Check if table exists first + if not await self.check_migrations_table_exists(): + logfire.info("Migrations table does not exist, returning empty list") + return [] + + supabase = self._get_supabase_client() + result = supabase.table("archon_migrations").select("*").order("applied_at", desc=True).execute() + + return [MigrationRecord(row) for row in result.data] + except Exception as e: + logfire.error(f"Error fetching applied migrations: {e}") + # Return empty list if we can't fetch migrations + return [] + + async def scan_migration_directory(self) -> list[PendingMigration]: + """ + Scan the migration directory for all SQL files. + + Returns: + List of PendingMigration objects + """ + migrations = [] + + if not self._migrations_dir.exists(): + logfire.warning(f"Migration directory does not exist: {self._migrations_dir}") + return migrations + + # Scan all version directories + for version_dir in sorted(self._migrations_dir.iterdir()): + if not version_dir.is_dir(): + continue + + version = version_dir.name + + # Scan all SQL files in version directory + for sql_file in sorted(version_dir.glob("*.sql")): + try: + # Read SQL content + with open(sql_file, encoding="utf-8") as f: + sql_content = f.read() + + # Extract migration name (filename without extension) + migration_name = sql_file.stem + + # Create pending migration object + migration = PendingMigration( + version=version, + name=migration_name, + sql_content=sql_content, + file_path=str(sql_file.relative_to(Path.cwd())), + ) + migrations.append(migration) + except Exception as e: + logfire.error(f"Error reading migration file {sql_file}: {e}") + + return migrations + + async def get_pending_migrations(self) -> list[PendingMigration]: + """ + Get list of pending migrations by comparing filesystem with database. + + Returns: + List of PendingMigration objects that haven't been applied + """ + # Get all migrations from filesystem + all_migrations = await self.scan_migration_directory() + + # Check if migrations table exists + if not await self.check_migrations_table_exists(): + # Bootstrap case - all migrations are pending + logfire.info("Migrations table doesn't exist, all migrations are pending") + return all_migrations + + # Get applied migrations from database + applied_migrations = await self.get_applied_migrations() + + # Create set of applied migration identifiers + applied_set = {(m.version, m.migration_name) for m in applied_migrations} + + # Filter out applied migrations + pending = [m for m in all_migrations if (m.version, m.name) not in applied_set] + + return pending + + async def get_migration_status(self) -> dict[str, Any]: + """ + Get comprehensive migration status. + + Returns: + Dictionary with pending and applied migrations info + """ + pending = await self.get_pending_migrations() + applied = await self.get_applied_migrations() + + # Check if bootstrap is required + bootstrap_required = not await self.check_migrations_table_exists() + + return { + "pending_migrations": [ + { + "version": m.version, + "name": m.name, + "sql_content": m.sql_content, + "file_path": m.file_path, + "checksum": m.checksum, + } + for m in pending + ], + "applied_migrations": [ + { + "version": m.version, + "migration_name": m.migration_name, + "applied_at": m.applied_at, + "checksum": m.checksum, + } + for m in applied + ], + "has_pending": len(pending) > 0, + "bootstrap_required": bootstrap_required, + "current_version": ARCHON_VERSION, + "pending_count": len(pending), + "applied_count": len(applied), + } + + +# Export singleton instance +migration_service = MigrationService() diff --git a/python/src/server/services/version_service.py b/python/src/server/services/version_service.py new file mode 100644 index 00000000..b916c984 --- /dev/null +++ b/python/src/server/services/version_service.py @@ -0,0 +1,162 @@ +""" +Version checking service with GitHub API integration. +""" + +from datetime import datetime, timedelta +from typing import Any + +import httpx +import logfire + +from ..config.version import ARCHON_VERSION, GITHUB_REPO_NAME, GITHUB_REPO_OWNER +from ..utils.semantic_version import is_newer_version + + +class VersionService: + """Service for checking Archon version against GitHub releases.""" + + def __init__(self): + self._cache: dict[str, Any] | None = None + self._cache_time: datetime | None = None + self._cache_ttl = 3600 # 1 hour cache TTL + + def _is_cache_valid(self) -> bool: + """Check if cached data is still valid.""" + if not self._cache or not self._cache_time: + return False + + age = datetime.now() - self._cache_time + return age < timedelta(seconds=self._cache_ttl) + + async def get_latest_release(self) -> dict[str, Any] | None: + """ + Fetch latest release information from GitHub API. + + Returns: + Release data dictionary or None if no releases + """ + # Check cache first + if self._is_cache_valid(): + logfire.debug("Using cached version data") + return self._cache + + # GitHub API endpoint + url = f"https://api.github.com/repos/{GITHUB_REPO_OWNER}/{GITHUB_REPO_NAME}/releases/latest" + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + response = await client.get( + url, + headers={ + "Accept": "application/vnd.github.v3+json", + "User-Agent": f"Archon/{ARCHON_VERSION}", + }, + ) + + # Handle 404 - no releases yet + if response.status_code == 404: + logfire.info("No releases found on GitHub") + return None + + response.raise_for_status() + data = response.json() + + # Cache the successful response + self._cache = data + self._cache_time = datetime.now() + + return data + + except httpx.TimeoutException: + logfire.warning("GitHub API request timed out") + # Return cached data if available + if self._cache: + return self._cache + return None + except httpx.HTTPError as e: + logfire.error(f"HTTP error fetching latest release: {e}") + # Return cached data if available + if self._cache: + return self._cache + return None + except Exception as e: + logfire.error(f"Unexpected error fetching latest release: {e}") + # Return cached data if available + if self._cache: + return self._cache + return None + + async def check_for_updates(self) -> dict[str, Any]: + """ + Check if a newer version of Archon is available. + + Returns: + Dictionary with version check results + """ + try: + # Get latest release from GitHub + release = await self.get_latest_release() + + if not release: + # No releases found or error occurred + return { + "current": ARCHON_VERSION, + "latest": None, + "update_available": False, + "release_url": None, + "release_notes": None, + "published_at": None, + "check_error": None, + } + + # Extract version from tag_name (e.g., "v1.0.0" -> "1.0.0") + latest_version = release.get("tag_name", "") + if latest_version.startswith("v"): + latest_version = latest_version[1:] + + # Check if update is available + update_available = is_newer_version(ARCHON_VERSION, latest_version) + + # Parse published date + published_at = None + if release.get("published_at"): + try: + published_at = datetime.fromisoformat( + release["published_at"].replace("Z", "+00:00") + ) + except Exception: + pass + + return { + "current": ARCHON_VERSION, + "latest": latest_version, + "update_available": update_available, + "release_url": release.get("html_url"), + "release_notes": release.get("body"), + "published_at": published_at, + "check_error": None, + "assets": release.get("assets", []), + "author": release.get("author", {}).get("login"), + } + + except Exception as e: + logfire.error(f"Error checking for updates: {e}") + # Return safe default with error + return { + "current": ARCHON_VERSION, + "latest": None, + "update_available": False, + "release_url": None, + "release_notes": None, + "published_at": None, + "check_error": str(e), + } + + def clear_cache(self): + """Clear the cached version data.""" + self._cache = None + self._cache_time = None + + +# Export singleton instance +version_service = VersionService() diff --git a/python/src/server/utils/semantic_version.py b/python/src/server/utils/semantic_version.py new file mode 100644 index 00000000..d869f7a8 --- /dev/null +++ b/python/src/server/utils/semantic_version.py @@ -0,0 +1,107 @@ +""" +Semantic version parsing and comparison utilities. +""" + +import re + + +def parse_version(version_string: str) -> tuple[int, int, int, str | None]: + """ + Parse a semantic version string into major, minor, patch, and optional prerelease. + + Supports formats like: + - "1.0.0" + - "v1.0.0" + - "1.0.0-beta" + - "v1.0.0-rc.1" + + Args: + version_string: Version string to parse + + Returns: + Tuple of (major, minor, patch, prerelease) + """ + # Remove 'v' prefix if present + version = version_string.strip() + if version.lower().startswith('v'): + version = version[1:] + + # Parse version with optional prerelease + pattern = r'^(\d+)\.(\d+)\.(\d+)(?:-(.+))?$' + match = re.match(pattern, version) + + if not match: + # Try to handle incomplete versions like "1.0" + simple_pattern = r'^(\d+)(?:\.(\d+))?(?:\.(\d+))?$' + simple_match = re.match(simple_pattern, version) + if simple_match: + major = int(simple_match.group(1)) + minor = int(simple_match.group(2) or 0) + patch = int(simple_match.group(3) or 0) + return (major, minor, patch, None) + raise ValueError(f"Invalid version string: {version_string}") + + major = int(match.group(1)) + minor = int(match.group(2)) + patch = int(match.group(3)) + prerelease = match.group(4) + + return (major, minor, patch, prerelease) + + +def compare_versions(version1: str, version2: str) -> int: + """ + Compare two semantic version strings. + + Args: + version1: First version string + version2: Second version string + + Returns: + -1 if version1 < version2 + 0 if version1 == version2 + 1 if version1 > version2 + """ + v1 = parse_version(version1) + v2 = parse_version(version2) + + # Compare major, minor, patch + for i in range(3): + if v1[i] < v2[i]: + return -1 + elif v1[i] > v2[i]: + return 1 + + # If main versions are equal, check prerelease + # No prerelease is considered newer than any prerelease + if v1[3] is None and v2[3] is None: + return 0 + elif v1[3] is None: + return 1 # v1 is release, v2 is prerelease + elif v2[3] is None: + return -1 # v1 is prerelease, v2 is release + else: + # Both have prereleases, compare lexicographically + if v1[3] < v2[3]: + return -1 + elif v1[3] > v2[3]: + return 1 + return 0 + + +def is_newer_version(current: str, latest: str) -> bool: + """ + Check if latest version is newer than current version. + + Args: + current: Current version string + latest: Latest version string to compare + + Returns: + True if latest > current, False otherwise + """ + try: + return compare_versions(latest, current) > 0 + except ValueError: + # If we can't parse versions, assume no update + return False diff --git a/python/tests/server/api_routes/test_migration_api.py b/python/tests/server/api_routes/test_migration_api.py new file mode 100644 index 00000000..57b9da2c --- /dev/null +++ b/python/tests/server/api_routes/test_migration_api.py @@ -0,0 +1,206 @@ +""" +Unit tests for migration_api.py +""" + +from datetime import datetime +from unittest.mock import AsyncMock, patch + +import pytest +from fastapi.testclient import TestClient + +from src.server.config.version import ARCHON_VERSION +from src.server.main import app +from src.server.services.migration_service import MigrationRecord, PendingMigration + + +@pytest.fixture +def client(): + """Create test client.""" + return TestClient(app) + + +@pytest.fixture +def mock_applied_migrations(): + """Mock applied migration data.""" + return [ + MigrationRecord({ + "version": "0.1.0", + "migration_name": "001_initial", + "applied_at": datetime(2025, 1, 1, 0, 0, 0), + "checksum": "abc123", + }), + MigrationRecord({ + "version": "0.1.0", + "migration_name": "002_add_column", + "applied_at": datetime(2025, 1, 2, 0, 0, 0), + "checksum": "def456", + }), + ] + + +@pytest.fixture +def mock_pending_migrations(): + """Mock pending migration data.""" + return [ + PendingMigration( + version="0.1.0", + name="003_add_index", + sql_content="CREATE INDEX idx_test ON test_table(name);", + file_path="migration/0.1.0/003_add_index.sql" + ), + PendingMigration( + version="0.1.0", + name="004_add_table", + sql_content="CREATE TABLE new_table (id INT);", + file_path="migration/0.1.0/004_add_table.sql" + ), + ] + + +@pytest.fixture +def mock_migration_status(mock_applied_migrations, mock_pending_migrations): + """Mock complete migration status.""" + return { + "pending_migrations": [ + {"version": m.version, "name": m.name, "sql_content": m.sql_content, "file_path": m.file_path, "checksum": m.checksum} + for m in mock_pending_migrations + ], + "applied_migrations": [ + {"version": m.version, "migration_name": m.migration_name, "applied_at": m.applied_at, "checksum": m.checksum} + for m in mock_applied_migrations + ], + "has_pending": True, + "bootstrap_required": False, + "current_version": ARCHON_VERSION, + "pending_count": 2, + "applied_count": 2, + } + + +def test_get_migration_status_success(client, mock_migration_status): + """Test successful migration status retrieval.""" + with patch("src.server.api_routes.migration_api.migration_service") as mock_service: + mock_service.get_migration_status = AsyncMock(return_value=mock_migration_status) + + response = client.get("/api/migrations/status") + + assert response.status_code == 200 + data = response.json() + assert data["current_version"] == ARCHON_VERSION + assert data["has_pending"] is True + assert data["bootstrap_required"] is False + assert data["pending_count"] == 2 + assert data["applied_count"] == 2 + assert len(data["pending_migrations"]) == 2 + assert len(data["applied_migrations"]) == 2 + + +def test_get_migration_status_bootstrap_required(client): + """Test migration status when bootstrap is required.""" + mock_status = { + "pending_migrations": [], + "applied_migrations": [], + "has_pending": True, + "bootstrap_required": True, + "current_version": ARCHON_VERSION, + "pending_count": 5, + "applied_count": 0, + } + + with patch("src.server.api_routes.migration_api.migration_service") as mock_service: + mock_service.get_migration_status = AsyncMock(return_value=mock_status) + + response = client.get("/api/migrations/status") + + assert response.status_code == 200 + data = response.json() + assert data["bootstrap_required"] is True + assert data["applied_count"] == 0 + + +def test_get_migration_status_error(client): + """Test error handling in migration status.""" + with patch("src.server.api_routes.migration_api.migration_service") as mock_service: + mock_service.get_migration_status = AsyncMock(side_effect=Exception("Database error")) + + response = client.get("/api/migrations/status") + + assert response.status_code == 500 + assert "Failed to get migration status" in response.json()["detail"] + + +def test_get_migration_history_success(client, mock_applied_migrations): + """Test successful migration history retrieval.""" + with patch("src.server.api_routes.migration_api.migration_service") as mock_service: + mock_service.get_applied_migrations = AsyncMock(return_value=mock_applied_migrations) + + response = client.get("/api/migrations/history") + + assert response.status_code == 200 + data = response.json() + assert data["total_count"] == 2 + assert data["current_version"] == ARCHON_VERSION + assert len(data["migrations"]) == 2 + assert data["migrations"][0]["migration_name"] == "001_initial" + + +def test_get_migration_history_empty(client): + """Test migration history when no migrations applied.""" + with patch("src.server.api_routes.migration_api.migration_service") as mock_service: + mock_service.get_applied_migrations = AsyncMock(return_value=[]) + + response = client.get("/api/migrations/history") + + assert response.status_code == 200 + data = response.json() + assert data["total_count"] == 0 + assert len(data["migrations"]) == 0 + + +def test_get_migration_history_error(client): + """Test error handling in migration history.""" + with patch("src.server.api_routes.migration_api.migration_service") as mock_service: + mock_service.get_applied_migrations = AsyncMock(side_effect=Exception("Database error")) + + response = client.get("/api/migrations/history") + + assert response.status_code == 500 + assert "Failed to get migration history" in response.json()["detail"] + + +def test_get_pending_migrations_success(client, mock_pending_migrations): + """Test successful pending migrations retrieval.""" + with patch("src.server.api_routes.migration_api.migration_service") as mock_service: + mock_service.get_pending_migrations = AsyncMock(return_value=mock_pending_migrations) + + response = client.get("/api/migrations/pending") + + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + assert data[0]["name"] == "003_add_index" + assert data[0]["sql_content"] == "CREATE INDEX idx_test ON test_table(name);" + assert data[1]["name"] == "004_add_table" + + +def test_get_pending_migrations_none(client): + """Test when no pending migrations exist.""" + with patch("src.server.api_routes.migration_api.migration_service") as mock_service: + mock_service.get_pending_migrations = AsyncMock(return_value=[]) + + response = client.get("/api/migrations/pending") + + assert response.status_code == 200 + data = response.json() + assert len(data) == 0 + + +def test_get_pending_migrations_error(client): + """Test error handling in pending migrations.""" + with patch("src.server.api_routes.migration_api.migration_service") as mock_service: + mock_service.get_pending_migrations = AsyncMock(side_effect=Exception("File error")) + + response = client.get("/api/migrations/pending") + + assert response.status_code == 500 + assert "Failed to get pending migrations" in response.json()["detail"] \ No newline at end of file diff --git a/python/tests/server/api_routes/test_version_api.py b/python/tests/server/api_routes/test_version_api.py new file mode 100644 index 00000000..d704c613 --- /dev/null +++ b/python/tests/server/api_routes/test_version_api.py @@ -0,0 +1,147 @@ +""" +Unit tests for version_api.py +""" + +from datetime import datetime +from unittest.mock import AsyncMock, patch + +import pytest +from fastapi.testclient import TestClient + +from src.server.config.version import ARCHON_VERSION +from src.server.main import app + + +@pytest.fixture +def client(): + """Create test client.""" + return TestClient(app) + + +@pytest.fixture +def mock_version_data(): + """Mock version check data.""" + return { + "current": ARCHON_VERSION, + "latest": "0.2.0", + "update_available": True, + "release_url": "https://github.com/coleam00/Archon/releases/tag/v0.2.0", + "release_notes": "New features and bug fixes", + "published_at": datetime(2025, 1, 1, 0, 0, 0), + "check_error": None, + "author": "coleam00", + "assets": [{"name": "archon.zip", "size": 1024000}], + } + + +def test_check_for_updates_success(client, mock_version_data): + """Test successful version check.""" + with patch("src.server.api_routes.version_api.version_service") as mock_service: + mock_service.check_for_updates = AsyncMock(return_value=mock_version_data) + + response = client.get("/api/version/check") + + assert response.status_code == 200 + data = response.json() + assert data["current"] == ARCHON_VERSION + assert data["latest"] == "0.2.0" + assert data["update_available"] is True + assert data["release_url"] == mock_version_data["release_url"] + + +def test_check_for_updates_no_update(client): + """Test when no update is available.""" + mock_data = { + "current": ARCHON_VERSION, + "latest": ARCHON_VERSION, + "update_available": False, + "release_url": None, + "release_notes": None, + "published_at": None, + "check_error": None, + } + + with patch("src.server.api_routes.version_api.version_service") as mock_service: + mock_service.check_for_updates = AsyncMock(return_value=mock_data) + + response = client.get("/api/version/check") + + assert response.status_code == 200 + data = response.json() + assert data["current"] == ARCHON_VERSION + assert data["latest"] == ARCHON_VERSION + assert data["update_available"] is False + + + + +def test_check_for_updates_with_etag_modified(client, mock_version_data): + """Test ETag handling when data has changed.""" + with patch("src.server.api_routes.version_api.version_service") as mock_service: + mock_service.check_for_updates = AsyncMock(return_value=mock_version_data) + + # First request + response1 = client.get("/api/version/check") + assert response1.status_code == 200 + old_etag = response1.headers.get("etag") + + # Modify data + modified_data = mock_version_data.copy() + modified_data["latest"] = "0.3.0" + mock_service.check_for_updates = AsyncMock(return_value=modified_data) + + # Second request with old ETag + response2 = client.get("/api/version/check", headers={"If-None-Match": old_etag}) + assert response2.status_code == 200 # Data changed, return new data + data = response2.json() + assert data["latest"] == "0.3.0" + + +def test_check_for_updates_error_handling(client): + """Test error handling in version check.""" + with patch("src.server.api_routes.version_api.version_service") as mock_service: + mock_service.check_for_updates = AsyncMock(side_effect=Exception("API error")) + + response = client.get("/api/version/check") + + assert response.status_code == 200 # Should still return 200 + data = response.json() + assert data["current"] == ARCHON_VERSION + assert data["latest"] is None + assert data["update_available"] is False + assert data["check_error"] == "API error" + + +def test_get_current_version(client): + """Test getting current version.""" + response = client.get("/api/version/current") + + assert response.status_code == 200 + data = response.json() + assert data["version"] == ARCHON_VERSION + assert "timestamp" in data + + +def test_clear_version_cache_success(client): + """Test clearing version cache.""" + with patch("src.server.api_routes.version_api.version_service") as mock_service: + mock_service.clear_cache.return_value = None + + response = client.post("/api/version/clear-cache") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["message"] == "Version cache cleared successfully" + mock_service.clear_cache.assert_called_once() + + +def test_clear_version_cache_error(client): + """Test error handling when clearing cache fails.""" + with patch("src.server.api_routes.version_api.version_service") as mock_service: + mock_service.clear_cache.side_effect = Exception("Cache error") + + response = client.post("/api/version/clear-cache") + + assert response.status_code == 500 + assert "Failed to clear cache" in response.json()["detail"] \ No newline at end of file diff --git a/python/tests/server/services/test_migration_service.py b/python/tests/server/services/test_migration_service.py new file mode 100644 index 00000000..83e46c9b --- /dev/null +++ b/python/tests/server/services/test_migration_service.py @@ -0,0 +1,271 @@ +""" +Fixed unit tests for migration_service.py +""" + +import hashlib +from datetime import datetime +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest + +from src.server.config.version import ARCHON_VERSION +from src.server.services.migration_service import ( + MigrationRecord, + MigrationService, + PendingMigration, +) + + +@pytest.fixture +def migration_service(): + """Create a migration service instance.""" + with patch("src.server.services.migration_service.Path.exists") as mock_exists: + # Mock that migration directory exists locally + mock_exists.return_value = False # Docker path doesn't exist + service = MigrationService() + return service + + +@pytest.fixture +def mock_supabase_client(): + """Mock Supabase client.""" + client = MagicMock() + return client + + +def test_pending_migration_init(): + """Test PendingMigration initialization and checksum calculation.""" + migration = PendingMigration( + version="0.1.0", + name="001_initial", + sql_content="CREATE TABLE test (id INT);", + file_path="migration/0.1.0/001_initial.sql" + ) + + assert migration.version == "0.1.0" + assert migration.name == "001_initial" + assert migration.sql_content == "CREATE TABLE test (id INT);" + assert migration.file_path == "migration/0.1.0/001_initial.sql" + assert migration.checksum == hashlib.md5("CREATE TABLE test (id INT);".encode()).hexdigest() + + +def test_migration_record_init(): + """Test MigrationRecord initialization from database data.""" + data = { + "id": "123-456", + "version": "0.1.0", + "migration_name": "001_initial", + "applied_at": "2025-01-01T00:00:00Z", + "checksum": "abc123" + } + + record = MigrationRecord(data) + + assert record.id == "123-456" + assert record.version == "0.1.0" + assert record.migration_name == "001_initial" + assert record.applied_at == "2025-01-01T00:00:00Z" + assert record.checksum == "abc123" + + +def test_migration_service_init_local(): + """Test MigrationService initialization with local path.""" + with patch("src.server.services.migration_service.Path.exists") as mock_exists: + # Mock that Docker path doesn't exist + mock_exists.return_value = False + + service = MigrationService() + assert service._migrations_dir == Path("migration") + + +def test_migration_service_init_docker(): + """Test MigrationService initialization with Docker path.""" + with patch("src.server.services.migration_service.Path.exists") as mock_exists: + # Mock that Docker path exists + mock_exists.return_value = True + + service = MigrationService() + assert service._migrations_dir == Path("/app/migration") + + +@pytest.mark.asyncio +async def test_get_applied_migrations_success(migration_service, mock_supabase_client): + """Test successful retrieval of applied migrations.""" + mock_response = MagicMock() + mock_response.data = [ + { + "id": "123", + "version": "0.1.0", + "migration_name": "001_initial", + "applied_at": "2025-01-01T00:00:00Z", + "checksum": "abc123", + }, + ] + + mock_supabase_client.table.return_value.select.return_value.order.return_value.execute.return_value = mock_response + + with patch.object(migration_service, '_get_supabase_client', return_value=mock_supabase_client): + with patch.object(migration_service, 'check_migrations_table_exists', return_value=True): + result = await migration_service.get_applied_migrations() + + assert len(result) == 1 + assert isinstance(result[0], MigrationRecord) + assert result[0].version == "0.1.0" + assert result[0].migration_name == "001_initial" + + +@pytest.mark.asyncio +async def test_get_applied_migrations_table_not_exists(migration_service, mock_supabase_client): + """Test handling when migrations table doesn't exist.""" + with patch.object(migration_service, '_get_supabase_client', return_value=mock_supabase_client): + with patch.object(migration_service, 'check_migrations_table_exists', return_value=False): + result = await migration_service.get_applied_migrations() + assert result == [] + + +@pytest.mark.asyncio +async def test_get_pending_migrations_with_files(migration_service, mock_supabase_client): + """Test getting pending migrations from filesystem.""" + # Mock scan_migration_directory to return test migrations + mock_migrations = [ + PendingMigration( + version="0.1.0", + name="001_initial", + sql_content="CREATE TABLE test;", + file_path="migration/0.1.0/001_initial.sql" + ), + PendingMigration( + version="0.1.0", + name="002_update", + sql_content="ALTER TABLE test ADD col TEXT;", + file_path="migration/0.1.0/002_update.sql" + ) + ] + + # Mock no applied migrations + with patch.object(migration_service, 'scan_migration_directory', return_value=mock_migrations): + with patch.object(migration_service, 'get_applied_migrations', return_value=[]): + result = await migration_service.get_pending_migrations() + + assert len(result) == 2 + assert all(isinstance(m, PendingMigration) for m in result) + assert result[0].name == "001_initial" + assert result[1].name == "002_update" + + +@pytest.mark.asyncio +async def test_get_pending_migrations_some_applied(migration_service, mock_supabase_client): + """Test getting pending migrations when some are already applied.""" + # Mock all migrations + mock_all_migrations = [ + PendingMigration( + version="0.1.0", + name="001_initial", + sql_content="CREATE TABLE test;", + file_path="migration/0.1.0/001_initial.sql" + ), + PendingMigration( + version="0.1.0", + name="002_update", + sql_content="ALTER TABLE test ADD col TEXT;", + file_path="migration/0.1.0/002_update.sql" + ) + ] + + # Mock first migration as applied + mock_applied = [ + MigrationRecord({ + "version": "0.1.0", + "migration_name": "001_initial", + "applied_at": "2025-01-01T00:00:00Z", + "checksum": None + }) + ] + + with patch.object(migration_service, 'scan_migration_directory', return_value=mock_all_migrations): + with patch.object(migration_service, 'get_applied_migrations', return_value=mock_applied): + with patch.object(migration_service, 'check_migrations_table_exists', return_value=True): + result = await migration_service.get_pending_migrations() + + assert len(result) == 1 + assert result[0].name == "002_update" + + +@pytest.mark.asyncio +async def test_get_migration_status_all_applied(migration_service, mock_supabase_client): + """Test migration status when all migrations are applied.""" + # Mock one migration file + mock_all_migrations = [ + PendingMigration( + version="0.1.0", + name="001_initial", + sql_content="CREATE TABLE test;", + file_path="migration/0.1.0/001_initial.sql" + ) + ] + + # Mock migration as applied + mock_applied = [ + MigrationRecord({ + "version": "0.1.0", + "migration_name": "001_initial", + "applied_at": "2025-01-01T00:00:00Z", + "checksum": None + }) + ] + + with patch.object(migration_service, 'scan_migration_directory', return_value=mock_all_migrations): + with patch.object(migration_service, 'get_applied_migrations', return_value=mock_applied): + with patch.object(migration_service, 'check_migrations_table_exists', return_value=True): + result = await migration_service.get_migration_status() + + assert result["current_version"] == ARCHON_VERSION + assert result["has_pending"] is False + assert result["bootstrap_required"] is False + assert result["pending_count"] == 0 + assert result["applied_count"] == 1 + + +@pytest.mark.asyncio +async def test_get_migration_status_bootstrap_required(migration_service, mock_supabase_client): + """Test migration status when bootstrap is required (table doesn't exist).""" + # Mock migration files + mock_all_migrations = [ + PendingMigration( + version="0.1.0", + name="001_initial", + sql_content="CREATE TABLE test;", + file_path="migration/0.1.0/001_initial.sql" + ), + PendingMigration( + version="0.1.0", + name="002_update", + sql_content="ALTER TABLE test ADD col TEXT;", + file_path="migration/0.1.0/002_update.sql" + ) + ] + + with patch.object(migration_service, 'scan_migration_directory', return_value=mock_all_migrations): + with patch.object(migration_service, 'get_applied_migrations', return_value=[]): + with patch.object(migration_service, 'check_migrations_table_exists', return_value=False): + result = await migration_service.get_migration_status() + + assert result["bootstrap_required"] is True + assert result["has_pending"] is True + assert result["pending_count"] == 2 + assert result["applied_count"] == 0 + assert len(result["pending_migrations"]) == 2 + + +@pytest.mark.asyncio +async def test_get_migration_status_no_files(migration_service, mock_supabase_client): + """Test migration status when no migration files exist.""" + with patch.object(migration_service, 'scan_migration_directory', return_value=[]): + with patch.object(migration_service, 'get_applied_migrations', return_value=[]): + with patch.object(migration_service, 'check_migrations_table_exists', return_value=True): + result = await migration_service.get_migration_status() + + assert result["has_pending"] is False + assert result["pending_count"] == 0 + assert len(result["pending_migrations"]) == 0 \ No newline at end of file diff --git a/python/tests/server/services/test_version_service.py b/python/tests/server/services/test_version_service.py new file mode 100644 index 00000000..0f76394d --- /dev/null +++ b/python/tests/server/services/test_version_service.py @@ -0,0 +1,234 @@ +""" +Unit tests for version_service.py +""" + +import json +from datetime import datetime, timedelta +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from src.server.config.version import ARCHON_VERSION +from src.server.services.version_service import VersionService + + +@pytest.fixture +def version_service(): + """Create a fresh version service instance for each test.""" + service = VersionService() + # Clear any cache from previous tests + service._cache = None + service._cache_time = None + return service + + +@pytest.fixture +def mock_release_data(): + """Mock GitHub release data.""" + return { + "tag_name": "v0.2.0", + "name": "Archon v0.2.0", + "html_url": "https://github.com/coleam00/Archon/releases/tag/v0.2.0", + "body": "## Release Notes\n\nNew features and bug fixes", + "published_at": "2025-01-01T00:00:00Z", + "author": {"login": "coleam00"}, + "assets": [ + { + "name": "archon-v0.2.0.zip", + "size": 1024000, + "download_count": 100, + "browser_download_url": "https://github.com/coleam00/Archon/releases/download/v0.2.0/archon-v0.2.0.zip", + "content_type": "application/zip", + } + ], + } + + +@pytest.mark.asyncio +async def test_get_latest_release_success(version_service, mock_release_data): + """Test successful fetching of latest release from GitHub.""" + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = mock_release_data + mock_client.get.return_value = mock_response + mock_client_class.return_value.__aenter__.return_value = mock_client + + result = await version_service.get_latest_release() + + assert result == mock_release_data + assert version_service._cache == mock_release_data + assert version_service._cache_time is not None + + +@pytest.mark.asyncio +async def test_get_latest_release_uses_cache(version_service, mock_release_data): + """Test that cache is used when available and not expired.""" + # Set up cache + version_service._cache = mock_release_data + version_service._cache_time = datetime.now() + + with patch("httpx.AsyncClient") as mock_client_class: + result = await version_service.get_latest_release() + + # Should not make HTTP request + mock_client_class.assert_not_called() + assert result == mock_release_data + + +@pytest.mark.asyncio +async def test_get_latest_release_cache_expired(version_service, mock_release_data): + """Test that cache is refreshed when expired.""" + # Set up expired cache + old_data = {"tag_name": "v0.1.0"} + version_service._cache = old_data + version_service._cache_time = datetime.now() - timedelta(hours=2) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = mock_release_data + mock_client.get.return_value = mock_response + mock_client_class.return_value.__aenter__.return_value = mock_client + + result = await version_service.get_latest_release() + + # Should make new HTTP request + mock_client.get.assert_called_once() + assert result == mock_release_data + assert version_service._cache == mock_release_data + + +@pytest.mark.asyncio +async def test_get_latest_release_404(version_service): + """Test handling of 404 (no releases).""" + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_response = MagicMock() + mock_response.status_code = 404 + mock_client.get.return_value = mock_response + mock_client_class.return_value.__aenter__.return_value = mock_client + + result = await version_service.get_latest_release() + + assert result is None + + +@pytest.mark.asyncio +async def test_get_latest_release_timeout(version_service, mock_release_data): + """Test handling of timeout with cache fallback.""" + # Set up cache + version_service._cache = mock_release_data + version_service._cache_time = datetime.now() - timedelta(hours=2) # Expired + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client.get.side_effect = httpx.TimeoutException("Timeout") + mock_client_class.return_value.__aenter__.return_value = mock_client + + result = await version_service.get_latest_release() + + # Should return cached data + assert result == mock_release_data + + +@pytest.mark.asyncio +async def test_check_for_updates_new_version_available(version_service, mock_release_data): + """Test when a new version is available.""" + with patch.object(version_service, "get_latest_release", return_value=mock_release_data): + result = await version_service.check_for_updates() + + assert result["current"] == ARCHON_VERSION + assert result["latest"] == "0.2.0" + assert result["update_available"] is True + assert result["release_url"] == mock_release_data["html_url"] + assert result["release_notes"] == mock_release_data["body"] + assert result["published_at"] == datetime.fromisoformat("2025-01-01T00:00:00+00:00") + assert result["author"] == "coleam00" + assert len(result["assets"]) == 1 + + +@pytest.mark.asyncio +async def test_check_for_updates_same_version(version_service): + """Test when current version is up to date.""" + mock_data = {"tag_name": f"v{ARCHON_VERSION}", "html_url": "test_url", "body": "notes"} + + with patch.object(version_service, "get_latest_release", return_value=mock_data): + result = await version_service.check_for_updates() + + assert result["current"] == ARCHON_VERSION + assert result["latest"] == ARCHON_VERSION + assert result["update_available"] is False + + +@pytest.mark.asyncio +async def test_check_for_updates_no_release(version_service): + """Test when no releases are found.""" + with patch.object(version_service, "get_latest_release", return_value=None): + result = await version_service.check_for_updates() + + assert result["current"] == ARCHON_VERSION + assert result["latest"] is None + assert result["update_available"] is False + assert result["release_url"] is None + + +@pytest.mark.asyncio +async def test_check_for_updates_parse_version(version_service, mock_release_data): + """Test version parsing with and without 'v' prefix.""" + # Test with 'v' prefix + mock_release_data["tag_name"] = "v1.2.3" + with patch.object(version_service, "get_latest_release", return_value=mock_release_data): + result = await version_service.check_for_updates() + assert result["latest"] == "1.2.3" + + # Test without 'v' prefix + mock_release_data["tag_name"] = "2.0.0" + with patch.object(version_service, "get_latest_release", return_value=mock_release_data): + result = await version_service.check_for_updates() + assert result["latest"] == "2.0.0" + + +@pytest.mark.asyncio +async def test_check_for_updates_missing_fields(version_service): + """Test handling of incomplete release data.""" + mock_data = {"tag_name": "v0.2.0"} # Minimal data + + with patch.object(version_service, "get_latest_release", return_value=mock_data): + result = await version_service.check_for_updates() + + assert result["latest"] == "0.2.0" + assert result["release_url"] is None + assert result["release_notes"] is None + assert result["published_at"] is None + assert result["author"] is None + assert result["assets"] == [] # Empty list, not None + + +def test_clear_cache(version_service, mock_release_data): + """Test cache clearing.""" + # Set up cache + version_service._cache = mock_release_data + version_service._cache_time = datetime.now() + + # Clear cache + version_service.clear_cache() + + assert version_service._cache is None + assert version_service._cache_time is None + + +def test_is_newer_version(): + """Test version comparison logic using the utility function.""" + from src.server.utils.semantic_version import is_newer_version + + # Test various version comparisons + assert is_newer_version("1.0.0", "2.0.0") is True + assert is_newer_version("2.0.0", "1.0.0") is False + assert is_newer_version("1.0.0", "1.0.0") is False + assert is_newer_version("1.0.0", "1.1.0") is True + assert is_newer_version("1.0.0", "1.0.1") is True + assert is_newer_version("1.2.3", "1.2.3") is False \ No newline at end of file From d3a5c3311a96da9c1bbeb9af397ca2fdd0484021 Mon Sep 17 00:00:00 2001 From: Wirasm <152263317+Wirasm@users.noreply.github.com> Date: Mon, 22 Sep 2025 12:54:55 +0300 Subject: [PATCH 4/7] refactor: move shared hooks from ui/hooks to shared/hooks (#729) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reorganize hook structure to follow vertical slice architecture: - Move useSmartPolling, useThemeAware, useToast to features/shared/hooks - Update 38+ import statements across codebase - Update test file mocks to reference new locations - Remove old ui/hooks directory This change aligns shared utilities with the architectural pattern where truly shared code resides in the shared directory. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-authored-by: Claude --- archon-ui-main/src/components/bug-report/BugReportModal.tsx | 2 +- archon-ui-main/src/components/layout/MainLayout.tsx | 2 +- archon-ui-main/src/components/onboarding/ProviderStep.tsx | 2 +- archon-ui-main/src/components/settings/APIKeysSection.tsx | 2 +- .../src/components/settings/CodeExtractionSettings.tsx | 2 +- archon-ui-main/src/components/settings/FeaturesSection.tsx | 2 +- archon-ui-main/src/components/settings/IDEGlobalRules.tsx | 2 +- .../src/components/settings/OllamaConfigurationPanel.tsx | 2 +- .../src/components/settings/OllamaInstanceHealthIndicator.tsx | 2 +- .../src/components/settings/OllamaModelDiscoveryModal.tsx | 2 +- .../src/components/settings/OllamaModelSelectionModal.tsx | 2 +- archon-ui-main/src/components/settings/RAGSettings.tsx | 2 +- .../src/features/knowledge/components/AddKnowledgeDialog.tsx | 2 +- .../src/features/knowledge/components/KnowledgeTable.tsx | 2 +- .../knowledge/hooks/tests/useKnowledgeQueries.test.ts | 4 ++-- .../src/features/knowledge/hooks/useKnowledgeQueries.ts | 4 ++-- archon-ui-main/src/features/knowledge/views/KnowledgeView.tsx | 2 +- .../src/features/mcp/components/McpConfigSection.tsx | 2 +- archon-ui-main/src/features/mcp/hooks/useMcpQueries.ts | 2 +- .../src/features/progress/hooks/useProgressQueries.ts | 2 +- .../src/features/projects/components/ProjectCardActions.tsx | 2 +- .../features/projects/hooks/tests/useProjectQueries.test.ts | 4 ++-- .../src/features/projects/hooks/useProjectQueries.ts | 4 ++-- .../features/projects/tasks/components/TaskCardActions.tsx | 2 +- .../projects/tasks/hooks/tests/useTaskQueries.test.ts | 4 ++-- .../src/features/projects/tasks/hooks/useTaskEditor.ts | 2 +- .../src/features/projects/tasks/hooks/useTaskQueries.ts | 4 ++-- .../settings/migrations/components/PendingMigrationsModal.tsx | 2 +- .../features/settings/migrations/hooks/useMigrationQueries.ts | 2 +- .../src/features/settings/version/hooks/useVersionQueries.ts | 2 +- archon-ui-main/src/features/{ui => shared}/hooks/index.ts | 2 +- .../{ui => shared}/hooks/tests/useSmartPolling.test.ts | 0 .../src/features/{ui => shared}/hooks/useSmartPolling.ts | 0 .../src/features/{ui => shared}/hooks/useThemeAware.ts | 0 archon-ui-main/src/features/{ui => shared}/hooks/useToast.ts | 2 +- archon-ui-main/src/features/ui/components/ToastProvider.tsx | 2 +- archon-ui-main/src/pages/SettingsPage.tsx | 2 +- 37 files changed, 40 insertions(+), 40 deletions(-) rename archon-ui-main/src/features/{ui => shared}/hooks/index.ts (70%) rename archon-ui-main/src/features/{ui => shared}/hooks/tests/useSmartPolling.test.ts (100%) rename archon-ui-main/src/features/{ui => shared}/hooks/useSmartPolling.ts (100%) rename archon-ui-main/src/features/{ui => shared}/hooks/useThemeAware.ts (100%) rename archon-ui-main/src/features/{ui => shared}/hooks/useToast.ts (97%) diff --git a/archon-ui-main/src/components/bug-report/BugReportModal.tsx b/archon-ui-main/src/components/bug-report/BugReportModal.tsx index 69b40262..2bfcb007 100644 --- a/archon-ui-main/src/components/bug-report/BugReportModal.tsx +++ b/archon-ui-main/src/components/bug-report/BugReportModal.tsx @@ -5,7 +5,7 @@ import { Button } from "../ui/Button"; import { Input } from "../ui/Input"; import { Card } from "../ui/Card"; import { Select } from "../ui/Select"; -import { useToast } from "../../features/ui/hooks/useToast"; +import { useToast } from "../../features/shared/hooks/useToast"; import { bugReportService, BugContext, diff --git a/archon-ui-main/src/components/layout/MainLayout.tsx b/archon-ui-main/src/components/layout/MainLayout.tsx index da0b2696..73fcc1de 100644 --- a/archon-ui-main/src/components/layout/MainLayout.tsx +++ b/archon-ui-main/src/components/layout/MainLayout.tsx @@ -2,7 +2,7 @@ import { AlertCircle, WifiOff } from "lucide-react"; import type React from "react"; import { useEffect } from "react"; import { useLocation, useNavigate } from "react-router-dom"; -import { useToast } from "../../features/ui/hooks/useToast"; +import { useToast } from "../../features/shared/hooks/useToast"; import { cn } from "../../lib/utils"; import { credentialsService } from "../../services/credentialsService"; import { isLmConfigured } from "../../utils/onboarding"; diff --git a/archon-ui-main/src/components/onboarding/ProviderStep.tsx b/archon-ui-main/src/components/onboarding/ProviderStep.tsx index 546be5f7..1beae073 100644 --- a/archon-ui-main/src/components/onboarding/ProviderStep.tsx +++ b/archon-ui-main/src/components/onboarding/ProviderStep.tsx @@ -3,7 +3,7 @@ import { Key, ExternalLink, Save, Loader } from "lucide-react"; import { Input } from "../ui/Input"; import { Button } from "../ui/Button"; import { Select } from "../ui/Select"; -import { useToast } from "../../features/ui/hooks/useToast"; +import { useToast } from "../../features/shared/hooks/useToast"; import { credentialsService } from "../../services/credentialsService"; interface ProviderStepProps { diff --git a/archon-ui-main/src/components/settings/APIKeysSection.tsx b/archon-ui-main/src/components/settings/APIKeysSection.tsx index 231e1125..0d926014 100644 --- a/archon-ui-main/src/components/settings/APIKeysSection.tsx +++ b/archon-ui-main/src/components/settings/APIKeysSection.tsx @@ -4,7 +4,7 @@ import { Input } from '../ui/Input'; import { Button } from '../ui/Button'; import { Card } from '../ui/Card'; import { credentialsService, Credential } from '../../services/credentialsService'; -import { useToast } from '../../features/ui/hooks/useToast'; +import { useToast } from '../../features/shared/hooks/useToast'; interface CustomCredential { key: string; diff --git a/archon-ui-main/src/components/settings/CodeExtractionSettings.tsx b/archon-ui-main/src/components/settings/CodeExtractionSettings.tsx index 2e7d40fb..2dd322df 100644 --- a/archon-ui-main/src/components/settings/CodeExtractionSettings.tsx +++ b/archon-ui-main/src/components/settings/CodeExtractionSettings.tsx @@ -3,7 +3,7 @@ import { Code, Check, Save, Loader } from 'lucide-react'; import { Card } from '../ui/Card'; import { Input } from '../ui/Input'; import { Button } from '../ui/Button'; -import { useToast } from '../../features/ui/hooks/useToast'; +import { useToast } from '../../features/shared/hooks/useToast'; import { credentialsService } from '../../services/credentialsService'; interface CodeExtractionSettingsProps { diff --git a/archon-ui-main/src/components/settings/FeaturesSection.tsx b/archon-ui-main/src/components/settings/FeaturesSection.tsx index 5fc57fb4..0a61cf5c 100644 --- a/archon-ui-main/src/components/settings/FeaturesSection.tsx +++ b/archon-ui-main/src/components/settings/FeaturesSection.tsx @@ -4,7 +4,7 @@ import { Toggle } from '../ui/Toggle'; import { Card } from '../ui/Card'; import { useTheme } from '../../contexts/ThemeContext'; import { credentialsService } from '../../services/credentialsService'; -import { useToast } from '../../features/ui/hooks/useToast'; +import { useToast } from '../../features/shared/hooks/useToast'; import { serverHealthService } from '../../services/serverHealthService'; export const FeaturesSection = () => { diff --git a/archon-ui-main/src/components/settings/IDEGlobalRules.tsx b/archon-ui-main/src/components/settings/IDEGlobalRules.tsx index 7f65ce4b..b4e29ef9 100644 --- a/archon-ui-main/src/components/settings/IDEGlobalRules.tsx +++ b/archon-ui-main/src/components/settings/IDEGlobalRules.tsx @@ -2,7 +2,7 @@ import { useState } from 'react'; import { FileCode, Copy, Check } from 'lucide-react'; import { Card } from '../ui/Card'; import { Button } from '../ui/Button'; -import { useToast } from '../../features/ui/hooks/useToast'; +import { useToast } from '../../features/shared/hooks/useToast'; import { copyToClipboard } from '../../features/shared/utils/clipboard'; type RuleType = 'claude' | 'universal'; diff --git a/archon-ui-main/src/components/settings/OllamaConfigurationPanel.tsx b/archon-ui-main/src/components/settings/OllamaConfigurationPanel.tsx index c4a9e267..4da6f9a0 100644 --- a/archon-ui-main/src/components/settings/OllamaConfigurationPanel.tsx +++ b/archon-ui-main/src/components/settings/OllamaConfigurationPanel.tsx @@ -3,7 +3,7 @@ import { Card } from '../ui/Card'; import { Button } from '../ui/Button'; import { Input } from '../ui/Input'; import { Badge } from '../ui/Badge'; -import { useToast } from '../../features/ui/hooks/useToast'; +import { useToast } from '../../features/shared/hooks/useToast'; import { cn } from '../../lib/utils'; import { credentialsService, OllamaInstance } from '../../services/credentialsService'; import { OllamaModelDiscoveryModal } from './OllamaModelDiscoveryModal'; diff --git a/archon-ui-main/src/components/settings/OllamaInstanceHealthIndicator.tsx b/archon-ui-main/src/components/settings/OllamaInstanceHealthIndicator.tsx index c65b2159..4c646dfa 100644 --- a/archon-ui-main/src/components/settings/OllamaInstanceHealthIndicator.tsx +++ b/archon-ui-main/src/components/settings/OllamaInstanceHealthIndicator.tsx @@ -3,7 +3,7 @@ import { Badge } from '../ui/Badge'; import { Button } from '../ui/Button'; import { Card } from '../ui/Card'; import { cn } from '../../lib/utils'; -import { useToast } from '../../features/ui/hooks/useToast'; +import { useToast } from '../../features/shared/hooks/useToast'; import { ollamaService } from '../../services/ollamaService'; import type { HealthIndicatorProps } from './types/OllamaTypes'; diff --git a/archon-ui-main/src/components/settings/OllamaModelDiscoveryModal.tsx b/archon-ui-main/src/components/settings/OllamaModelDiscoveryModal.tsx index 7525f1bd..53a698b5 100644 --- a/archon-ui-main/src/components/settings/OllamaModelDiscoveryModal.tsx +++ b/archon-ui-main/src/components/settings/OllamaModelDiscoveryModal.tsx @@ -13,7 +13,7 @@ import { Button } from '../ui/Button'; import { Input } from '../ui/Input'; import { Badge } from '../ui/Badge'; import { Card } from '../ui/Card'; -import { useToast } from '../../features/ui/hooks/useToast'; +import { useToast } from '../../features/shared/hooks/useToast'; import { ollamaService, type OllamaModel, type ModelDiscoveryResponse } from '../../services/ollamaService'; import type { OllamaInstance, ModelSelectionState } from './types/OllamaTypes'; diff --git a/archon-ui-main/src/components/settings/OllamaModelSelectionModal.tsx b/archon-ui-main/src/components/settings/OllamaModelSelectionModal.tsx index 9933526a..3c539f9c 100644 --- a/archon-ui-main/src/components/settings/OllamaModelSelectionModal.tsx +++ b/archon-ui-main/src/components/settings/OllamaModelSelectionModal.tsx @@ -3,7 +3,7 @@ import ReactDOM from 'react-dom'; import { X, Search, RotateCcw, Zap, Server, Eye, Settings, Download, Box } from 'lucide-react'; import { Button } from '../ui/Button'; import { Input } from '../ui/Input'; -import { useToast } from '../../features/ui/hooks/useToast'; +import { useToast } from '../../features/shared/hooks/useToast'; interface ContextInfo { current?: number; diff --git a/archon-ui-main/src/components/settings/RAGSettings.tsx b/archon-ui-main/src/components/settings/RAGSettings.tsx index a60925bb..ccba61ce 100644 --- a/archon-ui-main/src/components/settings/RAGSettings.tsx +++ b/archon-ui-main/src/components/settings/RAGSettings.tsx @@ -4,7 +4,7 @@ import { Card } from '../ui/Card'; import { Input } from '../ui/Input'; import { Select } from '../ui/Select'; import { Button } from '../ui/Button'; -import { useToast } from '../../features/ui/hooks/useToast'; +import { useToast } from '../../features/shared/hooks/useToast'; import { credentialsService } from '../../services/credentialsService'; import OllamaModelDiscoveryModal from './OllamaModelDiscoveryModal'; import OllamaModelSelectionModal from './OllamaModelSelectionModal'; diff --git a/archon-ui-main/src/features/knowledge/components/AddKnowledgeDialog.tsx b/archon-ui-main/src/features/knowledge/components/AddKnowledgeDialog.tsx index f6c7bc2a..3788affd 100644 --- a/archon-ui-main/src/features/knowledge/components/AddKnowledgeDialog.tsx +++ b/archon-ui-main/src/features/knowledge/components/AddKnowledgeDialog.tsx @@ -5,7 +5,7 @@ import { Globe, Loader2, Upload } from "lucide-react"; import { useId, useState } from "react"; -import { useToast } from "../../ui/hooks/useToast"; +import { useToast } from "@/features/shared/hooks/useToast"; import { Button, Input, Label } from "../../ui/primitives"; import { Dialog, DialogContent, DialogDescription, DialogHeader, DialogTitle } from "../../ui/primitives/dialog"; import { cn } from "../../ui/primitives/styles"; diff --git a/archon-ui-main/src/features/knowledge/components/KnowledgeTable.tsx b/archon-ui-main/src/features/knowledge/components/KnowledgeTable.tsx index 18985523..63844333 100644 --- a/archon-ui-main/src/features/knowledge/components/KnowledgeTable.tsx +++ b/archon-ui-main/src/features/knowledge/components/KnowledgeTable.tsx @@ -6,7 +6,7 @@ import { formatDistanceToNowStrict } from "date-fns"; import { Code, ExternalLink, Eye, FileText, MoreHorizontal, Trash2 } from "lucide-react"; import { useState } from "react"; -import { useToast } from "../../ui/hooks/useToast"; +import { useToast } from "@/features/shared/hooks/useToast"; import { Button } from "../../ui/primitives"; import { DropdownMenu, diff --git a/archon-ui-main/src/features/knowledge/hooks/tests/useKnowledgeQueries.test.ts b/archon-ui-main/src/features/knowledge/hooks/tests/useKnowledgeQueries.test.ts index 630f213a..c2251e03 100644 --- a/archon-ui-main/src/features/knowledge/hooks/tests/useKnowledgeQueries.test.ts +++ b/archon-ui-main/src/features/knowledge/hooks/tests/useKnowledgeQueries.test.ts @@ -23,14 +23,14 @@ vi.mock("../../services", () => ({ })); // Mock the toast hook -vi.mock("../../../ui/hooks/useToast", () => ({ +vi.mock("@/features/shared/hooks/useToast", () => ({ useToast: () => ({ showToast: vi.fn(), }), })); // Mock smart polling -vi.mock("../../../ui/hooks", () => ({ +vi.mock("@/features/shared/hooks", () => ({ useSmartPolling: () => ({ refetchInterval: 30000, isPaused: false, diff --git a/archon-ui-main/src/features/knowledge/hooks/useKnowledgeQueries.ts b/archon-ui-main/src/features/knowledge/hooks/useKnowledgeQueries.ts index 874499e2..5a45561d 100644 --- a/archon-ui-main/src/features/knowledge/hooks/useKnowledgeQueries.ts +++ b/archon-ui-main/src/features/knowledge/hooks/useKnowledgeQueries.ts @@ -10,8 +10,8 @@ import { useActiveOperations } from "../../progress/hooks"; import { progressKeys } from "../../progress/hooks/useProgressQueries"; import type { ActiveOperation, ActiveOperationsResponse } from "../../progress/types"; import { DISABLED_QUERY_KEY, STALE_TIMES } from "../../shared/queryPatterns"; -import { useSmartPolling } from "../../ui/hooks"; -import { useToast } from "../../ui/hooks/useToast"; +import { useSmartPolling } from "@/features/shared/hooks"; +import { useToast } from "@/features/shared/hooks/useToast"; import { knowledgeService } from "../services"; import type { CrawlRequest, diff --git a/archon-ui-main/src/features/knowledge/views/KnowledgeView.tsx b/archon-ui-main/src/features/knowledge/views/KnowledgeView.tsx index 20d43650..6f6a66df 100644 --- a/archon-ui-main/src/features/knowledge/views/KnowledgeView.tsx +++ b/archon-ui-main/src/features/knowledge/views/KnowledgeView.tsx @@ -6,7 +6,7 @@ import { useEffect, useMemo, useRef, useState } from "react"; import { CrawlingProgress } from "../../progress/components/CrawlingProgress"; import type { ActiveOperation } from "../../progress/types"; -import { useToast } from "../../ui/hooks/useToast"; +import { useToast } from "@/features/shared/hooks/useToast"; import { AddKnowledgeDialog } from "../components/AddKnowledgeDialog"; import { KnowledgeHeader } from "../components/KnowledgeHeader"; import { KnowledgeList } from "../components/KnowledgeList"; diff --git a/archon-ui-main/src/features/mcp/components/McpConfigSection.tsx b/archon-ui-main/src/features/mcp/components/McpConfigSection.tsx index 3f011f9d..c36b2f01 100644 --- a/archon-ui-main/src/features/mcp/components/McpConfigSection.tsx +++ b/archon-ui-main/src/features/mcp/components/McpConfigSection.tsx @@ -1,7 +1,7 @@ import { Copy, ExternalLink } from "lucide-react"; import type React from "react"; import { useState } from "react"; -import { useToast } from "../../ui/hooks"; +import { useToast } from "@/features/shared/hooks"; import { Button, cn, glassmorphism, Tabs, TabsContent, TabsList, TabsTrigger } from "../../ui/primitives"; import type { McpServerConfig, McpServerStatus, SupportedIDE } from "../types"; import { copyToClipboard } from "../../shared/utils/clipboard"; diff --git a/archon-ui-main/src/features/mcp/hooks/useMcpQueries.ts b/archon-ui-main/src/features/mcp/hooks/useMcpQueries.ts index 409694f5..aef5ec68 100644 --- a/archon-ui-main/src/features/mcp/hooks/useMcpQueries.ts +++ b/archon-ui-main/src/features/mcp/hooks/useMcpQueries.ts @@ -1,6 +1,6 @@ import { useQuery } from "@tanstack/react-query"; import { STALE_TIMES } from "../../shared/queryPatterns"; -import { useSmartPolling } from "../../ui/hooks"; +import { useSmartPolling } from "@/features/shared/hooks"; import { mcpApi } from "../services"; // Query keys factory diff --git a/archon-ui-main/src/features/progress/hooks/useProgressQueries.ts b/archon-ui-main/src/features/progress/hooks/useProgressQueries.ts index 19c8e401..ae82ba17 100644 --- a/archon-ui-main/src/features/progress/hooks/useProgressQueries.ts +++ b/archon-ui-main/src/features/progress/hooks/useProgressQueries.ts @@ -7,7 +7,7 @@ import { type UseQueryResult, useQueries, useQuery, useQueryClient } from "@tans import { useEffect, useMemo, useRef } from "react"; import { APIServiceError } from "../../shared/errors"; import { DISABLED_QUERY_KEY, STALE_TIMES } from "../../shared/queryPatterns"; -import { useSmartPolling } from "../../ui/hooks"; +import { useSmartPolling } from "../../shared/hooks"; import { progressService } from "../services"; import type { ActiveOperationsResponse, ProgressResponse, ProgressStatus } from "../types"; diff --git a/archon-ui-main/src/features/projects/components/ProjectCardActions.tsx b/archon-ui-main/src/features/projects/components/ProjectCardActions.tsx index 06a9f57d..fa10e71d 100644 --- a/archon-ui-main/src/features/projects/components/ProjectCardActions.tsx +++ b/archon-ui-main/src/features/projects/components/ProjectCardActions.tsx @@ -1,6 +1,6 @@ import { Clipboard, Pin, Trash2 } from "lucide-react"; import type React from "react"; -import { useToast } from "../../ui/hooks/useToast"; +import { useToast } from "@/features/shared/hooks/useToast"; import { cn, glassmorphism } from "../../ui/primitives/styles"; import { SimpleTooltip } from "../../ui/primitives/tooltip"; diff --git a/archon-ui-main/src/features/projects/hooks/tests/useProjectQueries.test.ts b/archon-ui-main/src/features/projects/hooks/tests/useProjectQueries.test.ts index 1ad07cf4..19601382 100644 --- a/archon-ui-main/src/features/projects/hooks/tests/useProjectQueries.test.ts +++ b/archon-ui-main/src/features/projects/hooks/tests/useProjectQueries.test.ts @@ -20,14 +20,14 @@ vi.mock("../../services", () => ({ })); // Mock the toast hook -vi.mock("../../../ui/hooks/useToast", () => ({ +vi.mock("@/features/shared/hooks/useToast", () => ({ useToast: () => ({ showToast: vi.fn(), }), })); // Mock smart polling -vi.mock("../../../ui/hooks", () => ({ +vi.mock("@/features/shared/hooks", () => ({ useSmartPolling: () => ({ refetchInterval: 5000, isPaused: false, diff --git a/archon-ui-main/src/features/projects/hooks/useProjectQueries.ts b/archon-ui-main/src/features/projects/hooks/useProjectQueries.ts index eaa85e66..ae216e66 100644 --- a/archon-ui-main/src/features/projects/hooks/useProjectQueries.ts +++ b/archon-ui-main/src/features/projects/hooks/useProjectQueries.ts @@ -6,8 +6,8 @@ import { replaceOptimisticEntity, } from "@/features/shared/optimistic"; import { DISABLED_QUERY_KEY, STALE_TIMES } from "../../shared/queryPatterns"; -import { useSmartPolling } from "../../ui/hooks"; -import { useToast } from "../../ui/hooks/useToast"; +import { useSmartPolling } from "@/features/shared/hooks"; +import { useToast } from "@/features/shared/hooks/useToast"; import { projectService } from "../services"; import type { CreateProjectRequest, Project, UpdateProjectRequest } from "../types"; diff --git a/archon-ui-main/src/features/projects/tasks/components/TaskCardActions.tsx b/archon-ui-main/src/features/projects/tasks/components/TaskCardActions.tsx index 3070d521..7bf60a31 100644 --- a/archon-ui-main/src/features/projects/tasks/components/TaskCardActions.tsx +++ b/archon-ui-main/src/features/projects/tasks/components/TaskCardActions.tsx @@ -1,6 +1,6 @@ import { Clipboard, Edit, Trash2 } from "lucide-react"; import type React from "react"; -import { useToast } from "../../../ui/hooks/useToast"; +import { useToast } from "@/features/shared/hooks/useToast"; import { cn, glassmorphism } from "../../../ui/primitives/styles"; import { SimpleTooltip } from "../../../ui/primitives/tooltip"; diff --git a/archon-ui-main/src/features/projects/tasks/hooks/tests/useTaskQueries.test.ts b/archon-ui-main/src/features/projects/tasks/hooks/tests/useTaskQueries.test.ts index ed1c6089..b2612637 100644 --- a/archon-ui-main/src/features/projects/tasks/hooks/tests/useTaskQueries.test.ts +++ b/archon-ui-main/src/features/projects/tasks/hooks/tests/useTaskQueries.test.ts @@ -20,14 +20,14 @@ vi.mock("../../services", () => ({ const showToastMock = vi.fn(); // Mock the toast hook -vi.mock("../../../../ui/hooks/useToast", () => ({ +vi.mock("../../../../shared/hooks/useToast", () => ({ useToast: () => ({ showToast: showToastMock, }), })); // Mock smart polling -vi.mock("../../../../ui/hooks", () => ({ +vi.mock("../../../../shared/hooks", () => ({ useSmartPolling: () => ({ refetchInterval: 5000, isPaused: false, diff --git a/archon-ui-main/src/features/projects/tasks/hooks/useTaskEditor.ts b/archon-ui-main/src/features/projects/tasks/hooks/useTaskEditor.ts index efb37ab6..fff35286 100644 --- a/archon-ui-main/src/features/projects/tasks/hooks/useTaskEditor.ts +++ b/archon-ui-main/src/features/projects/tasks/hooks/useTaskEditor.ts @@ -1,5 +1,5 @@ import { useCallback } from "react"; -import { useToast } from "../../../ui/hooks/useToast"; +import { useToast } from "@/features/shared/hooks/useToast"; import { useProjectFeatures } from "../../hooks/useProjectQueries"; import type { Assignee, CreateTaskRequest, Task, UpdateTaskRequest, UseTaskEditorReturn } from "../types"; import { useCreateTask, useUpdateTask } from "./useTaskQueries"; diff --git a/archon-ui-main/src/features/projects/tasks/hooks/useTaskQueries.ts b/archon-ui-main/src/features/projects/tasks/hooks/useTaskQueries.ts index b39cbb18..55b4bbd0 100644 --- a/archon-ui-main/src/features/projects/tasks/hooks/useTaskQueries.ts +++ b/archon-ui-main/src/features/projects/tasks/hooks/useTaskQueries.ts @@ -6,8 +6,8 @@ import { type OptimisticEntity, } from "@/features/shared/optimistic"; import { DISABLED_QUERY_KEY, STALE_TIMES } from "../../../shared/queryPatterns"; -import { useSmartPolling } from "../../../ui/hooks"; -import { useToast } from "../../../ui/hooks/useToast"; +import { useSmartPolling } from "../../../shared/hooks"; +import { useToast } from "../../../shared/hooks/useToast"; import { taskService } from "../services"; import type { CreateTaskRequest, Task, UpdateTaskRequest } from "../types"; diff --git a/archon-ui-main/src/features/settings/migrations/components/PendingMigrationsModal.tsx b/archon-ui-main/src/features/settings/migrations/components/PendingMigrationsModal.tsx index f4bd23c0..ff5ec746 100644 --- a/archon-ui-main/src/features/settings/migrations/components/PendingMigrationsModal.tsx +++ b/archon-ui-main/src/features/settings/migrations/components/PendingMigrationsModal.tsx @@ -6,7 +6,7 @@ import { AnimatePresence, motion } from "framer-motion"; import { CheckCircle, Copy, Database, ExternalLink, X } from "lucide-react"; import React from "react"; import { copyToClipboard } from "@/features/shared/utils/clipboard"; -import { useToast } from "@/features/ui/hooks/useToast"; +import { useToast } from "@/features/shared/hooks/useToast"; import type { PendingMigration } from "../types"; interface PendingMigrationsModalProps { diff --git a/archon-ui-main/src/features/settings/migrations/hooks/useMigrationQueries.ts b/archon-ui-main/src/features/settings/migrations/hooks/useMigrationQueries.ts index 1c2a6d7e..7a44ff8d 100644 --- a/archon-ui-main/src/features/settings/migrations/hooks/useMigrationQueries.ts +++ b/archon-ui-main/src/features/settings/migrations/hooks/useMigrationQueries.ts @@ -4,7 +4,7 @@ import { useQuery } from "@tanstack/react-query"; import { STALE_TIMES } from "@/features/shared/queryPatterns"; -import { useSmartPolling } from "@/features/ui/hooks/useSmartPolling"; +import { useSmartPolling } from "@/features/shared/hooks/useSmartPolling"; import { migrationService } from "../services/migrationService"; import type { MigrationHistoryResponse, MigrationStatusResponse, PendingMigration } from "../types"; diff --git a/archon-ui-main/src/features/settings/version/hooks/useVersionQueries.ts b/archon-ui-main/src/features/settings/version/hooks/useVersionQueries.ts index e1aefbd8..f9ea9165 100644 --- a/archon-ui-main/src/features/settings/version/hooks/useVersionQueries.ts +++ b/archon-ui-main/src/features/settings/version/hooks/useVersionQueries.ts @@ -4,7 +4,7 @@ import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; import { STALE_TIMES } from "@/features/shared/queryPatterns"; -import { useSmartPolling } from "@/features/ui/hooks/useSmartPolling"; +import { useSmartPolling } from "@/features/shared/hooks/useSmartPolling"; import { versionService } from "../services/versionService"; import type { VersionCheckResponse } from "../types"; diff --git a/archon-ui-main/src/features/ui/hooks/index.ts b/archon-ui-main/src/features/shared/hooks/index.ts similarity index 70% rename from archon-ui-main/src/features/ui/hooks/index.ts rename to archon-ui-main/src/features/shared/hooks/index.ts index b23209b4..db280d76 100644 --- a/archon-ui-main/src/features/ui/hooks/index.ts +++ b/archon-ui-main/src/features/shared/hooks/index.ts @@ -1,3 +1,3 @@ export * from "./useSmartPolling"; export * from "./useThemeAware"; -export * from "./useToast"; +export * from "./useToast"; \ No newline at end of file diff --git a/archon-ui-main/src/features/ui/hooks/tests/useSmartPolling.test.ts b/archon-ui-main/src/features/shared/hooks/tests/useSmartPolling.test.ts similarity index 100% rename from archon-ui-main/src/features/ui/hooks/tests/useSmartPolling.test.ts rename to archon-ui-main/src/features/shared/hooks/tests/useSmartPolling.test.ts diff --git a/archon-ui-main/src/features/ui/hooks/useSmartPolling.ts b/archon-ui-main/src/features/shared/hooks/useSmartPolling.ts similarity index 100% rename from archon-ui-main/src/features/ui/hooks/useSmartPolling.ts rename to archon-ui-main/src/features/shared/hooks/useSmartPolling.ts diff --git a/archon-ui-main/src/features/ui/hooks/useThemeAware.ts b/archon-ui-main/src/features/shared/hooks/useThemeAware.ts similarity index 100% rename from archon-ui-main/src/features/ui/hooks/useThemeAware.ts rename to archon-ui-main/src/features/shared/hooks/useThemeAware.ts diff --git a/archon-ui-main/src/features/ui/hooks/useToast.ts b/archon-ui-main/src/features/shared/hooks/useToast.ts similarity index 97% rename from archon-ui-main/src/features/ui/hooks/useToast.ts rename to archon-ui-main/src/features/shared/hooks/useToast.ts index 6e71297e..49b40139 100644 --- a/archon-ui-main/src/features/ui/hooks/useToast.ts +++ b/archon-ui-main/src/features/shared/hooks/useToast.ts @@ -1,6 +1,6 @@ import { AlertCircle, CheckCircle, Info, XCircle } from "lucide-react"; import { createContext, useCallback, useContext, useEffect, useRef, useState } from "react"; -import { createOptimisticId } from "../../shared/optimistic"; +import { createOptimisticId } from "../optimistic"; // Toast types interface Toast { diff --git a/archon-ui-main/src/features/ui/components/ToastProvider.tsx b/archon-ui-main/src/features/ui/components/ToastProvider.tsx index 1657ac80..8a3c476b 100644 --- a/archon-ui-main/src/features/ui/components/ToastProvider.tsx +++ b/archon-ui-main/src/features/ui/components/ToastProvider.tsx @@ -1,5 +1,5 @@ import type React from "react"; -import { createToastContext, getToastIcon, ToastContext } from "../hooks/useToast"; +import { createToastContext, getToastIcon, ToastContext } from "../../shared/hooks/useToast"; import { ToastProvider as RadixToastProvider, Toast, diff --git a/archon-ui-main/src/pages/SettingsPage.tsx b/archon-ui-main/src/pages/SettingsPage.tsx index 20c3c412..35136616 100644 --- a/archon-ui-main/src/pages/SettingsPage.tsx +++ b/archon-ui-main/src/pages/SettingsPage.tsx @@ -14,7 +14,7 @@ import { Database, } from "lucide-react"; import { motion, AnimatePresence } from "framer-motion"; -import { useToast } from "../features/ui/hooks/useToast"; +import { useToast } from "../features/shared/hooks/useToast"; import { useSettings } from "../contexts/SettingsContext"; import { useStaggeredEntrance } from "../hooks/useStaggeredEntrance"; import { FeaturesSection } from "../components/settings/FeaturesSection"; From 63a92cf7d75bd4fbc236816196d5c506f5c620c3 Mon Sep 17 00:00:00 2001 From: Wirasm <152263317+Wirasm@users.noreply.github.com> Date: Mon, 22 Sep 2025 14:59:33 +0300 Subject: [PATCH 5/7] refactor: reorganize features/shared directory for better maintainability (#730) * refactor: reorganize features/shared directory structure - Created organized subdirectories for better code organization: - api/ - API clients and HTTP utilities (renamed apiWithEtag.ts to apiClient.ts) - config/ - Configuration files (queryClient, queryPatterns) - types/ - Shared type definitions (errors) - utils/ - Pure utility functions (optimistic, clipboard) - hooks/ - Shared React hooks (already existed) - Updated all import paths across the codebase (~40+ files) - Updated all AI documentation in PRPs/ai_docs/ to reflect new structure - All tests passing, build successful, no functional changes This improves maintainability and follows vertical slice architecture patterns. Co-Authored-By: Claude * fix: address PR review comments and code improvements - Update imports to use @/features alias path for optimistic utils - Fix optimistic upload item replacement by matching on source_id instead of id - Clean up test suite naming and remove meta-terms from comments - Only set Content-Type header on requests with body - Add explicit TypeScript typing to useProjectFeatures hook - Complete Phase 4 improvements with proper query typing * fix: address additional PR review feedback - Clear feature queries when deleting project to prevent cache memory leaks - Update KnowledgeCard comments to follow documentation guidelines - Add explanatory comment for accessibility pattern in KnowledgeCard --------- Co-authored-by: Claude --- PRPs/ai_docs/API_NAMING_CONVENTIONS.md | 2 +- PRPs/ai_docs/ARCHITECTURE.md | 6 +- PRPs/ai_docs/DATA_FETCHING_ARCHITECTURE.md | 10 +- PRPs/ai_docs/ETAG_IMPLEMENTATION.md | 10 +- PRPs/ai_docs/QUERY_PATTERNS.md | 8 +- PRPs/ai_docs/optimistic_updates.md | 6 +- archon-ui-main/src/App.tsx | 2 +- .../layout/hooks/useBackendHealth.ts | 4 +- .../knowledge/components/KnowledgeCard.tsx | 7 +- .../knowledge/hooks/useKnowledgeQueries.ts | 17 +- .../components/KnowledgeInspector.tsx | 2 +- .../inspector/hooks/useInspectorPagination.ts | 2 +- .../knowledge/services/knowledgeService.ts | 4 +- .../utils/tests/providerErrorHandler.test.ts | 206 +++++++++--------- .../knowledge/views/KnowledgeView.tsx | 2 +- .../mcp/components/McpConfigSection.tsx | 5 +- .../src/features/mcp/hooks/useMcpQueries.ts | 2 +- .../src/features/mcp/services/mcpApi.ts | 2 +- .../hooks/tests/useProgressQueries.test.ts | 2 +- .../progress/hooks/useProgressQueries.ts | 4 +- .../progress/services/progressService.ts | 2 +- .../projects/components/ProjectCard.tsx | 2 +- .../documents/components/DocumentCard.tsx | 2 +- .../documents/hooks/useDocumentQueries.ts | 2 +- .../projects/hooks/useProjectQueries.ts | 14 +- .../projects/services/projectService.ts | 4 +- .../projects/tasks/components/TaskCard.tsx | 2 +- .../projects/tasks/hooks/useTaskQueries.ts | 8 +- .../projects/tasks/services/taskService.ts | 4 +- .../tasks/services/tests/taskService.test.ts | 4 +- .../components/MigrationStatusCard.tsx | 6 +- .../components/PendingMigrationsModal.tsx | 17 +- .../migrations/hooks/useMigrationQueries.ts | 2 +- .../migrations/services/migrationService.ts | 2 +- .../version/components/UpdateBanner.tsx | 3 +- .../version/components/VersionStatusCard.tsx | 3 +- .../version/hooks/useVersionQueries.ts | 2 +- .../version/services/versionService.ts | 2 +- .../{apiWithEtag.ts => api/apiClient.ts} | 23 +- .../tests/apiClient.test.ts} | 18 +- .../shared/{ => config}/queryClient.ts | 0 .../shared/{ => config}/queryPatterns.ts | 0 .../src/features/shared/hooks/index.ts | 2 +- .../src/features/shared/hooks/useToast.ts | 2 +- .../src/features/shared/{ => types}/errors.ts | 0 .../features/shared/{ => utils}/optimistic.ts | 0 .../{ => utils/tests}/optimistic.test.ts | 14 +- .../src/features/testing/test-utils.tsx | 2 +- 48 files changed, 230 insertions(+), 215 deletions(-) rename archon-ui-main/src/features/shared/{apiWithEtag.ts => api/apiClient.ts} (83%) rename archon-ui-main/src/features/shared/{apiWithEtag.test.ts => api/tests/apiClient.test.ts} (96%) rename archon-ui-main/src/features/shared/{ => config}/queryClient.ts (100%) rename archon-ui-main/src/features/shared/{ => config}/queryPatterns.ts (100%) rename archon-ui-main/src/features/shared/{ => types}/errors.ts (100%) rename archon-ui-main/src/features/shared/{ => utils}/optimistic.ts (100%) rename archon-ui-main/src/features/shared/{ => utils/tests}/optimistic.test.ts (97%) diff --git a/PRPs/ai_docs/API_NAMING_CONVENTIONS.md b/PRPs/ai_docs/API_NAMING_CONVENTIONS.md index 5688912b..2135bc8d 100644 --- a/PRPs/ai_docs/API_NAMING_CONVENTIONS.md +++ b/PRPs/ai_docs/API_NAMING_CONVENTIONS.md @@ -198,7 +198,7 @@ Database values used directly - no mapping layers: - Operation statuses: `"pending"`, `"processing"`, `"completed"`, `"failed"` ### Time Constants -**Location**: `archon-ui-main/src/features/shared/queryPatterns.ts` +**Location**: `archon-ui-main/src/features/shared/config/queryPatterns.ts` - `STALE_TIMES.instant` - 0ms - `STALE_TIMES.realtime` - 3 seconds - `STALE_TIMES.frequent` - 5 seconds diff --git a/PRPs/ai_docs/ARCHITECTURE.md b/PRPs/ai_docs/ARCHITECTURE.md index a5c0ae7a..eb3a7f81 100644 --- a/PRPs/ai_docs/ARCHITECTURE.md +++ b/PRPs/ai_docs/ARCHITECTURE.md @@ -88,8 +88,8 @@ Pattern: `{METHOD} /api/{resource}/{id?}/{sub-resource?}` ### Data Fetching **Core**: TanStack Query v5 -**Configuration**: `archon-ui-main/src/features/shared/queryClient.ts` -**Patterns**: `archon-ui-main/src/features/shared/queryPatterns.ts` +**Configuration**: `archon-ui-main/src/features/shared/config/queryClient.ts` +**Patterns**: `archon-ui-main/src/features/shared/config/queryPatterns.ts` ### State Management - **Server State**: TanStack Query @@ -139,7 +139,7 @@ TanStack Query is the single source of truth. No separate state management neede No translation layers. Database values (e.g., `"todo"`, `"doing"`) used directly in UI. ### Browser-Native Caching -ETags handled by browser, not JavaScript. See `archon-ui-main/src/features/shared/apiWithEtag.ts`. +ETags handled by browser, not JavaScript. See `archon-ui-main/src/features/shared/api/apiClient.ts`. ## Deployment diff --git a/PRPs/ai_docs/DATA_FETCHING_ARCHITECTURE.md b/PRPs/ai_docs/DATA_FETCHING_ARCHITECTURE.md index d8a9822b..8d1bbb62 100644 --- a/PRPs/ai_docs/DATA_FETCHING_ARCHITECTURE.md +++ b/PRPs/ai_docs/DATA_FETCHING_ARCHITECTURE.md @@ -8,7 +8,7 @@ Archon uses **TanStack Query v5** for all data fetching, caching, and synchroniz ### 1. Query Client Configuration -**Location**: `archon-ui-main/src/features/shared/queryClient.ts` +**Location**: `archon-ui-main/src/features/shared/config/queryClient.ts` Centralized QueryClient with: @@ -30,7 +30,7 @@ Visibility-aware polling that: ### 3. Query Patterns -**Location**: `archon-ui-main/src/features/shared/queryPatterns.ts` +**Location**: `archon-ui-main/src/features/shared/config/queryPatterns.ts` Shared constants: @@ -64,7 +64,7 @@ Standard pattern across all features: ### ETag Support -**Location**: `archon-ui-main/src/features/shared/apiWithEtag.ts` +**Location**: `archon-ui-main/src/features/shared/api/apiClient.ts` ETag implementation: @@ -83,7 +83,7 @@ Backend endpoints follow RESTful patterns: ## Optimistic Updates -**Utilities**: `archon-ui-main/src/features/shared/optimistic.ts` +**Utilities**: `archon-ui-main/src/features/shared/utils/optimistic.ts` All mutations use nanoid-based optimistic updates: @@ -105,7 +105,7 @@ Polling intervals are defined in each feature's query hooks. See actual implemen - **Progress**: `archon-ui-main/src/features/progress/hooks/useProgressQueries.ts` - **MCP**: `archon-ui-main/src/features/mcp/hooks/useMcpQueries.ts` -Standard intervals from `archon-ui-main/src/features/shared/queryPatterns.ts`: +Standard intervals from `archon-ui-main/src/features/shared/config/queryPatterns.ts`: - `STALE_TIMES.instant`: 0ms (always fresh) - `STALE_TIMES.frequent`: 5 seconds (frequently changing data) - `STALE_TIMES.normal`: 30 seconds (standard cache) diff --git a/PRPs/ai_docs/ETAG_IMPLEMENTATION.md b/PRPs/ai_docs/ETAG_IMPLEMENTATION.md index 70e4ce63..8560dbb5 100644 --- a/PRPs/ai_docs/ETAG_IMPLEMENTATION.md +++ b/PRPs/ai_docs/ETAG_IMPLEMENTATION.md @@ -17,7 +17,7 @@ The backend generates ETags for API responses: - Returns `304 Not Modified` when ETags match ### Frontend Handling -**Location**: `archon-ui-main/src/features/shared/apiWithEtag.ts` +**Location**: `archon-ui-main/src/features/shared/api/apiClient.ts` The frontend relies on browser-native HTTP caching: - Browser automatically sends `If-None-Match` headers with cached ETags @@ -28,7 +28,7 @@ The frontend relies on browser-native HTTP caching: #### Browser vs Non-Browser Behavior - **Standard Browsers**: Per the Fetch spec, a 304 response freshens the HTTP cache and returns the cached body to JavaScript - **Non-Browser Runtimes** (React Native, custom fetch): May surface 304 with empty body to JavaScript -- **Client Fallback**: The `apiWithEtag.ts` implementation handles both scenarios, ensuring consistent behavior across environments +- **Client Fallback**: The `apiClient.ts` implementation handles both scenarios, ensuring consistent behavior across environments ## Implementation Details @@ -81,8 +81,8 @@ Unlike previous implementations, the current approach: ### Configuration Cache behavior is controlled through TanStack Query's `staleTime`: -- See `archon-ui-main/src/features/shared/queryPatterns.ts` for standard times -- See `archon-ui-main/src/features/shared/queryClient.ts` for global configuration +- See `archon-ui-main/src/features/shared/config/queryPatterns.ts` for standard times +- See `archon-ui-main/src/features/shared/config/queryClient.ts` for global configuration ## Performance Benefits @@ -100,7 +100,7 @@ Cache behavior is controlled through TanStack Query's `staleTime`: ### Core Implementation - **Backend Utilities**: `python/src/server/utils/etag_utils.py` -- **Frontend Client**: `archon-ui-main/src/features/shared/apiWithEtag.ts` +- **Frontend Client**: `archon-ui-main/src/features/shared/api/apiClient.ts` - **Tests**: `python/tests/server/utils/test_etag_utils.py` ### Usage Examples diff --git a/PRPs/ai_docs/QUERY_PATTERNS.md b/PRPs/ai_docs/QUERY_PATTERNS.md index 3c3204db..499daa36 100644 --- a/PRPs/ai_docs/QUERY_PATTERNS.md +++ b/PRPs/ai_docs/QUERY_PATTERNS.md @@ -5,7 +5,7 @@ This guide documents the standardized patterns for using TanStack Query v5 in th ## Core Principles 1. **Feature Ownership**: Each feature owns its query keys in `{feature}/hooks/use{Feature}Queries.ts` -2. **Consistent Patterns**: Always use shared patterns from `shared/queryPatterns.ts` +2. **Consistent Patterns**: Always use shared patterns from `shared/config/queryPatterns.ts` 3. **No Hardcoded Values**: Never hardcode stale times or disabled keys 4. **Mirror Backend API**: Query keys should exactly match backend API structure @@ -49,7 +49,7 @@ export const taskKeys = { ### Import Required Patterns ```typescript -import { DISABLED_QUERY_KEY, STALE_TIMES } from "@/features/shared/queryPatterns"; +import { DISABLED_QUERY_KEY, STALE_TIMES } from "@/features/shared/config/queryPatterns"; ``` ### Disabled Queries @@ -106,7 +106,7 @@ export function useFeatureDetail(id: string | undefined) { ## Mutations with Optimistic Updates ```typescript -import { createOptimisticEntity, replaceOptimisticEntity } from "@/features/shared/optimistic"; +import { createOptimisticEntity, replaceOptimisticEntity } from "@/features/shared/utils/optimistic"; export function useCreateFeature() { const queryClient = useQueryClient(); @@ -161,7 +161,7 @@ vi.mock("../../services", () => ({ })); // Mock shared patterns with ALL values -vi.mock("../../../shared/queryPatterns", () => ({ +vi.mock("../../../shared/config/queryPatterns", () => ({ DISABLED_QUERY_KEY: ["disabled"] as const, STALE_TIMES: { instant: 0, diff --git a/PRPs/ai_docs/optimistic_updates.md b/PRPs/ai_docs/optimistic_updates.md index 7be11ea6..219b7866 100644 --- a/PRPs/ai_docs/optimistic_updates.md +++ b/PRPs/ai_docs/optimistic_updates.md @@ -3,7 +3,7 @@ ## Core Architecture ### Shared Utilities Module -**Location**: `src/features/shared/optimistic.ts` +**Location**: `src/features/shared/utils/optimistic.ts` Provides type-safe utilities for managing optimistic state across all features: - `createOptimisticId()` - Generates stable UUIDs using nanoid @@ -73,13 +73,13 @@ Reusable component showing: - Uses `createOptimisticId()` directly for progress tracking ### Toasts -- **Location**: `src/features/ui/hooks/useToast.ts:43` +- **Location**: `src/features/shared/hooks/useToast.ts:43` - Uses `createOptimisticId()` for unique toast IDs ## Testing ### Unit Tests -**Location**: `src/features/shared/optimistic.test.ts` +**Location**: `src/features/shared/utils/tests/optimistic.test.ts` Covers all utility functions with 8 test cases: - ID uniqueness and format validation diff --git a/archon-ui-main/src/App.tsx b/archon-ui-main/src/App.tsx index 1d4e22d3..ea2539cc 100644 --- a/archon-ui-main/src/App.tsx +++ b/archon-ui-main/src/App.tsx @@ -2,7 +2,7 @@ import { useState, useEffect } from 'react'; import { BrowserRouter as Router, Routes, Route, Navigate } from 'react-router-dom'; import { QueryClientProvider } from '@tanstack/react-query'; import { ReactQueryDevtools } from '@tanstack/react-query-devtools'; -import { queryClient } from './features/shared/queryClient'; +import { queryClient } from './features/shared/config/queryClient'; import { KnowledgeBasePage } from './pages/KnowledgeBasePage'; import { SettingsPage } from './pages/SettingsPage'; import { MCPPage } from './pages/MCPPage'; diff --git a/archon-ui-main/src/components/layout/hooks/useBackendHealth.ts b/archon-ui-main/src/components/layout/hooks/useBackendHealth.ts index 626d23b6..59e9ccfa 100644 --- a/archon-ui-main/src/components/layout/hooks/useBackendHealth.ts +++ b/archon-ui-main/src/components/layout/hooks/useBackendHealth.ts @@ -1,6 +1,6 @@ import { useQuery } from "@tanstack/react-query"; -import { callAPIWithETag } from "../../../features/shared/apiWithEtag"; -import { createRetryLogic, STALE_TIMES } from "../../../features/shared/queryPatterns"; +import { callAPIWithETag } from "../../../features/shared/api/apiClient"; +import { createRetryLogic, STALE_TIMES } from "../../../features/shared/config/queryPatterns"; import type { HealthResponse } from "../types"; /** diff --git a/archon-ui-main/src/features/knowledge/components/KnowledgeCard.tsx b/archon-ui-main/src/features/knowledge/components/KnowledgeCard.tsx index bb49edd9..05c882de 100644 --- a/archon-ui-main/src/features/knowledge/components/KnowledgeCard.tsx +++ b/archon-ui-main/src/features/knowledge/components/KnowledgeCard.tsx @@ -1,6 +1,6 @@ /** - * Enhanced Knowledge Card Component - * Individual knowledge item card with excellent UX and inline progress + * Knowledge Card component + * Displays a knowledge item with inline progress and status UI * Following the pattern from ProjectCard */ @@ -10,7 +10,7 @@ import { Clock, Code, ExternalLink, File, FileText, Globe } from "lucide-react"; import { useState } from "react"; import { KnowledgeCardProgress } from "../../progress/components/KnowledgeCardProgress"; import type { ActiveOperation } from "../../progress/types"; -import { isOptimistic } from "../../shared/optimistic"; +import { isOptimistic } from "@/features/shared/utils/optimistic"; import { StatPill } from "../../ui/primitives"; import { OptimisticIndicator } from "../../ui/primitives/OptimisticIndicator"; import { cn } from "../../ui/primitives/styles"; @@ -144,6 +144,7 @@ export const KnowledgeCard: React.FC = ({ }; return ( + // biome-ignore lint/a11y/useSemanticElements: Card contains nested interactive elements (buttons, links) - using div to avoid invalid HTML nesting ); - const tempItemId = optimisticItem.id; // Update all summaries caches with optimistic data, respecting each cache's filter const entries = queryClient.getQueriesData({ @@ -229,7 +228,7 @@ export function useCrawlUrl() { }); // Return context for rollback and replacement - return { previousSummaries, previousOperations, tempProgressId, tempItemId }; + return { previousSummaries, previousOperations, tempProgressId }; }, onSuccess: (response, _variables, context) => { // Replace temporary IDs with real ones from the server @@ -313,7 +312,6 @@ export function useUploadDocument() { previousSummaries?: Array<[readonly unknown[], KnowledgeItemsResponse | undefined]>; previousOperations?: ActiveOperationsResponse; tempProgressId: string; - tempItemId: string; } >({ mutationFn: ({ file, metadata }: { file: File; metadata: UploadMetadata }) => @@ -352,7 +350,6 @@ export function useUploadDocument() { created_at: new Date().toISOString(), updated_at: new Date().toISOString(), } as Omit); - const tempItemId = optimisticItem.id; // Respect each cache's filter (knowledge_type, tags, etc.) const entries = queryClient.getQueriesData({ @@ -410,7 +407,7 @@ export function useUploadDocument() { }; }); - return { previousSummaries, previousOperations, tempProgressId, tempItemId }; + return { previousSummaries, previousOperations, tempProgressId }; }, onSuccess: (response, _variables, context) => { // Replace temporary IDs with real ones from the server @@ -421,7 +418,7 @@ export function useUploadDocument() { return { ...old, items: old.items.map((item) => { - if (item.id === context.tempItemId) { + if (item.source_id === context.tempProgressId) { return { ...item, source_id: response.progressId, diff --git a/archon-ui-main/src/features/knowledge/inspector/components/KnowledgeInspector.tsx b/archon-ui-main/src/features/knowledge/inspector/components/KnowledgeInspector.tsx index 69e8f050..334d4567 100644 --- a/archon-ui-main/src/features/knowledge/inspector/components/KnowledgeInspector.tsx +++ b/archon-ui-main/src/features/knowledge/inspector/components/KnowledgeInspector.tsx @@ -4,13 +4,13 @@ */ import { useCallback, useEffect, useState } from "react"; +import { copyToClipboard } from "../../../shared/utils/clipboard"; import { InspectorDialog, InspectorDialogContent, InspectorDialogTitle } from "../../../ui/primitives"; import type { CodeExample, DocumentChunk, InspectorSelectedItem, KnowledgeItem } from "../../types"; import { useInspectorPagination } from "../hooks/useInspectorPagination"; import { ContentViewer } from "./ContentViewer"; import { InspectorHeader } from "./InspectorHeader"; import { InspectorSidebar } from "./InspectorSidebar"; -import { copyToClipboard } from "../../../shared/utils/clipboard"; interface KnowledgeInspectorProps { item: KnowledgeItem; diff --git a/archon-ui-main/src/features/knowledge/inspector/hooks/useInspectorPagination.ts b/archon-ui-main/src/features/knowledge/inspector/hooks/useInspectorPagination.ts index 613aa19d..a1f286d5 100644 --- a/archon-ui-main/src/features/knowledge/inspector/hooks/useInspectorPagination.ts +++ b/archon-ui-main/src/features/knowledge/inspector/hooks/useInspectorPagination.ts @@ -5,7 +5,7 @@ import { useInfiniteQuery } from "@tanstack/react-query"; import { useMemo } from "react"; -import { STALE_TIMES } from "@/features/shared/queryPatterns"; +import { STALE_TIMES } from "@/features/shared/config/queryPatterns"; import { knowledgeKeys } from "../../hooks/useKnowledgeQueries"; import { knowledgeService } from "../../services"; import type { ChunksResponse, CodeExample, CodeExamplesResponse, DocumentChunk } from "../../types"; diff --git a/archon-ui-main/src/features/knowledge/services/knowledgeService.ts b/archon-ui-main/src/features/knowledge/services/knowledgeService.ts index b9d6af06..cfab3f7f 100644 --- a/archon-ui-main/src/features/knowledge/services/knowledgeService.ts +++ b/archon-ui-main/src/features/knowledge/services/knowledgeService.ts @@ -3,8 +3,8 @@ * Handles all knowledge-related API operations using TanStack Query patterns */ -import { callAPIWithETag } from "../../shared/apiWithEtag"; -import { APIServiceError } from "../../shared/errors"; +import { callAPIWithETag } from "../../shared/api/apiClient"; +import { APIServiceError } from "../../shared/types/errors"; import type { ChunksResponse, CodeExamplesResponse, diff --git a/archon-ui-main/src/features/knowledge/utils/tests/providerErrorHandler.test.ts b/archon-ui-main/src/features/knowledge/utils/tests/providerErrorHandler.test.ts index 193e2444..9ddf380a 100644 --- a/archon-ui-main/src/features/knowledge/utils/tests/providerErrorHandler.test.ts +++ b/archon-ui-main/src/features/knowledge/utils/tests/providerErrorHandler.test.ts @@ -1,70 +1,70 @@ -import { describe, it, expect } from 'vitest'; -import { parseProviderError, getProviderErrorMessage, type ProviderError } from '../providerErrorHandler'; +import { describe, expect, it } from "vitest"; +import { getProviderErrorMessage, type ProviderError, parseProviderError } from "../providerErrorHandler"; -describe('providerErrorHandler', () => { - describe('parseProviderError', () => { - it('should handle basic Error objects', () => { - const error = new Error('Basic error message'); +describe("providerErrorHandler", () => { + describe("parseProviderError", () => { + it("should handle basic Error objects", () => { + const error = new Error("Basic error message"); const result = parseProviderError(error); - expect(result.message).toBe('Basic error message'); + expect(result.message).toBe("Basic error message"); expect(result.isProviderError).toBeUndefined(); }); - it('should handle errors with statusCode property', () => { - const error = { statusCode: 401, message: 'Unauthorized' }; + it("should handle errors with statusCode property", () => { + const error = { statusCode: 401, message: "Unauthorized" }; const result = parseProviderError(error); expect(result.statusCode).toBe(401); - expect(result.message).toBe('Unauthorized'); + expect(result.message).toBe("Unauthorized"); }); - it('should handle errors with status property', () => { - const error = { status: 429, message: 'Rate limited' }; + it("should handle errors with status property", () => { + const error = { status: 429, message: "Rate limited" }; const result = parseProviderError(error); expect(result.statusCode).toBe(429); - expect(result.message).toBe('Rate limited'); + expect(result.message).toBe("Rate limited"); }); - it('should prioritize statusCode over status when both are present', () => { - const error = { statusCode: 401, status: 429, message: 'Auth error' }; + it("should prioritize statusCode over status when both are present", () => { + const error = { statusCode: 401, status: 429, message: "Auth error" }; const result = parseProviderError(error); expect(result.statusCode).toBe(401); }); - it('should parse structured provider errors from backend', () => { + it("should parse structured provider errors from backend", () => { const error = { message: JSON.stringify({ detail: { - error_type: 'authentication_failed', - provider: 'OpenAI', - message: 'Invalid API key' - } - }) + error_type: "authentication_failed", + provider: "OpenAI", + message: "Invalid API key", + }, + }), }; const result = parseProviderError(error); expect(result.isProviderError).toBe(true); - expect(result.provider).toBe('OpenAI'); - expect(result.errorType).toBe('authentication_failed'); - expect(result.message).toBe('Invalid API key'); + expect(result.provider).toBe("OpenAI"); + expect(result.errorType).toBe("authentication_failed"); + expect(result.message).toBe("Invalid API key"); }); - it('should handle malformed JSON in message gracefully', () => { + it("should handle malformed JSON in message gracefully", () => { const error = { - message: 'invalid json { detail' + message: "invalid json { detail", }; const result = parseProviderError(error); expect(result.isProviderError).toBeUndefined(); - expect(result.message).toBe('invalid json { detail'); + expect(result.message).toBe("invalid json { detail"); }); - it('should handle null and undefined inputs safely', () => { + it("should handle null and undefined inputs safely", () => { expect(() => parseProviderError(null)).not.toThrow(); expect(() => parseProviderError(undefined)).not.toThrow(); @@ -75,7 +75,7 @@ describe('providerErrorHandler', () => { expect(undefinedResult).toBeDefined(); }); - it('should handle empty objects', () => { + it("should handle empty objects", () => { const result = parseProviderError({}); expect(result).toBeDefined(); @@ -83,171 +83,171 @@ describe('providerErrorHandler', () => { expect(result.isProviderError).toBeUndefined(); }); - it('should handle primitive values', () => { - expect(() => parseProviderError('string error')).not.toThrow(); + it("should handle primitive values", () => { + expect(() => parseProviderError("string error")).not.toThrow(); expect(() => parseProviderError(42)).not.toThrow(); expect(() => parseProviderError(true)).not.toThrow(); }); - it('should handle structured errors without provider field', () => { + it("should handle structured errors without provider field", () => { const error = { message: JSON.stringify({ detail: { - error_type: 'quota_exhausted', - message: 'Usage limit exceeded' - } - }) + error_type: "quota_exhausted", + message: "Usage limit exceeded", + }, + }), }; const result = parseProviderError(error); expect(result.isProviderError).toBe(true); - expect(result.provider).toBe('LLM'); // Default fallback - expect(result.errorType).toBe('quota_exhausted'); - expect(result.message).toBe('Usage limit exceeded'); + expect(result.provider).toBe("LLM"); // Default fallback + expect(result.errorType).toBe("quota_exhausted"); + expect(result.message).toBe("Usage limit exceeded"); }); - it('should handle partial structured errors', () => { + it("should handle partial structured errors", () => { const error = { message: JSON.stringify({ detail: { - error_type: 'rate_limit' + error_type: "rate_limit", // Missing message field - } - }) + }, + }), }; const result = parseProviderError(error); expect(result.isProviderError).toBe(true); - expect(result.errorType).toBe('rate_limit'); + expect(result.errorType).toBe("rate_limit"); expect(result.message).toBe(error.message); // Falls back to original message }); }); - describe('getProviderErrorMessage', () => { - it('should return user-friendly message for authentication_failed', () => { + describe("getProviderErrorMessage", () => { + it("should return user-friendly message for authentication_failed", () => { const error: ProviderError = { - name: 'Error', - message: 'Auth failed', + name: "Error", + message: "Auth failed", isProviderError: true, - provider: 'OpenAI', - errorType: 'authentication_failed' + provider: "OpenAI", + errorType: "authentication_failed", }; const result = getProviderErrorMessage(error); - expect(result).toBe('Please verify your OpenAI API key in Settings.'); + expect(result).toBe("Please verify your OpenAI API key in Settings."); }); - it('should return user-friendly message for quota_exhausted', () => { + it("should return user-friendly message for quota_exhausted", () => { const error: ProviderError = { - name: 'Error', - message: 'Quota exceeded', + name: "Error", + message: "Quota exceeded", isProviderError: true, - provider: 'Google AI', - errorType: 'quota_exhausted' + provider: "Google AI", + errorType: "quota_exhausted", }; const result = getProviderErrorMessage(error); - expect(result).toBe('Google AI quota exhausted. Please check your billing settings.'); + expect(result).toBe("Google AI quota exhausted. Please check your billing settings."); }); - it('should return user-friendly message for rate_limit', () => { + it("should return user-friendly message for rate_limit", () => { const error: ProviderError = { - name: 'Error', - message: 'Rate limited', + name: "Error", + message: "Rate limited", isProviderError: true, - provider: 'Anthropic', - errorType: 'rate_limit' + provider: "Anthropic", + errorType: "rate_limit", }; const result = getProviderErrorMessage(error); - expect(result).toBe('Anthropic rate limit exceeded. Please wait and try again.'); + expect(result).toBe("Anthropic rate limit exceeded. Please wait and try again."); }); - it('should return generic provider message for unknown error types', () => { + it("should return generic provider message for unknown error types", () => { const error: ProviderError = { - name: 'Error', - message: 'Unknown error', + name: "Error", + message: "Unknown error", isProviderError: true, - provider: 'OpenAI', - errorType: 'unknown_error' + provider: "OpenAI", + errorType: "unknown_error", }; const result = getProviderErrorMessage(error); - expect(result).toBe('OpenAI API error. Please check your configuration.'); + expect(result).toBe("OpenAI API error. Please check your configuration."); }); - it('should use default provider when provider is missing', () => { + it("should use default provider when provider is missing", () => { const error: ProviderError = { - name: 'Error', - message: 'Auth failed', + name: "Error", + message: "Auth failed", isProviderError: true, - errorType: 'authentication_failed' + errorType: "authentication_failed", }; const result = getProviderErrorMessage(error); - expect(result).toBe('Please verify your LLM API key in Settings.'); + expect(result).toBe("Please verify your LLM API key in Settings."); }); - it('should handle 401 status code for non-provider errors', () => { - const error = { statusCode: 401, message: 'Unauthorized' }; + it("should handle 401 status code for non-provider errors", () => { + const error = { statusCode: 401, message: "Unauthorized" }; const result = getProviderErrorMessage(error); - expect(result).toBe('Please verify your API key in Settings.'); + expect(result).toBe("Please verify your API key in Settings."); }); - it('should return original message for non-provider errors', () => { - const error = new Error('Network connection failed'); + it("should return original message for non-provider errors", () => { + const error = new Error("Network connection failed"); const result = getProviderErrorMessage(error); - expect(result).toBe('Network connection failed'); + expect(result).toBe("Network connection failed"); }); - it('should return default message when no message is available', () => { + it("should return default message when no message is available", () => { const error = {}; const result = getProviderErrorMessage(error); - expect(result).toBe('An error occurred.'); + expect(result).toBe("An error occurred."); }); - it('should handle complex error objects with structured backend response', () => { + it("should handle complex error objects with structured backend response", () => { const backendError = { statusCode: 400, message: JSON.stringify({ detail: { - error_type: 'authentication_failed', - provider: 'OpenAI', - message: 'API key invalid or expired' - } - }) + error_type: "authentication_failed", + provider: "OpenAI", + message: "API key invalid or expired", + }, + }), }; const result = getProviderErrorMessage(backendError); - expect(result).toBe('Please verify your OpenAI API key in Settings.'); + expect(result).toBe("Please verify your OpenAI API key in Settings."); }); it('should handle edge case: message contains "detail" but is not JSON', () => { const error = { - message: 'Error detail: something went wrong' + message: "Error detail: something went wrong", }; const result = getProviderErrorMessage(error); - expect(result).toBe('Error detail: something went wrong'); + expect(result).toBe("Error detail: something went wrong"); }); - it('should handle null and undefined gracefully', () => { - expect(getProviderErrorMessage(null)).toBe('An error occurred.'); - expect(getProviderErrorMessage(undefined)).toBe('An error occurred.'); + it("should handle null and undefined gracefully", () => { + expect(getProviderErrorMessage(null)).toBe("An error occurred."); + expect(getProviderErrorMessage(undefined)).toBe("An error occurred."); }); }); - describe('TypeScript strict mode compliance', () => { - it('should handle type-safe property access', () => { + describe("TypeScript strict mode compliance", () => { + it("should handle type-safe property access", () => { // Test that our type guards work properly const errorWithStatus = { statusCode: 500 }; - const errorWithMessage = { message: 'test' }; - const errorWithBoth = { statusCode: 401, message: 'unauthorized' }; + const errorWithMessage = { message: "test" }; + const errorWithBoth = { statusCode: 401, message: "unauthorized" }; // These should not throw TypeScript errors and should work correctly expect(() => parseProviderError(errorWithStatus)).not.toThrow(); @@ -259,13 +259,13 @@ describe('providerErrorHandler', () => { const result3 = parseProviderError(errorWithBoth); expect(result1.statusCode).toBe(500); - expect(result2.message).toBe('test'); + expect(result2.message).toBe("test"); expect(result3.statusCode).toBe(401); - expect(result3.message).toBe('unauthorized'); + expect(result3.message).toBe("unauthorized"); }); - it('should handle objects without expected properties safely', () => { - const objectWithoutStatus = { someOtherProperty: 'value' }; + it("should handle objects without expected properties safely", () => { + const objectWithoutStatus = { someOtherProperty: "value" }; const objectWithoutMessage = { anotherProperty: 42 }; expect(() => parseProviderError(objectWithoutStatus)).not.toThrow(); @@ -278,4 +278,4 @@ describe('providerErrorHandler', () => { expect(result2.message).toBeUndefined(); }); }); -}); \ No newline at end of file +}); diff --git a/archon-ui-main/src/features/knowledge/views/KnowledgeView.tsx b/archon-ui-main/src/features/knowledge/views/KnowledgeView.tsx index 6f6a66df..0bedc7b2 100644 --- a/archon-ui-main/src/features/knowledge/views/KnowledgeView.tsx +++ b/archon-ui-main/src/features/knowledge/views/KnowledgeView.tsx @@ -4,9 +4,9 @@ */ import { useEffect, useMemo, useRef, useState } from "react"; +import { useToast } from "@/features/shared/hooks/useToast"; import { CrawlingProgress } from "../../progress/components/CrawlingProgress"; import type { ActiveOperation } from "../../progress/types"; -import { useToast } from "@/features/shared/hooks/useToast"; import { AddKnowledgeDialog } from "../components/AddKnowledgeDialog"; import { KnowledgeHeader } from "../components/KnowledgeHeader"; import { KnowledgeList } from "../components/KnowledgeList"; diff --git a/archon-ui-main/src/features/mcp/components/McpConfigSection.tsx b/archon-ui-main/src/features/mcp/components/McpConfigSection.tsx index c36b2f01..b5344bda 100644 --- a/archon-ui-main/src/features/mcp/components/McpConfigSection.tsx +++ b/archon-ui-main/src/features/mcp/components/McpConfigSection.tsx @@ -2,9 +2,9 @@ import { Copy, ExternalLink } from "lucide-react"; import type React from "react"; import { useState } from "react"; import { useToast } from "@/features/shared/hooks"; +import { copyToClipboard } from "../../shared/utils/clipboard"; import { Button, cn, glassmorphism, Tabs, TabsContent, TabsList, TabsTrigger } from "../../ui/primitives"; import type { McpServerConfig, McpServerStatus, SupportedIDE } from "../types"; -import { copyToClipboard } from "../../shared/utils/clipboard"; interface McpConfigSectionProps { config?: McpServerConfig; @@ -324,7 +324,8 @@ export const McpConfigSection: React.FC = ({ config, stat

Platform Note: The configuration below shows{" "} {navigator.platform.toLowerCase().includes("win") ? "Windows" : "Linux/macOS"} format. Adjust paths - according to your system. This setup is complex right now because Codex has some bugs with MCP currently. + according to your system. This setup is complex right now because Codex has some bugs with MCP + currently.

)} diff --git a/archon-ui-main/src/features/mcp/hooks/useMcpQueries.ts b/archon-ui-main/src/features/mcp/hooks/useMcpQueries.ts index aef5ec68..eaf8f404 100644 --- a/archon-ui-main/src/features/mcp/hooks/useMcpQueries.ts +++ b/archon-ui-main/src/features/mcp/hooks/useMcpQueries.ts @@ -1,6 +1,6 @@ import { useQuery } from "@tanstack/react-query"; -import { STALE_TIMES } from "../../shared/queryPatterns"; import { useSmartPolling } from "@/features/shared/hooks"; +import { STALE_TIMES } from "../../shared/config/queryPatterns"; import { mcpApi } from "../services"; // Query keys factory diff --git a/archon-ui-main/src/features/mcp/services/mcpApi.ts b/archon-ui-main/src/features/mcp/services/mcpApi.ts index 008c800c..d4b02ed4 100644 --- a/archon-ui-main/src/features/mcp/services/mcpApi.ts +++ b/archon-ui-main/src/features/mcp/services/mcpApi.ts @@ -1,4 +1,4 @@ -import { callAPIWithETag } from "../../shared/apiWithEtag"; +import { callAPIWithETag } from "../../shared/api/apiClient"; import type { McpClient, McpServerConfig, McpServerStatus, McpSessionInfo } from "../types"; export const mcpApi = { diff --git a/archon-ui-main/src/features/progress/hooks/tests/useProgressQueries.test.ts b/archon-ui-main/src/features/progress/hooks/tests/useProgressQueries.test.ts index 565919aa..d305a146 100644 --- a/archon-ui-main/src/features/progress/hooks/tests/useProgressQueries.test.ts +++ b/archon-ui-main/src/features/progress/hooks/tests/useProgressQueries.test.ts @@ -19,7 +19,7 @@ vi.mock("../../services", () => ({ })); // Mock shared query patterns -vi.mock("../../../shared/queryPatterns", () => ({ +vi.mock("../../../shared/config/queryPatterns", () => ({ DISABLED_QUERY_KEY: ["disabled"] as const, STALE_TIMES: { instant: 0, diff --git a/archon-ui-main/src/features/progress/hooks/useProgressQueries.ts b/archon-ui-main/src/features/progress/hooks/useProgressQueries.ts index ae82ba17..1ebec2a9 100644 --- a/archon-ui-main/src/features/progress/hooks/useProgressQueries.ts +++ b/archon-ui-main/src/features/progress/hooks/useProgressQueries.ts @@ -5,9 +5,9 @@ import { type UseQueryResult, useQueries, useQuery, useQueryClient } from "@tanstack/react-query"; import { useEffect, useMemo, useRef } from "react"; -import { APIServiceError } from "../../shared/errors"; -import { DISABLED_QUERY_KEY, STALE_TIMES } from "../../shared/queryPatterns"; +import { DISABLED_QUERY_KEY, STALE_TIMES } from "../../shared/config/queryPatterns"; import { useSmartPolling } from "../../shared/hooks"; +import { APIServiceError } from "../../shared/types/errors"; import { progressService } from "../services"; import type { ActiveOperationsResponse, ProgressResponse, ProgressStatus } from "../types"; diff --git a/archon-ui-main/src/features/progress/services/progressService.ts b/archon-ui-main/src/features/progress/services/progressService.ts index d3f6e61e..ba0e68ba 100644 --- a/archon-ui-main/src/features/progress/services/progressService.ts +++ b/archon-ui-main/src/features/progress/services/progressService.ts @@ -3,7 +3,7 @@ * Uses ETag support for efficient polling */ -import { callAPIWithETag } from "../../shared/apiWithEtag"; +import { callAPIWithETag } from "../../shared/api/apiClient"; import type { ActiveOperationsResponse, ProgressResponse } from "../types"; export const progressService = { diff --git a/archon-ui-main/src/features/projects/components/ProjectCard.tsx b/archon-ui-main/src/features/projects/components/ProjectCard.tsx index df990710..a6b62349 100644 --- a/archon-ui-main/src/features/projects/components/ProjectCard.tsx +++ b/archon-ui-main/src/features/projects/components/ProjectCard.tsx @@ -1,7 +1,7 @@ import { motion } from "framer-motion"; import { Activity, CheckCircle2, ListTodo } from "lucide-react"; import type React from "react"; -import { isOptimistic } from "../../shared/optimistic"; +import { isOptimistic } from "@/features/shared/utils/optimistic"; import { OptimisticIndicator } from "../../ui/primitives/OptimisticIndicator"; import { cn } from "../../ui/primitives/styles"; import type { Project } from "../types"; diff --git a/archon-ui-main/src/features/projects/documents/components/DocumentCard.tsx b/archon-ui-main/src/features/projects/documents/components/DocumentCard.tsx index 25b12365..06241a46 100644 --- a/archon-ui-main/src/features/projects/documents/components/DocumentCard.tsx +++ b/archon-ui-main/src/features/projects/documents/components/DocumentCard.tsx @@ -13,9 +13,9 @@ import { } from "lucide-react"; import type React from "react"; import { memo, useCallback, useState } from "react"; +import { copyToClipboard } from "../../../shared/utils/clipboard"; import { Button } from "../../../ui/primitives"; import type { DocumentCardProps, DocumentType } from "../types"; -import { copyToClipboard } from "../../../shared/utils/clipboard"; const getDocumentIcon = (type?: DocumentType) => { switch (type) { diff --git a/archon-ui-main/src/features/projects/documents/hooks/useDocumentQueries.ts b/archon-ui-main/src/features/projects/documents/hooks/useDocumentQueries.ts index 0a7d23ee..00c6eea6 100644 --- a/archon-ui-main/src/features/projects/documents/hooks/useDocumentQueries.ts +++ b/archon-ui-main/src/features/projects/documents/hooks/useDocumentQueries.ts @@ -1,5 +1,5 @@ import { useQuery } from "@tanstack/react-query"; -import { DISABLED_QUERY_KEY, STALE_TIMES } from "../../../shared/queryPatterns"; +import { DISABLED_QUERY_KEY, STALE_TIMES } from "../../../shared/config/queryPatterns"; import { projectService } from "../../services"; import type { ProjectDocument } from "../types"; diff --git a/archon-ui-main/src/features/projects/hooks/useProjectQueries.ts b/archon-ui-main/src/features/projects/hooks/useProjectQueries.ts index ae216e66..946647ab 100644 --- a/archon-ui-main/src/features/projects/hooks/useProjectQueries.ts +++ b/archon-ui-main/src/features/projects/hooks/useProjectQueries.ts @@ -1,13 +1,13 @@ import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; +import { useSmartPolling } from "@/features/shared/hooks"; +import { useToast } from "@/features/shared/hooks/useToast"; import { createOptimisticEntity, type OptimisticEntity, removeDuplicateEntities, replaceOptimisticEntity, -} from "@/features/shared/optimistic"; -import { DISABLED_QUERY_KEY, STALE_TIMES } from "../../shared/queryPatterns"; -import { useSmartPolling } from "@/features/shared/hooks"; -import { useToast } from "@/features/shared/hooks/useToast"; +} from "@/features/shared/utils/optimistic"; +import { DISABLED_QUERY_KEY, STALE_TIMES } from "../../shared/config/queryPatterns"; import { projectService } from "../services"; import type { CreateProjectRequest, Project, UpdateProjectRequest } from "../types"; @@ -36,9 +36,7 @@ export function useProjects() { // Fetch project features export function useProjectFeatures(projectId: string | undefined) { - // TODO: Phase 4 - Add explicit typing: useQuery>> - // See PRPs/local/frontend-state-management-refactor.md Phase 4: Configure Request Deduplication - return useQuery({ + return useQuery>>({ queryKey: projectId ? projectKeys.features(projectId) : DISABLED_QUERY_KEY, queryFn: () => (projectId ? projectService.getProjectFeatures(projectId) : Promise.reject("No project ID")), enabled: !!projectId, @@ -208,6 +206,8 @@ export function useDeleteProject() { // Don't refetch on success - trust optimistic update // Only remove the specific project's detail data (including nested keys) queryClient.removeQueries({ queryKey: projectKeys.detail(projectId), exact: false }); + // Also remove the project's feature queries + queryClient.removeQueries({ queryKey: projectKeys.features(projectId), exact: false }); showToast("Project deleted successfully", "success"); }, }); diff --git a/archon-ui-main/src/features/projects/services/projectService.ts b/archon-ui-main/src/features/projects/services/projectService.ts index f74675ca..58b1f3e6 100644 --- a/archon-ui-main/src/features/projects/services/projectService.ts +++ b/archon-ui-main/src/features/projects/services/projectService.ts @@ -3,8 +3,8 @@ * Focused service for project CRUD operations only */ -import { callAPIWithETag } from "../../shared/apiWithEtag"; -import { formatZodErrors, ValidationError } from "../../shared/errors"; +import { callAPIWithETag } from "../../shared/api/apiClient"; +import { formatZodErrors, ValidationError } from "../../shared/types/errors"; import { validateCreateProject, validateUpdateProject } from "../schemas"; import { formatRelativeTime } from "../shared/api"; import type { CreateProjectRequest, Project, ProjectFeatures, UpdateProjectRequest } from "../types"; diff --git a/archon-ui-main/src/features/projects/tasks/components/TaskCard.tsx b/archon-ui-main/src/features/projects/tasks/components/TaskCard.tsx index 913964c6..c8e09464 100644 --- a/archon-ui-main/src/features/projects/tasks/components/TaskCard.tsx +++ b/archon-ui-main/src/features/projects/tasks/components/TaskCard.tsx @@ -2,7 +2,7 @@ import { Tag } from "lucide-react"; import type React from "react"; import { useCallback } from "react"; import { useDrag, useDrop } from "react-dnd"; -import { isOptimistic } from "../../../shared/optimistic"; +import { isOptimistic } from "@/features/shared/utils/optimistic"; import { OptimisticIndicator } from "../../../ui/primitives/OptimisticIndicator"; import { useTaskActions } from "../hooks"; import type { Assignee, Task, TaskPriority } from "../types"; diff --git a/archon-ui-main/src/features/projects/tasks/hooks/useTaskQueries.ts b/archon-ui-main/src/features/projects/tasks/hooks/useTaskQueries.ts index 55b4bbd0..2020a96d 100644 --- a/archon-ui-main/src/features/projects/tasks/hooks/useTaskQueries.ts +++ b/archon-ui-main/src/features/projects/tasks/hooks/useTaskQueries.ts @@ -1,11 +1,11 @@ import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; import { createOptimisticEntity, - replaceOptimisticEntity, - removeDuplicateEntities, type OptimisticEntity, -} from "@/features/shared/optimistic"; -import { DISABLED_QUERY_KEY, STALE_TIMES } from "../../../shared/queryPatterns"; + removeDuplicateEntities, + replaceOptimisticEntity, +} from "@/features/shared/utils/optimistic"; +import { DISABLED_QUERY_KEY, STALE_TIMES } from "../../../shared/config/queryPatterns"; import { useSmartPolling } from "../../../shared/hooks"; import { useToast } from "../../../shared/hooks/useToast"; import { taskService } from "../services"; diff --git a/archon-ui-main/src/features/projects/tasks/services/taskService.ts b/archon-ui-main/src/features/projects/tasks/services/taskService.ts index 223bdb73..dc2db1ed 100644 --- a/archon-ui-main/src/features/projects/tasks/services/taskService.ts +++ b/archon-ui-main/src/features/projects/tasks/services/taskService.ts @@ -3,8 +3,8 @@ * Focused service for task CRUD operations only */ -import { callAPIWithETag } from "../../../shared/apiWithEtag"; -import { formatZodErrors, ValidationError } from "../../../shared/errors"; +import { callAPIWithETag } from "../../../shared/api/apiClient"; +import { formatZodErrors, ValidationError } from "../../../shared/types/errors"; import { validateCreateTask, validateUpdateTask, validateUpdateTaskStatus } from "../schemas"; import type { CreateTaskRequest, DatabaseTaskStatus, Task, TaskCounts, UpdateTaskRequest } from "../types"; diff --git a/archon-ui-main/src/features/projects/tasks/services/tests/taskService.test.ts b/archon-ui-main/src/features/projects/tasks/services/tests/taskService.test.ts index d86cc94d..d4215814 100644 --- a/archon-ui-main/src/features/projects/tasks/services/tests/taskService.test.ts +++ b/archon-ui-main/src/features/projects/tasks/services/tests/taskService.test.ts @@ -1,10 +1,10 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; -import { callAPIWithETag } from "../../../../shared/apiWithEtag"; +import { callAPIWithETag } from "../../../../shared/api/apiClient"; import type { CreateTaskRequest, DatabaseTaskStatus, Task, UpdateTaskRequest } from "../../types"; import { taskService } from "../taskService"; // Mock the API call -vi.mock("../../../../shared/apiWithEtag", () => ({ +vi.mock("../../../../shared/api/apiClient", () => ({ callAPIWithETag: vi.fn(), })); diff --git a/archon-ui-main/src/features/settings/migrations/components/MigrationStatusCard.tsx b/archon-ui-main/src/features/settings/migrations/components/MigrationStatusCard.tsx index 2b29531c..be4317a5 100644 --- a/archon-ui-main/src/features/settings/migrations/components/MigrationStatusCard.tsx +++ b/archon-ui-main/src/features/settings/migrations/components/MigrationStatusCard.tsx @@ -29,7 +29,8 @@ export function MigrationStatusCard() {

Database Migrations

- - - ))} - - {/* Provider-specific configuration */} - {ragSettings.LLM_PROVIDER === 'ollama' && ( -
-
-
-

Ollama Configuration

-

Configure separate Ollama instances for LLM and embedding models

-
-
- {(llmStatus.online && embeddingStatus.online) ? "2 / 2 Online" : - (llmStatus.online || embeddingStatus.online) ? "1 / 2 Online" : "0 / 2 Online"} -
-
- - {/* LLM Instance Card */} -
-
-
-

LLM Instance

-

For chat completions and text generation

-
-
- {llmStatus.checking ? ( - Checking... - ) : llmStatus.online ? ( - Online ({llmStatus.responseTime}ms) - ) : ( - Offline - )} - {llmInstanceConfig.name && llmInstanceConfig.url && ( - - )} -
-
- -
-
- {llmInstanceConfig.name && llmInstanceConfig.url ? ( - <> -
-
{llmInstanceConfig.name}
-
{llmInstanceConfig.url}
-
- -
-
Model:
-
{getDisplayedChatModel(ragSettings)}
-
- -
- {llmStatus.checking ? ( - - ) : null} - {ollamaMetrics.loading ? 'Loading...' : `${ollamaMetrics.llmInstanceModels.total} models available`} -
- - ) : ( -
-
No LLM instance configured
-
Configure an instance to use LLM features
- - {/* Quick setup for single host users */} - {!embeddingInstanceConfig.url && ( -
- -
Sets up both LLM and Embedding for one host
-
- )} - - -
- )} -
- - {llmInstanceConfig.name && llmInstanceConfig.url && ( -
- - - -
- )} -
-
- - {/* Embedding Instance Card */} -
-
-
-

Embedding Instance

-

For generating text embeddings and vector search

-
-
- {embeddingStatus.checking ? ( - Checking... - ) : embeddingStatus.online ? ( - Online ({embeddingStatus.responseTime}ms) - ) : ( - Offline - )} - {embeddingInstanceConfig.name && embeddingInstanceConfig.url && ( - - )} -
-
- -
-
- {embeddingInstanceConfig.name && embeddingInstanceConfig.url ? ( - <> -
-
{embeddingInstanceConfig.name}
-
{embeddingInstanceConfig.url}
-
- -
-
Model:
-
{getDisplayedEmbeddingModel(ragSettings)}
-
- -
- {embeddingStatus.checking ? ( - - ) : null} - {ollamaMetrics.loading ? 'Loading...' : `${ollamaMetrics.embeddingInstanceModels.total} models available`} -
- - ) : ( -
-
No Embedding instance configured
-
Configure an instance to use embedding features
- -
- )} -
- - {embeddingInstanceConfig.name && embeddingInstanceConfig.url && ( -
- - - -
- )} -
-
- - {/* Single Host Indicator */} - {llmInstanceConfig.url && embeddingInstanceConfig.url && - llmInstanceConfig.url === embeddingInstanceConfig.url && ( -
-
- - - - Single Host Setup -
-

- Both LLM and Embedding instances are using the same Ollama host ({llmInstanceConfig.name}) -

-
- )} - - {/* Configuration Summary */} -
-

Configuration Summary

- - {/* Instance Comparison Table */} -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
ConfigurationLLM InstanceEmbedding Instance
Instance Name - {llmInstanceConfig.name || Not configured} - - {embeddingInstanceConfig.name || Not configured} -
Status - - {llmStatus.checking ? "Checking..." : llmStatus.online ? `Online (${llmStatus.responseTime}ms)` : "Offline"} - - - - {embeddingStatus.checking ? "Checking..." : embeddingStatus.online ? `Online (${embeddingStatus.responseTime}ms)` : "Offline"} - -
Selected Model - {getDisplayedChatModel(ragSettings) || No model selected} - - {getDisplayedEmbeddingModel(ragSettings) || No model selected} -
Available Models - {ollamaMetrics.loading ? ( - - ) : ( -
-
{ollamaMetrics.llmInstanceModels.total} Total Models
- {ollamaMetrics.llmInstanceModels.total > 0 && ( -
- - {ollamaMetrics.llmInstanceModels.chat} Chat - - - {ollamaMetrics.llmInstanceModels.embedding} Embedding - -
- )} -
- )} -
- {ollamaMetrics.loading ? ( - - ) : ( -
-
{ollamaMetrics.embeddingInstanceModels.total} Total Models
- {ollamaMetrics.embeddingInstanceModels.total > 0 && ( -
- - {ollamaMetrics.embeddingInstanceModels.chat} Chat - - - {ollamaMetrics.embeddingInstanceModels.embedding} Embedding - -
- )} -
- )} -
- - {/* System Readiness Summary */} -
-
- System Readiness: - - {(llmStatus.online && embeddingStatus.online) ? "✓ Ready (Both Instances Online)" : - (llmStatus.online || embeddingStatus.online) ? "⚠ Partial (1 of 2 Online)" : "✗ Not Ready (No Instances Online)"} - -
- - {/* Overall Model Metrics */} -
-
- - - - Overall Available: - - {ollamaMetrics.loading ? ( - - ) : ( - `${ollamaMetrics.totalModels} total (${ollamaMetrics.chatModels} chat, ${ollamaMetrics.embeddingModels} embedding)` - )} - -
-
-
-
-
-
- )} - {shouldShowProviderAlert && (

{providerAlertMessage}

)} -
- + )} + + {/* Save Settings Button */} +
+ + {/* Expandable Ollama Configuration Container */} + {showOllamaConfig && ((activeSelection === 'chat' && chatProvider === 'ollama') || + (activeSelection === 'embedding' && embeddingProvider === 'ollama')) && ( +
+
+
+

+ {activeSelection === 'chat' ? 'LLM Chat Configuration' : 'Embedding Configuration'} +

+

+ {activeSelection === 'chat' + ? 'Configure Ollama instance for chat completions' + : 'Configure Ollama instance for text embeddings'} +

+
+
+ {(activeSelection === 'chat' ? llmStatus.online : embeddingStatus.online) + ? "Online" : "Offline"} +
+
+ + {/* Configuration Content */} +
+ {activeSelection === 'chat' ? ( + // Chat Model Configuration +
+ {llmInstanceConfig.name && llmInstanceConfig.url ? ( + <> +
+
{llmInstanceConfig.name}
+
{llmInstanceConfig.url}
+
+ +
+
Model:
+
{getDisplayedChatModel(ragSettings)}
+
+ +
+ {llmStatus.checking ? ( + + ) : null} + {ollamaMetrics.loading ? 'Loading...' : `${ollamaMetrics.llmInstanceModels?.chat || 0} chat models available`} +
+ +
+ + + +
+ + ) : ( +
+
No LLM instance configured
+
Configure an instance to use LLM chat features
+ +
+ )} +
+ ) : ( + // Embedding Model Configuration +
+ {embeddingInstanceConfig.name && embeddingInstanceConfig.url ? ( + <> +
+
{embeddingInstanceConfig.name}
+
{embeddingInstanceConfig.url}
+
+ +
+
Model:
+
{getDisplayedEmbeddingModel(ragSettings)}
+
+ +
+ {embeddingStatus.checking ? ( + + ) : null} + {ollamaMetrics.loading ? 'Loading...' : `${ollamaMetrics.embeddingInstanceModels?.embedding || 0} embedding models available`} +
+ +
+ + + +
+ + ) : ( +
+
No Embedding instance configured
+
Configure an instance to use embedding features
+ +
+ )} +
+ )} +
+ + {/* Context-Aware Configuration Summary */} +
+

+ {activeSelection === 'chat' ? 'LLM Instance Summary' : 'Embedding Instance Summary'} +

+ +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Configuration + {activeSelection === 'chat' ? 'LLM Instance' : 'Embedding Instance'} +
Instance Name + {activeSelection === 'chat' + ? (llmInstanceConfig.name || Not configured) + : (embeddingInstanceConfig.name || Not configured) + } +
Instance URL + {activeSelection === 'chat' + ? (llmInstanceConfig.url || Not configured) + : (embeddingInstanceConfig.url || Not configured) + } +
Status + {activeSelection === 'chat' ? ( + + {llmStatus.checking ? "Checking..." : llmStatus.online ? `Online (${llmStatus.responseTime}ms)` : "Offline"} + + ) : ( + + {embeddingStatus.checking ? "Checking..." : embeddingStatus.online ? `Online (${embeddingStatus.responseTime}ms)` : "Offline"} + + )} +
Selected Model + {activeSelection === 'chat' + ? (getDisplayedChatModel(ragSettings) || No model selected) + : (getDisplayedEmbeddingModel(ragSettings) || No model selected) + } +
Available Models + {ollamaMetrics.loading ? ( + + ) : activeSelection === 'chat' ? ( +
+ {ollamaMetrics.llmInstanceModels?.chat || 0} + chat models +
+ ) : ( +
+ {ollamaMetrics.embeddingInstanceModels?.embedding || 0} + embedding models +
+ )} +
+ + {/* Instance-Specific Readiness */} +
+
+ + {activeSelection === 'chat' ? 'LLM Instance Status:' : 'Embedding Instance Status:'} + + + {activeSelection === 'chat' + ? (llmStatus.online ? "✓ Ready" : "✗ Not Ready") + : (embeddingStatus.online ? "✓ Ready" : "✗ Not Ready") + } + +
+ + {/* Instance-Specific Model Metrics */} +
+
+ + + + Available on this instance: + + {ollamaMetrics.loading ? ( + + ) : activeSelection === 'chat' ? ( + `${ollamaMetrics.llmInstanceModels?.chat || 0} chat models` + ) : ( + `${ollamaMetrics.embeddingInstanceModels?.embedding || 0} embedding models` + )} + +
+
+
+
+
+
+ )} - {/* Model Settings Row - Only show for non-Ollama providers */} - {ragSettings.LLM_PROVIDER !== 'ollama' && ( -
-
- setRagSettings({ - ...ragSettings, - MODEL_CHOICE: e.target.value - })} - placeholder={getModelPlaceholder(ragSettings.LLM_PROVIDER || 'openai')} - accentColor="green" - /> -
-
- setRagSettings({ - ...ragSettings, - EMBEDDING_MODEL: e.target.value - })} - placeholder={getEmbeddingPlaceholder(ragSettings.LLM_PROVIDER || 'openai')} - accentColor="green" - /> -
-
- )} - + {/* Second row: Contextual Embeddings, Max Workers, and description */}
@@ -1778,7 +2210,16 @@ export const RAGSettings = ({ showToast('LLM instance updated successfully', 'success'); // Wait 1 second then automatically test connection and refresh models setTimeout(() => { - manualTestConnection(llmInstanceConfig.url, setLLMStatus, llmInstanceConfig.name); + manualTestConnection( + llmInstanceConfig.url, + setLLMStatus, + llmInstanceConfig.name, + 'chat', + { suppressToast: true } + ).then((success) => { + setOllamaManualConfirmed(success); + setOllamaServerStatus(success ? 'online' : 'offline'); + }); fetchOllamaMetrics(); // Refresh model metrics after saving }, 1000); }} @@ -1829,7 +2270,16 @@ export const RAGSettings = ({ showToast('Embedding instance updated successfully', 'success'); // Wait 1 second then automatically test connection and refresh models setTimeout(() => { - manualTestConnection(embeddingInstanceConfig.url, setEmbeddingStatus, embeddingInstanceConfig.name); + manualTestConnection( + embeddingInstanceConfig.url, + setEmbeddingStatus, + embeddingInstanceConfig.name, + 'embedding', + { suppressToast: true } + ).then((success) => { + setOllamaManualConfirmed(success); + setOllamaServerStatus(success ? 'online' : 'offline'); + }); fetchOllamaMetrics(); // Refresh model metrics after saving }, 1000); }} @@ -1854,7 +2304,7 @@ export const RAGSettings = ({ ]} currentModel={ragSettings.MODEL_CHOICE} modelType="chat" - selectedInstanceUrl={llmInstanceConfig.url.replace('/v1', '')} + selectedInstanceUrl={normalizeBaseUrl(llmInstanceConfig.url) ?? ''} onSelectModel={(modelName: string) => { setRagSettings({ ...ragSettings, MODEL_CHOICE: modelName }); showToast(`Selected LLM model: ${modelName}`, 'success'); @@ -1873,7 +2323,7 @@ export const RAGSettings = ({ ]} currentModel={ragSettings.EMBEDDING_MODEL} modelType="embedding" - selectedInstanceUrl={embeddingInstanceConfig.url.replace('/v1', '')} + selectedInstanceUrl={normalizeBaseUrl(embeddingInstanceConfig.url) ?? ''} onSelectModel={(modelName: string) => { setRagSettings({ ...ragSettings, EMBEDDING_MODEL: modelName }); showToast(`Selected embedding model: ${modelName}`, 'success'); @@ -1907,7 +2357,7 @@ export const RAGSettings = ({ }; // Helper functions to get provider-specific model display -function getDisplayedChatModel(ragSettings: any): string { +function getDisplayedChatModel(ragSettings: RAGSettingsProps["ragSettings"]): string { const provider = ragSettings.LLM_PROVIDER || 'openai'; const modelChoice = ragSettings.MODEL_CHOICE; @@ -1935,8 +2385,8 @@ function getDisplayedChatModel(ragSettings: any): string { } } -function getDisplayedEmbeddingModel(ragSettings: any): string { - const provider = ragSettings.LLM_PROVIDER || 'openai'; +function getDisplayedEmbeddingModel(ragSettings: RAGSettingsProps["ragSettings"]): string { + const provider = ragSettings.EMBEDDING_PROVIDER || ragSettings.LLM_PROVIDER || 'openai'; const embeddingModel = ragSettings.EMBEDDING_MODEL; // Always prioritize user input to allow editing @@ -1964,7 +2414,7 @@ function getDisplayedEmbeddingModel(ragSettings: any): string { } // Helper functions for model placeholders -function getModelPlaceholder(provider: string): string { +function getModelPlaceholder(provider: ProviderKey): string { switch (provider) { case 'openai': return 'e.g., gpt-4o-mini'; @@ -1983,7 +2433,7 @@ function getModelPlaceholder(provider: string): string { } } -function getEmbeddingPlaceholder(provider: string): string { +function getEmbeddingPlaceholder(provider: ProviderKey): string { switch (provider) { case 'openai': return 'Default: text-embedding-3-small'; diff --git a/archon-ui-main/src/services/credentialsService.ts b/archon-ui-main/src/services/credentialsService.ts index f52d9679..b2d2da52 100644 --- a/archon-ui-main/src/services/credentialsService.ts +++ b/archon-ui-main/src/services/credentialsService.ts @@ -23,6 +23,7 @@ export interface RagSettings { OLLAMA_EMBEDDING_URL?: string; OLLAMA_EMBEDDING_INSTANCE_NAME?: string; EMBEDDING_MODEL?: string; + EMBEDDING_PROVIDER?: string; // Crawling Performance Settings CRAWL_BATCH_SIZE?: number; CRAWL_MAX_CONCURRENT?: number; @@ -75,6 +76,16 @@ import { getApiUrl } from "../config/api"; class CredentialsService { private baseUrl = getApiUrl(); + private notifyCredentialUpdate(keys: string[]): void { + if (typeof window === "undefined") { + return; + } + + window.dispatchEvent( + new CustomEvent("archon:credentials-updated", { detail: { keys } }) + ); + } + private handleCredentialError(error: any, context: string): Error { const errorMessage = error instanceof Error ? error.message : String(error); @@ -182,15 +193,16 @@ class CredentialsService { USE_CONTEXTUAL_EMBEDDINGS: false, CONTEXTUAL_EMBEDDINGS_MAX_WORKERS: 3, USE_HYBRID_SEARCH: true, - USE_AGENTIC_RAG: true, - USE_RERANKING: true, - MODEL_CHOICE: "gpt-4.1-nano", - LLM_PROVIDER: "openai", - LLM_BASE_URL: "", - LLM_INSTANCE_NAME: "", - OLLAMA_EMBEDDING_URL: "", - OLLAMA_EMBEDDING_INSTANCE_NAME: "", - EMBEDDING_MODEL: "", + USE_AGENTIC_RAG: true, + USE_RERANKING: true, + MODEL_CHOICE: "gpt-4.1-nano", + LLM_PROVIDER: "openai", + LLM_BASE_URL: "", + LLM_INSTANCE_NAME: "", + OLLAMA_EMBEDDING_URL: "", + OLLAMA_EMBEDDING_INSTANCE_NAME: "", + EMBEDDING_PROVIDER: "openai", + EMBEDDING_MODEL: "", // Crawling Performance Settings defaults CRAWL_BATCH_SIZE: 50, CRAWL_MAX_CONCURRENT: 10, @@ -221,6 +233,7 @@ class CredentialsService { "LLM_INSTANCE_NAME", "OLLAMA_EMBEDDING_URL", "OLLAMA_EMBEDDING_INSTANCE_NAME", + "EMBEDDING_PROVIDER", "EMBEDDING_MODEL", "CRAWL_WAIT_STRATEGY", ].includes(cred.key) @@ -278,7 +291,9 @@ class CredentialsService { throw new Error(`HTTP ${response.status}: ${errorText}`); } - return response.json(); + const updated = await response.json(); + this.notifyCredentialUpdate([credential.key]); + return updated; } catch (error) { throw this.handleCredentialError( error, @@ -302,7 +317,9 @@ class CredentialsService { throw new Error(`HTTP ${response.status}: ${errorText}`); } - return response.json(); + const created = await response.json(); + this.notifyCredentialUpdate([credential.key]); + return created; } catch (error) { throw this.handleCredentialError( error, @@ -321,6 +338,8 @@ class CredentialsService { const errorText = await response.text(); throw new Error(`HTTP ${response.status}: ${errorText}`); } + + this.notifyCredentialUpdate([key]); } catch (error) { throw this.handleCredentialError(error, `Deleting credential '${key}'`); } diff --git a/migration/0.1.0/009_add_provider_placeholders.sql b/migration/0.1.0/009_add_provider_placeholders.sql new file mode 100644 index 00000000..85d526e6 --- /dev/null +++ b/migration/0.1.0/009_add_provider_placeholders.sql @@ -0,0 +1,18 @@ +-- Migration: 009_add_provider_placeholders.sql +-- Description: Add placeholder API key rows for OpenRouter, Anthropic, and Grok +-- Version: 0.1.0 +-- Author: Archon Team +-- Date: 2025 + +-- Insert provider API key placeholders (idempotent) +INSERT INTO archon_settings (key, encrypted_value, is_encrypted, category, description) +VALUES + ('OPENROUTER_API_KEY', NULL, true, 'api_keys', 'OpenRouter API key for hosted community models. Get from: https://openrouter.ai/keys'), + ('ANTHROPIC_API_KEY', NULL, true, 'api_keys', 'Anthropic API key for Claude models. Get from: https://console.anthropic.com/account/keys'), + ('GROK_API_KEY', NULL, true, 'api_keys', 'Grok API key for xAI models. Get from: https://console.x.ai/') +ON CONFLICT (key) DO NOTHING; + +-- Record migration application for tracking +INSERT INTO archon_migrations (version, migration_name) +VALUES ('0.1.0', '009_add_provider_placeholders') +ON CONFLICT (version, migration_name) DO NOTHING; diff --git a/migration/complete_setup.sql b/migration/complete_setup.sql index 1609060c..801b07b4 100644 --- a/migration/complete_setup.sql +++ b/migration/complete_setup.sql @@ -100,7 +100,10 @@ ON CONFLICT (key) DO NOTHING; -- Add provider API key placeholders INSERT INTO archon_settings (key, encrypted_value, is_encrypted, category, description) VALUES -('GOOGLE_API_KEY', NULL, true, 'api_keys', 'Google API Key for Gemini models. Get from: https://aistudio.google.com/apikey') +('GOOGLE_API_KEY', NULL, true, 'api_keys', 'Google API key for Gemini models. Get from: https://aistudio.google.com/apikey'), +('OPENROUTER_API_KEY', NULL, true, 'api_keys', 'OpenRouter API key for hosted community models. Get from: https://openrouter.ai/keys'), +('ANTHROPIC_API_KEY', NULL, true, 'api_keys', 'Anthropic API key for Claude models. Get from: https://console.anthropic.com/account/keys'), +('GROK_API_KEY', NULL, true, 'api_keys', 'Grok API key for xAI models. Get from: https://console.x.ai/') ON CONFLICT (key) DO NOTHING; -- Code Extraction Settings Migration diff --git a/python/src/server/api_routes/knowledge_api.py b/python/src/server/api_routes/knowledge_api.py index 1f26dace..47a3d9db 100644 --- a/python/src/server/api_routes/knowledge_api.py +++ b/python/src/server/api_routes/knowledge_api.py @@ -1288,7 +1288,7 @@ async def stop_crawl_task(progress_id: str): found = False # Step 1: Cancel the orchestration service - orchestration = get_active_orchestration(progress_id) + orchestration = await get_active_orchestration(progress_id) if orchestration: orchestration.cancel() found = True @@ -1306,7 +1306,7 @@ async def stop_crawl_task(progress_id: str): found = True # Step 3: Remove from active orchestrations registry - unregister_orchestration(progress_id) + await unregister_orchestration(progress_id) # Step 4: Update progress tracker to reflect cancellation (only if we found and cancelled something) if found: diff --git a/python/src/server/services/crawling/code_extraction_service.py b/python/src/server/services/crawling/code_extraction_service.py index 1a540f57..b1705b02 100644 --- a/python/src/server/services/crawling/code_extraction_service.py +++ b/python/src/server/services/crawling/code_extraction_service.py @@ -140,6 +140,7 @@ class CodeExtractionService: progress_callback: Callable | None = None, cancellation_check: Callable[[], None] | None = None, provider: str | None = None, + embedding_provider: str | None = None, ) -> int: """ Extract code examples from crawled documents and store them. @@ -150,6 +151,8 @@ class CodeExtractionService: source_id: The unique source_id for all documents progress_callback: Optional async callback for progress updates cancellation_check: Optional function to check for cancellation + provider: Optional LLM provider identifier for summary generation + embedding_provider: Optional embedding provider override for vector creation Returns: Number of code examples stored @@ -158,9 +161,16 @@ class CodeExtractionService: extraction_callback = None if progress_callback: async def extraction_progress(data: dict): - # Scale progress to 0-20% range - raw_progress = data.get("progress", 0) - scaled_progress = int(raw_progress * 0.2) # 0-20% + # Scale progress to 0-20% range with normalization similar to later phases + raw = data.get("progress", data.get("percentage", 0)) + try: + raw_num = float(raw) + except (TypeError, ValueError): + raw_num = 0.0 + if 0.0 <= raw_num <= 1.0: + raw_num *= 100.0 + # 0-20% with clamping + scaled_progress = min(20, max(0, int(raw_num * 0.2))) data["progress"] = scaled_progress await progress_callback(data) extraction_callback = extraction_progress @@ -197,8 +207,15 @@ class CodeExtractionService: if progress_callback: async def summary_progress(data: dict): # Scale progress to 20-90% range - raw_progress = data.get("progress", 0) - scaled_progress = 20 + int(raw_progress * 0.7) # 20-90% + raw = data.get("progress", data.get("percentage", 0)) + try: + raw_num = float(raw) + except (TypeError, ValueError): + raw_num = 0.0 + if 0.0 <= raw_num <= 1.0: + raw_num *= 100.0 + # 20-90% with clamping + scaled_progress = min(90, max(20, 20 + int(raw_num * 0.7))) data["progress"] = scaled_progress await progress_callback(data) summary_callback = summary_progress @@ -216,15 +233,26 @@ class CodeExtractionService: if progress_callback: async def storage_progress(data: dict): # Scale progress to 90-100% range - raw_progress = data.get("progress", 0) - scaled_progress = 90 + int(raw_progress * 0.1) # 90-100% + raw = data.get("progress", data.get("percentage", 0)) + try: + raw_num = float(raw) + except (TypeError, ValueError): + raw_num = 0.0 + if 0.0 <= raw_num <= 1.0: + raw_num *= 100.0 + # 90-100% with clamping + scaled_progress = min(100, max(90, 90 + int(raw_num * 0.1))) data["progress"] = scaled_progress await progress_callback(data) storage_callback = storage_progress # Store code examples in database return await self._store_code_examples( - storage_data, url_to_full_document, storage_callback, provider + storage_data, + url_to_full_document, + storage_callback, + provider, + embedding_provider, ) async def _extract_code_blocks_from_documents( @@ -880,9 +908,20 @@ class CodeExtractionService: current_indent = indent block_start_idx = i current_block.append(line) - elif current_block and len("\n".join(current_block)) >= min_length: + elif current_block: + block_text = "\n".join(current_block) + threshold = ( + min_length + if min_length is not None + else await self._get_min_code_length() + ) + if len(block_text) < threshold: + current_block = [] + current_indent = None + continue + # End of indented block, check if it's code - code_content = "\n".join(current_block) + code_content = block_text # Try to detect language from content language = self._detect_language_from_content(code_content) @@ -1670,12 +1709,20 @@ class CodeExtractionService: url_to_full_document: dict[str, str], progress_callback: Callable | None = None, provider: str | None = None, + embedding_provider: str | None = None, ) -> int: """ Store code examples in the database. Returns: Number of code examples stored + + Args: + storage_data: Prepared code example payloads + url_to_full_document: Mapping of URLs to their full document content + progress_callback: Optional callback for progress updates + provider: Optional LLM provider identifier for summaries + embedding_provider: Optional embedding provider override for vector storage """ # Create progress callback for storage phase storage_progress_callback = None @@ -1713,6 +1760,7 @@ class CodeExtractionService: url_to_full_document=url_to_full_document, progress_callback=storage_progress_callback, provider=provider, + embedding_provider=embedding_provider, ) # Report completion of code extraction/storage phase diff --git a/python/src/server/services/crawling/crawling_service.py b/python/src/server/services/crawling/crawling_service.py index 69c65719..82a98c0c 100644 --- a/python/src/server/services/crawling/crawling_service.py +++ b/python/src/server/services/crawling/crawling_service.py @@ -14,6 +14,7 @@ from typing import Any, Optional from ...config.logfire_config import get_logger, safe_logfire_error, safe_logfire_info from ...utils import get_supabase_client from ...utils.progress.progress_tracker import ProgressTracker +from ..credential_service import credential_service # Import strategies # Import operations @@ -32,22 +33,35 @@ logger = get_logger(__name__) # Global registry to track active orchestration services for cancellation support _active_orchestrations: dict[str, "CrawlingService"] = {} +_orchestration_lock: asyncio.Lock | None = None -def get_active_orchestration(progress_id: str) -> Optional["CrawlingService"]: +def _ensure_orchestration_lock() -> asyncio.Lock: + global _orchestration_lock + if _orchestration_lock is None: + _orchestration_lock = asyncio.Lock() + return _orchestration_lock + + +async def get_active_orchestration(progress_id: str) -> Optional["CrawlingService"]: """Get an active orchestration service by progress ID.""" - return _active_orchestrations.get(progress_id) + lock = _ensure_orchestration_lock() + async with lock: + return _active_orchestrations.get(progress_id) -def register_orchestration(progress_id: str, orchestration: "CrawlingService"): +async def register_orchestration(progress_id: str, orchestration: "CrawlingService"): """Register an active orchestration service.""" - _active_orchestrations[progress_id] = orchestration + lock = _ensure_orchestration_lock() + async with lock: + _active_orchestrations[progress_id] = orchestration -def unregister_orchestration(progress_id: str): +async def unregister_orchestration(progress_id: str): """Unregister an orchestration service.""" - if progress_id in _active_orchestrations: - del _active_orchestrations[progress_id] + lock = _ensure_orchestration_lock() + async with lock: + _active_orchestrations.pop(progress_id, None) class CrawlingService: @@ -246,7 +260,7 @@ class CrawlingService: # Register this orchestration service for cancellation support if self.progress_id: - register_orchestration(self.progress_id, self) + await register_orchestration(self.progress_id, self) # Start the crawl as an async task in the main event loop # Store the task reference for proper cancellation @@ -477,15 +491,27 @@ class CrawlingService: try: # Extract provider from request or use credential service default provider = request.get("provider") + embedding_provider = None + if not provider: try: - from ..credential_service import credential_service provider_config = await credential_service.get_active_provider("llm") provider = provider_config.get("provider", "openai") except Exception as e: - logger.warning(f"Failed to get provider from credential service: {e}, defaulting to openai") + logger.warning( + f"Failed to get provider from credential service: {e}, defaulting to openai" + ) provider = "openai" + try: + embedding_config = await credential_service.get_active_provider("embedding") + embedding_provider = embedding_config.get("provider") + except Exception as e: + logger.warning( + f"Failed to get embedding provider from credential service: {e}. Using configured default." + ) + embedding_provider = None + code_examples_count = await self.doc_storage_ops.extract_and_store_code_examples( crawl_results, storage_results["url_to_full_document"], @@ -493,6 +519,7 @@ class CrawlingService: code_progress_callback, self._check_cancellation, provider, + embedding_provider, ) except RuntimeError as e: # Code extraction failed, continue crawl with warning @@ -548,7 +575,7 @@ class CrawlingService: # Unregister after successful completion if self.progress_id: - unregister_orchestration(self.progress_id) + await unregister_orchestration(self.progress_id) safe_logfire_info( f"Unregistered orchestration service after completion | progress_id={self.progress_id}" ) @@ -567,7 +594,7 @@ class CrawlingService: ) # Unregister on cancellation if self.progress_id: - unregister_orchestration(self.progress_id) + await unregister_orchestration(self.progress_id) safe_logfire_info( f"Unregistered orchestration service on cancellation | progress_id={self.progress_id}" ) @@ -591,7 +618,7 @@ class CrawlingService: await self.progress_tracker.error(error_message) # Unregister on error if self.progress_id: - unregister_orchestration(self.progress_id) + await unregister_orchestration(self.progress_id) safe_logfire_info( f"Unregistered orchestration service on error | progress_id={self.progress_id}" ) diff --git a/python/src/server/services/crawling/document_storage_operations.py b/python/src/server/services/crawling/document_storage_operations.py index 88ed8e80..8bfa4560 100644 --- a/python/src/server/services/crawling/document_storage_operations.py +++ b/python/src/server/services/crawling/document_storage_operations.py @@ -352,6 +352,7 @@ class DocumentStorageOperations: progress_callback: Callable | None = None, cancellation_check: Callable[[], None] | None = None, provider: str | None = None, + embedding_provider: str | None = None, ) -> int: """ Extract code examples from crawled documents and store them. @@ -363,12 +364,19 @@ class DocumentStorageOperations: progress_callback: Optional callback for progress updates cancellation_check: Optional function to check for cancellation provider: Optional LLM provider to use for code summaries + embedding_provider: Optional embedding provider override for code example embeddings Returns: Number of code examples stored """ result = await self.code_extraction_service.extract_and_store_code_examples( - crawl_results, url_to_full_document, source_id, progress_callback, cancellation_check, provider + crawl_results, + url_to_full_document, + source_id, + progress_callback, + cancellation_check, + provider, + embedding_provider, ) return result diff --git a/python/src/server/services/credential_service.py b/python/src/server/services/credential_service.py index 62fbb47a..a8aee849 100644 --- a/python/src/server/services/credential_service.py +++ b/python/src/server/services/credential_service.py @@ -36,42 +36,6 @@ class CredentialItem: description: str | None = None -def _detect_embedding_provider_from_model(embedding_model: str) -> str: - """ - Detect the appropriate embedding provider based on model name. - - Args: - embedding_model: The embedding model name - - Returns: - Provider name: 'google', 'openai', or 'openai' (default) - """ - if not embedding_model: - return "openai" # Default - - model_lower = embedding_model.lower() - - # Google embedding models - google_patterns = [ - "text-embedding-004", - "text-embedding-005", - "text-multilingual-embedding", - "gemini-embedding", - "multimodalembedding" - ] - - if any(pattern in model_lower for pattern in google_patterns): - return "google" - - # OpenAI embedding models (and default for unknown) - openai_patterns = [ - "text-embedding-ada-002", - "text-embedding-3-small", - "text-embedding-3-large" - ] - - # Default to OpenAI for OpenAI models or unknown models - return "openai" class CredentialService: @@ -475,26 +439,24 @@ class CredentialService: # Get the selected provider based on service type if service_type == "embedding": - # Get the LLM provider setting to determine embedding provider - llm_provider = rag_settings.get("LLM_PROVIDER", "openai") - embedding_model = rag_settings.get("EMBEDDING_MODEL", "text-embedding-3-small") + # First check for explicit EMBEDDING_PROVIDER setting (new split provider approach) + explicit_embedding_provider = rag_settings.get("EMBEDDING_PROVIDER") - # Determine embedding provider based on LLM provider - if llm_provider == "google": - provider = "google" - elif llm_provider == "ollama": - provider = "ollama" - elif llm_provider == "openrouter": - # OpenRouter supports both OpenAI and Google embedding models - provider = _detect_embedding_provider_from_model(embedding_model) - elif llm_provider in ["anthropic", "grok"]: - # Anthropic and Grok support both OpenAI and Google embedding models - provider = _detect_embedding_provider_from_model(embedding_model) + # Validate that embedding provider actually supports embeddings + embedding_capable_providers = {"openai", "google", "ollama"} + + if (explicit_embedding_provider and + explicit_embedding_provider != "" and + explicit_embedding_provider in embedding_capable_providers): + # Use the explicitly set embedding provider + provider = explicit_embedding_provider + logger.debug(f"Using explicit embedding provider: '{provider}'") else: - # Default case (openai, or unknown providers) + # Fall back to OpenAI as default embedding provider for backward compatibility + if explicit_embedding_provider and explicit_embedding_provider not in embedding_capable_providers: + logger.warning(f"Invalid embedding provider '{explicit_embedding_provider}' doesn't support embeddings, defaulting to OpenAI") provider = "openai" - - logger.debug(f"Determined embedding provider '{provider}' from LLM provider '{llm_provider}' and embedding model '{embedding_model}'") + logger.debug(f"No explicit embedding provider set, defaulting to OpenAI for backward compatibility") else: provider = rag_settings.get("LLM_PROVIDER", "openai") # Ensure provider is a valid string, not a boolean or other type diff --git a/python/src/server/services/embeddings/embedding_service.py b/python/src/server/services/embeddings/embedding_service.py index 4f825f1d..87ce390b 100644 --- a/python/src/server/services/embeddings/embedding_service.py +++ b/python/src/server/services/embeddings/embedding_service.py @@ -5,15 +5,19 @@ Handles all OpenAI embedding operations with proper rate limiting and error hand """ import asyncio +import inspect import os +from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any +import httpx +import numpy as np import openai from ...config.logfire_config import safe_span, search_logger from ..credential_service import credential_service -from ..llm_provider_service import get_embedding_model, get_llm_client, is_google_embedding_model, is_openai_embedding_model +from ..llm_provider_service import get_embedding_model, get_llm_client from ..threading_service import get_threading_service from .embedding_exceptions import ( EmbeddingAPIError, @@ -64,6 +68,167 @@ class EmbeddingBatchResult: return self.success_count + self.failure_count +class EmbeddingProviderAdapter(ABC): + """Adapter interface for embedding providers.""" + + @abstractmethod + async def create_embeddings( + self, + texts: list[str], + model: str, + dimensions: int | None = None, + ) -> list[list[float]]: + """Create embeddings for the given texts.""" + + +class OpenAICompatibleEmbeddingAdapter(EmbeddingProviderAdapter): + """Adapter for providers using the OpenAI embeddings API shape.""" + + def __init__(self, client: Any): + self._client = client + + async def create_embeddings( + self, + texts: list[str], + model: str, + dimensions: int | None = None, + ) -> list[list[float]]: + request_args: dict[str, Any] = { + "model": model, + "input": texts, + } + if dimensions is not None: + request_args["dimensions"] = dimensions + + response = await self._client.embeddings.create(**request_args) + return [item.embedding for item in response.data] + + +class GoogleEmbeddingAdapter(EmbeddingProviderAdapter): + """Adapter for Google's native embedding endpoint.""" + + async def create_embeddings( + self, + texts: list[str], + model: str, + dimensions: int | None = None, + ) -> list[list[float]]: + try: + google_api_key = await credential_service.get_credential("GOOGLE_API_KEY") + if not google_api_key: + raise EmbeddingAPIError("Google API key not found") + + async with httpx.AsyncClient(timeout=30.0) as http_client: + embeddings = await asyncio.gather( + *( + self._fetch_single_embedding(http_client, google_api_key, model, text, dimensions) + for text in texts + ) + ) + + return embeddings + + except httpx.HTTPStatusError as error: + error_content = error.response.text + search_logger.error( + f"Google embedding API returned {error.response.status_code} - {error_content}", + exc_info=True, + ) + raise EmbeddingAPIError( + f"Google embedding API error: {error.response.status_code} - {error_content}", + original_error=error, + ) from error + except Exception as error: + search_logger.error(f"Error calling Google embedding API: {error}", exc_info=True) + raise EmbeddingAPIError( + f"Google embedding error: {str(error)}", original_error=error + ) from error + + async def _fetch_single_embedding( + self, + http_client: httpx.AsyncClient, + api_key: str, + model: str, + text: str, + dimensions: int | None = None, + ) -> list[float]: + if model.startswith("models/"): + url_model = model[len("models/") :] + payload_model = model + else: + url_model = model + payload_model = f"models/{model}" + url = f"https://generativelanguage.googleapis.com/v1beta/models/{url_model}:embedContent" + headers = { + "x-goog-api-key": api_key, + "Content-Type": "application/json", + } + payload = { + "model": payload_model, + "content": {"parts": [{"text": text}]}, + } + + # Add output_dimensionality parameter if dimensions are specified and supported + if dimensions is not None and dimensions > 0: + model_name = payload_model.removeprefix("models/") + if model_name.startswith("textembedding-gecko"): + supported_dimensions = {128, 256, 512, 768} + else: + supported_dimensions = {128, 256, 512, 768, 1024, 1536, 2048, 3072} + + if dimensions in supported_dimensions: + payload["outputDimensionality"] = dimensions + else: + search_logger.warning( + f"Requested dimension {dimensions} is not supported by Google model '{model_name}'. " + "Falling back to the provider default." + ) + + response = await http_client.post(url, headers=headers, json=payload) + response.raise_for_status() + + result = response.json() + embedding = result.get("embedding", {}) + values = embedding.get("values") if isinstance(embedding, dict) else None + if not isinstance(values, list): + raise EmbeddingAPIError(f"Invalid embedding payload from Google: {result}") + + # Normalize embeddings for dimensions < 3072 as per Google's documentation + actual_dimension = len(values) + if actual_dimension > 0 and actual_dimension < 3072: + values = self._normalize_embedding(values) + + return values + + def _normalize_embedding(self, embedding: list[float]) -> list[float]: + """Normalize embedding vector for dimensions < 3072.""" + try: + embedding_array = np.array(embedding, dtype=np.float32) + norm = np.linalg.norm(embedding_array) + if norm > 0: + normalized = embedding_array / norm + return normalized.tolist() + else: + search_logger.warning("Zero-norm embedding detected, returning unnormalized") + return embedding + except Exception as e: + search_logger.error(f"Failed to normalize embedding: {e}") + # Return original embedding if normalization fails + return embedding + + +def _get_embedding_adapter(provider: str, client: Any) -> EmbeddingProviderAdapter: + provider_name = (provider or "").lower() + if provider_name == "google": + return GoogleEmbeddingAdapter() + return OpenAICompatibleEmbeddingAdapter(client) + + +async def _maybe_await(value: Any) -> Any: + """Await the value if it is awaitable, otherwise return as-is.""" + + return await value if inspect.isawaitable(value) else value + # Provider-aware client factory get_openai_client = get_llm_client @@ -185,27 +350,25 @@ async def create_embeddings_batch( "create_embeddings_batch", text_count=len(texts), total_chars=sum(len(t) for t in texts) ) as span: try: - # Intelligent embedding provider routing based on model type - # Get the embedding model first to determine the correct provider - embedding_model = await get_embedding_model(provider=provider) + embedding_config = await _maybe_await( + credential_service.get_active_provider(service_type="embedding") + ) - # Route to correct provider based on model type - if is_google_embedding_model(embedding_model): - embedding_provider = "google" - search_logger.info(f"Routing to Google for embedding model: {embedding_model}") - elif is_openai_embedding_model(embedding_model) or "openai/" in embedding_model.lower(): + embedding_provider = provider or embedding_config.get("provider") + + if not isinstance(embedding_provider, str) or not embedding_provider.strip(): embedding_provider = "openai" - search_logger.info(f"Routing to OpenAI for embedding model: {embedding_model}") - else: - # Keep original provider for ollama and other providers - embedding_provider = provider - search_logger.info(f"Using original provider '{provider}' for embedding model: {embedding_model}") + if not embedding_provider: + search_logger.error("No embedding provider configured") + raise ValueError("No embedding provider configured. Please set EMBEDDING_PROVIDER environment variable.") + + search_logger.info(f"Using embedding provider: '{embedding_provider}' (from EMBEDDING_PROVIDER setting)") async with get_llm_client(provider=embedding_provider, use_embedding_provider=True) as client: # Load batch size and dimensions from settings try: - rag_settings = await credential_service.get_credentials_by_category( - "rag_strategy" + rag_settings = await _maybe_await( + credential_service.get_credentials_by_category("rag_strategy") ) batch_size = int(rag_settings.get("EMBEDDING_BATCH_SIZE", "100")) embedding_dimensions = int(rag_settings.get("EMBEDDING_DIMENSIONS", "1536")) @@ -215,6 +378,8 @@ async def create_embeddings_batch( embedding_dimensions = 1536 total_tokens_used = 0 + adapter = _get_embedding_adapter(embedding_provider, client) + dimensions_to_use = embedding_dimensions if embedding_dimensions > 0 else None for i in range(0, len(texts), batch_size): batch = texts[i : i + batch_size] @@ -243,16 +408,14 @@ async def create_embeddings_batch( try: # Create embeddings for this batch embedding_model = await get_embedding_model(provider=embedding_provider) - - response = await client.embeddings.create( - model=embedding_model, - input=batch, - dimensions=embedding_dimensions, + embeddings = await adapter.create_embeddings( + batch, + embedding_model, + dimensions=dimensions_to_use, ) - # Add successful embeddings - for text, item in zip(batch, response.data, strict=False): - result.add_success(item.embedding, text) + for text, vector in zip(batch, embeddings, strict=False): + result.add_success(vector, text) break # Success, exit retry loop @@ -297,6 +460,17 @@ async def create_embeddings_batch( await asyncio.sleep(wait_time) else: raise # Will be caught by outer try + except EmbeddingRateLimitError as e: + retry_count += 1 + if retry_count < max_retries: + wait_time = 2**retry_count + search_logger.warning( + f"Embedding rate limit for batch {batch_index}: {e}. " + f"Waiting {wait_time}s before retry {retry_count}/{max_retries}" + ) + await asyncio.sleep(wait_time) + else: + raise except Exception as e: # This batch failed - track failures but continue with next batch diff --git a/python/src/server/services/storage/code_storage_service.py b/python/src/server/services/storage/code_storage_service.py index a993bc70..8e237f7e 100644 --- a/python/src/server/services/storage/code_storage_service.py +++ b/python/src/server/services/storage/code_storage_service.py @@ -1091,6 +1091,7 @@ async def add_code_examples_to_supabase( url_to_full_document: dict[str, str] | None = None, progress_callback: Callable | None = None, provider: str | None = None, + embedding_provider: str | None = None, ): """ Add code examples to the Supabase code_examples table in batches. @@ -1105,6 +1106,8 @@ async def add_code_examples_to_supabase( batch_size: Size of each batch for insertion url_to_full_document: Optional mapping of URLs to full document content progress_callback: Optional async callback for progress updates + provider: Optional LLM provider used for summary generation tracking + embedding_provider: Optional embedding provider override for vector generation """ if not urls: return @@ -1183,8 +1186,8 @@ async def add_code_examples_to_supabase( # Use original combined texts batch_texts = combined_texts - # Create embeddings for the batch - result = await create_embeddings_batch(batch_texts, provider=provider) + # Create embeddings for the batch (optionally overriding the embedding provider) + result = await create_embeddings_batch(batch_texts, provider=embedding_provider) # Log any failures if result.has_failures: @@ -1201,7 +1204,7 @@ async def add_code_examples_to_supabase( from ..llm_provider_service import get_embedding_model # Get embedding model name - embedding_model_name = await get_embedding_model(provider=provider) + embedding_model_name = await get_embedding_model(provider=embedding_provider) # Get LLM chat model (used for code summaries and contextual embeddings if enabled) llm_chat_model = None diff --git a/python/tests/test_code_extraction_source_id.py b/python/tests/test_code_extraction_source_id.py index 05405ee7..7899c7fc 100644 --- a/python/tests/test_code_extraction_source_id.py +++ b/python/tests/test_code_extraction_source_id.py @@ -111,8 +111,8 @@ class TestCodeExtractionSourceId: assert args[2] == source_id assert args[3] is None assert args[4] is None - if len(args) > 5: - assert args[5] is None + assert args[5] is None + assert args[6] is None assert result == 5 @pytest.mark.asyncio From 489415d72303c43781175e4aae7e9d1bef98fa3a Mon Sep 17 00:00:00 2001 From: Wirasm <152263317+Wirasm@users.noreply.github.com> Date: Thu, 9 Oct 2025 17:52:06 +0300 Subject: [PATCH 7/7] Fix: Database timeout when deleting large sources (#737) * fix: implement CASCADE DELETE for source deletion timeout issue - Add migration 009 to add CASCADE DELETE constraints to foreign keys - Simplify delete_source() to only delete parent record - Database now handles cascading deletes efficiently - Fixes timeout issues when deleting sources with thousands of pages * chore: update complete_setup.sql to include CASCADE DELETE constraints - Add ON DELETE CASCADE to foreign keys in initial setup - Include migration 009 in the migrations tracking - Ensures new installations have CASCADE DELETE from the start --- .../009_add_cascade_delete_constraints.sql | 67 +++++++++++ migration/complete_setup.sql | 11 +- .../services/source_management_service.py | 113 +++++++----------- 3 files changed, 116 insertions(+), 75 deletions(-) create mode 100644 migration/0.1.0/009_add_cascade_delete_constraints.sql diff --git a/migration/0.1.0/009_add_cascade_delete_constraints.sql b/migration/0.1.0/009_add_cascade_delete_constraints.sql new file mode 100644 index 00000000..a8e71a47 --- /dev/null +++ b/migration/0.1.0/009_add_cascade_delete_constraints.sql @@ -0,0 +1,67 @@ +-- ===================================================== +-- Migration 009: Add CASCADE DELETE constraints +-- ===================================================== +-- This migration adds CASCADE DELETE to foreign key constraints +-- for archon_crawled_pages and archon_code_examples tables +-- to fix database timeout issues when deleting large sources +-- +-- Issue: Deleting sources with thousands of crawled pages times out +-- Solution: Let the database handle cascading deletes efficiently +-- ===================================================== + +-- Start transaction for atomic changes +BEGIN; + +-- Drop existing foreign key constraints +ALTER TABLE archon_crawled_pages + DROP CONSTRAINT IF EXISTS archon_crawled_pages_source_id_fkey; + +ALTER TABLE archon_code_examples + DROP CONSTRAINT IF EXISTS archon_code_examples_source_id_fkey; + +-- Re-add foreign key constraints with CASCADE DELETE +ALTER TABLE archon_crawled_pages + ADD CONSTRAINT archon_crawled_pages_source_id_fkey + FOREIGN KEY (source_id) + REFERENCES archon_sources(source_id) + ON DELETE CASCADE; + +ALTER TABLE archon_code_examples + ADD CONSTRAINT archon_code_examples_source_id_fkey + FOREIGN KEY (source_id) + REFERENCES archon_sources(source_id) + ON DELETE CASCADE; + +-- Add comment explaining the CASCADE behavior +COMMENT ON CONSTRAINT archon_crawled_pages_source_id_fkey ON archon_crawled_pages IS + 'Foreign key with CASCADE DELETE - automatically deletes all crawled pages when source is deleted'; + +COMMENT ON CONSTRAINT archon_code_examples_source_id_fkey ON archon_code_examples IS + 'Foreign key with CASCADE DELETE - automatically deletes all code examples when source is deleted'; + +-- Record the migration +INSERT INTO archon_migrations (version, migration_name) +VALUES ('0.1.0', '009_add_cascade_delete_constraints') +ON CONFLICT (version, migration_name) DO NOTHING; + +-- Commit transaction +COMMIT; + +-- ===================================================== +-- Verification queries (run separately if needed) +-- ===================================================== +-- To verify the constraints after migration: +-- +-- SELECT +-- tc.table_name, +-- tc.constraint_name, +-- tc.constraint_type, +-- rc.delete_rule +-- FROM information_schema.table_constraints tc +-- JOIN information_schema.referential_constraints rc +-- ON tc.constraint_name = rc.constraint_name +-- WHERE tc.table_name IN ('archon_crawled_pages', 'archon_code_examples') +-- AND tc.constraint_type = 'FOREIGN KEY'; +-- +-- Expected result: Both constraints should show delete_rule = 'CASCADE' +-- ===================================================== \ No newline at end of file diff --git a/migration/complete_setup.sql b/migration/complete_setup.sql index 801b07b4..99917060 100644 --- a/migration/complete_setup.sql +++ b/migration/complete_setup.sql @@ -223,8 +223,8 @@ CREATE TABLE IF NOT EXISTS archon_crawled_pages ( -- Add a unique constraint to prevent duplicate chunks for the same URL UNIQUE(url, chunk_number), - -- Add foreign key constraint to sources table - FOREIGN KEY (source_id) REFERENCES archon_sources(source_id) + -- Add foreign key constraint to sources table with CASCADE DELETE + FOREIGN KEY (source_id) REFERENCES archon_sources(source_id) ON DELETE CASCADE ); -- Multi-dimensional indexes @@ -272,8 +272,8 @@ CREATE TABLE IF NOT EXISTS archon_code_examples ( -- Add a unique constraint to prevent duplicate chunks for the same URL UNIQUE(url, chunk_number), - -- Add foreign key constraint to sources table - FOREIGN KEY (source_id) REFERENCES archon_sources(source_id) + -- Add foreign key constraint to sources table with CASCADE DELETE + FOREIGN KEY (source_id) REFERENCES archon_sources(source_id) ON DELETE CASCADE ); -- Multi-dimensional indexes @@ -990,7 +990,8 @@ VALUES ('0.1.0', '005_ollama_create_functions'), ('0.1.0', '006_ollama_create_indexes_optional'), ('0.1.0', '007_add_priority_column_to_tasks'), - ('0.1.0', '008_add_migration_tracking') + ('0.1.0', '008_add_migration_tracking'), + ('0.1.0', '009_add_cascade_delete_constraints') ON CONFLICT (version, migration_name) DO NOTHING; -- Enable Row Level Security on migrations table diff --git a/python/src/server/services/source_management_service.py b/python/src/server/services/source_management_service.py index f8a27023..7152f830 100644 --- a/python/src/server/services/source_management_service.py +++ b/python/src/server/services/source_management_service.py @@ -11,7 +11,7 @@ from supabase import Client from ..config.logfire_config import get_logger, search_logger from .client_manager import get_supabase_client -from .llm_provider_service import extract_message_text, get_llm_client +from .llm_provider_service import extract_message_text, get_llm_client logger = get_logger(__name__) @@ -72,21 +72,21 @@ The above content is from the documentation for '{source_id}'. Please provide a ) # Extract the generated summary with proper error handling - if not response or not response.choices or len(response.choices) == 0: - search_logger.error(f"Empty or invalid response from LLM for {source_id}") - return default_summary - - choice = response.choices[0] - summary_text, _, _ = extract_message_text(choice) - if not summary_text: - search_logger.error(f"LLM returned None content for {source_id}") - return default_summary - - summary = summary_text.strip() - - # Ensure the summary is not too long - if len(summary) > max_length: - summary = summary[:max_length] + "..." + if not response or not response.choices or len(response.choices) == 0: + search_logger.error(f"Empty or invalid response from LLM for {source_id}") + return default_summary + + choice = response.choices[0] + summary_text, _, _ = extract_message_text(choice) + if not summary_text: + search_logger.error(f"LLM returned None content for {source_id}") + return default_summary + + summary = summary_text.strip() + + # Ensure the summary is not too long + if len(summary) > max_length: + summary = summary[:max_length] + "..." return summary @@ -188,9 +188,9 @@ Generate only the title, nothing else.""" ], ) - choice = response.choices[0] - generated_title, _, _ = extract_message_text(choice) - generated_title = generated_title.strip() + choice = response.choices[0] + generated_title, _, _ = extract_message_text(choice) + generated_title = generated_title.strip() # Clean up the title generated_title = generated_title.strip("\"'") if len(generated_title) < 50: # Sanity check @@ -400,7 +400,10 @@ class SourceManagementService: def delete_source(self, source_id: str) -> tuple[bool, dict[str, Any]]: """ - Delete a source and all associated crawled pages and code examples from the database. + Delete a source from the database. + + With CASCADE DELETE constraints in place (migration 009), deleting the source + will automatically delete all associated crawled_pages and code_examples. Args: source_id: The source ID to delete @@ -411,61 +414,31 @@ class SourceManagementService: try: logger.info(f"Starting delete_source for source_id: {source_id}") - # Delete from crawled_pages table - try: - logger.info(f"Deleting from crawled_pages table for source_id: {source_id}") - pages_response = ( - self.supabase_client.table("archon_crawled_pages") - .delete() - .eq("source_id", source_id) - .execute() - ) - pages_deleted = len(pages_response.data) if pages_response.data else 0 - logger.info(f"Deleted {pages_deleted} pages from crawled_pages") - except Exception as pages_error: - logger.error(f"Failed to delete from crawled_pages: {pages_error}") - return False, {"error": f"Failed to delete crawled pages: {str(pages_error)}"} + # With CASCADE DELETE, we only need to delete from the sources table + # The database will automatically handle deleting related records + logger.info(f"Deleting source {source_id} (CASCADE will handle related records)") - # Delete from code_examples table - try: - logger.info(f"Deleting from code_examples table for source_id: {source_id}") - code_response = ( - self.supabase_client.table("archon_code_examples") - .delete() - .eq("source_id", source_id) - .execute() - ) - code_deleted = len(code_response.data) if code_response.data else 0 - logger.info(f"Deleted {code_deleted} code examples") - except Exception as code_error: - logger.error(f"Failed to delete from code_examples: {code_error}") - return False, {"error": f"Failed to delete code examples: {str(code_error)}"} + source_response = ( + self.supabase_client.table("archon_sources") + .delete() + .eq("source_id", source_id) + .execute() + ) - # Delete from sources table - try: - logger.info(f"Deleting from sources table for source_id: {source_id}") - source_response = ( - self.supabase_client.table("archon_sources") - .delete() - .eq("source_id", source_id) - .execute() - ) - source_deleted = len(source_response.data) if source_response.data else 0 - logger.info(f"Deleted {source_deleted} source records") - except Exception as source_error: - logger.error(f"Failed to delete from sources: {source_error}") - return False, {"error": f"Failed to delete source: {str(source_error)}"} + source_deleted = len(source_response.data) if source_response.data else 0 - logger.info("Delete operation completed successfully") - return True, { - "source_id": source_id, - "pages_deleted": pages_deleted, - "code_examples_deleted": code_deleted, - "source_records_deleted": source_deleted, - } + if source_deleted > 0: + logger.info(f"Successfully deleted source {source_id} and all related data via CASCADE") + return True, { + "source_id": source_id, + "message": "Source and all related data deleted successfully via CASCADE DELETE" + } + else: + logger.warning(f"No source found with ID {source_id}") + return False, {"error": f"Source {source_id} not found"} except Exception as e: - logger.error(f"Unexpected error in delete_source: {e}") + logger.error(f"Error deleting source {source_id}: {e}") return False, {"error": f"Error deleting source: {str(e)}"} def update_source_metadata(