h5py读取写入

cooolr 于 2021-02-05 发布
def update_model(self, model_name, feats=np.array([]), names=[]):
        filename = f"{self.modles_path}/{model_name}"

        if not names:
            # 读取历史数据
            if os.path.exists(filename):
                with h5py.File(filename, 'r') as f:
                    feats = f['dataset_feats'][:]
                    names = f[f'dataset_names'][:].tolist()
            return feats, names

        feats = np.array(feats)
        # names = np.string_(names)
        news_id = names[0]
        names = np.array(names, dtype="|S20")

        if not os.path.exists(filename):
            h5f = h5py.File(filename, 'w')
            h5f.create_dataset('dataset_feats', data=feats, maxshape=(None, 512), chunks=True)
            h5f.create_dataset('dataset_names', data=names, maxshape=(None,), chunks=True)
            h5f.close()
        else:
            h5f = h5py.File(filename, 'a')
            dt_feats = h5f['dataset_feats']
            dt_names = h5f['dataset_names']
            news_id_list = dt_names[:].tolist()
            if not news_id in news_id_list:
                dt_feats_shape = (dt_feats.shape[0] + len(feats), dt_feats.shape[1])
                dt_names_shape = (dt_names.shape[0] + len(names),)

                dt_feats.resize(dt_feats_shape)
                dt_names.resize(dt_names_shape)

                dt_feats[dt_feats_shape[0] - 1:dt_feats_shape[0]] = feats
                dt_names[dt_names_shape[0] - 1:dt_names_shape[0]] = names
            h5f.close()