rhttp/src/server/mod.rs

268 lines
8.5 KiB
Rust

use axum::{
Router,
body::Body,
extract::{Request, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::any,
};
use std::sync::Arc;
use tokio::net::TcpListener;
use tracing::{error, info};
use crate::config::{RouteRule, ServerConfig, SiteConfig};
#[derive(Clone)]
pub struct ProxyServer {
pub config: Arc<ServerConfig>,
}
impl ProxyServer {
pub fn new(config: ServerConfig) -> Self {
Self {
config: Arc::new(config),
}
}
pub async fn start(&self) -> Result<(), Box<dyn std::error::Error>> {
let app = Router::new()
.fallback(any(handle_request))
.with_state(self.clone());
let addr = format!("0.0.0.0:{}", self.config.port);
let listener = TcpListener::bind(&addr).await?;
info!("Starting rhttpd server on {}", addr);
axum::serve(listener, app).await?;
Ok(())
}
pub fn find_site(&self, hostname: &str) -> Option<&SiteConfig> {
self.config.sites.get(hostname)
}
pub fn find_route<'a>(&self, site: &'a SiteConfig, path: &str) -> Option<&'a RouteRule> {
site.routes.iter().find(|route| {
let pattern = match route {
RouteRule::Static { path_pattern, .. } => path_pattern,
RouteRule::ReverseProxy { path_pattern, .. } => path_pattern,
RouteRule::ForwardProxy { path_pattern, .. } => path_pattern,
RouteRule::TcpProxy { path_pattern, .. } => path_pattern,
};
// Simple pattern matching for now
if pattern.ends_with("/*") {
let base = &pattern[..pattern.len() - 2];
path.starts_with(base)
} else if pattern == "*" {
true
} else {
path == pattern
}
})
}
}
pub async fn handle_request(State(server): State<ProxyServer>, req: Request<Body>) -> Response {
let hostname = req
.headers()
.get("host")
.and_then(|h| h.to_str().ok())
.unwrap_or("localhost");
let path = req.uri().path().to_string(); // Clone to avoid borrowing issues
info!("Request: {} {} {}", req.method(), hostname, path);
// Find site configuration
let site = match server.find_site(hostname) {
Some(site) => site,
None => {
error!("Site not found for hostname: {}", hostname);
return (StatusCode::NOT_FOUND, "Site not found").into_response();
}
};
// Find matching route
let route = match server.find_route(site, &path) {
Some(route) => route,
None => {
error!("No route found for path: {}", path);
return (StatusCode::NOT_FOUND, "Route not found").into_response();
}
};
// Handle request based on route type
match route {
RouteRule::Static { root, index, .. } => {
handle_static_request(req, &root, index.as_deref(), &path).await
}
RouteRule::ReverseProxy { target, .. } => handle_reverse_proxy(req, &target).await,
RouteRule::ForwardProxy { .. } => {
// For a forward proxy, we typically don't know the target from the path
// This would need to be handled differently in a real implementation
// For now, we'll return not implemented
(
StatusCode::NOT_IMPLEMENTED,
"Forward proxy not fully implemented yet",
)
.into_response()
}
RouteRule::TcpProxy {
target, protocol, ..
} => {
// For now, return a simple response indicating TCP proxy is not fully implemented
info!(
"TCP proxy requested for {} with protocol {:?}",
target, protocol
);
(
StatusCode::NOT_IMPLEMENTED,
format!(
"TCP proxy to {} (protocol: {:?}) - use CONNECT method for raw TCP",
target, protocol
),
)
.into_response()
}
}
}
async fn handle_static_request(
_req: Request<Body>,
root: &std::path::Path,
index_files: Option<&[String]>,
path: &str,
) -> Response {
let file_path = root.join(&path[1..]); // Remove leading '/'
// If it's a directory, try index files
if file_path.is_dir()
&& let Some(index_files) = index_files
{
for index_file in index_files {
let index_path = file_path.join(index_file);
if index_path.exists() {
match std::fs::read_to_string(&index_path) {
Ok(content) => {
let mime_type = mime_guess::from_path(&index_path)
.first_or_octet_stream()
.to_string();
return Response::builder()
.status(StatusCode::OK)
.header("Content-Type", mime_type)
.body(Body::from(content))
.unwrap()
.into_response();
}
Err(_) => continue,
}
}
}
}
// Try to read the file
if file_path.exists() && file_path.is_file() {
match std::fs::read_to_string(&file_path) {
Ok(content) => {
let mime_type = mime_guess::from_path(&file_path)
.first_or_octet_stream()
.to_string();
Response::builder()
.status(StatusCode::OK)
.header("Content-Type", mime_type)
.body(Body::from(content))
.unwrap()
.into_response()
}
Err(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Failed to read file").into_response(),
}
} else {
(StatusCode::NOT_FOUND, "File not found").into_response()
}
}
async fn handle_reverse_proxy(req: Request<Body>, target: &str) -> Response {
use reqwest::Client;
let client = Client::new();
let method = req.method();
let url = format!(
"{}{}",
target,
req.uri()
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("")
);
// Convert axum Method to reqwest Method
let reqwest_method = match method.as_str() {
"GET" => reqwest::Method::GET,
"POST" => reqwest::Method::POST,
"PUT" => reqwest::Method::PUT,
"DELETE" => reqwest::Method::DELETE,
"HEAD" => reqwest::Method::HEAD,
"OPTIONS" => reqwest::Method::OPTIONS,
"PATCH" => reqwest::Method::PATCH,
_ => reqwest::Method::GET,
};
let mut builder = client.request(reqwest_method, &url);
// Copy headers
for (name, value) in req.headers() {
if name != "host" {
let name_str = name.to_string();
if let Ok(value_str) = value.to_str() {
builder = builder.header(name_str, value_str);
}
}
}
// Copy body
let body_bytes = match axum::body::to_bytes(req.into_body(), usize::MAX).await {
Ok(bytes) => bytes,
Err(e) => {
error!("Failed to read request body: {}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to read body").into_response();
}
};
match builder.body(body_bytes).send().await {
Ok(resp) => {
let status = resp.status();
let mut response = Response::builder().status(
StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
);
// Copy response headers
for (name, value) in resp.headers() {
let name_str = name.to_string();
if let Ok(value_str) = value.to_str() {
response = response.header(name_str, value_str);
}
}
// Get response body
match resp.bytes().await {
Ok(body_bytes) => response
.body(Body::from(body_bytes))
.unwrap()
.into_response(),
Err(e) => {
error!("Failed to read response body: {}", e);
(StatusCode::BAD_GATEWAY, "Failed to read response").into_response()
}
}
}
Err(e) => {
error!("Proxy request failed: {}", e);
(StatusCode::BAD_GATEWAY, "Proxy request failed").into_response()
}
}
}