Scikit-learn predict_proba donne de mauvaises réponses
C'est une question de suivi à partir de Comment savoir ce que les classes sont représentées dans le tableau de predict_proba dans Scikit-learn
Dans cette question, j'ai cité le code suivant:
>>> import sklearn
>>> sklearn.__version__
'0.13.1'
>>> from sklearn import svm
>>> model = svm.SVC(probability=True)
>>> X = [[1,2,3], [2,3,4]] # feature vectors
>>> Y = ['apple', 'orange'] # classes
>>> model.fit(X, Y)
>>> model.predict_proba([1,2,3])
array([[ 0.39097541, 0.60902459]])
J'ai découvert à cette question, ce résultat représente la probabilité de le point appartenant à chaque classe, dans l'ordre donné par le modèle.classes_
>>> zip(model.classes_, model.predict_proba([1,2,3])[0])
[('apple', 0.39097541289393828), ('orange', 0.60902458710606167)]
Donc... cette réponse, si on l'interprète correctement, dit que le point est probablement un "orange" (avec un assez faible niveau de confiance, en raison de la petite quantité de données). Mais intuitivement, ce résultat est manifestement inexact, puisque le point donné est identique à la formation des données pour 'pomme'. Juste pour être sûr, j'ai testé le sens inverse:
>>> zip(model.classes_, model.predict_proba([2,3,4])[0])
[('apple', 0.60705475211840931), ('orange', 0.39294524788159074)]
De nouveau, manifestement incorrect, mais dans l'autre sens.
Enfin, j'ai essayé avec des points qui ont été beaucoup plus loin.
>>> X = [[1,1,1], [20,20,20]] # feature vectors
>>> model.fit(X, Y)
>>> zip(model.classes_, model.predict_proba([1,1,1])[0])
[('apple', 0.33333332048410247), ('orange', 0.66666667951589786)]
Encore une fois, le modèle prédit la mauvaise probabilités. MAIS, le modèle.prédire la fonction se à droite!
>>> model.predict([1,1,1])[0]
'apple'
Maintenant, je me souviens avoir lu quelque chose dans les docs sur predict_proba inexacte pour les petits jeux de données, mais je n'arrive pas à le retrouver. Est-ce le comportement attendu, ou suis-je en train de faire quelque chose de mal? Si c'EST le comportement attendu, alors pourquoi ne le prédire et predict_proba fonction de désaccord de la sortie? Et surtout, quelle taille fait le jeu de données doivent être avant que je puisse avoir confiance dans les résultats de predict_proba?
-------- Mise à JOUR --------
Ok, donc j'ai fait un peu plus d'expériences dans ce: le comportement de predict_proba est fortement dépendante de 'n', mais pas en aucune façon prévisible!
>>> def train_test(n):
... X = [[1,2,3], [2,3,4]] * n
... Y = ['apple', 'orange'] * n
... model.fit(X, Y)
... print "n =", n, zip(model.classes_, model.predict_proba([1,2,3])[0])
...
>>> train_test(1)
n = 1 [('apple', 0.39097541289393828), ('orange', 0.60902458710606167)]
>>> for n in range(1,10):
... train_test(n)
...
n = 1 [('apple', 0.39097541289393828), ('orange', 0.60902458710606167)]
n = 2 [('apple', 0.98437355278112448), ('orange', 0.015626447218875527)]
n = 3 [('apple', 0.90235408180319321), ('orange', 0.097645918196806694)]
n = 4 [('apple', 0.83333299908143665), ('orange', 0.16666700091856332)]
n = 5 [('apple', 0.85714254878984497), ('orange', 0.14285745121015511)]
n = 6 [('apple', 0.87499969631893626), ('orange', 0.1250003036810636)]
n = 7 [('apple', 0.88888844127886335), ('orange', 0.11111155872113669)]
n = 8 [('apple', 0.89999988018127364), ('orange', 0.10000011981872642)]
n = 9 [('apple', 0.90909082368682159), ('orange', 0.090909176313178491)]
Comment dois-je utiliser cette fonction en toute sécurité dans mon code? À tout le moins, est-il une valeur de n pour laquelle il sera garanti d'accord avec le résultat du modèle.prévoir?
Vous devez vous connecter pour publier un commentaire.
si vous utilisez
svm.LinearSVC()
comme estimateur, et.decision_function()
(qui est comme svm.SVC est .predict_proba()) pour trier les résultats de la plus probable de la classe la moins probable. ceci est en accord avec.predict()
fonction. De Plus, cet estimateur est plus rapide et donne presque les mêmes résultats avecsvm.SVC()
le seul inconvénient pour vous peut-être que
.decision_function()
donne une valeur signée qqch comme entre -1 et 3 au lieu d'une valeur de probabilité. mais il est d'accord avec la prédiction.LinearSVC()
donnera des prédictions similaires commeSVC(kernel='linear')
mais pasSVC(kernel='rbf')
qui est le noyau par défaut pourSVC
.predict_probas
est l'aide de l'Platt mise à l'échelle caractéristique de libsvm à callibrate probabilités, voir:Donc en effet le hyperplane prédictions et la proba d'étalonnage peuvent être en désaccord, surtout si vous ne disposez que de 2 échantillons dans votre jeu de données. C'est bizarre que l'interne de la croix-validation par libsvm pour la mise à l'échelle des probabilités de ne pas (explicitement) dans ce cas. C'est peut-être un bug. On aurait pu plonger dans le Platt mise à l'échelle de code de libsvm à comprendre ce qui se passe.
De la nourriture pour la pensée ici. Je pense que j'ai effectivement eu le predict_proba de travail. Veuillez voir le code ci-dessous...
De sortie:
Top 5 Des Taux De Précision = 1.0
Top 1 Taux De Précision = 1.0
Ne pouvais pas le faire fonctionner pour mes propres données si 🙁
Il existe une certaine confusion quant à ce qui predict_proba la réalité. Il ne permet pas de prédire les probabilités comme le suggère le titre, mais les sorties distances.
Dans l'apple vs orange exemple 0.39097541, 0.60902459 la distance la plus courte 0.39097541 est la pomme de classe. qui est contre-intuitif. vous êtes à la recherche à la probabilité la plus élevée, mais ce n'est pas le cas.
Une autre source de confusion vient de là predict_proba match dur étiquettes, tout simplement pas dans l'ordre des classes, à partir de 0..n de manière séquentielle . Scikit semble aléatoire les classes, mais il est possible de les cartographier.
ici est de savoir comment il fonctionne.
prédit étiquettes [2 0 1 0 4]
rien, mais la troisième classe est un match.
selon prédit étiquettes en cm, classe 0 est prévue et réelle de classe est
0 argmax(pred_prob).
Mais, son mappé à
afin de trouver la deuxième classe
nous allons le faire à nouveau.
regarder les erreurs de classification résultat numero 4, où réel lebel 4, prédit 1 selon cm.
Ce sont mes 0.02.