diff --git a/torrential/src/download.rs b/torrential/src/download.rs index d65b545f..ebcf63de 100644 --- a/torrential/src/download.rs +++ b/torrential/src/download.rs @@ -1,35 +1,28 @@ -use std::{collections::HashMap, time::Instant}; +use std::{collections::HashMap, path::PathBuf, time::Instant}; -use droplet_rs::versions::create_backend_constructor; +use droplet_rs::versions::{create_backend_constructor, types::VersionBackend}; use reqwest::StatusCode; -use crate::{AppInitData, DownloadContext, remote::{LibraryBackend, fetch_download_context}, util::ErrorOption}; - +use crate::{ + AppInitData, DownloadContext, + remote::{ContextResponseBody, LibraryBackend, fetch_download_context}, + util::ErrorOption, +}; pub async fn create_download_context<'a>( init_data: &AppInitData, game_id: String, version_name: String, -) -> Result, ErrorOption> { +) -> Result { let context = fetch_download_context(init_data.token.clone(), game_id, version_name.clone()).await?; - let (version_path, backend) = init_data - .libraries - .get(&context.library_id) - .ok_or(StatusCode::NOT_FOUND)?; - let version_path = version_path.join(context.library_path.clone()); - let version_path = match backend { - LibraryBackend::Filesystem => version_path.join(version_name), - LibraryBackend::FlatFilesystem => version_path, - }; + let backend = generate_backend(init_data, &context, &version_name)??; - let backend = - create_backend_constructor(&version_path).ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; - let backend = backend()?; - - let mut chunk_lookup_table = - HashMap::with_capacity_and_hasher(context.manifest.values().map(|v| v.ids.len()).sum(), Default::default()); + let mut chunk_lookup_table = HashMap::with_capacity_and_hasher( + context.manifest.values().map(|v| v.ids.len()).sum(), + Default::default(), + ); for (path, file_chunks) in context.manifest { let mut start = 0; @@ -47,3 +40,26 @@ pub async fn create_download_context<'a>( Ok(download_context) } + +fn generate_backend( + init_data: &AppInitData, + context: &ContextResponseBody, + version_name: &String, +) -> Result, anyhow::Error>, StatusCode> { + let (version_path, backend) = init_data + .libraries + .get(&context.library_id) + .ok_or(StatusCode::NOT_FOUND)?; + + let version_path = version_path.join(&context.library_path); + let version_path = match backend { + LibraryBackend::Filesystem => version_path.join(version_name), + LibraryBackend::FlatFilesystem => version_path, + }; + + let backend = + create_backend_constructor(&version_path).ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; + + let backend = backend(); + Ok(backend) +} diff --git a/torrential/src/main.rs b/torrential/src/main.rs index 3a9ae247..b9310b67 100644 --- a/torrential/src/main.rs +++ b/torrential/src/main.rs @@ -1,6 +1,6 @@ use anyhow::Result; -use dashmap::DashMap; -use droplet_rs::versions::types::{VersionBackend, VersionFile}; +use dashmap::{DashMap, mapref::one::RefMut}; +use droplet_rs::versions::types::{MinimumFileObject, VersionBackend, VersionFile}; use reqwest::header; use simple_logger::SimpleLogger; use std::{ @@ -19,11 +19,11 @@ use axum::{ }; use log::{error, info}; use serde::Deserialize; -use tokio::sync::{OnceCell, Semaphore}; +use tokio::sync::{OnceCell, Semaphore, SemaphorePermit}; use crate::{ download::create_download_context, - remote::{LibraryBackend, fetch_library_sources}, + remote::{LibraryBackend, LibrarySource, fetch_library_sources}, }; mod download; @@ -33,9 +33,9 @@ mod util; static GLOBAL_CONTEXT_SEMAPHORE: Semaphore = Semaphore::const_new(1); -struct DownloadContext<'a> { +struct DownloadContext { chunk_lookup_table: HashMap, - backend: Box, + backend: Box, last_access: Instant, } @@ -45,51 +45,122 @@ struct AppInitData { libraries: HashMap, } -struct AppState<'a> { +struct AppState { token: OnceCell, - context_cache: DashMap<(String, String), DownloadContext<'a>>, + context_cache: DashMap<(String, String), DownloadContext>, +} + +#[tokio::main] +async fn main() { + initialise_logger(); + + if let Ok(working_directory) = std::env::var("WORKING_DIRECTORY") { + set_current_dir(working_directory).expect("failed to change working directory"); + } + + let shared_state = Arc::new(AppState { + token: OnceCell::new(), + context_cache: DashMap::new(), + }); + + let app = setup_app(shared_state); + + serve(app).await.unwrap(); +} + +fn setup_app(shared_state: Arc) -> Router { + Router::new() + .route( + "/api/v1/depot/{game_id}/{version_name}/{chunk_id}", + get(serve_file), + ) + .route("/token", post(set_token)) + .route("/healthcheck", get(healthcheck)) + .with_state(shared_state) +} +async fn serve(app: Router) -> Result<(), std::io::Error> { + let listener = tokio::net::TcpListener::bind("0.0.0.0:5000").await.unwrap(); + info!("started depot server"); + axum::serve(listener, app).await +} + +async fn set_token( + State(state): State>, + Json(payload): Json, +) -> Result { + if check_token_exists(&state, &payload) { + return Ok(StatusCode::OK); + } + + let token = payload.token; + + let library_sources = fetch_library_sources(&token).await.map_err(|v| { + error!("{v:?}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + let valid_library_sources = filter_library_sources(library_sources); + + set_generated_token(state, token, valid_library_sources)?; + + info!("connected to drop server successfully"); + + Ok(StatusCode::OK) } async fn serve_file( - State(state): State>>, + State(state): State>, Path((game_id, version_name, chunk_id)): Path<(String, String, String)>, ) -> Result { - let init_data = state.token.get().ok_or(StatusCode::SERVICE_UNAVAILABLE)?; - let key = (game_id.clone(), version_name.clone()); - - let mut context = if let Some(context) = state.context_cache.get_mut(&key) { - context - } else { - let permit = GLOBAL_CONTEXT_SEMAPHORE - .acquire() - .await - .expect("failed to acquire semaphore"); - - // Check if it's been done while we've been sitting here - if let Some(already_done) = state.context_cache.get_mut(&key) { - already_done - } else { - info!("generating context..."); - let context_result = - create_download_context(init_data, game_id.clone(), version_name.clone()).await?; - state.context_cache.insert(key.clone(), context_result); - - info!("continuing download"); - - drop(permit); - - state.context_cache.get_mut(&key).unwrap() - } - }; + let context_cache = &state.context_cache; + let mut context = get_or_generate_context(&state, context_cache, game_id, version_name).await?; context.last_access = Instant::now(); - let (relative_filename, start, end) = context + let (relative_filename, start, end) = lookup_chunk(&chunk_id, &context)?; + let reader = get_file_reader(&mut context, relative_filename, start, end).await?; + + let stream = ReaderStream::new(reader); + let body: Body = Body::from_stream(stream); + + let headers: AppendHeaders<[(header::HeaderName, String); 2]> = AppendHeaders([ + (header::CONTENT_TYPE, "application/octet-stream".to_owned()), + (header::CONTENT_LENGTH, (end - start).to_string()), + ]); + + Ok((headers, body)) +} + +fn initialise_logger() { + SimpleLogger::new() + .with_level(log::LevelFilter::Info) + .init() + .unwrap(); +} + +async fn acquire_permit<'a>() -> SemaphorePermit<'a> { + GLOBAL_CONTEXT_SEMAPHORE + .acquire() + .await + .expect("failed to acquire semaphore") +} +fn lookup_chunk( + chunk_id: &String, + context: &RefMut<'_, (String, String), DownloadContext>, +) -> Result<(String, usize, usize), StatusCode> { + context .chunk_lookup_table - .get(&chunk_id) + .get(chunk_id) .cloned() - .ok_or(StatusCode::NOT_FOUND)?; - let reader = context + .ok_or(StatusCode::NOT_FOUND) +} +async fn get_file_reader( + context: &mut RefMut<'_, (String, String), DownloadContext>, + relative_filename: String, + start: usize, + end: usize, +) -> Result, StatusCode> { + context .backend .reader( &VersionFile { @@ -104,17 +175,40 @@ async fn serve_file( .map_err(|v| { error!("reader error: {v:?}"); StatusCode::INTERNAL_SERVER_ERROR - })?; + }) +} +async fn get_or_generate_context<'a>( + state: &Arc, + context_cache: &'a DashMap<(String, String), DownloadContext>, + game_id: String, + version_name: String, +) -> Result, StatusCode> { + let initialisation_data = state.token.get().ok_or(StatusCode::SERVICE_UNAVAILABLE)?; + let key = (game_id.clone(), version_name.clone()); - let stream = ReaderStream::new(reader); - let body = Body::from_stream(stream); + if let Some(context) = context_cache.get_mut(&key) { + Ok(context) + } else { + let permit = acquire_permit().await; - let headers: AppendHeaders<[(header::HeaderName, String); 2]> = AppendHeaders([ - (header::CONTENT_TYPE, "application/octet-stream".to_owned()), - (header::CONTENT_LENGTH, (end - start).to_string()), - ]); + // Check if it's been done while we've been sitting here + if let Some(already_done) = context_cache.get_mut(&key) { + Ok(already_done) + } else { + info!("generating context..."); + let context_result = + create_download_context(initialisation_data, game_id.clone(), version_name.clone()) + .await?; - Ok((headers, body)) + state.context_cache.insert(key.clone(), context_result); + + info!("continuing download"); + + drop(permit); + + Ok(context_cache.get_mut(&key).unwrap()) + } + } } #[derive(Deserialize)] @@ -122,33 +216,27 @@ struct TokenPayload { token: String, } -async fn healthcheck(State(state): State>>) -> StatusCode { - let inited = state.token.initialized(); - if !inited { +async fn healthcheck(State(state): State>) -> StatusCode { + let initialised = state.token.initialized(); + if !initialised { return StatusCode::SERVICE_UNAVAILABLE; } return StatusCode::OK; } -async fn set_token( - State(state): State>>, - Json(payload): Json, -) -> Result { +fn check_token_exists(state: &Arc, payload: &TokenPayload) -> bool { if let Some(existing_data) = state.token.get() { if existing_data.token != payload.token { panic!("already set up but provided with a different token"); } - return Ok(StatusCode::OK); + return true; } - - let token = payload.token; - - let library_sources = fetch_library_sources(token.clone()).await.map_err(|v| { - error!("{v:?}"); - StatusCode::INTERNAL_SERVER_ERROR - })?; - - let valid_library_sources = library_sources + false +} +fn filter_library_sources( + library_sources: Vec, +) -> HashMap { + library_sources .into_iter() .filter(|v| { matches!( @@ -170,51 +258,19 @@ async fn set_token( (v.id, (path, v.backend)) }) - .collect::>(); - + .collect() +} +fn set_generated_token( + state: Arc, + token: String, + libraries: HashMap, +) -> Result<(), StatusCode> { state .token - .set(AppInitData { - token, - libraries: valid_library_sources, - }) + .set(AppInitData { token, libraries }) .map_err(|err| { error!("failed to set token: {err:?}"); StatusCode::INTERNAL_SERVER_ERROR })?; - - info!("connected to drop server successfully"); - - Ok(StatusCode::OK) -} - -#[tokio::main] -async fn main() { - SimpleLogger::new() - .with_level(log::LevelFilter::Info) - .init() - .unwrap(); - - if let Ok(working_directory) = std::env::var("WORKING_DIRECTORY") { - set_current_dir(working_directory).expect("failed to change working directory"); - } - - let shared_state = Arc::new(AppState { - token: OnceCell::new(), - context_cache: DashMap::new(), - }); - - let app = Router::new() - .route( - "/api/v1/depot/{game_id}/{version_name}/{chunk_id}", - get(serve_file), - ) - .route("/token", post(set_token)) - .route("/healthcheck", get(healthcheck)) - .with_state(shared_state); - - // run our app with hyper, listening globally on port 3000 - let listener = tokio::net::TcpListener::bind("0.0.0.0:5000").await.unwrap(); - info!("started depot server"); - axum::serve(listener, app).await.unwrap(); + Ok(()) } diff --git a/torrential/src/remote.rs b/torrential/src/remote.rs index e40f4cc0..d242cbc9 100644 --- a/torrential/src/remote.rs +++ b/torrential/src/remote.rs @@ -66,7 +66,8 @@ pub async fn fetch_download_context( .text() .await .unwrap_or("(failed to read body)".to_string()) - ).into()); + ) + .into()); } let context: ContextResponseBody = context_response.json().await?; @@ -74,22 +75,21 @@ pub async fn fetch_download_context( Ok(context) } - #[derive(Deserialize, Debug)] #[non_exhaustive] pub enum LibraryBackend { Filesystem, - FlatFilesystem + FlatFilesystem, } #[derive(Deserialize)] pub struct LibrarySource { pub options: serde_json::Value, pub id: String, - pub backend: LibraryBackend + pub backend: LibraryBackend, } -pub async fn fetch_library_sources(token: String) -> Result> { +pub async fn fetch_library_sources(token: &String) -> Result> { let source_response = CLIENT .get(REMOTE_URL.join("/api/v1/admin/library/sources")?) .header("Authorization", format!("Bearer {}", token)) @@ -110,4 +110,4 @@ pub async fn fetch_library_sources(token: String) -> Result> let library_sources: Vec = source_response.json().await?; Ok(library_sources) -} \ No newline at end of file +}