@@ 15,6 15,7 @@ import (
"github.com/vektah/gqlparser/gqlerror"
+ "git.sr.ht/~sircmpwn/git.sr.ht/api/crypto"
"git.sr.ht/~sircmpwn/git.sr.ht/api/database"
)
@@ 61,6 62,77 @@ func authError(w http.ResponseWriter, reason string, code int) {
w.Write(b)
}
+type AuthCookie struct {
+ Name string `json:"name"`
+}
+
+func cookieAuth(db *sql.DB, cookie *http.Cookie,
+ w http.ResponseWriter, r *http.Request, next http.Handler) {
+
+ payload := crypto.Decrypt([]byte(cookie.Value))
+ if payload == nil {
+ authError(w, "Invalid authentication cookie", http.StatusForbidden)
+ return
+ }
+
+ var (
+ auth AuthCookie
+ err error
+ rows *sql.Rows
+ user User
+ )
+ if err := json.Unmarshal(payload, &auth); err != nil {
+ authError(w, "Invalid authentication cookie", http.StatusForbidden)
+ return
+ }
+
+ query := database.
+ Select(context.TODO(), []string{
+ `u.id`, `u.username`,
+ `u.created`, `u.updated`,
+ `u.email`,
+ `u.user_type`,
+ `u.url`, `u.location`, `u.bio`,
+ `u.suspension_notice`,
+ }).
+ From(`"user" u`).
+ Where(`u.username = ?`, auth.Name)
+ if rows, err = query.RunWith(db).Query(); err != nil {
+ panic(err)
+ }
+ defer rows.Close()
+
+ if !rows.Next() {
+ if err := rows.Err(); err != nil {
+ panic(err)
+ }
+ authError(w, "Invalid or expired OAuth token", http.StatusForbidden)
+ return
+ }
+ if err := rows.Scan(&user.ID, &user.Username, &user.Created, &user.Updated,
+ &user.Email, &user.UserType, &user.URL, &user.Location, &user.Bio,
+ &user.SuspensionNotice); err != nil {
+ panic(err)
+ }
+ if rows.Next() {
+ if err := rows.Err(); err != nil {
+ panic(err)
+ }
+ panic(errors.New("Multiple matching user accounts; invariant broken"))
+ }
+
+ if user.UserType == USER_SUSPENDED {
+ authError(w, fmt.Sprintf("Account suspended with the following notice: %s\nContact support",
+ user.SuspensionNotice), http.StatusForbidden)
+ return
+ }
+
+ ctx := context.WithValue(r.Context(), userCtxKey, &user)
+
+ r = r.WithContext(ctx)
+ next.ServeHTTP(w, r)
+}
+
func Middleware(db *sql.DB) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ 69,6 141,12 @@ func Middleware(db *sql.DB) func(http.Handler) http.Handler {
return
}
+ cookie, err := r.Cookie("sr.ht.unified-login.v1")
+ if err == nil {
+ cookieAuth(db, cookie, w, r, next)
+ return
+ }
+
auth := r.Header.Get("Authorization")
if auth == "" {
authError(w, `Authorization header is required.
@@ 100,7 178,6 @@ Expected 'Authorization: Bearer <token>'`, http.StatusForbidden)
}
var (
- err error
expires time.Time
rows *sql.Rows
scopes string