Skip to content

Commit

Permalink
just hold
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielLiu1123 committed Dec 2, 2023
1 parent c812f13 commit 1cd54e3
Show file tree
Hide file tree
Showing 15 changed files with 212 additions and 42 deletions.
15 changes: 14 additions & 1 deletion examples/grpc-sample-api/src/main/proto/sample/pet/v1/pet.proto
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ option java_multiple_files = true;
option java_package = "com.freemanan.sample.pet.v1";

import "google/protobuf/wrappers.proto";
import "google/api/annotations.proto";

message Pet {
string name = 1;
Expand All @@ -15,9 +16,21 @@ message Pet {

message GetPetRequest {
string name = 1;
Pet pet = 2;
google.protobuf.StringValue pet_name = 3;
}

service PetService {
rpc GetPet(GetPetRequest) returns (Pet);
rpc GetPet(GetPetRequest) returns (Pet) {
option(google.api.http) = {
post: "/v1/foo",
body: "*",
additional_bindings: [
{
post: "/v1/bar"
}
]
};
}
rpc GetPetName(google.protobuf.StringValue) returns (google.protobuf.StringValue);
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public ResponseEntity<ErrorResponse> handleStatusRuntimeException(StatusRuntimeE
@Data
@NoArgsConstructor
@AllArgsConstructor
static class ErrorResponse {
public static class ErrorResponse {
private int code;
private String message;
private Object data;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ grpc:
server:
reflection:
enabled: true
exception-handling:
use-default: false
server:
error:
include-message: always
include-message: always
8 changes: 7 additions & 1 deletion examples/proto-validate/src/main/proto/fm/foo/v1/foo.proto
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package fm.foo.v1;

import "google/protobuf/empty.proto";
import "buf/validate/validate.proto";
import "google/api/annotations.proto";

option java_multiple_files = true;
option java_package = "com.freemanan.foo.v1.api";
Expand All @@ -25,5 +26,10 @@ message Foo {
}

service FooService {
rpc InsertFoo(Foo) returns (Foo) {}
rpc InsertFoo(Foo) returns (Foo) {
option(google.api.http) = {
post: "/v1/foo",
body: "*"
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import io.grpc.TlsServerCredentials;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.internal.GrpcUtil;
import jakarta.annotation.Nullable;
import java.io.IOException;
import java.time.Duration;
import java.util.concurrent.CountDownLatch;
Expand Down Expand Up @@ -133,6 +134,12 @@ public int getPort() {
return server.getPort();
}

@Nullable
@Override
public Object getServer() {
return server;
}

@Override
public void stop() {
if (isRunning.get()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,9 @@ public boolean isAutoStartup() {
public int getPort() {
return DUMMY_PORT;
}

@Override
public Object getServer() {
return null;
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.freemanan.starter.grpc.server;

import jakarta.annotation.Nullable;
import org.springframework.context.SmartLifecycle;

/**
Expand All @@ -13,4 +14,10 @@ public interface GrpcServer extends SmartLifecycle {
* @return port number
*/
int getPort();

/**
* Get the server object.
*/
@Nullable
Object getServer();
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.Arrays;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import lombok.SneakyThrows;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeansException;
Expand Down Expand Up @@ -83,12 +84,9 @@ private static Method getWithInterceptorsMethod() {
}
}

@SneakyThrows
protected Object applyInterceptor4Stub(ClientInterceptor clientInterceptor, Object stub) {
try {
return withInterceptorsMethod.invoke(stub, (Object) new ClientInterceptor[] {clientInterceptor});
} catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) {
throw new IllegalStateException(e);
}
return withInterceptorsMethod.invoke(stub, (Object) new ClientInterceptor[] {clientInterceptor});
}

protected Message convert2ProtobufMessage(Class<?> messageClass, InputStream is) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.freemanan.starter.grpc.extensions.jsontranscoder.webflux.WebFluxProtobufHandlerAdaptor;
import io.grpc.BindableService;
import io.grpc.Metadata;
import java.util.stream.Collectors;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
Expand Down Expand Up @@ -47,7 +48,8 @@ static class WebMvc {
@ConditionalOnMissingBean
public WebMvcGrpcServiceHandlerMapping webMvcGrpcServiceHandlerMapping(
ObjectProvider<BindableService> grpcServices) {
return new WebMvcGrpcServiceHandlerMapping(grpcServices);
return new WebMvcGrpcServiceHandlerMapping(
grpcServices.orderedStream().collect(Collectors.toList()));
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,24 @@
import com.google.protobuf.Message;
import io.grpc.BindableService;
import io.grpc.ServerMethodDefinition;
import io.grpc.ServerServiceDefinition;
import io.grpc.stub.StreamObserver;
import java.lang.reflect.Method;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.Data;
import lombok.experimental.UtilityClass;
import org.springframework.aop.framework.AopProxyUtils;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.http.HttpHeaders;
import org.springframework.http.InvalidMediaTypeException;
import org.springframework.http.MediaType;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MimeTypeUtils;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.server.NotAcceptableStatusException;

Expand Down Expand Up @@ -57,15 +60,20 @@ private static Method findMethod(BindableService service, String fullMethodName)
return null;
}

public static Map<String, HandlerMethod> getPathToMethod(ObjectProvider<BindableService> grpcServiceProvider) {
public static Map<String, HandlerMethod> getPathToMethod(Collection<BindableService> grpcServiceProvider) {
return grpcServiceProvider.stream()
.map(bs -> Tuple2.of(bs.bindService(), bs))
.flatMap(en -> en.getT1().getMethods().stream()
.map(m -> Tuple2.of(m.getMethodDescriptor().getFullMethodName(), en.getT2())))
.map(en -> Tuple3.of(en.getT1(), en.getT2(), findMethod(en.getT2(), en.getT1())))
.filter(en -> en.getT3() != null)
.map(en -> Tuple2.of("/" + en.getT1(), new HandlerMethod(en.getT2(), en.getT3())))
.collect(Collectors.toMap(Tuple2::getT1, Tuple2::getT2));
.map(bindableService -> new ServiceDefinitionPair(bindableService.bindService(), bindableService))
.flatMap(pair -> pair.getServiceDefinition().getMethods().stream()
.map(md -> new MethodHandlerInfo(
md.getMethodDescriptor().getFullMethodName(),
pair.getBindableService(),
findMethod(
pair.getBindableService(),
md.getMethodDescriptor().getFullMethodName()))))
.filter(info -> info.getMethod() != null)
.collect(Collectors.toMap(
info -> "/" + info.getFullMethodName(),
info -> new HandlerMethod(info.getBindableService(), info.getMethod())));
}

public static boolean isGrpcHandleMethod(Object handler) {
Expand All @@ -77,6 +85,10 @@ public static boolean isGrpcHandleMethod(Object handler) {
* @return true if json string is a json object or json array
*/
public static boolean isJson(String json) {
if (!StringUtils.hasText(json)) {
return false;
}
json = json.trim();
return (json.startsWith("{") && json.endsWith("}")) || (json.startsWith("[") && json.endsWith("]"));
}

Expand Down Expand Up @@ -104,4 +116,17 @@ public static boolean anyCompatible(List<MediaType> mediaTypes, MediaType otherM
public static NotAcceptableStatusException notAcceptableException() {
return new NotAcceptableStatusException("Could not find acceptable representation");
}

@Data
private static final class ServiceDefinitionPair {
private final ServerServiceDefinition serviceDefinition;
private final BindableService bindableService;
}

@Data
private static final class MethodHandlerInfo {
private final String fullMethodName;
private final BindableService bindableService;
private final Method method;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import com.google.protobuf.FloatValue;
import com.google.protobuf.Int32Value;
import com.google.protobuf.Int64Value;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.google.protobuf.StringValue;
import com.google.protobuf.UInt32Value;
import com.google.protobuf.UInt64Value;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import lombok.SneakyThrows;
import lombok.experimental.UtilityClass;

/**
Expand All @@ -24,10 +24,10 @@ public class ProtoUtil {
private static final JsonFormat.Printer printer = JsonFormat.printer().omittingInsignificantWhitespace();

/**
* Check if protobuf message is simple value.
* Check if the protobuf message is a simple value.
*
* @param message protobuf message
* @return true if message is simple value
* @return true if the message is simple value
*/
public static boolean isSimpleValueMessage(Message message) {
if (isWrapperType(message.getClass())) {
Expand All @@ -53,12 +53,9 @@ public static boolean isSimpleValueMessage(Message message) {
* @param message protobuf message
* @return JSON string
*/
@SneakyThrows
public static String toJson(Message message) {
try {
return printer.print(message);
} catch (InvalidProtocolBufferException e) {
throw new IllegalStateException("Can't convert message to JSON", e);
}
return printer.print(message);
}

private static boolean isWrapperType(Class<?> clz) {
Expand Down
Loading

0 comments on commit 1cd54e3

Please sign in to comment.