diff options
Diffstat (limited to 'src/silx/gui/plot/test/testStats.py')
-rw-r--r-- | src/silx/gui/plot/test/testStats.py | 701 |
1 files changed, 399 insertions, 302 deletions
diff --git a/src/silx/gui/plot/test/testStats.py b/src/silx/gui/plot/test/testStats.py index c5d5181..2a2793e 100644 --- a/src/silx/gui/plot/test/testStats.py +++ b/src/silx/gui/plot/test/testStats.py @@ -34,13 +34,11 @@ from silx.gui.plot import StatsWidget from silx.gui.plot.stats import statshandler from silx.gui.utils.testutils import TestCaseQt, SignalListener from silx.gui.plot import Plot1D, Plot2D -from silx.gui.plot3d.SceneWidget import SceneWidget from silx.gui.plot.items.roi import RectangleROI, PolygonROI -from silx.gui.plot.tools.roi import RegionOfInterestManager +from silx.gui.plot.tools.roi import RegionOfInterestManager from silx.gui.plot.stats.stats import Stats from silx.gui.plot.CurvesROIWidget import ROI from silx.utils.testutils import ParametricTestCase -import unittest import logging import numpy @@ -49,6 +47,7 @@ _logger = logging.getLogger(__name__) class TestStatsBase(object): """Base class for stats TestCase""" + def setUp(self): self.createCurveContext() self.createImageContext() @@ -69,51 +68,52 @@ class TestStatsBase(object): self.plot1d = Plot1D() x = range(20) y = range(20) - self.plot1d.addCurve(x, y, legend='curve0') + self.plot1d.addCurve(x, y, legend="curve0") self.curveContext = stats._CurveContext( - item=self.plot1d.getCurve('curve0'), + item=self.plot1d.getCurve("curve0"), plot=self.plot1d, onlimits=False, - roi=None) + roi=None, + ) def createScatterContext(self): self.scatterPlot = Plot2D() - lgd = 'scatter plot' + lgd = "scatter plot" self.xScatterData = numpy.array([0, 2, 3, 20, 50, 60, 36]) self.yScatterData = numpy.array([2, 3, 4, 26, 69, 6, 18]) self.valuesScatterData = numpy.array([5, 6, 7, 10, 90, 20, 5]) - self.scatterPlot.addScatter(self.xScatterData, self.yScatterData, - self.valuesScatterData, legend=lgd) + self.scatterPlot.addScatter( + self.xScatterData, self.yScatterData, self.valuesScatterData, legend=lgd + ) self.scatterContext = stats._ScatterContext( item=self.scatterPlot.getScatter(lgd), plot=self.scatterPlot, onlimits=False, - roi=None + roi=None, ) def createImageContext(self): self.plot2d = Plot2D() - self._imgLgd = 'test image' - self.imageData = numpy.arange(32*128).reshape(32, 128) - self.plot2d.addImage(data=self.imageData, - legend=self._imgLgd, replace=False) + self._imgLgd = "test image" + self.imageData = numpy.arange(32 * 128).reshape(32, 128) + self.plot2d.addImage(data=self.imageData, legend=self._imgLgd, replace=False) self.imageContext = stats._ImageContext( item=self.plot2d.getImage(self._imgLgd), plot=self.plot2d, onlimits=False, - roi=None + roi=None, ) def getBasicStats(self): return { - 'min': stats.StatMin(), - 'minCoords': stats.StatCoordMin(), - 'max': stats.StatMax(), - 'maxCoords': stats.StatCoordMax(), - 'std': stats.Stat(name='std', fct=numpy.std), - 'mean': stats.Stat(name='mean', fct=numpy.mean), - 'com': stats.StatCOM() + "min": stats.StatMin(), + "minCoords": stats.StatCoordMin(), + "max": stats.StatMax(), + "maxCoords": stats.StatCoordMax(), + "std": stats.Stat(name="std", fct=numpy.std), + "mean": stats.Stat(name="mean", fct=numpy.mean), + "com": stats.StatCOM(), } @@ -121,6 +121,7 @@ class TestStats(TestStatsBase, TestCaseQt): """ Test :class:`BaseClass` class and inheriting classes """ + def setUp(self): TestCaseQt.setUp(self) TestStatsBase.setUp(self) @@ -133,41 +134,50 @@ class TestStats(TestStatsBase, TestCaseQt): """Test result for simple stats on a curve""" _stats = self.getBasicStats() xData = yData = numpy.array(range(20)) - self.assertEqual(_stats['min'].calculate(self.curveContext), 0) - self.assertEqual(_stats['max'].calculate(self.curveContext), 19) - self.assertEqual(_stats['minCoords'].calculate(self.curveContext), (0,)) - self.assertEqual(_stats['maxCoords'].calculate(self.curveContext), (19,)) - self.assertEqual(_stats['std'].calculate(self.curveContext), numpy.std(yData)) - self.assertEqual(_stats['mean'].calculate(self.curveContext), numpy.mean(yData)) + self.assertEqual(_stats["min"].calculate(self.curveContext), 0) + self.assertEqual(_stats["max"].calculate(self.curveContext), 19) + self.assertEqual(_stats["minCoords"].calculate(self.curveContext), (0,)) + self.assertEqual(_stats["maxCoords"].calculate(self.curveContext), (19,)) + self.assertEqual(_stats["std"].calculate(self.curveContext), numpy.std(yData)) + self.assertEqual(_stats["mean"].calculate(self.curveContext), numpy.mean(yData)) com = numpy.sum(xData * yData) / numpy.sum(yData) - self.assertEqual(_stats['com'].calculate(self.curveContext), com) + self.assertEqual(_stats["com"].calculate(self.curveContext), com) def testBasicStatsImage(self): """Test result for simple stats on an image""" _stats = self.getBasicStats() - self.assertEqual(_stats['min'].calculate(self.imageContext), 0) - self.assertEqual(_stats['max'].calculate(self.imageContext), 128 * 32 - 1) - self.assertEqual(_stats['minCoords'].calculate(self.imageContext), (0, 0)) - self.assertEqual(_stats['maxCoords'].calculate(self.imageContext), (127, 31)) - self.assertEqual(_stats['std'].calculate(self.imageContext), numpy.std(self.imageData)) - self.assertEqual(_stats['mean'].calculate(self.imageContext), numpy.mean(self.imageData)) + self.assertEqual(_stats["min"].calculate(self.imageContext), 0) + self.assertEqual(_stats["max"].calculate(self.imageContext), 128 * 32 - 1) + self.assertEqual(_stats["minCoords"].calculate(self.imageContext), (0, 0)) + self.assertEqual(_stats["maxCoords"].calculate(self.imageContext), (127, 31)) + self.assertEqual( + _stats["std"].calculate(self.imageContext), numpy.std(self.imageData) + ) + self.assertEqual( + _stats["mean"].calculate(self.imageContext), numpy.mean(self.imageData) + ) yData = numpy.sum(self.imageData.astype(numpy.float64), axis=1) xData = numpy.sum(self.imageData.astype(numpy.float64), axis=0) dataXRange = range(self.imageData.shape[1]) dataYRange = range(self.imageData.shape[0]) - ycom = numpy.sum(yData*dataYRange) / numpy.sum(yData) - xcom = numpy.sum(xData*dataXRange) / numpy.sum(xData) + ycom = numpy.sum(yData * dataYRange) / numpy.sum(yData) + xcom = numpy.sum(xData * dataXRange) / numpy.sum(xData) - self.assertEqual(_stats['com'].calculate(self.imageContext), (xcom, ycom)) + self.assertEqual(_stats["com"].calculate(self.imageContext), (xcom, ycom)) def testStatsImageAdv(self): """Test that scale and origin are taking into account for images""" image2Data = numpy.arange(32 * 128).reshape(32, 128) - self.plot2d.addImage(data=image2Data, legend=self._imgLgd, - replace=True, origin=(100, 10), scale=(2, 0.5)) + self.plot2d.addImage( + data=image2Data, + legend=self._imgLgd, + replace=True, + origin=(100, 10), + scale=(2, 0.5), + ) image2Context = stats._ImageContext( item=self.plot2d.getImage(self._imgLgd), plot=self.plot2d, @@ -175,18 +185,19 @@ class TestStats(TestStatsBase, TestCaseQt): roi=None, ) _stats = self.getBasicStats() - self.assertEqual(_stats['min'].calculate(image2Context), 0) + self.assertEqual(_stats["min"].calculate(image2Context), 0) + self.assertEqual(_stats["max"].calculate(image2Context), 128 * 32 - 1) + self.assertEqual(_stats["minCoords"].calculate(image2Context), (100, 10)) self.assertEqual( - _stats['max'].calculate(image2Context), 128 * 32 - 1) + _stats["maxCoords"].calculate(image2Context), + (127 * 2.0 + 100, 31 * 0.5 + 10), + ) self.assertEqual( - _stats['minCoords'].calculate(image2Context), (100, 10)) + _stats["std"].calculate(image2Context), numpy.std(self.imageData) + ) self.assertEqual( - _stats['maxCoords'].calculate(image2Context), (127*2. + 100, - 31 * 0.5 + 10)) - self.assertEqual(_stats['std'].calculate(image2Context), - numpy.std(self.imageData)) - self.assertEqual(_stats['mean'].calculate(image2Context), - numpy.mean(self.imageData)) + _stats["mean"].calculate(image2Context), numpy.mean(self.imageData) + ) yData = numpy.sum(self.imageData, axis=1) xData = numpy.sum(self.imageData, axis=0) @@ -196,30 +207,36 @@ class TestStats(TestStatsBase, TestCaseQt): ycom = numpy.sum(yData * dataYRange) / numpy.sum(yData) ycom = (ycom * 0.5) + 10 xcom = numpy.sum(xData * dataXRange) / numpy.sum(xData) - xcom = (xcom * 2.) + 100 - self.assertTrue(numpy.allclose( - _stats['com'].calculate(image2Context), (xcom, ycom))) + xcom = (xcom * 2.0) + 100 + self.assertTrue( + numpy.allclose(_stats["com"].calculate(image2Context), (xcom, ycom)) + ) def testBasicStatsScatter(self): """Test result for simple stats on a scatter""" _stats = self.getBasicStats() - self.assertEqual(_stats['min'].calculate(self.scatterContext), 5) - self.assertEqual(_stats['max'].calculate(self.scatterContext), 90) - self.assertEqual(_stats['minCoords'].calculate(self.scatterContext), (0, 2)) - self.assertEqual(_stats['maxCoords'].calculate(self.scatterContext), (50, 69)) - self.assertEqual(_stats['std'].calculate(self.scatterContext), numpy.std(self.valuesScatterData)) - self.assertEqual(_stats['mean'].calculate(self.scatterContext), numpy.mean(self.valuesScatterData)) + self.assertEqual(_stats["min"].calculate(self.scatterContext), 5) + self.assertEqual(_stats["max"].calculate(self.scatterContext), 90) + self.assertEqual(_stats["minCoords"].calculate(self.scatterContext), (0, 2)) + self.assertEqual(_stats["maxCoords"].calculate(self.scatterContext), (50, 69)) + self.assertEqual( + _stats["std"].calculate(self.scatterContext), + numpy.std(self.valuesScatterData), + ) + self.assertEqual( + _stats["mean"].calculate(self.scatterContext), + numpy.mean(self.valuesScatterData), + ) data = self.valuesScatterData.astype(numpy.float64) comx = numpy.sum(self.xScatterData * data) / numpy.sum(data) comy = numpy.sum(self.yScatterData * data) / numpy.sum(data) - self.assertEqual(_stats['com'].calculate(self.scatterContext), - (comx, comy)) + self.assertEqual(_stats["com"].calculate(self.scatterContext), (comx, comy)) def testKindNotManagedByStat(self): """Make sure an exception is raised if we try to execute calculate of the base class""" - b = stats.StatBase(name='toto', compatibleKinds='curve') + b = stats.StatBase(name="toto", compatibleKinds="curve") with self.assertRaises(NotImplementedError): b.calculate(self.imageContext) @@ -228,7 +245,7 @@ class TestStats(TestStatsBase, TestCaseQt): Make sure an error is raised if we try to calculate a statistic with a context not managed """ - myStat = stats.Stat(name='toto', fct=numpy.std, kinds=('curve')) + myStat = stats.Stat(name="toto", fct=numpy.std, kinds=("curve")) myStat.calculate(self.curveContext) with self.assertRaises(ValueError): myStat.calculate(self.scatterContext) @@ -240,43 +257,48 @@ class TestStats(TestStatsBase, TestCaseQt): self.plot1d.getXAxis().setLimitsConstraints(minPos=2, maxPos=5) curveContextOnLimits = stats._CurveContext( - item=self.plot1d.getCurve('curve0'), + item=self.plot1d.getCurve("curve0"), plot=self.plot1d, onlimits=True, - roi=None) + roi=None, + ) self.assertEqual(stat.calculate(curveContextOnLimits), 2) self.plot2d.getXAxis().setLimitsConstraints(minPos=32) imageContextOnLimits = stats._ImageContext( - item=self.plot2d.getImage('test image'), + item=self.plot2d.getImage("test image"), plot=self.plot2d, onlimits=True, - roi=None) + roi=None, + ) self.assertEqual(stat.calculate(imageContextOnLimits), 32) self.scatterPlot.getXAxis().setLimitsConstraints(minPos=40) scatterContextOnLimits = stats._ScatterContext( - item=self.scatterPlot.getScatter('scatter plot'), + item=self.scatterPlot.getScatter("scatter plot"), plot=self.scatterPlot, onlimits=True, - roi=None) + roi=None, + ) self.assertEqual(stat.calculate(scatterContextOnLimits), 20) class TestStatsFormatter(TestCaseQt): """Simple test to check usage of the :class:`StatsFormatter`""" + def setUp(self): TestCaseQt.setUp(self) self.plot1d = Plot1D() x = range(20) y = range(20) - self.plot1d.addCurve(x, y, legend='curve0') + self.plot1d.addCurve(x, y, legend="curve0") self.curveContext = stats._CurveContext( - item=self.plot1d.getCurve('curve0'), + item=self.plot1d.getCurve("curve0"), plot=self.plot1d, onlimits=False, - roi=None) + roi=None, + ) self.stat = stats.StatMin() @@ -291,27 +313,30 @@ class TestStatsFormatter(TestCaseQt): simple cast to str""" emptyFormatter = statshandler.StatFormatter() self.assertEqual( - emptyFormatter.format(self.stat.calculate(self.curveContext)), '0.000') + emptyFormatter.format(self.stat.calculate(self.curveContext)), "0.000" + ) def testSettedFormatter(self): """Make sure a formatter with no formatter definition will return a simple cast to str""" - formatter= statshandler.StatFormatter(formatter='{0:.3f}') + formatter = statshandler.StatFormatter(formatter="{0:.3f}") self.assertEqual( - formatter.format(self.stat.calculate(self.curveContext)), '0.000') + formatter.format(self.stat.calculate(self.curveContext)), "0.000" + ) class TestStatsHandler(TestCaseQt): - """Make sure the StatHandler is correctly making the link between + """Make sure the StatHandler is correctly making the link between :class:`StatBase` and :class:`StatFormatter` and checking the API is valid """ + def setUp(self): TestCaseQt.setUp(self) self.plot1d = Plot1D() x = range(20) y = range(20) - self.plot1d.addCurve(x, y, legend='curve0') - self.curveItem = self.plot1d.getCurve('curve0') + self.plot1d.addCurve(x, y, legend="curve0") + self.curveItem = self.plot1d.getCurve("curve0") self.stat = stats.StatMin() @@ -324,91 +349,94 @@ class TestStatsHandler(TestCaseQt): def testConstructor(self): """Make sure the constructor can deal will all possible arguments: - + * tuple of :class:`StatBase` derivated classes * tuple of tuples (:class:`StatBase`, :class:`StatFormatter`) * tuple of tuples (str, pointer to function, kind) """ - handler0 = statshandler.StatsHandler( - (stats.StatMin(), stats.StatMax()) - ) + handler0 = statshandler.StatsHandler((stats.StatMin(), stats.StatMax())) - res = handler0.calculate(item=self.curveItem, plot=self.plot1d, - onlimits=False) - self.assertTrue('min' in res) - self.assertEqual(res['min'], '0') - self.assertTrue('max' in res) - self.assertEqual(res['max'], '19') + res = handler0.calculate(item=self.curveItem, plot=self.plot1d, onlimits=False) + self.assertTrue("min" in res) + self.assertEqual(res["min"], "0") + self.assertTrue("max" in res) + self.assertEqual(res["max"], "19") handler1 = statshandler.StatsHandler( ( (stats.StatMin(), statshandler.StatFormatter(formatter=None)), - (stats.StatMax(), statshandler.StatFormatter()) + (stats.StatMax(), statshandler.StatFormatter()), ) ) - res = handler1.calculate(item=self.curveItem, plot=self.plot1d, - onlimits=False) - self.assertTrue('min' in res) - self.assertEqual(res['min'], '0') - self.assertTrue('max' in res) - self.assertEqual(res['max'], '19.000') + res = handler1.calculate(item=self.curveItem, plot=self.plot1d, onlimits=False) + self.assertTrue("min" in res) + self.assertEqual(res["min"], "0") + self.assertTrue("max" in res) + self.assertEqual(res["max"], "19.000") handler2 = statshandler.StatsHandler( + ((stats.StatMin(), None), (stats.StatMax(), statshandler.StatFormatter())) + ) + + res = handler2.calculate(item=self.curveItem, plot=self.plot1d, onlimits=False) + self.assertTrue("min" in res) + self.assertEqual(res["min"], "0") + self.assertTrue("max" in res) + self.assertEqual(res["max"], "19.000") + + handler3 = statshandler.StatsHandler( ( - (stats.StatMin(), None), - (stats.StatMax(), statshandler.StatFormatter()) - )) - - res = handler2.calculate(item=self.curveItem, plot=self.plot1d, - onlimits=False) - self.assertTrue('min' in res) - self.assertEqual(res['min'], '0') - self.assertTrue('max' in res) - self.assertEqual(res['max'], '19.000') - - handler3 = statshandler.StatsHandler(( - (('amin', numpy.argmin), statshandler.StatFormatter()), - ('amax', numpy.argmax) - )) - - res = handler3.calculate(item=self.curveItem, plot=self.plot1d, - onlimits=False) - self.assertTrue('amin' in res) - self.assertEqual(res['amin'], '0.000') - self.assertTrue('amax' in res) - self.assertEqual(res['amax'], '19') + (("amin", numpy.argmin), statshandler.StatFormatter()), + ("amax", numpy.argmax), + ) + ) + + res = handler3.calculate(item=self.curveItem, plot=self.plot1d, onlimits=False) + self.assertTrue("amin" in res) + self.assertEqual(res["amin"], "0.000") + self.assertTrue("amax" in res) + self.assertEqual(res["amax"], "19") with self.assertRaises(ValueError): - statshandler.StatsHandler(('name')) + statshandler.StatsHandler(("name")) class TestStatsWidgetWithCurves(TestCaseQt, ParametricTestCase): """Basic test for StatsWidget with curves""" + def setUp(self): TestCaseQt.setUp(self) self.plot = Plot1D() self.plot.show() x = range(20) y = range(20) - self.plot.addCurve(x, y, legend='curve0') + self.plot.addCurve(x, y, legend="curve0") y = range(12, 32) - self.plot.addCurve(x, y, legend='curve1') + self.plot.addCurve(x, y, legend="curve1") y = range(-2, 18) - self.plot.addCurve(x, y, legend='curve2') + self.plot.addCurve(x, y, legend="curve2") self.widget = StatsWidget.StatsWidget(plot=self.plot) self.statsTable = self.widget._statsTable - mystats = statshandler.StatsHandler(( - stats.StatMin(), - (stats.StatCoordMin(), statshandler.StatFormatter(None, qt.QTableWidgetItem)), - stats.StatMax(), - (stats.StatCoordMax(), statshandler.StatFormatter(None, qt.QTableWidgetItem)), - stats.StatDelta(), - ('std', numpy.std), - ('mean', numpy.mean), - stats.StatCOM() - )) + mystats = statshandler.StatsHandler( + ( + stats.StatMin(), + ( + stats.StatCoordMin(), + statshandler.StatFormatter(None, qt.QTableWidgetItem), + ), + stats.StatMax(), + ( + stats.StatCoordMax(), + statshandler.StatFormatter(None, qt.QTableWidgetItem), + ), + stats.StatDelta(), + ("std", numpy.std), + ("mean", numpy.mean), + stats.StatCOM(), + ) + ) self.statsTable.setStats(mystats) @@ -456,42 +484,44 @@ class TestStatsWidgetWithCurves(TestCaseQt, ParametricTestCase): def testRemoveCurve(self): """Make sure the Curves stats take into account the curve removal from plot""" - self.plot.removeCurve('curve2') + self.plot.removeCurve("curve2") self.assertEqual(self.statsTable.rowCount(), 2) for iRow in range(2): - self.assertTrue(self.statsTable.item(iRow, 0).text() in ('curve0', 'curve1')) + self.assertTrue( + self.statsTable.item(iRow, 0).text() in ("curve0", "curve1") + ) - self.plot.removeCurve('curve0') + self.plot.removeCurve("curve0") self.assertEqual(self.statsTable.rowCount(), 1) - self.plot.removeCurve('curve1') + self.plot.removeCurve("curve1") self.assertEqual(self.statsTable.rowCount(), 0) def testAddCurve(self): """Make sure the Curves stats take into account the add curve action""" - self.plot.addCurve(legend='curve3', x=range(10), y=range(10)) + self.plot.addCurve(legend="curve3", x=range(10), y=range(10)) self.assertEqual(self.statsTable.rowCount(), 4) def testUpdateCurveFromAddCurve(self): """Make sure the stats of the cuve will be removed after updating a curve""" - self.plot.addCurve(legend='curve0', x=range(10), y=range(10)) + self.plot.addCurve(legend="curve0", x=range(10), y=range(10)) self.qapp.processEvents() self.assertEqual(self.statsTable.rowCount(), 3) - curve = self.plot._getItem(kind='curve', legend='curve0') + curve = self.plot._getItem(kind="curve", legend="curve0") tableItems = self.statsTable._itemToTableItems(curve) - self.assertEqual(tableItems['max'].text(), '9') + self.assertEqual(tableItems["max"].text(), "9") def testUpdateCurveFromCurveObj(self): - self.plot.getCurve('curve0').setData(x=range(4), y=range(4)) + self.plot.getCurve("curve0").setData(x=range(4), y=range(4)) self.qapp.processEvents() self.assertEqual(self.statsTable.rowCount(), 3) - curve = self.plot._getItem(kind='curve', legend='curve0') + curve = self.plot._getItem(kind="curve", legend="curve0") tableItems = self.statsTable._itemToTableItems(curve) - self.assertEqual(tableItems['max'].text(), '3') + self.assertEqual(tableItems["max"].text(), "3") def testSetAnotherPlot(self): plot2 = Plot1D() - plot2.addCurve(x=range(26), y=range(26), legend='new curve') + plot2.addCurve(x=range(26), y=range(26), legend="new curve") self.statsTable.setPlot(plot2) self.assertEqual(self.statsTable.rowCount(), 1) self.qapp.processEvents() @@ -501,50 +531,62 @@ class TestStatsWidgetWithCurves(TestCaseQt, ParametricTestCase): def testUpdateMode(self): """Make sure the update modes are well take into account""" - self.plot.setActiveCurve('curve0') + self.plot.setActiveCurve("curve0") for display_only_active in (True, False): with self.subTest(display_only_active=display_only_active): self.widget.setDisplayOnlyActiveItem(display_only_active) - self.plot.getCurve('curve0').setData(x=range(4), y=range(4)) + self.plot.getCurve("curve0").setData(x=range(4), y=range(4)) self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO) update_stats_action = self.widget._options.getUpdateStatsAction() # test from api - self.assertEqual(self.widget.getUpdateMode(), StatsWidget.UpdateMode.AUTO) + self.assertEqual( + self.widget.getUpdateMode(), StatsWidget.UpdateMode.AUTO + ) self.widget.show() # check stats change in auto mode - self.plot.getCurve('curve0').setData(x=range(4), y=range(-1, 3)) + self.plot.getCurve("curve0").setData(x=range(4), y=range(-1, 3)) self.qapp.processEvents() - tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0')) - curve0_min = tableItems['min'].text() - self.assertTrue(float(curve0_min) == -1.) + tableItems = self.statsTable._itemToTableItems( + self.plot.getCurve("curve0") + ) + curve0_min = tableItems["min"].text() + self.assertTrue(float(curve0_min) == -1.0) - self.plot.getCurve('curve0').setData(x=range(4), y=range(1, 5)) + self.plot.getCurve("curve0").setData(x=range(4), y=range(1, 5)) self.qapp.processEvents() - tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0')) - curve0_min = tableItems['min'].text() - self.assertTrue(float(curve0_min) == 1.) + tableItems = self.statsTable._itemToTableItems( + self.plot.getCurve("curve0") + ) + curve0_min = tableItems["min"].text() + self.assertTrue(float(curve0_min) == 1.0) # check stats change in manual mode only if requested self.widget.setUpdateMode(StatsWidget.UpdateMode.MANUAL) - self.assertEqual(self.widget.getUpdateMode(), StatsWidget.UpdateMode.MANUAL) + self.assertEqual( + self.widget.getUpdateMode(), StatsWidget.UpdateMode.MANUAL + ) - self.plot.getCurve('curve0').setData(x=range(4), y=range(2, 6)) + self.plot.getCurve("curve0").setData(x=range(4), y=range(2, 6)) self.qapp.processEvents() - tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0')) - curve0_min = tableItems['min'].text() - self.assertTrue(float(curve0_min) == 1.) + tableItems = self.statsTable._itemToTableItems( + self.plot.getCurve("curve0") + ) + curve0_min = tableItems["min"].text() + self.assertTrue(float(curve0_min) == 1.0) update_stats_action.trigger() - tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0')) - curve0_min = tableItems['min'].text() - self.assertTrue(float(curve0_min) == 2.) + tableItems = self.statsTable._itemToTableItems( + self.plot.getCurve("curve0") + ) + curve0_min = tableItems["min"].text() + self.assertTrue(float(curve0_min) == 2.0) def testItemHidden(self): """Test if an item is hide, then the associated stats item is also hide""" - curve0 = self.plot.getCurve('curve0') - curve1 = self.plot.getCurve('curve1') - curve2 = self.plot.getCurve('curve2') + curve0 = self.plot.getCurve("curve0") + curve1 = self.plot.getCurve("curve1") + curve2 = self.plot.getCurve("curve2") self.plot.show() self.widget.show() @@ -563,8 +605,8 @@ class TestStatsWidgetWithCurves(TestCaseQt, ParametricTestCase): self.qapp.processEvents() self.assertTrue(self.statsTable.isRowHidden(1)) tableItems = self.statsTable._itemToTableItems(curve2) - curve2_min = tableItems['min'].text() - self.assertTrue(float(curve2_min) == -2.) + curve2_min = tableItems["min"].text() + self.assertTrue(float(curve2_min) == -2.0) curve0.setVisible(False) curve1.setVisible(False) @@ -578,27 +620,38 @@ class TestStatsWidgetWithCurves(TestCaseQt, ParametricTestCase): class TestStatsWidgetWithImages(TestCaseQt): """Basic test for StatsWidget with images""" - IMAGE_LEGEND = 'test image' + IMAGE_LEGEND = "test image" def setUp(self): TestCaseQt.setUp(self) self.plot = Plot2D() - self.plot.addImage(data=numpy.arange(128*128).reshape(128, 128), - legend=self.IMAGE_LEGEND, replace=False) + self.plot.addImage( + data=numpy.arange(128 * 128).reshape(128, 128), + legend=self.IMAGE_LEGEND, + replace=False, + ) self.widget = StatsWidget.StatsTable(plot=self.plot) - mystats = statshandler.StatsHandler(( - (stats.StatMin(), statshandler.StatFormatter()), - (stats.StatCoordMin(), statshandler.StatFormatter(None, qt.QTableWidgetItem)), - (stats.StatMax(), statshandler.StatFormatter()), - (stats.StatCoordMax(), statshandler.StatFormatter(None, qt.QTableWidgetItem)), - (stats.StatDelta(), statshandler.StatFormatter()), - ('std', numpy.std), - ('mean', numpy.mean), - (stats.StatCOM(), statshandler.StatFormatter(None)) - )) + mystats = statshandler.StatsHandler( + ( + (stats.StatMin(), statshandler.StatFormatter()), + ( + stats.StatCoordMin(), + statshandler.StatFormatter(None, qt.QTableWidgetItem), + ), + (stats.StatMax(), statshandler.StatFormatter()), + ( + stats.StatCoordMax(), + statshandler.StatFormatter(None, qt.QTableWidgetItem), + ), + (stats.StatDelta(), statshandler.StatFormatter()), + ("std", numpy.std), + ("mean", numpy.mean), + (stats.StatCOM(), statshandler.StatFormatter(None)), + ) + ) self.widget.setStats(mystats) @@ -613,17 +666,16 @@ class TestStatsWidgetWithImages(TestCaseQt): TestCaseQt.tearDown(self) def test(self): - image = self.plot._getItem( - kind='image', legend=self.IMAGE_LEGEND) + image = self.plot._getItem(kind="image", legend=self.IMAGE_LEGEND) tableItems = self.widget._itemToTableItems(image) - maxText = '{0:.3f}'.format((128 * 128) - 1) - self.assertEqual(tableItems['legend'].text(), self.IMAGE_LEGEND) - self.assertEqual(tableItems['min'].text(), '0.000') - self.assertEqual(tableItems['max'].text(), maxText) - self.assertEqual(tableItems['delta'].text(), maxText) - self.assertEqual(tableItems['coords min'].text(), '0.0, 0.0') - self.assertEqual(tableItems['coords max'].text(), '127.0, 127.0') + maxText = "{0:.3f}".format((128 * 128) - 1) + self.assertEqual(tableItems["legend"].text(), self.IMAGE_LEGEND) + self.assertEqual(tableItems["min"].text(), "0.000") + self.assertEqual(tableItems["max"].text(), maxText) + self.assertEqual(tableItems["delta"].text(), maxText) + self.assertEqual(tableItems["coords min"].text(), "0.0, 0.0") + self.assertEqual(tableItems["coords max"].text(), "127.0, 127.0") def testItemHidden(self): """Test if an item is hide, then the associated stats item is also @@ -638,28 +690,37 @@ class TestStatsWidgetWithImages(TestCaseQt): class TestStatsWidgetWithScatters(TestCaseQt): - - SCATTER_LEGEND = 'scatter plot' + SCATTER_LEGEND = "scatter plot" def setUp(self): TestCaseQt.setUp(self) self.scatterPlot = Plot2D() - self.scatterPlot.addScatter([0, 1, 2, 20, 50, 60], - [2, 3, 4, 26, 69, 6], - [5, 6, 7, 10, 90, 20], - legend=self.SCATTER_LEGEND) + self.scatterPlot.addScatter( + [0, 1, 2, 20, 50, 60], + [2, 3, 4, 26, 69, 6], + [5, 6, 7, 10, 90, 20], + legend=self.SCATTER_LEGEND, + ) self.widget = StatsWidget.StatsTable(plot=self.scatterPlot) - mystats = statshandler.StatsHandler(( - stats.StatMin(), - (stats.StatCoordMin(), statshandler.StatFormatter(None, qt.QTableWidgetItem)), - stats.StatMax(), - (stats.StatCoordMax(), statshandler.StatFormatter(None, qt.QTableWidgetItem)), - stats.StatDelta(), - ('std', numpy.std), - ('mean', numpy.mean), - stats.StatCOM() - )) + mystats = statshandler.StatsHandler( + ( + stats.StatMin(), + ( + stats.StatCoordMin(), + statshandler.StatFormatter(None, qt.QTableWidgetItem), + ), + stats.StatMax(), + ( + stats.StatCoordMax(), + statshandler.StatFormatter(None, qt.QTableWidgetItem), + ), + stats.StatDelta(), + ("std", numpy.std), + ("mean", numpy.mean), + stats.StatCOM(), + ) + ) self.widget.setStats(mystats) @@ -674,15 +735,14 @@ class TestStatsWidgetWithScatters(TestCaseQt): TestCaseQt.tearDown(self) def testStats(self): - scatter = self.scatterPlot._getItem( - kind='scatter', legend=self.SCATTER_LEGEND) + scatter = self.scatterPlot._getItem(kind="scatter", legend=self.SCATTER_LEGEND) tableItems = self.widget._itemToTableItems(scatter) - self.assertEqual(tableItems['legend'].text(), self.SCATTER_LEGEND) - self.assertEqual(tableItems['min'].text(), '5') - self.assertEqual(tableItems['coords min'].text(), '0, 2') - self.assertEqual(tableItems['max'].text(), '90') - self.assertEqual(tableItems['coords max'].text(), '50, 69') - self.assertEqual(tableItems['delta'].text(), '85') + self.assertEqual(tableItems["legend"].text(), self.SCATTER_LEGEND) + self.assertEqual(tableItems["min"].text(), "5") + self.assertEqual(tableItems["coords min"].text(), "0, 2") + self.assertEqual(tableItems["max"].text(), "90") + self.assertEqual(tableItems["coords max"].text(), "50, 69") + self.assertEqual(tableItems["delta"].text(), "85") class TestEmptyStatsWidget(TestCaseQt): @@ -694,25 +754,26 @@ class TestEmptyStatsWidget(TestCaseQt): class TestLineWidget(TestCaseQt): """Some test for the StatsLineWidget.""" + def setUp(self): TestCaseQt.setUp(self) - mystats = statshandler.StatsHandler(( - (stats.StatMin(), statshandler.StatFormatter()), - )) + mystats = statshandler.StatsHandler( + ((stats.StatMin(), statshandler.StatFormatter()),) + ) self.plot = Plot1D() self.plot.show() self.x = range(20) self.y0 = range(20) - self.curve0 = self.plot.addCurve(self.x, self.y0, legend='curve0') + self.plot.addCurve(self.x, self.y0, legend="curve0") self.y1 = range(12, 32) - self.plot.addCurve(self.x, self.y1, legend='curve1') + self.plot.addCurve(self.x, self.y1, legend="curve1") self.y2 = range(-2, 18) - self.plot.addCurve(self.x, self.y2, legend='curve2') - self.widget = StatsWidget.BasicGridStatsWidget(plot=self.plot, - kind='curve', - stats=mystats) + self.plot.addCurve(self.x, self.y2, legend="curve2") + self.widget = StatsWidget.BasicGridStatsWidget( + plot=self.plot, kind="curve", stats=mystats + ) def tearDown(self): Stats._getContext.cache_clear() @@ -730,27 +791,37 @@ class TestLineWidget(TestCaseQt): def testProcessing(self): self.widget._lineStatsWidget.setStatsOnVisibleData(False) self.qapp.processEvents() - self.plot.setActiveCurve(legend='curve0') - self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '0.000') - self.plot.setActiveCurve(legend='curve1') - self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '12.000') + self.plot.setActiveCurve(legend="curve0") + self.assertTrue( + self.widget._lineStatsWidget._statQlineEdit["min"].text() == "0.000" + ) + self.plot.setActiveCurve(legend="curve1") + self.assertTrue( + self.widget._lineStatsWidget._statQlineEdit["min"].text() == "12.000" + ) self.plot.getXAxis().setLimitsConstraints(minPos=2, maxPos=5) self.widget.setStatsOnVisibleData(True) self.qapp.processEvents() - self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '14.000') + self.assertTrue( + self.widget._lineStatsWidget._statQlineEdit["min"].text() == "14.000" + ) self.plot.setActiveCurve(None) self.assertIsNone(self.plot.getActiveCurve()) self.widget.setStatsOnVisibleData(False) self.qapp.processEvents() - self.assertFalse(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '14.000') - self.widget.setKind('image') - self.plot.addImage(numpy.arange(100*100).reshape(100, 100) + 0.312) + self.assertFalse( + self.widget._lineStatsWidget._statQlineEdit["min"].text() == "14.000" + ) + self.widget.setKind("image") + self.plot.addImage(numpy.arange(100 * 100).reshape(100, 100) + 0.312) self.qapp.processEvents() - self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '0.312') + self.assertTrue( + self.widget._lineStatsWidget._statQlineEdit["min"].text() == "0.312" + ) def testUpdateMode(self): """Make sure the update modes are well take into account""" - self.plot.setActiveCurve(self.curve0) + self.plot.setActiveCurve("curve0") _autoRB = self.widget._options._autoRB _manualRB = self.widget._options._manualRB # test from api @@ -759,10 +830,10 @@ class TestLineWidget(TestCaseQt): self.assertFalse(_manualRB.isChecked()) # check stats change in auto mode - curve0_min = self.widget._lineStatsWidget._statQlineEdit['min'].text() + curve0_min = self.widget._lineStatsWidget._statQlineEdit["min"].text() new_y = numpy.array(self.y0) - 2.56 - self.plot.addCurve(x=self.x, y=new_y, legend=self.curve0) - curve0_min2 = self.widget._lineStatsWidget._statQlineEdit['min'].text() + self.plot.addCurve(x=self.x, y=new_y, legend="curve0") + curve0_min2 = self.widget._lineStatsWidget._statQlineEdit["min"].text() self.assertTrue(curve0_min != curve0_min2) # check stats change in manual mode only if requested @@ -771,11 +842,11 @@ class TestLineWidget(TestCaseQt): self.assertTrue(_manualRB.isChecked()) new_y = numpy.array(self.y0) - 1.2 - self.plot.addCurve(x=self.x, y=new_y, legend=self.curve0) - curve0_min3 = self.widget._lineStatsWidget._statQlineEdit['min'].text() + self.plot.addCurve(x=self.x, y=new_y, legend="curve0") + curve0_min3 = self.widget._lineStatsWidget._statQlineEdit["min"].text() self.assertTrue(curve0_min3 == curve0_min2) self.widget._options._updateRequested() - curve0_min3 = self.widget._lineStatsWidget._statQlineEdit['min'].text() + curve0_min3 = self.widget._lineStatsWidget._statQlineEdit["min"].text() self.assertTrue(curve0_min3 != curve0_min2) # test from gui @@ -791,6 +862,7 @@ class TestLineWidget(TestCaseQt): class TestUpdateModeWidget(TestCaseQt): """Test UpdateModeWidget""" + def setUp(self): TestCaseQt.setUp(self) self.widget = StatsWidget.UpdateModeWidget(parent=None) @@ -832,6 +904,7 @@ class TestStatsROI(TestStatsBase, TestCaseQt): """ Test stats based on ROI """ + def setUp(self): TestCaseQt.setUp(self) self.createRois() @@ -855,7 +928,7 @@ class TestStatsROI(TestStatsBase, TestCaseQt): TestCaseQt.tearDown(self) def createRois(self): - self._1Droi = ROI(name='my1DRoi', fromdata=2.0, todata=5.0) + self._1Droi = ROI(name="my1DRoi", fromdata=2.0, todata=5.0) self._2Droi_rect = RectangleROI() self._2Droi_rect.setGeometry(size=(10, 10), origin=(10, 0)) self._2Droi_poly = PolygonROI() @@ -865,30 +938,32 @@ class TestStatsROI(TestStatsBase, TestCaseQt): def createCurveContext(self): TestStatsBase.createCurveContext(self) self.curveContext = stats._CurveContext( - item=self.plot1d.getCurve('curve0'), + item=self.plot1d.getCurve("curve0"), plot=self.plot1d, onlimits=False, - roi=self._1Droi) + roi=self._1Droi, + ) def createHistogramContext(self): self.plotHisto = Plot1D() x = range(20) y = range(20) - self.plotHisto.addHistogram(x, y, legend='histo0') + self.plotHisto.addHistogram(x, y, legend="histo0") self.histoContext = stats._HistogramContext( - item=self.plotHisto.getHistogram('histo0'), + item=self.plotHisto.getHistogram("histo0"), plot=self.plotHisto, onlimits=False, - roi=self._1Droi) + roi=self._1Droi, + ) def createScatterContext(self): TestStatsBase.createScatterContext(self) self.scatterContext = stats._ScatterContext( - item=self.scatterPlot.getScatter('scatter plot'), + item=self.scatterPlot.getScatter("scatter plot"), plot=self.scatterPlot, onlimits=False, - roi=self._1Droi + roi=self._1Droi, ) def createImageContext(self): @@ -898,56 +973,68 @@ class TestStatsROI(TestStatsBase, TestCaseQt): item=self.plot2d.getImage(self._imgLgd), plot=self.plot2d, onlimits=False, - roi=self._2Droi_rect + roi=self._2Droi_rect, ) self.imageContext_2 = stats._ImageContext( item=self.plot2d.getImage(self._imgLgd), plot=self.plot2d, onlimits=False, - roi=self._2Droi_poly + roi=self._2Droi_poly, ) def testErrors(self): # test if onlimits is True and give also a roi with self.assertRaises(ValueError): - stats._CurveContext(item=self.plot1d.getCurve('curve0'), - plot=self.plot1d, - onlimits=True, - roi=self._1Droi) + stats._CurveContext( + item=self.plot1d.getCurve("curve0"), + plot=self.plot1d, + onlimits=True, + roi=self._1Droi, + ) # test if is a curve context and give an invalid 2D roi with self.assertRaises(TypeError): - stats._CurveContext(item=self.plot1d.getCurve('curve0'), - plot=self.plot1d, - onlimits=False, - roi=self._2Droi_rect) + stats._CurveContext( + item=self.plot1d.getCurve("curve0"), + plot=self.plot1d, + onlimits=False, + roi=self._2Droi_rect, + ) def testBasicStatsCurve(self): """Test result for simple stats on a curve""" _stats = self.getBasicStats() xData = yData = numpy.array(range(0, 10)) - self.assertEqual(_stats['min'].calculate(self.curveContext), 2) - self.assertEqual(_stats['max'].calculate(self.curveContext), 5) - self.assertEqual(_stats['minCoords'].calculate(self.curveContext), (2,)) - self.assertEqual(_stats['maxCoords'].calculate(self.curveContext), (5,)) - self.assertEqual(_stats['std'].calculate(self.curveContext), numpy.std(yData[2:6])) - self.assertEqual(_stats['mean'].calculate(self.curveContext), numpy.mean(yData[2:6])) + self.assertEqual(_stats["min"].calculate(self.curveContext), 2) + self.assertEqual(_stats["max"].calculate(self.curveContext), 5) + self.assertEqual(_stats["minCoords"].calculate(self.curveContext), (2,)) + self.assertEqual(_stats["maxCoords"].calculate(self.curveContext), (5,)) + self.assertEqual( + _stats["std"].calculate(self.curveContext), numpy.std(yData[2:6]) + ) + self.assertEqual( + _stats["mean"].calculate(self.curveContext), numpy.mean(yData[2:6]) + ) com = numpy.sum(xData[2:6] * yData[2:6]) / numpy.sum(yData[2:6]) - self.assertEqual(_stats['com'].calculate(self.curveContext), com) + self.assertEqual(_stats["com"].calculate(self.curveContext), com) def testBasicStatsImageRectRoi(self): """Test result for simple stats on an image""" self.assertEqual(self.imageContext.values.compressed().size, 121) _stats = self.getBasicStats() - self.assertEqual(_stats['min'].calculate(self.imageContext), 10) - self.assertEqual(_stats['max'].calculate(self.imageContext), 1300) - self.assertEqual(_stats['minCoords'].calculate(self.imageContext), (10, 0)) - self.assertEqual(_stats['maxCoords'].calculate(self.imageContext), (20.0, 10.0)) - self.assertAlmostEqual(_stats['std'].calculate(self.imageContext), - numpy.std(self.imageData[0:11, 10:21])) - self.assertAlmostEqual(_stats['mean'].calculate(self.imageContext), - numpy.mean(self.imageData[0:11, 10:21])) + self.assertEqual(_stats["min"].calculate(self.imageContext), 10) + self.assertEqual(_stats["max"].calculate(self.imageContext), 1300) + self.assertEqual(_stats["minCoords"].calculate(self.imageContext), (10, 0)) + self.assertEqual(_stats["maxCoords"].calculate(self.imageContext), (20.0, 10.0)) + self.assertAlmostEqual( + _stats["std"].calculate(self.imageContext), + numpy.std(self.imageData[0:11, 10:21]), + ) + self.assertAlmostEqual( + _stats["mean"].calculate(self.imageContext), + numpy.mean(self.imageData[0:11, 10:21]), + ) compressed_values = self.imageContext.values.compressed() compressed_values = compressed_values.reshape(11, 11) @@ -957,41 +1044,47 @@ class TestStatsROI(TestStatsBase, TestCaseQt): dataYRange = range(11) dataXRange = range(10, 21) - ycom = numpy.sum(yData*dataYRange) / numpy.sum(yData) - xcom = numpy.sum(xData*dataXRange) / numpy.sum(xData) - self.assertEqual(_stats['com'].calculate(self.imageContext), (xcom, ycom)) + ycom = numpy.sum(yData * dataYRange) / numpy.sum(yData) + xcom = numpy.sum(xData * dataXRange) / numpy.sum(xData) + self.assertEqual(_stats["com"].calculate(self.imageContext), (xcom, ycom)) def testBasicStatsImagePolyRoi(self): """Test a simple rectangle ROI""" _stats = self.getBasicStats() - self.assertEqual(_stats['min'].calculate(self.imageContext_2), 0) - self.assertEqual(_stats['max'].calculate(self.imageContext_2), 2432) - self.assertEqual(_stats['minCoords'].calculate(self.imageContext_2), (0.0, 0.0)) + self.assertEqual(_stats["min"].calculate(self.imageContext_2), 0) + self.assertEqual(_stats["max"].calculate(self.imageContext_2), 2432) + self.assertEqual(_stats["minCoords"].calculate(self.imageContext_2), (0.0, 0.0)) # not 0.0, 19.0 because not fully in. Should all pixel have a weight, # on to manage them in stats. For now 0 if the center is not in, else 1 - self.assertEqual(_stats['maxCoords'].calculate(self.imageContext_2), (0.0, 19.0)) + self.assertEqual( + _stats["maxCoords"].calculate(self.imageContext_2), (0.0, 19.0) + ) def testBasicStatsScatter(self): self.assertEqual(self.scatterContext.values.compressed().size, 2) _stats = self.getBasicStats() - self.assertEqual(_stats['min'].calculate(self.scatterContext), 6) - self.assertEqual(_stats['max'].calculate(self.scatterContext), 7) - self.assertEqual(_stats['minCoords'].calculate(self.scatterContext), (2, 3)) - self.assertEqual(_stats['maxCoords'].calculate(self.scatterContext), (3, 4)) - self.assertEqual(_stats['std'].calculate(self.scatterContext), numpy.std([6, 7])) - self.assertEqual(_stats['mean'].calculate(self.scatterContext), numpy.mean([6, 7])) + self.assertEqual(_stats["min"].calculate(self.scatterContext), 6) + self.assertEqual(_stats["max"].calculate(self.scatterContext), 7) + self.assertEqual(_stats["minCoords"].calculate(self.scatterContext), (2, 3)) + self.assertEqual(_stats["maxCoords"].calculate(self.scatterContext), (3, 4)) + self.assertEqual( + _stats["std"].calculate(self.scatterContext), numpy.std([6, 7]) + ) + self.assertEqual( + _stats["mean"].calculate(self.scatterContext), numpy.mean([6, 7]) + ) def testBasicHistogram(self): _stats = self.getBasicStats() xData = yData = numpy.array(range(2, 6)) - self.assertEqual(_stats['min'].calculate(self.histoContext), 2) - self.assertEqual(_stats['max'].calculate(self.histoContext), 5) - self.assertEqual(_stats['minCoords'].calculate(self.histoContext), (2,)) - self.assertEqual(_stats['maxCoords'].calculate(self.histoContext), (5,)) - self.assertEqual(_stats['std'].calculate(self.histoContext), numpy.std(yData)) - self.assertEqual(_stats['mean'].calculate(self.histoContext), numpy.mean(yData)) + self.assertEqual(_stats["min"].calculate(self.histoContext), 2) + self.assertEqual(_stats["max"].calculate(self.histoContext), 5) + self.assertEqual(_stats["minCoords"].calculate(self.histoContext), (2,)) + self.assertEqual(_stats["maxCoords"].calculate(self.histoContext), (5,)) + self.assertEqual(_stats["std"].calculate(self.histoContext), numpy.std(yData)) + self.assertEqual(_stats["mean"].calculate(self.histoContext), numpy.mean(yData)) com = numpy.sum(xData * yData) / numpy.sum(yData) - self.assertEqual(_stats['com'].calculate(self.histoContext), com) + self.assertEqual(_stats["com"].calculate(self.histoContext), com) class TestAdvancedROIImageContext(TestCaseQt): @@ -1016,31 +1109,35 @@ class TestAdvancedROIImageContext(TestCaseQt): roi_origins = [(0, 0), (2, 10), (14, 20)] img_origins = [(0, 0), (14, 20), (2, 10)] img_scales = [1.0, 0.5, 2.0] - _stats = {'sum': stats.Stat(name='sum', fct=numpy.sum), } + _stats = { + "sum": stats.Stat(name="sum", fct=numpy.sum), + } for roi_origin in roi_origins: for img_origin in img_origins: for img_scale in img_scales: - with self.subTest(roi_origin=roi_origin, - img_origin=img_origin, - img_scale=img_scale): - self.plot.addImage(self.data, legend='img', - origin=img_origin, - scale=img_scale) + with self.subTest( + roi_origin=roi_origin, + img_origin=img_origin, + img_scale=img_scale, + ): + self.plot.addImage( + self.data, legend="img", origin=img_origin, scale=img_scale + ) roi = RectangleROI() roi.setGeometry(origin=roi_origin, size=(20, 20)) context = stats._ImageContext( - item=self.plot.getImage('img'), + item=self.plot.getImage("img"), plot=self.plot, onlimits=False, - roi=roi) + roi=roi, + ) x_start = int((roi_origin[0] - img_origin[0]) / img_scale) x_end = int(x_start + (20 / img_scale)) + 1 - y_start = int((roi_origin[1] - img_origin[1])/ img_scale) + y_start = int((roi_origin[1] - img_origin[1]) / img_scale) y_end = int(y_start + (20 / img_scale)) + 1 x_start = max(x_start, 0) x_end = min(max(x_end, 0), self.data_dims[1]) y_start = max(y_start, 0) y_end = min(max(y_end, 0), self.data_dims[0]) th_sum = numpy.sum(self.data[y_start:y_end, x_start:x_end]) - self.assertAlmostEqual(_stats['sum'].calculate(context), - th_sum) + self.assertAlmostEqual(_stats["sum"].calculate(context), th_sum) |