最新要闻

广告

手机

iphone11大小尺寸是多少?苹果iPhone11和iPhone13的区别是什么?

iphone11大小尺寸是多少?苹果iPhone11和iPhone13的区别是什么?

警方通报辅警执法直播中被撞飞:犯罪嫌疑人已投案

警方通报辅警执法直播中被撞飞:犯罪嫌疑人已投案

家电

全球百事通!以图搜图实现

来源:博客园

此次需要使用到的工具:


(相关资料图)

IDE:eclipse,pydevPython:3.10Packages:Keras + TensorFlow + Pillow + Numpy

keras

Keras是一个高层神经网络API,Keras由纯Python编写而成并基​​Tensorflow​​​、​​Theano​​​以及​​CNTK​​后端。简单来说,keras就是对TF等框架的再一次封装,使得使用起来更加方便。

基于vgg16网络提取图像特征 我们都知道,vgg网络在图像领域有着广泛的应用,后续许多层次更深,网络更宽的模型都是基于此扩展的,vgg网络能很好的提取到图片的有用特征,本次实现是基于Keras实现的,提取的是最后一层卷积特征。

思路

主要思路是基于CVPR2015的论文​​《Deep Learning of Binary Hash Codes for Fast Image Retrieval》​​实现的海量数据下的基于内容图片检索系统。简单说来就是对图片数据库的每张图片抽取特征(一般形式为特征向量),

存储于数据库中,对于待检索图片,抽取同样的特征向量,然后并对该向量和数据库中向量的距离(相似度计算),找出最接近的一些特征向量,其对应的图片即为检索结果

from keras.applications.vgg16 import VGG16from keras.preprocessing import imagefrom keras.utils.image_utils import load_img,img_to_arrayfrom keras.applications.vgg16 import preprocess_inputimport numpy as npfrom numpy import linalg as LAimport osimport h5pyimport matplotlib.image as mpimgimport matplotlib.pyplot as pltfrom PIL import Imageclass soutu2:    def __init__(self):        self.input_shape = (244, 244, 3)        self.weights = "imagenet"        self.pooling = "max"        self.model = VGG16(weights=self.weights, input_shape=self.input_shape, pooling=self.pooling, include_top=False)        self.model.predict(np.zeros((1, 244, 244, 3)))        self.h5f_index = "models/vgg_featureCNN.h5"            #抽取某个目录中的图片特征并保存    def xunlian(self, dirpath):        print("开始特征训练...")        feats = []        names = []        img_list = self.get_imglist(dirpath)        for i, img_path in enumerate(img_list):            norm_feat = self.extract_feat(img_path)            img_name = os.path.split(img_path)[1]            feats.append(norm_feat)            names.append(img_name)            print("正在处理%s/%s图片" % (i+1, len(img_list)))        feats = np.array(feats)                h5f = h5py.File(self.h5f_index, "w")        h5f.create_dataset("database_1", data=feats)        h5f.create_dataset("database_2", data=np.string_(names))        h5f.close()                print("训练完毕\n")            #查找相似图    def chazhao(self, img_path):        print("开始按输入图片特征查找...")        h5f = h5py.File(self.h5f_index, "r")        feats = h5f["database_1"][:]        names = h5f["database_2"][:]        h5f.close()                img_paths = []        img_paths.append(img_path)                query_feat = self.extract_feat(img_path)        scores = np.dot(query_feat, feats.T)        rank_id = np.argsort(scores)[::-1]        rank_score = scores[rank_id]                print("查找到以下图片")        img_list = []        for i, index in enumerate(rank_id[0:3]):            img_list.append(names[index])            img_path_in_db = "database/%s" % str(names[index], "utf-8")            print("图片名称:%s,得分:%s" % (img_path_in_db, rank_score[i]))            img_paths.append(img_path_in_db)                    self.show_imgs(img_paths)        print("查找完毕\n")            #提取图片的特征向量    def extract_feat(self, img_path):        img = load_img(img_path, False, "rgb", target_size=(224, 224))        img = img_to_array(img)        #扩展维度,因为preprocess_input需要4D的格式        img = np.expand_dims(img, axis=0)        #对张量进行预处理        img = preprocess_input(img)        feat = self.model.predict(img)        norm_feat = feat[0] / LA.norm(feat[0])                return norm_feat        def get_imglist(self, path):        return [os.path.join(path, f) for f in os.listdir(path) if f.endswith(".jpg") or f.endswith(".jpeg") or f.endswith(".png") or f.endswith(".gif")]        def show_img(self, img_path):        query_img = mpimg.imread(img_path)        plt.imshow(query_img)        plt.show()                def show_imgs(self, img_paths):        fig = plt.figure(figsize=(12,4))        fig.canvas.manager.set_window_title("第一张为输入的搜索图,其余3张为搜索结果")        for i,img_path in enumerate(img_paths):            query_img = mpimg.imread(img_path)            img_name = os.path.split(img_path)[1]            ax = fig.add_subplot(2, 4, i + 1, xticks=[], yticks=[])            ax.set_title(img_name, color=("black" ), fontsize=6, ha="center")            plt.subplots_adjust(wspace=0.05, hspace=0)            plt.imshow(query_img)        plt.show()    if __name__ == "__main__":    soutu2obj = soutu2()    #soutu2obj.xunlian("database/")        soutu2obj.chazhao("query/3.jpg")

训练:

搜索验证:

还挺好用

参考:

https://blog.51cto.com/captainbed/5572330

https://blog.csdn.net/starter_____/article/details/79340715

关键词: 特征向量 检索系统 我们都知道