yolo2(案例)1 作者:马育民 • 2020-10-07 17:21 • 阅读:10087 ### 数据集 使用 pascal-voc-trainval-2007 数据集 ### 导包 ``` import os,glob os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import numpy as np import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers import matplotlib.pyplot as plt from matplotlib import patches from PIL import Image import IPython.display as display from lxml import etree import xml.etree.ElementTree as ET print(tf.__version__) print(tf.test.is_gpu_available()) ``` ### 定义常量 ``` obj_names = ("aeroplane","bicycle","bird","boat","bottle","bus","car","cat","chair","cow","diningtable","dog","horse","motorbike","person","pottedplant","sheep","sofa","train","tvmonitor","head","hand","foot") img_path='/kaggle/input/pascal-voc-trainval-2007/VOCdevkit/VOC2007/JPEGImages' # img_val_path='/kaggle/input/grass-sugarbeet/data/val/image' ann_path='/kaggle/input/pascal-voc-trainval-2007/VOCdevkit/VOC2007/Annotations' yolo_weight_path='/kaggle/input/yoloweights/yolo.weights' ``` ``` IMGSZ = 416 GRIDSZ = 13 batch=8 #如果内存不够就减小 """ 预设ANCHORS 宽高,第一个为宽度、第二个为高度 """ ANCHORS = np.array(((0.57273, 0.677385), (1.87446, 2.06253), (3.33843, 5.47434), (7.88282, 3.52778), (9.77052, 9.16828))) ``` # 实现函数 获取所有图片路径,与图片对应的边框坐标 由于边框数量不一致,所以边框的矩阵要以最多的为主 ``` DEBUG=False def parse_annotation(img_dir, ann_dir, labels): # img_dir: image path # ann_dir: annotation xml file path # labels: ('sugarweet', 'weed') # parse annotation info from xml file """ sugarbeet 1 250 53 289 .... """ imgs_info = [] max_boxes = 0 # for each annotation xml file for ann in os.listdir(ann_dir): tree = ET.parse(os.path.join(ann_dir, ann)) img_info = dict() boxes_counter = 0 img_info['object'] = [] for elem in tree.iter(): if 'filename' in elem.tag: img_info['filename'] = os.path.join(img_dir,elem.text) if 'width' in elem.tag: img_info['width'] = int(elem.text) # assert img_info['width'] == 512 if 'height' in elem.tag: img_info['height'] = int(elem.text) # assert img_info['height'] == 512 if 'object' in elem.tag or 'part' in elem.tag: # x1-y1-x2-y2-label object_info = [0,0,0,0,0] boxes_counter += 1 for attr in list(elem): if 'name' in attr.tag: try: label = labels.index(attr.text) + 1 except Exception as e: print("error:",attr.text,img_info['filename']) raise e object_info[4] = label # if DEBUG: # print("label::",label) if 'bndbox' in attr.tag: for pos in list(attr): if 'xmin' in pos.tag: object_info[0] = int(float(pos.text)) if 'ymin' in pos.tag: object_info[1] = int(float(pos.text)) if 'xmax' in pos.tag: object_info[2] = int(float(pos.text)) if 'ymax' in pos.tag: object_info[3] = int(float(pos.text)) img_info['object'].append(object_info) imgs_info.append(img_info) # filename, w/h/box_info # (N,5)=(max_objects_num, 5) if boxes_counter > max_boxes: max_boxes = boxes_counter # the maximum boxes number is max_boxes # [b, 40, 5] boxes = np.zeros([len(imgs_info), max_boxes, 5]) # print(boxes.shape) imgs = [] # filename list imgs_size=[] for i, img_info in enumerate(imgs_info): # [N,5] img_boxes = np.array(img_info['object']) # overwrite the N boxes info boxes[i,:img_boxes.shape[0]] = img_boxes imgs.append(img_info['filename']) imgs_size.append([img_info['width'],img_info['height']]) # print(img_info['filename'], boxes[i,:5]) # imgs: list of image path # boxes: [b,40,5] return imgs, tf.constant(boxes),tf.constant(imgs_size) ``` 调用 ``` imgs, boxes,imgs_size = parse_annotation(img_path, ann_path, obj_names) ``` 图片总数 ``` img_size=len(imgs) img_size ``` 查看返回结果信息 ``` boxes.shape,imgs_size.shape ``` ### 测试、查看 实现函数,显示图片,并框出目标,测试数据读取是否正确 ``` from matplotlib import pyplot as plt from matplotlib import patches from PIL import Image def show_by_path(img, img_boxes): # imgs:[b, 512, 512, 3] # imgs_boxes: [b, 40, 5] f,ax1 = plt.subplots(1,figsize=(10,10)) # display the image, [512,512,3] ax1.imshow(Image.open(img)) for x1,y1,x2,y2,l in img_boxes: # [40,5] x1,y1,x2,y2 = float(x1), float(y1), float(x2), float(y2) w = x2 - x1 h = y2 - y1 """ if l==1: # green for sugarweet color = (0,1,0) elif l==2: # red for weed color = (1,0,0) # (R,G,B) else: # ignore invalid boxes break """ if l>=1 and l<=len(obj_names): # green for sugarweet color = (0,1,0) else: # ignore invalid boxes break rect = patches.Rectangle((x1,y1), w, h, linewidth=2, edgecolor=color, facecolor='none') ax1.add_patch(rect) ``` ``` show_by_path(imgs[1], boxes[1]) show_by_path(imgs[2], boxes[2]) show_by_path(imgs[30], boxes[30]) show_by_path(imgs[300], boxes[300]) show_by_path(imgs[4000], boxes[4000]) ``` ### 调整边框坐标 ``` DEBUG=True def modify_coordinate(boxes,imgs_size): # if DEBUG: # print(boxes.shape,imgs_size.shape) w=imgs_size[:,0] h=imgs_size[:,1] w=tf.reshape(w,[w.shape[0],1]) w=tf.cast(w,tf.float64) h=tf.reshape(h,[h.shape[0],1]) h=tf.cast(h,tf.float64) # if DEBUG: # print(w,h) xmin=boxes[:,:,0] ymin=boxes[:,:,1] xmax=boxes[:,:,2] ymax=boxes[:,:,3] typ=boxes[:,:,4] # if DEBUG: # print("xmin.shape::",xmin.shape) # # print(xmin) # print(ymin.shape) # print(xmax.shape) # print(ymax.shape) # print("typ.shape:",typ.shape) # pos_info=boxes[index] # w,h=imgs_size[index] # # plt.imshow(img.numpy()) # # print(w,h) # box_xmin=tf.divide(xmin*IMGSZ,tf.cast(w,tf.float64)) box_xmin=xmin*IMGSZ/w box_ymin=ymin*IMGSZ/h box_xmax=xmax*IMGSZ/w box_ymax=ymax*IMGSZ/h # if DEBUG: # print("box_xmin::",box_xmin) res=tf.stack([box_xmin,box_ymin,box_xmax,box_ymax,typ],axis=1) res=tf.transpose(res,perm=[0,2,1]) res=tf.cast(res,tf.int32) return res boxes_resize=modify_coordinate(boxes,imgs_size) # print(boxes_resize.shape) ``` ``` def preprocess(img, img_boxes): # img: string # img_boxes: [40,5] x = tf.io.read_file(img) x = tf.image.decode_png(x, channels=3) x = tf.image.resize(x,(IMGSZ,IMGSZ)) x = tf.image.convert_image_dtype(x, tf.float32) return x, img_boxes ``` ``` def get_dataset(img_dir, ann_dir): # return tf dataset # [b], boxes [b, 40, 5] # imgs, boxes,imgs_size = parse_annotation(img_dir, ann_dir, obj_names) # boxes_resize=modify_coordinate(boxes,imgs_size) db = tf.data.Dataset.from_tensor_slices((imgs, boxes_resize)) db = db.map(preprocess).shuffle(img_size).batch(batch) print('db Images:', len(imgs)) return db ``` ``` train_db = get_dataset(img_path, ann_path) print(train_db) ``` ### 显示db中的图片和边框 ``` def show_by_mat(img, img_boxes): """ img:图像矩阵 img_boxes:原始盒子(左上、右下坐标系) """ # print(img.shape) # imgs:[b, 512, 512, 3] # imgs_boxes: [b, 40, 5] f,ax1 = plt.subplots(1,figsize=(10,10)) # display the image, [512,512,3] ax1.imshow(img) # ax1.imshow(Image.open(img)) for x1,y1,x2,y2,l in img_boxes: # [40,5] x1,y1,x2,y2 = float(x1), float(y1), float(x2), float(y2) w = x2 - x1 h = y2 - y1 """ if l==1: # green for sugarweet color = (0,1,0) elif l==2: # red for weed color = (1,0,0) # (R,G,B) else: # ignore invalid boxes break """ if l>=1 and l<=len(obj_names): # green for sugarweet color = (0,1,0) else: # ignore invalid boxes break rect = patches.Rectangle((x1,y1), w, h, linewidth=2, edgecolor=color, facecolor='none') ax1.add_patch(rect) ``` ``` def process_true_boxes(gt_boxes): # gt_boxes: [40,5] # 512//16=32 scale = IMGSZ // GRIDSZ # [5,2] # mask for object detector_mask = np.zeros([GRIDSZ, GRIDSZ, 5, 1]) #x-y-w-h-l matching_gt_box = np.zeros([GRIDSZ, GRIDSZ, 5, 5]) # [40,5] x1-y1-x2-y2-l => x-y-w-h-l gt_boxes_grid = np.zeros(gt_boxes.shape) # DB: tensor => numpy gt_boxes = gt_boxes.numpy() for i in range(gt_boxes.shape[0]): # [40,5] # box: [5], x1-y1-x2-y2-l # 512 => 16 box=gt_boxes[i] x = ((box[0]+box[2])/2)/scale y = ((box[1]+box[3])/2)/scale w = (box[2] - box[0]) / scale h = (box[3] - box[1]) / scale # [40,5] x-y-w-h-l gt_boxes_grid[i] = [x,y,w,h,box[4]] if w*h > 0: # valid box # x,y: 7.3, 6.8 best_anchor = 0 best_iou = 0 for j in range(5): interct = np.minimum(w, ANCHORS[j,0]) * np.minimum(h, ANCHORS[j,1]) union = w*h + (ANCHORS[j,0]*ANCHORS[j,1]) - interct iou = interct / union if iou > best_iou: # best iou best_anchor = j best_iou = iou # found the best anchors if best_iou>0: x_coord = np.floor(x).astype(np.int32) y_coord = np.floor(y).astype(np.int32) # [b,h,w,5,1] detector_mask[y_coord, x_coord, best_anchor] = 1 # [b,h,w,5,x-y-w-h-l] matching_gt_box[y_coord, x_coord, best_anchor] = np.array([x,y,w,h,box[4]]) # [40,5] => [16,16,5,5] # [16,16,5,5] # [16,16,5,1] # [40,5] return matching_gt_box, detector_mask, gt_boxes_grid ``` ``` def show_by_mat_center(img, gt_boxes_grid): """ img:图像矩阵 boxes:盒子,中心点(x,y),宽度、高度、分类 """ # imgs:[b, 512, 512, 3] # imgs_boxes: [b, 40, 5] img=tf.cast(img,tf.int32) f,ax1 = plt.subplots(1,figsize=(10,10)) # display the image, [512,512,3] ax1.imshow(img) scale = IMGSZ // GRIDSZ for x,y,w,h,l in gt_boxes_grid: # [40,5] x,y,w,h = float(x)*scale, float(y)*scale, float(w)*scale, float(h)*scale x=x-w/2 y=y-h/2 if l>=1 and l<=len(obj_names): # green for sugarweet color = (0,1,0) else: # ignore invalid boxes break rect = patches.Rectangle((x,y), w, h, linewidth=2, edgecolor=color, facecolor='none') ax1.add_patch(rect) ``` ``` def ground_truth_generator(db): for imgs, imgs_boxes in db: # imgs: [b,512,512,3] # imgs_boxes: [b,40,5] batch_matching_gt_box = [] batch_detector_mask = [] batch_gt_boxes_grid = [] # print(imgs_boxes[0,:5]) # print("imgs_boxes.shape::",imgs_boxes.shape) for i in range(imgs_boxes.shape[0]): # for each image # print("imgs_boxes[i].shape::",imgs_boxes[i].shape) matching_gt_box, detector_mask, gt_boxes_grid = process_true_boxes(imgs_boxes[i]) batch_matching_gt_box.append(matching_gt_box) batch_detector_mask.append(detector_mask) batch_gt_boxes_grid.append(gt_boxes_grid) # [b, 16,16,5,1] detector_mask = tf.cast(np.array(batch_detector_mask), dtype=tf.float32) # [b,16,16,5,5] x-y-w-h-l matching_gt_box = tf.cast(np.array(batch_matching_gt_box), dtype=tf.float32) # [b,40,5] x-y-w-h-l gt_boxes_grid = tf.cast(np.array(batch_gt_boxes_grid), dtype=tf.float32) # [b,16,16,5] matching_classes = tf.cast(matching_gt_box[...,4], dtype=tf.int32) # [b,16,16,5,3] matching_classes_oh = tf.one_hot(matching_classes, depth=3) # x-y-w-h-conf-l1-l2 # [b,16,16,5,2] matching_classes_oh = tf.cast(matching_classes_oh[...,1:], dtype=tf.float32) # [b,512,512,3] # [b,16,16,5,1] # [b,16,16,5,5] # [b,16,16,5,2] # [b,40,5] yield imgs, detector_mask, matching_gt_box, matching_classes_oh,gt_boxes_grid ``` 原文出处:http://malaoshi.top/show_1EF6OcOcBxCt.html