From 8b27e005d0c6bfaa180dbb8662cfbf1cc52ba409 Mon Sep 17 00:00:00 2001 From: Drew DeVault Date: Tue, 14 Apr 2020 14:38:51 -0400 Subject: [PATCH] loaders: use squirrel for query construction --- api/auth/auth.go | 19 ++--- api/database/ql.go | 80 ++++++++++++++++++++ api/database/sq.go | 31 ++++++++ api/go.mod | 2 + api/go.sum | 9 +++ api/graph/model/blob.go | 8 +- api/graph/model/commit.go | 16 ++-- api/graph/model/object.go | 1 - api/graph/model/repository.go | 100 +++++++++++++++---------- api/graph/model/tree.go | 12 +-- api/graph/model/user.go | 59 ++++++++------- api/graph/model/util.go | 10 ++- api/loaders/middleware.go | 136 +++++++++++++++++++--------------- api/server.go | 4 +- 14 files changed, 328 insertions(+), 159 deletions(-) create mode 100644 api/database/ql.go create mode 100644 api/database/sq.go diff --git a/api/auth/auth.go b/api/auth/auth.go index b376e23..bd73a90 100644 --- a/api/auth/auth.go +++ b/api/auth/auth.go @@ -17,6 +17,7 @@ import ( ) var userCtxKey = &contextKey{"user"} + type contextKey struct { name string } @@ -24,14 +25,14 @@ type contextKey struct { var bearerRegex = regexp.MustCompile(`^[0-9a-f]{32}$`) const ( - USER_UNCONFIRMED = "unconfirmed" - USER_ACTIVE_NON_PAYING = "active_non_paying" - USER_ACTIVE_FREE = "active_free" - USER_ACTIVE_PAYING = "active_paying" - USER_ACTIVE_DELINQUENT = "active_delinquent" - USER_ADMIN = "admin" - USER_UNKNOWN = "unknown" - USER_SUSPENDED = "suspended" + USER_UNCONFIRMED = "unconfirmed" + USER_ACTIVE_NON_PAYING = "active_non_paying" + USER_ACTIVE_FREE = "active_free" + USER_ACTIVE_PAYING = "active_paying" + USER_ACTIVE_DELINQUENT = "active_delinquent" + USER_ADMIN = "admin" + USER_UNKNOWN = "unknown" + USER_SUSPENDED = "suspended" ) type User struct { @@ -80,7 +81,7 @@ Expected 'Authentication: Bearer '`, http.StatusForbidden) } var bearer string - switch (z[0]) { + switch z[0] { case "Bearer": token := []byte(z[1]) if !bearerRegex.Match(token) { diff --git a/api/database/ql.go b/api/database/ql.go new file mode 100644 index 0000000..459ddfa --- /dev/null +++ b/api/database/ql.go @@ -0,0 +1,80 @@ +package database + +import ( + "context" + "sort" + + "github.com/vektah/gqlparser/v2/ast" + + "git.sr.ht/~sircmpwn/gqlgen/graphql" +) + +func ColumnsFor(ctx context.Context, alias string, + colMap map[string]string) []string { + + var fields []graphql.CollectedField + if graphql.GetFieldContext(ctx) != nil { + fields = graphql.CollectFieldsCtx(ctx, nil) + } else { + // Collect all fields if we are not in an active graphql context + for qlCol, _ := range colMap { + fields = append(fields, graphql.CollectedField{ + &ast.Field{Name: qlCol}, nil, + }) + } + } + + sort.Slice(fields, func(a, b int) bool { + return fields[a].Name < fields[b].Name + }) + + var columns []string + for _, qlCol := range fields { + if sqlCol, ok := colMap[qlCol.Name]; ok { + if alias != "" { + columns = append(columns, alias+"."+sqlCol) + } else { + columns = append(columns, sqlCol) + } + } + } + + return columns +} + +func FieldsFor(ctx context.Context, + colMap map[string]interface{}) []interface{} { + + var qlFields []graphql.CollectedField + if graphql.GetFieldContext(ctx) != nil { + qlFields = graphql.CollectFieldsCtx(ctx, nil) + } else { + // Collect all fields if we are not in an active graphql context + for qlCol, _ := range colMap { + qlFields = append(qlFields, graphql.CollectedField{ + &ast.Field{Name: qlCol}, nil, + }) + } + } + + sort.Slice(qlFields, func(a, b int) bool { + return qlFields[a].Name < qlFields[b].Name + }) + + var fields []interface{} + for _, qlField := range qlFields { + if field, ok := colMap[qlField.Name]; ok { + fields = append(fields, field) + } + } + + return fields +} + +func WithAlias(alias, col string) string { + if alias != "" { + return alias + "." + col + } else { + return col + } +} diff --git a/api/database/sq.go b/api/database/sq.go new file mode 100644 index 0000000..82c8a2b --- /dev/null +++ b/api/database/sq.go @@ -0,0 +1,31 @@ +package database + +import ( + "context" + "fmt" + + sq "github.com/Masterminds/squirrel" +) + +type Selectable interface { + As(alias string) Selectable + Select(ctx context.Context) []string + Fields(ctx context.Context) []interface{} +} + +func Select(ctx context.Context, cols ...interface{}) sq.SelectBuilder { + q := sq.Select().PlaceholderFormat(sq.Dollar) + for _, col := range cols { + switch col := col.(type) { + case string: + q = q.Columns(col) + case []string: + q = q.Columns(col...) + case Selectable: + q = q.Columns(col.Select(ctx)...) + default: + panic(fmt.Errorf("Unknown selectable type %T", col)) + } + } + return q +} diff --git a/api/go.mod b/api/go.mod index a7c8f53..4229e0f 100644 --- a/api/go.mod +++ b/api/go.mod @@ -5,9 +5,11 @@ go 1.14 require ( git.sr.ht/~sircmpwn/getopt v0.0.0-20191230200459-23622cc906b3 // indirect git.sr.ht/~sircmpwn/gqlgen v0.0.0-20200412134447-57d7234737d4 + github.com/Masterminds/squirrel v1.2.0 github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect github.com/go-chi/chi v3.3.2+incompatible github.com/go-git/go-git/v5 v5.0.0 + github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/gorilla/websocket v1.4.2 // indirect github.com/hashicorp/golang-lru v0.5.4 // indirect github.com/lib/pq v1.3.0 diff --git a/api/go.sum b/api/go.sum index e784227..1ea91e5 100644 --- a/api/go.sum +++ b/api/go.sum @@ -5,6 +5,8 @@ git.sr.ht/~sircmpwn/git.sr.ht v0.0.0-20200413150414-046cd382d7b7 h1:PYRTIcsHR5W+ git.sr.ht/~sircmpwn/gqlgen v0.0.0-20200412134447-57d7234737d4 h1:J/Sb88htNHzZaN6ZEF8BnRWj3LzYoTrOL4WRhZEEiQE= git.sr.ht/~sircmpwn/gqlgen v0.0.0-20200412134447-57d7234737d4/go.mod h1:W1cijL2EqAyL1eo1WAJ3ijNVkZM2okpYyCF5TRu1VfI= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/Masterminds/squirrel v1.2.0 h1:K1NhbTO21BWG47IVR0OnIZuE0LZcXAYqywrC3Ko53KI= +github.com/Masterminds/squirrel v1.2.0/go.mod h1:yaPeOnPG5ZRwL9oKdTsO/prlkPbXWZlRVMQ/gGlzIuA= github.com/agnivade/levenshtein v1.0.1/go.mod h1:CURSv5d9Uaml+FovSIICkLbAUZ9S4RqaHDIsdSBg7lM= github.com/agnivade/levenshtein v1.0.3 h1:M5ZnqLOoZR8ygVq0FfkXsNOKzMCk0xRiow0R5+5VkQ0= github.com/agnivade/levenshtein v1.0.3/go.mod h1:4SFRZbbXWLF4MU1T9Qg0pGgH3Pjs+t6ie5efyrwRJXs= @@ -38,6 +40,8 @@ github.com/go-git/go-git/v5 v5.0.0 h1:k5RWPm4iJwYtfWoxIJy4wJX9ON7ihPeZZYC1fLYDnp github.com/go-git/go-git/v5 v5.0.0/go.mod h1:oYD8y9kWsGINPFJoLdaScGCN6dlKg23blmClfZwtUVA= github.com/gogo/protobuf v1.0.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/gorilla/context v0.0.0-20160226214623-1ea25387ff6f/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/mux v1.6.1/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/websocket v1.2.0 h1:VJtLvh6VQym50czpZzx07z/kw9EgAxI3x1ZB8taTMQQ= @@ -57,6 +61,10 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 h1:SOEGU9fKiNWd/HOJuq6+3iTQz8KNCLtVX6idSoTLdUw= +github.com/lann/builder v0.0.0-20180802200727-47ae307949d0/go.mod h1:dXGbAdH5GtBTC4WfIxhKZfyBF/HBFgRZSWwZ9g/He9o= +github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 h1:P6pPBnrTSX3DEVR4fDembhRWSsG5rVo6hYhAB/ADZrk= +github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0/go.mod h1:vmVJ0l/dxyfGW6FmdpVm2joNMFikkuWg0EoCKLGUMNw= github.com/lib/pq v1.3.0 h1:/qkRGz8zljWiDcFvgpwUpwIAPu3r07TDvs3Rws+o/pU= github.com/lib/pq v1.3.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/logrusorgru/aurora v0.0.0-20200102142835-e9ef32dff381/go.mod h1:7rIyQOR62GCctdiQpZ/zOJlFyk6y+94wXzv6RNZgaR4= @@ -93,6 +101,7 @@ github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeV github.com/shurcooL/vfsgen v0.0.0-20180121065927-ffb13db8def0/go.mod h1:TrYk7fJVaAttu97ZZKrO9UbRa8izdowaMIZcxYMbVaw= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.1/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/urfave/cli/v2 v2.1.1 h1:Qt8FeAtxE/vfdrLmR3rxR6JRE0RoVmbXu8+6kZtYU4k= diff --git a/api/graph/model/blob.go b/api/graph/model/blob.go index d7d7389..fff363c 100644 --- a/api/graph/model/blob.go +++ b/api/graph/model/blob.go @@ -10,10 +10,10 @@ import ( ) type Blob struct { - Type ObjectType `json:"type"` - ID string `json:"id"` - ShortID string `json:"shortId"` - Raw string `json:"raw"` + Type ObjectType `json:"type"` + ID string `json:"id"` + ShortID string `json:"shortId"` + Raw string `json:"raw"` blob *object.Blob repo *git.Repository diff --git a/api/graph/model/commit.go b/api/graph/model/commit.go index 93532b9..6ae93d3 100644 --- a/api/graph/model/commit.go +++ b/api/graph/model/commit.go @@ -6,10 +6,10 @@ import ( ) type Commit struct { - Type ObjectType `json:"type"` - ID string `json:"id"` - ShortID string `json:"shortId"` - Raw string `json:"raw"` + Type ObjectType `json:"type"` + ID string `json:"id"` + ShortID string `json:"shortId"` + Raw string `json:"raw"` commit *object.Commit repo *git.Repository @@ -23,17 +23,17 @@ func (c *Commit) Message() string { func (c *Commit) Author() *Signature { return &Signature{ - Name: c.commit.Author.Name, + Name: c.commit.Author.Name, Email: c.commit.Author.Email, - Time: c.commit.Author.When, + Time: c.commit.Author.When, } } func (c *Commit) Committer() *Signature { return &Signature{ - Name: c.commit.Committer.Name, + Name: c.commit.Committer.Name, Email: c.commit.Committer.Email, - Time: c.commit.Committer.When, + Time: c.commit.Committer.When, } } diff --git a/api/graph/model/object.go b/api/graph/model/object.go index 8d976ae..2d7ffd0 100644 --- a/api/graph/model/object.go +++ b/api/graph/model/object.go @@ -51,4 +51,3 @@ func LookupObject(repo *git.Repository, hash plumbing.Hash) (Object, error) { return nil, errors.New("Unknown object type") } } - diff --git a/api/graph/model/repository.go b/api/graph/model/repository.go index 0c8d77f..801a3ce 100644 --- a/api/graph/model/repository.go +++ b/api/graph/model/repository.go @@ -6,52 +6,29 @@ import ( "time" "github.com/go-git/go-git/v5" + + "git.sr.ht/~sircmpwn/git.sr.ht/api/database" ) type Repository struct { - ID int `json:"id"` - Created time.Time `json:"created"` - Updated time.Time `json:"updated"` - Name string `json:"name"` - Description *string `json:"description"` - Visibility Visibility `json:"visibility"` - UpstreamURL *string `json:"upstreamUrl"` - Objects []Object `json:"objects"` - Log []*Commit `json:"log"` - Tree *Tree `json:"tree"` - File *Blob `json:"file"` - RevparseSingle Object `json:"revparse_single"` + ID int `json:"id"` + Created time.Time `json:"created"` + Updated time.Time `json:"updated"` + Name string `json:"name"` + Description *string `json:"description"` + Visibility Visibility `json:"visibility"` + UpstreamURL *string `json:"upstreamUrl"` + Objects []Object `json:"objects"` + Log []*Commit `json:"log"` + Tree *Tree `json:"tree"` + File *Blob `json:"file"` + RevparseSingle Object `json:"revparse_single"` Path string OwnerID int - repo *git.Repository -} - -func (r *Repository) Columns(ctx context.Context, tbl string) string { - columns := ColumnsFor(ctx, map[string]string{ - "id": "id", - "created": "created", - "updated": "updated", - "name": "name", - "description": "description", - "visibility": "visibility", - "upstreamUrl": "upstream_uri", - }, tbl) - return strings.Join(append(columns, tbl + ".path", tbl + ".owner_id"), ", ") -} - -func (r *Repository) Fields(ctx context.Context) []interface{} { - fields := FieldsFor(ctx, map[string]interface{}{ - "id": &r.ID, - "created": &r.Created, - "updated": &r.Updated, - "name": &r.Name, - "description": &r.Description, - "visibility": &r.Visibility, - "upstream_url": &r.UpstreamURL, - }) - return append(fields, &r.Path, &r.OwnerID) + alias string + repo *git.Repository } func (r *Repository) Repo() *git.Repository { @@ -73,3 +50,48 @@ func (r *Repository) Head() *Reference { } return &Reference{Ref: ref, Repo: r.repo} } + +func (r *Repository) Columns(ctx context.Context, tbl string) string { + columns := ColumnsFor(ctx, map[string]string{ + "id": "id", + "created": "created", + "updated": "updated", + "name": "name", + "description": "description", + "visibility": "visibility", + "upstreamUrl": "upstream_uri", + }, tbl) + return strings.Join(append(columns, tbl+".path", tbl+".owner_id"), ", ") +} + +func (r *Repository) Select(ctx context.Context) []string { + return append(database.ColumnsFor(ctx, r.alias, map[string]string{ + "id": "id", + "created": "created", + "updated": "updated", + "name": "name", + "description": "description", + "visibility": "visibility", + "upstreamUrl": "upstream_uri", + }), + database.WithAlias(r.alias, "path"), + database.WithAlias(r.alias, "owner_id")) +} + +func (r *Repository) As(alias string) database.Selectable { + r.alias = alias + return r +} + +func (r *Repository) Fields(ctx context.Context) []interface{} { + fields := FieldsFor(ctx, map[string]interface{}{ + "id": &r.ID, + "created": &r.Created, + "updated": &r.Updated, + "name": &r.Name, + "description": &r.Description, + "visibility": &r.Visibility, + "upstream_url": &r.UpstreamURL, + }) + return append(fields, &r.Path, &r.OwnerID) +} diff --git a/api/graph/model/tree.go b/api/graph/model/tree.go index a781e6a..222dce9 100644 --- a/api/graph/model/tree.go +++ b/api/graph/model/tree.go @@ -9,10 +9,10 @@ import ( ) type Tree struct { - Type ObjectType `json:"type"` - ID string `json:"id"` - ShortID string `json:"shortId"` - Raw string `json:"raw"` + Type ObjectType `json:"type"` + ID string `json:"id"` + ShortID string `json:"shortId"` + Raw string `json:"raw"` tree *object.Tree repo *git.Repository @@ -21,8 +21,8 @@ type Tree struct { func (Tree) IsObject() {} type TreeEntry struct { - Name string `json:"name"` - Mode int `json:"mode"` + Name string `json:"name"` + Mode int `json:"mode"` hash plumbing.Hash repo *git.Repository diff --git a/api/graph/model/user.go b/api/graph/model/user.go index ae06687..b74ec4d 100644 --- a/api/graph/model/user.go +++ b/api/graph/model/user.go @@ -2,19 +2,22 @@ package model import ( "context" - "strings" "time" + + "git.sr.ht/~sircmpwn/git.sr.ht/api/database" ) type User struct { - ID int `json:"id"` - Created time.Time `json:"created"` - Updated time.Time `json:"updated"` - Username string `json:"username"` - Email string `json:"email"` - URL *string `json:"url"` - Location *string `json:"location"` - Bio *string `json:"bio"` + ID int `json:"id"` + Created time.Time `json:"created"` + Updated time.Time `json:"updated"` + Username string `json:"username"` + Email string `json:"email"` + URL *string `json:"url"` + Location *string `json:"location"` + Bio *string `json:"bio"` + + alias string } func (User) IsEntity() {} @@ -23,29 +26,33 @@ func (u *User) CanonicalName() string { return "~" + u.Username } -func (u *User) Columns(ctx context.Context, tbl string) string { - columns := ColumnsFor(ctx, map[string]string{ - "id": "id", - "created": "created", - "updated": "updated", +func (u *User) Select(ctx context.Context) []string { + return database.ColumnsFor(ctx, u.alias, map[string]string{ + "id": "id", + "created": "created", + "updated": "updated", "username": "username", - "email": "email", - "url": "url", + "email": "email", + "url": "url", "location": "location", - "bio": "bio", - }, tbl) - return strings.Join(columns, ", ") + "bio": "bio", + }) +} + +func (u *User) As(alias string) database.Selectable { + u.alias = alias + return u } func (u *User) Fields(ctx context.Context) []interface{} { - return FieldsFor(ctx, map[string]interface{}{ - "id": &u.ID, - "created": &u.Created, - "updated": &u.Updated, + return database.FieldsFor(ctx, map[string]interface{}{ + "id": &u.ID, + "created": &u.Created, + "updated": &u.Updated, "username": &u.Username, - "email": &u.Email, - "url": &u.URL, + "email": &u.Email, + "url": &u.URL, "location": &u.Location, - "bio": &u.Bio, + "bio": &u.Bio, }) } diff --git a/api/graph/model/util.go b/api/graph/model/util.go index 310b3a9..771af24 100644 --- a/api/graph/model/util.go +++ b/api/graph/model/util.go @@ -23,14 +23,18 @@ func ColumnsFor(ctx context.Context, } } - sort.Slice(fields, func (a, b int) bool { + sort.Slice(fields, func(a, b int) bool { return fields[a].Name < fields[b].Name }) var columns []string for _, qlCol := range fields { if sqlCol, ok := colMap[qlCol.Name]; ok { - columns = append(columns, tbl + "." + sqlCol) + if tbl != "" { + columns = append(columns, tbl+"."+sqlCol) + } else { + columns = append(columns, sqlCol) + } } } @@ -52,7 +56,7 @@ func FieldsFor(ctx context.Context, } } - sort.Slice(qlFields, func (a, b int) bool { + sort.Slice(qlFields, func(a, b int) bool { return qlFields[a].Name < qlFields[b].Name }) diff --git a/api/loaders/middleware.go b/api/loaders/middleware.go index a1bddcd..758e4ff 100644 --- a/api/loaders/middleware.go +++ b/api/loaders/middleware.go @@ -13,13 +13,16 @@ import ( "net/http" "time" + sq "github.com/Masterminds/squirrel" "github.com/lib/pq" "git.sr.ht/~sircmpwn/git.sr.ht/api/auth" + "git.sr.ht/~sircmpwn/git.sr.ht/api/database" "git.sr.ht/~sircmpwn/git.sr.ht/api/graph/model" ) var loadersCtxKey = &contextKey{"user"} + type contextKey struct { name string } @@ -33,23 +36,24 @@ type Loaders struct { } func fetchUsersByID(ctx context.Context, - db *sql.DB) func (ids []int) ([]*model.User, []error) { - return func (ids []int) ([]*model.User, []error) { + db *sql.DB) func(ids []int) ([]*model.User, []error) { + return func(ids []int) ([]*model.User, []error) { var ( err error rows *sql.Rows ) - if rows, err = db.QueryContext(ctx,` - SELECT `+(&model.User{}).Columns(ctx, "u")+` - FROM "user" u - WHERE u.id = ANY($1)`, pq.Array(ids)); err != nil { + query := database. + Select(ctx, (&model.User{}).As(`u`)). + From(`"user" u`). + Where(sq.Expr(`u.id = ANY(?)`, pq.Array(ids))) + if rows, err = query.RunWith(db).QueryContext(ctx); err != nil { panic(err) } defer rows.Close() usersById := map[int]*model.User{} for rows.Next() { - user := model.User{} + var user model.User if err := rows.Scan(user.Fields(ctx)...); err != nil { panic(err) } @@ -69,16 +73,17 @@ func fetchUsersByID(ctx context.Context, } func fetchUsersByName(ctx context.Context, - db *sql.DB) func (names []string) ([]*model.User, []error) { - return func (names []string) ([]*model.User, []error) { + db *sql.DB) func(names []string) ([]*model.User, []error) { + return func(names []string) ([]*model.User, []error) { var ( err error rows *sql.Rows ) - if rows, err = db.QueryContext(ctx,` - SELECT `+(&model.User{}).Columns(ctx, "u")+` - FROM "user" u - WHERE u.username = ANY($1)`, pq.Array(names)); err != nil { + query := database. + Select(ctx, (&model.User{}).As(`u`)). + From(`"user" u`). + Where(sq.Expr(`u.username = ANY(?)`, pq.Array(names))) + if rows, err = query.RunWith(db).QueryContext(ctx); err != nil { panic(err) } defer rows.Close() @@ -105,23 +110,26 @@ func fetchUsersByName(ctx context.Context, } func fetchRepositoriesByID(ctx context.Context, - db *sql.DB) func (ids []int) ([]*model.Repository, []error) { - return func (ids []int) ([]*model.Repository, []error) { + db *sql.DB) func(ids []int) ([]*model.Repository, []error) { + return func(ids []int) ([]*model.Repository, []error) { var ( err error rows *sql.Rows ) - if rows, err = db.QueryContext(ctx, ` - SELECT DISTINCT `+(&model.Repository{}).Columns(ctx, "repo")+` - FROM repository repo - FULL OUTER JOIN - access ON repo.id = access.repo_id - WHERE - repo.id = ANY($2) - AND (access.user_id = $1 - OR repo.owner_id = $1 - OR repo.visibility != 'private') - `, auth.ForContext(ctx).ID, pq.Array(ids)); err != nil { + authUser := auth.ForContext(ctx) + query := database. + Select(ctx, (&model.Repository{}).As(`repo`)). + Distinct(). + From(`repository repo`). + LeftJoin(`access ON repo.id = access.repo_id`). + Where(sq.And{ + sq.Expr(`repo.id = ANY(?)`, pq.Array(ids)), + sq.Or{ + sq.Expr(`? IN (access.user_id, repo.owner_id)`, authUser.ID), + sq.Expr(`repo.visibility != 'private'`), + }, + }) + if rows, err = query.RunWith(db).QueryContext(ctx); err != nil { panic(err) } defer rows.Close() @@ -148,17 +156,21 @@ func fetchRepositoriesByID(ctx context.Context, } func fetchRepositoriesByName(ctx context.Context, - db *sql.DB) func (names []string) ([]*model.Repository, []error) { - return func (names []string) ([]*model.Repository, []error) { + db *sql.DB) func(names []string) ([]*model.Repository, []error) { + return func(names []string) ([]*model.Repository, []error) { var ( err error rows *sql.Rows ) - if rows, err = db.QueryContext(ctx, ` - SELECT DISTINCT `+(&model.Repository{}).Columns(ctx, "repo")+` - FROM repository repo - WHERE repo.name = ANY($2) AND repo.owner_id = $1 - `, auth.ForContext(ctx).ID, pq.Array(names)); err != nil { + query := database. + Select(ctx, (&model.Repository{}).As(`repo`)). + Distinct(). + From(`repository repo`). + Where(sq.And{ + sq.Expr(`repo.name = ANY(?)`, pq.Array(names)), + sq.Expr(`repo.owner_id = ?`, auth.ForContext(ctx).ID), + }) + if rows, err = query.RunWith(db).QueryContext(ctx); err != nil { panic(err) } defer rows.Close() @@ -185,8 +197,8 @@ func fetchRepositoriesByName(ctx context.Context, } func fetchRepositoriesByOwnerRepoName(ctx context.Context, - db *sql.DB) func (names [][2]string) ([]*model.Repository, []error) { - return func (names [][2]string) ([]*model.Repository, []error) { + db *sql.DB) func(names [][2]string) ([]*model.Repository, []error) { + return func(names [][2]string) ([]*model.Repository, []error) { var ( err error rows *sql.Rows @@ -198,25 +210,27 @@ func fetchRepositoriesByOwnerRepoName(ctx context.Context, // and repo names _names[i] = name[0] + "/" + name[1] } - if rows, err = db.QueryContext(ctx, ` - WITH user_repo AS ( + query := database. + Select(ctx). + Prefix(`WITH user_repo AS ( SELECT substring(un for position('/' in un)-1) AS owner, substring(un from position('/' in un)+1) AS repo - FROM unnest($2::text[]) un - ) - SELECT DISTINCT - `+(&model.Repository{}).Columns(ctx, "repo")+`, - u.username - FROM user_repo ur - JOIN "user" u ON ur.owner = u.username - JOIN repository repo ON ur.repo = repo.name AND u.id = repo.owner_id - LEFT JOIN access ON repo.id = access.repo_id - WHERE - access.user_id = $1 - OR repo.owner_id = $1 - OR repo.visibility != 'private'`, - auth.ForContext(ctx).ID, pq.Array(_names)); err != nil { + FROM unnest(?::text[]) un)`, pq.Array(_names)). + Columns((&model.Repository{}).As(`repo`).Select(ctx)...). + Columns(`u.username`). + Distinct(). + From(`user_repo ur`). + Join(`"user" u on ur.owner = u.username`). + Join(`repository repo ON ur.repo = repo.name + AND u.id = repo.owner_id`). + LeftJoin(`access ON repo.id = access.repo_id`). + Where(sq.Or{ + sq.Expr(`? IN (access.user_id, repo.owner_id)`, + auth.ForContext(ctx).ID), + sq.Expr(`repo.visibility != 'private'`), + }) + if rows, err = query.RunWith(db).QueryContext(ctx); err != nil { panic(err) } defer rows.Close() @@ -250,28 +264,28 @@ func Middleware(db *sql.DB) func(http.Handler) http.Handler { ctx := context.WithValue(r.Context(), loadersCtxKey, &Loaders{ UsersByID: UsersByIDLoader{ maxBatch: 100, - wait: 1 * time.Millisecond, - fetch: fetchUsersByID(r.Context(), db), + wait: 1 * time.Millisecond, + fetch: fetchUsersByID(r.Context(), db), }, UsersByName: UsersByNameLoader{ maxBatch: 100, - wait: 1 * time.Millisecond, - fetch: fetchUsersByName(r.Context(), db), + wait: 1 * time.Millisecond, + fetch: fetchUsersByName(r.Context(), db), }, RepositoriesByID: RepositoriesByIDLoader{ maxBatch: 100, - wait: 1 * time.Millisecond, - fetch: fetchRepositoriesByID(r.Context(), db), + wait: 1 * time.Millisecond, + fetch: fetchRepositoriesByID(r.Context(), db), }, RepositoriesByName: RepositoriesByNameLoader{ maxBatch: 100, - wait: 1 * time.Millisecond, - fetch: fetchRepositoriesByName(r.Context(), db), + wait: 1 * time.Millisecond, + fetch: fetchRepositoriesByName(r.Context(), db), }, RepositoriesByOwnerRepoName: RepositoriesByOwnerRepoNameLoader{ maxBatch: 100, - wait: 1 * time.Millisecond, - fetch: fetchRepositoriesByOwnerRepoName(r.Context(), db), + wait: 1 * time.Millisecond, + fetch: fetchRepositoriesByOwnerRepoName(r.Context(), db), }, }) r = r.WithContext(ctx) diff --git a/api/server.go b/api/server.go index 2548be3..9602a92 100644 --- a/api/server.go +++ b/api/server.go @@ -11,13 +11,13 @@ import ( "git.sr.ht/~sircmpwn/gqlgen/graphql/playground" "github.com/go-chi/chi" "github.com/go-chi/chi/middleware" - "github.com/vaughan0/go-ini" _ "github.com/lib/pq" + "github.com/vaughan0/go-ini" "git.sr.ht/~sircmpwn/git.sr.ht/api/auth" - "git.sr.ht/~sircmpwn/git.sr.ht/api/loaders" "git.sr.ht/~sircmpwn/git.sr.ht/api/graph" "git.sr.ht/~sircmpwn/git.sr.ht/api/graph/generated" + "git.sr.ht/~sircmpwn/git.sr.ht/api/loaders" ) const defaultAddr = ":8080" -- 2.38.4