// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package bsoncodec

import (
	"encoding"
	"errors"
	"fmt"
	"reflect"
	"strconv"

	"go.mongodb.org/mongo-driver/bson/bsonoptions"
	"go.mongodb.org/mongo-driver/bson/bsonrw"
	"go.mongodb.org/mongo-driver/bson/bsontype"
)

var defaultMapCodec = NewMapCodec()

// MapCodec is the Codec used for map values.
//
// Deprecated: MapCodec will not be directly configurable in Go Driver 2.0. To
// configure the map encode and decode behavior, use the configuration methods
// on a [go.mongodb.org/mongo-driver/bson.Encoder] or
// [go.mongodb.org/mongo-driver/bson.Decoder]. To configure the map encode and
// decode behavior for a mongo.Client, use
// [go.mongodb.org/mongo-driver/mongo/options.ClientOptions.SetBSONOptions].
//
// For example, to configure a mongo.Client to marshal nil Go maps as empty BSON
// documents, use:
//
//	opt := options.Client().SetBSONOptions(&options.BSONOptions{
//	    NilMapAsEmpty: true,
//	})
//
// See the deprecation notice for each field in MapCodec for the corresponding
// settings.
type MapCodec struct {
	// DecodeZerosMap causes DecodeValue to delete any existing values from Go maps in the destination
	// value passed to Decode before unmarshaling BSON documents into them.
	//
	// Deprecated: Use bson.Decoder.ZeroMaps or options.BSONOptions.ZeroMaps instead.
	DecodeZerosMap bool

	// EncodeNilAsEmpty causes EncodeValue to marshal nil Go maps as empty BSON documents instead of
	// BSON null.
	//
	// Deprecated: Use bson.Encoder.NilMapAsEmpty or options.BSONOptions.NilMapAsEmpty instead.
	EncodeNilAsEmpty bool

	// EncodeKeysWithStringer causes the Encoder to convert Go map keys to BSON document field name
	// strings using fmt.Sprintf() instead of the default string conversion logic.
	//
	// Deprecated: Use bson.Encoder.StringifyMapKeysWithFmt or
	// options.BSONOptions.StringifyMapKeysWithFmt instead.
	EncodeKeysWithStringer bool
}

// KeyMarshaler is the interface implemented by an object that can marshal itself into a string key.
// This applies to types used as map keys and is similar to encoding.TextMarshaler.
type KeyMarshaler interface {
	MarshalKey() (key string, err error)
}

// KeyUnmarshaler is the interface implemented by an object that can unmarshal a string representation
// of itself. This applies to types used as map keys and is similar to encoding.TextUnmarshaler.
//
// UnmarshalKey must be able to decode the form generated by MarshalKey.
// UnmarshalKey must copy the text if it wishes to retain the text
// after returning.
type KeyUnmarshaler interface {
	UnmarshalKey(key string) error
}

// NewMapCodec returns a MapCodec with options opts.
//
// Deprecated: NewMapCodec will not be available in Go Driver 2.0. See
// [MapCodec] for more details.
func NewMapCodec(opts ...*bsonoptions.MapCodecOptions) *MapCodec {
	mapOpt := bsonoptions.MergeMapCodecOptions(opts...)

	codec := MapCodec{}
	if mapOpt.DecodeZerosMap != nil {
		codec.DecodeZerosMap = *mapOpt.DecodeZerosMap
	}
	if mapOpt.EncodeNilAsEmpty != nil {
		codec.EncodeNilAsEmpty = *mapOpt.EncodeNilAsEmpty
	}
	if mapOpt.EncodeKeysWithStringer != nil {
		codec.EncodeKeysWithStringer = *mapOpt.EncodeKeysWithStringer
	}
	return &codec
}

// EncodeValue is the ValueEncoder for map[*]* types.
func (mc *MapCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
	if !val.IsValid() || val.Kind() != reflect.Map {
		return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val}
	}

	if val.IsNil() && !mc.EncodeNilAsEmpty && !ec.nilMapAsEmpty {
		// If we have a nil map but we can't WriteNull, that means we're probably trying to encode
		// to a TopLevel document. We can't currently tell if this is what actually happened, but if
		// there's a deeper underlying problem, the error will also be returned from WriteDocument,
		// so just continue. The operations on a map reflection value are valid, so we can call
		// MapKeys within mapEncodeValue without a problem.
		err := vw.WriteNull()
		if err == nil {
			return nil
		}
	}

	dw, err := vw.WriteDocument()
	if err != nil {
		return err
	}

	return mc.mapEncodeValue(ec, dw, val, nil)
}

// mapEncodeValue handles encoding of the values of a map. The collisionFn returns
// true if the provided key exists, this is mainly used for inline maps in the
// struct codec.
func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw bsonrw.DocumentWriter, val reflect.Value, collisionFn func(string) bool) error {

	elemType := val.Type().Elem()
	encoder, err := ec.LookupEncoder(elemType)
	if err != nil && elemType.Kind() != reflect.Interface {
		return err
	}

	keys := val.MapKeys()
	for _, key := range keys {
		keyStr, err := mc.encodeKey(key, ec.stringifyMapKeysWithFmt)
		if err != nil {
			return err
		}

		if collisionFn != nil && collisionFn(keyStr) {
			return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key)
		}

		currEncoder, currVal, lookupErr := defaultValueEncoders.lookupElementEncoder(ec, encoder, val.MapIndex(key))
		if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) {
			return lookupErr
		}

		vw, err := dw.WriteDocumentElement(keyStr)
		if err != nil {
			return err
		}

		if errors.Is(lookupErr, errInvalidValue) {
			err = vw.WriteNull()
			if err != nil {
				return err
			}
			continue
		}

		err = currEncoder.EncodeValue(ec, vw, currVal)
		if err != nil {
			return err
		}
	}

	return dw.WriteDocumentEnd()
}

