001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.activemq.transport.udp;
018
019import java.io.EOFException;
020import java.io.IOException;
021import java.net.BindException;
022import java.net.DatagramSocket;
023import java.net.InetSocketAddress;
024import java.net.SocketAddress;
025import java.net.SocketException;
026import java.net.URI;
027import java.net.UnknownHostException;
028import java.nio.channels.AsynchronousCloseException;
029import java.nio.channels.DatagramChannel;
030import java.security.cert.X509Certificate;
031
032import org.apache.activemq.Service;
033import org.apache.activemq.command.Command;
034import org.apache.activemq.command.Endpoint;
035import org.apache.activemq.openwire.OpenWireFormat;
036import org.apache.activemq.transport.Transport;
037import org.apache.activemq.transport.TransportThreadSupport;
038import org.apache.activemq.transport.reliable.ExceptionIfDroppedReplayStrategy;
039import org.apache.activemq.transport.reliable.ReplayBuffer;
040import org.apache.activemq.transport.reliable.ReplayStrategy;
041import org.apache.activemq.transport.reliable.Replayer;
042import org.apache.activemq.util.InetAddressUtil;
043import org.apache.activemq.util.IntSequenceGenerator;
044import org.apache.activemq.util.ServiceStopper;
045import org.slf4j.Logger;
046import org.slf4j.LoggerFactory;
047
048/**
049 * An implementation of the {@link Transport} interface using raw UDP
050 */
051public class UdpTransport extends TransportThreadSupport implements Transport, Service, Runnable {
052
053    private static final Logger LOG = LoggerFactory.getLogger(UdpTransport.class);
054
055    private static final int MAX_BIND_ATTEMPTS = 50;
056    private static final long BIND_ATTEMPT_DELAY = 100;
057
058    private CommandChannel commandChannel;
059    private OpenWireFormat wireFormat;
060    private ByteBufferPool bufferPool;
061    private ReplayStrategy replayStrategy = new ExceptionIfDroppedReplayStrategy();
062    private ReplayBuffer replayBuffer;
063    private int datagramSize = 4 * 1024;
064    private SocketAddress targetAddress;
065    private SocketAddress originalTargetAddress;
066    private DatagramChannel channel;
067    private boolean trace;
068    private boolean useLocalHost = false;
069    private int port;
070    private int minmumWireFormatVersion;
071    private String description;
072    private IntSequenceGenerator sequenceGenerator;
073    private boolean replayEnabled = true;
074
075    protected UdpTransport(OpenWireFormat wireFormat) throws IOException {
076        this.wireFormat = wireFormat;
077    }
078
079    public UdpTransport(OpenWireFormat wireFormat, URI remoteLocation) throws UnknownHostException, IOException {
080        this(wireFormat);
081        this.targetAddress = createAddress(remoteLocation);
082        description = remoteLocation.toString() + "@";
083    }
084
085    public UdpTransport(OpenWireFormat wireFormat, SocketAddress socketAddress) throws IOException {
086        this(wireFormat);
087        this.targetAddress = socketAddress;
088        this.description = getProtocolName() + "ServerConnection@";
089    }
090
091    /**
092     * Used by the server transport
093     */
094    public UdpTransport(OpenWireFormat wireFormat, int port) throws UnknownHostException, IOException {
095        this(wireFormat);
096        this.port = port;
097        this.targetAddress = null;
098        this.description = getProtocolName() + "Server@";
099    }
100
101    /**
102     * Creates a replayer for working with the reliable transport
103     */
104    public Replayer createReplayer() throws IOException {
105        if (replayEnabled) {
106            return getCommandChannel();
107        }
108        return null;
109    }
110
111    /**
112     * A one way asynchronous send
113     */
114    @Override
115    public void oneway(Object command) throws IOException {
116        oneway(command, targetAddress);
117    }
118
119    /**
120     * A one way asynchronous send to a given address
121     */
122    public void oneway(Object command, SocketAddress address) throws IOException {
123        if (LOG.isDebugEnabled()) {
124            LOG.debug("Sending oneway from: " + this + " to target: " + targetAddress + " command: " + command);
125        }
126        checkStarted();
127        commandChannel.write((Command)command, address);
128    }
129
130    /**
131     * @return pretty print of 'this'
132     */
133    @Override
134    public String toString() {
135        if (description != null) {
136            return description + port;
137        } else {
138            return getProtocolUriScheme() + targetAddress + "@" + port;
139        }
140    }
141
142    /**
143     * reads packets from a Socket
144     */
145    @Override
146    public void run() {
147        LOG.trace("Consumer thread starting for: " + toString());
148        while (!isStopped()) {
149            try {
150                Command command = commandChannel.read();
151                doConsume(command);
152            } catch (AsynchronousCloseException e) {
153                // DatagramChannel closed
154                try {
155                    stop();
156                } catch (Exception e2) {
157                    LOG.warn("Caught in: " + this + " while closing: " + e2 + ". Now Closed", e2);
158                }
159            } catch (SocketException e) {
160                // DatagramSocket closed
161                LOG.debug("Socket closed: " + e, e);
162                try {
163                    stop();
164                } catch (Exception e2) {
165                    LOG.warn("Caught in: " + this + " while closing: " + e2 + ". Now Closed", e2);
166                }
167            } catch (EOFException e) {
168                // DataInputStream closed
169                LOG.debug("Socket closed: " + e, e);
170                try {
171                    stop();
172                } catch (Exception e2) {
173                    LOG.warn("Caught in: " + this + " while closing: " + e2 + ". Now Closed", e2);
174                }
175            } catch (Exception e) {
176                try {
177                    stop();
178                } catch (Exception e2) {
179                    LOG.warn("Caught in: " + this + " while closing: " + e2 + ". Now Closed", e2);
180                }
181                if (e instanceof IOException) {
182                    onException((IOException)e);
183                } else {
184                    LOG.error("Caught: " + e, e);
185                    e.printStackTrace();
186                }
187            }
188        }
189    }
190
191    /**
192     * We have received the WireFormatInfo from the server on the actual channel
193     * we should use for all future communication with the server, so lets set
194     * the target to be the actual channel that the server has chosen for us to
195     * talk on.
196     */
197    public void setTargetEndpoint(Endpoint newTarget) {
198        if (newTarget instanceof DatagramEndpoint) {
199            DatagramEndpoint endpoint = (DatagramEndpoint)newTarget;
200            SocketAddress address = endpoint.getAddress();
201            if (address != null) {
202                if (originalTargetAddress == null) {
203                    originalTargetAddress = targetAddress;
204                }
205                targetAddress = address;
206                commandChannel.setTargetAddress(address);
207            }
208        }
209    }
210
211    // Properties
212    // -------------------------------------------------------------------------
213    public boolean isTrace() {
214        return trace;
215    }
216
217    public void setTrace(boolean trace) {
218        this.trace = trace;
219    }
220
221    public int getDatagramSize() {
222        return datagramSize;
223    }
224
225    public void setDatagramSize(int datagramSize) {
226        this.datagramSize = datagramSize;
227    }
228
229    public boolean isUseLocalHost() {
230        return useLocalHost;
231    }
232
233    /**
234     * Sets whether 'localhost' or the actual local host name should be used to
235     * make local connections. On some operating systems such as Macs its not
236     * possible to connect as the local host name so localhost is better.
237     */
238    public void setUseLocalHost(boolean useLocalHost) {
239        this.useLocalHost = useLocalHost;
240    }
241
242    public CommandChannel getCommandChannel() throws IOException {
243        if (commandChannel == null) {
244            commandChannel = createCommandChannel();
245        }
246        return commandChannel;
247    }
248
249    /**
250     * Sets the implementation of the command channel to use.
251     */
252    public void setCommandChannel(CommandDatagramChannel commandChannel) {
253        this.commandChannel = commandChannel;
254    }
255
256    public ReplayStrategy getReplayStrategy() {
257        return replayStrategy;
258    }
259
260    /**
261     * Sets the strategy used to replay missed datagrams
262     */
263    public void setReplayStrategy(ReplayStrategy replayStrategy) {
264        this.replayStrategy = replayStrategy;
265    }
266
267    public int getPort() {
268        return port;
269    }
270
271    /**
272     * Sets the port to connect on
273     */
274    public void setPort(int port) {
275        this.port = port;
276    }
277
278    public int getMinmumWireFormatVersion() {
279        return minmumWireFormatVersion;
280    }
281
282    public void setMinmumWireFormatVersion(int minmumWireFormatVersion) {
283        this.minmumWireFormatVersion = minmumWireFormatVersion;
284    }
285
286    public OpenWireFormat getWireFormat() {
287        return wireFormat;
288    }
289
290    public IntSequenceGenerator getSequenceGenerator() {
291        if (sequenceGenerator == null) {
292            sequenceGenerator = new IntSequenceGenerator();
293        }
294        return sequenceGenerator;
295    }
296
297    public void setSequenceGenerator(IntSequenceGenerator sequenceGenerator) {
298        this.sequenceGenerator = sequenceGenerator;
299    }
300
301    public boolean isReplayEnabled() {
302        return replayEnabled;
303    }
304
305    /**
306     * Sets whether or not replay should be enabled when using the reliable
307     * transport. i.e. should we maintain a buffer of messages that can be
308     * replayed?
309     */
310    public void setReplayEnabled(boolean replayEnabled) {
311        this.replayEnabled = replayEnabled;
312    }
313
314    public ByteBufferPool getBufferPool() {
315        if (bufferPool == null) {
316            bufferPool = new DefaultBufferPool();
317        }
318        return bufferPool;
319    }
320
321    public void setBufferPool(ByteBufferPool bufferPool) {
322        this.bufferPool = bufferPool;
323    }
324
325    public ReplayBuffer getReplayBuffer() {
326        return replayBuffer;
327    }
328
329    public void setReplayBuffer(ReplayBuffer replayBuffer) throws IOException {
330        this.replayBuffer = replayBuffer;
331        getCommandChannel().setReplayBuffer(replayBuffer);
332    }
333
334    // Implementation methods
335    // -------------------------------------------------------------------------
336
337    /**
338     * Creates an address from the given URI
339     */
340    protected InetSocketAddress createAddress(URI remoteLocation) throws UnknownHostException, IOException {
341        String host = resolveHostName(remoteLocation.getHost());
342        return new InetSocketAddress(host, remoteLocation.getPort());
343    }
344
345    protected String resolveHostName(String host) throws UnknownHostException {
346        String localName = InetAddressUtil.getLocalHostName();
347        if (localName != null && isUseLocalHost()) {
348            if (localName.equals(host)) {
349                return "localhost";
350            }
351        }
352        return host;
353    }
354
355    @Override
356    protected void doStart() throws Exception {
357        getCommandChannel().start();
358
359        super.doStart();
360    }
361
362    protected CommandChannel createCommandChannel() throws IOException {
363        SocketAddress localAddress = createLocalAddress();
364        channel = DatagramChannel.open();
365
366        channel = connect(channel, targetAddress);
367
368        DatagramSocket socket = channel.socket();
369        bind(socket, localAddress);
370        if (port == 0) {
371            port = socket.getLocalPort();
372        }
373
374        return createCommandDatagramChannel();
375    }
376
377    protected CommandChannel createCommandDatagramChannel() {
378        return new CommandDatagramChannel(this, getWireFormat(), getDatagramSize(), getTargetAddress(), createDatagramHeaderMarshaller(), getChannel(), getBufferPool());
379    }
380
381    protected void bind(DatagramSocket socket, SocketAddress localAddress) throws IOException {
382        channel.configureBlocking(true);
383
384        if (LOG.isDebugEnabled()) {
385            LOG.debug("Binding to address: " + localAddress);
386        }
387
388        //
389        // We have noticed that on some platfoms like linux, after you close
390        // down
391        // a previously bound socket, it can take a little while before we can
392        // bind it again.
393        //
394        for (int i = 0; i < MAX_BIND_ATTEMPTS; i++) {
395            try {
396                socket.bind(localAddress);
397                return;
398            } catch (BindException e) {
399                if (i + 1 == MAX_BIND_ATTEMPTS) {
400                    throw e;
401                }
402                try {
403                    Thread.sleep(BIND_ATTEMPT_DELAY);
404                } catch (InterruptedException e1) {
405                    Thread.currentThread().interrupt();
406                    throw e;
407                }
408            }
409        }
410
411    }
412
413    protected DatagramChannel connect(DatagramChannel channel, SocketAddress targetAddress2) throws IOException {
414        // TODO
415        // connect to default target address to avoid security checks each time
416        // channel = channel.connect(targetAddress);
417
418        return channel;
419    }
420
421    protected SocketAddress createLocalAddress() {
422        return new InetSocketAddress(port);
423    }
424
425    @Override
426    protected void doStop(ServiceStopper stopper) throws Exception {
427        if (channel != null) {
428            channel.close();
429        }
430    }
431
432    protected DatagramHeaderMarshaller createDatagramHeaderMarshaller() {
433        return new DatagramHeaderMarshaller();
434    }
435
436    protected String getProtocolName() {
437        return "Udp";
438    }
439
440    protected String getProtocolUriScheme() {
441        return "udp://";
442    }
443
444    protected SocketAddress getTargetAddress() {
445        return targetAddress;
446    }
447
448    protected DatagramChannel getChannel() {
449        return channel;
450    }
451
452    protected void setChannel(DatagramChannel channel) {
453        this.channel = channel;
454    }
455
456    public InetSocketAddress getLocalSocketAddress() {
457        if (channel == null) {
458            return null;
459        } else {
460            return (InetSocketAddress)channel.socket().getLocalSocketAddress();
461        }
462    }
463
464    @Override
465    public String getRemoteAddress() {
466        if (targetAddress != null) {
467            return "" + targetAddress;
468        }
469        return null;
470    }
471
472    @Override
473    public int getReceiveCounter() {
474        if (commandChannel == null) {
475            return 0;
476        }
477        return commandChannel.getReceiveCounter();
478    }
479
480    @Override
481    public X509Certificate[] getPeerCertificates() {
482        return null;
483    }
484
485    @Override
486    public void setPeerCertificates(X509Certificate[] certificates) {
487    }
488}