This commit is contained in:
2025-04-12 12:38:00 +08:00
parent 5341dfcd1a
commit 52bfd05b80
47 changed files with 4730 additions and 2 deletions

View File

@@ -0,0 +1,11 @@
package httprule
import "strings"
// FieldPath describes the path for a field from a message.
// Individual segments are in snake case (same as in protobuf file).
type FieldPath []string
func (f FieldPath) String() string {
return strings.Join(f, ".")
}

94
internal/httprule/rule.go Normal file
View File

@@ -0,0 +1,94 @@
package httprule
import (
"fmt"
"net/http"
"google.golang.org/genproto/googleapis/api/annotations"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
)
func Get(m protoreflect.MethodDescriptor) (*annotations.HttpRule, bool) {
descriptor, ok := proto.GetExtension(m.Options(), annotations.E_Http).(*annotations.HttpRule)
if !ok || descriptor == nil {
return nil, false
}
return descriptor, true
}
type Rule struct {
// The HTTP method to use.
Method string
// The template describing the URL to use.
Template Template
Body string
AdditionalRules []Rule
}
func ParseRule(httpRule *annotations.HttpRule) (Rule, error) {
method, err := httpRuleMethod(httpRule)
if err != nil {
return Rule{}, err
}
url, err := httpRuleURL(httpRule)
if err != nil {
return Rule{}, err
}
template, err := ParseTemplate(url)
if err != nil {
return Rule{}, err
}
additional := make([]Rule, len(httpRule.GetAdditionalBindings()))
for i, r := range httpRule.GetAdditionalBindings() {
a, err := ParseRule(r)
if err != nil {
return Rule{}, fmt.Errorf("parse additional binding %d: %w", i, err)
}
additional[i] = a
}
return Rule{
Method: method,
Template: template,
Body: httpRule.GetBody(),
AdditionalRules: additional,
}, nil
}
func httpRuleURL(rule *annotations.HttpRule) (string, error) {
switch v := rule.GetPattern().(type) {
case *annotations.HttpRule_Get:
return v.Get, nil
case *annotations.HttpRule_Post:
return v.Post, nil
case *annotations.HttpRule_Delete:
return v.Delete, nil
case *annotations.HttpRule_Patch:
return v.Patch, nil
case *annotations.HttpRule_Put:
return v.Put, nil
case *annotations.HttpRule_Custom:
return v.Custom.GetPath(), nil
default:
return "", fmt.Errorf("http rule does not have an URL defined")
}
}
func httpRuleMethod(rule *annotations.HttpRule) (string, error) {
switch v := rule.GetPattern().(type) {
case *annotations.HttpRule_Get:
return http.MethodGet, nil
case *annotations.HttpRule_Post:
return http.MethodPost, nil
case *annotations.HttpRule_Delete:
return http.MethodDelete, nil
case *annotations.HttpRule_Patch:
return http.MethodPatch, nil
case *annotations.HttpRule_Put:
return http.MethodPut, nil
case *annotations.HttpRule_Custom:
return v.Custom.GetKind(), nil
default:
return "", fmt.Errorf("http rule does not have an URL defined")
}
}

View File

