Voraussagen, in Tensorflow Schätzer mit Eingabe fn
Ich den tutorial-code aus https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/examples/learn/wide_n_deep_tutorial.py und der code funktioniert einwandfrei, bis ich versuchte, um eine Vorhersage zu machen, anstatt einfach es zu bewerten. Ich habe versucht, eine andere Funktion für die Vorhersage, dass die so Aussehen (nur durch entfernen der parameter y):
def input_fn_predict(data_file, num_epochs, shuffle):
"""Input builder function."""
df_data = pd.read_csv(
tf.gfile.Open(data_file),
names=CSV_COLUMNS,
skipinitialspace=True,
engine="python",
skiprows=1)
# remove NaN elements
df_data = df_data.dropna(how="any", axis=0)
labels = df_data["income_bracket"].apply(lambda x: ">50K" in x).astype(int)
return tf.estimator.inputs.pandas_input_fn( #removed paramter y
x=df_data,
batch_size=100,
num_epochs=num_epochs,
shuffle=shuffle,
num_threads=5)
Und nennen es so:
predictions = m.predict(
input_fn=input_fn_predict(test_file_name, num_epochs=1, shuffle=True)
)
for i, p in enumerate(predictions):
print(i, p)
- Mache ich es richtig?
- Warum bekomme ich die Vorhersage 81404 statt 16282(Anzahl der line-in-test-Datei)?
- Jede Zeile enthält so etwas wie dieses:
{'Wahrscheinlichkeiten': array([ 0.78595656, 0.21404342], dtype=float32),
'logits': array([-1.3007226], dtype=float32), 'Klassen': array(['0'],
dtype=Objekt), 'class_ids': array([0]), 'logistic': array([
0.21404341], dtype=float32)}
Wie lese ich das?
Du musst angemeldet sein, um einen Kommentar abzugeben.
Müssen Sie
shuffle=False
da, um vorherzusagen, neue Labels, die Sie benötigen, um Daten zu erhalten, um.Unten ist mein code zum ausführen der Vorhersage (ich habe es getestet). Die input-Datei ist wie test-Daten (csv), aber es ist kein label-Spalte.
Nennen es:
Das Vorhersage-Ergebnis für eine Probe unter:
Was jedes Feld bedeutet, sind
Es sagt Voraus, der Ausgabe-Bezeichnung der Klasse-0 (in diesem Fall <=50K) mit
Vertrauen 0.78595656
Der Wert von z in die Gleichung 1/(1+e^(-z)) ist -1.3.
Die Klasse label ist 0