From 35732a2822d89feca856cb3a4623067939fcd659 Mon Sep 17 00:00:00 2001 From: Emman Date: Fri, 15 Apr 2022 11:14:30 +0800 Subject: [PATCH] #45, #46 --- smx509/cert_pool.go | 8 ++- smx509/cert_pool_test.go | 126 +++++++++++++++++++++++++++------------ 2 files changed, 94 insertions(+), 40 deletions(-) diff --git a/smx509/cert_pool.go b/smx509/cert_pool.go index df04f61..cf1dd52 100644 --- a/smx509/cert_pool.go +++ b/smx509/cert_pool.go @@ -73,7 +73,8 @@ func (s *CertPool) cert(n int) (*Certificate, error) { return s.lazyCerts[n].getCert() } -func (s *CertPool) copy() *CertPool { +// Clone returns a copy of s. +func (s *CertPool) Clone() *CertPool { p := &CertPool{ byName: make(map[string][]int, len(s.byName)), lazyCerts: make([]lazyCert, len(s.lazyCerts)), @@ -105,7 +106,7 @@ func (s *CertPool) copy() *CertPool { // New changes in the system cert pool might not be reflected in subsequent calls. func SystemCertPool() (*CertPool, error) { if sysRoots := systemRootsPool(); sysRoots != nil { - return sysRoots.copy(), nil + return sysRoots.Clone(), nil } return loadSystemRoots() @@ -248,6 +249,9 @@ func (s *CertPool) Subjects() [][]byte { // Equal reports whether s and other are equal. func (s *CertPool) Equal(other *CertPool) bool { + if s == nil || other == nil { + return s == other + } if s.systemPool != other.systemPool || len(s.haveSum) != len(other.haveSum) { return false } diff --git a/smx509/cert_pool_test.go b/smx509/cert_pool_test.go index 700f2ed..2e81a44 100644 --- a/smx509/cert_pool_test.go +++ b/smx509/cert_pool_test.go @@ -3,52 +3,102 @@ package smx509 import "testing" func TestCertPoolEqual(t *testing.T) { - a, b := NewCertPool(), NewCertPool() - if !a.Equal(b) { - t.Error("two empty pools not equal") - } - tc := &Certificate{Raw: []byte{1, 2, 3}, RawSubject: []byte{2}} - a.AddCert(tc) - if a.Equal(b) { - t.Error("empty pool equals non-empty pool") - } - - b.AddCert(tc) - if !a.Equal(b) { - t.Error("two non-empty pools not equal") - } - otherTC := &Certificate{Raw: []byte{9, 8, 7}, RawSubject: []byte{8}} - a.AddCert(otherTC) - if a.Equal(b) { - t.Error("non-equal pools equal") - } - systemA, err := SystemCertPool() + emptyPool := NewCertPool() + nonSystemPopulated := NewCertPool() + nonSystemPopulated.AddCert(tc) + nonSystemPopulatedAlt := NewCertPool() + nonSystemPopulatedAlt.AddCert(otherTC) + emptySystem, err := SystemCertPool() if err != nil { - t.Fatalf("unable to load system cert pool: %s", err) + t.Fatal(err) } - systemB, err := SystemCertPool() + populatedSystem, err := SystemCertPool() if err != nil { - t.Fatalf("unable to load system cert pool: %s", err) + t.Fatal(err) } - if !systemA.Equal(systemB) { - t.Error("two empty system pools not equal") + populatedSystem.AddCert(tc) + populatedSystemAlt, err := SystemCertPool() + if err != nil { + t.Fatal(err) + } + populatedSystemAlt.AddCert(otherTC) + tests := []struct { + name string + a *CertPool + b *CertPool + equal bool + }{ + { + name: "two empty pools", + a: emptyPool, + b: emptyPool, + equal: true, + }, + { + name: "one empty pool, one populated pool", + a: emptyPool, + b: nonSystemPopulated, + equal: false, + }, + { + name: "two populated pools", + a: nonSystemPopulated, + b: nonSystemPopulated, + equal: true, + }, + { + name: "two populated pools, different content", + a: nonSystemPopulated, + b: nonSystemPopulatedAlt, + equal: false, + }, + { + name: "two empty system pools", + a: emptySystem, + b: emptySystem, + equal: true, + }, + { + name: "one empty system pool, one populated system pool", + a: emptySystem, + b: populatedSystem, + equal: false, + }, + { + name: "two populated system pools", + a: populatedSystem, + b: populatedSystem, + equal: true, + }, + { + name: "two populated pools, different content", + a: populatedSystem, + b: populatedSystemAlt, + equal: false, + }, + { + name: "two nil pools", + a: nil, + b: nil, + equal: true, + }, + { + name: "one nil pool, one empty pool", + a: nil, + b: emptyPool, + equal: false, + }, } - systemA.AddCert(tc) - if systemA.Equal(systemB) { - t.Error("empty system pool equals non-empty system pool") - } - - systemB.AddCert(tc) - if !systemA.Equal(systemB) { - t.Error("two non-empty system pools not equal") - } - - systemA.AddCert(otherTC) - if systemA.Equal(systemB) { - t.Error("non-equal system pools equal") + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + equal := tc.a.Equal(tc.b) + if equal != tc.equal { + t.Errorf("Unexpected Equal result: got %t, want %t", equal, tc.equal) + } + }) } }