Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wanted to submit PR (usage tracking via headers when client starts with track_usage = True) #10

Open
MitchMigala opened this issue Oct 22, 2024 · 0 comments

Comments

@MitchMigala
Copy link

import asyncio
import json
import logging
import os
import subprocess
import requests

from dataclasses import dataclass
from datetime import datetime, timedelta

from collections import defaultdict
from groq import Groq, AsyncGroq
from langchain_groq import ChatGroq
from langchain_community.embeddings import OllamaEmbeddings
from typing import Callable, Dict, Any, List, Union, AsyncIterator, Optional

from .enhanced_web_tool import EnhancedWebTool
from .exceptions import GroqAPIKeyMissingError, GroqAPIError, OllamaServerNotRunningError
from .web_tool import WebTool
from .chain_of_thought.cot_manager import ChainOfThoughtManager
from .chain_of_thought.llm_interface import LLMInterface
from .rag_manager import RAGManager

logger = logging.getLogger(__name__)

@dataclass
class RateLimitInfo:
    # Limits
    requests_limit: int  # RPD (Requests Per Day)
    tokens_limit: int    # TPM (Tokens Per Minute)
    
    # Remaining
    requests_remaining: int
    tokens_remaining: int
    
    # Reset times
    requests_reset_time: datetime
    tokens_reset_time: datetime
    
    # Retry information (only set when rate limited)
    retry_after: Optional[float] = None
    last_updated: datetime = datetime.now()

