基于milvus搭建“以图搜图”服务(附代码)
“以图搜图”服务需要的关键功能和准备工作:
1 图像向量化功能,可选的模型有很多,本例选用resnet网络提取图像特征;
2 milvus建表,用milvus存放图像特征,通过唯一ID(此处称:milvus_id)与图像一一对应,sql建表将milvus_id作为唯一索引,存放图像的其他信息(如url,来源等);
3 异步添加图像,同步搜索图像,添加图像的量通常会很大,因此采用异步批量的方式将图像特征加载到milvus,图像添加服务会将每次的请求信息存到sql,写个脚本专门用来定时批量加载图像特征到milvus,由于是异步操作,可能会出现重复加载的情况,此处使用redis进行去重。图像搜索的请求通常会比图像添加少很多,因此图像搜索使采用同步方式返回结果;
(总结:需建立三个表:milvus表1,存放图像特征;sql表2,存放图像信息,数据与milvus表1一一对应;sql表3,存放图像添加请求信息,用于图像特征异步批量加载到milvus)
“以图搜图”服务关键功能及代码(代码仅做参考)
1 图像向量化
"""
功能:图像向量化
"""
from keras.applications.resnet50 import ResNet50
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input, decode_predictions
import numpy as np
from numpy import linalg as LA
import time
model = ResNet50(weights='imagenet')
# model.summary()
def img2feature(img_path, input_dim=224): # 图像路径???图像数据
img = image.load_img(img_path, target_size=(input_dim, input_dim))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
x = model.predict(x)
x = x / LA.norm(x)
return x
def main():
img_path = '1.jpg'
t0 = time.time()
res = img2feature(img_path)
print(time.time() - t0, res.shape)
# print(res, type(res), res.shape)
if __name__ == "__main__":
main()
2 milvus表的操作
# coding:utf-8
from functools import reduce
import numpy as np
import time
from img2feature import img2feature
from pymilvus import (
connections, list_collections,
FieldSchema, CollectionSchema, DataType,
Collection, utility
)
field_name = 'image_feature'
host = '***.***.***.***'
port = '19530'
dim = 1000
default_fields = [
FieldSchema(name="milvus_id", dtype=DataType.INT64, is_primary=True),
FieldSchema(name="feature", dtype=DataType.FLOAT_VECTOR, dim=dim),
FieldSchema(name="create_time", dtype=DataType.INT64)
]
# create_table
def create_table():
connections.connect(host=host, port=port)
# create collection
default_schema = CollectionSchema(fields=default_fields, description="test collection")
print(f"\nCreate collection...")
collection = Collection(name=field_name, schema=default_schema)
print(f"\nCreate index...")
default_index = {"index_type": "FLAT", "params": {"nlist": 128}, "metric_type": "L2"}
collection.create_index(field_name="feature", index_params=default_index)
print(print(f"\nCreate index...is OKOKOKOKOK"))
collection.load()
# insert data
def insert_data():
connections.connect(host=host, port=port)
default_schema = CollectionSchema(fields=default_fields, description="test collection")
collection = Collection(name=field_name, schema=default_schema)
vectors = img2feature('1.jpg').tolist()[0]
print(type(vectors), len(vectors))
data1 = [
[123],
[vectors],
[int(time.time())]
]
collection.insert(data1)
print('insert compete')
# search data
def search_data():
print('search')
connections.connect(host=host, port=port)
collection = Collection(name=field_name)
print('连接成功')
# 首次查询建立索引和load()
# default_index = {"index_type": "FLAT", "params": {"nlist": 128}, "metric_type": "L2"}
# print(f"\nCreate index...")
# collection.create_index(field_name="feature", index_params=default_index)
# print(print(f"\nCreate index...is OKOKOKOKOK"))
# collection.load()
# exit()
vectors = img2feature('1.jpg').tolist()[0]
topK = 10
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
res = collection.search(
[vectors],
"feature",
search_params,
topK,
"create_time > {}".format(0),
output_fields=["milvus_id"]
)
print('>>>', res)
for hits in res:
print(len(hits))
for hit in hits:
print(hit)
print('查询结束')
def show_nums():
connections.connect(host=host, port=port)
collection = Collection(name=field_name)
print('ok')
print(collection.num_entities)
# delete data
def delete_table():
connections.connect(host=host, port=port)
default_schema = CollectionSchema(fields=default_fields, description="test collection")
collection = Collection(name=field_name, schema=default_schema)
print('>>>', utility.has_collection(field_name))
collection.drop()
print('>>>', utility.has_collection(field_name))
if __name__ == "__main__":
t1 = time.time()
# create_table()
# insert_data()
# search_data()
show_nums()
# delete_table()
print('time cost: {}'.format(time.time() - t1))
3 创建sql表2、表3
略
4 图像添加、搜索服务
from rest_framework.views import APIView as View
from kpdjango.response import SucessAPIResponse, ErrorAPIResponse
from kpmysql.base import Kpmysql
from core import search_image
import kplog
import logging
log = logging.getLogger("console")
class add_image(View):
def post(self, requests):
try:
db = Kpmysql.connect("db168")
cur = db.cursor()
image_info = requests.POST.get('image_info')
image_path = requests.POST.get('image_path')
sql = "INSERT INTO t_image_search_image_add_log(image_path, info) VALUES(%s, %s)"
cur.execute(sql, (image_path, image_info))
db.commit()
log.info('添加图像成功:{}-{}'.format(image_path, image_info))
return SucessAPIResponse(msg="Success")
except Exception as e:
log.info('添加图像失败:{}'.format(e))
return ErrorAPIResponse(msg="Fail")
class search_image(View):
def post(self, requests):
try:
image_path = requests.POST.get('image_path')
res = search_image(image_path)
log.info('查询图像成功:{}-{}'.format(image_path, res))
return SucessAPIResponse(msg="Success", data={"data": res})
except Exception as e:
log.info('查询图像成功:{}'.format(e))
return ErrorAPIResponse(msg="Fail")
5 图像异步批量加载
import time, datetime
from kpmysql.base import Kpmysql
from core import insert_data_many
from concurrent.futures import ThreadPoolExecutor
import redis
from conf.setting import REDIS
from core import str2time
import kplog
import logging
log = logging.getLogger("console")
log_addimgs = logging.getLogger("console_addimgs")
def worker(datas):
try:
redis_cli = redis.Redis(host=REDIS.get('host'), port=REDIS.get('port'), password=REDIS.get('password'),
db=REDIS.get('db'))
dics = []
ids = []
for data in datas:
if redis_cli.zscore('image_search', str(data[0])): # 基于redis去重
continue
dics.append({'image_path': data[1], 'create_time': data[2]})
ids.append((data[0]))
redis_cli.zadd('image_search', {str(data[0]): str2time(data[2])})
# 数据插入milvus
insert_data_many(dics)
# 更新 set t_image_search_image_add_log is_load=1
sql_update = """UPDATE t_image_search_image_add_log SET is_load=1 WHERE id=%s"""
db168 = Kpmysql.connect("db168")
cur168 = db168.cursor()
cur168.executemany(sql_update, ids)
db168.commit()
except Exception as e:
print(e)
def main():
max_workers = 20 # 最大线程数
pool = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix='Thread')
task_list = []
init_time = datetime.datetime.now() - datetime.timedelta(hours=13)
create_time_init = '2020-2-22 00:00:00'
while True:
now = datetime.datetime.now()
diff = now - init_time
if diff.seconds > 3600:
# 加载 t_image_search_image_add_log where is_load=0 数据
db168 = Kpmysql.connect("db168")
cur168 = db168.cursor()
sql = """SELECT id, image_path, create_time FROM t_image_search_image_add_log WHERE is_load=0 and create_time >= %s ORDER BY create_time"""
cur168.execute(sql, create_time_init)
datas = cur168.fetchall()
create_time_init = datas[-1][2]
while True:
for _i, _n in enumerate(task_list):
if _n.done():
task_list.pop(_i)
if len(task_list) < int(max_workers * 0.9):
break
task_list.append(pool.submit(worker, datas))
init_time = now
time.sleep(600)
if __name__ == "__main__":
main()
优化(重点)
经过实际测试和使用的建议:
1. keras在调用GPU时并开启多线程时不如pytorch方便,pytorch占用显存更少;
2. 定时从数据库拿数据,改成kafka生产消费模型,代码更简洁,逻辑更简单;

浙公网安备 33010602011771号