Тонкая настройка Resnet-18 с Tensorflow Model Garden для классификации CIFAR-10

Тонкая настройка Resnet-18 с Tensorflow Model Garden для классификации CIFAR-10

13 августа 2025 г.

Обзор контента

  • Настраивать
  • Настройте модель RESNET-18 для набора данных CIFAR-10
  • Визуализировать обучающую модель
  • Визуализировать модель тестирования
  • Поезда и оценить
  • Экспорт сохраненной модели

Этот учебник изыскает остаточную сеть (Resnet) из TensorFlowМодельный садупаковка (tensorflow-models) классифицировать изображения вCifarнабор данных.

Model Garden содержит коллекцию современных моделей зрения, внедренных с API высокого уровня Tensorflow. Реализации демонстрируют лучшие методы моделирования, позволяя пользователям в полной мере воспользоваться TensorFlow для их исследований и разработки продукта.

В этом уроке используетсяResnetМодель, современный классификатор изображения. В этом уроке используется модель Resnet-18, сверточную нейронную сеть с 18 уровнями.

Этот учебник демонстрирует, как:

  1. Используйте модели из пакета Tensorflow Models.
  2. Fine-Tune предварительно построенный Resnet для классификации изображений.
  3. Экспортируйте настроенную модель Resnet.

Настраивать

Установите и импортируйте необходимые модули.

pip install -U -q "tf-models-official"

Импортируйте TensorFlow, наборы данных TensorFlow и несколько помощников.

import pprint
import tempfile

from IPython import display
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_datasets as tfds

2023-10-17 11:52:54.005237: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-17 11:52:54.005294: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-17 11:52:54.005338: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

Аtensorflow_modelsпакет содержит модель Resnet Vision иofficial.vision.servingМодель содержит функцию для сохранения и экспорта настроенной модели.

import tensorflow_models as tfm

# These are not in the tfm public API for v2.9. They will be available in v2.10
from official.vision.serving import export_saved_model_lib
import official.core.train_lib

Настройте модель RESNET-18 для набора данных CIFAR-10

Набор данных CIFAR10 содержит 60 000 цветных изображений в взаимоисключающих 10 классах с 6000 изображений в каждом классе.

В модельном саду коллекции параметров, которые определяют модель, называютсяконфигурацииПолем Model Garden может создать конфигурацию на основе известного набора параметров черезфабрикаПолем

Используйтеresnet_imagenetконфигурация фабрики, как определеноtfm.vision.configs.image_classification.image_classification_imagenetПолем Конфигурация настроена на обучение Resnet для сходимости наImageNetПолем

exp_config = tfm.core.exp_factory.get_exp_config('resnet_imagenet')
tfds_name = 'cifar10'
ds,ds_info = tfds.load(
tfds_name,
with_info=True)
ds_info

2023-10-17 11:52:59.285390: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
tfds.core.DatasetInfo(
    name='cifar10',
    full_name='cifar10/3.0.2',
    description="""
    The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.
    """,
    homepage='https://www.cs.toronto.edu/~kriz/cifar.html',
    data_dir='gs://tensorflow-datasets/datasets/cifar10/3.0.2',
    file_format=tfrecord,
    download_size=162.17 MiB,
    dataset_size=132.40 MiB,
    features=FeaturesDict({
        'id': Text(shape=(), dtype=string),
        'image': Image(shape=(32, 32, 3), dtype=uint8),
        'label': ClassLabel(shape=(), dtype=int64, num_classes=10),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'test': <SplitInfo num_examples=10000, num_shards=1>,
        'train': <SplitInfo num_examples=50000, num_shards=1>,
    },
    citation="""@TECHREPORT{Krizhevsky09learningmultiple,
        author = {Alex Krizhevsky},
        title = {Learning multiple layers of features from tiny images},
        institution = {},
        year = {2009}
    }""",
)

Отрегулируйте конфигурации модели и набора данных так, чтобы она работала с CIFAR-10 (cifar10)

# Configure model
exp_config.task.model.num_classes = 10
exp_config.task.model.input_size = list(ds_info.features["image"].shape)
exp_config.task.model.backbone.resnet.model_id = 18

# Configure training and testing data
batch_size = 128

exp_config.task.train_data.input_path = ''
exp_config.task.train_data.tfds_name = tfds_name
exp_config.task.train_data.tfds_split = 'train'
exp_config.task.train_data.global_batch_size = batch_size

exp_config.task.validation_data.input_path = ''
exp_config.task.validation_data.tfds_name = tfds_name
exp_config.task.validation_data.tfds_split = 'test'
exp_config.task.validation_data.global_batch_size = batch_size

Отрегулируйте конфигурацию тренера.

logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]