class GroqProvider(LLMInterface):
    def __init__(self, api_key: str = None, rag_persistent: bool = True, rag_index_path: str = "faiss_index.pkl", track_usage: bool = False):
        self.api_key = api_key or os.environ.get("GROQ_API_KEY")
        if not self.api_key:
            raise GroqAPIKeyMissingError("Groq API key is not provided")
        self.client = Groq(api_key=self.api_key)
        self.async_client = AsyncGroq(api_key=self.api_key)
        self.tool_use_models = [
            "llama3-groq-70b-8192-tool-use-preview",
            "llama3-groq-8b-8192-tool-use-preview"
        ]
        self.web_tool = WebTool()
        self.cot_manager = ChainOfThoughtManager(llm=self)
        self.rag_manager = None
        self.tools = {}
        self.rag_persistent = rag_persistent
        self.rag_index_path = rag_index_path
        self.enhanced_web_tool = EnhancedWebTool()

        # Add rate limit tracking
        self.rate_limits: Dict[str, RateLimitInfo] = {}
        self.track_usage = track_usage 

        # Initialize conversation sessions
        self.conversation_sessions = defaultdict(list)  # session_id -> list of messages

        # Check if Ollama server is running and initialize RAG if it is
        if self.is_ollama_server_running():
            if self.rag_persistent:
                logger.info("Initializing RAG with persistence enabled.")
                self.initialize_rag(index_path=self.rag_index_path)
        else:
            logger.warning("Ollama server is not running. RAG functionality will be limited.")


    def _parse_time_to_seconds(self, time_str: str) -> float:
        """
        Convert time string to seconds, handling both seconds and milliseconds formats.
        Examples:
            "2m59.56s" -> 179.56
            "6s" -> 6.0
            "48ms" -> 0.048
        """
        try:
            # Handle milliseconds format
            if time_str.endswith('ms'):
                return float(time_str.rstrip('ms')) / 1000.0

            total_seconds = 0
            # Handle minutes/seconds format
            if 'm' in time_str:
                minutes, rest = time_str.split('m')
                total_seconds += float(minutes) * 60
                rest = rest.replace('s', '')
                if rest:
                    total_seconds += float(rest)
            else:
                # Handle seconds only format
                total_seconds = float(time_str.replace('s', ''))
                
            return total_seconds
        except Exception as e:
            logger.error(f"Error parsing time string '{time_str}': {e}")
            return 0.0

    def _update_rate_limits(self, model: str, headers: Dict[str, str], status_code: int):
        """Update rate limit information from response headers"""
        now = datetime.now()
        
        try:
            # Extract and parse the rate limit headers
            requests_limit = int(headers.get('x-ratelimit-limit-requests', 0))
            tokens_limit = int(headers.get('x-ratelimit-limit-tokens', 0))
            requests_remaining = int(headers.get('x-ratelimit-remaining-requests', 0))
            tokens_remaining = int(headers.get('x-ratelimit-remaining-tokens', 0))
            
            # Parse reset times
            requests_reset = headers.get('x-ratelimit-reset-requests', '0s')
            tokens_reset = headers.get('x-ratelimit-reset-tokens', '0s')
            
            logger.debug(f"Raw reset times - requests: {requests_reset}, tokens: {tokens_reset}")
            
            requests_reset_seconds = self._parse_time_to_seconds(requests_reset)
            tokens_reset_seconds = self._parse_time_to_seconds(tokens_reset)
            
            logger.debug(f"Parsed reset times - requests: {requests_reset_seconds}s, tokens: {tokens_reset_seconds}s")
            
            # Create rate limit info
            rate_info = RateLimitInfo(
                requests_limit=requests_limit,
                tokens_limit=tokens_limit,
                requests_remaining=requests_remaining,
                tokens_remaining=tokens_remaining,
                requests_reset_time=now + timedelta(seconds=requests_reset_seconds),
                tokens_reset_time=now + timedelta(seconds=tokens_reset_seconds),
                retry_after=float(headers.get('retry-after', 0)) if status_code == 429 else None,
                last_updated=now
            )
            
            # Store the rate info
            self.rate_limits[model] = rate_info
            
        except Exception as e:
            logger.error(f"Error updating rate limits: {e}")
            logger.debug(f"Headers: {headers}")

    def get_rate_limits(self, model: str) -> Optional[RateLimitInfo]:
        """Get current rate limit information for a specific model"""
        return self.rate_limits.get(model)

    def get_model_usage_status(self, model: str) -> Dict[str, Any]:
        """Get detailed usage status for a specific model"""
        logger.debug(f"Getting usage status for model {model}")
        logger.debug(f"Current rate_limits: {self.rate_limits}")
        
        rate_info = self.rate_limits.get(model)
        if not rate_info:
            logger.debug(f"No rate limit information found for model {model}")
            return None
                
        now = datetime.now()
        
        status = {
            "requests_remaining": rate_info.requests_remaining,
            "tokens_remaining": rate_info.tokens_remaining,
            "requests_reset_in": max(0, (rate_info.requests_reset_time - now).total_seconds()),
            "tokens_reset_in": max(0, (rate_info.tokens_reset_time - now).total_seconds()),
            "retry_after": rate_info.retry_after,
            "last_updated": rate_info.last_updated,
            "requests_limit": rate_info.requests_limit,
            "tokens_limit": rate_info.tokens_limit
        }
        
        logger.debug(f"Returning status for model {model}: {status}")
        return status

    def get_available_models(self) -> List[Dict[str, Any]]:
        """
        Fetch the list of available models from the Groq provider.

        Returns:
            List[Dict[str, Any]]: A list of models with their details.
        """
        url = "https://api.groq.com/openai/v1/models"
        headers = {"Authorization": f"Bearer {self.api_key}"}
        try:
            response = requests.get(url, headers=headers)
            response.raise_for_status()
            models = response.json().get("data", [])
            return models
        except requests.RequestException as e:
            logger.error(f"Failed to fetch models: {e}")
            raise GroqAPIError(f"Failed to fetch models: {e}")

    def crawl_website(self, url: str, formats: List[str] = ["markdown"], max_depth: int = 3, max_pages: int = 100) -> List[Dict[str, Any]]:
        """
        Crawl a website and return its content in specified formats.
        
        Args:
            url (str): The starting URL to crawl.
            formats (List[str]): List of desired output formats (e.g., ["markdown", "html", "structured_data"]).
            max_depth (int): Maximum depth to crawl.
            max_pages (int): Maximum number of pages to crawl.
        
        Returns:
            List[Dict[str, Any]]: List of crawled pages with their content in specified formats.
        """
        self.enhanced_web_tool.max_depth = max_depth
        self.enhanced_web_tool.max_pages = max_pages
        return self.enhanced_web_tool.crawl(url, formats)

    def is_ollama_server_running(self) -> bool:
        """Check if the Ollama server is running."""
        try:
            response = requests.get("http://localhost:11434/api/tags")
            return response.status_code == 200
        except requests.RequestException:
            return False

    def ensure_ollama_server_running(func):
        """Decorator to ensure Ollama server is running for functions that require it."""
        def wrapper(self, *args, **kwargs):
            if not self.is_ollama_server_running():
                raise OllamaServerNotRunningError("Ollama server is not running. Please start it and try again.")
            return func(self, *args, **kwargs)
        return wrapper
    
    def evaluate_response(self, request: str, response: str) -> bool:
        """
        Evaluate if a response satisfies a given request using an AI LLM.
        
        Args:
            request (str): The original request or question.
            response (str): The response to be evaluated.
        
        Returns:
            bool: True if the response is deemed satisfactory, False otherwise.
        """
        evaluation_prompt = f"""
        You will be given a request and a response. Your task is to evaluate the response based on the following criteria:
        1. **Informative and Correct**: The response must be accurate and provide clear, useful, and sufficient information to fully answer the request.
        2. **No Uncertainty**: The response should not express any uncertainty, such as language indicating doubt (e.g., "maybe," "possibly," "it seems") or statements that are inconclusive.

        Request: {request}
        Response: {response}

        Based on these criteria, is the response satisfactory? Answer with only 'Yes' or 'No'.
        """

        evaluation = self.generate(evaluation_prompt, temperature=0.0, max_tokens=1)
        
        # Clean up the response and convert to boolean
        evaluation = evaluation.strip().lower()
        return evaluation == 'yes'

    def register_tool(self, name: str, func: callable):
        self.tools[name] = func

    def scrape_url(self, url: str, formats: List[str] = ["markdown"]) -> Dict[str, Any]:
        """
        Scrape a single URL and return its content in specified formats.
        
        Args:
            url (str): The URL to scrape.
            formats (List[str]): List of desired output formats (e.g., ["markdown", "html", "structured_data"]).
        
        Returns:
            Dict[str, Any]: The scraped content in specified formats.
        """
        return self.enhanced_web_tool.scrape_page(url, formats)

    def end_conversation(self, conversation_id: str):
        """
        Ends a conversation and clears its history.

        Args:
            conversation_id (str): The ID of the conversation to end.
        """
        if conversation_id in self.conversations:
            del self.conversations[conversation_id]
            logger.info(f"Ended conversation with ID: {conversation_id}")
        else:
            logger.warning(f"Attempted to end non-existent conversation ID: {conversation_id}")

    def get_conversation_history(self, session_id: str) -> List[Dict[str, str]]:
        """
        Retrieve the conversation history for a given session.

        Args:
            session_id (str): Unique identifier for the conversation session.

        Returns:
            List[Dict[str, str]]: List of messages in the conversation.
        """
        return self.conversation_sessions.get(session_id, [])            

    def start_conversation(self, session_id: str):
        """
        Initialize a new conversation session.

        Args:
            session_id (str): Unique identifier for the conversation session.
        """
        if session_id in self.conversation_sessions:
            logger.warning(f"Session '{session_id}' already exists. Overwriting.")
        self.conversation_sessions[session_id] = []
        logger.info(f"Started new conversation session '{session_id}'.")
    
    def reset_conversation(self, session_id: str):
        if session_id in self.conversation_sessions:
            del self.conversation_sessions[session_id]
            logger.info(f"Conversation session '{session_id}' has been reset.")
        else:
            logger.warning(f"Attempted to reset non-existent session '{session_id}'.")   

    def generate(self, prompt: str, session_id: Optional[str] = None, track_usage: Optional[bool] = None, **kwargs) -> Union[str, AsyncIterator[str]]:
        if session_id:
            messages = self.conversation_sessions[session_id]
            messages.append({"role": "user", "content": prompt})
        else:
            messages = [{"role": "user", "content": prompt}]

        # Use track_usage parameter if provided, otherwise use instance default
        track_usage = track_usage if track_usage is not None else self.track_usage
        response = self._create_completion(messages, track_usage=track_usage, **kwargs)

        if session_id:
            if isinstance(response, str):
                self.conversation_sessions[session_id].append({"role": "assistant", "content": response})
            elif asyncio.iscoroutine(response):
                # Handle asynchronous streaming responses if needed
                pass

        return response

    def set_api_key(self, api_key: str):
        self.api_key = api_key
        self.client = Groq(api_key=self.api_key)
        self.async_client = AsyncGroq(api_key=self.api_key)

    def _create_completion(self, messages: List[Dict[str, str]], track_usage: Optional[bool] = None, **kwargs) -> Union[str, AsyncIterator[str]]:
        completion_kwargs = {
            "model": self._select_model(kwargs.get("model"), kwargs.get("tools")),
            "messages": messages,
            "temperature": kwargs.get("temperature", 0.5),
            "max_tokens": kwargs.get("max_tokens", 1024),
            "top_p": kwargs.get("top_p", 1),
            "stop": kwargs.get("stop", None),
            "stream": kwargs.get("stream", False),
        }

        if kwargs.get("json_mode", False):
            completion_kwargs["response_format"] = {"type": "json_object"}

        if kwargs.get("tools"):
            completion_kwargs["tools"] = self._prepare_tools(kwargs["tools"])
            completion_kwargs["tool_choice"] = kwargs.get("tool_choice", "auto")

        # Use track_usage parameter if provided, otherwise use instance default
        track_usage = track_usage if track_usage is not None else self.track_usage

        if kwargs.get("async_mode", False):
            return self._async_create_completion(track_usage=track_usage, **completion_kwargs)
        else:
            return self._sync_create_completion(track_usage=track_usage, **completion_kwargs)


    def _select_model(self, requested_model: str, tools: List[Dict[str, Any]]) -> str:
        if tools and not requested_model:
            return self.tool_use_models[0]
        elif tools and requested_model not in self.tool_use_models:
            print(f"Warning: {requested_model} is not optimized for tool use. Switching to {self.tool_use_models[0]}.")
            return self.tool_use_models[0]
        return requested_model or os.environ.get('GROQ_MODEL', 'llama3-8b-8192')
    
    def _prepare_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        prepared_tools = []
        for tool in tools:
            prepared_tool = tool.copy()
            if 'function' in prepared_tool:
                prepared_tool['function'] = {k: v for k, v in prepared_tool['function'].items() if k != 'implementation'}
            prepared_tools.append(prepared_tool)
        return prepared_tools

    def _sync_create_completion(self, track_usage: bool = False, **kwargs) -> Union[str, AsyncIterator[str]]:
        try:
            model = kwargs.get("model", "llama3-8b-8192")
            
            if kwargs.get("stream", False):
                if track_usage:
                    with self.client.chat.completions.with_streaming_response.create(**kwargs) as response:
                        # Debug log the headers
                        logger.debug(f"Raw headers from streaming response: {dict(response.headers)}")
                        self._update_rate_limits(model, response.headers, response.status_code)
                        return (chunk.choices[0].delta.content 
                            for line in response.iter_lines()
                            for chunk in [json.loads(line)] if 'choices' in chunk)
                else:
                    response = self.client.chat.completions.create(**kwargs)
                    return (chunk.choices[0].delta.content for chunk in response)
            else:
                if track_usage:
                    response = self.client.chat.completions.with_raw_response.create(**kwargs)
                    # Debug log the headers
                    logger.debug(f"Raw headers from response: {dict(response.headers)}")
                    self._update_rate_limits(model, response.headers, response.status_code)
                    completion = response.parse()
                else:
                    completion = self.client.chat.completions.create(**kwargs)
                
                return self._process_tool_calls(completion)
        except Exception as e:
            logger.error(f"Error in Groq API call: {str(e)}")
            raise GroqAPIError(f"Error in Groq API call: {str(e)}")

    async def _async_create_completion(self, track_usage: bool = False, **kwargs) -> Union[str, AsyncIterator[str]]:
        try:
            model = kwargs.get("model", "llama3-8b-8192")
            
            if kwargs.get("stream", False):
                if track_usage:
                    return await self._async_stream_with_raw_response(model=model, **kwargs)
                else:
                    response = await self.async_client.chat.completions.create(**kwargs)
                    async def async_generator():
                        async for chunk in response:
                            yield chunk.choices[0].delta.content
                    return async_generator()
            else:
                if track_usage:
                    response = await self.async_client.chat.completions.with_raw_response.create(**kwargs)
                    self._update_rate_limits(model, response.headers, response.status_code)
                    completion = await response.parse()
                else:
                    completion = await self.async_client.chat.completions.create(**kwargs)
                
                return await self._async_process_tool_calls(completion)
        except Exception as e:
            raise GroqAPIError(f"Error in async Groq API call: {str(e)}")

    async def _async_stream_with_raw_response(self, model: str, **kwargs):
        """Helper method to handle async streaming with raw response"""
        async with self.async_client.chat.completions.with_streaming_response.create(**kwargs) as response:
            self._update_rate_limits(model, response.headers, response.status_code)
            async def async_generator():
                async for line in response.iter_lines():
                    if line:
                        chunk = json.loads(line)
                        if 'choices' in chunk:
                            yield chunk['choices'][0]['delta'].get('content', '')
            return async_generator()

    def _process_tool_calls(self, completion) -> str:
        message = completion.choices[0].message
        if hasattr(message, 'tool_calls') and message.tool_calls:
            tool_results = self._execute_tool_calls(message.tool_calls)
            new_message = {
                "role": "assistant",
                "content": message.content,
                "tool_calls": message.tool_calls,
            }
            return self._create_completion([new_message] + tool_results)
        return message.content

    async def _async_process_tool_calls(self, completion) -> str:
        message = completion.choices[0].message
        if hasattr(message, 'tool_calls') and message.tool_calls:
            tool_results = await self._async_execute_tool_calls(message.tool_calls)
            new_message = {
                "role": "assistant",
                "content": message.content,
                "tool_calls": message.tool_calls,
            }
            for result in tool_results:
                new_message["tool_results"] = result
            return await self._async_create_completion([new_message])
        return message.content

    def _execute_tool_calls(self, tool_calls) -> List[Dict[str, Any]]:
        results = []
        for tool_call in tool_calls:
            if tool_call.function.name in self.tools:
                args = json.loads(tool_call.function.arguments)
                result = self.tools[tool_call.function.name](**args)
            else:
                result = {"error": f"Unknown tool: {tool_call.function.name}"}
            
            results.append({
                "role": "tool",
                "content": json.dumps(result),
                "tool_call_id": tool_call.id,
            })
        return results

    async def _async_execute_tool_calls(self, tool_calls) -> List[Dict[str, Any]]:
        results = []
        for tool_call in tool_calls:
            if tool_call.function.name == "web_search":
                args = json.loads(tool_call.function.arguments)
                result = await asyncio.to_thread(self.web_tool.search, args.get("query", ""))
            elif tool_call.function.name == "get_web_content":
                args = json.loads(tool_call.function.arguments)
                result = await asyncio.to_thread(self.web_tool.get_web_content, args.get("url", ""))
            else:
                result = {"error": f"Unknown tool: {tool_call.function.name}"}
            
            results.append({
                "tool_call_id": tool_call.id,
                "role": "tool",
                "name": tool_call.function.name,
                "content": json.dumps(result),
            })
        return results

    def web_search(self, query: str, num_results: int = 10) -> List[Dict[str, Any]]:
        """Perform a web search using the integrated WebTool."""
        return self.web_tool.search(query)

    def get_web_content(self, url: str) -> str:
        """Retrieve the content of a web page using the integrated WebTool."""
        return self.web_tool.get_web_content(url)

    def is_url(self, text: str) -> bool:
        """Check if the given text is a valid URL using the integrated WebTool."""
        return self.web_tool.is_url(text)
    
    def solve_problem_with_cot(self, problem: str, **kwargs) -> str:
        """
        Solve a problem using Chain of Thought reasoning.
        """
        return self.cot_manager.solve_problem(problem)

    def generate_cot(self, problem: str, **kwargs) -> List[str]:
        """
        Generate Chain of Thought steps for a given problem.
        """
        return self.cot_manager.generate_cot(problem)

    def synthesize_cot(self, cot_steps: List[str], **kwargs) -> str:
        """
        Synthesize a final answer from Chain of Thought steps.
        """
        return self.cot_manager.synthesize_response(cot_steps)
    
    @ensure_ollama_server_running
    def initialize_rag(self, ollama_base_url: str = "http://localhost:11434", model_name: str = "nomic-embed-text", index_path: str = "faiss_index.pkl"):
        try:
            # Attempt to pull the model if it's not already available
            subprocess.run(["ollama", "pull", model_name], check=True)
        except subprocess.CalledProcessError:
            logger.error(f"Failed to pull model {model_name}. Ensure Ollama is installed and running.")
            raise

        embeddings = OllamaEmbeddings(base_url=ollama_base_url, model=model_name)
        self.rag_manager = RAGManager(embeddings, index_path=index_path)
        logger.info("RAG initialized successfully.")

    @ensure_ollama_server_running
    def load_documents(self, source: str, chunk_size: int = 1000, chunk_overlap: int = 200, 
                       progress_callback: Callable[[int, int], None] = None, timeout: int = 300, 
                       persistent: bool = None):
        if persistent is None:
            persistent = self.rag_persistent
        if not self.rag_manager:
            raise ValueError("RAG has not been initialized. Call initialize_rag first.")

        # Use a separate index path if non-persistent
        index_path = self.rag_index_path if persistent else f"temp_{self.rag_index_path}"
        self.rag_manager.index_path = index_path

        self.rag_manager.load_and_process_documents(source, chunk_size, chunk_overlap, progress_callback, timeout)

    @ensure_ollama_server_running
    def query_documents(self, query: str, session_id: Optional[str] = None, **kwargs) -> str:
        if not self.rag_manager:
            raise ValueError("RAG has not been initialized. Call initialize_rag first.")
        
        llm = ChatGroq(groq_api_key=self.api_key, model_name=kwargs.get("model", "llama3-8b-8192"))
        response = self.rag_manager.query_documents(llm, query)
        return response['answer']

    def query_documents(self, query: str, session_id: Optional[str] = None, **kwargs) -> str:
        if not self.rag_manager:
            raise ValueError("RAG has not been initialized. Call initialize_rag first.")
        
        llm = ChatGroq(groq_api_key=self.api_key, model_name=kwargs.get("model", "llama3-8b-8192"))
        response = self.rag_manager.query_documents(llm, query)
        return response['answer']

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant