Comment explicitement la diffusion d'un tenseur pour correspondre à une autre forme dans tensorflow?
J'ai trois tenseurs, A, B and C
dans tensorflow, A
et B
sont à la fois de la forme (m, n, r)
, C
est un tenseur binaire de la forme (m, n, 1)
.
Je veux sélectionner les éléments à partir de A ou B selon la valeur de C
. L'outil idéal est tf.select
, cependant que ne pas avoir de radiodiffusion de la sémantique, donc je dois tout d'abord explicitement la diffusion C
à la même forme que le A et le B.
Ce serait ma première tentative à la façon de le faire, mais il ne m'aime pas le mélange d'un tenseur (tf.shape(A)[2]
) dans la liste de forme.
import tensorflow as tf
A = tf.random_normal([20, 100, 10])
B = tf.random_normal([20, 100, 10])
C = tf.random_normal([20, 100, 1])
C = tf.greater_equal(C, tf.zeros_like(C))
C = tf.tile(C, [1,1,tf.shape(A)[2]])
D = tf.select(C, A, B)
Quelle est la bonne approche ici?
- Un hack qui fonctionne: je peux utiliser la radiodiffusion et la sémantique de se multiplier et de se multiplier par ceux tenseur ainsi:
Expander = tf.ones_like(B)
, puisC = Expander*C
Vous devez vous connecter pour publier un commentaire.
EDIT: Dans toutes les versions de TensorFlow depuis 0.12rc0, le code en question fonctionne directement. TensorFlow automatiquement pile tenseurs et Python numéros en un tenseur argument. La solution ci-dessous à l'aide de
tf.pack()
est nécessaire uniquement dans les versions antérieures à 0.12rc0. Notez quetf.pack()
a été renommétf.pile()
dans TensorFlow 1.0.Votre solution est très proche de travail. Vous devez remplacer la ligne:
...avec les éléments suivants:
(La raison pour laquelle le problème est que TensorFlow de ne pas convertir implicitement une liste des tenseurs et Python littéraux dans un tenseur.
tf.pack()
prend une liste de tenseurs, donc il vous permet de convertir chacun des éléments de son entrée (1
,1
, ettf.shape(C)[2]
) à un tenseur. Étant donné que chaque élément est un scalaire, le résultat sera un vecteur.)[
et l'absence d'un)
, mais puis-je obtenir un laconique message d'erreur quand j' exécuter le tf session:InvalidArgumentError: Inputs to operation Select_13 of type Select must have the same size and shape. Input 0: dim { size: 20 } dim { size: 100 } dim { size: 1 } != input 1: dim { size: 20 } dim { size: 100 } dim { size: 10 }
tf.shape()
doit avoir étéA
(ouB
). Cela fonctionne pour moi - quelle erreur avez-vous observé?tf.shape()
. Merci!tf.stack
maintenant, oui?tf.stack()
pour résoudre ce problème (voir edit). Le correctif pour le problème sous-jacent sont venus longtemps avant de le renommer detf.pack()
àtf.stack()
, donc je vais le garder commetf.pack()
pour l'exactitude historique.Voici un sale hack: