Skip to content

palette

Palette

Colour palette methods.

Source code in phomo/palette.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
class Palette:
    """Colour palette methods."""

    array: np.ndarray

    def __init__(self, array: ArrayLike):
        self.array = np.array(array)
        self.plot = PalettePlotter(self)

    @property
    def pixels(self) -> np.ndarray:
        """Returns flattened pixels from the array."""
        return self.array.reshape(-1, self.array.shape[-1])

    def equalize(self):
        """Equalize the colour distribution using `cv2.equalizeHist`.

        Returns:
            A new `Palette` with equalized colour distribution.
        """
        # the array of the pool is (n_tiles, height, width, 3)
        # the array of the master is (height, width, 3)
        # so we flatten until the colour channels
        out_shape = self.array.shape
        array = self.array.reshape(-1, 3)
        matched_image = np.zeros_like(array)
        for i in range(3):  # Assuming 3 channels (RGB)
            matched_image[:, i] = cv2.equalizeHist(array[:, i]).squeeze()
        return self.__class__(matched_image.reshape(out_shape))

    def match(self, other: "Palette"):
        """Match the colour distribution of this object to the colour distribution of the
        `other` using the Reinhard colour transfer algorithm.

        See:
            https://api.semanticscholar.org/CorpusID:14088925

        Args:
            The other `Palette` to match this `Palette`'s colour distribution to.

        Returns:
            A new `Palette` with it's colour distribution matched the `other` `Palette`.
        """
        self_shape = self.array.shape
        self_array = self.array.reshape(-1, self.array.shape[1], 3)
        target_array = other.array.reshape(-1, other.array.shape[1], 3)

        source_lab = cv2.cvtColor(self_array, cv2.COLOR_RGB2LAB)
        target_lab = cv2.cvtColor(target_array, cv2.COLOR_RGB2LAB)

        # Compute the mean and standard deviation of each channel
        src_mean, src_std = cv2.meanStdDev(source_lab)
        tgt_mean, tgt_std = cv2.meanStdDev(target_lab)

        src_mean, src_std = src_mean.flatten(), src_std.flatten()
        tgt_mean, tgt_std = tgt_mean.flatten(), tgt_std.flatten()

        epsilon = 1e-5
        src_std = np.where(src_std < epsilon, epsilon, src_std)

        # Transfer color
        result_lab = source_lab.astype(float)
        for i in range(3):
            result_lab[:, :, i] -= src_mean[i]
            result_lab[:, :, i] = result_lab[:, :, i] * (tgt_std[i] / src_std[i])
            result_lab[:, :, i] += tgt_mean[i]

        # Clip values to valid range and convert back to uint8
        result_lab = np.clip(result_lab, 0, 255).astype(np.uint8)
        result_rgb = cv2.cvtColor(result_lab, cv2.COLOR_LAB2RGB)
        return self.__class__(result_rgb.reshape(self_shape))

pixels: np.ndarray property

Returns flattened pixels from the array.

equalize()

Equalize the colour distribution using cv2.equalizeHist.

Returns:

Type Description

A new Palette with equalized colour distribution.

Source code in phomo/palette.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def equalize(self):
    """Equalize the colour distribution using `cv2.equalizeHist`.

    Returns:
        A new `Palette` with equalized colour distribution.
    """
    # the array of the pool is (n_tiles, height, width, 3)
    # the array of the master is (height, width, 3)
    # so we flatten until the colour channels
    out_shape = self.array.shape
    array = self.array.reshape(-1, 3)
    matched_image = np.zeros_like(array)
    for i in range(3):  # Assuming 3 channels (RGB)
        matched_image[:, i] = cv2.equalizeHist(array[:, i]).squeeze()
    return self.__class__(matched_image.reshape(out_shape))

match(other)

Match the colour distribution of this object to the colour distribution of the other using the Reinhard colour transfer algorithm.

See

https://api.semanticscholar.org/CorpusID:14088925

Returns:

Type Description

A new Palette with it's colour distribution matched the other Palette.

