-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.go
327 lines (276 loc) · 8.94 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
/* main.go */
package main
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"database/sql"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"html/template"
"io"
"log"
"net/http"
"os"
"strconv"
"strings"
"sync"
_ "github.com/go-sql-driver/mysql"
)
var db *sql.DB = connectDb()
func main() {
http.HandleFunc("/create/server/", createServerHandler)
http.HandleFunc("/create/client/", createClientHandler)
http.HandleFunc("/view/server/", viewServerHandler)
http.HandleFunc("/view/client/", viewClientHandler)
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, "static/home.html")
})
fs := http.FileServer(http.Dir("static"))
http.Handle("/static/", http.StripPrefix("/static/", fs))
err := http.ListenAndServe(":11994", nil)
if err != nil {
log.Fatal(err)
}
}
func connectDb() (*sql.DB){
tablename := "Ephemeral"
username := os.Getenv("ephemeralUsername")
password := os.Getenv("ephemeralPassword")
db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@/%s", username, password, tablename))
if err != nil {
log.Fatal(err)
}
/* Test connection */
err = db.Ping()
if err != nil {
log.Fatal(err)
}
return db
}
/* Write the given error message as HTML */
func writeError(w http.ResponseWriter, message string){
type Out struct {
Message string
}
template.Must(template.ParseFiles("static/error.html")).Execute(w, Out{message})
}
/* 128 bit AES */
func encrypt(key, text []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
b := base64.StdEncoding.EncodeToString(text)
ciphertext := make([]byte, aes.BlockSize+len(b))
iv := ciphertext[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return nil, err
}
cfb := cipher.NewCFBEncrypter(block, iv)
cfb.XORKeyStream(ciphertext[aes.BlockSize:], []byte(b))
return ciphertext, nil
}
/* 128 bit AES */
func decrypt(key, text []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
if len(text) < aes.BlockSize {
return nil, errors.New("Ciphertext too short")
}
iv := text[:aes.BlockSize]
text = text[aes.BlockSize:]
cfb := cipher.NewCFBDecrypter(block, iv)
cfb.XORKeyStream(text, text)
data, err := base64.StdEncoding.DecodeString(string(text))
if err != nil {
return nil, err
}
return data, nil
}
/* POST /create/server */
func createServerHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "POST required", http.StatusMethodNotAllowed)
return
}
text := r.PostFormValue("text")
expireMinutes, err := strconv.Atoi(r.PostFormValue("expireMinutes"))
if err != nil {
expireMinutes = 43200 /* Default expire in 30 days */
}
if len(text) > 16000{
writeError(w, "Message too long. Max character length is 16000.")
return
}
/* Generate 128 bit key */
key128bits := make([]byte, 16)
_, err = rand.Read(key128bits)
if err != nil {
http.Error(w, "Something went wrong :(", http.StatusInternalServerError)
return
}
/* Encrypt the text */
encryptedtextBytes, err := encrypt(key128bits, []byte(text))
if err != nil {
http.Error(w, "Something went wrong :(", http.StatusInternalServerError)
return
}
msgId := generateMsgId(db)
_, err = db.Exec("insert into messages values (?, ?, ?, UNIX_TIMESTAMP(), ?, ?)",
msgId, hex.EncodeToString(encryptedtextBytes), nil, expireMinutes, true)
if err != nil {
http.Error(w, "Something went wrong :(", http.StatusInternalServerError)
return
}
type Out struct {
MsgId string
Key string
}
tmpl := template.Must(template.ParseFiles("static/create.html"))
tmpl.Execute(w, Out{msgId, hex.EncodeToString(key128bits)})
}
/* POST /create/client */
func createClientHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "POST required", http.StatusMethodNotAllowed)
return
}
encryptedText := r.PostFormValue("text")
salt := r.PostFormValue("salt")
expireMinutes, err := strconv.Atoi(r.PostFormValue("expireMinutes"))
if err != nil {
expireMinutes = 43200 /* Default expire in 30 days */
}
if len(encryptedText) > 16000{
writeError(w, "Message too long. Max character length is 16000.")
return
}
msgId := generateMsgId(db)
_, err = db.Exec("insert into messages values (?, ?, ?, UNIX_TIMESTAMP(), ?, ?)",
msgId, encryptedText, salt, expireMinutes, false)
if err != nil {
http.Error(w, "Something went wrong :(", http.StatusInternalServerError)
return
}
w.Write([]byte("https://ephemeral.pw/view/client/" + msgId))
}
/* GET /view/server */
func viewServerHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "GET required", http.StatusMethodNotAllowed)
return
}
/* Blacklist sites that GET the url before sending to recipient */
blacklist := [...]string{"facebook"}
for _,e := range blacklist {
if strings.Contains(r.UserAgent(), e) {
fmt.Fprintf(w, "Go away %s! This is only for the recipient!", e)
return
}
}
/* ephemeral.pw/view/server/msgId/key/ */
queryString := strings.TrimSuffix(r.URL.Path[len("/view/server/"):],"/")
params := strings.Split(queryString, "/")
if len(params) != 2 {
writeError(w, "Message not found. It may have been deleted.")
return
}
msgId := params[0]
keyString := params[1]
keyBytes, err := hex.DecodeString(keyString)
if err != nil {
/* Key is not hex */
writeError(w, "Message not found. It may have been deleted.")
return
}
var m sync.Mutex
m.Lock() /* ONLY ONE THREAD IN HERE AT A TIME */
var encryptedText string
err = db.QueryRow("SELECT encrypted_text FROM messages WHERE id = ?", msgId).Scan(&encryptedText)
if err != nil {
writeError(w, "Message not found. It may have been deleted.")
return
}
/* Decrypt message */
encryptedtextBytes , err := hex.DecodeString(encryptedText)
if err != nil {
writeError(w, "Message not found. It may have been deleted.")
return
}
messageBytes, err := decrypt(keyBytes, []byte(encryptedtextBytes))
if err != nil {
/* Valid msgId, but invalid key */
writeError(w, "Message not found. It may have been deleted.")
return
}
db.Exec("DELETE FROM messages WHERE id = ? LIMIT 1", msgId)
m.Unlock()
type Out struct {
Message []string
}
message := template.HTMLEscapeString(string(messageBytes)) /* no XSS */
tmpl := template.Must(template.ParseFiles("static/viewServer.html"))
tmpl.Execute(w, Out{strings.Split(message, "\n")})
}
/* GET /view/client */
func viewClientHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "GET required", http.StatusMethodNotAllowed)
return
}
/* Blacklist sites that GET the url before sending to recipient */
blacklist := [...]string{"facebook"}
for _,e := range blacklist {
if strings.Contains(r.UserAgent(), e) {
fmt.Fprintf(w, "Go away %s! This is only for the recipient!", e)
return
}
}
/* ephemeral.pw/view/client/msgId */
queryString := strings.TrimSuffix(r.URL.Path[len("/view/client/"):],"/")
params := strings.Split(queryString, "/")
if len(params) != 1 {
writeError(w, "Message not found. It may have been deleted.")
return
}
msgId := params[0]
var m sync.Mutex
m.Lock() /* ONLY ONE THREAD IN HERE AT A TIME */
var encryptedText string
var salt string
err := db.QueryRow("SELECT encrypted_text, salt FROM messages WHERE id = ?", msgId).Scan(&encryptedText, &salt)
if err != nil {
writeError(w, "Message not found. It may have been deleted.")
return
}
db.Exec("DELETE FROM messages WHERE id = ? LIMIT 1", msgId)
m.Unlock()
type Out struct {
Message string
Salt string
}
tmpl := template.Must(template.ParseFiles("static/viewClient.html"))
tmpl.Execute(w, Out{encryptedText, salt})
}
/* Generate unique 64 random bits */
func generateMsgId(db *sql.DB) string {
rand64bits := make([]byte, 8)
_, err := rand.Read(rand64bits)
if err != nil {
return generateMsgId(db)
}
id := hex.EncodeToString(rand64bits)
/* Check for collision */
var available bool
db.QueryRow("SELECT COUNT(*) = 0 FROM messages WHERE id = ?", id).Scan(&available)
if(available){
return id
} else {
return generateMsgId(db)
}
}