|
| 1 | +package com.google.net.stubby.stub; |
| 2 | + |
| 3 | +import com.google.common.io.BaseEncoding; |
| 4 | +import com.google.common.util.concurrent.ListenableFuture; |
| 5 | +import com.google.net.stubby.Call; |
| 6 | +import com.google.net.stubby.Channel; |
| 7 | +import com.google.net.stubby.Metadata; |
| 8 | +import com.google.net.stubby.MethodDescriptor; |
| 9 | +import com.google.net.stubby.Status; |
| 10 | +import com.google.net.stubby.context.ForwardingChannel; |
| 11 | +import com.google.protobuf.GeneratedMessage; |
| 12 | +import com.google.protobuf.InvalidProtocolBufferException; |
| 13 | + |
| 14 | +import java.util.concurrent.atomic.AtomicReference; |
| 15 | + |
| 16 | +/** |
| 17 | + * Utility functions for binding and receiving headers |
| 18 | + */ |
| 19 | +public class MetadataUtils { |
| 20 | + |
| 21 | + /** |
| 22 | + * Attach a set of request headers to a stub. |
| 23 | + * @param stub to bind the headers to. |
| 24 | + * @param extraHeaders the headers to be passed by each call on the returned stub. |
| 25 | + * @return an implementation of the stub with extraHeaders bound to each call. |
| 26 | + */ |
| 27 | + @SuppressWarnings("unchecked") |
| 28 | + public static <T extends AbstractStub> T attachHeaders( |
| 29 | + T stub, |
| 30 | + final Metadata.Headers extraHeaders) { |
| 31 | + return (T) stub.configureNewStub().setChannel(attachHeaders(stub.getChannel(), extraHeaders)) |
| 32 | + .build(); |
| 33 | + } |
| 34 | + |
| 35 | + /** |
| 36 | + * Attach a set of request headers to a channel. |
| 37 | + * |
| 38 | + * @param channel to channel to intercept. |
| 39 | + * @param extraHeaders the headers to be passed by each call on the returned stub. |
| 40 | + * @return an implementation of the channel with extraHeaders bound to each call. |
| 41 | + */ |
| 42 | + @SuppressWarnings("unchecked") |
| 43 | + public static Channel attachHeaders(Channel channel, final Metadata.Headers extraHeaders) { |
| 44 | + return new ForwardingChannel(channel) { |
| 45 | + @Override |
| 46 | + public <ReqT, RespT> Call<ReqT, RespT> newCall(MethodDescriptor<ReqT, RespT> method) { |
| 47 | + return new ForwardingCall<ReqT, RespT>(delegate.newCall(method)) { |
| 48 | + @Override |
| 49 | + public void start(Listener<RespT> responseListener, Metadata.Headers headers) { |
| 50 | + headers.merge(extraHeaders); |
| 51 | + delegate.start(responseListener, headers); |
| 52 | + } |
| 53 | + }; |
| 54 | + } |
| 55 | + }; |
| 56 | + } |
| 57 | + |
| 58 | + /** |
| 59 | + * Capture the last received metadata for a stub. Useful for testing |
| 60 | + * @param stub to capture for |
| 61 | + * @param headersCapture to record the last received headers |
| 62 | + * @param trailersCapture to record the last received trailers |
| 63 | + * @return an implementation of the stub with extraHeaders bound to each call. |
| 64 | + */ |
| 65 | + @SuppressWarnings("unchecked") |
| 66 | + public static <T extends AbstractStub> T captureMetadata( |
| 67 | + T stub, |
| 68 | + AtomicReference<Metadata.Headers> headersCapture, |
| 69 | + AtomicReference<Metadata.Trailers> trailersCapture) { |
| 70 | + return (T) stub.configureNewStub().setChannel( |
| 71 | + captureMetadata(stub.getChannel(), headersCapture, trailersCapture)) |
| 72 | + .build(); |
| 73 | + } |
| 74 | + |
| 75 | + /** |
| 76 | + * Capture the last received metadata on a channel. Useful for testing |
| 77 | + * |
| 78 | + * @param channel to channel to capture for. |
| 79 | + * @param headersCapture to record the last received headers |
| 80 | + * @param trailersCapture to record the last received trailers |
| 81 | + * @return an implementation of the channel with captures installed. |
| 82 | + */ |
| 83 | + @SuppressWarnings("unchecked") |
| 84 | + public static Channel captureMetadata(Channel channel, |
| 85 | + final AtomicReference<Metadata.Headers> headersCapture, |
| 86 | + final AtomicReference<Metadata.Trailers> trailersCapture) { |
| 87 | + return new ForwardingChannel(channel) { |
| 88 | + @Override |
| 89 | + public <ReqT, RespT> Call<ReqT, RespT> newCall(MethodDescriptor<ReqT, RespT> method) { |
| 90 | + return new ForwardingCall<ReqT, RespT>(delegate.newCall(method)) { |
| 91 | + @Override |
| 92 | + public void start(Listener<RespT> responseListener, Metadata.Headers headers) { |
| 93 | + headersCapture.set(null); |
| 94 | + trailersCapture.set(null); |
| 95 | + delegate.start(new ForwardingListener<RespT>(responseListener) { |
| 96 | + @Override |
| 97 | + public ListenableFuture<Void> onHeaders(Metadata.Headers headers) { |
| 98 | + headersCapture.set(headers); |
| 99 | + return super.onHeaders(headers); |
| 100 | + } |
| 101 | + |
| 102 | + @Override |
| 103 | + public void onClose(Status status, Metadata.Trailers trailers) { |
| 104 | + trailersCapture.set(trailers); |
| 105 | + super.onClose(status, trailers); |
| 106 | + } |
| 107 | + }, headers); |
| 108 | + } |
| 109 | + }; |
| 110 | + } |
| 111 | + }; |
| 112 | + } |
| 113 | + |
| 114 | + /** |
| 115 | + * Produce a metadata key for a generated protobuf type. |
| 116 | + */ |
| 117 | + public static <T extends GeneratedMessage> Metadata.Key<T> keyForProto(final T instance) { |
| 118 | + return Metadata.Key.of(instance.getDescriptorForType().getFullName(), |
| 119 | + new Metadata.Marshaller<T>() { |
| 120 | + @Override |
| 121 | + public byte[] toBytes(T value) { |
| 122 | + return value.toByteArray(); |
| 123 | + } |
| 124 | + |
| 125 | + @Override |
| 126 | + public String toAscii(T value) { |
| 127 | + return BaseEncoding.base64().encode(value.toByteArray()); |
| 128 | + } |
| 129 | + |
| 130 | + @Override |
| 131 | + @SuppressWarnings("unchecked") |
| 132 | + public T parseBytes(byte[] serialized) { |
| 133 | + try { |
| 134 | + return (T) instance.getParserForType().parseFrom(serialized); |
| 135 | + } catch (InvalidProtocolBufferException ipbe) { |
| 136 | + throw new IllegalArgumentException(ipbe); |
| 137 | + } |
| 138 | + } |
| 139 | + |
| 140 | + @Override |
| 141 | + public T parseAscii(String ascii) { |
| 142 | + return parseBytes(BaseEncoding.base64().decode(ascii)); |
| 143 | + } |
| 144 | + }); |
| 145 | + } |
| 146 | +} |
0 commit comments