diff --git a/python/lsst/donut/viz/aggregateVisit.py b/python/lsst/donut/viz/aggregateVisit.py index a921166..3214ef6 100644 --- a/python/lsst/donut/viz/aggregateVisit.py +++ b/python/lsst/donut/viz/aggregateVisit.py @@ -24,6 +24,9 @@ "AggregateDonutStampsTaskConnections", "AggregateDonutStampsTaskConfig", "AggregateDonutStampsTask", + "AggregateStarStampsTaskConnections", + "AggregateStarStampsTaskConfig", + "AggregateStarStampsTask", ] @@ -494,3 +497,79 @@ def runQuantum( butlerQC.put(DonutStamps(extraStampsListRavel, metadata=extra.metadata), outputRefs.donutStampsExtraVisit) + + +class AggregateStarStampsTaskConnections( + pipeBase.PipelineTaskConnections, + dimensions=("instrument", "visit"), +): + starStamps = ct.Input( + doc="In-Focus Postage Stamp Images", + dimensions=("visit", "detector", "instrument"), + storageClass="StampsBase", + name="starStamps", + multiple=True, + ) + starStampsVisit = ct.Output( + doc="In-Focus Star Stamps", + dimensions=("visit", "instrument"), + storageClass="StampsBase", + name="starStampsVisit", + ) + + +class AggregateStarStampsTaskConfig( + pipeBase.PipelineTaskConfig, + pipelineConnections=AggregateStarStampsTaskConnections, +): + maxStarsPerDetector = pexConfig.Field[int]( + doc="Maximum number of stars to use per detector", + default=1, + ) + + def validate(self): + if self.maxStarsPerDetector < 1: + raise pexConfig.FieldValidationError("maxStarsPerDetector must be at least 1") + + +class AggregateStarStampsTask(pipeBase.PipelineTask): + ConfigClass = AggregateStarStampsTaskConfig + _DefaultName = "AggregateStarStamps" + + @timeMethod + def runQuantum( + self, + butlerQC: pipeBase.QuantumContext, + inputRefs: pipeBase.InputQuantizedConnection, + outputRefs: pipeBase.OutputQuantizedConnection + ) -> None: + stampsList = [] + for ref in inputRefs.starStamps: + stampRef = butlerQC.get(ref) + print(f'Number of stampRefs for the detector: {len(stampRef)}') + bad_idx = [] + for idx in range(len(stampRef)): + maskedImage = stampRef[idx].stamp_im + bit = maskedImage.mask.getPlaneBitMask(('SAT', 'BAD', 'NO_DATA')) + countAffectedPixels = len(np.where(np.bitwise_and(maskedImage.mask.array, bit))[0]) + if countAffectedPixels > 50: + print(f'{idx} has {countAffectedPixels} affected pixels') + bad_idx.append(idx) + + all_idx = np.arange(len(stampRef)) + good_idx = np.array([idx for idx in all_idx if idx not in bad_idx]) + if len(good_idx) == 0: + raise RuntimeError("All stars are saturated") + else: + if self.config.maxStarsPerDetector > len(good_idx): + self.log.warning(f"maxStarsPerDetector ({self.config.maxStarsPerDetector })larger \ +than number of available unsaturated stars ({len(good_idx)})") + self.config.maxStarsPerDetector = len(good_idx)-1 + self.log.warning(f"Reducing maxStarsPerDetector to {self.config.maxStarsPerDetector }") + # No need to change the number of selected stars + select_idx = good_idx[:self.config.maxStarsPerDetector] + stampsList.append(np.array(stampRef)[select_idx]) + stampsListRavel = np.ravel(stampsList) + + butlerQC.put(DonutStamps(stampsListRavel, metadata=stampRef.metadata), + outputRefs.starStampsVisit)