Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf.get_embedding_variable stuck in tf.nn.embedding_lookup when ids is negative number #1003

Open
welsonzhang opened this issue Jul 11, 2024 · 0 comments

Comments

@welsonzhang
Copy link

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 20.04):16.04
  • DeepRec version or commit id: deeprec2208
  • Python version:3.6
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version:

Describe the current behavior
when ids is negative number ,embedding_lookup stuck.

Describe the expected behavior
when ids is negative number , embedding_lookup get embedding float list

Code to reproduce the issue

import tensorflow as tf

var = tf.get_embedding_variable("var_0",
                                embedding_dim=3,
                                initializer=tf.ones_initializer(tf.float32),
                                partitioner=tf.fixed_size_partitioner(num_shards=4))

shape = [var1.total_count() for var1 in var]

emb = tf.nn.embedding_lookup(var, tf.cast([0,-1,-2,-5,-6,-7], tf.int64))
fun = tf.multiply(emb, 2.0, name='multiply')
loss = tf.reduce_sum(fun, name='reduce_sum')
opt = tf.train.AdagradOptimizer(0.1)

g_v = opt.compute_gradients(loss)
train_op = opt.apply_gradients(g_v)

init = tf.global_variables_initializer()

sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
with tf.Session(config=sess_config) as sess:
  sess.run([init])
  print(sess.run([emb, train_op, loss]))
  print(sess.run([emb, train_op, loss]))
  print(sess.run([emb, train_op, loss]))
  print(sess.run([shape]))

Provide a reproducible test case that is the bare minimum necessary to generate the problem.

Other info / logs

Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
1 participant