nom-nom-nix-gc/src/models/mod.rs

166 lines
5.4 KiB
Rust

use std::collections::HashMap;
use std::fs;
use std::ops::DerefMut;
use std::sync::Arc;
use url::Url;
use anyhow::{Result, Context};
use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod};
use handlebars::Handlebars;
use serde::{Deserialize, Serialize};
use serde_json;
use tokio::sync::RwLock;
use tokio_postgres::NoTls;
use webauthn_rs::prelude::{Uuid, PasskeyRegistration, Passkey};
use webauthn_rs::{Webauthn, WebauthnBuilder};
#[derive(Deserialize, Debug, Clone)]
pub struct Configuration {
pub url: String,
pub db_host: Option<String>,
pub db_port: Option<u16>,
pub db_name: String
}
pub fn read_config(config_path: &str) -> anyhow::Result<Configuration> {
let content = fs::read_to_string(config_path)
.with_context(|| format!("Cannot read the configuration file at {}.", config_path))?;
let res: Configuration = serde_json::from_str(&content)
.context("Cannot parse JSON configuration.")?;
Ok(res)
}
pub type DbField<T> = Arc<RwLock<T>>;
#[derive(Clone)]
pub struct Db {
pub user_keys: DbField<HashMap<Uuid, Passkey>>,
pub user_uuid_object: DbField<HashMap<Uuid, User>>,
}
#[derive(Clone)]
pub struct TempSession {
pub user_registrations: Arc<RwLock<HashMap<Uuid,PendingRegistration>>>
}
#[derive(Clone)]
pub struct PendingRegistration {
pub user: User,
pub registration: PasskeyRegistration
}
#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Hash)]
pub struct User {
pub user_name: String,
}
#[derive(Clone)]
pub struct AppState<'a>{
pub webauthn: Arc<Webauthn>,
pub db: Pool,
pub hbs: Arc<Handlebars<'a>>,
pub session: TempSession
}
mod embedded {
use refinery::embed_migrations;
embed_migrations!();
}
impl AppState<'_> {
pub fn new(conf: Configuration) -> Self {
let rp = "localhost";
let rp_origin = Url::parse(&conf.url).expect("Invalid URL");
let builder = WebauthnBuilder::new(rp, &rp_origin).expect("Invalid configuration");
let builder = builder.rp_name("LocalHost");
let webauthn = Arc::new(builder.build().expect("Invalid configuration"));
let hbs = Arc::new(crate::templates::new().unwrap());
let session: TempSession = TempSession {
user_registrations: Arc::new(RwLock::new(HashMap::new()))
};
let mut pg_config = tokio_postgres::Config::new();
pg_config.host_path(conf.db_host.unwrap().clone());
pg_config.dbname(&conf.db_name);
pg_config.port(conf.db_port.unwrap());
let mgr_config = ManagerConfig {
recycling_method: RecyclingMethod::Fast
};
println!("{:?}", pg_config);
let mgr = Manager::from_config(pg_config, NoTls, mgr_config);
let pool = Pool::builder(mgr).max_size(16).build().unwrap();
AppState {
webauthn,
db: pool,
hbs,
session
}
}
/**
Generate a new registration uuid for the given `username` and
saves it to the DB.
*/
pub async fn generate_registration_link(&self, user: User) -> Result<Uuid> {
let uuid = Uuid::new_v4();
let mut conn = self.db.get().await?;
let client = conn.deref_mut();
let stmt = client.prepare_cached("INSERT INTO PendingRegistrations (id, user_name) VALUES ($1, $2)").await?;
client.query(&stmt, &[&uuid, &user.user_name]).await?;
Ok(uuid)
}
/**
Retrieves the `User` attached to the registration `Uuid`.
Returns `Some` `User` if it can find it, `None` if the `Uuid`
doesn't exist in the DB.
*/
pub async fn retrieve_registration_link(&self, uuid: Uuid) -> Result<Option<User>> {
let mut conn = self.db.get().await?;
let client = conn.deref_mut();
let stmt = client.prepare_cached("SELECT user_name FROM PendingRegistrations WHERE id=$1").await?;
let row = client.query_one(&stmt, &[&uuid]).await?;
Ok(Some(User{ user_name: row.get(0)}))
}
pub async fn run_migrations(&self) -> Result<()> {
let mut conn = self.db.get().await?;
let client = conn.deref_mut().deref_mut();
let report = embedded::migrations::runner().run_async(client).await?;
println!("{:?}", report);
Ok(())
}
pub async fn save_user(&self, user: &User) -> Result<()> {
let mut conn = self.db.get().await?;
let client = conn.deref_mut();
let stmt = client.prepare_cached("INSERT INTO Users (user_name) VALUES ($1, $2, $3)").await?;
let _ = client.query(&stmt, &[&user.user_name]).await?;
Ok(())
}
pub async fn save_user_key(&self, uuid: &Uuid, passkey: &Passkey) -> Result<()> {
let passkey_json: String = serde_json::to_string(&passkey)?;
let mut conn = self.db.get().await?;
let client = conn.deref_mut();
let stmt = client.prepare_cached("INSERT INTO Keys (key_dump, user_id) VALUES ($1,$2)").await?;
client.query(&stmt, &[&passkey_json, &uuid.to_string()]).await?;
Ok(())
}
pub async fn get_user_keys(&self, user_uuid: &Uuid) -> Result<Vec<Result<Passkey, serde_json::Error>>> {
let mut conn = self.db.get().await?;
let client = conn.deref_mut();
let stmt = client.prepare_cached("SELECT key_dump FROM Keys WHERE user_id = $1").await?;
let res = client.query(&stmt, &[&user_uuid.to_string()]).await?;
let res2 = res.iter().map(|row| {
serde_json::from_str(row.get(0))
}).collect();
Ok(res2)
}
}