package structr

import (
	"errors"
	"fmt"
	"slices"
	"testing"

	"codeberg.org/gruf/go-byteutil"
	"codeberg.org/gruf/go-kv/format"
)

func TestCache(t *testing.T) {
	t.Run("structA", func(t *testing.T) { testCache(t, testStructA) })
	t.Run("structB", func(t *testing.T) { testCache(t, testStructB) })
	t.Run("structC", func(t *testing.T) { testCache(t, testStructC) })
}

func BenchmarkCacheGet(b *testing.B) {
	b.Run("structA", func(b *testing.B) { benchmarkCacheGet(b, testStructA) })
	b.Run("structB", func(b *testing.B) { benchmarkCacheGet(b, testStructB) })
	b.Run("structC", func(b *testing.B) { benchmarkCacheGet(b, testStructC) })
}

func BenchmarkCachePut(b *testing.B) {
	b.Run("structA", func(b *testing.B) { benchmarkCachePut(b, testStructA) })
	b.Run("structB", func(b *testing.B) { benchmarkCachePut(b, testStructB) })
	b.Run("structC", func(b *testing.B) { benchmarkCachePut(b, testStructC) })
}

func testCache[T any](t *testing.T, test test[T]) {
	var c Cache[*T]

	// Create invalidate function hook
	// to track invalidated value ptrs.
	invalidated := make(map[string]bool)
	invalidateFn := func(value *T) {
		var buf byteutil.Buffer
		format.Appendf(&buf, "{:?}", value)
		invalidated[buf.String()] = true
	}
	wasInvalidated := func(value *T) bool {
		var buf byteutil.Buffer
		format.Appendf(&buf, "{:?}", value)
		return invalidated[buf.String()]
	}

	// Initialize the struct cache.
	c.Init(CacheConfig[*T]{
		Indices:    test.indices,
		MaxSize:    len(test.values),
		Copy:       test.copyfn,
		Invalidate: invalidateFn,
	})

	// Check that fake indices cause panic
	for _, index := range test.indices {
		fake := index.Fields + "!"
		catchpanic(t, func() {
			c.Index(fake)
		}, "unknown index: "+fake)
	}

	// Check that wrong
	// index causes panic
	catchpanic(t, func() {
		wrong := new(Index)
		c.Invalidate(wrong)
	}, "invalid index for cache")

	// Insert all values.
	t.Logf("Put: %v", test.values)
	c.Put(test.values...)

	// Ensure cache fully populated
	if c.Len() != len(test.values) {
		t.Fatal("cache not fully populated after put")
	}

	// Ensure all values were invalidated
	// on insert via the callback function.
	for _, value := range test.values {
		if !wasInvalidated(value) {
			t.Fatalf("expected value was not invalidated: %+v", value)
		}
	}

	// Reset invalidated.
	clear(invalidated)

	// Check that we have each of these values
	// stored in all expected indices in cache.
	testCacheGetValues(t, &c, test)

	// Invalidate each of these values from
	// the cache. It's easier to just iterate
	// through all the values for all indices
	// instead of getting particular about which
	// value is cached in which particular index.
	for _, index := range test.indices {
		var keys []Key

		// Get associated structr index.
		idx := c.Index(index.Fields)

		for _, value := range test.values {
			// generate struct key parts for value.
			parts, ok := indexkey(idx, value)
			if !ok {
				continue
			}

			// generate key from parts.
			key := idx.Key(parts...)

			if !index.AllowZero && key.Zero() {
				// Key parts contain a zero value and this
				// index does not allow that. Skip lookup.
				continue
			}

			// add index key to keys.
			keys = append(keys, key)
		}

		// Invalidate all keys in index.
		t.Logf("Invalidate: %s %v", index.Fields, keys)
		c.Invalidate(idx, keys...)
	}

	// Ensure cache empty
	// after invalidation.
	if c.Len() != 0 {
		t.Fatal("cache not empty after invalidation")
	}

	// Ensure all values were invalidated
	// as expected via the callback function.
	for _, value := range test.values {
		if !wasInvalidated(value) {
			t.Fatalf("expected value was not invalidated: %+v", value)
		}
	}

	// Reset invalidated.
	clear(invalidated)

	// Store all values using the store function, though
	// returning an error (which shouldn't store them!).
	t.Log("testCacheStoreValueWithError")
	for _, value := range test.values {
		_ = c.Store(value, func() error {
			return errors.New("oh no!")
		})
	}

	// Ensure all values were invalidated
	// on insert via the callback function.
	for _, value := range test.values {
		if !wasInvalidated(value) {
			t.Fatalf("expected value was not invalidated: %+v", value)
		}
	}

	// Reset invalidated.
	clear(invalidated)

	// Store all values using the store function, this
	// time using no error, to ensure they get stored.
	t.Log("testCacheStoreValueNoError")
	for _, value := range test.values {
		_ = c.Store(value, func() error {
			return nil
		})
	}

	// Ensure all values were invalidated
	// on insert via the callback function.
	for _, value := range test.values {
		if !wasInvalidated(value) {
			t.Fatalf("expected value was not invalidated: %+v", value)
		}
	}

	// Reset invalidated.
	clear(invalidated)

	// Clear the cache.
	t.Log("Clear", c.Len())
	c.Clear()
	t.Log(c.Len())

	// Now test fetching values with load callback,
	// followed by a regular get to ensure cached.
	testCacheLoadValuesNoError(t, &c, test)
	testCacheGetValues(t, &c, test)

	// Clear the cache.
	t.Log("Clear", c.Len())
	c.Clear()
	t.Log(c.Len())

	// Now test loading values with an error returned
	// during load callback. To ensure error is cached
	// but also correctly invalidated on put.
	testCacheLoadValuesWithError(t, &c, test)
	t.Logf("testCachePutValues")
	c.Put(test.values...)
	testCacheLoadValuesNoError(t, &c, test)

	// Clear the cache.
	t.Log("Clear", c.Len())
	c.Clear()
	t.Log(c.Len())

	// Now test loading values with error returned
	// multiple times, followed by same situations
	// as the previous test, to ensure we don't get
	// double results in unique indices.
	testCacheLoadValuesWithError(t, &c, test)
	testCacheLoadValuesWithError(t, &c, test)
	t.Logf("testCachePutValues")
	c.Put(test.values...)
	testCacheLoadValuesNoError(t, &c, test)

	// print final debug.
	c.Clear()
	fmt.Println(c.Debug())
}

