diff --git a/.env.example b/.env.example index 1daa64bc..2c14308e 100644 --- a/.env.example +++ b/.env.example @@ -124,7 +124,8 @@ PROD=false # Run the credentials_setup.sql file in your Supabase SQL editor to set up the credentials table. # Then use the Settings page in the web UI to manage: # - OPENAI_API_KEY (encrypted) -# - MODEL_CHOICE +# - OPENROUTER_API_KEY (encrypted, format: sk-or-v1-..., get from https://openrouter.ai/keys) +# - MODEL_CHOICE # - TRANSPORT settings # - RAG strategy flags (USE_CONTEXTUAL_EMBEDDINGS, USE_HYBRID_SEARCH, etc.) # - Crawler settings: diff --git a/archon-ui-main/src/components/settings/RAGSettings.tsx b/archon-ui-main/src/components/settings/RAGSettings.tsx index 62739fc7..a5b9a945 100644 --- a/archon-ui-main/src/components/settings/RAGSettings.tsx +++ b/archon-ui-main/src/components/settings/RAGSettings.tsx @@ -15,7 +15,7 @@ import OllamaModelSelectionModal from './OllamaModelSelectionModal'; type ProviderKey = 'openai' | 'google' | 'ollama' | 'anthropic' | 'grok' | 'openrouter'; // Providers that support embedding models -const EMBEDDING_CAPABLE_PROVIDERS: ProviderKey[] = ['openai', 'google', 'ollama']; +const EMBEDDING_CAPABLE_PROVIDERS: ProviderKey[] = ['openai', 'google', 'openrouter', 'ollama']; interface ProviderModels { chatModel: string; @@ -42,7 +42,7 @@ const getDefaultModels = (provider: ProviderKey): ProviderModels => { anthropic: 'text-embedding-3-small', // Fallback to OpenAI google: 'text-embedding-004', grok: 'text-embedding-3-small', // Fallback to OpenAI - openrouter: 'text-embedding-3-small', + openrouter: 'openai/text-embedding-3-small', // MUST include provider prefix for OpenRouter ollama: 'nomic-embed-text' }; @@ -1291,7 +1291,7 @@ const manualTestConnection = async ( Select {activeSelection === 'chat' ? 'Chat' : 'Embedding'} Provider
{[ { key: 'openai', name: 'OpenAI', logo: '/img/OpenAI.png', color: 'green' }, diff --git a/archon-ui-main/src/services/openrouterService.ts b/archon-ui-main/src/services/openrouterService.ts new file mode 100644 index 00000000..71cc94c3 --- /dev/null +++ b/archon-ui-main/src/services/openrouterService.ts @@ -0,0 +1,244 @@ +/** + * OpenRouter Service Client + * + * Provides frontend API client for OpenRouter model discovery. + */ + +import { getApiUrl } from "../config/api"; + +// Type definitions for OpenRouter API responses +export interface OpenRouterEmbeddingModel { + id: string; + provider: string; + name: string; + dimensions: number; + context_length: number; + pricing_per_1m_tokens: number; + supports_dimension_reduction: boolean; +} + +export interface OpenRouterModelListResponse { + embedding_models: OpenRouterEmbeddingModel[]; + total_count: number; +} + +class OpenRouterService { + private getBaseUrl = () => getApiUrl(); + private cacheKey = "openrouter_models_cache"; + private cacheTTL = 5 * 60 * 1000; // 5 minutes + + private handleApiError(error: unknown, context: string): Error { + const errorMessage = error instanceof Error ? error.message : String(error); + + // Check for network errors + if ( + errorMessage.toLowerCase().includes("network") || + errorMessage.includes("fetch") || + errorMessage.includes("Failed to fetch") + ) { + return new Error( + `Network error while ${context.toLowerCase()}: ${errorMessage}. ` + + "Please check your connection.", + ); + } + + // Check for timeout errors + if (errorMessage.includes("timeout") || errorMessage.includes("AbortError")) { + return new Error( + `Timeout error while ${context.toLowerCase()}: The server may be slow to respond.`, + ); + } + + // Return original error with context + return new Error(`${context} failed: ${errorMessage}`); + } + + /** + * Type guard to validate cache entry structure + */ + private isCacheEntry( + value: unknown, + ): value is { data: OpenRouterModelListResponse; timestamp: number } { + if (typeof value !== "object" || value === null) { + return false; + } + + const obj = value as Record; + + // Validate timestamp is a number + if (typeof obj.timestamp !== "number") { + return false; + } + + // Validate data property exists and is an object + if (typeof obj.data !== "object" || obj.data === null) { + return false; + } + + const data = obj.data as Record; + + // Validate OpenRouterModelListResponse structure + if (!Array.isArray(data.embedding_models)) { + return false; + } + + if (typeof data.total_count !== "number") { + return false; + } + + // Validate each model in the array has required fields + for (const model of data.embedding_models) { + if (typeof model !== "object" || model === null) { + return false; + } + const m = model as Record; + if ( + typeof m.id !== "string" || + typeof m.provider !== "string" || + typeof m.name !== "string" || + typeof m.dimensions !== "number" || + typeof m.context_length !== "number" || + typeof m.pricing_per_1m_tokens !== "number" || + typeof m.supports_dimension_reduction !== "boolean" + ) { + return false; + } + } + + return true; + } + + /** + * Get cached models if available and not expired + */ + private getCachedModels(): OpenRouterModelListResponse | null { + try { + const cached = sessionStorage.getItem(this.cacheKey); + if (!cached) return null; + + const parsed: unknown = JSON.parse(cached); + + // Validate cache structure + if (!this.isCacheEntry(parsed)) { + // Cache is corrupted, remove it to avoid repeated failures + sessionStorage.removeItem(this.cacheKey); + return null; + } + + const now = Date.now(); + + // Check expiration + if (now - parsed.timestamp > this.cacheTTL) { + sessionStorage.removeItem(this.cacheKey); + return null; + } + + return parsed.data; + } catch { + // JSON parsing failed or other error, clear cache + sessionStorage.removeItem(this.cacheKey); + return null; + } + } + + /** + * Cache models for the TTL duration + */ + private cacheModels(data: OpenRouterModelListResponse): void { + try { + const cacheData = { + data, + timestamp: Date.now(), + }; + sessionStorage.setItem(this.cacheKey, JSON.stringify(cacheData)); + } catch { + // Ignore cache errors + } + } + + /** + * Discover available OpenRouter embedding models + */ + async discoverModels(): Promise { + try { + // Check cache first + const cached = this.getCachedModels(); + if (cached) { + return cached; + } + + const response = await fetch(`${this.getBaseUrl()}/api/openrouter/models`, { + method: "GET", + headers: { + "Content-Type": "application/json", + }, + }); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error(`HTTP ${response.status}: ${errorText}`); + } + + const data = await response.json(); + + // Validate response structure + if (!data.embedding_models || !Array.isArray(data.embedding_models)) { + throw new Error("Invalid response structure: missing or invalid embedding_models array"); + } + + if (typeof data.total_count !== "number" || data.total_count < 0) { + throw new Error("Invalid response structure: total_count must be a non-negative number"); + } + + if (data.total_count !== data.embedding_models.length) { + throw new Error( + `Response structure mismatch: total_count (${data.total_count}) does not match embedding_models length (${data.embedding_models.length})`, + ); + } + + // Validate at least one model has required fields + if (data.embedding_models.length > 0) { + const firstModel = data.embedding_models[0]; + if ( + !firstModel.id || + typeof firstModel.id !== "string" || + !firstModel.provider || + typeof firstModel.provider !== "string" || + typeof firstModel.dimensions !== "number" || + firstModel.dimensions <= 0 + ) { + throw new Error( + "Invalid model structure: models must have id (string), provider (string), and positive dimensions", + ); + } + + // Validate provider name is from expected set + const validProviders = ["openai", "google", "qwen", "mistralai"]; + if (!validProviders.includes(firstModel.provider)) { + throw new Error(`Invalid provider name: ${firstModel.provider}`); + } + } + + // Cache the successful response + this.cacheModels(data); + + return data; + } catch (error) { + throw this.handleApiError(error, "Model discovery"); + } + } + + /** + * Clear the models cache + */ + clearCache(): void { + try { + sessionStorage.removeItem(this.cacheKey); + } catch { + // Ignore cache clearing errors + } + } +} + +// Export singleton instance +export const openrouterService = new OpenRouterService(); diff --git a/python/src/server/api_routes/openrouter_api.py b/python/src/server/api_routes/openrouter_api.py new file mode 100644 index 00000000..7e28ea23 --- /dev/null +++ b/python/src/server/api_routes/openrouter_api.py @@ -0,0 +1,27 @@ +""" +OpenRouter API routes. + +Endpoints for OpenRouter model discovery and configuration. +""" + +from fastapi import APIRouter + +from ..services.openrouter_discovery_service import OpenRouterModelListResponse, openrouter_discovery_service + +router = APIRouter(prefix="/api/openrouter", tags=["openrouter"]) + + +@router.get("/models", response_model=OpenRouterModelListResponse) +async def get_openrouter_models() -> OpenRouterModelListResponse: + """ + Get available OpenRouter embedding models. + + Returns a list of embedding models available through OpenRouter, + including models from OpenAI, Google, Qwen, and Mistral providers. + + Returns: + OpenRouterModelListResponse: List of embedding models with metadata + """ + models = await openrouter_discovery_service.discover_embedding_models() + + return OpenRouterModelListResponse(embedding_models=models, total_count=len(models)) diff --git a/python/src/server/config/config.py b/python/src/server/config/config.py index d8104bb0..df035037 100644 --- a/python/src/server/config/config.py +++ b/python/src/server/config/config.py @@ -66,6 +66,19 @@ def validate_openai_api_key(api_key: str) -> bool: return True +def validate_openrouter_api_key(api_key: str) -> bool: + """Validate OpenRouter API key format.""" + if not api_key: + raise ConfigurationError("OpenRouter API key cannot be empty") + + if not api_key.startswith("sk-or-v1-"): + raise ConfigurationError( + "OpenRouter API key must start with 'sk-or-v1-'. " "Get your key at https://openrouter.ai/keys" + ) + + return True + + def validate_supabase_key(supabase_key: str) -> tuple[bool, str]: """Validate Supabase key type and return validation result. diff --git a/python/src/server/main.py b/python/src/server/main.py index e83dac1b..b7d272a6 100644 --- a/python/src/server/main.py +++ b/python/src/server/main.py @@ -26,6 +26,7 @@ from .api_routes.knowledge_api import router as knowledge_router from .api_routes.mcp_api import router as mcp_router from .api_routes.migration_api import router as migration_router from .api_routes.ollama_api import router as ollama_router +from .api_routes.openrouter_api import router as openrouter_router from .api_routes.pages_api import router as pages_router from .api_routes.progress_api import router as progress_router from .api_routes.projects_api import router as projects_router @@ -187,6 +188,7 @@ app.include_router(mcp_router) app.include_router(knowledge_router) app.include_router(pages_router) app.include_router(ollama_router) +app.include_router(openrouter_router) app.include_router(projects_router) app.include_router(progress_router) app.include_router(agent_chat_router) diff --git a/python/src/server/services/credential_service.py b/python/src/server/services/credential_service.py index a8aee849..f4fb275b 100644 --- a/python/src/server/services/credential_service.py +++ b/python/src/server/services/credential_service.py @@ -443,7 +443,7 @@ class CredentialService: explicit_embedding_provider = rag_settings.get("EMBEDDING_PROVIDER") # Validate that embedding provider actually supports embeddings - embedding_capable_providers = {"openai", "google", "ollama"} + embedding_capable_providers = {"openai", "google", "openrouter", "ollama"} if (explicit_embedding_provider and explicit_embedding_provider != "" and diff --git a/python/src/server/services/embeddings/provider_error_adapters.py b/python/src/server/services/embeddings/provider_error_adapters.py index 5fea9d5e..8ecf1b02 100644 --- a/python/src/server/services/embeddings/provider_error_adapters.py +++ b/python/src/server/services/embeddings/provider_error_adapters.py @@ -8,13 +8,6 @@ with unified error handling and sanitization patterns. import re from abc import ABC, abstractmethod -from .embedding_exceptions import ( - EmbeddingAPIError, - EmbeddingAuthenticationError, - EmbeddingQuotaExhaustedError, - EmbeddingRateLimitError, -) - class ProviderErrorAdapter(ABC): """Abstract base class for provider-specific error handling.""" @@ -37,7 +30,7 @@ class OpenAIErrorAdapter(ProviderErrorAdapter): return "OpenAI API encountered an error. Please verify your API key and quota." sanitized = message - + # Comprehensive OpenAI patterns with case-insensitive matching patterns = [ (r'sk-[a-zA-Z0-9]{48}', '[REDACTED_KEY]'), # OpenAI API keys @@ -68,7 +61,7 @@ class GoogleAIErrorAdapter(ProviderErrorAdapter): return "Google AI API encountered an error. Please verify your API key." sanitized = message - + # Comprehensive Google AI patterns patterns = [ (r'AIza[a-zA-Z0-9_-]{35}', '[REDACTED_KEY]'), # Google AI API keys @@ -99,7 +92,7 @@ class AnthropicErrorAdapter(ProviderErrorAdapter): return "Anthropic API encountered an error. Please verify your API key." sanitized = message - + # Comprehensive Anthropic patterns patterns = [ (r'sk-ant-[a-zA-Z0-9_-]{10,}', '[REDACTED_KEY]'), # Anthropic API keys @@ -118,6 +111,34 @@ class AnthropicErrorAdapter(ProviderErrorAdapter): return sanitized +class OpenRouterErrorAdapter(ProviderErrorAdapter): + def get_provider_name(self) -> str: + return "openrouter" + + def sanitize_error_message(self, message: str) -> str: + if not isinstance(message, str) or not message.strip() or len(message) > 2000: + return "OpenRouter API encountered an error. Please verify your API key and quota." + + sanitized = message + + # Comprehensive OpenRouter patterns + patterns = [ + (r'sk-or-v1-[a-zA-Z0-9_-]{10,}', '[REDACTED_KEY]'), # OpenRouter API keys + (r'https?://[^\s]*openrouter\.ai[^\s]*', '[REDACTED_URL]'), # OpenRouter URLs + (r'Bearer\s+[a-zA-Z0-9._-]+', 'Bearer [REDACTED_TOKEN]'), # Bearer tokens + ] + + for pattern, replacement in patterns: + sanitized = re.sub(pattern, replacement, sanitized, flags=re.IGNORECASE) + + # Check for sensitive words + sensitive_words = ['internal', 'server', 'endpoint'] + if any(word in sanitized.lower() for word in sensitive_words): + return "OpenRouter API encountered an error. Please verify your API key and quota." + + return sanitized + + class ProviderErrorFactory: """Factory for provider-agnostic error handling.""" @@ -125,6 +146,7 @@ class ProviderErrorFactory: "openai": OpenAIErrorAdapter(), "google": GoogleAIErrorAdapter(), "anthropic": AnthropicErrorAdapter(), + "openrouter": OpenRouterErrorAdapter(), } @classmethod @@ -141,22 +163,18 @@ class ProviderErrorFactory: """Detect provider from error message with comprehensive pattern matching.""" if not error_str: return "openai" - + error_lower = error_str.lower() - + # Case-insensitive provider detection with multiple patterns - if ("anthropic" in error_lower or - re.search(r'sk-ant-[a-zA-Z0-9_-]+', error_str, re.IGNORECASE) or - "claude" in error_lower): + # Check OpenRouter first since it may contain "openai" in model names + if ("openrouter" in error_lower or re.search(r'sk-or-v1-[a-zA-Z0-9_-]+', error_str, re.IGNORECASE)): + return "openrouter" + elif ("anthropic" in error_lower or re.search(r'sk-ant-[a-zA-Z0-9_-]+', error_str, re.IGNORECASE) or "claude" in error_lower): return "anthropic" - elif ("google" in error_lower or - re.search(r'AIza[a-zA-Z0-9_-]+', error_str, re.IGNORECASE) or - "googleapis" in error_lower or - "vertex" in error_lower): + elif ("google" in error_lower or re.search(r'AIza[a-zA-Z0-9_-]+', error_str, re.IGNORECASE) or "googleapis" in error_lower or "vertex" in error_lower): return "google" - elif ("openai" in error_lower or - re.search(r'sk-[a-zA-Z0-9]{48}', error_str, re.IGNORECASE) or - "gpt" in error_lower): + elif ("openai" in error_lower or re.search(r'sk-[a-zA-Z0-9]{48}', error_str, re.IGNORECASE) or "gpt" in error_lower): return "openai" else: - return "openai" # Safe default \ No newline at end of file + return "openai" # Safe default diff --git a/python/src/server/services/llm_provider_service.py b/python/src/server/services/llm_provider_service.py index 00197926..a4ac96c5 100644 --- a/python/src/server/services/llm_provider_service.py +++ b/python/src/server/services/llm_provider_service.py @@ -554,12 +554,12 @@ async def _get_optimal_ollama_instance(instance_type: str | None = None, base_url_override: str | None = None) -> str: """ Get the optimal Ollama instance URL based on configuration and health status. - + Args: instance_type: Preferred instance type ('chat', 'embedding', 'both', or None) use_embedding_provider: Whether this is for embedding operations base_url_override: Override URL if specified - + Returns: Best available Ollama instance URL """ @@ -655,8 +655,8 @@ async def get_embedding_model(provider: str | None = None) -> str: return "text-embedding-004" elif provider_name == "openrouter": # OpenRouter supports both OpenAI and Google embedding models - # Default to OpenAI's latest for compatibility - return "text-embedding-3-small" + # Model names MUST include provider prefix for OpenRouter API + return "openai/text-embedding-3-small" elif provider_name == "anthropic": # Anthropic supports OpenAI and Google embedding models through their API # Default to OpenAI's latest for compatibility @@ -846,7 +846,7 @@ def _extract_reasoning_strings(value: Any) -> list[str]: text = value.strip() return [text] if text else [] - if isinstance(value, (list, tuple, set)): + if isinstance(value, list | tuple | set): collected: list[str] = [] for item in value: collected.extend(_extract_reasoning_strings(item)) @@ -1135,11 +1135,11 @@ def prepare_chat_completion_params(model: str, params: dict) -> dict: async def get_embedding_model_with_routing(provider: str | None = None, instance_url: str | None = None) -> tuple[str, str]: """ Get the embedding model with intelligent routing for multi-instance setups. - + Args: provider: Override provider selection instance_url: Specific instance URL to use - + Returns: Tuple of (model_name, instance_url) for embedding operations """ @@ -1171,11 +1171,11 @@ async def get_embedding_model_with_routing(provider: str | None = None, instance async def validate_provider_instance(provider: str, instance_url: str | None = None) -> dict[str, any]: """ Validate a provider instance and return health information. - + Args: provider: Provider name (openai, ollama, google, etc.) instance_url: Instance URL for providers that support multiple instances - + Returns: Dictionary with validation results and health status """ diff --git a/python/src/server/services/openrouter_discovery_service.py b/python/src/server/services/openrouter_discovery_service.py new file mode 100644 index 00000000..95f4451f --- /dev/null +++ b/python/src/server/services/openrouter_discovery_service.py @@ -0,0 +1,137 @@ +""" +OpenRouter model discovery service. + +Provides discovery and metadata for OpenRouter embedding models. +""" + +from pydantic import BaseModel, Field, field_validator + + +class OpenRouterEmbeddingModel(BaseModel): + """OpenRouter embedding model metadata.""" + + id: str = Field(..., description="Full model ID with provider prefix (e.g., openai/text-embedding-3-large)") + provider: str = Field(..., description="Provider name (openai, google, qwen, mistralai)") + name: str = Field(..., description="Display name without prefix") + dimensions: int = Field(..., description="Embedding dimensions") + context_length: int = Field(..., description="Maximum context window in tokens") + pricing_per_1m_tokens: float = Field(..., description="Cost per 1M tokens in USD") + supports_dimension_reduction: bool = Field(default=False, description="Whether model supports dimension parameter") + + @field_validator("id") + @classmethod + def validate_model_id_has_prefix(cls, v: str) -> str: + """Ensure model ID includes provider prefix.""" + if "/" not in v: + raise ValueError("OpenRouter model IDs must include provider prefix (e.g., openai/model-name)") + return v + + +class OpenRouterModelListResponse(BaseModel): + """Response from OpenRouter model discovery.""" + + embedding_models: list[OpenRouterEmbeddingModel] = Field(default_factory=list) + total_count: int = Field(..., description="Total number of embedding models") + + +class OpenRouterDiscoveryService: + """Discover and manage OpenRouter embedding models.""" + + async def discover_embedding_models(self) -> list[OpenRouterEmbeddingModel]: + """ + Get available OpenRouter embedding models. + + Returns hardcoded list of supported embedding models with metadata. + Future enhancement: Could fetch from OpenRouter API if they provide a models endpoint. + """ + return [ + # OpenAI models via OpenRouter + OpenRouterEmbeddingModel( + id="openai/text-embedding-3-small", + provider="openai", + name="text-embedding-3-small", + dimensions=1536, + context_length=8191, + pricing_per_1m_tokens=0.02, + supports_dimension_reduction=True, + ), + OpenRouterEmbeddingModel( + id="openai/text-embedding-3-large", + provider="openai", + name="text-embedding-3-large", + dimensions=3072, + context_length=8191, + pricing_per_1m_tokens=0.13, + supports_dimension_reduction=True, + ), + OpenRouterEmbeddingModel( + id="openai/text-embedding-ada-002", + provider="openai", + name="text-embedding-ada-002", + dimensions=1536, + context_length=8191, + pricing_per_1m_tokens=0.10, + supports_dimension_reduction=False, + ), + # Google models via OpenRouter + OpenRouterEmbeddingModel( + id="google/gemini-embedding-001", + provider="google", + name="gemini-embedding-001", + dimensions=768, + context_length=20000, + pricing_per_1m_tokens=0.00, # Free tier available + supports_dimension_reduction=True, + ), + OpenRouterEmbeddingModel( + id="google/text-embedding-004", + provider="google", + name="text-embedding-004", + dimensions=768, + context_length=20000, + pricing_per_1m_tokens=0.00, # Free tier available + supports_dimension_reduction=True, + ), + # Qwen models via OpenRouter + OpenRouterEmbeddingModel( + id="qwen/qwen3-embedding-0.6b", + provider="qwen", + name="qwen3-embedding-0.6b", + dimensions=1024, + context_length=32768, + pricing_per_1m_tokens=0.01, + supports_dimension_reduction=False, + ), + OpenRouterEmbeddingModel( + id="qwen/qwen3-embedding-4b", + provider="qwen", + name="qwen3-embedding-4b", + dimensions=1024, + context_length=32768, + pricing_per_1m_tokens=0.01, + supports_dimension_reduction=False, + ), + OpenRouterEmbeddingModel( + id="qwen/qwen3-embedding-8b", + provider="qwen", + name="qwen3-embedding-8b", + dimensions=1024, + context_length=32768, + pricing_per_1m_tokens=0.01, + supports_dimension_reduction=False, + ), + # Mistral models via OpenRouter + OpenRouterEmbeddingModel( + id="mistralai/mistral-embed", + provider="mistralai", + name="mistral-embed", + dimensions=1024, + context_length=8192, + pricing_per_1m_tokens=0.10, + supports_dimension_reduction=False, + ), + ] + + +# Create singleton instance +openrouter_discovery_service = OpenRouterDiscoveryService() diff --git a/python/tests/test_openrouter_discovery.py b/python/tests/test_openrouter_discovery.py new file mode 100644 index 00000000..d9cf8633 --- /dev/null +++ b/python/tests/test_openrouter_discovery.py @@ -0,0 +1,158 @@ +""" +Unit tests for OpenRouter model discovery service. +""" + +import pytest + +from src.server.services.openrouter_discovery_service import ( + OpenRouterDiscoveryService, + OpenRouterEmbeddingModel, + OpenRouterModelListResponse, +) + + +@pytest.fixture +def discovery_service(): + """Create OpenRouter discovery service instance.""" + return OpenRouterDiscoveryService() + + +@pytest.mark.asyncio +async def test_discover_embedding_models_returns_valid_list(discovery_service): + """Test that discover_embedding_models returns a non-empty list of models.""" + models = await discovery_service.discover_embedding_models() + + assert isinstance(models, list) + assert len(models) > 0 + assert all(isinstance(model, OpenRouterEmbeddingModel) for model in models) + + +@pytest.mark.asyncio +async def test_all_models_have_provider_prefix(discovery_service): + """Test that all model IDs include provider prefix.""" + models = await discovery_service.discover_embedding_models() + + for model in models: + assert "/" in model.id, f"Model ID '{model.id}' missing provider prefix" + assert model.id.startswith( + f"{model.provider}/" + ), f"Model ID '{model.id}' doesn't match provider '{model.provider}'" + + +@pytest.mark.asyncio +async def test_dimensions_are_positive_integers(discovery_service): + """Test that all models have positive integer dimensions.""" + models = await discovery_service.discover_embedding_models() + + for model in models: + assert isinstance(model.dimensions, int), f"Model '{model.id}' dimensions is not an integer" + assert model.dimensions > 0, f"Model '{model.id}' has non-positive dimensions: {model.dimensions}" + + +@pytest.mark.asyncio +async def test_pricing_is_non_negative(discovery_service): + """Test that all models have non-negative pricing.""" + models = await discovery_service.discover_embedding_models() + + for model in models: + assert isinstance( + model.pricing_per_1m_tokens, (int, float) + ), f"Model '{model.id}' pricing is not numeric" + assert ( + model.pricing_per_1m_tokens >= 0 + ), f"Model '{model.id}' has negative pricing: {model.pricing_per_1m_tokens}" + + +@pytest.mark.asyncio +async def test_context_length_is_positive(discovery_service): + """Test that all models have positive context length.""" + models = await discovery_service.discover_embedding_models() + + for model in models: + assert isinstance( + model.context_length, int + ), f"Model '{model.id}' context_length is not an integer" + assert ( + model.context_length > 0 + ), f"Model '{model.id}' has non-positive context_length: {model.context_length}" + + +@pytest.mark.asyncio +async def test_model_providers_are_valid(discovery_service): + """Test that all models have valid provider names.""" + models = await discovery_service.discover_embedding_models() + valid_providers = {"openai", "google", "qwen", "mistralai"} + + for model in models: + assert ( + model.provider in valid_providers + ), f"Model '{model.id}' has invalid provider: {model.provider}" + + +@pytest.mark.asyncio +async def test_openai_models_present(discovery_service): + """Test that OpenAI models are included in the list.""" + models = await discovery_service.discover_embedding_models() + openai_models = [m for m in models if m.provider == "openai"] + + assert len(openai_models) > 0, "No OpenAI models found" + assert any( + "text-embedding-3-small" in m.id for m in openai_models + ), "text-embedding-3-small not found" + assert any( + "text-embedding-3-large" in m.id for m in openai_models + ), "text-embedding-3-large not found" + + +@pytest.mark.asyncio +async def test_qwen_models_present(discovery_service): + """Test that Qwen models are included in the list.""" + models = await discovery_service.discover_embedding_models() + qwen_models = [m for m in models if m.provider == "qwen"] + + assert len(qwen_models) > 0, "No Qwen models found" + # Verify at least one Qwen3 embedding model is present + assert any("qwen3-embedding" in m.id for m in qwen_models), "No Qwen3 embedding models found" + + +@pytest.mark.asyncio +async def test_model_list_response_structure(): + """Test OpenRouterModelListResponse structure.""" + service = OpenRouterDiscoveryService() + models = await service.discover_embedding_models() + + response = OpenRouterModelListResponse(embedding_models=models, total_count=len(models)) + + assert response.total_count == len(models) + assert response.total_count == len(response.embedding_models) + assert response.total_count > 0 + + +def test_model_id_validation_requires_prefix(): + """Test that model ID validation enforces provider prefix.""" + with pytest.raises(ValueError, match="must include provider prefix"): + OpenRouterEmbeddingModel( + id="text-embedding-3-small", # Missing provider prefix + provider="openai", + name="text-embedding-3-small", + dimensions=1536, + context_length=8191, + pricing_per_1m_tokens=0.02, + supports_dimension_reduction=True, + ) + + +def test_model_with_valid_prefix_accepted(): + """Test that model with valid provider prefix is accepted.""" + model = OpenRouterEmbeddingModel( + id="openai/text-embedding-3-small", + provider="openai", + name="text-embedding-3-small", + dimensions=1536, + context_length=8191, + pricing_per_1m_tokens=0.02, + supports_dimension_reduction=True, + ) + + assert model.id == "openai/text-embedding-3-small" + assert "/" in model.id