initial commit

This commit is contained in:
Michael Zhang 2020-10-11 14:32:58 -05:00
commit 24c9f21743
Signed by: michael
GPG key ID: BDA47A31A3C8EE6B
10 changed files with 761 additions and 0 deletions

3
.gitignore vendored Normal file
View file

@ -0,0 +1,3 @@
/subscribe-bot
/db
/config.toml

117
bot.go Normal file
View file

@ -0,0 +1,117 @@
package main
import (
"errors"
"fmt"
"log"
"reflect"
"regexp"
"strconv"
"strings"
"time"
"github.com/bwmarrin/discordgo"
)
type Bot struct {
*discordgo.Session
mentionRe *regexp.Regexp
db *Db
requests chan int
}
func NewBot(token string, db *Db, requests chan int) (bot *Bot, err error) {
s, err := discordgo.New("Bot " + token)
if err != nil {
return
}
err = s.Open()
if err != nil {
return
}
log.Println("connected to discord")
re, err := regexp.Compile("\\s*<@\\!?" + s.State.User.ID + ">\\s*")
if err != nil {
return
}
bot = &Bot{s, re, db, requests}
s.AddHandler(bot.errWrap(bot.newMessageHandler))
return
}
func (bot *Bot) errWrap(fn interface{}) interface{} {
val := reflect.ValueOf(fn)
origType := reflect.TypeOf(fn)
origTypeIn := make([]reflect.Type, origType.NumIn())
for i := 0; i < origType.NumIn(); i++ {
origTypeIn[i] = origType.In(i)
}
newType := reflect.FuncOf(origTypeIn, []reflect.Type{}, false)
newFunc := reflect.MakeFunc(newType, func(args []reflect.Value) (result []reflect.Value) {
res := val.Call(args)
if len(res) > 0 && !res[0].IsNil() {
err := res[0].Interface().(error)
if err != nil {
msg := fmt.Sprintf("error: %s", err)
channel, _ := bot.UserChannelCreate("100443064228646912")
id, _ := bot.ChannelMessageSend(channel.ID, msg)
log.Println(id, msg)
}
}
return []reflect.Value{}
})
return newFunc.Interface()
}
func (bot *Bot) newMessageHandler(s *discordgo.Session, m *discordgo.MessageCreate) (err error) {
mentionsMe := false
for _, user := range m.Mentions {
if user.ID == s.State.User.ID {
mentionsMe = true
break
}
}
if !mentionsMe {
return
}
msg := bot.mentionRe.ReplaceAllString(m.Content, " ")
msg = strings.Trim(msg, " ")
parts := strings.Split(msg, " ")
switch strings.ToLower(parts[0]) {
case "track":
if len(parts) < 2 {
err = errors.New("fucked up")
return
}
var mapperId int
mapperId, err = strconv.Atoi(parts[1])
if err != nil {
return
}
err = bot.db.ChannelTrackMapper(m.ChannelID, mapperId, 3)
if err != nil {
return
}
go func() {
time.Sleep(refreshInterval)
bot.requests <- mapperId
}()
bot.MessageReactionAdd(m.ChannelID, m.ID, "\xf0\x9f\x91\x8d")
}
return
}
func (bot *Bot) Close() {
bot.Session.Close()
}

37
config.go Normal file
View file

@ -0,0 +1,37 @@
package main
import (
"fmt"
"io/ioutil"
"os"
"github.com/BurntSushi/toml"
)
type Config struct {
BotToken string `toml:"bot_token"`
ClientId int `toml:"client_id"`
ClientSecret string `toml:"client_secret"`
}
func ReadConfig(path string) (config Config, err error) {
file, err := os.Open(path)
if err != nil {
err = fmt.Errorf("couldn't open file %s: %w", path, err)
return
}
data, err := ioutil.ReadAll(file)
if err != nil {
err = fmt.Errorf("couldn't read data from %s: %w", path, err)
return
}
err = toml.Unmarshal(data, &config)
if err != nil {
err = fmt.Errorf("couldn't parse config data from %s: %w", path, err)
return
}
return
}

240
db.go Normal file
View file

