Руководство по QuickStart: API -интерфейсы TensorFlow Core

Руководство по QuickStart: API -интерфейсы TensorFlow Core

15 июня 2025 г.

Обзор контента

  • Настраивать
  • Загрузить и предварительно обрабатывать набор данных
  • Создайте модель машинного обучения
  • Определить функцию потери
  • Тренируйте и оцените свою модель
  • Сохранить и загрузить модель
  • Заключение

Этот учебник QuickStart демонстрирует, как вы можете использовать API низкого уровня TensorFlow Low-Level для создания и обучения модели множественной линейной регрессии, которая предсказывает эффективность топлива. Он используетАвто миль на галлонНабор данных, который содержит данные о эффективности топлива для автомобилей конца 1970-х и начала 1980-х годов.

Вы будете следить за типичными этапами процесса машинного обучения:

  1. Загрузите набор данных.
  2. Создайте входной трубопровод.
  3. Создайте модель нескольких линейной регрессии.
  4. Оценить производительность модели.

Настраивать

Импорт TensorFlow и другие необходимые библиотеки для начала:


import tensorflow as tf
import pandas as pd
import matplotlib
from matplotlib import pyplot as plt
print("TensorFlow version:", tf.__version__)
# Set a random seed for reproducible results 
tf.random.set_seed(22)


2024-08-15 02:47:48.100561: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-15 02:47:48.122351: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-15 02:47:48.128828: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
TensorFlow version: 2.17.0

Загрузить и предварительно обрабатывать набор данных

Далее вам нужно загрузить и предварительно обрабатыватьНабор данных автоматического MPGизРепозиторий машинного обучения UCIПолем В этом наборе данных используется различные количественные и категориальные функции, такие как цилиндры, смещение, мощность и вес, чтобы предсказать топливную эффективность автомобилей в конце 1970-х и начале 1980-х годов.

Набор данных содержит несколько неизвестных значений. Обязательно отказались от пропущенных значений сpandas.DataFrame.dropna, и преобразовать набор данных вtf.float32Тензор типа сtf.convert_to_tensorиtf.castфункции.


url = 'http://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data'
column_names = ['MPG', 'Cylinders', 'Displacement', 'Horsepower', 'Weight',
                'Acceleration', 'Model Year', 'Origin']

dataset = pd.read_csv(url, names=column_names, na_values='?', comment='\t',
                          sep=' ', skipinitialspace=True)

dataset = dataset.dropna()
dataset_tf = tf.convert_to_tensor(dataset, dtype=tf.float32)
dataset.tail()


WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1723690071.413771  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690071.417612  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690071.421353  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690071.424990  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690071.436855  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690071.442040  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690071.445455  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690071.448955  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690071.452381  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690071.455938  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690071.459354  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690071.462823  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690072.697837  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690072.699861  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690072.701923  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690072.703929  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690072.706012  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690072.707881  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690072.709819  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690072.711727  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690072.713702  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690072.715579  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690072.717528  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690072.719413  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690072.760282  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690072.762246  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690072.764346  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690072.766784  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690072.768782  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690072.770677  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690072.772633  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690072.774570  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690072.776568  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690072.778957  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690072.781312  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690072.783540  148279 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355


Мип

Цилиндры

Смещение

Лошадиная сила

Масса

Ускорение

Модельный год

Источник

393

27.0

4

140.0

86.0

2790.0

15.6

82

1

394

44,0

4

97.0

52,0

2130.0

24.6

82

2

395

32,0

4

135.0

84,0

2295.0

11.6

82

1

396

28.0

4

120.0

79,0

2625.0

18.6

82

1

397

31.0

4

119,0

82.0

2720.0

19.4

82

1

Затем разделите набор данных на обучающие и тестовые наборы. Обязательно перетасовать набор данных с помощьюtf.random.shuffleЧтобы избежать предвзятых расколов.


dataset_shuffled = tf.random.shuffle(dataset_tf, seed=22)
train_data, test_data = dataset_shuffled[100:], dataset_shuffled[:100]
x_train, y_train = train_data[:, 1:], train_data[:, 0]
x_test, y_test = test_data[:, 1:], test_data[:, 0]

Выполнить базовую инженерную инженерию, одновременно кодируя"Origin"особенность. Аtf.one_hotФункция полезна для преобразования этого категориального столбца в 3 отдельных бинарных столбца.


def onehot_origin(x):
  origin = tf.cast(x[:, -1], tf.int32)
  # Use `origin - 1` to account for 1-indexed feature
  origin_oh = tf.one_hot(origin - 1, 3)
  x_ohe = tf.concat([x[:, :-1], origin_oh], axis = 1)
  return x_ohe

x_train_ohe, x_test_ohe = onehot_origin(x_train), onehot_origin(x_test)
x_train_ohe.numpy()


array([[  4., 140.,  72., ...,   1.,   0.,   0.],
       [  4., 120.,  74., ...,   0.,   0.,   1.],
       [  4., 122.,  88., ...,   0.,   1.,   0.],
       ...,
       [  8., 318., 150., ...,   1.,   0.,   0.],
       [  4., 156., 105., ...,   1.,   0.,   0.],
       [  6., 232., 100., ...,   1.,   0.,   0.]], dtype=float32)

