tensorflow.python.framework.ops.EagerTensor object has no attribute _in_graph_mode

tensorflow.python.framework.ops.EagerTensor object has no attribute _in_graph_mode

The reason for the bug is that the tf.keras optimizers apply gradients to variable objects (of type tf.Variable), while you are trying to apply gradients to tensors (of type tf.Tensor). Tensor objects are not mutable in TensorFlow, thus the optimizer cannot apply gradients to it.

You should initialize the variable img as a tf.Variable. This is how your code should be:

# NOTE: The original image is lost here. If this is not desired, then you can
# rename the variable to something like img_var.
img = tf.Variable(img)
opt = tf.optimizers.Adam(learning_rate=lr, decay = 1e-6)

for _ in range(epoch):
    with tf.GradientTape() as tape:
        y = model(img.value())[:, :, :, filter]
        loss = -tf.math.reduce_mean(y)

    grads = tape.gradient(loss, img)
    opt.apply_gradients(zip([grads], [img]))

Also, it is recommended to calculate the gradients outside the tapes context. This is because keeping it in will lead to the tape tracking the gradient calculation itself, leading to higher memory usage. This is only desirable if you want to calculate higher-order gradients. Since you dont need those, I have kept them outside.

Note I have changed the line y = model(img)[:, :, :, filter] to y = model(img.value())[:, :, :, filter]. This is because tf.keras models need tensors as input, not variables (bug, or feature?).

Well although not directly related, but can be somewhat useful to understand what causes this type of errors at the first place.

This type of error occurs whenever we try to modify a constant tensor.

Simple example, which raise similar error below–

unchangeable_tensors = tf.constant([1,2,3])

A way to bypass the error is using tf.Variable() as shown below

changeable_tensors = tf.Variable([1,2,3])

tensorflow.python.framework.ops.EagerTensor object has no attribute _in_graph_mode

Leave a Reply

Your email address will not be published. Required fields are marked *