| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467 |
- package pattern
- import (
- "fmt"
- "go/ast"
- "go/token"
- "reflect"
- )
- type Pattern struct {
- Root Node
- // Relevant contains instances of ast.Node that could potentially
- // initiate a successful match of the pattern.
- Relevant []reflect.Type
- }
- func MustParse(s string) Pattern {
- p := &Parser{AllowTypeInfo: true}
- pat, err := p.Parse(s)
- if err != nil {
- panic(err)
- }
- return pat
- }
- func roots(node Node) []reflect.Type {
- switch node := node.(type) {
- case Or:
- var out []reflect.Type
- for _, el := range node.Nodes {
- out = append(out, roots(el)...)
- }
- return out
- case Not:
- return roots(node.Node)
- case Binding:
- return roots(node.Node)
- case Nil, nil:
- // this branch is reached via bindings
- return allTypes
- default:
- Ts, ok := nodeToASTTypes[reflect.TypeOf(node)]
- if !ok {
- panic(fmt.Sprintf("internal error: unhandled type %T", node))
- }
- return Ts
- }
- }
- var allTypes = []reflect.Type{
- reflect.TypeOf((*ast.RangeStmt)(nil)),
- reflect.TypeOf((*ast.AssignStmt)(nil)),
- reflect.TypeOf((*ast.IndexExpr)(nil)),
- reflect.TypeOf((*ast.Ident)(nil)),
- reflect.TypeOf((*ast.ValueSpec)(nil)),
- reflect.TypeOf((*ast.GenDecl)(nil)),
- reflect.TypeOf((*ast.BinaryExpr)(nil)),
- reflect.TypeOf((*ast.ForStmt)(nil)),
- reflect.TypeOf((*ast.ArrayType)(nil)),
- reflect.TypeOf((*ast.DeferStmt)(nil)),
- reflect.TypeOf((*ast.MapType)(nil)),
- reflect.TypeOf((*ast.ReturnStmt)(nil)),
- reflect.TypeOf((*ast.SliceExpr)(nil)),
- reflect.TypeOf((*ast.StarExpr)(nil)),
- reflect.TypeOf((*ast.UnaryExpr)(nil)),
- reflect.TypeOf((*ast.SendStmt)(nil)),
- reflect.TypeOf((*ast.SelectStmt)(nil)),
- reflect.TypeOf((*ast.ImportSpec)(nil)),
- reflect.TypeOf((*ast.IfStmt)(nil)),
- reflect.TypeOf((*ast.GoStmt)(nil)),
- reflect.TypeOf((*ast.Field)(nil)),
- reflect.TypeOf((*ast.SelectorExpr)(nil)),
- reflect.TypeOf((*ast.StructType)(nil)),
- reflect.TypeOf((*ast.KeyValueExpr)(nil)),
- reflect.TypeOf((*ast.FuncType)(nil)),
- reflect.TypeOf((*ast.FuncLit)(nil)),
- reflect.TypeOf((*ast.FuncDecl)(nil)),
- reflect.TypeOf((*ast.ChanType)(nil)),
- reflect.TypeOf((*ast.CallExpr)(nil)),
- reflect.TypeOf((*ast.CaseClause)(nil)),
- reflect.TypeOf((*ast.CommClause)(nil)),
- reflect.TypeOf((*ast.CompositeLit)(nil)),
- reflect.TypeOf((*ast.EmptyStmt)(nil)),
- reflect.TypeOf((*ast.SwitchStmt)(nil)),
- reflect.TypeOf((*ast.TypeSwitchStmt)(nil)),
- reflect.TypeOf((*ast.TypeAssertExpr)(nil)),
- reflect.TypeOf((*ast.TypeSpec)(nil)),
- reflect.TypeOf((*ast.InterfaceType)(nil)),
- reflect.TypeOf((*ast.BranchStmt)(nil)),
- reflect.TypeOf((*ast.IncDecStmt)(nil)),
- reflect.TypeOf((*ast.BasicLit)(nil)),
- }
- var nodeToASTTypes = map[reflect.Type][]reflect.Type{
- reflect.TypeOf(String("")): nil,
- reflect.TypeOf(Token(0)): nil,
- reflect.TypeOf(List{}): {reflect.TypeOf((*ast.BlockStmt)(nil)), reflect.TypeOf((*ast.FieldList)(nil))},
- reflect.TypeOf(Builtin{}): {reflect.TypeOf((*ast.Ident)(nil))},
- reflect.TypeOf(Object{}): {reflect.TypeOf((*ast.Ident)(nil))},
- reflect.TypeOf(Function{}): {reflect.TypeOf((*ast.Ident)(nil)), reflect.TypeOf((*ast.SelectorExpr)(nil))},
- reflect.TypeOf(Any{}): allTypes,
- reflect.TypeOf(RangeStmt{}): {reflect.TypeOf((*ast.RangeStmt)(nil))},
- reflect.TypeOf(AssignStmt{}): {reflect.TypeOf((*ast.AssignStmt)(nil))},
- reflect.TypeOf(IndexExpr{}): {reflect.TypeOf((*ast.IndexExpr)(nil))},
- reflect.TypeOf(Ident{}): {reflect.TypeOf((*ast.Ident)(nil))},
- reflect.TypeOf(ValueSpec{}): {reflect.TypeOf((*ast.ValueSpec)(nil))},
- reflect.TypeOf(GenDecl{}): {reflect.TypeOf((*ast.GenDecl)(nil))},
- reflect.TypeOf(BinaryExpr{}): {reflect.TypeOf((*ast.BinaryExpr)(nil))},
- reflect.TypeOf(ForStmt{}): {reflect.TypeOf((*ast.ForStmt)(nil))},
- reflect.TypeOf(ArrayType{}): {reflect.TypeOf((*ast.ArrayType)(nil))},
- reflect.TypeOf(DeferStmt{}): {reflect.TypeOf((*ast.DeferStmt)(nil))},
- reflect.TypeOf(MapType{}): {reflect.TypeOf((*ast.MapType)(nil))},
- reflect.TypeOf(ReturnStmt{}): {reflect.TypeOf((*ast.ReturnStmt)(nil))},
- reflect.TypeOf(SliceExpr{}): {reflect.TypeOf((*ast.SliceExpr)(nil))},
- reflect.TypeOf(StarExpr{}): {reflect.TypeOf((*ast.StarExpr)(nil))},
- reflect.TypeOf(UnaryExpr{}): {reflect.TypeOf((*ast.UnaryExpr)(nil))},
- reflect.TypeOf(SendStmt{}): {reflect.TypeOf((*ast.SendStmt)(nil))},
- reflect.TypeOf(SelectStmt{}): {reflect.TypeOf((*ast.SelectStmt)(nil))},
- reflect.TypeOf(ImportSpec{}): {reflect.TypeOf((*ast.ImportSpec)(nil))},
- reflect.TypeOf(IfStmt{}): {reflect.TypeOf((*ast.IfStmt)(nil))},
- reflect.TypeOf(GoStmt{}): {reflect.TypeOf((*ast.GoStmt)(nil))},
- reflect.TypeOf(Field{}): {reflect.TypeOf((*ast.Field)(nil))},
- reflect.TypeOf(SelectorExpr{}): {reflect.TypeOf((*ast.SelectorExpr)(nil))},
- reflect.TypeOf(StructType{}): {reflect.TypeOf((*ast.StructType)(nil))},
- reflect.TypeOf(KeyValueExpr{}): {reflect.TypeOf((*ast.KeyValueExpr)(nil))},
- reflect.TypeOf(FuncType{}): {reflect.TypeOf((*ast.FuncType)(nil))},
- reflect.TypeOf(FuncLit{}): {reflect.TypeOf((*ast.FuncLit)(nil))},
- reflect.TypeOf(FuncDecl{}): {reflect.TypeOf((*ast.FuncDecl)(nil))},
- reflect.TypeOf(ChanType{}): {reflect.TypeOf((*ast.ChanType)(nil))},
- reflect.TypeOf(CallExpr{}): {reflect.TypeOf((*ast.CallExpr)(nil))},
- reflect.TypeOf(CaseClause{}): {reflect.TypeOf((*ast.CaseClause)(nil))},
- reflect.TypeOf(CommClause{}): {reflect.TypeOf((*ast.CommClause)(nil))},
- reflect.TypeOf(CompositeLit{}): {reflect.TypeOf((*ast.CompositeLit)(nil))},
- reflect.TypeOf(EmptyStmt{}): {reflect.TypeOf((*ast.EmptyStmt)(nil))},
- reflect.TypeOf(SwitchStmt{}): {reflect.TypeOf((*ast.SwitchStmt)(nil))},
- reflect.TypeOf(TypeSwitchStmt{}): {reflect.TypeOf((*ast.TypeSwitchStmt)(nil))},
- reflect.TypeOf(TypeAssertExpr{}): {reflect.TypeOf((*ast.TypeAssertExpr)(nil))},
- reflect.TypeOf(TypeSpec{}): {reflect.TypeOf((*ast.TypeSpec)(nil))},
- reflect.TypeOf(InterfaceType{}): {reflect.TypeOf((*ast.InterfaceType)(nil))},
- reflect.TypeOf(BranchStmt{}): {reflect.TypeOf((*ast.BranchStmt)(nil))},
- reflect.TypeOf(IncDecStmt{}): {reflect.TypeOf((*ast.IncDecStmt)(nil))},
- reflect.TypeOf(BasicLit{}): {reflect.TypeOf((*ast.BasicLit)(nil))},
- reflect.TypeOf(IntegerLiteral{}): {reflect.TypeOf((*ast.BasicLit)(nil)), reflect.TypeOf((*ast.UnaryExpr)(nil))},
- reflect.TypeOf(TrulyConstantExpression{}): allTypes, // this is an over-approximation, which is fine
- }
- var requiresTypeInfo = map[string]bool{
- "Function": true,
- "Builtin": true,
- "Object": true,
- }
- type Parser struct {
- // Allow nodes that rely on type information
- AllowTypeInfo bool
- lex *lexer
- cur item
- last *item
- items chan item
- }
- func (p *Parser) Parse(s string) (Pattern, error) {
- p.cur = item{}
- p.last = nil
- p.items = nil
- fset := token.NewFileSet()
- p.lex = &lexer{
- f: fset.AddFile("<input>", -1, len(s)),
- input: s,
- items: make(chan item),
- }
- go p.lex.run()
- p.items = p.lex.items
- root, err := p.node()
- if err != nil {
- // drain lexer if parsing failed
- for range p.lex.items {
- }
- return Pattern{}, err
- }
- if item := <-p.lex.items; item.typ != itemEOF {
- return Pattern{}, fmt.Errorf("unexpected token %s after end of pattern", item.typ)
- }
- return Pattern{
- Root: root,
- Relevant: roots(root),
- }, nil
- }
- func (p *Parser) next() item {
- if p.last != nil {
- n := *p.last
- p.last = nil
- return n
- }
- var ok bool
- p.cur, ok = <-p.items
- if !ok {
- p.cur = item{typ: eof}
- }
- return p.cur
- }
- func (p *Parser) rewind() {
- p.last = &p.cur
- }
- func (p *Parser) peek() item {
- n := p.next()
- p.rewind()
- return n
- }
- func (p *Parser) accept(typ itemType) (item, bool) {
- n := p.next()
- if n.typ == typ {
- return n, true
- }
- p.rewind()
- return item{}, false
- }
- func (p *Parser) unexpectedToken(valid string) error {
- if p.cur.typ == itemError {
- return fmt.Errorf("error lexing input: %s", p.cur.val)
- }
- var got string
- switch p.cur.typ {
- case itemTypeName, itemVariable, itemString:
- got = p.cur.val
- default:
- got = "'" + p.cur.typ.String() + "'"
- }
- pos := p.lex.f.Position(token.Pos(p.cur.pos))
- return fmt.Errorf("%s: expected %s, found %s", pos, valid, got)
- }
- func (p *Parser) node() (Node, error) {
- if _, ok := p.accept(itemLeftParen); !ok {
- return nil, p.unexpectedToken("'('")
- }
- typ, ok := p.accept(itemTypeName)
- if !ok {
- return nil, p.unexpectedToken("Node type")
- }
- var objs []Node
- for {
- if _, ok := p.accept(itemRightParen); ok {
- break
- } else {
- p.rewind()
- obj, err := p.object()
- if err != nil {
- return nil, err
- }
- objs = append(objs, obj)
- }
- }
- return p.populateNode(typ.val, objs)
- }
- func populateNode(typ string, objs []Node, allowTypeInfo bool) (Node, error) {
- T, ok := structNodes[typ]
- if !ok {
- return nil, fmt.Errorf("unknown node %s", typ)
- }
- if !allowTypeInfo && requiresTypeInfo[typ] {
- return nil, fmt.Errorf("Node %s requires type information", typ)
- }
- pv := reflect.New(T)
- v := pv.Elem()
- if v.NumField() == 1 {
- f := v.Field(0)
- if f.Type().Kind() == reflect.Slice {
- // Variadic node
- f.Set(reflect.AppendSlice(f, reflect.ValueOf(objs)))
- return v.Interface().(Node), nil
- }
- }
- if len(objs) != v.NumField() {
- return nil, fmt.Errorf("tried to initialize node %s with %d values, expected %d", typ, len(objs), v.NumField())
- }
- for i := 0; i < v.NumField(); i++ {
- f := v.Field(i)
- if f.Kind() == reflect.String {
- if obj, ok := objs[i].(String); ok {
- f.Set(reflect.ValueOf(string(obj)))
- } else {
- return nil, fmt.Errorf("first argument of (Binding name node) must be string, but got %s", objs[i])
- }
- } else {
- f.Set(reflect.ValueOf(objs[i]))
- }
- }
- return v.Interface().(Node), nil
- }
- func (p *Parser) populateNode(typ string, objs []Node) (Node, error) {
- return populateNode(typ, objs, p.AllowTypeInfo)
- }
- var structNodes = map[string]reflect.Type{
- "Any": reflect.TypeOf(Any{}),
- "Ellipsis": reflect.TypeOf(Ellipsis{}),
- "List": reflect.TypeOf(List{}),
- "Binding": reflect.TypeOf(Binding{}),
- "RangeStmt": reflect.TypeOf(RangeStmt{}),
- "AssignStmt": reflect.TypeOf(AssignStmt{}),
- "IndexExpr": reflect.TypeOf(IndexExpr{}),
- "Ident": reflect.TypeOf(Ident{}),
- "Builtin": reflect.TypeOf(Builtin{}),
- "ValueSpec": reflect.TypeOf(ValueSpec{}),
- "GenDecl": reflect.TypeOf(GenDecl{}),
- "BinaryExpr": reflect.TypeOf(BinaryExpr{}),
- "ForStmt": reflect.TypeOf(ForStmt{}),
- "ArrayType": reflect.TypeOf(ArrayType{}),
- "DeferStmt": reflect.TypeOf(DeferStmt{}),
- "MapType": reflect.TypeOf(MapType{}),
- "ReturnStmt": reflect.TypeOf(ReturnStmt{}),
- "SliceExpr": reflect.TypeOf(SliceExpr{}),
- "StarExpr": reflect.TypeOf(StarExpr{}),
- "UnaryExpr": reflect.TypeOf(UnaryExpr{}),
- "SendStmt": reflect.TypeOf(SendStmt{}),
- "SelectStmt": reflect.TypeOf(SelectStmt{}),
- "ImportSpec": reflect.TypeOf(ImportSpec{}),
- "IfStmt": reflect.TypeOf(IfStmt{}),
- "GoStmt": reflect.TypeOf(GoStmt{}),
- "Field": reflect.TypeOf(Field{}),
- "SelectorExpr": reflect.TypeOf(SelectorExpr{}),
- "StructType": reflect.TypeOf(StructType{}),
- "KeyValueExpr": reflect.TypeOf(KeyValueExpr{}),
- "FuncType": reflect.TypeOf(FuncType{}),
- "FuncLit": reflect.TypeOf(FuncLit{}),
- "FuncDecl": reflect.TypeOf(FuncDecl{}),
- "ChanType": reflect.TypeOf(ChanType{}),
- "CallExpr": reflect.TypeOf(CallExpr{}),
- "CaseClause": reflect.TypeOf(CaseClause{}),
- "CommClause": reflect.TypeOf(CommClause{}),
- "CompositeLit": reflect.TypeOf(CompositeLit{}),
- "EmptyStmt": reflect.TypeOf(EmptyStmt{}),
- "SwitchStmt": reflect.TypeOf(SwitchStmt{}),
- "TypeSwitchStmt": reflect.TypeOf(TypeSwitchStmt{}),
- "TypeAssertExpr": reflect.TypeOf(TypeAssertExpr{}),
- "TypeSpec": reflect.TypeOf(TypeSpec{}),
- "InterfaceType": reflect.TypeOf(InterfaceType{}),
- "BranchStmt": reflect.TypeOf(BranchStmt{}),
- "IncDecStmt": reflect.TypeOf(IncDecStmt{}),
- "BasicLit": reflect.TypeOf(BasicLit{}),
- "Object": reflect.TypeOf(Object{}),
- "Function": reflect.TypeOf(Function{}),
- "Or": reflect.TypeOf(Or{}),
- "Not": reflect.TypeOf(Not{}),
- "IntegerLiteral": reflect.TypeOf(IntegerLiteral{}),
- "TrulyConstantExpression": reflect.TypeOf(TrulyConstantExpression{}),
- }
- func (p *Parser) object() (Node, error) {
- n := p.next()
- switch n.typ {
- case itemLeftParen:
- p.rewind()
- node, err := p.node()
- if err != nil {
- return node, err
- }
- if p.peek().typ == itemColon {
- p.next()
- tail, err := p.object()
- if err != nil {
- return node, err
- }
- return List{Head: node, Tail: tail}, nil
- }
- return node, nil
- case itemLeftBracket:
- p.rewind()
- return p.array()
- case itemVariable:
- v := n
- if v.val == "nil" {
- return Nil{}, nil
- }
- var b Binding
- if _, ok := p.accept(itemAt); ok {
- o, err := p.node()
- if err != nil {
- return nil, err
- }
- b = Binding{
- Name: v.val,
- Node: o,
- }
- } else {
- p.rewind()
- b = Binding{Name: v.val}
- }
- if p.peek().typ == itemColon {
- p.next()
- tail, err := p.object()
- if err != nil {
- return b, err
- }
- return List{Head: b, Tail: tail}, nil
- }
- return b, nil
- case itemBlank:
- if p.peek().typ == itemColon {
- p.next()
- tail, err := p.object()
- if err != nil {
- return Any{}, err
- }
- return List{Head: Any{}, Tail: tail}, nil
- }
- return Any{}, nil
- case itemString:
- return String(n.val), nil
- default:
- return nil, p.unexpectedToken("object")
- }
- }
- func (p *Parser) array() (Node, error) {
- if _, ok := p.accept(itemLeftBracket); !ok {
- return nil, p.unexpectedToken("'['")
- }
- var objs []Node
- for {
- if _, ok := p.accept(itemRightBracket); ok {
- break
- } else {
- p.rewind()
- obj, err := p.object()
- if err != nil {
- return nil, err
- }
- objs = append(objs, obj)
- }
- }
- tail := List{}
- for i := len(objs) - 1; i >= 0; i-- {
- l := List{
- Head: objs[i],
- Tail: tail,
- }
- tail = l
- }
- return tail, nil
- }
- /*
- Node ::= itemLeftParen itemTypeName Object* itemRightParen
- Object ::= Node | Array | Binding | itemVariable | itemBlank | itemString
- Array := itemLeftBracket Object* itemRightBracket
- Array := Object itemColon Object
- Binding ::= itemVariable itemAt Node
- */
|