Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: Migrate Vertex AI Search Retriever from v1beta to v1 #630

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions libs/community/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ lint_tests: MYPY_CACHE=.mypy_cache_test

lint lint_diff lint_package lint_tests:
./scripts/lint_imports.sh
poetry run ruff .
poetry run ruff check .
poetry run ruff format $(PYTHON_FILES) --diff
poetry run ruff --select I $(PYTHON_FILES)
poetry run ruff check --select I $(PYTHON_FILES)
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)

format format_diff:
poetry run ruff format $(PYTHON_FILES)
poetry run ruff --select I --fix $(PYTHON_FILES)
poetry run ruff check --select I --fix $(PYTHON_FILES)

spell_check:
poetry run codespell --toml pyproject.toml
Expand Down
76 changes: 8 additions & 68 deletions libs/community/langchain_google_community/vertex_ai_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from google.protobuf.json_format import MessageToDict
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.load import Serializable, load
from langchain_core.retrievers import BaseRetriever
from langchain_core.tools import BaseTool
Expand All @@ -25,7 +24,7 @@
from langchain_google_community._utils import get_client_info

if TYPE_CHECKING:
from google.cloud.discoveryengine_v1beta import ( # type: ignore[import, attr-defined]
from google.cloud.discoveryengine_v1 import ( # type: ignore[import, attr-defined]
ConversationalSearchServiceClient,
SearchRequest,
SearchResult,
Expand Down Expand Up @@ -69,7 +68,7 @@ def __reduce__(self) -> Any:
def validate_environment(cls, values: Dict) -> Any:
"""Validates the environment."""
try:
from google.cloud import discoveryengine_v1beta # noqa: F401
from google.cloud import discoveryengine_v1 # noqa: F401
except ImportError as exc:
raise ImportError(
"Could not import google-cloud-discoveryengine python package. "
Expand Down Expand Up @@ -279,23 +278,6 @@ class VertexAISearchRetriever(BaseRetriever, _BaseVertexAISearchRetriever):
https://cloud.google.com/generative-ai-app-builder/docs/boost-search-results
https://cloud.google.com/generative-ai-app-builder/docs/reference/rest/v1beta/BoostSpec
"""
custom_embedding: Optional[Embeddings] = None
lgesuellip marked this conversation as resolved.
Show resolved Hide resolved
"""Custom embedding model for the retriever. (Bring your own embedding)
It needs to match the embedding model that was used to embed docs in the datastore.
It needs to be a langchain embedding VertexAIEmbeddings(project="{PROJECT}")
If you provide an embedding model, you also need to provide a ranking_expression and
a custom_embedding_field_path.
https://cloud.google.com/generative-ai-app-builder/docs/bring-embeddings
"""
custom_embedding_field_path: Optional[str] = None
""" The field path for the custom embedding used in the Vertex AI datastore schema.
"""
custom_embedding_ratio: Optional[float] = 0.0
"""Controls the ranking of results. Value should be between 0 and 1.
It will generate the ranking_expression in the following manner:
"{custom_embedding_ratio} * dotProduct({custom_embedding_field_path}) +
{1 - custom_embedding_ratio} * relevance_score"
"""

_client: SearchServiceClient = PrivateAttr()
_serving_config: str = PrivateAttr()
Expand All @@ -308,7 +290,7 @@ class VertexAISearchRetriever(BaseRetriever, _BaseVertexAISearchRetriever):
def __init__(self, **kwargs: Any) -> None:
"""Initializes private fields."""
try:
from google.cloud.discoveryengine_v1beta import SearchServiceClient
from google.cloud.discoveryengine_v1 import SearchServiceClient
except ImportError as exc:
raise ImportError(
"Could not import google-cloud-discoveryengine python package. "
Expand Down Expand Up @@ -340,7 +322,7 @@ def __init__(self, **kwargs: Any) -> None:
def _get_content_spec_kwargs(self) -> Optional[Dict[str, Any]]:
"""Prepares a ContentSpec object."""

from google.cloud.discoveryengine_v1beta import SearchRequest
from google.cloud.discoveryengine_v1 import SearchRequest

if self.engine_data_type == 0:
if self.get_extractive_answers:
Expand Down Expand Up @@ -382,7 +364,7 @@ def _get_content_spec_kwargs(self) -> Optional[Dict[str, Any]]:

def _create_search_request(self, query: str) -> SearchRequest:
"""Prepares a SearchRequest object."""
from google.cloud.discoveryengine_v1beta import SearchRequest
from google.cloud.discoveryengine_v1 import SearchRequest

query_expansion_spec = SearchRequest.QueryExpansionSpec(
condition=self.query_expansion_condition,
Expand All @@ -401,46 +383,6 @@ def _create_search_request(self, query: str) -> SearchRequest:
else:
content_search_spec = None

if (
self.custom_embedding is not None
or self.custom_embedding_field_path is not None
):
if self.custom_embedding is None:
raise ValueError(
"Please provide a custom embedding model if you provide a "
"custom_embedding_field_path."
)
if self.custom_embedding_field_path is None:
raise ValueError(
"Please provide a custom_embedding_field_path if you provide a "
"custom embedding model."
)
if self.custom_embedding_ratio is None:
raise ValueError(
"Please provide a custom_embedding_ratio if you provide a "
"custom embedding model or a custom_embedding_field_path."
)
if not 0 <= self.custom_embedding_ratio <= 1:
raise ValueError(
"Custom embedding ratio must be between 0 and 1 "
f"when using custom embeddings. Got {self.custom_embedding_ratio}"
)
embedding_vector = SearchRequest.EmbeddingSpec.EmbeddingVector(
field_path=self.custom_embedding_field_path,
vector=self.custom_embedding.embed_query(query),
)
embedding_spec = SearchRequest.EmbeddingSpec(
embedding_vectors=[embedding_vector]
)
ranking_expression = (
f"{self.custom_embedding_ratio} * "
f"dotProduct({self.custom_embedding_field_path}) + "
f"{1 - self.custom_embedding_ratio} * relevance_score"
)
else:
embedding_spec = None
ranking_expression = None

return SearchRequest(
query=query,
filter=self.filter,
Expand All @@ -454,8 +396,6 @@ def _create_search_request(self, query: str) -> SearchRequest:
boost_spec=SearchRequest.BoostSpec(**self.boost_spec)
if self.boost_spec
else None,
embedding_spec=embedding_spec,
ranking_expression=ranking_expression,
)

def _get_relevant_documents(
Expand Down Expand Up @@ -517,7 +457,7 @@ class VertexAIMultiTurnSearchRetriever(BaseRetriever, _BaseVertexAISearchRetriev

def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
from google.cloud.discoveryengine_v1beta import (
from google.cloud.discoveryengine_v1 import (
ConversationalSearchServiceClient,
)

Expand Down Expand Up @@ -545,7 +485,7 @@ def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant for a query."""
from google.cloud.discoveryengine_v1beta import (
from google.cloud.discoveryengine_v1 import (
ConverseConversationRequest,
TextInput,
)
Expand Down Expand Up @@ -599,7 +539,7 @@ def _get_content_spec_kwargs(self) -> Optional[Dict[str, Any]]:
Returns:
kwargs for the specification of the content.
"""
from google.cloud.discoveryengine_v1beta import SearchRequest
from google.cloud.discoveryengine_v1 import SearchRequest

kwargs = super()._get_content_spec_kwargs() or {}

Expand Down
Loading
Loading