Tensorflow: Lors de l'utilisation de tf.expand_dims?
Tensorflow tutoriels comprennent l'utilisation de tf.expand_dims
pour ajouter un "lot de dimension" d'un tenseur. J'ai lu la doc de cette fonction, mais il est encore un peu mystérieux pour moi. Personne ne sait exactement dans quelles circonstances il doit être utilisé?
Mon code est ci-dessous. Mon but est de calculer une perte en fonction de la distance entre les réels et les bacs. (E. g. si predictedBin = 10
et truthBin = 7
puis binDistanceLoss = 3
).
batch_size = tf.size(truthValues_placeholder)
labels = tf.expand_dims(truthValues_placeholder, 1)
predictedBin = tf.argmax(logits)
binDistanceLoss = tf.abs(tf.sub(labels, logits))
Dans ce cas, dois-je appliquer tf.expand_dims
à predictedBin
et binDistanceLoss
? Merci à l'avance.
Vous devez vous connecter pour publier un commentaire.
expand_dims
ne permet pas d'ajouter ou de réduire les éléments d'un tenseur, il change juste la forme en ajoutant1
à des dimensions. Par exemple, un vecteur avec 10 éléments pourraient être considérés comme une 10x1 de la matrice.La situation, j'ai rencontré à utiliser
expand_dims
, c'est quand j'ai essayé de construire un ConvNet pour classer les images en niveaux de gris. Les images en niveaux de gris sera chargé comme une matrice de taille[320, 320]
. Cependant,tf.nn.conv2d
exiger la saisie d'être[batch, in_height, in_width, in_channels]
, où lain_channels
dimension est manquante dans mes données, qui dans ce cas doit être1
. J'ai donc utiliséexpand_dims
pour ajouter une dimension de plus.Dans votre cas, je ne pense pas que vous avez besoin
expand_dims
.À ajouter à Da Tong réponse, vous pouvez élargir plus d'une dimension, à la même heure. Par exemple, si vous effectuez TensorFlow de
conv1d
opération sur les vecteurs de rang 1, vous avez besoin de les nourrir avec de rang trois.Effectuer
expand_dims
à plusieurs reprises qu'il est lisible, mais risque d'introduire des frais généraux dans le calcul graphique. Vous pouvez obtenir les mêmes fonctionnalités dans un one-liner avecreshape
:REMARQUE: Dans le cas où vous obtenez le message d'erreur
TypeError: Failed to convert object of type <type 'list'> to Tensor.
, essayez de passertf.shape(x)[0]
au lieu dex.get_shape()[0]
comme l'a suggéré ici.Espère que cela aide!
Cheers,
Andres
reshape
est plus rapide que de le faire, disons, deux ou troisexpand_dims
?