diff --git a/Seq2SeqSharp/Applications/SeqClassification.cs b/Seq2SeqSharp/Applications/SeqClassification.cs index ae67891..3ecfab1 100644 --- a/Seq2SeqSharp/Applications/SeqClassification.cs +++ b/Seq2SeqSharp/Applications/SeqClassification.cs @@ -161,12 +161,21 @@ public override List RunForwardOnSingleDevice(IComputeGraph compu using var targetIdxTensor = computeGraph.Argmax(probs, 1); float[] targetIdx = targetIdxTensor.ToWeightArray(); List targetWords = m_modelMetaData.TgtVocab.ConvertIdsToString(targetIdx.ToList()); + + List allWords = m_modelMetaData.TgtVocab.IndexToWord.Values.Select(s => s).ToList(); + var combinedOutput = probs.ToWeightArray().Zip(allWords, (prob, word) => $"{word.PadRight(15)}: {prob,10:F8}"); + Logger.WriteLine(Logger.Level.info, $"\nWrds w/Probs\n{string.Join("\n", combinedOutput)}"); + nr.Output.Add(new List>()); for (int k = 0; k < batchSize; k++) { nr.Output[0].Add(new List()); - nr.Output[0][k].Add(targetWords[k]); + + // Fetch the corresponding probability for the predicted target index + float probAtTargetIdx = probs.GetWeightAt([0, (long)targetIdx[k]]); + + nr.Output[0][k].Add($"{targetWords[k]} {probAtTargetIdx:F8}"); } }