Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: Add vector search to Postgres connector #10213

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
416 changes: 395 additions & 21 deletions python/samples/getting_started/third_party/postgres-memory.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
# Limitation based on pgvector documentation https://github.com/pgvector/pgvector#what-if-i-want-to-index-vectors-with-more-than-2000-dimensions
MAX_DIMENSIONALITY = 2000

# The name of the column that returns distance value in the database.
# It is used in the similarity search query. Must not conflict with model property.
DISTANCE_COLUMN_NAME = "sk_pg_distance"

# Environment Variables
PGHOST_ENV_VAR = "PGHOST"
PGPORT_ENV_VAR = "PGPORT"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,17 @@
import logging
import sys
from collections.abc import Sequence
from typing import Any, ClassVar, TypeVar
from typing import Any, ClassVar, Generic, TypeVar

from semantic_kernel.data.filter_clauses.any_tags_equal_to_filter_clause import AnyTagsEqualTo
from semantic_kernel.data.filter_clauses.equal_to_filter_clause import EqualTo
from semantic_kernel.data.kernel_search_results import KernelSearchResults
from semantic_kernel.data.vector_search.vector_search import VectorSearchBase
from semantic_kernel.data.vector_search.vector_search_filter import VectorSearchFilter
from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions
from semantic_kernel.data.vector_search.vector_search_result import VectorSearchResult
from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin
from semantic_kernel.exceptions.vector_store_exceptions import VectorSearchExecutionException

if sys.version_info >= (3, 12):
from typing import override # pragma: no cover
Expand All @@ -14,25 +24,27 @@
from psycopg_pool import AsyncConnectionPool
from pydantic import PrivateAttr

from semantic_kernel.connectors.memory.postgres.constants import DEFAULT_SCHEMA, MAX_DIMENSIONALITY
from semantic_kernel.connectors.memory.postgres.constants import (
DEFAULT_SCHEMA,
DISTANCE_COLUMN_NAME,
MAX_DIMENSIONALITY,
)
from semantic_kernel.connectors.memory.postgres.postgres_settings import PostgresSettings
from semantic_kernel.connectors.memory.postgres.utils import (
convert_dict_to_row,
convert_row_to_dict,
get_vector_distance_ops_str,
get_vector_index_ops_str,
python_type_to_postgres,
)
from semantic_kernel.data.const import IndexKind
from semantic_kernel.data.const import DistanceFunction, IndexKind
from semantic_kernel.data.record_definition.vector_store_model_definition import VectorStoreRecordDefinition
from semantic_kernel.data.record_definition.vector_store_record_fields import (
VectorStoreRecordField,
VectorStoreRecordKeyField,
VectorStoreRecordVectorField,
)
from semantic_kernel.data.vector_storage.vector_store_record_collection import VectorStoreRecordCollection
from semantic_kernel.exceptions import (
VectorStoreModelValidationError,
VectorStoreOperationException,
)
from semantic_kernel.exceptions import VectorStoreModelValidationError, VectorStoreOperationException
from semantic_kernel.kernel_types import OneOrMany
from semantic_kernel.utils.experimental_decorator import experimental_class

Expand All @@ -43,7 +55,11 @@


@experimental_class
class PostgresCollection(VectorStoreRecordCollection[TKey, TModel]):
class PostgresCollection(
VectorSearchBase[TKey, TModel],
VectorizedSearchMixin[TModel],
Generic[TKey, TModel],
):
"""PostgreSQL collection implementation."""

connection_pool: AsyncConnectionPool | None = None
Expand Down Expand Up @@ -84,26 +100,31 @@ def __init__(
data_model_definition=data_model_definition,
connection_pool=connection_pool,
db_schema=db_schema,
# This controls whether the connection pool is managed by the collection
# in the __aenter__ and __aexit__ methods.
managed_client=connection_pool is None,
)

self._settings = settings or PostgresSettings.create(
env_file_path=env_file_path, env_file_encoding=env_file_encoding
)

# region: VectorStoreRecordCollection implementation

@override
async def __aenter__(self) -> "PostgresCollection":
# If the connection pool was not provided, create a new one.
if not self.connection_pool:
self.connection_pool = await self._settings.create_connection_pool()
self.managed_client = True
return self

