Skip to content Skip to sidebar Skip to footer

Determining If A Value Is In A Set In TensorFlow

The tf.logical_or, tf.logical_and, and tf.select functions are very useful. However, suppose you have value x, and you wanted to see if it was in a set(a, b, c, d, e). In python yo

Solution 1:

To provide a more concrete answer, say you want to check whether the last dimension of the tensor x contains any value from a 1D tensor s, you could do the following:

tile_multiples = tf.concat([tf.ones(tf.shape(tf.shape(x)), dtype=tf.int32), tf.shape(s)], axis=0)
x_tile = tf.tile(tf.expand_dims(x, -1), tile_multiples)
x_in_s = tf.reduce_any(tf.equal(x_tile, s), -1))

For example, for s and x:

s = tf.constant([3, 4])
x = tf.constant([[[1, 2, 3, 0, 0], 
                  [4, 4, 4, 0, 0]], 
                 [[3, 5, 5, 6, 4], 
                  [4, 7, 3, 8, 9]]])

x has shape [2, 2, 5] and s has shape [2] so tile_multiples = [1, 1, 1, 2], meaning we will tile the last dimension of x 2 times (once for each element in s) along a new dimension. So, x_tile will look like:

[[[[1 1]
   [2 2]
   [3 3]
   [0 0]
   [0 0]]

  [[4 4]
   [4 4]
   [4 4]
   [0 0]
   [0 0]]]

 [[[3 3]
   [5 5]
   [5 5]
   [6 6]
   [4 4]]

  [[4 4]
   [7 7]
   [3 3]
   [8 8]
   [9 9]]]]

and x_in_s will compare each of the tiled values to one of the values in s. tf.reduce_any along the last dim will return true if any of the tiled values was in s, giving the final result:

[[[False False  True False False]
  [ True  True  True False False]]

 [[ True False False False  True]
  [ True False  True False False]]]

Solution 2:

Take a look at this related question: Count number of "True" values in boolean Tensor

You should be able to build a tensor consisting of [a, b, c, d, e] and then check if any of the rows is equal to x using tf.equal(.)


Solution 3:

Here's two solutions, we want to check if query is in whitelist

whitelist = tf.constant(["CUISINE", "DISH", "RESTAURANT", "ADDRESS"])
query = "RESTAURANT"

#use broadcasting for element-wise tensor operation
broadcast_equal = tf.equal(whitelist, query)

#method 1: using tensor ops
broadcast_equal_int = tf.cast(broadcast_equal, tf.int8)
broadcast_sum = tf.reduce_sum(broadcast_equal_int)

#method 2: using some tf.core API
nz_cnt = tf.count_nonzero(broadcast_equal)

sess.run([broadcast_equal, broadcast_sum, nz_cnt])
#=> [array([False, False,  True, False]), 1, 1]

So if the output is > 0 then the item is in the set.


Post a Comment for "Determining If A Value Is In A Set In TensorFlow"