227 lines
7.9 KiB
Rust
227 lines
7.9 KiB
Rust
use std::collections::HashMap;
|
|
use std::fs;
|
|
use std::ops::DerefMut;
|
|
use std::sync::Arc;
|
|
use postgres_types::{FromSql, ToSql};
|
|
use url::Url;
|
|
|
|
use anyhow::{Result, Context};
|
|
use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod};
|
|
use handlebars::Handlebars;
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json;
|
|
use tokio::sync::RwLock;
|
|
use tokio_postgres::NoTls;
|
|
use webauthn_rs::prelude::{Uuid, PasskeyRegistration, Passkey, PasskeyAuthentication, AuthenticationResult};
|
|
use webauthn_rs::{Webauthn, WebauthnBuilder};
|
|
|
|
mod authentication;
|
|
|
|
#[derive(Deserialize, Debug, Clone)]
|
|
pub struct Configuration {
|
|
pub url: String,
|
|
pub db_host: Option<String>,
|
|
pub db_port: Option<u16>,
|
|
pub db_name: String
|
|
}
|
|
|
|
|
|
pub fn read_config(config_path: &str) -> anyhow::Result<Configuration> {
|
|
let content = fs::read_to_string(config_path)
|
|
.with_context(|| format!("Cannot read the configuration file at {}.", config_path))?;
|
|
let res: Configuration = serde_json::from_str(&content)
|
|
.context("Cannot parse JSON configuration.")?;
|
|
Ok(res)
|
|
}
|
|
|
|
pub type DbField<T> = Arc<RwLock<T>>;
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct TempSession {
|
|
pub user_registrations: Arc<RwLock<HashMap<RegistrationUuid,PasskeyRegistration>>>,
|
|
pub user_pending_logins: Arc<RwLock<HashMap<LoginUuid,(PasskeyAuthentication, User)>>>,
|
|
pub user_sessions: Arc<RwLock<HashMap<SessionUuid,User>>>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)]
|
|
pub struct RegistrationUuid(pub Uuid);
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)]
|
|
pub struct LoginUuid(pub Uuid);
|
|
|
|
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, FromSql, ToSql)]
|
|
#[postgres(transparent)]
|
|
pub struct UserUuid(pub Uuid);
|
|
|
|
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
|
|
pub struct SessionUuid(pub Uuid);
|
|
|
|
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
|
|
pub struct ProjectUuid(pub Uuid);
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)]
|
|
pub struct User {
|
|
pub uuid: UserUuid,
|
|
pub name: String,
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct AppState<'a>{
|
|
pub webauthn: Arc<Webauthn>,
|
|
pub db: Pool,
|
|
pub hbs: Arc<Handlebars<'a>>,
|
|
pub session: TempSession
|
|
}
|
|
|
|
mod embedded {
|
|
use refinery::embed_migrations;
|
|
embed_migrations!();
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct Key {
|
|
pub name: String,
|
|
pub key: Passkey
|
|
}
|
|
|
|
#[derive(Clone, Debug, Eq, PartialEq)]
|
|
pub struct BinaryCache {
|
|
pub name: String,
|
|
pub access_key: String,
|
|
pub secret_key: String,
|
|
pub region: String
|
|
}
|
|
|
|
#[derive(Clone, Debug, Eq, PartialEq)]
|
|
pub struct Project {
|
|
pub name: String,
|
|
}
|
|
|
|
impl AppState<'_> {
|
|
pub fn new(conf: Configuration) -> Self {
|
|
let rp = "localhost";
|
|
let rp_origin = Url::parse(&conf.url).expect("Invalid URL");
|
|
let builder = WebauthnBuilder::new(rp, &rp_origin).expect("Invalid configuration");
|
|
let builder = builder.rp_name("LocalHost");
|
|
let webauthn = Arc::new(builder.build().expect("Invalid configuration"));
|
|
let hbs = Arc::new(crate::templates::new().unwrap());
|
|
let session: TempSession = TempSession {
|
|
user_registrations: Arc::new(RwLock::new(HashMap::new())),
|
|
user_pending_logins: Arc::new(RwLock::new(HashMap::new())),
|
|
user_sessions: Arc::new(RwLock::new(HashMap::new()))
|
|
};
|
|
let mut pg_config = tokio_postgres::Config::new();
|
|
pg_config.host_path(conf.db_host.unwrap().clone());
|
|
pg_config.dbname(&conf.db_name);
|
|
pg_config.port(conf.db_port.unwrap());
|
|
let mgr_config = ManagerConfig {
|
|
recycling_method: RecyclingMethod::Fast
|
|
};
|
|
let mgr = Manager::from_config(pg_config, NoTls, mgr_config);
|
|
let pool = Pool::builder(mgr).max_size(16).build().unwrap();
|
|
|
|
AppState {
|
|
webauthn,
|
|
db: pool,
|
|
hbs,
|
|
session
|
|
}
|
|
}
|
|
|
|
/**
|
|
Generate a new registration uuid for the given `UserUuid` and
|
|
saves it to the DB.
|
|
*/
|
|
pub async fn generate_registration_uuid(&self, user_id: &UserUuid) -> Result<RegistrationUuid> {
|
|
authentication::generate_registration_uuid(&self.db, user_id).await
|
|
}
|
|
|
|
/**
|
|
Retrieves the registred `User` attached to the `RegistrationUuid`.
|
|
|
|
Returns `Some` `User` if it can find it, `None` if the `RegistrationUuid`
|
|
doesn't exist in the registration table.
|
|
*/
|
|
pub async fn retrieve_registration_user(&self, registration_id: &RegistrationUuid) -> Result<Option<User>> {
|
|
authentication::retrieve_registration_user(&self.db, registration_id).await
|
|
}
|
|
|
|
/**
|
|
Deletes the registration `Uuid` attached to the `UserUuid` passed in
|
|
parameter.
|
|
*/
|
|
pub async fn delete_registration_link(&self, user_id: &RegistrationUuid) -> Result<()> {
|
|
authentication::delete_registration_link(&self.db, user_id).await
|
|
}
|
|
|
|
pub async fn run_migrations(&self) -> Result<()> {
|
|
let mut conn = self.db.get().await?;
|
|
let client = conn.deref_mut().deref_mut();
|
|
let report = embedded::migrations::runner().run_async(client).await?;
|
|
println!("{:?}", report);
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn get_user(&self, username: &str) -> Result<User> {
|
|
authentication::get_user(&self.db, username).await
|
|
}
|
|
|
|
pub async fn save_user(&self, user: &User) -> Result<()> {
|
|
authentication::save_user(&self.db, user).await
|
|
}
|
|
|
|
pub async fn save_user_key(&self, user_id: &UserUuid, passkey: &Key) -> Result<()> {
|
|
authentication::save_user_key(&self.db, user_id, passkey).await
|
|
}
|
|
|
|
pub async fn get_user_keys(&self, user_id: &UserUuid) -> Result<Vec<Key>> {
|
|
authentication::get_user_keys(&self.db, user_id).await
|
|
}
|
|
|
|
/**
|
|
Updates the keys for the user `UserUuid` according to the last
|
|
successful login `AuthenticationResult`.
|
|
*/
|
|
pub async fn update_user_keys(&self, user_id: &UserUuid, auth_res: AuthenticationResult) -> Result<()> {
|
|
authentication::update_user_keys(&self.db, user_id, auth_res).await
|
|
}
|
|
|
|
pub async fn create_binary_cache(&self, binary_cache: &BinaryCache) -> Result<()> {
|
|
let conn = self.db.get().await?;
|
|
let stmt = conn.prepare_cached("INSERT INTO BinaryCaches (name, access_key, secret_key, region) VALUES ($1, $2, $3, $4)").await?;
|
|
let _ = conn.execute(&stmt, &[&binary_cache.name, &binary_cache.access_key, &binary_cache.secret_key, &binary_cache.region]).await?;
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn create_project(&self, binary_cache: &BinaryCache, project: &Project) -> Result<()> {
|
|
let conn = self.db.get().await?;
|
|
let stmt = conn.prepare_cached("INSERT INTO Projects (name, binary_cache_id) \
|
|
SELECT $1, b.id FROM BinaryCaches b \
|
|
WHERE b.name = $2").await?;
|
|
let _ = conn.execute(&stmt, &[&project.name, &binary_cache.name]).await?;
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn create_project_token(&self, project: &Project) -> Result<ProjectUuid> {
|
|
let conn = self.db.get().await?;
|
|
let token = Uuid::new_v4();
|
|
let stmt = conn.prepare_cached("INSERT INTO ProjectTokens (token, project_id) \
|
|
SELECT $1, p.id FROM Projects p \
|
|
WHERE name = $2").await?;
|
|
let _ = conn.execute(&stmt, &[&token, &project.name]).await?;
|
|
Ok(ProjectUuid(token))
|
|
}
|
|
|
|
pub async fn get_project(&self, token: &ProjectUuid) -> Result<Project> {
|
|
let conn = self.db.get().await?;
|
|
let stmt = conn.prepare_cached("SELECT name FROM Projects p \
|
|
INNER JOIN ProjectTokens t ON p.id = t.project_id \
|
|
WHERE t.token = $1").await?;
|
|
let row = conn.query_one(&stmt, &[&token.0]).await?;
|
|
Ok(Project {
|
|
name: row.get(0)
|
|
})
|
|
}
|
|
|
|
}
|