2
0

match.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620
  1. package pattern
  2. import (
  3. "fmt"
  4. "go/ast"
  5. "go/token"
  6. "go/types"
  7. "reflect"
  8. "golang.org/x/exp/typeparams"
  9. )
  10. var tokensByString = map[string]Token{
  11. "INT": Token(token.INT),
  12. "FLOAT": Token(token.FLOAT),
  13. "IMAG": Token(token.IMAG),
  14. "CHAR": Token(token.CHAR),
  15. "STRING": Token(token.STRING),
  16. "+": Token(token.ADD),
  17. "-": Token(token.SUB),
  18. "*": Token(token.MUL),
  19. "/": Token(token.QUO),
  20. "%": Token(token.REM),
  21. "&": Token(token.AND),
  22. "|": Token(token.OR),
  23. "^": Token(token.XOR),
  24. "<<": Token(token.SHL),
  25. ">>": Token(token.SHR),
  26. "&^": Token(token.AND_NOT),
  27. "+=": Token(token.ADD_ASSIGN),
  28. "-=": Token(token.SUB_ASSIGN),
  29. "*=": Token(token.MUL_ASSIGN),
  30. "/=": Token(token.QUO_ASSIGN),
  31. "%=": Token(token.REM_ASSIGN),
  32. "&=": Token(token.AND_ASSIGN),
  33. "|=": Token(token.OR_ASSIGN),
  34. "^=": Token(token.XOR_ASSIGN),
  35. "<<=": Token(token.SHL_ASSIGN),
  36. ">>=": Token(token.SHR_ASSIGN),
  37. "&^=": Token(token.AND_NOT_ASSIGN),
  38. "&&": Token(token.LAND),
  39. "||": Token(token.LOR),
  40. "<-": Token(token.ARROW),
  41. "++": Token(token.INC),
  42. "--": Token(token.DEC),
  43. "==": Token(token.EQL),
  44. "<": Token(token.LSS),
  45. ">": Token(token.GTR),
  46. "=": Token(token.ASSIGN),
  47. "!": Token(token.NOT),
  48. "!=": Token(token.NEQ),
  49. "<=": Token(token.LEQ),
  50. ">=": Token(token.GEQ),
  51. ":=": Token(token.DEFINE),
  52. "...": Token(token.ELLIPSIS),
  53. "IMPORT": Token(token.IMPORT),
  54. "VAR": Token(token.VAR),
  55. "TYPE": Token(token.TYPE),
  56. "CONST": Token(token.CONST),
  57. "BREAK": Token(token.BREAK),
  58. "CONTINUE": Token(token.CONTINUE),
  59. "GOTO": Token(token.GOTO),
  60. "FALLTHROUGH": Token(token.FALLTHROUGH),
  61. }
  62. func maybeToken(node Node) (Node, bool) {
  63. if node, ok := node.(String); ok {
  64. if tok, ok := tokensByString[string(node)]; ok {
  65. return tok, true
  66. }
  67. return node, false
  68. }
  69. return node, false
  70. }
  71. func isNil(v interface{}) bool {
  72. if v == nil {
  73. return true
  74. }
  75. if _, ok := v.(Nil); ok {
  76. return true
  77. }
  78. return false
  79. }
  80. type matcher interface {
  81. Match(*Matcher, interface{}) (interface{}, bool)
  82. }
  83. type State = map[string]interface{}
  84. type Matcher struct {
  85. TypesInfo *types.Info
  86. State State
  87. }
  88. func (m *Matcher) fork() *Matcher {
  89. state := make(State, len(m.State))
  90. for k, v := range m.State {
  91. state[k] = v
  92. }
  93. return &Matcher{
  94. TypesInfo: m.TypesInfo,
  95. State: state,
  96. }
  97. }
  98. func (m *Matcher) merge(mc *Matcher) {
  99. m.State = mc.State
  100. }
  101. func (m *Matcher) Match(a Node, b ast.Node) bool {
  102. m.State = State{}
  103. _, ok := match(m, a, b)
  104. return ok
  105. }
  106. func Match(a Node, b ast.Node) (*Matcher, bool) {
  107. m := &Matcher{}
  108. ret := m.Match(a, b)
  109. return m, ret
  110. }
  111. // Match two items, which may be (Node, AST) or (AST, AST)
  112. func match(m *Matcher, l, r interface{}) (interface{}, bool) {
  113. if _, ok := r.(Node); ok {
  114. panic("Node mustn't be on right side of match")
  115. }
  116. switch l := l.(type) {
  117. case *ast.ParenExpr:
  118. return match(m, l.X, r)
  119. case *ast.ExprStmt:
  120. return match(m, l.X, r)
  121. case *ast.DeclStmt:
  122. return match(m, l.Decl, r)
  123. case *ast.LabeledStmt:
  124. return match(m, l.Stmt, r)
  125. case *ast.BlockStmt:
  126. return match(m, l.List, r)
  127. case *ast.FieldList:
  128. return match(m, l.List, r)
  129. }
  130. switch r := r.(type) {
  131. case *ast.ParenExpr:
  132. return match(m, l, r.X)
  133. case *ast.ExprStmt:
  134. return match(m, l, r.X)
  135. case *ast.DeclStmt:
  136. return match(m, l, r.Decl)
  137. case *ast.LabeledStmt:
  138. return match(m, l, r.Stmt)
  139. case *ast.BlockStmt:
  140. if r == nil {
  141. return match(m, l, nil)
  142. }
  143. return match(m, l, r.List)
  144. case *ast.FieldList:
  145. if r == nil {
  146. return match(m, l, nil)
  147. }
  148. return match(m, l, r.List)
  149. case *ast.BasicLit:
  150. if r == nil {
  151. return match(m, l, nil)
  152. }
  153. }
  154. if l, ok := l.(matcher); ok {
  155. return l.Match(m, r)
  156. }
  157. if l, ok := l.(Node); ok {
  158. // Matching of pattern with concrete value
  159. return matchNodeAST(m, l, r)
  160. }
  161. if l == nil || r == nil {
  162. return nil, l == r
  163. }
  164. {
  165. ln, ok1 := l.(ast.Node)
  166. rn, ok2 := r.(ast.Node)
  167. if ok1 && ok2 {
  168. return matchAST(m, ln, rn)
  169. }
  170. }
  171. {
  172. obj, ok := l.(types.Object)
  173. if ok {
  174. switch r := r.(type) {
  175. case *ast.Ident:
  176. return obj, obj == m.TypesInfo.ObjectOf(r)
  177. case *ast.SelectorExpr:
  178. return obj, obj == m.TypesInfo.ObjectOf(r.Sel)
  179. default:
  180. return obj, false
  181. }
  182. }
  183. }
  184. {
  185. ln, ok1 := l.([]ast.Expr)
  186. rn, ok2 := r.([]ast.Expr)
  187. if ok1 || ok2 {
  188. if ok1 && !ok2 {
  189. rn = []ast.Expr{r.(ast.Expr)}
  190. } else if !ok1 && ok2 {
  191. ln = []ast.Expr{l.(ast.Expr)}
  192. }
  193. if len(ln) != len(rn) {
  194. return nil, false
  195. }
  196. for i, ll := range ln {
  197. if _, ok := match(m, ll, rn[i]); !ok {
  198. return nil, false
  199. }
  200. }
  201. return r, true
  202. }
  203. }
  204. {
  205. ln, ok1 := l.([]ast.Stmt)
  206. rn, ok2 := r.([]ast.Stmt)
  207. if ok1 || ok2 {
  208. if ok1 && !ok2 {
  209. rn = []ast.Stmt{r.(ast.Stmt)}
  210. } else if !ok1 && ok2 {
  211. ln = []ast.Stmt{l.(ast.Stmt)}
  212. }
  213. if len(ln) != len(rn) {
  214. return nil, false
  215. }
  216. for i, ll := range ln {
  217. if _, ok := match(m, ll, rn[i]); !ok {
  218. return nil, false
  219. }
  220. }
  221. return r, true
  222. }
  223. }
  224. {
  225. ln, ok1 := l.([]*ast.Field)
  226. rn, ok2 := r.([]*ast.Field)
  227. if ok1 || ok2 {
  228. if ok1 && !ok2 {
  229. rn = []*ast.Field{r.(*ast.Field)}
  230. } else if !ok1 && ok2 {
  231. ln = []*ast.Field{l.(*ast.Field)}
  232. }
  233. if len(ln) != len(rn) {
  234. return nil, false
  235. }
  236. for i, ll := range ln {
  237. if _, ok := match(m, ll, rn[i]); !ok {
  238. return nil, false
  239. }
  240. }
  241. return r, true
  242. }
  243. }
  244. panic(fmt.Sprintf("unsupported comparison: %T and %T", l, r))
  245. }
  246. // Match a Node with an AST node
  247. func matchNodeAST(m *Matcher, a Node, b interface{}) (interface{}, bool) {
  248. switch b := b.(type) {
  249. case []ast.Stmt:
  250. // 'a' is not a List or we'd be using its Match
  251. // implementation.
  252. if len(b) != 1 {
  253. return nil, false
  254. }
  255. return match(m, a, b[0])
  256. case []ast.Expr:
  257. // 'a' is not a List or we'd be using its Match
  258. // implementation.
  259. if len(b) != 1 {
  260. return nil, false
  261. }
  262. return match(m, a, b[0])
  263. case ast.Node:
  264. ra := reflect.ValueOf(a)
  265. rb := reflect.ValueOf(b).Elem()
  266. if ra.Type().Name() != rb.Type().Name() {
  267. return nil, false
  268. }
  269. for i := 0; i < ra.NumField(); i++ {
  270. af := ra.Field(i)
  271. fieldName := ra.Type().Field(i).Name
  272. bf := rb.FieldByName(fieldName)
  273. if (bf == reflect.Value{}) {
  274. panic(fmt.Sprintf("internal error: could not find field %s in type %t when comparing with %T", fieldName, b, a))
  275. }
  276. ai := af.Interface()
  277. bi := bf.Interface()
  278. if ai == nil {
  279. return b, bi == nil
  280. }
  281. if _, ok := match(m, ai.(Node), bi); !ok {
  282. return b, false
  283. }
  284. }
  285. return b, true
  286. case nil:
  287. return nil, a == Nil{}
  288. default:
  289. panic(fmt.Sprintf("unhandled type %T", b))
  290. }
  291. }
  292. // Match two AST nodes
  293. func matchAST(m *Matcher, a, b ast.Node) (interface{}, bool) {
  294. ra := reflect.ValueOf(a)
  295. rb := reflect.ValueOf(b)
  296. if ra.Type() != rb.Type() {
  297. return nil, false
  298. }
  299. if ra.IsNil() || rb.IsNil() {
  300. return rb, ra.IsNil() == rb.IsNil()
  301. }
  302. ra = ra.Elem()
  303. rb = rb.Elem()
  304. for i := 0; i < ra.NumField(); i++ {
  305. af := ra.Field(i)
  306. bf := rb.Field(i)
  307. if af.Type() == rtTokPos || af.Type() == rtObject || af.Type() == rtCommentGroup {
  308. continue
  309. }
  310. switch af.Kind() {
  311. case reflect.Slice:
  312. if af.Len() != bf.Len() {
  313. return nil, false
  314. }
  315. for j := 0; j < af.Len(); j++ {
  316. if _, ok := match(m, af.Index(j).Interface().(ast.Node), bf.Index(j).Interface().(ast.Node)); !ok {
  317. return nil, false
  318. }
  319. }
  320. case reflect.String:
  321. if af.String() != bf.String() {
  322. return nil, false
  323. }
  324. case reflect.Int:
  325. if af.Int() != bf.Int() {
  326. return nil, false
  327. }
  328. case reflect.Bool:
  329. if af.Bool() != bf.Bool() {
  330. return nil, false
  331. }
  332. case reflect.Ptr, reflect.Interface:
  333. if _, ok := match(m, af.Interface(), bf.Interface()); !ok {
  334. return nil, false
  335. }
  336. default:
  337. panic(fmt.Sprintf("internal error: unhandled kind %s (%T)", af.Kind(), af.Interface()))
  338. }
  339. }
  340. return b, true
  341. }
  342. func (b Binding) Match(m *Matcher, node interface{}) (interface{}, bool) {
  343. if isNil(b.Node) {
  344. v, ok := m.State[b.Name]
  345. if ok {
  346. // Recall value
  347. return match(m, v, node)
  348. }
  349. // Matching anything
  350. b.Node = Any{}
  351. }
  352. // Store value
  353. if _, ok := m.State[b.Name]; ok {
  354. panic(fmt.Sprintf("binding already created: %s", b.Name))
  355. }
  356. new, ret := match(m, b.Node, node)
  357. if ret {
  358. m.State[b.Name] = new
  359. }
  360. return new, ret
  361. }
  362. func (Any) Match(m *Matcher, node interface{}) (interface{}, bool) {
  363. return node, true
  364. }
  365. func (l List) Match(m *Matcher, node interface{}) (interface{}, bool) {
  366. v := reflect.ValueOf(node)
  367. if v.Kind() == reflect.Slice {
  368. if isNil(l.Head) {
  369. return node, v.Len() == 0
  370. }
  371. if v.Len() == 0 {
  372. return nil, false
  373. }
  374. // OPT(dh): don't check the entire tail if head didn't match
  375. _, ok1 := match(m, l.Head, v.Index(0).Interface())
  376. _, ok2 := match(m, l.Tail, v.Slice(1, v.Len()).Interface())
  377. return node, ok1 && ok2
  378. }
  379. // Our empty list does not equal an untyped Go nil. This way, we can
  380. // tell apart an if with no else and an if with an empty else.
  381. return nil, false
  382. }
  383. func (s String) Match(m *Matcher, node interface{}) (interface{}, bool) {
  384. switch o := node.(type) {
  385. case token.Token:
  386. if tok, ok := maybeToken(s); ok {
  387. return match(m, tok, node)
  388. }
  389. return nil, false
  390. case string:
  391. return o, string(s) == o
  392. case types.TypeAndValue:
  393. return o, o.Value != nil && o.Value.String() == string(s)
  394. default:
  395. return nil, false
  396. }
  397. }
  398. func (tok Token) Match(m *Matcher, node interface{}) (interface{}, bool) {
  399. o, ok := node.(token.Token)
  400. if !ok {
  401. return nil, false
  402. }
  403. return o, token.Token(tok) == o
  404. }
  405. func (Nil) Match(m *Matcher, node interface{}) (interface{}, bool) {
  406. return nil, isNil(node) || reflect.ValueOf(node).IsNil()
  407. }
  408. func (builtin Builtin) Match(m *Matcher, node interface{}) (interface{}, bool) {
  409. r, ok := match(m, Ident(builtin), node)
  410. if !ok {
  411. return nil, false
  412. }
  413. ident := r.(*ast.Ident)
  414. obj := m.TypesInfo.ObjectOf(ident)
  415. if obj != types.Universe.Lookup(ident.Name) {
  416. return nil, false
  417. }
  418. return ident, true
  419. }
  420. func (obj Object) Match(m *Matcher, node interface{}) (interface{}, bool) {
  421. r, ok := match(m, Ident(obj), node)
  422. if !ok {
  423. return nil, false
  424. }
  425. ident := r.(*ast.Ident)
  426. id := m.TypesInfo.ObjectOf(ident)
  427. _, ok = match(m, obj.Name, ident.Name)
  428. return id, ok
  429. }
  430. func (fn Function) Match(m *Matcher, node interface{}) (interface{}, bool) {
  431. var name string
  432. var obj types.Object
  433. base := []Node{
  434. Ident{Any{}},
  435. SelectorExpr{Any{}, Any{}},
  436. }
  437. p := Or{
  438. Nodes: append(base,
  439. IndexExpr{Or{Nodes: base}, Any{}},
  440. IndexListExpr{Or{Nodes: base}, Any{}})}
  441. r, ok := match(m, p, node)
  442. if !ok {
  443. return nil, false
  444. }
  445. fun := r
  446. switch idx := fun.(type) {
  447. case *ast.IndexExpr:
  448. fun = idx.X
  449. case *typeparams.IndexListExpr:
  450. fun = idx.X
  451. }
  452. switch fun := fun.(type) {
  453. case *ast.Ident:
  454. obj = m.TypesInfo.ObjectOf(fun)
  455. switch obj := obj.(type) {
  456. case *types.Func:
  457. // OPT(dh): optimize this similar to code.FuncName
  458. name = obj.FullName()
  459. case *types.Builtin:
  460. name = obj.Name()
  461. case *types.TypeName:
  462. name = types.TypeString(obj.Type(), nil)
  463. default:
  464. return nil, false
  465. }
  466. case *ast.SelectorExpr:
  467. obj = m.TypesInfo.ObjectOf(fun.Sel)
  468. switch obj := obj.(type) {
  469. case *types.Func:
  470. // OPT(dh): optimize this similar to code.FuncName
  471. name = obj.FullName()
  472. case *types.TypeName:
  473. name = types.TypeString(obj.Type(), nil)
  474. default:
  475. return nil, false
  476. }
  477. default:
  478. panic("unreachable")
  479. }
  480. _, ok = match(m, fn.Name, name)
  481. return obj, ok
  482. }
  483. func (or Or) Match(m *Matcher, node interface{}) (interface{}, bool) {
  484. for _, opt := range or.Nodes {
  485. mc := m.fork()
  486. if ret, ok := match(mc, opt, node); ok {
  487. m.merge(mc)
  488. return ret, true
  489. }
  490. }
  491. return nil, false
  492. }
  493. func (not Not) Match(m *Matcher, node interface{}) (interface{}, bool) {
  494. _, ok := match(m, not.Node, node)
  495. if ok {
  496. return nil, false
  497. }
  498. return node, true
  499. }
  500. var integerLiteralQ = MustParse(`(Or (BasicLit "INT" _) (UnaryExpr (Or "+" "-") (IntegerLiteral _)))`)
  501. func (lit IntegerLiteral) Match(m *Matcher, node interface{}) (interface{}, bool) {
  502. matched, ok := match(m, integerLiteralQ.Root, node)
  503. if !ok {
  504. return nil, false
  505. }
  506. tv, ok := m.TypesInfo.Types[matched.(ast.Expr)]
  507. if !ok {
  508. return nil, false
  509. }
  510. if tv.Value == nil {
  511. return nil, false
  512. }
  513. _, ok = match(m, lit.Value, tv)
  514. return matched, ok
  515. }
  516. func (texpr TrulyConstantExpression) Match(m *Matcher, node interface{}) (interface{}, bool) {
  517. expr, ok := node.(ast.Expr)
  518. if !ok {
  519. return nil, false
  520. }
  521. tv, ok := m.TypesInfo.Types[expr]
  522. if !ok {
  523. return nil, false
  524. }
  525. if tv.Value == nil {
  526. return nil, false
  527. }
  528. truly := true
  529. ast.Inspect(expr, func(node ast.Node) bool {
  530. if _, ok := node.(*ast.Ident); ok {
  531. truly = false
  532. return false
  533. }
  534. return true
  535. })
  536. if !truly {
  537. return nil, false
  538. }
  539. _, ok = match(m, texpr.Value, tv)
  540. return expr, ok
  541. }
  542. var (
  543. // Types of fields in go/ast structs that we want to skip
  544. rtTokPos = reflect.TypeOf(token.Pos(0))
  545. rtObject = reflect.TypeOf((*ast.Object)(nil))
  546. rtCommentGroup = reflect.TypeOf((*ast.CommentGroup)(nil))
  547. )
  548. var (
  549. _ matcher = Binding{}
  550. _ matcher = Any{}
  551. _ matcher = List{}
  552. _ matcher = String("")
  553. _ matcher = Token(0)
  554. _ matcher = Nil{}
  555. _ matcher = Builtin{}
  556. _ matcher = Object{}
  557. _ matcher = Function{}
  558. _ matcher = Or{}
  559. _ matcher = Not{}
  560. _ matcher = IntegerLiteral{}
  561. _ matcher = TrulyConstantExpression{}
  562. )