Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix:solve the dup retrieval issue #575

Merged
merged 5 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions petercat_utils/rag_helper/retrieval.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import json
from typing import Any, Dict


from langchain_community.vectorstores import SupabaseVectorStore
from langchain_openai import OpenAIEmbeddings

from .github_file_loader import GithubFileLoader
from ..data_class import GitDocConfig, RAGGitDocConfig, S3Config
from ..db.client.supabase import get_client


TABLE_NAME = "rag_docs"
QUERY_NAME = "match_embedding_docs"
CHUNK_SIZE = 2000
Expand Down Expand Up @@ -118,15 +116,16 @@
supabase = get_client()
is_doc_added_query = (
supabase.table(TABLE_NAME)
.select("id, repo_name, commit_id, file_path")
.select("id")

Check warning on line 119 in petercat_utils/rag_helper/retrieval.py

View check run for this annotation

Codecov / codecov/patch

petercat_utils/rag_helper/retrieval.py#L119

Added line #L119 was not covered by tests
.eq("repo_name", config.repo_name)
.eq("commit_id", loader.commit_id)
.eq("file_path", config.file_path)
.limit(1)
.execute()
)
if not is_doc_added_query.data:
is_doc_equal_query = (
supabase.table(TABLE_NAME).select("*").eq("file_sha", loader.file_sha)
supabase.table(TABLE_NAME).select("id").eq("file_sha", loader.file_sha).limit(1)

Check warning on line 128 in petercat_utils/rag_helper/retrieval.py

View check run for this annotation

Codecov / codecov/patch

petercat_utils/rag_helper/retrieval.py#L128

Added line #L128 was not covered by tests
).execute()
if not is_doc_equal_query.data:
# If there is no file with the same file_sha, perform embedding.
Expand All @@ -139,14 +138,26 @@
)
return store
else:
# Prioritize obtaining the minimal set of records to avoid overlapping with the original records.
minimum_repeat_result = supabase.rpc('count_rag_docs_by_sha', {'file_sha_input': loader.file_sha}).execute()

Check warning on line 142 in petercat_utils/rag_helper/retrieval.py

View check run for this annotation

Codecov / codecov/patch

petercat_utils/rag_helper/retrieval.py#L142

Added line #L142 was not covered by tests
xingwanying marked this conversation as resolved.
Show resolved Hide resolved
xingwanying marked this conversation as resolved.
Show resolved Hide resolved
target_filter = minimum_repeat_result.data[0]
# Copy the minimal set
insert_docs = (
supabase.table(TABLE_NAME)
.select("*")
.eq("repo_name", target_filter['repo_name'])
.eq("file_path", target_filter['file_path'])
.eq("file_sha", target_filter['file_sha'])
.execute()
)

Check warning on line 152 in petercat_utils/rag_helper/retrieval.py

View check run for this annotation

Codecov / codecov/patch

petercat_utils/rag_helper/retrieval.py#L151-L152

Added lines #L151 - L152 were not covered by tests
new_commit_list = [
{
**{k: v for k, v in item.items() if k != "id"},
"repo_name": config.repo_name,
"commit_id": loader.commit_id,
"file_path": config.file_path,
}
for item in is_doc_equal_query.data
for item in insert_docs.data
]
insert_result = supabase.table(TABLE_NAME).insert(new_commit_list).execute()
return insert_result
Expand All @@ -169,9 +180,9 @@


