Skip to content

Commit

Permalink
Fix plotting errors and add tests. Add utility tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
jbkalmbach committed Dec 10, 2024
1 parent 4ee2230 commit d06a75d
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 20 deletions.
11 changes: 11 additions & 0 deletions doc/versionHistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@
Version History
##################

.._lsst.ts.donut.viz-1.3.0

-------------
1.3.0
-------------

* Fix failures that occur when detectors are missing data.
* Add tests for detectors missing data.
* Fix intra, extra labeling in donut plots.
* Add utilities tests.

.._lsst.ts.donut.viz-1.2.3

-------------
Expand Down
29 changes: 14 additions & 15 deletions python/lsst/donut/viz/aggregate_visit.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,21 +509,20 @@ def runQuantum(
inputRefs: pipeBase.InputQuantizedConnection,
outputRefs: pipeBase.OutputQuantizedConnection,
) -> None:
adc = butlerQC.get(inputRefs.aggregateDonutTable)
adt = butlerQC.get(inputRefs.aggregateDonutTable)
azr = butlerQC.get(inputRefs.aggregateZernikesRaw)
aza = butlerQC.get(inputRefs.aggregateZernikesAvg)

avg_table, raw_table = self.run(adc, azr, aza)
avg_table, raw_table = self.run(adt, azr, aza)

print(outputRefs.aggregateAOSAvg)
butlerQC.put(avg_table, outputRefs.aggregateAOSAvg)
butlerQC.put(raw_table, outputRefs.aggregateAOSRaw)

