Почему ваша функция не работает так, как вы думаете (и как это исправить)

Почему ваша функция не работает так, как вы думаете (и как это исправить)

14 августа 2025 г.

Обзор контента

  • Ограничения
  • Выполнение побочных эффектов Python
  • Все выходы TF.Function должны быть возвращаемыми значениями
  • Рекурсивные tf.functions не поддерживаются
  • Известные проблемы
  • В зависимости от Global и Free переменных Python
  • В зависимости от объектов Python
  • Создание TF.Variables
  • Дальнейшее чтение

Ограничения

tf.functionимеет несколько ограничений по дизайну, о которых вы должны знать при преобразовании функции Python вtf.functionПолем

Выполнение побочных эффектов Python

Побочные эффекты, такие как печать, приложение к спискам и мутирующие глобальные значения, могут вести себя неожиданно вtf.functionИногда выполнять дважды или нет. Они случаются только в первый раз, когда вы называетеtf.functionс набором входов. После этого отслеживаетсяtf.GraphПовторяется, без выполнения кода Python.

Общее эмпирическое правило состоит в том, чтобы не полагаться на побочные эффекты Python в вашей логике и использовать их только для отладки ваших следов. В противном случае API -интерфейсы Tensorflow какtf.dataВtf.printВtf.summaryВtf.Variable.assign, иtf.TensorArrayЛучший способ обеспечить выполнение вашего кода во время выполнения TensorFlow с каждым вызовом.

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)

Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

Если вы хотите выполнить код Python во время каждого вызоваtf.functionВtf. py_functionвыходной люк. Недостаткиtf.py_functionэто то, что это не портативное или особенно эффективное, не может быть сохранено с помощьюSavedModel, и плохо работает в распределенных (мульти-GPU, TPU) настройках. Также с тех порtf.py_functionДолжен быть включен в график, он бросает все входы/выходы в тензоры.

@tf.py_function(Tout=tf.float32)
def py_plus(x, y):
  print('Executing eagerly.')
  return x + y

@tf.function
def tf_wrapper(x, y):
  print('Tracing.')
  return py_plus(x, y)

Аtf.functionбудет проследить в первый раз:

tf_wrapper(tf.constant(1.0), tf.constant(2.0)).numpy()

Tracing.
Executing eagerly.
3.0

Ноtf.py_functionВнутри выполняется с нетерпением каждый раз:

tf_wrapper(tf.constant(1.0), tf.constant(2.0)).numpy()

Executing eagerly.
3.0

Изменение глобальных и свободных переменных Python

Изменение Python Global исвободные переменныесчитается побочным эффектом Python, так что это происходит только во время трассировки.

external_list = []

@tf.function
def side_effect(x):
  print('Python side effect')
  external_list.append(x)

side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1

Python side effect

Иногда неожиданное поведение очень трудно заметить. В примере ниже,counterпредназначен для защиты увеличения переменной. Однако, поскольку это целое число Python, а не объект Tensorflow, его значение захватывается во время первого трассировки. Когдаtf.functionиспользуется,assign_addбудет записан безоговорочно на базовом графике. Поэтомуvувеличится на 1, каждый раз, когдаtf.functionназывается. Эта проблема распространена среди пользователей, которые пытаются перенести свой код графического режима TensorFlow в TensorFlow 2, используяtf.functionдекораторы, когда на боковых эффектах Python (counterв примере) используются для определения того, какие OPS запустить (assign_addв примере). Обычно пользователи осознают это только после того, как увидели подозрительные численные результаты или значительно более низкую производительность, чем ожидалось (например, если охраняемая операция очень дорогой).

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # A python side-effect
      self.counter += 1
      self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 2, 3

1
2
3

Обходной путь для достижения ожидаемого поведения используетсяtf.init_scopeЧтобы поднять операции за пределами функционального графика. Это гарантирует, что переменная приращение выполняется только один раз во время отслеживания. Это следует отметитьinit_scopeИмеет другие побочные эффекты, включая очищенный поток управления и градиентную ленту. Иногда использованиеinit_scopeможет стать слишком сложным, чтобы реалистично управлять.

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # Lifts ops out of function-building graphs
      with tf.init_scope():
        self.counter += 1
        self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 1, 1

1
1
1

Таким образом, как правило, вы должны избегать мутирования объектов Python, таких как целые числа или контейнеры, такие как списки, которые живут внеtf.functionПолем Вместо этого используйте аргументы и объекты TF. Например, раздел"Накапливание значений в цикле"Имеет один пример того, как могут быть реализованы операции, подобные списку.

В некоторых случаях вы можете захватить и манипулировать состоянием, если этоtf.VariableПолем Так обновляется веса моделей керасConcreteFunctionПолем

Использование итераторов Python и генераторов