@@ -0,0 +1,397 @@
package httprule
import "fmt"
// Template represents a http path template.
//
// Example: `/v1/{name=books/*}:publish`.
type Template struct {
Segments []Segment
Verb string
}
// Segment represents a single segment of a Template.
type Segment struct {
Kind SegmentKind
Literal string
Variable VariableSegment
}
type SegmentKind int
const (
SegmentKindLiteral SegmentKind = iota
SegmentKindMatchSingle
SegmentKindMatchMultiple
SegmentKindVariable
)
// VariableSegment represents a variable segment.
type VariableSegment struct {
FieldPath FieldPath
Segments []Segment
}
func ParseTemplate(s string) (Template, error) {
p := &parser{
content: s,
}
template, err := p.parse()
if err != nil {
return Template{}, err
}
if err := validate(template); err != nil {
return Template{}, err
}
return template, nil
}
type parser struct {
content string
// The next pos in content to read
pos int
// The currently read rune in content
tok rune
}
func (p *parser) parse() (Template, error) {
// Grammar.
// Template = "/" Segments [ Verb ] ;
// Segments = Segment { "/" Segment } ;
// Segment = "*" | "**" | LITERAL | Variable ;
// Variable = "{" FieldPath [ "=" Segments ] "}" ;
// FieldPath = IDENT { "." IDENT } ;
// Verb = ":" LITERAL ;.
p.next()
if err := p.expect('/'); err != nil {
return Template{}, err
}
segments, err := p.parseSegments()
if err != nil {
return Template{}, err
}
var verb string
if p.tok == ':' {
v, err := p.parseVerb()
if err != nil {
return Template{}, err
}
verb = v
}
if p.tok != -1 {
return Template{}, fmt.Errorf("expected EOF, got %q", p.tok)
}
return Template{
Segments: segments,
Verb: verb,
}, nil
}
func (p *parser) parseSegments() ([]Segment, error) {
seg, err := p.parseSegment()
if err != nil {
return nil, err
}
if p.tok == '/' {
p.next()
rest, err := p.parseSegments()
if err != nil {
return nil, err
}
return append([]Segment{seg}, rest...), nil
}
return []Segment{seg}, nil
}
func (p *parser) parseSegment() (Segment, error) {
switch {
case p.tok == '*' && p.peek() == '*':
return p.parseMatchMultipleSegment(), nil
case p.tok == '*':
return p.parseMatchSingleSegment(), nil
case p.tok == '{':
return p.parseVariableSegment()
default:
return p.parseLiteralSegment()
}
}
func (p *parser) parseMatchMultipleSegment() Segment {
p.next()
p.next()
return Segment{
Kind: SegmentKindMatchMultiple,
}
}
func (p *parser) parseMatchSingleSegment() Segment {
p.next()
return Segment{
Kind: SegmentKindMatchSingle,
}
}
func (p *parser) parseLiteralSegment() (Segment, error) {
lit, err := p.parseLiteral()
if err != nil {
return Segment{}, err
}
return Segment{
Kind: SegmentKindLiteral,
Literal: lit,
}, nil
}
func (p *parser) parseVariableSegment() (Segment, error) {
if err := p.expect('{'); err != nil {
return Segment{}, err
}
fieldPath, err := p.parseFieldPath()
if err != nil {
return Segment{}, err
}
segments := []Segment{
{Kind: SegmentKindMatchSingle},
}
if p.tok == '=' {
p.next()
s, err := p.parseSegments()
if err != nil {
return Segment{}, err
}
segments = s
}
if err := p.expect('}'); err != nil {
return Segment{}, err
}
return Segment{
Kind: SegmentKindVariable,
Variable: VariableSegment{
FieldPath: fieldPath,
Segments: segments,
},
}, nil
}
func (p *parser) parseVerb() (string, error) {
if err := p.expect(':'); err != nil {
return "", err
}
return p.parseLiteral()
}
func (p *parser) parseFieldPath() ([]string, error) {
fp, err := p.parseIdent()
if err != nil {
return nil, err
}
if p.tok == '.' {
p.next()
rest, err := p.parseFieldPath()
if err != nil {
return nil, err
}
return append([]string{fp}, rest...), nil
}
return []string{fp}, nil
}
// parseLiteral consumes input as long as next token(s) belongs to pchars, as defined in RFC3986.
// Returns an error if not literal is found.
//
// https://www.ietf.org/rfc/rfc3986.txt, P.49
//
// pchar = unreserved / pct-encoded / sub-delims / ":" / "@"
// unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~"
// sub-delims = "!" / "$" / "&" / "'" / "(" / ")"
// / "*" / "+" / "," / ";" / "="
// pct-encoded = "%" HEXDIG HEXDIG
func (p *parser) parseLiteral() (string, error) {
var literal []rune
startPos := p.pos
for {
if isSingleCharPChar(p.tok) {
literal = append(literal, p.tok)
p.next()
continue
}
if p.tok == '%' && isHexDigit(p.peekN(1)) && isHexDigit(p.peekN(2)) {
literal = append(literal, p.tok)
p.next()
literal = append(literal, p.tok)
p.next()
literal = append(literal, p.tok)
p.next()
continue
}
break
}
if len(literal) == 0 {
return "", fmt.Errorf("expected literal at position %d, found %s", startPos-1, p.tokenString())
}
return string(literal), nil
}
func (p *parser) parseIdent() (string, error) {
var ident []rune
startPos := p.pos
for {
if isAlpha(p.tok) || isDigit(p.tok) || p.tok == '_' {
ident = append(ident, p.tok)
p.next()
continue
}
break
}
if len(ident) == 0 {
return "", fmt.Errorf("expected identifier at position %d, found %s", startPos-1, p.tokenString())
}
return string(ident), nil
}
func (p *parser) next() {
if p.pos < len(p.content) {
p.tok = rune(p.content[p.pos])
p.pos++
} else {
p.tok = -1
p.pos = len(p.content)
}
}
func (p parser) tokenString() string {
if p.tok == -1 {
return "EOF"
}
return fmt.Sprintf("%q", p.tok)
}
func (p *parser) peek() rune {
return p.peekN(1)
}
func (p *parser) peekN(n int) rune {
if offset := p.pos + n - 1; offset < len(p.content) {
return rune(p.content[offset])
}
return -1
}
func (p *parser) expect(r rune) error {
if p.tok != r {
return fmt.Errorf("expected token %q at position %d, found %s", r, p.pos, p.tokenString())
}
p.next()
return nil
}
// https://www.ietf.org/rfc/rfc3986.txt, P.49
//
// pchar = unreserved / pct-encoded / sub-delims / ":" / "@"
// unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~"
// sub-delims = "!" / "$" / "&" / "'" / "(" / ")"
// / "*" / "+" / "," / ";" / "="
// pct-encoded = "%" HEXDIG HEXDIG
func isSingleCharPChar(r rune) bool {
if isAlpha(r) || isDigit(r) {
return true
}
switch r {
case '@', '-', '.', '_', '~', '!',
'$', '&', '\'', '(', ')', '*', '+',
',', ';', '=': // ':'
return true
}
return false
}
func isAlpha(r rune) bool {
return ('A' <= r && r <= 'Z') || ('a' <= r && r <= 'z')
}
func isDigit(r rune) bool {
return '0' <= r && r <= '9'
}
func isHexDigit(r rune) bool {
switch {
case '0' <= r && r <= '9':
return true
case 'A' <= r && r <= 'F':
return true
case 'a' <= r && r <= 'f':
return true
}
return false
}
// validate validates parts of the template that are
// allowed by the grammar, but disallowed in practice.
//
// - nested variable segments
// - '**' for segments other than the last.
func validate(t Template) error {
// check for nested variable segments
for _, s1 := range t.Segments {
if s1.Kind != SegmentKindVariable {
continue
}
for _, s2 := range s1.Variable.Segments {
if s2.Kind == SegmentKindVariable {
return fmt.Errorf("nested variable segment is not allowed")
}
}
}
// check for '**' that are not the last part of the template
for i, s := range t.Segments {
if i == len(t.Segments)-1 {
continue
}
if s.Kind == SegmentKindMatchMultiple {
return fmt.Errorf("'**' only allowed as last part of template")
}
if s.Kind == SegmentKindVariable {
for _, s2 := range s.Variable.Segments {
if s2.Kind == SegmentKindMatchMultiple {
return fmt.Errorf("'**' only allowed as last part of template")
}
}
}
}
// check for variable where '**' is not last part
for _, s := range t.Segments {
if s.Kind != SegmentKindVariable {
continue
}
for i, s2 := range s.Variable.Segments {
if i == len(s.Variable.Segments)-1 {
continue
}
if s2.Kind == SegmentKindMatchMultiple {
return fmt.Errorf("'**' only allowed as the last part of the template")
}
}
}
// check for top level expansions
for _, s := range t.Segments {
if s.Kind == SegmentKindMatchSingle {
return fmt.Errorf("'*' must only be used in variables")
}
if s.Kind == SegmentKindMatchMultiple {
return fmt.Errorf("'**' must only be used in variables")
}
}
// check for duplicate variable bindings
seen := make(map[string]struct{})
for _, s := range t.Segments {
if s.Kind == SegmentKindVariable {
field := s.Variable.FieldPath.String()
if _, ok := seen[s.Variable.FieldPath.String()]; ok {
return fmt.Errorf("variable '%s' bound multiple times", field)
}
seen[field] = struct{}{}
}
}
return nil
}

