Skip to content

Commit

Permalink
refactor: RAG Refactor (#985)
Browse files Browse the repository at this point in the history
Co-authored-by: Aralhi <[email protected]>
Co-authored-by: csunny <[email protected]>
  • Loading branch information
3 people authored Jan 3, 2024
1 parent 90775aa commit 9ad70a2
Show file tree
Hide file tree
Showing 206 changed files with 5,766 additions and 2,419 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ setup: ## Set up the Python development environment
$(VENV_BIN)/pip install -r requirements/lint-requirements.txt

testenv: setup ## Set up the Python test environment
$(VENV_BIN)/pip install -e ".[simple_framework]"
$(VENV_BIN)/pip install -e ".[default]"

.PHONY: fmt
fmt: setup ## Format Python code
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/app/component_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def initialize_components(
system_app.register_instance(controller)

# Register global default RAGGraphFactory
# from dbgpt.graph_engine.graph_factory import DefaultRAGGraphFactory
# from dbgpt.graph.graph_factory import DefaultRAGGraphFactory

# system_app.register(DefaultRAGGraphFactory)

Expand Down
2 changes: 1 addition & 1 deletion dbgpt/app/initialization/embedding_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from typing import Any, Type, TYPE_CHECKING
from dbgpt.component import ComponentType, SystemApp
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory

if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/app/knowledge/_cli/knowledge_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
DocumentQueryRequest,
)

from dbgpt.rag.embedding_engine.knowledge_type import KnowledgeType
from dbgpt.app.knowledge.request.request import DocumentSyncRequest

from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
from dbgpt.rag.knowledge.base import KnowledgeType

HTTP_HEADERS = {"Content-Type": "application/json"}

Expand Down
79 changes: 69 additions & 10 deletions dbgpt/app/knowledge/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import shutil
import tempfile
import logging
from typing import List

from fastapi import APIRouter, File, UploadFile, Form

Expand All @@ -13,10 +14,10 @@
from dbgpt.app.openapi.api_v1.api_v1 import no_stream_generator, stream_generator

from dbgpt.app.openapi.api_view_model import Result
from dbgpt.rag.embedding_engine.embedding_engine import EmbeddingEngine
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory

from dbgpt.app.knowledge.service import KnowledgeService
from dbgpt.rag.knowledge.factory import KnowledgeFactory
from dbgpt.app.knowledge.request.request import (
KnowledgeQueryRequest,
KnowledgeQueryResponse,
Expand All @@ -27,9 +28,14 @@
SpaceArgumentRequest,
EntityExtractRequest,
DocumentSummaryRequest,
KnowledgeSyncRequest,
)

from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
from dbgpt.rag.knowledge.base import ChunkStrategy
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.util.tracer import root_tracer, SpanType

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -103,6 +109,39 @@ def document_add(space_name: str, request: KnowledgeDocumentRequest):
return Result.failed(code="E000X", msg=f"document add error {e}")


@router.get("/knowledge/document/chunkstrategies")
def chunk_strategies():
"""Get chunk strategies"""
print(f"/document/chunkstrategies:")
try:
return Result.succ(
[
{
"strategy": strategy.name,
"name": strategy.value[2],
"description": strategy.value[3],
"parameters": strategy.value[1],
"suffix": [
knowledge.document_type().value
for knowledge in KnowledgeFactory.subclasses()
if strategy in knowledge.support_chunk_strategy()
and knowledge.document_type() is not None
],
"type": set(
[
knowledge.type().value
for knowledge in KnowledgeFactory.subclasses()
if strategy in knowledge.support_chunk_strategy()
]
),
}
for strategy in ChunkStrategy
]
)
except Exception as e:
return Result.failed(code="E000X", msg=f"chunk strategies error {e}")


@router.post("/knowledge/{space_name}/document/list")
def document_list(space_name: str, query_request: DocumentQueryRequest):
print(f"/document/list params: {space_name}, {query_request}")
Expand Down Expand Up @@ -189,6 +228,18 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
return Result.failed(code="E000X", msg=f"document sync error {e}")


@router.post("/knowledge/{space_name}/document/sync_batch")
def batch_document_sync(space_name: str, request: List[KnowledgeSyncRequest]):
logger.info(f"Received params: {space_name}, {request}")
try:
doc_ids = knowledge_space_service.batch_document_sync(
space_name=space_name, sync_requests=request
)
return Result.succ({"tasks": doc_ids})
except Exception as e:
return Result.failed(code="E000X", msg=f"document sync error {e}")


@router.post("/knowledge/{space_name}/chunk/list")
def document_list(space_name: str, query_request: ChunkQueryRequest):
print(f"/document/list params: {space_name}, {query_request}")
Expand All @@ -204,15 +255,23 @@ def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
client = EmbeddingEngine(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config={"vector_store_name": space_name},
embedding_factory=embedding_factory,
config = VectorStoreConfig(
name=space_name,
embedding_fn=embedding_factory.create(
EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
),
)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
)
retriever = EmbeddingRetriever(
top_k=query_request.top_k, vector_store_connector=vector_store_connector
)
docs = client.similar_search(query_request.query, query_request.top_k)
chunks = retriever.retrieve(query_request.query)
res = [
KnowledgeQueryResponse(text=d.page_content, source=d.metadata["source"])
for d in docs
KnowledgeQueryResponse(text=d.content, source=d.metadata["source"])
for d in chunks
]
return {"response": res}

Expand Down Expand Up @@ -254,7 +313,7 @@ async def entity_extract(request: EntityExtractRequest):
logger.info(f"Received params: {request}")
try:
from dbgpt.app.scene import ChatScene
from dbgpt._private.chat_util import llm_chat_response_nostream
from dbgpt.util.chat_util import llm_chat_response_nostream
import uuid

chat_param = {
Expand Down
24 changes: 24 additions & 0 deletions dbgpt/app/knowledge/document_db.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime
from typing import List

from sqlalchemy import Column, String, DateTime, Integer, Text, func

Expand Down Expand Up @@ -51,6 +52,12 @@ def create_knowledge_document(self, document: KnowledgeDocumentEntity):
return doc_id

def get_knowledge_documents(self, query, page=1, page_size=20):
"""Get a list of documents that match the given query.
Args:
query: A KnowledgeDocumentEntity object containing the query parameters.
page: The page number to return.
page_size: The number of documents to return per page.
"""
session = self.get_raw_session()
print(f"current session:{session}")
knowledge_documents = session.query(KnowledgeDocumentEntity)
Expand Down Expand Up @@ -85,6 +92,23 @@ def get_knowledge_documents(self, query, page=1, page_size=20):
session.close()
return result

def documents_by_ids(self, ids) -> List[KnowledgeDocumentEntity]:
"""Get a list of documents by their IDs.
Args:
ids: A list of document IDs.
Returns:
A list of KnowledgeDocumentEntity objects.
"""
session = self.get_raw_session()
print(f"current session:{session}")
knowledge_documents = session.query(KnowledgeDocumentEntity)
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.id.in_(ids)
)
result = knowledge_documents.all()
session.close()
return result

def get_documents(self, query):
session = self.get_raw_session()
print(f"current session:{session}")
Expand Down
18 changes: 18 additions & 0 deletions dbgpt/app/knowledge/request/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from dbgpt._private.pydantic import BaseModel
from fastapi import UploadFile

from dbgpt.rag.chunk_manager import ChunkParameters


class KnowledgeQueryRequest(BaseModel):
"""query: knowledge query"""
Expand Down Expand Up @@ -43,6 +45,8 @@ class DocumentQueryRequest(BaseModel):
"""doc_name: doc path"""

doc_name: str = None
"""doc_ids: doc ids"""
doc_ids: Optional[List] = None
"""doc_type: doc type"""
doc_type: str = None
"""status: status"""
Expand Down Expand Up @@ -76,6 +80,20 @@ class DocumentSyncRequest(BaseModel):
chunk_overlap: Optional[int] = None


class KnowledgeSyncRequest(BaseModel):
"""Sync request"""

"""doc_ids: doc ids"""
doc_id: int

"""model_name: model name"""
model_name: Optional[str] = None

"""chunk_parameters: chunk parameters
"""
chunk_parameters: ChunkParameters


class ChunkQueryRequest(BaseModel):
"""id: id"""

Expand Down
Loading

0 comments on commit 9ad70a2

Please sign in to comment.