Ticket #11301: oops.go

File oops.go, 3.5 KB (added by dcf, 6 years ago)

pluggable transport that drops its first few connections

Line 
1// Client transport that closes its first "-n" SOCKS connection after time "-t".
2//
3// To build and run:
4//      export GOPATH=$PWD/go
5//      go get
6//      go build
7//      tor -f oops-torrc
8//
9// Usage (in torrc):
10//      UseBridges 1
11//      Bridge oops 173.255.221.44:9001
12//      ClientTransportPlugin oops exec ./oops -t 3s -n 10 --log oops.log
13//      DataDirectory datadir
14//      Log info stderr
15//      SocksPort 9099
16package main
17
18import (
19        "flag"
20        "io"
21        "log"
22        "net"
23        "os"
24        "os/signal"
25        "sync"
26        "syscall"
27        "time"
28)
29
30import "git.torproject.org/pluggable-transports/goptlib.git"
31
32var ptInfo pt.ClientInfo
33var oopsTimeout time.Duration
34var numOops int
35
36// When a connection handler starts, +1 is written to this channel; when it
37// ends, -1 is written.
38var handlerChan = make(chan int)
39
40func copyLoop(a, b net.Conn, count int) {
41        var wg sync.WaitGroup
42        wg.Add(2)
43
44        go func() {
45                io.Copy(b, a)
46                wg.Done()
47        }()
48        go func() {
49                io.Copy(a, b)
50                wg.Done()
51        }()
52
53        if count < numOops {
54                go func() {
55                        <-time.After(oopsTimeout)
56                        log.Printf("oops! connection %d\n", count)
57                        a.Close()
58                        b.Close()
59                }()
60        }
61
62        wg.Wait()
63}
64
65func handler(conn *pt.SocksConn, count int) error {
66        handlerChan <- 1
67        defer func() {
68                handlerChan <- -1
69        }()
70
71        defer conn.Close()
72        remote, err := net.Dial("tcp", conn.Req.Target)
73        if err != nil {
74                conn.Reject()
75                return err
76        }
77        defer remote.Close()
78        err = conn.Grant(remote.RemoteAddr().(*net.TCPAddr))
79        if err != nil {
80                return err
81        }
82
83        copyLoop(conn, remote, count)
84
85        return nil
86}
87
88func acceptLoop(ln *pt.SocksListener) error {
89        defer ln.Close()
90        var count int
91        for {
92                conn, err := ln.AcceptSocks()
93                if err != nil {
94                        if e, ok := err.(net.Error); ok && !e.Temporary() {
95                                return err
96                        }
97                        log.Print(err)
98                        continue
99                }
100                log.Printf("got connection %d", count)
101                go handler(conn, count)
102                count += 1
103        }
104}
105
106func main() {
107        var err error
108        var timeoutStr string
109        var logFilename string
110
111        flag.IntVar(&numOops, "n", 0, "how many first connections to drop")
112        flag.StringVar(&timeoutStr, "t", "5s", "how long to wait before dropping the connection")
113        flag.StringVar(&logFilename, "log", "", "name of log file")
114        flag.Parse()
115
116        if logFilename != "" {
117                f, err := os.OpenFile(logFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
118                if err != nil {
119                        log.Fatal(err)
120                }
121                defer f.Close()
122                log.SetOutput(f)
123        }
124
125        oopsTimeout, err = time.ParseDuration(timeoutStr)
126        if err != nil {
127                log.Fatalf("can't parse duration: %s\n", err)
128        }
129
130        log.Printf("starting %v", os.Args)
131
132        ptInfo, err = pt.ClientSetup([]string{"oops"})
133        if err != nil {
134                log.Fatal(err)
135        }
136
137        listeners := make([]net.Listener, 0)
138        for _, methodName := range ptInfo.MethodNames {
139                switch methodName {
140                case "oops":
141                        ln, err := pt.ListenSocks("tcp", "127.0.0.1:0")
142                        if err != nil {
143                                pt.CmethodError(methodName, err.Error())
144                                break
145                        }
146                        go acceptLoop(ln)
147                        pt.Cmethod(methodName, ln.Version(), ln.Addr())
148                        listeners = append(listeners, ln)
149                default:
150                        pt.CmethodError(methodName, "no such method")
151                }
152        }
153        pt.CmethodsDone()
154
155        var numHandlers int = 0
156        var sig os.Signal
157        sigChan := make(chan os.Signal, 1)
158        signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
159
160        // wait for first signal
161        sig = nil
162        for sig == nil {
163                select {
164                case n := <-handlerChan:
165                        numHandlers += n
166                case sig = <-sigChan:
167                }
168        }
169        for _, ln := range listeners {
170                ln.Close()
171        }
172
173        if sig == syscall.SIGTERM {
174                return
175        }
176
177        // wait for second signal or no more handlers
178        sig = nil
179        for sig == nil && numHandlers != 0 {
180                select {
181                case n := <-handlerChan:
182                        numHandlers += n
183                case sig = <-sigChan:
184                }
185        }
186}