diff --git a/src/GenerateBatchVmSkus/GenerateBatchVmSkus.csproj b/src/GenerateBatchVmSkus/GenerateBatchVmSkus.csproj index 6c38ebbb8..7f2257ba3 100644 --- a/src/GenerateBatchVmSkus/GenerateBatchVmSkus.csproj +++ b/src/GenerateBatchVmSkus/GenerateBatchVmSkus.csproj @@ -8,7 +8,7 @@ - + diff --git a/src/Tes.ApiClients/Tes.ApiClients.csproj b/src/Tes.ApiClients/Tes.ApiClients.csproj index 446559573..9e3752919 100644 --- a/src/Tes.ApiClients/Tes.ApiClients.csproj +++ b/src/Tes.ApiClients/Tes.ApiClients.csproj @@ -8,7 +8,7 @@ - + diff --git a/src/Tes.Runner.Test/Tes.Runner.Test.csproj b/src/Tes.Runner.Test/Tes.Runner.Test.csproj index f47a18d37..49711d9fe 100644 --- a/src/Tes.Runner.Test/Tes.Runner.Test.csproj +++ b/src/Tes.Runner.Test/Tes.Runner.Test.csproj @@ -11,7 +11,7 @@ - + diff --git a/src/Tes.Runner/Tes.Runner.csproj b/src/Tes.Runner/Tes.Runner.csproj index 35f75644d..73ea6d965 100644 --- a/src/Tes.Runner/Tes.Runner.csproj +++ b/src/Tes.Runner/Tes.Runner.csproj @@ -7,7 +7,7 @@ - + diff --git a/src/Tes/Models/PostgreSqlOptions.cs b/src/Tes/Models/PostgreSqlOptions.cs index 65ef70e46..0fe19f255 100644 --- a/src/Tes/Models/PostgreSqlOptions.cs +++ b/src/Tes/Models/PostgreSqlOptions.cs @@ -27,5 +27,6 @@ public static string GetConfigurationSectionName(string serviceName = "Tes") public string DatabaseName { get; set; } = "tes_db"; public string DatabaseUserLogin { get; set; } public string DatabaseUserPassword { get; set; } + public bool UseManagedIdentity { get; set; } } } diff --git a/src/Tes/Models/TesTaskPostgres.cs b/src/Tes/Models/TesTaskPostgres.cs index bc6a528ac..6a7ed97aa 100644 --- a/src/Tes/Models/TesTaskPostgres.cs +++ b/src/Tes/Models/TesTaskPostgres.cs @@ -9,7 +9,7 @@ namespace Tes.Models /// /// Database schema for encapsulating a TesTask as Json for Postgresql. /// - [Table(Repository.TesDbContext.TesTasksPostgresTableName)] + [Table("testasks")] public class TesTaskDatabaseItem { [Column("id")] diff --git a/src/Tes/Repository/PostgreSqlCachingRepository.cs b/src/Tes/Repository/PostgreSqlCachingRepository.cs index 8a8c36811..e069fdde1 100644 --- a/src/Tes/Repository/PostgreSqlCachingRepository.cs +++ b/src/Tes/Repository/PostgreSqlCachingRepository.cs @@ -9,6 +9,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Polly; @@ -16,6 +17,7 @@ namespace Tes.Repository { public abstract class PostgreSqlCachingRepository : IDisposable where T : class { + private readonly IServiceScopeFactory _scopeFactory = null!; private readonly TimeSpan _writerWaitTime = TimeSpan.FromMilliseconds(50); private readonly int _batchSize = 1000; private static readonly TimeSpan defaultCompletedTaskCacheExpiration = TimeSpan.FromDays(1); @@ -30,17 +32,16 @@ public abstract class PostgreSqlCachingRepository : IDisposable where T : cla private readonly Task _writerWorkerTask; protected enum WriteAction { Add, Update, Delete } - - protected Func CreateDbContext { get; init; } protected readonly ICache _cache; protected readonly ILogger _logger; private bool _disposedValue; - protected PostgreSqlCachingRepository(ILogger logger = default, ICache cache = default) + protected PostgreSqlCachingRepository(ILogger logger = default, ICache cache = default, IServiceScopeFactory scopeFactory = default) { _logger = logger; _cache = cache; + _scopeFactory = scopeFactory; // The only "normal" exit for _writerWorkerTask is "cancelled". Anything else should force the process to exit because it means that this repository will no longer write to the database! _writerWorkerTask = Task.Run(() => WriterWorkerAsync(_writerWorkerCancellationTokenSource.Token)) @@ -187,7 +188,8 @@ private async ValueTask WriteItemsAsync(IList<(T DbItem, WriteAction Action, Tas if (dbItems.Count == 0) { return; } cancellationToken.ThrowIfCancellationRequested(); - using var dbContext = CreateDbContext(); + using var scope = _scopeFactory.CreateScope(); + using var dbContext = scope.ServiceProvider.GetRequiredService(); // Manually set entity state to avoid potential NPG PostgreSql bug dbContext.ChangeTracker.AutoDetectChangesEnabled = false; diff --git a/src/Tes/Repository/TesDbContext.cs b/src/Tes/Repository/TesDbContext.cs index aa766ba61..1f288fcca 100644 --- a/src/Tes/Repository/TesDbContext.cs +++ b/src/Tes/Repository/TesDbContext.cs @@ -1,40 +1,40 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using System; +using Azure.Core; +using Azure.Identity; using Microsoft.EntityFrameworkCore; using Tes.Models; +using Tes.Utilities; namespace Tes.Repository { public class TesDbContext : DbContext { - public const string TesTasksPostgresTableName = "testasks"; + private const int maxBatchSize = 1000; + private readonly PostgresConnectionStringUtility connectionStringUtility = null!; public TesDbContext() { // Default constructor, which is required to run the EF migrations tool, // "dotnet ef migrations add InitialCreate" + // DI will NOT use this constructor } - public TesDbContext(string connectionString) + public TesDbContext(PostgresConnectionStringUtility connectionStringUtility) { - ArgumentException.ThrowIfNullOrEmpty(connectionString, nameof(connectionString)); - ConnectionString = connectionString; + this.connectionStringUtility = connectionStringUtility; } - public string ConnectionString { get; set; } public DbSet TesTasks { get; set; } protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) { - if (!optionsBuilder.IsConfigured) - { - // use PostgreSQL - optionsBuilder - .UseNpgsql(ConnectionString, options => options.MaxBatchSize(1000)) - .UseLowerCaseNamingConvention(); - } + string connectionString = this.connectionStringUtility.GetConnectionString().Result; + + optionsBuilder + .UseNpgsql(connectionString, options => options.MaxBatchSize(maxBatchSize)) + .UseLowerCaseNamingConvention(); } } } diff --git a/src/Tes/Repository/TesTaskPostgreSqlRepository.cs b/src/Tes/Repository/TesTaskPostgreSqlRepository.cs index 5fb345841..de3f01be3 100644 --- a/src/Tes/Repository/TesTaskPostgreSqlRepository.cs +++ b/src/Tes/Repository/TesTaskPostgreSqlRepository.cs @@ -11,8 +11,8 @@ namespace Tes.Repository using System.Threading; using System.Threading.Tasks; using Microsoft.EntityFrameworkCore; + using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; - using Microsoft.Extensions.Options; using Polly; using Tes.Models; using Tes.Utilities; @@ -23,34 +23,24 @@ namespace Tes.Repository /// public sealed class TesTaskPostgreSqlRepository : PostgreSqlCachingRepository, IRepository { + private readonly IServiceScopeFactory _scopeFactory = null!; + /// /// Default constructor that also will create the schema if it does not exist /// /// /// /// - public TesTaskPostgreSqlRepository(IOptions options, ILogger logger, ICache cache = null) + public TesTaskPostgreSqlRepository(ILogger logger = default, IServiceScopeFactory scopeFactory = default, ICache cache = null) : base(logger, cache) { - var connectionString = new ConnectionStringUtility().GetPostgresConnectionString(options); - CreateDbContext = () => { return new TesDbContext(connectionString); }; - using var dbContext = CreateDbContext(); + _scopeFactory = scopeFactory; + using var scope = _scopeFactory.CreateScope(); + using var dbContext = scope.ServiceProvider.GetRequiredService(); dbContext.Database.MigrateAsync().Wait(); WarmCacheAsync(CancellationToken.None).Wait(); } - /// - /// Constructor for testing to enable mocking DbContext - /// - /// A delegate that creates a TesTaskPostgreSqlRepository context - public TesTaskPostgreSqlRepository(Func createDbContext) - : base() - { - CreateDbContext = createDbContext; - using var dbContext = createDbContext(); - dbContext.Database.MigrateAsync().Wait(); - } - private async Task WarmCacheAsync(CancellationToken cancellationToken) { if (_cache is null) @@ -224,7 +214,8 @@ private async Task GetItemFromCacheOrDatabase(string id, bo if (!_cache?.TryGetValue(id, out item) ?? true) { - using var dbContext = CreateDbContext(); + using var scope = _scopeFactory.CreateScope(); + using var dbContext = scope.ServiceProvider.GetRequiredService(); // Search for Id within the JSON item = await _asyncPolicy.ExecuteAsync(ct => dbContext.TesTasks.FirstOrDefaultAsync(t => t.Json.Id == id, ct), cancellationToken); @@ -252,7 +243,8 @@ private async Task> InternalGetItemsAsync(Expression q.OrderBy(t => t.Json.CreationTime).ThenBy(t => t.Json.Id); orderBy = pagination is null ? orderBy : q => q.OrderBy(t => t.Json.Id); - using var dbContext = CreateDbContext(); + using var scope = _scopeFactory.CreateScope(); + using var dbContext = scope.ServiceProvider.GetRequiredService(); return (await GetItemsAsync(dbContext.TesTasks, WhereTesTask(predicate), cancellationToken, orderBy, pagination)).Select(item => EnsureActiveItemInCache(item, t => t.Json.Id, t => t.Json.IsActiveState()).Json); } diff --git a/src/Tes/Tes.csproj b/src/Tes/Tes.csproj index ab248c35b..e75c31818 100644 --- a/src/Tes/Tes.csproj +++ b/src/Tes/Tes.csproj @@ -5,18 +5,20 @@ + - - + + all runtime; build; native; contentfiles; analyzers; buildtransitive - + + - + diff --git a/src/Tes/Utilities/PostgresConnectionStringUtility.cs b/src/Tes/Utilities/PostgresConnectionStringUtility.cs index e5263e47d..25476141e 100644 --- a/src/Tes/Utilities/PostgresConnectionStringUtility.cs +++ b/src/Tes/Utilities/PostgresConnectionStringUtility.cs @@ -3,35 +3,81 @@ using System; using System.Text; -using Microsoft.Extensions.Options; +using System.Threading.Tasks; +using Azure.Core; using Tes.Models; namespace Tes.Utilities { - public class ConnectionStringUtility + public class PostgresConnectionStringUtility { - public string GetPostgresConnectionString(IOptions options) + private const string azureDatabaseForPostgresqlScope = "https://ossrdbms-aad.database.windows.net/.default"; + private readonly string connectionString = null!; + private readonly TokenCredential tokenCredential = null!; + public bool UseManagedIdentity { get; set; } + + public PostgresConnectionStringUtility(PostgreSqlOptions options, TokenCredential tokenCredential) + { + this.tokenCredential = tokenCredential; + connectionString = InternalGetConnectionString(options); + UseManagedIdentity = options.UseManagedIdentity; + } + + public async Task GetConnectionString() { - ArgumentException.ThrowIfNullOrEmpty(options.Value.ServerName, nameof(options.Value.ServerName)); - ArgumentException.ThrowIfNullOrEmpty(options.Value.ServerNameSuffix, nameof(options.Value.ServerNameSuffix)); - ArgumentException.ThrowIfNullOrEmpty(options.Value.ServerPort, nameof(options.Value.ServerPort)); - ArgumentException.ThrowIfNullOrEmpty(options.Value.ServerSslMode, nameof(options.Value.ServerSslMode)); - ArgumentException.ThrowIfNullOrEmpty(options.Value.DatabaseName, nameof(options.Value.DatabaseName)); - ArgumentException.ThrowIfNullOrEmpty(options.Value.DatabaseUserLogin, nameof(options.Value.DatabaseUserLogin)); - ArgumentException.ThrowIfNullOrEmpty(options.Value.DatabaseUserPassword, nameof(options.Value.DatabaseUserPassword)); - - if (options.Value.ServerName.Contains(options.Value.ServerNameSuffix, StringComparison.OrdinalIgnoreCase)) + if (UseManagedIdentity) { - throw new ArgumentException($"'{nameof(options.Value.ServerName)}' should only contain the name of the server like 'myserver' and NOT the full host name like 'myserver{options.Value.ServerNameSuffix}'", nameof(options.Value.ServerName)); + // Use AAD managed identity + // https://learn.microsoft.com/en-us/azure/postgresql/single-server/how-to-connect-with-managed-identity + // https://learn.microsoft.com/en-us/azure/postgresql/single-server/concepts-azure-ad-authentication + + var accessToken = await tokenCredential.GetTokenAsync( + new TokenRequestContext(scopes: new string[] { azureDatabaseForPostgresqlScope }), System.Threading.CancellationToken.None); + + return $"{connectionString}Password={accessToken.Token};"; + } + + return connectionString; + } + + private string InternalGetConnectionString(PostgreSqlOptions options) + { + ArgumentException.ThrowIfNullOrEmpty(options.ServerName, nameof(options.ServerName)); + ArgumentException.ThrowIfNullOrEmpty(options.ServerNameSuffix, nameof(options.ServerNameSuffix)); + ArgumentException.ThrowIfNullOrEmpty(options.ServerPort, nameof(options.ServerPort)); + ArgumentException.ThrowIfNullOrEmpty(options.ServerSslMode, nameof(options.ServerSslMode)); + ArgumentException.ThrowIfNullOrEmpty(options.DatabaseName, nameof(options.DatabaseName)); + ArgumentException.ThrowIfNullOrEmpty(options.DatabaseUserLogin, nameof(options.DatabaseUserLogin)); + + if (!options.UseManagedIdentity) + { + // Ensure password is set if NOT using Managed Identity + ArgumentException.ThrowIfNullOrEmpty(options.DatabaseUserPassword, nameof(options.DatabaseUserPassword)); + } + + if (options.UseManagedIdentity && !string.IsNullOrWhiteSpace(options.DatabaseUserPassword)) + { + // throw if password IS set when using Managed Identity + throw new ArgumentException("DatabaseUserPassword shall not be set if UseManagedIdentity is true"); + } + + if (options.ServerName.Contains(options.ServerNameSuffix, StringComparison.OrdinalIgnoreCase)) + { + throw new ArgumentException($"'{nameof(options.ServerName)}' should only contain the name of the server like 'myserver' and NOT the full host name like 'myserver{options.ServerNameSuffix}'", nameof(options.ServerName)); } var connectionStringBuilder = new StringBuilder(); - connectionStringBuilder.Append($"Server={options.Value.ServerName}{options.Value.ServerNameSuffix};"); - connectionStringBuilder.Append($"Database={options.Value.DatabaseName};"); - connectionStringBuilder.Append($"Port={options.Value.ServerPort};"); - connectionStringBuilder.Append($"User Id={options.Value.DatabaseUserLogin};"); - connectionStringBuilder.Append($"Password={options.Value.DatabaseUserPassword};"); - connectionStringBuilder.Append($"SSL Mode={options.Value.ServerSslMode};"); + connectionStringBuilder.Append($"Server={options.ServerName}{options.ServerNameSuffix};"); + connectionStringBuilder.Append($"Database={options.DatabaseName};"); + connectionStringBuilder.Append($"Port={options.ServerPort};"); + connectionStringBuilder.Append($"SSL Mode={options.ServerSslMode};"); + connectionStringBuilder.Append($"User Id={options.DatabaseUserLogin};"); + + if (!options.UseManagedIdentity) + { + connectionStringBuilder.Append($"Password={options.DatabaseUserPassword};"); + } + return connectionStringBuilder.ToString(); } } diff --git a/src/TesApi.Tests/Repository/TesTaskPostgreSqlRepositoryIntegrationTests.cs b/src/TesApi.Tests/Repository/TesTaskPostgreSqlRepositoryIntegrationTests.cs index ca192e7e6..99058d260 100644 --- a/src/TesApi.Tests/Repository/TesTaskPostgreSqlRepositoryIntegrationTests.cs +++ b/src/TesApi.Tests/Repository/TesTaskPostgreSqlRepositoryIntegrationTests.cs @@ -61,10 +61,9 @@ await PostgreSqlTestUtility.CreateTestDbAsync( DatabaseUserPassword = adminPw }; - var optionsMock = new Mock>(); - optionsMock.Setup(x => x.Value).Returns(options); - var connectionString = new ConnectionStringUtility().GetPostgresConnectionString(optionsMock.Object); - repository = new TesTaskPostgreSqlRepository(() => new TesDbContext(connectionString)); + var optionsMock = new Mock(); + var connectionString = new PostgresConnectionStringUtility(optionsMock.Object, null); + repository = new TesTaskPostgreSqlRepository(); Console.WriteLine("Creation complete."); } diff --git a/src/TesApi.Web/Startup.cs b/src/TesApi.Web/Startup.cs index e28d2c4d2..5463f7e78 100644 --- a/src/TesApi.Web/Startup.cs +++ b/src/TesApi.Web/Startup.cs @@ -20,6 +20,7 @@ using Tes.ApiClients.Options; using Tes.Models; using Tes.Repository; +using Tes.Utilities; using TesApi.Filters; using TesApi.Web.Management; using TesApi.Web.Management.Batch; @@ -77,6 +78,10 @@ public void ConfigureServices(IServiceCollection services) .Configure(configuration.GetSection(MarthaOptions.SectionName)) .AddMemoryCache(o => o.ExpirationScanFrequency = TimeSpan.FromHours(12)) + + .AddSingleton() + .AddSingleton() + .AddDbContext(ServiceLifetime.Scoped) .AddSingleton, TesRepositoryCache>() .AddSingleton() .AddSingleton() @@ -108,7 +113,6 @@ public void ConfigureServices(IServiceCollection services) .AddSingleton() .AddSingleton() .AddSingleton() - .AddSingleton(s => new DefaultAzureCredential()) .AddSingleton() .AddSingleton() .AddTransient() diff --git a/src/TesApi.Web/TesApi.Web.csproj b/src/TesApi.Web/TesApi.Web.csproj index ea889c7b5..5697eeff8 100644 --- a/src/TesApi.Web/TesApi.Web.csproj +++ b/src/TesApi.Web/TesApi.Web.csproj @@ -12,7 +12,7 @@ - + diff --git a/src/deploy-tes-on-azure/deploy-tes-on-azure.csproj b/src/deploy-tes-on-azure/deploy-tes-on-azure.csproj index 6c662bbee..373b43bb0 100644 --- a/src/deploy-tes-on-azure/deploy-tes-on-azure.csproj +++ b/src/deploy-tes-on-azure/deploy-tes-on-azure.csproj @@ -26,7 +26,7 @@ - +