summaryrefslogtreecommitdiff
path: root/silx/math/colormap.pyx
diff options
context:
space:
mode:
Diffstat (limited to 'silx/math/colormap.pyx')
-rw-r--r--silx/math/colormap.pyx24
1 files changed, 20 insertions, 4 deletions
diff --git a/silx/math/colormap.pyx b/silx/math/colormap.pyx
index 2495f3c..2cefe04 100644
--- a/silx/math/colormap.pyx
+++ b/silx/math/colormap.pyx
@@ -56,6 +56,8 @@ else: # Fallback
DEFAULT_NUM_THREADS = 1
# Number of threads to use for the computation (initialized to up to 4)
+cdef int USE_OPENMP_THRESHOLD = 1000
+"""OpenMP is not used for arrays with less elements than this threshold"""
# Supported data types
ctypedef fused data_types:
@@ -312,7 +314,7 @@ cdef image_types[:, ::1] compute_cmap(
cdef image_types[:, ::1] output
cdef double scale, value, normalized_vmin, normalized_vmax
cdef int length, nb_channels, nb_colors
- cdef int channel, index, lut_index
+ cdef int channel, index, lut_index, num_threads
nb_colors = <int> colors.shape[0]
nb_channels = <int> colors.shape[1]
@@ -332,8 +334,15 @@ cdef image_types[:, ::1] compute_cmap(
else:
scale = nb_colors / (normalized_vmax - normalized_vmin)
+ if length < USE_OPENMP_THRESHOLD:
+ num_threads = 1
+ else:
+ num_threads = min(
+ DEFAULT_NUM_THREADS,
+ int(os.environ.get("OMP_NUM_THREADS", DEFAULT_NUM_THREADS)))
+
with nogil:
- for index in prange(length, num_threads=DEFAULT_NUM_THREADS):
+ for index in prange(length, num_threads=num_threads):
value = normalization.apply_double(
<double> data[index], vmin, vmax)
@@ -386,7 +395,7 @@ cdef image_types[:, ::1] compute_cmap_with_lut(
cdef image_types[:, ::1] lut
cdef int type_min, type_max
cdef int nb_channels, length
- cdef int channel, index, lut_index
+ cdef int channel, index, lut_index, num_threads
length = <int> data.size
nb_channels = <int> colors.shape[1]
@@ -412,9 +421,16 @@ cdef image_types[:, ::1] compute_cmap_with_lut(
output = numpy.empty((length, nb_channels), dtype=colors_dtype)
+ if length < USE_OPENMP_THRESHOLD:
+ num_threads = 1
+ else:
+ num_threads = min(
+ DEFAULT_NUM_THREADS,
+ int(os.environ.get("OMP_NUM_THREADS", DEFAULT_NUM_THREADS)))
+
with nogil:
# Apply LUT
- for index in prange(length, num_threads=DEFAULT_NUM_THREADS):
+ for index in prange(length, num_threads=num_threads):
lut_index = data[index] - type_min
for channel in range(nb_channels):
output[index, channel] = lut[lut_index, channel]