func benchmarkCacheGet[T any](b *testing.B, test test[T]) {
	var c Cache[*T]

	// Initialize the struct cache.
	c.Init(CacheConfig[*T]{
		Indices: test.indices,
		MaxSize: len(test.values),
		Copy:    test.copyfn,
	})

	// Insert test values.
	c.Put(test.values...)

	var lookups []struct {
		Index *Index
		Key   Key
	}

	// Generate possible lookups for values.
	for _, index := range test.indices {

		// Get associated structr index.
		idx := c.Index(index.Fields)

		for _, value := range test.values {
			// generate struct key parts for value.
			parts, ok := indexkey(idx, value)
			if !ok {
				continue
			}

			// generate key from parts.
			key := idx.Key(parts...)

			if !index.AllowZero && key.Zero() {
				// Key parts contain a zero value and this
				// index does not allow that. Skip lookup.
				continue
			}

			// Add generated index + key to lookup.
			lookups = append(lookups, struct {
				Index *Index
				Key   Key
			}{
				Index: idx,
				Key:   key,
			})
		}
	}

	// Reset stats.
	b.ResetTimer()

	// Run all Get lookups in parallel!
	b.RunParallel(func(pb *testing.PB) {
		for pb.Next() {
			for _, lookup := range lookups {
				_, _ = c.GetOne(lookup.Index, lookup.Key)
			}
		}
	})
}

func benchmarkCachePut[T any](b *testing.B, test test[T]) {
	var c Cache[*T]

	// Initialize the struct cache.
	c.Init(CacheConfig[*T]{
		Indices: test.indices,
		MaxSize: len(test.values),
		Copy:    test.copyfn,
	})

	// Reset stats.
	b.ResetTimer()

	// Run all Put writes in parallel!
	b.RunParallel(func(pb *testing.PB) {
		for pb.Next() {
			for _, value := range test.values {
				c.Put(value)
			}
		}
	})
}

func testCacheGetValues[T any](t *testing.T, c *Cache[*T], test test[T]) {
	testCacheOnEachIndexable(
		t, c, test,

		// onEachSingle:
		func(t *testing.T, index *Index, key Key, value *T) {
			t.Log("GetOne:", index.Name(), key)

			// Check for value under key.
			check, ok := c.GetOne(index, key)

			if !ok {
				t.Fatalf("could not find value in cache under: %s %+v", index.Name(), key)
			}

			if !test.equalfn(check, value) {
				t.Fatalf("incorrect value in cache under: %s %+v", index.Name(), key)
			}
		},

		// onEachMulti:
		func(t *testing.T, index *Index, key Key, values []*T) {
			t.Log("Get:", index.Name(), key)

			// Check for values under key.
			check := c.Get(index, key)

			if len(check) != len(values) {
				t.Fatalf("incorrect no. values in cache under: %s %+v have=%d want=%d", index.Name(), key, len(values), len(check))
			}

			for _, value := range values {
				if !slices.ContainsFunc(check, func(check *T) bool {
					return test.equalfn(value, check)
				}) {
					t.Fatalf("missing expected value in cache under: %s %+v", index.Name(), key)
				}
			}
		},
	)
}

