python读文件夹图片,做数据集

2021-12-03
6

程序功能:读取文件夹内图片并输出形状[m,n_H,n_W,n_C]的数组
m:图片数量
n_H:图片高度
n_W:图片宽度
n_C:图片维数

def read_picture(path,n_C):
    import os
    from PIL import Image
    import numpy as np
    import matplotlib.pyplot as plt
    #function:读取path路径下的图片,并转为形状为[m,n_H,n_W,n_C]的数组
    #path:str,图片所在路径
    #n_C:int,图像维数,黑白图像输入1,rgb图像输入3
    #datas:返回维度为(m,n_H,n_W,n_C)的array(数组)矩阵
    datas=[]
    x_dirs=os.listdir(path)
    for x_file in x_dirs:
        fpath=os.path.join(path,x_file)
        if n_C == 1 :
            _x=Image.open(fpath).convert("L")
            #plt.imshow(_x,"gray")   #显示图像(只显示最后一张)
        elif n_C ==3:
            _x=Image.open(fpath)
            #plt.imshow(_x)         #显示图像(只显示最后一张)
        else:
            print("错误:图像维数错误")
        n_W=_x.size[0]
        n_H=_x.size[1]
        #若要对图像进行放大缩小,激活(去掉注释)以下函数
        '''
        rat=0.8          #放大/缩小倍数
        n_W=int(rat*n_W)
        n_H=int(rat*n_H)
        _x=_x.resize((n_W,n_H))  #直接给n_W,n_H赋值可将图像变为任意大小
        '''
        datas.append(np.array(_x))
        _x.close()  
    datas=np.array(datas)
    
    m=datas.shape[0]
    datas=datas.reshape((m,n_H,n_W,n_C))
    #print(datas.shape)
    
    return datas

评论