Skip to content

Commit

Permalink
updated with async code and improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
eavanvalkenburg committed Jan 15, 2025
1 parent 79b1f91 commit 6cc63a7
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 73 deletions.
12 changes: 12 additions & 0 deletions python/semantic_kernel/connectors/memory/mongodb_atlas/const.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) Microsoft. All rights reserved.

from typing import Final

from semantic_kernel.data.const import DistanceFunction

DISTANCE_FUNCTION_MAPPING: Final[dict[DistanceFunction, str]] = {
DistanceFunction.EUCLIDEAN_DISTANCE: "euclidean",
DistanceFunction.COSINE_SIMILARITY: "cosine",
DistanceFunction.DOT_PROD: "dotProduct",
}
MONGODB_ID_FIELD: Final[str] = "_id"
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) Microsoft. All rights reserved.

import asyncio
import logging
import sys
from collections.abc import Sequence
Expand All @@ -11,22 +10,22 @@
else:
from typing_extensions import override # pragma: no cover

from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.database import Database
from pydantic import ValidationError
from pymongo import AsyncMongoClient
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.database import AsyncDatabase

from semantic_kernel.connectors.memory.mongodb_atlas.const import MONGODB_ID_FIELD
from semantic_kernel.connectors.memory.mongodb_atlas.utils import create_index_definition
from semantic_kernel.data.filter_clauses import AnyTagsEqualTo, EqualTo
from semantic_kernel.data.kernel_search_results import KernelSearchResults
from semantic_kernel.data.record_definition import VectorStoreRecordDefinition, VectorStoreRecordVectorField
from semantic_kernel.data.record_definition import VectorStoreRecordDefinition
from semantic_kernel.data.vector_search import (
VectorizableTextSearchMixin,
VectorSearchFilter,
VectorSearchOptions,
)
from semantic_kernel.data.vector_search.vector_search import VectorSearchBase
from semantic_kernel.data.vector_search.vector_search_result import VectorSearchResult
from semantic_kernel.data.vector_search.vector_text_search import VectorTextSearchMixin
from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin
from semantic_kernel.exceptions import (
VectorSearchExecutionException,
Expand All @@ -43,111 +42,135 @@
@experimental_class
class MongoDBAtlasCollection(
VectorSearchBase[str, TModel],
VectorizableTextSearchMixin[TModel],
VectorizedSearchMixin[TModel],
VectorTextSearchMixin[TModel],
Generic[TModel],
):
"""MongoDB Atlas collection implementation."""

mongo_client: MongoClient
mongo_client: AsyncMongoClient
database_name: str
index_name: str
supported_key_types: ClassVar[list[str] | None] = ["str"]
supported_vector_types: ClassVar[list[str] | None] = ["float", "int"]
managed_mongo_client: bool = True

def __init__(
self,
data_model_type: type[TModel],
data_model_definition: VectorStoreRecordDefinition | None = None,
collection_name: str | None = None,
mongo_client: MongoClient | None = None,
database_name: str | None = None,
mongo_client: AsyncMongoClient | None = None,
index_name: str | None = None,
**kwargs: Any,
) -> None:
"""Initializes a new instance of the MongoDBAtlasCollection class.
Args:
data_model_type (type[TModel]): The type of the data model.
data_model_definition (VectorStoreRecordDefinition): The model definition, optional.
collection_name (str): The name of the collection, optional.
mongo_client (MongoClient): The MongoDB client for interacting with MongoDB Atlas,
data_model_type: The type of the data model.
data_model_definition: The model definition, optional.
collection_name: The name of the collection, optional.
database_name: The name of the database, will be filled from the env when this is not set.
mongo_client: The MongoDB client for interacting with MongoDB Atlas,
used for creating and deleting collections.
index_name: The name of the index to use for searching, when not passed, will use <collection_name>_idx.
**kwargs: Additional keyword arguments, including:
The same keyword arguments used for MongoDBAtlasStore:
connection_string: str | None = None,
database_name: str | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None
"""
if mongo_client:
if not collection_name:
raise VectorStoreInitializationException("Collection name is required.")
if not collection_name:
raise VectorStoreInitializationException("Collection name is required.")
if mongo_client and database_name:
super().__init__(
data_model_type=data_model_type,
data_model_definition=data_model_definition,
collection_name=collection_name,
mongo_client=mongo_client,
managed_mongo_client=False,
collection_name=collection_name,
database_name=database_name,
index_name=index_name or f"{collection_name}_idx",
managed_client=False,
)
return

from semantic_kernel.connectors.memory.mongodb_atlas.mongodb_atlas_settings import (
MongoDBAtlasSettings,
)
from semantic_kernel.connectors.memory.mongodb_atlas.mongodb_atlas_settings import MongoDBAtlasSettings

try:
mongodb_atlas_settings = MongoDBAtlasSettings.create(
env_file_path=kwargs.get("env_file_path"),
connection_string=kwargs.get("connection_string"),
database_name=kwargs.get("database_name"),
database_name=database_name,
env_file_encoding=kwargs.get("env_file_encoding"),
)
except ValidationError as exc:
raise VectorStoreInitializationException("Failed to create MongoDB Atlas settings.") from exc
mongo_client = MongoClient(mongodb_atlas_settings.connection_string)
managed_client = not mongo_client
if not mongo_client:
mongo_client = AsyncMongoClient(mongodb_atlas_settings.connection_string)
if not mongodb_atlas_settings.database_name:
raise VectorStoreInitializationException("Database name is required.")

super().__init__(
data_model_type=data_model_type,
data_model_definition=data_model_definition,
collection_name=mongodb_atlas_settings.database_name,
collection_name=collection_name,
mongo_client=mongo_client,
managed_client=managed_client,
database_name=mongodb_atlas_settings.database_name,
index_name=index_name or f"{collection_name}_idx",
)

def _get_database(self) -> AsyncDatabase:
"""Get the database."""
return self.mongo_client.get_database(self.database_name)

def _get_collection(self) -> AsyncCollection:
"""Get the collection."""
return self.mongo_client.get_database(self.database_name).get_collection(self.collection_name)

@override
async def _inner_upsert(
self,
records: Sequence[Any],
**kwargs: Any,
) -> Sequence[str]:
if not isinstance(records, list):
records = list(records)
collection: Collection = self.mongo_client.get_database().get_collection(self.collection_name)
result = await collection.insert_many(records, **kwargs)
return [str(inserted_id) for inserted_id in result.inserted_ids]
result = await self._get_collection().update_many(update=records, upsert=True, **kwargs)
return [str(ids) for ids in result.upserted_id]

@override
async def _inner_get(self, keys: Sequence[str], **kwargs: Any) -> Sequence[dict[str, Any]]:
collection: Collection = self.mongo_client.get_database().get_collection(self.collection_name)
result = await asyncio.gather(
*[collection.find_one({"_id": key}) for key in keys],
return_exceptions=True,
)
return [res for res in result if not isinstance(res, BaseException)]
result = self._get_collection().find({MONGODB_ID_FIELD: {"$in": keys}})
return await result.to_list(length=len(keys))

@override
async def _inner_delete(self, keys: Sequence[str], **kwargs: Any) -> None:
collection: Collection = self.mongo_client.get_database().get_collection(self.collection_name)
await collection.delete_many({"_id": {"$in": keys}})
collection = self._get_collection()
await collection.delete_many({MONGODB_ID_FIELD: {"$in": keys}})

def _replace_key_field(self, record: dict[str, Any]) -> dict[str, Any]:
if self._key_field_name == MONGODB_ID_FIELD:
return record
return {
MONGODB_ID_FIELD: record.pop(self._key_field_name, None),
**record,
}

def _reset_key_field(self, record: dict[str, Any]) -> dict[str, Any]:
if self._key_field_name == MONGODB_ID_FIELD:
return record
return {
self._key_field_name: record.pop(MONGODB_ID_FIELD, None),
**record,
}

@override
def _serialize_dicts_to_store_models(self, records: Sequence[dict[str, Any]], **kwargs: Any) -> Sequence[Any]:
return records
return [self._replace_key_field(record) for record in records]

@override
def _deserialize_store_models_to_dicts(self, records: Sequence[Any], **kwargs: Any) -> Sequence[dict[str, Any]]:
return records
return [self._reset_key_field(record) for record in records]

@override
async def create_collection(self, **kwargs) -> None:
Expand All @@ -156,18 +179,17 @@ async def create_collection(self, **kwargs) -> None:
Args:
**kwargs: Additional keyword arguments.
"""
database: Database = self.mongo_client.get_database()
await database.create_collection(self.collection_name, **kwargs)
database = self._get_database()
collection = await database.create_collection(self.collection_name, **kwargs)
await collection.create_search_index(create_index_definition(self.data_model_definition, self.index_name))

@override
async def does_collection_exist(self, **kwargs) -> bool:
database: Database = self.mongo_client.get_database()
return self.collection_name in await database.list_collection_names()
return self.collection_name in await self._get_database().list_collection_names()

@override
async def delete_collection(self, **kwargs) -> None:
database: Database = self.mongo_client.get_database()
await database.drop_collection(self.collection_name, **kwargs)
await self._get_database().drop_collection(self.collection_name, **kwargs)

@override
async def _inner_search(
Expand All @@ -178,25 +200,36 @@ async def _inner_search(
vector: list[float | int] | None = None,
**kwargs: Any,
) -> KernelSearchResults[VectorSearchResult[TModel]]:
collection: Collection = self.mongo_client.get_database().get_collection(self.collection_name)
search_args: dict[str, Any] = {
"limit": options.top,
"skip": options.skip,
collection = self._get_collection()
vector_search_query: dict[str, Any] = {
"limit": options.top + options.skip,
}
if options.filter.filters:
search_args["filter"] = self._build_filter_dict(options.filter)
vector_search_query["filter"] = self._build_filter_dict(options.filter)
if vector is not None:
search_args["vector"] = vector
if "vector" not in search_args:
vector_search_query["queryVector"] = vector
vector_search_query["path"] = options.vector_field_name
if "queryVector" not in vector_search_query:
raise VectorStoreOperationException("Vector is required for search.")

projection_query: dict[str, int | dict] = {
field: 1
for field in self.data_model_definition.get_field_names(
include_vector_fields=options.include_vectors,
include_key_field=False, # _id is always included
)
}
projection_query["score"] = {"$meta": "vectorSearchScore"}
try:
raw_results = await collection.find(search_args)
raw_results = await collection.aggregate([
{"$vectorSearch": vector_search_query},
{"$project": projection_query},
])
except Exception as exc:
raise VectorSearchExecutionException("Failed to search the collection.") from exc
return KernelSearchResults(
results=self._get_vector_search_results_from_results(raw_results, options),
total_count=await raw_results.count() if options.include_total_count else None,
total_count=None, # no way to get a count before looping through the result cursor
)

def _build_filter_dict(self, search_filter: VectorSearchFilter) -> dict[str, Any]:
Expand All @@ -221,4 +254,4 @@ def _get_score_from_result(self, result: dict[str, Any]) -> float | None:
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
"""Exit the context manager."""
if self.managed_mongo_client:
self.mongo_client.close()
await self.mongo_client.close()
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
else:
from typing_extensions import override # pragma: no cover

from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.database import Database
from pydantic import ValidationError
from pymongo import AsyncMongoClient
from pymongo.asynchronous.database import AsyncDatabase

from semantic_kernel.connectors.memory.mongodb_atlas.mongodb_atlas_collection import (
MongoDBAtlasCollection,
Expand All @@ -34,13 +34,14 @@
class MongoDBAtlasStore(VectorStore):
"""MongoDB Atlas store implementation."""

mongo_client: MongoClient
mongo_client: AsyncMongoClient
database_name: str

def __init__(
self,
connection_string: str | None = None,
database_name: str | None = None,
mongo_client: MongoClient | None = None,
mongo_client: AsyncMongoClient | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> None:
Expand Down Expand Up @@ -71,18 +72,21 @@ def __init__(
)
except ValidationError as exc:
raise VectorStoreInitializationException("Failed to create MongoDB Atlas settings.") from exc
mongo_client = MongoClient(mongodb_atlas_settings.connection_string)
mongo_client = AsyncMongoClient(mongodb_atlas_settings.connection_string)
managed_client = True

super().__init__(mongo_client=mongo_client, managed_client=managed_client)
super().__init__(
mongo_client=mongo_client,
managed_client=managed_client,
database_name=mongodb_atlas_settings.database_name,
)

@override
def get_collection(
self,
collection_name: str,
data_model_type: type[TModel],
data_model_definition: VectorStoreRecordDefinition | None = None,
mongo_client: MongoClient | None = None,
**kwargs: Any,
) -> "VectorStoreRecordCollection":
"""Get a MongoDBAtlasCollection tied to a collection.
Expand All @@ -91,27 +95,25 @@ def get_collection(
collection_name (str): The name of the collection.
data_model_type (type[TModel]): The type of the data model.
data_model_definition (VectorStoreRecordDefinition | None): The model fields, optional.
mongo_client (MongoClient | None): The MongoDB client for interacting with MongoDB Atlas,
will be created if not supplied.
**kwargs: Additional keyword arguments, passed to the collection constructor.
"""
if collection_name not in self.vector_record_collections:
self.vector_record_collections[collection_name] = MongoDBAtlasCollection(
data_model_type=data_model_type,
data_model_definition=data_model_definition,
mongo_client=mongo_client or self.mongo_client,
mongo_client=self.mongo_client,
collection_name=collection_name,
managed_client=mongo_client is None,
database_name=self.database_name,
**kwargs,
)
return self.vector_record_collections[collection_name]

@override
async def list_collection_names(self, **kwargs: Any) -> list[str]:
database: Database = self.mongo_client.get_database()
database: AsyncDatabase = self.mongo_client.get_database(self.database_name)
return await database.list_collection_names()

async def __aexit__(self, exc_type, exc_value, traceback) -> None:
"""Exit the context manager."""
if self.managed_client:
self.mongo_client.close()
await self.mongo_client.close()
Loading

0 comments on commit 6cc63a7

Please sign in to comment.