diff --git a/map.go b/map.go index 433f275..48610e1 100644 --- a/map.go +++ b/map.go @@ -50,7 +50,11 @@ type serviceMap struct { } // register adds a new service using reflection to extract its methods. -func (m *serviceMap) register(rcvr interface{}, name string, passReq bool) error { +func (m *serviceMap) register(rcvr interface{}, name string, passReq bool, f func(method string) string) error { + if f == nil { + f = func(method string) string { return method } + } + // Setup service. s := &service{ name: name, @@ -73,6 +77,7 @@ func (m *serviceMap) register(rcvr interface{}, name string, passReq bool) error for i := 0; i < s.rcvrType.NumMethod(); i++ { method := s.rcvrType.Method(i) mtype := method.Type + methodName := f(method.Name) // offset the parameter indexes by one if the // service methods accept an HTTP request pointer @@ -117,7 +122,7 @@ func (m *serviceMap) register(rcvr interface{}, name string, passReq bool) error if returnType := mtype.Out(0); returnType != typeOfError { continue } - s.methods[method.Name] = &serviceMethod{ + s.methods[methodName] = &serviceMethod{ method: method, argsType: args.Elem(), replyType: reply.Elem(), @@ -142,8 +147,8 @@ func (m *serviceMap) register(rcvr interface{}, name string, passReq bool) error // get returns a registered service given a method name. // // The method name uses a dotted notation as in "Service.Method". -func (m *serviceMap) get(method string) (*service, *serviceMethod, error) { - parts := strings.Split(method, ".") +func (m *serviceMap) get(method string, concatStyle string) (*service, *serviceMethod, error) { + parts := strings.Split(method, concatStyle) if len(parts) != 2 { err := fmt.Errorf("rpc: service/method request ill-formed: %q", method) return nil, nil, err diff --git a/server.go b/server.go index 76a3260..c9e1a83 100644 --- a/server.go +++ b/server.go @@ -40,8 +40,9 @@ type CodecRequest interface { // NewServer returns a new RPC server. func NewServer() *Server { return &Server{ - codecs: make(map[string]Codec), - services: new(serviceMap), + codecs: make(map[string]Codec), + services: new(serviceMap), + concatStyle: ".", } } @@ -60,6 +61,16 @@ type Server struct { interceptFunc func(i *RequestInfo) *http.Request beforeFunc func(i *RequestInfo) afterFunc func(i *RequestInfo) + concatStyle string // The style used to concatenate service and method names, e.g., "Service.Method" or "Service/Method" +} + +// SetConcatStyle sets the style used to concatenate service and method names. +// The default style is ".". +// This is useful when you want to use a different style, such as "/". +// The style is used in the method name, e.g., "Service.Method" or "Service/Method". +func (s *Server) SetConcatStyle(concatStyle string) *Server { + s.concatStyle = concatStyle + return s } // RegisterCodec adds a new codec to the server. @@ -78,17 +89,38 @@ func (s *Server) RegisterCodec(codec Codec, contentType string) { // // Methods from the receiver will be extracted if these rules are satisfied: // -// - The receiver is exported (begins with an upper case letter) or local -// (defined in the package registering the service). -// - The method name is exported. -// - The method has three arguments: *http.Request, *args, *reply. -// - All three arguments are pointers. -// - The second and third arguments are exported or local. -// - The method has return type error. +// - The receiver is exported (begins with an upper case letter) or local +// (defined in the package registering the service). +// - The method name is exported. +// - The method has three arguments: *http.Request, *args, *reply. +// - All three arguments are pointers. +// - The second and third arguments are exported or local. +// - The method has return type error. // // All other methods are ignored. func (s *Server) RegisterService(receiver interface{}, name string) error { - return s.services.register(receiver, name, true) + return s.services.register(receiver, name, true, nil) +} + +// RegisterService adds a new service to the server. +// +// The name parameter is optional: if empty it will be inferred from +// the receiver type name. +// The methodF function is optional: it is used to transform the method name +// +// Methods from the receiver will be extracted if these rules are satisfied: +// +// - The receiver is exported (begins with an upper case letter) or local +// (defined in the package registering the service). +// - The method name is exported. +// - The method has three arguments: *http.Request, *args, *reply. +// - All three arguments are pointers. +// - The second and third arguments are exported or local. +// - The method has return type error. +// +// All other methods are ignored. +func (s *Server) RegisterServiceWithMethod(receiver interface{}, name string, methodF func(method string) string) error { + return s.services.register(receiver, name, true, methodF) } // RegisterTCPService adds a new TCP service to the server. @@ -99,24 +131,24 @@ func (s *Server) RegisterService(receiver interface{}, name string) error { // // Methods from the receiver will be extracted if these rules are satisfied: // -// - The receiver is exported (begins with an upper case letter) or local -// (defined in the package registering the service). -// - The method name is exported. -// - The method has two arguments: *args, *reply. -// - Both arguments are pointers. -// - Both arguments are exported or local. -// - The method has return type error. +// - The receiver is exported (begins with an upper case letter) or local +// (defined in the package registering the service). +// - The method name is exported. +// - The method has two arguments: *args, *reply. +// - Both arguments are pointers. +// - Both arguments are exported or local. +// - The method has return type error. // // All other methods are ignored. func (s *Server) RegisterTCPService(receiver interface{}, name string) error { - return s.services.register(receiver, name, false) + return s.services.register(receiver, name, false, nil) } // HasMethod returns true if the given method is registered. // // The method uses a dotted notation as in "Service.Method". func (s *Server) HasMethod(method string) bool { - if _, _, err := s.services.get(method); err == nil { + if _, _, err := s.services.get(method, s.concatStyle); err == nil { return true } return false @@ -180,7 +212,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.writeError(w, 400, errMethod.Error()) return } - serviceSpec, methodSpec, errGet := s.services.get(method) + serviceSpec, methodSpec, errGet := s.services.get(method, s.concatStyle) if errGet != nil { s.writeError(w, 400, errGet.Error()) return @@ -267,3 +299,10 @@ func (s *Server) writeError(w http.ResponseWriter, status int, msg string) { }) } } + +func LowerFirstLetter(method string) string { + if len(method) == 0 { + return method + } + return strings.ToLower(method[:1]) + method[1:] +} diff --git a/server_test.go b/server_test.go index 95d11ad..f9c423e 100644 --- a/server_test.go +++ b/server_test.go @@ -37,11 +37,20 @@ func (t *Service1) Add(req *Service1Request, res *Service1Response) error { type Service2 struct { } +type Service3 struct { +} + +func (t *Service3) Multiply(r *http.Request, req *Service1Request, res *Service1Response) error { + res.Result = req.A * req.B + return nil +} + func TestRegisterService(t *testing.T) { var err error s := NewServer() service1 := new(Service1) service2 := new(Service2) + service3 := new(Service3) // Inferred name. err = s.RegisterService(service1, "") @@ -58,6 +67,11 @@ func TestRegisterService(t *testing.T) { if err == nil { t.Errorf("Expected error on service2") } + + err = s.RegisterServiceWithMethod(service3, "", LowerFirstLetter) + if err != nil || !s.HasMethod("Service3.multiply") { + t.Errorf("Expected to be registered: Service3.multiply") + } } func TestRegisterTCPService(t *testing.T) { @@ -65,6 +79,7 @@ func TestRegisterTCPService(t *testing.T) { s := NewServer() service1 := new(Service1) service2 := new(Service2) + service3 := new(Service3) // Inferred name. err = s.RegisterTCPService(service1, "") @@ -81,6 +96,56 @@ func TestRegisterTCPService(t *testing.T) { if err == nil { t.Errorf("Expected error on service2") } + err = s.RegisterServiceWithMethod(service3, "", LowerFirstLetter) + if err != nil || !s.HasMethod("Service3.multiply") { + t.Errorf("Expected to be registered: Service3.multiply") + } +} + +func TestRegisterServiceWithSlash(t *testing.T) { + var err error + s := NewServer().SetConcatStyle("/") + service1 := new(Service1) + service2 := new(Service2) + + // Inferred name. + err = s.RegisterService(service1, "") + if err != nil || !s.HasMethod("Service1/Multiply") { + t.Errorf("Expected to be registered: Service1/Multiply") + } + // Provided name. + err = s.RegisterService(service1, "Foo") + if err != nil || !s.HasMethod("Foo/Multiply") { + t.Errorf("Expected to be registered: Foo/Multiply") + } + // No methods. + err = s.RegisterService(service2, "") + if err == nil { + t.Errorf("Expected error on service2") + } +} + +func TestRegisterServiceWithUnderline(t *testing.T) { + var err error + s := NewServer().SetConcatStyle("_") + service1 := new(Service1) + service2 := new(Service2) + + // Inferred name. + err = s.RegisterService(service1, "") + if err != nil || !s.HasMethod("Service1_Multiply") { + t.Errorf("Expected to be registered: Service1_Multiply") + } + // Provided name. + err = s.RegisterService(service1, "Foo") + if err != nil || !s.HasMethod("Foo_Multiply") { + t.Errorf("Expected to be registered: Foo_Multiply") + } + // No methods. + err = s.RegisterService(service2, "") + if err == nil { + t.Errorf("Expected error on service2") + } } // MockCodec decodes to Service1.Multiply. diff --git a/v2/map.go b/v2/map.go index dda4216..4a7ca86 100644 --- a/v2/map.go +++ b/v2/map.go @@ -49,7 +49,10 @@ type serviceMap struct { } // register adds a new service using reflection to extract its methods. -func (m *serviceMap) register(rcvr interface{}, name string) error { +func (m *serviceMap) register(rcvr interface{}, name string, f func(method string) string) error { + if f == nil { + f = func(method string) string { return method } + } // Setup service. s := &service{ name: name, @@ -71,7 +74,7 @@ func (m *serviceMap) register(rcvr interface{}, name string) error { for i := 0; i < s.rcvrType.NumMethod(); i++ { method := s.rcvrType.Method(i) mtype := method.Type - // Method must be exported. + methodName := f(method.Name) if method.PkgPath != "" { continue } @@ -101,7 +104,7 @@ func (m *serviceMap) register(rcvr interface{}, name string) error { if returnType := mtype.Out(0); returnType != typeOfError { continue } - s.methods[method.Name] = &serviceMethod{ + s.methods[methodName] = &serviceMethod{ method: method, argsType: args.Elem(), replyType: reply.Elem(), @@ -126,8 +129,8 @@ func (m *serviceMap) register(rcvr interface{}, name string) error { // get returns a registered service given a method name. // // The method name uses a dotted notation as in "Service.Method". -func (m *serviceMap) get(method string) (*service, *serviceMethod, error) { - parts := strings.Split(method, ".") +func (m *serviceMap) get(method string, concatStyle string) (*service, *serviceMethod, error) { + parts := strings.Split(method, concatStyle) if len(parts) != 2 { err := fmt.Errorf("rpc: service/method request ill-formed: %q", method) return nil, nil, err diff --git a/v2/server.go b/v2/server.go index 15e0113..36b28cf 100644 --- a/v2/server.go +++ b/v2/server.go @@ -43,8 +43,9 @@ type CodecRequest interface { // NewServer returns a new RPC server. func NewServer() *Server { return &Server{ - codecs: make(map[string]Codec), - services: new(serviceMap), + codecs: make(map[string]Codec), + services: new(serviceMap), + concatStyle: ".", } } @@ -64,6 +65,16 @@ type Server struct { beforeFunc func(i *RequestInfo) afterFunc func(i *RequestInfo) validateFunc reflect.Value + concatStyle string // The style used to concatenate service and method names, e.g., "Service.Method" or "Service.Method()" +} + +// SetConcatStyle sets the style used to concatenate service and method names. +// The default style is ".". +// This is useful when you want to use a different style, such as "/". +// The style is used in the method name, e.g., "Service.Method" or "Service/Method". +func (s *Server) SetConcatStyle(concatStyle string) *Server { + s.concatStyle = concatStyle + return s } // RegisterCodec adds a new codec to the server. @@ -120,24 +131,45 @@ func (s *Server) RegisterAfterFunc(f func(i *RequestInfo)) { // // Methods from the receiver will be extracted if these rules are satisfied: // -// - The receiver is exported (begins with an upper case letter) or local -// (defined in the package registering the service). -// - The method name is exported. -// - The method has three arguments: *http.Request, *args, *reply. -// - All three arguments are pointers. -// - The second and third arguments are exported or local. -// - The method has return type error. +// - The receiver is exported (begins with an upper case letter) or local +// (defined in the package registering the service). +// - The method name is exported. +// - The method has three arguments: *http.Request, *args, *reply. +// - All three arguments are pointers. +// - The second and third arguments are exported or local. +// - The method has return type error. // // All other methods are ignored. func (s *Server) RegisterService(receiver interface{}, name string) error { - return s.services.register(receiver, name) + return s.services.register(receiver, name, nil) +} + +// RegisterService adds a new service to the server. +// +// The name parameter is optional: if empty it will be inferred from +// the receiver type name. +// The methodF function is optional: it is used to transform the method name +// +// Methods from the receiver will be extracted if these rules are satisfied: +// +// - The receiver is exported (begins with an upper case letter) or local +// (defined in the package registering the service). +// - The method name is exported. +// - The method has three arguments: *http.Request, *args, *reply. +// - All three arguments are pointers. +// - The second and third arguments are exported or local. +// - The method has return type error. +// +// All other methods are ignored. +func (s *Server) RegisterServiceWithMethod(receiver interface{}, name string, methodF func(method string) string) error { + return s.services.register(receiver, name, methodF) } // HasMethod returns true if the given method is registered. // // The method uses a dotted notation as in "Service.Method". func (s *Server) HasMethod(method string) bool { - if _, _, err := s.services.get(method); err == nil { + if _, _, err := s.services.get(method, s.concatStyle); err == nil { return true } return false @@ -173,7 +205,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { codecReq.WriteError(w, http.StatusBadRequest, errMethod) return } - serviceSpec, methodSpec, errGet := s.services.get(method) + serviceSpec, methodSpec, errGet := s.services.get(method, s.concatStyle) if errGet != nil { codecReq.WriteError(w, http.StatusBadRequest, errGet) return @@ -273,3 +305,10 @@ func WriteError(w http.ResponseWriter, status int, msg string) { w.WriteHeader(status) fmt.Fprint(w, msg) } + +func LowerFirstLetter(method string) string { + if len(method) == 0 { + return method + } + return strings.ToLower(method[:1]) + method[1:] +} diff --git a/v2/server_test.go b/v2/server_test.go index 2f25f3e..e68946a 100644 --- a/v2/server_test.go +++ b/v2/server_test.go @@ -36,11 +36,20 @@ func (t *Service1) Multiply(r *http.Request, req *Service1Request, res *Service1 type Service2 struct { } +type Service3 struct { +} + +func (t *Service3) Multiply(r *http.Request, req *Service1Request, res *Service1Response) error { + res.Result = req.A * req.B + return nil +} + func TestRegisterService(t *testing.T) { var err error s := NewServer() service1 := new(Service1) service2 := new(Service2) + service3 := new(Service3) // Inferred name. err = s.RegisterService(service1, "") @@ -57,6 +66,57 @@ func TestRegisterService(t *testing.T) { if err == nil { t.Errorf("Expected error on service2") } + + err = s.RegisterServiceWithMethod(service3, "", LowerFirstLetter) + if err != nil || !s.HasMethod("Service3.multiply") { + t.Errorf("Expected to be registered: Service3.multiply") + } +} + +func TestRegisterServiceWithSlash(t *testing.T) { + var err error + s := NewServer().SetConcatStyle("/") + service1 := new(Service1) + service2 := new(Service2) + + // Inferred name. + err = s.RegisterService(service1, "") + if err != nil || !s.HasMethod("Service1/Multiply") { + t.Errorf("Expected to be registered: Service1/Multiply") + } + // Provided name. + err = s.RegisterService(service1, "Foo") + if err != nil || !s.HasMethod("Foo/Multiply") { + t.Errorf("Expected to be registered: Foo/Multiply") + } + // No methods. + err = s.RegisterService(service2, "") + if err == nil { + t.Errorf("Expected error on service2") + } +} + +func TestRegisterServiceWithUnderline(t *testing.T) { + var err error + s := NewServer().SetConcatStyle("_") + service1 := new(Service1) + service2 := new(Service2) + + // Inferred name. + err = s.RegisterService(service1, "") + if err != nil || !s.HasMethod("Service1_Multiply") { + t.Errorf("Expected to be registered: Service1_Multiply") + } + // Provided name. + err = s.RegisterService(service1, "Foo") + if err != nil || !s.HasMethod("Foo_Multiply") { + t.Errorf("Expected to be registered: Foo_Multiply") + } + // No methods. + err = s.RegisterService(service2, "") + if err == nil { + t.Errorf("Expected error on service2") + } } // MockCodec decodes to Service1.Multiply.