hnsw/hnsw.py
2025-04-19 18:22:18 -07:00

150 lines
5.3 KiB
Python

import math
import numpy as np
from heapq import heappush, heappop
from collections import Counter
import random
def distance(a:np.array, b:np.array):
return np.linalg.norm(a-b)
def printw(W):
print('<==')
for d,elem in W:
print(d, elem.coord)
print('==>')
class HNSWTower:
def __init__(self, coord, level):
self.coord = coord
self.level = level
self.neighbors = [set() for _ in range(level+1)]
def connect(self, other, level):
if other not in self.neighbors[level]:
self.neighbors[level].add(other)
if self not in other.neighbors:
other.neighbors[level].add(self)
def disconnect(self, other, level):
if other in self.neighbors[level]:
self.neighbors[level].remove(other)
if self in other.neighbors:
other.neighbors[level].remove(self)
def limit_degree(self, level, M):
if len(self.neighbors[level]) > M:
node = None
max_dist = 0
for other in self.neighbors[level]:
d = distance(self.coord, other.coord)
if d > max_dist:
max_dist = d
node = other
self.disconnect(node, level)
class HNSW:
"""
Approximate k-Nearest Neighbor Search (k-ANNS) using Hierarchical Navigable Small World (HNSW) graph
"""
def __init__(self, mL, M, M_max, ef_construction, seed):
# Enter point
self.ep = None
# Normalizing factor for level generation
self.mL = mL
# Number of connections to establish
self.M = M
# Maximum degree for node
self.M_max = M_max
# Size of NN to return during construction
self.ef_construction = ef_construction
self.rng = random.Random(seed)
self.level_count = Counter()
def generate_level(self):
# the condition of returning a level higher than L is
# self.mL * math.log(1/u) >= L
# u <= exp(-L/self.mL) = exp(-1/self.mL) ^ L
# equivalent to a "level up probability" of exp(-1/self.mL)
u = self.rng.random()
return math.floor(-self.mL * math.log(u))
def insert_first(self, coord):
self.ep = HNSWTower(coord, 0)
self.level_count[0] += 1
def insert(self, coord):
if not self.ep:
self.insert_first(coord)
return
ep = self.ep
l = self.generate_level()
self.level_count[l] += 1
elem = HNSWTower(coord, l)
L = self.ep.level
for level in range(L, l, -1):
ep = self.route_layer(coord, ep, level)
# enter points (multiple, sorted by distance from near to far)
eps = [(distance(coord, ep.coord), ep)]
for level in range(min(l,L), -1, -1):
W = self.search_layer(coord, eps, self.ef_construction, level)
neighbors = [entry[1] for entry in sorted(W)[:self.M]]
# Create node at this layer
for other in neighbors:
elem.connect(other, level)
# Shrink connections if needed
for other in neighbors:
other.limit_degree(level, self.M_max)
eps = W
if l > L:
self.ep = elem
def route_layer(self, coord, ep, level):
# returns node that is closes to coord at level, starting from "ep"
while True:
best = None
min_d = distance(coord, ep.coord)
for other in ep.neighbors[level]:
d = distance(coord, other.coord)
if d < min_d:
min_d = d
best = other
if best:
ep = best
else:
return ep
def search_layer(self, coord, eps, ef, level):
# returns a vector of (distance, node) satisfying heap condition
V = set([entry[1] for entry in eps])
# set of candidates organized as priority queue (pq)
C = eps # nearest element has top priority
# dynamic list of found nearest neighbors organized as pq
W = [(-entry[0], entry[1]) for entry in eps]
# print(level, "start")
# printw(W)
W.reverse() # farthest element has top priority
while len(C) > 0:
c_dist, c = heappop(C)
f_dist, f = W[0]
f_dist = - f_dist
if c_dist > f_dist:
break # all elements in W are evaluated
for e in c.neighbors[level]:
if e not in V:
V.add(e)
f_dist, f = W[0]
f_dist = - f_dist
e_dist = distance(coord, e.coord)
if e_dist < f_dist or len(W) < ef:
heappush(C, (e_dist, e))
heappush(W, (-e_dist, e))
if len(W) > ef:
heappop(W)
W = [(-entry[0], entry[1]) for entry in W]
W.reverse()
# print(level, "end")
# printw(W)
return W
def k_nn_search(self, coord, K, ef):
ep = self.ep
L = ep.level
for level in range(L, 0, -1):
ep = self.route_layer(coord, ep, level)
eps = [(distance(coord, ep.coord), ep)]
W = self.search_layer(coord, eps, ef, 0)
return sorted(W)[:K]