云搜索服务 CSS-Elasticsearch向量检索的性能测试和比较:脚本“base_test_example.py”
时间:2025-04-10 20:27:00
脚本“base_test_example.py”
# -*- coding: UTF-8 -*- import json import time import h5py from elasticsearch import Elasticsearch from elasticsearch import helpers def get_client(hosts: list, user: str = None, password: str = None): if user and password: return Elasticsearch(hosts, http_auth=(user, password), verify_certs=False, ssl_show_warn=False) else: return Elasticsearch(hosts) # 索引参数说明请参见在Elasticsearch集群创建向量索引。 def create(es_client, index_name, shards, replicas, dim, algorithm="GRAPH", metric="euclidean", neighbors=64, efc=200, shrink=1.0): index_mapping = { "settings": { "index": { "vector": True }, "number_of_shards": shards, "number_of_replicas": replicas, }, "mappings": { "properties": { "id": { "type": "integer" }, "vec": { "type": "vector", "indexing": True, "dimension": dim, "algorithm": algorithm, "metric": metric, "neighbors": neighbors, "efc": efc, "shrink": shrink, } } } } es_client.indices.create(index=index_name, body=index_mapping) print(f"Create index success! Index name: {index_name}") def write(es_client, index_name, vectors, bulk_size=1000): print("Start write! Index name: " + index_name) start = time.time() for i in range(0, len(vectors), bulk_size): actions = [{ "_index": index_name, "id": i + j, "vec": v.tolist() } for j, v in enumerate(vectors[i: i + bulk_size])] helpers.bulk(es_client, actions, request_timeout=180) print(f"Write success! Docs count: {len(vectors)}, total cost: {time.time() - start:.2f} seconds") merge(es_client, index_name) def merge(es_client, index_name, seg_cnt=1): print(f"Start merge! Index name: {index_name}") start = time.time() es_client.indices.forcemerge(index=index_name, max_num_segments=seg_cnt, request_timeout=7200) print(f"Merge success! Total cost: {time.time() - start:.2f} seconds") # 查询参数说明请参考见在Elasticsearch集群使用向量索引搜索数据。 def query(es_client, index_name, queries, gts, size=10, k=10, ef=200, msn=10000): print("Start query! Index name: " + index_name) i = 0 precision = [] for vec in queries: hits = set() dsl = { "size": size, "stored_fields": ["_none_"], "docvalue_fields": ["id"], "query": { "vector": { "vec": { "vector": vec.tolist(), "topk": k, "ef": ef, "max_scan_num": msn } } } } res = es_client.search(index=index_name, body=json.dumps(dsl)) for hit in res['hits']['hits']: hits.add(int(hit['fields']['id'][0])) precision.append(len(hits.intersection(set(gts[i, :size]))) / size) i += 1 print(f"Query complete! Average precision: {sum(precision) / len(precision)}") def load_test_data(src): hdf5_file = h5py.File(src, "r") base_vectors = hdf5_file["train"] query_vectors = hdf5_file["test"] ground_truths = hdf5_file["neighbors"] return base_vectors, query_vectors, ground_truths def test_sift(es_client): index_name = "index_sift_graph" vectors, queries, gts = load_test_data(r"sift-128-euclidean.hdf5") # 根据实际测试需求调整分片和副本数、索引算法、索引参数等。本文性能测试均配置的是1个分片、2个副本。 create(es_client, index_name, shards=1, replicas=2, dim=128) write(es_client, index_name, vectors) query(es_client, index_name, queries, gts) if __name__ == "__main__": # 此处修改为 CSS 集群的实际访问地址。 client = get_client(['http://x.x.x.x:9200']) test_sift(client)
support.huaweicloud.com/bestpractice-css/css_07_0050.html