ANNOY(Approximate Nearest Neighbors Oh Yeah)算法能够帮助我们高效的查找近邻的 N 个向量。其基本原理:就是将所有向量按照空间进行划分,直到子空间小于等于 K 个向量位置。如下图所示:
随机选择两个向量,在两点的直线的中心垂直画一条直线将样本分割成两部分。接下来,按照这个思路继续划分每个子空间,直到空间向量数量小于等于 K 个为止。
上面的划分过程,也可以看成构建二叉树的过程,如下图所示:
ANNOY 构建的二叉树都是随机构建的,并且 ANNOY 会构建多个这样的随机二叉树,树的数量可以由我们自己来指定。当来了一个新向量时,该向量在每棵树上必然属于某个子空间,假设我们有 5 棵树,则将新向量所在的 5 个子空间中的所有向量中,找出 N 个最相似的向量。
上图中,可以看到新向量所属的不同子空间,并且每个子空间都有多个向量,ANNOY 就是从这些向量中找到 N 个相近的向量返回。
pip install annoy
示例代码:
from annoy import AnnoyIndex
import torch
# 1. annoy 构建搜索
def test01():
# 构建索引
# metric 为距离度量方法,可选的有:
# 向量角: "angular"
# 欧式距离: "euclidean"
# 曼哈顿距离: "manhattan"
# 汉明距离: "hamming"
# 点积: "dot"
# f 参数表示向量的维度
index = AnnoyIndex(f=3, metric='euclidean')
# 插入向量: 第一个参数为插入位置,第二个参数为插入向量
index.add_item(0, torch.tensor([1, 2, 3]))
index.add_item(1, torch.tensor([7, 8, 9]))
index.add_item(2, torch.tensor([4, 5, 6]))
# 构建二叉树
# on_disk_build 方法将树构建到文件中
# build 方法会将树构建到内存中
index.build(n_trees=5)
# 查询向量:
# 第一个参数为待查询向量的索引
# 第二个参数为要返回的向量个数
# 返回值为已查到向量的索引
find_index = index.get_nns_by_item(i=0, n=2)
print('find_index:', find_index)
# 第一个参数为待查询的向量
# 第二个参数为要返回的向量个数
# 返回值为已查到向量的索引
find_index = index.get_nns_by_vector(vector=torch.tensor([0, 1, 1]), n=2)
print('find_index:', find_index)
# 2. annoy 存储加载
def test02():
index = AnnoyIndex(f=3, metric='euclidean')
index.add_item(0, torch.tensor([1, 2, 3]))
index.add_item(1, torch.tensor([7, 8, 9]))
index.add_item(2, torch.tensor([4, 5, 6]))
index.build(n_trees=5)
# 存储索引
index.save('index.ann')
# 加载索引
index = AnnoyIndex(f=3, metric='euclidean')
index.load('index.ann')
find_index = index.get_nns_by_item(i=0, n=2)
print('find_index:', find_index)
# 3. annoy 其他函数
def test03():
index = AnnoyIndex(f=3, metric='euclidean')
index.add_item(0, torch.tensor([1, 2, 3]))
index.add_item(1, torch.tensor([7, 8, 9]))
index.add_item(2, torch.tensor([4, 5, 6]))
index.build(n_trees=5)
# 获得指定索引位置的向量
print('位置向量:', index.get_item_vector(0))
# 返回索引向量个数
print('向量个数:', index.get_n_items())
# 返回索引树的数量
print('树的数量:', index.get_n_trees())
# 获得索引中指定下标位置的两个向量的距离
print('向量距离:', index.get_distance(0, 1))
if __name__ == '__main__':
test01()
print('-' * 30)
test02()
print('-' * 30)
test03()
程序输出结果:
find_index: [0, 2]
find_index: [0, 2]
------------------------------
find_index: [0, 2]
------------------------------
位置向量: [1.0, 2.0, 3.0]
向量个数: 3
树的数量: 5
向量距离: 10.392304420471191