mirror of
https://github.com/coleam00/Archon.git
synced 2025-12-24 02:39:17 -05:00
Merge pull request #852 from coleam00/feature/openrouter-embeddings-support
Add OpenRouter Embeddings Support
This commit is contained in:
@@ -124,6 +124,7 @@ PROD=false
|
|||||||
# Run the credentials_setup.sql file in your Supabase SQL editor to set up the credentials table.
|
# Run the credentials_setup.sql file in your Supabase SQL editor to set up the credentials table.
|
||||||
# Then use the Settings page in the web UI to manage:
|
# Then use the Settings page in the web UI to manage:
|
||||||
# - OPENAI_API_KEY (encrypted)
|
# - OPENAI_API_KEY (encrypted)
|
||||||
|
# - OPENROUTER_API_KEY (encrypted, format: sk-or-v1-..., get from https://openrouter.ai/keys)
|
||||||
# - MODEL_CHOICE
|
# - MODEL_CHOICE
|
||||||
# - TRANSPORT settings
|
# - TRANSPORT settings
|
||||||
# - RAG strategy flags (USE_CONTEXTUAL_EMBEDDINGS, USE_HYBRID_SEARCH, etc.)
|
# - RAG strategy flags (USE_CONTEXTUAL_EMBEDDINGS, USE_HYBRID_SEARCH, etc.)
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import OllamaModelSelectionModal from './OllamaModelSelectionModal';
|
|||||||
type ProviderKey = 'openai' | 'google' | 'ollama' | 'anthropic' | 'grok' | 'openrouter';
|
type ProviderKey = 'openai' | 'google' | 'ollama' | 'anthropic' | 'grok' | 'openrouter';
|
||||||
|
|
||||||
// Providers that support embedding models
|
// Providers that support embedding models
|
||||||
const EMBEDDING_CAPABLE_PROVIDERS: ProviderKey[] = ['openai', 'google', 'ollama'];
|
const EMBEDDING_CAPABLE_PROVIDERS: ProviderKey[] = ['openai', 'google', 'openrouter', 'ollama'];
|
||||||
|
|
||||||
interface ProviderModels {
|
interface ProviderModels {
|
||||||
chatModel: string;
|
chatModel: string;
|
||||||
@@ -42,7 +42,7 @@ const getDefaultModels = (provider: ProviderKey): ProviderModels => {
|
|||||||
anthropic: 'text-embedding-3-small', // Fallback to OpenAI
|
anthropic: 'text-embedding-3-small', // Fallback to OpenAI
|
||||||
google: 'text-embedding-004',
|
google: 'text-embedding-004',
|
||||||
grok: 'text-embedding-3-small', // Fallback to OpenAI
|
grok: 'text-embedding-3-small', // Fallback to OpenAI
|
||||||
openrouter: 'text-embedding-3-small',
|
openrouter: 'openai/text-embedding-3-small', // MUST include provider prefix for OpenRouter
|
||||||
ollama: 'nomic-embed-text'
|
ollama: 'nomic-embed-text'
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -1291,7 +1291,7 @@ const manualTestConnection = async (
|
|||||||
Select {activeSelection === 'chat' ? 'Chat' : 'Embedding'} Provider
|
Select {activeSelection === 'chat' ? 'Chat' : 'Embedding'} Provider
|
||||||
</label>
|
</label>
|
||||||
<div className={`grid gap-3 mb-4 ${
|
<div className={`grid gap-3 mb-4 ${
|
||||||
activeSelection === 'chat' ? 'grid-cols-6' : 'grid-cols-3'
|
activeSelection === 'chat' ? 'grid-cols-6' : 'grid-cols-4'
|
||||||
}`}>
|
}`}>
|
||||||
{[
|
{[
|
||||||
{ key: 'openai', name: 'OpenAI', logo: '/img/OpenAI.png', color: 'green' },
|
{ key: 'openai', name: 'OpenAI', logo: '/img/OpenAI.png', color: 'green' },
|
||||||
|
|||||||
244
archon-ui-main/src/services/openrouterService.ts
Normal file
244
archon-ui-main/src/services/openrouterService.ts
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
/**
|
||||||
|
* OpenRouter Service Client
|
||||||
|
*
|
||||||
|
* Provides frontend API client for OpenRouter model discovery.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { getApiUrl } from "../config/api";
|
||||||
|
|
||||||
|
// Type definitions for OpenRouter API responses
|
||||||
|
export interface OpenRouterEmbeddingModel {
|
||||||
|
id: string;
|
||||||
|
provider: string;
|
||||||
|
name: string;
|
||||||
|
dimensions: number;
|
||||||
|
context_length: number;
|
||||||
|
pricing_per_1m_tokens: number;
|
||||||
|
supports_dimension_reduction: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface OpenRouterModelListResponse {
|
||||||
|
embedding_models: OpenRouterEmbeddingModel[];
|
||||||
|
total_count: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
class OpenRouterService {
|
||||||
|
private getBaseUrl = () => getApiUrl();
|
||||||
|
private cacheKey = "openrouter_models_cache";
|
||||||
|
private cacheTTL = 5 * 60 * 1000; // 5 minutes
|
||||||
|
|
||||||
|
private handleApiError(error: unknown, context: string): Error {
|
||||||
|
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||||
|
|
||||||
|
// Check for network errors
|
||||||
|
if (
|
||||||
|
errorMessage.toLowerCase().includes("network") ||
|
||||||
|
errorMessage.includes("fetch") ||
|
||||||
|
errorMessage.includes("Failed to fetch")
|
||||||
|
) {
|
||||||
|
return new Error(
|
||||||
|
`Network error while ${context.toLowerCase()}: ${errorMessage}. ` +
|
||||||
|
"Please check your connection.",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for timeout errors
|
||||||
|
if (errorMessage.includes("timeout") || errorMessage.includes("AbortError")) {
|
||||||
|
return new Error(
|
||||||
|
`Timeout error while ${context.toLowerCase()}: The server may be slow to respond.`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return original error with context
|
||||||
|
return new Error(`${context} failed: ${errorMessage}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Type guard to validate cache entry structure
|
||||||
|
*/
|
||||||
|
private isCacheEntry(
|
||||||
|
value: unknown,
|
||||||
|
): value is { data: OpenRouterModelListResponse; timestamp: number } {
|
||||||
|
if (typeof value !== "object" || value === null) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const obj = value as Record<string, unknown>;
|
||||||
|
|
||||||
|
// Validate timestamp is a number
|
||||||
|
if (typeof obj.timestamp !== "number") {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate data property exists and is an object
|
||||||
|
if (typeof obj.data !== "object" || obj.data === null) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = obj.data as Record<string, unknown>;
|
||||||
|
|
||||||
|
// Validate OpenRouterModelListResponse structure
|
||||||
|
if (!Array.isArray(data.embedding_models)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (typeof data.total_count !== "number") {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate each model in the array has required fields
|
||||||
|
for (const model of data.embedding_models) {
|
||||||
|
if (typeof model !== "object" || model === null) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const m = model as Record<string, unknown>;
|
||||||
|
if (
|
||||||
|
typeof m.id !== "string" ||
|
||||||
|
typeof m.provider !== "string" ||
|
||||||
|
typeof m.name !== "string" ||
|
||||||
|
typeof m.dimensions !== "number" ||
|
||||||
|
typeof m.context_length !== "number" ||
|
||||||
|
typeof m.pricing_per_1m_tokens !== "number" ||
|
||||||
|
typeof m.supports_dimension_reduction !== "boolean"
|
||||||
|
) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get cached models if available and not expired
|
||||||
|
*/
|
||||||
|
private getCachedModels(): OpenRouterModelListResponse | null {
|
||||||
|
try {
|
||||||
|
const cached = sessionStorage.getItem(this.cacheKey);
|
||||||
|
if (!cached) return null;
|
||||||
|
|
||||||
|
const parsed: unknown = JSON.parse(cached);
|
||||||
|
|
||||||
|
// Validate cache structure
|
||||||
|
if (!this.isCacheEntry(parsed)) {
|
||||||
|
// Cache is corrupted, remove it to avoid repeated failures
|
||||||
|
sessionStorage.removeItem(this.cacheKey);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const now = Date.now();
|
||||||
|
|
||||||
|
// Check expiration
|
||||||
|
if (now - parsed.timestamp > this.cacheTTL) {
|
||||||
|
sessionStorage.removeItem(this.cacheKey);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsed.data;
|
||||||
|
} catch {
|
||||||
|
// JSON parsing failed or other error, clear cache
|
||||||
|
sessionStorage.removeItem(this.cacheKey);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cache models for the TTL duration
|
||||||
|
*/
|
||||||
|
private cacheModels(data: OpenRouterModelListResponse): void {
|
||||||
|
try {
|
||||||
|
const cacheData = {
|
||||||
|
data,
|
||||||
|
timestamp: Date.now(),
|
||||||
|
};
|
||||||
|
sessionStorage.setItem(this.cacheKey, JSON.stringify(cacheData));
|
||||||
|
} catch {
|
||||||
|
// Ignore cache errors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Discover available OpenRouter embedding models
|
||||||
|
*/
|
||||||
|
async discoverModels(): Promise<OpenRouterModelListResponse> {
|
||||||
|
try {
|
||||||
|
// Check cache first
|
||||||
|
const cached = this.getCachedModels();
|
||||||
|
if (cached) {
|
||||||
|
return cached;
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await fetch(`${this.getBaseUrl()}/api/openrouter/models`, {
|
||||||
|
method: "GET",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const errorText = await response.text();
|
||||||
|
throw new Error(`HTTP ${response.status}: ${errorText}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
// Validate response structure
|
||||||
|
if (!data.embedding_models || !Array.isArray(data.embedding_models)) {
|
||||||
|
throw new Error("Invalid response structure: missing or invalid embedding_models array");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (typeof data.total_count !== "number" || data.total_count < 0) {
|
||||||
|
throw new Error("Invalid response structure: total_count must be a non-negative number");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (data.total_count !== data.embedding_models.length) {
|
||||||
|
throw new Error(
|
||||||
|
`Response structure mismatch: total_count (${data.total_count}) does not match embedding_models length (${data.embedding_models.length})`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate at least one model has required fields
|
||||||
|
if (data.embedding_models.length > 0) {
|
||||||
|
const firstModel = data.embedding_models[0];
|
||||||
|
if (
|
||||||
|
!firstModel.id ||
|
||||||
|
typeof firstModel.id !== "string" ||
|
||||||
|
!firstModel.provider ||
|
||||||
|
typeof firstModel.provider !== "string" ||
|
||||||
|
typeof firstModel.dimensions !== "number" ||
|
||||||
|
firstModel.dimensions <= 0
|
||||||
|
) {
|
||||||
|
throw new Error(
|
||||||
|
"Invalid model structure: models must have id (string), provider (string), and positive dimensions",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate provider name is from expected set
|
||||||
|
const validProviders = ["openai", "google", "qwen", "mistralai"];
|
||||||
|
if (!validProviders.includes(firstModel.provider)) {
|
||||||
|
throw new Error(`Invalid provider name: ${firstModel.provider}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cache the successful response
|
||||||
|
this.cacheModels(data);
|
||||||
|
|
||||||
|
return data;
|
||||||
|
} catch (error) {
|
||||||
|
throw this.handleApiError(error, "Model discovery");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clear the models cache
|
||||||
|
*/
|
||||||
|
clearCache(): void {
|
||||||
|
try {
|
||||||
|
sessionStorage.removeItem(this.cacheKey);
|
||||||
|
} catch {
|
||||||
|
// Ignore cache clearing errors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Export singleton instance
|
||||||
|
export const openrouterService = new OpenRouterService();
|
||||||
27
python/src/server/api_routes/openrouter_api.py
Normal file
27
python/src/server/api_routes/openrouter_api.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
"""
|
||||||
|
OpenRouter API routes.
|
||||||
|
|
||||||
|
Endpoints for OpenRouter model discovery and configuration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from ..services.openrouter_discovery_service import OpenRouterModelListResponse, openrouter_discovery_service
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/openrouter", tags=["openrouter"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/models", response_model=OpenRouterModelListResponse)
|
||||||
|
async def get_openrouter_models() -> OpenRouterModelListResponse:
|
||||||
|
"""
|
||||||
|
Get available OpenRouter embedding models.
|
||||||
|
|
||||||
|
Returns a list of embedding models available through OpenRouter,
|
||||||
|
including models from OpenAI, Google, Qwen, and Mistral providers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OpenRouterModelListResponse: List of embedding models with metadata
|
||||||
|
"""
|
||||||
|
models = await openrouter_discovery_service.discover_embedding_models()
|
||||||
|
|
||||||
|
return OpenRouterModelListResponse(embedding_models=models, total_count=len(models))
|
||||||
@@ -66,6 +66,19 @@ def validate_openai_api_key(api_key: str) -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def validate_openrouter_api_key(api_key: str) -> bool:
|
||||||
|
"""Validate OpenRouter API key format."""
|
||||||
|
if not api_key:
|
||||||
|
raise ConfigurationError("OpenRouter API key cannot be empty")
|
||||||
|
|
||||||
|
if not api_key.startswith("sk-or-v1-"):
|
||||||
|
raise ConfigurationError(
|
||||||
|
"OpenRouter API key must start with 'sk-or-v1-'. " "Get your key at https://openrouter.ai/keys"
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def validate_supabase_key(supabase_key: str) -> tuple[bool, str]:
|
def validate_supabase_key(supabase_key: str) -> tuple[bool, str]:
|
||||||
"""Validate Supabase key type and return validation result.
|
"""Validate Supabase key type and return validation result.
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from .api_routes.knowledge_api import router as knowledge_router
|
|||||||
from .api_routes.mcp_api import router as mcp_router
|
from .api_routes.mcp_api import router as mcp_router
|
||||||
from .api_routes.migration_api import router as migration_router
|
from .api_routes.migration_api import router as migration_router
|
||||||
from .api_routes.ollama_api import router as ollama_router
|
from .api_routes.ollama_api import router as ollama_router
|
||||||
|
from .api_routes.openrouter_api import router as openrouter_router
|
||||||
from .api_routes.pages_api import router as pages_router
|
from .api_routes.pages_api import router as pages_router
|
||||||
from .api_routes.progress_api import router as progress_router
|
from .api_routes.progress_api import router as progress_router
|
||||||
from .api_routes.projects_api import router as projects_router
|
from .api_routes.projects_api import router as projects_router
|
||||||
@@ -187,6 +188,7 @@ app.include_router(mcp_router)
|
|||||||
app.include_router(knowledge_router)
|
app.include_router(knowledge_router)
|
||||||
app.include_router(pages_router)
|
app.include_router(pages_router)
|
||||||
app.include_router(ollama_router)
|
app.include_router(ollama_router)
|
||||||
|
app.include_router(openrouter_router)
|
||||||
app.include_router(projects_router)
|
app.include_router(projects_router)
|
||||||
app.include_router(progress_router)
|
app.include_router(progress_router)
|
||||||
app.include_router(agent_chat_router)
|
app.include_router(agent_chat_router)
|
||||||
|
|||||||
@@ -443,7 +443,7 @@ class CredentialService:
|
|||||||
explicit_embedding_provider = rag_settings.get("EMBEDDING_PROVIDER")
|
explicit_embedding_provider = rag_settings.get("EMBEDDING_PROVIDER")
|
||||||
|
|
||||||
# Validate that embedding provider actually supports embeddings
|
# Validate that embedding provider actually supports embeddings
|
||||||
embedding_capable_providers = {"openai", "google", "ollama"}
|
embedding_capable_providers = {"openai", "google", "openrouter", "ollama"}
|
||||||
|
|
||||||
if (explicit_embedding_provider and
|
if (explicit_embedding_provider and
|
||||||
explicit_embedding_provider != "" and
|
explicit_embedding_provider != "" and
|
||||||
|
|||||||
@@ -8,13 +8,6 @@ with unified error handling and sanitization patterns.
|
|||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from .embedding_exceptions import (
|
|
||||||
EmbeddingAPIError,
|
|
||||||
EmbeddingAuthenticationError,
|
|
||||||
EmbeddingQuotaExhaustedError,
|
|
||||||
EmbeddingRateLimitError,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ProviderErrorAdapter(ABC):
|
class ProviderErrorAdapter(ABC):
|
||||||
"""Abstract base class for provider-specific error handling."""
|
"""Abstract base class for provider-specific error handling."""
|
||||||
@@ -118,6 +111,34 @@ class AnthropicErrorAdapter(ProviderErrorAdapter):
|
|||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouterErrorAdapter(ProviderErrorAdapter):
|
||||||
|
def get_provider_name(self) -> str:
|
||||||
|
return "openrouter"
|
||||||
|
|
||||||
|
def sanitize_error_message(self, message: str) -> str:
|
||||||
|
if not isinstance(message, str) or not message.strip() or len(message) > 2000:
|
||||||
|
return "OpenRouter API encountered an error. Please verify your API key and quota."
|
||||||
|
|
||||||
|
sanitized = message
|
||||||
|
|
||||||
|
# Comprehensive OpenRouter patterns
|
||||||
|
patterns = [
|
||||||
|
(r'sk-or-v1-[a-zA-Z0-9_-]{10,}', '[REDACTED_KEY]'), # OpenRouter API keys
|
||||||
|
(r'https?://[^\s]*openrouter\.ai[^\s]*', '[REDACTED_URL]'), # OpenRouter URLs
|
||||||
|
(r'Bearer\s+[a-zA-Z0-9._-]+', 'Bearer [REDACTED_TOKEN]'), # Bearer tokens
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern, replacement in patterns:
|
||||||
|
sanitized = re.sub(pattern, replacement, sanitized, flags=re.IGNORECASE)
|
||||||
|
|
||||||
|
# Check for sensitive words
|
||||||
|
sensitive_words = ['internal', 'server', 'endpoint']
|
||||||
|
if any(word in sanitized.lower() for word in sensitive_words):
|
||||||
|
return "OpenRouter API encountered an error. Please verify your API key and quota."
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
class ProviderErrorFactory:
|
class ProviderErrorFactory:
|
||||||
"""Factory for provider-agnostic error handling."""
|
"""Factory for provider-agnostic error handling."""
|
||||||
|
|
||||||
@@ -125,6 +146,7 @@ class ProviderErrorFactory:
|
|||||||
"openai": OpenAIErrorAdapter(),
|
"openai": OpenAIErrorAdapter(),
|
||||||
"google": GoogleAIErrorAdapter(),
|
"google": GoogleAIErrorAdapter(),
|
||||||
"anthropic": AnthropicErrorAdapter(),
|
"anthropic": AnthropicErrorAdapter(),
|
||||||
|
"openrouter": OpenRouterErrorAdapter(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -145,18 +167,14 @@ class ProviderErrorFactory:
|
|||||||
error_lower = error_str.lower()
|
error_lower = error_str.lower()
|
||||||
|
|
||||||
# Case-insensitive provider detection with multiple patterns
|
# Case-insensitive provider detection with multiple patterns
|
||||||
if ("anthropic" in error_lower or
|
# Check OpenRouter first since it may contain "openai" in model names
|
||||||
re.search(r'sk-ant-[a-zA-Z0-9_-]+', error_str, re.IGNORECASE) or
|
if ("openrouter" in error_lower or re.search(r'sk-or-v1-[a-zA-Z0-9_-]+', error_str, re.IGNORECASE)):
|
||||||
"claude" in error_lower):
|
return "openrouter"
|
||||||
|
elif ("anthropic" in error_lower or re.search(r'sk-ant-[a-zA-Z0-9_-]+', error_str, re.IGNORECASE) or "claude" in error_lower):
|
||||||
return "anthropic"
|
return "anthropic"
|
||||||
elif ("google" in error_lower or
|
elif ("google" in error_lower or re.search(r'AIza[a-zA-Z0-9_-]+', error_str, re.IGNORECASE) or "googleapis" in error_lower or "vertex" in error_lower):
|
||||||
re.search(r'AIza[a-zA-Z0-9_-]+', error_str, re.IGNORECASE) or
|
|
||||||
"googleapis" in error_lower or
|
|
||||||
"vertex" in error_lower):
|
|
||||||
return "google"
|
return "google"
|
||||||
elif ("openai" in error_lower or
|
elif ("openai" in error_lower or re.search(r'sk-[a-zA-Z0-9]{48}', error_str, re.IGNORECASE) or "gpt" in error_lower):
|
||||||
re.search(r'sk-[a-zA-Z0-9]{48}', error_str, re.IGNORECASE) or
|
|
||||||
"gpt" in error_lower):
|
|
||||||
return "openai"
|
return "openai"
|
||||||
else:
|
else:
|
||||||
return "openai" # Safe default
|
return "openai" # Safe default
|
||||||
@@ -655,8 +655,8 @@ async def get_embedding_model(provider: str | None = None) -> str:
|
|||||||
return "text-embedding-004"
|
return "text-embedding-004"
|
||||||
elif provider_name == "openrouter":
|
elif provider_name == "openrouter":
|
||||||
# OpenRouter supports both OpenAI and Google embedding models
|
# OpenRouter supports both OpenAI and Google embedding models
|
||||||
# Default to OpenAI's latest for compatibility
|
# Model names MUST include provider prefix for OpenRouter API
|
||||||
return "text-embedding-3-small"
|
return "openai/text-embedding-3-small"
|
||||||
elif provider_name == "anthropic":
|
elif provider_name == "anthropic":
|
||||||
# Anthropic supports OpenAI and Google embedding models through their API
|
# Anthropic supports OpenAI and Google embedding models through their API
|
||||||
# Default to OpenAI's latest for compatibility
|
# Default to OpenAI's latest for compatibility
|
||||||
@@ -846,7 +846,7 @@ def _extract_reasoning_strings(value: Any) -> list[str]:
|
|||||||
text = value.strip()
|
text = value.strip()
|
||||||
return [text] if text else []
|
return [text] if text else []
|
||||||
|
|
||||||
if isinstance(value, (list, tuple, set)):
|
if isinstance(value, list | tuple | set):
|
||||||
collected: list[str] = []
|
collected: list[str] = []
|
||||||
for item in value:
|
for item in value:
|
||||||
collected.extend(_extract_reasoning_strings(item))
|
collected.extend(_extract_reasoning_strings(item))
|
||||||
|
|||||||
137
python/src/server/services/openrouter_discovery_service.py
Normal file
137
python/src/server/services/openrouter_discovery_service.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
"""
|
||||||
|
OpenRouter model discovery service.
|
||||||
|
|
||||||
|
Provides discovery and metadata for OpenRouter embedding models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouterEmbeddingModel(BaseModel):
|
||||||
|
"""OpenRouter embedding model metadata."""
|
||||||
|
|
||||||
|
id: str = Field(..., description="Full model ID with provider prefix (e.g., openai/text-embedding-3-large)")
|
||||||
|
provider: str = Field(..., description="Provider name (openai, google, qwen, mistralai)")
|
||||||
|
name: str = Field(..., description="Display name without prefix")
|
||||||
|
dimensions: int = Field(..., description="Embedding dimensions")
|
||||||
|
context_length: int = Field(..., description="Maximum context window in tokens")
|
||||||
|
pricing_per_1m_tokens: float = Field(..., description="Cost per 1M tokens in USD")
|
||||||
|
supports_dimension_reduction: bool = Field(default=False, description="Whether model supports dimension parameter")
|
||||||
|
|
||||||
|
@field_validator("id")
|
||||||
|
@classmethod
|
||||||
|
def validate_model_id_has_prefix(cls, v: str) -> str:
|
||||||
|
"""Ensure model ID includes provider prefix."""
|
||||||
|
if "/" not in v:
|
||||||
|
raise ValueError("OpenRouter model IDs must include provider prefix (e.g., openai/model-name)")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouterModelListResponse(BaseModel):
|
||||||
|
"""Response from OpenRouter model discovery."""
|
||||||
|
|
||||||
|
embedding_models: list[OpenRouterEmbeddingModel] = Field(default_factory=list)
|
||||||
|
total_count: int = Field(..., description="Total number of embedding models")
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouterDiscoveryService:
|
||||||
|
"""Discover and manage OpenRouter embedding models."""
|
||||||
|
|
||||||
|
async def discover_embedding_models(self) -> list[OpenRouterEmbeddingModel]:
|
||||||
|
"""
|
||||||
|
Get available OpenRouter embedding models.
|
||||||
|
|
||||||
|
Returns hardcoded list of supported embedding models with metadata.
|
||||||
|
Future enhancement: Could fetch from OpenRouter API if they provide a models endpoint.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
# OpenAI models via OpenRouter
|
||||||
|
OpenRouterEmbeddingModel(
|
||||||
|
id="openai/text-embedding-3-small",
|
||||||
|
provider="openai",
|
||||||
|
name="text-embedding-3-small",
|
||||||
|
dimensions=1536,
|
||||||
|
context_length=8191,
|
||||||
|
pricing_per_1m_tokens=0.02,
|
||||||
|
supports_dimension_reduction=True,
|
||||||
|
),
|
||||||
|
OpenRouterEmbeddingModel(
|
||||||
|
id="openai/text-embedding-3-large",
|
||||||
|
provider="openai",
|
||||||
|
name="text-embedding-3-large",
|
||||||
|
dimensions=3072,
|
||||||
|
context_length=8191,
|
||||||
|
pricing_per_1m_tokens=0.13,
|
||||||
|
supports_dimension_reduction=True,
|
||||||
|
),
|
||||||
|
OpenRouterEmbeddingModel(
|
||||||
|
id="openai/text-embedding-ada-002",
|
||||||
|
provider="openai",
|
||||||
|
name="text-embedding-ada-002",
|
||||||
|
dimensions=1536,
|
||||||
|
context_length=8191,
|
||||||
|
pricing_per_1m_tokens=0.10,
|
||||||
|
supports_dimension_reduction=False,
|
||||||
|
),
|
||||||
|
# Google models via OpenRouter
|
||||||
|
OpenRouterEmbeddingModel(
|
||||||
|
id="google/gemini-embedding-001",
|
||||||
|
provider="google",
|
||||||
|
name="gemini-embedding-001",
|
||||||
|
dimensions=768,
|
||||||
|
context_length=20000,
|
||||||
|
pricing_per_1m_tokens=0.00, # Free tier available
|
||||||
|
supports_dimension_reduction=True,
|
||||||
|
),
|
||||||
|
OpenRouterEmbeddingModel(
|
||||||
|
id="google/text-embedding-004",
|
||||||
|
provider="google",
|
||||||
|
name="text-embedding-004",
|
||||||
|
dimensions=768,
|
||||||
|
context_length=20000,
|
||||||
|
pricing_per_1m_tokens=0.00, # Free tier available
|
||||||
|
supports_dimension_reduction=True,
|
||||||
|
),
|
||||||
|
# Qwen models via OpenRouter
|
||||||
|
OpenRouterEmbeddingModel(
|
||||||
|
id="qwen/qwen3-embedding-0.6b",
|
||||||
|
provider="qwen",
|
||||||
|
name="qwen3-embedding-0.6b",
|
||||||
|
dimensions=1024,
|
||||||
|
context_length=32768,
|
||||||
|
pricing_per_1m_tokens=0.01,
|
||||||
|
supports_dimension_reduction=False,
|
||||||
|
),
|
||||||
|
OpenRouterEmbeddingModel(
|
||||||
|
id="qwen/qwen3-embedding-4b",
|
||||||
|
provider="qwen",
|
||||||
|
name="qwen3-embedding-4b",
|
||||||
|
dimensions=1024,
|
||||||
|
context_length=32768,
|
||||||
|
pricing_per_1m_tokens=0.01,
|
||||||
|
supports_dimension_reduction=False,
|
||||||
|
),
|
||||||
|
OpenRouterEmbeddingModel(
|
||||||
|
id="qwen/qwen3-embedding-8b",
|
||||||
|
provider="qwen",
|
||||||
|
name="qwen3-embedding-8b",
|
||||||
|
dimensions=1024,
|
||||||
|
context_length=32768,
|
||||||
|
pricing_per_1m_tokens=0.01,
|
||||||
|
supports_dimension_reduction=False,
|
||||||
|
),
|
||||||
|
# Mistral models via OpenRouter
|
||||||
|
OpenRouterEmbeddingModel(
|
||||||
|
id="mistralai/mistral-embed",
|
||||||
|
provider="mistralai",
|
||||||
|
name="mistral-embed",
|
||||||
|
dimensions=1024,
|
||||||
|
context_length=8192,
|
||||||
|
pricing_per_1m_tokens=0.10,
|
||||||
|
supports_dimension_reduction=False,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Create singleton instance
|
||||||
|
openrouter_discovery_service = OpenRouterDiscoveryService()
|
||||||
158
python/tests/test_openrouter_discovery.py
Normal file
158
python/tests/test_openrouter_discovery.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for OpenRouter model discovery service.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.server.services.openrouter_discovery_service import (
|
||||||
|
OpenRouterDiscoveryService,
|
||||||
|
OpenRouterEmbeddingModel,
|
||||||
|
OpenRouterModelListResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def discovery_service():
|
||||||
|
"""Create OpenRouter discovery service instance."""
|
||||||
|
return OpenRouterDiscoveryService()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_discover_embedding_models_returns_valid_list(discovery_service):
|
||||||
|
"""Test that discover_embedding_models returns a non-empty list of models."""
|
||||||
|
models = await discovery_service.discover_embedding_models()
|
||||||
|
|
||||||
|
assert isinstance(models, list)
|
||||||
|
assert len(models) > 0
|
||||||
|
assert all(isinstance(model, OpenRouterEmbeddingModel) for model in models)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_all_models_have_provider_prefix(discovery_service):
|
||||||
|
"""Test that all model IDs include provider prefix."""
|
||||||
|
models = await discovery_service.discover_embedding_models()
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
assert "/" in model.id, f"Model ID '{model.id}' missing provider prefix"
|
||||||
|
assert model.id.startswith(
|
||||||
|
f"{model.provider}/"
|
||||||
|
), f"Model ID '{model.id}' doesn't match provider '{model.provider}'"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dimensions_are_positive_integers(discovery_service):
|
||||||
|
"""Test that all models have positive integer dimensions."""
|
||||||
|
models = await discovery_service.discover_embedding_models()
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
assert isinstance(model.dimensions, int), f"Model '{model.id}' dimensions is not an integer"
|
||||||
|
assert model.dimensions > 0, f"Model '{model.id}' has non-positive dimensions: {model.dimensions}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pricing_is_non_negative(discovery_service):
|
||||||
|
"""Test that all models have non-negative pricing."""
|
||||||
|
models = await discovery_service.discover_embedding_models()
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
assert isinstance(
|
||||||
|
model.pricing_per_1m_tokens, (int, float)
|
||||||
|
), f"Model '{model.id}' pricing is not numeric"
|
||||||
|
assert (
|
||||||
|
model.pricing_per_1m_tokens >= 0
|
||||||
|
), f"Model '{model.id}' has negative pricing: {model.pricing_per_1m_tokens}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_length_is_positive(discovery_service):
|
||||||
|
"""Test that all models have positive context length."""
|
||||||
|
models = await discovery_service.discover_embedding_models()
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
assert isinstance(
|
||||||
|
model.context_length, int
|
||||||
|
), f"Model '{model.id}' context_length is not an integer"
|
||||||
|
assert (
|
||||||
|
model.context_length > 0
|
||||||
|
), f"Model '{model.id}' has non-positive context_length: {model.context_length}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_providers_are_valid(discovery_service):
|
||||||
|
"""Test that all models have valid provider names."""
|
||||||
|
models = await discovery_service.discover_embedding_models()
|
||||||
|
valid_providers = {"openai", "google", "qwen", "mistralai"}
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
assert (
|
||||||
|
model.provider in valid_providers
|
||||||
|
), f"Model '{model.id}' has invalid provider: {model.provider}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_models_present(discovery_service):
|
||||||
|
"""Test that OpenAI models are included in the list."""
|
||||||
|
models = await discovery_service.discover_embedding_models()
|
||||||
|
openai_models = [m for m in models if m.provider == "openai"]
|
||||||
|
|
||||||
|
assert len(openai_models) > 0, "No OpenAI models found"
|
||||||
|
assert any(
|
||||||
|
"text-embedding-3-small" in m.id for m in openai_models
|
||||||
|
), "text-embedding-3-small not found"
|
||||||
|
assert any(
|
||||||
|
"text-embedding-3-large" in m.id for m in openai_models
|
||||||
|
), "text-embedding-3-large not found"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_qwen_models_present(discovery_service):
|
||||||
|
"""Test that Qwen models are included in the list."""
|
||||||
|
models = await discovery_service.discover_embedding_models()
|
||||||
|
qwen_models = [m for m in models if m.provider == "qwen"]
|
||||||
|
|
||||||
|
assert len(qwen_models) > 0, "No Qwen models found"
|
||||||
|
# Verify at least one Qwen3 embedding model is present
|
||||||
|
assert any("qwen3-embedding" in m.id for m in qwen_models), "No Qwen3 embedding models found"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_list_response_structure():
|
||||||
|
"""Test OpenRouterModelListResponse structure."""
|
||||||
|
service = OpenRouterDiscoveryService()
|
||||||
|
models = await service.discover_embedding_models()
|
||||||
|
|
||||||
|
response = OpenRouterModelListResponse(embedding_models=models, total_count=len(models))
|
||||||
|
|
||||||
|
assert response.total_count == len(models)
|
||||||
|
assert response.total_count == len(response.embedding_models)
|
||||||
|
assert response.total_count > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_id_validation_requires_prefix():
|
||||||
|
"""Test that model ID validation enforces provider prefix."""
|
||||||
|
with pytest.raises(ValueError, match="must include provider prefix"):
|
||||||
|
OpenRouterEmbeddingModel(
|
||||||
|
id="text-embedding-3-small", # Missing provider prefix
|
||||||
|
provider="openai",
|
||||||
|
name="text-embedding-3-small",
|
||||||
|
dimensions=1536,
|
||||||
|
context_length=8191,
|
||||||
|
pricing_per_1m_tokens=0.02,
|
||||||
|
supports_dimension_reduction=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_with_valid_prefix_accepted():
|
||||||
|
"""Test that model with valid provider prefix is accepted."""
|
||||||
|
model = OpenRouterEmbeddingModel(
|
||||||
|
id="openai/text-embedding-3-small",
|
||||||
|
provider="openai",
|
||||||
|
name="text-embedding-3-small",
|
||||||
|
dimensions=1536,
|
||||||
|
context_length=8191,
|
||||||
|
pricing_per_1m_tokens=0.02,
|
||||||
|
supports_dimension_reduction=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert model.id == "openai/text-embedding-3-small"
|
||||||
|
assert "/" in model.id
|
||||||
Reference in New Issue
Block a user