Skip to content

Commit

Permalink
Handle auth change while already connected. (#841)
Browse files Browse the repository at this point in the history
  • Loading branch information
scottf authored Nov 8, 2023
1 parent 4208ad1 commit 60544fe
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 22 deletions.
22 changes: 14 additions & 8 deletions src/NATS.Client/Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1177,7 +1177,7 @@ internal bool connect(Srv s, out Exception exToThrow)
{
exToThrow = null;

NATSConnectionException natsAuthEx = null;
String lastAuthExMessage = null;

for(var i = 0; i < 6; i++) //Precaution to not end up in server returning ExTypeA, ExTypeB, ExTypeA etc.
{
Expand All @@ -1196,16 +1196,18 @@ internal bool connect(Srv s, out Exception exToThrow)
}
catch (NATSConnectionException ex)
{
if (!ex.IsAuthenticationOrAuthorizationError())
string message = ex.Message.ToLower();
if (!NATSException.IsAuthenticationOrAuthorizationError(message, true))
{
throw;
}

ScheduleErrorEvent(s, ex);

if (natsAuthEx == null || !natsAuthEx.Message.Equals(ex.Message, StringComparison.OrdinalIgnoreCase))
// avoiding double the same
if (lastAuthExMessage == null || !lastAuthExMessage.Equals(message))
{
natsAuthEx = ex;
lastAuthExMessage = message;
continue;
}

Expand Down Expand Up @@ -2448,9 +2450,6 @@ public Exception LastError
// sets the connection's lastError.
internal void processErr(MemoryStream errorStream)
{
bool invokeDelegates = false;
Exception ex = null;

string s = getNormalizedError(errorStream);

if (IC.STALE_CONNECTION.Equals(s))
Expand All @@ -2466,7 +2465,9 @@ internal void processErr(MemoryStream errorStream)
}
else
{
ex = new NATSException("Error from processErr(): " + s);
NATSException ex = new NATSException("Error from processErr(): " + s);
bool invokeDelegates = false;

lock (mu)
{
lastEx = ex;
Expand All @@ -2478,6 +2479,11 @@ internal void processErr(MemoryStream errorStream)
}

close(ConnState.CLOSED, invokeDelegates, ex);

if (NATSException.IsAuthenticationOrAuthorizationError(s))
{
processReconnect();
}
}
}

Expand Down
16 changes: 8 additions & 8 deletions src/NATS.Client/Exceptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ public class NATSException : Exception
public NATSException() : base() { }
public NATSException(string err) : base (err) {}
public NATSException(string err, Exception innerEx) : base(err, innerEx) { }

public static bool IsAuthenticationOrAuthorizationError(string message, bool alreadyLowered = false)
{
string lowerMessage = alreadyLowered ? message : message.ToLower();
return lowerMessage.Contains("user authentication")
|| lowerMessage.Contains("authorization violation")
|| lowerMessage.Contains("authentication expired");
}
}

/// <summary>
Expand All @@ -33,14 +41,6 @@ public class NATSConnectionException : NATSException
{
public NATSConnectionException(string err) : base(err) { }
public NATSConnectionException(string err, Exception innerEx) : base(err, innerEx) { }

public bool IsAuthenticationOrAuthorizationError()
{
string lowerMessage = Message.ToLower();
return lowerMessage.Contains("user authentication")
|| lowerMessage.Contains("authorization violation")
|| lowerMessage.Contains("authentication expired");
}
}

/// <summary>
Expand Down
33 changes: 27 additions & 6 deletions src/Tests/IntegrationTests/TestAuthorization.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
using System.Threading;
using NATS.Client;
using NATS.Client.Internals;
using UnitTests;
using Xunit;
using Xunit.Abstractions;

namespace IntegrationTests
{
Expand All @@ -27,7 +29,15 @@ namespace IntegrationTests
/// </summary>
public class TestAuthorization : TestSuite<AuthorizationSuiteContext>
{
public TestAuthorization(AuthorizationSuiteContext context) : base(context) {}
private readonly ITestOutputHelper output;

public TestAuthorization(ITestOutputHelper output, AuthorizationSuiteContext context) : base(context)
{
this.output = output;
Console.SetOut(new TestBase.ConsoleWriter(output));
}

// public TestAuthorization(AuthorizationSuiteContext context) : base(context) {}

int hitDisconnect;

Expand Down Expand Up @@ -261,23 +271,34 @@ public void TestRealUserAuthenticationExpired()
string credsFile = Path.GetTempFileName();
File.WriteAllText(credsFile, cred);

CountdownEvent userAuthenticationExpired = new CountdownEvent(1);
CountdownEvent userAuthenticationExpiredCde = new CountdownEvent(1);
CountdownEvent reconnectCde = new CountdownEvent(1);

using (NATSServer.CreateWithConfig(Context.Server3.Port, "operatorJnatsTest.conf"))
{
var opts = Context.GetTestOptionsWithDefaultTimeout(Context.Server3.Port);
opts.SetUserCredentials(credsFile);
opts.MaxReconnect = 1;
opts.DisconnectedEventHandler += (sender, e) =>
{
if (e.Error.ToString().Contains("user authentication expired"))
{
userAuthenticationExpired.Signal();
userAuthenticationExpiredCde.Signal();
}
};
opts.ReconnectedEventHandler += (sender, e) =>
{
if (userAuthenticationExpiredCde.IsSet)
{
reconnectCde.Signal();
}
};

IConnection c = Context.ConnectionFactory.CreateConnection(opts);
userAuthenticationExpired.Wait(wait);
Assert.True(userAuthenticationExpired.IsSet);
IConnection c = Context.ConnectionFactory.CreateConnection(opts, true);
userAuthenticationExpiredCde.Wait(wait);
Assert.True(userAuthenticationExpiredCde.IsSet);
reconnectCde.Wait(wait);
Assert.True(reconnectCde.IsSet);
}
}

Expand Down

0 comments on commit 60544fe

Please sign in to comment.