Skip to content
78 changes: 65 additions & 13 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub mod websockets;
pub use attestation::AttestationGenerator;

use bytes::Bytes;
use http::HeaderValue;
use http::{HeaderMap, HeaderName, HeaderValue};
use http_body_util::{combinators::BoxBody, BodyExt};
use hyper::{service::service_fn, Response};
use hyper_util::rt::TokioIo;
Expand Down Expand Up @@ -45,6 +45,12 @@ const ATTESTATION_TYPE_HEADER: &str = "X-Flashbots-Attestation-Type";
/// The header name for giving measurements
const MEASUREMENT_HEADER: &str = "X-Flashbots-Measurement";

/// The header name for giving the forwarded for IP
static X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");

/// The header name for giving the 'real IP' - in our case that of the client
static X_REAL_IP: HeaderName = HeaderName::from_static("x-real-ip");

/// The longest time in seconds to wait between reconnection attempts
const SERVER_RECONNECT_MAX_BACKOFF_SECS: u64 = 120;

Expand Down Expand Up @@ -122,15 +128,20 @@ impl ProxyServer {
/// Accept an incoming connection and handle it in a seperate task
pub async fn accept(&self) -> Result<(), ProxyError> {
let target = self.target.clone();
let (inbound, _client_addr) = self.listener.accept().await?;
let (inbound, client_addr) = self.listener.accept().await?;
let attested_tls_server = self.attested_tls_server.clone();

tokio::spawn(async move {
match attested_tls_server.handle_connection(inbound).await {
Ok((tls_stream, measurements, attestation_type)) => {
if let Err(err) =
Self::handle_connection(tls_stream, measurements, attestation_type, target)
.await
if let Err(err) = Self::handle_connection(
tls_stream,
measurements,
attestation_type,
target,
client_addr,
)
.await
{
warn!("Failed to handle connection: {err}");
}
Expand All @@ -155,6 +166,7 @@ impl ProxyServer {
measurements: Option<MultiMeasurements>,
remote_attestation_type: AttestationType,
target: String,
client_addr: SocketAddr,
) -> Result<(), ProxyError> {
tracing::debug!("proxy-server accepted connection");

Expand All @@ -163,9 +175,29 @@ impl ProxyServer {

// Setup a request handler
let service = service_fn(move |mut req| {
let headers = req.headers_mut();

// Add or update the HOST header
let old_value = update_header(headers, &http::header::HOST, &target);
tracing::info!("Updating Host header - old value: {old_value:?} new value: {target}",);

// Add the x-real-ip header
let client_ip = client_addr.ip().to_string();
update_header(headers, &X_REAL_IP, &client_ip);

// Add or update the x-forwarded-for header
let new_x_forwarded_for =
match headers.get(&X_FORWARDED_FOR).and_then(|v| v.to_str().ok()) {
Some(existing) if !existing.trim().is_empty() => {
format!("{}, {}", existing.trim(), client_ip)
}
_ => client_ip.clone(),
};

update_header(headers, &X_FORWARDED_FOR, &new_x_forwarded_for);

// If we have measurements, from the remote peer, add them to the request header
let measurements = measurements.clone();
let headers = req.headers_mut();
if let Some(measurements) = measurements {
match measurements.to_header_format() {
Ok(header_value) => {
Expand All @@ -178,10 +210,11 @@ impl ProxyServer {
}
}
}
headers.insert(

update_header(
headers,
ATTESTATION_TYPE_HEADER,
HeaderValue::from_str(remote_attestation_type.as_str())
.expect("Attestation type should be able to be encoded as a header value"),
remote_attestation_type.as_str(),
);

let target = target.clone();
Expand Down Expand Up @@ -346,11 +379,11 @@ impl ProxyClient {
}
}
}
headers.insert(

update_header(
headers,
ATTESTATION_TYPE_HEADER,
HeaderValue::from_str(remote_attestation_type.as_str()).expect(
"Attestation type should be able to be encoded as a header value",
),
remote_attestation_type.as_str(),
);
(Ok(resp.map(|b| b.boxed())), false)
}
Expand Down Expand Up @@ -508,6 +541,25 @@ impl ProxyClient {
}
}

/// Update a request/response header if we are able to encode the header value
///
/// This avoids bailing on bad header values - the headers are simply not updated
fn update_header<K>(
headers: &mut HeaderMap,
header_name: K,
header_value: &str,
) -> Option<HeaderValue>
where
K: http::header::IntoHeaderName + std::fmt::Display,
{
if let Ok(value) = HeaderValue::from_str(header_value) {
headers.insert(header_name, value)
} else {
error!("Failed to encode {header_name} header value: {header_value}");
None
}
}

/// An error when running a proxy client or server
#[derive(Error, Debug)]
pub enum ProxyError {
Expand Down