Skip to content

Commit

Permalink
Revert sanitization changes added in #26 and #37
Browse files Browse the repository at this point in the history
Closes Revert changes to aind-data-schema sanitization #38
  • Loading branch information
bruno-f-cruz committed Feb 4, 2025
1 parent 58d026e commit aabf844
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@

import abc
import logging
from typing import Any, Generic, Type, TypeVar, Union
from typing import Generic, TypeVar, Union

from aind_data_schema.core import rig as ads_rig
from aind_data_schema.core import session as ads_session
from pydantic import BaseModel, create_model, model_validator

from aind_behavior_experiment_launcher.data_mapper import _base

Expand All @@ -34,45 +33,3 @@ class AindDataSchemaSessionDataMapper(AindDataSchemaDataMapper[ads_session.Sessi


class AindDataSchemaRigDataMapper(AindDataSchemaDataMapper[ads_rig.Rig], abc.ABC): ...


_TModel = TypeVar("_TModel", bound=BaseModel)


def create_encoding_model(model: Type[_TModel]) -> Type[_TModel]:
"""Creates a new BaseModel by wrapping the incoming model and adding a Before
ModelValidator to replace _SPECIAL_CHARACTERS with the unicode, escaped,
representation"""

_SPECIAL_CHARACTERS = ".$"

def _to_unicode_repr(character: str):
if len(character) != 1:
raise ValueError(f"Expected a single character, got {character}")
return f"\\u{ord(character):04x}"

def _aind_data_schema_encoder(cls, data: Any) -> Any:
if isinstance(data, dict):
return _sanitize_dict(data)
return data

def _sanitize_dict(value: dict) -> dict:
if isinstance(value, dict):
_keys = list(value.keys())
for key in _keys:
if isinstance(value[key], dict):
value[key] = _sanitize_dict(value[key])
if isinstance(sanitized_key := key, str):
for char in _SPECIAL_CHARACTERS:
if char in sanitized_key:
sanitized_key = sanitized_key.replace(char, _to_unicode_repr(char))
value[sanitized_key] = value.pop(key)
return value

return create_model(
f"_Wrapped{model.__class__.__name__}",
__base__=model,
__validators__={
"encoder": model_validator(mode="before")(_aind_data_schema_encoder) # type: ignore
},
)
49 changes: 4 additions & 45 deletions tests/test_data_mapper.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import unittest
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Dict, List, Optional
from unittest.mock import patch

from aind_data_schema.base import AindGeneric
from pydantic import BaseModel

from aind_behavior_experiment_launcher.data_mapper.aind_data_schema import create_encoding_model
from aind_behavior_experiment_launcher.data_mapper.helpers import (
_sanity_snapshot_keys,
snapshot_bonsai_environment,
snapshot_python_environment,
)
Expand All @@ -28,58 +25,20 @@ class TestHelpers(unittest.TestCase):
@patch("importlib.metadata.distributions")
def test_snapshot_python_environment(self, mock_distributions):
mock_distributions.return_value = [
type("Distribution", (object,), {"name": "package1.sub$s", "version": "1.0.0"}),
type("Distribution", (object,), {"name": "package1", "version": "1.0.0"}),
type("Distribution", (object,), {"name": "package2", "version": "2.0.0"}),
]
expected_result = {"package1_sub_s": "1.0.0", "package2": "2.0.0"}
expected_result = {"package1": "1.0.0", "package2": "2.0.0"}
result = snapshot_python_environment()
self.assertEqual(result, expected_result)

def test_snapshot_bonsai_environment_from_mock(self):
out = snapshot_bonsai_environment(config_file=Path(TESTS_ASSETS) / "bonsai.config")
self.assertEqual(
out,
{"Bonsai": "2.8.5", "Bonsai_Core": "2.8.5", "Bonsai_Design": "2.8.5", "Bonsai_Design_Visualizers": "2.8.0"},
{"Bonsai": "2.8.5", "Bonsai.Core": "2.8.5", "Bonsai.Design": "2.8.5", "Bonsai.Design.Visualizers": "2.8.0"},
)

def test_sanity_snapshot_keys_no_special_chars(self):
snapshot = {"key1": "value1", "key2": "value2"}
expected = {"key1": "value1", "key2": "value2"}
result = _sanity_snapshot_keys(snapshot)
self.assertEqual(result, expected)

def test_sanity_snapshot_keys_with_dots(self):
snapshot = {"key.1": "value1", "key.2": "value2"}
expected = {"key_1": "value1", "key_2": "value2"}
result = _sanity_snapshot_keys(snapshot)
self.assertEqual(result, expected)

def test_sanity_snapshot_keys_with_dollars(self):
snapshot = {"key$1": "value1", "key$2": "value2"}
expected = {"key_1": "value1", "key_2": "value2"}
result = _sanity_snapshot_keys(snapshot)
self.assertEqual(result, expected)

def test_sanity_snapshot_keys_with_dots_and_dollars(self):
snapshot = {"key.1$": "value1", "key.2$": "value2"}
expected = {"key_1_": "value1", "key_2_": "value2"}
result = _sanity_snapshot_keys(snapshot)
self.assertEqual(result, expected)


class TestAindDataMapper(unittest.TestCase):
class MyMockModel(BaseModel):
a_dict: Dict[str, Any]
a_generic: AindGeneric

def test_encoding_with_illegal_characters(self):
_input = {"key": "value", "$key.key": "value"}
_expected = {"key": "value", "\\u0024key\\u002ekey": "value"}
encoding_model = create_encoding_model(self.MyMockModel)
test = encoding_model(a_dict=_input, a_generic=_input)
self.assertEqual(test.a_dict, _expected)
self.assertEqual(test.a_generic.model_dump(), _expected)


if __name__ == "__main__":
unittest.main()

0 comments on commit aabf844

Please sign in to comment.