diff options
Diffstat (limited to 'silx/math/colormap.pyx')
-rw-r--r-- | silx/math/colormap.pyx | 24 |
1 files changed, 20 insertions, 4 deletions
diff --git a/silx/math/colormap.pyx b/silx/math/colormap.pyx index 2495f3c..2cefe04 100644 --- a/silx/math/colormap.pyx +++ b/silx/math/colormap.pyx @@ -56,6 +56,8 @@ else: # Fallback DEFAULT_NUM_THREADS = 1 # Number of threads to use for the computation (initialized to up to 4) +cdef int USE_OPENMP_THRESHOLD = 1000 +"""OpenMP is not used for arrays with less elements than this threshold""" # Supported data types ctypedef fused data_types: @@ -312,7 +314,7 @@ cdef image_types[:, ::1] compute_cmap( cdef image_types[:, ::1] output cdef double scale, value, normalized_vmin, normalized_vmax cdef int length, nb_channels, nb_colors - cdef int channel, index, lut_index + cdef int channel, index, lut_index, num_threads nb_colors = <int> colors.shape[0] nb_channels = <int> colors.shape[1] @@ -332,8 +334,15 @@ cdef image_types[:, ::1] compute_cmap( else: scale = nb_colors / (normalized_vmax - normalized_vmin) + if length < USE_OPENMP_THRESHOLD: + num_threads = 1 + else: + num_threads = min( + DEFAULT_NUM_THREADS, + int(os.environ.get("OMP_NUM_THREADS", DEFAULT_NUM_THREADS))) + with nogil: - for index in prange(length, num_threads=DEFAULT_NUM_THREADS): + for index in prange(length, num_threads=num_threads): value = normalization.apply_double( <double> data[index], vmin, vmax) @@ -386,7 +395,7 @@ cdef image_types[:, ::1] compute_cmap_with_lut( cdef image_types[:, ::1] lut cdef int type_min, type_max cdef int nb_channels, length - cdef int channel, index, lut_index + cdef int channel, index, lut_index, num_threads length = <int> data.size nb_channels = <int> colors.shape[1] @@ -412,9 +421,16 @@ cdef image_types[:, ::1] compute_cmap_with_lut( output = numpy.empty((length, nb_channels), dtype=colors_dtype) + if length < USE_OPENMP_THRESHOLD: + num_threads = 1 + else: + num_threads = min( + DEFAULT_NUM_THREADS, + int(os.environ.get("OMP_NUM_THREADS", DEFAULT_NUM_THREADS))) + with nogil: # Apply LUT - for index in prange(length, num_threads=DEFAULT_NUM_THREADS): + for index in prange(length, num_threads=num_threads): lut_index = data[index] - type_min for channel in range(nb_channels): output[index, channel] = lut[lut_index, channel] |