From ef64c43928c831956910dcc1908306f12608be57 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 15 Jan 2025 11:16:50 +0100 Subject: [PATCH] updated with async code and improvements --- .../connectors/memory/mongodb_atlas/const.py | 12 ++ .../mongodb_atlas/mongodb_atlas_collection.py | 151 +++++++++++------- .../mongodb_atlas/mongodb_atlas_store.py | 30 ++-- .../connectors/memory/mongodb_atlas/utils.py | 49 ++++++ 4 files changed, 169 insertions(+), 73 deletions(-) create mode 100644 python/semantic_kernel/connectors/memory/mongodb_atlas/const.py diff --git a/python/semantic_kernel/connectors/memory/mongodb_atlas/const.py b/python/semantic_kernel/connectors/memory/mongodb_atlas/const.py new file mode 100644 index 000000000000..d05732788bfc --- /dev/null +++ b/python/semantic_kernel/connectors/memory/mongodb_atlas/const.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft. All rights reserved. + +from typing import Final + +from semantic_kernel.data.const import DistanceFunction + +DISTANCE_FUNCTION_MAPPING: Final[dict[DistanceFunction, str]] = { + DistanceFunction.EUCLIDEAN_DISTANCE: "euclidean", + DistanceFunction.COSINE_SIMILARITY: "cosine", + DistanceFunction.DOT_PROD: "dotProduct", +} +MONGODB_ID_FIELD: Final[str] = "_id" diff --git a/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_collection.py b/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_collection.py index 2655776f5a8b..d08f8e81d5f6 100644 --- a/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_collection.py +++ b/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_collection.py @@ -1,6 +1,5 @@ # Copyright (c) Microsoft. All rights reserved. -import asyncio import logging import sys from collections.abc import Sequence @@ -11,22 +10,22 @@ else: from typing_extensions import override # pragma: no cover -from pymongo import MongoClient -from pymongo.collection import Collection -from pymongo.database import Database from pydantic import ValidationError +from pymongo import AsyncMongoClient +from pymongo.asynchronous.collection import AsyncCollection +from pymongo.asynchronous.database import AsyncDatabase +from semantic_kernel.connectors.memory.mongodb_atlas.const import MONGODB_ID_FIELD +from semantic_kernel.connectors.memory.mongodb_atlas.utils import create_index_definition from semantic_kernel.data.filter_clauses import AnyTagsEqualTo, EqualTo from semantic_kernel.data.kernel_search_results import KernelSearchResults -from semantic_kernel.data.record_definition import VectorStoreRecordDefinition, VectorStoreRecordVectorField +from semantic_kernel.data.record_definition import VectorStoreRecordDefinition from semantic_kernel.data.vector_search import ( - VectorizableTextSearchMixin, VectorSearchFilter, VectorSearchOptions, ) from semantic_kernel.data.vector_search.vector_search import VectorSearchBase from semantic_kernel.data.vector_search.vector_search_result import VectorSearchResult -from semantic_kernel.data.vector_search.vector_text_search import VectorTextSearchMixin from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin from semantic_kernel.exceptions import ( VectorSearchExecutionException, @@ -43,111 +42,135 @@ @experimental_class class MongoDBAtlasCollection( VectorSearchBase[str, TModel], - VectorizableTextSearchMixin[TModel], VectorizedSearchMixin[TModel], - VectorTextSearchMixin[TModel], Generic[TModel], ): """MongoDB Atlas collection implementation.""" - mongo_client: MongoClient + mongo_client: AsyncMongoClient + database_name: str + index_name: str supported_key_types: ClassVar[list[str] | None] = ["str"] supported_vector_types: ClassVar[list[str] | None] = ["float", "int"] - managed_mongo_client: bool = True def __init__( self, data_model_type: type[TModel], data_model_definition: VectorStoreRecordDefinition | None = None, collection_name: str | None = None, - mongo_client: MongoClient | None = None, + database_name: str | None = None, + mongo_client: AsyncMongoClient | None = None, + index_name: str | None = None, **kwargs: Any, ) -> None: """Initializes a new instance of the MongoDBAtlasCollection class. Args: - data_model_type (type[TModel]): The type of the data model. - data_model_definition (VectorStoreRecordDefinition): The model definition, optional. - collection_name (str): The name of the collection, optional. - mongo_client (MongoClient): The MongoDB client for interacting with MongoDB Atlas, + data_model_type: The type of the data model. + data_model_definition: The model definition, optional. + collection_name: The name of the collection, optional. + database_name: The name of the database, will be filled from the env when this is not set. + mongo_client: The MongoDB client for interacting with MongoDB Atlas, used for creating and deleting collections. + index_name: The name of the index to use for searching, when not passed, will use _idx. **kwargs: Additional keyword arguments, including: The same keyword arguments used for MongoDBAtlasStore: connection_string: str | None = None, - database_name: str | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None """ - if mongo_client: - if not collection_name: - raise VectorStoreInitializationException("Collection name is required.") + if not collection_name: + raise VectorStoreInitializationException("Collection name is required.") + if mongo_client and database_name: super().__init__( data_model_type=data_model_type, data_model_definition=data_model_definition, - collection_name=collection_name, mongo_client=mongo_client, - managed_mongo_client=False, + collection_name=collection_name, + database_name=database_name, + index_name=index_name or f"{collection_name}_idx", + managed_client=False, ) return - from semantic_kernel.connectors.memory.mongodb_atlas.mongodb_atlas_settings import ( - MongoDBAtlasSettings, - ) + from semantic_kernel.connectors.memory.mongodb_atlas.mongodb_atlas_settings import MongoDBAtlasSettings try: mongodb_atlas_settings = MongoDBAtlasSettings.create( env_file_path=kwargs.get("env_file_path"), connection_string=kwargs.get("connection_string"), - database_name=kwargs.get("database_name"), + database_name=database_name, env_file_encoding=kwargs.get("env_file_encoding"), ) except ValidationError as exc: raise VectorStoreInitializationException("Failed to create MongoDB Atlas settings.") from exc - mongo_client = MongoClient(mongodb_atlas_settings.connection_string) + managed_client = not mongo_client + if not mongo_client: + mongo_client = AsyncMongoClient(mongodb_atlas_settings.connection_string) if not mongodb_atlas_settings.database_name: raise VectorStoreInitializationException("Database name is required.") super().__init__( data_model_type=data_model_type, data_model_definition=data_model_definition, - collection_name=mongodb_atlas_settings.database_name, + collection_name=collection_name, mongo_client=mongo_client, + managed_client=managed_client, + database_name=mongodb_atlas_settings.database_name, + index_name=index_name or f"{collection_name}_idx", ) + def _get_database(self) -> AsyncDatabase: + """Get the database.""" + return self.mongo_client.get_database(self.database_name) + + def _get_collection(self) -> AsyncCollection: + """Get the collection.""" + return self.mongo_client.get_database(self.database_name).get_collection(self.collection_name) + @override async def _inner_upsert( self, records: Sequence[Any], **kwargs: Any, ) -> Sequence[str]: - if not isinstance(records, list): - records = list(records) - collection: Collection = self.mongo_client.get_database().get_collection(self.collection_name) - result = await collection.insert_many(records, **kwargs) - return [str(inserted_id) for inserted_id in result.inserted_ids] + result = await self._get_collection().update_many(update=records, upsert=True, **kwargs) + return [str(ids) for ids in result.upserted_id] @override async def _inner_get(self, keys: Sequence[str], **kwargs: Any) -> Sequence[dict[str, Any]]: - collection: Collection = self.mongo_client.get_database().get_collection(self.collection_name) - result = await asyncio.gather( - *[collection.find_one({"_id": key}) for key in keys], - return_exceptions=True, - ) - return [res for res in result if not isinstance(res, BaseException)] + result = self._get_collection().find({MONGODB_ID_FIELD: {"$in": keys}}) + return await result.to_list(length=len(keys)) @override async def _inner_delete(self, keys: Sequence[str], **kwargs: Any) -> None: - collection: Collection = self.mongo_client.get_database().get_collection(self.collection_name) - await collection.delete_many({"_id": {"$in": keys}}) + collection = self._get_collection() + await collection.delete_many({MONGODB_ID_FIELD: {"$in": keys}}) + + def _replace_key_field(self, record: dict[str, Any]) -> dict[str, Any]: + if self._key_field_name == MONGODB_ID_FIELD: + return record + return { + MONGODB_ID_FIELD: record.pop(self._key_field_name, None), + **record, + } + + def _reset_key_field(self, record: dict[str, Any]) -> dict[str, Any]: + if self._key_field_name == MONGODB_ID_FIELD: + return record + return { + self._key_field_name: record.pop(MONGODB_ID_FIELD, None), + **record, + } @override def _serialize_dicts_to_store_models(self, records: Sequence[dict[str, Any]], **kwargs: Any) -> Sequence[Any]: - return records + return [self._replace_key_field(record) for record in records] @override def _deserialize_store_models_to_dicts(self, records: Sequence[Any], **kwargs: Any) -> Sequence[dict[str, Any]]: - return records + return [self._reset_key_field(record) for record in records] @override async def create_collection(self, **kwargs) -> None: @@ -156,18 +179,17 @@ async def create_collection(self, **kwargs) -> None: Args: **kwargs: Additional keyword arguments. """ - database: Database = self.mongo_client.get_database() - await database.create_collection(self.collection_name, **kwargs) + database = self._get_database() + collection = await database.create_collection(self.collection_name, **kwargs) + await collection.create_search_index(create_index_definition(self.data_model_definition, self.index_name)) @override async def does_collection_exist(self, **kwargs) -> bool: - database: Database = self.mongo_client.get_database() - return self.collection_name in await database.list_collection_names() + return self.collection_name in await self._get_database().list_collection_names() @override async def delete_collection(self, **kwargs) -> None: - database: Database = self.mongo_client.get_database() - await database.drop_collection(self.collection_name, **kwargs) + await self._get_database().drop_collection(self.collection_name, **kwargs) @override async def _inner_search( @@ -178,25 +200,36 @@ async def _inner_search( vector: list[float | int] | None = None, **kwargs: Any, ) -> KernelSearchResults[VectorSearchResult[TModel]]: - collection: Collection = self.mongo_client.get_database().get_collection(self.collection_name) - search_args: dict[str, Any] = { - "limit": options.top, - "skip": options.skip, + collection = self._get_collection() + vector_search_query: dict[str, Any] = { + "limit": options.top + options.skip, } if options.filter.filters: - search_args["filter"] = self._build_filter_dict(options.filter) + vector_search_query["filter"] = self._build_filter_dict(options.filter) if vector is not None: - search_args["vector"] = vector - if "vector" not in search_args: + vector_search_query["queryVector"] = vector + vector_search_query["path"] = options.vector_field_name + if "queryVector" not in vector_search_query: raise VectorStoreOperationException("Vector is required for search.") + projection_query: dict[str, int | dict] = { + field: 1 + for field in self.data_model_definition.get_field_names( + include_vector_fields=options.include_vectors, + include_key_field=False, # _id is always included + ) + } + projection_query["score"] = {"$meta": "vectorSearchScore"} try: - raw_results = await collection.find(search_args) + raw_results = await collection.aggregate([ + {"$vectorSearch": vector_search_query}, + {"$project": projection_query}, + ]) except Exception as exc: raise VectorSearchExecutionException("Failed to search the collection.") from exc return KernelSearchResults( results=self._get_vector_search_results_from_results(raw_results, options), - total_count=await raw_results.count() if options.include_total_count else None, + total_count=None, # no way to get a count before looping through the result cursor ) def _build_filter_dict(self, search_filter: VectorSearchFilter) -> dict[str, Any]: @@ -221,4 +254,4 @@ def _get_score_from_result(self, result: dict[str, Any]) -> float | None: async def __aexit__(self, exc_type, exc_value, traceback) -> None: """Exit the context manager.""" if self.managed_mongo_client: - self.mongo_client.close() + await self.mongo_client.close() diff --git a/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_store.py b/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_store.py index c19915a85a25..a57cd05fbc1a 100644 --- a/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_store.py +++ b/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_store.py @@ -9,9 +9,9 @@ else: from typing_extensions import override # pragma: no cover -from pymongo import MongoClient -from pymongo.collection import Collection -from pymongo.database import Database +from pydantic import ValidationError +from pymongo import AsyncMongoClient +from pymongo.asynchronous.database import AsyncDatabase from semantic_kernel.connectors.memory.mongodb_atlas.mongodb_atlas_collection import ( MongoDBAtlasCollection, @@ -34,13 +34,14 @@ class MongoDBAtlasStore(VectorStore): """MongoDB Atlas store implementation.""" - mongo_client: MongoClient + mongo_client: AsyncMongoClient + database_name: str def __init__( self, connection_string: str | None = None, database_name: str | None = None, - mongo_client: MongoClient | None = None, + mongo_client: AsyncMongoClient | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, ) -> None: @@ -71,10 +72,14 @@ def __init__( ) except ValidationError as exc: raise VectorStoreInitializationException("Failed to create MongoDB Atlas settings.") from exc - mongo_client = MongoClient(mongodb_atlas_settings.connection_string) + mongo_client = AsyncMongoClient(mongodb_atlas_settings.connection_string) managed_client = True - super().__init__(mongo_client=mongo_client, managed_client=managed_client) + super().__init__( + mongo_client=mongo_client, + managed_client=managed_client, + database_name=mongodb_atlas_settings.database_name, + ) @override def get_collection( @@ -82,7 +87,6 @@ def get_collection( collection_name: str, data_model_type: type[TModel], data_model_definition: VectorStoreRecordDefinition | None = None, - mongo_client: MongoClient | None = None, **kwargs: Any, ) -> "VectorStoreRecordCollection": """Get a MongoDBAtlasCollection tied to a collection. @@ -91,27 +95,25 @@ def get_collection( collection_name (str): The name of the collection. data_model_type (type[TModel]): The type of the data model. data_model_definition (VectorStoreRecordDefinition | None): The model fields, optional. - mongo_client (MongoClient | None): The MongoDB client for interacting with MongoDB Atlas, - will be created if not supplied. **kwargs: Additional keyword arguments, passed to the collection constructor. """ if collection_name not in self.vector_record_collections: self.vector_record_collections[collection_name] = MongoDBAtlasCollection( data_model_type=data_model_type, data_model_definition=data_model_definition, - mongo_client=mongo_client or self.mongo_client, + mongo_client=self.mongo_client, collection_name=collection_name, - managed_client=mongo_client is None, + database_name=self.database_name, **kwargs, ) return self.vector_record_collections[collection_name] @override async def list_collection_names(self, **kwargs: Any) -> list[str]: - database: Database = self.mongo_client.get_database() + database: AsyncDatabase = self.mongo_client.get_database(self.database_name) return await database.list_collection_names() async def __aexit__(self, exc_type, exc_value, traceback) -> None: """Exit the context manager.""" if self.managed_client: - self.mongo_client.close() + await self.mongo_client.close() diff --git a/python/semantic_kernel/connectors/memory/mongodb_atlas/utils.py b/python/semantic_kernel/connectors/memory/mongodb_atlas/utils.py index cb415f45377c..71898859d984 100644 --- a/python/semantic_kernel/connectors/memory/mongodb_atlas/utils.py +++ b/python/semantic_kernel/connectors/memory/mongodb_atlas/utils.py @@ -1,7 +1,15 @@ # Copyright (c) Microsoft. All rights reserved. from numpy import array +from pymongo.operations import SearchIndexModel +from semantic_kernel.connectors.memory.mongodb_atlas.const import DISTANCE_FUNCTION_MAPPING +from semantic_kernel.data.record_definition.vector_store_model_definition import VectorStoreRecordDefinition +from semantic_kernel.data.record_definition.vector_store_record_fields import ( + VectorStoreRecordDataField, + VectorStoreRecordVectorField, +) +from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError from semantic_kernel.memory.memory_record import MemoryRecord DEFAULT_DB_NAME = "default" @@ -66,3 +74,44 @@ def memory_record_to_mongo_document(record: MemoryRecord) -> dict: MONGODB_FIELD_EMBEDDING: record._embedding.tolist(), MONGODB_FIELD_TIMESTAMP: record._timestamp, } + + +def create_vector_field(field: VectorStoreRecordVectorField) -> dict: + """Create a vector field. + + Args: + field (VectorStoreRecordVectorField): The vector field. + + Returns: + dict: The vector field. + """ + if field.distance_function not in DISTANCE_FUNCTION_MAPPING: + raise ServiceInitializationError(f"Invalid distance function: {field.distance_function}") + return { + "type": "vector", + "numDimensions": field.dimensions, + "path": field.name, + "similarity": DISTANCE_FUNCTION_MAPPING[field.distance_function], + } + + +def create_index_definition(record_definition: VectorStoreRecordDefinition, index_name: str) -> SearchIndexModel: + """Create an index definition. + + Args: + record_definition (VectorStoreRecordDefinition): The record definition. + index_name (str): The index name. + + Returns: + SearchIndexModel: The index definition. + """ + vector_fields = [create_vector_field(field) for field in record_definition.vector_fields] + data_fields = [ + {"path": field.name, "type": "filter"} + for field in record_definition.fields + if isinstance(field, VectorStoreRecordDataField) and (field.is_filterable or field.is_full_text_searchable) + ] + key_field = [{"path": record_definition.key_field.name, "type": "filter"}] + return SearchIndexModel( + type="vectorSearch", name=index_name, definition={"fields": vector_fields + data_fields + key_field} + )