Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Update Vertex AI Search Component APIs to v1 #427

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions libs/community/langchain_google_community/vertex_check_grounding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from langchain_google_community._utils import get_client_info

if TYPE_CHECKING:
from google.cloud import discoveryengine_v1alpha # type: ignore
from google.cloud import discoveryengine_v1 # type: ignore


class VertexAICheckGroundingWrapper(
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(self, **kwargs: Any):

def _get_check_grounding_service_client(
self,
) -> "discoveryengine_v1alpha.GroundedGenerationServiceClient":
) -> "discoveryengine_v1.GroundedGenerationServiceClient":
"""
Returns a GroundedGenerationServiceClient instance using provided credentials.
Raises ImportError if necessary packages are not installed.
Expand All @@ -77,14 +77,14 @@ def _get_check_grounding_service_client(
A GroundedGenerationServiceClient instance.
"""
try:
from google.cloud import discoveryengine_v1alpha # type: ignore
from google.cloud import discoveryengine_v1 # type: ignore
except ImportError as exc:
raise ImportError(
"Could not import google-cloud-discoveryengine python package. "
"Please install vertexaisearch dependency group: "
"`pip install langchain-google-community[vertexaisearch]`"
) from exc
return discoveryengine_v1alpha.GroundedGenerationServiceClient(
return discoveryengine_v1.GroundedGenerationServiceClient(
credentials=(
self.credentials
or Credentials.from_service_account_file(self.credentials_path) # type: ignore[attr-defined]
Expand All @@ -106,7 +106,7 @@ def invoke(
answer_candidate (str): The candidate answer to be evaluated for grounding.
documents (List[Document]): The documents against which grounding is
checked. This will be converted to facts:
facts (MutableSequence[google.cloud.discoveryengine_v1alpha.types.\
facts (MutableSequence[google.cloud.discoveryengine_v1.types.\
GroundingFact]):
List of facts for the grounding check.
We support up to 200 facts.
Expand All @@ -121,31 +121,31 @@ def invoke(
provided facts. This is always set when a
response is returned.

cited_chunks (MutableSequence[google.cloud.discoveryengine_v1alpha.types.\
cited_chunks (MutableSequence[google.cloud.discoveryengine_v1.types.\
FactChunk]):
List of facts cited across all claims in the
answer candidate. These are derived from the
facts supplied in the request.

claims (MutableSequence[google.cloud.discoveryengine_v1alpha.types.\
claims (MutableSequence[google.cloud.discoveryengine_v1.types.\
CheckGroundingResponse.Claim]):
Claim texts and citation info across all
claims in the answer candidate.

answer_with_citations (str):
Complete formed answer formatted with inline citations
"""
from google.cloud import discoveryengine_v1alpha # type: ignore
from google.cloud import discoveryengine_v1 # type: ignore

answer_candidate = input
documents = self.extract_documents(config)

grounding_spec = discoveryengine_v1alpha.CheckGroundingSpec(
grounding_spec = discoveryengine_v1.CheckGroundingSpec(
citation_threshold=self.citation_threshold,
)

facts = [
discoveryengine_v1alpha.GroundingFact(
discoveryengine_v1.GroundingFact(
fact_text=doc.page_content,
attributes={
key: value
Expand All @@ -162,15 +162,18 @@ def invoke(
if not facts:
raise ValueError("No valid documents provided for grounding.")

request = discoveryengine_v1alpha.CheckGroundingRequest(
grounding_config=f"projects/{self.project_id}/locations/{self.location_id}/groundingConfigs/{self.grounding_config}",
if self.client is None:
raise ValueError("Client not initialized.")

request = discoveryengine_v1.CheckGroundingRequest(
grounding_config=self.client.grounding_config_path(
self.project_id, self.location_id, self.grounding_config
),
answer_candidate=answer_candidate,
facts=facts,
grounding_spec=grounding_spec,
)

if self.client is None:
raise ValueError("Client not initialized.")
try:
response = self.client.check_grounding(request=request)
except core_exceptions.GoogleAPICallError as e:
Expand Down
26 changes: 10 additions & 16 deletions libs/community/langchain_google_community/vertex_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@
from langchain_google_community._utils import get_client_info

if TYPE_CHECKING:
from google.cloud import discoveryengine_v1alpha # type: ignore

if TYPE_CHECKING:
from google.cloud import discoveryengine_v1alpha # type: ignore
from google.cloud import discoveryengine_v1 # type: ignore


class VertexAIRank(BaseDocumentCompressor):
Expand Down Expand Up @@ -78,7 +75,7 @@ def __init__(self, **kwargs: Any):
if not self.client:
self.client = self._get_rank_service_client()

def _get_rank_service_client(self) -> "discoveryengine_v1alpha.RankServiceClient":
def _get_rank_service_client(self) -> "discoveryengine_v1.RankServiceClient":
"""
Returns a RankServiceClient instance for making API calls to the
Vertex AI Ranking service.
Expand All @@ -87,14 +84,14 @@ def _get_rank_service_client(self) -> "discoveryengine_v1alpha.RankServiceClient
A RankServiceClient instance.
"""
try:
from google.cloud import discoveryengine_v1alpha # type: ignore
from google.cloud import discoveryengine_v1 # type: ignore
except ImportError as exc:
raise ImportError(
"Could not import google-cloud-discoveryengine python package. "
"Please, install vertexaisearch dependency group: "
"`pip install langchain-google-community[vertexaisearch]`"
) from exc
return discoveryengine_v1alpha.RankServiceClient(
return discoveryengine_v1.RankServiceClient(
credentials=(
self.credentials
or Credentials.from_service_account_file(self.credentials_path) # type: ignore[attr-defined]
Expand All @@ -117,11 +114,11 @@ def _rerank_documents(
Returns:
A list of reranked documents.
"""
from google.cloud import discoveryengine_v1alpha # type: ignore
from google.cloud import discoveryengine_v1 # type: ignore

try:
records = [
discoveryengine_v1alpha.RankingRecord(
discoveryengine_v1.RankingRecord(
id=(doc.metadata.get(self.id_field) if self.id_field else str(idx)),
content=doc.page_content,
**(
Expand All @@ -137,13 +134,10 @@ def _rerank_documents(
except KeyError:
warnings.warn(f"id_field '{self.id_field}' not found in document metadata.")

ranking_config_path = (
f"projects/{self.project_id}/locations/{self.location_id}"
f"/rankingConfigs/{self.ranking_config}"
)

request = discoveryengine_v1alpha.RankRequest(
ranking_config=ranking_config_path,
request = discoveryengine_v1.RankRequest(
ranking_config=self.client.ranking_config_path(
self.project_id, self.location_id, self.ranking_config
),
model=self.model,
query=query,
records=records,
Expand Down
2 changes: 1 addition & 1 deletion libs/community/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ google-cloud-speech = { version = "^2.26.0", optional = true }
googlemaps = { version = "^4.10.0", optional = true }
google-cloud-texttospeech = { version = "^2.16.3", optional = true }
google-cloud-translate = { version = "^3.15.3", optional = true }
google-cloud-discoveryengine = { version = "^0.11.13", optional = true }
google-cloud-discoveryengine = { version = "^0.12.0", optional = true }
google-cloud-vision = { version = "^3.7.2", optional = true }
beautifulsoup4 = { version = "^4.12.3", optional = true }
pandas = [
Expand Down
Loading