@timeMethod
def run(
self, adc: typing.List[Table], azr: typing.List[Table], aza: typing.List[Table]
self, adt: typing.List[Table], azr: typing.List[Table], aza: typing.List[Table]
) -> tuple[Table, Table]:
dets = np.unique(adc["detector"])
dets = np.unique(adt["detector"])
avg_table = aza.copy()
avg_keys = [
"coord_ra",
Expand All @@ -543,22 +542,22 @@ def run(
for det in dets:
w = avg_table["detector"] == det
for k in avg_keys:
avg_table[k][w] = np.mean(adc[k][adc["detector"] == det])
avg_table[k][w] = np.mean(adt[k][adt["detector"] == det])

raw_table = azr.copy()
for k in avg_keys:
raw_table[k] = np.nan # Allocate
for det in dets:
w = raw_table["detector"] == det
wadc = adc["detector"] == det
fzmin = adc[wadc]["focusZ"].min()
fzmax = adc[wadc]["focusZ"].max()
wadt = adt["detector"] == det
fzmin = adt[wadt]["focusZ"].min()
fzmax = adt[wadt]["focusZ"].max()
if fzmin == fzmax: # single-sided Zernike estimates
for k in avg_keys:
raw_table[k][w] = adc[k][wadc]
raw_table[k][w] = adt[k][wadt]
else: # double-sided Zernike estimates
wintra = adc[wadc]["focusZ"] == fzmin
wextra = adc[wadc]["focusZ"] == fzmax
wintra = adt[wadt]["focusZ"] == fzmin
wextra = adt[wadt]["focusZ"] == fzmax
for k in avg_keys:
# If one table has more rows than the other,
# trim the longer one
Expand All @@ -572,13 +571,13 @@ def run(
)
# ought to be the same length now
raw_table[k][w] = 0.5 * (
adc[k][wadc][wintra] + adc[k][wadc][wextra]
adt[k][wadt][wintra] + adt[k][wadt][wextra]
)
if k + "_intra" not in raw_table.colnames:
raw_table[k + "_intra"] = np.nan
raw_table[k + "_extra"] = np.nan
raw_table[k + "_intra"][w] = adc[k][wadc][wintra]
raw_table[k + "_extra"][w] = adc[k][wadc][wextra]
raw_table[k + "_intra"][w] = adt[k][wadt][wintra]
raw_table[k + "_extra"][w] = adt[k][wadt][wextra]

return avg_table, raw_table

Expand Down
2 changes: 1 addition & 1 deletion python/lsst/donut/viz/plot_aos_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,6 @@ def run(self, zernikes, **kwargs) -> plt.figure:
fig = psfPanel(xs, ys, psf, dname, fig=fig)

# draw rose
add_coordinate_roses(fig, rtp, q)
add_coordinate_roses(fig, rtp, q, [(0.15, 0.94), (0.85, 0.94)])

return fig
2 changes: 1 addition & 1 deletion python/lsst/donut/viz/psf_from_zern.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ def psfPanel(

# cycling through the axes.
for i, dn in enumerate(detname):
axs[i].set(xlim=det_lim_x, ylim=det_lim_y, xticks=[], yticks=[], aspect="equal")
if len(psf[i]) == 0:
continue
im = axs[i].scatter(xs[i], ys[i], c=psf[i], cmap=cmap, vmax=pmax, vmin=pmin)
axs[i].set_title(f"{dn}: {np.nanmean(psf[i]):.3f} +/- {np.nanstd(psf[i]):.3f}")
axs[i].set(xlim=det_lim_x, ylim=det_lim_y, xticks=[], yticks=[], aspect="equal")

# setting the colorbar
cb = fig.colorbar(im, cax=ax_cbar, location="bottom")
Expand Down
16 changes: 13 additions & 3 deletions python/lsst/donut/viz/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def get_day_obs_seq_num_from_visitid(visit):
return day_obs, seq_num


def add_coordinate_roses(fig, rtp, q):
def add_coordinate_roses(fig, rtp, q, p0=None):
"""Add coordinate system roses to the figure.
Parameters
Expand All @@ -192,19 +192,29 @@ def add_coordinate_roses(fig, rtp, q):
q : `float`
The boresight parallactic angle in radians, used to
determine the position of the North and East vectors.
p0 : list or None
If list should be list of two (x, y) coordinates on the figure. The
first is for the x,y coordinates system rose and the latter is for
the compass directional rose. Default locations when None is passed
are (0.15, 0.8), (0.85, 0.8). (The default is None.)
"""
if p0 is None:
p0 = [(0.15, 0.8), (0.85, 0.8)]
elif np.shape(p0) != (2, 2):
raise ValueError("If p0 is not None, it must be a pair of (x, y) coordinates")

vecs_xy = {
r"$x_\mathrm{Opt}$": (1, 0),
r"$y_\mathrm{Opt}$": (0, -1),
r"$x_\mathrm{Cam}$": (np.cos(rtp), -np.sin(rtp)),
r"$y_\mathrm{Cam}$": (-np.sin(rtp), -np.cos(rtp)),
}
rose(fig, vecs_xy, p0=(0.15, 0.8))
rose(fig, vecs_xy, p0=p0[0])

vecs_NE = {
"az": (1, 0),
"alt": (0, +1),
"N": (np.sin(q), np.cos(q)),
"E": (np.sin(q - np.pi / 2), np.cos(q - np.pi / 2)),
}
rose(fig, vecs_NE, p0=(0.85, 0.8))
rose(fig, vecs_NE, p0=p0[1])
19 changes: 19 additions & 0 deletions tests/test_donut_viz_pipeline_science_sensors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from copy import copy

import numpy as np
from lsst.daf.butler import Butler
Expand All @@ -9,6 +10,8 @@
AggregateDonutTablesTaskConfig,
AggregateZernikeTablesTask,
AggregateZernikeTablesTaskConfig,
PlotPsfZernTask,
PlotPsfZernTaskConfig,
)
from lsst.ts.wep.task.generateDonutCatalogUtils import convertDictToVisitInfo
from lsst.ts.wep.task.pairTask import ExposurePairer
Expand Down Expand Up @@ -378,3 +381,19 @@ def testAggDonutTablesRunMissingDate(self):
camera, visitInfoDict, pairs, donutTables, qualityTables
)
self.assertEqual(len(agg_donut_table[4021123106001]), 6)

def testPlotPsfZernTaskMissingData(self):
# Test that if detectors have different numbers of zernikes
# the plot still gets made.
zernike_datasets = self.butler.query_datasets(
"zernikes", collections=self.test_run_name
)
zernikes = [self.butler.get(dataset) for dataset in zernike_datasets]
zernikes_missing_data = copy(zernikes)
zernikes_missing_data[0].remove_rows(np.arange(len(zernikes_missing_data[0])))
task = PlotPsfZernTask(config=PlotPsfZernTaskConfig())
for input_data in [zernikes, zernikes_missing_data]:
try:
task.run(zernikes)
except Exception:
self.fail(f"Unexpected exception raised with input {input_data}")
31 changes: 31 additions & 0 deletions tests/test_utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from lsst.donut.viz.utilities import (
add_coordinate_roses,
get_day_obs_seq_num_from_visitid,
get_instrument_channel_name,
)
from lsst.utils.tests import TestCase


class TestDonutVizUtilities(TestCase):
def testGetInstrumentChannelName(self):
self.assertTrue(get_instrument_channel_name("LSSTCam"), "lsstcam_aos")
self.assertTrue(get_instrument_channel_name("LSSTCamSim"), "lsstcam_sim_aos")
self.assertTrue(get_instrument_channel_name("LSSTComCam"), "comcam_aos")
self.assertTrue(get_instrument_channel_name("LSSTComCamSim"), "comcam_sim_aos")
with self.assertRaises(ValueError) as context:
get_instrument_channel_name("LSSTCAM")
expected_msg = "Unknown instrument LSSTCAM"
self.assertEqual(str(context.exception), expected_msg)

def testGetDayObsSeqNumFromVisitId(self):
day_obs, seq_num = get_day_obs_seq_num_from_visitid(4021123106001)
self.assertEqual(day_obs, 20211231)
self.assertEqual(seq_num, 6001)

def testAddCoordinateRoses(self):
with self.assertRaises(ValueError):
add_coordinate_roses(None, None, None, [1, 2])
with self.assertRaises(ValueError) as context:
add_coordinate_roses(None, None, None, [(1, 2)])
expected_msg = "If p0 is not None, it must be a pair of (x, y) coordinates"
self.assertEqual(str(context.exception), expected_msg)

0 comments on commit d06a75d

Please sign in to comment.