Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/AppCommon/Commands/BaseCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,14 @@ public async Task Run(CancellationToken cancellationToken = default)
Out.WriteError(halt.Message);
if (halt.InnerException != null)
{
Out.WriteLine();
Out.WriteLine("Original exception message: " + halt.InnerException.Message);
var original = halt.GetBaseException();
if (original != halt && !halt.Message.Contains(original.Message))
{
Out.WriteLine();
Out.WriteLine("Original exception message: " + original.Message);
}

Exceptions.ReportError(halt);
}
Environment.ExitCode = halt.ExitCode;
}
Expand Down
65 changes: 59 additions & 6 deletions src/AppCommon/Commands/SqlServerCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,21 @@ public static Command CreateCommand()
command.SetHandler(async context =>
{
var shared = SharedOptions.Parse(context);
var connectionStrings = GetConnectionStrings(context.ParseResult);
var cancellationToken = context.GetCancellationToken();

var runner = new SqlServerCommand(shared, connectionStrings);
await runner.Run(cancellationToken);
try
{
var connectionStrings = GetConnectionStrings(context.ParseResult);

var runner = new SqlServerCommand(shared, connectionStrings);
await runner.Run(cancellationToken);
}
catch (HaltException halt)
{
Out.WriteLine();
Out.WriteError(halt.Message);
Environment.ExitCode = halt.ExitCode;
}
});

return command;
Expand All @@ -56,9 +66,7 @@ static string[] GetConnectionStrings(ParseResult parsed)
throw new FileNotFoundException($"Could not find file specified by {ConnectionStringSource.Name} parameter", sourcePath);
}

return File.ReadAllLines(sourcePath)
.Where(line => !string.IsNullOrWhiteSpace(line))
.ToArray();
return ParseConnectionStringSource(File.ReadAllLines(sourcePath), sourcePath);
}

var single = parsed.GetValueForOption(ConnectionString);
Expand Down Expand Up @@ -92,6 +100,44 @@ static string[] GetConnectionStrings(ParseResult parsed)
return list.ToArray();
}

internal static string[] ParseConnectionStringSource(string[] lines, string sourcePath)
{
var connectionStrings = new List<string>();

for (var i = 0; i < lines.Length; i++)
{
var line = lines[i].Trim();

if (string.IsNullOrWhiteSpace(line))
{
continue;
}

if (line.Length >= 2 && ((line[0] == '"' && line[^1] == '"') || (line[0] == '\'' && line[^1] == '\'')))
{
line = line[1..^1].Trim();
}

try
{
_ = new SqlConnectionStringBuilder(line);
}
catch (Exception x) when (x is FormatException or ArgumentException or KeyNotFoundException)
{
throw new HaltException(HaltReason.InvalidConfig, $"ERROR: Line {i + 1} of '{sourcePath}' could not be parsed as a SQL Server connection string: {x.Message}");
}

connectionStrings.Add(line);
}

if (connectionStrings.Count == 0)
{
throw new HaltException(HaltReason.InvalidConfig, $"ERROR: The file '{sourcePath}' does not contain any connection strings.");
}

return connectionStrings.ToArray();
}

readonly string[] connectionStrings;
DatabaseDetails[] databases;
string scopeType;
Expand All @@ -111,6 +157,13 @@ protected override async Task<EnvironmentDetails> GetEnvironment(CancellationTok

foreach (var db in databases)
{
Out.WriteLine($"Testing connection to server '{db.DataSource}', database '{db.DatabaseName ?? "(default)"}', Integrated Security={(db.IntegratedSecurity ? "true" : "false")}...");
if (db.IntegratedSecurity && OperatingSystem.IsWindows())
{
using var identity = System.Security.Principal.WindowsIdentity.GetCurrent();
Out.WriteLine($" - Connecting as Windows identity '{identity.Name}'");
}

await db.TestConnection(cancellationToken);
}

Expand Down
36 changes: 36 additions & 0 deletions src/AppCommon/Infra/ConnectionStringSanitizer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using System.Text.RegularExpressions;
using Microsoft.Data.SqlClient;

static partial class ConnectionStringSanitizer
{
const string Mask = "*****";

[GeneratedRegex("""(?<key>\b(?:password|pwd))\s*=\s*(?:'[^']*'|"[^"]*"|[^;]*)""", RegexOptions.IgnoreCase)]
private static partial Regex SecretRegex();

/// <summary>
/// Returns the connection string with the Password/PWD value blanked so it can be safely echoed or logged.
/// </summary>
public static string Sanitize(string connectionString)
{
try
{
var builder = new SqlConnectionStringBuilder { ConnectionString = connectionString };
if (!string.IsNullOrEmpty(builder.Password))
{
builder.Password = Mask;
}
return builder.ToString();
}
catch (Exception x) when (x is FormatException or ArgumentException)
{
// Not parseable as a connection string, fall back to pattern-based redaction
return RedactText(connectionString);
}
}

/// <summary>
/// Redacts Password/PWD values in free-form text, such as an exception dump that may embed a connection string.
/// </summary>
public static string RedactText(string text) => text is null ? null : SecretRegex().Replace(text, $"${{key}}={Mask}");
}
35 changes: 34 additions & 1 deletion src/AppCommon/Infra/Exceptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ public static void SetupUnhandledExceptionHandling()

public static void ReportError(Exception x)
{
WriteDiagnosticsLog(x);

var settings = new RaygunSettings()
{
ApplicationVersion = Versioning.NuGetVersion
};

RunInfo.Add("ToolOutput", Out.GetToolOutput());

if (x is SqlException sqlX)
if (FindInChain<SqlException>(x) is SqlException sqlX)
{
RunInfo.Add("SqlException.Number", sqlX.Number.ToString());
if (sqlX.Errors is not null)
Expand Down Expand Up @@ -82,4 +84,35 @@ public static void ReportError(Exception x)
}

}

/// <summary>
/// Writes the full (redacted) exception chain to a local log file so support can diagnose failures
/// without a live debugger, including when running with --unattended.
/// </summary>
static void WriteDiagnosticsLog(Exception x)
{
try
{
var path = Path.Join(Environment.CurrentDirectory, $"throughput-diagnostics-{DateTime.Now:yyyyMMdd-HHmmss}.log");
File.WriteAllText(path, ConnectionStringSanitizer.RedactText(x.ToString()));
Console.WriteLine($"Diagnostic details written to {path}");
}
catch (Exception logX) when (logX is IOException or UnauthorizedAccessException)
{
Console.WriteLine($"Unable to write diagnostics log: {logX.Message}");
}
}

static T FindInChain<T>(Exception x) where T : Exception
{
while (x is not null)
{
if (x is T match)
{
return match;
}
x = x.InnerException;
}
return null;
}
}
26 changes: 25 additions & 1 deletion src/Query/SqlTransport/DatabaseDetails.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Data.SqlClient;
Expand All @@ -12,6 +13,8 @@ public class DatabaseDetails
readonly string connectionString;

public string DatabaseName { get; }
public string DataSource { get; }
public bool IntegratedSecurity { get; }
public List<QueueTableName> Tables { get; private set; }
public int ErrorCount { get; private set; }

Expand All @@ -21,6 +24,8 @@ public DatabaseDetails(string connectionString)
{
var builder = new SqlConnectionStringBuilder { ConnectionString = connectionString, TrustServerCertificate = true };
DatabaseName = builder["Initial Catalog"] as string ?? builder["Database"] as string;
DataSource = builder.DataSource;
IntegratedSecurity = builder.IntegratedSecurity;
this.connectionString = builder.ToString();
}
catch (Exception x) when (x is FormatException or ArgumentException)
Expand All @@ -39,10 +44,29 @@ public async Task TestConnection(CancellationToken cancellationToken = default)
}
catch (SqlException x) when (IsConnectionOrLoginIssue(x))
{
throw new QueryException(QueryFailureReason.Auth, "Could not access SQL database. Is the connection string correct?", x);
throw new QueryException(QueryFailureReason.Auth, BuildConnectionIssueMessage(x), x);
}
}

string BuildConnectionIssueMessage(SqlException x)
{
var serverName = string.IsNullOrEmpty(x.Server) ? DataSource : x.Server;

var message = new StringBuilder()
.AppendLine($"SQL error {x.Number} (state {x.State}, class {x.Class}) from {serverName}: {x.Message}");

if (x.Errors is not null && x.Errors.Count > 1)
{
for (var i = 0; i < x.Errors.Count; i++)
{
var err = x.Errors[i];
_ = message.AppendLine($" - SQL error {err.Number} (state {err.State}, class {err.Class}): {err.Message}");
}
}

return message.ToString().TrimEnd();
}

