commit 24c9f21743174ff0bea72f42159502748951edba Author: Michael Zhang Date: Sun Oct 11 14:32:58 2020 -0500 initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..92f116f --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/subscribe-bot +/db +/config.toml diff --git a/bot.go b/bot.go new file mode 100644 index 0000000..0924574 --- /dev/null +++ b/bot.go @@ -0,0 +1,117 @@ +package main + +import ( + "errors" + "fmt" + "log" + "reflect" + "regexp" + "strconv" + "strings" + "time" + + "github.com/bwmarrin/discordgo" +) + +type Bot struct { + *discordgo.Session + mentionRe *regexp.Regexp + db *Db + requests chan int +} + +func NewBot(token string, db *Db, requests chan int) (bot *Bot, err error) { + s, err := discordgo.New("Bot " + token) + if err != nil { + return + } + + err = s.Open() + if err != nil { + return + } + log.Println("connected to discord") + + re, err := regexp.Compile("\\s*<@\\!?" + s.State.User.ID + ">\\s*") + if err != nil { + return + } + + bot = &Bot{s, re, db, requests} + s.AddHandler(bot.errWrap(bot.newMessageHandler)) + return +} + +func (bot *Bot) errWrap(fn interface{}) interface{} { + val := reflect.ValueOf(fn) + origType := reflect.TypeOf(fn) + origTypeIn := make([]reflect.Type, origType.NumIn()) + for i := 0; i < origType.NumIn(); i++ { + origTypeIn[i] = origType.In(i) + } + newType := reflect.FuncOf(origTypeIn, []reflect.Type{}, false) + newFunc := reflect.MakeFunc(newType, func(args []reflect.Value) (result []reflect.Value) { + res := val.Call(args) + if len(res) > 0 && !res[0].IsNil() { + err := res[0].Interface().(error) + if err != nil { + msg := fmt.Sprintf("error: %s", err) + channel, _ := bot.UserChannelCreate("100443064228646912") + id, _ := bot.ChannelMessageSend(channel.ID, msg) + log.Println(id, msg) + } + } + return []reflect.Value{} + }) + return newFunc.Interface() +} + +func (bot *Bot) newMessageHandler(s *discordgo.Session, m *discordgo.MessageCreate) (err error) { + mentionsMe := false + for _, user := range m.Mentions { + if user.ID == s.State.User.ID { + mentionsMe = true + break + } + } + + if !mentionsMe { + return + } + + msg := bot.mentionRe.ReplaceAllString(m.Content, " ") + msg = strings.Trim(msg, " ") + + parts := strings.Split(msg, " ") + switch strings.ToLower(parts[0]) { + case "track": + if len(parts) < 2 { + err = errors.New("fucked up") + return + } + + var mapperId int + mapperId, err = strconv.Atoi(parts[1]) + if err != nil { + return + } + + err = bot.db.ChannelTrackMapper(m.ChannelID, mapperId, 3) + if err != nil { + return + } + + go func() { + time.Sleep(refreshInterval) + bot.requests <- mapperId + }() + + bot.MessageReactionAdd(m.ChannelID, m.ID, "\xf0\x9f\x91\x8d") + } + + return +} + +func (bot *Bot) Close() { + bot.Session.Close() +} diff --git a/config.go b/config.go new file mode 100644 index 0000000..6a842fe --- /dev/null +++ b/config.go @@ -0,0 +1,37 @@ +package main + +import ( + "fmt" + "io/ioutil" + "os" + + "github.com/BurntSushi/toml" +) + +type Config struct { + BotToken string `toml:"bot_token"` + ClientId int `toml:"client_id"` + ClientSecret string `toml:"client_secret"` +} + +func ReadConfig(path string) (config Config, err error) { + file, err := os.Open(path) + if err != nil { + err = fmt.Errorf("couldn't open file %s: %w", path, err) + return + } + + data, err := ioutil.ReadAll(file) + if err != nil { + err = fmt.Errorf("couldn't read data from %s: %w", path, err) + return + } + + err = toml.Unmarshal(data, &config) + if err != nil { + err = fmt.Errorf("couldn't parse config data from %s: %w", path, err) + return + } + + return +} diff --git a/db.go b/db.go new file mode 100644 index 0000000..21823dd --- /dev/null +++ b/db.go @@ -0,0 +1,240 @@ +package main + +// Database is laid out like this: +// mapper//trackers/ -> priority +// mapper//latestEvent +// channel//tracks/ -> priority + +import ( + "strconv" + + bolt "go.etcd.io/bbolt" +) + +var ( + LATEST_EVENT = []byte("latestEvent") +) + +type Db struct { + *bolt.DB + api *Osuapi +} + +func OpenDb(path string, api *Osuapi) (db *Db, err error) { + inner, err := bolt.Open(path, 0666, nil) + db = &Db{inner, api} + return +} + +// Loop over channels that are tracking this specific mapper +func (db *Db) IterTrackingChannels(mapperId int, fn func(channelId string) error) (err error) { + err = db.DB.View(func(tx *bolt.Tx) error { + mapper := getMapper(tx, mapperId) + if mapper == nil { + return nil + } + + trackers := mapper.Bucket([]byte("trackers")) + if trackers == nil { + return nil + } + + c := trackers.Cursor() + for k, _ := c.First(); k != nil; k, _ = c.Next() { + channelId := string(k) + err := fn(channelId) + if err != nil { + return err + } + } + + return nil + }) + + return +} + +// Loop over tracked mappers +func (db *Db) IterTrackedMappers(fn func(userId int) error) (err error) { + err = db.DB.View(func(tx *bolt.Tx) error { + mappers := tx.Bucket([]byte("mapper")) + if mappers == nil { + return nil + } + + c := mappers.Cursor() + for k, _ := c.First(); k != nil; k, _ = c.Next() { + userId, err := strconv.Atoi(string(k)) + if err != nil { + return err + } + + err = fn(userId) + if err != nil { + return err + } + } + + return nil + }) + + return +} + +// Get a list of channels that are tracking this mapper +func (db *Db) GetMapperTrackers(userId int) (trackersList []string) { + trackersList = make([]string, 0) + db.DB.View(func(tx *bolt.Tx) error { + mapper, err := getMapperMut(tx, userId) + if err != nil { + return err + } + + trackers := mapper.Bucket([]byte("trackers")) + if trackers == nil { + return nil + } + + c := trackers.Cursor() + for k, _ := c.First(); k != nil; k, _ = c.Next() { + channelId := string(k) + trackersList = append(trackersList, channelId) + } + + return nil + }) + return +} + +// Update the latest event of a mapper to the given one +func (db *Db) UpdateMapperLatestEvent(userId int, eventId int) (err error) { + err = db.DB.Update(func(tx *bolt.Tx) error { + mapper, err := getMapperMut(tx, userId) + if err != nil { + return err + } + + err = mapper.Put(LATEST_EVENT, []byte(strconv.Itoa(eventId))) + if err != nil { + return err + } + + return nil + }) + return +} + +// Get the latest event ID of this mapper, if they have one +func (db *Db) MapperLastEvent(userId int) (has bool, id int) { + has = false + id = -1 + db.DB.View(func(tx *bolt.Tx) error { + mapper := getMapper(tx, userId) + if mapper == nil { + return nil + } + + lastEventId := mapper.Get(LATEST_EVENT) + if lastEventId == nil { + return nil + } + + var err error + id, err = strconv.Atoi(string(lastEventId)) + if err != nil { + return nil + } + + has = true + return nil + }) + + return +} + +// Start tracking a new mapper (if they're not already tracked) +func (db *Db) ChannelTrackMapper(channelId string, mapperId int, priority int) (err error) { + events, err := db.api.GetUserEvents(mapperId, 1, 0) + if err != nil { + return + } + + err = db.Batch(func(tx *bolt.Tx) error { + { + mapper, err := getMapperMut(tx, mapperId) + if err != nil { + return err + } + + if len(events) > 0 { + latestEventId := strconv.Itoa(events[0].ID) + mapper.Put(LATEST_EVENT, []byte(latestEventId)) + } + + trackers, err := mapper.CreateBucketIfNotExists([]byte("trackers")) + if err != nil { + return err + } + + err = trackers.Put([]byte(channelId), []byte(strconv.Itoa(priority))) + if err != nil { + return err + } + } + { + channels, err := tx.CreateBucketIfNotExists([]byte("channels")) + if err != nil { + return err + } + + channel, err := channels.CreateBucketIfNotExists([]byte(channelId)) + if err != nil { + return err + } + + tracks, err := channel.CreateBucketIfNotExists([]byte("tracks")) + if err != nil { + return err + } + + err = tracks.Put([]byte(strconv.Itoa(mapperId)), []byte(strconv.Itoa(priority))) + if err != nil { + return err + } + } + return nil + }) + return +} + +func (db *Db) Close() { + db.DB.Close() +} + +func getMapper(tx *bolt.Tx, userId int) (mapper *bolt.Bucket) { + mappers := tx.Bucket([]byte("mapper")) + if mappers == nil { + return nil + } + + mapper = mappers.Bucket([]byte(strconv.Itoa(userId))) + if mapper == nil { + return nil + } + + return +} + +func getMapperMut(tx *bolt.Tx, userId int) (mapper *bolt.Bucket, err error) { + mappers, err := tx.CreateBucketIfNotExists([]byte("mapper")) + if err != nil { + return + } + + mapper, err = mappers.CreateBucketIfNotExists([]byte(strconv.Itoa(userId))) + if err != nil { + return + } + + return +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..d7b76e0 --- /dev/null +++ b/go.mod @@ -0,0 +1,10 @@ +module subscribe-bot + +go 1.14 + +require ( + github.com/BurntSushi/toml v0.3.1 + github.com/bwmarrin/discordgo v0.22.0 + go.etcd.io/bbolt v1.3.5 + golang.org/x/sync v0.0.0-20201008141435-b3e1573b7520 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e39543c --- /dev/null +++ b/go.sum @@ -0,0 +1,13 @@ +github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/bwmarrin/discordgo v0.22.0 h1:uBxY1HmlVCsW1IuaPjpCGT6A2DBwRn0nvOguQIxDdFM= +github.com/bwmarrin/discordgo v0.22.0/go.mod h1:c1WtWUGN6nREDmzIpyTp/iD3VYt4Fpx+bVyfBG7JE+M= +github.com/gorilla/websocket v1.4.0 h1:WDFjx/TMzVgy9VdMMQi2K2Emtwi2QcUQsztZ/zLaH/Q= +github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= +go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0= +go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= +golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16 h1:y6ce7gCWtnH+m3dCjzQ1PCuwl28DDIc3VNnvY29DlIA= +golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/sync v0.0.0-20201008141435-b3e1573b7520 h1:Bx6FllMpG4NWDOfhMBz1VR2QYNp/SAOHPIAsaVmxfPo= +golang.org/x/sync v0.0.0-20201008141435-b3e1573b7520/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/main.go b/main.go new file mode 100644 index 0000000..e023296 --- /dev/null +++ b/main.go @@ -0,0 +1,63 @@ +package main + +import ( + "flag" + "log" + "os" + "os/signal" + "syscall" +) + +var exit_chan = make(chan int) + +func main() { + configPath := flag.String("config", "config.toml", "Path to the config file (defaults to config.toml)") + flag.Parse() + + config, err := ReadConfig(*configPath) + + requests := make(chan int) + + api := NewOsuapi(&config) + + db, err := OpenDb("db", api) + if err != nil { + log.Fatal(err) + } + + bot, err := NewBot(config.BotToken, db, requests) + if err != nil { + log.Fatal(err) + } + + go RunScraper(bot, db, api, requests) + + signal_chan := make(chan os.Signal, 1) + signal.Notify(signal_chan, + syscall.SIGHUP, + syscall.SIGINT, + syscall.SIGTERM, + syscall.SIGQUIT) + go func() { + for { + s := <-signal_chan + switch s { + case syscall.SIGHUP: + fallthrough + case syscall.SIGINT: + fallthrough + case syscall.SIGTERM: + fallthrough + case syscall.SIGQUIT: + exit_chan <- 0 + default: + exit_chan <- 1 + } + } + }() + code := <-exit_chan + + db.Close() + bot.Close() + os.Exit(code) +} diff --git a/models.go b/models.go new file mode 100644 index 0000000..b8db5a4 --- /dev/null +++ b/models.go @@ -0,0 +1,32 @@ +package main + +type Event struct { + CreatedAt string `json:"created_at"` + ID int `json:"id"` + Type string `json:"type"` + + // type: achievement + Achievement Achievement `json:"achievement,omitempty"` + + // type: beatmapsetApprove + // type: beatmapsetDelete + // type: beatmapsetRevive + // type: beatmapsetUpdate + // type: beatmapsetUpload + Beatmapset Beatmapset `json:"beatmapset,omitempty"` + + User User `json:"user,omitempty"` +} + +type Achievement struct{} + +type Beatmapset struct { + Title string `json:"title"` + URL string `json:"url"` +} + +type User struct { + Username string `json:"username"` + URL string `json:"url"` + PreviousUsername string `json:"previousUsername,omitempty"` +} diff --git a/osuapi.go b/osuapi.go new file mode 100644 index 0000000..97a3778 --- /dev/null +++ b/osuapi.go @@ -0,0 +1,126 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "log" + "net/http" + "strings" + "time" + + "golang.org/x/sync/semaphore" +) + +const BASE_URL = "https://osu.ppy.sh/api/v2" + +type Osuapi struct { + lock *semaphore.Weighted + token string + expires time.Time + clientId int + clientSecret string +} + +func NewOsuapi(config *Config) *Osuapi { + // want to cap at around 1000 requests a minute, OSU cap is 1200 + lock := semaphore.NewWeighted(1000) + return &Osuapi{lock, "", time.Now(), config.ClientId, config.ClientSecret} +} + +func (api *Osuapi) Token() (token string, err error) { + if time.Now().Before(api.expires) { + token = api.token + return + } + + data := fmt.Sprintf( + "client_id=%d&client_secret=%s&grant_type=client_credentials&scope=public", + api.clientId, + api.clientSecret, + ) + + resp, err := http.Post("https://osu.ppy.sh/oauth/token", "application/x-www-form-urlencoded", strings.NewReader(data)) + if err != nil { + return + } + + var osuToken OsuToken + respBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + return + } + + err = json.Unmarshal(respBody, &osuToken) + if err != nil { + return + } + + log.Println("got new access token", osuToken.AccessToken[:12]+"...") + api.token = osuToken.AccessToken + api.expires = time.Now().Add(time.Duration(osuToken.ExpiresIn) * time.Second) + token = api.token + return +} + +func (api *Osuapi) Request(action string, url string, result interface{}) (err error) { + err = api.lock.Acquire(context.TODO(), 1) + if err != nil { + return + } + apiUrl := BASE_URL + url + req, err := http.NewRequest(action, apiUrl, nil) + + token, err := api.Token() + if err != nil { + return + } + + req.Header.Add("Authorization", "Bearer "+token) + if err != nil { + return + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return + } + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + return + } + + err = json.Unmarshal(data, result) + if err != nil { + return + } + + // release the lock after 1 minute + go func() { + time.Sleep(time.Minute) + api.lock.Release(1) + }() + return +} + +func (api *Osuapi) GetUserEvents(userId int, limit int, offset int) (events []Event, err error) { + url := fmt.Sprintf( + "/users/%d/recent_activity?limit=%d&offset=%d", + userId, + limit, + offset, + ) + err = api.Request("GET", url, &events) + if err != nil { + return + } + + return +} + +type OsuToken struct { + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + AccessToken string `json:"access_token"` +} diff --git a/scrape.go b/scrape.go new file mode 100644 index 0000000..16ae8fd --- /dev/null +++ b/scrape.go @@ -0,0 +1,120 @@ +package main + +import ( + "fmt" + "log" + "time" +) + +var ( + refreshInterval = 60 * time.Second +) + +func RunScraper(bot *Bot, db *Db, api *Osuapi, requests chan int) { + // start timers + go startTimers(db, requests) + + for userId := range requests { + log.Println("scraping", userId) + newMaps, err := getNewMaps(db, api, userId) + if err != nil { + log.Println("err getting new maps:", err) + exit_chan <- 1 + } + + db.IterTrackingChannels(userId, func(channelId string) error { + for _, beatmap := range newMaps { + bot.ChannelMessageSend(channelId, fmt.Sprintf("new beatmap event [%s](%s)", beatmap.Title, beatmap.URL)) + } + return nil + }) + + // wait a minute and put them back into the queue + go func() { + time.Sleep(refreshInterval) + requests <- userId + }() + } +} + +func getNewMaps(db *Db, api *Osuapi, userId int) (newMaps []Beatmapset, err error) { + // see if there's a last event + hasLastEvent, lastEventId := db.MapperLastEvent(userId) + newMaps = make([]Beatmapset, 0) + var ( + events []Event + newLatestEvent = 0 + updateLatestEvent = false + ) + if hasLastEvent { + log.Printf("last event id for %d is %d\n", userId, lastEventId) + offset := 0 + + loop: + for { + log.Println("loading user events from", offset) + events, err = api.GetUserEvents(userId, 50, offset) + if err != nil { + err = fmt.Errorf("couldn't load events for user %d, offset %d: %w", userId, offset, err) + return + } + if len(events) == 0 { + break + } + + for _, event := range events { + if event.ID == lastEventId { + break loop + } + + if event.ID > newLatestEvent { + updateLatestEvent = true + newLatestEvent = event.ID + } + + if event.Type == "beatmapsetUpload" || + event.Type == "beatmapsetRevive" || + event.Type == "beatmapsetUpdate" { + newMaps = append(newMaps, event.Beatmapset) + } + } + + offset += len(events) + } + } else { + log.Printf("no last event id found for %d\n", userId) + events, err = api.GetUserEvents(userId, 50, 0) + if err != nil { + return + } + + for _, event := range events { + if event.ID > newLatestEvent { + updateLatestEvent = true + newLatestEvent = event.ID + } + + if event.Type == "beatmapsetUpload" || + event.Type == "beatmapsetRevive" || + event.Type == "beatmapsetUpdate" { + newMaps = append(newMaps, event.Beatmapset) + } + } + } + + if updateLatestEvent { + err = db.UpdateMapperLatestEvent(userId, newLatestEvent) + if err != nil { + return + } + } + + return +} + +func startTimers(db *Db, requests chan int) { + db.IterTrackedMappers(func(userId int) error { + requests <- userId + return nil + }) +}