subscribe-bot/discord/bot.go

327 lines
7.4 KiB
Go

package discord
import (
"errors"
"fmt"
"io/ioutil"
"log"
"os"
"path"
"reflect"
"regexp"
"strconv"
"strings"
"time"
"github.com/bwmarrin/discordgo"
"github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/plumbing"
"github.com/go-git/go-git/v5/plumbing/object"
"subscribe-bot/config"
"subscribe-bot/db"
"subscribe-bot/osuapi"
)
type Bot struct {
*discordgo.Session
mentionRe *regexp.Regexp
db *db.Db
api *osuapi.Osuapi
config *config.Config
}
func NewBot(config *config.Config, db *db.Db, api *osuapi.Osuapi) (bot *Bot, err error) {
s, err := discordgo.New("Bot " + config.BotToken)
if err != nil {
return
}
err = s.Open()
if err != nil {
return
}
log.Println("connected to discord")
re, err := regexp.Compile("\\s*<@\\!?" + s.State.User.ID + ">\\s*")
if err != nil {
return
}
bot = &Bot{s, re, db, api, config}
s.AddHandler(bot.errWrap(bot.newMessageHandler))
return
}
func (bot *Bot) NotifyError(format string, args ...interface{}) {
msg := fmt.Sprintf(format, args...)
channel, _ := bot.UserChannelCreate("100443064228646912")
bot.ChannelMessageSend(channel.ID, msg)
}
func (bot *Bot) errWrap(fn interface{}) interface{} {
val := reflect.ValueOf(fn)
origType := reflect.TypeOf(fn)
origTypeIn := make([]reflect.Type, origType.NumIn())
for i := 0; i < origType.NumIn(); i++ {
origTypeIn[i] = origType.In(i)
}
newType := reflect.FuncOf(origTypeIn, []reflect.Type{}, false)
newFunc := reflect.MakeFunc(newType, func(args []reflect.Value) (result []reflect.Value) {
res := val.Call(args)
if len(res) > 0 && !res[0].IsNil() {
err := res[0].Interface().(error)
if err != nil {
bot.NotifyError("error: %s", err)
}
}
return []reflect.Value{}
})
return newFunc.Interface()
}
func (bot *Bot) NotifyNewBeatmap(channels []string, newMaps []osuapi.Beatmapset) (err error) {
for _, beatmapSet := range newMaps {
var eventTime time.Time
eventTime, err = time.Parse(time.RFC3339, beatmapSet.LastUpdated)
if err != nil {
return
}
var (
gotDownloadedBeatmap = false
downloadedBeatmap BeatmapsetDownloaded
// status git.Status
commit *object.Commit
parent *object.Commit
patch *object.Patch
foundPatch = false
// commitFiles *object.FileIter
)
// beatmapSet, err = bot.getBeatmapsetInfo(beatmap)
// try to open a repo for this beatmap
var repo *git.Repository
repoDir := path.Join(bot.config.Repos, strconv.Itoa(beatmapSet.UserID), strconv.Itoa(beatmapSet.ID))
if _, err := os.Stat(repoDir); os.IsNotExist(err) {
os.MkdirAll(repoDir, 0777)
}
repo, err = git.PlainOpen(repoDir)
if err == git.ErrRepositoryNotExists {
// create a new repo
repo, err = git.PlainInit(repoDir, false)
}
if err != nil {
return
}
// download latest updates to the map
err = bot.downloadBeatmapTo(&beatmapSet, repo, repoDir)
if err != nil {
log.Println("failed to download beatmap:", err)
} else {
gotDownloadedBeatmap = true
}
// create a commit
var (
worktree *git.Worktree
files []os.FileInfo
hash plumbing.Hash
)
worktree, err = repo.Worktree()
if err != nil {
return
}
files, err = ioutil.ReadDir(repoDir)
if err != nil {
return
}
for _, f := range files {
if f.Name() == ".git" {
continue
}
worktree.Add(f.Name())
}
hash, err = worktree.Commit(
fmt.Sprintf("update: %d", beatmapSet.ID),
&git.CommitOptions{
Author: &object.Signature{
Name: beatmapSet.Creator,
Email: "nobody@localhost",
When: eventTime,
},
},
)
if err != nil {
err = fmt.Errorf("couldn't create commit for %d: %w", beatmapSet.ID, err)
return
}
commit, err = repo.CommitObject(hash)
if err != nil {
err = fmt.Errorf("couldn't find commit with hash %s: %w", hash, err)
return
}
parent, err = commit.Parent(0)
if errors.Is(err, object.ErrParentNotFound) {
err = nil
} else if err != nil {
err = fmt.Errorf("couldn't retrieve commit parent: %w", err)
return
} else {
patch, err = parent.Patch(commit)
if err != nil {
err = fmt.Errorf("couldn't retrieve patch: %w", err)
return
}
foundPatch = true
}
embed := &discordgo.MessageEmbed{
URL: fmt.Sprintf("%s/map/%d/%d/versions", bot.config.Web.ServedAt, beatmapSet.UserID, beatmapSet.ID),
Title: fmt.Sprintf("Update: %s - %s", beatmapSet.Artist, beatmapSet.Title),
Timestamp: eventTime.Format(time.RFC3339),
Author: &discordgo.MessageEmbedAuthor{
URL: "https://osu.ppy.sh/u/" + strconv.Itoa(beatmapSet.UserID),
Name: beatmapSet.Creator,
IconURL: fmt.Sprintf(
"https://a.ppy.sh/%d?%d.png",
beatmapSet.UserID,
time.Now().Unix(),
),
},
Thumbnail: &discordgo.MessageEmbedThumbnail{
URL: beatmapSet.Covers.SlimCover2x,
},
}
if gotDownloadedBeatmap {
log.Println(downloadedBeatmap)
if foundPatch {
embed.Description = fmt.Sprintf(
"Latest revision: %s\n%s",
hash,
patch.Stats().String(),
)
} else {
embed.Description = "Newly tracked map; diff information will be reported upon next update!"
}
}
for _, channelId := range channels {
_, err = bot.ChannelMessageSendEmbed(channelId, embed)
if err != nil {
err = fmt.Errorf("failed to send to %s: %w", channelId, err)
}
}
}
return
}
type BeatmapsetDownloaded struct {
Path string
}
func (bot *Bot) downloadBeatmapTo(beatmapSet *osuapi.Beatmapset, repo *git.Repository, repoDir string) (err error) {
// clear all OSU files
files, err := ioutil.ReadDir(repoDir)
if err != nil {
return
}
for _, f := range files {
if !strings.HasSuffix(f.Name(), ".osu") {
continue
}
os.Remove(f.Name())
}
for _, beatmap := range beatmapSet.Beatmaps {
path := path.Join(repoDir, fmt.Sprintf("%d.osu", beatmap.ID))
err = bot.api.DownloadSingleBeatmap(beatmap.ID, path)
if err != nil {
return
}
}
return
}
func (bot *Bot) getBeatmapsetInfo(event osuapi.Event) (beatmapSet osuapi.Beatmapset, err error) {
beatmapSetId, err := strconv.Atoi(strings.TrimPrefix(event.Beatmapset.URL, "/s/"))
if err != nil {
return
}
log.Println("beatmap set id", beatmapSetId)
beatmapSet, err = bot.api.GetBeatmapSet(beatmapSetId)
if err != nil {
return
}
return
}
func (bot *Bot) newMessageHandler(s *discordgo.Session, m *discordgo.MessageCreate) (err error) {
mentionsMe := false
for _, user := range m.Mentions {
if user.ID == s.State.User.ID {
mentionsMe = true
break
}
}
if !mentionsMe {
return
}
msg := bot.mentionRe.ReplaceAllString(m.Content, " ")
msg = strings.Trim(msg, " ")
parts := strings.Split(msg, " ")
switch strings.ToLower(parts[0]) {
case "track":
if len(parts) < 2 {
err = errors.New("fucked up")
return
}
var mapper osuapi.User
mapperName := strings.Join(parts[1:], " ")
mapper, err = bot.api.GetUser(mapperName)
if err != nil {
return
}
mapperId := mapper.ID
err = bot.db.ChannelTrackMapper(m.ChannelID, mapperId, 3)
if err != nil {
return
}
bot.ChannelMessageSend(m.ChannelID, fmt.Sprintf("subscribed to %+v", mapper))
case "list":
mappers := make([]string, 0)
bot.db.IterChannelTrackedMappers(m.ChannelID, func(userId int) error {
var mapper osuapi.User
mapper, err = bot.api.GetUser(strconv.Itoa(userId))
if err != nil {
return err
}
mappers = append(mappers, mapper.Username)
return nil
})
bot.ChannelMessageSend(m.ChannelID, "tracking: "+strings.Join(mappers, ", "))
}
return
}
func (bot *Bot) Close() {
bot.Session.Close()
}