@override
async def __aexit__(self, *args):
# Only close the connection pool if it was created by the collection.
if self.managed_client and self.connection_pool:
await self.connection_pool.close()
# If the pool was created by the collection, set it to None to enable reusing the collection.
if self._settings:
if self.managed_client:
self.connection_pool = None

@override
Expand Down Expand Up @@ -313,6 +334,42 @@ async def create_collection(self, **kwargs: Any) -> None:
if vector_field.index_kind:
await self._create_index(table_name, vector_field)

@override
async def does_collection_exist(self, **kwargs: Any) -> bool:
"""Check if the collection exists."""
if self.connection_pool is None:
raise VectorStoreOperationException(
"Connection pool is not available, use the collection as a context manager."
)

async with self.connection_pool.connection() as conn, conn.cursor() as cur:
await cur.execute(
"""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = %s AND table_name = %s
""",
(self.db_schema, self.collection_name),
)
row = await cur.fetchone()
return bool(row)

@override
async def delete_collection(self, **kwargs: Any) -> None:
"""Delete the collection."""
if self.connection_pool is None:
raise VectorStoreOperationException(
"Connection pool is not available, use the collection as a context manager."
)

async with self.connection_pool.connection() as conn, conn.cursor() as cur:
await cur.execute(
sql.SQL("DROP TABLE {scm}.{tbl} CASCADE").format(
scm=sql.Identifier(self.db_schema), tbl=sql.Identifier(self.collection_name)
),
)
await conn.commit()

async def _create_index(self, table_name: str, vector_field: VectorStoreRecordVectorField) -> None:
"""Create an index on a column in the table.

Expand Down Expand Up @@ -360,38 +417,167 @@ async def _create_index(self, table_name: str, vector_field: VectorStoreRecordVe

logger.info(f"Index '{index_name}' created successfully on column '{column_name}'.")

# endregion
# region: VectorSearchBase implementation

@override
async def does_collection_exist(self, **kwargs: Any) -> bool:
"""Check if the collection exists."""
async def _inner_search(
self,
options: VectorSearchOptions,
search_text: str | None = None,
vectorizable_text: str | None = None,
vector: list[float | int] | None = None,
**kwargs: Any,
) -> KernelSearchResults[VectorSearchResult[TModel]]:
if self.connection_pool is None:
raise VectorStoreOperationException(
"Connection pool is not available, use the collection as a context manager."
)

fields = [(field.name, field) for field in self.data_model_definition.fields.values()]

if vector is not None:
query, params, return_fields = self._construct_vector_query(vector, fields, options, **kwargs)
elif search_text:
raise VectorSearchExecutionException("Text search not supported.")
elif vectorizable_text:
raise VectorSearchExecutionException("Vectorizable text search not supported.")

async with self.connection_pool.connection() as conn, conn.cursor() as cur:
await cur.execute(
"""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = %s AND table_name = %s
""",
(self.db_schema, self.collection_name),
await cur.execute(query, params)
rows = await cur.fetchall()
# Add the distance to the results
results = [convert_row_to_dict(row, return_fields) for row in rows]
return KernelSearchResults(
results=self._get_vector_search_results_from_results(results, options),
total_count=len(results) if options.include_total_count else None,
)
row = await cur.fetchone()
return bool(row)

@override
async def delete_collection(self, **kwargs: Any) -> None:
"""Delete the collection."""
if self.connection_pool is None:
raise VectorStoreOperationException(
"Connection pool is not available, use the collection as a context manager."
def _construct_vector_query(
self,
vector: list[float | int],
fields: list[tuple[str, VectorStoreRecordField]],
options: VectorSearchOptions,
**kwargs: Any,
) -> tuple[sql.Composed, list[Any], list[tuple[str, VectorStoreRecordField | None]]]:
"""Construct a vector search query.

Args:
vector: The vector to search for.
fields: The fields.
options: The search options.
**kwargs: Additional arguments.