if 'GPU' in ''.join(logical_device_names):
  print('This may be broken in Colab.')
  device = 'GPU'
elif 'TPU' in ''.join(logical_device_names):
  print('This may be broken in Colab.')
  device = 'TPU'
else:
  print('Running on CPU is slow, so only train for a few steps.')
  device = 'CPU'

if device=='CPU':
  train_steps = 20
  exp_config.trainer.steps_per_loop = 5
else:
  train_steps=5000
  exp_config.trainer.steps_per_loop = 100

exp_config.trainer.summary_interval = 100
exp_config.trainer.checkpoint_interval = train_steps
exp_config.trainer.validation_interval = 1000
exp_config.trainer.validation_steps =  ds_info.splits['test'].num_examples // batch_size
exp_config.trainer.train_steps = train_steps
exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'
exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps
exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1
exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 100

Running on CPU is slow, so only train for a few steps.

Распечатайте измененную конфигурацию.

pprint.pprint(exp_config.as_dict())

display.Javascript("google.colab.output.setIframeHeight('300px');")

{'runtime': {'all_reduce_alg': None,
             'batchnorm_spatial_persistent': False,
             'dataset_num_private_threads': None,
             'default_shard_dim': -1,
             'distribution_strategy': 'mirrored',
             'enable_xla': True,
             'gpu_thread_mode': None,
             'loss_scale': None,
             'mixed_precision_dtype': None,
             'num_cores_per_replica': 1,
             'num_gpus': 0,
             'num_packs': 1,
             'per_gpu_thread_count': 0,
             'run_eagerly': False,
             'task_index': -1,
             'tpu': None,
             'tpu_enable_xla_dynamic_padder': None,
             'use_tpu_mp_strategy': False,
             'worker_hosts': None},
 'task': {'allow_image_summary': False,
          'differential_privacy_config': None,
          'eval_input_partition_dims': [],
          'evaluation': {'precision_and_recall_thresholds': None,
                         'report_per_class_precision_and_recall': False,
                         'top_k': 5},
          'freeze_backbone': False,
          'init_checkpoint': None,
          'init_checkpoint_modules': 'all',
          'losses': {'l2_weight_decay': 0.0001,
                     'label_smoothing': 0.0,
                     'loss_weight': 1.0,
                     'one_hot': True,
                     'soft_labels': False,
                     'use_binary_cross_entropy': False},
          'model': {'add_head_batch_norm': False,
                    'backbone': {'resnet': {'bn_trainable': True,
                                            'depth_multiplier': 1.0,
                                            'model_id': 18,
                                            'replace_stem_max_pool': False,
                                            'resnetd_shortcut': False,
                                            'scale_stem': True,
                                            'se_ratio': 0.0,
                                            'stem_type': 'v0',
                                            'stochastic_depth_drop_rate': 0.0},
                                 'type': 'resnet'},
                    'dropout_rate': 0.0,
                    'input_size': [32, 32, 3],
                    'kernel_initializer': 'random_uniform',
                    'norm_activation': {'activation': 'relu',
                                        'norm_epsilon': 1e-05,
                                        'norm_momentum': 0.9,
                                        'use_sync_bn': False},
                    'num_classes': 10,
                    'output_softmax': False},
          'model_output_keys': [],
          'name': None,
          'train_data': {'apply_tf_data_service_before_batching': False,
                         'aug_crop': True,
                         'aug_policy': None,
                         'aug_rand_hflip': True,
                         'aug_type': None,
                         'autotune_algorithm': None,
                         'block_length': 1,
                         'cache': False,
                         'center_crop_fraction': 0.875,
                         'color_jitter': 0.0,
                         'crop_area_range': (0.08, 1.0),
                         'cycle_length': 10,
                         'decode_jpeg_only': True,
                         'decoder': {'simple_decoder': {'attribute_names': [],
                                                        'mask_binarize_threshold': None,
                                                        'regenerate_source_id': False},
                                     'type': 'simple_decoder'},
                         'deterministic': None,
                         'drop_remainder': True,
                         'dtype': 'float32',
                         'enable_shared_tf_data_service_between_parallel_trainers': False,
                         'enable_tf_data_service': False,
                         'file_type': 'tfrecord',
                         'global_batch_size': 128,
                         'image_field_key': 'image/encoded',
                         'input_path': '',
                         'is_multilabel': False,
                         'is_training': True,
                         'label_field_key': 'image/class/label',
                         'mixup_and_cutmix': None,
                         'prefetch_buffer_size': None,
                         'randaug_magnitude': 10,
                         'random_erasing': None,
                         'repeated_augment': None,
                         'seed': None,
                         'sharding': True,
                         'shuffle_buffer_size': 10000,
                         'tf_data_service_address': None,
                         'tf_data_service_job_name': None,
                         'tf_resize_method': 'bilinear',
                         'tfds_as_supervised': False,
                         'tfds_data_dir': '',
                         'tfds_name': 'cifar10',
                         'tfds_skip_decoding_feature': '',
                         'tfds_split': 'train',
                         'three_augment': False,
                         'trainer_id': None,
                         'weights': None},
          'train_input_partition_dims': [],
          'validation_data': {'apply_tf_data_service_before_batching': False,
                              'aug_crop': True,
                              'aug_policy': None,
                              'aug_rand_hflip': True,
                              'aug_type': None,
                              'autotune_algorithm': None,
                              'block_length': 1,
                              'cache': False,
                              'center_crop_fraction': 0.875,
                              'color_jitter': 0.0,
                              'crop_area_range': (0.08, 1.0),
                              'cycle_length': 10,
                              'decode_jpeg_only': True,
                              'decoder': {'simple_decoder': {'attribute_names': [],
                                                             'mask_binarize_threshold': None,
                                                             'regenerate_source_id': False},
                                          'type': 'simple_decoder'},
                              'deterministic': None,
                              'drop_remainder': True,
                              'dtype': 'float32',
                              'enable_shared_tf_data_service_between_parallel_trainers': False,
                              'enable_tf_data_service': False,
                              'file_type': 'tfrecord',
                              'global_batch_size': 128,
                              'image_field_key': 'image/encoded',
                              'input_path': '',
                              'is_multilabel': False,
                              'is_training': False,
                              'label_field_key': 'image/class/label',
                              'mixup_and_cutmix': None,
                              'prefetch_buffer_size': None,
                              'randaug_magnitude': 10,
                              'random_erasing': None,
                              'repeated_augment': None,
                              'seed': None,
                              'sharding': True,
                              'shuffle_buffer_size': 10000,
                              'tf_data_service_address': None,
                              'tf_data_service_job_name': None,
                              'tf_resize_method': 'bilinear',
                              'tfds_as_supervised': False,
                              'tfds_data_dir': '',
                              'tfds_name': 'cifar10',
                              'tfds_skip_decoding_feature': '',
                              'tfds_split': 'test',
                              'three_augment': False,
                              'trainer_id': None,
                              'weights': None} },
 'trainer': {'allow_tpu_summary': False,
             'best_checkpoint_eval_metric': '',
             'best_checkpoint_export_subdir': '',
             'best_checkpoint_metric_comp': 'higher',
             'checkpoint_interval': 20,
             'continuous_eval_timeout': 3600,
             'eval_tf_function': True,
             'eval_tf_while_loop': False,
             'loss_upper_bound': 1000000.0,
             'max_to_keep': 5,
             'optimizer_config': {'ema': None,
                                  'learning_rate': {'cosine': {'alpha': 0.0,
                                                               'decay_steps': 20,
                                                               'initial_learning_rate': 0.1,
                                                               'name': 'CosineDecay',
                                                               'offset': 0},
                                                    'type': 'cosine'},
                                  'optimizer': {'sgd': {'clipnorm': None,
                                                        'clipvalue': None,
                                                        'decay': 0.0,
                                                        'global_clipnorm': None,
                                                        'momentum': 0.9,
                                                        'name': 'SGD',
                                                        'nesterov': False},
                                                'type': 'sgd'},
                                  'warmup': {'linear': {'name': 'linear',
                                                        'warmup_learning_rate': 0,
                                                        'warmup_steps': 100},
                                             'type': 'linear'} },
             'preemption_on_demand_checkpoint': True,
             'recovery_begin_steps': 0,
             'recovery_max_trials': 0,
             'steps_per_loop': 5,
             'summary_interval': 100,
             'train_steps': 20,
             'train_tf_function': True,
             'train_tf_while_loop': True,
             'validation_interval': 1000,
             'validation_steps': 78,
             'validation_summary_subdir': 'validation'} }
