Skip to content

Commit b6dd8ea

Browse files
committed
feat: add client retrieval and pagination to stores and tests
- Add GetClients method to Store interface for retrieving all clients - Implement GetClients for in-memory and Redis stores - Introduce tests for GetClients covering empty, small, and large client sets in both memory and Redis stores - Add Redis tests to verify correct pagination when retrieving large numbers of clients - Improve Redis store tests to handle missing Docker environment gracefully Signed-off-by: appleboy <[email protected]>
1 parent 1d01f61 commit b6dd8ea

File tree

6 files changed

+274
-23
lines changed

6 files changed

+274
-23
lines changed

pkg/core/store.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ type Store interface {
3434
DeleteAuthorizationCode(ctx context.Context, clientID string) error
3535

3636
GetClient(ctx context.Context, clientID string) (*Client, error)
37+
GetClients(ctx context.Context) ([]*Client, error)
3738
CreateClient(ctx context.Context, client *Client) error
3839
UpdateClient(ctx context.Context, client *Client) error
3940
DeleteClient(ctx context.Context, clientID string) error

pkg/store/factory_test.go

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -67,19 +67,19 @@ func TestParseStoreType(t *testing.T) {
6767

6868
func TestStoreType_String(t *testing.T) {
6969
tests := []struct {
70-
name string
70+
name string
7171
storeType StoreType
72-
expected string
72+
expected string
7373
}{
7474
{
75-
name: "memory to string",
75+
name: "memory to string",
7676
storeType: StoreTypeMemory,
77-
expected: "memory",
77+
expected: "memory",
7878
},
7979
{
80-
name: "redis to string",
80+
name: "redis to string",
8181
storeType: StoreTypeRedis,
82-
expected: "redis",
82+
expected: "redis",
8383
},
8484
}
8585

@@ -95,29 +95,29 @@ func TestStoreType_String(t *testing.T) {
9595

9696
func TestStoreType_IsValid(t *testing.T) {
9797
tests := []struct {
98-
name string
98+
name string
9999
storeType StoreType
100-
expected bool
100+
expected bool
101101
}{
102102
{
103-
name: "memory is valid",
103+
name: "memory is valid",
104104
storeType: StoreTypeMemory,
105-
expected: true,
105+
expected: true,
106106
},
107107
{
108-
name: "redis is valid",
108+
name: "redis is valid",
109109
storeType: StoreTypeRedis,
110-
expected: true,
110+
expected: true,
111111
},
112112
{
113-
name: "invalid type",
113+
name: "invalid type",
114114
storeType: StoreType("invalid"),
115-
expected: false,
115+
expected: false,
116116
},
117117
{
118-
name: "empty type",
118+
name: "empty type",
119119
storeType: StoreType(""),
120-
expected: false,
120+
expected: false,
121121
},
122122
}
123123

@@ -167,6 +167,13 @@ func TestFactory_Create_Memory(t *testing.T) {
167167
}
168168

169169
func TestFactory_Create_Redis(t *testing.T) {
170+
// Recover from panic (e.g., Docker not available)
171+
defer func() {
172+
if r := recover(); r != nil {
173+
t.Skipf("Cannot setup Redis container (Docker may not be running): %v", r)
174+
}
175+
}()
176+
170177
ctx := context.Background()
171178

172179
// Setup Redis container using testcontainers
@@ -192,7 +199,6 @@ func TestFactory_Create_Redis(t *testing.T) {
192199
factory := NewFactory(config)
193200

194201
store, err := factory.Create()
195-
196202
// Skip test if Redis is not available
197203
if err != nil {
198204
t.Skipf("Redis not available, skipping test: %v", err)
@@ -231,17 +237,17 @@ func TestFactory_Create_InvalidType(t *testing.T) {
231237

232238
func TestNewStore(t *testing.T) {
233239
tests := []struct {
234-
name string
235-
config Config
236-
wantErr bool
240+
name string
241+
config Config
242+
wantErr bool
237243
wantType interface{}
238244
}{
239245
{
240246
name: "create memory store",
241247
config: Config{
242248
Type: StoreTypeMemory,
243249
},
244-
wantErr: false,
250+
wantErr: false,
245251
wantType: (*MemoryStore)(nil),
246252
},
247253
{

pkg/store/memory.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,16 @@ func (m *MemoryStore) DeleteClient(ctx context.Context, clientID string) error {
176176
delete(m.clients, clientID)
177177
return nil
178178
}
179+
180+
// GetClients retrieves all clients from memory.
181+
func (m *MemoryStore) GetClients(ctx context.Context) ([]*core.Client, error) {
182+
m.mu.RLock()
183+
defer m.mu.RUnlock()
184+
185+
clients := make([]*core.Client, 0, len(m.clients))
186+
for _, client := range m.clients {
187+
clients = append(clients, client)
188+
}
189+
190+
return clients, nil
191+
}

pkg/store/memory_test.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package store
33
import (
44
"context"
55
"errors"
6+
"fmt"
67
"sync"
78
"testing"
89
"time"
@@ -713,3 +714,92 @@ func TestMemoryStore_Client_Concurrent(t *testing.T) {
713714

714715
wg.Wait()
715716
}
717+
718+
func TestMemoryStore_GetClients(t *testing.T) {
719+
store := NewMemoryStore()
720+
ctx := context.Background()
721+
722+
// 1. Test with an empty store
723+
clients, err := store.GetClients(ctx)
724+
if err != nil {
725+
t.Fatalf("GetClients() on empty store failed: %v", err)
726+
}
727+
if len(clients) != 0 {
728+
t.Fatalf("Expected 0 clients, got %d", len(clients))
729+
}
730+
731+
// 2. Add some clients
732+
client1 := &core.Client{ID: "client1", Secret: "secret1"}
733+
client2 := &core.Client{ID: "client2", Secret: "secret2"}
734+
if err := store.CreateClient(ctx, client1); err != nil {
735+
t.Fatalf("Failed to create client1: %v", err)
736+
}
737+
if err := store.CreateClient(ctx, client2); err != nil {
738+
t.Fatalf("Failed to create client2: %v", err)
739+
}
740+
741+
// 3. Test with multiple clients
742+
clients, err = store.GetClients(ctx)
743+
if err != nil {
744+
t.Fatalf("GetClients() with multiple clients failed: %v", err)
745+
}
746+
if len(clients) != 2 {
747+
t.Fatalf("Expected 2 clients, got %d", len(clients))
748+
}
749+
750+
// Check if the correct clients are returned (order is not guaranteed)
751+
found1 := false
752+
found2 := false
753+
for _, c := range clients {
754+
if c.ID == "client1" {
755+
found1 = true
756+
}
757+
if c.ID == "client2" {
758+
found2 = true
759+
}
760+
}
761+
if !found1 || !found2 {
762+
t.Errorf("Did not find all clients. Found1: %v, Found2: %v", found1, found2)
763+
}
764+
}
765+
766+
func TestMemoryStore_GetClients_LargeNumber(t *testing.T) {
767+
store := NewMemoryStore()
768+
ctx := context.Background()
769+
numClients := 150
770+
771+
// Create a large number of clients
772+
for i := 0; i < numClients; i++ {
773+
client := &core.Client{
774+
ID: fmt.Sprintf("client-large-%d", i),
775+
Secret: "secret",
776+
}
777+
if err := store.CreateClient(ctx, client); err != nil {
778+
t.Fatalf("Failed to create client %d: %v", i, err)
779+
}
780+
}
781+
782+
// Get all clients
783+
clients, err := store.GetClients(ctx)
784+
if err != nil {
785+
t.Fatalf("GetClients() with large number of clients failed: %v", err)
786+
}
787+
788+
// Verify the number of clients retrieved
789+
if len(clients) != numClients {
790+
t.Errorf("Expected %d clients, but got %d", numClients, len(clients))
791+
}
792+
793+
// Verify all clients are present
794+
clientMap := make(map[string]bool)
795+
for _, c := range clients {
796+
clientMap[c.ID] = true
797+
}
798+
799+
for i := 0; i < numClients; i++ {
800+
clientID := fmt.Sprintf("client-large-%d", i)
801+
if !clientMap[clientID] {
802+
t.Errorf("Client %s was not found in the retrieved list", clientID)
803+
}
804+
}
805+
}

pkg/store/redis.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,3 +259,45 @@ func (r *RedisStore) DeleteClient(ctx context.Context, clientID string) error {
259259

260260
return nil
261261
}
262+
263+
// GetClients retrieves all clients from Redis.
264+
// This operation can be expensive and should be used with caution in production.
265+
func (r *RedisStore) GetClients(ctx context.Context) ([]*core.Client, error) {
266+
var clients []*core.Client
267+
var cursor uint64
268+
269+
for {
270+
scanCmd := r.client.B().Scan().Cursor(cursor).Match(clientPrefix + "*").Count(100).Build()
271+
scanResult, err := r.client.Do(ctx, scanCmd).AsScanEntry()
272+
if err != nil {
273+
return nil, fmt.Errorf("failed to scan for clients in redis: %w", err)
274+
}
275+
276+
if len(scanResult.Elements) > 0 {
277+
mgetCmd := r.client.B().Mget().Key(scanResult.Elements...).Build()
278+
values, err := r.client.Do(ctx, mgetCmd).AsStrSlice()
279+
if err != nil {
280+
return nil, fmt.Errorf("failed to get clients from redis: %w", err)
281+
}
282+
283+
for _, val := range values {
284+
if val == "" {
285+
continue
286+
}
287+
var client core.Client
288+
if err := json.Unmarshal([]byte(val), &client); err != nil {
289+
// Log or skip corrupted data
290+
continue
291+
}
292+
clients = append(clients, &client)
293+
}
294+
}
295+
296+
cursor = scanResult.Cursor
297+
if cursor == 0 {
298+
break
299+
}
300+
}
301+
302+
return clients, nil
303+
}

pkg/store/redis_test.go

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ func setupRedisContainer(ctx context.Context) (string, error) {
2121
Image: "redis:7-alpine",
2222
ExposedPorts: []string{"6379/tcp"},
2323
WaitingFor: wait.ForAll(
24-
wait.ForLog("Ready to accept connections").WithStartupTimeout(30 * time.Second),
25-
wait.ForListeningPort("6379/tcp").WithStartupTimeout(30 * time.Second),
24+
wait.ForLog("Ready to accept connections").WithStartupTimeout(30*time.Second),
25+
wait.ForListeningPort("6379/tcp").WithStartupTimeout(30*time.Second),
2626
),
2727
}
2828

@@ -958,3 +958,102 @@ func TestRedisStore_GetClient_CacheInvalidation(t *testing.T) {
958958

959959
t.Log("Cache invalidation working correctly after client update")
960960
}
961+
962+
func TestRedisStore_GetClients(t *testing.T) {
963+
store, cleanup := setupRedisStore(t)
964+
if store == nil {
965+
return // Skip if Redis not available
966+
}
967+
defer cleanup()
968+
969+
ctx := context.Background()
970+
971+
// 1. Test with an empty store
972+
clients, err := store.GetClients(ctx)
973+
if err != nil {
974+
t.Fatalf("GetClients() on empty store failed: %v", err)
975+
}
976+
if len(clients) != 0 {
977+
t.Fatalf("Expected 0 clients, got %d", len(clients))
978+
}
979+
980+
// 2. Add some clients
981+
client1 := &core.Client{ID: "client1-redis", Secret: "secret1"}
982+
client2 := &core.Client{ID: "client2-redis", Secret: "secret2"}
983+
if err := store.CreateClient(ctx, client1); err != nil {
984+
t.Fatalf("Failed to create client1: %v", err)
985+
}
986+
if err := store.CreateClient(ctx, client2); err != nil {
987+
t.Fatalf("Failed to create client2: %v", err)
988+
}
989+
990+
// 3. Test with multiple clients
991+
clients, err = store.GetClients(ctx)
992+
if err != nil {
993+
t.Fatalf("GetClients() with multiple clients failed: %v", err)
994+
}
995+
if len(clients) != 2 {
996+
t.Fatalf("Expected 2 clients, got %d", len(clients))
997+
}
998+
999+
// Check if the correct clients are returned
1000+
found1 := false
1001+
found2 := false
1002+
for _, c := range clients {
1003+
if c.ID == "client1-redis" {
1004+
found1 = true
1005+
}
1006+
if c.ID == "client2-redis" {
1007+
found2 = true
1008+
}
1009+
}
1010+
if !found1 || !found2 {
1011+
t.Errorf("Did not find all clients. Found1: %v, Found2: %v", found1, found2)
1012+
}
1013+
}
1014+
1015+
func TestRedisStore_GetClients_Pagination(t *testing.T) {
1016+
store, cleanup := setupRedisStore(t)
1017+
if store == nil {
1018+
return // Skip if Redis not available
1019+
}
1020+
defer cleanup()
1021+
1022+
ctx := context.Background()
1023+
numClients := 150 // More than the SCAN COUNT of 100
1024+
1025+
// Create a large number of clients
1026+
for i := 0; i < numClients; i++ {
1027+
client := &core.Client{
1028+
ID: fmt.Sprintf("client-pagination-%d", i),
1029+
Secret: "secret",
1030+
}
1031+
if err := store.CreateClient(ctx, client); err != nil {
1032+
t.Fatalf("Failed to create client %d: %v", i, err)
1033+
}
1034+
}
1035+
1036+
// Get all clients
1037+
clients, err := store.GetClients(ctx)
1038+
if err != nil {
1039+
t.Fatalf("GetClients() with pagination failed: %v", err)
1040+
}
1041+
1042+
// Verify the number of clients retrieved
1043+
if len(clients) != numClients {
1044+
t.Errorf("Expected %d clients, but got %d", numClients, len(clients))
1045+
}
1046+
1047+
// Verify all clients are present
1048+
clientMap := make(map[string]bool)
1049+
for _, c := range clients {
1050+
clientMap[c.ID] = true
1051+
}
1052+
1053+
for i := 0; i < numClients; i++ {
1054+
clientID := fmt.Sprintf("client-pagination-%d", i)
1055+
if !clientMap[clientID] {
1056+
t.Errorf("Client %s was not found in the retrieved list", clientID)
1057+
}
1058+
}
1059+
}

0 commit comments

Comments
 (0)