yongyong-e

[Tensorflow-Slim] Convert to TFRecord file 본문

머신러닝/Tensorflow - Models

[Tensorflow-Slim] Convert to TFRecord file

Yonghan Kim 2017. 8. 16. 15:28

summary

자동차 차종 분류를 위해 자동차 이미지들을 TFRecord 형식으로 변환하는 방법에 대해 진행

TFRecord :  바이너리 파일 형식으로 텐서플로우에서 data 저장 및 입·출력을 위해 사용


1) Preparing image files

분류 하고 싶은 이미지를 다음과 같이 cars5디렉토리 안의 cars5_photos디렉토리에 label 별로 저장


      


   


2) slim디렉토리의 *.py파일 수정

앞서 준비된 이미지를 TFRecord 형식으로 변환하기 전에  

2017/08/14 - [머신러닝/Tensorflow] - (Tensorflow-Slim) Tutorial 

앞선 글에서 Flowers Dataset을 TFRecord로 변환 하기 위해 사용했던 파일의 코드를 수정하면서 진행


① download_and_convert_data.py 파일 수정

import 추가

from datasets import download_and_convert_cars5

from datasets import download_and_convert_cifar10
from datasets import download_and_convert_flowers
from datasets import download_and_convert_mnist
from datasets import download_and_convert_cars5

main()함수에 cars5에 대한 코드를 추가

elif FLAGS.dataset_name == 'cars5':
download_and_convert_cars5.run(FLAGS.dataset_dir)

def main(_):
if not FLAGS.dataset_name:
raise ValueError('You must supply the dataset name with --dataset_name')
if not FLAGS.dataset_dir:
raise ValueError('You must supply the dataset directory with --dataset_dir')

if FLAGS.dataset_name == 'cifar10':
download_and_convert_cifar10.run(FLAGS.dataset_dir)
elif FLAGS.dataset_name == 'flowers':
download_and_convert_flowers.run(FLAGS.dataset_dir)
elif FLAGS.dataset_name == 'mnist':
download_and_convert_mnist.run(FLAGS.dataset_dir)
elif FLAGS.dataset_name == 'cars5':
download_and_convert_cars5.run(FLAGS.dataset_dir)
else:
raise ValueError(
'dataset_name [%s] was not recognized.' % FLAGS.dataset_name)


slim/datasets 디렉토리에 cars5에 대한 파일 만들기

flowers를 참고를 하기 위해 flowers.py 와 download_and_convert_flowers.py를 복사

flowers.py    -->    cars5.py

download_and_convert_flowers.py    -->    download_and_convert_cars5.py


③ cars5.py 파일 수정

train과 validation 사이즈는 train 사이즈는 총 600장의 이미지에서 80%, validation는 20%를 의미

_FILE_PATTERN = 'cars5_%s_*.tfrecord'

SPLITS_TO_SIZES = {'train': 480, 'validation': 120}

_NUM_CLASSES = 5


④ download_and_convert_cars5.py 파일 수정

_NUM_VALIDATION은 위 처럼 총 600장의 이미지에서 20%인 120을 의미

_NUM_SHARDS은 TFRecord로 변환될 파일의 크기로 현재 구성 된 cars5 이미지 데이터셋의 크기가 작으므로 1로 설정

# The number of images in the validation set.
_NUM_VALIDATION = 120

# Seed for repeatability.
_RANDOM_SEED = 0

# The number of shards per dataset split.
_NUM_SHARDS = 1

그리고 다음과 같이 flowers는 cars5로 변경

def _get_dataset_filename(dataset_dir, split_name, shard_id):
output_filename = 'cars5_%s_%05d-of-%05d.tfrecord' % (
split_name, shard_id, _NUM_SHARDS)
return os.path.join(dataset_dir, output_filename)

cars5_root = os.path.join(dataset_dir, 'cars5_photos')
directories = []
class_names = []
for filename in os.listdir(cars5_root):
    path = os.path.join(cars5_root, filename)
if os.path.isdir(path):
directories.append(path)
class_names.append(filename)


④ dataset_factory.py 파일 수정

import 추가

from datasets import cifar10
from datasets import flowers
from datasets import imagenet
from datasets import mnist
from datasets import cars5

cars5에 대한 datasets_map 추가

datasets_map = {
'cifar10': cifar10,
'flowers': flowers,
'imagenet': imagenet,
'mnist': mnist,
'cars5': cars5,
}


3) TFRecord 파일 형식으로 변환

자신의 dataset_name과 dataset_dir를 변경 후, 스크립트 실행

$ python download_and_convert_data.py \ --dataset_name=cars \ --dataset_dir=/home/yong/dataset/cars5


이후 labels.txt와 함께 *.tfrecord 파일이 생성된 것을 볼 수 있음


Comments