<IPython.core.display.Javascript object>

Создайте стратегию распространения.

logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]

if exp_config.runtime.mixed_precision_dtype == tf.float16:
    tf.keras.mixed_precision.set_global_policy('mixed_float16')

if 'GPU' in ''.join(logical_device_names):
  distribution_strategy = tf.distribute.MirroredStrategy()
elif 'TPU' in ''.join(logical_device_names):
  tf.tpu.experimental.initialize_tpu_system()
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')
  distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
  print('Warning: this will be really slow.')
  distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])

Warning: this will be really slow.

СоздатьTaskобъект (tfm.core.base_task.Task) изconfig_definitions.TaskConfigПолем

АTaskУ объекта есть все методы, необходимые для создания набора данных, создания модели и работы обучения и оценки. Эти методы обусловленыtfm.core.train_lib.run_experimentПолем

with distribution_strategy.scope():
  model_dir = tempfile.mkdtemp()
  task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)

#  tf.keras.utils.plot_model(task.build_model(), show_shapes=True)

for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
  print()
  print(f'images.shape: {str(images.shape):16}  images.dtype: {images.dtype!r}')
  print(f'labels.shape: {str(labels.shape):16}  labels.dtype: {labels.dtype!r}')

images.shape: (128, 32, 32, 3)  images.dtype: tf.float32
labels.shape: (128,)            labels.dtype: tf.int32
2023-10-17 11:53:02.248801: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Визуализировать данные обучения

