Skip to content Skip to sidebar Skip to footer

Class Weights For Balancing Data In Tensorflow Object Detection Api

I'm fine-tuning SSD object detector using TensorFlow object detection API on Open Images Dataset. My training data contains imbalanced classes, e.g. top (5K images) dress (50K ima

Solution 1:

the API expects a weight for each object (bbox) directly in the annotation files. Due to this requirement the solutions to use class weights seem to be:

1) If you have a custom dataset you can modify the annotations of each object (bbox) to include the weight field as 'object/weight'.

2) If you don't want to modify the annotations you can recreate only the tf_records file in order to include the weights of the bboxes.

3) Modify the code of the API (seemed to me quite tricky)

I decided to go for #2, so I put here the code to generate such weighted tf records file for a custom dataset with two classes ("top", "dress") with weights (1.0, 0.1) given a folder of xml annotations as:

import os
import io
import glob
import hashlib
import pandas as pd
import xml.etree.ElementTree as ET
import tensorflow as tf
import random
from PIL import Image
from object_detection.utils import dataset_util

# Define the class names and their weight
class_names = ['top', 'dress', ...]
class_weights = [1.0, 0.1, ...]

defcreate_example(xml_file):

        tree = ET.parse(xml_file)
        root = tree.getroot()
        image_name = root.find('filename').text
        image_path = root.find('path').text
        file_name = image_name.encode('utf8')
        size=root.find('size')
        width = int(size[0].text)
        height = int(size[1].text)
        xmin = []
        ymin = []
        xmax = []
        ymax = []
        classes = []
        classes_text = []
        truncated = []
        poses = []
        difficult_obj = []
        weights = [] # Important linefor member in root.findall('object'):

           xmin.append(float(member[4][0].text) / width)
           ymin.append(float(member[4][1].text) / height)
           xmax.append(float(member[4][2].text) / width)
           ymax.append(float(member[4][3].text) / height)
           difficult_obj.append(0)

           class_name = member[0].text
           class_id = class_names.index(class_name)
           weights.append(class_weights[class_id])

           if class_name == 'top':
               classes_text.append('top'.encode('utf8'))
               classes.append(1)
           elif class_name == 'dress':
               classes_text.append('dress'.encode('utf8'))
               classes.append(2)
           else:
               print('E: class not recognized!')

           truncated.append(0)
           poses.append('Unspecified'.encode('utf8'))

        full_path = image_path 
        with tf.gfile.GFile(full_path, 'rb') as fid:
            encoded_jpg = fid.read()
        encoded_jpg_io = io.BytesIO(encoded_jpg)
        image = Image.open(encoded_jpg_io)
        if image.format != 'JPEG':
           raise ValueError('Image format not JPEG')
        key = hashlib.sha256(encoded_jpg).hexdigest()

        #create TFRecord Example
        example = tf.train.Example(features=tf.train.Features(feature={
            'image/height': dataset_util.int64_feature(height),
            'image/width': dataset_util.int64_feature(width),
            'image/filename': dataset_util.bytes_feature(file_name),
            'image/source_id': dataset_util.bytes_feature(file_name),
            'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
            'image/encoded': dataset_util.bytes_feature(encoded_jpg),
            'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
            'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
            'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
            'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
            'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
            'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
            'image/object/class/label': dataset_util.int64_list_feature(classes),
            'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
            'image/object/truncated': dataset_util.int64_list_feature(truncated),
            'image/object/view': dataset_util.bytes_list_feature(poses),
            'image/object/weight': dataset_util.float_list_feature(weights) # Important line
        })) 
        return example  

defmain(_):

    weighted_tf_records_output = 'name_of_records_file.record'# output file
    annotations_path = '/path/to/annotations/folder/*.xml'# input annotations

    writer_train = tf.python_io.TFRecordWriter(weighted_tf_records_output)
    filename_list=tf.train.match_filenames_once(annotations_path)
    init = (tf.global_variables_initializer(), tf.local_variables_initializer())
    sess=tf.Session()
    sess.run(init)
    list = sess.run(filename_list)
    random.shuffle(list)  

    for xml_file inlist:
      print('-> Processing {}'.format(xml_file))
      example = create_example(xml_file)
      writer_train.write(example.SerializeToString())

    writer_train.close()
    print('-> Successfully converted dataset to TFRecord.')


if __name__ == '__main__':
    tf.app.run()

If you have other kinds of annotations the code will be very similar but this one unfortunately will not work.

Solution 2:

The Object Detection API losses are defined in: https://github.com/tensorflow/models/blob/master/research/object_detection/core/losses.py

In particular, the following loss classes have been implemented:

Classification losses:

  1. WeightedSigmoidClassificationLoss
  2. SigmoidFocalClassificationLoss
  3. WeightedSoftmaxClassificationLoss
  4. WeightedSoftmaxClassificationAgainstLogitsLoss
  5. BootstrappedSigmoidClassificationLoss

Localization losses:

  1. WeightedL2LocalizationLoss
  2. WeightedSmoothL1LocalizationLoss
  3. WeightedIOULocalizationLoss

The weight parameters are used to balance anchors (prior boxes) and are of size [batch_size, num_anchors] in addition to hard negative mining. Alternatively, the focal loss down weighs well classified examples and focusses on the hard examples.

The primary class imbalance is due to many more negative examples (bounding boxes without objects of interest) in comparison to very few positive examples (bounding boxes with object classes). That seems to be the reason why class imbalance within positive examples (i.e. unequal distribution of positive class labels) is not implemented as part of object detection losses.

Post a Comment for "Class Weights For Balancing Data In Tensorflow Object Detection Api"