diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json index b26833333238..c537844dc84a 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 2 + "modification": 3 } diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json index b26833333238..c537844dc84a 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 2 + "modification": 3 } diff --git a/sdks/python/apache_beam/ml/rag/enrichment/__init__.py b/sdks/python/apache_beam/ml/rag/enrichment/__init__.py new file mode 100644 index 000000000000..efcb5ac31950 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/enrichment/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Enrichment components for RAG pipelines. +This module provides components for vector search enrichment in RAG pipelines. +""" diff --git a/sdks/python/apache_beam/ml/rag/enrichment/bigquery_vector_search.py b/sdks/python/apache_beam/ml/rag/enrichment/bigquery_vector_search.py new file mode 100644 index 000000000000..b958273b29fe --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/enrichment/bigquery_vector_search.py @@ -0,0 +1,367 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import defaultdict +from dataclasses import dataclass +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +from google.cloud import bigquery + +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Embedding +from apache_beam.transforms.enrichment import EnrichmentSourceHandler + + +@dataclass +class BigQueryVectorSearchParameters: + """Parameters for configuring BigQuery vector similarity search. + + This class is used by BigQueryVectorSearchEnrichmentHandler to perform + vector similarity search using BigQuery's VECTOR_SEARCH function. It + processes :class:`~apache_beam.ml.rag.types.Chunk` objects that contain + :class:`~apache_beam.ml.rag.types.Embedding` and returns similar vectors + from a BigQuery table. + + BigQueryVectorSearchEnrichmentHandler is used with + :class:`~apache_beam.transforms.enrichment.Enrichment` transform to enrich + Chunks with similar content from a vector database. For example: + + >>> # Create search parameters + >>> params = BigQueryVectorSearchParameters( + ... table_name='project.dataset.embeddings', + ... embedding_column='embedding', + ... columns=['content'], + ... neighbor_count=5 + ... ) + >>> # Use in pipeline + >>> enriched = ( + ... chunks + ... | "Generate Embeddings" >> MLTransform(...) + ... | "Find Similar" >> Enrichment( + ... BigQueryVectorSearchEnrichmentHandler( + ... project='my-project', + ... vector_search_parameters=params + ... ) + ... ) + ... ) + + BigQueryVectorSearchParameters encapsulates the configuration needed to + perform vector similarity search using BigQuery's VECTOR_SEARCH function. + It handles formatting the query with proper embedding vectors and metadata + restrictions. + + Example with flattened metadata column: + + Table schema:: + + embedding: ARRAY # Vector embedding + content: STRING # Document content + language: STRING # Direct metadata column + + Code:: + + >>> params = BigQueryVectorSearchParameters( + ... table_name='project.dataset.embeddings', + ... embedding_column='embedding', + ... columns=['content', 'language'], + ... neighbor_count=5, + ... # For column 'language', value comes from + ... # chunk.metadata['language'] + ... metadata_restriction_template="language = '{language}'" + ... ) + >>> # When processing a chunk with metadata={'language': 'en'}, + >>> # generates: WHERE language = 'en' + + Example with nested repeated metadata: + + Table schema:: + + embedding: ARRAY # Vector embedding + content: STRING # Document content + metadata: ARRAY # Nested repeated metadata + key: STRING, + value: STRING + >> + + Code:: + + >>> params = BigQueryVectorSearchParameters( + ... table_name='project.dataset.embeddings', + ... embedding_column='embedding', + ... columns=['content', 'metadata'], + ... neighbor_count=5, + ... # check_metadata(field_name, key_to_search, value_from_chunk) + ... metadata_restriction_template=( + ... "check_metadata(metadata, 'language', '{language}')" + ... ) + ... ) + >>> # When processing a chunk with metadata={'language': 'en'}, + >>> # generates: WHERE check_metadata(metadata, 'language', 'en') + >>> # Searches for {key: 'language', value: 'en'} in metadata array + + Args: + project: GCP project ID containing the BigQuery dataset + table_name: Fully qualified BigQuery table name containing vectors. + embedding_column: Column name containing the embedding vectors. + columns: List of columns to retrieve from matched vectors. + neighbor_count: Number of similar vectors to return (top-k). + metadata_restriction_template: Template string for filtering vectors. + Two formats supported: + + 1. For flattened metadata columns: + ``column_name = '{metadata_key}'`` where column_name is the + BigQuery column and metadata_key is used to get the value from + chunk.metadata[metadata_key]. + 2. For nested repeated metadata (ARRAY>): + ``check_metadata(field_name, 'key_to_match', '{metadata_key}')`` + where field_name is the ARRAY column in BigQuery, + key_to_match is the literal key to search for in the array, and + metadata_key is used to get value from + chunk.metadata[metadata_key]. + + Multiple conditions can be combined using AND/OR operators. For + example:: + + >>> # Combine metadata check with column filter + >>> template = ( + ... "check_metadata(metadata, 'language', '{language}') " + ... "AND source = '{source}'" + ... ) + >>> # When chunk.metadata = {'language': 'en', 'source': 'web'} + >>> # Generates: WHERE + >>> # check_metadata(metadata, 'language', 'en') + >>> # AND source = 'web' + + distance_type: Optional distance metric to use. Supported values: + COSINE (default), EUCLIDEAN, or DOT_PRODUCT. + options: Optional dictionary of additional VECTOR_SEARCH options. + """ + project: str + table_name: str + embedding_column: str + columns: List[str] + neighbor_count: int + metadata_restriction_template: Optional[str] = None + distance_type: Optional[str] = None + options: Optional[Dict[str, Any]] = None + + def format_query(self, chunks: List[Chunk]) -> str: + """Format the vector search query template.""" + base_columns_str = ", ".join(f"base.{col}" for col in self.columns) + columns_str = ", ".join(self.columns) + distance_clause = ( + f", distance_type => '{self.distance_type}'" + if self.distance_type else "") + options_clause = (f", options => {self.options}" if self.options else "") + + # Create metadata check function only if needed + metadata_fn = """ + CREATE TEMP FUNCTION check_metadata( + metadata ARRAY>, + search_key STRING, + search_value STRING + ) + AS (( + SELECT COUNT(*) > 0 + FROM UNNEST(metadata) + WHERE key = search_key AND value = search_value + )); + """ if self.metadata_restriction_template else "" + + # Group chunks by their metadata conditions + condition_groups = defaultdict(list) + if self.metadata_restriction_template: + for chunk in chunks: + condition = self.metadata_restriction_template.format(**chunk.metadata) + condition_groups[condition].append(chunk) + else: + # No metadata filtering - all chunks in one group + condition_groups[""] = chunks + + # Generate VECTOR_SEARCH subqueries for each condition group + vector_searches = [] + for condition, group_chunks in condition_groups.items(): + # Create embeddings subquery for this group + embedding_unions = [] + for chunk in group_chunks: + if chunk.embedding is None or chunk.embedding.dense_embedding is None: + raise ValueError(f"Chunk {chunk.id} missing embedding") + embedding_str = ( + f"SELECT '{chunk.id}' as id, " + f"{[float(x) for x in chunk.embedding.dense_embedding]} " + f"as embedding") + embedding_unions.append(embedding_str) + group_embeddings = " UNION ALL ".join(embedding_unions) + + # Create VECTOR_SEARCH for this condition group + where_clause = f"WHERE {condition}" if condition else "" + # Create VECTOR_SEARCH for this condition group + vector_search = f""" + SELECT + query.id, + ARRAY_AGG( + STRUCT({base_columns_str}) + ) as chunks + FROM VECTOR_SEARCH( + (SELECT {columns_str}, {self.embedding_column} + FROM `{self.table_name}` + {where_clause}), + '{self.embedding_column}', + (SELECT * FROM ({group_embeddings})), + top_k => {self.neighbor_count} + {distance_clause} + {options_clause} + ) + GROUP BY query.id + """ + vector_searches.append(vector_search) + + # Combine all vector searches + combined_searches = " UNION ALL ".join(vector_searches) + + return f""" + {metadata_fn} + + {combined_searches} + """ + + +class BigQueryVectorSearchEnrichmentHandler( + EnrichmentSourceHandler[Union[Chunk, List[Chunk]], + List[Tuple[Chunk, Dict[str, Any]]]]): + """Enrichment handler that performs vector similarity search using BigQuery. + + This handler enriches Chunks by finding similar vectors in a BigQuery table + using the VECTOR_SEARCH function. It supports batching requests for efficiency + and preserves the original Chunk metadata while adding the search results. + + Example: + >>> from apache_beam.ml.rag.types import Chunk, Content, Embedding + >>> + >>> # Configure vector search + >>> params = BigQueryVectorSearchParameters( + ... table_name='project.dataset.embeddings', + ... embedding_column='embedding', + ... columns=['content', 'metadata'], + ... neighbor_count=2, + ... metadata_restriction_template="language = '{language}'" + ... ) + >>> + >>> # Create handler + >>> handler = BigQueryVectorSearchEnrichmentHandler( + ... project='my-project', + ... vector_search_parameters=params, + ... min_batch_size=100, + ... max_batch_size=1000 + ... ) + >>> + >>> # Use in pipeline + >>> with beam.Pipeline() as p: + ... enriched = ( + ... p + ... | beam.Create([ + ... Chunk( + ... id='query1', + ... embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + ... content=Content(text='test query'), + ... metadata={'language': 'en'} + ... ) + ... ]) + ... | Enrichment(handler) + ... ) + + Args: + vector_search_parameters: Configuration for the vector search query + min_batch_size: Minimum number of chunks to batch before processing + max_batch_size: Maximum number of chunks to process in one batch + **kwargs: Additional arguments passed to bigquery.Client + + The handler will: + 1. Batch incoming chunks according to batch size parameters + 2. Format and execute vector search query for each batch + 3. Join results back to original chunks + 4. Return tuples of (original_chunk, search_results) + """ + def __init__( + self, + vector_search_parameters: BigQueryVectorSearchParameters, + *, + min_batch_size: int = 1, + max_batch_size: int = 1000, + **kwargs): + self.project = vector_search_parameters.project + self.vector_search_parameters = vector_search_parameters + self.kwargs = kwargs + self._batching_kwargs = { + 'min_batch_size': min_batch_size, 'max_batch_size': max_batch_size + } + self.join_fn = join_fn + self.use_custom_types = True + + def __enter__(self): + self.client = bigquery.Client(project=self.project, **self.kwargs) + + def __call__(self, request: Union[Chunk, List[Chunk]], *args, + **kwargs) -> List[Tuple[Chunk, Dict[str, Any]]]: + """Process request(s) using BigQuery vector search. + + Args: + request: Single Chunk with embedding or list of Chunk's with + embeddings to process + + Returns: + Chunk(s) where chunk.metadata['enrichment_output'] contains the + data retrieved via BigQuery VECTOR_SEARCH. + """ + # Convert single request to list for uniform processing + requests = request if isinstance(request, list) else [request] + + # Generate and execute query + query = self.vector_search_parameters.format_query(requests) + query_job = self.client.query(query) + results = query_job.result() + + # Create results dict with empty chunks list as default + results_by_id = {} + for result_row in results: + result_dict = dict(result_row.items()) + results_by_id[result_row.id] = result_dict + + # Return all chunks in original order, with empty results if no matches + response = [] + for chunk in requests: + result_dict = results_by_id.get(chunk.id, {}) + response.append((chunk, result_dict)) + + return response + + def __exit__(self, exc_type, exc_val, exc_tb): + self.client.close() + + def batch_elements_kwargs(self) -> Dict[str, int]: + """Returns kwargs for beam.BatchElements.""" + return self._batching_kwargs + + +def join_fn(left: Embedding, right: Dict[str, Any]) -> Embedding: + left.metadata['enrichment_data'] = right + return left diff --git a/sdks/python/apache_beam/ml/rag/enrichment/bigquery_vector_search_it_test.py b/sdks/python/apache_beam/ml/rag/enrichment/bigquery_vector_search_it_test.py new file mode 100644 index 000000000000..a038f3760f24 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/enrichment/bigquery_vector_search_it_test.py @@ -0,0 +1,717 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import secrets +import time +import unittest + +import apache_beam as beam +from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper +from apache_beam.io.gcp.internal.clients import bigquery +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import Embedding +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to + +# pylint: disable=ungrouped-imports +try: + from google.api_core.exceptions import BadRequest + from apache_beam.transforms.enrichment import Enrichment + from apache_beam.ml.rag.enrichment.bigquery_vector_search import \ + BigQueryVectorSearchEnrichmentHandler + from apache_beam.ml.rag.enrichment.bigquery_vector_search import \ + BigQueryVectorSearchParameters +except ImportError: + raise unittest.SkipTest('BigQuery dependencies not installed') + +_LOGGER = logging.getLogger(__name__) + + +class BigQueryVectorSearchIT(unittest.TestCase): + bigquery_dataset_id = 'python_vector_search_test_' + project = "apache-beam-testing" + + @classmethod + def setUpClass(cls): + cls.bigquery_client = BigQueryWrapper() + cls.dataset_id = '%s%d%s' % ( + cls.bigquery_dataset_id, int(time.time()), secrets.token_hex(3)) + cls.bigquery_client.get_or_create_dataset(cls.project, cls.dataset_id) + _LOGGER.info( + "Created dataset %s in project %s", cls.dataset_id, cls.project) + + @classmethod + def tearDownClass(cls): + request = bigquery.BigqueryDatasetsDeleteRequest( + projectId=cls.project, datasetId=cls.dataset_id, deleteContents=True) + try: + cls.bigquery_client.client.datasets.Delete(request) + except Exception: + _LOGGER.warning( + 'Failed to clean up dataset %s in project %s', + cls.dataset_id, + cls.project) + + +class TestBigQueryVectorSearchIT(BigQueryVectorSearchIT): + # Test data with embeddings + table_data = [{ + "id": "doc1", + "content": "This is a test document", + "domain": "medical", + "embedding": [0.1, 0.2, 0.3], + "metadata": [{ + "key": "language", "value": "en" + }] + }, + { + "id": "doc2", + "content": "Another test document", + "domain": "legal", + "embedding": [0.2, 0.3, 0.4], + "metadata": [{ + "key": "language", "value": "en" + }] + }, + { + "id": "doc3", + "content": "Un document de test", + "domain": "financial", + "embedding": [0.3, 0.4, 0.5], + "metadata": [{ + "key": "language", "value": "fr" + }] + }] + + @classmethod + def create_table(cls, table_name): + fields = [('id', 'STRING'), ('content', 'STRING'), ('domain', 'STRING'), + ('embedding', 'FLOAT64', 'REPEATED'), + ( + 'metadata', + 'RECORD', + 'REPEATED', [('key', 'STRING'), ('value', 'STRING')])] + table_schema = bigquery.TableSchema() + for field_def in fields: + field = bigquery.TableFieldSchema() + field.name = field_def[0] + field.type = field_def[1] + if len(field_def) > 2: + field.mode = field_def[2] + if len(field_def) > 3: + for subfield_def in field_def[3]: + subfield = bigquery.TableFieldSchema() + subfield.name = subfield_def[0] + subfield.type = subfield_def[1] + field.fields.append(subfield) + table_schema.fields.append(field) + + table = bigquery.Table( + tableReference=bigquery.TableReference( + projectId=cls.project, datasetId=cls.dataset_id, + tableId=table_name), + schema=table_schema) + + request = bigquery.BigqueryTablesInsertRequest( + projectId=cls.project, datasetId=cls.dataset_id, table=table) + cls.bigquery_client.client.tables.Insert(request) + cls.bigquery_client.insert_rows( + cls.project, cls.dataset_id, table_name, cls.table_data) + cls.table_name = f"{cls.project}.{cls.dataset_id}.{table_name}" + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.create_table('vector_test') + + def test_basic_vector_search(self): + """Test basic vector similarity search.""" + test_chunks = [ + Chunk( + id="query1", + index=0, + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content(text="test query"), + metadata={"language": "en"}) + ] + # Expected chunk will have enrichment_data in metadata + expected_chunks = [ + Chunk( + id="query1", + index=0, + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content(text="test query"), + metadata={ + "language": "en", + "enrichment_data": { + "id": "query1", + "chunks": [{ + "content": "This is a test document", + "metadata": [{ + "key": "language", "value": "en" + }] + }, + { + "content": "Another test document", + "metadata": [{ + "key": "language", "value": "en" + }] + }] + } + }) + ] + + params = BigQueryVectorSearchParameters( + project=self.project, + table_name=self.table_name, + embedding_column='embedding', + columns=['content', 'metadata'], + neighbor_count=2, + metadata_restriction_template=( + "check_metadata(metadata, 'language', '{language}')")) + + handler = BigQueryVectorSearchEnrichmentHandler( + vector_search_parameters=params) + + with TestPipeline(is_integration_test=True) as p: + result = (p | beam.Create(test_chunks) | Enrichment(handler)) + + assert_that(result, equal_to(expected_chunks)) + + def test_batched_metadata_filter_vector_search(self): + """Test vector similarity search with batching.""" + test_chunks = [ + Chunk( + id="query1", + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content(text="test query 1"), + metadata={"language": "en"}, + index=0), + Chunk( + id="query2", + embedding=Embedding(dense_embedding=[0.2, 0.3, 0.4]), + content=Content(text="test query 2"), + metadata={"language": "en"}, + index=1), + Chunk( + id="query3", + embedding=Embedding(dense_embedding=[0.3, 0.4, 0.5]), + content=Content(text="test query 3"), + metadata={"language": "fr"}, + index=2) + ] + + params = BigQueryVectorSearchParameters( + project=self.project, + table_name=self.table_name, + embedding_column='embedding', + columns=['content', 'metadata'], + neighbor_count=2, + metadata_restriction_template=( + "check_metadata(metadata, 'language', '{language}')")) + + handler = BigQueryVectorSearchEnrichmentHandler( + vector_search_parameters=params, + min_batch_size=2, # Force batching + max_batch_size=2 # Process 2 chunks at a time + ) + + expected_chunks = [ + Chunk( + id="query1", + embedding=Embedding( + dense_embedding=[0.1, 0.2, 0.3], sparse_embedding=None), + content=Content(text="test query 1"), + metadata={ + "language": "en", + "enrichment_data": { + "id": "query1", + "chunks": [{ + "content": "This is a test document", + "metadata": [{ + "key": "language", "value": "en" + }] + }, + { + "content": "Another test document", + "metadata": [{ + "key": "language", "value": "en" + }] + }] + } + }, + index=0), + Chunk( + id="query2", + embedding=Embedding( + dense_embedding=[0.2, 0.3, 0.4], sparse_embedding=None), + content=Content(text="test query 2"), + metadata={ + "language": "en", + "enrichment_data": { + "id": "query2", + "chunks": [{ + "content": "Another test document", + "metadata": [{ + "key": "language", "value": "en" + }] + }, + { + "content": "This is a test document", + "metadata": [{ + "key": "language", "value": "en" + }] + }] + } + }, + index=1), + Chunk( + id="query3", + embedding=Embedding( + dense_embedding=[0.3, 0.4, 0.5], sparse_embedding=None), + content=Content(text="test query 3"), + metadata={ + "language": "fr", + "enrichment_data": { + "id": "query3", + "chunks": [{ + "content": "Un document de test", + "metadata": [{ + "key": "language", "value": "fr" + }] + }] + } + }, + index=2) + ] + + with TestPipeline(is_integration_test=True) as p: + result = (p | beam.Create(test_chunks) | Enrichment(handler)) + + assert_that(result, equal_to(expected_chunks)) + + def test_euclidean_distance_search(self): + """Test vector similarity search using Euclidean distance.""" + test_chunks = [ + Chunk( + id="query1", + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content(text="test query 1"), + metadata={"language": "en"}, + index=0), + Chunk( + id="query2", + embedding=Embedding(dense_embedding=[0.2, 0.3, 0.4]), + content=Content(text="test query 2"), + metadata={"language": "en"}, + index=1) + ] + + params = BigQueryVectorSearchParameters( + project=self.project, + table_name=self.table_name, + embedding_column='embedding', + columns=['content', 'metadata'], + neighbor_count=2, + metadata_restriction_template=( + "check_metadata(metadata, 'language', '{language}')"), + distance_type='EUCLIDEAN' # Use Euclidean distance + ) + + handler = BigQueryVectorSearchEnrichmentHandler( + vector_search_parameters=params, min_batch_size=2, max_batch_size=2) + + expected_chunks = [ + Chunk( + id="query1", + embedding=Embedding( + dense_embedding=[0.1, 0.2, 0.3], sparse_embedding=None), + content=Content(text="test query 1"), + metadata={ + "language": "en", + "enrichment_data": { + "id": "query1", + "chunks": [{ + "content": "This is a test document", + "metadata": [{ + "key": "language", "value": "en" + }] + }, + { + "content": "Another test document", + "metadata": [{ + "key": "language", "value": "en" + }] + }] + } + }, + index=0), + Chunk( + id="query2", + embedding=Embedding( + dense_embedding=[0.2, 0.3, 0.4], sparse_embedding=None), + content=Content(text="test query 2"), + metadata={ + "language": "en", + "enrichment_data": { + "id": "query2", + "chunks": [{ + "content": "Another test document", + "metadata": [{ + "key": "language", "value": "en" + }] + }, + { + "content": "This is a test document", + "metadata": [{ + "key": "language", "value": "en" + }] + }] + } + }, + index=1) + ] + + with TestPipeline(is_integration_test=True) as p: + result = (p | beam.Create(test_chunks) | Enrichment(handler)) + + assert_that(result, equal_to(expected_chunks)) + + def test_no_metadata_restriction(self): + """Test vector search without metadata filtering.""" + test_chunks = [ + Chunk( + id="query1", + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content(text="test query"), + metadata={'language': 'fr'}, + index=0) + ] + + params = BigQueryVectorSearchParameters( + project=self.project, + table_name=self.table_name, + embedding_column='embedding', + columns=['content', 'metadata'], + neighbor_count=2, # Get top 2 matches + metadata_restriction_template=None # No filtering + ) + + handler = BigQueryVectorSearchEnrichmentHandler( + vector_search_parameters=params) + + with TestPipeline(is_integration_test=True) as p: + result = (p | beam.Create(test_chunks) | Enrichment(handler)) + + # Should get matches regardless of metadata + assert_that( + result, + equal_to([ + Chunk( + id="query1", + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content(text="test query"), + metadata={ + "language": "fr", + "enrichment_data": { + "id": "query1", + "chunks": [{ + "content": "This is a test document", + "metadata": [{ + "key": "language", "value": "en" + }] + }, + { + "content": "Another test document", + "metadata": [{ + "key": "language", "value": "en" + }] + }] + } + }, + index=0) + ])) + + def test_metadata_filter_leakage(self): + """Test that metadata filters don't leak between batched chunks.""" + + test_chunks = [ + Chunk( + id="query1", + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content(text="medical query"), + metadata={ + "domain": "medical", "language": "en" + }, + index=0), + Chunk( + id="query2", + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content(text="unmatched query"), + metadata={ + "domain": "doesntexist", "language": "en" + }, + index=1) + ] + + params = BigQueryVectorSearchParameters( + project=self.project, + table_name=self.table_name, + embedding_column='embedding', + columns=['content', 'metadata', 'domain'], + neighbor_count=1, + metadata_restriction_template=( + "domain = '{domain}' AND " + "check_metadata(metadata, 'language', '{language}')")) + + handler = BigQueryVectorSearchEnrichmentHandler( + vector_search_parameters=params, + min_batch_size=2, # Force batching + max_batch_size=2 + ) + + with TestPipeline(is_integration_test=True) as p: + result = (p | beam.Create(test_chunks) | Enrichment(handler)) + + assert_that( + result, + equal_to([ + Chunk( + id="query1", + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content(text="medical query"), + metadata={ + "domain": "medical", + "language": "en", + "enrichment_data": { + "id": "query1", + "chunks": [{ + "content": "This is a test document", + "domain": "medical", + "metadata": [{ + "key": "language", "value": "en" + }] + }] + } + }, + index=0), + Chunk( + id="query2", + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content(text="unmatched query"), + metadata={ + "domain": "doesntexist", + "language": "en", + "enrichment_data": {} + }, + index=1) + ])) + + def test_condition_batching(self): + """Test that queries with same metadata conditions are batched together.""" + + # Create three queries with same conditions but different embeddings + test_chunks = [ + Chunk( + id="query1", + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content(text="english query 1"), + metadata={"language": "en"}, + index=0), + Chunk( + id="query2", + embedding=Embedding(dense_embedding=[0.4, 0.5, 0.6]), + content=Content(text="english query 2"), + metadata={"language": "en"}, + index=1), + Chunk( + id="query3", + embedding=Embedding(dense_embedding=[0.7, 0.8, 0.9]), + content=Content(text="french query 3"), + metadata={"language": "fr"}, + index=2) + ] + + params = BigQueryVectorSearchParameters( + project=self.project, + table_name=self.table_name, + embedding_column='embedding', + columns=['content', 'domain', 'metadata'], + neighbor_count=3, + metadata_restriction_template=( + "check_metadata(metadata, 'language', '{language}')")) + + handler = BigQueryVectorSearchEnrichmentHandler( + vector_search_parameters=params, + min_batch_size=10, # Force batching + max_batch_size=100 + ) + + with TestPipeline(is_integration_test=True) as p: + result = (p | beam.Create(test_chunks) | Enrichment(handler)) + + # All queries should be handled in a single VECTOR_SEARCH + # Each should get its closest match + assert_that( + result, + equal_to([ + Chunk( + id="query1", + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content(text="english query 1"), + metadata={ + "language": "en", + "enrichment_data": { + "id": "query1", + "chunks": [{ + "content": "This is a test document", + "domain": "medical", + "metadata": [{ + "key": "language", "value": "en" + }] + }, + { + "content": "Another test document", + "domain": "legal", + "metadata": [{ + "key": "language", "value": "en" + }] + }] + } + }, + index=0), + Chunk( + id="query2", + embedding=Embedding(dense_embedding=[0.4, 0.5, 0.6]), + content=Content(text="english query 2"), + metadata={ + "language": "en", + "enrichment_data": { + "id": "query2", + "chunks": [{ + "content": "Another test document", + "domain": "legal", + "metadata": [{ + "key": "language", "value": "en" + }] + }, + { + "content": "This is a test document", + "domain": "medical", + "metadata": [{ + "key": "language", "value": "en" + }] + }] + } + }, + index=1), + Chunk( + id="query3", + embedding=Embedding(dense_embedding=[0.7, 0.8, 0.9]), + content=Content(text="french query 3"), + metadata={ + "language": "fr", + "enrichment_data": { + "id": "query3", + "chunks": [{ + "content": "Un document de test", + "domain": "financial", + "metadata": [{ + "key": "language", "value": "fr" + }] + }] + } + }, + index=2) + ])) + + def test_invalid_query(self): + """Test error handling for invalid queries.""" + test_chunks = [ + Chunk( + id="query1", + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content(text="test query"), + metadata={"language": "en"}) + ] + + params = BigQueryVectorSearchParameters( + project=self.project, + table_name=self.table_name, + embedding_column='nonexistent_column', # Invalid column + columns=['content'], + neighbor_count=1, + metadata_restriction_template=( + "language = '{language}'" # Invalid template + ) + ) + + handler = BigQueryVectorSearchEnrichmentHandler( + vector_search_parameters=params) + + with self.assertRaises(BadRequest): + with TestPipeline() as p: + _ = (p | beam.Create(test_chunks) | Enrichment(handler)) + + def test_empty_input(self): + """Test handling of empty input.""" + params = BigQueryVectorSearchParameters( + project=self.project, + table_name=self.table_name, + embedding_column='embedding', + columns=['content'], + neighbor_count=1) + handler = BigQueryVectorSearchEnrichmentHandler( + vector_search_parameters=params) + + with TestPipeline(is_integration_test=True) as p: + result = (p | beam.Create([]) | Enrichment(handler)) + assert_that(result, equal_to([])) + + def test_missing_embedding(self): + """Test handling of chunks with missing embeddings.""" + test_chunks = [ + Chunk( + id="query1", + embedding=None, # Missing embedding + content=Content(text="test query"), + metadata={"language": "en"}, + index=0 + ) + ] + + params = BigQueryVectorSearchParameters( + project=self.project, + table_name=self.table_name, + embedding_column='embedding', + columns=['content'], + neighbor_count=1) + handler = BigQueryVectorSearchEnrichmentHandler( + vector_search_parameters=params) + + with self.assertRaises(ValueError) as context: + with TestPipeline() as p: + _ = (p | beam.Create(test_chunks) | Enrichment(handler)) + self.assertIn("missing embedding", str(context.exception)) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() diff --git a/sdks/python/apache_beam/ml/rag/ingestion/__init__.py b/sdks/python/apache_beam/ml/rag/ingestion/__init__.py new file mode 100644 index 000000000000..6a81a3586b57 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Vector storage ingestion components for RAG pipelines. +This module provides components for storing vectors in RAG pipelines. +""" diff --git a/sdks/python/apache_beam/ml/rag/ingestion/base.py b/sdks/python/apache_beam/ml/rag/ingestion/base.py new file mode 100644 index 000000000000..d79aa7778405 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/base.py @@ -0,0 +1,114 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC +from abc import abstractmethod +from typing import Any + +import apache_beam as beam +from apache_beam.ml.rag.types import Chunk + + +class VectorDatabaseWriteConfig(ABC): + """Abstract base class for vector database configurations in RAG pipelines. + + VectorDatabaseWriteConfig defines the interface for configuring vector + database writes in RAG pipelines. Implementations should provide + database-specific configuration and create appropriate write transforms. + + The configuration flow: + 1. Subclass provides database-specific configuration (table names, etc) + 2. create_write_transform() creates appropriate PTransform for writing + 3. Transform handles converting Chunks to database-specific format + + Example implementation: + >>> class BigQueryVectorWriterConfig(VectorDatabaseWriteConfig): + ... def __init__(self, table: str): + ... self.embedding_column = embedding_column + ... + ... def create_write_transform(self): + ... return beam.io.WriteToBigQuery( + ... table=self.table + ... ) + """ + @abstractmethod + def create_write_transform(self) -> beam.PTransform[Chunk, Any]: + """Creates a PTransform that writes embeddings to the vector database. + + Returns: + A PTransform that accepts PCollection[Chunk] and writes the chunks' + embeddings and metadata to the configured vector database. + The transform should handle: + - Converting Chunk format to database schema + - Setting up database connection/client + - Writing with appropriate batching/error handling + """ + raise NotImplementedError(type(self)) + + +class VectorDatabaseWriteTransform(beam.PTransform): + """A PTransform for writing embedded chunks to vector databases. + + This transform uses a VectorDatabaseWriteConfig to write chunks with + embeddings to vector database. It handles validating the config and applying + the database-specific write transform. + + Example usage: + >>> config = BigQueryVectorConfig( + ... table='project.dataset.embeddings', + ... embedding_column='embedding' + ... ) + >>> + >>> with beam.Pipeline() as p: + ... chunks = p | beam.Create([...]) # PCollection[Chunk] + ... chunks | VectorDatabaseWriteTransform(config) + + Args: + database_config: Configuration for the target vector database. + Must be a subclass of VectorDatabaseWriteConfig that implements + create_write_transform(). + + Raises: + TypeError: If database_config is not a VectorDatabaseWriteConfig instance. + """ + def __init__(self, database_config: VectorDatabaseWriteConfig): + """Initialize transform with database config. + + Args: + database_config: Configuration for target vector database. + """ + if not isinstance(database_config, VectorDatabaseWriteConfig): + raise TypeError( + f"database_config must be VectorDatabaseWriteConfig, " + f"got {type(database_config)}") + self.database_config = database_config + + def expand(self, + pcoll: beam.PCollection[Chunk]) -> beam.PTransform[Chunk, Any]: + """Creates and applies the database-specific write transform. + + Args: + pcoll: PCollection of Chunks with embeddings to write to the + vector database. Each Chunk must have: + - An embedding + - An ID + - Metadata used to filter results as specified by database config + + Returns: + Result of writing to database (implementation specific). + """ + write_transform = self.database_config.create_write_transform() + return pcoll | write_transform diff --git a/sdks/python/apache_beam/ml/rag/ingestion/base_test.py b/sdks/python/apache_beam/ml/rag/ingestion/base_test.py new file mode 100644 index 000000000000..57e3c8b10e68 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/base_test.py @@ -0,0 +1,81 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import apache_beam as beam +from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig +from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteTransform +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import Embedding +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to + + +class MockWriteTransform(beam.PTransform): + """Mock transform that returns element.""" + def expand(self, pcoll): + return pcoll | beam.Map(lambda x: x) + + +class MockDatabaseConfig(VectorDatabaseWriteConfig): + """Mock database config for testing.""" + def __init__(self): + self.write_transform = MockWriteTransform() + + def create_write_transform(self) -> beam.PTransform: + return self.write_transform + + +class VectorDatabaseBaseTest(unittest.TestCase): + def test_write_transform_creation(self): + """Test that write transform is created correctly.""" + config = MockDatabaseConfig() + transform = VectorDatabaseWriteTransform(config) + self.assertEqual(transform.database_config, config) + + def test_pipeline_integration(self): + """Test writing through pipeline.""" + test_data = [ + Chunk( + content=Content(text="foo"), + id="1", + embedding=Embedding(dense_embedding=[0.1, 0.2])), + Chunk( + content=Content(text="bar"), + id="2", + embedding=Embedding(dense_embedding=[0.3, 0.4])) + ] + + with TestPipeline() as p: + result = ( + p + | beam.Create(test_data) + | VectorDatabaseWriteTransform(MockDatabaseConfig())) + + # Verify data was written + assert_that(result, equal_to(test_data)) + + def test_invalid_config(self): + """Test error handling for invalid config.""" + with self.assertRaises(TypeError): + VectorDatabaseWriteTransform(None) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/rag/ingestion/bigquery.py b/sdks/python/apache_beam/ml/rag/ingestion/bigquery.py new file mode 100644 index 000000000000..7d2caa67868a --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/bigquery.py @@ -0,0 +1,183 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any +from typing import Dict +from typing import Optional + +import apache_beam as beam +from apache_beam.io.gcp.bigquery_tools import beam_row_from_dict +from apache_beam.io.gcp.bigquery_tools import get_beam_typehints_from_tableschema +from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig +from apache_beam.ml.rag.types import Chunk +from apache_beam.typehints.row_type import RowTypeConstraint + +ChunkToDictFn = Callable[[Chunk], Dict[str, any]] + + +@dataclass +class SchemaConfig: + """Configuration for custom BigQuery schema and row conversion. + + Allows overriding the default schema and row conversion logic for BigQuery + vector storage. This enables custom table schemas beyond the default + id/embedding/content/metadata structure. + + Attributes: + schema: BigQuery TableSchema dict defining the table structure. + Example: + >>> { + ... 'fields': [ + ... {'name': 'id', 'type': 'STRING'}, + ... {'name': 'embedding', 'type': 'FLOAT64', 'mode': 'REPEATED'}, + ... {'name': 'custom_field', 'type': 'STRING'} + ... ] + ... } + chunk_to_dict_fn: Function that converts a Chunk to a dict matching the + schema. Takes a Chunk and returns Dict[str, Any] with keys matching + schema fields. + Example: + >>> def chunk_to_dict(chunk: Chunk) -> Dict[str, Any]: + ... return { + ... 'id': chunk.id, + ... 'embedding': chunk.embedding.dense_embedding, + ... 'custom_field': chunk.metadata.get('custom_field') + ... } + """ + schema: Dict + chunk_to_dict_fn: ChunkToDictFn + + +class BigQueryVectorWriterConfig(VectorDatabaseWriteConfig): + def __init__( + self, + write_config: Dict[str, Any], + *, # Force keyword arguments + schema_config: Optional[SchemaConfig] = None + ): + """Configuration for writing vectors to BigQuery using managed transforms. + + Supports both default schema (id, embedding, content, metadata columns) and + custom schemas through SchemaConfig. + + Example with default schema: + >>> config = BigQueryVectorWriterConfig( + ... write_config={'table': 'project.dataset.embeddings'}) + + Example with custom schema: + >>> schema_config = SchemaConfig( + ... schema={ + ... 'fields': [ + ... {'name': 'id', 'type': 'STRING'}, + ... {'name': 'embedding', 'type': 'FLOAT64', 'mode': 'REPEATED'}, + ... {'name': 'source_url', 'type': 'STRING'} + ... ] + ... }, + ... chunk_to_dict_fn=lambda chunk: { + ... 'id': chunk.id, + ... 'embedding': chunk.embedding.dense_embedding, + ... 'source_url': chunk.metadata.get('url') + ... } + ... ) + >>> config = BigQueryVectorWriterConfig( + ... write_config={'table': 'project.dataset.embeddings'}, + ... schema_config=schema_config + ... ) + + Args: + write_config: BigQuery write configuration dict. Must include 'table'. + Other options like create_disposition, write_disposition can be + specified. + schema_config: Optional configuration for custom schema and row + conversion. + If not provided, uses default schema with id, embedding, content and + metadata columns. + + Raises: + ValueError: If write_config doesn't include table specification. + """ + if 'table' not in write_config: + raise ValueError("write_config must be provided with 'table' specified") + + self.write_config = write_config + self.schema_config = schema_config + + def create_write_transform(self) -> beam.PTransform: + """Creates transform to write to BigQuery.""" + return _WriteToBigQueryVectorDatabase(self) + + +def _default_chunk_to_dict_fn(chunk: Chunk): + if chunk.embedding is None or chunk.embedding.dense_embedding is None: + raise ValueError("chunk must contain dense embedding") + return { + 'id': chunk.id, + 'embedding': chunk.embedding.dense_embedding, + 'content': chunk.content.text, + 'metadata': [ + { + "key": k, "value": str(v) + } for k, v in chunk.metadata.items() + ] + } + + +def _default_schema(): + return { + 'fields': [{ + 'name': 'id', 'type': 'STRING' + }, { + 'name': 'embedding', 'type': 'FLOAT64', 'mode': 'REPEATED' + }, { + 'name': 'content', 'type': 'STRING' + }, + { + 'name': 'metadata', + 'type': 'RECORD', + 'mode': 'REPEATED', + 'fields': [{ + 'name': 'key', 'type': 'STRING' + }, { + 'name': 'value', 'type': 'STRING' + }] + }] + } + + +class _WriteToBigQueryVectorDatabase(beam.PTransform): + """Implementation of BigQuery vector database write. """ + def __init__(self, config: BigQueryVectorWriterConfig): + self.config = config + + def expand(self, pcoll: beam.PCollection[Chunk]): + schema = ( + self.config.schema_config.schema + if self.config.schema_config else _default_schema()) + chunk_to_dict_fn = ( + self.config.schema_config.chunk_to_dict_fn + if self.config.schema_config else _default_chunk_to_dict_fn) + return ( + pcoll + | "Chunk to dict" >> beam.Map(chunk_to_dict_fn) + | "Chunk dict to schema'd row" >> beam.Map( + lambda chunk_dict: beam_row_from_dict( + row=chunk_dict, schema=schema)).with_output_types( + RowTypeConstraint.from_fields( + get_beam_typehints_from_tableschema(schema))) + | "Write to BigQuery" >> beam.managed.Write( + beam.managed.BIGQUERY, config=self.config.write_config)) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/bigquery_it_test.py b/sdks/python/apache_beam/ml/rag/ingestion/bigquery_it_test.py new file mode 100644 index 000000000000..6c034a1aeae7 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/bigquery_it_test.py @@ -0,0 +1,241 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import os +import secrets +import time +import unittest + +import hamcrest as hc +import pytest + +import apache_beam as beam +from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper +from apache_beam.io.gcp.internal.clients import bigquery +from apache_beam.io.gcp.tests.bigquery_matcher import BigqueryFullResultMatcher +from apache_beam.ml.rag.ingestion.bigquery import BigQueryVectorWriterConfig +from apache_beam.ml.rag.ingestion.bigquery import SchemaConfig +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import Embedding +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.transforms.periodicsequence import PeriodicImpulse + + +@pytest.mark.uses_gcp_java_expansion_service +@unittest.skipUnless( + os.environ.get('EXPANSION_JARS'), + "EXPANSION_JARS environment var is not provided, " + "indicating that jars have not been built") +class BigQueryVectorWriterConfigTest(unittest.TestCase): + BIG_QUERY_DATASET_ID = 'python_rag_bigquery_' + + def setUp(self): + self.test_pipeline = TestPipeline(is_integration_test=True) + self._runner = type(self.test_pipeline.runner).__name__ + self.project = self.test_pipeline.get_option('project') + + self.bigquery_client = BigQueryWrapper() + self.dataset_id = '%s%d%s' % ( + self.BIG_QUERY_DATASET_ID, int(time.time()), secrets.token_hex(3)) + self.bigquery_client.get_or_create_dataset(self.project, self.dataset_id) + _LOGGER = logging.getLogger(__name__) + _LOGGER.info( + "Created dataset %s in project %s", self.dataset_id, self.project) + + def tearDown(self): + request = bigquery.BigqueryDatasetsDeleteRequest( + projectId=self.project, datasetId=self.dataset_id, deleteContents=True) + try: + _LOGGER = logging.getLogger(__name__) + _LOGGER.info( + "Deleting dataset %s in project %s", self.dataset_id, self.project) + self.bigquery_client.client.datasets.Delete(request) + # Failing to delete a dataset should not cause a test failure. + except Exception: + _LOGGER = logging.getLogger(__name__) + _LOGGER.debug( + 'Failed to clean up dataset %s in project %s', + self.dataset_id, + self.project) + + def test_default_schema(self): + table_name = 'python_default_schema_table' + table_id = '{}.{}.{}'.format(self.project, self.dataset_id, table_name) + + config = BigQueryVectorWriterConfig(write_config={'table': table_id}) + chunks = [ + Chunk( + id="1", + embedding=Embedding(dense_embedding=[0.1, 0.2]), + content=Content(text="foo"), + metadata={"a": "b"}), + Chunk( + id="2", + embedding=Embedding(dense_embedding=[0.3, 0.4]), + content=Content(text="bar"), + metadata={"c": "d"}) + ] + + pipeline_verifiers = [ + BigqueryFullResultMatcher( + project=self.project, + query="SELECT id, content, embedding, metadata FROM %s" % table_id, + data=[("1", "foo", [0.1, 0.2], [{ + "key": "a", "value": "b" + }]), ("2", "bar", [0.3, 0.4], [{ + "key": "c", "value": "d" + }])]) + ] + + args = self.test_pipeline.get_full_options_as_args( + on_success_matcher=hc.all_of(*pipeline_verifiers)) + with beam.Pipeline(argv=args) as p: + _ = (p | beam.Create(chunks) | config.create_write_transform()) + + def test_default_schema_missing_embedding(self): + table_name = 'python_default_schema_table' + table_id = '{}.{}.{}'.format(self.project, self.dataset_id, table_name) + + config = BigQueryVectorWriterConfig(write_config={'table': table_id}) + chunks = [ + Chunk(id="1", content=Content(text="foo"), metadata={"a": "b"}), + Chunk(id="2", content=Content(text="bar"), metadata={"c": "d"}) + ] + with self.assertRaises(ValueError): + with beam.Pipeline() as p: + _ = (p | beam.Create(chunks) | config.create_write_transform()) + + def test_custom_schema(self): + table_name = 'python_custom_schema_table' + table_id = '{}.{}.{}'.format(self.project, self.dataset_id, table_name) + + schema_config = SchemaConfig( + schema={ + 'fields': [{ + 'name': 'id', 'type': 'STRING' + }, + { + 'name': 'embedding', + 'type': 'FLOAT64', + 'mode': 'REPEATED' + }, { + 'name': 'source', 'type': 'STRING' + }] + }, + chunk_to_dict_fn=lambda chunk: { + 'id': chunk.id, + 'embedding': chunk.embedding.dense_embedding, + 'source': chunk.metadata.get('source') + }) + config = BigQueryVectorWriterConfig( + write_config={'table': table_id}, schema_config=schema_config) + chunks = [ + Chunk( + id="1", + embedding=Embedding(dense_embedding=[0.1, 0.2]), + content=Content(text="foo content"), + metadata={"source": "foo"}), + Chunk( + id="2", + embedding=Embedding(dense_embedding=[0.3, 0.4]), + content=Content(text="bar content"), + metadata={"source": "bar"}) + ] + + pipeline_verifiers = [ + BigqueryFullResultMatcher( + project=self.project, + query="SELECT id, embedding, source FROM %s" % table_id, + data=[("1", [0.1, 0.2], "foo"), ("2", [0.3, 0.4], "bar")]) + ] + + args = self.test_pipeline.get_full_options_as_args( + on_success_matcher=hc.all_of(*pipeline_verifiers)) + + with beam.Pipeline(argv=args) as p: + _ = (p | beam.Create(chunks) | config.create_write_transform()) + + def test_streaming_default_schema(self): + self.skip_if_not_dataflow_runner() + + table_name = 'python_streaming_default_schema_table' + table_id = '{}.{}.{}'.format(self.project, self.dataset_id, table_name) + + config = BigQueryVectorWriterConfig(write_config={'table': table_id}) + chunks = [ + Chunk( + id="1", + embedding=Embedding(dense_embedding=[0.1, 0.2]), + content=Content(text="foo"), + metadata={"a": "b"}), + Chunk( + id="2", + embedding=Embedding(dense_embedding=[0.3, 0.4]), + content=Content(text="bar"), + metadata={"c": "d"}), + Chunk( + id="3", + embedding=Embedding(dense_embedding=[0.5, 0.6]), + content=Content(text="foo"), + metadata={"e": "f"}), + Chunk( + id="4", + embedding=Embedding(dense_embedding=[0.7, 0.8]), + content=Content(text="bar"), + metadata={"g": "h"}) + ] + + pipeline_verifiers = [ + BigqueryFullResultMatcher( + project=self.project, + query="SELECT id, content, embedding, metadata FROM %s" % table_id, + data=[("1", "foo", [0.1, 0.2], [{ + "key": "a", "value": "b" + }]), ("2", "bar", [0.3, 0.4], [{ + "key": "c", "value": "d" + }]), ("3", "foo", [0.5, 0.6], [{ + "key": "e", "value": "f" + }]), ("4", "bar", [0.7, 0.8], [{ + "key": "g", "value": "h" + }])]) + ] + args = self.test_pipeline.get_full_options_as_args( + on_success_matcher=hc.all_of(*pipeline_verifiers), + streaming=True, + allow_unsafe_triggers=True) + + with beam.Pipeline(argv=args) as p: + _ = ( + p + | PeriodicImpulse(0, 4, 1) + | beam.Map(lambda t: chunks[t]) + | config.create_write_transform()) + + def skip_if_not_dataflow_runner(self): + # skip if dataflow runner is not specified + if not self._runner or "dataflowrunner" not in self._runner.lower(): + self.skipTest( + "Streaming with exactly-once route has the requirement " + "`beam:requirement:pardo:on_window_expiration:v1`, " + "which is currently only supported by the Dataflow runner") + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() diff --git a/sdks/python/apache_beam/transforms/enrichment.py b/sdks/python/apache_beam/transforms/enrichment.py index 5bb1e2024e79..b4ebc31f2e74 100644 --- a/sdks/python/apache_beam/transforms/enrichment.py +++ b/sdks/python/apache_beam/transforms/enrichment.py @@ -133,16 +133,22 @@ class Enrichment(beam.PTransform[beam.PCollection[InputT], def __init__( self, source_handler: EnrichmentSourceHandler, - join_fn: JoinFn = cross_join, + join_fn: Optional[JoinFn] = None, timeout: Optional[float] = DEFAULT_TIMEOUT_SECS, repeater: Repeater = ExponentialBackOffRepeater(), - throttler: PreCallThrottler = DefaultThrottler()): + throttler: PreCallThrottler = DefaultThrottler(), + use_custom_types: bool = False): self._cache = None self._source_handler = source_handler - self._join_fn = join_fn + self._join_fn = ( + join_fn if join_fn else source_handler.join_fn if hasattr( + source_handler, 'join_fn') else cross_join) self._timeout = timeout self._repeater = repeater self._throttler = throttler + self._use_custom_types = ( + source_handler.use_custom_types if hasattr( + source_handler, 'use_custom_types') else use_custom_types) def expand(self, input_row: beam.PCollection[InputT]) -> beam.PCollection[OutputT]: @@ -165,8 +171,9 @@ def expand(self, # EnrichmentSourceHandler returns a tuple of (request,response). return ( fetched_data - | "enrichment_join" >> - beam.Map(lambda x: self._join_fn(x[0]._asdict(), x[1]._asdict()))) + | "enrichment_join" >> beam.Map( + lambda x: self._join_fn(x[0]._asdict(), x[1]._asdict()) + if not self._use_custom_types else self._join_fn(x[0], x[1]))) def with_redis_cache( self, diff --git a/sdks/python/test-suites/direct/common.gradle b/sdks/python/test-suites/direct/common.gradle index e290e8003b13..1dd15ecb09f9 100644 --- a/sdks/python/test-suites/direct/common.gradle +++ b/sdks/python/test-suites/direct/common.gradle @@ -447,6 +447,7 @@ project(":sdks:python:test-suites:xlang").ext.xlangTasks.each { taskMetadata -> pythonPipelineOptions: [ "--runner=TestDirectRunner", "--project=${gcpProject}", + "--temp_location=gs://temp-storage-for-end-to-end-tests/temp-it", ], pytestOptions: [ "--capture=no", // print stdout instantly