def search_knowledge(
query: str,
repo_name: str,
meta_filter: Dict[str, Any] = {},
query: str,
repo_name: str,
meta_filter: Dict[str, Any] = {},
):
retriever = init_retriever(
{"filter": {"metadata": meta_filter, "repo_name": repo_name}}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "petercat_utils"
version = "0.1.39"
version = "0.1.40"
description = ""
authors = ["raoha.rh <[email protected]>"]
readme = "README.md"
Expand Down
137 changes: 68 additions & 69 deletions server/auth/middleware.py
Original file line number Diff line number Diff line change
@@ -1,88 +1,87 @@
import traceback
from typing import Awaitable, Callable

from fastapi import HTTPException, Request, status
from fastapi.responses import JSONResponse
from petercat_utils import get_env_variable
from fastapi.security import OAuth2PasswordBearer
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from fastapi.security import OAuth2PasswordBearer

from core.dao.botDAO import BotDAO

WEB_URL = get_env_variable("WEB_URL")
ENVRIMENT = get_env_variable("PETERCAT_ENV", "development")
from env import ENVIRONMENT, WEB_URL

ALLOW_LIST = [
"/",
"/favicon.ico",
"/api/health_checker",
"/api/bot/list",
"/api/bot/detail",
"/api/github/app/webhook",
"/app/installation/callback",
"/",
"/favicon.ico",
"/api/health_checker",
"/api/bot/list",
"/api/bot/detail",
"/api/github/app/webhook",
"/app/installation/callback",
]

ANONYMOUS_USER_ALLOW_LIST = [
"/api/auth/userinfo",
"/api/chat/qa",
"/api/chat/stream_qa",
"/api/auth/userinfo",
"/api/chat/qa",
"/api/chat/stream_qa",
]

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/token")


class AuthMiddleWare(BaseHTTPMiddleware):

async def oauth(self, request: Request):
try:
referer = request.headers.get('referer')
origin = request.headers.get('origin')
if referer and referer.startswith(WEB_URL):
return True
token = await oauth2_scheme(request=request)
if token:
bot_dao = BotDAO()
bot = bot_dao.get_bot(bot_id=token)
return bot and (
"*" in bot.domain_whitelist
or
origin in bot.domain_whitelist
)
except HTTPException:
return False
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
try:
# if ENVRIMENT == "development":
# return await call_next(request)
# Auth 相关的直接放过
if request.url.path.startswith("/api/auth"):
return await call_next(request)
if request.url.path in ALLOW_LIST:
return await call_next(request)
if await self.oauth(request=request):
return await call_next(request)

# 获取 session 中的用户信息
user = request.session.get("user")
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized")
if user['sub'].startswith("client|"):
if request.url.path in ANONYMOUS_USER_ALLOW_LIST:
return await call_next(request)
else:
# 如果没有用户信息,返回 401 Unauthorized 错误
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Anonymous User Not Allow")
return await call_next(request)
except HTTPException as e:
print(traceback.format_exception(e))
# 处理 HTTP 异常
return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
except Exception as e:
# 处理其他异常
return JSONResponse(status_code=500, content={"detail": f"Internal Server Error: {e}"})
async def oauth(self, request: Request):
try:
referer = request.headers.get('referer')
origin = request.headers.get('origin')
if referer and referer.startswith(WEB_URL):
return True

Check warning on line 40 in server/auth/middleware.py

View check run for this annotation

Codecov / codecov/patch

server/auth/middleware.py#L40

Added line #L40 was not covered by tests
token = await oauth2_scheme(request=request)
if token:
bot_dao = BotDAO()
bot = bot_dao.get_bot(bot_id=token)
return bot and (
"*" in bot.domain_whitelist

Check warning on line 46 in server/auth/middleware.py

View check run for this annotation

Codecov / codecov/patch

server/auth/middleware.py#L43-L46

Added lines #L43 - L46 were not covered by tests
or
origin in bot.domain_whitelist
)
except HTTPException:
return False

async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
try:
if ENVIRONMENT == "development":
return await call_next(request)

# Auth 相关的直接放过
if request.url.path.startswith("/api/auth"):
return await call_next(request)

Check warning on line 61 in server/auth/middleware.py

View check run for this annotation

Codecov / codecov/patch

server/auth/middleware.py#L61

Added line #L61 was not covered by tests
if request.url.path in ALLOW_LIST:
return await call_next(request)

if await self.oauth(request=request):
return await call_next(request)

# 获取 session 中的用户信息
user = request.session.get("user")
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized")

if user['sub'].startswith("client|"):
if request.url.path in ANONYMOUS_USER_ALLOW_LIST:
return await call_next(request)
else:

Check warning on line 76 in server/auth/middleware.py

View check run for this annotation

Codecov / codecov/patch

server/auth/middleware.py#L75-L76

Added lines #L75 - L76 were not covered by tests
# 如果没有用户信息,返回 401 Unauthorized 错误
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Anonymous User Not Allow")

Check warning on line 79 in server/auth/middleware.py

View check run for this annotation

Codecov / codecov/patch

server/auth/middleware.py#L79

Added line #L79 was not covered by tests
return await call_next(request)
except HTTPException as e:
print(traceback.format_exception(e))
# 处理 HTTP 异常
return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
except Exception as e:
# 处理其他异常

Check warning on line 86 in server/auth/middleware.py

View check run for this annotation

Codecov / codecov/patch

server/auth/middleware.py#L86

Added line #L86 was not covered by tests
return JSONResponse(status_code=500, content={"detail": f"Internal Server Error: {e}"})
22 changes: 11 additions & 11 deletions server/auth/router.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from github import Github
from core.dao.profilesDAO import ProfilesDAO
import secrets
from typing import Annotated, Optional

from authlib.integrations.starlette_client import OAuth
from fastapi import APIRouter, Request, HTTPException, status, Depends
from fastapi.responses import RedirectResponse, JSONResponse
import secrets
from petercat_utils import get_client, get_env_variable
from github import Github
from starlette.config import Config
from authlib.integrations.starlette_client import OAuth
from typing import Annotated, Optional

from auth.get_user_info import generateAnonymousUser, getUserInfoByToken, get_user_id
from auth.get_user_info import (
generateAnonymousUser,
getUserAccessToken,
getUserInfoByToken,
get_user_id,
)
from core.dao.profilesDAO import ProfilesDAO
from petercat_utils import get_client, get_env_variable

AUTH0_DOMAIN = get_env_variable("AUTH0_DOMAIN")

Expand All @@ -26,6 +25,7 @@
LOGIN_URL = f"{API_URL}/api/auth/login"

WEB_URL = get_env_variable("WEB_URL")

WEB_LOGIN_SUCCESS_URL = f"{WEB_URL}/user/login"
MARKET_URL = f"{WEB_URL}/market"

Expand Down Expand Up @@ -133,8 +133,8 @@ async def get_agreement_status(user_id: Optional[str] = Depends(get_user_id)):

@router.post("/accept/agreement", status_code=200)
async def bot_generator(
request: Request,
user_id: Annotated[str | None, Depends(get_user_id)] = None,
request: Request,
user_id: Annotated[str | None, Depends(get_user_id)] = None,
):
if not user_id:
raise HTTPException(status_code=401, detail="User not found")
Expand Down
6 changes: 6 additions & 0 deletions server/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# list all env variables
from petercat_utils import get_env_variable

WEB_URL = get_env_variable("WEB_URL")
ENVIRONMENT = get_env_variable("PETERCAT_ENV", "development")
API_URL = get_env_variable("API_URL")
21 changes: 10 additions & 11 deletions server/github_app/router.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
from typing import Annotated

from fastapi import (
APIRouter,
BackgroundTasks,
Expand All @@ -8,27 +10,24 @@
Request,
status,
)
import logging
from fastapi.responses import RedirectResponse

from github import Auth, Github

from auth.get_user_info import get_user
from core.dao.repositoryConfigDAO import RepositoryConfigDAO
from core.models.bot import RepoBindBotRequest
from core.models.user import User

from env import WEB_URL
from github_app.handlers import get_handler
from github_app.purchased import PurchaseServer
from github_app.utils import (
get_private_key,
)

from petercat_utils import get_env_variable

REGIN_NAME = get_env_variable("AWS_REGION")
AWS_GITHUB_SECRET_NAME = get_env_variable("AWS_GITHUB_SECRET_NAME")
APP_ID = get_env_variable("X_GITHUB_APP_ID")
WEB_URL = get_env_variable("WEB_URL")

logger = logging.getLogger()
logger.setLevel("INFO")
Expand All @@ -51,9 +50,9 @@ def github_app_callback(code: str, installation_id: str, setup_action: str):

@router.post("/app/webhook")
async def github_app_webhook(
request: Request,
background_tasks: BackgroundTasks,
x_github_event: str = Header(...),
request: Request,
background_tasks: BackgroundTasks,
x_github_event: str = Header(...),
):
payload = await request.json()
if x_github_event == "marketplace_purchase":
Expand Down Expand Up @@ -86,7 +85,7 @@ async def github_app_webhook(

@router.get("/user/repos_installed_app")
def get_user_repos_installed_app(
user: Annotated[User | None, Depends(get_user)] = None
user: Annotated[User | None, Depends(get_user)] = None
):
"""
Get github user installed app repositories which saved in platform database.
Expand Down Expand Up @@ -116,8 +115,8 @@ def get_user_repos_installed_app(

@router.post("/repo/bind_bot", status_code=200)
def bind_bot_to_repo(
request: RepoBindBotRequest,
user: Annotated[User | None, Depends(get_user)] = None,
request: RepoBindBotRequest,
user: Annotated[User | None, Depends(get_user)] = None,
):
if user is None:
raise HTTPException(
Expand Down
Loading
Loading