DataLoader применяет нормализацию Z-показателя с использованиемpreprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB), поэтому изображения, возвращаемые набором данных, не могут быть напрямую отображаться стандартными инструментами. Код визуализации должен пересекать данные в диапазоне [0,1].

plt.hist(images.numpy().flatten());

Использоватьds_info(который является экземпляромtfds.core.DatasetInfo) поиск текстовых описаний каждого идентификатора класса.

label_info = ds_info.features['label']
label_info.int2str(1)

'automobile'

Визуализируйте партию данных.

def show_batch(images, labels, predictions=None):
  plt.figure(figsize=(10, 10))
  min = images.numpy().min()
  max = images.numpy().max()
  delta = max - min

  for i in range(12):
    plt.subplot(6, 6, i + 1)
    plt.imshow((images[i]-min) / delta)
    if predictions is None:
      plt.title(label_info.int2str(labels[i]))
    else:
      if labels[i] == predictions[i]:
        color = 'g'
      else:
        color = 'r'
      plt.title(label_info.int2str(predictions[i]), color=color)
    plt.axis("off")

plt.figure(figsize=(10, 10))
for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
  show_batch(images, labels)

2023-10-17 11:53:04.198417: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Визуализировать данные тестирования

Визуализируйте партию изображений из набора данных проверки.

plt.figure(figsize=(10, 10));
for images, labels in task.build_inputs(exp_config.task.validation_data).take(1):
  show_batch(images, labels)

2023-10-17 11:53:07.007846: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Поезда и оценить

model, eval_logs = tfm.core.train_lib.run_experiment(
    distribution_strategy=distribution_strategy,
    task=task,
    mode='train_and_eval',
    params=exp_config,
    model_dir=model_dir,
    run_post_eval=True)

restoring or initializing model...
INFO:tensorflow:Customized initialization is done through the passed `init_fn`.
INFO:tensorflow:Customized initialization is done through the passed `init_fn`.
train | step:      0 | training until step 20...
2023-10-17 11:53:09.849007: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
train | step:      5 | steps/sec:    0.5 | output: 
    {'accuracy': 0.103125,
     'learning_rate': 0.0,
     'top_5_accuracy': 0.4828125,
     'training_loss': 2.7998607}
saved checkpoint to /tmpfs/tmp/tmpu0ate1h5/ckpt-5.
train | step:     10 | steps/sec:    0.8 | output: 
    {'accuracy': 0.0828125,
     'learning_rate': 0.0,
     'top_5_accuracy': 0.4984375,
     'training_loss': 2.8205295}
