Zion Boggan
repos/jwt-differential-fuzzer/targets/golang-jwt/server.go
zionboggan.com ↗
121 lines · go
History for this file →
1
package main
2
 
3
import (
4
	"crypto/x509"
5
	"encoding/json"
6
	"encoding/pem"
7
	"fmt"
8
	"io"
9
	"log"
10
	"net/http"
11
 
12
	"github.com/golang-jwt/jwt/v5"
13
)
14
 
15
const libID = "gojwt"
16
const libVersion = "5.2.1"
17
 
18
type req struct {
19
	Token string   `json:"token"`
20
	Key   any      `json:"key"`
21
	Algs  []string `json:"algs"`
22
}
23
 
24
type res struct {
25
	Valid   bool           `json:"valid"`
26
	Claims  map[string]any `json:"claims"`
27
	Error   string         `json:"error,omitempty"`
28
	Lib     string         `json:"lib"`
29
	Version string         `json:"version"`
30
}
31
 
32
func parseKey(keyMat any, alg string) (any, error) {
33
	keyStr, ok := keyMat.(string)
34
	if !ok {
35
		return nil, fmt.Errorf("unsupported key type")
36
	}
37
	switch alg[:2] {
38
	case "HS":
39
		return []byte(keyStr), nil
40
	case "RS", "PS":
41
		block, _ := pem.Decode([]byte(keyStr))
42
		if block == nil {
43
			return nil, fmt.Errorf("pem decode failed")
44
		}
45
		k, err := x509.ParsePKIXPublicKey(block.Bytes)
46
		if err != nil {
47
			k2, err2 := x509.ParseCertificate(block.Bytes)
48
			if err2 != nil {
49
				return nil, err
50
			}
51
			return k2.PublicKey, nil
52
		}
53
		return k, nil
54
	case "ES":
55
		block, _ := pem.Decode([]byte(keyStr))
56
		if block == nil {
57
			return nil, fmt.Errorf("pem decode failed")
58
		}
59
		k, err := x509.ParsePKIXPublicKey(block.Bytes)
60
		if err != nil {
61
			return nil, err
62
		}
63
		return k, nil
64
	default:
65
		return []byte(keyStr), nil
66
	}
67
}
68
 
69
func verdict(body []byte) res {
70
	var r req
71
	if err := json.Unmarshal(body, &r); err != nil {
72
		return res{Valid: false, Error: "bad json: " + err.Error(), Lib: libID, Version: libVersion}
73
	}
74
	if len(r.Algs) == 0 {
75
		return res{Valid: false, Error: "no algs", Lib: libID, Version: libVersion}
76
	}
77
	keyfn := func(t *jwt.Token) (any, error) {
78
		alg, ok := t.Header["alg"].(string)
79
		if !ok {
80
			return nil, fmt.Errorf("no alg")
81
		}
82
		permitted := false
83
		for _, a := range r.Algs {
84
			if a == alg {
85
				permitted = true
86
				break
87
			}
88
		}
89
		if !permitted {
90
			return nil, fmt.Errorf("alg %s not in allowlist", alg)
91
		}
92
		return parseKey(r.Key, alg)
93
	}
94
	parsed, err := jwt.Parse(r.Token, keyfn, jwt.WithValidMethods(r.Algs))
95
	if err != nil || !parsed.Valid {
96
		msg := "invalid"
97
		if err != nil {
98
			msg = err.Error()
99
		}
100
		return res{Valid: false, Error: msg, Lib: libID, Version: libVersion}
101
	}
102
	claims, _ := parsed.Claims.(jwt.MapClaims)
103
	return res{Valid: true, Claims: claims, Lib: libID, Version: libVersion}
104
}
105
 
106
func handler(w http.ResponseWriter, r *http.Request) {
107
	if r.Method != http.MethodPost || r.URL.Path != "/verify" {
108
		w.WriteHeader(http.StatusNotFound)
109
		return
110
	}
111
	body, _ := io.ReadAll(r.Body)
112
	out := verdict(body)
113
	w.Header().Set("Content-Type", "application/json")
114
	json.NewEncoder(w).Encode(out)
115
}
116
 
117
func main() {
118
	http.HandleFunc("/verify", handler)
119
	log.Printf("[%s %s] listening :7005", libID, libVersion)
120
	log.Fatal(http.ListenAndServe(":7005", nil))
121
}