hnsw/knntest.py
2025-04-19 18:22:18 -07:00

80 lines
3.1 KiB
Python

import numpy as np
import time
from hnsw import *
class SyntheticVectors:
def __init__(self, num_pts, dim, seed):
np.random.seed(seed)
self.num_pts = num_pts
self.dim = dim
self.points = np.random.normal(0, 1, (num_pts, dim))
def knn(self, query, k):
W = [(distance(query, self.points[i]), self.points[i]) for i in range(self.num_pts)]
return sorted(W)[:k]
def validate_results(self, query, k, results):
ground_truth = self.knn(query, k)
hit = 0
for g_dist, g_coord in ground_truth:
for r_dist, r_coord in results:
if np.isclose(np.linalg.norm(g_coord-r_coord), 0):
hit += 1
break
return hit
class HNSWTester:
def __init__(self, num_corpus, dim, seed):
self.num_corpus = num_corpus
self.dim = dim
self.seed = seed
self.corpus = SyntheticVectors(num_corpus, dim, seed)
np.random.seed(seed+1)
def build_hnsw(self, M, ef_construction):
self.M, self.ef_construction = M, ef_construction
mL = 1/math.log(M)
# using the above ml, the "level up" probability is 1/M.
# Build the HNSW
start_time = time.perf_counter()
self.hnsw = HNSW(mL, M, M, ef_construction, self.seed)
for ic in range(self.num_corpus):
self.hnsw.insert(self.corpus.points[ic])
end_time = time.perf_counter()
build_time = end_time - start_time
print("Built HNSW in %.4f seconds" % build_time)
self.build_time = build_time
self.print_hnsw_levels()
def gen_queries(self, num_query):
self.num_query = num_query
self.queries = np.random.normal(0, 1, (num_query, self.dim))
def query_hnsw(self, K, ef):
"returns recall and average query time"
self.K, self.ef = K, ef
n_hit = 0
tot_time = 0
for iq in range(self.num_query):
start_time = time.perf_counter()
result = self.hnsw.k_nn_search(self.queries[iq], K, ef)
result = [(dist, node.coord) for dist, node in result]
end_time = time.perf_counter()
tot_time += end_time - start_time
n_hit += self.corpus.validate_results(self.queries[iq], K, result)
self.recall, self.query_ms = n_hit / (self.num_query * K), tot_time * 1000 / self.num_query
print("recall:%.4f, query time/ms: %.4f" % (self.recall, self.query_ms))
def print_hnsw_levels(self):
cnt = self.hnsw.level_count
L = self.hnsw.ep.level
for level in range(L+1):
print(f"level {level}: {cnt[level]} nodes")
def log_results(self, fname="result.csv"):
f = open(fname, 'a', encoding='utf-8')
f.write(f"{self.num_corpus},{self.dim},")
f.write(f"{self.M},{self.ef_construction},%.4f," % self.build_time)
f.write(f"{self.K},{self.ef},%.4f,%.4f," % (self.recall, self.query_ms))
f.write(f"{self.num_query},{self.seed}\n")
tester = HNSWTester(num_corpus=5000, dim=100, seed=43)
tester.build_hnsw(M=10, ef_construction=10)
tester.gen_queries(num_query=100)
tester.query_hnsw(K=5, ef=10)
tester.log_results()