diff options
Diffstat (limited to 'src/silx/opencl/test/test_stats.py')
-rw-r--r-- | src/silx/opencl/test/test_stats.py | 49 |
1 files changed, 36 insertions, 13 deletions
diff --git a/src/silx/opencl/test/test_stats.py b/src/silx/opencl/test/test_stats.py index f8ab1a7..7637211 100644 --- a/src/silx/opencl/test/test_stats.py +++ b/src/silx/opencl/test/test_stats.py @@ -39,17 +39,18 @@ import numpy import unittest from ..common import ocl + if ocl: import pyopencl import pyopencl.array from ..statistics import StatResults, Statistics from ..utils import get_opencl_code + logger = logging.getLogger(__name__) @unittest.skipUnless(ocl, "PyOpenCl is missing") class TestStatistics(unittest.TestCase): - @classmethod def setUpClass(cls): cls.size = 1 << 20 # 1 million elements @@ -57,9 +58,15 @@ class TestStatistics(unittest.TestCase): fdata = cls.data.astype("float64") t0 = time.perf_counter() std = fdata.std() - cls.ref = StatResults(fdata.min(), fdata.max(), float(fdata.size), - fdata.sum(), fdata.mean(), std ** 2, - std) + cls.ref = StatResults( + fdata.min(), + fdata.max(), + float(fdata.size), + fdata.sum(), + fdata.mean(), + std**2, + std, + ) t1 = time.perf_counter() cls.ref_time = t1 - t0 @@ -70,11 +77,12 @@ class TestStatistics(unittest.TestCase): @classmethod def validate(cls, res): return ( - (res.min == cls.ref.min) and - (res.max == cls.ref.max) and - (res.cnt == cls.ref.cnt) and - abs(res.mean - cls.ref.mean) < 0.01 and - abs(res.std - cls.ref.std) < 0.1) + (res.min == cls.ref.min) + and (res.max == cls.ref.max) + and (res.cnt == cls.ref.cnt) + and abs(res.mean - cls.ref.mean) < 0.01 + and abs(res.std - cls.ref.std) < 0.1 + ) def test_measurement(self): """ @@ -95,11 +103,26 @@ class TestStatistics(unittest.TestCase): t0 = time.perf_counter() res = s(self.data, comp=comp) t1 = time.perf_counter() - logger.info("Runtime on %s/%s : %.3fms x%.1f", platform, device, 1000 * (t1 - t0), self.ref_time / (t1 - t0)) + logger.info( + "Runtime on %s/%s : %.3fms x%.1f", + platform, + device, + 1000 * (t1 - t0), + self.ref_time / (t1 - t0), + ) if failed_init or not self.validate(res): - logger.error("failed_init %s; Computation modes %s", failed_init, comp) - logger.error("Failed on platform %s device %s", platform, device) + logger.error( + "failed_init %s; Computation modes %s", + failed_init, + comp, + ) + logger.error( + "Failed on platform %s device %s", platform, device + ) logger.error("Reference results: %s", self.ref) logger.error("Faulty results: %s", res) - self.assertTrue(False, f"Stat calculation failed on {platform},{device} in mode {comp}") + self.assertTrue( + False, + f"Stat calculation failed on {platform},{device} in mode {comp}", + ) |