-
Notifications
You must be signed in to change notification settings - Fork 65
/
main.go
132 lines (112 loc) · 2.77 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package main
import (
"crypto/tls"
"flag"
"fmt"
"github.com/liberal-boy/tls-shunt-proxy/config"
"github.com/liberal-boy/tls-shunt-proxy/handler"
"github.com/liberal-boy/tls-shunt-proxy/sniffer"
"github.com/stevenjohnstone/sni"
"log"
"net"
"strings"
)
const version = "0.8.1"
var conf config.Config
func main() {
fmt.Println("tls-shunt-proxy version", version)
configPath := flag.String("config", "./config.yaml", "Path to config file")
flag.Parse()
var err error
conf, err = config.ReadConfig(*configPath)
if err != nil {
log.Fatalf("failed to read config %s: %v", *configPath, err)
}
if conf.RedirectHttps != "" {
handler.ServeRedirectHttps(conf.RedirectHttps)
}
listenAndServe()
}
func listenAndServe() {
ln, err := net.Listen("tcp", conf.Listen)
if err != nil {
log.Fatalf("failed to listen on %s: %v", conf.Listen, err)
}
defer func() { _ = ln.Close() }()
for {
conn, err := ln.Accept()
if err != nil {
log.Printf("fail to establish conn: %v\n", err)
continue
}
go handle(conn)
}
}
func handle(conn net.Conn) {
serverName, sniConn, err := sni.ServerNameFromConn(conn)
if err != nil {
log.Printf("fail to obtain server name: %v\n", err)
handler.NewPlainTextHandler(handler.SentHttpToHttps).Handle(conn)
return
}
handleWithServerName(sniConn, serverName)
}
func handleWithServerName(conn net.Conn, serverName string) {
vh, has := conf.VHosts[strings.ToLower(serverName)]
if !has {
log.Printf("no available vhost for %s\n", serverName)
handler.NewPlainTextHandler(handler.NoCertificateAvailable).Handle(conn)
return
}
if vh.TlsConfig != nil {
conn = tlsOffloading(conn, vh.TlsConfig)
sniffConn := sniffer.NewPeekPreDataConn(conn)
conn = sniffConn
switch sniffConn.Type {
case sniffer.TypeHttp:
if handleHttp(sniffConn, vh) {
return
}
case sniffer.TypeHttp2:
if handleHttp2(sniffConn, vh) {
return
}
case sniffer.TypeTrojan:
if handleTrojan(sniffConn, vh) {
return
}
}
}
vh.Default.Handle(conn)
}
func handleHttp(conn *sniffer.SniffConn, vh config.VHost) bool {
for _, p := range vh.PathHandlers {
if strings.HasPrefix(conn.GetPath(), p.Path) {
conn.SetPath(strings.TrimPrefix(conn.GetPath(), p.TrimPrefix))
p.Handler.Handle(conn)
return true
}
}
if vh.Http != handler.NoopHandler {
vh.Http.Handle(conn)
return true
}
return false
}
func handleHttp2(conn *sniffer.SniffConn, vh config.VHost) bool {
if vh.Http2 != handler.NoopHandler {
vh.Http2.Handle(conn)
return true
}
return handleHttp(conn, vh)
}
func handleTrojan(conn *sniffer.SniffConn, vh config.VHost) bool {
if vh.Trojan != handler.NoopHandler {
vh.Trojan.Handle(conn)
return true
}
return false
}
func tlsOffloading(conn net.Conn, tlsConfig *tls.Config) *tls.Conn {
return tls.Server(conn, tlsConfig)
}