// DecodeValue is the ValueDecoder for map[string/decimal]* types.
func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
	if val.Kind() != reflect.Map || (!val.CanSet() && val.IsNil()) {
		return ValueDecoderError{Name: "MapDecodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val}
	}

	switch vrType := vr.Type(); vrType {
	case bsontype.Type(0), bsontype.EmbeddedDocument:
	case bsontype.Null:
		val.Set(reflect.Zero(val.Type()))
		return vr.ReadNull()
	case bsontype.Undefined:
		val.Set(reflect.Zero(val.Type()))
		return vr.ReadUndefined()
	default:
		return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type())
	}

	dr, err := vr.ReadDocument()
	if err != nil {
		return err
	}

	if val.IsNil() {
		val.Set(reflect.MakeMap(val.Type()))
	}

	if val.Len() > 0 && (mc.DecodeZerosMap || dc.zeroMaps) {
		clearMap(val)
	}

	eType := val.Type().Elem()
	decoder, err := dc.LookupDecoder(eType)
	if err != nil {
		return err
	}
	eTypeDecoder, _ := decoder.(typeDecoder)

	if eType == tEmpty {
		dc.Ancestor = val.Type()
	}

	keyType := val.Type().Key()

	for {
		key, vr, err := dr.ReadElement()
		if errors.Is(err, bsonrw.ErrEOD) {
			break
		}
		if err != nil {
			return err
		}

		k, err := mc.decodeKey(key, keyType)
		if err != nil {
			return err
		}

		elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, eType, true)
		if err != nil {
			return newDecodeError(key, err)
		}

		val.SetMapIndex(k, elem)
	}
	return nil
}

func clearMap(m reflect.Value) {
	var none reflect.Value
	for _, k := range m.MapKeys() {
		m.SetMapIndex(k, none)
	}
}

func (mc *MapCodec) encodeKey(val reflect.Value, encodeKeysWithStringer bool) (string, error) {
	if mc.EncodeKeysWithStringer || encodeKeysWithStringer {
		return fmt.Sprint(val), nil
	}

	// keys of any string type are used directly
	if val.Kind() == reflect.String {
		return val.String(), nil
	}
	// KeyMarshalers are marshaled
	if km, ok := val.Interface().(KeyMarshaler); ok {
		if val.Kind() == reflect.Ptr && val.IsNil() {
			return "", nil
		}
		buf, err := km.MarshalKey()
		if err == nil {
			return buf, nil
		}
		return "", err
	}
	// keys implement encoding.TextMarshaler are marshaled.
	if km, ok := val.Interface().(encoding.TextMarshaler); ok {
		if val.Kind() == reflect.Ptr && val.IsNil() {
			return "", nil
		}

		buf, err := km.MarshalText()
		if err != nil {
			return "", err
		}

		return string(buf), nil
	}

	switch val.Kind() {
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
		return strconv.FormatInt(val.Int(), 10), nil
	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
		return strconv.FormatUint(val.Uint(), 10), nil
	}
	return "", fmt.Errorf("unsupported key type: %v", val.Type())
}

var keyUnmarshalerType = reflect.TypeOf((*KeyUnmarshaler)(nil)).Elem()
var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()

func (mc *MapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value, error) {
	keyVal := reflect.ValueOf(key)
	var err error
	switch {
	// First, if EncodeKeysWithStringer is not enabled, try to decode withKeyUnmarshaler
	case !mc.EncodeKeysWithStringer && reflect.PtrTo(keyType).Implements(keyUnmarshalerType):
		keyVal = reflect.New(keyType)
		v := keyVal.Interface().(KeyUnmarshaler)
		err = v.UnmarshalKey(key)
		keyVal = keyVal.Elem()
	// Try to decode encoding.TextUnmarshalers.
	case reflect.PtrTo(keyType).Implements(textUnmarshalerType):
		keyVal = reflect.New(keyType)
		v := keyVal.Interface().(encoding.TextUnmarshaler)
		err = v.UnmarshalText([]byte(key))
		keyVal = keyVal.Elem()
	// Otherwise, go to type specific behavior
	default:
		switch keyType.Kind() {
		case reflect.String:
			keyVal = reflect.ValueOf(key).Convert(keyType)
		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
			n, parseErr := strconv.ParseInt(key, 10, 64)
			if parseErr != nil || reflect.Zero(keyType).OverflowInt(n) {
				err = fmt.Errorf("failed to unmarshal number key %v", key)
			}
			keyVal = reflect.ValueOf(n).Convert(keyType)
		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
			n, parseErr := strconv.ParseUint(key, 10, 64)
			if parseErr != nil || reflect.Zero(keyType).OverflowUint(n) {
				err = fmt.Errorf("failed to unmarshal number key %v", key)
				break
			}
			keyVal = reflect.ValueOf(n).Convert(keyType)
		case reflect.Float32, reflect.Float64:
			if mc.EncodeKeysWithStringer {
				parsed, err := strconv.ParseFloat(key, 64)
				if err != nil {
					return keyVal, fmt.Errorf("Map key is defined to be a decimal type (%v) but got error %w", keyType.Kind(), err)
				}
				keyVal = reflect.ValueOf(parsed)
				break
			}
			fallthrough
		default:
			return keyVal, fmt.Errorf("unsupported key type: %v", keyType)
		}
	}
	return keyVal, err
}