Skip to content

Commit

Permalink
Add BigQueryVectorWriterConfig tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Claude authored and claudevdm committed Dec 19, 2024
1 parent a22a48c commit ba5ef2a
Show file tree
Hide file tree
Showing 6 changed files with 272 additions and 67 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run",
"modification": 2
"modification": 3
}
36 changes: 16 additions & 20 deletions sdks/python/apache_beam/ml/rag/ingestion/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,14 @@ class VectorDatabaseWriteConfig(ABC):
3. Transform handles converting Chunks to database-specific format
Example implementation:
```python
class BigQueryVectorWriterConfig(VectorDatabaseWriteConfig):
def __init__(self, table: str):
self.embedding_column = embedding_column
def create_write_transform(self):
return beam.io.WriteToBigQuery(
table=self.table
)
```
>>> class BigQueryVectorWriterConfig(VectorDatabaseWriteConfig):
... def __init__(self, table: str):
... self.embedding_column = embedding_column
...
... def create_write_transform(self):
... return beam.io.WriteToBigQuery(
... table=self.table
... )
"""
@abstractmethod
def create_write_transform(self) -> beam.PTransform:
Expand All @@ -67,16 +65,14 @@ class VectorDatabaseWriteTransform(beam.PTransform):
the database-specific write transform.
Example usage:
```python
config = BigQueryVectorConfig(
table='project.dataset.embeddings',
embedding_column='embedding'
)
with beam.Pipeline() as p:
chunks = p | beam.Create([...]) # PCollection[Chunk]
chunks | VectorDatabaseWriteTransform(config)
```
>>> config = BigQueryVectorConfig(
... table='project.dataset.embeddings',
... embedding_column='embedding'
... )
>>>
>>> with beam.Pipeline() as p:
... chunks = p | beam.Create([...]) # PCollection[Chunk]
... chunks | VectorDatabaseWriteTransform(config)
Args:
database_config: Configuration for the target vector database.
Expand Down
8 changes: 4 additions & 4 deletions sdks/python/apache_beam/ml/rag/ingestion/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
# limitations under the License.

import unittest
import apache_beam as beam
from apache_beam.ml.rag.types import Chunk, Embedding, Content
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that, equal_to

import apache_beam as beam
from apache_beam.ml.rag.ingestion.base import (
VectorDatabaseWriteConfig, VectorDatabaseWriteTransform)
from apache_beam.ml.rag.types import Chunk, Content, Embedding
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that, equal_to


class MockWriteTransform(beam.PTransform):
Expand Down
80 changes: 38 additions & 42 deletions sdks/python/apache_beam/ml/rag/ingestion/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass

from typing import Optional, List, Dict, Any
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import apache_beam as beam
from apache_beam.io.gcp.bigquery_tools import (
beam_row_from_dict, get_beam_typehints_from_tableschema)
from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig
from apache_beam.ml.rag.types import Chunk
from apache_beam.typehints.row_type import RowTypeConstraint
from apache_beam.io.gcp.bigquery_tools import beam_row_from_dict, get_beam_typehints_from_tableschema

ChunkToDictFn = Callable[[Chunk], Dict[str, any]]

Expand All @@ -39,23 +39,23 @@ class SchemaConfig:
Attributes:
schema: BigQuery TableSchema dict defining the table structure.
Example:
{
'fields': [
{'name': 'id', 'type': 'STRING'},
{'name': 'embedding', 'type': 'FLOAT64', 'mode': 'REPEATED'},
{'name': 'custom_field', 'type': 'STRING'}
]
}
>>> {
... 'fields': [
... {'name': 'id', 'type': 'STRING'},
... {'name': 'embedding', 'type': 'FLOAT64', 'mode': 'REPEATED'},
... {'name': 'custom_field', 'type': 'STRING'}
... ]
... }
chunk_to_dict_fn: Function that converts a Chunk to a dict matching the
schema. Takes a Chunk and returns Dict[str, Any] with keys matching
schema fields.
Example:
def chunk_to_dict(chunk: Chunk) -> Dict[str, Any]:
return {
'id': chunk.id,
'embedding': chunk.embedding.dense_embedding,
'custom_field': chunk.metadata.get('custom_field')
}
>>> def chunk_to_dict(chunk: Chunk) -> Dict[str, Any]:
... return {
... 'id': chunk.id,
... 'embedding': chunk.embedding.dense_embedding,
... 'custom_field': chunk.metadata.get('custom_field')
... }
"""
schema: Dict
chunk_to_dict_fn: ChunkToDictFn
Expand All @@ -66,40 +66,36 @@ def __init__(
self,
write_config: Dict[str, Any],
*, # Force keyword arguments
schema_config: Optional[SchemaConfig]
schema_config: Optional[SchemaConfig] = None
):
"""Configuration for writing vectors to BigQuery using managed transforms.
Supports both default schema (id, embedding, content, metadata columns) and
custom schemas through SchemaConfig.
Example with default schema:
```python
config = BigQueryVectorWriterConfig(
write_config={'table': 'project.dataset.embeddings'})
```
>>> config = BigQueryVectorWriterConfig(
... write_config={'table': 'project.dataset.embeddings'})
Example with custom schema:
```python
schema_config = SchemaConfig(
schema={
'fields': [
{'name': 'id', 'type': 'STRING'},
{'name': 'embedding', 'type': 'FLOAT64', 'mode': 'REPEATED'},
{'name': 'source_url', 'type': 'STRING'}
]
},
chunk_to_dict_fn=lambda chunk: {
'id': chunk.id,
'embedding': chunk.embedding.dense_embedding,
'source_url': chunk.metadata.get('url')
}
)
config = BigQueryVectorWriterConfig(
write_config={'table': 'project.dataset.embeddings'},
schema_config=schema_config
)
```
>>> schema_config = SchemaConfig(
... schema={
... 'fields': [
... {'name': 'id', 'type': 'STRING'},
... {'name': 'embedding', 'type': 'FLOAT64', 'mode': 'REPEATED'},
... {'name': 'source_url', 'type': 'STRING'}
... ]
... },
... chunk_to_dict_fn=lambda chunk: {
... 'id': chunk.id,
... 'embedding': chunk.embedding.dense_embedding,
... 'source_url': chunk.metadata.get('url')
... }
... )
>>> config = BigQueryVectorWriterConfig(
... write_config={'table': 'project.dataset.embeddings'},
... schema_config=schema_config
... )
Args:
write_config: BigQuery write configuration dict. Must include 'table'.
Expand Down
Loading

0 comments on commit ba5ef2a

Please sign in to comment.