auth-proxy/src/Program.cs

312 lines
11 KiB
C#

using System.Globalization;
using System.Net;
using System.Security.Cryptography;
using System.Text;
using System.Text.Json;
using Fido2NetLib;
using Fido2NetLib.Objects;
using Microsoft.Data.Sqlite;
using SQLitePCL;
namespace WebauthnProxy;
public class AddKeyRequest {
public required string Password { get; set; }
public required AuthenticatorAttestationRawResponse Response { get; set; }
}
public static class Program {
private const string COOKIE_NAME = "__Secure-Token";
private const string SESS_ATTESTATION_KEY = "fido2.attestationOptions";
private const string SESS_ASSERTION_KEY = "fido2.assertionOptions";
private static readonly string s_loginHtmlCache = File.ReadAllText(Path.Combine(
AppDomain.CurrentDomain.BaseDirectory, "login.html"));
private static readonly string s_domain = Environment
.GetEnvironmentVariable("WEBAUTHN_DOMAIN") ?? "localhost";
private static readonly string s_db = Environment
.GetEnvironmentVariable("WEBAUTHN_DB") ?? "credentials.db";
private static readonly long s_lifetime = long.Parse(Environment
.GetEnvironmentVariable("WEBAUTHN_LIFETIME") ?? "7",
CultureInfo.InvariantCulture) * 60 * 60 * 24;
private static readonly string? s_password = Environment
.GetEnvironmentVariable("WEBAUTHN_PASSWORD");
private static readonly int s_port = int.Parse(Environment
.GetEnvironmentVariable("WEBAUTHN_PORT") ?? "5000",
CultureInfo.InvariantCulture);
private static readonly Fido2 s_fido2 = new(new Fido2Configuration {
ServerDomain = s_domain,
ServerName = "WebauthnProxy",
Origins = new(new[] { $"http{(
s_domain == "localhost"
? string.Empty
: "s")}://{s_domain}{(
s_domain == "localhost"
? $":{s_port}"
: string.Empty)}" }),
});
private static readonly List<PublicKeyCredentialDescriptor> s_keys = new();
private static string ConnectionString { get => $"Data Source={s_db}"; }
public static void Main(string[] args) {
var app = Initialize(args);
app.UseSession();
app.MapGet("/favicon.ico", () => Results.File(Convert.FromBase64String(
$"AAABAAEAEBAAAAAAAABoBQAAFgAAACgAAAAQAAAAIAAAAAEACAAAAAAAAAEAAAAAAAAAAAAAAAEAAAAAAAD///8{new string('A', 1788)}="),
contentType: "image/x-icon"));
app.MapGet("/auth/check", async (context) => {
var token = context.Request.Cookies[COOKIE_NAME];
Console.WriteLine(token);
if (!TokenIsValid(token)) {
context.Response.ContentType = "text/plain";
context.Response.StatusCode = (int)HttpStatusCode.Unauthorized;
await context.Response.WriteAsync("unauthorized");
return;
}
context.Response.ContentType = "text/plain";
context.Response.StatusCode = (int)HttpStatusCode.OK;
await context.Response.WriteAsync("authorized");
});
app.MapGet("/auth/login", async (context) => {
context.Response.ContentType = "text/html";
context.Response.StatusCode = (int)HttpStatusCode.OK;
await context.Response.WriteAsync(s_loginHtmlCache);
});
app.MapGet("/auth/logout", (context) => {
if (context.Request.Cookies[COOKIE_NAME] == null)
return Task.CompletedTask;
context.Response.Cookies.Append(COOKIE_NAME, string.Empty, new CookieOptions {
Path = "/",
Secure = true,
HttpOnly = true,
MaxAge = TimeSpan.Zero,
Domain = s_domain,
});
context.Response.Redirect("/");
return Task.CompletedTask;
});
app.MapPost("/auth/key", async (context) => {
var options = s_fido2.GetAssertionOptions(
s_keys, UserVerificationRequirement.Discouraged);
context.Session.SetString(SESS_ASSERTION_KEY,
JsonSerializer.Serialize(options));
await context.Response.WriteAsJsonAsync(options);
});
app.MapPost("/auth/complete", async (context) => {
var assertionResponse = await context.Request
.ReadFromJsonAsync<AuthenticatorAssertionRawResponse>();
if (assertionResponse == null) {
context.Response.StatusCode = 401;
await context.Response.WriteAsJsonAsync(new { error = "bad response" });
return;
}
var optsJson = context.Session.GetString(SESS_ASSERTION_KEY);
context.Session.Remove(SESS_ASSERTION_KEY);
var opts = AssertionOptions.FromJson(optsJson);
using var connection = new SqliteConnection(ConnectionString);
connection.Open();
byte[]? pubKey;
using (var cmd = new SqliteCommand(
"select key from credentials where id=@Id",
connection)) {
var param = cmd.CreateParameter();
param.ParameterName = "@Id";
param.Value = Convert.ToBase64String(assertionResponse.Id);
cmd.Parameters.Add(param);
pubKey = cmd.ExecuteScalar() is string keyStr
? Convert.FromBase64String(keyStr)
: null;
if (pubKey == null) {
context.Response.StatusCode = 401;
await context.Response.WriteAsJsonAsync(new { error = "bad key" });
return;
}
}
var res = await s_fido2.MakeAssertionAsync(
assertionResponse, opts, pubKey, 0, (_, _) => Task.FromResult(true));
if (res.Status != "ok") {
context.Response.StatusCode = 401;
await context.Response.WriteAsJsonAsync(new { error = "auth failed" });
return;
}
context.Response.Cookies.Append(
COOKIE_NAME,
GenerateToken(connection),
new CookieOptions {
Path = "/",
Secure = true,
HttpOnly = true,
SameSite = SameSiteMode.None,
Domain = s_domain,
MaxAge = TimeSpan.FromSeconds(s_lifetime),
});
await context.Response.WriteAsJsonAsync(new { status = "ok" });
});
app.MapPost("/auth/new-key", async (context) => {
var user = new Fido2User {
Id = Encoding.UTF8.GetBytes("default"),
Name = "Default User",
DisplayName = "Default User",
};
var options = s_fido2.RequestNewCredential(
user,
new List<PublicKeyCredentialDescriptor>(),
AuthenticatorSelection.Default,
AttestationConveyancePreference.None);
context.Session.SetString(SESS_ATTESTATION_KEY,
JsonSerializer.Serialize(options));
await context.Response.WriteAsJsonAsync(options);
});
app.MapPost("/auth/add-key", async (context) => {
var req = await context.Request.ReadFromJsonAsync<AddKeyRequest>();
if (req == null) {
context.Response.StatusCode = 400;
return;
}
if (string.IsNullOrEmpty(s_password)
|| s_password != req.Password) {
context.Response.StatusCode = 401;
await context.Response.WriteAsJsonAsync(new { error = "bad password" });
return;
}
var optsJson = context.Session.GetString(SESS_ATTESTATION_KEY);
context.Session.Remove(SESS_ATTESTATION_KEY);
var opts = CredentialCreateOptions.FromJson(optsJson);
var cred = await s_fido2.MakeNewCredentialAsync(
req.Response, opts, (_, _) => Task.FromResult(true));
var descriptor = new PublicKeyCredentialDescriptor(
cred.Result!.CredentialId);
// store in list and save to database
s_keys.Add(descriptor);
using var connection = new SqliteConnection(ConnectionString);
connection.Open();
using var cmd = new SqliteCommand(
@"insert into credentials (id, key) values (@Id, @Key)",
connection);
var idParam = cmd.CreateParameter();
idParam.ParameterName = "@Id";
idParam.Value = Convert.ToBase64String(cred.Result.CredentialId);
var keyParam = cmd.CreateParameter();
keyParam.ParameterName = "@Key";
keyParam.Value = Convert.ToBase64String(cred.Result.PublicKey);
cmd.Parameters.Add(idParam);
cmd.Parameters.Add(keyParam);
await cmd.ExecuteNonQueryAsync();
});
app.Run($"http://0.0.0.0:{s_port}");
}
private static string GenerateToken(SqliteConnection connection) {
var rng = RandomNumberGenerator.Create();
var bytes = new byte[32];
rng.GetBytes(bytes);
var token = Convert.ToBase64String(bytes);
using var cmd = new SqliteCommand(@"insert into tokens (token) values (@Token)",
connection);
var param = cmd.CreateParameter();
param.ParameterName = "@Token";
param.Value = token;
cmd.Parameters.Add(param);
cmd.ExecuteNonQuery();
return token;
}
private static bool TokenIsValid(string? token) {
if (string.IsNullOrWhiteSpace(token)) return false;
using var connection = new SqliteConnection(ConnectionString);
connection.Open();
using var cmd = new SqliteCommand(
@"select 1 from tokens where
token=@Token
and julianday(created_date, @Timeout) > julianday(CURRENT_TIMESTAMP)",
connection);
var tokenParam = cmd.CreateParameter();
tokenParam.ParameterName = "@Token";
tokenParam.Value = token;
var timeoutParam = cmd.CreateParameter();
timeoutParam.ParameterName = "@Timeout";
timeoutParam.Value = $"+{s_lifetime} seconds";
cmd.Parameters.Add(tokenParam);
cmd.Parameters.Add(timeoutParam);
return cmd.ExecuteScalar() != null;
}
private static void RemoveExpiredTokens() {
using var connection = new SqliteConnection(ConnectionString);
connection.Open();
using var cmd = new SqliteCommand(
@"delete from tokens where
julianday(created_date, @Timeout) < julianday(CURRENT_TIMESTAMP)",
connection);
var timeoutParam = cmd.CreateParameter();
timeoutParam.ParameterName = "@Timeout";
timeoutParam.Value = $"+{s_lifetime} seconds";
cmd.Parameters.Add(timeoutParam);
cmd.ExecuteNonQuery();
}
private static WebApplication Initialize(string[] args) {
Batteries.Init();
using var connection = new SqliteConnection(ConnectionString);
connection.Open();
using (var cmd = new SqliteCommand(
@"select 1 from sqlite_master
where type='table' and name='credentials'",
connection)) {
var exists = cmd.ExecuteScalar();
if (exists == null) {
var schemaPath = Path.Combine(AppDomain.CurrentDomain.BaseDirectory, "schema.sql");
if (!File.Exists(schemaPath)) {
throw new FileNotFoundException(schemaPath);
}
var schema = File.ReadAllText(schemaPath);
using var schemaCmd = new SqliteCommand(schema, connection);
schemaCmd.ExecuteNonQuery();
}
}
RemoveExpiredTokens();
// read credentials
using (var cmd = new SqliteCommand(
@"select id from credentials",
connection)) {
var reader = cmd.ExecuteReader();
while (reader.Read()) {
var id = reader.GetString(0);
var desc = new PublicKeyCredentialDescriptor(
Convert.FromBase64String(id));
if (desc != null) s_keys.Add(desc);
}
}
var builder = WebApplication.CreateBuilder(args);
builder.Services.AddDistributedMemoryCache();
builder.Services.AddSession();
return builder.Build();
}
}