func testCacheLoadValuesNoError[T any](t *testing.T, c *Cache[*T], test test[T]) {
	testCacheOnEachIndexable(
		t, c, test,

		// onEachSingle:
		func(t *testing.T, index *Index, key Key, value *T) {
			t.Log("LoadOneNoError:", index.Name(), key)

			// Check cache for this value, else load it using callback.
			check, _ := c.LoadOne(index, key, func() (*T, error) {
				return value, nil
			})

			if !test.equalfn(check, value) {
				t.Fatalf("incorrect value in cache under: %s %+v", index.Name(), key)
			}
		},

		// onEachMulti:
		func(t *testing.T, index *Index, key Key, values []*T) {
			t.Log("LoadNoError:", index.Name(), key)

			keys := []Key{key}

			// Check cache for values under key parts, else load it using callback.
			check, _ := c.Load(index, keys, func(ks []Key) ([]*T, error) {

				// Check that provided input keys equals originally input.
				if !slices.EqualFunc(keys, ks, func(k1, k2 Key) bool {
					return k1.Equal(k2)
				}) {
					t.Fatalf("unexpected keys passed to load: %s, %+v", index.Name(), ks)
				}

				return values, nil
			})

			if len(check) != len(values) {
				t.Fatalf("incorrect no. values in cache under: %s %+v", index.Name(), key)
			}

			for _, value := range values {
				if !slices.ContainsFunc(check, func(check *T) bool {
					return test.equalfn(value, check)
				}) {
					t.Fatalf("missing expected value in cache under: %s %+v", index.Name(), key)
				}
			}
		},
	)
}

var errInTest = errors.New("oh no an error")

func testCacheLoadValuesWithError[T any](t *testing.T, c *Cache[*T], test test[T]) {
	testCacheOnEachIndexable(
		t, c, test,

		// onEachSingle:
		func(t *testing.T, index *Index, key Key, _ *T) {
			t.Log("LoadOneWithError:", index.Name(), key)

			// Check cache for value but return error on callback.
			value, err := c.LoadOne(index, key, func() (*T, error) {
				return nil, errInTest
			})

			if value != nil {
				t.Fatalf("value remained cached: %s %+v %+v", index.Name(), key, value)
			}

			if err != errInTest {
				t.Fatalf("errInTest was not returned after callback: %s %+v", index.Name(), key)
			}

			// Check the cache again, ensuring error was cached.
			_, err = c.LoadOne(index, key, func() (*T, error) {
				t.Fatal("callback should not be called")
				return nil, nil
			})

			if err != errInTest {
				t.Fatalf("errInTest was not cached: %s %+v", index.Name(), key)
			}
		},

		// onEachMulti:
		func(t *testing.T, index *Index, key Key, values []*T) {
			t.Log("LoadWithError:", index.Name(), key)

			keys := []Key{key}

			// Check cache for values but return error on load callback.
			_, err := c.Load(index, keys, func([]Key) ([]*T, error) {
				return nil, errInTest
			})

			if err != errInTest {
				t.Fatalf("errInTest was not returned after callback: %s %+v", index.Name(), key)
			}
		},
	)
}

func testCacheOnEachIndexable[T any](
	t *testing.T,
	c *Cache[*T],
	test test[T],
	onEachSingle func(t *testing.T, index *Index, key Key, value *T),
	onEachMulti func(t *testing.T, index *Index, key Key, values []*T),
) {
	// Check that we have each of these values
	// stored in all expected indices in cache.
	for _, index := range test.indices {

		// Get Index with name.
		idx := c.Index(index.Fields)

		if !index.Multiple {
			// This index only stores by unique values,
			// so only one at a time should be returned.
			//
			// Iterate through all test values.
			for _, value := range test.values {

				// Generate struct key parts for value in index.
				parts, ok := indexkey(idx, value)
				if !ok {
					continue
				}

				// Generate key from parts.
				key := idx.Key(parts...)

				if !index.AllowZero && key.Zero() {
					// Key parts contain a zero value and this
					// index does not allow that. Skip lookup.
					continue
				}

				// Pass to the provided "each" test fn.
				onEachSingle(t, idx, key, value)
			}
		} else {
			// This index allows multiple values to be stored under each key,
			// so we need to separate the test values into expected groupings.
			for _, values := range groupValues(c, index.Fields, test.values) {

				// Take first value and generate index key parts.
				// They should all generate the same key for this
				// index anyway, so which value we use doesn't matter.
				parts, ok := indexkey(idx, values[0])
				if !ok {
					continue
				}

				// Generate key from parts.
				key := idx.Key(parts...)

				if !index.AllowZero && key.Zero() {
					// Key parts contain a zero value and this
					// index does not allow that. Skip lookup.
					continue
				}

				// Pass to the provided "each" test fn.
				onEachMulti(t, idx, key, values)
			}
		}
	}
}

// groupValues groups all provided values by their key for this index, in the case of "Multiple" configured indices.
func groupValues[T any](c *Cache[*T], index string, values []*T) map[string][]*T {
	idx := c.Index(index)
	groups := make(map[string][]*T)
	for _, value := range values {
		parts, ok := indexkey(idx, value)
		if !ok {
			continue
		}
		key := idx.Key(parts...)
		if !key.Zero() {
			continue
		}
		keystr := key.Key()
		groups[keystr] = append(groups[keystr], value)
	}
	return groups
}
