You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
import asyncio
import json
import logging
import os
import subprocess
import requests
from dataclasses import dataclass
from datetime import datetime, timedelta
from collections import defaultdict
from groq import Groq, AsyncGroq
from langchain_groq import ChatGroq
from langchain_community.embeddings import OllamaEmbeddings
from typing import Callable, Dict, Any, List, Union, AsyncIterator, Optional
from .enhanced_web_tool import EnhancedWebTool
from .exceptions import GroqAPIKeyMissingError, GroqAPIError, OllamaServerNotRunningError
from .web_tool import WebTool
from .chain_of_thought.cot_manager import ChainOfThoughtManager
from .chain_of_thought.llm_interface import LLMInterface
from .rag_manager import RAGManager
logger = logging.getLogger(__name__)
@dataclass
class RateLimitInfo:
# Limits
requests_limit: int # RPD (Requests Per Day)
tokens_limit: int # TPM (Tokens Per Minute)
# Remaining
requests_remaining: int
tokens_remaining: int
# Reset times
requests_reset_time: datetime
tokens_reset_time: datetime
# Retry information (only set when rate limited)
retry_after: Optional[float] = None
last_updated: datetime = datetime.now()
class GroqProvider(LLMInterface):
def __init__(self, api_key: str = None, rag_persistent: bool = True, rag_index_path: str = "faiss_index.pkl", track_usage: bool = False):
self.api_key = api_key or os.environ.get("GROQ_API_KEY")
if not self.api_key:
raise GroqAPIKeyMissingError("Groq API key is not provided")
self.client = Groq(api_key=self.api_key)
self.async_client = AsyncGroq(api_key=self.api_key)
self.tool_use_models = [
"llama3-groq-70b-8192-tool-use-preview",
"llama3-groq-8b-8192-tool-use-preview"
]
self.web_tool = WebTool()
self.cot_manager = ChainOfThoughtManager(llm=self)
self.rag_manager = None
self.tools = {}
self.rag_persistent = rag_persistent
self.rag_index_path = rag_index_path
self.enhanced_web_tool = EnhancedWebTool()
# Add rate limit tracking
self.rate_limits: Dict[str, RateLimitInfo] = {}
self.track_usage = track_usage
# Initialize conversation sessions
self.conversation_sessions = defaultdict(list) # session_id -> list of messages
# Check if Ollama server is running and initialize RAG if it is
if self.is_ollama_server_running():
if self.rag_persistent:
logger.info("Initializing RAG with persistence enabled.")
self.initialize_rag(index_path=self.rag_index_path)
else:
logger.warning("Ollama server is not running. RAG functionality will be limited.")
def _parse_time_to_seconds(self, time_str: str) -> float:
"""
Convert time string to seconds, handling both seconds and milliseconds formats.
Examples:
"2m59.56s" -> 179.56
"6s" -> 6.0
"48ms" -> 0.048
"""
try:
# Handle milliseconds format
if time_str.endswith('ms'):
return float(time_str.rstrip('ms')) / 1000.0
total_seconds = 0
# Handle minutes/seconds format
if 'm' in time_str:
minutes, rest = time_str.split('m')
total_seconds += float(minutes) * 60
rest = rest.replace('s', '')
if rest:
total_seconds += float(rest)
else:
# Handle seconds only format
total_seconds = float(time_str.replace('s', ''))
return total_seconds
except Exception as e:
logger.error(f"Error parsing time string '{time_str}': {e}")
return 0.0
def _update_rate_limits(self, model: str, headers: Dict[str, str], status_code: int):
"""Update rate limit information from response headers"""
now = datetime.now()
try:
# Extract and parse the rate limit headers
requests_limit = int(headers.get('x-ratelimit-limit-requests', 0))
tokens_limit = int(headers.get('x-ratelimit-limit-tokens', 0))
requests_remaining = int(headers.get('x-ratelimit-remaining-requests', 0))
tokens_remaining = int(headers.get('x-ratelimit-remaining-tokens', 0))
# Parse reset times
requests_reset = headers.get('x-ratelimit-reset-requests', '0s')
tokens_reset = headers.get('x-ratelimit-reset-tokens', '0s')
logger.debug(f"Raw reset times - requests: {requests_reset}, tokens: {tokens_reset}")
requests_reset_seconds = self._parse_time_to_seconds(requests_reset)
tokens_reset_seconds = self._parse_time_to_seconds(tokens_reset)
logger.debug(f"Parsed reset times - requests: {requests_reset_seconds}s, tokens: {tokens_reset_seconds}s")
# Create rate limit info
rate_info = RateLimitInfo(
requests_limit=requests_limit,
tokens_limit=tokens_limit,
requests_remaining=requests_remaining,
tokens_remaining=tokens_remaining,
requests_reset_time=now + timedelta(seconds=requests_reset_seconds),
tokens_reset_time=now + timedelta(seconds=tokens_reset_seconds),
retry_after=float(headers.get('retry-after', 0)) if status_code == 429 else None,
last_updated=now
)
# Store the rate info
self.rate_limits[model] = rate_info
except Exception as e:
logger.error(f"Error updating rate limits: {e}")
logger.debug(f"Headers: {headers}")
def get_rate_limits(self, model: str) -> Optional[RateLimitInfo]:
"""Get current rate limit information for a specific model"""
return self.rate_limits.get(model)
def get_model_usage_status(self, model: str) -> Dict[str, Any]:
"""Get detailed usage status for a specific model"""
logger.debug(f"Getting usage status for model {model}")
logger.debug(f"Current rate_limits: {self.rate_limits}")
rate_info = self.rate_limits.get(model)
if not rate_info:
logger.debug(f"No rate limit information found for model {model}")
return None
now = datetime.now()
status = {
"requests_remaining": rate_info.requests_remaining,
"tokens_remaining": rate_info.tokens_remaining,
"requests_reset_in": max(0, (rate_info.requests_reset_time - now).total_seconds()),
"tokens_reset_in": max(0, (rate_info.tokens_reset_time - now).total_seconds()),
"retry_after": rate_info.retry_after,
"last_updated": rate_info.last_updated,
"requests_limit": rate_info.requests_limit,
"tokens_limit": rate_info.tokens_limit
}
logger.debug(f"Returning status for model {model}: {status}")
return status
def get_available_models(self) -> List[Dict[str, Any]]:
"""
Fetch the list of available models from the Groq provider.
Returns:
List[Dict[str, Any]]: A list of models with their details.
"""
url = "https://api.groq.com/openai/v1/models"
headers = {"Authorization": f"Bearer {self.api_key}"}
try:
response = requests.get(url, headers=headers)
response.raise_for_status()
models = response.json().get("data", [])
return models
except requests.RequestException as e:
logger.error(f"Failed to fetch models: {e}")
raise GroqAPIError(f"Failed to fetch models: {e}")
def crawl_website(self, url: str, formats: List[str] = ["markdown"], max_depth: int = 3, max_pages: int = 100) -> List[Dict[str, Any]]:
"""
Crawl a website and return its content in specified formats.
Args:
url (str): The starting URL to crawl.
formats (List[str]): List of desired output formats (e.g., ["markdown", "html", "structured_data"]).
max_depth (int): Maximum depth to crawl.
max_pages (int): Maximum number of pages to crawl.
Returns:
List[Dict[str, Any]]: List of crawled pages with their content in specified formats.
"""
self.enhanced_web_tool.max_depth = max_depth
self.enhanced_web_tool.max_pages = max_pages
return self.enhanced_web_tool.crawl(url, formats)
def is_ollama_server_running(self) -> bool:
"""Check if the Ollama server is running."""
try:
response = requests.get("http://localhost:11434/api/tags")
return response.status_code == 200
except requests.RequestException:
return False
def ensure_ollama_server_running(func):
"""Decorator to ensure Ollama server is running for functions that require it."""
def wrapper(self, *args, **kwargs):
if not self.is_ollama_server_running():
raise OllamaServerNotRunningError("Ollama server is not running. Please start it and try again.")
return func(self, *args, **kwargs)
return wrapper
def evaluate_response(self, request: str, response: str) -> bool:
"""
Evaluate if a response satisfies a given request using an AI LLM.
Args:
request (str): The original request or question.
response (str): The response to be evaluated.
Returns:
bool: True if the response is deemed satisfactory, False otherwise.
"""
evaluation_prompt = f"""
You will be given a request and a response. Your task is to evaluate the response based on the following criteria:
1. **Informative and Correct**: The response must be accurate and provide clear, useful, and sufficient information to fully answer the request.
2. **No Uncertainty**: The response should not express any uncertainty, such as language indicating doubt (e.g., "maybe," "possibly," "it seems") or statements that are inconclusive.
Request: {request}
Response: {response}
Based on these criteria, is the response satisfactory? Answer with only 'Yes' or 'No'.
"""
evaluation = self.generate(evaluation_prompt, temperature=0.0, max_tokens=1)
# Clean up the response and convert to boolean
evaluation = evaluation.strip().lower()
return evaluation == 'yes'
def register_tool(self, name: str, func: callable):
self.tools[name] = func
def scrape_url(self, url: str, formats: List[str] = ["markdown"]) -> Dict[str, Any]:
"""
Scrape a single URL and return its content in specified formats.
Args:
url (str): The URL to scrape.
formats (List[str]): List of desired output formats (e.g., ["markdown", "html", "structured_data"]).
Returns:
Dict[str, Any]: The scraped content in specified formats.
"""
return self.enhanced_web_tool.scrape_page(url, formats)
def end_conversation(self, conversation_id: str):
"""
Ends a conversation and clears its history.
Args:
conversation_id (str): The ID of the conversation to end.
"""
if conversation_id in self.conversations:
del self.conversations[conversation_id]
logger.info(f"Ended conversation with ID: {conversation_id}")
else:
logger.warning(f"Attempted to end non-existent conversation ID: {conversation_id}")
def get_conversation_history(self, session_id: str) -> List[Dict[str, str]]:
"""
Retrieve the conversation history for a given session.
Args:
session_id (str): Unique identifier for the conversation session.
Returns:
List[Dict[str, str]]: List of messages in the conversation.
"""
return self.conversation_sessions.get(session_id, [])
def start_conversation(self, session_id: str):
"""
Initialize a new conversation session.
Args:
session_id (str): Unique identifier for the conversation session.
"""
if session_id in self.conversation_sessions:
logger.warning(f"Session '{session_id}' already exists. Overwriting.")
self.conversation_sessions[session_id] = []
logger.info(f"Started new conversation session '{session_id}'.")
def reset_conversation(self, session_id: str):
if session_id in self.conversation_sessions:
del self.conversation_sessions[session_id]
logger.info(f"Conversation session '{session_id}' has been reset.")
else:
logger.warning(f"Attempted to reset non-existent session '{session_id}'.")
def generate(self, prompt: str, session_id: Optional[str] = None, track_usage: Optional[bool] = None, **kwargs) -> Union[str, AsyncIterator[str]]:
if session_id:
messages = self.conversation_sessions[session_id]
messages.append({"role": "user", "content": prompt})
else:
messages = [{"role": "user", "content": prompt}]
# Use track_usage parameter if provided, otherwise use instance default
track_usage = track_usage if track_usage is not None else self.track_usage
response = self._create_completion(messages, track_usage=track_usage, **kwargs)
if session_id:
if isinstance(response, str):
self.conversation_sessions[session_id].append({"role": "assistant", "content": response})
elif asyncio.iscoroutine(response):
# Handle asynchronous streaming responses if needed
pass
return response
def set_api_key(self, api_key: str):
self.api_key = api_key
self.client = Groq(api_key=self.api_key)
self.async_client = AsyncGroq(api_key=self.api_key)
def _create_completion(self, messages: List[Dict[str, str]], track_usage: Optional[bool] = None, **kwargs) -> Union[str, AsyncIterator[str]]:
completion_kwargs = {
"model": self._select_model(kwargs.get("model"), kwargs.get("tools")),
"messages": messages,
"temperature": kwargs.get("temperature", 0.5),
"max_tokens": kwargs.get("max_tokens", 1024),
"top_p": kwargs.get("top_p", 1),
"stop": kwargs.get("stop", None),
"stream": kwargs.get("stream", False),
}
if kwargs.get("json_mode", False):
completion_kwargs["response_format"] = {"type": "json_object"}
if kwargs.get("tools"):
completion_kwargs["tools"] = self._prepare_tools(kwargs["tools"])
completion_kwargs["tool_choice"] = kwargs.get("tool_choice", "auto")
# Use track_usage parameter if provided, otherwise use instance default
track_usage = track_usage if track_usage is not None else self.track_usage
if kwargs.get("async_mode", False):
return self._async_create_completion(track_usage=track_usage, **completion_kwargs)
else:
return self._sync_create_completion(track_usage=track_usage, **completion_kwargs)
def _select_model(self, requested_model: str, tools: List[Dict[str, Any]]) -> str:
if tools and not requested_model:
return self.tool_use_models[0]
elif tools and requested_model not in self.tool_use_models:
print(f"Warning: {requested_model} is not optimized for tool use. Switching to {self.tool_use_models[0]}.")
return self.tool_use_models[0]
return requested_model or os.environ.get('GROQ_MODEL', 'llama3-8b-8192')
def _prepare_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
prepared_tools = []
for tool in tools:
prepared_tool = tool.copy()
if 'function' in prepared_tool:
prepared_tool['function'] = {k: v for k, v in prepared_tool['function'].items() if k != 'implementation'}
prepared_tools.append(prepared_tool)
return prepared_tools
def _sync_create_completion(self, track_usage: bool = False, **kwargs) -> Union[str, AsyncIterator[str]]:
try:
model = kwargs.get("model", "llama3-8b-8192")
if kwargs.get("stream", False):
if track_usage:
with self.client.chat.completions.with_streaming_response.create(**kwargs) as response:
# Debug log the headers
logger.debug(f"Raw headers from streaming response: {dict(response.headers)}")
self._update_rate_limits(model, response.headers, response.status_code)
return (chunk.choices[0].delta.content
for line in response.iter_lines()
for chunk in [json.loads(line)] if 'choices' in chunk)
else:
response = self.client.chat.completions.create(**kwargs)
return (chunk.choices[0].delta.content for chunk in response)
else:
if track_usage:
response = self.client.chat.completions.with_raw_response.create(**kwargs)
# Debug log the headers
logger.debug(f"Raw headers from response: {dict(response.headers)}")
self._update_rate_limits(model, response.headers, response.status_code)
completion = response.parse()
else:
completion = self.client.chat.completions.create(**kwargs)
return self._process_tool_calls(completion)
except Exception as e:
logger.error(f"Error in Groq API call: {str(e)}")
raise GroqAPIError(f"Error in Groq API call: {str(e)}")
async def _async_create_completion(self, track_usage: bool = False, **kwargs) -> Union[str, AsyncIterator[str]]:
try:
model = kwargs.get("model", "llama3-8b-8192")
if kwargs.get("stream", False):
if track_usage:
return await self._async_stream_with_raw_response(model=model, **kwargs)
else:
response = await self.async_client.chat.completions.create(**kwargs)
async def async_generator():
async for chunk in response:
yield chunk.choices[0].delta.content
return async_generator()
else:
if track_usage:
response = await self.async_client.chat.completions.with_raw_response.create(**kwargs)
self._update_rate_limits(model, response.headers, response.status_code)
completion = await response.parse()
else:
completion = await self.async_client.chat.completions.create(**kwargs)
return await self._async_process_tool_calls(completion)
except Exception as e:
raise GroqAPIError(f"Error in async Groq API call: {str(e)}")
async def _async_stream_with_raw_response(self, model: str, **kwargs):
"""Helper method to handle async streaming with raw response"""
async with self.async_client.chat.completions.with_streaming_response.create(**kwargs) as response:
self._update_rate_limits(model, response.headers, response.status_code)
async def async_generator():
async for line in response.iter_lines():
if line:
chunk = json.loads(line)
if 'choices' in chunk:
yield chunk['choices'][0]['delta'].get('content', '')
return async_generator()
def _process_tool_calls(self, completion) -> str:
message = completion.choices[0].message
if hasattr(message, 'tool_calls') and message.tool_calls:
tool_results = self._execute_tool_calls(message.tool_calls)
new_message = {
"role": "assistant",
"content": message.content,
"tool_calls": message.tool_calls,
}
return self._create_completion([new_message] + tool_results)
return message.content
async def _async_process_tool_calls(self, completion) -> str:
message = completion.choices[0].message
if hasattr(message, 'tool_calls') and message.tool_calls:
tool_results = await self._async_execute_tool_calls(message.tool_calls)
new_message = {
"role": "assistant",
"content": message.content,
"tool_calls": message.tool_calls,
}
for result in tool_results:
new_message["tool_results"] = result
return await self._async_create_completion([new_message])
return message.content
def _execute_tool_calls(self, tool_calls) -> List[Dict[str, Any]]:
results = []
for tool_call in tool_calls:
if tool_call.function.name in self.tools:
args = json.loads(tool_call.function.arguments)
result = self.tools[tool_call.function.name](**args)
else:
result = {"error": f"Unknown tool: {tool_call.function.name}"}
results.append({
"role": "tool",
"content": json.dumps(result),
"tool_call_id": tool_call.id,
})
return results
async def _async_execute_tool_calls(self, tool_calls) -> List[Dict[str, Any]]:
results = []
for tool_call in tool_calls:
if tool_call.function.name == "web_search":
args = json.loads(tool_call.function.arguments)
result = await asyncio.to_thread(self.web_tool.search, args.get("query", ""))
elif tool_call.function.name == "get_web_content":
args = json.loads(tool_call.function.arguments)
result = await asyncio.to_thread(self.web_tool.get_web_content, args.get("url", ""))
else:
result = {"error": f"Unknown tool: {tool_call.function.name}"}
results.append({
"tool_call_id": tool_call.id,
"role": "tool",
"name": tool_call.function.name,
"content": json.dumps(result),
})
return results
def web_search(self, query: str, num_results: int = 10) -> List[Dict[str, Any]]:
"""Perform a web search using the integrated WebTool."""
return self.web_tool.search(query)
def get_web_content(self, url: str) -> str:
"""Retrieve the content of a web page using the integrated WebTool."""
return self.web_tool.get_web_content(url)
def is_url(self, text: str) -> bool:
"""Check if the given text is a valid URL using the integrated WebTool."""
return self.web_tool.is_url(text)
def solve_problem_with_cot(self, problem: str, **kwargs) -> str:
"""
Solve a problem using Chain of Thought reasoning.
"""
return self.cot_manager.solve_problem(problem)
def generate_cot(self, problem: str, **kwargs) -> List[str]:
"""
Generate Chain of Thought steps for a given problem.
"""
return self.cot_manager.generate_cot(problem)
def synthesize_cot(self, cot_steps: List[str], **kwargs) -> str:
"""
Synthesize a final answer from Chain of Thought steps.
"""
return self.cot_manager.synthesize_response(cot_steps)
@ensure_ollama_server_running
def initialize_rag(self, ollama_base_url: str = "http://localhost:11434", model_name: str = "nomic-embed-text", index_path: str = "faiss_index.pkl"):
try:
# Attempt to pull the model if it's not already available
subprocess.run(["ollama", "pull", model_name], check=True)
except subprocess.CalledProcessError:
logger.error(f"Failed to pull model {model_name}. Ensure Ollama is installed and running.")
raise
embeddings = OllamaEmbeddings(base_url=ollama_base_url, model=model_name)
self.rag_manager = RAGManager(embeddings, index_path=index_path)
logger.info("RAG initialized successfully.")
@ensure_ollama_server_running
def load_documents(self, source: str, chunk_size: int = 1000, chunk_overlap: int = 200,
progress_callback: Callable[[int, int], None] = None, timeout: int = 300,
persistent: bool = None):
if persistent is None:
persistent = self.rag_persistent
if not self.rag_manager:
raise ValueError("RAG has not been initialized. Call initialize_rag first.")
# Use a separate index path if non-persistent
index_path = self.rag_index_path if persistent else f"temp_{self.rag_index_path}"
self.rag_manager.index_path = index_path
self.rag_manager.load_and_process_documents(source, chunk_size, chunk_overlap, progress_callback, timeout)
@ensure_ollama_server_running
def query_documents(self, query: str, session_id: Optional[str] = None, **kwargs) -> str:
if not self.rag_manager:
raise ValueError("RAG has not been initialized. Call initialize_rag first.")
llm = ChatGroq(groq_api_key=self.api_key, model_name=kwargs.get("model", "llama3-8b-8192"))
response = self.rag_manager.query_documents(llm, query)
return response['answer']
def query_documents(self, query: str, session_id: Optional[str] = None, **kwargs) -> str:
if not self.rag_manager:
raise ValueError("RAG has not been initialized. Call initialize_rag first.")
llm = ChatGroq(groq_api_key=self.api_key, model_name=kwargs.get("model", "llama3-8b-8192"))
response = self.rag_manager.query_documents(llm, query)
return response['answer']
The text was updated successfully, but these errors were encountered:
The text was updated successfully, but these errors were encountered: