feature: Bootstrapper secure server discovery

This commit is contained in:
Gary Sharp
2026-01-22 15:26:23 +11:00
parent 71fa53bfb2
commit e1f1973520
40 changed files with 2094 additions and 460 deletions
+1
View File
@@ -141,6 +141,7 @@
<Compile Include="Extensions\EnrolExtensions.cs" />
<Compile Include="Extensions\WhoAmIExtensions.cs" />
<Compile Include="Interop\Certificates.cs" />
<Compile Include="Interop\EndpointDiscovery.cs" />
<Compile Include="Interop\Hardware.cs" />
<Compile Include="Interop\LocalAuthentication.cs" />
<Compile Include="Interop\Native\NetworkConnectionStatuses.cs" />
+8 -7
View File
@@ -10,19 +10,18 @@ namespace Disco.Client
{
public static class ErrorReporting
{
private const string ServicePathTemplate = "http://DISCO:9292/Services/Client/ClientError";
public static string DeviceIdentifier { get; set; }
public static string EnrolmentSessionId { get; set; }
public static void ReportError(Exception Ex, bool ReportToServer)
public static void ReportError(Exception exception, bool reportToServer)
{
bool isClientServiceException = Ex is ClientServiceException;
bool isClientServiceException = exception is ClientServiceException;
ErrorReport report = new ErrorReport()
{
DeviceIdentifier = DeviceIdentifier,
SessionId = EnrolmentSessionId,
JsonException = Ex.IntenseExceptionSerialization()
JsonException = exception.IntenseExceptionSerialization()
};
try
@@ -38,7 +37,7 @@ namespace Disco.Client
catch (Exception) { }
// Don't log server errors back to the server
if (!isClientServiceException && ReportToServer)
if (!isClientServiceException && reportToServer)
{
try
{
@@ -49,7 +48,7 @@ namespace Disco.Client
try
{
Presentation.WriteFatalError(Ex);
Presentation.WriteFatalError(exception);
}
catch (Exception) { }
}
@@ -85,7 +84,9 @@ namespace Disco.Client
string reportJson = JsonConvert.SerializeObject(report);
string reportResponse;
HttpWebRequest request = (HttpWebRequest)WebRequest.Create(ServicePathTemplate);
var serverUri = new Uri(Program.ServerUrl ?? new Uri("http://disco:9292"), "/Services/Client/ClientError");
HttpWebRequest request = (HttpWebRequest)WebRequest.Create(serverUri);
request.UserAgent = $"Disco-Client/{Assembly.GetExecutingAssembly().GetName().Version.ToString(3)}";
request.ContentType = "application/json";
request.Method = WebRequestMethods.Http.Post;
@@ -1,30 +1,23 @@
using Disco.Models.ClientServices;
using Newtonsoft.Json;
using System;
using System.IO;
using System.Net;
using System.Reflection;
namespace Disco.Client.Extensions
{
public static class ClientServicesExtensions
internal static class ClientServicesExtensions
{
//#if DEBUG
// public const string ServicePathAuthenticatedTemplate = "http://WS-GSHARP:57252/Services/Client/Authenticated/{0}";
// public const string ServicePathUnauthenticatedTemplate = "http://WS-GSHARP:57252/Services/Client/Unauthenticated/{0}";
//#else
public const string ServicePathAuthenticatedTemplate = "http://DISCO:9292/Services/Client/Authenticated/{0}";
public const string ServicePathUnauthenticatedTemplate = "http://DISCO:9292/Services/Client/Unauthenticated/{0}";
//#endif
public static ResponseType Post<ResponseType>(this ServiceBase<ResponseType> Service, bool Authenticated)
public static ResponseType Post<ResponseType>(this ServiceBase<ResponseType> service, bool authenticated)
{
ResponseType serviceResponse;
string serviceUrl;
Uri serviceUrl;
if (Authenticated)
serviceUrl = string.Format(ServicePathAuthenticatedTemplate, Service.Feature);
if (authenticated)
serviceUrl = new Uri(Program.ServerUrl, $"/Services/Client/Authenticated/{service.Feature}");
else
serviceUrl = string.Format(ServicePathUnauthenticatedTemplate, Service.Feature);
serviceUrl = new Uri(Program.ServerUrl, $"/Services/Client/Unauthenticated/{service.Feature}");
HttpWebRequest request = (HttpWebRequest)WebRequest.Create(serviceUrl);
request.UserAgent = $"Disco-Client/{Assembly.GetExecutingAssembly().GetName().Version.ToString(3)}";
@@ -39,7 +32,7 @@ namespace Disco.Client.Extensions
{
using (var jsonWriter = new JsonTextWriter(requestWriter))
{
jsonSerializer.Serialize(jsonWriter, Service);
jsonSerializer.Serialize(jsonWriter, service);
}
}
+317
View File
@@ -0,0 +1,317 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using System.Net;
using System.Net.NetworkInformation;
using System.Runtime.InteropServices;
using System.Threading;
namespace Disco.Client.Interop
{
internal class EndpointDiscovery
{
[DllImport("dnsapi", EntryPoint = "DnsQuery_W", CharSet = CharSet.Unicode, SetLastError = true, ExactSpelling = true)]
private static extern int DnsQuery([MarshalAs(UnmanagedType.VBByRefStr)] ref string pszName, NativeDnsQueryTypes wType, NativeDnsQueryOptions options, int aipServers, ref IntPtr ppQueryResults, int pReserved);
[DllImport("dnsapi", CharSet = CharSet.Auto, SetLastError = true)]
private static extern void DnsRecordListFree(IntPtr pRecordList, int FreeType);
private const int DNS_ERROR_RCODE_NAME_ERROR = 0x232B;
private const int DNS_ERROR_BAD_PACKET = 0x251E;
public static Tuple<Uri, string> DiscoverServer(Uri forcedServerUri)
{
// 1. Check first command line argument for server name
if (forcedServerUri != null)
return Tuple.Create(forcedServerUri, "Manual");
// 2. Check for a DNS SRV record for _discoict._tcp.domain
var domainSuffixes = new List<string>();
var primaryDomain = IPGlobalProperties.GetIPGlobalProperties().DomainName;
if (!string.IsNullOrEmpty(primaryDomain))
domainSuffixes.Add(primaryDomain);
var networkInterfaces = NetworkInterface.GetAllNetworkInterfaces()
.Where(ni => ni.OperationalStatus == OperationalStatus.Up);
foreach (var ni in networkInterfaces)
{
var domainSuffix = ni.GetIPProperties().DnsSuffix;
if (!string.IsNullOrWhiteSpace(domainSuffix))
{
if (domainSuffix.Equals("mshome.net", StringComparison.OrdinalIgnoreCase))
continue;
if (!domainSuffixes.Contains(domainSuffix, StringComparer.OrdinalIgnoreCase))
domainSuffixes.Add(domainSuffix);
}
}
foreach (var domain in domainSuffixes)
{
var dnsRecords = GetSRVRecords("_discoict._tcp." + domain);
if (dnsRecords.Count > 0)
{
var firstRecord = dnsRecords.OrderBy(r => r.Priority).ThenByDescending(r => r.Weight).First();
if (firstRecord.Port == 443)
return Tuple.Create(new Uri($"https://{firstRecord.Target}"), "SRV");
else
return Tuple.Create(new Uri($"https://{firstRecord.Target}:{firstRecord.Port}"), "SRV");
}
}
// 3. Detect VicSmart network and try resolving with Disco ICT Online Services
if (TryResolveVicSmartServer(domainSuffixes, out var vicSmartServerUrl))
return Tuple.Create(vicSmartServerUrl, "VicSmart");
// 4. Legacy: Ping 'disco' and assume port 9292
using (Ping p = new Ping())
{
try
{
PingReply pr = p.Send("disco", 2000);
if (pr.Status == IPStatus.Success)
return Tuple.Create(new Uri("http://disco:9292"), "Legacy");
}
catch (Exception)
{
}
}
throw new Exception("Could not locate Disco ICT server on the network.");
}
private static bool TryResolveVicSmartServer(List<string> domainSuffixes, out Uri serverUrl)
{
if (IsVicSmartNetwork(domainSuffixes))
{
var potentialVicSmartAddresses = NetworkInterface.GetAllNetworkInterfaces()
.Where(ni => ni.OperationalStatus == OperationalStatus.Up)
.SelectMany(ni => ni.GetIPProperties().UnicastAddresses)
.Where(ua => ua.Address.AddressFamily == System.Net.Sockets.AddressFamily.InterNetwork)
.Select(ua => ua.Address.GetAddressBytes())
.Where(a => a[0] == 10)
.Select(a => (ushort)((a[1] >> 4) & 0x000F) | ((a[1] << 4) & 0x00F0) | ((a[2] << 12) & 0xF000) | ((a[2] << 4) & 0x0F00))
.Distinct()
.Select(a => $"{a:x4}.vicsmart.discoict.com")
.ToList();
foreach (var potentialAddress in potentialVicSmartAddresses)
{
var records = GetTxtRecords(potentialAddress);
foreach (var record in records)
{
if (!record.Content.StartsWith("https://", StringComparison.OrdinalIgnoreCase))
continue;
if (Uri.TryCreate(record.Content, UriKind.Absolute, out var discoveredUri))
{
serverUrl = discoveredUri;
return true;
}
}
}
}
serverUrl = null;
return false;
}
private static bool IsVicSmartNetwork(List<string> domainSuffixes)
{
if (domainSuffixes.Any(s => string.Equals("services.education.vic.gov.au", s, StringComparison.OrdinalIgnoreCase)) ||
domainSuffixes.Any(s => string.Equals("education.vic.gov.au", s, StringComparison.OrdinalIgnoreCase))
)
return true;
IPHostEntry doeWanDnsEntry;
try
{
doeWanDnsEntry = Dns.GetHostEntry("broadband.doe.wan");
if (doeWanDnsEntry.AddressList.Length > 0)
return true;
}
catch (Exception)
{ }
return false;
}
private static List<DnsTxtRecord> GetTxtRecords(string name)
{
IntPtr resourceRecordsPointer = IntPtr.Zero;
var records = new List<DnsTxtRecord>();
var retry = 5;
retry:
try
{
int queryResult = DnsQuery(ref name, NativeDnsQueryTypes.DNS_TYPE_TEXT, NativeDnsQueryOptions.DNS_QUERY_STANDARD, 0, ref resourceRecordsPointer, 0);
if (queryResult != 0)
{
if (queryResult == DNS_ERROR_RCODE_NAME_ERROR)
return records;
else if (queryResult == DNS_ERROR_BAD_PACKET && retry > 0)
{
// Sometimes a BAD_PACKET error is returned, retry a few times
Thread.Sleep(200);
retry--;
goto retry;
}
else
throw new Win32Exception(queryResult);
}
NativeDnsTxtRecord record;
for (var resourceRecordPointer = resourceRecordsPointer; !resourceRecordPointer.Equals(IntPtr.Zero); resourceRecordPointer = record.pNext)
{
record = Marshal.PtrToStructure<NativeDnsTxtRecord>(resourceRecordPointer);
if (record.wType == (ushort)NativeDnsQueryTypes.DNS_TYPE_TEXT)
records.Add(DnsTxtRecord.FromNativeRecord(record));
}
}
finally
{
if (resourceRecordsPointer != IntPtr.Zero)
DnsRecordListFree(resourceRecordsPointer, 0);
}
return records;
}
private static List<DnsSrvRecord> GetSRVRecords(string name)
{
IntPtr resourceRecordsPointer = IntPtr.Zero;
var records = new List<DnsSrvRecord>();
var retry = 5;
retry:
try
{
int queryResult = DnsQuery(ref name, NativeDnsQueryTypes.DNS_TYPE_SRV, NativeDnsQueryOptions.DNS_QUERY_STANDARD, 0, ref resourceRecordsPointer, 0);
if (queryResult != 0)
{
if (queryResult == DNS_ERROR_RCODE_NAME_ERROR)
return records;
else if (queryResult == DNS_ERROR_BAD_PACKET && retry > 0)
{
// Sometimes a BAD_PACKET error is returned, retry a few times
Thread.Sleep(200);
retry--;
goto retry;
}
else
throw new Win32Exception(queryResult);
}
NativeDnsSrvRecord record;
for (var resourceRecordPointer = resourceRecordsPointer; !resourceRecordPointer.Equals(IntPtr.Zero); resourceRecordPointer = record.pNext)
{
record = Marshal.PtrToStructure<NativeDnsSrvRecord>(resourceRecordPointer);
if (record.wType == (ushort)NativeDnsQueryTypes.DNS_TYPE_SRV)
records.Add(DnsSrvRecord.FromNativeRecord(record));
}
}
finally
{
if (resourceRecordsPointer != IntPtr.Zero)
DnsRecordListFree(resourceRecordsPointer, 0);
}
return records;
}
private enum NativeDnsQueryOptions
{
DNS_QUERY_ACCEPT_TRUNCATED_RESPONSE = 1,
DNS_QUERY_BYPASS_CACHE = 8,
DNS_QUERY_DONT_RESET_TTL_VALUES = 0x100000,
DNS_QUERY_NO_HOSTS_FILE = 0x40,
DNS_QUERY_NO_LOCAL_NAME = 0x20,
DNS_QUERY_NO_NETBT = 0x80,
DNS_QUERY_NO_RECURSION = 4,
DNS_QUERY_NO_WIRE_QUERY = 0x10,
DNS_QUERY_RESERVED = -16777216,
DNS_QUERY_RETURN_MESSAGE = 0x200,
DNS_QUERY_STANDARD = 0,
DNS_QUERY_TREAT_AS_FQDN = 0x1000,
DNS_QUERY_USE_TCP_ONLY = 2,
DNS_QUERY_WIRE_ONLY = 0x100
}
private enum NativeDnsQueryTypes
{
DNS_TYPE_TEXT = 0x0010,
DNS_TYPE_SRV = 0x0021
}
[StructLayout(LayoutKind.Sequential)]
private struct NativeDnsSrvRecord
{
public IntPtr pNext;
[MarshalAs(UnmanagedType.LPWStr)]
public string pName;
public ushort wType;
public ushort wDataLength;
public int flags;
public int dwTtl;
public int dwReserved;
[MarshalAs(UnmanagedType.LPWStr)]
public string pNameTarget;
public ushort wPriority;
public ushort wWeight;
public ushort wPort;
public ushort Pad;
}
private class DnsSrvRecord
{
public string Name { get; set; }
public int Type { get; set; }
public int Ttl { get; set; }
public string Target { get; set; }
public int Priority { get; set; }
public int Weight { get; set; }
public int Port { get; set; }
public static DnsSrvRecord FromNativeRecord(NativeDnsSrvRecord nativeRecord)
{
return new DnsSrvRecord
{
Name = nativeRecord.pName,
Type = nativeRecord.wType,
Ttl = nativeRecord.dwTtl,
Target = nativeRecord.pNameTarget,
Priority = nativeRecord.wPriority,
Weight = nativeRecord.wWeight,
Port = nativeRecord.wPort
};
}
}
[StructLayout(LayoutKind.Sequential)]
private struct NativeDnsTxtRecord
{
public IntPtr pNext;
[MarshalAs(UnmanagedType.LPWStr)]
public string pName;
public ushort wType;
public ushort wDataLength;
public int flags;
public int dwTtl;
public int dwReserved;
public uint dwStringLength;
[MarshalAs(UnmanagedType.LPWStr)]
public string pStringArray;
}
private class DnsTxtRecord
{
public string Name { get; set; }
public int Type { get; set; }
public int Ttl { get; set; }
public string Content { get; set; }
public static DnsTxtRecord FromNativeRecord(NativeDnsTxtRecord nativeRecord)
{
return new DnsTxtRecord
{
Name = nativeRecord.pName,
Type = nativeRecord.wType,
Ttl = nativeRecord.dwTtl,
Content = nativeRecord.pStringArray,
};
}
}
}
}
+15 -3
View File
@@ -1,6 +1,7 @@
using Disco.Client.Extensions;
using Disco.Client.Interop;
using System;
using System.Net;
using System.Reflection;
using System.Text;
using System.Threading;
@@ -26,7 +27,7 @@ namespace Disco.Client
}
public static void UpdateStatus(string SubHeading, string Message, bool ShowProgress, int Progress)
{
Console.WriteLine($"#{SubHeading.EscapeMessage()},{Message.EscapeMessage()},{ShowProgress.ToString()},{Progress.ToString()}");
Console.WriteLine($"#{SubHeading.EscapeMessage()},{Message.EscapeMessage()},{ShowProgress},{Progress}");
}
public static void TryDelay(int Milliseconds)
{
@@ -38,6 +39,11 @@ namespace Disco.Client
{
StringBuilder message = new StringBuilder();
message.AppendLine($"Version: {Assembly.GetExecutingAssembly().GetName().Version.ToString(3)}");
message.Append($"Server: {Program.ServerUrl})");
if (Program.ServerUrl.Scheme.Equals("https", StringComparison.OrdinalIgnoreCase))
message.AppendLine(" [Secure]");
else
message.AppendLine(" [Insecure]");
message.AppendLine($"Device: {Hardware.Information.SerialNumber} ({Hardware.Information.Manufacturer} {Hardware.Information.Model})");
Console.ForegroundColor = ConsoleColor.Yellow;
UpdateStatus("Preparation Client Started", message.ToString(), false, 0);
@@ -48,12 +54,18 @@ namespace Disco.Client
{
Console.ForegroundColor = ConsoleColor.Magenta;
ClientServiceException clientServiceException = ex as ClientServiceException;
if (clientServiceException != null)
if (ex is ClientServiceException clientServiceException)
{
UpdateStatus($"An error occurred during {clientServiceException.ServiceFeature}",
clientServiceException.Message, false, 0);
}
else if (ex is WebException exWeb &&
exWeb.Response is HttpWebResponse webResponse &&
webResponse.StatusCode == HttpStatusCode.InternalServerError)
{
UpdateStatus("Something went wrong on the server",
"Review logs for more information (Configuration > Logging)", false, 0);
}
else
{
StringBuilder message = new StringBuilder();
+59 -6
View File
@@ -1,6 +1,8 @@
using Disco.Client.Extensions;
using Disco.Client.Interop;
using Disco.Models.ClientServices;
using System;
using System.Diagnostics;
using System.Linq;
using System.Net;
@@ -11,6 +13,9 @@ namespace Disco.Client
public static bool IsAuthenticated { get; set; }
public static bool RebootRequired { get; set; }
public static bool AllowUninstall { get; set; }
public static int BootstrapperVersion { get; private set; } = 1;
public static int BootstrapperProcessId { get; private set; } = -1;
public static Uri ServerUrl { get; private set; }
[STAThread]
public static void Main(string[] args)
@@ -24,12 +29,15 @@ namespace Disco.Client
{
Console.WriteLine("Waiting for Debugger to Attach");
System.Threading.Thread.Sleep(1000);
} while (!System.Diagnostics.Debugger.IsAttached);
} while (!Debugger.IsAttached);
}
#endif
// Initialize Environment Settings
SetupEnvironment();
SetupEnvironment(args);
if (ServerUrl == null)
keepProcessing = DiscoverDiscoIct();
// Report to Bootstrapper
Presentation.WriteBanner();
@@ -45,7 +53,7 @@ namespace Disco.Client
Presentation.WriteFooter(RebootRequired, AllowUninstall, !keepProcessing);
}
public static void SetupEnvironment()
public static void SetupEnvironment(string[] args)
{
// Hookup Unhandled Error Handling
AppDomain.CurrentDomain.UnhandledException += ErrorReporting.CurrentDomain_UnhandledException;
@@ -54,21 +62,66 @@ namespace Disco.Client
WebRequest.DefaultWebProxy = new WebProxy();
// Override Http 100 Continue Behaviour
ServicePointManager.Expect100Continue = false;
ServicePointManager.SecurityProtocol |= SecurityProtocolType.Tls12;
// Assume success unless otherwise notified
AllowUninstall = true;
if (args != null && args.Length == 3)
{
// Parse Bootstrapper Version
int parsedVersion;
if (int.TryParse(args[0], out parsedVersion))
BootstrapperVersion = parsedVersion;
// Parse Bootstrapper Process ID
int parsedProcessId;
if (int.TryParse(args[1], out parsedProcessId))
BootstrapperProcessId = parsedProcessId;
// Parse Server URL
Uri parsedUri;
if (Uri.TryCreate(args[2], UriKind.Absolute, out parsedUri))
ServerUrl = parsedUri;
}
else
{
BootstrapperVersion = 1;
BootstrapperProcessId = -1;
ServerUrl = null;
}
// Detect Disco.Bootstrapper - Create Enable UI Delay if Running
Presentation.DelayUI = false;
try
{
Presentation.DelayUI = (System.Diagnostics.Process.GetProcessesByName("Disco.ClientBootstrapper").Length > 0);
if (BootstrapperProcessId != -1)
{
var parentProcess = Process.GetProcessById(BootstrapperProcessId);
Presentation.DelayUI = !parentProcess.HasExited;
}
}
catch (Exception)
{
Presentation.DelayUI = true; // Add Delays on Error
}
}
public static bool DiscoverDiscoIct()
{
try
{
Presentation.UpdateStatus("Detecting Disco ICT", "Locating Disco ICT Server, Please wait...", true, -1);
Presentation.TryDelay(3000);
ServerUrl = EndpointDiscovery.DiscoverServer(null).Item1;
// Complete
return true;
}
catch (Exception ex)
{
ErrorReporting.ReportError(ex, false);
}
return false;
}
public static bool WhoAmI()
{
try
@@ -144,7 +197,7 @@ namespace Disco.Client
var secondsConsumed = (DateTimeOffset.Now - startTime).TotalSeconds;
var progress = (int)((secondsConsumed / totalSeconds) * 100);
Presentation.UpdateStatus($"Pending Device Enrolment Approval: {response.PendingIdentifier}", $"Waiting for enrolment session '{response.PendingIdentifier}' to be approved.{Environment.NewLine}Reason: {response.PendingReason}", true, progress);
Presentation.UpdateStatus($"Pending Device Enrolment Approval: {response.PendingIdentifier}", $"Server: {Program.ServerUrl}{Environment.NewLine}Reason: {response.PendingReason}", true, progress);
System.Threading.Thread.Sleep(TimeSpan.FromSeconds(10));
}
else
+3 -3
View File
@@ -1,9 +1,9 @@
@ECHO OFF
IF /I "%USERDOMAIN%"=="NT AUTHORITY" GOTO RunAsNetworkService
Disco.Client.exe
Disco.Client.exe %1 %2 %3
EXIT /B 0
:RunAsNetworkService
ECHO #Running,Launching Preparation Client, Please wait...{newline}Starting client as 'NT AUTHORITY\Network Service',true,-1
PsExec -acceptula -i -u "NT AUTHORITY\Network Service" -w "%CD%" "%CD%\Start.bat"
EXIT /B 0
PsExec -acceptula -i -u "NT AUTHORITY\Network Service" -w "%CD%" "%CD%\Start.bat %1 %2 %3"
EXIT /B 0