diff --git a/.env.example b/.env.example index 18b34cb7..4b99373b 100644 --- a/.env.example +++ b/.env.example @@ -1,2 +1,2 @@ # Copy this file to .env and add your actual API key -ANTHROPIC_API_KEY=your-anthropic-api-key-here \ No newline at end of file +MINMAX_API_KEY=your-minimax-api-key-here \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..c2d5da08 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,41 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## 项目概述 + +这是一个基于 FastAPI 的 RAG (Retrieval-Augmented Generation) 系统,用于回答关于课程材料的问题。核心组件: + +- **ChromaDB** - 向量存储,用于语义搜索 +- **MiniMax** - AI 生成模型 +- **FastAPI** - Web 框架和 API + +## 运行命令 + +```bash +# 快速启动 +./run.sh + +# 手动启动 +cd backend && uv run uvicorn app:app --reload --port 8000 + +# 安装依赖 +uv sync +``` + +## 架构要点 + +- `backend/rag_system.py` - 主协调器,整合所有组件 +- `backend/ai_generator.py` - 调用 MiniMax API,处理工具执行 +- `backend/search_tools.py` - 语义搜索工具,基于 ChromaDB +- `backend/document_processor.py` - 文档分块 (chunk_size=1000, overlap=200) +- `backend/session_manager.py` - 维护对话历史 + +## 配置 + +需要 `.env` 文件包含 `MINMAX_API_KEY`。使用 `.env.example` 作为模板。 + +## API 端点 + +- `POST /api/query` - 处理查询,返回答案和来源 +- `GET /api/courses` - 获取课程统计信息 \ No newline at end of file diff --git a/backend/ai_generator.py b/backend/ai_generator.py index 0363ca90..66eca676 100644 --- a/backend/ai_generator.py +++ b/backend/ai_generator.py @@ -1,14 +1,24 @@ -import anthropic +import requests +import json from typing import List, Optional, Dict, Any + class AIGenerator: - """Handles interactions with Anthropic's Claude API for generating responses""" - + """Handles interactions with MiniMax API for generating responses""" + # Static system prompt to avoid rebuilding on each call - SYSTEM_PROMPT = """ You are an AI assistant specialized in course materials and educational content with access to a comprehensive search tool for course information. + SYSTEM_PROMPT = """You are an AI assistant specialized in course materials and educational content with access to comprehensive search tools for course information. + +Tools Available: +1. search_course_content - Search for specific content within courses +2. get_course_outline - Get course outline with title, link, and all lessons + +Tool Selection Guidelines: +- Use **get_course_outline** for: course outline requests, listing all lessons, what lessons are in a course, course structure, syllabus queries +- Use **search_course_content** for: specific content questions, detailed information about topics Search Tool Usage: -- Use the search tool **only** for questions about specific course content or detailed educational materials +- Use the search tools **only** for questions about course content or outline - **One search per query maximum** - Synthesize search results into accurate, fact-based responses - If search yields no results, state this clearly without offering alternatives @@ -20,6 +30,10 @@ class AIGenerator: - Provide direct answers only — no reasoning process, search explanations, or question-type analysis - Do not mention "based on the search results" +When responding to outline queries, include: +- Course title +- Course link (if available) +- Number and title of each lesson All responses must be: 1. **Brief, Concise and focused** - Get to the point quickly @@ -28,108 +42,164 @@ class AIGenerator: 4. **Example-supported** - Include relevant examples when they aid understanding Provide only the direct answer to what was asked. """ - - def __init__(self, api_key: str, model: str): - self.client = anthropic.Anthropic(api_key=api_key) + + def __init__(self, api_key: str, base_url: str, model: str): + self.api_key = api_key + self.base_url = base_url self.model = model - - # Pre-build base API parameters - self.base_params = { + self.temperature = 0 + self.max_tokens = 800 + + def _make_request(self, messages: List[Dict], tools: Optional[List] = None, stream: bool = False): + """Make API request to MiniMax""" + url = f"{self.base_url}/v1/messages" + + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + # Format messages for MiniMax Anthropic-compatible API + formatted_messages = [] + system_prompt = "" + + for msg in messages: + if msg.get("role") == "system": + # Combine system messages + system_prompt += msg.get("text", "") + "\n" + else: + # Convert to Anthropic format: content as array with text object + content = msg.get("text", "") + formatted_messages.append({ + "role": msg.get("role", "user"), + "content": [{"type": "text", "text": content}] + }) + + payload = { "model": self.model, - "temperature": 0, - "max_tokens": 800 + "max_tokens": self.max_tokens, + "temperature": self.temperature, + "stream": stream } - + + if system_prompt: + payload["system"] = system_prompt.strip() + + if formatted_messages: + payload["messages"] = formatted_messages + + if tools: + payload["tools"] = tools + payload["tool_choice"] = {"type": "auto"} + + try: + response = requests.post(url, headers=headers, json=payload, timeout=60) + print(f"MiniMax request: {payload}") + print(f"MiniMax response status: {response.status_code}") + print(f"MiniMax response: {response.text[:500]}") + response.raise_for_status() + return response.json() + except requests.exceptions.HTTPError as e: + error_detail = e.response.text if e.response else str(e) + print(f"MiniMax API error detail: {error_detail}") + raise Exception(f"MiniMax API error: {error_detail}") + def generate_response(self, query: str, conversation_history: Optional[str] = None, tools: Optional[List] = None, tool_manager=None) -> str: """ - Generate AI response with optional tool usage and conversation context. - + Generate AI response with sequential tool calling (up to 2 rounds). + Args: query: The user's question or request conversation_history: Previous messages for context tools: Available tools the AI can use tool_manager: Manager to execute tools - + Returns: Generated response as string """ - - # Build system content efficiently - avoid string ops when possible - system_content = ( - f"{self.SYSTEM_PROMPT}\n\nPrevious conversation:\n{conversation_history}" - if conversation_history - else self.SYSTEM_PROMPT - ) - - # Prepare API call parameters efficiently - api_params = { - **self.base_params, - "messages": [{"role": "user", "content": query}], - "system": system_content - } - - # Add tools if available - if tools: - api_params["tools"] = tools - api_params["tool_choice"] = {"type": "auto"} - - # Get response from Claude - response = self.client.messages.create(**api_params) - - # Handle tool execution if needed - if response.stop_reason == "tool_use" and tool_manager: - return self._handle_tool_execution(response, api_params, tool_manager) - - # Return direct response - return response.content[0].text - - def _handle_tool_execution(self, initial_response, base_params: Dict[str, Any], tool_manager): - """ - Handle execution of tool calls and get follow-up response. - - Args: - initial_response: The response containing tool use requests - base_params: Base API parameters - tool_manager: Manager to execute tools - - Returns: - Final response text after tool execution - """ - # Start with existing messages - messages = base_params["messages"].copy() - - # Add AI's tool use response - messages.append({"role": "assistant", "content": initial_response.content}) - - # Execute all tool calls and collect results - tool_results = [] - for content_block in initial_response.content: - if content_block.type == "tool_use": - tool_result = tool_manager.execute_tool( - content_block.name, - **content_block.input - ) - - tool_results.append({ - "type": "tool_result", - "tool_use_id": content_block.id, - "content": tool_result - }) - - # Add tool results as single message - if tool_results: - messages.append({"role": "user", "content": tool_results}) - - # Prepare final API call without tools - final_params = { - **self.base_params, - "messages": messages, - "system": base_params["system"] + # Build initial messages + messages = [] + messages.append({"role": "system", "text": self.SYSTEM_PROMPT}) + if conversation_history: + messages.append({"role": "system", "text": f"Previous conversation:\n{conversation_history}"}) + messages.append({"role": "user", "text": query}) + + # Sequential tool calling: max 2 rounds + max_rounds = 2 + + for round_num in range(1, max_rounds + 1): + # Make request with tools enabled + response = self._make_request(messages, tools=tools) + + stop_reason = response.get("stop_reason") + content_blocks = response.get("content", []) + + # Check if model wants to use a tool + if stop_reason == "tool_use" and tool_manager: + tool_block = self._extract_tool_use(content_blocks) + if not tool_block: + break + + # Execute tool with error handling + tool_result = self._execute_tool_safely(tool_manager, tool_block) + + # Accumulate messages for next round + messages.append({"role": "assistant", "content": [tool_block]}) + messages.append(self._build_tool_result_message(tool_block, tool_result)) + + # If not last round, continue to next round + if round_num < max_rounds: + continue + else: + # Last round - make final call without tools + final_response = self._make_request(messages, tools=None) + content_blocks = final_response.get("content", []) + else: + # No tool use - return response + pass + + # Extract text and return + return self._extract_text_from_blocks(content_blocks) + + # Fallback: return whatever we have + return self._extract_text_from_blocks(content_blocks) + + def _extract_tool_use(self, content_blocks: List[Dict]) -> Optional[Dict]: + """Extract tool_use block from response content.""" + for block in content_blocks: + if block.get("type") == "tool_use": + return block + return None + + def _execute_tool_safely(self, tool_manager, tool_block: Dict) -> str: + """Execute tool with graceful error handling.""" + try: + tool_name = tool_block.get("name") + tool_input = tool_block.get("input", {}) + return tool_manager.execute_tool(tool_name, **tool_input) + except KeyError as e: + return f"Tool parameter error: {str(e)}" + except Exception as e: + return f"Tool execution error: {str(e)}" + + def _build_tool_result_message(self, tool_block: Dict, result: str) -> Dict: + """Build tool_result message in API format.""" + return { + "role": "user", + "content": [{ + "type": "tool_result", + "tool_use_id": tool_block.get("id"), + "content": result + }] } - - # Get final response - final_response = self.client.messages.create(**final_params) - return final_response.content[0].text \ No newline at end of file + + def _extract_text_from_blocks(self, content_blocks: List[Dict]) -> str: + """Extract text from content blocks.""" + response_text = "" + for block in content_blocks: + if block.get("type") == "text": + response_text += block.get("text", "") + return response_text \ No newline at end of file diff --git a/backend/app.py b/backend/app.py index 5a69d741..a1755b80 100644 --- a/backend/app.py +++ b/backend/app.py @@ -51,8 +51,32 @@ class CourseStats(BaseModel): total_courses: int course_titles: List[str] +class SessionClearRequest(BaseModel): + """Request model for clearing session""" + session_id: Optional[str] = None + +class SessionClearResponse(BaseModel): + """Response model for session clear""" + success: bool + cleared_session_id: Optional[str] = None + # API Endpoints +@app.post("/api/session/clear", response_model=SessionClearResponse) +async def clear_session(request: SessionClearRequest): + """Clear session history on the backend""" + try: + session_id = request.session_id + if session_id: + rag_system.session_manager.clear_session(session_id) + return SessionClearResponse( + success=True, + cleared_session_id=session_id + ) + return SessionClearResponse(success=True, cleared_session_id=None) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + @app.post("/api/query", response_model=QueryResponse) async def query_documents(request: QueryRequest): """Process a query and return response with sources""" @@ -61,16 +85,18 @@ async def query_documents(request: QueryRequest): session_id = request.session_id if not session_id: session_id = rag_system.session_manager.create_session() - + # Process query using RAG system answer, sources = rag_system.query(request.query, session_id) - + return QueryResponse( answer=answer, sources=sources, session_id=session_id ) except Exception as e: + import traceback + traceback.print_exc() raise HTTPException(status_code=500, detail=str(e)) @app.get("/api/courses", response_model=CourseStats) diff --git a/backend/config.py b/backend/config.py index d9f6392e..551c1c6c 100644 --- a/backend/config.py +++ b/backend/config.py @@ -8,19 +8,24 @@ @dataclass class Config: """Configuration settings for the RAG system""" - # Anthropic API settings - ANTHROPIC_API_KEY: str = os.getenv("ANTHROPIC_API_KEY", "") - ANTHROPIC_MODEL: str = "claude-sonnet-4-20250514" - + # MiniMax API settings + MINMAX_API_KEY: str = os.getenv("MINMAX_API_KEY", "") + MINMAX_BASE_URL: str = "https://api.minimaxi.com/anthropic" + MINMAX_MODEL: str = "MiniMax-M2.5" + + # Anthropic API settings (legacy - not used) + # ANTHROPIC_API_KEY: str = os.getenv("ANTHROPIC_API_KEY", "") + # ANTHROPIC_MODEL: str = "claude-sonnet-4-20250514" + # Embedding model settings EMBEDDING_MODEL: str = "all-MiniLM-L6-v2" - + # Document processing settings CHUNK_SIZE: int = 800 # Size of text chunks for vector storage - CHUNK_OVERLAP: int = 100 # Characters to overlap between chunks + CHUNK_OVERLAP: int = 100 # Characters to overlap between chunks MAX_RESULTS: int = 5 # Maximum search results to return - MAX_HISTORY: int = 2 # Number of conversation messages to remember - + MAX_HISTORY: int = 2 # Number of conversation messages to remember + # Database paths CHROMA_PATH: str = "./chroma_db" # ChromaDB storage location diff --git a/backend/rag_system.py b/backend/rag_system.py index 50d848c8..cabd72f7 100644 --- a/backend/rag_system.py +++ b/backend/rag_system.py @@ -4,7 +4,7 @@ from vector_store import VectorStore from ai_generator import AIGenerator from session_manager import SessionManager -from search_tools import ToolManager, CourseSearchTool +from search_tools import ToolManager, CourseSearchTool, CourseOutlineTool from models import Course, Lesson, CourseChunk class RAGSystem: @@ -16,13 +16,15 @@ def __init__(self, config): # Initialize core components self.document_processor = DocumentProcessor(config.CHUNK_SIZE, config.CHUNK_OVERLAP) self.vector_store = VectorStore(config.CHROMA_PATH, config.EMBEDDING_MODEL, config.MAX_RESULTS) - self.ai_generator = AIGenerator(config.ANTHROPIC_API_KEY, config.ANTHROPIC_MODEL) + self.ai_generator = AIGenerator(config.MINMAX_API_KEY, config.MINMAX_BASE_URL, config.MINMAX_MODEL) self.session_manager = SessionManager(config.MAX_HISTORY) # Initialize search tools self.tool_manager = ToolManager() self.search_tool = CourseSearchTool(self.vector_store) + self.outline_tool = CourseOutlineTool(self.vector_store) self.tool_manager.register_tool(self.search_tool) + self.tool_manager.register_tool(self.outline_tool) def add_course_document(self, file_path: str) -> Tuple[Course, int]: """ @@ -101,42 +103,80 @@ def add_course_folder(self, folder_path: str, clear_existing: bool = False) -> T def query(self, query: str, session_id: Optional[str] = None) -> Tuple[str, List[str]]: """ - Process a user query using the RAG system with tool-based search. - + Process a user query using the RAG system with pre-search approach. + Args: query: User's question session_id: Optional session ID for conversation context - + Returns: - Tuple of (response, sources list - empty for tool-based approach) + Tuple of (response, sources list) """ - # Create prompt for the AI with clear instructions - prompt = f"""Answer this question about course materials: {query}""" - + # Step 1: Search for relevant content first (instead of using tool calling) + search_results = self.vector_store.search(query=query, limit=5) + + # Step 2: Build context from search results + context = "" + sources = [] + + if not search_results.is_empty() and search_results.documents: + context_parts = [] + for doc, meta in zip(search_results.documents, search_results.metadata): + course_title = meta.get('course_title', 'unknown') + lesson_num = meta.get('lesson_number') + + # Build header + header = f"[{course_title}" + if lesson_num is not None: + header += f" - Lesson {lesson_num}" + header += "]" + + context_parts.append(f"{header}\n{doc}") + + # Track source with link + source_name = course_title + if lesson_num is not None: + source_name += f" - Lesson {lesson_num}" + + # Get link + course_link = self.vector_store.get_course_link(course_title) + lesson_link = None + if lesson_num is not None: + lesson_link = self.vector_store.get_lesson_link(course_title, lesson_num) + link = lesson_link or course_link + + source_with_link = f"{source_name}|{link}" if link else source_name + sources.append(source_with_link) + + context = "\n\n".join(context_parts) + context = f"Here are relevant course materials:\n\n{context}\n\n" + else: + context = "No relevant course materials found.\n\n" + + # Step 3: Create prompt with context injected + if context: + prompt = f"""{context}Answer this question about course materials: {query}""" + else: + prompt = f"""Answer this question about course materials: {query}""" + # Get conversation history if session exists history = None if session_id: history = self.session_manager.get_conversation_history(session_id) - - # Generate response using AI with tools + + # Generate response using AI (without tools - we did search manually) response = self.ai_generator.generate_response( query=prompt, conversation_history=history, - tools=self.tool_manager.get_tool_definitions(), - tool_manager=self.tool_manager + tools=None, # No tools - we pre-searched + tool_manager=None ) - - # Get sources from the search tool - sources = self.tool_manager.get_last_sources() - # Reset sources after retrieving them - self.tool_manager.reset_sources() - # Update conversation history if session_id: self.session_manager.add_exchange(session_id, query, response) - - # Return response with sources from tool searches + + # Return response with sources return response, sources def get_course_analytics(self) -> Dict: diff --git a/backend/search_tools.py b/backend/search_tools.py index adfe8235..0bd7af9f 100644 --- a/backend/search_tools.py +++ b/backend/search_tools.py @@ -1,22 +1,106 @@ -from typing import Dict, Any, Optional, Protocol +from typing import Dict, Any, Optional from abc import ABC, abstractmethod from vector_store import VectorStore, SearchResults class Tool(ABC): """Abstract base class for all tools""" - + @abstractmethod def get_tool_definition(self) -> Dict[str, Any]: - """Return Anthropic tool definition for this tool""" + """Return tool definition for this tool""" pass - + @abstractmethod def execute(self, **kwargs) -> str: """Execute the tool with given parameters""" pass +class CourseOutlineTool(Tool): + """Tool for retrieving course outline - title, link, and complete lesson list""" + + def __init__(self, vector_store: VectorStore): + self.store = vector_store + self.last_sources = [] + + def get_tool_definition(self) -> Dict[str, Any]: + """Return tool definition for course outline""" + return { + "type": "function", + "function": { + "name": "get_course_outline", + "description": "Get the complete outline of a course including course title, course link, and all lessons with their numbers and titles", + "parameters": { + "type": "object", + "properties": { + "course_title": { + "type": "string", + "description": "The title of the course to get outline for" + } + }, + "required": ["course_title"] + } + } + } + + def execute(self, course_title: str) -> str: + """ + Execute the course outline tool. + + Args: + course_title: The course title to get outline for + + Returns: + Formatted course outline with title, link, and lessons + """ + import json + + try: + # Get all courses metadata to find the matching course + all_courses = self.store.get_all_courses_metadata() + + # Find the course by title (partial match support) + matched_course = None + for course in all_courses: + if course_title.lower() in course.get('title', '').lower() or \ + course.get('title', '').lower() in course_title.lower(): + matched_course = course + break + + if not matched_course: + return f"No course found matching '{course_title}'" + + # Extract course info + title = matched_course.get('title', 'Unknown') + link = matched_course.get('course_link', '') + lessons = matched_course.get('lessons', []) + + # Build the formatted output + output = f"Course: {title}" + if link: + output += f"\nLink: {link}" + output += "\n\nLessons:" + + for lesson in lessons: + lesson_num = lesson.get('lesson_number', '?') + lesson_title = lesson.get('lesson_title', 'Untitled') + output += f"\n- Lesson {lesson_num}: {lesson_title}" + + # Track source for the UI + source = title + if link: + source_with_link = f"{source}|{link}" + self.last_sources = [source_with_link] + else: + self.last_sources = [source] + + return output + + except Exception as e: + return f"Error getting course outline: {str(e)}" + + class CourseSearchTool(Tool): """Tool for searching course content with semantic course name matching""" @@ -25,27 +109,30 @@ def __init__(self, vector_store: VectorStore): self.last_sources = [] # Track sources from last search def get_tool_definition(self) -> Dict[str, Any]: - """Return Anthropic tool definition for this tool""" + """Return MiniMax/OpenAI tool definition for this tool""" return { - "name": "search_course_content", - "description": "Search course materials with smart course name matching and lesson filtering", - "input_schema": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "What to search for in the course content" + "type": "function", + "function": { + "name": "search_course_content", + "description": "Search course materials with smart course name matching and lesson filtering", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "What to search for in the course content" + }, + "course_name": { + "type": "string", + "description": "Course title (partial matches work, e.g. 'MCP', 'Introduction')" + }, + "lesson_number": { + "type": "integer", + "description": "Specific lesson number to search within (e.g. 1, 2, 3)" + } }, - "course_name": { - "type": "string", - "description": "Course title (partial matches work, e.g. 'MCP', 'Introduction')" - }, - "lesson_number": { - "type": "integer", - "description": "Specific lesson number to search within (e.g. 1, 2, 3)" - } - }, - "required": ["query"] + "required": ["query"] + } } } @@ -89,28 +176,40 @@ def _format_results(self, results: SearchResults) -> str: """Format search results with course and lesson context""" formatted = [] sources = [] # Track sources for the UI - + for doc, meta in zip(results.documents, results.metadata): course_title = meta.get('course_title', 'unknown') lesson_num = meta.get('lesson_number') - - # Build context header + + # Get course and lesson links + course_link = self.store.get_course_link(course_title) + lesson_link = None + if lesson_num is not None: + lesson_link = self.store.get_lesson_link(course_title, lesson_num) + + # Prefer lesson link, fall back to course link + link = lesson_link or course_link + + # Build context header with embedded link header = f"[{course_title}" if lesson_num is not None: header += f" - Lesson {lesson_num}" header += "]" - - # Track source for the UI + + # Track source for the UI as object with link source = course_title if lesson_num is not None: source += f" - Lesson {lesson_num}" - sources.append(source) - + + # Store source with embedded link (as JSON for frontend to parse) + source_with_link = f"{source}|{link}" if link else source + sources.append(source_with_link) + formatted.append(f"{header}\n{doc}") - + # Store sources for retrieval self.last_sources = sources - + return "\n\n".join(formatted) class ToolManager: @@ -122,7 +221,8 @@ def __init__(self): def register_tool(self, tool: Tool): """Register any tool that implements the Tool interface""" tool_def = tool.get_tool_definition() - tool_name = tool_def.get("name") + # Handle both Anthropic format (direct name) and OpenAI/MiniMax format (nested in function) + tool_name = tool_def.get("name") or tool_def.get("function", {}).get("name") if not tool_name: raise ValueError("Tool must have a 'name' in its definition") self.tools[tool_name] = tool diff --git a/backend/tests/__init__.py b/backend/tests/__init__.py new file mode 100644 index 00000000..f80384eb --- /dev/null +++ b/backend/tests/__init__.py @@ -0,0 +1 @@ +# Test package for RAG system \ No newline at end of file diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py new file mode 100644 index 00000000..7e7ac938 --- /dev/null +++ b/backend/tests/conftest.py @@ -0,0 +1,251 @@ +""" +Pytest configuration and shared fixtures for RAG system tests. +""" +import pytest +from unittest.mock import MagicMock, AsyncMock, patch +from typing import List, Dict, Any +from datetime import datetime + + +# ============================================================================= +# Mock Data Fixtures +# ============================================================================= + +@pytest.fixture +def sample_course_data() -> Dict[str, Any]: + """Sample course data for testing.""" + return { + "course_id": "course-001", + "title": "Introduction to Python", + "description": "Learn Python programming basics", + "lessons": [ + { + "lesson_id": "lesson-001", + "title": "Getting Started", + "content": "Welcome to Python programming.", + "lesson_number": 1 + }, + { + "lesson_id": "lesson-002", + "title": "Variables and Types", + "content": "Learn about variables and data types in Python.", + "lesson_number": 2 + } + ] + } + + +@pytest.fixture +def sample_course_chunks() -> List[Dict[str, Any]]: + """Sample course chunks for testing vector store operations.""" + return [ + { + "chunk_id": "chunk-001", + "course_id": "course-001", + "course_title": "Introduction to Python", + "lesson_number": 1, + "content": "Welcome to Python programming.", + "content_hash": "abc123" + }, + { + "chunk_id": "chunk-002", + "course_id": "course-001", + "course_title": "Introduction to Python", + "lesson_number": 2, + "content": "Learn about variables and data types in Python.", + "content_hash": "def456" + } + ] + + +@pytest.fixture +def sample_search_results() -> Dict[str, Any]: + """Sample search results from vector store.""" + return { + "documents": [ + "Welcome to Python programming.", + "Learn about variables and data types in Python." + ], + "metadatas": [ + { + "course_id": "course-001", + "course_title": "Introduction to Python", + "lesson_number": 1, + "chunk_id": "chunk-001" + }, + { + "course_id": "course-001", + "course_title": "Introduction to Python", + "lesson_number": 2, + "chunk_id": "chunk-002" + } + ], + "ids": [["chunk-001"], ["chunk-002"]], + "distances": [[0.1], [0.2]] + } + + +@pytest.fixture +def sample_query_request() -> Dict[str, Any]: + """Sample query request payload.""" + return { + "query": "What is Python?", + "session_id": "test-session-001" + } + + +@pytest.fixture +def sample_query_response() -> Dict[str, Any]: + """Sample query response payload.""" + return { + "answer": "Python is a high-level programming language.", + "sources": ["Introduction to Python|Lesson 1"], + "session_id": "test-session-001" + } + + +@pytest.fixture +def sample_course_stats() -> Dict[str, Any]: + """Sample course statistics response.""" + return { + "total_courses": 2, + "course_titles": ["Introduction to Python", "Advanced Python"] + } + + +# ============================================================================= +# Mock Component Fixtures +# ============================================================================= + +@pytest.fixture +def mock_config(): + """Mock configuration object.""" + config = MagicMock() + config.CHUNK_SIZE = 1000 + config.CHUNK_OVERLAP = 200 + config.CHROMA_PATH = "./chroma_db" + config.EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" + config.MAX_RESULTS = 5 + config.MINMAX_API_KEY = "test-api-key" + config.MINMAX_BASE_URL = "https://api.minimax.chat/v1" + config.MINMAX_MODEL = "abab6.5s-chat" + config.MAX_HISTORY = 10 + return config + + +@pytest.fixture +def mock_vector_store(): + """Mock vector store for testing.""" + mock = MagicMock() + mock.search = MagicMock(return_value=MagicMock( + documents=["Sample content"], + metadata=[{"course_title": "Test Course", "lesson_number": 1}], + is_empty=False + )) + mock.get_course_count = MagicMock(return_value=5) + mock.get_existing_course_titles = MagicMock(return_value=["Test Course"]) + mock.add_course_metadata = MagicMock() + mock.add_course_content = MagicMock() + mock.get_course_link = MagicMock(return_value="https://example.com/course") + mock.get_lesson_link = MagicMock(return_value="https://example.com/lesson") + return mock + + +@pytest.fixture +def mock_ai_generator(): + """Mock AI generator for testing.""" + mock = MagicMock() + mock.generate_response = MagicMock(return_value="This is a test response.") + return mock + + +@pytest.fixture +def mock_session_manager(): + """Mock session manager for testing.""" + mock = MagicMock() + mock.create_session = MagicMock(return_value="test-session-001") + mock.get_conversation_history = MagicMock(return_value=[]) + mock.add_exchange = MagicMock() + mock.clear_session = MagicMock() + return mock + + +@pytest.fixture +def mock_document_processor(): + """Mock document processor for testing.""" + mock = MagicMock() + mock.process_course_document = MagicMock(return_value=( + MagicMock( + course_id="test-course", + title="Test Course", + description="Test description" + ), + [MagicMock(chunk_id="chunk-1"), MagicMock(chunk_id="chunk-2")] + )) + return mock + + +# ============================================================================= +# FastAPI Test Client Fixtures +# ============================================================================= + +@pytest.fixture +def mock_rag_system(mock_config, mock_vector_store, mock_ai_generator, mock_session_manager): + """Create a mock RAG system with all components mocked.""" + with patch('app.RAGSystem') as MockRAG: + mock_rag = MagicMock() + mock_rag.query = MagicMock(return_value=( + "Test answer", + ["Source 1"] + )) + mock_rag.get_course_analytics = MagicMock(return_value={ + "total_courses": 1, + "course_titles": ["Test Course"] + }) + mock_rag.session_manager = mock_session_manager + MockRAG.return_value = mock_rag + yield mock_rag + + +# ============================================================================= +# Async Fixtures +# ============================================================================= + +@pytest.fixture +def mock_async_search_results(): + """Mock async search results.""" + async def mock_search(query: str, limit: int = 5): + return MagicMock( + documents=["Async result"], + metadata=[{"course_title": "Async Course"}], + is_empty=False + ) + return mock_search + + +# ============================================================================= +# Test Helpers +# ============================================================================= + +def create_mock_search_results(documents: List[str], metadatas: List[Dict]) -> MagicMock: + """Helper to create mock search results with customizable data.""" + result = MagicMock() + result.documents = documents + result.metadata = metadatas + result.is_empty = len(documents) == 0 + return result + + +def create_mock_config(**kwargs) -> MagicMock: + """Helper to create mock config with custom values.""" + config = MagicMock() + config.CHUNK_SIZE = kwargs.get("chunk_size", 1000) + config.CHUNK_OVERLAP = kwargs.get("chunk_overlap", 200) + config.CHROMA_PATH = kwargs.get("chroma_path", "./chroma_db") + config.EMBEDDING_MODEL = kwargs.get("embedding_model", "test-model") + config.MAX_RESULTS = kwargs.get("max_results", 5) + config.MINMAX_API_KEY = kwargs.get("minmax_api_key", "test-key") + config.MINMAX_BASE_URL = kwargs.get("minmax_base_url", "https://api.test.com") + config.MINMAX_MODEL = kwargs.get("minmax_model", "test-model") + config.MAX_HISTORY = kwargs.get("max_history", 10) + return config \ No newline at end of file diff --git a/backend/tests/test_api.py b/backend/tests/test_api.py new file mode 100644 index 00000000..8b1ace5f --- /dev/null +++ b/backend/tests/test_api.py @@ -0,0 +1,388 @@ +""" +API endpoint tests for the RAG system FastAPI application. + +These tests cover the core API endpoints: +- POST /api/query - Query processing endpoint +- GET /api/courses - Course statistics endpoint +- POST /api/session/clear - Session clearing endpoint + +Note: The actual app.py imports static files that don't exist in test environment, +so we recreate the API endpoints inline for testing to avoid mount issues. +""" +import pytest +from fastapi import FastAPI, HTTPException +from fastapi.testclient import TestClient +from pydantic import BaseModel +from typing import List, Optional +from unittest.mock import MagicMock, patch + + +# ============================================================================= +# Re-create API Models and Endpoints (avoiding static file mount issues) +# ============================================================================= + +class QueryRequest(BaseModel): + """Request model for course queries""" + query: str + session_id: Optional[str] = None + + +class QueryResponse(BaseModel): + """Response model for course queries""" + answer: str + sources: List[str] + session_id: str + + +class CourseStats(BaseModel): + """Response model for course statistics""" + total_courses: int + course_titles: List[str] + + +class SessionClearRequest(BaseModel): + """Request model for clearing session""" + session_id: Optional[str] = None + + +class SessionClearResponse(BaseModel): + """Response model for session clear""" + success: bool + cleared_session_id: Optional[str] = None + + +def create_test_app(mock_rag_system: MagicMock) -> FastAPI: + """Create a test FastAPI app with mocked RAG system.""" + app = FastAPI(title="Test RAG System") + + @app.post("/api/query", response_model=QueryResponse) + async def query_documents(request: QueryRequest): + """Process a query and return response with sources""" + try: + session_id = request.session_id + if not session_id: + session_id = mock_rag_system.session_manager.create_session() + + answer, sources = mock_rag_system.query(request.query, session_id) + + return QueryResponse( + answer=answer, + sources=sources, + session_id=session_id + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.get("/api/courses", response_model=CourseStats) + async def get_course_stats(): + """Get course analytics and statistics""" + try: + analytics = mock_rag_system.get_course_analytics() + return CourseStats( + total_courses=analytics["total_courses"], + course_titles=analytics["course_titles"] + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.post("/api/session/clear", response_model=SessionClearResponse) + async def clear_session(request: SessionClearRequest): + """Clear session history on the backend""" + try: + session_id = request.session_id + if session_id: + mock_rag_system.session_manager.clear_session(session_id) + return SessionClearResponse( + success=True, + cleared_session_id=session_id + ) + return SessionClearResponse(success=True, cleared_session_id=None) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + return app + + +# ============================================================================= +# Test Cases +# ============================================================================= + +class TestQueryEndpoint: + """Tests for POST /api/query endpoint.""" + + def test_query_with_valid_request(self, mock_rag_system, sample_query_request): + """Test successful query with valid request.""" + # Setup mock + mock_rag_system.query.return_value = ( + "Python is a high-level programming language.", + ["Introduction to Python|Lesson 1"] + ) + mock_rag_system.session_manager.create_session.return_value = "new-session-001" + + # Create app and client + app = create_test_app(mock_rag_system) + client = TestClient(app) + + # Execute request + response = client.post("/api/query", json=sample_query_request) + + # Verify response + assert response.status_code == 200 + data = response.json() + assert "answer" in data + assert "sources" in data + assert "session_id" in data + assert data["answer"] == "Python is a high-level programming language." + assert data["sources"] == ["Introduction to Python|Lesson 1"] + + def test_query_with_session_id(self, mock_rag_system): + """Test query with provided session ID.""" + # Setup mock + mock_rag_system.query.return_value = ("Test answer", ["Source 1"]) + + # Create app and client + app = create_test_app(mock_rag_system) + client = TestClient(app) + + # Execute request + payload = {"query": "What is Python?", "session_id": "existing-session-123"} + response = client.post("/api/query", json=payload) + + # Verify + assert response.status_code == 200 + data = response.json() + assert data["session_id"] == "existing-session-123" + mock_rag_system.query.assert_called_once_with("What is Python?", "existing-session-123") + + def test_query_without_session_creates_new_session(self, mock_rag_system): + """Test that query without session creates new session.""" + # Setup mock + mock_rag_system.query.return_value = ("Answer", ["Source"]) + mock_rag_system.session_manager.create_session.return_value = "auto-generated-session" + + # Create app and client + app = create_test_app(mock_rag_system) + client = TestClient(app) + + # Execute request without session_id + payload = {"query": "Test question"} + response = client.post("/api/query", json=payload) + + # Verify new session was created + assert response.status_code == 200 + mock_rag_system.session_manager.create_session.assert_called_once() + + def test_query_with_empty_query(self, mock_rag_system): + """Test query with empty query string.""" + # Setup mock + mock_rag_system.query.return_value = ("Empty query handled", []) + + # Create app and client + app = create_test_app(mock_rag_system) + client = TestClient(app) + + # Execute request + payload = {"query": ""} + response = client.post("/api/query", json=payload) + + # Verify - FastAPI will validate, so expect 422 for missing/empty field + assert response.status_code in [422, 500] + + def test_query_internal_error(self, mock_rag_system): + """Test query handles internal errors gracefully.""" + # Setup mock to raise exception + mock_rag_system.query.side_effect = RuntimeError("Database error") + + # Create app and client + app = create_test_app(mock_rag_system) + client = TestClient(app) + + # Execute request + payload = {"query": "Test question"} + response = client.post("/api/query", json=payload) + + # Verify error handling + assert response.status_code == 500 + assert "detail" in response.json() + + +class TestCoursesEndpoint: + """Tests for GET /api/courses endpoint.""" + + def test_courses_returns_statistics(self, mock_rag_system): + """Test successful course statistics retrieval.""" + # Setup mock + mock_rag_system.get_course_analytics.return_value = { + "total_courses": 5, + "course_titles": ["Python Basics", "Advanced Python", "Data Science", "Web Dev", "ML"] + } + + # Create app and client + app = create_test_app(mock_rag_system) + client = TestClient(app) + + # Execute request + response = client.get("/api/courses") + + # Verify response + assert response.status_code == 200 + data = response.json() + assert data["total_courses"] == 5 + assert len(data["course_titles"]) == 5 + assert "Python Basics" in data["course_titles"] + + def test_courses_empty_catalog(self, mock_rag_system): + """Test course endpoint with empty catalog.""" + # Setup mock + mock_rag_system.get_course_analytics.return_value = { + "total_courses": 0, + "course_titles": [] + } + + # Create app and client + app = create_test_app(mock_rag_system) + client = TestClient(app) + + # Execute request + response = client.get("/api/courses") + + # Verify response + assert response.status_code == 200 + data = response.json() + assert data["total_courses"] == 0 + assert data["course_titles"] == [] + + def test_courses_internal_error(self, mock_rag_system): + """Test course endpoint handles internal errors.""" + # Setup mock to raise exception + mock_rag_system.get_course_analytics.side_effect = Exception("Vector store error") + + # Create app and client + app = create_test_app(mock_rag_system) + client = TestClient(app) + + # Execute request + response = client.get("/api/courses") + + # Verify error handling + assert response.status_code == 500 + + +class TestSessionClearEndpoint: + """Tests for POST /api/session/clear endpoint.""" + + def test_clear_specific_session(self, mock_rag_system): + """Test clearing a specific session.""" + # Create app and client + app = create_test_app(mock_rag_system) + client = TestClient(app) + + # Execute request + payload = {"session_id": "session-to-clear"} + response = client.post("/api/session/clear", json=payload) + + # Verify + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["cleared_session_id"] == "session-to-clear" + mock_rag_system.session_manager.clear_session.assert_called_once_with("session-to-clear") + + def test_clear_without_session_id(self, mock_rag_system): + """Test clearing without providing session ID.""" + # Create app and client + app = create_test_app(mock_rag_system) + client = TestClient(app) + + # Execute request + payload = {} + response = client.post("/api/session/clear", json=payload) + + # Verify - should succeed without clearing any session + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["cleared_session_id"] is None + mock_rag_system.session_manager.clear_session.assert_not_called() + + def test_clear_none_session(self, mock_rag_system): + """Test clearing with null session ID.""" + # Create app and client + app = create_test_app(mock_rag_system) + client = TestClient(app) + + # Execute request + payload = {"session_id": None} + response = client.post("/api/session/clear", json=payload) + + # Verify + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + mock_rag_system.session_manager.clear_session.assert_not_called() + + +class TestRequestValidation: + """Tests for request validation.""" + + def test_query_missing_query_field(self, mock_rag_system): + """Test query endpoint rejects missing query field.""" + app = create_test_app(mock_rag_system) + client = TestClient(app) + + # Execute request with missing query + payload = {"session_id": "test-session"} + response = client.post("/api/query", json=payload) + + # Verify validation error + assert response.status_code == 422 + + def test_query_invalid_type(self, mock_rag_system): + """Test query endpoint rejects invalid types.""" + app = create_test_app(mock_rag_system) + client = TestClient(app) + + # Execute request with wrong type + payload = {"query": 12345} + response = client.post("/api/query", json=payload) + + # Verify validation error + assert response.status_code == 422 + + +class TestResponseSchema: + """Tests for response schema validation.""" + + def test_query_response_schema(self, mock_rag_system): + """Test query response has correct schema.""" + mock_rag_system.query.return_value = ("Test answer", ["Source 1", "Source 2"]) + mock_rag_system.session_manager.create_session.return_value = "test-123" + + app = create_test_app(mock_rag_system) + client = TestClient(app) + + response = client.post("/api/query", json={"query": "Test?"}) + + assert response.status_code == 200 + data = response.json() + # Check all required fields exist + assert "answer" in data and isinstance(data["answer"], str) + assert "sources" in data and isinstance(data["sources"], list) + assert "session_id" in data and isinstance(data["session_id"], str) + + def test_courses_response_schema(self, mock_rag_system): + """Test courses response has correct schema.""" + mock_rag_system.get_course_analytics.return_value = { + "total_courses": 3, + "course_titles": ["A", "B", "C"] + } + + app = create_test_app(mock_rag_system) + client = TestClient(app) + + response = client.get("/api/courses") + + assert response.status_code == 200 + data = response.json() + assert "total_courses" in data and isinstance(data["total_courses"], int) + assert "course_titles" in data and isinstance(data["course_titles"], list) \ No newline at end of file diff --git a/frontend-changes.md b/frontend-changes.md new file mode 100644 index 00000000..51befb32 --- /dev/null +++ b/frontend-changes.md @@ -0,0 +1,42 @@ +# Frontend Changes - Theme Toggle Feature + +## Overview +Added a dark/light theme toggle button that allows users to switch between themes with smooth transitions. + +## Files Modified + +### 1. `frontend/index.html` +- Added a theme toggle button in the top-right corner with sun/moon icons +- Button includes `aria-label` and `title` for accessibility +- Uses SVG icons for sun and moon + +### 2. `frontend/style.css` +- Added light theme CSS variables (under `[data-theme="light"]`) +- Added `.theme-toggle` button styles: + - Fixed position at top-right + - Circular design with shadow + - Hover effects with scale transform + - Focus ring for keyboard accessibility + - Smooth icon transitions between sun/moon + +### 3. `frontend/script.js` +- Added `initTheme()` function to load saved theme from localStorage +- Added `toggleTheme()` function to switch themes +- Added click event listener for the theme toggle button +- Theme preference persists in localStorage + +## Theme Variables Added +- `--background`: Light cream (#f8fafc) +- `--surface`: White (#ffffff) +- `--surface-hover`: Light gray (#f1f5f9) +- `--text-primary`: Dark slate (#1e293b) +- `--text-secondary`: Medium gray (#64748b) +- `--border-color`: Light border (#e2e8f0) +- `--shadow`: Subtle shadow for light theme + +## Features +- Smooth 0.3s transition animations between themes +- Theme preference saved to localStorage +- Keyboard accessible (focus ring on button) +- Icon switches between sun (light mode) and moon (dark mode) +- Works on both desktop and mobile responsive layouts \ No newline at end of file diff --git a/frontend/index.html b/frontend/index.html index f8e25a62..6edf3493 100644 --- a/frontend/index.html +++ b/frontend/index.html @@ -10,6 +10,24 @@
+ + +