package grpcreflect

import (
	"fmt"

	"google.golang.org/grpc"
	"google.golang.org/grpc/reflection"
	"google.golang.org/protobuf/reflect/protoreflect"
	"google.golang.org/protobuf/reflect/protoregistry"

	"github.com/jhump/protoreflect/v2/protoresolve"
)

// GRPCServer is the interface provided by a gRPC server. In addition to being a
// service registrar (for registering services and handlers), it also has an
// accessor for retrieving metadata about all registered services.
type GRPCServer = reflection.GRPCServer

// LoadServiceDescriptors loads the service descriptors for all services exposed by the
// given GRPC server.
func LoadServiceDescriptors(s GRPCServer) (map[string]protoreflect.ServiceDescriptor, error) {
	descs := map[string]protoreflect.ServiceDescriptor{}
	for name, info := range s.GetServiceInfo() {
		// See if the service info provides the schema in the service metadata.
		sd, ok := info.Metadata.(protoreflect.ServiceDescriptor)
		if !ok {
			var err error
			sd, err = findServiceDescriptor(name)
			if err != nil {
				return nil, err
			}
		}
		descs[name] = sd
	}
	return descs, nil
}

// LoadServiceDescriptor loads a rich descriptor for a given service description
// generated by protoc-gen-go. Generated code contains an exported symbol with
// a name like "<Service>_serviceDesc" which is the service's description. It
// is used internally to register a service implementation with a GRPC server.
// But it can also be used by this package to retrieve the rich descriptor for
// the service.
func LoadServiceDescriptor(svc *grpc.ServiceDesc) (protoreflect.ServiceDescriptor, error) {
	// See if the service info provides the schema in the service metadata.
	if sd, ok := svc.Metadata.(protoreflect.ServiceDescriptor); ok {
		return sd, nil
	}
	sd, err := findServiceDescriptor(svc.ServiceName)
	if err != nil {
		return nil, fmt.Errorf("could not resolve descriptor for service %q: %w", svc.ServiceName, err)
	}
	return sd, nil
}

func findServiceDescriptor(name string) (protoreflect.ServiceDescriptor, error) {
	d, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(name))
	if err != nil {
		return nil, fmt.Errorf("could not resolve descriptor for service %q: %w", name, err)
	}
	sd, ok := d.(protoreflect.ServiceDescriptor)
	if !ok {
		return nil, protoresolve.NewUnexpectedTypeError(protoresolve.DescriptorKindService, d, "")
	}
	return sd, nil
}
