mirror of
https://github.com/coleam00/Archon.git
synced 2025-12-30 21:49:30 -05:00
Feature/LLM-Providers-UI-Polished (#736)
* Add Anthropic and Grok provider support * feat: Add crucial GPT-5 and reasoning model support for OpenRouter - Add requires_max_completion_tokens() function for GPT-5, o1, o3, Grok-3 series - Add prepare_chat_completion_params() for reasoning model compatibility - Implement max_tokens → max_completion_tokens conversion for reasoning models - Add temperature handling for reasoning models (must be 1.0 default) - Enhanced provider validation and API key security in provider endpoints - Streamlined retry logic (3→2 attempts) for faster issue detection - Add failure tracking and circuit breaker analysis for debugging - Support OpenRouter format detection (openai/gpt-5-nano, openai/o1-mini) - Improved Grok provider empty response handling with structured fallbacks - Enhanced contextual embedding with provider-aware model selection Core provider functionality: - OpenRouter, Grok, Anthropic provider support with full embedding integration - Provider-specific model defaults and validation - Secure API connectivity testing endpoints - Provider context passing for code generation workflows 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * fully working model providers, addressing securtiy and code related concerns, throughly hardening our code * added multiprovider support, embeddings model support, cleaned the pr, need to fix health check, asnyico tasks errors, and contextual embeddings error * fixed contextual embeddings issue * - Added inspect-aware shutdown handling so get_llm_client always closes the underlying AsyncOpenAI / httpx.AsyncClient while the loop is still alive, with defensive logging if shutdown happens late (python/src/server/services/llm_provider_service.py:14, python/src/server/ services/llm_provider_service.py:520). * - Restructured get_llm_client so client creation and usage live in separate try/finally blocks; fallback clients now close without logging spurious Error creating LLM client when downstream code raises (python/src/server/services/llm_provider_service.py:335-556). - Close logic now sanitizes provider names consistently and awaits whichever aclose/close coroutine the SDK exposes, keeping the loop shut down cleanly (python/src/server/services/llm_provider_service.py:530-559). Robust JSON Parsing - Added _extract_json_payload to strip code fences / extra text returned by Ollama before json.loads runs, averting the markdown-induced decode errors you saw in logs (python/src/server/services/storage/code_storage_service.py:40-63). - Swapped the direct parse call for the sanitized payload and emit a debug preview when cleanup alters the content (python/src/server/ services/storage/code_storage_service.py:858-864). * added provider connection support * added provider api key not being configured warning * Updated get_llm_client so missing OpenAI keys automatically fall back to Ollama (matching existing tests) and so unsupported providers still raise the legacy ValueError the suite expects. The fallback now reuses _get_optimal_ollama_instance and rethrows ValueError(OpenAI API key not found and Ollama fallback failed) when it cant connect. Adjusted test_code_extraction_source_id.py to accept the new optional argument on the mocked extractor (and confirm its None when present). * Resolved a few needed code rabbit suggestion - Updated the knowledge API key validation to call create_embedding with the provider argument and removed the hard-coded OpenAI fallback (python/src/server/api_routes/knowledge_api.py). - Broadened embedding provider detection so prefixed OpenRouter/OpenAI model names route through the correct client (python/src/server/ services/embeddings/embedding_service.py, python/src/server/services/llm_provider_service.py). - Removed the duplicate helper definitions from llm_provider_service.py, eliminating the stray docstring that was causing the import-time syntax error. * updated via code rabbit PR review, code rabbit in my IDE found no issues and no nitpicks with the updates! what was done: Credential service now persists the provider under the uppercase key LLM_PROVIDER, matching the read path (no new EMBEDDING_PROVIDER usage introduced). Embedding batch creation stops inserting blank strings, logging failures and skipping invalid items before they ever hit the provider (python/src/server/services/embeddings/embedding_service.py). Contextual embedding prompts use real newline characters everywhereboth when constructing the batch prompt and when parsing the models response (python/src/server/services/embeddings/contextual_embedding_service.py). Embedding provider routing already recognizes OpenRouter-prefixed OpenAI models via is_openai_embedding_model; no further change needed there. Embedding insertion now skips unsupported vector dimensions instead of forcing them into the 1536-column, and the backoff loop uses await asyncio.sleep so we no longer block the event loop (python/src/server/services/storage/code_storage_service.py). RAG settings props were extended to include LLM_INSTANCE_NAME and OLLAMA_EMBEDDING_INSTANCE_NAME, and the debug log no longer prints API-key prefixes (the rest of the TanStack refactor/EMBEDDING_PROVIDER support remains deferred). * test fix * enhanced Openrouters parsing logic to automatically detect reasoning models and parse regardless of json output or not. this commit creates a robust way for archons parsing to work throughly with openrouter automatically, regardless of the model youre using, to ensure proper functionality with out breaking any generation capabilities! * updated ui llm interface, added seprate embeddings provider, made the system fully capabale of mix and matching llm providers (local and non local) for chat & embeddings. updated the ragsettings.tsx ui mainly, along with core functionality * added warning labels and updated ollama health checks * ready for review, fixed som error warnings and consildated ollama status health checks * fixed FAILED test_async_embedding_service.py * code rabbit fixes * Separated the code-summary LLM provider from the embedding provider, so code example storage now forwards a dedicated embedding provider override end-to-end without hijacking the embedding pipeline. this fixes code rabbits (Preserve provider override in create_embeddings_batch) suggesting * - Swapped API credential storage to booleans so decrypted keys never sit in React state (archon-ui-main/src/components/ settings/RAGSettings.tsx). - Normalized Ollama instance URLs and gated the metrics effect on real state changes to avoid mis-counts and duplicate fetches (RAGSettings.tsx). - Tightened crawl progress scaling and indented-block parsing to handle min_length=None safely (python/src/server/ services/crawling/code_extraction_service.py:160, python/src/server/services/crawling/code_extraction_service.py:911). - Added provider-agnostic embedding rate-limit retries so Google and friends back off gracefully (python/src/server/ services/embeddings/embedding_service.py:427). - Made the orchestration registry async + thread-safe and updated every caller to await it (python/src/server/services/ crawling/crawling_service.py:34, python/src/server/api_routes/knowledge_api.py:1291). * Update RAGSettings.tsx - header for 'LLM Settings' is now 'LLM Provider Settings' * (RAG Settings) - Ollama Health Checks & Metrics - Added a 10-second timeout to the health fetch so it doesn't hang. - Adjusted logic so metric refreshes run for embedding-only Ollama setups too. - Initial page load now checks Ollama if either chat or embedding provider uses it. - Metrics and alerts now respect which provider (chat/embedding) is currently selected. - Provider Sync & Alerts - Fixed a sync bug so the very first provider change updates settings as expected. - Alerts now track the active provider (chat vs embedding) rather than only the LLM provider. - Warnings about missing credentials now skip whichever provider is currently selected. - Modals & Types - Normalize URLs before handing them to selection modals to keep consistent data. - Strengthened helper function types (getDisplayedChatModel, getModelPlaceholder, etc.). (Crawling Service) - Made the orchestration registry lock lazy-initialized to avoid issues in Python 3.12 and wrapped registry commands (register, unregister) in async calls. This keeps things thread-safe even during concurrent crawling and cancellation. * - migration/complete_setup.sql:101 seeds Google/OpenRouter/Anthropic/Grok API key rows so fresh databases expose every provider by default. - migration/0.1.0/009_add_provider_placeholders.sql:1 backfills the same rows for existing Supabase instances and records the migration. - archon-ui-main/src/components/settings/RAGSettings.tsx:121 introduces a shared credentialprovider map, reloadApiCredentials runs through all five providers, and the status poller includes the new keys. - archon-ui-main/src/components/settings/RAGSettings.tsx:353 subscribes to the archon:credentials-updated browser event so adding/removing a key immediately refetches credential status and pings the corresponding connectivity test. - archon-ui-main/src/components/settings/RAGSettings.tsx:926 now treats missing Anthropic/OpenRouter/Grok keys as missing, preventing stale connected badges when a key is removed. * - archon-ui-main/src/components/settings/RAGSettings.tsx:90 adds a simple display-name map and reuses one red alert style. - archon-ui-main/src/components/settings/RAGSettings.tsx:1016 now shows exactly one red banner when the active provider - Removed the old duplicate Missing API Key Configuration block, so the panel no longer stacks two warnings. * Update credentialsService.ts default model * updated the google embedding adapter for multi dimensional rag querying * thought this micro fix in the google embedding pushed with the embedding update the other day, it didnt. pushing now --------- Co-authored-by: Chillbruhhh <joshchesser97@gmail.com> Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -1288,7 +1288,7 @@ async def stop_crawl_task(progress_id: str):
|
||||
|
||||
found = False
|
||||
# Step 1: Cancel the orchestration service
|
||||
orchestration = get_active_orchestration(progress_id)
|
||||
orchestration = await get_active_orchestration(progress_id)
|
||||
if orchestration:
|
||||
orchestration.cancel()
|
||||
found = True
|
||||
@@ -1306,7 +1306,7 @@ async def stop_crawl_task(progress_id: str):
|
||||
found = True
|
||||
|
||||
# Step 3: Remove from active orchestrations registry
|
||||
unregister_orchestration(progress_id)
|
||||
await unregister_orchestration(progress_id)
|
||||
|
||||
# Step 4: Update progress tracker to reflect cancellation (only if we found and cancelled something)
|
||||
if found:
|
||||
|
||||
@@ -140,6 +140,7 @@ class CodeExtractionService:
|
||||
progress_callback: Callable | None = None,
|
||||
cancellation_check: Callable[[], None] | None = None,
|
||||
provider: str | None = None,
|
||||
embedding_provider: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Extract code examples from crawled documents and store them.
|
||||
@@ -150,6 +151,8 @@ class CodeExtractionService:
|
||||
source_id: The unique source_id for all documents
|
||||
progress_callback: Optional async callback for progress updates
|
||||
cancellation_check: Optional function to check for cancellation
|
||||
provider: Optional LLM provider identifier for summary generation
|
||||
embedding_provider: Optional embedding provider override for vector creation
|
||||
|
||||
Returns:
|
||||
Number of code examples stored
|
||||
@@ -158,9 +161,16 @@ class CodeExtractionService:
|
||||
extraction_callback = None
|
||||
if progress_callback:
|
||||
async def extraction_progress(data: dict):
|
||||
# Scale progress to 0-20% range
|
||||
raw_progress = data.get("progress", 0)
|
||||
scaled_progress = int(raw_progress * 0.2) # 0-20%
|
||||
# Scale progress to 0-20% range with normalization similar to later phases
|
||||
raw = data.get("progress", data.get("percentage", 0))
|
||||
try:
|
||||
raw_num = float(raw)
|
||||
except (TypeError, ValueError):
|
||||
raw_num = 0.0
|
||||
if 0.0 <= raw_num <= 1.0:
|
||||
raw_num *= 100.0
|
||||
# 0-20% with clamping
|
||||
scaled_progress = min(20, max(0, int(raw_num * 0.2)))
|
||||
data["progress"] = scaled_progress
|
||||
await progress_callback(data)
|
||||
extraction_callback = extraction_progress
|
||||
@@ -197,8 +207,15 @@ class CodeExtractionService:
|
||||
if progress_callback:
|
||||
async def summary_progress(data: dict):
|
||||
# Scale progress to 20-90% range
|
||||
raw_progress = data.get("progress", 0)
|
||||
scaled_progress = 20 + int(raw_progress * 0.7) # 20-90%
|
||||
raw = data.get("progress", data.get("percentage", 0))
|
||||
try:
|
||||
raw_num = float(raw)
|
||||
except (TypeError, ValueError):
|
||||
raw_num = 0.0
|
||||
if 0.0 <= raw_num <= 1.0:
|
||||
raw_num *= 100.0
|
||||
# 20-90% with clamping
|
||||
scaled_progress = min(90, max(20, 20 + int(raw_num * 0.7)))
|
||||
data["progress"] = scaled_progress
|
||||
await progress_callback(data)
|
||||
summary_callback = summary_progress
|
||||
@@ -216,15 +233,26 @@ class CodeExtractionService:
|
||||
if progress_callback:
|
||||
async def storage_progress(data: dict):
|
||||
# Scale progress to 90-100% range
|
||||
raw_progress = data.get("progress", 0)
|
||||
scaled_progress = 90 + int(raw_progress * 0.1) # 90-100%
|
||||
raw = data.get("progress", data.get("percentage", 0))
|
||||
try:
|
||||
raw_num = float(raw)
|
||||
except (TypeError, ValueError):
|
||||
raw_num = 0.0
|
||||
if 0.0 <= raw_num <= 1.0:
|
||||
raw_num *= 100.0
|
||||
# 90-100% with clamping
|
||||
scaled_progress = min(100, max(90, 90 + int(raw_num * 0.1)))
|
||||
data["progress"] = scaled_progress
|
||||
await progress_callback(data)
|
||||
storage_callback = storage_progress
|
||||
|
||||
# Store code examples in database
|
||||
return await self._store_code_examples(
|
||||
storage_data, url_to_full_document, storage_callback, provider
|
||||
storage_data,
|
||||
url_to_full_document,
|
||||
storage_callback,
|
||||
provider,
|
||||
embedding_provider,
|
||||
)
|
||||
|
||||
async def _extract_code_blocks_from_documents(
|
||||
@@ -880,9 +908,20 @@ class CodeExtractionService:
|
||||
current_indent = indent
|
||||
block_start_idx = i
|
||||
current_block.append(line)
|
||||
elif current_block and len("\n".join(current_block)) >= min_length:
|
||||
elif current_block:
|
||||
block_text = "\n".join(current_block)
|
||||
threshold = (
|
||||
min_length
|
||||
if min_length is not None
|
||||
else await self._get_min_code_length()
|
||||
)
|
||||
if len(block_text) < threshold:
|
||||
current_block = []
|
||||
current_indent = None
|
||||
continue
|
||||
|
||||
# End of indented block, check if it's code
|
||||
code_content = "\n".join(current_block)
|
||||
code_content = block_text
|
||||
|
||||
# Try to detect language from content
|
||||
language = self._detect_language_from_content(code_content)
|
||||
@@ -1670,12 +1709,20 @@ class CodeExtractionService:
|
||||
url_to_full_document: dict[str, str],
|
||||
progress_callback: Callable | None = None,
|
||||
provider: str | None = None,
|
||||
embedding_provider: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Store code examples in the database.
|
||||
|
||||
Returns:
|
||||
Number of code examples stored
|
||||
|
||||
Args:
|
||||
storage_data: Prepared code example payloads
|
||||
url_to_full_document: Mapping of URLs to their full document content
|
||||
progress_callback: Optional callback for progress updates
|
||||
provider: Optional LLM provider identifier for summaries
|
||||
embedding_provider: Optional embedding provider override for vector storage
|
||||
"""
|
||||
# Create progress callback for storage phase
|
||||
storage_progress_callback = None
|
||||
@@ -1713,6 +1760,7 @@ class CodeExtractionService:
|
||||
url_to_full_document=url_to_full_document,
|
||||
progress_callback=storage_progress_callback,
|
||||
provider=provider,
|
||||
embedding_provider=embedding_provider,
|
||||
)
|
||||
|
||||
# Report completion of code extraction/storage phase
|
||||
|
||||
@@ -14,6 +14,7 @@ from typing import Any, Optional
|
||||
from ...config.logfire_config import get_logger, safe_logfire_error, safe_logfire_info
|
||||
from ...utils import get_supabase_client
|
||||
from ...utils.progress.progress_tracker import ProgressTracker
|
||||
from ..credential_service import credential_service
|
||||
|
||||
# Import strategies
|
||||
# Import operations
|
||||
@@ -32,22 +33,35 @@ logger = get_logger(__name__)
|
||||
|
||||
# Global registry to track active orchestration services for cancellation support
|
||||
_active_orchestrations: dict[str, "CrawlingService"] = {}
|
||||
_orchestration_lock: asyncio.Lock | None = None
|
||||
|
||||
|
||||
def get_active_orchestration(progress_id: str) -> Optional["CrawlingService"]:
|
||||
def _ensure_orchestration_lock() -> asyncio.Lock:
|
||||
global _orchestration_lock
|
||||
if _orchestration_lock is None:
|
||||
_orchestration_lock = asyncio.Lock()
|
||||
return _orchestration_lock
|
||||
|
||||
|
||||
async def get_active_orchestration(progress_id: str) -> Optional["CrawlingService"]:
|
||||
"""Get an active orchestration service by progress ID."""
|
||||
return _active_orchestrations.get(progress_id)
|
||||
lock = _ensure_orchestration_lock()
|
||||
async with lock:
|
||||
return _active_orchestrations.get(progress_id)
|
||||
|
||||
|
||||
def register_orchestration(progress_id: str, orchestration: "CrawlingService"):
|
||||
async def register_orchestration(progress_id: str, orchestration: "CrawlingService"):
|
||||
"""Register an active orchestration service."""
|
||||
_active_orchestrations[progress_id] = orchestration
|
||||
lock = _ensure_orchestration_lock()
|
||||
async with lock:
|
||||
_active_orchestrations[progress_id] = orchestration
|
||||
|
||||
|
||||
def unregister_orchestration(progress_id: str):
|
||||
async def unregister_orchestration(progress_id: str):
|
||||
"""Unregister an orchestration service."""
|
||||
if progress_id in _active_orchestrations:
|
||||
del _active_orchestrations[progress_id]
|
||||
lock = _ensure_orchestration_lock()
|
||||
async with lock:
|
||||
_active_orchestrations.pop(progress_id, None)
|
||||
|
||||
|
||||
class CrawlingService:
|
||||
@@ -246,7 +260,7 @@ class CrawlingService:
|
||||
|
||||
# Register this orchestration service for cancellation support
|
||||
if self.progress_id:
|
||||
register_orchestration(self.progress_id, self)
|
||||
await register_orchestration(self.progress_id, self)
|
||||
|
||||
# Start the crawl as an async task in the main event loop
|
||||
# Store the task reference for proper cancellation
|
||||
@@ -477,15 +491,27 @@ class CrawlingService:
|
||||
try:
|
||||
# Extract provider from request or use credential service default
|
||||
provider = request.get("provider")
|
||||
embedding_provider = None
|
||||
|
||||
if not provider:
|
||||
try:
|
||||
from ..credential_service import credential_service
|
||||
provider_config = await credential_service.get_active_provider("llm")
|
||||
provider = provider_config.get("provider", "openai")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get provider from credential service: {e}, defaulting to openai")
|
||||
logger.warning(
|
||||
f"Failed to get provider from credential service: {e}, defaulting to openai"
|
||||
)
|
||||
provider = "openai"
|
||||
|
||||
try:
|
||||
embedding_config = await credential_service.get_active_provider("embedding")
|
||||
embedding_provider = embedding_config.get("provider")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to get embedding provider from credential service: {e}. Using configured default."
|
||||
)
|
||||
embedding_provider = None
|
||||
|
||||
code_examples_count = await self.doc_storage_ops.extract_and_store_code_examples(
|
||||
crawl_results,
|
||||
storage_results["url_to_full_document"],
|
||||
@@ -493,6 +519,7 @@ class CrawlingService:
|
||||
code_progress_callback,
|
||||
self._check_cancellation,
|
||||
provider,
|
||||
embedding_provider,
|
||||
)
|
||||
except RuntimeError as e:
|
||||
# Code extraction failed, continue crawl with warning
|
||||
@@ -548,7 +575,7 @@ class CrawlingService:
|
||||
|
||||
# Unregister after successful completion
|
||||
if self.progress_id:
|
||||
unregister_orchestration(self.progress_id)
|
||||
await unregister_orchestration(self.progress_id)
|
||||
safe_logfire_info(
|
||||
f"Unregistered orchestration service after completion | progress_id={self.progress_id}"
|
||||
)
|
||||
@@ -567,7 +594,7 @@ class CrawlingService:
|
||||
)
|
||||
# Unregister on cancellation
|
||||
if self.progress_id:
|
||||
unregister_orchestration(self.progress_id)
|
||||
await unregister_orchestration(self.progress_id)
|
||||
safe_logfire_info(
|
||||
f"Unregistered orchestration service on cancellation | progress_id={self.progress_id}"
|
||||
)
|
||||
@@ -591,7 +618,7 @@ class CrawlingService:
|
||||
await self.progress_tracker.error(error_message)
|
||||
# Unregister on error
|
||||
if self.progress_id:
|
||||
unregister_orchestration(self.progress_id)
|
||||
await unregister_orchestration(self.progress_id)
|
||||
safe_logfire_info(
|
||||
f"Unregistered orchestration service on error | progress_id={self.progress_id}"
|
||||
)
|
||||
|
||||
@@ -352,6 +352,7 @@ class DocumentStorageOperations:
|
||||
progress_callback: Callable | None = None,
|
||||
cancellation_check: Callable[[], None] | None = None,
|
||||
provider: str | None = None,
|
||||
embedding_provider: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Extract code examples from crawled documents and store them.
|
||||
@@ -363,12 +364,19 @@ class DocumentStorageOperations:
|
||||
progress_callback: Optional callback for progress updates
|
||||
cancellation_check: Optional function to check for cancellation
|
||||
provider: Optional LLM provider to use for code summaries
|
||||
embedding_provider: Optional embedding provider override for code example embeddings
|
||||
|
||||
Returns:
|
||||
Number of code examples stored
|
||||
"""
|
||||
result = await self.code_extraction_service.extract_and_store_code_examples(
|
||||
crawl_results, url_to_full_document, source_id, progress_callback, cancellation_check, provider
|
||||
crawl_results,
|
||||
url_to_full_document,
|
||||
source_id,
|
||||
progress_callback,
|
||||
cancellation_check,
|
||||
provider,
|
||||
embedding_provider,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -36,42 +36,6 @@ class CredentialItem:
|
||||
description: str | None = None
|
||||
|
||||
|
||||
def _detect_embedding_provider_from_model(embedding_model: str) -> str:
|
||||
"""
|
||||
Detect the appropriate embedding provider based on model name.
|
||||
|
||||
Args:
|
||||
embedding_model: The embedding model name
|
||||
|
||||
Returns:
|
||||
Provider name: 'google', 'openai', or 'openai' (default)
|
||||
"""
|
||||
if not embedding_model:
|
||||
return "openai" # Default
|
||||
|
||||
model_lower = embedding_model.lower()
|
||||
|
||||
# Google embedding models
|
||||
google_patterns = [
|
||||
"text-embedding-004",
|
||||
"text-embedding-005",
|
||||
"text-multilingual-embedding",
|
||||
"gemini-embedding",
|
||||
"multimodalembedding"
|
||||
]
|
||||
|
||||
if any(pattern in model_lower for pattern in google_patterns):
|
||||
return "google"
|
||||
|
||||
# OpenAI embedding models (and default for unknown)
|
||||
openai_patterns = [
|
||||
"text-embedding-ada-002",
|
||||
"text-embedding-3-small",
|
||||
"text-embedding-3-large"
|
||||
]
|
||||
|
||||
# Default to OpenAI for OpenAI models or unknown models
|
||||
return "openai"
|
||||
|
||||
|
||||
class CredentialService:
|
||||
@@ -475,26 +439,24 @@ class CredentialService:
|
||||
|
||||
# Get the selected provider based on service type
|
||||
if service_type == "embedding":
|
||||
# Get the LLM provider setting to determine embedding provider
|
||||
llm_provider = rag_settings.get("LLM_PROVIDER", "openai")
|
||||
embedding_model = rag_settings.get("EMBEDDING_MODEL", "text-embedding-3-small")
|
||||
# First check for explicit EMBEDDING_PROVIDER setting (new split provider approach)
|
||||
explicit_embedding_provider = rag_settings.get("EMBEDDING_PROVIDER")
|
||||
|
||||
# Determine embedding provider based on LLM provider
|
||||
if llm_provider == "google":
|
||||
provider = "google"
|
||||
elif llm_provider == "ollama":
|
||||
provider = "ollama"
|
||||
elif llm_provider == "openrouter":
|
||||
# OpenRouter supports both OpenAI and Google embedding models
|
||||
provider = _detect_embedding_provider_from_model(embedding_model)
|
||||
elif llm_provider in ["anthropic", "grok"]:
|
||||
# Anthropic and Grok support both OpenAI and Google embedding models
|
||||
provider = _detect_embedding_provider_from_model(embedding_model)
|
||||
# Validate that embedding provider actually supports embeddings
|
||||
embedding_capable_providers = {"openai", "google", "ollama"}
|
||||
|
||||
if (explicit_embedding_provider and
|
||||
explicit_embedding_provider != "" and
|
||||
explicit_embedding_provider in embedding_capable_providers):
|
||||
# Use the explicitly set embedding provider
|
||||
provider = explicit_embedding_provider
|
||||
logger.debug(f"Using explicit embedding provider: '{provider}'")
|
||||
else:
|
||||
# Default case (openai, or unknown providers)
|
||||
# Fall back to OpenAI as default embedding provider for backward compatibility
|
||||
if explicit_embedding_provider and explicit_embedding_provider not in embedding_capable_providers:
|
||||
logger.warning(f"Invalid embedding provider '{explicit_embedding_provider}' doesn't support embeddings, defaulting to OpenAI")
|
||||
provider = "openai"
|
||||
|
||||
logger.debug(f"Determined embedding provider '{provider}' from LLM provider '{llm_provider}' and embedding model '{embedding_model}'")
|
||||
logger.debug(f"No explicit embedding provider set, defaulting to OpenAI for backward compatibility")
|
||||
else:
|
||||
provider = rag_settings.get("LLM_PROVIDER", "openai")
|
||||
# Ensure provider is a valid string, not a boolean or other type
|
||||
|
||||
@@ -5,15 +5,19 @@ Handles all OpenAI embedding operations with proper rate limiting and error hand
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
import openai
|
||||
|
||||
from ...config.logfire_config import safe_span, search_logger
|
||||
from ..credential_service import credential_service
|
||||
from ..llm_provider_service import get_embedding_model, get_llm_client, is_google_embedding_model, is_openai_embedding_model
|
||||
from ..llm_provider_service import get_embedding_model, get_llm_client
|
||||
from ..threading_service import get_threading_service
|
||||
from .embedding_exceptions import (
|
||||
EmbeddingAPIError,
|
||||
@@ -64,6 +68,167 @@ class EmbeddingBatchResult:
|
||||
return self.success_count + self.failure_count
|
||||
|
||||
|
||||
class EmbeddingProviderAdapter(ABC):
|
||||
"""Adapter interface for embedding providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def create_embeddings(
|
||||
self,
|
||||
texts: list[str],
|
||||
model: str,
|
||||
dimensions: int | None = None,
|
||||
) -> list[list[float]]:
|
||||
"""Create embeddings for the given texts."""
|
||||
|
||||
|
||||
class OpenAICompatibleEmbeddingAdapter(EmbeddingProviderAdapter):
|
||||
"""Adapter for providers using the OpenAI embeddings API shape."""
|
||||
|
||||
def __init__(self, client: Any):
|
||||
self._client = client
|
||||
|
||||
async def create_embeddings(
|
||||
self,
|
||||
texts: list[str],
|
||||
model: str,
|
||||
dimensions: int | None = None,
|
||||
) -> list[list[float]]:
|
||||
request_args: dict[str, Any] = {
|
||||
"model": model,
|
||||
"input": texts,
|
||||
}
|
||||
if dimensions is not None:
|
||||
request_args["dimensions"] = dimensions
|
||||
|
||||
response = await self._client.embeddings.create(**request_args)
|
||||
return [item.embedding for item in response.data]
|
||||
|
||||
|
||||
class GoogleEmbeddingAdapter(EmbeddingProviderAdapter):
|
||||
"""Adapter for Google's native embedding endpoint."""
|
||||
|
||||
async def create_embeddings(
|
||||
self,
|
||||
texts: list[str],
|
||||
model: str,
|
||||
dimensions: int | None = None,
|
||||
) -> list[list[float]]:
|
||||
try:
|
||||
google_api_key = await credential_service.get_credential("GOOGLE_API_KEY")
|
||||
if not google_api_key:
|
||||
raise EmbeddingAPIError("Google API key not found")
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as http_client:
|
||||
embeddings = await asyncio.gather(
|
||||
*(
|
||||
self._fetch_single_embedding(http_client, google_api_key, model, text, dimensions)
|
||||
for text in texts
|
||||
)
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
except httpx.HTTPStatusError as error:
|
||||
error_content = error.response.text
|
||||
search_logger.error(
|
||||
f"Google embedding API returned {error.response.status_code} - {error_content}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise EmbeddingAPIError(
|
||||
f"Google embedding API error: {error.response.status_code} - {error_content}",
|
||||
original_error=error,
|
||||
) from error
|
||||
except Exception as error:
|
||||
search_logger.error(f"Error calling Google embedding API: {error}", exc_info=True)
|
||||
raise EmbeddingAPIError(
|
||||
f"Google embedding error: {str(error)}", original_error=error
|
||||
) from error
|
||||
|
||||
async def _fetch_single_embedding(
|
||||
self,
|
||||
http_client: httpx.AsyncClient,
|
||||
api_key: str,
|
||||
model: str,
|
||||
text: str,
|
||||
dimensions: int | None = None,
|
||||
) -> list[float]:
|
||||
if model.startswith("models/"):
|
||||
url_model = model[len("models/") :]
|
||||
payload_model = model
|
||||
else:
|
||||
url_model = model
|
||||
payload_model = f"models/{model}"
|
||||
url = f"https://generativelanguage.googleapis.com/v1beta/models/{url_model}:embedContent"
|
||||
headers = {
|
||||
"x-goog-api-key": api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"model": payload_model,
|
||||
"content": {"parts": [{"text": text}]},
|
||||
}
|
||||
|
||||
# Add output_dimensionality parameter if dimensions are specified and supported
|
||||
if dimensions is not None and dimensions > 0:
|
||||
model_name = payload_model.removeprefix("models/")
|
||||
if model_name.startswith("textembedding-gecko"):
|
||||
supported_dimensions = {128, 256, 512, 768}
|
||||
else:
|
||||
supported_dimensions = {128, 256, 512, 768, 1024, 1536, 2048, 3072}
|
||||
|
||||
if dimensions in supported_dimensions:
|
||||
payload["outputDimensionality"] = dimensions
|
||||
else:
|
||||
search_logger.warning(
|
||||
f"Requested dimension {dimensions} is not supported by Google model '{model_name}'. "
|
||||
"Falling back to the provider default."
|
||||
)
|
||||
|
||||
response = await http_client.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
embedding = result.get("embedding", {})
|
||||
values = embedding.get("values") if isinstance(embedding, dict) else None
|
||||
if not isinstance(values, list):
|
||||
raise EmbeddingAPIError(f"Invalid embedding payload from Google: {result}")
|
||||
|
||||
# Normalize embeddings for dimensions < 3072 as per Google's documentation
|
||||
actual_dimension = len(values)
|
||||
if actual_dimension > 0 and actual_dimension < 3072:
|
||||
values = self._normalize_embedding(values)
|
||||
|
||||
return values
|
||||
|
||||
def _normalize_embedding(self, embedding: list[float]) -> list[float]:
|
||||
"""Normalize embedding vector for dimensions < 3072."""
|
||||
try:
|
||||
embedding_array = np.array(embedding, dtype=np.float32)
|
||||
norm = np.linalg.norm(embedding_array)
|
||||
if norm > 0:
|
||||
normalized = embedding_array / norm
|
||||
return normalized.tolist()
|
||||
else:
|
||||
search_logger.warning("Zero-norm embedding detected, returning unnormalized")
|
||||
return embedding
|
||||
except Exception as e:
|
||||
search_logger.error(f"Failed to normalize embedding: {e}")
|
||||
# Return original embedding if normalization fails
|
||||
return embedding
|
||||
|
||||
|
||||
def _get_embedding_adapter(provider: str, client: Any) -> EmbeddingProviderAdapter:
|
||||
provider_name = (provider or "").lower()
|
||||
if provider_name == "google":
|
||||
return GoogleEmbeddingAdapter()
|
||||
return OpenAICompatibleEmbeddingAdapter(client)
|
||||
|
||||
|
||||
async def _maybe_await(value: Any) -> Any:
|
||||
"""Await the value if it is awaitable, otherwise return as-is."""
|
||||
|
||||
return await value if inspect.isawaitable(value) else value
|
||||
|
||||
# Provider-aware client factory
|
||||
get_openai_client = get_llm_client
|
||||
|
||||
@@ -185,27 +350,25 @@ async def create_embeddings_batch(
|
||||
"create_embeddings_batch", text_count=len(texts), total_chars=sum(len(t) for t in texts)
|
||||
) as span:
|
||||
try:
|
||||
# Intelligent embedding provider routing based on model type
|
||||
# Get the embedding model first to determine the correct provider
|
||||
embedding_model = await get_embedding_model(provider=provider)
|
||||
embedding_config = await _maybe_await(
|
||||
credential_service.get_active_provider(service_type="embedding")
|
||||
)
|
||||
|
||||
# Route to correct provider based on model type
|
||||
if is_google_embedding_model(embedding_model):
|
||||
embedding_provider = "google"
|
||||
search_logger.info(f"Routing to Google for embedding model: {embedding_model}")
|
||||
elif is_openai_embedding_model(embedding_model) or "openai/" in embedding_model.lower():
|
||||
embedding_provider = provider or embedding_config.get("provider")
|
||||
|
||||
if not isinstance(embedding_provider, str) or not embedding_provider.strip():
|
||||
embedding_provider = "openai"
|
||||
search_logger.info(f"Routing to OpenAI for embedding model: {embedding_model}")
|
||||
else:
|
||||
# Keep original provider for ollama and other providers
|
||||
embedding_provider = provider
|
||||
search_logger.info(f"Using original provider '{provider}' for embedding model: {embedding_model}")
|
||||
|
||||
if not embedding_provider:
|
||||
search_logger.error("No embedding provider configured")
|
||||
raise ValueError("No embedding provider configured. Please set EMBEDDING_PROVIDER environment variable.")
|
||||
|
||||
search_logger.info(f"Using embedding provider: '{embedding_provider}' (from EMBEDDING_PROVIDER setting)")
|
||||
async with get_llm_client(provider=embedding_provider, use_embedding_provider=True) as client:
|
||||
# Load batch size and dimensions from settings
|
||||
try:
|
||||
rag_settings = await credential_service.get_credentials_by_category(
|
||||
"rag_strategy"
|
||||
rag_settings = await _maybe_await(
|
||||
credential_service.get_credentials_by_category("rag_strategy")
|
||||
)
|
||||
batch_size = int(rag_settings.get("EMBEDDING_BATCH_SIZE", "100"))
|
||||
embedding_dimensions = int(rag_settings.get("EMBEDDING_DIMENSIONS", "1536"))
|
||||
@@ -215,6 +378,8 @@ async def create_embeddings_batch(
|
||||
embedding_dimensions = 1536
|
||||
|
||||
total_tokens_used = 0
|
||||
adapter = _get_embedding_adapter(embedding_provider, client)
|
||||
dimensions_to_use = embedding_dimensions if embedding_dimensions > 0 else None
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
@@ -243,16 +408,14 @@ async def create_embeddings_batch(
|
||||
try:
|
||||
# Create embeddings for this batch
|
||||
embedding_model = await get_embedding_model(provider=embedding_provider)
|
||||
|
||||
response = await client.embeddings.create(
|
||||
model=embedding_model,
|
||||
input=batch,
|
||||
dimensions=embedding_dimensions,
|
||||
embeddings = await adapter.create_embeddings(
|
||||
batch,
|
||||
embedding_model,
|
||||
dimensions=dimensions_to_use,
|
||||
)
|
||||
|
||||
# Add successful embeddings
|
||||
for text, item in zip(batch, response.data, strict=False):
|
||||
result.add_success(item.embedding, text)
|
||||
for text, vector in zip(batch, embeddings, strict=False):
|
||||
result.add_success(vector, text)
|
||||
|
||||
break # Success, exit retry loop
|
||||
|
||||
@@ -297,6 +460,17 @@ async def create_embeddings_batch(
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
raise # Will be caught by outer try
|
||||
except EmbeddingRateLimitError as e:
|
||||
retry_count += 1
|
||||
if retry_count < max_retries:
|
||||
wait_time = 2**retry_count
|
||||
search_logger.warning(
|
||||
f"Embedding rate limit for batch {batch_index}: {e}. "
|
||||
f"Waiting {wait_time}s before retry {retry_count}/{max_retries}"
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
# This batch failed - track failures but continue with next batch
|
||||
|
||||
@@ -1091,6 +1091,7 @@ async def add_code_examples_to_supabase(
|
||||
url_to_full_document: dict[str, str] | None = None,
|
||||
progress_callback: Callable | None = None,
|
||||
provider: str | None = None,
|
||||
embedding_provider: str | None = None,
|
||||
):
|
||||
"""
|
||||
Add code examples to the Supabase code_examples table in batches.
|
||||
@@ -1105,6 +1106,8 @@ async def add_code_examples_to_supabase(
|
||||
batch_size: Size of each batch for insertion
|
||||
url_to_full_document: Optional mapping of URLs to full document content
|
||||
progress_callback: Optional async callback for progress updates
|
||||
provider: Optional LLM provider used for summary generation tracking
|
||||
embedding_provider: Optional embedding provider override for vector generation
|
||||
"""
|
||||
if not urls:
|
||||
return
|
||||
@@ -1183,8 +1186,8 @@ async def add_code_examples_to_supabase(
|
||||
# Use original combined texts
|
||||
batch_texts = combined_texts
|
||||
|
||||
# Create embeddings for the batch
|
||||
result = await create_embeddings_batch(batch_texts, provider=provider)
|
||||
# Create embeddings for the batch (optionally overriding the embedding provider)
|
||||
result = await create_embeddings_batch(batch_texts, provider=embedding_provider)
|
||||
|
||||
# Log any failures
|
||||
if result.has_failures:
|
||||
@@ -1201,7 +1204,7 @@ async def add_code_examples_to_supabase(
|
||||
from ..llm_provider_service import get_embedding_model
|
||||
|
||||
# Get embedding model name
|
||||
embedding_model_name = await get_embedding_model(provider=provider)
|
||||
embedding_model_name = await get_embedding_model(provider=embedding_provider)
|
||||
|
||||
# Get LLM chat model (used for code summaries and contextual embeddings if enabled)
|
||||
llm_chat_model = None
|
||||
|
||||
@@ -111,8 +111,8 @@ class TestCodeExtractionSourceId:
|
||||
assert args[2] == source_id
|
||||
assert args[3] is None
|
||||
assert args[4] is None
|
||||
if len(args) > 5:
|
||||
assert args[5] is None
|
||||
assert args[5] is None
|
||||
assert args[6] is None
|
||||
assert result == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user