18
18
19
19
from __future__ import annotations
20
20
21
+ import json
22
+
21
23
from botocore .eventstream import EventStream
22
24
from wrapt import ObjectProxy
23
25
@@ -46,20 +48,21 @@ def __iter__(self):
46
48
def _process_event (self , event ):
47
49
if "messageStart" in event :
48
50
# {'messageStart': {'role': 'assistant'}}
49
- pass
51
+ return
50
52
51
53
if "contentBlockDelta" in event :
52
54
# {'contentBlockDelta': {'delta': {'text': "Hello"}, 'contentBlockIndex': 0}}
53
- pass
55
+ return
54
56
55
57
if "contentBlockStop" in event :
56
58
# {'contentBlockStop': {'contentBlockIndex': 0}}
57
- pass
59
+ return
58
60
59
61
if "messageStop" in event :
60
62
# {'messageStop': {'stopReason': 'end_turn'}}
61
63
if stop_reason := event ["messageStop" ].get ("stopReason" ):
62
64
self ._response ["stopReason" ] = stop_reason
65
+ return
63
66
64
67
if "metadata" in event :
65
68
# {'metadata': {'usage': {'inputTokens': 12, 'outputTokens': 15, 'totalTokens': 27}, 'metrics': {'latencyMs': 2980}}}
@@ -72,3 +75,136 @@ def _process_event(self, event):
72
75
self ._response ["usage" ]["outputTokens" ] = output_tokens
73
76
74
77
self ._stream_done_callback (self ._response )
78
+ return
79
+
80
+
81
+ # pylint: disable=abstract-method
82
+ class InvokeModelWithResponseStreamWrapper (ObjectProxy ):
83
+ """Wrapper for botocore.eventstream.EventStream"""
84
+
85
+ def __init__ (
86
+ self ,
87
+ stream : EventStream ,
88
+ stream_done_callback ,
89
+ model_id : str ,
90
+ ):
91
+ super ().__init__ (stream )
92
+
93
+ self ._stream_done_callback = stream_done_callback
94
+ self ._model_id = model_id
95
+
96
+ # accumulating things in the same shape of the Converse API
97
+ # {"usage": {"inputTokens": 0, "outputTokens": 0}, "stopReason": "finish"}
98
+ self ._response = {}
99
+
100
+ def __iter__ (self ):
101
+ for event in self .__wrapped__ :
102
+ self ._process_event (event )
103
+ yield event
104
+
105
+ def _process_event (self , event ):
106
+ if "chunk" not in event :
107
+ return
108
+
109
+ json_bytes = event ["chunk" ].get ("bytes" , b"" )
110
+ decoded = json_bytes .decode ("utf-8" )
111
+ try :
112
+ chunk = json .loads (decoded )
113
+ except json .JSONDecodeError :
114
+ return
115
+
116
+ if "amazon.titan" in self ._model_id :
117
+ self ._process_amazon_titan_chunk (chunk )
118
+ elif "amazon.nova" in self ._model_id :
119
+ self ._process_amazon_nova_chunk (chunk )
120
+ elif "anthropic.claude" in self ._model_id :
121
+ self ._process_anthropic_claude_chunk (chunk )
122
+
123
+ def _process_invocation_metrics (self , invocation_metrics ):
124
+ self ._response ["usage" ] = {}
125
+ if input_tokens := invocation_metrics .get ("inputTokenCount" ):
126
+ self ._response ["usage" ]["inputTokens" ] = input_tokens
127
+
128
+ if output_tokens := invocation_metrics .get ("outputTokenCount" ):
129
+ self ._response ["usage" ]["outputTokens" ] = output_tokens
130
+
131
+ def _process_amazon_titan_chunk (self , chunk ):
132
+ if (stop_reason := chunk .get ("completionReason" )) is not None :
133
+ self ._response ["stopReason" ] = stop_reason
134
+
135
+ if invocation_metrics := chunk .get ("amazon-bedrock-invocationMetrics" ):
136
+ # "amazon-bedrock-invocationMetrics":{
137
+ # "inputTokenCount":9,"outputTokenCount":128,"invocationLatency":3569,"firstByteLatency":2180
138
+ # }
139
+ self ._process_invocation_metrics (invocation_metrics )
140
+ self ._stream_done_callback (self ._response )
141
+
142
+ def _process_amazon_nova_chunk (self , chunk ):
143
+ if "messageStart" in chunk :
144
+ # {'messageStart': {'role': 'assistant'}}
145
+ return
146
+
147
+ if "contentBlockDelta" in chunk :
148
+ # {'contentBlockDelta': {'delta': {'text': "Hello"}, 'contentBlockIndex': 0}}
149
+ return
150
+
151
+ if "contentBlockStop" in chunk :
152
+ # {'contentBlockStop': {'contentBlockIndex': 0}}
153
+ return
154
+
155
+ if "messageStop" in chunk :
156
+ # {'messageStop': {'stopReason': 'end_turn'}}
157
+ if stop_reason := chunk ["messageStop" ].get ("stopReason" ):
158
+ self ._response ["stopReason" ] = stop_reason
159
+ return
160
+
161
+ if "metadata" in chunk :
162
+ # {'metadata': {'usage': {'inputTokens': 8, 'outputTokens': 117}, 'metrics': {}, 'trace': {}}}
163
+ if usage := chunk ["metadata" ].get ("usage" ):
164
+ self ._response ["usage" ] = {}
165
+ if input_tokens := usage .get ("inputTokens" ):
166
+ self ._response ["usage" ]["inputTokens" ] = input_tokens
167
+
168
+ if output_tokens := usage .get ("outputTokens" ):
169
+ self ._response ["usage" ]["outputTokens" ] = output_tokens
170
+
171
+ self ._stream_done_callback (self ._response )
172
+ return
173
+
174
+ def _process_anthropic_claude_chunk (self , chunk ):
175
+ # pylint: disable=too-many-return-statements
176
+ if not (message_type := chunk .get ("type" )):
177
+ return
178
+
179
+ if message_type == "message_start" :
180
+ # {'type': 'message_start', 'message': {'id': 'id', 'type': 'message', 'role': 'assistant', 'model': 'claude-2.0', 'content': [], 'stop_reason': None, 'stop_sequence': None, 'usage': {'input_tokens': 18, 'output_tokens': 1}}}
181
+ return
182
+
183
+ if message_type == "content_block_start" :
184
+ # {'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}}
185
+ return
186
+
187
+ if message_type == "content_block_delta" :
188
+ # {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Here'}}
189
+ return
190
+
191
+ if message_type == "content_block_stop" :
192
+ # {'type': 'content_block_stop', 'index': 0}
193
+ return
194
+
195
+ if message_type == "message_delta" :
196
+ # {'type': 'message_delta', 'delta': {'stop_reason': 'end_turn', 'stop_sequence': None}, 'usage': {'output_tokens': 123}}
197
+ if (
198
+ stop_reason := chunk .get ("delta" , {}).get ("stop_reason" )
199
+ ) is not None :
200
+ self ._response ["stopReason" ] = stop_reason
201
+ return
202
+
203
+ if message_type == "message_stop" :
204
+ # {'type': 'message_stop', 'amazon-bedrock-invocationMetrics': {'inputTokenCount': 18, 'outputTokenCount': 123, 'invocationLatency': 5250, 'firstByteLatency': 290}}
205
+ if invocation_metrics := chunk .get (
206
+ "amazon-bedrock-invocationMetrics"
207
+ ):
208
+ self ._process_invocation_metrics (invocation_metrics )
209
+ self ._stream_done_callback (self ._response )
210
+ return
0 commit comments