package memstore

import (
	"bytes"
	"encoding/base32"
	"encoding/gob"
	"fmt"
	"net/http"
	"strings"

	"github.com/gorilla/securecookie"
	"github.com/gorilla/sessions"
)

// MemStore is an in-memory implementation of gorilla/sessions, suitable
// for use in tests and development environments. Do not use in production.
// Values are cached in a map. The cache is protected and can be used by
// multiple goroutines.
type MemStore struct {
	Codecs  []securecookie.Codec
	Options *sessions.Options
	cache   *cache
}

type valueType map[interface{}]interface{}

// NewMemStore returns a new MemStore.
//
// Keys are defined in pairs to allow key rotation, but the common case is
// to set a single authentication key and optionally an encryption key.
//
// The first key in a pair is used for authentication and the second for
// encryption. The encryption key can be set to nil or omitted in the last
// pair, but the authentication key is required in all pairs.
//
// It is recommended to use an authentication key with 32 or 64 bytes.
// The encryption key, if set, must be either 16, 24, or 32 bytes to select
// AES-128, AES-192, or AES-256 modes.
//
// Use the convenience function securecookie.GenerateRandomKey() to create
// strong keys.
func NewMemStore(keyPairs ...[]byte) *MemStore {
	store := MemStore{
		Codecs: securecookie.CodecsFromPairs(keyPairs...),
		Options: &sessions.Options{
			Path:   "/",
			MaxAge: 86400 * 30,
		},
		cache: newCache(),
	}
	store.MaxAge(store.Options.MaxAge)
	return &store
}

// Get returns a session for the given name after adding it to the registry.
//
// It returns a new session if the sessions doesn't exist. Access IsNew on
// the session to check if it is an existing session or a new one.
//
// It returns a new session and an error if the session exists but could
// not be decoded.
func (m *MemStore) Get(r *http.Request, name string) (*sessions.Session, error) {
	return sessions.GetRegistry(r).Get(m, name)
}

// New returns a session for the given name without adding it to the registry.
//
// The difference between New() and Get() is that calling New() twice will
// decode the session data twice, while Get() registers and reuses the same
// decoded session after the first call.
func (m *MemStore) New(r *http.Request, name string) (*sessions.Session, error) {
	session := sessions.NewSession(m, name)
	options := *m.Options
	session.Options = &options
	session.IsNew = true

	c, err := r.Cookie(name)
	if err != nil {
		// Cookie not found, this is a new session
		return session, nil
	}

	err = securecookie.DecodeMulti(name, c.Value, &session.ID, m.Codecs...)
	if err != nil {
		// Value could not be decrypted, consider this is a new session
		return session, err
	}

	v, ok := m.cache.value(session.ID)
	if !ok {
		// No value found in cache, don't set any values in session object,
		// consider a new session
		return session, nil
	}

	// Values found in session, this is not a new session
	session.Values = m.copy(v)
	session.IsNew = false
	return session, nil
}

// Save adds a single session to the response.
// Set Options.MaxAge to -1 or call MaxAge(-1) before saving the session to delete all values in it.
func (m *MemStore) Save(r *http.Request, w http.ResponseWriter, s *sessions.Session) error {
	var cookieValue string
	if s.Options.MaxAge < 0 {
		cookieValue = ""
		m.cache.delete(s.ID)
		for k := range s.Values {
			delete(s.Values, k)
		}
	} else {
		if s.ID == "" {
			s.ID = strings.TrimRight(base32.StdEncoding.EncodeToString(securecookie.GenerateRandomKey(32)), "=")
		}
		encrypted, err := securecookie.EncodeMulti(s.Name(), s.ID, m.Codecs...)
		if err != nil {
			return err
		}
		cookieValue = encrypted
		m.cache.setValue(s.ID, m.copy(s.Values))
	}
	http.SetCookie(w, sessions.NewCookie(s.Name(), cookieValue, s.Options))
	return nil
}

// MaxAge sets the maximum age for the store and the underlying cookie
// implementation. Individual sessions can be deleted by setting Options.MaxAge
// = -1 for that session.
func (m *MemStore) MaxAge(age int) {
	m.Options.MaxAge = age

	// Set the maxAge for each securecookie instance.
	for _, codec := range m.Codecs {
		if sc, ok := codec.(*securecookie.SecureCookie); ok {
			sc.MaxAge(age)
		}
	}
}

func (m *MemStore) copy(v valueType) valueType {
	var buf bytes.Buffer
	enc := gob.NewEncoder(&buf)
	dec := gob.NewDecoder(&buf)
	err := enc.Encode(v)
	if err != nil {
		panic(fmt.Errorf("could not copy memstore value. Encoding to gob failed: %v", err))
	}
	var value valueType
	err = dec.Decode(&value)
	if err != nil {
		panic(fmt.Errorf("could not copy memstore value. Decoding from gob failed: %v", err))
	}
	return value
}