mirror of
https://github.com/coleam00/Archon.git
synced 2025-12-27 04:00:29 -05:00
421 lines
16 KiB
Python
421 lines
16 KiB
Python
"""
|
|
RAG Agent - Conversational Search and Retrieval with PydanticAI
|
|
|
|
This agent enables users to search and chat with documents stored in the RAG system.
|
|
It uses the perform_rag_query functionality to retrieve relevant content and provide
|
|
intelligent responses based on the retrieved information.
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from typing import Any
|
|
|
|
from pydantic import BaseModel, Field
|
|
from pydantic_ai import Agent, RunContext
|
|
|
|
from .base_agent import ArchonDependencies, BaseAgent
|
|
from .mcp_client import get_mcp_client
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class RagDependencies(ArchonDependencies):
|
|
"""Dependencies for RAG operations."""
|
|
|
|
project_id: str | None = None
|
|
source_filter: str | None = None
|
|
match_count: int = 5
|
|
progress_callback: Any | None = None # Callback for progress updates
|
|
|
|
|
|
class RagQueryResult(BaseModel):
|
|
"""Structured output for RAG query results."""
|
|
|
|
query_type: str = Field(description="Type of query: search, explain, summarize, compare")
|
|
original_query: str = Field(description="The original user query")
|
|
refined_query: str | None = Field(
|
|
description="Refined query used for search if different from original"
|
|
)
|
|
results_found: int = Field(description="Number of relevant results found")
|
|
sources: list[str] = Field(description="List of unique sources referenced")
|
|
answer: str = Field(description="The synthesized answer based on retrieved content")
|
|
citations: list[dict[str, Any]] = Field(description="Citations with source and relevance info")
|
|
success: bool = Field(description="Whether the query was successful")
|
|
message: str = Field(description="Status message or error description")
|
|
|
|
|
|
class RagAgent(BaseAgent[RagDependencies, str]):
|
|
"""
|
|
Conversational agent for RAG-based document search and retrieval.
|
|
|
|
Capabilities:
|
|
- Search documents using natural language queries
|
|
- Filter by specific sources
|
|
- Search code examples
|
|
- Provide synthesized answers with citations
|
|
- Explain concepts found in documentation
|
|
"""
|
|
|
|
def __init__(self, model: str = None, **kwargs):
|
|
# Use provided model or fall back to default
|
|
if model is None:
|
|
model = os.getenv("RAG_AGENT_MODEL", "openai:gpt-4o-mini")
|
|
|
|
super().__init__(
|
|
model=model, name="RagAgent", retries=3, enable_rate_limiting=True, **kwargs
|
|
)
|
|
|
|
def _create_agent(self, **kwargs) -> Agent:
|
|
"""Create the PydanticAI agent with tools and prompts."""
|
|
|
|
agent = Agent(
|
|
model=self.model,
|
|
deps_type=RagDependencies,
|
|
system_prompt="""You are a RAG (Retrieval-Augmented Generation) Assistant that helps users search and understand documentation through conversation.
|
|
|
|
**Your Capabilities:**
|
|
- Search through crawled documentation using semantic search
|
|
- Filter searches by specific sources or domains
|
|
- Find relevant code examples
|
|
- Synthesize information from multiple sources
|
|
- Provide clear, cited answers based on retrieved content
|
|
- Explain technical concepts found in documentation
|
|
|
|
**Your Approach:**
|
|
1. **Understand the query** - Interpret what the user is looking for
|
|
2. **Search effectively** - Use appropriate search terms and filters
|
|
3. **Analyze results** - Review retrieved content for relevance
|
|
4. **Synthesize answers** - Combine information from multiple sources
|
|
5. **Cite sources** - Always provide references to source documents
|
|
|
|
**Common Queries:**
|
|
- "What resources/sources are available?" → Use list_available_sources tool
|
|
- "Search for X" → Use search_documents tool
|
|
- "Find code examples for Y" → Use search_code_examples tool
|
|
- "What documentation do you have?" → Use list_available_sources tool
|
|
|
|
**Search Strategies:**
|
|
- For conceptual questions: Use broader search terms
|
|
- For specific features: Use exact terminology
|
|
- For code examples: Search for function names, patterns
|
|
- For comparisons: Search for each item separately
|
|
|
|
**Response Guidelines:**
|
|
- Provide direct answers based on retrieved content
|
|
- Include relevant quotes from sources
|
|
- Cite sources with URLs when available
|
|
- Admit when information is not found
|
|
- Suggest alternative searches if needed""",
|
|
**kwargs,
|
|
)
|
|
|
|
# Register dynamic system prompt for context
|
|
@agent.system_prompt
|
|
async def add_search_context(ctx: RunContext[RagDependencies]) -> str:
|
|
source_info = (
|
|
f"Source Filter: {ctx.deps.source_filter}"
|
|
if ctx.deps.source_filter
|
|
else "No source filter"
|
|
)
|
|
return f"""
|
|
**Current Search Context:**
|
|
- Project ID: {ctx.deps.project_id or "Global search"}
|
|
- {source_info}
|
|
- Max Results: {ctx.deps.match_count}
|
|
- Timestamp: {datetime.now().isoformat()}
|
|
"""
|
|
|
|
# Register tools for RAG operations
|
|
@agent.tool
|
|
async def search_documents(
|
|
ctx: RunContext[RagDependencies], query: str, source_filter: str | None = None
|
|
) -> str:
|
|
"""Search through documents using RAG query."""
|
|
try:
|
|
# Use source filter from context if not provided
|
|
if source_filter is None:
|
|
source_filter = ctx.deps.source_filter
|
|
|
|
# Use MCP client to perform RAG query
|
|
mcp_client = await get_mcp_client()
|
|
result_json = await mcp_client.perform_rag_query(
|
|
query=query, source=source_filter, match_count=ctx.deps.match_count
|
|
)
|
|
|
|
# Parse the JSON response
|
|
import json
|
|
|
|
result = json.loads(result_json)
|
|
|
|
if not result.get("success", False):
|
|
return f"Search failed: {result.get('error', 'Unknown error')}"
|
|
|
|
results = result.get("results", [])
|
|
if not results:
|
|
return "No results found for your query. Try using different search terms or removing filters."
|
|
|
|
# Format results for display
|
|
formatted_results = []
|
|
for i, res in enumerate(results, 1):
|
|
similarity = res.get("similarity_score", res.get("similarity", 0))
|
|
metadata = res.get("metadata", {})
|
|
source = metadata.get("source", "Unknown")
|
|
url = metadata.get("url", res.get("url", ""))
|
|
content = res.get("content", "")
|
|
|
|
# Truncate content if too long
|
|
if len(content) > 500:
|
|
content = content[:500] + "..."
|
|
|
|
formatted_results.append(
|
|
f"**Result {i}** (Relevance: {similarity:.2%})\n"
|
|
f"Source: {source}\n"
|
|
f"URL: {url}\n"
|
|
f"Content: {content}\n"
|
|
)
|
|
|
|
return f"Found {len(results)} relevant results:\n\n" + "\n---\n".join(
|
|
formatted_results
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error searching documents: {e}")
|
|
return f"Error performing search: {str(e)}"
|
|
|
|
@agent.tool
|
|
async def list_available_sources(ctx: RunContext[RagDependencies]) -> str:
|
|
"""List all available sources that can be searched."""
|
|
try:
|
|
# Use MCP client to get available sources
|
|
mcp_client = await get_mcp_client()
|
|
result_json = await mcp_client.get_available_sources()
|
|
|
|
# Parse the JSON response
|
|
import json
|
|
|
|
result = json.loads(result_json)
|
|
|
|
if not result.get("success", False):
|
|
return f"Failed to get sources: {result.get('error', 'Unknown error')}"
|
|
|
|
sources = result.get("sources", [])
|
|
if not sources:
|
|
return "No sources are currently available. You may need to crawl some documentation first."
|
|
|
|
source_list = []
|
|
for source in sources:
|
|
source_id = source.get("source_id", "Unknown")
|
|
title = source.get("title", "Untitled")
|
|
description = source.get("description", "")
|
|
created = source.get("created_at", "")
|
|
|
|
# Format the description if available
|
|
desc_text = f" - {description}" if description else ""
|
|
|
|
source_list.append(
|
|
f"- **{source_id}**: {title}{desc_text} (added {created[:10]})"
|
|
)
|
|
|
|
return f"Available sources ({len(sources)} total):\n" + "\n".join(source_list)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error listing sources: {e}")
|
|
return f"Error retrieving sources: {str(e)}"
|
|
|
|
@agent.tool
|
|
async def search_code_examples(
|
|
ctx: RunContext[RagDependencies], query: str, source_filter: str | None = None
|
|
) -> str:
|
|
"""Search for code examples related to the query."""
|
|
try:
|
|
# Use source filter from context if not provided
|
|
if source_filter is None:
|
|
source_filter = ctx.deps.source_filter
|
|
|
|
# Use MCP client to search code examples
|
|
mcp_client = await get_mcp_client()
|
|
result_json = await mcp_client.search_code_examples(
|
|
query=query, source_id=source_filter, match_count=ctx.deps.match_count
|
|
)
|
|
|
|
# Parse the JSON response
|
|
import json
|
|
|
|
result = json.loads(result_json)
|
|
|
|
if not result.get("success", False):
|
|
return f"Code search failed: {result.get('error', 'Unknown error')}"
|
|
|
|
examples = result.get("results", result.get("code_examples", []))
|
|
if not examples:
|
|
return "No code examples found for your query."
|
|
|
|
formatted_examples = []
|
|
for i, example in enumerate(examples, 1):
|
|
similarity = example.get("similarity", 0)
|
|
summary = example.get("summary", "No summary")
|
|
code = example.get("code", example.get("code_block", ""))
|
|
url = example.get("url", "")
|
|
|
|
# Extract language from code block if available
|
|
lang = "code"
|
|
if code.startswith("```"):
|
|
first_line = code.split("\n")[0]
|
|
if len(first_line) > 3:
|
|
lang = first_line[3:].strip()
|
|
|
|
formatted_examples.append(
|
|
f"**Example {i}** (Relevance: {similarity:.2%})\n"
|
|
f"Summary: {summary}\n"
|
|
f"Source: {url}\n"
|
|
f"```{lang}\n{code}\n```"
|
|
)
|
|
|
|
return f"Found {len(examples)} code examples:\n\n" + "\n---\n".join(
|
|
formatted_examples
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error searching code examples: {e}")
|
|
return f"Error searching code: {str(e)}"
|
|
|
|
@agent.tool
|
|
async def refine_search_query(
|
|
ctx: RunContext[RagDependencies], original_query: str, context: str
|
|
) -> str:
|
|
"""Refine a search query based on context to get better results."""
|
|
try:
|
|
# Simple query expansion based on context
|
|
refined_parts = [original_query]
|
|
|
|
# Add contextual keywords
|
|
if "how" in original_query.lower():
|
|
refined_parts.append("tutorial guide example")
|
|
elif "what" in original_query.lower():
|
|
refined_parts.append("definition explanation overview")
|
|
elif "error" in original_query.lower() or "issue" in original_query.lower():
|
|
refined_parts.append("troubleshooting solution fix")
|
|
elif "api" in original_query.lower():
|
|
refined_parts.append("endpoint method parameters response")
|
|
|
|
# Add project-specific context if available
|
|
if ctx.deps.project_id:
|
|
refined_parts.append(f"project:{ctx.deps.project_id}")
|
|
|
|
refined_query = " ".join(refined_parts)
|
|
return f"Refined query: '{refined_query}' (original: '{original_query}')"
|
|
|
|
except Exception as e:
|
|
return f"Could not refine query: {str(e)}"
|
|
|
|
return agent
|
|
|
|
def get_system_prompt(self) -> str:
|
|
"""Get the base system prompt for this agent."""
|
|
try:
|
|
from ..services.prompt_service import prompt_service
|
|
|
|
return prompt_service.get_prompt(
|
|
"rag_assistant",
|
|
default="RAG Assistant for intelligent document search and retrieval.",
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Could not load prompt from service: {e}")
|
|
return "RAG Assistant for intelligent document search and retrieval."
|
|
|
|
async def run_conversation(
|
|
self,
|
|
user_message: str,
|
|
project_id: str | None = None,
|
|
source_filter: str | None = None,
|
|
match_count: int = 5,
|
|
user_id: str = None,
|
|
progress_callback: Any = None,
|
|
) -> RagQueryResult:
|
|
"""
|
|
Run the agent for conversational RAG queries.
|
|
|
|
Args:
|
|
user_message: The user's search query or question
|
|
project_id: Optional project ID for context
|
|
source_filter: Optional source domain to filter results
|
|
match_count: Maximum number of results to return
|
|
user_id: ID of the user making the request
|
|
progress_callback: Optional callback for progress updates
|
|
|
|
Returns:
|
|
Structured RagQueryResult
|
|
"""
|
|
deps = RagDependencies(
|
|
project_id=project_id,
|
|
source_filter=source_filter,
|
|
match_count=match_count,
|
|
user_id=user_id,
|
|
progress_callback=progress_callback,
|
|
)
|
|
|
|
try:
|
|
# Run the agent and get the string response
|
|
response_text = await self.run(user_message, deps)
|
|
self.logger.info("RAG query completed successfully")
|
|
|
|
# Create a structured result from the response text
|
|
# Try to extract some basic information from the response
|
|
query_type = "search" # Default type
|
|
results_found = 0
|
|
sources = []
|
|
|
|
# Simple analysis of the response to gather metadata
|
|
if "found" in response_text.lower() and "results" in response_text.lower():
|
|
# Try to extract number of results
|
|
import re
|
|
|
|
match = re.search(r"found (\d+)", response_text.lower())
|
|
if match:
|
|
results_found = int(match.group(1))
|
|
|
|
if "available sources" in response_text.lower():
|
|
query_type = "list_sources"
|
|
elif "code example" in response_text.lower():
|
|
query_type = "code_search"
|
|
elif "no results" in response_text.lower():
|
|
results_found = 0
|
|
|
|
# Extract source references if present
|
|
source_lines = [line for line in response_text.split("\n") if "Source:" in line]
|
|
sources = [line.split("Source:")[-1].strip() for line in source_lines]
|
|
|
|
return RagQueryResult(
|
|
query_type=query_type,
|
|
original_query=user_message,
|
|
refined_query=None,
|
|
results_found=results_found,
|
|
sources=list(set(sources)), # Remove duplicates
|
|
answer=response_text,
|
|
citations=[], # Could be enhanced to extract citations
|
|
success=True,
|
|
message="Query completed successfully",
|
|
)
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"RAG query failed: {str(e)}")
|
|
# Return error result
|
|
return RagQueryResult(
|
|
query_type="error",
|
|
original_query=user_message,
|
|
refined_query=None,
|
|
results_found=0,
|
|
sources=[],
|
|
answer=f"I encountered an error while searching: {str(e)}",
|
|
citations=[],
|
|
success=False,
|
|
message=f"Failed to process query: {str(e)}",
|
|
)
|
|
|
|
|
|
# Note: RagAgent instances should be created on-demand in API endpoints
|
|
# to avoid initialization issues during module import
|