From cf4974dfecf6aa0390a3fc2b7e35b2612e746cac Mon Sep 17 00:00:00 2001 From: Michael Twohey Date: Sat, 23 Nov 2024 12:07:00 -0800 Subject: [PATCH] Update SeqClassification.cs Showing the probabilities of an answer --- Seq2SeqSharp/Applications/SeqClassification.cs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/Seq2SeqSharp/Applications/SeqClassification.cs b/Seq2SeqSharp/Applications/SeqClassification.cs index ae67891..3ecfab1 100644 --- a/Seq2SeqSharp/Applications/SeqClassification.cs +++ b/Seq2SeqSharp/Applications/SeqClassification.cs @@ -161,12 +161,21 @@ public override List RunForwardOnSingleDevice(IComputeGraph compu using var targetIdxTensor = computeGraph.Argmax(probs, 1); float[] targetIdx = targetIdxTensor.ToWeightArray(); List targetWords = m_modelMetaData.TgtVocab.ConvertIdsToString(targetIdx.ToList()); + + List allWords = m_modelMetaData.TgtVocab.IndexToWord.Values.Select(s => s).ToList(); + var combinedOutput = probs.ToWeightArray().Zip(allWords, (prob, word) => $"{word.PadRight(15)}: {prob,10:F8}"); + Logger.WriteLine(Logger.Level.info, $"\nWrds w/Probs\n{string.Join("\n", combinedOutput)}"); + nr.Output.Add(new List>()); for (int k = 0; k < batchSize; k++) { nr.Output[0].Add(new List()); - nr.Output[0][k].Add(targetWords[k]); + + // Fetch the corresponding probability for the predicted target index + float probAtTargetIdx = probs.GetWeightAt([0, (long)targetIdx[k]]); + + nr.Output[0][k].Add($"{targetWords[k]} {probAtTargetIdx:F8}"); } }