summaryrefslogtreecommitdiff
path: root/src/silx/opencl/test/test_stats.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/silx/opencl/test/test_stats.py')
-rw-r--r--src/silx/opencl/test/test_stats.py49
1 files changed, 36 insertions, 13 deletions
diff --git a/src/silx/opencl/test/test_stats.py b/src/silx/opencl/test/test_stats.py
index f8ab1a7..7637211 100644
--- a/src/silx/opencl/test/test_stats.py
+++ b/src/silx/opencl/test/test_stats.py
@@ -39,17 +39,18 @@ import numpy
import unittest
from ..common import ocl
+
if ocl:
import pyopencl
import pyopencl.array
from ..statistics import StatResults, Statistics
from ..utils import get_opencl_code
+
logger = logging.getLogger(__name__)
@unittest.skipUnless(ocl, "PyOpenCl is missing")
class TestStatistics(unittest.TestCase):
-
@classmethod
def setUpClass(cls):
cls.size = 1 << 20 # 1 million elements
@@ -57,9 +58,15 @@ class TestStatistics(unittest.TestCase):
fdata = cls.data.astype("float64")
t0 = time.perf_counter()
std = fdata.std()
- cls.ref = StatResults(fdata.min(), fdata.max(), float(fdata.size),
- fdata.sum(), fdata.mean(), std ** 2,
- std)
+ cls.ref = StatResults(
+ fdata.min(),
+ fdata.max(),
+ float(fdata.size),
+ fdata.sum(),
+ fdata.mean(),
+ std**2,
+ std,
+ )
t1 = time.perf_counter()
cls.ref_time = t1 - t0
@@ -70,11 +77,12 @@ class TestStatistics(unittest.TestCase):
@classmethod
def validate(cls, res):
return (
- (res.min == cls.ref.min) and
- (res.max == cls.ref.max) and
- (res.cnt == cls.ref.cnt) and
- abs(res.mean - cls.ref.mean) < 0.01 and
- abs(res.std - cls.ref.std) < 0.1)
+ (res.min == cls.ref.min)
+ and (res.max == cls.ref.max)
+ and (res.cnt == cls.ref.cnt)
+ and abs(res.mean - cls.ref.mean) < 0.01
+ and abs(res.std - cls.ref.std) < 0.1
+ )
def test_measurement(self):
"""
@@ -95,11 +103,26 @@ class TestStatistics(unittest.TestCase):
t0 = time.perf_counter()
res = s(self.data, comp=comp)
t1 = time.perf_counter()
- logger.info("Runtime on %s/%s : %.3fms x%.1f", platform, device, 1000 * (t1 - t0), self.ref_time / (t1 - t0))
+ logger.info(
+ "Runtime on %s/%s : %.3fms x%.1f",
+ platform,
+ device,
+ 1000 * (t1 - t0),
+ self.ref_time / (t1 - t0),
+ )
if failed_init or not self.validate(res):
- logger.error("failed_init %s; Computation modes %s", failed_init, comp)
- logger.error("Failed on platform %s device %s", platform, device)
+ logger.error(
+ "failed_init %s; Computation modes %s",
+ failed_init,
+ comp,
+ )
+ logger.error(
+ "Failed on platform %s device %s", platform, device
+ )
logger.error("Reference results: %s", self.ref)
logger.error("Faulty results: %s", res)
- self.assertTrue(False, f"Stat calculation failed on {platform},{device} in mode {comp}")
+ self.assertTrue(
+ False,
+ f"Stat calculation failed on {platform},{device} in mode {comp}",
+ )