Этот пример показывает проблему множественной регрессии с предикторами или функциями в совершенно разных масштабах. Следовательно, полезно стандартизировать данные, чтобы каждая функция имела ноль среднее значение и дисперсию единиц. Используйтеtf.reduce_meanиtf.math.reduce_stdфункции для стандартизации. Прогноз регрессионной модели может быть затем нестандарен, чтобы получить ее значение с точки зрения исходных единиц.


class Normalize(tf.Module):
  def __init__(self, x):
    # Initialize the mean and standard deviation for normalization
    self.mean = tf.math.reduce_mean(x, axis=0)
    self.std = tf.math.reduce_std(x, axis=0)

  def norm(self, x):
    # Normalize the input
    return (x - self.mean)/self.std

  def unnorm(self, x):
    # Unnormalize the input
    return (x * self.std) + self.mean


norm_x = Normalize(x_train_ohe)
norm_y = Normalize(y_train)
x_train_norm, y_train_norm = norm_x.norm(x_train_ohe), norm_y.norm(y_train)
x_test_norm, y_test_norm = norm_x.norm(x_test_ohe), norm_y.norm(y_test)

Создайте модель машинного обучения

Создайте модель линейной регрессии с API -интерфейсом ядра Tensorflow. Уравнение для множественной линейной регрессии заключается в следующем:

Y = xw+b

где

  • YM × 1: целевой вектор
  • XM × N: Матрица функции
  • wn × 1: вектор веса
  • Б: предвзятость

С помощью@tf.functionДекоратор, соответствующий код Python прослеживается, чтобы генерировать вызовный график Tensorflow. Этот подход полезен для сохранения и загрузки модели после обучения. Это также может обеспечить повышение производительности для моделей со многими уровнями и сложными операциями.


class LinearRegression(tf.Module):

  def __init__(self):
    self.built = False

  @tf.function
  def __call__(self, x):
    # Initialize the model parameters on the first call
    if not self.built:
      # Randomly generate the weight vector and bias term
      rand_w = tf.random.uniform(shape=[x.shape[-1], 1])
      rand_b = tf.random.uniform(shape=[])
      self.w = tf.Variable(rand_w)
      self.b = tf.Variable(rand_b)
      self.built = True
    y = tf.add(tf.matmul(x, self.w), self.b)
    return tf.squeeze(y, axis=1)

Для каждого примера модель возвращает прогноз для входного автомобиля MPG, вычисляя взвешенную сумму своих функций плюс термин смещения. Этот прогноз может быть затем не заперт, чтобы получить его значение с точки зрения исходных единиц.

lin_reg = LinearRegression()
prediction = lin_reg(x_train_norm[:1])
prediction_unnorm = norm_y.unnorm(prediction)
prediction_unnorm.numpy()

array([6.8007355], dtype=float32)

Определить функцию потери

Теперь определите функцию потери для оценки эффективности модели во время учебного процесса.

Поскольку проблемы с регрессией касаются непрерывных выходов, средняя квадратная ошибка (MSE) является идеальным выбором для функции потери. MSE определяется следующим уравнением:

Mse = 1m∑i = 1m (y^i -yi) 2

где

  • y^: вектор прогнозов
  • Y: Вектор истинных целей

Цель этой проблемы регрессии состоит в том, чтобы найти оптимальный вектор веса, W и BIAS, B, что сводит к минимуму функцию потери MSE.


def mse_loss(y_pred, y):
  return tf.reduce_mean(tf.square(y_pred - y))

Тренируйте и оцените свою модель

Использование мини-партий для обучения обеспечивает как эффективность памяти, так и более высокую конвергенцию. Аtf.data.DatasetAPI имеет полезные функции для партии и перетасовки. API позволяет создавать сложные входные трубопроводы из простых, многоразовых кусочков. Узнайте больше о строительстве входных трубопроводов TensorFlow вэто руководствоПолем


batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices((x_train_norm, y_train_norm))
train_dataset = train_dataset.shuffle(buffer_size=x_train.shape[0]).batch(batch_size)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test_norm, y_test_norm))
test_dataset = test_dataset.shuffle(buffer_size=x_test.shape[0]).batch(batch_size)

Затем напишите обучающий цикл, чтобы итеративно обновить параметры вашей модели, используя функцию потери MSE и ее градиенты в отношении входных параметров.

Этот итеративный метод называется какградиент спускПолем На каждой итерации параметры модели обновляются, сделав шаг в противоположном направлении их вычисленных градиентов. Размер этого шага определяется скоростью обучения, которая является настраиваемым гиперпараметром. Напомним, что градиент функции указывает на направление его самого крутого подъема; Следовательно, сделать шаг в противоположном направлении, указывает на направление самого крутого спуска, что в конечном итоге помогает минимизировать функцию потери MSE.