Многие функции Python, такие как генераторы и итераторы, полагаются на время выполнения Python, чтобы отслеживать штат. В целом, хотя эти конструкции работают, как и ожидалось в нетерпеливом режиме, они являются примерами побочных эффектов Python и, следовательно, происходят только во время трассировки.

@tf.function
def buggy_consume_next(iterator):
  tf.print("Value:", next(iterator))

iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)

Value: 1
Value: 1
Value: 1

Так же, как у Tensorflow есть специализированныйtf.TensorArrayДля конструкций списка у него есть специализированныйtf.data.IteratorДля итерационных конструкций. См. Раздел оАвтограф -преобразованиедля обзора. Такжеtf.dataAPI может помочь реализовать шаблоны генератора:

@tf.function
def good_consume_next(iterator):
  # This is ok, iterator is a tf.data.Iterator
  tf.print("Value:", next(iterator))

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)

Value: 1
Value: 2
Value: 3

Все выходы TF.Function должны быть возвращаемыми значениями

За исключениемtf.VariableS, функция TF. должна вернуть все свои выходы. Попытка непосредственно получить доступ к любым тензорам из функции, не проходя через возвратные значения, вызывает «утечки».

Например, функция ниже «утечка» тензораaЧерез Python Globalx:

x = None

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return a + 2

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)

3
'SymbolicTensor' object has no attribute 'numpy'

Это верно, даже если утечка также возвращается:

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return x  # Good - uses local tensor

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)

@tf.function
def captures_leaked_tensor(b):
  b += x  # Bad - `x` is leaked from `leaky_function`
  return b

with assert_raises(TypeError):
  captures_leaked_tensor(tf.constant(2))

2
'SymbolicTensor' object has no attribute 'numpy'
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_167534/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_167534/566849597.py", line 21, in <module>
    captures_leaked_tensor(tf.constant(2))
TypeError: <tf.Tensor 'add:0' shape=() dtype=int32> is out of scope and cannot be used here. Use return values, explicit Python locals or TensorFlow collections to access it.
Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

<tf.Tensor 'add:0' shape=() dtype=int32> was defined here:
    File "/usr/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    File "/usr/lib/python3.9/runpy.py", line 87, in _run_code
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py", line 18, in <module>
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 739, in start
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 205, in start
    File "/usr/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
    File "/usr/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
    File "/usr/lib/python3.9/asyncio/events.py", line 80, in _run
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 534, in process_one
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 362, in execute_request
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 778, in execute_request
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 449, in do_execute
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3048, in run_cell
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3103, in _run_cell
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3308, in run_cell_async
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3490, in run_ast_nodes
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3550, in run_code
    File "/tmpfs/tmp/ipykernel_167534/566849597.py", line 7, in <module>
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 833, in __call__
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 889, in _call
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 696, in _initialize
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 178, in trace_function
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 283, in _maybe_define_function
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 310, in _create_concrete_function
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 1059, in func_graph_from_py_func
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 599, in wrapped_fn
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py", line 41, in autograph_handler
    File "/tmpfs/tmp/ipykernel_167534/566849597.py", line 4, in leaky_function
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/override_binary_operator.py", line 113, in binary_op_wrapper
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/tensor_math_operator_overrides.py", line 28, in _add_dispatch_factory
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py", line 1260, in op_dispatch_handler
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/math_ops.py", line 1701, in _add_dispatch
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/gen_math_ops.py", line 490, in add_v2
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py", line 796, in _apply_op_helper
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 670, in _create_op_internal
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 2682, in _create_op_internal
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 1177, in from_node_def

The tensor <tf.Tensor 'add:0' shape=() dtype=int32> cannot be accessed from here, because it was defined in FuncGraph(name=leaky_function, id=139959630636096), which is out of scope.

Обычно утечки, подобные этим, происходят, когда вы используете операторы Python или структуры данных. В дополнение к утечке недоступных тензоров, такие заявления также, вероятно, неверны, поскольку они считаются побочными эффектами Python и не гарантируют выполнять при каждом вызове функции.

Общие способы утечки локальных тензоров также включают мутирование внешней коллекции питонов или объект:

class MyClass:

  def __init__(self):
    self.field = None

external_list = []
external_object = MyClass()

def leaky_function():
  a = tf.constant(1)
  external_list.append(a)  # Bad - leaks tensor
  external_object.field = a  # Bad - leaks tensor

Рекурсивные tf.functions не поддерживаются

Рекурсивныйtf.functionS не поддерживаются и могут вызвать бесконечные петли. Например,

@tf.function
def recursive_fn(n):
  if n > 0:
    return recursive_fn(n - 1)
  else:
    return 1

with assert_raises(Exception):
  recursive_fn(tf.constant(5))  # Bad - maximum recursion error.

