diff --git a/main.go b/main.go index 58ef693..65be33c 100644 --- a/main.go +++ b/main.go @@ -12,6 +12,7 @@ import ( "os" "regexp" "runtime" + "strings" "sync/atomic" "syscall" "time" @@ -251,6 +252,12 @@ func root(w http.ResponseWriter, req *http.Request) { // CustomHandler wraps the default promhttp.Handler with custom logic func metricsHandler() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + // To prevent accessing from the bare IP address + if req.Host == "" || net.ParseIP(strings.Split(req.Host, ":")[0]) != nil { + w.WriteHeader(444) + return + } + metrics.Uptime.Set(float64(time.Duration(time.Since(programInit).Seconds()))) promhttp.Handler().ServeHTTP(w, req) }) @@ -296,10 +303,30 @@ func requestPerMinute() { } } -func beforeAll(next http.HandlerFunc) http.HandlerFunc { +func beforeMisc(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { defer panicHandler(w) + // To prevent accessing from the bare IP address + if req.Host == "" || net.ParseIP(strings.Split(req.Host, ":")[0]) != nil { + w.WriteHeader(444) + return + } + + next(w, req) + } +} + +func beforeProxy(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + defer panicHandler(w) + + // To prevent accessing from the bare IP address + if req.Host == "" || net.ParseIP(strings.Split(req.Host, ":")[0]) != nil { + w.WriteHeader(444) + return + } + if h3s { w.Header().Set("Alt-Svc", "h3=\":8443\"; ma=86400") } @@ -378,9 +405,10 @@ func main() { mux := http.NewServeMux() - mux.HandleFunc("/", root) - mux.HandleFunc("/health", health) - mux.HandleFunc("/stats", stats) + // MISC ROUTES + mux.HandleFunc("/", beforeMisc(root)) + mux.HandleFunc("/health", beforeMisc(health)) + mux.HandleFunc("/stats", beforeMisc(stats)) prometheus.MustRegister(metrics.Uptime) prometheus.MustRegister(metrics.ActiveConnections) @@ -396,13 +424,14 @@ func main() { mux.Handle("/metrics", metricsHandler()) - mux.HandleFunc("/videoplayback", beforeAll(videoplayback)) - mux.HandleFunc("/vi/", beforeAll(vi)) - mux.HandleFunc("/vi_webp/", beforeAll(vi)) - mux.HandleFunc("/sb/", beforeAll(vi)) - mux.HandleFunc("/ggpht/", beforeAll(ggpht)) - mux.HandleFunc("/a/", beforeAll(ggpht)) - mux.HandleFunc("/ytc/", beforeAll(ggpht)) + // PROXY ROUTES + mux.HandleFunc("/videoplayback", beforeProxy(videoplayback)) + mux.HandleFunc("/vi/", beforeProxy(vi)) + mux.HandleFunc("/vi_webp/", beforeProxy(vi)) + mux.HandleFunc("/sb/", beforeProxy(vi)) + mux.HandleFunc("/ggpht/", beforeProxy(ggpht)) + mux.HandleFunc("/a/", beforeProxy(ggpht)) + mux.HandleFunc("/ytc/", beforeProxy(ggpht)) go requestPerSecond() go requestPerMinute()