| 
          package main | 
        
        
           | 
          
 | 
        
        
           | 
          import ( | 
        
        
           | 
          	"context" | 
        
        
           | 
          	"encoding/json" | 
        
        
           | 
          	"fmt" | 
        
        
           | 
          	"io" | 
        
        
           | 
          	"log" | 
        
        
           | 
          	"os" | 
        
        
           | 
          	"os/exec" | 
        
        
           | 
          	"sync" | 
        
        
           | 
          	"sync/atomic" | 
        
        
           | 
          	"time" | 
        
        
           | 
          	"unsafe" | 
        
        
           | 
          
 | 
        
        
           | 
          	"github.com/creack/pty" | 
        
        
           | 
          	"github.com/spf13/cobra" | 
        
        
           | 
          	"golang.org/x/sys/unix" | 
        
        
           | 
          	"golang.org/x/term" | 
        
        
           | 
          ) | 
        
        
           | 
          
 | 
        
        
           | 
          const ( | 
        
        
           | 
          	// IOCTL command to get local CID | 
        
        
           | 
          	IOCTL_VM_SOCKETS_GET_LOCAL_CID = 0x7b9 | 
        
        
           | 
          	defaultPort                    = 9999 | 
        
        
           | 
          
 | 
        
        
           | 
          	// Channel IDs for multiplexing | 
        
        
           | 
          	channelStdin   = 0 | 
        
        
           | 
          	channelStdout  = 1 | 
        
        
           | 
          	channelStderr  = 2 | 
        
        
           | 
          	channelControl = 3 | 
        
        
           | 
          
 | 
        
        
           | 
          	// Frame header size: 1 byte channel + 4 bytes length | 
        
        
           | 
          	frameHeaderSize = 5 | 
        
        
           | 
          	maxFrameSize    = 256 * 1024 // Increased from 32KB to 256KB for better throughput | 
        
        
           | 
          ) | 
        
        
           | 
          
 | 
        
        
           | 
          // ClientRequest is sent from client to server to specify execution mode | 
        
        
           | 
          type ClientRequest struct { | 
        
        
           | 
          	UsePTY  bool     `json:"use_pty"` | 
        
        
           | 
          	Command []string `json:"command"` // empty means interactive shell | 
        
        
           | 
          } | 
        
        
           | 
          
 | 
        
        
           | 
          // ServerResponse is sent from server to client with the exit code | 
        
        
           | 
          type ServerResponse struct { | 
        
        
           | 
          	ExitCode int `json:"exit_code"` | 
        
        
           | 
          } | 
        
        
           | 
          
 | 
        
        
           | 
          var ( | 
        
        
           | 
          	port     int | 
        
        
           | 
          	single   bool | 
        
        
           | 
          	forcePTY bool | 
        
        
           | 
          	connID   int32 | 
        
        
           | 
          ) | 
        
        
           | 
          
 | 
        
        
           | 
          // Frame represents a multiplexed data frame | 
        
        
           | 
          type Frame struct { | 
        
        
           | 
          	Channel byte | 
        
        
           | 
          	Data    []byte | 
        
        
           | 
          } | 
        
        
           | 
          
 | 
        
        
           | 
          // writeFrame writes a frame to the connection using vectored I/O for efficiency | 
        
        
           | 
          func writeFrame(w io.Writer, channel byte, data []byte) error { | 
        
        
           | 
          	if len(data) > maxFrameSize { | 
        
        
           | 
          		return fmt.Errorf("frame data too large: %d > %d", len(data), maxFrameSize) | 
        
        
           | 
          	} | 
        
        
           | 
          
 | 
        
        
           | 
          	header := make([]byte, frameHeaderSize) | 
        
        
           | 
          	header[0] = channel | 
        
        
           | 
          	// Write length in big-endian (network byte order) | 
        
        
           | 
          	header[1] = byte(len(data) >> 24) | 
        
        
           | 
          	header[2] = byte(len(data) >> 16) | 
        
        
           | 
          	header[3] = byte(len(data) >> 8) | 
        
        
           | 
          	header[4] = byte(len(data)) | 
        
        
           | 
          
 | 
        
        
           | 
          	// Try to use writev for efficient writing (single syscall) | 
        
        
           | 
          	if file, ok := w.(*os.File); ok && len(data) > 0 { | 
        
        
           | 
          		// Use writev to write header and data in one syscall | 
        
        
           | 
          		buffers := [][]byte{header, data} | 
        
        
           | 
          		_, err := unix.Writev(int(file.Fd()), buffers) | 
        
        
           | 
          		return err | 
        
        
           | 
          	} | 
        
        
           | 
          
 | 
        
        
           | 
          	// Fallback for non-file writers | 
        
        
           | 
          	if _, err := w.Write(header); err != nil { | 
        
        
           | 
          		return err | 
        
        
           | 
          	} | 
        
        
           | 
          	if len(data) > 0 { | 
        
        
           | 
          		if _, err := w.Write(data); err != nil { | 
        
        
           | 
          			return err | 
        
        
           | 
          		} | 
        
        
           | 
          	} | 
        
        
           | 
          	return nil | 
        
        
           | 
          } | 
        
        
           | 
          
 | 
        
        
           | 
          // readFrame reads a frame from the connection | 
        
        
           | 
          func readFrame(r io.Reader) (*Frame, error) { | 
        
        
           | 
          	header := make([]byte, frameHeaderSize) | 
        
        
           | 
          	if _, err := io.ReadFull(r, header); err != nil { | 
        
        
           | 
          		return nil, err | 
        
        
           | 
          	} | 
        
        
           | 
          
 | 
        
        
           | 
          	channel := header[0] | 
        
        
           | 
          	length := int(header[1])<<24 | int(header[2])<<16 | int(header[3])<<8 | int(header[4]) | 
        
        
           | 
          
 | 
        
        
           | 
          	if length > maxFrameSize { | 
        
        
           | 
          		return nil, fmt.Errorf("frame too large: %d > %d", length, maxFrameSize) | 
        
        
           | 
          	} | 
        
        
           | 
          
 | 
        
        
           | 
          	data := make([]byte, length) | 
        
        
           | 
          	if length > 0 { | 
        
        
           | 
          		if _, err := io.ReadFull(r, data); err != nil { | 
        
        
           | 
          			return nil, err | 
        
        
           | 
          		} | 
        
        
           | 
          	} | 
        
        
           | 
          
 | 
        
        
           | 
          	return &Frame{Channel: channel, Data: data}, nil | 
        
        
           | 
          } | 
        
        
           | 
          
 | 
        
        
           | 
          // getLocalCID retrieves the local CID by performing an ioctl on /dev/vsock | 
        
        
           | 
          func getLocalCID() (uint32, error) { | 
        
        
           | 
          	f, err := os.Open("/dev/vsock") | 
        
        
           | 
          	if err != nil { | 
        
        
           | 
          		return 0, fmt.Errorf("failed to open /dev/vsock: %w", err) | 
        
        
           | 
          	} | 
        
        
           | 
          	defer f.Close() | 
        
        
           | 
          
 | 
        
        
           | 
          	var cid uint32 | 
        
        
           | 
          	_, _, errno := unix.Syscall( | 
        
        
           | 
          		unix.SYS_IOCTL, | 
        
        
           | 
          		f.Fd(), | 
        
        
           | 
          		IOCTL_VM_SOCKETS_GET_LOCAL_CID, | 
        
        
           | 
          		uintptr(unsafe.Pointer(&cid)), | 
        
        
           | 
          	) | 
        
        
           | 
          	if errno != 0 { | 
        
        
           | 
          		return 0, fmt.Errorf("ioctl failed: %v", errno) | 
        
        
           | 
          	} | 
        
        
           | 
          
 | 
        
        
           | 
          	return cid, nil | 
        
        
           | 
          } | 
        
        
           | 
          
 | 
        
        
           | 
          // runServer starts the vsock server | 
        
        
           | 
          func runServer() error { | 
        
        
           | 
          	cid, err := getLocalCID() | 
        
        
           | 
          	if err != nil { | 
        
        
           | 
          		return fmt.Errorf("failed to get local CID: %w", err) | 
        
        
           | 
          	} | 
        
        
           | 
          
 | 
        
        
           | 
          	log.Printf("Local CID: %d", cid) | 
        
        
           | 
          
 | 
        
        
           | 
          	// Create a vsock socket | 
        
        
           | 
          	fd, err := unix.Socket(unix.AF_VSOCK, unix.SOCK_STREAM, 0) | 
        
        
           | 
          	if err != nil { | 
        
        
           | 
          		return fmt.Errorf("failed to create socket: %w", err) | 
        
        
           | 
          	} | 
        
        
           | 
          	defer unix.Close(fd) | 
        
        
           | 
          
 | 
        
        
           | 
          	// Bind to the vsock address | 
        
        
           | 
          	sockaddr := &unix.SockaddrVM{ | 
        
        
           | 
          		CID:  unix.VMADDR_CID_ANY, | 
        
        
           | 
          		Port: uint32(port), | 
        
        
           | 
          	} | 
        
        
           | 
          
 | 
        
        
           | 
          	if err := unix.Bind(fd, sockaddr); err != nil { | 
        
        
           | 
          		return fmt.Errorf("failed to bind: %w", err) | 
        
        
           | 
          	} | 
        
        
           | 
          
 | 
        
        
           | 
          	// Listen for connections | 
        
        
           | 
          	if err := unix.Listen(fd, 128); err != nil { | 
        
        
           | 
          		return fmt.Errorf("failed to listen: %w", err) | 
        
        
           | 
          	} | 
        
        
           | 
          
 | 
        
        
           | 
          	log.Printf("Listening on vsock(%d:%d)", cid, port) | 
        
        
           | 
          
 | 
        
        
           | 
          	// Accept connections in a loop | 
        
        
           | 
          	for { | 
        
        
           | 
          		clientFd, _, err := unix.Accept(fd) | 
        
        
           | 
          		if err != nil { | 
        
        
           | 
          			log.Printf("Accept error: %v", err) | 
        
        
           | 
          			continue | 
        
        
           | 
          		} | 
        
        
           | 
          
 | 
        
        
           | 
          		if single { | 
        
        
           | 
          			// Handle client and exit after it's done | 
        
        
           | 
          			handleClient(clientFd) | 
        
        
           | 
          			return nil | 
        
        
           | 
          		} else { | 
        
        
           | 
          			// Handle each client in a goroutine | 
        
        
           | 
          			go handleClient(clientFd) | 
        
        
           | 
          		} | 
        
        
           | 
          	} | 
        
        
           | 
          } | 
        
        
           | 
          
 | 
        
        
           | 
          // handleClient handles an individual client connection | 
        
        
           | 
          func handleClient(clientFd int) { | 
        
        
           | 
          	id := atomic.AddInt32(&connID, 1) | 
        
        
           | 
          	connName := fmt.Sprintf("conn-%d", id) | 
        
        
           | 
          
 | 
        
        
           | 
          	log.Printf("[%s] Client connected", connName) | 
        
        
           | 
          	defer log.Printf("[%s] Client disconnected", connName) | 
        
        
           | 
          
 | 
        
        
           | 
          	// Create a file from the client fd for easier I/O | 
        
        
           | 
          	// Note: closing clientFile will also close the FD | 
        
        
           | 
          	clientFile := os.NewFile(uintptr(clientFd), "vsock-client") | 
        
        
           | 
          	defer clientFile.Close() | 
        
        
           | 
          
 | 
        
        
           | 
          	// Read the client request | 
        
        
           | 
          	var req ClientRequest | 
        
        
           | 
          	decoder := json.NewDecoder(clientFile) | 
        
        
           | 
          	if err := decoder.Decode(&req); err != nil { | 
        
        
           | 
          		log.Printf("[%s] Failed to read client request: %v", connName, err) | 
        
        
           | 
          		return | 
        
        
           | 
          	} | 
        
        
           | 
          
 | 
        
        
           | 
          	// Determine what to execute | 
        
        
           | 
          	var cmdArgs []string | 
        
        
           | 
          	if len(req.Command) == 0 { | 
        
        
           | 
          		// No command specified, use default shell | 
        
        
           | 
          		shell := os.Getenv("SHELL") | 
        
        
           | 
          		if shell == "" { | 
        
        
           | 
          			shell = "/bin/sh" | 
        
        
           | 
          		} | 
        
        
           | 
          		cmdArgs = []string{shell} | 
        
        
           | 
          	} else { | 
        
        
           | 
          		cmdArgs = req.Command | 
        
        
           | 
          	} | 
        
        
           | 
          
 | 
        
        
           | 
          	cmd := exec.Command(cmdArgs[0], cmdArgs[1:]...) | 
        
        
           | 
          	cmd.Env = os.Environ() | 
        
        
           | 
          
 | 
        
        
           | 
          	var exitCode int | 
        
        
           | 
          
 | 
        
        
           | 
          	if req.UsePTY { | 
        
        
           | 
          		// PTY mode | 
        
        
           | 
          		ptmx, err := pty.Start(cmd) | 
        
        
           | 
          		if err != nil { | 
        
        
           | 
          			log.Printf("[%s] Failed to start command with PTY: %v", connName, err) | 
        
        
           | 
          			exitCode = 255 | 
        
        
           | 
          			// Send error exit code | 
        
        
           | 
          			resp := ServerResponse{ExitCode: exitCode} | 
        
        
           | 
          			json.NewEncoder(clientFile).Encode(&resp) | 
        
        
           | 
          			return | 
        
        
           | 
          		} | 
        
        
           | 
          		defer ptmx.Close() | 
        
        
           | 
          
 | 
        
        
           | 
          		// Copy data bidirectionally | 
        
        
           | 
          		var wg sync.WaitGroup | 
        
        
           | 
          		wg.Add(2) | 
        
        
           | 
          
 | 
        
        
           | 
          		// Client -> PTY | 
        
        
           | 
          		go func() { | 
        
        
           | 
          			defer wg.Done() | 
        
        
           | 
          			io.Copy(ptmx, clientFile) | 
        
        
           | 
          		}() | 
        
        
           | 
          
 | 
        
        
           | 
          		// PTY -> Client | 
        
        
           | 
          		go func() { | 
        
        
           | 
          			defer wg.Done() | 
        
        
           | 
          			io.Copy(clientFile, ptmx) | 
        
        
           | 
          		}() | 
        
        
           | 
          
 | 
        
        
           | 
          		// Wait for the command to exit | 
        
        
           | 
          		if err := cmd.Wait(); err != nil { | 
        
        
           | 
          			log.Printf("[%s] Command error: %v", connName, err) | 
        
        
           | 
          			if exitErr, ok := err.(*exec.ExitError); ok { | 
        
        
           | 
          				exitCode = exitErr.ExitCode() | 
        
        
           | 
          			} else { | 
        
        
           | 
          				exitCode = 255 | 
        
        
           | 
          			} | 
        
        
           | 
          		} else { | 
        
        
           | 
          			exitCode = 0 | 
        
        
           | 
          		} | 
        
        
           | 
          
 | 
        
        
           | 
          		// Close the PTY to stop writing to client | 
        
        
           | 
          		ptmx.Close() | 
        
        
           | 
          
 | 
        
        
           | 
          		// Shutdown read side to unblock the Client->PTY goroutine | 
        
        
           | 
          		unix.Shutdown(int(clientFile.Fd()), unix.SHUT_RD) | 
        
        
           | 
          
 | 
        
        
           | 
          		// Wait for goroutines | 
        
        
           | 
          		wg.Wait() | 
        
        
           | 
          
 | 
        
        
           | 
          		// Send exit code to client | 
        
        
           | 
          		resp := ServerResponse{ExitCode: exitCode} | 
        
        
           | 
          		json.NewEncoder(clientFile).Encode(&resp) | 
        
        
           | 
          	} else { | 
        
        
           | 
          		// Non-PTY mode: use pipes for stdin/stdout/stderr and multiplex them | 
        
        
           | 
          
 | 
        
        
           | 
          		// Create pipes for stdin, stdout, stderr | 
        
        
           | 
          		stdinPipe, err := cmd.StdinPipe() | 
        
        
           | 
          		if err != nil { | 
        
        
           | 
          			log.Printf("[%s] Failed to create stdin pipe: %v", connName, err) | 
        
        
           | 
          			exitCode = 255 | 
        
        
           | 
          			respData, _ := json.Marshal(ServerResponse{ExitCode: exitCode}) | 
        
        
           | 
          			writeFrame(clientFile, channelControl, respData) | 
        
        
           | 
          			return | 
        
        
           | 
          		} | 
        
        
           | 
          
 | 
        
        
           | 
          		stdoutPipe, err := cmd.StdoutPipe() | 
        
        
           | 
          		if err != nil { | 
        
        
           | 
          			log.Printf("[%s] Failed to create stdout pipe: %v", connName, err) | 
        
        
           | 
          			exitCode = 255 | 
        
        
           | 
          			respData, _ := json.Marshal(ServerResponse{ExitCode: exitCode}) | 
        
        
           | 
          			writeFrame(clientFile, channelControl, respData) | 
        
        
           | 
          			return | 
        
        
           | 
          		} | 
        
        
           | 
          
 | 
        
        
           | 
          		stderrPipe, err := cmd.StderrPipe() | 
        
        
           | 
          		if err != nil { | 
        
        
           | 
          			log.Printf("[%s] Failed to create stderr pipe: %v", connName, err) | 
        
        
           | 
          			exitCode = 255 | 
        
        
           | 
          			respData, _ := json.Marshal(ServerResponse{ExitCode: exitCode}) | 
        
        
           | 
          			writeFrame(clientFile, channelControl, respData) | 
        
        
           | 
          			return | 
        
        
           | 
          		} | 
        
        
           | 
          
 | 
        
        
           | 
          		if err := cmd.Start(); err != nil { | 
        
        
           | 
          			log.Printf("[%s] Failed to start command: %v", connName, err) | 
        
        
           | 
          			exitCode = 255 | 
        
        
           | 
          			respData, _ := json.Marshal(ServerResponse{ExitCode: exitCode}) | 
        
        
           | 
          			writeFrame(clientFile, channelControl, respData) | 
        
        
           | 
          			return | 
        
        
           | 
          		} | 
        
        
           | 
          
 | 
        
        
           | 
          		var wg sync.WaitGroup | 
        
        
           | 
          		var stdinWg sync.WaitGroup | 
        
        
           | 
          
 | 
        
        
           | 
          		// Track stdout/stderr completion separately | 
        
        
           | 
          		wg.Add(2) | 
        
        
           | 
          		stdinWg.Add(1) | 
        
        
           | 
          
 | 
        
        
           | 
          		// Client stdin -> Command stdin | 
        
        
           | 
          		go func() { | 
        
        
           | 
          			defer stdinWg.Done() | 
        
        
           | 
          			defer stdinPipe.Close() | 
        
        
           | 
          			for { | 
        
        
           | 
          				frame, err := readFrame(clientFile) | 
        
        
           | 
          				if err != nil { | 
        
        
           | 
          					return | 
        
        
           | 
          				} | 
        
        
           | 
          				if frame.Channel == channelStdin { | 
        
        
           | 
          					if len(frame.Data) > 0 { | 
        
        
           | 
          						stdinPipe.Write(frame.Data) | 
        
        
           | 
          					} | 
        
        
           | 
          				} | 
        
        
           | 
          			} | 
        
        
           | 
          		}() | 
        
        
           | 
          
 | 
        
        
           | 
          		// Command stdout -> Client | 
        
        
           | 
          		go func() { | 
        
        
           | 
          			defer wg.Done() | 
        
        
           | 
          			buf := make([]byte, maxFrameSize) | 
        
        
           | 
          			for { | 
        
        
           | 
          				n, err := stdoutPipe.Read(buf) | 
        
        
           | 
          				if n > 0 { | 
        
        
           | 
          					writeFrame(clientFile, channelStdout, buf[:n]) | 
        
        
           | 
          				} | 
        
        
           | 
          				if err != nil { | 
        
        
           | 
          					return | 
        
        
           | 
          				} | 
        
        
           | 
          			} | 
        
        
           | 
          		}() | 
        
        
           | 
          
 | 
        
        
           | 
          		// Command stderr -> Client | 
        
        
           | 
          		go func() { | 
        
        
           | 
          			defer wg.Done() | 
        
        
           | 
          			buf := make([]byte, maxFrameSize) | 
        
        
           | 
          			for { | 
        
        
           | 
          				n, err := stderrPipe.Read(buf) | 
        
        
           | 
          				if n > 0 { | 
        
        
           | 
          					writeFrame(clientFile, channelStderr, buf[:n]) | 
        
        
           | 
          				} | 
        
        
           | 
          				if err != nil { | 
        
        
           | 
          					return | 
        
        
           | 
          				} | 
        
        
           | 
          			} | 
        
        
           | 
          		}() | 
        
        
           | 
          
 | 
        
        
           | 
          		// Wait for stdout/stderr to be fully read (they'll get EOF when command exits) | 
        
        
           | 
          		wg.Wait() | 
        
        
           | 
          
 | 
        
        
           | 
          		// Now it's safe to call Wait - all pipe data has been read | 
        
        
           | 
          		if err := cmd.Wait(); err != nil { | 
        
        
           | 
          			log.Printf("[%s] Command error: %v", connName, err) | 
        
        
           | 
          			if exitErr, ok := err.(*exec.ExitError); ok { | 
        
        
           | 
          				exitCode = exitErr.ExitCode() | 
        
        
           | 
          			} else { | 
        
        
           | 
          				exitCode = 255 | 
        
        
           | 
          			} | 
        
        
           | 
          		} else { | 
        
        
           | 
          			exitCode = 0 | 
        
        
           | 
          		} | 
        
        
           | 
          
 | 
        
        
           | 
          		// Shutdown read side to unblock stdin reader goroutine | 
        
        
           | 
          		unix.Shutdown(int(clientFile.Fd()), unix.SHUT_RD) | 
        
        
           | 
          
 | 
        
        
           | 
          		// Wait for stdin goroutine to finish | 
        
        
           | 
          		stdinWg.Wait() | 
        
        
           | 
          
 | 
        
        
           | 
          		// Send exit code on control channel | 
        
        
           | 
          		resp := ServerResponse{ExitCode: exitCode} | 
        
        
           | 
          		respData, _ := json.Marshal(resp) | 
        
        
           | 
          		writeFrame(clientFile, channelControl, respData) | 
        
        
           | 
          	} | 
        
        
           | 
          } | 
        
        
           | 
          
 | 
        
        
           | 
          // runClient connects to a vsock server | 
        
        
           | 
          func runClient(cid uint32, command []string) error { | 
        
        
           | 
          	// Create a vsock socket | 
        
        
           | 
          	fd, err := unix.Socket(unix.AF_VSOCK, unix.SOCK_STREAM, 0) | 
        
        
           | 
          	if err != nil { | 
        
        
           | 
          		return fmt.Errorf("failed to create socket: %w", err) | 
        
        
           | 
          	} | 
        
        
           | 
          	defer unix.Close(fd) | 
        
        
           | 
          
 | 
        
        
           | 
          	// Connect to the server | 
        
        
           | 
          	sockaddr := &unix.SockaddrVM{ | 
        
        
           | 
          		CID:  cid, | 
        
        
           | 
          		Port: uint32(port), | 
        
        
           | 
          	} | 
        
        
           | 
          
 | 
        
        
           | 
          	if err := unix.Connect(fd, sockaddr); err != nil { | 
        
        
           | 
          		return fmt.Errorf("failed to connect: %w", err) | 
        
        
           | 
          	} | 
        
        
           | 
          
 | 
        
        
           | 
          	// Create a file from the socket for easier I/O | 
        
        
           | 
          	conn := os.NewFile(uintptr(fd), "vsock-conn") | 
        
        
           | 
          
 | 
        
        
           | 
          	// Determine if we need PTY | 
        
        
           | 
          	// PTY is needed if: no command provided OR -t flag is set | 
        
        
           | 
          	usePTY := len(command) == 0 || forcePTY | 
        
        
           | 
          
 | 
        
        
           | 
          	// Send the client request | 
        
        
           | 
          	req := ClientRequest{ | 
        
        
           | 
          		UsePTY:  usePTY, | 
        
        
           | 
          		Command: command, | 
        
        
           | 
          	} | 
        
        
           | 
          
 | 
        
        
           | 
          	encoder := json.NewEncoder(conn) | 
        
        
           | 
          	if err := encoder.Encode(&req); err != nil { | 
        
        
           | 
          		return fmt.Errorf("failed to send request: %w", err) | 
        
        
           | 
          	} | 
        
        
           | 
          
 | 
        
        
           | 
          	if usePTY { | 
        
        
           | 
          		// PTY mode: set terminal to raw mode if stdin is a terminal | 
        
        
           | 
          		var oldState *term.State | 
        
        
           | 
          		if term.IsTerminal(int(os.Stdin.Fd())) { | 
        
        
           | 
          			oldState, err = term.MakeRaw(int(os.Stdin.Fd())) | 
        
        
           | 
          			if err != nil { | 
        
        
           | 
          				return fmt.Errorf("failed to set raw mode: %w", err) | 
        
        
           | 
          			} | 
        
        
           | 
          			defer term.Restore(int(os.Stdin.Fd()), oldState) | 
        
        
           | 
          		} | 
        
        
           | 
          
 | 
        
        
           | 
          		// Channel to receive exit code | 
        
        
           | 
          		exitCodeChan := make(chan int, 1) | 
        
        
           | 
          
 | 
        
        
           | 
          		// Stdin -> Server | 
        
        
           | 
          		go func() { | 
        
        
           | 
          			io.Copy(conn, os.Stdin) | 
        
        
           | 
          		}() | 
        
        
           | 
          
 | 
        
        
           | 
          		// Server -> Stdout, then read exit code JSON | 
        
        
           | 
          		go func() { | 
        
        
           | 
          			// Copy all PTY output to stdout until the JSON starts | 
        
        
           | 
          			// We need to detect the JSON exit code message | 
        
        
           | 
          			buf := make([]byte, 32*1024) | 
        
        
           | 
          			var jsonBuf []byte | 
        
        
           | 
          
 | 
        
        
           | 
          			for { | 
        
        
           | 
          				n, err := conn.Read(buf) | 
        
        
           | 
          				if n > 0 { | 
        
        
           | 
          					// Check if this contains the start of our JSON message | 
        
        
           | 
          					data := buf[:n] | 
        
        
           | 
          
 | 
        
        
           | 
          					// Look for {"exit_code": pattern | 
        
        
           | 
          					jsonStart := -1 | 
        
        
           | 
          					for i := 0; i < len(data); i++ { | 
        
        
           | 
          						if i <= len(data)-13 && string(data[i:i+13]) == `{"exit_code":` { | 
        
        
           | 
          							jsonStart = i | 
        
        
           | 
          							break | 
        
        
           | 
          						} | 
        
        
           | 
          					} | 
        
        
           | 
          
 | 
        
        
           | 
          					if jsonStart >= 0 { | 
        
        
           | 
          						// Write everything before the JSON | 
        
        
           | 
          						if jsonStart > 0 { | 
        
        
           | 
          							os.Stdout.Write(data[:jsonStart]) | 
        
        
           | 
          						} | 
        
        
           | 
          						// Start collecting JSON | 
        
        
           | 
          						jsonBuf = append(jsonBuf, data[jsonStart:]...) | 
        
        
           | 
          						break | 
        
        
           | 
          					} else { | 
        
        
           | 
          						// No JSON yet, write all output | 
        
        
           | 
          						os.Stdout.Write(data) | 
        
        
           | 
          					} | 
        
        
           | 
          				} | 
        
        
           | 
          				if err != nil { | 
        
        
           | 
          					exitCodeChan <- 255 // Error reading | 
        
        
           | 
          					return | 
        
        
           | 
          				} | 
        
        
           | 
          			} | 
        
        
           | 
          
 | 
        
        
           | 
          			// Continue reading to get complete JSON | 
        
        
           | 
          			for { | 
        
        
           | 
          				n, err := conn.Read(buf) | 
        
        
           | 
          				if n > 0 { | 
        
        
           | 
          					jsonBuf = append(jsonBuf, buf[:n]...) | 
        
        
           | 
          				} | 
        
        
           | 
          				if err != nil { | 
        
        
           | 
          					break | 
        
        
           | 
          				} | 
        
        
           | 
          			} | 
        
        
           | 
          
 | 
        
        
           | 
          			// Parse the JSON exit code | 
        
        
           | 
          			var resp ServerResponse | 
        
        
           | 
          			if err := json.Unmarshal(jsonBuf, &resp); err != nil { | 
        
        
           | 
          				exitCodeChan <- 255 | 
        
        
           | 
          			} else { | 
        
        
           | 
          				exitCodeChan <- resp.ExitCode | 
        
        
           | 
          			} | 
        
        
           | 
          		}() | 
        
        
           | 
          
 | 
        
        
           | 
          		// Wait for exit code | 
        
        
           | 
          		exitCode := <-exitCodeChan | 
        
        
           | 
          
 | 
        
        
           | 
          		// Restore terminal before exiting | 
        
        
           | 
          		if oldState != nil { | 
        
        
           | 
          			term.Restore(int(os.Stdin.Fd()), oldState) | 
        
        
           | 
          		} | 
        
        
           | 
          
 | 
        
        
           | 
          		os.Exit(exitCode) | 
        
        
           | 
          	} else { | 
        
        
           | 
          		// Non-PTY mode: use framing protocol for stdin/stdout/stderr | 
        
        
           | 
          
 | 
        
        
           | 
          		ctx, cancel := context.WithCancel(context.Background()) | 
        
        
           | 
          		defer cancel() | 
        
        
           | 
          
 | 
        
        
           | 
          		exitCodeChan := make(chan int, 1) | 
        
        
           | 
          		var wg sync.WaitGroup | 
        
        
           | 
          		wg.Add(2) | 
        
        
           | 
          
 | 
        
        
           | 
          		// Stdin -> Server (on stdin channel) | 
        
        
           | 
          		go func() { | 
        
        
           | 
          			defer wg.Done() | 
        
        
           | 
          			buf := make([]byte, maxFrameSize) | 
        
        
           | 
          			for { | 
        
        
           | 
          				select { | 
        
        
           | 
          				case <-ctx.Done(): | 
        
        
           | 
          					return | 
        
        
           | 
          				default: | 
        
        
           | 
          				} | 
        
        
           | 
          
 | 
        
        
           | 
          				n, err := os.Stdin.Read(buf) | 
        
        
           | 
          				if n > 0 { | 
        
        
           | 
          					if err := writeFrame(conn, channelStdin, buf[:n]); err != nil { | 
        
        
           | 
          						return | 
        
        
           | 
          					} | 
        
        
           | 
          				} | 
        
        
           | 
          				if err != nil { | 
        
        
           | 
          					return | 
        
        
           | 
          				} | 
        
        
           | 
          			} | 
        
        
           | 
          		}() | 
        
        
           | 
          
 | 
        
        
           | 
          		// Server -> Stdout/Stderr | 
        
        
           | 
          		go func() { | 
        
        
           | 
          			defer wg.Done() | 
        
        
           | 
          			for { | 
        
        
           | 
          				frame, err := readFrame(conn) | 
        
        
           | 
          				if err != nil { | 
        
        
           | 
          					exitCodeChan <- 255 | 
        
        
           | 
          					return | 
        
        
           | 
          				} | 
        
        
           | 
          
 | 
        
        
           | 
          				switch frame.Channel { | 
        
        
           | 
          				case channelStdout: | 
        
        
           | 
          					os.Stdout.Write(frame.Data) | 
        
        
           | 
          				case channelStderr: | 
        
        
           | 
          					os.Stderr.Write(frame.Data) | 
        
        
           | 
          				case channelControl: | 
        
        
           | 
          					// Parse exit code | 
        
        
           | 
          					var resp ServerResponse | 
        
        
           | 
          					if err := json.Unmarshal(frame.Data, &resp); err != nil { | 
        
        
           | 
          						exitCodeChan <- 255 | 
        
        
           | 
          					} else { | 
        
        
           | 
          						exitCodeChan <- resp.ExitCode | 
        
        
           | 
          					} | 
        
        
           | 
          					// Cancel context to stop stdin reader | 
        
        
           | 
          					cancel() | 
        
        
           | 
          					return | 
        
        
           | 
          				} | 
        
        
           | 
          			} | 
        
        
           | 
          		}() | 
        
        
           | 
          
 | 
        
        
           | 
          		// Wait for exit code | 
        
        
           | 
          		exitCode := <-exitCodeChan | 
        
        
           | 
          
 | 
        
        
           | 
          		// Give goroutines a moment to finish, but don't wait forever | 
        
        
           | 
          		done := make(chan struct{}) | 
        
        
           | 
          		go func() { | 
        
        
           | 
          			wg.Wait() | 
        
        
           | 
          			close(done) | 
        
        
           | 
          		}() | 
        
        
           | 
          
 | 
        
        
           | 
          		select { | 
        
        
           | 
          		case <-done: | 
        
        
           | 
          		case <-time.After(100 * time.Millisecond): | 
        
        
           | 
          			// Force exit if goroutines don't finish quickly | 
        
        
           | 
          		} | 
        
        
           | 
          
 | 
        
        
           | 
          		os.Exit(exitCode) | 
        
        
           | 
          	} | 
        
        
           | 
          
 | 
        
        
           | 
          	return nil // unreachable | 
        
        
           | 
          } | 
        
        
           | 
          
 | 
        
        
           | 
          func main() { | 
        
        
           | 
          	var rootCmd = &cobra.Command{ | 
        
        
           | 
          		Use:   "vsock-shell", | 
        
        
           | 
          		Short: "A tool for executing commands over VM sockets", | 
        
        
           | 
          		Long:  `vsock-shell provides SSH-like functionality for virtual machine communication using VSOCK sockets.`, | 
        
        
           | 
          	} | 
        
        
           | 
          
 | 
        
        
           | 
          	var serverCmd = &cobra.Command{ | 
        
        
           | 
          		Use:   "serve", | 
        
        
           | 
          		Short: "Run in server mode", | 
        
        
           | 
          		Long:  `Start a vsock-shell server that listens for client connections and executes commands.`, | 
        
        
           | 
          		RunE: func(cmd *cobra.Command, args []string) error { | 
        
        
           | 
          			return runServer() | 
        
        
           | 
          		}, | 
        
        
           | 
          	} | 
        
        
           | 
          
 | 
        
        
           | 
          	serverCmd.Flags().IntVarP(&port, "port", "p", defaultPort, "Port to listen on") | 
        
        
           | 
          	serverCmd.Flags().BoolVar(&single, "single", false, "Exit after handling one client connection") | 
        
        
           | 
          	serverCmd.Flags().BoolVarP(&single, "one", "1", false, "Exit after handling one client connection (shorthand)") | 
        
        
           | 
          
 | 
        
        
           | 
          	var clientCmd = &cobra.Command{ | 
        
        
           | 
          		Use:     "exec CID [command...]", | 
        
        
           | 
          		Aliases: []string{"x"}, | 
        
        
           | 
          		Short:   "Run in client mode", | 
        
        
           | 
          		Long:    `Connect to a vsock-shell server and execute commands or open an interactive shell.`, | 
        
        
           | 
          		Args:    cobra.MinimumNArgs(1), | 
        
        
           | 
          		RunE: func(cmd *cobra.Command, args []string) error { | 
        
        
           | 
          			var cid uint32 | 
        
        
           | 
          			if _, err := fmt.Sscanf(args[0], "%d", &cid); err != nil { | 
        
        
           | 
          				return fmt.Errorf("invalid CID: %v", err) | 
        
        
           | 
          			} | 
        
        
           | 
          
 | 
        
        
           | 
          			command := args[1:] | 
        
        
           | 
          			return runClient(cid, command) | 
        
        
           | 
          		}, | 
        
        
           | 
          	} | 
        
        
           | 
          
 | 
        
        
           | 
          	clientCmd.Flags().IntVarP(&port, "port", "p", defaultPort, "Port to connect to") | 
        
        
           | 
          	clientCmd.Flags().BoolVarP(&forcePTY, "tty", "t", false, "Force PTY allocation") | 
        
        
           | 
          
 | 
        
        
           | 
          	rootCmd.AddCommand(serverCmd) | 
        
        
           | 
          	rootCmd.AddCommand(clientCmd) | 
        
        
           | 
          
 | 
        
        
           | 
          	if err := rootCmd.Execute(); err != nil { | 
        
        
           | 
          		os.Exit(1) | 
        
        
           | 
          	} | 
        
        
           | 
          } |