Last active
January 4, 2025 13:35
-
-
Save kran/46798d857bc7268665865ec20b83a736 to your computer and use it in GitHub Desktop.
a little helper, setting up a cozy HTTP server framework that handles requests, manages routes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import com.sun.net.httpserver.Headers; | |
import com.sun.net.httpserver.HttpContext; | |
import com.sun.net.httpserver.HttpExchange; | |
import com.sun.net.httpserver.HttpServer; | |
import java.io.*; | |
import java.net.InetSocketAddress; | |
import java.net.URLDecoder; | |
import java.nio.charset.StandardCharsets; | |
import java.nio.file.Files; | |
import java.util.*; | |
import java.util.concurrent.Executor; | |
import java.util.function.BiConsumer; | |
import java.util.function.Consumer; | |
import java.util.function.Function; | |
public class Doge { | |
// 请求上下文的包装类 | |
public static class Exchange { | |
public final HttpExchange inner; | |
protected final Map<String, Object> attributes = new HashMap<>(); | |
protected int statusCode = 200; // 默认状态码 | |
protected int maxRequestBodySize = Integer.MAX_VALUE; // 默认无限制 | |
protected byte[] cachedBody; | |
protected Map<String, String> cachedQueryParams; | |
protected static final int DEFAULT_BUFFER_SIZE = 8192; | |
private boolean headersSent = false; | |
private boolean closed = false; | |
public Exchange(HttpExchange inner) { | |
this.inner = inner; | |
} | |
public String path() { | |
return inner.getRequestURI().getPath(); | |
} | |
public String method() { | |
return inner.getRequestMethod(); | |
} | |
public Headers headers() { | |
return inner.getRequestHeaders(); | |
} | |
// 用户定义的JSON输出方法 | |
public void json(Object obj) { | |
// 用户需要在子类中实现具体的JSON序列化逻辑 | |
throw new UnsupportedOperationException("JSON serialization not implemented"); | |
} | |
public Exchange status(int code) { | |
this.statusCode = code; // 仅设置状态码,不发送响应头 | |
return this; | |
} | |
public Exchange contentType(String type) { | |
inner.getResponseHeaders().set("Content-Type", type); | |
return this; | |
} | |
public void download(File file, String filename) { | |
if (!file.exists() || !file.isFile()) { | |
status(404).result("File not found"); | |
return; | |
} | |
contentType("application/octet-stream") | |
.header("Content-Disposition", "attachment; filename=\"" + filename + "\""); | |
try { | |
byte[] data = Files.readAllBytes(file.toPath()); | |
result(data); | |
} catch (IOException e) { | |
throw new RuntimeException("File read failed", e); | |
} | |
} | |
public void redirect(String location) { | |
status(302) | |
.header("Location", location) | |
.result(""); | |
} | |
public Exchange attr(String name, Object value) { | |
attributes.put(name, value); | |
return this; | |
} | |
@SuppressWarnings("unchecked") | |
public <T> T attr(String name) { | |
return (T) attributes.get(name); | |
} | |
public <T> T attr(String name, Class<T> type) { | |
T value = attr(name); | |
if (value == null) { | |
throw new IllegalStateException(type.getSimpleName() + " not found in context"); | |
} | |
return value; | |
} | |
public void sse(Consumer<SseEmitter> handler) { | |
contentType("text/event-stream") | |
.header("Cache-Control", "no-cache") | |
.header("Connection", "keep-alive"); | |
try (PrintWriter writer = new PrintWriter(new OutputStreamWriter(inner.getResponseBody()))) { | |
SseEmitter emitter = new SseEmitter(writer); | |
inner.sendResponseHeaders(200, 0); | |
handler.accept(emitter); | |
} catch (IOException e) { | |
throw new RuntimeException("SSE failed", e); | |
} | |
} | |
public void close() { | |
if (closed) { | |
return; | |
} | |
try { | |
inner.close(); | |
} finally { | |
closed = true; | |
} | |
} | |
public static String urlDecode(String str) { | |
return URLDecoder.decode(str, StandardCharsets.UTF_8); | |
} | |
// 获取查询参数 | |
public Map<String, String> params() { | |
if (cachedQueryParams != null) { | |
return cachedQueryParams; | |
} | |
String query = inner.getRequestURI().getQuery(); | |
cachedQueryParams = new LinkedHashMap<>(); | |
if (query == null || query.trim().equals("")) { | |
return cachedQueryParams; | |
} | |
for (String param : query.split("&")) { | |
int idx = param.indexOf('='); | |
if (idx > 0) { | |
String key = urlDecode(param.substring(0, idx)); | |
String value = urlDecode(param.substring(idx + 1)); | |
cachedQueryParams.put(key, value); | |
} else if (param.length() > 0) { | |
cachedQueryParams.put(urlDecode(param), ""); | |
} | |
} | |
return cachedQueryParams; | |
} | |
// 获取请求体为字符串 | |
public String body() { | |
return new String(bytes(), StandardCharsets.UTF_8); | |
} | |
public byte[] bytes() { | |
if (cachedBody != null) { | |
return cachedBody; | |
} | |
try (InputStream is = inner.getRequestBody(); | |
ByteArrayOutputStream result = new ByteArrayOutputStream()) { | |
byte[] buffer = new byte[DEFAULT_BUFFER_SIZE]; | |
int totalRead = 0; | |
int length; | |
while ((length = is.read(buffer)) != -1) { | |
totalRead += length; | |
if (totalRead > maxRequestBodySize) { | |
throw new IOException("Request body too large"); | |
} | |
result.write(buffer, 0, length); | |
} | |
cachedBody = result.toByteArray(); | |
return cachedBody; | |
} catch (IOException e) { | |
throw new RuntimeException("Failed to read request body", e); | |
} | |
} | |
// 设响应头 | |
public Exchange header(String name, String value) { | |
inner.getResponseHeaders().set(name, value); | |
return this; | |
} | |
// 获取单个请求头 | |
public String header(String name) { | |
return inner.getRequestHeaders().getFirst(name); | |
} | |
// 设置Cookie | |
public Exchange cookie(String name, String value, int maxAge) { | |
String cookie = String.format("%s=%s; Max-Age=%d; Path=/", name, value, maxAge); | |
inner.getResponseHeaders().add("Set-Cookie", cookie); | |
return this; | |
} | |
// 获取Cookie | |
public String cookie(String name) { | |
List<String> cookies = inner.getRequestHeaders().get("Cookie"); | |
if (cookies != null) { | |
for (String cookie : cookies) { | |
for (String pair : cookie.split(";")) { | |
String[] keyValue = pair.trim().split("="); | |
if (keyValue.length == 2 && keyValue[0].equals(name)) { | |
return keyValue[1]; | |
} | |
} | |
} | |
} | |
return null; | |
} | |
public void html(String htmlContent) { | |
contentType("text/html; charset=UTF-8"); | |
result(htmlContent); | |
} | |
public void staticFile(File root, String path) { | |
File file = new File(root, path); | |
try { | |
if(!file.exists() || !file.isFile()) { | |
status(404).result("File not found: " + file); | |
return; | |
} | |
if(!file.getCanonicalPath().startsWith(root.getCanonicalPath())) { | |
status(403).result("Forbidden: " + file); | |
return; | |
} | |
byte[] fileBytes = Files.readAllBytes(file.toPath()); | |
String contentType = Files.probeContentType(file.toPath()); | |
contentType(contentType != null ? contentType : "application/octet-stream"); | |
result(fileBytes); | |
} catch (IOException e) { | |
throw new RuntimeException("Read file error: " + file, e); | |
} finally { | |
close(); | |
} | |
} | |
public void result(Object obj) { | |
String response = obj != null ? obj.toString() : "null"; | |
byte[] responseBytes = response.getBytes(StandardCharsets.UTF_8); | |
result(responseBytes); | |
} | |
public void result(byte[] responseBytes) { | |
try { | |
sendHeaders(responseBytes.length); | |
inner.getResponseBody().write(responseBytes); | |
} catch (IOException e) { | |
e.printStackTrace(); | |
throw new RuntimeException("Failed to send response", e); | |
} finally { | |
close(); | |
} | |
} | |
public void result(InputStream inputStream) { | |
try { | |
int availableBytes = inputStream.available(); | |
if (availableBytes > 0) { | |
sendHeaders(availableBytes); | |
} else { | |
inner.getResponseHeaders().set("Transfer-Encoding", "chunked"); | |
sendHeaders(0); // 使用分块传输编码 | |
} | |
byte[] data = new byte[1024]; | |
int nRead; | |
OutputStream os = inner.getResponseBody(); | |
while ((nRead = inputStream.read(data, 0, data.length)) != -1) { | |
os.write(data, 0, nRead); | |
} | |
os.flush(); | |
} catch (IOException e) { | |
e.printStackTrace(); | |
throw new RuntimeException("Failed to send response from InputStream", e); | |
} finally { | |
close(); | |
} | |
} | |
// 设置请求体大小限制 | |
public Exchange maxRequestBodySize(int size) { | |
this.maxRequestBodySize = size; | |
return this; | |
} | |
private void sendHeaders(long length) throws IOException { | |
if (headersSent) { | |
return; | |
} | |
inner.sendResponseHeaders(statusCode, length); | |
headersSent = true; | |
} | |
} | |
// 核心服务器实现类 | |
public static class Server<T extends Exchange> { | |
private final HttpServer server; | |
private final Function<HttpExchange, T> exchangeFactory; | |
private BiConsumer<T, Exception> errorHandler = this::defaultErrorHandler; | |
private int maxRequestBodySize = 2 * 1024 * 1024; // 默认请求体大小限制为2MB | |
private final Deque<String> routeStack = new ArrayDeque<>(); // 路由栈 | |
private final List<RouteFilter<T>> routeFilters = new ArrayList<>(); // 保存带路径的过滤器 | |
public Server(Function<HttpExchange, T> exchangeFactory) { | |
try { | |
this.server = HttpServer.create(); // 初始化为空的HttpServer | |
this.exchangeFactory = exchangeFactory; | |
} catch (IOException e) { | |
throw new RuntimeException("Failed to create HttpServer", e); | |
} | |
} | |
public Server<T> start(int port) throws IOException { | |
return start(port, 0); | |
} | |
public Server<T> start(int port, int backlog) { | |
try { | |
this.server.bind(new InetSocketAddress(port), backlog); | |
this.server.start(); | |
return this; | |
} catch (IOException e) { | |
throw new RuntimeException("start server error", e); | |
} | |
} | |
public Server<T> maxRequestBodySize(int size) { | |
this.maxRequestBodySize = size; | |
return this; | |
} | |
public Server<T> error(BiConsumer<T, Exception> handler) { | |
this.errorHandler = handler; | |
return this; | |
} | |
public Server<T> executor(Executor executor) { | |
server.setExecutor(executor); | |
return this; | |
} | |
public void stop() { | |
server.stop(0); | |
} | |
public void filter(String pathPrefix, BiConsumer<T, Consumer<T>> filter) { | |
routeFilters.add(new RouteFilter<>(pathPrefix, filter)); | |
} | |
public HttpContext route(String path, Consumer<T> handler) { | |
String fullPath = buildFullPath(path); | |
return server.createContext(fullPath, httpExchange -> { | |
T exchange = null; | |
try { | |
exchange = exchangeFactory.apply(httpExchange); | |
exchange.maxRequestBodySize(maxRequestBodySize); | |
applyFilters(exchange, handler, 0); | |
} catch (Exception e) { | |
if (exchange != null) { | |
try { | |
errorHandler.accept(exchange, e); | |
} catch (Exception ex) { | |
// 记录错误处理器的异常,但不抛出 | |
ex.printStackTrace(); | |
} | |
} | |
} finally { | |
if (exchange != null) { | |
exchange.close(); | |
} | |
} | |
}); | |
} | |
public void group(String path, Runnable groupRoutes) { | |
routeStack.push(path); | |
try { | |
groupRoutes.run(); | |
} finally { | |
routeStack.pop(); | |
} | |
} | |
private String buildFullPath(String path) { | |
StringBuilder fullPath = new StringBuilder(); | |
for (String part : routeStack) { | |
fullPath.insert(0, part); | |
} | |
fullPath.append(path); | |
return fullPath.toString(); | |
} | |
private void applyFilters(T exchange, Consumer<T> handler, int index) { | |
if (index < routeFilters.size()) { | |
RouteFilter<T> routeFilter = routeFilters.get(index); | |
if (exchange.path().startsWith(routeFilter.pathPrefix)) { | |
routeFilter.filter.accept(exchange, ex -> applyFilters(ex, handler, index + 1)); | |
} else { | |
applyFilters(exchange, handler, index + 1); | |
} | |
} else { | |
handler.accept(exchange); | |
} | |
} | |
private void defaultErrorHandler(T ctx, Exception e) { | |
int status; | |
String message; | |
if (e instanceof IllegalArgumentException) { | |
status = 400; | |
message = "Bad Request"; | |
} else if (e instanceof SecurityException) { | |
status = 403; | |
message = "Forbidden"; | |
} else if (e instanceof FileNotFoundException) { | |
status = 404; | |
message = "Not Found"; | |
} else if (e instanceof UnsupportedOperationException) { | |
status = 501; | |
message = "Not Implemented"; | |
} else if (e instanceof IOException) { | |
status = 503; | |
message = "Service Unavailable"; | |
} else { | |
status = 500; | |
message = "Internal Server Error"; | |
e.printStackTrace(); | |
} | |
message += ": " + e.getMessage(); | |
ctx.status(status).result(message); | |
} | |
} | |
protected static class RouteFilter<T> { | |
String pathPrefix; | |
BiConsumer<T, Consumer<T>> filter; | |
RouteFilter(String pathPrefix, BiConsumer<T, Consumer<T>> filter) { | |
this.pathPrefix = pathPrefix; | |
this.filter = filter; | |
} | |
} | |
public static class SseEmitter { | |
private final PrintWriter writer; | |
public SseEmitter(PrintWriter writer) { | |
this.writer = writer; | |
} | |
public void send(String data) throws IOException { | |
if (writer.checkError()) { // 检查writer状态 | |
throw new IOException("Client disconnected"); | |
} | |
writer.println("data: " + data); | |
writer.println(); | |
writer.flush(); | |
} | |
public void send(String event, String data) throws IOException { | |
if (writer.checkError()) { // 检查writer状态 | |
throw new IOException("Client disconnected"); | |
} | |
writer.println("event: " + event); | |
writer.println("data: " + data); | |
writer.println(); | |
writer.flush(); | |
} | |
} | |
// 工厂方法 | |
public static <T extends Exchange> Server<T> create(Function<HttpExchange, T> contextFactory) { | |
return new Server<>(contextFactory); | |
} | |
public static Server<Exchange> create() { | |
return new Server<>(Exchange::new); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment