diff --git a/src/lib.rs b/src/lib.rs index b88a2c5..8d33f1e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,7 +12,7 @@ pub mod websockets; pub use attestation::AttestationGenerator; use bytes::Bytes; -use http::HeaderValue; +use http::{HeaderMap, HeaderName, HeaderValue}; use http_body_util::{combinators::BoxBody, BodyExt}; use hyper::{service::service_fn, Response}; use hyper_util::rt::TokioIo; @@ -45,6 +45,12 @@ const ATTESTATION_TYPE_HEADER: &str = "X-Flashbots-Attestation-Type"; /// The header name for giving measurements const MEASUREMENT_HEADER: &str = "X-Flashbots-Measurement"; +/// The header name for giving the forwarded for IP +static X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for"); + +/// The header name for giving the 'real IP' - in our case that of the client +static X_REAL_IP: HeaderName = HeaderName::from_static("x-real-ip"); + /// The longest time in seconds to wait between reconnection attempts const SERVER_RECONNECT_MAX_BACKOFF_SECS: u64 = 120; @@ -122,15 +128,20 @@ impl ProxyServer { /// Accept an incoming connection and handle it in a seperate task pub async fn accept(&self) -> Result<(), ProxyError> { let target = self.target.clone(); - let (inbound, _client_addr) = self.listener.accept().await?; + let (inbound, client_addr) = self.listener.accept().await?; let attested_tls_server = self.attested_tls_server.clone(); tokio::spawn(async move { match attested_tls_server.handle_connection(inbound).await { Ok((tls_stream, measurements, attestation_type)) => { - if let Err(err) = - Self::handle_connection(tls_stream, measurements, attestation_type, target) - .await + if let Err(err) = Self::handle_connection( + tls_stream, + measurements, + attestation_type, + target, + client_addr, + ) + .await { warn!("Failed to handle connection: {err}"); } @@ -155,6 +166,7 @@ impl ProxyServer { measurements: Option, remote_attestation_type: AttestationType, target: String, + client_addr: SocketAddr, ) -> Result<(), ProxyError> { tracing::debug!("proxy-server accepted connection"); @@ -163,9 +175,29 @@ impl ProxyServer { // Setup a request handler let service = service_fn(move |mut req| { + let headers = req.headers_mut(); + + // Add or update the HOST header + let old_value = update_header(headers, &http::header::HOST, &target); + tracing::info!("Updating Host header - old value: {old_value:?} new value: {target}",); + + // Add the x-real-ip header + let client_ip = client_addr.ip().to_string(); + update_header(headers, &X_REAL_IP, &client_ip); + + // Add or update the x-forwarded-for header + let new_x_forwarded_for = + match headers.get(&X_FORWARDED_FOR).and_then(|v| v.to_str().ok()) { + Some(existing) if !existing.trim().is_empty() => { + format!("{}, {}", existing.trim(), client_ip) + } + _ => client_ip.clone(), + }; + + update_header(headers, &X_FORWARDED_FOR, &new_x_forwarded_for); + // If we have measurements, from the remote peer, add them to the request header let measurements = measurements.clone(); - let headers = req.headers_mut(); if let Some(measurements) = measurements { match measurements.to_header_format() { Ok(header_value) => { @@ -178,10 +210,11 @@ impl ProxyServer { } } } - headers.insert( + + update_header( + headers, ATTESTATION_TYPE_HEADER, - HeaderValue::from_str(remote_attestation_type.as_str()) - .expect("Attestation type should be able to be encoded as a header value"), + remote_attestation_type.as_str(), ); let target = target.clone(); @@ -346,11 +379,11 @@ impl ProxyClient { } } } - headers.insert( + + update_header( + headers, ATTESTATION_TYPE_HEADER, - HeaderValue::from_str(remote_attestation_type.as_str()).expect( - "Attestation type should be able to be encoded as a header value", - ), + remote_attestation_type.as_str(), ); (Ok(resp.map(|b| b.boxed())), false) } @@ -508,6 +541,25 @@ impl ProxyClient { } } +/// Update a request/response header if we are able to encode the header value +/// +/// This avoids bailing on bad header values - the headers are simply not updated +fn update_header( + headers: &mut HeaderMap, + header_name: K, + header_value: &str, +) -> Option +where + K: http::header::IntoHeaderName + std::fmt::Display, +{ + if let Ok(value) = HeaderValue::from_str(header_value) { + headers.insert(header_name, value) + } else { + error!("Failed to encode {header_name} header value: {header_value}"); + None + } +} + /// An error when running a proxy client or server #[derive(Error, Debug)] pub enum ProxyError {