001    /**
002     * Copyright (C) 2012 FuseSource, Inc.
003     * http://fusesource.com
004     *
005     * Licensed under the Apache License, Version 2.0 (the "License");
006     * you may not use this file except in compliance with the License.
007     * You may obtain a copy of the License at
008     *
009     *    http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.fusesource.hawtdispatch.transport;
019    
020    import org.fusesource.hawtdispatch.*;
021    
022    import java.io.IOException;
023    import java.net.*;
024    import java.nio.ByteBuffer;
025    import java.nio.channels.ReadableByteChannel;
026    import java.nio.channels.SelectionKey;
027    import java.nio.channels.SocketChannel;
028    import java.nio.channels.WritableByteChannel;
029    import java.util.LinkedList;
030    import java.util.concurrent.TimeUnit;
031    
032    /**
033     * An implementation of the {@link org.fusesource.hawtdispatch.transport.Transport} interface using raw tcp/ip
034     *
035     * @author <a href="http://hiramchirino.com">Hiram Chirino</a>
036     */
037    public class TcpTransport extends ServiceBase implements Transport {
038    
039        abstract static class SocketState {
040            void onStop(Task onCompleted) {
041            }
042            void onCanceled() {
043            }
044            boolean is(Class<? extends SocketState> clazz) {
045                return getClass()==clazz;
046            }
047        }
048    
049        static class DISCONNECTED extends SocketState{}
050    
051        class CONNECTING extends SocketState{
052            void onStop(Task onCompleted) {
053                trace("CONNECTING.onStop");
054                CANCELING state = new CANCELING();
055                socketState = state;
056                state.onStop(onCompleted);
057            }
058            void onCanceled() {
059                trace("CONNECTING.onCanceled");
060                CANCELING state = new CANCELING();
061                socketState = state;
062                state.onCanceled();
063            }
064        }
065    
066        class CONNECTED extends SocketState {
067    
068            public CONNECTED() {
069                localAddress = channel.socket().getLocalSocketAddress();
070                remoteAddress = channel.socket().getRemoteSocketAddress();
071            }
072    
073            void onStop(Task onCompleted) {
074                trace("CONNECTED.onStop");
075                CANCELING state = new CANCELING();
076                socketState = state;
077                state.add(createDisconnectTask());
078                state.onStop(onCompleted);
079            }
080            void onCanceled() {
081                trace("CONNECTED.onCanceled");
082                CANCELING state = new CANCELING();
083                socketState = state;
084                state.add(createDisconnectTask());
085                state.onCanceled();
086            }
087            Task createDisconnectTask() {
088                return new Task(){
089                    public void run() {
090                        listener.onTransportDisconnected();
091                    }
092                };
093            }
094        }
095    
096        class CANCELING extends SocketState {
097            private LinkedList<Task> runnables =  new LinkedList<Task>();
098            private int remaining;
099            private boolean dispose;
100    
101            public CANCELING() {
102                if( readSource!=null ) {
103                    remaining++;
104                    readSource.cancel();
105                }
106                if( writeSource!=null ) {
107                    remaining++;
108                    writeSource.cancel();
109                }
110            }
111            void onStop(Task onCompleted) {
112                trace("CANCELING.onCompleted");
113                add(onCompleted);
114                dispose = true;
115            }
116            void add(Task onCompleted) {
117                if( onCompleted!=null ) {
118                    runnables.add(onCompleted);
119                }
120            }
121            void onCanceled() {
122                trace("CANCELING.onCanceled");
123                remaining--;
124                if( remaining!=0 ) {
125                    return;
126                }
127                try {
128                    channel.close();
129                } catch (IOException ignore) {
130                }
131                socketState = new CANCELED(dispose);
132                for (Task runnable : runnables) {
133                    runnable.run();
134                }
135                if (dispose) {
136                    dispose();
137                }
138            }
139        }
140    
141        class CANCELED extends SocketState {
142            private boolean disposed;
143    
144            public CANCELED(boolean disposed) {
145                this.disposed=disposed;
146            }
147    
148            void onStop(Task onCompleted) {
149                trace("CANCELED.onStop");
150                if( !disposed ) {
151                    disposed = true;
152                    dispose();
153                }
154                onCompleted.run();
155            }
156        }
157    
158        protected URI remoteLocation;
159        protected URI localLocation;
160        protected TransportListener listener;
161        protected ProtocolCodec codec;
162    
163        protected SocketChannel channel;
164    
165        protected SocketState socketState = new DISCONNECTED();
166    
167        protected DispatchQueue dispatchQueue;
168        private DispatchSource readSource;
169        private DispatchSource writeSource;
170        protected CustomDispatchSource<Integer, Integer> drainOutboundSource;
171        protected CustomDispatchSource<Integer, Integer> yieldSource;
172    
173        protected boolean useLocalHost = true;
174    
175        int maxReadRate;
176        int maxWriteRate;
177        int receiveBufferSize = 1024*64;
178        int sendBufferSize = 1024*64;
179        boolean keepAlive = true;
180    
181    
182        public static final int IPTOS_LOWCOST = 0x02;
183        public static final int IPTOS_RELIABILITY = 0x04;
184        public static final int IPTOS_THROUGHPUT = 0x08;
185        public static final int IPTOS_LOWDELAY = 0x10;
186    
187        int trafficClass = IPTOS_THROUGHPUT;
188    
189        protected RateLimitingChannel rateLimitingChannel;
190        SocketAddress localAddress;
191        SocketAddress remoteAddress;
192    
193        class RateLimitingChannel implements ReadableByteChannel, WritableByteChannel {
194    
195            int read_allowance = maxReadRate;
196            boolean read_suspended = false;
197            int read_resume_counter = 0;
198            int write_allowance = maxWriteRate;
199            boolean write_suspended = false;
200    
201            public void resetAllowance() {
202                if( read_allowance != maxReadRate || write_allowance != maxWriteRate) {
203                    read_allowance = maxReadRate;
204                    write_allowance = maxWriteRate;
205                    if( write_suspended ) {
206                        write_suspended = false;
207                        resumeWrite();
208                    }
209                    if( read_suspended ) {
210                        read_suspended = false;
211                        resumeRead();
212                        for( int i=0; i < read_resume_counter ; i++ ) {
213                            resumeRead();
214                        }
215                    }
216                }
217            }
218    
219            public int read(ByteBuffer dst) throws IOException {
220                if( maxReadRate ==0 ) {
221                    return channel.read(dst);
222                } else {
223                    int remaining = dst.remaining();
224                    if( read_allowance ==0 || remaining ==0 ) {
225                        return 0;
226                    }
227    
228                    int reduction = 0;
229                    if( remaining > read_allowance) {
230                        reduction = remaining - read_allowance;
231                        dst.limit(dst.limit() - reduction);
232                    }
233                    int rc=0;
234                    try {
235                        rc = channel.read(dst);
236                        read_allowance -= rc;
237                    } finally {
238                        if( reduction!=0 ) {
239                            if( dst.remaining() == 0 ) {
240                                // we need to suspend the read now until we get
241                                // a new allowance..
242                                readSource.suspend();
243                                read_suspended = true;
244                            }
245                            dst.limit(dst.limit() + reduction);
246                        }
247                    }
248                    return rc;
249                }
250            }
251    
252            public int write(ByteBuffer src) throws IOException {
253                if( maxWriteRate ==0 ) {
254                    return channel.write(src);
255                } else {
256                    int remaining = src.remaining();
257                    if( write_allowance ==0 || remaining ==0 ) {
258                        return 0;
259                    }
260    
261                    int reduction = 0;
262                    if( remaining > write_allowance) {
263                        reduction = remaining - write_allowance;
264                        src.limit(src.limit() - reduction);
265                    }
266                    int rc = 0;
267                    try {
268                        rc = channel.write(src);
269                        write_allowance -= rc;
270                    } finally {
271                        if( reduction!=0 ) {
272                            if( src.remaining() == 0 ) {
273                                // we need to suspend the read now until we get
274                                // a new allowance..
275                                write_suspended = true;
276                                suspendWrite();
277                            }
278                            src.limit(src.limit() + reduction);
279                        }
280                    }
281                    return rc;
282                }
283            }
284    
285            public boolean isOpen() {
286                return channel.isOpen();
287            }
288    
289            public void close() throws IOException {
290                channel.close();
291            }
292    
293            public void resumeRead() {
294                if( read_suspended ) {
295                    read_resume_counter += 1;
296                } else {
297                    _resumeRead();
298                }
299            }
300    
301        }
302    
303        private final Task CANCEL_HANDLER = new Task() {
304            public void run() {
305                socketState.onCanceled();
306            }
307        };
308    
309        static final class OneWay {
310            final Object command;
311            final Retained retained;
312    
313            public OneWay(Object command, Retained retained) {
314                this.command = command;
315                this.retained = retained;
316            }
317        }
318    
319        public void connected(SocketChannel channel) throws IOException, Exception {
320            this.channel = channel;
321            initializeChannel();
322            this.socketState = new CONNECTED();
323        }
324    
325        protected void initializeChannel() throws Exception {
326            this.channel.configureBlocking(false);
327            Socket socket = channel.socket();
328            try {
329                socket.setReuseAddress(true);
330            } catch (SocketException e) {
331            }
332            try {
333                socket.setSoLinger(true, 0);
334            } catch (SocketException e) {
335            }
336            try {
337                socket.setTrafficClass(trafficClass);
338            } catch (SocketException e) {
339            }
340            try {
341                socket.setKeepAlive(keepAlive);
342            } catch (SocketException e) {
343            }
344            try {
345                socket.setTcpNoDelay(true);
346            } catch (SocketException e) {
347            }
348            try {
349                socket.setReceiveBufferSize(receiveBufferSize);
350            } catch (SocketException e) {
351            }
352            try {
353                socket.setSendBufferSize(sendBufferSize);
354            } catch (SocketException e) {
355            }
356    
357            if( channel!=null && codec!=null ) {
358                initializeCodec();
359            }
360        }
361    
362        protected void initializeCodec() throws Exception {
363            codec.setReadableByteChannel(readChannel());
364            codec.setWritableByteChannel(writeChannel());
365            if( codec instanceof TransportAware ) {
366                ((TransportAware)codec).setTransport(this);
367            }
368        }
369    
370        public void connecting(URI remoteLocation, URI localLocation) throws IOException, Exception {
371            this.channel = SocketChannel.open();
372            initializeChannel();
373            this.remoteLocation = remoteLocation;
374            this.localLocation = localLocation;
375    
376            if (localLocation != null) {
377                InetSocketAddress localAddress = new InetSocketAddress(InetAddress.getByName(localLocation.getHost()), localLocation.getPort());
378                channel.socket().bind(localAddress);
379            }
380    
381            String host = resolveHostName(remoteLocation.getHost());
382            InetSocketAddress remoteAddress = new InetSocketAddress(host, remoteLocation.getPort());
383            channel.connect(remoteAddress);
384            this.socketState = new CONNECTING();
385        }
386    
387    
388        public DispatchQueue getDispatchQueue() {
389            return dispatchQueue;
390        }
391    
392        public void setDispatchQueue(DispatchQueue queue) {
393            this.dispatchQueue = queue;
394            if(readSource!=null) readSource.setTargetQueue(queue);
395            if(writeSource!=null) writeSource.setTargetQueue(queue);
396            if(drainOutboundSource!=null) drainOutboundSource.setTargetQueue(queue);
397            if(yieldSource!=null) yieldSource.setTargetQueue(queue);
398        }
399    
400        public void _start(Task onCompleted) {
401            try {
402                if (socketState.is(CONNECTING.class) ) {
403                    trace("connecting...");
404                    // this allows the connect to complete..
405                    readSource = Dispatch.createSource(channel, SelectionKey.OP_CONNECT, dispatchQueue);
406                    readSource.setEventHandler(new Task() {
407                        public void run() {
408                            if (getServiceState() != STARTED) {
409                                return;
410                            }
411                            try {
412                                trace("connected.");
413                                channel.finishConnect();
414                                readSource.setCancelHandler(null);
415                                readSource.cancel();
416                                readSource=null;
417                                socketState = new CONNECTED();
418                                onConnected();
419                            } catch (IOException e) {
420                                onTransportFailure(e);
421                            }
422                        }
423                    });
424                    readSource.setCancelHandler(CANCEL_HANDLER);
425                    readSource.resume();
426    
427                } else if (socketState.is(CONNECTED.class) ) {
428                    dispatchQueue.execute(new Task() {
429                        public void run() {
430                            try {
431                                trace("was connected.");
432                                onConnected();
433                            } catch (IOException e) {
434                                 onTransportFailure(e);
435                            }
436                        }
437                    });
438                } else {
439                    System.err.println("cannot be started.  socket state is: "+socketState);
440                }
441            } finally {
442                if( onCompleted!=null ) {
443                    onCompleted.run();
444                }
445            }
446        }
447    
448        public void _stop(final Task onCompleted) {
449            trace("stopping.. at state: "+socketState);
450            socketState.onStop(onCompleted);
451        }
452    
453        protected String resolveHostName(String host) throws UnknownHostException {
454            String localName = InetAddress.getLocalHost().getHostName();
455            if (localName != null && isUseLocalHost()) {
456                if (localName.equals(host)) {
457                    return "localhost";
458                }
459            }
460            return host;
461        }
462    
463        protected void onConnected() throws IOException {
464            yieldSource = Dispatch.createSource(EventAggregators.INTEGER_ADD, dispatchQueue);
465            yieldSource.setEventHandler(new Task() {
466                public void run() {
467                    drainInbound();
468                }
469            });
470            yieldSource.resume();
471            drainOutboundSource = Dispatch.createSource(EventAggregators.INTEGER_ADD, dispatchQueue);
472            drainOutboundSource.setEventHandler(new Task() {
473                public void run() {
474                    flush();
475                }
476            });
477            drainOutboundSource.resume();
478    
479            readSource = Dispatch.createSource(channel, SelectionKey.OP_READ, dispatchQueue);
480            writeSource = Dispatch.createSource(channel, SelectionKey.OP_WRITE, dispatchQueue);
481    
482            readSource.setCancelHandler(CANCEL_HANDLER);
483            writeSource.setCancelHandler(CANCEL_HANDLER);
484    
485            readSource.setEventHandler(new Task() {
486                public void run() {
487                    drainInbound();
488                }
489            });
490            writeSource.setEventHandler(new Task() {
491                public void run() {
492                    flush();
493                }
494            });
495    
496            if( maxReadRate !=0 || maxWriteRate !=0 ) {
497                rateLimitingChannel = new RateLimitingChannel();
498                schedualRateAllowanceReset();
499            }
500            listener.onTransportConnected();
501        }
502    
503        private void schedualRateAllowanceReset() {
504            dispatchQueue.executeAfter(1, TimeUnit.SECONDS, new Task(){
505                public void run() {
506                    if( !socketState.is(CONNECTED.class) ) {
507                        return;
508                    }
509                    rateLimitingChannel.resetAllowance();
510                    schedualRateAllowanceReset();
511                }
512            });
513        }
514    
515        private void dispose() {
516            if( readSource!=null ) {
517                readSource.cancel();
518                readSource=null;
519            }
520    
521            if( writeSource!=null ) {
522                writeSource.cancel();
523                writeSource=null;
524            }
525            this.codec = null;
526        }
527    
528        public void onTransportFailure(IOException error) {
529            listener.onTransportFailure(error);
530            socketState.onCanceled();
531        }
532    
533    
534        public boolean full() {
535            return codec==null || codec.full();
536        }
537    
538        boolean rejectingOffers;
539    
540        public boolean offer(Object command) {
541            dispatchQueue.assertExecuting();
542            try {
543                if (!socketState.is(CONNECTED.class)) {
544                    throw new IOException("Not connected.");
545                }
546                if (getServiceState() != STARTED) {
547                    throw new IOException("Not running.");
548                }
549    
550                ProtocolCodec.BufferState rc = codec.write(command);
551                rejectingOffers = codec.full();
552                switch (rc ) {
553                    case FULL:
554                        return false;
555                    default:
556                        drainOutboundSource.merge(1);
557                        return true;
558                }
559            } catch (IOException e) {
560                onTransportFailure(e);
561                return false;
562            }
563    
564        }
565    
566        boolean writeResumedForCodecFlush = false;
567    
568        /**
569         *
570         */
571        public void flush() {
572            dispatchQueue.assertExecuting();
573            if (getServiceState() != STARTED || !socketState.is(CONNECTED.class)) {
574                return;
575            }
576            try {
577                if( codec.flush() == ProtocolCodec.BufferState.EMPTY && transportFlush() ) {
578                    if( writeResumedForCodecFlush) {
579                        writeResumedForCodecFlush = false;
580                        suspendWrite();
581                    }
582                    rejectingOffers = false;
583                    listener.onRefill();
584    
585                } else {
586                    if(!writeResumedForCodecFlush) {
587                        writeResumedForCodecFlush = true;
588                        resumeWrite();
589                    }
590                }
591            } catch (IOException e) {
592                onTransportFailure(e);
593            }
594        }
595    
596        protected boolean transportFlush() throws IOException {
597            return true;
598        }
599    
600        protected void drainInbound() {
601            if (!getServiceState().isStarted() || readSource.isSuspended()) {
602                return;
603            }
604            try {
605                long initial = codec.getReadCounter();
606                // Only process upto 2 x the read buffer worth of data at a time so we can give
607                // other connections a chance to process their requests.
608                while( codec.getReadCounter()-initial < codec.getReadBufferSize()<<2 ) {
609                    Object command = codec.read();
610                    if ( command!=null ) {
611                        try {
612                            listener.onTransportCommand(command);
613                        } catch (Throwable e) {
614                            e.printStackTrace();
615                            onTransportFailure(new IOException("Transport listener failure."));
616                        }
617    
618                        // the transport may be suspended after processing a command.
619                        if (getServiceState() == STOPPED || readSource.isSuspended()) {
620                            return;
621                        }
622                    } else {
623                        return;
624                    }
625                }
626                yieldSource.merge(1);
627            } catch (IOException e) {
628                onTransportFailure(e);
629            }
630        }
631    
632        public SocketAddress getLocalAddress() {
633            return localAddress;
634        }
635    
636        public SocketAddress getRemoteAddress() {
637            return remoteAddress;
638        }
639    
640        private boolean assertConnected() {
641            try {
642                if ( !isConnected() ) {
643                    throw new IOException("Not connected.");
644                }
645                return true;
646            } catch (IOException e) {
647                onTransportFailure(e);
648            }
649            return false;
650        }
651    
652        public void suspendRead() {
653            if( isConnected() && readSource!=null ) {
654                readSource.suspend();
655            }
656        }
657    
658    
659        public void resumeRead() {
660            if( isConnected() && readSource!=null ) {
661                if( rateLimitingChannel!=null ) {
662                    rateLimitingChannel.resumeRead();
663                } else {
664                    _resumeRead();
665                }
666            }
667        }
668    
669        private void _resumeRead() {
670            readSource.resume();
671            dispatchQueue.execute(new Task(){
672                public void run() {
673                    drainInbound();
674                }
675            });
676        }
677    
678        protected void suspendWrite() {
679            if( isConnected() && writeSource!=null ) {
680                writeSource.suspend();
681            }
682        }
683    
684        protected void resumeWrite() {
685            if( isConnected() && writeSource!=null ) {
686                writeSource.resume();
687            }
688        }
689    
690        public TransportListener getTransportListener() {
691            return listener;
692        }
693    
694        public void setTransportListener(TransportListener transportListener) {
695            this.listener = transportListener;
696        }
697    
698        public ProtocolCodec getProtocolCodec() {
699            return codec;
700        }
701    
702        public void setProtocolCodec(ProtocolCodec protocolCodec) throws Exception {
703            this.codec = protocolCodec;
704            if( channel!=null && codec!=null ) {
705                initializeCodec();
706            }
707        }
708    
709        public boolean isConnected() {
710            return socketState.is(CONNECTED.class);
711        }
712    
713        public boolean isClosed() {
714            return getServiceState() == STOPPED;
715        }
716    
717        public boolean isUseLocalHost() {
718            return useLocalHost;
719        }
720    
721        /**
722         * Sets whether 'localhost' or the actual local host name should be used to
723         * make local connections. On some operating systems such as Macs its not
724         * possible to connect as the local host name so localhost is better.
725         */
726        public void setUseLocalHost(boolean useLocalHost) {
727            this.useLocalHost = useLocalHost;
728        }
729    
730        private void trace(String message) {
731            // TODO:
732        }
733    
734        public SocketChannel getSocketChannel() {
735            return channel;
736        }
737    
738        public ReadableByteChannel readChannel() {
739            if(rateLimitingChannel!=null) {
740                return rateLimitingChannel;
741            } else {
742                return channel;
743            }
744        }
745    
746        public WritableByteChannel writeChannel() {
747            if(rateLimitingChannel!=null) {
748                return rateLimitingChannel;
749            } else {
750                return channel;
751            }
752        }
753    
754        public int getMaxReadRate() {
755            return maxReadRate;
756        }
757    
758        public void setMaxReadRate(int maxReadRate) {
759            this.maxReadRate = maxReadRate;
760        }
761    
762        public int getMaxWriteRate() {
763            return maxWriteRate;
764        }
765    
766        public void setMaxWriteRate(int maxWriteRate) {
767            this.maxWriteRate = maxWriteRate;
768        }
769    
770        public int getTrafficClass() {
771            return trafficClass;
772        }
773    
774        public void setTrafficClass(int trafficClass) {
775            this.trafficClass = trafficClass;
776        }
777    
778        public int getReceiveBufferSize() {
779            return receiveBufferSize;
780        }
781    
782        public void setReceiveBufferSize(int receiveBufferSize) {
783            this.receiveBufferSize = receiveBufferSize;
784        }
785    
786        public int getSendBufferSize() {
787            return sendBufferSize;
788        }
789    
790        public void setSendBufferSize(int sendBufferSize) {
791            this.sendBufferSize = sendBufferSize;
792        }
793    
794        public boolean isKeepAlive() {
795            return keepAlive;
796        }
797    
798        public void setKeepAlive(boolean keepAlive) {
799            this.keepAlive = keepAlive;
800        }
801    
802    }