Source code in phomo/palette.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def match(self, other: "Palette"):
    """Match the colour distribution of this object to the colour distribution of the
    `other` using the Reinhard colour transfer algorithm.

    See:
        https://api.semanticscholar.org/CorpusID:14088925

    Args:
        The other `Palette` to match this `Palette`'s colour distribution to.

    Returns:
        A new `Palette` with it's colour distribution matched the `other` `Palette`.
    """
    self_shape = self.array.shape
    self_array = self.array.reshape(-1, self.array.shape[1], 3)
    target_array = other.array.reshape(-1, other.array.shape[1], 3)

    source_lab = cv2.cvtColor(self_array, cv2.COLOR_RGB2LAB)
    target_lab = cv2.cvtColor(target_array, cv2.COLOR_RGB2LAB)

    # Compute the mean and standard deviation of each channel
    src_mean, src_std = cv2.meanStdDev(source_lab)
    tgt_mean, tgt_std = cv2.meanStdDev(target_lab)

    src_mean, src_std = src_mean.flatten(), src_std.flatten()
    tgt_mean, tgt_std = tgt_mean.flatten(), tgt_std.flatten()

    epsilon = 1e-5
    src_std = np.where(src_std < epsilon, epsilon, src_std)

    # Transfer color
    result_lab = source_lab.astype(float)
    for i in range(3):
        result_lab[:, :, i] -= src_mean[i]
        result_lab[:, :, i] = result_lab[:, :, i] * (tgt_std[i] / src_std[i])
        result_lab[:, :, i] += tgt_mean[i]

    # Clip values to valid range and convert back to uint8
    result_lab = np.clip(result_lab, 0, 255).astype(np.uint8)
    result_rgb = cv2.cvtColor(result_lab, cv2.COLOR_LAB2RGB)
    return self.__class__(result_rgb.reshape(self_shape))

PalettePlotter

