diff --git a/pyproject.toml b/pyproject.toml index 4d7ff893..1cb978cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "aind-behavior-vr-foraging" -description = "A library that defines AIND data schema for the Aind Behavior VR Foraing experiment." +description = "A library that defines AIND data schema for the Aind Behavior VR Foraging experiment." authors = [ {name = "Bruno Cruz", email = "bruno.cruz@alleninstitute.org"}] license = {text = "MIT"} requires-python = ">=3.11" @@ -15,7 +15,7 @@ readme = "README.md" dynamic = ["version"] dependencies = [ - "aind_behavior_services>=0.8.0", + "aind_behavior_services>=0.8, <0.9", ] [project.optional-dependencies] @@ -25,7 +25,7 @@ linters = [ 'codespell' ] -launcher = ["aind_behavior_experiment_launcher[aind-services]<0.2.0"] +launcher = ["aind_behavior_experiment_launcher[aind-services]>=0.2.0rc4"] docs = [ 'Sphinx<7.3', @@ -36,6 +36,10 @@ docs = [ 'sphinx-jsonschema' ] +[project.scripts] +clabe = "aind_behavior_vr_foraging.launcher:main" +regenerate = "aind_behavior_vr_foraging.regenerate:main" + [tool.setuptools.packages.find] where = ["src/DataSchemas"] diff --git a/scripts/regenerate.cmd b/scripts/regenerate.cmd deleted file mode 100644 index 7179544c..00000000 --- a/scripts/regenerate.cmd +++ /dev/null @@ -1,6 +0,0 @@ -@echo off -setlocal -set "scriptPath=%~dp0" -set "pythonScriptPath=%scriptPath%regenerate.ps1" -powershell -ExecutionPolicy Bypass -File "%pythonScriptPath%" -endlocal diff --git a/scripts/regenerate.ps1 b/scripts/regenerate.ps1 deleted file mode 100644 index 732f62b3..00000000 --- a/scripts/regenerate.ps1 +++ /dev/null @@ -1,4 +0,0 @@ -$scriptPath = Split-Path -Parent $MyInvocation.MyCommand.Path -Set-Location -Path (Split-Path -Parent $scriptPath) -.\.venv\Scripts\Activate.ps1 -& python .\scripts\regenerate.py \ No newline at end of file diff --git a/src/DataSchemas/aind_behavior_session_model.json b/src/DataSchemas/aind_behavior_session_model.json index e900c6a3..fd53b1d8 100644 --- a/src/DataSchemas/aind_behavior_session_model.json +++ b/src/DataSchemas/aind_behavior_session_model.json @@ -1,7 +1,7 @@ { "properties": { "aind_behavior_services_pkg_version": { - "default": "0.8.1", + "default": "0.8.2", "pattern": "^(0|[1-9]\\d*)\\.(0|[1-9]\\d*)\\.(0|[1-9]\\d*)(?:-((?:0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\\.(?:0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\\+([0-9a-zA-Z-]+(?:\\.[0-9a-zA-Z-]+)*))?$", "title": "aind_behavior_services package version", "type": "string" diff --git a/src/DataSchemas/aind_behavior_vr_foraging/data_mappers.py b/src/DataSchemas/aind_behavior_vr_foraging/data_mappers.py new file mode 100644 index 00000000..47fa6b14 --- /dev/null +++ b/src/DataSchemas/aind_behavior_vr_foraging/data_mappers.py @@ -0,0 +1,431 @@ +import datetime +import logging +import os +from pathlib import Path +from typing import Dict, List, Optional, Self, Type, TypeVar, Union + +import aind_behavior_services.rig as AbsRig +import aind_data_schema +import aind_data_schema.base +import aind_data_schema.components.coordinates +import aind_data_schema.components.devices +import aind_data_schema.components.stimulus +import aind_data_schema.core.session +import git +import pydantic +from aind_behavior_experiment_launcher.data_mappers import data_mapper_service +from aind_behavior_experiment_launcher.launcher.behavior_launcher import BehaviorLauncher +from aind_behavior_experiment_launcher.records.subject import WaterLogResult +from aind_behavior_services.calibration import Calibration +from aind_behavior_services.calibration.olfactometer import OlfactometerChannelType +from aind_behavior_services.session import AindBehaviorSessionModel +from aind_behavior_services.utils import model_from_json_file, utcnow +from aind_data_schema.core.rig import Rig +from pydantic import BaseModel + +from aind_behavior_vr_foraging.rig import AindVrForagingRig +from aind_behavior_vr_foraging.task_logic import AindVrForagingTaskLogic + +TFrom = TypeVar("TFrom", bound=Union[BaseModel, dict]) +TTo = TypeVar("TTo", bound=BaseModel) + +T = TypeVar("T") + +logger = logging.getLogger(__name__) + +_DATABASE_DIR = "AindDataSchemaRig" + + +class AindRigDataMapper(data_mapper_service.DataMapperService): + def __init__( + self, + *, + rig_schema_filename: str, + db_root: os.PathLike, + destination_dir: os.PathLike, + db_suffix: Optional[str] = None, + ): + super().__init__() + self.filename = rig_schema_filename + self.db_root = db_root + self.db_dir = db_suffix if db_suffix else f"{_DATABASE_DIR}/{os.environ['COMPUTERNAME']}" + self.target_file = Path(self.db_root) / self.db_dir / self.filename + self.destination_dir = destination_dir + self._mapped: Optional[Rig] = None + + def validate(self): + file_exists = self.target_file.exists() + if not file_exists: + raise FileNotFoundError(f"File {self.target_file} does not exist.") + return file_exists + + def map(self) -> Rig: + self._mapped = model_from_json_file(self.target_file, Rig) + return self.mapped + + @property + def mapped(self) -> Rig: + if self._mapped is None: + raise ValueError("Data has not been mapped yet.") + return self._mapped + + def write_standard_file(self) -> None: + self.mapped.write_standard_file(self.destination_dir) + + +class AindSessionDataMapper(data_mapper_service.DataMapperService): + def __init__( + self, + session_model: AindBehaviorSessionModel, + rig_model: AindVrForagingRig, + task_logic_model: AindVrForagingTaskLogic, + repository: Union[os.PathLike, git.Repo], + script_path: os.PathLike, + session_end_time: Optional[datetime.datetime] = None, + output_parameters: Optional[Dict] = None, + subject_info: Optional[WaterLogResult] = None, + session_directory: Optional[os.PathLike] = None, + ): + self.session_model = session_model + self.rig_model = rig_model + self.task_logic_model = task_logic_model + self.session_directory = session_directory + self.repository = repository + self.script_path = script_path + self.session_end_time = session_end_time + self.output_parameters = output_parameters + self.subject_info = subject_info + self.mapped: Optional[aind_data_schema.core.session.Session] = None + + def validate(self, *args, **kwargs) -> bool: + return True + + def is_mapped(self) -> bool: + return self.mapped is not None + + def map(self) -> Optional[aind_data_schema.core.session.Session]: + logger.info("Mapping to aind-data-schema Session") + try: + ads_session = self._map( + session_model=self.session_model, + rig_model=self.rig_model, + task_logic_model=self.task_logic_model, + repository=self.repository, + script_path=self.script_path, + session_end_time=self.session_end_time, + output_parameters=self.output_parameters, + subject_info=self.subject_info, + ) + self.mapped = ads_session + if self.session_directory is not None: + logger.info("Writing session.json to %s", self.session_directory) + ads_session.write_standard_file(self.session_directory) + logger.info("Mapping successful.") + except (pydantic.ValidationError, ValueError, IOError) as e: + logger.error("Failed to map to aind-data-schema Session. %s", e) + raise e + else: + return ads_session + + @classmethod + def map_from_session_root( + cls, + schema_root: os.PathLike, + session_model: Type[AindBehaviorSessionModel], + rig_model: Type[AindVrForagingRig], + task_logic_model: Type[AindVrForagingTaskLogic], + repository: Union[os.PathLike, git.Repo], + script_path: os.PathLike, + session_end_time: Optional[datetime.datetime] = None, + output_parameters: Optional[Dict] = None, + subject_info: Optional[WaterLogResult] = None, + ) -> Self: + return cls( + session_model=model_from_json_file(Path(schema_root) / "session_input.json", session_model), + rig_model=model_from_json_file(Path(schema_root) / "rig_input.json", rig_model), + task_logic_model=model_from_json_file(Path(schema_root) / "tasklogic_input.json", task_logic_model), + session_directory=schema_root, + repository=repository, + script_path=script_path, + session_end_time=session_end_time if session_end_time else utcnow(), + output_parameters=output_parameters, + subject_info=subject_info, + ) + + @classmethod + def map_from_json_files( + cls, + session_json: os.PathLike, + rig_json: os.PathLike, + task_logic_json: os.PathLike, + session_model: Type[AindBehaviorSessionModel], + rig_model: Type[AindVrForagingRig], + task_logic_model: Type[AindVrForagingTaskLogic], + repository: Union[os.PathLike, git.Repo], + script_path: os.PathLike, + session_end_time: Optional[datetime.datetime], + session_directory: Optional[os.PathLike] = None, + output_parameters: Optional[Dict] = None, + subject_info: Optional[WaterLogResult] = None, + **kwargs, + ) -> Self: + return cls( + session_model=model_from_json_file(session_json, session_model), + rig_model=model_from_json_file(rig_json, rig_model), + task_logic_model=model_from_json_file(task_logic_json, task_logic_model), + session_directory=session_directory, + repository=repository, + script_path=script_path, + session_end_time=session_end_time if session_end_time else utcnow(), + output_parameters=output_parameters, + subject_info=subject_info, + **kwargs, + ) + + @classmethod + def _map( + cls, + session_model: AindBehaviorSessionModel, + rig_model: AindVrForagingRig, + task_logic_model: AindVrForagingTaskLogic, + repository: Union[os.PathLike, git.Repo], + script_path: os.PathLike, + session_end_time: Optional[datetime.datetime] = None, + output_parameters: Optional[Dict] = None, + subject_info: Optional[WaterLogResult] = None, + **kwargs, + ) -> aind_data_schema.core.session.Session: + # Normalize repository + if isinstance(repository, os.PathLike | str): + repository = git.Repo(Path(repository)) + repository_remote_url = repository.remote().url + repository_sha = repository.head.commit.hexsha + repository_relative_script_path = Path(script_path).resolve().relative_to(repository.working_dir) + + # Populate calibrations: + calibrations = [cls._mapper_calibration(rig_model.calibration.water_valve)] + # Populate cameras + cameras = data_mapper_service.get_cameras(rig_model, exclude_without_video_writer=True) + # populate devices + devices = [ + device[0] + for device in data_mapper_service.get_fields_of_type(rig_model, AbsRig.HarpDeviceGeneric) + if device[0] + ] + # Populate modalities + modalities: list[aind_data_schema.core.session.Modality] = [ + getattr(aind_data_schema.core.session.Modality, "BEHAVIOR") + ] + if len(cameras) > 0: + modalities.append(getattr(aind_data_schema.core.session.Modality, "BEHAVIOR_VIDEOS")) + modalities = list(set(modalities)) + # Populate stimulus modalities + stimulus_modalities: list[aind_data_schema.core.session.StimulusModality] = [] + stimulation_parameters: List[ + aind_data_schema.core.session.AuditoryStimulation + | aind_data_schema.core.session.OlfactoryStimulation + | aind_data_schema.core.session.VisualStimulation + ] = [] + stimulation_devices: List[str] = [] + # Olfactory Stimulation + stimulus_modalities.append(aind_data_schema.core.session.StimulusModality.OLFACTORY) + olfactory_stimulus_channel_config: List[aind_data_schema.components.stimulus.OlfactometerChannelConfig] = [] + for _, channel in rig_model.harp_olfactometer.calibration.input.channel_config.items(): + if channel.channel_type == OlfactometerChannelType.ODOR: + olfactory_stimulus_channel_config.append( + coerce_to_aind_data_schema(channel, aind_data_schema.components.stimulus.OlfactometerChannelConfig) + ) + stimulation_parameters.append( + aind_data_schema.core.session.OlfactoryStimulation( + stimulus_name="Olfactory", channels=olfactory_stimulus_channel_config + ) + ) + + _olfactory_device = data_mapper_service.get_fields_of_type(rig_model, AbsRig.HarpOlfactometer) + if len(_olfactory_device) > 0: + if _olfactory_device[0][0]: + stimulation_devices.append(_olfactory_device[0][0]) + else: + logger.error("Olfactometer device not found in rig model.") + raise ValueError("Olfactometer device not found in rig model.") + + # Auditory Stimulation + stimulus_modalities.append(aind_data_schema.core.session.StimulusModality.AUDITORY) + + stimulation_parameters.append( + aind_data_schema.core.session.AuditoryStimulation(sitmulus_name="Beep", sample_frequency=0) + ) + speaker_config = aind_data_schema.core.session.SpeakerConfig(name="Speaker", volume=60) + stimulation_devices.append("speaker") + # Visual/VR Stimulation + stimulus_modalities.extend( + [ + aind_data_schema.core.session.StimulusModality.VISUAL, + aind_data_schema.core.session.StimulusModality.VIRTUAL_REALITY, + ] + ) + + stimulation_parameters.append( + aind_data_schema.core.session.VisualStimulation( + stimulus_name="VrScreen", + stimulus_parameters={}, + ) + ) + _screen_device = data_mapper_service.get_fields_of_type(rig_model, AbsRig.Screen) + if len(_screen_device) > 0: + if _screen_device[0][0]: + stimulation_devices.append(_screen_device[0][0]) + else: + logger.error("Screen device not found in rig model.") + raise ValueError("Screen device not found in rig model.") + + stimulus_modalities.append(aind_data_schema.core.session.StimulusModality.WHEEL_FRICTION) + # Mouse platform + mouse_platform: str = "wheel" + + # Reward delivery + if rig_model.manipulator.calibration is None: + logger.error("Manipulator calibration is not set.") + raise ValueError("Manipulator calibration is not set.") + initial_position = rig_model.manipulator.calibration.input.initial_position + reward_delivery_config = aind_data_schema.core.session.RewardDeliveryConfig( + reward_solution=aind_data_schema.core.session.RewardSolution.WATER, + reward_spouts=[ + aind_data_schema.core.session.RewardSpoutConfig( + side=aind_data_schema.components.devices.SpoutSide.CENTER, + variable_position=True, + starting_position=aind_data_schema.components.devices.RelativePosition( + device_position_transformations=[ + aind_data_schema.components.coordinates.Translation3dTransform( + translation=[initial_position.x, initial_position.y2, initial_position.z] + ) + ], + device_origin="Manipulator home", + device_axes=[ + aind_data_schema.components.coordinates.Axis( + name=aind_data_schema.components.coordinates.AxisName.X, direction="Left" + ), + aind_data_schema.components.coordinates.Axis( + name=aind_data_schema.components.coordinates.AxisName.Y, direction="Front" + ), + aind_data_schema.components.coordinates.Axis( + name=aind_data_schema.components.coordinates.AxisName.Z, direction="Top" + ), + ], + ), + ) + ], + ) + + end_time = datetime.datetime.now() + + # Construct aind-data-schema session + aind_data_schema_session = aind_data_schema.core.session.Session( + animal_weight_post=subject_info.weight_g if subject_info else None, + reward_consumed_total=subject_info.water_earned_ml if subject_info else None, + reward_delivery=reward_delivery_config, + experimenter_full_name=session_model.experimenter, + session_start_time=session_model.date, + session_end_time=session_end_time, + session_type=session_model.experiment, + rig_id=rig_model.rig_name, + subject_id=session_model.subject, + notes=session_model.notes, + data_streams=[ + aind_data_schema.core.session.Stream( + daq_names=devices, + stream_modalities=modalities, + stream_start_time=session_model.date, + stream_end_time=session_end_time if session_end_time else end_time, + camera_names=list(cameras.keys()), + ), + ], + calibrations=calibrations, + mouse_platform_name=mouse_platform, + active_mouse_platform=True, + stimulus_epochs=[ + aind_data_schema.core.session.StimulusEpoch( + stimulus_name=session_model.experiment, + stimulus_start_time=session_model.date, + stimulus_end_time=session_end_time if session_end_time else end_time, + stimulus_modalities=stimulus_modalities, + stimulus_parameters=stimulation_parameters, + software=[ + aind_data_schema.core.session.Software( + name="Bonsai", + version=f"{repository_remote_url}/blob/{repository_sha}/bonsai/Bonsai.config", + url=f"{repository_remote_url}/blob/{repository_sha}/bonsai", + parameters=data_mapper_service.snapshot_bonsai_environment( + config_file=kwargs.get("bonsai_config_path", Path("./bonsai/bonsai.config")) + ), + ), + aind_data_schema.core.session.Software( + name="Python", + version=f"{repository_remote_url}/blob/{repository_sha}/pyproject.toml", + url=f"{repository_remote_url}/blob/{repository_sha}", + parameters=data_mapper_service.snapshot_python_environment(), + ), + ], + script=aind_data_schema.core.session.Software( + name=Path(script_path).stem, + version=session_model.commit_hash if session_model.commit_hash else repository_sha, + url=f"{repository_remote_url}/blob/{repository_sha}/{repository_relative_script_path}", + parameters=task_logic_model.model_dump(), + ), + output_parameters=output_parameters if output_parameters else {}, + speaker_config=speaker_config, + reward_consumed_during_epoch=subject_info.total_water_ml if subject_info else None, + stimulus_device_names=stimulation_devices, + ) # type: ignore + ], + ) # type: ignore + return aind_data_schema_session + + @staticmethod + def _mapper_calibration(calibration: Calibration) -> aind_data_schema.components.devices.Calibration: + return aind_data_schema.components.devices.Calibration( + device_name=calibration.device_name, + input=calibration.input.model_dump() if calibration.input else {}, + output=calibration.output.model_dump() if calibration.output else {}, + calibration_date=calibration.date if calibration.date else utcnow(), + description=calibration.description if calibration.description else "", + notes=calibration.notes, + ) + + +def coerce_to_aind_data_schema(value: TFrom, target_type: Type[TTo]) -> TTo: + _normalized_input: dict + if isinstance(value, BaseModel): + _normalized_input = value.model_dump() + elif isinstance(value, dict): + _normalized_input = value + else: + raise ValueError(f"Expected value to be a BaseModel or a dict, got {type(value)}") + target_fields = target_type.model_fields + _normalized_input = {k: v for k, v in _normalized_input.items() if k in target_fields} + return target_type(**_normalized_input) + + +def aind_session_data_mapper_factory(launcher: BehaviorLauncher) -> AindSessionDataMapper: + now = utcnow() + return AindSessionDataMapper( + session_model=launcher.session_schema, + rig_model=launcher.rig_schema, + task_logic_model=launcher.task_logic_schema, + repository=launcher.repository, + script_path=launcher.services_factory_manager.bonsai_app.workflow, + session_directory=launcher.session_directory, + session_end_time=now, + ) + + +def aind_rig_data_mapper_factory( + launcher: BehaviorLauncher[AindVrForagingRig, AindBehaviorSessionModel, AindVrForagingTaskLogic], +) -> AindRigDataMapper: + rig_schema: AindVrForagingRig = launcher.rig_schema + return AindRigDataMapper( + rig_schema_filename=rig_schema.rig_name, + db_suffix=f"{_DATABASE_DIR}/{launcher.computer_name}", + db_root=launcher.config_library_dir, + destination_dir=launcher.session_directory, + ) diff --git a/src/DataSchemas/aind_behavior_vr_foraging/launcher.py b/src/DataSchemas/aind_behavior_vr_foraging/launcher.py new file mode 100644 index 00000000..5e247e01 --- /dev/null +++ b/src/DataSchemas/aind_behavior_vr_foraging/launcher.py @@ -0,0 +1,51 @@ +import aind_behavior_experiment_launcher.launcher.behavior_launcher as behavior_launcher +from aind_behavior_experiment_launcher.apps.app_service import BonsaiApp +from aind_behavior_experiment_launcher.resource_monitor.resource_monitor_service import ( + ResourceMonitor, + available_storage_constraint_factory, + remote_dir_exists_constraint_factory, +) +from aind_behavior_services.session import AindBehaviorSessionModel + +from aind_behavior_vr_foraging.rig import AindVrForagingRig +from aind_behavior_vr_foraging.task_logic import AindVrForagingTaskLogic + + +def make_launcher() -> behavior_launcher.BehaviorLauncher: + data_dir = r"C:/Data" + remote_dir = r"\\allen\aind\scratch\vr-foraging\data" + srv = behavior_launcher.BehaviorServicesFactoryManager() + srv.bonsai_app = BonsaiApp(r"./src/vr-foraging.bonsai") + srv.data_transfer = behavior_launcher.robocopy_data_transfer_factory(remote_dir) + srv.resource_monitor = ResourceMonitor( + constrains=[ + available_storage_constraint_factory(data_dir, 2e11), + remote_dir_exists_constraint_factory(remote_dir), + ] + ) + + return behavior_launcher.BehaviorLauncher( + rig_schema_model=AindVrForagingRig, + session_schema_model=AindBehaviorSessionModel, + task_logic_schema_model=AindVrForagingTaskLogic, + data_dir=data_dir, + config_library_dir=r"\\allen\aind\scratch\AindBehavior.db\AindVrForaging", + temp_dir=r"./local/.temp", + repository_dir=None, + allow_dirty=False, + skip_hardware_validation=False, + debug_mode=False, + group_by_subject_log=True, + services=srv, + validate_init=True, + ) + + +def main(): + launcher = make_launcher() + launcher.main() + return None + + +if __name__ == "__main__": + main() diff --git a/scripts/regenerate.py b/src/DataSchemas/aind_behavior_vr_foraging/regenerate.py similarity index 99% rename from scripts/regenerate.py rename to src/DataSchemas/aind_behavior_vr_foraging/regenerate.py index b9db9e2e..279045f2 100644 --- a/scripts/regenerate.py +++ b/src/DataSchemas/aind_behavior_vr_foraging/regenerate.py @@ -1,8 +1,6 @@ import inspect from pathlib import Path -import aind_behavior_vr_foraging.rig -import aind_behavior_vr_foraging.task_logic from aind_behavior_services.session import AindBehaviorSessionModel from aind_behavior_services.utils import ( convert_pydantic_to_bonsai, @@ -10,6 +8,9 @@ snake_to_pascal_case, ) +import aind_behavior_vr_foraging.rig +import aind_behavior_vr_foraging.task_logic + SCHEMA_ROOT = Path("./src/DataSchemas/") EXTENSIONS_ROOT = Path("./src/Extensions/") NAMESPACE_PREFIX = "AindVrForagingDataSchema" diff --git a/src/DataSchemas/aind_behavior_vr_foraging/rig.py b/src/DataSchemas/aind_behavior_vr_foraging/rig.py index 346068f0..a3921220 100644 --- a/src/DataSchemas/aind_behavior_vr_foraging/rig.py +++ b/src/DataSchemas/aind_behavior_vr_foraging/rig.py @@ -39,7 +39,7 @@ class AindManipulatorDevice(aind_manipulator.AindManipulatorDevice): class HarpOlfactometer(rig.HarpOlfactometer): """Overrides the default settings for the olfactometer calibration""" - calibration: Optional[oc.OlfactometerCalibration] = Field(default=None, description="Olfactometer calibration") + calibration: oc.OlfactometerCalibration = Field(default=None, description="Olfactometer calibration") class RigCalibration(BaseModel): diff --git a/src/DataSchemas/aind_vr_foraging_rig.json b/src/DataSchemas/aind_vr_foraging_rig.json index 35c4fb57..829b0f93 100644 --- a/src/DataSchemas/aind_vr_foraging_rig.json +++ b/src/DataSchemas/aind_vr_foraging_rig.json @@ -2107,7 +2107,7 @@ }, "properties": { "aind_behavior_services_pkg_version": { - "default": "0.8.1", + "default": "0.8.2", "pattern": "^(0|[1-9]\\d*)\\.(0|[1-9]\\d*)\\.(0|[1-9]\\d*)(?:-((?:0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\\.(?:0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\\+([0-9a-zA-Z-]+(?:\\.[0-9a-zA-Z-]+)*))?$", "title": "aind_behavior_services package version", "type": "string" diff --git a/src/DataSchemas/aind_vr_foraging_task_logic.json b/src/DataSchemas/aind_vr_foraging_task_logic.json index 7b6fe21f..aa022565 100644 --- a/src/DataSchemas/aind_vr_foraging_task_logic.json +++ b/src/DataSchemas/aind_vr_foraging_task_logic.json @@ -17,7 +17,7 @@ "title": "Rng Seed" }, "aind_behavior_services_pkg_version": { - "default": "0.8.1", + "default": "0.8.2", "pattern": "^(0|[1-9]\\d*)\\.(0|[1-9]\\d*)\\.(0|[1-9]\\d*)(?:-((?:0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\\.(?:0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\\+([0-9a-zA-Z-]+(?:\\.[0-9a-zA-Z-]+)*))?$", "title": "aind_behavior_services package version", "type": "string" diff --git a/src/Extensions/AindBehaviorSessionModel.cs b/src/Extensions/AindBehaviorSessionModel.cs index d7d58912..3e1c7e61 100644 --- a/src/Extensions/AindBehaviorSessionModel.cs +++ b/src/Extensions/AindBehaviorSessionModel.cs @@ -15,7 +15,7 @@ namespace AindVrForagingDataSchema.Session public partial class AindBehaviorSessionModel { - private string _aindBehaviorServicesPkgVersion = "0.8.1"; + private string _aindBehaviorServicesPkgVersion = "0.8.2"; private string _version = "0.3.0"; diff --git a/src/Extensions/AindVrForagingRig.cs b/src/Extensions/AindVrForagingRig.cs index 5ce1d97c..645830c7 100644 --- a/src/Extensions/AindVrForagingRig.cs +++ b/src/Extensions/AindVrForagingRig.cs @@ -5385,7 +5385,7 @@ public override string ToString() public partial class AindVrForagingRig { - private string _aindBehaviorServicesPkgVersion = "0.8.1"; + private string _aindBehaviorServicesPkgVersion = "0.8.2"; private string _version = "0.4.0"; diff --git a/src/Extensions/AindVrForagingTaskLogic.cs b/src/Extensions/AindVrForagingTaskLogic.cs index 46e63503..d4e2b4a9 100644 --- a/src/Extensions/AindVrForagingTaskLogic.cs +++ b/src/Extensions/AindVrForagingTaskLogic.cs @@ -17,7 +17,7 @@ public partial class AindVrForagingTaskParameters private double? _rngSeed; - private string _aindBehaviorServicesPkgVersion = "0.8.1"; + private string _aindBehaviorServicesPkgVersion = "0.8.2"; private System.Collections.Generic.IDictionary _updaters; diff --git a/tests/__init__.py b/tests/__init__.py index 2bc553f6..f7760d56 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,11 +1,16 @@ import glob import importlib.util +import logging from pathlib import Path from types import ModuleType EXAMPLES_DIR = Path(__file__).parents[1] / "examples" JSON_ROOT = Path("./local").resolve() +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) +logging.disable(logging.CRITICAL) + def build_example(script_path: str) -> ModuleType: module_name = Path(script_path).stem diff --git a/tests/test_aind_data_mapper.py b/tests/test_aind_data_mapper.py new file mode 100644 index 00000000..300ff6a6 --- /dev/null +++ b/tests/test_aind_data_mapper.py @@ -0,0 +1,100 @@ +import sys +import unittest +from datetime import datetime +from pathlib import Path +from unittest.mock import MagicMock, patch + +from aind_behavior_vr_foraging.data_mappers import ( + AindBehaviorSessionModel, + AindRigDataMapper, + AindSessionDataMapper, + AindVrForagingRig, + AindVrForagingTaskLogic, +) +from aind_data_schema.core.rig import Rig +from git import Repo + +sys.path.append(".") +from examples.examples import mock_rig, mock_session, mock_task_logic # isort:skip # pylint: disable=wrong-import-position + + +class TestAindSessionDataMapper(unittest.TestCase): + def setUp(self): + self.session_model = mock_session() + self.rig_model = mock_rig() + self.task_logic_model = mock_task_logic() + self.repository = Repo(Path("./")) + self.script_path = Path("./src/vr-foraging.bonsai") + self.session_end_time = datetime.now() + self.session_directory = Path("./") + + self.mapper = AindSessionDataMapper( + session_model=self.session_model, + rig_model=self.rig_model, + task_logic_model=self.task_logic_model, + repository=self.repository, + script_path=self.script_path, + session_end_time=self.session_end_time, + session_directory=self.session_directory, + ) + + def test_validate(self): + self.assertTrue(self.mapper.validate()) + + @patch("aind_behavior_vr_foraging.data_mappers.logger") + @patch("aind_behavior_vr_foraging.data_mappers.AindSessionDataMapper._map") + def test_mock_map(self, mock_map, mock_logger): + mock_map.return_value = MagicMock() + result = self.mapper.map() + self.assertIsNotNone(result) + self.assertTrue(self.mapper.is_mapped()) + mock_logger.info.assert_called_with("Mapping successful.") + + def test_map(self): + mapped = self.mapper.map() + self.assertIsNotNone(mapped) + + @patch("aind_behavior_vr_foraging.data_mappers.model_from_json_file") + def test_map_from_json_files(self, mock_model_from_json_file): + mock_model_from_json_file.side_effect = [self.session_model, self.rig_model, self.task_logic_model] + session_json = MagicMock() + rig_json = MagicMock() + task_logic_json = MagicMock() + mapper = AindSessionDataMapper.map_from_json_files( + session_json=session_json, + rig_json=rig_json, + task_logic_json=task_logic_json, + session_model=AindBehaviorSessionModel, + rig_model=AindVrForagingRig, + task_logic_model=AindVrForagingTaskLogic, + repository=self.repository, + script_path=self.script_path, + session_end_time=self.session_end_time, + ) + self.assertIsInstance(mapper, AindSessionDataMapper) + + +class TestAindRigDataMapper(unittest.TestCase): + def setUp(self): + self.rig_schema_filename = "rig_schema.json" + self.db_root = MagicMock() + self.destination_dir = MagicMock() + self.db_suffix = "test_suffix" + self.mapper = AindRigDataMapper( + rig_schema_filename=self.rig_schema_filename, + db_root=self.db_root, + destination_dir=self.destination_dir, + db_suffix=self.db_suffix, + ) + + @patch("aind_behavior_vr_foraging.data_mappers.model_from_json_file") + def test_mock_map(self, mock_model_from_json_file): + mock_model_from_json_file.return_value = MagicMock(spec=Rig) + result = self.mapper.map() + self.assertIsNotNone(result) + self.assertTrue(self.mapper.mapped) + self.assertIsInstance(result, Rig) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_bonsai.py b/tests/test_bonsai.py index c67b93e1..74d716dd 100644 --- a/tests/test_bonsai.py +++ b/tests/test_bonsai.py @@ -1,6 +1,7 @@ import os import sys import unittest +import warnings from pathlib import Path from typing import Generic, List, Optional, TypeVar, Union @@ -43,11 +44,13 @@ def test_deserialization(self): stdout = completed_proc.stdout.decode().split("\n") stdout = [line for line in stdout if (line or line != "")] - for model in models_to_test: - try: - model.try_deserialization(stdout) - except ValueError: - self.fail(f"Could not find a match for {model.input_model.__class__.__name__}.") + with warnings.catch_warnings(): # suppress the warnings relative to the coercion of version across schemas + warnings.simplefilter("ignore") + for model in models_to_test: + try: + model.try_deserialization(stdout) + except ValueError: + self.fail(f"Could not find a match for {model.input_model.__class__.__name__}.") class TestModel(Generic[TModel]):