@ -0,0 +1,240 @@
package main
// Database is laid out like this:
// mapper/<mapper_id>/trackers/<channel_id> -> priority
// mapper/<mapper_id>/latestEvent
// channel/<channel_id>/tracks/<mapper_id> -> priority
import (
"strconv"
bolt "go.etcd.io/bbolt"
)
var (
LATEST_EVENT = []byte("latestEvent")
)
type Db struct {
*bolt.DB
api *Osuapi
}
func OpenDb(path string, api *Osuapi) (db *Db, err error) {
inner, err := bolt.Open(path, 0666, nil)
db = &Db{inner, api}
return
}
// Loop over channels that are tracking this specific mapper
func (db *Db) IterTrackingChannels(mapperId int, fn func(channelId string) error) (err error) {
err = db.DB.View(func(tx *bolt.Tx) error {
mapper := getMapper(tx, mapperId)
if mapper == nil {
return nil
}
trackers := mapper.Bucket([]byte("trackers"))
if trackers == nil {
return nil
}
c := trackers.Cursor()
for k, _ := c.First(); k != nil; k, _ = c.Next() {
channelId := string(k)
err := fn(channelId)
if err != nil {
return err
}
}
return nil
})
return
}
// Loop over tracked mappers
func (db *Db) IterTrackedMappers(fn func(userId int) error) (err error) {
err = db.DB.View(func(tx *bolt.Tx) error {
mappers := tx.Bucket([]byte("mapper"))
if mappers == nil {
return nil
}
c := mappers.Cursor()
for k, _ := c.First(); k != nil; k, _ = c.Next() {
userId, err := strconv.Atoi(string(k))
if err != nil {
return err
}
err = fn(userId)
if err != nil {
return err
}
}
return nil
})
return
}
// Get a list of channels that are tracking this mapper
func (db *Db) GetMapperTrackers(userId int) (trackersList []string) {
trackersList = make([]string, 0)
db.DB.View(func(tx *bolt.Tx) error {
mapper, err := getMapperMut(tx, userId)
if err != nil {
return err
}
trackers := mapper.Bucket([]byte("trackers"))
if trackers == nil {
return nil
}
c := trackers.Cursor()
for k, _ := c.First(); k != nil; k, _ = c.Next() {
channelId := string(k)
trackersList = append(trackersList, channelId)
}
return nil
})
return
}
// Update the latest event of a mapper to the given one
func (db *Db) UpdateMapperLatestEvent(userId int, eventId int) (err error) {
err = db.DB.Update(func(tx *bolt.Tx) error {
mapper, err := getMapperMut(tx, userId)
if err != nil {
return err
}
err = mapper.Put(LATEST_EVENT, []byte(strconv.Itoa(eventId)))
if err != nil {
return err
}
return nil
})
return
}
// Get the latest event ID of this mapper, if they have one
func (db *Db) MapperLastEvent(userId int) (has bool, id int) {
has = false
id = -1
db.DB.View(func(tx *bolt.Tx) error {
mapper := getMapper(tx, userId)
if mapper == nil {
return nil
}
lastEventId := mapper.Get(LATEST_EVENT)
if lastEventId == nil {
return nil
}
var err error
id, err = strconv.Atoi(string(lastEventId))
if err != nil {
return nil
}
has = true
return nil
})
return
}
// Start tracking a new mapper (if they're not already tracked)
func (db *Db) ChannelTrackMapper(channelId string, mapperId int, priority int) (err error) {
events, err := db.api.GetUserEvents(mapperId, 1, 0)
if err != nil {
return
}
err = db.Batch(func(tx *bolt.Tx) error {
{
mapper, err := getMapperMut(tx, mapperId)
if err != nil {
return err
}
if len(events) > 0 {
latestEventId := strconv.Itoa(events[0].ID)
mapper.Put(LATEST_EVENT, []byte(latestEventId))
}
trackers, err := mapper.CreateBucketIfNotExists([]byte("trackers"))
if err != nil {
return err
}
err = trackers.Put([]byte(channelId), []byte(strconv.Itoa(priority)))
if err != nil {
return err
}
}
{
channels, err := tx.CreateBucketIfNotExists([]byte("channels"))
if err != nil {
return err
}
channel, err := channels.CreateBucketIfNotExists([]byte(channelId))
if err != nil {
return err
}
tracks, err := channel.CreateBucketIfNotExists([]byte("tracks"))
if err != nil {
return err
}
err = tracks.Put([]byte(strconv.Itoa(mapperId)), []byte(strconv.Itoa(priority)))
if err != nil {
return err
}
}
return nil
})
return
}
func (db *Db) Close() {
db.DB.Close()
}
func getMapper(tx *bolt.Tx, userId int) (mapper *bolt.Bucket) {
mappers := tx.Bucket([]byte("mapper"))
if mappers == nil {
return nil
}
mapper = mappers.Bucket([]byte(strconv.Itoa(userId)))
if mapper == nil {
return nil
}
return
}
func getMapperMut(tx *bolt.Tx, userId int) (mapper *bolt.Bucket, err error) {
mappers, err := tx.CreateBucketIfNotExists([]byte("mapper"))
if err != nil {
return
}
mapper, err = mappers.CreateBucketIfNotExists([]byte(strconv.Itoa(userId)))
if err != nil {
return
}
return
}

