80 lines
3.1 KiB
Python
80 lines
3.1 KiB
Python
import numpy as np
|
|
import time
|
|
from hnsw import *
|
|
|
|
class SyntheticVectors:
|
|
def __init__(self, num_pts, dim, seed):
|
|
np.random.seed(seed)
|
|
self.num_pts = num_pts
|
|
self.dim = dim
|
|
self.points = np.random.normal(0, 1, (num_pts, dim))
|
|
def knn(self, query, k):
|
|
W = [(distance(query, self.points[i]), self.points[i]) for i in range(self.num_pts)]
|
|
return sorted(W)[:k]
|
|
def validate_results(self, query, k, results):
|
|
ground_truth = self.knn(query, k)
|
|
hit = 0
|
|
for g_dist, g_coord in ground_truth:
|
|
for r_dist, r_coord in results:
|
|
if np.isclose(np.linalg.norm(g_coord-r_coord), 0):
|
|
hit += 1
|
|
break
|
|
return hit
|
|
|
|
|
|
class HNSWTester:
|
|
def __init__(self, num_corpus, dim, seed):
|
|
self.num_corpus = num_corpus
|
|
self.dim = dim
|
|
self.seed = seed
|
|
self.corpus = SyntheticVectors(num_corpus, dim, seed)
|
|
np.random.seed(seed+1)
|
|
def build_hnsw(self, M, ef_construction):
|
|
self.M, self.ef_construction = M, ef_construction
|
|
mL = 1/math.log(M)
|
|
# using the above ml, the "level up" probability is 1/M.
|
|
# Build the HNSW
|
|
start_time = time.perf_counter()
|
|
self.hnsw = HNSW(mL, M, M, ef_construction, self.seed)
|
|
for ic in range(self.num_corpus):
|
|
self.hnsw.insert(self.corpus.points[ic])
|
|
end_time = time.perf_counter()
|
|
build_time = end_time - start_time
|
|
print("Built HNSW in %.4f seconds" % build_time)
|
|
self.build_time = build_time
|
|
self.print_hnsw_levels()
|
|
def gen_queries(self, num_query):
|
|
self.num_query = num_query
|
|
self.queries = np.random.normal(0, 1, (num_query, self.dim))
|
|
def query_hnsw(self, K, ef):
|
|
"returns recall and average query time"
|
|
self.K, self.ef = K, ef
|
|
n_hit = 0
|
|
tot_time = 0
|
|
for iq in range(self.num_query):
|
|
start_time = time.perf_counter()
|
|
result = self.hnsw.k_nn_search(self.queries[iq], K, ef)
|
|
result = [(dist, node.coord) for dist, node in result]
|
|
end_time = time.perf_counter()
|
|
tot_time += end_time - start_time
|
|
n_hit += self.corpus.validate_results(self.queries[iq], K, result)
|
|
self.recall, self.query_ms = n_hit / (self.num_query * K), tot_time * 1000 / self.num_query
|
|
print("recall:%.4f, query time/ms: %.4f" % (self.recall, self.query_ms))
|
|
def print_hnsw_levels(self):
|
|
cnt = self.hnsw.level_count
|
|
L = self.hnsw.ep.level
|
|
for level in range(L+1):
|
|
print(f"level {level}: {cnt[level]} nodes")
|
|
def log_results(self, fname="result.csv"):
|
|
f = open(fname, 'a', encoding='utf-8')
|
|
f.write(f"{self.num_corpus},{self.dim},")
|
|
f.write(f"{self.M},{self.ef_construction},%.4f," % self.build_time)
|
|
f.write(f"{self.K},{self.ef},%.4f,%.4f," % (self.recall, self.query_ms))
|
|
f.write(f"{self.num_query},{self.seed}\n")
|
|
|
|
|
|
tester = HNSWTester(num_corpus=5000, dim=100, seed=43)
|
|
tester.build_hnsw(M=10, ef_construction=10)
|
|
tester.gen_queries(num_query=100)
|
|
tester.query_hnsw(K=5, ef=10)
|
|
tester.log_results() |