Determining If A Value Is In A Set In TensorFlow
Solution 1:
To provide a more concrete answer, say you want to check whether the last dimension of the tensor x
contains any value from a 1D tensor s
, you could do the following:
tile_multiples = tf.concat([tf.ones(tf.shape(tf.shape(x)), dtype=tf.int32), tf.shape(s)], axis=0)
x_tile = tf.tile(tf.expand_dims(x, -1), tile_multiples)
x_in_s = tf.reduce_any(tf.equal(x_tile, s), -1))
For example, for s
and x
:
s = tf.constant([3, 4])
x = tf.constant([[[1, 2, 3, 0, 0],
[4, 4, 4, 0, 0]],
[[3, 5, 5, 6, 4],
[4, 7, 3, 8, 9]]])
x
has shape [2, 2, 5]
and s
has shape [2]
so tile_multiples = [1, 1, 1, 2]
, meaning we will tile the last dimension of x
2 times (once for each element in s
) along a new dimension. So, x_tile
will look like:
[[[[1 1]
[2 2]
[3 3]
[0 0]
[0 0]]
[[4 4]
[4 4]
[4 4]
[0 0]
[0 0]]]
[[[3 3]
[5 5]
[5 5]
[6 6]
[4 4]]
[[4 4]
[7 7]
[3 3]
[8 8]
[9 9]]]]
and x_in_s
will compare each of the tiled values to one of the values in s
. tf.reduce_any
along the last dim will return true if any of the tiled values was in s
, giving the final result:
[[[False False True False False]
[ True True True False False]]
[[ True False False False True]
[ True False True False False]]]
Solution 2:
Take a look at this related question: Count number of "True" values in boolean Tensor
You should be able to build a tensor consisting of [a, b, c, d, e] and then check if any of the rows is equal to x using tf.equal(.)
Solution 3:
Here's two solutions, we want to check if query
is in whitelist
whitelist = tf.constant(["CUISINE", "DISH", "RESTAURANT", "ADDRESS"])
query = "RESTAURANT"
#use broadcasting for element-wise tensor operation
broadcast_equal = tf.equal(whitelist, query)
#method 1: using tensor ops
broadcast_equal_int = tf.cast(broadcast_equal, tf.int8)
broadcast_sum = tf.reduce_sum(broadcast_equal_int)
#method 2: using some tf.core API
nz_cnt = tf.count_nonzero(broadcast_equal)
sess.run([broadcast_equal, broadcast_sum, nz_cnt])
#=> [array([False, False, True, False]), 1, 1]
So if the output is > 0
then the item is in the set.
Post a Comment for "Determining If A Value Is In A Set In TensorFlow"