Caught expected exception 
  <class 'Exception'>:
Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_167534/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 9, in <module>
    recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
tensorflow.python.autograph.impl.api.StagingError: in user code:

    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_167534/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)

    RecursionError: maximum recursion depth exceeded while calling a Python object

Даже если рекурсивныйtf.functionКажется, работает, функция Python будет прослежена несколько раз и может иметь последствия для производительности. Например,

@tf.function
def recursive_fn(n):
  if n > 0:
    print('tracing')
    return recursive_fn(n - 1)
  else:
    return 1

recursive_fn(5)  # Warning - multiple tracings

tracing
tracing
tracing
tracing
tracing
<tf.Tensor: shape=(), dtype=int32, numpy=1>

Известные проблемы

Если тыtf.functionне оценивается правильно, ошибка может быть объяснена этими известными проблемами, которые планируются исправить в будущем.

В зависимости от Global и Free переменных Python

tf.functionСоздает новыйConcreteFunctionПри вызове с новым значением аргумента Python. Тем не менее, это не делает этого для закрытия питона, глобальных или нелокалов этогоtf.functionПолем Если их значение меняется между вызовами вtf.function,tf.functionвсе еще будет использовать значения, которые они имели, когда они были прослежены. Это отличается от того, как работают регулярные функции Python.

По этой причине вы должны следовать стилю функционального программирования, который использует аргументы вместо закрытия внешних имен.

@tf.function
def buggy_add():
  return 1 + foo

@tf.function
def recommended_add(foo):
  return 1 + foo

foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))

Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(2, shape=(), dtype=int32)

print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add())  # Did not change!
print("Correct:", recommended_add(foo))

Updating the value of `foo` to 100!
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(101, shape=(), dtype=int32)

Еще один способ обновить глобальную ценность - сделать егоtf.Variableи используйтеVariable.assignметод вместо.

@tf.function
def variable_add():
  return 1 + foo

foo = tf.Variable(1)
print("Variable:", variable_add())

Variable: tf.Tensor(2, shape=(), dtype=int32)

print("Updating the value of `foo` to 100!")
foo.assign(100)
print("Variable:", variable_add())

Updating the value of `foo` to 100!
Variable: tf.Tensor(101, shape=(), dtype=int32)

В зависимости от объектов Python

Проход индивидуальных объектов Python в качестве аргументов дляtf.functionподдерживается, но имеет определенные ограничения.

Для максимального покрытия функций рассмотрите возможность преобразования объектов вТипы расширенияПрежде чем передавать их вtf.functionПолем Вы также можете использовать примитивы Python иtf.nest-Подолетимые структуры.

Однако, как это покрытоПравила трассировки, когда обычайTraceTypeне предоставлен пользовательским классом Python,tf.functionвынужден использовать равенство на основе экземпляров, что означает, что оно будетне создавать новую трассуКогда вы передаететот же объект с модифицированными атрибутамиПолем

class SimpleModel(tf.Module):
  def __init__(self):
    # These values are *not* tf.Variables.
    self.bias = 0.
    self.weight = 2.

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))

tf.Tensor(20.0, shape=(), dtype=float32)

print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x))  # Didn't change :(

Adding bias!
tf.Tensor(20.0, shape=(), dtype=float32)

Используя то же самоеtf.functionЧтобы оценить измененный экземпляр модели, будет багги, так как у нее все еще естьТот и тот же на основе экземпляра Трасетипкак оригинальная модель.

По этой причине вам рекомендуется написать вашtf.functionчтобы избежать в зависимости от атрибутов изменяемого объекта или реализацииПротокол трассировкиЧтобы объекты информировалиtf.functionо таких атрибутах.

Если это невозможно, один обходной путь - сделать новыйtf.functionS каждый раз, когда вы изменяете свой объект, чтобы заставить обращаться:

def evaluate(model, x):
  return model.weight * x + model.bias

