mirror of
https://github.com/FrankerFaceZ/FrankerFaceZ.git
synced 2025-06-27 21:05:53 +00:00
327 lines
8 KiB
Go
327 lines
8 KiB
Go
package server
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/FrankerFaceZ/FrankerFaceZ/socketserver/server/rate"
|
|
"github.com/pkg/errors"
|
|
"golang.org/x/sync/singleflight"
|
|
)
|
|
|
|
// LastSavedMessage contains a reply to a command along with an expiration time.
|
|
type LastSavedMessage struct {
|
|
Expires time.Time
|
|
Data string
|
|
}
|
|
|
|
// map is command -> channel -> data
|
|
|
|
// CachedLastMessages is of CacheTypeLastOnly.
|
|
// Not actually cleaned up by reaper goroutine every ~hour.
|
|
var CachedLastMessages = make(map[Command]map[string]LastSavedMessage)
|
|
var CachedLSMLock sync.RWMutex
|
|
|
|
var singleFlighter singleflight.Group
|
|
|
|
func cachedMessageJanitor() {
|
|
for {
|
|
time.Sleep(1 * time.Hour)
|
|
cachedMessageJanitor_do()
|
|
}
|
|
}
|
|
|
|
func cachedMessageJanitor_do() {
|
|
CachedLSMLock.Lock()
|
|
defer CachedLSMLock.Unlock()
|
|
|
|
now := time.Now()
|
|
|
|
for cmd, chanMap := range CachedLastMessages {
|
|
for channel, msg := range chanMap {
|
|
if !msg.Expires.IsZero() && msg.Expires.Before(now) {
|
|
delete(chanMap, channel)
|
|
}
|
|
}
|
|
if len(chanMap) == 0 {
|
|
delete(CachedLastMessages, cmd)
|
|
}
|
|
}
|
|
}
|
|
|
|
// DumpBacklogData drops all /cached_pub data.
|
|
func DumpBacklogData() {
|
|
CachedLSMLock.Lock()
|
|
CachedLastMessages = make(map[Command]map[string]LastSavedMessage)
|
|
CachedLSMLock.Unlock()
|
|
}
|
|
|
|
// SendBacklogForNewClient sends any backlog data relevant to a new client.
|
|
// This should be done when the client sends a `ready` message.
|
|
// This will only send data for CacheTypePersistent and CacheTypeLastOnly because those do not involve timestamps.
|
|
func SendBacklogForNewClient(client *ClientInfo) {
|
|
client.Mutex.Lock() // reading CurrentChannels
|
|
curChannels := make([]string, len(client.CurrentChannels))
|
|
copy(curChannels, client.CurrentChannels)
|
|
client.Mutex.Unlock()
|
|
|
|
CachedLSMLock.RLock()
|
|
for cmd, chanMap := range CachedLastMessages {
|
|
if chanMap == nil {
|
|
continue
|
|
}
|
|
for _, channel := range curChannels {
|
|
msg, ok := chanMap[channel]
|
|
if ok {
|
|
msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: msg.Data}
|
|
msg.parseOrigArguments()
|
|
client.Send(msg)
|
|
}
|
|
}
|
|
}
|
|
CachedLSMLock.RUnlock()
|
|
}
|
|
|
|
func SendBacklogForChannel(client *ClientInfo, channel string) {
|
|
CachedLSMLock.RLock()
|
|
for cmd, chanMap := range CachedLastMessages {
|
|
if chanMap == nil {
|
|
continue
|
|
}
|
|
if msg, ok := chanMap[channel]; ok {
|
|
msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: msg.Data}
|
|
msg.parseOrigArguments()
|
|
client.Send(msg)
|
|
}
|
|
}
|
|
CachedLSMLock.RUnlock()
|
|
}
|
|
|
|
type timestampArray interface {
|
|
Len() int
|
|
GetTime(int) time.Time
|
|
}
|
|
|
|
// the CachedLSMLock must be held when calling this
|
|
func saveLastMessage(cmd Command, channel string, expires time.Time, data string, deleting bool) {
|
|
chanMap, ok := CachedLastMessages[cmd]
|
|
if !ok {
|
|
if deleting {
|
|
return
|
|
}
|
|
chanMap = make(map[string]LastSavedMessage)
|
|
CachedLastMessages[cmd] = chanMap
|
|
}
|
|
|
|
if deleting {
|
|
delete(chanMap, channel)
|
|
} else {
|
|
chanMap[channel] = LastSavedMessage{Expires: expires, Data: data}
|
|
}
|
|
}
|
|
|
|
func HTTPBackendDropBacklog(w http.ResponseWriter, r *http.Request) {
|
|
r.ParseForm()
|
|
formData, err := Backend.secureForm.Unseal(r.Form)
|
|
if err != nil {
|
|
w.WriteHeader(403)
|
|
fmt.Fprintf(w, "Error: %v", err)
|
|
return
|
|
}
|
|
|
|
confirm := formData.Get("confirm")
|
|
if confirm == "1" {
|
|
DumpBacklogData()
|
|
}
|
|
}
|
|
|
|
func rateLimitFromFormData(formData url.Values) (rate.Limiter, error) {
|
|
rateCount := formData.Get("rateCount")
|
|
if rateCount != "" {
|
|
c, err := strconv.ParseInt(rateCount, 10, 32)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "rateCount")
|
|
}
|
|
d, err := time.ParseDuration(formData.Get("rateTime"))
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "rateTime")
|
|
}
|
|
return rate.NewRateLimit(int(c), d), nil
|
|
}
|
|
return rate.Unlimited(), nil
|
|
}
|
|
|
|
// HTTPBackendCachedPublish handles the /cached_pub route.
|
|
// It publishes a message to clients, and then updates the in-server cache for the message.
|
|
//
|
|
// The 'channel' parameter is a comma-separated list of topics to publish the message to.
|
|
// The 'args' parameter is the JSON-encoded command data.
|
|
// If the 'delete' parameter is present, an entry is removed from the cache instead of publishing a message.
|
|
// If the 'expires' parameter is not specified, the message will not expire (though it is only kept in-memory).
|
|
func HTTPBackendCachedPublish(w http.ResponseWriter, r *http.Request) {
|
|
r.ParseForm()
|
|
formData, err := Backend.secureForm.Unseal(r.Form)
|
|
if err != nil {
|
|
w.WriteHeader(403)
|
|
fmt.Fprintf(w, "Error: %v", err)
|
|
return
|
|
}
|
|
|
|
cmd := CommandPool.InternCommand(formData.Get("cmd"))
|
|
json := formData.Get("args")
|
|
channel := formData.Get("channel")
|
|
deleteMode := formData.Get("delete") != ""
|
|
timeStr := formData.Get("expires")
|
|
var expires time.Time
|
|
if timeStr != "" {
|
|
timeNum, err := strconv.ParseInt(timeStr, 10, 64)
|
|
if err != nil {
|
|
w.WriteHeader(422)
|
|
fmt.Fprintf(w, "error parsing time: %v", err)
|
|
return
|
|
}
|
|
expires = time.Unix(timeNum, 0)
|
|
}
|
|
rl, err := rateLimitFromFormData(formData)
|
|
if err != nil {
|
|
w.WriteHeader(422)
|
|
fmt.Fprintf(w, "error parsing ratelimit: %v", err)
|
|
return
|
|
}
|
|
|
|
var count int
|
|
msg := ClientMessage{MessageID: -1, Command: cmd, origArguments: json}
|
|
msg.parseOrigArguments()
|
|
|
|
channels := strings.Split(channel, ",")
|
|
CachedLSMLock.Lock()
|
|
for _, channel := range channels {
|
|
saveLastMessage(cmd, channel, expires, json, deleteMode)
|
|
}
|
|
CachedLSMLock.Unlock()
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
go rl.Run()
|
|
go func() {
|
|
count = PublishToMultiple(channels, msg, rl)
|
|
wg.Done()
|
|
rl.Close()
|
|
}()
|
|
ch := make(chan struct{})
|
|
go func() {
|
|
wg.Wait()
|
|
close(ch)
|
|
}()
|
|
select {
|
|
case <-time.After(3 * time.Second):
|
|
count = -1
|
|
case <-ch:
|
|
}
|
|
w.Write([]byte(strconv.Itoa(count)))
|
|
}
|
|
|
|
// HTTPBackendUncachedPublish handles the /uncached_pub route.
|
|
// The backend can POST here to publish a message to clients with no caching.
|
|
// The POST arguments are `cmd`, `args`, `channel`, and `scope`.
|
|
// If "scope" is "global", then "channel" is not used.
|
|
func HTTPBackendUncachedPublish(w http.ResponseWriter, r *http.Request) {
|
|
r.ParseForm()
|
|
formData, err := Backend.secureForm.Unseal(r.Form)
|
|
if err != nil {
|
|
w.WriteHeader(403)
|
|
fmt.Fprintf(w, "Error: %v", err)
|
|
return
|
|
}
|
|
|
|
cmd := formData.Get("cmd")
|
|
json := formData.Get("args")
|
|
channel := formData.Get("channel")
|
|
scope := formData.Get("scope")
|
|
|
|
if cmd == "" {
|
|
w.WriteHeader(422)
|
|
fmt.Fprint(w, "Error: cmd cannot be blank")
|
|
return
|
|
}
|
|
if channel == "" && scope != "global" {
|
|
w.WriteHeader(422)
|
|
fmt.Fprint(w, "Error: channel must be specified")
|
|
return
|
|
}
|
|
rl, err := rateLimitFromFormData(formData)
|
|
if err != nil {
|
|
w.WriteHeader(422)
|
|
fmt.Fprintf(w, "error parsing ratelimit: %v", err)
|
|
return
|
|
}
|
|
|
|
cm := ClientMessage{MessageID: -1, Command: CommandPool.InternCommand(cmd), origArguments: json}
|
|
cm.parseOrigArguments()
|
|
|
|
var count int
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
go rl.Run()
|
|
go func() {
|
|
switch scope {
|
|
default:
|
|
count = PublishToMultiple(strings.Split(channel, ","), cm, rl)
|
|
case "global":
|
|
count = PublishToAll(cm, rl)
|
|
}
|
|
wg.Done()
|
|
rl.Close()
|
|
}()
|
|
ch := make(chan struct{})
|
|
go func() {
|
|
wg.Wait()
|
|
close(ch)
|
|
}()
|
|
select {
|
|
case <-time.After(3 * time.Second):
|
|
count = -1
|
|
case <-ch:
|
|
}
|
|
w.Write([]byte(strconv.Itoa(count)))
|
|
|
|
}
|
|
|
|
// HTTPGetSubscriberCount handles the /get_sub_count route.
|
|
// It replies with the number of clients subscribed to a pub/sub topic.
|
|
// A "global" option is not available, use fetch(/stats).CurrentClientCount instead.
|
|
func HTTPGetSubscriberCount(w http.ResponseWriter, r *http.Request) {
|
|
r.ParseForm()
|
|
formData, err := Backend.secureForm.Unseal(r.Form)
|
|
if err != nil {
|
|
w.WriteHeader(403)
|
|
fmt.Fprintf(w, "Error: %v", err)
|
|
return
|
|
}
|
|
|
|
channel := formData.Get("channel")
|
|
|
|
fmt.Fprint(w, CountSubscriptions(strings.Split(channel, ",")))
|
|
}
|
|
|
|
func HTTPListAllTopics(w http.ResponseWriter, r *http.Request) {
|
|
r.ParseForm()
|
|
_, err := Backend.secureForm.Unseal(r.Form)
|
|
if err != nil {
|
|
//w.WriteHeader(403)
|
|
//fmt.Fprintf(w, "Error: %v", err)
|
|
//return
|
|
}
|
|
|
|
topicList, _, _ := singleFlighter.Do("/all_topics", func() (interface{}, error) {
|
|
return GetAllTopics(), nil
|
|
})
|
|
w.WriteHeader(200)
|
|
json.NewEncoder(w).Encode(topicList)
|
|
}
|