train | step:     15 | steps/sec:    0.8 | output: 
    {'accuracy': 0.0921875,
     'learning_rate': 0.0,
     'top_5_accuracy': 0.503125,
     'training_loss': 2.8169343}
train | step:     20 | steps/sec:    0.8 | output: 
    {'accuracy': 0.1015625,
     'learning_rate': 0.0,
     'top_5_accuracy': 0.45,
     'training_loss': 2.8760865}
 eval | step:     20 | running 78 steps of evaluation...
 eval | step:     20 | steps/sec:   24.4 | eval time:    3.2 sec | output: 
    {'accuracy': 0.09485176,
     'steps_per_second': 24.40085348913806,
     'top_5_accuracy': 0.49589342,
     'validation_loss': 2.5864375}
saved checkpoint to /tmpfs/tmp/tmpu0ate1h5/ckpt-20.
2023-10-17 11:53:43.844533: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/nn_ops.py:5253: tensor_shape_from_node_def_name (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This API was designed for TensorFlow v1. See https://www.tensorflow.org/guide/migrate for instructions on how to migrate your code to TensorFlow v2.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/nn_ops.py:5253: tensor_shape_from_node_def_name (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This API was designed for TensorFlow v1. See https://www.tensorflow.org/guide/migrate for instructions on how to migrate your code to TensorFlow v2.
eval | step:     20 | running 78 steps of evaluation...
2023-10-17 11:53:45.627213: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
eval | step:     20 | steps/sec:   40.1 | eval time:    1.9 sec | output: 
    {'accuracy': 0.09485176,
     'steps_per_second': 40.14298727815298,
     'top_5_accuracy': 0.49589342,
     'validation_loss': 2.5864375}

#  tf.keras.utils.plot_model(model, show_shapes=True)

РаспечататьaccuracyВtop_5_accuracy, иvalidation_lossпоказатели оценки.

for key, value in eval_logs.items():
    if isinstance(value, tf.Tensor):
      value = value.numpy()
    print(f'{key:20}: {value:.3f}')

accuracy            : 0.095
top_5_accuracy      : 0.496
validation_loss     : 2.586
steps_per_second    : 40.143

Запустите партию обработанных учебных данных через модель и просмотрите результаты

for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
  predictions = model.predict(images)
  predictions = tf.argmax(predictions, axis=-1)

show_batch(images, labels, tf.cast(predictions, tf.int32))

if device=='CPU':
  plt.suptitle('The model was only trained for a few steps, it is not expected to do well.')

2023-10-17 11:53:49.840600: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
4/4 [==============================] - 1s 13ms/step
2023-10-17 11:53:50.778301: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Экспорт сохраненной модели

Аkeras.Modelобъект возвращаетсяtrain_lib.run_experimentОжидает, что данные будут нормализованы погрузчиком набора данных с использованием того же среднего значения и дисперсии в статишике вpreprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)Полем Эта экспортная функция обрабатывает эти детали, так что вы можете пройтиtf.uint8изображения и получите правильные результаты.

# Saving and exporting the trained model
export_saved_model_lib.export_inference_graph(
    input_type='image_tensor',
    batch_size=1,
    input_image_size=[32, 32],
    params=exp_config,
    checkpoint_path=tf.train.latest_checkpoint(model_dir),
    export_dir='./export/')

INFO:tensorflow:Assets written to: ./export/assets
INFO:tensorflow:Assets written to: ./export/assets

Проверьте экспортируемую модель.

# Importing SavedModel
imported = tf.saved_model.load('./export/')
model_fn = imported.signatures['serving_default']

Визуализируйте прогнозы.

plt.figure(figsize=(10, 10))
for data in tfds.load('cifar10', split='test').batch(12).take(1):
  predictions = []
  for image in data['image']:
    index = tf.argmax(model_fn(image[tf.newaxis, ...])['logits'], axis=1)[0]
    predictions.append(index)
  show_batch(data['image'], data['label'], predictions)

  if device=='CPU':
    plt.suptitle('The model was only trained for a few steps, it is not expected to do better than random.')

2023-10-17 11:54:01.438509: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Первоначально опубликовано наTensorflowВеб -сайт, эта статья появляется здесь под новым заголовком и имеет лицензию в CC на 4.0. Образцы кода, разделенные по лицензии Apache 2.0.


Оригинал
PREVIOUS ARTICLE
NEXT ARTICLE