// The channel package is the container for first-party framework C2 structures and variables, it
// holds the internal settings for multiple types of C2s. It is also passed to other external C2
// components in order to support extracting components such as lhost and lport.
package channel

import (
	"crypto/rand"
	"encoding/base32"
	"io"
	"net"
	"sync/atomic"
	"time"

	"github.com/vulncheck-oss/go-exploit/output"
)

type Channel struct {
	IPAddr   string
	HTTPAddr string
	Port     int
	HTTPPort int
	Timeout  int
	IsClient bool
	Shutdown *atomic.Bool
	Sessions map[string]Session
	Input    io.Reader
	Output   io.Writer // Currently unused but figured we'd add it ahead of time
}

type Session struct {
	RemoteAddr     string
	ConnectionTime time.Time
	conn           *net.Conn
	Active         bool
	LastSeen       time.Time
}

// HasSessions checks if a channel has any tracked sessions. This can be used to lookup if a C2
// successfully received callbacks:
//
//	c, ok := c2.GetInstance(conf.C2Type)
//	c.Channel().HasSessions()
func (c *Channel) HasSessions() bool {
	for _, sess := range c.Sessions {
		if sess.Active {
			return true
		}
	}

	return false
}

// AddSession adds a remote connection for session tracking. If a network connection is being
// tracked it can be added here and will be cleaned up and closed automatically by the C2 on
// shutdown.
func (c *Channel) AddSession(conn *net.Conn, addr string) bool {
	if len(c.Sessions) == 0 {
		c.Sessions = make(map[string]Session)
	}
	// This is my session randomizing logic. The theory is that it keeps us dependency free while
	// also creating the same 16bit strength of UUIDs. If we only plan on using the random UUIDs
	// anyway this should meet the same goals while also being URL safe and no special characters.
	k := make([]byte, 16)
	_, err := rand.Read(k)
	if err != nil {
		output.PrintfFrameworkError("Could not add session: %s", err.Error())

		return false
	}
	id := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(k)
	c.Sessions[id] = Session{
		// Add the time of now to the current connection time
		ConnectionTime: time.Now(),
		conn:           conn,
		RemoteAddr:     addr,
		LastSeen:       time.Now(),
		Active:         true,
	}

	return true
}

// Updates the LastSeen value for provided connection to the provided time
func (c *Channel) UpdateLastSeenByConn(conn net.Conn, timeStamp time.Time) bool {
	id, ok := c.GetSessionIDByConn(conn)
	if !ok {
		return false
	}

	session, ok := c.Sessions[id]
	if !ok {
		output.PrintFrameworkError("Session ID does not exist")

		return false
	}

	session.LastSeen = timeStamp
	c.Sessions[id] = session

	return true
}

// Returns the session ID that contains a given connection
func (c *Channel) GetSessionIDByConn(conn net.Conn) (string, bool) {
	if len(c.Sessions) == 0 {
		output.PrintFrameworkDebug("No sessions exist")

		return "", false
	}

	for id, session := range c.Sessions {
		if *session.conn == conn {
			return id, true
		}
	}

	output.PrintFrameworkError("Conn does not exist in sessions")

	return "", false
}


// RemoveSession removes a specific session ID and if a connection exists, closes it.
func (c *Channel) RemoveSession(id string) bool {
	if len(c.Sessions) == 0 {
		output.PrintFrameworkDebug("No sessions exist")

		return false
	}
	session, ok := c.Sessions[id]
	if !ok {
		output.PrintFrameworkError("Session ID does not exist")

		return false
	}
	if c.Sessions[id].conn != nil {
		(*c.Sessions[id].conn).Close()
	}
	session.Active = false
	c.Sessions[id] = session

	return true
}

// RemoveSessions removes all tracked sessions and closes any open connections if applicable.
func (c *Channel) RemoveSessions() bool {
	if len(c.Sessions) == 0 {
		output.PrintFrameworkDebug("No sessions exist")

		return false
	}
	for id := range c.Sessions {
		c.RemoveSession(id)
	}

	return true
}
