Comment obtenir les dimensions d'un tenseur (dans TensorFlow) au moment de la construction du graphique?
J'essaie une Op qui ne se comporte pas comme prévu.
graph = tf.Graph()
with graph.as_default():
train_dataset = tf.placeholder(tf.int32, shape=[128, 2])
embeddings = tf.Variable(
tf.random_uniform([50000, 64], -1.0, 1.0))
embed = tf.nn.embedding_lookup(embeddings, train_dataset)
embed = tf.reduce_sum(embed, reduction_indices=0)
J'ai donc besoin de connaître les dimensions du Tenseur embed
. Je sais que cela peut être fait au moment de l'exécution, mais c'est trop de travail pour une opération simple. Quelle est la façon la plus facile de le faire?
source d'informationauteur Thoran
Vous devez vous connecter pour publier un commentaire.
Tensor.get_shape()
de ce post.À partir de la documentation:
Je vois la plupart des gens confus au sujet de
tf.shape(tensor)
ettensor.get_shape()
Soyons clair:
tf.shape
tf.shape
est utilisé pour la forme dynamique. Si votre tenseur de forme est changablel'utiliser.Un exemple: une entrée est une image avec changable la largeur et la hauteur, nous voulons la redimensionner à la moitié de sa taille, alors on peut écrire quelque chose comme:
new_height = tf.shape(image)[0] /2
tensor.get_shape
tensor.get_shape
est utilisé pour les formes, ce qui signifie que le tenseur de forme peut être déduit dans le graphique.Conclusion:
tf.shape
peut être utilisé presque partout, maist.get_shape
seulement pour les formes peuvent être déduites à partir d'un graphique.Une fonction de accès les valeurs:
Exemple:
Il suffit d'imprimer les intégrer après la construction graphique (ops) sans cours d'exécution:
Cela permettra de montrer la forme de l'intégrer tenseur:
Généralement, il est bon de vérifier les formes de tous les tenseurs avant la formation de vos modèles.
Nous allons faire simple comme l'enfer. Si vous voulez un numéro unique pour le nombre de dimensions, comme
2, 3, 4, etc.,
alors utilisez simplementtf.rank()
. Mais, si vous voulez la forme exacte du tenseur ensuite utilisertensor.get_shape()
La méthode tf.la forme est un TensorFlow méthode statique. Cependant, il y a aussi la méthode get_shape pour le Tenseur de classe. Voir
https://www.tensorflow.org/api_docs/python/tf/Tensor#get_shape