clip-retrieval检索本地数据集

clip-retrieval检索本地数据集

from clip_retrieval.clip_client import ClipClient, Modality
from tqdm import tqdm
import urllib.request  
import os

import requests

import socket

client = ClipClient(url="https://knn.laion.ai/knn-service", indice_name="laion5B-L-14")

# Query by text
results  = client.query(text="an image of a garbage can")
# results = client.query(text="an image of a fire")
print("search len:", len(results))
# results = client.query(text="an image of a cat")
# print("search len:", len(results))

# Query by image
# results = client.query(image="cat.jpg")
# print("search len:", len(results))

# save_path = "./search_result/"
save_path = "/data/home/linxu/PycharmProjects/clip-retrieval/data_result/"
for i in tqdm(range(0, len(results))):
    caption = results[i]['caption']
    url = results[i]['url']
    id = results[i]['id']
    similarity = results[i]['similarity']
 
    parsed_url = urllib.parse.urlparse(url) 
    # print("caption:", caption, "url:", url, "id:", id, "similarity:", similarity)
    # file_name = save_path + url.strip('/').split('.')[0] + '.jpg'
    file_name = save_path + parsed_url.path.strip('/').split('.')[0] + '.jpg'
    # print("parsed_url:", parsed_url, "file_name:", file_name)
    print("file_name:", file_name, "url:", url)
    if os.path.exists(save_path) == False:
        os.makedirs(save_path)
    
    try:
        # 将图片数据写入文件
        # print("file_name:", file_name, "url:", url)
        #设置超时时间
        socket.setdefaulttimeout(10)
        try:
            urllib.request.urlretrieve(url,file_name)
        
        #如果超时
        except urllib.request.urlretrieve.timeout:
            count = 1
            while count <= 5:
                try:
                    urllib.request.urlretrieve(url,file_name)                                                
                    break
                except socket.timeout:
                    err_info = 'Reloading for %d time'%count if count == 1 else 'Reloading for %d times'%count
                    print(err_info)
                    count += 1
            if count > 5:
                print("download job failed!")
    except:
        print("error url:", url)
    

posted @ 2023-07-04 17:45  Xu_Lin  阅读(262)  评论(0编辑  收藏  举报