Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions map.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ type serviceMap struct {
}

// register adds a new service using reflection to extract its methods.
func (m *serviceMap) register(rcvr interface{}, name string, passReq bool) error {
func (m *serviceMap) register(rcvr interface{}, name string, passReq bool, f func(method string) string) error {
if f == nil {
f = func(method string) string { return method }
}

// Setup service.
s := &service{
name: name,
Expand All @@ -73,6 +77,7 @@ func (m *serviceMap) register(rcvr interface{}, name string, passReq bool) error
for i := 0; i < s.rcvrType.NumMethod(); i++ {
method := s.rcvrType.Method(i)
mtype := method.Type
methodName := f(method.Name)

// offset the parameter indexes by one if the
// service methods accept an HTTP request pointer
Expand Down Expand Up @@ -117,7 +122,7 @@ func (m *serviceMap) register(rcvr interface{}, name string, passReq bool) error
if returnType := mtype.Out(0); returnType != typeOfError {
continue
}
s.methods[method.Name] = &serviceMethod{
s.methods[methodName] = &serviceMethod{
method: method,
argsType: args.Elem(),
replyType: reply.Elem(),
Expand All @@ -142,8 +147,8 @@ func (m *serviceMap) register(rcvr interface{}, name string, passReq bool) error
// get returns a registered service given a method name.
//
// The method name uses a dotted notation as in "Service.Method".
func (m *serviceMap) get(method string) (*service, *serviceMethod, error) {
parts := strings.Split(method, ".")
func (m *serviceMap) get(method string, concatStyle string) (*service, *serviceMethod, error) {
parts := strings.Split(method, concatStyle)
if len(parts) != 2 {
err := fmt.Errorf("rpc: service/method request ill-formed: %q", method)
return nil, nil, err
Expand Down
79 changes: 59 additions & 20 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ type CodecRequest interface {
// NewServer returns a new RPC server.
func NewServer() *Server {
return &Server{
codecs: make(map[string]Codec),
services: new(serviceMap),
codecs: make(map[string]Codec),
services: new(serviceMap),
concatStyle: ".",
}
}

Expand All @@ -60,6 +61,16 @@ type Server struct {
interceptFunc func(i *RequestInfo) *http.Request
beforeFunc func(i *RequestInfo)
afterFunc func(i *RequestInfo)
concatStyle string // The style used to concatenate service and method names, e.g., "Service.Method" or "Service/Method"
}

// SetConcatStyle sets the style used to concatenate service and method names.
// The default style is ".".
// This is useful when you want to use a different style, such as "/".
// The style is used in the method name, e.g., "Service.Method" or "Service/Method".
func (s *Server) SetConcatStyle(concatStyle string) *Server {
s.concatStyle = concatStyle
return s
}

// RegisterCodec adds a new codec to the server.
Expand All @@ -78,17 +89,38 @@ func (s *Server) RegisterCodec(codec Codec, contentType string) {
//
// Methods from the receiver will be extracted if these rules are satisfied:
//
// - The receiver is exported (begins with an upper case letter) or local
// (defined in the package registering the service).
// - The method name is exported.
// - The method has three arguments: *http.Request, *args, *reply.
// - All three arguments are pointers.
// - The second and third arguments are exported or local.
// - The method has return type error.
// - The receiver is exported (begins with an upper case letter) or local
// (defined in the package registering the service).
// - The method name is exported.
// - The method has three arguments: *http.Request, *args, *reply.
// - All three arguments are pointers.
// - The second and third arguments are exported or local.
// - The method has return type error.
//
// All other methods are ignored.
func (s *Server) RegisterService(receiver interface{}, name string) error {
return s.services.register(receiver, name, true)
return s.services.register(receiver, name, true, nil)
}

// RegisterService adds a new service to the server.
//
// The name parameter is optional: if empty it will be inferred from
// the receiver type name.
// The methodF function is optional: it is used to transform the method name
//
// Methods from the receiver will be extracted if these rules are satisfied:
//
// - The receiver is exported (begins with an upper case letter) or local
// (defined in the package registering the service).
// - The method name is exported.
// - The method has three arguments: *http.Request, *args, *reply.
// - All three arguments are pointers.
// - The second and third arguments are exported or local.
// - The method has return type error.
//
// All other methods are ignored.
func (s *Server) RegisterServiceWithMethod(receiver interface{}, name string, methodF func(method string) string) error {
return s.services.register(receiver, name, true, methodF)
}

// RegisterTCPService adds a new TCP service to the server.
Expand All @@ -99,24 +131,24 @@ func (s *Server) RegisterService(receiver interface{}, name string) error {
//
// Methods from the receiver will be extracted if these rules are satisfied:
//
// - The receiver is exported (begins with an upper case letter) or local
// (defined in the package registering the service).
// - The method name is exported.
// - The method has two arguments: *args, *reply.
// - Both arguments are pointers.
// - Both arguments are exported or local.
// - The method has return type error.
// - The receiver is exported (begins with an upper case letter) or local
// (defined in the package registering the service).
// - The method name is exported.
// - The method has two arguments: *args, *reply.
// - Both arguments are pointers.
// - Both arguments are exported or local.
// - The method has return type error.
//
// All other methods are ignored.
func (s *Server) RegisterTCPService(receiver interface{}, name string) error {
return s.services.register(receiver, name, false)
return s.services.register(receiver, name, false, nil)
}

// HasMethod returns true if the given method is registered.
//
// The method uses a dotted notation as in "Service.Method".
func (s *Server) HasMethod(method string) bool {
if _, _, err := s.services.get(method); err == nil {
if _, _, err := s.services.get(method, s.concatStyle); err == nil {
return true
}
return false
Expand Down Expand Up @@ -180,7 +212,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.writeError(w, 400, errMethod.Error())
return
}
serviceSpec, methodSpec, errGet := s.services.get(method)
serviceSpec, methodSpec, errGet := s.services.get(method, s.concatStyle)
if errGet != nil {
s.writeError(w, 400, errGet.Error())
return
Expand Down Expand Up @@ -267,3 +299,10 @@ func (s *Server) writeError(w http.ResponseWriter, status int, msg string) {
})
}
}

func LowerFirstLetter(method string) string {
if len(method) == 0 {
return method
}
return strings.ToLower(method[:1]) + method[1:]
}
65 changes: 65 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,20 @@ func (t *Service1) Add(req *Service1Request, res *Service1Response) error {
type Service2 struct {
}

type Service3 struct {
}

func (t *Service3) Multiply(r *http.Request, req *Service1Request, res *Service1Response) error {
res.Result = req.A * req.B
return nil
}

func TestRegisterService(t *testing.T) {
var err error
s := NewServer()
service1 := new(Service1)
service2 := new(Service2)
service3 := new(Service3)

// Inferred name.
err = s.RegisterService(service1, "")
Expand All @@ -58,13 +67,19 @@ func TestRegisterService(t *testing.T) {
if err == nil {
t.Errorf("Expected error on service2")
}

err = s.RegisterServiceWithMethod(service3, "", LowerFirstLetter)
if err != nil || !s.HasMethod("Service3.multiply") {
t.Errorf("Expected to be registered: Service3.multiply")
}
}

func TestRegisterTCPService(t *testing.T) {
var err error
s := NewServer()
service1 := new(Service1)
service2 := new(Service2)
service3 := new(Service3)

// Inferred name.
err = s.RegisterTCPService(service1, "")
Expand All @@ -81,6 +96,56 @@ func TestRegisterTCPService(t *testing.T) {
if err == nil {
t.Errorf("Expected error on service2")
}
err = s.RegisterServiceWithMethod(service3, "", LowerFirstLetter)
if err != nil || !s.HasMethod("Service3.multiply") {
t.Errorf("Expected to be registered: Service3.multiply")
}
}

func TestRegisterServiceWithSlash(t *testing.T) {
var err error
s := NewServer().SetConcatStyle("/")
service1 := new(Service1)
service2 := new(Service2)

// Inferred name.
err = s.RegisterService(service1, "")
if err != nil || !s.HasMethod("Service1/Multiply") {
t.Errorf("Expected to be registered: Service1/Multiply")
}
// Provided name.
err = s.RegisterService(service1, "Foo")
if err != nil || !s.HasMethod("Foo/Multiply") {
t.Errorf("Expected to be registered: Foo/Multiply")
}
// No methods.
err = s.RegisterService(service2, "")
if err == nil {
t.Errorf("Expected error on service2")
}
}

func TestRegisterServiceWithUnderline(t *testing.T) {
var err error
s := NewServer().SetConcatStyle("_")
service1 := new(Service1)
service2 := new(Service2)

// Inferred name.
err = s.RegisterService(service1, "")
if err != nil || !s.HasMethod("Service1_Multiply") {
t.Errorf("Expected to be registered: Service1_Multiply")
}
// Provided name.
err = s.RegisterService(service1, "Foo")
if err != nil || !s.HasMethod("Foo_Multiply") {
t.Errorf("Expected to be registered: Foo_Multiply")
}
// No methods.
err = s.RegisterService(service2, "")
if err == nil {
t.Errorf("Expected error on service2")
}
}

// MockCodec decodes to Service1.Multiply.
Expand Down
13 changes: 8 additions & 5 deletions v2/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ type serviceMap struct {
}

// register adds a new service using reflection to extract its methods.
func (m *serviceMap) register(rcvr interface{}, name string) error {
func (m *serviceMap) register(rcvr interface{}, name string, f func(method string) string) error {
if f == nil {
f = func(method string) string { return method }
}
// Setup service.
s := &service{
name: name,
Expand All @@ -71,7 +74,7 @@ func (m *serviceMap) register(rcvr interface{}, name string) error {
for i := 0; i < s.rcvrType.NumMethod(); i++ {
method := s.rcvrType.Method(i)
mtype := method.Type
// Method must be exported.
methodName := f(method.Name)
if method.PkgPath != "" {
continue
}
Expand Down Expand Up @@ -101,7 +104,7 @@ func (m *serviceMap) register(rcvr interface{}, name string) error {
if returnType := mtype.Out(0); returnType != typeOfError {
continue
}
s.methods[method.Name] = &serviceMethod{
s.methods[methodName] = &serviceMethod{
method: method,
argsType: args.Elem(),
replyType: reply.Elem(),
Expand All @@ -126,8 +129,8 @@ func (m *serviceMap) register(rcvr interface{}, name string) error {
// get returns a registered service given a method name.
//
// The method name uses a dotted notation as in "Service.Method".
func (m *serviceMap) get(method string) (*service, *serviceMethod, error) {
parts := strings.Split(method, ".")
func (m *serviceMap) get(method string, concatStyle string) (*service, *serviceMethod, error) {
parts := strings.Split(method, concatStyle)
if len(parts) != 2 {
err := fmt.Errorf("rpc: service/method request ill-formed: %q", method)
return nil, nil, err
Expand Down
Loading