Merge pull request #852 from coleam00/feature/openrouter-embeddings-support

Add OpenRouter Embeddings Support
This commit is contained in:
sean-esk
2025-11-29 14:24:10 -05:00
committed by GitHub
11 changed files with 637 additions and 37 deletions

View File

@@ -124,6 +124,7 @@ PROD=false
# Run the credentials_setup.sql file in your Supabase SQL editor to set up the credentials table. # Run the credentials_setup.sql file in your Supabase SQL editor to set up the credentials table.
# Then use the Settings page in the web UI to manage: # Then use the Settings page in the web UI to manage:
# - OPENAI_API_KEY (encrypted) # - OPENAI_API_KEY (encrypted)
# - OPENROUTER_API_KEY (encrypted, format: sk-or-v1-..., get from https://openrouter.ai/keys)
# - MODEL_CHOICE # - MODEL_CHOICE
# - TRANSPORT settings # - TRANSPORT settings
# - RAG strategy flags (USE_CONTEXTUAL_EMBEDDINGS, USE_HYBRID_SEARCH, etc.) # - RAG strategy flags (USE_CONTEXTUAL_EMBEDDINGS, USE_HYBRID_SEARCH, etc.)

View File

@@ -15,7 +15,7 @@ import OllamaModelSelectionModal from './OllamaModelSelectionModal';
type ProviderKey = 'openai' | 'google' | 'ollama' | 'anthropic' | 'grok' | 'openrouter'; type ProviderKey = 'openai' | 'google' | 'ollama' | 'anthropic' | 'grok' | 'openrouter';
// Providers that support embedding models // Providers that support embedding models
const EMBEDDING_CAPABLE_PROVIDERS: ProviderKey[] = ['openai', 'google', 'ollama']; const EMBEDDING_CAPABLE_PROVIDERS: ProviderKey[] = ['openai', 'google', 'openrouter', 'ollama'];
interface ProviderModels { interface ProviderModels {
chatModel: string; chatModel: string;
@@ -42,7 +42,7 @@ const getDefaultModels = (provider: ProviderKey): ProviderModels => {
anthropic: 'text-embedding-3-small', // Fallback to OpenAI anthropic: 'text-embedding-3-small', // Fallback to OpenAI
google: 'text-embedding-004', google: 'text-embedding-004',
grok: 'text-embedding-3-small', // Fallback to OpenAI grok: 'text-embedding-3-small', // Fallback to OpenAI
openrouter: 'text-embedding-3-small', openrouter: 'openai/text-embedding-3-small', // MUST include provider prefix for OpenRouter
ollama: 'nomic-embed-text' ollama: 'nomic-embed-text'
}; };
@@ -1291,7 +1291,7 @@ const manualTestConnection = async (
Select {activeSelection === 'chat' ? 'Chat' : 'Embedding'} Provider Select {activeSelection === 'chat' ? 'Chat' : 'Embedding'} Provider
</label> </label>
<div className={`grid gap-3 mb-4 ${ <div className={`grid gap-3 mb-4 ${
activeSelection === 'chat' ? 'grid-cols-6' : 'grid-cols-3' activeSelection === 'chat' ? 'grid-cols-6' : 'grid-cols-4'
}`}> }`}>
{[ {[
{ key: 'openai', name: 'OpenAI', logo: '/img/OpenAI.png', color: 'green' }, { key: 'openai', name: 'OpenAI', logo: '/img/OpenAI.png', color: 'green' },

View File

@@ -0,0 +1,244 @@
/**
* OpenRouter Service Client
*
* Provides frontend API client for OpenRouter model discovery.
*/
import { getApiUrl } from "../config/api";
// Type definitions for OpenRouter API responses
export interface OpenRouterEmbeddingModel {
id: string;
provider: string;
name: string;
dimensions: number;
context_length: number;
pricing_per_1m_tokens: number;
supports_dimension_reduction: boolean;
}
export interface OpenRouterModelListResponse {
embedding_models: OpenRouterEmbeddingModel[];
total_count: number;
}
class OpenRouterService {
private getBaseUrl = () => getApiUrl();
private cacheKey = "openrouter_models_cache";
private cacheTTL = 5 * 60 * 1000; // 5 minutes
private handleApiError(error: unknown, context: string): Error {
const errorMessage = error instanceof Error ? error.message : String(error);
// Check for network errors
if (
errorMessage.toLowerCase().includes("network") ||
errorMessage.includes("fetch") ||
errorMessage.includes("Failed to fetch")
) {
return new Error(
`Network error while ${context.toLowerCase()}: ${errorMessage}. ` +
"Please check your connection.",
);
}
// Check for timeout errors
if (errorMessage.includes("timeout") || errorMessage.includes("AbortError")) {
return new Error(
`Timeout error while ${context.toLowerCase()}: The server may be slow to respond.`,
);
}
// Return original error with context
return new Error(`${context} failed: ${errorMessage}`);
}
/**
* Type guard to validate cache entry structure
*/
private isCacheEntry(
value: unknown,
): value is { data: OpenRouterModelListResponse; timestamp: number } {
if (typeof value !== "object" || value === null) {
return false;
}
const obj = value as Record<string, unknown>;
// Validate timestamp is a number
if (typeof obj.timestamp !== "number") {
return false;
}
// Validate data property exists and is an object
if (typeof obj.data !== "object" || obj.data === null) {
return false;
}
const data = obj.data as Record<string, unknown>;
// Validate OpenRouterModelListResponse structure
if (!Array.isArray(data.embedding_models)) {
return false;
}
if (typeof data.total_count !== "number") {
return false;
}
// Validate each model in the array has required fields
for (const model of data.embedding_models) {
if (typeof model !== "object" || model === null) {
return false;
}
const m = model as Record<string, unknown>;
if (
typeof m.id !== "string" ||
typeof m.provider !== "string" ||
typeof m.name !== "string" ||
typeof m.dimensions !== "number" ||
typeof m.context_length !== "number" ||
typeof m.pricing_per_1m_tokens !== "number" ||
typeof m.supports_dimension_reduction !== "boolean"
) {
return false;
}
}
return true;
}
/**
* Get cached models if available and not expired
*/
private getCachedModels(): OpenRouterModelListResponse | null {
try {
const cached = sessionStorage.getItem(this.cacheKey);
if (!cached) return null;
const parsed: unknown = JSON.parse(cached);
// Validate cache structure
if (!this.isCacheEntry(parsed)) {
// Cache is corrupted, remove it to avoid repeated failures
sessionStorage.removeItem(this.cacheKey);
return null;
}
const now = Date.now();
// Check expiration
if (now - parsed.timestamp > this.cacheTTL) {
sessionStorage.removeItem(this.cacheKey);
return null;
}
return parsed.data;
} catch {
// JSON parsing failed or other error, clear cache
sessionStorage.removeItem(this.cacheKey);
return null;
}
}
/**
* Cache models for the TTL duration
*/
private cacheModels(data: OpenRouterModelListResponse): void {
try {
const cacheData = {
data,
timestamp: Date.now(),
};
sessionStorage.setItem(this.cacheKey, JSON.stringify(cacheData));
} catch {
// Ignore cache errors
}
}
/**
* Discover available OpenRouter embedding models
*/
async discoverModels(): Promise<OpenRouterModelListResponse> {
try {
// Check cache first
const cached = this.getCachedModels();
if (cached) {
return cached;
}
const response = await fetch(`${this.getBaseUrl()}/api/openrouter/models`, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(`HTTP ${response.status}: ${errorText}`);
}
const data = await response.json();
// Validate response structure
if (!data.embedding_models || !Array.isArray(data.embedding_models)) {
throw new Error("Invalid response structure: missing or invalid embedding_models array");
}
if (typeof data.total_count !== "number" || data.total_count < 0) {
throw new Error("Invalid response structure: total_count must be a non-negative number");
}
if (data.total_count !== data.embedding_models.length) {
throw new Error(
`Response structure mismatch: total_count (${data.total_count}) does not match embedding_models length (${data.embedding_models.length})`,
);
}
// Validate at least one model has required fields
if (data.embedding_models.length > 0) {
const firstModel = data.embedding_models[0];
if (
!firstModel.id ||
typeof firstModel.id !== "string" ||
!firstModel.provider ||
typeof firstModel.provider !== "string" ||
typeof firstModel.dimensions !== "number" ||
firstModel.dimensions <= 0
) {
throw new Error(
"Invalid model structure: models must have id (string), provider (string), and positive dimensions",
);
}
// Validate provider name is from expected set
const validProviders = ["openai", "google", "qwen", "mistralai"];
if (!validProviders.includes(firstModel.provider)) {
throw new Error(`Invalid provider name: ${firstModel.provider}`);
}
}
// Cache the successful response
this.cacheModels(data);
return data;
} catch (error) {
throw this.handleApiError(error, "Model discovery");
}
}
/**
* Clear the models cache
*/
clearCache(): void {
try {
sessionStorage.removeItem(this.cacheKey);
} catch {
// Ignore cache clearing errors
}
}
}
// Export singleton instance
export const openrouterService = new OpenRouterService();

View File

@@ -0,0 +1,27 @@
"""
OpenRouter API routes.
Endpoints for OpenRouter model discovery and configuration.
"""
from fastapi import APIRouter
from ..services.openrouter_discovery_service import OpenRouterModelListResponse, openrouter_discovery_service
router = APIRouter(prefix="/api/openrouter", tags=["openrouter"])
@router.get("/models", response_model=OpenRouterModelListResponse)
async def get_openrouter_models() -> OpenRouterModelListResponse:
"""
Get available OpenRouter embedding models.
Returns a list of embedding models available through OpenRouter,
including models from OpenAI, Google, Qwen, and Mistral providers.
Returns:
OpenRouterModelListResponse: List of embedding models with metadata
"""
models = await openrouter_discovery_service.discover_embedding_models()
return OpenRouterModelListResponse(embedding_models=models, total_count=len(models))

View File

@@ -66,6 +66,19 @@ def validate_openai_api_key(api_key: str) -> bool:
return True return True
def validate_openrouter_api_key(api_key: str) -> bool:
"""Validate OpenRouter API key format."""
if not api_key:
raise ConfigurationError("OpenRouter API key cannot be empty")
if not api_key.startswith("sk-or-v1-"):
raise ConfigurationError(
"OpenRouter API key must start with 'sk-or-v1-'. " "Get your key at https://openrouter.ai/keys"
)
return True
def validate_supabase_key(supabase_key: str) -> tuple[bool, str]: def validate_supabase_key(supabase_key: str) -> tuple[bool, str]:
"""Validate Supabase key type and return validation result. """Validate Supabase key type and return validation result.

View File

@@ -26,6 +26,7 @@ from .api_routes.knowledge_api import router as knowledge_router
from .api_routes.mcp_api import router as mcp_router from .api_routes.mcp_api import router as mcp_router
from .api_routes.migration_api import router as migration_router from .api_routes.migration_api import router as migration_router
from .api_routes.ollama_api import router as ollama_router from .api_routes.ollama_api import router as ollama_router
from .api_routes.openrouter_api import router as openrouter_router
from .api_routes.pages_api import router as pages_router from .api_routes.pages_api import router as pages_router
from .api_routes.progress_api import router as progress_router from .api_routes.progress_api import router as progress_router
from .api_routes.projects_api import router as projects_router from .api_routes.projects_api import router as projects_router
@@ -187,6 +188,7 @@ app.include_router(mcp_router)
app.include_router(knowledge_router) app.include_router(knowledge_router)
app.include_router(pages_router) app.include_router(pages_router)
app.include_router(ollama_router) app.include_router(ollama_router)
app.include_router(openrouter_router)
app.include_router(projects_router) app.include_router(projects_router)
app.include_router(progress_router) app.include_router(progress_router)
app.include_router(agent_chat_router) app.include_router(agent_chat_router)

View File

@@ -443,7 +443,7 @@ class CredentialService:
explicit_embedding_provider = rag_settings.get("EMBEDDING_PROVIDER") explicit_embedding_provider = rag_settings.get("EMBEDDING_PROVIDER")
# Validate that embedding provider actually supports embeddings # Validate that embedding provider actually supports embeddings
embedding_capable_providers = {"openai", "google", "ollama"} embedding_capable_providers = {"openai", "google", "openrouter", "ollama"}
if (explicit_embedding_provider and if (explicit_embedding_provider and
explicit_embedding_provider != "" and explicit_embedding_provider != "" and

View File

@@ -8,13 +8,6 @@ with unified error handling and sanitization patterns.
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from .embedding_exceptions import (
EmbeddingAPIError,
EmbeddingAuthenticationError,
EmbeddingQuotaExhaustedError,
EmbeddingRateLimitError,
)
class ProviderErrorAdapter(ABC): class ProviderErrorAdapter(ABC):
"""Abstract base class for provider-specific error handling.""" """Abstract base class for provider-specific error handling."""
@@ -118,6 +111,34 @@ class AnthropicErrorAdapter(ProviderErrorAdapter):
return sanitized return sanitized
class OpenRouterErrorAdapter(ProviderErrorAdapter):
def get_provider_name(self) -> str:
return "openrouter"
def sanitize_error_message(self, message: str) -> str:
if not isinstance(message, str) or not message.strip() or len(message) > 2000:
return "OpenRouter API encountered an error. Please verify your API key and quota."
sanitized = message
# Comprehensive OpenRouter patterns
patterns = [
(r'sk-or-v1-[a-zA-Z0-9_-]{10,}', '[REDACTED_KEY]'), # OpenRouter API keys
(r'https?://[^\s]*openrouter\.ai[^\s]*', '[REDACTED_URL]'), # OpenRouter URLs
(r'Bearer\s+[a-zA-Z0-9._-]+', 'Bearer [REDACTED_TOKEN]'), # Bearer tokens
]
for pattern, replacement in patterns:
sanitized = re.sub(pattern, replacement, sanitized, flags=re.IGNORECASE)
# Check for sensitive words
sensitive_words = ['internal', 'server', 'endpoint']
if any(word in sanitized.lower() for word in sensitive_words):
return "OpenRouter API encountered an error. Please verify your API key and quota."
return sanitized
class ProviderErrorFactory: class ProviderErrorFactory:
"""Factory for provider-agnostic error handling.""" """Factory for provider-agnostic error handling."""
@@ -125,6 +146,7 @@ class ProviderErrorFactory:
"openai": OpenAIErrorAdapter(), "openai": OpenAIErrorAdapter(),
"google": GoogleAIErrorAdapter(), "google": GoogleAIErrorAdapter(),
"anthropic": AnthropicErrorAdapter(), "anthropic": AnthropicErrorAdapter(),
"openrouter": OpenRouterErrorAdapter(),
} }
@classmethod @classmethod
@@ -145,18 +167,14 @@ class ProviderErrorFactory:
error_lower = error_str.lower() error_lower = error_str.lower()
# Case-insensitive provider detection with multiple patterns # Case-insensitive provider detection with multiple patterns
if ("anthropic" in error_lower or # Check OpenRouter first since it may contain "openai" in model names
re.search(r'sk-ant-[a-zA-Z0-9_-]+', error_str, re.IGNORECASE) or if ("openrouter" in error_lower or re.search(r'sk-or-v1-[a-zA-Z0-9_-]+', error_str, re.IGNORECASE)):
"claude" in error_lower): return "openrouter"
elif ("anthropic" in error_lower or re.search(r'sk-ant-[a-zA-Z0-9_-]+', error_str, re.IGNORECASE) or "claude" in error_lower):
return "anthropic" return "anthropic"
elif ("google" in error_lower or elif ("google" in error_lower or re.search(r'AIza[a-zA-Z0-9_-]+', error_str, re.IGNORECASE) or "googleapis" in error_lower or "vertex" in error_lower):
re.search(r'AIza[a-zA-Z0-9_-]+', error_str, re.IGNORECASE) or
"googleapis" in error_lower or
"vertex" in error_lower):
return "google" return "google"
elif ("openai" in error_lower or elif ("openai" in error_lower or re.search(r'sk-[a-zA-Z0-9]{48}', error_str, re.IGNORECASE) or "gpt" in error_lower):
re.search(r'sk-[a-zA-Z0-9]{48}', error_str, re.IGNORECASE) or
"gpt" in error_lower):
return "openai" return "openai"
else: else:
return "openai" # Safe default return "openai" # Safe default

View File

@@ -655,8 +655,8 @@ async def get_embedding_model(provider: str | None = None) -> str:
return "text-embedding-004" return "text-embedding-004"
elif provider_name == "openrouter": elif provider_name == "openrouter":
# OpenRouter supports both OpenAI and Google embedding models # OpenRouter supports both OpenAI and Google embedding models
# Default to OpenAI's latest for compatibility # Model names MUST include provider prefix for OpenRouter API
return "text-embedding-3-small" return "openai/text-embedding-3-small"
elif provider_name == "anthropic": elif provider_name == "anthropic":
# Anthropic supports OpenAI and Google embedding models through their API # Anthropic supports OpenAI and Google embedding models through their API
# Default to OpenAI's latest for compatibility # Default to OpenAI's latest for compatibility
@@ -846,7 +846,7 @@ def _extract_reasoning_strings(value: Any) -> list[str]:
text = value.strip() text = value.strip()
return [text] if text else [] return [text] if text else []
if isinstance(value, (list, tuple, set)): if isinstance(value, list | tuple | set):
collected: list[str] = [] collected: list[str] = []
for item in value: for item in value:
collected.extend(_extract_reasoning_strings(item)) collected.extend(_extract_reasoning_strings(item))

View File

@@ -0,0 +1,137 @@
"""
OpenRouter model discovery service.
Provides discovery and metadata for OpenRouter embedding models.
"""
from pydantic import BaseModel, Field, field_validator
class OpenRouterEmbeddingModel(BaseModel):
"""OpenRouter embedding model metadata."""
id: str = Field(..., description="Full model ID with provider prefix (e.g., openai/text-embedding-3-large)")
provider: str = Field(..., description="Provider name (openai, google, qwen, mistralai)")
name: str = Field(..., description="Display name without prefix")
dimensions: int = Field(..., description="Embedding dimensions")
context_length: int = Field(..., description="Maximum context window in tokens")
pricing_per_1m_tokens: float = Field(..., description="Cost per 1M tokens in USD")
supports_dimension_reduction: bool = Field(default=False, description="Whether model supports dimension parameter")
@field_validator("id")
@classmethod
def validate_model_id_has_prefix(cls, v: str) -> str:
"""Ensure model ID includes provider prefix."""
if "/" not in v:
raise ValueError("OpenRouter model IDs must include provider prefix (e.g., openai/model-name)")
return v
class OpenRouterModelListResponse(BaseModel):
"""Response from OpenRouter model discovery."""
embedding_models: list[OpenRouterEmbeddingModel] = Field(default_factory=list)
total_count: int = Field(..., description="Total number of embedding models")
class OpenRouterDiscoveryService:
"""Discover and manage OpenRouter embedding models."""
async def discover_embedding_models(self) -> list[OpenRouterEmbeddingModel]:
"""
Get available OpenRouter embedding models.
Returns hardcoded list of supported embedding models with metadata.
Future enhancement: Could fetch from OpenRouter API if they provide a models endpoint.
"""
return [
# OpenAI models via OpenRouter
OpenRouterEmbeddingModel(
id="openai/text-embedding-3-small",
provider="openai",
name="text-embedding-3-small",
dimensions=1536,
context_length=8191,
pricing_per_1m_tokens=0.02,
supports_dimension_reduction=True,
),
OpenRouterEmbeddingModel(
id="openai/text-embedding-3-large",
provider="openai",
name="text-embedding-3-large",
dimensions=3072,
context_length=8191,
pricing_per_1m_tokens=0.13,
supports_dimension_reduction=True,
),
OpenRouterEmbeddingModel(
id="openai/text-embedding-ada-002",
provider="openai",
name="text-embedding-ada-002",
dimensions=1536,
context_length=8191,
pricing_per_1m_tokens=0.10,
supports_dimension_reduction=False,
),
# Google models via OpenRouter
OpenRouterEmbeddingModel(
id="google/gemini-embedding-001",
provider="google",
name="gemini-embedding-001",
dimensions=768,
context_length=20000,
pricing_per_1m_tokens=0.00, # Free tier available
supports_dimension_reduction=True,
),
OpenRouterEmbeddingModel(
id="google/text-embedding-004",
provider="google",
name="text-embedding-004",
dimensions=768,
context_length=20000,
pricing_per_1m_tokens=0.00, # Free tier available
supports_dimension_reduction=True,
),
# Qwen models via OpenRouter
OpenRouterEmbeddingModel(
id="qwen/qwen3-embedding-0.6b",
provider="qwen",
name="qwen3-embedding-0.6b",
dimensions=1024,
context_length=32768,
pricing_per_1m_tokens=0.01,
supports_dimension_reduction=False,
),
OpenRouterEmbeddingModel(
id="qwen/qwen3-embedding-4b",
provider="qwen",
name="qwen3-embedding-4b",
dimensions=1024,
context_length=32768,
pricing_per_1m_tokens=0.01,
supports_dimension_reduction=False,
),
OpenRouterEmbeddingModel(
id="qwen/qwen3-embedding-8b",
provider="qwen",
name="qwen3-embedding-8b",
dimensions=1024,
context_length=32768,
pricing_per_1m_tokens=0.01,
supports_dimension_reduction=False,
),
# Mistral models via OpenRouter
OpenRouterEmbeddingModel(
id="mistralai/mistral-embed",
provider="mistralai",
name="mistral-embed",
dimensions=1024,
context_length=8192,
pricing_per_1m_tokens=0.10,
supports_dimension_reduction=False,
),
]
# Create singleton instance
openrouter_discovery_service = OpenRouterDiscoveryService()

View File

@@ -0,0 +1,158 @@
"""
Unit tests for OpenRouter model discovery service.
"""
import pytest
from src.server.services.openrouter_discovery_service import (
OpenRouterDiscoveryService,
OpenRouterEmbeddingModel,
OpenRouterModelListResponse,
)
@pytest.fixture
def discovery_service():
"""Create OpenRouter discovery service instance."""
return OpenRouterDiscoveryService()
@pytest.mark.asyncio
async def test_discover_embedding_models_returns_valid_list(discovery_service):
"""Test that discover_embedding_models returns a non-empty list of models."""
models = await discovery_service.discover_embedding_models()
assert isinstance(models, list)
assert len(models) > 0
assert all(isinstance(model, OpenRouterEmbeddingModel) for model in models)
@pytest.mark.asyncio
async def test_all_models_have_provider_prefix(discovery_service):
"""Test that all model IDs include provider prefix."""
models = await discovery_service.discover_embedding_models()
for model in models:
assert "/" in model.id, f"Model ID '{model.id}' missing provider prefix"
assert model.id.startswith(
f"{model.provider}/"
), f"Model ID '{model.id}' doesn't match provider '{model.provider}'"
@pytest.mark.asyncio
async def test_dimensions_are_positive_integers(discovery_service):
"""Test that all models have positive integer dimensions."""
models = await discovery_service.discover_embedding_models()
for model in models:
assert isinstance(model.dimensions, int), f"Model '{model.id}' dimensions is not an integer"
assert model.dimensions > 0, f"Model '{model.id}' has non-positive dimensions: {model.dimensions}"
@pytest.mark.asyncio
async def test_pricing_is_non_negative(discovery_service):
"""Test that all models have non-negative pricing."""
models = await discovery_service.discover_embedding_models()
for model in models:
assert isinstance(
model.pricing_per_1m_tokens, (int, float)
), f"Model '{model.id}' pricing is not numeric"
assert (
model.pricing_per_1m_tokens >= 0
), f"Model '{model.id}' has negative pricing: {model.pricing_per_1m_tokens}"
@pytest.mark.asyncio
async def test_context_length_is_positive(discovery_service):
"""Test that all models have positive context length."""
models = await discovery_service.discover_embedding_models()
for model in models:
assert isinstance(
model.context_length, int
), f"Model '{model.id}' context_length is not an integer"
assert (
model.context_length > 0
), f"Model '{model.id}' has non-positive context_length: {model.context_length}"
@pytest.mark.asyncio
async def test_model_providers_are_valid(discovery_service):
"""Test that all models have valid provider names."""
models = await discovery_service.discover_embedding_models()
valid_providers = {"openai", "google", "qwen", "mistralai"}
for model in models:
assert (
model.provider in valid_providers
), f"Model '{model.id}' has invalid provider: {model.provider}"
@pytest.mark.asyncio
async def test_openai_models_present(discovery_service):
"""Test that OpenAI models are included in the list."""
models = await discovery_service.discover_embedding_models()
openai_models = [m for m in models if m.provider == "openai"]
assert len(openai_models) > 0, "No OpenAI models found"
assert any(
"text-embedding-3-small" in m.id for m in openai_models
), "text-embedding-3-small not found"
assert any(
"text-embedding-3-large" in m.id for m in openai_models
), "text-embedding-3-large not found"
@pytest.mark.asyncio
async def test_qwen_models_present(discovery_service):
"""Test that Qwen models are included in the list."""
models = await discovery_service.discover_embedding_models()
qwen_models = [m for m in models if m.provider == "qwen"]
assert len(qwen_models) > 0, "No Qwen models found"
# Verify at least one Qwen3 embedding model is present
assert any("qwen3-embedding" in m.id for m in qwen_models), "No Qwen3 embedding models found"
@pytest.mark.asyncio
async def test_model_list_response_structure():
"""Test OpenRouterModelListResponse structure."""
service = OpenRouterDiscoveryService()
models = await service.discover_embedding_models()
response = OpenRouterModelListResponse(embedding_models=models, total_count=len(models))
assert response.total_count == len(models)
assert response.total_count == len(response.embedding_models)
assert response.total_count > 0
def test_model_id_validation_requires_prefix():
"""Test that model ID validation enforces provider prefix."""
with pytest.raises(ValueError, match="must include provider prefix"):
OpenRouterEmbeddingModel(
id="text-embedding-3-small", # Missing provider prefix
provider="openai",
name="text-embedding-3-small",
dimensions=1536,
context_length=8191,
pricing_per_1m_tokens=0.02,
supports_dimension_reduction=True,
)
def test_model_with_valid_prefix_accepted():
"""Test that model with valid provider prefix is accepted."""
model = OpenRouterEmbeddingModel(
id="openai/text-embedding-3-small",
provider="openai",
name="text-embedding-3-small",
dimensions=1536,
context_length=8191,
pricing_per_1m_tokens=0.02,
supports_dimension_reduction=True,
)
assert model.id == "openai/text-embedding-3-small"
assert "/" in model.id