nom-nom-nix-gc/src/models/mod.rs

243 lines
8.9 KiB
Rust

use std::collections::HashMap;
use std::fs;
use std::ops::DerefMut;
use std::sync::Arc;
use postgres_types::{FromSql, ToSql};
use url::Url;
use anyhow::{Result, Context};
use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod};
use handlebars::Handlebars;
use serde::{Deserialize, Serialize};
use serde_json::{self, Value};
use tokio::sync::RwLock;
use tokio_postgres::NoTls;
use webauthn_rs::prelude::{Uuid, PasskeyRegistration, Passkey, PasskeyAuthentication, AuthenticationResult};
use webauthn_rs::{Webauthn, WebauthnBuilder};
#[derive(Deserialize, Debug, Clone)]
pub struct Configuration {
pub url: String,
pub db_host: Option<String>,
pub db_port: Option<u16>,
pub db_name: String
}
pub fn read_config(config_path: &str) -> anyhow::Result<Configuration> {
let content = fs::read_to_string(config_path)
.with_context(|| format!("Cannot read the configuration file at {}.", config_path))?;
let res: Configuration = serde_json::from_str(&content)
.context("Cannot parse JSON configuration.")?;
Ok(res)
}
pub type DbField<T> = Arc<RwLock<T>>;
#[derive(Clone, Debug)]
pub struct TempSession {
pub user_registrations: Arc<RwLock<HashMap<RegistrationUuid,PasskeyRegistration>>>,
pub user_pending_logins: Arc<RwLock<HashMap<LoginUuid,(PasskeyAuthentication, User)>>>,
pub user_sessions: Arc<RwLock<HashMap<SessionUuid,User>>>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)]
pub struct RegistrationUuid(pub Uuid);
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)]
pub struct LoginUuid(pub Uuid);
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, FromSql, ToSql)]
#[postgres(transparent)]
pub struct UserUuid(pub Uuid);
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct SessionUuid(pub Uuid);
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)]
pub struct User {
pub uuid: UserUuid,
pub name: String,
}
#[derive(Clone, Debug)]
pub struct AppState<'a>{
pub webauthn: Arc<Webauthn>,
pub db: Pool,
pub hbs: Arc<Handlebars<'a>>,
pub session: TempSession
}
mod embedded {
use refinery::embed_migrations;
embed_migrations!();
}
#[derive(Clone, Debug)]
pub struct Key {
pub name: String,
pub key: Passkey
}
impl AppState<'_> {
pub fn new(conf: Configuration) -> Self {
let rp = "localhost";
let rp_origin = Url::parse(&conf.url).expect("Invalid URL");
let builder = WebauthnBuilder::new(rp, &rp_origin).expect("Invalid configuration");
let builder = builder.rp_name("LocalHost");
let webauthn = Arc::new(builder.build().expect("Invalid configuration"));
let hbs = Arc::new(crate::templates::new().unwrap());
let session: TempSession = TempSession {
user_registrations: Arc::new(RwLock::new(HashMap::new())),
user_pending_logins: Arc::new(RwLock::new(HashMap::new())),
user_sessions: Arc::new(RwLock::new(HashMap::new()))
};
let mut pg_config = tokio_postgres::Config::new();
pg_config.host_path(conf.db_host.unwrap().clone());
pg_config.dbname(&conf.db_name);
pg_config.port(conf.db_port.unwrap());
let mgr_config = ManagerConfig {
recycling_method: RecyclingMethod::Fast
};
let mgr = Manager::from_config(pg_config, NoTls, mgr_config);
let pool = Pool::builder(mgr).max_size(16).build().unwrap();
AppState {
webauthn,
db: pool,
hbs,
session
}
}
/**
Generate a new registration uuid for the given `UserUuid` and
saves it to the DB.
*/
pub async fn generate_registration_uuid(&self, user_id: &UserUuid) -> Result<RegistrationUuid> {
let registration_uuid = Uuid::new_v4();
let mut conn = self.db.get().await?;
let client = conn.deref_mut();
let stmt = client.prepare_cached("INSERT INTO PendingRegistrations (id, user_id) VALUES ($1, $2)").await?;
client.query(&stmt, &[&registration_uuid, &user_id.0]).await?;
Ok(RegistrationUuid(registration_uuid))
}
/**
Retrieves the registred `User` attached to the `RegistrationUuid`.
Returns `Some` `User` if it can find it, `None` if the `RegistrationUuid`
doesn't exist in the registration table.
*/
pub async fn retrieve_registration_user(&self, registration_id: &RegistrationUuid) -> Result<Option<User>> {
let mut conn = self.db.get().await?;
let client = conn.deref_mut();
let stmt = client.prepare_cached("SELECT u.id, u.user_name FROM Users u INNER JOIN PendingRegistrations r ON r.user_id = u.id WHERE r.id=$1").await?;
let row = client.query_one(&stmt, &[&registration_id.0]).await?;
let usr = User {
uuid: row.get(0),
name: row.get(1),
};
Ok(Some(usr))
}
/**
Deletes the registration `Uuid` attached to the `UserUuid` passed in
parameter.
*/
pub async fn delete_registration_link(&self, user_id: &RegistrationUuid) -> Result<()> {
let mut conn = self.db.get().await?;
let client = conn.deref_mut();
let stmt = client.prepare_cached("DELETE FROM PendingRegistrations WHERE id=$1").await?;
client.query(&stmt, &[&user_id.0]).await?;
Ok(())
}
pub async fn run_migrations(&self) -> Result<()> {
let mut conn = self.db.get().await?;
let client = conn.deref_mut().deref_mut();
let report = embedded::migrations::runner().run_async(client).await?;
println!("{:?}", report);
Ok(())
}
pub async fn get_user(&self, username: &str) -> Result<User> {
let mut conn = self.db.get().await?;
let client = conn.deref_mut();
let stmt = client.prepare_cached("SELECT id, user_name FROM Users WHERE user_name = $1").await?;
let row = client.query_one(&stmt, &[&username]).await?;
Ok(User{ uuid: row.get(0), name: row.get(1) })
}
pub async fn save_user(&self, user: &User) -> Result<()> {
let mut conn = self.db.get().await?;
let client = conn.deref_mut();
let stmt = client.prepare_cached("INSERT INTO Users (id, user_name) VALUES ($1, $2)").await?;
let _ = client.query(&stmt, &[&user.uuid.0, &user.name]).await?;
Ok(())
}
pub async fn save_user_key(&self, user_id: &UserUuid, passkey: &Key) -> Result<()> {
let passkey_json: Value = serde_json::to_value(&passkey.key).unwrap();
let mut conn = self.db.get().await?;
let client = conn.deref_mut();
let stmt = client.prepare_cached("INSERT INTO Keys (user_id, name, key_dump) VALUES ($1, $2, $3)").await?;
client.query(&stmt, &[&user_id.0, &passkey.name, &passkey_json]).await?;
Ok(())
}
pub async fn get_user_keys(&self, user_id: &UserUuid) -> Result<Vec<Key>> {
let mut conn = self.db.get().await?;
let client = conn.deref_mut();
let stmt = client.prepare_cached("SELECT name, key_dump FROM Keys WHERE user_id = $1").await?;
let res = client.query(&stmt, &[&user_id.0]).await?;
let res = res.iter().map(|row| {
let key: Value = row.get(1);
Key {
name: row.get(0),
// We can safely assume the JSON was valid, postgres
// would reject it if it wasn't. We can safely unwrap
// here.
key: serde_json::from_value(key).unwrap()
}}).collect();
Ok(res)
}
/**
Updates the keys for the user `UserUuid` according to the last
successful login `AuthenticationResult`.
*/
pub async fn update_user_keys(&self, user_id: &UserUuid, auth_res: AuthenticationResult) -> Result<()> {
let mut conn = self.db.get().await?;
let transaction = conn.deref_mut().transaction().await?;
// Retrieve user keys
let stmt = transaction.prepare_cached("SELECT name, key_dump FROM Keys WHERE user_id = $1").await?;
let res = transaction.query(&stmt, &[&user_id.0]).await?;
let mut keys: Vec<Key> = res.iter().map(|row| {
let key = row.get(1);
Key { name: row.get(0), key: serde_json::from_value(key).unwrap() }
}).collect();
// Update keys
keys.iter_mut().for_each(
|key| {
key.key.update_credential(&auth_res);
}
);
// Delete current keys, save updated keys
let stmt = transaction.prepare_cached("DELETE FROM Keys WHERE user_id = $1").await?;
let _ = transaction.execute(&stmt, &[&user_id.0]).await?;
let stmt = transaction.prepare_cached("INSERT INTO Keys (user_id, name, key_dump) VALUES ($1,$2,$3)").await?;
for key in keys {
let passkey_val: Value = serde_json::to_value(&key.key).unwrap();
let _ = transaction.execute(&stmt, &[&user_id.0, &key.name, &passkey_val]).await?;
}
transaction.commit().await?;
Ok(())
}
}