10
go.mod Normal file
View file

@ -0,0 +1,10 @@
module subscribe-bot
go 1.14
require (
github.com/BurntSushi/toml v0.3.1
github.com/bwmarrin/discordgo v0.22.0
go.etcd.io/bbolt v1.3.5
golang.org/x/sync v0.0.0-20201008141435-b3e1573b7520
)

13
go.sum Normal file
View file

@ -0,0 +1,13 @@
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/bwmarrin/discordgo v0.22.0 h1:uBxY1HmlVCsW1IuaPjpCGT6A2DBwRn0nvOguQIxDdFM=
github.com/bwmarrin/discordgo v0.22.0/go.mod h1:c1WtWUGN6nREDmzIpyTp/iD3VYt4Fpx+bVyfBG7JE+M=
github.com/gorilla/websocket v1.4.0 h1:WDFjx/TMzVgy9VdMMQi2K2Emtwi2QcUQsztZ/zLaH/Q=
github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0=
go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16 h1:y6ce7gCWtnH+m3dCjzQ1PCuwl28DDIc3VNnvY29DlIA=
golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/sync v0.0.0-20201008141435-b3e1573b7520 h1:Bx6FllMpG4NWDOfhMBz1VR2QYNp/SAOHPIAsaVmxfPo=
golang.org/x/sync v0.0.0-20201008141435-b3e1573b7520/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

63
main.go Normal file
View file

@ -0,0 +1,63 @@
package main
import (
"flag"
"log"
"os"
"os/signal"
"syscall"
)
var exit_chan = make(chan int)
func main() {
configPath := flag.String("config", "config.toml", "Path to the config file (defaults to config.toml)")
flag.Parse()
config, err := ReadConfig(*configPath)
requests := make(chan int)
api := NewOsuapi(&config)
db, err := OpenDb("db", api)
if err != nil {
log.Fatal(err)
}
bot, err := NewBot(config.BotToken, db, requests)
if err != nil {
log.Fatal(err)
}
go RunScraper(bot, db, api, requests)
signal_chan := make(chan os.Signal, 1)
signal.Notify(signal_chan,
syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGTERM,
syscall.SIGQUIT)
go func() {
for {
s := <-signal_chan
switch s {
case syscall.SIGHUP:
fallthrough
case syscall.SIGINT:
fallthrough
case syscall.SIGTERM:
fallthrough
case syscall.SIGQUIT:
exit_chan <- 0
default:
exit_chan <- 1
}
}
}()
code := <-exit_chan
db.Close()
bot.Close()
os.Exit(code)
}

32
models.go Normal file
View file

@ -0,0 +1,32 @@
package main
type Event struct {
CreatedAt string `json:"created_at"`
ID int `json:"id"`
Type string `json:"type"`
// type: achievement
Achievement Achievement `json:"achievement,omitempty"`
// type: beatmapsetApprove
// type: beatmapsetDelete
// type: beatmapsetRevive
// type: beatmapsetUpdate
// type: beatmapsetUpload
Beatmapset Beatmapset `json:"beatmapset,omitempty"`
User User `json:"user,omitempty"`
}
type Achievement struct{}
type Beatmapset struct {
Title string `json:"title"`
URL string `json:"url"`
}
type User struct {
Username string `json:"username"`
URL string `json:"url"`
PreviousUsername string `json:"previousUsername,omitempty"`
}

126
osuapi.go Normal file
View file