new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`. `tf.function` already captured its state during tracing.
print(evaluate_no_bias(x))

tf.Tensor(20.0, shape=(), dtype=float32)

print("Adding bias!")
new_model.bias += 5.0
# Create new `tf.function` and `ConcreteFunction` since you modified `new_model`.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.

Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

КакПовреждение может быть дорогим, вы можете использоватьtf.VariableS как атрибуты объекта, которые могут быть мутированы (но не изменены, осторожны!) Для аналогичного эффекта без необходимости отслеживания.

class BetterModel:

  def __init__(self):
    self.bias = tf.Variable(0.)
    self.weight = tf.Variable(2.)

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

better_model = BetterModel()
print(evaluate(better_model, x))

tf.Tensor(20.0, shape=(), dtype=float32)

print("Adding bias!")
better_model.bias.assign_add(5.0)  # Note: instead of better_model.bias += 5
print(evaluate(better_model, x))  # This works!

Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

Создание TF.Variables

tf.functionтолько поддерживает Синглтонtf.VariableS создал один раз на первом вызове и повторно использовал в последующих вызовах функций. Приведенный ниже фрагмент кода создаст новыйtf.VariableВ каждом вызове функции, который приводит кValueErrorисключение.

Пример:

@tf.function
def f(x):
  v = tf.Variable(1.0)
  return v

with assert_raises(ValueError):
  f(1.0)

Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_167534/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_167534/3018268426.py", line 7, in <module>
    f(1.0)
ValueError: in user code:

    File "/tmpfs/tmp/ipykernel_167534/3018268426.py", line 3, in f  *
        v = tf.Variable(1.0)

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

Общий шаблон, используемый для оборудования этого ограничения, - это начать с значения Python None, а затем условно создаватьtf.VariableЕсли значение нет:

class Count(tf.Module):
  def __init__(self):
    self.count = None

  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())

tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Использование с несколькими оптимизаторами керас

Вы можете встретитьсяValueError: tf.function only supports singleton tf.Variables created on the first call.При использовании более одного оптимизатора кераса сtf.functionПолем Эта ошибка возникает из -за того, что оптимизаторы внутренне создаютtf.VariableS, когда они применяют градиенты в первый раз.

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

@tf.function
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
  train_step(w, x, y, opt2)

Calling `train_step` with different optimizer...
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_167534/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_167534/950644149.py", line 18, in <module>
    train_step(w, x, y, opt2)
ValueError: in user code:

    File "/tmpfs/tmp/ipykernel_167534/950644149.py", line 9, in train_step  *
        optimizer.apply_gradients(zip(gradients, [w]))
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/base_optimizer.py", line 291, in apply_gradients  **
        self.apply(grads, trainable_variables)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/base_optimizer.py", line 330, in apply
        self.build(trainable_variables)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/adam.py", line 97, in build
        self.add_variable_from_reference(
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/backend/tensorflow/optimizer.py", line 36, in add_variable_from_reference
        return super().add_variable_from_reference(
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/base_optimizer.py", line 227, in add_variable_from_reference
        return self.add_variable(
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/base_optimizer.py", line 201, in add_variable
        variable = backend.Variable(
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/backend/common/variables.py", line 163, in __init__
        self._initialize_with_initializer(initializer)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/backend/tensorflow/core.py", line 40, in _initialize_with_initializer
        self._value = tf.Variable(

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

Если вам нужно изменить объект состояния между вызовами, простым определитьtf.Moduleподкласс и создайте экземпляры для удержания этих объектов:

class TrainStep(tf.Module):
  def __init__(self, optimizer):
    self.optimizer = optimizer

  @tf.function
  def __call__(self, w, x, y):
    with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
    gradients = tape.gradient(L, [w])
    self.optimizer.apply_gradients(zip(gradients, [w]))


opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

train_o1 = TrainStep(opt1)
train_o2 = TrainStep(opt2)

train_o1(w, x, y)
train_o2(w, x, y)

Вы также можете сделать это вручную, создав несколько экземпляров@tf.functionобертка, по одному для каждого оптимизатора:

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

# Not a tf.function.
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

# Make a new tf.function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step)
train_step_2 = tf.function(train_step)
for i in range(10):
  if i % 2 == 0:
    train_step_1(w, x, y, opt1)
  else:
    train_step_2(w, x, y, opt2)

Using with multiple Keras models

Вы также можете встретитьсяValueError: tf.function only supports singleton tf.Variables created on the first call.При передаче различных экземпляров модели к одному и тому жеtf.functionПолем

Эта ошибка возникает из -за моделей кераса (которыене иметь определенной формы ввода) и слои кераса создаютtf.VariableS, когда они впервые называются. Вы можете попытаться инициализировать эти переменные внутриtf.function, который уже был вызван. Чтобы избежать этой ошибки, попробуйте позвонитьmodel.build(input_shape)Инициализировать все веса перед тренировкой модели.

Дальнейшее чтение

Чтобы узнать, как экспортировать и загрузитьtf.function, смРуководство SaveDmodelПолем Чтобы узнать больше об оптимизации графика, которые выполняются после трассировки, см.Грапплеерский гидПолем Чтобы узнать, как оптимизировать ваш трубопровод данных и профилировать вашу модель, см.ПрофилировщикПолем

Первоначально опубликовано наTensorflowВеб -сайт, эта статья появляется здесь под новым заголовком и имеет лицензию в CC на 4.0. Образцы кода, разделенные по лицензии Apache 2.0.


Оригинал
PREVIOUS ARTICLE
NEXT ARTICLE