static bool IsConnectionOrLoginIssue(SqlException x)
{
// Reference is here: https://learn.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/cc645603(v=sql.105)?redirectedfrom=MSDN
Expand Down
131 changes: 131 additions & 0 deletions src/Tests/SqlServer/ConnectionStringSourceTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
namespace Tests.SqlServer
{
using NUnit.Framework;

[TestFixture]
public class ConnectionStringSourceTests
{
const string SourcePath = "connection-strings.txt";

[Test]
public void Parses_valid_lines_and_skips_blank_lines()
{
var lines = new[]
{
"",
"Server=srv1;Database=db1",
" ",
"Server=srv2;Database=db2",
""
};

var result = SqlServerCommand.ParseConnectionStringSource(lines, SourcePath);

Assert.That(result, Is.EqualTo(new[] { "Server=srv1;Database=db1", "Server=srv2;Database=db2" }));
}

[TestCase("\"Server=srv;Database=db\"")]
[TestCase("'Server=srv;Database=db'")]
[TestCase(" \"Server=srv;Database=db\" ")]
public void Strips_a_single_pair_of_wrapping_quotes(string line)
{
var result = SqlServerCommand.ParseConnectionStringSource(new[] { line }, SourcePath);

Assert.That(result, Is.EqualTo(new[] { "Server=srv;Database=db" }));
}

[Test]
public void Reports_one_based_line_number_and_parse_error_for_invalid_line()
{
var lines = new[]
{
"Server=srv1;Database=db1",
"this is not a connection string"
};

var halt = Assert.Throws<HaltException>(() => SqlServerCommand.ParseConnectionStringSource(lines, SourcePath));

Assert.That(halt.ExitCode, Is.EqualTo((int)HaltReason.InvalidConfig));
Assert.That(halt.Message, Does.Contain("Line 2"));
Assert.That(halt.Message, Does.Contain(SourcePath));
}

[Test]
public void Reports_invalid_keyword_parse_error()
{
var halt = Assert.Throws<HaltException>(() => SqlServerCommand.ParseConnectionStringSource(new[] { "NotAKeyword=value" }, SourcePath));

Assert.That(halt.ExitCode, Is.EqualTo((int)HaltReason.InvalidConfig));
Assert.That(halt.Message, Does.Contain("Line 1"));
}

[Test]
public void Throws_when_file_contains_no_connection_strings()
{
var halt = Assert.Throws<HaltException>(() => SqlServerCommand.ParseConnectionStringSource(new[] { "", " " }, SourcePath));

Assert.That(halt.ExitCode, Is.EqualTo((int)HaltReason.InvalidConfig));
}
}

[TestFixture]
public class ConnectionStringSanitizerTests
{
[TestCase("Server=srv;Database=db;User ID=user;Password=s3cret!")]
[TestCase("Server=srv;Database=db;User ID=user;Pwd=s3cret!")]
[TestCase("Server=srv;Database=db;User ID=user;Password='s3cret!;more'")]
public void Sanitize_blanks_the_password(string connectionString)
{
var sanitized = ConnectionStringSanitizer.Sanitize(connectionString);

Assert.That(sanitized, Does.Not.Contain("s3cret!"));
Assert.That(sanitized, Does.Contain("srv"));
Assert.That(sanitized, Does.Contain("user"));
}

[Test]
public void Sanitize_keeps_connection_string_without_password_intact()
{
var sanitized = ConnectionStringSanitizer.Sanitize("Server=srv;Database=db;Integrated Security=True");

Assert.That(sanitized, Does.Contain("srv"));
Assert.That(sanitized, Does.Not.Contain("*****"));
}

[Test]
public void Sanitize_falls_back_to_redaction_for_unparseable_input()
{
var sanitized = ConnectionStringSanitizer.Sanitize("some garbage ;; Password=s3cret!;more garbage");

Assert.That(sanitized, Does.Not.Contain("s3cret!"));
Assert.That(sanitized, Does.Contain("*****"));
}

[TestCase("Login failed. Connection: Server=x;Password=abc123;Encrypt=true", "abc123")]
[TestCase("Connection: Server=x;Pwd=abc123;Encrypt=true", "abc123")]
[TestCase("Connection: Server=x;Password='se;cret';Encrypt=true", "se;cret")]
[TestCase("Connection: Server=x;Password=\"se;cret\";Encrypt=true", "se;cret")]
[TestCase("Connection: Server=x; Password = abc123 ;Encrypt=true", "abc123")]
public void RedactText_masks_secrets_in_free_form_text(string text, string secret)
{
var redacted = ConnectionStringSanitizer.RedactText(text);

Assert.That(redacted, Does.Not.Contain(secret));
Assert.That(redacted, Does.Contain("*****"));
}

[Test]
public void RedactText_leaves_text_without_secrets_untouched()
{
const string text = "SQL error 18456 (state 38) from MYSERVER: Login failed for user 'sa'.";

Assert.That(ConnectionStringSanitizer.RedactText(text), Is.EqualTo(text));
}

[Test]
public void RedactText_handles_null()
{
Assert.That(ConnectionStringSanitizer.RedactText(null), Is.Null);
}
}
}