# Set training parameters
epochs = 100
learning_rate = 0.01
train_losses, test_losses = [], []

# Format training loop
for epoch in range(epochs):
  batch_losses_train, batch_losses_test = [], []

  # Iterate through the training data
  for x_batch, y_batch in train_dataset:
    with tf.GradientTape() as tape:
      y_pred_batch = lin_reg(x_batch)
      batch_loss = mse_loss(y_pred_batch, y_batch)
    # Update parameters with respect to the gradient calculations
    grads = tape.gradient(batch_loss, lin_reg.variables)
    for g,v in zip(grads, lin_reg.variables):
      v.assign_sub(learning_rate * g)
    # Keep track of batch-level training performance 
    batch_losses_train.append(batch_loss)

  # Iterate through the testing data
  for x_batch, y_batch in test_dataset:
    y_pred_batch = lin_reg(x_batch)
    batch_loss = mse_loss(y_pred_batch, y_batch)
    # Keep track of batch-level testing performance 
    batch_losses_test.append(batch_loss)

  # Keep track of epoch-level model performance
  train_loss = tf.reduce_mean(batch_losses_train)
  test_loss = tf.reduce_mean(batch_losses_test)
  train_losses.append(train_loss)
  test_losses.append(test_loss)
  if epoch % 10 == 0:
    print(f'Mean squared error for step {epoch}: {train_loss.numpy():0.3f}')

# Output final losses
print(f"\nFinal train loss: {train_loss:0.3f}")
print(f"Final test loss: {test_loss:0.3f}")


Mean squared error for step 0: 2.866
Mean squared error for step 10: 0.453
Mean squared error for step 20: 0.285
Mean squared error for step 30: 0.231
Mean squared error for step 40: 0.209
Mean squared error for step 50: 0.203
Mean squared error for step 60: 0.194
Mean squared error for step 70: 0.184
Mean squared error for step 80: 0.186
Mean squared error for step 90: 0.176

Final train loss: 0.177
Final test loss: 0.157

Постройте изменения в потере MSE с течением времени. Расчет показателей производительности на назначенномвалидация наборилиТестовый наборгарантирует, что модель не переполняется в набор обучения и может хорошо обобщать невидимые данные.


matplotlib.rcParams['figure.figsize'] = [9, 6]

plt.plot(range(epochs), train_losses, label = "Training loss")
plt.plot(range(epochs), test_losses, label = "Testing loss")
plt.xlabel("Epoch")
plt.ylabel("Mean squared error loss")
plt.legend()
plt.title("MSE loss vs training iterations");

Похоже, что модель отлично справляется с установкой данных обучения, а также хорошо обобщает невидимые тестовые данные.

Сохранить и загрузить модель

Начните с создания экспортного модуля, который принимает необработанные данные и выполняет следующие операции:

  • Извлечение функций
  • Нормализация
  • Прогноз
  • Нинормализация


class ExportModule(tf.Module):
  def __init__(self, model, extract_features, norm_x, norm_y):
    # Initialize pre and postprocessing functions
    self.model = model
    self.extract_features = extract_features
    self.norm_x = norm_x
    self.norm_y = norm_y

  @tf.function(input_signature=[tf.TensorSpec(shape=[None, None], dtype=tf.float32)]) 
  def __call__(self, x):
    # Run the ExportModule for new data points
    x = self.extract_features(x)
    x = self.norm_x.norm(x)
    y = self.model(x)
    y = self.norm_y.unnorm(y)
    return y


lin_reg_export = ExportModule(model=lin_reg,
                              extract_features=onehot_origin,
                              norm_x=norm_x,
                              norm_y=norm_y)

Если вы хотите сохранить модель в ее текущем состоянии, используйтеtf.saved_model.saveфункция Чтобы загрузить сохраненную модель для прогнозирования, используйтеtf.saved_model.loadфункция


import tempfile
import os

models = tempfile.mkdtemp()
save_path = os.path.join(models, 'lin_reg_export')
tf.saved_model.save(lin_reg_export, save_path)


INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpesymyg6z/lin_reg_export/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpesymyg6z/lin_reg_export/assets


lin_reg_loaded = tf.saved_model.load(save_path)
test_preds = lin_reg_loaded(x_test)
test_preds[:10].numpy()


array([28.097498, 26.193336, 33.564373, 27.719315, 31.787922, 24.014559,
       24.421043, 13.459579, 28.562454, 27.368692], dtype=float32)

Заключение

Поздравляю! Вы обучили регрессионную модель, используя API низкого уровня TensorFlow Core.

Для получения дополнительных примеров использования API -интерфейсов TensorFlow Core, ознакомьтесь с следующими руководствами:

  • Логистическая регрессияДля бинарной классификации
  • Многослойные персептроныдля написания ручной распознавания цифр

Первоначально опубликовано наTensorflowВеб -сайт, эта статья появляется здесь под новым заголовком и имеет лицензию в CC на 4.0. Образцы кода, разделенные по лицензии Apache 2.0


Оригинал
PREVIOUS ARTICLE
NEXT ARTICLE