Feature/LLM-Providers-UI-Polished (#736)

* Add Anthropic and Grok provider support

* feat: Add crucial GPT-5 and reasoning model support for OpenRouter

- Add requires_max_completion_tokens() function for GPT-5, o1, o3, Grok-3 series
- Add prepare_chat_completion_params() for reasoning model compatibility
- Implement max_tokens → max_completion_tokens conversion for reasoning models
- Add temperature handling for reasoning models (must be 1.0 default)
- Enhanced provider validation and API key security in provider endpoints
- Streamlined retry logic (3→2 attempts) for faster issue detection
- Add failure tracking and circuit breaker analysis for debugging
- Support OpenRouter format detection (openai/gpt-5-nano, openai/o1-mini)
- Improved Grok provider empty response handling with structured fallbacks
- Enhanced contextual embedding with provider-aware model selection

Core provider functionality:
- OpenRouter, Grok, Anthropic provider support with full embedding integration
- Provider-specific model defaults and validation
- Secure API connectivity testing endpoints
- Provider context passing for code generation workflows

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fully working model providers, addressing securtiy and code related concerns, throughly hardening our code

* added multiprovider support, embeddings model support, cleaned the pr, need to fix health check, asnyico tasks errors, and contextual embeddings error

* fixed contextual embeddings issue

* - Added inspect-aware shutdown handling so get_llm_client always closes the underlying AsyncOpenAI / httpx.AsyncClient while the loop is   still alive, with defensive logging if shutdown happens late (python/src/server/services/llm_provider_service.py:14, python/src/server/    services/llm_provider_service.py:520).

* - Restructured get_llm_client so client creation and usage live in separate try/finally blocks; fallback clients now close without         logging spurious Error creating LLM client when downstream code raises (python/src/server/services/llm_provider_service.py:335-556).    - Close logic now sanitizes provider names consistently and awaits whichever aclose/close coroutine the SDK exposes, keeping the loop      shut down cleanly (python/src/server/services/llm_provider_service.py:530-559).                                                                                                                                                                                                       Robust JSON Parsing                                                                                                                                                                                                                                                                   - Added _extract_json_payload to strip code fences / extra text returned by Ollama before json.loads runs, averting the markdown-induced   decode errors you saw in logs (python/src/server/services/storage/code_storage_service.py:40-63).                                          - Swapped the direct parse call for the sanitized payload and emit a debug preview when cleanup alters the content (python/src/server/     services/storage/code_storage_service.py:858-864).

* added provider connection support

* added provider api key not being configured warning

* Updated get_llm_client so missing OpenAI keys automatically fall back to Ollama (matching existing tests) and so unsupported providers     still raise the legacy ValueError the suite expects. The fallback now reuses _get_optimal_ollama_instance and rethrows ValueError(OpenAI  API key not found and Ollama fallback failed) when it cant connect.  Adjusted test_code_extraction_source_id.py to accept the new optional argument on the mocked extractor (and confirm its None when         present).

* Resolved a few needed code rabbit suggestion   - Updated the knowledge API key validation to call create_embedding with the provider argument and removed the hard-coded OpenAI fallback  (python/src/server/api_routes/knowledge_api.py).                                                                                           - Broadened embedding provider detection so prefixed OpenRouter/OpenAI model names route through the correct client (python/src/server/    services/embeddings/embedding_service.py, python/src/server/services/llm_provider_service.py).                                             - Removed the duplicate helper definitions from llm_provider_service.py, eliminating the stray docstring that was causing the import-time  syntax error.

* updated via code rabbit PR review, code rabbit in my IDE found no issues and no nitpicks with the updates! what was done:    Credential service now persists the provider under the uppercase key LLM_PROVIDER, matching the read path (no new EMBEDDING_PROVIDER     usage introduced).                                                                                                                          Embedding batch creation stops inserting blank strings, logging failures and skipping invalid items before they ever hit the provider    (python/src/server/services/embeddings/embedding_service.py).                                                                               Contextual embedding prompts use real newline characters everywhereboth when constructing the batch prompt and when parsing the         models response (python/src/server/services/embeddings/contextual_embedding_service.py).                                                   Embedding provider routing already recognizes OpenRouter-prefixed OpenAI models via is_openai_embedding_model; no further change needed  there.                                                                                                                                      Embedding insertion now skips unsupported vector dimensions instead of forcing them into the 1536-column, and the backoff loop uses      await asyncio.sleep so we no longer block the event loop (python/src/server/services/storage/code_storage_service.py).                      RAG settings props were extended to include LLM_INSTANCE_NAME and OLLAMA_EMBEDDING_INSTANCE_NAME, and the debug log no longer prints     API-key prefixes (the rest of the TanStack refactor/EMBEDDING_PROVIDER support remains deferred).

* test fix

* enhanced Openrouters parsing logic to automatically detect reasoning models and parse regardless of json output or not. this commit creates a robust way for archons parsing to work throughly with openrouter automatically, regardless of the model youre using, to ensure proper functionality with out breaking any generation capabilities!

* updated ui llm interface, added seprate embeddings provider, made the system fully capabale of mix and matching llm providers (local and non local) for chat & embeddings. updated the ragsettings.tsx ui mainly, along with core functionality

* added warning labels and updated ollama health checks

* ready for review, fixed som error warnings and consildated ollama status health checks

* fixed FAILED test_async_embedding_service.py

* code rabbit fixes

* Separated the code-summary LLM provider from the embedding provider, so code example storage now forwards a dedicated embedding provider override end-to-end without hijacking the embedding pipeline. this fixes code rabbits (Preserve provider override in create_embeddings_batch) suggesting

* - Swapped API credential storage to booleans so decrypted keys never sit in React state (archon-ui-main/src/components/
  settings/RAGSettings.tsx).
  - Normalized Ollama instance URLs and gated the metrics effect on real state changes to avoid mis-counts and duplicate
  fetches (RAGSettings.tsx).
  - Tightened crawl progress scaling and indented-block parsing to handle min_length=None safely (python/src/server/
  services/crawling/code_extraction_service.py:160, python/src/server/services/crawling/code_extraction_service.py:911).
  - Added provider-agnostic embedding rate-limit retries so Google and friends back off gracefully (python/src/server/
  services/embeddings/embedding_service.py:427).
  - Made the orchestration registry async + thread-safe and updated every caller to await it (python/src/server/services/
  crawling/crawling_service.py:34, python/src/server/api_routes/knowledge_api.py:1291).

* Update RAGSettings.tsx - header for 'LLM Settings' is now 'LLM Provider Settings'

* (RAG Settings)

  - Ollama Health Checks & Metrics
      - Added a 10-second timeout to the health fetch so it doesn't hang.
      - Adjusted logic so metric refreshes run for embedding-only Ollama setups too.
      - Initial page load now checks Ollama if either chat or embedding provider uses it.
      - Metrics and alerts now respect which provider (chat/embedding) is currently selected.
  - Provider Sync & Alerts
      - Fixed a sync bug so the very first provider change updates settings as expected.
      - Alerts now track the active provider (chat vs embedding) rather than only the LLM provider.
      - Warnings about missing credentials now skip whichever provider is currently selected.
  - Modals & Types
      - Normalize URLs before handing them to selection modals to keep consistent data.
      - Strengthened helper function types (getDisplayedChatModel, getModelPlaceholder, etc.).

 (Crawling Service)

  - Made the orchestration registry lock lazy-initialized to avoid issues in Python 3.12 and wrapped registry commands
  (register, unregister) in async calls. This keeps things thread-safe even during concurrent crawling and cancellation.

* - migration/complete_setup.sql:101 seeds Google/OpenRouter/Anthropic/Grok API key rows so fresh databases expose every
  provider by default.
  - migration/0.1.0/009_add_provider_placeholders.sql:1 backfills the same rows for existing Supabase instances and
  records the migration.
  - archon-ui-main/src/components/settings/RAGSettings.tsx:121 introduces a shared credentialprovider map,
  reloadApiCredentials runs through all five providers, and the status poller includes the new keys.
  - archon-ui-main/src/components/settings/RAGSettings.tsx:353 subscribes to the archon:credentials-updated browser event
  so adding/removing a key immediately refetches credential status and pings the corresponding connectivity test.
  - archon-ui-main/src/components/settings/RAGSettings.tsx:926 now treats missing Anthropic/OpenRouter/Grok keys as
  missing, preventing stale connected badges when a key is removed.

* - archon-ui-main/src/components/settings/RAGSettings.tsx:90 adds a simple display-name map and reuses one red alert
  style.
  - archon-ui-main/src/components/settings/RAGSettings.tsx:1016 now shows exactly one red banner when the active provider
  - Removed the old duplicate Missing API Key Configuration block, so the panel no longer stacks two warnings.

* Update credentialsService.ts default model

* updated the google embedding adapter for multi dimensional rag querying

* thought this micro fix in the google embedding pushed with the embedding update the other day, it didnt. pushing now

---------

Co-authored-by: Chillbruhhh <joshchesser97@gmail.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Josh
2025-10-05 13:49:09 -05:00
committed by GitHub
parent 63a92cf7d7
commit a580fdfe66
14 changed files with 2194 additions and 1388 deletions

View File

@@ -1288,7 +1288,7 @@ async def stop_crawl_task(progress_id: str):
found = False
# Step 1: Cancel the orchestration service
orchestration = get_active_orchestration(progress_id)
orchestration = await get_active_orchestration(progress_id)
if orchestration:
orchestration.cancel()
found = True
@@ -1306,7 +1306,7 @@ async def stop_crawl_task(progress_id: str):
found = True
# Step 3: Remove from active orchestrations registry
unregister_orchestration(progress_id)
await unregister_orchestration(progress_id)
# Step 4: Update progress tracker to reflect cancellation (only if we found and cancelled something)
if found:

View File

@@ -140,6 +140,7 @@ class CodeExtractionService:
progress_callback: Callable | None = None,
cancellation_check: Callable[[], None] | None = None,
provider: str | None = None,
embedding_provider: str | None = None,
) -> int:
"""
Extract code examples from crawled documents and store them.
@@ -150,6 +151,8 @@ class CodeExtractionService:
source_id: The unique source_id for all documents
progress_callback: Optional async callback for progress updates
cancellation_check: Optional function to check for cancellation
provider: Optional LLM provider identifier for summary generation
embedding_provider: Optional embedding provider override for vector creation
Returns:
Number of code examples stored
@@ -158,9 +161,16 @@ class CodeExtractionService:
extraction_callback = None
if progress_callback:
async def extraction_progress(data: dict):
# Scale progress to 0-20% range
raw_progress = data.get("progress", 0)
scaled_progress = int(raw_progress * 0.2) # 0-20%
# Scale progress to 0-20% range with normalization similar to later phases
raw = data.get("progress", data.get("percentage", 0))
try:
raw_num = float(raw)
except (TypeError, ValueError):
raw_num = 0.0
if 0.0 <= raw_num <= 1.0:
raw_num *= 100.0
# 0-20% with clamping
scaled_progress = min(20, max(0, int(raw_num * 0.2)))
data["progress"] = scaled_progress
await progress_callback(data)
extraction_callback = extraction_progress
@@ -197,8 +207,15 @@ class CodeExtractionService:
if progress_callback:
async def summary_progress(data: dict):
# Scale progress to 20-90% range
raw_progress = data.get("progress", 0)
scaled_progress = 20 + int(raw_progress * 0.7) # 20-90%
raw = data.get("progress", data.get("percentage", 0))
try:
raw_num = float(raw)
except (TypeError, ValueError):
raw_num = 0.0
if 0.0 <= raw_num <= 1.0:
raw_num *= 100.0
# 20-90% with clamping
scaled_progress = min(90, max(20, 20 + int(raw_num * 0.7)))
data["progress"] = scaled_progress
await progress_callback(data)
summary_callback = summary_progress
@@ -216,15 +233,26 @@ class CodeExtractionService:
if progress_callback:
async def storage_progress(data: dict):
# Scale progress to 90-100% range
raw_progress = data.get("progress", 0)
scaled_progress = 90 + int(raw_progress * 0.1) # 90-100%
raw = data.get("progress", data.get("percentage", 0))
try:
raw_num = float(raw)
except (TypeError, ValueError):
raw_num = 0.0
if 0.0 <= raw_num <= 1.0:
raw_num *= 100.0
# 90-100% with clamping
scaled_progress = min(100, max(90, 90 + int(raw_num * 0.1)))
data["progress"] = scaled_progress
await progress_callback(data)
storage_callback = storage_progress
# Store code examples in database
return await self._store_code_examples(
storage_data, url_to_full_document, storage_callback, provider
storage_data,
url_to_full_document,
storage_callback,
provider,
embedding_provider,
)
async def _extract_code_blocks_from_documents(
@@ -880,9 +908,20 @@ class CodeExtractionService:
current_indent = indent
block_start_idx = i
current_block.append(line)
elif current_block and len("\n".join(current_block)) >= min_length:
elif current_block:
block_text = "\n".join(current_block)
threshold = (
min_length
if min_length is not None
else await self._get_min_code_length()
)
if len(block_text) < threshold:
current_block = []
current_indent = None
continue
# End of indented block, check if it's code
code_content = "\n".join(current_block)
code_content = block_text
# Try to detect language from content
language = self._detect_language_from_content(code_content)
@@ -1670,12 +1709,20 @@ class CodeExtractionService:
url_to_full_document: dict[str, str],
progress_callback: Callable | None = None,
provider: str | None = None,
embedding_provider: str | None = None,
) -> int:
"""
Store code examples in the database.
Returns:
Number of code examples stored
Args:
storage_data: Prepared code example payloads
url_to_full_document: Mapping of URLs to their full document content
progress_callback: Optional callback for progress updates
provider: Optional LLM provider identifier for summaries
embedding_provider: Optional embedding provider override for vector storage
"""
# Create progress callback for storage phase
storage_progress_callback = None
@@ -1713,6 +1760,7 @@ class CodeExtractionService:
url_to_full_document=url_to_full_document,
progress_callback=storage_progress_callback,
provider=provider,
embedding_provider=embedding_provider,
)
# Report completion of code extraction/storage phase

View File

@@ -14,6 +14,7 @@ from typing import Any, Optional
from ...config.logfire_config import get_logger, safe_logfire_error, safe_logfire_info
from ...utils import get_supabase_client
from ...utils.progress.progress_tracker import ProgressTracker
from ..credential_service import credential_service
# Import strategies
# Import operations
@@ -32,22 +33,35 @@ logger = get_logger(__name__)
# Global registry to track active orchestration services for cancellation support
_active_orchestrations: dict[str, "CrawlingService"] = {}
_orchestration_lock: asyncio.Lock | None = None
def get_active_orchestration(progress_id: str) -> Optional["CrawlingService"]:
def _ensure_orchestration_lock() -> asyncio.Lock:
global _orchestration_lock
if _orchestration_lock is None:
_orchestration_lock = asyncio.Lock()
return _orchestration_lock
async def get_active_orchestration(progress_id: str) -> Optional["CrawlingService"]:
"""Get an active orchestration service by progress ID."""
return _active_orchestrations.get(progress_id)
lock = _ensure_orchestration_lock()
async with lock:
return _active_orchestrations.get(progress_id)
def register_orchestration(progress_id: str, orchestration: "CrawlingService"):
async def register_orchestration(progress_id: str, orchestration: "CrawlingService"):
"""Register an active orchestration service."""
_active_orchestrations[progress_id] = orchestration
lock = _ensure_orchestration_lock()
async with lock:
_active_orchestrations[progress_id] = orchestration
def unregister_orchestration(progress_id: str):
async def unregister_orchestration(progress_id: str):
"""Unregister an orchestration service."""
if progress_id in _active_orchestrations:
del _active_orchestrations[progress_id]
lock = _ensure_orchestration_lock()
async with lock:
_active_orchestrations.pop(progress_id, None)
class CrawlingService:
@@ -246,7 +260,7 @@ class CrawlingService:
# Register this orchestration service for cancellation support
if self.progress_id:
register_orchestration(self.progress_id, self)
await register_orchestration(self.progress_id, self)
# Start the crawl as an async task in the main event loop
# Store the task reference for proper cancellation
@@ -477,15 +491,27 @@ class CrawlingService:
try:
# Extract provider from request or use credential service default
provider = request.get("provider")
embedding_provider = None
if not provider:
try:
from ..credential_service import credential_service
provider_config = await credential_service.get_active_provider("llm")
provider = provider_config.get("provider", "openai")
except Exception as e:
logger.warning(f"Failed to get provider from credential service: {e}, defaulting to openai")
logger.warning(
f"Failed to get provider from credential service: {e}, defaulting to openai"
)
provider = "openai"
try:
embedding_config = await credential_service.get_active_provider("embedding")
embedding_provider = embedding_config.get("provider")
except Exception as e:
logger.warning(
f"Failed to get embedding provider from credential service: {e}. Using configured default."
)
embedding_provider = None
code_examples_count = await self.doc_storage_ops.extract_and_store_code_examples(
crawl_results,
storage_results["url_to_full_document"],
@@ -493,6 +519,7 @@ class CrawlingService:
code_progress_callback,
self._check_cancellation,
provider,
embedding_provider,
)
except RuntimeError as e:
# Code extraction failed, continue crawl with warning
@@ -548,7 +575,7 @@ class CrawlingService:
# Unregister after successful completion
if self.progress_id:
unregister_orchestration(self.progress_id)
await unregister_orchestration(self.progress_id)
safe_logfire_info(
f"Unregistered orchestration service after completion | progress_id={self.progress_id}"
)
@@ -567,7 +594,7 @@ class CrawlingService:
)
# Unregister on cancellation
if self.progress_id:
unregister_orchestration(self.progress_id)
await unregister_orchestration(self.progress_id)
safe_logfire_info(
f"Unregistered orchestration service on cancellation | progress_id={self.progress_id}"
)
@@ -591,7 +618,7 @@ class CrawlingService:
await self.progress_tracker.error(error_message)
# Unregister on error
if self.progress_id:
unregister_orchestration(self.progress_id)
await unregister_orchestration(self.progress_id)
safe_logfire_info(
f"Unregistered orchestration service on error | progress_id={self.progress_id}"
)

View File

@@ -352,6 +352,7 @@ class DocumentStorageOperations:
progress_callback: Callable | None = None,
cancellation_check: Callable[[], None] | None = None,
provider: str | None = None,
embedding_provider: str | None = None,
) -> int:
"""
Extract code examples from crawled documents and store them.
@@ -363,12 +364,19 @@ class DocumentStorageOperations:
progress_callback: Optional callback for progress updates
cancellation_check: Optional function to check for cancellation
provider: Optional LLM provider to use for code summaries
embedding_provider: Optional embedding provider override for code example embeddings
Returns:
Number of code examples stored
"""
result = await self.code_extraction_service.extract_and_store_code_examples(
crawl_results, url_to_full_document, source_id, progress_callback, cancellation_check, provider
crawl_results,
url_to_full_document,
source_id,
progress_callback,
cancellation_check,
provider,
embedding_provider,
)
return result

View File

@@ -36,42 +36,6 @@ class CredentialItem:
description: str | None = None
def _detect_embedding_provider_from_model(embedding_model: str) -> str:
"""
Detect the appropriate embedding provider based on model name.
Args:
embedding_model: The embedding model name
Returns:
Provider name: 'google', 'openai', or 'openai' (default)
"""
if not embedding_model:
return "openai" # Default
model_lower = embedding_model.lower()
# Google embedding models
google_patterns = [
"text-embedding-004",
"text-embedding-005",
"text-multilingual-embedding",
"gemini-embedding",
"multimodalembedding"
]
if any(pattern in model_lower for pattern in google_patterns):
return "google"
# OpenAI embedding models (and default for unknown)
openai_patterns = [
"text-embedding-ada-002",
"text-embedding-3-small",
"text-embedding-3-large"
]
# Default to OpenAI for OpenAI models or unknown models
return "openai"
class CredentialService:
@@ -475,26 +439,24 @@ class CredentialService:
# Get the selected provider based on service type
if service_type == "embedding":
# Get the LLM provider setting to determine embedding provider
llm_provider = rag_settings.get("LLM_PROVIDER", "openai")
embedding_model = rag_settings.get("EMBEDDING_MODEL", "text-embedding-3-small")
# First check for explicit EMBEDDING_PROVIDER setting (new split provider approach)
explicit_embedding_provider = rag_settings.get("EMBEDDING_PROVIDER")
# Determine embedding provider based on LLM provider
if llm_provider == "google":
provider = "google"
elif llm_provider == "ollama":
provider = "ollama"
elif llm_provider == "openrouter":
# OpenRouter supports both OpenAI and Google embedding models
provider = _detect_embedding_provider_from_model(embedding_model)
elif llm_provider in ["anthropic", "grok"]:
# Anthropic and Grok support both OpenAI and Google embedding models
provider = _detect_embedding_provider_from_model(embedding_model)
# Validate that embedding provider actually supports embeddings
embedding_capable_providers = {"openai", "google", "ollama"}
if (explicit_embedding_provider and
explicit_embedding_provider != "" and
explicit_embedding_provider in embedding_capable_providers):
# Use the explicitly set embedding provider
provider = explicit_embedding_provider
logger.debug(f"Using explicit embedding provider: '{provider}'")
else:
# Default case (openai, or unknown providers)
# Fall back to OpenAI as default embedding provider for backward compatibility
if explicit_embedding_provider and explicit_embedding_provider not in embedding_capable_providers:
logger.warning(f"Invalid embedding provider '{explicit_embedding_provider}' doesn't support embeddings, defaulting to OpenAI")
provider = "openai"
logger.debug(f"Determined embedding provider '{provider}' from LLM provider '{llm_provider}' and embedding model '{embedding_model}'")
logger.debug(f"No explicit embedding provider set, defaulting to OpenAI for backward compatibility")
else:
provider = rag_settings.get("LLM_PROVIDER", "openai")
# Ensure provider is a valid string, not a boolean or other type

View File

@@ -5,15 +5,19 @@ Handles all OpenAI embedding operations with proper rate limiting and error hand
"""
import asyncio
import inspect
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
import httpx
import numpy as np
import openai
from ...config.logfire_config import safe_span, search_logger
from ..credential_service import credential_service
from ..llm_provider_service import get_embedding_model, get_llm_client, is_google_embedding_model, is_openai_embedding_model
from ..llm_provider_service import get_embedding_model, get_llm_client
from ..threading_service import get_threading_service
from .embedding_exceptions import (
EmbeddingAPIError,
@@ -64,6 +68,167 @@ class EmbeddingBatchResult:
return self.success_count + self.failure_count
class EmbeddingProviderAdapter(ABC):
"""Adapter interface for embedding providers."""
@abstractmethod
async def create_embeddings(
self,
texts: list[str],
model: str,
dimensions: int | None = None,
) -> list[list[float]]:
"""Create embeddings for the given texts."""
class OpenAICompatibleEmbeddingAdapter(EmbeddingProviderAdapter):
"""Adapter for providers using the OpenAI embeddings API shape."""
def __init__(self, client: Any):
self._client = client
async def create_embeddings(
self,
texts: list[str],
model: str,
dimensions: int | None = None,
) -> list[list[float]]:
request_args: dict[str, Any] = {
"model": model,
"input": texts,
}
if dimensions is not None:
request_args["dimensions"] = dimensions
response = await self._client.embeddings.create(**request_args)
return [item.embedding for item in response.data]
class GoogleEmbeddingAdapter(EmbeddingProviderAdapter):
"""Adapter for Google's native embedding endpoint."""
async def create_embeddings(
self,
texts: list[str],
model: str,
dimensions: int | None = None,
) -> list[list[float]]:
try:
google_api_key = await credential_service.get_credential("GOOGLE_API_KEY")
if not google_api_key:
raise EmbeddingAPIError("Google API key not found")
async with httpx.AsyncClient(timeout=30.0) as http_client:
embeddings = await asyncio.gather(
*(
self._fetch_single_embedding(http_client, google_api_key, model, text, dimensions)
for text in texts
)
)
return embeddings
except httpx.HTTPStatusError as error:
error_content = error.response.text
search_logger.error(
f"Google embedding API returned {error.response.status_code} - {error_content}",
exc_info=True,
)
raise EmbeddingAPIError(
f"Google embedding API error: {error.response.status_code} - {error_content}",
original_error=error,
) from error
except Exception as error:
search_logger.error(f"Error calling Google embedding API: {error}", exc_info=True)
raise EmbeddingAPIError(
f"Google embedding error: {str(error)}", original_error=error
) from error
async def _fetch_single_embedding(
self,
http_client: httpx.AsyncClient,
api_key: str,
model: str,
text: str,
dimensions: int | None = None,
) -> list[float]:
if model.startswith("models/"):
url_model = model[len("models/") :]
payload_model = model
else:
url_model = model
payload_model = f"models/{model}"
url = f"https://generativelanguage.googleapis.com/v1beta/models/{url_model}:embedContent"
headers = {
"x-goog-api-key": api_key,
"Content-Type": "application/json",
}
payload = {
"model": payload_model,
"content": {"parts": [{"text": text}]},
}
# Add output_dimensionality parameter if dimensions are specified and supported
if dimensions is not None and dimensions > 0:
model_name = payload_model.removeprefix("models/")
if model_name.startswith("textembedding-gecko"):
supported_dimensions = {128, 256, 512, 768}
else:
supported_dimensions = {128, 256, 512, 768, 1024, 1536, 2048, 3072}
if dimensions in supported_dimensions:
payload["outputDimensionality"] = dimensions
else:
search_logger.warning(
f"Requested dimension {dimensions} is not supported by Google model '{model_name}'. "
"Falling back to the provider default."
)
response = await http_client.post(url, headers=headers, json=payload)
response.raise_for_status()
result = response.json()
embedding = result.get("embedding", {})
values = embedding.get("values") if isinstance(embedding, dict) else None
if not isinstance(values, list):
raise EmbeddingAPIError(f"Invalid embedding payload from Google: {result}")
# Normalize embeddings for dimensions < 3072 as per Google's documentation
actual_dimension = len(values)
if actual_dimension > 0 and actual_dimension < 3072:
values = self._normalize_embedding(values)
return values
def _normalize_embedding(self, embedding: list[float]) -> list[float]:
"""Normalize embedding vector for dimensions < 3072."""
try:
embedding_array = np.array(embedding, dtype=np.float32)
norm = np.linalg.norm(embedding_array)
if norm > 0:
normalized = embedding_array / norm
return normalized.tolist()
else:
search_logger.warning("Zero-norm embedding detected, returning unnormalized")
return embedding
except Exception as e:
search_logger.error(f"Failed to normalize embedding: {e}")
# Return original embedding if normalization fails
return embedding
def _get_embedding_adapter(provider: str, client: Any) -> EmbeddingProviderAdapter:
provider_name = (provider or "").lower()
if provider_name == "google":
return GoogleEmbeddingAdapter()
return OpenAICompatibleEmbeddingAdapter(client)
async def _maybe_await(value: Any) -> Any:
"""Await the value if it is awaitable, otherwise return as-is."""
return await value if inspect.isawaitable(value) else value
# Provider-aware client factory
get_openai_client = get_llm_client
@@ -185,27 +350,25 @@ async def create_embeddings_batch(
"create_embeddings_batch", text_count=len(texts), total_chars=sum(len(t) for t in texts)
) as span:
try:
# Intelligent embedding provider routing based on model type
# Get the embedding model first to determine the correct provider
embedding_model = await get_embedding_model(provider=provider)
embedding_config = await _maybe_await(
credential_service.get_active_provider(service_type="embedding")
)
# Route to correct provider based on model type
if is_google_embedding_model(embedding_model):
embedding_provider = "google"
search_logger.info(f"Routing to Google for embedding model: {embedding_model}")
elif is_openai_embedding_model(embedding_model) or "openai/" in embedding_model.lower():
embedding_provider = provider or embedding_config.get("provider")
if not isinstance(embedding_provider, str) or not embedding_provider.strip():
embedding_provider = "openai"
search_logger.info(f"Routing to OpenAI for embedding model: {embedding_model}")
else:
# Keep original provider for ollama and other providers
embedding_provider = provider
search_logger.info(f"Using original provider '{provider}' for embedding model: {embedding_model}")
if not embedding_provider:
search_logger.error("No embedding provider configured")
raise ValueError("No embedding provider configured. Please set EMBEDDING_PROVIDER environment variable.")
search_logger.info(f"Using embedding provider: '{embedding_provider}' (from EMBEDDING_PROVIDER setting)")
async with get_llm_client(provider=embedding_provider, use_embedding_provider=True) as client:
# Load batch size and dimensions from settings
try:
rag_settings = await credential_service.get_credentials_by_category(
"rag_strategy"
rag_settings = await _maybe_await(
credential_service.get_credentials_by_category("rag_strategy")
)
batch_size = int(rag_settings.get("EMBEDDING_BATCH_SIZE", "100"))
embedding_dimensions = int(rag_settings.get("EMBEDDING_DIMENSIONS", "1536"))
@@ -215,6 +378,8 @@ async def create_embeddings_batch(
embedding_dimensions = 1536
total_tokens_used = 0
adapter = _get_embedding_adapter(embedding_provider, client)
dimensions_to_use = embedding_dimensions if embedding_dimensions > 0 else None
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
@@ -243,16 +408,14 @@ async def create_embeddings_batch(
try:
# Create embeddings for this batch
embedding_model = await get_embedding_model(provider=embedding_provider)
response = await client.embeddings.create(
model=embedding_model,
input=batch,
dimensions=embedding_dimensions,
embeddings = await adapter.create_embeddings(
batch,
embedding_model,
dimensions=dimensions_to_use,
)
# Add successful embeddings
for text, item in zip(batch, response.data, strict=False):
result.add_success(item.embedding, text)
for text, vector in zip(batch, embeddings, strict=False):
result.add_success(vector, text)
break # Success, exit retry loop
@@ -297,6 +460,17 @@ async def create_embeddings_batch(
await asyncio.sleep(wait_time)
else:
raise # Will be caught by outer try
except EmbeddingRateLimitError as e:
retry_count += 1
if retry_count < max_retries:
wait_time = 2**retry_count
search_logger.warning(
f"Embedding rate limit for batch {batch_index}: {e}. "
f"Waiting {wait_time}s before retry {retry_count}/{max_retries}"
)
await asyncio.sleep(wait_time)
else:
raise
except Exception as e:
# This batch failed - track failures but continue with next batch

View File

@@ -1091,6 +1091,7 @@ async def add_code_examples_to_supabase(
url_to_full_document: dict[str, str] | None = None,
progress_callback: Callable | None = None,
provider: str | None = None,
embedding_provider: str | None = None,
):
"""
Add code examples to the Supabase code_examples table in batches.
@@ -1105,6 +1106,8 @@ async def add_code_examples_to_supabase(
batch_size: Size of each batch for insertion
url_to_full_document: Optional mapping of URLs to full document content
progress_callback: Optional async callback for progress updates
provider: Optional LLM provider used for summary generation tracking
embedding_provider: Optional embedding provider override for vector generation
"""
if not urls:
return
@@ -1183,8 +1186,8 @@ async def add_code_examples_to_supabase(
# Use original combined texts
batch_texts = combined_texts
# Create embeddings for the batch
result = await create_embeddings_batch(batch_texts, provider=provider)
# Create embeddings for the batch (optionally overriding the embedding provider)
result = await create_embeddings_batch(batch_texts, provider=embedding_provider)
# Log any failures
if result.has_failures:
@@ -1201,7 +1204,7 @@ async def add_code_examples_to_supabase(
from ..llm_provider_service import get_embedding_model
# Get embedding model name
embedding_model_name = await get_embedding_model(provider=provider)
embedding_model_name = await get_embedding_model(provider=embedding_provider)
# Get LLM chat model (used for code summaries and contextual embeddings if enabled)
llm_chat_model = None

View File

@@ -111,8 +111,8 @@ class TestCodeExtractionSourceId:
assert args[2] == source_id
assert args[3] is None
assert args[4] is None
if len(args) > 5:
assert args[5] is None
assert args[5] is None
assert args[6] is None
assert result == 5
@pytest.mark.asyncio