mirror of
https://github.com/coleam00/Archon.git
synced 2025-12-24 02:39:17 -05:00
Merge pull request #852 from coleam00/feature/openrouter-embeddings-support
Add OpenRouter Embeddings Support
This commit is contained in:
@@ -124,6 +124,7 @@ PROD=false
|
||||
# Run the credentials_setup.sql file in your Supabase SQL editor to set up the credentials table.
|
||||
# Then use the Settings page in the web UI to manage:
|
||||
# - OPENAI_API_KEY (encrypted)
|
||||
# - OPENROUTER_API_KEY (encrypted, format: sk-or-v1-..., get from https://openrouter.ai/keys)
|
||||
# - MODEL_CHOICE
|
||||
# - TRANSPORT settings
|
||||
# - RAG strategy flags (USE_CONTEXTUAL_EMBEDDINGS, USE_HYBRID_SEARCH, etc.)
|
||||
|
||||
@@ -15,7 +15,7 @@ import OllamaModelSelectionModal from './OllamaModelSelectionModal';
|
||||
type ProviderKey = 'openai' | 'google' | 'ollama' | 'anthropic' | 'grok' | 'openrouter';
|
||||
|
||||
// Providers that support embedding models
|
||||
const EMBEDDING_CAPABLE_PROVIDERS: ProviderKey[] = ['openai', 'google', 'ollama'];
|
||||
const EMBEDDING_CAPABLE_PROVIDERS: ProviderKey[] = ['openai', 'google', 'openrouter', 'ollama'];
|
||||
|
||||
interface ProviderModels {
|
||||
chatModel: string;
|
||||
@@ -42,7 +42,7 @@ const getDefaultModels = (provider: ProviderKey): ProviderModels => {
|
||||
anthropic: 'text-embedding-3-small', // Fallback to OpenAI
|
||||
google: 'text-embedding-004',
|
||||
grok: 'text-embedding-3-small', // Fallback to OpenAI
|
||||
openrouter: 'text-embedding-3-small',
|
||||
openrouter: 'openai/text-embedding-3-small', // MUST include provider prefix for OpenRouter
|
||||
ollama: 'nomic-embed-text'
|
||||
};
|
||||
|
||||
@@ -1291,7 +1291,7 @@ const manualTestConnection = async (
|
||||
Select {activeSelection === 'chat' ? 'Chat' : 'Embedding'} Provider
|
||||
</label>
|
||||
<div className={`grid gap-3 mb-4 ${
|
||||
activeSelection === 'chat' ? 'grid-cols-6' : 'grid-cols-3'
|
||||
activeSelection === 'chat' ? 'grid-cols-6' : 'grid-cols-4'
|
||||
}`}>
|
||||
{[
|
||||
{ key: 'openai', name: 'OpenAI', logo: '/img/OpenAI.png', color: 'green' },
|
||||
|
||||
244
archon-ui-main/src/services/openrouterService.ts
Normal file
244
archon-ui-main/src/services/openrouterService.ts
Normal file
@@ -0,0 +1,244 @@
|
||||
/**
|
||||
* OpenRouter Service Client
|
||||
*
|
||||
* Provides frontend API client for OpenRouter model discovery.
|
||||
*/
|
||||
|
||||
import { getApiUrl } from "../config/api";
|
||||
|
||||
// Type definitions for OpenRouter API responses
|
||||
export interface OpenRouterEmbeddingModel {
|
||||
id: string;
|
||||
provider: string;
|
||||
name: string;
|
||||
dimensions: number;
|
||||
context_length: number;
|
||||
pricing_per_1m_tokens: number;
|
||||
supports_dimension_reduction: boolean;
|
||||
}
|
||||
|
||||
export interface OpenRouterModelListResponse {
|
||||
embedding_models: OpenRouterEmbeddingModel[];
|
||||
total_count: number;
|
||||
}
|
||||
|
||||
class OpenRouterService {
|
||||
private getBaseUrl = () => getApiUrl();
|
||||
private cacheKey = "openrouter_models_cache";
|
||||
private cacheTTL = 5 * 60 * 1000; // 5 minutes
|
||||
|
||||
private handleApiError(error: unknown, context: string): Error {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||
|
||||
// Check for network errors
|
||||
if (
|
||||
errorMessage.toLowerCase().includes("network") ||
|
||||
errorMessage.includes("fetch") ||
|
||||
errorMessage.includes("Failed to fetch")
|
||||
) {
|
||||
return new Error(
|
||||
`Network error while ${context.toLowerCase()}: ${errorMessage}. ` +
|
||||
"Please check your connection.",
|
||||
);
|
||||
}
|
||||
|
||||
// Check for timeout errors
|
||||
if (errorMessage.includes("timeout") || errorMessage.includes("AbortError")) {
|
||||
return new Error(
|
||||
`Timeout error while ${context.toLowerCase()}: The server may be slow to respond.`,
|
||||
);
|
||||
}
|
||||
|
||||
// Return original error with context
|
||||
return new Error(`${context} failed: ${errorMessage}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Type guard to validate cache entry structure
|
||||
*/
|
||||
private isCacheEntry(
|
||||
value: unknown,
|
||||
): value is { data: OpenRouterModelListResponse; timestamp: number } {
|
||||
if (typeof value !== "object" || value === null) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const obj = value as Record<string, unknown>;
|
||||
|
||||
// Validate timestamp is a number
|
||||
if (typeof obj.timestamp !== "number") {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Validate data property exists and is an object
|
||||
if (typeof obj.data !== "object" || obj.data === null) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const data = obj.data as Record<string, unknown>;
|
||||
|
||||
// Validate OpenRouterModelListResponse structure
|
||||
if (!Array.isArray(data.embedding_models)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (typeof data.total_count !== "number") {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Validate each model in the array has required fields
|
||||
for (const model of data.embedding_models) {
|
||||
if (typeof model !== "object" || model === null) {
|
||||
return false;
|
||||
}
|
||||
const m = model as Record<string, unknown>;
|
||||
if (
|
||||
typeof m.id !== "string" ||
|
||||
typeof m.provider !== "string" ||
|
||||
typeof m.name !== "string" ||
|
||||
typeof m.dimensions !== "number" ||
|
||||
typeof m.context_length !== "number" ||
|
||||
typeof m.pricing_per_1m_tokens !== "number" ||
|
||||
typeof m.supports_dimension_reduction !== "boolean"
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get cached models if available and not expired
|
||||
*/
|
||||
private getCachedModels(): OpenRouterModelListResponse | null {
|
||||
try {
|
||||
const cached = sessionStorage.getItem(this.cacheKey);
|
||||
if (!cached) return null;
|
||||
|
||||
const parsed: unknown = JSON.parse(cached);
|
||||
|
||||
// Validate cache structure
|
||||
if (!this.isCacheEntry(parsed)) {
|
||||
// Cache is corrupted, remove it to avoid repeated failures
|
||||
sessionStorage.removeItem(this.cacheKey);
|
||||
return null;
|
||||
}
|
||||
|
||||
const now = Date.now();
|
||||
|
||||
// Check expiration
|
||||
if (now - parsed.timestamp > this.cacheTTL) {
|
||||
sessionStorage.removeItem(this.cacheKey);
|
||||
return null;
|
||||
}
|
||||
|
||||
return parsed.data;
|
||||
} catch {
|
||||
// JSON parsing failed or other error, clear cache
|
||||
sessionStorage.removeItem(this.cacheKey);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Cache models for the TTL duration
|
||||
*/
|
||||
private cacheModels(data: OpenRouterModelListResponse): void {
|
||||
try {
|
||||
const cacheData = {
|
||||
data,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
sessionStorage.setItem(this.cacheKey, JSON.stringify(cacheData));
|
||||
} catch {
|
||||
// Ignore cache errors
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Discover available OpenRouter embedding models
|
||||
*/
|
||||
async discoverModels(): Promise<OpenRouterModelListResponse> {
|
||||
try {
|
||||
// Check cache first
|
||||
const cached = this.getCachedModels();
|
||||
if (cached) {
|
||||
return cached;
|
||||
}
|
||||
|
||||
const response = await fetch(`${this.getBaseUrl()}/api/openrouter/models`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(`HTTP ${response.status}: ${errorText}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
// Validate response structure
|
||||
if (!data.embedding_models || !Array.isArray(data.embedding_models)) {
|
||||
throw new Error("Invalid response structure: missing or invalid embedding_models array");
|
||||
}
|
||||
|
||||
if (typeof data.total_count !== "number" || data.total_count < 0) {
|
||||
throw new Error("Invalid response structure: total_count must be a non-negative number");
|
||||
}
|
||||
|
||||
if (data.total_count !== data.embedding_models.length) {
|
||||
throw new Error(
|
||||
`Response structure mismatch: total_count (${data.total_count}) does not match embedding_models length (${data.embedding_models.length})`,
|
||||
);
|
||||
}
|
||||
|
||||
// Validate at least one model has required fields
|
||||
if (data.embedding_models.length > 0) {
|
||||
const firstModel = data.embedding_models[0];
|
||||
if (
|
||||
!firstModel.id ||
|
||||
typeof firstModel.id !== "string" ||
|
||||
!firstModel.provider ||
|
||||
typeof firstModel.provider !== "string" ||
|
||||
typeof firstModel.dimensions !== "number" ||
|
||||
firstModel.dimensions <= 0
|
||||
) {
|
||||
throw new Error(
|
||||
"Invalid model structure: models must have id (string), provider (string), and positive dimensions",
|
||||
);
|
||||
}
|
||||
|
||||
// Validate provider name is from expected set
|
||||
const validProviders = ["openai", "google", "qwen", "mistralai"];
|
||||
if (!validProviders.includes(firstModel.provider)) {
|
||||
throw new Error(`Invalid provider name: ${firstModel.provider}`);
|
||||
}
|
||||
}
|
||||
|
||||
// Cache the successful response
|
||||
this.cacheModels(data);
|
||||
|
||||
return data;
|
||||
} catch (error) {
|
||||
throw this.handleApiError(error, "Model discovery");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear the models cache
|
||||
*/
|
||||
clearCache(): void {
|
||||
try {
|
||||
sessionStorage.removeItem(this.cacheKey);
|
||||
} catch {
|
||||
// Ignore cache clearing errors
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Export singleton instance
|
||||
export const openrouterService = new OpenRouterService();
|
||||
27
python/src/server/api_routes/openrouter_api.py
Normal file
27
python/src/server/api_routes/openrouter_api.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""
|
||||
OpenRouter API routes.
|
||||
|
||||
Endpoints for OpenRouter model discovery and configuration.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from ..services.openrouter_discovery_service import OpenRouterModelListResponse, openrouter_discovery_service
|
||||
|
||||
router = APIRouter(prefix="/api/openrouter", tags=["openrouter"])
|
||||
|
||||
|
||||
@router.get("/models", response_model=OpenRouterModelListResponse)
|
||||
async def get_openrouter_models() -> OpenRouterModelListResponse:
|
||||
"""
|
||||
Get available OpenRouter embedding models.
|
||||
|
||||
Returns a list of embedding models available through OpenRouter,
|
||||
including models from OpenAI, Google, Qwen, and Mistral providers.
|
||||
|
||||
Returns:
|
||||
OpenRouterModelListResponse: List of embedding models with metadata
|
||||
"""
|
||||
models = await openrouter_discovery_service.discover_embedding_models()
|
||||
|
||||
return OpenRouterModelListResponse(embedding_models=models, total_count=len(models))
|
||||
@@ -66,6 +66,19 @@ def validate_openai_api_key(api_key: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def validate_openrouter_api_key(api_key: str) -> bool:
|
||||
"""Validate OpenRouter API key format."""
|
||||
if not api_key:
|
||||
raise ConfigurationError("OpenRouter API key cannot be empty")
|
||||
|
||||
if not api_key.startswith("sk-or-v1-"):
|
||||
raise ConfigurationError(
|
||||
"OpenRouter API key must start with 'sk-or-v1-'. " "Get your key at https://openrouter.ai/keys"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def validate_supabase_key(supabase_key: str) -> tuple[bool, str]:
|
||||
"""Validate Supabase key type and return validation result.
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from .api_routes.knowledge_api import router as knowledge_router
|
||||
from .api_routes.mcp_api import router as mcp_router
|
||||
from .api_routes.migration_api import router as migration_router
|
||||
from .api_routes.ollama_api import router as ollama_router
|
||||
from .api_routes.openrouter_api import router as openrouter_router
|
||||
from .api_routes.pages_api import router as pages_router
|
||||
from .api_routes.progress_api import router as progress_router
|
||||
from .api_routes.projects_api import router as projects_router
|
||||
@@ -187,6 +188,7 @@ app.include_router(mcp_router)
|
||||
app.include_router(knowledge_router)
|
||||
app.include_router(pages_router)
|
||||
app.include_router(ollama_router)
|
||||
app.include_router(openrouter_router)
|
||||
app.include_router(projects_router)
|
||||
app.include_router(progress_router)
|
||||
app.include_router(agent_chat_router)
|
||||
|
||||
@@ -443,7 +443,7 @@ class CredentialService:
|
||||
explicit_embedding_provider = rag_settings.get("EMBEDDING_PROVIDER")
|
||||
|
||||
# Validate that embedding provider actually supports embeddings
|
||||
embedding_capable_providers = {"openai", "google", "ollama"}
|
||||
embedding_capable_providers = {"openai", "google", "openrouter", "ollama"}
|
||||
|
||||
if (explicit_embedding_provider and
|
||||
explicit_embedding_provider != "" and
|
||||
|
||||
@@ -8,13 +8,6 @@ with unified error handling and sanitization patterns.
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from .embedding_exceptions import (
|
||||
EmbeddingAPIError,
|
||||
EmbeddingAuthenticationError,
|
||||
EmbeddingQuotaExhaustedError,
|
||||
EmbeddingRateLimitError,
|
||||
)
|
||||
|
||||
|
||||
class ProviderErrorAdapter(ABC):
|
||||
"""Abstract base class for provider-specific error handling."""
|
||||
@@ -118,6 +111,34 @@ class AnthropicErrorAdapter(ProviderErrorAdapter):
|
||||
return sanitized
|
||||
|
||||
|
||||
class OpenRouterErrorAdapter(ProviderErrorAdapter):
|
||||
def get_provider_name(self) -> str:
|
||||
return "openrouter"
|
||||
|
||||
def sanitize_error_message(self, message: str) -> str:
|
||||
if not isinstance(message, str) or not message.strip() or len(message) > 2000:
|
||||
return "OpenRouter API encountered an error. Please verify your API key and quota."
|
||||
|
||||
sanitized = message
|
||||
|
||||
# Comprehensive OpenRouter patterns
|
||||
patterns = [
|
||||
(r'sk-or-v1-[a-zA-Z0-9_-]{10,}', '[REDACTED_KEY]'), # OpenRouter API keys
|
||||
(r'https?://[^\s]*openrouter\.ai[^\s]*', '[REDACTED_URL]'), # OpenRouter URLs
|
||||
(r'Bearer\s+[a-zA-Z0-9._-]+', 'Bearer [REDACTED_TOKEN]'), # Bearer tokens
|
||||
]
|
||||
|
||||
for pattern, replacement in patterns:
|
||||
sanitized = re.sub(pattern, replacement, sanitized, flags=re.IGNORECASE)
|
||||
|
||||
# Check for sensitive words
|
||||
sensitive_words = ['internal', 'server', 'endpoint']
|
||||
if any(word in sanitized.lower() for word in sensitive_words):
|
||||
return "OpenRouter API encountered an error. Please verify your API key and quota."
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
class ProviderErrorFactory:
|
||||
"""Factory for provider-agnostic error handling."""
|
||||
|
||||
@@ -125,6 +146,7 @@ class ProviderErrorFactory:
|
||||
"openai": OpenAIErrorAdapter(),
|
||||
"google": GoogleAIErrorAdapter(),
|
||||
"anthropic": AnthropicErrorAdapter(),
|
||||
"openrouter": OpenRouterErrorAdapter(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -145,18 +167,14 @@ class ProviderErrorFactory:
|
||||
error_lower = error_str.lower()
|
||||
|
||||
# Case-insensitive provider detection with multiple patterns
|
||||
if ("anthropic" in error_lower or
|
||||
re.search(r'sk-ant-[a-zA-Z0-9_-]+', error_str, re.IGNORECASE) or
|
||||
"claude" in error_lower):
|
||||
# Check OpenRouter first since it may contain "openai" in model names
|
||||
if ("openrouter" in error_lower or re.search(r'sk-or-v1-[a-zA-Z0-9_-]+', error_str, re.IGNORECASE)):
|
||||
return "openrouter"
|
||||
elif ("anthropic" in error_lower or re.search(r'sk-ant-[a-zA-Z0-9_-]+', error_str, re.IGNORECASE) or "claude" in error_lower):
|
||||
return "anthropic"
|
||||
elif ("google" in error_lower or
|
||||
re.search(r'AIza[a-zA-Z0-9_-]+', error_str, re.IGNORECASE) or
|
||||
"googleapis" in error_lower or
|
||||
"vertex" in error_lower):
|
||||
elif ("google" in error_lower or re.search(r'AIza[a-zA-Z0-9_-]+', error_str, re.IGNORECASE) or "googleapis" in error_lower or "vertex" in error_lower):
|
||||
return "google"
|
||||
elif ("openai" in error_lower or
|
||||
re.search(r'sk-[a-zA-Z0-9]{48}', error_str, re.IGNORECASE) or
|
||||
"gpt" in error_lower):
|
||||
elif ("openai" in error_lower or re.search(r'sk-[a-zA-Z0-9]{48}', error_str, re.IGNORECASE) or "gpt" in error_lower):
|
||||
return "openai"
|
||||
else:
|
||||
return "openai" # Safe default
|
||||
@@ -655,8 +655,8 @@ async def get_embedding_model(provider: str | None = None) -> str:
|
||||
return "text-embedding-004"
|
||||
elif provider_name == "openrouter":
|
||||
# OpenRouter supports both OpenAI and Google embedding models
|
||||
# Default to OpenAI's latest for compatibility
|
||||
return "text-embedding-3-small"
|
||||
# Model names MUST include provider prefix for OpenRouter API
|
||||
return "openai/text-embedding-3-small"
|
||||
elif provider_name == "anthropic":
|
||||
# Anthropic supports OpenAI and Google embedding models through their API
|
||||
# Default to OpenAI's latest for compatibility
|
||||
@@ -846,7 +846,7 @@ def _extract_reasoning_strings(value: Any) -> list[str]:
|
||||
text = value.strip()
|
||||
return [text] if text else []
|
||||
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
if isinstance(value, list | tuple | set):
|
||||
collected: list[str] = []
|
||||
for item in value:
|
||||
collected.extend(_extract_reasoning_strings(item))
|
||||
|
||||
137
python/src/server/services/openrouter_discovery_service.py
Normal file
137
python/src/server/services/openrouter_discovery_service.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
OpenRouter model discovery service.
|
||||
|
||||
Provides discovery and metadata for OpenRouter embedding models.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class OpenRouterEmbeddingModel(BaseModel):
|
||||
"""OpenRouter embedding model metadata."""
|
||||
|
||||
id: str = Field(..., description="Full model ID with provider prefix (e.g., openai/text-embedding-3-large)")
|
||||
provider: str = Field(..., description="Provider name (openai, google, qwen, mistralai)")
|
||||
name: str = Field(..., description="Display name without prefix")
|
||||
dimensions: int = Field(..., description="Embedding dimensions")
|
||||
context_length: int = Field(..., description="Maximum context window in tokens")
|
||||
pricing_per_1m_tokens: float = Field(..., description="Cost per 1M tokens in USD")
|
||||
supports_dimension_reduction: bool = Field(default=False, description="Whether model supports dimension parameter")
|
||||
|
||||
@field_validator("id")
|
||||
@classmethod
|
||||
def validate_model_id_has_prefix(cls, v: str) -> str:
|
||||
"""Ensure model ID includes provider prefix."""
|
||||
if "/" not in v:
|
||||
raise ValueError("OpenRouter model IDs must include provider prefix (e.g., openai/model-name)")
|
||||
return v
|
||||
|
||||
|
||||
class OpenRouterModelListResponse(BaseModel):
|
||||
"""Response from OpenRouter model discovery."""
|
||||
|
||||
embedding_models: list[OpenRouterEmbeddingModel] = Field(default_factory=list)
|
||||
total_count: int = Field(..., description="Total number of embedding models")
|
||||
|
||||
|
||||
class OpenRouterDiscoveryService:
|
||||
"""Discover and manage OpenRouter embedding models."""
|
||||
|
||||
async def discover_embedding_models(self) -> list[OpenRouterEmbeddingModel]:
|
||||
"""
|
||||
Get available OpenRouter embedding models.
|
||||
|
||||
Returns hardcoded list of supported embedding models with metadata.
|
||||
Future enhancement: Could fetch from OpenRouter API if they provide a models endpoint.
|
||||
"""
|
||||
return [
|
||||
# OpenAI models via OpenRouter
|
||||
OpenRouterEmbeddingModel(
|
||||
id="openai/text-embedding-3-small",
|
||||
provider="openai",
|
||||
name="text-embedding-3-small",
|
||||
dimensions=1536,
|
||||
context_length=8191,
|
||||
pricing_per_1m_tokens=0.02,
|
||||
supports_dimension_reduction=True,
|
||||
),
|
||||
OpenRouterEmbeddingModel(
|
||||
id="openai/text-embedding-3-large",
|
||||
provider="openai",
|
||||
name="text-embedding-3-large",
|
||||
dimensions=3072,
|
||||
context_length=8191,
|
||||
pricing_per_1m_tokens=0.13,
|
||||
supports_dimension_reduction=True,
|
||||
),
|
||||
OpenRouterEmbeddingModel(
|
||||
id="openai/text-embedding-ada-002",
|
||||
provider="openai",
|
||||
name="text-embedding-ada-002",
|
||||
dimensions=1536,
|
||||
context_length=8191,
|
||||
pricing_per_1m_tokens=0.10,
|
||||
supports_dimension_reduction=False,
|
||||
),
|
||||
# Google models via OpenRouter
|
||||
OpenRouterEmbeddingModel(
|
||||
id="google/gemini-embedding-001",
|
||||
provider="google",
|
||||
name="gemini-embedding-001",
|
||||
dimensions=768,
|
||||
context_length=20000,
|
||||
pricing_per_1m_tokens=0.00, # Free tier available
|
||||
supports_dimension_reduction=True,
|
||||
),
|
||||
OpenRouterEmbeddingModel(
|
||||
id="google/text-embedding-004",
|
||||
provider="google",
|
||||
name="text-embedding-004",
|
||||
dimensions=768,
|
||||
context_length=20000,
|
||||
pricing_per_1m_tokens=0.00, # Free tier available
|
||||
supports_dimension_reduction=True,
|
||||
),
|
||||
# Qwen models via OpenRouter
|
||||
OpenRouterEmbeddingModel(
|
||||
id="qwen/qwen3-embedding-0.6b",
|
||||
provider="qwen",
|
||||
name="qwen3-embedding-0.6b",
|
||||
dimensions=1024,
|
||||
context_length=32768,
|
||||
pricing_per_1m_tokens=0.01,
|
||||
supports_dimension_reduction=False,
|
||||
),
|
||||
OpenRouterEmbeddingModel(
|
||||
id="qwen/qwen3-embedding-4b",
|
||||
provider="qwen",
|
||||
name="qwen3-embedding-4b",
|
||||
dimensions=1024,
|
||||
context_length=32768,
|
||||
pricing_per_1m_tokens=0.01,
|
||||
supports_dimension_reduction=False,
|
||||
),
|
||||
OpenRouterEmbeddingModel(
|
||||
id="qwen/qwen3-embedding-8b",
|
||||
provider="qwen",
|
||||
name="qwen3-embedding-8b",
|
||||
dimensions=1024,
|
||||
context_length=32768,
|
||||
pricing_per_1m_tokens=0.01,
|
||||
supports_dimension_reduction=False,
|
||||
),
|
||||
# Mistral models via OpenRouter
|
||||
OpenRouterEmbeddingModel(
|
||||
id="mistralai/mistral-embed",
|
||||
provider="mistralai",
|
||||
name="mistral-embed",
|
||||
dimensions=1024,
|
||||
context_length=8192,
|
||||
pricing_per_1m_tokens=0.10,
|
||||
supports_dimension_reduction=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# Create singleton instance
|
||||
openrouter_discovery_service = OpenRouterDiscoveryService()
|
||||
158
python/tests/test_openrouter_discovery.py
Normal file
158
python/tests/test_openrouter_discovery.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
Unit tests for OpenRouter model discovery service.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.server.services.openrouter_discovery_service import (
|
||||
OpenRouterDiscoveryService,
|
||||
OpenRouterEmbeddingModel,
|
||||
OpenRouterModelListResponse,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def discovery_service():
|
||||
"""Create OpenRouter discovery service instance."""
|
||||
return OpenRouterDiscoveryService()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_embedding_models_returns_valid_list(discovery_service):
|
||||
"""Test that discover_embedding_models returns a non-empty list of models."""
|
||||
models = await discovery_service.discover_embedding_models()
|
||||
|
||||
assert isinstance(models, list)
|
||||
assert len(models) > 0
|
||||
assert all(isinstance(model, OpenRouterEmbeddingModel) for model in models)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_models_have_provider_prefix(discovery_service):
|
||||
"""Test that all model IDs include provider prefix."""
|
||||
models = await discovery_service.discover_embedding_models()
|
||||
|
||||
for model in models:
|
||||
assert "/" in model.id, f"Model ID '{model.id}' missing provider prefix"
|
||||
assert model.id.startswith(
|
||||
f"{model.provider}/"
|
||||
), f"Model ID '{model.id}' doesn't match provider '{model.provider}'"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dimensions_are_positive_integers(discovery_service):
|
||||
"""Test that all models have positive integer dimensions."""
|
||||
models = await discovery_service.discover_embedding_models()
|
||||
|
||||
for model in models:
|
||||
assert isinstance(model.dimensions, int), f"Model '{model.id}' dimensions is not an integer"
|
||||
assert model.dimensions > 0, f"Model '{model.id}' has non-positive dimensions: {model.dimensions}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pricing_is_non_negative(discovery_service):
|
||||
"""Test that all models have non-negative pricing."""
|
||||
models = await discovery_service.discover_embedding_models()
|
||||
|
||||
for model in models:
|
||||
assert isinstance(
|
||||
model.pricing_per_1m_tokens, (int, float)
|
||||
), f"Model '{model.id}' pricing is not numeric"
|
||||
assert (
|
||||
model.pricing_per_1m_tokens >= 0
|
||||
), f"Model '{model.id}' has negative pricing: {model.pricing_per_1m_tokens}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_length_is_positive(discovery_service):
|
||||
"""Test that all models have positive context length."""
|
||||
models = await discovery_service.discover_embedding_models()
|
||||
|
||||
for model in models:
|
||||
assert isinstance(
|
||||
model.context_length, int
|
||||
), f"Model '{model.id}' context_length is not an integer"
|
||||
assert (
|
||||
model.context_length > 0
|
||||
), f"Model '{model.id}' has non-positive context_length: {model.context_length}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_providers_are_valid(discovery_service):
|
||||
"""Test that all models have valid provider names."""
|
||||
models = await discovery_service.discover_embedding_models()
|
||||
valid_providers = {"openai", "google", "qwen", "mistralai"}
|
||||
|
||||
for model in models:
|
||||
assert (
|
||||
model.provider in valid_providers
|
||||
), f"Model '{model.id}' has invalid provider: {model.provider}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_models_present(discovery_service):
|
||||
"""Test that OpenAI models are included in the list."""
|
||||
models = await discovery_service.discover_embedding_models()
|
||||
openai_models = [m for m in models if m.provider == "openai"]
|
||||
|
||||
assert len(openai_models) > 0, "No OpenAI models found"
|
||||
assert any(
|
||||
"text-embedding-3-small" in m.id for m in openai_models
|
||||
), "text-embedding-3-small not found"
|
||||
assert any(
|
||||
"text-embedding-3-large" in m.id for m in openai_models
|
||||
), "text-embedding-3-large not found"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qwen_models_present(discovery_service):
|
||||
"""Test that Qwen models are included in the list."""
|
||||
models = await discovery_service.discover_embedding_models()
|
||||
qwen_models = [m for m in models if m.provider == "qwen"]
|
||||
|
||||
assert len(qwen_models) > 0, "No Qwen models found"
|
||||
# Verify at least one Qwen3 embedding model is present
|
||||
assert any("qwen3-embedding" in m.id for m in qwen_models), "No Qwen3 embedding models found"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_list_response_structure():
|
||||
"""Test OpenRouterModelListResponse structure."""
|
||||
service = OpenRouterDiscoveryService()
|
||||
models = await service.discover_embedding_models()
|
||||
|
||||
response = OpenRouterModelListResponse(embedding_models=models, total_count=len(models))
|
||||
|
||||
assert response.total_count == len(models)
|
||||
assert response.total_count == len(response.embedding_models)
|
||||
assert response.total_count > 0
|
||||
|
||||
|
||||
def test_model_id_validation_requires_prefix():
|
||||
"""Test that model ID validation enforces provider prefix."""
|
||||
with pytest.raises(ValueError, match="must include provider prefix"):
|
||||
OpenRouterEmbeddingModel(
|
||||
id="text-embedding-3-small", # Missing provider prefix
|
||||
provider="openai",
|
||||
name="text-embedding-3-small",
|
||||
dimensions=1536,
|
||||
context_length=8191,
|
||||
pricing_per_1m_tokens=0.02,
|
||||
supports_dimension_reduction=True,
|
||||
)
|
||||
|
||||
|
||||
def test_model_with_valid_prefix_accepted():
|
||||
"""Test that model with valid provider prefix is accepted."""
|
||||
model = OpenRouterEmbeddingModel(
|
||||
id="openai/text-embedding-3-small",
|
||||
provider="openai",
|
||||
name="text-embedding-3-small",
|
||||
dimensions=1536,
|
||||
context_length=8191,
|
||||
pricing_per_1m_tokens=0.02,
|
||||
supports_dimension_reduction=True,
|
||||
)
|
||||
|
||||
assert model.id == "openai/text-embedding-3-small"
|
||||
assert "/" in model.id
|
||||
Reference in New Issue
Block a user