Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong binary accuracy with Jax #20178

Open
eli-osherovich opened this issue Aug 28, 2024 · 4 comments
Open

Wrong binary accuracy with Jax #20178

eli-osherovich opened this issue Aug 28, 2024 · 4 comments

Comments

@eli-osherovich
Copy link

I have some very strange results out of the `

Consider the code below:

import os
os.environ["KERAS_BACKEND"] = "jax"
import keras


inp = keras.Input(shape=(1,))
out = inp > 0.5
mm = keras.Model(inputs=inp, outputs=out) 

x = np.random.rand(32, 1)

res = mm.predict(x)
met = keras.metrics.BinaryAccuracy()
met.update_state(x>0.5, res>0.5)
met.result()

I would expect to get 1 every single run. Instead I get some random result (close to 0.5).

Packages' versions (tf, keras, jax, np)

'2.17.0', '3.5.0', '0.4.26', '1.26.4'
@eli-osherovich
Copy link
Author

eli-osherovich commented Aug 28, 2024

The result is correct if I cast the second parameter of update_state to a float or int.

@mehtamansi29
Copy link
Collaborator

Hi @eli-osherovich-

While updating state(met.update_state(x>0.5, res>0.5)), x>0.5 and res>0.5 are in boolean arrays. But BinaryAccuracy metrics accepts only numerical values(floats or integers) only.

While running same code in tensorflow backend it is giving error message.
Error:
InvalidArgumentError: Value for attr 'T' of bool is not in the list of allowed values: float, double, int32, uint8, int16, int8, int64, bfloat16, uint16, half, uint32, uint64
; NodeDef: {{node Greater}}; Op<name=Greater; signature=x:T, y:T -> z:bool; attr=T:type,allowed=[DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16, DT_INT8, DT_INT64, DT_BFLOAT16, DT_UINT16, DT_HALF, DT_UINT32, DT_UINT64]> [Op:Greater] name

import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import numpy as np

inp = keras.Input(shape=(1,))
out = inp > 0.5
mm = keras.Model(inputs=inp, outputs=out) 

x = np.random.rand(32, 1)

res = mm.predict(x)
met = keras.metrics.BinaryAccuracy()
met.update_state(x>0.5, res>0.5)
met.result()

 

So in the JAX there should be same error message comes while giving boolean into BinaryAccuracy metrics. You can create new issue in JAX repo for adding the error message.

@fchollet
Copy link
Member

We could consider casting the values to floatx() in update_state() -- would you like to open a PR @eli-osherovich ?

@mehtamansi29
Copy link
Collaborator

mehtamansi29 commented Aug 30, 2024

We could consider casting the values to floatx() in update_state() -- would you like to open a PR @eli-osherovich ?

Hi @fchollet - I will raise PR for casting the values to floatx() in update_state().

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment