55.人工智能——h5、pkl、npz三种格式文件存储数据集

人工智能
后台-插件-广告管理-内容页头部广告(手机)

在人工智能项目中,经常会用到h5、pkl、npz这种格式文件存储的数据集。下面通过具体实例来了解一下这种三格式文件是如何存储和读取的。

h5是h5py文件格式,h5py是python的一个模块,可以用来存储数据集。

pkl是pickle文件格式,pickle是python的一个模块,可以将几乎所有的数据类型(列表,字典,集合,类等)都可以用pickle来序列化后保存文件。

npz是numpy存储的压缩文件格式。

数据集分别对应在三个目录:car、cat、flower,每个类别200张,图像:(64,64,3),可以用着图像分类。

55.人工智能——h5、pkl、npz三种格式文件存储数据集

car

55.人工智能——h5、pkl、npz三种格式文件存储数据集

cat

55.人工智能——h5、pkl、npz三种格式文件存储数据集

flower

一、创建h5格式数据集

import osimport numpy as npimport cv2import matplotlib.image as mpimg import h5py#创建数据集 h5pydef save_image_to_h5py(path):    img_list = []    label_list = []    classes=[]#分类    #train 文件夹名是分类名称        for id,child_dir in enumerate(os.listdir(path)):        classes_name=child_dir.encode("utf-8") #byte方式保存分类名        classes.append(classes_name)         child_path = os.path.join(path,child_dir)        for dir_image in os.listdir(child_path): #遍历                                   #BGR            #img = cv2.imread(os.path.join(child_path,dir_image))            #img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)            #或者RGB             img=mpimg.imread(os.path.join(child_path,dir_image))             img_list.append(img)            label_list.append(id)#classes id  分类ID        img_np = np.array(img_list) #列表转ndarray    label_np = np.array(label_list)    classes_np=np.array(classes)    #方法二:    # img_label=np.array([img_np,label_np]),    # np.random.shuffle(img_label).T    #随机产生测试数据集(方法一:)    rate=int(img_np.shape[0]*0.3)#数据集的30%为测试集    img_np_test=np.empty((rate,img_np.shape[1],img_np.shape[2],img_np.shape[3]),dtype=np.uint8)    label_np_test=np.empty((rate,),dtype=np.uint8)    img_np_test_rndnum=np.random.choice(range(img_np.shape[0]),rate,replace=False)    for i,num in enumerate(img_np_test_rndnum):        img_np_test[i]=img_np[num]        label_np_test[i]=label_np[num]    #print(img_np.shape,img_np_test.shape)    #print(label_np_test)        #print('数据集标签顺序:\n',label_np)        print(classes_np)           #print(img_np_test[:,])                                                   #'a' ,如果已经有这个名字的h5文件存在将不会打开,目的为了防止误删信息。    #‘w' ,如果有同名文件也能打开,但会覆盖上次的内容。    if path=="train": #训练数据集        with h5py.File('datasets/train.h5','w') as f:            f.create_dataset('training_data',data = img_np)              #创建两个数据集,分别为training_cat            f.create_dataset('training_label',data = label_np)          #和training_label的数组集            f.create_dataset("training_classes",data=classes_np)            f.close()        with h5py.File('datasets/test.h5','w') as f:            f.create_dataset('testing_data',data = img_np_test)              #创建两个数据集,分别为training_cat            f.create_dataset('testing_label',data = label_np_test)          #和training_label的数组集            f.create_dataset("testing_classes",data=classes_np)            f.close()import timestart_time = time.time()save_image_to_h5py('train')#traindt=time.time()-start_timeprint(dt)

二、读取h5格式文件数据集

