Skip to content

Instantly share code, notes, and snippets.

@kran
Last active January 4, 2025 13:35
Show Gist options
  • Save kran/46798d857bc7268665865ec20b83a736 to your computer and use it in GitHub Desktop.
Save kran/46798d857bc7268665865ec20b83a736 to your computer and use it in GitHub Desktop.
a little helper, setting up a cozy HTTP server framework that handles requests, manages routes
import com.sun.net.httpserver.Headers;
import com.sun.net.httpserver.HttpContext;
import com.sun.net.httpserver.HttpExchange;
import com.sun.net.httpserver.HttpServer;
import java.io.*;
import java.net.InetSocketAddress;
import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.util.*;
import java.util.concurrent.Executor;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;
public class Doge {
// 请求上下文的包装类
public static class Exchange {
public final HttpExchange inner;
protected final Map<String, Object> attributes = new HashMap<>();
protected int statusCode = 200; // 默认状态码
protected int maxRequestBodySize = Integer.MAX_VALUE; // 默认无限制
protected byte[] cachedBody;
protected Map<String, String> cachedQueryParams;
protected static final int DEFAULT_BUFFER_SIZE = 8192;
private boolean headersSent = false;
private boolean closed = false;
public Exchange(HttpExchange inner) {
this.inner = inner;
}
public String path() {
return inner.getRequestURI().getPath();
}
public String method() {
return inner.getRequestMethod();
}
public Headers headers() {
return inner.getRequestHeaders();
}
// 用户定义的JSON输出方法
public void json(Object obj) {
// 用户需要在子类中实现具体的JSON序列化逻辑
throw new UnsupportedOperationException("JSON serialization not implemented");
}
public Exchange status(int code) {
this.statusCode = code; // 仅设置状态码,不发送响应头
return this;
}
public Exchange contentType(String type) {
inner.getResponseHeaders().set("Content-Type", type);
return this;
}
public void download(File file, String filename) {
if (!file.exists() || !file.isFile()) {
status(404).result("File not found");
return;
}
contentType("application/octet-stream")
.header("Content-Disposition", "attachment; filename=\"" + filename + "\"");
try {
byte[] data = Files.readAllBytes(file.toPath());
result(data);
} catch (IOException e) {
throw new RuntimeException("File read failed", e);
}
}
public void redirect(String location) {
status(302)
.header("Location", location)
.result("");
}
public Exchange attr(String name, Object value) {
attributes.put(name, value);
return this;
}
@SuppressWarnings("unchecked")
public <T> T attr(String name) {
return (T) attributes.get(name);
}
public <T> T attr(String name, Class<T> type) {
T value = attr(name);
if (value == null) {
throw new IllegalStateException(type.getSimpleName() + " not found in context");
}
return value;
}
public void sse(Consumer<SseEmitter> handler) {
contentType("text/event-stream")
.header("Cache-Control", "no-cache")
.header("Connection", "keep-alive");
try (PrintWriter writer = new PrintWriter(new OutputStreamWriter(inner.getResponseBody()))) {
SseEmitter emitter = new SseEmitter(writer);
inner.sendResponseHeaders(200, 0);
handler.accept(emitter);
} catch (IOException e) {
throw new RuntimeException("SSE failed", e);
}
}
public void close() {
if (closed) {
return;
}
try {
inner.close();
} finally {
closed = true;
}
}
public static String urlDecode(String str) {
return URLDecoder.decode(str, StandardCharsets.UTF_8);
}
// 获取查询参数
public Map<String, String> params() {
if (cachedQueryParams != null) {
return cachedQueryParams;
}
String query = inner.getRequestURI().getQuery();
cachedQueryParams = new LinkedHashMap<>();
if (query == null || query.trim().equals("")) {
return cachedQueryParams;
}
for (String param : query.split("&")) {
int idx = param.indexOf('=');
if (idx > 0) {
String key = urlDecode(param.substring(0, idx));
String value = urlDecode(param.substring(idx + 1));
cachedQueryParams.put(key, value);
} else if (param.length() > 0) {
cachedQueryParams.put(urlDecode(param), "");
}
}
return cachedQueryParams;
}
// 获取请求体为字符串
public String body() {
return new String(bytes(), StandardCharsets.UTF_8);
}
public byte[] bytes() {
if (cachedBody != null) {
return cachedBody;
}
try (InputStream is = inner.getRequestBody();
ByteArrayOutputStream result = new ByteArrayOutputStream()) {
byte[] buffer = new byte[DEFAULT_BUFFER_SIZE];
int totalRead = 0;
int length;
while ((length = is.read(buffer)) != -1) {
totalRead += length;
if (totalRead > maxRequestBodySize) {
throw new IOException("Request body too large");
}
result.write(buffer, 0, length);
}
cachedBody = result.toByteArray();
return cachedBody;
} catch (IOException e) {
throw new RuntimeException("Failed to read request body", e);
}
}
// 设响应头
public Exchange header(String name, String value) {
inner.getResponseHeaders().set(name, value);
return this;
}
// 获取单个请求头
public String header(String name) {
return inner.getRequestHeaders().getFirst(name);
}
// 设置Cookie
public Exchange cookie(String name, String value, int maxAge) {
String cookie = String.format("%s=%s; Max-Age=%d; Path=/", name, value, maxAge);
inner.getResponseHeaders().add("Set-Cookie", cookie);
return this;
}
// 获取Cookie
public String cookie(String name) {
List<String> cookies = inner.getRequestHeaders().get("Cookie");
if (cookies != null) {
for (String cookie : cookies) {
for (String pair : cookie.split(";")) {
String[] keyValue = pair.trim().split("=");
if (keyValue.length == 2 && keyValue[0].equals(name)) {
return keyValue[1];
}
}
}
}
return null;
}
public void html(String htmlContent) {
contentType("text/html; charset=UTF-8");
result(htmlContent);
}
public void staticFile(File root, String path) {
File file = new File(root, path);
try {
if(!file.exists() || !file.isFile()) {
status(404).result("File not found: " + file);
return;
}
if(!file.getCanonicalPath().startsWith(root.getCanonicalPath())) {
status(403).result("Forbidden: " + file);
return;
}
byte[] fileBytes = Files.readAllBytes(file.toPath());
String contentType = Files.probeContentType(file.toPath());
contentType(contentType != null ? contentType : "application/octet-stream");
result(fileBytes);
} catch (IOException e) {
throw new RuntimeException("Read file error: " + file, e);
} finally {
close();
}
}
public void result(Object obj) {
String response = obj != null ? obj.toString() : "null";
byte[] responseBytes = response.getBytes(StandardCharsets.UTF_8);
result(responseBytes);
}
public void result(byte[] responseBytes) {
try {
sendHeaders(responseBytes.length);
inner.getResponseBody().write(responseBytes);
} catch (IOException e) {
e.printStackTrace();
throw new RuntimeException("Failed to send response", e);
} finally {
close();
}
}
public void result(InputStream inputStream) {
try {
int availableBytes = inputStream.available();
if (availableBytes > 0) {
sendHeaders(availableBytes);
} else {
inner.getResponseHeaders().set("Transfer-Encoding", "chunked");
sendHeaders(0); // 使用分块传输编码
}
byte[] data = new byte[1024];
int nRead;
OutputStream os = inner.getResponseBody();
while ((nRead = inputStream.read(data, 0, data.length)) != -1) {
os.write(data, 0, nRead);
}
os.flush();
} catch (IOException e) {
e.printStackTrace();
throw new RuntimeException("Failed to send response from InputStream", e);
} finally {
close();
}
}
// 设置请求体大小限制
public Exchange maxRequestBodySize(int size) {
this.maxRequestBodySize = size;
return this;
}
private void sendHeaders(long length) throws IOException {
if (headersSent) {
return;
}
inner.sendResponseHeaders(statusCode, length);
headersSent = true;
}
}
// 核心服务器实现类
public static class Server<T extends Exchange> {
private final HttpServer server;
private final Function<HttpExchange, T> exchangeFactory;
private BiConsumer<T, Exception> errorHandler = this::defaultErrorHandler;
private int maxRequestBodySize = 2 * 1024 * 1024; // 默认请求体大小限制为2MB
private final Deque<String> routeStack = new ArrayDeque<>(); // 路由栈
private final List<RouteFilter<T>> routeFilters = new ArrayList<>(); // 保存带路径的过滤器
public Server(Function<HttpExchange, T> exchangeFactory) {
try {
this.server = HttpServer.create(); // 初始化为空的HttpServer
this.exchangeFactory = exchangeFactory;
} catch (IOException e) {
throw new RuntimeException("Failed to create HttpServer", e);
}
}
public Server<T> start(int port) throws IOException {
return start(port, 0);
}
public Server<T> start(int port, int backlog) {
try {
this.server.bind(new InetSocketAddress(port), backlog);
this.server.start();
return this;
} catch (IOException e) {
throw new RuntimeException("start server error", e);
}
}
public Server<T> maxRequestBodySize(int size) {
this.maxRequestBodySize = size;
return this;
}
public Server<T> error(BiConsumer<T, Exception> handler) {
this.errorHandler = handler;
return this;
}
public Server<T> executor(Executor executor) {
server.setExecutor(executor);
return this;
}
public void stop() {
server.stop(0);
}
public void filter(String pathPrefix, BiConsumer<T, Consumer<T>> filter) {
routeFilters.add(new RouteFilter<>(pathPrefix, filter));
}
public HttpContext route(String path, Consumer<T> handler) {
String fullPath = buildFullPath(path);
return server.createContext(fullPath, httpExchange -> {
T exchange = null;
try {
exchange = exchangeFactory.apply(httpExchange);
exchange.maxRequestBodySize(maxRequestBodySize);
applyFilters(exchange, handler, 0);
} catch (Exception e) {
if (exchange != null) {
try {
errorHandler.accept(exchange, e);
} catch (Exception ex) {
// 记录错误处理器的异常,但不抛出
ex.printStackTrace();
}
}
} finally {
if (exchange != null) {
exchange.close();
}
}
});
}
public void group(String path, Runnable groupRoutes) {
routeStack.push(path);
try {
groupRoutes.run();
} finally {
routeStack.pop();
}
}
private String buildFullPath(String path) {
StringBuilder fullPath = new StringBuilder();
for (String part : routeStack) {
fullPath.insert(0, part);
}
fullPath.append(path);
return fullPath.toString();
}
private void applyFilters(T exchange, Consumer<T> handler, int index) {
if (index < routeFilters.size()) {
RouteFilter<T> routeFilter = routeFilters.get(index);
if (exchange.path().startsWith(routeFilter.pathPrefix)) {
routeFilter.filter.accept(exchange, ex -> applyFilters(ex, handler, index + 1));
} else {
applyFilters(exchange, handler, index + 1);
}
} else {
handler.accept(exchange);
}
}
private void defaultErrorHandler(T ctx, Exception e) {
int status;
String message;
if (e instanceof IllegalArgumentException) {
status = 400;
message = "Bad Request";
} else if (e instanceof SecurityException) {
status = 403;
message = "Forbidden";
} else if (e instanceof FileNotFoundException) {
status = 404;
message = "Not Found";
} else if (e instanceof UnsupportedOperationException) {
status = 501;
message = "Not Implemented";
} else if (e instanceof IOException) {
status = 503;
message = "Service Unavailable";
} else {
status = 500;
message = "Internal Server Error";
e.printStackTrace();
}
message += ": " + e.getMessage();
ctx.status(status).result(message);
}
}
protected static class RouteFilter<T> {
String pathPrefix;
BiConsumer<T, Consumer<T>> filter;
RouteFilter(String pathPrefix, BiConsumer<T, Consumer<T>> filter) {
this.pathPrefix = pathPrefix;
this.filter = filter;
}
}
public static class SseEmitter {
private final PrintWriter writer;
public SseEmitter(PrintWriter writer) {
this.writer = writer;
}
public void send(String data) throws IOException {
if (writer.checkError()) { // 检查writer状态
throw new IOException("Client disconnected");
}
writer.println("data: " + data);
writer.println();
writer.flush();
}
public void send(String event, String data) throws IOException {
if (writer.checkError()) { // 检查writer状态
throw new IOException("Client disconnected");
}
writer.println("event: " + event);
writer.println("data: " + data);
writer.println();
writer.flush();
}
}
// 工厂方法
public static <T extends Exchange> Server<T> create(Function<HttpExchange, T> contextFactory) {
return new Server<>(contextFactory);
}
public static Server<Exchange> create() {
return new Server<>(Exchange::new);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment