前言
前段时间在 Halo 的 应用市场 中遇到希望主题和插件的封面图背景色为封面图主色的问题,于是乎需要根据封面图提取主色就想到使用 K-Means 算法来提取。
在图像处理中,图像是由像素点构成的,每个像素点都有一个颜色值,颜色值通常由 RGB 三个分量组成。因此,我们可以将图像看作是一个由颜色值构成的点云,每个点代表一个像素点。
为了更好地理解,我们可以将图像的颜色值可视化为一个 Scatter 3D 图。在 Scatter 3D 图中,每个点的坐标由 RGB 三个分量组成,点的颜色与其坐标对应的颜色值相同。
图像的颜色值量化
以下面的图片为例
它的色值分布为如下的图像
从上述 RGB 3D Scatter Plot
图如果将相似的颜色值归为一类可以看出图像大概有三种主色调蓝色、绿色和粉色:
如果我们从三簇中各选一个中心,如以 A、B、C三点表示 A(50, 150, 200)
、B(240, 150, 200)
、C(50, 100, 50)
并将每个数据点分配到最近的中心所在的簇中这个过程称之为聚类而这个中心称之为聚类中心,这样就可以得到 K 个以聚类中心为坐标的主色值。而 K-Means 算法是一种常用的聚类算法,它的基本思想就是将数据集分成 K 个簇,每个簇的中心点称为聚类中心,将每个数据点分配到最近的聚类中心所在的簇中。
K-Means
算法的实现过程如下:
-
初始化聚类中心:随机选择 K 个点作为聚类中心。
-
分配数据点到最近的聚类中心所在的簇中:对于每个数据点,计算它与每个聚类中心的距离,将它分配到距离最近的聚类中心所在的簇中。
-
更新聚类中心:对于每个簇,计算它的所有数据点的平均值,将这个平均值作为新的聚类中心。
-
重复步骤 2 和步骤 3,直到聚类中心不再改变或达到最大迭代次数。
在图像处理中,我们可以将每个像素点的颜色值看作是一个三维向量,使用欧几里得距离计算两个颜色值之间的距离。对于每个像素点,我们将它分配到距离最近的聚类中心所在的簇中,然后将它的颜色值替换为所在簇的聚类中心的颜色值,如 A1(10, 140, 170)
以距离它最近的距离中心 A 的坐标表示即 A1 = A(50, 150, 200)
。这样,我们就可以将图像中的颜色值进行量化,将相似的颜色值归为一类。
最后,我们可以根据聚类中心的颜色值,计算每个颜色值在图像中出现的次数,并按出现次数从大到小排序,取前几个颜色作为主要颜色。
<script>
const img = new Image();
img.src = "https://guqing-blog.oss-cn-hangzhou.aliyuncs.com/image.jpg";
img.setAttribute("crossOrigin", "");
img.onload = function () {
const canvas = document.createElement("canvas");
const ctx = canvas.getContext("2d");
canvas.width = img.width;
canvas.height = img.height;
ctx.drawImage(img, 0, 0);
<span class="hljs-keyword">const</span> imageData = ctx.<span class="hljs-title function_">getImageData</span>(<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, canvas.<span class="hljs-property">width</span>, canvas.<span class="hljs-property">height</span>);
<span class="hljs-keyword">const</span> data = imageData.<span class="hljs-property">data</span>;
<span class="hljs-keyword">const</span> k = <span class="hljs-number">3</span>; <span class="hljs-comment">// 聚类数</span>
<span class="hljs-keyword">const</span> centers = <span class="hljs-title function_">quantize</span>(data, k);
<span class="hljs-variable language_">console</span>.<span class="hljs-title function_">log</span>(centers)
<span class="hljs-keyword">for</span> (<span class="hljs-keyword">const</span> color <span class="hljs-keyword">of</span> centers) {
<span class="hljs-keyword">const</span> div = <span class="hljs-variable language_">document</span>.<span class="hljs-title function_">createElement</span>(<span class="hljs-string">"div"</span>);
div.<span class="hljs-property">style</span>.<span class="hljs-property">width</span> = <span class="hljs-string">"50px"</span>;
div.<span class="hljs-property">style</span>.<span class="hljs-property">height</span> = <span class="hljs-string">"50px"</span>;
div.<span class="hljs-property">style</span>.<span class="hljs-property">backgroundColor</span> = color;
<span class="hljs-variable language_">document</span>.<span class="hljs-property">body</span>.<span class="hljs-title function_">appendChild</span>(div);
}
};
function quantize(data, k) {
// 将颜色值转换为三维向量
const vectors = [];
for (let i = 0; i < data.length; i += 4) {
vectors.push([data[i], data[i + 1], data[i + 2]]);
}
<span class="hljs-comment">// 随机选择 K 个聚类中心</span>
<span class="hljs-keyword">const</span> centers = [];
<span class="hljs-keyword">for</span> (<span class="hljs-keyword">let</span> i = <span class="hljs-number">0</span>; i &lt; k; i++) {
centers.<span class="hljs-title function_">push</span>(vectors[<span class="hljs-title class_">Math</span>.<span class="hljs-title function_">floor</span>(<span class="hljs-title class_">Math</span>.<span class="hljs-title function_">random</span>() * vectors.<span class="hljs-property">length</span>)]);
}
<span class="hljs-comment">// 迭代更新聚类中心</span>
<span class="hljs-keyword">let</span> iterations = <span class="hljs-number">0</span>;
<span class="hljs-keyword">while</span> (iterations &lt; <span class="hljs-number">100</span>) {
<span class="hljs-comment">// 分配数据点到最近的聚类中心所在的簇中</span>
<span class="hljs-keyword">const</span> clusters = <span class="hljs-keyword">new</span> <span class="hljs-title class_">Array</span>(k).<span class="hljs-title function_">fill</span>().<span class="hljs-title function_">map</span>(<span class="hljs-function">() =&gt;</span> []);
<span class="hljs-keyword">for</span> (<span class="hljs-keyword">let</span> i = <span class="hljs-number">0</span>; i &lt; vectors.<span class="hljs-property">length</span>; i++) {
<span class="hljs-keyword">let</span> minDist = <span class="hljs-title class_">Infinity</span>;
<span class="hljs-keyword">let</span> minIndex = <span class="hljs-number">0</span>;
<span class="hljs-keyword">for</span> (<span class="hljs-keyword">let</span> j = <span class="hljs-number">0</span>; j &lt; centers.<span class="hljs-property">length</span>; j++) {
<span class="hljs-keyword">const</span> dist = <span class="hljs-title function_">distance</span>(vectors[i], centers[j]);
<span class="hljs-keyword">if</span> (dist &lt; minDist) {
minDist = dist;
minIndex = j;
}
}
clusters[minIndex].<span class="hljs-title function_">push</span>(vectors[i]);
}
<span class="hljs-comment">// 更新聚类中心</span>
<span class="hljs-keyword">let</span> converged = <span class="hljs-literal">true</span>;
<span class="hljs-keyword">for</span> (<span class="hljs-keyword">let</span> i = <span class="hljs-number">0</span>; i &lt; centers.<span class="hljs-property">length</span>; i++) {
<span class="hljs-keyword">const</span> cluster = clusters[i];
<span class="hljs-keyword">if</span> (cluster.<span class="hljs-property">length</span> &gt; <span class="hljs-number">0</span>) {
<span class="hljs-keyword">const</span> newCenter = cluster
.<span class="hljs-title function_">reduce</span>(<span class="hljs-function">(<span class="hljs-params">acc, cur</span>) =&gt;</span> [
acc[<span class="hljs-number">0</span>] + cur[<span class="hljs-number">0</span>],
acc[<span class="hljs-number">1</span>] + cur[<span class="hljs-number">1</span>],
acc[<span class="hljs-number">2</span>] + cur[<span class="hljs-number">2</span>],
])
.<span class="hljs-title function_">map</span>(<span class="hljs-function">(<span class="hljs-params">val</span>) =&gt;</span> val / cluster.<span class="hljs-property">length</span>);
<span class="hljs-keyword">if</span> (!<span class="hljs-title function_">equal</span>(centers[i], newCenter)) {
centers[i] = newCenter;
converged = <span class="hljs-literal">false</span>;
}
}
}
<span class="hljs-keyword">if</span> (converged) {
<span class="hljs-keyword">break</span>;
}
iterations++;
}
<span class="hljs-comment">// 将每个像素点的颜色值替换为所在簇的聚类中心的颜色值</span>
<span class="hljs-keyword">for</span> (<span class="hljs-keyword">let</span> i = <span class="hljs-number">0</span>; i &lt; data.<span class="hljs-property">length</span>; i += <span class="hljs-number">4</span>) {
<span class="hljs-keyword">const</span> vector = [data[i], data[i + <span class="hljs-number">1</span>], data[i + <span class="hljs-number">2</span>]];
<span class="hljs-keyword">let</span> minDist = <span class="hljs-title class_">Infinity</span>;
<span class="hljs-keyword">let</span> minIndex = <span class="hljs-number">0</span>;
<span class="hljs-keyword">for</span> (<span class="hljs-keyword">let</span> j = <span class="hljs-number">0</span>; j &lt; centers.<span class="hljs-property">length</span>; j++) {
<span class="hljs-keyword">const</span> dist = <span class="hljs-title function_">distance</span>(vector, centers[j]);
<span class="hljs-keyword">if</span> (dist &lt; minDist) {
minDist = dist;
minIndex = j;
}
}
<span class="hljs-keyword">const</span> center = centers[minIndex];
data[i] = center[<span class="hljs-number">0</span>];
data[i + <span class="hljs-number">1</span>] = center[<span class="hljs-number">1</span>];
data[i + <span class="hljs-number">2</span>] = center[<span class="hljs-number">2</span>];
}
<span class="hljs-comment">// 计算每个颜色值在图像中出现的次数,并按出现次数从大到小排序</span>
<span class="hljs-keyword">const</span> counts = {};
<span class="hljs-keyword">for</span> (<span class="hljs-keyword">let</span> i = <span class="hljs-number">0</span>; i &lt; data.<span class="hljs-property">length</span>; i += <span class="hljs-number">4</span>) {
<span class="hljs-keyword">const</span> color = <span class="hljs-string">`rgb(<span class="hljs-subst">${data[i]}</span>, <span class="hljs-subst">${data[i + <span class="hljs-number">1</span>]}</span>, <span class="hljs-subst">${data[i + <span class="hljs-number">2</span>]}</span>)`</span>;
counts[color] = counts[color] ? counts[color] + <span class="hljs-number">1</span> : <span class="hljs-number">1</span>;
}
<span class="hljs-keyword">const</span> sortedColors = <span class="hljs-title class_">Object</span>.<span class="hljs-title function_">keys</span>(counts).<span class="hljs-title function_">sort</span>(
<span class="hljs-function">(<span class="hljs-params">a, b</span>) =&gt;</span> counts[b] - counts[a]
);
<span class="hljs-comment">// 取前 k 个颜色作为主要颜色</span>
<span class="hljs-keyword">return</span> sortedColors.<span class="hljs-title function_">slice</span>(<span class="hljs-number">0</span>, k);
}
function distance(a, b) {
return Math.sqrt(
(a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2 + (a[2] - b[2]) ** 2
);
}
function
equal(
a, b) {
returna[
0] === b[
0] && a[
1] === b[
1] && a[
2] === b[
2]; } </script>
自动选取 K 值
在实际应用中,我们可能不知道应该选择多少个聚类中心,即 K 值。一种常用的方法是使用 Gap 统计量法,它的基本思想是比较聚类结果与随机数据集的聚类结果之间的差异,选择使差异最大的 K 值。
Gap 统计量法的实现过程如下:
-
对原始数据集进行 K-Means 聚类,得到聚类结果。
-
生成 B 个随机数据集,对每个随机数据集进行 K-Means 聚类,得到聚类结果。
-
计算聚类结果与随机数据集聚类结果之间的差异,使用 Gap 统计量表示。
-
选择使 Gap 统计量最大的 K 值。
下面是使用 JavaScript 实现 Gap 统计量法的示例代码:
function gap(data, maxK) {
const gaps = [];
for (let k = 1; k <= maxK; k++) {
const quantized = quantize(data, k);
const gap = logWk(quantized) - logWk(randomData(data.length));
gaps.push(gap);
}
const maxGap = Math.max(...gaps);
return gaps.findIndex((gap) => gap === maxGap) + 1;
}
function logWk(quantized) {
const counts = {};
for (let i = 0; i < quantized.length; i++) {
counts[quantized[i]] = counts[quantized[i]] ? counts[quantized[i]] + 1 : 1;
}
const n = quantized.length;
const k = Object.keys(counts).length;
const wk = Object.values(counts).reduce((acc, cur) => acc + cur * Math.log(cur / n), 0);
return Math.log(n) + wk / n;
}
function
randomData(
n) {
constdata =
new
Uint8ClampedArray(n *
4);
for(
leti =
0; i < data.
length; i++) { data[i] =
Math.
floor(
Math.
random() *
256); }
returndata; }
使用:
const k = gap(data, 10)
// const k = 3; // 聚类数
const centers = quantize(data, k);
好吧,挺麻烦的,最终直接将封面图再作为背景图添加 backdrop-filter 来实现了 ?。
附录
Python 绘制图片 Scatter 3D:
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
def visualize_rgb(image_path):
image = plt.imread(image_path)
height, width, _ = image.shape
<span class="hljs-comment"># Reshape the image array to a 2D array of pixels</span>
pixels = np.reshape(image, (height * width, <span class="hljs-number">3</span>))
<span class="hljs-comment"># Extract RGB values</span>
red = pixels[:, <span class="hljs-number">0</span>]
green = pixels[:, <span class="hljs-number">1</span>]
blue = pixels[:, <span class="hljs-number">2</span>]
<span class="hljs-comment"># Create 3D scatter plot</span>
fig = plt.figure()
ax = fig.add_subplot(<span class="hljs-number">111</span>, projection=<span class="hljs-string">'3d'</span>)
ax.scatter(red, green, blue, c=pixels/<span class="hljs-number">255</span>, alpha=<span class="hljs-number">0.3</span>)
ax.set_xlabel(<span class="hljs-string">'Red'</span>)
ax.set_ylabel(<span class="hljs-string">'Green'</span>)
ax.set_zlabel(<span class="hljs-string">'Blue'</span>)
ax.set_title(<span class="hljs-string">'RGB 3D Scatter Plot'</span>)
plt.show()
调用函数并传入图像路径`
visualize_rgb('image.jpg'
)
`