From e1f1973520e8788af61721e74ea1297870b23884 Mon Sep 17 00:00:00 2001 From: Gary Sharp Date: Thu, 22 Jan 2026 15:26:23 +1100 Subject: [PATCH] feature: Bootstrapper secure server discovery --- Disco.Client/Disco.Client.csproj | 1 + Disco.Client/ErrorReporting.cs | 15 +- .../Extensions/ClientServicesExtensions.cs | 23 +- Disco.Client/Interop/EndpointDiscovery.cs | 317 ++++++++++ Disco.Client/Presentation.cs | 18 +- Disco.Client/Program.cs | 65 +- Disco.Client/Start.bat | 6 +- Disco.ClientBootstrapper/BootstrapperLoop.cs | 367 ++++++------ .../Disco.ClientBootstrapper.csproj | 3 + Disco.ClientBootstrapper/FormStatus.cs | 32 +- Disco.ClientBootstrapper/InstallLoop.cs | 69 +-- .../Interop/CertificateInterop.cs | 15 +- .../Interop/InstallInterop.cs | 209 ++++--- .../Interop/NetworkInterop.cs | 48 +- Disco.ClientBootstrapper/Program.cs | 104 ++-- .../Modules/DevicesConfiguration.cs | 6 + Disco.Models/Disco.Models.csproj | 1 + .../DeviceEnrolmentServerDiscoveryMethod.cs | 13 + .../Interop/DiscoServices/UpdateRequestV2.cs | 1 + .../Enrolment/ConfigEnrolmentIndexModel.cs | 11 +- .../Enrolment/WindowsDeviceEnrolment.cs | 13 + Disco.Services/Disco.Services.csproj | 8 + Disco.Services/Interop/DNS/ADnsRecord.cs | 24 + Disco.Services/Interop/DNS/CnameDnsRecord.cs | 12 + Disco.Services/Interop/DNS/DnsRecord.cs | 20 + Disco.Services/Interop/DNS/DnsRecordType.cs | 10 + Disco.Services/Interop/DNS/DnsService.cs | 30 + Disco.Services/Interop/DNS/NativeDns.cs | 202 +++++++ Disco.Services/Interop/DNS/SrvDnsRecord.cs | 21 + Disco.Services/Interop/DNS/TxtDnsRecord.cs | 12 + .../Interop/DiscoServices/UpdateQuery.cs | 5 + Disco.Services/Interop/VicEduDept/VicSmart.cs | 38 ++ .../API/Controllers/EnrolmentController.cs | 16 + .../Config/Controllers/EnrolmentController.cs | 26 + .../Config/Models/Enrolment/IndexModel.cs | 8 + .../Areas/Config/Views/Enrolment/Index.cshtml | 163 ++++- .../Config/Views/Enrolment/Index.generated.cs | 557 +++++++++++++++++- .../Services/Controllers/ClientController.cs | 32 + .../API.EnrolmentController.generated.cs | 28 + Disco.sln | 5 +- 40 files changed, 2094 insertions(+), 460 deletions(-) create mode 100644 Disco.Client/Interop/EndpointDiscovery.cs create mode 100644 Disco.Models/Services/Devices/DeviceEnrolmentServerDiscoveryMethod.cs create mode 100644 Disco.Services/Interop/DNS/ADnsRecord.cs create mode 100644 Disco.Services/Interop/DNS/CnameDnsRecord.cs create mode 100644 Disco.Services/Interop/DNS/DnsRecord.cs create mode 100644 Disco.Services/Interop/DNS/DnsRecordType.cs create mode 100644 Disco.Services/Interop/DNS/DnsService.cs create mode 100644 Disco.Services/Interop/DNS/NativeDns.cs create mode 100644 Disco.Services/Interop/DNS/SrvDnsRecord.cs create mode 100644 Disco.Services/Interop/DNS/TxtDnsRecord.cs diff --git a/Disco.Client/Disco.Client.csproj b/Disco.Client/Disco.Client.csproj index 434bf266..86c4d507 100644 --- a/Disco.Client/Disco.Client.csproj +++ b/Disco.Client/Disco.Client.csproj @@ -141,6 +141,7 @@ + diff --git a/Disco.Client/ErrorReporting.cs b/Disco.Client/ErrorReporting.cs index 747f3665..19cb8f5e 100644 --- a/Disco.Client/ErrorReporting.cs +++ b/Disco.Client/ErrorReporting.cs @@ -10,19 +10,18 @@ namespace Disco.Client { public static class ErrorReporting { - private const string ServicePathTemplate = "http://DISCO:9292/Services/Client/ClientError"; public static string DeviceIdentifier { get; set; } public static string EnrolmentSessionId { get; set; } - public static void ReportError(Exception Ex, bool ReportToServer) + public static void ReportError(Exception exception, bool reportToServer) { - bool isClientServiceException = Ex is ClientServiceException; + bool isClientServiceException = exception is ClientServiceException; ErrorReport report = new ErrorReport() { DeviceIdentifier = DeviceIdentifier, SessionId = EnrolmentSessionId, - JsonException = Ex.IntenseExceptionSerialization() + JsonException = exception.IntenseExceptionSerialization() }; try @@ -38,7 +37,7 @@ namespace Disco.Client catch (Exception) { } // Don't log server errors back to the server - if (!isClientServiceException && ReportToServer) + if (!isClientServiceException && reportToServer) { try { @@ -49,7 +48,7 @@ namespace Disco.Client try { - Presentation.WriteFatalError(Ex); + Presentation.WriteFatalError(exception); } catch (Exception) { } } @@ -85,7 +84,9 @@ namespace Disco.Client string reportJson = JsonConvert.SerializeObject(report); string reportResponse; - HttpWebRequest request = (HttpWebRequest)WebRequest.Create(ServicePathTemplate); + var serverUri = new Uri(Program.ServerUrl ?? new Uri("http://disco:9292"), "/Services/Client/ClientError"); + + HttpWebRequest request = (HttpWebRequest)WebRequest.Create(serverUri); request.UserAgent = $"Disco-Client/{Assembly.GetExecutingAssembly().GetName().Version.ToString(3)}"; request.ContentType = "application/json"; request.Method = WebRequestMethods.Http.Post; diff --git a/Disco.Client/Extensions/ClientServicesExtensions.cs b/Disco.Client/Extensions/ClientServicesExtensions.cs index 708cac7f..74c96737 100644 --- a/Disco.Client/Extensions/ClientServicesExtensions.cs +++ b/Disco.Client/Extensions/ClientServicesExtensions.cs @@ -1,30 +1,23 @@ using Disco.Models.ClientServices; using Newtonsoft.Json; +using System; using System.IO; using System.Net; using System.Reflection; namespace Disco.Client.Extensions { - public static class ClientServicesExtensions + internal static class ClientServicesExtensions { - //#if DEBUG - // public const string ServicePathAuthenticatedTemplate = "http://WS-GSHARP:57252/Services/Client/Authenticated/{0}"; - // public const string ServicePathUnauthenticatedTemplate = "http://WS-GSHARP:57252/Services/Client/Unauthenticated/{0}"; - //#else - public const string ServicePathAuthenticatedTemplate = "http://DISCO:9292/Services/Client/Authenticated/{0}"; - public const string ServicePathUnauthenticatedTemplate = "http://DISCO:9292/Services/Client/Unauthenticated/{0}"; - //#endif - - public static ResponseType Post(this ServiceBase Service, bool Authenticated) + public static ResponseType Post(this ServiceBase service, bool authenticated) { ResponseType serviceResponse; - string serviceUrl; + Uri serviceUrl; - if (Authenticated) - serviceUrl = string.Format(ServicePathAuthenticatedTemplate, Service.Feature); + if (authenticated) + serviceUrl = new Uri(Program.ServerUrl, $"/Services/Client/Authenticated/{service.Feature}"); else - serviceUrl = string.Format(ServicePathUnauthenticatedTemplate, Service.Feature); + serviceUrl = new Uri(Program.ServerUrl, $"/Services/Client/Unauthenticated/{service.Feature}"); HttpWebRequest request = (HttpWebRequest)WebRequest.Create(serviceUrl); request.UserAgent = $"Disco-Client/{Assembly.GetExecutingAssembly().GetName().Version.ToString(3)}"; @@ -39,7 +32,7 @@ namespace Disco.Client.Extensions { using (var jsonWriter = new JsonTextWriter(requestWriter)) { - jsonSerializer.Serialize(jsonWriter, Service); + jsonSerializer.Serialize(jsonWriter, service); } } diff --git a/Disco.Client/Interop/EndpointDiscovery.cs b/Disco.Client/Interop/EndpointDiscovery.cs new file mode 100644 index 00000000..7c74f90a --- /dev/null +++ b/Disco.Client/Interop/EndpointDiscovery.cs @@ -0,0 +1,317 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Net; +using System.Net.NetworkInformation; +using System.Runtime.InteropServices; +using System.Threading; + +namespace Disco.Client.Interop +{ + internal class EndpointDiscovery + { + [DllImport("dnsapi", EntryPoint = "DnsQuery_W", CharSet = CharSet.Unicode, SetLastError = true, ExactSpelling = true)] + private static extern int DnsQuery([MarshalAs(UnmanagedType.VBByRefStr)] ref string pszName, NativeDnsQueryTypes wType, NativeDnsQueryOptions options, int aipServers, ref IntPtr ppQueryResults, int pReserved); + + [DllImport("dnsapi", CharSet = CharSet.Auto, SetLastError = true)] + private static extern void DnsRecordListFree(IntPtr pRecordList, int FreeType); + private const int DNS_ERROR_RCODE_NAME_ERROR = 0x232B; + private const int DNS_ERROR_BAD_PACKET = 0x251E; + public static Tuple DiscoverServer(Uri forcedServerUri) + { + // 1. Check first command line argument for server name + if (forcedServerUri != null) + return Tuple.Create(forcedServerUri, "Manual"); + + // 2. Check for a DNS SRV record for _discoict._tcp.domain + var domainSuffixes = new List(); + var primaryDomain = IPGlobalProperties.GetIPGlobalProperties().DomainName; + if (!string.IsNullOrEmpty(primaryDomain)) + domainSuffixes.Add(primaryDomain); + var networkInterfaces = NetworkInterface.GetAllNetworkInterfaces() + .Where(ni => ni.OperationalStatus == OperationalStatus.Up); + foreach (var ni in networkInterfaces) + { + var domainSuffix = ni.GetIPProperties().DnsSuffix; + if (!string.IsNullOrWhiteSpace(domainSuffix)) + { + if (domainSuffix.Equals("mshome.net", StringComparison.OrdinalIgnoreCase)) + continue; + + if (!domainSuffixes.Contains(domainSuffix, StringComparer.OrdinalIgnoreCase)) + domainSuffixes.Add(domainSuffix); + } + } + foreach (var domain in domainSuffixes) + { + var dnsRecords = GetSRVRecords("_discoict._tcp." + domain); + if (dnsRecords.Count > 0) + { + var firstRecord = dnsRecords.OrderBy(r => r.Priority).ThenByDescending(r => r.Weight).First(); + if (firstRecord.Port == 443) + return Tuple.Create(new Uri($"https://{firstRecord.Target}"), "SRV"); + else + return Tuple.Create(new Uri($"https://{firstRecord.Target}:{firstRecord.Port}"), "SRV"); + } + } + + // 3. Detect VicSmart network and try resolving with Disco ICT Online Services + if (TryResolveVicSmartServer(domainSuffixes, out var vicSmartServerUrl)) + return Tuple.Create(vicSmartServerUrl, "VicSmart"); + + // 4. Legacy: Ping 'disco' and assume port 9292 + using (Ping p = new Ping()) + { + try + { + PingReply pr = p.Send("disco", 2000); + if (pr.Status == IPStatus.Success) + return Tuple.Create(new Uri("http://disco:9292"), "Legacy"); + } + catch (Exception) + { + } + } + throw new Exception("Could not locate Disco ICT server on the network."); + } + + private static bool TryResolveVicSmartServer(List domainSuffixes, out Uri serverUrl) + { + if (IsVicSmartNetwork(domainSuffixes)) + { + var potentialVicSmartAddresses = NetworkInterface.GetAllNetworkInterfaces() + .Where(ni => ni.OperationalStatus == OperationalStatus.Up) + .SelectMany(ni => ni.GetIPProperties().UnicastAddresses) + .Where(ua => ua.Address.AddressFamily == System.Net.Sockets.AddressFamily.InterNetwork) + .Select(ua => ua.Address.GetAddressBytes()) + .Where(a => a[0] == 10) + .Select(a => (ushort)((a[1] >> 4) & 0x000F) | ((a[1] << 4) & 0x00F0) | ((a[2] << 12) & 0xF000) | ((a[2] << 4) & 0x0F00)) + .Distinct() + .Select(a => $"{a:x4}.vicsmart.discoict.com") + .ToList(); + + foreach (var potentialAddress in potentialVicSmartAddresses) + { + var records = GetTxtRecords(potentialAddress); + + foreach (var record in records) + { + if (!record.Content.StartsWith("https://", StringComparison.OrdinalIgnoreCase)) + continue; + + if (Uri.TryCreate(record.Content, UriKind.Absolute, out var discoveredUri)) + { + serverUrl = discoveredUri; + return true; + } + } + } + } + + serverUrl = null; + return false; + } + + private static bool IsVicSmartNetwork(List domainSuffixes) + { + if (domainSuffixes.Any(s => string.Equals("services.education.vic.gov.au", s, StringComparison.OrdinalIgnoreCase)) || + domainSuffixes.Any(s => string.Equals("education.vic.gov.au", s, StringComparison.OrdinalIgnoreCase)) + ) + return true; + + IPHostEntry doeWanDnsEntry; + try + { + doeWanDnsEntry = Dns.GetHostEntry("broadband.doe.wan"); + if (doeWanDnsEntry.AddressList.Length > 0) + return true; + } + catch (Exception) + { } + return false; + } + + private static List GetTxtRecords(string name) + { + IntPtr resourceRecordsPointer = IntPtr.Zero; + var records = new List(); + var retry = 5; + retry: + try + { + int queryResult = DnsQuery(ref name, NativeDnsQueryTypes.DNS_TYPE_TEXT, NativeDnsQueryOptions.DNS_QUERY_STANDARD, 0, ref resourceRecordsPointer, 0); + if (queryResult != 0) + { + if (queryResult == DNS_ERROR_RCODE_NAME_ERROR) + return records; + else if (queryResult == DNS_ERROR_BAD_PACKET && retry > 0) + { + // Sometimes a BAD_PACKET error is returned, retry a few times + Thread.Sleep(200); + retry--; + goto retry; + } + else + throw new Win32Exception(queryResult); + } + NativeDnsTxtRecord record; + for (var resourceRecordPointer = resourceRecordsPointer; !resourceRecordPointer.Equals(IntPtr.Zero); resourceRecordPointer = record.pNext) + { + record = Marshal.PtrToStructure(resourceRecordPointer); + if (record.wType == (ushort)NativeDnsQueryTypes.DNS_TYPE_TEXT) + records.Add(DnsTxtRecord.FromNativeRecord(record)); + } + } + finally + { + if (resourceRecordsPointer != IntPtr.Zero) + DnsRecordListFree(resourceRecordsPointer, 0); + } + return records; + } + + private static List GetSRVRecords(string name) + { + IntPtr resourceRecordsPointer = IntPtr.Zero; + var records = new List(); + var retry = 5; + retry: + try + { + int queryResult = DnsQuery(ref name, NativeDnsQueryTypes.DNS_TYPE_SRV, NativeDnsQueryOptions.DNS_QUERY_STANDARD, 0, ref resourceRecordsPointer, 0); + if (queryResult != 0) + { + if (queryResult == DNS_ERROR_RCODE_NAME_ERROR) + return records; + else if (queryResult == DNS_ERROR_BAD_PACKET && retry > 0) + { + // Sometimes a BAD_PACKET error is returned, retry a few times + Thread.Sleep(200); + retry--; + goto retry; + } + else + throw new Win32Exception(queryResult); + } + NativeDnsSrvRecord record; + for (var resourceRecordPointer = resourceRecordsPointer; !resourceRecordPointer.Equals(IntPtr.Zero); resourceRecordPointer = record.pNext) + { + record = Marshal.PtrToStructure(resourceRecordPointer); + if (record.wType == (ushort)NativeDnsQueryTypes.DNS_TYPE_SRV) + records.Add(DnsSrvRecord.FromNativeRecord(record)); + } + } + finally + { + if (resourceRecordsPointer != IntPtr.Zero) + DnsRecordListFree(resourceRecordsPointer, 0); + } + return records; + } + + private enum NativeDnsQueryOptions + { + DNS_QUERY_ACCEPT_TRUNCATED_RESPONSE = 1, + DNS_QUERY_BYPASS_CACHE = 8, + DNS_QUERY_DONT_RESET_TTL_VALUES = 0x100000, + DNS_QUERY_NO_HOSTS_FILE = 0x40, + DNS_QUERY_NO_LOCAL_NAME = 0x20, + DNS_QUERY_NO_NETBT = 0x80, + DNS_QUERY_NO_RECURSION = 4, + DNS_QUERY_NO_WIRE_QUERY = 0x10, + DNS_QUERY_RESERVED = -16777216, + DNS_QUERY_RETURN_MESSAGE = 0x200, + DNS_QUERY_STANDARD = 0, + DNS_QUERY_TREAT_AS_FQDN = 0x1000, + DNS_QUERY_USE_TCP_ONLY = 2, + DNS_QUERY_WIRE_ONLY = 0x100 + } + + private enum NativeDnsQueryTypes + { + DNS_TYPE_TEXT = 0x0010, + DNS_TYPE_SRV = 0x0021 + } + + [StructLayout(LayoutKind.Sequential)] + private struct NativeDnsSrvRecord + { + public IntPtr pNext; + [MarshalAs(UnmanagedType.LPWStr)] + public string pName; + public ushort wType; + public ushort wDataLength; + public int flags; + public int dwTtl; + public int dwReserved; + [MarshalAs(UnmanagedType.LPWStr)] + public string pNameTarget; + public ushort wPriority; + public ushort wWeight; + public ushort wPort; + public ushort Pad; + } + + private class DnsSrvRecord + { + public string Name { get; set; } + public int Type { get; set; } + public int Ttl { get; set; } + public string Target { get; set; } + public int Priority { get; set; } + public int Weight { get; set; } + public int Port { get; set; } + + public static DnsSrvRecord FromNativeRecord(NativeDnsSrvRecord nativeRecord) + { + return new DnsSrvRecord + { + Name = nativeRecord.pName, + Type = nativeRecord.wType, + Ttl = nativeRecord.dwTtl, + Target = nativeRecord.pNameTarget, + Priority = nativeRecord.wPriority, + Weight = nativeRecord.wWeight, + Port = nativeRecord.wPort + }; + } + } + + [StructLayout(LayoutKind.Sequential)] + private struct NativeDnsTxtRecord + { + public IntPtr pNext; + [MarshalAs(UnmanagedType.LPWStr)] + public string pName; + public ushort wType; + public ushort wDataLength; + public int flags; + public int dwTtl; + public int dwReserved; + public uint dwStringLength; + [MarshalAs(UnmanagedType.LPWStr)] + public string pStringArray; + } + + private class DnsTxtRecord + { + public string Name { get; set; } + public int Type { get; set; } + public int Ttl { get; set; } + public string Content { get; set; } + + public static DnsTxtRecord FromNativeRecord(NativeDnsTxtRecord nativeRecord) + { + return new DnsTxtRecord + { + Name = nativeRecord.pName, + Type = nativeRecord.wType, + Ttl = nativeRecord.dwTtl, + Content = nativeRecord.pStringArray, + }; + } + } + + } +} diff --git a/Disco.Client/Presentation.cs b/Disco.Client/Presentation.cs index 211ece31..66d1e647 100644 --- a/Disco.Client/Presentation.cs +++ b/Disco.Client/Presentation.cs @@ -1,6 +1,7 @@ using Disco.Client.Extensions; using Disco.Client.Interop; using System; +using System.Net; using System.Reflection; using System.Text; using System.Threading; @@ -26,7 +27,7 @@ namespace Disco.Client } public static void UpdateStatus(string SubHeading, string Message, bool ShowProgress, int Progress) { - Console.WriteLine($"#{SubHeading.EscapeMessage()},{Message.EscapeMessage()},{ShowProgress.ToString()},{Progress.ToString()}"); + Console.WriteLine($"#{SubHeading.EscapeMessage()},{Message.EscapeMessage()},{ShowProgress},{Progress}"); } public static void TryDelay(int Milliseconds) { @@ -38,6 +39,11 @@ namespace Disco.Client { StringBuilder message = new StringBuilder(); message.AppendLine($"Version: {Assembly.GetExecutingAssembly().GetName().Version.ToString(3)}"); + message.Append($"Server: {Program.ServerUrl})"); + if (Program.ServerUrl.Scheme.Equals("https", StringComparison.OrdinalIgnoreCase)) + message.AppendLine(" [Secure]"); + else + message.AppendLine(" [Insecure]"); message.AppendLine($"Device: {Hardware.Information.SerialNumber} ({Hardware.Information.Manufacturer} {Hardware.Information.Model})"); Console.ForegroundColor = ConsoleColor.Yellow; UpdateStatus("Preparation Client Started", message.ToString(), false, 0); @@ -48,12 +54,18 @@ namespace Disco.Client { Console.ForegroundColor = ConsoleColor.Magenta; - ClientServiceException clientServiceException = ex as ClientServiceException; - if (clientServiceException != null) + if (ex is ClientServiceException clientServiceException) { UpdateStatus($"An error occurred during {clientServiceException.ServiceFeature}", clientServiceException.Message, false, 0); } + else if (ex is WebException exWeb && + exWeb.Response is HttpWebResponse webResponse && + webResponse.StatusCode == HttpStatusCode.InternalServerError) + { + UpdateStatus("Something went wrong on the server", + "Review logs for more information (Configuration > Logging)", false, 0); + } else { StringBuilder message = new StringBuilder(); diff --git a/Disco.Client/Program.cs b/Disco.Client/Program.cs index 50839971..51cd656a 100644 --- a/Disco.Client/Program.cs +++ b/Disco.Client/Program.cs @@ -1,6 +1,8 @@ using Disco.Client.Extensions; +using Disco.Client.Interop; using Disco.Models.ClientServices; using System; +using System.Diagnostics; using System.Linq; using System.Net; @@ -11,6 +13,9 @@ namespace Disco.Client public static bool IsAuthenticated { get; set; } public static bool RebootRequired { get; set; } public static bool AllowUninstall { get; set; } + public static int BootstrapperVersion { get; private set; } = 1; + public static int BootstrapperProcessId { get; private set; } = -1; + public static Uri ServerUrl { get; private set; } [STAThread] public static void Main(string[] args) @@ -24,12 +29,15 @@ namespace Disco.Client { Console.WriteLine("Waiting for Debugger to Attach"); System.Threading.Thread.Sleep(1000); - } while (!System.Diagnostics.Debugger.IsAttached); + } while (!Debugger.IsAttached); } #endif // Initialize Environment Settings - SetupEnvironment(); + SetupEnvironment(args); + + if (ServerUrl == null) + keepProcessing = DiscoverDiscoIct(); // Report to Bootstrapper Presentation.WriteBanner(); @@ -45,7 +53,7 @@ namespace Disco.Client Presentation.WriteFooter(RebootRequired, AllowUninstall, !keepProcessing); } - public static void SetupEnvironment() + public static void SetupEnvironment(string[] args) { // Hookup Unhandled Error Handling AppDomain.CurrentDomain.UnhandledException += ErrorReporting.CurrentDomain_UnhandledException; @@ -54,21 +62,66 @@ namespace Disco.Client WebRequest.DefaultWebProxy = new WebProxy(); // Override Http 100 Continue Behaviour ServicePointManager.Expect100Continue = false; + ServicePointManager.SecurityProtocol |= SecurityProtocolType.Tls12; // Assume success unless otherwise notified AllowUninstall = true; + if (args != null && args.Length == 3) + { + // Parse Bootstrapper Version + int parsedVersion; + if (int.TryParse(args[0], out parsedVersion)) + BootstrapperVersion = parsedVersion; + // Parse Bootstrapper Process ID + int parsedProcessId; + if (int.TryParse(args[1], out parsedProcessId)) + BootstrapperProcessId = parsedProcessId; + // Parse Server URL + Uri parsedUri; + if (Uri.TryCreate(args[2], UriKind.Absolute, out parsedUri)) + ServerUrl = parsedUri; + } + else + { + BootstrapperVersion = 1; + BootstrapperProcessId = -1; + ServerUrl = null; + } + // Detect Disco.Bootstrapper - Create Enable UI Delay if Running + Presentation.DelayUI = false; try { - Presentation.DelayUI = (System.Diagnostics.Process.GetProcessesByName("Disco.ClientBootstrapper").Length > 0); + if (BootstrapperProcessId != -1) + { + var parentProcess = Process.GetProcessById(BootstrapperProcessId); + Presentation.DelayUI = !parentProcess.HasExited; + } } catch (Exception) { - Presentation.DelayUI = true; // Add Delays on Error } } + public static bool DiscoverDiscoIct() + { + try + { + Presentation.UpdateStatus("Detecting Disco ICT", "Locating Disco ICT Server, Please wait...", true, -1); + Presentation.TryDelay(3000); + ServerUrl = EndpointDiscovery.DiscoverServer(null).Item1; + + // Complete + return true; + } + catch (Exception ex) + { + ErrorReporting.ReportError(ex, false); + } + return false; + } + public static bool WhoAmI() { try @@ -144,7 +197,7 @@ namespace Disco.Client var secondsConsumed = (DateTimeOffset.Now - startTime).TotalSeconds; var progress = (int)((secondsConsumed / totalSeconds) * 100); - Presentation.UpdateStatus($"Pending Device Enrolment Approval: {response.PendingIdentifier}", $"Waiting for enrolment session '{response.PendingIdentifier}' to be approved.{Environment.NewLine}Reason: {response.PendingReason}", true, progress); + Presentation.UpdateStatus($"Pending Device Enrolment Approval: {response.PendingIdentifier}", $"Server: {Program.ServerUrl}{Environment.NewLine}Reason: {response.PendingReason}", true, progress); System.Threading.Thread.Sleep(TimeSpan.FromSeconds(10)); } else diff --git a/Disco.Client/Start.bat b/Disco.Client/Start.bat index a8df9249..65dabbe3 100644 --- a/Disco.Client/Start.bat +++ b/Disco.Client/Start.bat @@ -1,9 +1,9 @@ @ECHO OFF IF /I "%USERDOMAIN%"=="NT AUTHORITY" GOTO RunAsNetworkService -Disco.Client.exe +Disco.Client.exe %1 %2 %3 EXIT /B 0 :RunAsNetworkService ECHO #Running,Launching Preparation Client, Please wait...{newline}Starting client as 'NT AUTHORITY\Network Service',true,-1 -PsExec -acceptula -i -u "NT AUTHORITY\Network Service" -w "%CD%" "%CD%\Start.bat" -EXIT /B 0 \ No newline at end of file +PsExec -acceptula -i -u "NT AUTHORITY\Network Service" -w "%CD%" "%CD%\Start.bat %1 %2 %3" +EXIT /B 0 \ No newline at end of file diff --git a/Disco.ClientBootstrapper/BootstrapperLoop.cs b/Disco.ClientBootstrapper/BootstrapperLoop.cs index 9edcf8be..935f3a79 100644 --- a/Disco.ClientBootstrapper/BootstrapperLoop.cs +++ b/Disco.ClientBootstrapper/BootstrapperLoop.cs @@ -1,4 +1,6 @@ -using System; +using Disco.Client.Interop; +using Disco.ClientBootstrapper.Interop; +using System; using System.Collections.Generic; using System.Diagnostics; using System.IO; @@ -6,46 +8,191 @@ using System.Linq; using System.Net; using System.Text; using System.Threading; +using System.Threading.Tasks; namespace Disco.ClientBootstrapper { - class BootstrapperLoop + internal class BootstrapperLoop { - - public Thread LoopThread; - public delegate void LoopCompleteCallback(); - private LoopCompleteCallback mLoopCompleteCallback; - private IStatus statusUI; + private readonly Func completeCallback; + private readonly CancellationToken cancellationToken; + private readonly IStatus statusUI; + private readonly Uri forcedServerUrl; private string tempWorkingDirectory; - private StringBuilder errorMessage; private Process clientProcess; - //#if DEBUG - // public const string DiscoServerName = "WS-GSHARP"; - // public const int DiscoServerPort = 57252; - //#else - public const string DiscoServerName = "DISCO"; - public const int DiscoServerPort = 9292; - //#endif - - public BootstrapperLoop(IStatus StatusUI, LoopCompleteCallback Callback) + public BootstrapperLoop(IStatus statusUI, Uri forcedServerUrl, Func callback, CancellationToken cancellationToken) { - statusUI = StatusUI; - mLoopCompleteCallback = Callback; - errorMessage = new StringBuilder(); + this.statusUI = statusUI; + this.forcedServerUrl = forcedServerUrl; + completeCallback = callback; + this.cancellationToken = cancellationToken; } public void Start() { - LoopThread = new Thread(new ThreadStart(loopHost)); - LoopThread.Start(); + Task.Factory.StartNew(async () => + { + await Loop(forcedServerUrl, cancellationToken); + }, cancellationToken); } - private void loopHost() + private async Task Loop(Uri forcedServerUrl, CancellationToken cancellationToken) { try { - loop(); + statusUI.UpdateStatus("System Preparation (Bootstrapper)", "Starting", "Please wait...", true, -1); + + tempWorkingDirectory = Path.Combine(Path.GetPathRoot(Environment.SystemDirectory), @"Disco\Temp"); + if (!Directory.Exists(tempWorkingDirectory)) + Directory.CreateDirectory(tempWorkingDirectory); + + // Check for Network Connectivity + statusUI.UpdateStatus(null, "Detecting Network", "Checking network connectivity, Please wait...", true, -1); + if (!NetworkInterop.HasNetworkConnectivity()) + { + statusUI.UpdateStatus(null, "Detecting Network", "No network connectivity detected, Diagnosing...", true, -1); + statusUI_WriteAdapterInfo(); + + if (!NetworkInterop.HasNetworkConnectivity()) + { + // Check for Wireless + var hasWireless = (NetworkInterop.NetworkAdapters.Count(na => na.IsWireless) > 0); + if (hasWireless) + { + // True: Do wireless loop + statusUI.UpdateStatus(null, "Configuring Wireless Network", "Wireless adapter detected, Configuring...", true, -1); + await NetworkInterop.ConfigureWireless(cancellationToken); + statusUI.UpdateStatus(null, "Waiting for Wireless Network", null, true, 0); + for (int i = 0; i < 30; i++) + { + statusUI_WriteAdapterInfo(); + statusUI.UpdateStatus(null, null, null, true, i); + await Program.SleepThread(2000, false, cancellationToken); + if (NetworkInterop.HasNetworkConnectivity()) + break; + } + if (!NetworkInterop.HasNetworkConnectivity()) + { + statusUI.UpdateStatus(null, "Wireless Network Failed", "Unable to connect to the wireless network, please connect the network cable...", false); + await Program.SleepThread(3000, false, cancellationToken); + } + } + + if (!NetworkInterop.HasNetworkConnectivity()) + { + // Instruct user to connect network cable + statusUI.UpdateStatus(null, "Please connect the network cable", null); + for (int i = 0; i < 30; i++) + { + statusUI_WriteAdapterInfo(); + statusUI.UpdateStatus(null, null, null, true, i); + await Program.SleepThread(2000, false, cancellationToken); + if (NetworkInterop.HasNetworkConnectivity()) + break; + } + } + } + + if (!NetworkInterop.HasNetworkConnectivity()) + { + // Client Failed + if (completeCallback != null) + await completeCallback(cancellationToken); + return; + } + } + + Tuple serverDiscovery; + statusUI.UpdateStatus(null, "Detecting Disco ICT", "Locating Disco ICT Server, Please wait...", true, -1); + try + { + serverDiscovery = EndpointDiscovery.DiscoverServer(forcedServerUrl); + statusUI.UpdateStatus(null, null, $"{serverDiscovery.Item1} ({serverDiscovery.Item2})", true, -1); + } + catch (Exception) + { + statusUI.UpdateStatus(null, null, "Failed to locate Disco ICT Server, exiting...", true, -1); + await Program.SleepThread(2000, false, cancellationToken); + throw; + } + + // Download Client + statusUI.UpdateStatus(null, "Downloading", "Retrieving Preparation Client, Please wait...", true, -1); + string clientSourceLocation = Path.Combine(tempWorkingDirectory, "PreparationClient.zip"); + using (var webClient = new WebClient()) + { + // Don't use a proxy when downloading the Client + webClient.Proxy = new WebProxy(); + webClient.Headers.Add("X-DiscoICT-Discovery", serverDiscovery.Item2); + try + { + webClient.DownloadFile(new Uri(serverDiscovery.Item1, "/Services/Client/PreparationClient"), clientSourceLocation); + } + catch (WebException ex) + { + if (ex.Response != null && + ex.Response is HttpWebResponse response) + { + if (response.StatusCode == HttpStatusCode.BadRequest) + { + statusUI.UpdateStatus(null, "Download failed: Bad Request", response.StatusDescription, true, -1); + await Program.SleepThread(5000, false, cancellationToken); + } + else if (response.StatusCode == HttpStatusCode.InternalServerError) + { + statusUI.UpdateStatus(null, "Download failed: Something went wrong on the server", "Review logs for more information (Configuration > Logging)", true, -1); + await Program.SleepThread(5000, false, cancellationToken); + } + } + throw; + } + } + + // Unzip Client + statusUI.UpdateStatus(null, "Extracting", "Retrieving Preparation Client, Please wait...", true, -1); + string clientLocation = Path.Combine(tempWorkingDirectory, "PreparationClient"); + if (Directory.Exists(clientLocation)) + Directory.Delete(clientLocation, true); + + Directory.CreateDirectory(clientLocation); + using (var clientSource = Ionic.Zip.ZipFile.Read(clientSourceLocation)) + { + clientSource.ExtractAll(clientLocation, Ionic.Zip.ExtractExistingFileAction.OverwriteSilently); + } + + // Launch Client + statusUI.UpdateStatus("System Preparation (Client)", "Running", "Launching Preparation Client, Please wait...", true, -1); + ProcessStartInfo clientProcessStart = new ProcessStartInfo(Path.Combine(clientLocation, "Start.bat"), $"2 {Process.GetCurrentProcess().Id} {serverDiscovery.Item1}") + { + WorkingDirectory = clientLocation, + CreateNoWindow = true, + RedirectStandardOutput = true, + UseShellExecute = false, + }; + using (clientProcess = Process.Start(clientProcessStart)) + { + // Read StdOutput until End + try + { + clientProcess.OutputDataReceived += new DataReceivedEventHandler(clientProcess_OutputDataReceived); + clientProcess.BeginOutputReadLine(); + clientProcess.WaitForExit(); + } + catch (Exception) { throw; } + finally + { + try { clientProcess.CloseMainWindow(); } + catch (Exception) { } + } + } + clientProcess = null; + + // Cleanup + if (Directory.Exists(tempWorkingDirectory)) + Directory.Delete(tempWorkingDirectory, true); + CertificateInterop.RemoveTempCerts(); + } catch (Exception ex) { @@ -53,167 +200,20 @@ namespace Disco.ClientBootstrapper return; if (ex.GetType() == typeof(ThreadInterruptedException)) return; - Program.WriteAppError(ex); - throw; - } - } - - private void loop() - { - -#if Debug - statusUI.UpdateStatus("Waiting for Debugger", "Please wait...", true, -1); - try - { - do - { - System.Threading.Thread.Sleep(10); - } while (!System.Diagnostics.Debugger.IsAttached); - } - catch (Exception ex) - { - statusUI.UpdateStatus("Error", ex.Message, true, -1); - return; - } -#else - statusUI.UpdateStatus("System Preparation (Bootstrapper)", "Starting", "Please wait...", true, -1); -#endif - - tempWorkingDirectory = Path.Combine(Path.GetPathRoot(Environment.SystemDirectory), "Disco\\Temp"); - if (!Directory.Exists(tempWorkingDirectory)) - Directory.CreateDirectory(tempWorkingDirectory); - - // Check for Network Connectivity - statusUI.UpdateStatus(null, "Detecting Network", "Checking network connectivity, Please wait...", true, -1); - if (!Interop.NetworkInterop.PingDiscoIct(DiscoServerName)) - { - statusUI.UpdateStatus(null, "Detecting Network", "No network connectivity detected, Diagnosing...", true, -1); - statusUI_WriteAdapterInfo(); - - if (!Interop.NetworkInterop.PingDiscoIct(DiscoServerName)) - { - // Check for Wireless - var hasWireless = (Interop.NetworkInterop.NetworkAdapters.Count(na => na.IsWireless) > 0); - if (hasWireless) - { - // True: Do wireless loop - statusUI.UpdateStatus(null, "Configuring Wireless Network", "Wireless adapter detected, Configuring...", true, -1); - Interop.NetworkInterop.ConfigureWireless(); - statusUI.UpdateStatus(null, "Waiting for Wireless Network", null, true, 0); - for (int i = 0; i < 100; i++) - { - statusUI_WriteAdapterInfo(); - statusUI.UpdateStatus(null, null, null, true, i); - Program.SleepThread(500, false); - if (Interop.NetworkInterop.PingDiscoIct(DiscoServerName)) - break; - } - if (!Interop.NetworkInterop.PingDiscoIct(DiscoServerName)) - { - statusUI.UpdateStatus(null, "Wireless Network Failed", "Unable to connect to the wireless network, please connect the network cable...", false); - Program.SleepThread(3000, false); - } - } - - if (!Interop.NetworkInterop.PingDiscoIct(DiscoServerName)) - { - // Instruct user to connect network cable - statusUI.UpdateStatus(null, "Please connect the network cable", null); - for (int i = 0; i < 100; i++) - { - statusUI_WriteAdapterInfo(); - statusUI.UpdateStatus(null, null, null, true, i); - Program.SleepThread(500, false); - if (Interop.NetworkInterop.PingDiscoIct(DiscoServerName)) - break; - } - } - } - - if (!Interop.NetworkInterop.PingDiscoIct(DiscoServerName)) - { - // Client Failed - if (mLoopCompleteCallback != null) - { - mLoopCompleteCallback.BeginInvoke(null, null); - } + if (ex.GetType() == typeof(OperationCanceledException)) return; - } + Program.WriteAppError(ex); } - - // Download Client - statusUI.UpdateStatus(null, "Downloading", "Retrieving Preparation Client, Please wait...", true, -1); - string clientSourceLocation = Path.Combine(tempWorkingDirectory, "PreparationClient.zip"); - using (var webClient = new WebClient()) - { - // Don't use a proxy when downloading the Client - webClient.Proxy = new WebProxy(); - - webClient.DownloadFile($"http://{DiscoServerName}:{DiscoServerPort}/Services/Client/PreparationClient", clientSourceLocation); - } - - // Unzip Client - statusUI.UpdateStatus(null, "Extracting", "Retrieving Preparation Client, Please wait...", true, -1); - string clientLocation = Path.Combine(tempWorkingDirectory, "PreparationClient"); - if (Directory.Exists(clientLocation)) - Directory.Delete(clientLocation, true); - - Directory.CreateDirectory(clientLocation); - using (var clientSource = Ionic.Zip.ZipFile.Read(clientSourceLocation)) - { - clientSource.ExtractAll(clientLocation, Ionic.Zip.ExtractExistingFileAction.OverwriteSilently); - } - - // Launch Client - statusUI.UpdateStatus("System Preparation (Client)", "Running", "Launching Preparation Client, Please wait...", true, -1); - ProcessStartInfo clientProcessStart = new ProcessStartInfo(Path.Combine(clientLocation, "Start.bat")) - { - WorkingDirectory = clientLocation, - CreateNoWindow = true, - RedirectStandardOutput = true, - UseShellExecute = false, - }; - using (clientProcess = Process.Start(clientProcessStart)) - { - // Read StdOutput until End - try - { - clientProcess.OutputDataReceived += new DataReceivedEventHandler(clientProcess_OutputDataReceived); - clientProcess.BeginOutputReadLine(); - clientProcess.WaitForExit(); - } - catch (Exception) { throw; } - finally - { - try { clientProcess.CloseMainWindow(); } - catch (Exception) { } - } - } - clientProcess = null; - - // Cleanup - if (Directory.Exists(tempWorkingDirectory)) - Directory.Delete(tempWorkingDirectory, true); - Interop.CertificateInterop.RemoveTempCerts(); - - // Pause if Error - if (errorMessage.Length > 0) - { - Program.SleepThread(10000, true); - } - // End Of Loop - if (mLoopCompleteCallback != null) - { - mLoopCompleteCallback.BeginInvoke(null, null); - } + if (completeCallback != null) + await completeCallback(cancellationToken); } - void statusUI_WriteAdapterInfo() + private void statusUI_WriteAdapterInfo() { var info = new StringBuilder(); - foreach (var na in Interop.NetworkInterop.NetworkAdapters) + foreach (var na in NetworkInterop.NetworkAdapters) { if (na.IsWireless) { @@ -228,11 +228,10 @@ namespace Disco.ClientBootstrapper } - void clientProcess_OutputDataReceived(object sender, DataReceivedEventArgs e) + private void clientProcess_OutputDataReceived(object sender, DataReceivedEventArgs e) { if (!string.IsNullOrWhiteSpace(e.Data)) { - Debug.WriteLine($"OUTPUT: {e.Data}"); var data = e.Data.Substring(1).Split(new char[] { ',' }); switch (e.Data[0]) { @@ -249,15 +248,5 @@ namespace Disco.ClientBootstrapper } } - //void clientProcess_ErrorDataReceived(object sender, DataReceivedEventArgs e) - //{ - // if (!string.IsNullOrEmpty(e.Data)) - // { - // System.Diagnostics.Debug.WriteLine(string.Format("ERROR: {0}", e.Data)); - // this.errorMessage.AppendLine(e.Data); - // statusUI.UpdateStatus(null, "An Error Occurred", this.errorMessage.ToString(), false); - // } - //} - } } diff --git a/Disco.ClientBootstrapper/Disco.ClientBootstrapper.csproj b/Disco.ClientBootstrapper/Disco.ClientBootstrapper.csproj index 1f857e38..3fc1d6af 100644 --- a/Disco.ClientBootstrapper/Disco.ClientBootstrapper.csproj +++ b/Disco.ClientBootstrapper/Disco.ClientBootstrapper.csproj @@ -90,6 +90,9 @@ + + Interop\EndpointDiscovery.cs + DotNetZip\BZip2\BitWriter.cs diff --git a/Disco.ClientBootstrapper/FormStatus.cs b/Disco.ClientBootstrapper/FormStatus.cs index 95b1d8fd..d1d25bbd 100644 --- a/Disco.ClientBootstrapper/FormStatus.cs +++ b/Disco.ClientBootstrapper/FormStatus.cs @@ -7,22 +7,22 @@ namespace Disco.ClientBootstrapper { private delegate void dUpdateStatus(string Heading, string SubHeading, string Message, bool? ShowProgress, int? Progress); - private dUpdateStatus mUpdateStatus; + private readonly dUpdateStatus mUpdateStatus; public FormStatus() { InitializeComponent(); var version = System.Reflection.Assembly.GetExecutingAssembly().GetName().Version; - this.labelVersion.Text = $"v{version.ToString(3)}"; + labelVersion.Text = $"v{version.ToString(3)}"; - this.FormClosed += new FormClosedEventHandler(FormStatus_FormClosed); + FormClosed += new FormClosedEventHandler(FormStatus_FormClosed); mUpdateStatus = new dUpdateStatus(UpdateStatusDo); Cursor.Hide(); } - void FormStatus_FormClosed(object sender, FormClosedEventArgs e) + private void FormStatus_FormClosed(object sender, FormClosedEventArgs e) { Cursor.Show(); Program.ExitApplication(); @@ -32,43 +32,43 @@ namespace Disco.ClientBootstrapper { try { - this.Invoke(mUpdateStatus, Heading, SubHeading, Message, ShowProgress, Progress); + Invoke(mUpdateStatus, Heading, SubHeading, Message, ShowProgress, Progress); } catch (Exception) { } } private void UpdateStatusDo(string Heading, string SubHeading, string Message, bool? ShowProgress, int? Progress) { if (Heading != null) - if (this.labelHeading.Text != Heading) - this.labelHeading.Text = Heading; + if (labelHeading.Text != Heading) + labelHeading.Text = Heading; if (SubHeading != null) - if (this.labelSubHeading.Text != SubHeading) - this.labelSubHeading.Text = SubHeading; + if (labelSubHeading.Text != SubHeading) + labelSubHeading.Text = SubHeading; if (Message != null) - if (this.labelMessage.Text != Message) - this.labelMessage.Text = Message; + if (labelMessage.Text != Message) + labelMessage.Text = Message; if (ShowProgress.HasValue) { if (ShowProgress.Value) { - this.progressBar.Visible = true; + progressBar.Visible = true; if (Progress.HasValue) { if (Progress.Value >= 0) { - this.progressBar.Value = Math.Min(Progress.Value, 100); - this.progressBar.Style = ProgressBarStyle.Continuous; + progressBar.Value = Math.Min(Progress.Value, 100); + progressBar.Style = ProgressBarStyle.Continuous; } else { - this.progressBar.Style = ProgressBarStyle.Marquee; + progressBar.Style = ProgressBarStyle.Marquee; } } } else { - this.progressBar.Visible = false; + progressBar.Visible = false; } } } diff --git a/Disco.ClientBootstrapper/InstallLoop.cs b/Disco.ClientBootstrapper/InstallLoop.cs index 8923f22f..848497c3 100644 --- a/Disco.ClientBootstrapper/InstallLoop.cs +++ b/Disco.ClientBootstrapper/InstallLoop.cs @@ -1,54 +1,49 @@ using System; using System.Threading; +using System.Threading.Tasks; namespace Disco.ClientBootstrapper { - class InstallLoop + internal class InstallLoop { + private readonly CancellationTokenSource cancellationTokenSource = new CancellationTokenSource(); + private readonly string installLocation; + private readonly string wimImageId; + private readonly string tempPath; + private readonly Action completeCallback; + private readonly Uri forcedServerUrl; - public Thread LoopThread; - public delegate void CompleteCallback(); - private CompleteCallback mCompleteCallback; - private string InstallLocation; - private string WimImageId; - private string TempPath; - - public InstallLoop(string InstallLocation, string WimImageId, string TempPath) + public InstallLoop(string installLocation, string wimImageId, string tempPath, Action completeCallback, Uri forcedServerUrl) { - this.InstallLocation = InstallLocation; - this.WimImageId = WimImageId; - this.TempPath = TempPath; + this.installLocation = installLocation; + this.wimImageId = wimImageId; + this.tempPath = tempPath; + this.completeCallback = completeCallback; + this.forcedServerUrl = forcedServerUrl; } - public void Start(CompleteCallback Callback) + public void Start() { - mCompleteCallback = Callback; - LoopThread = new Thread(new ThreadStart(loopHost)); - LoopThread.Start(); - } - private void loopHost() - { - try + var cancellationToken = cancellationTokenSource.Token; + Task.Run(async () => { - - //Program.Status.UpdateStatus(null, null, "Testing UI"); - //Program.SleepThread(5000, false); - Interop.InstallInterop.Install(InstallLocation, WimImageId, TempPath); - if (mCompleteCallback != null) + try { - mCompleteCallback.BeginInvoke(null, null); + await Interop.InstallInterop.Install(installLocation, wimImageId, tempPath, forcedServerUrl, cancellationToken); + completeCallback?.BeginInvoke(null, null); } - } - catch (Exception ex) - { - if (ex.GetType() == typeof(ThreadAbortException)) - return; - if (ex.GetType() == typeof(ThreadInterruptedException)) - return; - Program.WriteAppError(ex); - throw; - } + catch (Exception ex) + { + if (ex.GetType() == typeof(ThreadAbortException)) + return; + if (ex.GetType() == typeof(ThreadInterruptedException)) + return; + if (ex.GetType() == typeof(OperationCanceledException)) + return; + Program.WriteAppError(ex); + throw; + } + }, cancellationToken); } - } } diff --git a/Disco.ClientBootstrapper/Interop/CertificateInterop.cs b/Disco.ClientBootstrapper/Interop/CertificateInterop.cs index 2466ad9e..9589c2ef 100644 --- a/Disco.ClientBootstrapper/Interop/CertificateInterop.cs +++ b/Disco.ClientBootstrapper/Interop/CertificateInterop.cs @@ -4,6 +4,8 @@ using System.IO; using System.Linq; using System.Security.Cryptography.X509Certificates; using System.Text.RegularExpressions; +using System.Threading; +using System.Threading.Tasks; namespace Disco.ClientBootstrapper.Interop { @@ -20,12 +22,12 @@ namespace Disco.ClientBootstrapper.Interop //Remove(StoreName.Root, StoreLocation.LocalMachine, _tempCerts); } } - public static void AddTempCerts() + public static async Task AddTempCerts(CancellationToken cancellationToken) { if (_tempCerts == null) _tempCerts = new List(); - var inlineCertificateLocation = Program.InlinePath.Value; + var inlineCertificateLocation = Path.GetDirectoryName(typeof(Program).Assembly.Location); // Root Certificates try @@ -35,6 +37,7 @@ namespace Disco.ClientBootstrapper.Interop { foreach (var certFile in CertFiles) { + cancellationToken.ThrowIfCancellationRequested(); var cert = new X509Certificate2(File.ReadAllBytes(certFile), "password"); var result = Add(StoreName.Root, StoreLocation.LocalMachine, cert); if (result) @@ -42,7 +45,7 @@ namespace Disco.ClientBootstrapper.Interop if (Path.GetFileNameWithoutExtension(certFile).ToLower().Contains("temp")) _tempCerts.Add(cert.SerialNumber); Program.Status.UpdateStatus(null, null, $"Added Root Certificate: {cert.ShortSubjectName()}"); - Program.SleepThread(500, false); + await Program.SleepThread(500, false, cancellationToken); } } } @@ -60,6 +63,7 @@ namespace Disco.ClientBootstrapper.Interop { foreach (var certFile in CertFiles) { + cancellationToken.ThrowIfCancellationRequested(); var cert = new X509Certificate2(File.ReadAllBytes(certFile), "password"); var result = Add(StoreName.CertificateAuthority, StoreLocation.LocalMachine, cert); if (result) @@ -67,7 +71,7 @@ namespace Disco.ClientBootstrapper.Interop if (Path.GetFileNameWithoutExtension(certFile).ToLower().Contains("temp")) _tempCerts.Add(cert.SerialNumber); Program.Status.UpdateStatus(null, null, $"Added Intermediate Certificate: {cert.ShortSubjectName()}"); - Program.SleepThread(500, false); + await Program.SleepThread(500, false, cancellationToken); } } } @@ -85,6 +89,7 @@ namespace Disco.ClientBootstrapper.Interop { foreach (var certFile in CertFiles) { + cancellationToken.ThrowIfCancellationRequested(); var cert = new X509Certificate2(File.ReadAllBytes(certFile), "password", X509KeyStorageFlags.MachineKeySet | X509KeyStorageFlags.PersistKeySet); var result = Add(StoreName.My, StoreLocation.LocalMachine, cert); if (result) @@ -92,7 +97,7 @@ namespace Disco.ClientBootstrapper.Interop if (Path.GetFileNameWithoutExtension(certFile).ToLower().Contains("temp")) _tempCerts.Add(cert.SerialNumber); Program.Status.UpdateStatus(null, null, $"Added Host Certificate: {cert.ShortSubjectName()}"); - Program.SleepThread(500, false); + await Program.SleepThread(500, false, cancellationToken); } } } diff --git a/Disco.ClientBootstrapper/Interop/InstallInterop.cs b/Disco.ClientBootstrapper/Interop/InstallInterop.cs index d2e3cfef..6722c6ec 100644 --- a/Disco.ClientBootstrapper/Interop/InstallInterop.cs +++ b/Disco.ClientBootstrapper/Interop/InstallInterop.cs @@ -4,15 +4,17 @@ using System.IO; using System.Linq; using System.Runtime.InteropServices; using System.Text; +using System.Threading; +using System.Threading.Tasks; namespace Disco.ClientBootstrapper.Interop { public static class InstallInterop { [DllImport("kernel32.dll", SetLastError = true, CharSet = CharSet.Unicode)] - static extern bool MoveFileEx(string lpExistingFileName, string lpNewFileName, MoveFileFlags dwFlags); + private static extern bool MoveFileEx(string lpExistingFileName, string lpNewFileName, MoveFileFlags dwFlags); [Flags] - enum MoveFileFlags + private enum MoveFileFlags { MOVEFILE_REPLACE_EXISTING = 0x00000001, MOVEFILE_COPY_ALLOWED = 0x00000002, @@ -22,19 +24,19 @@ namespace Disco.ClientBootstrapper.Interop MOVEFILE_FAIL_IF_NOT_TRACKABLE = 0x00000020 } - private static void Install(string RootFilesystemLocation, RegistryKey RootRegistryLocation, string FilesystemInstallLocation, string VirtualRootFilesystemLocation) + private static async Task Install(string rootFilesystemLocation, RegistryKey rootRegistryLocation, string filesystemInstallLocation, string virtualRootFilesystemLocation, Uri forcedServerUrl, CancellationToken cancellationToken) { var SourceLocation = Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location); - var InstallLocation = Path.Combine(RootFilesystemLocation, FilesystemInstallLocation); - var BootstrapperCmdLinePath = Path.Combine(VirtualRootFilesystemLocation, FilesystemInstallLocation, "Disco.ClientBootstrapper.exe"); + var InstallLocation = Path.Combine(rootFilesystemLocation, filesystemInstallLocation); + var BootstrapperCmdLinePath = Path.Combine(virtualRootFilesystemLocation, filesystemInstallLocation, "Disco.ClientBootstrapper.exe"); - var GroupPolicyScriptsIniLocation = Path.Combine(RootFilesystemLocation, "Windows\\System32\\GroupPolicy\\Machine\\Scripts\\scripts.ini"); - var GroupPolicyScriptsIniBackupLocation = Path.Combine(RootFilesystemLocation, "Windows\\System32\\GroupPolicy\\Machine\\Scripts\\disco_scripts.ini"); + var GroupPolicyScriptsIniLocation = Path.Combine(rootFilesystemLocation, @"Windows\System32\GroupPolicy\Machine\Scripts\scripts.ini"); + var GroupPolicyScriptsIniBackupLocation = Path.Combine(rootFilesystemLocation, @"Windows\System32\GroupPolicy\Machine\Scripts\disco_scripts.ini"); // Create file system Location #region "Create File System Location" Program.Status.UpdateStatus(null, null, "Creating Installation Location"); - Program.SleepThread(500, false); + await Program.SleepThread(500, false, cancellationToken); if (Directory.Exists(InstallLocation)) { // Try and Delete Directory @@ -52,19 +54,23 @@ namespace Disco.ClientBootstrapper.Interop var installDir = Directory.CreateDirectory(InstallLocation); installDir.Attributes = installDir.Attributes | FileAttributes.Hidden; } + cancellationToken.ThrowIfCancellationRequested(); #endregion // Copy files to file system location #region "Copy to File System" Program.Status.UpdateStatus(null, null, "Copying Files"); - Program.SleepThread(500, false); + await Program.SleepThread(500, false, cancellationToken); // Copy Bootstrapper // ie: Executing Assembly File.Copy(System.Reflection.Assembly.GetExecutingAssembly().Location, Path.Combine(InstallLocation, "Disco.ClientBootstrapper.exe")); + cancellationToken.ThrowIfCancellationRequested(); + foreach (var file in Directory.EnumerateFiles(SourceLocation)) { + cancellationToken.ThrowIfCancellationRequested(); var fileName = Path.GetFileName(file); // Only Copy Certain Files @@ -86,7 +92,7 @@ namespace Disco.ClientBootstrapper.Interop // Backup & Create Group Policy Scripts.ini #region "Group Policy Scripts.ini" Program.Status.UpdateStatus(null, null, "Creating Group Policy Script Entry"); - Program.SleepThread(500, false); + await Program.SleepThread(500, false, cancellationToken); // Backup if (!File.Exists(GroupPolicyScriptsIniBackupLocation)) { @@ -95,6 +101,7 @@ namespace Disco.ClientBootstrapper.Interop File.Move(GroupPolicyScriptsIniLocation, GroupPolicyScriptsIniBackupLocation); } } + cancellationToken.ThrowIfCancellationRequested(); // Create if (File.Exists(GroupPolicyScriptsIniLocation)) @@ -105,56 +112,67 @@ namespace Disco.ClientBootstrapper.Interop { using (var scriptsIniStreamWriter = new StreamWriter(scriptsIniStream, Encoding.Unicode)) { - scriptsIniStreamWriter.Write($"[Startup]{Environment.NewLine}0CmdLine={BootstrapperCmdLinePath}{Environment.NewLine}0Parameters=/AllowUninstall"); - scriptsIniStreamWriter.Flush(); + scriptsIniStreamWriter.WriteLine("[Startup]"); + scriptsIniStreamWriter.WriteLine($"0CmdLine={BootstrapperCmdLinePath}"); + if (forcedServerUrl == null) + scriptsIniStreamWriter.WriteLine("0Parameters=/AllowUninstall"); + else + scriptsIniStreamWriter.WriteLine($"0Parameters=/AllowUninstall {forcedServerUrl}"); } } + cancellationToken.ThrowIfCancellationRequested(); #endregion // Backup & Create Group Policy Registry #region "Group Policy Registry" Program.Status.UpdateStatus(null, null, "Creating Group Policy Registry Entries"); - Program.SleepThread(500, false); + await Program.SleepThread(500, false, cancellationToken); // Backup Scripts - using (var regGroupPolicy = RootRegistryLocation.OpenSubKey("Microsoft\\Windows\\CurrentVersion\\Group Policy", true)) + using (var regGroupPolicy = rootRegistryLocation.OpenSubKey(@"Microsoft\Windows\CurrentVersion\Group Policy", true)) { if (regGroupPolicy != null && regGroupPolicy.GetSubKeyNames().Contains("Scripts") && !regGroupPolicy.GetSubKeyNames().Contains("Disco_Scripts")) { RegistryUtilities.RenameSubKey(regGroupPolicy, "Scripts", "Disco_Scripts"); } } + cancellationToken.ThrowIfCancellationRequested(); // Create Scripts - RootRegistryLocation.CreateSubKey("Microsoft\\Windows\\CurrentVersion\\Group Policy\\Scripts\\Shutdown").Dispose(); - using (var regScriptsStartup = RootRegistryLocation.CreateSubKey("Microsoft\\Windows\\CurrentVersion\\Group Policy\\Scripts\\Startup\\0")) + rootRegistryLocation.CreateSubKey(@"Microsoft\Windows\CurrentVersion\Group Policy\Scripts\Shutdown").Dispose(); + using (var regScriptsStartup = rootRegistryLocation.CreateSubKey(@"Microsoft\Windows\CurrentVersion\Group Policy\Scripts\Startup\0")) { regScriptsStartup.SetValue("GPO-ID", "LocalGPO", RegistryValueKind.String); regScriptsStartup.SetValue("SOM-ID", "Local", RegistryValueKind.String); - regScriptsStartup.SetValue("FileSysPath", Path.Combine(Environment.SystemDirectory, "GroupPolicy\\Machine"), RegistryValueKind.String); + regScriptsStartup.SetValue("FileSysPath", Path.Combine(Environment.SystemDirectory, @"GroupPolicy\Machine"), RegistryValueKind.String); regScriptsStartup.SetValue("DisplayName", "Local Group Policy", RegistryValueKind.String); regScriptsStartup.SetValue("GPOName", "Local Group Policy", RegistryValueKind.String); regScriptsStartup.SetValue("PSScriptOrder", 1, RegistryValueKind.DWord); using (var regScriptsStartup0 = regScriptsStartup.CreateSubKey("0")) { regScriptsStartup0.SetValue("Script", BootstrapperCmdLinePath, RegistryValueKind.String); - regScriptsStartup0.SetValue("Parameters", "/AllowUninstall", RegistryValueKind.String); + if (forcedServerUrl == null) + regScriptsStartup0.SetValue("Parameters", "/AllowUninstall", RegistryValueKind.String); + else + regScriptsStartup0.SetValue("Parameters", $"/AllowUninstall {forcedServerUrl}", RegistryValueKind.String); regScriptsStartup0.SetValue("IsPowershell", 0, RegistryValueKind.DWord); regScriptsStartup0.SetValue("ExecTime", new byte[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, RegistryValueKind.Binary); } } - RootRegistryLocation.CreateSubKey("Microsoft\\Windows\\CurrentVersion\\Group Policy\\State\\Machine\\Scripts\\Shutdown").Dispose(); + rootRegistryLocation.CreateSubKey(@"Microsoft\Windows\CurrentVersion\Group Policy\State\Machine\Scripts\Shutdown").Dispose(); + cancellationToken.ThrowIfCancellationRequested(); // Backup Scripts State - using (var regGroupPolicy = RootRegistryLocation.OpenSubKey("Microsoft\\Windows\\CurrentVersion\\Group Policy\\State\\Machine", true)) + using (var regGroupPolicy = rootRegistryLocation.OpenSubKey(@"Microsoft\Windows\CurrentVersion\Group Policy\State\Machine", true)) { if (regGroupPolicy != null && regGroupPolicy.GetSubKeyNames().Contains("Scripts") && !regGroupPolicy.GetSubKeyNames().Contains("Disco_Scripts")) { RegistryUtilities.RenameSubKey(regGroupPolicy, "Scripts", "Disco_Scripts"); } } + cancellationToken.ThrowIfCancellationRequested(); // Create Scripts State - using (var regStateScriptsStartup = RootRegistryLocation.CreateSubKey("Microsoft\\Windows\\CurrentVersion\\Group Policy\\State\\Machine\\Scripts\\Startup\\0")) + using (var regStateScriptsStartup = rootRegistryLocation.CreateSubKey(@"Microsoft\Windows\CurrentVersion\Group Policy\State\Machine\Scripts\Startup\0")) { regStateScriptsStartup.SetValue("GPO-ID", "LocalGPO", RegistryValueKind.String); regStateScriptsStartup.SetValue("SOM-ID", "Local", RegistryValueKind.String); @@ -165,17 +183,21 @@ namespace Disco.ClientBootstrapper.Interop using (var regStateScriptsStartup0 = regStateScriptsStartup.CreateSubKey("0")) { regStateScriptsStartup0.SetValue("Script", BootstrapperCmdLinePath, RegistryValueKind.String); - regStateScriptsStartup0.SetValue("Parameters", "/AllowUninstall", RegistryValueKind.String); + if (forcedServerUrl == null) + regStateScriptsStartup0.SetValue("Parameters", "/AllowUninstall", RegistryValueKind.String); + else + regStateScriptsStartup0.SetValue("Parameters", $"/AllowUninstall {forcedServerUrl}", RegistryValueKind.String); regStateScriptsStartup0.SetValue("ExecTime", new byte[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, RegistryValueKind.Binary); } } + cancellationToken.ThrowIfCancellationRequested(); #endregion // Set Registry Startup Environment Policies #region "Registry Startup Policies" Program.Status.UpdateStatus(null, null, "Creating Startup Policy Registry Entries"); - Program.SleepThread(500, false); - using (var regWinlogon = RootRegistryLocation.OpenSubKey("Microsoft\\Windows NT\\CurrentVersion\\Winlogon", true)) + await Program.SleepThread(500, false, cancellationToken); + using (var regWinlogon = rootRegistryLocation.OpenSubKey(@"Microsoft\Windows NT\CurrentVersion\Winlogon", true)) { regWinlogon.SetValue("HideStartupScripts", 0, RegistryValueKind.DWord); regWinlogon.SetValue("RunStartupScriptSync", 1, RegistryValueKind.DWord); @@ -183,94 +205,110 @@ namespace Disco.ClientBootstrapper.Interop #endregion } - public static void Install(string InstallLocation, string WimImageId, string TempPath) + public static async Task Install(string installLocation, string wimImageId, string tempPath, Uri forcedServerUrl, CancellationToken cancellationToken) { Program.Status.UpdateStatus("Installing Bootstrapper", "Starting", "Please wait...", false); - if (string.IsNullOrWhiteSpace(InstallLocation)) - InstallLocation = Path.Combine(Path.GetPathRoot(Environment.SystemDirectory), "Disco"); + if (string.IsNullOrWhiteSpace(installLocation)) + installLocation = Path.Combine(Path.GetPathRoot(Environment.SystemDirectory), "Disco"); - if (InstallLocation.EndsWith(".wim", StringComparison.OrdinalIgnoreCase)) + cancellationToken.ThrowIfCancellationRequested(); + + if (installLocation.EndsWith(".wim", StringComparison.OrdinalIgnoreCase)) { // Offline File System (WIM) - Program.Status.UpdateStatus("Installing Bootstrapper (Offline)", "Installing", $"Install Location: {InstallLocation}"); - Program.SleepThread(1000, false); + Program.Status.UpdateStatus("Installing Bootstrapper (Offline)", "Installing", $"Install Location: {installLocation}"); + await Program.SleepThread(1000, false, cancellationToken); // Mount WIM int wimImageIndex = 0; - using (var wim = new WIMInterop.WindowsImageContainer(InstallLocation, WIMInterop.WindowsImageContainer.CreateFileMode.OpenExisting, WIMInterop.WindowsImageContainer.CreateFileAccess.Write)) + using (var wim = new WIMInterop.WindowsImageContainer(installLocation, WIMInterop.WindowsImageContainer.CreateFileMode.OpenExisting, WIMInterop.WindowsImageContainer.CreateFileAccess.Write)) { - if (WimImageId == null) - WimImageId = "1"; - if (!int.TryParse(WimImageId, out wimImageIndex)) + cancellationToken.ThrowIfCancellationRequested(); + if (wimImageId == null) + wimImageId = "1"; + if (!int.TryParse(wimImageId, out wimImageIndex)) { - Program.Status.UpdateStatus(null, "Analysing WIM", $"Looking for Image Name: {WimImageId}"); - Program.SleepThread(500, false); + Program.Status.UpdateStatus(null, "Analysing WIM", $"Looking for Image Name: {wimImageId}"); + await Program.SleepThread(500, false, cancellationToken); for (int i = 0; i < wim.ImageCount; i++) { var wimImageInfo = new System.Xml.XmlDocument(); using (var wimImage = wim[i]) wimImageInfo.LoadXml(wimImage.ImageInformation); var wimImageInfoName = wimImageInfo.SelectSingleNode("//IMAGE/NAME"); - if (wimImageInfoName != null && wimImageInfoName.InnerText.Equals(WimImageId, StringComparison.OrdinalIgnoreCase)) + if (wimImageInfoName != null && wimImageInfoName.InnerText.Equals(wimImageId, StringComparison.OrdinalIgnoreCase)) { wimImageIndex = i + 1; - Program.Status.UpdateStatus(null, "Analysing WIM", $"Found Image Id '{WimImageId}' at Index {wimImageIndex}"); - Program.SleepThread(500, false); + Program.Status.UpdateStatus(null, "Analysing WIM", $"Found Image Id '{wimImageId}' at Index {wimImageIndex}"); + await Program.SleepThread(500, false, cancellationToken); break; } } } } + cancellationToken.ThrowIfCancellationRequested(); if (wimImageIndex == 0) { - Program.Status.UpdateStatus(null, "Error", $"Unable to load WIM Image Id: {WimImageId}"); - Program.SleepThread(5000, false); + Program.Status.UpdateStatus(null, "Error", $"Unable to load WIM Image Id: {wimImageId}"); + await Program.SleepThread(5000, false, cancellationToken); return; } // Get Temp Path - var wimMountPath = Path.Combine(TempPath ?? Path.GetTempPath(), "DiscoClientBootstrapperWimMount"); + var wimMountPath = Path.Combine(tempPath ?? Path.GetTempPath(), "DiscoClientBootstrapperWimMount"); if (Directory.Exists(wimMountPath)) Directory.Delete(wimMountPath, true); Directory.CreateDirectory(wimMountPath); - var wimTempMountPath = Path.Combine(TempPath ?? Path.GetTempPath(), "DiscoClientBootstrapperWimTempMount"); + cancellationToken.ThrowIfCancellationRequested(); + + var wimTempMountPath = Path.Combine(tempPath ?? Path.GetTempPath(), "DiscoClientBootstrapperWimTempMount"); if (Directory.Exists(wimTempMountPath)) Directory.Delete(wimTempMountPath, true); Directory.CreateDirectory(wimTempMountPath); + cancellationToken.ThrowIfCancellationRequested(); + bool wimCommitChanges = true; WIMInterop.WindowsImageContainer.NativeMethods.MessageCallback m_MessageCallback = null; try { // Mount WIM Program.Status.UpdateStatus(null, "Mounting WIM", $"Mounting WIM Image to '{wimMountPath}'"); - Program.SleepThread(500, false); + await Program.SleepThread(500, false, cancellationToken); m_MessageCallback = new WIMInterop.WindowsImageContainer.NativeMethods.MessageCallback(WimImageEventMessagePump); WIMInterop.WindowsImageContainer.NativeMethods.RegisterCallback(m_MessageCallback); - WIMInterop.WindowsImageContainer.NativeMethods.MountImage(wimMountPath, InstallLocation, wimImageIndex, wimTempMountPath); + WIMInterop.WindowsImageContainer.NativeMethods.MountImage(wimMountPath, installLocation, wimImageIndex, wimTempMountPath); // Load Local Machine Registry var wimHivePath = Path.Combine(wimMountPath, "Windows\\System32\\config\\SOFTWARE"); Program.Status.UpdateStatus(null, "Mounting Offline Registry Hive", $"Mounting Offline Registry Hive at '{wimHivePath}'"); - Program.SleepThread(500, false); + await Program.SleepThread(500, false, cancellationToken); using (var wimReg = new RegistryInterop(RegistryInterop.RegistryHives.HKEY_LOCAL_MACHINE, "DiscoClientBootstrapperWimHive", wimHivePath)) { - using (RegistryKey rootRegistryLocation = Registry.LocalMachine.OpenSubKey("DiscoClientBootstrapperWimHive", true)) + try { - string rootFileSystemLocation = wimMountPath; - string fileSystemInstallLocation = "Disco"; - string virtualRootFileSystemLocation = "C:\\"; + cancellationToken.ThrowIfCancellationRequested(); + using (RegistryKey rootRegistryLocation = Registry.LocalMachine.OpenSubKey("DiscoClientBootstrapperWimHive", true)) + { + string rootFileSystemLocation = wimMountPath; + string fileSystemInstallLocation = "Disco"; + string virtualRootFileSystemLocation = "C:\\"; - Install(rootFileSystemLocation, rootRegistryLocation, fileSystemInstallLocation, virtualRootFileSystemLocation); + cancellationToken.ThrowIfCancellationRequested(); + await Install(rootFileSystemLocation, rootRegistryLocation, fileSystemInstallLocation, virtualRootFileSystemLocation, forcedServerUrl, cancellationToken); + cancellationToken.ThrowIfCancellationRequested(); + } + } + finally + { + // Unload Local Machine Registry + Program.Status.UpdateStatus(null, "Unmounting Offline Registry Hive", $"Unmounting Offline Registry Hive at '{wimHivePath}'"); + await Program.SleepThread(500, false, cancellationToken); + wimReg.Unload(); } - - // Unload Local Machine Registry - Program.Status.UpdateStatus(null, "Unmounting Offline Registry Hive", $"Unmounting Offline Registry Hive at '{wimHivePath}'"); - Program.SleepThread(500, false); - wimReg.Unload(); } } catch (Exception) @@ -282,8 +320,8 @@ namespace Disco.ClientBootstrapper.Interop { // Unmount WIM Program.Status.UpdateStatus(null, "Unmounting WIM", $"Unmounting WIM Image at '{wimMountPath}'"); - Program.SleepThread(500, false); - WIMInterop.WindowsImageContainer.NativeMethods.DismountImage(wimMountPath, InstallLocation, wimImageIndex, wimCommitChanges); + await Program.SleepThread(500, false, cancellationToken); + WIMInterop.WindowsImageContainer.NativeMethods.DismountImage(wimMountPath, installLocation, wimImageIndex, wimCommitChanges); if (m_MessageCallback != null) { @@ -295,23 +333,25 @@ namespace Disco.ClientBootstrapper.Interop Directory.Delete(wimMountPath, true); if (Directory.Exists(wimTempMountPath)) Directory.Delete(wimTempMountPath, true); + + cancellationToken.ThrowIfCancellationRequested(); } } else { // Online File System - Program.Status.UpdateStatus("Installing Bootstrapper (Online)", "Installing", $"Install Location: {InstallLocation}", true, -1); - Program.SleepThread(1000, false); - string rootFileSystemLocation = Path.GetPathRoot(InstallLocation); + Program.Status.UpdateStatus("Installing Bootstrapper (Online)", "Installing", $"Install Location: {installLocation}", true, -1); + await Program.SleepThread(1000, false, cancellationToken); + string rootFileSystemLocation = Path.GetPathRoot(installLocation); RegistryKey rootRegistryLocation = Registry.LocalMachine.OpenSubKey("SOFTWARE", true); - string fileSystemInstallLocation = InstallLocation.Substring(rootFileSystemLocation.Length); + string fileSystemInstallLocation = installLocation.Substring(rootFileSystemLocation.Length); - Install(rootFileSystemLocation, rootRegistryLocation, fileSystemInstallLocation, rootFileSystemLocation); + await Install(rootFileSystemLocation, rootRegistryLocation, fileSystemInstallLocation, rootFileSystemLocation, forcedServerUrl, cancellationToken); Program.Status.UpdateStatus(null, "Online File System Installation Complete", string.Empty, true, -1); - Program.SleepThread(1000, false); + await Program.SleepThread(1000, false, cancellationToken); } Program.Status.UpdateStatus(null, "Complete", "Finished Installing Bootstrapper"); - Program.SleepThread(1500, false); + await Program.SleepThread(1500, false, cancellationToken); } private static uint WimImageEventMessagePump( @@ -349,41 +389,28 @@ namespace Disco.ClientBootstrapper.Interop return status; } - public static void Uninstall() + public static async Task Uninstall(CancellationToken cancellationToken) { // Application Directory - var appDirectory = Program.InlinePath.Value; - if (Program.AllowUninstall && !appDirectory.StartsWith("\\\\")) + var appDirectory = Path.GetDirectoryName(typeof(Program).Assembly.Location); + if (Program.AllowUninstall && !appDirectory.StartsWith(@"\\")) { Program.Status.UpdateStatus("System Preparation (Bootstrapper)", "Uninstalling Bootstrapper...", string.Empty, false, 0); - Program.SleepThread(1000, true); - //var uninstallScriptLocation = System.IO.Path.Combine(appDirectory, "UninstallBootstrapper.vbs"); - //if (System.IO.File.Exists(uninstallScriptLocation)) - //{ - // var bootstrapperPID = System.Diagnostics.Process.GetCurrentProcess().Id; - // var cscriptPath = System.IO.Path.Combine(Environment.SystemDirectory, "cscript.exe"); - // var cscriptArgs = string.Format("\"{0}\" /WaitForProcessID:{1}", uninstallScriptLocation, bootstrapperPID); - - // var startProc = new ProcessStartInfo(cscriptPath, cscriptArgs); - // startProc.WorkingDirectory = Environment.SystemDirectory; - // startProc.WindowStyle = ProcessWindowStyle.Hidden; - - // Process.Start(startProc); - //} + await Program.SleepThread(1000, true, cancellationToken); // Remove Registry Entries - using (var regWinlogon = Registry.LocalMachine.OpenSubKey("SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\Winlogon", true)) + using (var regWinlogon = Registry.LocalMachine.OpenSubKey(@"SOFTWARE\Microsoft\Windows NT\CurrentVersion\Winlogon", true)) { regWinlogon.DeleteValue("HideStartupScripts", false); regWinlogon.DeleteValue("RunStartupScriptSync", false); } - Registry.LocalMachine.DeleteSubKeyTree("SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\Group Policy\\Scripts\\Shutdown", false); - Registry.LocalMachine.DeleteSubKeyTree("SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\Group Policy\\Scripts\\Startup", false); - Registry.LocalMachine.DeleteSubKeyTree("SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\Group Policy\\State\\Machine\\Scripts\\Shutdown", false); - Registry.LocalMachine.DeleteSubKeyTree("SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\Group Policy\\State\\Machine\\Scripts\\Startup", false); + Registry.LocalMachine.DeleteSubKeyTree(@"SOFTWARE\Microsoft\Windows\CurrentVersion\Group Policy\Scripts\Shutdown", false); + Registry.LocalMachine.DeleteSubKeyTree(@"SOFTWARE\Microsoft\Windows\CurrentVersion\Group Policy\Scripts\Startup", false); + Registry.LocalMachine.DeleteSubKeyTree(@"SOFTWARE\Microsoft\Windows\CurrentVersion\Group Policy\State\Machine\Scripts\Shutdown", false); + Registry.LocalMachine.DeleteSubKeyTree(@"SOFTWARE\Microsoft\Windows\CurrentVersion\Group Policy\State\Machine\Scripts\Startup", false); // Restore Registry Backups - using (var regGroupPolicy = Registry.LocalMachine.OpenSubKey("SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\Group Policy", true)) + using (var regGroupPolicy = Registry.LocalMachine.OpenSubKey(@"SOFTWARE\Microsoft\Windows\CurrentVersion\Group Policy", true)) { if (regGroupPolicy != null && regGroupPolicy.GetSubKeyNames().Contains("Disco_Scripts")) { @@ -391,7 +418,7 @@ namespace Disco.ClientBootstrapper.Interop RegistryUtilities.RenameSubKey(regGroupPolicy, "Disco_Scripts", "Scripts"); } } - using (var regGroupPolicy = Registry.LocalMachine.OpenSubKey("SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\Group Policy\\State\\Machine", true)) + using (var regGroupPolicy = Registry.LocalMachine.OpenSubKey(@"SOFTWARE\Microsoft\Windows\CurrentVersion\Group Policy\State\Machine", true)) { if (regGroupPolicy != null && regGroupPolicy.GetSubKeyNames().Contains("Disco_Scripts")) { @@ -401,10 +428,10 @@ namespace Disco.ClientBootstrapper.Interop } // Delete Group Policy Script File - var groupPolicyScriptsPath = Path.Combine(Environment.SystemDirectory, "GroupPolicy\\Machine\\Scripts\\scripts.ini"); + var groupPolicyScriptsPath = Path.Combine(Environment.SystemDirectory, @"GroupPolicy\Machine\Scripts\scripts.ini"); if (File.Exists(groupPolicyScriptsPath)) File.Delete(groupPolicyScriptsPath); - var groupPolicyScriptsBackupPath = Path.Combine(Environment.SystemDirectory, "GroupPolicy\\Machine\\Scripts\\disco_scripts.ini"); + var groupPolicyScriptsBackupPath = Path.Combine(Environment.SystemDirectory, @"GroupPolicy\Machine\Scripts\disco_scripts.ini"); if (File.Exists(groupPolicyScriptsBackupPath)) File.Move(groupPolicyScriptsBackupPath, groupPolicyScriptsPath); diff --git a/Disco.ClientBootstrapper/Interop/NetworkInterop.cs b/Disco.ClientBootstrapper/Interop/NetworkInterop.cs index fc5aab1e..68d47ab0 100644 --- a/Disco.ClientBootstrapper/Interop/NetworkInterop.cs +++ b/Disco.ClientBootstrapper/Interop/NetworkInterop.cs @@ -1,14 +1,17 @@ using System; using System.Collections.Generic; +using System.IO; using System.Linq; using System.Management; using System.Net.NetworkInformation; using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; using System.Xml; namespace Disco.ClientBootstrapper.Interop { - static class NetworkInterop + internal static class NetworkInterop { #region PInvoke @@ -164,30 +167,35 @@ namespace Disco.ClientBootstrapper.Interop } } - public static bool PingDiscoIct(string ServerName) + public static bool HasNetworkConnectivity() { - using (Ping p = new Ping()) + var nics = NetworkInterface.GetAllNetworkInterfaces() + .Where(ni => ni.OperationalStatus == OperationalStatus.Up) + .ToList(); + + foreach (var nic in nics) { - try + if (nic.Supports(NetworkInterfaceComponent.IPv4)) { - PingReply pr = p.Send(ServerName, 2000); - if (pr.Status == IPStatus.Success) - return true; - else - return false; - } - catch (Exception) - { - return false; + var ipProps = nic.GetIPProperties(); + var ipv4Props = ipProps.GetIPv4Properties(); + if (ipv4Props.IsAutomaticPrivateAddressingActive) + continue; + + return ipProps.UnicastAddresses + .Where(ua => ua.Address.AddressFamily == System.Net.Sockets.AddressFamily.InterNetwork) + .Any(); } } + + return false; } - public static void ConfigureWireless() + public static async Task ConfigureWireless(CancellationToken cancellationToken) { // Add Certificates Program.Status.UpdateStatus(null, null, "Configuring Wireless Certificates"); - CertificateInterop.AddTempCerts(); + await CertificateInterop.AddTempCerts(cancellationToken); // Add Wireless Profiles Program.Status.UpdateStatus(null, null, "Configuring Wireless Profiles"); @@ -208,15 +216,16 @@ namespace Disco.ClientBootstrapper.Interop { foreach (var inlineWirelessProfile in wirelessInlineProfiles) { + cancellationToken.ThrowIfCancellationRequested(); if (inlineWirelessProfile.AddProfile(wlanHandle, na.Guid)) { Program.Status.UpdateStatus(null, null, $"Added Wireless Profile: {inlineWirelessProfile.ProfileName}"); - Program.SleepThread(500, false); + await Program.SleepThread(500, false, cancellationToken); } else { Program.Status.UpdateStatus(null, null, $"Unable to add Wireless Profile: {inlineWirelessProfile.ProfileName}"); - Program.SleepThread(5000, false); + await Program.SleepThread(5000, false, cancellationToken); } } } @@ -246,14 +255,15 @@ namespace Disco.ClientBootstrapper.Interop private static List GetInlineWirelessProfiles() { - var inlineProfileFiles = System.IO.Directory.EnumerateFiles(Program.InlinePath.Value, "WLAN_Profile_*.xml").ToList(); + var directoryPath = Path.GetDirectoryName(typeof(Program).Assembly.Location); + var inlineProfileFiles = Directory.EnumerateFiles(directoryPath, "WLAN_Profile_*.xml").ToList(); var inlineProfiles = new List(inlineProfileFiles.Count); foreach (var filename in inlineProfileFiles) { var profile = new WirelessProfile() { Filename = filename, - ProfileXml = System.IO.File.ReadAllText(filename) + ProfileXml = File.ReadAllText(filename) }; var profileXml = new XmlDocument(); profileXml.LoadXml(profile.ProfileXml); diff --git a/Disco.ClientBootstrapper/Program.cs b/Disco.ClientBootstrapper/Program.cs index 76ec367c..a0d1a8b0 100644 --- a/Disco.ClientBootstrapper/Program.cs +++ b/Disco.ClientBootstrapper/Program.cs @@ -1,36 +1,58 @@ using System; using System.Collections.Generic; +using System.Linq; +using System.Net; using System.Threading; +using System.Threading.Tasks; using System.Windows.Forms; namespace Disco.ClientBootstrapper { - static class Program + internal static class Program { - public static IStatus Status { get; set; } - public static BootstrapperLoop BootstrapperLoop { get; set; } - public static InstallLoop InstallLoop { get; set; } + private static readonly CancellationTokenSource cancellationTokenSource = new CancellationTokenSource(); + public static IStatus Status { get; private set; } public static List PostBootstrapperActions { get; set; } - public static bool AllowUninstall { get; set; } - public static bool ApplicationExiting { get; set; } - public static Lazy InlinePath = new Lazy(() => - { - return System.IO.Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location); - }); + public static bool AllowUninstall { get; private set; } + public static Uri ForcedServerUrl { get; private set; } = null; /// /// The main entry point for the application. /// [STAThread] - static void Main(string[] args) + private static void Main(string[] args) { Application.ThreadException += new ThreadExceptionEventHandler(Application_ThreadException); Application.EnableVisualStyles(); Application.SetCompatibleTextRenderingDefault(false); + ServicePointManager.SecurityProtocol |= SecurityProtocolType.Tls12; + if (args.Length > 0) { +#if DEBUG + if (args.Any(a => a.Equals("debug", StringComparison.OrdinalIgnoreCase))) + { + do + { + Console.WriteLine("Waiting for Debugger to Attach"); + Thread.Sleep(1000); + } while (!System.Diagnostics.Debugger.IsAttached); + } +#endif + + if (args.Any(a => a.StartsWith("http://", StringComparison.OrdinalIgnoreCase))) + throw new ArgumentException("Only HTTPS URLs are supported for a forced server URL."); + var forcedServerArg = args.FirstOrDefault(a => a.StartsWith("https://", StringComparison.OrdinalIgnoreCase)); + if (forcedServerArg != null) + { + if (Uri.TryCreate(forcedServerArg, UriKind.Absolute, out var forcedUri)) + ForcedServerUrl = forcedUri; + else + throw new ArgumentException("The provided forced server URL is not valid."); + } + switch (args[0].ToLower()) { case "/install": @@ -46,14 +68,17 @@ namespace Disco.ClientBootstrapper wimImage = args[2]; if (args.Length > 3) tempPath = args[3]; - InstallLoop = new InstallLoop(installLocation, wimImage, tempPath); - InstallLoop.Start(new InstallLoop.CompleteCallback(InstallComplete)); + var installLoop = new InstallLoop(installLocation, wimImage, tempPath, InstallComplete, ForcedServerUrl); + installLoop.Start(); Application.Run(); return; case "/uninstall": AllowUninstall = true; Status = new NullStatus(); - Interop.InstallInterop.Uninstall(); + Task.Run(async () => + { + await Interop.InstallInterop.Uninstall(cancellationTokenSource.Token); + }).Wait(cancellationTokenSource.Token); return; case "/allowuninstall": AllowUninstall = true; @@ -71,13 +96,13 @@ namespace Disco.ClientBootstrapper statusForm.Show(); } - BootstrapperLoop = new BootstrapperLoop(Status, new BootstrapperLoop.LoopCompleteCallback(LoopComplete)); - BootstrapperLoop.Start(); + var bootstrapperLoop = new BootstrapperLoop(Status, ForcedServerUrl, LoopComplete, cancellationTokenSource.Token); + bootstrapperLoop.Start(); Application.Run(); } - static void Application_ThreadException(object sender, ThreadExceptionEventArgs e) + private static void Application_ThreadException(object sender, ThreadExceptionEventArgs e) { WriteAppError(e.Exception); } @@ -100,7 +125,7 @@ namespace Disco.ClientBootstrapper catch (Exception) { } } - public static void LoopComplete() + public static async Task LoopComplete(CancellationToken cancellationToken) { // Run Post Actions if (PostBootstrapperActions != null) @@ -108,32 +133,32 @@ namespace Disco.ClientBootstrapper // Check Uninstall if (AllowUninstall && PostBootstrapperActions.Contains("UninstallBootstrapper")) { - Interop.InstallInterop.Uninstall(); + await Interop.InstallInterop.Uninstall(cancellationToken); } // Check ShutdownActions if (PostBootstrapperActions.Contains("Shutdown")) { Status.UpdateStatus("System Preparation (Bootstrapper)", "Shutting Down; Finished...", string.Empty, false, 0); - SleepThread(4000, true); + await SleepThread(4000, true, cancellationToken); Interop.ShutdownInterop.Shutdown(); } else if (PostBootstrapperActions.Contains("Reboot")) { Status.UpdateStatus("System Preparation (Bootstrapper)", "Rebooting; Finished...", string.Empty, false, 0); - SleepThread(4000, true); + await SleepThread(4000, true, cancellationToken); Interop.ShutdownInterop.Reboot(); } else { Status.UpdateStatus("System Preparation (Bootstrapper)", "Starting System; Finished...", string.Empty, false, 0); - SleepThread(2000, true); + await SleepThread(2000, true, cancellationToken); } } else { Status.UpdateStatus("System Preparation (Bootstrapper)", "Starting System; Finished...", string.Empty, false, 0); - SleepThread(2000, true); + await SleepThread(2000, true, cancellationToken); } ExitApplication(); @@ -146,33 +171,12 @@ namespace Disco.ClientBootstrapper public static void ExitApplication() { - if (!ApplicationExiting) - { - ApplicationExiting = true; - if (BootstrapperLoop != null) - { - if (BootstrapperLoop.LoopThread != null) - { - if (BootstrapperLoop.LoopThread.ThreadState == ThreadState.WaitSleepJoin) - { - BootstrapperLoop.LoopThread.Interrupt(); - } - if (BootstrapperLoop.LoopThread.ThreadState == ThreadState.Running) - { - BootstrapperLoop.LoopThread.Abort(); - } - } - } - Application.Exit(); - } + if (!cancellationTokenSource.IsCancellationRequested) + cancellationTokenSource.Cancel(); + Application.Exit(); } - public static void Trace(string Format, params string[] args) - { - System.Diagnostics.Debug.WriteLine(Format, args); - } - - public static void SleepThread(int millisecondsTimeout, bool updateUI) + public static async Task SleepThread(int millisecondsTimeout, bool updateUI, CancellationToken cancellationToken) { if (updateUI) { @@ -180,12 +184,12 @@ namespace Disco.ClientBootstrapper { int progress = Convert.ToInt32(((Convert.ToDouble(i) / Convert.ToDouble(millisecondsTimeout)) * 100)); Status.UpdateStatus(null, null, null, true, progress); - Thread.Sleep(500); + await Task.Delay(500, cancellationToken); } } else { - Thread.Sleep(millisecondsTimeout); + await Task.Delay(millisecondsTimeout, cancellationToken); } } } diff --git a/Disco.Data/Configuration/Modules/DevicesConfiguration.cs b/Disco.Data/Configuration/Modules/DevicesConfiguration.cs index b38fed40..09690554 100644 --- a/Disco.Data/Configuration/Modules/DevicesConfiguration.cs +++ b/Disco.Data/Configuration/Modules/DevicesConfiguration.cs @@ -14,5 +14,11 @@ namespace Disco.Data.Configuration.Modules get => Get(DeviceExportOptions.DefaultOptions()); set => Set(value); } + + public bool EnrollmentLegacyDiscoveryDisabled + { + get => Get(false); + set => Set(value); + } } } diff --git a/Disco.Models/Disco.Models.csproj b/Disco.Models/Disco.Models.csproj index 5f02c943..c8aeab5b 100644 --- a/Disco.Models/Disco.Models.csproj +++ b/Disco.Models/Disco.Models.csproj @@ -62,6 +62,7 @@ + diff --git a/Disco.Models/Services/Devices/DeviceEnrolmentServerDiscoveryMethod.cs b/Disco.Models/Services/Devices/DeviceEnrolmentServerDiscoveryMethod.cs new file mode 100644 index 00000000..18a5a675 --- /dev/null +++ b/Disco.Models/Services/Devices/DeviceEnrolmentServerDiscoveryMethod.cs @@ -0,0 +1,13 @@ +namespace Disco.Models.Services.Devices +{ + public enum DeviceEnrolmentServerDiscoveryMethod + { + Unknown = 0, + Manual = 1, + SRV = 2, + VicSmart = 3, + Legacy = 4, + Mac = 50, + MacSecure = 51, + } +} diff --git a/Disco.Models/Services/Interop/DiscoServices/UpdateRequestV2.cs b/Disco.Models/Services/Interop/DiscoServices/UpdateRequestV2.cs index ae317def..42c68e07 100644 --- a/Disco.Models/Services/Interop/DiscoServices/UpdateRequestV2.cs +++ b/Disco.Models/Services/Interop/DiscoServices/UpdateRequestV2.cs @@ -24,6 +24,7 @@ namespace Disco.Models.Services.Interop.DiscoServices public List Stat_JobIdentifiers { get; set; } public List Stat_Jobs { get; set; } + public List Stat_EnrollmentDiscovery { get; set; } public class StatisticIntPair { diff --git a/Disco.Models/UI/Config/Enrolment/ConfigEnrolmentIndexModel.cs b/Disco.Models/UI/Config/Enrolment/ConfigEnrolmentIndexModel.cs index ae35bb4f..6ef4771a 100644 --- a/Disco.Models/UI/Config/Enrolment/ConfigEnrolmentIndexModel.cs +++ b/Disco.Models/UI/Config/Enrolment/ConfigEnrolmentIndexModel.cs @@ -1,8 +1,17 @@ -namespace Disco.Models.UI.Config.Enrolment +using System; + +namespace Disco.Models.UI.Config.Enrolment { public interface ConfigEnrolmentIndexModel : BaseUIModel { string MacSshUsername { get; set; } int PendingTimeoutMinutes { get; set; } + Uri MacEnrolUrl { get; set; } + bool HostingPluginInstalled { get; set; } + bool IsVicSmartDeployment { get; set; } + bool IsServicesEducationVicGovAuDomain { get; set; } + string DnsSrvRecordName { get; set; } + string DnsSrvRecordValue { get; set; } + bool LegacyDiscoveryEnabled { get; set; } } } diff --git a/Disco.Services/Devices/Enrolment/WindowsDeviceEnrolment.cs b/Disco.Services/Devices/Enrolment/WindowsDeviceEnrolment.cs index e1d0097e..ecf5f2bd 100644 --- a/Disco.Services/Devices/Enrolment/WindowsDeviceEnrolment.cs +++ b/Disco.Services/Devices/Enrolment/WindowsDeviceEnrolment.cs @@ -1,6 +1,7 @@ using Disco.Data.Repository; using Disco.Models.ClientServices; using Disco.Models.Repository; +using Disco.Models.Services.Devices; using Disco.Services.Authorization; using Disco.Services.Interop.ActiveDirectory; using Disco.Services.Users; @@ -18,6 +19,18 @@ namespace Disco.Services.Devices.Enrolment private static readonly string pendingIdentifierAlphabet = "23456789ABCDEFGHJKMNPQRSTWXYZ"; private static readonly Random pendingIdentifierRng = new Random(); private static readonly ConcurrentDictionary pendingEnrolments = new ConcurrentDictionary(); + private static readonly Dictionary discoveryMethodStatistics = Enum.GetValues(typeof(DeviceEnrolmentServerDiscoveryMethod)).Cast().ToDictionary(k => k, k => 0); + + public static string GetDnsServiceLocationRecordName() + => $"_discoict._tcp.{ActiveDirectory.Context.PrimaryDomain.Name}"; + + public static void IncrementDiscoveryMethod(DeviceEnrolmentServerDiscoveryMethod method) + { + discoveryMethodStatistics[method]++; + } + + public static IEnumerable> GetDiscoveryMethodStatistics() + => discoveryMethodStatistics.AsEnumerable(); private static void CleanupPendingEnrolments() { diff --git a/Disco.Services/Disco.Services.csproj b/Disco.Services/Disco.Services.csproj index 1b57ab56..fca0d1c5 100644 --- a/Disco.Services/Disco.Services.csproj +++ b/Disco.Services/Disco.Services.csproj @@ -497,6 +497,14 @@ + + + + + + + + diff --git a/Disco.Services/Interop/DNS/ADnsRecord.cs b/Disco.Services/Interop/DNS/ADnsRecord.cs new file mode 100644 index 00000000..abd712ff --- /dev/null +++ b/Disco.Services/Interop/DNS/ADnsRecord.cs @@ -0,0 +1,24 @@ +using System; +using System.Net; + +namespace Disco.Services.Interop.DNS +{ + public class ADnsRecord : DnsRecord + { + public IPAddress Address { get; } + + public ADnsRecord(string name, TimeSpan timeToLive, uint address) + : base(name, DnsRecordType.A, timeToLive, UIntToIPAddress(address).ToString()) + { + Address = UIntToIPAddress(address); + } + + private static IPAddress UIntToIPAddress(uint address) + { + byte[] bytes = BitConverter.GetBytes(address); + if (BitConverter.IsLittleEndian) + Array.Reverse(bytes); + return new IPAddress(bytes); + } + } +} diff --git a/Disco.Services/Interop/DNS/CnameDnsRecord.cs b/Disco.Services/Interop/DNS/CnameDnsRecord.cs new file mode 100644 index 00000000..731b1436 --- /dev/null +++ b/Disco.Services/Interop/DNS/CnameDnsRecord.cs @@ -0,0 +1,12 @@ +using System; + +namespace Disco.Services.Interop.DNS +{ + public class CnameDnsRecord : DnsRecord + { + public CnameDnsRecord(string name, TimeSpan timeToLive, string canonicalName) + : base(name, DnsRecordType.Cname, timeToLive, canonicalName) + { + } + } +} diff --git a/Disco.Services/Interop/DNS/DnsRecord.cs b/Disco.Services/Interop/DNS/DnsRecord.cs new file mode 100644 index 00000000..c20d8bee --- /dev/null +++ b/Disco.Services/Interop/DNS/DnsRecord.cs @@ -0,0 +1,20 @@ +using System; + +namespace Disco.Services.Interop.DNS +{ + public abstract class DnsRecord + { + public string Name { get; } + public DnsRecordType Type { get; } + public TimeSpan TimeToLive { get; } + public string Content { get; } + + protected DnsRecord(string name, DnsRecordType type, TimeSpan timeToLive, string content) + { + Name = name; + Type = type; + TimeToLive = timeToLive; + Content = content; + } + } +} diff --git a/Disco.Services/Interop/DNS/DnsRecordType.cs b/Disco.Services/Interop/DNS/DnsRecordType.cs new file mode 100644 index 00000000..7af27365 --- /dev/null +++ b/Disco.Services/Interop/DNS/DnsRecordType.cs @@ -0,0 +1,10 @@ +namespace Disco.Services.Interop.DNS +{ + public enum DnsRecordType + { + A = 0x01, + Cname = 0x05, + Txt = 0x10, + Srv = 0x21 + } +} diff --git a/Disco.Services/Interop/DNS/DnsService.cs b/Disco.Services/Interop/DNS/DnsService.cs new file mode 100644 index 00000000..fc02a3e4 --- /dev/null +++ b/Disco.Services/Interop/DNS/DnsService.cs @@ -0,0 +1,30 @@ +using System; +using System.Collections.Generic; + +namespace Disco.Services.Interop.DNS +{ + public class DnsService + { + public DnsService() + { + } + + public static List Query(string name, bool bypassCache = false) where T : DnsRecord + { + DnsRecordType recordType; + if (typeof(T) == typeof(ADnsRecord)) + recordType = DnsRecordType.A; + else if (typeof(T) == typeof(CnameDnsRecord)) + recordType = DnsRecordType.Cname; + else if (typeof(T) == typeof(TxtDnsRecord)) + recordType = DnsRecordType.Txt; + else if (typeof(T) == typeof(SrvDnsRecord)) + recordType = DnsRecordType.Srv; + else + throw new NotSupportedException($"Unsupported DNS record type: {typeof(T).Name}"); + var records = NativeDns.QueryRecords(recordType, name, bypassCache); + return records.ConvertAll(r => (T)r); + } + + } +} diff --git a/Disco.Services/Interop/DNS/NativeDns.cs b/Disco.Services/Interop/DNS/NativeDns.cs new file mode 100644 index 00000000..016dac5b --- /dev/null +++ b/Disco.Services/Interop/DNS/NativeDns.cs @@ -0,0 +1,202 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Runtime.InteropServices; +using System.Threading; + +namespace Disco.Services.Interop.DNS +{ + internal static class NativeDns + { + + [DllImport("dnsapi", EntryPoint = "DnsQuery_W", CharSet = CharSet.Unicode, SetLastError = true, ExactSpelling = true)] + private static extern int DnsQuery([MarshalAs(UnmanagedType.VBByRefStr)] ref string pszName, NativeDnsQueryTypes wType, NativeDnsQueryOptions options, int aipServers, ref IntPtr ppQueryResults, int pReserved); + + [DllImport("dnsapi", CharSet = CharSet.Auto, SetLastError = true)] + private static extern void DnsRecordListFree(IntPtr pRecordList, int FreeType); + private const int DNS_ERROR_RCODE_NAME_ERROR = 0x232B; + private const int DNS_ERROR_BAD_PACKET = 0x251E; + + public static List QueryRecords(DnsRecordType type, string name, bool bypassCache) + { + NativeDnsQueryTypes queryType; + Func> marshaller; + + switch (type) + { + case DnsRecordType.A: + queryType = NativeDnsQueryTypes.DNS_TYPE_A; + marshaller = MarshalARecord; + break; + case DnsRecordType.Cname: + queryType = NativeDnsQueryTypes.DNS_TYPE_CNAME; + marshaller = MarshalCnameRecord; + break; + case DnsRecordType.Txt: + queryType = NativeDnsQueryTypes.DNS_TYPE_TEXT; + marshaller = MarshalTxtRecord; + break; + case DnsRecordType.Srv: + queryType = NativeDnsQueryTypes.DNS_TYPE_SRV; + marshaller = MarshalSrvRecord; + break; + default: + throw new NotSupportedException($"Unsupported DNS record type: {type}"); + } + + IntPtr rrPointers = IntPtr.Zero; + var records = new List(); + var retry = 5; + retry: + try + { + int queryResult = DnsQuery(ref name, queryType, bypassCache ? NativeDnsQueryOptions.DNS_QUERY_BYPASS_CACHE : NativeDnsQueryOptions.DNS_QUERY_STANDARD, 0, ref rrPointers, 0); + if (queryResult != 0) + { + if (queryResult == DNS_ERROR_RCODE_NAME_ERROR) + return records; + else if (queryResult == DNS_ERROR_BAD_PACKET && retry > 0) + { + // Sometimes a BAD_PACKET error is returned, retry a few times + Thread.Sleep(100); + retry--; + goto retry; + } + else + throw new Win32Exception(queryResult); + } + for (var rrPointer = rrPointers; !rrPointer.Equals(IntPtr.Zero);) + { + var (record, rrPointerNext) = marshaller(rrPointer); + records.Add(record); + rrPointer = rrPointerNext; + } + } + finally + { + if (rrPointers != IntPtr.Zero) + DnsRecordListFree(rrPointers, 0); + } + return records; + } + + private static Tuple MarshalARecord(IntPtr pointer) + { + var native = Marshal.PtrToStructure(pointer); + var record = new ADnsRecord(native.pName, TimeSpan.FromSeconds(native.dwTtl), native.IpAddress); + return Tuple.Create((DnsRecord)record, native.pNext); + } + + private static Tuple MarshalCnameRecord(IntPtr pointer) + { + var native = Marshal.PtrToStructure(pointer); + var record = new CnameDnsRecord(native.pName, TimeSpan.FromSeconds(native.dwTtl), native.pNameHost); + return Tuple.Create((DnsRecord)record, native.pNext); + } + + private static Tuple MarshalTxtRecord(IntPtr pointer) + { + var native = Marshal.PtrToStructure(pointer); + var record = new TxtDnsRecord(native.pName, TimeSpan.FromSeconds(native.dwTtl), native.pStringArray); + return Tuple.Create((DnsRecord)record, native.pNext); + } + + private static Tuple MarshalSrvRecord(IntPtr pointer) + { + var native = Marshal.PtrToStructure(pointer); + var record = new SrvDnsRecord(native.pName, TimeSpan.FromSeconds(native.dwTtl), native.pNameTarget, native.wPriority, native.wWeight, native.wPort); + return Tuple.Create((DnsRecord)record, native.pNext); + } + + private enum NativeDnsQueryOptions + { + DNS_QUERY_ACCEPT_TRUNCATED_RESPONSE = 1, + DNS_QUERY_BYPASS_CACHE = 8, + DNS_QUERY_DONT_RESET_TTL_VALUES = 0x100000, + DNS_QUERY_NO_HOSTS_FILE = 0x40, + DNS_QUERY_NO_LOCAL_NAME = 0x20, + DNS_QUERY_NO_NETBT = 0x80, + DNS_QUERY_NO_RECURSION = 4, + DNS_QUERY_NO_WIRE_QUERY = 0x10, + DNS_QUERY_RESERVED = -16777216, + DNS_QUERY_RETURN_MESSAGE = 0x200, + DNS_QUERY_STANDARD = 0, + DNS_QUERY_TREAT_AS_FQDN = 0x1000, + DNS_QUERY_USE_TCP_ONLY = 2, + DNS_QUERY_WIRE_ONLY = 0x100 + } + + private enum NativeDnsQueryTypes + { + DNS_TYPE_A = 0x0001, + DNS_TYPE_CNAME = 0x0005, + DNS_TYPE_TEXT = 0x0010, + DNS_TYPE_SRV = 0x0021 + } + + [StructLayout(LayoutKind.Sequential)] + private struct NativeDnsSrvData + { + public IntPtr pNext; + [MarshalAs(UnmanagedType.LPWStr)] + public string pName; + public ushort wType; + public ushort wDataLength; + public int flags; + public int dwTtl; + public int dwReserved; + [MarshalAs(UnmanagedType.LPWStr)] + public string pNameTarget; + public ushort wPriority; + public ushort wWeight; + public ushort wPort; + public ushort Pad; + } + + [StructLayout(LayoutKind.Sequential)] + private struct NativeDnsTxtData + { + public IntPtr pNext; + [MarshalAs(UnmanagedType.LPWStr)] + public string pName; + public ushort wType; + public ushort wDataLength; + public int flags; + public int dwTtl; + public int dwReserved; + public uint dwStringLength; + [MarshalAs(UnmanagedType.LPWStr)] + public string pStringArray; + } + + [StructLayout(LayoutKind.Sequential)] + private struct NativeDnsPtrData + { + public IntPtr pNext; + [MarshalAs(UnmanagedType.LPWStr)] + public string pName; + public ushort wType; + public ushort wDataLength; + public int flags; + public int dwTtl; + public int dwReserved; + [MarshalAs(UnmanagedType.LPWStr)] + public string pNameHost; + } + + [StructLayout(LayoutKind.Sequential)] + private struct NativeDnsAData + { + public IntPtr pNext; + [MarshalAs(UnmanagedType.LPWStr)] + public string pName; + public ushort wType; + public ushort wDataLength; + public int flags; + public int dwTtl; + public int dwReserved; + public uint IpAddress; + } + + } +} diff --git a/Disco.Services/Interop/DNS/SrvDnsRecord.cs b/Disco.Services/Interop/DNS/SrvDnsRecord.cs new file mode 100644 index 00000000..4b53e2ab --- /dev/null +++ b/Disco.Services/Interop/DNS/SrvDnsRecord.cs @@ -0,0 +1,21 @@ +using System; + +namespace Disco.Services.Interop.DNS +{ + public class SrvDnsRecord : DnsRecord + { + public string Target { get; } + public ushort Priority { get; } + public ushort Weight { get; } + public ushort Port { get; } + + public SrvDnsRecord(string name, TimeSpan timeToLive, string target, ushort priority, ushort weight, ushort port) + : base(name, DnsRecordType.Srv, timeToLive, $"{priority} {weight} {port} {target}") + { + Target = target; + Priority = priority; + Weight = weight; + Port = port; + } + } +} diff --git a/Disco.Services/Interop/DNS/TxtDnsRecord.cs b/Disco.Services/Interop/DNS/TxtDnsRecord.cs new file mode 100644 index 00000000..e68a2587 --- /dev/null +++ b/Disco.Services/Interop/DNS/TxtDnsRecord.cs @@ -0,0 +1,12 @@ +using System; + +namespace Disco.Services.Interop.DNS +{ + public class TxtDnsRecord : DnsRecord + { + public TxtDnsRecord(string name, TimeSpan timeToLive, string text) + : base(name, DnsRecordType.Txt, timeToLive, text) + { + } + } +} diff --git a/Disco.Services/Interop/DiscoServices/UpdateQuery.cs b/Disco.Services/Interop/DiscoServices/UpdateQuery.cs index 2477e318..9d698be0 100644 --- a/Disco.Services/Interop/DiscoServices/UpdateQuery.cs +++ b/Disco.Services/Interop/DiscoServices/UpdateQuery.cs @@ -1,6 +1,7 @@ using Disco.Data.Repository; using Disco.Models.Repository; using Disco.Models.Services.Interop.DiscoServices; +using Disco.Services.Devices.Enrolment; using Disco.Services.Tasks; using Newtonsoft.Json; using System; @@ -221,6 +222,10 @@ namespace Disco.Services.Interop.DiscoServices RepairerLogged = j.JobType == JobType.JobTypeIds.HWar ? j.WarrantyRepairerLoggedDate : j.RepairerLoggedDate, RepairerCompleted = j.JobType == JobType.JobTypeIds.HWar ? j.WarrantyRepairerCompletedDate : j.RepairerCompletedDate }).ToList(); + + m.Stat_EnrollmentDiscovery = WindowsDeviceEnrolment.GetDiscoveryMethodStatistics() + .Where(s => s.Value != 0) + .Select(s => new StatisticInt() { Key = s.Key.ToString(), Value = s.Value }).ToList(); } m.InstalledPlugins = Plugins.Plugins.GetPlugins().Select(manifest => new StatisticString() { Key = manifest.Id, Value = manifest.VersionFormatted }).ToList(); diff --git a/Disco.Services/Interop/VicEduDept/VicSmart.cs b/Disco.Services/Interop/VicEduDept/VicSmart.cs index 7c27ab16..af42fd71 100644 --- a/Disco.Services/Interop/VicEduDept/VicSmart.cs +++ b/Disco.Services/Interop/VicEduDept/VicSmart.cs @@ -1,13 +1,51 @@ using Disco.Services.Interop.DiscoServices; using System; using System.IO; +using System.Linq; using System.Net; +using System.Net.NetworkInformation; using System.Xml.Linq; namespace Disco.Services.Interop.VicEduDept { public class VicSmart { + public static bool IsVicSmartDeployment() + { + var nics = NetworkInterface.GetAllNetworkInterfaces() + .Where(ni => ni.OperationalStatus == OperationalStatus.Up) + .ToList(); + + bool found10Net = false; + foreach (var nic in nics) + { + if (nic.Supports(NetworkInterfaceComponent.IPv4)) + { + var ipProps = nic.GetIPProperties(); + var ipv4Props = ipProps.GetIPv4Properties(); + if (ipv4Props.IsAutomaticPrivateAddressingActive) + continue; + + found10Net = ipProps.UnicastAddresses + .Where(ua => + ua.Address.AddressFamily == System.Net.Sockets.AddressFamily.InterNetwork && + ua.Address.GetAddressBytes()[0] == 10) + .Any(); + if (found10Net) + break; + } + } + if (!found10Net) + return false; + + try + { + var entry = Dns.GetHostEntry("broadband.doe.wan"); + return entry.AddressList.Length > 0; + } + catch (Exception) + { return false; } // Fail on error + } /// /// Queries DoE VicSmart Service to detect the current site. diff --git a/Disco.Web/Areas/API/Controllers/EnrolmentController.cs b/Disco.Web/Areas/API/Controllers/EnrolmentController.cs index 958b5fc0..ccea355c 100644 --- a/Disco.Web/Areas/API/Controllers/EnrolmentController.cs +++ b/Disco.Web/Areas/API/Controllers/EnrolmentController.cs @@ -88,5 +88,21 @@ namespace Disco.Web.Areas.API.Controllers return BadRequest(ex.Message); } } + + [DiscoAuthorize(Claims.Config.Enrolment.Configure)] + [HttpPost, ValidateAntiForgeryToken] + public virtual ActionResult LegacyDiscovery(bool enabled) + { + try + { + Database.DiscoConfiguration.Devices.EnrollmentLegacyDiscoveryDisabled = !enabled; + Database.SaveChanges(); + return Ok(); + } + catch (Exception ex) + { + return BadRequest(ex.Message); + } + } } } diff --git a/Disco.Web/Areas/Config/Controllers/EnrolmentController.cs b/Disco.Web/Areas/Config/Controllers/EnrolmentController.cs index 4fcf2968..79990672 100644 --- a/Disco.Web/Areas/Config/Controllers/EnrolmentController.cs +++ b/Disco.Web/Areas/Config/Controllers/EnrolmentController.cs @@ -1,7 +1,13 @@ using Disco.Models.UI.Config.Enrolment; using Disco.Services.Authorization; +using Disco.Services.Devices.Enrolment; +using Disco.Services.Interop.ActiveDirectory; +using Disco.Services.Interop.DNS; +using Disco.Services.Interop.VicEduDept; +using Disco.Services.Plugins; using Disco.Services.Plugins.Features.UIExtension; using Disco.Services.Web; +using System; using System.Linq; using System.Web.Mvc; @@ -12,10 +18,30 @@ namespace Disco.Web.Areas.Config.Controllers [DiscoAuthorize(Claims.Config.Enrolment.Show)] public virtual ActionResult Index() { + var serverUrl = Request.Url; + if ((serverUrl.HostNameType == UriHostNameType.Dns && serverUrl.Host.Equals("localhost", StringComparison.OrdinalIgnoreCase)) || + serverUrl.HostNameType == UriHostNameType.IPv4 || serverUrl.HostNameType == UriHostNameType.IPv6) + { + serverUrl = new UriBuilder(serverUrl) + { + Host = Environment.MachineName + }.Uri; + } + + var srvRecord = DnsService.Query(WindowsDeviceEnrolment.GetDnsServiceLocationRecordName(), true).FirstOrDefault(); + var srvValue = srvRecord == null ? null : (srvRecord.Port == 443 ? srvRecord.Target : $"{srvRecord.Target}:{srvRecord.Port}"); + var m = new Models.Enrolment.IndexModel() { MacSshUsername = Database.DiscoConfiguration.Bootstrapper.MacSshUsername, PendingTimeoutMinutes = (int)Database.DiscoConfiguration.Bootstrapper.PendingTimeout.TotalMinutes, + MacEnrolUrl = new Uri(serverUrl, Url.Action(MVC.Services.Client.Unauthenticated("MacSecureEnrol"))), + HostingPluginInstalled = Plugins.PluginInstalled("Hosting"), + IsServicesEducationVicGovAuDomain = ActiveDirectory.Context.PrimaryDomain.Name.Equals("services.education.vic.gov.au", StringComparison.OrdinalIgnoreCase), + IsVicSmartDeployment = VicSmart.IsVicSmartDeployment(), + DnsSrvRecordName = WindowsDeviceEnrolment.GetDnsServiceLocationRecordName(), + DnsSrvRecordValue = srvValue, + LegacyDiscoveryEnabled = !Database.DiscoConfiguration.Devices.EnrollmentLegacyDiscoveryDisabled, }; // UI Extensions diff --git a/Disco.Web/Areas/Config/Models/Enrolment/IndexModel.cs b/Disco.Web/Areas/Config/Models/Enrolment/IndexModel.cs index e70f7b42..5e63e5b4 100644 --- a/Disco.Web/Areas/Config/Models/Enrolment/IndexModel.cs +++ b/Disco.Web/Areas/Config/Models/Enrolment/IndexModel.cs @@ -1,4 +1,5 @@ using Disco.Models.UI.Config.Enrolment; +using System; namespace Disco.Web.Areas.Config.Models.Enrolment { @@ -6,5 +7,12 @@ namespace Disco.Web.Areas.Config.Models.Enrolment { public string MacSshUsername { get; set; } public int PendingTimeoutMinutes { get; set; } + public Uri MacEnrolUrl { get; set; } + public bool HostingPluginInstalled { get; set; } + public bool IsVicSmartDeployment { get; set; } + public bool IsServicesEducationVicGovAuDomain { get; set; } + public string DnsSrvRecordName { get; set; } + public string DnsSrvRecordValue { get; set; } + public bool LegacyDiscoveryEnabled { get; set; } } } \ No newline at end of file diff --git a/Disco.Web/Areas/Config/Views/Enrolment/Index.cshtml b/Disco.Web/Areas/Config/Views/Enrolment/Index.cshtml index 13145d82..06c3a4a3 100644 --- a/Disco.Web/Areas/Config/Views/Enrolment/Index.cshtml +++ b/Disco.Web/Areas/Config/Views/Enrolment/Index.cshtml @@ -121,7 +121,7 @@ able to connect to the requesting Apple Mac client via SSH. Enter/Script the following command: This url will return a JSON response containing basic information about the enrolment.
@@ -133,6 +133,167 @@ +
+

Bootstrapper Server Discovery

+ + + + +
+
+ The Disco ICT + @if (Authorization.Has(Claims.Config.Enrolment.DownloadBootstrapper)) + { + @Html.ActionLink("Bootstrapper", MVC.Services.Client.Bootstrapper()) + } + else + { + Bootstrapper + } + is used to enrol devices. It is strongly recommended that HTTPS be used for all communication. + the + The @Html.ActionLink("Hosting", Model.HostingPluginInstalled ? MVC.Config.Plugins.Configure("Hosting") : MVC.Config.Plugins.Install()) + plugin can be used to automate deployment of HTTPS certificates. +
+
+ The Bootstrapper discovers the server using the first successful method (in order): +
+
    +
  1. +
    Manually Specified
    +
    + The server url can be specified at the command line. The url must use HTTPS. For example: +
    +
    Disco.ClientBootstrapper.exe https://@Request.Url.Authority
    +
  2. +
  3. +
    DNS Service Location (SRV) Record
    + Expected Record Name: @Model.DnsSrvRecordName + @if (Model.IsServicesEducationVicGovAuDomain) + { +
    + This mechanism is not supported in the shared education.vic.gov.au domain and can be ignored. +
    + } + else + { + if (Model.DnsSrvRecordValue == null) + { +
    + + No Service Location (SRV) record found. + + @if (Request.IsSecureConnection) + { + + Please create a DNS Service Location (SRV) record: + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Service:_discoict
    Protocol:_tcp
    Priority:0
    Weight:0
    Port:@Request.Url.Port
    Host offering this service:@Request.Url.Host
    + } + else + { +
    + Please configure and connect with HTTPS. + + You can enable HTTPS automation using the + @Html.ActionLink("Hosting", Model.HostingPluginInstalled ? MVC.Config.Plugins.Configure("Hosting") : MVC.Config.Plugins.Install()) + plugin. + +
    + } +
    + } + else + { +
    + Value: https://@Model.DnsSrvRecordValue + @if (Request.IsSecureConnection && !string.Equals(Model.DnsSrvRecordValue, Request.Url.Authority, StringComparison.OrdinalIgnoreCase)) + { +
    + The Service Location (SRV) record does not match the way you are currently accessing the server: @Request.Url.Authority. +
    + } +
    + } + } +
  4. + @if (Model.IsVicSmartDeployment) + { +
  5. +
    Victorian Government Schools VicSmart Discovery
    + If the Bootstrapper detects it is running inside the VicSmart network, it will query Online Services for the Disco ICT server address based on the subnets assigned to each school. + This is configured in the @Html.ActionLink("Hosting", Model.HostingPluginInstalled ? MVC.Config.Plugins.Configure("Hosting") : MVC.Config.Plugins.Install()) + plugin. +
  6. + } +
  7. +
    Legacy Discovery
    +
    + The Bootstrapper will attempt to send an ICMP ping to "disco". If the ping is successful, it will attempt to connect to http://disco:9292/. +
    +
    + @if (canConfig) + { + + + } + else + { + + } + + @AjaxHelpers.AjaxLoader() +
    + @if ((Model.IsServicesEducationVicGovAuDomain || Model.DnsSrvRecordValue != null) && Model.LegacyDiscoveryEnabled) + { +
    + + It is not recommended to have Legacy Discovery enabled. Please use the latest Bootstrapper and disable this option. +
    + } +
    + This method is not secure and is only provided for backwards compatibility. In time this method will be removed. +
    +
  8. +
+
+
@if (canShowStatus && Authorization.Has(Claims.Config.Logging.Show)) {

Live Enrolment Logging

diff --git a/Disco.Web/Areas/Config/Views/Enrolment/Index.generated.cs b/Disco.Web/Areas/Config/Views/Enrolment/Index.generated.cs index 83f57c72..a0356404 100644 --- a/Disco.Web/Areas/Config/Views/Enrolment/Index.generated.cs +++ b/Disco.Web/Areas/Config/Views/Enrolment/Index.generated.cs @@ -451,10 +451,26 @@ WriteLiteral(">\r\n curl (Model.MacEnrolUrl + + #line default + #line hidden +, 4888), false) +); -WriteLiteral(">http://disco:9292/Services/Client/Unauthenticated/MacSecureEnrol\r\n " + -" \r\n "); + + + #line 124 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + Write(Model.MacEnrolUrl); + + + #line default + #line hidden +WriteLiteral("\r\n \r\n <script>
\r\n tag embedded on the organisation\'s in" + "tranet.\r\n \r\n \r\n \r\n " + -"\r\n\r\n"); +"\r\n\r\n\r\n

Bootstrapper Server Discovery

\r\n \r\n \r\n " + +" + + + + + + + + + + + + + + + + + + + + + + \r\n \r\n " + +"
\r\n
\r\n The Disco ICT\r\n"); - #line 136 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + #line 143 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + + + #line default + #line hidden + + #line 143 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + if (Authorization.Has(Claims.Config.Enrolment.DownloadBootstrapper)) + { + + + #line default + #line hidden + + #line 145 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + Write(Html.ActionLink("Bootstrapper", MVC.Services.Client.Bootstrapper())); + + + #line default + #line hidden + + #line 145 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + + } + else + { + + + #line default + #line hidden +WriteLiteral(" "); + +WriteLiteral("Bootstrapper"); + +WriteLiteral("\r\n"); + + + #line 150 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + } + + + #line default + #line hidden +WriteLiteral(" is used to enrol devices. It is strongly recommended that HTT" + +"PS be used for all communication.\r\n the\r\n " + +"The "); + + + #line 153 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + Write(Html.ActionLink("Hosting", Model.HostingPluginInstalled ? MVC.Config.Plugins.Configure("Hosting") : MVC.Config.Plugins.Install())); + + + #line default + #line hidden +WriteLiteral(@" + plugin can be used to automate deployment of HTTPS certificates. +
+
+ The Bootstrapper discovers the server using the first successful method (in order): +
+
    +
  1. +
    Manually Specified
    +
    + The server url can be specified at the command line. The url must use HTTPS. For example: +
    + Disco.ClientBootstrapper.exe https://"); + + + #line 165 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + Write(Request.Url.Authority); + + + #line default + #line hidden +WriteLiteral("\r\n
  2. \r\n
  3. \r\n " + +"
    DNS Service Location (SRV) Record
    \r\n Expected" + +" Record Name: "); + + + #line 169 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + Write(Model.DnsSrvRecordName); + + + #line default + #line hidden +WriteLiteral("\r\n"); + + + #line 170 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + + + #line default + #line hidden + + #line 170 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + if (Model.IsServicesEducationVicGovAuDomain) + { + + + #line default + #line hidden +WriteLiteral(" \r\n This mechanism is not supported in the shared " + +"education.vic.gov.au domain and can be ignored.\r\n \r\n"); + + + #line 175 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + } + else + { + if (Model.DnsSrvRecordValue == null) + { + + + #line default + #line hidden +WriteLiteral(" \r\n \r\n No Service Location (SRV) record found" + +".\r\n \r\n"); + + + #line 184 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + + + #line default + #line hidden + + #line 184 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + if (Request.IsSecureConnection) + { + + + #line default + #line hidden +WriteLiteral(" \r\n " + +" Please create a DNS Service Location (SRV) record:\r\n " + +" \r\n"); + +WriteLiteral(" +
Service:_discoict
Protocol:_tcp
Priority:0
Weight:0
Port:"); + + + #line 208 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + Write(Request.Url.Port); + + + #line default + #line hidden +WriteLiteral(@"
Host offering this service:"); + + + #line 212 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + Write(Request.Url.Host); + + + #line default + #line hidden +WriteLiteral("
\r\n"); + + + #line 215 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + } + else + { + + + #line default + #line hidden +WriteLiteral(@"
+ Please configure and connect with HTTPS. + + You can enable HTTPS automation using the +"); + +WriteLiteral(" "); + + + #line 222 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + Write(Html.ActionLink("Hosting", Model.HostingPluginInstalled ? MVC.Config.Plugins.Configure("Hosting") : MVC.Config.Plugins.Install())); + + + #line default + #line hidden +WriteLiteral("\r\n plugin.\r\n " + +" \r\n
\r\n"); + + + #line 226 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + } + + + #line default + #line hidden +WriteLiteral(" \r\n"); + + + #line 228 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + } + else + { + + + #line default + #line hidden +WriteLiteral("
\r\n Value:" + +" https://"); + + + #line 232 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + Write(Model.DnsSrvRecordValue); + + + #line default + #line hidden +WriteLiteral("\r\n"); + + + #line 233 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + + + #line default + #line hidden + + #line 233 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + if (Request.IsSecureConnection && !string.Equals(Model.DnsSrvRecordValue, Request.Url.Authority, StringComparison.OrdinalIgnoreCase)) + { + + + #line default + #line hidden +WriteLiteral(" \r\n The Service Location (SRV) record does not match the way you are currently " + +"accessing the server: "); + + + #line 236 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + Write(Request.Url.Authority); + + + #line default + #line hidden +WriteLiteral(".\r\n
\r\n"); + + + #line 238 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + } + + + #line default + #line hidden +WriteLiteral(" \r\n"); + + + #line 240 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + } + } + + + #line default + #line hidden +WriteLiteral(" \r\n"); + + + #line 243 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + + + #line default + #line hidden + + #line 243 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + if (Model.IsVicSmartDeployment) + { + + + #line default + #line hidden +WriteLiteral(@"
  • +
    Victorian Government Schools VicSmart Discovery
    + If the Bootstrapper detects it is running inside the VicSmart network, it will query Online Services for the Disco ICT server address based on the subnets assigned to each school. + This is configured in the "); + + + #line 248 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + Write(Html.ActionLink("Hosting", Model.HostingPluginInstalled ? MVC.Config.Plugins.Configure("Hosting") : MVC.Config.Plugins.Install())); + + + #line default + #line hidden +WriteLiteral("\r\n plugin.\r\n
  • \r\n"); + + + #line 251 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + } + + + #line default + #line hidden +WriteLiteral(@"
  • +
    Legacy Discovery
    +
    + The Bootstrapper will attempt to send an ICMP ping to "disco". If the ping is successful, it will attempt to connect to http://disco:9292/. +
    +
    +"); + + + #line 258 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + + + #line default + #line hidden + + #line 258 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + if (canConfig) + { + + + #line default + #line hidden +WriteLiteral("
    \r\n"); + + + #line 281 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + + + #line default + #line hidden + + #line 281 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + if ((Model.IsServicesEducationVicGovAuDomain || Model.DnsSrvRecordValue != null) && Model.LegacyDiscoveryEnabled) + { + + + #line default + #line hidden +WriteLiteral(" \r\n \r\n It is not recommended to have Legacy Disco" + +"very enabled. Please use the latest Bootstrapper and disable this option.\r\n " + +" \r\n"); + + + #line 287 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + } + + + #line default + #line hidden +WriteLiteral(@"
    + This method is not secure and is only provided for backwards compatibility. In time this method will be removed. +
    +
  • + + + + + +"); + + + #line 297 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" if (canShowStatus && Authorization.Has(Claims.Config.Logging.Show)) { @@ -499,13 +1026,13 @@ WriteLiteral("><script>\r\n tag embedded on the WriteLiteral("

    Live Enrolment Logging

    \r\n"); - #line 139 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + #line 300 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" #line default #line hidden - #line 139 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + #line 300 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" Write(Html.Partial(MVC.Config.Shared.Views.LogEvents, new Disco.Web.Areas.Config.Models.Shared.LogEventsModel() { IsLive = true, @@ -519,7 +1046,7 @@ Write(Html.Partial(MVC.Config.Shared.Views.LogEvents, new Disco.Web.Areas.Config #line default #line hidden - #line 146 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + #line 307 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" } @@ -533,13 +1060,13 @@ WriteLiteral(" class=\"actionBar\""); WriteLiteral(">\r\n"); - #line 149 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + #line 310 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" #line default #line hidden - #line 149 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + #line 310 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" if (Authorization.Has(Claims.Config.Enrolment.DownloadBootstrapper)) { @@ -547,14 +1074,14 @@ WriteLiteral(">\r\n"); #line default #line hidden - #line 151 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + #line 312 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" Write(Html.ActionLinkButton("Download Bootstrapper", MVC.Services.Client.Bootstrapper())); #line default #line hidden - #line 151 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + #line 312 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" } @@ -564,7 +1091,7 @@ WriteLiteral(">\r\n"); WriteLiteral(" "); - #line 153 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + #line 314 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" if (canShowStatus) { @@ -572,14 +1099,14 @@ WriteLiteral(" "); #line default #line hidden - #line 155 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + #line 316 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" Write(Html.ActionLinkButton("Enrolment Status", MVC.Config.Enrolment.Status())); #line default #line hidden - #line 155 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" + #line 316 "..\..\Areas\Config\Views\Enrolment\Index.cshtml" } diff --git a/Disco.Web/Areas/Services/Controllers/ClientController.cs b/Disco.Web/Areas/Services/Controllers/ClientController.cs index 969543c4..834077db 100644 --- a/Disco.Web/Areas/Services/Controllers/ClientController.cs +++ b/Disco.Web/Areas/Services/Controllers/ClientController.cs @@ -1,5 +1,6 @@ using Disco.Data.Repository; using Disco.Models.ClientServices; +using Disco.Models.Services.Devices; using Disco.Services; using Disco.Services.Authorization; using Disco.Services.Devices.Enrolment; @@ -22,11 +23,21 @@ namespace Disco.Web.Areas.Services.Controllers public virtual ActionResult PreparationClient() { + var discoveryMethodHeader = Request.Headers["X-DiscoICT-Discovery"]; + if (!string.IsNullOrEmpty(discoveryMethodHeader) && Enum.TryParse(discoveryMethodHeader, out var discoveryMethod)) + WindowsDeviceEnrolment.IncrementDiscoveryMethod(discoveryMethod); + + if (!CheckLegacyEnrollmentDiscovery()) + return BadRequest("Enrollment Legacy Discovery is disabled. Please use secure connection (HTTPS) for device enrollment."); + return File(Links.ClientBin.PreparationClient_zip, "application/x-msdownload", "PreparationClient.zip"); } public virtual ActionResult Unauthenticated(string feature) { + if (!CheckLegacyEnrollmentDiscovery()) + return BadRequest("Enrollment Legacy Discovery is disabled. Please use secure connection (HTTPS) for device enrollment."); + if (string.IsNullOrEmpty(feature)) { return Json(null); @@ -64,6 +75,7 @@ namespace Disco.Web.Areas.Services.Controllers } case "macenrol": { + WindowsDeviceEnrolment.IncrementDiscoveryMethod(DeviceEnrolmentServerDiscoveryMethod.Mac); var Binder = ModelBinders.Binders.GetBinder(typeof(MacEnrol)); var BinderContext = new ModelBindingContext() { @@ -78,6 +90,7 @@ namespace Disco.Web.Areas.Services.Controllers } case "macsecureenrol": { + WindowsDeviceEnrolment.IncrementDiscoveryMethod(DeviceEnrolmentServerDiscoveryMethod.MacSecure); using (var database = new DiscoDataContext()) { var host = HttpContext.Request.UserHostAddress; @@ -93,6 +106,9 @@ namespace Disco.Web.Areas.Services.Controllers [Authorize] public virtual ActionResult Authenticated(string feature) { + if (!CheckLegacyEnrollmentDiscovery()) + return BadRequest("Enrollment Legacy Discovery is disabled. Please use secure connection (HTTPS) for device enrollment."); + if (string.IsNullOrEmpty(feature)) { WhoAmIResponse whoAmIResponse = new WhoAmI().BuildResponse(); @@ -171,5 +187,21 @@ namespace Disco.Web.Areas.Services.Controllers return Content("Error Message Logged"); } + private bool CheckLegacyEnrollmentDiscovery() + { + if (!Request.IsSecureConnection) + { + using (DiscoDataContext database = new DiscoDataContext()) + { + if (database.DiscoConfiguration.Devices.EnrollmentLegacyDiscoveryDisabled) + { + EnrolmentLog.LogClientError(Request.UserHostAddress, Request.UserHostName, string.Empty, "Enrollment Legacy Discovery is disabled. Please use secure connection (HTTPS) for device enrollment.", string.Empty); + return false; + } + } + } + return true; + } + } } diff --git a/Disco.Web/Extensions/T4MVC/API.EnrolmentController.generated.cs b/Disco.Web/Extensions/T4MVC/API.EnrolmentController.generated.cs index 8694b8ca..fec36f75 100644 --- a/Disco.Web/Extensions/T4MVC/API.EnrolmentController.generated.cs +++ b/Disco.Web/Extensions/T4MVC/API.EnrolmentController.generated.cs @@ -83,6 +83,12 @@ namespace Disco.Web.Areas.API.Controllers { return new T4MVC_System_Web_Mvc_ActionResult(Area, Name, ActionNames.MacSshPassword); } + [NonAction] + [GeneratedCode("T4MVC", "2.0"), DebuggerNonUserCode] + public virtual System.Web.Mvc.ActionResult LegacyDiscovery() + { + return new T4MVC_System_Web_Mvc_ActionResult(Area, Name, ActionNames.LegacyDiscovery); + } [GeneratedCode("T4MVC", "2.0"), DebuggerNonUserCode] public EnrolmentController Actions { get { return MVC.API.Enrolment; } } @@ -103,6 +109,7 @@ namespace Disco.Web.Areas.API.Controllers public readonly string PendingTimeoutMinutes = "PendingTimeoutMinutes"; public readonly string MacSshUsername = "MacSshUsername"; public readonly string MacSshPassword = "MacSshPassword"; + public readonly string LegacyDiscovery = "LegacyDiscovery"; } [GeneratedCode("T4MVC", "2.0"), DebuggerNonUserCode] @@ -112,6 +119,7 @@ namespace Disco.Web.Areas.API.Controllers public const string PendingTimeoutMinutes = "PendingTimeoutMinutes"; public const string MacSshUsername = "MacSshUsername"; public const string MacSshPassword = "MacSshPassword"; + public const string LegacyDiscovery = "LegacyDiscovery"; } @@ -151,6 +159,14 @@ namespace Disco.Web.Areas.API.Controllers { public readonly string MacSshPassword = "MacSshPassword"; } + static readonly ActionParamsClass_LegacyDiscovery s_params_LegacyDiscovery = new ActionParamsClass_LegacyDiscovery(); + [GeneratedCode("T4MVC", "2.0"), DebuggerNonUserCode] + public ActionParamsClass_LegacyDiscovery LegacyDiscoveryParams { get { return s_params_LegacyDiscovery; } } + [GeneratedCode("T4MVC", "2.0"), DebuggerNonUserCode] + public class ActionParamsClass_LegacyDiscovery + { + public readonly string enabled = "enabled"; + } static readonly ViewsClass s_views = new ViewsClass(); [GeneratedCode("T4MVC", "2.0"), DebuggerNonUserCode] public ViewsClass Views { get { return s_views; } } @@ -222,6 +238,18 @@ namespace Disco.Web.Areas.API.Controllers return callInfo; } + [NonAction] + partial void LegacyDiscoveryOverride(T4MVC_System_Web_Mvc_ActionResult callInfo, bool enabled); + + [NonAction] + public override System.Web.Mvc.ActionResult LegacyDiscovery(bool enabled) + { + var callInfo = new T4MVC_System_Web_Mvc_ActionResult(Area, Name, ActionNames.LegacyDiscovery); + ModelUnbinderHelpers.AddRouteValues(callInfo.RouteValueDictionary, "enabled", enabled); + LegacyDiscoveryOverride(callInfo, enabled); + return callInfo; + } + } } diff --git a/Disco.sln b/Disco.sln index 23aad660..f0f71722 100644 --- a/Disco.sln +++ b/Disco.sln @@ -145,8 +145,9 @@ Global UpdateAssemblyVersion = True UpdateAssemblyFileVersion = True UpdateAssemblyInfoVersion = False - AssemblyVersionSettings = None.None.DateStamp.TimeStamp - AssemblyFileVersionSettings = None.None.DateStamp.TimeStamp + ShouldCreateLogs = True + AssemblyVersionSettings = None.None.DateStamp.None + AssemblyFileVersionSettings = None.None.DateStamp.None UpdatePackageVersion = False AssemblyInfoVersionType = SettingsVersion InheritWinAppVersionFrom = None