diff --git a/Disco.Services/Expressions/Extensions/UserExt.cs b/Disco.Services/Expressions/Extensions/UserExt.cs index 2fa95aa3..f0090bf0 100644 --- a/Disco.Services/Expressions/Extensions/UserExt.cs +++ b/Disco.Services/Expressions/Extensions/UserExt.cs @@ -1,6 +1,8 @@ -using Disco.Models.Repository; +using Disco.Data.Repository; +using Disco.Models.Repository; using Disco.Services.Users; using System; +using System.Data.Entity; using System.Linq; namespace Disco.Services.Expressions.Extensions @@ -69,5 +71,41 @@ namespace Disco.Services.Expressions.Extensions return authorization.HasAny(Claims); } #endregion + + #region Flag Extensions + + public static bool AddFlag(User user, string flagName, User techUser) + => AddFlag(user, flagName, techUser, comments: null); + + public static bool AddFlag(User user, string flagName, User techUser, string comments) + { + using (var database = new DiscoDataContext()) + { + database.Configuration.LazyLoadingEnabled = true; + + var flag = database.UserFlags.Single(f => f.Name == flagName); + if (flag == null) + throw new ArgumentException("Invalid User Flag Name", nameof(flagName)); + + var flagUser = database.Users.Include(u => u.UserFlagAssignments).FirstOrDefault(u => u.UserId == user.UserId); + if (flagUser == null) + throw new ArgumentException("Invalid User", nameof(user)); + + if (flagUser.UserFlagAssignments.Any(fa => !fa.RemovedDate.HasValue && fa.UserFlagId == flag.Id)) + return false; + + var addingUser = database.Users.Find(techUser.UserId); + if (addingUser == null) + throw new ArgumentException("Invalid Tech User", nameof(techUser)); + + var userFlagAssignment = flagUser.OnAddUserFlagUnsafe(database, flag, addingUser, comments); + + database.SaveChanges(); + } + + return true; + } + + #endregion } }