Returns:
The query, parameters, and the fields representing the columns in the result.
"""
# Get the vector field we will be searching against,
# defaulting to the first vector field if not specified
vector_fields = self.data_model_definition.vector_fields
if not vector_fields:
raise VectorSearchExecutionException("No vector fields defined.")
if options.vector_field_name:
vector_field = next((f for f in vector_fields if f.name == options.vector_field_name), None)
if not vector_field:
raise VectorSearchExecutionException(f"Vector field '{options.vector_field_name}' not found.")
else:
vector_field = vector_fields[0]

# Default to cosine distance if not set
distance_function = vector_field.distance_function or DistanceFunction.COSINE_DISTANCE
ops_str = get_vector_distance_ops_str(distance_function)

# Select all fields except all vector fields if include_vectors is False
select_fields = [(name, f) for (name, f) in fields if (name != vector_field.name or options.include_vectors)]
select_list = [name for (name, _) in select_fields]

where_clause = self._build_where_clauses_from_filter(options.filter)

query = sql.SQL("SELECT {}, {} {} %s as {} FROM {}.{}").format(
sql.SQL(", ").join(sql.Identifier(name) for name in select_list),
sql.Identifier(vector_field.name),
sql.SQL(ops_str),
sql.Identifier(DISTANCE_COLUMN_NAME),
sql.Identifier(self.db_schema),
sql.Identifier(self.collection_name),
)

if where_clause:
query += where_clause

query += sql.SQL(" ORDER BY {} LIMIT {}").format(
sql.Identifier(DISTANCE_COLUMN_NAME),
sql.Literal(options.top),
)

if options.skip:
query += sql.SQL(" OFFSET {}").format(sql.Literal(options.skip))

# For cosine similarity, we need to take 1 - cosine distance.
# However, we can't use an expression in the ORDER BY clause or else the index won't be used.
# Instead we'll wrap the query in a subquery and modify the distance in the outer query.
if distance_function == DistanceFunction.COSINE_SIMILARITY:
query = sql.SQL("SELECT subquery.*, 1 - subquery.{} AS {} FROM ({}) AS subquery").format(
sql.Identifier(DISTANCE_COLUMN_NAME),
sql.Identifier(DISTANCE_COLUMN_NAME),
query,
)

async with self.connection_pool.connection() as conn, conn.cursor() as cur:
await cur.execute(
sql.SQL("DROP TABLE {scm}.{tbl} CASCADE").format(
scm=sql.Identifier(self.db_schema), tbl=sql.Identifier(self.collection_name)
),
# For inner product, we need to take -1 * inner product.
# However, we can't use an expression in the ORDER BY clause or else the index won't be used.
# Instead we'll wrap the query in a subquery and modify the distance in the outer query.
if distance_function == DistanceFunction.DOT_PROD:
query = sql.SQL("SELECT subquery.*, -1 * subquery.{} AS {} FROM ({}) AS subquery").format(
sql.Identifier(DISTANCE_COLUMN_NAME),
sql.Identifier(DISTANCE_COLUMN_NAME),
query,
)
await conn.commit()

# Convert the vector to a string for the query
params = ["[" + ",".join([str(float(v)) for v in vector]) + "]"]

return query, params, [*select_fields, (DISTANCE_COLUMN_NAME, None)]

def _build_where_clauses_from_filter(self, filters: VectorSearchFilter | None) -> sql.Composed | None:
"""Build the WHERE clause for the search query from the filter in the search options.

Args:
filters: The filters.

Returns:
The WHERE clause.
"""
if not filters or not filters.filters:
return None

where_clauses = []
for filter in filters.filters:
match filter:
case EqualTo():
where_clauses.append(
sql.SQL("{field} = {value}").format(
field=sql.Identifier(filter.field_name),
value=sql.Literal(filter.value),
)
)
case AnyTagsEqualTo():
where_clauses.append(
sql.SQL("{field} @> ARRAY[{value}::TEXT").format(
field=sql.Identifier(filter.field_name),
value=sql.Literal(filter.value),
)
)
case _:
raise ValueError(f"Unsupported filter: {filter}")

return sql.SQL("WHERE {}").format(sql.SQL(" AND ").join(where_clauses))

@override
def _get_record_from_result(self, result: dict[str, Any]) -> dict[str, Any]:
return {k: v for (k, v) in result.items() if k != DISTANCE_COLUMN_NAME}

@override
def _get_score_from_result(self, result: Any) -> float | None:
return result.pop(DISTANCE_COLUMN_NAME, None)

# endregion
Loading
Loading