never executed always true always false
    1 module PureClaw.Providers.Anthropic
    2   ( -- * Provider type (constructor intentionally NOT exported)
    3     AnthropicProvider
    4   , mkAnthropicProvider
    5     -- * Errors
    6   , AnthropicError (..)
    7     -- * Request/response encoding (exported for testing)
    8   , encodeRequest
    9   , decodeResponse
   10     -- * SSE parsing (exported for testing)
   11   , parseSSELine
   12   ) where
   13 
   14 import Control.Exception
   15 import Data.Aeson
   16 import Data.Aeson.Types
   17 import Data.ByteString (ByteString)
   18 import Data.ByteString qualified as BS
   19 import Data.ByteString.Lazy qualified as BL
   20 import Data.IORef
   21 import Data.Maybe
   22 import Data.Text (Text)
   23 import Data.Text qualified as T
   24 import Data.Text.Encoding qualified as TE
   25 import Network.HTTP.Client qualified as HTTP
   26 import Network.HTTP.Types.Status qualified as Status
   27 
   28 import PureClaw.Core.Errors
   29 import PureClaw.Core.Types
   30 import PureClaw.Providers.Class
   31 import PureClaw.Security.Secrets
   32 
   33 -- | Anthropic API provider. Constructor is not exported — use
   34 -- 'mkAnthropicProvider'.
   35 data AnthropicProvider = AnthropicProvider
   36   { _ap_manager :: HTTP.Manager
   37   , _ap_apiKey  :: ApiKey
   38   }
   39 
   40 -- | Create an Anthropic provider with an HTTP manager and API key.
   41 mkAnthropicProvider :: HTTP.Manager -> ApiKey -> AnthropicProvider
   42 mkAnthropicProvider = AnthropicProvider
   43 
   44 instance Provider AnthropicProvider where
   45   complete = anthropicComplete
   46   completeStream = anthropicCompleteStream
   47 
   48 -- | Errors from the Anthropic API.
   49 data AnthropicError
   50   = AnthropicAPIError Int ByteString   -- ^ HTTP status code + response body
   51   | AnthropicParseError Text           -- ^ JSON parse/decode error
   52   deriving stock (Show)
   53 
   54 instance Exception AnthropicError
   55 
   56 instance ToPublicError AnthropicError where
   57   toPublicError (AnthropicAPIError 429 _) = RateLimitError
   58   toPublicError (AnthropicAPIError 401 _) = NotAllowedError
   59   toPublicError _                         = TemporaryError "Provider error"
   60 
   61 -- | Anthropic Messages API base URL.
   62 anthropicBaseUrl :: String
   63 anthropicBaseUrl = "https://api.anthropic.com/v1/messages"
   64 
   65 -- | Call the Anthropic Messages API.
   66 anthropicComplete :: AnthropicProvider -> CompletionRequest -> IO CompletionResponse
   67 anthropicComplete provider req = do
   68   initReq <- HTTP.parseRequest anthropicBaseUrl
   69   let httpReq = initReq
   70         { HTTP.method = "POST"
   71         , HTTP.requestBody = HTTP.RequestBodyLBS (encodeRequest req)
   72         , HTTP.requestHeaders =
   73             [ ("x-api-key", withApiKey (_ap_apiKey provider) id)
   74             , ("anthropic-version", "2023-06-01")
   75             , ("content-type", "application/json")
   76             ]
   77         }
   78   resp <- HTTP.httpLbs httpReq (_ap_manager provider)
   79   let status = Status.statusCode (HTTP.responseStatus resp)
   80   if status /= 200
   81     then throwIO (AnthropicAPIError status (BL.toStrict (HTTP.responseBody resp)))
   82     else case decodeResponse (HTTP.responseBody resp) of
   83       Left err -> throwIO (AnthropicParseError (T.pack err))
   84       Right response -> pure response
   85 
   86 -- | Encode a completion request as JSON for the Anthropic API.
   87 encodeRequest :: CompletionRequest -> BL.ByteString
   88 encodeRequest req = encode $ object $
   89   [ "model"      .= unModelId (_cr_model req)
   90   , "max_tokens" .= fromMaybe 4096 (_cr_maxTokens req)
   91   , "messages"   .= map encodeMsg (_cr_messages req)
   92   ]
   93   ++ maybe [] (\s -> ["system" .= s]) (_cr_systemPrompt req)
   94   ++ if null (_cr_tools req)
   95      then maybe [] (\tc -> ["tool_choice" .= encodeToolChoice tc]) (_cr_toolChoice req)
   96      else ("tools" .= map encodeTool (_cr_tools req))
   97         : maybe [] (\tc -> ["tool_choice" .= encodeToolChoice tc]) (_cr_toolChoice req)
   98 
   99 encodeMsg :: Message -> Value
  100 encodeMsg msg = object
  101   [ "role"    .= roleToText (_msg_role msg)
  102   , "content" .= map encodeContentBlock (_msg_content msg)
  103   ]
  104 
  105 encodeContentBlock :: ContentBlock -> Value
  106 encodeContentBlock (TextBlock t) = object
  107   [ "type" .= ("text" :: Text)
  108   , "text" .= t
  109   ]
  110 encodeContentBlock (ToolUseBlock callId name input) = object
  111   [ "type"  .= ("tool_use" :: Text)
  112   , "id"    .= unToolCallId callId
  113   , "name"  .= name
  114   , "input" .= input
  115   ]
  116 encodeContentBlock (ImageBlock mediaType imageData) = object
  117   [ "type" .= ("image" :: Text)
  118   , "source" .= object
  119       [ "type"       .= ("base64" :: Text)
  120       , "media_type" .= mediaType
  121       , "data"       .= TE.decodeUtf8 imageData
  122       ]
  123   ]
  124 encodeContentBlock (ToolResultBlock callId parts isErr) = object $
  125   [ "type"        .= ("tool_result" :: Text)
  126   , "tool_use_id" .= unToolCallId callId
  127   , "content"     .= map encodeToolResultPart parts
  128   ]
  129   ++ ["is_error" .= True | isErr]
  130 
  131 encodeToolResultPart :: ToolResultPart -> Value
  132 encodeToolResultPart (TRPText t) = object
  133   [ "type" .= ("text" :: Text), "text" .= t ]
  134 encodeToolResultPart (TRPImage mediaType imageData) = object
  135   [ "type" .= ("image" :: Text)
  136   , "source" .= object
  137       [ "type"       .= ("base64" :: Text)
  138       , "media_type" .= mediaType
  139       , "data"       .= TE.decodeUtf8 imageData
  140       ]
  141   ]
  142 
  143 encodeTool :: ToolDefinition -> Value
  144 encodeTool td = object
  145   [ "name"         .= _td_name td
  146   , "description"  .= _td_description td
  147   , "input_schema" .= _td_inputSchema td
  148   ]
  149 
  150 encodeToolChoice :: ToolChoice -> Value
  151 encodeToolChoice AutoTool = object ["type" .= ("auto" :: Text)]
  152 encodeToolChoice AnyTool = object ["type" .= ("any" :: Text)]
  153 encodeToolChoice (SpecificTool name) = object
  154   [ "type" .= ("tool" :: Text)
  155   , "name" .= name
  156   ]
  157 
  158 -- | Decode an Anthropic API response into a 'CompletionResponse'.
  159 decodeResponse :: BL.ByteString -> Either String CompletionResponse
  160 decodeResponse bs = eitherDecode bs >>= parseEither parseResp
  161   where
  162     parseResp :: Value -> Parser CompletionResponse
  163     parseResp = withObject "AnthropicResponse" $ \o -> do
  164       contentArr <- o .: "content"
  165       blocks <- mapM parseBlock contentArr
  166       modelText <- o .: "model"
  167       usageObj <- o .: "usage"
  168       inToks <- usageObj .: "input_tokens"
  169       outToks <- usageObj .: "output_tokens"
  170       pure CompletionResponse
  171         { _crsp_content = blocks
  172         , _crsp_model   = ModelId modelText
  173         , _crsp_usage   = Just (Usage inToks outToks)
  174         }
  175 
  176     parseBlock :: Value -> Parser ContentBlock
  177     parseBlock = withObject "ContentBlock" $ \b -> do
  178       bType <- b .: "type"
  179       case (bType :: Text) of
  180         "text" -> TextBlock <$> b .: "text"
  181         "tool_use" -> do
  182           callId <- b .: "id"
  183           name <- b .: "name"
  184           input <- b .: "input"
  185           pure (ToolUseBlock (ToolCallId callId) name input)
  186         other -> fail $ "Unknown content block type: " <> T.unpack other
  187 
  188 -- | Encode a streaming completion request (adds "stream": true).
  189 encodeStreamRequest :: CompletionRequest -> BL.ByteString
  190 encodeStreamRequest req = encode $ object $
  191   [ "model"      .= unModelId (_cr_model req)
  192   , "max_tokens" .= fromMaybe 4096 (_cr_maxTokens req)
  193   , "messages"   .= map encodeMsg (_cr_messages req)
  194   , "stream"     .= True
  195   ]
  196   ++ maybe [] (\s -> ["system" .= s]) (_cr_systemPrompt req)
  197   ++ if null (_cr_tools req)
  198      then maybe [] (\tc -> ["tool_choice" .= encodeToolChoice tc]) (_cr_toolChoice req)
  199      else ("tools" .= map encodeTool (_cr_tools req))
  200         : maybe [] (\tc -> ["tool_choice" .= encodeToolChoice tc]) (_cr_toolChoice req)
  201 
  202 -- | Stream a completion from the Anthropic API.
  203 -- Processes SSE events and emits StreamEvent callbacks. Accumulates
  204 -- the full response for the final StreamDone event.
  205 anthropicCompleteStream :: AnthropicProvider -> CompletionRequest -> (StreamEvent -> IO ()) -> IO ()
  206 anthropicCompleteStream provider req callback = do
  207   initReq <- HTTP.parseRequest anthropicBaseUrl
  208   let httpReq = initReq
  209         { HTTP.method = "POST"
  210         , HTTP.requestBody = HTTP.RequestBodyLBS (encodeStreamRequest req)
  211         , HTTP.requestHeaders =
  212             [ ("x-api-key", withApiKey (_ap_apiKey provider) id)
  213             , ("anthropic-version", "2023-06-01")
  214             , ("content-type", "application/json")
  215             ]
  216         }
  217   HTTP.withResponse httpReq (_ap_manager provider) $ \resp -> do
  218     let status = Status.statusCode (HTTP.responseStatus resp)
  219     if status /= 200
  220       then do
  221         body <- BL.toStrict <$> HTTP.brReadSome (HTTP.responseBody resp) (1024 * 1024)
  222         throwIO (AnthropicAPIError status body)
  223       else do
  224         -- Accumulate content blocks and usage as events arrive
  225         blocksRef <- newIORef ([] :: [ContentBlock])
  226         modelRef <- newIORef (ModelId "")
  227         usageRef <- newIORef (Nothing :: Maybe Usage)
  228         bufRef <- newIORef BS.empty
  229         let readChunks = do
  230               chunk <- HTTP.brRead (HTTP.responseBody resp)
  231               if BS.null chunk
  232                 then do
  233                   -- Stream ended — emit final response
  234                   blocks <- readIORef blocksRef
  235                   model <- readIORef modelRef
  236                   usage <- readIORef usageRef
  237                   callback $ StreamDone CompletionResponse
  238                     { _crsp_content = reverse blocks
  239                     , _crsp_model   = model
  240                     , _crsp_usage   = usage
  241                     }
  242                 else do
  243                   buf <- readIORef bufRef
  244                   let fullBuf = buf <> chunk
  245                       (lines', remaining) = splitSSELines fullBuf
  246                   writeIORef bufRef remaining
  247                   mapM_ (processSSELine blocksRef modelRef usageRef callback) lines'
  248                   readChunks
  249         readChunks
  250 
  251 -- | Split a buffer into complete SSE lines and remaining partial data.
  252 splitSSELines :: ByteString -> ([ByteString], ByteString)
  253 splitSSELines bs =
  254   let parts = BS.splitWith (== 0x0a) bs  -- split on newline
  255   in case parts of
  256     [] -> ([], BS.empty)
  257     _ -> (init parts, last parts)
  258 
  259 -- | Process a single SSE line.
  260 processSSELine :: IORef [ContentBlock] -> IORef ModelId -> IORef (Maybe Usage) -> (StreamEvent -> IO ()) -> ByteString -> IO ()
  261 processSSELine blocksRef modelRef usageRef callback line =
  262   case parseSSELine line of
  263     Nothing -> pure ()
  264     Just json -> case parseEither parseStreamEvent json of
  265       Left _ -> pure ()
  266       Right evt -> case evt of
  267         SSEContentText t -> do
  268           callback (StreamText t)
  269           modifyIORef blocksRef (TextBlock t :)
  270         SSEToolStart callId name ->
  271           callback (StreamToolUse callId name)
  272         SSEToolDelta t ->
  273           callback (StreamToolInput t)
  274         SSEMessageStart model -> writeIORef modelRef model
  275         SSEUsage usage -> writeIORef usageRef (Just usage)
  276         SSEMessageStop -> pure ()
  277 
  278 -- | Parse an SSE "data: ..." line into a JSON value.
  279 parseSSELine :: ByteString -> Maybe Value
  280 parseSSELine bs
  281   | BS.isPrefixOf "data: " bs =
  282       let jsonBs = BS.drop 6 bs
  283       in decode (BL.fromStrict jsonBs)
  284   | otherwise = Nothing
  285 
  286 -- | Internal SSE event types.
  287 data SSEEvent
  288   = SSEContentText Text
  289   | SSEToolStart ToolCallId Text
  290   | SSEToolDelta Text
  291   | SSEMessageStart ModelId
  292   | SSEUsage Usage
  293   | SSEMessageStop
  294 
  295 -- | Parse a JSON SSE event.
  296 parseStreamEvent :: Value -> Parser SSEEvent
  297 parseStreamEvent = withObject "SSEEvent" $ \o -> do
  298   eventType <- o .: "type"
  299   case (eventType :: Text) of
  300     "message_start" -> do
  301       msg <- o .: "message"
  302       model <- msg .: "model"
  303       pure (SSEMessageStart (ModelId model))
  304     "content_block_delta" -> do
  305       delta <- o .: "delta"
  306       deltaType <- delta .: "type"
  307       case (deltaType :: Text) of
  308         "text_delta" -> SSEContentText <$> delta .: "text"
  309         "input_json_delta" -> SSEToolDelta <$> delta .: "partial_json"
  310         _ -> fail $ "Unknown delta type: " <> T.unpack deltaType
  311     "content_block_start" -> do
  312       block <- o .: "content_block"
  313       blockType <- block .: "type"
  314       case (blockType :: Text) of
  315         "tool_use" -> do
  316           callId <- block .: "id"
  317           name <- block .: "name"
  318           pure (SSEToolStart (ToolCallId callId) name)
  319         _ -> fail "Ignored block start"
  320     "message_delta" -> do
  321       usage <- o .:? "usage"
  322       case usage of
  323         Just u -> do
  324           outToks <- u .: "output_tokens"
  325           pure (SSEUsage (Usage 0 outToks))
  326         Nothing -> pure SSEMessageStop
  327     "message_stop" -> pure SSEMessageStop
  328     _ -> fail $ "Unknown event type: " <> T.unpack eventType