@ -0,0 +1,126 @@
package main
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net/http"
"strings"
"time"
"golang.org/x/sync/semaphore"
)
const BASE_URL = "https://osu.ppy.sh/api/v2"
type Osuapi struct {
lock *semaphore.Weighted
token string
expires time.Time
clientId int
clientSecret string
}
func NewOsuapi(config *Config) *Osuapi {
// want to cap at around 1000 requests a minute, OSU cap is 1200
lock := semaphore.NewWeighted(1000)
return &Osuapi{lock, "", time.Now(), config.ClientId, config.ClientSecret}
}
func (api *Osuapi) Token() (token string, err error) {
if time.Now().Before(api.expires) {
token = api.token
return
}
data := fmt.Sprintf(
"client_id=%d&client_secret=%s&grant_type=client_credentials&scope=public",
api.clientId,
api.clientSecret,
)
resp, err := http.Post("https://osu.ppy.sh/oauth/token", "application/x-www-form-urlencoded", strings.NewReader(data))
if err != nil {
return
}
var osuToken OsuToken
respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
return
}
err = json.Unmarshal(respBody, &osuToken)
if err != nil {
return
}
log.Println("got new access token", osuToken.AccessToken[:12]+"...")
api.token = osuToken.AccessToken
api.expires = time.Now().Add(time.Duration(osuToken.ExpiresIn) * time.Second)
token = api.token
return
}
func (api *Osuapi) Request(action string, url string, result interface{}) (err error) {
err = api.lock.Acquire(context.TODO(), 1)
if err != nil {
return
}
apiUrl := BASE_URL + url
req, err := http.NewRequest(action, apiUrl, nil)
token, err := api.Token()
if err != nil {
return
}
req.Header.Add("Authorization", "Bearer "+token)
if err != nil {
return
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return
}
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
return
}
err = json.Unmarshal(data, result)
if err != nil {
return
}
// release the lock after 1 minute
go func() {
time.Sleep(time.Minute)
api.lock.Release(1)
}()
return
}
func (api *Osuapi) GetUserEvents(userId int, limit int, offset int) (events []Event, err error) {
url := fmt.Sprintf(
"/users/%d/recent_activity?limit=%d&offset=%d",
userId,
limit,
offset,
)
err = api.Request("GET", url, &events)
if err != nil {
return
}
return
}
type OsuToken struct {
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
AccessToken string `json:"access_token"`
}

120
scrape.go Normal file
View file

@ -0,0 +1,120 @@
package main
import (
"fmt"
"log"
"time"
)
var (
refreshInterval = 60 * time.Second
)
func RunScraper(bot *Bot, db *Db, api *Osuapi, requests chan int) {
// start timers
go startTimers(db, requests)
for userId := range requests {
log.Println("scraping", userId)
newMaps, err := getNewMaps(db, api, userId)
if err != nil {
log.Println("err getting new maps:", err)
exit_chan <- 1
}
db.IterTrackingChannels(userId, func(channelId string) error {
for _, beatmap := range newMaps {
bot.ChannelMessageSend(channelId, fmt.Sprintf("new beatmap event [%s](%s)", beatmap.Title, beatmap.URL))
}
return nil
})
// wait a minute and put them back into the queue
go func() {
time.Sleep(refreshInterval)
requests <- userId
}()
}
}
func getNewMaps(db *Db, api *Osuapi, userId int) (newMaps []Beatmapset, err error) {
// see if there's a last event
hasLastEvent, lastEventId := db.MapperLastEvent(userId)
newMaps = make([]Beatmapset, 0)
var (
events []Event
newLatestEvent = 0
updateLatestEvent = false
)
if hasLastEvent {
log.Printf("last event id for %d is %d\n", userId, lastEventId)
offset := 0
loop:
for {
log.Println("loading user events from", offset)
events, err = api.GetUserEvents(userId, 50, offset)
if err != nil {
err = fmt.Errorf("couldn't load events for user %d, offset %d: %w", userId, offset, err)
return
}
if len(events) == 0 {
break
}
for _, event := range events {
if event.ID == lastEventId {
break loop
}
if event.ID > newLatestEvent {
updateLatestEvent = true
newLatestEvent = event.ID
}
if event.Type == "beatmapsetUpload" ||
event.Type == "beatmapsetRevive" ||
event.Type == "beatmapsetUpdate" {
newMaps = append(newMaps, event.Beatmapset)
}
}
offset += len(events)
}
} else {
log.Printf("no last event id found for %d\n", userId)
events, err = api.GetUserEvents(userId, 50, 0)
if err != nil {
return
}
for _, event := range events {
if event.ID > newLatestEvent {
updateLatestEvent = true
newLatestEvent = event.ID
}
if event.Type == "beatmapsetUpload" ||
event.Type == "beatmapsetRevive" ||
event.Type == "beatmapsetUpdate" {
newMaps = append(newMaps, event.Beatmapset)
}
}
}
if updateLatestEvent {
err = db.UpdateMapperLatestEvent(userId, newLatestEvent)
if err != nil {
return
}
}
return
}
func startTimers(db *Db, requests chan int) {
db.IterTrackedMappers(func(userId int) error {
requests <- userId
return nil
})
}