diff --git a/torrential/Cargo.lock b/torrential/Cargo.lock index b3e798b7..24eadec9 100644 --- a/torrential/Cargo.lock +++ b/torrential/Cargo.lock @@ -504,6 +504,17 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.31" @@ -523,9 +534,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-core", + "futures-macro", "futures-task", "pin-project-lite", "pin-utils", + "slab", ] [[package]] @@ -1744,6 +1757,7 @@ dependencies = [ "criterion", "dashmap", "droplet-rs", + "futures-util", "log", "rand", "reqwest", diff --git a/torrential/Cargo.toml b/torrential/Cargo.toml index a4a8e02d..88c95fe9 100644 --- a/torrential/Cargo.toml +++ b/torrential/Cargo.toml @@ -30,6 +30,7 @@ serde_json = "1.0.145" url = { version = "2.5.7", default-features = false } tokio-util = { version = "0.7.17", features = ["io"] } async-trait = "0.1.89" +futures-util = "0.3.31" [lints.clippy] pedantic = { level = "warn", priority = -1 } diff --git a/torrential/src/handlers.rs b/torrential/src/handlers.rs index b54bee6d..f7b0fe18 100644 --- a/torrential/src/handlers.rs +++ b/torrential/src/handlers.rs @@ -1,43 +1,10 @@ use std::sync::Arc; -use axum::{ - body::Body, - extract::{Path, State}, - response::{AppendHeaders, IntoResponse}, -}; -use dashmap::{DashMap, mapref::one::RefMut}; -use droplet_rs::versions::types::{MinimumFileObject, VersionFile}; -use log::{error, info}; -use reqwest::{StatusCode, header}; -use tokio::sync::SemaphorePermit; -use tokio_util::io::ReaderStream; +use axum::{Json, extract::State}; +use reqwest::StatusCode; +use serde::Deserialize; -use crate::{ - DownloadContext, GLOBAL_CONTEXT_SEMAPHORE, download::create_download_context, state::AppState, -}; - -pub async fn serve_file( - State(state): State>, - Path((game_id, version_name, chunk_id)): Path<(String, String, String)>, -) -> Result { - let context_cache = &state.context_cache; - - let mut context = get_or_generate_context(&state, context_cache, game_id, version_name).await?; - context.reset_last_access(); - - let (relative_filename, start, end) = lookup_chunk(&chunk_id, &context)?; - let reader = get_file_reader(&mut context, relative_filename, start, end).await?; - - let stream = ReaderStream::new(reader); - let body: Body = Body::from_stream(stream); - - let headers: AppendHeaders<[(header::HeaderName, String); 2]> = AppendHeaders([ - (header::CONTENT_TYPE, "application/octet-stream".to_owned()), - (header::CONTENT_LENGTH, (end - start).to_string()), - ]); - - Ok((headers, body)) -} +use crate::state::AppState; pub async fn healthcheck(State(state): State>) -> StatusCode { let initialised = state.token.initialized(); @@ -47,80 +14,16 @@ pub async fn healthcheck(State(state): State>) -> StatusCode { StatusCode::OK } -async fn acquire_permit<'a>() -> SemaphorePermit<'a> { - return GLOBAL_CONTEXT_SEMAPHORE - .acquire() - .await - .expect("failed to acquire semaphore"); -} -fn lookup_chunk( - chunk_id: &String, - context: &RefMut<'_, (String, String), DownloadContext>, -) -> Result<(String, usize, usize), StatusCode> { - context - .chunk_lookup_table - .get(chunk_id) - .cloned() - .ok_or(StatusCode::NOT_FOUND) -} -async fn get_file_reader( - context: &mut RefMut<'_, (String, String), DownloadContext>, - relative_filename: String, - start: usize, - end: usize, -) -> Result, StatusCode> { - context - .backend - .reader( - &VersionFile { - relative_filename: relative_filename.clone(), - permission: 0, - size: 0, - }, - start as u64, - end as u64, - ) - .await - .map_err(|v| { - error!("reader error: {v:?}"); - StatusCode::INTERNAL_SERVER_ERROR - }) -} -async fn get_or_generate_context<'a>( - state: &Arc, - context_cache: &'a DashMap<(String, String), DownloadContext>, +#[derive(Deserialize)] +pub struct InvalidateBody { game_id: String, version_name: String, -) -> Result, StatusCode> { - let initialisation_data = state.token.get().ok_or(StatusCode::SERVICE_UNAVAILABLE)?; - let key = (game_id.clone(), version_name.clone()); - - if let Some(context) = context_cache.get_mut(&key) { - Ok(context) - } else { - let permit = acquire_permit().await; - - // Check if it's been done while we've been sitting here - if let Some(already_done) = context_cache.get_mut(&key) { - Ok(already_done) - } else { - info!("generating context..."); - let context_result = create_download_context( - &*state.metadata_provider, - &*state.backend_factory, - initialisation_data, - game_id.clone(), - version_name.clone(), - ) - .await?; - - state.context_cache.insert(key.clone(), context_result); - - info!("continuing download"); - - drop(permit); - - Ok(context_cache.get_mut(&key).unwrap()) - } - } +} + +pub async fn invalidate( + State(state): State>, + Json(payload): Json, +) -> StatusCode { + state.context_cache.remove(&(payload.game_id, payload.version_name)); + StatusCode::OK } diff --git a/torrential/src/lib.rs b/torrential/src/lib.rs index 4c70554f..4a28441b 100644 --- a/torrential/src/lib.rs +++ b/torrential/src/lib.rs @@ -1,5 +1,6 @@ use tokio::sync::Semaphore; mod download; +pub mod serve; pub mod handlers; mod manifest; mod remote; diff --git a/torrential/src/main.rs b/torrential/src/main.rs index b9fbb195..c1ac1b63 100644 --- a/torrential/src/main.rs +++ b/torrential/src/main.rs @@ -10,10 +10,9 @@ use axum::{ use dashmap::DashMap; use log::info; use simple_logger::SimpleLogger; -use tokio::sync::OnceCell; +use tokio::{runtime::Handle, sync::OnceCell}; use torrential::{ - DropBackendFactory, DropLibraryProvider, DropContextProvider, handlers, set_token, - state::AppState, + DropBackendFactory, DropContextProvider, DropLibraryProvider, handlers, serve, set_token, state::AppState }; use url::Url; @@ -25,6 +24,9 @@ async fn main() { set_current_dir(working_directory).expect("failed to change working directory"); } + let metrics = Handle::current().metrics(); + info!("using {} threads", metrics.num_workers()); + let remote_url = get_remote_url(); let shared_state = Arc::new(AppState { @@ -44,11 +46,12 @@ async fn main() { fn setup_app(shared_state: Arc) -> Router { Router::new() .route( - "/api/v1/depot/{game_id}/{version_name}/{chunk_id}", - get(handlers::serve_file), + "/api/v1/depot/{game_id}/{version_name}/{*chunk_ids}", + get(serve::serve_file), ) .route("/token", post(set_token)) .route("/healthcheck", get(handlers::healthcheck)) + .route("/invalid", post(handlers::invalidate)) .with_state(shared_state) } diff --git a/torrential/src/serve.rs b/torrential/src/serve.rs new file mode 100644 index 00000000..b16c9619 --- /dev/null +++ b/torrential/src/serve.rs @@ -0,0 +1,135 @@ +use std::sync::Arc; + +use axum::{ + body::Body, + extract::{Path, State}, + http::HeaderMap, + response::{AppendHeaders, IntoResponse}, +}; +use dashmap::{DashMap, mapref::one::RefMut}; +use droplet_rs::versions::types::{MinimumFileObject, VersionFile}; +use log::{error, info}; +use reqwest::{StatusCode, header}; +use tokio::sync::SemaphorePermit; +use tokio_util::io::ReaderStream; +use futures_util::{StreamExt as _, stream}; + + +use crate::{ + DownloadContext, GLOBAL_CONTEXT_SEMAPHORE, download::create_download_context, state::AppState, +}; + +pub async fn serve_file( + State(state): State>, + Path((game_id, version_name, chunk_ids)): Path<(String, String, String)>, +) -> Result { + let context_cache = &state.context_cache; + + let mut context = get_or_generate_context(&state, context_cache, game_id, version_name).await?; + context.reset_last_access(); + + let chunk_ids = chunk_ids.split("/").collect::>(); + let mut streams = Vec::with_capacity(chunk_ids.len()); + let mut content_lengths = Vec::with_capacity(chunk_ids.len()); + let mut total_size = 0; + for chunk_id in chunk_ids { + let (relative_filename, start, end) = lookup_chunk(chunk_id, &context)?; + let reader = get_file_reader(&mut context, relative_filename, start, end).await?; + + let stream = ReaderStream::new(reader); + streams.push(stream); + content_lengths.push((end - start).to_string()); + + total_size += end - start; + } + + let stream = stream::iter(streams).flatten(); + let body: Body = Body::from_stream(stream); + + let mut headers = HeaderMap::new(); + headers.insert("Content-Type", "application/octet-stream".parse().unwrap()); + headers.insert("Content-Length", total_size.to_string().parse().unwrap()); + headers.insert( + "Content-Lengths", + content_lengths.join(",").parse().unwrap(), + ); + + Ok((headers, body)) +} +async fn acquire_permit<'a>() -> SemaphorePermit<'a> { + return GLOBAL_CONTEXT_SEMAPHORE + .acquire() + .await + .expect("failed to acquire semaphore"); +} +fn lookup_chunk( + chunk_id: &str, + context: &RefMut<'_, (String, String), DownloadContext>, +) -> Result<(String, usize, usize), StatusCode> { + context + .chunk_lookup_table + .get(chunk_id) + .cloned() + .ok_or(StatusCode::NOT_FOUND) +} +async fn get_file_reader( + context: &mut RefMut<'_, (String, String), DownloadContext>, + relative_filename: String, + start: usize, + end: usize, +) -> Result, StatusCode> { + context + .backend + .reader( + &VersionFile { + relative_filename: relative_filename.clone(), + permission: 0, + size: 0, + }, + start as u64, + end as u64, + ) + .await + .map_err(|v| { + error!("reader error: {v:?}"); + StatusCode::INTERNAL_SERVER_ERROR + }) +} +async fn get_or_generate_context<'a>( + state: &Arc, + context_cache: &'a DashMap<(String, String), DownloadContext>, + game_id: String, + version_name: String, +) -> Result, StatusCode> { + let initialisation_data = state.token.get().ok_or(StatusCode::SERVICE_UNAVAILABLE)?; + let key = (game_id.clone(), version_name.clone()); + + if let Some(context) = context_cache.get_mut(&key) { + Ok(context) + } else { + let permit = acquire_permit().await; + + // Check if it's been done while we've been sitting here + if let Some(already_done) = context_cache.get_mut(&key) { + Ok(already_done) + } else { + info!("generating context for {}...", game_id); + let context_result = create_download_context( + &*state.metadata_provider, + &*state.backend_factory, + initialisation_data, + game_id.clone(), + version_name.clone(), + ) + .await?; + + state.context_cache.insert(key.clone(), context_result); + + info!("continuing download for {}", game_id); + + drop(permit); + + Ok(context_cache.get_mut(&key).unwrap()) + } + } +}