import numpy as npimport matplotlib.pyplot as pltimport h5pydef load_dataset():        dataset=h5py.File('datasets/train.h5',"r")    train_set_x=dataset["training_data"]    train_set_y=dataset["training_label"]    classes=dataset["training_classes"]    # for key in dataset.keys():    #     print(dataset[key],key,dataset[key].name)    return train_set_x,train_set_y,classestrain_set_x,train_set_y,classes=load_dataset()num_classes=len(classes)rndnum=4#随机取数目,可视化显示 idxs=np.random.choice(range(len(train_set_x)), rndnum*num_classes, replace=False)print(idxs)for y in range(num_classes):    for i,idx in enumerate(idxs):        plt.subplot(rndnum,num_classes,i+1)        plt.imshow(train_set_x[idx])        plt.axis("off")        plt.title(classes[train_set_y[idx]].decode("utf-8"))plt.show()
55.人工智能——h5、pkl、npz三种格式文件存储数据集

运行结果

三、存储和读取pkl和npz格式文件的数据集

import osimport numpy as npimport matplotlib.image as mpimgimport pickle#存储pkldef save_image_to_pickle(path):    img_list=[]    label_list=[]    classes=[]    for id,child_dir in enumerate(os.listdir(path)):        classes_name=child_dir.encode("utf-8")        classes.append(classes_name)        child_path=os.path.join(path,child_dir)        for dir_image in os.listdir(child_path):            img=mpimg.imread(os.path.join(child_path,dir_image))              if img.shape!=tuple((64,64,3)):                print(os.path.join(child_path,dir_image),img.shape)              img_list.append(img)            label_list.append(id)      #训练数据集    train_np=np.array([img_list,label_list])        train_np=train_np.transpose()        #产生测试集    np.random.shuffle(train_np)    rate=int(0.3*train_np.shape[0])    test_np=train_np[:rate]    #测试集    train_np=train_np[rate:]   #训练集        print(train_np.shape)    print(test_np.shape)    #分类名    classes_np=np.array(classes)    data={"train":train_np,"test":test_np,"classes":classes}    #print(data["classes"])    with open(r"datasets/data.pkl","wb") as f:        pickle.dump(data,f)import timestart = time.time()save_image_to_pickle("train")dt=time.time()-startprint(dt)# with open(r"datasets/data.pkl","rb") as f:#     data=pickle.load(f)# print(data["classes"])# train=data["train"]# test=data["test"]# print(train[:,1])# print(test[:,1])#############存储npz格式###################################################def save_image_to_npz(path):    img_list=[]    label_list=[]    classes=[]    for id,child_dir in enumerate(os.listdir(path)):        classes_name=child_dir.encode("utf-8")        classes.append(classes_name)        child_path=os.path.join(path,child_dir)        for dir_image in os.listdir(child_path):            img=mpimg.imread(os.path.join(child_path,dir_image))            img_list.append(img)            label_list.append(id)    #训练数据集    train_np=np.array([img_list,label_list])    train_np=train_np.transpose()    #产生测试集    np.random.shuffle(train_np)    rate=int(0.3*train_np.shape[0])    test_np=train_np[:rate]        #print(train_np[:,1])    #print(test_np[0:,1])    #分类名    classes_np=np.array(classes)    #保存npz格式文件    np.savez(r"datasets/data.npz", train=train_np,test=test_np,classes=classes_np)###################################################################### import time# start_time=time.time()# save_image_to_npz("train")# dt=time.time()-start_time# print(dt)#读取npz格式文件# def read_npz(path):#     data=np.load(path,allow_pickle=True)#     return data# data=read_npz(r"datasets/data.npz")# print(data.files)# train=data["train"]# test=data["test"]# classes=data["classes"]# print(train[:,1])# print(test[:,1])# print(np.char.decode(classes,encoding="utf-8"))# print(train.shape,test.shape)

运行结果:

(420, 2)

(180, 2)

7.0843353271484375

本文主要以实例代码演示如何存储和读取h5、pkl、npz这种格式文件,方便以后数据集的处理。

后台-插件-广告管理-内容页尾部广告(手机)
标签:

评论留言

我要留言

◎欢迎参与讨论,请在这里发表您的看法、交流您的观点。