Source code in phomo/palette.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
class PalettePlotter:
    def __init__(self, palette: Palette):
        self._palette = palette

    def _colour_hist(self, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
        """Compute the 1D colour distributions.

        Args:
            **kwargs: passed to `numpy.histogram`.

        Returns:
            Histogram edges and counts.
        """
        bins = kwargs.pop("bins", range(256))
        values = []
        bin_edges = []
        for i in range(self._palette.pixels.shape[1]):
            freqs, edges = np.histogram(self._palette.pixels[:, i], bins=bins, **kwargs)
            bin_edges.append(edges)
            values.append(freqs)
        values = np.vstack(values).T
        bin_edges = np.vstack(bin_edges).T
        return bin_edges, values

    def _colour_hist_3d(
        self, bins: int = 256
    ) -> Tuple[Tuple[np.ndarray, ...], np.ndarray]:
        """Compute the 3D colour distribution."""
        hist, edges = np.histogramdd(
            self._palette.pixels,
            bins=bins,
            range=[(0, 255), (0, 255), (0, 255)],
        )
        return edges, hist

    def _colour_palette(self, depth: int = 3):
        pixels = self._palette.array.reshape(-1, 3)

        def split(pixels: np.ndarray, depth: int) -> list[np.ndarray]:
            if len(pixels) == 0 or depth == 0:
                return [pixels]

            ranges = np.ptp(pixels, axis=0)
            axis = np.argmax(ranges)
            median = np.median(pixels[:, axis])

            left = pixels[pixels[:, axis] <= median]
            right = pixels[pixels[:, axis] > median]

            return split(left, depth - 1) + split(right, depth - 1)

        quantized = split(pixels, depth)

        palette = [np.mean(region, axis=0) for region in quantized if len(region) > 0]
        palette = np.array(palette, dtype=np.uint8)

        return palette[::-1]

    def palette(self, depth: int = 3) -> Tuple[Figure, np.ndarray]:
        """Show the dominant colours of the palette using a median cut algorithm.

        See:
            https://en.wikipedia.org/wiki/Median_cut

        Args:
            depth: The number of splits to perform.

        Returns:
            `Figure` and `np.array` of `Axes`.
        """
        palette = self._colour_palette(depth=depth)

        square_size = 50
        palette_ar = np.zeros(
            (square_size, len(palette) * square_size, 3), dtype="uint8"
        )

        for i, color in enumerate(palette):
            palette_ar[:, i * square_size : (i + 1) * square_size, :] = color

        fig, ax = plt.subplots(
            1,
            figsize=(5, 5 * len(palette)),
            frameon=False,
        )
        ax.imshow(palette_ar, aspect="equal")
        ax.set_axis_off()
        ax.margins(0, 0)
        fig.tight_layout(pad=0)
        return fig, ax

    def distribution(self, log: bool = False) -> Tuple[Figure, np.ndarray]:
        """Plot the colour distribution of each channel.

        Args:
            log: Plot y axis in log scale.

        Returns:
            `Figure` and `np.array` of `Axes`.
        """

        bin_edges, values = self._colour_hist()
        fig, axs = plt.subplots(3, figsize=(12, 6))
        channels = ["Red", "Green", "Blue"]
        for i, (ax, channel) in enumerate(zip(axs, channels)):
            ax.bar(
                bin_edges[:-1, i],
                values[:, i],
                width=np.diff(bin_edges[:, i]),
                align="edge",
                color=channel,
            )
            if log:
                ax.set_yscale("log")
            ax.set_title(channel)
            ax.set_xlim(0, 255)
        fig.tight_layout()
        return fig, axs

    def distribution_2d(self) -> Tuple[Figure, np.ndarray]:
        """Plot 2D projections of the 3D colour distribution.

        Returns:
            `Figure` and `np.array` of `Axes`.
        """
        _, hist = self._colour_hist_3d()
        fig, axs = plt.subplots(1, 3, figsize=(12, 4))
        axs = axs.ravel()

        titles = ["Red-Green", "Green-Blue", "Blue-Red"]
        for i, (ax, title) in enumerate(zip(axs, titles)):
            i = (i + 2) % 3
            proj = np.sum(hist, axis=i)
            if i != 1:
                proj = proj.T
            ax.imshow(
                proj,
                origin="lower",
                extent=[0, 255, 0, 255],
                aspect="auto",
                vmax=np.mean(proj) + 3 * np.std(proj),
            )
            ax.set_title(title)
            ax.set_xlabel(title.split("-")[0])
            ax.set_ylabel(title.split("-")[1])

        fig.tight_layout()
        return fig, axs

    def __call__(self):
        """Plot all the plots."""
        self.palette()
        self.distribution()
        self.distribution_2d()

__call__()

Plot all the plots.

Source code in phomo/palette.py
232
233
234
235
236
def __call__(self):
    """Plot all the plots."""
    self.palette()
    self.distribution()
    self.distribution_2d()

distribution(log=False)

Plot the colour distribution of each channel.

Parameters:

Name Type Description Default
log bool

Plot y axis in log scale.

False

Returns:

Type Description
Tuple[Figure, ndarray]

Figure and np.array of Axes.

Source code in phomo/palette.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def distribution(self, log: bool = False) -> Tuple[Figure, np.ndarray]:
    """Plot the colour distribution of each channel.

    Args:
        log: Plot y axis in log scale.

    Returns:
        `Figure` and `np.array` of `Axes`.
    """

    bin_edges, values = self._colour_hist()
    fig, axs = plt.subplots(3, figsize=(12, 6))
    channels = ["Red", "Green", "Blue"]
    for i, (ax, channel) in enumerate(zip(axs, channels)):
        ax.bar(
            bin_edges[:-1, i],
            values[:, i],
            width=np.diff(bin_edges[:, i]),
            align="edge",
            color=channel,
        )
        if log:
            ax.set_yscale("log")
        ax.set_title(channel)
        ax.set_xlim(0, 255)
    fig.tight_layout()
    return fig, axs

distribution_2d()

Plot 2D projections of the 3D colour distribution.

Returns:

Type Description
Tuple[Figure, ndarray]

Figure and np.array of Axes.

Source code in phomo/palette.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def distribution_2d(self) -> Tuple[Figure, np.ndarray]:
    """Plot 2D projections of the 3D colour distribution.

    Returns:
        `Figure` and `np.array` of `Axes`.
    """
    _, hist = self._colour_hist_3d()
    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    axs = axs.ravel()

    titles = ["Red-Green", "Green-Blue", "Blue-Red"]
    for i, (ax, title) in enumerate(zip(axs, titles)):
        i = (i + 2) % 3
        proj = np.sum(hist, axis=i)
        if i != 1:
            proj = proj.T
        ax.imshow(
            proj,
            origin="lower",
            extent=[0, 255, 0, 255],
            aspect="auto",
            vmax=np.mean(proj) + 3 * np.std(proj),
        )
        ax.set_title(title)
        ax.set_xlabel(title.split("-")[0])
        ax.set_ylabel(title.split("-")[1])

    fig.tight_layout()
    return fig, axs

palette(depth=3)

Show the dominant colours of the palette using a median cut algorithm.

See

https://en.wikipedia.org/wiki/Median_cut

Parameters:

Name Type Description Default
depth int

The number of splits to perform.

3

Returns:

Type Description
Tuple[Figure, ndarray]

Figure and np.array of Axes.

Source code in phomo/palette.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def palette(self, depth: int = 3) -> Tuple[Figure, np.ndarray]:
    """Show the dominant colours of the palette using a median cut algorithm.

    See:
        https://en.wikipedia.org/wiki/Median_cut

    Args:
        depth: The number of splits to perform.

    Returns:
        `Figure` and `np.array` of `Axes`.
    """
    palette = self._colour_palette(depth=depth)

    square_size = 50
    palette_ar = np.zeros(
        (square_size, len(palette) * square_size, 3), dtype="uint8"
    )

    for i, color in enumerate(palette):
        palette_ar[:, i * square_size : (i + 1) * square_size, :] = color

    fig, ax = plt.subplots(
        1,
        figsize=(5, 5 * len(palette)),
        frameon=False,
    )
    ax.imshow(palette_ar, aspect="equal")
    ax.set_axis_off()
    ax.margins(0, 0)
    fig.tight_layout(pad=0)
    return fig, ax