yongyong-e
[Tensorflow-Slim] Convert to TFRecord file 본문
summary
자동차 차종 분류를 위해 자동차 이미지들을 TFRecord 형식으로 변환하는 방법에 대해 진행
TFRecord : 바이너리 파일 형식으로 텐서플로우에서 data 저장 및 입·출력을 위해 사용
1) Preparing image files
분류 하고 싶은 이미지를 다음과 같이 cars5디렉토리 안의 cars5_photos디렉토리에 label 별로 저장
2) slim디렉토리의 *.py파일 수정
앞서 준비된 이미지를 TFRecord 형식으로 변환하기 전에
2017/08/14 - [머신러닝/Tensorflow] - (Tensorflow-Slim) Tutorial
앞선 글에서 Flowers Dataset을 TFRecord로 변환 하기 위해 사용했던 파일의 코드를 수정하면서 진행
① download_and_convert_data.py 파일 수정
import 추가
from datasets import download_and_convert_cars5
from datasets import download_and_convert_cifar10
from datasets import download_and_convert_flowers
from datasets import download_and_convert_mnist
from datasets import download_and_convert_cars5
main()함수에 cars5에 대한 코드를 추가
elif FLAGS.dataset_name == 'cars5':
download_and_convert_cars5.run(FLAGS.dataset_dir)
def main(_):
if not FLAGS.dataset_name:
raise ValueError('You must supply the dataset name with --dataset_name')
if not FLAGS.dataset_dir:
raise ValueError('You must supply the dataset directory with --dataset_dir')
if FLAGS.dataset_name == 'cifar10':
download_and_convert_cifar10.run(FLAGS.dataset_dir)
elif FLAGS.dataset_name == 'flowers':
download_and_convert_flowers.run(FLAGS.dataset_dir)
elif FLAGS.dataset_name == 'mnist':
download_and_convert_mnist.run(FLAGS.dataset_dir)
elif FLAGS.dataset_name == 'cars5':
download_and_convert_cars5.run(FLAGS.dataset_dir)
else:
raise ValueError(
'dataset_name [%s] was not recognized.' % FLAGS.dataset_name)
② slim/datasets 디렉토리에 cars5에 대한 파일 만들기
flowers를 참고를 하기 위해 flowers.py 와 download_and_convert_flowers.py를 복사
flowers.py --> cars5.py
download_and_convert_flowers.py --> download_and_convert_cars5.py
③ cars5.py 파일 수정
train과 validation 사이즈는 train 사이즈는 총 600장의 이미지에서 80%, validation는 20%를 의미
_FILE_PATTERN = 'cars5_%s_*.tfrecord'
SPLITS_TO_SIZES = {'train': 480, 'validation': 120}
_NUM_CLASSES = 5
④ download_and_convert_cars5.py 파일 수정
_NUM_VALIDATION은 위 처럼 총 600장의 이미지에서 20%인 120을 의미
_NUM_SHARDS은 TFRecord로 변환될 파일의 크기로 현재 구성 된 cars5 이미지 데이터셋의 크기가 작으므로 1로 설정
# The number of images in the validation set.
_NUM_VALIDATION = 120
# Seed for repeatability.
_RANDOM_SEED = 0
# The number of shards per dataset split.
_NUM_SHARDS = 1
그리고 다음과 같이 flowers는 cars5로 변경
def _get_dataset_filename(dataset_dir, split_name, shard_id):
output_filename = 'cars5_%s_%05d-of-%05d.tfrecord' % (
split_name, shard_id, _NUM_SHARDS)
return os.path.join(dataset_dir, output_filename)
cars5_root = os.path.join(dataset_dir, 'cars5_photos')
directories = []
class_names = []
for filename in os.listdir(cars5_root):
path = os.path.join(cars5_root, filename)
if os.path.isdir(path):
directories.append(path)
class_names.append(filename)
④ dataset_factory.py 파일 수정
import 추가
from datasets import cifar10
from datasets import flowers
from datasets import imagenet
from datasets import mnist
from datasets import cars5
cars5에 대한 datasets_map 추가
datasets_map = {
'cifar10': cifar10,
'flowers': flowers,
'imagenet': imagenet,
'mnist': mnist,
'cars5': cars5,
}
3) TFRecord 파일 형식으로 변환
자신의 dataset_name과 dataset_dir를 변경 후, 스크립트 실행
$ python download_and_convert_data.py \ --dataset_name=cars \ --dataset_dir=/home/yong/dataset/cars5
이후 labels.txt와 함께 *.tfrecord 파일이 생성된 것을 볼 수 있음
'머신러닝 > Tensorflow - Models' 카테고리의 다른 글
[Tensorflow Object Detection API] 2. Training your own dataset (11) | 2017.08.30 |
---|---|
[Tensorflow Object Detection API] 1. Creating your own dataset (25) | 2017.08.29 |
[Tensorflow-Slim] Tutorial (0) | 2017.08.14 |
[Tensorflow Object Detection API] Training a pet detector (0) | 2017.08.03 |
[Tensorflow Object Detection API] How to install (0) | 2017.07.28 |