View File

@@ -0,0 +1,178 @@
package httprule
import (
"testing"
"gotest.tools/v3/assert"
)
func Test_ParseTemplate(t *testing.T) {
t.Parallel()
for _, tt := range []struct {
input string
path Template
}{
{
input: "/v1/messages",
path: Template{
Segments: []Segment{
{Kind: SegmentKindLiteral, Literal: "v1"},
{Kind: SegmentKindLiteral, Literal: "messages"},
},
},
},
{
input: "/v1/messages:peek",
path: Template{
Segments: []Segment{
{Kind: SegmentKindLiteral, Literal: "v1"},
{Kind: SegmentKindLiteral, Literal: "messages"},
},
Verb: "peek",
},
},
{
input: "/{id}",
path: Template{
Segments: []Segment{
{
Kind: SegmentKindVariable,
Variable: VariableSegment{
FieldPath: []string{"id"},
Segments: []Segment{
{Kind: SegmentKindMatchSingle},
},
},
},
},
},
},
{
input: "/{message.id}",
path: Template{
Segments: []Segment{
{
Kind: SegmentKindVariable,
Variable: VariableSegment{
FieldPath: []string{"message", "id"},
Segments: []Segment{
{Kind: SegmentKindMatchSingle},
},
},
},
},
},
},
{
input: "/{id=messages/*}",
path: Template{
Segments: []Segment{
{
Kind: SegmentKindVariable,
Variable: VariableSegment{
FieldPath: []string{"id"},
Segments: []Segment{
{Kind: SegmentKindLiteral, Literal: "messages"},
{Kind: SegmentKindMatchSingle},
},
},
},
},
},
},
{
input: "/{id=messages/*/threads/*}",
path: Template{
Segments: []Segment{
{
Kind: SegmentKindVariable,
Variable: VariableSegment{
FieldPath: []string{"id"},
Segments: []Segment{
{Kind: SegmentKindLiteral, Literal: "messages"},
{Kind: SegmentKindMatchSingle},
{Kind: SegmentKindLiteral, Literal: "threads"},
{Kind: SegmentKindMatchSingle},
},
},
},
},
},
},
{
input: "/{id=**}",
path: Template{
Segments: []Segment{
{
Kind: SegmentKindVariable,
Variable: VariableSegment{
FieldPath: []string{"id"},
Segments: []Segment{
{Kind: SegmentKindMatchMultiple},
},
},
},
},
},
},
{
input: "/v1/messages/{message}/threads/{thread}",
path: Template{
Segments: []Segment{
{Kind: SegmentKindLiteral, Literal: "v1"},
{Kind: SegmentKindLiteral, Literal: "messages"},
{
Kind: SegmentKindVariable,
Variable: VariableSegment{
FieldPath: []string{"message"},
Segments: []Segment{
{Kind: SegmentKindMatchSingle},
},
},
},
{Kind: SegmentKindLiteral, Literal: "threads"},
{
Kind: SegmentKindVariable,
Variable: VariableSegment{
FieldPath: []string{"thread"},
Segments: []Segment{
{Kind: SegmentKindMatchSingle},
},
},
},
},
},
},
} {
t.Run(tt.input, func(t *testing.T) {
t.Parallel()
got, err := ParseTemplate(tt.input)
assert.NilError(t, err)
assert.DeepEqual(t, tt.path, got)
})
}
}
func Test_ParseTemplate_Invalid(t *testing.T) {
t.Parallel()
for _, tt := range []struct {
template string
expected string
}{
{template: "", expected: "expected token '/' at position 0, found EOF"},
{template: "//", expected: "expected literal at position 1, found '/'"},
{template: "/v1:", expected: "expected literal at position 3, found EOF"},
{template: "/v1/:", expected: "expected literal at position 4, found ':'"},
{template: "/{name=messages/{id}}", expected: "nested variable segment is not allowed"},
{template: "/**/*", expected: "'**' only allowed as last part of template"},
{template: "/v1/messages/*", expected: "'*' must only be used in variables"},
{template: "/v1/{id}/{id}", expected: "variable 'id' bound multiple times"},
} {
t.Run(tt.template, func(t *testing.T) {
t.Parallel()
_, err := ParseTemplate(tt.template)
assert.Check(t, err != nil)
assert.ErrorContains(t, err, tt.expected)
})
}
}